Skip to content

mytechnotalent/SARAN

Repository files navigation

SARAN: Shallow Auto-Regressive Attention Network

The dominant paradigms in sequence transduction - Recurrent Neural Networks and deep Transformer architectures - rely on complex, multi-layered structures to achieve performance, often at the cost of interpretability and computational transparency. In this work, we introduce the Shallow Auto-Regressive Attention Network (SARAN), a minimalist architecture that reduces the Transformer decoder to its fundamental components. SARAN is defined by a strictly linear, 15-stage computational graph that maps input embeddings directly to output probabilities via a single, isolated block of masked self-attention. We present a "first principles" derivation of the network's training dynamics, explicitly defining the manual backpropagation algorithm through the attention mechanism without reliance on automatic differentiation engines. By stripping away deep layer stacking and feed-forward networks, SARAN demonstrates that a solitary attention block is sufficient to mechanically derive autoregressive properties, providing a transparent and rigorous baseline for understanding the mechanics of attention-based sequence modeling.


Quick Start

SARAN consists of four Python scripts for the complete training pipeline:

0. Dataset Preparation (get_dataset.py)

Download and tokenize the OpenWebText dataset for pre-training:

# Activate virtual environment
source venv/bin/activate

# Download and tokenize OpenWebText (~30-60 min, streams from HuggingFace)
python get_dataset.py

Output: openwebtext_tokens.jsonl and openwebtext_offsets.npy

1. Pre-training (saran_mlv.py)

Pre-train the model on a large text corpus (OpenWebText) to learn general language understanding:

# Activate virtual environment
source venv/bin/activate

# Run pre-training (50k iterations, ~4-8 hours on GPU/MPS)
python saran_mlv.py

Output: saran_mlv_best.pt (best weights saved during training)

2. Fine-tuning (saran_mlv_ft.py)

Two-phase fine-tuning for professional chatbot capabilities:

Phase 1: Supervised Fine-Tuning (SFT)

# Run fine-tuning pipeline
python saran_mlv_ft.py

Phase 2: Direct Preference Optimization (DPO)

  • Automatically runs after SFT if dpo.enabled=true in config
  • Trains on preference data (Anthropic HH-RLHF or synthetic)
  • Creates reference model (frozen copy) for stable training

Output: saran_mlv_dpo_best.pt

3. Chat Inference (saran_mlv_c.py)

Run the interactive chatbot with agentic web search:

# Start the chat interface
python saran_mlv_c.py

Features:

  • Agentic Web Search + LLM Synthesis: Questions automatically search the web via DuckDuckGo, then the LLM synthesizes the results into a coherent response. Falls back to raw web results if synthesis fails.
  • Garbage Detection: Low-quality model outputs are caught and replaced with web results or "I don't know"
  • Conversation History: Maintains context from the last 10 turns
  • Configurable: All parameters loaded from config.json

Commands:

  • quit / exit / q - Exit the chat
  • clear / reset - Clear conversation history

4. Web Search Agent (web.py)

The web search agent provides real-time information retrieval:

import web
result = web.search("What is the capital of France?")
# Returns: "Paris is the capital and largest city of France..."

Features:

  • DuckDuckGo Instant Answer API (primary)
  • HTML scrape fallback for reliability
  • Automatic text cleaning (HTML entities, non-ASCII)
  • Truncation to complete sentences

5. Configuration (config.json)

All hyperparameters are centralized in config.json:

{
    "model": {
        "block_size": 512,
        "n_embd": 1536,
        "n_layer": 24,
        "vocab_size": 50304,
        "dropout": 0.1
    },
    "training": {
        "batch_size": 2,
        "grad_accum_steps": 32,
        "max_iters": 50000,
        "eval_interval": 1000,
        "eval_iters": 100,
        "learning_rate": 3e-4,
        "warmup_iters": 2000,
        "grad_clip": 1.0,
        "weight_decay": 0.1
    },
    "finetuning": {
        "batch_size": 2,
        "grad_accum_steps": 8,
        "max_iters": 50000,
        "eval_interval": 200,
        "learning_rate": 3e-5,
        "grad_clip": 1.0,
        "weight_decay": 0.01,
        "patience": 5
    },
    "dpo": {
        "enabled": true,
        "beta": 0.1,
        "learning_rate": 1e-6,
        "max_iters": 5000,
        "eval_interval": 100,
        "batch_size": 1,
        "grad_accum_steps": 16,
        "grad_clip": 1.0,
        "weight_decay": 0.01,
        "patience": 10,
        "dataset": "Anthropic/hh-rlhf"
    },
    "generation": {
        "max_new_tokens": 1024,
        "temperature": 0.7,
        "top_k": 40,
        "repetition_penalty": 1.3,
        "debug": false
    },
    "agents": {
        "web": {
            "enabled": true,
            "module": "web",
            "function": "search"
        }
    },
    "search_triggers": ["what", "who", "where", "when", "why", "how", ...]
}
Section Description
model Architecture hyperparameters (shared by all scripts)
training Pre-training hyperparameters (saran_mlv.py)
finetuning SFT hyperparameters (saran_mlv_ft.py Phase 1)
dpo DPO hyperparameters (saran_mlv_ft.py Phase 2)
generation Inference settings for chat (temperature, top_k, etc.)
agents Agentic capabilities (web search, future agents)
search_triggers Words that trigger automatic web search

