Skip to content

Commit 8d30ca2

Browse files
FlashTP Benchmark, Documentation Updates, and Release Prep (#153)
* Added FlashTP benchmarking. * Linted. * Updated documentation with FlashTP benchmarking instructions. * Minor changes to ensure that docs build. * Updated changelog for release prep. * Linted.
1 parent 4df7dd6 commit 8d30ca2

File tree

9 files changed

+153
-42
lines changed

9 files changed

+153
-42
lines changed

CHANGELOG.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,38 @@
11
## Latest Changes
22

3+
### v0.4.0 (2025-08-14)
4+
This release adds a benchmark against
5+
FlashTP, exposes weight reordering functions
6+
for e3nn compatibility, adds input validation,
7+
and provides rudimentary support for PyTorch
8+
automatic mixed precision (AMP). Our fused,
9+
JIT-compiled kernels exhibit up to 2x speedup
10+
over FlashTP!
11+
12+
**Added**:
13+
1. Both `TensorProduct` and `TensorProductConv`
14+
now have the methods `reoder_weights_from_e3nn`
15+
and `reorder_weights_to_e3nn`. These convert
16+
the buffer of trainable weights from / to e3nn's
17+
canonical ordering. See the API page for usage
18+
details.
19+
2. If you have FlashTP installed, see our
20+
documentation ("Tests and Benchmarks" page)
21+
to benchmark FlashTP against OpenEquivariance.
22+
3. Tensor product inputs with incorrect sizes or
23+
datatypes now trigger clear errors in advance of
24+
execution.
25+
4. OpenEquivariance now has some support for
26+
automatic mixed precision (AMP), but only if
27+
`TensorProduct` / `TensorProductConv` objects
28+
are constructed with `float32` precision for
29+
both `irrep_dtype` and `weight_dtype`.
30+
31+
**Fixed / Enhanced**:
32+
1. Added additional fake functions to remove
33+
warnings from TorchBind.
34+
2. Removed bloat from benchmarking code.
35+
336
### v0.3.0 (2025-06-22)
437
This release includes bugfixes and new opaque operations that
538
compose with `torch.compile`

docs/supported_ops.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,7 @@ toplevel. You can use our implementation by running
117117

118118
.. code-block::
119119
120-
from openequivariance.implementations.symmetric_contraction import SymmetricContraction as OEQSymmetricContraction
120+
from openequivariance.implementations.symmetric_contraction import SymmetricContraction as OEQSymmetricContraction
121+
122+
Some Github users report weak performance for the
123+
symmetric contraction backward pass; your mileage may vary.

docs/tests_and_benchmarks.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,16 @@ For GPUs besides the NVIDIA A100, the roofline slope / peak will be incorrect.
6767
The plots for the convolution fusion experiments also require a GPU
6868
with a minimum of 40GB of memory.
6969

70+
We recently added a benchmark against
71+
`FlashTP <https://github.com/SNU-ARC/flashTP>`_. To replicate it
72+
on your system, install FlashTP via ``pip`` and run
73+
74+
.. code-block:: bash
75+
76+
python tests/benchmark.py -o outputs/conv conv --plot --data data/molecular_structures -i cue_unfused oeq_scattersum flashtp cue_fused oeq_det oeq_atomic
77+
78+
OpenEquivariance exhibits up to 2x speedup over FlashTP's fused kernels.
79+
7080
List of GPUs Tested
7181
--------------------------------
7282
OpenEquivariance has been tested successfully the following GPUs. Submit a pull

openequivariance/benchmark/plotting/plot_convolution.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,8 @@ def plot_convolution(data_folder):
1919
data_folder = pathlib.Path(data_folder)
2020
benchmarks, metadata = load_benchmarks(data_folder)
2121

22-
implementations = [
23-
"CUEConvolution",
24-
"CUEConvolutionFused",
25-
"LoopUnrollConvScatterSum",
26-
"LoopUnrollConvAtomic",
27-
"LoopUnrollConvDeterministic",
28-
]
22+
implementations = metadata["implementations"]
23+
assert "CUEConvolution" in implementations
2924

3025
graphs = ["1drf_radius6.0", "covid_spike_radius3.0", "carbon_lattice_radius6.0"]
3126
graph_lmap = {
@@ -81,7 +76,7 @@ def plot_convolution(data_folder):
8176
rotate_xlabels=True,
8277
colormap=colormap,
8378
hatchmap=hatchmap,
84-
group_spacing=6.0,
79+
group_spacing=7.0,
8580
)
8681

8782
axes[i][j].set_xlabel(dtype_labelmap[dtype])

openequivariance/benchmark/plotting/plotting_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,9 @@ def set_size(w, h, ax=None):
392392
"CUEConvolutionFused": "cuE-fused",
393393
"LoopUnrollConvDeterministic": "fast-fused-det",
394394
"LoopUnrollConvAtomic": "fast-fused-atomic",
395+
"FlashTPConv": "flashtp",
395396
}
396-
colormap = {"e3nn": "lightblue", "cuE": "orange", "ours": "g"}
397+
colormap = {"e3nn": "lightblue", "cuE": "orange", "ours": "g", "flashtp": "purple"}
397398

