Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions DENOISING_DIFFUSION/src/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Model components for EXXA denoising diffusion pipeline."""

from .blocks import ResidualBlock, AttentionBlock, SinusoidalTimeEmbedding
from .unet import UNet
from .noise_scheduler import NoiseScheduler
from .ddpm import DDPM

__all__ = [
"ResidualBlock",
"AttentionBlock",
"SinusoidalTimeEmbedding",
"UNet",
"NoiseScheduler",
"DDPM",
]
96 changes: 96 additions & 0 deletions DENOISING_DIFFUSION/src/models/blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Building blocks for the U-Net backbone."""

import torch
import torch.nn as nn
from typing import Optional


class SinusoidalTimeEmbedding(nn.Module):
"""
Sinusoidal positional embedding for diffusion timestep conditioning.

Encodes scalar timestep t into a fixed-size vector that gets injected
into each ResidualBlock of the U-Net.

Args:
dim: Embedding dimension (should match U-Net base channels * 4)
"""

def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim

def forward(self, t: torch.Tensor) -> torch.Tensor:
"""
Args:
t: Timestep tensor, shape (B,)

Returns:
Embedding tensor, shape (B, dim)
"""
raise NotImplementedError


class ResidualBlock(nn.Module):
"""
Residual block with GroupNorm, SiLU activation, and time conditioning.

Used in both encoder and decoder of the U-Net. Injects the timestep
embedding via a linear projection added to the hidden features.

Args:
in_channels: Number of input channels
out_channels: Number of output channels
time_emb_dim: Dimension of the time embedding vector
dropout: Dropout probability
"""

