Skip to content

Commit a57c672

Browse files
authored
fix: complete schema v0.11.2 follow-ups (#81)
1 parent 2af0e7c commit a57c672

File tree

9 files changed

+317
-47
lines changed

9 files changed

+317
-47
lines changed

src/acp/agent/router.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
from typing import Any
44

5+
from pydantic import BaseModel
6+
57
from ..exceptions import RequestError
68
from ..interfaces import Agent
79
from ..meta import AGENT_METHODS
8-
from ..router import MessageRouter
10+
from ..router import MessageRouter, Route, _resolve_handler, _warn_legacy_handler
911
from ..schema import (
1012
AuthenticateRequest,
1113
CancelNotification,
@@ -17,15 +19,41 @@
1719
NewSessionRequest,
1820
PromptRequest,
1921
ResumeSessionRequest,
22+
SetSessionConfigOptionBooleanRequest,
2023
SetSessionConfigOptionSelectRequest,
2124
SetSessionModelRequest,
2225
SetSessionModeRequest,
2326
)
24-
from ..utils import normalize_result
27+
from ..utils import model_to_kwargs, normalize_result
2528

2629
__all__ = ["build_agent_router"]
2730

2831

32+
_SET_CONFIG_OPTION_MODELS = (SetSessionConfigOptionBooleanRequest, SetSessionConfigOptionSelectRequest)
33+
34+
35+
def _validate_set_config_option_request(params: Any) -> BaseModel:
36+
if isinstance(params, dict) and params.get("type") == "boolean":
37+
return SetSessionConfigOptionBooleanRequest.model_validate(params)
38+
return SetSessionConfigOptionSelectRequest.model_validate(params)
39+
40+
41+
def _make_set_config_option_handler(agent: Agent) -> Any:
42+
func, attr, legacy_api = _resolve_handler(agent, "set_config_option")
43+
if func is None:
44+
return None
45+
46+
async def wrapper(params: Any) -> Any:
47+
if legacy_api:
48+
_warn_legacy_handler(agent, attr)
49+
request = _validate_set_config_option_request(params)
50+
if legacy_api:
51+
return await func(request)
52+
return await func(**model_to_kwargs(request, _SET_CONFIG_OPTION_MODELS))
53+
54+
return wrapper
55+
56+
2957
def build_agent_router(agent: Agent, use_unstable_protocol: bool = False) -> MessageRouter:
3058
router = MessageRouter(use_unstable_protocol=use_unstable_protocol)
3159

@@ -63,12 +91,13 @@ def build_agent_router(agent: Agent, use_unstable_protocol: bool = False) -> Mes
6391
adapt_result=normalize_result,
6492
unstable=True,
6593
)
66-
router.route_request(
67-
AGENT_METHODS["session_set_config_option"],
68-
SetSessionConfigOptionSelectRequest,
69-
agent,
70-
"set_config_option",
71-
adapt_result=normalize_result,
94+
router.add_route(
95+
Route(
96+
method=AGENT_METHODS["session_set_config_option"],
97+
func=_make_set_config_option_handler(agent),
98+
kind="request",
99+
adapt_result=normalize_result,
100+
)
72101
)
73102
router.route_request(
74103
AGENT_METHODS["authenticate"],

src/acp/client/connection.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
ResourceContentBlock,
3636
ResumeSessionRequest,
3737
ResumeSessionResponse,
38+
SetSessionConfigOptionBooleanRequest,
3839
SetSessionConfigOptionResponse,
3940
SetSessionConfigOptionSelectRequest,
4041
SetSessionModelRequest,
@@ -44,7 +45,7 @@
4445
SseMcpServer,
4546
TextContentBlock,
4647
)
47-
from ..utils import compatible_class, notify_model, param_model, request_model, request_model_from_dict
48+
from ..utils import compatible_class, notify_model, param_model, param_models, request_model, request_model_from_dict
4849
from .router import build_client_router
4950

5051
__all__ = ["ClientSideConnection"]
@@ -154,16 +155,30 @@ async def set_session_model(self, model_id: str, session_id: str, **kwargs: Any)
154155
SetSessionModelResponse,
155156
)
156157

