Skip to content

Commit 1826752

Browse files
committed
add swigul yaygit add -u!
1 parent bb3c92d commit 1826752

File tree

4 files changed

+37
-21
lines changed

4 files changed

+37
-21
lines changed

llmc/repkv.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ __global__ void repkv_forward_kernel1(floatX* replicated_qkv,
4848
replicated_qkv[idx_flat] = __ldcs(&gqa_qkv[inp_idx]);
4949
}
5050

51-
void repkv_forward(floatX* out, const floatX* inp, int B, int T, int NH, int NH_KV, int HD) {
51+
void repkv_forward(floatX* out, const floatX* inp, int B, int T, int NH, int NH_KV, int HD, cudaStream_t stream) {
5252
// NH = number of query heads, NH_KV = number of key and value heads, HD = head dimension
5353
const int block_size = 128;
5454
int total_threads = B * T * (3 * NH) * HD; // one thread per output element
5555
int num_blocks = CEIL_DIV(total_threads, block_size);
5656
int replicate_factor = NH / NH_KV;
5757
if (replicate_factor > 1) {
58-
repkv_forward_kernel1<<<num_blocks, block_size>>>(out, inp, B, T, NH, replicate_factor, HD);
58+
repkv_forward_kernel1<<<num_blocks, block_size, 0, stream>>>(out, inp, B, T, NH, replicate_factor, HD);
5959
} else {
6060
cudaMemcpy(out, inp, total_threads * sizeof(floatX), cudaMemcpyDeviceToDevice);
6161
}

llmc/rope.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,14 @@ __global__ void rope_forward_kernel1(floatX *out, const floatX *inp, const float
7070
out[idxi + 1] = x_real * freqs_sin + x_imag * freqs_cos;
7171
}
7272

73-
void rope_forward(floatX *out, const floatX *inp, const floatX *freqs_cis, int B, int T, int n_head, int head_dim) {
73+
void rope_forward(floatX *out, const floatX *inp, const floatX *freqs_cis, int B, int T, int n_head, int head_dim, cudaStream_t stream) {
7474
// the input and output to this kernel are (B, T, 3, NH, HD) where the 3 is q,k,v
7575
// we are going to launch exactly one thread per element of the output,
7676
// except divide by two because the work is in "tuples"
7777
// so this single kernel launch will do RoPE for both q and k, and the threads for v will be a no-op
7878
const int block_size = 128;
7979
int total_threads = B * T * 3 * n_head * head_dim / 2;
8080
int num_blocks = CEIL_DIV(total_threads, block_size);
81-
rope_forward_kernel1<<<num_blocks, block_size>>>(out, inp, freqs_cis, B, T, n_head, head_dim);
81+
rope_forward_kernel1<<<num_blocks, block_size, 0, stream>>>(out, inp, freqs_cis, B, T, n_head, head_dim);
8282
cudaCheck(cudaGetLastError());
8383
}

train_llama3.cu

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
6767
#include "llmc/repkv.cuh"
6868
// defines: precompute_freqs_cis, rope_forward
6969
#include "llmc/rope.cuh"
70+
// defines: swiglu_forward
71+
#include "llmc/swiglu.cuh"
7072
// ----------- Multi-GPU support -----------
7173
// defines: ncclFloatX, ncclCheck, MultiGpuConfig, ShardInfo
7274
// defines: printf0, multi_gpu_config
@@ -252,6 +254,13 @@ void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensor
252254
const size_t n_kv_head = config.num_kv_heads; // num key and value heads
253255
const size_t hd = C / n_head; // the size of each head
254256
const size_t qkv_channels = (n_head + 2*n_kv_head) * hd; // Q, K, V channels
257+
// SwiGLU-related calculation to determine the number of channels here
258+
size_t hidden_dim = 4 * C;
259+
hidden_dim = (2 * hidden_dim) / 3;
260+
hidden_dim = config.ffn_dim_multiplier * hidden_dim;
261+
hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) / config.multiple_of);
262+
size_t ffn_channels = hidden_dim * 2; // c_fc + c_fc2 concatenated
263+
size_t ffn_channels_post_gelu = hidden_dim; // swiglu will halve the channels
255264

256265
tensors[0] = TENSOR_SPEC(data->encoded, B * T * C);
257266
// if recompute >= 1 then we will recompute the layernorm forward activation during backward pass
@@ -270,9 +279,9 @@ void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensor
270279
tensors[7] = TENSOR_SPEC(data->ln2, (recompute < 2) ? L * B * T * C : 0);
271280
tensors[8] = TENSOR_SPEC(data->ln2_mean, L * B * T);
272281
tensors[9] = TENSOR_SPEC(data->ln2_rstd, L * B * T);
273-
tensors[10] = TENSOR_SPEC(data->fch, L * B * T * 4*C);
282+
tensors[10] = TENSOR_SPEC(data->fch, L * B * T * ffn_channels);
274283
// if recompute >= 1 then we will recompute gelu_forward during backward and use this as scratch buffer
275-
tensors[11] = TENSOR_SPEC(data->fch_gelu, (recompute < 1) ? L * B * T * 4*C : B * T * 4*C);
284+
tensors[11] = TENSOR_SPEC(data->fch_gelu, (recompute < 1) ? L * B * T * ffn_channels_post_gelu : B * T * ffn_channels_post_gelu);
276285
tensors[12] = TENSOR_SPEC(data->residual3, L * B * T * C);
277286
tensors[13] = TENSOR_SPEC(data->lnf, B * T * C);
278287
tensors[14] = TENSOR_SPEC(data->lnf_mean, B * T);
@@ -621,6 +630,12 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
621630
const size_t n_kv_head = model->config.num_kv_heads;
622631
const size_t hd = C / n_head; // head dimension
623632
const size_t qkv_channels = (n_head + 2*n_kv_head) * hd; // Q, K, V channels
633+
size_t hidden_dim = 4 * C;
634+
hidden_dim = (2 * hidden_dim) / 3;
635+
hidden_dim = model->config.ffn_dim_multiplier * hidden_dim;
636+
hidden_dim = model->config.multiple_of * ((hidden_dim + model->config.multiple_of - 1) / model->config.multiple_of);
637+
size_t ffn_channels = hidden_dim * 2; // c_fc + c_fc2 concatenated
638+
size_t ffn_channels_post_gelu = hidden_dim; // swiglu halves the channels
624639

625640
// validate B,T are not larger than the values used at initialisation
626641
// (smaller B,T are okay for inference only)
@@ -653,9 +668,9 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
653668
floatX* l_attprojw = params.attprojw + l * C * C;
654669
floatX* l_attprojb = params.attprojb + l * C;
655670
floatX* l_ln2w = params.ln2w + l * C;
656-
floatX* l_fcw = params.fcw + l * 4*C * C;
657-
floatX* l_fcb = params.fcb + l * 4*C;
658-
floatX* l_fcprojw = params.fcprojw + l * C * 4*C;
671+
floatX* l_fcw = params.fcw + l * ffn_channels * C;
672+
floatX* l_fcb = params.fcb + l * ffn_channels;
673+
floatX* l_fcprojw = params.fcprojw + l * C * ffn_channels_post_gelu;
659674
floatX* l_fcprojb = params.fcprojb + l * C;
660675

661676
// get the pointers of the activations for this layer
@@ -665,10 +680,10 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
665680
floatX* l_residual2 = acts.residual2 + l * B * T * C;
666681
floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf;
667682
float* l_ln2_rstd = acts.ln2_rstd + l * B * T;
668-
floatX* l_fch = acts.fch + l * B * T * 4*C;
683+
floatX* l_fch = acts.fch + l * B * T * ffn_channels;
669684
// reuse the same activation buffer at each layer, as we'll re-compute the gelu during backward
670685
// very useful because we dramatically reduce VRAM usage, and may be able to fit larger batch size
671-
floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * 4*C : acts.fch_gelu;
686+
floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * ffn_channels_post_gelu : acts.fch_gelu;
672687
floatX* l_residual3 = acts.residual3 + l * B * T * C;
673688
floatX* scratch = (floatX*)acts.output; // used for non-cudnn attention, fcproj, attproj, etc.
674689
floatX* qkv_rep_scratch = (floatX*)acts.scratch_bt4c; // we can use the BT4C scratch for qkv replication
@@ -687,20 +702,23 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
687702
// 1) projection to QKV vectors (note k,v may be fewer heads than q)
688703
matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, qkv_channels, main_stream);
689704
// 2) replicate k,v so that all of q,k,v have the same number of heads. done for simplicity, for now
690-
repkv_forward(qkv_rep_scratch, scratch, B, T, n_head, n_kv_head, hd);
705+
repkv_forward(qkv_rep_scratch, scratch, B, T, n_head, n_kv_head, hd, main_stream);
691706
// 3) apply RoPE to q,k in place
692-
rope_forward(qkv_rep_scratch, qkv_rep_scratch, model->freqs_cis, B, T, n_head, hd);
707+
rope_forward(qkv_rep_scratch, qkv_rep_scratch, model->freqs_cis, B, T, n_head, hd, main_stream);
693708
// 4) attention: att <- softmax(qk^T)v
694709
attention_forward(l_atty, l_qkvr, l_att, qkv_rep_scratch, B, T, C, NH, main_stream);
695710
#endif
696711

