@@ -472,7 +472,7 @@ def from_pretrained_llama3_meta(cls, ckpt_dir, tokenizer_path):
472472 model .tokenizer = tokenizer
473473 return model
474474
475- def configure_optimizers (self , weight_decay , learning_rate , betas , device_type , zero_stage , offload ):
475+ def configure_optimizers (self , weight_decay , learning_rate , betas , device_type , zero_stage ):
476476 # start with all of the candidate parameters
477477 param_dict = {pn : p for pn , p in self .named_parameters ()}
478478 # filter out those that do not require grad
@@ -494,14 +494,10 @@ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type,
494494 use_fused = fused_available and device_type == 'cuda'
495495 print0 (f"using fused AdamW: { use_fused } " )
496496 if zero_stage == 1 :
497- assert not offload
498497 print0 ("using ZeroRedundancyOptimizer" )
499498 optimizer = ZeroRedundancyOptimizer (** optim_groups [0 ], optimizer_class = torch .optim .AdamW ,
500499 lr = learning_rate , betas = betas , fused = use_fused )
501500 optimizer .add_param_group (optim_groups [1 ])
502- elif offload :
503- from torchao .prototype .low_bit_optim import CPUOffloadOptimizer
504- optimizer = CPUOffloadOptimizer (optim_groups , torch .optim .AdamW , lr = learning_rate , betas = betas , fused = use_fused )
505501 else :
506502 print0 ("using regular AdamW" )
507503 optimizer = torch .optim .AdamW (optim_groups , lr = learning_rate , betas = betas , fused = use_fused )
@@ -980,9 +976,10 @@ def write_state(model, x, y, logits, loss, filename):
980976 # this can be used for checking the computation correctness in C
981977 header = torch .zeros (256 , dtype = torch .int32 )
982978 header [0 ] = 20240803 # magic
983- header [1 ] = 2 # version
979+ header [1 ] = 3 # version
984980 header [2 ] = x .size (0 ) # batch size of the batch, B
985981 header [3 ] = x .size (1 ) # temporal extent of the batch, T
982+ header [4 ] = 0
986983 grads = {name : param .grad .cpu () for name , param in model .named_parameters ()}
987984 with open (filename , "wb" ) as file :
988985 # header
@@ -999,6 +996,22 @@ def write_state(model, x, y, logits, loss, filename):
999996 write_tensors (grads , model .config .n_layer , file , "float32" )
1000997 print (f"wrote { filename } " )
1001998
999+
1000+ def write_training_history (losses , norms , filename ):
1001+ # amends the state file with the sequence of losses and grad norms
1002+ assert len (norms ) == len (losses )
1003+ with open (filename , "r+b" ) as f :
1004+ header = np .frombuffer (f .read (256 * 4 ), dtype = np .int32 ).copy ()
1005+ header [4 ] = len (losses )
1006+ f .seek (0 , os .SEEK_SET )
1007+ f .write (header .tobytes ())
1008+ f .seek (0 , os .SEEK_END )
1009+ # write the losses and norms at the end of the file
1010+ f .write (np .asarray (losses ).astype (np .float32 ).tobytes ())
1011+ f .write (np .asarray (norms ).astype (np .float32 ).tobytes ())
1012+
1013+ print (f"updated { filename } " )
1014+
10021015# -----------------------------------------------------------------------------
10031016# int main
10041017
@@ -1022,6 +1035,7 @@ def print0(*args, **kwargs):
10221035 parser .add_argument ("--input_val_bin" , type = str , default = "" , help = "input .bin to eval validation loss on" )
10231036 parser .add_argument ("--output_dir" , type = str , default = "" , help = "output directory to which to write logs and checkpoints" )
10241037 parser .add_argument ("--model" , type = str , default = "meta-llama/Llama-3.2-1B" , help = "chose the llama model" )
1038+ parser .add_argument ("--depth" , type = int , default = - 1 , help = "load only a subset of the model's layers" )
10251039 # token layout for each step of the optimization
10261040 parser .add_argument ("--batch_size" , type = int , default = 4 , help = "batch size, in units of #batch dimensions" )
10271041 parser .add_argument ("--sequence_length" , type = int , default = 64 , help = "sequence length" )
@@ -1048,7 +1062,6 @@ def print0(*args, **kwargs):
10481062 parser .add_argument ("--compile" , type = int , default = 0 , help = "torch.compile the model" )
10491063 parser .add_argument ("--dtype" , type = str , default = "bfloat16" , help = "float32|float16|bfloat16" )
10501064 parser .add_argument ("--zero_stage" , type = int , default = 0 , help = "zero redundancy optimizer stage (0/1/2/3)" )
1051- parser .add_argument ("--offload" , type = int , default = 0 , help = "offload optimizer to CPU" )
10521065 # python -> C bridge
10531066 parser .add_argument ("--write_tensors" , type = int , default = 0 , help = "write tensors to disk" )
10541067 args = parser .parse_args ()
@@ -1133,9 +1146,16 @@ def print0(*args, **kwargs):
11331146 assert args .tokenizer_path is not None and os .path .exists (args .tokenizer_path ), f"llama3 tokenizer path { args .tokenizer_path } does not exist"
11341147 model = LLaMA .from_pretrained_llama3_meta (args .ckpt_dir , args .tokenizer_path )
11351148
1136- # convert the model to the desired precision
1137- if args .dtype == "float32" :
1138- model = model .to (torch .float32 )
1149+ if args .depth > 0 :
1150+ assert args .depth < len (model .transformer .h ), f"invalid depth { args .depth } , model has { len (model .transformer .h )} blocks"
1151+ model .transformer .h = model .transformer .h [0 :args .depth ]
1152+ model .config .n_layer = args .depth
1153+
1154+ # PT optimizer doesn't do stochastic rounding, so we
1155+ # really want the model to be in fp32 precision:
1156+ # --dtype should only enable AMP
1157+ # as the original checkpoints are in 16 bit, we need to convert
1158+ model = model .to (torch .float32 )
11391159
11401160 model = model .to (device )
11411161 model .train ()
@@ -1185,7 +1205,7 @@ def print0(*args, **kwargs):
11851205 # init the optimizer
11861206 optimizer = raw_model .configure_optimizers (weight_decay = args .weight_decay ,
11871207 learning_rate = args .learning_rate , betas = (0.9 , 0.95 ),
1188- device_type = device , zero_stage = zero_stage , offload = args . offload )
1208+ device_type = device , zero_stage = zero_stage )
11891209
11901210 # learning rate decay scheduler (cosine with warmup)
11911211 def get_lr (it ):
@@ -1205,6 +1225,8 @@ def get_lr(it):
12051225 if device == "cuda" :
12061226 torch .cuda .reset_peak_memory_stats ()
12071227 timings = []
1228+ losses = []
1229+ norms = []
12081230 norm = - 1.0 # dummy value to print in inference-only mode
12091231 for step in range (args .num_iterations + 1 ):
12101232 t0 = time .time ()
@@ -1298,16 +1320,6 @@ def get_lr(it):
12981320 dist .all_reduce (lossf , op = dist .ReduceOp .AVG )
12991321 lossf = lossf .item ()
13001322 norm = torch .nn .utils .clip_grad_norm_ (model .parameters (), args .grad_clip )
1301- if args .offload :
1302- # CPUOffloadOptimizer is *not* compatible with gradient clipping and will *silently*
1303- # give wrong results. So we
1304- # a) explicitly wait for it to finish its gradients transfers
1305- # b) overwrite the CPU gradients with the clipped GPU gradients.
1306- # This is terribly inefficient, but correct and lets us run CI on
1307- # small(ish) GPUs
1308- torch .cuda .synchronize ()
1309- for gpu , cpu in optimizer .param_d2h_map .items ():
1310- cpu .grad [...] = gpu .grad [...]
13111323
13121324 # determine and set the learning rate for this iteration
13131325 lr = get_lr (step )
@@ -1327,6 +1339,8 @@ def get_lr(it):
13271339 t1 = time .time ()
13281340 # the 0th iteration is often an outlier (much slower) => skip logging it
13291341 tokens_per_second = grad_accum_steps * ddp_world_size * B * T / (t1 - t0 )
1342+ losses .append (lossf )
1343+ norms .append (norm .item ())
13301344 print0 (f"step { step + 1 :4d} /{ args .num_iterations } | train loss { lossf :.6f} | norm { norm :.4f} | lr { lr :.2e} | ({ (t1 - t0 )* 1000 :.2f} ms | { tokens_per_second :.0f} tok/s)" )
13311345 # log to logile
13321346 if master_process and logfile is not None :
@@ -1337,6 +1351,9 @@ def get_lr(it):
13371351 if step > 0 and step > args .num_iterations - 20 :
13381352 timings .append (t1 - t0 )
13391353
1354+ if master_process and args .write_tensors and (not args .inference_only ):
1355+ write_training_history (losses , norms , f"llama3_{ model_size_str } _debug_state.bin" )
1356+
13401357 # print the average of the last 20 timings, to get something smooth-ish
13411358 timings = timings [- 20 :]
13421359 print0 (f"final { len (timings )} iters avg: { np .mean (timings )* 1000 :.3f} ms" )
0 commit comments