Skip to content

Conversation

@tyb0807
Copy link
Contributor

@tyb0807 tyb0807 commented Jan 2, 2026

When accessing shared memory allocated by wave.allocate, the read/write ops are supposed to operate on the distributed shape and not the logical shape. This PR implements the logic to handle this when lowering read/write ops to vector.load/store, transforming logical indices to distributed indices.

Fixes #659.

When accessing shared memory allocated by wave.allocate, the read/write
ops are supposed to operate on the distributed shape and not the logical
shape. This PR implements the logic to handle this when lowering
read/write ops to vector.load/store, transforming logical indices to
distributed indices.

Signed-off-by: tyb0807 <[email protected]>
Copy link
Contributor

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you thought of alternatives that don't require jumping through hoops in lowering?

Comment on lines +282 to +300
mlir::ArrayAttr indexAttr = getIndexAttr();
if (indexAttr && !indexAttr.empty())
return false; // Regular allocations have index expressions.

wave::WaveExprListAttr distributedShape = getDistributedShape();

// Empty symbol list in distributed_shape.
if (!distributedShape.getSymbols().empty())
return false;

// Distributed shape rank must be 1 (flattened).
if (distributedShape.getMap().getNumResults() != 1)
return false;

// Distributed shape must be constant.
if (!distributedShape.getMap().isConstant())
return false;

return true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't put more than a couple LoC in as inlined code in a .td.

Comment on lines +100 to +104
if (!llvm::hasSingleElement(indexAttr.getValue())) {
return op->emitError() << "'index' attribute must contain exactly one "
"dictionary for this op, got "
<< indexAttr.size();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a test for this error message.


auto indexDict = dyn_cast<DictionaryAttr>(indexAttr[0]);
if (!indexDict)
return success(); // Empty dictionary is valid
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic contradicts the message above that claims "must contain exactly one" whereas this clearly accepts there being zero dictionaries.

if (!indexDict)
return success(); // Empty dictionary is valid

// Check count matches
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Systematically use full stops at the end of the sencence.

Comment on lines +112 to +115
return op->emitError() << "number of index expressions ("
<< indexDict.size()
<< ") must match logical shape rank ("
<< tensorShape.size() << ")";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a low-value diagnostic that suppresses the higher-value diagnostic below. If there is a missing symbol, it will just say "size mismatch" without saying what is missing, leaving the user figure out something that code should have done. You can instead keep the diagnostic below and add a diagnostic for symbols that are present in the index expression but not used. Arguably, this should be an error and not a warning since nothing bad will happen due to the unnecessary symbol being present

Comment on lines +457 to +469
// Check if the source is a MemRefType with shared memory.
if (auto sourceMemrefType =
dyn_cast<MemRefType>(sourceValue.getType())) {
if (auto addrSpace = sourceMemrefType.getMemorySpace()) {
if (auto gpuAddrSpace =
dyn_cast<gpu::AddressSpaceAttr>(addrSpace)) {
if (gpuAddrSpace.getValue() == gpu::AddressSpace::Workgroup) {
SmallVector<int64_t> shape(sourceMemrefType.getShape().begin(),
sourceMemrefType.getShape().end());
return shape;
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This long block is similar to the one above and could have been turned into a function/lambda

Comment on lines +570 to +580
}
}
}
}
}
}
}
}
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I counted 11 levels of indentation???????????

