Skip to content

Commit dd05eae

Browse files
committed
feat: split files into multiple modules
1 parent 77e2d61 commit dd05eae

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

checkpoint_engine/__init__.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,39 @@
22
from ._version import __version__
33
except ImportError:
44
__version__ = "dev"
5+
6+
from .api import request_inference_to_update
7+
from .data_types import (
8+
BucketRange,
9+
DataToGather,
10+
H2DBucket,
11+
MemoryBuffer,
12+
MemoryBufferMetaList,
13+
MemoryBufferMetas,
14+
ParameterMeta,
15+
)
16+
from .device_utils import DeviceManager, get_ip, npu_generate_uuid
17+
from .p2p_store import P2PStore
18+
from .ps import ParameterServer
19+
from .worker import FlattenedTensorMetadata, VllmColocateWorkerExtension, update_weights_from_ipc
20+
21+
22+
__all__ = [
23+
"BucketRange",
24+
"DataToGather",
25+
"DeviceManager",
26+
"FlattenedTensorMetadata",
27+
"H2DBucket",
28+
"MemoryBuffer",
29+
"MemoryBufferMetaList",
30+
"MemoryBufferMetas",
31+
"P2PStore",
32+
"ParameterMeta",
33+
"ParameterServer",
34+
"VllmColocateWorkerExtension",
35+
"__version__",
36+
"get_ip",
37+
"npu_generate_uuid",
38+
"request_inference_to_update",
39+
"update_weights_from_ipc",
40+
]

checkpoint_engine/ps.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from loguru import logger
1313
from torch.multiprocessing.reductions import reduce_tensor
1414

15-
from checkpoint_engine.api import _init_api
1615
from checkpoint_engine.data_types import (
1716
BucketRange,
1817
DataToGather,
@@ -22,7 +21,6 @@
2221
MemoryBufferMetas,
2322
ParameterMeta,
2423
)
25-
2624
from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
2725
from checkpoint_engine.p2p_store import P2PStore
2826
from checkpoint_engine.pin_memory import _ALIGN_SIZE, _register_checkpoint
@@ -31,6 +29,7 @@
3129
if TYPE_CHECKING:
3230
from checkpoint_engine.data_types import T
3331

32+
3433
def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
3534
ret = []
3635
for meta in metas:

0 commit comments

Comments
 (0)