From 1633628cdfa86ac692c3e14c1011853e398e7ef1 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Tue, 31 Mar 2026 09:48:55 -0700 Subject: [PATCH 1/7] Reorganize graph test files for clarity Rename test files to reflect what they actually test: - test_basic -> test_graph_builder (stream capture tests) - test_conditional -> test_graph_builder_conditional - test_advanced -> test_graph_update (moved child_graph and stream_lifetime tests into test_graph_builder) - test_capture_alloc -> test_graph_memory_resource - test_explicit* -> test_graphdef* Made-with: Cursor --- cuda_core/tests/graph/test_device_launch.py | 9 +- .../{test_basic.py => test_graph_builder.py} | 84 ++++++++++++++++- ...l.py => test_graph_builder_conditional.py} | 2 +- ...alloc.py => test_graph_memory_resource.py} | 2 +- ...{test_advanced.py => test_graph_update.py} | 92 +------------------ .../{test_explicit.py => test_graphdef.py} | 2 +- ...icit_errors.py => test_graphdef_errors.py} | 7 +- ...ration.py => test_graphdef_integration.py} | 26 +----- ..._lifetime.py => test_graphdef_lifetime.py} | 7 +- 9 files changed, 93 insertions(+), 138 deletions(-) rename cuda_core/tests/graph/{test_basic.py => test_graph_builder.py} (70%) rename cuda_core/tests/graph/{test_conditional.py => test_graph_builder_conditional.py} (99%) rename cuda_core/tests/graph/{test_capture_alloc.py => test_graph_memory_resource.py} (99%) rename cuda_core/tests/graph/{test_advanced.py => test_graph_update.py} (50%) rename cuda_core/tests/graph/{test_explicit.py => test_graphdef.py} (99%) rename cuda_core/tests/graph/{test_explicit_errors.py => test_graphdef_errors.py} (96%) rename cuda_core/tests/graph/{test_explicit_integration.py => test_graphdef_integration.py} (93%) rename cuda_core/tests/graph/{test_explicit_lifetime.py => test_graphdef_lifetime.py} (98%) diff --git a/cuda_core/tests/graph/test_device_launch.py b/cuda_core/tests/graph/test_device_launch.py index d302978028..3a5d12c28b 100644 --- a/cuda_core/tests/graph/test_device_launch.py +++ b/cuda_core/tests/graph/test_device_launch.py @@ -1,14 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE -"""Device-side graph launch tests. - -Device-side graph launch allows a kernel running on the GPU to launch a CUDA graph. -This feature requires: -- CUDA 12.0+ -- Hopper architecture (sm_90+) -- The kernel calling cudaGraphLaunch() must itself be launched from within a graph -""" +"""Tests for device-side graph launch (GPU kernel launching a CUDA graph).""" import numpy as np import pytest diff --git a/cuda_core/tests/graph/test_basic.py b/cuda_core/tests/graph/test_graph_builder.py similarity index 70% rename from cuda_core/tests/graph/test_basic.py rename to cuda_core/tests/graph/test_graph_builder.py index af1c744dbf..d5906128d8 100644 --- a/cuda_core/tests/graph/test_basic.py +++ b/cuda_core/tests/graph/test_graph_builder.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE -"""Basic graph construction and topology tests.""" +"""GraphBuilder stream capture tests.""" import numpy as np import pytest @@ -205,3 +205,85 @@ def read_byte(data): launch_stream.sync() assert result[0] == 0xAB + + +@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+") +def test_graph_child_graph(init_cuda): + mod = compile_common_kernels() + add_one = mod.get_kernel("add_one") + + # Allocate memory + launch_stream = Device().create_stream() + mr = LegacyPinnedMemoryResource() + b = mr.allocate(8) + arr = np.from_dlpack(b).view(np.int32) + arr[0] = 0 + arr[1] = 0 + + # Capture the child graph + gb_child = Device().create_graph_builder().begin_building() + launch(gb_child, LaunchConfig(grid=1, block=1), add_one, arr[1:].ctypes.data) + launch(gb_child, LaunchConfig(grid=1, block=1), add_one, arr[1:].ctypes.data) + launch(gb_child, LaunchConfig(grid=1, block=1), add_one, arr[1:].ctypes.data) + gb_child.end_building() + + # Capture the parent graph + gb_parent = Device().create_graph_builder().begin_building() + launch(gb_parent, LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data) + + ## Add child + try: + gb_parent.add_child(gb_child) + except NotImplementedError as e: + with pytest.raises( + NotImplementedError, + match="^Launching child graphs is not implemented for versions older than CUDA 12", + ): + raise e + gb_parent.end_building() + b.close() + pytest.skip("Launching child graphs is not implemented for versions older than CUDA 12") + + launch(gb_parent, LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data) + graph = gb_parent.end_building().complete() + + # Parent updates first value, child updates second value + assert arr[0] == 0 + assert arr[1] == 0 + graph.launch(launch_stream) + launch_stream.sync() + assert arr[0] == 2 + assert arr[1] == 3 + + b.close() + + +def test_graph_stream_lifetime(init_cuda): + mod = compile_common_kernels() + empty_kernel = mod.get_kernel("empty_kernel") + + # Create simple graph from device + gb = Device().create_graph_builder().begin_building() + launch(gb, LaunchConfig(grid=1, block=1), empty_kernel) + graph = gb.end_building().complete() + + # Destroy simple graph and builder + gb.close() + graph.close() + + # Create simple graph from stream + stream = Device().create_stream() + gb = stream.create_graph_builder().begin_building() + launch(gb, LaunchConfig(grid=1, block=1), empty_kernel) + graph = gb.end_building().complete() + + # Destroy simple graph and builder + gb.close() + graph.close() + + # Verify the stream can still launch work + launch(stream, LaunchConfig(grid=1, block=1), empty_kernel) + stream.sync() + + # Destroy the stream + stream.close() diff --git a/cuda_core/tests/graph/test_conditional.py b/cuda_core/tests/graph/test_graph_builder_conditional.py similarity index 99% rename from cuda_core/tests/graph/test_conditional.py rename to cuda_core/tests/graph/test_graph_builder_conditional.py index 157d23e4f5..480179c4fc 100644 --- a/cuda_core/tests/graph/test_conditional.py +++ b/cuda_core/tests/graph/test_graph_builder_conditional.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE -"""Conditional graph node tests (if, if-else, switch, while).""" +"""Tests for GraphBuilder conditional node capture (if, if-else, switch, while).""" import ctypes diff --git a/cuda_core/tests/graph/test_capture_alloc.py b/cuda_core/tests/graph/test_graph_memory_resource.py similarity index 99% rename from cuda_core/tests/graph/test_capture_alloc.py rename to cuda_core/tests/graph/test_graph_memory_resource.py index 5cb23fd022..fe47ef2d68 100644 --- a/cuda_core/tests/graph/test_capture_alloc.py +++ b/cuda_core/tests/graph/test_graph_memory_resource.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE -"""Graph memory resource tests.""" +"""Tests for GraphMemoryResource allocation and attributes during graph capture.""" import pytest from helpers import IS_WINDOWS, IS_WSL diff --git a/cuda_core/tests/graph/test_advanced.py b/cuda_core/tests/graph/test_graph_update.py similarity index 50% rename from cuda_core/tests/graph/test_advanced.py rename to cuda_core/tests/graph/test_graph_update.py index 9d4f1b3040..baa9a50313 100644 --- a/cuda_core/tests/graph/test_advanced.py +++ b/cuda_core/tests/graph/test_graph_update.py @@ -1,68 +1,15 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE -"""Advanced graph feature tests (child graphs, update, stream lifetime).""" +"""Tests for whole-graph update (Graph.update).""" import numpy as np import pytest -from helpers.graph_kernels import compile_common_kernels, compile_conditional_kernels +from helpers.graph_kernels import compile_conditional_kernels from cuda.core import Device, LaunchConfig, LegacyPinnedMemoryResource, launch -@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+") -def test_graph_child_graph(init_cuda): - mod = compile_common_kernels() - add_one = mod.get_kernel("add_one") - - # Allocate memory - launch_stream = Device().create_stream() - mr = LegacyPinnedMemoryResource() - b = mr.allocate(8) - arr = np.from_dlpack(b).view(np.int32) - arr[0] = 0 - arr[1] = 0 - - # Capture the child graph - gb_child = Device().create_graph_builder().begin_building() - launch(gb_child, LaunchConfig(grid=1, block=1), add_one, arr[1:].ctypes.data) - launch(gb_child, LaunchConfig(grid=1, block=1), add_one, arr[1:].ctypes.data) - launch(gb_child, LaunchConfig(grid=1, block=1), add_one, arr[1:].ctypes.data) - gb_child.end_building() - - # Capture the parent graph - gb_parent = Device().create_graph_builder().begin_building() - launch(gb_parent, LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data) - - ## Add child - try: - gb_parent.add_child(gb_child) - except NotImplementedError as e: - with pytest.raises( - NotImplementedError, - match="^Launching child graphs is not implemented for versions older than CUDA 12", - ): - raise e - gb_parent.end_building() - b.close() - pytest.skip("Launching child graphs is not implemented for versions older than CUDA 12") - - launch(gb_parent, LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data) - graph = gb_parent.end_building().complete() - - # Parent updates first value, child updates second value - assert arr[0] == 0 - assert arr[1] == 0 - graph.launch(launch_stream) - launch_stream.sync() - assert arr[0] == 2 - assert arr[1] == 3 - - # Close the memory resource now because the garbage collected might - # de-allocate it during the next graph builder process - b.close() - - @pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+") def test_graph_update(init_cuda): mod = compile_conditional_kernels(int) @@ -151,37 +98,4 @@ def build_graph(condition_value): assert arr[1] == 3 assert arr[2] == 3 - # Close the memory resource now because the garbage collected might - # de-allocate it during the next graph builder process b.close() - - -def test_graph_stream_lifetime(init_cuda): - mod = compile_common_kernels() - empty_kernel = mod.get_kernel("empty_kernel") - - # Create simple graph from device - gb = Device().create_graph_builder().begin_building() - launch(gb, LaunchConfig(grid=1, block=1), empty_kernel) - graph = gb.end_building().complete() - - # Destroy simple graph and builder - gb.close() - graph.close() - - # Create simple graph from stream - stream = Device().create_stream() - gb = stream.create_graph_builder().begin_building() - launch(gb, LaunchConfig(grid=1, block=1), empty_kernel) - graph = gb.end_building().complete() - - # Destroy simple graph and builder - gb.close() - graph.close() - - # Verify the stream can still launch work - launch(stream, LaunchConfig(grid=1, block=1), empty_kernel) - stream.sync() - - # Destroy the stream - stream.close() diff --git a/cuda_core/tests/graph/test_explicit.py b/cuda_core/tests/graph/test_graphdef.py similarity index 99% rename from cuda_core/tests/graph/test_explicit.py rename to cuda_core/tests/graph/test_graphdef.py index 33826cb5fd..30a7f05c98 100644 --- a/cuda_core/tests/graph/test_explicit.py +++ b/cuda_core/tests/graph/test_graphdef.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE -"""Tests for explicit CUDA graph construction (GraphDef and GraphNode).""" +"""Tests for GraphDef topology, node types, instantiation, and execution.""" from collections.abc import Callable from dataclasses import dataclass, field diff --git a/cuda_core/tests/graph/test_explicit_errors.py b/cuda_core/tests/graph/test_graphdef_errors.py similarity index 96% rename from cuda_core/tests/graph/test_explicit_errors.py rename to cuda_core/tests/graph/test_graphdef_errors.py index 53e9d52bad..09c3bf8ec4 100644 --- a/cuda_core/tests/graph/test_explicit_errors.py +++ b/cuda_core/tests/graph/test_graphdef_errors.py @@ -1,12 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE -"""Tests for error handling, input validation, and edge cases in explicit graphs. - -These tests verify that the explicit graph API properly validates inputs, -raises appropriate exceptions for misuse, and handles boundary conditions -correctly. -""" +"""Tests for GraphDef input validation, error handling, and edge cases.""" import ctypes diff --git a/cuda_core/tests/graph/test_explicit_integration.py b/cuda_core/tests/graph/test_graphdef_integration.py similarity index 93% rename from cuda_core/tests/graph/test_explicit_integration.py rename to cuda_core/tests/graph/test_graphdef_integration.py index 1af975fb44..bb7eab0f8e 100644 --- a/cuda_core/tests/graph/test_explicit_integration.py +++ b/cuda_core/tests/graph/test_graphdef_integration.py @@ -1,31 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE -"""Integration tests for explicit CUDA graph construction. - -Three test scenarios exercise complementary subsets of node types: - -test_heat_diffusion - 1D heat bar evolving toward steady state via finite differences. - Exercises: AllocNode, FreeNode, MemsetNode, ChildGraphNode, - EmptyNode, EventRecordNode, EventWaitNode, WhileNode, KernelNode, - MemcpyNode, HostCallbackNode. - -test_bisection_root - Find sqrt(2) by bisecting f(x) = x^2 - 2 on [0, 2], with an - optional Newton polish step. - Exercises: IfElseNode (interval halving), IfNode (refinement - guard), WhileNode, KernelNode, AllocNode, MemsetNode, MemcpyNode, - HostCallbackNode, FreeNode, EmptyNode. - -test_switch_dispatch - Apply one of four element-wise transforms selected at graph - creation time via a switch condition. - Exercises: SwitchNode, KernelNode, AllocNode, MemsetNode, - MemcpyNode, FreeNode. - -Together the three tests cover all 14 explicit-graph node types. -""" +"""End-to-end integration tests exercising all GraphDef node types in realistic scenarios.""" import ctypes diff --git a/cuda_core/tests/graph/test_explicit_lifetime.py b/cuda_core/tests/graph/test_graphdef_lifetime.py similarity index 98% rename from cuda_core/tests/graph/test_explicit_lifetime.py rename to cuda_core/tests/graph/test_graphdef_lifetime.py index e087e014d9..1fa1c025c2 100644 --- a/cuda_core/tests/graph/test_explicit_lifetime.py +++ b/cuda_core/tests/graph/test_graphdef_lifetime.py @@ -1,12 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE -"""Tests for resource lifetime management in explicit CUDA graphs. - -These tests verify that the RAII mechanism in GraphHandle correctly -prevents dangling references when parent Python objects are deleted -while child/body graph references remain alive. -""" +"""Tests for GraphDef resource lifetime management and RAII correctness.""" import gc From d3c4ef4fdd08f89336b4fe8a14f0ee5a9abb7631 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Tue, 31 Mar 2026 11:24:50 -0700 Subject: [PATCH 2/7] Enhance Graph.update() and add whole-graph update tests - Extend Graph.update() to accept both GraphBuilder and GraphDef sources - Surface CUgraphExecUpdateResultInfo details on update failure instead of a generic CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE message - Release the GIL during cuGraphExecUpdate via nogil block - Add parametrized happy-path test covering both GraphBuilder and GraphDef - Add error-case tests: unfinished builder, topology mismatch, wrong type Made-with: Cursor --- cuda_core/cuda/core/_graph/_graph_builder.pyx | 42 +++++-- cuda_core/tests/graph/test_graph_update.py | 119 +++++++++++++++++- 2 files changed, 144 insertions(+), 17 deletions(-) diff --git a/cuda_core/cuda/core/_graph/_graph_builder.pyx b/cuda_core/cuda/core/_graph/_graph_builder.pyx index 3ec3d158eb..e67efb120f 100644 --- a/cuda_core/cuda/core/_graph/_graph_builder.pyx +++ b/cuda_core/cuda/core/_graph/_graph_builder.pyx @@ -5,6 +5,8 @@ import weakref from dataclasses import dataclass +from libc.stdint cimport intptr_t + from cuda.bindings cimport cydriver from cuda.core._graph._utils cimport _attach_host_callback_to_graph @@ -14,6 +16,7 @@ from cuda.core._utils.cuda_utils cimport HANDLE_RETURN from cuda.core._utils.version cimport cy_binding_version, cy_driver_version from cuda.core._utils.cuda_utils import ( + CUDAError, driver, handle_return, ) @@ -783,24 +786,41 @@ class Graph: """ return self._mnff.graph - def update(self, builder: GraphBuilder): - """Update the graph using new build configuration from the builder. + def update(self, source: "GraphBuilder | GraphDef") -> None: + """Update the graph using a new graph definition. - The topology of the provided builder must be identical to this graph. + The topology of the provided source must be identical to this graph. Parameters ---------- - builder : :obj:`~_graph.GraphBuilder` - The builder to update the graph with. + source : :obj:`~_graph.GraphBuilder` or :obj:`~_graph._graphdef.GraphDef` + The graph definition to update from. A GraphBuilder must have + finished building. """ - if not builder._building_ended: - raise ValueError("Graph has not finished building.") + from cuda.core._graph._graphdef import GraphDef + + cdef cydriver.CUgraph cu_graph + cdef cydriver.CUgraphExec cu_exec = int(self._mnff.graph) - # Update the graph with the new nodes from the builder - exec_update_result = handle_return(driver.cuGraphExecUpdate(self._mnff.graph, builder._mnff.graph)) - if exec_update_result.result != driver.CUgraphExecUpdateResult.CU_GRAPH_EXEC_UPDATE_SUCCESS: - raise RuntimeError(f"Failed to update graph: {exec_update_result.result()}") + if isinstance(source, GraphBuilder): + if not source._building_ended: + raise ValueError("Graph has not finished building.") + cu_graph = int(source._mnff.graph) + elif isinstance(source, GraphDef): + cu_graph = int(source.handle) + else: + raise TypeError( + f"expected GraphBuilder or GraphDef, got {type(source).__name__}") + + cdef cydriver.CUgraphExecUpdateResultInfo result_info + cdef cydriver.CUresult err + with nogil: + err = cydriver.cuGraphExecUpdate(cu_exec, cu_graph, &result_info) + if err != cydriver.CUresult.CUDA_SUCCESS: + reason = driver.CUgraphExecUpdateResult(result_info.result) + msg = f"Graph update failed: {reason.__doc__.strip()} ({reason.name})" + raise CUDAError(msg) def upload(self, stream: Stream): """Uploads the graph in a stream. diff --git a/cuda_core/tests/graph/test_graph_update.py b/cuda_core/tests/graph/test_graph_update.py index baa9a50313..80bf7edc53 100644 --- a/cuda_core/tests/graph/test_graph_update.py +++ b/cuda_core/tests/graph/test_graph_update.py @@ -5,17 +5,66 @@ import numpy as np import pytest -from helpers.graph_kernels import compile_conditional_kernels +from helpers.graph_kernels import compile_common_kernels, compile_conditional_kernels from cuda.core import Device, LaunchConfig, LegacyPinnedMemoryResource, launch +from cuda.core._graph._graphdef import GraphDef +from cuda.core._utils.cuda_utils import CUDAError +@pytest.mark.parametrize("builder", ["GraphBuilder", "GraphDef"]) @pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+") -def test_graph_update(init_cuda): +def test_graph_update_kernel_args(init_cuda, builder): + """Update redirects a kernel to write to a different pointer.""" + mod = compile_common_kernels() + add_one = mod.get_kernel("add_one") + + launch_stream = Device().create_stream() + mr = LegacyPinnedMemoryResource() + b = mr.allocate(8) + arr = np.from_dlpack(b).view(np.int32) + arr[0] = 0 + arr[1] = 0 + + if builder == "GraphBuilder": + + def build(ptr): + gb = Device().create_graph_builder().begin_building() + launch(gb, LaunchConfig(grid=1, block=1), add_one, ptr) + launch(gb, LaunchConfig(grid=1, block=1), add_one, ptr) + finished = gb.end_building() + return finished.complete(), finished + elif builder == "GraphDef": + + def build(ptr): + g = GraphDef() + g.launch(LaunchConfig(grid=1, block=1), add_one, ptr) + g.launch(LaunchConfig(grid=1, block=1), add_one, ptr) + return g.instantiate(), g + + graph, _ = build(arr[0:].ctypes.data) + _, source1 = build(arr[1:].ctypes.data) + + graph.launch(launch_stream) + launch_stream.sync() + assert arr[0] == 2 + assert arr[1] == 0 + + graph.update(source1) + graph.launch(launch_stream) + launch_stream.sync() + assert arr[0] == 2 + assert arr[1] == 2 + + b.close() + + +@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+") +def test_graph_update_conditional(init_cuda): + """Update swaps conditional switch graphs with matching topology.""" mod = compile_conditional_kernels(int) add_one = mod.get_kernel("add_one") - # Allocate memory launch_stream = Device().create_stream() mr = LegacyPinnedMemoryResource() b = mr.allocate(12) @@ -72,9 +121,6 @@ def build_graph(condition_value): pytest.skip("Driver does not support conditional switch") # Launch the first graph - assert arr[0] == 0 - assert arr[1] == 0 - assert arr[2] == 0 graph = graph_variants[0].complete() graph.launch(launch_stream) launch_stream.sync() @@ -98,4 +144,65 @@ def build_graph(condition_value): assert arr[1] == 3 assert arr[2] == 3 + # Close the memory resource now because the garbage collected might + # de-allocate it during the next graph builder process b.close() + + +# ============================================================================= +# Error cases +# ============================================================================= + + +def test_graph_update_unfinished_builder(init_cuda): + """Update with an unfinished GraphBuilder raises ValueError.""" + mod = compile_common_kernels() + empty_kernel = mod.get_kernel("empty_kernel") + + gb_finished = Device().create_graph_builder().begin_building() + launch(gb_finished, LaunchConfig(grid=1, block=1), empty_kernel) + graph = gb_finished.end_building().complete() + + gb_unfinished = Device().create_graph_builder().begin_building() + launch(gb_unfinished, LaunchConfig(grid=1, block=1), empty_kernel) + + with pytest.raises(ValueError, match="Graph has not finished building"): + graph.update(gb_unfinished) + + gb_unfinished.end_building() + + +def test_graph_update_topology_mismatch(init_cuda): + """Update with a different topology raises CUDAError.""" + mod = compile_common_kernels() + empty_kernel = mod.get_kernel("empty_kernel") + + # Two-node graph + gb1 = Device().create_graph_builder().begin_building() + launch(gb1, LaunchConfig(grid=1, block=1), empty_kernel) + launch(gb1, LaunchConfig(grid=1, block=1), empty_kernel) + graph = gb1.end_building().complete() + + # Three-node graph (different topology) + gb2 = Device().create_graph_builder().begin_building() + launch(gb2, LaunchConfig(grid=1, block=1), empty_kernel) + launch(gb2, LaunchConfig(grid=1, block=1), empty_kernel) + launch(gb2, LaunchConfig(grid=1, block=1), empty_kernel) + gb2.end_building() + + expected = r"Graph update failed: The update failed because the topology changed \(CU_GRAPH_EXEC_UPDATE_ERROR_TOPOLOGY_CHANGED\)" + with pytest.raises(CUDAError, match=expected): + graph.update(gb2) + + +def test_graph_update_wrong_type(init_cuda): + """Update with an invalid type raises TypeError.""" + mod = compile_common_kernels() + empty_kernel = mod.get_kernel("empty_kernel") + + gb = Device().create_graph_builder().begin_building() + launch(gb, LaunchConfig(grid=1, block=1), empty_kernel) + graph = gb.end_building().complete() + + with pytest.raises(TypeError, match="expected GraphBuilder or GraphDef"): + graph.update("not a graph") From 51cfcb0a3df6dc9ffd7c83a8bef4f1c9303946be Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Tue, 31 Mar 2026 12:16:56 -0700 Subject: [PATCH 3/7] Fix GraphDef test race condition; correct handle property annotations - Chain GraphDef kernel launches sequentially (n.launch instead of g.launch) to avoid concurrent writes to the same memory location - Update GraphDef.handle and GraphNode.handle annotations to reflect that as_py returns driver types (CUgraph, CUgraphNode), not int Made-with: Cursor --- cuda_core/cuda/core/_graph/_graphdef.pyx | 8 ++++---- cuda_core/tests/graph/test_graph_update.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cuda_core/cuda/core/_graph/_graphdef.pyx b/cuda_core/cuda/core/_graph/_graphdef.pyx index e924540281..58074f9c1e 100644 --- a/cuda_core/cuda/core/_graph/_graphdef.pyx +++ b/cuda_core/cuda/core/_graph/_graphdef.pyx @@ -529,8 +529,8 @@ cdef class GraphDef: ) @property - def handle(self) -> int: - """Return the underlying CUgraph handle.""" + def handle(self): + """Return the underlying driver CUgraph handle.""" return as_py(self._h_graph) @@ -624,8 +624,8 @@ cdef class GraphNode: return GraphDef._from_handle(graph_node_get_graph(self._h_node)) @property - def handle(self) -> int | None: - """Return the underlying CUgraphNode handle as an int. + def handle(self): + """Return the underlying driver CUgraphNode handle. Returns None for the entry node. """ diff --git a/cuda_core/tests/graph/test_graph_update.py b/cuda_core/tests/graph/test_graph_update.py index 80bf7edc53..caf9ea4304 100644 --- a/cuda_core/tests/graph/test_graph_update.py +++ b/cuda_core/tests/graph/test_graph_update.py @@ -38,8 +38,8 @@ def build(ptr): def build(ptr): g = GraphDef() - g.launch(LaunchConfig(grid=1, block=1), add_one, ptr) - g.launch(LaunchConfig(grid=1, block=1), add_one, ptr) + n = g.launch(LaunchConfig(grid=1, block=1), add_one, ptr) + n.launch(LaunchConfig(grid=1, block=1), add_one, ptr) return g.instantiate(), g graph, _ = build(arr[0:].ctypes.data) From 72821c7882d0eba6376f09fe9ec5e8dd64895d4c Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Tue, 31 Mar 2026 13:55:36 -0700 Subject: [PATCH 4/7] Split _graphdef.pyx into _graph_def/ subpackage for maintainability The monolithic _graphdef.pyx (2000+ lines) is split into three focused modules under _graph_def/: _graph_def.pyx (Condition, GraphAllocOptions, GraphDef), _graph_node.pyx (GraphNode base class and builder methods), and _subclasses.pyx (all concrete node subclasses). Long method bodies in GraphNode are factored into cdef inline GN_* helpers following existing codebase conventions. Handle property annotations updated to use driver.* types consistently. Made-with: Cursor --- .../cuda/core/_graph/_graph_def/__init__.pxd | 23 + .../cuda/core/_graph/_graph_def/__init__.py | 51 + .../core/_graph/_graph_def/_graph_def.pxd | 21 + .../core/_graph/_graph_def/_graph_def.pyx | 381 +++ .../core/_graph/_graph_def/_graph_node.pxd | 17 + .../core/_graph/_graph_def/_graph_node.pyx | 980 ++++++++ .../_subclasses.pxd} | 50 +- .../core/_graph/_graph_def/_subclasses.pyx | 755 ++++++ cuda_core/cuda/core/_graph/_graphdef.pyx | 2053 ----------------- cuda_core/tests/graph/test_graphdef.py | 2 +- cuda_core/tests/graph/test_graphdef_errors.py | 2 +- .../tests/graph/test_graphdef_integration.py | 2 +- .../tests/graph/test_graphdef_lifetime.py | 2 +- cuda_core/tests/test_object_protocols.py | 2 +- 14 files changed, 2236 insertions(+), 2105 deletions(-) create mode 100644 cuda_core/cuda/core/_graph/_graph_def/__init__.pxd create mode 100644 cuda_core/cuda/core/_graph/_graph_def/__init__.py create mode 100644 cuda_core/cuda/core/_graph/_graph_def/_graph_def.pxd create mode 100644 cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx create mode 100644 cuda_core/cuda/core/_graph/_graph_def/_graph_node.pxd create mode 100644 cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx rename cuda_core/cuda/core/_graph/{_graphdef.pxd => _graph_def/_subclasses.pxd} (80%) create mode 100644 cuda_core/cuda/core/_graph/_graph_def/_subclasses.pyx delete mode 100644 cuda_core/cuda/core/_graph/_graphdef.pyx diff --git a/cuda_core/cuda/core/_graph/_graph_def/__init__.pxd b/cuda_core/cuda/core/_graph/_graph_def/__init__.pxd new file mode 100644 index 0000000000..cfd0367876 --- /dev/null +++ b/cuda_core/cuda/core/_graph/_graph_def/__init__.pxd @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from cuda.core._graph._graph_def._graph_def cimport Condition, GraphDef +from cuda.core._graph._graph_def._graph_node cimport GraphNode +from cuda.core._graph._graph_def._subclasses cimport ( + AllocNode, + ChildGraphNode, + ConditionalNode, + EmptyNode, + EventRecordNode, + EventWaitNode, + FreeNode, + HostCallbackNode, + IfElseNode, + IfNode, + KernelNode, + MemcpyNode, + MemsetNode, + SwitchNode, + WhileNode, +) diff --git a/cuda_core/cuda/core/_graph/_graph_def/__init__.py b/cuda_core/cuda/core/_graph/_graph_def/__init__.py new file mode 100644 index 0000000000..472cdbde74 --- /dev/null +++ b/cuda_core/cuda/core/_graph/_graph_def/__init__.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Explicit CUDA graph construction — GraphDef, GraphNode, and node subclasses.""" + +from cuda.core._graph._graph_def._graph_def import ( + Condition, + GraphAllocOptions, + GraphDef, +) +from cuda.core._graph._graph_def._graph_node import GraphNode +from cuda.core._graph._graph_def._subclasses import ( + AllocNode, + ChildGraphNode, + ConditionalNode, + EmptyNode, + EventRecordNode, + EventWaitNode, + FreeNode, + HostCallbackNode, + IfElseNode, + IfNode, + KernelNode, + MemcpyNode, + MemsetNode, + SwitchNode, + WhileNode, +) + +__all__ = [ + "AllocNode", + "ChildGraphNode", + "Condition", + "ConditionalNode", + "EmptyNode", + "EventRecordNode", + "EventWaitNode", + "FreeNode", + "GraphAllocOptions", + "GraphDef", + "GraphNode", + "HostCallbackNode", + "IfElseNode", + "IfNode", + "KernelNode", + "MemcpyNode", + "MemsetNode", + "SwitchNode", + "WhileNode", +] diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pxd b/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pxd new file mode 100644 index 0000000000..19c4f08031 --- /dev/null +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pxd @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from cuda.bindings cimport cydriver +from cuda.core._resource_handles cimport GraphHandle + + +cdef class Condition: + cdef: + cydriver.CUgraphConditionalHandle _c_handle + object __weakref__ + + +cdef class GraphDef: + cdef: + GraphHandle _h_graph + object __weakref__ + + @staticmethod + cdef GraphDef _from_handle(GraphHandle h_graph) diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx b/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx new file mode 100644 index 0000000000..d45c72ba2a --- /dev/null +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx @@ -0,0 +1,381 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""GraphDef: explicit CUDA graph definition.""" + +from __future__ import annotations + +from libc.stddef cimport size_t + +from libcpp.vector cimport vector + +from cuda.bindings cimport cydriver + +from cuda.core._graph._graph_def._graph_node cimport GraphNode +from cuda.core._resource_handles cimport ( + GraphHandle, + as_cu, + as_intptr, + as_py, + create_graph_handle, + create_graph_handle_ref, + create_graph_node_handle, +) +from cuda.core._utils.cuda_utils cimport HANDLE_RETURN + +from dataclasses import dataclass + +from cuda.core._utils.cuda_utils import driver, handle_return + + +cdef class Condition: + """Wraps a CUgraphConditionalHandle. + + Created by :meth:`GraphDef.create_condition` and passed to + conditional-node builder methods (``if_cond``, ``if_else``, + ``while_loop``, ``switch``). The underlying value is set at + runtime by device code via ``cudaGraphSetConditional``. + """ + + def __repr__(self) -> str: + return f"self._c_handle:x}>" + + def __eq__(self, other) -> bool: + if not isinstance(other, Condition): + return NotImplemented + return self._c_handle == (other)._c_handle + + def __hash__(self) -> int: + return hash(self._c_handle) + + @property + def handle(self) -> driver.CUgraphConditionalHandle: + """The raw CUgraphConditionalHandle as an int.""" + return self._c_handle + + +@dataclass +class GraphAllocOptions: + """Options for graph memory allocation nodes. + + Attributes + ---------- + device : int or Device, optional + The device on which to allocate memory. If None (default), + uses the current CUDA context's device. + memory_type : str, optional + Type of memory to allocate. One of: + + - ``"device"`` (default): Pinned device memory, optimal for GPU kernels. + - ``"host"``: Pinned host memory, accessible from both host and device. + Useful for graphs containing host callback nodes. Note: may not be + supported on all systems/drivers. + - ``"managed"``: Managed/unified memory that automatically migrates + between host and device. Useful for mixed host/device access patterns. + + peer_access : list of int or Device, optional + List of devices that should have read-write access to the + allocated memory. If None (default), only the allocating + device has access. + + Notes + ----- + - IPC (inter-process communication) is not supported for graph + memory allocation nodes per CUDA documentation. + - The allocation uses the device's default memory pool. + """ + + device: int | "Device" | None = None + memory_type: str = "device" + peer_access: list | None = None + + +cdef class GraphDef: + """Represents a CUDA graph definition (CUgraph). + + A GraphDef is used to construct a graph explicitly by adding nodes + and specifying dependencies. Once construction is complete, call + instantiate() to obtain an executable Graph. + """ + + def __init__(self): + """Create a new empty graph definition.""" + cdef cydriver.CUgraph graph = NULL + with nogil: + HANDLE_RETURN(cydriver.cuGraphCreate(&graph, 0)) + self._h_graph = create_graph_handle(graph) + + @staticmethod + cdef GraphDef _from_handle(GraphHandle h_graph): + """Create a GraphDef from an existing GraphHandle (internal use).""" + cdef GraphDef g = GraphDef.__new__(GraphDef) + g._h_graph = h_graph + return g + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other) -> bool: + if not isinstance(other, GraphDef): + return NotImplemented + return as_intptr(self._h_graph) == as_intptr((other)._h_graph) + + def __hash__(self) -> int: + return hash(as_intptr(self._h_graph)) + + @property + def _entry(self) -> "GraphNode": + """Return the internal entry-point GraphNode (no dependencies).""" + cdef GraphNode n = GraphNode.__new__(GraphNode) + n._h_node = create_graph_node_handle(NULL, self._h_graph) + return n + + def alloc(self, size_t size, options: GraphAllocOptions | None = None) -> "AllocNode": + """Add an entry-point memory allocation node (no dependencies). + + See :meth:`GraphNode.alloc` for full documentation. + """ + return self._entry.alloc(size, options) + + def free(self, dptr) -> "FreeNode": + """Add an entry-point memory free node (no dependencies). + + See :meth:`GraphNode.free` for full documentation. + """ + return self._entry.free(dptr) + + def memset(self, dst, value, size_t width, size_t height=1, size_t pitch=0) -> "MemsetNode": + """Add an entry-point memset node (no dependencies). + + See :meth:`GraphNode.memset` for full documentation. + """ + return self._entry.memset(dst, value, width, height, pitch) + + def launch(self, config, kernel, *args) -> "KernelNode": + """Add an entry-point kernel launch node (no dependencies). + + See :meth:`GraphNode.launch` for full documentation. + """ + return self._entry.launch(config, kernel, *args) + + def join(self, *nodes) -> "EmptyNode": + """Create an empty node that depends on all given nodes. + + Parameters + ---------- + *nodes : GraphNode + Nodes to merge. + + Returns + ------- + EmptyNode + A new EmptyNode that depends on all input nodes. + """ + return self._entry.join(*nodes) + + def memcpy(self, dst, src, size_t size) -> "MemcpyNode": + """Add an entry-point memcpy node (no dependencies). + + See :meth:`GraphNode.memcpy` for full documentation. + """ + return self._entry.memcpy(dst, src, size) + + def embed(self, child: GraphDef) -> "ChildGraphNode": + """Add an entry-point child graph node (no dependencies). + + See :meth:`GraphNode.embed` for full documentation. + """ + return self._entry.embed(child) + + def record_event(self, event) -> "EventRecordNode": + """Add an entry-point event record node (no dependencies). + + See :meth:`GraphNode.record_event` for full documentation. + """ + return self._entry.record_event(event) + + def wait_event(self, event) -> "EventWaitNode": + """Add an entry-point event wait node (no dependencies). + + See :meth:`GraphNode.wait_event` for full documentation. + """ + return self._entry.wait_event(event) + + def callback(self, fn, *, user_data=None) -> "HostCallbackNode": + """Add an entry-point host callback node (no dependencies). + + See :meth:`GraphNode.callback` for full documentation. + """ + return self._entry.callback(fn, user_data=user_data) + + def create_condition(self, default_value: int | None = None) -> Condition: + """Create a condition variable for use with conditional nodes. + + The returned :class:`Condition` object is passed to conditional-node + builder methods. Its value is controlled at runtime by device code + via ``cudaGraphSetConditional``. + + Parameters + ---------- + default_value : int, optional + The default value to assign to the condition. + If None, no default is assigned. + + Returns + ------- + Condition + A condition variable for controlling conditional execution. + """ + cdef cydriver.CUgraphConditionalHandle c_handle + cdef unsigned int flags = 0 + cdef unsigned int default_val = 0 + + if default_value is not None: + default_val = default_value + flags = cydriver.CU_GRAPH_COND_ASSIGN_DEFAULT + + cdef cydriver.CUcontext ctx = NULL + with nogil: + HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx)) + HANDLE_RETURN(cydriver.cuGraphConditionalHandleCreate( + &c_handle, as_cu(self._h_graph), ctx, default_val, flags)) + + cdef Condition cond = Condition.__new__(Condition) + cond._c_handle = c_handle + return cond + + def if_cond(self, condition: Condition) -> "IfNode": + """Add an entry-point if-conditional node (no dependencies). + + See :meth:`GraphNode.if_cond` for full documentation. + """ + return self._entry.if_cond(condition) + + def if_else(self, condition: Condition) -> "IfElseNode": + """Add an entry-point if-else conditional node (no dependencies). + + See :meth:`GraphNode.if_else` for full documentation. + """ + return self._entry.if_else(condition) + + def while_loop(self, condition: Condition) -> "WhileNode": + """Add an entry-point while-loop conditional node (no dependencies). + + See :meth:`GraphNode.while_loop` for full documentation. + """ + return self._entry.while_loop(condition) + + def switch(self, condition: Condition, unsigned int count) -> "SwitchNode": + """Add an entry-point switch conditional node (no dependencies). + + See :meth:`GraphNode.switch` for full documentation. + """ + return self._entry.switch(condition, count) + + def instantiate(self, options=None): + """Instantiate the graph definition into an executable Graph. + + Parameters + ---------- + options : :obj:`~_graph.GraphCompleteOptions`, optional + Customizable dataclass for graph instantiation options. + + Returns + ------- + Graph + An executable graph that can be launched on a stream. + """ + from cuda.core._graph._graph_builder import _instantiate_graph + + return _instantiate_graph( + driver.CUgraph(as_intptr(self._h_graph)), options) + + def debug_dot_print(self, path: str, options=None) -> None: + """Write a GraphViz DOT representation of the graph to a file. + + Parameters + ---------- + path : str + File path for the DOT output. + options : GraphDebugPrintOptions, optional + Customizable options for the debug print. + """ + from cuda.core._graph._graph_builder import GraphDebugPrintOptions + + cdef unsigned int flags = 0 + if options is not None: + if not isinstance(options, GraphDebugPrintOptions): + raise TypeError("options must be a GraphDebugPrintOptions instance") + flags = options._to_flags() + + cdef bytes path_bytes = path.encode('utf-8') + cdef const char* c_path = path_bytes + with nogil: + HANDLE_RETURN(cydriver.cuGraphDebugDotPrint(as_cu(self._h_graph), c_path, flags)) + + def nodes(self) -> tuple: + """Return all nodes in the graph. + + Returns + ------- + tuple of GraphNode + All nodes in the graph. + """ + cdef size_t num_nodes = 0 + + with nogil: + HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), NULL, &num_nodes)) + + if num_nodes == 0: + return () + + cdef vector[cydriver.CUgraphNode] nodes_vec + nodes_vec.resize(num_nodes) + with nogil: + HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes)) + + return tuple(GraphNode._create(self._h_graph, nodes_vec[i]) for i in range(num_nodes)) + + def edges(self) -> tuple: + """Return all edges in the graph as (from_node, to_node) pairs. + + Returns + ------- + tuple of tuple + Each element is a (from_node, to_node) pair representing + a dependency edge in the graph. + """ + cdef size_t num_edges = 0 + + with nogil: + IF CUDA_CORE_BUILD_MAJOR >= 13: + HANDLE_RETURN(cydriver.cuGraphGetEdges(as_cu(self._h_graph), NULL, NULL, NULL, &num_edges)) + ELSE: + HANDLE_RETURN(cydriver.cuGraphGetEdges(as_cu(self._h_graph), NULL, NULL, &num_edges)) + + if num_edges == 0: + return () + + cdef vector[cydriver.CUgraphNode] from_nodes + cdef vector[cydriver.CUgraphNode] to_nodes + from_nodes.resize(num_edges) + to_nodes.resize(num_edges) + with nogil: + IF CUDA_CORE_BUILD_MAJOR >= 13: + HANDLE_RETURN(cydriver.cuGraphGetEdges( + as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), NULL, &num_edges)) + ELSE: + HANDLE_RETURN(cydriver.cuGraphGetEdges( + as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), &num_edges)) + + return tuple( + (GraphNode._create(self._h_graph, from_nodes[i]), + GraphNode._create(self._h_graph, to_nodes[i])) + for i in range(num_edges) + ) + + @property + def handle(self) -> driver.CUgraph: + """Return the underlying driver CUgraph handle.""" + return as_py(self._h_graph) diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pxd b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pxd new file mode 100644 index 0000000000..7a9f82f33f --- /dev/null +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pxd @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from cuda.bindings cimport cydriver +from cuda.core._resource_handles cimport GraphHandle, GraphNodeHandle + + +cdef class GraphNode: + cdef: + GraphNodeHandle _h_node + tuple _pred_cache + tuple _succ_cache + object __weakref__ + + @staticmethod + cdef GraphNode _create(GraphHandle h_graph, cydriver.CUgraphNode node) diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx new file mode 100644 index 0000000000..17c2c072f7 --- /dev/null +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx @@ -0,0 +1,980 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""GraphNode base class — factory, properties, and builder methods.""" + +from __future__ import annotations + +from libc.stddef cimport size_t +from libc.stdint cimport uintptr_t +from libc.string cimport memset as c_memset + +from libcpp.vector cimport vector + +from cuda.bindings cimport cydriver + +from cuda.core._event cimport Event +from cuda.core._kernel_arg_handler cimport ParamHolder +from cuda.core._launch_config cimport LaunchConfig +from cuda.core._module cimport Kernel +from cuda.core._graph._graph_def._graph_def cimport Condition, GraphDef +from cuda.core._graph._graph_def._subclasses cimport ( + AllocNode, + ChildGraphNode, + ConditionalNode, + EmptyNode, + EventRecordNode, + EventWaitNode, + FreeNode, + HostCallbackNode, + IfElseNode, + IfNode, + KernelNode, + MemcpyNode, + MemsetNode, + SwitchNode, + WhileNode, +) +from cuda.core._resource_handles cimport ( + EventHandle, + GraphHandle, + KernelHandle, + GraphNodeHandle, + as_cu, + as_intptr, + as_py, + create_event_handle_ref, + create_graph_handle_ref, + create_graph_node_handle, + graph_node_get_graph, +) +from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, _parse_fill_value + +from cuda.core._graph._utils cimport ( + _attach_host_callback_to_graph, + _attach_user_object, +) + +from cuda.core import Device +from cuda.core._utils.cuda_utils import driver, handle_return + + +cdef class GraphNode: + """Base class for all graph nodes. + + Nodes are created by calling builder methods on GraphDef (for + entry-point nodes with no dependencies) or on other Nodes (for + nodes that depend on a predecessor). + """ + + @staticmethod + cdef GraphNode _create(GraphHandle h_graph, cydriver.CUgraphNode node): + """Factory: dispatch to the right subclass based on node type.""" + return GN_create(h_graph, node) + + def __repr__(self) -> str: + cdef cydriver.CUgraphNode node = as_cu(self._h_node) + if node == NULL: + return "" + return f"node:x}>" + + def __eq__(self, other) -> bool: + if not isinstance(other, GraphNode): + return NotImplemented + cdef GraphNode o = other + cdef GraphHandle self_graph = graph_node_get_graph(self._h_node) + cdef GraphHandle other_graph = graph_node_get_graph(o._h_node) + return (as_intptr(self._h_node) == as_intptr(o._h_node) + and as_intptr(self_graph) == as_intptr(other_graph)) + + def __hash__(self) -> int: + cdef GraphHandle g = graph_node_get_graph(self._h_node) + return hash((as_intptr(self._h_node), as_intptr(g))) + + @property + def type(self): + """Return the CUDA graph node type. + + Returns + ------- + CUgraphNodeType or None + The node type enum value, or None for the entry node. + """ + cdef cydriver.CUgraphNode node = as_cu(self._h_node) + if node == NULL: + return None + cdef cydriver.CUgraphNodeType node_type + with nogil: + HANDLE_RETURN(cydriver.cuGraphNodeGetType(node, &node_type)) + return driver.CUgraphNodeType(node_type) + + @property + def graph(self) -> "GraphDef": + """Return the GraphDef this node belongs to.""" + return GraphDef._from_handle(graph_node_get_graph(self._h_node)) + + @property + def handle(self) -> driver.CUgraphNode: + """Return the underlying driver CUgraphNode handle. + + Returns None for the entry node. + """ + return as_py(self._h_node) + + @property + def pred(self) -> tuple: + """Return the predecessor nodes (dependencies) of this node. + + Results are cached since a node's dependencies are immutable + once created. + + Returns + ------- + tuple of GraphNode + The nodes that this node depends on. + """ + return GN_pred(self) + + @property + def succ(self) -> tuple: + """Return the successor nodes (dependents) of this node. + + Results are cached and automatically invalidated when new + dependent nodes are added via builder methods. + + Returns + ------- + tuple of GraphNode + The nodes that depend on this node. + """ + return GN_succ(self) + + def launch(self, config: LaunchConfig, kernel: Kernel, *args) -> KernelNode: + """Add a kernel launch node depending on this node. + + Parameters + ---------- + config : LaunchConfig + Launch configuration (grid, block, shared memory, etc.) + kernel : Kernel + The kernel to launch. + *args + Kernel arguments. + + Returns + ------- + KernelNode + A new KernelNode representing the kernel launch. + """ + return GN_launch(self, config, kernel, ParamHolder(args)) + + def join(self, *nodes: GraphNode) -> EmptyNode: + """Create an empty node that depends on this node and all given nodes. + + This is used to synchronize multiple branches of execution. + + Parameters + ---------- + *nodes : GraphNode + Additional nodes to depend on. + + Returns + ------- + EmptyNode + A new EmptyNode that depends on all input nodes. + """ + return GN_join(self, nodes) + + def alloc(self, size_t size, options=None) -> AllocNode: + """Add a memory allocation node depending on this node. + + Parameters + ---------- + size : int + Number of bytes to allocate. + options : GraphAllocOptions, optional + Allocation options. If None, allocates on the current device. + + Returns + ------- + AllocNode + A new AllocNode representing the allocation. Access the allocated + device pointer via the dptr property. + """ + return GN_alloc(self, size, options) + + def free(self, dptr: int) -> FreeNode: + """Add a memory free node depending on this node. + + Parameters + ---------- + dptr : int + Device pointer to free (typically from AllocNode.dptr). + + Returns + ------- + FreeNode + A new FreeNode representing the free operation. + """ + return GN_free(self, dptr) + + def memset(self, dst: int, value, size_t width, size_t height=1, size_t pitch=0) -> MemsetNode: + """Add a memset node depending on this node. + + Parameters + ---------- + dst : int + Destination device pointer. + value : int or buffer-protocol object + Fill value. int for 1-byte fill (range [0, 256)), + or buffer-protocol object of 1, 2, or 4 bytes. + width : int + Width of the row in elements. + height : int, optional + Number of rows (default 1). + pitch : int, optional + Pitch of destination in bytes (default 0, unused if height is 1). + + Returns + ------- + MemsetNode + A new MemsetNode representing the memset operation. + """ + cdef unsigned int val + cdef unsigned int elem_size + val, elem_size = _parse_fill_value(value) + return GN_memset(self, dst, val, elem_size, width, height, pitch) + + def memcpy(self, dst: int, src: int, size_t size) -> MemcpyNode: + """Add a memcpy node depending on this node. + + Copies ``size`` bytes from ``src`` to ``dst``. Memory types are + auto-detected via the driver, so both device and pinned host + pointers are supported. + + Parameters + ---------- + dst : int + Destination pointer (device or pinned host). + src : int + Source pointer (device or pinned host). + size : int + Number of bytes to copy. + + Returns + ------- + MemcpyNode + A new MemcpyNode representing the copy operation. + """ + return GN_memcpy(self, dst, src, size) + + def embed(self, child: GraphDef) -> ChildGraphNode: + """Add a child graph node depending on this node. + + Embeds a clone of the given graph definition as a sub-graph node. + The child graph must not contain allocation, free, or conditional + nodes. + + Parameters + ---------- + child : GraphDef + The graph definition to embed (will be cloned). + + Returns + ------- + ChildGraphNode + A new ChildGraphNode representing the embedded sub-graph. + """ + return GN_embed(self, child) + + def record_event(self, event: Event) -> EventRecordNode: + """Add an event record node depending on this node. + + Parameters + ---------- + event : Event + The event to record. + + Returns + ------- + EventRecordNode + A new EventRecordNode representing the event record operation. + """ + return GN_record_event(self, event) + + def wait_event(self, event: Event) -> EventWaitNode: + """Add an event wait node depending on this node. + + Parameters + ---------- + event : Event + The event to wait for. + + Returns + ------- + EventWaitNode + A new EventWaitNode representing the event wait operation. + """ + return GN_wait_event(self, event) + + def callback(self, fn, *, user_data=None) -> HostCallbackNode: + """Add a host callback node depending on this node. + + The callback runs on the host CPU when the graph reaches this node. + Two modes are supported: + + - **Python callable**: Pass any callable. The GIL is acquired + automatically. The callable must take no arguments; use closures + or ``functools.partial`` to bind state. + - **ctypes function pointer**: Pass a ``ctypes.CFUNCTYPE`` instance. + The function receives a single ``void*`` argument (the + ``user_data``). The caller must keep the ctypes wrapper alive + for the lifetime of the graph. + + .. warning:: + + Callbacks must not call CUDA API functions. Doing so may + deadlock or corrupt driver state. + + Parameters + ---------- + fn : callable or ctypes function pointer + The callback function. + user_data : int or bytes-like, optional + Only for ctypes function pointers. If ``int``, passed as a raw + pointer (caller manages lifetime). If bytes-like, the data is + copied and its lifetime is tied to the graph. + + Returns + ------- + HostCallbackNode + A new HostCallbackNode representing the callback. + """ + return GN_callback(self, fn, user_data) + + def if_cond(self, condition: Condition) -> IfNode: + """Add an if-conditional node depending on this node. + + The body graph executes only when the condition evaluates to + a non-zero value at runtime. + + Parameters + ---------- + condition : Condition + Condition from :meth:`GraphDef.create_condition`. + + Returns + ------- + IfNode + A new IfNode with one branch accessible via ``.then``. + """ + return _make_conditional_node( + self, condition, + cydriver.CU_GRAPH_COND_TYPE_IF, 1, IfNode) + + def if_else(self, condition: Condition) -> IfElseNode: + """Add an if-else conditional node depending on this node. + + Two body graphs: the first executes when the condition is + non-zero, the second when it is zero. + + Parameters + ---------- + condition : Condition + Condition from :meth:`GraphDef.create_condition`. + + Returns + ------- + IfElseNode + A new IfElseNode with branches accessible via + ``.then`` and ``.else_``. + """ + return _make_conditional_node( + self, condition, + cydriver.CU_GRAPH_COND_TYPE_IF, 2, IfElseNode) + + def while_loop(self, condition: Condition) -> WhileNode: + """Add a while-loop conditional node depending on this node. + + The body graph executes repeatedly while the condition + evaluates to a non-zero value. + + Parameters + ---------- + condition : Condition + Condition from :meth:`GraphDef.create_condition`. + + Returns + ------- + WhileNode + A new WhileNode with body accessible via ``.body``. + """ + return _make_conditional_node( + self, condition, + cydriver.CU_GRAPH_COND_TYPE_WHILE, 1, WhileNode) + + def switch(self, condition: Condition, unsigned int count) -> SwitchNode: + """Add a switch conditional node depending on this node. + + The condition value selects which branch to execute. If the + value is out of range, no branch executes. + + Parameters + ---------- + condition : Condition + Condition from :meth:`GraphDef.create_condition`. + count : int + Number of switch cases (branches). + + Returns + ------- + SwitchNode + A new SwitchNode with branches accessible via ``.branches``. + """ + return _make_conditional_node( + self, condition, + cydriver.CU_GRAPH_COND_TYPE_SWITCH, count, SwitchNode) + + +cdef void _destroy_event_handle_copy(void* ptr) noexcept nogil: + cdef EventHandle* p = ptr + del p + + +cdef void _destroy_kernel_handle_copy(void* ptr) noexcept nogil: + cdef KernelHandle* p = ptr + del p + + +cdef inline ConditionalNode _make_conditional_node( + GraphNode pred, + Condition condition, + cydriver.CUgraphConditionalNodeType cond_type, + unsigned int size, + type node_cls): + if not isinstance(condition, Condition): + raise TypeError( + f"condition must be a Condition object (from " + f"GraphDef.create_condition()), got {type(condition).__name__}") + cdef cydriver.CUgraphNodeParams params + cdef cydriver.CUgraphNode new_node = NULL + + c_memset(¶ms, 0, sizeof(params)) + params.type = cydriver.CU_GRAPH_NODE_TYPE_CONDITIONAL + params.conditional.handle = condition._c_handle + params.conditional.type = cond_type + params.conditional.size = size + + cdef cydriver.CUcontext ctx = NULL + cdef GraphHandle h_graph = graph_node_get_graph(pred._h_node) + cdef cydriver.CUgraphNode pred_node = as_cu(pred._h_node) + cdef cydriver.CUgraphNode* deps = NULL + cdef size_t num_deps = 0 + + if pred_node != NULL: + deps = &pred_node + num_deps = 1 + + with nogil: + HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx)) + params.conditional.ctx = ctx + + with nogil: + IF CUDA_CORE_BUILD_MAJOR >= 13: + HANDLE_RETURN(cydriver.cuGraphAddNode( + &new_node, as_cu(h_graph), deps, NULL, num_deps, ¶ms)) + ELSE: + HANDLE_RETURN(cydriver.cuGraphAddNode( + &new_node, as_cu(h_graph), deps, num_deps, ¶ms)) + + cdef list branch_list = [] + cdef unsigned int i + cdef cydriver.CUgraph bg + cdef GraphHandle h_branch + for i in range(size): + bg = params.conditional.phGraph_out[i] + h_branch = create_graph_handle_ref(bg, h_graph) + branch_list.append(GraphDef._from_handle(h_branch)) + cdef tuple branches = tuple(branch_list) + + cdef ConditionalNode n = node_cls.__new__(node_cls) + n._h_node = create_graph_node_handle(new_node, h_graph) + n._condition = condition + n._cond_type = cond_type + n._branches = branches + + pred._succ_cache = None + return n + +cdef inline GraphNode GN_create(GraphHandle h_graph, cydriver.CUgraphNode node): + if node == NULL: + n = GraphNode.__new__(GraphNode) + (n)._h_node = create_graph_node_handle(node, h_graph) + return n + + cdef GraphNodeHandle h_node = create_graph_node_handle(node, h_graph) + cdef cydriver.CUgraphNodeType node_type + with nogil: + HANDLE_RETURN(cydriver.cuGraphNodeGetType(node, &node_type)) + + if node_type == cydriver.CU_GRAPH_NODE_TYPE_EMPTY: + return EmptyNode._create_impl(h_node) + elif node_type == cydriver.CU_GRAPH_NODE_TYPE_KERNEL: + return KernelNode._create_from_driver(h_node) + elif node_type == cydriver.CU_GRAPH_NODE_TYPE_MEM_ALLOC: + return AllocNode._create_from_driver(h_node) + elif node_type == cydriver.CU_GRAPH_NODE_TYPE_MEM_FREE: + return FreeNode._create_from_driver(h_node) + elif node_type == cydriver.CU_GRAPH_NODE_TYPE_MEMSET: + return MemsetNode._create_from_driver(h_node) + elif node_type == cydriver.CU_GRAPH_NODE_TYPE_MEMCPY: + return MemcpyNode._create_from_driver(h_node) + elif node_type == cydriver.CU_GRAPH_NODE_TYPE_GRAPH: + return ChildGraphNode._create_from_driver(h_node) + elif node_type == cydriver.CU_GRAPH_NODE_TYPE_EVENT_RECORD: + return EventRecordNode._create_from_driver(h_node) + elif node_type == cydriver.CU_GRAPH_NODE_TYPE_WAIT_EVENT: + return EventWaitNode._create_from_driver(h_node) + elif node_type == cydriver.CU_GRAPH_NODE_TYPE_HOST: + return HostCallbackNode._create_from_driver(h_node) + elif node_type == cydriver.CU_GRAPH_NODE_TYPE_CONDITIONAL: + return ConditionalNode._create_from_driver(h_node) + else: + n = GraphNode.__new__(GraphNode) + (n)._h_node = h_node + return n + + +cdef inline tuple GN_pred(GraphNode self): + if self._pred_cache is not None: + return self._pred_cache + + cdef cydriver.CUgraphNode node = as_cu(self._h_node) + if node == NULL: + self._pred_cache = () + return self._pred_cache + + cdef size_t num_deps = 0 + with nogil: + IF CUDA_CORE_BUILD_MAJOR >= 13: + HANDLE_RETURN(cydriver.cuGraphNodeGetDependencies(node, NULL, NULL, &num_deps)) + ELSE: + HANDLE_RETURN(cydriver.cuGraphNodeGetDependencies(node, NULL, &num_deps)) + + if num_deps == 0: + self._pred_cache = () + return self._pred_cache + + cdef vector[cydriver.CUgraphNode] deps + deps.resize(num_deps) + with nogil: + IF CUDA_CORE_BUILD_MAJOR >= 13: + HANDLE_RETURN(cydriver.cuGraphNodeGetDependencies(node, deps.data(), NULL, &num_deps)) + ELSE: + HANDLE_RETURN(cydriver.cuGraphNodeGetDependencies(node, deps.data(), &num_deps)) + + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + self._pred_cache = tuple(GraphNode._create(h_graph, deps[i]) for i in range(num_deps)) + return self._pred_cache + + +cdef inline tuple GN_succ(GraphNode self): + if self._succ_cache is not None: + return self._succ_cache + + cdef cydriver.CUgraphNode node = as_cu(self._h_node) + if node == NULL: + self._succ_cache = () + return self._succ_cache + + cdef size_t num_deps = 0 + with nogil: + IF CUDA_CORE_BUILD_MAJOR >= 13: + HANDLE_RETURN(cydriver.cuGraphNodeGetDependentNodes(node, NULL, NULL, &num_deps)) + ELSE: + HANDLE_RETURN(cydriver.cuGraphNodeGetDependentNodes(node, NULL, &num_deps)) + + if num_deps == 0: + self._succ_cache = () + return self._succ_cache + + cdef vector[cydriver.CUgraphNode] deps + deps.resize(num_deps) + with nogil: + IF CUDA_CORE_BUILD_MAJOR >= 13: + HANDLE_RETURN(cydriver.cuGraphNodeGetDependentNodes(node, deps.data(), NULL, &num_deps)) + ELSE: + HANDLE_RETURN(cydriver.cuGraphNodeGetDependentNodes(node, deps.data(), &num_deps)) + + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + self._succ_cache = tuple(GraphNode._create(h_graph, deps[i]) for i in range(num_deps)) + return self._succ_cache + + +cdef inline KernelNode GN_launch(GraphNode self, LaunchConfig conf, Kernel ker, ParamHolder ker_args): + cdef cydriver.CUDA_KERNEL_NODE_PARAMS node_params + cdef cydriver.CUgraphNode new_node = NULL + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) + cdef cydriver.CUgraphNode* deps = NULL + cdef size_t num_deps = 0 + + if pred_node != NULL: + deps = &pred_node + num_deps = 1 + + node_params.kern = as_cu(ker._h_kernel) + node_params.func = NULL + node_params.gridDimX = conf.grid[0] + node_params.gridDimY = conf.grid[1] + node_params.gridDimZ = conf.grid[2] + node_params.blockDimX = conf.block[0] + node_params.blockDimY = conf.block[1] + node_params.blockDimZ = conf.block[2] + node_params.sharedMemBytes = conf.shmem_size + node_params.kernelParams = (ker_args.ptr) + node_params.extra = NULL + node_params.ctx = NULL + + with nogil: + HANDLE_RETURN(cydriver.cuGraphAddKernelNode( + &new_node, as_cu(h_graph), deps, num_deps, &node_params)) + + _attach_user_object(as_cu(h_graph), new KernelHandle(ker._h_kernel), + _destroy_kernel_handle_copy) + + self._succ_cache = None + return KernelNode._create_with_params( + create_graph_node_handle(new_node, h_graph), + conf.grid, conf.block, conf.shmem_size, + ker._h_kernel) + + +cdef inline EmptyNode GN_join(GraphNode self, tuple nodes): + cdef vector[cydriver.CUgraphNode] deps + cdef cydriver.CUgraphNode new_node = NULL + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + cdef GraphNode other + cdef cydriver.CUgraphNode* deps_ptr = NULL + cdef size_t num_deps = 0 + cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) + + if pred_node != NULL: + deps.push_back(pred_node) + for other in nodes: + if as_cu((other)._h_node) != NULL: + deps.push_back(as_cu((other)._h_node)) + + num_deps = deps.size() + if num_deps > 0: + deps_ptr = deps.data() + + with nogil: + HANDLE_RETURN(cydriver.cuGraphAddEmptyNode( + &new_node, as_cu(h_graph), deps_ptr, num_deps)) + + self._succ_cache = None + for other in nodes: + (other)._succ_cache = None + return EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph)) + + +cdef inline AllocNode GN_alloc(GraphNode self, size_t size, object options): + cdef int device_id + cdef cydriver.CUdevice dev + + if options is None or options.device is None: + with nogil: + HANDLE_RETURN(cydriver.cuCtxGetDevice(&dev)) + device_id = dev + else: + device_id = getattr(options.device, 'device_id', options.device) + + cdef cydriver.CUDA_MEM_ALLOC_NODE_PARAMS alloc_params + cdef cydriver.CUgraphNode new_node = NULL + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) + cdef cydriver.CUgraphNode* deps = NULL + cdef size_t num_deps = 0 + + if pred_node != NULL: + deps = &pred_node + num_deps = 1 + + cdef vector[cydriver.CUmemAccessDesc] access_descs + cdef int peer_id + cdef list peer_ids = [] + + if options is not None and options.peer_access is not None: + for peer_dev in options.peer_access: + peer_id = getattr(peer_dev, 'device_id', peer_dev) + peer_ids.append(peer_id) + access_descs.push_back(cydriver.CUmemAccessDesc_st( + cydriver.CUmemLocation_st( + cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE, + peer_id + ), + cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE + )) + + cdef str memory_type = "device" + if options is not None and options.memory_type is not None: + memory_type = options.memory_type + + c_memset(&alloc_params, 0, sizeof(alloc_params)) + alloc_params.poolProps.handleTypes = cydriver.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_NONE + alloc_params.bytesize = size + + if memory_type == "device": + alloc_params.poolProps.allocType = cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED + alloc_params.poolProps.location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + alloc_params.poolProps.location.id = device_id + elif memory_type == "host": + alloc_params.poolProps.allocType = cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED + alloc_params.poolProps.location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_HOST + alloc_params.poolProps.location.id = 0 + elif memory_type == "managed": + IF CUDA_CORE_BUILD_MAJOR >= 13: + alloc_params.poolProps.allocType = cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_MANAGED + alloc_params.poolProps.location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + alloc_params.poolProps.location.id = device_id + ELSE: + raise ValueError("memory_type='managed' requires CUDA 13.0 or later") + else: + raise ValueError(f"Invalid memory_type: {memory_type!r}. " + "Must be 'device', 'host', or 'managed'.") + + if access_descs.size() > 0: + alloc_params.accessDescs = access_descs.data() + alloc_params.accessDescCount = access_descs.size() + + with nogil: + HANDLE_RETURN(cydriver.cuGraphAddMemAllocNode( + &new_node, as_cu(h_graph), deps, num_deps, &alloc_params)) + + self._succ_cache = None + return AllocNode._create_with_params( + create_graph_node_handle(new_node, h_graph), alloc_params.dptr, size, + device_id, memory_type, tuple(peer_ids)) + + +cdef inline FreeNode GN_free(GraphNode self, cydriver.CUdeviceptr c_dptr): + cdef cydriver.CUgraphNode new_node = NULL + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) + cdef cydriver.CUgraphNode* deps = NULL + cdef size_t num_deps = 0 + + if pred_node != NULL: + deps = &pred_node + num_deps = 1 + + with nogil: + HANDLE_RETURN(cydriver.cuGraphAddMemFreeNode( + &new_node, as_cu(h_graph), deps, num_deps, c_dptr)) + + self._succ_cache = None + return FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr) + + +cdef inline MemsetNode GN_memset( + GraphNode self, cydriver.CUdeviceptr c_dst, + unsigned int val, unsigned int elem_size, + size_t width, size_t height, size_t pitch): + cdef cydriver.CUDA_MEMSET_NODE_PARAMS memset_params + cdef cydriver.CUgraphNode new_node = NULL + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) + cdef cydriver.CUgraphNode* deps = NULL + cdef size_t num_deps = 0 + + if pred_node != NULL: + deps = &pred_node + num_deps = 1 + + cdef cydriver.CUcontext ctx = NULL + with nogil: + HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx)) + + c_memset(&memset_params, 0, sizeof(memset_params)) + memset_params.dst = c_dst + memset_params.value = val + memset_params.elementSize = elem_size + memset_params.width = width + memset_params.height = height + memset_params.pitch = pitch + + with nogil: + HANDLE_RETURN(cydriver.cuGraphAddMemsetNode( + &new_node, as_cu(h_graph), deps, num_deps, + &memset_params, ctx)) + + self._succ_cache = None + return MemsetNode._create_with_params( + create_graph_node_handle(new_node, h_graph), c_dst, + val, elem_size, width, height, pitch) + + +cdef inline MemcpyNode GN_memcpy( + GraphNode self, cydriver.CUdeviceptr c_dst, + cydriver.CUdeviceptr c_src, size_t size): + cdef unsigned int dst_mem_type = cydriver.CU_MEMORYTYPE_DEVICE + cdef unsigned int src_mem_type = cydriver.CU_MEMORYTYPE_DEVICE + cdef cydriver.CUresult ret + with nogil: + ret = cydriver.cuPointerGetAttribute( + &dst_mem_type, + cydriver.CU_POINTER_ATTRIBUTE_MEMORY_TYPE, + c_dst) + if ret != cydriver.CUDA_SUCCESS and ret != cydriver.CUDA_ERROR_INVALID_VALUE: + HANDLE_RETURN(ret) + ret = cydriver.cuPointerGetAttribute( + &src_mem_type, + cydriver.CU_POINTER_ATTRIBUTE_MEMORY_TYPE, + c_src) + if ret != cydriver.CUDA_SUCCESS and ret != cydriver.CUDA_ERROR_INVALID_VALUE: + HANDLE_RETURN(ret) + + cdef cydriver.CUmemorytype c_dst_type = dst_mem_type + cdef cydriver.CUmemorytype c_src_type = src_mem_type + + cdef cydriver.CUDA_MEMCPY3D params + c_memset(¶ms, 0, sizeof(params)) + + params.srcMemoryType = c_src_type + params.dstMemoryType = c_dst_type + if c_src_type == cydriver.CU_MEMORYTYPE_HOST: + params.srcHost = c_src + else: + params.srcDevice = c_src + if c_dst_type == cydriver.CU_MEMORYTYPE_HOST: + params.dstHost = c_dst + else: + params.dstDevice = c_dst + params.WidthInBytes = size + params.Height = 1 + params.Depth = 1 + + cdef cydriver.CUgraphNode new_node = NULL + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) + cdef cydriver.CUgraphNode* deps = NULL + cdef size_t num_deps = 0 + + if pred_node != NULL: + deps = &pred_node + num_deps = 1 + + cdef cydriver.CUcontext ctx = NULL + with nogil: + HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx)) + HANDLE_RETURN(cydriver.cuGraphAddMemcpyNode( + &new_node, as_cu(h_graph), deps, num_deps, ¶ms, ctx)) + + self._succ_cache = None + return MemcpyNode._create_with_params( + create_graph_node_handle(new_node, h_graph), c_dst, c_src, size, + c_dst_type, c_src_type) + + +cdef inline ChildGraphNode GN_embed(GraphNode self, GraphDef child_def): + cdef cydriver.CUgraphNode new_node = NULL + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) + cdef cydriver.CUgraphNode* deps = NULL + cdef size_t num_deps = 0 + + if pred_node != NULL: + deps = &pred_node + num_deps = 1 + + with nogil: + HANDLE_RETURN(cydriver.cuGraphAddChildGraphNode( + &new_node, as_cu(h_graph), deps, num_deps, as_cu(child_def._h_graph))) + + cdef cydriver.CUgraph embedded_graph = NULL + with nogil: + HANDLE_RETURN(cydriver.cuGraphChildGraphNodeGetGraph( + new_node, &embedded_graph)) + + cdef GraphHandle h_embedded = create_graph_handle_ref(embedded_graph, h_graph) + + self._succ_cache = None + return ChildGraphNode._create_with_params( + create_graph_node_handle(new_node, h_graph), h_embedded) + + +cdef inline EventRecordNode GN_record_event(GraphNode self, Event ev): + cdef cydriver.CUgraphNode new_node = NULL + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) + cdef cydriver.CUgraphNode* deps = NULL + cdef size_t num_deps = 0 + + if pred_node != NULL: + deps = &pred_node + num_deps = 1 + + with nogil: + HANDLE_RETURN(cydriver.cuGraphAddEventRecordNode( + &new_node, as_cu(h_graph), deps, num_deps, as_cu(ev._h_event))) + + _attach_user_object(as_cu(h_graph), new EventHandle(ev._h_event), + _destroy_event_handle_copy) + + self._succ_cache = None + return EventRecordNode._create_with_params( + create_graph_node_handle(new_node, h_graph), ev._h_event) + + +cdef inline EventWaitNode GN_wait_event(GraphNode self, Event ev): + cdef cydriver.CUgraphNode new_node = NULL + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) + cdef cydriver.CUgraphNode* deps = NULL + cdef size_t num_deps = 0 + + if pred_node != NULL: + deps = &pred_node + num_deps = 1 + + with nogil: + HANDLE_RETURN(cydriver.cuGraphAddEventWaitNode( + &new_node, as_cu(h_graph), deps, num_deps, as_cu(ev._h_event))) + + _attach_user_object(as_cu(h_graph), new EventHandle(ev._h_event), + _destroy_event_handle_copy) + + self._succ_cache = None + return EventWaitNode._create_with_params( + create_graph_node_handle(new_node, h_graph), ev._h_event) + + +cdef inline HostCallbackNode GN_callback(GraphNode self, object fn, object user_data): + import ctypes as ct + + cdef cydriver.CUDA_HOST_NODE_PARAMS node_params + cdef cydriver.CUgraphNode new_node = NULL + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) + cdef cydriver.CUgraphNode* deps = NULL + cdef size_t num_deps = 0 + + if pred_node != NULL: + deps = &pred_node + num_deps = 1 + + _attach_host_callback_to_graph( + as_cu(h_graph), fn, user_data, + &node_params.fn, &node_params.userData) + + with nogil: + HANDLE_RETURN(cydriver.cuGraphAddHostNode( + &new_node, as_cu(h_graph), deps, num_deps, &node_params)) + + cdef object callable_obj = fn if not isinstance(fn, ct._CFuncPtr) else None + self._succ_cache = None + return HostCallbackNode._create_with_params( + create_graph_node_handle(new_node, h_graph), callable_obj, + node_params.fn, node_params.userData) diff --git a/cuda_core/cuda/core/_graph/_graphdef.pxd b/cuda_core/cuda/core/_graph/_graph_def/_subclasses.pxd similarity index 80% rename from cuda_core/cuda/core/_graph/_graphdef.pxd rename to cuda_core/cuda/core/_graph/_graph_def/_subclasses.pxd index 0657115c34..90ca228ec9 100644 --- a/cuda_core/cuda/core/_graph/_graphdef.pxd +++ b/cuda_core/cuda/core/_graph/_graph_def/_subclasses.pxd @@ -5,55 +5,11 @@ from libc.stddef cimport size_t from cuda.bindings cimport cydriver +from cuda.core._graph._graph_def._graph_def cimport Condition +from cuda.core._graph._graph_def._graph_node cimport GraphNode from cuda.core._resource_handles cimport EventHandle, GraphHandle, GraphNodeHandle, KernelHandle -cdef class Condition -cdef class GraphDef -cdef class GraphNode -cdef class EmptyNode(GraphNode) -cdef class KernelNode(GraphNode) -cdef class AllocNode(GraphNode) -cdef class FreeNode(GraphNode) -cdef class MemsetNode(GraphNode) -cdef class MemcpyNode(GraphNode) -cdef class ChildGraphNode(GraphNode) -cdef class EventRecordNode(GraphNode) -cdef class EventWaitNode(GraphNode) -cdef class HostCallbackNode(GraphNode) -cdef class ConditionalNode(GraphNode) -cdef class IfNode(ConditionalNode) -cdef class IfElseNode(ConditionalNode) -cdef class WhileNode(ConditionalNode) -cdef class SwitchNode(ConditionalNode) - - -cdef class Condition: - cdef: - cydriver.CUgraphConditionalHandle _c_handle - object __weakref__ - - -cdef class GraphDef: - cdef: - GraphHandle _h_graph - object __weakref__ - - @staticmethod - cdef GraphDef _from_handle(GraphHandle h_graph) - - -cdef class GraphNode: - cdef: - GraphNodeHandle _h_node - tuple _pred_cache - tuple _succ_cache - object __weakref__ - - @staticmethod - cdef GraphNode _create(GraphHandle h_graph, cydriver.CUgraphNode node) - - cdef class EmptyNode(GraphNode): @staticmethod cdef EmptyNode _create_impl(GraphNodeHandle h_node) @@ -196,7 +152,7 @@ cdef class ConditionalNode(GraphNode): cdef: Condition _condition cydriver.CUgraphConditionalNodeType _cond_type - tuple _branches # tuple of GraphDef (non-owning wrappers) + tuple _branches @staticmethod cdef ConditionalNode _create_from_driver(GraphNodeHandle h_node) diff --git a/cuda_core/cuda/core/_graph/_graph_def/_subclasses.pyx b/cuda_core/cuda/core/_graph/_graph_def/_subclasses.pyx new file mode 100644 index 0000000000..2c78b3b0ac --- /dev/null +++ b/cuda_core/cuda/core/_graph/_graph_def/_subclasses.pyx @@ -0,0 +1,755 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""GraphNode subclasses — EmptyNode through SwitchNode.""" + +from __future__ import annotations + +from libc.stddef cimport size_t +from libc.stdint cimport uintptr_t + +from cuda.bindings cimport cydriver + +from cuda.core._event cimport Event +from cuda.core._launch_config cimport LaunchConfig +from cuda.core._module cimport Kernel +from cuda.core._graph._graph_def._graph_def cimport Condition, GraphDef +from cuda.core._graph._graph_def._graph_node cimport GraphNode +from cuda.core._resource_handles cimport ( + EventHandle, + GraphHandle, + KernelHandle, + GraphNodeHandle, + as_cu, + as_intptr, + create_event_handle_ref, + create_graph_handle_ref, + create_kernel_handle_ref, + create_graph_node_handle, + graph_node_get_graph, +) +from cuda.core._utils.cuda_utils cimport HANDLE_RETURN + +from cuda.core._graph._utils cimport _is_py_host_trampoline + +from cuda.core._utils.cuda_utils import driver, handle_return + + +cdef bint _has_cuGraphNodeGetParams = False +cdef bint _version_checked = False + +cdef bint _check_node_get_params(): + global _has_cuGraphNodeGetParams, _version_checked + if not _version_checked: + from cuda.core._utils.version import driver_version + _has_cuGraphNodeGetParams = driver_version() >= (13, 2, 0) + _version_checked = True + return _has_cuGraphNodeGetParams + + +cdef class EmptyNode(GraphNode): + """A synchronization / join node with no operation.""" + + @staticmethod + cdef EmptyNode _create_impl(GraphNodeHandle h_node): + cdef EmptyNode n = EmptyNode.__new__(EmptyNode) + n._h_node = h_node + return n + + def __repr__(self) -> str: + cdef Py_ssize_t n = len(self.pred) + return f"" + + +cdef class KernelNode(GraphNode): + """A kernel launch node. + + Properties + ---------- + grid : tuple of int + Grid dimensions (gridDimX, gridDimY, gridDimZ). + block : tuple of int + Block dimensions (blockDimX, blockDimY, blockDimZ). + shmem_size : int + Dynamic shared memory size in bytes. + kernel : Kernel + The kernel object for this launch node. + config : LaunchConfig + A LaunchConfig reconstructed from this node's parameters. + """ + + @staticmethod + cdef KernelNode _create_with_params(GraphNodeHandle h_node, + tuple grid, tuple block, unsigned int shmem_size, + KernelHandle h_kernel): + """Create from known params (called by launch() builder).""" + cdef KernelNode n = KernelNode.__new__(KernelNode) + n._h_node = h_node + n._grid = grid + n._block = block + n._shmem_size = shmem_size + n._h_kernel = h_kernel + return n + + @staticmethod + cdef KernelNode _create_from_driver(GraphNodeHandle h_node): + """Create by fetching params from the driver (called by _create factory).""" + cdef cydriver.CUgraphNode node = as_cu(h_node) + cdef cydriver.CUDA_KERNEL_NODE_PARAMS params + with nogil: + HANDLE_RETURN(cydriver.cuGraphKernelNodeGetParams(node, ¶ms)) + cdef KernelHandle h_kernel = create_kernel_handle_ref(params.kern) + return KernelNode._create_with_params( + h_node, + (params.gridDimX, params.gridDimY, params.gridDimZ), + (params.blockDimX, params.blockDimY, params.blockDimZ), + params.sharedMemBytes, + h_kernel) + + def __repr__(self) -> str: + return (f"") + + @property + def grid(self) -> tuple: + """Grid dimensions as a 3-tuple (gridDimX, gridDimY, gridDimZ).""" + return self._grid + + @property + def block(self) -> tuple: + """Block dimensions as a 3-tuple (blockDimX, blockDimY, blockDimZ).""" + return self._block + + @property + def shmem_size(self) -> int: + """Dynamic shared memory size in bytes.""" + return self._shmem_size + + @property + def kernel(self) -> Kernel: + """The Kernel object for this launch node.""" + return Kernel._from_handle(self._h_kernel) + + @property + def config(self) -> LaunchConfig: + """A LaunchConfig reconstructed from this node's grid, block, and shmem_size. + + Note: cluster dimensions and cooperative_launch are not preserved + by the CUDA driver's kernel node params, so they are not included. + """ + return LaunchConfig(grid=self._grid, block=self._block, + shmem_size=self._shmem_size) + + +cdef class AllocNode(GraphNode): + """A memory allocation node. + + Properties + ---------- + dptr : int + The device pointer for the allocation. + bytesize : int + The number of bytes allocated. + device_id : int + The device on which the allocation was made. + memory_type : str + The type of memory allocated (``"device"``, ``"host"``, or ``"managed"``). + peer_access : tuple of int + Device IDs that have read-write access to this allocation. + options : GraphAllocOptions + A GraphAllocOptions reconstructed from this node's parameters. + """ + + @staticmethod + cdef AllocNode _create_with_params(GraphNodeHandle h_node, + cydriver.CUdeviceptr dptr, size_t bytesize, + int device_id, str memory_type, tuple peer_access): + """Create from known params (called by alloc() builder).""" + cdef AllocNode n = AllocNode.__new__(AllocNode) + n._h_node = h_node + n._dptr = dptr + n._bytesize = bytesize + n._device_id = device_id + n._memory_type = memory_type + n._peer_access = peer_access + return n + + @staticmethod + cdef AllocNode _create_from_driver(GraphNodeHandle h_node): + """Create by fetching params from the driver (called by _create factory).""" + cdef cydriver.CUgraphNode node = as_cu(h_node) + cdef cydriver.CUDA_MEM_ALLOC_NODE_PARAMS params + with nogil: + HANDLE_RETURN(cydriver.cuGraphMemAllocNodeGetParams(node, ¶ms)) + + cdef str memory_type + if params.poolProps.allocType == cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED: + if params.poolProps.location.type == cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_HOST: + memory_type = "host" + else: + memory_type = "device" + else: + IF CUDA_CORE_BUILD_MAJOR >= 13: + if params.poolProps.allocType == cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_MANAGED: + memory_type = "managed" + else: + memory_type = "device" + ELSE: + memory_type = "device" + + cdef list peer_ids = [] + cdef size_t i + for i in range(params.accessDescCount): + peer_ids.append(params.accessDescs[i].location.id) + + return AllocNode._create_with_params( + h_node, params.dptr, params.bytesize, + params.poolProps.location.id, memory_type, tuple(peer_ids)) + + def __repr__(self) -> str: + return f"" + + @property + def dptr(self) -> int: + """The device pointer for the allocation.""" + return self._dptr + + @property + def bytesize(self) -> int: + """The number of bytes allocated.""" + return self._bytesize + + @property + def device_id(self) -> int: + """The device on which the allocation was made.""" + return self._device_id + + @property + def memory_type(self) -> str: + """The type of memory: ``"device"``, ``"host"``, or ``"managed"``.""" + return self._memory_type + + @property + def peer_access(self) -> tuple: + """Device IDs with read-write access to this allocation.""" + return self._peer_access + + @property + def options(self): + """A GraphAllocOptions reconstructed from this node's parameters.""" + from cuda.core._graph._graph_def._graph_def import GraphAllocOptions + return GraphAllocOptions( + device=self._device_id, + memory_type=self._memory_type, + peer_access=list(self._peer_access) if self._peer_access else None, + ) + + +cdef class FreeNode(GraphNode): + """A memory free node. + + Properties + ---------- + dptr : int + The device pointer being freed. + """ + + @staticmethod + cdef FreeNode _create_with_params(GraphNodeHandle h_node, + cydriver.CUdeviceptr dptr): + """Create from known params (called by free() builder).""" + cdef FreeNode n = FreeNode.__new__(FreeNode) + n._h_node = h_node + n._dptr = dptr + return n + + @staticmethod + cdef FreeNode _create_from_driver(GraphNodeHandle h_node): + """Create by fetching params from the driver (called by _create factory).""" + cdef cydriver.CUgraphNode node = as_cu(h_node) + cdef cydriver.CUdeviceptr dptr + with nogil: + HANDLE_RETURN(cydriver.cuGraphMemFreeNodeGetParams(node, &dptr)) + return FreeNode._create_with_params(h_node, dptr) + + def __repr__(self) -> str: + return f"" + + @property + def dptr(self) -> int: + """The device pointer being freed.""" + return self._dptr + + +cdef class MemsetNode(GraphNode): + """A memory set node. + + Properties + ---------- + dptr : int + The destination device pointer. + value : int + The fill value. + element_size : int + Element size in bytes (1, 2, or 4). + width : int + Width of the row in elements. + height : int + Number of rows. + pitch : int + Pitch in bytes (unused if height is 1). + """ + + @staticmethod + cdef MemsetNode _create_with_params(GraphNodeHandle h_node, + cydriver.CUdeviceptr dptr, unsigned int value, + unsigned int element_size, size_t width, + size_t height, size_t pitch): + """Create from known params (called by memset() builder).""" + cdef MemsetNode n = MemsetNode.__new__(MemsetNode) + n._h_node = h_node + n._dptr = dptr + n._value = value + n._element_size = element_size + n._width = width + n._height = height + n._pitch = pitch + return n + + @staticmethod + cdef MemsetNode _create_from_driver(GraphNodeHandle h_node): + """Create by fetching params from the driver (called by _create factory).""" + cdef cydriver.CUgraphNode node = as_cu(h_node) + cdef cydriver.CUDA_MEMSET_NODE_PARAMS params + with nogil: + HANDLE_RETURN(cydriver.cuGraphMemsetNodeGetParams(node, ¶ms)) + return MemsetNode._create_with_params( + h_node, params.dst, params.value, + params.elementSize, params.width, params.height, params.pitch) + + def __repr__(self) -> str: + return (f"") + + @property + def dptr(self) -> int: + """The destination device pointer.""" + return self._dptr + + @property + def value(self) -> int: + """The fill value.""" + return self._value + + @property + def element_size(self) -> int: + """Element size in bytes (1, 2, or 4).""" + return self._element_size + + @property + def width(self) -> int: + """Width of the row in elements.""" + return self._width + + @property + def height(self) -> int: + """Number of rows.""" + return self._height + + @property + def pitch(self) -> int: + """Pitch in bytes (unused if height is 1).""" + return self._pitch + + +cdef class MemcpyNode(GraphNode): + """A memory copy node. + + Properties + ---------- + dst : int + The destination pointer. + src : int + The source pointer. + size : int + The number of bytes copied. + """ + + @staticmethod + cdef MemcpyNode _create_with_params(GraphNodeHandle h_node, + cydriver.CUdeviceptr dst, cydriver.CUdeviceptr src, + size_t size, cydriver.CUmemorytype dst_type, + cydriver.CUmemorytype src_type): + """Create from known params (called by memcpy() builder).""" + cdef MemcpyNode n = MemcpyNode.__new__(MemcpyNode) + n._h_node = h_node + n._dst = dst + n._src = src + n._size = size + n._dst_type = dst_type + n._src_type = src_type + return n + + @staticmethod + cdef MemcpyNode _create_from_driver(GraphNodeHandle h_node): + """Create by fetching params from the driver (called by _create factory).""" + cdef cydriver.CUgraphNode node = as_cu(h_node) + cdef cydriver.CUDA_MEMCPY3D params + with nogil: + HANDLE_RETURN(cydriver.cuGraphMemcpyNodeGetParams(node, ¶ms)) + + cdef cydriver.CUdeviceptr dst + cdef cydriver.CUdeviceptr src + if params.dstMemoryType == cydriver.CU_MEMORYTYPE_HOST: + dst = params.dstHost + else: + dst = params.dstDevice + if params.srcMemoryType == cydriver.CU_MEMORYTYPE_HOST: + src = params.srcHost + else: + src = params.srcDevice + + return MemcpyNode._create_with_params( + h_node, dst, src, params.WidthInBytes, + params.dstMemoryType, params.srcMemoryType) + + def __repr__(self) -> str: + cdef str dt = "H" if self._dst_type == cydriver.CU_MEMORYTYPE_HOST else "D" + cdef str st = "H" if self._src_type == cydriver.CU_MEMORYTYPE_HOST else "D" + return (f"") + + @property + def dst(self) -> int: + """The destination pointer.""" + return self._dst + + @property + def src(self) -> int: + """The source pointer.""" + return self._src + + @property + def size(self) -> int: + """The number of bytes copied.""" + return self._size + + +cdef class ChildGraphNode(GraphNode): + """A child graph (sub-graph) node. + + Properties + ---------- + child_graph : GraphDef + The embedded graph definition (non-owning wrapper). + """ + + @staticmethod + cdef ChildGraphNode _create_with_params(GraphNodeHandle h_node, + GraphHandle h_child_graph): + """Create from known params (called by embed() builder).""" + cdef ChildGraphNode n = ChildGraphNode.__new__(ChildGraphNode) + n._h_node = h_node + n._h_child_graph = h_child_graph + return n + + @staticmethod + cdef ChildGraphNode _create_from_driver(GraphNodeHandle h_node): + """Create by fetching params from the driver (called by _create factory).""" + cdef cydriver.CUgraphNode node = as_cu(h_node) + cdef cydriver.CUgraph child_graph = NULL + with nogil: + HANDLE_RETURN(cydriver.cuGraphChildGraphNodeGetGraph(node, &child_graph)) + cdef GraphHandle h_graph = graph_node_get_graph(h_node) + cdef GraphHandle h_child = create_graph_handle_ref(child_graph, h_graph) + return ChildGraphNode._create_with_params(h_node, h_child) + + def __repr__(self) -> str: + cdef cydriver.CUgraph g = as_cu(self._h_child_graph) + cdef size_t num_nodes = 0 + with nogil: + HANDLE_RETURN(cydriver.cuGraphGetNodes(g, NULL, &num_nodes)) + cdef Py_ssize_t n = num_nodes + return f"" + + @property + def child_graph(self) -> "GraphDef": + """The embedded graph definition (non-owning wrapper).""" + return GraphDef._from_handle(self._h_child_graph) + + +cdef class EventRecordNode(GraphNode): + """An event record node. + + Properties + ---------- + event : Event + The event being recorded. + """ + + @staticmethod + cdef EventRecordNode _create_with_params(GraphNodeHandle h_node, + EventHandle h_event): + """Create from known params (called by record_event() builder).""" + cdef EventRecordNode n = EventRecordNode.__new__(EventRecordNode) + n._h_node = h_node + n._h_event = h_event + return n + + @staticmethod + cdef EventRecordNode _create_from_driver(GraphNodeHandle h_node): + """Create by fetching params from the driver (called by _create factory).""" + cdef cydriver.CUgraphNode node = as_cu(h_node) + cdef cydriver.CUevent event + with nogil: + HANDLE_RETURN(cydriver.cuGraphEventRecordNodeGetEvent(node, &event)) + cdef EventHandle h_event = create_event_handle_ref(event) + return EventRecordNode._create_with_params(h_node, h_event) + + def __repr__(self) -> str: + return f"" + + @property + def event(self) -> Event: + """The event being recorded.""" + return Event._from_handle(self._h_event) + + +cdef class EventWaitNode(GraphNode): + """An event wait node. + + Properties + ---------- + event : Event + The event being waited on. + """ + + @staticmethod + cdef EventWaitNode _create_with_params(GraphNodeHandle h_node, + EventHandle h_event): + """Create from known params (called by wait_event() builder).""" + cdef EventWaitNode n = EventWaitNode.__new__(EventWaitNode) + n._h_node = h_node + n._h_event = h_event + return n + + @staticmethod + cdef EventWaitNode _create_from_driver(GraphNodeHandle h_node): + """Create by fetching params from the driver (called by _create factory).""" + cdef cydriver.CUgraphNode node = as_cu(h_node) + cdef cydriver.CUevent event + with nogil: + HANDLE_RETURN(cydriver.cuGraphEventWaitNodeGetEvent(node, &event)) + cdef EventHandle h_event = create_event_handle_ref(event) + return EventWaitNode._create_with_params(h_node, h_event) + + def __repr__(self) -> str: + return f"" + + @property + def event(self) -> Event: + """The event being waited on.""" + return Event._from_handle(self._h_event) + + +cdef class HostCallbackNode(GraphNode): + """A host callback node. + + Properties + ---------- + callback_fn : callable or None + The Python callable (None for ctypes function pointer callbacks). + """ + + @staticmethod + cdef HostCallbackNode _create_with_params(GraphNodeHandle h_node, + object callable_obj, cydriver.CUhostFn fn, + void* user_data): + """Create from known params (called by callback() builder).""" + cdef HostCallbackNode n = HostCallbackNode.__new__(HostCallbackNode) + n._h_node = h_node + n._callable = callable_obj + n._fn = fn + n._user_data = user_data + return n + + @staticmethod + cdef HostCallbackNode _create_from_driver(GraphNodeHandle h_node): + """Create by fetching params from the driver (called by _create factory).""" + cdef cydriver.CUgraphNode node = as_cu(h_node) + cdef cydriver.CUDA_HOST_NODE_PARAMS params + with nogil: + HANDLE_RETURN(cydriver.cuGraphHostNodeGetParams(node, ¶ms)) + + cdef object callable_obj = None + if _is_py_host_trampoline(params.fn): + callable_obj = params.userData + + return HostCallbackNode._create_with_params( + h_node, callable_obj, params.fn, params.userData) + + def __repr__(self) -> str: + if self._callable is not None: + name = getattr(self._callable, '__name__', '?') + return f"" + return f"self._fn:x}>" + + @property + def callback_fn(self): + """The Python callable, or None for ctypes function pointer callbacks.""" + return self._callable + + +cdef class ConditionalNode(GraphNode): + """Base class for conditional graph nodes. + + When created via builder methods (if_cond, if_else, while_loop, switch), + a specific subclass (IfNode, IfElseNode, WhileNode, SwitchNode) is + returned. When reconstructed from the driver on CUDA 13.2+, the + correct subclass is determined via cuGraphNodeGetParams. On older + drivers, this base class is used as a fallback. + + Properties + ---------- + condition : Condition or None + The condition variable controlling execution (None pre-13.2). + cond_type : str or None + The conditional type ("if", "while", or "switch"; None pre-13.2). + branches : tuple of GraphDef + The body graphs for each branch (empty pre-13.2). + """ + + @staticmethod + cdef ConditionalNode _create_from_driver(GraphNodeHandle h_node): + cdef ConditionalNode n + if not _check_node_get_params(): + n = ConditionalNode.__new__(ConditionalNode) + n._h_node = h_node + n._condition = None + n._cond_type = cydriver.CU_GRAPH_COND_TYPE_IF + n._branches = () + return n + + cdef cydriver.CUgraphNode node = as_cu(h_node) + params = handle_return(driver.cuGraphNodeGetParams( + node)) + cond_params = params.conditional + cdef int cond_type_int = int(cond_params.type) + cdef unsigned int size = int(cond_params.size) + + cdef Condition condition = Condition.__new__(Condition) + condition._c_handle = ( + int(cond_params.handle)) + + cdef GraphHandle h_graph = graph_node_get_graph(h_node) + cdef list branch_list = [] + cdef unsigned int i + cdef GraphHandle h_branch + if cond_params.phGraph_out is not None: + for i in range(size): + h_branch = create_graph_handle_ref( + int(cond_params.phGraph_out[i]), + h_graph) + branch_list.append(GraphDef._from_handle(h_branch)) + cdef tuple branches = tuple(branch_list) + + cdef type cls + if cond_type_int == cydriver.CU_GRAPH_COND_TYPE_IF: + if size == 1: + cls = IfNode + else: + cls = IfElseNode + elif cond_type_int == cydriver.CU_GRAPH_COND_TYPE_WHILE: + cls = WhileNode + else: + cls = SwitchNode + + n = cls.__new__(cls) + n._h_node = h_node + n._condition = condition + n._cond_type = cond_type_int + n._branches = branches + return n + + def __repr__(self) -> str: + return "" + + @property + def condition(self) -> Condition | None: + """The condition variable controlling execution.""" + return self._condition + + @property + def cond_type(self) -> str | None: + """The conditional type as a string: 'if', 'while', or 'switch'. + + Returns None when reconstructed from the driver pre-CUDA 13.2, + as the conditional type cannot be determined. + """ + if self._condition is None: + return None + if self._cond_type == cydriver.CU_GRAPH_COND_TYPE_IF: + return "if" + elif self._cond_type == cydriver.CU_GRAPH_COND_TYPE_WHILE: + return "while" + else: + return "switch" + + @property + def branches(self) -> tuple: + """The body graphs for each branch as a tuple of GraphDef. + + Returns an empty tuple when reconstructed from the driver + pre-CUDA 13.2. + """ + return self._branches + + +cdef class IfNode(ConditionalNode): + """An if-conditional node (1 branch, executes when condition is non-zero).""" + + def __repr__(self) -> str: + return f"self._condition._c_handle:x}>" + + @property + def then(self) -> "GraphDef": + """The 'then' branch graph.""" + return self._branches[0] + + +cdef class IfElseNode(ConditionalNode): + """An if-else conditional node (2 branches).""" + + def __repr__(self) -> str: + return f"self._condition._c_handle:x}>" + + @property + def then(self) -> "GraphDef": + """The 'then' branch graph (executed when condition is non-zero).""" + return self._branches[0] + + @property + def else_(self) -> "GraphDef": + """The 'else' branch graph (executed when condition is zero).""" + return self._branches[1] + + +cdef class WhileNode(ConditionalNode): + """A while-loop conditional node (1 branch, repeats while condition is non-zero).""" + + def __repr__(self) -> str: + return f"self._condition._c_handle:x}>" + + @property + def body(self) -> "GraphDef": + """The loop body graph.""" + return self._branches[0] + + +cdef class SwitchNode(ConditionalNode): + """A switch conditional node (N branches, selected by condition value).""" + + def __repr__(self) -> str: + cdef Py_ssize_t n = len(self._branches) + return (f"self._condition._c_handle:x}" + f" with {n} {'branch' if n == 1 else 'branches'}>") diff --git a/cuda_core/cuda/core/_graph/_graphdef.pyx b/cuda_core/cuda/core/_graph/_graphdef.pyx deleted file mode 100644 index 58074f9c1e..0000000000 --- a/cuda_core/cuda/core/_graph/_graphdef.pyx +++ /dev/null @@ -1,2053 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -""" -Private module for explicit CUDA graph construction. - -This module provides GraphDef and a GraphNode class hierarchy for building CUDA -graphs explicitly (as opposed to stream capture). Both approaches produce -the same public Graph type for execution. - -GraphNode hierarchy: - GraphNode (base — also used for the internal entry point) - ├── EmptyNode (synchronization / join point) - ├── KernelNode (kernel launch) - ├── AllocNode (memory allocation, exposes dptr and bytesize) - ├── FreeNode (memory free, exposes dptr) - ├── MemsetNode (memory set, exposes dptr, value, element_size, etc.) - ├── MemcpyNode (memory copy, exposes dst, src, size) - ├── ChildGraphNode (embedded sub-graph) - ├── EventRecordNode (record an event) - ├── EventWaitNode (wait for an event) - ├── HostCallbackNode (host CPU callback) - └── ConditionalNode (conditional execution — base for reconstruction) - ├── IfNode (if-then conditional, 1 branch) - ├── IfElseNode (if-then-else conditional, 2 branches) - ├── WhileNode (while-loop conditional, 1 branch) - └── SwitchNode (switch conditional, N branches) -""" - -from __future__ import annotations - -from libc.stddef cimport size_t -from libc.stdint cimport uintptr_t -from libc.stdlib cimport malloc, free -from libc.string cimport memset as c_memset, memcpy as c_memcpy - -from libcpp.vector cimport vector - -from cuda.bindings cimport cydriver - -from cuda.core._event cimport Event -from cuda.core._kernel_arg_handler cimport ParamHolder -from cuda.core._launch_config cimport LaunchConfig -from cuda.core._module cimport Kernel -from cuda.core._resource_handles cimport ( - EventHandle, - GraphHandle, - KernelHandle, - GraphNodeHandle, - as_cu, - as_intptr, - as_py, - create_event_handle_ref, - create_graph_handle, - create_graph_handle_ref, - create_kernel_handle_ref, - create_graph_node_handle, - graph_node_get_graph, -) -from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, _parse_fill_value - -from dataclasses import dataclass - -from cuda.core import Device -from cuda.core._utils.cuda_utils import driver, handle_return - -__all__ = [ - "Condition", - "GraphAllocOptions", - "GraphDef", - "GraphNode", - "EmptyNode", - "KernelNode", - "AllocNode", - "FreeNode", - "MemsetNode", - "MemcpyNode", - "ChildGraphNode", - "EventRecordNode", - "EventWaitNode", - "HostCallbackNode", - "ConditionalNode", - "IfNode", - "IfElseNode", - "WhileNode", - "SwitchNode", -] - - -cdef bint _has_cuGraphNodeGetParams = False -cdef bint _version_checked = False - -cdef bint _check_node_get_params(): - global _has_cuGraphNodeGetParams, _version_checked - if not _version_checked: - from cuda.core._utils.version import driver_version - _has_cuGraphNodeGetParams = driver_version() >= (13, 2, 0) - _version_checked = True - return _has_cuGraphNodeGetParams - - -from cuda.core._graph._utils cimport ( - _attach_host_callback_to_graph, - _attach_user_object, - _is_py_host_trampoline, -) - - -cdef void _destroy_event_handle_copy(void* ptr) noexcept nogil: - cdef EventHandle* p = ptr - del p - - -cdef void _destroy_kernel_handle_copy(void* ptr) noexcept nogil: - cdef KernelHandle* p = ptr - del p - - - - -cdef class Condition: - """Wraps a CUgraphConditionalHandle. - - Created by :meth:`GraphDef.create_condition` and passed to - conditional-node builder methods (``if_cond``, ``if_else``, - ``while_loop``, ``switch``). The underlying value is set at - runtime by device code via ``cudaGraphSetConditional``. - """ - - def __repr__(self) -> str: - return f"self._c_handle:x}>" - - def __eq__(self, other) -> bool: - if not isinstance(other, Condition): - return NotImplemented - return self._c_handle == (other)._c_handle - - def __hash__(self) -> int: - return hash(self._c_handle) - - @property - def handle(self) -> int: - """The raw CUgraphConditionalHandle as an int.""" - return self._c_handle - - -cdef ConditionalNode _make_conditional_node( - GraphNode pred, - Condition condition, - cydriver.CUgraphConditionalNodeType cond_type, - unsigned int size, - type node_cls): - if not isinstance(condition, Condition): - raise TypeError( - f"condition must be a Condition object (from " - f"GraphDef.create_condition()), got {type(condition).__name__}") - cdef cydriver.CUgraphNodeParams params - cdef cydriver.CUgraphNode new_node = NULL - - c_memset(¶ms, 0, sizeof(params)) - params.type = cydriver.CU_GRAPH_NODE_TYPE_CONDITIONAL - params.conditional.handle = condition._c_handle - params.conditional.type = cond_type - params.conditional.size = size - - cdef cydriver.CUcontext ctx = NULL - cdef GraphHandle h_graph = graph_node_get_graph(pred._h_node) - cdef cydriver.CUgraphNode pred_node = as_cu(pred._h_node) - cdef cydriver.CUgraphNode* deps = NULL - cdef size_t num_deps = 0 - - if pred_node != NULL: - deps = &pred_node - num_deps = 1 - - with nogil: - HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx)) - params.conditional.ctx = ctx - - with nogil: - IF CUDA_CORE_BUILD_MAJOR >= 13: - HANDLE_RETURN(cydriver.cuGraphAddNode( - &new_node, as_cu(h_graph), deps, NULL, num_deps, ¶ms)) - ELSE: - HANDLE_RETURN(cydriver.cuGraphAddNode( - &new_node, as_cu(h_graph), deps, num_deps, ¶ms)) - - # cuGraphAddNode sets phGraph_out to an internal array of body - # graphs (it replaces the pointer, not writing into a caller array). - cdef list branch_list = [] - cdef unsigned int i - cdef cydriver.CUgraph bg - cdef GraphHandle h_branch - for i in range(size): - bg = params.conditional.phGraph_out[i] - h_branch = create_graph_handle_ref(bg, h_graph) - branch_list.append(GraphDef._from_handle(h_branch)) - cdef tuple branches = tuple(branch_list) - - cdef ConditionalNode n = node_cls.__new__(node_cls) - n._h_node = create_graph_node_handle(new_node, h_graph) - n._condition = condition - n._cond_type = cond_type - n._branches = branches - - pred._succ_cache = None - return n - - -@dataclass -class GraphAllocOptions: - """Options for graph memory allocation nodes. - - Attributes - ---------- - device : int or Device, optional - The device on which to allocate memory. If None (default), - uses the current CUDA context's device. - memory_type : str, optional - Type of memory to allocate. One of: - - - ``"device"`` (default): Pinned device memory, optimal for GPU kernels. - - ``"host"``: Pinned host memory, accessible from both host and device. - Useful for graphs containing host callback nodes. Note: may not be - supported on all systems/drivers. - - ``"managed"``: Managed/unified memory that automatically migrates - between host and device. Useful for mixed host/device access patterns. - - peer_access : list of int or Device, optional - List of devices that should have read-write access to the - allocated memory. If None (default), only the allocating - device has access. - - Notes - ----- - - IPC (inter-process communication) is not supported for graph - memory allocation nodes per CUDA documentation. - - The allocation uses the device's default memory pool. - """ - - device: int | Device | None = None - memory_type: str = "device" - peer_access: list | None = None - - -cdef class GraphDef: - """Represents a CUDA graph definition (CUgraph). - - A GraphDef is used to construct a graph explicitly by adding nodes - and specifying dependencies. Once construction is complete, call - instantiate() to obtain an executable Graph. - """ - - def __init__(self): - """Create a new empty graph definition.""" - cdef cydriver.CUgraph graph = NULL - with nogil: - HANDLE_RETURN(cydriver.cuGraphCreate(&graph, 0)) - self._h_graph = create_graph_handle(graph) - - @staticmethod - cdef GraphDef _from_handle(GraphHandle h_graph): - """Create a GraphDef from an existing GraphHandle (internal use).""" - cdef GraphDef g = GraphDef.__new__(GraphDef) - g._h_graph = h_graph - return g - - def __repr__(self) -> str: - return f"" - - def __eq__(self, other) -> bool: - if not isinstance(other, GraphDef): - return NotImplemented - return as_intptr(self._h_graph) == as_intptr((other)._h_graph) - - def __hash__(self) -> int: - return hash(as_intptr(self._h_graph)) - - @property - def _entry(self) -> GraphNode: - """Return the internal entry-point GraphNode (no dependencies).""" - cdef GraphNode n = GraphNode.__new__(GraphNode) - n._h_node = create_graph_node_handle(NULL, self._h_graph) - return n - - def alloc(self, size_t size, options: GraphAllocOptions | None = None) -> AllocNode: - """Add an entry-point memory allocation node (no dependencies). - - See :meth:`GraphNode.alloc` for full documentation. - """ - return self._entry.alloc(size, options) - - def free(self, dptr) -> FreeNode: - """Add an entry-point memory free node (no dependencies). - - See :meth:`GraphNode.free` for full documentation. - """ - return self._entry.free(dptr) - - def memset(self, dst, value, size_t width, size_t height=1, size_t pitch=0) -> MemsetNode: - """Add an entry-point memset node (no dependencies). - - See :meth:`GraphNode.memset` for full documentation. - """ - return self._entry.memset(dst, value, width, height, pitch) - - def launch(self, config, kernel, *args) -> KernelNode: - """Add an entry-point kernel launch node (no dependencies). - - See :meth:`GraphNode.launch` for full documentation. - """ - return self._entry.launch(config, kernel, *args) - - def join(self, *nodes) -> EmptyNode: - """Create an empty node that depends on all given nodes. - - Parameters - ---------- - *nodes : GraphNode - Nodes to merge. - - Returns - ------- - EmptyNode - A new EmptyNode that depends on all input nodes. - """ - return self._entry.join(*nodes) - - def memcpy(self, dst, src, size_t size) -> MemcpyNode: - """Add an entry-point memcpy node (no dependencies). - - See :meth:`GraphNode.memcpy` for full documentation. - """ - return self._entry.memcpy(dst, src, size) - - def embed(self, child: GraphDef) -> ChildGraphNode: - """Add an entry-point child graph node (no dependencies). - - See :meth:`GraphNode.embed` for full documentation. - """ - return self._entry.embed(child) - - def record_event(self, event: Event) -> EventRecordNode: - """Add an entry-point event record node (no dependencies). - - See :meth:`GraphNode.record_event` for full documentation. - """ - return self._entry.record_event(event) - - def wait_event(self, event: Event) -> EventWaitNode: - """Add an entry-point event wait node (no dependencies). - - See :meth:`GraphNode.wait_event` for full documentation. - """ - return self._entry.wait_event(event) - - def callback(self, fn, *, user_data=None) -> HostCallbackNode: - """Add an entry-point host callback node (no dependencies). - - See :meth:`GraphNode.callback` for full documentation. - """ - return self._entry.callback(fn, user_data=user_data) - - def create_condition(self, default_value: int | None = None) -> Condition: - """Create a condition variable for use with conditional nodes. - - The returned :class:`Condition` object is passed to conditional-node - builder methods. Its value is controlled at runtime by device code - via ``cudaGraphSetConditional``. - - Parameters - ---------- - default_value : int, optional - The default value to assign to the condition. - If None, no default is assigned. - - Returns - ------- - Condition - A condition variable for controlling conditional execution. - """ - cdef cydriver.CUgraphConditionalHandle c_handle - cdef unsigned int flags = 0 - cdef unsigned int default_val = 0 - - if default_value is not None: - default_val = default_value - flags = cydriver.CU_GRAPH_COND_ASSIGN_DEFAULT - - cdef cydriver.CUcontext ctx = NULL - with nogil: - HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx)) - HANDLE_RETURN(cydriver.cuGraphConditionalHandleCreate( - &c_handle, as_cu(self._h_graph), ctx, default_val, flags)) - - cdef Condition cond = Condition.__new__(Condition) - cond._c_handle = c_handle - return cond - - def if_cond(self, condition: Condition) -> IfNode: - """Add an entry-point if-conditional node (no dependencies). - - See :meth:`GraphNode.if_cond` for full documentation. - """ - return self._entry.if_cond(condition) - - def if_else(self, condition: Condition) -> IfElseNode: - """Add an entry-point if-else conditional node (no dependencies). - - See :meth:`GraphNode.if_else` for full documentation. - """ - return self._entry.if_else(condition) - - def while_loop(self, condition: Condition) -> WhileNode: - """Add an entry-point while-loop conditional node (no dependencies). - - See :meth:`GraphNode.while_loop` for full documentation. - """ - return self._entry.while_loop(condition) - - def switch(self, condition: Condition, unsigned int count) -> SwitchNode: - """Add an entry-point switch conditional node (no dependencies). - - See :meth:`GraphNode.switch` for full documentation. - """ - return self._entry.switch(condition, count) - - def instantiate(self, options=None): - """Instantiate the graph definition into an executable Graph. - - Parameters - ---------- - options : :obj:`~_graph.GraphCompleteOptions`, optional - Customizable dataclass for graph instantiation options. - - Returns - ------- - Graph - An executable graph that can be launched on a stream. - """ - from cuda.core._graph._graph_builder import _instantiate_graph - - return _instantiate_graph( - driver.CUgraph(as_intptr(self._h_graph)), options) - - def debug_dot_print(self, path: str, options=None) -> None: - """Write a GraphViz DOT representation of the graph to a file. - - Parameters - ---------- - path : str - File path for the DOT output. - options : GraphDebugPrintOptions, optional - Customizable options for the debug print. - """ - from cuda.core._graph._graph_builder import GraphDebugPrintOptions - - cdef unsigned int flags = 0 - if options is not None: - if not isinstance(options, GraphDebugPrintOptions): - raise TypeError("options must be a GraphDebugPrintOptions instance") - flags = options._to_flags() - - cdef bytes path_bytes = path.encode('utf-8') - cdef const char* c_path = path_bytes - with nogil: - HANDLE_RETURN(cydriver.cuGraphDebugDotPrint(as_cu(self._h_graph), c_path, flags)) - - def nodes(self) -> tuple: - """Return all nodes in the graph. - - Returns - ------- - tuple of GraphNode - All nodes in the graph. - """ - cdef size_t num_nodes = 0 - - with nogil: - HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), NULL, &num_nodes)) - - if num_nodes == 0: - return () - - cdef vector[cydriver.CUgraphNode] nodes_vec - nodes_vec.resize(num_nodes) - with nogil: - HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes)) - - return tuple(GraphNode._create(self._h_graph, nodes_vec[i]) for i in range(num_nodes)) - - def edges(self) -> tuple: - """Return all edges in the graph as (from_node, to_node) pairs. - - Returns - ------- - tuple of tuple - Each element is a (from_node, to_node) pair representing - a dependency edge in the graph. - """ - cdef size_t num_edges = 0 - - with nogil: - IF CUDA_CORE_BUILD_MAJOR >= 13: - HANDLE_RETURN(cydriver.cuGraphGetEdges(as_cu(self._h_graph), NULL, NULL, NULL, &num_edges)) - ELSE: - HANDLE_RETURN(cydriver.cuGraphGetEdges(as_cu(self._h_graph), NULL, NULL, &num_edges)) - - if num_edges == 0: - return () - - cdef vector[cydriver.CUgraphNode] from_nodes - cdef vector[cydriver.CUgraphNode] to_nodes - from_nodes.resize(num_edges) - to_nodes.resize(num_edges) - with nogil: - IF CUDA_CORE_BUILD_MAJOR >= 13: - HANDLE_RETURN(cydriver.cuGraphGetEdges( - as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), NULL, &num_edges)) - ELSE: - HANDLE_RETURN(cydriver.cuGraphGetEdges( - as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), &num_edges)) - - return tuple( - (GraphNode._create(self._h_graph, from_nodes[i]), - GraphNode._create(self._h_graph, to_nodes[i])) - for i in range(num_edges) - ) - - @property - def handle(self): - """Return the underlying driver CUgraph handle.""" - return as_py(self._h_graph) - - -cdef class GraphNode: - """Base class for all graph nodes. - - Nodes are created by calling builder methods on GraphDef (for - entry-point nodes with no dependencies) or on other Nodes (for - nodes that depend on a predecessor). - """ - - @staticmethod - cdef GraphNode _create(GraphHandle h_graph, cydriver.CUgraphNode node): - """Factory: dispatch to the right subclass based on node type.""" - if node == NULL: - n = GraphNode.__new__(GraphNode) - (n)._h_node = create_graph_node_handle(node, h_graph) - return n - - cdef GraphNodeHandle h_node = create_graph_node_handle(node, h_graph) - cdef cydriver.CUgraphNodeType node_type - with nogil: - HANDLE_RETURN(cydriver.cuGraphNodeGetType(node, &node_type)) - - if node_type == cydriver.CU_GRAPH_NODE_TYPE_EMPTY: - return EmptyNode._create_impl(h_node) - elif node_type == cydriver.CU_GRAPH_NODE_TYPE_KERNEL: - return KernelNode._create_from_driver(h_node) - elif node_type == cydriver.CU_GRAPH_NODE_TYPE_MEM_ALLOC: - return AllocNode._create_from_driver(h_node) - elif node_type == cydriver.CU_GRAPH_NODE_TYPE_MEM_FREE: - return FreeNode._create_from_driver(h_node) - elif node_type == cydriver.CU_GRAPH_NODE_TYPE_MEMSET: - return MemsetNode._create_from_driver(h_node) - elif node_type == cydriver.CU_GRAPH_NODE_TYPE_MEMCPY: - return MemcpyNode._create_from_driver(h_node) - elif node_type == cydriver.CU_GRAPH_NODE_TYPE_GRAPH: - return ChildGraphNode._create_from_driver(h_node) - elif node_type == cydriver.CU_GRAPH_NODE_TYPE_EVENT_RECORD: - return EventRecordNode._create_from_driver(h_node) - elif node_type == cydriver.CU_GRAPH_NODE_TYPE_WAIT_EVENT: - return EventWaitNode._create_from_driver(h_node) - elif node_type == cydriver.CU_GRAPH_NODE_TYPE_HOST: - return HostCallbackNode._create_from_driver(h_node) - elif node_type == cydriver.CU_GRAPH_NODE_TYPE_CONDITIONAL: - return ConditionalNode._create_from_driver(h_node) - else: - n = GraphNode.__new__(GraphNode) - (n)._h_node = h_node - return n - - def __repr__(self) -> str: - cdef cydriver.CUgraphNode node = as_cu(self._h_node) - if node == NULL: - return "" - return f"node:x}>" - - def __eq__(self, other) -> bool: - if not isinstance(other, GraphNode): - return NotImplemented - cdef GraphNode o = other - cdef GraphHandle self_graph = graph_node_get_graph(self._h_node) - cdef GraphHandle other_graph = graph_node_get_graph(o._h_node) - return (as_intptr(self._h_node) == as_intptr(o._h_node) - and as_intptr(self_graph) == as_intptr(other_graph)) - - def __hash__(self) -> int: - cdef GraphHandle g = graph_node_get_graph(self._h_node) - return hash((as_intptr(self._h_node), as_intptr(g))) - - @property - def type(self): - """Return the CUDA graph node type. - - Returns - ------- - CUgraphNodeType or None - The node type enum value, or None for the entry node. - """ - cdef cydriver.CUgraphNode node = as_cu(self._h_node) - if node == NULL: - return None - cdef cydriver.CUgraphNodeType node_type - with nogil: - HANDLE_RETURN(cydriver.cuGraphNodeGetType(node, &node_type)) - return driver.CUgraphNodeType(node_type) - - @property - def graph(self) -> GraphDef: - """Return the GraphDef this node belongs to.""" - return GraphDef._from_handle(graph_node_get_graph(self._h_node)) - - @property - def handle(self): - """Return the underlying driver CUgraphNode handle. - - Returns None for the entry node. - """ - return as_py(self._h_node) - - @property - def pred(self) -> tuple: - """Return the predecessor nodes (dependencies) of this node. - - Results are cached since a node's dependencies are immutable - once created. - - Returns - ------- - tuple of GraphNode - The nodes that this node depends on. - """ - if self._pred_cache is not None: - return self._pred_cache - - cdef cydriver.CUgraphNode node = as_cu(self._h_node) - if node == NULL: - self._pred_cache = () - return self._pred_cache - - cdef size_t num_deps = 0 - - with nogil: - IF CUDA_CORE_BUILD_MAJOR >= 13: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependencies(node, NULL, NULL, &num_deps)) - ELSE: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependencies(node, NULL, &num_deps)) - - if num_deps == 0: - self._pred_cache = () - return self._pred_cache - - cdef vector[cydriver.CUgraphNode] deps - deps.resize(num_deps) - with nogil: - IF CUDA_CORE_BUILD_MAJOR >= 13: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependencies(node, deps.data(), NULL, &num_deps)) - ELSE: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependencies(node, deps.data(), &num_deps)) - - cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) - self._pred_cache = tuple(GraphNode._create(h_graph, deps[i]) for i in range(num_deps)) - return self._pred_cache - - @property - def succ(self) -> tuple: - """Return the successor nodes (dependents) of this node. - - Results are cached and automatically invalidated when new - dependent nodes are added via builder methods. - - Returns - ------- - tuple of GraphNode - The nodes that depend on this node. - """ - if self._succ_cache is not None: - return self._succ_cache - - cdef cydriver.CUgraphNode node = as_cu(self._h_node) - if node == NULL: - self._succ_cache = () - return self._succ_cache - - cdef size_t num_deps = 0 - - with nogil: - IF CUDA_CORE_BUILD_MAJOR >= 13: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependentNodes(node, NULL, NULL, &num_deps)) - ELSE: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependentNodes(node, NULL, &num_deps)) - - if num_deps == 0: - self._succ_cache = () - return self._succ_cache - - cdef vector[cydriver.CUgraphNode] deps - deps.resize(num_deps) - with nogil: - IF CUDA_CORE_BUILD_MAJOR >= 13: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependentNodes(node, deps.data(), NULL, &num_deps)) - ELSE: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependentNodes(node, deps.data(), &num_deps)) - - cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) - self._succ_cache = tuple(GraphNode._create(h_graph, deps[i]) for i in range(num_deps)) - return self._succ_cache - - def launch(self, config: LaunchConfig, kernel: Kernel, *args) -> KernelNode: - """Add a kernel launch node depending on this node. - - Parameters - ---------- - config : LaunchConfig - Launch configuration (grid, block, shared memory, etc.) - kernel : Kernel - The kernel to launch. - *args - Kernel arguments. - - Returns - ------- - KernelNode - A new KernelNode representing the kernel launch. - """ - cdef LaunchConfig conf = config - cdef Kernel ker = kernel - cdef ParamHolder ker_args = ParamHolder(args) - - cdef cydriver.CUDA_KERNEL_NODE_PARAMS node_params - cdef cydriver.CUgraphNode new_node = NULL - cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) - cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) - cdef cydriver.CUgraphNode* deps = NULL - cdef size_t num_deps = 0 - - if pred_node != NULL: - deps = &pred_node - num_deps = 1 - - node_params.kern = as_cu(ker._h_kernel) - node_params.func = NULL - node_params.gridDimX = conf.grid[0] - node_params.gridDimY = conf.grid[1] - node_params.gridDimZ = conf.grid[2] - node_params.blockDimX = conf.block[0] - node_params.blockDimY = conf.block[1] - node_params.blockDimZ = conf.block[2] - node_params.sharedMemBytes = conf.shmem_size - node_params.kernelParams = (ker_args.ptr) - node_params.extra = NULL - node_params.ctx = NULL - - with nogil: - HANDLE_RETURN(cydriver.cuGraphAddKernelNode( - &new_node, as_cu(h_graph), deps, num_deps, &node_params)) - - _attach_user_object(as_cu(h_graph), new KernelHandle(ker._h_kernel), - _destroy_kernel_handle_copy) - - self._succ_cache = None - return KernelNode._create_with_params( - create_graph_node_handle(new_node, h_graph), - conf.grid, conf.block, conf.shmem_size, - ker._h_kernel) - - def join(self, *nodes: GraphNode) -> EmptyNode: - """Create an empty node that depends on this node and all given nodes. - - This is used to synchronize multiple branches of execution. - - Parameters - ---------- - *nodes : GraphNode - Additional nodes to depend on. - - Returns - ------- - EmptyNode - A new EmptyNode that depends on all input nodes. - """ - cdef vector[cydriver.CUgraphNode] deps - cdef cydriver.CUgraphNode new_node = NULL - cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) - cdef GraphNode other - cdef cydriver.CUgraphNode* deps_ptr = NULL - cdef size_t num_deps = 0 - cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) - - if pred_node != NULL: - deps.push_back(pred_node) - for other in nodes: - if as_cu((other)._h_node) != NULL: - deps.push_back(as_cu((other)._h_node)) - - num_deps = deps.size() - if num_deps > 0: - deps_ptr = deps.data() - - with nogil: - HANDLE_RETURN(cydriver.cuGraphAddEmptyNode( - &new_node, as_cu(h_graph), deps_ptr, num_deps)) - - self._succ_cache = None - for other in nodes: - (other)._succ_cache = None - return EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph)) - - def alloc(self, size_t size, options: GraphAllocOptions | None = None) -> AllocNode: - """Add a memory allocation node depending on this node. - - Parameters - ---------- - size : int - Number of bytes to allocate. - options : GraphAllocOptions, optional - Allocation options. If None, allocates on the current device. - - Returns - ------- - AllocNode - A new AllocNode representing the allocation. Access the allocated - device pointer via the dptr property. - """ - cdef int device_id - cdef cydriver.CUdevice dev - - if options is None or options.device is None: - with nogil: - HANDLE_RETURN(cydriver.cuCtxGetDevice(&dev)) - device_id = dev - else: - device_id = getattr(options.device, 'device_id', options.device) - - cdef cydriver.CUDA_MEM_ALLOC_NODE_PARAMS alloc_params - cdef cydriver.CUgraphNode new_node = NULL - cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) - cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) - cdef cydriver.CUgraphNode* deps = NULL - cdef size_t num_deps = 0 - - if pred_node != NULL: - deps = &pred_node - num_deps = 1 - - cdef vector[cydriver.CUmemAccessDesc] access_descs - cdef int peer_id - cdef list peer_ids = [] - - if options is not None and options.peer_access is not None: - for peer_dev in options.peer_access: - peer_id = getattr(peer_dev, 'device_id', peer_dev) - peer_ids.append(peer_id) - access_descs.push_back(cydriver.CUmemAccessDesc_st( - cydriver.CUmemLocation_st( - cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE, - peer_id - ), - cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE - )) - - cdef str memory_type = "device" - if options is not None and options.memory_type is not None: - memory_type = options.memory_type - - c_memset(&alloc_params, 0, sizeof(alloc_params)) - alloc_params.poolProps.handleTypes = cydriver.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_NONE - alloc_params.bytesize = size - - if memory_type == "device": - alloc_params.poolProps.allocType = cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED - alloc_params.poolProps.location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE - alloc_params.poolProps.location.id = device_id - elif memory_type == "host": - alloc_params.poolProps.allocType = cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED - alloc_params.poolProps.location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_HOST - alloc_params.poolProps.location.id = 0 - elif memory_type == "managed": - IF CUDA_CORE_BUILD_MAJOR >= 13: - alloc_params.poolProps.allocType = cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_MANAGED - alloc_params.poolProps.location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE - alloc_params.poolProps.location.id = device_id - ELSE: - raise ValueError("memory_type='managed' requires CUDA 13.0 or later") - else: - raise ValueError(f"Invalid memory_type: {memory_type!r}. " - "Must be 'device', 'host', or 'managed'.") - - if access_descs.size() > 0: - alloc_params.accessDescs = access_descs.data() - alloc_params.accessDescCount = access_descs.size() - - with nogil: - HANDLE_RETURN(cydriver.cuGraphAddMemAllocNode( - &new_node, as_cu(h_graph), deps, num_deps, &alloc_params)) - - self._succ_cache = None - return AllocNode._create_with_params( - create_graph_node_handle(new_node, h_graph), alloc_params.dptr, size, - device_id, memory_type, tuple(peer_ids)) - - def free(self, dptr: int) -> FreeNode: - """Add a memory free node depending on this node. - - Parameters - ---------- - dptr : int - Device pointer to free (typically from AllocNode.dptr). - - Returns - ------- - FreeNode - A new FreeNode representing the free operation. - """ - cdef cydriver.CUgraphNode new_node = NULL - cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) - cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) - cdef cydriver.CUgraphNode* deps = NULL - cdef size_t num_deps = 0 - cdef cydriver.CUdeviceptr c_dptr = dptr - - if pred_node != NULL: - deps = &pred_node - num_deps = 1 - - with nogil: - HANDLE_RETURN(cydriver.cuGraphAddMemFreeNode( - &new_node, as_cu(h_graph), deps, num_deps, c_dptr)) - - self._succ_cache = None - return FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr) - - def memset(self, dst: int, value, size_t width, size_t height=1, size_t pitch=0) -> MemsetNode: - """Add a memset node depending on this node. - - Parameters - ---------- - dst : int - Destination device pointer. - value : int or buffer-protocol object - Fill value. int for 1-byte fill (range [0, 256)), - or buffer-protocol object of 1, 2, or 4 bytes. - width : int - Width of the row in elements. - height : int, optional - Number of rows (default 1). - pitch : int, optional - Pitch of destination in bytes (default 0, unused if height is 1). - - Returns - ------- - MemsetNode - A new MemsetNode representing the memset operation. - """ - cdef unsigned int val - cdef unsigned int elem_size - val, elem_size = _parse_fill_value(value) - - cdef cydriver.CUDA_MEMSET_NODE_PARAMS memset_params - cdef cydriver.CUgraphNode new_node = NULL - cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) - cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) - cdef cydriver.CUgraphNode* deps = NULL - cdef size_t num_deps = 0 - - if pred_node != NULL: - deps = &pred_node - num_deps = 1 - - cdef cydriver.CUdeviceptr c_dst = dst - cdef cydriver.CUcontext ctx = NULL - with nogil: - HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx)) - - c_memset(&memset_params, 0, sizeof(memset_params)) - memset_params.dst = c_dst - memset_params.value = val - memset_params.elementSize = elem_size - memset_params.width = width - memset_params.height = height - memset_params.pitch = pitch - - with nogil: - HANDLE_RETURN(cydriver.cuGraphAddMemsetNode( - &new_node, as_cu(h_graph), deps, num_deps, - &memset_params, ctx)) - - self._succ_cache = None - return MemsetNode._create_with_params( - create_graph_node_handle(new_node, h_graph), c_dst, - val, elem_size, width, height, pitch) - - def memcpy(self, dst: int, src: int, size_t size) -> MemcpyNode: - """Add a memcpy node depending on this node. - - Copies ``size`` bytes from ``src`` to ``dst``. Memory types are - auto-detected via the driver, so both device and pinned host - pointers are supported. - - Parameters - ---------- - dst : int - Destination pointer (device or pinned host). - src : int - Source pointer (device or pinned host). - size : int - Number of bytes to copy. - - Returns - ------- - MemcpyNode - A new MemcpyNode representing the copy operation. - """ - cdef cydriver.CUdeviceptr c_dst = dst - cdef cydriver.CUdeviceptr c_src = src - - cdef unsigned int dst_mem_type = cydriver.CU_MEMORYTYPE_DEVICE - cdef unsigned int src_mem_type = cydriver.CU_MEMORYTYPE_DEVICE - cdef cydriver.CUresult ret - with nogil: - ret = cydriver.cuPointerGetAttribute( - &dst_mem_type, - cydriver.CU_POINTER_ATTRIBUTE_MEMORY_TYPE, - c_dst) - if ret != cydriver.CUDA_SUCCESS and ret != cydriver.CUDA_ERROR_INVALID_VALUE: - HANDLE_RETURN(ret) - ret = cydriver.cuPointerGetAttribute( - &src_mem_type, - cydriver.CU_POINTER_ATTRIBUTE_MEMORY_TYPE, - c_src) - if ret != cydriver.CUDA_SUCCESS and ret != cydriver.CUDA_ERROR_INVALID_VALUE: - HANDLE_RETURN(ret) - - cdef cydriver.CUmemorytype c_dst_type = dst_mem_type - cdef cydriver.CUmemorytype c_src_type = src_mem_type - - cdef cydriver.CUDA_MEMCPY3D params - c_memset(¶ms, 0, sizeof(params)) - - params.srcMemoryType = c_src_type - params.dstMemoryType = c_dst_type - if c_src_type == cydriver.CU_MEMORYTYPE_HOST: - params.srcHost = c_src - else: - params.srcDevice = c_src - if c_dst_type == cydriver.CU_MEMORYTYPE_HOST: - params.dstHost = c_dst - else: - params.dstDevice = c_dst - params.WidthInBytes = size - params.Height = 1 - params.Depth = 1 - - cdef cydriver.CUgraphNode new_node = NULL - cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) - cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) - cdef cydriver.CUgraphNode* deps = NULL - cdef size_t num_deps = 0 - - if pred_node != NULL: - deps = &pred_node - num_deps = 1 - - cdef cydriver.CUcontext ctx = NULL - with nogil: - HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx)) - HANDLE_RETURN(cydriver.cuGraphAddMemcpyNode( - &new_node, as_cu(h_graph), deps, num_deps, ¶ms, ctx)) - - self._succ_cache = None - return MemcpyNode._create_with_params( - create_graph_node_handle(new_node, h_graph), c_dst, c_src, size, - c_dst_type, c_src_type) - - def embed(self, child: GraphDef) -> ChildGraphNode: - """Add a child graph node depending on this node. - - Embeds a clone of the given graph definition as a sub-graph node. - The child graph must not contain allocation, free, or conditional - nodes. - - Parameters - ---------- - child : GraphDef - The graph definition to embed (will be cloned). - - Returns - ------- - ChildGraphNode - A new ChildGraphNode representing the embedded sub-graph. - """ - cdef GraphDef child_def = child - cdef cydriver.CUgraphNode new_node = NULL - cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) - cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) - cdef cydriver.CUgraphNode* deps = NULL - cdef size_t num_deps = 0 - - if pred_node != NULL: - deps = &pred_node - num_deps = 1 - - with nogil: - HANDLE_RETURN(cydriver.cuGraphAddChildGraphNode( - &new_node, as_cu(h_graph), deps, num_deps, as_cu(child_def._h_graph))) - - cdef cydriver.CUgraph embedded_graph = NULL - with nogil: - HANDLE_RETURN(cydriver.cuGraphChildGraphNodeGetGraph( - new_node, &embedded_graph)) - - cdef GraphHandle h_embedded = create_graph_handle_ref(embedded_graph, h_graph) - - self._succ_cache = None - return ChildGraphNode._create_with_params( - create_graph_node_handle(new_node, h_graph), h_embedded) - - def record_event(self, event: Event) -> EventRecordNode: - """Add an event record node depending on this node. - - Parameters - ---------- - event : Event - The event to record. - - Returns - ------- - EventRecordNode - A new EventRecordNode representing the event record operation. - """ - cdef Event ev = event - cdef cydriver.CUgraphNode new_node = NULL - cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) - cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) - cdef cydriver.CUgraphNode* deps = NULL - cdef size_t num_deps = 0 - - if pred_node != NULL: - deps = &pred_node - num_deps = 1 - - with nogil: - HANDLE_RETURN(cydriver.cuGraphAddEventRecordNode( - &new_node, as_cu(h_graph), deps, num_deps, as_cu(ev._h_event))) - - _attach_user_object(as_cu(h_graph), new EventHandle(ev._h_event), - _destroy_event_handle_copy) - - self._succ_cache = None - return EventRecordNode._create_with_params( - create_graph_node_handle(new_node, h_graph), ev._h_event) - - def wait_event(self, event: Event) -> EventWaitNode: - """Add an event wait node depending on this node. - - Parameters - ---------- - event : Event - The event to wait for. - - Returns - ------- - EventWaitNode - A new EventWaitNode representing the event wait operation. - """ - cdef Event ev = event - cdef cydriver.CUgraphNode new_node = NULL - cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) - cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) - cdef cydriver.CUgraphNode* deps = NULL - cdef size_t num_deps = 0 - - if pred_node != NULL: - deps = &pred_node - num_deps = 1 - - with nogil: - HANDLE_RETURN(cydriver.cuGraphAddEventWaitNode( - &new_node, as_cu(h_graph), deps, num_deps, as_cu(ev._h_event))) - - _attach_user_object(as_cu(h_graph), new EventHandle(ev._h_event), - _destroy_event_handle_copy) - - self._succ_cache = None - return EventWaitNode._create_with_params( - create_graph_node_handle(new_node, h_graph), ev._h_event) - - def callback(self, fn, *, user_data=None) -> HostCallbackNode: - """Add a host callback node depending on this node. - - The callback runs on the host CPU when the graph reaches this node. - Two modes are supported: - - - **Python callable**: Pass any callable. The GIL is acquired - automatically. The callable must take no arguments; use closures - or ``functools.partial`` to bind state. - - **ctypes function pointer**: Pass a ``ctypes.CFUNCTYPE`` instance. - The function receives a single ``void*`` argument (the - ``user_data``). The caller must keep the ctypes wrapper alive - for the lifetime of the graph. - - .. warning:: - - Callbacks must not call CUDA API functions. Doing so may - deadlock or corrupt driver state. - - Parameters - ---------- - fn : callable or ctypes function pointer - The callback function. - user_data : int or bytes-like, optional - Only for ctypes function pointers. If ``int``, passed as a raw - pointer (caller manages lifetime). If bytes-like, the data is - copied and its lifetime is tied to the graph. - - Returns - ------- - HostCallbackNode - A new HostCallbackNode representing the callback. - """ - import ctypes as ct - - cdef cydriver.CUDA_HOST_NODE_PARAMS node_params - cdef cydriver.CUgraphNode new_node = NULL - cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) - cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) - cdef cydriver.CUgraphNode* deps = NULL - cdef size_t num_deps = 0 - - if pred_node != NULL: - deps = &pred_node - num_deps = 1 - - _attach_host_callback_to_graph( - as_cu(h_graph), fn, user_data, - &node_params.fn, &node_params.userData) - - with nogil: - HANDLE_RETURN(cydriver.cuGraphAddHostNode( - &new_node, as_cu(h_graph), deps, num_deps, &node_params)) - - cdef object callable_obj = fn if not isinstance(fn, ct._CFuncPtr) else None - self._succ_cache = None - return HostCallbackNode._create_with_params( - create_graph_node_handle(new_node, h_graph), callable_obj, - node_params.fn, node_params.userData) - - def if_cond(self, condition: Condition) -> IfNode: - """Add an if-conditional node depending on this node. - - The body graph executes only when the condition evaluates to - a non-zero value at runtime. - - Parameters - ---------- - condition : Condition - Condition from :meth:`GraphDef.create_condition`. - - Returns - ------- - IfNode - A new IfNode with one branch accessible via ``.then``. - """ - return _make_conditional_node( - self, condition, - cydriver.CU_GRAPH_COND_TYPE_IF, 1, IfNode) - - def if_else(self, condition: Condition) -> IfElseNode: - """Add an if-else conditional node depending on this node. - - Two body graphs: the first executes when the condition is - non-zero, the second when it is zero. - - Parameters - ---------- - condition : Condition - Condition from :meth:`GraphDef.create_condition`. - - Returns - ------- - IfElseNode - A new IfElseNode with branches accessible via - ``.then`` and ``.else_``. - """ - return _make_conditional_node( - self, condition, - cydriver.CU_GRAPH_COND_TYPE_IF, 2, IfElseNode) - - def while_loop(self, condition: Condition) -> WhileNode: - """Add a while-loop conditional node depending on this node. - - The body graph executes repeatedly while the condition - evaluates to a non-zero value. - - Parameters - ---------- - condition : Condition - Condition from :meth:`GraphDef.create_condition`. - - Returns - ------- - WhileNode - A new WhileNode with body accessible via ``.body``. - """ - return _make_conditional_node( - self, condition, - cydriver.CU_GRAPH_COND_TYPE_WHILE, 1, WhileNode) - - def switch(self, condition: Condition, unsigned int count) -> SwitchNode: - """Add a switch conditional node depending on this node. - - The condition value selects which branch to execute. If the - value is out of range, no branch executes. - - Parameters - ---------- - condition : Condition - Condition from :meth:`GraphDef.create_condition`. - count : int - Number of switch cases (branches). - - Returns - ------- - SwitchNode - A new SwitchNode with branches accessible via ``.branches``. - """ - return _make_conditional_node( - self, condition, - cydriver.CU_GRAPH_COND_TYPE_SWITCH, count, SwitchNode) - - -# ============================================================================= -# GraphNode subclasses -# ============================================================================= - - -cdef class EmptyNode(GraphNode): - """A synchronization / join node with no operation.""" - - @staticmethod - cdef EmptyNode _create_impl(GraphNodeHandle h_node): - cdef EmptyNode n = EmptyNode.__new__(EmptyNode) - n._h_node = h_node - return n - - def __repr__(self) -> str: - cdef Py_ssize_t n = len(self.pred) - return f"" - - -cdef class KernelNode(GraphNode): - """A kernel launch node. - - Properties - ---------- - grid : tuple of int - Grid dimensions (gridDimX, gridDimY, gridDimZ). - block : tuple of int - Block dimensions (blockDimX, blockDimY, blockDimZ). - shmem_size : int - Dynamic shared memory size in bytes. - kernel : Kernel - The kernel object for this launch node. - config : LaunchConfig - A LaunchConfig reconstructed from this node's parameters. - """ - - @staticmethod - cdef KernelNode _create_with_params(GraphNodeHandle h_node, - tuple grid, tuple block, unsigned int shmem_size, - KernelHandle h_kernel): - """Create from known params (called by launch() builder).""" - cdef KernelNode n = KernelNode.__new__(KernelNode) - n._h_node = h_node - n._grid = grid - n._block = block - n._shmem_size = shmem_size - n._h_kernel = h_kernel - return n - - @staticmethod - cdef KernelNode _create_from_driver(GraphNodeHandle h_node): - """Create by fetching params from the driver (called by _create factory).""" - cdef cydriver.CUgraphNode node = as_cu(h_node) - cdef cydriver.CUDA_KERNEL_NODE_PARAMS params - with nogil: - HANDLE_RETURN(cydriver.cuGraphKernelNodeGetParams(node, ¶ms)) - cdef KernelHandle h_kernel = create_kernel_handle_ref(params.kern) - return KernelNode._create_with_params( - h_node, - (params.gridDimX, params.gridDimY, params.gridDimZ), - (params.blockDimX, params.blockDimY, params.blockDimZ), - params.sharedMemBytes, - h_kernel) - - def __repr__(self) -> str: - return (f"") - - @property - def grid(self) -> tuple: - """Grid dimensions as a 3-tuple (gridDimX, gridDimY, gridDimZ).""" - return self._grid - - @property - def block(self) -> tuple: - """Block dimensions as a 3-tuple (blockDimX, blockDimY, blockDimZ).""" - return self._block - - @property - def shmem_size(self) -> int: - """Dynamic shared memory size in bytes.""" - return self._shmem_size - - @property - def kernel(self) -> Kernel: - """The Kernel object for this launch node.""" - return Kernel._from_handle(self._h_kernel) - - @property - def config(self) -> LaunchConfig: - """A LaunchConfig reconstructed from this node's grid, block, and shmem_size. - - Note: cluster dimensions and cooperative_launch are not preserved - by the CUDA driver's kernel node params, so they are not included. - """ - return LaunchConfig(grid=self._grid, block=self._block, - shmem_size=self._shmem_size) - - -cdef class AllocNode(GraphNode): - """A memory allocation node. - - Properties - ---------- - dptr : int - The device pointer for the allocation. - bytesize : int - The number of bytes allocated. - device_id : int - The device on which the allocation was made. - memory_type : str - The type of memory allocated (``"device"``, ``"host"``, or ``"managed"``). - peer_access : tuple of int - Device IDs that have read-write access to this allocation. - options : GraphAllocOptions - A GraphAllocOptions reconstructed from this node's parameters. - """ - - @staticmethod - cdef AllocNode _create_with_params(GraphNodeHandle h_node, - cydriver.CUdeviceptr dptr, size_t bytesize, - int device_id, str memory_type, tuple peer_access): - """Create from known params (called by alloc() builder).""" - cdef AllocNode n = AllocNode.__new__(AllocNode) - n._h_node = h_node - n._dptr = dptr - n._bytesize = bytesize - n._device_id = device_id - n._memory_type = memory_type - n._peer_access = peer_access - return n - - @staticmethod - cdef AllocNode _create_from_driver(GraphNodeHandle h_node): - """Create by fetching params from the driver (called by _create factory).""" - cdef cydriver.CUgraphNode node = as_cu(h_node) - cdef cydriver.CUDA_MEM_ALLOC_NODE_PARAMS params - with nogil: - HANDLE_RETURN(cydriver.cuGraphMemAllocNodeGetParams(node, ¶ms)) - - cdef str memory_type - if params.poolProps.allocType == cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED: - if params.poolProps.location.type == cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_HOST: - memory_type = "host" - else: - memory_type = "device" - else: - IF CUDA_CORE_BUILD_MAJOR >= 13: - if params.poolProps.allocType == cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_MANAGED: - memory_type = "managed" - else: - memory_type = "device" - ELSE: - memory_type = "device" - - cdef list peer_ids = [] - cdef size_t i - for i in range(params.accessDescCount): - peer_ids.append(params.accessDescs[i].location.id) - - return AllocNode._create_with_params( - h_node, params.dptr, params.bytesize, - params.poolProps.location.id, memory_type, tuple(peer_ids)) - - def __repr__(self) -> str: - return f"" - - @property - def dptr(self) -> int: - """The device pointer for the allocation.""" - return self._dptr - - @property - def bytesize(self) -> int: - """The number of bytes allocated.""" - return self._bytesize - - @property - def device_id(self) -> int: - """The device on which the allocation was made.""" - return self._device_id - - @property - def memory_type(self) -> str: - """The type of memory: ``"device"``, ``"host"``, or ``"managed"``.""" - return self._memory_type - - @property - def peer_access(self) -> tuple: - """Device IDs with read-write access to this allocation.""" - return self._peer_access - - @property - def options(self) -> GraphAllocOptions: - """A GraphAllocOptions reconstructed from this node's parameters.""" - return GraphAllocOptions( - device=self._device_id, - memory_type=self._memory_type, - peer_access=list(self._peer_access) if self._peer_access else None, - ) - - -cdef class FreeNode(GraphNode): - """A memory free node. - - Properties - ---------- - dptr : int - The device pointer being freed. - """ - - @staticmethod - cdef FreeNode _create_with_params(GraphNodeHandle h_node, - cydriver.CUdeviceptr dptr): - """Create from known params (called by free() builder).""" - cdef FreeNode n = FreeNode.__new__(FreeNode) - n._h_node = h_node - n._dptr = dptr - return n - - @staticmethod - cdef FreeNode _create_from_driver(GraphNodeHandle h_node): - """Create by fetching params from the driver (called by _create factory).""" - cdef cydriver.CUgraphNode node = as_cu(h_node) - cdef cydriver.CUdeviceptr dptr - with nogil: - HANDLE_RETURN(cydriver.cuGraphMemFreeNodeGetParams(node, &dptr)) - return FreeNode._create_with_params(h_node, dptr) - - def __repr__(self) -> str: - return f"" - - @property - def dptr(self) -> int: - """The device pointer being freed.""" - return self._dptr - - -cdef class MemsetNode(GraphNode): - """A memory set node. - - Properties - ---------- - dptr : int - The destination device pointer. - value : int - The fill value. - element_size : int - Element size in bytes (1, 2, or 4). - width : int - Width of the row in elements. - height : int - Number of rows. - pitch : int - Pitch in bytes (unused if height is 1). - """ - - @staticmethod - cdef MemsetNode _create_with_params(GraphNodeHandle h_node, - cydriver.CUdeviceptr dptr, unsigned int value, - unsigned int element_size, size_t width, - size_t height, size_t pitch): - """Create from known params (called by memset() builder).""" - cdef MemsetNode n = MemsetNode.__new__(MemsetNode) - n._h_node = h_node - n._dptr = dptr - n._value = value - n._element_size = element_size - n._width = width - n._height = height - n._pitch = pitch - return n - - @staticmethod - cdef MemsetNode _create_from_driver(GraphNodeHandle h_node): - """Create by fetching params from the driver (called by _create factory).""" - cdef cydriver.CUgraphNode node = as_cu(h_node) - cdef cydriver.CUDA_MEMSET_NODE_PARAMS params - with nogil: - HANDLE_RETURN(cydriver.cuGraphMemsetNodeGetParams(node, ¶ms)) - return MemsetNode._create_with_params( - h_node, params.dst, params.value, - params.elementSize, params.width, params.height, params.pitch) - - def __repr__(self) -> str: - return (f"") - - @property - def dptr(self) -> int: - """The destination device pointer.""" - return self._dptr - - @property - def value(self) -> int: - """The fill value.""" - return self._value - - @property - def element_size(self) -> int: - """Element size in bytes (1, 2, or 4).""" - return self._element_size - - @property - def width(self) -> int: - """Width of the row in elements.""" - return self._width - - @property - def height(self) -> int: - """Number of rows.""" - return self._height - - @property - def pitch(self) -> int: - """Pitch in bytes (unused if height is 1).""" - return self._pitch - - -cdef class MemcpyNode(GraphNode): - """A memory copy node. - - Properties - ---------- - dst : int - The destination pointer. - src : int - The source pointer. - size : int - The number of bytes copied. - """ - - @staticmethod - cdef MemcpyNode _create_with_params(GraphNodeHandle h_node, - cydriver.CUdeviceptr dst, cydriver.CUdeviceptr src, - size_t size, cydriver.CUmemorytype dst_type, - cydriver.CUmemorytype src_type): - """Create from known params (called by memcpy() builder).""" - cdef MemcpyNode n = MemcpyNode.__new__(MemcpyNode) - n._h_node = h_node - n._dst = dst - n._src = src - n._size = size - n._dst_type = dst_type - n._src_type = src_type - return n - - @staticmethod - cdef MemcpyNode _create_from_driver(GraphNodeHandle h_node): - """Create by fetching params from the driver (called by _create factory).""" - cdef cydriver.CUgraphNode node = as_cu(h_node) - cdef cydriver.CUDA_MEMCPY3D params - with nogil: - HANDLE_RETURN(cydriver.cuGraphMemcpyNodeGetParams(node, ¶ms)) - - cdef cydriver.CUdeviceptr dst - cdef cydriver.CUdeviceptr src - if params.dstMemoryType == cydriver.CU_MEMORYTYPE_HOST: - dst = params.dstHost - else: - dst = params.dstDevice - if params.srcMemoryType == cydriver.CU_MEMORYTYPE_HOST: - src = params.srcHost - else: - src = params.srcDevice - - return MemcpyNode._create_with_params( - h_node, dst, src, params.WidthInBytes, - params.dstMemoryType, params.srcMemoryType) - - def __repr__(self) -> str: - cdef str dt = "H" if self._dst_type == cydriver.CU_MEMORYTYPE_HOST else "D" - cdef str st = "H" if self._src_type == cydriver.CU_MEMORYTYPE_HOST else "D" - return (f"") - - @property - def dst(self) -> int: - """The destination pointer.""" - return self._dst - - @property - def src(self) -> int: - """The source pointer.""" - return self._src - - @property - def size(self) -> int: - """The number of bytes copied.""" - return self._size - - -cdef class ChildGraphNode(GraphNode): - """A child graph (sub-graph) node. - - Properties - ---------- - child_graph : GraphDef - The embedded graph definition (non-owning wrapper). - """ - - @staticmethod - cdef ChildGraphNode _create_with_params(GraphNodeHandle h_node, - GraphHandle h_child_graph): - """Create from known params (called by embed() builder).""" - cdef ChildGraphNode n = ChildGraphNode.__new__(ChildGraphNode) - n._h_node = h_node - n._h_child_graph = h_child_graph - return n - - @staticmethod - cdef ChildGraphNode _create_from_driver(GraphNodeHandle h_node): - """Create by fetching params from the driver (called by _create factory).""" - cdef cydriver.CUgraphNode node = as_cu(h_node) - cdef cydriver.CUgraph child_graph = NULL - with nogil: - HANDLE_RETURN(cydriver.cuGraphChildGraphNodeGetGraph(node, &child_graph)) - cdef GraphHandle h_graph = graph_node_get_graph(h_node) - cdef GraphHandle h_child = create_graph_handle_ref(child_graph, h_graph) - return ChildGraphNode._create_with_params(h_node, h_child) - - def __repr__(self) -> str: - cdef cydriver.CUgraph g = as_cu(self._h_child_graph) - cdef size_t num_nodes = 0 - with nogil: - HANDLE_RETURN(cydriver.cuGraphGetNodes(g, NULL, &num_nodes)) - cdef Py_ssize_t n = num_nodes - return f"" - - @property - def child_graph(self) -> GraphDef: - """The embedded graph definition (non-owning wrapper).""" - return GraphDef._from_handle(self._h_child_graph) - - -cdef class EventRecordNode(GraphNode): - """An event record node. - - Properties - ---------- - event : Event - The event being recorded. - """ - - @staticmethod - cdef EventRecordNode _create_with_params(GraphNodeHandle h_node, - EventHandle h_event): - """Create from known params (called by record_event() builder).""" - cdef EventRecordNode n = EventRecordNode.__new__(EventRecordNode) - n._h_node = h_node - n._h_event = h_event - return n - - @staticmethod - cdef EventRecordNode _create_from_driver(GraphNodeHandle h_node): - """Create by fetching params from the driver (called by _create factory).""" - cdef cydriver.CUgraphNode node = as_cu(h_node) - cdef cydriver.CUevent event - with nogil: - HANDLE_RETURN(cydriver.cuGraphEventRecordNodeGetEvent(node, &event)) - cdef EventHandle h_event = create_event_handle_ref(event) - return EventRecordNode._create_with_params(h_node, h_event) - - def __repr__(self) -> str: - return f"" - - @property - def event(self) -> Event: - """The event being recorded.""" - return Event._from_handle(self._h_event) - - -cdef class EventWaitNode(GraphNode): - """An event wait node. - - Properties - ---------- - event : Event - The event being waited on. - """ - - @staticmethod - cdef EventWaitNode _create_with_params(GraphNodeHandle h_node, - EventHandle h_event): - """Create from known params (called by wait_event() builder).""" - cdef EventWaitNode n = EventWaitNode.__new__(EventWaitNode) - n._h_node = h_node - n._h_event = h_event - return n - - @staticmethod - cdef EventWaitNode _create_from_driver(GraphNodeHandle h_node): - """Create by fetching params from the driver (called by _create factory).""" - cdef cydriver.CUgraphNode node = as_cu(h_node) - cdef cydriver.CUevent event - with nogil: - HANDLE_RETURN(cydriver.cuGraphEventWaitNodeGetEvent(node, &event)) - cdef EventHandle h_event = create_event_handle_ref(event) - return EventWaitNode._create_with_params(h_node, h_event) - - def __repr__(self) -> str: - return f"" - - @property - def event(self) -> Event: - """The event being waited on.""" - return Event._from_handle(self._h_event) - - -cdef class HostCallbackNode(GraphNode): - """A host callback node. - - Properties - ---------- - callback_fn : callable or None - The Python callable (None for ctypes function pointer callbacks). - """ - - @staticmethod - cdef HostCallbackNode _create_with_params(GraphNodeHandle h_node, - object callable_obj, cydriver.CUhostFn fn, - void* user_data): - """Create from known params (called by callback() builder).""" - cdef HostCallbackNode n = HostCallbackNode.__new__(HostCallbackNode) - n._h_node = h_node - n._callable = callable_obj - n._fn = fn - n._user_data = user_data - return n - - @staticmethod - cdef HostCallbackNode _create_from_driver(GraphNodeHandle h_node): - """Create by fetching params from the driver (called by _create factory).""" - cdef cydriver.CUgraphNode node = as_cu(h_node) - cdef cydriver.CUDA_HOST_NODE_PARAMS params - with nogil: - HANDLE_RETURN(cydriver.cuGraphHostNodeGetParams(node, ¶ms)) - - cdef object callable_obj = None - if _is_py_host_trampoline(params.fn): - callable_obj = params.userData - - return HostCallbackNode._create_with_params( - h_node, callable_obj, params.fn, params.userData) - - def __repr__(self) -> str: - if self._callable is not None: - name = getattr(self._callable, '__name__', '?') - return f"" - return f"self._fn:x}>" - - @property - def callback_fn(self): - """The Python callable, or None for ctypes function pointer callbacks.""" - return self._callable - - -cdef class ConditionalNode(GraphNode): - """Base class for conditional graph nodes. - - When created via builder methods (if_cond, if_else, while_loop, switch), - a specific subclass (IfNode, IfElseNode, WhileNode, SwitchNode) is - returned. When reconstructed from the driver on CUDA 13.2+, the - correct subclass is determined via cuGraphNodeGetParams. On older - drivers, this base class is used as a fallback. - - Properties - ---------- - condition : Condition or None - The condition variable controlling execution (None pre-13.2). - cond_type : str or None - The conditional type ("if", "while", or "switch"; None pre-13.2). - branches : tuple of GraphDef - The body graphs for each branch (empty pre-13.2). - """ - - @staticmethod - cdef ConditionalNode _create_from_driver(GraphNodeHandle h_node): - cdef ConditionalNode n - if not _check_node_get_params(): - n = ConditionalNode.__new__(ConditionalNode) - n._h_node = h_node - n._condition = None - n._cond_type = cydriver.CU_GRAPH_COND_TYPE_IF - n._branches = () - return n - - cdef cydriver.CUgraphNode node = as_cu(h_node) - params = handle_return(driver.cuGraphNodeGetParams( - node)) - cond_params = params.conditional - cdef int cond_type_int = int(cond_params.type) - cdef unsigned int size = int(cond_params.size) - - cdef Condition condition = Condition.__new__(Condition) - condition._c_handle = ( - int(cond_params.handle)) - - cdef GraphHandle h_graph = graph_node_get_graph(h_node) - cdef list branch_list = [] - cdef unsigned int i - cdef GraphHandle h_branch - if cond_params.phGraph_out is not None: - for i in range(size): - h_branch = create_graph_handle_ref( - int(cond_params.phGraph_out[i]), - h_graph) - branch_list.append(GraphDef._from_handle(h_branch)) - cdef tuple branches = tuple(branch_list) - - cdef type cls - if cond_type_int == cydriver.CU_GRAPH_COND_TYPE_IF: - if size == 1: - cls = IfNode - else: - cls = IfElseNode - elif cond_type_int == cydriver.CU_GRAPH_COND_TYPE_WHILE: - cls = WhileNode - else: - cls = SwitchNode - - n = cls.__new__(cls) - n._h_node = h_node - n._condition = condition - n._cond_type = cond_type_int - n._branches = branches - return n - - def __repr__(self) -> str: - return "" - - @property - def condition(self) -> Condition | None: - """The condition variable controlling execution.""" - return self._condition - - @property - def cond_type(self) -> str | None: - """The conditional type as a string: 'if', 'while', or 'switch'. - - Returns None when reconstructed from the driver pre-CUDA 13.2, - as the conditional type cannot be determined. - """ - if self._condition is None: - return None - if self._cond_type == cydriver.CU_GRAPH_COND_TYPE_IF: - return "if" - elif self._cond_type == cydriver.CU_GRAPH_COND_TYPE_WHILE: - return "while" - else: - return "switch" - - @property - def branches(self) -> tuple: - """The body graphs for each branch as a tuple of GraphDef. - - Returns an empty tuple when reconstructed from the driver - pre-CUDA 13.2. - """ - return self._branches - - -cdef class IfNode(ConditionalNode): - """An if-conditional node (1 branch, executes when condition is non-zero).""" - - def __repr__(self) -> str: - return f"self._condition._c_handle:x}>" - - @property - def then(self) -> GraphDef: - """The 'then' branch graph.""" - return self._branches[0] - - -cdef class IfElseNode(ConditionalNode): - """An if-else conditional node (2 branches).""" - - def __repr__(self) -> str: - return f"self._condition._c_handle:x}>" - - @property - def then(self) -> GraphDef: - """The 'then' branch graph (executed when condition is non-zero).""" - return self._branches[0] - - @property - def else_(self) -> GraphDef: - """The 'else' branch graph (executed when condition is zero).""" - return self._branches[1] - - -cdef class WhileNode(ConditionalNode): - """A while-loop conditional node (1 branch, repeats while condition is non-zero).""" - - def __repr__(self) -> str: - return f"self._condition._c_handle:x}>" - - @property - def body(self) -> GraphDef: - """The loop body graph.""" - return self._branches[0] - - -cdef class SwitchNode(ConditionalNode): - """A switch conditional node (N branches, selected by condition value).""" - - def __repr__(self) -> str: - cdef Py_ssize_t n = len(self._branches) - return (f"self._condition._c_handle:x}" - f" with {n} {'branch' if n == 1 else 'branches'}>") diff --git a/cuda_core/tests/graph/test_graphdef.py b/cuda_core/tests/graph/test_graphdef.py index 30a7f05c98..3412d71847 100644 --- a/cuda_core/tests/graph/test_graphdef.py +++ b/cuda_core/tests/graph/test_graphdef.py @@ -12,7 +12,7 @@ from cuda.core import Device, LaunchConfig from cuda.core._graph import GraphCompleteOptions, GraphDebugPrintOptions -from cuda.core._graph._graphdef import ( +from cuda.core._graph._graph_def import ( AllocNode, ChildGraphNode, ConditionalNode, diff --git a/cuda_core/tests/graph/test_graphdef_errors.py b/cuda_core/tests/graph/test_graphdef_errors.py index 09c3bf8ec4..9c6a870562 100644 --- a/cuda_core/tests/graph/test_graphdef_errors.py +++ b/cuda_core/tests/graph/test_graphdef_errors.py @@ -10,7 +10,7 @@ from helpers.misc import try_create_condition from cuda.core import Device, LaunchConfig -from cuda.core._graph._graphdef import ( +from cuda.core._graph._graph_def import ( Condition, EmptyNode, GraphDef, diff --git a/cuda_core/tests/graph/test_graphdef_integration.py b/cuda_core/tests/graph/test_graphdef_integration.py index bb7eab0f8e..d66b60f450 100644 --- a/cuda_core/tests/graph/test_graphdef_integration.py +++ b/cuda_core/tests/graph/test_graphdef_integration.py @@ -9,7 +9,7 @@ import pytest from cuda.core import Device, EventOptions, LaunchConfig, Program, ProgramOptions -from cuda.core._graph._graphdef import GraphDef +from cuda.core._graph._graph_def import GraphDef from cuda.core._utils.cuda_utils import driver, handle_return SIZEOF_FLOAT = 4 diff --git a/cuda_core/tests/graph/test_graphdef_lifetime.py b/cuda_core/tests/graph/test_graphdef_lifetime.py index 1fa1c025c2..133f2c7ca1 100644 --- a/cuda_core/tests/graph/test_graphdef_lifetime.py +++ b/cuda_core/tests/graph/test_graphdef_lifetime.py @@ -10,7 +10,7 @@ from helpers.misc import try_create_condition from cuda.core import Device, EventOptions, Kernel, LaunchConfig -from cuda.core._graph._graphdef import ( +from cuda.core._graph._graph_def import ( ChildGraphNode, ConditionalNode, GraphDef, diff --git a/cuda_core/tests/test_object_protocols.py b/cuda_core/tests/test_object_protocols.py index bd92ad0696..ef4f1337d1 100644 --- a/cuda_core/tests/test_object_protocols.py +++ b/cuda_core/tests/test_object_protocols.py @@ -27,7 +27,7 @@ Stream, system, ) -from cuda.core._graph._graphdef import GraphDef +from cuda.core._graph._graph_def import GraphDef from cuda.core._program import _can_load_generated_ptx From 60964fdbf8c2e9570ea849a78a7ee19cf41f6615 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Tue, 31 Mar 2026 15:38:09 -0700 Subject: [PATCH 5/7] Fix stale _graphdef import paths after subpackage rename Update two references that still used _graphdef instead of _graph_def after the subpackage split. Made-with: Cursor --- cuda_core/cuda/core/_graph/_graph_builder.pyx | 4 ++-- cuda_core/tests/graph/test_graph_update.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cuda_core/cuda/core/_graph/_graph_builder.pyx b/cuda_core/cuda/core/_graph/_graph_builder.pyx index e67efb120f..2f1301d8f8 100644 --- a/cuda_core/cuda/core/_graph/_graph_builder.pyx +++ b/cuda_core/cuda/core/_graph/_graph_builder.pyx @@ -793,12 +793,12 @@ class Graph: Parameters ---------- - source : :obj:`~_graph.GraphBuilder` or :obj:`~_graph._graphdef.GraphDef` + source : :obj:`~_graph.GraphBuilder` or :obj:`~_graph._graph_def.GraphDef` The graph definition to update from. A GraphBuilder must have finished building. """ - from cuda.core._graph._graphdef import GraphDef + from cuda.core._graph._graph_def import GraphDef cdef cydriver.CUgraph cu_graph cdef cydriver.CUgraphExec cu_exec = int(self._mnff.graph) diff --git a/cuda_core/tests/graph/test_graph_update.py b/cuda_core/tests/graph/test_graph_update.py index caf9ea4304..e4716d5601 100644 --- a/cuda_core/tests/graph/test_graph_update.py +++ b/cuda_core/tests/graph/test_graph_update.py @@ -8,7 +8,7 @@ from helpers.graph_kernels import compile_common_kernels, compile_conditional_kernels from cuda.core import Device, LaunchConfig, LegacyPinnedMemoryResource, launch -from cuda.core._graph._graphdef import GraphDef +from cuda.core._graph._graph_def import GraphDef from cuda.core._utils.cuda_utils import CUDAError From e924fdebec6e2d1393d139944d9b4c098480c35e Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Wed, 1 Apr 2026 17:38:09 -0700 Subject: [PATCH 6/7] Assert specific error code from cuGraphExecUpdate Made-with: Cursor --- cuda_core/cuda/core/_graph/_graph_builder.pyx | 1 + 1 file changed, 1 insertion(+) diff --git a/cuda_core/cuda/core/_graph/_graph_builder.pyx b/cuda_core/cuda/core/_graph/_graph_builder.pyx index 2f1301d8f8..f9d2143d04 100644 --- a/cuda_core/cuda/core/_graph/_graph_builder.pyx +++ b/cuda_core/cuda/core/_graph/_graph_builder.pyx @@ -818,6 +818,7 @@ class Graph: with nogil: err = cydriver.cuGraphExecUpdate(cu_exec, cu_graph, &result_info) if err != cydriver.CUresult.CUDA_SUCCESS: + assert err == cydriver.CUresult.CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE reason = driver.CUgraphExecUpdateResult(result_info.result) msg = f"Graph update failed: {reason.__doc__.strip()} ({reason.name})" raise CUDAError(msg) From fe04a910a89be39f9d3c197a6394f131922d7b1f Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Wed, 1 Apr 2026 17:44:42 -0700 Subject: [PATCH 7/7] Handle all cuGraphExecUpdate error codes, not just update failure Check for CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE first to provide the rich error message with the update result reason, then fall through to HANDLE_RETURN for any other error code (CUDA_ERROR_INVALID_VALUE, CUDA_ERROR_NOT_SUPPORTED, etc.) or success. Made-with: Cursor --- cuda_core/cuda/core/_graph/_graph_builder.pyx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cuda_core/cuda/core/_graph/_graph_builder.pyx b/cuda_core/cuda/core/_graph/_graph_builder.pyx index f9d2143d04..1d3b48435b 100644 --- a/cuda_core/cuda/core/_graph/_graph_builder.pyx +++ b/cuda_core/cuda/core/_graph/_graph_builder.pyx @@ -817,11 +817,11 @@ class Graph: cdef cydriver.CUresult err with nogil: err = cydriver.cuGraphExecUpdate(cu_exec, cu_graph, &result_info) - if err != cydriver.CUresult.CUDA_SUCCESS: - assert err == cydriver.CUresult.CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE + if err == cydriver.CUresult.CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE: reason = driver.CUgraphExecUpdateResult(result_info.result) msg = f"Graph update failed: {reason.__doc__.strip()} ({reason.name})" raise CUDAError(msg) + HANDLE_RETURN(err) def upload(self, stream: Stream): """Uploads the graph in a stream.