Skip to content

Commit a291782

Browse files
authored
optimize _update_per_bucket_p2p logic (#28)
1 parent bbc83db commit a291782

File tree

3 files changed

+260
-111
lines changed

3 files changed

+260
-111
lines changed

README.md

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ updating our [Kimi-K2](https://github.com/MoonshotAI/Kimi-K2) model (1 Trillion
1414

1515
The core weight update logic is in `ParameterServer` class, a service colocated with inference engines. It provides two implementations of weight update: Broadcast and P2P.
1616

17-
- **Broadcast**: Used when a large number of inference instances need to update weights in synchronous. This is the fastest implementation and should be used as the default update method. See `_update_per_bucket`.
18-
- **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket_p2p`.
17+
- **Broadcast**: Used when a large number of inference instances need to update weights in synchronous. This is the fastest implementation and should be used as the default update method. See `_update_per_bucket` with `ranks == None or []`.
18+
- **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket` with `ranks` specified.
1919

2020
### Optimized Weight Broadcast
2121
In the *Broadcast* implementation, the checkpoint-engine holds references to sharded weights in CPU memory, and need to efficiently broadcast them to a cluster of inference instances, often under a different sharding pattern.
@@ -36,23 +36,30 @@ It then executes the transfer, where it controls the inference engine through a
3636

3737
Pipelining naturally requires more GPU memory. When memory is not enough, checkpoint-engine will fallback to serial execution.
3838

39+
### Optimized P2P Bucket Assignment
40+
In the *P2P* implementation, checkpoint-engine needs to send weights from existing instances to new instances.
41+
To minimize the overall transfer time, checkpoint-engine optimizes the bucket assignment for each sender-receiver pair.
42+
The optimization goal is to make full use of the available network bandwidth for each sender and receiver.
43+
See [issue #25](https://github.com/MoonshotAI/checkpoint-engine/issues/25)
44+
3945
## Benchmark
4046

4147
| Model | Device Info | GatherMetas | Update (Broadcast) | Update (P2P) |
4248
| :----------------------------------- | :----------- | :---------- |:-------------------| :---------------------- |
43-
| GLM-4.5-Air (BF16) | 8xH800 TP8 | 0.17s | 3.94s (1.42GiB) | 8.83s (4.77GiB) |
44-
| Qwen3-235B-A22B-Instruct-2507 (BF16) | 8xH800 TP8 | 0.46s | 6.75s (2.69GiB) | 16.47s (4.05GiB) |
45-
| DeepSeek-V3.1 (FP8) | 16xH20 TP16 | 1.44s | 12.22s (2.38GiB) | 25.77s (3.61GiB) |
46-
| Kimi-K2-Instruct (FP8) | 16xH20 TP16 | 1.81s | 15.45s (2.93GiB) | 36.24s (4.46GiB) |
47-
| DeepSeek-V3.1 (FP8) | 256xH20 TP16 | 1.40s | 13.88s (2.54GiB) | 33.30s (3.86 GiB) |
48-
| Kimi-K2-Instruct (FP8) | 256xH20 TP16 | 1.88s | 21.50s (2.99GiB) | 34.49s (4.57 GiB) |
49+
| GLM-4.5-Air (BF16) | 8xH800 TP8 | 0.12s | 3.47s (3.02GiB) | 4.12s (3.02GiB) |
50+
| Qwen3-235B-A22B-Instruct-2507 (BF16) | 8xH800 TP8 | 0.33s | 6.22s (2.67GiB) | 7.10s (2.68GiB) |
51+
| DeepSeek-V3.1 (FP8) | 16xH20 TP16 | 1.17s | 10.19s (5.39GiB) | 11.80s (5.41GiB) |
52+
| Kimi-K2-Instruct (FP8) | 16xH20 TP16 | 1.33s | 14.36s (5.89GiB) | 17.49s (5.91GiB) |
53+
| DeepSeek-V3.1 (FP8) | 256xH20 TP16 | 0.80s | 11.33s (8.00GiB) | 11.81s (8.00GiB) |
54+
| Kimi-K2-Instruct (FP8) | 256xH20 TP16 | 1.22s | 16.04s (8.00GiB) | 16.75s (8.00GiB) |
4955

5056
All results above are tested by [`examples/update.py`](./examples/update.py) and use [vLLM v0.10.2rc1](https://github.com/vllm-project/vllm/tree/v0.10.2rc1) as inference engine. Some notes:
5157

5258
* FP8 test needs additional vLLM patches, see [FP8 quantization](#fp8-quantization).
5359
* Device Info: we tested various combination of devices and parallelism setups. For example, a 256-GPU TP16 setup means that we deploy 16 vLLM instances, each with 16-way tensor parallelism.
5460
* Since update duration is related to IPC bucket size, we provide the bucket size in the table.
5561
* The P2P time were tested for updating no more than two nodes (16 GPUs) (`ParameterServer.update(ranks=range(0, 16))`) out of the entire cluster.
62+
* We bind each GPU to its corresponding NUMA node to ensure stable H2D transfer speeds.
5663

5764
## Installation
5865

@@ -68,7 +75,7 @@ Use the flexible P2P implementation, notice this will install `mooncake-transfer
6875
pip install 'checkpoint-engine[p2p]'
6976
```
7077

71-
If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. If not set, it will read all RDMA devices and try to divide them into each rank.
78+
If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If not set, it will read all RDMA devices and try to divide them into each rank.
7279

7380
## Getting Started
7481

@@ -141,11 +148,11 @@ Run a simple correctness test for checkpoint_engine
141148
torchrun --nproc-per-node 8 tests/test_update.py
142149
```
143150

151+
Other unit tests can be done with pytest.
144152
## Limitations and Future Work
145153

146154
- This project is currently only tested with vLLM. But it is easy to integrate with other frameworks like SGLang.
147155
- The perfect three-stage pipeline mentioned in our paper is currently not implemented. This could be useful for architectures where H2D and broadcast do not conflict in PCIE.
148-
- The P2P update method is currently not the optimal implementation since it will receive data only in rank 0 and broadcast to others synchronizely. This is a potential optimization in the future.
149156

150157
## Acknowledgments
151158

0 commit comments

Comments
 (0)