@@ -88,26 +88,24 @@ def matrix_add(
8888 trace = compiled_kernel .compiled_graph
8989 constraints = matrix_add .constraints
9090
91- # Emit Wave dialect MLIR
91+ # Emit Wave dialect MLIR.
9292 wave_dialect_mlir , diagnostics , _ = emit_wave_dialect (
9393 trace , constraints , options_mlir
9494 )
9595
96- # Apply Water PassManager lowering
96+ # Apply Water middle-end pipeline.
9797 lowered_mlir = apply_water_middle_end_passes (wave_dialect_mlir )
9898
9999 print (lowered_mlir )
100100
101- # Create test tensors
102101 shape = (128 , 128 )
103102 a_tensor = device_randn (* shape , dtype = torch .float16 )
104103 b_tensor = device_randn (* shape , dtype = torch .float16 )
105104 c_tensor = device_zeros (* shape , dtype = torch .float16 )
106105
107- # Expected result (CPU computation)
108106 expected = a_tensor + b_tensor
109107
110- # Test execution with lowered MLIR
108+ # Test execution with lowered MLIR.
111109 options_e2e = WaveCompileOptions (
112110 subs = subs ,
113111 canonicalize = True ,
@@ -139,105 +137,69 @@ def matrix_add(
139137@run_test
140138def test_matmul_water_e2e ():
141139 """Test Water PassManager with matmul kernel and e2e execution."""
142- torch .manual_seed (0 )
143-
144- # Input sizes
145- M = tkl .sym .M
146- N = tkl .sym .N
147- K = tkl .sym .K
148- # Workgroup tile sizes
149- BLOCK_M = tkl .sym .BLOCK_M
150- BLOCK_N = tkl .sym .BLOCK_N
151- BLOCK_K = tkl .sym .BLOCK_K
152- # Address space (for GPU, shared(1) or global(0))
153- ADDRESS_SPACE = tkl .sym .ADDRESS_SPACE
154- dtype = tkl .f16
155-
156- # Define constraints for matmul
157- constraints : list [tkw .Constraint ] = [tkw .WorkgroupConstraint (M , BLOCK_M , 0 )]
158- constraints += [tkw .WorkgroupConstraint (N , BLOCK_N , 1 )]
159- constraints += [tkw .TilingConstraint (K , BLOCK_K )]
160- constraints += [tkw .WaveConstraint (M , sympy .floor (BLOCK_M / 2 ))]
161- constraints += [tkw .WaveConstraint (N , sympy .floor (BLOCK_N / 2 ))]
162- constraints += [
163- tkw .HardwareConstraint (threads_per_wave = 64 , mma_type = MMAType .F32_32x32x8_F16 )
164- ]
165-
166- @tkw .wave (constraints )
167- def matmul (
168- a : tkl .Memory [M , K , ADDRESS_SPACE , dtype ],
169- b : tkl .Memory [N , K , ADDRESS_SPACE , dtype ],
170- c : tkl .Memory [M , N , GLOBAL_ADDRESS_SPACE , tkl .f32 ],
171- ):
172- c_reg = tkl .Register [M , N , tkl .f32 ](0.0 )
140+ from wave_lang .kernel .wave .templates .gemm import get_gemm_kernel
173141
174- @tkw .iterate (K , init_args = [c_reg ])
175- def repeat (acc : tkl .Register [M , N , tkl .f32 ]) -> tkl .Register [M , N , tkl .f32 ]:
176- a_reg = tkw .read (a )
177- b_reg = tkw .read (b )
178- acc = tkw .mma (a_reg , b_reg , acc )
179- return acc
180-
181- tkw .write (repeat , c )
142+ torch .manual_seed (0 )
182143
144+ # Matrix dimensions.
183145 m = 1024
184146 n = 5120
185147 k = 640
186- # Set parameters for compilation
187- subs : dict [str | IndexSymbol , Any ] = {
188- ADDRESS_SPACE : SHARED_ADDRESS_SPACE ,
189- BLOCK_M : 64 ,
190- BLOCK_N : 64 ,
191- BLOCK_K : 32 ,
192- M : m ,
193- N : n ,
194- K : k ,
195- }
148+
149+ # Get GEMM kernel from template.
150+ gemm , hyperparams , _ = get_gemm_kernel (
151+ shape = (m , n , k ),
152+ dynamic_dims = False ,
153+ mfma_variant = MMAType .F32_32x32x8_F16 ,
154+ block_shape = (64 , 64 , 32 ),
155+ waves_per_block = (2 , 2 ),
156+ )
196157
197158 options_mlir = WaveCompileOptions (
198- subs = subs ,
159+ subs = hyperparams ,
199160 compile_to_mlir = True ,
200161 location_capture_config = LocationCaptureConfig (level = LocationCaptureLevel .NONE ),
201162 enforce_locations = False ,
202- print_mlir = True ,
203163 )
204164 options_mlir = set_default_run_config (options_mlir )
205165
206- compiled_kernel = wave_compile (options_mlir , matmul )
166+ compiled_kernel = wave_compile (options_mlir , gemm )
207167 trace = compiled_kernel .compiled_graph
208- constraints = matmul .constraints
168+ constraints = gemm .constraints
209169
210- # Emit Wave dialect MLIR
211- wave_dialect_mlir , diagnostics , _ = emit_wave_dialect (trace , constraints , options_mlir )
170+ # Emit Wave dialect MLIR.
171+ wave_dialect_mlir , diagnostics , _ = emit_wave_dialect (
172+ trace , constraints , options_mlir
173+ )
212174
213- # Apply Water PassManager lowering
175+ # Apply Water middle-end pipeline.
214176 lowered_mlir = apply_water_middle_end_passes (wave_dialect_mlir )
215177
216178 print (lowered_mlir )
217179
218- # Create test tensors
180+ # Create test tensors on device.
219181 a_tensor = device_randn (m , k , dtype = torch .float16 )
220182 b_tensor = device_randn (n , k , dtype = torch .float16 ) # Note: transposed in matmul
221183 c_tensor = device_zeros (m , n , dtype = torch .float32 )
222184
223- # Expected result (CPU computation)
224- expected = torch .matmul (a_tensor . float () , b_tensor .T .float () )
185+ # Expected result using PyTorch reference.
186+ expected = torch .matmul (a_tensor , b_tensor .T ) .float ()
225187
226- # Test execution with lowered MLIR
188+ # Test execution with lowered MLIR.
227189 options_e2e = WaveCompileOptions (
228- subs = subs ,
190+ subs = hyperparams ,
229191 canonicalize = True ,
230192 location_capture_config = LocationCaptureConfig (level = LocationCaptureLevel .NONE ),
231193 enforce_locations = False ,
232194 override_mlir = lowered_mlir ,
233195 )
234196 options_e2e = set_default_run_config (options_e2e )
235197
236- compiled_e2e = wave_compile (options_e2e , matmul )
198+ compiled_e2e = wave_compile (options_e2e , gemm )
237199
238200 compiled_e2e (a_tensor , b_tensor , c_tensor )
239201
240- assert_close (c_tensor , expected , rtol = 1e-3 , atol = 1e-3 )
202+ assert_close (c_tensor , expected )
241203
242204
243205# CHECK-LABEL: test_matmul_water_e2e
0 commit comments