Skip to content

Commit 322f684

Browse files
blahgeekspecture724Yikai Zhao
authored
misc: split ps.py file into multiple files [resolved conflict] (#75)
* feat: split p2p store and functions related to RDMA devices * feat: split data types into a separate file * feat: split pin_memory.py from and ps.py * feat: split api.py from ps.py * feat: split __main__.py from ps.py * feat: split files into multiple modules * feat: add entrypoint in ps.py for compatibility * fix: request to register error --------- Co-authored-by: specture724 <[email protected]> Co-authored-by: Yikai Zhao <[email protected]>
1 parent 4fc8d6f commit 322f684

File tree

8 files changed

+894
-796
lines changed

8 files changed

+894
-796
lines changed

checkpoint_engine/__init__.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,39 @@
22
from ._version import __version__
33
except ImportError:
44
__version__ = "dev"
5+
6+
from .api import request_inference_to_update
7+
from .data_types import (
8+
BucketRange,
9+
DataToGather,
10+
H2DBucket,
11+
MemoryBuffer,
12+
MemoryBufferMetaList,
13+
MemoryBufferMetas,
14+
ParameterMeta,
15+
)
16+
from .device_utils import DeviceManager, get_ip, npu_generate_uuid
17+
from .p2p_store import P2PStore
18+
from .ps import ParameterServer
19+
from .worker import FlattenedTensorMetadata, VllmColocateWorkerExtension, update_weights_from_ipc
20+
21+
22+
__all__ = [
23+
"BucketRange",
24+
"DataToGather",
25+
"DeviceManager",
26+
"FlattenedTensorMetadata",
27+
"H2DBucket",
28+
"MemoryBuffer",
29+
"MemoryBufferMetaList",
30+
"MemoryBufferMetas",
31+
"P2PStore",
32+
"ParameterMeta",
33+
"ParameterServer",
34+
"VllmColocateWorkerExtension",
35+
"__version__",
36+
"get_ip",
37+
"npu_generate_uuid",
38+
"request_inference_to_update",
39+
"update_weights_from_ipc",
40+
]

checkpoint_engine/__main__.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import argparse
2+
import os
3+
4+
from loguru import logger
5+
6+
from checkpoint_engine.api import _init_api
7+
from checkpoint_engine.ps import ParameterServer
8+
9+
10+
@logger.catch(reraise=True)
11+
def run_from_cli():
12+
import uvicorn
13+
14+
parser = argparse.ArgumentParser(description="Parameter Server")
15+
parser.add_argument("--uds", type=str)
16+
17+
args = parser.parse_args()
18+
logger.info(
19+
f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}"
20+
)
21+
22+
assert args.uds and len(args.uds) > 0, args.uds
23+
ps = ParameterServer(auto_pg=True)
24+
uvicorn.run(_init_api(ps), uds=args.uds, timeout_keep_alive=60)
25+
26+
27+
if __name__ == "__main__":
28+
run_from_cli()

checkpoint_engine/api.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from collections.abc import Callable
2+
from typing import Any
3+
4+
import fastapi
5+
import httpx
6+
from fastapi import Request
7+
from fastapi.responses import JSONResponse, Response
8+
from loguru import logger
9+
from pydantic import BaseModel
10+
11+
from checkpoint_engine.ps import ParameterServer
12+
13+
14+
def request_inference_to_update(
15+
url: str,
16+
socket_paths: dict[str, str],
17+
timeout: float = 300.0,
18+
uds: str | None = None,
19+
):
20+
"""Send an inference update request to inference server via HTTP or Unix socket.
21+
22+
Args:
23+
url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to.
24+
socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights.
25+
timeout (float, optional): Request timeout in seconds. Defaults to 300.0.
26+
uds (str, optional): Path to a Unix domain socket. If provided, the request
27+
will be sent via the Unix socket instead of HTTP. Defaults to None.
28+
29+
Raises:
30+
httpx.HTTPStatusError: If the response contains an HTTP error status.
31+
httpx.RequestError: If there was an issue while making the request.
32+
"""
33+
resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post(
34+
url,
35+
json={
36+
"method": "update_weights_from_ipc",
37+
"args": [socket_paths],
38+
"timeout": timeout,
39+
},
40+
timeout=timeout,
41+
)
42+
resp.raise_for_status()
43+
44+
45+
def _init_api(ps: ParameterServer) -> Any:
46+
app = fastapi.FastAPI()
47+
48+
class RegisterRequest(BaseModel):
49+
files: list[str]
50+
51+
class UpdateRequest(BaseModel):
52+
ranks: list[int] = []
53+
update_url: str | None = None
54+
inference_group_ranks: list[int] = []
55+
timeout: float = 300.0
56+
uds: str | None = None
57+
58+
def wrap_exception(func: Callable[[], None]) -> Response:
59+
try:
60+
func()
61+
except Exception as e: # noqa: BLE001
62+
logger.exception(f"wrap exception {func} failed")
63+
return JSONResponse(content=str(e), status_code=500)
64+
return Response(status_code=200)
65+
66+
@app.post("/v1/checkpoints/{checkpoint_name}/files")
67+
async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response:
68+
return wrap_exception(lambda: ps.register_checkpoint(checkpoint_name, files=req.files))
69+
70+
@app.delete("/v1/checkpoints/{checkpoint_name}")
71+
async def unregister_checkpoint(checkpoint_name: str) -> Response:
72+
return wrap_exception(lambda: ps.unregister_checkpoint(checkpoint_name))
73+
74+
@app.get("/v1/healthz")
75+
async def healthz() -> Response:
76+
return Response(status_code=200)
77+
78+
@app.post("/v1/checkpoints/{checkpoint_name}/gather-metas")
79+
async def gather_metas(checkpoint_name: str) -> Response:
80+
return wrap_exception(lambda: ps.gather_metas(checkpoint_name))
81+
82+
@app.post("/v1/checkpoints/{checkpoint_name}/update")
83+
async def update(checkpoint_name: str, req: UpdateRequest) -> Response:
84+
def update_func(socket_paths: list[tuple[str, str]]):
85+
if req.update_url is None:
86+
return
87+
if req.inference_group_ranks:
88+
socket_paths = [socket_paths[i] for i in req.inference_group_ranks]
89+
request_inference_to_update(
90+
req.update_url, dict(socket_paths), timeout=req.timeout, uds=req.uds
91+
)
92+
93+
return wrap_exception(lambda: ps.update(checkpoint_name, update_func, ranks=req.ranks))
94+
95+
return app

