Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
# CHANGELOG


## v1.1.4 (2025-07-06)

### Features

- Add new initialization strategy for the transcoder.


## v1.1.3 (2025-04-17)

### Bug Fixes
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies = [
"torch",
"transformers",
]
version = "1.1.3"
version = "1.1.4"
[project.optional-dependencies]
dev = [
"pre-commit",
Expand Down
2 changes: 1 addition & 1 deletion sparsify/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.1.3"
__version__ = "1.1.4"

from .config import SaeConfig, SparseCoderConfig, TrainConfig, TranscoderConfig
from .sparse_coder import Sae, SparseCoder
Expand Down
3 changes: 3 additions & 0 deletions sparsify/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class SparseCoderConfig(Serializable):
transcode: bool = False
"""Whether we want to predict the output of a module given its input."""

init_method: Literal["linear", "mlp"] = "linear"
"""Initialization method to use for the transcoder."""


# Support different naming conventions for the same configuration
SaeConfig = SparseCoderConfig
Expand Down
11 changes: 6 additions & 5 deletions sparsify/fused_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def forward(
bias: (M,)
k: int (number of top elements to select along dim=1)
"""
preacts = F.relu(F.linear(input, weight, bias))
if b_enc is not None:
pre_acts = pre_acts + b_enc

# Get top-k values and indices for each row
if activation == "topk":
Expand Down Expand Up @@ -88,11 +89,11 @@ def backward(ctx, grad_values, grad_indices, grad_preacts):


def fused_encoder(
input,
weight,
bias,
x: Tensor,
W_enc: Tensor,
b_enc: Tensor | None,
k: int,
activation: Literal["groupmax", "topk"],
activation: Literal["groupmax", "topk"] = "topk",
) -> EncoderOutput:
"""
Convenience wrapper that performs an nn.Linear followed by `activation` with
Expand Down
90 changes: 77 additions & 13 deletions sparsify/sparse_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,28 +49,53 @@ def __init__(
self.d_in = d_in
self.num_latents = cfg.num_latents or d_in * cfg.expansion_factor

self.encoder = nn.Linear(d_in, self.num_latents, device=device, dtype=dtype)
self.encoder.bias.data.zero_()
if cfg.init_method == "linear":
self.encoder = nn.Linear(
d_in, self.num_latents, device=device, dtype=dtype
)
self.encoder.bias.data.zero_()
elif cfg.init_method == "mlp":
# Initialize the encoder with the weights of the first layer of the MLP
# and the decoder with the weights of the second layer.
self.encoder = nn.Linear(
d_in, self.num_latents, bias=False, device=device, dtype=dtype
)
self.W_dec = nn.Parameter(
torch.zeros(self.num_latents, d_in, device=device, dtype=dtype)
)
else:
raise ValueError(f"Unknown init_method: {cfg.init_method}")

else:
self.W_dec = None

if decoder:
if decoder and self.W_dec is None:
# Transcoder initialization: use zeros
if cfg.transcode:
self.W_dec = nn.Parameter(torch.zeros_like(self.encoder.weight.data))
if cfg.init_method == "linear":
self.W_dec = nn.Parameter(
torch.zeros_like(self.encoder.weight.data)
)

# Sparse autoencoder initialization: use the transpose of encoder weights
else:
self.W_dec = nn.Parameter(self.encoder.weight.data.clone())
if self.cfg.normalize_decoder:
self.set_decoder_norm_to_unit_norm()

if self.cfg.init_method == "linear":
self.b_dec = nn.Parameter(torch.zeros(d_in, dtype=dtype, device=device))
else:
self.W_dec = None
self.b_dec = None

self.b_dec = nn.Parameter(torch.zeros(d_in, dtype=dtype, device=device))
self.W_skip = (
nn.Parameter(torch.zeros(d_in, d_in, device=device, dtype=dtype))
if cfg.skip_connection
else None
)
if cfg.skip_connection and cfg.init_method == "linear":
self.W_skip = nn.Parameter(
torch.zeros(d_in, d_in, device=device, dtype=dtype)
)
else:
self.W_skip = None
else:
self.W_skip = None

@staticmethod
def load_many(
Expand Down Expand Up @@ -168,6 +193,41 @@ def save_to_disk(self, path: Path | str):
f,
)

@torch.no_grad()
def init_from_mlp(
self, model: nn.Module, layer_idx: int, hookpoint: str
):
"""Initialize the transcoder from an MLP layer."""
assert self.cfg.init_method == "mlp"

# Get the MLP layer
mlp = model.get_submodule(hookpoint)

# The encoder of the sparse coder is analogous to the up-projection in the MLP
up_proj_weight = mlp.gate_proj.weight.data

# The decoder of the sparse coder is analogous to the down-projection in the MLP
down_proj_weight = mlp.down_proj.weight.data

# Check if we need to stack the weights
mlp_hidden_dim = up_proj_weight.shape[0]
if self.num_latents % mlp_hidden_dim != 0:
raise ValueError(
f"Number of latents ({self.num_latents}) must be a multiple of the "
f"MLP hidden dimension ({mlp_hidden_dim}) for MLP initialization."
)

num_stacks = self.num_latents // mlp_hidden_dim

# Initialize the encoder with stacked copies of the MLP's up-projection
stacked_up_proj = up_proj_weight.repeat(num_stacks, 1)
self.encoder.weight.data.copy_(stacked_up_proj)

# The decoder weights are the transpose of the down-projection
# So we stack the transpose of the down-projection
stacked_down_proj_t = down_proj_weight.t().repeat(1, num_stacks)
self.W_dec.data.copy_(stacked_down_proj_t.t())

@property
def device(self):
return self.encoder.weight.device
Expand All @@ -178,18 +238,22 @@ def dtype(self):

def encode(self, x: Tensor) -> EncoderOutput:
"""Encode the input and select the top-k latents."""
if not self.cfg.transcode:
if not self.cfg.transcode and self.b_dec is not None:
x = x - self.b_dec

return fused_encoder(
x, self.encoder.weight, self.encoder.bias, self.cfg.k, self.cfg.activation
) if self.encoder.bias is not None else fused_encoder(
x, self.encoder.weight, torch.zeros_like(self.encoder.weight[:, 0]), self.cfg.k, self.cfg.activation
)

def decode(self, top_acts: Tensor, top_indices: Tensor) -> Tensor:
assert self.W_dec is not None, "Decoder weight was not initialized."

y = decoder_impl(top_indices, top_acts.to(self.dtype), self.W_dec.mT)
return y + self.b_dec
if self.b_dec is not None:
y = y + self.b_dec
return y

# Wrapping the forward in bf16 autocast improves performance by almost 2x
@torch.autocast(
Expand Down
9 changes: 6 additions & 3 deletions sparsify/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def __init__(
self.saes[name] = SparseCoder(
input_widths[hook], cfg.sae, device, dtype=torch.float32
)
if cfg.sae.init_method == "mlp":
self.saes[name].init_from_mlp(model.base_model, 0, hook)

assert isinstance(dataset, Sized)
num_batches = len(dataset) // cfg.batch_size
Expand Down Expand Up @@ -377,13 +379,14 @@ def hook(module: nn.Module, inputs, outputs):
# Ensure the preactivations are centered at initialization
# This is mathematically equivalent to Anthropic's proposal of
# subtracting the decoder bias
if self.cfg.sae.transcode:
if self.cfg.sae.transcode and self.cfg.sae.init_method == "linear":
mean = self.maybe_all_reduce(inputs.mean(0)).to(raw.dtype)
mean_image = -mean @ raw.encoder.weight.data.T
raw.encoder.bias.data = mean_image

mean = self.maybe_all_reduce(outputs.mean(0))
raw.b_dec.data = mean.to(raw.dtype)
if raw.b_dec is not None:
mean = self.maybe_all_reduce(outputs.mean(0))
raw.b_dec.data = mean.to(raw.dtype)

# Make sure the W_dec is still unit-norm if we're autoencoding
if raw.cfg.normalize_decoder and not self.cfg.sae.transcode:
Expand Down