1+ #include " Common/AcceleratorAttrs.h"
12#include " Conversion/ConversionPasses.h"
23#include " NeuraDialect/NeuraDialect.h"
34#include " NeuraDialect/NeuraOps.h"
89#include " mlir/IR/PatternMatch.h"
910#include " mlir/Pass/Pass.h"
1011#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
12+ #include " llvm/ADT/StringRef.h"
1113
1214namespace mlir {
1315namespace neura {
@@ -24,9 +26,6 @@ using namespace mlir;
2426using namespace mlir ::func;
2527using namespace mlir ::neura;
2628
27- #define GEN_PASS_DEF_LOWERARITHTONEURA
28- #include " NeuraDialect/NeuraPasses.h.inc"
29-
3029namespace {
3130
3231struct ArithConstantToNeuraConstant
@@ -35,10 +34,10 @@ struct ArithConstantToNeuraConstant
3534
3635 LogicalResult matchAndRewrite (arith::ConstantOp op,
3736 PatternRewriter &rewriter) const override {
38- // Converts arith constant to Neura constant
37+ // Converts arith constant to Neura constant.
3938 Type result_type = op.getType ();
4039 Attribute value = op.getValue ();
41- // Optional predicate parameter can be null
40+ // Optional predicate parameter can be null.
4241 rewriter.replaceOpWithNewOp <neura::ConstantOp>(op, result_type, value,
4342 nullptr );
4443 return success ();
@@ -54,7 +53,7 @@ struct ArithAddIToNeuraAdd : public OpRewritePattern<mlir::arith::AddIOp> {
5453 Value rhs = op.getRhs ();
5554 Type result_type = op.getType ();
5655
57- // Optional predicate: default to null
56+ // Optional predicate: default to null.
5857 rewriter.replaceOpWithNewOp <neura::AddOp>(op, result_type, lhs, rhs,
5958 nullptr );
6059 return success ();
@@ -70,7 +69,7 @@ struct ArithFAddToNeuraFAdd : public OpRewritePattern<mlir::arith::AddFOp> {
7069 Value rhs = op.getRhs ();
7170 Type result_type = op.getType ();
7271
73- // Optional predicate: default to null
72+ // Optional predicate: default to null.
7473 rewriter.replaceOpWithNewOp <neura::FAddOp>(op, result_type, lhs, rhs,
7574 nullptr );
7675 return success ();
@@ -86,7 +85,7 @@ struct ArithSubIToNeuraSub : public OpRewritePattern<mlir::arith::SubIOp> {
8685 Value rhs = op.getRhs ();
8786 Type result_type = op.getType ();
8887
89- // Optional predicate: default to null
88+ // Optional predicate: default to null.
9089 rewriter.replaceOpWithNewOp <neura::SubOp>(op, result_type, lhs, rhs,
9190 nullptr );
9291 return success ();
@@ -102,7 +101,7 @@ struct ArithSubFToNeuraFSub : public OpRewritePattern<mlir::arith::SubFOp> {
102101 Value rhs = op.getRhs ();
103102 Type result_type = op.getType ();
104103
105- // Optional predicate: default to null
104+ // Optional predicate: default to null.
106105 rewriter.replaceOpWithNewOp <neura::FSubOp>(op, result_type, lhs, rhs,
107106 nullptr );
108107 return success ();
@@ -118,7 +117,7 @@ struct ArithMulFToNeuraFMul : public OpRewritePattern<mlir::arith::MulFOp> {
118117 Value rhs = op.getRhs ();
119118 Type result_type = op.getType ();
120119
121- // Optional predicate: default to null
120+ // Optional predicate: default to null.
122121 rewriter.replaceOpWithNewOp <neura::FMulOp>(op, result_type, lhs, rhs,
123122 nullptr );
124123 return success ();
@@ -134,7 +133,7 @@ struct ArithFDivToNeuraFDiv : public OpRewritePattern<mlir::arith::DivFOp> {
134133 Value rhs = op.getRhs ();
135134 Type result_type = op.getType ();
136135
137- // Optional predicate: default to null
136+ // Optional predicate: default to null.
138137 rewriter.replaceOpWithNewOp <neura::FDivOp>(op, result_type, lhs, rhs,
139138 nullptr );
140139 return success ();
@@ -185,8 +184,8 @@ struct ArithCmpiToNeuraICmp : public OpRewritePattern<mlir::arith::CmpIOp> {
185184 return rewriter.notifyMatchFailure (op, " Unsupported arith CmpIOp type" );
186185 }
187186
188- // Convert arith CmpIOp to Neura ICmpOp
189- // Optional predicate: default to null
187+ // Converts arith CmpIOp to Neura ICmpOp.
188+ // Optional predicate: default to null.
190189 rewriter.replaceOpWithNewOp <neura::ICmpOp>(
191190 op, result_type, lhs, rhs, nullptr , rewriter.getStringAttr (cmp_type));
192191 return success ();
@@ -203,7 +202,7 @@ struct ArithSelectToNeuraSel : public OpRewritePattern<mlir::arith::SelectOp> {
203202 Value false_value = op.getFalseValue ();
204203 Type result_type = op.getType ();
205204
206- // Convert arith SelectOp to Neura SelOp
205+ // Converts arith SelectOp to Neura SelOp.
207206 rewriter.replaceOpWithNewOp <neura::SelOp>(op, result_type, true_value,
208207 false_value, condition);
209208 return success ();
@@ -218,8 +217,8 @@ struct ArithExtUIToNeuraCast : public OpRewritePattern<mlir::arith::ExtUIOp> {
218217 Value input = op.getIn ();
219218 Type result_type = op.getType ();
220219
221- // Convert arith ExtUIOp to Neura cast operation
222- // Optional predicate: default to null
220+ // Converts arith ExtUIOp to Neura cast operation.
221+ // Optional predicate: default to null.
223222 rewriter.replaceOpWithNewOp <neura::CastOp>(
224223 op, result_type, input, rewriter.getStringAttr (" extui" ), nullptr );
225224 return success ();
@@ -234,8 +233,8 @@ struct ArithExtfToNeuraCast : public OpRewritePattern<mlir::arith::ExtFOp> {
234233 Value input = op.getIn ();
235234 Type result_type = op.getType ();
236235
237- // Convert arith ExtFOp to Neura cast operation
238- // Optional predicate: default to null
236+ // Converts arith ExtFOp to Neura cast operation.
237+ // Optional predicate: default to null.
239238 rewriter.replaceOpWithNewOp <neura::CastOp>(
240239 op, result_type, input, rewriter.getStringAttr (" extf" ), nullptr );
241240 return success ();
@@ -250,11 +249,23 @@ struct ArithIndexCastToNeuraCast
250249 PatternRewriter &rewriter) const override {
251250 Value input = op.getIn ();
252251 Type result_type = op.getType ();
252+ Type in_type = input.getType ();
253+ StringRef cast_string;
254+
255+ // The isa<IntegerType> check is generic and handles any integer bit width.
256+ // (e.g., i32, i64).
257+ if (in_type.isIndex () && isa<IntegerType>(result_type)) {
258+ cast_string = " index_to_int" ;
259+ } else if (isa<IntegerType>(in_type) && result_type.isIndex ()) {
260+ cast_string = " int_to_index" ;
261+ } else {
262+ return rewriter.notifyMatchFailure (op, " index_cast" );
263+ }
253264
254- // Convert arith IndexCastOp to Neura cast operation
255- // Optional predicate: default to null
265+ // Converts arith IndexCastOp to Neura cast operation.
266+ // Optional predicate: default to null.
256267 rewriter.replaceOpWithNewOp <neura::CastOp>(
257- op, result_type, input, rewriter.getStringAttr (" indexCast " ), nullptr );
268+ op, result_type, input, rewriter.getStringAttr (cast_string ), nullptr );
258269 return success ();
259270 }
260271};
@@ -274,16 +285,28 @@ struct LowerArithToNeuraPass
274285 }
275286
276287 void runOnOperation () override {
277- RewritePatternSet patterns (&getContext ());
278- mlir::neura::arith2neura::populateWithGenerated (patterns);
279- patterns
280- .add <ArithFAddToNeuraFAdd, ArithConstantToNeuraConstant,
281- ArithAddIToNeuraAdd, ArithCmpiToNeuraICmp, ArithSelectToNeuraSel,
282- ArithExtUIToNeuraCast, ArithIndexCastToNeuraCast,
283- ArithFDivToNeuraFDiv, ArithExtfToNeuraCast, ArithMulFToNeuraFMul, ArithSubIToNeuraSub, ArithSubFToNeuraFSub>(&getContext ());
284- if (failed (applyPatternsGreedily (getOperation (), std::move (patterns)))) {
285- signalPassFailure ();
286- }
288+ ModuleOp module_op = getOperation ();
289+ MLIRContext *context = &getContext ();
290+ module_op.walk ([&](func::FuncOp func_op) {
291+ if (func_op->hasAttr (mlir::accel::kAcceleratorAttr )) {
292+ auto target =
293+ func_op->getAttrOfType <StringAttr>(mlir::accel::kAcceleratorAttr );
294+ if (target && target.getValue () == mlir::accel::kNeuraTarget ) {
295+ RewritePatternSet patterns (&getContext ());
296+ mlir::neura::arith2neura::populateWithGenerated (patterns);
297+ patterns.add <ArithFAddToNeuraFAdd, ArithConstantToNeuraConstant,
298+ ArithAddIToNeuraAdd, ArithCmpiToNeuraICmp,
299+ ArithSelectToNeuraSel, ArithExtUIToNeuraCast,
300+ ArithIndexCastToNeuraCast, ArithFDivToNeuraFDiv,
301+ ArithExtfToNeuraCast, ArithMulFToNeuraFMul,
302+ ArithSubIToNeuraSub, ArithSubFToNeuraFSub>(context);
303+ if (failed (
304+ applyPatternsGreedily (getOperation (), std::move (patterns)))) {
305+ signalPassFailure ();
306+ }
307+ }
308+ }
309+ });
287310 }
288311};
289312} // namespace
0 commit comments