Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions gemma/gm/text/_chat_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,31 @@ def chat(
object.__setattr__(self, 'last_state', out.state)
return model_output


def _print_stream(
out: Iterator[_sampler.SamplerOutput],
*,
fallback_state: _sampler_loop.SamplingState | None = None,
) -> _sampler.SamplerOutput:
"""Prints the streaming output."""
text_tokens = []
"""Prints the streaming output and returns the final SamplerOutput.

If the iterator yields no elements (e.g. `max_new_tokens=0` or an early stop
before the first token), this function returns a dummy `SamplerOutput` with an
empty `text` field and the provided `fallback_state`, instead of raising an
internal error.
"""
text_tokens: list[str] = []
last_state: _sampler.SamplerOutput | None = None

for state in out:
last_state = state
text_tokens.append(state.text)
if state.text == '<end_of_turn>': # Last token is not printed.
continue
print(state.text, end='', flush=True)
out = dataclasses.replace(state, text=''.join(text_tokens)) # pylint: disable=undefined-variable,undefined-loop-variable
return out

if last_state is None:
# No tokens were streamed; return a dummy SamplerOutput rather than
# propagating an internal UnboundLocalError.
return _sampler.SamplerOutput(text='', state=fallback_state)

return dataclasses.replace(last_state, text=''.join(text_tokens))