Skip to content

Commit 683c1e3

Browse files
committed
Fix bug: Discriminated unions didn't work for Feature 🐝
1 parent 8e7cbf1 commit 683c1e3

File tree

4 files changed

+629
-25
lines changed

4 files changed

+629
-25
lines changed

packages/overture-schema-system/src/overture/schema/system/feature.py

Lines changed: 198 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
Geospatial feature model with GeoJSON-compatible JSON Schema.
33
"""
44

5+
import inspect
56
from enum import Enum
67
from functools import reduce
78
from typing import Any
89

910
from pydantic import (
1011
BaseModel,
12+
Discriminator,
1113
Field,
1214
GetJsonSchemaHandler,
1315
ModelWrapValidatorHandler,
@@ -88,6 +90,82 @@ class Feature(BaseModel):
8890
... ]
8991
9092
.. _GeoJSON format: https://datatracker.ietf.org/doc/html/rfc7946
93+
94+
Because the GeoJSON format moves feature fields to a place Pydantic does not expect them (the
95+
`"properties"` block of the JSON object), a naive use of Pydantic discriminated unions will not
96+
work when deserializing from JSON with `model_validate_json`. To create a robust discriminated
97+
union, use the `field_discriminator` method:
98+
99+
>>> from typing import Annotated, Literal
100+
>>> from overture.schema.system.primitive import float32
101+
>>> import pydantic
102+
>>>
103+
>>> class Field(Feature):
104+
... geometry: Annotated[
105+
... Geometry,
106+
... GeometryTypeConstraint(GeometryType.POLYGON, GeometryType.MULTI_POLYGON)
107+
... ]
108+
... type: Literal['field']
109+
...
110+
>>> class Fence(Feature):
111+
... geometry: Annotated[
112+
... Geometry,
113+
... GeometryTypeConstraint(GeometryType.LINE_STRING)
114+
... ]
115+
... type: Literal['fence']
116+
... subtype: Literal[
117+
... 'chain_link', 'barb_wire_3', 'barb_wire_4', 'barb_wire_5', 'electric_1',
118+
... 'electric_2', 'electric_3', 'electric_4', 'split_rail', 'woven_wire'
119+
... ]
120+
... height: float32 | None = pydantic.Field(
121+
... default=None,
122+
... description='Optional fence height in meters'
123+
... )
124+
...
125+
>>> FarmFeature = pydantic.TypeAdapter(
126+
... Annotated[
127+
... Annotated[Field, pydantic.Tag('field')]
128+
... | Annotated[Fence, pydantic.Tag('fence')],
129+
... pydantic.Field(
130+
... discriminator=Feature.field_discriminator('type', Field, Fence)
131+
... ),
132+
... ]
133+
... )
134+
>>>
135+
>>> FarmFeature.validate_json('''{
136+
... "type": "Feature",
137+
... "geometry": {
138+
... "type": "LineString",
139+
... "coordinates": [[0, 0], [0, 0.01]]
140+
... },
141+
... "properties": {
142+
... "type": "fence",
143+
... "subtype": "barb_wire_4"
144+
... }
145+
... }''')
146+
Fence(id=<MISSING>, bbox=<MISSING>, geometry=<<LINESTRING (0 0, 0 0.01)>>, type='fence', subtype='barb_wire_4', height=None)
147+
148+
You can model classes that are not `Feature` subclasses in `field_discriminator` to enable
149+
discriminated unions between features and non-features, as long as at least one model class is
150+
a `Feature`:
151+
152+
>>> class Farmer(BaseModel):
153+
... type: Literal['farmer']
154+
... name: str
155+
...
156+
>>> FarmModel = pydantic.TypeAdapter(
157+
... Annotated[
158+
... Annotated[Farmer, pydantic.Tag('farmer')]
159+
... | Annotated[Field, pydantic.Tag('field')]
160+
... | Annotated[Fence, pydantic.Tag('fence')],
161+
... pydantic.Field(
162+
... discriminator=Feature.field_discriminator('type', Field, Fence)
163+
... ),
164+
... ]
165+
... )
166+
>>>
167+
>>> FarmModel.validate_json('{"type":"farmer","name":"John Deere"}')
168+
Farmer(type='farmer', name='John Deere')
91169
"""
92170

