Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 99 additions & 4 deletions taps/transformer/_proxy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from __future__ import annotations

import dataclasses
import json
import pathlib
import sys
from typing import Any
from typing import cast
from typing import Dict
from typing import Literal
from typing import TypeVar
from typing import Union

if sys.version_info >= (3, 11): # pragma: >=3.11 cover
from typing import Self
Expand All @@ -16,18 +22,30 @@
from proxystore.store import Store
from proxystore.store.config import ConnectorConfig
from proxystore.store.utils import resolve_async
from pydantic import ConfigDict
from pydantic import Field
from pydantic import model_validator

from taps.plugins import register
from taps.transformer._protocol import TransformerConfig

T = TypeVar('T')
JSON = Union[int, float, str, Dict[str, 'JSON']]
_PROXYSTORE_DIR = 'proxystore'
_PROXYSTORE_AGGREGATE_FILE = 'aggregated.json'
_PROXYSTORE_STATS_FILE = 'stats.jsonl'


@register('transformer')
class ProxyTransformerConfig(TransformerConfig):
"""[`ProxyTransformer`][taps.transformer.ProxyTransformer] plugin configuration.""" # noqa: E501
"""[`ProxyTransformer`][taps.transformer.ProxyTransformer] plugin configuration.

Note:
Extra arguments provided to this config will be passed as parameters
to the [`Store`][proxystore.store.Store].
""" # noqa: E501

model_config = ConfigDict(extra='allow') # type: ignore[misc]

