Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 20 additions & 23 deletions src/l0/_structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,19 +265,15 @@ async def buffering_iterator() -> RawStream:
for attempt in range(max_attempts):
try:
# _internal_run expects a callable factory
# Handle both direct async iterators and factory functions
def make_stream_factory(
src: AwaitableStreamSource,
) -> AwaitableStreamFactory:
if callable(src) and not hasattr(src, "__anext__"):
# It's already a factory
return src
else:
# It's a direct async iterator - wrap in factory
# Note: This only works once per stream!
return lambda: cast(RawStream, src)

stream_factory = make_stream_factory(stream_source)
# For factory functions, pass them directly so _internal_run can call fresh on retries
# For direct async iterators (already wrapped in buffering factory above),
# wrap in a lambda - the buffering factory handles replay
if callable(stream_source) and not hasattr(stream_source, "__anext__"):
# It's a factory - pass it directly to _internal_run
stream_factory = cast(AwaitableStreamFactory, stream_source)
else:
# It's a direct async iterator (wrapped in buffering factory)
stream_factory = lambda src=stream_source: cast(RawStream, src)

# Run through L0 runtime
result = await _internal_run(
Expand Down Expand Up @@ -890,17 +886,18 @@ async def buffering_iterator() -> RawStream:
for stream_source in all_streams:
for attempt in range(max_attempts):
try:
# _internal_run expects a callable factory
# For factory functions, pass them directly so _internal_run can call fresh on retries
# For direct async iterators (already wrapped in buffering factory above),
# wrap in a lambda - the buffering factory handles replay
if callable(stream_source) and not hasattr(stream_source, "__anext__"):
# It's a factory - pass it directly to _internal_run
stream_factory = cast(AwaitableStreamFactory, stream_source)
else:
# It's a direct async iterator (wrapped in buffering factory)
stream_factory = lambda src=stream_source: cast(RawStream, src)

def make_stream_factory(
src: AwaitableStreamSource,
) -> AwaitableStreamFactory:
if callable(src) and not hasattr(src, "__anext__"):
return src
else:
return lambda: cast(RawStream, src)

stream_factory = make_stream_factory(stream_source)

# Run through L0 runtime
result = await _internal_run(
stream=stream_factory,
on_event=on_event,
Expand Down
189 changes: 189 additions & 0 deletions tests/test_structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,195 @@ async def json_stream():
assert "list[SimpleModel]" in result.telemetry.schema_name


class TestStreamFactoryRetryBehavior:
"""Test that factory functions are called fresh on each retry attempt.

This tests the fix for the "stream already consumed" bug where factory
functions were not being called on each retry, causing streams to be
reused and fail with "ReadableStream is locked" or similar errors.
"""

@pytest.mark.asyncio
async def test_factory_called_fresh_on_each_retry_structured(self):
"""Test that factory is called fresh on each retry in structured()."""
factory_call_count = 0

def stream_factory():
nonlocal factory_call_count
factory_call_count += 1

async def gen():
if factory_call_count == 1:
# First attempt: invalid JSON that will fail validation
yield Event(type=EventType.TOKEN, text='{"wrong": "field"}')
else:
# Subsequent attempts: valid JSON
yield Event(type=EventType.TOKEN, text='{"value": "success"}')
yield Event(type=EventType.COMPLETE)

return gen()

result = await structured(
schema=SimpleModel,
stream=stream_factory,
retry=Retry(attempts=3),
)

assert result.data.value == "success"
assert factory_call_count == 2 # Called twice: first failed, second succeeded

@pytest.mark.asyncio
async def test_factory_called_fresh_on_each_retry_structured_array(self):
"""Test that factory is called fresh on each retry in structured_array()."""
factory_call_count = 0

def stream_factory():
nonlocal factory_call_count
factory_call_count += 1

async def gen():
if factory_call_count == 1:
# First attempt: invalid JSON that will fail validation
yield Event(type=EventType.TOKEN, text='[{"wrong": "field"}]')
else:
# Subsequent attempts: valid JSON
yield Event(type=EventType.TOKEN, text='[{"value": "success"}]')
yield Event(type=EventType.COMPLETE)

return gen()

result = await structured_array(
item_schema=SimpleModel,
stream=stream_factory,
retry=Retry(attempts=3),
)

assert len(result.data) == 1
assert result.data[0].value == "success"
assert factory_call_count == 2 # Called twice: first failed, second succeeded

@pytest.mark.asyncio
async def test_factory_exhausts_all_retries_on_persistent_failure(self):
"""Test that all retry attempts are used when validation keeps failing."""
factory_call_count = 0

def stream_factory():
nonlocal factory_call_count
factory_call_count += 1

async def gen():
# Always return invalid JSON
yield Event(type=EventType.TOKEN, text='{"wrong": "field"}')
yield Event(type=EventType.COMPLETE)

return gen()

with pytest.raises(ValueError, match="Schema validation failed"):
await structured(
schema=SimpleModel,
stream=stream_factory,
retry=Retry(attempts=3),
)

# Factory should have been called 3 times (once per retry attempt)
assert factory_call_count == 3

@pytest.mark.asyncio
async def test_factory_not_called_after_success(self):
"""Test that factory is not called again after successful validation."""
factory_call_count = 0

def stream_factory():
nonlocal factory_call_count
factory_call_count += 1

async def gen():
# Always return valid JSON
yield Event(type=EventType.TOKEN, text='{"value": "test"}')
yield Event(type=EventType.COMPLETE)

return gen()

result = await structured(
schema=SimpleModel,
stream=stream_factory,
retry=Retry(attempts=5),
)

assert result.data.value == "test"
assert factory_call_count == 1 # Only called once since first attempt succeeded

@pytest.mark.asyncio
async def test_fallback_factory_called_fresh_on_retry(self):
"""Test that fallback factory functions are also called fresh on retry."""
main_call_count = 0
fallback_call_count = 0

def main_factory():
nonlocal main_call_count
main_call_count += 1

async def gen():
# Main always fails
yield Event(type=EventType.TOKEN, text='{"wrong": "field"}')
yield Event(type=EventType.COMPLETE)

return gen()

def fallback_factory():
nonlocal fallback_call_count
fallback_call_count += 1

async def gen():
if fallback_call_count == 1:
# First fallback attempt fails
yield Event(type=EventType.TOKEN, text='{"also_wrong": "field"}')
else:
# Second fallback attempt succeeds
yield Event(type=EventType.TOKEN, text='{"value": "from_fallback"}')
yield Event(type=EventType.COMPLETE)

return gen()

result = await structured(
schema=SimpleModel,
stream=main_factory,
fallbacks=[fallback_factory],
retry=Retry(attempts=2),
)

assert result.data.value == "from_fallback"
assert main_call_count == 2 # Main tried twice
assert fallback_call_count == 2 # Fallback tried twice, second succeeded

@pytest.mark.asyncio
async def test_async_factory_called_fresh_on_retry(self):
"""Test that async factory functions are called fresh on each retry."""
factory_call_count = 0

async def async_stream_factory():
nonlocal factory_call_count
factory_call_count += 1

async def gen():
if factory_call_count == 1:
yield Event(type=EventType.TOKEN, text='{"wrong": "field"}')
else:
yield Event(type=EventType.TOKEN, text='{"value": "async_success"}')
yield Event(type=EventType.COMPLETE)

return gen()

result = await structured(
schema=SimpleModel,
stream=async_stream_factory,
retry=Retry(attempts=3),
)

assert result.data.value == "async_success"
assert factory_call_count == 2


class TestStructuredStrictMode:
"""Test strict_mode parameter for rejecting extra fields."""

Expand Down