Skip to content

Commit 805548c

Browse files
authored
Merge pull request #53 from ShangkunLi/memref-builtin-lower
Enable end2end affine-to-neura lowering
2 parents 38ca969 + 93620bc commit 805548c

File tree

24 files changed

+452
-219
lines changed

24 files changed

+452
-219
lines changed

include/Conversion/ConversionPasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ namespace mlir {
2020
std::unique_ptr<mlir::Pass> createLowerArithToNeuraPass();
2121
std::unique_ptr<mlir::Pass> createLowerLlvmToNeuraPass();
2222
std::unique_ptr<mlir::Pass> createLowerMemRefToNeuraPass();
23+
std::unique_ptr<mlir::Pass> createLowerBuiltinToNeuraPass();
2324

2425
#define GEN_PASS_REGISTRATION
2526
#include "Conversion/ConversionPasses.h.inc"

include/Conversion/ConversionPasses.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,10 @@ def LowerMemRefToNeura : Pass<"lower-memref-to-neura", "ModuleOp">{
2626
let constructor = "mlir::createLowerMemRefToNeuraPass()";
2727
}
2828

29+
def LowerBuiltinToNeura : Pass<"lower-builtin-to-neura", "ModuleOp">{
30+
let summary = "Lower Builtin to Neura dialect";
31+
let description = [{Lower Builtin operations to Neura dialect operations.}];
32+
let constructor = "mlir::createLowerBuiltinToNeuraPass()";
33+
}
34+
2935
#endif // CONVERSION_PASSES_TD

include/NeuraDialect/NeuraOps.td

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,34 @@ def Neura_StoreOp : Op<NeuraDialect, "store"> {
116116
// let assemblyFormat = "$value `,` $addr `,` $predicate attr-dict";
117117
}
118118

119+
// Defines a load operation with integrated address calculation.
120+
def Neura_LoadIndexedOp: Op<NeuraDialect, "load_indexed", [AttrSizedOperandSegments]>{
121+
let summary = "Load with integrated address calculation for multi-dimensional arrays";
122+
let description = [{
123+
Calculates the address using the base address and indices.
124+
Load the value at the calculated address.
125+
Example:
126+
%value = neura.load_indexed %base [%arg1, %arg2] : f32
127+
}];
128+
let arguments = (ins Arg<AnyMemRef, "the load operation">:$base, Variadic<AnyType>:$indices, Optional<AnyType>:$predicate);
129+
let results = (outs AnyType:$result);
130+
let assemblyFormat = "$base `[` $indices `:` type($indices) `]` type($base) ($predicate^ `:` type($predicate))? attr-dict `:` type($result)";
131+
}
132+
133+
//Defines a store operation with integrated address calculation.
134+
def Neura_StoreIndexedOp: Op<NeuraDialect, "store_indexed", [AttrSizedOperandSegments]> {
135+
let summary = "Store with integrated address calculation for multi-dimensional arrays";
136+
let description = [{
137+
Calculates the address using the base address and indices.
138+
Store the value at the calculated address.
139+
Example:
140+
neura.store_indexed %value, %base [%arg1, %arg2] : f32
141+
}];
142+
let arguments = (ins AnyType:$value, Arg<AnyMemRef, "the store operation">:$base, Variadic<AnyType>:$indices, Optional<AnyType>:$predicate);
143+
let results = (outs);
144+
let assemblyFormat = "$value `to` $base `[` $indices `:` type($indices) `]` type($base) ($predicate^ `:` type($predicate))? attr-dict `:` type($value)";
145+
}
146+
119147
// Defines a pointer computation operation.
120148
def Neura_GEP : Op<NeuraDialect, "gep"> {
121149
let summary = "Pointer computation using offset indices";
@@ -131,14 +159,14 @@ def Neura_CondBr : Op<NeuraDialect, "cond_br", [Terminator, AttrSizedOperandSegm
131159
Variadic<AnyType>:$trueArgs,
132160
Variadic<AnyType>:$falseArgs);
133161
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
134-
let assemblyFormat = "$condition `:` type($condition) ($predicate^ `:` type($predicate))? `then` ($trueArgs^)? `:` type($trueArgs) `to` $trueDest `else` ($falseArgs^)? `:` type($falseArgs) `to` $falseDest attr-dict";
162+
let assemblyFormat = "$condition `:` type($condition) ($predicate^ `:` type($predicate))? `then` ($trueArgs^ `:` type($trueArgs))? `to` $trueDest `else` ($falseArgs^ `:` type($falseArgs))? `to` $falseDest attr-dict";
135163
}
136164

137165
// Defines an unconditional branch operation.
138166
def Neura_Br : Op<NeuraDialect, "br", [Terminator]> {
139167
let arguments = (ins Variadic<AnyType>:$args);
140168
let successors = (successor AnySuccessor:$dest);
141-
let assemblyFormat = "($args^)? `:` type($args) `to` $dest attr-dict";
169+
let assemblyFormat = "($args^ `:` type($args))? `to` $dest attr-dict";
142170
}
143171

144172
def Neura_SelOp : Op<NeuraDialect, "sel"> {

include/NeuraDialect/NeuraPasses.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,5 +57,4 @@ def MapToAccelerator : Pass<"map-to-accelerator", "ModuleOp"> {
5757
}];
5858
let constructor = "neura::createMapToAcceleratorPass()";
5959
}
60-
6160
#endif // NEURA_PASSES_TD