157-
@param_model(SetSessionConfigOptionSelectRequest)
158+
@param_models(SetSessionConfigOptionBooleanRequest, SetSessionConfigOptionSelectRequest)
158159
async def set_config_option(
159-
self, config_id: str, session_id: str, value: str, **kwargs: Any
160+
self, config_id: str, session_id: str, value: str | bool, **kwargs: Any
160161
) -> SetSessionConfigOptionResponse:
162+
request = (
163+
SetSessionConfigOptionBooleanRequest(
164+
config_id=config_id,
165+
session_id=session_id,
166+
type="boolean",
167+
value=value,
168+
field_meta=kwargs or None,
169+
)
170+
if isinstance(value, bool)
171+
else SetSessionConfigOptionSelectRequest(
172+
config_id=config_id,
173+
session_id=session_id,
174+
value=value,
175+
field_meta=kwargs or None,
176+
)
177+
)
161178
return await request_model_from_dict(
162179
self._conn,
163180
AGENT_METHODS["session_set_config_option"],
164-
SetSessionConfigOptionSelectRequest(
165-
config_id=config_id, session_id=session_id, value=value, field_meta=kwargs or None
166-
),
181+
request,
167182
SetSessionConfigOptionResponse,
168183
)
169184

@@ -193,7 +208,12 @@ async def prompt(
193208
return await request_model(
194209
self._conn,
195210
AGENT_METHODS["session_prompt"],
196-
PromptRequest(prompt=prompt, session_id=session_id, field_meta=kwargs or None),
211+
PromptRequest(
212+
prompt=prompt,
213+
session_id=session_id,
214+
message_id=message_id,
215+
field_meta=kwargs or None,
216+
),
197217
PromptResponse,
198218
)
199219

src/acp/interfaces.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
ResumeSessionResponse,
5151
SessionInfoUpdate,
5252
SessionNotification,
53+
SetSessionConfigOptionBooleanRequest,
5354
SetSessionConfigOptionResponse,
5455
SetSessionConfigOptionSelectRequest,
5556
SetSessionModelRequest,
@@ -70,7 +71,7 @@
7071
WriteTextFileRequest,
7172
WriteTextFileResponse,
7273
)
73-
from .utils import param_model
74+
from .utils import param_model, param_models
7475

7576
__all__ = ["Agent", "Client"]
7677

@@ -181,9 +182,9 @@ async def set_session_model(
181182
self, model_id: str, session_id: str, **kwargs: Any
182183
) -> SetSessionModelResponse | None: ...
183184

184-
@param_model(SetSessionConfigOptionSelectRequest)
185+
@param_models(SetSessionConfigOptionBooleanRequest, SetSessionConfigOptionSelectRequest)
185186
async def set_config_option(
186-
self, config_id: str, session_id: str, value: str, **kwargs: Any
187+
self, config_id: str, session_id: str, value: str | bool, **kwargs: Any
187188
) -> SetSessionConfigOptionResponse | None: ...
188189

189190
@param_model(AuthenticateRequest)

src/acp/router.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,34 @@
2020
HandlerT = TypeVar("HandlerT", bound=RequestHandler)
2121

2222

23+
def _warn_legacy_handler(obj: Any, attr: str) -> None:
24+
warnings.warn(
25+
f"The old style method {type(obj).__name__}.{attr} is deprecated, please update to the snake-cased form.",
26+
DeprecationWarning,
27+
stacklevel=3,
28+
)
29+
30+
31+
def _resolve_handler(obj: Any, attr: str) -> tuple[AsyncHandler | None, str, bool]:
32+
legacy_api = False
33+
func = getattr(obj, attr, None)
34+
if func is None and "_" in attr:
35+
attr = to_camel_case(attr)
36+
func = getattr(obj, attr, None)
37+
legacy_api = True
38+
elif callable(func) and "_" not in attr:
39+
original_func = func
40+
if hasattr(func, "__func__"):
41+
original_func = func.__func__
42+
parameters = inspect.signature(original_func).parameters
43+
if len(parameters) == 2 and "params" in parameters:
44+
legacy_api = True
45+
46+
if func is None or not callable(func):
47+
return None, attr, legacy_api
48+
return func, attr, legacy_api
49+
50+
2351
@dataclass(slots=True)
2452
class Route:
2553
method: str
@@ -63,31 +91,13 @@ def add_route(self, route: Route) -> None:
6391
self._notifications[route.method] = route
6492