697712
matmul_forward_cublaslt(scratch, l_atty, l_attprojw, l_attprojb, B, T, C, C, main_stream);
698713
fused_residual_rmsnorm_forward5(l_residual2, l_ln2, l_ln2_rstd, residual, scratch, l_ln2w, B*T, C, main_stream);
714+
matmul_forward_cublaslt(l_fch, l_ln2, l_fcw, l_fcb, B, T, C, ffn_channels, main_stream);
715+
swiglu_forward(l_fch_gelu, l_fch, B, T, ffn_channels_post_gelu, main_stream);
716+
matmul_forward_cublaslt(scratch, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, ffn_channels_post_gelu, C, main_stream);
699717

700718
// ------------------------------------------------------------------------
701719
// DEBUGGING: we only work until this point right now, so exit here
702720
// transfer the first 32 elements to CPU and print them
703-
floatX* output = l_ln2;
721+
floatX* output = scratch;
704722
floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX));
705723
cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost));
706724
for (int i = 0; i < 32; i++) {
@@ -716,8 +734,6 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
716734
exit(0);
717735
// ------------------------------------------------------------------------
718736

719-
matmul_forward_cublaslt(l_fch_gelu, l_ln2, l_fcw, l_fcb, B, T, C, 4*C, main_stream, l_fch, model->gelu_fusion);
720-
matmul_forward_cublaslt(scratch, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C, main_stream);
721737
// OK, fusion across blocks.
722738
if(l+1 != L) {
723739
floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + (l + 1) * B * T * C : acts.lnf;

train_llama3.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,11 @@ def __init__(self, config):
216216

217217
def forward(self, x):
218218
# SwiGLU self.c_proj(F.silu(self.c_fc2(x)) * self.c_fc(x)) <-- 3. difference compared to GPT-2
219+
x1 = self.c_fc(x)
220+
x2 = self.c_fc2(x)
221+
x2 = F.silu(x2)
222+
x = x1 * x2
223+
x = self.c_proj(x)
219224

220225
# ---------------------------------------------------------------------
221226
# DEBUGGING: print first 32 elements of x
@@ -227,11 +232,6 @@ def forward(self, x):
227232
breakpoint()
228233
# ---------------------------------------------------------------------
229234

230-
x1 = self.c_fc(x)
231-
x2 = self.c_fc2(x)
232-
x2 = F.silu(x2)
233-
x = x1 * x2
234-
x = self.c_proj(x)
235235
return x
236236

237237
class Block(nn.Module):

0 commit comments

Comments
 (0)