Skip to content

Commit d3efc8b

Browse files
authored
transformations: (convert-pdl-to-pdl-interp) Make get_value_at use implicit insertion block (#5591)
1 parent 8355dc2 commit d3efc8b

File tree

2 files changed

+91
-72
lines changed

2 files changed

+91
-72
lines changed

tests/transforms/test_convert_pdl_to_pdl_interp.py

Lines changed: 39 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,7 +1791,7 @@ def test_get_value_at_operation_position():
17911791
# Manually add root to cache
17921792
generator.values[root_pos] = root_val
17931793

1794-
result = generator.get_value_at(block, root_pos)
1794+
result = generator.get_value_at(root_pos)
17951795
assert result is root_val
17961796

17971797
# Test case 2: Operand defining op position
@@ -1801,7 +1801,7 @@ def test_get_value_at_operation_position():
18011801
# First get the operand value
18021802
generator.values[operand_pos] = root_val # Mock operand value
18031803

1804-
result = generator.get_value_at(block, defining_op_pos)
1804+
result = generator.get_value_at(defining_op_pos)
18051805

18061806
# Should create GetDefiningOpOp
18071807
get_def_ops = [op for op in block.ops if isinstance(op, pdl_interp.GetDefiningOpOp)]
@@ -1833,7 +1833,7 @@ def test_get_value_at_operand_position():
18331833

18341834
# Get operand at index 2
18351835
operand_pos = root_pos.get_operand(2)
1836-
result = generator.get_value_at(block, operand_pos)
1836+
result = generator.get_value_at(operand_pos)
18371837

18381838
# Should create GetOperandOp with index 2
18391839
get_operand_ops = [
@@ -1868,7 +1868,7 @@ def test_get_value_at_result_position():
18681868

18691869
# Get result at index 1
18701870
result_pos = root_pos.get_result(1)
1871-
result = generator.get_value_at(block, result_pos)
1871+
result = generator.get_value_at(result_pos)
18721872

18731873
# Should create GetResultOp with index 1
18741874
get_result_ops = [op for op in block.ops if isinstance(op, pdl_interp.GetResultOp)]
@@ -1901,7 +1901,7 @@ def test_get_value_at_result_group_position():
19011901

19021902
# Test variadic result group
19031903
result_group_pos = root_pos.get_result_group(0, is_variadic=True)
1904-
result = generator.get_value_at(block, result_group_pos)
1904+
result = generator.get_value_at(result_group_pos)
19051905

19061906
# Should create GetResultsOp
19071907
get_results_ops = [
@@ -1915,7 +1915,7 @@ def test_get_value_at_result_group_position():
19151915

19161916
# Test non-variadic result group
19171917
result_group_pos2 = root_pos.get_result_group(1, is_variadic=False)
1918-
result2 = generator.get_value_at(block, result_group_pos2)
1918+
result2 = generator.get_value_at(result_group_pos2)
19191919

19201920
get_results_ops = [
19211921
op for op in block.ops if isinstance(op, pdl_interp.GetResultsOp)
@@ -1949,7 +1949,7 @@ def test_get_value_at_attribute_position():
19491949

19501950
# Get attribute named "test_attr"
19511951
attr_pos = root_pos.get_attribute("test_attr")
1952-
result = generator.get_value_at(block, attr_pos)
1952+
result = generator.get_value_at(attr_pos)
19531953

19541954
# Should create GetAttributeOp
19551955
get_attr_ops = [op for op in block.ops if isinstance(op, pdl_interp.GetAttributeOp)]
@@ -1983,7 +1983,7 @@ def test_get_value_at_attribute_literal_position():
19831983
const_attr = IntegerAttr(42, i32)
19841984
attr_literal_pos = AttributeLiteralPosition(value=const_attr, parent=None)
19851985

1986-
result = generator.get_value_at(block, attr_literal_pos)
1986+
result = generator.get_value_at(attr_literal_pos)
19871987

19881988
# Should create CreateAttributeOp
19891989
create_attr_ops = [
@@ -2021,7 +2021,7 @@ def test_get_value_at_type_position():
20212021
generator.values[result_pos] = result_val
20222022

20232023
type_pos = result_pos.get_type()
2024-
result = generator.get_value_at(block, type_pos)
2024+
result = generator.get_value_at(type_pos)
20252025

20262026
# Should create GetValueTypeOp
20272027
get_type_ops = [op for op in block.ops if isinstance(op, pdl_interp.GetValueTypeOp)]
@@ -2050,7 +2050,7 @@ def test_get_value_at_type_literal_position():
20502050

20512051
# Test case 1: Single type literal
20522052
type_literal_pos = TypeLiteralPosition.get_type_literal(value=i32)
2053-
result = generator.get_value_at(block, type_literal_pos)
2053+
result = generator.get_value_at(type_literal_pos)
20542054

20552055
# Should create CreateTypeOp
20562056
create_type_ops = [
@@ -2063,7 +2063,7 @@ def test_get_value_at_type_literal_position():
20632063
# Test case 2: Multiple types (ArrayAttr)
20642064
types_array = ArrayAttr([i32, f32])
20652065
types_literal_pos = TypeLiteralPosition.get_type_literal(value=types_array)
2066-
result2 = generator.get_value_at(block, types_literal_pos)
2066+
result2 = generator.get_value_at(types_literal_pos)
20672067

20682068
# Should create CreateTypesOp
20692069
create_types_ops = [
@@ -2118,7 +2118,7 @@ def test_get_value_at_constraint_position():
21182118

21192119
# Get constraint result at index 1
21202120
constraint_pos = ConstraintPosition.get_constraint(constraint_q, result_index=1)
2121-
result = generator.get_value_at(block, constraint_pos)
2121+
result = generator.get_value_at(constraint_pos)
21222122

21232123
# Should return the second result of the constraint op
21242124
assert result == constraint_op.results[1]
@@ -2147,8 +2147,8 @@ def test_get_value_at_caching():
21472147

21482148
# Get operand twice
21492149
operand_pos = root_pos.get_operand(0)
2150-
result1 = generator.get_value_at(block, operand_pos)
2151-
result2 = generator.get_value_at(block, operand_pos)
2150+
result1 = generator.get_value_at(operand_pos)
2151+
result2 = generator.get_value_at(operand_pos)
21522152

21532153
# Should return the same value (cached)
21542154
assert result1 is result2
@@ -2182,17 +2182,16 @@ def test_get_value_at_unimplemented_positions():
21822182

21832183
root_pos = OperationPosition(None, depth=0)
21842184
generator.values[root_pos] = matcher_func.body.block.args[0]
2185-
block = matcher_func.body.block
21862185

21872186
# Test UsersPosition
21882187
users_pos = UsersPosition(parent=root_pos, use_representative=True)
21892188
with pytest.raises(NotImplementedError, match="UsersPosition"):
2190-
generator.get_value_at(block, users_pos)
2189+
generator.get_value_at(users_pos)
21912190

21922191
# Test ForEachPosition
21932192
foreach_pos = ForEachPosition(parent=root_pos, id=0)
21942193
with pytest.raises(NotImplementedError, match="ForEachPosition"):
2195-
generator.get_value_at(block, foreach_pos)
2194+
generator.get_value_at(foreach_pos)
21962195

21972196

21982197
def test_get_value_at_operand_group_position():
@@ -2218,7 +2217,7 @@ def test_get_value_at_operand_group_position():
22182217

22192218
# Test variadic operand group
22202219
operand_group_pos = root_pos.get_operand_group(0, is_variadic=True)
2221-
result = generator.get_value_at(block, operand_group_pos)
2220+
result = generator.get_value_at(operand_group_pos)
22222221

22232222
# Should create GetOperandsOp
22242223
get_operands_ops = [
@@ -2232,7 +2231,7 @@ def test_get_value_at_operand_group_position():
22322231

22332232
# Test non-variadic operand group
22342233
operand_group_pos2 = root_pos.get_operand_group(1, is_variadic=False)
2235-
result2 = generator.get_value_at(block, operand_group_pos2)
2234+
result2 = generator.get_value_at(operand_group_pos2)
22362235

22372236
get_operands_ops = [
22382237
op for op in block.ops if isinstance(op, pdl_interp.GetOperandsOp)
@@ -2298,7 +2297,7 @@ def test_get_value_at_operation_position_passthrough():
22982297
op_pos_with_parent = OperationPosition(parent=constraint_pos, depth=1)
22992298

23002299
# Get the value - should hit the passthrough branch
2301-
result = generator.get_value_at(block, op_pos_with_parent)
2300+
result = generator.get_value_at(op_pos_with_parent)
23022301

23032302
# Should return the constraint's operation result (passthrough from parent)
23042303
assert result == constraint_op.results[0]
@@ -2308,7 +2307,7 @@ def test_get_value_at_operation_position_passthrough():
23082307
assert generator.values[op_pos_with_parent] is result
23092308

23102309
# Getting it again should return the cached value
2311-
result2 = generator.get_value_at(block, op_pos_with_parent)
2310+
result2 = generator.get_value_at(op_pos_with_parent)
23122311
assert result2 is result
23132312

23142313

@@ -2347,7 +2346,7 @@ def test_generate_bool_node_is_not_null():
23472346
bool_node = BoolNode(question=question, answer=answer)
23482347

23492348
# Generate the bool node
2350-
generator.generate_bool_node(bool_node, block, val)
2349+
generator.generate_bool_node(bool_node, val)
23512350

23522351
# Check that IsNotNullOp was created
23532352
check_ops = [op for op in block.ops if isinstance(op, pdl_interp.IsNotNullOp)]
@@ -2397,7 +2396,7 @@ def test_generate_bool_node_operation_name():
23972396
bool_node = BoolNode(question=question, answer=answer)
23982397

23992398
# Generate the bool node
2400-
generator.generate_bool_node(bool_node, block, val)
2399+
generator.generate_bool_node(bool_node, val)
24012400

24022401
# Check that CheckOperationNameOp was created
24032402
check_ops = [
@@ -2444,7 +2443,7 @@ def test_generate_bool_node_operand_count():
24442443
bool_node = BoolNode(question=question, answer=answer)
24452444

24462445
# Generate the bool node
2447-
generator.generate_bool_node(bool_node, block, val)
2446+
generator.generate_bool_node(bool_node, val)
24482447

24492448
# Check that CheckOperandCountOp was created
24502449
check_ops = [
@@ -2492,7 +2491,7 @@ def test_generate_bool_node_result_count_at_least():
24922491
bool_node = BoolNode(question=question, answer=answer)
24932492

24942493
# Generate the bool node
2495-
generator.generate_bool_node(bool_node, block, val)
2494+
generator.generate_bool_node(bool_node, val)
24962495

24972496
# Check that CheckResultCountOp was created
24982497
check_ops = [
@@ -2544,7 +2543,7 @@ def test_generate_bool_node_equal_to():
25442543
bool_node = BoolNode(question=question, answer=answer)
25452544

25462545
# Generate the bool node
2547-
generator.generate_bool_node(bool_node, block, val1)
2546+
generator.generate_bool_node(bool_node, val1)
25482547

25492548
# Check that AreEqualOp was created
25502549
check_ops = [op for op in block.ops if isinstance(op, pdl_interp.AreEqualOp)]
@@ -2590,7 +2589,7 @@ def test_generate_bool_node_attribute_constraint():
25902589
bool_node = BoolNode(question=question, answer=answer)
25912590

25922591
# Generate the bool node
2593-
generator.generate_bool_node(bool_node, block, val)
2592+
generator.generate_bool_node(bool_node, val)
25942593

25952594
# Check that CheckAttributeOp was created
25962595
check_ops = [op for op in block.ops if isinstance(op, pdl_interp.CheckAttributeOp)]
@@ -2635,7 +2634,7 @@ def test_generate_bool_node_type_constraint():
26352634
bool_node = BoolNode(question=question, answer=answer)
26362635

26372636
# Generate the bool node
2638-
generator.generate_bool_node(bool_node, block, val)
2637+
generator.generate_bool_node(bool_node, val)
26392638

26402639
# Check that CheckTypeOp was created
26412640
check_ops = [op for op in block.ops if isinstance(op, pdl_interp.CheckTypeOp)]
@@ -2683,7 +2682,7 @@ def test_generate_bool_node_native_constraint():
26832682
bool_node = BoolNode(question=question, answer=answer)
26842683

26852684
# Generate the bool node
2686-
generator.generate_bool_node(bool_node, block, val)
2685+
generator.generate_bool_node(bool_node, val)
26872686

26882687
# Check that ApplyConstraintOp was created
26892688
check_ops = [op for op in block.ops if isinstance(op, pdl_interp.ApplyConstraintOp)]
@@ -2734,7 +2733,7 @@ def test_generate_bool_node_operand_count_at_least():
27342733
bool_node = BoolNode(question=question, answer=answer)
27352734

27362735
# Generate the bool node
2737-
generator.generate_bool_node(bool_node, block, val)
2736+
generator.generate_bool_node(bool_node, val)
27382737

27392738
# Check that CheckOperandCountOp was created with compareAtLeast=True
27402739
check_ops = [
@@ -3086,7 +3085,7 @@ def test_generate_bool_node_with_success_node_calls_generate_matcher():
30863085
bool_node = BoolNode(question=question, answer=answer, success_node=success_node)
30873086

30883087
# Generate the bool node
3089-
generator.generate_bool_node(bool_node, block, val)
3088+
generator.generate_bool_node(bool_node, val)
30903089

30913090
# Check that IsNotNullOp was created
30923091
check_ops = [op for op in block.ops if isinstance(op, pdl_interp.IsNotNullOp)]
@@ -3149,7 +3148,7 @@ def test_generate_switch_node_operation_name():
31493148
mock_blocks = [Block(), Block(), Block()]
31503149
with patch.object(generator, "generate_matcher", side_effect=mock_blocks):
31513150
# Generate the switch node
3152-
generator.generate_switch_node(switch_node, block, val)
3151+
generator.generate_switch_node(switch_node, val)
31533152

31543153
# Check that SwitchOperationNameOp was created
31553154
switch_ops = [
@@ -3219,7 +3218,7 @@ def test_generate_switch_node_attribute_constraint():
32193218
mock_blocks = [Block(), Block(), Block()]
32203219
with patch.object(generator, "generate_matcher", side_effect=mock_blocks):
32213220
# Generate the switch node
3222-
generator.generate_switch_node(switch_node, block, val)
3221+
generator.generate_switch_node(switch_node, val)
32233222

32243223
# Check that SwitchAttributeOp was created
32253224
switch_ops = [
@@ -3289,7 +3288,7 @@ def test_generate_switch_node_with_none_child():
32893288
mock_blocks = [Block(), Block()]
32903289
with patch.object(generator, "generate_matcher", side_effect=mock_blocks):
32913290
# Generate the switch node
3292-
generator.generate_switch_node(switch_node, block, val)
3291+
generator.generate_switch_node(switch_node, val)
32933292

32943293
# Check that SwitchOperationNameOp was created
32953294
switch_ops = [
@@ -3339,7 +3338,7 @@ def test_generate_switch_node_empty_children():
33393338
switch_node = SwitchNode(question=question, children=children)
33403339

33413340
# Generate the switch node
3342-
generator.generate_switch_node(switch_node, block, val)
3341+
generator.generate_switch_node(switch_node, val)
33433342

33443343
# Check that SwitchOperationNameOp was created even with empty cases
33453344
switch_ops = [
@@ -3401,7 +3400,7 @@ def test_generate_switch_node_operand_count_not_implemented():
34013400
mock_blocks = [Block(), Block()]
34023401
with patch.object(generator, "generate_matcher", side_effect=mock_blocks):
34033402
# Generate the switch node
3404-
generator.generate_switch_node(switch_node, block, val)
3403+
generator.generate_switch_node(switch_node, val)
34053404

34063405
# Check that SwitchOperandCountOp was created
34073406
switch_ops = [
@@ -3467,7 +3466,7 @@ def test_generate_switch_node_result_count_not_implemented():
34673466
mock_blocks = [Block(), Block()]
34683467
with patch.object(generator, "generate_matcher", side_effect=mock_blocks):
34693468
# Generate the switch node
3470-
generator.generate_switch_node(switch_node, block, val)
3469+
generator.generate_switch_node(switch_node, val)
34713470

34723471
# Check that SwitchResultCountOp was created
34733472
switch_ops = [
@@ -3533,7 +3532,7 @@ def test_generate_switch_node_type_constraint_not_implemented():
35333532
mock_blocks = [Block(), Block()]
35343533
with patch.object(generator, "generate_matcher", side_effect=mock_blocks):
35353534
# Generate the switch node
3536-
generator.generate_switch_node(switch_node, block, val)
3535+
generator.generate_switch_node(switch_node, val)
35373536

35383537
# Check that SwitchTypeOp was created (val.type is pdl.TypeType, not RangeType)
35393538
switch_ops = [op for op in block.ops if isinstance(op, pdl_interp.SwitchTypeOp)]
@@ -3597,7 +3596,7 @@ def test_generate_switch_node_unhandled_question():
35973596
with patch.object(generator, "generate_matcher", side_effect=mock_blocks):
35983597
# Should raise NotImplementedError
35993598
with pytest.raises(NotImplementedError, match="Unhandled question type"):
3600-
generator.generate_switch_node(switch_node, block, val)
3599+
generator.generate_switch_node(switch_node, val)
36013600

36023601

36033602
@pytest.mark.parametrize(
@@ -3641,7 +3640,7 @@ def test_generate_switch_node_at_least_question(
36413640
switch_node = SwitchNode(question=question, children=children)
36423641

36433642
# 3. Call the method under test
3644-
generator.generate_switch_node(switch_node, block, val)
3643+
generator.generate_switch_node(switch_node, val)
36453644

36463645
# 4. Verify the generated IR
36473646
# The logic creates a chain starting with the LOWEST count (1)

0 commit comments

Comments
 (0)