Skip to content

Commit 3db7004

Browse files
committed
Fix wp.static() capturing global variables instead of loop variables
When a global Python variable had the same name as a kernel for-loop variable, wp.static() incorrectly used the global value. The fix tracks loop variables during AST traversal and defers wp.static() evaluation when the expression references a loop variable. Fixes GH-1139 Signed-off-by: Eric Shi <[email protected]>
1 parent eb3e96a commit 3db7004

File tree

3 files changed

+120
-0
lines changed

3 files changed

+120
-0
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
### Fixed
1717

18+
- Fix `wp.static()` incorrectly capturing global Python variables instead of loop variables when used inside for-loops
19+
in kernels ([GH-1139](https://github.com/NVIDIA/warp/issues/1139)).
1820
- Fix `--llvm-path` build option to use existing LLVM installation when building `warp-clang` library instead of
1921
downloading from packman.
2022
- Fix excessive memory usage in CUDA graphs with multiple allocations/deallocations

warp/_src/codegen.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3736,9 +3736,37 @@ def evaluate_static_expression(adj, node) -> tuple[Any, str]:
37363736
# expression can be evaluated
37373737
def replace_static_expressions(adj):
37383738
class StaticExpressionReplacer(ast.NodeTransformer):
3739+
def __init__(self):
3740+
# Track loop variable names from enclosing for loops. This prevents
3741+
# wp.static() from capturing a global variable that shadows a loop variable.
3742+
# Uses a counter (not a set) to handle nested loops that reuse the same variable name.
3743+
self.loop_vars = {}
3744+
3745+
def visit_For(self, node):
3746+
# Track loop variable while visiting loop body (simple names only;
3747+
# tuple unpacking like `for x, y in ...` is rare in Warp kernels)
3748+
var_name = node.target.id if isinstance(node.target, ast.Name) else None
3749+
if var_name:
3750+
self.loop_vars[var_name] = self.loop_vars.get(var_name, 0) + 1
3751+
result = self.generic_visit(node)
3752+
if var_name:
3753+
self.loop_vars[var_name] -= 1
3754+
if self.loop_vars[var_name] == 0:
3755+
del self.loop_vars[var_name]
3756+
return result
3757+
37393758
def visit_Call(self, node):
37403759
func, _ = adj.resolve_static_expression(node.func, eval_types=False)
37413760
if adj.is_static_expression(func):
3761+
# If the static expression references an enclosing loop variable,
3762+
# defer evaluation to codegen time when the loop constant is available
3763+
expr_node = node.args[0] if node.args else (node.keywords[0].value if node.keywords else None)
3764+
if expr_node:
3765+
referenced = {n.id for n in ast.walk(expr_node) if isinstance(n, ast.Name)}
3766+
if referenced & self.loop_vars.keys():
3767+
adj.has_unresolved_static_expressions = True
3768+
return self.generic_visit(node)
3769+
37423770
try:
37433771
# the static expression will execute as long as the static expression is valid and
37443772
# only depends on global or captured variables

warp/tests/test_static.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,90 @@ def test_unresolved_static_expression(test, device):
659659
test.assertEqual(output2.numpy()[0], 2)
660660

661661

662+
# Global variables used to test wp.static() loop variable handling
663+
_global_test_idx = 999
664+
_global_test_j = 888
665+
666+
667+
@wp.kernel
668+
def static_loop_var_kernel(results: wp.array(dtype=int)):
669+
"""Kernel where wp.static() should capture the loop variable, not the global."""
670+
for _global_test_idx in range(3):
671+
results[_global_test_idx] = wp.static(_global_test_idx)
672+
673+
674+
@wp.kernel
675+
def static_loop_var_in_expr_kernel(results: wp.array(dtype=int)):
676+
"""Kernel where loop variable is used in an arithmetic expression."""
677+
for _global_test_idx in range(3):
678+
# Even in complex expressions, the loop variable should be used
679+
results[_global_test_idx] = wp.static(_global_test_idx * 2 + 1)
680+
681+
682+
@wp.kernel
683+
def static_nested_loop_kernel(results: wp.array(dtype=int)):
684+
"""Kernel with nested loops - both loop variables should be protected."""
685+
for _global_test_idx in range(2):
686+
for _global_test_j in range(2):
687+
idx = _global_test_idx * 2 + _global_test_j
688+
results[idx] = wp.static(_global_test_idx * 10 + _global_test_j)
689+
690+
691+
@wp.kernel
692+
def static_nested_loop_same_var_kernel(results: wp.array(dtype=int)):
693+
"""Kernel with nested loops reusing the same variable name.
694+
695+
Tests counter-based loop variable tracking: when inner and outer loops
696+
use the same variable name, the global should still not be captured.
697+
Per Python semantics, after the inner loop the variable has the inner
698+
loop's final value.
699+
"""
700+
idx = 0
701+
for _global_test_idx in range(2):
702+
for _global_test_idx in range(3): # intentional shadowing for test
703+
pass
704+
# Per Python semantics, _global_test_idx is now 2 (inner loop's final value)
705+
# Key: we should NOT capture the global value (999)
706+
results[idx] = wp.static(_global_test_idx)
707+
idx += 1
708+
709+
710+
def test_static_loop_variable_not_shadowed_by_global(test, device):
711+
"""Test that wp.static() inside a for loop correctly captures the loop variable.
712+
713+
When a global Python variable exists with the same name as a kernel loop variable,
714+
wp.static() should use the loop variable's compile-time constant value (0, 1, 2, ...),
715+
not the unrelated global variable. This prevents confusing behavior where the
716+
presence of a global variable silently changes the kernel's output.
717+
"""
718+
with wp.ScopedDevice(device):
719+
# Test 1: Simple loop variable
720+
results = wp.zeros(3, dtype=int)
721+
wp.launch(static_loop_var_kernel, dim=1, inputs=[results])
722+
np.testing.assert_array_equal(results.numpy(), np.array([0, 1, 2]), err_msg="Simple loop variable test failed")
723+
724+
# Test 2: Loop variable in arithmetic expression
725+
results2 = wp.zeros(3, dtype=int)
726+
wp.launch(static_loop_var_in_expr_kernel, dim=1, inputs=[results2])
727+
np.testing.assert_array_equal(
728+
results2.numpy(), np.array([1, 3, 5]), err_msg="Loop variable in expression test failed"
729+
)
730+
731+
# Test 3: Nested loops - both variables protected
732+
results3 = wp.zeros(4, dtype=int)
733+
wp.launch(static_nested_loop_kernel, dim=1, inputs=[results3])
734+
np.testing.assert_array_equal(results3.numpy(), np.array([0, 1, 10, 11]), err_msg="Nested loop test failed")
735+
736+
# Test 4: Nested loops reusing the same variable name
737+
# Tests counter-based tracking: global should not be captured even with shadowing
738+
results4 = wp.zeros(2, dtype=int)
739+
wp.launch(static_nested_loop_same_var_kernel, dim=1, inputs=[results4])
740+
# Per Python semantics: inner loop shadows outer, final value (2) persists
741+
np.testing.assert_array_equal(
742+
results4.numpy(), np.array([2, 2]), err_msg="Nested loop with same variable test failed"
743+
)
744+
745+
662746
devices = get_test_devices()
663747

664748

@@ -689,6 +773,12 @@ def test_static_python_call(self):
689773
add_function_test(TestStatic, "test_static_constant_hash", test_static_constant_hash, devices=None)
690774
add_function_test(TestStatic, "test_static_function_hash", test_static_function_hash, devices=None)
691775
add_function_test(TestStatic, "test_static_len_query", test_static_len_query, devices=None)
776+
add_function_test(
777+
TestStatic,
778+
"test_static_loop_variable_not_shadowed_by_global",
779+
test_static_loop_variable_not_shadowed_by_global,
780+
devices=devices,
781+
)
692782

693783

694784
if __name__ == "__main__":

0 commit comments

Comments
 (0)