Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions cuda_core/cuda/core/_memory/_managed_memory_resource.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ from cuda.core._utils.cuda_utils cimport (
HANDLE_RETURN,
check_or_create_options,
)
from cuda.core._utils.cuda_utils import CUDAError

from dataclasses import dataclass
import threading
Expand Down Expand Up @@ -226,12 +227,25 @@ cdef inline _MMR_init(ManagedMemoryResource self, options):
)

if opts is None:
MP_init_current_pool(
self,
loc_type,
loc_id,
cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_MANAGED,
)
try:
MP_init_current_pool(
self,
loc_type,
loc_id,
cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_MANAGED,
)
except CUDAError as e:
if "CUDA_ERROR_NOT_SUPPORTED" in str(e):
from .._device import Device
if not Device().properties.concurrent_managed_access:
raise RuntimeError(
"The default memory pool on this device does not support "
"managed allocations (concurrent managed access is not "
"available). Use "
"ManagedMemoryResource(options=ManagedMemoryResourceOptions(...)) "
"to create a dedicated managed pool."
) from e
raise
else:
MP_init_create_pool(
self,
Expand Down
4 changes: 3 additions & 1 deletion cuda_core/cuda/core/_memory/_memory_pool.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,9 @@ cdef int MP_init_current_pool(
self._h_pool = create_mempool_handle_ref(pool)
self._mempool_owned = False
ELSE:
raise RuntimeError("not supported")
raise RuntimeError(
"Getting the current memory pool requires CUDA 13.0 or later"
)
Comment on lines +260 to +262
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated fix, but a needed improvement.

return 0


Expand Down
7 changes: 7 additions & 0 deletions cuda_core/tests/test_managed_memory_warning.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ def device_without_concurrent_managed_access(init_cuda):
return device


@requires_cuda_13
def test_default_pool_error_without_concurrent_access(device_without_concurrent_managed_access):
"""ManagedMemoryResource() raises RuntimeError when the default pool doesn't support managed."""
with pytest.raises(RuntimeError, match="does not support managed allocations"):
ManagedMemoryResource()


@requires_cuda_13
def test_warning_emitted(device_without_concurrent_managed_access):
"""ManagedMemoryResource emits a warning when concurrent managed access is unsupported."""
Expand Down
Loading