Skip to content

Commit 9e5ef20

Browse files
committed
Add run options
1 parent fd306b7 commit 9e5ef20

File tree

3 files changed

+21
-15
lines changed

3 files changed

+21
-15
lines changed

python/restate/ext/pydantic/_agent.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
55
from typing import Any, overload
66

7-
from restate import TerminalError
7+
from restate import RunOptions, TerminalError
88
from restate.ext.pydantic._utils import state_context
99
from restate.extensions import current_context
1010

@@ -93,6 +93,7 @@ def __init__(
9393
*,
9494
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
9595
disable_auto_wrapping_tools: bool = False,
96+
run_options: RunOptions | None = None,
9697
):
9798
super().__init__(wrapped)
9899
if not isinstance(wrapped.model, Model):
@@ -102,12 +103,16 @@ def __init__(
102103

103104
self._event_stream_handler = event_stream_handler
104105
self._disable_auto_wrapping_tools = disable_auto_wrapping_tools
105-
self._model = RestateModelWrapper(wrapped.model, event_stream_handler=event_stream_handler, max_attempts=3)
106+
107+
if run_options is None:
108+
run_options = RunOptions(max_attempts=3)
109+
110+
self._model = RestateModelWrapper(wrapped.model, run_options, event_stream_handler=event_stream_handler)
106111

107112
def set_context(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT]:
108113
"""Set the Restate context for the toolset, wrapping tools if needed."""
109114
if isinstance(toolset, FunctionToolset) and not disable_auto_wrapping_tools:
110-
return RestateContextRunToolSet(toolset)
115+
return RestateContextRunToolSet(toolset, run_options)
111116
try:
112117
from pydantic_ai.mcp import MCPServer
113118

@@ -116,7 +121,7 @@ def set_context(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDe
116121
pass
117122
else:
118123
if isinstance(toolset, MCPServer):
119-
return RestateMCPServer(toolset)
124+
return RestateMCPServer(toolset, run_options)
120125

121126
return toolset
122127

python/restate/ext/pydantic/_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from contextlib import asynccontextmanager
33
from datetime import datetime
44
from typing import Any
5+
import dataclasses
56

67
from restate import RunOptions, SdkInternalBaseException
78
from restate.ext.pydantic._utils import current_state
@@ -60,11 +61,12 @@ class RestateModelWrapper(WrapperModel):
6061
def __init__(
6162
self,
6263
wrapped: Model,
64+
run_options: RunOptions,
6365
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
64-
max_attempts: int | None = None,
6566
):
6667
super().__init__(wrapped)
67-
self._options = RunOptions(serde=MODEL_RESPONSE_SERDE, max_attempts=max_attempts)
68+
69+
self._options = dataclasses.replace(run_options, serde=MODEL_RESPONSE_SERDE)
6870
self._event_stream_handler = event_stream_handler
6971

7072
async def request(self, *args: Any, **kwargs: Any) -> ModelResponse:

python/restate/ext/pydantic/_toolset.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Callable
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, replace
55
from typing import Any, Literal
66

77
from restate import RunOptions, SdkInternalBaseException, TerminalError
@@ -55,9 +55,9 @@ class RestateMCPToolRunResult:
5555
class RestateContextRunToolSet(WrapperToolset[AgentDepsT]):
5656
"""A toolset that automatically wraps tool calls with restate's `ctx.run_typed()`."""
5757

58-
def __init__(self, wrapped: AbstractToolset[AgentDepsT]):
58+
def __init__(self, wrapped: AbstractToolset[AgentDepsT], run_options: RunOptions):
5959
super().__init__(wrapped)
60-
self.options = RunOptions[RestateContextRunResult](serde=CONTEXT_RUN_SERDE)
60+
self.options = replace(run_options, serde=CONTEXT_RUN_SERDE)
6161

6262
async def call_tool(
6363
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
@@ -119,9 +119,11 @@ def visit_and_replace(
119119
class RestateMCPServer(WrapperToolset[AgentDepsT]):
120120
"""A wrapper for MCPServer that integrates with restate."""
121121

122-
def __init__(self, wrapped: MCPServer):
122+
def __init__(self, wrapped: MCPServer, run_options: RunOptions):
123123
super().__init__(wrapped)
124124
self._wrapped = wrapped
125+
self.get_tools_options = replace(run_options, serde=MCP_GET_TOOLS_SERDE)
126+
self.run_tools_options = replace(run_options, serde=MCP_RUN_SERDE)
125127

126128
def visit_and_replace(
127129
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
@@ -136,16 +138,14 @@ async def get_tools_in_context() -> RestateMCPGetToolsContextRunResult:
136138
# so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity.
137139
return RestateMCPGetToolsContextRunResult(output={name: tool.tool_def for name, tool in res.items()})
138140

139-
options = RunOptions(serde=MCP_GET_TOOLS_SERDE)
140-
141141
context = current_context()
142142
if context is None:
143143
raise UserError(
144144
"A toolset cannot be used without a Restate context. Make sure to run it within an agent or a run context."
145145
)
146146

147147
try:
148-
tool_defs = await context.run_typed("get mcp tools", get_tools_in_context, options)
148+
tool_defs = await context.run_typed("get mcp tools", get_tools_in_context, self.get_tools_options)
149149
return {name: self.tool_for_tool_def(tool_def) for name, tool_def in tool_defs.output.items()}
150150
except SdkInternalBaseException as e:
151151
raise Exception("Internal error during get_tools call") from e
@@ -165,14 +165,13 @@ async def call_tool_in_context() -> RestateMCPToolRunResult:
165165
res = await self._wrapped.call_tool(name, tool_args, ctx, tool)
166166
return RestateMCPToolRunResult(output=res)
167167

168-
options = RunOptions(serde=MCP_RUN_SERDE)
169168
context = current_context()
170169
if context is None:
171170
raise UserError(
172171
"A toolset cannot be used without a Restate context. Make sure to run it within an agent or a run context."
173172
)
174173
try:
175-
res = await context.run_typed(f"Calling mcp tool {name}", call_tool_in_context, options)
174+
res = await context.run_typed(f"Calling mcp tool {name}", call_tool_in_context, self.run_tools_options)
176175
except SdkInternalBaseException as e:
177176
raise Exception("Internal error during tool call") from e
178177

0 commit comments

Comments
 (0)