Skip to content

Commit 63918c8

Browse files
committed
[water] Add ordered_syms attribute for memref read/write lowering
Add ordered_syms attribute to wave.read and wave.write operations to preserve dimension ordering information when the memory operand is a MemRefType. This is needed because the index dictionary keys are dimension names, and we need to know their ordering for correct lowering. The ResolveDistributedAllocations pass now sets this attribute on read/write ops that use resolved allocate results, extracting the ordered symbols from the original WaveTensorType before it becomes a MemRefType. The LowerReadWriteOps code now handles both WaveTensorType (using getShape()) and MemRefType (using ordered_syms attribute) memory operands. Signed-off-by: tyb0807 <[email protected]>
1 parent a909927 commit 63918c8

File tree

4 files changed

+81
-8
lines changed

4 files changed

+81
-8
lines changed

water/include/water/Dialect/Wave/IR/WaveOps.td

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,14 +310,21 @@ def ReadOp : WaveOp<"read", [
310310
let description = [{
311311
Moves data from a memory-resident tensor to a register-resident tensor
312312
preserving the shape.
313+
314+
When the memory operand is a MemRefType (after ResolveDistributedAllocations),
315+
the `ordered_syms` attribute provides the ordered dimension symbols that
316+
correspond to the memref dimensions. This is needed because the index
317+
dictionary keys are dimension names, and we need to know their ordering.
313318
}];
314319

315320
let arguments = !con((ins
316321
Arg<WaveMemoryType, "Memory to read from">:$memory,
317322
Arg<OptionalAttr<I64Attr>,
318323
"Number of elements processed by each thread">:$elements_per_thread,
319324
Arg<OptionalAttr<WaveReadWriteBoundsAttr>,
320-
"Bound expressions for each symbolic dimension">:$bounds
325+
"Bound expressions for each symbolic dimension">:$bounds,
326+
Arg<OptionalAttr<ArrayAttr>,
327+
"Ordered dimension symbols for memref operands">:$ordered_syms
321328
), commonArguments);
322329

323330
let results = (outs
@@ -364,6 +371,11 @@ def WriteOp : WaveOp<"write", [
364371
let description = [{
365372
Moves data from a register-resident tensor into a memory-resident tensor
366373
preserving the shape.
374+
375+
When the memory operand is a MemRefType (after ResolveDistributedAllocations),
376+
the `ordered_syms` attribute provides the ordered dimension symbols that
377+
correspond to the memref dimensions. This is needed because the index
378+
dictionary keys are dimension names, and we need to know their ordering.
367379
}];
368380

369381
let arguments = !con((ins
@@ -372,7 +384,9 @@ def WriteOp : WaveOp<"write", [
372384
Arg<OptionalAttr<I64Attr>,
373385
"Number of elements processed by each thread">:$elements_per_thread,
374386
Arg<OptionalAttr<WaveReadWriteBoundsAttr>,
375-
"Bound expressions for each symbolic dimension">:$bounds
387+
"Bound expressions for each symbolic dimension">:$bounds,
388+
Arg<OptionalAttr<ArrayAttr>,
389+
"Ordered dimension symbols for memref operands">:$ordered_syms
376390
), commonArguments);
377391

378392
let assemblyFormat =

water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -432,14 +432,28 @@ static FailureOr<MemAccessInfo>
432432
createMemoryIndicesAndMask(ConversionPatternRewriter &rewriter,
433433
const TypeConverter *typeConverter, OpTy op,
434434
Type memoryTypeArg, VectorType vectorType) {
435-
auto memoryType = dyn_cast<wave::WaveTensorType>(memoryTypeArg);
436-
if (!memoryType)
437-
return rewriter.notifyMatchFailure(
438-
op, "lowering with MemRefType memory not yet implemented");
439-
440435
int64_t elementsPerThread = vectorType.getNumElements();
441436

442-
ArrayRef<wave::WaveSymbolAttr> orderedSyms = memoryType.getShape();
437+
// Get ordered symbols from either WaveTensorType or ordered_syms attribute.
438+
SmallVector<wave::WaveSymbolAttr> orderedSymsStorage;
439+
ArrayRef<wave::WaveSymbolAttr> orderedSyms;
440+
441+
if (auto memoryType = dyn_cast<wave::WaveTensorType>(memoryTypeArg)) {
442+
orderedSyms = memoryType.getShape();
443+
} else if (isa<MemRefType>(memoryTypeArg)) {
444+
// For MemRefType, get ordered symbols from the ordered_syms attribute.
445+
ArrayAttr orderedSymsAttr = op.getOrderedSymsAttr();
446+
if (!orderedSymsAttr)
447+
return rewriter.notifyMatchFailure(
448+
op, "MemRefType memory requires ordered_syms attribute");
449+
orderedSymsStorage.reserve(orderedSymsAttr.size());
450+
for (Attribute attr : orderedSymsAttr)
451+
orderedSymsStorage.push_back(cast<wave::WaveSymbolAttr>(attr));
452+
orderedSyms = orderedSymsStorage;
453+
} else {
454+
return rewriter.notifyMatchFailure(
455+
op, "unsupported memory type for lowering");
456+
}
443457

444458
wave::WaveReadWriteBoundsAttr boundsDict = op.getBoundsAttr();
445459
wave::WaveHyperparameterAttr hyper =

water/lib/Dialect/Wave/Transforms/ResolveDistributedAllocations.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,23 @@ struct ResolveDistributedAllocations
5555
return;
5656
}
5757

58+
// Extract ordered dimension symbols from the original tensor type.
59+
ArrayRef<WaveSymbolAttr> shapeSyms = tensorType.getShape();
60+
SmallVector<Attribute> orderedSyms(shapeSyms.begin(), shapeSyms.end());
61+
ArrayAttr orderedSymsAttr =
62+
ArrayAttr::get(allocateOp.getContext(), orderedSyms);
63+
64+
// Update read/write ops that use this allocate result to include
65+
// ordered_syms, since they will need it for lowering after the memory
66+
// operand becomes a MemRefType.
67+
for (Operation *user : allocateOp.getResult().getUsers()) {
68+
if (auto readOp = dyn_cast<ReadOp>(user))
69+
readOp.setOrderedSymsAttr(orderedSymsAttr);
70+
else if (auto writeOp = dyn_cast<WriteOp>(user))
71+
writeOp.setOrderedSymsAttr(orderedSymsAttr);
72+
}
73+
74+
// Update the result type in place.
5875
allocateOp.getResult().setType(memrefType);
5976
});
6077
return result;

water/test/Dialect/Wave/lower-wave-to-mlir.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,3 +867,31 @@ module attributes {wave.normal_form = #wave.normal_form<full_types,memory_only_t
867867
return
868868
}
869869
}
870+
871+
// -----
872+
873+
// Test read/write lowering with MemRefType memory operand.
874+
// This simulates the state after ResolveDistributedAllocations pass has run.
875+
module attributes {wave.normal_form = #wave.normal_form<full_types,memory_only_types,resolved_allocations>} {
876+
// CHECK-LABEL: @lower_read_write_memref
877+
func.func @lower_read_write_memref(%mem: memref<64x64xf16, #gpu.address_space<workgroup>>)
878+
attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_N = 64}>} {
879+
// Test ReadOp lowering when memory operand is already a MemRefType
880+
// and ordered_syms attribute provides the dimension ordering.
881+
// CHECK: %[[READ:.*]] = vector.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x64xf16, #gpu.address_space<workgroup>>, vector<8xf16>
882+
%0 = wave.read %mem index [{
883+
BLOCK_M : [#wave.index_symbol<T0>, #wave.symbol<"BLOCK_M">] -> (T0 mod 64, 1, 64),
884+
BLOCK_N : [#wave.index_symbol<T1>, #wave.symbol<"BLOCK_N">] -> (T1 * 8, 8, 1)
885+
}] { ordered_syms = [#wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">] }
886+
: (memref<64x64xf16, #gpu.address_space<workgroup>>) -> vector<8xf16>
887+
888+
// Test WriteOp lowering when memory operand is already a MemRefType.
889+
// CHECK: vector.store %[[READ]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x64xf16, #gpu.address_space<workgroup>>, vector<8xf16>
890+
wave.write %0, %mem index [{
891+
BLOCK_M : [#wave.index_symbol<T0>, #wave.symbol<"BLOCK_M">] -> (T0 mod 64, 1, 64),
892+
BLOCK_N : [#wave.index_symbol<T1>, #wave.symbol<"BLOCK_N">] -> (T1 * 8, 8, 1)
893+
}] { ordered_syms = [#wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">] }
894+
: vector<8xf16>, memref<64x64xf16, #gpu.address_space<workgroup>>
895+
return
896+
}
897+
}

0 commit comments

Comments
 (0)