Skip to content

Commit 339d8ea

Browse files
asgloverAustin Glover
andauthored
No import error when loading the package without cuda (#165)
* add setuptools to cover for torch cpp extension * only import the extlib if the cuda is available * revert ineffective change * add pathway for when cuda.is.available is false, stubs for better errors * move example tests to another file so they don't get loaded during the import test * format * attempt to force compilation when no devices are present * bad commit! * revert bad commit --------- Co-authored-by: Austin Glover <[email protected]>
1 parent f1f453f commit 339d8ea

File tree

4 files changed

+123
-90
lines changed

4 files changed

+123
-90
lines changed

openequivariance/extlib/__init__.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,16 @@
55
import sysconfig
66
from pathlib import Path
77

8+
global torch
9+
import torch
10+
811
from openequivariance.benchmark.logging_utils import getLogger
912

1013
oeq_root = str(Path(__file__).parent.parent)
1114

1215
build_ext = True
1316
TORCH_COMPILE = True
17+
TORCH_VERSION_CUDA_OR_HIP = torch.version.cuda or torch.version.hip
1418
torch_module, generic_module = None, None
1519
postprocess_kernel = lambda kernel: kernel # noqa : E731
1620

@@ -38,12 +42,9 @@
3842
import openequivariance.extlib.generic_module
3943

4044
generic_module = openequivariance.extlib.generic_module
41-
else:
45+
elif TORCH_VERSION_CUDA_OR_HIP:
4246
from torch.utils.cpp_extension import library_paths, include_paths
4347

44-
global torch
45-
import torch
46-
4748
extra_cflags = ["-O3"]
4849
generic_sources = ["generic_module.cpp"]
4950
torch_sources = ["libtorch_tp_jit.cpp"]
@@ -128,13 +129,46 @@ def postprocess(kernel):
128129
"Could not compile integrated PyTorch wrapper. Falling back to Pybind11"
129130
+ f", but JITScript, compile fullgraph, and export will fail.\n {torch_compile_exception}"
130131
)
132+
else:
133+
TORCH_COMPILE = False
134+
135+
136+
def _raise_import_error_helper(import_target: str):
137+
if not TORCH_VERSION_CUDA_OR_HIP:
138+
raise ImportError(
139+
f"Could not import {import_target}: OpenEquivariance's torch extension was not built because torch.version.cuda || torch.version.hip is false"
140+
)
141+
142+
143+
if TORCH_VERSION_CUDA_OR_HIP:
144+
from generic_module import (
145+
JITTPImpl,
146+
JITConvImpl,
147+
GroupMM_F32,
148+
GroupMM_F64,
149+
DeviceProp,
150+
DeviceBuffer,
151+
GPUTimer,
152+
)
153+
else:
154+
155+
def JITTPImpl(*args, **kwargs):
156+
_raise_import_error_helper("JITTPImpl")
157+
158+
def JITConvImpl(*args, **kwargs):
159+
_raise_import_error_helper("JITConvImpl")
160+
161+
def GroupMM_F32(*args, **kwargs):
162+
_raise_import_error_helper("GroupMM_F32")
163+
164+
def GroupMM_F64(*args, **kwargs):
165+
_raise_import_error_helper("GroupMM_F64")
166+
167+
def DeviceProp(*args, **kwargs):
168+
_raise_import_error_helper("DeviceProp")
169+
170+
def DeviceBuffer(*args, **kwargs):
171+
_raise_import_error_helper("DeviceBuffer")
131172