func.func @lower_alloc_view() attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 4, BLOCK_K = 28, M = 128, N=128, K= 128}>} {
// CHECK: %[[BUFF:.*]] = memref.alloc() : memref<256xi8, #gpu.address_space<workgroup>>
%parent = wave.allocate { distributed_shape = #wave.expr_list<[] -> (256)> }
// CHECK: %[[BUFF:.*]] = memref.alloc() : memref<2097152xi8, #gpu.address_space<workgroup>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would much rather add a new test.

Comment on lines +313 to +319
// Child allocations are exempt from rank constraints (this should be valid)
// distributed_shape has rank 1, logical shape has rank 2
// CHECK: %[[PARENT:.*]] = wave.allocate {distributed_shape = #wave.expr_list<[] -> (256)>}
// CHECK: wave.allocate in %[[PARENT]]
%buf = wave.allocate in %parent : !wave.tensor<[@M, @K] of i8, <shared>>
{ distributed_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, offset = 128}
: !wave.tensor<[@M, @K] of bf16, <shared>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't make sense to me, why is it okay to have a lower-rank distributed shape here?

Comment on lines +562 to +569
if (auto sourceMemrefType =
dyn_cast<MemRefType>(sourceValue.getType())) {
if (auto addrSpace = sourceMemrefType.getMemorySpace()) {
if (auto gpuAddrSpace =
dyn_cast<gpu::AddressSpaceAttr>(addrSpace)) {
if (gpuAddrSpace.getValue() ==
gpu::AddressSpace::Workgroup) {
memoryOperand = sourceValue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This yet again looks identical to the code I've seen above.

Copy link
Contributor

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.

@tyb0807
Copy link
Contributor Author

tyb0807 commented Jan 5, 2026

Have you thought of alternatives that don't require jumping through hoops in lowering?

Yes, and I came to conclusion that doing this transformation during lowering might be the least worst.

Just to clarify the context/problem: when lowering wave.read/wave.write ops that access LDS memory from wave.allocate, the allocated memref has a distributed shape that is different from the logical shape (WaveTensorType shape). The type converter creates an unrealized_conversion_cast to bridge this mismatch, but the lowered read/write ops need to use the distributed memref for correct memory access.

Considered alternatives are:

  1. Create a pre-lowering pass that propagates distributed_shape attribute from wave.allocate to read/write ops that access the allocated LDS memory. This can either use dataflow analysis or not, though in this case I think there's exactly one source (the wave.allocate) and no conflicts possible, so dataflow analysis would be an overkill. Anyway, the problem is not "how to propagate the attribute", but even with the attribute, lowering still needs to look through the cast to get the actual memref value. So the attribute becomes just a "signal" that the cast exists, which is not really useful. And we still need to "jump through hoops" to get that memref value in this approach.

  2. Create a post-lowering pass to reconcile the casts, i.e. walk unrealized_conversion_cast ops and replace uses with source memrefs. This is also messy, as we can't just replaceAllUsesWith, as we need to recreate each vector op with correct memref type. Furthermore, we would need to handle different flavors of vector read/write ops (e.g. vector.load, vector.store, vector.transfer_read, vector.transfer_write, vector.maskedload, vector.maskedstore). I don't think there's a trait/interface that group these (and only these) ops, right?

  3. Make wave.read/write ops accept memref type, then we still need to separate the lowering into phases: convert wave.allocate to memref.alloc first, then the main lowering pass converts read/write ops, which now have memref memory operands. This doesn't sound any less work tbh.

What other alternatives you have in mind? Please shed some light on this matter!

@tyb0807 tyb0807 closed this Jan 5, 2026
@tyb0807 tyb0807 reopened this Jan 5, 2026
@tyb0807
Copy link
Contributor Author

tyb0807 commented Jan 5, 2026

Ok, the more I think about this, the more I feel the cleanest solution would be:

  • Create a pre-lowering pass ResolveDistributedAllocations to transform wave.allocate to produce memref-typed result.
  • Make wave.read/write ops accept memref-typed memory operand, so the pass above creates valid IR.
  • Bonus: create wave.view op which takes a memory value create by wave.allocate and returns a memref type of its distributed shape. This can only handle child allocations though, so we'll still need wave.allocate to be able to produce memref-typed result to handle normal allocations. Or we can have a no-op wave.view for normal allocations, where distributed shape = logical shape.

WDYT?

@ftynse
Copy link
Contributor

ftynse commented Jan 6, 2026

Yes, and I came to conclusion that doing this transformation during lowering might be the least worst.

Perhaps. This is something that a commit/PR message can explain, so the reviewers and future contributors can understand the reasoning, not guess it.

What other alternatives you have in mind? Please shed some light on this matter!

I did not really think about it, just reviewing what I see. Given the conjunction of conditions in the code that are required for the lowering to work, it looks excessively fragile, which makes me worried about its scalability and longer-term health. Hence the question: it may well be the best available approach, or there may be others, I just don't know and don't have enough information to judge, so I'm requesting it. If I had to figure it out myself, I might as well write the code at that point.

The specific part that appears problematic to me, though it was present in the code already, is having to "look through" the unrealized conversion cast. This feels like leaking the abstraction of the dialect conversion, which may be the wrong one here as we have the same wave tensor type converted to different upstream types based on usage.

Ok, the more I think about this, the more I feel the cleanest solution would be: [...]

Indeed, this looks like the least "internally complex" approach. It can be similar to, or even integrate with, the one we have for register-resident tensors. Note that it is possible to call setType on a value, the only time when you need to re-create an operation is when its number of results changes. Note also that calling setType from within dialect conversion will likely cause deep problems with the infra, so it will have to be a separate pass.

This doesn't sound any less work tbh.

It may well be the same amount of work, but is consistent with the approach taken for register-resident tensors that are converted to vectors, so less overall complexity in the system (one approach rather than two approaches). I don't necessarily insist on one approach or another, I haven't thought too much about the problem, but the more complex the logic is, the stronger should be the arguments behind it.

@tyb0807 tyb0807 closed this Jan 8, 2026
@tyb0807
Copy link
Contributor Author

tyb0807 commented Jan 8, 2026

Superseded by #677, #684 and #686.

@tyb0807 tyb0807 deleted the dist_shape branch January 8, 2026 06:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Use distributed_shape for Read/WriteOp accessing memory from AllocateOp

2 participants