A JAX-based library for building interpretable Bayesian models using piecewise linear regression and gradient-flow adaptive importance sampling for leave-one-out cross-validation.
Bayesianquilts provides tools for building truly interpretable input-output maps based on the principle of piecewise linearity. Rather than using black-box neural networks, this library combines representation learning, clustering, and multilevel linear regression modeling to create transparent, interpretable models suitable for high-stakes applications like healthcare and scientific research.
The library includes two major research contributions:
- Piecewise Linear Regression Models: An additive decomposition approach where parameter values arise as sums of contributions at different length scales
- Gradient-Flow Adaptive Importance Sampling (AIS): Advanced methods for Bayesian leave-one-out cross-validation
- Interpretable by Design: Models are constructed to be inherently interpretable, not just post-hoc explainable
- Parameter Decomposition: Additive decomposition of parameters across interaction dimensions
- Flexible Model Types: Supports classification, regression, and matrix factorization
- Advanced Cross-Validation: Gradient-flow adaptive importance sampling for LOO-CV
- Bayesian Inference: Full support for variational inference (ADVI) and importance sampling
- JAX-Accelerated: Built on JAX for GPU/TPU acceleration and automatic differentiation
- Robust Training: Includes gradient clipping, learning rate scheduling, NaN recovery, and checkpointing
pip install bayesianquiltsgit clone https://github.com/mederrata/bayesianquilts.git
cd bayesianquilts
pip install -r requirements.txt
pip install -e .- Python >= 3.8
- JAX >= 0.7.1
- TensorFlow Probability >= 0.25.0
- Flax >= 0.11.2
- NumPy, Pandas, SciPy, Scikit-learn
- Optax (optimization)
- Orbax (checkpointing)
- ArviZ (Bayesian diagnostics)
See requirements.txt for complete dependency list.
from bayesianquilts.predictors.classification import LogisticBayesianquilt
from bayesianquilts.util import training_loop
import jax.numpy as jnp
# Prepare your data
X_train = jnp.array(...) # Features
y_train = jnp.array(...) # Labels
# Create model
model = LogisticBayesianquilt(
num_features=X_train.shape[1],
num_classes=2
)
# Initialize parameters
params = model.initialize(random_key)
# Train with built-in utilities
losses, trained_params = training_loop(
initial_values=params,
loss_fn=lambda p: model.loss(p, X_train, y_train),
num_epochs=100,
learning_rate=0.01,
clip_norm=1.0,
patience=10
)from bayesianquilts.metrics.ais import (
AdaptiveImportanceSampler,
LogisticRegressionLikelihood
)
# Define likelihood function
likelihood_fn = LogisticRegressionLikelihood()
# Create AIS sampler
ais_sampler = AdaptiveImportanceSampler(
likelihood_fn,
prior_log_prob_fn=prior_fn,
surrogate_log_prob_fn=surrogate_fn
)
# Compute LOO-CV with multiple transformation strategies
results = ais_sampler.adaptive_is_loo(
data={'X': X, 'y': y},
params=trained_params,
hbar=1.0,
variational=False,
transformations=['ll', 'kl', 'var', 'identity']
)
# Access results
print(f"LOO log-likelihood: {results['ll_loo_psis']}")
print(f"Effective parameters (p_loo): {results['p_loo_psis']}")
print(f"PSIS k-hat diagnostic: {results['khat']}")The fundamental innovation is an additive decomposition of model parameters:
θ_effective = θ_global + θ_group1 + θ_group2 + ... + θ_local
Each parameter value arises as a sum of contributions at different hierarchical levels (length scales), enabling:
- Automatic regularization through hierarchical priors
- Interpretable multi-level effects
- Interaction modeling across categorical and continuous variables
See notebooks/decomposition.ipynb for detailed examples.
The AIS framework implements gradient-flow transformations for stable LOO-CV:
- T_ll: Likelihood descent using negative log-likelihood gradients
- T_kl: KL-divergence weighted gradients using posterior weights
- T_var: Variance-based adaptation using Hessian curvature
- T_I: Identity (baseline, no transformation)
Combined with Pareto Smoothed Importance Sampling (PSIS) for robust weight estimation.
LogisticBayesianquilt: Piecewise linear logistic regressionLogisticRegression: Standard Bayesian logistic regression with decompositionLogisticRelunet: ReLU neural network classifierLogisticGamiNet: Generalized additive model with neural networksLogisticRidge: Ridge-regularized logistic regression
RegressionQuilt: Piecewise linear regressionHierarchicalAttention: Attention-based regression
GaussianFactorization: Continuous latent factor modelsPoissonFactorization: Count data factorizationBernoulliFactorization: Binary data factorization
DenseHorseshoe: Dense layers with horseshoe priorsDenseGaussian: Dense layers with Gaussian priorsGamiNetUnivariate: Univariate shape functionsGamiNetPairwise: Pairwise interaction networks
The util.py module provides robust training infrastructure:
from bayesianquilts.util import training_loop
losses, params = training_loop(
initial_values=initial_params,
loss_fn=loss_function,
data_iterator=data_batches,
steps_per_epoch=100,
num_epochs=50,
learning_rate=0.01,
clip_norm=1.0, # Gradient clipping
patience=10, # Early stopping
lr_decay_factor=0.5, # Learning rate decay
checkpoint_dir="./ckpts", # Automatic checkpointing
recover_from_nan=True # NaN recovery strategies
)Features:
- Gradient clipping for stability
- Learning rate scheduling with decay
- Early stopping with patience
- Automatic checkpointing with Orbax
- NaN/Inf detection and recovery
- Progress tracking with tqdm
Bayesianquilts includes several custom probability distributions:
GeneralizedGamma: Flexible shape for positive continuous dataPiecewiseExponential: For survival/duration modelingTransformedHorseshoe: Sparsity-inducing priorsTransformedCauchy: Heavy-tailed priorsTransformedInverseGamma: Scale parameter priors
notebooks/decomposition.ipynb: Parameter decomposition methodologynotebooks/ovarian/: Medical claims modeling examplesnotebooks/roach/: Logistic regression case studiesnotebooks/enset/: Model comparison demonstrationstest_ais_framework.py: Complete AIS usage examplestest_gradient_clipping.py: Training utilities demonstration
This library implements methods from:
-
Chang TL, Xia H, Mahajan S, Mahajan R, Maisog J, et al. (2024). Interpretable (not just posthoc-explainable) medical claims modeling for discharge placement to reduce preventable all-cause readmissions or death. PLOS ONE 19(5): e0302871. https://doi.org/10.1371/journal.pone.0302871
-
Xia H, Chang JC, Nowak S, Mahajan S, Mahajan R, Chang TL, Chow CC (2023). Proceedings of the 8th Machine Learning for Healthcare Conference, PMLR 219:884-905.
- Chang JC, Li X, Xu S, Yao HR, Porcino J, Chow CC (2024). Gradient-flow adaptive importance sampling for Bayesian leave one out cross-validation with application to sigmoidal classification models. ArXiv [Preprint] 2402.08151v2. PMID: 38711425; PMCID: PMC11071546. https://arxiv.org/abs/2402.08151
bayesianquilts/
├── bayesianquilts/
│ ├── model.py # Base BayesianModel class
│ ├── util.py # Training loops and utilities
│ ├── features.py # Feature engineering
│ ├── jax/
│ │ └── parameter.py # Parameter decomposition (Decomposed, Interactions)
│ ├── predictors/
│ │ ├── classification/ # Classification models
│ │ ├── regression/ # Regression models
│ │ ├── nn/ # Neural network components
│ │ └── factorization/ # Matrix factorization
│ ├── metrics/
│ │ ├── ais.py # Adaptive importance sampling
│ │ ├── psis.py # Pareto smoothed IS
│ │ └── nppsis.py # NumPy/JAX PSIS
│ ├── vi/
│ │ ├── advi.py # ADVI implementation
│ │ └── minibatch.py # Minibatch VI
│ ├── distributions/ # Custom distributions
│ └── plotting/ # Visualization utilities
├── notebooks/ # Example notebooks
├── requirements.txt # Dependencies
└── setup.py # Package setup
The API is currently evolving as we prepare manuscripts on the methodology and theory. We will stabilize the API in future releases. For production use, please pin to specific versions.
Contributions are welcome! Please:
- Fork the repository
- Create a feature branch
- Add tests for new functionality
- Ensure all tests pass
- Submit a pull request
MIT License - see LICENSE file for details.
- Organization: Mederrata Research LLC (501(c)3 non-profit)
- Email: [email protected]
- Repository: https://github.com/mederrata/bayesianquilts
Mederrata Research LLC is a 501(c)3 non-profit organization. Tax-deductible monetary contributions are welcome and help support the development and maintenance of open-source tools for interpretable machine learning in healthcare and scientific research.
To make a contribution or learn more, please contact us at [email protected].
If you use this library in your research, please cite:
@article{chang2024interpretable,
title={Interpretable (not just posthoc-explainable) medical claims modeling for discharge placement to reduce preventable all-cause readmissions or death},
author={Chang, Ted L and Xia, Hongjing and Mahajan, Sonya and Mahajan, Rohit and Maisog, Jose and others},
journal={PLOS ONE},
volume={19},
number={5},
pages={e0302871},
year={2024},
publisher={Public Library of Science}
}
@article{chang2024gradient,
title={Gradient-flow adaptive importance sampling for Bayesian leave one out cross-validation with application to sigmoidal classification models},
author={Chang, Joshua C and Li, Xu and Xu, Shuang and Yao, Howard R and Porcino, John and Chow, Carson C},
journal={arXiv preprint arXiv:2402.08151},
year={2024}
}This work was developed by the Mederrata Research team with support from the research community. Special thanks to all contributors and users who have provided feedback and helped improve the library.