Skip to content

Conversation

@tyb0807
Copy link
Contributor

@tyb0807 tyb0807 commented Jan 7, 2026

After ResolveDistributedAllocations converts WaveTensorType to MemRefType,
read/write ops need to determine dimension ordering for correct lowering.

With IndexExprsSpecified as a precondition for LowerWaveToMLIR, read/write
ops are guaranteed to have index expressions. Since DictAttr is internally
an ArrayRef, the index dictionary keys are ordered and can
be used directly for dimension ordering.

Fixes #659.

@ftynse
Copy link
Contributor

ftynse commented Jan 8, 2026

This is needed because the index dictionary keys are
dimension names, and we need to know their ordering for correct lowering.

DictAttr is an ArrayRef<NamedAttribute> internally, so it is always ordered (otherwise we would have had massive flakiness or all tests without custom syntax for attributes). So do we actually need this? Maybe we rather need a verifier that the index attribute is present on memory operations that use memref types?

@tyb0807
Copy link
Contributor Author

tyb0807 commented Jan 8, 2026

Awesome, I hate it so much when having to introduce a new attribute solely to communicate this information. I think at this point we can have a normal form requiring all memory ops to have index attribute before getting to lowering pass instead of a rigid verification for memref-typed memory ops. WDYT?

@ftynse
Copy link
Contributor

ftynse commented Jan 8, 2026

Normal forms are not free, neither cognitively nor in compile time. Do we have a situation where we want operations with memref types and without index expressions? If not, I'd rather not add complexity to the system to support that and have it either in the main verifier or in the resolved-distribution form.

Separately, I think index expressions we compute follow the order of dimensions in shapes, but this is purely an artifact of the mechanics: we just happen follow the tensor shape in the loop computing the index expression. For index expressions converted from Python, this may not be guaranteed. So we may need to check/reorder them here, and potentially consider verifying this at some point.

@tyb0807
Copy link
Contributor Author

tyb0807 commented Jan 8, 2026

No I wanted to have a normal form that checks that index expressions are present for ALL memory ops in the IR, not only for memref-typed ones. This does not hold for IR before running index analysis, so we cannot just put it in the op verifier, right?

Agreed on the second point. I guess then we'll need to ensure the order matches between index expressions and tensor shape WHEN creating the memory ops from Python. Then make sure that's still the case when resolving distributed allocation, rewriting memory ops to use memref types?

@ftynse
Copy link
Contributor

ftynse commented Jan 8, 2026

No I wanted to have a normal form that checks that index expressions are present for ALL memory ops in the IR, not only for memref-typed ones.

We already have something:

NORMAL_FORM_INDEX_EXPRS,
is this not enough?

@tyb0807 tyb0807 changed the title [water] Add ordered_syms attribute for memref read/write lowering [water] Support read/write lowering with MemRefType memory operands Jan 12, 2026
After ResolveDistributedAllocations converts WaveTensorType to MemRefType,
read/write ops need to determine dimension ordering for correct lowering.

With IndexExprsSpecified as a precondition for LowerWaveToMLIR, read/write
ops are guaranteed to have index expressions. Since DictAttr is internally
an ArrayRef<NamedAttribute>, the index dictionary keys are ordered and can
be used directly for dimension ordering.

Signed-off-by: tyb0807 <[email protected]>
@tyb0807 tyb0807 merged commit b38f8de into iree-org:main Jan 12, 2026
14 of 15 checks passed
@tyb0807 tyb0807 deleted the ordered_syms_attr branch January 12, 2026 14:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Use distributed_shape for Read/WriteOp accessing memory from AllocateOp

2 participants