Skip to content

Commit 0f36bca

Browse files
committed
docs: update mla explanation in mla.md
1 parent 028ee99 commit 0f36bca

File tree

1 file changed

+31
-10
lines changed

1 file changed

+31
-10
lines changed

notes/mla.md

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,8 @@ content:
271271
$46 = {512, 1, 1, 1}
272272
```
273273

274+
And then we create a view into the kv_cmpr_pe tensor which holds the rope
275+
information:
274276
```c++
275277
// and {n_embd_head_qk_rope, 1, n_tokens}
276278
ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens,
@@ -284,60 +286,79 @@ $46 = {512, 1, 1, 1}
284286
$47 = {64, 1, 1, 1}
285287
```
286288

289+
And then we apply RoPE to the query position embedding.
287290
```c++
288291
q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
289292
ext_factor, attn_factor, beta_fast, beta_slow);
290293
cb(q_pe, "q_pe", il);
291294
```
295+
So this is where we apply the rotation passing in the query position embedding
296+
tensor and the input position tensor.
297+
298+
And then we also apply RoPE for the key position embedding:
292299
```c++
293300
294301
k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
295302
ext_factor, attn_factor, beta_fast, beta_slow);
296303
cb(k_pe, "k_pe", il);
297304
```
305+
306+
Next we apply normalization to the compressed kv tensor:
298307
```c++
299308

300309
kv_cmpr = build_norm(kv_cmpr, model.layers[il].attn_kv_a_norm, nullptr, LLM_NORM_RMS, il);
301310
cb(kv_cmpr, "kv_cmpr", il);
302-
311+
```
312+
And then we get to the MLA part:
313+
```c++
303314
if (is_mla) {
304-
// {n_embd_head_qk_nope, n_tokens, n_head}
305315
q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
306316
cb(q_nope, "q_nope_perm", il);
317+
```
307318

308-
// {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head}
319+
Following that we have the absorption part where we multiply the query nope tensor
320+
with the unzip matrix to incorporate the expansion into the query:
321+
```c++
309322
ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope);
310323
cb(q_nope_absorbed, "q_nope_absorbed", il);
311324

312325
// {kv_lora_rank, n_head, n_tokens}
313326
q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3);
314327
cb(q_nope_absorbed, "q_nope_absorbed_perm", il);
315-
316-
// {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
317-
// note: rope must go first for in-place context shifting in build_rope_shift()
328+
```
329+
And then we concatenate the rope and nope parts of the query together:
330+
```c++
318331
ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0);
319332
cb(Qcur, "Qcur", il);
320-
333+
```
334+
And then we do the same concatenation for the key:
335+
```c++
321336
kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
322337
cb(kv_cmpr, "kv_cmpr_reshape", il);
323338

324-
// {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
325339
ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0);
326340
cb(Kcur, "Kcur", il);
341+
```
342+
And notice that Vcur is just the compressed kv tensor:
343+
```c++
327344
328345
// {kv_lora_rank, 1, n_tokens}
329346
ggml_tensor * Vcur = kv_cmpr;
330347
cb(Vcur, "Vcur", il);
348+
```
349+
And if the model has temperature scaling that is applied here:
350+
```c++
331351

332352
if (inp_attn_scale) {
333353
// apply llama 4 temperature scaling
334354
Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
335355
cb(Qcur, "Qcur_attn_temp_scaled", il);
336356
}
337-
357+
```
358+
And then we finally have the attention operation itself:
359+
```c++
338360
// note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group)
339361
cur = build_attn(inp_attn,
340362
model.layers[il].wo, NULL,
341363
Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il);
342-
343364
```

0 commit comments

Comments
 (0)