name: Literal['proxystore'] = Field(
'proxystore',
Expand All @@ -36,21 +54,25 @@ class ProxyTransformerConfig(TransformerConfig):
connector: ConnectorConfig = Field(
description='Connector configuration.',
)
cache_size: int = Field(16, description='cache size')
async_resolve: bool = Field(
False,
description=(
'Asynchronously resolve proxies. Not compatible with '
'extract_target=True.'
),
)
cache_size: int = Field(16, description='cache size')
extract_target: bool = Field(
False,
description=(
'Extract the target from the proxy when resolving the identifier. '
'Not compatible with async_resolve=True.'
),
)
metrics: bool = Field(
False,
description='Enable recording operation metrics.',
)
populate_target: bool = Field(
True,
description='Populate target objects of newly created proxies.',
Expand All @@ -68,16 +90,26 @@ def _validate_mutex_options(self) -> Self:
def get_transformer(self) -> ProxyTransformer:
"""Create a transformer from the configuration."""
connector = self.connector.get_connector()

# Want register=True to be the default unless the user config
# has explicitly disabled it.
extra: dict[str, Any] = {'register': True}
# Guaranteed when config.extra is set to "allow"
assert self.model_extra is not None
extra.update(self.model_extra)

return ProxyTransformer(
store=Store(
'proxy-transformer',
connector=connector,
cache_size=self.cache_size,
metrics=self.metrics,
populate_target=self.populate_target,
register=True,
**extra,
),
async_resolve=self.async_resolve,
extract_target=self.extract_target,
metrics_dir=_PROXYSTORE_DIR if self.metrics else None,
)


Expand All @@ -95,6 +127,12 @@ class ProxyTransformer:
will return the target object. Otherwise, the proxy is returned
since a proxy can act as the target object. Not compatible
with `async_resolve=True`.
metrics_dir: If metrics recording on `store` is `True`, then
write the recorded metrics to this directory when this transformer
is closed. Typically, `close()` is only called on the transformer
instance in the main TaPS process (i.e., `close()` is not called
in worker processes) so only the metrics from the main process
will be recorded.
"""

def __init__(
Expand All @@ -103,6 +141,7 @@ def __init__(
*,
async_resolve: bool = False,
extract_target: bool = False,
metrics_dir: str | None = None,
) -> None:
if async_resolve and extract_target:
raise ValueError(
Expand All @@ -113,19 +152,26 @@ def __init__(
self.store = store
self.async_resolve = async_resolve
self.extract_target = extract_target
self.metrics_dir = (
pathlib.Path(metrics_dir).resolve()
if metrics_dir is not None
else None
)

def __repr__(self) -> str:
ctype = type(self).__name__
store = f'store={self.store}'
async_ = f'async_resolve={self.async_resolve}'
extract = f'extract_target={self.extract_target}'
return f'{ctype}({store}, {async_}, {extract})'
metrics = f'metrics_dir={self.metrics_dir}'
return f'{ctype}({store}, {async_}, {extract}, {metrics})'

def __getstate__(self) -> dict[str, Any]:
return {
'config': self.store.config(),
'async_resolve': self.async_resolve,
'extract_target': self.extract_target,
'metrics_dir': self.metrics_dir,
}

def __setstate__(self, state: dict[str, Any]) -> None:
Expand All @@ -136,11 +182,19 @@ def __setstate__(self, state: dict[str, Any]) -> None:
self.store = Store.from_config(state['config'])
self.async_resolve = state['async_resolve']
self.extract_target = state['extract_target']
self.metrics_dir = state['metrics_dir']

def close(self) -> None:
"""Close the transformer."""
self.store.close()

if self.metrics_dir is not None:
_write_metrics(
self.store,
self.metrics_dir / _PROXYSTORE_AGGREGATE_FILE,
self.metrics_dir / _PROXYSTORE_STATS_FILE,
)

def is_identifier(self, obj: Any) -> bool:
"""Check if the object is an identifier instance."""
return isinstance(obj, Proxy)
Expand Down Expand Up @@ -171,3 +225,44 @@ def resolve(self, identifier: Proxy[T]) -> T | Proxy[T]:
if self.async_resolve:
resolve_async(identifier)
return identifier


def _format_metrics(
store: Store[Any],
) -> tuple[dict[str, JSON], list[JSON]] | None:
if store.metrics is None:
return None

aggregated = {
key: cast(JSON, dataclasses.asdict(times))
for key, times in store.metrics.aggregate_times().items()
}

metrics = store.metrics._metrics.values()
jsonified = map(dataclasses.asdict, metrics)

return aggregated, list(jsonified)


def _write_metrics(
store: Store[Any],
aggregated_path: pathlib.Path,
stats_path: pathlib.Path,
) -> None:
metrics = _format_metrics(store)
if metrics is None:
return

aggregated, individual = metrics
if len(individual) == 0:
return

aggregated_path.parent.mkdir(parents=True, exist_ok=True)
with open(aggregated_path, 'w') as f:
json.dump(aggregated, f, indent=4)

stats_path.parent.mkdir(parents=True, exist_ok=True)
with open(stats_path, 'a') as f:
for stats in individual:
json.dump(stats, f)
f.write('\n')
50 changes: 50 additions & 0 deletions tests/transformer/proxy_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
import pathlib
import pickle

Expand All @@ -12,8 +13,10 @@
from proxystore.store.config import ConnectorConfig
from pydantic import ValidationError

from taps.run.utils import change_cwd
from taps.transformer import ProxyTransformer
from taps.transformer import ProxyTransformerConfig
from taps.transformer._proxy import _write_metrics


def test_file_config(tmp_path: pathlib.Path) -> None:
Expand All @@ -27,6 +30,18 @@ def test_file_config(tmp_path: pathlib.Path) -> None:
transformer.close()


def test_file_config_extras(tmp_path: pathlib.Path) -> None:
config = ProxyTransformerConfig(
connector=ConnectorConfig(
kind='file',
options={'store_dir': str(tmp_path)},
),
register=False,
)
transformer = config.get_transformer()
transformer.close()


def test_config_validation_error(tmp_path: pathlib.Path) -> None:
with pytest.raises(
ValidationError,
Expand Down Expand Up @@ -102,3 +117,38 @@ def test_proxy_transformer_pickling() -> None:
assert get_store(name) is not None

transformer.close()


def test_metrics_recording(tmp_path: pathlib.Path) -> None:
with change_cwd(tmp_path):
config = ProxyTransformerConfig(
connector=ConnectorConfig(kind='local'),
metrics=True,
register=False,
)

transformer = config.get_transformer()
obj = transformer.transform('value')
transformer.resolve(obj)
transformer.close()

assert isinstance(transformer.metrics_dir, pathlib.Path)
files = list(transformer.metrics_dir.iterdir())
assert len(files) == 2 # noqa: PLR2004


@pytest.mark.parametrize('metrics', (True, False))
def test_write_metrics_empty(metrics: bool, tmp_path: pathlib.Path) -> None:
with Store(
'test-write-metrics-disabled',
LocalConnector(),
metrics=metrics,
register=False,
) as store:
_write_metrics(
store,
tmp_path / 'aggregated.json',
tmp_path / 'stats.jsonl',
)

assert len(os.listdir(tmp_path)) == 0
Loading