Skip to content

Commit e8b93e0

Browse files
committed
apriel2 modeling bug
1 parent 691d8b2 commit e8b93e0

File tree

2 files changed

+121
-10
lines changed

2 files changed

+121
-10
lines changed

fast_llm_external_models/apriel2/modeling_apriel2.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,8 +1323,17 @@ def _recurrent_gated_delta_rule(self, query, key, value, g, beta, state):
13231323
"""Single-step recurrent update for cached inference.
13241324
13251325
Input shapes: [batch, seq=1, heads, dim]
1326-
Need shapes: [batch, heads, dim] for einsum operations
1326+
State shape: [batch, heads, key_dim, value_dim]
1327+
1328+
Implements the delta rule recurrence:
1329+
1. Decay state: S = S * exp(g)
1330+
2. Retrieve memory: mem = S @ k
1331+
3. Compute delta: delta = (v - mem) * beta
1332+
4. Update state: S = S + k ⊗ delta
1333+
5. Output: o = S @ q (scaled)
13271334
"""
1335+
input_dtype = query.dtype
1336+
13281337
# Transpose from [batch, seq, heads, dim] to [batch, heads, seq, dim]
13291338
query = query.transpose(1, 2)
13301339
key = key.transpose(1, 2)
@@ -1334,25 +1343,38 @@ def _recurrent_gated_delta_rule(self, query, key, value, g, beta, state):
13341343
query = _l2norm(query, dim=-1, eps=1e-6)
13351344
key = _l2norm(key, dim=-1, eps=1e-6)
13361345

1346+
# Apply query scaling (matches chunked mode)
1347+
scale = 1.0 / (query.shape[-1] ** 0.5)
1348+
query = query * scale
1349+
13371350
# Reshape for computation: [batch, heads, 1, dim] -> [batch, heads, dim]
13381351
query = query.squeeze(2)
13391352
key = key.squeeze(2)
13401353
value = value.squeeze(2)
13411354
g = g.squeeze(1)
13421355
beta = beta.squeeze(1)
13431356

1344-
# Update state: S = exp(g) * S + beta * k^T @ v
1345-
# Keep everything in the same dtype as input (exp() returns float32, need to convert back)
1346-
input_dtype = query.dtype
1357+
# 1. Decay state: S = S * exp(g)
13471358
decay = g.exp().to(input_dtype).unsqueeze(-1).unsqueeze(-1) # [batch, heads, 1, 1]
1348-
k_outer_v = torch.einsum("bhk,bhv->bhkv", key * beta.unsqueeze(-1), value)
1349-
state = decay * state + k_outer_v
1359+
state = state * decay
1360+
1361+
# 2. Retrieve memory: mem = S @ k = (S * k.unsqueeze(-1)).sum(dim=-2)
1362+
# state: [batch, heads, key_dim, value_dim], key: [batch, heads, key_dim]
1363+
kv_mem = (state * key.unsqueeze(-1)).sum(dim=-2) # [batch, heads, value_dim]
1364+
1365+
# 3. Compute delta: delta = (v - mem) * beta
1366+
delta = (value - kv_mem) * beta.unsqueeze(-1) # [batch, heads, value_dim]
1367+
1368+
# 4. Update state: S = S + k ⊗ delta
1369+
# k.unsqueeze(-1): [batch, heads, key_dim, 1]
1370+
# delta.unsqueeze(-2): [batch, heads, 1, value_dim]
1371+
state = state + key.unsqueeze(-1) * delta.unsqueeze(-2)
13501372

1351-
# Output: o = q @ S
1352-
output = torch.einsum("bhk,bhkv->bhv", query, state)
1353-
output = output.unsqueeze(2) # [batch, heads, 1, v_dim]
1373+
# 5. Output: o = S @ q = (S * q.unsqueeze(-1)).sum(dim=-2)
1374+
output = (state * query.unsqueeze(-1)).sum(dim=-2) # [batch, heads, value_dim]
1375+
output = output.unsqueeze(2) # [batch, heads, 1, value_dim]
13541376

1355-
# Transpose back to [batch, seq=1, heads, v_dim]
1377+
# Transpose back to [batch, seq=1, heads, value_dim]
13561378
output = output.transpose(1, 2)
13571379

13581380
# Ensure state matches output dtype

fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,95 @@ def test_vs_qwen3next(
811811
msg=f"Apriel2GatedDeltaNet vs Qwen3NextGatedDeltaNet (batch={batch_size}, seq={seq_len})",
812812
)
813813

814+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="GDN requires CUDA")
815+
@pytest.mark.parametrize("seed", [42, 123, 456])
816+
@pytest.mark.parametrize("prefill_len", [4, 8, 16])
817+
def test_chunked_vs_recurrent(
818+
self,
819+
gdn_config,
820+
seed,
821+
prefill_len,
822+
):
823+
"""Verify GDN recurrent mode (decode) matches chunked mode (prefill).
824+
825+
This tests the inference path: after prefilling N tokens with chunked mode,
826+
subsequent single-token decodes using recurrent mode should produce the same
827+
output as if we had run the full sequence through chunked mode.
828+
"""
829+
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet
830+
831+
value_heads, key_heads, key_head_dim, value_head_dim = gdn_config
832+
hidden_size = 256
833+
batch_size = 2
834+
total_len = prefill_len + 4 # Prefill + 4 decode steps
835+
836+
config_dict = {
837+
"type": "gdn",
838+
"value_heads": value_heads,
839+
"key_heads": key_heads,
840+
"key_head_dim": key_head_dim,
841+
"value_head_dim": value_head_dim,
842+
"convolution_layer": {"kernel_size": 4},
843+
"norm_eps": 1e-5,
844+
}
845+
846+
# Create model
847+
torch.manual_seed(seed)
848+
model = Apriel2GatedDeltaNet(hidden_size, config_dict, layer_idx=0)
849+
model = model.cuda()
850+
model.eval()
851+
852+
# Create input sequence
853+
torch.manual_seed(seed + 1)
854+
full_hidden_states = torch.randn(batch_size, total_len, hidden_size, device="cuda")
855+
856+
# === Reference: Run full sequence through chunked mode ===
857+
with torch.no_grad():
858+
reference_output = model(full_hidden_states)[0]
859+
860+
# === Test: Prefill + decode ===
861+
# Create a simple cache object to hold conv and recurrent states
862+
class SimpleCache:
863+
def __init__(self):
864+
self.conv_states = {0: None}
865+
self.recurrent_states = {0: None}
866+
867+
cache = SimpleCache()
868+
869+
# Prefill phase
870+
prefill_input = full_hidden_states[:, :prefill_len, :]
871+
with torch.no_grad():
872+
prefill_output = model(
873+
prefill_input,
874+
past_key_values=cache,
875+
cache_position=torch.arange(prefill_len, device="cuda"),
876+
)[0]
877+
878+
# Decode phase - one token at a time
879+
decode_outputs = []
880+
for i in range(prefill_len, total_len):
881+
decode_input = full_hidden_states[:, i : i + 1, :]
882+
with torch.no_grad():
883+
decode_output = model(
884+
decode_input,
885+
past_key_values=cache,
886+
cache_position=torch.tensor([i], device="cuda"),
887+
)[0]
888+
decode_outputs.append(decode_output)
889+
890+
# Concatenate prefill + decode outputs
891+
test_output = torch.cat([prefill_output] + decode_outputs, dim=1)
892+
893+
# Use looser tolerance for chunked vs recurrent comparison
894+
# (different processing order leads to numerical differences)
895+
assert_close(
896+
test_output,
897+
reference_output,
898+
rtol=1e-3,
899+
atol=1e-3,
900+
msg=f"GDN chunked vs recurrent mode (prefill={prefill_len}, total={total_len})",
901+
)
902+
814903

815904
# =============================================================================
816905
# SECTION 2: EQUIVALENCE TESTS - KimiDeltaAttention

0 commit comments

Comments
 (0)