398399
for key in ["fast-scattersum", "fast-fused-det", "fast-fused-atomic"]:
399400
colormap[key] = colormap["ours"]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
__all__ = [
2+
"FlashTPConv",
3+
]
4+
5+
import torch
6+
import numpy as np
7+
from openequivariance.implementations.convolution.ConvolutionBase import ConvolutionBase
8+
from openequivariance.implementations.utils import oeq_to_torch_dtype
9+
10+
11+
class FlashTPConv(ConvolutionBase):
12+
def __init__(self, config, *, idx_dtype=np.int64, torch_op=True):
13+
super().__init__(config, idx_dtype=idx_dtype, torch_op=torch_op)
14+
from flashTP_e3nn import uvu_TP
15+
16+
instructions = [
17+
(
18+
inst.i_in1,
19+
inst.i_in2,
20+
inst.i_out,
21+
inst.connection_mode,
22+
inst.has_weight,
23+
inst.path_weight,
24+
)
25+
for inst in config.instructions
26+
]
27+
28+
self.internal = uvu_TP(
29+
config.irreps_in1,
30+
config.irreps_in2,
31+
config.irreps_out,
32+
instructions,
33+
device="cuda",
34+
dtype=oeq_to_torch_dtype(config.irrep_dtype),
35+
)
36+
37+
def forward(self, L1_in, L2_in, weights, rows, cols, transpose_perm=None):
38+
return self.internal(
39+
L1_in, L2_in, weights, rows.to(torch.int), cols.to(torch.int)
40+
)
41+
42+
@staticmethod
43+
def name():
44+
return "FlashTPConv"

openequivariance/implementations/dtype_enum.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from enum import IntEnum
2-
from typing import Mapping
32
from types import MappingProxyType
43
import numpy as np
54
import torch
@@ -13,33 +12,31 @@ class DTypeEnum(IntEnum):
1312
UINT8 = 5
1413

1514

