Skip to content

Commit 956df7a

Browse files
committed
outline functions
1 parent 69ace3f commit 956df7a

File tree

1 file changed

+167
-131
lines changed

1 file changed

+167
-131
lines changed

xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py

Lines changed: 167 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,38 @@ class PredicateSplit:
786786
] = field(default_factory=lambda: [])
787787

788788

789+
def _get_position_operation_dependencies(pos: Position) -> set[OperationPosition]:
790+
"""Get all operation position dependencies for a position."""
791+
operations: set[OperationPosition] = set()
792+
worklist: deque[Position] = deque([pos])
793+
visited: set[Position] = set()
794+
795+
while worklist:
796+
current = worklist.popleft()
797+
if current in visited:
798+
continue
799+
visited.add(current)
800+
801+
# If this is a ConstraintPosition, add its argument positions
802+
if isinstance(current, ConstraintPosition):
803+
worklist.extend(current.constraint.arg_positions)
804+
805+
# Get the base operation and all ancestors
806+
op = current.get_base_operation()
807+
while op:
808+
operations.add(op)
809+
if op.parent:
810+
parent_op = op.parent.get_base_operation()
811+
if parent_op:
812+
op = parent_op
813+
else:
814+
break
815+
else:
816+
break
817+
818+
return operations
819+
820+
789821
@dataclass
790822
class OperationPositionTree:
791823
"""Node in the tree representing an OperationPosition."""
@@ -847,59 +879,78 @@ def build_operation_position_tree(
847879
root = OperationPositionTree(operation=roots[0])
848880
pattern_paths: list[list[int]] = [[] for _ in pattern_operations]
849881

850-
# Build tree recursively
851-
def build_subtree(
852-
node: OperationPositionTree,
853-
prefix: set[OperationPosition],
854-
remaining_indices: list[int],
855-
current_paths: dict[int, list[int]],
856-
):
857-
if not remaining_indices:
858-
return
859-
860-
# Split patterns into covered and remaining
861-
covered: list[int] = []
862-
still_needed: list[int] = []
863-
for i in remaining_indices:
864-
uncovered = pattern_operations[i] - prefix
865-
if not uncovered:
866-
covered.append(i)
867-
else:
868-
still_needed.append(i)
869-
870-
node.covered_patterns.update(covered)
871-
872-
if not still_needed:
873-
return
874-
875-
# Group patterns by next operation
876-
next_ops: dict[OperationPosition, list[int]] = defaultdict(list)
877-
for i in still_needed:
878-
candidates = pattern_operations[i] - prefix
879-
if candidates:
880-
# Pick operation with highest score (appears in most patterns, shallow depth)
881-
best_op = max(
882-
candidates,
883-
key=lambda op: (
884-
sum(1 for j in still_needed if op in pattern_operations[j]),
885-
-op.get_operation_depth(),
886-
),
887-
)
888-
next_ops[best_op].append(i)
882+
# Build tree using the helper method
883+
OperationPositionTree._build_subtree(
884+
root,
885+
{roots[0]},
886+
list(range(len(pattern_operations))),
887+
{},
888+
pattern_operations,
889+
pattern_paths,
890+
)
891+
892+
return root, pattern_paths, predicate_dependencies
893+
894+
@staticmethod
895+
def _build_subtree(
896+
node: "OperationPositionTree",
897+
prefix: set[OperationPosition],
898+
remaining_indices: list[int],
899+
current_paths: dict[int, list[int]],
900+
pattern_operations: list[set[OperationPosition]],
901+
pattern_paths: list[list[int]],
902+
) -> None:
903+
"""Helper method to recursively build the operation position tree."""
904+
if not remaining_indices:
905+
return
889906

890-
# Create children
891-
for child_index, (op, indices) in enumerate(next_ops.items()):
892-
child = OperationPositionTree(operation=op)
893-
node.children.append(child)
907+
# Split patterns into covered and remaining
908+
covered: list[int] = []
909+
still_needed: list[int] = []
910+
for i in remaining_indices:
911+
uncovered = pattern_operations[i] - prefix
912+
if not uncovered:
913+
covered.append(i)
914+
else:
915+
still_needed.append(i)
894916

895-
child_paths: dict[int, list[int]] = {}
896-
for idx in indices:
897-
child_paths[idx] = current_paths.get(idx, []) + [child_index]
898-
pattern_paths[idx] = child_paths[idx]
899-
build_subtree(child, prefix | {op}, indices, child_paths)
917+
node.covered_patterns.update(covered)
900918

901-
build_subtree(root, {roots[0]}, list(range(len(pattern_operations))), {})
902-
return root, pattern_paths, predicate_dependencies
919+
if not still_needed:
920+
return
921+
922+
# Group patterns by next operation
923+
next_ops: dict[OperationPosition, list[int]] = defaultdict(list)
924+
for i in still_needed:
925+
candidates = pattern_operations[i] - prefix
926+
if candidates:
927+
# Pick operation with highest score (appears in most patterns, shallow depth)
928+
best_op = max(
929+
candidates,
930+
key=lambda op: (
931+
sum(1 for j in still_needed if op in pattern_operations[j]),
932+
-op.get_operation_depth(),
933+
),
934+
)
935+
next_ops[best_op].append(i)
936+
937+
# Create children
938+
for child_index, (op, indices) in enumerate(next_ops.items()):
939+
child = OperationPositionTree(operation=op)
940+
node.children.append(child)
941+
942+
child_paths: dict[int, list[int]] = {}
943+
for idx in indices:
944+
child_paths[idx] = current_paths.get(idx, []) + [child_index]
945+
pattern_paths[idx] = child_paths[idx]
946+
OperationPositionTree._build_subtree(
947+
child,
948+
prefix | {op},
949+
indices,
950+
child_paths,
951+
pattern_operations,
952+
pattern_paths,
953+
)
903954

904955
def build_predicate_tree_from_operation_tree(
905956
self,
@@ -920,109 +971,94 @@ def build_predicate_tree_from_operation_tree(
920971
Returns:
921972
List of predicates with PredicateSplits representing the tree structure
922973
"""
974+
# Start building from root
975+
root_prefix = {self.operation}
976+
return self._build_predicate_subtree(
977+
self,
978+
root_prefix,
979+
set(),
980+
ordered_predicates,
981+
pattern_predicates,
982+
predicate_dependencies,
983+
)
923984

924-
def build_predicate_subtree(
925-
node: OperationPositionTree,
926-
prefix: set[OperationPosition],
927-
parent_prefix: set[OperationPosition],
928-
) -> list[OrderedPredicate | PredicateSplit]:
929-
"""Build predicate tree for a subtree of the operation position tree."""
930-
931-
# Collect predicates whose dependencies are satisfied by current prefix
932-
# but weren't satisfied by parent prefix (newly satisfied)
933-
node_predicates: dict[tuple[Position, Question], OrderedPredicate] = {}
934-
935-
for pattern_preds, pred_deps in zip(
936-
pattern_predicates, predicate_dependencies, strict=False
937-
):
938-
for pred in pattern_preds:
939-
deps = pred_deps.get((pred.position, pred.q))
940-
if deps is None:
941-
continue # Skip if no dependencies recorded
942-
# Check if all dependencies are satisfied by current prefix
943-
# but not all were satisfied by parent prefix
944-
if deps.issubset(prefix) and not deps.issubset(parent_prefix):
945-
key = (pred.position, pred.q)
946-
if key in ordered_predicates:
947-
node_predicates[key] = ordered_predicates[key]
948-
949-
# Sort predicates for this node
950-
sorted_node_preds = cast(
951-
list[OrderedPredicate | PredicateSplit],
952-
sorted(node_predicates.values()),
953-
)
985+
@staticmethod
986+
def _build_predicate_subtree(
987+
node: "OperationPositionTree",
988+
prefix: set[OperationPosition],
989+
parent_prefix: set[OperationPosition],
990+
ordered_predicates: dict[tuple[Position, Question], OrderedPredicate],
991+
pattern_predicates: list[list[PositionalPredicate]],
992+
predicate_dependencies: list[
993+
dict[tuple[Position, Question], set[OperationPosition]]
994+
],
995+
) -> list[OrderedPredicate | PredicateSplit]:
996+
"""Build predicate tree for a subtree of the operation position tree."""
954997

955-
# If there are children, create a PredicateSplit
956-
if node.children:
957-
splits: list[
958-
tuple[OperationPosition, list[OrderedPredicate | PredicateSplit]]
959-
] = []
998+
# Collect predicates whose dependencies are satisfied by current prefix
999+
# but weren't satisfied by parent prefix (newly satisfied)
1000+
node_predicates: dict[tuple[Position, Question], OrderedPredicate] = {}
9601001

961-
for child in node.children:
962-
# Recursively build predicate tree for child
963-
child_preds = build_predicate_subtree(
964-
child, prefix | {child.operation}, prefix
965-
)
966-
splits.append((child.operation, child_preds))
1002+
for pattern_preds, pred_deps in zip(
1003+
pattern_predicates, predicate_dependencies, strict=False
1004+
):
1005+
for pred in pattern_preds:
1006+
deps = pred_deps.get((pred.position, pred.q))
1007+
if deps is None:
1008+
continue # Skip if no dependencies recorded
1009+
# Check if all dependencies are satisfied by current prefix
1010+
# but not all were satisfied by parent prefix
1011+
if deps.issubset(prefix) and not deps.issubset(parent_prefix):
1012+
key = (pred.position, pred.q)
1013+
if key in ordered_predicates:
1014+
node_predicates[key] = ordered_predicates[key]
1015+
1016+
# Sort predicates for this node
1017+
sorted_node_preds = cast(
1018+
list[OrderedPredicate | PredicateSplit],
1019+
sorted(node_predicates.values()),
1020+
)
9671021

968-
sorted_node_preds.append(PredicateSplit(splits))
1022+
# If there are children, create a PredicateSplit
1023+
if node.children:
1024+
splits: list[
1025+
tuple[OperationPosition, list[OrderedPredicate | PredicateSplit]]
1026+
] = []
1027+
1028+
for child in node.children:
1029+
# Recursively build predicate tree for child
1030+
child_preds = OperationPositionTree._build_predicate_subtree(
1031+
child,
1032+
prefix | {child.operation},
1033+
prefix,
1034+
ordered_predicates,
1035+
pattern_predicates,
1036+
predicate_dependencies,
1037+
)
1038+
splits.append((child.operation, child_preds))
9691039

970-
return sorted_node_preds
1040+
sorted_node_preds.append(PredicateSplit(splits))
9711041

972-
# Start building from root
973-
root_prefix = {self.operation}
974-
return build_predicate_subtree(self, root_prefix, set())
1042+
return sorted_node_preds
9751043

9761044
@staticmethod
9771045
def get_predicate_operation_dependencies(
9781046
pred: PositionalPredicate,
9791047
) -> set[OperationPosition]:
9801048
"""Get all operation position dependencies for a predicate."""
981-
982-
def get_position_dependencies(pos: Position) -> set[OperationPosition]:
983-
"""Get all operation position dependencies for a position."""
984-
operations: set[OperationPosition] = set()
985-
worklist: deque[Position] = deque([pos])
986-
visited: set[Position] = set()
987-
988-
while worklist:
989-
current = worklist.popleft()
990-
if current in visited:
991-
continue
992-
visited.add(current)
993-
994-
# If this is a ConstraintPosition, add its argument positions
995-
if isinstance(current, ConstraintPosition):
996-
worklist.extend(current.constraint.arg_positions)
997-
998-
# Get the base operation and all ancestors
999-
op = current.get_base_operation()
1000-
while op:
1001-
operations.add(op)
1002-
if op.parent:
1003-
parent_op = op.parent.get_base_operation()
1004-
if parent_op:
1005-
op = parent_op
1006-
else:
1007-
break
1008-
else:
1009-
break
1010-
1011-
return operations
1012-
10131049
deps: set[OperationPosition] = set()
10141050

10151051
# Add dependencies from the predicate position
1016-
deps.update(get_position_dependencies(pred.position))
1052+
deps.update(_get_position_operation_dependencies(pred.position))
10171053

10181054
# Handle EqualToQuestion - add the other position
10191055
if isinstance(pred.q, EqualToQuestion):
1020-
deps.update(get_position_dependencies(pred.q.other_position))
1056+
deps.update(_get_position_operation_dependencies(pred.q.other_position))
10211057

10221058
# Handle ConstraintQuestion - add all argument positions
10231059
if isinstance(pred.q, ConstraintQuestion):
10241060
for arg_pos in pred.q.arg_positions:
1025-
deps.update(get_position_dependencies(arg_pos))
1061+
deps.update(_get_position_operation_dependencies(arg_pos))
10261062

10271063
return deps
10281064

0 commit comments

Comments
 (0)