From 65acce3fffb98deddb11e4dd181747e52c0791ca Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Wed, 29 Oct 2025 16:19:02 -0700 Subject: [PATCH 1/9] add setuptools to cover for torch cpp extension --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index f07fc32..6f7b877 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ authors = [ description = "A fast GPU JIT kernel generator for the Clebsch-Gordon Tensor Product" requires-python = ">=3.10" dependencies = [ + "setuptools", "ninja", "jinja2", "numpy", From 739706994686d5e9f386ad1df0cec8ec1327ffc6 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Wed, 29 Oct 2025 16:24:28 -0700 Subject: [PATCH 2/9] only import the extlib if the cuda is available --- openequivariance/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/openequivariance/__init__.py b/openequivariance/__init__.py index 9fb67d0..2a76e5c 100644 --- a/openequivariance/__init__.py +++ b/openequivariance/__init__.py @@ -2,11 +2,11 @@ import sys import torch import numpy as np - -try: - import openequivariance.extlib -except Exception as e: - raise ImportError(f"Unable to load OpenEquivariance extension library:\n{e}") +if torch.cuda.is_available(): + try: + import openequivariance.extlib + except Exception as e: + raise ImportError(f"Unable to load OpenEquivariance extension library:\n{e}") from pathlib import Path from importlib.metadata import version From 7f545e87936b31ce722a7c2d45c4b581c2f0d419 Mon Sep 17 00:00:00 2001 From: asglover <140220574+asglover@users.noreply.github.com> Date: Thu, 30 Oct 2025 12:27:15 -0700 Subject: [PATCH 3/9] revert ineffective change --- openequivariance/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/openequivariance/__init__.py b/openequivariance/__init__.py index 2a76e5c..9fb67d0 100644 --- a/openequivariance/__init__.py +++ b/openequivariance/__init__.py @@ -2,11 +2,11 @@ import sys import torch import numpy as np -if torch.cuda.is_available(): - try: - import openequivariance.extlib - except Exception as e: - raise ImportError(f"Unable to load OpenEquivariance extension library:\n{e}") + +try: + import openequivariance.extlib +except Exception as e: + raise ImportError(f"Unable to load OpenEquivariance extension library:\n{e}") from pathlib import Path from importlib.metadata import version From 613ad75e9102224c4f4367d330f95d88a91f6c92 Mon Sep 17 00:00:00 2001 From: asglover <140220574+asglover@users.noreply.github.com> Date: Thu, 30 Oct 2025 12:27:47 -0700 Subject: [PATCH 4/9] add pathway for when cuda.is.available is false, stubs for better errors --- openequivariance/extlib/__init__.py | 58 ++++++++++++++++++++++------- 1 file changed, 45 insertions(+), 13 deletions(-) diff --git a/openequivariance/extlib/__init__.py b/openequivariance/extlib/__init__.py index b5e2537..ed5695b 100644 --- a/openequivariance/extlib/__init__.py +++ b/openequivariance/extlib/__init__.py @@ -5,12 +5,16 @@ import sysconfig from pathlib import Path +global torch +import torch + from openequivariance.benchmark.logging_utils import getLogger oeq_root = str(Path(__file__).parent.parent) build_ext = True TORCH_COMPILE = True +TORCH_CUDA_AVAILABLE = torch.cuda.is_available() torch_module, generic_module = None, None postprocess_kernel = lambda kernel: kernel # noqa : E731 @@ -38,12 +42,9 @@ import openequivariance.extlib.generic_module generic_module = openequivariance.extlib.generic_module -else: +elif TORCH_CUDA_AVAILABLE: from torch.utils.cpp_extension import library_paths, include_paths - global torch - import torch - extra_cflags = ["-O3"] generic_sources = ["generic_module.cpp"] torch_sources = ["libtorch_tp_jit.cpp"] @@ -128,13 +129,44 @@ def postprocess(kernel): "Could not compile integrated PyTorch wrapper. Falling back to Pybind11" + f", but JITScript, compile fullgraph, and export will fail.\n {torch_compile_exception}" ) +else: + TORCH_COMPILE = False + +def _raise_import_error_helper(import_target: str): + if not TORCH_CUDA_AVAILABLE: + raise ImportError( + f"Could not import {import_target}: OpenEquivariance's torch extension was not built because torch.cuda.is_available() is false" + ) + + +if TORCH_CUDA_AVAILABLE: + from generic_module import ( + JITTPImpl, + JITConvImpl, + GroupMM_F32, + GroupMM_F64, + DeviceProp, + DeviceBuffer, + GPUTimer, + ) +else: + def JITTPImpl(*args, **kwargs): + _raise_import_error_helper("JITTPImpl") + + def JITConvImpl(*args, **kwargs): + _raise_import_error_helper("JITConvImpl") + + def GroupMM_F32(*args, **kwargs): + _raise_import_error_helper("GroupMM_F32") + + def GroupMM_F64(*args, **kwargs): + _raise_import_error_helper("GroupMM_F64") + + def DeviceProp(*args, **kwargs): + _raise_import_error_helper("DeviceProp") + + def DeviceBuffer(*args, **kwargs): + _raise_import_error_helper("DeviceBuffer") -from generic_module import ( - JITTPImpl, - JITConvImpl, - GroupMM_F32, - GroupMM_F64, - DeviceProp, - DeviceBuffer, - GPUTimer, -) + def GPUTimer(*args, **kwargs): + _raise_import_error_helper("GPUTimer") From 6b87c69b5c818082dd835b1fabaf5f10383e1c1e Mon Sep 17 00:00:00 2001 From: asglover <140220574+asglover@users.noreply.github.com> Date: Fri, 31 Oct 2025 14:20:22 -0700 Subject: [PATCH 5/9] move example tests to another file so they don't get loaded during the import test --- tests/examples_test.py | 75 ++++++++++++++++++++++++++++++++++++++++ tests/import_test.py | 77 ------------------------------------------ 2 files changed, 75 insertions(+), 77 deletions(-) create mode 100644 tests/examples_test.py diff --git a/tests/examples_test.py b/tests/examples_test.py new file mode 100644 index 0000000..3beaabb --- /dev/null +++ b/tests/examples_test.py @@ -0,0 +1,75 @@ +def test_tutorial(): + import torch + import e3nn.o3 as o3 + + gen = torch.Generator(device="cuda") + + batch_size = 1000 + X_ir, Y_ir, Z_ir = o3.Irreps("1x2e"), o3.Irreps("1x3e"), o3.Irreps("1x2e") + X = torch.rand(batch_size, X_ir.dim, device="cuda", generator=gen) + Y = torch.rand(batch_size, Y_ir.dim, device="cuda", generator=gen) + + instructions = [(0, 0, 0, "uvu", True)] + + tp_e3nn = o3.TensorProduct( + X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False + ).to("cuda") + W = torch.rand(batch_size, tp_e3nn.weight_numel, device="cuda", generator=gen) + + Z = tp_e3nn(X, Y, W) + print(torch.norm(Z)) + # =============================== + + # =============================== + import openequivariance as oeq + + problem = oeq.TPProblem( + X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False + ) + tp_fast = oeq.TensorProduct(problem, torch_op=True) + + Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier + print(torch.norm(Z)) + # =============================== + + # Graph Convolution + # =============================== + from torch_geometric import EdgeIndex + + node_ct, nonzero_ct = 3, 4 + + # Receiver, sender indices for message passing GNN + edge_index = EdgeIndex( + [ + [0, 1, 1, 2], # Receiver + [1, 0, 2, 1], + ], # Sender + device="cuda", + dtype=torch.long, + ) + + X = torch.rand(node_ct, X_ir.dim, device="cuda", generator=gen) + Y = torch.rand(nonzero_ct, Y_ir.dim, device="cuda", generator=gen) + W = torch.rand(nonzero_ct, problem.weight_numel, device="cuda", generator=gen) + + tp_conv = oeq.TensorProductConv( + problem, torch_op=True, deterministic=False + ) # Reuse problem from earlier + Z = tp_conv.forward( + X, Y, W, edge_index[0], edge_index[1] + ) # Z has shape [node_ct, z_ir.dim] + print(torch.norm(Z)) + # =============================== + + # =============================== + _, sender_perm = edge_index.sort_by("col") # Sort by sender index + edge_index, receiver_perm = edge_index.sort_by("row") # Sort by receiver index + + # Now we can use the faster deterministic algorithm + tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=True) + Z = tp_conv.forward( + X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm + ) + print(torch.norm(Z)) + # =============================== + assert True diff --git a/tests/import_test.py b/tests/import_test.py index a527783..3cfbb14 100644 --- a/tests/import_test.py +++ b/tests/import_test.py @@ -7,80 +7,3 @@ def test_import(): assert openequivariance.__version__ is not None assert openequivariance.__version__ != "0.0.0" assert openequivariance.__version__ == version("openequivariance") - - -def test_tutorial(): - import torch - import e3nn.o3 as o3 - - gen = torch.Generator(device="cuda") - - batch_size = 1000 - X_ir, Y_ir, Z_ir = o3.Irreps("1x2e"), o3.Irreps("1x3e"), o3.Irreps("1x2e") - X = torch.rand(batch_size, X_ir.dim, device="cuda", generator=gen) - Y = torch.rand(batch_size, Y_ir.dim, device="cuda", generator=gen) - - instructions = [(0, 0, 0, "uvu", True)] - - tp_e3nn = o3.TensorProduct( - X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False - ).to("cuda") - W = torch.rand(batch_size, tp_e3nn.weight_numel, device="cuda", generator=gen) - - Z = tp_e3nn(X, Y, W) - print(torch.norm(Z)) - # =============================== - - # =============================== - import openequivariance as oeq - - problem = oeq.TPProblem( - X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False - ) - tp_fast = oeq.TensorProduct(problem, torch_op=True) - - Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier - print(torch.norm(Z)) - # =============================== - - # Graph Convolution - # =============================== - from torch_geometric import EdgeIndex - - node_ct, nonzero_ct = 3, 4 - - # Receiver, sender indices for message passing GNN - edge_index = EdgeIndex( - [ - [0, 1, 1, 2], # Receiver - [1, 0, 2, 1], - ], # Sender - device="cuda", - dtype=torch.long, - ) - - X = torch.rand(node_ct, X_ir.dim, device="cuda", generator=gen) - Y = torch.rand(nonzero_ct, Y_ir.dim, device="cuda", generator=gen) - W = torch.rand(nonzero_ct, problem.weight_numel, device="cuda", generator=gen) - - tp_conv = oeq.TensorProductConv( - problem, torch_op=True, deterministic=False - ) # Reuse problem from earlier - Z = tp_conv.forward( - X, Y, W, edge_index[0], edge_index[1] - ) # Z has shape [node_ct, z_ir.dim] - print(torch.norm(Z)) - # =============================== - - # =============================== - _, sender_perm = edge_index.sort_by("col") # Sort by sender index - edge_index, receiver_perm = edge_index.sort_by("row") # Sort by receiver index - - # Now we can use the faster deterministic algorithm - tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=True) - Z = tp_conv.forward( - X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm - ) - print(torch.norm(Z)) - # =============================== - assert True From 1d6f7aca29b369276738d65d8e628e1eb50eb086 Mon Sep 17 00:00:00 2001 From: asglover <140220574+asglover@users.noreply.github.com> Date: Fri, 31 Oct 2025 14:33:35 -0700 Subject: [PATCH 6/9] format --- openequivariance/extlib/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/openequivariance/extlib/__init__.py b/openequivariance/extlib/__init__.py index ed5695b..d21d507 100644 --- a/openequivariance/extlib/__init__.py +++ b/openequivariance/extlib/__init__.py @@ -132,6 +132,7 @@ def postprocess(kernel): else: TORCH_COMPILE = False + def _raise_import_error_helper(import_target: str): if not TORCH_CUDA_AVAILABLE: raise ImportError( @@ -150,6 +151,7 @@ def _raise_import_error_helper(import_target: str): GPUTimer, ) else: + def JITTPImpl(*args, **kwargs): _raise_import_error_helper("JITTPImpl") From 6075ad543f86cd3beba3c471491b0040dfdfde9a Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Fri, 31 Oct 2025 16:44:12 -0700 Subject: [PATCH 7/9] attempt to force compilation when no devices are present --- openequivariance/extlib/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/openequivariance/extlib/__init__.py b/openequivariance/extlib/__init__.py index d21d507..527c4ab 100644 --- a/openequivariance/extlib/__init__.py +++ b/openequivariance/extlib/__init__.py @@ -14,7 +14,7 @@ build_ext = True TORCH_COMPILE = True -TORCH_CUDA_AVAILABLE = torch.cuda.is_available() +TORCH_VERSION_CUDA_OR_HIP = torch.version.cuda or torch.version.hip torch_module, generic_module = None, None postprocess_kernel = lambda kernel: kernel # noqa : E731 @@ -42,7 +42,7 @@ import openequivariance.extlib.generic_module generic_module = openequivariance.extlib.generic_module -elif TORCH_CUDA_AVAILABLE: +elif TORCH_VERSION_CUDA_OR_HIP: from torch.utils.cpp_extension import library_paths, include_paths extra_cflags = ["-O3"] @@ -134,13 +134,13 @@ def postprocess(kernel): def _raise_import_error_helper(import_target: str): - if not TORCH_CUDA_AVAILABLE: + if not TORCH_VERSION_CUDA_OR_HIP: raise ImportError( - f"Could not import {import_target}: OpenEquivariance's torch extension was not built because torch.cuda.is_available() is false" + f"Could not import {import_target}: OpenEquivariance's torch extension was not built because torch.version.cuda || torch.version.hip is false" ) -if TORCH_CUDA_AVAILABLE: +if TORCH_VERSION_CUDA_OR_HIP: from generic_module import ( JITTPImpl, JITConvImpl, From c8ec0cf73f2bd6cb8c66df8f33c8f4cbebac91e5 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Fri, 31 Oct 2025 16:50:48 -0700 Subject: [PATCH 8/9] bad commit! --- openequivariance/extension/group_mm_cuda.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/openequivariance/extension/group_mm_cuda.hpp b/openequivariance/extension/group_mm_cuda.hpp index 95f1412..70ef13d 100644 --- a/openequivariance/extension/group_mm_cuda.hpp +++ b/openequivariance/extension/group_mm_cuda.hpp @@ -26,7 +26,8 @@ class GroupMMCUDA { beta(0.0) { stat = cublasCreate(&handle); if (stat != CUBLAS_STATUS_SUCCESS) { - throw std::logic_error("CUBLAS initialization failed"); + throw std::log + ic_error("CUBLAS initialization failed"); } } From eab3c432ea32853354969c61d3ef47e9b93e18c9 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Fri, 31 Oct 2025 17:06:26 -0700 Subject: [PATCH 9/9] revert bad commit --- openequivariance/extension/group_mm_cuda.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/openequivariance/extension/group_mm_cuda.hpp b/openequivariance/extension/group_mm_cuda.hpp index 70ef13d..95f1412 100644 --- a/openequivariance/extension/group_mm_cuda.hpp +++ b/openequivariance/extension/group_mm_cuda.hpp @@ -26,8 +26,7 @@ class GroupMMCUDA { beta(0.0) { stat = cublasCreate(&handle); if (stat != CUBLAS_STATUS_SUCCESS) { - throw std::log - ic_error("CUBLAS initialization failed"); + throw std::logic_error("CUBLAS initialization failed"); } }