Skip to content

Commit f579744

Browse files
committed
Merge branch 'experimental_cuda_graph_branching' into 'main'
Conditional graph support (GH-597) See merge request omniverse/warp!1192
2 parents 38eb215 + f6419a9 commit f579744

File tree

9 files changed

+2126
-21
lines changed

9 files changed

+2126
-21
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
the points as USD spheres using a point instancer, or as simple USD points otherwise.
2424
- Add `wp.tile_squeeze()` ([GH-662](https://github.com/NVIDIA/warp/issues/662)).
2525
- Add `wp.tile_reshape()` ([GH-663](https://github.com/NVIDIA/warp/issues/663)).
26+
- Add support for dynamic control flow in CUDA graphs, see `wp.capture_if()` and `wp.capture_while()` ([GH-597](https://github.com/NVIDIA/warp/issues/597)).
2627
- Support tile inplace add/subtract operations
2728
([GH-518](https://github.com/NVIDIA/warp/issues/518)).
2829
- Add support for in-place tile component addition and subtraction

docs/modules/runtime.rst

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,9 +921,232 @@ Note that only launch calls are recorded in the graph; any Python executed outsi
921921
Typically it is only beneficial to use CUDA graphs when the graph will be reused or launched multiple times, as
922922
there is a graph-creation overhead.
923923

924+
Conditional Execution
925+
#####################
926+
927+
CUDA 12.4+ supports conditional graph nodes that enable dynamic control flow in CUDA graphs.
928+
929+
:func:`wp.capture_if <capture_if>` creates a dynamic branch based on a condition. The condition value is read from a single-element ``int`` array, where a non-zero value means that the condition is True.
930+
931+
.. code:: python
932+
933+
# create condition
934+
cond = wp.zeros(1, dtype=int)
935+
936+
with wp.ScopedCapture() as capture:
937+
wp.launch(foo, ...)
938+
939+
# execute a branch based on the condition value
940+
wp.capture_if(cond,
941+
on_true=...,
942+
on_false=...)
943+
944+
wp.launch(bar, ...)
945+
946+
The condition value can be updated by kernels launched prior to ``capture_if()`` in the same graph (e.g. kernel ``foo`` above) or it can be updated by other means before the graph is launched. Note that during graph capture, the value of the condition is ignored. It is only used when the graph is launched, making dynamic control flow possible.
947+
948+
.. code:: python
949+
950+
# this will execute the `on_true` branch
951+
cond.fill_(1)
952+
wp.capture_launch(capture.graph)
953+
954+
# this will execute the `on_false` branch
955+
cond.fill_(0)
956+
wp.capture_launch(capture.graph)
957+
958+
The ``on_true`` and ``on_false`` callbacks can be previously captured :class:`Graph` objects or Python callback functions. These callbacks are captured as child graphs of the enclosing graph. It's possible to specify only one or both callbacks, as needed. When the parent graph is launched, the correct child graph is executed based on the value of the condition. This is done efficiently on the device without involving the CPU.
959+
960+
Here is an example that uses previously captured graphs:
961+
962+
.. code:: python
963+
964+
@wp.kernel
965+
def hello_kernel():
966+
print("Hello")
967+
968+
@wp.kernel
969+
def goodbye_kernel():
970+
print("Goodbye")
971+
972+
@wp.kernel
973+
def yes_kernel():
974+
print("Yes!")
975+
976+
@wp.kernel
977+
def no_kernel():
978+
print("No!")
979+
980+
981+
# create condition
982+
cond = wp.zeros(1, dtype=int)
983+
984+
# capture the on_true branch
985+
with wp.ScopedCapture() as yes_capture:
986+
wp.launch(yes_kernel, dim=1)
987+
988+
# capture the on_false branch
989+
with wp.ScopedCapture() as no_capture:
990+
wp.launch(no_kernel, dim=1)
991+
992+
# capture the main graph
993+
with wp.ScopedCapture() as capture:
994+
wp.launch(hello_kernel, dim=1)
995+
996+
# specify branches using subgraphs
997+
wp.capture_if(cond,
998+
on_true=yes_capture.graph,
999+
on_false=no_capture.graph)
1000+
1001+
wp.launch(goodbye_kernel, dim=1)
1002+
1003+
# execute on_true branch
1004+
cond.fill_(1)
1005+
wp.capture_launch(capture.graph)
1006+
1007+
# execute on_false branch
1008+
cond.fill_(0)
1009+
wp.capture_launch(capture.graph)
1010+
1011+
wp.synchronize_device()
1012+
1013+
Here is an example that uses Python callback functions. These callbacks will be captured as child graphs of the main graph:
1014+
1015+
.. code:: python
1016+
1017+
@wp.kernel
1018+
def hello_kernel():
1019+
print("Hello")
1020+
1021+
@wp.kernel
1022+
def goodbye_kernel():
1023+
print("Goodbye")
1024+
1025+
@wp.kernel
1026+
def yes_kernel():
1027+
print("Yes!")
1028+
1029+
@wp.kernel
1030+
def no_kernel():
1031+
print("No!")
1032+
1033+
1034+
# create condition
1035+
cond = wp.zeros(1, dtype=int)
1036+
1037+
# Python callback for the on_true branch
1038+
def yes_callback():
1039+
wp.launch(yes_kernel, dim=1)
1040+
1041+
# Python callback for the on_false branch
1042+
def no_callback():
1043+
wp.launch(no_kernel, dim=1)
1044+
1045+
# capture the main graph
1046+
with wp.ScopedCapture() as capture:
1047+
wp.launch(hello_kernel, dim=1)
1048+
1049+
# specify branches using Python callback functions
1050+
wp.capture_if(cond,
1051+
on_true=yes_callback,
1052+
on_false=no_callback)
1053+
1054+
wp.launch(goodbye_kernel, dim=1)
1055+
1056+
# execute on_true branch
1057+
cond.fill_(1)
1058+
wp.capture_launch(capture.graph)
1059+
1060+
# execute on_false branch
1061+
cond.fill_(0)
1062+
wp.capture_launch(capture.graph)
1063+
1064+
wp.synchronize_device()
1065+
1066+
When using Python callback functions, any extra keyword arguments to :func:`wp.capture_if <capture_if>` are forwarded to the callbacks.
1067+
1068+
:func:`wp.capture_while <capture_while>` creates a dynamic loop based on a condition. Similarly to :func:`wp.capture_if <capture_if>`, the condition value is read from a single-element ``int`` array, where a non-zero value means that the condition is True.
1069+
1070+
.. code:: python
1071+
1072+
# create condition
1073+
cond = wp.zeros(1, dtype=int)
1074+
1075+
with wp.ScopedCapture() as capture:
1076+
wp.launch(foo, ...)
1077+
1078+
# execute the while_body while the condition is true
1079+
wp.capture_while(cond, while_body=...)
1080+
1081+
wp.launch(bar, ...)
1082+
1083+
The ``while_body`` callback will be executed as long as the condition is non-zero. The callback is responsible for updating the condition value so that the loop eventually terminates. The ``while_body`` argument can be a previously captured graph or a Python callback function. Here is an example that will run some number of iterations, using the condition value as a counter:
1084+
1085+
.. code:: python
1086+
1087+
@wp.kernel
1088+
def hello_kernel():
1089+
print("Hello")
1090+
1091+
@wp.kernel
1092+
def goodbye_kernel():
1093+
print("Goodbye")
1094+
1095+
@wp.kernel
1096+
def body_kernel(cond: wp.array(dtype=int)):
1097+
tid = wp.tid()
1098+
print(cond[0])
1099+
# decrement the condition counter
1100+
if tid == 0:
1101+
cond[0] -= 1
1102+
1103+
1104+
# create condition
1105+
cond = wp.zeros(1, dtype=int)
1106+
1107+
# capture the while_body
1108+
with wp.ScopedCapture() as body_capture:
1109+
wp.launch(body_kernel, dim=1, inputs=[cond])
1110+
1111+
# capture the main graph
1112+
with wp.ScopedCapture() as capture:
1113+
wp.launch(hello_kernel, dim=1)
1114+
1115+
# dynamic loop
1116+
wp.capture_while(cond, while_body=body_capture.graph)
1117+
1118+
wp.launch(goodbye_kernel, dim=1)
1119+
1120+
# loop 5 times
1121+
cond.fill_(5)
1122+
wp.capture_launch(capture.graph)
1123+
1124+
# loop 2 times
1125+
cond.fill_(2)
1126+
wp.capture_launch(capture.graph)
1127+
1128+
wp.synchronize_device()
1129+
1130+
1131+
.. note::
1132+
Conditional graph node support is only available if Warp is built using CUDA Toolkit 12.4+ and the NVIDIA driver supports CUDA 12.4+.
1133+
1134+
.. note::
1135+
Due to a current CUDA limitation, graphs with conditional nodes cannot be used as child graphs. It means that it's not possible to create nested conditional constructs using previously captured graphs. If nesting is required, using Python callback functions is the way to go.
1136+
1137+
.. note::
1138+
:func:`wp.capture_if <capture_if>` and :func:`wp.capture_while <capture_while>` will work even without graph capture on any device. If there is no active capture, the condition will be evaluated on the CPU and the correct branch will be executed immediately. This makes it possible to write code that works similarly with and without graph capture.
1139+
1140+
1141+
1142+
Graph API Reference
1143+
###################
1144+
9241145
.. autofunction:: capture_begin
9251146
.. autofunction:: capture_end
9261147
.. autofunction:: capture_launch
1148+
.. autofunction:: capture_if
1149+
.. autofunction:: capture_while
9271150

9281151
.. autoclass:: ScopedCapture
9291152
:members:

warp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
event_from_ipc_handle,
7979
)
8080
from warp.context import set_module_options, get_module_options, get_module
81-
from warp.context import capture_begin, capture_end, capture_launch
81+
from warp.context import capture_begin, capture_end, capture_launch, capture_if, capture_while
8282
from warp.context import Kernel, Function, Launch
8383
from warp.context import Stream, get_stream, set_stream, wait_stream, synchronize_stream
8484
from warp.context import Event, record_event, wait_event, synchronize_event, get_event_elapsed_time

0 commit comments

Comments
 (0)