Skip to content

Commit 5a4a3e3

Browse files
committed
[water] Add lowering pattern using LDS promotion for write group
Signed-off-by: tyb0807 <[email protected]>
1 parent 015d79c commit 5a4a3e3

File tree

7 files changed

+1148
-39
lines changed

7 files changed

+1148
-39
lines changed

water/include/water/Dialect/Wave/IR/WaveOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,9 @@ def WriteOp : WaveOp<"write", [
306306
Arg<OptionalAttr<I64Attr>,
307307
"Number of elements processed by each thread">:$elements_per_thread,
308308
Arg<OptionalAttr<WaveReadWriteBoundsAttr>,
309-
"Bound expressions for each symbolic dimension">:$bounds
309+
"Bound expressions for each symbolic dimension">:$bounds,
310+
Arg<OptionalAttr<WaveMemoryAccessPatternAttr>,
311+
"Memory access pattern controlling LDS promotion">:$memory_access_pattern
310312
), commonArguments);
311313

312314
let assemblyFormat =

water/lib/Dialect/Wave/IR/WaveAttrs.cpp

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -460,24 +460,24 @@ DeviceConstraintAttr::verify(function_ref<InFlightDiagnostic()> emitError,
460460
// WaveMemoryAccessPatternAttr
461461
//===----------------------------------------------------------------------===//
462462

463-
LogicalResult
464-
WaveMemoryAccessPatternAttr::verify(function_ref<InFlightDiagnostic()> emitError,
465-
bool use_lds_promotion, StringRef group_id,
466-
WaveExprListAttr lds_block_global_base,
467-
WaveExprListAttr lds_block_shape,
468-
WaveExprListAttr lds_load_indices,
469-
WaveExprListAttr lds_load_vector_sizes,
470-
WaveExprListAttr global_store_indices) {
463+
LogicalResult WaveMemoryAccessPatternAttr::verify(
464+
function_ref<InFlightDiagnostic()> emitError, bool use_lds_promotion,
465+
StringRef group_id, WaveExprListAttr lds_block_global_base,
466+
WaveExprListAttr lds_block_shape, WaveExprListAttr lds_load_indices,
467+
WaveExprListAttr lds_load_vector_sizes,
468+
WaveExprListAttr global_store_indices) {
471469
// Validate group_id is not empty
472470
if (group_id.empty()) {
473471
return emitError() << "group_id cannot be empty";
474472
}
475473

476-
// When LDS promotion is disabled, no LDS-related parameters should be specified
474+
// When LDS promotion is disabled, no LDS-related parameters should be
475+
// specified
477476
if (!use_lds_promotion) {
478477
if (lds_block_global_base || lds_block_shape || lds_load_indices ||
479478
lds_load_vector_sizes || global_store_indices) {
480-
return emitError() << "LDS promotion parameters should not be specified when use_lds_promotion=false";
479+
return emitError() << "LDS promotion parameters should not be specified "
480+
"when use_lds_promotion=false";
481481
}
482482
return success();
483483
}
@@ -492,10 +492,12 @@ WaveMemoryAccessPatternAttr::verify(function_ref<InFlightDiagnostic()> emitError
492492
// Check for partial specification - either all or none should be provided
493493
if (hasLdsBase || hasLdsShape || hasLdsLoadIndices || hasLdsLoadVectorSizes ||
494494
hasGlobalStoreIndices) {
495-
if (!hasLdsBase || !hasLdsShape || !hasLdsLoadIndices || !hasLdsLoadVectorSizes ||
496-
!hasGlobalStoreIndices) {
497-
return emitError() << "when LDS promotion is enabled, all LDS parameters must be specified: "
498-
"lds_block_global_base, lds_block_shape, lds_load_indices, lds_load_vector_sizes, "
495+
if (!hasLdsBase || !hasLdsShape || !hasLdsLoadIndices ||
496+
!hasLdsLoadVectorSizes || !hasGlobalStoreIndices) {
497+
return emitError() << "when LDS promotion is enabled, all LDS parameters "
498+
"must be specified: "
499+
"lds_block_global_base, lds_block_shape, "
500+
"lds_load_indices, lds_load_vector_sizes, "
499501
"global_store_indices";
500502
}
501503
}
@@ -504,13 +506,15 @@ WaveMemoryAccessPatternAttr::verify(function_ref<InFlightDiagnostic()> emitError
504506
if (hasLdsBase && hasLdsShape && hasLdsLoadIndices && hasLdsLoadVectorSizes &&
505507
hasGlobalStoreIndices) {
506508

507-
// Validate that lds_block_global_base and lds_block_shape have consistent ranks
509+
// Validate that lds_block_global_base and lds_block_shape have consistent
510+
// ranks
508511
unsigned ldsBaseRank = lds_block_global_base.getRank();
509512
unsigned ldsShapeRank = lds_block_shape.getRank();
510513

511514
if (ldsBaseRank != ldsShapeRank) {
512515
return emitError() << "lds_block_global_base rank (" << ldsBaseRank
513-
<< ") must match lds_block_shape rank (" << ldsShapeRank << ")";
516+
<< ") must match lds_block_shape rank ("
517+
<< ldsShapeRank << ")";
514518
}
515519

516520
// Validate that load indices and vector sizes have consistent ranks
@@ -520,58 +524,68 @@ WaveMemoryAccessPatternAttr::verify(function_ref<InFlightDiagnostic()> emitError
520524

521525
if (ldsLoadIndicesRank != ldsLoadVectorSizesRank) {
522526
return emitError() << "lds_load_indices rank (" << ldsLoadIndicesRank
523-
<< ") must match lds_load_vector_sizes rank (" << ldsLoadVectorSizesRank << ")";
527+
<< ") must match lds_load_vector_sizes rank ("
528+
<< ldsLoadVectorSizesRank << ")";
524529
}
525530

526531
if (ldsBaseRank != ldsLoadIndicesRank) {
527532
return emitError() << "LDS block rank (" << ldsBaseRank
528-
<< ") must match LDS load indices rank (" << ldsLoadIndicesRank << ")";
533+
<< ") must match LDS load indices rank ("
534+
<< ldsLoadIndicesRank << ")";
529535
}
530536

531537
if (ldsBaseRank != globalStoreIndicesRank) {
532538
return emitError() << "LDS block rank (" << ldsBaseRank
533-
<< ") must match global store indices rank (" << globalStoreIndicesRank << ")";
539+
<< ") must match global store indices rank ("
540+
<< globalStoreIndicesRank << ")";
534541
}
535542

536543
// Validate that all symbols are WaveSymbolAttr or WaveIndexSymbolAttr
537544
if (!llvm::all_of(lds_block_global_base.getSymbols(),
538545
llvm::IsaPred<WaveSymbolAttr, WaveIndexSymbolAttr>)) {
539-
return emitError() << "lds_block_global_base must only contain WaveSymbolAttr or WaveIndexSymbolAttr";
546+
return emitError() << "lds_block_global_base must only contain "
547+
"WaveSymbolAttr or WaveIndexSymbolAttr";
540548
}
541549

542550
if (!llvm::all_of(lds_block_shape.getSymbols(),
543551
llvm::IsaPred<WaveSymbolAttr, WaveIndexSymbolAttr>)) {
544-
return emitError() << "lds_block_shape must only contain WaveSymbolAttr or WaveIndexSymbolAttr";
552+
return emitError() << "lds_block_shape must only contain WaveSymbolAttr "
553+
"or WaveIndexSymbolAttr";
545554
}
546555

547556
if (!llvm::all_of(lds_load_indices.getSymbols(),
548557
llvm::IsaPred<WaveSymbolAttr, WaveIndexSymbolAttr>)) {
549-
return emitError() << "lds_load_indices must only contain WaveSymbolAttr or WaveIndexSymbolAttr";
558+
return emitError() << "lds_load_indices must only contain WaveSymbolAttr "
559+
"or WaveIndexSymbolAttr";
550560
}
551561

552562
if (!llvm::all_of(lds_load_vector_sizes.getSymbols(),
553563
llvm::IsaPred<WaveSymbolAttr, WaveIndexSymbolAttr>)) {
554-
return emitError() << "lds_load_vector_sizes must only contain WaveSymbolAttr or WaveIndexSymbolAttr";
564+
return emitError() << "lds_load_vector_sizes must only contain "
565+
"WaveSymbolAttr or WaveIndexSymbolAttr";
555566
}
556567

557568
if (!llvm::all_of(global_store_indices.getSymbols(),
558569
llvm::IsaPred<WaveSymbolAttr, WaveIndexSymbolAttr>)) {
559-
return emitError() << "global_store_indices must only contain WaveSymbolAttr or WaveIndexSymbolAttr";
570+
return emitError() << "global_store_indices must only contain "
571+
"WaveSymbolAttr or WaveIndexSymbolAttr";
560572
}
561573

562574
// Validate that mappings have at least one dimension
563575
if (ldsBaseRank == 0) {
564576
return emitError() << "LDS block must have at least one dimension";
565577
}
566578

567-
// Note: We cannot validate that the ranks match the original global memory tensor rank here
568-
// because this attribute verification doesn't have access to the WriteOp's memory operand.
569-
// This validation should be performed in the WriteOp's verifier where both the attribute
570-
// and the memory operand type are available.
579+
// Note: We cannot validate that the ranks match the original global memory
580+
// tensor rank here because this attribute verification doesn't have access
581+
// to the WriteOp's memory operand. This validation should be performed in
582+
// the WriteOp's verifier where both the attribute and the memory operand
583+
// type are available.
571584
//
572-
// Additionally, data coverage verification (ensuring that the collective workgroup access
573-
// pattern covers exactly the same elements before and after LDS promotion) should be
574-
// performed in the WriteOp verifier where access to the original index mapping is available.
585+
// Additionally, data coverage verification (ensuring that the collective
586+
// workgroup access pattern covers exactly the same elements before and
587+
// after LDS promotion) should be performed in the WriteOp verifier where
588+
// access to the original index mapping is available.
575589
}
576590

577591
return success();

0 commit comments

Comments
 (0)