diff --git a/gemma/gm/text/_chat_sampler.py b/gemma/gm/text/_chat_sampler.py index 302ce050..3fff0606 100644 --- a/gemma/gm/text/_chat_sampler.py +++ b/gemma/gm/text/_chat_sampler.py @@ -219,16 +219,31 @@ def chat( object.__setattr__(self, 'last_state', out.state) return model_output - def _print_stream( out: Iterator[_sampler.SamplerOutput], + *, + fallback_state: _sampler_loop.SamplingState | None = None, ) -> _sampler.SamplerOutput: - """Prints the streaming output.""" - text_tokens = [] + """Prints the streaming output and returns the final SamplerOutput. + + If the iterator yields no elements (e.g. `max_new_tokens=0` or an early stop + before the first token), this function returns a dummy `SamplerOutput` with an + empty `text` field and the provided `fallback_state`, instead of raising an + internal error. + """ + text_tokens: list[str] = [] + last_state: _sampler.SamplerOutput | None = None + for state in out: + last_state = state text_tokens.append(state.text) if state.text == '': # Last token is not printed. continue print(state.text, end='', flush=True) - out = dataclasses.replace(state, text=''.join(text_tokens)) # pylint: disable=undefined-variable,undefined-loop-variable - return out + + if last_state is None: + # No tokens were streamed; return a dummy SamplerOutput rather than + # propagating an internal UnboundLocalError. + return _sampler.SamplerOutput(text='', state=fallback_state) + + return dataclasses.replace(last_state, text=''.join(text_tokens))