@@ -786,6 +786,38 @@ class PredicateSplit:
786786 ] = field (default_factory = lambda : [])
787787
788788
789+ def _get_position_operation_dependencies (pos : Position ) -> set [OperationPosition ]:
790+ """Get all operation position dependencies for a position."""
791+ operations : set [OperationPosition ] = set ()
792+ worklist : deque [Position ] = deque ([pos ])
793+ visited : set [Position ] = set ()
794+
795+ while worklist :
796+ current = worklist .popleft ()
797+ if current in visited :
798+ continue
799+ visited .add (current )
800+
801+ # If this is a ConstraintPosition, add its argument positions
802+ if isinstance (current , ConstraintPosition ):
803+ worklist .extend (current .constraint .arg_positions )
804+
805+ # Get the base operation and all ancestors
806+ op = current .get_base_operation ()
807+ while op :
808+ operations .add (op )
809+ if op .parent :
810+ parent_op = op .parent .get_base_operation ()
811+ if parent_op :
812+ op = parent_op
813+ else :
814+ break
815+ else :
816+ break
817+
818+ return operations
819+
820+
789821@dataclass
790822class OperationPositionTree :
791823 """Node in the tree representing an OperationPosition."""
@@ -847,59 +879,78 @@ def build_operation_position_tree(
847879 root = OperationPositionTree (operation = roots [0 ])
848880 pattern_paths : list [list [int ]] = [[] for _ in pattern_operations ]
849881
850- # Build tree recursively
851- def build_subtree (
852- node : OperationPositionTree ,
853- prefix : set [OperationPosition ],
854- remaining_indices : list [int ],
855- current_paths : dict [int , list [int ]],
856- ):
857- if not remaining_indices :
858- return
859-
860- # Split patterns into covered and remaining
861- covered : list [int ] = []
862- still_needed : list [int ] = []
863- for i in remaining_indices :
864- uncovered = pattern_operations [i ] - prefix
865- if not uncovered :
866- covered .append (i )
867- else :
868- still_needed .append (i )
869-
870- node .covered_patterns .update (covered )
871-
872- if not still_needed :
873- return
874-
875- # Group patterns by next operation
876- next_ops : dict [OperationPosition , list [int ]] = defaultdict (list )
877- for i in still_needed :
878- candidates = pattern_operations [i ] - prefix
879- if candidates :
880- # Pick operation with highest score (appears in most patterns, shallow depth)
881- best_op = max (
882- candidates ,
883- key = lambda op : (
884- sum (1 for j in still_needed if op in pattern_operations [j ]),
885- - op .get_operation_depth (),
886- ),
887- )
888- next_ops [best_op ].append (i )
882+ # Build tree using the helper method
883+ OperationPositionTree ._build_subtree (
884+ root ,
885+ {roots [0 ]},
886+ list (range (len (pattern_operations ))),
887+ {},
888+ pattern_operations ,
889+ pattern_paths ,
890+ )
891+
892+ return root , pattern_paths , predicate_dependencies
893+
894+ @staticmethod
895+ def _build_subtree (
896+ node : "OperationPositionTree" ,
897+ prefix : set [OperationPosition ],
898+ remaining_indices : list [int ],
899+ current_paths : dict [int , list [int ]],
900+ pattern_operations : list [set [OperationPosition ]],
901+ pattern_paths : list [list [int ]],
902+ ) -> None :
903+ """Helper method to recursively build the operation position tree."""
904+ if not remaining_indices :
905+ return
889906
890- # Create children
891- for child_index , (op , indices ) in enumerate (next_ops .items ()):
892- child = OperationPositionTree (operation = op )
893- node .children .append (child )
907+ # Split patterns into covered and remaining
908+ covered : list [int ] = []
909+ still_needed : list [int ] = []
910+ for i in remaining_indices :
911+ uncovered = pattern_operations [i ] - prefix
912+ if not uncovered :
913+ covered .append (i )
914+ else :
915+ still_needed .append (i )
894916
895- child_paths : dict [int , list [int ]] = {}
896- for idx in indices :
897- child_paths [idx ] = current_paths .get (idx , []) + [child_index ]
898- pattern_paths [idx ] = child_paths [idx ]
899- build_subtree (child , prefix | {op }, indices , child_paths )
917+ node .covered_patterns .update (covered )
900918
901- build_subtree (root , {roots [0 ]}, list (range (len (pattern_operations ))), {})
902- return root , pattern_paths , predicate_dependencies
919+ if not still_needed :
920+ return
921+
922+ # Group patterns by next operation
923+ next_ops : dict [OperationPosition , list [int ]] = defaultdict (list )
924+ for i in still_needed :
925+ candidates = pattern_operations [i ] - prefix
926+ if candidates :
927+ # Pick operation with highest score (appears in most patterns, shallow depth)
928+ best_op = max (
929+ candidates ,
930+ key = lambda op : (
931+ sum (1 for j in still_needed if op in pattern_operations [j ]),
932+ - op .get_operation_depth (),
933+ ),
934+ )
935+ next_ops [best_op ].append (i )
936+
937+ # Create children
938+ for child_index , (op , indices ) in enumerate (next_ops .items ()):
939+ child = OperationPositionTree (operation = op )
940+ node .children .append (child )
941+
942+ child_paths : dict [int , list [int ]] = {}
943+ for idx in indices :
944+ child_paths [idx ] = current_paths .get (idx , []) + [child_index ]
945+ pattern_paths [idx ] = child_paths [idx ]
946+ OperationPositionTree ._build_subtree (
947+ child ,
948+ prefix | {op },
949+ indices ,
950+ child_paths ,
951+ pattern_operations ,
952+ pattern_paths ,
953+ )
903954
904955 def build_predicate_tree_from_operation_tree (
905956 self ,
@@ -920,109 +971,94 @@ def build_predicate_tree_from_operation_tree(
920971 Returns:
921972 List of predicates with PredicateSplits representing the tree structure
922973 """
974+ # Start building from root
975+ root_prefix = {self .operation }
976+ return self ._build_predicate_subtree (
977+ self ,
978+ root_prefix ,
979+ set (),
980+ ordered_predicates ,
981+ pattern_predicates ,
982+ predicate_dependencies ,
983+ )
923984
924- def build_predicate_subtree (
925- node : OperationPositionTree ,
926- prefix : set [OperationPosition ],
927- parent_prefix : set [OperationPosition ],
928- ) -> list [OrderedPredicate | PredicateSplit ]:
929- """Build predicate tree for a subtree of the operation position tree."""
930-
931- # Collect predicates whose dependencies are satisfied by current prefix
932- # but weren't satisfied by parent prefix (newly satisfied)
933- node_predicates : dict [tuple [Position , Question ], OrderedPredicate ] = {}
934-
935- for pattern_preds , pred_deps in zip (
936- pattern_predicates , predicate_dependencies , strict = False
937- ):
938- for pred in pattern_preds :
939- deps = pred_deps .get ((pred .position , pred .q ))
940- if deps is None :
941- continue # Skip if no dependencies recorded
942- # Check if all dependencies are satisfied by current prefix
943- # but not all were satisfied by parent prefix
944- if deps .issubset (prefix ) and not deps .issubset (parent_prefix ):
945- key = (pred .position , pred .q )
946- if key in ordered_predicates :
947- node_predicates [key ] = ordered_predicates [key ]
948-
949- # Sort predicates for this node
950- sorted_node_preds = cast (
951- list [OrderedPredicate | PredicateSplit ],
952- sorted (node_predicates .values ()),
953- )
985+ @staticmethod
986+ def _build_predicate_subtree (
987+ node : "OperationPositionTree" ,
988+ prefix : set [OperationPosition ],
989+ parent_prefix : set [OperationPosition ],
990+ ordered_predicates : dict [tuple [Position , Question ], OrderedPredicate ],
991+ pattern_predicates : list [list [PositionalPredicate ]],
992+ predicate_dependencies : list [
993+ dict [tuple [Position , Question ], set [OperationPosition ]]
994+ ],
995+ ) -> list [OrderedPredicate | PredicateSplit ]:
996+ """Build predicate tree for a subtree of the operation position tree."""
954997
955- # If there are children, create a PredicateSplit
956- if node .children :
957- splits : list [
958- tuple [OperationPosition , list [OrderedPredicate | PredicateSplit ]]
959- ] = []
998+ # Collect predicates whose dependencies are satisfied by current prefix
999+ # but weren't satisfied by parent prefix (newly satisfied)
1000+ node_predicates : dict [tuple [Position , Question ], OrderedPredicate ] = {}
9601001
961- for child in node .children :
962- # Recursively build predicate tree for child
963- child_preds = build_predicate_subtree (
964- child , prefix | {child .operation }, prefix
965- )
966- splits .append ((child .operation , child_preds ))
1002+ for pattern_preds , pred_deps in zip (
1003+ pattern_predicates , predicate_dependencies , strict = False
1004+ ):
1005+ for pred in pattern_preds :
1006+ deps = pred_deps .get ((pred .position , pred .q ))
1007+ if deps is None :
1008+ continue # Skip if no dependencies recorded
1009+ # Check if all dependencies are satisfied by current prefix
1010+ # but not all were satisfied by parent prefix
1011+ if deps .issubset (prefix ) and not deps .issubset (parent_prefix ):
1012+ key = (pred .position , pred .q )
1013+ if key in ordered_predicates :
1014+ node_predicates [key ] = ordered_predicates [key ]
1015+
1016+ # Sort predicates for this node
1017+ sorted_node_preds = cast (
1018+ list [OrderedPredicate | PredicateSplit ],
1019+ sorted (node_predicates .values ()),
1020+ )
9671021
968- sorted_node_preds .append (PredicateSplit (splits ))
1022+ # If there are children, create a PredicateSplit
1023+ if node .children :
1024+ splits : list [
1025+ tuple [OperationPosition , list [OrderedPredicate | PredicateSplit ]]
1026+ ] = []
1027+
1028+ for child in node .children :
1029+ # Recursively build predicate tree for child
1030+ child_preds = OperationPositionTree ._build_predicate_subtree (
1031+ child ,
1032+ prefix | {child .operation },
1033+ prefix ,
1034+ ordered_predicates ,
1035+ pattern_predicates ,
1036+ predicate_dependencies ,
1037+ )
1038+ splits .append ((child .operation , child_preds ))
9691039
970- return sorted_node_preds
1040+ sorted_node_preds . append ( PredicateSplit ( splits ))
9711041
972- # Start building from root
973- root_prefix = {self .operation }
974- return build_predicate_subtree (self , root_prefix , set ())
1042+ return sorted_node_preds
9751043
9761044 @staticmethod
9771045 def get_predicate_operation_dependencies (
9781046 pred : PositionalPredicate ,
9791047 ) -> set [OperationPosition ]:
9801048 """Get all operation position dependencies for a predicate."""
981-
982- def get_position_dependencies (pos : Position ) -> set [OperationPosition ]:
983- """Get all operation position dependencies for a position."""
984- operations : set [OperationPosition ] = set ()
985- worklist : deque [Position ] = deque ([pos ])
986- visited : set [Position ] = set ()
987-
988- while worklist :
989- current = worklist .popleft ()
990- if current in visited :
991- continue
992- visited .add (current )
993-
994- # If this is a ConstraintPosition, add its argument positions
995- if isinstance (current , ConstraintPosition ):
996- worklist .extend (current .constraint .arg_positions )
997-
998- # Get the base operation and all ancestors
999- op = current .get_base_operation ()
1000- while op :
1001- operations .add (op )
1002- if op .parent :
1003- parent_op = op .parent .get_base_operation ()
1004- if parent_op :
1005- op = parent_op
1006- else :
1007- break
1008- else :
1009- break
1010-
1011- return operations
1012-
10131049 deps : set [OperationPosition ] = set ()
10141050
10151051 # Add dependencies from the predicate position
1016- deps .update (get_position_dependencies (pred .position ))
1052+ deps .update (_get_position_operation_dependencies (pred .position ))
10171053
10181054 # Handle EqualToQuestion - add the other position
10191055 if isinstance (pred .q , EqualToQuestion ):
1020- deps .update (get_position_dependencies (pred .q .other_position ))
1056+ deps .update (_get_position_operation_dependencies (pred .q .other_position ))
10211057
10221058 # Handle ConstraintQuestion - add all argument positions
10231059 if isinstance (pred .q , ConstraintQuestion ):
10241060 for arg_pos in pred .q .arg_positions :
1025- deps .update (get_position_dependencies (arg_pos ))
1061+ deps .update (_get_position_operation_dependencies (arg_pos ))
10261062
10271063 return deps
10281064
0 commit comments