Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 116 additions & 25 deletions piccolo_api/crud/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
execute_post_hooks,
)

from .exceptions import MalformedQuery, db_exception_handler
from .exceptions import MalformedQuery, PageSizeExceeded, db_exception_handler
from .serializers import Config, create_pydantic_model
from .validators import Validators, apply_validators

Expand Down Expand Up @@ -272,6 +272,7 @@ def __init__(
Route(path="/schema/", endpoint=self.get_schema, methods=["GET"]),
Route(path="/ids/", endpoint=self.get_ids, methods=["GET"]),
Route(path="/count/", endpoint=self.get_count, methods=["GET"]),
Route(path="/query/", endpoint=self.post_query, methods=["POST"]),
Route(
path="/references/",
endpoint=self.get_references,
Expand Down Expand Up @@ -759,19 +760,12 @@ def _apply_filters(

return query

@apply_validators
async def get_all(
self, request: Request, params: t.Optional[t.Dict[str, t.Any]] = None
) -> Response:
"""
Get all rows - query parameters are used for filtering.
"""
async def _get_rows(
self, params: t.Optional[t.Dict[str, t.Any]] = None
) -> t.Tuple[pydantic.BaseModel, t.Dict]:
params = self._clean_data(params) if params else {}

try:
split_params = self._split_params(params)
except ParamException as exception:
return Response(str(exception), status_code=400)
split_params = self._split_params(params)

# Visible fields
visible_fields = split_params.visible_fields
Expand Down Expand Up @@ -808,10 +802,7 @@ async def get_all(
query = query.output(nested=True)

# Apply filters
try:
query = t.cast(Select, self._apply_filters(query, split_params))
except MalformedQuery as exception:
return Response(str(exception), status_code=400)
query = t.cast(Select, self._apply_filters(query, split_params))

# Ordering
order_by = split_params.order_by
Expand All @@ -829,10 +820,8 @@ async def get_all(
page_size = split_params.page_size or self.page_size
# If the page_size is greater than max_page_size return an error
if page_size > self.max_page_size:
return JSONResponse(
{"error": "The page size limit has been exceeded"},
status_code=403,
)
raise PageSizeExceeded("The page size limit has been exceeded")

query = query.limit(page_size)
page = split_params.page
offset = 0
Expand Down Expand Up @@ -861,12 +850,114 @@ async def get_all(

# We need to serialise it ourselves, in case there are datetime
# fields.
json = self.pydantic_model_plural(
include_readable=include_readable,
include_columns=tuple(visible_fields),
data_model = self.pydantic_model_plural(
include_readable=split_params.include_readable,
include_columns=tuple(split_params.visible_fields)
if split_params.visible_fields
else tuple(),
nested=nested,
)(rows=rows).json()
return CustomJSONResponse(json, headers=headers)
)(rows=rows)

return (data_model, headers)

@apply_validators
async def get_all(
self, request: Request, params: t.Optional[t.Dict[str, t.Any]] = None
) -> Response:
"""
Get all rows - query parameters are used for filtering.
"""
try:
data_model, headers = await self._get_rows(params=params)
except (MalformedQuery, ParamException, PageSizeExceeded) as exception:
return JSONResponse({"error": str(exception)}, status_code=400)

return CustomJSONResponse(data_model.json(), headers=headers)

###########################################################################

@apply_validators
async def post_query(self, request: Request):
"""
This endpoint allows the user to pass multiple queries in one go via
JSON. This can save on network requests if a lot of queries need to
be performed.

The data structure should be like:

.. code-block:: javascript

{
"queries": [
{
"name": "Star Wars",
"name__match": "contains",
"__visible_fields": ["name"]
},
{
"name": "Lord of the Rings",
"name__match": "contains",
"__visible_fields": ["name"]
}
]
}

In the above example, we query for each ``'Star Wars`'`` and
``'Lord of the Rings'`` movie. The response looks like this:

.. code-block:: javascript

{
"response": [
[
{
"name": "Star Wars: Episode IV - A New Hope"
},
{
"name": "Star Wars: Episode I - The Phantom Menace"
}
],
[
{
"name": "The Lord of the Rings: The Fellowship of the Ring"
},
{
"name": "The Lord of the Rings: The Two Towers"
}
]
]
]
}

The results are returned in the same order as the queries.

""" # noqa E501
data = await request.json()

queries = data.get("queries")
if queries is None or not isinstance(queries, list):
return JSONResponse(
content={"error": "Malformed query body"},
status_code=400,
)

response_json: t.List[str] = []

for query in queries:
try:
data_model, _ = await self._get_rows(params=query)
except (
MalformedQuery,
ParamException,
PageSizeExceeded,
) as exception:
return JSONResponse({"error": str(exception)}, status_code=400)

response_json.append(data_model.json())

return CustomJSONResponse(
content='{"response": [' + ",".join(response_json) + "]}"
)

###########################################################################

Expand Down
17 changes: 17 additions & 0 deletions piccolo_api/crud/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,23 @@ class MalformedQuery(Exception):
pass


class PageSizeExceeded(Exception):
"""
Raised when the page size requested too large (meaning we would return too
much data).
"""

pass


class RowRetrievalError(Exception):
"""
A catch all exception for several more specific errors.
"""

pass


def db_exception_handler(func: t.Callable[..., t.Coroutine]):
"""
A decorator which wraps an endpoint, and converts database exceptions
Expand Down
2 changes: 2 additions & 0 deletions piccolo_api/crud/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
get_new: t.List[ValidatorFunction] = [],
get_schema: t.List[ValidatorFunction] = [],
get_count: t.List[ValidatorFunction] = [],
post_query: t.List[ValidatorFunction] = [],
extra_context: t.Dict[str, t.Any] = {},
):
self.every = every
Expand All @@ -76,6 +77,7 @@ def __init__(
self.get_new = get_new
self.get_schema = get_schema
self.get_count = get_count
self.post_query = post_query
self.extra_context = extra_context


Expand Down
40 changes: 38 additions & 2 deletions piccolo_api/fastapi/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
patch: t.Dict[str, t.Any] = {},
get_single: t.Dict[str, t.Any] = {},
delete_single: t.Dict[str, t.Any] = {},
post_query: t.Dict[str, t.Any] = {},
):
self.all_routes = all_routes
self.get = get
Expand All @@ -50,6 +51,7 @@ def __init__(
self.patch = patch
self.get_single = get_single
self.delete_single = delete_single
self.post_query = post_query

def get_kwargs(self, endpoint_name: str) -> t.Dict[str, t.Any]:
"""
Expand Down Expand Up @@ -227,6 +229,24 @@ async def schema(request: Request):
**fastapi_kwargs.get_kwargs("get"),
)

#######################################################################
# Root - Post Query

async def post_query(request: Request):
"""
Post a query, which lets you retrieve multiple responses in one
request.
"""
return await piccolo_crud.post_query(request=request)

fastapi_app.add_api_route(
path=self.join_urls(root_url, "/query/"),
endpoint=post_query,
methods=["POST"],
response_model=t.Dict[str, t.Any],
**fastapi_kwargs.get_kwargs("post_query"),
)

#######################################################################
# Root - References

Expand Down Expand Up @@ -441,14 +461,23 @@ def modify_signature(
name=f"{field_name}__operator",
kind=Parameter.POSITIONAL_OR_KEYWORD,
default=Query(
default=None,
default="e",
description=(
f"Which operator to use for `{field_name}`. "
"The options are `e` (equals - default) `lt`, "
"`lte`, `gt`, `gte`, `is_null`, and "
"`not_null`."
),
),
annotation=t.Literal[
"lt",
"lte",
"gt",
"gte",
"e",
"is_null",
"not_null",
],
)
)
else:
Expand All @@ -463,6 +492,10 @@ def modify_signature(
"The options are `is_null`, and `not_null`."
),
),
annotation=t.Literal[
"is_null",
"not_null",
],
)
)

Expand All @@ -474,13 +507,16 @@ def modify_signature(
name=f"{field_name}__match",
kind=Parameter.POSITIONAL_OR_KEYWORD,
default=Query(
default=None,
default="contains",
description=(
f"Specifies how `{field_name}` should be "
"matched - `contains` (default), `exact`, "
"`starts`, `ends`."
),
),
annotation=t.Literal[
"contains", "exact", "starts", "ends"
],
)
)

