diff --git a/src/openenv/core/env_server/interfaces.py b/src/openenv/core/env_server/interfaces.py index ecf6da57c..24b5342bf 100644 --- a/src/openenv/core/env_server/interfaces.py +++ b/src/openenv/core/env_server/interfaces.py @@ -5,10 +5,13 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import Any, Generic, Optional, Protocol, TypedDict, TypeVar +from typing import Any, Generic, Optional, Protocol, TypedDict, TypeVar, TYPE_CHECKING from .types import Action, Observation, State, EnvironmentMetadata +if TYPE_CHECKING: + from openenv.core.rubrics import Rubric + ActT = TypeVar("ActT", bound=Action) ObsT = TypeVar("ObsT", bound=Observation) StateT = TypeVar("StateT", bound=State) @@ -94,6 +97,8 @@ class Environment(ABC, Generic[ActT, ObsT, StateT]): Args: transform: Optional transform to apply to observations + rubric: Optional rubric for reward computation. When provided, the + rubric's output can be used to set the observation's reward in step(). Class Attributes: SUPPORTS_CONCURRENT_SESSIONS: Whether this environment supports concurrent sessions. @@ -105,13 +110,30 @@ class Environment(ABC, Generic[ActT, ObsT, StateT]): - The environment uses proper session isolation (e.g., unique working dirs) - No shared mutable state exists between instances - External resources (databases, APIs) can handle concurrent access + + Attributes: + rubric: Optional rubric for computing rewards. Environments can set this + in __init__ and use it in step() to compute observation rewards. + Training infrastructure can access it for introspection: + for name, r in env.rubric.named_rubrics(): + print(f"{name}: {r.last_score}") + + See RFC 004 for rubric design: rfcs/004-rubrics.md """ # Class-level flag indicating whether this environment supports concurrent sessions SUPPORTS_CONCURRENT_SESSIONS: bool = False - def __init__(self, transform: Optional[Transform[ObsT]] = None): + # Optional rubric for reward computation + rubric: Optional["Rubric"] + + def __init__( + self, + transform: Optional[Transform[ObsT]] = None, + rubric: Optional["Rubric"] = None, + ): self.transform = transform + self.rubric = rubric @abstractmethod def reset( @@ -185,6 +207,40 @@ def _apply_transform(self, observation: ObsT) -> ObsT: return self.transform(observation) return observation + def _apply_rubric(self, action: ActT, observation: ObsT) -> float: + """Apply rubric if one is provided. + + Args: + action: The action taken by the agent. + observation: The resulting observation. + + Returns: + Reward value from the rubric, or 0.0 if no rubric is set. + + Usage in step(): + def step(self, action: MyAction, ...) -> MyObservation: + # ... execute action and create observation ... + observation.reward = self._apply_rubric(action, observation) + return observation + """ + if self.rubric is not None: + return self.rubric(action, observation) + return 0.0 + + def _reset_rubric(self) -> None: + """Reset the rubric state if one is provided. + + Call this in reset() to clear any trajectory state in the rubric. + + Usage in reset(): + def reset(self, ...) -> MyObservation: + self._reset_rubric() + # ... create initial observation ... + return observation + """ + if self.rubric is not None: + self.rubric.reset() + def close(self) -> None: """Clean up resources used by the environment. diff --git a/src/openenv/core/rubrics/__init__.py b/src/openenv/core/rubrics/__init__.py new file mode 100644 index 000000000..0c29556f2 --- /dev/null +++ b/src/openenv/core/rubrics/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Rubrics for reward computation. + +See RFC 004 for full design: rfcs/004-rubrics.md +""" + +from openenv.core.rubrics.base import Rubric +from openenv.core.rubrics.containers import ( + Sequential, + Gate, + WeightedSum, + RubricList, + RubricDict, +) +from openenv.core.rubrics.trajectory import ( + TrajectoryRubric, + ExponentialDiscountingTrajectoryRubric, +) + +__all__ = [ + # Base + "Rubric", + # Containers + "Sequential", + "Gate", + "WeightedSum", + "RubricList", + "RubricDict", + # Trajectory + "TrajectoryRubric", + "ExponentialDiscountingTrajectoryRubric", +] diff --git a/src/openenv/core/rubrics/base.py b/src/openenv/core/rubrics/base.py new file mode 100644 index 000000000..8041e654e --- /dev/null +++ b/src/openenv/core/rubrics/base.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Base Rubric class for reward computation. + +Rubrics compute rewards from actions and observations. The API is modeled +after PyTorch's nn.Module: users implement forward(), and the framework +handles child registration and hooks. + +See RFC 004 for full design: rfcs/004-rubrics.md +""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, Iterator, List, Optional, Tuple, Callable + + +class Rubric(ABC): + """Abstract base class for reward computation. + + A Rubric computes a reward signal from an action and observation. + Subclasses implement forward() to define the reward logic. + + Usage: + class MyRubric(Rubric): + def forward(self, action, observation) -> float: + return 1.0 if action.valid else 0.0 + + rubric = MyRubric() + reward = rubric(action, observation) + + Child rubrics are auto-registered when assigned as attributes, + enabling hierarchical composition and introspection. + """ + + _rubric_children: Dict[str, "Rubric"] + _forward_hooks: List[Callable] + _forward_pre_hooks: List[Callable] + last_score: Optional[float] + + def __init__(self): + # Use object.__setattr__ to avoid triggering __setattr__ during init + object.__setattr__(self, "_rubric_children", {}) + object.__setattr__(self, "_forward_hooks", []) + object.__setattr__(self, "_forward_pre_hooks", []) + object.__setattr__(self, "last_score", None) + + def __setattr__(self, name: str, value: Any) -> None: + # Auto-register child rubrics when assigned as attributes + if isinstance(value, Rubric): + self._rubric_children[name] = value + object.__setattr__(self, name, value) + + def __call__(self, action: Any, observation: Any) -> float: + """Evaluate the rubric with hooks. + + Args: + action: The action taken by the agent. + observation: The resulting observation. + + Returns: + Reward value (typically 0.0 to 1.0). + """ + # Pre-forward hooks + for hook in self._forward_pre_hooks: + hook(self, action, observation) + + # Compute reward + result = self.forward(action, observation) + self.last_score = result + + # Post-forward hooks + for hook in self._forward_hooks: + hook(self, action, observation, result) + + return result + + @abstractmethod + def forward(self, action: Any, observation: Any) -> float: + """Compute the reward. Implement this in subclasses. + + Args: + action: The action taken by the agent. + observation: The resulting observation. + + Returns: + Reward value (typically 0.0 to 1.0). + """ + raise NotImplementedError + + def register_forward_hook( + self, hook: Callable[["Rubric", Any, Any, float], None] + ) -> None: + """Register a hook called after forward(). + + Args: + hook: Callable with signature (rubric, action, observation, result). + """ + self._forward_hooks.append(hook) + + def register_forward_pre_hook( + self, hook: Callable[["Rubric", Any, Any], None] + ) -> None: + """Register a hook called before forward(). + + Args: + hook: Callable with signature (rubric, action, observation). + """ + self._forward_pre_hooks.append(hook) + + def children(self) -> Iterator["Rubric"]: + """Iterate over immediate child rubrics.""" + yield from self._rubric_children.values() + + def named_children(self) -> Iterator[Tuple[str, "Rubric"]]: + """Iterate over immediate child rubrics with names.""" + yield from self._rubric_children.items() + + def rubrics(self) -> Iterator["Rubric"]: + """Iterate over all descendant rubrics (depth-first).""" + for child in self._rubric_children.values(): + yield child + yield from child.rubrics() + + def named_rubrics(self, prefix: str = "") -> Iterator[Tuple[str, "Rubric"]]: + """Iterate over all descendant rubrics with dot-separated names.""" + for name, child in self._rubric_children.items(): + full_name = f"{prefix}.{name}" if prefix else name + yield full_name, child + yield from child.named_rubrics(full_name) + + def get_rubric(self, path: str) -> "Rubric": + """Access a nested rubric by dot-separated path. + + Args: + path: Dot-separated path (e.g., "code.syntax"). + + Returns: + The rubric at the specified path. + + Raises: + KeyError: If the path does not exist. + """ + parts = path.split(".") + current = self + for part in parts: + if part not in current._rubric_children: + raise KeyError(f"Rubric path not found: {path}") + current = current._rubric_children[part] + return current + + def reset(self) -> None: + """Reset any internal state. Override in subclasses if needed.""" + pass + + def state_dict(self) -> Dict[str, Any]: + """Serialize rubric configuration for checkpointing.""" + return {} + + def load_state_dict(self, state: Dict[str, Any]) -> None: + """Load rubric configuration from checkpoint.""" + pass diff --git a/src/openenv/core/rubrics/containers.py b/src/openenv/core/rubrics/containers.py new file mode 100644 index 000000000..19d37e6e3 --- /dev/null +++ b/src/openenv/core/rubrics/containers.py @@ -0,0 +1,275 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Container rubrics for composing reward computations. + +These containers provide common aggregation patterns for rubrics, +similar to how PyTorch provides nn.Sequential alongside nn.Module. + +See RFC 004 for full design: rfcs/004-rubrics.md +""" + +from typing import Any, Dict, Iterator, List, Mapping, Tuple, Union + +from openenv.core.rubrics.base import Rubric + + +class Sequential(Rubric): + """Run rubrics in order, fail-fast on zero. + + Runs child rubrics in order. If any returns 0, stops immediately + and returns 0. This implements hierarchical gating patterns where + syntax checks run before execution checks. + + Usage: + rubric = Sequential( + Gate(Compiles()), + Gate(PassesTests(), threshold=0.5), + WeightedSum([PassesTests(), StyleRubric()], weights=[0.7, 0.3]) + ) + """ + + def __init__(self, *rubrics: Rubric): + """Initialize with rubrics to run in sequence. + + Args: + *rubrics: Rubrics to run in order. Stops and returns 0 if any + child returns 0. + """ + super().__init__() + for i, rubric in enumerate(rubrics): + setattr(self, f"rubric_{i}", rubric) + self._rubric_list = list(rubrics) + + def forward(self, action: Any, observation: Any) -> float: + """Run rubrics in order, return 0 if any returns 0.""" + result = 1.0 + for rubric in self._rubric_list: + score = rubric(action, observation) + if score == 0.0: + return 0.0 + result = score # Return last non-zero score + return result + + def __len__(self) -> int: + return len(self._rubric_list) + + def __getitem__(self, index: int) -> Rubric: + return self._rubric_list[index] + + +class Gate(Rubric): + """Threshold wrapper - returns 0 if child score is below threshold. + + Useful for hard constraints like "must pass 50% of tests". + + Usage: + rubric = Gate(PassesTests(), threshold=0.5) + # Returns PassesTests() score if >= 0.5, else 0.0 + """ + + def __init__(self, rubric: Rubric, threshold: float = 1.0): + """Initialize with a rubric and threshold. + + Args: + rubric: The rubric to gate. + threshold: Minimum score required. If child returns less than + this, Gate returns 0. Default is 1.0 (must pass completely). + """ + super().__init__() + self.rubric = rubric + self.threshold = threshold + + def forward(self, action: Any, observation: Any) -> float: + """Return child score if >= threshold, else 0.""" + score = self.rubric(action, observation) + if score < self.threshold: + return 0.0 + return score + + +class WeightedSum(Rubric): + """Weighted combination of child rubrics. + + Standard aggregation pattern for multi-criteria evaluation. + + Usage: + rubric = WeightedSum( + [PassesTests(), StyleRubric()], + weights=[0.7, 0.3] + ) + """ + + def __init__(self, rubrics: List[Rubric], weights: List[float]): + """Initialize with rubrics and weights. + + Args: + rubrics: List of rubrics to combine. + weights: Weight for each rubric. Must sum to 1.0. + + Raises: + ValueError: If lengths don't match or weights don't sum to 1.0. + """ + super().__init__() + if len(rubrics) != len(weights): + raise ValueError( + f"Number of rubrics ({len(rubrics)}) must match " + f"number of weights ({len(weights)})" + ) + if abs(sum(weights) - 1.0) > 1e-6: + raise ValueError(f"Weights must sum to 1.0, got {sum(weights)}") + + for i, rubric in enumerate(rubrics): + setattr(self, f"rubric_{i}", rubric) + self._rubric_list = list(rubrics) + self._weights = list(weights) + + def forward(self, action: Any, observation: Any) -> float: + """Return weighted sum of child scores.""" + total = 0.0 + for rubric, weight in zip(self._rubric_list, self._weights): + score = rubric(action, observation) + total += score * weight + return total + + @property + def weights(self) -> List[float]: + """Get the weights (read-only copy).""" + return list(self._weights) + + +class RubricList(Rubric): + """Container for dynamic lists of rubrics. + + Analogous to nn.ModuleList. Does not define aggregation - use within + a parent rubric that implements custom logic. + + Usage: + class MultiGameRubric(Rubric): + def __init__(self, games: List[str]): + super().__init__() + self.games = RubricList([GameRubric(g) for g in games]) + + def forward(self, action, obs) -> float: + return self.games[obs.game_index](action, obs) + """ + + def __init__(self, rubrics: List[Rubric] = None): + """Initialize with optional list of rubrics. + + Args: + rubrics: Optional list of rubrics to start with. + """ + super().__init__() + self._rubrics: List[Rubric] = [] + if rubrics is not None: + for i, rubric in enumerate(rubrics): + self.append(rubric) + + def forward(self, action: Any, observation: Any) -> float: + """RubricList does not define aggregation - override in parent.""" + raise NotImplementedError( + "RubricList.forward() is not implemented. " + "Use RubricList within a parent rubric that defines aggregation." + ) + + def append(self, rubric: Rubric) -> None: + """Add a rubric to the list.""" + index = len(self._rubrics) + setattr(self, f"rubric_{index}", rubric) + self._rubrics.append(rubric) + + def extend(self, rubrics: List[Rubric]) -> None: + """Add multiple rubrics to the list.""" + for rubric in rubrics: + self.append(rubric) + + def __len__(self) -> int: + return len(self._rubrics) + + def __getitem__(self, index: int) -> Rubric: + return self._rubrics[index] + + def __iter__(self) -> Iterator[Rubric]: + return iter(self._rubrics) + + +class RubricDict(Rubric): + """Container for named rubrics with keyed access. + + Analogous to nn.ModuleDict. Enables keyed access for multi-task + environments where different tasks require different rubrics. + + Usage: + class AtariRubric(Rubric): + def __init__(self): + super().__init__() + self.games = RubricDict({ + "pong": PongRubric(), + "breakout": BreakoutRubric(), + "space_invaders": SpaceInvadersRubric(), + }) + + def forward(self, action, obs) -> float: + return self.games[obs.game_id](action, obs) + + # Access: env.rubric.games["pong"] + """ + + def __init__(self, rubrics: Dict[str, Rubric] = None): + """Initialize with optional dictionary of rubrics. + + Args: + rubrics: Optional dictionary mapping names to rubrics. + """ + super().__init__() + self._rubric_dict: Dict[str, Rubric] = {} + if rubrics is not None: + for name, rubric in rubrics.items(): + self[name] = rubric + + def forward(self, action: Any, observation: Any) -> float: + """RubricDict does not define aggregation - override in parent.""" + raise NotImplementedError( + "RubricDict.forward() is not implemented. " + "Use RubricDict within a parent rubric that defines aggregation." + ) + + def __setitem__(self, key: str, rubric: Rubric) -> None: + """Add a rubric with the given key.""" + setattr(self, key, rubric) + self._rubric_dict[key] = rubric + + def __getitem__(self, key: str) -> Rubric: + """Get rubric by key.""" + return self._rubric_dict[key] + + def __contains__(self, key: str) -> bool: + """Check if key exists.""" + return key in self._rubric_dict + + def __len__(self) -> int: + return len(self._rubric_dict) + + def __iter__(self) -> Iterator[str]: + return iter(self._rubric_dict) + + def keys(self) -> Iterator[str]: + """Iterate over keys.""" + return iter(self._rubric_dict.keys()) + + def values(self) -> Iterator[Rubric]: + """Iterate over rubrics.""" + return iter(self._rubric_dict.values()) + + def items(self) -> Iterator[Tuple[str, Rubric]]: + """Iterate over (key, rubric) pairs.""" + return iter(self._rubric_dict.items()) + + def update(self, rubrics: Union[Dict[str, Rubric], Mapping[str, Rubric]]) -> None: + """Update with rubrics from a dictionary.""" + for name, rubric in rubrics.items(): + self[name] = rubric diff --git a/src/openenv/core/rubrics/trajectory.py b/src/openenv/core/rubrics/trajectory.py new file mode 100644 index 000000000..b3bb9aa91 --- /dev/null +++ b/src/openenv/core/rubrics/trajectory.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Trajectory-based rubrics for delayed reward computation. + +These rubrics accumulate trajectory data and compute rewards based on +episode outcomes rather than individual steps. This supports scenarios +where reward signals depend on future events: + +- Terminal games (chess, Go): Win/loss known only at game end +- Plan execution: Plan quality depends on execution success +- Multi-agent games: One player's action quality depends on opponent response + +See RFC 004 "Delayed Rewards" section for design rationale. +""" + +from abc import abstractmethod +from typing import Any, Dict, List, Tuple + +from openenv.core.rubrics.base import Rubric + + +class TrajectoryRubric(Rubric): + """Abstract base for rubrics that score based on full trajectories. + + Subclasses implement: + - score_trajectory(): Compute final score from trajectory + - compute_step_rewards(): Define credit assignment strategy + + The __call__ method accumulates steps and returns rewards according + to the subclass's implementation. + + IMPORTANT: Trajectories are stored in CPU memory to avoid GPU pressure. + Environments with GPU tensors in observations must move them to CPU + before returning from step(). + + Known limitation: Very long episodes (thousands of steps) may consume + significant CPU memory. For such cases, consider streaming rubrics. + + Usage: + class WinLossRubric(TrajectoryRubric): + def score_trajectory(self, trajectory): + _, final_obs = trajectory[-1] + return 1.0 if final_obs.metadata.get('won') else 0.0 + + def compute_step_rewards(self): + # Equal credit to all steps + score = self.score_trajectory(self._trajectory) + return [score] * len(self._trajectory) + + rubric = WinLossRubric() + for action, obs in episode: + reward = rubric(action, obs) # 0.0 until done + step_rewards = rubric.compute_step_rewards() # Credit assignment + """ + + _trajectory: List[Tuple[Any, Any]] + intermediate_reward: float + + def __init__(self, intermediate_reward: float = 0.0): + """Initialize trajectory rubric. + + Args: + intermediate_reward: Value to return for non-terminal steps. + Defaults to 0.0. + """ + super().__init__() + self.intermediate_reward = intermediate_reward + self._trajectory = [] + + def forward(self, action: Any, observation: Any) -> float: + """Accumulate step and return reward. + + Returns intermediate_reward until done, then computes trajectory score. + + Args: + action: The action taken. + observation: The resulting observation. Must have a 'done' attribute. + + Returns: + intermediate_reward if not done, else score_trajectory() result. + """ + self._trajectory.append((action, observation)) + + if getattr(observation, "done", False): + return self.score_trajectory(self._trajectory) + else: + return self.intermediate_reward + + @abstractmethod + def score_trajectory(self, trajectory: List[Tuple[Any, Any]]) -> float: + """Score the complete trajectory. Return 0.0-1.0. + + Called when observation.done=True. + + Args: + trajectory: List of (action, observation) tuples. + + Returns: + Final trajectory score (typically 0.0 to 1.0). + """ + raise NotImplementedError + + @abstractmethod + def compute_step_rewards(self) -> List[float]: + """Compute per-step rewards from the accumulated trajectory. + + Returns: + List of rewards, one per step. Length matches len(trajectory). + + Define your credit assignment strategy here (e.g., discounting, + assigning all credit to specific steps, etc.). + """ + raise NotImplementedError + + def reset(self) -> None: + """Clear accumulated trajectory. Call on env.reset().""" + self._trajectory = [] + + @property + def trajectory(self) -> List[Tuple[Any, Any]]: + """Current trajectory (read-only copy).""" + return list(self._trajectory) + + def state_dict(self) -> Dict[str, Any]: + """Serialize configuration (not trajectory data).""" + return {"intermediate_reward": self.intermediate_reward} + + def load_state_dict(self, state: Dict[str, Any]) -> None: + """Load configuration from checkpoint.""" + if "intermediate_reward" in state: + self.intermediate_reward = state["intermediate_reward"] + + +class ExponentialDiscountingTrajectoryRubric(TrajectoryRubric): + """TrajectoryRubric with exponential discounting for credit assignment. + + Per-step reward: r_t = gamma^(T-1-t) * R_final + + With gamma=0.99, later steps get higher reward (they're "closer" to the outcome). + With gamma=1.0, all steps get equal reward. + With gamma=0.0, only the final step gets reward. + + This is the standard temporal discounting used in reinforcement learning, + applied retroactively once the episode outcome is known. + + Usage: + class ChessRubric(ExponentialDiscountingTrajectoryRubric): + def score_trajectory(self, trajectory): + _, final_obs = trajectory[-1] + outcome = final_obs.metadata.get('winner') + if outcome == 'agent': return 1.0 + elif outcome == 'opponent': return 0.0 + else: return 0.5 # Draw + + rubric = ChessRubric(gamma=0.99) + reward = rubric(action, obs) # 0.0 until done, then final score + step_rewards = rubric.compute_step_rewards() # Discounted per-step rewards + """ + + gamma: float + + def __init__(self, gamma: float = 0.99, intermediate_reward: float = 0.0): + """Initialize with discount factor. + + Args: + gamma: Discount factor in [0, 1]. Higher values give more credit + to early moves. 0.99 is a common choice. + intermediate_reward: Value to return for non-terminal steps. + """ + super().__init__(intermediate_reward=intermediate_reward) + if not 0.0 <= gamma <= 1.0: + raise ValueError(f"gamma must be in [0, 1], got {gamma}") + self.gamma = gamma + + def compute_step_rewards(self) -> List[float]: + """Apply exponential discounting from final reward. + + Returns: + List of discounted rewards. step_rewards[t] = gamma^(T-1-t) * R_final + where T is the trajectory length and R_final is score_trajectory(). + """ + if not self._trajectory: + return [] + + final_score = self.score_trajectory(self._trajectory) + T = len(self._trajectory) + return [final_score * (self.gamma ** (T - 1 - t)) for t in range(T)] + + def state_dict(self) -> Dict[str, Any]: + """Serialize configuration.""" + state = super().state_dict() + state["gamma"] = self.gamma + return state + + def load_state_dict(self, state: Dict[str, Any]) -> None: + """Load configuration from checkpoint.""" + super().load_state_dict(state) + if "gamma" in state: + self.gamma = state["gamma"] diff --git a/tests/core/test_rubrics/__init__.py b/tests/core/test_rubrics/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/tests/core/test_rubrics/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/core/test_rubrics/test_base_rubric.py b/tests/core/test_rubrics/test_base_rubric.py new file mode 100644 index 000000000..570163f33 --- /dev/null +++ b/tests/core/test_rubrics/test_base_rubric.py @@ -0,0 +1,206 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for the base Rubric class.""" + +import pytest +from typing import Any + +from openenv.core.rubrics.base import Rubric + + +class SimpleRubric(Rubric): + """Concrete rubric that returns a fixed score.""" + + def __init__(self, score: float = 1.0): + super().__init__() + self.score = score + + def forward(self, action: Any, observation: Any) -> float: + return self.score + + +class CompositeRubric(Rubric): + """Rubric with child rubrics.""" + + def __init__(self): + super().__init__() + self.child1 = SimpleRubric(0.5) + self.child2 = SimpleRubric(0.7) + + def forward(self, action: Any, observation: Any) -> float: + return (self.child1(action, observation) + self.child2(action, observation)) / 2 + + +class TestRubricBasics: + """Test basic Rubric functionality.""" + + def test_forward_is_abstract(self): + """Cannot instantiate Rubric directly.""" + with pytest.raises(TypeError): + Rubric() + + def test_simple_rubric_call(self): + """Calling a rubric invokes forward().""" + rubric = SimpleRubric(0.8) + result = rubric("action", "observation") + assert result == 0.8 + + def test_last_score_tracked(self): + """last_score is updated after each call.""" + rubric = SimpleRubric(0.6) + assert rubric.last_score is None + + rubric("action", "observation") + assert rubric.last_score == 0.6 + + +class TestChildRegistration: + """Test auto-registration of child rubrics.""" + + def test_children_registered(self): + """Child rubrics are registered when assigned as attributes.""" + rubric = CompositeRubric() + + children = list(rubric.children()) + assert len(children) == 2 + assert rubric.child1 in children + assert rubric.child2 in children + + def test_named_children(self): + """named_children returns name-rubric pairs.""" + rubric = CompositeRubric() + + named = dict(rubric.named_children()) + assert "child1" in named + assert "child2" in named + assert named["child1"].score == 0.5 + assert named["child2"].score == 0.7 + + def test_rubrics_recursive(self): + """rubrics() returns all descendants.""" + + class NestedRubric(Rubric): + def __init__(self): + super().__init__() + self.inner = CompositeRubric() + + def forward(self, action, observation): + return self.inner(action, observation) + + rubric = NestedRubric() + + all_rubrics = list(rubric.rubrics()) + # inner, inner.child1, inner.child2 + assert len(all_rubrics) == 3 + + def test_named_rubrics_paths(self): + """named_rubrics() returns dot-separated paths.""" + + class NestedRubric(Rubric): + def __init__(self): + super().__init__() + self.inner = CompositeRubric() + + def forward(self, action, observation): + return self.inner(action, observation) + + rubric = NestedRubric() + + paths = dict(rubric.named_rubrics()) + assert "inner" in paths + assert "inner.child1" in paths + assert "inner.child2" in paths + + def test_get_rubric_by_path(self): + """get_rubric() navigates dot-separated paths.""" + + class NestedRubric(Rubric): + def __init__(self): + super().__init__() + self.inner = CompositeRubric() + + def forward(self, action, observation): + return self.inner(action, observation) + + rubric = NestedRubric() + + assert rubric.get_rubric("inner") is rubric.inner + assert rubric.get_rubric("inner.child1") is rubric.inner.child1 + + def test_get_rubric_invalid_path(self): + """get_rubric() raises KeyError for invalid paths.""" + rubric = CompositeRubric() + + with pytest.raises(KeyError): + rubric.get_rubric("nonexistent") + + +class TestHooks: + """Test forward hook functionality.""" + + def test_forward_hook_called(self): + """Forward hooks are called after forward().""" + rubric = SimpleRubric(0.9) + hook_calls = [] + + def hook(r, action, obs, result): + hook_calls.append((action, obs, result)) + + rubric.register_forward_hook(hook) + rubric("my_action", "my_obs") + + assert len(hook_calls) == 1 + assert hook_calls[0] == ("my_action", "my_obs", 0.9) + + def test_forward_pre_hook_called(self): + """Pre-forward hooks are called before forward().""" + rubric = SimpleRubric(0.9) + hook_calls = [] + + def pre_hook(r, action, obs): + hook_calls.append((action, obs)) + + rubric.register_forward_pre_hook(pre_hook) + rubric("my_action", "my_obs") + + assert len(hook_calls) == 1 + assert hook_calls[0] == ("my_action", "my_obs") + + def test_multiple_hooks(self): + """Multiple hooks can be registered.""" + rubric = SimpleRubric(0.5) + results = [] + + rubric.register_forward_hook(lambda r, a, o, res: results.append(1)) + rubric.register_forward_hook(lambda r, a, o, res: results.append(2)) + + rubric("action", "obs") + + assert results == [1, 2] + + +class TestReset: + """Test reset functionality.""" + + def test_default_reset_is_noop(self): + """Default reset() does nothing (for stateless rubrics).""" + rubric = SimpleRubric(0.5) + rubric.reset() # Should not raise + + +class TestStateDictSerialization: + """Test state_dict serialization.""" + + def test_default_state_dict_empty(self): + """Default state_dict returns empty dict.""" + rubric = SimpleRubric(0.5) + assert rubric.state_dict() == {} + + def test_load_state_dict_accepts_empty(self): + """load_state_dict accepts empty dict.""" + rubric = SimpleRubric(0.5) + rubric.load_state_dict({}) # Should not raise diff --git a/tests/core/test_rubrics/test_containers.py b/tests/core/test_rubrics/test_containers.py new file mode 100644 index 000000000..5edc38b23 --- /dev/null +++ b/tests/core/test_rubrics/test_containers.py @@ -0,0 +1,414 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for container rubrics: Sequential, Gate, WeightedSum, RubricList, RubricDict.""" + +import pytest +from typing import Any + +from openenv.core.rubrics.base import Rubric +from openenv.core.rubrics.containers import ( + Sequential, + Gate, + WeightedSum, + RubricList, + RubricDict, +) + + +class FixedRubric(Rubric): + """Concrete rubric that returns a fixed score.""" + + def __init__(self, score: float = 1.0): + super().__init__() + self.score = score + + def forward(self, action: Any, observation: Any) -> float: + return self.score + + +class CountingRubric(Rubric): + """Rubric that counts how many times it's called.""" + + def __init__(self, score: float = 1.0): + super().__init__() + self.score = score + self.call_count = 0 + + def forward(self, action: Any, observation: Any) -> float: + self.call_count += 1 + return self.score + + +class TestSequential: + """Test Sequential container.""" + + def test_empty_sequential(self): + """Empty sequential returns 1.0.""" + rubric = Sequential() + result = rubric("action", "obs") + assert result == 1.0 + + def test_single_rubric(self): + """Single rubric returns its score.""" + rubric = Sequential(FixedRubric(0.8)) + result = rubric("action", "obs") + assert result == 0.8 + + def test_multiple_rubrics_all_pass(self): + """Multiple passing rubrics return last score.""" + rubric = Sequential( + FixedRubric(1.0), + FixedRubric(0.8), + FixedRubric(0.9), + ) + result = rubric("action", "obs") + assert result == 0.9 + + def test_fail_fast_on_zero(self): + """Stops immediately when a rubric returns 0.""" + r1 = CountingRubric(1.0) + r2 = CountingRubric(0.0) # Fails + r3 = CountingRubric(1.0) + + rubric = Sequential(r1, r2, r3) + result = rubric("action", "obs") + + assert result == 0.0 + assert r1.call_count == 1 + assert r2.call_count == 1 + assert r3.call_count == 0 # Should not be called + + def test_children_registered(self): + """Child rubrics are auto-registered.""" + r1 = FixedRubric(0.5) + r2 = FixedRubric(0.7) + + rubric = Sequential(r1, r2) + + children = list(rubric.children()) + assert len(children) == 2 + assert r1 in children + assert r2 in children + + def test_len_and_getitem(self): + """__len__ and __getitem__ work correctly.""" + r1 = FixedRubric(0.5) + r2 = FixedRubric(0.7) + + rubric = Sequential(r1, r2) + + assert len(rubric) == 2 + assert rubric[0] is r1 + assert rubric[1] is r2 + + +class TestGate: + """Test Gate container.""" + + def test_gate_passes_above_threshold(self): + """Returns child score when above threshold.""" + rubric = Gate(FixedRubric(0.8), threshold=0.5) + result = rubric("action", "obs") + assert result == 0.8 + + def test_gate_fails_below_threshold(self): + """Returns 0 when child score is below threshold.""" + rubric = Gate(FixedRubric(0.4), threshold=0.5) + result = rubric("action", "obs") + assert result == 0.0 + + def test_gate_passes_at_threshold(self): + """Returns score when exactly at threshold.""" + rubric = Gate(FixedRubric(0.5), threshold=0.5) + result = rubric("action", "obs") + assert result == 0.5 + + def test_gate_default_threshold(self): + """Default threshold is 1.0.""" + # Passes only with perfect score + rubric = Gate(FixedRubric(1.0)) + assert rubric("action", "obs") == 1.0 + + rubric2 = Gate(FixedRubric(0.99)) + assert rubric2("action", "obs") == 0.0 + + def test_gate_child_registered(self): + """Child rubric is auto-registered.""" + child = FixedRubric(0.5) + rubric = Gate(child, threshold=0.3) + + children = list(rubric.children()) + assert len(children) == 1 + assert child in children + + +class TestWeightedSum: + """Test WeightedSum container.""" + + def test_single_rubric_weight_one(self): + """Single rubric with weight 1.0.""" + rubric = WeightedSum([FixedRubric(0.8)], [1.0]) + result = rubric("action", "obs") + assert result == 0.8 + + def test_two_rubrics_equal_weights(self): + """Two rubrics with equal weights.""" + rubric = WeightedSum( + [FixedRubric(0.6), FixedRubric(0.8)], + [0.5, 0.5], + ) + result = rubric("action", "obs") + assert result == pytest.approx(0.7) + + def test_weighted_combination(self): + """Weighted combination with different weights.""" + rubric = WeightedSum( + [FixedRubric(1.0), FixedRubric(0.0)], + [0.7, 0.3], + ) + result = rubric("action", "obs") + assert result == pytest.approx(0.7) + + def test_weights_must_sum_to_one(self): + """Raises error if weights don't sum to 1.0.""" + with pytest.raises(ValueError, match="must sum to 1.0"): + WeightedSum([FixedRubric(0.5), FixedRubric(0.5)], [0.5, 0.3]) + + def test_lengths_must_match(self): + """Raises error if lengths don't match.""" + with pytest.raises(ValueError, match="must match"): + WeightedSum([FixedRubric(0.5), FixedRubric(0.5)], [1.0]) + + def test_children_registered(self): + """Child rubrics are auto-registered.""" + r1 = FixedRubric(0.5) + r2 = FixedRubric(0.7) + + rubric = WeightedSum([r1, r2], [0.5, 0.5]) + + children = list(rubric.children()) + assert len(children) == 2 + assert r1 in children + assert r2 in children + + def test_weights_property(self): + """weights property returns copy of weights.""" + rubric = WeightedSum([FixedRubric(0.5)], [1.0]) + + weights = rubric.weights + assert weights == [1.0] + + # Modifying copy shouldn't affect internal state + weights.append(0.5) + assert rubric.weights == [1.0] + + +class TestRubricList: + """Test RubricList container.""" + + def test_empty_list(self): + """Empty list has length 0.""" + rubric = RubricList() + assert len(rubric) == 0 + + def test_init_with_rubrics(self): + """Initialize with list of rubrics.""" + r1 = FixedRubric(0.5) + r2 = FixedRubric(0.7) + + rubric = RubricList([r1, r2]) + + assert len(rubric) == 2 + assert rubric[0] is r1 + assert rubric[1] is r2 + + def test_append(self): + """Append adds rubric to list.""" + rubric = RubricList() + r1 = FixedRubric(0.5) + + rubric.append(r1) + + assert len(rubric) == 1 + assert rubric[0] is r1 + + def test_extend(self): + """Extend adds multiple rubrics.""" + rubric = RubricList() + r1 = FixedRubric(0.5) + r2 = FixedRubric(0.7) + + rubric.extend([r1, r2]) + + assert len(rubric) == 2 + + def test_iteration(self): + """Can iterate over rubrics.""" + r1 = FixedRubric(0.5) + r2 = FixedRubric(0.7) + + rubric = RubricList([r1, r2]) + + items = list(rubric) + assert items == [r1, r2] + + def test_children_registered(self): + """Child rubrics are auto-registered.""" + r1 = FixedRubric(0.5) + r2 = FixedRubric(0.7) + + rubric = RubricList([r1, r2]) + + children = list(rubric.children()) + assert len(children) == 2 + assert r1 in children + assert r2 in children + + def test_forward_not_implemented(self): + """forward() raises NotImplementedError.""" + rubric = RubricList([FixedRubric(0.5)]) + + with pytest.raises(NotImplementedError): + rubric("action", "obs") + + +class TestRubricDict: + """Test RubricDict container.""" + + def test_empty_dict(self): + """Empty dict has length 0.""" + rubric = RubricDict() + assert len(rubric) == 0 + + def test_init_with_dict(self): + """Initialize with dictionary of rubrics.""" + r1 = FixedRubric(0.5) + r2 = FixedRubric(0.7) + + rubric = RubricDict({"game1": r1, "game2": r2}) + + assert len(rubric) == 2 + assert rubric["game1"] is r1 + assert rubric["game2"] is r2 + + def test_setitem_and_getitem(self): + """__setitem__ and __getitem__ work.""" + rubric = RubricDict() + r1 = FixedRubric(0.5) + + rubric["game1"] = r1 + + assert rubric["game1"] is r1 + + def test_contains(self): + """__contains__ works.""" + rubric = RubricDict({"game1": FixedRubric(0.5)}) + + assert "game1" in rubric + assert "game2" not in rubric + + def test_keys_values_items(self): + """keys(), values(), items() work.""" + r1 = FixedRubric(0.5) + r2 = FixedRubric(0.7) + + rubric = RubricDict({"game1": r1, "game2": r2}) + + assert set(rubric.keys()) == {"game1", "game2"} + assert set(rubric.values()) == {r1, r2} + assert set(rubric.items()) == {("game1", r1), ("game2", r2)} + + def test_iteration(self): + """Can iterate over keys.""" + rubric = RubricDict({"game1": FixedRubric(0.5), "game2": FixedRubric(0.7)}) + + keys = list(rubric) + assert set(keys) == {"game1", "game2"} + + def test_update(self): + """update() adds rubrics from dict.""" + rubric = RubricDict({"game1": FixedRubric(0.5)}) + rubric.update({"game2": FixedRubric(0.7)}) + + assert len(rubric) == 2 + assert "game2" in rubric + + def test_children_registered(self): + """Child rubrics are auto-registered.""" + r1 = FixedRubric(0.5) + r2 = FixedRubric(0.7) + + rubric = RubricDict({"game1": r1, "game2": r2}) + + children = list(rubric.children()) + assert len(children) == 2 + assert r1 in children + assert r2 in children + + def test_forward_not_implemented(self): + """forward() raises NotImplementedError.""" + rubric = RubricDict({"game1": FixedRubric(0.5)}) + + with pytest.raises(NotImplementedError): + rubric("action", "obs") + + +class TestContainerComposition: + """Test composing containers together.""" + + def test_sequential_of_gates(self): + """Sequential of Gate rubrics for hierarchical gating.""" + rubric = Sequential( + Gate(FixedRubric(1.0)), # Must pass completely + Gate(FixedRubric(0.6), threshold=0.5), # Must be >= 0.5 + FixedRubric(0.9), # Final score + ) + result = rubric("action", "obs") + assert result == 0.9 + + def test_sequential_fails_early(self): + """Sequential stops when Gate fails.""" + r3 = CountingRubric(0.9) + + rubric = Sequential( + Gate(FixedRubric(0.3), threshold=0.5), # Fails + r3, + ) + result = rubric("action", "obs") + + assert result == 0.0 + assert r3.call_count == 0 + + def test_weighted_sum_of_gates(self): + """WeightedSum with Gate rubrics.""" + rubric = WeightedSum( + [ + Gate(FixedRubric(0.8), threshold=0.5), # Passes: 0.8 + Gate(FixedRubric(0.3), threshold=0.5), # Fails: 0.0 + ], + [0.6, 0.4], + ) + result = rubric("action", "obs") + # 0.8 * 0.6 + 0.0 * 0.4 = 0.48 + assert result == pytest.approx(0.48) + + def test_nested_named_rubrics(self): + """Can traverse nested rubrics with named_rubrics().""" + inner = Sequential( + Gate(FixedRubric(1.0), threshold=0.5), + FixedRubric(0.8), + ) + outer = RubricDict({"task": inner}) + + paths = dict(outer.named_rubrics()) + + # Should have paths for all nested rubrics + assert "task" in paths + assert "task.rubric_0" in paths # Gate + assert "task.rubric_1" in paths # FixedRubric + # Gate's child + assert "task.rubric_0.rubric" in paths diff --git a/tests/core/test_rubrics/test_environment_integration.py b/tests/core/test_rubrics/test_environment_integration.py new file mode 100644 index 000000000..ab6fe6d2f --- /dev/null +++ b/tests/core/test_rubrics/test_environment_integration.py @@ -0,0 +1,256 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for rubric integration with Environment base class.""" + +import pytest +from typing import Any, Optional, List, Tuple + +from openenv.core.env_server.interfaces import Environment +from openenv.core.env_server.types import Action, Observation, State +from openenv.core.rubrics import Rubric, TrajectoryRubric + + +# Test fixtures - using Pydantic models (not dataclasses) + + +class MockAction(Action): + """Simple action for testing.""" + + content: str = "test" + + +class MockObservation(Observation): + """Simple observation for testing.""" + + content: str = "" + + +class MockState(State): + """Simple state for testing.""" + + pass + + +class FixedRubric(Rubric): + """Rubric that returns a fixed score.""" + + def __init__(self, score: float = 1.0): + super().__init__() + self.score = score + + def forward(self, action: Any, observation: Any) -> float: + return self.score + + +class CountingRubric(Rubric): + """Rubric that counts calls and returns action-dependent score.""" + + def __init__(self): + super().__init__() + self.call_count = 0 + + def forward(self, action: Any, observation: Any) -> float: + self.call_count += 1 + # Return score based on action content + if hasattr(action, "content"): + if action.content == "good": + return 1.0 + elif action.content == "bad": + return 0.0 + return 0.5 + + +class MockTrajectoryRubric(TrajectoryRubric): + """Trajectory rubric for testing reset behavior.""" + + def score_trajectory(self, trajectory: List[Tuple[Any, Any]]) -> float: + return 1.0 if trajectory else 0.0 + + def compute_step_rewards(self) -> List[float]: + return [1.0] * len(self._trajectory) + + +class SimpleEnvironment(Environment[MockAction, MockObservation, MockState]): + """Minimal environment implementation for testing.""" + + def __init__(self, rubric: Optional[Rubric] = None): + super().__init__(rubric=rubric) + self._state = MockState() + + def reset( + self, + seed: Optional[int] = None, + episode_id: Optional[str] = None, + **kwargs: Any, + ) -> MockObservation: + self._reset_rubric() # Reset rubric state + self._state = MockState(episode_id=episode_id) + return MockObservation(content="initial") + + def step( + self, + action: MockAction, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> MockObservation: + obs = MockObservation(content=f"response to {action.content}") + obs.reward = self._apply_rubric(action, obs) + return obs + + @property + def state(self) -> MockState: + return self._state + + +class TestEnvironmentRubricIntegration: + """Test rubric integration with Environment base class.""" + + def test_environment_without_rubric(self): + """Environment works without a rubric.""" + env = SimpleEnvironment() + assert env.rubric is None + + obs = env.reset() + assert obs.content == "initial" + + obs = env.step(MockAction(content="test")) + assert obs.reward == 0.0 # Default when no rubric + + def test_environment_with_rubric(self): + """Environment uses rubric for reward computation.""" + rubric = FixedRubric(0.75) + env = SimpleEnvironment(rubric=rubric) + + assert env.rubric is rubric + + env.reset() + obs = env.step(MockAction(content="test")) + + assert obs.reward == 0.75 + + def test_rubric_called_each_step(self): + """Rubric is called on each step.""" + rubric = CountingRubric() + env = SimpleEnvironment(rubric=rubric) + + env.reset() + assert rubric.call_count == 0 + + env.step(MockAction(content="a")) + assert rubric.call_count == 1 + + env.step(MockAction(content="b")) + assert rubric.call_count == 2 + + def test_rubric_receives_action_and_observation(self): + """Rubric receives both action and observation.""" + rubric = CountingRubric() + env = SimpleEnvironment(rubric=rubric) + + env.reset() + + obs = env.step(MockAction(content="good")) + assert obs.reward == 1.0 + + obs = env.step(MockAction(content="bad")) + assert obs.reward == 0.0 + + def test_rubric_reset_on_env_reset(self): + """Rubric state is reset when environment resets.""" + rubric = MockTrajectoryRubric() + env = SimpleEnvironment(rubric=rubric) + + env.reset() + env.step(MockAction(content="a")) + env.step(MockAction(content="b")) + + assert len(rubric._trajectory) == 2 + + env.reset() + assert len(rubric._trajectory) == 0 # Reset clears trajectory + + def test_rubric_introspection(self): + """Can introspect rubric from environment.""" + + class CompositeRubric(Rubric): + def __init__(self): + super().__init__() + self.child1 = FixedRubric(0.5) + self.child2 = FixedRubric(0.8) + + def forward(self, action, obs): + return (self.child1(action, obs) + self.child2(action, obs)) / 2 + + rubric = CompositeRubric() + env = SimpleEnvironment(rubric=rubric) + + env.reset() + env.step(MockAction(content="test")) + + # Can introspect child scores + assert env.rubric is not None + named = dict(env.rubric.named_children()) + assert "child1" in named + assert "child2" in named + assert named["child1"].last_score == 0.5 + assert named["child2"].last_score == 0.8 + + def test_apply_rubric_without_rubric(self): + """_apply_rubric returns 0.0 when no rubric is set.""" + env = SimpleEnvironment() + action = MockAction(content="test") + obs = MockObservation(content="result") + + result = env._apply_rubric(action, obs) + assert result == 0.0 + + def test_reset_rubric_without_rubric(self): + """_reset_rubric is safe when no rubric is set.""" + env = SimpleEnvironment() + env._reset_rubric() # Should not raise + + +class TestEnvironmentRubricLifecycle: + """Test rubric lifecycle with multiple episodes.""" + + def test_multiple_episodes(self): + """Rubric handles multiple episodes correctly.""" + rubric = MockTrajectoryRubric() + env = SimpleEnvironment(rubric=rubric) + + # Episode 1 + env.reset() + env.step(MockAction(content="a")) + env.step(MockAction(content="b")) + ep1_len = len(rubric._trajectory) + + # Episode 2 + env.reset() + env.step(MockAction(content="c")) + ep2_len = len(rubric._trajectory) + + assert ep1_len == 2 + assert ep2_len == 1 # Reset cleared previous episode + + def test_rubric_hooks_work(self): + """Rubric hooks work through environment.""" + rubric = FixedRubric(0.9) + env = SimpleEnvironment(rubric=rubric) + + hook_calls = [] + + def hook(r, action, obs, result): + hook_calls.append(result) + + rubric.register_forward_hook(hook) + + env.reset() + env.step(MockAction(content="a")) + env.step(MockAction(content="b")) + + assert len(hook_calls) == 2 + assert hook_calls == [0.9, 0.9] diff --git a/tests/core/test_rubrics/test_trajectory_rubric.py b/tests/core/test_rubrics/test_trajectory_rubric.py new file mode 100644 index 000000000..60a6f5eec --- /dev/null +++ b/tests/core/test_rubrics/test_trajectory_rubric.py @@ -0,0 +1,362 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for TrajectoryRubric and ExponentialDiscountingTrajectoryRubric.""" + +import pytest +from dataclasses import dataclass +from typing import Any, List, Tuple + +from openenv.core.rubrics.trajectory import ( + TrajectoryRubric, + ExponentialDiscountingTrajectoryRubric, +) + + +@dataclass +class MockObservation: + """Mock observation for testing.""" + + done: bool = False + metadata: dict = None + + def __post_init__(self): + if self.metadata is None: + self.metadata = {} + + +@dataclass +class MockAction: + """Mock action for testing.""" + + value: str = "move" + metadata: dict = None + + def __post_init__(self): + if self.metadata is None: + self.metadata = {} + + +class WinLossRubric(ExponentialDiscountingTrajectoryRubric): + """Example rubric that scores 1.0 for win, 0.0 for loss, 0.5 for draw.""" + + def score_trajectory(self, trajectory: List[Tuple[Any, Any]]) -> float: + if not trajectory: + return 0.0 + _, final_obs = trajectory[-1] + outcome = getattr(final_obs, "metadata", {}).get("outcome") + if outcome == "win": + return 1.0 + elif outcome == "loss": + return 0.0 + else: + return 0.5 + + +class EqualCreditRubric(TrajectoryRubric): + """Rubric that gives equal credit to all steps.""" + + def score_trajectory(self, trajectory: List[Tuple[Any, Any]]) -> float: + if not trajectory: + return 0.0 + _, final_obs = trajectory[-1] + return final_obs.metadata.get("score", 0.0) + + def compute_step_rewards(self) -> List[float]: + if not self._trajectory: + return [] + score = self.score_trajectory(self._trajectory) + return [score] * len(self._trajectory) + + +class TestTrajectoryRubricBasics: + """Test basic TrajectoryRubric functionality.""" + + def test_abstract_methods_required(self): + """Cannot instantiate TrajectoryRubric without implementing abstract methods.""" + with pytest.raises(TypeError): + TrajectoryRubric() + + def test_returns_intermediate_until_done(self): + """Returns intermediate_reward for non-terminal steps.""" + rubric = EqualCreditRubric(intermediate_reward=0.0) + + obs1 = MockObservation(done=False) + result = rubric(MockAction(), obs1) + + assert result == 0.0 + assert len(rubric._trajectory) == 1 + + def test_returns_score_when_done(self): + """Returns trajectory score when done=True.""" + rubric = EqualCreditRubric(intermediate_reward=0.0) + + obs1 = MockObservation(done=False) + obs2 = MockObservation(done=True, metadata={"score": 0.8}) + + rubric(MockAction(), obs1) + result = rubric(MockAction(), obs2) + + assert result == 0.8 + assert len(rubric._trajectory) == 2 + + def test_custom_intermediate_reward(self): + """Intermediate reward can be customized.""" + rubric = EqualCreditRubric(intermediate_reward=0.1) + + obs = MockObservation(done=False) + result = rubric(MockAction(), obs) + + assert result == 0.1 + + def test_reset_clears_trajectory(self): + """reset() clears the accumulated trajectory.""" + rubric = EqualCreditRubric() + + rubric(MockAction(), MockObservation(done=False)) + rubric(MockAction(), MockObservation(done=False)) + assert len(rubric._trajectory) == 2 + + rubric.reset() + assert len(rubric._trajectory) == 0 + + def test_trajectory_property_returns_copy(self): + """trajectory property returns a copy.""" + rubric = EqualCreditRubric() + + rubric(MockAction(), MockObservation(done=False)) + trajectory = rubric.trajectory + + # Modifying the copy should not affect internal state + trajectory.clear() + assert len(rubric._trajectory) == 1 + + +class TestExponentialDiscounting: + """Test ExponentialDiscountingTrajectoryRubric.""" + + def test_gamma_validation(self): + """Gamma must be in [0, 1].""" + with pytest.raises(ValueError): + WinLossRubric(gamma=-0.1) + + with pytest.raises(ValueError): + WinLossRubric(gamma=1.5) + + def test_gamma_one_equal_credit(self): + """With gamma=1.0, all steps get equal credit.""" + rubric = WinLossRubric(gamma=1.0) + + # Simulate 3-step episode with win + rubric(MockAction(), MockObservation(done=False)) + rubric(MockAction(), MockObservation(done=False)) + rubric(MockAction(), MockObservation(done=True, metadata={"outcome": "win"})) + + step_rewards = rubric.compute_step_rewards() + + assert len(step_rewards) == 3 + assert step_rewards[0] == 1.0 + assert step_rewards[1] == 1.0 + assert step_rewards[2] == 1.0 + + def test_gamma_zero_final_only(self): + """With gamma=0.0, only final step gets reward.""" + rubric = WinLossRubric(gamma=0.0) + + rubric(MockAction(), MockObservation(done=False)) + rubric(MockAction(), MockObservation(done=False)) + rubric(MockAction(), MockObservation(done=True, metadata={"outcome": "win"})) + + step_rewards = rubric.compute_step_rewards() + + assert step_rewards == [0.0, 0.0, 1.0] + + def test_gamma_discounting_pattern(self): + """With 0 < gamma < 1, later steps get higher reward.""" + rubric = WinLossRubric(gamma=0.5) + + rubric(MockAction(), MockObservation(done=False)) + rubric(MockAction(), MockObservation(done=False)) + rubric(MockAction(), MockObservation(done=True, metadata={"outcome": "win"})) + + step_rewards = rubric.compute_step_rewards() + + # r_t = gamma^(T-1-t) * R_final, T=3, R_final=1.0 + # t=0: 0.5^2 = 0.25 + # t=1: 0.5^1 = 0.5 + # t=2: 0.5^0 = 1.0 + assert step_rewards[0] == pytest.approx(0.25) + assert step_rewards[1] == pytest.approx(0.5) + assert step_rewards[2] == pytest.approx(1.0) + + def test_gamma_099_standard_discounting(self): + """With gamma=0.99, standard RL discounting pattern.""" + rubric = WinLossRubric(gamma=0.99) + + # 5-step episode with win + for _ in range(4): + rubric(MockAction(), MockObservation(done=False)) + rubric(MockAction(), MockObservation(done=True, metadata={"outcome": "win"})) + + step_rewards = rubric.compute_step_rewards() + + # Verify discounting order: later steps get more + for i in range(len(step_rewards) - 1): + assert step_rewards[i] < step_rewards[i + 1] + + # Final step gets full reward + assert step_rewards[-1] == pytest.approx(1.0) + + def test_loss_outcome(self): + """Loss returns 0.0 for all steps.""" + rubric = WinLossRubric(gamma=0.99) + + rubric(MockAction(), MockObservation(done=False)) + rubric(MockAction(), MockObservation(done=True, metadata={"outcome": "loss"})) + + step_rewards = rubric.compute_step_rewards() + + assert step_rewards == [0.0, 0.0] + + def test_draw_outcome(self): + """Draw returns 0.5 for all steps (with discounting).""" + rubric = WinLossRubric(gamma=1.0) + + rubric(MockAction(), MockObservation(done=False)) + rubric(MockAction(), MockObservation(done=True, metadata={"outcome": "draw"})) + + step_rewards = rubric.compute_step_rewards() + + assert step_rewards == [0.5, 0.5] + + def test_empty_trajectory(self): + """compute_step_rewards() returns empty list for empty trajectory.""" + rubric = WinLossRubric(gamma=0.99) + + step_rewards = rubric.compute_step_rewards() + + assert step_rewards == [] + + +class TestTrajectoryRubricStateSerialization: + """Test state_dict serialization for trajectory rubrics.""" + + def test_trajectory_rubric_state_dict(self): + """TrajectoryRubric serializes intermediate_reward.""" + rubric = EqualCreditRubric(intermediate_reward=0.2) + + state = rubric.state_dict() + + assert state["intermediate_reward"] == 0.2 + + def test_trajectory_rubric_load_state_dict(self): + """TrajectoryRubric loads intermediate_reward from state.""" + rubric = EqualCreditRubric(intermediate_reward=0.0) + + rubric.load_state_dict({"intermediate_reward": 0.3}) + + assert rubric.intermediate_reward == 0.3 + + def test_exponential_discounting_state_dict(self): + """ExponentialDiscountingTrajectoryRubric serializes gamma.""" + rubric = WinLossRubric(gamma=0.95, intermediate_reward=0.1) + + state = rubric.state_dict() + + assert state["gamma"] == 0.95 + assert state["intermediate_reward"] == 0.1 + + def test_exponential_discounting_load_state_dict(self): + """ExponentialDiscountingTrajectoryRubric loads gamma from state.""" + rubric = WinLossRubric(gamma=0.99) + + rubric.load_state_dict({"gamma": 0.9, "intermediate_reward": 0.2}) + + assert rubric.gamma == 0.9 + assert rubric.intermediate_reward == 0.2 + + +class TestTrajectoryRubricHooks: + """Test that hooks work with trajectory rubrics.""" + + def test_hooks_called_each_step(self): + """Forward hooks are called on each step.""" + rubric = EqualCreditRubric() + hook_calls = [] + + def hook(r, action, obs, result): + hook_calls.append(result) + + rubric.register_forward_hook(hook) + + rubric(MockAction(), MockObservation(done=False)) + rubric(MockAction(), MockObservation(done=True, metadata={"score": 0.7})) + + assert len(hook_calls) == 2 + assert hook_calls[0] == 0.0 # intermediate + assert hook_calls[1] == 0.7 # final + + +class TestTrajectoryRubricEdgeCases: + """Test edge cases.""" + + def test_single_step_episode(self): + """Single-step episode (immediately done).""" + rubric = WinLossRubric(gamma=0.99) + + rubric(MockAction(), MockObservation(done=True, metadata={"outcome": "win"})) + + step_rewards = rubric.compute_step_rewards() + + assert step_rewards == [1.0] + + def test_very_long_episode(self): + """Long episode (100 steps).""" + rubric = WinLossRubric(gamma=0.99) + + for _ in range(99): + rubric(MockAction(), MockObservation(done=False)) + rubric(MockAction(), MockObservation(done=True, metadata={"outcome": "win"})) + + step_rewards = rubric.compute_step_rewards() + + assert len(step_rewards) == 100 + # First step should have gamma^99 reward + assert step_rewards[0] == pytest.approx(0.99**99) + # Last step should have full reward + assert step_rewards[-1] == 1.0 + + def test_observation_without_done_attribute(self): + """Handles observations without done attribute (defaults to False).""" + rubric = EqualCreditRubric() + + class NoDoneObs: + pass + + obs = NoDoneObs() + result = rubric(MockAction(), obs) + + # Should treat as not done + assert result == 0.0 + assert len(rubric._trajectory) == 1 + + def test_multiple_episodes_with_reset(self): + """Multiple episodes with reset between them.""" + rubric = WinLossRubric(gamma=1.0) + + # Episode 1: win + rubric(MockAction(), MockObservation(done=False)) + rubric(MockAction(), MockObservation(done=True, metadata={"outcome": "win"})) + ep1_rewards = rubric.compute_step_rewards() + + rubric.reset() + + # Episode 2: loss + rubric(MockAction(), MockObservation(done=True, metadata={"outcome": "loss"})) + ep2_rewards = rubric.compute_step_rewards() + + assert ep1_rewards == [1.0, 1.0] + assert ep2_rewards == [0.0]