Skip to content

Commit 814aa13

Browse files
committed
[fix] fix the logic for no-value cond_br edges
1 parent 78de86a commit 814aa13

File tree

4 files changed

+110
-80
lines changed

4 files changed

+110
-80
lines changed

lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -264,25 +264,27 @@ void createReserveAndPhiOps(func::FuncOp &func, ControlFlowInfo &ctrl_info,
264264
DenseMap<BlockArgument, SmallVector<Value>> arg_to_phi_operands;
265265

266266
for (auto &edge : ctrl_info.all_edges) {
267-
Block *source = edge->source;
268267
Block *target = edge->target;
269268

270269
// Type 1 & 2: Backward cond_br/br edges with values.
271270
if (edge->is_back_edge && !edge->passed_values.empty()) {
272271
if (edge->passed_values.size() != target->getNumArguments()) {
273-
llvm::errs() << "[ctrl2data] Error: Number of passed values does not match "
274-
"target block arguments.\n";
272+
llvm::errs()
273+
<< "[ctrl2data] Error: Number of passed values does not match "
274+
"target block arguments.\n";
275275
assert(false);
276276
}
277277
for (BlockArgument arg : target->getArguments()) {
278278
backward_value_edges[arg].push_back(edge.get());
279279
}
280280
}
281281
// Type 3 & 4: Forward cond_br/br edges with values.
282-
else if (!edge->is_back_edge && !edge->passed_values.empty()) {;
282+
else if (!edge->is_back_edge && !edge->passed_values.empty()) {
283+
;
283284
if (edge->passed_values.size() != target->getNumArguments()) {
284-
llvm::errs() << "[ctrl2data] Error: Number of passed values does not match "
285-
"target block arguments.\n";
285+
llvm::errs()
286+
<< "[ctrl2data] Error: Number of passed values does not match "
287+
"target block arguments.\n";
286288
assert(false);
287289
}
288290
for (BlockArgument arg : target->getArguments()) {
@@ -373,15 +375,14 @@ void createReserveAndPhiOps(func::FuncOp &func, ControlFlowInfo &ctrl_info,
373375
// No need to create a phi operation if there's only one operand.
374376

375377
if (phi_operands.size() == 1) {
376-
llvm::errs() << phi_operands[0] << "\n";
377378
arg_to_phi_result[arg] = phi_operands[0];
378379
arg.replaceAllUsesWith(phi_operands[0]);
379380
}
380381
continue;
381382
}
382383

383-
384-
// Handles the blcockargument with/without reserve seperately (different insertion points).
384+
// Handles the blcockargument with/without reserve seperately (different
385+
// insertion points).
385386
if (arg_to_reserve.count(arg)) {
386387
Value reserve_value = arg_to_reserve[arg];
387388
builder.setInsertionPointAfter(reserve_value.getDefiningOp());
@@ -403,6 +404,7 @@ void createReserveAndPhiOps(func::FuncOp &func, ControlFlowInfo &ctrl_info,
403404
}
404405
}
405406

407+
llvm::errs() << func << "\n";
406408
// ================================================
407409
// Step 5: Handles Forward cond_br edges without values.
408410
// ================================================
@@ -423,34 +425,61 @@ void createReserveAndPhiOps(func::FuncOp &func, ControlFlowInfo &ctrl_info,
423425
conditions.push_back(condition);
424426
}
425427

426-
427428
// Unsupported case: multiple conditions for a single block.
428429
// TODO: Adds support if needed.
429430
if (conditions.size() > 1) {
430-
llvm::errs()
431-
<< "[ctrl2data] Unsupported case: multiple conditions for a single block: "
432-
<< *target << "\n";
431+
llvm::errs() << "[ctrl2data] Unsupported case: multiple conditions for a "
432+
"single block: "
433+
<< *target << "\n";
433434
assert(false);
434435
}
435436

436437
if (target->getArguments().empty()) {
438+
// Grants predicate for all the live-in values in the target block.
439+
DenseSet<Value> live_in_values;
437440
for (Operation &op : target->getOperations()) {
438-
if (op.hasTrait<OpTrait::IsTerminator>() ||
439-
isa<neura::PhiOp, neura::ReserveOp, neura::CtrlMovOp,
440-
neura::GrantPredicateOp, neura::GrantAlwaysOp,
441-
neura::GrantOnceOp, neura::NotOp>(op)) {
442-
continue;
441+
for (Value operand : op.getOperands()) {
442+
if (operand.getDefiningOp() &&
443+
operand.getDefiningOp()->getBlock() != target &&
444+
!isa<neura::ReserveOp>(operand.getDefiningOp())) {
445+
live_in_values.insert(operand);
446+
}
443447
}
448+
}
444449

445-
builder.setInsertionPointAfter(&op);
450+
// Apply grant_predicate for each live-in value.
451+
for (Value live_in_value : live_in_values) {
452+
// Finds the earliest use of the live-in value.
453+
Operation *earliest_use = nullptr;
454+
for (Operation &op : target->getOperations()) {
455+
for (Value operand : op.getOperands()) {
456+
if (operand == live_in_value) {
457+
earliest_use = &op;
458+
break;
459+
}
460+
}
461+
if (earliest_use) {
462+
break;
463+
}
464+
}
465+
466+
if (earliest_use) {
467+
builder.setInsertionPoint(earliest_use);
468+
} else {
469+
builder.setInsertionPointToStart(target);
470+
}
446471

447-
for (Value result : op.getResults()) {
448-
if (isa<neura::PredicatedValue>(result.getType())) {
449-
// Creates a predicated value for the operation result.
450-
Value predicated_value = builder.create<neura::GrantPredicateOp>(
451-
op.getLoc(), result.getType(), result, conditions.front());
452-
result.replaceAllUsesExcept(predicated_value,
453-
predicated_value.getDefiningOp());
472+
// Creates predicated version of the live-in value
473+
Value predicated_value = builder.create<neura::GrantPredicateOp>(
474+
live_in_value.getLoc(), live_in_value.getType(), live_in_value,
475+
conditions[0]);
476+
477+
// Replace uses of the live-in value within this block only.
478+
for (OpOperand &use :
479+
llvm::make_early_inc_range(live_in_value.getUses())) {
480+
if (use.getOwner()->getBlock() == target &&
481+
use.getOwner() != predicated_value.getDefiningOp()) {
482+
use.set(predicated_value);
454483
}
455484
}
456485
}
@@ -494,7 +523,8 @@ void transformControlFlowToDataFlow(func::FuncOp &func,
494523
}
495524
}
496525

497-
// Moves all operations from blocks to the entry block before the terminator.
526+
// Moves all operations from blocks to the entry block before the
527+
// terminator.
498528
for (Block *block : blocks_to_flatten) {
499529
auto &ops = block->getOperations();
500530
while (!ops.empty()) {

test/affine2neura/bert/bert_node1/bert_node1.mlir

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,25 +59,26 @@ module attributes {} {
5959
// CTRL2DATA-NEXT: %10 = "neura.cast"(%9) <{cast_type = "int_to_index"}> : (!neura.data<i64, i1>) -> !neura.data<index, i1>
6060
// CTRL2DATA-NEXT: %11 = "neura.icmp"(%10, %3) <{cmpType = "slt"}> : (!neura.data<index, i1>, !neura.data<index, i1>) -> !neura.data<i1, i1>
6161
// CTRL2DATA-NEXT: %12 = "neura.not"(%11) : (!neura.data<i1, i1>) -> !neura.data<i1, i1>
62-
// CTRL2DATA-NEXT: %13 = "neura.cast"(%5) <{cast_type = "index_to_int"}> : (!neura.data<index, i1>) -> !neura.data<i64, i1>
63-
// CTRL2DATA-NEXT: %14 = neura.grant_predicate %13, %11 : !neura.data<i64, i1>, !neura.data<i1, i1> -> !neura.data<i64, i1>
62+
// CTRL2DATA-NEXT: %13 = neura.grant_predicate %5, %11 : !neura.data<index, i1>, !neura.data<i1, i1> -> !neura.data<index, i1>
63+
// CTRL2DATA-NEXT: %14 = "neura.cast"(%13) <{cast_type = "index_to_int"}> : (!neura.data<index, i1>) -> !neura.data<i64, i1>
6464
// CTRL2DATA-NEXT: %15 = neura.reserve : !neura.data<i64, i1>
6565
// CTRL2DATA-NEXT: %16 = "neura.phi"(%15, %14) : (!neura.data<i64, i1>, !neura.data<i64, i1>) -> !neura.data<i64, i1>
6666
// CTRL2DATA-NEXT: %17 = "neura.cast"(%16) <{cast_type = "int_to_index"}> : (!neura.data<i64, i1>) -> !neura.data<index, i1>
6767
// CTRL2DATA-NEXT: %18 = "neura.icmp"(%17, %3) <{cmpType = "slt"}> : (!neura.data<index, i1>, !neura.data<index, i1>) -> !neura.data<i1, i1>
6868
// CTRL2DATA-NEXT: %19 = "neura.not"(%18) : (!neura.data<i1, i1>) -> !neura.data<i1, i1>
69-
// CTRL2DATA-NEXT: %20 = neura.load_indexed %arg0[%5, %5, %5, %5, %5, %17 : !neura.data<index, i1>, !neura.data<index, i1>, !neura.data<index, i1>, !neura.data<index, i1>, !neura.data<index, i1>, !neura.data<index, i1>] memref<?x1x1x1x1x128xi8> : !neura.data<i8, i1>
70-
// CTRL2DATA-NEXT: %21 = neura.grant_predicate %20, %18 : !neura.data<i8, i1>, !neura.data<i1, i1> -> !neura.data<i8, i1>
71-
// CTRL2DATA-NEXT: neura.store_indexed %21 to %arg1[%5, %5, %10, %5, %5, %17 : !neura.data<index, i1>, !neura.data<index, i1>, !neura.data<index, i1>, !neura.data<index, i1>, !neura.data<index, i1>, !neura.data<index, i1>] memref<?x1x128x1x1x128xi8> : !neura.data<i8, i1>
72-
// CTRL2DATA-NEXT: %22 = "neura.add"(%17, %1) : (!neura.data<index, i1>, !neura.data<index, i1>) -> !neura.data<index, i1>
73-
// CTRL2DATA-NEXT: %23 = neura.grant_predicate %22, %18 : !neura.data<index, i1>, !neura.data<i1, i1> -> !neura.data<index, i1>
74-
// CTRL2DATA-NEXT: %24 = "neura.cast"(%23) <{cast_type = "index_to_int"}> : (!neura.data<index, i1>) -> !neura.data<i64, i1>
75-
// CTRL2DATA-NEXT: %25 = neura.grant_predicate %24, %18 : !neura.data<i64, i1>, !neura.data<i1, i1> -> !neura.data<i64, i1>
76-
// CTRL2DATA-NEXT: neura.ctrl_mov %25 -> %15 : !neura.data<i64, i1> !neura.data<i64, i1>
77-
// CTRL2DATA-NEXT: %26 = "neura.add"(%10, %1) : (!neura.data<index, i1>, !neura.data<index, i1>) -> !neura.data<index, i1>
78-
// CTRL2DATA-NEXT: %27 = neura.grant_predicate %26, %19 : !neura.data<index, i1>, !neura.data<i1, i1> -> !neura.data<index, i1>
79-
// CTRL2DATA-NEXT: %28 = "neura.cast"(%27) <{cast_type = "index_to_int"}> : (!neura.data<index, i1>) -> !neura.data<i64, i1>
80-
// CTRL2DATA-NEXT: %29 = neura.grant_predicate %28, %19 : !neura.data<i64, i1>, !neura.data<i1, i1> -> !neura.data<i64, i1>
81-
// CTRL2DATA-NEXT: neura.ctrl_mov %29 -> %8 : !neura.data<i64, i1> !neura.data<i64, i1>
69+
// CTRL2DATA-NEXT: %20 = neura.grant_predicate %{{[0-9]+}}, %18 : !neura.data<index, i1>, !neura.data<i1, i1> -> !neura.data<index, i1>
70+
// CTRL2DATA-NEXT: %21 = neura.grant_predicate %{{[0-9]+}}, %18 : !neura.data<index, i1>, !neura.data<i1, i1> -> !neura.data<index, i1>
71+
// CTRL2DATA-NEXT: %22 = neura.load_indexed %arg0[%{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}} : !neura.data<index, i1>, !neura.data<index, i1>, !neura.data<index, i1>, !neura.data<index, i1>, !neura.data<index, i1>, !neura.data<index, i1>] memref<?x1x1x1x1x128xi8> : !neura.data<i8, i1>
72+
// CTRL2DATA-NEXT: %23 = neura.grant_predicate %10, %18 : !neura.data<index, i1>, !neura.data<i1, i1> -> !neura.data<index, i1>
73+
// CTRL2DATA-NEXT: neura.store_indexed %22 to %arg1[%{{[0-9]+}}, %{{[0-9]+}}, %23, %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}} : !neura.data<index, i1>, !neura.data<index, i1>, !neura.data<index, i1>, !neura.data<index, i1>, !neura.data<index, i1>, !neura.data<index, i1>] memref<?x1x128x1x1x128xi8> : !neura.data<i8, i1>
74+
// CTRL2DATA-NEXT: %24 = neura.grant_predicate %1, %18 : !neura.data<index, i1>, !neura.data<i1, i1> -> !neura.data<index, i1>
75+
// CTRL2DATA-NEXT: %25 = "neura.add"(%{{[0-9]+}}, %24) : (!neura.data<index, i1>, !neura.data<index, i1>) -> !neura.data<index, i1>
76+
// CTRL2DATA-NEXT: %26 = "neura.cast"(%25) <{cast_type = "index_to_int"}> : (!neura.data<index, i1>) -> !neura.data<i64, i1>
77+
// CTRL2DATA-NEXT: neura.ctrl_mov %26 -> %15 : !neura.data<i64, i1> !neura.data<i64, i1>
78+
// CTRL2DATA-NEXT: %27 = neura.grant_predicate %{{[0-9]+}}, %19 : !neura.data<index, i1>, !neura.data<i1, i1> -> !neura.data<index, i1>
79+
// CTRL2DATA-NEXT: %28 = neura.grant_predicate %{{[0-9]+}}, %19 : !neura.data<index, i1>, !neura.data<i1, i1> -> !neura.data<index, i1>
80+
// CTRL2DATA-NEXT: %29 = "neura.add"(%{{[0-9]+}}, %{{[0-9]+}}) : (!neura.data<index, i1>, !neura.data<index, i1>) -> !neura.data<index, i1>
81+
// CTRL2DATA-NEXT: %30 = "neura.cast"(%29) <{cast_type = "index_to_int"}> : (!neura.data<index, i1>) -> !neura.data<i64, i1>
82+
// CTRL2DATA-NEXT: neura.ctrl_mov %30 -> %8 : !neura.data<i64, i1> !neura.data<i64, i1>
8283
// CTRL2DATA-NEXT: "neura.return"() : () -> ()
8384
// CTRL2DATA-NEXT: }

0 commit comments

Comments
 (0)