Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,16 @@ We currently don't support fine-grained manual control over the learning rate, n

## Distributed training

We support distributed training via PyTorch's `torchrun` command. By default we use the Distributed Data Parallel method, which means that the weights of each SAE are replicated on every GPU.
We support distributed training out of the box using Torch Distributed Elastic. By default we use Distributed Data Parallel with all visible GPUs, which means that the weights of each SAE are replicated on every GPU.

```bash
torchrun --nproc_per_node gpu -m sparsify meta-llama/Meta-Llama-3-8B --batch_size 1 --layers 16 24 --k 192 --grad_acc_steps 8 --ctx_len 2048
python -m sparsify meta-llama/Meta-Llama-3-8B --batch_size 1 --layers 16 24 --k 192 --grad_acc_steps 8 --ctx_len 2048
```

This is simple, but very memory inefficient. If you want to train SAEs for many layers of a model, we recommend using the `--distribute_modules` flag, which allocates the SAEs for different layers to different GPUs. Currently, we require that the number of GPUs evenly divides the number of layers you're training SAEs for.

```bash
torchrun --nproc_per_node gpu -m sparsify meta-llama/Meta-Llama-3-8B --distribute_modules --batch_size 1 --layer_stride 2 --grad_acc_steps 8 --ctx_len 2048 --k 192 --load_in_8bit --micro_acc_steps 2
python -m sparsify meta-llama/Meta-Llama-3-8B --distribute_modules --batch_size 1 --layer_stride 2 --grad_acc_steps 8 --ctx_len 2048 --k 192 --load_in_8bit --micro_acc_steps 2
```

The above command trains an SAE for every _even_ layer of Llama 3 8B, using all available GPUs. It accumulates gradients over 8 minibatches, and splits each minibatch into 2 microbatches before feeding them into the SAE encoder, thus saving a lot of memory. It also loads the model in 8-bit precision using `bitsandbytes`. This command requires no more than 48GB of memory per GPU on an 8 GPU node.
Expand Down
198 changes: 1 addition & 197 deletions sparsify/__main__.py
Original file line number Diff line number Diff line change
@@ -1,200 +1,4 @@
import os
from contextlib import nullcontext, redirect_stdout
from dataclasses import dataclass
from datetime import timedelta
from multiprocessing import cpu_count

import torch
import torch.distributed as dist
from datasets import Dataset, load_dataset
from safetensors.torch import load_model
from simple_parsing import field, parse
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
PreTrainedModel,
)

from .data import MemmapDataset, chunk_and_tokenize
from .trainer import TrainConfig, Trainer
from .utils import simple_parse_args_string


@dataclass
class RunConfig(TrainConfig):
model: str = field(
default="HuggingFaceTB/SmolLM2-135M",
positional=True,
)
"""Name of the model to train."""

dataset: str = field(
default="EleutherAI/SmolLM2-135M-10B",
positional=True,
)
"""Path to the dataset to use for training."""

split: str = "train"
"""Dataset split to use for training."""

ctx_len: int = 2048
"""Context length to use for training."""

# Use a dummy encoding function to prevent the token from being saved
# to disk in plain text
hf_token: str | None = field(default=None, encoding_fn=lambda _: None)
"""Huggingface API token for downloading models."""

revision: str | None = None
"""Model revision to use for training."""

load_in_8bit: bool = False
"""Load the model in 8-bit mode."""

max_examples: int | None = None
"""Maximum number of examples to use for training."""

resume: bool = False
"""Whether to try resuming from the checkpoint present at `checkpoints/run_name`."""

text_column: str = "text"
"""Column name to use for text data."""

shuffle_seed: int = 42
"""Random seed for shuffling the dataset."""

data_preprocessing_num_proc: int = field(
default_factory=lambda: cpu_count() // 2,
)
"""Number of processes to use for preprocessing data"""

data_args: str = field(
default="",
)
"""Arguments to pass to the HuggingFace dataset constructor in the
format 'arg1=val1,arg2=val2'."""


def load_artifacts(
args: RunConfig, rank: int
) -> tuple[PreTrainedModel, Dataset | MemmapDataset]:
if args.load_in_8bit:
dtype = torch.float16
elif torch.cuda.is_bf16_supported():
dtype = torch.bfloat16
else:
dtype = "auto"

# End-to-end training requires a model with a causal LM head
model_cls = AutoModel if args.loss_fn == "fvu" else AutoModelForCausalLM
model = model_cls.from_pretrained(
args.model,
device_map={"": f"cuda:{rank}"},
quantization_config=(
BitsAndBytesConfig(load_in_8bit=args.load_in_8bit)
if args.load_in_8bit
else None
),
revision=args.revision,
torch_dtype=dtype,
token=args.hf_token,
)

