Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 47 additions & 13 deletions openequivariance/extlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
import sysconfig
from pathlib import Path

global torch
import torch

from openequivariance.benchmark.logging_utils import getLogger

oeq_root = str(Path(__file__).parent.parent)

build_ext = True
TORCH_COMPILE = True
TORCH_VERSION_CUDA_OR_HIP = torch.version.cuda or torch.version.hip
torch_module, generic_module = None, None
postprocess_kernel = lambda kernel: kernel # noqa : E731

Expand Down Expand Up @@ -38,12 +42,9 @@
import openequivariance.extlib.generic_module

generic_module = openequivariance.extlib.generic_module
else:
elif TORCH_VERSION_CUDA_OR_HIP:
from torch.utils.cpp_extension import library_paths, include_paths

global torch
import torch

extra_cflags = ["-O3"]
generic_sources = ["generic_module.cpp"]
torch_sources = ["libtorch_tp_jit.cpp"]
Expand Down Expand Up @@ -128,13 +129,46 @@ def postprocess(kernel):
"Could not compile integrated PyTorch wrapper. Falling back to Pybind11"
+ f", but JITScript, compile fullgraph, and export will fail.\n {torch_compile_exception}"
)
else:
TORCH_COMPILE = False


def _raise_import_error_helper(import_target: str):
if not TORCH_VERSION_CUDA_OR_HIP:
raise ImportError(
f"Could not import {import_target}: OpenEquivariance's torch extension was not built because torch.version.cuda || torch.version.hip is false"
)


if TORCH_VERSION_CUDA_OR_HIP:
from generic_module import (
JITTPImpl,
JITConvImpl,
GroupMM_F32,
GroupMM_F64,
DeviceProp,
DeviceBuffer,
GPUTimer,
)
else:

def JITTPImpl(*args, **kwargs):
_raise_import_error_helper("JITTPImpl")

def JITConvImpl(*args, **kwargs):
_raise_import_error_helper("JITConvImpl")

def GroupMM_F32(*args, **kwargs):
_raise_import_error_helper("GroupMM_F32")

def GroupMM_F64(*args, **kwargs):
_raise_import_error_helper("GroupMM_F64")

def DeviceProp(*args, **kwargs):
_raise_import_error_helper("DeviceProp")

def DeviceBuffer(*args, **kwargs):
_raise_import_error_helper("DeviceBuffer")

