Skip to content
23 changes: 23 additions & 0 deletions examples/start_pose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from stretch_mujoco.stretch_mujoco_simulator import StretchMujocoSimulator
from scipy.spatial.transform import Rotation as R



if __name__ == "__main__":

translation_m = [0.1, -0.2, 0]

# Convert euler angles to quaternion
euler_angles_degrees = [0, 0, -90] # Example: roll, pitch, yaw (x, y, z)
rotation_obj = R.from_euler('xyz', euler_angles_degrees, degrees=True)
rotation_quat = rotation_obj.as_quat(scalar_first=True)


sim = StretchMujocoSimulator(
start_translation=translation_m,
start_rotation_quat=rotation_quat
)

sim.start(headless=False)

while sim.is_running(): ...
33 changes: 27 additions & 6 deletions stretch_mujoco/mujoco_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
from dataclasses import dataclass
from multiprocessing.managers import DictProxy, SyncManager
import os
import signal
import threading
import time
Expand Down Expand Up @@ -192,20 +193,39 @@ def launch_server(
stop_mujoco_process_event: threading.Event,
data_proxies: MujocoServerProxies,
cameras_to_use: list[StretchCameras],
start_translation: list,
start_rotation_quat: list
):
server = cls(scene_xml_path, model, stop_mujoco_process_event, data_proxies)
server = cls(scene_xml_path, model, stop_mujoco_process_event, data_proxies,start_translation , start_rotation_quat)
server.run(
show_viewer_ui=show_viewer_ui,
camera_hz=camera_hz,
cameras_to_use=cameras_to_use,
)

def change_start_pose(self,xml_path: str, translation: list, rotation_quat: list):
"""Edit the MjSpec and recompile it before loading the model. Mujoco does not allow us to edit body positions at runtime:"""
spec = mujoco.MjSpec.from_file(xml_path)

current_file_path = os.path.abspath(__file__)
current_directory = os.path.dirname(current_file_path)
spec.meshdir = f"{current_directory}/models/assets/"
spec.texturedir = spec.meshdir

spec.find_body("base_link").pos = translation
spec.find_body("base_link").quat = rotation_quat
spec.compile()
return MjModel.from_xml_string(spec.to_xml())


def __init__(
self,
scene_xml_path: str | None,
model: MjModel | None,
stop_mujoco_process_event: threading.Event,
data_proxies: MujocoServerProxies,
start_translation: list,
start_rotation_quat: list
):
"""
Initialize the Simulator handle with a scene
Expand All @@ -215,11 +235,12 @@ def __init__(
"""
if scene_xml_path is None:
scene_xml_path = utils.default_scene_xml_path
self.mjmodel = MjModel.from_xml_path(scene_xml_path)
elif model is None:
self.mjmodel = MjModel.from_xml_path(scene_xml_path)
if model is not None:
self.mjmodel = model

if model is None:
model = self.change_start_pose(translation = start_translation, rotation_quat=start_rotation_quat, xml_path=scene_xml_path)

self.mjmodel = model

self.mjdata = MjData(self.mjmodel)

self._base_in_pos_motion = False
Expand Down
2 changes: 1 addition & 1 deletion stretch_mujoco/mujoco_server_passive.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _run_ui_simulation(self, show_viewer_ui: bool):
https://mujoco.readthedocs.io/en/stable/python.html#passive-viewer
"""
self.viewer = mujoco.viewer.launch_passive(
self.mjmodel, self.mjdata, show_left_ui=show_viewer_ui, show_right_ui=show_viewer_ui
model=self.mjmodel, data=self.mjdata, show_left_ui=show_viewer_ui, show_right_ui=show_viewer_ui
)

self.viewer._opt.flags[mujoco._enums.mjtVisFlag.mjVIS_RANGEFINDER] = False # Disables the lidar yellow lines.
Expand Down
6 changes: 6 additions & 0 deletions stretch_mujoco/stretch_mujoco_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,17 @@ def __init__(
model: MjModel | None = None,
camera_hz: float = 30,
cameras_to_use: list[StretchCameras] = [],
start_translation: list = [0,0,0],
start_rotation_quat: list = [1,0,0,0]
) -> None:
self.scene_xml_path = scene_xml_path
self.model = model
self.camera_hz = camera_hz
self.urdf_model = utils.URDFmodel()
self._server_process = None
self._cameras_to_use = cameras_to_use
self._start_translation = start_translation
self._start_rotation_quat = start_rotation_quat

self.is_stop_called = False

Expand Down Expand Up @@ -106,6 +110,8 @@ def start(
self._stop_mujoco_process_event,
self.data_proxies,
self._cameras_to_use,
self._start_translation,
self._start_rotation_quat,
),
daemon=False, # We're gonna handle terminating this in stop_mujoco_process()
)
Expand Down
39 changes: 22 additions & 17 deletions tests/test_one_process.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
from multiprocessing import Manager
import signal
import threading
from multiprocessing import Manager

from stretch_mujoco.mujoco_server import MujocoServer, MujocoServerProxies

from stretch_mujoco.mujoco_server import MujocoServer
def main():
_manager = Manager()
data_proxies = MujocoServerProxies.default(_manager)

from stretch_mujoco.mujoco_server import MujocoServerProxies
event = threading.Event()
signal.signal(signal.SIGTERM, lambda num, frame: event.set())
signal.signal(signal.SIGINT, lambda num, frame: event.set())

_manager = Manager()
data_proxies = MujocoServerProxies.default(_manager)
MujocoServer.launch_server(
scene_xml_path=None,
model=None,
camera_hz=30,
show_viewer_ui=True,
stop_mujoco_process_event=event,
data_proxies=data_proxies,
cameras_to_use=[],
start_translation=[0, 0, 0],
start_rotation_quat=[1, 0, 0, 0],
)

event = threading.Event()
signal.signal(signal.SIGTERM, lambda num, frame: event.set())
signal.signal(signal.SIGINT, lambda num, frame: event.set())

MujocoServer.launch_server(
scene_xml_path=None,
model=None,
camera_hz=30,
show_viewer_ui=True,
stop_mujoco_process_event=event,
data_proxies=data_proxies,
cameras_to_use=[]
)
if __name__ == "__main__":
main()
36 changes: 22 additions & 14 deletions tests/test_one_process_passive.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
from multiprocessing import Manager
import signal
import threading
from multiprocessing import Manager

from stretch_mujoco.mujoco_server import MujocoServerProxies
from stretch_mujoco.mujoco_server_passive import MujocoServerPassive

_manager = Manager()
data_proxies = MujocoServerProxies.default(_manager)
def main():
_manager = Manager()
data_proxies = MujocoServerProxies.default(_manager)

event = threading.Event()
signal.signal(signal.SIGTERM, lambda num, frame: event.set())
signal.signal(signal.SIGINT, lambda num, frame: event.set())
event = threading.Event()
signal.signal(signal.SIGTERM, lambda num, frame: event.set())
signal.signal(signal.SIGINT, lambda num, frame: event.set())

MujocoServerPassive.launch_server(
scene_xml_path=None,
model=None,
camera_hz=30,
show_viewer_ui=True,
stop_mujoco_process_event=event,
data_proxies=data_proxies,
cameras_to_use=[]
MujocoServerPassive.launch_server(
scene_xml_path=None,
model=None,
camera_hz=30,
show_viewer_ui=True,
stop_mujoco_process_event=event,
data_proxies=data_proxies,
cameras_to_use=[],
start_translation=[0, 0, 0],
start_rotation_quat=[1, 0, 0, 0],
)



if __name__ == "__main__":
main()