Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/NeuraDialect/NeuraOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,11 @@ def Neura_PhiOp : Op<NeuraDialect, "phi"> {
neura.ctrl_mov %next to %v // Connect next iteration
}];

let arguments = (ins AnyType:$init_val, AnyType:$loop_val);
let arguments = (ins Variadic<AnyType>:$inputs);
let results = (outs AnyType:$result);

// Explicitly specify types for operands in the assembly format
let assemblyFormat = "$init_val `:` type($init_val) `,` $loop_val `:` type($loop_val) attr-dict `:` type($result)";
// let assemblyFormat = "$init_val `:` type($init_val) `,` $loop_val `:` type($loop_val) attr-dict `,` type($result)";
}

// Control movement extending base move but with different signature.
Expand Down
73 changes: 22 additions & 51 deletions lib/NeuraDialect/Transforms/LeveragePredicatedValuePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,51 +13,6 @@ using namespace mlir;
#include "NeuraDialect/NeuraPasses.h.inc"

namespace {
struct applyPredicatedDataType : public RewritePattern {
applyPredicatedDataType(MLIRContext *context)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}

LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
llvm::errs() << "Processing op: " << *op << "\n";

// Skips if not a Neura op or already using predicated values.
if (op->getDialect()->getNamespace() != "neura") {
llvm::errs() << "Skipping non-Neura op\n";
return failure();
}

if (llvm::any_of(op->getResultTypes(),
[](Type t) { return mlir::isa<mlir::neura::PredicatedValue>(t); })) {
llvm::errs() << "Skipping already predicated op\n";
return failure();
}

// Converts result types to predicated form.
SmallVector<Type> newResults;
for (Type t : op->getResultTypes()) {
auto predicatedTy = mlir::neura::PredicatedValue::get(
op->getContext(),
t,
rewriter.getI1Type());
newResults.push_back(predicatedTy);
}

// Clones the operation with new result types.
OperationState state(op->getLoc(), op->getName());
state.addOperands(op->getOperands());
state.addTypes(newResults);
state.addAttributes(op->getAttrs());
Operation *newOp = rewriter.create(state);

// Replaces the old op with the new one.
rewriter.replaceOp(op, newOp->getResults());
llvm::errs() << "Converted op to predicated form: " << *newOp << "\n";
if (!newResults.empty()) {
assert(false);
}
return success();
}
};

struct LeveragePredicatedValuePass
: public PassWrapper<LeveragePredicatedValuePass, OperationPass<ModuleOp>> {
Expand All @@ -77,7 +32,26 @@ struct LeveragePredicatedValuePass

// Processes each function.
module.walk([&](func::FuncOp func) {
// Get operations in topological order (operands before users)
// Converts block argument types to predicated values.
func.walk([&](Block *block) {
// skips the entry (first) block of the function.
if (block == &block->getParent()->front()) {
return;
}
for (BlockArgument arg : block->getArguments()) {
Type origType = arg.getType();

// Avoid double-wrapping if already predicated
if (llvm::isa<neura::PredicatedValue>(origType))
continue;

auto predicated_type = neura::PredicatedValue::get(
func.getContext(), origType, IntegerType::get(func.getContext(), 1));
arg.setType(predicated_type);
}
});

// Gets operations in topological order (operands before users).
SmallVector<Operation*> orderedOps;
getOperationsInTopologicalOrder(func, orderedOps);

Expand Down Expand Up @@ -122,11 +96,8 @@ struct LeveragePredicatedValuePass

// Converts a single operation to use predicated values.
LogicalResult applyPredicatedDataType(Operation *op) {
llvm::errs() << "Processing op: " << *op << "\n";

// Skips if not a Neura op.
if (op->getDialect()->getNamespace() != "neura") {
llvm::errs() << "Skipping non-Neura op\n";
return success();
}

Expand All @@ -141,11 +112,11 @@ struct LeveragePredicatedValuePass
OpBuilder builder(op);
SmallVector<Type> newResults;
for (Type t : op->getResultTypes()) {
auto predicatedTy = mlir::neura::PredicatedValue::get(
auto predicated_type = mlir::neura::PredicatedValue::get(
op->getContext(),
t,
builder.getI1Type());
newResults.push_back(predicatedTy);
newResults.push_back(predicated_type);
}

// Clones with new result types.
Expand Down
Loading