Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions recipe/one_step_off_policy/ckpt_engine_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright 2025 Meituan Ltd. and/or its affiliates
# Copyright 2025 Huawei Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import time

import httpx
import torch
import torch.distributed
from checkpoint_engine.ps import ParameterServer, request_inference_to_update
from omegaconf import DictConfig, OmegaConf

from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils.device import (
get_device_name,
)

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))

device_name = get_device_name()


class CkptEngineWorker(Worker):
def __init__(self, rank_offset, ps_world_size, inference_parallel_size, rollout_name):
super().__init__()
rank = self.rank + rank_offset
self.ps_rank = rank
self.ps_rank_offset = rank_offset
self.ps_world_size = ps_world_size
self.inference_parallel_size = inference_parallel_size
self.rollout_name = rollout_name
self.ps = ParameterServer(rank=rank, world_size=ps_world_size)
self.index = 0

def _init_process_group(self):
os.environ["HCCL_NPU_SOCKET_PORT_RANGE"] = "61020"
self.ps.init_process_group(device_index=0, master_port=60010)
del os.environ["HCCL_NPU_SOCKET_PORT_RANGE"]

def check_vllm_ready(self, uds: str | None = None):
if self.ps_rank != self.ps_rank // self.inference_parallel_size * self.inference_parallel_size:
return
retry_num = 0
transport = None
if uds is not None:
transport = httpx.HTTPTransport(uds=uds)
while True:
try:
response = httpx.Client(transport=transport).get(f"{self.endpoint}/health", timeout=10)
response.raise_for_status()
break
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
retry_num += 1
logger.warning(f"fail to check vllm ready, retry {retry_num} times, error: {e}")
time.sleep(5)
Comment on lines +59 to +71
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The while True loop for checking vLLM readiness can run indefinitely if the server fails to start, causing the worker to hang. It's much safer to implement a timeout mechanism with a maximum number of retries. This ensures that the worker will eventually fail with a clear error message instead of getting stuck. I've also moved the httpx.Client instantiation out of the loop for efficiency.

Suggested change
retry_num = 0
transport = None
if uds is not None:
transport = httpx.HTTPTransport(uds=uds)
while True:
try:
response = httpx.Client(transport=transport).get(f"{self.endpoint}/health", timeout=10)
response.raise_for_status()
break
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
retry_num += 1
logger.warning(f"fail to check vllm ready, retry {retry_num} times, error: {e}")
time.sleep(5)
retry_num = 0
max_retries = 60 # e.g., 5 minutes
transport = httpx.HTTPTransport(uds=uds) if uds is not None else None
client = httpx.Client(transport=transport)
while retry_num < max_retries:
try:
response = client.get(f"{self.endpoint}/health", timeout=10)
response.raise_for_status()
logger.info("vLLM server is ready.")
return
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
retry_num += 1
logger.warning(f"fail to check vllm ready, retry {retry_num}/{max_retries} times, error: {e}")
time.sleep(5)
raise RuntimeError(f"vLLM server not ready after {max_retries} retries.")


def check_sglang_ready(self, uds: str | None = None):
if self.ps_rank != self.ps_rank // self.inference_parallel_size * self.inference_parallel_size:
return
retry_num = 0
transport = None
if uds is not None:
transport = httpx.HTTPTransport(uds=uds)
with httpx.Client(transport=transport) as client:
while True:
try:
response = client.get(f"{self.endpoint}/ping", timeout=10)
response.raise_for_status()
break
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
if retry_num % 10 == 0:
logger.warning(
f"fail to check sglang ready, retry {retry_num} times, error: {e}"
)
retry_num += 1
time.sleep(0.1)

@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def set_server_addresses(self, server_addresses: list[str]):
# todo support multiple api server
self.endpoint = f"http://{server_addresses[0]}"
if self.rollout_name == "sglang":
self.check_sglang_ready()
elif self.rollout_name == "vllm":
self.check_vllm_ready()

@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def sync_rollout_weights_by_ckpt_engine(self):
rank = self.rank
src = rank // self.inference_parallel_size * self.inference_parallel_size

def vllm_req_func(socket_paths: list[tuple[str, str]]) -> None:
if rank == src:
request_inference_to_update(
url=f"{self.endpoint}/collective_rpc",
socket_paths=dict(socket_paths),
)

def sglang_req_func(socket_paths: list[tuple[str, str]]) -> None:
if rank == src:
with httpx.Client(transport=httpx.HTTPTransport()) as client:
resp = client.post(
f"{self.endpoint}/update_weights_from_ipc",
json={
"zmq_handles": dict(socket_paths),
"flush_cache": True,
"weight_version": None,
},
timeout=300.0,
)
resp.raise_for_status()
pass

