Skip to content

Commit da88cb1

Browse files
authored
Merge pull request #803 from ngc92/ngc92/llama3-tied-weights
Tied and untied weights for LLama3
2 parents 5c17e4e + 9caeceb commit da88cb1

File tree

4 files changed

+91
-19
lines changed

4 files changed

+91
-19
lines changed

.github/workflows/ci_gpu.yml

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ jobs:
113113
run: ./test_gpt2cu
114114

115115
build-and-test-llama3:
116+
name: Build and test LLama3.2 1B
116117
runs-on: ubicloud-gpu-standard-1-latest
117118
env:
118119
HF_TOKEN: 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
@@ -150,18 +151,52 @@ jobs:
150151
- name: Build BF16 precision
151152
run: PRECISION=BF16 make train_llama3cu test_llama3cu
152153

153-
- name: Run default
154+
- name: Run default (BF16)
154155
run: ./test_llama3cu
155156

156-
- name: Run no recompute GeLU
157+
- name: Run no recompute GeLU (BF16)
157158
run: ./test_llama3cu -r 0
158159

159-
- name: Run no master weights
160+
- name: Run no master weights (BF16)
160161
run: ./test_llama3cu -w 0
161162

162-
- name: Run recompute LN
163+
- name: Run recompute LN (BF16)
163164
run: ./test_llama3cu -r 2
164165

166+
build-and-test-llama3-untied:
167+
name: Build and test LLama3.2 1B with untie weights
168+
runs-on: ubicloud-gpu-standard-1-latest
169+
env:
170+
HF_TOKEN: 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
171+
steps:
172+
- name: Checkout code
173+
uses: actions/checkout@v4
174+
- run: echo "::add-mask::$HF_TOKEN"
175+
176+
- name: Install OpenMP
177+
run: sudo apt-get update && sudo apt-get install -y libomp-dev
178+
179+
- name: Install dependencies
180+
run: pip install -r requirements.txt
181+
182+
- name: Run preprocessing
183+
run: python dev/data/tinyshakespeare.py --model_desc llama-3
184+
185+
- name: Train model
186+
run: python train_llama3.py --write_tensors 1 --dtype float32 --untie 1 --depth 10
187+
188+
- name: Build FP32 precision
189+
run: PRECISION=FP32 make test_llama3cu
190+
191+
- name: Run default
192+
run: ./test_llama3cu
193+
194+
- name: Build BF16 precision
195+
run: PRECISION=BF16 make train_llama3cu test_llama3cu
196+
197+
- name: Run default
198+
run: ./test_llama3cu
199+
165200
unit-tests-gpu:
166201
runs-on: ubicloud-gpu-standard-1-latest
167202

test_llama3.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ float* float_cpu_malloc_and_point_parameters(FloatParameterTensors* params, size
8585
*(ptrs[i]) = params_memory_iterator;
8686
params_memory_iterator += param_sizes[i];
8787
}
88+
if(param_sizes[1] == 0) {
89+
params->wlmhead = nullptr;
90+
}
8891
return params_memory;
8992
}
9093

train_llama3.cu

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ typedef struct {
106106
float norm_eps; // epsilon used in layernorm, e.g. 1e-5
107107
float rope_theta; // theta used in ROPE attention, e.g. 500000.0 (<-- new in Llama 3)
108108
bool use_biases; // we always allocate memory for biases; to match llama3 they are not used
109+
bool tied_weights; // untied for large models (3.1 8B/70B/405B), tied for small (3.2 1B/3B)
109110
} LLama3Config;
110111

111112
// the parameters of the model
@@ -153,7 +154,12 @@ void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, LLama3Co
153154
size_t ffn_channels = hidden_dim * 2; // c_fc + c_fc2 concatenated
154155
// now populate the parameter sizes
155156
param_sizes[0] = Vp * C; // wte
156-
param_sizes[1] = Vp * C; // (3) lm_head (final classifier layer weights)
157+
if(config.tied_weights) {
158+
param_sizes[1] = 0; // no lm_head with tied weights
159+
} else {
160+
param_sizes[1] = Vp * C; // (3) lm_head (final classifier layer weights)
161+
}
162+
157163
param_sizes[2] = L * C; // ln1w
158164
param_sizes[3] = L * C; // ln1b; (1) all biases are zero it's ok
159165
param_sizes[4] = L * (qkv_channels) * C; // qkvw
@@ -195,6 +201,10 @@ void* malloc_and_point_parameters(ParameterTensors* params, size_t* param_elemen
195201
*(ptrs[i]) = (floatX*)params_memory_iterator;
196202
params_memory_iterator += param_elements[i] * param_sizeof[i];
197203
}
204+
// tied weights?
205+
if(param_elements[1] == 0) {
206+
params->wlmhead = nullptr;
207+
}
198208
return params_memory;
199209
}
200210

