11from __future__ import annotations
22
33from collections .abc import Callable
4- from dataclasses import dataclass
4+ from dataclasses import dataclass , replace
55from typing import Any , Literal
66
77from restate import RunOptions , SdkInternalBaseException , TerminalError
@@ -55,9 +55,9 @@ class RestateMCPToolRunResult:
5555class RestateContextRunToolSet (WrapperToolset [AgentDepsT ]):
5656 """A toolset that automatically wraps tool calls with restate's `ctx.run_typed()`."""
5757
58- def __init__ (self , wrapped : AbstractToolset [AgentDepsT ]):
58+ def __init__ (self , wrapped : AbstractToolset [AgentDepsT ], run_options : RunOptions ):
5959 super ().__init__ (wrapped )
60- self .options = RunOptions [ RestateContextRunResult ]( serde = CONTEXT_RUN_SERDE )
60+ self .options = replace ( run_options , serde = CONTEXT_RUN_SERDE )
6161
6262 async def call_tool (
6363 self , name : str , tool_args : dict [str , Any ], ctx : RunContext [AgentDepsT ], tool : ToolsetTool [AgentDepsT ]
@@ -119,9 +119,11 @@ def visit_and_replace(
119119class RestateMCPServer (WrapperToolset [AgentDepsT ]):
120120 """A wrapper for MCPServer that integrates with restate."""
121121
122- def __init__ (self , wrapped : MCPServer ):
122+ def __init__ (self , wrapped : MCPServer , run_options : RunOptions ):
123123 super ().__init__ (wrapped )
124124 self ._wrapped = wrapped
125+ self .get_tools_options = replace (run_options , serde = MCP_GET_TOOLS_SERDE )
126+ self .run_tools_options = replace (run_options , serde = MCP_RUN_SERDE )
125127
126128 def visit_and_replace (
127129 self , visitor : Callable [[AbstractToolset [AgentDepsT ]], AbstractToolset [AgentDepsT ]]
@@ -136,16 +138,14 @@ async def get_tools_in_context() -> RestateMCPGetToolsContextRunResult:
136138 # so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity.
137139 return RestateMCPGetToolsContextRunResult (output = {name : tool .tool_def for name , tool in res .items ()})
138140
139- options = RunOptions (serde = MCP_GET_TOOLS_SERDE )
140-
141141 context = current_context ()
142142 if context is None :
143143 raise UserError (
144144 "A toolset cannot be used without a Restate context. Make sure to run it within an agent or a run context."
145145 )
146146
147147 try :
148- tool_defs = await context .run_typed ("get mcp tools" , get_tools_in_context , options )
148+ tool_defs = await context .run_typed ("get mcp tools" , get_tools_in_context , self . get_tools_options )
149149 return {name : self .tool_for_tool_def (tool_def ) for name , tool_def in tool_defs .output .items ()}
150150 except SdkInternalBaseException as e :
151151 raise Exception ("Internal error during get_tools call" ) from e
@@ -165,14 +165,13 @@ async def call_tool_in_context() -> RestateMCPToolRunResult:
165165 res = await self ._wrapped .call_tool (name , tool_args , ctx , tool )
166166 return RestateMCPToolRunResult (output = res )
167167
168- options = RunOptions (serde = MCP_RUN_SERDE )
169168 context = current_context ()
170169 if context is None :
171170 raise UserError (
172171 "A toolset cannot be used without a Restate context. Make sure to run it within an agent or a run context."
173172 )
174173 try :
175- res = await context .run_typed (f"Calling mcp tool { name } " , call_tool_in_context , options )
174+ res = await context .run_typed (f"Calling mcp tool { name } " , call_tool_in_context , self . run_tools_options )
176175 except SdkInternalBaseException as e :
177176 raise Exception ("Internal error during tool call" ) from e
178177
0 commit comments