if self.rollout_name == "sglang":
req_func = sglang_req_func
elif self.rollout_name == "vllm":
req_func = vllm_req_func

self._init_process_group()
checkpoint_name = f"sync_{self.index}"
self.ps.register_checkpoint(checkpoint_name=checkpoint_name)
self.ps.gather_metas(checkpoint_name)
self.ps.update(checkpoint_name, req_func, ranks=list(range(self.ps_rank_offset, self.ps_world_size)))
self.index += 1
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ actor_rollout_ref:
free_cache_engine: False
# Must be enabled! Otherwise, log_probs cannot be calculated.
calculate_log_probs: True
engine_kwargs:
vllm:
worker_extension_cls: checkpoint_engine.worker.VllmColocateWorkerExtension

# Only then will the use of log probs be correct.
# And it can be used in conjunction with other rollout_correction algorithms.
algorithm:
rollout_correction:
bypass_mode: True
bypass_mode: True
69 changes: 55 additions & 14 deletions recipe/one_step_off_policy/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@

import torch
import torch.distributed
from checkpoint_engine.ps import ParameterServer
from omegaconf import DictConfig
from ray.util.collective import collective
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from recipe.one_step_off_policy.distributed_util import vllm_stateless_init_process_group
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils.device import (
get_device_name,
Expand Down Expand Up @@ -53,17 +53,6 @@ class DetachSync(AsyncActorRolloutRefWorker):
def _get_actor_params(self):
pass

@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def create_weight_sync_group(self, master_address, master_port, rank_offset, world_size):
rank = torch.distributed.get_rank() + rank_offset
self._weight_sync_group = vllm_stateless_init_process_group(
master_address,
master_port,
rank,
world_size,
get_torch_device().current_device(),
)

@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def sync_rollout_weights(self):
assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine
Expand Down Expand Up @@ -127,6 +116,59 @@ async def update_weights(self, inference_engine, params):


class DetachActorWorker(DetachSync):
def __init__(self, config: DictConfig, role: str, **kwargs):
ActorRolloutRefWorker.__init__(self, config, role)

if role == "actor":
self.ps_rank_offset = kwargs.get("rank_offset", self.rank)
self.ps_world_size = kwargs.get("ps_world_size", self.world_size)
self.ps = ParameterServer(rank=self.rank, world_size=self.ps_world_size)
self.index = 0

def init_process_group(self):
os.environ["HCCL_NPU_SOCKET_PORT_RANGE"] = "61020"
self.ps.init_process_group(device_index=0, master_port=60010)
Comment on lines +129 to +130
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Hardcoding network ports and port ranges can lead to conflicts in a multi-user or multi-job environment. It's better to make these configurable or use a dynamic port allocation mechanism.

del os.environ["HCCL_NPU_SOCKET_PORT_RANGE"]

def split_tensors(self) -> dict[str, torch.Tensor]:
assert self._is_actor and not self.config.hybrid_engine
assert hasattr(self, "_weights_info") and self._weights_info is not None

if self._is_actor and self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
params = self._get_actor_params()

named_tensors = {}

world_size = self.world_size
rank = self.rank

weights_per_rank = (len(self._weights_info) + world_size - 1) // world_size
for index, (key, _, _) in enumerate(self._weights_info):
assert key in params
tensor = params[key].full_tensor()
if index >= rank * weights_per_rank and index < (rank + 1) * weights_per_rank:
named_tensors[key] = tensor.to("cpu", non_blocking=True)

get_torch_device().synchronize()

return named_tensors

@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def sync_rollout_weights_by_ckpt_engine(self):
def req_func(socket_paths: list[tuple[str, str]]):
return

self.init_process_group()
named_tensors = self.split_tensors()
checkpoint_name = f"sync_{self.index}"

self.ps.register_checkpoint(checkpoint_name=checkpoint_name, named_tensors=named_tensors)
self.ps.gather_metas(checkpoint_name)
self.ps.update(checkpoint_name, req_func, ranks=list(range(self.ps_rank_offset, self.ps_world_size)))

self.index += 1

def _get_actor_params(self):
assert self._is_actor
params = self.actor_module_fsdp.state_dict()
Expand Down Expand Up @@ -159,8 +201,7 @@ def get_actor_weights_info(self):


class DetachAsyncRolloutWorker(DetachSync):
def __init__(self, config: DictConfig, role: str):
print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}")
def __init__(self, config: DictConfig, role: str, **kwargs):
ActorRolloutRefWorker.__init__(self, config, role)

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
Expand Down
14 changes: 14 additions & 0 deletions recipe/one_step_off_policy/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from verl.utils.config import validate_config
from verl.utils.device import auto_set_ascend_device_name

