The GPU memory optimization with Larger angular momentum #161
-
|
As is known, the headaches of CGTP are the computational inefficiency and GPU memory overhead. OpenEquivariance achieves excellent acceleration over e3nn by optimizing the GPU utility as reported in the paper. However, the acceleration from the algorithmic methods like SO(2) operation and Gaunt TP can reduce the complexity from O(L^6) to between O(L^3) and O(L^2logL), which makes it very fast while being friendly to GPU memory. So, has OpenEquivariance been tested regarding (1) the optimization for GPU memory, especially when the angular momentum L is very large, (2) acceleration for scenarios with larger angular momentum L? Thanks. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
|
How large do you want to make L? We've tested up to L=7,7,7 interactions with high GPU utilization, see line 87 of tests/benchmark.py. You can benchmark it yourself by modifying the irreps in the example at the top of the README to have high L values. We haven't benchmarked against those methods (but you are welcome to). In terms of memory for our implementation: the nonzeros for the TP are coded into the instruction stream, so as L increases we don't really run into memory constraints (but the I-cache will eventually spill, resulting in slowdown). I'd say we are still pretty memory and compute efficient overall :) The SO2 convolution is very clever, but you have to rotate irreps by a unique transform on each edge of the atomic graph, so it's much more complicated from an HPC perspective. See page 38, bottom-most paragraph of my dissertation for an analysis of the tradeoffs and a full explanation of their method. I think against a simple implementation of SO2 convolution in PyTorch, we'd be very competitive (if not outright better). Gaunt TP is potentially very fast, but it's a distinct operation, see analysis here, that may sacrifice some expressive capability. And again, if you wish to implement optimizations like kernel fusion with graph convolution as we do, the simple PyTorch implementation probably won't cut it. You would need to fuse the FFTs / multiplication at the kernel level with the node edge aggregation, which is a lot of engineering. But happy to see numbers to the contrary on these, these are back of the envelope guesses. |
Beta Was this translation helpful? Give feedback.
How large do you want to make L? We've tested up to L=7,7,7 interactions with high GPU utilization, see line 87 of tests/benchmark.py. You can benchmark it yourself by modifying the irreps in the example at the top of the README to have high L values.
We haven't benchmarked against those methods (but you are welcome to). In terms of memory for our implementation: the nonzeros for the TP are coded into the instruction stream, so as L increases we don't really run into memory constraints (but the I-cache will eventually spill, resulting in slowdown). I'd say we are still pretty memory and compute efficient overall :)
The SO2 convolution is very clever, but you have to rotate irreps by a uni…