Skip to content

Commit 2dd1684

Browse files
Cross-Platform Torch Save / Load Support, Doc Updates, Release Prep (#141)
* Torch save / load implemented. * Linted. * Updated documentation and changelog. * Updated documentation.
1 parent 3fb5703 commit 2dd1684

File tree

6 files changed

+125
-2
lines changed

6 files changed

+125
-2
lines changed

CHANGELOG.md

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
## Latest Changes
2+
3+
### v0.3.0 (2025-06-22)
4+
This release includes bugfixes and new opaque operations that
5+
compose with `torch.compile`
6+
for PT2.4-2.7. These will be unnecessary for PT2.8+.
7+
8+
**Added**:
9+
1. Opaque variants of major operations
10+
via PyTorch `custom_op` declarations. These
11+
functions cannot be traced through and fail
12+
for JITScript / AOTI. They are shims that
13+
enable composition with `torch.compile`
14+
pre-PT2.8.
15+
2. `torch.load`/`torch.save` functionality
16+
that, without `torch.compile`, is portable
17+
across GPU architectures.
18+
3. `.to()` support to move `TensorProduct`
19+
and `TensorProductConv` between devices or
20+
change datatypes.
21+
22+
**Fixed**:
23+
1. Gracefully records an error if `libpython.so`
24+
is not linked against C++ extension.
25+
2. Resolves Kahan summation / various other bugs
26+
for HIP at O3 compiler-optimization level.
27+
3. Removes multiple contexts spawning for GPU 0
28+
when multiple devices are used.
29+
4. Zero-initialized gradient buffers to prevent
30+
backward pass garbage accumulation.
31+
32+
### v0.2.0 (2025-06-09)
33+
34+
Our first stable release, **v0.2.0**, introduces several new features. Highlights include:
35+
36+
1. Full HIP support for all kernels.
37+
2. Support for `torch.compile`, JITScript and export, preliminary support for AOTI.
38+
3. Faster double backward performance for training.
39+
4. Ability to install versioned releases from PyPI.
40+
5. Support for CUDA streams and multiple devices.
41+
6. An extensive test suite and newly released [documentation](https://passionlab.github.io/OpenEquivariance/).
42+
43+
If you successfully run OpenEquivariance on a GPU model not listed [here](https://passionlab.github.io/OpenEquivariance/tests_and_benchmarks/), let us know! We can add your name to the list.
44+
45+
---
46+
47+
**Known issues:**
48+
49+
- Kahan summation is broken on HIP – fix planned.
50+
- FX + Export + Compile has trouble with PyTorch dynamo; fix planned.
51+
- AOTI broken on PT <2.8; you need the nightly build due to incomplete support for TorchBind in prior versions.
52+
53+
### v0.1.0 (2025-01-23)
54+
Initial Github release with preprint.

docs/supported_ops.rst

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,26 @@ We do not (yet) support:
5555

5656
If you have a use case for any of the unsupported features above, let us know.
5757

58+
59+
Torch Save / Load
60+
---------------------------------------------------
61+
OpenEquivariance's ``TensorProduct`` / ``TensorProductConv`` modules
62+
can be saved via ``torch.save`` and restored via ``torch.load``.
63+
You must call ``import openequivariance`` before attempting to load, i.e.
64+
65+
.. code-block::
66+
67+
import torch
68+
import openequivariance
69+
module = torch.load("my_module_with_tp.pt")
70+
71+
If you do NOT use ``torch.compile`` or ``torch.export``, these modules
72+
can be loaded on a platform with a distinct GPU architecture from the saving
73+
platform. In this case, kernels are recompiled dynamically. After compilation,
74+
a module may only be used on a platform with GPU architecture identical
75+
to the machine that saved it.
76+
77+
5878
Compilation with JITScript, Export, and AOTInductor
5979
---------------------------------------------------
6080

@@ -72,7 +92,6 @@ unless you are using a Nightly
7292
build of PyTorch past 4/10/2025 due to incomplete support for
7393
TorchBind in earlier versions.
7494

75-
7695
Multiple Devices and Streams
7796
----------------------------
7897
OpenEquivariance compiles kernels based on the compute capability of the

openequivariance/__init__.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# ruff: noqa: F401
22
import sys
3+
import torch
4+
import numpy as np
35

46
try:
57
import openequivariance.extlib
@@ -8,7 +10,13 @@
810
from pathlib import Path
911
from importlib.metadata import version
1012

11-
from openequivariance.implementations.e3nn_lite import TPProblem, Irreps
13+
from openequivariance.implementations.e3nn_lite import (
14+
TPProblem,
15+
Irrep,
16+
Irreps,
17+
_MulIr,
18+
Instruction,
19+
)
1220
from openequivariance.implementations.TensorProduct import TensorProduct
1321
from openequivariance.implementations.convolution.TensorProductConv import (
1422
TensorProductConv,
@@ -41,6 +49,20 @@ def torch_ext_so_path():
4149
return openequivariance.extlib.torch_module.__file__
4250

4351

52+
torch.serialization.add_safe_globals(
53+
[
54+
TensorProduct,
55+
TensorProductConv,
56+
TPProblem,
57+
Irrep,
58+
Irreps,
59+
_MulIr,
60+
Instruction,
61+
np.float32,
62+
np.float64,
63+
]
64+
)
65+
4466
LINKED_LIBPYTHON = openequivariance.extlib.LINKED_LIBPYTHON
4567
LINKED_LIBPYTHON_ERROR = openequivariance.extlib.LINKED_LIBPYTHON_ERROR
4668

openequivariance/implementations/TensorProduct.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ def to(self, *args, **kwargs):
5252
torch.nn.Module.to(self, *args, **kwargs)
5353
return self
5454

55+
def __getstate__(self):
56+
return self.input_args
57+
58+
def __setstate__(self, state):
59+
torch.nn.Module.__init__(self)
60+
self.input_args = state
61+
self._init_class()
62+
5563
@staticmethod
5664
def name():
5765
return LoopUnrollTP.name()

openequivariance/implementations/convolution/TensorProductConv.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,14 @@ def to(self, *args, **kwargs):
8484
torch.nn.Module.to(self, *args, **kwargs)
8585
return self
8686

87+
def __getstate__(self):
88+
return self.input_args
89+
90+
def __setstate__(self, state):
91+
torch.nn.Module.__init__(self)
92+
self.input_args = state
93+
self._init_class()
94+
8795
def forward(
8896
self,
8997
X: torch.Tensor,

tests/export_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,18 @@ def tp_and_inputs(request, problem_and_irreps):
8686
)
8787

8888

89+
def test_torch_load(tp_and_inputs):
90+
tp, inputs = tp_and_inputs
91+
original_result = tp.forward(*inputs)
92+
93+
with tempfile.NamedTemporaryFile(suffix=".pt") as tmp_file:
94+
torch.save(tp, tmp_file.name)
95+
loaded_tp = torch.load(tmp_file.name)
96+
97+
reloaded_result = loaded_tp(*inputs)
98+
assert torch.allclose(original_result, reloaded_result, atol=1e-5)
99+
100+
89101
def test_jitscript(tp_and_inputs):
90102
tp, inputs = tp_and_inputs
91103
uncompiled_result = tp.forward(*inputs)

0 commit comments

Comments
 (0)