The dominant paradigms in sequence transduction - Recurrent Neural Networks and deep Transformer architectures - rely on complex, multi-layered structures to achieve performance, often at the cost of interpretability and computational transparency. In this work, we introduce the Shallow Auto-Regressive Attention Network (SARAN), a minimalist architecture that reduces the Transformer decoder to its fundamental components. SARAN is defined by a strictly linear, 15-stage computational graph that maps input embeddings directly to output probabilities via a single, isolated block of masked self-attention. We present a "first principles" derivation of the network's training dynamics, explicitly defining the manual backpropagation algorithm through the attention mechanism without reliance on automatic differentiation engines. By stripping away deep layer stacking and feed-forward networks, SARAN demonstrates that a solitary attention block is sufficient to mechanically derive autoregressive properties, providing a transparent and rigorous baseline for understanding the mechanics of attention-based sequence modeling.
SARAN consists of four Python scripts for the complete training pipeline:
Download and tokenize the OpenWebText dataset for pre-training:
# Activate virtual environment
source venv/bin/activate
# Download and tokenize OpenWebText (~30-60 min, streams from HuggingFace)
python get_dataset.pyOutput: openwebtext_tokens.jsonl and openwebtext_offsets.npy
Pre-train the model on a large text corpus (OpenWebText) to learn general language understanding:
# Activate virtual environment
source venv/bin/activate
# Run pre-training (50k iterations, ~4-8 hours on GPU/MPS)
python saran_mlv.pyOutput: saran_mlv_best.pt (best weights saved during training)
Two-phase fine-tuning for professional chatbot capabilities:
Phase 1: Supervised Fine-Tuning (SFT)
# Run fine-tuning pipeline
python saran_mlv_ft.pyPhase 2: Direct Preference Optimization (DPO)
- Automatically runs after SFT if
dpo.enabled=truein config - Trains on preference data (Anthropic HH-RLHF or synthetic)
- Creates reference model (frozen copy) for stable training
Output: saran_mlv_dpo_best.pt
Run the interactive chatbot with agentic web search:
# Start the chat interface
python saran_mlv_c.pyFeatures:
- Agentic Web Search + LLM Synthesis: Questions automatically search the web via DuckDuckGo, then the LLM synthesizes the results into a coherent response. Falls back to raw web results if synthesis fails.
- Garbage Detection: Low-quality model outputs are caught and replaced with web results or "I don't know"
- Conversation History: Maintains context from the last 10 turns
- Configurable: All parameters loaded from
config.json
Commands:
quit/exit/q- Exit the chatclear/reset- Clear conversation history
The web search agent provides real-time information retrieval:
import web
result = web.search("What is the capital of France?")
# Returns: "Paris is the capital and largest city of France..."Features:
- DuckDuckGo Instant Answer API (primary)
- HTML scrape fallback for reliability
- Automatic text cleaning (HTML entities, non-ASCII)
- Truncation to complete sentences
All hyperparameters are centralized in config.json:
{
"model": {
"block_size": 512,
"n_embd": 1536,
"n_layer": 24,
"vocab_size": 50304,
"dropout": 0.1
},
"training": {
"batch_size": 2,
"grad_accum_steps": 32,
"max_iters": 50000,
"eval_interval": 1000,
"eval_iters": 100,
"learning_rate": 3e-4,
"warmup_iters": 2000,
"grad_clip": 1.0,
"weight_decay": 0.1
},
"finetuning": {
"batch_size": 2,
"grad_accum_steps": 8,
"max_iters": 50000,
"eval_interval": 200,
"learning_rate": 3e-5,
"grad_clip": 1.0,
"weight_decay": 0.01,
"patience": 5
},
"dpo": {
"enabled": true,
"beta": 0.1,
"learning_rate": 1e-6,
"max_iters": 5000,
"eval_interval": 100,
"batch_size": 1,
"grad_accum_steps": 16,
"grad_clip": 1.0,
"weight_decay": 0.01,
"patience": 10,
"dataset": "Anthropic/hh-rlhf"
},
"generation": {
"max_new_tokens": 1024,
"temperature": 0.7,
"top_k": 40,
"repetition_penalty": 1.3,
"debug": false
},
"agents": {
"web": {
"enabled": true,
"module": "web",
"function": "search"
}
},
"search_triggers": ["what", "who", "where", "when", "why", "how", ...]
}| Section | Description |
|---|---|
model |
Architecture hyperparameters (shared by all scripts) |
training |
Pre-training hyperparameters (saran_mlv.py) |
finetuning |
SFT hyperparameters (saran_mlv_ft.py Phase 1) |
dpo |
DPO hyperparameters (saran_mlv_ft.py Phase 2) |
generation |
Inference settings for chat (temperature, top_k, etc.) |
agents |
Agentic capabilities (web search, future agents) |
search_triggers |
Words that trigger automatic web search |
# Complete training pipeline
cd /path/to/a-gpt
source venv/bin/activate
# Step 0: Download and tokenize dataset
python get_dataset.py
# Step 1: Pre-train (creates language model)
python saran_mlv.py
# Step 2: Fine-tune (teaches instruction following)
python saran_mlv_ft.py
# Step 3: Chat with agentic web search!
python saran_mlv_c.py| File | Purpose |
|---|---|
saran_mlv.py |
Pre-training on OpenWebText |
saran_mlv_ft.py |
Fine-tuning on Alpaca instructions |
saran_mlv_c.py |
Chat interface with web search |
web.py |
DuckDuckGo search agent |
config.json |
Centralized configuration |
get_dataset.py |
Dataset download and tokenization |
- SARAN: Shallow Auto-Regressive Attention Network
- Quick Start
- Table of Contents
- 1. SARAN vs GPT: Key Innovations
- 2. Configuration & Hyperparameters
- 3. Data Pipeline
- 4. Execution Flow Overview
- 5. The SARAN Class: The Heart of the Model
- 6. Token Embeddings
- 7. Positional Embeddings
- 8. The Transformer Block
- 9. RMSNorm (Root Mean Square Normalization)
- 10. Single-Head Attention (SARAN's Simplicity Innovation)
- 11. Feed-Forward Network (4x Expansion)
- 12. Weight Tying
- 13. Output Projection & Loss
- 14. Text Generation
- 15. Training Loop
- 16. Direct Preference Optimization (DPO)
- 17. Parameter Count
- Complete Forward Pass Example
- Summary
SARAN introduces three key architectural simplifications compared to standard GPT:
| Feature | GPT | SARAN | Benefit |
|---|---|---|---|
| Attention Heads | 12 multi-head | 1 single-head | Simpler, more interpretable |
| FFN Expansion | 4× (1536→6144) | 4× (1536→6144) | Same capacity for synthesis |
| Normalization | LayerNorm | RMSNorm | Faster computation |
| Activation | GELU | SiLU (Swish) | Modern, smooth gradients |
| Weight Tying | No | Yes (embed = output) | Fewer parameters |
| Biases | Yes | No (in Linear layers) | Fewer parameters |
| Precision | float32 | bfloat16 (mixed) | ~2x faster, 50% less memory |
| Compilation | No | torch.compile (CUDA) | ~1.5-2x faster on GPU |
| Flash Attention | No | Yes (SDPA) | O(T) memory, ~2-4x faster |
| Parameters | ~125M (GPT-2) | ~530M | Better synthesis quality |
These changes result in a more parameter-efficient model while maintaining competitive performance.
The model is configured with these key hyperparameters:
B, T, C, L = 2, 512, 1536, 24| Symbol | Name | Value | Description |
|---|---|---|---|
| Batch Size | 2 | Sequences per micro-batch | |
| Context Length | 512 | Maximum sequence length (tokens) | |
| Embedding Dimension | 1536 | Size of token/positional embeddings | |
| Number of Layers | 24 | Transformer blocks stacked | |
| Vocabulary Size | 50,304 | Padded for GPU efficiency |
Note: SARAN has no
Model Size: ~760M parameters (with 4x FFN expansion for better synthesis quality)
Additional training hyperparameters:
| Parameter | Value | Description |
|---|---|---|
grad_accum_steps |
32 | Gradient accumulation steps |
lr |
3e-4 | Learning rate (with 2000-step warmup) |
warmup_iters |
2000 | Linear warmup steps before cosine decay |
grad_clip |
1.0 | Gradient clipping threshold |
dropout |
0.1 | Dropout for regularization |
weight_decay |
0.1 | AdamW weight decay |
Effective batch size:
Text is converted to integer tokens using the GPT-2 BPE (Byte Pair Encoding) tokenizer:
enc = tiktoken.get_encoding("gpt2")
# Round vocab to nearest 64 for GPU efficiency (Karpathy's nanoGPT optimization)
# GPT-2 has 50257 tokens, but 50304 = 64 * 786 aligns with tensor core block sizes
vocab_size = 50304
encode = lambda s: enc.encode(s)
decode = lambda l: enc.decode(list(l))Why 50,304 instead of 50,257?
- GPT-2 tokenizer has exactly 50,257 tokens
- 50,304 = 64 × 786 — aligns with GPU tensor core block sizes
- Provides ~5-15% speedup on embedding/output layers
- Only 47 extra "padding" tokens (never predicted)
- Credit: Andrej Karpathy's nanoGPT optimization
Example:
"Hello world" → [15496, 995]
SARAN uses a memory-mapped approach to load pre-tokenized OpenWebText data:
offsets = np.load("openwebtext_offsets.npy")
tokens_f = open("openwebtext_tokens.jsonl", "rb")
def get_batch(split):
split_idx = int(0.9 * num_examples)
start_i, end_i = (0, split_idx) if split == "train" else (split_idx, num_examples)
x_list, y_list = [], []
while len(x_list) < B:
ex_id = start_i + np.random.randint(0, end_i - start_i)
tokens_f.seek(int(offsets[ex_id]))
toks = json.loads(tokens_f.readline().decode("utf-8"))
if len(toks) <= T:
continue
start = np.random.randint(0, len(toks) - T)
x_list.append(toks[start : start + T])
y_list.append(toks[start + 1 : start + T + 1])
return torch.tensor(x_list, device=device), torch.tensor(y_list, device=device)This approach:
- Uses file offsets to randomly access documents
- Samples a random starting position within each document
- Extracts input/target pairs for next-token prediction
Shape Example:
- Input shape:
$(B, T) = (4, 512)$ - Target shape:
$(B, T) = (4, 512)$
When training, execution follows this path:
┌─────────────────────────────────────────────────────┐
│ SARAN Class │
├─────────────────────────────────────────────────────┤
Input Tokens ──► Token Embed ──► + ──► Block ──► Block ──► ... ──► Block ──► RMSNorm ──► Linear ──► Logits
(B, T) (B,T,C) │ ×1 ×2 ×24 (B,T,C) (B,T,V)
│ ↑
Pos Embed ──┘ │
(T, C) (weight tying) ────┘
Each Block contains:
Input ──► RMSNorm ──► Single-Head Attention ──► + ──► RMSNorm ──► FFN (4x) ──► + ──► Output
│ │ │
└────────────── (residual) ───────────────────┘────────── (residual) ────────┘
Key differences from GPT:
- RMSNorm instead of LayerNorm
- Single-head attention instead of 12-head MHA
- SiLU activation instead of GELU
class SARAN(nn.Module):
def __init__(self):
super().__init__()
self.tok = nn.Embedding(vocab_size, C) # Token embeddings
self.pos = nn.Embedding(T, C) # Positional embeddings
self.blocks = nn.Sequential(*[Block() for _ in range(L)]) # 12 transformer blocks
self.ln = RMSNorm(C) # Final RMSNorm
self.head = nn.Linear(C, vocab_size, bias=False) # Output projection
self.tok.weight = self.head.weight # Weight tying!
self.apply(self._init_weights)The forward pass:
def forward(self, idx, tgt=None):
x = self.tok(idx) + self.pos(torch.arange(idx.shape[1], device=device))
logits = self.head(self.ln(self.blocks(x)))
return logits, F.cross_entropy(...) if tgt is not None else NoneLet's trace a concrete example through the entire network.
The token embedding layer maps each token ID to a dense vector:
For input tokens
Concrete Example:
Suppose our input is the tokens [15496, 995, 0] representing "Hello world" plus a padding token.
For token ID 15496:
- Look up row 15496 in
$\mathbf{E}_{tok}$ - Retrieve a 1536-dimensional vector, e.g.:
$[0.02, -0.15, 0.08, ..., 0.11]$
Each of the 50,257 possible tokens has its own learned 1536-dimensional representation.
Memory (but shared with output via weight tying!):
Transformers have no inherent notion of sequence order. Positional embeddings inject position information:
For each position
The combined embedding is:
Concrete Example:
Position 0 embedding:
Position 1 embedding:
If token "Hello" at position 0 has embedding
Memory:
Each of the 24 blocks applies the same structure with different learned weights:
class Block(nn.Module):
def __init__(self):
self.ln1, self.ln2, self.attn, self.ffn = RMSNorm(C), RMSNorm(C), Attn(), FFN()
def forward(self, x):
x = x + self.attn(self.ln1(x)) # Attention with residual
return x + self.ffn(self.ln2(x)) # FFN with residualMathematically, for block
This is the Pre-Norm formulation, where normalization is applied before each sublayer.
The + operations are residual (skip) connections. They:
- Allow gradients to flow directly backward through the network
- Enable the model to learn identity mappings easily
- Stabilize training of deep networks
Without residuals, training a 12-layer network would be extremely difficult due to vanishing gradients.
SARAN uses RMSNorm instead of LayerNorm. RMSNorm is simpler and faster:
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps, self.weight = eps, nn.Parameter(torch.ones(dim))
def forward(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weightMathematically:
Where:
And
Concrete Example:
For a single position's embedding
Normalized (assuming
| Property | LayerNorm | RMSNorm |
|---|---|---|
| Centering | Yes (subtracts mean) | No |
| Scaling | By std deviation | By RMS |
| Parameters |
|
|
| Speed | Slower | ~15% faster |
RMSNorm removes the mean-centering step, which empirically doesn't hurt performance but speeds up computation.
Parameters per RMSNorm:
Unlike GPT's multi-head attention, SARAN uses single-head attention operating on the full embedding dimension:
class Attn(nn.Module):
def __init__(self):
super().__init__()
self.qkv = nn.Linear(C, 3 * C, bias=False) # Fused Q, K, V projection
self.proj = nn.Linear(C, C, bias=False) # Output projection
self.register_buffer("mask", torch.triu(torch.ones(T, T), diagonal=1).bool())
def forward(self, x):
_, t, _ = x.shape
q, k, v = self.qkv(x).split(C, dim=-1)
w = (q @ k.transpose(-2, -1) * C**-0.5).masked_fill(self.mask[:t, :t], float("-inf"))
return self.proj(F.softmax(w, dim=-1) @ v)Key difference from GPT:
- GPT: 12 heads, each with
$d_k = 64$ dimensions - SARAN: 1 head with
$d_k = 1536$ dimensions (full embedding)
The input
Where
Then split into three tensors, each
Concrete Example (simplified):
Let's trace a tiny example with
Input at 3 positions:
After projection (assume weights give these results):
Note: SARAN scales by
Scaling by
The causal mask prevents attending to future positions. SARAN uses torch.triu with diagonal=1:
Where 1 means "mask out" (set to
Apply softmax row-wise. Since
Row 0:
Row 1:
Row 2:
-
$e^{1.0} = 2.72$ ,$e^{0.5} = 1.65$ ,$e^{0.0} = 1.0$ - Sum = 5.37
$= [0.51, 0.31, 0.19]$
Interpretation:
- Position 0 only sees itself (attention weight 1.0 on position 0)
- Position 1 sees positions 0 and 1 equally (0.5 each)
- Position 2 sees all previous positions with decaying attention
Unlike GPT where concatenation of head outputs is projected, SARAN directly projects the single-head output:
Where
Where
Multi-head attention was designed to let the model attend to different aspects in parallel. However:
- With sufficient depth (12 layers), single-head attention can learn diverse patterns across layers
-
Full
$C$ -dimensional attention captures richer relationships per layer - Simpler architecture means easier optimization and interpretation
- Fewer parameters without significant quality loss
SARAN uses standard GPT-style 4x expansion for maximum synthesis capacity:
class FFN(nn.Module):
def __init__(self, dim, dropout=0.0):
super().__init__()
hidden = dim * 4 # 4x expansion
self.w1 = nn.Linear(dim, hidden, bias=False)
self.w2 = nn.Linear(hidden, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x))))Mathematically:
Where:
-
$\mathbf{W}_1 \in \mathbb{R}^{6144 \times 1536}$ (expansion to 4×) -
$\mathbf{W}_2 \in \mathbb{R}^{1536 \times 6144}$ (projection back) - No biases!
SARAN uses SiLU (Sigmoid Linear Unit), also known as Swish:
Example values:
| SiLU( |
||
|---|---|---|
| -2.0 | 0.12 | -0.24 |
| -1.0 | 0.27 | -0.27 |
| 0.0 | 0.50 | 0.0 |
| 1.0 | 0.73 | 0.73 |
| 2.0 | 0.88 | 1.76 |
SiLU is smooth, non-monotonic (has a small negative region), and has been shown to work well in modern architectures like LLaMA and PaLM.
SiLU vs GELU comparison:
| Property | GELU | SiLU |
|---|---|---|
| Formula | ||
| Min value | ~-0.17 at x≈-0.75 | ~-0.28 at x≈-1.28 |
| Computation | Slower (erf) | Faster (sigmoid) |
| Usage | GPT, BERT | LLaMA, SARAN |
SARAN applies dropout (default 0.1) at multiple points to prevent overfitting:
- After embedding summation — regularizes input representations
- In attention (SDPA) —
dropout_pparameter during training - After attention output projection — residual dropout
- After FFN output — before residual addition
# Attention with dropout
out = F.scaled_dot_product_attention(
q, k, v, is_causal=True, dropout_p=self.dropout if self.training else 0.0
)
return self.resid_dropout(self.out_proj(out))
# FFN with dropout
return self.dropout(self.w2(F.silu(self.w1(x))))Dropout is disabled during inference (model.eval()) for deterministic outputs.
SARAN ties the token embedding matrix to the output projection matrix:
self.tok = nn.Embedding(vocab_size, C)
self.head = nn.Linear(C, vocab_size, bias=False)
self.tok.weight = self.head.weight # Weight tying!This means:
The same matrix is used for:
- Encoding: token ID → embedding vector
- Decoding: hidden state → vocabulary logits
Benefits:
-
Fewer parameters: Saves
$V \times C = 38.6M$ parameters - Semantic consistency: Similar tokens have similar embeddings AND similar output distributions
- Regularization effect: Constrains the model's representation space
Memory savings:
After all transformer blocks, we project to vocabulary logits:
logits = self.head(self.ln(self.blocks(x)))Where
Output shape:
Each position produces a 50,257-dimensional vector of logits (unnormalized log-probabilities).
For training, we compute cross-entropy loss between predictions and targets:
Where
Concrete Example:
For a single position predicting token 42:
- Logits:
$[1.2, 0.5, ..., 3.8_{(42)}, ..., 0.1]$ (50,257 values) - Softmax:
$[0.001, 0.0005, ..., 0.15_{(42)}, ..., 0.0003]$ - Loss:
$-\log(0.15) = 1.90$
A loss of ~2.0 means the model assigns roughly
Generation uses autoregressive sampling:
def generate(self, idx, n, temp=0.8, top_k=40):
for _ in range(n):
logits = self(idx[:, -T:])[0][:, -1, :] / temp
if top_k:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float("-inf")
idx = torch.cat([idx, torch.multinomial(F.softmax(logits, -1), 1)], 1)
return idxTemperature
| Temperature | Effect |
|---|---|
| Sharper distribution, more deterministic | |
| Original distribution | |
| Flatter distribution, more random |
Example with logits
| Probabilities | |
|---|---|
| 0.5 | [0.88, 0.10, 0.02] |
| 1.0 | [0.67, 0.24, 0.09] |
| 2.0 | [0.51, 0.31, 0.18] |
SARAN uses top-k=40 by default (GPT uses 50):
- Find the 40th largest logit value
- Set all logits below this threshold to
$-\infty$ - Renormalize with softmax
- Sample from this truncated distribution
Example: If top-k=3 and logits are
- Top 3 values:
$[5, 3, 2]$ - Threshold: 2
- After masking:
$[5, 3, 2, -\infty, -\infty]$ - After softmax:
$[0.84, 0.11, 0.04, 0, 0]$
SARAN's training loop includes gradient accumulation, cosine annealing, and gradient clipping:
for i in range(max_iters):
opt.zero_grad(set_to_none=True)
for _ in range(grad_accum_steps):
loss = model(*get_batch("train"))[1]
(loss / grad_accum_steps).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
opt.step()
sched.step()With batch size 4 and 16 accumulation steps:
This allows training with a large effective batch size on limited GPU memory.
SARAN uses AdamW with specific hyperparameters:
opt = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1)Where:
-
$\alpha = 3 \times 10^{-4}$ (peak learning rate, with warmup) -
$\lambda = 0.1$ (weight decay) -
$\beta_1 = 0.9$ ,$\beta_2 = 0.95$ (momentum terms — note$\beta_2$ is lower than typical 0.999)
SARAN uses linear warmup followed by cosine decay:
def get_lr(it):
# Linear warmup for warmup_iters steps
if it < warmup_iters:
return learning_rate * (it + 1) / warmup_iters
# Cosine decay after warmup
decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters)
coeff = 0.5 * (1.0 + np.cos(np.pi * decay_ratio))
return learning_rate / 10 + coeff * (learning_rate - learning_rate / 10)The learning rate schedule:
-
Warmup (steps 0-2000): Linear increase from 0 to
$\eta_{max}$ -
Cosine decay (steps 2000-50000): Smooth decay to
$\eta_{min}$
- Warmup: 2,000 steps
- Peak:
$\eta_{max} = 3 \times 10^{-4}$ - Minimum:
$\eta_{min} = 3 \times 10^{-5}$ - Total: 50,000 iterations
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)If
This prevents exploding gradients and stabilizes training.
SARAN uses automatic mixed precision (AMP) for faster training on GPU/MPS:
from torch.amp import autocast
# Device-aware dtype selection
use_amp = device in ("mps", "cuda")
amp_dtype = torch.bfloat16 if use_amp else torch.float32
# Training loop with autocast
with autocast(device_type=device, dtype=amp_dtype, enabled=use_amp):
_, loss = model(xb, yb)
(loss / grad_accum_steps).backward()Precision Comparison:
| Precision | Bits | Memory | Speed | Stability |
|---|---|---|---|---|
| float32 | 32 | 100% | 1x | Best |
| float16 | 16 | 50% | ~2x | Needs loss scaling |
| bfloat16 | 16 | 50% | ~1.5-2x | Very stable |
Why bfloat16?
- Same exponent range as float32 (8 bits) — no overflow issues
- Reduced mantissa (7 bits vs 23) — slightly less precision
- No loss scaling required (unlike float16)
- Native support on Apple Silicon (MPS) and modern NVIDIA GPUs
- Typical loss difference: < 0.01-0.05 (negligible)
Memory savings:
PyTorch 2.0+ offers torch.compile for Just-In-Time (JIT) compilation, fusing operations for faster execution:
# Compile model for faster execution (CUDA only)
if hasattr(torch, "compile") and device == "cuda":
print("Compiling model with torch.compile...")
model = torch.compile(model, mode="reduce-overhead")Compilation Modes:
| Mode | Speedup | Compile Time | Best For |
|---|---|---|---|
| default | ~1.5x | Moderate | General use |
| reduce-overhead | ~1.5-2x | Longer | Small batches, LLMs |
| max-autotune | ~2x+ | Very long | Max throughput needed |
How it works:
- Graph capture: Traces model execution into a graph
- Fusion: Combines multiple operations (e.g., matmul + add + activation)
- Codegen: Generates optimized CUDA kernels via Triton
Platform support:
- CUDA: Full support, ~1.5-2x speedup typical
- MPS: Limited/experimental, ~1.1-1.3x speedup
- CPU: Works but minimal gains
Why CUDA-only in SARAN:
- MPS backend still experimental with torch.compile
- CPU gains are negligible and add startup overhead
- CUDA's Triton backend provides the most reliable speedups
SARAN uses PyTorch 2.0's F.scaled_dot_product_attention for efficient attention computation:
# Flash Attention - replaces manual attention computation
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)What it replaces:
# Old manual implementation (O(T²) memory)
scores = (q @ k.transpose(-2, -1)) * scale
scores = scores.masked_fill(causal_mask, float("-inf"))
attn = F.softmax(scores, dim=-1)
out = attn @ vBackend selection (automatic):
| Backend | Device | Memory | Speed | Requirements |
|---|---|---|---|---|
| Flash Attention v2 | CUDA | O(T) | Fastest | Ampere+ GPU |
| Memory-Efficient (xFormers) | CUDA | O(T) | Fast | Any CUDA GPU |
| Math Fallback | All | O(T²) | Baseline | Always available |
Benefits:
- O(T) memory vs O(T²) — enables much longer sequences
- ~2-4x faster on CUDA for T > 256
- Fused kernel — no intermediate tensors for scores/softmax
- is_causal=True — handles causal masking internally (no manual mask needed)
- Single-head compatible — works perfectly with SARAN's 1-head design
Memory comparison at T=512:
After Supervised Fine-Tuning (SFT) on Alpaca, SARAN applies Direct Preference Optimization (DPO) to align the model with human preferences. This is a simpler alternative to RLHF that doesn't require a separate reward model.
Traditional RLHF has three stages:
┌─────────────┐ ┌──────────────┐ ┌─────────────┐
│ SFT │ ──▶ │ Reward Model │ ──▶ │ RLHF │
│ (Alpaca) │ │ Training │ │ (PPO) │
└─────────────┘ └──────────────┘ └─────────────┘
↓ ↓ ↓
Base Model Scores Outputs Final Model
Problems with RLHF:
- Requires training a separate reward model
- PPO is unstable and hyperparameter-sensitive
- Complex multi-stage pipeline
DPO directly optimizes preferences without a reward model:
┌─────────────┐ ┌─────────────┐
│ SFT │ ──▶ │ DPO │
│ (Alpaca) │ │ (Preference)│
└─────────────┘ └─────────────┘
↓ ↓
SFT Model Final Model
Given a preference pair (chosen
Where:
-
$\pi_\theta$ = policy model (being trained) -
$\pi_{\text{ref}}$ = reference model (frozen copy of SFT model) -
$\beta$ = temperature parameter (0.1 in our config) -
$\sigma$ = sigmoid function
Intuition: The loss encourages the policy to:
- Increase probability of chosen responses (relative to reference)
- Decrease probability of rejected responses (relative to reference)
A frozen copy of the SFT model serves as the reference:
# Create frozen reference model
ref_model = copy.deepcopy(model)
ref_model.eval()
for param in ref_model.parameters():
param.requires_grad = FalsePurpose: The reference model prevents the policy from deviating too far from the SFT model, which could lead to reward hacking or degenerate outputs.
DPO requires pairs of (chosen, rejected) responses:
# Example preference pair
{
"chosen": "User: What is 2+2?\nAssistant: 2+2 equals 4.",
"rejected": "User: What is 2+2?\nAssistant: 2+2 is probably 5..."
}Data sources:
- Anthropic HH-RLHF (default) — Human preference data for helpfulness/harmlessness
- Synthetic (fallback) — Generated from Alpaca by creating lower-quality alternatives
| Parameter | Value | Description |
|---|---|---|
beta |
0.1 | Temperature controlling reference deviation |
learning_rate |
1e-6 | Very small LR for stable preference learning |
max_iters |
5000 | Training iterations |
batch_size |
1 | Preference pairs per batch |
grad_accum |
16 | Effective batch size = 16 |
patience |
10 | Early stopping patience |
Why lower LR? DPO is a fine-tuning of a fine-tuning — the model should only adjust preferences, not forget its instruction-following abilities.
for it in range(dpo_max_iters):
# Get preference pair
chosen, rejected = get_dpo_batch("train")
# Compute DPO loss
pi_chosen = get_log_probs(model, chosen)
pi_rejected = get_log_probs(model, rejected)
ref_chosen = get_log_probs(ref_model, chosen)
ref_rejected = get_log_probs(ref_model, rejected)
log_ratio = (pi_chosen - pi_rejected) - (ref_chosen - ref_rejected)
loss = -F.logsigmoid(beta * log_ratio).mean()
# Optimize
loss.backward()
optimizer.step()After the complete fine-tuning pipeline:
Final model: saran_mlv_dpo_best.pt — Used by saran_mlv_c.py for chat
Let's count all parameters:
| Component | Calculation | Parameters |
|---|---|---|
| Token Embedding | 50,304 × 1536 = 77,266,944 | |
| Position Embedding | 512 × 1536 = 786,432 | |
| Per Transformer Block: | ||
| → RMSNorm 1 | 1,536 | |
| → RMSNorm 2 | 1,536 | |
| → Attention QKV | 1536 × 4608 = 7,077,888 | |
| → Attention Output | 1536 × 1536 = 2,359,296 | |
| → FFN Layer 1 | 1536 × 6144 = 9,437,184 | |
| → FFN Layer 2 | 6144 × 1536 = 9,437,184 | |
| Block Total | ~28,314,624 | |
| All 24 Blocks | ~679,550,976 | |
| Final RMSNorm | 1,536 | |
| Output Head | (tied with embedding) | 0 |
Total: ~757.6 Million Parameters
Let's trace "Hello" through the entire network:
1. Input: "Hello" → token [15496] → tensor shape
2. Token Embedding: Look up row 15496 →
3. Position Embedding: Look up position 0 →
4. Sum + Dropout: Token + Position + Dropout →
5. Through 24 Blocks:
- Each block: RMSNorm → Single-Head Attn → Add → RMSNorm → FFN(4x) → Add
- Shape stays
$(1, 1, 1536)$ throughout
6. Final RMSNorm:
7. Output Head: Linear projection (tied weights) →
8. Softmax + Sample: Probability distribution over 50,257 tokens → sample next token
9. Repeat: Append new token, process again for next prediction
The SARAN architecture makes strategic simplifications to the GPT design:
| Component | GPT | SARAN | Tradeoff |
|---|---|---|---|
| Attention | 12 heads × 64d | 1 head × 1536d | Simpler, full-rank attention |
| FFN | 4x expansion | 4x expansion | Maximum synthesis capacity |
| Normalization | LayerNorm | RMSNorm | Faster, fewer params |
| Activation | GELU | SiLU | Modern, smooth |
| Output | Separate head | Weight tied | Fewer params |
| Biases | Yes | No | Fewer params |
| Dropout | Variable | 0.1 | Regularization |
| LR Schedule | Various | Warmup+Cosine | Stable training |
Key Insights:
- Single-head attention can be as effective as multi-head when the model is deep enough
- 4x FFN expansion provides maximum capacity for knowledge storage and synthesis
- Weight tying enforces semantic consistency and saves parameters
- RMSNorm is faster without sacrificing quality
- Dropout (0.1) prevents overfitting during pre-training
- LR warmup (2000 steps) stabilizes early training for large models
The result is a ~760M parameter model optimized for synthesis quality through architectural choices and scale.