checkpoint_engine/data_types.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from typing import TYPE_CHECKING, Annotated, Any, NamedTuple
2+
3+
import torch
4+
from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
5+
6+
7+
if TYPE_CHECKING:
8+
from typing import TypeVar
9+
10+
from typing_extensions import TypedDict
11+
12+
class FileMeta(TypedDict):
13+
key: str # parameter name
14+
dtype: torch.dtype
15+
shape: torch.Size
16+
type: type
17+
tp_concat_dim: int
18+
19+
T = TypeVar("T")
20+
21+
22+
def _dt_validate(value: Any) -> torch.dtype:
23+
if isinstance(value, str):
24+
if not value.startswith("torch."):
25+
raise ValueError(f"dtype {value} should start with torch.")
26+
try:
27+
value = getattr(torch, value.split(".")[1])
28+
except AttributeError as e:
29+
raise ValueError(f"unknown dtype: {value}") from e
30+
if not isinstance(value, torch.dtype):
31+
raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}")
32+
return value
33+
34+
35+
_TorchDtype = Annotated[
36+
torch.dtype,
37+
PlainValidator(_dt_validate),
38+
PlainSerializer(lambda x: str(x), return_type=str),
39+
WithJsonSchema({"type": "string"}, mode="serialization"),
40+
]
41+
42+
43+
def _size_validate(value: Any) -> torch.Size:
44+
if isinstance(value, list | tuple):
45+
return torch.Size(value)
46+
if not isinstance(value, torch.Size):
47+
raise TypeError(f"size {value} should be torch.Size, got {type(value)}")
48+
return value
49+
50+
51+
_TorchSize = Annotated[
52+
torch.Size,
53+
PlainValidator(_size_validate),
54+
PlainSerializer(lambda x: tuple(x), return_type=tuple),
55+
WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"),
56+
]
57+
58+
59+
def _tensor_validate(value: Any) -> torch.Tensor:
60+
if isinstance(value, torch.Tensor):
61+
return value
62+
raise TypeError(f"tensor {value} should be torch.Tensor, got {type(value)}")
63+
64+
65+
_TorchTensor = Annotated[
66+
torch.Tensor,
67+
PlainValidator(_tensor_validate),
68+
]
69+
70+
71+
class ParameterMeta(BaseModel):
72+
name: str
73+
dtype: _TorchDtype
74+
shape: _TorchSize
75+
aligned_size: int
76+
77+
78+
class BucketRange(NamedTuple):
79+
idx: int # bucket_idx of MemoryBucket in memory_pool
80+
offset: int
81+
size: int
82+
83+
84+
class H2DBucket(BaseModel):
85+
size: int
86+
ranges: list[BucketRange]
87+
items: list[ParameterMeta]
88+
89+
90+
class MemoryBufferMetas(BaseModel):
91+
metas: list[ParameterMeta]
92+
ptr: int
93+
size: int
94+
95+
96+
class MemoryBuffer(BaseModel):
97+
buffer: _TorchTensor
98+
size: int
99+
metas: list[ParameterMeta]
100+
manually_pinned: bool = False
101+
102+
103+
class MemoryBufferMetaList(BaseModel):
104+
p2p_store_addr: str | None
105+
memory_buffer_metas_list: list[MemoryBufferMetas]
106+
rdma_device: str
107+
108+
109+
class DataToGather(MemoryBufferMetaList):
110+
host_ip: str
111+
device_uuid: str

0 commit comments

Comments
 (0)