@@ -460,24 +460,24 @@ DeviceConstraintAttr::verify(function_ref<InFlightDiagnostic()> emitError,
460460// WaveMemoryAccessPatternAttr
461461// ===----------------------------------------------------------------------===//
462462
463- LogicalResult
464- WaveMemoryAccessPatternAttr::verify (function_ref<InFlightDiagnostic()> emitError,
465- bool use_lds_promotion, StringRef group_id,
466- WaveExprListAttr lds_block_global_base,
467- WaveExprListAttr lds_block_shape,
468- WaveExprListAttr lds_load_indices,
469- WaveExprListAttr lds_load_vector_sizes,
470- WaveExprListAttr global_store_indices) {
463+ LogicalResult WaveMemoryAccessPatternAttr::verify (
464+ function_ref<InFlightDiagnostic()> emitError, bool use_lds_promotion,
465+ StringRef group_id, WaveExprListAttr lds_block_global_base,
466+ WaveExprListAttr lds_block_shape, WaveExprListAttr lds_load_indices,
467+ WaveExprListAttr lds_load_vector_sizes,
468+ WaveExprListAttr global_store_indices) {
471469 // Validate group_id is not empty
472470 if (group_id.empty ()) {
473471 return emitError () << " group_id cannot be empty" ;
474472 }
475473
476- // When LDS promotion is disabled, no LDS-related parameters should be specified
474+ // When LDS promotion is disabled, no LDS-related parameters should be
475+ // specified
477476 if (!use_lds_promotion) {
478477 if (lds_block_global_base || lds_block_shape || lds_load_indices ||
479478 lds_load_vector_sizes || global_store_indices) {
480- return emitError () << " LDS promotion parameters should not be specified when use_lds_promotion=false" ;
479+ return emitError () << " LDS promotion parameters should not be specified "
480+ " when use_lds_promotion=false" ;
481481 }
482482 return success ();
483483 }
@@ -492,10 +492,12 @@ WaveMemoryAccessPatternAttr::verify(function_ref<InFlightDiagnostic()> emitError
492492 // Check for partial specification - either all or none should be provided
493493 if (hasLdsBase || hasLdsShape || hasLdsLoadIndices || hasLdsLoadVectorSizes ||
494494 hasGlobalStoreIndices) {
495- if (!hasLdsBase || !hasLdsShape || !hasLdsLoadIndices || !hasLdsLoadVectorSizes ||
496- !hasGlobalStoreIndices) {
497- return emitError () << " when LDS promotion is enabled, all LDS parameters must be specified: "
498- " lds_block_global_base, lds_block_shape, lds_load_indices, lds_load_vector_sizes, "
495+ if (!hasLdsBase || !hasLdsShape || !hasLdsLoadIndices ||
496+ !hasLdsLoadVectorSizes || !hasGlobalStoreIndices) {
497+ return emitError () << " when LDS promotion is enabled, all LDS parameters "
498+ " must be specified: "
499+ " lds_block_global_base, lds_block_shape, "
500+ " lds_load_indices, lds_load_vector_sizes, "
499501 " global_store_indices" ;
500502 }
501503 }
@@ -504,13 +506,15 @@ WaveMemoryAccessPatternAttr::verify(function_ref<InFlightDiagnostic()> emitError
504506 if (hasLdsBase && hasLdsShape && hasLdsLoadIndices && hasLdsLoadVectorSizes &&
505507 hasGlobalStoreIndices) {
506508
507- // Validate that lds_block_global_base and lds_block_shape have consistent ranks
509+ // Validate that lds_block_global_base and lds_block_shape have consistent
510+ // ranks
508511 unsigned ldsBaseRank = lds_block_global_base.getRank ();
509512 unsigned ldsShapeRank = lds_block_shape.getRank ();
510513
511514 if (ldsBaseRank != ldsShapeRank) {
512515 return emitError () << " lds_block_global_base rank (" << ldsBaseRank
513- << " ) must match lds_block_shape rank (" << ldsShapeRank << " )" ;
516+ << " ) must match lds_block_shape rank ("
517+ << ldsShapeRank << " )" ;
514518 }
515519
516520 // Validate that load indices and vector sizes have consistent ranks
@@ -520,58 +524,68 @@ WaveMemoryAccessPatternAttr::verify(function_ref<InFlightDiagnostic()> emitError
520524
521525 if (ldsLoadIndicesRank != ldsLoadVectorSizesRank) {
522526 return emitError () << " lds_load_indices rank (" << ldsLoadIndicesRank
523- << " ) must match lds_load_vector_sizes rank (" << ldsLoadVectorSizesRank << " )" ;
527+ << " ) must match lds_load_vector_sizes rank ("
528+ << ldsLoadVectorSizesRank << " )" ;
524529 }
525530
526531 if (ldsBaseRank != ldsLoadIndicesRank) {
527532 return emitError () << " LDS block rank (" << ldsBaseRank
528- << " ) must match LDS load indices rank (" << ldsLoadIndicesRank << " )" ;
533+ << " ) must match LDS load indices rank ("
534+ << ldsLoadIndicesRank << " )" ;
529535 }
530536
531537 if (ldsBaseRank != globalStoreIndicesRank) {
532538 return emitError () << " LDS block rank (" << ldsBaseRank
533- << " ) must match global store indices rank (" << globalStoreIndicesRank << " )" ;
539+ << " ) must match global store indices rank ("
540+ << globalStoreIndicesRank << " )" ;
534541 }
535542
536543 // Validate that all symbols are WaveSymbolAttr or WaveIndexSymbolAttr
537544 if (!llvm::all_of (lds_block_global_base.getSymbols (),
538545 llvm::IsaPred<WaveSymbolAttr, WaveIndexSymbolAttr>)) {
539- return emitError () << " lds_block_global_base must only contain WaveSymbolAttr or WaveIndexSymbolAttr" ;
546+ return emitError () << " lds_block_global_base must only contain "
547+ " WaveSymbolAttr or WaveIndexSymbolAttr" ;
540548 }
541549
542550 if (!llvm::all_of (lds_block_shape.getSymbols (),
543551 llvm::IsaPred<WaveSymbolAttr, WaveIndexSymbolAttr>)) {
544- return emitError () << " lds_block_shape must only contain WaveSymbolAttr or WaveIndexSymbolAttr" ;
552+ return emitError () << " lds_block_shape must only contain WaveSymbolAttr "
553+ " or WaveIndexSymbolAttr" ;
545554 }
546555
547556 if (!llvm::all_of (lds_load_indices.getSymbols (),
548557 llvm::IsaPred<WaveSymbolAttr, WaveIndexSymbolAttr>)) {
549- return emitError () << " lds_load_indices must only contain WaveSymbolAttr or WaveIndexSymbolAttr" ;
558+ return emitError () << " lds_load_indices must only contain WaveSymbolAttr "
559+ " or WaveIndexSymbolAttr" ;
550560 }
551561
552562 if (!llvm::all_of (lds_load_vector_sizes.getSymbols (),
553563 llvm::IsaPred<WaveSymbolAttr, WaveIndexSymbolAttr>)) {
554- return emitError () << " lds_load_vector_sizes must only contain WaveSymbolAttr or WaveIndexSymbolAttr" ;
564+ return emitError () << " lds_load_vector_sizes must only contain "
565+ " WaveSymbolAttr or WaveIndexSymbolAttr" ;
555566 }
556567
557568 if (!llvm::all_of (global_store_indices.getSymbols (),
558569 llvm::IsaPred<WaveSymbolAttr, WaveIndexSymbolAttr>)) {
559- return emitError () << " global_store_indices must only contain WaveSymbolAttr or WaveIndexSymbolAttr" ;
570+ return emitError () << " global_store_indices must only contain "
571+ " WaveSymbolAttr or WaveIndexSymbolAttr" ;
560572 }
561573
562574 // Validate that mappings have at least one dimension
563575 if (ldsBaseRank == 0 ) {
564576 return emitError () << " LDS block must have at least one dimension" ;
565577 }
566578
567- // Note: We cannot validate that the ranks match the original global memory tensor rank here
568- // because this attribute verification doesn't have access to the WriteOp's memory operand.
569- // This validation should be performed in the WriteOp's verifier where both the attribute
570- // and the memory operand type are available.
579+ // Note: We cannot validate that the ranks match the original global memory
580+ // tensor rank here because this attribute verification doesn't have access
581+ // to the WriteOp's memory operand. This validation should be performed in
582+ // the WriteOp's verifier where both the attribute and the memory operand
583+ // type are available.
571584 //
572- // Additionally, data coverage verification (ensuring that the collective workgroup access
573- // pattern covers exactly the same elements before and after LDS promotion) should be
574- // performed in the WriteOp verifier where access to the original index mapping is available.
585+ // Additionally, data coverage verification (ensuring that the collective
586+ // workgroup access pattern covers exactly the same elements before and
587+ // after LDS promotion) should be performed in the WriteOp verifier where
588+ // access to the original index mapping is available.
575589 }
576590
577591 return success ();
0 commit comments