Full Pipeline Example

# Complete training pipeline
cd /path/to/a-gpt
source venv/bin/activate

# Step 0: Download and tokenize dataset
python get_dataset.py

# Step 1: Pre-train (creates language model)
python saran_mlv.py

# Step 2: Fine-tune (teaches instruction following)
python saran_mlv_ft.py

# Step 3: Chat with agentic web search!
python saran_mlv_c.py

File Overview

File Purpose
saran_mlv.py Pre-training on OpenWebText
saran_mlv_ft.py Fine-tuning on Alpaca instructions
saran_mlv_c.py Chat interface with web search
web.py DuckDuckGo search agent
config.json Centralized configuration
get_dataset.py Dataset download and tokenization

Table of Contents


1. SARAN vs GPT: Key Innovations

SARAN introduces three key architectural simplifications compared to standard GPT:

Feature GPT SARAN Benefit
Attention Heads 12 multi-head 1 single-head Simpler, more interpretable
FFN Expansion 4× (1536→6144) 4× (1536→6144) Same capacity for synthesis
Normalization LayerNorm RMSNorm Faster computation
Activation GELU SiLU (Swish) Modern, smooth gradients
Weight Tying No Yes (embed = output) Fewer parameters
Biases Yes No (in Linear layers) Fewer parameters
Precision float32 bfloat16 (mixed) ~2x faster, 50% less memory
Compilation No torch.compile (CUDA) ~1.5-2x faster on GPU
Flash Attention No Yes (SDPA) O(T) memory, ~2-4x faster
Parameters ~125M (GPT-2) ~530M Better synthesis quality

These changes result in a more parameter-efficient model while maintaining competitive performance.


2. Configuration & Hyperparameters

The model is configured with these key hyperparameters:

B, T, C, L = 2, 512, 1536, 24
Symbol Name Value Description
$B$ Batch Size 2 Sequences per micro-batch
$T$ Context Length 512 Maximum sequence length (tokens)
$C$ Embedding Dimension 1536 Size of token/positional embeddings
$L$ Number of Layers 24 Transformer blocks stacked
$V$ Vocabulary Size 50,304 Padded for GPU efficiency

Note: SARAN has no $H$ (heads) parameter because it uses single-head attention. The full embedding dimension $C = 1536$ is used for attention, not split across heads.

Model Size: ~760M parameters (with 4x FFN expansion for better synthesis quality)

Additional training hyperparameters:

Parameter Value Description
grad_accum_steps 32 Gradient accumulation steps
lr 3e-4 Learning rate (with 2000-step warmup)
warmup_iters 2000 Linear warmup steps before cosine decay
grad_clip 1.0 Gradient clipping threshold
dropout 0.1 Dropout for regularization
weight_decay 0.1 AdamW weight decay

Effective batch size: $B \times G = 2 \times 32 = 64$ (where $G$ = gradient accumulation steps)


3. Data Pipeline

Tokenization

Text is converted to integer tokens using the GPT-2 BPE (Byte Pair Encoding) tokenizer:

enc = tiktoken.get_encoding("gpt2")
# Round vocab to nearest 64 for GPU efficiency (Karpathy's nanoGPT optimization)
# GPT-2 has 50257 tokens, but 50304 = 64 * 786 aligns with tensor core block sizes
vocab_size = 50304
encode = lambda s: enc.encode(s)
decode = lambda l: enc.decode(list(l))

Why 50,304 instead of 50,257?

  • GPT-2 tokenizer has exactly 50,257 tokens
  • 50,304 = 64 × 786 — aligns with GPU tensor core block sizes
  • Provides ~5-15% speedup on embedding/output layers
  • Only 47 extra "padding" tokens (never predicted)
  • Credit: Andrej Karpathy's nanoGPT optimization

Example:

"Hello world" → [15496, 995]

Batch Creation (OpenWebText)

SARAN uses a memory-mapped approach to load pre-tokenized OpenWebText data:

offsets = np.load("openwebtext_offsets.npy")
tokens_f = open("openwebtext_tokens.jsonl", "rb")

def get_batch(split):
    split_idx = int(0.9 * num_examples)
    start_i, end_i = (0, split_idx) if split == "train" else (split_idx, num_examples)
    x_list, y_list = [], []
    while len(x_list) < B:
        ex_id = start_i + np.random.randint(0, end_i - start_i)
        tokens_f.seek(int(offsets[ex_id]))
        toks = json.loads(tokens_f.readline().decode("utf-8"))
        if len(toks) <= T:
            continue
        start = np.random.randint(0, len(toks) - T)
        x_list.append(toks[start : start + T])
        y_list.append(toks[start + 1 : start + T + 1])
    return torch.tensor(x_list, device=device), torch.tensor(y_list, device=device)

This approach:

  1. Uses file offsets to randomly access documents
  2. Samples a random starting position within each document
  3. Extracts input/target pairs for next-token prediction

Shape Example:

  • Input shape: $(B, T) = (4, 512)$
  • Target shape: $(B, T) = (4, 512)$

4. Execution Flow Overview

When training, execution follows this path:

                              ┌─────────────────────────────────────────────────────┐
                              │                   SARAN Class                       │
                              ├─────────────────────────────────────────────────────┤
