Skip to content
Merged
43 changes: 32 additions & 11 deletions cuda_core/cuda/core/_graph/_graph_builder.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import weakref
from dataclasses import dataclass

from libc.stdint cimport intptr_t

from cuda.bindings cimport cydriver

from cuda.core._graph._utils cimport _attach_host_callback_to_graph
Expand All @@ -14,6 +16,7 @@ from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
from cuda.core._utils.version cimport cy_binding_version, cy_driver_version

from cuda.core._utils.cuda_utils import (
CUDAError,
driver,
handle_return,
)
Expand Down Expand Up @@ -783,24 +786,42 @@ class Graph:
"""
return self._mnff.graph

def update(self, builder: GraphBuilder):
"""Update the graph using new build configuration from the builder.
def update(self, source: "GraphBuilder | GraphDef") -> None:
"""Update the graph using a new graph definition.

The topology of the provided builder must be identical to this graph.
The topology of the provided source must be identical to this graph.

Parameters
----------
builder : :obj:`~_graph.GraphBuilder`
The builder to update the graph with.
source : :obj:`~_graph.GraphBuilder` or :obj:`~_graph._graph_def.GraphDef`
The graph definition to update from. A GraphBuilder must have
finished building.

"""
if not builder._building_ended:
raise ValueError("Graph has not finished building.")
from cuda.core._graph._graph_def import GraphDef

cdef cydriver.CUgraph cu_graph
cdef cydriver.CUgraphExec cu_exec = <cydriver.CUgraphExec><intptr_t>int(self._mnff.graph)

# Update the graph with the new nodes from the builder
exec_update_result = handle_return(driver.cuGraphExecUpdate(self._mnff.graph, builder._mnff.graph))
if exec_update_result.result != driver.CUgraphExecUpdateResult.CU_GRAPH_EXEC_UPDATE_SUCCESS:
raise RuntimeError(f"Failed to update graph: {exec_update_result.result()}")
if isinstance(source, GraphBuilder):
if not source._building_ended:
raise ValueError("Graph has not finished building.")
cu_graph = <cydriver.CUgraph><intptr_t>int(source._mnff.graph)
elif isinstance(source, GraphDef):
cu_graph = <cydriver.CUgraph><intptr_t>int(source.handle)
else:
raise TypeError(
f"expected GraphBuilder or GraphDef, got {type(source).__name__}")

cdef cydriver.CUgraphExecUpdateResultInfo result_info
cdef cydriver.CUresult err
with nogil:
err = cydriver.cuGraphExecUpdate(cu_exec, cu_graph, &result_info)
if err == cydriver.CUresult.CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE:
reason = driver.CUgraphExecUpdateResult(result_info.result)
msg = f"Graph update failed: {reason.__doc__.strip()} ({reason.name})"
raise CUDAError(msg)
HANDLE_RETURN(err)

def upload(self, stream: Stream):
"""Uploads the graph in a stream.
Expand Down
23 changes: 23 additions & 0 deletions cuda_core/cuda/core/_graph/_graph_def/__init__.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

from cuda.core._graph._graph_def._graph_def cimport Condition, GraphDef
from cuda.core._graph._graph_def._graph_node cimport GraphNode
from cuda.core._graph._graph_def._subclasses cimport (
AllocNode,
ChildGraphNode,
ConditionalNode,
EmptyNode,
EventRecordNode,
EventWaitNode,
FreeNode,
HostCallbackNode,
IfElseNode,
IfNode,
KernelNode,
MemcpyNode,
MemsetNode,
SwitchNode,
WhileNode,
)
51 changes: 51 additions & 0 deletions cuda_core/cuda/core/_graph/_graph_def/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

"""Explicit CUDA graph construction — GraphDef, GraphNode, and node subclasses."""

from cuda.core._graph._graph_def._graph_def import (
Condition,
GraphAllocOptions,
GraphDef,
)
from cuda.core._graph._graph_def._graph_node import GraphNode
from cuda.core._graph._graph_def._subclasses import (
AllocNode,
ChildGraphNode,
ConditionalNode,
EmptyNode,
EventRecordNode,
EventWaitNode,
FreeNode,
HostCallbackNode,
IfElseNode,
IfNode,
KernelNode,
MemcpyNode,
MemsetNode,
SwitchNode,
WhileNode,
)

__all__ = [
"AllocNode",
"ChildGraphNode",
"Condition",
"ConditionalNode",
"EmptyNode",
"EventRecordNode",
"EventWaitNode",
"FreeNode",
"GraphAllocOptions",
"GraphDef",
"GraphNode",
"HostCallbackNode",
"IfElseNode",
"IfNode",
"KernelNode",
"MemcpyNode",
"MemsetNode",
"SwitchNode",
"WhileNode",
]
21 changes: 21 additions & 0 deletions cuda_core/cuda/core/_graph/_graph_def/_graph_def.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

from cuda.bindings cimport cydriver
from cuda.core._resource_handles cimport GraphHandle


cdef class Condition:
cdef:
cydriver.CUgraphConditionalHandle _c_handle
object __weakref__


cdef class GraphDef:
cdef:
GraphHandle _h_graph
object __weakref__

@staticmethod
cdef GraphDef _from_handle(GraphHandle h_graph)
Loading
Loading