diff --git a/cuda_core/cuda/core/typing.py b/cuda_core/cuda/core/typing.py new file mode 100644 index 0000000000..f516e04554 --- /dev/null +++ b/cuda_core/cuda/core/typing.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Public type aliases and protocols used in cuda.core API signatures.""" + +from cuda.core._memory._buffer import DevicePointerT +from cuda.core._memory._virtual_memory_resource import ( + VirtualMemoryAccessTypeT, + VirtualMemoryAllocationTypeT, + VirtualMemoryGranularityT, + VirtualMemoryHandleTypeT, + VirtualMemoryLocationTypeT, +) +from cuda.core._stream import IsStreamT + +__all__ = [ + "DevicePointerT", + "IsStreamT", + "VirtualMemoryAccessTypeT", + "VirtualMemoryAllocationTypeT", + "VirtualMemoryGranularityT", + "VirtualMemoryHandleTypeT", + "VirtualMemoryLocationTypeT", +] diff --git a/cuda_core/docs/source/api_private.rst b/cuda_core/docs/source/api_private.rst index 0aa88d1d64..0c36914c5e 100644 --- a/cuda_core/docs/source/api_private.rst +++ b/cuda_core/docs/source/api_private.rst @@ -16,12 +16,12 @@ CUDA runtime .. autosummary:: :toctree: generated/ - _memory._buffer.DevicePointerT - _memory._virtual_memory_resource.VirtualMemoryAllocationTypeT - _memory._virtual_memory_resource.VirtualMemoryLocationTypeT - _memory._virtual_memory_resource.VirtualMemoryGranularityT - _memory._virtual_memory_resource.VirtualMemoryAccessTypeT - _memory._virtual_memory_resource.VirtualMemoryHandleTypeT + typing.DevicePointerT + typing.VirtualMemoryAllocationTypeT + typing.VirtualMemoryLocationTypeT + typing.VirtualMemoryGranularityT + typing.VirtualMemoryAccessTypeT + typing.VirtualMemoryHandleTypeT _module.KernelAttributes _module.KernelOccupancy _module.ParamInfo @@ -41,4 +41,4 @@ CUDA protocols :toctree: generated/ :template: protocol.rst - _stream.IsStreamT + typing.IsStreamT diff --git a/cuda_core/tests/test_typing_imports.py b/cuda_core/tests/test_typing_imports.py new file mode 100644 index 0000000000..8e4ab78d39 --- /dev/null +++ b/cuda_core/tests/test_typing_imports.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for cuda.core.typing public type aliases and protocols.""" + + +def test_typing_module_imports(): + """All type aliases and protocols are importable from cuda.core.typing.""" + from cuda.core.typing import ( + DevicePointerT, + IsStreamT, + VirtualMemoryAccessTypeT, + VirtualMemoryAllocationTypeT, + VirtualMemoryGranularityT, + VirtualMemoryHandleTypeT, + VirtualMemoryLocationTypeT, + ) + + # Verify they are not None (sanity check) + for name, obj in ( + ("DevicePointerT", DevicePointerT), + ("IsStreamT", IsStreamT), + ("VirtualMemoryAccessTypeT", VirtualMemoryAccessTypeT), + ("VirtualMemoryAllocationTypeT", VirtualMemoryAllocationTypeT), + ("VirtualMemoryGranularityT", VirtualMemoryGranularityT), + ("VirtualMemoryHandleTypeT", VirtualMemoryHandleTypeT), + ("VirtualMemoryLocationTypeT", VirtualMemoryLocationTypeT), + ): + assert obj is not None, f"{name} should not be None" + + +def test_typing_matches_private_definitions(): + """cuda.core.typing re-exports match the original private definitions.""" + from cuda.core._memory._buffer import DevicePointerT as _DevicePointerT + from cuda.core._memory._virtual_memory_resource import ( + VirtualMemoryAccessTypeT as _VirtualMemoryAccessTypeT, + VirtualMemoryAllocationTypeT as _VirtualMemoryAllocationTypeT, + VirtualMemoryGranularityT as _VirtualMemoryGranularityT, + VirtualMemoryHandleTypeT as _VirtualMemoryHandleTypeT, + VirtualMemoryLocationTypeT as _VirtualMemoryLocationTypeT, + ) + from cuda.core._stream import IsStreamT as _IsStreamT + from cuda.core.typing import ( + DevicePointerT, + IsStreamT, + VirtualMemoryAccessTypeT, + VirtualMemoryAllocationTypeT, + VirtualMemoryGranularityT, + VirtualMemoryHandleTypeT, + VirtualMemoryLocationTypeT, + ) + + assert DevicePointerT is _DevicePointerT + assert IsStreamT is _IsStreamT + assert VirtualMemoryAccessTypeT is _VirtualMemoryAccessTypeT + assert VirtualMemoryAllocationTypeT is _VirtualMemoryAllocationTypeT + assert VirtualMemoryGranularityT is _VirtualMemoryGranularityT + assert VirtualMemoryHandleTypeT is _VirtualMemoryHandleTypeT + assert VirtualMemoryLocationTypeT is _VirtualMemoryLocationTypeT