Input Tokens ──► Token Embed ──► + ──► Block ──► Block ──► ... ──► Block ──► RMSNorm ──► Linear ──► Logits
  (B, T)           (B,T,C)      │       ×1        ×2              ×24         (B,T,C)      (B,T,V)
                                │                                                            ↑
                    Pos Embed ──┘                                                            │
                      (T, C)                                              (weight tying) ────┘

Each Block contains:

Input ──► RMSNorm ──► Single-Head Attention ──► + ──► RMSNorm ──► FFN (4x) ──► + ──► Output
  │                                             │                              │
  └────────────── (residual) ───────────────────┘────────── (residual) ────────┘

Key differences from GPT:

  • RMSNorm instead of LayerNorm
  • Single-head attention instead of 12-head MHA
  • SiLU activation instead of GELU

5. The SARAN Class: The Heart of the Model

class SARAN(nn.Module):
    def __init__(self):
        super().__init__()
        self.tok = nn.Embedding(vocab_size, C)                     # Token embeddings
        self.pos = nn.Embedding(T, C)                              # Positional embeddings
        self.blocks = nn.Sequential(*[Block() for _ in range(L)])  # 12 transformer blocks
        self.ln = RMSNorm(C)                                       # Final RMSNorm
        self.head = nn.Linear(C, vocab_size, bias=False)           # Output projection
        self.tok.weight = self.head.weight                         # Weight tying!
        self.apply(self._init_weights)

The forward pass:

def forward(self, idx, tgt=None):
    x = self.tok(idx) + self.pos(torch.arange(idx.shape[1], device=device))
    logits = self.head(self.ln(self.blocks(x)))
    return logits, F.cross_entropy(...) if tgt is not None else None

Let's trace a concrete example through the entire network.


6. Token Embeddings

The token embedding layer maps each token ID to a dense vector:

$$\mathbf{E}_{tok} \in \mathbb{R}^{V \times C} = \mathbb{R}^{50304 \times 1536}$$

For input tokens $\mathbf{x} \in \mathbb{Z}^{B \times T}$:

$$\mathbf{X}_{tok} = \text{Embedding}(\mathbf{x}) \in \mathbb{R}^{B \times T \times C}$$

Concrete Example:

Suppose our input is the tokens [15496, 995, 0] representing "Hello world" plus a padding token.

For token ID 15496:

  • Look up row 15496 in $\mathbf{E}_{tok}$
  • Retrieve a 1536-dimensional vector, e.g.: $[0.02, -0.15, 0.08, ..., 0.11]$

Each of the 50,257 possible tokens has its own learned 1536-dimensional representation.

Memory (but shared with output via weight tying!): $$50304 \times 1536 = 77,266,944 \text{ parameters} \approx 77.3\text{M}$$


7. Positional Embeddings

Transformers have no inherent notion of sequence order. Positional embeddings inject position information:

$$\mathbf{E}_{pos} \in \mathbb{R}^{T \times C} = \mathbb{R}^{512 \times 1536}$$

For each position $t \in {0, 1, ..., 511}$, we retrieve a learned 1536-dimensional vector.

The combined embedding is:

$$\mathbf{X} = \mathbf{X}_{tok} + \mathbf{E}_{pos}$$

Concrete Example:

Position 0 embedding: $\mathbf{p}_0 = [0.01, 0.03, -0.02, ..., 0.05]$
Position 1 embedding: $\mathbf{p}_1 = [-0.02, 0.01, 0.04, ..., -0.03]$

If token "Hello" at position 0 has embedding $[0.02, -0.15, 0.08, ...]$:

$$\mathbf{x}_0 = [0.02 + 0.01, -0.15 + 0.03, 0.08 + (-0.02), ...] = [0.03, -0.12, 0.06, ...]$$

Memory: $$512 \times 1536 = 786,432 \text{ parameters} \approx 0.8\text{M}$$


8. The Transformer Block

Each of the 24 blocks applies the same structure with different learned weights:

class Block(nn.Module):
    def __init__(self):
        self.ln1, self.ln2, self.attn, self.ffn = RMSNorm(C), RMSNorm(C), Attn(), FFN()

    def forward(self, x):
        x = x + self.attn(self.ln1(x))   # Attention with residual
        return x + self.ffn(self.ln2(x)) # FFN with residual

Mathematically, for block $\ell$:

$$\mathbf{X}^{(\ell)} = \mathbf{X}^{(\ell-1)} + \text{Attn}(\text{RMSNorm}(\mathbf{X}^{(\ell-1)}))$$

$$\mathbf{X}^{(\ell)} = \mathbf{X}^{(\ell)} + \text{FFN}(\text{RMSNorm}(\mathbf{X}^{(\ell)}))$$

This is the Pre-Norm formulation, where normalization is applied before each sublayer.

Residual Connections

The + operations are residual (skip) connections. They:

  1. Allow gradients to flow directly backward through the network
  2. Enable the model to learn identity mappings easily
  3. Stabilize training of deep networks

Without residuals, training a 12-layer network would be extremely difficult due to vanishing gradients.


9. RMSNorm (Root Mean Square Normalization)

SARAN uses RMSNorm instead of LayerNorm. RMSNorm is simpler and faster:

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps, self.weight = eps, nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

Mathematically:

$$\text{RMSNorm}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x}}{\text{RMS}(\mathbf{x})}$$

Where: $$\text{RMS}(\mathbf{x}) = \sqrt{\frac{1}{C} \sum_{i=1}^{C} x_i^2 + \epsilon}$$

