diff --git a/DENOISING_DIFFUSION/src/models/__init__.py b/DENOISING_DIFFUSION/src/models/__init__.py new file mode 100644 index 0000000..641410e --- /dev/null +++ b/DENOISING_DIFFUSION/src/models/__init__.py @@ -0,0 +1,15 @@ +"""Model components for EXXA denoising diffusion pipeline.""" + +from .blocks import ResidualBlock, AttentionBlock, SinusoidalTimeEmbedding +from .unet import UNet +from .noise_scheduler import NoiseScheduler +from .ddpm import DDPM + +__all__ = [ + "ResidualBlock", + "AttentionBlock", + "SinusoidalTimeEmbedding", + "UNet", + "NoiseScheduler", + "DDPM", +] diff --git a/DENOISING_DIFFUSION/src/models/blocks.py b/DENOISING_DIFFUSION/src/models/blocks.py new file mode 100644 index 0000000..c938bf4 --- /dev/null +++ b/DENOISING_DIFFUSION/src/models/blocks.py @@ -0,0 +1,96 @@ +"""Building blocks for the U-Net backbone.""" + +import torch +import torch.nn as nn +from typing import Optional + + +class SinusoidalTimeEmbedding(nn.Module): + """ + Sinusoidal positional embedding for diffusion timestep conditioning. + + Encodes scalar timestep t into a fixed-size vector that gets injected + into each ResidualBlock of the U-Net. + + Args: + dim: Embedding dimension (should match U-Net base channels * 4) + """ + + def __init__(self, dim: int) -> None: + super().__init__() + self.dim = dim + + def forward(self, t: torch.Tensor) -> torch.Tensor: + """ + Args: + t: Timestep tensor, shape (B,) + + Returns: + Embedding tensor, shape (B, dim) + """ + raise NotImplementedError + + +class ResidualBlock(nn.Module): + """ + Residual block with GroupNorm, SiLU activation, and time conditioning. + + Used in both encoder and decoder of the U-Net. Injects the timestep + embedding via a linear projection added to the hidden features. + + Args: + in_channels: Number of input channels + out_channels: Number of output channels + time_emb_dim: Dimension of the time embedding vector + dropout: Dropout probability + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + time_emb_dim: int, + dropout: float = 0.1, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Input feature map, shape (B, in_channels, H, W) + time_emb: Time embedding, shape (B, time_emb_dim) + + Returns: + Output feature map, shape (B, out_channels, H, W) + """ + raise NotImplementedError + + +class AttentionBlock(nn.Module): + """ + Self-attention block for capturing global context. + + Applied at deeper U-Net levels where spatial resolution is low. + Uses multi-head self-attention with GroupNorm and residual connection. + + Args: + channels: Number of input/output channels + num_heads: Number of attention heads + """ + + def __init__(self, channels: int, num_heads: int = 4) -> None: + super().__init__() + self.channels = channels + self.num_heads = num_heads + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Input feature map, shape (B, channels, H, W) + + Returns: + Output feature map, shape (B, channels, H, W) + """ + raise NotImplementedError diff --git a/DENOISING_DIFFUSION/src/models/ddpm.py b/DENOISING_DIFFUSION/src/models/ddpm.py new file mode 100644 index 0000000..93d2512 --- /dev/null +++ b/DENOISING_DIFFUSION/src/models/ddpm.py @@ -0,0 +1,77 @@ +"""DDPM wrapper combining U-Net backbone and noise scheduler.""" + +import torch +import torch.nn as nn +from typing import Tuple + +from .unet import UNet +from .noise_scheduler import NoiseScheduler + + +class DDPM(nn.Module): + """ + Denoising Diffusion Probabilistic Model. + + Wraps the U-Net backbone and noise scheduler into a single module. + Exposes training loss computation and inference sampling. + + Args: + unet: U-Net model that predicts noise given (x_t, t) + scheduler: NoiseScheduler managing beta/alpha values + loss_type: Loss function to use — "l1" or "l2" + + Example: + >>> unet = UNet(in_channels=1, out_channels=1) + >>> scheduler = NoiseScheduler(timesteps=1000) + >>> model = DDPM(unet, scheduler) + >>> x0 = torch.randn(2, 1, 64, 64) + >>> loss = model.training_loss(x0) + """ + + def __init__( + self, + unet: UNet, + scheduler: NoiseScheduler, + loss_type: str = "l2", + ) -> None: + super().__init__() + self.unet = unet + self.scheduler = scheduler + self.loss_type = loss_type + + def training_loss(self, x0: torch.Tensor) -> torch.Tensor: + """ + Compute training loss for a batch of clean images. + + Samples random timesteps and noise, runs forward diffusion, + predicts noise with U-Net, and computes loss against true noise. + + Args: + x0: Clean image batch, shape (B, C, H, W) + + Returns: + Scalar loss tensor + """ + raise NotImplementedError + + @torch.no_grad() + def sample( + self, + shape: Tuple[int, ...], + device: torch.device, + ) -> torch.Tensor: + """ + Generate a clean image from pure Gaussian noise. + + Args: + shape: Output shape (B, C, H, W) + device: Target device + + Returns: + Generated image, shape (B, C, H, W) + """ + raise NotImplementedError + + def forward(self, x0: torch.Tensor) -> torch.Tensor: + """Alias for training_loss — used during the training loop.""" + return self.training_loss(x0) diff --git a/DENOISING_DIFFUSION/src/models/noise_scheduler.py b/DENOISING_DIFFUSION/src/models/noise_scheduler.py new file mode 100644 index 0000000..bdd800a --- /dev/null +++ b/DENOISING_DIFFUSION/src/models/noise_scheduler.py @@ -0,0 +1,99 @@ +"""Noise scheduler for the forward and reverse diffusion process.""" + +import torch +from typing import Tuple + + +class NoiseScheduler: + """ + Manages the noise schedule for the diffusion process. + + Precomputes beta, alpha, and alpha_cumprod values used in both + the forward (noising) and reverse (denoising) diffusion steps. + + Args: + timesteps: Total number of diffusion steps T + beta_schedule: Schedule type — "linear" or "cosine" + beta_start: Starting beta value (used for linear schedule) + beta_end: Ending beta value (used for linear schedule) + + Example: + >>> scheduler = NoiseScheduler(timesteps=1000, beta_schedule="linear") + >>> x0 = torch.randn(2, 1, 64, 64) + >>> t = torch.tensor([100, 500]) + >>> xt, noise = scheduler.q_sample(x0, t) + >>> xt.shape + torch.Size([2, 1, 64, 64]) + """ + + def __init__( + self, + timesteps: int = 1000, + beta_schedule: str = "linear", + beta_start: float = 1e-4, + beta_end: float = 2e-2, + ) -> None: + self.timesteps = timesteps + self.beta_schedule = beta_schedule + self.beta_start = beta_start + self.beta_end = beta_end + + def q_sample( + self, + x0: torch.Tensor, + t: torch.Tensor, + noise: torch.Tensor | None = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward diffusion: add noise to clean image at timestep t. + + Samples x_t ~ q(x_t | x_0) using the closed-form expression: + x_t = sqrt(alpha_cumprod_t) * x_0 + sqrt(1 - alpha_cumprod_t) * eps + + Args: + x0: Clean image, shape (B, C, H, W) + t: Timestep indices, shape (B,) + noise: Optional pre-sampled noise; sampled from N(0,I) if None + + Returns: + (x_t, noise) tuple — noisy image and the noise that was added + """ + raise NotImplementedError + + def p_sample( + self, + model: torch.nn.Module, + xt: torch.Tensor, + t: torch.Tensor, + ) -> torch.Tensor: + """ + Reverse diffusion: denoise one step from x_t to x_{t-1}. + + Args: + model: U-Net that predicts noise given (x_t, t) + xt: Noisy image at timestep t, shape (B, C, H, W) + t: Current timestep indices, shape (B,) + + Returns: + Denoised image x_{t-1}, shape (B, C, H, W) + """ + raise NotImplementedError + + def p_sample_loop( + self, + model: torch.nn.Module, + shape: Tuple[int, ...], + device: torch.device, + ) -> torch.Tensor: + """ + Full reverse diffusion loop from pure noise to clean image. + + Args: + model: Trained U-Net + shape: Output shape (B, C, H, W) + device: Target device + + Returns: + Generated clean image, shape (B, C, H, W) + """ + raise NotImplementedError diff --git a/DENOISING_DIFFUSION/src/models/unet.py b/DENOISING_DIFFUSION/src/models/unet.py new file mode 100644 index 0000000..84ecd80 --- /dev/null +++ b/DENOISING_DIFFUSION/src/models/unet.py @@ -0,0 +1,63 @@ +"""U-Net backbone for the diffusion model.""" + +import torch +import torch.nn as nn +from typing import Tuple + + +class UNet(nn.Module): + """ + U-Net architecture for predicting noise in the diffusion process. + + Encoder-decoder structure with skip connections and time conditioning. + Attention is applied at deeper levels where spatial resolution is low. + + Args: + in_channels: Number of input image channels (1 for grayscale) + out_channels: Number of output channels (same as in_channels) + base_channels: Base feature channels, doubled at each encoder level + channel_multipliers: Per-level channel multipliers, e.g. (1, 2, 4, 8) + num_res_blocks: Number of residual blocks per encoder/decoder level + attention_levels: Which levels (0-indexed) to apply self-attention + dropout: Dropout probability in residual blocks + + Example: + >>> model = UNet(in_channels=1, out_channels=1, base_channels=64) + >>> x = torch.randn(2, 1, 64, 64) + >>> t = torch.randint(0, 1000, (2,)) + >>> out = model(x, t) + >>> out.shape + torch.Size([2, 1, 64, 64]) + """ + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + base_channels: int = 64, + channel_multipliers: Tuple[int, ...] = (1, 2, 4, 8), + num_res_blocks: int = 2, + attention_levels: Tuple[int, ...] = (2, 3), + dropout: float = 0.1, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.base_channels = base_channels + self.channel_multipliers = channel_multipliers + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.dropout = dropout + + def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """ + Predict noise given noisy image and timestep. + + Args: + x: Noisy image tensor, shape (B, in_channels, H, W) + t: Diffusion timestep, shape (B,) + + Returns: + Predicted noise, shape (B, out_channels, H, W) + """ + raise NotImplementedError diff --git a/DENOISING_DIFFUSION/tests/test_models_skeleton.py b/DENOISING_DIFFUSION/tests/test_models_skeleton.py new file mode 100644 index 0000000..f590dad --- /dev/null +++ b/DENOISING_DIFFUSION/tests/test_models_skeleton.py @@ -0,0 +1,133 @@ +"""Tests for model skeleton — imports, instantiation, and interface contracts.""" + +import pytest +import torch +from pathlib import Path +import sys + +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from models import DDPM, UNet, NoiseScheduler, ResidualBlock, AttentionBlock, SinusoidalTimeEmbedding + + +class TestImports: + def test_all_classes_importable(self): + assert UNet is not None + assert NoiseScheduler is not None + assert DDPM is not None + assert ResidualBlock is not None + assert AttentionBlock is not None + assert SinusoidalTimeEmbedding is not None + + +class TestUNet: + def test_instantiation_defaults(self): + model = UNet() + assert model.in_channels == 1 + assert model.out_channels == 1 + assert model.base_channels == 64 + + def test_instantiation_custom(self): + model = UNet(in_channels=1, out_channels=1, base_channels=32, channel_multipliers=(1, 2, 4)) + assert model.base_channels == 32 + assert model.channel_multipliers == (1, 2, 4) + + def test_forward_not_implemented(self): + model = UNet() + x = torch.randn(2, 1, 64, 64) + t = torch.randint(0, 1000, (2,)) + with pytest.raises(NotImplementedError): + model(x, t) + + def test_is_nn_module(self): + assert isinstance(UNet(), torch.nn.Module) + + +class TestNoiseScheduler: + def test_instantiation_defaults(self): + scheduler = NoiseScheduler() + assert scheduler.timesteps == 1000 + assert scheduler.beta_schedule == "linear" + + def test_instantiation_custom(self): + scheduler = NoiseScheduler(timesteps=500, beta_schedule="cosine") + assert scheduler.timesteps == 500 + assert scheduler.beta_schedule == "cosine" + + def test_q_sample_not_implemented(self): + scheduler = NoiseScheduler() + x0 = torch.randn(2, 1, 64, 64) + t = torch.tensor([100, 500]) + with pytest.raises(NotImplementedError): + scheduler.q_sample(x0, t) + + def test_p_sample_not_implemented(self): + scheduler = NoiseScheduler() + model = UNet() + xt = torch.randn(2, 1, 64, 64) + t = torch.tensor([100, 500]) + with pytest.raises(NotImplementedError): + scheduler.p_sample(model, xt, t) + + +class TestDDPM: + @pytest.fixture + def ddpm(self): + return DDPM(unet=UNet(), scheduler=NoiseScheduler()) + + def test_instantiation(self, ddpm): + assert ddpm.loss_type == "l2" + assert isinstance(ddpm.unet, UNet) + assert isinstance(ddpm.scheduler, NoiseScheduler) + + def test_is_nn_module(self, ddpm): + assert isinstance(ddpm, torch.nn.Module) + + def test_training_loss_not_implemented(self, ddpm): + x0 = torch.randn(2, 1, 64, 64) + with pytest.raises(NotImplementedError): + ddpm.training_loss(x0) + + def test_sample_not_implemented(self, ddpm): + with pytest.raises(NotImplementedError): + ddpm.sample(shape=(2, 1, 64, 64), device=torch.device("cpu")) + + def test_forward_delegates_to_training_loss(self, ddpm): + x0 = torch.randn(2, 1, 64, 64) + with pytest.raises(NotImplementedError): + ddpm(x0) + + +class TestBlocks: + def test_residual_block_instantiation(self): + block = ResidualBlock(in_channels=64, out_channels=64, time_emb_dim=256) + assert block.in_channels == 64 + assert block.out_channels == 64 + + def test_residual_block_forward_not_implemented(self): + block = ResidualBlock(in_channels=64, out_channels=64, time_emb_dim=256) + x = torch.randn(2, 64, 32, 32) + t = torch.randn(2, 256) + with pytest.raises(NotImplementedError): + block(x, t) + + def test_attention_block_instantiation(self): + block = AttentionBlock(channels=64, num_heads=4) + assert block.channels == 64 + assert block.num_heads == 4 + + def test_attention_block_forward_not_implemented(self): + block = AttentionBlock(channels=64) + x = torch.randn(2, 64, 16, 16) + with pytest.raises(NotImplementedError): + block(x) + + def test_time_embedding_instantiation(self): + emb = SinusoidalTimeEmbedding(dim=256) + assert emb.dim == 256 + + def test_time_embedding_forward_not_implemented(self): + emb = SinusoidalTimeEmbedding(dim=256) + t = torch.randint(0, 1000, (2,)) + with pytest.raises(NotImplementedError): + emb(t)