Skip to content

Commit 85de3f6

Browse files
Merge branch 'ccrouzet/gh-656-struct-inheritance-error-msg' into 'main'
Inform About Class Inheritance Not Being Supported for `wp.struct` (GH-656) See merge request omniverse/warp!1287
2 parents b2b52a7 + 451100c commit 85de3f6

File tree

9 files changed

+57
-37
lines changed

9 files changed

+57
-37
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
averaging the shape's and the cloth's coefficients.
4040
- Add damping terms for collisions in `wp.sim.VBDIntegrator`, whose strength is controlled by `Model.soft_contact_kd`.
4141
- Exposed new `warp.fem` operators: `node_count`, `node_index`, `element_coordinates`, `element_closest_point`.
42+
- Inform about class inheritance not being supported for `wp.struct`
43+
([GH-656](https://github.com/NVIDIA/warp/issues/656)).
4244

4345
### Fixed
4446

docs/limitations.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ Structs
6363
-------
6464

6565
* Structs cannot have generic members, i.e. of type ``typing.Any``.
66+
* Structs do not support inheritance. Consider using composition instead.
6667

6768
Volumes
6869
-------

warp/codegen.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -413,12 +413,15 @@ def numpy_value(self):
413413
class Struct:
414414
hash: bytes
415415

416-
def __init__(self, cls: type, key: str, module: warp.context.Module):
416+
def __init__(self, key: str, cls: type, module: warp.context.Module):
417+
self.key = key
417418
self.cls = cls
418419
self.module = module
419-
self.key = key
420420
self.vars: dict[str, Var] = {}
421421

422+
if isinstance(self.cls, Sequence):
423+
raise RuntimeError("Warp structs must be defined as base classes")
424+
422425
annotations = get_annotations(self.cls)
423426
for label, type in annotations.items():
424427
self.vars[label] = Var(label, type)
@@ -489,7 +492,7 @@ class StructType(ctypes.Structure):
489492

490493
self.default_constructor.add_overload(self.value_constructor)
491494

492-
if module:
495+
if isinstance(module, warp.context.Module):
493496
module.register_struct(self)
494497

495498
# Define class for instances of this struct

warp/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1106,7 +1106,7 @@ def wrapper(f, *args, **kwargs):
11061106
# decorator to register struct, @struct
11071107
def struct(c: type):
11081108
m = get_module(c.__module__)
1109-
s = warp.codegen.Struct(cls=c, key=warp.codegen.make_full_qualified_name(c), module=m)
1109+
s = warp.codegen.Struct(key=warp.codegen.make_full_qualified_name(c), cls=c, module=m)
11101110
s = functools.update_wrapper(s, c)
11111111
return s
11121112

warp/fem/cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ def get_struct(struct: type, suffix: str):
145145
if key not in _struct_cache:
146146
module = wp.get_module(struct.__module__)
147147
_struct_cache[key] = wp.codegen.Struct(
148-
cls=struct,
149148
key=key,
149+
cls=struct,
150150
module=module,
151151
)
152152

warp/sparse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,8 @@ class BsrMatrixTyped(BsrMatrix):
273273

274274
if key not in _struct_cache:
275275
_struct_cache[key] = wp.codegen.Struct(
276-
cls=BsrMatrixTyped,
277276
key=key,
277+
cls=BsrMatrixTyped,
278278
module=module,
279279
)
280280

warp/tests/test_builtins_resolution.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -189,73 +189,73 @@ def test_mat_mat_args_support(test, device, dtype):
189189
else:
190190
with test.assertRaisesRegex(
191191
RuntimeError,
192-
r"Couldn't find a function 'ddot' compatible with " r"the arguments 'mat_t, tuple'$",
192+
r"Couldn't find a function 'ddot' compatible with the arguments 'mat_t, tuple'$",
193193
):
194194
wp.ddot(mat_cls(*a_values), b_values)
195195

196196
with test.assertRaisesRegex(
197197
RuntimeError,
198-
r"Couldn't find a function 'ddot' compatible with " r"the arguments 'tuple, tuple'$",
198+
r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
199199
):
200200
wp.ddot(wpv(dtype, a_values), b_values)
201201

202202
with test.assertRaisesRegex(
203203
RuntimeError,
204-
r"Couldn't find a function 'ddot' compatible with " r"the arguments 'tuple, tuple'$",
204+
r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
205205
):
206206
wp.ddot(wpm(dtype, 3, a_values), b_values)
207207

208208
with test.assertRaisesRegex(
209209
RuntimeError,
210-
r"Couldn't find a function 'ddot' compatible with " r"the arguments 'tuple, tuple'$",
210+
r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
211211
):
212212
wp.ddot(npv(np_type, a_values), b_values)
213213

214214
with test.assertRaisesRegex(
215215
RuntimeError,
216-
r"Couldn't find a function 'ddot' compatible with " r"the arguments 'tuple, tuple'$",
216+
r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
217217
):
218218
wp.ddot(npm(np_type, 3, a_values), b_values)
219219

220220
with test.assertRaisesRegex(
221221
RuntimeError,
222-
r"Couldn't find a function 'ddot' compatible with " r"the arguments 'ndarray, tuple'$",
222+
r"Couldn't find a function 'ddot' compatible with the arguments 'ndarray, tuple'$",
223223
):
224224
wp.ddot(np.array(npv(np_type, a_values)), b_values)
225225

226226
with test.assertRaisesRegex(
227227
RuntimeError,
228-
r"Couldn't find a function 'ddot' compatible with " r"the arguments 'tuple, mat_t'$",
228+
r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, mat_t'$",
229229
):
230230
wp.ddot(a_values, mat_cls(*b_values))
231231

232232
with test.assertRaisesRegex(
233233
RuntimeError,
234-
r"Couldn't find a function 'ddot' compatible with " r"the arguments 'tuple, tuple'$",
234+
r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
235235
):
236236
wp.ddot(a_values, wpv(dtype, b_values))
237237

238238
with test.assertRaisesRegex(
239239
RuntimeError,
240-
r"Couldn't find a function 'ddot' compatible with " r"the arguments 'tuple, tuple'$",
240+
r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
241241
):
242242
wp.ddot(a_values, wpm(dtype, 3, b_values))
243243

244244
with test.assertRaisesRegex(
245245
RuntimeError,
246-
r"Couldn't find a function 'ddot' compatible with " r"the arguments 'tuple, tuple'$",
246+
r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
247247
):
248248
wp.ddot(a_values, npv(np_type, b_values))
249249

250250
with test.assertRaisesRegex(
251251
RuntimeError,
252-
r"Couldn't find a function 'ddot' compatible with " r"the arguments 'tuple, tuple'$",
252+
r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, tuple'$",
253253
):
254254
wp.ddot(a_values, npm(np_type, 3, b_values))
255255

256256
with test.assertRaisesRegex(
257257
RuntimeError,
258-
r"Couldn't find a function 'ddot' compatible with " r"the arguments 'tuple, ndarray'$",
258+
r"Couldn't find a function 'ddot' compatible with the arguments 'tuple, ndarray'$",
259259
):
260260
wp.ddot(a_values, np.array(npv(np_type, b_values)))
261261

@@ -300,37 +300,37 @@ def test_mat_float_args_support(test, device, dtype):
300300
else:
301301
with test.assertRaisesRegex(
302302
RuntimeError,
303-
r"Couldn't find a function 'mul' compatible with " r"the arguments 'mat_t, float'$",
303+
r"Couldn't find a function 'mul' compatible with the arguments 'mat_t, float'$",
304304
):
305305
wp.mul(mat_cls(*a_values), b_value)
306306

307307
with test.assertRaisesRegex(
308308
RuntimeError,
309-
r"Couldn't find a function 'mul' compatible with " r"the arguments 'tuple, float'$",
309+
r"Couldn't find a function 'mul' compatible with the arguments 'tuple, float'$",
310310
):
311311
wp.mul(wpv(dtype, a_values), b_value)
312312

313313
with test.assertRaisesRegex(
314314
RuntimeError,
315-
r"Couldn't find a function 'mul' compatible with " r"the arguments 'tuple, float'$",
315+
r"Couldn't find a function 'mul' compatible with the arguments 'tuple, float'$",
316316
):
317317
wp.mul(wpm(dtype, 3, a_values), b_value)
318318

319319
with test.assertRaisesRegex(
320320
RuntimeError,
321-
r"Couldn't find a function 'mul' compatible with " r"the arguments 'tuple, float'$",
321+
r"Couldn't find a function 'mul' compatible with the arguments 'tuple, float'$",
322322
):
323323
wp.mul(npv(np_type, a_values), b_value)
324324

325325
with test.assertRaisesRegex(
326326
RuntimeError,
327-
r"Couldn't find a function 'mul' compatible with " r"the arguments 'tuple, float'$",
327+
r"Couldn't find a function 'mul' compatible with the arguments 'tuple, float'$",
328328
):
329329
wp.mul(npm(np_type, 3, a_values), b_value)
330330

331331
with test.assertRaisesRegex(
332332
RuntimeError,
333-
r"Couldn't find a function 'mul' compatible with " r"the arguments 'ndarray, float'$",
333+
r"Couldn't find a function 'mul' compatible with the arguments 'ndarray, float'$",
334334
):
335335
wp.mul(np.array(npv(np_type, a_values)), b_value)
336336

@@ -401,49 +401,49 @@ def test_vec_vec_args_support(test, device, dtype):
401401
else:
402402
with test.assertRaisesRegex(
403403
RuntimeError,
404-
r"Couldn't find a function 'dot' compatible with " r"the arguments 'vec_t, tuple'$",
404+
r"Couldn't find a function 'dot' compatible with the arguments 'vec_t, tuple'$",
405405
):
406406
wp.dot(vec_cls(*a_values), b_values)
407407

408408
with test.assertRaisesRegex(
409409
RuntimeError,
410-
r"Couldn't find a function 'dot' compatible with " r"the arguments 'tuple, tuple'$",
410+
r"Couldn't find a function 'dot' compatible with the arguments 'tuple, tuple'$",
411411
):
412412
wp.dot(wpv(dtype, a_values), b_values)
413413

414414
with test.assertRaisesRegex(
415415
RuntimeError,
416-
r"Couldn't find a function 'dot' compatible with " r"the arguments 'tuple, tuple'$",
416+
r"Couldn't find a function 'dot' compatible with the arguments 'tuple, tuple'$",
417417
):
418418
wp.dot(npv(np_type, a_values), b_values)
419419

420420
with test.assertRaisesRegex(
421421
RuntimeError,
422-
r"Couldn't find a function 'dot' compatible with " r"the arguments 'ndarray, tuple'$",
422+
r"Couldn't find a function 'dot' compatible with the arguments 'ndarray, tuple'$",
423423
):
424424
wp.dot(np.array(npv(np_type, a_values)), b_values)
425425

426426
with test.assertRaisesRegex(
427427
RuntimeError,
428-
r"Couldn't find a function 'dot' compatible with " r"the arguments 'tuple, vec_t'$",
428+
r"Couldn't find a function 'dot' compatible with the arguments 'tuple, vec_t'$",
429429
):
430430
wp.dot(a_values, vec_cls(*b_values))
431431

432432
with test.assertRaisesRegex(
433433
RuntimeError,
434-
r"Couldn't find a function 'dot' compatible with " r"the arguments 'tuple, tuple'$",
434+
r"Couldn't find a function 'dot' compatible with the arguments 'tuple, tuple'$",
435435
):
436436
wp.dot(a_values, wpv(dtype, b_values))
437437

438438
with test.assertRaisesRegex(
439439
RuntimeError,
440-
r"Couldn't find a function 'dot' compatible with " r"the arguments 'tuple, tuple'$",
440+
r"Couldn't find a function 'dot' compatible with the arguments 'tuple, tuple'$",
441441
):
442442
wp.dot(a_values, npv(np_type, b_values))
443443

444444
with test.assertRaisesRegex(
445445
RuntimeError,
446-
r"Couldn't find a function 'dot' compatible with " r"the arguments 'tuple, ndarray'$",
446+
r"Couldn't find a function 'dot' compatible with the arguments 'tuple, ndarray'$",
447447
):
448448
wp.dot(a_values, np.array(npv(np_type, b_values)))
449449

@@ -480,25 +480,25 @@ def test_vec_float_args_support(test, device, dtype):
480480
else:
481481
with test.assertRaisesRegex(
482482
RuntimeError,
483-
r"Couldn't find a function 'mul' compatible with " r"the arguments 'vec_t, float'$",
483+
r"Couldn't find a function 'mul' compatible with the arguments 'vec_t, float'$",
484484
):
485485
wp.mul(vec_cls(*a_values), b_value)
486486

487487
with test.assertRaisesRegex(
488488
RuntimeError,
489-
r"Couldn't find a function 'mul' compatible with " r"the arguments 'tuple, float'$",
489+
r"Couldn't find a function 'mul' compatible with the arguments 'tuple, float'$",
490490
):
491491
wp.mul(wpv(dtype, a_values), b_value)
492492

493493
with test.assertRaisesRegex(
494494
RuntimeError,
495-
r"Couldn't find a function 'mul' compatible with " r"the arguments 'tuple, float'$",
495+
r"Couldn't find a function 'mul' compatible with the arguments 'tuple, float'$",
496496
):
497497
wp.mul(npv(np_type, a_values), b_value)
498498

499499
with test.assertRaisesRegex(
500500
RuntimeError,
501-
r"Couldn't find a function 'mul' compatible with " r"the arguments 'ndarray, float'$",
501+
r"Couldn't find a function 'mul' compatible with the arguments 'ndarray, float'$",
502502
):
503503
wp.mul(np.array(npv(np_type, a_values)), b_value)
504504

warp/tests/test_func.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def test_native_function_error_resolution(self):
421421
b = wp.mat22d(1.0, 2.0, 3.0, 4.0)
422422
with self.assertRaisesRegex(
423423
RuntimeError,
424-
r"^Couldn't find a function 'mul' compatible with " r"the arguments 'mat22f, mat22d'$",
424+
r"^Couldn't find a function 'mul' compatible with the arguments 'mat22f, mat22d'$",
425425
):
426426
a * b
427427

warp/tests/test_struct.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,18 @@ def kernel(foo: Foo):
248248
)
249249

250250

251+
def test_struct_inheritance_error(test, device):
252+
with test.assertRaisesRegex(RuntimeError, r"Warp structs must be defined as base classes$"):
253+
254+
@wp.struct
255+
class Parent:
256+
x: int
257+
258+
@wp.struct
259+
class Child(Parent):
260+
y: int
261+
262+
251263
@wp.kernel
252264
def test_struct_instantiate(data: wp.array(dtype=int)):
253265
baz = Baz(data, wp.vec3(0.0, 0.0, 26.0))
@@ -682,6 +694,8 @@ def test_nested_vec_assignment(self):
682694
)
683695
add_kernel_test(TestStruct, kernel=test_return, name="test_return", dim=1, inputs=[], devices=devices)
684696
add_function_test(TestStruct, "test_nested_struct", test_nested_struct, devices=devices)
697+
add_function_test(TestStruct, "test_struct_attribute_error", test_struct_attribute_error, devices=devices)
698+
add_function_test(TestStruct, "test_struct_inheritance_error", test_struct_inheritance_error, devices=devices)
685699
add_function_test(TestStruct, "test_nested_array_struct", test_nested_array_struct, devices=devices)
686700
add_function_test(TestStruct, "test_convert_to_device", test_convert_to_device, devices=devices)
687701
add_function_test(TestStruct, "test_nested_empty_struct", test_nested_empty_struct, devices=devices)

0 commit comments

Comments
 (0)