diff --git a/genmetaballs/src/cuda/bindings.cu b/genmetaballs/src/cuda/bindings.cu index d525ef4..fa5c018 100644 --- a/genmetaballs/src/cuda/bindings.cu +++ b/genmetaballs/src/cuda/bindings.cu @@ -23,6 +23,8 @@ template void bind_image(nb::module_& m, const char* name); template void bind_image_view(nb::module_& m, const char* name); +template +void bind_fmb_scene(nb::module_& m, const char* name); NB_MODULE(_genmetaballs_bindings, m) { @@ -66,6 +68,8 @@ NB_MODULE(_genmetaballs_bindings, m) { "apply the inverse covariance matrix to the given vector", nb::arg("vec")) .def("quadratic_form", &FMB::quadratic_form, "Evaluate the associated quadratic form at the given vector", nb::arg("vec")); + bind_fmb_scene(fmb, "CPUFMBScene"); + bind_fmb_scene(fmb, "GPUFMBScene"); /* * Geometry module bindings @@ -244,3 +248,16 @@ void bind_image(nb::module_& m, const char* name) { return nb::str("{}(height={}, width={})").format(name, img.num_rows(), img.num_cols()); }); } + +template +void bind_fmb_scene(nb::module_& m, const char* name) { + nb::class_>(m, name) + .def(nb::init(), nb::arg("size")) + .def_prop_ro("size", &FMBScene::size) + .def("__len__", &FMBScene::size) + .def("__getitem__", &FMBScene::get_fmb, nb::arg("idx"), + "Get the (FMB, log_weight) tuple at index i") + .def("__repr__", [=](const FMBScene& scene) { + return nb::str("{}(size={})").format(name, scene.size()); + }); +} diff --git a/genmetaballs/src/genmetaballs/core/__init__.py b/genmetaballs/src/genmetaballs/core/__init__.py index e007205..530f526 100644 --- a/genmetaballs/src/genmetaballs/core/__init__.py +++ b/genmetaballs/src/genmetaballs/core/__init__.py @@ -10,6 +10,7 @@ TwoParameterConfidence, ZeroParameterConfidence, ) +from genmetaballs._genmetaballs_bindings.fmb import CPUFMBScene, GPUFMBScene from genmetaballs._genmetaballs_bindings.image import CPUImage, GPUImage from genmetaballs._genmetaballs_bindings.utils import CPUFloatArray2D, GPUFloatArray2D, sigmoid @@ -47,6 +48,21 @@ def make_image(height: int, width: int, device: DeviceType) -> CPUImage | GPUIma raise ValueError(f"Unsupported device type: {device}") +def make_fmb_scene(size: int, device: DeviceType) -> CPUFMBScene | GPUFMBScene: + """Create an FMBScene on the specified device. + + Args: + size: The number of FMBs in the scene. + device: 'cpu' or 'gpu' to specify the target device. + """ + if device == "cpu": + return CPUFMBScene(size) + elif device == "gpu": + return GPUFMBScene(size) + else: + raise ValueError(f"Unsupported device type: {device}") + + __all__ = [ "array2d_float", "ZeroParameterConfidence", @@ -60,4 +76,5 @@ def make_image(height: int, width: int, device: DeviceType) -> CPUImage | GPUIma "FourParameterBlender", "ThreeParameterBlender", "make_image", + "make_fmb_scene", ] diff --git a/tests/python_tests/test_fmb.py b/tests/python_tests/test_fmb.py index fd020ef..7044234 100644 --- a/tests/python_tests/test_fmb.py +++ b/tests/python_tests/test_fmb.py @@ -3,7 +3,7 @@ from scipy.spatial.distance import mahalanobis from scipy.spatial.transform import Rotation as Rot -from genmetaballs.core import fmb, geometry +from genmetaballs.core import fmb, geometry, make_fmb_scene FMB = fmb.FMB Pose, Vec3D, Rotation = geometry.Pose, geometry.Vec3D, geometry.Rotation @@ -38,3 +38,13 @@ def test_fmb_quadratic_form(rng): FMB(pose, *extent).quadratic_form(Vec3D(*vec)), mahalanobis(vec, tran, np.linalg.inv(cov)) ** 2, ) + + +def test_fmb_scene_creation(): + cpu_scene = make_fmb_scene(10, device="cpu") + assert isinstance(cpu_scene, fmb.CPUFMBScene) + assert len(cpu_scene) == 10 + + gpu_scene = make_fmb_scene(20, device="gpu") + assert isinstance(gpu_scene, fmb.GPUFMBScene) + assert len(gpu_scene) == 20