Skip to content

Commit d36f0e6

Browse files
authored
Merge pull request #811 from ngc92/llama-fixes
Llama fixes
2 parents 49cef1d + 76a7cce commit d36f0e6

File tree

7 files changed

+55
-52
lines changed

7 files changed

+55
-52
lines changed

.github/workflows/ci_gpu.yml

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,20 +104,14 @@ jobs:
104104
git clone https://github.com/NVIDIA/cudnn-frontend.git
105105

106106
- name: Build with cuDNN
107-
run: USE_CUDNN=1 make test_gpt2cu train_gpt2cu test_gpt2fp32cu train_gpt2fp32cu
107+
run: USE_CUDNN=1 make test_gpt2cu train_gpt2cu
108108

109109
- name: Train model with cuDNN
110110
run: ./train_gpt2cu
111111

112-
- name: Train model fp32 with cuDNN
113-
run: ./train_gpt2fp32cu
114-
115112
- name: Execute testing program with cuDNN
116113
run: ./test_gpt2cu
117114

118-
- name: Execute testing program fp32 with cuDNN
119-
run: ./test_gpt2fp32cu
120-
121115
build-and-test-llama3:
122116
runs-on: ubicloud-gpu-standard-1-latest
123117
env:
@@ -137,7 +131,9 @@ jobs:
137131
run: python dev/data/tinyshakespeare.py --model_desc llama-3
138132

139133
- name: Train model
140-
run: python train_llama3.py --write_tensors 1 --dtype float32 --offload 1
134+
# use the first 10 layers, so that everything fits into the 20GB of
135+
# the A4000 Ada that we have in CI
136+
run: python train_llama3.py --write_tensors 1 --dtype float32 --depth 10
141137

142138
- name: Build FP32 precision
143139
run: PRECISION=FP32 make test_llama3cu

