2626MethodT = TypeVar ("MethodT" , bound = Callable )
2727ClassT = TypeVar ("ClassT" , bound = type )
2828T = TypeVar ("T" )
29+ MultiParamModelSpec = tuple [type [BaseModel ], ...]
30+
31+
32+ def _param_models_name (models : MultiParamModelSpec ) -> str :
33+ return " | " .join (model_type .__name__ for model_type in models )
34+
35+
36+ def _param_models_field_names (models : MultiParamModelSpec ) -> tuple [str , ...]:
37+ shared_fields = set (models [0 ].model_fields )
38+ for model_type in models [1 :]:
39+ shared_fields &= set (model_type .model_fields )
40+ return tuple (field_name for field_name in models [0 ].model_fields if field_name in shared_fields )
41+
42+
43+ def model_to_kwargs (model_obj : BaseModel , models : MultiParamModelSpec ) -> dict [str , Any ]:
44+ kwargs = {
45+ field_name : getattr (model_obj , field_name )
46+ for field_name in _param_models_field_names (models )
47+ if field_name != "field_meta"
48+ }
49+ if meta := getattr (model_obj , "field_meta" , None ):
50+ kwargs .update (meta )
51+ return kwargs
2952
3053
3154def serialize_params (params : BaseModel ) -> dict [str , Any ]:
@@ -114,6 +137,18 @@ def decorator(func: MethodT) -> MethodT:
114137 return decorator
115138
116139
140+ def param_models (* param_cls : type [BaseModel ]) -> Callable [[MethodT ], MethodT ]:
141+ """Decorator to mark a method as accepting multiple legacy parameter models."""
142+ if not param_cls :
143+ raise ValueError ("param_models() requires at least one model class" )
144+
145+ def decorator (func : MethodT ) -> MethodT :
146+ func .__param_models__ = param_cls # type: ignore[attr-defined]
147+ return func
148+
149+ return decorator
150+
151+
117152def to_camel_case (snake_str : str ) -> str :
118153 """Convert snake_case strings to camelCase."""
119154 components = snake_str .split ("_" )
@@ -129,7 +164,9 @@ def wrapped(self, params: BaseModel) -> T:
129164 DeprecationWarning ,
130165 stacklevel = 3 ,
131166 )
132- kwargs = {k : getattr (params , k ) for k in model .model_fields if k != "field_meta" }
167+ kwargs = {
168+ field_name : getattr (params , field_name ) for field_name in model .model_fields if field_name != "field_meta"
169+ }
133170 if meta := getattr (params , "field_meta" , None ):
134171 kwargs .update (meta )
135172 return func (self , ** kwargs ) # type: ignore[arg-type]
@@ -152,7 +189,11 @@ def wrapped(self, *args: Any, **kwargs: Any) -> T:
152189 DeprecationWarning ,
153190 stacklevel = 3 ,
154191 )
155- kwargs = {k : getattr (param , k ) for k in model .model_fields if k != "field_meta" }
192+ kwargs = {
193+ field_name : getattr (param , field_name )
194+ for field_name in model .model_fields
195+ if field_name != "field_meta"
196+ }
156197 if meta := getattr (param , "field_meta" , None ):
157198 kwargs .update (meta )
158199 return func (self , ** kwargs ) # type: ignore[arg-type]
@@ -161,14 +202,67 @@ def wrapped(self, *args: Any, **kwargs: Any) -> T:
161202 return wrapped
162203
163204
205+ def _make_multi_legacy_func (func : Callable [..., T ], models : MultiParamModelSpec ) -> Callable [[Any , BaseModel ], T ]:
206+ model_name = _param_models_name (models )
207+
208+ @functools .wraps (func )
209+ def wrapped (self , params : BaseModel ) -> T :
210+ warnings .warn (
211+ f"Calling { func .__name__ } with { model_name } parameter is " # type: ignore[attr-defined]
212+ "deprecated, please update to the new API style." ,
213+ DeprecationWarning ,
214+ stacklevel = 3 ,
215+ )
216+ return func (self , ** model_to_kwargs (params , models )) # type: ignore[arg-type]
217+
218+ return wrapped
219+
220+
221+ def _make_multi_compatible_func (func : Callable [..., T ], models : MultiParamModelSpec ) -> Callable [..., T ]:
222+ model_name = _param_models_name (models )
223+
224+ @functools .wraps (func )
225+ def wrapped (self , * args : Any , ** kwargs : Any ) -> T :
226+ param = None
227+ if not kwargs and len (args ) == 1 :
228+ param = args [0 ]
229+ elif not args and len (kwargs ) == 1 :
230+ param = kwargs .get ("params" )
231+ if isinstance (param , models ):
232+ warnings .warn (
233+ f"Calling { func .__name__ } with { model_name } parameter " # type: ignore[attr-defined]
234+ "is deprecated, please update to the new API style." ,
235+ DeprecationWarning ,
236+ stacklevel = 3 ,
237+ )
238+ return func (self , ** model_to_kwargs (param , models )) # type: ignore[arg-type]
239+ return func (self , * args , ** kwargs )
240+
241+ return wrapped
242+
243+
164244def compatible_class (cls : ClassT ) -> ClassT :
165245 """Mark a class as backward compatible with old API style."""
166246 for attr in dir (cls ):
167247 func = getattr (cls , attr )
168- if not callable (func ) or (model := getattr (func , "__param_model__" , None )) is None :
248+ if not callable (func ):
249+ continue
250+ model = getattr (func , "__param_model__" , None )
251+ models = getattr (func , "__param_models__" , None )
252+ if model is None and models is None :
169253 continue
170254 if "_" in attr :
171- setattr (cls , to_camel_case (attr ), _make_legacy_func (func , model ))
255+ if models is not None :
256+ setattr (cls , to_camel_case (attr ), _make_multi_legacy_func (func , models ))
257+ else :
258+ if model is None :
259+ continue
260+ setattr (cls , to_camel_case (attr ), _make_legacy_func (func , model ))
172261 else :
173- setattr (cls , attr , _make_compatible_func (func , model ))
262+ if models is not None :
263+ setattr (cls , attr , _make_multi_compatible_func (func , models ))
264+ else :
265+ if model is None :
266+ continue
267+ setattr (cls , attr , _make_compatible_func (func , model ))
174268 return cls
0 commit comments