lib/Conversion/ArithToNeura/ArithToNeuraPass.cpp

Lines changed: 54 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "Common/AcceleratorAttrs.h"
12
#include "Conversion/ConversionPasses.h"
23
#include "NeuraDialect/NeuraDialect.h"
34
#include "NeuraDialect/NeuraOps.h"
@@ -8,6 +9,7 @@
89
#include "mlir/IR/PatternMatch.h"
910
#include "mlir/Pass/Pass.h"
1011
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
12+
#include "llvm/ADT/StringRef.h"
1113

1214
namespace mlir {
1315
namespace neura {
@@ -24,9 +26,6 @@ using namespace mlir;
2426
using namespace mlir::func;
2527
using namespace mlir::neura;
2628

27-
#define GEN_PASS_DEF_LOWERARITHTONEURA
28-
#include "NeuraDialect/NeuraPasses.h.inc"
29-
3029
namespace {
3130

3231
struct ArithConstantToNeuraConstant
@@ -35,10 +34,10 @@ struct ArithConstantToNeuraConstant
3534

3635
LogicalResult matchAndRewrite(arith::ConstantOp op,
3736
PatternRewriter &rewriter) const override {
38-
// Converts arith constant to Neura constant
37+
// Converts arith constant to Neura constant.
3938
Type result_type = op.getType();
4039
Attribute value = op.getValue();
41-
// Optional predicate parameter can be null
40+
// Optional predicate parameter can be null.
4241
rewriter.replaceOpWithNewOp<neura::ConstantOp>(op, result_type, value,
4342
nullptr);
4443
return success();
@@ -54,7 +53,7 @@ struct ArithAddIToNeuraAdd : public OpRewritePattern<mlir::arith::AddIOp> {
5453
Value rhs = op.getRhs();
5554
Type result_type = op.getType();
5655

57-
// Optional predicate: default to null
56+
// Optional predicate: default to null.
5857
rewriter.replaceOpWithNewOp<neura::AddOp>(op, result_type, lhs, rhs,
5958
nullptr);
6059
return success();
@@ -70,7 +69,7 @@ struct ArithFAddToNeuraFAdd : public OpRewritePattern<mlir::arith::AddFOp> {
7069
Value rhs = op.getRhs();
7170
Type result_type = op.getType();
7271

73-
// Optional predicate: default to null
72+
// Optional predicate: default to null.
7473
rewriter.replaceOpWithNewOp<neura::FAddOp>(op, result_type, lhs, rhs,
7574
nullptr);
7675
return success();
@@ -86,7 +85,7 @@ struct ArithSubIToNeuraSub : public OpRewritePattern<mlir::arith::SubIOp> {
8685
Value rhs = op.getRhs();
8786
Type result_type = op.getType();
8887

89-
// Optional predicate: default to null
88+
// Optional predicate: default to null.
9089
rewriter.replaceOpWithNewOp<neura::SubOp>(op, result_type, lhs, rhs,
9190
nullptr);
9291
return success();
@@ -102,7 +101,7 @@ struct ArithSubFToNeuraFSub : public OpRewritePattern<mlir::arith::SubFOp> {
102101
Value rhs = op.getRhs();
103102
Type result_type = op.getType();
104103

105-
// Optional predicate: default to null
104+
// Optional predicate: default to null.
106105
rewriter.replaceOpWithNewOp<neura::FSubOp>(op, result_type, lhs, rhs,
107106
nullptr);
108107
return success();
@@ -118,7 +117,7 @@ struct ArithMulFToNeuraFMul : public OpRewritePattern<mlir::arith::MulFOp> {
118117
Value rhs = op.getRhs();
119118
Type result_type = op.getType();
120119

121-
// Optional predicate: default to null
120+
// Optional predicate: default to null.
122121
rewriter.replaceOpWithNewOp<neura::FMulOp>(op, result_type, lhs, rhs,
123122
nullptr);
124123
return success();
@@ -134,7 +133,7 @@ struct ArithFDivToNeuraFDiv : public OpRewritePattern<mlir::arith::DivFOp> {
134133
Value rhs = op.getRhs();
135134
Type result_type = op.getType();
136135

137-
// Optional predicate: default to null
136+
// Optional predicate: default to null.
138137
rewriter.replaceOpWithNewOp<neura::FDivOp>(op, result_type, lhs, rhs,
139138
nullptr);
140139
return success();
@@ -185,8 +184,8 @@ struct ArithCmpiToNeuraICmp : public OpRewritePattern<mlir::arith::CmpIOp> {
185184
return rewriter.notifyMatchFailure(op, "Unsupported arith CmpIOp type");
186185
}
187186

188-
// Convert arith CmpIOp to Neura ICmpOp
189-
// Optional predicate: default to null
187+
// Converts arith CmpIOp to Neura ICmpOp.
188+
// Optional predicate: default to null.
190189
rewriter.replaceOpWithNewOp<neura::ICmpOp>(
191190
op, result_type, lhs, rhs, nullptr, rewriter.getStringAttr(cmp_type));
192191
return success();
@@ -203,7 +202,7 @@ struct ArithSelectToNeuraSel : public OpRewritePattern<mlir::arith::SelectOp> {
203202
Value false_value = op.getFalseValue();
204203
Type result_type = op.getType();
205204

206-
// Convert arith SelectOp to Neura SelOp
205+
// Converts arith SelectOp to Neura SelOp.
207206
rewriter.replaceOpWithNewOp<neura::SelOp>(op, result_type, true_value,
208207
false_value, condition);
209208
return success();
@@ -218,8 +217,8 @@ struct ArithExtUIToNeuraCast : public OpRewritePattern<mlir::arith::ExtUIOp> {
218217
Value input = op.getIn();
219218
Type result_type = op.getType();
220219

221-
// Convert arith ExtUIOp to Neura cast operation
222-
// Optional predicate: default to null
220+
// Converts arith ExtUIOp to Neura cast operation.
221+
// Optional predicate: default to null.
223222
rewriter.replaceOpWithNewOp<neura::CastOp>(
224223
op, result_type, input, rewriter.getStringAttr("extui"), nullptr);
225224
return success();
@@ -234,8 +233,8 @@ struct ArithExtfToNeuraCast : public OpRewritePattern<mlir::arith::ExtFOp> {
234233
Value input = op.getIn();
235234
Type result_type = op.getType();
236235

237-
// Convert arith ExtFOp to Neura cast operation
238-
// Optional predicate: default to null
236+
// Converts arith ExtFOp to Neura cast operation.
237+
// Optional predicate: default to null.
239238
rewriter.replaceOpWithNewOp<neura::CastOp>(
240239
op, result_type, input, rewriter.getStringAttr("extf"), nullptr);
241240
return success();
@@ -250,11 +249,23 @@ struct ArithIndexCastToNeuraCast
250249
PatternRewriter &rewriter) const override {
251250
Value input = op.getIn();
252251
Type result_type = op.getType();
252+
Type in_type = input.getType();
253+
StringRef cast_string;
254+
255+
// The isa<IntegerType> check is generic and handles any integer bit width.
256+
// (e.g., i32, i64).
257+
if (in_type.isIndex() && isa<IntegerType>(result_type)) {
258+
cast_string = "index_to_int";
259+
} else if (isa<IntegerType>(in_type) && result_type.isIndex()) {
260+
cast_string = "int_to_index";
261+
} else {
262+
return rewriter.notifyMatchFailure(op, "index_cast");
263+
}
253264

254-
// Convert arith IndexCastOp to Neura cast operation
255-
// Optional predicate: default to null
265+
// Converts arith IndexCastOp to Neura cast operation.
266+
// Optional predicate: default to null.
256267
rewriter.replaceOpWithNewOp<neura::CastOp>(
257-
op, result_type, input, rewriter.getStringAttr("indexCast"), nullptr);
268+
op, result_type, input, rewriter.getStringAttr(cast_string), nullptr);
258269
return success();
259270
}
260271
};
@@ -274,16 +285,28 @@ struct LowerArithToNeuraPass
274285
}
275286

276287
void runOnOperation() override {
277-
RewritePatternSet patterns(&getContext());
278-
mlir::neura::arith2neura::populateWithGenerated(patterns);
279-
patterns
280-
.add<ArithFAddToNeuraFAdd, ArithConstantToNeuraConstant,
281-
ArithAddIToNeuraAdd, ArithCmpiToNeuraICmp, ArithSelectToNeuraSel,
282-
ArithExtUIToNeuraCast, ArithIndexCastToNeuraCast,
283-
ArithFDivToNeuraFDiv, ArithExtfToNeuraCast, ArithMulFToNeuraFMul, ArithSubIToNeuraSub, ArithSubFToNeuraFSub>(&getContext());
284-
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
285-
signalPassFailure();
286-
}
288+
ModuleOp module_op = getOperation();
289+
MLIRContext *context = &getContext();
290+
module_op.walk([&](func::FuncOp func_op) {
291+
if (func_op->hasAttr(mlir::accel::kAcceleratorAttr)) {
292+
auto target =
293+
func_op->getAttrOfType<StringAttr>(mlir::accel::kAcceleratorAttr);
294+
if (target && target.getValue() == mlir::accel::kNeuraTarget) {
295+
RewritePatternSet patterns(&getContext());
296+
mlir::neura::arith2neura::populateWithGenerated(patterns);
297+
patterns.add<ArithFAddToNeuraFAdd, ArithConstantToNeuraConstant,
298+
ArithAddIToNeuraAdd, ArithCmpiToNeuraICmp,
299+
ArithSelectToNeuraSel, ArithExtUIToNeuraCast,
300+
ArithIndexCastToNeuraCast, ArithFDivToNeuraFDiv,
301+
ArithExtfToNeuraCast, ArithMulFToNeuraFMul,
302+
ArithSubIToNeuraSub, ArithSubFToNeuraFSub>(context);
303+
if (failed(
304+
applyPatternsGreedily(getOperation(), std::move(patterns)))) {
305+
signalPassFailure();
306+
}
307+
}
308+
}
309+
});
287310
}
288311
};
289312
} // namespace
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#include "Common/AcceleratorAttrs.h"
2+
#include "Conversion/ConversionPasses.h"
3+
#include "NeuraDialect/NeuraDialect.h"
4+
#include "NeuraDialect/NeuraOps.h"
5+
#include "mlir/Dialect/Func/IR/FuncOps.h"
6+
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
7+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
8+
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
9+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
10+
#include "mlir/IR/MLIRContext.h"
11+
#include "mlir/IR/PatternMatch.h"
12+
#include "mlir/Pass/Pass.h"
13+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14+
#include "llvm/Support/raw_ostream.h"
15+
16+
using namespace mlir;
17+
using namespace mlir::neura;
18+
19+
namespace {
20+
21+
struct BuiltinUnrealizedConversionCastToNeuraCast
22+
: public OpRewritePattern<mlir::UnrealizedConversionCastOp> {
23+
using OpRewritePattern::OpRewritePattern;
24+
25+
LogicalResult matchAndRewrite(mlir::UnrealizedConversionCastOp op,
26+
PatternRewriter &rewriter) const override {
27+
// Only handles simple 1:1 casts.
28+
// TODO: Handle more complex casts if needed.
29+
if (op.getInputs().size() == 1 && op.getResults().size() == 1) {
30+
Value input = op.getInputs()[0];
31+
Type result_type = op.getResults()[0].getType();
32+
Type input_type = input.getType();
33+
34+
StringRef cast_type;
35+
if (input_type.isIndex() && isa<IntegerType>(result_type)) {
36+
cast_type = "index_to_int";
37+
} else if (isa<IntegerType>(input_type) && result_type.isIndex()) {
38+
cast_type = "int_to_index";
39+
} else {
40+
return rewriter.notifyMatchFailure(op, "unsupported cast");
41+
}
42+
43+
// Optional predicate: default to null.
44+
rewriter.replaceOpWithNewOp<neura::CastOp>(
45+
op, result_type, input, rewriter.getStringAttr(cast_type), nullptr);
46+
return success();
47+
}
48+
return failure();
49+
}
50+
};
51+
52+
struct LowerBuiltinToNeuraPass
53+
: public PassWrapper<LowerBuiltinToNeuraPass, OperationPass<ModuleOp>> {
54+
55+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerBuiltinToNeuraPass)
56+
57+
StringRef getArgument() const override { return "lower-builtin-to-neura"; }
58+
StringRef getDescription() const override {
59+
return "Lower Builtin operations to Neura dialect operations";
60+
}
61+
62+
void getDependentDialects(DialectRegistry &registry) const override {
63+
registry.insert<mlir::neura::NeuraDialect>();
64+
}
65+
66+
void runOnOperation() override {
67+
ModuleOp module_op = getOperation();
68+
MLIRContext *context = &getContext();
69+
RewritePatternSet patterns(&getContext());
70+
patterns.add<BuiltinUnrealizedConversionCastToNeuraCast>(context);
71+
module_op.walk([&](func::FuncOp func_op) {
72+
if (func_op->hasAttr(mlir::accel::kAcceleratorAttr)) {
73+
auto target =
74+
func_op->getAttrOfType<StringAttr>(mlir::accel::kAcceleratorAttr);
75+
if (target && target.getValue() == mlir::accel::kNeuraTarget) {
76+
if (failed(applyPatternsGreedily(func_op, std::move(patterns)))) {
77+
return signalPassFailure();
78+
}
79+
}
80+
}
81+
});
82+
}
83+
};
84+
} // namespace
85+
86+
std::unique_ptr<Pass> mlir::createLowerBuiltinToNeuraPass() {
87+
return std::make_unique<LowerBuiltinToNeuraPass>();
88+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
include_directories(${CMAKE_CURRENT_BINARY_DIR})
2+
3+
add_mlir_conversion_library(MLIRNeuraBuiltinToNeuraPass
4+
BuiltinToNeuraPass.cpp
5+
6+
DEPENDS
7+
MLIRConversionIncGen
8+
9+
LINK_LIBS PUBLIC
10+
MLIRArithDialect
11+
MLIRFuncDialect
12+
MLIRLLVMDialect
13+
MLIRIR
14+
MLIRPass
15+
MLIRTransforms
16+
MLIRNeura
17+
MLIRSupport
18+
)

0 commit comments

Comments
 (0)