Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 39 additions & 40 deletions tests/transforms/test_convert_pdl_to_pdl_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1791,7 +1791,7 @@ def test_get_value_at_operation_position():
# Manually add root to cache
generator.values[root_pos] = root_val

result = generator.get_value_at(block, root_pos)
result = generator.get_value_at(root_pos)
assert result is root_val

# Test case 2: Operand defining op position
Expand All @@ -1801,7 +1801,7 @@ def test_get_value_at_operation_position():
# First get the operand value
generator.values[operand_pos] = root_val # Mock operand value

result = generator.get_value_at(block, defining_op_pos)
result = generator.get_value_at(defining_op_pos)

# Should create GetDefiningOpOp
get_def_ops = [op for op in block.ops if isinstance(op, pdl_interp.GetDefiningOpOp)]
Expand Down Expand Up @@ -1833,7 +1833,7 @@ def test_get_value_at_operand_position():

# Get operand at index 2
operand_pos = root_pos.get_operand(2)
result = generator.get_value_at(block, operand_pos)
result = generator.get_value_at(operand_pos)

# Should create GetOperandOp with index 2
get_operand_ops = [
Expand Down Expand Up @@ -1868,7 +1868,7 @@ def test_get_value_at_result_position():

# Get result at index 1
result_pos = root_pos.get_result(1)
result = generator.get_value_at(block, result_pos)
result = generator.get_value_at(result_pos)

# Should create GetResultOp with index 1
get_result_ops = [op for op in block.ops if isinstance(op, pdl_interp.GetResultOp)]
Expand Down Expand Up @@ -1901,7 +1901,7 @@ def test_get_value_at_result_group_position():

# Test variadic result group
result_group_pos = root_pos.get_result_group(0, is_variadic=True)
result = generator.get_value_at(block, result_group_pos)
result = generator.get_value_at(result_group_pos)

# Should create GetResultsOp
get_results_ops = [
Expand All @@ -1915,7 +1915,7 @@ def test_get_value_at_result_group_position():

# Test non-variadic result group
result_group_pos2 = root_pos.get_result_group(1, is_variadic=False)
result2 = generator.get_value_at(block, result_group_pos2)
result2 = generator.get_value_at(result_group_pos2)

get_results_ops = [
op for op in block.ops if isinstance(op, pdl_interp.GetResultsOp)
Expand Down Expand Up @@ -1949,7 +1949,7 @@ def test_get_value_at_attribute_position():

# Get attribute named "test_attr"
attr_pos = root_pos.get_attribute("test_attr")
result = generator.get_value_at(block, attr_pos)
result = generator.get_value_at(attr_pos)

# Should create GetAttributeOp
get_attr_ops = [op for op in block.ops if isinstance(op, pdl_interp.GetAttributeOp)]
Expand Down Expand Up @@ -1983,7 +1983,7 @@ def test_get_value_at_attribute_literal_position():
const_attr = IntegerAttr(42, i32)
attr_literal_pos = AttributeLiteralPosition(value=const_attr, parent=None)

result = generator.get_value_at(block, attr_literal_pos)
result = generator.get_value_at(attr_literal_pos)

# Should create CreateAttributeOp
create_attr_ops = [
Expand Down Expand Up @@ -2021,7 +2021,7 @@ def test_get_value_at_type_position():
generator.values[result_pos] = result_val

type_pos = result_pos.get_type()
result = generator.get_value_at(block, type_pos)
result = generator.get_value_at(type_pos)

# Should create GetValueTypeOp
get_type_ops = [op for op in block.ops if isinstance(op, pdl_interp.GetValueTypeOp)]
Expand Down Expand Up @@ -2050,7 +2050,7 @@ def test_get_value_at_type_literal_position():

# Test case 1: Single type literal
type_literal_pos = TypeLiteralPosition.get_type_literal(value=i32)
result = generator.get_value_at(block, type_literal_pos)
result = generator.get_value_at(type_literal_pos)

# Should create CreateTypeOp
create_type_ops = [
Expand All @@ -2063,7 +2063,7 @@ def test_get_value_at_type_literal_position():
# Test case 2: Multiple types (ArrayAttr)
types_array = ArrayAttr([i32, f32])
types_literal_pos = TypeLiteralPosition.get_type_literal(value=types_array)
result2 = generator.get_value_at(block, types_literal_pos)
result2 = generator.get_value_at(types_literal_pos)

# Should create CreateTypesOp
create_types_ops = [
Expand Down Expand Up @@ -2118,7 +2118,7 @@ def test_get_value_at_constraint_position():

# Get constraint result at index 1
constraint_pos = ConstraintPosition.get_constraint(constraint_q, result_index=1)
result = generator.get_value_at(block, constraint_pos)
result = generator.get_value_at(constraint_pos)

# Should return the second result of the constraint op
assert result == constraint_op.results[1]
Expand Down Expand Up @@ -2147,8 +2147,8 @@ def test_get_value_at_caching():

# Get operand twice
operand_pos = root_pos.get_operand(0)
result1 = generator.get_value_at(block, operand_pos)
result2 = generator.get_value_at(block, operand_pos)
result1 = generator.get_value_at(operand_pos)
result2 = generator.get_value_at(operand_pos)

# Should return the same value (cached)
assert result1 is result2
Expand Down Expand Up @@ -2182,17 +2182,16 @@ def test_get_value_at_unimplemented_positions():

root_pos = OperationPosition(None, depth=0)
generator.values[root_pos] = matcher_func.body.block.args[0]
block = matcher_func.body.block

# Test UsersPosition
users_pos = UsersPosition(parent=root_pos, use_representative=True)
with pytest.raises(NotImplementedError, match="UsersPosition"):
generator.get_value_at(block, users_pos)
generator.get_value_at(users_pos)

# Test ForEachPosition
foreach_pos = ForEachPosition(parent=root_pos, id=0)
with pytest.raises(NotImplementedError, match="ForEachPosition"):
generator.get_value_at(block, foreach_pos)
generator.get_value_at(foreach_pos)


def test_get_value_at_operand_group_position():
Expand All @@ -2218,7 +2217,7 @@ def test_get_value_at_operand_group_position():

# Test variadic operand group
operand_group_pos = root_pos.get_operand_group(0, is_variadic=True)
result = generator.get_value_at(block, operand_group_pos)
result = generator.get_value_at(operand_group_pos)

# Should create GetOperandsOp
get_operands_ops = [
Expand All @@ -2232,7 +2231,7 @@ def test_get_value_at_operand_group_position():

# Test non-variadic operand group
operand_group_pos2 = root_pos.get_operand_group(1, is_variadic=False)
result2 = generator.get_value_at(block, operand_group_pos2)
result2 = generator.get_value_at(operand_group_pos2)

get_operands_ops = [
op for op in block.ops if isinstance(op, pdl_interp.GetOperandsOp)
Expand Down Expand Up @@ -2298,7 +2297,7 @@ def test_get_value_at_operation_position_passthrough():
op_pos_with_parent = OperationPosition(parent=constraint_pos, depth=1)

# Get the value - should hit the passthrough branch
result = generator.get_value_at(block, op_pos_with_parent)
result = generator.get_value_at(op_pos_with_parent)

# Should return the constraint's operation result (passthrough from parent)
assert result == constraint_op.results[0]
Expand All @@ -2308,7 +2307,7 @@ def test_get_value_at_operation_position_passthrough():
assert generator.values[op_pos_with_parent] is result

# Getting it again should return the cached value
result2 = generator.get_value_at(block, op_pos_with_parent)
result2 = generator.get_value_at(op_pos_with_parent)
assert result2 is result


Expand Down Expand Up @@ -2347,7 +2346,7 @@ def test_generate_bool_node_is_not_null():
bool_node = BoolNode(question=question, answer=answer)

# Generate the bool node
generator.generate_bool_node(bool_node, block, val)
generator.generate_bool_node(bool_node, val)

# Check that IsNotNullOp was created
check_ops = [op for op in block.ops if isinstance(op, pdl_interp.IsNotNullOp)]
Expand Down Expand Up @@ -2397,7 +2396,7 @@ def test_generate_bool_node_operation_name():
bool_node = BoolNode(question=question, answer=answer)

# Generate the bool node
generator.generate_bool_node(bool_node, block, val)
generator.generate_bool_node(bool_node, val)

# Check that CheckOperationNameOp was created
check_ops = [
Expand Down Expand Up @@ -2444,7 +2443,7 @@ def test_generate_bool_node_operand_count():
bool_node = BoolNode(question=question, answer=answer)

# Generate the bool node
generator.generate_bool_node(bool_node, block, val)
generator.generate_bool_node(bool_node, val)

# Check that CheckOperandCountOp was created
check_ops = [
Expand Down Expand Up @@ -2492,7 +2491,7 @@ def test_generate_bool_node_result_count_at_least():
bool_node = BoolNode(question=question, answer=answer)

# Generate the bool node
generator.generate_bool_node(bool_node, block, val)
generator.generate_bool_node(bool_node, val)

# Check that CheckResultCountOp was created
check_ops = [
Expand Down Expand Up @@ -2544,7 +2543,7 @@ def test_generate_bool_node_equal_to():
bool_node = BoolNode(question=question, answer=answer)

# Generate the bool node
generator.generate_bool_node(bool_node, block, val1)
generator.generate_bool_node(bool_node, val1)

# Check that AreEqualOp was created
check_ops = [op for op in block.ops if isinstance(op, pdl_interp.AreEqualOp)]
Expand Down Expand Up @@ -2590,7 +2589,7 @@ def test_generate_bool_node_attribute_constraint():
bool_node = BoolNode(question=question, answer=answer)

# Generate the bool node
generator.generate_bool_node(bool_node, block, val)
generator.generate_bool_node(bool_node, val)

# Check that CheckAttributeOp was created
check_ops = [op for op in block.ops if isinstance(op, pdl_interp.CheckAttributeOp)]
Expand Down Expand Up @@ -2635,7 +2634,7 @@ def test_generate_bool_node_type_constraint():
bool_node = BoolNode(question=question, answer=answer)

# Generate the bool node
generator.generate_bool_node(bool_node, block, val)
generator.generate_bool_node(bool_node, val)

# Check that CheckTypeOp was created
check_ops = [op for op in block.ops if isinstance(op, pdl_interp.CheckTypeOp)]
Expand Down Expand Up @@ -2683,7 +2682,7 @@ def test_generate_bool_node_native_constraint():
bool_node = BoolNode(question=question, answer=answer)

# Generate the bool node
generator.generate_bool_node(bool_node, block, val)
generator.generate_bool_node(bool_node, val)

# Check that ApplyConstraintOp was created
check_ops = [op for op in block.ops if isinstance(op, pdl_interp.ApplyConstraintOp)]
Expand Down Expand Up @@ -2734,7 +2733,7 @@ def test_generate_bool_node_operand_count_at_least():
bool_node = BoolNode(question=question, answer=answer)

# Generate the bool node
generator.generate_bool_node(bool_node, block, val)
generator.generate_bool_node(bool_node, val)

# Check that CheckOperandCountOp was created with compareAtLeast=True
check_ops = [
Expand Down Expand Up @@ -3086,7 +3085,7 @@ def test_generate_bool_node_with_success_node_calls_generate_matcher():
bool_node = BoolNode(question=question, answer=answer, success_node=success_node)

# Generate the bool node
generator.generate_bool_node(bool_node, block, val)
generator.generate_bool_node(bool_node, val)

# Check that IsNotNullOp was created
check_ops = [op for op in block.ops if isinstance(op, pdl_interp.IsNotNullOp)]
Expand Down Expand Up @@ -3149,7 +3148,7 @@ def test_generate_switch_node_operation_name():
mock_blocks = [Block(), Block(), Block()]
with patch.object(generator, "generate_matcher", side_effect=mock_blocks):
# Generate the switch node
generator.generate_switch_node(switch_node, block, val)
generator.generate_switch_node(switch_node, val)

# Check that SwitchOperationNameOp was created
switch_ops = [
Expand Down Expand Up @@ -3219,7 +3218,7 @@ def test_generate_switch_node_attribute_constraint():
mock_blocks = [Block(), Block(), Block()]
with patch.object(generator, "generate_matcher", side_effect=mock_blocks):
# Generate the switch node
generator.generate_switch_node(switch_node, block, val)
generator.generate_switch_node(switch_node, val)

# Check that SwitchAttributeOp was created
switch_ops = [
Expand Down Expand Up @@ -3289,7 +3288,7 @@ def test_generate_switch_node_with_none_child():
mock_blocks = [Block(), Block()]
with patch.object(generator, "generate_matcher", side_effect=mock_blocks):
# Generate the switch node
generator.generate_switch_node(switch_node, block, val)
generator.generate_switch_node(switch_node, val)

# Check that SwitchOperationNameOp was created
switch_ops = [
Expand Down Expand Up @@ -3339,7 +3338,7 @@ def test_generate_switch_node_empty_children():
switch_node = SwitchNode(question=question, children=children)

# Generate the switch node
generator.generate_switch_node(switch_node, block, val)
generator.generate_switch_node(switch_node, val)

# Check that SwitchOperationNameOp was created even with empty cases
switch_ops = [
Expand Down Expand Up @@ -3401,7 +3400,7 @@ def test_generate_switch_node_operand_count_not_implemented():
mock_blocks = [Block(), Block()]
with patch.object(generator, "generate_matcher", side_effect=mock_blocks):
# Generate the switch node
generator.generate_switch_node(switch_node, block, val)
generator.generate_switch_node(switch_node, val)

# Check that SwitchOperandCountOp was created
switch_ops = [
Expand Down Expand Up @@ -3467,7 +3466,7 @@ def test_generate_switch_node_result_count_not_implemented():
mock_blocks = [Block(), Block()]
with patch.object(generator, "generate_matcher", side_effect=mock_blocks):
# Generate the switch node
generator.generate_switch_node(switch_node, block, val)
generator.generate_switch_node(switch_node, val)

# Check that SwitchResultCountOp was created
switch_ops = [
Expand Down Expand Up @@ -3533,7 +3532,7 @@ def test_generate_switch_node_type_constraint_not_implemented():
mock_blocks = [Block(), Block()]
with patch.object(generator, "generate_matcher", side_effect=mock_blocks):
# Generate the switch node
generator.generate_switch_node(switch_node, block, val)
generator.generate_switch_node(switch_node, val)

# Check that SwitchTypeOp was created (val.type is pdl.TypeType, not RangeType)
switch_ops = [op for op in block.ops if isinstance(op, pdl_interp.SwitchTypeOp)]
Expand Down Expand Up @@ -3597,7 +3596,7 @@ def test_generate_switch_node_unhandled_question():
with patch.object(generator, "generate_matcher", side_effect=mock_blocks):
# Should raise NotImplementedError
with pytest.raises(NotImplementedError, match="Unhandled question type"):
generator.generate_switch_node(switch_node, block, val)
generator.generate_switch_node(switch_node, val)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -3641,7 +3640,7 @@ def test_generate_switch_node_at_least_question(
switch_node = SwitchNode(question=question, children=children)

# 3. Call the method under test
generator.generate_switch_node(switch_node, block, val)
generator.generate_switch_node(switch_node, val)

# 4. Verify the generated IR
# The logic creates a chain starting with the LOWEST count (1)
Expand Down
Loading