You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+17-10Lines changed: 17 additions & 10 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -14,8 +14,8 @@ updating our [Kimi-K2](https://github.com/MoonshotAI/Kimi-K2) model (1 Trillion
14
14
15
15
The core weight update logic is in `ParameterServer` class, a service colocated with inference engines. It provides two implementations of weight update: Broadcast and P2P.
16
16
17
-
-**Broadcast**: Used when a large number of inference instances need to update weights in synchronous. This is the fastest implementation and should be used as the default update method. See `_update_per_bucket`.
18
-
-**P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket_p2p`.
17
+
-**Broadcast**: Used when a large number of inference instances need to update weights in synchronous. This is the fastest implementation and should be used as the default update method. See `_update_per_bucket` with `ranks == None or []`.
18
+
-**P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket` with `ranks` specified.
19
19
20
20
### Optimized Weight Broadcast
21
21
In the *Broadcast* implementation, the checkpoint-engine holds references to sharded weights in CPU memory, and need to efficiently broadcast them to a cluster of inference instances, often under a different sharding pattern.
@@ -36,23 +36,30 @@ It then executes the transfer, where it controls the inference engine through a
36
36
37
37
Pipelining naturally requires more GPU memory. When memory is not enough, checkpoint-engine will fallback to serial execution.
38
38
39
+
### Optimized P2P Bucket Assignment
40
+
In the *P2P* implementation, checkpoint-engine needs to send weights from existing instances to new instances.
41
+
To minimize the overall transfer time, checkpoint-engine optimizes the bucket assignment for each sender-receiver pair.
42
+
The optimization goal is to make full use of the available network bandwidth for each sender and receiver.
43
+
See [issue #25](https://github.com/MoonshotAI/checkpoint-engine/issues/25)
44
+
39
45
## Benchmark
40
46
41
47
| Model | Device Info | GatherMetas | Update (Broadcast) | Update (P2P) |
All results above are tested by [`examples/update.py`](./examples/update.py) and use [vLLM v0.10.2rc1](https://github.com/vllm-project/vllm/tree/v0.10.2rc1) as inference engine. Some notes:
51
57
52
58
* FP8 test needs additional vLLM patches, see [FP8 quantization](#fp8-quantization).
53
59
* Device Info: we tested various combination of devices and parallelism setups. For example, a 256-GPU TP16 setup means that we deploy 16 vLLM instances, each with 16-way tensor parallelism.
54
60
* Since update duration is related to IPC bucket size, we provide the bucket size in the table.
55
61
* The P2P time were tested for updating no more than two nodes (16 GPUs) (`ParameterServer.update(ranks=range(0, 16))`) out of the entire cluster.
62
+
* We bind each GPU to its corresponding NUMA node to ensure stable H2D transfer speeds.
56
63
57
64
## Installation
58
65
@@ -68,7 +75,7 @@ Use the flexible P2P implementation, notice this will install `mooncake-transfer
68
75
pip install 'checkpoint-engine[p2p]'
69
76
```
70
77
71
-
If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. If not set, it will read all RDMA devices and try to divide them into each rank.
78
+
If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If not set, it will read all RDMA devices and try to divide them into each rank.
72
79
73
80
## Getting Started
74
81
@@ -141,11 +148,11 @@ Run a simple correctness test for checkpoint_engine
141
148
torchrun --nproc-per-node 8 tests/test_update.py
142
149
```
143
150
151
+
Other unit tests can be done with pytest.
144
152
## Limitations and Future Work
145
153
146
154
- This project is currently only tested with vLLM. But it is easy to integrate with other frameworks like SGLang.
147
155
- The perfect three-stage pipeline mentioned in our paper is currently not implemented. This could be useful for architectures where H2D and broadcast do not conflict in PCIE.
148
-
- The P2P update method is currently not the optimal implementation since it will receive data only in rank 0 and broadcast to others synchronizely. This is a potential optimization in the future.
0 commit comments