Skip to content

Commit d3c4ef4

Browse files
committed
Enhance Graph.update() and add whole-graph update tests
- Extend Graph.update() to accept both GraphBuilder and GraphDef sources - Surface CUgraphExecUpdateResultInfo details on update failure instead of a generic CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE message - Release the GIL during cuGraphExecUpdate via nogil block - Add parametrized happy-path test covering both GraphBuilder and GraphDef - Add error-case tests: unfinished builder, topology mismatch, wrong type Made-with: Cursor
1 parent 1633628 commit d3c4ef4

File tree

2 files changed

+144
-17
lines changed

2 files changed

+144
-17
lines changed

cuda_core/cuda/core/_graph/_graph_builder.pyx

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import weakref
66
from dataclasses import dataclass
77

8+
from libc.stdint cimport intptr_t
9+
810
from cuda.bindings cimport cydriver
911

1012
from cuda.core._graph._utils cimport _attach_host_callback_to_graph
@@ -14,6 +16,7 @@ from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
1416
from cuda.core._utils.version cimport cy_binding_version, cy_driver_version
1517

1618
from cuda.core._utils.cuda_utils import (
19+
CUDAError,
1720
driver,
1821
handle_return,
1922
)
@@ -783,24 +786,41 @@ class Graph:
783786
"""
784787
return self._mnff.graph
785788

786-
def update(self, builder: GraphBuilder):
787-
"""Update the graph using new build configuration from the builder.
789+
def update(self, source: "GraphBuilder | GraphDef") -> None:
790+
"""Update the graph using a new graph definition.
788791

789-
The topology of the provided builder must be identical to this graph.
792+
The topology of the provided source must be identical to this graph.
790793

791794
Parameters
792795
----------
793-
builder : :obj:`~_graph.GraphBuilder`
794-
The builder to update the graph with.
796+
source : :obj:`~_graph.GraphBuilder` or :obj:`~_graph._graphdef.GraphDef`
797+
The graph definition to update from. A GraphBuilder must have
798+
finished building.
795799

796800
"""
797-
if not builder._building_ended:
798-
raise ValueError("Graph has not finished building.")
801+
from cuda.core._graph._graphdef import GraphDef
802+
803+
cdef cydriver.CUgraph cu_graph
804+
cdef cydriver.CUgraphExec cu_exec = <cydriver.CUgraphExec><intptr_t>int(self._mnff.graph)
799805

800-
# Update the graph with the new nodes from the builder
801-
exec_update_result = handle_return(driver.cuGraphExecUpdate(self._mnff.graph, builder._mnff.graph))
802-
if exec_update_result.result != driver.CUgraphExecUpdateResult.CU_GRAPH_EXEC_UPDATE_SUCCESS:
803-
raise RuntimeError(f"Failed to update graph: {exec_update_result.result()}")
806+
if isinstance(source, GraphBuilder):
807+
if not source._building_ended:
808+
raise ValueError("Graph has not finished building.")
809+
cu_graph = <cydriver.CUgraph><intptr_t>int(source._mnff.graph)
810+
elif isinstance(source, GraphDef):
811+
cu_graph = <cydriver.CUgraph><intptr_t>int(source.handle)
812+
else:
813+
raise TypeError(
814+
f"expected GraphBuilder or GraphDef, got {type(source).__name__}")
815+
816+
cdef cydriver.CUgraphExecUpdateResultInfo result_info
817+
cdef cydriver.CUresult err
818+
with nogil:
819+
err = cydriver.cuGraphExecUpdate(cu_exec, cu_graph, &result_info)
820+
if err != cydriver.CUresult.CUDA_SUCCESS:
821+
reason = driver.CUgraphExecUpdateResult(result_info.result)
822+
msg = f"Graph update failed: {reason.__doc__.strip()} ({reason.name})"
823+
raise CUDAError(msg)
804824

