diff --git a/taps/transformer/_proxy.py b/taps/transformer/_proxy.py index b6653c92..845296a3 100644 --- a/taps/transformer/_proxy.py +++ b/taps/transformer/_proxy.py @@ -1,9 +1,15 @@ from __future__ import annotations +import dataclasses +import json +import pathlib import sys from typing import Any +from typing import cast +from typing import Dict from typing import Literal from typing import TypeVar +from typing import Union if sys.version_info >= (3, 11): # pragma: >=3.11 cover from typing import Self @@ -16,6 +22,7 @@ from proxystore.store import Store from proxystore.store.config import ConnectorConfig from proxystore.store.utils import resolve_async +from pydantic import ConfigDict from pydantic import Field from pydantic import model_validator @@ -23,11 +30,22 @@ from taps.transformer._protocol import TransformerConfig T = TypeVar('T') +JSON = Union[int, float, str, Dict[str, 'JSON']] +_PROXYSTORE_DIR = 'proxystore' +_PROXYSTORE_AGGREGATE_FILE = 'aggregated.json' +_PROXYSTORE_STATS_FILE = 'stats.jsonl' @register('transformer') class ProxyTransformerConfig(TransformerConfig): - """[`ProxyTransformer`][taps.transformer.ProxyTransformer] plugin configuration.""" # noqa: E501 + """[`ProxyTransformer`][taps.transformer.ProxyTransformer] plugin configuration. + + Note: + Extra arguments provided to this config will be passed as parameters + to the [`Store`][proxystore.store.Store]. + """ # noqa: E501 + + model_config = ConfigDict(extra='allow') # type: ignore[misc] name: Literal['proxystore'] = Field( 'proxystore', @@ -36,7 +54,6 @@ class ProxyTransformerConfig(TransformerConfig): connector: ConnectorConfig = Field( description='Connector configuration.', ) - cache_size: int = Field(16, description='cache size') async_resolve: bool = Field( False, description=( @@ -44,6 +61,7 @@ class ProxyTransformerConfig(TransformerConfig): 'extract_target=True.' ), ) + cache_size: int = Field(16, description='cache size') extract_target: bool = Field( False, description=( @@ -51,6 +69,10 @@ class ProxyTransformerConfig(TransformerConfig): 'Not compatible with async_resolve=True.' ), ) + metrics: bool = Field( + False, + description='Enable recording operation metrics.', + ) populate_target: bool = Field( True, description='Populate target objects of newly created proxies.', @@ -68,16 +90,26 @@ def _validate_mutex_options(self) -> Self: def get_transformer(self) -> ProxyTransformer: """Create a transformer from the configuration.""" connector = self.connector.get_connector() + + # Want register=True to be the default unless the user config + # has explicitly disabled it. + extra: dict[str, Any] = {'register': True} + # Guaranteed when config.extra is set to "allow" + assert self.model_extra is not None + extra.update(self.model_extra) + return ProxyTransformer( store=Store( 'proxy-transformer', connector=connector, cache_size=self.cache_size, + metrics=self.metrics, populate_target=self.populate_target, - register=True, + **extra, ), async_resolve=self.async_resolve, extract_target=self.extract_target, + metrics_dir=_PROXYSTORE_DIR if self.metrics else None, ) @@ -95,6 +127,12 @@ class ProxyTransformer: will return the target object. Otherwise, the proxy is returned since a proxy can act as the target object. Not compatible with `async_resolve=True`. + metrics_dir: If metrics recording on `store` is `True`, then + write the recorded metrics to this directory when this transformer + is closed. Typically, `close()` is only called on the transformer + instance in the main TaPS process (i.e., `close()` is not called + in worker processes) so only the metrics from the main process + will be recorded. """ def __init__( @@ -103,6 +141,7 @@ def __init__( *, async_resolve: bool = False, extract_target: bool = False, + metrics_dir: str | None = None, ) -> None: if async_resolve and extract_target: raise ValueError( @@ -113,19 +152,26 @@ def __init__( self.store = store self.async_resolve = async_resolve self.extract_target = extract_target + self.metrics_dir = ( + pathlib.Path(metrics_dir).resolve() + if metrics_dir is not None + else None + ) def __repr__(self) -> str: ctype = type(self).__name__ store = f'store={self.store}' async_ = f'async_resolve={self.async_resolve}' extract = f'extract_target={self.extract_target}' - return f'{ctype}({store}, {async_}, {extract})' + metrics = f'metrics_dir={self.metrics_dir}' + return f'{ctype}({store}, {async_}, {extract}, {metrics})' def __getstate__(self) -> dict[str, Any]: return { 'config': self.store.config(), 'async_resolve': self.async_resolve, 'extract_target': self.extract_target, + 'metrics_dir': self.metrics_dir, } def __setstate__(self, state: dict[str, Any]) -> None: @@ -136,11 +182,19 @@ def __setstate__(self, state: dict[str, Any]) -> None: self.store = Store.from_config(state['config']) self.async_resolve = state['async_resolve'] self.extract_target = state['extract_target'] + self.metrics_dir = state['metrics_dir'] def close(self) -> None: """Close the transformer.""" self.store.close() + if self.metrics_dir is not None: + _write_metrics( + self.store, + self.metrics_dir / _PROXYSTORE_AGGREGATE_FILE, + self.metrics_dir / _PROXYSTORE_STATS_FILE, + ) + def is_identifier(self, obj: Any) -> bool: """Check if the object is an identifier instance.""" return isinstance(obj, Proxy) @@ -171,3 +225,44 @@ def resolve(self, identifier: Proxy[T]) -> T | Proxy[T]: if self.async_resolve: resolve_async(identifier) return identifier + + +def _format_metrics( + store: Store[Any], +) -> tuple[dict[str, JSON], list[JSON]] | None: + if store.metrics is None: + return None + + aggregated = { + key: cast(JSON, dataclasses.asdict(times)) + for key, times in store.metrics.aggregate_times().items() + } + + metrics = store.metrics._metrics.values() + jsonified = map(dataclasses.asdict, metrics) + + return aggregated, list(jsonified) + + +def _write_metrics( + store: Store[Any], + aggregated_path: pathlib.Path, + stats_path: pathlib.Path, +) -> None: + metrics = _format_metrics(store) + if metrics is None: + return + + aggregated, individual = metrics + if len(individual) == 0: + return + + aggregated_path.parent.mkdir(parents=True, exist_ok=True) + with open(aggregated_path, 'w') as f: + json.dump(aggregated, f, indent=4) + + stats_path.parent.mkdir(parents=True, exist_ok=True) + with open(stats_path, 'a') as f: + for stats in individual: + json.dump(stats, f) + f.write('\n') diff --git a/tests/transformer/proxy_test.py b/tests/transformer/proxy_test.py index ea2c0aee..b25650db 100644 --- a/tests/transformer/proxy_test.py +++ b/tests/transformer/proxy_test.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os import pathlib import pickle @@ -12,8 +13,10 @@ from proxystore.store.config import ConnectorConfig from pydantic import ValidationError +from taps.run.utils import change_cwd from taps.transformer import ProxyTransformer from taps.transformer import ProxyTransformerConfig +from taps.transformer._proxy import _write_metrics def test_file_config(tmp_path: pathlib.Path) -> None: @@ -27,6 +30,18 @@ def test_file_config(tmp_path: pathlib.Path) -> None: transformer.close() +def test_file_config_extras(tmp_path: pathlib.Path) -> None: + config = ProxyTransformerConfig( + connector=ConnectorConfig( + kind='file', + options={'store_dir': str(tmp_path)}, + ), + register=False, + ) + transformer = config.get_transformer() + transformer.close() + + def test_config_validation_error(tmp_path: pathlib.Path) -> None: with pytest.raises( ValidationError, @@ -102,3 +117,38 @@ def test_proxy_transformer_pickling() -> None: assert get_store(name) is not None transformer.close() + + +def test_metrics_recording(tmp_path: pathlib.Path) -> None: + with change_cwd(tmp_path): + config = ProxyTransformerConfig( + connector=ConnectorConfig(kind='local'), + metrics=True, + register=False, + ) + + transformer = config.get_transformer() + obj = transformer.transform('value') + transformer.resolve(obj) + transformer.close() + + assert isinstance(transformer.metrics_dir, pathlib.Path) + files = list(transformer.metrics_dir.iterdir()) + assert len(files) == 2 # noqa: PLR2004 + + +@pytest.mark.parametrize('metrics', (True, False)) +def test_write_metrics_empty(metrics: bool, tmp_path: pathlib.Path) -> None: + with Store( + 'test-write-metrics-disabled', + LocalConnector(), + metrics=metrics, + register=False, + ) as store: + _write_metrics( + store, + tmp_path / 'aggregated.json', + tmp_path / 'stats.jsonl', + ) + + assert len(os.listdir(tmp_path)) == 0