from generic_module import (
JITTPImpl,
JITConvImpl,
GroupMM_F32,
GroupMM_F64,
DeviceProp,
DeviceBuffer,
GPUTimer,
)
def GPUTimer(*args, **kwargs):
_raise_import_error_helper("GPUTimer")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ authors = [
description = "A fast GPU JIT kernel generator for the Clebsch-Gordon Tensor Product"
requires-python = ">=3.10"
dependencies = [
"setuptools",
"ninja",
"jinja2",
"numpy",
Expand Down
75 changes: 75 additions & 0 deletions tests/examples_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
def test_tutorial():
import torch
import e3nn.o3 as o3

gen = torch.Generator(device="cuda")

batch_size = 1000
X_ir, Y_ir, Z_ir = o3.Irreps("1x2e"), o3.Irreps("1x3e"), o3.Irreps("1x2e")
X = torch.rand(batch_size, X_ir.dim, device="cuda", generator=gen)
Y = torch.rand(batch_size, Y_ir.dim, device="cuda", generator=gen)

instructions = [(0, 0, 0, "uvu", True)]

tp_e3nn = o3.TensorProduct(
X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False
).to("cuda")
W = torch.rand(batch_size, tp_e3nn.weight_numel, device="cuda", generator=gen)

Z = tp_e3nn(X, Y, W)
print(torch.norm(Z))
# ===============================

# ===============================
import openequivariance as oeq

problem = oeq.TPProblem(
X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False
)
tp_fast = oeq.TensorProduct(problem, torch_op=True)

Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier
print(torch.norm(Z))
# ===============================

# Graph Convolution
# ===============================
from torch_geometric import EdgeIndex

node_ct, nonzero_ct = 3, 4

# Receiver, sender indices for message passing GNN
edge_index = EdgeIndex(
[
[0, 1, 1, 2], # Receiver
[1, 0, 2, 1],
], # Sender
device="cuda",
dtype=torch.long,
)

X = torch.rand(node_ct, X_ir.dim, device="cuda", generator=gen)
Y = torch.rand(nonzero_ct, Y_ir.dim, device="cuda", generator=gen)
W = torch.rand(nonzero_ct, problem.weight_numel, device="cuda", generator=gen)

tp_conv = oeq.TensorProductConv(
problem, torch_op=True, deterministic=False
) # Reuse problem from earlier
Z = tp_conv.forward(
X, Y, W, edge_index[0], edge_index[1]
) # Z has shape [node_ct, z_ir.dim]
print(torch.norm(Z))
# ===============================

# ===============================
_, sender_perm = edge_index.sort_by("col") # Sort by sender index
edge_index, receiver_perm = edge_index.sort_by("row") # Sort by receiver index

# Now we can use the faster deterministic algorithm
tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=True)
Z = tp_conv.forward(
X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm
)
print(torch.norm(Z))
# ===============================
assert True
77 changes: 0 additions & 77 deletions tests/import_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,80 +7,3 @@ def test_import():
assert openequivariance.__version__ is not None
assert openequivariance.__version__ != "0.0.0"
assert openequivariance.__version__ == version("openequivariance")


def test_tutorial():
import torch
import e3nn.o3 as o3

gen = torch.Generator(device="cuda")

batch_size = 1000
X_ir, Y_ir, Z_ir = o3.Irreps("1x2e"), o3.Irreps("1x3e"), o3.Irreps("1x2e")
X = torch.rand(batch_size, X_ir.dim, device="cuda", generator=gen)
Y = torch.rand(batch_size, Y_ir.dim, device="cuda", generator=gen)

instructions = [(0, 0, 0, "uvu", True)]

tp_e3nn = o3.TensorProduct(
X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False
).to("cuda")
W = torch.rand(batch_size, tp_e3nn.weight_numel, device="cuda", generator=gen)

Z = tp_e3nn(X, Y, W)
print(torch.norm(Z))
# ===============================

# ===============================
import openequivariance as oeq

problem = oeq.TPProblem(
X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False
)
tp_fast = oeq.TensorProduct(problem, torch_op=True)

Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier
print(torch.norm(Z))
# ===============================

# Graph Convolution
# ===============================
from torch_geometric import EdgeIndex

node_ct, nonzero_ct = 3, 4

# Receiver, sender indices for message passing GNN
edge_index = EdgeIndex(
[
[0, 1, 1, 2], # Receiver
[1, 0, 2, 1],
], # Sender
device="cuda",
dtype=torch.long,
)

X = torch.rand(node_ct, X_ir.dim, device="cuda", generator=gen)
Y = torch.rand(nonzero_ct, Y_ir.dim, device="cuda", generator=gen)
W = torch.rand(nonzero_ct, problem.weight_numel, device="cuda", generator=gen)

tp_conv = oeq.TensorProductConv(
problem, torch_op=True, deterministic=False
) # Reuse problem from earlier
Z = tp_conv.forward(
X, Y, W, edge_index[0], edge_index[1]
) # Z has shape [node_ct, z_ir.dim]
print(torch.norm(Z))
# ===============================

# ===============================
_, sender_perm = edge_index.sort_by("col") # Sort by sender index
edge_index, receiver_perm = edge_index.sort_by("row") # Sort by receiver index

# Now we can use the faster deterministic algorithm
tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=True)
Z = tp_conv.forward(
X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm
)
print(torch.norm(Z))
# ===============================
assert True