diff --git a/cuda_core/cuda/core/_memory/_managed_memory_resource.pyx b/cuda_core/cuda/core/_memory/_managed_memory_resource.pyx index 4f24bd8d11..8ae1633c6a 100644 --- a/cuda_core/cuda/core/_memory/_managed_memory_resource.pyx +++ b/cuda_core/cuda/core/_memory/_managed_memory_resource.pyx @@ -11,6 +11,7 @@ from cuda.core._utils.cuda_utils cimport ( HANDLE_RETURN, check_or_create_options, ) +from cuda.core._utils.cuda_utils import CUDAError from dataclasses import dataclass import threading @@ -226,12 +227,25 @@ cdef inline _MMR_init(ManagedMemoryResource self, options): ) if opts is None: - MP_init_current_pool( - self, - loc_type, - loc_id, - cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_MANAGED, - ) + try: + MP_init_current_pool( + self, + loc_type, + loc_id, + cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_MANAGED, + ) + except CUDAError as e: + if "CUDA_ERROR_NOT_SUPPORTED" in str(e): + from .._device import Device + if not Device().properties.concurrent_managed_access: + raise RuntimeError( + "The default memory pool on this device does not support " + "managed allocations (concurrent managed access is not " + "available). Use " + "ManagedMemoryResource(options=ManagedMemoryResourceOptions(...)) " + "to create a dedicated managed pool." + ) from e + raise else: MP_init_create_pool( self, diff --git a/cuda_core/cuda/core/_memory/_memory_pool.pyx b/cuda_core/cuda/core/_memory/_memory_pool.pyx index a37ea17ab3..a24f75e027 100644 --- a/cuda_core/cuda/core/_memory/_memory_pool.pyx +++ b/cuda_core/cuda/core/_memory/_memory_pool.pyx @@ -257,7 +257,9 @@ cdef int MP_init_current_pool( self._h_pool = create_mempool_handle_ref(pool) self._mempool_owned = False ELSE: - raise RuntimeError("not supported") + raise RuntimeError( + "Getting the current memory pool requires CUDA 13.0 or later" + ) return 0 diff --git a/cuda_core/tests/test_managed_memory_warning.py b/cuda_core/tests/test_managed_memory_warning.py index 1f13f06f30..78015978e7 100644 --- a/cuda_core/tests/test_managed_memory_warning.py +++ b/cuda_core/tests/test_managed_memory_warning.py @@ -44,6 +44,13 @@ def device_without_concurrent_managed_access(init_cuda): return device +@requires_cuda_13 +def test_default_pool_error_without_concurrent_access(device_without_concurrent_managed_access): + """ManagedMemoryResource() raises RuntimeError when the default pool doesn't support managed.""" + with pytest.raises(RuntimeError, match="does not support managed allocations"): + ManagedMemoryResource() + + @requires_cuda_13 def test_warning_emitted(device_without_concurrent_managed_access): """ManagedMemoryResource emits a warning when concurrent managed access is unsupported."""