From 7529e83a8a09b7bed579bd14ade6de92a6a0bab5 Mon Sep 17 00:00:00 2001 From: aumpatel Date: Fri, 28 Nov 2025 18:45:13 -0500 Subject: [PATCH 1/3] refactor: separate metric definition from metric value and share comparator instances - Add _MetricDefinition class to separate metric type (definition) from metric value (measurement) - Implement comparator caching via _get_shared_comparator() to ensure only one comparator object exists per unique (method, target, epsilon) combination - All solutions using the same comparison method now share the same comparator instance, reducing memory usage - Maintain full backward compatibility - no changes to Metric class API - Remove TODO comment as the refactoring addresses the concern about mixing metric and metric value concerns --- plexe/internal/models/entities/metric.py | 144 ++++++++++++++++++++--- 1 file changed, 127 insertions(+), 17 deletions(-) diff --git a/plexe/internal/models/entities/metric.py b/plexe/internal/models/entities/metric.py index f5355984..078954e9 100644 --- a/plexe/internal/models/entities/metric.py +++ b/plexe/internal/models/entities/metric.py @@ -22,6 +22,8 @@ from enum import Enum from functools import total_ordering +from typing import Optional +from weakref import WeakValueDictionary class ComparisonMethod(Enum): @@ -98,31 +100,138 @@ def compare(self, value1: float, value2: float) -> int: raise ValueError("Invalid comparison method.") -# todo: this class is a mess as it mixes concerns of a metric and a metric value; needs refactoring +# Internal cache for sharing MetricComparator instances across all metrics +# This ensures only one comparator object exists per unique (method, target, epsilon) combination +_comparator_cache: WeakValueDictionary = WeakValueDictionary() + + +def _get_shared_comparator(comparison_method: ComparisonMethod, target: Optional[float] = None, epsilon: float = 1e-9) -> MetricComparator: + """ + Get or create a shared MetricComparator instance. + + This function ensures that identical comparators are reused across all Metric instances, + reducing memory usage and ensuring consistency. + + :param comparison_method: The comparison method. + :param target: Optional target value for TARGET_IS_BETTER. + :param epsilon: Tolerance for floating-point comparisons. + :return: A shared MetricComparator instance. + """ + # Create a cache key from the comparator parameters + cache_key = (comparison_method, target, epsilon) + + # Try to get existing comparator from cache + if cache_key in _comparator_cache: + return _comparator_cache[cache_key] + + # Create new comparator and cache it + comparator = MetricComparator(comparison_method, target, epsilon) + _comparator_cache[cache_key] = comparator + return comparator + + +class _MetricDefinition: + """ + Internal class representing a metric type definition. + + This separates the metric definition (what it is) from the metric value (a measurement). + Metric definitions are immutable and can be shared across multiple metric values. + + This is an internal implementation detail - users should not interact with this class directly. + """ + + def __init__(self, name: str, comparator: MetricComparator): + """ + Initialize a metric definition. + + :param name: The name of the metric. + :param comparator: The shared comparator instance. + """ + self._name = name + self._comparator = comparator + + @property + def name(self) -> str: + """The name of the metric.""" + return self._name + + @property + def comparator(self) -> MetricComparator: + """The shared comparator instance.""" + return self._comparator + + def __eq__(self, other) -> bool: + """Check if two metric definitions are equal.""" + if not isinstance(other, _MetricDefinition): + return False + return ( + self.name == other.name + and self.comparator.comparison_method == other.comparator.comparison_method + and self.comparator.target == other.comparator.target + ) + + def __hash__(self) -> int: + """Hash the metric definition.""" + return hash((self.name, self.comparator.comparison_method, self.comparator.target)) + + @total_ordering class Metric: """ Represents a metric with a name, a value, and a comparator for determining which metric is better. + This class internally separates the metric definition (type) from the metric value (measurement), + and automatically shares comparator instances to reduce memory usage. + Attributes: name (str): The name of the metric (e.g., 'accuracy', 'loss'). value (float): The numeric value of the metric. - comparator (MetricComparator): The comparison logic for the metric. + comparator (MetricComparator): The comparison logic for the metric (shared instance). """ def __init__(self, name: str, value: float = None, comparator: MetricComparator = None, is_worst: bool = False): """ Initializes a Metric object. + The comparator instance is automatically shared with other metrics that have the same + comparison method, target, and epsilon values, reducing memory usage. + :param name: The name of the metric. :param value: The numeric value of the metric. :param comparator: An instance of MetricComparator for comparison logic. :param is_worst: Indicates if the metric value is the worst possible value. """ - self.name = name + # Store the metric value (dynamic, instance-specific) self.value = value - self.comparator = comparator self.is_worst = is_worst or value is None + + # Get or create a shared comparator instance + if comparator is not None: + # Use the shared comparator cache to ensure we reuse identical comparators + # This is the key optimization: identical comparators are shared across all metrics + shared_comparator = _get_shared_comparator( + comparison_method=comparator.comparison_method, + target=comparator.target, + epsilon=comparator.epsilon + ) + else: + # If no comparator provided, raise an error as it's required for a valid metric + # This maintains the same behavior as before + raise ValueError("Metric requires a comparator. Provide a MetricComparator instance.") + + # Create internal metric definition (separates type from value) + # This is the key separation: definition (what it is) vs value (measurement) + self._definition = _MetricDefinition(name=name, comparator=shared_comparator) + + @property + def name(self) -> str: + """The name of the metric (for backward compatibility).""" + return self._definition.name + + @property + def comparator(self) -> MetricComparator: + """The shared comparator instance (for backward compatibility).""" + return self._definition.comparator def __gt__(self, other) -> bool: """ @@ -141,17 +250,18 @@ def __gt__(self, other) -> bool: if other.is_worst: return True - if self.name != other.name: - raise ValueError("Cannot compare metrics with different names.") - - if self.comparator.comparison_method != other.comparator.comparison_method: - raise ValueError("Cannot compare metrics with different comparison methods.") - - if ( - self.comparator.comparison_method == ComparisonMethod.TARGET_IS_BETTER - and self.comparator.target != other.comparator.target - ): - raise ValueError("Cannot compare 'TARGET_IS_BETTER' metrics with different target values.") + # Compare using definitions - this is cleaner and ensures consistency + if self._definition != other._definition: + # Provide detailed error message for backward compatibility + if self.name != other.name: + raise ValueError("Cannot compare metrics with different names.") + if self.comparator.comparison_method != other.comparator.comparison_method: + raise ValueError("Cannot compare metrics with different comparison methods.") + if ( + self.comparator.comparison_method == ComparisonMethod.TARGET_IS_BETTER + and self.comparator.target != other.comparator.target + ): + raise ValueError("Cannot compare 'TARGET_IS_BETTER' metrics with different target values.") return self.comparator.compare(self.value, other.value) < 0 @@ -171,9 +281,9 @@ def __eq__(self, other) -> bool: if self.is_worst or other.is_worst: return False + # Use definition equality for cleaner comparison return ( - self.name == other.name - and self.comparator.comparison_method == other.comparator.comparison_method + self._definition == other._definition and self.comparator.compare(self.value, other.value) == 0 ) From f413dd3405c609c3553bc6ba3db913f00dd9ba7b Mon Sep 17 00:00:00 2001 From: aumpatel Date: Fri, 28 Nov 2025 18:55:00 -0500 Subject: [PATCH 2/3] fix: include epsilon in MetricDefinition equality and hash methods - Add epsilon to __eq__ method to ensure metrics with different epsilon values are correctly differentiated - Add epsilon to __hash__ method to maintain hash contract (must include all fields used in __eq__) - Fix redundant condition in Metric.__gt__ method This fixes critical bugs identified in code review that could cause incorrect metric comparisons. --- plexe/internal/models/entities/metric.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/plexe/internal/models/entities/metric.py b/plexe/internal/models/entities/metric.py index 078954e9..0478b6db 100644 --- a/plexe/internal/models/entities/metric.py +++ b/plexe/internal/models/entities/metric.py @@ -168,11 +168,12 @@ def __eq__(self, other) -> bool: self.name == other.name and self.comparator.comparison_method == other.comparator.comparison_method and self.comparator.target == other.comparator.target + and self.comparator.epsilon == other.comparator.epsilon ) def __hash__(self) -> int: """Hash the metric definition.""" - return hash((self.name, self.comparator.comparison_method, self.comparator.target)) + return hash((self.name, self.comparator.comparison_method, self.comparator.target, self.comparator.epsilon)) @total_ordering @@ -244,7 +245,7 @@ def __gt__(self, other) -> bool: if not isinstance(other, Metric): return NotImplemented - if self.is_worst or (self.is_worst and other.is_worst): + if self.is_worst: return False if other.is_worst: From ac5284323911b43066ef2177e2f112aaf5ecf429 Mon Sep 17 00:00:00 2001 From: aumpatel Date: Fri, 28 Nov 2025 19:03:11 -0500 Subject: [PATCH 3/3] style: format metric.py with black - Reformat code to comply with black formatting standards - Fixes CI formatting check failure --- plexe/internal/models/entities/metric.py | 45 +++++++++++------------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/plexe/internal/models/entities/metric.py b/plexe/internal/models/entities/metric.py index 0478b6db..ce313f05 100644 --- a/plexe/internal/models/entities/metric.py +++ b/plexe/internal/models/entities/metric.py @@ -105,13 +105,15 @@ def compare(self, value1: float, value2: float) -> int: _comparator_cache: WeakValueDictionary = WeakValueDictionary() -def _get_shared_comparator(comparison_method: ComparisonMethod, target: Optional[float] = None, epsilon: float = 1e-9) -> MetricComparator: +def _get_shared_comparator( + comparison_method: ComparisonMethod, target: Optional[float] = None, epsilon: float = 1e-9 +) -> MetricComparator: """ Get or create a shared MetricComparator instance. - + This function ensures that identical comparators are reused across all Metric instances, reducing memory usage and ensuring consistency. - + :param comparison_method: The comparison method. :param target: Optional target value for TARGET_IS_BETTER. :param epsilon: Tolerance for floating-point comparisons. @@ -119,11 +121,11 @@ def _get_shared_comparator(comparison_method: ComparisonMethod, target: Optional """ # Create a cache key from the comparator parameters cache_key = (comparison_method, target, epsilon) - + # Try to get existing comparator from cache if cache_key in _comparator_cache: return _comparator_cache[cache_key] - + # Create new comparator and cache it comparator = MetricComparator(comparison_method, target, epsilon) _comparator_cache[cache_key] = comparator @@ -133,33 +135,33 @@ def _get_shared_comparator(comparison_method: ComparisonMethod, target: Optional class _MetricDefinition: """ Internal class representing a metric type definition. - + This separates the metric definition (what it is) from the metric value (a measurement). Metric definitions are immutable and can be shared across multiple metric values. - + This is an internal implementation detail - users should not interact with this class directly. """ - + def __init__(self, name: str, comparator: MetricComparator): """ Initialize a metric definition. - + :param name: The name of the metric. :param comparator: The shared comparator instance. """ self._name = name self._comparator = comparator - + @property def name(self) -> str: """The name of the metric.""" return self._name - + @property def comparator(self) -> MetricComparator: """The shared comparator instance.""" return self._comparator - + def __eq__(self, other) -> bool: """Check if two metric definitions are equal.""" if not isinstance(other, _MetricDefinition): @@ -170,7 +172,7 @@ def __eq__(self, other) -> bool: and self.comparator.target == other.comparator.target and self.comparator.epsilon == other.comparator.epsilon ) - + def __hash__(self) -> int: """Hash the metric definition.""" return hash((self.name, self.comparator.comparison_method, self.comparator.target, self.comparator.epsilon)) @@ -205,30 +207,28 @@ def __init__(self, name: str, value: float = None, comparator: MetricComparator # Store the metric value (dynamic, instance-specific) self.value = value self.is_worst = is_worst or value is None - + # Get or create a shared comparator instance if comparator is not None: # Use the shared comparator cache to ensure we reuse identical comparators # This is the key optimization: identical comparators are shared across all metrics shared_comparator = _get_shared_comparator( - comparison_method=comparator.comparison_method, - target=comparator.target, - epsilon=comparator.epsilon + comparison_method=comparator.comparison_method, target=comparator.target, epsilon=comparator.epsilon ) else: # If no comparator provided, raise an error as it's required for a valid metric # This maintains the same behavior as before raise ValueError("Metric requires a comparator. Provide a MetricComparator instance.") - + # Create internal metric definition (separates type from value) # This is the key separation: definition (what it is) vs value (measurement) self._definition = _MetricDefinition(name=name, comparator=shared_comparator) - + @property def name(self) -> str: """The name of the metric (for backward compatibility).""" return self._definition.name - + @property def comparator(self) -> MetricComparator: """The shared comparator instance (for backward compatibility).""" @@ -283,10 +283,7 @@ def __eq__(self, other) -> bool: return False # Use definition equality for cleaner comparison - return ( - self._definition == other._definition - and self.comparator.compare(self.value, other.value) == 0 - ) + return self._definition == other._definition and self.comparator.compare(self.value, other.value) == 0 def __repr__(self) -> str: """