Skip to content

Commit 8a60e65

Browse files
authored
feat: support configurable gpu count and memory fraction (#29)
Signed-off-by: Cruz Zhao <[email protected]>
1 parent 490c222 commit 8a60e65

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

checkpoint_engine/ps.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -592,25 +592,41 @@ def batch_transfer_sync_read(
592592

593593
class ParameterServer:
594594
def __init__(
595-
self, *, rank: int | None = None, world_size: int | None = None, auto_pg: bool = False
595+
self,
596+
*,
597+
rank: int | None = None,
598+
world_size: int | None = None,
599+
auto_pg: bool = False,
600+
gpu_count: int | None = None,
601+
mem_fraction: float | None = None,
596602
):
597603
"""
598604
Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
599605
600606
Args:
601607
auto_pg: Whether to automatically initialize the process group.
602608
Notice that if auto_pg is True, will destroy the process group after update.
609+
mem_fraction: The proportion (as a fraction) of the current free CUDA memory for allocation.
603610
"""
604611
self._rank = rank or int(os.environ.get("RANK", None))
605612
self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
606-
self._gpu_count = torch.cuda.device_count()
613+
self._gpu_count = gpu_count or torch.cuda.device_count()
607614
self._local_rank = self._rank % self._gpu_count
608615
self._auto_pg = auto_pg
609616
self._all_hosts = []
610617
self._global_device_uuids: list[str] = []
618+
self._mem_fraction = mem_fraction or 0.9
611619

612620
assert self._rank is not None and self._rank >= 0, self._rank
613621
assert self._world_size and self._world_size > 0, self._world_size
622+
assert (
623+
self._gpu_count is not None
624+
and self._gpu_count > 0
625+
and self._gpu_count <= torch.cuda.device_count()
626+
), self._gpu_count
627+
assert (
628+
self._mem_fraction is not None and self._mem_fraction > 0 and self._mem_fraction <= 1
629+
), self._mem_fraction
614630

615631
self._zmq_ctx = zmq.Context()
616632
self._zmq_addr_counter = 0
@@ -834,8 +850,8 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
834850
# auto detect bucket size
835851
tensor = torch.tensor(
836852
[
837-
# 90% of current cuda free memory bytes
838-
int(float(torch.cuda.mem_get_info()[0]) * 0.9),
853+
# proportion of current cuda free memory bytes
854+
int(float(torch.cuda.mem_get_info()[0]) * self._mem_fraction),
839855
# we use negative value to reuse allreduce min operation
840856
# for getting the max value of zmq_addr_counter in all ranks
841857
-self._zmq_addr_counter,

0 commit comments

Comments
 (0)