-
Notifications
You must be signed in to change notification settings - Fork 3k
[wip][BREAKING][recipe, ckpt]add checkpoint engine for one step off policy #4601
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
x1314aq seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account. You have signed the CLA already but the status is still pending? Let us recheck it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new CkptEngineWorker to manage weight synchronization using a ParameterServer for VLLM-based rollout workers, replacing the previous Ray collective group mechanism. The CkptEngineWorker is responsible for initializing its own process group, checking VLLM readiness, setting server addresses, and synchronizing rollout weights by registering and updating checkpoints via the ParameterServer. The DetachActorWorker is also updated to integrate with this new checkpoint engine, including its own ParameterServer initialization and a split_tensors method to prepare actor weights for synchronization. The PPOTrainer now calculates rank_offset and ps_world_size to configure the ParameterServer instances and explicitly creates a resource pool for the CkptEngine role, allocating specific CPU/NPU resources. The weight synchronization logic in ray_trainer.py is refactored to use the new sync_rollout_weights_by_ckpt_engine methods on both actor and checkpoint engine worker groups. Additionally, the RayResourcePool class is enhanced with a custom_bundle parameter to allow for more flexible resource allocation, particularly for NPU devices, and a new shell script grpo_0.6b_gsm8k_fsdp2_2_6_ckpt_engine.sh is added to demonstrate the new setup. Review comments highlight critical issues: hardcoded network ports and port ranges in CkptEngineWorker and DetachActorWorker should be made configurable to prevent conflicts, and the check_vllm_ready loop in CkptEngineWorker needs a maximum retry mechanism to prevent indefinite hanging, along with incrementing self.index after each checkpoint update to ensure unique checkpoint names.
| os.environ["HCCL_NPU_SOCKET_PORT_RANGE"] = "61020-61050" | ||
| self.ps.init_process_group(device_index=0, master_port=60010) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self.ps.register_checkpoint(checkpoint_name=checkpoint_name) | ||
| self.ps.gather_metas(checkpoint_name) | ||
| ranks = list(range(self.ps_rank_offset, self.ps_world_size)) | ||
| self.ps.update(checkpoint_name, req_func, ranks=ranks) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To fix the checkpoint naming issue, you should increment self.index after each update. This ensures that a unique checkpoint name is used for every synchronization.
| self.ps.update(checkpoint_name, req_func, ranks=ranks) | |
| self.ps.update(checkpoint_name, req_func, ranks=ranks) | |
| self.index += 1 |
| os.environ["HCCL_NPU_SOCKET_PORT_RANGE"] = "61020" | ||
| self.ps.init_process_group(device_index=0, master_port=60010) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self.ps.gather_metas(checkpoint_name) | ||
| ranks = list(range(self.ps_rank_offset, self.ps_world_size)) | ||
|
|
||
| self.ps.update(checkpoint_name, req_func, ranks=ranks) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To fix the checkpoint naming issue, you should increment self.index after each update. This ensures that a unique checkpoint name is used for every synchronization.
| self.ps.update(checkpoint_name, req_func, ranks=ranks) | |
| self.ps.update(checkpoint_name, req_func, ranks=ranks) | |
| self.index += 1 |
| retry_num = 0 | ||
| transport = None | ||
| if uds is not None: | ||
| transport = httpx.HTTPTransport(uds=uds) | ||
| while True: | ||
| try: | ||
| response = httpx.Client(transport=transport).get(f"{self.endpoint}/health", timeout=10) | ||
| response.raise_for_status() | ||
| break | ||
| except (httpx.ConnectError, httpx.HTTPStatusError) as e: | ||
| retry_num += 1 | ||
| logger.warning(f"fail to check vllm ready, retry {retry_num} times, error: {e}") | ||
| time.sleep(5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The while True loop for checking vLLM readiness can run indefinitely if the server fails to start, causing the worker to hang. It's much safer to implement a timeout mechanism with a maximum number of retries. This ensures that the worker will eventually fail with a clear error message instead of getting stuck. I've also moved the httpx.Client instantiation out of the loop for efficiency.
| retry_num = 0 | |
| transport = None | |
| if uds is not None: | |
| transport = httpx.HTTPTransport(uds=uds) | |
| while True: | |
| try: | |
| response = httpx.Client(transport=transport).get(f"{self.endpoint}/health", timeout=10) | |
| response.raise_for_status() | |
| break | |
| except (httpx.ConnectError, httpx.HTTPStatusError) as e: | |
| retry_num += 1 | |
| logger.warning(f"fail to check vllm ready, retry {retry_num} times, error: {e}") | |
| time.sleep(5) | |
| retry_num = 0 | |
| max_retries = 60 # e.g., 5 minutes | |
| transport = httpx.HTTPTransport(uds=uds) if uds is not None else None | |
| client = httpx.Client(transport=transport) | |
| while retry_num < max_retries: | |
| try: | |
| response = client.get(f"{self.endpoint}/health", timeout=10) | |
| response.raise_for_status() | |
| logger.info("vLLM server is ready.") | |
| return | |
| except (httpx.ConnectError, httpx.HTTPStatusError) as e: | |
| retry_num += 1 | |
| logger.warning(f"fail to check vllm ready, retry {retry_num}/{max_retries} times, error: {e}") | |
| time.sleep(5) | |
| raise RuntimeError(f"vLLM server not ready after {max_retries} retries.") |
What does this PR do?
Introducing checkpoint-engine to achieve efficient parameter synchronization between trainer and rollouter.
This PR is somewhat similar to #4427, but employs a completely different implementation approach. The main differences are as follows:
checkpoint-engineas a dependency, rather than re-implementing its core logic.checkpoint-engineruns as an independent process, not within the same process as the rollouter, and updates the weights of the rollouter via HTTP requests.Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
Dependent on checkpoint-engine PR and hixl issue to run on Ascend Atlas A2/A3 server.
Work on Megatron and SGLang is currently in progress and will be completed soon.
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)