Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 61 additions & 34 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1286,8 +1286,7 @@
/** Get argument index that has the split instruction for a group of instructions
* If instructions in a group have different split indexes, return -1.
*/
int get_binary_op_split_idx(std::vector<instruction_ref> group,
std::vector<instruction_ref> splits) const
int get_split_idx(std::vector<instruction_ref> group, std::vector<instruction_ref> splits) const
{
auto first_group_inputs = group.front()->inputs();
auto arg_it =
Expand Down Expand Up @@ -1361,7 +1360,7 @@
return false;
}

void apply(module& m, const match::matcher_result& r) const

Check warning on line 1363 in src/simplify_algebra.cpp

View workflow job for this annotation

GitHub Actions / tidy

function 'apply' exceeds recommended size/complexity thresholds [readability-function-size,-warnings-as-errors]
{
auto ins = r.result;
auto splits = get_splits(ins);
Expand Down Expand Up @@ -1393,7 +1392,7 @@
{
c = m.insert_instruction(std::next(ins), op, {ins}, start->module_inputs());
}
else if(start->inputs().size() == 2)
else if(start->inputs().size() >= 2)
{
assert(not std::none_of(start->inputs().begin(), start->inputs().end(), [](auto i) {
return i->name() == "slice";
Expand All @@ -1407,47 +1406,75 @@
if(not concat_const_foldable(group.begin(), group.end(), concat_axis))
return;

split_idx = get_binary_op_split_idx(group, splits);
assert(split_idx < 2);
size_t data_idx;
if(split_idx < 0 and op.attributes().contains("commutative"))
split_idx = get_split_idx(group, splits);
if(split_idx < 0)
{
split_idx = 0;
data_idx = 1;
align_commutative_op_args(m, group, splits, split_idx);
}
else if(split_idx < 0)
{
return;
// For binary commutative operations, try swapping arguments
if(start->inputs().size() == 2 and op.attributes().contains("commutative"))
{
split_idx = 0;
align_commutative_op_args(m, group, splits, split_idx);
}
else
{
return;
}
}
else
assert(split_idx >= 0 and split_idx < start->inputs().size());

// Collect all data arguments for each position (excluding split position)
std::vector<std::vector<instruction_ref>> all_data_args;
for(size_t arg_idx = 0; arg_idx < start->inputs().size(); ++arg_idx)
{
data_idx = split_idx == 0 ? 1 : 0;
}
if(arg_idx == split_idx)
continue;

std::vector<instruction_ref> data_args;
std::transform(group.begin(),
group.end(),
std::back_inserter(data_args),
[&](auto i) { return i->inputs()[data_idx]; });
std::vector<instruction_ref> data_args;
std::transform(group.begin(),
group.end(),
std::back_inserter(data_args),
[&](auto i) { return i->inputs()[arg_idx]; });

// Data arguments must be a constant
if(std::any_of(data_args.begin(), data_args.end(), [](auto i) {
return not i->can_eval();
}))
return;
// All data arguments must be constants
if(std::any_of(data_args.begin(), data_args.end(), [](auto i) {
return not i->can_eval();
}))
return;

move_instructions_back(m, ins, data_args);
all_data_args.push_back(std::move(data_args));
}

// TODO: Check if axises match
auto concat = m.insert_instruction(
ins, make_op("concat", {{"axis", concat_axis}}), data_args);
// Move all data arguments back to ensure they're available
for(const auto& data_args : all_data_args)
{
move_instructions_back(m, ins, data_args);
}

// Create concatenations for each data argument position
std::vector<instruction_ref> concat_args;
for(const auto& data_args : all_data_args)
{
// TODO: Check if axises match
auto concat = m.insert_instruction(
ins, make_op("concat", {{"axis", concat_axis}}), data_args);
concat_args.push_back(concat);
}

// Build the final argument list
std::vector<instruction_ref> args;
args.resize(2);
args.resize(start->inputs().size());
args[split_idx] = ins;
args[data_idx] = concat;
c = m.insert_instruction(std::next(ins), op, {args}, start->module_inputs());

size_t concat_idx = 0;
for(size_t arg_idx = 0; arg_idx < start->inputs().size(); ++arg_idx)
{
if(arg_idx != split_idx)
{
args[arg_idx] = concat_args[concat_idx++];
}
}

c = m.insert_instruction(std::next(ins), op, args, start->module_inputs());
}
if(c != m.end())
{
Expand Down
67 changes: 67 additions & 0 deletions test/simplify_algebra_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2585,6 +2585,73 @@ TEST_CASE(find_splits_for_pointwise2)
EXPECT(p1 == p2);
}

TEST_CASE(find_splits_for_multi_arg_ops)
{
auto s = migraphx::shape{migraphx::shape::float_type, {3, 2, 4}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto input = mm->add_parameter("input", s);
auto x = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto y = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);

// Create constants for the multi-argument operations
auto c1 = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3, 1, 4}}, {1.0f}});
auto c2 = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3, 1, 4}}, {2.0f}});
auto c3 = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3, 1, 4}}, {3.0f}});
auto c4 = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3, 1, 4}}, {4.0f}});

// Create multi-argument pointwise operations
auto multi_op_x =
add_pointwise(p1, "main:pointwise0", {x, c1, c2}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("mul"), add1, inputs[2]);
});
auto multi_op_y =
add_pointwise(p1, "main:pointwise1", {y, c3, c4}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("mul"), add1, inputs[2]);
});
mm->add_return({multi_op_x, multi_op_y});
}

// Store the original instruction count before transformation
auto original_pointwise_count =
std::count_if(p1.get_main_module()->begin(),
p1.get_main_module()->end(),
[](const auto& ins) { return ins.name() == "pointwise"; });

migraphx::run_passes(p1, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}});

// After transformation, there should be only one pointwise operation
auto final_pointwise_count =
std::count_if(p1.get_main_module()->begin(),
p1.get_main_module()->end(),
[](const auto& ins) { return ins.name() == "pointwise"; });

// Check that the transformation reduced the number of pointwise operations
EXPECT(original_pointwise_count == 2);
EXPECT(final_pointwise_count == 1);

// Check that we have concat operations (indicating fusion happened)
auto concat_count = std::count_if(p1.get_main_module()->begin(),
p1.get_main_module()->end(),
[](const auto& ins) { return ins.name() == "concat"; });
EXPECT(concat_count == 2); // Should have 2 concat operations for the 2 constant arguments

// Check that we still have the slice operations after the fused pointwise
auto slice_count = std::count_if(p1.get_main_module()->begin(),
p1.get_main_module()->end(),
[](const auto& ins) { return ins.name() == "slice"; });
EXPECT(slice_count == 2); // Should have 2 slice operations for the results
}

TEST_CASE(simplify_slice_different_axis)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 2}};
Expand Down
Loading