Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions src/openenv/core/env_server/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from typing import Any, Generic, Optional, Protocol, TypedDict, TypeVar
from typing import Any, Generic, Optional, Protocol, TypedDict, TypeVar, TYPE_CHECKING

from .types import Action, Observation, State, EnvironmentMetadata

if TYPE_CHECKING:
from openenv.core.rubrics import Rubric

ActT = TypeVar("ActT", bound=Action)
ObsT = TypeVar("ObsT", bound=Observation)
StateT = TypeVar("StateT", bound=State)
Expand Down Expand Up @@ -94,6 +97,8 @@ class Environment(ABC, Generic[ActT, ObsT, StateT]):

Args:
transform: Optional transform to apply to observations
rubric: Optional rubric for reward computation. When provided, the
rubric's output can be used to set the observation's reward in step().

Class Attributes:
SUPPORTS_CONCURRENT_SESSIONS: Whether this environment supports concurrent sessions.
Expand All @@ -105,13 +110,30 @@ class Environment(ABC, Generic[ActT, ObsT, StateT]):
- The environment uses proper session isolation (e.g., unique working dirs)
- No shared mutable state exists between instances
- External resources (databases, APIs) can handle concurrent access

Attributes:
rubric: Optional rubric for computing rewards. Environments can set this
in __init__ and use it in step() to compute observation rewards.
Training infrastructure can access it for introspection:
for name, r in env.rubric.named_rubrics():
print(f"{name}: {r.last_score}")

See RFC 004 for rubric design: rfcs/004-rubrics.md
"""

# Class-level flag indicating whether this environment supports concurrent sessions
SUPPORTS_CONCURRENT_SESSIONS: bool = False

def __init__(self, transform: Optional[Transform[ObsT]] = None):
# Optional rubric for reward computation
rubric: Optional["Rubric"]

def __init__(
self,
transform: Optional[Transform[ObsT]] = None,
rubric: Optional["Rubric"] = None,
):
self.transform = transform
self.rubric = rubric

@abstractmethod
def reset(
Expand Down Expand Up @@ -185,6 +207,40 @@ def _apply_transform(self, observation: ObsT) -> ObsT:
return self.transform(observation)
return observation

def _apply_rubric(self, action: ActT, observation: ObsT) -> float:
"""Apply rubric if one is provided.

Args:
action: The action taken by the agent.
observation: The resulting observation.

Returns:
Reward value from the rubric, or 0.0 if no rubric is set.

Usage in step():
def step(self, action: MyAction, ...) -> MyObservation:
# ... execute action and create observation ...
observation.reward = self._apply_rubric(action, observation)
return observation
"""
if self.rubric is not None:
return self.rubric(action, observation)
return 0.0

def _reset_rubric(self) -> None:
"""Reset the rubric state if one is provided.

Call this in reset() to clear any trajectory state in the rubric.

Usage in reset():
def reset(self, ...) -> MyObservation:
self._reset_rubric()
# ... create initial observation ...
return observation
"""
if self.rubric is not None:
self.rubric.reset()

def close(self) -> None:
"""Clean up resources used by the environment.

Expand Down
37 changes: 37 additions & 0 deletions src/openenv/core/rubrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Rubrics for reward computation.

See RFC 004 for full design: rfcs/004-rubrics.md
"""

from openenv.core.rubrics.base import Rubric
from openenv.core.rubrics.containers import (
Sequential,
Gate,
WeightedSum,
RubricList,
RubricDict,
)
from openenv.core.rubrics.trajectory import (
TrajectoryRubric,
ExponentialDiscountingTrajectoryRubric,
)

__all__ = [
# Base
"Rubric",
# Containers
"Sequential",
"Gate",
"WeightedSum",
"RubricList",
"RubricDict",
# Trajectory
"TrajectoryRubric",
"ExponentialDiscountingTrajectoryRubric",
]
164 changes: 164 additions & 0 deletions src/openenv/core/rubrics/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Base Rubric class for reward computation.

Rubrics compute rewards from actions and observations. The API is modeled
after PyTorch's nn.Module: users implement forward(), and the framework
handles child registration and hooks.

See RFC 004 for full design: rfcs/004-rubrics.md
"""

