❗ This repository is heavily based on the TableShift GitHub repository. Our code is built on top of TableShift and code from Ricardo Sandoval and Hardt & Kim (2023).
This is code to reproduce experiments in the paper:
Vivian Y. Nastl and Moritz Hardt. "Do causal predictors generalize better to new domains?", 2024.
Simply clone the repo, enter the root directory and create a local conda environment.
git clone https://github.com/socialfoundations/causal-features.git
# set up the environment
conda env create -f environment.ymlRun the following commands to test the local execution environment:
conda env create -f environment.yml
conda activate tableshift
# test the install by running the training script
python examples/run_expt.pyThe final line above will print some detailed logging output as the script executes. When you see training completed! test accuracy: 0.6221 your environment is ready to go! (Accuracy may vary slightly due to randomness.)
The training script we run is located at experiments_causal/run_experiment.py.
It takes the following arguments:
experiment(experiment to run)model(model to use)cache_dir(directory to cache raw data files to)save_dir(directory to save result files to)
The full list of model names is given below. For more details on each algorithm, see TableShift.
| Model | Name in TableShift |
|---|---|
| XGBoost | xgb |
| LightGBM | lightgbm |
| SAINT | saint |
| NODE | node |
| Group DRO | group_dro |
| MLP | mlp |
| Tabular ResNet | resnet |
| Adversarial Label DRO | aldro |
| CORAL | deepcoral |
| MMD | mmd |
| DRO | dro |
| DANN | dann |
| TabTransformer | tabtransformer |
| MixUp | mixup |
| Label Group DRO | label_group_dro |
| IRM | irm |
| VREX | vrex |
| FT-Transformer | ft_transformer |
| IB-IRM | ib_irm |
| CausIRL CORAL | causirl_coral |
| CausIRL MMD | causirl_mmd |
| AND-Mask | and_mask |
All experiments were run as jobs submitted to a centralized cluster, running the open-source HTCondor scheduler.
The relevant script launching the jobs is located at experiments_causal/launch_experiments.py.
We provide the raw results of our experiments in the folder experiments_causal/results/. They contain a single json file for each task, feature selection and trained model.
Use the following Python scripts:
- Main result:
- Figure in introduction:
experiments_causal/plot_paper_introduction_figure.py - Figures in section "Empirical results":
experiments_causal/plot_paper_figures.py
- Figure in introduction:
- Appendix:
- Main results:
experiments_causal/plot_paper_appendix_figures.py,experiments_causal/plot_paper_appendix_figures_extra.py,experiments_causal/plot_paper_appendix_figures_extra2.py - Anti-causal features:
experiments_causal/plot_paper_appendix_figures.py - Causal machine learning:
experiments_causal/plot_add_on_causalml.py - Causal discovery:
experiments_causal/plot_add_on_causal_discovery.py - Random subsets:
experiments_causal/plot_add_on_random_subsets.py - Ablation study:
experiments_causal/plot_experiment_ablation.py - Empirical results across machine learning models:
experiments_causal/plot_add_on_models.py - Synthetic experiments:
experiments_causal/synthetic_experiments.ipynb
- Main results:
The datasets in our paper are either publicly available, or provide open credentialized access.
The datasets with open credentialized access require signing a data use agreement. For the tasks ICU Mortality and ICU Length of Stay, it is required to complete training CITI Data or Specimens Only Research, as they contain sensitive personal information.
Hence, these datasets must be manually fetched and stored locally.
A list of datasets, their names in our code, and the corresponding access levels are below. The string identifier is the value that should be passed as the experiment parameter to the --experiment flag of experiments_causal/run_experiment.py.
The causal, arguably causal, and anti-causal feature sets are obtained by appending _causal, _arguablycausal and _anticausal to the string identifier. Combined causal and anti-causal features have the appendix _causal_anticausal. If they exist, one obtains the estimated parents from causal discovery algorithms by appending the abbreviation of the algorithms in lower letters. For example, acsincome_pc. Random subsets are indexed from 0 to 500, and callable via the appendix _random_test_{index}.
| Tasks | String Identifier | Availability | Source | Preprocessing |
|---|---|---|---|---|
| Voting | anes |
Public Credentialized Access (source) | American National Election Studies (ANES) | TableShift |
| ASSISTments | assistments |
Public | Kaggle | TableShift |
| Childhood Lead | nhanes_lead |
Public | National Health and Nutrition Examination Survey (NHANES) | TableShift |
| College Scorecard | college_scorecard |
Public | College Scorecard | TableShift |
| Diabetes | brfss_diabetes |
Public | Behavioral Risk Factor Surveillance System (BRFSS) | TableShift |
| Food Stamps | acsfoodstamps |
Public | American Community Survey (via folktables) | |
| Hospital Readmission | diabetes_readmission |
Public | UCI | TableShift |
| Hypertension | brfss_blood_pressure |
Public | Behavioral Risk Factor Surveillance System (BRFSS) | TableShift |
| ICU Length of Stay | mimic_extract_los_3 |
Public Credentialized Access (source) | MIMIC-iii via MIMIC-Extract | TableShift |
| ICU Mortality | mimic_extract_mort_hosp |
Public Credentialized Access (source) | MIMIC-iii via MIMIC-Extract | TableShift |
| Income | acsincome |
Public | American Community Survey (via folktables) | TableShift |
| Public Health Insurance | acspubcov |
Public | American Community Survey (via folktables) | TableShift |
| Sepsis | physionet |
Public | Physionet | TableShift |
| Unemployment | acsunemployment |
Public | American Community Survey (via folktables) | TableShift |
| Utilization | meps |
Public (source) | Medical expenditure panel survey | Hardt & Kim (2023) |
| Poverty | sipp |
Public (source, source) | Survey of income and program participation | Hardt & Kim (2023) |
TableShift includes the preprocessing of the data files in their implementation. For the tasks Utilization and Poverty, follow the instructions provided by Hardt & Kim (2023) in backward_predictor/README.md.
We list in the following which files/folders we changed for our experiments:
- created folder
experiments_causalwith python scripts to run experiments, launch experiments on a cluster, and plot figures for the paper - created folder
backward_predictionwith preprocessing files adapted from Hardt & Kim (2023) withbackward_predictor/sipp/data/data_cleaning.ipynb© Ricardo Sandoval, 2024 - added tasks
mepsandsipp, as well as feature selections of all tasks in their respective Python scripts in the foldertableshift/datasets - added data source for
mepsandsippintableshift/core/data_source.py - added tasks
mepsandsipp, as well as feature selections of all tasks intableshift/core/tasks.py - added configurations for tasks and their feature selections in
tableshift/configs/non_benchmark_configs.py - added models
ib_erm,ib_irm,causirl_coral,causirl_mmdandand_maskintableshift/models, adapted from Gulrajani & Lopez-Paz (2021) - added configurations for hyperparameters of added models in
tableshift/configs/hparams.py - added computation of balanced accuracy in
tableshift/models/torchutils.pyand adaptedtableshift/models/compat.pyaccordingly - minor fixes in
tableshift/core/features.py,tableshift/core/tabular_dataset.pyandtableshift/models/training.py - added the packages
paretoset==1.2.3andseaborn==0.13.0inrequirements.txt
This repository contains code and supplementary materials for the following preprint:
@inproceedings{nastl2024predictors,
author = {Nastl, Vivian and Hardt, Moritz},
booktitle = {Advances in Neural Information Processing Systems},
editor = {A. Globerson and L. Mackey and D. Belgrave and A. Fan and U. Paquet and J. Tomczak and C. Zhang},
pages = {31202--31315},
publisher = {Curran Associates, Inc.},
title = {Do causal predictors generalize better to new domains?},
url = {https://proceedings.neurips.cc/paper_files/paper/2024/file/3792ddbf94b68ff4369f510f7a3e1777-Paper-Conference.pdf},
volume = {37},
year = {2024}
}