Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- Add colocated runtime helpers for Pathways MTC.

## [0.11.36] - 2026-04-14

### Added
Expand Down
26 changes: 19 additions & 7 deletions checkpoint/orbax/checkpoint/_src/futures/signaling_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,18 +357,30 @@ def wait_at_barrier(
)


def mark_pathways_colocated_runtime_active() -> None:
"""Marks the current Python process as the Pathways colocated runtime."""
multihost.mark_pathways_colocated_runtime_active()
get_signaling_client.cache_clear()


@functools.lru_cache()
def get_signaling_client() -> SignalingClient:
"""Returns the signaling client to use for the current environment."""
if multihost.is_jax_distributed_client_initialized():
logging.info("Using JaxDistributedSignalingClient")
return JaxDistributedSignalingClient()
else:
process_count = multihost.process_count()
if process_count > 1:
raise RuntimeError(
"ThreadSafeKeyValueSignalingClient should only be used in a single"
f" controller setup, process count: {process_count}."
)

# Verify that we are either in a Pathways backend, Pathways colocated
# runtime, or single process environment.
if (
multihost.is_pathways_backend()
or multihost.is_pathways_colocated_runtime_active()
or (process_count := multihost.process_count()) == 1
):
logging.info("Using ThreadSafeKeyValueSignalingClient")
return ThreadSafeKeyValueSignalingClient()

raise RuntimeError(
"ThreadSafeKeyValueSignalingClient should only be used in a single"
f" controller setup, process count: {process_count}."
)
90 changes: 88 additions & 2 deletions checkpoint/orbax/checkpoint/_src/futures/signaling_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,12 @@ class TestGetSignalingClient(absltest.TestCase):

def setUp(self):
super().setUp()
multihost._PATHWAYS_COLOCATED_RUNTIME_ACTIVE = False # pylint: disable=protected-access
signaling_client.get_signaling_client.cache_clear()

def tearDown(self):
super().tearDown()
multihost._PATHWAYS_COLOCATED_RUNTIME_ACTIVE = False # pylint: disable=protected-access
signaling_client.get_signaling_client.cache_clear()

@mock.patch.object(multihost, "is_jax_distributed_client_initialized")
Expand All @@ -312,13 +314,19 @@ def test_returns_jax_client_when_initialized(
client2 = signaling_client.get_signaling_client()
self.assertIs(client, client2)

@mock.patch.object(multihost, "is_pathways_backend", return_value=False)
@mock.patch.object(multihost, "process_count", return_value=1)
@mock.patch.object(multihost, "is_jax_distributed_client_initialized")
def test_returns_thread_safe_client_when_not_initialized(self, mock_is_init):
def test_returns_thread_safe_client_when_single_process_not_initialized(
self, mock_is_init, mock_process_count, mock_is_pathways_backend
):
mock_is_init.return_value = False

client = signaling_client.get_signaling_client()

mock_is_init.assert_called_once()
mock_is_pathways_backend.assert_called_once()
mock_process_count.assert_called_once()
self.assertIsInstance(
client, signaling_client.ThreadSafeKeyValueSignalingClient
)
Expand All @@ -330,9 +338,85 @@ def test_returns_thread_safe_client_when_not_initialized(self, mock_is_init):
@mock.patch.object(
multihost, "is_jax_distributed_client_initialized", return_value=False
)
@mock.patch.object(multihost, "is_pathways_backend", return_value=True)
@mock.patch.object(multihost, "process_count", return_value=2)
def test_returns_thread_safe_client_when_pathways_active(
self, mock_process_count, mock_is_pathways_backend, mock_is_init
):
client = signaling_client.get_signaling_client()

self.assertIsInstance(
client, signaling_client.ThreadSafeKeyValueSignalingClient
)
mock_is_init.assert_called_once()
mock_is_pathways_backend.assert_called_once()
mock_process_count.assert_not_called()

@mock.patch.object(
multihost, "is_jax_distributed_client_initialized", return_value=False
)
@mock.patch.object(multihost, "is_pathways_backend", return_value=False)
@mock.patch.object(multihost, "process_count", return_value=16)
def test_returns_thread_safe_client_when_explicit_colocated_runtime_active(
self,
mock_process_count,
mock_is_pathways_backend,
mock_is_init,
):
multihost._PATHWAYS_COLOCATED_RUNTIME_ACTIVE = True # pylint: disable=protected-access

client = signaling_client.get_signaling_client()

self.assertIsInstance(
client, signaling_client.ThreadSafeKeyValueSignalingClient
)
mock_is_init.assert_called_once()
mock_is_pathways_backend.assert_called_once()
mock_process_count.assert_not_called()

@mock.patch.object(multihost, "is_jax_distributed_client_initialized")
@mock.patch.object(multihost, "process_count", return_value=2)
@mock.patch.object(multihost, "is_pathways_backend", return_value=False)
@mock.patch.object(multihost, "get_jax_distributed_client")
def test_mark_pathways_colocated_runtime_active_clears_cached_client(
self,
mock_get_jax_client,
mock_is_pathways_backend,
mock_process_count,
mock_is_init,
):
del mock_is_pathways_backend, mock_process_count
mock_is_init.return_value = True
mock_get_jax_client.return_value = mock.Mock()

client = signaling_client.get_signaling_client()
self.assertIsInstance(
client, signaling_client.JaxDistributedSignalingClient
)

mock_is_init.return_value = False
signaling_client.mark_pathways_colocated_runtime_active()

client = signaling_client.get_signaling_client()
self.assertIsInstance(
client, signaling_client.ThreadSafeKeyValueSignalingClient
)
self.assertTrue(multihost.is_pathways_colocated_runtime_active())

@mock.patch.object(
multihost, "is_jax_distributed_client_initialized", return_value=False
)
@mock.patch.object(multihost, "is_pathways_backend", return_value=False)
@mock.patch.object(
multihost, "is_pathways_colocated_runtime_active", return_value=False
)
@mock.patch.object(multihost, "process_count", return_value=2)
def test_raises_error_when_multiprocess_and_not_initialized(
self, mock_is_init, mock_process_count
self,
mock_process_count,
mock_is_pathways_colocated_runtime_active,
mock_is_pathways_backend,
mock_is_init,
):

with self.assertRaisesRegex(
Expand All @@ -342,6 +426,8 @@ def test_raises_error_when_multiprocess_and_not_initialized(
):
signaling_client.get_signaling_client()
mock_is_init.assert_called_once()
mock_is_pathways_backend.assert_called_once()
mock_is_pathways_colocated_runtime_active.assert_called_once()
mock_process_count.assert_called_once()


Expand Down
Loading
Loading