Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 127 additions & 17 deletions plexe/internal/models/entities/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

from enum import Enum
from functools import total_ordering
from typing import Optional
from weakref import WeakValueDictionary


class ComparisonMethod(Enum):
Expand Down Expand Up @@ -98,31 +100,138 @@ def compare(self, value1: float, value2: float) -> int:
raise ValueError("Invalid comparison method.")


# todo: this class is a mess as it mixes concerns of a metric and a metric value; needs refactoring
# Internal cache for sharing MetricComparator instances across all metrics
# This ensures only one comparator object exists per unique (method, target, epsilon) combination
_comparator_cache: WeakValueDictionary = WeakValueDictionary()


def _get_shared_comparator(comparison_method: ComparisonMethod, target: Optional[float] = None, epsilon: float = 1e-9) -> MetricComparator:
"""
Get or create a shared MetricComparator instance.

This function ensures that identical comparators are reused across all Metric instances,
reducing memory usage and ensuring consistency.

:param comparison_method: The comparison method.
:param target: Optional target value for TARGET_IS_BETTER.
:param epsilon: Tolerance for floating-point comparisons.
:return: A shared MetricComparator instance.
"""
# Create a cache key from the comparator parameters
cache_key = (comparison_method, target, epsilon)

# Try to get existing comparator from cache
if cache_key in _comparator_cache:
return _comparator_cache[cache_key]

# Create new comparator and cache it
comparator = MetricComparator(comparison_method, target, epsilon)
_comparator_cache[cache_key] = comparator
return comparator


class _MetricDefinition:
"""
Internal class representing a metric type definition.

This separates the metric definition (what it is) from the metric value (a measurement).
Metric definitions are immutable and can be shared across multiple metric values.

This is an internal implementation detail - users should not interact with this class directly.
"""

def __init__(self, name: str, comparator: MetricComparator):
"""
Initialize a metric definition.

:param name: The name of the metric.
:param comparator: The shared comparator instance.
"""
self._name = name
self._comparator = comparator

@property
def name(self) -> str:
"""The name of the metric."""
return self._name

@property
def comparator(self) -> MetricComparator:
"""The shared comparator instance."""
return self._comparator

def __eq__(self, other) -> bool:
"""Check if two metric definitions are equal."""
if not isinstance(other, _MetricDefinition):
return False
return (
self.name == other.name
and self.comparator.comparison_method == other.comparator.comparison_method
and self.comparator.target == other.comparator.target
)
Comment on lines 165 to 174
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: epsilon is missing from equality check - two metric definitions with different epsilon values will be considered equal, causing incorrect behavior when comparing metrics with different epsilon tolerances

Suggested change
def __eq__(self, other) -> bool:
"""Check if two metric definitions are equal."""
if not isinstance(other, _MetricDefinition):
return False
return (
self.name == other.name
and self.comparator.comparison_method == other.comparator.comparison_method
and self.comparator.target == other.comparator.target
)
def __eq__(self, other) -> bool:
"""Check if two metric definitions are equal."""
if not isinstance(other, _MetricDefinition):
return False
return (
self.name == other.name
and self.comparator.comparison_method == other.comparator.comparison_method
and self.comparator.target == other.comparator.target
and self.comparator.epsilon == other.comparator.epsilon
)
Prompt To Fix With AI
This is a comment left during a code review.
Path: plexe/internal/models/entities/metric.py
Line: 163:171

Comment:
**logic:** `epsilon` is missing from equality check - two metric definitions with different epsilon values will be considered equal, causing incorrect behavior when comparing metrics with different epsilon tolerances

```suggestion
    def __eq__(self, other) -> bool:
        """Check if two metric definitions are equal."""
        if not isinstance(other, _MetricDefinition):
            return False
        return (
            self.name == other.name
            and self.comparator.comparison_method == other.comparator.comparison_method
            and self.comparator.target == other.comparator.target
            and self.comparator.epsilon == other.comparator.epsilon
        )
```

How can I resolve this? If you propose a fix, please make it concise.


def __hash__(self) -> int:
"""Hash the metric definition."""
return hash((self.name, self.comparator.comparison_method, self.comparator.target))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: epsilon is missing from hash calculation - must include all fields used in __eq__ to maintain hash contract

