Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions causy/causal_discovery/constraint/orientation_rules/pc.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,15 @@ def process(
# It cannot be a collider because we have already oriented all unshielded triples that contain colliders.
for z in potential_zs:
z = graph.nodes[z]
print(f"x: {x.name}, y: {y.name}, z: {z.name}")
breakflag = False

if graph.only_directed_edge_exists(x, z) and graph.undirected_edge_exists(
z, y
):
for node in graph.nodes:
if graph.only_directed_edge_exists(graph.nodes[node], y) and not graph.edge_exists(graph.nodes[node], z):
if graph.only_directed_edge_exists(
graph.nodes[node], y
) and not graph.edge_exists(graph.nodes[node], z):
breakflag = True
break
if breakflag is True:
Expand All @@ -274,19 +276,17 @@ def process(
action=TestResultAction.REMOVE_EDGE_DIRECTED,
data={"between": {"x": x, "y": y, "z": z}},
)
print(graph.only_directed_edge_exists(y, z))
print(graph.undirected_edge_exists(z, x))
if z.name in graph.edges:
if x.name in graph.edges[z.name]:
print(f"{z.name} -> {x.name}")

if graph.only_directed_edge_exists(y, z) and graph.undirected_edge_exists(
z, x
):
for node in graph.nodes:
if graph.only_directed_edge_exists(graph.nodes[node], x):
if graph.only_directed_edge_exists(
graph.nodes[node], x
) and not graph.edge_exists(graph.nodes[node], z):
breakflag = True
break
print(f"breakflag: {breakflag}")

if breakflag is True:
return TestResult(
u=x,
Expand All @@ -301,7 +301,7 @@ def process(
u=x,
v=z,
action=TestResultAction.REMOVE_EDGE_DIRECTED,
data={},
data={"between": {"x": x, "y": y, "z": z}},
)


Expand Down
1 change: 1 addition & 0 deletions causy/common_pipeline_steps/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def execute(
current_actions_taken,
all_proposed_actions,
) = graph_model_instance_.execute_pipeline_step(pipeline_step)

steps.append(
ActionHistoryStep(
name=pipeline_step.name,
Expand Down
94 changes: 69 additions & 25 deletions tests/test_pc_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

class PCTestTestCase(CausyTestCase):
SEED = 1

def _sample_generator(self):
rdnv = self.seeded_random.normalvariate
return IIDSampleGenerator(
Expand All @@ -53,7 +54,9 @@ def test_pc_e2e_auto_mpg(self):
PC_LOCAL = graph_model_factory(
Algorithm(
pipeline_steps=[
CalculatePearsonCorrelations(display_name="Calculate Pearson Correlations"),
CalculatePearsonCorrelations(
display_name="Calculate Pearson Correlations"
),
CorrelationCoefficientTest(
threshold=VariableReference(name="threshold"),
display_name="Correlation Coefficient Test",
Expand Down Expand Up @@ -82,11 +85,6 @@ def test_pc_e2e_auto_mpg(self):
pc.create_all_possible_edges()
pc.execute_pipeline_steps()

for s in pc.graph.action_history:
print(s.name)
for a in s.actions:
print(a.u.name, a.v.name, a.action, a.data.keys())

# skeleton
self.assertEqual(pc.graph.edge_exists("mpg", "weight"), True)
self.assertEqual(pc.graph.edge_exists("mpg", "horsepower"), True)
Expand All @@ -107,15 +105,36 @@ def test_pc_e2e_auto_mpg(self):
self.assertEqual(pc.graph.edge_exists("horsepower", "cylinders"), False)

# directions
self.assertEqual(pc.graph.edge_of_type_exists("mpg", "weight", UndirectedEdge()), True)
self.assertEqual(pc.graph.edge_of_type_exists("weight", "horsepower", DirectedEdge()), True)
self.assertEqual(pc.graph.edge_of_type_exists("weight", "displacement", DirectedEdge()), True)
self.assertEqual(pc.graph.edge_of_type_exists("mpg", "horsepower", DirectedEdge()), True)
self.assertEqual(pc.graph.edge_of_type_exists("acceleration", "horsepower", DirectedEdge()), True)
self.assertEqual(pc.graph.edge_of_type_exists("acceleration", "displacement", DirectedEdge()), True)
self.assertEqual(pc.graph.edge_of_type_exists("displacement", "cylinders", DirectedEdge()), True)
self.assertEqual(pc.graph.edge_of_type_exists("horsepower", "displacement", DirectedEdge()), True)

self.assertEqual(
pc.graph.edge_of_type_exists("mpg", "weight", UndirectedEdge()), True
)
self.assertEqual(
pc.graph.edge_of_type_exists("weight", "horsepower", DirectedEdge()), True
)
self.assertEqual(
pc.graph.edge_of_type_exists("weight", "displacement", DirectedEdge()), True
)
self.assertEqual(
pc.graph.edge_of_type_exists("mpg", "horsepower", DirectedEdge()), True
)
self.assertEqual(
pc.graph.edge_of_type_exists("acceleration", "horsepower", DirectedEdge()),
True,
)
self.assertEqual(
pc.graph.edge_of_type_exists(
"acceleration", "displacement", DirectedEdge()
),
True,
)
self.assertEqual(
pc.graph.edge_of_type_exists("displacement", "cylinders", DirectedEdge()),
True,
)
self.assertEqual(
pc.graph.edge_of_type_exists("horsepower", "displacement", DirectedEdge()),
True,
)

def test_pc_collider_rule_on_auto_mpg(self):
script_dir = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -125,7 +144,9 @@ def test_pc_collider_rule_on_auto_mpg(self):
PC_LOCAL = graph_model_factory(
Algorithm(
pipeline_steps=[
CalculatePearsonCorrelations(display_name="Calculate Pearson Correlations"),
CalculatePearsonCorrelations(
display_name="Calculate Pearson Correlations"
),
CorrelationCoefficientTest(
threshold=VariableReference(name="threshold"),
display_name="Correlation Coefficient Test",
Expand Down Expand Up @@ -171,14 +192,38 @@ def test_pc_collider_rule_on_auto_mpg(self):
self.assertEqual(pc.graph.edge_exists("horsepower", "cylinders"), False)

# after collider rule
self.assertEqual(pc.graph.edge_of_type_exists("mpg", "weight", UndirectedEdge()), True)
self.assertEqual(pc.graph.edge_of_type_exists("weight", "horsepower", DirectedEdge()), True)
self.assertEqual(pc.graph.edge_of_type_exists("weight", "displacement", DirectedEdge()), True)
self.assertEqual(pc.graph.edge_of_type_exists("mpg", "horsepower", DirectedEdge()), True)
self.assertEqual(pc.graph.edge_of_type_exists("acceleration", "horsepower", DirectedEdge()), True)
self.assertEqual(pc.graph.edge_of_type_exists("acceleration", "displacement", DirectedEdge()), True)
self.assertEqual(pc.graph.edge_of_type_exists("displacement", "cylinders", UndirectedEdge()), True)
self.assertEqual(pc.graph.edge_of_type_exists("horsepower", "displacement", UndirectedEdge()), True)
self.assertEqual(
pc.graph.edge_of_type_exists("mpg", "weight", UndirectedEdge()), True
)
self.assertEqual(
pc.graph.edge_of_type_exists("weight", "horsepower", DirectedEdge()), True
)
self.assertEqual(
pc.graph.edge_of_type_exists("weight", "displacement", DirectedEdge()), True
)
self.assertEqual(
pc.graph.edge_of_type_exists("mpg", "horsepower", DirectedEdge()), True
)
self.assertEqual(
pc.graph.edge_of_type_exists("acceleration", "horsepower", DirectedEdge()),
True,
)
self.assertEqual(
pc.graph.edge_of_type_exists(
"acceleration", "displacement", DirectedEdge()
),
True,
)
self.assertEqual(
pc.graph.edge_of_type_exists("displacement", "cylinders", UndirectedEdge()),
True,
)
self.assertEqual(
pc.graph.edge_of_type_exists(
"horsepower", "displacement", UndirectedEdge()
),
True,
)

def test_pc_number_of_all_proposed_actions_two_nodes(self):
"""
Expand Down Expand Up @@ -660,7 +705,6 @@ def test_noncollider_triple_rule_e2e(self):
self.assertEqual(tst.graph.edge_of_type_exists("Z", "Y", DirectedEdge()), True)
self.assertEqual(tst.graph.edge_of_type_exists("Y", "W", DirectedEdge()), True)


def test_five_node_example_e2e(self):
rdnv = self.seeded_random.normalvariate
sample_generator = IIDSampleGenerator(
Expand Down
Loading