Expand Down
16 changes: 9 additions & 7 deletions tests/crud/test_crud_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,11 +756,13 @@ def test_visible_fields(self):
)
self.assertEqual(response.status_code, 400)
self.assertEqual(
response.content,
(
b"No matching column found with name == foobar - the column "
b"options are ('id', 'name', 'rating')."
),
response.json(),
{
"error": (
"No matching column found with name == foobar - the "
"column options are ('id', 'name', 'rating')."
),
},
)

def test_visible_fields_with_join(self):
Expand All @@ -775,7 +777,7 @@ def test_visible_fields_with_join(self):
params={"__visible_fields": "name,movie.name", "__order": "id"},
)
self.assertEqual(response.status_code, 400)
self.assertEqual(response.content, b"Max join depth exceeded")
self.assertEqual(response.json(), {"error": "Max join depth exceeded"})

# Test 2 - should work as `max_joins` is set:
client = TestClient(
Expand Down Expand Up @@ -865,7 +867,7 @@ def test_page_size_limit(self):
response = client.get(
"/", params={"__page_size": PiccoloCRUD.max_page_size + 1}
)
self.assertEqual(response.status_code, 403)
self.assertEqual(response.status_code, 400)
self.assertEqual(
response.json(), {"error": "The page size limit has been exceeded"}
)
Expand Down