805825
def upload(self, stream: Stream):
806826
"""Uploads the graph in a stream.

cuda_core/tests/graph/test_graph_update.py

Lines changed: 113 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,66 @@
55

66
import numpy as np
77
import pytest
8-
from helpers.graph_kernels import compile_conditional_kernels
8+
from helpers.graph_kernels import compile_common_kernels, compile_conditional_kernels
99

1010
from cuda.core import Device, LaunchConfig, LegacyPinnedMemoryResource, launch
11+
from cuda.core._graph._graphdef import GraphDef
12+
from cuda.core._utils.cuda_utils import CUDAError
1113

1214

15+
@pytest.mark.parametrize("builder", ["GraphBuilder", "GraphDef"])
1316
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
14-
def test_graph_update(init_cuda):
17+
def test_graph_update_kernel_args(init_cuda, builder):
18+
"""Update redirects a kernel to write to a different pointer."""
19+
mod = compile_common_kernels()
20+
add_one = mod.get_kernel("add_one")
21+
22+
launch_stream = Device().create_stream()
23+
mr = LegacyPinnedMemoryResource()
24+
b = mr.allocate(8)
25+
arr = np.from_dlpack(b).view(np.int32)
26+
arr[0] = 0
27+
arr[1] = 0
28+
29+
if builder == "GraphBuilder":
30+
31+
def build(ptr):
32+
gb = Device().create_graph_builder().begin_building()
33+
launch(gb, LaunchConfig(grid=1, block=1), add_one, ptr)
34+
launch(gb, LaunchConfig(grid=1, block=1), add_one, ptr)
35+
finished = gb.end_building()
36+
return finished.complete(), finished
37+
elif builder == "GraphDef":
38+
39+
def build(ptr):
40+
g = GraphDef()
41+
g.launch(LaunchConfig(grid=1, block=1), add_one, ptr)
42+
g.launch(LaunchConfig(grid=1, block=1), add_one, ptr)
43+
return g.instantiate(), g
44+
45+
graph, _ = build(arr[0:].ctypes.data)
46+
_, source1 = build(arr[1:].ctypes.data)
47+
48+
graph.launch(launch_stream)
49+
launch_stream.sync()
50+
assert arr[0] == 2
51+
assert arr[1] == 0
52+
53+
graph.update(source1)
54+
graph.launch(launch_stream)
55+
launch_stream.sync()
56+
assert arr[0] == 2
57+
assert arr[1] == 2
58+
59+
b.close()
60+
61+
62+
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
63+
def test_graph_update_conditional(init_cuda):
64+
"""Update swaps conditional switch graphs with matching topology."""
1565
mod = compile_conditional_kernels(int)
1666
add_one = mod.get_kernel("add_one")
1767

18-
# Allocate memory
1968
launch_stream = Device().create_stream()
2069
mr = LegacyPinnedMemoryResource()
2170
b = mr.allocate(12)
@@ -72,9 +121,6 @@ def build_graph(condition_value):
72121
pytest.skip("Driver does not support conditional switch")
73122

74123
# Launch the first graph
75-
assert arr[0] == 0
76-
assert arr[1] == 0
77-
assert arr[2] == 0
78124
graph = graph_variants[0].complete()
79125
graph.launch(launch_stream)
80126
launch_stream.sync()
@@ -98,4 +144,65 @@ def build_graph(condition_value):
98144
assert arr[1] == 3
99145
assert arr[2] == 3
100146

147+
# Close the memory resource now because the garbage collected might
148+
# de-allocate it during the next graph builder process
101149
b.close()
150+
151+
152+
# =============================================================================
153+
# Error cases
154+
# =============================================================================
155+
156+
157+
def test_graph_update_unfinished_builder(init_cuda):
158+
"""Update with an unfinished GraphBuilder raises ValueError."""
159+
mod = compile_common_kernels()
160+
empty_kernel = mod.get_kernel("empty_kernel")
161+
162+
gb_finished = Device().create_graph_builder().begin_building()
163+
launch(gb_finished, LaunchConfig(grid=1, block=1), empty_kernel)
164+
graph = gb_finished.end_building().complete()
165+
166+
gb_unfinished = Device().create_graph_builder().begin_building()
167+
launch(gb_unfinished, LaunchConfig(grid=1, block=1), empty_kernel)
168+
169+
with pytest.raises(ValueError, match="Graph has not finished building"):
170+
graph.update(gb_unfinished)
171+
172+
gb_unfinished.end_building()
173+
174+
175+
def test_graph_update_topology_mismatch(init_cuda):
176+
"""Update with a different topology raises CUDAError."""
177+
mod = compile_common_kernels()
178+
empty_kernel = mod.get_kernel("empty_kernel")
179+
180+
# Two-node graph
181+
gb1 = Device().create_graph_builder().begin_building()
182+
launch(gb1, LaunchConfig(grid=1, block=1), empty_kernel)
183+
launch(gb1, LaunchConfig(grid=1, block=1), empty_kernel)
184+
graph = gb1.end_building().complete()
185+
186+
# Three-node graph (different topology)
187+
gb2 = Device().create_graph_builder().begin_building()
188+
launch(gb2, LaunchConfig(grid=1, block=1), empty_kernel)
189+
launch(gb2, LaunchConfig(grid=1, block=1), empty_kernel)
190+
launch(gb2, LaunchConfig(grid=1, block=1), empty_kernel)
191+
gb2.end_building()
192+
193+
expected = r"Graph update failed: The update failed because the topology changed \(CU_GRAPH_EXEC_UPDATE_ERROR_TOPOLOGY_CHANGED\)"
194+
with pytest.raises(CUDAError, match=expected):
195+
graph.update(gb2)
196+
197+
198+
def test_graph_update_wrong_type(init_cuda):
199+
"""Update with an invalid type raises TypeError."""
200+
mod = compile_common_kernels()
201+
empty_kernel = mod.get_kernel("empty_kernel")
202+
203+
gb = Device().create_graph_builder().begin_building()
204+
launch(gb, LaunchConfig(grid=1, block=1), empty_kernel)
205+
graph = gb.end_building().complete()
206+
207+
with pytest.raises(TypeError, match="expected GraphBuilder or GraphDef"):
208+
graph.update("not a graph")

0 commit comments

Comments
 (0)