Skip to content
36 changes: 36 additions & 0 deletions checkpoint_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,39 @@
from ._version import __version__
except ImportError:
__version__ = "dev"

from .api import request_inference_to_update
from .data_types import (
BucketRange,
DataToGather,
H2DBucket,
MemoryBuffer,
MemoryBufferMetaList,
MemoryBufferMetas,
ParameterMeta,
)
from .device_utils import DeviceManager, get_ip, npu_generate_uuid
from .p2p_store import P2PStore
from .ps import ParameterServer
from .worker import FlattenedTensorMetadata, VllmColocateWorkerExtension, update_weights_from_ipc


__all__ = [
"BucketRange",
"DataToGather",
"DeviceManager",
"FlattenedTensorMetadata",
"H2DBucket",
"MemoryBuffer",
"MemoryBufferMetaList",
"MemoryBufferMetas",
"P2PStore",
"ParameterMeta",
"ParameterServer",
"VllmColocateWorkerExtension",
"__version__",
"get_ip",
"npu_generate_uuid",
"request_inference_to_update",
"update_weights_from_ipc",
]
28 changes: 28 additions & 0 deletions checkpoint_engine/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import argparse
import os

from loguru import logger

from checkpoint_engine.api import _init_api
from checkpoint_engine.ps import ParameterServer


@logger.catch(reraise=True)
def run_from_cli():
import uvicorn

parser = argparse.ArgumentParser(description="Parameter Server")
parser.add_argument("--uds", type=str)

args = parser.parse_args()
logger.info(
f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}"
)

assert args.uds and len(args.uds) > 0, args.uds
ps = ParameterServer(auto_pg=True)
uvicorn.run(_init_api(ps), uds=args.uds, timeout_keep_alive=60)


if __name__ == "__main__":
run_from_cli()
95 changes: 95 additions & 0 deletions checkpoint_engine/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from collections.abc import Callable
from typing import Any

import fastapi
import httpx
from fastapi import Request
from fastapi.responses import JSONResponse, Response
from loguru import logger
from pydantic import BaseModel

from checkpoint_engine.ps import ParameterServer


def request_inference_to_update(
url: str,
socket_paths: dict[str, str],
timeout: float = 300.0,
uds: str | None = None,
):
"""Send an inference update request to inference server via HTTP or Unix socket.

Args:
url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to.
socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights.
timeout (float, optional): Request timeout in seconds. Defaults to 300.0.
uds (str, optional): Path to a Unix domain socket. If provided, the request
will be sent via the Unix socket instead of HTTP. Defaults to None.

Raises:
httpx.HTTPStatusError: If the response contains an HTTP error status.
httpx.RequestError: If there was an issue while making the request.
"""
resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post(
url,
json={
"method": "update_weights_from_ipc",
"args": [socket_paths],
"timeout": timeout,
},
timeout=timeout,
)
resp.raise_for_status()


def _init_api(ps: ParameterServer) -> Any:
app = fastapi.FastAPI()

class RegisterRequest(BaseModel):
files: list[str]

class UpdateRequest(BaseModel):
ranks: list[int] = []
update_url: str | None = None
inference_group_ranks: list[int] = []
timeout: float = 300.0
uds: str | None = None

def wrap_exception(func: Callable[[], None]) -> Response:
try:
func()
except Exception as e: # noqa: BLE001
logger.exception(f"wrap exception {func} failed")
return JSONResponse(content=str(e), status_code=500)
return Response(status_code=200)

@app.post("/v1/checkpoints/{checkpoint_name}/files")
async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response:
return wrap_exception(lambda: ps.register_checkpoint(checkpoint_name, files=req.files))

@app.delete("/v1/checkpoints/{checkpoint_name}")
async def unregister_checkpoint(checkpoint_name: str) -> Response:
return wrap_exception(lambda: ps.unregister_checkpoint(checkpoint_name))

@app.get("/v1/healthz")
async def healthz() -> Response:
return Response(status_code=200)

@app.post("/v1/checkpoints/{checkpoint_name}/gather-metas")
async def gather_metas(checkpoint_name: str) -> Response:
return wrap_exception(lambda: ps.gather_metas(checkpoint_name))

@app.post("/v1/checkpoints/{checkpoint_name}/update")
async def update(checkpoint_name: str, req: UpdateRequest) -> Response:
def update_func(socket_paths: list[tuple[str, str]]):
if req.update_url is None:
return
if req.inference_group_ranks:
socket_paths = [socket_paths[i] for i in req.inference_group_ranks]
request_inference_to_update(
req.update_url, dict(socket_paths), timeout=req.timeout, uds=req.uds
)

return wrap_exception(lambda: ps.update(checkpoint_name, update_func, ranks=req.ranks))

return app
111 changes: 111 additions & 0 deletions checkpoint_engine/data_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import TYPE_CHECKING, Annotated, Any, NamedTuple

import torch
from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema


if TYPE_CHECKING:
from typing import TypeVar

from typing_extensions import TypedDict

class FileMeta(TypedDict):
key: str # parameter name
dtype: torch.dtype
shape: torch.Size
type: type
tp_concat_dim: int

T = TypeVar("T")


def _dt_validate(value: Any) -> torch.dtype:
if isinstance(value, str):
if not value.startswith("torch."):
raise ValueError(f"dtype {value} should start with torch.")
try:
value = getattr(torch, value.split(".")[1])
except AttributeError as e:
raise ValueError(f"unknown dtype: {value}") from e
if not isinstance(value, torch.dtype):
raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}")
return value


_TorchDtype = Annotated[
torch.dtype,
PlainValidator(_dt_validate),
PlainSerializer(lambda x: str(x), return_type=str),
WithJsonSchema({"type": "string"}, mode="serialization"),
]


def _size_validate(value: Any) -> torch.Size:
if isinstance(value, list | tuple):
return torch.Size(value)
if not isinstance(value, torch.Size):
raise TypeError(f"size {value} should be torch.Size, got {type(value)}")
return value


_TorchSize = Annotated[
torch.Size,
PlainValidator(_size_validate),
PlainSerializer(lambda x: tuple(x), return_type=tuple),
WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"),
]


def _tensor_validate(value: Any) -> torch.Tensor:
if isinstance(value, torch.Tensor):
return value
raise TypeError(f"tensor {value} should be torch.Tensor, got {type(value)}")


_TorchTensor = Annotated[
torch.Tensor,
PlainValidator(_tensor_validate),
]


class ParameterMeta(BaseModel):
name: str
dtype: _TorchDtype
shape: _TorchSize
aligned_size: int


class BucketRange(NamedTuple):
idx: int # bucket_idx of MemoryBucket in memory_pool
offset: int
size: int


class H2DBucket(BaseModel):
size: int
ranges: list[BucketRange]
items: list[ParameterMeta]


class MemoryBufferMetas(BaseModel):
metas: list[ParameterMeta]
ptr: int
size: int


class MemoryBuffer(BaseModel):
buffer: _TorchTensor
size: int
metas: list[ParameterMeta]
manually_pinned: bool = False


class MemoryBufferMetaList(BaseModel):
p2p_store_addr: str | None
memory_buffer_metas_list: list[MemoryBufferMetas]
rdma_device: str


class DataToGather(MemoryBufferMetaList):
host_ip: str
device_uuid: str
Loading