@@ -506,8 +516,9 @@ void llama3_write_to_checkpoint(LLama3 *model, const char* checkpoint_path) {
506516
model_header[7] = model->config.channels;
507517
model_header[8] = model->config.multiple_of;
508518
model_header[9] = model->config.use_scaled_rope;
509-
model_header[10] = 3;
510-
model_header[11] = 1;
519+
model_header[10] = model->config.tied_weights;
520+
model_header[11] = 3;
521+
model_header[12] = model->config.tied_weights ? 2 : 1;
511522
fwriteCheck(model_header, sizeof(int), 256, model_file);
512523
float float_header[256];
513524
float_header[0] = model->config.ffn_dim_multiplier;
@@ -580,8 +591,9 @@ void llama3_build_from_checkpoint(LLama3 *model, const char* checkpoint_path, bo
580591
model->config.multiple_of = header_int[8];
581592
model->config.use_scaled_rope = header_int[9];
582593
model->config.use_biases = false;
583-
int major_version = header_int[10]; // currently unused, e.g. 3
584-
int minor_version = header_int[11]; // currently unused, e.g. 1 (so Llama 3.1)
594+
model->config.tied_weights = header_int[10];
595+
int major_version = header_int[11]; // currently unused, e.g. 3
596+
int minor_version = header_int[12]; // 1 or 2
585597
// now the float section
586598
model->config.ffn_dim_multiplier = header_float[0];
587599
model->config.norm_eps = header_float[1];
@@ -740,7 +752,9 @@ void llama3_forward(LLama3 *model, const int* inputs, size_t B, size_t T) {
740752
}
741753
}
742754

743-
matmul_forward_cublaslt(acts.output, acts.lnf, params.wlmhead, NULL, B, T, C, Vp, main_stream);
755+
floatX* lm_head = model->config.tied_weights ? params.wte : params.wlmhead;
756+
matmul_forward_cublaslt(acts.output, acts.lnf, lm_head, NULL, B, T, C, Vp, main_stream);
757+
744758
cudaCheck(cudaDeviceSynchronize());
745759
}
746760

@@ -836,7 +850,10 @@ void llama3_backward_and_reduce(LLama3 *model, int* inputs, const int* targets,
836850
// technically that is a small, inline backward() pass of calculating
837851
// total, final loss as the mean over all losses over all (B,T) positions in the batch
838852
// next: backward the classifier matmul
839-
matmul_backward(model->acts.scratch_bt4c, grads.wlmhead, NULL, acts.output, acts.lnf, params.wlmhead, NULL, B, T, C, Vp, main_stream);
853+
floatX* w_lm_head = model->config.tied_weights ? params.wte : params.wlmhead;
854+
floatX* g_lm_head = model->config.tied_weights ? grads.wte : grads.wlmhead;
855+
856+
matmul_backward(model->acts.scratch_bt4c, g_lm_head, NULL, acts.output, acts.lnf, w_lm_head, NULL, B, T, C, Vp, main_stream);
840857
// backward the final layernorm
841858
floatX* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3
842859
rmsnorm_backward(dresidual, grads.lnfw, scratchF, model->acts.scratch_bt4c, residual, params.lnfw, acts.lnf_rstd, B, T, C, main_stream);
@@ -1076,6 +1093,8 @@ void llama3_update(LLama3 *model, float learning_rate, float beta1, float beta2,
10761093
}
10771094

10781095
ShardInfo tensor = llama3_get_tensor_at_layer(model, 0, i);
1096+
if(tensor.size == 0)
1097+
continue;
10791098
ShardInfo shard = multi_gpu_get_shard_offset(tensor.size, multi_gpu_config, 1);
10801099
ptrdiff_t local_offset_full = tensor.offset + shard.offset;
10811100
ptrdiff_t local_offset_partial = tensor.offset / multi_gpu_config->num_processes;
@@ -1144,6 +1163,10 @@ float llama3_estimate_mfu(LLama3 *model, int num_tokens, float dt) {
11441163
second is the attention matmul, which is also usually a small contribution.
11451164
*/
11461165
size_t N = model->num_parameters;
1166+
if(!model->config.tied_weights) {
1167+
N -= model->param_elements[0]; // remove embedding parameters, which can be significant at 128k vocab
1168+
}
1169+
11471170
int L = model->config.num_layers;
11481171
int C = model->config.channels;
11491172
int T = model->seq_len;

train_llama3.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,8 @@ def __init__(self, config):
312312
ln_f = RMSNorm(config.n_embd, config.norm_eps),
313313
))
314314
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
315+
if config.tied_embeddings:
316+
self.transformer.wte.weight = self.lm_head.weight
315317

316318
# init all weights, use a torch rng object to be very careful
317319
self.init_rng = torch.Generator()
@@ -433,10 +435,16 @@ def unpermute(w, n_heads, dim1, dim2):
433435
return checkpoint
434436

