Skip to content

Commit 74441e3

Browse files
authored
Skip swizzling when read/gather row phases are inconsistent (#699)
Add validation to detect row phase mismatch between reads and gathers in gather_to_shared_swizzling and gracefully skip unsupported cases. Add F32_32x32x16_F16 MMA test for CDNA4 which fails if this validation is not present. Signed-off-by: harsh-nod <[email protected]>
1 parent 55e9491 commit 74441e3

File tree

4 files changed

+25
-21
lines changed

4 files changed

+25
-21
lines changed

lit_tests/kernel/wave/gather_to_shared.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -329,21 +329,11 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
329329
print(scaled_gemm.asm)
330330

331331
# CHECK-LABEL: test_gather_to_shared_scaled_dims
332-
# CHECK: #[[map1:.*]] = affine_map<()[s0] -> ((s0 floordiv 8) mod 8)>
333-
# CHECK: #[[map2:.*]] = affine_map<()[s0] -> (s0 mod 8)>
334-
# CHECK: #[[map6:.*]] = affine_map<()[s0] -> ((s0 mod 64) floordiv 16)>
335-
# CHECK: #[[map7:.*]] = affine_map<()[s0] -> ((s0 mod 64) floordiv 16 + 4)>
336332
# CHECK: func.func @scaled_gemm
337333
# CHECK: %[[thread_id_x:.*]] = gpu.thread_id x
338334
# CHECK-COUNT-1: memref.alloc()
339-
# Check some swizzling was done
340-
# CHECK: %[[col:.*]] = affine.apply #[[map1]]()[%[[thread_id_x]]]
341-
# CHECK: %[[row:.*]] = affine.apply #[[map2]]()[%[[thread_id_x]]]
342-
# CHECK: %{{.*}} = arith.xori %[[row]], %[[col]] : index
343-
# CHECK: %[[row_swizzled:.*]] = affine.apply #[[map6]]()[%[[thread_id_x]]]
344-
# CHECK: %[[row_swizzled_2:.*]] = affine.apply #[[map7]]()[%[[thread_id_x]]]
345-
# CHECK: %{{.*}} = arith.xori %[[row_swizzled]], %[[row]] : index
346-
# CHECK: %{{.*}} = arith.xori %[[row_swizzled_2]], %[[row]] : index
335+
# Note: Swizzling is disabled for this test due to row phase inconsistency
336+
# between reads and gathers.
347337
# CHECK: scf.for
348338
# CHECK: amdgpu.lds_barrier
349339
# CHECK-COUNT-4: amdgpu.gather_to_lds {{.*}}

tests/kernel/wave_gemm_mxfp_test.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -397,31 +397,29 @@ def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path):
397397
# We encode the exact registers and wait counts as we want to know if
398398
# they suddenly change due to backend or upstream MLIR changes.
399399
if use_water_backend:
400-
vgpr_count = 148
400+
vgpr_count = 154
401401
vgpr_spill_count = 0
402402
sgpr_count = 58
403403
sgpr_spill_count = 0
404404
waitcounts = [
405405
"s_waitcnt lgkmcnt(0)",
406406
"s_waitcnt vmcnt(0)",
407-
"s_waitcnt lgkmcnt(10)",
408-
"s_waitcnt lgkmcnt(1)",
409-
"s_waitcnt lgkmcnt(0)",
407+
"s_waitcnt lgkmcnt(8)",
410408
"s_waitcnt lgkmcnt(1)",
411409
"s_waitcnt lgkmcnt(1)",
412410
"s_waitcnt vmcnt(0) lgkmcnt(0)",
413411
"s_waitcnt vmcnt(0)",
414-
"s_waitcnt lgkmcnt(8)",
412+
"s_waitcnt lgkmcnt(7)",
415413
"s_waitcnt lgkmcnt(6)",
416414
"s_waitcnt lgkmcnt(5)",
415+
"s_waitcnt lgkmcnt(4)",
417416
"s_waitcnt lgkmcnt(3)",
418-
"s_waitcnt lgkmcnt(1)",
419417
"s_waitcnt lgkmcnt(2)",
420418
"s_waitcnt lgkmcnt(1)",
421419
"s_waitcnt lgkmcnt(0)",
422420
]
423421
else:
424-
vgpr_count = 162
422+
vgpr_count = 160
425423
vgpr_spill_count = 0
426424
sgpr_count = 59
427425
sgpr_spill_count = 0
@@ -430,10 +428,10 @@ def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path):
430428
"s_waitcnt vmcnt(0)",
431429
"s_waitcnt vmcnt(0) lgkmcnt(0)",
432430
"s_waitcnt vmcnt(0)",
433-
"s_waitcnt lgkmcnt(7)",
434-
"s_waitcnt lgkmcnt(6)",
435431
"s_waitcnt lgkmcnt(5)",
432+
"s_waitcnt lgkmcnt(4)",
436433
"s_waitcnt lgkmcnt(3)",
434+
"s_waitcnt lgkmcnt(2)",
437435
"s_waitcnt lgkmcnt(1)",
438436
"s_waitcnt lgkmcnt(0)",
439437
]

tests/kernel/wave_gemm_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ def testPureGemm(
352352
[
353353
pytest.param(MMAType.F32_16x16x16_F16, 64, marks=require_cdna_3_or_4),
354354
pytest.param(MMAType.F32_32x32x8_F16, 64, marks=require_cdna_3_or_4),
355+
pytest.param(MMAType.F32_32x32x16_F16, 64, marks=require_cdna4),
355356
pytest.param(MMAType.GFX1250_F32_16x16x32_F16, 32, marks=require_gfx1250),
356357
],
357358
)

wave_lang/kernel/wave/gather_to_shared.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,21 @@ def gather_to_shared_swizzling(
629629

630630
max_phase = 8
631631

632+
# Check row phase inconsistency between reads and gathers.
633+
gather_local_index = remove_global_indexing(gather.src_index, constraints)
634+
read_local_index = remove_global_indexing(read.index, constraints)
635+
gather_row_expr = sympy.simplify(
636+
subs_idxc(gather_local_index[row_dim].start) % max_phase
637+
)
638+
read_row_expr = sympy.simplify(
639+
subs_idxc(read_local_index[row_dim].start) % max_phase
640+
)
641+
if gather_row_expr != read_row_expr:
642+
logger.info(
643+
f"row phase inconsistency between reads and gathers: {gather_row_expr} != {read_row_expr}. Skipping swizzling as it is not supported."
644+
)
645+
continue
646+
632647
for read in reads:
633648
index = remove_global_indexing(read.index, constraints)
634649
col_seq = index[col_dim]

0 commit comments

Comments
 (0)