@@ -312,6 +312,8 @@ def __init__(self, config):
312312 ln_f = RMSNorm (config .n_embd , config .norm_eps ),
313313 ))
314314 self .lm_head = nn .Linear (config .n_embd , config .vocab_size , bias = False )
315+ if config .tied_embeddings :
316+ self .transformer .wte .weight = self .lm_head .weight
315317
316318 # init all weights, use a torch rng object to be very careful
317319 self .init_rng = torch .Generator ()
@@ -433,10 +435,16 @@ def unpermute(w, n_heads, dim1, dim2):
433435 return checkpoint
434436
435437 @classmethod
436- def from_pretrained_llama3_hf (cls , model_id ):
438+ def from_pretrained_llama3_hf (cls , model_id , untie ):
437439 """Loads pretrained LLaMA model weights from HuggingFace"""
438440 from transformers import AutoModelForCausalLM , AutoTokenizer
439441 model_args = MODEL_DICT [model_id ]
442+ if untie :
443+ if not model_args .tied_embeddings :
444+ print ("Model embeddings are not tied, --untie has no effect." )
445+ else :
446+ print ("Untying token embeddings and LM head." )
447+ model_args .tied_embeddings = False
440448
441449 model = AutoModelForCausalLM .from_pretrained (model_id )
442450 checkpoint = LLaMA .adapt_llama_state_dict_keys_hf (model .state_dict (), model_args )
@@ -876,7 +884,7 @@ def write_bf16(tensor, file):
876884 b = t .numpy ().tobytes ()
877885 file .write (b )
878886
879- def write_tensors (model_tensors , L , file , dtype ):
887+ def write_tensors (model_tensors , L , tied , file , dtype ):
880888 # writes LLaMA 3 model's weights to a binary file
881889 # things get a bit more complicated though:
882890 # 1) We want to maintain the ability to finetune just the biases in the C code
@@ -894,7 +902,8 @@ def write_tensors(model_tensors, L, file, dtype):
894902 assert dtype in {"float32" , "bfloat16" }
895903 write_fun = write_fp32 if dtype == "float32" else write_bf16
896904 write_fun (model_tensors ["transformer.wte.weight" ], file ) # (V, C)
897- write_fun (model_tensors ["lm_head.weight" ], file ) # (V, C) # <--- hack (3) here!
905+ if not tied :
906+ write_fun (model_tensors ["lm_head.weight" ], file ) # (V, C) # <--- hack (3) here!
898907 for i in range (L ): # (L, C)
899908 write_fun (model_tensors [f"transformer.h.{ i } .ln_1.weight" ], file )
900909 for i in range (L ): # (L, C)
@@ -954,8 +963,9 @@ def write_model(model, filename, dtype):
954963 header_int [7 ] = model .config .n_embd
955964 header_int [8 ] = model .config .multiple_of
956965 header_int [9 ] = int (model .config .use_scaled_rope )
957- header_int [10 ] = int (model .config .version .split ('.' )[0 ]) # major version
958- header_int [11 ] = int (model .config .version .split ('.' )[1 ]) # minor version
966+ header_int [10 ] = int (model .config .tied_embeddings )
967+ header_int [11 ] = int (model .config .version .split ('.' )[0 ]) # major version
968+ header_int [12 ] = int (model .config .version .split ('.' )[1 ]) # minor version
959969 # float section of the header
960970 header_float = torch .zeros (256 , dtype = torch .float32 )
961971 header_float [0 ] = model .config .ffn_dim_multiplier
@@ -967,7 +977,7 @@ def write_model(model, filename, dtype):
967977 with open (filename , "wb" ) as file :
968978 file .write (header_int .numpy ().tobytes ()) # int header
969979 file .write (header_float .numpy ().tobytes ()) # float header
970- write_tensors (params , model .config .n_layer , file , dtype ) # params
980+ write_tensors (params , model .config .n_layer , model . config . tied_embeddings , file , dtype ) # params
971981 print (f"wrote { filename } " )
972982
973983def write_state (model , x , y , logits , loss , filename ):
@@ -993,7 +1003,7 @@ def write_state(model, x, y, logits, loss, filename):
9931003 # loss (single float, result of the cross entropy loss)
9941004 write_fp32 (loss .cpu (), file )
9951005 # gradients
996- write_tensors (grads , model .config .n_layer , file , "float32" )
1006+ write_tensors (grads , model .config .n_layer , model . config . tied_embeddings , file , "float32" )
9971007 print (f"wrote { filename } " )
9981008
9991009
@@ -1036,6 +1046,7 @@ def print0(*args, **kwargs):
10361046 parser .add_argument ("--output_dir" , type = str , default = "" , help = "output directory to which to write logs and checkpoints" )
10371047 parser .add_argument ("--model" , type = str , default = "meta-llama/Llama-3.2-1B" , help = "chose the llama model" )
10381048 parser .add_argument ("--depth" , type = int , default = - 1 , help = "load only a subset of the model's layers" )
1049+ parser .add_argument ("--untie" , type = int , default = False , help = "Untie token embeddings and LM-head, even if they are tied in the checkpoint." )
10391050 # token layout for each step of the optimization
10401051 parser .add_argument ("--batch_size" , type = int , default = 4 , help = "batch size, in units of #batch dimensions" )
10411052 parser .add_argument ("--sequence_length" , type = int , default = 64 , help = "sequence length" )
@@ -1140,7 +1151,7 @@ def print0(*args, **kwargs):
11401151
11411152 # init the model
11421153 if args .use_hf :
1143- model = LLaMA .from_pretrained_llama3_hf (args .model )
1154+ model = LLaMA .from_pretrained_llama3_hf (args .model , args . untie )
11441155 else : # use Meta's checkpoint
11451156 assert args .ckpt_dir is not None and os .path .exists (args .ckpt_dir ), f"llama3 ckpt dir { args .ckpt_dir } does not exist"
11461157 assert args .tokenizer_path is not None and os .path .exists (args .tokenizer_path ), f"llama3 tokenizer path { args .tokenizer_path } does not exist"
0 commit comments