diff --git a/piccolo_api/crud/endpoints.py b/piccolo_api/crud/endpoints.py index b2067cb5..d0d644cc 100644 --- a/piccolo_api/crud/endpoints.py +++ b/piccolo_api/crud/endpoints.py @@ -37,7 +37,7 @@ execute_post_hooks, ) -from .exceptions import MalformedQuery, db_exception_handler +from .exceptions import MalformedQuery, PageSizeExceeded, db_exception_handler from .serializers import Config, create_pydantic_model from .validators import Validators, apply_validators @@ -272,6 +272,7 @@ def __init__( Route(path="/schema/", endpoint=self.get_schema, methods=["GET"]), Route(path="/ids/", endpoint=self.get_ids, methods=["GET"]), Route(path="/count/", endpoint=self.get_count, methods=["GET"]), + Route(path="/query/", endpoint=self.post_query, methods=["POST"]), Route( path="/references/", endpoint=self.get_references, @@ -759,19 +760,12 @@ def _apply_filters( return query - @apply_validators - async def get_all( - self, request: Request, params: t.Optional[t.Dict[str, t.Any]] = None - ) -> Response: - """ - Get all rows - query parameters are used for filtering. - """ + async def _get_rows( + self, params: t.Optional[t.Dict[str, t.Any]] = None + ) -> t.Tuple[pydantic.BaseModel, t.Dict]: params = self._clean_data(params) if params else {} - try: - split_params = self._split_params(params) - except ParamException as exception: - return Response(str(exception), status_code=400) + split_params = self._split_params(params) # Visible fields visible_fields = split_params.visible_fields @@ -808,10 +802,7 @@ async def get_all( query = query.output(nested=True) # Apply filters - try: - query = t.cast(Select, self._apply_filters(query, split_params)) - except MalformedQuery as exception: - return Response(str(exception), status_code=400) + query = t.cast(Select, self._apply_filters(query, split_params)) # Ordering order_by = split_params.order_by @@ -829,10 +820,8 @@ async def get_all( page_size = split_params.page_size or self.page_size # If the page_size is greater than max_page_size return an error if page_size > self.max_page_size: - return JSONResponse( - {"error": "The page size limit has been exceeded"}, - status_code=403, - ) + raise PageSizeExceeded("The page size limit has been exceeded") + query = query.limit(page_size) page = split_params.page offset = 0 @@ -861,12 +850,114 @@ async def get_all( # We need to serialise it ourselves, in case there are datetime # fields. - json = self.pydantic_model_plural( - include_readable=include_readable, - include_columns=tuple(visible_fields), + data_model = self.pydantic_model_plural( + include_readable=split_params.include_readable, + include_columns=tuple(split_params.visible_fields) + if split_params.visible_fields + else tuple(), nested=nested, - )(rows=rows).json() - return CustomJSONResponse(json, headers=headers) + )(rows=rows) + + return (data_model, headers) + + @apply_validators + async def get_all( + self, request: Request, params: t.Optional[t.Dict[str, t.Any]] = None + ) -> Response: + """ + Get all rows - query parameters are used for filtering. + """ + try: + data_model, headers = await self._get_rows(params=params) + except (MalformedQuery, ParamException, PageSizeExceeded) as exception: + return JSONResponse({"error": str(exception)}, status_code=400) + + return CustomJSONResponse(data_model.json(), headers=headers) + + ########################################################################### + + @apply_validators + async def post_query(self, request: Request): + """ + This endpoint allows the user to pass multiple queries in one go via + JSON. This can save on network requests if a lot of queries need to + be performed. + + The data structure should be like: + + .. code-block:: javascript + + { + "queries": [ + { + "name": "Star Wars", + "name__match": "contains", + "__visible_fields": ["name"] + }, + { + "name": "Lord of the Rings", + "name__match": "contains", + "__visible_fields": ["name"] + } + ] + } + + In the above example, we query for each ``'Star Wars`'`` and + ``'Lord of the Rings'`` movie. The response looks like this: + + .. code-block:: javascript + + { + "response": [ + [ + { + "name": "Star Wars: Episode IV - A New Hope" + }, + { + "name": "Star Wars: Episode I - The Phantom Menace" + } + ], + [ + { + "name": "The Lord of the Rings: The Fellowship of the Ring" + }, + { + "name": "The Lord of the Rings: The Two Towers" + } + ] + ] + ] + } + + The results are returned in the same order as the queries. + + """ # noqa E501 + data = await request.json() + + queries = data.get("queries") + if queries is None or not isinstance(queries, list): + return JSONResponse( + content={"error": "Malformed query body"}, + status_code=400, + ) + + response_json: t.List[str] = [] + + for query in queries: + try: + data_model, _ = await self._get_rows(params=query) + except ( + MalformedQuery, + ParamException, + PageSizeExceeded, + ) as exception: + return JSONResponse({"error": str(exception)}, status_code=400) + + response_json.append(data_model.json()) + + return CustomJSONResponse( + content='{"response": [' + ",".join(response_json) + "]}" + ) ########################################################################### diff --git a/piccolo_api/crud/exceptions.py b/piccolo_api/crud/exceptions.py index 7934118c..7622851e 100644 --- a/piccolo_api/crud/exceptions.py +++ b/piccolo_api/crud/exceptions.py @@ -31,6 +31,23 @@ class MalformedQuery(Exception): pass +class PageSizeExceeded(Exception): + """ + Raised when the page size requested too large (meaning we would return too + much data). + """ + + pass + + +class RowRetrievalError(Exception): + """ + A catch all exception for several more specific errors. + """ + + pass + + def db_exception_handler(func: t.Callable[..., t.Coroutine]): """ A decorator which wraps an endpoint, and converts database exceptions diff --git a/piccolo_api/crud/validators.py b/piccolo_api/crud/validators.py index a86bbaa3..b9839397 100644 --- a/piccolo_api/crud/validators.py +++ b/piccolo_api/crud/validators.py @@ -61,6 +61,7 @@ def __init__( get_new: t.List[ValidatorFunction] = [], get_schema: t.List[ValidatorFunction] = [], get_count: t.List[ValidatorFunction] = [], + post_query: t.List[ValidatorFunction] = [], extra_context: t.Dict[str, t.Any] = {}, ): self.every = every @@ -76,6 +77,7 @@ def __init__( self.get_new = get_new self.get_schema = get_schema self.get_count = get_count + self.post_query = post_query self.extra_context = extra_context diff --git a/piccolo_api/fastapi/endpoints.py b/piccolo_api/fastapi/endpoints.py index 78d01ad9..627a68be 100644 --- a/piccolo_api/fastapi/endpoints.py +++ b/piccolo_api/fastapi/endpoints.py @@ -41,6 +41,7 @@ def __init__( patch: t.Dict[str, t.Any] = {}, get_single: t.Dict[str, t.Any] = {}, delete_single: t.Dict[str, t.Any] = {}, + post_query: t.Dict[str, t.Any] = {}, ): self.all_routes = all_routes self.get = get @@ -50,6 +51,7 @@ def __init__( self.patch = patch self.get_single = get_single self.delete_single = delete_single + self.post_query = post_query def get_kwargs(self, endpoint_name: str) -> t.Dict[str, t.Any]: """ @@ -227,6 +229,24 @@ async def schema(request: Request): **fastapi_kwargs.get_kwargs("get"), ) + ####################################################################### + # Root - Post Query + + async def post_query(request: Request): + """ + Post a query, which lets you retrieve multiple responses in one + request. + """ + return await piccolo_crud.post_query(request=request) + + fastapi_app.add_api_route( + path=self.join_urls(root_url, "/query/"), + endpoint=post_query, + methods=["POST"], + response_model=t.Dict[str, t.Any], + **fastapi_kwargs.get_kwargs("post_query"), + ) + ####################################################################### # Root - References @@ -441,7 +461,7 @@ def modify_signature( name=f"{field_name}__operator", kind=Parameter.POSITIONAL_OR_KEYWORD, default=Query( - default=None, + default="e", description=( f"Which operator to use for `{field_name}`. " "The options are `e` (equals - default) `lt`, " @@ -449,6 +469,15 @@ def modify_signature( "`not_null`." ), ), + annotation=t.Literal[ + "lt", + "lte", + "gt", + "gte", + "e", + "is_null", + "not_null", + ], ) ) else: @@ -463,6 +492,10 @@ def modify_signature( "The options are `is_null`, and `not_null`." ), ), + annotation=t.Literal[ + "is_null", + "not_null", + ], ) ) @@ -474,13 +507,16 @@ def modify_signature( name=f"{field_name}__match", kind=Parameter.POSITIONAL_OR_KEYWORD, default=Query( - default=None, + default="contains", description=( f"Specifies how `{field_name}` should be " "matched - `contains` (default), `exact`, " "`starts`, `ends`." ), ), + annotation=t.Literal[ + "contains", "exact", "starts", "ends" + ], ) ) diff --git a/tests/crud/test_crud_endpoints.py b/tests/crud/test_crud_endpoints.py index bcc5d0b5..8c67773d 100644 --- a/tests/crud/test_crud_endpoints.py +++ b/tests/crud/test_crud_endpoints.py @@ -756,11 +756,13 @@ def test_visible_fields(self): ) self.assertEqual(response.status_code, 400) self.assertEqual( - response.content, - ( - b"No matching column found with name == foobar - the column " - b"options are ('id', 'name', 'rating')." - ), + response.json(), + { + "error": ( + "No matching column found with name == foobar - the " + "column options are ('id', 'name', 'rating')." + ), + }, ) def test_visible_fields_with_join(self): @@ -775,7 +777,7 @@ def test_visible_fields_with_join(self): params={"__visible_fields": "name,movie.name", "__order": "id"}, ) self.assertEqual(response.status_code, 400) - self.assertEqual(response.content, b"Max join depth exceeded") + self.assertEqual(response.json(), {"error": "Max join depth exceeded"}) # Test 2 - should work as `max_joins` is set: client = TestClient( @@ -865,7 +867,7 @@ def test_page_size_limit(self): response = client.get( "/", params={"__page_size": PiccoloCRUD.max_page_size + 1} ) - self.assertEqual(response.status_code, 403) + self.assertEqual(response.status_code, 400) self.assertEqual( response.json(), {"error": "The page size limit has been exceeded"} )