def __init__(
self,
in_channels: int,
out_channels: int,
time_emb_dim: int,
dropout: float = 0.1,
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels

def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input feature map, shape (B, in_channels, H, W)
time_emb: Time embedding, shape (B, time_emb_dim)

Returns:
Output feature map, shape (B, out_channels, H, W)
"""
raise NotImplementedError


class AttentionBlock(nn.Module):
"""
Self-attention block for capturing global context.

Applied at deeper U-Net levels where spatial resolution is low.
Uses multi-head self-attention with GroupNorm and residual connection.

Args:
channels: Number of input/output channels
num_heads: Number of attention heads
"""

def __init__(self, channels: int, num_heads: int = 4) -> None:
super().__init__()
self.channels = channels
self.num_heads = num_heads

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input feature map, shape (B, channels, H, W)

Returns:
Output feature map, shape (B, channels, H, W)
"""
raise NotImplementedError
77 changes: 77 additions & 0 deletions DENOISING_DIFFUSION/src/models/ddpm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""DDPM wrapper combining U-Net backbone and noise scheduler."""

import torch
import torch.nn as nn
from typing import Tuple

from .unet import UNet
from .noise_scheduler import NoiseScheduler


class DDPM(nn.Module):
"""
Denoising Diffusion Probabilistic Model.

Wraps the U-Net backbone and noise scheduler into a single module.
Exposes training loss computation and inference sampling.

Args:
unet: U-Net model that predicts noise given (x_t, t)
scheduler: NoiseScheduler managing beta/alpha values
loss_type: Loss function to use — "l1" or "l2"

Example:
>>> unet = UNet(in_channels=1, out_channels=1)
>>> scheduler = NoiseScheduler(timesteps=1000)
>>> model = DDPM(unet, scheduler)
>>> x0 = torch.randn(2, 1, 64, 64)
>>> loss = model.training_loss(x0)
"""

def __init__(
self,
unet: UNet,
scheduler: NoiseScheduler,
loss_type: str = "l2",
) -> None:
super().__init__()
self.unet = unet
self.scheduler = scheduler
self.loss_type = loss_type

def training_loss(self, x0: torch.Tensor) -> torch.Tensor:
"""
Compute training loss for a batch of clean images.

Samples random timesteps and noise, runs forward diffusion,
predicts noise with U-Net, and computes loss against true noise.

Args:
x0: Clean image batch, shape (B, C, H, W)

Returns:
Scalar loss tensor
"""
raise NotImplementedError

@torch.no_grad()
def sample(
self,
shape: Tuple[int, ...],
device: torch.device,
) -> torch.Tensor:
"""
Generate a clean image from pure Gaussian noise.

Args:
shape: Output shape (B, C, H, W)
device: Target device

Returns:
Generated image, shape (B, C, H, W)
"""
raise NotImplementedError

def forward(self, x0: torch.Tensor) -> torch.Tensor:
"""Alias for training_loss — used during the training loop."""
return self.training_loss(x0)
99 changes: 99 additions & 0 deletions DENOISING_DIFFUSION/src/models/noise_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Noise scheduler for the forward and reverse diffusion process."""

import torch
from typing import Tuple


class NoiseScheduler:
"""
Manages the noise schedule for the diffusion process.

Precomputes beta, alpha, and alpha_cumprod values used in both
the forward (noising) and reverse (denoising) diffusion steps.

Args:
timesteps: Total number of diffusion steps T
beta_schedule: Schedule type — "linear" or "cosine"
beta_start: Starting beta value (used for linear schedule)
beta_end: Ending beta value (used for linear schedule)

Example:
>>> scheduler = NoiseScheduler(timesteps=1000, beta_schedule="linear")
>>> x0 = torch.randn(2, 1, 64, 64)
>>> t = torch.tensor([100, 500])
>>> xt, noise = scheduler.q_sample(x0, t)
>>> xt.shape
torch.Size([2, 1, 64, 64])
"""

def __init__(
self,
timesteps: int = 1000,
beta_schedule: str = "linear",
beta_start: float = 1e-4,
beta_end: float = 2e-2,
) -> None:
self.timesteps = timesteps
self.beta_schedule = beta_schedule
self.beta_start = beta_start
self.beta_end = beta_end

def q_sample(
self,
x0: torch.Tensor,
t: torch.Tensor,
noise: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward diffusion: add noise to clean image at timestep t.

Samples x_t ~ q(x_t | x_0) using the closed-form expression:
x_t = sqrt(alpha_cumprod_t) * x_0 + sqrt(1 - alpha_cumprod_t) * eps

Args:
x0: Clean image, shape (B, C, H, W)
t: Timestep indices, shape (B,)
noise: Optional pre-sampled noise; sampled from N(0,I) if None

Returns:
(x_t, noise) tuple — noisy image and the noise that was added
"""
raise NotImplementedError

def p_sample(
self,
model: torch.nn.Module,
xt: torch.Tensor,
t: torch.Tensor,
) -> torch.Tensor:
"""
Reverse diffusion: denoise one step from x_t to x_{t-1}.

Args:
model: U-Net that predicts noise given (x_t, t)
xt: Noisy image at timestep t, shape (B, C, H, W)
t: Current timestep indices, shape (B,)

Returns:
Denoised image x_{t-1}, shape (B, C, H, W)
"""
raise NotImplementedError

def p_sample_loop(
self,
model: torch.nn.Module,
shape: Tuple[int, ...],
device: torch.device,
) -> torch.Tensor:
"""
Full reverse diffusion loop from pure noise to clean image.

Args:
model: Trained U-Net
shape: Output shape (B, C, H, W)
device: Target device

Returns:
Generated clean image, shape (B, C, H, W)
"""
raise NotImplementedError
63 changes: 63 additions & 0 deletions DENOISING_DIFFUSION/src/models/unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""U-Net backbone for the diffusion model."""

import torch
import torch.nn as nn
from typing import Tuple


class UNet(nn.Module):
"""
U-Net architecture for predicting noise in the diffusion process.

Encoder-decoder structure with skip connections and time conditioning.
Attention is applied at deeper levels where spatial resolution is low.

Args:
in_channels: Number of input image channels (1 for grayscale)
out_channels: Number of output channels (same as in_channels)
base_channels: Base feature channels, doubled at each encoder level
channel_multipliers: Per-level channel multipliers, e.g. (1, 2, 4, 8)
num_res_blocks: Number of residual blocks per encoder/decoder level
attention_levels: Which levels (0-indexed) to apply self-attention
dropout: Dropout probability in residual blocks

Example:
>>> model = UNet(in_channels=1, out_channels=1, base_channels=64)
>>> x = torch.randn(2, 1, 64, 64)
>>> t = torch.randint(0, 1000, (2,))
>>> out = model(x, t)
>>> out.shape
torch.Size([2, 1, 64, 64])
"""

def __init__(
self,
in_channels: int = 1,
out_channels: int = 1,
base_channels: int = 64,
channel_multipliers: Tuple[int, ...] = (1, 2, 4, 8),
num_res_blocks: int = 2,
attention_levels: Tuple[int, ...] = (2, 3),
dropout: float = 0.1,
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.base_channels = base_channels
self.channel_multipliers = channel_multipliers
self.num_res_blocks = num_res_blocks
self.attention_levels = attention_levels
self.dropout = dropout

def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Predict noise given noisy image and timestep.

Args:
x: Noisy image tensor, shape (B, in_channels, H, W)
t: Diffusion timestep, shape (B,)

Returns:
Predicted noise, shape (B, out_channels, H, W)
"""
raise NotImplementedError
Loading