diff --git a/docs/environments.md b/docs/environments.md index fa80464df..782b52745 100644 --- a/docs/environments.md +++ b/docs/environments.md @@ -468,6 +468,25 @@ The OpenEnv community has built a catalog of ready-to-run environments that cove +
+
+ dm_control +

+ MuJoCo-based continuous control tasks from DeepMind's dm_control suite — cartpole, hopper, quadruped, walker, and more. +

+
+ +
+ diff --git a/docs/environments/dm-control.md b/docs/environments/dm-control.md new file mode 100644 index 000000000..835153b35 --- /dev/null +++ b/docs/environments/dm-control.md @@ -0,0 +1 @@ +--8<-- "../../envs/dm_control_env/README.md" diff --git a/envs/dm_control_env/README.md b/envs/dm_control_env/README.md new file mode 100644 index 000000000..c8069760d --- /dev/null +++ b/envs/dm_control_env/README.md @@ -0,0 +1,167 @@ +--- +title: dm_control Environment Server +emoji: 🤖 +colorFrom: green +colorTo: blue +sdk: docker +pinned: false +app_port: 8000 +base_path: /web +tags: + - openenv +--- + +# dm_control OpenEnv Environment + +A generic OpenEnv environment for [dm_control.suite](https://github.com/google-deepmind/dm_control), providing access to all MuJoCo-based continuous control tasks. + +

+ Cartpole Balance + Quadruped Walk +