And $\gamma \in \mathbb{R}^C$ is a learned scale parameter (initialized to ones).

Concrete Example:

For a single position's embedding $\mathbf{x} = [2.0, 4.0, 6.0, 8.0]$ (simplified to 4D):

$$\text{RMS} = \sqrt{\frac{2^2 + 4^2 + 6^2 + 8^2}{4}} = \sqrt{\frac{4 + 16 + 36 + 64}{4}} = \sqrt{30} \approx 5.48$$

Normalized (assuming $\gamma = [1,1,1,1]$):

$$\hat{\mathbf{x}} = \frac{[2, 4, 6, 8]}{5.48} = [0.37, 0.73, 1.09, 1.46]$$

RMSNorm vs LayerNorm

Property LayerNorm RMSNorm
Centering Yes (subtracts mean) No
Scaling By std deviation By RMS
Parameters $\gamma$ and $\beta$ (2C) $\gamma$ only (C)
Speed Slower ~15% faster

RMSNorm removes the mean-centering step, which empirically doesn't hurt performance but speeds up computation.

Parameters per RMSNorm: $C = 1536$ (only scale, no shift)


10. Single-Head Attention (SARAN's Simplicity Innovation)

Unlike GPT's multi-head attention, SARAN uses single-head attention operating on the full embedding dimension:

class Attn(nn.Module):
    def __init__(self):
        super().__init__()
        self.qkv = nn.Linear(C, 3 * C, bias=False)  # Fused Q, K, V projection
        self.proj = nn.Linear(C, C, bias=False)     # Output projection
        self.register_buffer("mask", torch.triu(torch.ones(T, T), diagonal=1).bool())

    def forward(self, x):
        _, t, _ = x.shape
        q, k, v = self.qkv(x).split(C, dim=-1)
        w = (q @ k.transpose(-2, -1) * C**-0.5).masked_fill(self.mask[:t, :t], float("-inf"))
        return self.proj(F.softmax(w, dim=-1) @ v)

Key difference from GPT:

  • GPT: 12 heads, each with $d_k = 64$ dimensions
  • SARAN: 1 head with $d_k = 1536$ dimensions (full embedding)

Step-by-Step Breakdown

Step 1: Compute Q, K, V

The input $\mathbf{X} \in \mathbb{R}^{B \times T \times C}$ is projected:

$$[\mathbf{Q}, \mathbf{K}, \mathbf{V}] = \mathbf{X}\mathbf{W}^{QKV}$$

Where $\mathbf{W}^{QKV} \in \mathbb{R}^{C \times 3C} = \mathbb{R}^{768 \times 2304}$ (no bias!)

Then split into three tensors, each $\in \mathbb{R}^{B \times T \times C}$

Concrete Example (simplified):

Let's trace a tiny example with $T=3$ positions and $C = 4$:

Input at 3 positions:

$$ \mathbf{X} = \begin{bmatrix} x_0 \\ x_1 \\ x_2 \end{bmatrix} \in \mathbb{R}^{3 \times 4} $$

After projection (assume weights give these results):

$$ \mathbf{Q} = \begin{bmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 1 & 1 & 0 & 0 \end{bmatrix}, \quad \mathbf{K} = \begin{bmatrix} 1 & 1 & 0 & 0 \\ 0 & 1 & 1 & 0 \\ 0 & 0 & 1 & 1 \end{bmatrix}, \quad \mathbf{V} = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \end{bmatrix} $$

Step 2: Compute Attention Scores

$$\text{scores} = \frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{C}}$$

Note: SARAN scales by $\sqrt{C} = \sqrt{1536} \approx 39.2$, not $\sqrt{d_k} = \sqrt{64} = 8$ as in GPT.

$$ \mathbf{Q}\mathbf{K}^T = \begin{bmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 1 & 1 & 0 & 0 \end{bmatrix} \begin{bmatrix} 1 & 0 & 0 \\ 1 & 1 & 0 \\ 0 & 1 & 1 \\ 0 & 0 & 1 \end{bmatrix} = \begin{bmatrix} 1 & 1 & 1 \\ 1 & 1 & 1 \\ 2 & 1 & 0 \end{bmatrix} $$

Scaling by $\frac{1}{\sqrt{4}} = 0.5$ (in our simplified 4D example):

$$ \text{scores} = \begin{bmatrix} 0.5 & 0.5 & 0.5 \\ 0.5 & 0.5 & 0.5 \\ 1.0 & 0.5 & 0.0 \end{bmatrix} $$

Step 3: Apply Causal Mask

The causal mask prevents attending to future positions. SARAN uses torch.triu with diagonal=1:

$$ \text{mask} = \begin{bmatrix} 0 & 1 & 1 \\ 0 & 0 & 1 \\ 0 & 0 & 0 \end{bmatrix} $$

Where 1 means "mask out" (set to $-\infty$):

$$ \mathbf{S}_{\text{masked}} = \begin{bmatrix} 0.5 & -\infty & -\infty \\ 0.5 & 0.5 & -\infty \\ 1.0 & 0.5 & 0.0 \end{bmatrix} $$

Step 4: Softmax

Apply softmax row-wise. Since $e^{-\infty} = 0$:

Row 0: $\text{softmax}([0.5, -\infty, -\infty]) = [1.0, 0, 0]$

Row 1: $\text{softmax}([0.5, 0.5, -\infty]) = [0.5, 0.5, 0]$

Row 2: $\text{softmax}([1.0, 0.5, 0.0])$:

  • $e^{1.0} = 2.72$, $e^{0.5} = 1.65$, $e^{0.0} = 1.0$
  • Sum = 5.37
  • $= [0.51, 0.31, 0.19]$

$$ \text{attn} = \begin{bmatrix} 1.0 & 0 & 0 \\ 0.5 & 0.5 & 0 \\ 0.51 & 0.31 & 0.19 \end{bmatrix} $$

Step 5: Weighted Sum of Values

$$\mathbf{A}_{\text{out}} = \mathbf{A} \times \mathbf{V}$$

$$ = \begin{bmatrix} 1.0 & 0 & 0 \\ 0.5 & 0.5 & 0 \\ 0.51 & 0.31 & 0.19 \end{bmatrix} \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \end{bmatrix} = \begin{bmatrix} 1.0 & 0 & 0 & 0 \\ 0.5 & 0.5 & 0 & 0 \\ 0.51 & 0.31 & 0.19 & 0 \end{bmatrix} $$

Interpretation:

  • Position 0 only sees itself (attention weight 1.0 on position 0)
  • Position 1 sees positions 0 and 1 equally (0.5 each)
  • Position 2 sees all previous positions with decaying attention

Step 6: Output Projection

Unlike GPT where concatenation of head outputs is projected, SARAN directly projects the single-head output:

$$\mathbf{O} = \mathbf{A}_{\text{out}} \cdot \mathbf{W}^{O}$$

Where $\mathbf{W}^{O} \in \mathbb{R}^{C \times C} = \mathbb{R}^{1536 \times 1536}$ (no bias!)

The Attention Formula (Complete)

$$\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{C}} + \mathbf{M}\right)\mathbf{V}$$

