@@ -67,6 +67,8 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
6767#include " llmc/repkv.cuh"
6868// defines: precompute_freqs_cis, rope_forward
6969#include " llmc/rope.cuh"
70+ // defines: swiglu_forward
71+ #include " llmc/swiglu.cuh"
7072// ----------- Multi-GPU support -----------
7173// defines: ncclFloatX, ncclCheck, MultiGpuConfig, ShardInfo
7274// defines: printf0, multi_gpu_config
@@ -252,6 +254,13 @@ void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensor
252254 const size_t n_kv_head = config.num_kv_heads ; // num key and value heads
253255 const size_t hd = C / n_head; // the size of each head
254256 const size_t qkv_channels = (n_head + 2 *n_kv_head) * hd; // Q, K, V channels
257+ // SwiGLU-related calculation to determine the number of channels here
258+ size_t hidden_dim = 4 * C;
259+ hidden_dim = (2 * hidden_dim) / 3 ;
260+ hidden_dim = config.ffn_dim_multiplier * hidden_dim;
261+ hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1 ) / config.multiple_of );
262+ size_t ffn_channels = hidden_dim * 2 ; // c_fc + c_fc2 concatenated
263+ size_t ffn_channels_post_gelu = hidden_dim; // swiglu will halve the channels
255264
256265 tensors[0 ] = TENSOR_SPEC (data->encoded , B * T * C);
257266 // if recompute >= 1 then we will recompute the layernorm forward activation during backward pass
@@ -270,9 +279,9 @@ void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensor
270279 tensors[7 ] = TENSOR_SPEC (data->ln2 , (recompute < 2 ) ? L * B * T * C : 0 );
271280 tensors[8 ] = TENSOR_SPEC (data->ln2_mean , L * B * T);
272281 tensors[9 ] = TENSOR_SPEC (data->ln2_rstd , L * B * T);
273- tensors[10 ] = TENSOR_SPEC (data->fch , L * B * T * 4 *C );
282+ tensors[10 ] = TENSOR_SPEC (data->fch , L * B * T * ffn_channels );
274283 // if recompute >= 1 then we will recompute gelu_forward during backward and use this as scratch buffer
275- tensors[11 ] = TENSOR_SPEC (data->fch_gelu , (recompute < 1 ) ? L * B * T * 4 *C : B * T * 4 *C );
284+ tensors[11 ] = TENSOR_SPEC (data->fch_gelu , (recompute < 1 ) ? L * B * T * ffn_channels_post_gelu : B * T * ffn_channels_post_gelu );
276285 tensors[12 ] = TENSOR_SPEC (data->residual3 , L * B * T * C);
277286 tensors[13 ] = TENSOR_SPEC (data->lnf , B * T * C);
278287 tensors[14 ] = TENSOR_SPEC (data->lnf_mean , B * T);
@@ -621,6 +630,12 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
621630 const size_t n_kv_head = model->config .num_kv_heads ;
622631 const size_t hd = C / n_head; // head dimension
623632 const size_t qkv_channels = (n_head + 2 *n_kv_head) * hd; // Q, K, V channels
633+ size_t hidden_dim = 4 * C;
634+ hidden_dim = (2 * hidden_dim) / 3 ;
635+ hidden_dim = model->config .ffn_dim_multiplier * hidden_dim;
636+ hidden_dim = model->config .multiple_of * ((hidden_dim + model->config .multiple_of - 1 ) / model->config .multiple_of );
637+ size_t ffn_channels = hidden_dim * 2 ; // c_fc + c_fc2 concatenated
638+ size_t ffn_channels_post_gelu = hidden_dim; // swiglu halves the channels
624639
625640 // validate B,T are not larger than the values used at initialisation
626641 // (smaller B,T are okay for inference only)
@@ -653,9 +668,9 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
653668 floatX* l_attprojw = params.attprojw + l * C * C;
654669 floatX* l_attprojb = params.attprojb + l * C;
655670 floatX* l_ln2w = params.ln2w + l * C;
656- floatX* l_fcw = params.fcw + l * 4 *C * C;
657- floatX* l_fcb = params.fcb + l * 4 *C ;
658- floatX* l_fcprojw = params.fcprojw + l * C * 4 *C ;
671+ floatX* l_fcw = params.fcw + l * ffn_channels * C;
672+ floatX* l_fcb = params.fcb + l * ffn_channels ;
673+ floatX* l_fcprojw = params.fcprojw + l * C * ffn_channels_post_gelu ;
659674 floatX* l_fcprojb = params.fcprojb + l * C;
660675
661676 // get the pointers of the activations for this layer
@@ -665,10 +680,10 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
665680 floatX* l_residual2 = acts.residual2 + l * B * T * C;
666681 floatX* l_ln2 = (model->recompute < 2 ) ? acts.ln2 + l * B * T * C : acts.lnf ;
667682 float * l_ln2_rstd = acts.ln2_rstd + l * B * T;
668- floatX* l_fch = acts.fch + l * B * T * 4 *C ;
683+ floatX* l_fch = acts.fch + l * B * T * ffn_channels ;
669684 // reuse the same activation buffer at each layer, as we'll re-compute the gelu during backward
670685 // very useful because we dramatically reduce VRAM usage, and may be able to fit larger batch size
671- floatX* l_fch_gelu = (model->recompute < 1 ) ? acts.fch_gelu + l * B * T * 4 *C : acts.fch_gelu ;
686+ floatX* l_fch_gelu = (model->recompute < 1 ) ? acts.fch_gelu + l * B * T * ffn_channels_post_gelu : acts.fch_gelu ;
672687 floatX* l_residual3 = acts.residual3 + l * B * T * C;
673688 floatX* scratch = (floatX*)acts.output ; // used for non-cudnn attention, fcproj, attproj, etc.
674689 floatX* qkv_rep_scratch = (floatX*)acts.scratch_bt4c ; // we can use the BT4C scratch for qkv replication
@@ -687,20 +702,23 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
687702 // 1) projection to QKV vectors (note k,v may be fewer heads than q)
688703 matmul_forward_cublaslt (scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, qkv_channels, main_stream);
689704 // 2) replicate k,v so that all of q,k,v have the same number of heads. done for simplicity, for now
690- repkv_forward (qkv_rep_scratch, scratch, B, T, n_head, n_kv_head, hd);
705+ repkv_forward (qkv_rep_scratch, scratch, B, T, n_head, n_kv_head, hd, main_stream );
691706 // 3) apply RoPE to q,k in place
692- rope_forward (qkv_rep_scratch, qkv_rep_scratch, model->freqs_cis , B, T, n_head, hd);
707+ rope_forward (qkv_rep_scratch, qkv_rep_scratch, model->freqs_cis , B, T, n_head, hd, main_stream );
693708 // 4) attention: att <- softmax(qk^T)v
694709 attention_forward (l_atty, l_qkvr, l_att, qkv_rep_scratch, B, T, C, NH, main_stream);
695710 #endif
696711
697712 matmul_forward_cublaslt (scratch, l_atty, l_attprojw, l_attprojb, B, T, C, C, main_stream);
698713 fused_residual_rmsnorm_forward5 (l_residual2, l_ln2, l_ln2_rstd, residual, scratch, l_ln2w, B*T, C, main_stream);
714+ matmul_forward_cublaslt (l_fch, l_ln2, l_fcw, l_fcb, B, T, C, ffn_channels, main_stream);
715+ swiglu_forward (l_fch_gelu, l_fch, B, T, ffn_channels_post_gelu, main_stream);
716+ matmul_forward_cublaslt (scratch, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, ffn_channels_post_gelu, C, main_stream);
699717
700718 // ------------------------------------------------------------------------
701719 // DEBUGGING: we only work until this point right now, so exit here
702720 // transfer the first 32 elements to CPU and print them
703- floatX* output = l_ln2 ;
721+ floatX* output = scratch ;
704722 floatX* cpu = (floatX*)mallocCheck (32 * sizeof (floatX));
705723 cudaCheck (cudaMemcpy (cpu, output, 32 * sizeof (floatX), cudaMemcpyDeviceToHost));
706724 for (int i = 0 ; i < 32 ; i++) {
@@ -716,8 +734,6 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
716734 exit (0 );
717735 // ------------------------------------------------------------------------
718736
719- matmul_forward_cublaslt (l_fch_gelu, l_ln2, l_fcw, l_fcb, B, T, C, 4 *C, main_stream, l_fch, model->gelu_fusion );
720- matmul_forward_cublaslt (scratch, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4 *C, C, main_stream);
721737 // OK, fusion across blocks.
722738 if (l+1 != L) {
723739 floatX* l_ln1 = (model->recompute < 2 ) ? acts.ln1 + (l + 1 ) * B * T * C : acts.lnf ;
0 commit comments