Skip to content

tnsaai/OpenArchX

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

OpenArchX Logo

OpenArchX

Transformable numerical computing at scale

Build Status Python 3.11+ License: Apache 2.0

Transformations | Scaling | Install guide | Change logs | Reference docs


What is OpenArchX?

OpenArchX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.

OpenArchX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) via openarchx.grad as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.

OpenArchX uses XLA to compile and scale your NumPy programs on TPUs, GPUs, and other hardware accelerators. You can compile your own pure functions with openarchx.jit. Compilation and automatic differentiation can be composed arbitrarily.

OpenArchX is an extensible system for composable function transformations at scale.

Note: The current version of OpenArchX is architecturally similar to JAX with targeted modifications for advanced neural models. TNSA will be actively contributing to expand the feature set, optimize performance, and introduce novel architectures to set OpenArchX apart as a premier platform for open research and large-scale AI.

This is a research project. Expect sharp edges. Please help by trying it out, reporting bugs, and letting us know what you think!

import openarchx as ax
import openarchx.numpy as anp

def predict(params, inputs):
  for W, b in params:
    outputs = anp.dot(inputs, W) + b
    inputs = anp.tanh(outputs)  # inputs to the next layer
  return outputs                # no activation on last layer

def loss(params, inputs, targets):
  preds = predict(params, inputs)
  return anp.sum((preds - targets)**2)

grad_loss = ax.jit(ax.grad(loss))  # compiled gradient evaluation function
perex_grads = ax.jit(ax.vmap(grad_loss, in_axes=(None, 0, 0)))  # fast per-example grads

Contents


Transformations

At its core, OpenArchX is an extensible system for transforming numerical functions. Here are three core transformations: openarchx.grad, openarchx.jit, and openarchx.vmap.

Automatic differentiation with grad

Use openarchx.grad to efficiently compute reverse-mode gradients:

import openarchx as ax
import openarchx.numpy as anp

def tanh(x):
  y = anp.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

grad_tanh = ax.grad(tanh)
print(grad_tanh(1.0))
# prints 0.4199743

You can differentiate to any order with grad:

print(ax.grad(ax.grad(ax.grad(tanh)))(1.0))
# prints 0.62162673

Compilation with jit

Use XLA to compile your functions end-to-end with jit, used either as an @jit decorator or as a higher-order function.

import openarchx as ax
import openarchx.numpy as anp

def slow_f(x):
  # Element-wise ops see a large benefit from fusion
  return x * x + x * 2.0

x = anp.ones((5000, 5000))
fast_f = ax.jit(slow_f)

Auto-vectorization with vmap

vmap maps a function along array axes. Instead of looping, it pushes the loop down onto the function’s primitive operations (e.g., turning matrix-vector multiplies into matrix-matrix multiplies).

import openarchx as ax
import openarchx.numpy as anp

def l1_distance(x, y):
  assert x.ndim == y.ndim == 1  # only works on 1D inputs
  return anp.sum(anp.abs(x - y))

def pairwise_distances(dist1D, xs):
  return ax.vmap(ax.vmap(dist1D, (0, None)), (None, 0))(xs, xs)

xs = ax.random.normal(ax.random.key(0), (100, 3))
dists = pairwise_distances(l1_distance, xs)

Scaling

To scale computations across thousands of devices, you can use any composition of these:

Mode View Explicit sharding? Explicit Collectives?
Auto Global No No
Explicit Global Yes No
Manual Per-device Yes Yes
from openarchx.sharding import set_mesh, AxisType, PartitionSpec as P

mesh = ax.make_mesh((8,), ('data',), axis_types=(AxisType.Explicit,))
set_mesh(mesh)

# shard data for batch parallelism:
inputs, targets = ax.device_put((inputs, targets), P('data'))

# evaluate gradients, automatically parallelized!
gradfun = ax.jit(ax.grad(loss))
param_grads = gradfun(params, (inputs, targets))

New Architectures

OpenArchX provides integrated support for state-of-the-art neural architectures.

Mamba (State Space Models)

Optimized selective scan implementation for linear-time sequence modeling.

from openarchx.nn import Mamba, MambaConfig

config = MambaConfig(d_model=256, d_state=16)
model = Mamba(config)

Flash Attention

Native cross-platform memory-efficient attention (O(N) memory complexity).

from openarchx.nn import dot_product_attention

output = dot_product_attention(q, k, v, use_flash=True)

Continuous-Time Neural Networks

Architectures for irregular time-series and autonomous control:

  • LTC (Liquid Time-Constant): Recurrent models with ODE-based synaptic dynamics.
  • CfC (Closed-form Continuous): Liquid ODE properties with closed-form solver efficiency.
  • NCP (Neural Circuit Policy): Sparse, biologically-inspired networks for explainable control.

GPU Optimizations

The latest updates bring significant GPU performance benefits:

  • Memory Pooling: Pre-allocation and managed memory pools to reduce latency.
  • Precision Contexts: Specialized support for FP8, BF16, and stable arithmetic paths.
from openarchx.gpu import set_memory_preallocation, PrecisionContext

set_memory_preallocation(fraction=0.8)

with PrecisionContext(precision="mixed"):
    # High-performance computation here
    pass

Installation

Supported platforms

Platform Linux x86_64 Linux aarch64 Mac aarch64 Windows x86_64 Windows WSL2
CPU Yes Yes Yes Yes Yes
NVIDIA GPU Yes Yes n/a Yes Yes
Google TPU Yes n/a n/a n/a n/a

Instructions

Platform Instructions
CPU pip install -U openarchx
NVIDIA GPU pip install -U "openarchx[cuda]"
Google TPU pip install -U "openarchx[tpu]"

OpenArchX is an independent project derived from the Apache-2.0 licensed JAX framework. It is not affiliated with or endorsed by Google.


Citing OpenArchX

To cite this repository:

@software{openarchx2026,
  author = {OpenArchX Contributors},
  title = {OpenArchX: Composable transformations of Python+NumPy programs},
  url = {http://github.com/tnsaai/openarchx},
  version = {1.0.0},
  year = {2026},
}

Reference documentation

For details about the OpenArchX API, see the reference documentation.

For getting started as an OpenArchX developer, see the developer documentation.