+ +## Supported Environments + +| Domain | Tasks | +|--------|-------| +| cartpole | balance, swingup, swingup_sparse | +| walker | stand, walk, run | +| humanoid | stand, walk, run | +| cheetah | run | +| hopper | stand, hop | +| reacher | easy, hard | +| pendulum | swingup | +| finger | spin, turn_easy, turn_hard | +| fish | upright, swim | +| ball_in_cup | catch | +| And more... | See `dm_control.suite.BENCHMARKING` | + +## Quick Start + +### Using the Client + +```python +from envs.dm_control_env import DMControlEnv, DMControlAction + +# Connect to a running server +with DMControlEnv(base_url="http://localhost:8000") as env: + # Reset with default (cartpole/balance) + result = env.reset() + print(f"Observations: {result.observation.observations.keys()}") + + # Take actions + for _ in range(100): + action = DMControlAction(values=[0.5]) # Push cart right + result = env.step(action) + print(f"Reward: {result.reward}, Done: {result.done}") + + if result.done: + result = env.reset() +``` + +### Switching Environments + +```python +# Start with cartpole +result = env.reset(domain_name="cartpole", task_name="balance") + +# Switch to walker (on next reset) +result = env.reset(domain_name="walker", task_name="walk") +# Note: walker has 6 action dimensions +action = DMControlAction(values=[0.0] * 6) +result = env.step(action) +``` + +### Running the Server + +```bash +# From OpenEnv root +cd envs/dm_control_env +uvicorn server.app:app --host 0.0.0.0 --port 8000 + +# Or using uv +uv run --project . server +``` + +### Using Docker + +```bash +# Build +docker build -t dm_control:latest -f server/Dockerfile . + +# Run +docker run -p 8000:8000 dm_control:latest +``` + +## API + +### Action + +```python +class DMControlAction(Action): + values: List[float] # Continuous action values +``` + +Action dimensions vary by environment: +- cartpole: 1 (force on cart) +- walker: 6 (joint torques) +- humanoid: 21 (joint torques) + +### Observation + +```python +class DMControlObservation(Observation): + observations: Dict[str, List[float]] # Named observation arrays + pixels: Optional[str] # Base64 PNG (if render=True) + reward: float + done: bool +``` + +### State + +```python +class DMControlState(State): + domain_name: str + task_name: str + action_spec: Dict[str, Any] + observation_spec: Dict[str, Any] + physics_timestep: float + control_timestep: float + episode_id: str + step_count: int +``` + +## Examples + +See the `examples/` directory: +- `cartpole_control.py` - Interactive cartpole control with arrow keys +- `hopper_control.py` - Interactive hopper control with spacebar for random forces +- `quadruped_control.py` - Interactive quadruped control with spacebar for random forces +- `list_environments.py` - Print all available environments + +All examples support consistent CLI arguments: + +```bash +# Default: interactive mode with minimal pygame window +python examples/cartpole_control.py + +# Visual mode with rendered MuJoCo frames +python examples/cartpole_control.py --visual + +# Headless mode (no pygame, automated control) +python examples/cartpole_control.py --headless --max-steps 500 + +# Select a different task +python examples/cartpole_control.py --task swingup +python examples/hopper_control.py --task stand +python examples/quadruped_control.py --task run +``` + +## Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `DMCONTROL_DOMAIN` | cartpole | Default domain | +| `DMCONTROL_TASK` | balance | Default task | +| `DMCONTROL_RENDER_HEIGHT` | 480 | Render height | +| `DMCONTROL_RENDER_WIDTH` | 640 | Render width | diff --git a/envs/dm_control_env/__init__.py b/envs/dm_control_env/__init__.py new file mode 100644 index 000000000..4de4129cd --- /dev/null +++ b/envs/dm_control_env/__init__.py @@ -0,0 +1,14 @@ +"""dm_control OpenEnv Environment. + +A generic OpenEnv environment for dm_control.suite supporting all domains/tasks. +""" + +from .models import DMControlAction, DMControlObservation, DMControlState +from .client import DMControlEnv + +__all__ = [ + "DMControlAction", + "DMControlObservation", + "DMControlState", + "DMControlEnv", +] diff --git a/envs/dm_control_env/assets/cartpole.png b/envs/dm_control_env/assets/cartpole.png new file mode 100644 index 000000000..c4982fc31 Binary files /dev/null and b/envs/dm_control_env/assets/cartpole.png differ diff --git a/envs/dm_control_env/assets/quadruped.png b/envs/dm_control_env/assets/quadruped.png new file mode 100644 index 000000000..967027853 Binary files /dev/null and b/envs/dm_control_env/assets/quadruped.png differ diff --git a/envs/dm_control_env/client.py b/envs/dm_control_env/client.py new file mode 100644 index 000000000..cc77d9fe0 --- /dev/null +++ b/envs/dm_control_env/client.py @@ -0,0 +1,375 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +dm_control Environment Client. + +This module provides the client for connecting to a dm_control +Environment server via WebSocket for persistent sessions. +""" + +from typing import Any, Dict, List, Optional, Tuple + +try: + from openenv.core.client_types import StepResult + from openenv.core.env_client import EnvClient + + from .models import ( + AVAILABLE_ENVIRONMENTS, + DMControlAction, + DMControlObservation, + DMControlState, + ) +except ImportError: + from openenv.core.client_types import StepResult + from openenv.core.env_client import EnvClient + + try: + from models import ( + AVAILABLE_ENVIRONMENTS, + DMControlAction, + DMControlObservation, + DMControlState, + ) + except ImportError: + try: + from dm_control_env.models import ( + AVAILABLE_ENVIRONMENTS, + DMControlAction, + DMControlObservation, + DMControlState, + ) + except ImportError: + from envs.dm_control_env.models import ( + AVAILABLE_ENVIRONMENTS, + DMControlAction, + DMControlObservation, + DMControlState, + ) + + +class DMControlEnv(EnvClient[DMControlAction, DMControlObservation, DMControlState]): + """ + Client for dm_control.suite environments. + + This client maintains a persistent WebSocket connection to the environment + server, enabling efficient multi-step interactions with lower latency. + + Supported Environments (via dm_control.suite): + - cartpole: balance, swingup, swingup_sparse + - walker: stand, walk, run + - humanoid: stand, walk, run + - cheetah: run + - hopper: stand, hop + - reacher: easy, hard + - And many more... + + Example: + >>> # Connect to a running server + >>> with DMControlEnv(base_url="http://localhost:8000") as client: + ... result = client.reset() + ... print(f"Observations: {result.observation.observations.keys()}") + ... + ... # Take action (cartpole: push right) + ... result = client.step(DMControlAction(values=[0.5])) + ... print(f"Reward: {result.reward}") + + Example switching environments: + >>> client = DMControlEnv(base_url="http://localhost:8000") + >>> # Start with cartpole balance + >>> result = client.reset(domain_name="cartpole", task_name="balance") + >>> # ... train on cartpole ... + >>> # Switch to walker walk + >>> result = client.reset(domain_name="walker", task_name="walk") + >>> # ... train on walker ... + """ + + def __init__( + self, + base_url: str, + connect_timeout_s: float = 10.0, + message_timeout_s: float = 60.0, + provider: Optional[Any] = None, + ): + """ + Initialize dm_control environment client. + + Args: + base_url: Base URL of the environment server (http:// or ws://). + connect_timeout_s: Timeout for establishing WebSocket connection. + message_timeout_s: Timeout for receiving responses. + provider: Optional container/runtime provider for lifecycle management. + """ + super().__init__( + base_url=base_url, + connect_timeout_s=connect_timeout_s, + message_timeout_s=message_timeout_s, + provider=provider, + ) + + def _step_payload(self, action: DMControlAction) -> Dict: + """ + Convert DMControlAction to JSON payload for step request. + + Args: + action: DMControlAction instance + + Returns: + Dictionary representation suitable for JSON encoding + """ + payload: Dict[str, Any] = {"values": action.values} + + if action.metadata: + payload["metadata"] = action.metadata + + return payload + + def _parse_result(self, payload: Dict) -> StepResult[DMControlObservation]: + """ + Parse server response into StepResult[DMControlObservation]. + + Args: + payload: JSON response from server + + Returns: + StepResult with DMControlObservation + """ + obs_data = payload.get("observation", {}) + + observation = DMControlObservation( + observations=obs_data.get("observations", {}), + pixels=obs_data.get("pixels"), + done=payload.get("done", False), + reward=payload.get("reward"), + metadata=obs_data.get("metadata", {}), + ) + + return StepResult( + observation=observation, + reward=payload.get("reward"), + done=payload.get("done", False), + ) + + def _parse_state(self, payload: Dict) -> DMControlState: + """ + Parse server response into DMControlState object. + + Args: + payload: JSON response from /state endpoint + + Returns: + DMControlState object with environment information + """ + return DMControlState( + episode_id=payload.get("episode_id"), + step_count=payload.get("step_count", 0), + domain_name=payload.get("domain_name", ""), + task_name=payload.get("task_name", ""), + action_spec=payload.get("action_spec", {}), + observation_spec=payload.get("observation_spec", {}), + physics_timestep=payload.get("physics_timestep", 0.002), + control_timestep=payload.get("control_timestep", 0.02), + ) + + def reset( + self, + domain_name: Optional[str] = None, + task_name: Optional[str] = None, + seed: Optional[int] = None, + render: bool = False, + **kwargs, + ) -> StepResult[DMControlObservation]: + """ + Reset the environment. + + Args: + domain_name: Optionally switch to a different domain. + task_name: Optionally switch to a different task. + seed: Random seed for reproducibility. + render: If True, include pixel observations in response. + **kwargs: Additional arguments passed to server. + + Returns: + StepResult with initial observation. + """ + reset_kwargs = dict(kwargs) + if domain_name is not None: + reset_kwargs["domain_name"] = domain_name + if task_name is not None: + reset_kwargs["task_name"] = task_name + if seed is not None: + reset_kwargs["seed"] = seed + reset_kwargs["render"] = render + + return super().reset(**reset_kwargs) + + def step( + self, + action: DMControlAction, + render: bool = False, + **kwargs, + ) -> StepResult[DMControlObservation]: + """ + Execute one step in the environment. + + Args: + action: DMControlAction with continuous action values. + render: If True, include pixel observations in response. + **kwargs: Additional arguments passed to server. + + Returns: + StepResult with new observation, reward, and done flag. + """ + # Note: render flag needs to be passed differently + # For now, the server remembers the render setting from reset + return super().step(action, **kwargs) + + @staticmethod + def available_environments() -> List[Tuple[str, str]]: + """ + List available dm_control environments. + + Returns: + List of (domain_name, task_name) tuples. + """ + return AVAILABLE_ENVIRONMENTS + + @classmethod + def from_direct( + cls, + domain_name: str = "cartpole", + task_name: str = "balance", + render_height: int = 480, + render_width: int = 640, + port: int = 8765, + ) -> "DMControlEnv": + """ + Create a dm_control environment client with an embedded local server. + + This method starts a local uvicorn server in a subprocess and returns + a client connected to it. + + Args: + domain_name: Default domain to use. + task_name: Default task to use. + render_height: Height of rendered images. + render_width: Width of rendered images. + port: Port for the local server. + + Returns: + DMControlEnv client connected to the local server. + + Example: + >>> client = DMControlEnv.from_direct(domain_name="walker", task_name="walk") + >>> try: + ... result = client.reset() + ... for _ in range(100): + ... result = client.step(DMControlAction(values=[0.0] * 6)) + ... finally: + ... client.close() + """ + import os + import subprocess + import sys + import time + + import requests + + try: + from pathlib import Path + + client_dir = Path(__file__).parent + server_app = "envs.dm_control_env.server.app:app" + cwd = client_dir.parent.parent + + if not (cwd / "envs" / "dm_control_env" / "server" / "app.py").exists(): + if (client_dir / "server" / "app.py").exists(): + server_app = "server.app:app" + cwd = client_dir + except Exception: + server_app = "envs.dm_control_env.server.app:app" + cwd = None + + env = { + **os.environ, + "DMCONTROL_DOMAIN": domain_name, + "DMCONTROL_TASK": task_name, + "DMCONTROL_RENDER_HEIGHT": str(render_height), + "DMCONTROL_RENDER_WIDTH": str(render_width), + "NO_PROXY": "localhost,127.0.0.1", + "no_proxy": "localhost,127.0.0.1", + } + + if cwd: + src_path = str(cwd / "src") + existing_path = env.get("PYTHONPATH", "") + env["PYTHONPATH"] = ( + f"{src_path}:{cwd}:{existing_path}" + if existing_path + else f"{src_path}:{cwd}" + ) + + cmd = [ + sys.executable, + "-m", + "uvicorn", + server_app, + "--host", + "127.0.0.1", + "--port", + str(port), + ] + + server_process = subprocess.Popen( + cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + cwd=str(cwd) if cwd else None, + ) + + base_url = f"http://127.0.0.1:{port}" + healthy = False + for _ in range(30): + try: + response = requests.get( + f"{base_url}/health", + timeout=2, + proxies={"http": None, "https": None}, + ) + if response.status_code == 200: + healthy = True + break + except requests.exceptions.RequestException: + pass + time.sleep(1) + + if not healthy: + server_process.kill() + raise RuntimeError( + f"Failed to start local dm_control server on port {port}. " + "Check that the port is available and dependencies are installed." + ) + + class DirectModeProvider: + """Provider that manages the embedded server subprocess.""" + + def __init__(self, process: subprocess.Popen): + self._process = process + + def stop(self): + """Stop the embedded server.""" + if self._process: + self._process.terminate() + try: + self._process.wait(timeout=10) + except subprocess.TimeoutExpired: + self._process.kill() + self._process = None + + provider = DirectModeProvider(server_process) + client = cls(base_url=base_url, provider=provider) + return client diff --git a/envs/dm_control_env/examples/__init__.py b/envs/dm_control_env/examples/__init__.py new file mode 100644 index 000000000..8ac151820 --- /dev/null +++ b/envs/dm_control_env/examples/__init__.py @@ -0,0 +1 @@ +"""dm_control examples.""" diff --git a/envs/dm_control_env/examples/cartpole_control.py b/envs/dm_control_env/examples/cartpole_control.py new file mode 100644 index 000000000..bd4653227 --- /dev/null +++ b/envs/dm_control_env/examples/cartpole_control.py @@ -0,0 +1,333 @@ +#!/usr/bin/env python3 +"""Interactive cartpole control via OpenEnv. + +This example demonstrates using the dm_control OpenEnv client with +the cartpole environment. Use arrow keys to control the cart. + +Controls: + LEFT/RIGHT arrows: Apply force to move cart + R: Reset environment + ESC or Q: Quit + +Requirements: + pip install pygame + +Usage: + 1. Start the server: uvicorn server.app:app --host 0.0.0.0 --port 8000 + 2. Run this script: python examples/cartpole_control.py + + For visual mode (requires working MuJoCo rendering): + python examples/cartpole_control.py --visual +""" + +import argparse +import random +import sys +from pathlib import Path + +# Add parent directory to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from client import DMControlEnv +from models import DMControlAction + + +def run_headless(env: DMControlEnv, task: str = "balance", max_steps: int = 500): + """Run cartpole control in headless mode.""" + print("\n=== Headless Mode (OpenEnv Step/Observation Pattern) ===") + print("This mode demonstrates the OpenEnv API with the cartpole.\n") + + # Reset environment using OpenEnv pattern + result = env.reset(domain_name="cartpole", task_name=task) + print(f"Initial observations: {list(result.observation.observations.keys())}") + print(f" position: {result.observation.observations.get('position', [])}") + print(f" velocity: {result.observation.observations.get('velocity', [])}") + + total_reward = 0.0 + step_count = 0 + + print("\nRunning with random actions to demonstrate step/observation pattern...\n") + + while not result.done and step_count < max_steps: + # Random action in [-1, 1] + action_value = random.uniform(-1.0, 1.0) + + # Step the environment using OpenEnv pattern + action = DMControlAction(values=[action_value]) + result = env.step(action) + + # Access observation and reward from result + total_reward += result.reward or 0.0 + step_count += 1 + + # Print progress periodically + if step_count % 50 == 0: + pos = result.observation.observations.get("position", []) + vel = result.observation.observations.get("velocity", []) + print( + f"Step {step_count}: reward={result.reward:.3f}, " + f"total={total_reward:.2f}, done={result.done}" + ) + print(f" position={pos}, velocity={vel}") + + print(f"\nEpisode finished: {step_count} steps, total reward: {total_reward:.2f}") + + +def run_interactive(env: DMControlEnv, task: str = "balance"): + """Run interactive control with keyboard input via pygame.""" + import pygame + + print("\n=== Interactive Mode (OpenEnv Step/Observation Pattern) ===") + print("Use LEFT/RIGHT arrows to control cart, R to reset, ESC to quit.\n") + + # Reset environment using OpenEnv pattern + result = env.reset(domain_name="cartpole", task_name=task) + print(f"Initial observations: {list(result.observation.observations.keys())}") + + # Initialize pygame for keyboard input (minimal window) + pygame.init() + screen = pygame.display.set_mode((400, 100)) + pygame.display.set_caption("Cartpole Control - Arrow keys to move, R to reset") + clock = pygame.time.Clock() + + # Font for display + font = pygame.font.Font(None, 24) + + running = True + total_reward = 0.0 + step_count = 0 + + print("\nControls:") + print(" LEFT/RIGHT arrows: Move cart") + print(" R: Reset environment") + print(" ESC or Q: Quit\n") + + while running: + # Handle events + for event in pygame.event.get(): + if event.type == pygame.QUIT: + running = False + elif event.type == pygame.KEYDOWN: + if event.key in (pygame.K_ESCAPE, pygame.K_q): + running = False + elif event.key == pygame.K_r: + result = env.reset(domain_name="cartpole", task_name=task) + total_reward = 0.0 + step_count = 0 + print("Environment reset") + + # Check for held keys (for continuous control) + keys = pygame.key.get_pressed() + if keys[pygame.K_LEFT]: + action_value = -1.0 + elif keys[pygame.K_RIGHT]: + action_value = 1.0 + else: + action_value = 0.0 + + # Step the environment using OpenEnv pattern + action = DMControlAction(values=[action_value]) + result = env.step(action) + + # Track reward from result + total_reward += result.reward or 0.0 + step_count += 1 + + # Check if episode is done + if result.done: + print( + f"Episode finished! Steps: {step_count}, " + f"Total reward: {total_reward:.2f}" + ) + # Auto-reset on done + result = env.reset(domain_name="cartpole", task_name=task) + total_reward = 0.0 + step_count = 0 + + # Update display + direction = ( + "<--" if action_value < 0 else ("-->" if action_value > 0 else "---") + ) + screen.fill((30, 30, 30)) + text = font.render( + f"Step: {step_count} | Reward: {total_reward:.1f} | {direction}", + True, + (255, 255, 255), + ) + screen.blit(text, (10, 40)) + pygame.display.flip() + + # Print progress periodically + if step_count % 200 == 0 and step_count > 0: + print(f"Step {step_count}: Total reward: {total_reward:.2f}") + + # Cap at 30 FPS + clock.tick(30) + + pygame.quit() + print(f"Session ended. Final reward: {total_reward:.2f}") + + +def run_visual(env: DMControlEnv, task: str = "balance"): + """Run with pygame visualization showing rendered frames.""" + import base64 + import io + + import pygame + + print("\n=== Visual Mode (OpenEnv Step/Observation Pattern) ===") + + # Reset environment with rendering enabled + result = env.reset(domain_name="cartpole", task_name=task, render=True) + print(f"Initial observations: {list(result.observation.observations.keys())}") + + # Get first frame to determine window size + if result.observation.pixels is None: + print("Error: Server did not return rendered pixels.") + print("Make sure the server supports render=True") + print("\nTry running in interactive mode (default) instead.") + sys.exit(1) + + # Decode base64 PNG to pygame surface + png_data = base64.b64decode(result.observation.pixels) + frame = pygame.image.load(io.BytesIO(png_data)) + frame_size = frame.get_size() + + # Initialize pygame + pygame.init() + screen = pygame.display.set_mode(frame_size) + pygame.display.set_caption( + "Cartpole (OpenEnv) - Arrow Keys to Move, R to Reset, ESC to Quit" + ) + clock = pygame.time.Clock() + + print("Controls:") + print(" LEFT/RIGHT arrows: Move cart") + print(" R: Reset environment") + print(" ESC or Q: Quit") + + running = True + total_reward = 0.0 + step_count = 0 + + while running: + # Handle events + for event in pygame.event.get(): + if event.type == pygame.QUIT: + running = False + elif event.type == pygame.KEYDOWN: + if event.key in (pygame.K_ESCAPE, pygame.K_q): + running = False + elif event.key == pygame.K_r: + result = env.reset( + domain_name="cartpole", task_name=task, render=True + ) + total_reward = 0.0 + step_count = 0 + print("Environment reset") + + # Check for held keys (for continuous control) + keys = pygame.key.get_pressed() + if keys[pygame.K_LEFT]: + action_value = -1.0 + elif keys[pygame.K_RIGHT]: + action_value = 1.0 + else: + action_value = 0.0 + + # Step the environment using OpenEnv pattern + action = DMControlAction(values=[action_value]) + result = env.step(action, render=True) + + # Track reward from result + total_reward += result.reward or 0.0 + step_count += 1 + + # Check if episode is done + if result.done: + print( + f"Episode finished! Steps: {step_count}, " + f"Total reward: {total_reward:.2f}" + ) + result = env.reset(domain_name="cartpole", task_name=task, render=True) + total_reward = 0.0 + step_count = 0 + + # Render the frame from observation pixels + if result.observation.pixels: + png_data = base64.b64decode(result.observation.pixels) + frame = pygame.image.load(io.BytesIO(png_data)) + screen.blit(frame, (0, 0)) + pygame.display.flip() + + # Print progress periodically + if step_count % 200 == 0 and step_count > 0: + print(f"Step {step_count}: Total reward: {total_reward:.2f}") + + # Cap at 30 FPS + clock.tick(30) + + pygame.quit() + print(f"Session ended. Final reward: {total_reward:.2f}") + + +def main(): + parser = argparse.ArgumentParser( + description="Interactive cartpole control via OpenEnv" + ) + parser.add_argument( + "--visual", + action="store_true", + help="Enable pygame visualization with rendered frames", + ) + parser.add_argument( + "--headless", + action="store_true", + help="Run in headless mode (no pygame, automated control)", + ) + parser.add_argument( + "--max-steps", + type=int, + default=500, + help="Maximum steps for headless mode (default: 500)", + ) + parser.add_argument( + "--task", + type=str, + default="balance", + choices=["balance", "balance_sparse", "swingup", "swingup_sparse"], + help="Cartpole task (default: balance)", + ) + args = parser.parse_args() + + server_url = "http://localhost:8000" + print(f"Connecting to {server_url}...") + + try: + with DMControlEnv(base_url=server_url) as env: + print("Connected!") + + # Get environment state + state = env.state() + print(f"Domain: {state.domain_name}, Task: {state.task_name}") + print(f"Action spec: {state.action_spec}") + + if args.headless: + run_headless(env, task=args.task, max_steps=args.max_steps) + elif args.visual: + run_visual(env, task=args.task) + else: + run_interactive(env, task=args.task) + + except ConnectionError as e: + print(f"Failed to connect: {e}") + print("\nMake sure the server is running:") + print(" cd OpenEnv") + print( + " PYTHONPATH=src:envs uvicorn envs.dm_control_env.server.app:app --port 8000" + ) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/envs/dm_control_env/examples/hopper_control.py b/envs/dm_control_env/examples/hopper_control.py new file mode 100644 index 000000000..a63c82733 --- /dev/null +++ b/envs/dm_control_env/examples/hopper_control.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python3 +"""Interactive hopper control via OpenEnv. + +This example demonstrates using the dm_control OpenEnv client with +the hopper environment. Press SPACE to apply random forces to the joints. + +Controls: + SPACE: Apply random force to all joints + R: Reset environment + ESC or Q: Quit + +Requirements: + pip install pygame + +Usage: + 1. Start the server: uvicorn server.app:app --host 0.0.0.0 --port 8000 + 2. Run this script: python examples/hopper_control.py + + For visual mode (requires working MuJoCo rendering): + python examples/hopper_control.py --visual +""" + +import argparse +import random +import sys +from pathlib import Path + +# Add parent directory to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from client import DMControlEnv +from models import DMControlAction + + +def get_action_dim(env: DMControlEnv) -> int: + """Get the action dimension from the environment state.""" + state = env.state() + action_spec = state.action_spec + if action_spec and "shape" in action_spec: + shape = action_spec["shape"] + if isinstance(shape, list) and len(shape) > 0: + return shape[0] + # Hopper default: 4 actuators (hip, knee, ankle, toe) + return 4 + + +def generate_random_action(action_dim: int, magnitude: float = 1.0) -> DMControlAction: + """Generate a random action with values in [-magnitude, magnitude].""" + values = [random.uniform(-magnitude, magnitude) for _ in range(action_dim)] + return DMControlAction(values=values) + + +def generate_zero_action(action_dim: int) -> DMControlAction: + """Generate a zero action (no force applied).""" + return DMControlAction(values=[0.0] * action_dim) + + +def run_headless(env: DMControlEnv, task: str = "hop", max_steps: int = 1000): + """Run hopper control in headless mode.""" + print("\n=== Headless Mode (OpenEnv Step/Observation Pattern) ===") + print("This mode demonstrates the OpenEnv API with the hopper.\n") + + # Reset environment using OpenEnv pattern + result = env.reset(domain_name="hopper", task_name=task) + print(f"Initial observations: {list(result.observation.observations.keys())}") + + # Get action dimension + action_dim = get_action_dim(env) + print(f"Action dimension: {action_dim}") + + total_reward = 0.0 + step_count = 0 + + print("\nRunning with periodic random forces...") + print("Every 30 steps, a random force burst is applied.\n") + + while not result.done and step_count < max_steps: + # Apply random force every 30 steps, otherwise zero action + if step_count % 30 < 5: + # Random force burst for 5 steps + action = generate_random_action(action_dim, magnitude=0.8) + else: + # No force + action = generate_zero_action(action_dim) + + # Step the environment using OpenEnv pattern + result = env.step(action) + + # Access observation and reward from result + total_reward += result.reward or 0.0 + step_count += 1 + + # Print progress periodically + if step_count % 100 == 0: + # Get some observation values + position = result.observation.observations.get("position", []) + velocity = result.observation.observations.get("velocity", []) + print( + f"Step {step_count}: reward={result.reward:.3f}, " + f"total={total_reward:.2f}, done={result.done}" + ) + if position: + print(f" position: {position[:3]}") + if velocity: + print(f" velocity: {velocity[:3]}") + + print(f"\nEpisode finished: {step_count} steps, total reward: {total_reward:.2f}") + + +def run_interactive(env: DMControlEnv, task: str = "hop"): + """Run interactive control with keyboard input via pygame.""" + import pygame + + print("\n=== Interactive Mode (OpenEnv Step/Observation Pattern) ===") + print("Press SPACE to apply random force, R to reset, ESC to quit.\n") + + # Reset environment using OpenEnv pattern + result = env.reset(domain_name="hopper", task_name=task) + print(f"Initial observations: {list(result.observation.observations.keys())}") + + # Get action dimension + action_dim = get_action_dim(env) + print(f"Action dimension: {action_dim}") + + # Initialize pygame for keyboard input (minimal window) + pygame.init() + screen = pygame.display.set_mode((400, 100)) + pygame.display.set_caption("Hopper Control - SPACE for random force, R to reset") + clock = pygame.time.Clock() + + # Font for display + font = pygame.font.Font(None, 24) + + running = True + total_reward = 0.0 + step_count = 0 + apply_random_force = False + + print("\nControls:") + print(" SPACE: Apply random force to joints") + print(" R: Reset environment") + print(" ESC or Q: Quit\n") + + while running: + # Handle events + for event in pygame.event.get(): + if event.type == pygame.QUIT: + running = False + elif event.type == pygame.KEYDOWN: + if event.key in (pygame.K_ESCAPE, pygame.K_q): + running = False + elif event.key == pygame.K_r: + result = env.reset(domain_name="hopper", task_name=task) + total_reward = 0.0 + step_count = 0 + print("Environment reset") + + # Check for held keys + keys = pygame.key.get_pressed() + apply_random_force = keys[pygame.K_SPACE] + + # Generate action based on input + if apply_random_force: + action = generate_random_action(action_dim, magnitude=2.0) + else: + action = generate_zero_action(action_dim) + + # Step the environment using OpenEnv pattern + result = env.step(action) + + # Track reward from result + total_reward += result.reward or 0.0 + step_count += 1 + + # Check if episode is done + if result.done: + print( + f"Episode finished! Steps: {step_count}, " + f"Total reward: {total_reward:.2f}" + ) + # Auto-reset on done + result = env.reset(domain_name="hopper", task_name=task) + total_reward = 0.0 + step_count = 0 + + # Update display + screen.fill((30, 30, 30)) + status = "FORCE!" if apply_random_force else "idle" + text = font.render( + f"Step: {step_count} | Reward: {total_reward:.1f} | {status}", + True, + (255, 255, 255), + ) + screen.blit(text, (10, 40)) + pygame.display.flip() + + # Print progress periodically + if step_count % 200 == 0 and step_count > 0: + print(f"Step {step_count}: Total reward: {total_reward:.2f}") + + # Cap at 30 FPS + clock.tick(30) + + pygame.quit() + print(f"Session ended. Final reward: {total_reward:.2f}") + + +def run_visual(env: DMControlEnv, task: str = "hop"): + """Run with pygame visualization showing rendered frames.""" + import base64 + import io + + import pygame + + print("\n=== Visual Mode (OpenEnv Step/Observation Pattern) ===") + + # Reset environment with rendering enabled + result = env.reset(domain_name="hopper", task_name=task, render=True) + print(f"Initial observations: {list(result.observation.observations.keys())}") + + # Get action dimension + action_dim = get_action_dim(env) + print(f"Action dimension: {action_dim}") + + # Get first frame to determine window size + if result.observation.pixels is None: + print("Error: Server did not return rendered pixels.") + print("Make sure the server supports render=True") + print("\nTry running in interactive mode (default) instead.") + sys.exit(1) + + # Decode base64 PNG to pygame surface + png_data = base64.b64decode(result.observation.pixels) + frame = pygame.image.load(io.BytesIO(png_data)) + frame_size = frame.get_size() + + # Initialize pygame + pygame.init() + screen = pygame.display.set_mode(frame_size) + pygame.display.set_caption( + "Hopper (OpenEnv) - SPACE for random force, R to Reset, ESC to Quit" + ) + clock = pygame.time.Clock() + + print("Controls:") + print(" SPACE: Apply random force to joints") + print(" R: Reset environment") + print(" ESC or Q: Quit") + + running = True + total_reward = 0.0 + step_count = 0 + + while running: + # Handle events + for event in pygame.event.get(): + if event.type == pygame.QUIT: + running = False + elif event.type == pygame.KEYDOWN: + if event.key in (pygame.K_ESCAPE, pygame.K_q): + running = False + elif event.key == pygame.K_r: + result = env.reset( + domain_name="hopper", task_name=task, render=True + ) + total_reward = 0.0 + step_count = 0 + print("Environment reset") + + # Check for held keys + keys = pygame.key.get_pressed() + apply_random_force = keys[pygame.K_SPACE] + + # Generate action based on input + if apply_random_force: + action = generate_random_action(action_dim, magnitude=2.0) + else: + action = generate_zero_action(action_dim) + + # Step the environment using OpenEnv pattern + result = env.step(action, render=True) + + # Track reward from result + total_reward += result.reward or 0.0 + step_count += 1 + + # Check if episode is done + if result.done: + print( + f"Episode finished! Steps: {step_count}, " + f"Total reward: {total_reward:.2f}" + ) + result = env.reset(domain_name="hopper", task_name=task, render=True) + total_reward = 0.0 + step_count = 0 + + # Render the frame from observation pixels + if result.observation.pixels: + png_data = base64.b64decode(result.observation.pixels) + frame = pygame.image.load(io.BytesIO(png_data)) + screen.blit(frame, (0, 0)) + pygame.display.flip() + + # Print progress periodically + if step_count % 200 == 0 and step_count > 0: + print(f"Step {step_count}: Total reward: {total_reward:.2f}") + + # Cap at 30 FPS + clock.tick(30) + + pygame.quit() + print(f"Session ended. Final reward: {total_reward:.2f}") + + +def main(): + parser = argparse.ArgumentParser( + description="Interactive hopper control via OpenEnv" + ) + parser.add_argument( + "--visual", + action="store_true", + help="Enable pygame visualization with rendered frames", + ) + parser.add_argument( + "--headless", + action="store_true", + help="Run in headless mode (no pygame, automated control)", + ) + parser.add_argument( + "--max-steps", + type=int, + default=1000, + help="Maximum steps for headless mode (default: 1000)", + ) + parser.add_argument( + "--task", + type=str, + default="hop", + choices=["stand", "hop"], + help="Hopper task (default: hop)", + ) + args = parser.parse_args() + + server_url = "http://localhost:8000" + print(f"Connecting to {server_url}...") + + try: + with DMControlEnv(base_url=server_url) as env: + print("Connected!") + + # Get environment state + state = env.state() + print(f"Domain: {state.domain_name}, Task: {state.task_name}") + print(f"Action spec: {state.action_spec}") + + if args.headless: + run_headless(env, task=args.task, max_steps=args.max_steps) + elif args.visual: + run_visual(env, task=args.task) + else: + run_interactive(env, task=args.task) + + except ConnectionError as e: + print(f"Failed to connect: {e}") + print("\nMake sure the server is running:") + print(" cd OpenEnv") + print( + " PYTHONPATH=src:envs uvicorn envs.dm_control_env.server.app:app --port 8000" + ) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/envs/dm_control_env/examples/list_environments.py b/envs/dm_control_env/examples/list_environments.py new file mode 100644 index 000000000..6cf5213ff --- /dev/null +++ b/envs/dm_control_env/examples/list_environments.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +"""List all available dm_control.suite environments. + +This utility prints all available domain/task combinations from dm_control.suite. +""" + +from dm_control import suite + + +def main(): + print("Available dm_control.suite environments:") + print("=" * 50) + + # Group by domain + domains = {} + for domain, task in suite.BENCHMARKING: + if domain not in domains: + domains[domain] = [] + domains[domain].append(task) + + for domain in sorted(domains.keys()): + tasks = sorted(domains[domain]) + print(f"\n{domain}:") + for task in tasks: + # Load env to get action spec + try: + env = suite.load(domain_name=domain, task_name=task) + action_spec = env.action_spec() + action_dim = action_spec.shape[0] + obs_keys = list(env.observation_spec().keys()) + env.close() + print(f" - {task:20s} (action_dim={action_dim}, obs={obs_keys})") + except Exception as e: + print(f" - {task:20s} (error: {e})") + + print("\n" + "=" * 50) + print(f"Total: {len(suite.BENCHMARKING)} environments") + + +if __name__ == "__main__": + main() diff --git a/envs/dm_control_env/examples/quadruped_control.py b/envs/dm_control_env/examples/quadruped_control.py new file mode 100644 index 000000000..ebc7f7c3f --- /dev/null +++ b/envs/dm_control_env/examples/quadruped_control.py @@ -0,0 +1,373 @@ +#!/usr/bin/env python3 +"""Interactive quadruped control via OpenEnv. + +This example demonstrates using the dm_control OpenEnv client with +the quadruped environment. Press SPACE to apply random forces to the joints. + +Controls: + SPACE: Apply random force to all joints + R: Reset environment + ESC or Q: Quit + +Requirements: + pip install pygame + +Usage: + 1. Start the server: uvicorn server.app:app --host 0.0.0.0 --port 8000 + 2. Run this script: python examples/quadruped_control.py + + For visual mode (requires working MuJoCo rendering): + python examples/quadruped_control.py --visual +""" + +import argparse +import random +import sys +from pathlib import Path + +# Add parent directory to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from client import DMControlEnv +from models import DMControlAction + + +def get_action_dim(env: DMControlEnv) -> int: + """Get the action dimension from the environment state.""" + state = env.state() + action_spec = state.action_spec + if action_spec and "shape" in action_spec: + shape = action_spec["shape"] + if isinstance(shape, list) and len(shape) > 0: + return shape[0] + # Quadruped default: 12 actuators (3 per leg x 4 legs) + return 12 + + +def generate_random_action(action_dim: int, magnitude: float = 1.0) -> DMControlAction: + """Generate a random action with values in [-magnitude, magnitude].""" + values = [random.uniform(-magnitude, magnitude) for _ in range(action_dim)] + return DMControlAction(values=values) + + +def generate_zero_action(action_dim: int) -> DMControlAction: + """Generate a zero action (no force applied).""" + return DMControlAction(values=[0.0] * action_dim) + + +def run_headless(env: DMControlEnv, max_steps: int = 1000): + """Run quadruped control in headless mode.""" + print("\n=== Headless Mode (OpenEnv Step/Observation Pattern) ===") + print("This mode demonstrates the OpenEnv API with the quadruped.\n") + + # Reset environment using OpenEnv pattern + result = env.reset(domain_name="quadruped", task_name="walk") + print(f"Initial observations: {list(result.observation.observations.keys())}") + + # Get action dimension + action_dim = get_action_dim(env) + print(f"Action dimension: {action_dim}") + + total_reward = 0.0 + step_count = 0 + + print("\nRunning with periodic random forces...") + print("Every 50 steps, a random force burst is applied.\n") + + while not result.done and step_count < max_steps: + # Apply random force every 50 steps, otherwise zero action + if step_count % 50 < 10: + # Random force burst for 10 steps + action = generate_random_action(action_dim, magnitude=0.5) + else: + # No force + action = generate_zero_action(action_dim) + + # Step the environment using OpenEnv pattern + result = env.step(action) + + # Access observation and reward from result + total_reward += result.reward or 0.0 + step_count += 1 + + # Print progress periodically + if step_count % 100 == 0: + # Get some observation values + egocentric_state = result.observation.observations.get( + "egocentric_state", [] + ) + print( + f"Step {step_count}: reward={result.reward:.3f}, " + f"total={total_reward:.2f}, done={result.done}" + ) + if egocentric_state: + print(f" egocentric_state (first 5): {egocentric_state[:5]}") + + print(f"\nEpisode finished: {step_count} steps, total reward: {total_reward:.2f}") + + +def run_interactive(env: DMControlEnv): + """Run interactive control with keyboard input via pygame.""" + import pygame + + print("\n=== Interactive Mode (OpenEnv Step/Observation Pattern) ===") + print("Press SPACE to apply random force, R to reset, ESC to quit.\n") + + # Reset environment using OpenEnv pattern + result = env.reset(domain_name="quadruped", task_name="walk") + print(f"Initial observations: {list(result.observation.observations.keys())}") + + # Get action dimension + action_dim = get_action_dim(env) + print(f"Action dimension: {action_dim}") + + # Initialize pygame for keyboard input (minimal window) + pygame.init() + screen = pygame.display.set_mode((400, 100)) + pygame.display.set_caption("Quadruped Control - SPACE for random force, R to reset") + clock = pygame.time.Clock() + + # Draw instructions on the window + font = pygame.font.Font(None, 24) + + running = True + total_reward = 0.0 + step_count = 0 + apply_random_force = False + + print("\nControls:") + print(" SPACE: Apply random force to joints") + print(" R: Reset environment") + print(" ESC or Q: Quit\n") + + while running: + # Handle events + for event in pygame.event.get(): + if event.type == pygame.QUIT: + running = False + elif event.type == pygame.KEYDOWN: + if event.key in (pygame.K_ESCAPE, pygame.K_q): + running = False + elif event.key == pygame.K_r: + result = env.reset(domain_name="quadruped", task_name="walk") + total_reward = 0.0 + step_count = 0 + print("Environment reset") + + # Check for held keys + keys = pygame.key.get_pressed() + apply_random_force = keys[pygame.K_SPACE] + + # Generate action based on input + if apply_random_force: + action = generate_random_action(action_dim, magnitude=2.0) + else: + action = generate_zero_action(action_dim) + + # Step the environment using OpenEnv pattern + result = env.step(action) + + # Track reward from result + total_reward += result.reward or 0.0 + step_count += 1 + + # Check if episode is done + if result.done: + print( + f"Episode finished! Steps: {step_count}, " + f"Total reward: {total_reward:.2f}" + ) + # Auto-reset on done + result = env.reset(domain_name="quadruped", task_name="walk") + total_reward = 0.0 + step_count = 0 + + # Update display + screen.fill((30, 30, 30)) + status = "FORCE!" if apply_random_force else "idle" + text = font.render( + f"Step: {step_count} | Reward: {total_reward:.1f} | {status}", + True, + (255, 255, 255), + ) + screen.blit(text, (10, 40)) + pygame.display.flip() + + # Print progress periodically + if step_count % 200 == 0 and step_count > 0: + print(f"Step {step_count}: Total reward: {total_reward:.2f}") + + # Cap at 30 FPS + clock.tick(30) + + pygame.quit() + print(f"Session ended. Final reward: {total_reward:.2f}") + + +def run_visual(env: DMControlEnv): + """Run with pygame visualization showing rendered frames.""" + import base64 + import io + + import pygame + + print("\n=== Visual Mode (OpenEnv Step/Observation Pattern) ===") + + # Reset environment with rendering enabled + result = env.reset(domain_name="quadruped", task_name="walk", render=True) + print(f"Initial observations: {list(result.observation.observations.keys())}") + + # Get action dimension + action_dim = get_action_dim(env) + print(f"Action dimension: {action_dim}") + + # Get first frame to determine window size + if result.observation.pixels is None: + print("Error: Server did not return rendered pixels.") + print("Make sure the server supports render=True") + print("\nTry running in interactive mode (default) instead.") + sys.exit(1) + + # Decode base64 PNG to pygame surface + png_data = base64.b64decode(result.observation.pixels) + frame = pygame.image.load(io.BytesIO(png_data)) + frame_size = frame.get_size() + + # Initialize pygame + pygame.init() + screen = pygame.display.set_mode(frame_size) + pygame.display.set_caption( + "Quadruped (OpenEnv) - SPACE for random force, R to Reset, ESC to Quit" + ) + clock = pygame.time.Clock() + + print("Controls:") + print(" SPACE: Apply random force to joints") + print(" R: Reset environment") + print(" ESC or Q: Quit") + + running = True + total_reward = 0.0 + step_count = 0 + + while running: + # Handle events + for event in pygame.event.get(): + if event.type == pygame.QUIT: + running = False + elif event.type == pygame.KEYDOWN: + if event.key in (pygame.K_ESCAPE, pygame.K_q): + running = False + elif event.key == pygame.K_r: + result = env.reset( + domain_name="quadruped", task_name="walk", render=True + ) + total_reward = 0.0 + step_count = 0 + print("Environment reset") + + # Check for held keys + keys = pygame.key.get_pressed() + apply_random_force = keys[pygame.K_SPACE] + + # Generate action based on input + if apply_random_force: + action = generate_random_action(action_dim, magnitude=2.0) + else: + action = generate_zero_action(action_dim) + + # Step the environment using OpenEnv pattern + result = env.step(action, render=True) + + # Track reward from result + total_reward += result.reward or 0.0 + step_count += 1 + + # Check if episode is done + if result.done: + print( + f"Episode finished! Steps: {step_count}, " + f"Total reward: {total_reward:.2f}" + ) + result = env.reset(domain_name="quadruped", task_name="walk", render=True) + total_reward = 0.0 + step_count = 0 + + # Render the frame from observation pixels + if result.observation.pixels: + png_data = base64.b64decode(result.observation.pixels) + frame = pygame.image.load(io.BytesIO(png_data)) + screen.blit(frame, (0, 0)) + pygame.display.flip() + + # Print progress periodically + if step_count % 200 == 0 and step_count > 0: + print(f"Step {step_count}: Total reward: {total_reward:.2f}") + + # Cap at 30 FPS + clock.tick(30) + + pygame.quit() + print(f"Session ended. Final reward: {total_reward:.2f}") + + +def main(): + parser = argparse.ArgumentParser( + description="Interactive quadruped control via OpenEnv" + ) + parser.add_argument( + "--visual", + action="store_true", + help="Enable pygame visualization with rendered frames", + ) + parser.add_argument( + "--headless", + action="store_true", + help="Run in headless mode (no pygame, automated control)", + ) + parser.add_argument( + "--max-steps", + type=int, + default=1000, + help="Maximum steps for headless mode (default: 1000)", + ) + parser.add_argument( + "--task", + type=str, + default="walk", + choices=["walk", "run", "escape", "fetch"], + help="Quadruped task (default: walk)", + ) + args = parser.parse_args() + + server_url = "http://localhost:8000" + print(f"Connecting to {server_url}...") + + try: + with DMControlEnv(base_url=server_url) as env: + print("Connected!") + + # Get environment state + state = env.state() + print(f"Domain: {state.domain_name}, Task: {state.task_name}") + print(f"Action spec: {state.action_spec}") + + if args.headless: + run_headless(env, max_steps=args.max_steps) + elif args.visual: + run_visual(env) + else: + run_interactive(env) + + except ConnectionError as e: + print(f"Failed to connect: {e}") + print("\nMake sure the server is running:") + print(" cd OpenEnv") + print( + " PYTHONPATH=src:envs uvicorn envs.dm_control_env.server.app:app --port 8000" + ) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/envs/dm_control_env/models.py b/envs/dm_control_env/models.py new file mode 100644 index 000000000..a4421f537 --- /dev/null +++ b/envs/dm_control_env/models.py @@ -0,0 +1,186 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Data models for the dm_control OpenEnv Environment. + +This environment wraps dm_control.suite, providing access to all MuJoCo-based +continuous control tasks (cartpole, walker, humanoid, cheetah, etc.). +""" + +from typing import Any, Dict, List, Optional + +from pydantic import Field + +try: + from openenv.core.env_server.types import Action, Observation, State +except ImportError: + from openenv.core.env_server.types import Action, Observation, State + + +class DMControlAction(Action): + """ + Action for dm_control environments. + + All dm_control.suite environments use continuous actions represented as + a list of float values. The size and bounds depend on the specific + domain/task combination. + + Example (cartpole - 1D action): + >>> action = DMControlAction(values=[0.5]) # Push cart right + + Example (walker - 6D action): + >>> action = DMControlAction(values=[0.1, -0.2, 0.3, 0.0, -0.1, 0.2]) + + Attributes: + values: List of continuous action values. Shape and bounds depend on + the loaded environment's action_spec. + """ + + values: List[float] = Field( + default_factory=list, + description="Continuous action values matching the environment's action_spec", + ) + + +class DMControlObservation(Observation): + """ + Observation from dm_control environments. + + dm_control environments return observations as a dictionary of named arrays. + Common observation keys include 'position', 'velocity', 'orientations', etc. + The exact keys depend on the domain/task combination. + + Example observation keys by domain: + - cartpole: 'position' (cos/sin of angle), 'velocity' + - walker: 'orientations', 'height', 'velocity' + - humanoid: 'joint_angles', 'head_height', 'extremities', 'torso_vertical', 'com_velocity' + + Attributes: + observations: Dictionary mapping observation names to their values. + Each value is a flattened list of floats. + pixels: Optional base64-encoded PNG image of the rendered scene. + Only included when render=True is passed to reset/step. + """ + + observations: Dict[str, List[float]] = Field( + default_factory=dict, + description="Named observation arrays from the environment", + ) + pixels: Optional[str] = Field( + default=None, + description="Base64-encoded PNG image (when render=True)", + ) + + +class DMControlState(State): + """ + Extended state for dm_control environments. + + Provides metadata about the currently loaded environment including + the domain/task names and action/observation specifications. + + Attributes: + episode_id: Unique identifier for the current episode. + step_count: Number of steps taken in the current episode. + domain_name: The dm_control domain (e.g., 'cartpole', 'walker'). + task_name: The specific task (e.g., 'balance', 'walk'). + action_spec: Specification of the action space including shape and bounds. + observation_spec: Specification of the observation space. + physics_timestep: The physics simulation timestep in seconds. + control_timestep: The control timestep (time between actions) in seconds. + """ + + domain_name: str = Field( + default="cartpole", + description="The dm_control domain name", + ) + task_name: str = Field( + default="balance", + description="The task name within the domain", + ) + action_spec: Dict[str, Any] = Field( + default_factory=dict, + description="Specification of the action space (shape, dtype, bounds)", + ) + observation_spec: Dict[str, Any] = Field( + default_factory=dict, + description="Specification of the observation space", + ) + physics_timestep: float = Field( + default=0.002, + description="Physics simulation timestep in seconds", + ) + control_timestep: float = Field( + default=0.02, + description="Control timestep (time between actions) in seconds", + ) + + +# Available dm_control.suite environments +# Format: (domain_name, task_name) +AVAILABLE_ENVIRONMENTS = [ + # Cartpole + ("cartpole", "balance"), + ("cartpole", "balance_sparse"), + ("cartpole", "swingup"), + ("cartpole", "swingup_sparse"), + # Pendulum + ("pendulum", "swingup"), + # Point mass + ("point_mass", "easy"), + ("point_mass", "hard"), + # Reacher + ("reacher", "easy"), + ("reacher", "hard"), + # Ball in cup + ("ball_in_cup", "catch"), + # Finger + ("finger", "spin"), + ("finger", "turn_easy"), + ("finger", "turn_hard"), + # Fish + ("fish", "upright"), + ("fish", "swim"), + # Cheetah + ("cheetah", "run"), + # Walker + ("walker", "stand"), + ("walker", "walk"), + ("walker", "run"), + # Hopper + ("hopper", "stand"), + ("hopper", "hop"), + # Swimmer + ("swimmer", "swimmer6"), + ("swimmer", "swimmer15"), + # Humanoid + ("humanoid", "stand"), + ("humanoid", "walk"), + ("humanoid", "run"), + # Manipulator + ("manipulator", "bring_ball"), + ("manipulator", "bring_peg"), + ("manipulator", "insert_ball"), + ("manipulator", "insert_peg"), + # Acrobot + ("acrobot", "swingup"), + ("acrobot", "swingup_sparse"), + # Stacker + ("stacker", "stack_2"), + ("stacker", "stack_4"), + # Dog + ("dog", "stand"), + ("dog", "walk"), + ("dog", "trot"), + ("dog", "run"), + ("dog", "fetch"), + # Quadruped + ("quadruped", "walk"), + ("quadruped", "run"), + ("quadruped", "escape"), + ("quadruped", "fetch"), +] diff --git a/envs/dm_control_env/openenv.yaml b/envs/dm_control_env/openenv.yaml new file mode 100644 index 000000000..3ec5d28d6 --- /dev/null +++ b/envs/dm_control_env/openenv.yaml @@ -0,0 +1,6 @@ +spec_version: 1 +name: dm_control_env +type: space +runtime: fastapi +app: server.app:app +port: 8000 diff --git a/envs/dm_control_env/pyproject.toml b/envs/dm_control_env/pyproject.toml new file mode 100644 index 000000000..e6612969d --- /dev/null +++ b/envs/dm_control_env/pyproject.toml @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +[build-system] +requires = ["setuptools>=45", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "openenv-dmcontrol-env" +version = "0.1.0" +description = "dm_control Environment for OpenEnv - wraps MuJoCo-based continuous control tasks (cartpole, walker, humanoid, etc.)" +requires-python = ">=3.10" +dependencies = [ + # Core OpenEnv dependencies + "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git", + "fastapi>=0.115.0", + "pydantic>=2.0.0", + "uvicorn>=0.24.0", + "requests>=2.31.0", + # dm_control dependencies + "mujoco>=3.0.0", + "dm_control>=1.0.0", + "numpy>=1.20.0", + # Optional: for pixel observations + "pillow>=9.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "pytest-cov>=4.0.0", +] +interactive = [ + # For interactive examples with keyboard control + "pygame>=2.0.0", +] + +[project.scripts] +# Server entry point - enables running via: uv run --project . server +server = "dm_control_env.server.app:main" + +[tool.setuptools] +include-package-data = true +packages = ["dm_control_env", "dm_control_env.server"] +package-dir = { "dm_control_env" = ".", "dm_control_env.server" = "server" } diff --git a/envs/dm_control_env/server/Dockerfile b/envs/dm_control_env/server/Dockerfile new file mode 100644 index 000000000..d3b1c64ae --- /dev/null +++ b/envs/dm_control_env/server/Dockerfile @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Multi-stage build for dm_control environment +# Uses pip for package installation + +FROM python:3.11-slim AS builder + +WORKDIR /app + +# Install build dependencies including OpenGL for MuJoCo +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + git \ + libgl1 \ + libglx-mesa0 \ + libglew-dev \ + libosmesa6-dev \ + libgl1-mesa-dev \ + libglfw3 \ + patchelf \ + && rm -rf /var/lib/apt/lists/* + +# Copy environment code +COPY . /app/env + +WORKDIR /app/env + +# Install dependencies using pip +RUN pip install --upgrade pip && \ + pip install --no-cache-dir -e . + +# Final runtime stage +FROM python:3.11-slim + +WORKDIR /app + +# Install runtime dependencies (OpenGL for MuJoCo rendering, curl for healthcheck) +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + libgl1 \ + libglx-mesa0 \ + libglew-dev \ + libosmesa6-dev \ + libglfw3 \ + && rm -rf /var/lib/apt/lists/* + +# Copy installed packages from builder +COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin + +# Copy the environment code +COPY . /app/env + +# Set PYTHONPATH so imports work correctly +ENV PYTHONPATH="/app/env" + +# Set MuJoCo to use OSMesa for headless rendering +ENV MUJOCO_GL="osmesa" + +# Expose port +EXPOSE 8000 + +# Health check +HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 \ + CMD curl -f http://localhost:8000/health || exit 1 + +# Run the FastAPI server +# Use exec to replace the shell with uvicorn so it receives SIGINT/SIGTERM directly +CMD ["sh", "-c", "cd /app/env && exec uvicorn server.app:app --host 0.0.0.0 --port 8000"] diff --git a/envs/dm_control_env/server/__init__.py b/envs/dm_control_env/server/__init__.py new file mode 100644 index 000000000..bb6bf5bc4 --- /dev/null +++ b/envs/dm_control_env/server/__init__.py @@ -0,0 +1,5 @@ +"""dm_control OpenEnv server module.""" + +from .dm_control_environment import DMControlEnvironment + +__all__ = ["DMControlEnvironment"] diff --git a/envs/dm_control_env/server/app.py b/envs/dm_control_env/server/app.py new file mode 100644 index 000000000..d1d88f3e9 --- /dev/null +++ b/envs/dm_control_env/server/app.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +FastAPI application for the dm_control Environment. + +This module creates an HTTP server that exposes dm_control.suite environments +over HTTP and WebSocket endpoints, compatible with EnvClient. + +Usage: + # Development (with auto-reload): + uvicorn server.app:app --reload --host 0.0.0.0 --port 8000 + + # Production: + uvicorn server.app:app --host 0.0.0.0 --port 8000 + + # Or run directly: + uv run --project . server +""" + +try: + from openenv.core.env_server.http_server import create_app + + from ..models import DMControlAction, DMControlObservation + from .dm_control_environment import DMControlEnvironment +except ImportError: + from openenv.core.env_server.http_server import create_app + + try: + import sys + from pathlib import Path + + _parent = str(Path(__file__).parent.parent) + if _parent not in sys.path: + sys.path.insert(0, _parent) + from models import DMControlAction, DMControlObservation + from server.dm_control_environment import DMControlEnvironment + except ImportError: + try: + from dm_control_env.models import DMControlAction, DMControlObservation + from dm_control_env.server.dm_control_environment import ( + DMControlEnvironment, + ) + except ImportError: + from envs.dm_control_env.models import DMControlAction, DMControlObservation + from envs.dm_control_env.server.dm_control_environment import ( + DMControlEnvironment, + ) + +# Create the app with web interface +# Pass the class (factory) for concurrent session support +app = create_app( + DMControlEnvironment, + DMControlAction, + DMControlObservation, + env_name="dm_control_env", +) + + +def main(): + """ + Entry point for direct execution via uv run or python -m. + + This function enables running the server without Docker: + uv run --project . server + python -m envs.dm_control_env.server.app + openenv serve dm_control_env + """ + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8000) + + +if __name__ == "__main__": + main() diff --git a/envs/dm_control_env/server/dm_control_environment.py b/envs/dm_control_env/server/dm_control_environment.py new file mode 100644 index 000000000..7e454f685 --- /dev/null +++ b/envs/dm_control_env/server/dm_control_environment.py @@ -0,0 +1,428 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +dm_control Environment Implementation. + +Wraps dm_control.suite environments (cartpole, walker, humanoid, etc.) +with the OpenEnv interface for standardized reinforcement learning. +""" + +import base64 +import io +import os +import sys +from typing import Any, Dict, Optional +from uuid import uuid4 + +# Configure MuJoCo rendering backend before importing dm_control +# On macOS, we don't set MUJOCO_GL - use default (glfw) which works +# when running synchronously in the main thread (see reset_async/step_async) +# On Linux, use egl for headless rendering +if "MUJOCO_GL" not in os.environ and sys.platform != "darwin": + os.environ.setdefault("MUJOCO_GL", "egl") + +import numpy as np + +try: + from openenv.core.env_server.interfaces import Environment + + from ..models import DMControlAction, DMControlObservation, DMControlState +except ImportError: + from openenv.core.env_server.interfaces import Environment + + try: + import sys + from pathlib import Path + + _parent = str(Path(__file__).parent.parent) + if _parent not in sys.path: + sys.path.insert(0, _parent) + from models import DMControlAction, DMControlObservation, DMControlState + except ImportError: + try: + from dm_control_env.models import ( + DMControlAction, + DMControlObservation, + DMControlState, + ) + except ImportError: + from envs.dm_control_env.models import ( + DMControlAction, + DMControlObservation, + DMControlState, + ) + + +class DMControlEnvironment(Environment): + """ + Wraps dm_control.suite environments with the OpenEnv interface. + + This environment supports all dm_control.suite domains and tasks including + cartpole, walker, humanoid, cheetah, and more. + + Features: + - Dynamic environment switching via reset(domain_name="...", task_name="...") + - Support for all continuous control tasks + - Optional visual observations (base64-encoded images) + - Configurable via constructor or environment variables + + Example: + >>> env = DMControlEnvironment() + >>> obs = env.reset() # Default: cartpole/balance + >>> print(obs.observations) + >>> + >>> # Take an action + >>> obs = env.step(DMControlAction(values=[0.5])) # Push cart right + >>> print(obs.reward) + + Example with different environment: + >>> env = DMControlEnvironment(domain_name="walker", task_name="walk") + >>> obs = env.reset() + >>> + >>> # Or switch environment on reset + >>> obs = env.reset(domain_name="cheetah", task_name="run") + """ + + # dm_control environments are isolated and thread-safe + SUPPORTS_CONCURRENT_SESSIONS = True + + def __init__( + self, + domain_name: Optional[str] = None, + task_name: Optional[str] = None, + render_height: Optional[int] = None, + render_width: Optional[int] = None, + ): + """ + Initialize the dm_control environment. + + Args: + domain_name: The dm_control domain to load. + Env var: DMCONTROL_DOMAIN (default: cartpole) + task_name: The task within the domain. + Env var: DMCONTROL_TASK (default: balance) + render_height: Height of rendered images (when render=True). + Env var: DMCONTROL_RENDER_HEIGHT (default: 480) + render_width: Width of rendered images (when render=True). + Env var: DMCONTROL_RENDER_WIDTH (default: 640) + """ + self._env = None + + self._domain_name = domain_name or os.environ.get( + "DMCONTROL_DOMAIN", "cartpole" + ) + self._task_name = task_name or os.environ.get("DMCONTROL_TASK", "balance") + self._render_height = ( + render_height + if render_height is not None + else int(os.environ.get("DMCONTROL_RENDER_HEIGHT", "480")) + ) + self._render_width = ( + render_width + if render_width is not None + else int(os.environ.get("DMCONTROL_RENDER_WIDTH", "640")) + ) + self._include_pixels = False + + self._state = DMControlState( + episode_id=str(uuid4()), + step_count=0, + domain_name=self._domain_name, + task_name=self._task_name, + ) + + def _load_environment(self, domain_name: str, task_name: str) -> None: + """Load or switch to a dm_control environment.""" + if self._env is not None: + try: + self._env.close() + except Exception: + pass + + try: + from dm_control import suite + except ImportError as e: + raise ImportError( + "dm_control is required. Install with: pip install dm_control" + ) from e + except Exception as e: + # MuJoCo/OpenGL initialization can fail on macOS + error_msg = str(e) + if sys.platform == "darwin": + raise RuntimeError( + f"Failed to import dm_control (MuJoCo error): {error_msg}\n\n" + "On macOS, try one of these solutions:\n" + "1. Install osmesa: brew install mesa\n" + "2. Run with MUJOCO_GL=glfw (requires display)\n" + "3. Run with MUJOCO_GL=egl (if EGL is available)" + ) from e + raise + + try: + self._env = suite.load(domain_name=domain_name, task_name=task_name) + except Exception as e: + error_msg = str(e).lower() + # Check for MuJoCo/OpenGL errors + if "gl" in error_msg or "render" in error_msg or "display" in error_msg: + if sys.platform == "darwin": + raise RuntimeError( + f"MuJoCo initialization failed: {e}\n\n" + "On macOS, try one of these solutions:\n" + "1. Install osmesa: brew install mesa\n" + "2. Run with MUJOCO_GL=glfw (requires display)\n" + "3. Set PYOPENGL_PLATFORM=osmesa" + ) from e + # Check if it's an invalid environment error + try: + available = [(d, t) for d, t in suite.BENCHMARKING] + raise ValueError( + f"Failed to load {domain_name}/{task_name}. " + f"Available environments: {available[:10]}... " + f"(use dm_control.suite.BENCHMARKING for full list)" + ) from e + except Exception: + raise + + self._domain_name = domain_name + self._task_name = task_name + + self._state.domain_name = domain_name + self._state.task_name = task_name + self._state.action_spec = self._get_action_spec_info() + self._state.observation_spec = self._get_observation_spec_info() + self._state.physics_timestep = self._env.physics.timestep() + self._state.control_timestep = self._env.control_timestep() + + def _get_action_spec_info(self) -> Dict[str, Any]: + """Get information about the action space.""" + spec = self._env.action_spec() + return { + "shape": list(spec.shape), + "dtype": str(spec.dtype), + "minimum": spec.minimum.tolist(), + "maximum": spec.maximum.tolist(), + "name": spec.name, + } + + def _get_observation_spec_info(self) -> Dict[str, Any]: + """Get information about the observation space.""" + specs = self._env.observation_spec() + obs_info = {} + for name, spec in specs.items(): + obs_info[name] = { + "shape": list(spec.shape), + "dtype": str(spec.dtype), + } + return obs_info + + def _get_observation( + self, + time_step, + include_pixels: bool = False, + ) -> DMControlObservation: + """Convert dm_control TimeStep to DMControlObservation.""" + import dm_env + + observations = {} + for name, value in time_step.observation.items(): + observations[name] = np.asarray(value).flatten().tolist() + + pixels = None + if include_pixels: + try: + frame = self._env.physics.render( + height=self._render_height, + width=self._render_width, + camera_id=0, + ) + from PIL import Image + + img = Image.fromarray(frame) + buffer = io.BytesIO() + img.save(buffer, format="PNG") + pixels = base64.b64encode(buffer.getvalue()).decode("utf-8") + except Exception: + pass + + done = time_step.step_type == dm_env.StepType.LAST + reward = float(time_step.reward) if time_step.reward is not None else 0.0 + + return DMControlObservation( + observations=observations, + pixels=pixels, + reward=reward, + done=done, + ) + + def reset( + self, + domain_name: Optional[str] = None, + task_name: Optional[str] = None, + seed: Optional[int] = None, + render: bool = False, + **kwargs, + ) -> DMControlObservation: + """ + Reset the environment and return initial observation. + + Args: + domain_name: Optionally switch to a different domain. + task_name: Optionally switch to a different task. + seed: Random seed for reproducibility. + render: If True, include pixel observations. + **kwargs: Additional arguments (ignored). + + Returns: + DMControlObservation with initial state. + """ + self._include_pixels = render + + target_domain = domain_name or self._domain_name + target_task = task_name or self._task_name + + if ( + self._env is None + or target_domain != self._domain_name + or target_task != self._task_name + ): + self._load_environment(target_domain, target_task) + + if seed is not None: + np.random.seed(seed) + + time_step = self._env.reset() + + self._state = DMControlState( + episode_id=str(uuid4()), + step_count=0, + domain_name=self._domain_name, + task_name=self._task_name, + action_spec=self._state.action_spec, + observation_spec=self._state.observation_spec, + physics_timestep=self._state.physics_timestep, + control_timestep=self._state.control_timestep, + ) + + return self._get_observation(time_step, include_pixels=render) + + def step( + self, + action: DMControlAction, + render: bool = False, + **kwargs, + ) -> DMControlObservation: + """ + Execute one step in the environment. + + Args: + action: DMControlAction with continuous action values. + render: If True, include pixel observations. + + Returns: + DMControlObservation with new state, reward, and done flag. + """ + if self._env is None: + raise RuntimeError("Environment not initialized. Call reset() first.") + + action_array = np.array(action.values, dtype=np.float64) + + action_spec = self._env.action_spec() + expected_shape = action_spec.shape + if action_array.shape != expected_shape: + if action_array.size == np.prod(expected_shape): + action_array = action_array.reshape(expected_shape) + else: + raise ValueError( + f"Action shape {action_array.shape} doesn't match " + f"expected shape {expected_shape}" + ) + + action_array = np.clip(action_array, action_spec.minimum, action_spec.maximum) + + time_step = self._env.step(action_array) + self._state.step_count += 1 + + return self._get_observation( + time_step, include_pixels=render or self._include_pixels + ) + + async def reset_async( + self, + domain_name: Optional[str] = None, + task_name: Optional[str] = None, + seed: Optional[int] = None, + render: bool = False, + **kwargs, + ) -> DMControlObservation: + """Async version of reset. + + On macOS, runs synchronously to avoid MuJoCo threading crashes. + On other platforms, runs in a thread pool. + """ + if sys.platform == "darwin": + # On macOS, MuJoCo crashes when run in a background thread + # Run synchronously (blocks event loop but avoids crash) + return self.reset( + domain_name=domain_name, + task_name=task_name, + seed=seed, + render=render, + **kwargs, + ) + else: + import asyncio + + return await asyncio.to_thread( + self.reset, + domain_name=domain_name, + task_name=task_name, + seed=seed, + render=render, + **kwargs, + ) + + async def step_async( + self, + action: DMControlAction, + render: bool = False, + **kwargs, + ) -> DMControlObservation: + """Async version of step. + + On macOS, runs synchronously to avoid MuJoCo threading crashes. + On other platforms, runs in a thread pool. + """ + if sys.platform == "darwin": + # On macOS, MuJoCo crashes when run in a background thread + # Run synchronously (blocks event loop but avoids crash) + return self.step(action, render=render, **kwargs) + else: + import asyncio + + return await asyncio.to_thread(self.step, action, render=render, **kwargs) + + @property + def state(self) -> DMControlState: + """Get the current environment state.""" + return self._state + + def close(self) -> None: + """Close the dm_control environment.""" + env = getattr(self, "_env", None) + if env is not None: + try: + env.close() + except Exception: + pass + self._env = None + + def __del__(self): + """Cleanup on deletion.""" + try: + self.close() + except Exception: + pass