from abc import ABC, abstractmethod
from typing import Any, Dict, Iterator, List, Optional, Tuple, Callable


class Rubric(ABC):
"""Abstract base class for reward computation.

A Rubric computes a reward signal from an action and observation.
Subclasses implement forward() to define the reward logic.

Usage:
class MyRubric(Rubric):
def forward(self, action, observation) -> float:
return 1.0 if action.valid else 0.0

rubric = MyRubric()
reward = rubric(action, observation)

Child rubrics are auto-registered when assigned as attributes,
enabling hierarchical composition and introspection.
"""

_rubric_children: Dict[str, "Rubric"]
_forward_hooks: List[Callable]
_forward_pre_hooks: List[Callable]
last_score: Optional[float]

def __init__(self):
# Use object.__setattr__ to avoid triggering __setattr__ during init
object.__setattr__(self, "_rubric_children", {})
object.__setattr__(self, "_forward_hooks", [])
object.__setattr__(self, "_forward_pre_hooks", [])
object.__setattr__(self, "last_score", None)

def __setattr__(self, name: str, value: Any) -> None:
# Auto-register child rubrics when assigned as attributes
if isinstance(value, Rubric):
self._rubric_children[name] = value
object.__setattr__(self, name, value)

def __call__(self, action: Any, observation: Any) -> float:
"""Evaluate the rubric with hooks.

Args:
action: The action taken by the agent.
observation: The resulting observation.

Returns:
Reward value (typically 0.0 to 1.0).
"""
# Pre-forward hooks
for hook in self._forward_pre_hooks:
hook(self, action, observation)

# Compute reward
result = self.forward(action, observation)
self.last_score = result

# Post-forward hooks
for hook in self._forward_hooks:
hook(self, action, observation, result)

return result

@abstractmethod
def forward(self, action: Any, observation: Any) -> float:
"""Compute the reward. Implement this in subclasses.

Args:
action: The action taken by the agent.
observation: The resulting observation.

Returns:
Reward value (typically 0.0 to 1.0).
"""
raise NotImplementedError

def register_forward_hook(
self, hook: Callable[["Rubric", Any, Any, float], None]
) -> None:
"""Register a hook called after forward().

Args:
hook: Callable with signature (rubric, action, observation, result).
"""
self._forward_hooks.append(hook)

def register_forward_pre_hook(
self, hook: Callable[["Rubric", Any, Any], None]
) -> None:
"""Register a hook called before forward().

Args:
hook: Callable with signature (rubric, action, observation).
"""
self._forward_pre_hooks.append(hook)

def children(self) -> Iterator["Rubric"]:
"""Iterate over immediate child rubrics."""
yield from self._rubric_children.values()

def named_children(self) -> Iterator[Tuple[str, "Rubric"]]:
"""Iterate over immediate child rubrics with names."""
yield from self._rubric_children.items()

def rubrics(self) -> Iterator["Rubric"]:
"""Iterate over all descendant rubrics (depth-first)."""
for child in self._rubric_children.values():
yield child
yield from child.rubrics()

def named_rubrics(self, prefix: str = "") -> Iterator[Tuple[str, "Rubric"]]:
"""Iterate over all descendant rubrics with dot-separated names."""
for name, child in self._rubric_children.items():
full_name = f"{prefix}.{name}" if prefix else name
yield full_name, child
yield from child.named_rubrics(full_name)

def get_rubric(self, path: str) -> "Rubric":
"""Access a nested rubric by dot-separated path.

Args:
path: Dot-separated path (e.g., "code.syntax").

Returns:
The rubric at the specified path.

Raises:
KeyError: If the path does not exist.
"""
parts = path.split(".")
current = self
for part in parts:
if part not in current._rubric_children:
raise KeyError(f"Rubric path not found: {path}")
current = current._rubric_children[part]
return current

def reset(self) -> None:
"""Reset any internal state. Override in subclasses if needed."""
pass

def state_dict(self) -> Dict[str, Any]:
"""Serialize rubric configuration for checkpointing."""
return {}

def load_state_dict(self, state: Dict[str, Any]) -> None:
"""Load rubric configuration from checkpoint."""
pass
Loading