Skip to content

Commit b6b175b

Browse files
committed
add bert-kernel test & lower arith2neura
1 parent 67ecd5f commit b6b175b

File tree

26 files changed

+4087
-9
lines changed

26 files changed

+4087
-9
lines changed

include/Conversion/ConversionPasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ namespace mlir {
1919
// Conversion passes.
2020
std::unique_ptr<mlir::Pass> createLowerArithToNeuraPass();
2121
std::unique_ptr<mlir::Pass> createLowerLlvmToNeuraPass();
22+
std::unique_ptr<mlir::Pass> createLowerMemRefToNeuraPass();
2223

2324
#define GEN_PASS_REGISTRATION
2425
#include "Conversion/ConversionPasses.h.inc"

include/Conversion/ConversionPasses.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,10 @@ def LowerLlvmToNeura : Pass<"lower-llvm-to-neura", "ModuleOp">{
2020
let constructor = "mlir::createLowerLlvmToNeuraPass()";
2121
}
2222

23+
def LowerMemRefToNeura : Pass<"lower-memref-to-neura", "ModuleOp">{
24+
let summary = "Lower MemRef to Neura dialect";
25+
let description = [{Lower MemRef operations to Neura dialect operations.}];
26+
let constructor = "mlir::createLowerMemRefToNeuraPass()";
27+
}
28+
2329
#endif // CONVERSION_PASSES_TD

include/NeuraDialect/NeuraOps.td

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ def Neura_AddOp : Op<NeuraDialect, "add"> {
2424
let traits = [SameOperandsAndResultElementType];
2525
}
2626

27+
def Neura_SubOp : Op<NeuraDialect, "sub"> {
28+
let summary = "Integer substraction operation";
29+
let arguments = (ins AnyType:$lhs, AnyType:$rhs, Optional<AnyType>:$predicate);
30+
let results = (outs AnyType:$result);
31+
// let assemblyFormat = "$lhs `,` $rhs `,` $predicate attr-dict `:` type($result)";
32+
let traits = [SameOperandsAndResultElementType];
33+
}
34+
2735
// Defines a floating-point addition operation.
2836
def Neura_FAddOp : Op<NeuraDialect, "fadd"> {
2937
let summary = "Floating addition operation";
@@ -38,7 +46,7 @@ def Neura_FAddOp : Op<NeuraDialect, "fadd"> {
3846
def Neura_FSubOp: Op<NeuraDialect, "fsub"> {
3947
let summary = "Floating substraction operation";
4048
let opName = "fsub";
41-
let arguments = (ins AnyFloat:$lhs, AnyFloat:$rhs);
49+
let arguments = (ins AnyFloat:$lhs, AnyFloat:$rhs, Optional<AnyType>:$predicate);
4250
let results = (outs AnyFloat:$result);
4351
// let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
4452
let traits = [SameOperandsAndResultElementType];
@@ -54,6 +62,13 @@ def Neura_FMulOp : Op<NeuraDialect, "fmul"> {
5462
// let traits = [SameOperandsAndResultElementType];
5563
}
5664

65+
def Neura_FDivOp : Op<NeuraDialect, "fdiv"> {
66+
let summary = "Floating division operation";
67+
let arguments = (ins AnyFloat:$lhs, AnyFloat:$rhs, Optional<AnyType>:$predicate);
68+
let results = (outs AnyFloat:$result);
69+
// let assemblyFormat = "$lhs `,` $rhs `,` $predicate attr-dict `:` type($result)";
70+
}
71+
5772
// Defines a bitwise OR operation.
5873
def Neura_OrOp : Op<NeuraDialect, "or"> {
5974
let summary = "Bitwise OR operation";
@@ -144,6 +159,14 @@ def Neura_ReturnOp : Op<NeuraDialect, "return", [Terminator]> {
144159
// let assemblyFormat = "($values^)? `,` $predicate attr-dict";
145160
}
146161

162+
// Defines a cast operation for type conversion.
163+
def Neura_CastOp : Op<NeuraDialect, "cast">{
164+
let summary = "Generic type conversion operation";
165+
let arguments = (ins AnyType:$input, StrAttr:$cast_type, Optional<AnyType>:$predicate);
166+
let results = (outs AnyType:$result);
167+
// let assemblyFormat = "$input type($input) `->` type($output) `,` $predicate attr-dict";
168+
}
169+
147170
// ----------------------------------------------------
148171
// Defines vector operations.
149172

lib/Conversion/ArithToNeura/ArithToNeuraPass.cpp

Lines changed: 228 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1+
#include "Conversion/ConversionPasses.h"
12
#include "NeuraDialect/NeuraDialect.h"
23
#include "NeuraDialect/NeuraOps.h"
3-
#include "mlir/Dialect/Arith/IR/Arith.h"
44
#include "NeuraDialect/NeuraPasses.h"
5+
#include "mlir/Dialect/Arith/IR/Arith.h"
56
#include "mlir/Dialect/Func/IR/FuncOps.h"
7+
#include "mlir/IR/Attributes.h"
68
#include "mlir/IR/PatternMatch.h"
79
#include "mlir/Pass/Pass.h"
810
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
9-
#include "Conversion/ConversionPasses.h"
1011

1112
namespace mlir {
1213
namespace neura {
@@ -26,7 +27,39 @@ using namespace mlir::neura;
2627
#define GEN_PASS_DEF_LOWERARITHTONEURA
2728
#include "NeuraDialect/NeuraPasses.h.inc"
2829

29-
namespace{
30+
namespace {
31+
32+
struct ArithConstantToNeuraConstant
33+
: public OpRewritePattern<mlir::arith::ConstantOp> {
34+
using OpRewritePattern::OpRewritePattern;
35+
36+
LogicalResult matchAndRewrite(arith::ConstantOp op,
37+
PatternRewriter &rewriter) const override {
38+
// Converts arith constant to Neura constant
39+
Type result_type = op.getType();
40+
Attribute value = op.getValue();
41+
// Optional predicate parameter can be null
42+
rewriter.replaceOpWithNewOp<neura::ConstantOp>(op, result_type, value,
43+
nullptr);
44+
return success();
45+
}
46+
};
47+
48+
struct ArithAddIToNeuraAdd : public OpRewritePattern<mlir::arith::AddIOp> {
49+
using OpRewritePattern::OpRewritePattern;
50+
51+
LogicalResult matchAndRewrite(arith::AddIOp op,
52+
PatternRewriter &rewriter) const override {
53+
Value lhs = op.getLhs();
54+
Value rhs = op.getRhs();
55+
Type result_type = op.getType();
56+
57+
// Optional predicate: default to null
58+
rewriter.replaceOpWithNewOp<neura::AddOp>(op, result_type, lhs, rhs,
59+
nullptr);
60+
return success();
61+
}
62+
};
3063

3164
struct ArithFAddToNeuraFAdd : public OpRewritePattern<mlir::arith::AddFOp> {
3265
using OpRewritePattern::OpRewritePattern;
@@ -35,16 +68,199 @@ struct ArithFAddToNeuraFAdd : public OpRewritePattern<mlir::arith::AddFOp> {
3568
PatternRewriter &rewriter) const override {
3669
Value lhs = op.getLhs();
3770
Value rhs = op.getRhs();
38-
Type resultType = op.getType();
71+
Type result_type = op.getType();
72+
73+
// Optional predicate: default to null
74+
rewriter.replaceOpWithNewOp<neura::FAddOp>(op, result_type, lhs, rhs,
75+
nullptr);
76+
return success();
77+
}
78+
};
79+
80+
struct ArithSubIToNeuraSub : public OpRewritePattern<mlir::arith::SubIOp> {
81+
using OpRewritePattern::OpRewritePattern;
82+
83+
LogicalResult matchAndRewrite(arith::SubIOp op,
84+
PatternRewriter &rewriter) const override {
85+
Value lhs = op.getLhs();
86+
Value rhs = op.getRhs();
87+
Type result_type = op.getType();
88+
89+
// Optional predicate: default to null
90+
rewriter.replaceOpWithNewOp<neura::SubOp>(op, result_type, lhs, rhs,
91+
nullptr);
92+
return success();
93+
}
94+
};
95+
96+
struct ArithSubFToNeuraFSub : public OpRewritePattern<mlir::arith::SubFOp> {
97+
using OpRewritePattern::OpRewritePattern;
98+
99+
LogicalResult matchAndRewrite(arith::SubFOp op,
100+
PatternRewriter &rewriter) const override {
101+
Value lhs = op.getLhs();
102+
Value rhs = op.getRhs();
103+
Type result_type = op.getType();
104+
105+
// Optional predicate: default to null
106+
rewriter.replaceOpWithNewOp<neura::FSubOp>(op, result_type, lhs, rhs,
107+
nullptr);
108+
return success();
109+
}
110+
};
111+
112+
struct ArithMulFToNeuraFMul : public OpRewritePattern<mlir::arith::MulFOp> {
113+
using OpRewritePattern::OpRewritePattern;
114+
115+
LogicalResult matchAndRewrite(arith::MulFOp op,
116+
PatternRewriter &rewriter) const override {
117+
Value lhs = op.getLhs();
118+
Value rhs = op.getRhs();
119+
Type result_type = op.getType();
120+
121+
// Optional predicate: default to null
122+
rewriter.replaceOpWithNewOp<neura::FMulOp>(op, result_type, lhs, rhs,
123+
nullptr);
124+
return success();
125+
}
126+
};
127+
128+
struct ArithFDivToNeuraFDiv : public OpRewritePattern<mlir::arith::DivFOp> {
129+
using OpRewritePattern::OpRewritePattern;
130+
131+
LogicalResult matchAndRewrite(arith::DivFOp op,
132+
PatternRewriter &rewriter) const override {
133+
Value lhs = op.getLhs();
134+
Value rhs = op.getRhs();
135+
Type result_type = op.getType();
136+
137+
// Optional predicate: default to null
138+
rewriter.replaceOpWithNewOp<neura::FDivOp>(op, result_type, lhs, rhs,
139+
nullptr);
140+
return success();
141+
}
142+
};
143+
struct ArithCmpiToNeuraICmp : public OpRewritePattern<mlir::arith::CmpIOp> {
144+
using OpRewritePattern::OpRewritePattern;
145+
146+
LogicalResult matchAndRewrite(arith::CmpIOp op,
147+
PatternRewriter &rewriter) const override {
148+
Value lhs = op.getLhs();
149+
Value rhs = op.getRhs();
150+
Type result_type = op.getType();
151+
arith::CmpIPredicate arith_cmp_type = op.getPredicate();
152+
StringRef cmp_type;
153+
switch (arith_cmp_type) {
154+
case arith::CmpIPredicate::eq:
155+
cmp_type = "eq"; // ==
156+
break;
157+
case arith::CmpIPredicate::ne:
158+
cmp_type = "ne"; // !=
159+
break;
160+
case arith::CmpIPredicate::slt:
161+
cmp_type = "slt"; // <
162+
break;
163+
case arith::CmpIPredicate::sle:
164+
cmp_type = "sle"; // <=
165+
break;
166+
case arith::CmpIPredicate::sgt:
167+
cmp_type = "sgt"; // >
168+
break;
169+
case arith::CmpIPredicate::sge:
170+
cmp_type = "sge"; // >=
171+
break;
172+
case arith::CmpIPredicate::ult:
173+
cmp_type = "ult"; // unsigned <
174+
break;
175+
case arith::CmpIPredicate::ule:
176+
cmp_type = "ule"; // unsigned <=
177+
break;
178+
case arith::CmpIPredicate::ugt:
179+
cmp_type = "ugt"; // unsigned >
180+
break;
181+
case arith::CmpIPredicate::uge:
182+
cmp_type = "uge"; // unsigned >=
183+
break;
184+
default:
185+
return rewriter.notifyMatchFailure(op, "Unsupported arith CmpIOp type");
186+
}
187+
188+
// Convert arith CmpIOp to Neura ICmpOp
189+
// Optional predicate: default to null
190+
rewriter.replaceOpWithNewOp<neura::ICmpOp>(
191+
op, result_type, lhs, rhs, nullptr, rewriter.getStringAttr(cmp_type));
192+
return success();
193+
}
194+
};
195+
196+
struct ArithSelectToNeuraSel : public OpRewritePattern<mlir::arith::SelectOp> {
197+
using OpRewritePattern::OpRewritePattern;
198+
199+
LogicalResult matchAndRewrite(arith::SelectOp op,
200+
PatternRewriter &rewriter) const override {
201+
Value condition = op.getCondition();
202+
Value true_value = op.getTrueValue();
203+
Value false_value = op.getFalseValue();
204+
Type result_type = op.getType();
205+
206+
// Convert arith SelectOp to Neura SelOp
207+
rewriter.replaceOpWithNewOp<neura::SelOp>(op, result_type, true_value,
208+
false_value, condition);
209+
return success();
210+
}
211+
};
212+
213+
struct ArithExtUIToNeuraCast : public OpRewritePattern<mlir::arith::ExtUIOp> {
214+
using OpRewritePattern::OpRewritePattern;
215+
216+
LogicalResult matchAndRewrite(arith::ExtUIOp op,
217+
PatternRewriter &rewriter) const override {
218+
Value input = op.getIn();
219+
Type result_type = op.getType();
220+
221+
// Convert arith ExtUIOp to Neura cast operation
222+
// Optional predicate: default to null
223+
rewriter.replaceOpWithNewOp<neura::CastOp>(
224+
op, result_type, input, rewriter.getStringAttr("extui"), nullptr);
225+
return success();
226+
}
227+
};
228+
229+
struct ArithExtfToNeuraCast : public OpRewritePattern<mlir::arith::ExtFOp> {
230+
using OpRewritePattern::OpRewritePattern;
231+
232+
LogicalResult matchAndRewrite(arith::ExtFOp op,
233+
PatternRewriter &rewriter) const override {
234+
Value input = op.getIn();
235+
Type result_type = op.getType();
236+
237+
// Convert arith ExtFOp to Neura cast operation
238+
// Optional predicate: default to null
239+
rewriter.replaceOpWithNewOp<neura::CastOp>(
240+
op, result_type, input, rewriter.getStringAttr("extf"), nullptr);
241+
return success();
242+
}
243+
};
244+
245+
struct ArithIndexCastToNeuraCast
246+
: public OpRewritePattern<mlir::arith::IndexCastOp> {
247+
using OpRewritePattern::OpRewritePattern;
248+
249+
LogicalResult matchAndRewrite(arith::IndexCastOp op,
250+
PatternRewriter &rewriter) const override {
251+
Value input = op.getIn();
252+
Type result_type = op.getType();
39253

40-
// Optional predicate: default to 'none'
41-
rewriter.replaceOpWithNewOp<neura::FAddOp>(op, resultType, lhs, rhs, Value());
254+
// Convert arith IndexCastOp to Neura cast operation
255+
// Optional predicate: default to null
256+
rewriter.replaceOpWithNewOp<neura::CastOp>(
257+
op, result_type, input, rewriter.getStringAttr("indexCast"), nullptr);
42258
return success();
43259
}
44260
};
45261

46262
struct LowerArithToNeuraPass
47-
: public PassWrapper<LowerArithToNeuraPass, OperationPass<func::FuncOp>> {
263+
: public PassWrapper<LowerArithToNeuraPass, OperationPass<ModuleOp>> {
48264

49265
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerArithToNeuraPass)
50266

@@ -60,7 +276,11 @@ struct LowerArithToNeuraPass
60276
void runOnOperation() override {
61277
RewritePatternSet patterns(&getContext());
62278
mlir::neura::arith2neura::populateWithGenerated(patterns);
63-
patterns.add<ArithFAddToNeuraFAdd>(&getContext());
279+
patterns
280+
.add<ArithFAddToNeuraFAdd, ArithConstantToNeuraConstant,
281+
ArithAddIToNeuraAdd, ArithCmpiToNeuraICmp, ArithSelectToNeuraSel,
282+
ArithExtUIToNeuraCast, ArithIndexCastToNeuraCast,
283+
ArithFDivToNeuraFDiv, ArithExtfToNeuraCast, ArithMulFToNeuraFMul, ArithSubIToNeuraSub, ArithSubFToNeuraFSub>(&getContext());
64284
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
65285
signalPassFailure();
66286
}

lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
22

33
add_subdirectory(ArithToNeura)
44
add_subdirectory(LlvmToNeura)
5+
add_subdirectory(MemRefToNeura)
56

67
# add_mlir_library(
78
# MLIRNeuraConversion
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
include_directories(${CMAKE_CURRENT_BINARY_DIR})
2+
3+
add_mlir_conversion_library(MLIRNeuraMemRefToNeuraPass
4+
MemRefToNeuraPass.cpp
5+
6+
DEPENDS
7+
MLIRConversionIncGen
8+
9+
LINK_LIBS PUBLIC
10+
MLIRArithDialect
11+
MLIRFuncDialect
12+
MLIRLLVMDialect
13+
MLIRIR
14+
MLIRPass
15+
MLIRTransforms
16+
MLIRNeura
17+
MLIRSupport
18+
)

0 commit comments

Comments
 (0)