@@ -94,7 +94,6 @@ def test_pc_e2e_auto_mpg(self):
9494 self .assertEqual (pc .graph .edge_exists ("weight" , "horsepower" ), True )
9595 self .assertEqual (pc .graph .edge_exists ("displacement" , "cylinders" ), True )
9696 self .assertEqual (pc .graph .edge_exists ("displacement" , "acceleration" ), True )
97- # here
9897 self .assertEqual (pc .graph .edge_exists ("displacement" , "horsepower" ), True )
9998 self .assertEqual (pc .graph .edge_exists ("horsepower" , "acceleration" ), True )
10099
@@ -107,7 +106,7 @@ def test_pc_e2e_auto_mpg(self):
107106 self .assertEqual (pc .graph .edge_exists ("acceleration" , "cylinders" ), False )
108107 self .assertEqual (pc .graph .edge_exists ("horsepower" , "cylinders" ), False )
109108
110- # directtions
109+ # directions
111110 self .assertEqual (pc .graph .edge_of_type_exists ("mpg" , "weight" , UndirectedEdge ()), True )
112111 self .assertEqual (pc .graph .edge_of_type_exists ("weight" , "horsepower" , DirectedEdge ()), True )
113112 self .assertEqual (pc .graph .edge_of_type_exists ("weight" , "displacement" , DirectedEdge ()), True )
@@ -118,10 +117,68 @@ def test_pc_e2e_auto_mpg(self):
118117 self .assertEqual (pc .graph .edge_of_type_exists ("horsepower" , "displacement" , DirectedEdge ()), True )
119118
120119
120+ def test_pc_collider_rule_on_auto_mpg (self ):
121+ script_dir = os .path .dirname (os .path .abspath (__file__ ))
122+ folder_auto_mpg = os .path .join (script_dir , "fixtures/auto_mpg/" )
123+ with open (f"{ folder_auto_mpg } auto_mpg.json" , "r" ) as f :
124+ auto_mpg_data_set = json .load (f )
125+ PC_LOCAL = graph_model_factory (
126+ Algorithm (
127+ pipeline_steps = [
128+ CalculatePearsonCorrelations (display_name = "Calculate Pearson Correlations" ),
129+ CorrelationCoefficientTest (
130+ threshold = VariableReference (name = "threshold" ),
131+ display_name = "Correlation Coefficient Test" ,
132+ ),
133+ PartialCorrelationTest (
134+ threshold = VariableReference (name = "threshold" ),
135+ display_name = "Partial Correlation Test" ,
136+ ),
137+ ExtendedPartialCorrelationTestMatrix (
138+ threshold = VariableReference (name = "threshold" ),
139+ display_name = "Extended Partial Correlation Test Matrix" ,
140+ ),
141+ ColliderTest (display_name = "Collider Test" ),
142+ ],
143+ edge_types = PC_EDGE_TYPES ,
144+ extensions = [PC_GRAPH_UI_EXTENSION ],
145+ name = "PC" ,
146+ variables = [FloatVariable (name = "threshold" , value = 0.05 )],
147+ )
148+ )
149+ pc = PC_LOCAL ()
150+ pc .create_graph_from_data (auto_mpg_data_set )
151+ pc .create_all_possible_edges ()
152+ pc .execute_pipeline_steps ()
121153
154+ # skeleton
155+ self .assertEqual (pc .graph .edge_exists ("mpg" , "weight" ), True )
156+ self .assertEqual (pc .graph .edge_exists ("mpg" , "horsepower" ), True )
157+ self .assertEqual (pc .graph .edge_exists ("weight" , "displacement" ), True )
158+ self .assertEqual (pc .graph .edge_exists ("weight" , "horsepower" ), True )
159+ self .assertEqual (pc .graph .edge_exists ("displacement" , "cylinders" ), True )
160+ self .assertEqual (pc .graph .edge_exists ("displacement" , "acceleration" ), True )
161+ self .assertEqual (pc .graph .edge_exists ("displacement" , "horsepower" ), True )
162+ self .assertEqual (pc .graph .edge_exists ("horsepower" , "acceleration" ), True )
122163
164+ # assert all other edges are not present
165+ self .assertEqual (pc .graph .edge_exists ("mpg" , "displacement" ), False )
166+ self .assertEqual (pc .graph .edge_exists ("mpg" , "cylinders" ), False )
167+ self .assertEqual (pc .graph .edge_exists ("mpg" , "acceleration" ), False )
168+ self .assertEqual (pc .graph .edge_exists ("weight" , "cylinders" ), False )
169+ self .assertEqual (pc .graph .edge_exists ("weight" , "acceleration" ), False )
170+ self .assertEqual (pc .graph .edge_exists ("acceleration" , "cylinders" ), False )
171+ self .assertEqual (pc .graph .edge_exists ("horsepower" , "cylinders" ), False )
123172
124-
173+ # after collider rule
174+ self .assertEqual (pc .graph .edge_of_type_exists ("mpg" , "weight" , UndirectedEdge ()), True )
175+ self .assertEqual (pc .graph .edge_of_type_exists ("weight" , "horsepower" , DirectedEdge ()), True )
176+ self .assertEqual (pc .graph .edge_of_type_exists ("weight" , "displacement" , DirectedEdge ()), True )
177+ self .assertEqual (pc .graph .edge_of_type_exists ("mpg" , "horsepower" , DirectedEdge ()), True )
178+ self .assertEqual (pc .graph .edge_of_type_exists ("acceleration" , "horsepower" , DirectedEdge ()), True )
179+ self .assertEqual (pc .graph .edge_of_type_exists ("acceleration" , "displacement" , DirectedEdge ()), True )
180+ self .assertEqual (pc .graph .edge_of_type_exists ("displacement" , "cylinders" , UndirectedEdge ()), True )
181+ self .assertEqual (pc .graph .edge_of_type_exists ("horsepower" , "displacement" , UndirectedEdge ()), True )
125182
126183 def test_pc_number_of_all_proposed_actions_two_nodes (self ):
127184 """
@@ -602,3 +659,26 @@ def test_noncollider_triple_rule_e2e(self):
602659 self .assertEqual (tst .graph .edge_of_type_exists ("X" , "Y" , DirectedEdge ()), True )
603660 self .assertEqual (tst .graph .edge_of_type_exists ("Z" , "Y" , DirectedEdge ()), True )
604661 self .assertEqual (tst .graph .edge_of_type_exists ("Y" , "W" , DirectedEdge ()), True )
662+
663+
664+ def test_five_node_example_e2e (self ):
665+ rdnv = self .seeded_random .normalvariate
666+ sample_generator = IIDSampleGenerator (
667+ edges = [
668+ SampleEdge (NodeReference ("X" ), NodeReference ("Z" ), 1 ),
669+ SampleEdge (NodeReference ("Y" ), NodeReference ("Z" ), 1 ),
670+ SampleEdge (NodeReference ("Z" ), NodeReference ("V" ), 1 ),
671+ SampleEdge (NodeReference ("Z" ), NodeReference ("W" ), 1 ),
672+ ],
673+ random = lambda : rdnv (0 , 1 ),
674+ )
675+ test_data , graph = sample_generator .generate (10000 )
676+ tst = PCClassic ()
677+ tst .create_graph_from_data (test_data )
678+ tst .create_all_possible_edges ()
679+ tst .execute_pipeline_steps ()
680+
681+ self .assertEqual (tst .graph .edge_of_type_exists ("X" , "Z" , DirectedEdge ()), True )
682+ self .assertEqual (tst .graph .edge_of_type_exists ("Y" , "Z" , DirectedEdge ()), True )
683+ self .assertEqual (tst .graph .edge_of_type_exists ("Z" , "W" , DirectedEdge ()), True )
684+ self .assertEqual (tst .graph .edge_of_type_exists ("Z" , "V" , DirectedEdge ()), True )
0 commit comments