1+ #include " Conversion/ConversionPasses.h"
12#include " NeuraDialect/NeuraDialect.h"
23#include " NeuraDialect/NeuraOps.h"
3- #include " mlir/Dialect/Arith/IR/Arith.h"
44#include " NeuraDialect/NeuraPasses.h"
5+ #include " mlir/Dialect/Arith/IR/Arith.h"
56#include " mlir/Dialect/Func/IR/FuncOps.h"
7+ #include " mlir/IR/Attributes.h"
68#include " mlir/IR/PatternMatch.h"
79#include " mlir/Pass/Pass.h"
810#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
9- #include " Conversion/ConversionPasses.h"
1011
1112namespace mlir {
1213namespace neura {
@@ -26,7 +27,39 @@ using namespace mlir::neura;
2627#define GEN_PASS_DEF_LOWERARITHTONEURA
2728#include " NeuraDialect/NeuraPasses.h.inc"
2829
29- namespace {
30+ namespace {
31+
32+ struct ArithConstantToNeuraConstant
33+ : public OpRewritePattern<mlir::arith::ConstantOp> {
34+ using OpRewritePattern::OpRewritePattern;
35+
36+ LogicalResult matchAndRewrite (arith::ConstantOp op,
37+ PatternRewriter &rewriter) const override {
38+ // Converts arith constant to Neura constant
39+ Type result_type = op.getType ();
40+ Attribute value = op.getValue ();
41+ // Optional predicate parameter can be null
42+ rewriter.replaceOpWithNewOp <neura::ConstantOp>(op, result_type, value,
43+ nullptr );
44+ return success ();
45+ }
46+ };
47+
48+ struct ArithAddIToNeuraAdd : public OpRewritePattern <mlir::arith::AddIOp> {
49+ using OpRewritePattern::OpRewritePattern;
50+
51+ LogicalResult matchAndRewrite (arith::AddIOp op,
52+ PatternRewriter &rewriter) const override {
53+ Value lhs = op.getLhs ();
54+ Value rhs = op.getRhs ();
55+ Type result_type = op.getType ();
56+
57+ // Optional predicate: default to null
58+ rewriter.replaceOpWithNewOp <neura::AddOp>(op, result_type, lhs, rhs,
59+ nullptr );
60+ return success ();
61+ }
62+ };
3063
3164struct ArithFAddToNeuraFAdd : public OpRewritePattern <mlir::arith::AddFOp> {
3265 using OpRewritePattern::OpRewritePattern;
@@ -35,16 +68,199 @@ struct ArithFAddToNeuraFAdd : public OpRewritePattern<mlir::arith::AddFOp> {
3568 PatternRewriter &rewriter) const override {
3669 Value lhs = op.getLhs ();
3770 Value rhs = op.getRhs ();
38- Type resultType = op.getType ();
71+ Type result_type = op.getType ();
72+
73+ // Optional predicate: default to null
74+ rewriter.replaceOpWithNewOp <neura::FAddOp>(op, result_type, lhs, rhs,
75+ nullptr );
76+ return success ();
77+ }
78+ };
79+
80+ struct ArithSubIToNeuraSub : public OpRewritePattern <mlir::arith::SubIOp> {
81+ using OpRewritePattern::OpRewritePattern;
82+
83+ LogicalResult matchAndRewrite (arith::SubIOp op,
84+ PatternRewriter &rewriter) const override {
85+ Value lhs = op.getLhs ();
86+ Value rhs = op.getRhs ();
87+ Type result_type = op.getType ();
88+
89+ // Optional predicate: default to null
90+ rewriter.replaceOpWithNewOp <neura::SubOp>(op, result_type, lhs, rhs,
91+ nullptr );
92+ return success ();
93+ }
94+ };
95+
96+ struct ArithSubFToNeuraFSub : public OpRewritePattern <mlir::arith::SubFOp> {
97+ using OpRewritePattern::OpRewritePattern;
98+
99+ LogicalResult matchAndRewrite (arith::SubFOp op,
100+ PatternRewriter &rewriter) const override {
101+ Value lhs = op.getLhs ();
102+ Value rhs = op.getRhs ();
103+ Type result_type = op.getType ();
104+
105+ // Optional predicate: default to null
106+ rewriter.replaceOpWithNewOp <neura::FSubOp>(op, result_type, lhs, rhs,
107+ nullptr );
108+ return success ();
109+ }
110+ };
111+
112+ struct ArithMulFToNeuraFMul : public OpRewritePattern <mlir::arith::MulFOp> {
113+ using OpRewritePattern::OpRewritePattern;
114+
115+ LogicalResult matchAndRewrite (arith::MulFOp op,
116+ PatternRewriter &rewriter) const override {
117+ Value lhs = op.getLhs ();
118+ Value rhs = op.getRhs ();
119+ Type result_type = op.getType ();
120+
121+ // Optional predicate: default to null
122+ rewriter.replaceOpWithNewOp <neura::FMulOp>(op, result_type, lhs, rhs,
123+ nullptr );
124+ return success ();
125+ }
126+ };
127+
128+ struct ArithFDivToNeuraFDiv : public OpRewritePattern <mlir::arith::DivFOp> {
129+ using OpRewritePattern::OpRewritePattern;
130+
131+ LogicalResult matchAndRewrite (arith::DivFOp op,
132+ PatternRewriter &rewriter) const override {
133+ Value lhs = op.getLhs ();
134+ Value rhs = op.getRhs ();
135+ Type result_type = op.getType ();
136+
137+ // Optional predicate: default to null
138+ rewriter.replaceOpWithNewOp <neura::FDivOp>(op, result_type, lhs, rhs,
139+ nullptr );
140+ return success ();
141+ }
142+ };
143+ struct ArithCmpiToNeuraICmp : public OpRewritePattern <mlir::arith::CmpIOp> {
144+ using OpRewritePattern::OpRewritePattern;
145+
146+ LogicalResult matchAndRewrite (arith::CmpIOp op,
147+ PatternRewriter &rewriter) const override {
148+ Value lhs = op.getLhs ();
149+ Value rhs = op.getRhs ();
150+ Type result_type = op.getType ();
151+ arith::CmpIPredicate arith_cmp_type = op.getPredicate ();
152+ StringRef cmp_type;
153+ switch (arith_cmp_type) {
154+ case arith::CmpIPredicate::eq:
155+ cmp_type = " eq" ; // ==
156+ break ;
157+ case arith::CmpIPredicate::ne:
158+ cmp_type = " ne" ; // !=
159+ break ;
160+ case arith::CmpIPredicate::slt:
161+ cmp_type = " slt" ; // <
162+ break ;
163+ case arith::CmpIPredicate::sle:
164+ cmp_type = " sle" ; // <=
165+ break ;
166+ case arith::CmpIPredicate::sgt:
167+ cmp_type = " sgt" ; // >
168+ break ;
169+ case arith::CmpIPredicate::sge:
170+ cmp_type = " sge" ; // >=
171+ break ;
172+ case arith::CmpIPredicate::ult:
173+ cmp_type = " ult" ; // unsigned <
174+ break ;
175+ case arith::CmpIPredicate::ule:
176+ cmp_type = " ule" ; // unsigned <=
177+ break ;
178+ case arith::CmpIPredicate::ugt:
179+ cmp_type = " ugt" ; // unsigned >
180+ break ;
181+ case arith::CmpIPredicate::uge:
182+ cmp_type = " uge" ; // unsigned >=
183+ break ;
184+ default :
185+ return rewriter.notifyMatchFailure (op, " Unsupported arith CmpIOp type" );
186+ }
187+
188+ // Convert arith CmpIOp to Neura ICmpOp
189+ // Optional predicate: default to null
190+ rewriter.replaceOpWithNewOp <neura::ICmpOp>(
191+ op, result_type, lhs, rhs, nullptr , rewriter.getStringAttr (cmp_type));
192+ return success ();
193+ }
194+ };
195+
196+ struct ArithSelectToNeuraSel : public OpRewritePattern <mlir::arith::SelectOp> {
197+ using OpRewritePattern::OpRewritePattern;
198+
199+ LogicalResult matchAndRewrite (arith::SelectOp op,
200+ PatternRewriter &rewriter) const override {
201+ Value condition = op.getCondition ();
202+ Value true_value = op.getTrueValue ();
203+ Value false_value = op.getFalseValue ();
204+ Type result_type = op.getType ();
205+
206+ // Convert arith SelectOp to Neura SelOp
207+ rewriter.replaceOpWithNewOp <neura::SelOp>(op, result_type, true_value,
208+ false_value, condition);
209+ return success ();
210+ }
211+ };
212+
213+ struct ArithExtUIToNeuraCast : public OpRewritePattern <mlir::arith::ExtUIOp> {
214+ using OpRewritePattern::OpRewritePattern;
215+
216+ LogicalResult matchAndRewrite (arith::ExtUIOp op,
217+ PatternRewriter &rewriter) const override {
218+ Value input = op.getIn ();
219+ Type result_type = op.getType ();
220+
221+ // Convert arith ExtUIOp to Neura cast operation
222+ // Optional predicate: default to null
223+ rewriter.replaceOpWithNewOp <neura::CastOp>(
224+ op, result_type, input, rewriter.getStringAttr (" extui" ), nullptr );
225+ return success ();
226+ }
227+ };
228+
229+ struct ArithExtfToNeuraCast : public OpRewritePattern <mlir::arith::ExtFOp> {
230+ using OpRewritePattern::OpRewritePattern;
231+
232+ LogicalResult matchAndRewrite (arith::ExtFOp op,
233+ PatternRewriter &rewriter) const override {
234+ Value input = op.getIn ();
235+ Type result_type = op.getType ();
236+
237+ // Convert arith ExtFOp to Neura cast operation
238+ // Optional predicate: default to null
239+ rewriter.replaceOpWithNewOp <neura::CastOp>(
240+ op, result_type, input, rewriter.getStringAttr (" extf" ), nullptr );
241+ return success ();
242+ }
243+ };
244+
245+ struct ArithIndexCastToNeuraCast
246+ : public OpRewritePattern<mlir::arith::IndexCastOp> {
247+ using OpRewritePattern::OpRewritePattern;
248+
249+ LogicalResult matchAndRewrite (arith::IndexCastOp op,
250+ PatternRewriter &rewriter) const override {
251+ Value input = op.getIn ();
252+ Type result_type = op.getType ();
39253
40- // Optional predicate: default to 'none'
41- rewriter.replaceOpWithNewOp <neura::FAddOp>(op, resultType, lhs, rhs, Value ());
254+ // Convert arith IndexCastOp to Neura cast operation
255+ // Optional predicate: default to null
256+ rewriter.replaceOpWithNewOp <neura::CastOp>(
257+ op, result_type, input, rewriter.getStringAttr (" indexCast" ), nullptr );
42258 return success ();
43259 }
44260};
45261
46262struct LowerArithToNeuraPass
47- : public PassWrapper<LowerArithToNeuraPass, OperationPass<func::FuncOp >> {
263+ : public PassWrapper<LowerArithToNeuraPass, OperationPass<ModuleOp >> {
48264
49265 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (LowerArithToNeuraPass)
50266
@@ -60,7 +276,11 @@ struct LowerArithToNeuraPass
60276 void runOnOperation () override {
61277 RewritePatternSet patterns (&getContext ());
62278 mlir::neura::arith2neura::populateWithGenerated (patterns);
63- patterns.add <ArithFAddToNeuraFAdd>(&getContext ());
279+ patterns
280+ .add <ArithFAddToNeuraFAdd, ArithConstantToNeuraConstant,
281+ ArithAddIToNeuraAdd, ArithCmpiToNeuraICmp, ArithSelectToNeuraSel,
282+ ArithExtUIToNeuraCast, ArithIndexCastToNeuraCast,
283+ ArithFDivToNeuraFDiv, ArithExtfToNeuraCast, ArithMulFToNeuraFMul, ArithSubIToNeuraSub, ArithSubFToNeuraFSub>(&getContext ());
64284 if (failed (applyPatternsGreedily (getOperation (), std::move (patterns)))) {
65285 signalPassFailure ();
66286 }
0 commit comments