@@ -389,14 +389,35 @@ async def _send_tokens_multi_rank(
389389 self .num_ranks ,
390390 )
391391
392+ # For decode (single token), send only to last rank
393+ # This avoids empty chunks when splitting 1 token across multiple ranks
394+ if num_tokens <= self .num_ranks :
395+ # Decode mode: only last rank gets the token
396+ last_rank = self .num_ranks - 1
397+ await self ._send_chunk_to_rank (
398+ last_rank ,
399+ nonce ,
400+ tokens ,
401+ callback_addr ,
402+ logprobs ,
403+ top_logprobs ,
404+ decoding_config ,
405+ num_tokens ,
406+ )
407+ return
408+
392409 async def send_to_rank (rank : int ) -> None :
393- # Shard the sequence for this rank
410+ # Shard the sequence for this rank (prefill mode)
394411 chunk_mx , indices = shard_for_mode (
395412 mx .array (full_tokens ), self .num_ranks , rank , "prefill"
396413 )
397414 chunk = np .array (chunk_mx , dtype = np .int32 )
398415 chunk_bytes = chunk .tobytes ()
399416
417+ if len (chunk ) == 0 :
418+ logger .debug ("CP rank %d: skipping empty chunk" , rank )
419+ return
420+
400421 logger .debug (
401422 "CP rank %d: sending %d tokens (indices %d-%d)" ,
402423 rank ,
@@ -450,6 +471,66 @@ async def send_to_rank(rank: int) -> None:
450471 # Send to all ranks in parallel
451472 await asyncio .gather (* [send_to_rank (r ) for r in range (self .num_ranks )])
452473
474+ async def _send_chunk_to_rank (
475+ self ,
476+ rank : int ,
477+ nonce : str ,
478+ tokens : bytes ,
479+ callback_addr : str ,
480+ logprobs : bool ,
481+ top_logprobs : int ,
482+ decoding_config : Optional [Any ],
483+ num_tokens : int ,
484+ ) -> None :
485+ """Send tokens directly to a specific rank (for decode phase)."""
486+ logger .debug (
487+ "CP decode: sending %d tokens directly to rank %d (last rank)" ,
488+ num_tokens ,
489+ rank ,
490+ )
491+
492+ msg = ActivationMessage (
493+ nonce = nonce ,
494+ pool_id = - 1 ,
495+ batch_size = 1 ,
496+ shape = (num_tokens ,),
497+ dtype = "tokens" ,
498+ layer_id = - 1 ,
499+ timestamp = utc_epoch_now (),
500+ node_origin = "api" ,
501+ callback_url = f"grpc://{ callback_addr } " ,
502+ req_logprobs = logprobs ,
503+ req_top_logprobs = top_logprobs ,
504+ temperature = decoding_config .temperature if decoding_config else 1.0 ,
505+ top_p = decoding_config .top_p if decoding_config else 1.0 ,
506+ top_k = decoding_config .top_k if decoding_config else - 1 ,
507+ repetition_penalty = (
508+ decoding_config .repetition_penalty if decoding_config else 1.0
509+ ),
510+ min_p = decoding_config .min_p if decoding_config else 0.0 ,
511+ min_tokens_to_keep = (
512+ decoding_config .min_tokens_to_keep if decoding_config else 1
513+ ),
514+ )
515+ req = msg .to_proto (tokens )
516+
517+ stub = self .rank_stubs [rank ]
518+ streams = self ._streams_by_rank [rank ]
519+ ctx = await streams .get_or_create_stream (
520+ nonce ,
521+ lambda it : stub .StreamActivations (it ),
522+ )
523+ if not ctx or not ctx .open :
524+ raise RuntimeError (
525+ f"Failed to create stream for rank { rank } , nonce { nonce } "
526+ )
527+
528+ ctx .last_seq += 1
529+ await ctx .queue .put (
530+ pb2 .ActivationFrame (request = req , seq = ctx .last_seq , end_of_request = False )
531+ )
532+ ctx .last_activity_t = asyncio .get_running_loop ().time ()
533+
453534 async def await_token (self , nonce : str , timeout_s : float ) -> TokenResult :
454535 fut = asyncio .get_running_loop ().create_future ()
455536 self ._pending [nonce ] = fut
0 commit comments