llmc/attention.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ void attention_backward(floatX* dinp, floatX* dqkvr, floatX* datt, floatX* scrat
263263
matmul_cublaslt(dv, scratch, att, nullptr, HS, T, T, stream, false, true, B * NH, T * HS, T * T, T * HS);
264264
const float scale = 1.0f / sqrtf((float)HS);
265265
// backward into preatt. this is an in-place operation; datt turns into dpreatt here
266-
softmax_autoregressive_backward_inplace_kernel<<<dim3(T / 4, B * NH), 256>>>(datt, att, B, T, C, scale);
266+
softmax_autoregressive_backward_inplace_kernel<<<dim3(T / 4, B * NH), 256, 0, stream>>>(datt, att, B, T, C, scale);
267267
const floatX* dpreatt = datt;
268268
// backward into q
269269
matmul_cublaslt(dq, k, dpreatt, nullptr, HS, T, T, stream, false, false, B * NH, T * HS, T * T, T * HS);

llmc/repkv.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ __global__ void repkv_forward_kernel1(floatX* replicated_qkv,
5050

5151
__global__ void repkv_backward_kernel1(floatX* dinp, const floatX* dout,
5252
int B, int N, int NH, int replicate_factor, int HD) {
53-
// we have a single tensor dout of shapae of (B, N 3 * NH * HD)
53+
// we have a single tensor dout of shape of (B, N 3 * NH * HD)
5454
// we want to reduce sum (for K and V) into (B, N, (NH + 2*(NH/replicate_factor)) * HD)
5555
int idx = blockIdx.x * blockDim.x + threadIdx.x;
5656
if (idx >= B * N * 3 * NH * HD) { return;}
@@ -111,11 +111,11 @@ void repkv_forward(floatX* out, const floatX* inp, int B, int T, int NH, int NH_
111111
}
112112

113113
void repkv_backward(floatX* dinp, const floatX* dout,
114-
const int B, const int T, const int NH, const int NH_KV, const int d) {
114+
const int B, const int T, const int NH, const int NH_KV, const int d, cudaStream_t stream) {
115115
const int block_size = 128;
116116
int total_threads = B * T * (3 * NH) * d;
117117
int num_blocks = CEIL_DIV(total_threads, block_size);
118118
int replicate_factor = NH / NH_KV;
119-
repkv_backward_kernel1<<<num_blocks, block_size>>>(dinp, dout, B, T, NH, replicate_factor, d);
119+
repkv_backward_kernel1<<<num_blocks, block_size, 0, stream>>>(dinp, dout, B, T, NH, replicate_factor, d);
120120
cudaCheck(cudaGetLastError());
121121
}

requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,4 @@ torch
44
tiktoken
55
transformers
66
datasets
7-
requests
8-
torchao
7+
requests

test_llama3.cu

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -128,17 +128,20 @@ int main(int argc, char *argv[]) {
128128
int state_header[256];
129129
freadCheck(state_header, sizeof(int), 256, state_file);
130130
if (state_header[0] != 20240803) { fprintf(stderr, "Bad magic state file\n"); exit(EXIT_FAILURE); }
131-
if (state_header[1] != 2) {
131+
if (state_header[1] != 3) {
132132
fprintf(stderr, "Bad version in state file: %d\n", state_header[1]);
133133
fprintf(stderr, "---> HINT: try to re-run `python train_llama3.py`\n");
134134
exit(EXIT_FAILURE);
135135
}
136136
int B = state_header[2]; // batch size, e.g. 4
137137
int T = state_header[3]; // time / sequence length (e.g. 64, up to maxT)
138+
int steps = state_header[4];
139+
float* expected_losses = (float*) malloc(steps * sizeof(float));
138140
assert(0 <= T && T <= maxT);
139141
printf("[State]\n");
140142
printf("batch_size: %d\n", B);
141143
printf("seq_len: %d\n", T);
144+
printf("steps: %d\n", steps);
142145

143146
set_zero_configs(&multi_gpu_config, 0, model.num_parameters);
144147

@@ -157,6 +160,7 @@ int main(int argc, char *argv[]) {
157160
FloatParameterTensors expected_grads; // will be read from file. right now: all in fp32
158161
float* expected_grads_memory = float_cpu_malloc_and_point_parameters(&expected_grads, model.param_elements);
159162
freadCheck(expected_grads_memory, sizeof(float), model.num_parameters, state_file);
163+
freadCheck(expected_losses, sizeof(float), steps, state_file);
160164
fcloseCheck(state_file);
161165

162166
// this memory will be used to do one single copy of all (mixed precision) GPU grads to CPU grads
@@ -290,27 +294,13 @@ int main(int argc, char *argv[]) {
290294
llama3_update(&model, 1e-5f, 0.9f, 0.95f, 1e-8f, 0.0f, grad_scale, step+1, &multi_gpu_config);
291295

292296
// print the timing information at the end
293-
printf("step %d: loss %f (took %f ms)\n", step+1, model.mean_loss, time_elapsed_s * 1000);
297+
printf("step %d: loss %f norm %f (took %f ms)\n", step+1, model.mean_loss, grad_norm, time_elapsed_s * 1000);
294298
// the expected losses from PyTorch were copied over after the print formatting rounded
295299
// them to 6 decimal places, so we do the same here
296300
float rounded_loss = roundf(model.mean_loss * 1000000) / 1000000;
297301
losses[step] = rounded_loss;
298302
}
299303

300-
// expected losses are as follows, from Python (without CPUOffload)
301-
float expected_losses[10] = {
302-
4.849688f,
303-
3.070303f,
304-
1.711614f,
305-
1.056311f,
306-
0.593335f,
307-
0.428291f,
308-
0.372275f,
309-
0.360507f,
310-
0.355562f,
311-
0.334824f
312-
};
313-
314304
// compare
315305
for (int i = 0; i < 10; i++) {
316306
if (fabsf(losses[i] - expected_losses[i]) >= loss_diff_threshold) {
@@ -377,6 +367,7 @@ int main(int argc, char *argv[]) {
377367
common_free(model);
378368
free(x);
379369
free(y);
370+
free(expected_losses);
380371
free(logits_cpu_raw);
381372
free(logits_cpu);
382373
free(expected_logits);

train_llama3.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -923,7 +923,7 @@ void llama3_backward_and_reduce(LLama3 *model, int* inputs, const int* targets,
923923
// backward rope (this can be done in-place)
924924
rope_backward_inplace(dl_bt4c, dl_bt4c, model->freqs_cis, B, T, NH, hd, main_stream);
925925
// backward repkv (use scratchX as gradient buffer here)
926-
repkv_backward(dl_bt4c2, dl_bt4c, B, T, NH, n_kv_head, hd);
926+
repkv_backward(dl_bt4c2, dl_bt4c, B, T, NH, n_kv_head, hd, main_stream);
927927
// backward QKV projection
928928
if(model->recompute >= 2) {
929929
rmsnorm_forward(l_ln1, l_ln1_rstd, residual, l_ln1w, B, T, C, main_stream);

train_llama3.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def from_pretrained_llama3_meta(cls, ckpt_dir, tokenizer_path):
472472
model.tokenizer = tokenizer
473473
return model
474474

475-
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type, zero_stage, offload):
475+
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type, zero_stage):
476476
# start with all of the candidate parameters
477477
param_dict = {pn: p for pn, p in self.named_parameters()}
478478
# filter out those that do not require grad
@@ -494,14 +494,10 @@ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type,
494494
use_fused = fused_available and device_type == 'cuda'
495495
print0(f"using fused AdamW: {use_fused}")
496496
if zero_stage == 1:
497-
assert not offload
498497
print0("using ZeroRedundancyOptimizer")
499498
optimizer = ZeroRedundancyOptimizer(**optim_groups[0], optimizer_class=torch.optim.AdamW,
500499
lr=learning_rate, betas=betas, fused=use_fused)
501500
optimizer.add_param_group(optim_groups[1])
502-
elif offload:
503-
from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
504-
optimizer = CPUOffloadOptimizer(optim_groups, torch.optim.AdamW, lr=learning_rate, betas=betas, fused=use_fused)
505501
else:
506502
print0("using regular AdamW")
507503
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=use_fused)
@@ -980,9 +976,10 @@ def write_state(model, x, y, logits, loss, filename):
980976
# this can be used for checking the computation correctness in C
981977
header = torch.zeros(256, dtype=torch.int32)
982978
header[0] = 20240803 # magic
983-
header[1] = 2 # version
979+
header[1] = 3 # version
984980
header[2] = x.size(0) # batch size of the batch, B
985981
header[3] = x.size(1) # temporal extent of the batch, T
982+
header[4] = 0
986983
grads = {name: param.grad.cpu() for name, param in model.named_parameters()}
987984
with open(filename, "wb") as file:
988985
# header
@@ -999,6 +996,22 @@ def write_state(model, x, y, logits, loss, filename):
999996
write_tensors(grads, model.config.n_layer, file, "float32")
1000997
print(f"wrote {filename}")
1001998

999+
1000+
def write_training_history(losses, norms, filename):
1001+
# amends the state file with the sequence of losses and grad norms
1002+
assert len(norms) == len(losses)
1003+
with open(filename, "r+b") as f:
1004+
header = np.frombuffer(f.read(256*4), dtype=np.int32).copy()
1005+
header[4] = len(losses)
1006+
f.seek(0, os.SEEK_SET)
1007+
f.write(header.tobytes())
1008+
f.seek(0, os.SEEK_END)
1009+
# write the losses and norms at the end of the file
1010+
f.write(np.asarray(losses).astype(np.float32).tobytes())
1011+
f.write(np.asarray(norms).astype(np.float32).tobytes())
1012+
1013+
print(f"updated {filename}")
1014+
10021015
# -----------------------------------------------------------------------------
10031016
# int main
10041017

@@ -1022,6 +1035,7 @@ def print0(*args, **kwargs):
10221035
parser.add_argument("--input_val_bin", type=str, default="", help="input .bin to eval validation loss on")
10231036
parser.add_argument("--output_dir", type=str, default="", help="output directory to which to write logs and checkpoints")
10241037
parser.add_argument("--model", type=str, default="meta-llama/Llama-3.2-1B", help="chose the llama model")
1038+
parser.add_argument("--depth", type=int, default=-1, help="load only a subset of the model's layers")
10251039
# token layout for each step of the optimization
10261040
parser.add_argument("--batch_size", type=int, default=4, help="batch size, in units of #batch dimensions")
10271041
parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
@@ -1048,7 +1062,6 @@ def print0(*args, **kwargs):
10481062
parser.add_argument("--compile", type=int, default=0, help="torch.compile the model")
10491063
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|float16|bfloat16")
10501064
parser.add_argument("--zero_stage", type=int, default=0, help="zero redundancy optimizer stage (0/1/2/3)")
1051-
parser.add_argument("--offload", type=int, default=0, help="offload optimizer to CPU")
10521065
# python -> C bridge
10531066
parser.add_argument("--write_tensors", type=int, default=0, help="write tensors to disk")
10541067
args = parser.parse_args()
@@ -1133,9 +1146,16 @@ def print0(*args, **kwargs):
11331146
assert args.tokenizer_path is not None and os.path.exists(args.tokenizer_path), f"llama3 tokenizer path {args.tokenizer_path} does not exist"
11341147
model = LLaMA.from_pretrained_llama3_meta(args.ckpt_dir, args.tokenizer_path)
11351148

1136-
# convert the model to the desired precision
1137-
if args.dtype == "float32":
1138-
model = model.to(torch.float32)
1149+
if args.depth > 0:
1150+
assert args.depth < len(model.transformer.h), f"invalid depth {args.depth}, model has {len(model.transformer.h)} blocks"
1151+
model.transformer.h = model.transformer.h[0:args.depth]
1152+
model.config.n_layer = args.depth
1153+
1154+
# PT optimizer doesn't do stochastic rounding, so we
1155+
# really want the model to be in fp32 precision:
1156+
# --dtype should only enable AMP
1157+
# as the original checkpoints are in 16 bit, we need to convert
1158+
model = model.to(torch.float32)
11391159

11401160
model = model.to(device)
11411161
model.train()
@@ -1185,7 +1205,7 @@ def print0(*args, **kwargs):
11851205
# init the optimizer
11861206
optimizer = raw_model.configure_optimizers(weight_decay=args.weight_decay,
11871207
learning_rate=args.learning_rate, betas=(0.9, 0.95),
1188-
device_type=device, zero_stage=zero_stage, offload=args.offload)
1208+
device_type=device, zero_stage=zero_stage)
11891209

11901210
# learning rate decay scheduler (cosine with warmup)
11911211
def get_lr(it):
@@ -1205,6 +1225,8 @@ def get_lr(it):
12051225
if device == "cuda":
12061226
torch.cuda.reset_peak_memory_stats()
12071227
timings = []
1228+
losses = []
1229+
norms = []
12081230
norm = -1.0 # dummy value to print in inference-only mode
12091231
for step in range(args.num_iterations + 1):
12101232
t0 = time.time()
@@ -1298,16 +1320,6 @@ def get_lr(it):
12981320
dist.all_reduce(lossf, op=dist.ReduceOp.AVG)
12991321
lossf = lossf.item()
13001322
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
1301-
if args.offload:
1302-
# CPUOffloadOptimizer is *not* compatible with gradient clipping and will *silently*
1303-
# give wrong results. So we
1304-
# a) explicitly wait for it to finish its gradients transfers
1305-
# b) overwrite the CPU gradients with the clipped GPU gradients.
1306-
# This is terribly inefficient, but correct and lets us run CI on
1307-
# small(ish) GPUs
1308-
torch.cuda.synchronize()
1309-
for gpu, cpu in optimizer.param_d2h_map.items():
1310-
cpu.grad[...] = gpu.grad[...]
13111323

13121324
# determine and set the learning rate for this iteration
13131325
lr = get_lr(step)
@@ -1327,6 +1339,8 @@ def get_lr(it):
13271339
t1 = time.time()
13281340
# the 0th iteration is often an outlier (much slower) => skip logging it
13291341
tokens_per_second = grad_accum_steps * ddp_world_size * B * T / (t1-t0)
1342+
losses.append(lossf)
1343+
norms.append(norm.item())
13301344
print0(f"step {step+1:4d}/{args.num_iterations} | train loss {lossf:.6f} | norm {norm:.4f} | lr {lr:.2e} | ({(t1-t0)*1000:.2f} ms | {tokens_per_second:.0f} tok/s)")
13311345
# log to logile
13321346
if master_process and logfile is not None:
@@ -1337,6 +1351,9 @@ def get_lr(it):
13371351
if step > 0 and step > args.num_iterations - 20:
13381352
timings.append(t1-t0)
13391353

1354+
if master_process and args.write_tensors and (not args.inference_only):
1355+
write_training_history(losses, norms, f"llama3_{model_size_str}_debug_state.bin")
1356+
13401357
# print the average of the last 20 timings, to get something smooth-ish
13411358
timings = timings[-20:]
13421359
print0(f"final {len(timings)} iters avg: {np.mean(timings)*1000:.3f}ms")

0 commit comments

Comments
 (0)