diff --git a/CHANGELOG.md b/CHANGELOG.md index bfa830ce..01834c41 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,13 @@ # CHANGELOG +## v1.1.4 (2025-07-06) + +### Features + +- Add new initialization strategy for the transcoder. + + ## v1.1.3 (2025-04-17) ### Bug Fixes diff --git a/pyproject.toml b/pyproject.toml index aa31af56..f68e56c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "torch", "transformers", ] -version = "1.1.3" +version = "1.1.4" [project.optional-dependencies] dev = [ "pre-commit", diff --git a/sparsify/__init__.py b/sparsify/__init__.py index acf1d5e3..865b6d6f 100644 --- a/sparsify/__init__.py +++ b/sparsify/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.1.3" +__version__ = "1.1.4" from .config import SaeConfig, SparseCoderConfig, TrainConfig, TranscoderConfig from .sparse_coder import Sae, SparseCoder diff --git a/sparsify/config.py b/sparsify/config.py index 1e348ef5..425ea9f0 100644 --- a/sparsify/config.py +++ b/sparsify/config.py @@ -35,6 +35,9 @@ class SparseCoderConfig(Serializable): transcode: bool = False """Whether we want to predict the output of a module given its input.""" + init_method: Literal["linear", "mlp"] = "linear" + """Initialization method to use for the transcoder.""" + # Support different naming conventions for the same configuration SaeConfig = SparseCoderConfig diff --git a/sparsify/fused_encoder.py b/sparsify/fused_encoder.py index 272d47db..a94068d1 100644 --- a/sparsify/fused_encoder.py +++ b/sparsify/fused_encoder.py @@ -26,7 +26,8 @@ def forward( bias: (M,) k: int (number of top elements to select along dim=1) """ - preacts = F.relu(F.linear(input, weight, bias)) + if b_enc is not None: + pre_acts = pre_acts + b_enc # Get top-k values and indices for each row if activation == "topk": @@ -88,11 +89,11 @@ def backward(ctx, grad_values, grad_indices, grad_preacts): def fused_encoder( - input, - weight, - bias, + x: Tensor, + W_enc: Tensor, + b_enc: Tensor | None, k: int, - activation: Literal["groupmax", "topk"], + activation: Literal["groupmax", "topk"] = "topk", ) -> EncoderOutput: """ Convenience wrapper that performs an nn.Linear followed by `activation` with diff --git a/sparsify/sparse_coder.py b/sparsify/sparse_coder.py index 2a8a8cd7..fb9a2ddb 100644 --- a/sparsify/sparse_coder.py +++ b/sparsify/sparse_coder.py @@ -49,28 +49,53 @@ def __init__( self.d_in = d_in self.num_latents = cfg.num_latents or d_in * cfg.expansion_factor - self.encoder = nn.Linear(d_in, self.num_latents, device=device, dtype=dtype) - self.encoder.bias.data.zero_() + if cfg.init_method == "linear": + self.encoder = nn.Linear( + d_in, self.num_latents, device=device, dtype=dtype + ) + self.encoder.bias.data.zero_() + elif cfg.init_method == "mlp": + # Initialize the encoder with the weights of the first layer of the MLP + # and the decoder with the weights of the second layer. + self.encoder = nn.Linear( + d_in, self.num_latents, bias=False, device=device, dtype=dtype + ) + self.W_dec = nn.Parameter( + torch.zeros(self.num_latents, d_in, device=device, dtype=dtype) + ) + else: + raise ValueError(f"Unknown init_method: {cfg.init_method}") + + else: + self.W_dec = None - if decoder: + if decoder and self.W_dec is None: # Transcoder initialization: use zeros if cfg.transcode: - self.W_dec = nn.Parameter(torch.zeros_like(self.encoder.weight.data)) + if cfg.init_method == "linear": + self.W_dec = nn.Parameter( + torch.zeros_like(self.encoder.weight.data) + ) # Sparse autoencoder initialization: use the transpose of encoder weights else: self.W_dec = nn.Parameter(self.encoder.weight.data.clone()) if self.cfg.normalize_decoder: self.set_decoder_norm_to_unit_norm() + + if self.cfg.init_method == "linear": + self.b_dec = nn.Parameter(torch.zeros(d_in, dtype=dtype, device=device)) else: - self.W_dec = None + self.b_dec = None - self.b_dec = nn.Parameter(torch.zeros(d_in, dtype=dtype, device=device)) - self.W_skip = ( - nn.Parameter(torch.zeros(d_in, d_in, device=device, dtype=dtype)) - if cfg.skip_connection - else None - ) + if cfg.skip_connection and cfg.init_method == "linear": + self.W_skip = nn.Parameter( + torch.zeros(d_in, d_in, device=device, dtype=dtype) + ) + else: + self.W_skip = None + else: + self.W_skip = None @staticmethod def load_many( @@ -168,6 +193,41 @@ def save_to_disk(self, path: Path | str): f, ) + @torch.no_grad() + def init_from_mlp( + self, model: nn.Module, layer_idx: int, hookpoint: str + ): + """Initialize the transcoder from an MLP layer.""" + assert self.cfg.init_method == "mlp" + + # Get the MLP layer + mlp = model.get_submodule(hookpoint) + + # The encoder of the sparse coder is analogous to the up-projection in the MLP + up_proj_weight = mlp.gate_proj.weight.data + + # The decoder of the sparse coder is analogous to the down-projection in the MLP + down_proj_weight = mlp.down_proj.weight.data + + # Check if we need to stack the weights + mlp_hidden_dim = up_proj_weight.shape[0] + if self.num_latents % mlp_hidden_dim != 0: + raise ValueError( + f"Number of latents ({self.num_latents}) must be a multiple of the " + f"MLP hidden dimension ({mlp_hidden_dim}) for MLP initialization." + ) + + num_stacks = self.num_latents // mlp_hidden_dim + + # Initialize the encoder with stacked copies of the MLP's up-projection + stacked_up_proj = up_proj_weight.repeat(num_stacks, 1) + self.encoder.weight.data.copy_(stacked_up_proj) + + # The decoder weights are the transpose of the down-projection + # So we stack the transpose of the down-projection + stacked_down_proj_t = down_proj_weight.t().repeat(1, num_stacks) + self.W_dec.data.copy_(stacked_down_proj_t.t()) + @property def device(self): return self.encoder.weight.device @@ -178,18 +238,22 @@ def dtype(self): def encode(self, x: Tensor) -> EncoderOutput: """Encode the input and select the top-k latents.""" - if not self.cfg.transcode: + if not self.cfg.transcode and self.b_dec is not None: x = x - self.b_dec return fused_encoder( x, self.encoder.weight, self.encoder.bias, self.cfg.k, self.cfg.activation + ) if self.encoder.bias is not None else fused_encoder( + x, self.encoder.weight, torch.zeros_like(self.encoder.weight[:, 0]), self.cfg.k, self.cfg.activation ) def decode(self, top_acts: Tensor, top_indices: Tensor) -> Tensor: assert self.W_dec is not None, "Decoder weight was not initialized." y = decoder_impl(top_indices, top_acts.to(self.dtype), self.W_dec.mT) - return y + self.b_dec + if self.b_dec is not None: + y = y + self.b_dec + return y # Wrapping the forward in bf16 autocast improves performance by almost 2x @torch.autocast( diff --git a/sparsify/trainer.py b/sparsify/trainer.py index 8093cdab..f270055e 100644 --- a/sparsify/trainer.py +++ b/sparsify/trainer.py @@ -85,6 +85,8 @@ def __init__( self.saes[name] = SparseCoder( input_widths[hook], cfg.sae, device, dtype=torch.float32 ) + if cfg.sae.init_method == "mlp": + self.saes[name].init_from_mlp(model.base_model, 0, hook) assert isinstance(dataset, Sized) num_batches = len(dataset) // cfg.batch_size @@ -377,13 +379,14 @@ def hook(module: nn.Module, inputs, outputs): # Ensure the preactivations are centered at initialization # This is mathematically equivalent to Anthropic's proposal of # subtracting the decoder bias - if self.cfg.sae.transcode: + if self.cfg.sae.transcode and self.cfg.sae.init_method == "linear": mean = self.maybe_all_reduce(inputs.mean(0)).to(raw.dtype) mean_image = -mean @ raw.encoder.weight.data.T raw.encoder.bias.data = mean_image - mean = self.maybe_all_reduce(outputs.mean(0)) - raw.b_dec.data = mean.to(raw.dtype) + if raw.b_dec is not None: + mean = self.maybe_all_reduce(outputs.mean(0)) + raw.b_dec.data = mean.to(raw.dtype) # Make sure the W_dec is still unit-norm if we're autoencoding if raw.cfg.normalize_decoder and not self.cfg.sae.transcode: