@@ -1323,8 +1323,17 @@ def _recurrent_gated_delta_rule(self, query, key, value, g, beta, state):
13231323 """Single-step recurrent update for cached inference.
13241324
13251325 Input shapes: [batch, seq=1, heads, dim]
1326- Need shapes: [batch, heads, dim] for einsum operations
1326+ State shape: [batch, heads, key_dim, value_dim]
1327+
1328+ Implements the delta rule recurrence:
1329+ 1. Decay state: S = S * exp(g)
1330+ 2. Retrieve memory: mem = S @ k
1331+ 3. Compute delta: delta = (v - mem) * beta
1332+ 4. Update state: S = S + k ⊗ delta
1333+ 5. Output: o = S @ q (scaled)
13271334 """
1335+ input_dtype = query .dtype
1336+
13281337 # Transpose from [batch, seq, heads, dim] to [batch, heads, seq, dim]
13291338 query = query .transpose (1 , 2 )
13301339 key = key .transpose (1 , 2 )
@@ -1334,25 +1343,38 @@ def _recurrent_gated_delta_rule(self, query, key, value, g, beta, state):
13341343 query = _l2norm (query , dim = - 1 , eps = 1e-6 )
13351344 key = _l2norm (key , dim = - 1 , eps = 1e-6 )
13361345
1346+ # Apply query scaling (matches chunked mode)
1347+ scale = 1.0 / (query .shape [- 1 ] ** 0.5 )
1348+ query = query * scale
1349+
13371350 # Reshape for computation: [batch, heads, 1, dim] -> [batch, heads, dim]
13381351 query = query .squeeze (2 )
13391352 key = key .squeeze (2 )
13401353 value = value .squeeze (2 )
13411354 g = g .squeeze (1 )
13421355 beta = beta .squeeze (1 )
13431356
1344- # Update state: S = exp(g) * S + beta * k^T @ v
1345- # Keep everything in the same dtype as input (exp() returns float32, need to convert back)
1346- input_dtype = query .dtype
1357+ # 1. Decay state: S = S * exp(g)
13471358 decay = g .exp ().to (input_dtype ).unsqueeze (- 1 ).unsqueeze (- 1 ) # [batch, heads, 1, 1]
1348- k_outer_v = torch .einsum ("bhk,bhv->bhkv" , key * beta .unsqueeze (- 1 ), value )
1349- state = decay * state + k_outer_v
1359+ state = state * decay
1360+
1361+ # 2. Retrieve memory: mem = S @ k = (S * k.unsqueeze(-1)).sum(dim=-2)
1362+ # state: [batch, heads, key_dim, value_dim], key: [batch, heads, key_dim]
1363+ kv_mem = (state * key .unsqueeze (- 1 )).sum (dim = - 2 ) # [batch, heads, value_dim]
1364+
1365+ # 3. Compute delta: delta = (v - mem) * beta
1366+ delta = (value - kv_mem ) * beta .unsqueeze (- 1 ) # [batch, heads, value_dim]
1367+
1368+ # 4. Update state: S = S + k ⊗ delta
1369+ # k.unsqueeze(-1): [batch, heads, key_dim, 1]
1370+ # delta.unsqueeze(-2): [batch, heads, 1, value_dim]
1371+ state = state + key .unsqueeze (- 1 ) * delta .unsqueeze (- 2 )
13501372
1351- # Output: o = q @ S
1352- output = torch . einsum ( "bhk,bhkv->bhv" , query , state )
1353- output = output .unsqueeze (2 ) # [batch, heads, 1, v_dim ]
1373+ # 5. Output: o = S @ q = (S * q.unsqueeze(-1)).sum(dim=-2)
1374+ output = ( state * query . unsqueeze ( - 1 )). sum ( dim = - 2 ) # [batch, heads, value_dim]
1375+ output = output .unsqueeze (2 ) # [batch, heads, 1, value_dim ]
13541376
1355- # Transpose back to [batch, seq=1, heads, v_dim ]
1377+ # Transpose back to [batch, seq=1, heads, value_dim ]
13561378 output = output .transpose (1 , 2 )
13571379
13581380 # Ensure state matches output dtype
0 commit comments