@@ -93,6 +93,14 @@ def step(self, action: np.ndarray):
9393 """
9494 pass
9595
96+ @abstractmethod
97+ def reset (self , reset_state : np .ndarray ):
98+ """Reset state of the robot accordingly to the action mode.
99+
100+ :param reset_state: Target reset state of robot actuators.
101+ """
102+ pass
103+
96104
97105class TorqueActionMode (ActionMode ):
98106 """Control all joints through torque control.
@@ -140,6 +148,27 @@ def step(self, action: np.ndarray):
140148 self ._robot .grippers [side ].set_control (action )
141149 self ._mojo .step ()
142150
151+ def reset (self , reset_state : np .ndarray ):
152+ """See base."""
153+ if len (reset_state ) != len (self ._robot .limb_actuators ):
154+ raise ValueError (
155+ f"Mismatch between reset_state length "
156+ f"({ len (reset_state )} ) "
157+ f"and number of actuators ({ len (self ._robot .limb_actuators )} ). "
158+ f"Ensure reset_state matches the actuators count in the model."
159+ )
160+ for value , actuator in zip (reset_state , self ._robot .limb_actuators ):
161+ if actuator .joint :
162+ joint = self ._mojo .physics .bind (actuator .joint )
163+ joint .qpos = value
164+ joint .qvel *= 0
165+ joint .qacc *= 0
166+ elif actuator .tendon :
167+ warnings .warn (
168+ f"Tendon actuators are not fully supported "
169+ f"for { self .__class__ .__name__ } action mode."
170+ )
171+
143172
144173class JointPositionActionMode (ActionMode ):
145174 """Control all joints through joint position.
@@ -220,6 +249,36 @@ def step(self, action: np.ndarray):
220249 else :
221250 self ._mojo .step ()
222251
252+ def reset (self , reset_state : np .ndarray ):
253+ """See base."""
254+ if len (reset_state ) != len (self ._robot .limb_actuators ):
255+ raise ValueError (
256+ f"Mismatch between reset_state length "
257+ f"({ len (reset_state )} ) "
258+ f"and number of actuators ({ len (self ._robot .limb_actuators )} ). "
259+ f"Ensure reset_state matches the actuators count in the model."
260+ )
261+ for value , actuator in zip (reset_state , self ._robot .limb_actuators ):
262+ if actuator .joint :
263+ bound_joint = self ._mojo .physics .bind (actuator .joint )
264+ bound_joint .qpos = value
265+ bound_joint .qvel *= 0
266+ bound_joint .qacc *= 0
267+ elif actuator .tendon :
268+ if actuator .tendon .joint is None or len (actuator .tendon .joint ) == 0 :
269+ raise RuntimeError (
270+ "Currently only fixed tendons with joints are supported."
271+ )
272+ joint_value = value / len (actuator .tendon .joint )
273+ for tendon_joint in actuator .tendon .joint :
274+ value_coefficient = tendon_joint .coef
275+ bound_joint = self ._mojo .physics .bind (tendon_joint .joint )
276+ bound_joint .qpos = joint_value / value_coefficient
277+ bound_joint .qvel *= 0
278+ bound_joint .qacc *= 0
279+ bound_actuator = self ._mojo .physics .bind (actuator )
280+ bound_actuator .ctrl = value
281+
223282 def _step_until_reached (self ):
224283 """Step physics until the target position is reached."""
225284 steps_counter = 0
0 commit comments