6593
def _make_func(self, model: type[BaseModel], obj: Any, attr: str) -> AsyncHandler | None:
66-
legacy_api = False
67-
func = getattr(obj, attr, None)
68-
if func is None and "_" in attr:
69-
attr = to_camel_case(attr)
70-
func = getattr(obj, attr, None)
71-
legacy_api = True
72-
elif callable(func) and "_" not in attr:
73-
original_func = func
74-
if hasattr(func, "__func__"):
75-
original_func = func.__func__
76-
parameters = inspect.signature(original_func).parameters
77-
if len(parameters) == 2 and "params" in parameters:
78-
legacy_api = True
79-
80-
if func is None or not callable(func):
94+
func, attr, legacy_api = _resolve_handler(obj, attr)
95+
if func is None:
8196
return None
8297

8398
async def wrapper(params: Any) -> Any:
8499
if legacy_api:
85-
warnings.warn(
86-
f"The old style method {type(obj).__name__}.{attr} is deprecated, "
87-
"please update to the snake-cased form.",
88-
DeprecationWarning,
89-
stacklevel=3,
90-
)
100+
_warn_legacy_handler(obj, attr)
91101
model_obj = model.model_validate(params)
92102
if legacy_api:
93103
return await func(model_obj) # type: ignore[arg-type]

src/acp/utils.py

Lines changed: 99 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,29 @@
2626
MethodT = TypeVar("MethodT", bound=Callable)
2727
ClassT = TypeVar("ClassT", bound=type)
2828
T = TypeVar("T")
29+
MultiParamModelSpec = tuple[type[BaseModel], ...]
30+
31+
32+
def _param_models_name(models: MultiParamModelSpec) -> str:
33+
return " | ".join(model_type.__name__ for model_type in models)
34+
35+
36+
def _param_models_field_names(models: MultiParamModelSpec) -> tuple[str, ...]:
37+
shared_fields = set(models[0].model_fields)
38+
for model_type in models[1:]:
39+
shared_fields &= set(model_type.model_fields)
40+
return tuple(field_name for field_name in models[0].model_fields if field_name in shared_fields)
41+
42+
43+
def model_to_kwargs(model_obj: BaseModel, models: MultiParamModelSpec) -> dict[str, Any]:
44+
kwargs = {
45+
field_name: getattr(model_obj, field_name)
46+
for field_name in _param_models_field_names(models)
47+
if field_name != "field_meta"
48+
}
49+
if meta := getattr(model_obj, "field_meta", None):
50+
kwargs.update(meta)
51+
return kwargs
2952

3053

3154
def serialize_params(params: BaseModel) -> dict[str, Any]:
@@ -114,6 +137,18 @@ def decorator(func: MethodT) -> MethodT:
114137
return decorator
115138

116139

