Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,17 @@ A [PR](https://github.com/vllm-project/vllm/pull/24488) is opened to the vLLM pr
Run a simple correctness test for checkpoint_engine

```bash
torchrun --nproc-per-node 8 tests/test_update.py
pytest tests/test_update.py
```

`test_update.py` are only designed to run with `pytest`. Please don't run it directly with `torchrun`.

Other unit tests can also be done with pytest. Only test_update.py requires GPUs, other tests can be run on CPUs. Only to run CPU tests, use:

```bash
pytest tests/ -m "not gpu"
```

Other unit tests can be done with pytest.
## SGLang Integration

Checkpoint Engine provides efficient distributed checkpoint loading for SGLang inference servers, significantly reducing model loading time for large models and multi-node setups.
Expand Down
127 changes: 71 additions & 56 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def __init__(
Args:
auto_pg: Whether to automatically initialize the process group.
Notice that if auto_pg is True, will destroy the process group after update.
mem_fraction: The proportion (as a fraction) of the current free CUDA memory for allocation.
mem_fraction: The proportion (as a fraction) of the current free device memory for allocation.
"""
self._rank = rank or int(os.environ.get("RANK", None))
self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
Expand Down Expand Up @@ -988,21 +988,22 @@ def update(
if self._rank not in ranks:
return
self._update_per_bucket(checkpoint_name, req_func, ranks)
if self._auto_pg:

except Exception as e:
logger.exception(
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
)
raise
finally:
if self._auto_pg and (not ranks or self._rank in ranks):
dist.destroy_process_group()

self.device_manager.device_module.empty_cache()

logger.info(
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
f"Current CUDA allocated {self.device_manager.device_module.memory_allocated() / 1024 / 1024} MB, "
f"Current device allocated {self.device_manager.device_module.memory_allocated() / 1024 / 1024} MB, "
f"reserved {self.device_manager.device_module.memory_reserved() / 1024 / 1024} MB."
)
except Exception as e:
logger.exception(
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
)
raise

def _bind_zmq_socket(self) -> tuple[zmq.Socket, list[tuple[str, str]]]:
def zmq_handle(device_uuid: str) -> str:
Expand All @@ -1019,7 +1020,7 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
# auto detect bucket size
tensor = torch.tensor(
[
# proportion of current cuda free memory bytes
# proportion of current device free memory bytes
int(
float(self.device_manager.device_module.mem_get_info()[0]) * self._mem_fraction
),
Expand Down Expand Up @@ -1183,7 +1184,7 @@ def _update_per_bucket(

if not need_update:
return
# first execute a barrier to avoid subsequent cuda oom
# first execute a barrier to avoid subsequent device oom
dist.barrier()

bucket_size, disable_h2d_buffer = self._detect_bucket_size()
Expand Down Expand Up @@ -1232,52 +1233,66 @@ def _update_per_bucket(
socket.send_pyobj(handle)

gidx = 0
ret_code = torch.zeros((), device=self.device_manager.device_type, dtype=torch.int64)
bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks)
for i in range(max_len):
if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
self._copy_to_buffer(
checkpoint_name,
receiver_rank_buckets[i][1],
h2d_buffer,
receiver_rank_buckets[i][0] if ranks else None,
)
for receiver_rank, _buckets in buckets_by_receiver_rank.items():
if i >= len(_buckets):
continue
bucket = _buckets[i]
alloc, reserved = (
self.device_manager.device_module.memory_allocated() / 1024 / 1024,
self.device_manager.device_module.memory_reserved() / 1024 / 1024,
)
self._logger_rank0(
f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} receiver_rank {receiver_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
f"Current CUDA allocated {alloc:.2f} MB, "
f"reserved {reserved:.2f} MB."
)
start = gidx % 2 * bucket_size
buffer_b: torch.Tensor = buffer[start : start + bucket.size]
if receiver_rank == self._rank:
if disable_h2d_buffer:
self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
else:
buffer_b.data.copy_(h2d_buffer[: bucket.size])
brank = bcast_rank_map[receiver_rank]
dist.broadcast(buffer_b, src=brank)
socket.recv()
dist.barrier()
socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
gidx += 1

socket.recv()
socket.send_pyobj(None)
socket.recv()
req_thread.join()
dist.barrier()
socket.close()
if ranks and h2d_buffer is not None:
self._p2p_store.unregister_named_tensors([h2d_buffer_name])

self.device_manager.device_module.empty_cache()
try:
for i in range(max_len):
if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
self._copy_to_buffer(
checkpoint_name,
receiver_rank_buckets[i][1],
h2d_buffer,
receiver_rank_buckets[i][0] if ranks else None,
)
for receiver_rank, _buckets in buckets_by_receiver_rank.items():
if i >= len(_buckets):
continue
bucket = _buckets[i]
alloc, reserved = (
self.device_manager.device_module.memory_allocated() / 1024 / 1024,
self.device_manager.device_module.memory_reserved() / 1024 / 1024,
)
self._logger_rank0(
f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} receiver_rank {receiver_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
f"Current device allocated {alloc:.2f} MB, "
f"reserved {reserved:.2f} MB."
)
start = gidx % 2 * bucket_size
buffer_b: torch.Tensor = buffer[start : start + bucket.size]
if receiver_rank == self._rank:
if disable_h2d_buffer:
self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
else:
buffer_b.data.copy_(h2d_buffer[: bucket.size])
brank = bcast_rank_map[receiver_rank]
dist.broadcast(buffer_b, src=brank)
resp = socket.recv()
if resp != b"":
exception_obj = pickle.loads(resp)
logger.error(
f"[rank{self._rank}] receive error response '{type(exception_obj).__name__}: {exception_obj}' from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}"
)
ret_code.fill_(1)
dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
self.device_manager.device_module.synchronize()
if ret_code.item() != 0:
# quit early if any rank failed
socket.send_pyobj(RuntimeError("Some workers failed to update weights"))
raise RuntimeError("Failed to update weights due to remote errors")
socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
gidx += 1

socket.recv()
socket.send_pyobj(None)
socket.recv()
finally:
req_thread.join()
dist.barrier()
socket.close()
if ranks and h2d_buffer is not None:
self._p2p_store.unregister_named_tensors([h2d_buffer_name])

self.device_manager.device_module.empty_cache()


def _init_api(ps: ParameterServer) -> Any:
Expand Down
65 changes: 40 additions & 25 deletions checkpoint_engine/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,32 +55,47 @@ def update_weights_from_ipc(
socket = zmq_ctx.socket(zmq.REP)
socket.connect(zmq_handle)
buffer: torch.Tensor | None = None
device_mananger = DeviceManager()
while True:
payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = socket.recv_pyobj()
if payload is None:
# means the update is done
if post_hook is not None:
post_hook()
device_mananger.device_module.synchronize()
socket.send(b"")
break
if isinstance(payload, tuple):
# an ipc handle that vLLM can use `func, args = handle`
# and `func(*args)` to rebuild GPU tensor.
buffer = _rebuild_ipc(payload, device_id)
assert buffer.dtype == torch.uint8
socket.send(b"")
continue
assert isinstance(payload, list)
run(_extract_weights(payload, buffer))
device_mananger.device_module.synchronize()
device_manager = DeviceManager()
try:
ipc_handle: tuple[Callable, tuple] = socket.recv_pyobj()
assert isinstance(ipc_handle, tuple)
buffer = _rebuild_ipc(ipc_handle, device_id)
assert buffer.dtype == torch.uint8
socket.send(b"")

socket.close()
del buffer
gc.collect()
device_mananger.device_module.empty_cache()
except Exception as e:
socket.send_pyobj(e)
socket.recv() # wait for ack
raise
try:
while True:
payload: list[FlattenedTensorMetadata] | Exception | None = socket.recv_pyobj()
if payload is None: # done signal
if post_hook is not None:
post_hook()
device_manager.device_module.synchronize()
socket.send(b"")
break
if isinstance(payload, list): # still updating weights
try:
run(_extract_weights(payload, buffer))
device_manager.device_module.synchronize()
socket.send(b"")
except Exception as e: # noqa: BLE001
# Send exception back to Parameter Server.
# Don't raise here. Because all workers should quit in the same way by receiving the exception from PS
socket.send_pyobj(e)
elif isinstance(
payload, Exception
): # error occurred, got force quit signal from Parameter Server
raise payload
else:
raise TypeError(f"Unexpected payload type: {type(payload)}")

finally:
socket.close()
del buffer
gc.collect()
device_manager.device_module.empty_cache()


class VllmColocateWorkerExtension:
Expand Down
Loading