435437
@classmethod
436-
def from_pretrained_llama3_hf(cls, model_id):
438+
def from_pretrained_llama3_hf(cls, model_id, untie):
437439
"""Loads pretrained LLaMA model weights from HuggingFace"""
438440
from transformers import AutoModelForCausalLM, AutoTokenizer
439441
model_args = MODEL_DICT[model_id]
442+
if untie:
443+
if not model_args.tied_embeddings:
444+
print("Model embeddings are not tied, --untie has no effect.")
445+
else:
446+
print("Untying token embeddings and LM head.")
447+
model_args.tied_embeddings = False
440448

441449
model = AutoModelForCausalLM.from_pretrained(model_id)
442450
checkpoint = LLaMA.adapt_llama_state_dict_keys_hf(model.state_dict(), model_args)
@@ -876,7 +884,7 @@ def write_bf16(tensor, file):
876884
b = t.numpy().tobytes()
877885
file.write(b)
878886

879-
def write_tensors(model_tensors, L, file, dtype):
887+
def write_tensors(model_tensors, L, tied, file, dtype):
880888
# writes LLaMA 3 model's weights to a binary file
881889
# things get a bit more complicated though:
882890
# 1) We want to maintain the ability to finetune just the biases in the C code
@@ -894,7 +902,8 @@ def write_tensors(model_tensors, L, file, dtype):
894902
assert dtype in {"float32", "bfloat16"}
895903
write_fun = write_fp32 if dtype == "float32" else write_bf16
896904
write_fun(model_tensors["transformer.wte.weight"], file) # (V, C)
897-
write_fun(model_tensors["lm_head.weight"], file) # (V, C) # <--- hack (3) here!
905+
if not tied:
906+
write_fun(model_tensors["lm_head.weight"], file) # (V, C) # <--- hack (3) here!
898907
for i in range(L): # (L, C)
899908
write_fun(model_tensors[f"transformer.h.{i}.ln_1.weight"], file)
900909
for i in range(L): # (L, C)
@@ -954,8 +963,9 @@ def write_model(model, filename, dtype):
954963
header_int[7] = model.config.n_embd
955964
header_int[8] = model.config.multiple_of
956965
header_int[9] = int(model.config.use_scaled_rope)
957-
header_int[10] = int(model.config.version.split('.')[0]) # major version
958-
header_int[11] = int(model.config.version.split('.')[1]) # minor version
966+
header_int[10] = int(model.config.tied_embeddings)
967+
header_int[11] = int(model.config.version.split('.')[0]) # major version
968+
header_int[12] = int(model.config.version.split('.')[1]) # minor version
959969
# float section of the header
960970
header_float = torch.zeros(256, dtype=torch.float32)
961971
header_float[0] = model.config.ffn_dim_multiplier
@@ -967,7 +977,7 @@ def write_model(model, filename, dtype):
967977
with open(filename, "wb") as file:
968978
file.write(header_int.numpy().tobytes()) # int header
969979
file.write(header_float.numpy().tobytes()) # float header
970-
write_tensors(params, model.config.n_layer, file, dtype) # params
980+
write_tensors(params, model.config.n_layer, model.config.tied_embeddings, file, dtype) # params
971981
print(f"wrote {filename}")
972982

973983
def write_state(model, x, y, logits, loss, filename):
@@ -993,7 +1003,7 @@ def write_state(model, x, y, logits, loss, filename):
9931003
# loss (single float, result of the cross entropy loss)
9941004
write_fp32(loss.cpu(), file)
9951005
# gradients
996-
write_tensors(grads, model.config.n_layer, file, "float32")
1006+
write_tensors(grads, model.config.n_layer, model.config.tied_embeddings, file, "float32")
9971007
print(f"wrote {filename}")
9981008

9991009

@@ -1036,6 +1046,7 @@ def print0(*args, **kwargs):
10361046
parser.add_argument("--output_dir", type=str, default="", help="output directory to which to write logs and checkpoints")
10371047
parser.add_argument("--model", type=str, default="meta-llama/Llama-3.2-1B", help="chose the llama model")
10381048
parser.add_argument("--depth", type=int, default=-1, help="load only a subset of the model's layers")
1049+
parser.add_argument("--untie", type=int, default=False, help="Untie token embeddings and LM-head, even if they are tied in the checkpoint.")
10391050
# token layout for each step of the optimization
10401051
parser.add_argument("--batch_size", type=int, default=4, help="batch size, in units of #batch dimensions")
10411052
parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
@@ -1140,7 +1151,7 @@ def print0(*args, **kwargs):
11401151

11411152
# init the model
11421153
if args.use_hf:
1143-
model = LLaMA.from_pretrained_llama3_hf(args.model)
1154+
model = LLaMA.from_pretrained_llama3_hf(args.model, args.untie)
11441155
else: # use Meta's checkpoint
11451156
assert args.ckpt_dir is not None and os.path.exists(args.ckpt_dir), f"llama3 ckpt dir {args.ckpt_dir} does not exist"
11461157
assert args.tokenizer_path is not None and os.path.exists(args.tokenizer_path), f"llama3 tokenizer path {args.tokenizer_path} does not exist"

0 commit comments

Comments
 (0)