Discover extreme patterns in your data with GPU/TPU-accelerated Archetypal Analysis, high-performance convex hull optimization, and interpretable matrix factorization.
- Overview
- Features
- Installation
- Quick Start
- Import Patterns
- Documentation
- Examples
- How It Works
- Changelog
- Citation
- License
- Contributing
- Community
archetypax is a high-performance implementation of Archetypal Analysis (AA) leveraging JAX for GPU acceleration.
Archetypal Analysis is a powerful matrix factorization technique representing data points
as convex combinations of extreme points (archetypes) found within the data's convex hull.
Unlike traditional dimensionality reduction techniques like PCA, which finds abstract orthogonal components,
AA discovers interpretable extremal points often corresponding to meaningful prototypes.
This makes it valuable for applications requiring both dimensionality reduction and human-interpretable insights,
such as market segmentation, document analysis, and anomaly detection.
Performance & Stability
- ๐ GPU/TPU acceleration using JAX
- ๐ง Smart initialization (k-means++, directional)
- ๐ ๏ธ Numerical stability & convergence techniques
Usability & Compatibility
- ๐ scikit-learn compatible API (fit/transform)
- ๐ Thorough documentation
Interpretability & Visualization
- ๐ Meaningful interpretable archetypes
- ๐ Advanced tracking & optimization trajectory monitoring
- ๐ฏ Comprehensive evaluation & visualization tooling
ArchetypAX can be used alongside or compared with these related approaches:
- PCA: Principal Component Analysis finds orthogonal directions of maximum variance
- NMF: Non-negative Matrix Factorization decomposes data into non-negative components
- k-means: Clustering technique that partitions data into k clusters
- JAX Ecosystem: Compatible with JAX-based machine learning frameworks like Flax
- scikit-learn: Follows similar API conventions, allowing easy integration
Install with pip, uv, or poetry:
# pip
pip install archetypax
pip install git+https://github.com/lv416e/archetypax.git
# uv
uv pip install archetypax
uv pip install git+https://github.com/lv416e/archetypax.git
# poetry
poetry add archetypax
poetry add git+https://github.com/lv416e/archetypax.gitInstall optional dependencies:
pip install archetypax[dev] # Development dependencies
pip install archetypax[examples] # Example dependencies
pip install archetypax[docs] # Documentation dependencies| Type | Dependency | Version | Description |
|---|---|---|---|
| Core | Python | >=3.10 | Required for modern language features and compatibility with JAX |
| Core | JAX | >=0.4.0 | Powers the hardware acceleration and automatic differentiation |
| Core | NumPy | >=1.20.0 | Handles core numerical operations and array manipulations |
| Core | optax | >=0.1.0 | JAX-based optimization framework for gradient-based updates |
| Core | pandas | >=1.3.0 | Data manipulation and analysis library |
| Core | scikit-learn | >=1.0.0 | Provides machine learning utilities and compatible interfaces |
| Examples | jupyter | >=1.0.0 | Interactive computing environment for notebooks |
| Examples | matplotlib | >=3.7.5 | Required for visualization functionality |
| Examples | seaborn | >=0.13.2 | Statistical data visualization |
| Dev | black | ==23.7.0 | Code formatter |
| Dev | mypy | >=1.8.0 | Static type checker |
| Dev | pytest | >=7.0.0 | Testing framework |
| Dev | ruff | >=0.9.0 | Fast Python linter and formatter |
import numpy as np
from archetypax import ImprovedArchetypalAnalysis as ArchetypalAnalysis
# Generate sample data
np.random.seed(42)
X = np.random.rand(1000, 10)
# Initialize and fit the model
model = ArchetypalAnalysis(n_archetypes=5)
weights = model.fit_transform(X)
# Get the archetypes
archetypes = model.archetypes
# Reconstruct the data
X_reconstructed = model.reconstruct()
# Calculate reconstruction error
mse = np.mean((X - X_reconstructed) ** 2)
print(f"Reconstruction MSE: {mse:.6f}")ArchetypAX supports multiple import patterns for flexibility:
from archetypax import ArchetypalAnalysis, ImprovedArchetypalAnalysis, BiarchetypalAnalysis, ArchetypeTrackerfrom archetypax.models.base import ArchetypalAnalysis
from archetypax.models.biarchetypes import BiarchetypalAnalysis
from archetypax.tools.evaluation import ArchetypalAnalysisEvaluator
from archetypax.tools.tracker import ArchetypeTrackerfrom archetypax.models import ArchetypalAnalysis
from archetypax.tools import ArchetypalAnalysisVisualizer, ArchetypeTrackerFor a detailed list of changes and version history, please see the CHANGELOG.md file.
| Parameter | Type | Default | Description |
|---|---|---|---|
n_archetypes |
int | - | Number of archetypes to find |
max_iter |
int | 500 | Maximum number of iterations |
tol |
float | 1e-6 | Convergence tolerance |
random_seed |
int | 42 | Random seed for initialization |
learning_rate |
float | 0.001 | Learning rate for optimizer |
lambda_reg |
float | 0.01 | Regularization strength for weight distribution |
normalize |
bool | False | Whether to normalize features before fitting |
projection_method |
str | "cbap" | Method for projecting archetypes ("cbap", "convex_hull", "knn") |
projection_alpha |
float | 0.1 | Blending coefficient for boundary projection |
archetype_init_method |
str | "directional" | Initialization strategy ("directional", "kmeans++", "qhull") |
| Parameter | Type | Default | Description |
|---|---|---|---|
n_row_archetypes |
int | - | Number of archetypes in observation space |
n_col_archetypes |
int | - | Number of archetypes in feature space |
max_iter |
int | 500 | Maximum number of iterations |
tol |
float | 1e-6 | Convergence tolerance |
random_seed |
int | 42 | Random seed for initialization |
learning_rate |
float | 0.001 | Learning rate for optimizer |
projection_method |
str | "default" | Method for projecting archetypes |
lambda_reg |
float | 0.01 | Regularization strength for entropy terms |
| Method | Returns | Description |
|---|---|---|
fit(X) |
model | Fit the model to the data |
transform(X) |
array | Transform new data to archetype weights |
fit_transform(X) |
array | Fit the model and transform the data |
reconstruct(X) |
array | Reconstruct data from archetype weights |
get_loss_history() |
array | Get the loss history from training |
get_all_archetypes() |
tuple | Get both sets of archetypes (BiarchetypalAnalysis only) |
get_all_weights() |
tuple | Get both sets of weights (BiarchetypalAnalysis only) |
import numpy as np
import matplotlib.pyplot as plt
from archetypax import ImprovedArchetypalAnalysis
from archetypax.tools.visualization import ArchetypalAnalysisVisualizer
# Generate some interesting 2D data (a triangle with points inside)
n_samples = 500
vertices = np.array([[0, 0], [1, 0], [0.5, 0.866]])
weights = np.random.dirichlet(np.ones(3), size=n_samples)
X = weights @ vertices
# Fit archetypal analysis with 3 archetypes
model = ImprovedArchetypalAnalysis(n_archetypes=3, archetype_init_method="directional")
model.fit(X)
# Plot original data and archetypes
plt.figure(figsize=(10, 8))
ArchetypalAnalysisVisualizer.plot_archetypes_2d(model, X)
plt.title("Archetypal Analysis of 2D Data")
plt.show()import numpy as np
import matplotlib.pyplot as plt
from archetypax import BiarchetypalAnalysis
from archetypax.tools.visualization import ArchetypalAnalysisVisualizer
# Generate synthetic data
np.random.seed(42)
X = np.random.rand(500, 5)
# Initialize and fit the model with row and column archetypes
model = BiarchetypalAnalysis(
n_row_archetypes=2, # Number of archetypes in observation space
n_col_archetypes=2, # Number of archetypes in feature space
max_iter=500,
random_seed=42
)
model.fit(X)
# Get both sets of archetypes
row_archetypes, col_archetypes = model.get_all_archetypes()
print("Row archetypes shape:", row_archetypes.shape)
print("Column archetypes shape:", col_archetypes.shape)
# Get both sets of weights
row_weights, col_weights = model.get_all_weights()
print("Row weights shape:", row_weights.shape)
print("Column weights shape:", col_weights.shape)
# Reconstruct data using biarchetypes
X_reconstructed = model.reconstruct()
mse = np.mean((X - X_reconstructed) ** 2)
print(f"Reconstruction MSE: {mse:.6f}")import numpy as np
import matplotlib.pyplot as plt
from archetypax import ArchetypeTracker
# Generate sample data
np.random.seed(42)
X = np.random.rand(1000, 10)
# Initialize the tracker
tracker = ArchetypeTracker(
n_archetypes=3,
max_iter=300,
random_seed=42
)
# Fit the model while tracking archetype movement
tracker.fit(X)
# Visualize the archetype movement trajectory
tracker.visualize_movement()
# Visualize boundary proximity over iterations
tracker.visualize_boundary_proximity()Archetypal Analysis solves the following optimization problem:
Given a data matrix
subject to:
-
$\mathbf{W}$ is non-negative - Each row of
$\mathbf{W}$ sums to 1 (simplex constraint) -
$\mathbf{A}$ lies within the convex hull of$\mathbf{X}$
The biarchetypal extension solves a more complex factorization:
This implementation uses JAX's automatic differentiation and optimization tools to efficiently solve these problems on GPUs. It also incorporates several advanced enhancements:
- Strategic initialization methods including directional initialization, k-means++ style, and convex hull approximation
- Intelligent regularization techniques to promote interpretable weight distributions
- Advanced projection methods including adaptive convex boundary approximation (CBAP)
- Sophisticated numerical stability safeguards throughout the optimization process
- Comprehensive trajectory tracking for monitoring convergence dynamics
Contributions are welcome and highly encouraged! Before submitting a pull request, please review the following resources:
- Code of Conduct: Guidelines for community participation
- Security Policy: Vulnerability reporting and handling procedures
To contribute to the project:
- Fork the repository
- Create a feature branch (
git checkout -b feature/amazing-feature) - Commit your changes (
git commit -m 'Add some amazing feature') - Push to the branch (
git push origin feature/amazing-feature) - Open a Pull Request
- ๐ Issues: Report bugs and request features
- ๐ฌ Discussions: Questions and general community interactions
If you use this package in your research, please cite:
@software{archetypax2025,
author = {mary},
title = {archetypax: GPU-accelerated Archetypal Analysis using JAX},
year = {2025},
url = {https://github.com/lv416e/archetypax}
}
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.