Skip to content

Commit 3a1d920

Browse files
committed
Address comments
Signed-off-by: tyb0807 <[email protected]>
1 parent 9a6276f commit 3a1d920

File tree

1 file changed

+30
-68
lines changed

1 file changed

+30
-68
lines changed

lit_tests/kernel/wave/mlir_converter_e2e.py

Lines changed: 30 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -88,26 +88,24 @@ def matrix_add(
8888
trace = compiled_kernel.compiled_graph
8989
constraints = matrix_add.constraints
9090

91-
# Emit Wave dialect MLIR
91+
# Emit Wave dialect MLIR.
9292
wave_dialect_mlir, diagnostics, _ = emit_wave_dialect(
9393
trace, constraints, options_mlir
9494
)
9595

96-
# Apply Water PassManager lowering
96+
# Apply Water middle-end pipeline.
9797
lowered_mlir = apply_water_middle_end_passes(wave_dialect_mlir)
9898

9999
print(lowered_mlir)
100100

101-
# Create test tensors
102101
shape = (128, 128)
103102
a_tensor = device_randn(*shape, dtype=torch.float16)
104103
b_tensor = device_randn(*shape, dtype=torch.float16)
105104
c_tensor = device_zeros(*shape, dtype=torch.float16)
106105

107-
# Expected result (CPU computation)
108106
expected = a_tensor + b_tensor
109107

110-
# Test execution with lowered MLIR
108+
# Test execution with lowered MLIR.
111109
options_e2e = WaveCompileOptions(
112110
subs=subs,
113111
canonicalize=True,
@@ -139,105 +137,69 @@ def matrix_add(
139137
@run_test
140138
def test_matmul_water_e2e():
141139
"""Test Water PassManager with matmul kernel and e2e execution."""
142-
torch.manual_seed(0)
143-
144-
# Input sizes
145-
M = tkl.sym.M
146-
N = tkl.sym.N
147-
K = tkl.sym.K
148-
# Workgroup tile sizes
149-
BLOCK_M = tkl.sym.BLOCK_M
150-
BLOCK_N = tkl.sym.BLOCK_N
151-
BLOCK_K = tkl.sym.BLOCK_K
152-
# Address space (for GPU, shared(1) or global(0))
153-
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
154-
dtype = tkl.f16
155-
156-
# Define constraints for matmul
157-
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
158-
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
159-
constraints += [tkw.TilingConstraint(K, BLOCK_K)]
160-
constraints += [tkw.WaveConstraint(M, sympy.floor(BLOCK_M / 2))]
161-
constraints += [tkw.WaveConstraint(N, sympy.floor(BLOCK_N / 2))]
162-
constraints += [
163-
tkw.HardwareConstraint(threads_per_wave=64, mma_type=MMAType.F32_32x32x8_F16)
164-
]
165-
166-
@tkw.wave(constraints)
167-
def matmul(
168-
a: tkl.Memory[M, K, ADDRESS_SPACE, dtype],
169-
b: tkl.Memory[N, K, ADDRESS_SPACE, dtype],
170-
c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
171-
):
172-
c_reg = tkl.Register[M, N, tkl.f32](0.0)
140+
from wave_lang.kernel.wave.templates.gemm import get_gemm_kernel
173141

174-
@tkw.iterate(K, init_args=[c_reg])
175-
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
176-
a_reg = tkw.read(a)
177-
b_reg = tkw.read(b)
178-
acc = tkw.mma(a_reg, b_reg, acc)
179-
return acc
180-
181-
tkw.write(repeat, c)
142+
torch.manual_seed(0)
182143

144+
# Matrix dimensions.
183145
m = 1024
184146
n = 5120
185147
k = 640
186-
# Set parameters for compilation
187-
subs: dict[str | IndexSymbol, Any] = {
188-
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
189-
BLOCK_M: 64,
190-
BLOCK_N: 64,
191-
BLOCK_K: 32,
192-
M: m,
193-
N: n,
194-
K: k,
195-
}
148+
149+
# Get GEMM kernel from template.
150+
gemm, hyperparams, _ = get_gemm_kernel(
151+
shape=(m, n, k),
152+
dynamic_dims=False,
153+
mfma_variant=MMAType.F32_32x32x8_F16,
154+
block_shape=(64, 64, 32),
155+
waves_per_block=(2, 2),
156+
)
196157

197158
options_mlir = WaveCompileOptions(
198-
subs=subs,
159+
subs=hyperparams,
199160
compile_to_mlir=True,
200161
location_capture_config=LocationCaptureConfig(level=LocationCaptureLevel.NONE),
201162
enforce_locations=False,
202-
print_mlir=True,
203163
)
204164
options_mlir = set_default_run_config(options_mlir)
205165

206-
compiled_kernel = wave_compile(options_mlir, matmul)
166+
compiled_kernel = wave_compile(options_mlir, gemm)
207167
trace = compiled_kernel.compiled_graph
208-
constraints = matmul.constraints
168+
constraints = gemm.constraints
209169

210-
# Emit Wave dialect MLIR
211-
wave_dialect_mlir, diagnostics, _ = emit_wave_dialect(trace, constraints, options_mlir)
170+
# Emit Wave dialect MLIR.
171+
wave_dialect_mlir, diagnostics, _ = emit_wave_dialect(
172+
trace, constraints, options_mlir
173+
)
212174

213-
# Apply Water PassManager lowering
175+
# Apply Water middle-end pipeline.
214176
lowered_mlir = apply_water_middle_end_passes(wave_dialect_mlir)
215177

216178
print(lowered_mlir)
217179

218-
# Create test tensors
180+
# Create test tensors on device.
219181
a_tensor = device_randn(m, k, dtype=torch.float16)
220182
b_tensor = device_randn(n, k, dtype=torch.float16) # Note: transposed in matmul
221183
c_tensor = device_zeros(m, n, dtype=torch.float32)
222184

223-
# Expected result (CPU computation)
224-
expected = torch.matmul(a_tensor.float(), b_tensor.T.float())
185+
# Expected result using PyTorch reference.
186+
expected = torch.matmul(a_tensor, b_tensor.T).float()
225187

226-
# Test execution with lowered MLIR
188+
# Test execution with lowered MLIR.
227189
options_e2e = WaveCompileOptions(
228-
subs=subs,
190+
subs=hyperparams,
229191
canonicalize=True,
230192
location_capture_config=LocationCaptureConfig(level=LocationCaptureLevel.NONE),
231193
enforce_locations=False,
232194
override_mlir=lowered_mlir,
233195
)
234196
options_e2e = set_default_run_config(options_e2e)
235197

236-
compiled_e2e = wave_compile(options_e2e, matmul)
198+
compiled_e2e = wave_compile(options_e2e, gemm)
237199

238200
compiled_e2e(a_tensor, b_tensor, c_tensor)
239201

240-
assert_close(c_tensor, expected, rtol=1e-3, atol=1e-3)
202+
assert_close(c_tensor, expected)
241203

242204

243205
# CHECK-LABEL: test_matmul_water_e2e

0 commit comments

Comments
 (0)