Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 13 additions & 18 deletions causal_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from causal_testing.estimation.logistic_regression_estimator import LogisticRegressionEstimator
from causal_testing.specification.causal_dag import CausalDAG
from causal_testing.specification.causal_specification import CausalSpecification
from causal_testing.specification.scenario import Scenario
from causal_testing.specification.variable import Input, Output
from causal_testing.testing.base_test_case import BaseTestCase
from causal_testing.testing.causal_effect import Negative, NoEffect, Positive, SomeEffect
Expand Down Expand Up @@ -105,13 +104,12 @@ def __init__(self, paths: CausalTestingPaths, ignore_cycles: bool = False, query
self.dag: Optional[CausalDAG] = None
self.data: Optional[pd.DataFrame] = None
self.variables: Dict[str, Any] = {"inputs": {}, "outputs": {}, "metas": {}}
self.scenario: Optional[Scenario] = None
self.causal_specification: Optional[CausalSpecification] = None
self.test_cases: Optional[List[CausalTestCase]] = None

def setup(self) -> None:
"""
Set up the framework by loading DAG, runtime csv data, creating the scenario and causal specification.
Set up the framework by loading DAG, runtime csv data, creating the causal specification.

:raises: FileNotFoundError if required files are missing
"""
Expand All @@ -130,8 +128,12 @@ def setup(self) -> None:
# Create variables from DAG
self.create_variables()

# Create scenario and specification
self.create_scenario_and_specification()
# Create causal specification
self.causal_specification = CausalSpecification(
variables=list(self.variables["inputs"].values()) + list(self.variables["outputs"].values()),
causal_dag=self.dag,
constraints={self.query} if self.query else None,
)

logger.info("Setup completed successfully")

Expand Down Expand Up @@ -187,18 +189,6 @@ def create_variables(self) -> None:
if self.dag.in_degree(node_name) > 0:
self.variables["outputs"][node_name] = Output(name=node_name, datatype=dtype)

def create_scenario_and_specification(self) -> None:
"""Create scenario and causal specification objects from loaded data."""
# Create scenario
all_variables = list(self.variables["inputs"].values()) + list(self.variables["outputs"].values())
self.scenario = Scenario(variables=all_variables)

# Set up treatment variables
self.scenario.setup_treatment_variables()

# Create causal specification
self.causal_specification = CausalSpecification(scenario=self.scenario, causal_dag=self.dag)

def load_tests(self) -> None:
"""
Load and prepare test configurations from file.
Expand Down Expand Up @@ -316,7 +306,12 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC
base_test_case=base_test,
treatment_value=test.get("treatment_value"),
control_value=test.get("control_value"),
adjustment_set=test.get("adjustment_set", self.causal_specification.causal_dag.identification(base_test)),
adjustment_set=test.get(
"adjustment_set",
self.causal_specification.causal_dag.identification(
base_test, self.causal_specification.hidden_variables()
),
),
df=filtered_df,
effect_modifiers=None,
formula=test.get("formula"),
Expand Down
38 changes: 7 additions & 31 deletions causal_testing/specification/causal_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

from causal_testing.testing.base_test_case import BaseTestCase

from .scenario import Scenario
from .variable import Output
from .variable import Variable

Node = Union[str, int] # Node type hint: A node is a string or an int

Expand Down Expand Up @@ -489,37 +488,12 @@ def get_backdoor_graph(self, treatments: list[str]) -> CausalDAG:
backdoor_graph.add_edges_from(filter(lambda x: x not in outgoing_edges, self.edges))
return backdoor_graph

def depends_on_outputs(self, node: Node, scenario: Scenario) -> bool:
"""Check whether a given node in a given scenario is or depends on a
model output in the given scenario. That is, whether or not the model
needs to be run to determine its value.

NOTE: The graph must be acyclic for this to terminate.

:param node: The node in the DAG representing the variable of interest.
:param scenario: The modelling scenario.

:return: Whether the given variable is or depends on an output.
"""
if isinstance(scenario.variables[node], Output):
return True
return any((self.depends_on_outputs(n, scenario) for n in self.predecessors(node)))

@staticmethod
def remove_hidden_adjustment_sets(minimal_adjustment_sets: list[str], scenario: Scenario):
"""Remove variables labelled as hidden from adjustment set(s)

:param minimal_adjustment_sets: list of minimal adjustment set(s) to have hidden variables removed from
:param scenario: The modelling scenario which informs the variables that are hidden
"""
return [adj for adj in minimal_adjustment_sets if all(not scenario.variables.get(x).hidden for x in adj)]

def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None):
def identification(self, base_test_case: BaseTestCase, avoid_variables: set[Variable] = None):
"""Identify and return the minimum adjustment set

:param base_test_case: A base test case instance containing the outcome_variable and the
treatment_variable required for identification.
:param scenario: The modelling scenario relating to the tests
:param avoid_variables: Variables not to be adjusted for (e.g. hidden variables).

:return: The smallest set of variables which can be adjusted for to obtain a causal
estimate as opposed to a purely associational estimate.
Expand All @@ -539,8 +513,10 @@ def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None
else:
raise ValueError("Causal effect should be 'total' or 'direct'")

if scenario is not None:
minimal_adjustment_sets = self.remove_hidden_adjustment_sets(minimal_adjustment_sets, scenario)
if avoid_variables is not None:
minimal_adjustment_sets = [
adj for adj in minimal_adjustment_sets if not {x.name for x in avoid_variables}.intersection(adj)
]

minimal_adjustment_set = min(minimal_adjustment_sets, key=len, default=set())
return set(minimal_adjustment_set)
Expand Down
39 changes: 31 additions & 8 deletions causal_testing/specification/causal_specification.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,45 @@
"""This module holds the abstract CausalSpecification data class, which holds a Scenario and CausalDag"""
"""This module holds the CausalSpecification data class."""

from dataclasses import dataclass
from typing import Union
from collections.abc import Iterable

from causal_testing.specification.causal_dag import CausalDAG
from causal_testing.specification.scenario import Scenario

Node = Union[str, int] # Node type hint: A node is a string or an int
from .variable import Variable


@dataclass
class CausalSpecification:
"""
Data class storing the Causal Specification (combination of Scenario and Causal Dag)
Data class storing the Causal Specification, made up of the modelling scenario and causal DAG).
A scenario defines the setting by listing the endogenous variables, their
datatypes, distributions, and any constraints over them. This is a common
practice in CI and is analogous to an investigator specifying "we are
interested in individuals over 40 who regularly eat cheese" or whatever. A
scenario, here, is not a specific test case; it just defines the population
of interest, in our case "runs of the model with parameters meeting the
constraints". The model may have other inputs/outputs which the investigator
may choose to leave out. These are then exogenous variables and behave accordingly.

:param {Variable} variables: The set of endogenous variables.
:param {str} causal_dag: The causal DAG.
:param {str} constraints: The set of constraints relating the endogenous variables.
"""

scenario: Scenario
causal_dag: CausalDAG
def __init__(self, variables: Iterable[Variable], causal_dag: CausalDAG, constraints: set[str] = None):
self.variables = {v.name: v for v in variables}
self.causal_dag = causal_dag
if constraints is not None:
self.constraints = set(constraints)
else:
self.constraints = set()

def hidden_variables(self) -> set[Variable]:
"""Get the set of hidden variables

:return The variables marked as hidden.
:rtype: {Variable}
"""
return {v for v in self.variables.values() if v.hidden}

def __str__(self):
return f"Scenario: {self.scenario}\nCausal DAG:\n{self.causal_dag}"
157 changes: 0 additions & 157 deletions causal_testing/specification/scenario.py

This file was deleted.

Loading