hello! we are seeing the following failure pattern leading to timeouts in our medium sized (512 node) training runs on GPUs
- everyone finishes train step around the same time, and calls
manager.save around the same time
- save involves slice operations. when jax hits XLA ops in eager mode it compiles down to single op graphs
- get 10s-100s of slice calls, each go through compilation and autotuning
- compilation cache access (in our case, probably for others too) funnels to a single NFS path. some ranks get starved for file locks
- the starved ranks don't make it to barrier in time and cause timeout
example callstack of starved ranks
Thread 3266587 (idle): "MainThread"
backend_compile_and_load (jax/_src/compiler.py:362)
wrapper (jax/_src/profiler.py:384)
_compile_and_write_cache (jax/_src/compiler.py:746)
compile_or_get_cached (jax/_src/compiler.py:478)
_cached_compilation (jax/_src/interpreters/pxla.py:2843)
from_hlo (jax/_src/interpreters/pxla.py:3066)
compile (jax/_src/interpreters/pxla.py:2515)
_pjit_call_impl_python (jax/_src/pjit.py:1207)
_run_python_pjit (jax/_src/pjit.py:140)
cache_miss (jax/_src/pjit.py:255)
reraise_with_filtered_traceback (jax/_src/traceback_util.py:197)
apply_primitive (jax/_src/dispatch.py:91)
_slice_impl (jax/_src/lax/slicing.py:1484)
process_primitive (jax/_src/core.py:1208)
bind_with_trace (jax/_src/core.py:664)
_true_bind (jax/_src/core.py:652)
bind (jax/_src/core.py:636)
slice (jax/_src/lax/slicing.py:113)
slice_in_dim (jax/_src/lax/slicing.py:1016)
data (orbax/checkpoint/_src/serialization/replica_slices.py:100)
async_transfer_slice (orbax/checkpoint/_src/serialization/replica_slices.py:444)
transfer_arrays_to_host (orbax/checkpoint/_src/serialization/replica_slices.py:471)
_serialize_arrays_batches_without_dispatcher (orbax/checkpoint/_src/serialization/jax_array_handlers.py:363)
_serialize_arrays (orbax/checkpoint/_src/serialization/jax_array_handlers.py:447)
serialize (orbax/checkpoint/_src/serialization/jax_array_handlers.py:1073)
_logging_serialize (orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py:149)
repro script
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
os.environ["JAX_PLATFORMS"] = "cpu"
import tempfile
import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp
jax.config.update("jax_log_compiles", True)
mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(2, 2), ("data", "model"))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("data", None))
pytree = {
f"p{i}": jax.device_put(jnp.ones(s), sharding)
for i, s in enumerate(
[(8, 4), (16, 8), (32, 16), (64, 32), (128, 64), (24, 12), (48, 6)]
)
}
with tempfile.TemporaryDirectory() as tmpdir:
with ocp.CheckpointManager(tmpdir) as mgr:
mgr.save(0, args=ocp.args.StandardSave(pytree))
for now we're working around this by monkeypatching orbax s.t. it batches/jits the slices by mesh https://gist.github.com/jfc4050/27b8817e497b70a467e13720b5de20af.
Maybe there's a way to make jax not compile for eagerly called slice ops. Otherwise happy to open a PR if it would be helpful
hello! we are seeing the following failure pattern leading to timeouts in our medium sized (512 node) training runs on GPUs
manager.savearound the same timeexample callstack of starved ranks
repro script
for now we're working around this by monkeypatching orbax s.t. it batches/jits the slices by mesh https://gist.github.com/jfc4050/27b8817e497b70a467e13720b5de20af.
Maybe there's a way to make jax not compile for eagerly called slice ops. Otherwise happy to open a PR if it would be helpful