Where $\mathbf{M}$ is the causal mask ($0$ for allowed positions, $-\infty$ for masked).

Why Single-Head Works

Multi-head attention was designed to let the model attend to different aspects in parallel. However:

  1. With sufficient depth (12 layers), single-head attention can learn diverse patterns across layers
  2. Full $C$-dimensional attention captures richer relationships per layer
  3. Simpler architecture means easier optimization and interpretation
  4. Fewer parameters without significant quality loss

11. Feed-Forward Network (4x Expansion)

SARAN uses standard GPT-style 4x expansion for maximum synthesis capacity:

class FFN(nn.Module):
    def __init__(self, dim, dropout=0.0):
        super().__init__()
        hidden = dim * 4  # 4x expansion
        self.w1 = nn.Linear(dim, hidden, bias=False)
        self.w2 = nn.Linear(hidden, dim, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.w2(F.silu(self.w1(x))))

Mathematically:

$$\text{FFN}(\mathbf{x}) = \text{Dropout}(\mathbf{W}_2 \cdot \text{SiLU}(\mathbf{W}_1 \mathbf{x}))$$

Where:

  • $\mathbf{W}_1 \in \mathbb{R}^{6144 \times 1536}$ (expansion to 4×)
  • $\mathbf{W}_2 \in \mathbb{R}^{1536 \times 6144}$ (projection back)
  • No biases!

SiLU Activation

SARAN uses SiLU (Sigmoid Linear Unit), also known as Swish:

$$\text{SiLU}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}}$$

Example values:

$x$ $\sigma(x)$ SiLU($x$)
-2.0 0.12 -0.24
-1.0 0.27 -0.27
0.0 0.50 0.0
1.0 0.73 0.73
2.0 0.88 1.76

SiLU is smooth, non-monotonic (has a small negative region), and has been shown to work well in modern architectures like LLaMA and PaLM.

SiLU vs GELU comparison:

Property GELU SiLU
Formula $x \cdot \Phi(x)$ $x \cdot \sigma(x)$
Min value ~-0.17 at x≈-0.75 ~-0.28 at x≈-1.28
Computation Slower (erf) Faster (sigmoid)
Usage GPT, BERT LLaMA, SARAN

Dropout Regularization

SARAN applies dropout (default 0.1) at multiple points to prevent overfitting:

  1. After embedding summation — regularizes input representations
  2. In attention (SDPA)dropout_p parameter during training
  3. After attention output projection — residual dropout
  4. After FFN output — before residual addition
# Attention with dropout
out = F.scaled_dot_product_attention(
    q, k, v, is_causal=True, dropout_p=self.dropout if self.training else 0.0
)
return self.resid_dropout(self.out_proj(out))

# FFN with dropout
return self.dropout(self.w2(F.silu(self.w1(x))))

Dropout is disabled during inference (model.eval()) for deterministic outputs.


12. Weight Tying

SARAN ties the token embedding matrix to the output projection matrix:

self.tok = nn.Embedding(vocab_size, C)
self.head = nn.Linear(C, vocab_size, bias=False)
self.tok.weight = self.head.weight  # Weight tying!

This means:

$$\mathbf{E}_{\text{tok}} = \mathbf{W}_{\text{out}}^T$$

The same matrix is used for:

  1. Encoding: token ID → embedding vector
  2. Decoding: hidden state → vocabulary logits