# For memmap-style datasets
if args.dataset.endswith(".bin"):
dataset = MemmapDataset(args.dataset, args.ctx_len, args.max_examples)
else:
# For Huggingface datasets
try:
kwargs = simple_parse_args_string(args.data_args)
dataset = load_dataset(args.dataset, split=args.split, **kwargs)
except ValueError as e:
# Automatically use load_from_disk if appropriate
if "load_from_disk" in str(e):
dataset = Dataset.load_from_disk(args.dataset, keep_in_memory=False)
else:
raise e

assert isinstance(dataset, Dataset)
if "input_ids" not in dataset.column_names:
tokenizer = AutoTokenizer.from_pretrained(args.model, token=args.hf_token)
dataset = chunk_and_tokenize(
dataset,
tokenizer,
max_seq_len=args.ctx_len,
num_proc=args.data_preprocessing_num_proc,
text_key=args.text_column,
)
else:
print("Dataset already tokenized; skipping tokenization.")

print(f"Shuffling dataset with seed {args.shuffle_seed}")
dataset = dataset.shuffle(args.shuffle_seed)

dataset = dataset.with_format("torch")
if limit := args.max_examples:
dataset = dataset.select(range(limit))

return model, dataset


def run():
local_rank = os.environ.get("LOCAL_RANK")
ddp = local_rank is not None
rank = int(local_rank) if ddp else 0

if ddp:
torch.cuda.set_device(int(local_rank))

# Increase the default timeout in order to account for slow downloads
# and data preprocessing on the main rank
dist.init_process_group(
"nccl", device_id=torch.device(rank), timeout=timedelta(minutes=20)
)

if rank == 0:
print(f"Using DDP across {dist.get_world_size()} GPUs.")

args = parse(RunConfig)

# Prevent ranks other than 0 from printing
with nullcontext() if rank == 0 else redirect_stdout(None):
# Awkward hack to prevent other ranks from duplicating data preprocessing
if not ddp or rank == 0:
model, dataset = load_artifacts(args, rank)
if ddp:
dist.barrier()
if rank != 0:
model, dataset = load_artifacts(args, rank)

# Drop examples that are indivisible across processes to prevent deadlock
remainder_examples = len(dataset) % dist.get_world_size()
dataset = dataset.select(range(len(dataset) - remainder_examples))

dataset = dataset.shard(dist.get_world_size(), rank)

# Drop examples that are indivisible across processes to prevent deadlock
remainder_examples = len(dataset) % dist.get_world_size()
dataset = dataset.select(range(len(dataset) - remainder_examples))

print(f"Training on '{args.dataset}' (split '{args.split}')")
print(f"Storing model weights in {model.dtype}")

trainer = Trainer(args, dataset, model)
if args.resume:
trainer.load_state(f"checkpoints/{args.run_name}" or "checkpoints/unnamed")
elif args.finetune:
for name, sae in trainer.saes.items():
load_model(
sae,
f"{args.finetune}/{name}/sae.safetensors",
device=str(model.device),
)

trainer.fit()

from sparsify.train import run

if __name__ == "__main__":
run()
2 changes: 1 addition & 1 deletion sparsify/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def chunk_and_tokenize(
tokenizer: PreTrainedTokenizerBase,
*,
format: str = "torch",
num_proc: int = cpu_count() // 2,
num_proc: int = min(64, cpu_count() // 2),
text_key: str = "text",
max_seq_len: int = 2048,
return_final_batch: bool = False,
Expand Down
78 changes: 78 additions & 0 deletions sparsify/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import socket
from typing import Any, Callable

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, start_processes


def dist_worker(
worker: Callable,
*worker_args,
):
try:
worker(*worker_args)
finally:
if dist.is_initialized():
try:
dist.barrier()
except Exception as e:
print(f"Barrier failed during cleanup: {e}")
pass

dist.destroy_process_group()


def handle_distribute(process_name: str, worker, const_worker_args: list[Any]):
"""
Launch a distributed multi-process job over all visible CUDA devices.

Parameters
----------
process_name : str
Label used by Torch Elastic to tag logs and processes.
worker : Callable
Function that will be executed on every spawned process. It must accept
``(rank, world_size, *const_worker_args)`` in that order.
const_worker_args : list
Arguments passed verbatim to every worker invocation after ``rank`` and
``world_size``. These are typically configuration or shared datasets.
"""
world_size = torch.cuda.device_count()
if world_size <= 1:
# Run the worker directly if no distributed training is needed. This is great
# for debugging purposes.
worker(0, 1, *const_worker_args)
else:
# Set up multiprocessing and distributed training
mp.set_sharing_strategy("file_system")

# Find an available port for distributed training
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
_, port = s.getsockname()

ctx = None
try:
ctx = start_processes(
process_name,
dist_worker,
args={
i: (worker, i, world_size, *const_worker_args)
for i in range(world_size)
},
envs={
i: {
"LOCAL_RANK": str(i),
"MASTER_ADDR": "localhost",
"MASTER_PORT": str(port),
}
for i in range(world_size)
},
logs_specs=DefaultLogsSpecs(),
)
ctx.wait()
finally:
if ctx is not None:
ctx.close() # Kill any processes that are still running
Loading