diff --git a/openequivariance/extension/util/backend_cuda.hpp b/openequivariance/extension/util/backend_cuda.hpp index c869dfe..364186f 100644 --- a/openequivariance/extension/util/backend_cuda.hpp +++ b/openequivariance/extension/util/backend_cuda.hpp @@ -6,6 +6,7 @@ #include #include #include +#include using namespace std; using Stream = cudaStream_t; @@ -160,6 +161,8 @@ class __attribute__((visibility("default"))) CUJITKernel { CUlibrary library; + vector supported_archs; + vector kernel_names; vector kernels; @@ -167,6 +170,15 @@ class __attribute__((visibility("default"))) CUJITKernel { string kernel_plaintext; CUJITKernel(string plaintext) : kernel_plaintext(plaintext) { + + int num_supported_archs; + NVRTC_SAFE_CALL( + nvrtcGetNumSupportedArchs(&num_supported_archs)); + + supported_archs.resize(num_supported_archs); + NVRTC_SAFE_CALL( + nvrtcGetSupportedArchs(supported_archs.data())); + NVRTC_SAFE_CALL( nvrtcCreateProgram( &prog, // prog @@ -196,6 +208,21 @@ class __attribute__((visibility("default"))) CUJITKernel { throw std::logic_error("Kernel names and template parameters must have the same size!"); } + int device_arch = cu_major * 10 + cu_minor; + if (std::find(supported_archs.begin(), supported_archs.end(), device_arch) == supported_archs.end()){ + int nvrtc_version_major, nvrtc_version_minor; + NVRTC_SAFE_CALL( + nvrtcVersion(&nvrtc_version_major, &nvrtc_version_minor)); + + throw std::runtime_error("NVRTC version " + + std::to_string(nvrtc_version_major) + + "." + + std::to_string(nvrtc_version_minor) + + " does not support device architecture " + + std::to_string(device_arch) + ); + } + for(unsigned int kernel = 0; kernel < kernel_names_i.size(); kernel++) { string kernel_name = kernel_names_i[kernel]; vector &template_params = template_param_list[kernel];