Skip to content

Commit 7403cff

Browse files
committed
Switch from torchrun to elastic library
1 parent b1ca472 commit 7403cff

File tree

3 files changed

+125
-34
lines changed

3 files changed

+125
-34
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,16 @@ We currently don't support fine-grained manual control over the learning rate, n
113113

114114
## Distributed training
115115

116-
We support distributed training via PyTorch's `torchrun` command. By default we use the Distributed Data Parallel method, which means that the weights of each SAE are replicated on every GPU.
116+
We support distributed training out of the box using Torch Distributed Elastic. By default we use Distributed Data Parallel with all visible GPUs, which means that the weights of each SAE are replicated on every GPU.
117117

118118
```bash
119-
torchrun --nproc_per_node gpu -m sparsify meta-llama/Meta-Llama-3-8B --batch_size 1 --layers 16 24 --k 192 --grad_acc_steps 8 --ctx_len 2048
119+
python -m sparsify meta-llama/Meta-Llama-3-8B --batch_size 1 --layers 16 24 --k 192 --grad_acc_steps 8 --ctx_len 2048
120120
```
121121

122122
This is simple, but very memory inefficient. If you want to train SAEs for many layers of a model, we recommend using the `--distribute_modules` flag, which allocates the SAEs for different layers to different GPUs. Currently, we require that the number of GPUs evenly divides the number of layers you're training SAEs for.
123123

124124
```bash
125-
torchrun --nproc_per_node gpu -m sparsify meta-llama/Meta-Llama-3-8B --distribute_modules --batch_size 1 --layer_stride 2 --grad_acc_steps 8 --ctx_len 2048 --k 192 --load_in_8bit --micro_acc_steps 2
125+
python -m sparsify meta-llama/Meta-Llama-3-8B --distribute_modules --batch_size 1 --layer_stride 2 --grad_acc_steps 8 --ctx_len 2048 --k 192 --load_in_8bit --micro_acc_steps 2
126126
```
127127

128128
The above command trains an SAE for every _even_ layer of Llama 3 8B, using all available GPUs. It accumulates gradients over 8 minibatches, and splits each minibatch into 2 microbatches before feeding them into the SAE encoder, thus saving a lot of memory. It also loads the model in 8-bit precision using `bitsandbytes`. This command requires no more than 48GB of memory per GPU on an 8 GPU node.

sparsify/__main__.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919

2020
from .data import MemmapDataset, chunk_and_tokenize
21+
from .distributed import handle_distribute
2122
from .trainer import TrainConfig, Trainer
2223
from .utils import simple_parse_args_string
2324

