OpenArchX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.
OpenArchX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) via openarchx.grad as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.
OpenArchX uses XLA to compile and scale your NumPy programs on TPUs, GPUs, and other hardware accelerators. You can compile your own pure functions with openarchx.jit. Compilation and automatic differentiation can be composed arbitrarily.
OpenArchX is an extensible system for composable function transformations at scale.
Note: The current version of OpenArchX is architecturally similar to JAX with targeted modifications for advanced neural models. TNSA will be actively contributing to expand the feature set, optimize performance, and introduce novel architectures to set OpenArchX apart as a premier platform for open research and large-scale AI.
This is a research project. Expect sharp edges. Please help by trying it out, reporting bugs, and letting us know what you think!
import openarchx as ax
import openarchx.numpy as anp
def predict(params, inputs):
for W, b in params:
outputs = anp.dot(inputs, W) + b
inputs = anp.tanh(outputs) # inputs to the next layer
return outputs # no activation on last layer
def loss(params, inputs, targets):
preds = predict(params, inputs)
return anp.sum((preds - targets)**2)
grad_loss = ax.jit(ax.grad(loss)) # compiled gradient evaluation function
perex_grads = ax.jit(ax.vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads- Transformations
- Scaling
- New Architectures
- GPU Optimizations
- Installation
- Citing OpenArchX
- Reference documentation
At its core, OpenArchX is an extensible system for transforming numerical functions. Here are three core transformations: openarchx.grad, openarchx.jit, and openarchx.vmap.
Use openarchx.grad to efficiently compute reverse-mode gradients:
import openarchx as ax
import openarchx.numpy as anp
def tanh(x):
y = anp.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)
grad_tanh = ax.grad(tanh)
print(grad_tanh(1.0))
# prints 0.4199743You can differentiate to any order with grad:
print(ax.grad(ax.grad(ax.grad(tanh)))(1.0))
# prints 0.62162673Use XLA to compile your functions end-to-end with jit, used either as an @jit decorator or as a higher-order function.
import openarchx as ax
import openarchx.numpy as anp
def slow_f(x):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
x = anp.ones((5000, 5000))
fast_f = ax.jit(slow_f)vmap maps a function along array axes. Instead of looping, it pushes the loop down onto the function’s primitive operations (e.g., turning matrix-vector multiplies into matrix-matrix multiplies).
import openarchx as ax
import openarchx.numpy as anp
def l1_distance(x, y):
assert x.ndim == y.ndim == 1 # only works on 1D inputs
return anp.sum(anp.abs(x - y))
def pairwise_distances(dist1D, xs):
return ax.vmap(ax.vmap(dist1D, (0, None)), (None, 0))(xs, xs)
xs = ax.random.normal(ax.random.key(0), (100, 3))
dists = pairwise_distances(l1_distance, xs)To scale computations across thousands of devices, you can use any composition of these:
| Mode | View | Explicit sharding? | Explicit Collectives? |
|---|---|---|---|
| Auto | Global | No | No |
| Explicit | Global | Yes | No |
| Manual | Per-device | Yes | Yes |
from openarchx.sharding import set_mesh, AxisType, PartitionSpec as P
mesh = ax.make_mesh((8,), ('data',), axis_types=(AxisType.Explicit,))
set_mesh(mesh)
# shard data for batch parallelism:
inputs, targets = ax.device_put((inputs, targets), P('data'))
# evaluate gradients, automatically parallelized!
gradfun = ax.jit(ax.grad(loss))
param_grads = gradfun(params, (inputs, targets))OpenArchX provides integrated support for state-of-the-art neural architectures.
Optimized selective scan implementation for linear-time sequence modeling.
from openarchx.nn import Mamba, MambaConfig
config = MambaConfig(d_model=256, d_state=16)
model = Mamba(config)Native cross-platform memory-efficient attention (O(N) memory complexity).
from openarchx.nn import dot_product_attention
output = dot_product_attention(q, k, v, use_flash=True)Architectures for irregular time-series and autonomous control:
- LTC (Liquid Time-Constant): Recurrent models with ODE-based synaptic dynamics.
- CfC (Closed-form Continuous): Liquid ODE properties with closed-form solver efficiency.
- NCP (Neural Circuit Policy): Sparse, biologically-inspired networks for explainable control.
The latest updates bring significant GPU performance benefits:
- Memory Pooling: Pre-allocation and managed memory pools to reduce latency.
- Precision Contexts: Specialized support for FP8, BF16, and stable arithmetic paths.
from openarchx.gpu import set_memory_preallocation, PrecisionContext
set_memory_preallocation(fraction=0.8)
with PrecisionContext(precision="mixed"):
# High-performance computation here
pass| Platform | Linux x86_64 | Linux aarch64 | Mac aarch64 | Windows x86_64 | Windows WSL2 |
|---|---|---|---|---|---|
| CPU | Yes | Yes | Yes | Yes | Yes |
| NVIDIA GPU | Yes | Yes | n/a | Yes | Yes |
| Google TPU | Yes | n/a | n/a | n/a | n/a |
| Platform | Instructions |
|---|---|
| CPU | pip install -U openarchx |
| NVIDIA GPU | pip install -U "openarchx[cuda]" |
| Google TPU | pip install -U "openarchx[tpu]" |
OpenArchX is an independent project derived from the Apache-2.0 licensed JAX framework. It is not affiliated with or endorsed by Google.
To cite this repository:
@software{openarchx2026,
author = {OpenArchX Contributors},
title = {OpenArchX: Composable transformations of Python+NumPy programs},
url = {http://github.com/tnsaai/openarchx},
version = {1.0.0},
year = {2026},
}For details about the OpenArchX API, see the reference documentation.
For getting started as an OpenArchX developer, see the developer documentation.
