55
66import numpy as np
77import pytest
8- from helpers .graph_kernels import compile_conditional_kernels
8+ from helpers .graph_kernels import compile_common_kernels , compile_conditional_kernels
99
1010from cuda .core import Device , LaunchConfig , LegacyPinnedMemoryResource , launch
11+ from cuda .core ._graph ._graphdef import GraphDef
12+ from cuda .core ._utils .cuda_utils import CUDAError
1113
1214
15+ @pytest .mark .parametrize ("builder" , ["GraphBuilder" , "GraphDef" ])
1316@pytest .mark .skipif (tuple (int (i ) for i in np .__version__ .split ("." )[:2 ]) < (2 , 1 ), reason = "need numpy 2.1.0+" )
14- def test_graph_update (init_cuda ):
17+ def test_graph_update_kernel_args (init_cuda , builder ):
18+ """Update redirects a kernel to write to a different pointer."""
19+ mod = compile_common_kernels ()
20+ add_one = mod .get_kernel ("add_one" )
21+
22+ launch_stream = Device ().create_stream ()
23+ mr = LegacyPinnedMemoryResource ()
24+ b = mr .allocate (8 )
25+ arr = np .from_dlpack (b ).view (np .int32 )
26+ arr [0 ] = 0
27+ arr [1 ] = 0
28+
29+ if builder == "GraphBuilder" :
30+
31+ def build (ptr ):
32+ gb = Device ().create_graph_builder ().begin_building ()
33+ launch (gb , LaunchConfig (grid = 1 , block = 1 ), add_one , ptr )
34+ launch (gb , LaunchConfig (grid = 1 , block = 1 ), add_one , ptr )
35+ finished = gb .end_building ()
36+ return finished .complete (), finished
37+ elif builder == "GraphDef" :
38+
39+ def build (ptr ):
40+ g = GraphDef ()
41+ g .launch (LaunchConfig (grid = 1 , block = 1 ), add_one , ptr )
42+ g .launch (LaunchConfig (grid = 1 , block = 1 ), add_one , ptr )
43+ return g .instantiate (), g
44+
45+ graph , _ = build (arr [0 :].ctypes .data )
46+ _ , source1 = build (arr [1 :].ctypes .data )
47+
48+ graph .launch (launch_stream )
49+ launch_stream .sync ()
50+ assert arr [0 ] == 2
51+ assert arr [1 ] == 0
52+
53+ graph .update (source1 )
54+ graph .launch (launch_stream )
55+ launch_stream .sync ()
56+ assert arr [0 ] == 2
57+ assert arr [1 ] == 2
58+
59+ b .close ()
60+
61+
62+ @pytest .mark .skipif (tuple (int (i ) for i in np .__version__ .split ("." )[:2 ]) < (2 , 1 ), reason = "need numpy 2.1.0+" )
63+ def test_graph_update_conditional (init_cuda ):
64+ """Update swaps conditional switch graphs with matching topology."""
1565 mod = compile_conditional_kernels (int )
1666 add_one = mod .get_kernel ("add_one" )
1767
18- # Allocate memory
1968 launch_stream = Device ().create_stream ()
2069 mr = LegacyPinnedMemoryResource ()
2170 b = mr .allocate (12 )
@@ -72,9 +121,6 @@ def build_graph(condition_value):
72121 pytest .skip ("Driver does not support conditional switch" )
73122
74123 # Launch the first graph
75- assert arr [0 ] == 0
76- assert arr [1 ] == 0
77- assert arr [2 ] == 0
78124 graph = graph_variants [0 ].complete ()
79125 graph .launch (launch_stream )
80126 launch_stream .sync ()
@@ -98,4 +144,65 @@ def build_graph(condition_value):
98144 assert arr [1 ] == 3
99145 assert arr [2 ] == 3
100146
147+ # Close the memory resource now because the garbage collected might
148+ # de-allocate it during the next graph builder process
101149 b .close ()
150+
151+
152+ # =============================================================================
153+ # Error cases
154+ # =============================================================================
155+
156+
157+ def test_graph_update_unfinished_builder (init_cuda ):
158+ """Update with an unfinished GraphBuilder raises ValueError."""
159+ mod = compile_common_kernels ()
160+ empty_kernel = mod .get_kernel ("empty_kernel" )
161+
162+ gb_finished = Device ().create_graph_builder ().begin_building ()
163+ launch (gb_finished , LaunchConfig (grid = 1 , block = 1 ), empty_kernel )
164+ graph = gb_finished .end_building ().complete ()
165+
166+ gb_unfinished = Device ().create_graph_builder ().begin_building ()
167+ launch (gb_unfinished , LaunchConfig (grid = 1 , block = 1 ), empty_kernel )
168+
169+ with pytest .raises (ValueError , match = "Graph has not finished building" ):
170+ graph .update (gb_unfinished )
171+
172+ gb_unfinished .end_building ()
173+
174+
175+ def test_graph_update_topology_mismatch (init_cuda ):
176+ """Update with a different topology raises CUDAError."""
177+ mod = compile_common_kernels ()
178+ empty_kernel = mod .get_kernel ("empty_kernel" )
179+
180+ # Two-node graph
181+ gb1 = Device ().create_graph_builder ().begin_building ()
182+ launch (gb1 , LaunchConfig (grid = 1 , block = 1 ), empty_kernel )
183+ launch (gb1 , LaunchConfig (grid = 1 , block = 1 ), empty_kernel )
184+ graph = gb1 .end_building ().complete ()
185+
186+ # Three-node graph (different topology)
187+ gb2 = Device ().create_graph_builder ().begin_building ()
188+ launch (gb2 , LaunchConfig (grid = 1 , block = 1 ), empty_kernel )
189+ launch (gb2 , LaunchConfig (grid = 1 , block = 1 ), empty_kernel )
190+ launch (gb2 , LaunchConfig (grid = 1 , block = 1 ), empty_kernel )
191+ gb2 .end_building ()
192+
193+ expected = r"Graph update failed: The update failed because the topology changed \(CU_GRAPH_EXEC_UPDATE_ERROR_TOPOLOGY_CHANGED\)"
194+ with pytest .raises (CUDAError , match = expected ):
195+ graph .update (gb2 )
196+
197+
198+ def test_graph_update_wrong_type (init_cuda ):
199+ """Update with an invalid type raises TypeError."""
200+ mod = compile_common_kernels ()
201+ empty_kernel = mod .get_kernel ("empty_kernel" )
202+
203+ gb = Device ().create_graph_builder ().begin_building ()
204+ launch (gb , LaunchConfig (grid = 1 , block = 1 ), empty_kernel )
205+ graph = gb .end_building ().complete ()
206+
207+ with pytest .raises (TypeError , match = "expected GraphBuilder or GraphDef" ):
208+ graph .update ("not a graph" )
0 commit comments