Suggested change
def __hash__(self) -> int:
"""Hash the metric definition."""
return hash((self.name, self.comparator.comparison_method, self.comparator.target))
def __hash__(self) -> int:
"""Hash the metric definition."""
return hash((self.name, self.comparator.comparison_method, self.comparator.target, self.comparator.epsilon))
Prompt To Fix With AI
This is a comment left during a code review.
Path: plexe/internal/models/entities/metric.py
Line: 173:175

Comment:
**logic:** `epsilon` is missing from hash calculation - must include all fields used in `__eq__` to maintain hash contract

```suggestion
    def __hash__(self) -> int:
        """Hash the metric definition."""
        return hash((self.name, self.comparator.comparison_method, self.comparator.target, self.comparator.epsilon))
```

How can I resolve this? If you propose a fix, please make it concise.



@total_ordering
class Metric:
"""
Represents a metric with a name, a value, and a comparator for determining which metric is better.

This class internally separates the metric definition (type) from the metric value (measurement),
and automatically shares comparator instances to reduce memory usage.

Attributes:
name (str): The name of the metric (e.g., 'accuracy', 'loss').
value (float): The numeric value of the metric.
comparator (MetricComparator): The comparison logic for the metric.
comparator (MetricComparator): The comparison logic for the metric (shared instance).
"""

def __init__(self, name: str, value: float = None, comparator: MetricComparator = None, is_worst: bool = False):
"""
Initializes a Metric object.

The comparator instance is automatically shared with other metrics that have the same
comparison method, target, and epsilon values, reducing memory usage.

:param name: The name of the metric.
:param value: The numeric value of the metric.
:param comparator: An instance of MetricComparator for comparison logic.
:param is_worst: Indicates if the metric value is the worst possible value.
"""
self.name = name
# Store the metric value (dynamic, instance-specific)
self.value = value
self.comparator = comparator
self.is_worst = is_worst or value is None

# Get or create a shared comparator instance
if comparator is not None:
# Use the shared comparator cache to ensure we reuse identical comparators
# This is the key optimization: identical comparators are shared across all metrics
shared_comparator = _get_shared_comparator(
comparison_method=comparator.comparison_method,
target=comparator.target,
epsilon=comparator.epsilon
)
else:
# If no comparator provided, raise an error as it's required for a valid metric
# This maintains the same behavior as before
raise ValueError("Metric requires a comparator. Provide a MetricComparator instance.")

# Create internal metric definition (separates type from value)
# This is the key separation: definition (what it is) vs value (measurement)
self._definition = _MetricDefinition(name=name, comparator=shared_comparator)

@property
def name(self) -> str:
"""The name of the metric (for backward compatibility)."""
return self._definition.name

@property
def comparator(self) -> MetricComparator:
"""The shared comparator instance (for backward compatibility)."""
return self._definition.comparator

def __gt__(self, other) -> bool:
"""
Expand All @@ -141,17 +250,18 @@ def __gt__(self, other) -> bool:
if other.is_worst:
return True

if self.name != other.name:
raise ValueError("Cannot compare metrics with different names.")

if self.comparator.comparison_method != other.comparator.comparison_method:
raise ValueError("Cannot compare metrics with different comparison methods.")

if (
self.comparator.comparison_method == ComparisonMethod.TARGET_IS_BETTER
and self.comparator.target != other.comparator.target
):
raise ValueError("Cannot compare 'TARGET_IS_BETTER' metrics with different target values.")
# Compare using definitions - this is cleaner and ensures consistency
if self._definition != other._definition:
# Provide detailed error message for backward compatibility
if self.name != other.name:
raise ValueError("Cannot compare metrics with different names.")
if self.comparator.comparison_method != other.comparator.comparison_method:
raise ValueError("Cannot compare metrics with different comparison methods.")
if (
self.comparator.comparison_method == ComparisonMethod.TARGET_IS_BETTER
and self.comparator.target != other.comparator.target
):
raise ValueError("Cannot compare 'TARGET_IS_BETTER' metrics with different target values.")

return self.comparator.compare(self.value, other.value) < 0

Expand All @@ -171,9 +281,9 @@ def __eq__(self, other) -> bool:
if self.is_worst or other.is_worst:
return False

# Use definition equality for cleaner comparison
return (
self.name == other.name
and self.comparator.comparison_method == other.comparator.comparison_method
self._definition == other._definition
and self.comparator.compare(self.value, other.value) == 0
)

Expand Down
Loading