Skip to content

Commit 4cb62c3

Browse files
committed
[water] Support read/write lowering with MemRefType memory operands
After ResolveDistributedAllocations converts WaveTensorType to MemRefType, read/write ops need to determine dimension ordering for correct lowering. With IndexExprsSpecified as a precondition for LowerWaveToMLIR, read/write ops are guaranteed to have index expressions. Since DictAttr is internally an ArrayRef<NamedAttribute>, the index dictionary keys are ordered and can be used directly for dimension ordering. Signed-off-by: tyb0807 <[email protected]>
1 parent 74441e3 commit 4cb62c3

File tree

3 files changed

+36
-7
lines changed

3 files changed

+36
-7
lines changed

water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -432,15 +432,8 @@ static FailureOr<MemAccessInfo>
432432
createMemoryIndicesAndMask(ConversionPatternRewriter &rewriter,
433433
const TypeConverter *typeConverter, OpTy op,
434434
Type memoryTypeArg, VectorType vectorType) {
435-
auto memoryType = dyn_cast<wave::WaveTensorType>(memoryTypeArg);
436-
if (!memoryType)
437-
return rewriter.notifyMatchFailure(
438-
op, "lowering with MemRefType memory not yet implemented");
439-
440435
int64_t elementsPerThread = vectorType.getNumElements();
441436

442-
ArrayRef<wave::WaveSymbolAttr> orderedSyms = memoryType.getShape();
443-
444437
wave::WaveReadWriteBoundsAttr boundsDict = op.getBoundsAttr();
445438
wave::WaveHyperparameterAttr hyper =
446439
static_cast<const wave::WaveTypeConverter &>(*typeConverter)
@@ -459,6 +452,15 @@ createMemoryIndicesAndMask(ConversionPatternRewriter &rewriter,
459452
assert(llvm::hasSingleElement(indexArr.getValue()) &&
460453
"'index' must be an array with exactly one dictionary");
461454
DictionaryAttr indexDict = cast<DictionaryAttr>(indexArr[0]);
455+
456+
// Get ordered symbols from the index dictionary keys.
457+
// DictAttr is internally an ArrayRef<NamedAttribute>, so keys are ordered.
458+
SmallVector<wave::WaveSymbolAttr> orderedSymsStorage;
459+
orderedSymsStorage.reserve(indexDict.size());
460+
for (NamedAttribute namedAttr : indexDict)
461+
orderedSymsStorage.push_back(wave::WaveSymbolAttr::get(
462+
op.getContext(), namedAttr.getName().strref()));
463+
ArrayRef<wave::WaveSymbolAttr> orderedSyms = orderedSymsStorage;
462464
std::optional<int64_t> vectorizedDim =
463465
wave::getPositionOfVectorizedDim(orderedSyms, indexDict, hyper);
464466

water/lib/Dialect/Wave/Transforms/ResolveDistributedAllocations.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ struct ResolveDistributedAllocations
5555
return;
5656
}
5757

58+
// Update the result type in place.
5859
allocateOp.getResult().setType(memrefType);
5960
});
6061
return result;

water/test/Dialect/Wave/lower-wave-to-mlir.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,3 +867,29 @@ module attributes {wave.normal_form = #wave.normal_form<full_types,memory_only_t
867867
return
868868
}
869869
}
870+
871+
// -----
872+
873+
// Test read/write lowering with MemRefType memory operand.
874+
// This simulates the state after ResolveDistributedAllocations pass has run.
875+
// Dimension ordering is derived from the index dictionary keys (DictAttr is ordered).
876+
module attributes {wave.normal_form = #wave.normal_form<full_types,index_exprs,memory_only_types,resolved_allocations>} {
877+
// CHECK-LABEL: @lower_read_write_memref
878+
func.func @lower_read_write_memref(%mem: memref<64x64xf16, #gpu.address_space<workgroup>>)
879+
attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_N = 64}>} {
880+
// CHECK: %[[READ:.*]] = vector.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x64xf16, #gpu.address_space<workgroup>>, vector<8xf16>
881+
%0 = wave.read %mem index [{
882+
BLOCK_M : [#wave.index_symbol<T0>, #wave.symbol<"BLOCK_M">] -> (T0 mod 64, 1, 64),
883+
BLOCK_N : [#wave.index_symbol<T1>, #wave.symbol<"BLOCK_N">] -> (T1 * 8, 8, 1)
884+
}]
885+
: (memref<64x64xf16, #gpu.address_space<workgroup>>) -> vector<8xf16>
886+
887+
// CHECK: vector.store %[[READ]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x64xf16, #gpu.address_space<workgroup>>, vector<8xf16>
888+
wave.write %0, %mem index [{
889+
BLOCK_M : [#wave.index_symbol<T0>, #wave.symbol<"BLOCK_M">] -> (T0 mod 64, 1, 64),
890+
BLOCK_N : [#wave.index_symbol<T1>, #wave.symbol<"BLOCK_N">] -> (T1 * 8, 8, 1)
891+
}]
892+
: vector<8xf16>, memref<64x64xf16, #gpu.address_space<workgroup>>
893+
return
894+
}
895+
}

0 commit comments

Comments
 (0)