@@ -592,25 +592,41 @@ def batch_transfer_sync_read(
592592
593593class ParameterServer :
594594 def __init__ (
595- self , * , rank : int | None = None , world_size : int | None = None , auto_pg : bool = False
595+ self ,
596+ * ,
597+ rank : int | None = None ,
598+ world_size : int | None = None ,
599+ auto_pg : bool = False ,
600+ gpu_count : int | None = None ,
601+ mem_fraction : float | None = None ,
596602 ):
597603 """
598604 Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
599605
600606 Args:
601607 auto_pg: Whether to automatically initialize the process group.
602608 Notice that if auto_pg is True, will destroy the process group after update.
609+ mem_fraction: The proportion (as a fraction) of the current free CUDA memory for allocation.
603610 """
604611 self ._rank = rank or int (os .environ .get ("RANK" , None ))
605612 self ._world_size = world_size or int (os .environ .get ("WORLD_SIZE" , None ))
606- self ._gpu_count = torch .cuda .device_count ()
613+ self ._gpu_count = gpu_count or torch .cuda .device_count ()
607614 self ._local_rank = self ._rank % self ._gpu_count
608615 self ._auto_pg = auto_pg
609616 self ._all_hosts = []
610617 self ._global_device_uuids : list [str ] = []
618+ self ._mem_fraction = mem_fraction or 0.9
611619
612620 assert self ._rank is not None and self ._rank >= 0 , self ._rank
613621 assert self ._world_size and self ._world_size > 0 , self ._world_size
622+ assert (
623+ self ._gpu_count is not None
624+ and self ._gpu_count > 0
625+ and self ._gpu_count <= torch .cuda .device_count ()
626+ ), self ._gpu_count
627+ assert (
628+ self ._mem_fraction is not None and self ._mem_fraction > 0 and self ._mem_fraction <= 1
629+ ), self ._mem_fraction
614630
615631 self ._zmq_ctx = zmq .Context ()
616632 self ._zmq_addr_counter = 0
@@ -834,8 +850,8 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
834850 # auto detect bucket size
835851 tensor = torch .tensor (
836852 [
837- # 90% of current cuda free memory bytes
838- int (float (torch .cuda .mem_get_info ()[0 ]) * 0.9 ),
853+ # proportion of current cuda free memory bytes
854+ int (float (torch .cuda .mem_get_info ()[0 ]) * self . _mem_fraction ),
839855 # we use negative value to reuse allreduce min operation
840856 # for getting the max value of zmq_addr_counter in all ranks
841857 - self ._zmq_addr_counter ,
0 commit comments