from .ckpt_engine_worker import CkptEngineWorker


def create_resource_pool_manager(config, roles: list) -> ResourcePoolManager:
"""
Expand Down Expand Up @@ -69,6 +71,14 @@ def create_resource_pool_manager(config, roles: list) -> ResourcePoolManager:
resource_pool_spec["rollout_pool"] = rollout_pool
mapping[Role.Rollout] = "rollout_pool"

if Role.CkptEngine in roles:
assert config.rollout.n_gpus_per_node > 0, "ckpt_engine config.rollout.n_gpus_per_node must be greater than 0"
assert config.rollout.nnodes > 0, "ckpt_engine config.rollout.nnodes must be greater than 0"
# the same as rollout pool
ckpt_engine_pool = [config.rollout.n_gpus_per_node] * config.rollout.nnodes
resource_pool_spec["ckpt_engine_pool"] = ckpt_engine_pool
mapping[Role.CkptEngine] = "ckpt_engine_pool"

return ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)


Expand Down Expand Up @@ -111,6 +121,7 @@ def create_role_worker_mapping(config):
Role.Actor: ray.remote(DetachActorWorker),
Role.Rollout: ray.remote(DetachAsyncRolloutWorker),
Role.Critic: ray.remote(CriticWorker),
Role.CkptEngine: ray.remote(CkptEngineWorker),
}

if config.reward_model.enable:
Expand Down Expand Up @@ -140,6 +151,9 @@ def run(self, config):

from verl.utils.fs import copy_to_local

if os.environ.get("ASCEND_RT_VISIBLE_DEVICES", None) is not None:
del os.environ["ASCEND_RT_VISIBLE_DEVICES"]

print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")

pprint(OmegaConf.to_container(config, resolve=True))
Expand Down
55 changes: 54 additions & 1 deletion recipe/one_step_off_policy/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch
import torch.distributed
from checkpoint_engine.ps import ParameterServer
from omegaconf import DictConfig
from ray.util.collective import collective

Expand Down Expand Up @@ -120,6 +121,58 @@ async def update_weights(self, inference_engine, params):


class DetachActorWorker(DetachSync):
def __init__(self, config: DictConfig, role: str, **kwargs):
ActorRolloutRefWorker.__init__(self, config, role)

if role == "actor":
self.ps_rank_offset = kwargs.get("rank_offset", self.rank)
self.ps_world_size = kwargs.get("ps_world_size", self.world_size)
self.ps = ParameterServer(rank=self.rank, world_size=self.ps_world_size)
self.index = 0

def init_process_group(self):
os.environ["HCCL_NPU_SOCKET_PORT_RANGE"] = "61020"
self.ps.init_process_group(device_index=0, master_port=60010)
del os.environ["HCCL_NPU_SOCKET_PORT_RANGE"]

def split_tensors(self) -> dict[str, torch.Tensor]:
assert self._is_actor and not self.config.hybrid_engine
assert hasattr(self, "_weights_info") and self._weights_info is not None

params_generator = self._get_actor_params_generator() if self._is_actor else None

if self._is_actor and self._is_offload_param:
load_megatron_model_to_gpu(self.actor_module)

named_tensors = {}

world_size = self.world_size
rank = self.rank

weights_per_rank = (len(self._weights_info) + world_size - 1) // world_size
for index, (key, tensor) in enumerate(params_generator):
if index >= rank * weights_per_rank and index < (rank + 1) * weights_per_rank:
named_tensors[key] = tensor.to("cpu", non_blocking=True)

get_torch_device().synchronize()

return named_tensors

@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def sync_rollout_weights_by_ckpt_engine(self):
def req_func(socket_paths: list[tuple[str, str]]):
return

self.init_process_group()
named_tensors = self.split_tensors()
checkpoint_name = f"sync_{self.index}"

self.ps.register_checkpoint(checkpoint_name=checkpoint_name, named_tensors=named_tensors)
self.ps.gather_metas(checkpoint_name)
self.ps.update(checkpoint_name, req_func, ranks=list(range(self.ps_rank_offset, self.ps_world_size)))

self.index += 1

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def _get_actor_params_generator(self):
assert self._is_actor
Expand Down Expand Up @@ -160,7 +213,7 @@ def get_actor_weights_info(self):


class DetachAsyncRolloutWorker(DetachSync):
def __init__(self, config: DictConfig, role: str):
def __init__(self, config: DictConfig, role: str, **kwargs):
print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}")
ActorRolloutRefWorker.__init__(self, config, role)

Expand Down
Loading