Benefits:

  • Fewer parameters: Saves $V \times C = 38.6M$ parameters
  • Semantic consistency: Similar tokens have similar embeddings AND similar output distributions
  • Regularization effect: Constrains the model's representation space

Memory savings: $$50304 \times 1536 = 77,266,944 \text{ parameters saved}$$


13. Output Projection & Loss

After all transformer blocks, we project to vocabulary logits:

logits = self.head(self.ln(self.blocks(x)))

$$\text{logits} = \mathbf{W}_{out} \cdot \text{RMSNorm}(\mathbf{X}^{(L)})$$

Where $\mathbf{W}_{out} \in \mathbb{R}^{V \times C} = \mathbb{R}^{50304 \times 1536}$ (shared with embedding!)

Output shape: $(B, T, V) = (4, 512, 50304)$

Each position produces a 50,257-dimensional vector of logits (unnormalized log-probabilities).

Cross-Entropy Loss

For training, we compute cross-entropy loss between predictions and targets:

$$\mathcal{L} = -\frac{1}{BT}\sum_{b=1}^{B}\sum_{t=1}^{T} \log P(y_{b,t} \mid x_{b,1:t-1})$$

Where $P(y \mid x) = \text{softmax}(\text{logits})_y$

Concrete Example:

For a single position predicting token 42:

  • Logits: $[1.2, 0.5, ..., 3.8_{(42)}, ..., 0.1]$ (50,257 values)
  • Softmax: $[0.001, 0.0005, ..., 0.15_{(42)}, ..., 0.0003]$
  • Loss: $-\log(0.15) = 1.90$

A loss of ~2.0 means the model assigns roughly $e^{-2} \approx 13.5%$ probability to the correct token on average.


14. Text Generation

Generation uses autoregressive sampling:

def generate(self, idx, n, temp=0.8, top_k=40):
    for _ in range(n):
        logits = self(idx[:, -T:])[0][:, -1, :] / temp
        if top_k:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = float("-inf")
        idx = torch.cat([idx, torch.multinomial(F.softmax(logits, -1), 1)], 1)
    return idx

Temperature Scaling

Temperature $\tau$ controls randomness:

$$P(w_i) = \frac{e^{z_i / \tau}}{\sum_j e^{z_j / \tau}}$$

Temperature Effect
$\tau &lt; 1$ Sharper distribution, more deterministic
$\tau = 1$ Original distribution
$\tau &gt; 1$ Flatter distribution, more random

Example with logits $[2.0, 1.0, 0.5]$:

$\tau$ Probabilities
0.5 [0.88, 0.10, 0.02]
1.0 [0.67, 0.24, 0.09]
2.0 [0.51, 0.31, 0.18]

Top-k Sampling

SARAN uses top-k=40 by default (GPT uses 50):

  1. Find the 40th largest logit value
  2. Set all logits below this threshold to $-\infty$
  3. Renormalize with softmax
  4. Sample from this truncated distribution

Example: If top-k=3 and logits are $[5, 3, 2, 1, 0.5]$:

  • Top 3 values: $[5, 3, 2]$
  • Threshold: 2
  • After masking: $[5, 3, 2, -\infty, -\infty]$
  • After softmax: $[0.84, 0.11, 0.04, 0, 0]$

15. Training Loop

SARAN's training loop includes gradient accumulation, cosine annealing, and gradient clipping:

for i in range(max_iters):
    opt.zero_grad(set_to_none=True)
    for _ in range(grad_accum_steps):
        loss = model(*get_batch("train"))[1]
        (loss / grad_accum_steps).backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    opt.step()
    sched.step()

Gradient Accumulation

With batch size 4 and 16 accumulation steps:

$$\text{Effective Batch Size} = B \times G = 4 \times 16 = 64$$

This allows training with a large effective batch size on limited GPU memory.

AdamW Optimizer

SARAN uses AdamW with specific hyperparameters:

opt = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1)

$$m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t$$ $$v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2$$ $$\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t}$$ $$\theta_t = \theta_{t-1} - \alpha \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \theta_{t-1} \right)$$

Where:

  • $\alpha = 3 \times 10^{-4}$ (peak learning rate, with warmup)
  • $\lambda = 0.1$ (weight decay)
  • $\beta_1 = 0.9$, $\beta_2 = 0.95$ (momentum terms — note $\beta_2$ is lower than typical 0.999)

Learning Rate Schedule (Warmup + Cosine Decay)

SARAN uses linear warmup followed by cosine decay:

def get_lr(it):
    # Linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * (it + 1) / warmup_iters
    # Cosine decay after warmup
    decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters)
    coeff = 0.5 * (1.0 + np.cos(np.pi * decay_ratio))
    return learning_rate / 10 + coeff * (learning_rate - learning_rate / 10)

The learning rate schedule:

  1. Warmup (steps 0-2000): Linear increase from 0 to $\eta_{max}$
  2. Cosine decay (steps 2000-50000): Smooth decay to $\eta_{min}$

$$\eta_t = \begin{cases} \eta_{max} \cdot \frac{t+1}{T_{warmup}} & t < T_{warmup} \\ \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + \cos\left(\frac{t - T_{warmup}}{T_{max} - T_{warmup}}\pi\right)\right) & t \geq T_{warmup} \end{cases}$$

  • Warmup: 2,000 steps
  • Peak: $\eta_{max} = 3 \times 10^{-4}$
  • Minimum: $\eta_{min} = 3 \times 10^{-5}$
  • Total: 50,000 iterations