132-
from generic_module import (
133-
JITTPImpl,
134-
JITConvImpl,
135-
GroupMM_F32,
136-
GroupMM_F64,
137-
DeviceProp,
138-
DeviceBuffer,
139-
GPUTimer,
140-
)
173+
def GPUTimer(*args, **kwargs):
174+
_raise_import_error_helper("GPUTimer")

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ authors = [
1414
description = "A fast GPU JIT kernel generator for the Clebsch-Gordon Tensor Product"
1515
requires-python = ">=3.10"
1616
dependencies = [
17+
"setuptools",
1718
"ninja",
1819
"jinja2",
1920
"numpy",

tests/examples_test.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
def test_tutorial():
2+
import torch
3+
import e3nn.o3 as o3
4+
5+
gen = torch.Generator(device="cuda")
6+
7+
batch_size = 1000
8+
X_ir, Y_ir, Z_ir = o3.Irreps("1x2e"), o3.Irreps("1x3e"), o3.Irreps("1x2e")
9+
X = torch.rand(batch_size, X_ir.dim, device="cuda", generator=gen)
10+
Y = torch.rand(batch_size, Y_ir.dim, device="cuda", generator=gen)
11+
12+
instructions = [(0, 0, 0, "uvu", True)]
13+
14+
tp_e3nn = o3.TensorProduct(
15+
X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False
16+
).to("cuda")
17+
W = torch.rand(batch_size, tp_e3nn.weight_numel, device="cuda", generator=gen)
18+
19+
Z = tp_e3nn(X, Y, W)
20+
print(torch.norm(Z))
21+
# ===============================
22+
23+
# ===============================
24+
import openequivariance as oeq
25+
26+
problem = oeq.TPProblem(
27+
X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False
28+
)
29+
tp_fast = oeq.TensorProduct(problem, torch_op=True)
30+
31+
Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier
32+
print(torch.norm(Z))
33+
# ===============================
34+
35+
# Graph Convolution
36+
# ===============================
37+
from torch_geometric import EdgeIndex
38+
39+
node_ct, nonzero_ct = 3, 4
40+
41+
# Receiver, sender indices for message passing GNN
42+
edge_index = EdgeIndex(
43+
[
44+
[0, 1, 1, 2], # Receiver
45+
[1, 0, 2, 1],
46+
], # Sender
47+
device="cuda",
48+
dtype=torch.long,
49+
)
50+
51+
X = torch.rand(node_ct, X_ir.dim, device="cuda", generator=gen)
52+
Y = torch.rand(nonzero_ct, Y_ir.dim, device="cuda", generator=gen)
53+
W = torch.rand(nonzero_ct, problem.weight_numel, device="cuda", generator=gen)
54+
55+
tp_conv = oeq.TensorProductConv(
56+
problem, torch_op=True, deterministic=False
57+
) # Reuse problem from earlier
58+
Z = tp_conv.forward(
59+
X, Y, W, edge_index[0], edge_index[1]
60+
) # Z has shape [node_ct, z_ir.dim]
61+
print(torch.norm(Z))
62+
# ===============================
63+
64+
# ===============================
65+
_, sender_perm = edge_index.sort_by("col") # Sort by sender index
66+
edge_index, receiver_perm = edge_index.sort_by("row") # Sort by receiver index
67+
68+
# Now we can use the faster deterministic algorithm
69+
tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=True)
70+
Z = tp_conv.forward(
71+
X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm
72+
)
73+
print(torch.norm(Z))
74+
# ===============================
75+
assert True

tests/import_test.py

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -7,80 +7,3 @@ def test_import():
77
assert openequivariance.__version__ is not None
88
assert openequivariance.__version__ != "0.0.0"
99
assert openequivariance.__version__ == version("openequivariance")
10-
11-
12-
def test_tutorial():
13-
import torch
14-
import e3nn.o3 as o3
15-
16-
gen = torch.Generator(device="cuda")
17-
18-
batch_size = 1000
19-
X_ir, Y_ir, Z_ir = o3.Irreps("1x2e"), o3.Irreps("1x3e"), o3.Irreps("1x2e")
20-
X = torch.rand(batch_size, X_ir.dim, device="cuda", generator=gen)
21-
Y = torch.rand(batch_size, Y_ir.dim, device="cuda", generator=gen)
22-
23-
instructions = [(0, 0, 0, "uvu", True)]
24-
25-
tp_e3nn = o3.TensorProduct(
26-
X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False
27-
).to("cuda")
28-
W = torch.rand(batch_size, tp_e3nn.weight_numel, device="cuda", generator=gen)
29-
30-
Z = tp_e3nn(X, Y, W)
31-
print(torch.norm(Z))
32-
# ===============================
33-
34-
# ===============================
35-
import openequivariance as oeq
36-
37-
problem = oeq.TPProblem(
38-
X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False
39-
)
40-
tp_fast = oeq.TensorProduct(problem, torch_op=True)
41-
42-
Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier
43-
print(torch.norm(Z))
44-
# ===============================
45-
46-
# Graph Convolution
47-
# ===============================
48-
from torch_geometric import EdgeIndex
49-
50-
node_ct, nonzero_ct = 3, 4
51-
52-
# Receiver, sender indices for message passing GNN
53-
edge_index = EdgeIndex(
54-
[
55-
[0, 1, 1, 2], # Receiver
56-
[1, 0, 2, 1],
57-
], # Sender
58-
device="cuda",
59-
dtype=torch.long,
60-
)
61-
62-
X = torch.rand(node_ct, X_ir.dim, device="cuda", generator=gen)
63-
Y = torch.rand(nonzero_ct, Y_ir.dim, device="cuda", generator=gen)
64-
W = torch.rand(nonzero_ct, problem.weight_numel, device="cuda", generator=gen)
65-
66-
tp_conv = oeq.TensorProductConv(
67-
problem, torch_op=True, deterministic=False
68-
) # Reuse problem from earlier
69-
Z = tp_conv.forward(
70-
X, Y, W, edge_index[0], edge_index[1]
71-
) # Z has shape [node_ct, z_ir.dim]
72-
print(torch.norm(Z))
73-
# ===============================
74-
75-
# ===============================
76-
_, sender_perm = edge_index.sort_by("col") # Sort by sender index
77-
edge_index, receiver_perm = edge_index.sort_by("row") # Sort by receiver index
78-
79-
# Now we can use the faster deterministic algorithm
80-
tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=True)
81-
Z = tp_conv.forward(
82-
X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm
83-
)
84-
print(torch.norm(Z))
85-
# ===============================
86-
assert True

0 commit comments

Comments
 (0)