-
Notifications
You must be signed in to change notification settings - Fork 25
[water] Use distributed shape when accessing allocated LDS #667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
When accessing shared memory allocated by wave.allocate, the read/write ops are supposed to operate on the distributed shape and not the logical shape. This PR implements the logic to handle this when lowering read/write ops to vector.load/store, transforming logical indices to distributed indices. Signed-off-by: tyb0807 <[email protected]>
ftynse
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you thought of alternatives that don't require jumping through hoops in lowering?
| mlir::ArrayAttr indexAttr = getIndexAttr(); | ||
| if (indexAttr && !indexAttr.empty()) | ||
| return false; // Regular allocations have index expressions. | ||
|
|
||
| wave::WaveExprListAttr distributedShape = getDistributedShape(); | ||
|
|
||
| // Empty symbol list in distributed_shape. | ||
| if (!distributedShape.getSymbols().empty()) | ||
| return false; | ||
|
|
||
| // Distributed shape rank must be 1 (flattened). | ||
| if (distributedShape.getMap().getNumResults() != 1) | ||
| return false; | ||
|
|
||
| // Distributed shape must be constant. | ||
| if (!distributedShape.getMap().isConstant()) | ||
| return false; | ||
|
|
||
| return true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't put more than a couple LoC in as inlined code in a .td.
| if (!llvm::hasSingleElement(indexAttr.getValue())) { | ||
| return op->emitError() << "'index' attribute must contain exactly one " | ||
| "dictionary for this op, got " | ||
| << indexAttr.size(); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see a test for this error message.
|
|
||
| auto indexDict = dyn_cast<DictionaryAttr>(indexAttr[0]); | ||
| if (!indexDict) | ||
| return success(); // Empty dictionary is valid |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic contradicts the message above that claims "must contain exactly one" whereas this clearly accepts there being zero dictionaries.
| if (!indexDict) | ||
| return success(); // Empty dictionary is valid | ||
|
|
||
| // Check count matches |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Systematically use full stops at the end of the sencence.
| return op->emitError() << "number of index expressions (" | ||
| << indexDict.size() | ||
| << ") must match logical shape rank (" | ||
| << tensorShape.size() << ")"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a low-value diagnostic that suppresses the higher-value diagnostic below. If there is a missing symbol, it will just say "size mismatch" without saying what is missing, leaving the user figure out something that code should have done. You can instead keep the diagnostic below and add a diagnostic for symbols that are present in the index expression but not used. Arguably, this should be an error and not a warning since nothing bad will happen due to the unnecessary symbol being present
| // Check if the source is a MemRefType with shared memory. | ||
| if (auto sourceMemrefType = | ||
| dyn_cast<MemRefType>(sourceValue.getType())) { | ||
| if (auto addrSpace = sourceMemrefType.getMemorySpace()) { | ||
| if (auto gpuAddrSpace = | ||
| dyn_cast<gpu::AddressSpaceAttr>(addrSpace)) { | ||
| if (gpuAddrSpace.getValue() == gpu::AddressSpace::Workgroup) { | ||
| SmallVector<int64_t> shape(sourceMemrefType.getShape().begin(), | ||
| sourceMemrefType.getShape().end()); | ||
| return shape; | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This long block is similar to the one above and could have been turned into a function/lambda
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I counted 11 levels of indentation???????????
| func.func @lower_alloc_view() attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 4, BLOCK_K = 28, M = 128, N=128, K= 128}>} { | ||
| // CHECK: %[[BUFF:.*]] = memref.alloc() : memref<256xi8, #gpu.address_space<workgroup>> | ||
| %parent = wave.allocate { distributed_shape = #wave.expr_list<[] -> (256)> } | ||
| // CHECK: %[[BUFF:.*]] = memref.alloc() : memref<2097152xi8, #gpu.address_space<workgroup>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would much rather add a new test.
| // Child allocations are exempt from rank constraints (this should be valid) | ||
| // distributed_shape has rank 1, logical shape has rank 2 | ||
| // CHECK: %[[PARENT:.*]] = wave.allocate {distributed_shape = #wave.expr_list<[] -> (256)>} | ||
| // CHECK: wave.allocate in %[[PARENT]] | ||
| %buf = wave.allocate in %parent : !wave.tensor<[@M, @K] of i8, <shared>> | ||
| { distributed_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, offset = 128} | ||
| : !wave.tensor<[@M, @K] of bf16, <shared>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't make sense to me, why is it okay to have a lower-rank distributed shape here?
| if (auto sourceMemrefType = | ||
| dyn_cast<MemRefType>(sourceValue.getType())) { | ||
| if (auto addrSpace = sourceMemrefType.getMemorySpace()) { | ||
| if (auto gpuAddrSpace = | ||
| dyn_cast<gpu::AddressSpaceAttr>(addrSpace)) { | ||
| if (gpuAddrSpace.getValue() == | ||
| gpu::AddressSpace::Workgroup) { | ||
| memoryOperand = sourceValue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This yet again looks identical to the code I've seen above.
ftynse
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.
Yes, and I came to conclusion that doing this transformation during lowering might be the least worst. Just to clarify the context/problem: when lowering Considered alternatives are:
What other alternatives you have in mind? Please shed some light on this matter! |
|
Ok, the more I think about this, the more I feel the cleanest solution would be:
WDYT? |
Perhaps. This is something that a commit/PR message can explain, so the reviewers and future contributors can understand the reasoning, not guess it.
I did not really think about it, just reviewing what I see. Given the conjunction of conditions in the code that are required for the lowering to work, it looks excessively fragile, which makes me worried about its scalability and longer-term health. Hence the question: it may well be the best available approach, or there may be others, I just don't know and don't have enough information to judge, so I'm requesting it. If I had to figure it out myself, I might as well write the code at that point. The specific part that appears problematic to me, though it was present in the code already, is having to "look through" the unrealized conversion cast. This feels like leaking the abstraction of the dialect conversion, which may be the wrong one here as we have the same wave tensor type converted to different upstream types based on usage.
Indeed, this looks like the least "internally complex" approach. It can be similar to, or even integrate with, the one we have for register-resident tensors. Note that it is possible to call
It may well be the same amount of work, but is consistent with the approach taken for register-resident tensors that are converted to vectors, so less overall complexity in the system (one approach rather than two approaches). I don't necessarily insist on one approach or another, I haven't thought too much about the problem, but the more complex the logic is, the stronger should be the arguments behind it. |
When accessing shared memory allocated by wave.allocate, the read/write ops are supposed to operate on the distributed shape and not the logical shape. This PR implements the logic to handle this when lowering read/write ops to vector.load/store, transforming logical indices to distributed indices.
Fixes #659.