You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
- Add support for dynamic control flow in CUDA graphs, see `wp.capture_if()` and `wp.capture_while()` ([GH-597](https://github.com/NVIDIA/warp/issues/597)).
Copy file name to clipboardExpand all lines: docs/modules/runtime.rst
+223Lines changed: 223 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -921,9 +921,232 @@ Note that only launch calls are recorded in the graph; any Python executed outsi
921
921
Typically it is only beneficial to use CUDA graphs when the graph will be reused or launched multiple times, as
922
922
there is a graph-creation overhead.
923
923
924
+
Conditional Execution
925
+
#####################
926
+
927
+
CUDA 12.4+ supports conditional graph nodes that enable dynamic control flow in CUDA graphs.
928
+
929
+
:func:`wp.capture_if <capture_if>` creates a dynamic branch based on a condition. The condition value is read from a single-element ``int`` array, where a non-zero value means that the condition is True.
930
+
931
+
.. code:: python
932
+
933
+
# create condition
934
+
cond = wp.zeros(1, dtype=int)
935
+
936
+
with wp.ScopedCapture() as capture:
937
+
wp.launch(foo, ...)
938
+
939
+
# execute a branch based on the condition value
940
+
wp.capture_if(cond,
941
+
on_true=...,
942
+
on_false=...)
943
+
944
+
wp.launch(bar, ...)
945
+
946
+
The condition value can be updated by kernels launched prior to ``capture_if()`` in the same graph (e.g. kernel ``foo`` above) or it can be updated by other means before the graph is launched. Note that during graph capture, the value of the condition is ignored. It is only used when the graph is launched, making dynamic control flow possible.
947
+
948
+
.. code:: python
949
+
950
+
# this will execute the `on_true` branch
951
+
cond.fill_(1)
952
+
wp.capture_launch(capture.graph)
953
+
954
+
# this will execute the `on_false` branch
955
+
cond.fill_(0)
956
+
wp.capture_launch(capture.graph)
957
+
958
+
The ``on_true`` and ``on_false`` callbacks can be previously captured :class:`Graph` objects or Python callback functions. These callbacks are captured as child graphs of the enclosing graph. It's possible to specify only one or both callbacks, as needed. When the parent graph is launched, the correct child graph is executed based on the value of the condition. This is done efficiently on the device without involving the CPU.
959
+
960
+
Here is an example that uses previously captured graphs:
961
+
962
+
.. code:: python
963
+
964
+
@wp.kernel
965
+
defhello_kernel():
966
+
print("Hello")
967
+
968
+
@wp.kernel
969
+
defgoodbye_kernel():
970
+
print("Goodbye")
971
+
972
+
@wp.kernel
973
+
defyes_kernel():
974
+
print("Yes!")
975
+
976
+
@wp.kernel
977
+
defno_kernel():
978
+
print("No!")
979
+
980
+
981
+
# create condition
982
+
cond = wp.zeros(1, dtype=int)
983
+
984
+
# capture the on_true branch
985
+
with wp.ScopedCapture() as yes_capture:
986
+
wp.launch(yes_kernel, dim=1)
987
+
988
+
# capture the on_false branch
989
+
with wp.ScopedCapture() as no_capture:
990
+
wp.launch(no_kernel, dim=1)
991
+
992
+
# capture the main graph
993
+
with wp.ScopedCapture() as capture:
994
+
wp.launch(hello_kernel, dim=1)
995
+
996
+
# specify branches using subgraphs
997
+
wp.capture_if(cond,
998
+
on_true=yes_capture.graph,
999
+
on_false=no_capture.graph)
1000
+
1001
+
wp.launch(goodbye_kernel, dim=1)
1002
+
1003
+
# execute on_true branch
1004
+
cond.fill_(1)
1005
+
wp.capture_launch(capture.graph)
1006
+
1007
+
# execute on_false branch
1008
+
cond.fill_(0)
1009
+
wp.capture_launch(capture.graph)
1010
+
1011
+
wp.synchronize_device()
1012
+
1013
+
Here is an example that uses Python callback functions. These callbacks will be captured as child graphs of the main graph:
1014
+
1015
+
.. code:: python
1016
+
1017
+
@wp.kernel
1018
+
defhello_kernel():
1019
+
print("Hello")
1020
+
1021
+
@wp.kernel
1022
+
defgoodbye_kernel():
1023
+
print("Goodbye")
1024
+
1025
+
@wp.kernel
1026
+
defyes_kernel():
1027
+
print("Yes!")
1028
+
1029
+
@wp.kernel
1030
+
defno_kernel():
1031
+
print("No!")
1032
+
1033
+
1034
+
# create condition
1035
+
cond = wp.zeros(1, dtype=int)
1036
+
1037
+
# Python callback for the on_true branch
1038
+
defyes_callback():
1039
+
wp.launch(yes_kernel, dim=1)
1040
+
1041
+
# Python callback for the on_false branch
1042
+
defno_callback():
1043
+
wp.launch(no_kernel, dim=1)
1044
+
1045
+
# capture the main graph
1046
+
with wp.ScopedCapture() as capture:
1047
+
wp.launch(hello_kernel, dim=1)
1048
+
1049
+
# specify branches using Python callback functions
1050
+
wp.capture_if(cond,
1051
+
on_true=yes_callback,
1052
+
on_false=no_callback)
1053
+
1054
+
wp.launch(goodbye_kernel, dim=1)
1055
+
1056
+
# execute on_true branch
1057
+
cond.fill_(1)
1058
+
wp.capture_launch(capture.graph)
1059
+
1060
+
# execute on_false branch
1061
+
cond.fill_(0)
1062
+
wp.capture_launch(capture.graph)
1063
+
1064
+
wp.synchronize_device()
1065
+
1066
+
When using Python callback functions, any extra keyword arguments to :func:`wp.capture_if <capture_if>` are forwarded to the callbacks.
1067
+
1068
+
:func:`wp.capture_while <capture_while>` creates a dynamic loop based on a condition. Similarly to :func:`wp.capture_if <capture_if>`, the condition value is read from a single-element ``int`` array, where a non-zero value means that the condition is True.
1069
+
1070
+
.. code:: python
1071
+
1072
+
# create condition
1073
+
cond = wp.zeros(1, dtype=int)
1074
+
1075
+
with wp.ScopedCapture() as capture:
1076
+
wp.launch(foo, ...)
1077
+
1078
+
# execute the while_body while the condition is true
1079
+
wp.capture_while(cond, while_body=...)
1080
+
1081
+
wp.launch(bar, ...)
1082
+
1083
+
The ``while_body`` callback will be executed as long as the condition is non-zero. The callback is responsible for updating the condition value so that the loop eventually terminates. The ``while_body`` argument can be a previously captured graph or a Python callback function. Here is an example that will run some number of iterations, using the condition value as a counter:
Conditional graph node support is only available if Warp is built using CUDA Toolkit 12.4+ and the NVIDIA driver supports CUDA 12.4+.
1133
+
1134
+
.. note::
1135
+
Due to a current CUDA limitation, graphs with conditional nodes cannot be used as child graphs. It means that it's not possible to create nested conditional constructs using previously captured graphs. If nesting is required, using Python callback functions is the way to go.
1136
+
1137
+
.. note::
1138
+
:func:`wp.capture_if <capture_if>` and :func:`wp.capture_while <capture_while>` will work even without graph capture on any device. If there is no active capture, the condition will be evaluated on the CPU and the correct branch will be executed immediately. This makes it possible to write code that works similarly with and without graph capture.
0 commit comments