Gradient Clipping

torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

If $|\nabla| &gt; 1.0$, gradients are scaled down:

$$\nabla' = \nabla \cdot \frac{1.0}{|\nabla|}$$

This prevents exploding gradients and stabilizes training.

Mixed Precision (bfloat16)

SARAN uses automatic mixed precision (AMP) for faster training on GPU/MPS:

from torch.amp import autocast

# Device-aware dtype selection
use_amp = device in ("mps", "cuda")
amp_dtype = torch.bfloat16 if use_amp else torch.float32

# Training loop with autocast
with autocast(device_type=device, dtype=amp_dtype, enabled=use_amp):
    _, loss = model(xb, yb)
(loss / grad_accum_steps).backward()

Precision Comparison:

Precision Bits Memory Speed Stability
float32 32 100% 1x Best
float16 16 50% ~2x Needs loss scaling
bfloat16 16 50% ~1.5-2x Very stable

Why bfloat16?

  • Same exponent range as float32 (8 bits) — no overflow issues
  • Reduced mantissa (7 bits vs 23) — slightly less precision
  • No loss scaling required (unlike float16)
  • Native support on Apple Silicon (MPS) and modern NVIDIA GPUs
  • Typical loss difference: < 0.01-0.05 (negligible)

Memory savings: $$95\text{M params} \times 2\text{ bytes} = 190\text{ MB} \quad \text{(vs 380 MB with float32)}$$

JIT Compilation (torch.compile)

PyTorch 2.0+ offers torch.compile for Just-In-Time (JIT) compilation, fusing operations for faster execution:

# Compile model for faster execution (CUDA only)
if hasattr(torch, "compile") and device == "cuda":
    print("Compiling model with torch.compile...")
    model = torch.compile(model, mode="reduce-overhead")

Compilation Modes:

Mode Speedup Compile Time Best For
default ~1.5x Moderate General use
reduce-overhead ~1.5-2x Longer Small batches, LLMs
max-autotune ~2x+ Very long Max throughput needed

How it works:

  1. Graph capture: Traces model execution into a graph
  2. Fusion: Combines multiple operations (e.g., matmul + add + activation)
  3. Codegen: Generates optimized CUDA kernels via Triton

Platform support:

  • CUDA: Full support, ~1.5-2x speedup typical
  • MPS: Limited/experimental, ~1.1-1.3x speedup
  • CPU: Works but minimal gains

Why CUDA-only in SARAN:

  • MPS backend still experimental with torch.compile
  • CPU gains are negligible and add startup overhead
  • CUDA's Triton backend provides the most reliable speedups

Flash Attention (scaled_dot_product_attention)

SARAN uses PyTorch 2.0's F.scaled_dot_product_attention for efficient attention computation:

# Flash Attention - replaces manual attention computation
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)

What it replaces:

# Old manual implementation (O(T²) memory)
scores = (q @ k.transpose(-2, -1)) * scale
scores = scores.masked_fill(causal_mask, float("-inf"))
attn = F.softmax(scores, dim=-1)
out = attn @ v

Backend selection (automatic):

Backend Device Memory Speed Requirements
Flash Attention v2 CUDA O(T) Fastest Ampere+ GPU
Memory-Efficient (xFormers) CUDA O(T) Fast Any CUDA GPU
Math Fallback All O(T²) Baseline Always available

Benefits:

  • O(T) memory vs O(T²) — enables much longer sequences
  • ~2-4x faster on CUDA for T > 256
  • Fused kernel — no intermediate tensors for scores/softmax
  • is_causal=True — handles causal masking internally (no manual mask needed)
  • Single-head compatible — works perfectly with SARAN's 1-head design

Memory comparison at T=512: $$\text{Manual: } 512 \times 512 \times 4 \text{ bytes} = 1\text{ MB per batch}$$ $$\text{Flash: } O(T) \approx 2\text{ KB per batch}$$


16. Direct Preference Optimization (DPO)

After Supervised Fine-Tuning (SFT) on Alpaca, SARAN applies Direct Preference Optimization (DPO) to align the model with human preferences. This is a simpler alternative to RLHF that doesn't require a separate reward model.

The RLHF Pipeline (What We're Simplifying)

Traditional RLHF has three stages:

┌─────────────┐     ┌──────────────┐     ┌─────────────┐
│    SFT      │ ──▶ │ Reward Model │ ──▶ │    RLHF     │
│ (Alpaca)    │     │  Training    │     │   (PPO)     │
└─────────────┘     └──────────────┘     └─────────────┘
     ↓                    ↓                    ↓
  Base Model      Scores Outputs          Final Model

Problems with RLHF:

  • Requires training a separate reward model
  • PPO is unstable and hyperparameter-sensitive
  • Complex multi-stage pipeline

DPO: A Simpler Alternative

DPO directly optimizes preferences without a reward model:

┌─────────────┐     ┌─────────────┐
│    SFT      │ ──▶ │    DPO      │
│ (Alpaca)    │     │ (Preference)│
└─────────────┘     └─────────────┘
     ↓                    ↓
  SFT Model         Final Model

The DPO Loss Function

Given a preference pair (chosen $y_w$, rejected $y_l$) for prompt $x$:

$$\mathcal{L}_{\text{DPO}} = -\log \sigma \left( \beta \cdot \left[ \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} - \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)} \right] \right)$$