16-
dtype_to_enum: Mapping[torch.dtype | type[np.generic] | np.dtype, DTypeEnum] = (
17-
MappingProxyType(
18-
{
19-
torch.float32: DTypeEnum.FLOAT32,
20-
torch.float64: DTypeEnum.FLOAT64,
21-
torch.int32: DTypeEnum.INT32,
22-
torch.int64: DTypeEnum.INT64,
23-
torch.uint8: DTypeEnum.UINT8,
24-
# torch
25-
np.float32: DTypeEnum.FLOAT32,
26-
np.float64: DTypeEnum.FLOAT64,
27-
np.int32: DTypeEnum.INT32,
28-
np.int64: DTypeEnum.INT64,
29-
np.uint8: DTypeEnum.UINT8,
30-
# numpy generic
31-
np.dtype(np.float32): DTypeEnum.FLOAT32,
32-
np.dtype(np.float64): DTypeEnum.FLOAT64,
33-
np.dtype(np.int32): DTypeEnum.INT32,
34-
np.dtype(np.int64): DTypeEnum.INT64,
35-
np.dtype(np.uint8): DTypeEnum.UINT8,
36-
# numpy dtype
37-
}
38-
)
15+
dtype_to_enum = MappingProxyType(
16+
{
17+
torch.float32: DTypeEnum.FLOAT32,
18+
torch.float64: DTypeEnum.FLOAT64,
19+
torch.int32: DTypeEnum.INT32,
20+
torch.int64: DTypeEnum.INT64,
21+
torch.uint8: DTypeEnum.UINT8,
22+
# torch
23+
np.float32: DTypeEnum.FLOAT32,
24+
np.float64: DTypeEnum.FLOAT64,
25+
np.int32: DTypeEnum.INT32,
26+
np.int64: DTypeEnum.INT64,
27+
np.uint8: DTypeEnum.UINT8,
28+
# numpy generic
29+
np.dtype(np.float32): DTypeEnum.FLOAT32,
30+
np.dtype(np.float64): DTypeEnum.FLOAT64,
31+
np.dtype(np.int32): DTypeEnum.INT32,
32+
np.dtype(np.int64): DTypeEnum.INT64,
33+
np.dtype(np.uint8): DTypeEnum.UINT8,
34+
# numpy dtype
35+
}
3936
)
4037

4138

42-
enum_to_torch_dtype: Mapping[DTypeEnum, torch.dtype] = MappingProxyType(
39+
enum_to_torch_dtype = MappingProxyType(
4340
{
4441
DTypeEnum.FLOAT32: torch.float32,
4542
DTypeEnum.FLOAT64: torch.float64,

openequivariance/implementations/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,18 @@ def torch_to_oeq_dtype(torch_dtype) -> type[np.generic]:
106106
raise ValueError("Unsupported torch dtype!")
107107

108108

109+
def oeq_to_torch_dtype(oeq_dtype: type[np.generic]):
110+
global torch
111+
import torch
112+
113+
if oeq_dtype == np.float32:
114+
return torch.float32
115+
elif oeq_dtype == np.float64:
116+
return torch.float64
117+
else:
118+
raise ValueError("Unsupported numpy dtype!")
119+
120+
109121
def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]):
110122
"""
111123
mode=gpu_time may include PyTorch overhead

tests/benchmark.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939

4040
from openequivariance.implementations.convolution.CUEConv import CUEConv, CUEConvFused
41+
from openequivariance.implementations.convolution.FlashTPConv import FlashTPConv
4142
from openequivariance.benchmark.ConvBenchmarkSuite import ConvBenchmarkSuite, load_graph
4243

4344
from openequivariance.benchmark.problems import (
@@ -54,13 +55,22 @@
5455
CTPP = ChannelwiseTPP
5556
FCTPP = FullyConnectedTPProblem
5657

57-
implementation_map = {
58+
implementation_map_tp = {
5859
"e3nn": E3NNTensorProductCompiledMaxAutotuneCUDAGraphs,
5960
"e3nn_uncompiled": E3NNTensorProduct,
6061
"cue": CUETensorProduct,
6162
"oeq": TensorProduct,
6263
}
6364

65+
implementation_map_conv = {
66+
"cue_unfused": CUEConv,
67+
"oeq_scattersum": TensorProductConvScatterSum,
68+
"flashtp": FlashTPConv,
69+
"cue_fused": CUEConvFused,
70+
"oeq_det": TensorProductConvDeterministic,
71+
"oeq_atomic": TensorProductConvAtomic,
72+
}
73+
6474
datatype_map = {"float32": np.float32, "float64": np.float64}
6575

6676
roofline_configs = [
@@ -87,7 +97,7 @@ def benchmark_uvu(params):
8797
problem.weight_dtype = np.float64
8898
problems = mace_problems() + nequip_problems() + float64_problems
8999

90-
implementations = [implementation_map[impl] for impl in params.implementations]
100+
implementations = [implementation_map_tp[impl] for impl in params.implementations]
91101
directions = params.directions
92102

93103
tests = [
@@ -289,11 +299,7 @@ def benchmark_convolution(params):
289299
bench = ConvBenchmarkSuite(configs, test_name="convolution")
290300

291301
implementations = [
292-
TensorProductConvScatterSum,
293-
CUEConv,
294-
CUEConvFused,
295-
TensorProductConvDeterministic,
296-
TensorProductConvAtomic,
302+
implementation_map_conv[impl] for impl in params.implementations
297303
]
298304

299305
if params.limited_memory:
@@ -496,6 +502,16 @@ def plot(params):
496502
help="Disable tests requiring large amounts of memory.",
497503
)
498504
parser_conv.add_argument("--plot", action="store_true", help="Plot the results.")
505+
parser_conv.add_argument(
506+
"--implementations",
507+
"-i",
508+
type=str,
509+
nargs="+",
510+
default=["cue_unfused", "oeq_scattersum", "cue_fused", "oeq_atomic", "oeq_det"],
511+
help="Implementations to benchmark",
512+
choices=list(implementation_map_conv.keys()),
513+
)
514+
499515
parser_conv.set_defaults(func=benchmark_convolution)
500516

501517
parser_uvw = subparsers.add_parser(

0 commit comments

Comments
 (0)