diff --git a/src/l0/_structured.py b/src/l0/_structured.py index 0dbe4f1..0d11bfc 100644 --- a/src/l0/_structured.py +++ b/src/l0/_structured.py @@ -265,19 +265,15 @@ async def buffering_iterator() -> RawStream: for attempt in range(max_attempts): try: # _internal_run expects a callable factory - # Handle both direct async iterators and factory functions - def make_stream_factory( - src: AwaitableStreamSource, - ) -> AwaitableStreamFactory: - if callable(src) and not hasattr(src, "__anext__"): - # It's already a factory - return src - else: - # It's a direct async iterator - wrap in factory - # Note: This only works once per stream! - return lambda: cast(RawStream, src) - - stream_factory = make_stream_factory(stream_source) + # For factory functions, pass them directly so _internal_run can call fresh on retries + # For direct async iterators (already wrapped in buffering factory above), + # wrap in a lambda - the buffering factory handles replay + if callable(stream_source) and not hasattr(stream_source, "__anext__"): + # It's a factory - pass it directly to _internal_run + stream_factory = cast(AwaitableStreamFactory, stream_source) + else: + # It's a direct async iterator (wrapped in buffering factory) + stream_factory = lambda src=stream_source: cast(RawStream, src) # Run through L0 runtime result = await _internal_run( @@ -890,17 +886,18 @@ async def buffering_iterator() -> RawStream: for stream_source in all_streams: for attempt in range(max_attempts): try: + # _internal_run expects a callable factory + # For factory functions, pass them directly so _internal_run can call fresh on retries + # For direct async iterators (already wrapped in buffering factory above), + # wrap in a lambda - the buffering factory handles replay + if callable(stream_source) and not hasattr(stream_source, "__anext__"): + # It's a factory - pass it directly to _internal_run + stream_factory = cast(AwaitableStreamFactory, stream_source) + else: + # It's a direct async iterator (wrapped in buffering factory) + stream_factory = lambda src=stream_source: cast(RawStream, src) - def make_stream_factory( - src: AwaitableStreamSource, - ) -> AwaitableStreamFactory: - if callable(src) and not hasattr(src, "__anext__"): - return src - else: - return lambda: cast(RawStream, src) - - stream_factory = make_stream_factory(stream_source) - + # Run through L0 runtime result = await _internal_run( stream=stream_factory, on_event=on_event, diff --git a/tests/test_structured.py b/tests/test_structured.py index 266c43e..41d02d8 100644 --- a/tests/test_structured.py +++ b/tests/test_structured.py @@ -681,6 +681,195 @@ async def json_stream(): assert "list[SimpleModel]" in result.telemetry.schema_name +class TestStreamFactoryRetryBehavior: + """Test that factory functions are called fresh on each retry attempt. + + This tests the fix for the "stream already consumed" bug where factory + functions were not being called on each retry, causing streams to be + reused and fail with "ReadableStream is locked" or similar errors. + """ + + @pytest.mark.asyncio + async def test_factory_called_fresh_on_each_retry_structured(self): + """Test that factory is called fresh on each retry in structured().""" + factory_call_count = 0 + + def stream_factory(): + nonlocal factory_call_count + factory_call_count += 1 + + async def gen(): + if factory_call_count == 1: + # First attempt: invalid JSON that will fail validation + yield Event(type=EventType.TOKEN, text='{"wrong": "field"}') + else: + # Subsequent attempts: valid JSON + yield Event(type=EventType.TOKEN, text='{"value": "success"}') + yield Event(type=EventType.COMPLETE) + + return gen() + + result = await structured( + schema=SimpleModel, + stream=stream_factory, + retry=Retry(attempts=3), + ) + + assert result.data.value == "success" + assert factory_call_count == 2 # Called twice: first failed, second succeeded + + @pytest.mark.asyncio + async def test_factory_called_fresh_on_each_retry_structured_array(self): + """Test that factory is called fresh on each retry in structured_array().""" + factory_call_count = 0 + + def stream_factory(): + nonlocal factory_call_count + factory_call_count += 1 + + async def gen(): + if factory_call_count == 1: + # First attempt: invalid JSON that will fail validation + yield Event(type=EventType.TOKEN, text='[{"wrong": "field"}]') + else: + # Subsequent attempts: valid JSON + yield Event(type=EventType.TOKEN, text='[{"value": "success"}]') + yield Event(type=EventType.COMPLETE) + + return gen() + + result = await structured_array( + item_schema=SimpleModel, + stream=stream_factory, + retry=Retry(attempts=3), + ) + + assert len(result.data) == 1 + assert result.data[0].value == "success" + assert factory_call_count == 2 # Called twice: first failed, second succeeded + + @pytest.mark.asyncio + async def test_factory_exhausts_all_retries_on_persistent_failure(self): + """Test that all retry attempts are used when validation keeps failing.""" + factory_call_count = 0 + + def stream_factory(): + nonlocal factory_call_count + factory_call_count += 1 + + async def gen(): + # Always return invalid JSON + yield Event(type=EventType.TOKEN, text='{"wrong": "field"}') + yield Event(type=EventType.COMPLETE) + + return gen() + + with pytest.raises(ValueError, match="Schema validation failed"): + await structured( + schema=SimpleModel, + stream=stream_factory, + retry=Retry(attempts=3), + ) + + # Factory should have been called 3 times (once per retry attempt) + assert factory_call_count == 3 + + @pytest.mark.asyncio + async def test_factory_not_called_after_success(self): + """Test that factory is not called again after successful validation.""" + factory_call_count = 0 + + def stream_factory(): + nonlocal factory_call_count + factory_call_count += 1 + + async def gen(): + # Always return valid JSON + yield Event(type=EventType.TOKEN, text='{"value": "test"}') + yield Event(type=EventType.COMPLETE) + + return gen() + + result = await structured( + schema=SimpleModel, + stream=stream_factory, + retry=Retry(attempts=5), + ) + + assert result.data.value == "test" + assert factory_call_count == 1 # Only called once since first attempt succeeded + + @pytest.mark.asyncio + async def test_fallback_factory_called_fresh_on_retry(self): + """Test that fallback factory functions are also called fresh on retry.""" + main_call_count = 0 + fallback_call_count = 0 + + def main_factory(): + nonlocal main_call_count + main_call_count += 1 + + async def gen(): + # Main always fails + yield Event(type=EventType.TOKEN, text='{"wrong": "field"}') + yield Event(type=EventType.COMPLETE) + + return gen() + + def fallback_factory(): + nonlocal fallback_call_count + fallback_call_count += 1 + + async def gen(): + if fallback_call_count == 1: + # First fallback attempt fails + yield Event(type=EventType.TOKEN, text='{"also_wrong": "field"}') + else: + # Second fallback attempt succeeds + yield Event(type=EventType.TOKEN, text='{"value": "from_fallback"}') + yield Event(type=EventType.COMPLETE) + + return gen() + + result = await structured( + schema=SimpleModel, + stream=main_factory, + fallbacks=[fallback_factory], + retry=Retry(attempts=2), + ) + + assert result.data.value == "from_fallback" + assert main_call_count == 2 # Main tried twice + assert fallback_call_count == 2 # Fallback tried twice, second succeeded + + @pytest.mark.asyncio + async def test_async_factory_called_fresh_on_retry(self): + """Test that async factory functions are called fresh on each retry.""" + factory_call_count = 0 + + async def async_stream_factory(): + nonlocal factory_call_count + factory_call_count += 1 + + async def gen(): + if factory_call_count == 1: + yield Event(type=EventType.TOKEN, text='{"wrong": "field"}') + else: + yield Event(type=EventType.TOKEN, text='{"value": "async_success"}') + yield Event(type=EventType.COMPLETE) + + return gen() + + result = await structured( + schema=SimpleModel, + stream=async_stream_factory, + retry=Retry(attempts=3), + ) + + assert result.data.value == "async_success" + assert factory_call_count == 2 + + class TestStructuredStrictMode: """Test strict_mode parameter for rejecting extra fields."""