Where:

  • $\pi_\theta$ = policy model (being trained)
  • $\pi_{\text{ref}}$ = reference model (frozen copy of SFT model)
  • $\beta$ = temperature parameter (0.1 in our config)
  • $\sigma$ = sigmoid function

Intuition: The loss encourages the policy to:

  1. Increase probability of chosen responses (relative to reference)
  2. Decrease probability of rejected responses (relative to reference)

Reference Model

A frozen copy of the SFT model serves as the reference:

# Create frozen reference model
ref_model = copy.deepcopy(model)
ref_model.eval()
for param in ref_model.parameters():
    param.requires_grad = False

Purpose: The reference model prevents the policy from deviating too far from the SFT model, which could lead to reward hacking or degenerate outputs.

Preference Data

DPO requires pairs of (chosen, rejected) responses:

# Example preference pair
{
    "chosen": "User: What is 2+2?\nAssistant: 2+2 equals 4.",
    "rejected": "User: What is 2+2?\nAssistant: 2+2 is probably 5..."
}

Data sources:

  1. Anthropic HH-RLHF (default) — Human preference data for helpfulness/harmlessness
  2. Synthetic (fallback) — Generated from Alpaca by creating lower-quality alternatives

DPO Hyperparameters

Parameter Value Description
beta 0.1 Temperature controlling reference deviation
learning_rate 1e-6 Very small LR for stable preference learning
max_iters 5000 Training iterations
batch_size 1 Preference pairs per batch
grad_accum 16 Effective batch size = 16
patience 10 Early stopping patience

Why lower LR? DPO is a fine-tuning of a fine-tuning — the model should only adjust preferences, not forget its instruction-following abilities.

Training Flow

for it in range(dpo_max_iters):
    # Get preference pair
    chosen, rejected = get_dpo_batch("train")
    
    # Compute DPO loss
    pi_chosen = get_log_probs(model, chosen)
    pi_rejected = get_log_probs(model, rejected)
    ref_chosen = get_log_probs(ref_model, chosen)
    ref_rejected = get_log_probs(ref_model, rejected)
    
    log_ratio = (pi_chosen - pi_rejected) - (ref_chosen - ref_rejected)
    loss = -F.logsigmoid(beta * log_ratio).mean()
    
    # Optimize
    loss.backward()
    optimizer.step()

Output

After the complete fine-tuning pipeline:

Final model: saran_mlv_dpo_best.pt — Used by saran_mlv_c.py for chat


17. Parameter Count

Let's count all parameters:

Component Calculation Parameters
Token Embedding $V \times C$ 50,304 × 1536 = 77,266,944
Position Embedding $T \times C$ 512 × 1536 = 786,432
Per Transformer Block:
→ RMSNorm 1 $C$ 1,536
→ RMSNorm 2 $C$ 1,536
→ Attention QKV $C \times 3C$ 1536 × 4608 = 7,077,888
→ Attention Output $C \times C$ 1536 × 1536 = 2,359,296
→ FFN Layer 1 $C \times 4C$ 1536 × 6144 = 9,437,184
→ FFN Layer 2 $4C \times C$ 6144 × 1536 = 9,437,184
Block Total ~28,314,624
All 24 Blocks $24 \times$ ~679,550,976
Final RMSNorm $C$ 1,536
Output Head (tied with embedding) 0

Total: ~757.6 Million Parameters


Complete Forward Pass Example

Let's trace "Hello" through the entire network:

1. Input: "Hello" → token [15496] → tensor shape $(1, 1)$

2. Token Embedding: Look up row 15496 → $(1, 1, 1536)$

3. Position Embedding: Look up position 0 → $(1, 1536)$, broadcast to $(1, 1, 1536)$

4. Sum + Dropout: Token + Position + Dropout → $(1, 1, 1536)$

5. Through 24 Blocks:

  • Each block: RMSNorm → Single-Head Attn → Add → RMSNorm → FFN(4x) → Add
  • Shape stays $(1, 1, 1536)$ throughout

6. Final RMSNorm: $(1, 1, 1536)$

7. Output Head: Linear projection (tied weights) → $(1, 1, 50304)$

8. Softmax + Sample: Probability distribution over 50,257 tokens → sample next token

9. Repeat: Append new token, process again for next prediction


Summary

The SARAN architecture makes strategic simplifications to the GPT design:

Component GPT SARAN Tradeoff
Attention 12 heads × 64d 1 head × 1536d Simpler, full-rank attention
FFN 4x expansion 4x expansion Maximum synthesis capacity
Normalization LayerNorm RMSNorm Faster, fewer params
Activation GELU SiLU Modern, smooth
Output Separate head Weight tied Fewer params
Biases Yes No Fewer params
Dropout Variable 0.1 Regularization
LR Schedule Various Warmup+Cosine Stable training

Key Insights:

  1. Single-head attention can be as effective as multi-head when the model is deep enough
  2. 4x FFN expansion provides maximum capacity for knowledge storage and synthesis
  3. Weight tying enforces semantic consistency and saves parameters
  4. RMSNorm is faster without sacrificing quality
  5. Dropout (0.1) prevents overfitting during pre-training
  6. LR warmup (2000 steps) stabilizes early training for large models

The result is a ~760M parameter model optimized for synthesis quality through architectural choices and scale.