Skip to content

Commit 6c226e1

Browse files
committed
Support @Property getters for warp.struct
Signed-off-by: Alexandre Ghelfi <[email protected]>
1 parent eb3e96a commit 6c226e1

File tree

2 files changed

+99
-1
lines changed

2 files changed

+99
-1
lines changed

warp/_src/codegen.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ def __init__(self, key: str, cls: type, module: warp._src.context.Module):
460460
self.cls = cls
461461
self.module = module
462462
self.vars: dict[str, Var] = {}
463+
self.properties: dict[str, warp._src.context.Function] = {}
463464

464465
if isinstance(self.cls, Sequence):
465466
raise RuntimeError("Warp structs must be defined as base classes")
@@ -482,6 +483,12 @@ def __init__(self, key: str, cls: type, module: warp._src.context.Module):
482483
warp.init()
483484
fields.append((label, var.type._type_))
484485

486+
# Collect properties, but postpone Function creation until after native_name is set
487+
property_members = []
488+
for name, item in inspect.getmembers(self.cls):
489+
if isinstance(item, property):
490+
property_members.append((name, item))
491+
485492
class StructType(ctypes.Structure):
486493
# if struct is empty, add a dummy field to avoid launch errors on CPU device ("ffi_prep_cif failed")
487494
_fields_ = fields or [("_dummy_", ctypes.c_byte)]
@@ -502,12 +509,51 @@ class StructType(ctypes.Structure):
502509
if isinstance(type_hint, Struct):
503510
ch.update(type_hint.hash)
504511

505-
self.hash = ch.digest()
512+
# Hash property names (to ensure layout/identity stability if names change)
513+
for name, _ in property_members:
514+
ch.update(bytes(name, "utf-8"))
506515

516+
self.hash = ch.digest()
507517
# generate unique identifier for structs in native code
508518
hash_suffix = f"{self.hash.hex()[:8]}"
509519
self.native_name = f"{self.key}_{hash_suffix}"
510520

521+
# Extract properties and create Functions
522+
# self.native_name is now defined, so Function() can resolve 'self' type code.
523+
for name, item in property_members:
524+
# We currently support only getters
525+
if item.fset is not None:
526+
raise TypeError("Struct properties with setters are not supported")
527+
if item.fdel is not None:
528+
raise TypeError("Struct properties with deleters are not supported")
529+
getter = item.fget
530+
# We need to add 'self' as the first argument, with the type of the struct itself.
531+
# This allows overload resolution to match the struct instance to the 'self' argument.
532+
if not hasattr(getter, "__annotations__"):
533+
getter.__annotations__ = {}
534+
535+
# Find the name of the first argument (conventionally 'self')
536+
argspec = get_full_arg_spec(getter)
537+
if len(argspec.args) > 0:
538+
self_arg = argspec.args[0]
539+
getter.__annotations__[self_arg] = self
540+
541+
# Create the Warp Function.
542+
# We pass 'func=getter' so that input_types and return_types are inferred.
543+
# We set 'namespace=""' and a unique 'native_func' to generate a free function
544+
# in C++ that takes the struct as the first argument (e.g., StructName_propName(struct_inst)).
545+
p_func = warp._src.context.Function(
546+
func=getter,
547+
key=f"{self.key}.{name}",
548+
namespace="",
549+
module=module,
550+
)
551+
552+
# Ensure the C++ function name is unique and predictable
553+
p_func.native_func = f"{self.native_name}_{name}"
554+
555+
self.properties[name] = p_func
556+
511557
# create default constructor (zero-initialize)
512558
self.default_constructor = warp._src.context.Function(
513559
func=None,
@@ -2260,6 +2306,11 @@ def emit_Attribute(adj, node):
22602306
else:
22612307
return adj.add_builtin_call("transform_get_rotation", [aggregate])
22622308

2309+
elif isinstance(aggregate_type, Struct) and node.attr in aggregate_type.properties:
2310+
# property access
2311+
prop = aggregate_type.properties[node.attr]
2312+
return adj.add_call(prop, (aggregate,), {}, {})
2313+
22632314
else:
22642315
attr_var = aggregate_type.vars[node.attr]
22652316

warp/tests/test_struct.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,53 @@ def test_nested_vec_assignment(self):
896896
)
897897

898898

899+
@wp.struct
900+
class StructWithProperty:
901+
value: float
902+
903+
@property
904+
def neg_value(self) -> float:
905+
return -self.value
906+
907+
908+
@wp.kernel
909+
def kernel_struct_property(s: StructWithProperty, out: wp.array(dtype=float)):
910+
out[0] = s.neg_value
911+
912+
913+
def test_struct_property(test, device):
914+
"""Tests that structs with properties (getters) are supported."""
915+
s = StructWithProperty()
916+
s.value = 42.0
917+
918+
out = wp.zeros(1, dtype=float, device=device)
919+
920+
wp.launch(kernel_struct_property, dim=1, inputs=[s, out], device=device)
921+
922+
assert_np_equal(out.numpy(), np.array([-42.0]))
923+
924+
925+
def test_struct_property_with_setter(test, device):
926+
"""Tests that structs with properties (setters) are not supported."""
927+
with test.assertRaisesRegex(TypeError, "Struct properties with setters are not supported"):
928+
929+
@wp.struct
930+
class StructWithPropertySetter:
931+
value: float
932+
933+
@property
934+
def neg_value(self) -> float:
935+
return -self.value
936+
937+
@neg_value.setter
938+
def neg_value(self, value: float):
939+
self.value = -value
940+
941+
942+
add_function_test(TestStruct, "test_struct_property", test_struct_property, devices=devices)
943+
add_function_test(TestStruct, "test_struct_property_with_setter", test_struct_property_with_setter, devices=devices)
944+
945+
899946
if __name__ == "__main__":
900947
wp.clear_kernel_cache()
901948
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)