93171
id: Omitable[Id] = Field(description="An optional unique ID for the feature")
@@ -104,6 +182,122 @@ class Feature(BaseModel):
104182
field and annotating it with a `GeometryTypeConstraint`.
105183
"""
106184

185+
@staticmethod
186+
def field_discriminator(
187+
field: str, *model_classes: type[BaseModel]
188+
) -> Discriminator:
189+
"""
190+
Return a discriminator that can be used in a Pydantic `Field` to support tagged unions of
191+
features.
192+
193+
Use this method to generate a Pydantic discriminator that works *both* with Python-style
194+
flat data *and* GeoJSON. Note that at least one member of `model_classes` must be a
195+
`Feature` (if no feature models are involved, you don't need this method and should build
196+
your discriminated union using Pydantic's standard discriminator facilities).
197+
198+
Parameters
199+
----------
200+
field : str
201+
Field name, which must be present in all models
202+
*model_classes : type[BaseModel]
203+
One or more Pydantic model classes, at least one of which must be a subclass of the
204+
`Feature` class
205+
206+
Returns
207+
-------
208+
Discriminator
209+
Discriminator that enables discriminated unions that include features
210+
211+
Raises
212+
------
213+
TypeError
214+
If any member of `model_classes` is not a subclass of `BaseModel`.
215+
TypeError
216+
If no member of `model_classes` is a subclass of `Feature`.
217+
TypeError
218+
If any member of `model_classes` does not have a field named `field`.
219+
ValueError
220+
If `field` names one of the core GeoJSON feature fields that cannot be discriminated:
221+
`"bbox"`, `"geometry`", or `"id"`.
222+
ValueError
223+
If `model_classes` has length less than 2.
224+
"""
225+
if not isinstance(field, str):
226+
raise TypeError(
227+
f"`field` must be a `str`, but {repr(field)} has type `{type(field).__name__}`"
228+
)
229+
elif field in ["bbox", "geometry", "id"]:
230+
raise ValueError(
231+
f"`field` value {repr(field)} is not allowed because it is one of the core GeoJSON "
232+
"feature properties: 'bbox', 'geometry', and 'id' - use a different discriminator "
233+
"field!"
234+
)
235+
elif len(model_classes) < 2:
236+
raise ValueError(
237+
f"`model_classes` must have at least two items, but {repr(model_classes)} has length {len(model_classes)}"
238+
)
239+
240+
non_models = [
241+
x
242+
for x in model_classes
243+
if not isinstance(x, type) or not issubclass(x, BaseModel)
244+
]
245+
if non_models:
246+
raise TypeError(
247+
"`model_classes` contains at least one non-model class: the value(s) "
248+
f"{repr(non_models)} should be subclasses of {BaseModel.__name__} but the type(s) "
249+
f"are {', '.join([f'`{x.__name__ if isinstance(x, type) else type(x).__name__}`' for x in non_models])}, "
250+
f"respectively, which are not subclasses of `{BaseModel.__name__}`..."
251+
)
252+
253+
missing_field = [
254+
f"`{t.__name__}`" for t in model_classes if field not in t.model_fields
255+
]
256+
if missing_field:
257+
raise TypeError(
258+
"`model_classes` contains at least one model class that does not have a field "
259+
f"named {repr(field)}: {', '.join(missing_field)}"
260+
)
261+
262+
if not any(t for t in model_classes if issubclass(t, Feature)):
263+
frame = inspect.currentframe()
264+
method_name = (
265+
f"{Feature.__name__}.{frame.f_code.co_name if frame else '???'}"
266+
)
267+
raise TypeError(
268+
f"`model_classes` does not contain any subclasses of `{Feature.__name__}` - "
269+
f"you don't need `{method_name}(...)` unless you have at least one "
270+
f"`{Feature.__name__}` model - use standard Pydantic discriminators instead"
271+
)
272+
273+
def get_discriminator_value(data: object) -> Any:
274+
# Pydantic doesn't have a facility to tell the dynamic discriminator function whether
275+
# the context is 'python' or 'json', so we just have to use heuristics. If the input is
276+
# a `dict` with the mandatory attributes `"type": "Feature"`, `"geometry"`, and
277+
# `"properties"`, we assume we're in GeoJSON-land, otherwise not.
278+
#
279+
# If the data doesn't contain the discriminator field at all, we return `None` to tell
280+
# Pydantic proceed to try the next variant in the union, if there is one. This is
281+
# equivalent to how Pydantic behaves with static unions.
282+
if (
283+
isinstance(data, dict)
284+
and all(f in data for f in ["geometry", "properties", "type"])
285+
and data["type"] == "Feature"
286+
):
287+
properties: Any = data["properties"]
288+
if not isinstance(properties, dict):
289+
return None
290+
else:
291+
return properties.get(field, None)
292+
else:
293+
return (
294+
data.get(field, None)
295+
if isinstance(data, dict)
296+
else getattr(data, field, None)
297+
)
298+
299+
return Discriminator(get_discriminator_value)
300+
107301
@model_serializer(mode="wrap")
108302
def __serialize_with_geo_json_support__(
109303
self, serializer: SerializerFunctionWrapHandler, info: SerializationInfo
@@ -133,12 +327,11 @@ def __validate_with_geo_json_support__(
133327
Validate the model as GeoJSON when the mode is JSON, otherwise applies Pydantic's standard
134328
validation.
135329
"""
136-
if not isinstance(data, dict):
137-
raise TypeError(
138-
f"feature data must be a `dict`, but {repr(data)} is a `{type(data).__name__}`"
139-
)
140-
141330
if info.mode == "json":
331+
if not isinstance(data, dict):
332+
raise TypeError(
333+
f"feature data must be a `dict` when validating JSON, but {repr(data)} is a `{type(data).__name__}`"
334+
)
142335

143336
def validation_error(
144337
type: str, input: object, error: str, *loc: str

0 commit comments

Comments
 (0)