Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion causal_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,6 @@ def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
args = main_parser.parse_args(args)

# Assume the user wants test adequacy if they're setting bootstrap_size
print(args)
if getattr(args, "bootstrap_size", None) is not None:
args.adequacy = True
if getattr(args, "adequacy", False) and getattr(args, "bootstrap_size", None) is None:
Expand Down
20 changes: 14 additions & 6 deletions causal_testing/testing/causal_test_adequacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def __init__(
):
self.test_case = test_case
self.kurtosis = None
self.outcomes = None
self.passing = None
self.results = None
self.successful = None
self.bootstrap_size = bootstrap_size
self.group_by = group_by
Expand All @@ -93,6 +94,7 @@ def measure_adequacy(self):
Calculate the adequacy measurement, and populate the data_adequacy field.
"""
results = []
outcomes = []
for i in range(self.bootstrap_size):
estimator = deepcopy(self.test_case.estimator)

Expand All @@ -103,7 +105,9 @@ def measure_adequacy(self):
else:
estimator.df = estimator.df.sample(len(estimator.df), replace=True, random_state=i)
try:
results.append(self.test_case.execute_test(estimator))
result = self.test_case.execute_test(estimator)
outcomes.append(self.test_case.expected_causal_effect.apply(result))
results.append(result.effect_estimate.to_df())
except LinAlgError:
logger.warning("Adequacy LinAlgError")
continue
Expand All @@ -113,19 +117,23 @@ def measure_adequacy(self):
except ValueError as e:
logger.warning(f"Adequacy ValueError: {e}")
continue
outcomes = [self.test_case.expected_causal_effect.apply(c) for c in results]
results = pd.concat([c.effect_estimate.to_df() for c in results])
# outcomes = [self.test_case.expected_causal_effect.apply(c) for c in results]
# results = pd.concat([c.effect_estimate.to_df() for c in results])
results = pd.concat(results)
results["var"] = results.index
results["passed"] = outcomes

self.results = results
self.kurtosis = results.groupby("var")["effect_estimate"].apply(lambda x: x.kurtosis())
self.outcomes = sum(filter(lambda x: x is not None, outcomes))
self.passing = sum(filter(lambda x: x is not None, outcomes))
self.successful = sum(x is not None for x in outcomes)

def to_dict(self):
"""Returns the adequacy object as a dictionary."""
return {
"kurtosis": self.kurtosis.to_dict(),
"bootstrap_size": self.bootstrap_size,
"passing": self.outcomes,
"passing": self.passing,
"successful": self.successful,
"results": self.results.reset_index(drop=True).to_dict(),
}
36 changes: 27 additions & 9 deletions tests/testing_tests/test_causal_test_adequacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,17 @@ def test_data_adequacy_numeric(self):
)
adequacy_metric = DataAdequacy(causal_test_case)
adequacy_metric.measure_adequacy()

self.assertEqual(
adequacy_metric.kurtosis["test_input"],
0,
f"Expected kurtosis not {adequacy_metric.kurtosis['test_input']}",
)
self.assertEqual(
adequacy_metric.to_dict(),
{"kurtosis": {"test_input": 0.0}, "bootstrap_size": 100, "passing": 100, "successful": 100},
adequacy_metric.bootstrap_size, 100, f"Expected bootstrap size 100 not {adequacy_metric.bootstrap_size}"
)
self.assertEqual(adequacy_metric.passing, 100, f"Expected passing 32 not {adequacy_metric.passing}")
self.assertEqual(adequacy_metric.successful, 100, f"Expected successful 100 not {adequacy_metric.successful}")

def test_data_adequacy_categorical(self):
base_test_case = BaseTestCase(
Expand All @@ -68,10 +75,17 @@ def test_data_adequacy_categorical(self):
)
adequacy_metric = DataAdequacy(causal_test_case)
adequacy_metric.measure_adequacy()

self.assertEqual(
adequacy_metric.to_dict(),
{"kurtosis": {"test_input_no_dist[T.b]": 0.0}, "bootstrap_size": 100, "passing": 100, "successful": 100},
adequacy_metric.kurtosis["test_input_no_dist[T.b]"],
0,
f"Expected kurtosis not {adequacy_metric.kurtosis['test_input_no_dist[T.b]']}",
)
self.assertEqual(
adequacy_metric.bootstrap_size, 100, f"Expected bootstrap size 100 not {adequacy_metric.bootstrap_size}"
)
self.assertEqual(adequacy_metric.passing, 100, f"Expected passing 100 not {adequacy_metric.passing}")
self.assertEqual(adequacy_metric.successful, 100, f"Expected successful 100 not {adequacy_metric.successful}")

def test_data_adequacy_group_by(self):
timesteps_per_intervention = 1
Expand Down Expand Up @@ -102,13 +116,17 @@ def test_data_adequacy_group_by(self):
)
adequacy_metric = DataAdequacy(causal_test_case, group_by="id")
adequacy_metric.measure_adequacy()
adequacy_dict = adequacy_metric.to_dict()
self.assertEqual(round(adequacy_dict["kurtosis"]["trtrand"], 3), -0.857)
adequacy_dict.pop("kurtosis")

self.assertEqual(
round(adequacy_metric.kurtosis["trtrand"], 3),
-0.857,
f"Expected kurtosis not {round(adequacy_metric.kurtosis['trtrand'], 3)}",
)
self.assertEqual(
adequacy_dict,
{"bootstrap_size": 100, "passing": 32, "successful": 100},
adequacy_metric.bootstrap_size, 100, f"Expected bootstrap size 100 not {adequacy_metric.bootstrap_size}"
)
self.assertEqual(adequacy_metric.passing, 32, f"Expected passing 32 not {adequacy_metric.passing}")
self.assertEqual(adequacy_metric.successful, 100, f"Expected successful 100 not {adequacy_metric.successful}")

def test_dag_adequacy_dependent(self):
base_test_case = BaseTestCase(
Expand Down