140+
def param_models(*param_cls: type[BaseModel]) -> Callable[[MethodT], MethodT]:
141+
"""Decorator to mark a method as accepting multiple legacy parameter models."""
142+
if not param_cls:
143+
raise ValueError("param_models() requires at least one model class")
144+
145+
def decorator(func: MethodT) -> MethodT:
146+
func.__param_models__ = param_cls # type: ignore[attr-defined]
147+
return func
148+
149+
return decorator
150+
151+
117152
def to_camel_case(snake_str: str) -> str:
118153
"""Convert snake_case strings to camelCase."""
119154
components = snake_str.split("_")
@@ -129,7 +164,9 @@ def wrapped(self, params: BaseModel) -> T:
129164
DeprecationWarning,
130165
stacklevel=3,
131166
)
132-
kwargs = {k: getattr(params, k) for k in model.model_fields if k != "field_meta"}
167+
kwargs = {
168+
field_name: getattr(params, field_name) for field_name in model.model_fields if field_name != "field_meta"
169+
}
133170
if meta := getattr(params, "field_meta", None):
134171
kwargs.update(meta)
135172
return func(self, **kwargs) # type: ignore[arg-type]
@@ -152,7 +189,11 @@ def wrapped(self, *args: Any, **kwargs: Any) -> T:
152189
DeprecationWarning,
153190
stacklevel=3,
154191
)
155-
kwargs = {k: getattr(param, k) for k in model.model_fields if k != "field_meta"}
192+
kwargs = {
193+
field_name: getattr(param, field_name)
194+
for field_name in model.model_fields
195+
if field_name != "field_meta"
196+
}
156197
if meta := getattr(param, "field_meta", None):
157198
kwargs.update(meta)
158199
return func(self, **kwargs) # type: ignore[arg-type]
@@ -161,14 +202,67 @@ def wrapped(self, *args: Any, **kwargs: Any) -> T:
161202
return wrapped
162203

163204

205+
def _make_multi_legacy_func(func: Callable[..., T], models: MultiParamModelSpec) -> Callable[[Any, BaseModel], T]:
206+
model_name = _param_models_name(models)
207+
208+
@functools.wraps(func)
209+
def wrapped(self, params: BaseModel) -> T:
210+
warnings.warn(
211+
f"Calling {func.__name__} with {model_name} parameter is " # type: ignore[attr-defined]
212+
"deprecated, please update to the new API style.",
213+
DeprecationWarning,
214+
stacklevel=3,
215+
)
216+
return func(self, **model_to_kwargs(params, models)) # type: ignore[arg-type]
217+
218+
return wrapped
219+
220+
221+
def _make_multi_compatible_func(func: Callable[..., T], models: MultiParamModelSpec) -> Callable[..., T]:
222+
model_name = _param_models_name(models)
223+
224+
@functools.wraps(func)
225+
def wrapped(self, *args: Any, **kwargs: Any) -> T:
226+
param = None
227+
if not kwargs and len(args) == 1:
228+
param = args[0]
229+
elif not args and len(kwargs) == 1:
230+
param = kwargs.get("params")
231+
if isinstance(param, models):
232+
warnings.warn(
233+
f"Calling {func.__name__} with {model_name} parameter " # type: ignore[attr-defined]
234+
"is deprecated, please update to the new API style.",
235+
DeprecationWarning,
236+
stacklevel=3,
237+
)
238+
return func(self, **model_to_kwargs(param, models)) # type: ignore[arg-type]
239+
return func(self, *args, **kwargs)
240+
241+
return wrapped
242+
243+
164244
def compatible_class(cls: ClassT) -> ClassT:
165245
"""Mark a class as backward compatible with old API style."""
166246
for attr in dir(cls):
167247
func = getattr(cls, attr)
168-
if not callable(func) or (model := getattr(func, "__param_model__", None)) is None:
248+
if not callable(func):
249+
continue
250+
model = getattr(func, "__param_model__", None)
251+
models = getattr(func, "__param_models__", None)
252+
if model is None and models is None:
169253
continue
170254
if "_" in attr:
171-
setattr(cls, to_camel_case(attr), _make_legacy_func(func, model))
255+
if models is not None:
256+
setattr(cls, to_camel_case(attr), _make_multi_legacy_func(func, models))
257+
else:
258+
if model is None:
259+
continue
260+
setattr(cls, to_camel_case(attr), _make_legacy_func(func, model))
172261
else:
173-
setattr(cls, attr, _make_compatible_func(func, model))
262+
if models is not None:
263+
setattr(cls, attr, _make_multi_compatible_func(func, models))
264+
else:
265+
if model is None:
266+
continue
267+
setattr(cls, attr, _make_compatible_func(func, model))
174268
return cls

0 commit comments

Comments
 (0)