You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+3-3Lines changed: 3 additions & 3 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -113,16 +113,16 @@ We currently don't support fine-grained manual control over the learning rate, n
113
113
114
114
## Distributed training
115
115
116
-
We support distributed training via PyTorch's `torchrun` command. By default we use the Distributed Data Parallel method, which means that the weights of each SAE are replicated on every GPU.
116
+
We support distributed training out of the box using Torch Distributed Elastic. By default we use Distributed Data Parallel with all visible GPUs, which means that the weights of each SAE are replicated on every GPU.
This is simple, but very memory inefficient. If you want to train SAEs for many layers of a model, we recommend using the `--distribute_modules` flag, which allocates the SAEs for different layers to different GPUs. Currently, we require that the number of GPUs evenly divides the number of layers you're training SAEs for.
The above command trains an SAE for every _even_ layer of Llama 3 8B, using all available GPUs. It accumulates gradients over 8 minibatches, and splits each minibatch into 2 microbatches before feeding them into the SAE encoder, thus saving a lot of memory. It also loads the model in 8-bit precision using `bitsandbytes`. This command requires no more than 48GB of memory per GPU on an 8 GPU node.
0 commit comments