22Geospatial feature model with GeoJSON-compatible JSON Schema.
33"""
44
5+ import inspect
56from enum import Enum
67from functools import reduce
78from typing import Any
89
910from pydantic import (
1011 BaseModel ,
12+ Discriminator ,
1113 Field ,
1214 GetJsonSchemaHandler ,
1315 ModelWrapValidatorHandler ,
@@ -88,6 +90,82 @@ class Feature(BaseModel):
8890 ... ]
8991
9092 .. _GeoJSON format: https://datatracker.ietf.org/doc/html/rfc7946
93+
94+ Because the GeoJSON format moves feature fields to a place Pydantic does not expect them (the
95+ `"properties"` block of the JSON object), a naive use of Pydantic discriminated unions will not
96+ work when deserializing from JSON with `model_validate_json`. To create a robust discriminated
97+ union, use the `field_discriminator` method:
98+
99+ >>> from typing import Annotated, Literal
100+ >>> from overture.schema.system.primitive import float32
101+ >>> import pydantic
102+ >>>
103+ >>> class Field(Feature):
104+ ... geometry: Annotated[
105+ ... Geometry,
106+ ... GeometryTypeConstraint(GeometryType.POLYGON, GeometryType.MULTI_POLYGON)
107+ ... ]
108+ ... type: Literal['field']
109+ ...
110+ >>> class Fence(Feature):
111+ ... geometry: Annotated[
112+ ... Geometry,
113+ ... GeometryTypeConstraint(GeometryType.LINE_STRING)
114+ ... ]
115+ ... type: Literal['fence']
116+ ... subtype: Literal[
117+ ... 'chain_link', 'barb_wire_3', 'barb_wire_4', 'barb_wire_5', 'electric_1',
118+ ... 'electric_2', 'electric_3', 'electric_4', 'split_rail', 'woven_wire'
119+ ... ]
120+ ... height: float32 | None = pydantic.Field(
121+ ... default=None,
122+ ... description='Optional fence height in meters'
123+ ... )
124+ ...
125+ >>> FarmFeature = pydantic.TypeAdapter(
126+ ... Annotated[
127+ ... Annotated[Field, pydantic.Tag('field')]
128+ ... | Annotated[Fence, pydantic.Tag('fence')],
129+ ... pydantic.Field(
130+ ... discriminator=Feature.field_discriminator('type', Field, Fence)
131+ ... ),
132+ ... ]
133+ ... )
134+ >>>
135+ >>> FarmFeature.validate_json('''{
136+ ... "type": "Feature",
137+ ... "geometry": {
138+ ... "type": "LineString",
139+ ... "coordinates": [[0, 0], [0, 0.01]]
140+ ... },
141+ ... "properties": {
142+ ... "type": "fence",
143+ ... "subtype": "barb_wire_4"
144+ ... }
145+ ... }''')
146+ Fence(id=<MISSING>, bbox=<MISSING>, geometry=<<LINESTRING (0 0, 0 0.01)>>, type='fence', subtype='barb_wire_4', height=None)
147+
148+ You can model classes that are not `Feature` subclasses in `field_discriminator` to enable
149+ discriminated unions between features and non-features, as long as at least one model class is
150+ a `Feature`:
151+
152+ >>> class Farmer(BaseModel):
153+ ... type: Literal['farmer']
154+ ... name: str
155+ ...
156+ >>> FarmModel = pydantic.TypeAdapter(
157+ ... Annotated[
158+ ... Annotated[Farmer, pydantic.Tag('farmer')]
159+ ... | Annotated[Field, pydantic.Tag('field')]
160+ ... | Annotated[Fence, pydantic.Tag('fence')],
161+ ... pydantic.Field(
162+ ... discriminator=Feature.field_discriminator('type', Field, Fence)
163+ ... ),
164+ ... ]
165+ ... )
166+ >>>
167+ >>> FarmModel.validate_json('{"type":"farmer","name":"John Deere"}')
168+ Farmer(type='farmer', name='John Deere')
91169 """
92170
93171 id : Omitable [Id ] = Field (description = "An optional unique ID for the feature" )
@@ -104,6 +182,122 @@ class Feature(BaseModel):
104182 field and annotating it with a `GeometryTypeConstraint`.
105183 """
106184
185+ @staticmethod
186+ def field_discriminator (
187+ field : str , * model_classes : type [BaseModel ]
188+ ) -> Discriminator :
189+ """
190+ Return a discriminator that can be used in a Pydantic `Field` to support tagged unions of
191+ features.
192+
193+ Use this method to generate a Pydantic discriminator that works *both* with Python-style
194+ flat data *and* GeoJSON. Note that at least one member of `model_classes` must be a
195+ `Feature` (if no feature models are involved, you don't need this method and should build
196+ your discriminated union using Pydantic's standard discriminator facilities).
197+
198+ Parameters
199+ ----------
200+ field : str
201+ Field name, which must be present in all models
202+ *model_classes : type[BaseModel]
203+ One or more Pydantic model classes, at least one of which must be a subclass of the
204+ `Feature` class
205+
206+ Returns
207+ -------
208+ Discriminator
209+ Discriminator that enables discriminated unions that include features
210+
211+ Raises
212+ ------
213+ TypeError
214+ If any member of `model_classes` is not a subclass of `BaseModel`.
215+ TypeError
216+ If no member of `model_classes` is a subclass of `Feature`.
217+ TypeError
218+ If any member of `model_classes` does not have a field named `field`.
219+ ValueError
220+ If `field` names one of the core GeoJSON feature fields that cannot be discriminated:
221+ `"bbox"`, `"geometry`", or `"id"`.
222+ ValueError
223+ If `model_classes` has length less than 2.
224+ """
225+ if not isinstance (field , str ):
226+ raise TypeError (
227+ f"`field` must be a `str`, but { repr (field )} has type `{ type (field ).__name__ } `"
228+ )
229+ elif field in ["bbox" , "geometry" , "id" ]:
230+ raise ValueError (
231+ f"`field` value { repr (field )} is not allowed because it is one of the core GeoJSON "
232+ "feature properties: 'bbox', 'geometry', and 'id' - use a different discriminator "
233+ "field!"
234+ )
235+ elif len (model_classes ) < 2 :
236+ raise ValueError (
237+ f"`model_classes` must have at least two items, but { repr (model_classes )} has length { len (model_classes )} "
238+ )
239+
240+ non_models = [
241+ x
242+ for x in model_classes
243+ if not isinstance (x , type ) or not issubclass (x , BaseModel )
244+ ]
245+ if non_models :
246+ raise TypeError (
247+ "`model_classes` contains at least one non-model class: the value(s) "
248+ f"{ repr (non_models )} should be subclasses of { BaseModel .__name__ } but the type(s) "
249+ f"are { ', ' .join ([f'`{ x .__name__ if isinstance (x , type ) else type (x ).__name__ } `' for x in non_models ])} , "
250+ f"respectively, which are not subclasses of `{ BaseModel .__name__ } `..."
251+ )
252+
253+ missing_field = [
254+ f"`{ t .__name__ } `" for t in model_classes if field not in t .model_fields
255+ ]
256+ if missing_field :
257+ raise TypeError (
258+ "`model_classes` contains at least one model class that does not have a field "
259+ f"named { repr (field )} : { ', ' .join (missing_field )} "
260+ )
261+
262+ if not any (t for t in model_classes if issubclass (t , Feature )):
263+ frame = inspect .currentframe ()
264+ method_name = (
265+ f"{ Feature .__name__ } .{ frame .f_code .co_name if frame else '???' } "
266+ )
267+ raise TypeError (
268+ f"`model_classes` does not contain any subclasses of `{ Feature .__name__ } ` - "
269+ f"you don't need `{ method_name } (...)` unless you have at least one "
270+ f"`{ Feature .__name__ } ` model - use standard Pydantic discriminators instead"
271+ )
272+
273+ def get_discriminator_value (data : object ) -> Any :
274+ # Pydantic doesn't have a facility to tell the dynamic discriminator function whether
275+ # the context is 'python' or 'json', so we just have to use heuristics. If the input is
276+ # a `dict` with the mandatory attributes `"type": "Feature"`, `"geometry"`, and
277+ # `"properties"`, we assume we're in GeoJSON-land, otherwise not.
278+ #
279+ # If the data doesn't contain the discriminator field at all, we return `None` to tell
280+ # Pydantic proceed to try the next variant in the union, if there is one. This is
281+ # equivalent to how Pydantic behaves with static unions.
282+ if (
283+ isinstance (data , dict )
284+ and all (f in data for f in ["geometry" , "properties" , "type" ])
285+ and data ["type" ] == "Feature"
286+ ):
287+ properties : Any = data ["properties" ]
288+ if not isinstance (properties , dict ):
289+ return None
290+ else :
291+ return properties .get (field , None )
292+ else :
293+ return (
294+ data .get (field , None )
295+ if isinstance (data , dict )
296+ else getattr (data , field , None )
297+ )
298+
299+ return Discriminator (get_discriminator_value )
300+
107301 @model_serializer (mode = "wrap" )
108302 def __serialize_with_geo_json_support__ (
109303 self , serializer : SerializerFunctionWrapHandler , info : SerializationInfo
@@ -133,12 +327,11 @@ def __validate_with_geo_json_support__(
133327 Validate the model as GeoJSON when the mode is JSON, otherwise applies Pydantic's standard
134328 validation.
135329 """
136- if not isinstance (data , dict ):
137- raise TypeError (
138- f"feature data must be a `dict`, but { repr (data )} is a `{ type (data ).__name__ } `"
139- )
140-
141330 if info .mode == "json" :
331+ if not isinstance (data , dict ):
332+ raise TypeError (
333+ f"feature data must be a `dict` when validating JSON, but { repr (data )} is a `{ type (data ).__name__ } `"
334+ )
142335
143336 def validation_error (
144337 type : str , input : object , error : str , * loc : str
0 commit comments