@@ -175,6 +175,8 @@ def __init__(
175175 auto_pg : bool = True ,
176176 gpu_count : int | None = None ,
177177 mem_fraction : float | None = None ,
178+ master_addr : str | None = None ,
179+ master_port : int | None = None ,
178180 ):
179181 """
180182 Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
@@ -228,6 +230,17 @@ def __init__(
228230 self ._device_uuid = _get_physical_gpu_id (self .device_manager , device_index )
229231 self ._rdma_device = None if self ._p2p_store is None else self ._p2p_store .device
230232
233+ master_addr = master_addr or os .getenv ("MASTER_ADDR" )
234+ assert master_addr , "master_addr is required"
235+ self ._store = torch .distributed .TCPStore (
236+ master_addr ,
237+ _get_master_port (master_port ),
238+ self ._world_size ,
239+ timeout = timedelta (minutes = 10 ),
240+ is_master = self ._rank == 0 ,
241+ )
242+ self ._store_counter = 0
243+
231244 def _get_memory_pool (self , checkpoint_name : str ) -> list [MemoryBuffer ]:
232245 if checkpoint_name == self ._current_shared_memory_pool_user :
233246 assert self ._memory_pool [self .shared_memory_pool_name ], (
@@ -487,8 +500,6 @@ def gather_metas(self, checkpoint_name: str):
487500 def init_process_group (
488501 self ,
489502 * ,
490- master_addr : str | None = None ,
491- master_port : int | None = None ,
492503 timeout : timedelta = timedelta (minutes = 10 ),
493504 ):
494505 """
@@ -498,27 +509,18 @@ def init_process_group(
498509 master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
499510 timeout: The timeout of the process group.
500511 """
501- master_addr = master_addr or os .getenv ("MASTER_ADDR" )
502- assert master_addr , "master_addr is required"
503- store = dist .TCPStore (
504- master_addr ,
505- _get_master_port (master_port ),
506- self ._world_size ,
507- timeout = timeout ,
508- is_master = self ._rank == 0 ,
509- )
512+ self ._store_counter += 1
513+ sub_store = torch .distributed .PrefixStore (f"prefix-{ self ._store_counter } " , self ._store )
510514 dist .init_process_group (
511515 backend = self .device_manager .backend ,
512516 world_size = self ._world_size ,
513517 rank = self ._rank ,
514518 timeout = timeout ,
515- store = store ,
519+ store = sub_store ,
516520 )
517521 logger .info (f"[rank{ self ._rank } ] init process group successfully." )
518522
519- def store_based_barrier (
520- self , store : dist .TCPStore , timeout : timedelta = timedelta (minutes = 5 )
521- ) -> None :
523+ def store_based_barrier (self , timeout : timedelta = timedelta (minutes = 5 )) -> None :
522524 """
523525 Perform a store-based barrier synchronization across all ranks.
524526
@@ -531,7 +533,7 @@ def store_based_barrier(
531533 """
532534 dist .distributed_c10d ._store_based_barrier (
533535 rank = self ._rank ,
534- store = store ,
536+ store = self . _store ,
535537 group_name = "parameter_server_barrier" ,
536538 rendezvous_count = self ._world_size ,
537539 timeout = timeout ,
@@ -544,8 +546,6 @@ def update(
544546 * ,
545547 timeout : timedelta = timedelta (minutes = 10 ),
546548 ranks : list [int ] | None = None ,
547- master_addr : str | None = None ,
548- master_port : int | None = None ,
549549 ) -> None :
550550 """
551551 Update the checkpoint to inference engine. This function should be called after gather_metas.
@@ -566,28 +566,12 @@ def update(
566566 assert req_func is not None , "req_func is required"
567567 ranks_group = None
568568 try :
569- master_addr = os .getenv ("MASTER_ADDR" ) or master_addr
570- assert master_addr , "master_addr is required"
571- if self ._auto_pg :
572- if not dist .is_initialized ():
573- self .init_process_group (
574- timeout = timeout , master_addr = master_addr , master_port = master_port
575- )
576- manager_store = dist .distributed_c10d ._get_default_store ()
577- else :
578- # HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
579- # If master_port is provided, use master_port+1 for barrier store
580- manager_store = dist .TCPStore (
581- master_addr ,
582- _get_master_port (master_port ) + 1 ,
583- self ._world_size ,
584- timeout = timeout ,
585- is_master = self ._rank == 0 ,
586- )
569+ if self ._auto_pg and not dist .is_initialized ():
570+ self .init_process_group (timeout = timeout )
587571 # if ranks is None or [], it will use fully broadcast to update to all ranks
588572 ranks_group = dist .new_group (ranks ) if ranks else None
589573 self ._update_per_bucket (checkpoint_name , req_func , ranks_group , ranks )
590- self .store_based_barrier (manager_store )
574+ self .store_based_barrier ()
591575 except Exception as e :
592576 logger .exception (
593577 f"[rank{ self ._rank } ] update checkpoint { checkpoint_name } with ranks { ranks } error { e } "
0 commit comments