diff --git a/docker/patch/latest/sglang.patch b/docker/patch/latest/sglang.patch index ceee2a1f9..490d692d9 100644 --- a/docker/patch/latest/sglang.patch +++ b/docker/patch/latest/sglang.patch @@ -1,5 +1,5 @@ diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py -index aa10cb0..d41c31a 100644 +index aa10cb08d..d41c31a09 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -268,6 +268,12 @@ class ModelConfig: @@ -16,10 +16,18 @@ index aa10cb0..d41c31a 100644 self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN" diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py -index 51af676..661ea6f 100644 +index 51af67636..3ec1778ed 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py -@@ -315,6 +315,16 @@ class DecodePreallocQueue: +@@ -21,6 +21,7 @@ Life cycle of a request in the decode server + from __future__ import annotations + + import logging ++import os + import time + from collections import deque + from dataclasses import dataclass +@@ -315,6 +316,16 @@ class DecodePreallocQueue: ) return kv_manager @@ -36,8 +44,128 @@ index 51af676..661ea6f 100644 def add(self, req: Req, is_retracted: bool = False) -> None: """Add a request to the pending queue.""" if self._check_if_req_exceed_kv_capacity(req): +@@ -419,12 +430,37 @@ class DecodePreallocQueue: + [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group + ) + ++ # Bootstrap timeout: if a request has been stuck in Bootstrapping for too long, treat it as failed. ++ bootstrap_timeout = float( ++ os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600") ++ ) ++ now = time.perf_counter() ++ + for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): + if rids_to_check is not None and decode_req.req.rid not in rids_to_check: + continue + + if poll == KVPoll.Bootstrapping: +- pass ++ # Check for bootstrap timeout ++ entry_time = getattr( ++ decode_req.req.time_stats, ++ "decode_prealloc_queue_entry_time", ++ None, ++ ) ++ if entry_time is not None and (now - entry_time) > bootstrap_timeout: ++ error_message = ( ++ f"Decode bootstrap timed out after {now - entry_time:.1f}s " ++ f"for request rank={self.tp_rank} " ++ f"{decode_req.req.rid=} {decode_req.req.bootstrap_room=}" ++ ) ++ logger.error(error_message) ++ prepare_abort( ++ decode_req.req, ++ error_message, ++ status_code=HTTPStatus.GATEWAY_TIMEOUT, ++ ) ++ if self.scheduler.enable_metrics: ++ self.scheduler.metrics_collector.increment_bootstrap_failed_reqs() + elif poll == KVPoll.WaitingForInput: + decode_req.waiting_for_input = True + elif poll == KVPoll.Failed: +@@ -776,6 +812,13 @@ class DecodeTransferQueue: + [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group + ) + ++ # Transfer timeout: if a request has been in the transfer queue for too long ++ # (e.g., stuck in Bootstrapping/WaitingForInput/Transferring), treat it as failed. ++ transfer_timeout = float( ++ os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600") ++ ) ++ now = time.perf_counter() ++ + transferred_reqs = [] + indices_to_remove = set() + for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): +@@ -811,7 +854,33 @@ class DecodeTransferQueue: + KVPoll.WaitingForInput, + KVPoll.Transferring, + ]: +- pass ++ # Check for transfer timeout ++ entry_time = getattr( ++ decode_req.req.time_stats, ++ "decode_transfer_queue_entry_time", ++ None, ++ ) ++ if entry_time is not None and (now - entry_time) > transfer_timeout: ++ error_message = ( ++ f"Decode transfer timed out after {now - entry_time:.1f}s " ++ f"(state={poll}) for request rank={self.tp_rank} " ++ f"{decode_req.req.rid=} {decode_req.req.bootstrap_room=}" ++ ) ++ logger.error(error_message) ++ prepare_abort( ++ decode_req.req, ++ error_message, ++ status_code=HTTPStatus.GATEWAY_TIMEOUT, ++ ) ++ self.scheduler.stream_output( ++ [decode_req.req], decode_req.req.return_logprob ++ ) ++ release_kv_cache( ++ decode_req.req, self.tree_cache, is_insert=False ++ ) ++ indices_to_remove.add(i) ++ if self.scheduler.enable_metrics: ++ self.scheduler.metrics_collector.increment_transfer_failed_reqs() + else: + raise ValueError(f"Unexpected poll case: {poll}") + +@@ -827,6 +896,14 @@ class DecodeTransferQueue: + + return transferred_reqs + ++ def release_memory_occupation(self): ++ """Clean up all in-flight transfers before releasing GPU memory.""" ++ self.queue.clear() ++ ++ def resume_memory_occupation(self): ++ """Resume after GPU memory re-allocation. Queue was already cleared on release.""" ++ pass ++ + + class SchedulerDisaggregationDecodeMixin: + +@@ -1004,7 +1081,15 @@ class SchedulerDisaggregationDecodeMixin: + resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs() + self.waiting_queue.extend(resumed_reqs) + if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0: +- # if there are still retracted requests, we do not allocate new requests ++ # Still have retracted requests that couldn't resume (not enough memory). ++ # Don't accept new requests (pop_preallocated) — they would consume memory ++ # that retracted requests need. ++ # But DO drain completed transfers: their KV is already committed, and ++ # moving them to waiting_queue frees the reserved-decode-token budget ++ # in _allocatable_tokens(), which may unblock resume on the next iteration. ++ # Without this, completed transfers hold memory indefinitely → deadlock. ++ alloc_reqs = self.disagg_decode_transfer_queue.pop_transferred() ++ self.waiting_queue.extend(alloc_reqs) + return + + if not hasattr(self, "polling_count"): diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py -index 32e8c0b..dc93c5c 100644 +index 32e8c0b69..dc93c5c5f 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -253,6 +253,19 @@ class MooncakeKVManager(CommonKVManager): @@ -61,10 +189,59 @@ index 32e8c0b..dc93c5c 100644 if not transfer_blocks: return 0 diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py -index a6eed74..24a72ca 100644 +index a6eed743a..191b0977f 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py -@@ -306,6 +306,15 @@ class PrefillBootstrapQueue: +@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server + from __future__ import annotations + + import logging ++import os + import time + from collections import deque + from http import HTTPStatus +@@ -250,6 +251,12 @@ class PrefillBootstrapQueue: + [req.disagg_kv_sender for req in self.queue], self.gloo_group + ) + ++ # Bootstrap timeout: if a request has been stuck in Bootstrapping for too long, treat it as failed. ++ bootstrap_timeout = float( ++ os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600") ++ ) ++ now = time.perf_counter() ++ + for i, (req, poll) in enumerate(zip(self.queue, polls)): + if rids_to_check is not None: + # if req not in reqs_info_to_check, skip +@@ -257,6 +264,27 @@ class PrefillBootstrapQueue: + continue + + if poll == KVPoll.Bootstrapping: ++ # Check for bootstrap timeout ++ entry_time = getattr( ++ req.time_stats, ++ "prefill_bootstrap_queue_entry_time", ++ None, ++ ) ++ if entry_time is not None and (now - entry_time) > bootstrap_timeout: ++ error_message = ( ++ f"Prefill bootstrap timed out after {now - entry_time:.1f}s " ++ f"for request rank={self.tp_rank} " ++ f"{req.rid=} {req.bootstrap_room=}" ++ ) ++ logger.error(error_message) ++ prepare_abort( ++ req, error_message, status_code=HTTPStatus.GATEWAY_TIMEOUT ++ ) ++ self.scheduler.stream_output([req], req.return_logprob) ++ indices_to_remove.add(i) ++ failed_reqs.append(req) ++ if self.scheduler.enable_metrics: ++ self.scheduler.metrics_collector.increment_bootstrap_failed_reqs() + continue + elif poll == KVPoll.Failed: + error_message = f"Prefill bootstrap failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}" +@@ -306,6 +334,15 @@ class PrefillBootstrapQueue: else: return bootstrapped_reqs, failed_reqs @@ -80,8 +257,54 @@ index a6eed74..24a72ca 100644 class SchedulerDisaggregationPrefillMixin: """ +@@ -535,6 +572,13 @@ class SchedulerDisaggregationPrefillMixin: + self.attn_tp_cpu_group, + ) + ++ # Transfer timeout: if a request has been in the inflight queue for too long ++ # (e.g., stuck in WaitingForInput/Transferring), treat it as failed. ++ transfer_timeout = float( ++ os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600") ++ ) ++ now = time.perf_counter() ++ + undone_reqs: List[Req] = [] + # Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue + for req, poll in zip(self.disagg_prefill_inflight_queue, polls): +@@ -547,7 +591,30 @@ class SchedulerDisaggregationPrefillMixin: + assert poll == KVPoll.Success or poll == KVPoll.Failed + + if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]: +- undone_reqs.append(req) ++ # Check for transfer timeout ++ entry_time = getattr( ++ req.time_stats, ++ "prefill_transfer_queue_entry_time", ++ None, ++ ) ++ if entry_time is not None and (now - entry_time) > transfer_timeout: ++ error_message = ( ++ f"Prefill transfer timed out after {now - entry_time:.1f}s " ++ f"(state={poll}) for request rank={self.tp_rank} " ++ f"{req.rid=} {req.bootstrap_room=}" ++ ) ++ logger.error(error_message) ++ release_kv_cache(req, self.tree_cache) # unlock the tree ++ prepare_abort( ++ req, error_message, status_code=HTTPStatus.GATEWAY_TIMEOUT ++ ) ++ if hasattr(req.disagg_kv_sender, "clear"): ++ req.disagg_kv_sender.clear() ++ done_reqs.append(req) ++ if self.enable_metrics: ++ self.metrics_collector.increment_transfer_failed_reqs() ++ else: ++ undone_reqs.append(req) + elif poll == KVPoll.Success: # transfer done + release_kv_cache(req, self.tree_cache) # unlock the tree + req.finished_reason = FINISH_LENGTH(length=0) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py -index 0478526..cfb1aa6 100644 +index 0478526ef..cfb1aa669 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -1797,7 +1797,10 @@ def get_tensor_model_parallel_world_size(): @@ -97,7 +320,7 @@ index 0478526..cfb1aa6 100644 def get_pipeline_model_parallel_world_size(): diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py -index 6f69fd1..da20ac2 100644 +index 6f69fd19b..da20ac2ed 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -49,6 +49,7 @@ from sglang.srt.managers.io_struct import ( @@ -134,7 +357,7 @@ index 6f69fd1..da20ac2 100644 """Get weights by parameter name.""" obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py -index 88705cc..c8dc052 100644 +index 88705cc35..c8dc052f1 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -107,6 +107,7 @@ from sglang.srt.managers.io_struct import ( @@ -168,7 +391,7 @@ index 88705cc..c8dc052 100644 @app.post("/update_weight_version") async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request): diff --git a/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py b/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py -index d6c499d..5650042 100644 +index d6c499df0..565004260 100644 --- a/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py +++ b/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py @@ -613,7 +613,6 @@ def _get_k_and_s_triton( @@ -188,7 +411,7 @@ index d6c499d..5650042 100644 buf_numel_per_page: tl.constexpr, index_head_dim: tl.constexpr, diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py -index c9e82e4..f258454 100644 +index c9e82e4b1..f2584546a 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -3,6 +3,7 @@ from __future__ import annotations @@ -229,7 +452,7 @@ index c9e82e4..f258454 100644 if enable_dual_stream: current_stream = torch.cuda.current_stream() diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py -index 15df851..1636ed7 100644 +index 15df851eb..1636ed706 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -371,10 +371,13 @@ class LayerCommunicator: @@ -276,7 +499,7 @@ index 15df851..1636ed7 100644 hidden_states = self._communicate_simple_fn( diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py -index 7bef9d2..5926ff7 100644 +index 7bef9d2ab..5926ff7f5 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -83,15 +83,12 @@ class RMSNorm(MultiPlatformOp): @@ -640,7 +863,7 @@ index 7bef9d2..5926ff7 100644 return output diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py -index fa74310..cd33ea7 100644 +index fa7431048..cd33ea735 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -878,11 +878,6 @@ class LogitsProcessor(nn.Module): @@ -656,7 +879,7 @@ index fa74310..cd33ea7 100644 logits = torch.matmul( hidden_states.to(lm_head.weight.dtype), lm_head.weight.T diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py -index a1885fa..14d6923 100644 +index a1885fade..14d692365 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -14,6 +14,7 @@ import torch.nn.functional as F @@ -680,7 +903,7 @@ index a1885fa..14d6923 100644 intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx], diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py -index 8394635..7948779 100644 +index 839463518..7948779aa 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -647,7 +647,7 @@ class FusedMoE(torch.nn.Module): @@ -693,7 +916,7 @@ index 8394635..7948779 100644 ) diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py -index 00bd687..5a3ca8a 100644 +index 00bd68755..5a3ca8a67 100644 --- a/python/sglang/srt/layers/moe/routed_experts_capturer.py +++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py @@ -1,5 +1,6 @@ @@ -764,7 +987,7 @@ index 00bd687..5a3ca8a 100644 def get_routed_experts( diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py -index b4bdc41..3b895ff 100644 +index b4bdc41b3..3b895ff6a 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -442,7 +442,7 @@ class CompressedTensorsConfig(QuantizationConfig): @@ -777,7 +1000,7 @@ index b4bdc41..3b895ff 100644 def _get_scheme_from_parts( self, weight_quant: BaseModel, input_quant: BaseModel diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py -index c5e5a11..c46526e 100644 +index c5e5a11fc..c46526ecc 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -30,7 +30,10 @@ from sglang.srt.layers.quantization.fp8_utils import ( @@ -965,7 +1188,7 @@ index c5e5a11..c46526e 100644 is_k_full=self.is_k_full, routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py -index 480579e..dd8ca7d 100644 +index 480579e01..dd8ca7d4f 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -136,9 +136,7 @@ class RotaryEmbedding(MultiPlatformOp): @@ -990,7 +1213,7 @@ index 480579e..dd8ca7d 100644 assert ( fused_set_kv_buffer_arg is None diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py -index 55bef56..35ad68b 100644 +index 55bef5652..35ad68b1c 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -108,16 +108,11 @@ class Sampler(nn.Module): @@ -1014,7 +1237,7 @@ index 55bef56..35ad68b 100644 if not get_global_server_args().sampling_backend == "ascend" or ( return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py -index 2ecd854..89ef820 100644 +index 2ecd8542f..2a2e011ea 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1292,6 +1292,19 @@ class UpdateWeightsFromIPCReqOutput(BaseReq): @@ -1037,8 +1260,19 @@ index 2ecd854..89ef820 100644 @dataclass class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq): +@@ -1667,6 +1680,10 @@ class GetLoadReqOutput(BaseReq): + num_waiting_reqs: int + num_tokens: int + ts_tic: float ++ # Per-queue breakdown: list of {name, num_reqs, num_tokens, reqs: [{rid, seqlen, input_len, output_len}]} ++ queue_details: Optional[List[Dict[str, Any]]] = None ++ # Running batch info ++ running_details: Optional[Dict[str, Any]] = None + + + @dataclass diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py -index d423e61..d1f54a8 100644 +index d423e61d7..d1f54a832 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1779,7 +1779,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): @@ -1054,7 +1288,7 @@ index d423e61..d1f54a8 100644 break diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py -index 92d2868..43bfab6 100644 +index 92d286897..43bfab691 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -98,6 +98,7 @@ from sglang.srt.managers.io_struct import ( @@ -1073,8 +1307,61 @@ index 92d2868..43bfab6 100644 (GetWeightsByNameReqInput, self.get_weights_by_name), (ReleaseMemoryOccupationReqInput, self.release_memory_occupation), (ResumeMemoryOccupationReqInput, self.resume_memory_occupation), +diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py +index d44ff6027..3fad54598 100644 +--- a/python/sglang/srt/managers/scheduler_metrics_mixin.py ++++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py +@@ -553,12 +553,48 @@ class SchedulerMetricsMixin: + num_tokens += sum(req.seqlen for queue in waiting_queues for req in queue) + num_waiting_reqs = sum(len(queue) for queue in waiting_queues) + ++ # Collect per-queue details ++ queue_names = ["waiting_queue"] ++ if self.disaggregation_mode == DisaggregationMode.PREFILL: ++ queue_names.append("bootstrap_queue") ++ elif self.disaggregation_mode == DisaggregationMode.DECODE: ++ queue_names.append("prealloc_queue") ++ queue_names.append("transfer_queue") ++ queue_names.append("retracted_queue") ++ ++ queue_details = [] ++ for name, queue in zip(queue_names, waiting_queues): ++ reqs_info = [] ++ for req in queue: ++ reqs_info.append({ ++ "seqlen": req.seqlen, ++ }) ++ queue_details.append({ ++ "name": name, ++ "num_reqs": len(queue), ++ "num_tokens": sum(r["seqlen"] for r in reqs_info), ++ "reqs": reqs_info, ++ }) ++ ++ # Collect running batch details ++ running_reqs_info = [] ++ for req in self.running_batch.reqs: ++ running_reqs_info.append({ ++ "seqlen": req.seqlen, ++ }) ++ running_details = { ++ "num_reqs": len(self.running_batch.reqs), ++ "reqs": running_reqs_info, ++ } ++ + return GetLoadReqOutput( + dp_rank=self.dp_rank, + num_reqs=len(self.running_batch.reqs) + num_waiting_reqs, + num_waiting_reqs=num_waiting_reqs, + num_tokens=num_tokens, + ts_tic=time.perf_counter(), ++ queue_details=queue_details, ++ running_details=running_details, + ) + + @contextmanager diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py -index e40586c..243e2b0 100644 +index e40586c24..243e2b0c2 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -10,6 +10,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode @@ -1095,7 +1382,7 @@ index e40586c..243e2b0 100644 return diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py -index 293a843..c3a618b 100644 +index 293a84350..8ee36c794 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -1,6 +1,7 @@ @@ -1137,11 +1424,13 @@ index 293a843..c3a618b 100644 def get_weights_by_name(self: Scheduler, recv_req: GetWeightsByNameReqInput): parameter = self.tp_worker.get_weights_by_name(recv_req) return GetWeightsByNameReqOutput(parameter) -@@ -137,6 +148,13 @@ class SchedulerUpdateWeightsMixin: +@@ -137,6 +148,15 @@ class SchedulerUpdateWeightsMixin: self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE) self.flush_cache() + if self.disaggregation_mode == DisaggregationMode.DECODE: ++ if hasattr(self, "disagg_decode_transfer_queue"): ++ self.disagg_decode_transfer_queue.release_memory_occupation() + if hasattr(self, "disagg_decode_prealloc_queue"): + self.disagg_decode_prealloc_queue.release_memory_occupation() + elif self.disaggregation_mode == DisaggregationMode.PREFILL: @@ -1151,11 +1440,13 @@ index 293a843..c3a618b 100644 if GPU_MEMORY_TYPE_WEIGHTS in tags: self.stashed_model_static_state = _export_static_state( self.tp_worker.model_runner.model -@@ -177,6 +195,13 @@ class SchedulerUpdateWeightsMixin: +@@ -177,6 +197,15 @@ class SchedulerUpdateWeightsMixin: if GPU_MEMORY_TYPE_KV_CACHE in tags: self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE) + if self.disaggregation_mode == DisaggregationMode.DECODE: ++ if hasattr(self, "disagg_decode_transfer_queue"): ++ self.disagg_decode_transfer_queue.resume_memory_occupation() + if hasattr(self, "disagg_decode_prealloc_queue"): + self.disagg_decode_prealloc_queue.resume_memory_occupation() + elif self.disaggregation_mode == DisaggregationMode.PREFILL: @@ -1166,7 +1457,7 @@ index 293a843..c3a618b 100644 def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py -index e5d42be..412293b 100644 +index e5d42bed8..412293b30 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -49,6 +49,8 @@ from sglang.srt.managers.io_struct import ( @@ -1218,7 +1509,7 @@ index e5d42be..412293b 100644 self, obj: InitWeightsSendGroupForRemoteInstanceReqInput, diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py -index 49f63a1..e4cd0ff 100644 +index 49f63a198..e4cd0ff2b 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import ( @@ -1242,7 +1533,7 @@ index 49f63a1..e4cd0ff 100644 parameter = self.model_runner.get_weights_by_name( recv_req.name, recv_req.truncate_size diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py -index eaf2962..bf74cbd 100644 +index eaf29628b..bf74cbd12 100644 --- a/python/sglang/srt/mem_cache/allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -287,6 +287,85 @@ def alloc_decode_kernel( @@ -1389,7 +1680,7 @@ index eaf2962..bf74cbd 100644 if self.debug_mode: assert len(torch.unique(out_indices)) == len(out_indices) diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py -index f6cfca8..5d3cad0 100644 +index f6cfca8b6..5d3cad059 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -11,10 +11,15 @@ import torch @@ -1436,7 +1727,7 @@ index f6cfca8..5d3cad0 100644 self.tp_group = params.tp_cache_group self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py -index 65d562a..8517837 100644 +index 65d562a27..fe5547d7b 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -1678,7 +1678,8 @@ class NSATokenToKVPool(MLATokenToKVPool): @@ -1461,8 +1752,59 @@ index 65d562a..8517837 100644 self._finalize_allocation_log(size) def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor: +@@ -1775,6 +1781,50 @@ class NSATokenToKVPool(MLATokenToKVPool): + ] + return data_ptrs, data_lens, item_lens + ++ def get_cpu_copy(self, indices): ++ # First, save the kv_buffer (inherited from MLATokenToKVPool) ++ kv_cache_cpu = super().get_cpu_copy(indices) ++ ++ # Additionally, save the index_k_with_scale_buffer (page-indexed) ++ page_indices = indices[:: self.page_size] // self.page_size ++ torch.cuda.synchronize() ++ index_k_cpu = [] ++ chunk_size = self.cpu_offloading_chunk_size ++ # Convert chunk_size from token-level to page-level ++ page_chunk_size = max(1, chunk_size // self.page_size) ++ for layer_id in range(self.layer_num): ++ index_k_cpu.append([]) ++ for i in range(0, len(page_indices), page_chunk_size): ++ chunk_page_indices = page_indices[i : i + page_chunk_size] ++ idx_cpu = self.index_k_with_scale_buffer[layer_id][ ++ chunk_page_indices ++ ].to("cpu", non_blocking=True) ++ index_k_cpu[-1].append(idx_cpu) ++ torch.cuda.synchronize() ++ ++ return {"kv": kv_cache_cpu, "index_k": index_k_cpu} ++ ++ def load_cpu_copy(self, kv_cache_cpu_dict, indices): ++ # Restore the kv_buffer (inherited from MLATokenToKVPool) ++ super().load_cpu_copy(kv_cache_cpu_dict["kv"], indices) ++ ++ # Restore the index_k_with_scale_buffer (page-indexed) ++ page_indices = indices[:: self.page_size] // self.page_size ++ index_k_cpu = kv_cache_cpu_dict["index_k"] ++ torch.cuda.synchronize() ++ chunk_size = self.cpu_offloading_chunk_size ++ page_chunk_size = max(1, chunk_size // self.page_size) ++ for layer_id in range(self.layer_num): ++ for i in range(0, len(page_indices), page_chunk_size): ++ chunk_page_indices = page_indices[i : i + page_chunk_size] ++ idx_cpu = index_k_cpu[layer_id][i // page_chunk_size] ++ assert idx_cpu.shape[0] == len(chunk_page_indices) ++ idx_chunk = idx_cpu.to( ++ self.index_k_with_scale_buffer[0].device, non_blocking=True ++ ) ++ self.index_k_with_scale_buffer[layer_id][chunk_page_indices] = idx_chunk ++ torch.cuda.synchronize() ++ + def get_kv_size_bytes(self): + kv_size_bytes = super().get_kv_size_bytes() + for index_k_cache in self.index_k_with_scale_buffer: diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py -index 4639415..d99cc4b 100644 +index 46394158f..d99cc4b3b 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -15,7 +15,12 @@ from sglang.jit_kernel.hicache import ( @@ -1680,7 +2022,7 @@ index 4639415..d99cc4b 100644 + device_pool, host_indices, device_indices, io_backend + ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py -index 1d69c05..d984c2e 100644 +index 1d69c0582..d984c2e12 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -558,7 +558,8 @@ class ModelRunner(ModelRunnerKVCacheMixin): @@ -1761,7 +2103,7 @@ index 1d69c05..d984c2e 100644 def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]): params_dict = dict(model.named_parameters()) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py -index ed8cc7a..b8f1026 100644 +index ed8cc7ada..b8f1026dd 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -159,6 +159,7 @@ from sglang.srt.utils import ( @@ -1795,7 +2137,7 @@ index ed8cc7a..b8f1026 100644 if ( self.current_attention_backend == "fa3" diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py -index a7dbade..c83a413 100644 +index a7dbadec6..c83a41338 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -90,9 +90,6 @@ class Qwen2MLP(nn.Module): @@ -1833,7 +2175,7 @@ index a7dbade..c83a413 100644 if get_global_server_args().rl_on_policy_target is not None else {} diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py -index 3ad9f67..0b9c7f4 100644 +index 3ad9f6736..0b9c7f499 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -586,7 +586,17 @@ class Qwen2MoeModel(nn.Module): @@ -1856,7 +2198,7 @@ index 3ad9f67..0b9c7f4 100644 self.norm = PPMissingLayer(return_tuple=True) diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py -index 9220831..2b8303b 100644 +index 9220831f6..2b8303b54 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -90,8 +90,8 @@ class Qwen3Attention(nn.Module): @@ -1899,7 +2241,7 @@ index 9220831..2b8303b 100644 if hidden_states.shape[0] != 0: hidden_states = self.self_attn( diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py -index e11678a..e277d46 100644 +index e11678a9e..e277d46f2 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -22,6 +22,7 @@ import math @@ -2009,7 +2351,7 @@ index e11678a..e277d46 100644 self.layer_communicator = LayerCommunicator( diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py -index 079f458..218e323 100644 +index 079f45843..218e32362 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -397,28 +397,68 @@ class Qwen3VLMoeVisionModel(nn.Module, RotaryPosMixin): @@ -2125,7 +2467,7 @@ index 079f458..218e323 100644 positions, hidden_states, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py -index a2b26e0..72db298 100644 +index a2b26e0e0..72db29801 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -527,6 +527,7 @@ class ServerArgs: @@ -2149,7 +2491,7 @@ index a2b26e0..72db298 100644 "--disable-cuda-graph-padding", action="store_true", diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py -index 5fe4508..c95fbd0 100644 +index 5fe45086c..c95fbd0f6 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -341,7 +341,10 @@ class EAGLEDraftCudaGraphRunner: @@ -2176,7 +2518,7 @@ index 5fe4508..c95fbd0 100644 self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py -index 1bf3816..b5b41db 100644 +index 1bf3816e9..b5b41dba4 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -778,6 +778,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): @@ -2219,7 +2561,7 @@ index 1bf3816..b5b41db 100644 @dataclass diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py -index a702df4..61d9ae3 100644 +index a702df4f8..61d9ae366 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -231,7 +231,7 @@ class EAGLEWorker(TpModelWorker): @@ -2232,7 +2574,7 @@ index a702df4..61d9ae3 100644 Device2DraftCudaGraphRunner = { diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py -index 8560246..13db860 100644 +index 8560246c6..13db860dc 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -2224,6 +2224,8 @@ class SafeUnpickler(pickle.Unpickler):