Skip to content

Commit b7b0760

Browse files
committed
fix(cp): send decode tokens only to last rank to avoid empty chunks
During decode phase (1 token), splitting across ranks gives 0 tokens to some ranks causing reshape errors. Now decode tokens go directly to last rank only.
1 parent e0f10b1 commit b7b0760

File tree

2 files changed

+83
-1
lines changed

2 files changed

+83
-1
lines changed

dnet-tui

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit bcb47a606115c0f069de79726ee5d771eac0e40f

src/dnet/api/strategies/context_parallel.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,14 +389,35 @@ async def _send_tokens_multi_rank(
389389
self.num_ranks,
390390
)
391391

392+
# For decode (single token), send only to last rank
393+
# This avoids empty chunks when splitting 1 token across multiple ranks
394+
if num_tokens <= self.num_ranks:
395+
# Decode mode: only last rank gets the token
396+
last_rank = self.num_ranks - 1
397+
await self._send_chunk_to_rank(
398+
last_rank,
399+
nonce,
400+
tokens,
401+
callback_addr,
402+
logprobs,
403+
top_logprobs,
404+
decoding_config,
405+
num_tokens,
406+
)
407+
return
408+
392409
async def send_to_rank(rank: int) -> None:
393-
# Shard the sequence for this rank
410+
# Shard the sequence for this rank (prefill mode)
394411
chunk_mx, indices = shard_for_mode(
395412
mx.array(full_tokens), self.num_ranks, rank, "prefill"
396413
)
397414
chunk = np.array(chunk_mx, dtype=np.int32)
398415
chunk_bytes = chunk.tobytes()
399416

417+
if len(chunk) == 0:
418+
logger.debug("CP rank %d: skipping empty chunk", rank)
419+
return
420+
400421
logger.debug(
401422
"CP rank %d: sending %d tokens (indices %d-%d)",
402423
rank,
@@ -450,6 +471,66 @@ async def send_to_rank(rank: int) -> None:
450471
# Send to all ranks in parallel
451472
await asyncio.gather(*[send_to_rank(r) for r in range(self.num_ranks)])
452473

474+
async def _send_chunk_to_rank(
475+
self,
476+
rank: int,
477+
nonce: str,
478+
tokens: bytes,
479+
callback_addr: str,
480+
logprobs: bool,
481+
top_logprobs: int,
482+
decoding_config: Optional[Any],
483+
num_tokens: int,
484+
) -> None:
485+
"""Send tokens directly to a specific rank (for decode phase)."""
486+
logger.debug(
487+
"CP decode: sending %d tokens directly to rank %d (last rank)",
488+
num_tokens,
489+
rank,
490+
)
491+
492+
msg = ActivationMessage(
493+
nonce=nonce,
494+
pool_id=-1,
495+
batch_size=1,
496+
shape=(num_tokens,),
497+
dtype="tokens",
498+
layer_id=-1,
499+
timestamp=utc_epoch_now(),
500+
node_origin="api",
501+
callback_url=f"grpc://{callback_addr}",
502+
req_logprobs=logprobs,
503+
req_top_logprobs=top_logprobs,
504+
temperature=decoding_config.temperature if decoding_config else 1.0,
505+
top_p=decoding_config.top_p if decoding_config else 1.0,
506+
top_k=decoding_config.top_k if decoding_config else -1,
507+
repetition_penalty=(
508+
decoding_config.repetition_penalty if decoding_config else 1.0
509+
),
510+
min_p=decoding_config.min_p if decoding_config else 0.0,
511+
min_tokens_to_keep=(
512+
decoding_config.min_tokens_to_keep if decoding_config else 1
513+
),
514+
)
515+
req = msg.to_proto(tokens)
516+
517+
stub = self.rank_stubs[rank]
518+
streams = self._streams_by_rank[rank]
519+
ctx = await streams.get_or_create_stream(
520+
nonce,
521+
lambda it: stub.StreamActivations(it),
522+
)
523+
if not ctx or not ctx.open:
524+
raise RuntimeError(
525+
f"Failed to create stream for rank {rank}, nonce {nonce}"
526+
)
527+
528+
ctx.last_seq += 1
529+
await ctx.queue.put(
530+
pb2.ActivationFrame(request=req, seq=ctx.last_seq, end_of_request=False)
531+
)
532+
ctx.last_activity_t = asyncio.get_running_loop().time()
533+
453534
async def await_token(self, nonce: str, timeout_s: float) -> TokenResult:
454535
fut = asyncio.get_running_loop().create_future()
455536
self._pending[nonce] = fut

0 commit comments

Comments
 (0)