Note
Now this project has been released on PyPI. You can install it withpip install wsffn.
Warning: This is a alpha version. Please report any issues you encounter.
Feedback and contributions welcome!
WsFFN is a drop-in replacement for the standard SwiGLU feed-forward network (FFN) used in modern Transformer language models. It introduces two ideas:
- Parallelized z-head projection that yields Balanced Multi-head while keeping the compute profile of a dense FFN.
- An auxiliary latent space with regularization and contrastive structure to encourage world-model-like representations.
This repository provides a compact PyTorch implementation intended to be embedded inside Transformer blocks.
- SwiGLU FFN core: WsFFN keeps the strong empirical performance of SwiGLU.
-
Head-wise partition of the FFN hidden dimension
$d_{\text{ffn}}$ into$n_{\text{head}}$ groups ("heads"). Each head operates on$$z_{\text{dim-head}} = d_{\text{ffn}} / n_{\text{head}}$$ channels. - Grouped, block-diagonal z-projection that is computed in a single dense matmul for efficiency, yet preserves per-head independence.
- Auxiliary training losses on the z-head outputs that are fully batched and parallel, adding negligible overhead.
The core idea and architectural design of WsFFN originated from my initial insight into forming and regularizing a point-wise latent space within the FFN. Through a subsequent brainstorming process, this idea was refined and materialized into its current form, which aims to achieve Balanced Multi-head functionality using dense computation. The final PyTorch code implementation was executed with high fidelity and efficiency with the assistance of Claude Sonnet and Gemini.
Given input
where $\mathbf{W}1, \mathbf{W}3 \in \mathbb{R}^{d{\text{model}} \times d{\text{ffn}}}$.
where $\mathbf{W}2 \in \mathbb{R}^{d{\text{ffn}} \times d_{\text{model}}}$.
This recovers the standard SwiGLU FFN.
We split
where
We apply a grouped linear projection
where $\mathbf{W}z \in \mathbb{R}^{d{\text{ffn}} \times d_{\text{ffn}}}$ is block-diagonal across heads.
We then reshape back to per-head tensors to obtain per-head latent vectors:
Let
Implementation detail: We flatten
We average
Flattening over batch and head yields $\mathbf{Z}{\text{ctx_flat}} \in \mathbb{R}^{(B \cdot n{\text{head}}) \times z_{\text{dim}}}$.
We compute cosine similarities
In practice,
These losses are designed to be fully batched and parallelized, adding minimal per-step latency.
- Efficiency: A single dense
Linear(d_ffn→d_ffn)achieves high utilization on modern accelerators. - Independence: Initializing with block-diagonal weights assigns disjoint channel groups to different heads, preserving head-wise specialization pressure.
- Flexibility: Heads can learn to share information by drifting from the initial block structure if beneficial, but start with clear separation.
Class: wsFFN
Config: d_model, d_ffn, n_head, λ_z, λ_c, λ_logits_z (reserved), use_aux_loss toggles losses at call time.
Important:
import torch
import torch.nn as nn
from typing import Optional, Any, Tuple
from WsFFN import wsFFN, Config # Assuming wsFFN, Config are available
# --- CONCEPTUAL CLASS: Transformer Block (Wraps wsFFN) ---
# A simplified layer that integrates wsFFN and passes through the training status.
class TransformerLayer(nn.Module):
def __init__(self, config: Config):
super().__init__()
# Simplified: just the wsFFN part for demonstration
self.ffn = wsFFN(config)
self.attn = nn.Identity()
# Note: This forward method must match the call in the full model: h, _, aux_loss = layer(h, use_cache=False)
def forward(self, x: torch.Tensor, use_cache: bool = False) -> Tuple[torch.Tensor, Any, torch.Tensor or None]:
# Assume self.training is correctly set (via model.train() or passed down from full model)
output, aux_loss = self.ffn(x)
# The full layer would also include Attention and Normalization, but we return a dummy cache slot
return output, None, aux_loss
# --- CONCEPTUAL CLASS: Full Model (The one you designed) ---
class Model(nn.Module):
def __init__(self, config: Config, num_layers: int = 4, vocab_size: int = 32000):
super().__init__()
self.config = config
self.token_embeddings = nn.Embedding(vocab_size, config.d_model)
# Create a stack of TransformerLayers, each containing a wsFFN
self.layers = nn.ModuleList([TransformerLayer(config) for _ in range(num_layers)])
self.norm_f = nn.LayerNorm(config.d_model)
self.lm_head = nn.Linear(config.d_model, vocab_size, bias=False)
# Simplified mock method for main loss calculation
def compute_main_loss(self, logits, labels, objective_id):
# In a real model, this would be Cross-Entropy or similar
return torch.randn(1) * 10 # Mock loss
def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None,
objective_id: Optional[torch.Tensor] = None) -> Any:
# Custom training check based on inputs being provided
is_training = labels is not None and objective_id is not None
# 1. Embeddings
h = self.token_embeddings(input_ids)
total_wsffn_aux_loss = torch.tensor(0.0, device=h.device)
# 2. Iterate and Accumulate wsFFN Loss
for layer in self.layers:
# kv_cache is ignored for training loop
# aux_loss contains L_Z + L_C from the wsFFN in that layer
h, _, aux_loss = layer(h, use_cache=False)
if is_training and aux_loss is not None:
total_wsffn_aux_loss += aux_loss
# 3. Final Projection
h = self.norm_f(h)
logits = self.lm_head(h)
if is_training:
# Calculate Main Loss (e.g., cross-entropy)
main_loss = self.compute_main_loss(logits, labels, objective_id)
# Calculate Logits Z-Loss (The third component)
logits_for_z = logits.detach()
logits_z_loss = self.config.lambda_logits_z * torch.logsumexp(logits_for_z, dim=-1).pow(2).mean()
# FINAL TOTAL LOSS = Main Loss + wsFFN Aux Loss (L_Z+L_C) + Logits Z-Loss
total_loss = main_loss + total_wsffn_aux_loss + logits_z_loss
# Prepare loss dictionary for logging/return
loss_dict = {
'total_loss': total_loss.item(),
'main_loss': main_loss.item(),
'wsffn_aux_loss': total_wsffn_aux_loss.item(),
'z_loss': logits_z_loss.item(),
}
# Simplified return, ignoring the NaN/Inf handling for clarity
return total_loss, logits, loss_dict
return logits
# --- USAGE DEMONSTRATION ---
# Configuration
VOCAB_SIZE = 32000
B, L = 2, 128
cfg = Config(d_model=1024, d_ffn=4096, n_head=8,
lambda_z=1e-5, lambda_c=5e-3, lambda_logits_z=1e-4, use_aux_loss=True)
# Instantiate the full model
full_model = Model(cfg)
# 1. TRAINING MODE (All Losses Calculated)
# Note: The model's loss calculation relies on the presence of labels/objective_id
# We must ensure model.train() is called for the internal wsFFN layers
full_model.train()
# Prepare dummy inputs for training
input_ids_train = torch.randint(0, VOCAB_SIZE, (B, L))
labels_train = torch.randint(0, VOCAB_SIZE, (B, L)) # Required to signal training
objective_id_train = torch.ones(B, L) # Required to signal training
print("\n--- Training (Auxiliary Losses Enabled) ---")
total_loss, logits, loss_dict = full_model(
input_ids=input_ids_train,
labels=labels_train,
objective_id=objective_id_train
)
print(f"Total Loss (L_Main + L_wsFFN + L_LogitsZ): {total_loss.item():.4f}")
print(f"Loss Components: {loss_dict}")
# 2. INFERENCE MODE (Only Logits Returned)
# Note: While full_model.eval() is good practice, loss is avoided by omitting inputs
full_model.eval()
# Prepare dummy inputs for inference (omit labels/objective_id)
input_ids_eval = torch.randint(0, VOCAB_SIZE, (B, L))
print("\n--- Inference (Loss Calculation Skipped) ---")
with torch.no_grad():
logits_eval = full_model(input_ids=input_ids_eval)
print(f"Output: Logits tensor of shape {logits_eval.shape}")Replace the standard FFN with wsFFN inside your Transformer layer, keeping the same residual structure. For example:
- Pre-LN:
x = x + Dropout(wsFFN(LayerNorm(x))) - Post-LN:
x = LayerNorm(x + Dropout(wsFFN(x)))
The module is purely feed-forward; there is no change to attention modules.
-
Pretraining: Enable auxiliary losses (
model.train()) and includeauxin the total loss with weights$\lambda_z$ ,$\lambda_c$ . -
Finetuning: Optionally disable auxiliary losses (
model.eval()) to reduce overhead. -
Typical hyperparameters:
$\lambda_z \approx 10^{-5}$ ,$\lambda_c \approx 5 \times 10^{-3}$ ,$\tau = 0.07$ . -
$d_{\text{ffn}}$ divisible by$n_{\text{head}}$ ; common setting:$d_{\text{ffn}} = 4 \times d_{\text{model}}$ ,$n_{\text{head}} \in {4, 8, 16}$ .
Soft MoE encourages specialization without discrete routing by creating pressure for different subspaces to model different features. WsFFN achieves a similar effect by:
- Structuring the FFN hidden space into heads.
- Applying per-head projections and contrastive pressure across batch×head elements.
- Keeping the compute dense and fully parallel, preserving throughput and hardware efficiency.
- The InfoNCE variant here uses self-similarity as the positive. It is simple and fully parallel, but you may experiment with other positives (e.g., augmentations, multi-view encoders) for stronger semantic structure.
-
$\lambda_{\text{logits_z}}$ is included in the config for future extensions where a$\mathbf{z} \rightarrow \text{logits}$ head is added; it is unused in the current implementation. - The z-head projection is initialized block-diagonally; training may alter this structure.
Requires PyTorch.
pip install torchThen import the module from this repository.
If you use WsFFN in your research, please cite:
vmintf. (2025). vmintf/WsFFN: World-Structured Feed-Forward Network for Efficient Language Models (Version v0.0.1-alpha.2) [Computer software]. https://doi.org/10.5281/zenodo.17640630
Key references:
- Soft MoE: Puigcerver, J., et al. (2023). "From Sparse to Soft Mixtures of Experts." arXiv:2308.00951.
- SwiGLU: Shazeer, N. (2020). "GLU Variants Improve Transformer." arXiv:2002.05202.
- Contrastive Learning: Oord, A., et al. (2018). "Representation Learning with Contrastive Predictive Coding." arXiv:1807.03748.
MIT