@@ -77,31 +78,7 @@ class RunConfig(TrainConfig):
7778
format 'arg1=val1,arg2=val2'."""
7879

7980

80-
def load_artifacts(
81-
args: RunConfig, rank: int
82-
) -> tuple[PreTrainedModel, Dataset | MemmapDataset]:
83-
if args.load_in_8bit:
84-
dtype = torch.float16
85-
elif torch.cuda.is_bf16_supported():
86-
dtype = torch.bfloat16
87-
else:
88-
dtype = "auto"
89-
90-
# End-to-end training requires a model with a causal LM head
91-
model_cls = AutoModel if args.loss_fn == "fvu" else AutoModelForCausalLM
92-
model = model_cls.from_pretrained(
93-
args.model,
94-
device_map={"": f"cuda:{rank}"},
95-
quantization_config=(
96-
BitsAndBytesConfig(load_in_8bit=args.load_in_8bit)
97-
if args.load_in_8bit
98-
else None
99-
),
100-
revision=args.revision,
101-
torch_dtype=dtype,
102-
token=args.hf_token,
103-
)
104-
81+
def load_data(args: RunConfig):
10582
# For memmap-style datasets
10683
if args.dataset.endswith(".bin"):
10784
dataset = MemmapDataset(args.dataset, args.ctx_len, args.max_examples)
@@ -137,10 +114,36 @@ def load_artifacts(
137114
if limit := args.max_examples:
138115
dataset = dataset.select(range(limit))
139116

140-
return model, dataset
117+
return dataset
141118

142119

143-
def run():
120+
def load_model_artifact(args: RunConfig, rank: int) -> PreTrainedModel:
121+
if args.load_in_8bit:
122+
dtype = torch.float16
123+
elif torch.cuda.is_bf16_supported():
124+
dtype = torch.bfloat16
125+
else:
126+
dtype = "auto"
127+
128+
# End-to-end training requires a model with a causal LM head
129+
model_cls = AutoModel if args.loss_fn == "fvu" else AutoModelForCausalLM
130+
model = model_cls.from_pretrained(
131+
args.model,
132+
device_map={"": f"cuda:{rank}"},
133+
quantization_config=(
134+
BitsAndBytesConfig(load_in_8bit=args.load_in_8bit)
135+
if args.load_in_8bit
136+
else None
137+
),
138+
revision=args.revision,
139+
torch_dtype=dtype,
140+
token=args.hf_token,
141+
)
142+
143+
return model
144+
145+
146+
def worker(args: RunConfig, dataset: Dataset | MemmapDataset):
144147
local_rank = os.environ.get("LOCAL_RANK")
145148
ddp = local_rank is not None
146149
rank = int(local_rank) if ddp else 0
@@ -157,17 +160,15 @@ def run():
157160
if rank == 0:
158161
print(f"Using DDP across {dist.get_world_size()} GPUs.")
159162

160-
args = parse(RunConfig)
161-
162163
# Prevent ranks other than 0 from printing
163164
with nullcontext() if rank == 0 else redirect_stdout(None):
164165
# Awkward hack to prevent other ranks from duplicating data preprocessing
165166
if not ddp or rank == 0:
166-
model, dataset = load_artifacts(args, rank)
167+
model = load_model_artifact(args, rank)
167168
if ddp:
168169
dist.barrier()
169170
if rank != 0:
170-
model, dataset = load_artifacts(args, rank)
171+
model = load_model_artifact(args, rank)
171172

172173
# Drop examples that are indivisible across processes to prevent deadlock
173174
remainder_examples = len(dataset) % dist.get_world_size()
@@ -196,5 +197,17 @@ def run():
196197
trainer.fit()
197198

198199

200+
def run():
201+
args = parse(RunConfig)
202+
203+
dataset = load_data(args)
204+
205+
handle_distribute(
206+
process_name="sparsify",
207+
worker=worker,
208+
const_worker_args=[args, dataset],
209+
)
210+
211+
199212
if __name__ == "__main__":
200213
run()

sparsify/distributed.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import socket
2+
from typing import Any, Callable
3+
4+
import torch
5+
import torch.distributed as dist
6+
import torch.multiprocessing as mp
7+
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, start_processes
8+
9+
10+
def dist_worker(
11+
worker: Callable,
12+
*worker_args,
13+
):
14+
try:
15+
worker(*worker_args)
16+
finally:
17+
if dist.is_initialized():
18+
try:
19+
dist.barrier()
20+
except Exception as e:
21+
print(f"Barrier failed during cleanup: {e}")
22+
pass
23+
24+
dist.destroy_process_group()
25+
26+
27+
def handle_distribute(process_name: str, worker, const_worker_args: list[Any]):
28+
"""
29+
Launch a distributed multi-process job over all visible CUDA devices.
30+
31+
Parameters
32+
----------
33+
process_name : str
34+
Label used by Torch Elastic to tag logs and processes.
35+
worker : Callable
36+
Function that will be executed on every spawned process. It must accept
37+
``(rank, world_size, *const_worker_args)`` in that order.
38+
const_worker_args : list
39+
Arguments passed verbatim to every worker invocation after ``rank`` and
40+
``world_size``. These are typically configuration or shared datasets.
41+
"""
42+
world_size = torch.cuda.device_count()
43+
if world_size <= 1:
44+
# Run the worker directly if no distributed training is needed. This is great
45+
# for debugging purposes.
46+
worker(0, 1, *const_worker_args)
47+
else:
48+
# Set up multiprocessing and distributed training
49+
mp.set_sharing_strategy("file_system")
50+
51+
# Find an available port for distributed training
52+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
53+
s.bind(("", 0))
54+
_, port = s.getsockname()
55+
56+
ctx = None
57+
try:
58+
ctx = start_processes(
59+
process_name,
60+
dist_worker,
61+
args={
62+
i: (worker, i, world_size, *const_worker_args)
63+
for i in range(world_size)
64+
},
65+
envs={
66+
i: {
67+
"LOCAL_RANK": str(i),
68+
"MASTER_ADDR": "localhost",
69+
"MASTER_PORT": str(port),
70+
}
71+
for i in range(world_size)
72+
},
73+
logs_specs=DefaultLogsSpecs(),
74+
)
75+
ctx.wait()
76+
finally:
77+
if ctx is not None:
78+
ctx.close() # Kill any processes that are still running

0 commit comments

Comments
 (0)