diff --git a/docs/source/session_auth/examples/example.py b/docs/source/session_auth/examples/example.py index e4ba02f4..802a83ca 100644 --- a/docs/source/session_auth/examples/example.py +++ b/docs/source/session_auth/examples/example.py @@ -1,5 +1,5 @@ import datetime -import typing as t +from typing import Any from fastapi import FastAPI from fastapi.responses import JSONResponse @@ -71,16 +71,16 @@ # Example FastAPI endpoints and Pydantic models. -MovieModelIn: t.Any = create_pydantic_model( +MovieModelIn: Any = create_pydantic_model( table=Movie, model_name="MovieModelIn" ) -MovieModelOut: t.Any = create_pydantic_model( +MovieModelOut: Any = create_pydantic_model( table=Movie, include_default_columns=True, model_name="MovieModelOut" ) -@private_app.get("/movies/", response_model=t.List[MovieModelOut]) +@private_app.get("/movies/", response_model=list[MovieModelOut]) async def movies(): return await Movie.select().order_by(Movie._meta.primary_key) diff --git a/piccolo_api/change_password/endpoints.py b/piccolo_api/change_password/endpoints.py index dacac77f..ac754948 100644 --- a/piccolo_api/change_password/endpoints.py +++ b/piccolo_api/change_password/endpoints.py @@ -1,257 +1,257 @@ -from __future__ import annotations - -import os -import typing as t -from abc import ABCMeta, abstractmethod -from json import JSONDecodeError - -from jinja2 import Environment, FileSystemLoader -from starlette.endpoints import HTTPEndpoint, Request -from starlette.exceptions import HTTPException -from starlette.responses import ( - HTMLResponse, - PlainTextResponse, - RedirectResponse, -) -from starlette.status import HTTP_303_SEE_OTHER - -from piccolo_api.session_auth.tables import SessionsBase -from piccolo_api.shared.auth.styles import Styles - -if t.TYPE_CHECKING: # pragma: no cover - from jinja2 import Template - from starlette.responses import Response - - -CHANGE_PASSWORD_TEMPLATE_PATH = os.path.join( - os.path.dirname(os.path.dirname(__file__)), - "templates", - "change_password.html", -) - - -class ChangePasswordEndpoint(HTTPEndpoint, metaclass=ABCMeta): - @property - @abstractmethod - def _login_url(self) -> str: - raise NotImplementedError - - @property - @abstractmethod - def _change_password_template(self) -> Template: - raise NotImplementedError - - @property - @abstractmethod - def _styles(self) -> Styles: - raise NotImplementedError - - @property - @abstractmethod - def _session_table(self) -> t.Optional[t.Type[SessionsBase]]: - raise NotImplementedError - - @property - @abstractmethod - def _session_cookie_name(self) -> t.Optional[str]: - raise NotImplementedError - - @property - @abstractmethod - def _read_only(self) -> bool: - raise NotImplementedError - - def render_template( - self, - request: Request, - template_context: t.Dict[str, t.Any] = {}, - login_url: t.Optional[str] = None, - min_password_length: int = 6, - ) -> HTMLResponse: - # If CSRF middleware is present, we have to include a form field with - # the CSRF token. It only works if CSRFMiddleware has - # allow_form_param=True, otherwise it only looks for the token in the - # header. - csrftoken = request.scope.get("csrftoken") - csrf_cookie_name = request.scope.get("csrf_cookie_name") - - return HTMLResponse( - self._change_password_template.render( - csrftoken=csrftoken, - csrf_cookie_name=csrf_cookie_name, - request=request, - styles=self._styles, - username=request.user.user.username, - login_url=login_url, - min_password_length=min_password_length, - **template_context, - ) - ) - - async def get(self, request: Request) -> Response: - piccolo_user = request.user.user - if piccolo_user: - min_password_length = piccolo_user._min_password_length - return self.render_template( - request, min_password_length=min_password_length - ) - else: - return RedirectResponse(self._login_url) - - async def post(self, request: Request) -> Response: - if self._read_only: - return PlainTextResponse( - content="Running in read only mode", status_code=405 - ) - - # Some middleware (for example CSRF) has already awaited the request - # body, and adds it to the request. - body: t.Any = request.scope.get("form") - - if not body: - try: - body = await request.json() - except JSONDecodeError: - body = await request.form() - - current_password = body.get("current_password", None) - new_password = body.get("new_password", None) - confirm_new_password = body.get("confirm_new_password", None) - - piccolo_user = request.user.user - min_password_length = piccolo_user._min_password_length - - if ( - (not current_password) - or (not new_password) - or (not confirm_new_password) - ): - error = "Form is invalid. Missing one or more fields." - if body.get("format") == "html": - return self.render_template( - request, - template_context={"error": error}, - min_password_length=min_password_length, - ) - raise HTTPException(status_code=422, detail=error) - - if len(new_password) < min_password_length: - error = ( - f"Password must be at least {min_password_length} characters " - "long." - ) - if body.get("format") == "html": - return self.render_template( - request, - min_password_length=min_password_length, - template_context={"error": error}, - ) - else: - raise HTTPException( - status_code=422, - detail=error, - ) - - if confirm_new_password != new_password: - error = "Passwords do not match." - - if body.get("format") == "html": - return self.render_template( - request, - min_password_length=min_password_length, - template_context={"error": error}, - ) - else: - raise HTTPException(status_code=422, detail=error) - - if not await piccolo_user.login( - username=piccolo_user.username, password=current_password - ): - error = "Incorrect password." - if body.get("format") == "html": - return self.render_template( - request, - min_password_length=min_password_length, - template_context={"error": error}, - ) - raise HTTPException(detail=error, status_code=422) - - await piccolo_user.update_password( - user=request.user.user_id, password=new_password - ) - - ####################################################################### - # After the password changes, we invalidate the session and - # redirect the user to the login endpoint. - - session_table = self._session_table - if session_table: - # This will invalidate all of the user's sessions on all devices. - await session_table.delete().where( - session_table.user_id == piccolo_user.id - ) - - response = RedirectResponse( - url=self._login_url, status_code=HTTP_303_SEE_OTHER - ) - - if self._session_cookie_name: - response.delete_cookie(self._session_cookie_name) - - return response - - -def change_password( - login_url: str = "/login/", - session_table: t.Optional[t.Type[SessionsBase]] = SessionsBase, - session_cookie_name: t.Optional[str] = "id", - template_path: t.Optional[str] = None, - styles: t.Optional[Styles] = None, - read_only: bool = False, -) -> t.Type[ChangePasswordEndpoint]: - """ - An endpoint for changing passwords. - - :param login_url: - Where to redirect the user to after successfully changing their - password. - :param session_table: - If provided, when the password is changed, the sessions for the user - will be invalidated in the database. - :param session_cookie_name: - If provided, when the password is changed, the session cookie with this - name will be deleted. - :param template_path: - If you want to override the default change password HTML template, - you can do so by specifying the absolute path to a custom template. - For example ``'/some_directory/change_password.html'``. Refer to - the default template at ``piccolo_api/templates/change_password.html`` - as a basis for your custom template. - :param styles: - Modify the appearance of the HTML template using CSS. - :read_only: - If ``True``, the endpoint only responds to GET requests. It's not - commonly needed, except when running demos. - - """ - template_path = ( - CHANGE_PASSWORD_TEMPLATE_PATH - if template_path is None - else template_path - ) - - directory, filename = os.path.split(template_path) - environment = Environment( - loader=FileSystemLoader(directory), autoescape=True - ) - change_password_template = environment.get_template(filename) - - class _ChangePasswordEndpoint(ChangePasswordEndpoint): - _login_url = login_url - _change_password_template = change_password_template - _styles = styles or Styles() - _session_table = session_table - _session_cookie_name = session_cookie_name - _read_only = read_only - - return _ChangePasswordEndpoint +from __future__ import annotations + +import os +from abc import ABCMeta, abstractmethod +from json import JSONDecodeError +from typing import TYPE_CHECKING, Any, Optional + +from jinja2 import Environment, FileSystemLoader +from starlette.endpoints import HTTPEndpoint, Request +from starlette.exceptions import HTTPException +from starlette.responses import ( + HTMLResponse, + PlainTextResponse, + RedirectResponse, +) +from starlette.status import HTTP_303_SEE_OTHER + +from piccolo_api.session_auth.tables import SessionsBase +from piccolo_api.shared.auth.styles import Styles + +if TYPE_CHECKING: # pragma: no cover + from jinja2 import Template + from starlette.responses import Response + + +CHANGE_PASSWORD_TEMPLATE_PATH = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "templates", + "change_password.html", +) + + +class ChangePasswordEndpoint(HTTPEndpoint, metaclass=ABCMeta): + @property + @abstractmethod + def _login_url(self) -> str: + raise NotImplementedError + + @property + @abstractmethod + def _change_password_template(self) -> Template: + raise NotImplementedError + + @property + @abstractmethod + def _styles(self) -> Styles: + raise NotImplementedError + + @property + @abstractmethod + def _session_table(self) -> Optional[type[SessionsBase]]: + raise NotImplementedError + + @property + @abstractmethod + def _session_cookie_name(self) -> Optional[str]: + raise NotImplementedError + + @property + @abstractmethod + def _read_only(self) -> bool: + raise NotImplementedError + + def render_template( + self, + request: Request, + template_context: dict[str, Any] = {}, + login_url: Optional[str] = None, + min_password_length: int = 6, + ) -> HTMLResponse: + # If CSRF middleware is present, we have to include a form field with + # the CSRF token. It only works if CSRFMiddleware has + # allow_form_param=True, otherwise it only looks for the token in the + # header. + csrftoken = request.scope.get("csrftoken") + csrf_cookie_name = request.scope.get("csrf_cookie_name") + + return HTMLResponse( + self._change_password_template.render( + csrftoken=csrftoken, + csrf_cookie_name=csrf_cookie_name, + request=request, + styles=self._styles, + username=request.user.user.username, + login_url=login_url, + min_password_length=min_password_length, + **template_context, + ) + ) + + async def get(self, request: Request) -> Response: + piccolo_user = request.user.user + if piccolo_user: + min_password_length = piccolo_user._min_password_length + return self.render_template( + request, min_password_length=min_password_length + ) + else: + return RedirectResponse(self._login_url) + + async def post(self, request: Request) -> Response: + if self._read_only: + return PlainTextResponse( + content="Running in read only mode", status_code=405 + ) + + # Some middleware (for example CSRF) has already awaited the request + # body, and adds it to the request. + body: Any = request.scope.get("form") + + if not body: + try: + body = await request.json() + except JSONDecodeError: + body = await request.form() + + current_password = body.get("current_password", None) + new_password = body.get("new_password", None) + confirm_new_password = body.get("confirm_new_password", None) + + piccolo_user = request.user.user + min_password_length = piccolo_user._min_password_length + + if ( + (not current_password) + or (not new_password) + or (not confirm_new_password) + ): + error = "Form is invalid. Missing one or more fields." + if body.get("format") == "html": + return self.render_template( + request, + template_context={"error": error}, + min_password_length=min_password_length, + ) + raise HTTPException(status_code=422, detail=error) + + if len(new_password) < min_password_length: + error = ( + f"Password must be at least {min_password_length} characters " + "long." + ) + if body.get("format") == "html": + return self.render_template( + request, + min_password_length=min_password_length, + template_context={"error": error}, + ) + else: + raise HTTPException( + status_code=422, + detail=error, + ) + + if confirm_new_password != new_password: + error = "Passwords do not match." + + if body.get("format") == "html": + return self.render_template( + request, + min_password_length=min_password_length, + template_context={"error": error}, + ) + else: + raise HTTPException(status_code=422, detail=error) + + if not await piccolo_user.login( + username=piccolo_user.username, password=current_password + ): + error = "Incorrect password." + if body.get("format") == "html": + return self.render_template( + request, + min_password_length=min_password_length, + template_context={"error": error}, + ) + raise HTTPException(detail=error, status_code=422) + + await piccolo_user.update_password( + user=request.user.user_id, password=new_password + ) + + ####################################################################### + # After the password changes, we invalidate the session and + # redirect the user to the login endpoint. + + session_table = self._session_table + if session_table: + # This will invalidate all of the user's sessions on all devices. + await session_table.delete().where( + session_table.user_id == piccolo_user.id + ) + + response = RedirectResponse( + url=self._login_url, status_code=HTTP_303_SEE_OTHER + ) + + if self._session_cookie_name: + response.delete_cookie(self._session_cookie_name) + + return response + + +def change_password( + login_url: str = "/login/", + session_table: Optional[type[SessionsBase]] = SessionsBase, + session_cookie_name: Optional[str] = "id", + template_path: Optional[str] = None, + styles: Optional[Styles] = None, + read_only: bool = False, +) -> type[ChangePasswordEndpoint]: + """ + An endpoint for changing passwords. + + :param login_url: + Where to redirect the user to after successfully changing their + password. + :param session_table: + If provided, when the password is changed, the sessions for the user + will be invalidated in the database. + :param session_cookie_name: + If provided, when the password is changed, the session cookie with this + name will be deleted. + :param template_path: + If you want to override the default change password HTML template, + you can do so by specifying the absolute path to a custom template. + For example ``'/some_directory/change_password.html'``. Refer to + the default template at ``piccolo_api/templates/change_password.html`` + as a basis for your custom template. + :param styles: + Modify the appearance of the HTML template using CSS. + :read_only: + If ``True``, the endpoint only responds to GET requests. It's not + commonly needed, except when running demos. + + """ + template_path = ( + CHANGE_PASSWORD_TEMPLATE_PATH + if template_path is None + else template_path + ) + + directory, filename = os.path.split(template_path) + environment = Environment( + loader=FileSystemLoader(directory), autoescape=True + ) + change_password_template = environment.get_template(filename) + + class _ChangePasswordEndpoint(ChangePasswordEndpoint): + _login_url = login_url + _change_password_template = change_password_template + _styles = styles or Styles() + _session_table = session_table + _session_cookie_name = session_cookie_name + _read_only = read_only + + return _ChangePasswordEndpoint diff --git a/piccolo_api/crud/endpoints.py b/piccolo_api/crud/endpoints.py index 42509091..1daa682f 100644 --- a/piccolo_api/crud/endpoints.py +++ b/piccolo_api/crud/endpoints.py @@ -1,10 +1,10 @@ from __future__ import annotations import itertools -import typing as t import uuid from collections import defaultdict from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Optional, Union, cast import pydantic from piccolo.apps.user.tables import BaseUser @@ -41,7 +41,7 @@ from .exceptions import MalformedQuery, db_exception_handler from .validators import Validators, apply_validators -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from piccolo.query.methods.count import Count from piccolo.query.methods.objects import Objects from starlette.datastructures import QueryParams @@ -62,7 +62,7 @@ MATCH_TYPES = ("contains", "exact", "starts", "ends") -PK_TYPES = t.Union[str, uuid.UUID, int] +PK_TYPES = Union[str, uuid.UUID, int] class CustomJSONResponse(Response): @@ -96,7 +96,7 @@ def to_dict(self) -> HashableDict: ) return HashableDict(column=column, ascending=self.ascending) - def __eq__(self, value: t.Any) -> bool: + def __eq__(self, value: Any) -> bool: if not isinstance(value, OrderBy): return False @@ -105,28 +105,28 @@ def __eq__(self, value: t.Any) -> bool: @dataclass class Params: - operators: t.Dict[str, t.Type[ComparisonOperator]] = field( + operators: dict[str, type[ComparisonOperator]] = field( default_factory=lambda: defaultdict(lambda: Equal) ) - match_types: t.Dict[str, str] = field( + match_types: dict[str, str] = field( default_factory=lambda: defaultdict(lambda: MATCH_TYPES[0]) ) - fields: t.Dict[str, t.Any] = field(default_factory=dict) - order_by: t.Optional[t.List[OrderBy]] = None + fields: dict[str, Any] = field(default_factory=dict) + order_by: Optional[list[OrderBy]] = None include_readable: bool = False page: int = 1 - page_size: t.Optional[int] = None - visible_fields: t.Optional[t.List[Column]] = None + page_size: Optional[int] = None + visible_fields: Optional[list[Column]] = None range_header: bool = False range_header_name: str = field(default="") def get_visible_fields_options( - table: t.Type[Table], + table: type[Table], exclude_secrets: bool = False, max_joins: int = 0, prefix: str = "", -) -> t.Tuple[str, ...]: +) -> tuple[str, ...]: """ In the schema, we tell the user which fields are allowed with the ``__visible_fields`` GET parameter. This function extracts the column @@ -173,15 +173,15 @@ class PiccoloCRUD(Router): def __init__( self, - table: t.Type[Table], + table: type[Table], read_only: bool = True, allow_bulk_delete: bool = False, page_size: int = 15, exclude_secrets: bool = True, - validators: t.Optional[Validators] = None, - schema_extra: t.Optional[t.Dict[str, t.Any]] = None, + validators: Optional[Validators] = None, + schema_extra: Optional[dict[str, Any]] = None, max_joins: int = 0, - hooks: t.Optional[t.List[Hook]] = None, + hooks: Optional[list[Hook]] = None, ) -> None: """ :param table: @@ -268,7 +268,7 @@ def __init__( ["POST", "DELETE"] if allow_bulk_delete else ["POST"] ) - routes: t.List[BaseRoute] = [ + routes: list[BaseRoute] = [ Route(path="/", endpoint=self.root, methods=root_methods), Route(path="/schema/", endpoint=self.get_schema, methods=["GET"]), Route(path="/ids/", endpoint=self.get_ids, methods=["GET"]), @@ -293,7 +293,7 @@ def __init__( ########################################################################### @property - def pydantic_model(self) -> t.Type[pydantic.BaseModel]: + def pydantic_model(self) -> type[pydantic.BaseModel]: """ Useful for serialising inbound data from POST and PUT requests. """ @@ -307,9 +307,9 @@ def pydantic_model(self) -> t.Type[pydantic.BaseModel]: def _pydantic_model_output( self, include_readable: bool = False, - include_columns: t.Tuple[Column, ...] = (), - nested: t.Union[bool, t.Tuple[ForeignKey, ...]] = False, - ) -> t.Type[pydantic.BaseModel]: + include_columns: tuple[Column, ...] = (), + nested: Union[bool, tuple[ForeignKey, ...]] = False, + ) -> type[pydantic.BaseModel]: return create_pydantic_model( self.table, include_default_columns=True, @@ -320,7 +320,7 @@ def _pydantic_model_output( ) @property - def pydantic_model_output(self) -> t.Type[pydantic.BaseModel]: + def pydantic_model_output(self) -> type[pydantic.BaseModel]: """ Contains the default columns, which is required when exporting data (for example, in a GET request). @@ -328,7 +328,7 @@ def pydantic_model_output(self) -> t.Type[pydantic.BaseModel]: return self._pydantic_model_output() @property - def pydantic_model_optional(self) -> t.Type[pydantic.BaseModel]: + def pydantic_model_optional(self) -> type[pydantic.BaseModel]: """ All fields are optional, which is useful for PATCH requests, which may only update some fields. @@ -341,7 +341,7 @@ def pydantic_model_optional(self) -> t.Type[pydantic.BaseModel]: ) @property - def pydantic_model_filters(self) -> t.Type[pydantic.BaseModel]: + def pydantic_model_filters(self) -> type[pydantic.BaseModel]: """ Used for serialising query params, which are used for filtering. @@ -381,14 +381,14 @@ def pydantic_model_filters(self) -> t.Type[pydantic.BaseModel]: __base__=base_model, **{ i._meta.name: ( - t.Optional[t.List[i._get_inner_value_type()]], # type: ignore # noqa: E501 + Optional[list[i._get_inner_value_type()]], # type: ignore # noqa: E501 pydantic.Field(default=None), ) for i in multidimensional_array_columns }, **{ i._meta.name: ( - t.Optional[str], + Optional[str], pydantic.Field(default=None), ) for i in email_columns @@ -400,13 +400,13 @@ def pydantic_model_filters(self) -> t.Type[pydantic.BaseModel]: def pydantic_model_plural( self, include_readable=False, - include_columns: t.Tuple[Column, ...] = (), - nested: t.Union[bool, t.Tuple[ForeignKey, ...]] = False, - ) -> t.Type[pydantic.BaseModel]: + include_columns: tuple[Column, ...] = (), + nested: Union[bool, tuple[ForeignKey, ...]] = False, + ) -> type[pydantic.BaseModel]: """ This is for when we want to serialise many copies of the model. """ - base_model: t.Any = create_pydantic_model( + base_model: Any = create_pydantic_model( self.table, include_default_columns=True, include_readable=include_readable, @@ -419,7 +419,7 @@ def pydantic_model_plural( __config__=pydantic.config.ConfigDict( arbitrary_types_allowed=True ), - rows=(t.List[base_model], None), + rows=(list[base_model], None), ) @apply_validators @@ -444,11 +444,11 @@ async def get_ids(self, request: Request) -> Response: """ readable = self.table.get_readable() - query: t.Any = self.table.select().columns( + query: Any = self.table.select().columns( self.table._meta.primary_key._meta.name, readable ) - limit: t.Union[t.Optional[str], int] = request.query_params.get( + limit: Union[Optional[str], int] = request.query_params.get( "limit", None ) if limit is not None: @@ -461,7 +461,7 @@ async def get_ids(self, request: Request) -> Response: else: limit = "ALL" - offset: t.Union[t.Optional[str], int] = request.query_params.get( + offset: Union[Optional[str], int] = request.query_params.get( "offset", None ) if offset is not None: @@ -479,7 +479,7 @@ async def get_ids(self, request: Request) -> Response: # Readable doesn't currently have a 'like' method, so we do it # manually. if self.table._meta.db.engine_type == "postgres": - query = t.cast( + query = cast( Select, self.table.raw( ( @@ -499,7 +499,7 @@ async def get_ids(self, request: Request) -> Response: ) if isinstance(limit, int): sql += f" LIMIT {limit} OFFSET {offset}" - query = t.cast( + query = cast( Select, self.table.raw(sql, f"%{search_term.upper()}%") ) else: @@ -559,7 +559,7 @@ async def get_count(self, request: Request) -> Response: ########################################################################### - def _parse_params(self, params: QueryParams) -> t.Dict[str, t.Any]: + def _parse_params(self, params: QueryParams) -> dict[str, Any]: """ The GET params may contain multiple values for each parameter name. For example: @@ -575,7 +575,7 @@ def _parse_params(self, params: QueryParams) -> t.Dict[str, t.Any]: multiple are present. """ - params_map: t.Dict[str, t.Any] = { + params_map: dict[str, Any] = { i[0]: [j[1] for j in i[1]] for i in itertools.groupby(params.multi_items(), lambda x: x[0]) } @@ -615,7 +615,7 @@ async def root(self, request: Request) -> Response: ########################################################################### - def _split_params(self, params: t.Dict[str, t.Any]) -> Params: + def _split_params(self, params: dict[str, Any]) -> Params: """ Some parameters reference fields, and others provide instructions on how to perform the query (e.g. which operator to use). @@ -675,8 +675,8 @@ def _split_params(self, params: t.Dict[str, t.Any]) -> Params: # separated string e.g. 'name,created_on'. The value may # already be a list if the parameter is passed in multiple # times for example `?__order=name?__order=created_on`. - order_by: t.List[OrderBy] = [] - sub_values: t.List[str] + order_by: list[OrderBy] = [] + sub_values: list[str] if isinstance(value, str): sub_values = value.split(",") @@ -721,7 +721,7 @@ def _split_params(self, params: t.Dict[str, t.Any]) -> Params: continue if key == "__visible_fields": - column_names: t.List[str] + column_names: list[str] if isinstance(value, str): column_names = value.split(",") @@ -762,8 +762,8 @@ def _split_params(self, params: t.Dict[str, t.Any]) -> Params: return response def _apply_filters( - self, query: t.Union[Select, Count, Objects, Delete], params: Params - ) -> t.Union[Select, Count, Objects, Delete]: + self, query: Union[Select, Count, Objects, Delete], params: Params + ) -> Union[Select, Count, Objects, Delete]: """ Apply the HTTP query parameters to the Piccolo query object, then return it. @@ -821,7 +821,7 @@ def _apply_filters( @apply_validators async def get_all( - self, request: Request, params: t.Optional[t.Dict[str, t.Any]] = None + self, request: Request, params: Optional[dict[str, Any]] = None ) -> Response: """ Get all rows - query parameters are used for filtering. @@ -835,7 +835,7 @@ async def get_all( # Visible fields visible_fields = split_params.visible_fields - nested: t.Union[bool, t.Tuple[Column, ...]] + nested: Union[bool, tuple[Column, ...]] if visible_fields: nested = tuple( i._meta.call_chain[-1] @@ -871,7 +871,7 @@ async def get_all( # Apply filters try: - query = t.cast(Select, self._apply_filters(query, split_params)) + query = cast(Select, self._apply_filters(query, split_params)) except MalformedQuery as exception: return Response(str(exception), status_code=400) @@ -932,8 +932,8 @@ async def get_all( ########################################################################### - def _clean_data(self, data: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - cleaned_data: t.Dict[str, t.Any] = {} + def _clean_data(self, data: dict[str, Any]) -> dict[str, Any]: + cleaned_data: dict[str, Any] = {} for key, value in data.items(): value = ( @@ -948,7 +948,7 @@ def _clean_data(self, data: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: @apply_validators @db_exception_handler async def post_single( - self, request: Request, data: t.Dict[str, t.Any] + self, request: Request, data: dict[str, Any] ) -> Response: """ Adds a single row, if the id doesn't already exist. @@ -987,7 +987,7 @@ async def post_single( @apply_validators async def delete_all( - self, request: Request, params: t.Optional[t.Dict[str, t.Any]] = None + self, request: Request, params: Optional[dict[str, Any]] = None ) -> Response: """ Deletes all rows - query parameters are used for filtering. @@ -1115,7 +1115,7 @@ async def get_single(self, request: Request, row_id: PK_TYPES) -> Response: return Response(str(exception), status_code=400) # Visible fields - nested: t.Union[bool, t.Tuple[ForeignKey, ...]] + nested: Union[bool, tuple[ForeignKey, ...]] visible_fields = split_params.visible_fields if visible_fields: nested = tuple( @@ -1168,7 +1168,7 @@ async def get_single(self, request: Request, row_id: PK_TYPES) -> Response: @apply_validators @db_exception_handler async def put_single( - self, request: Request, row_id: PK_TYPES, data: t.Dict[str, t.Any] + self, request: Request, row_id: PK_TYPES, data: dict[str, Any] ) -> Response: """ Replaces an existing row. We don't allow new resources to be created. @@ -1197,7 +1197,7 @@ async def put_single( @apply_validators @db_exception_handler async def patch_single( - self, request: Request, row_id: PK_TYPES, data: t.Dict[str, t.Any] + self, request: Request, row_id: PK_TYPES, data: dict[str, Any] ) -> Response: """ Patch a single row. @@ -1282,7 +1282,7 @@ async def delete_single( except ValueError: return Response("Unable to delete the resource.", status_code=500) - def __eq__(self, other: t.Any) -> bool: + def __eq__(self, other: Any) -> bool: """ To keep LGTM happy. """ diff --git a/piccolo_api/crud/exceptions.py b/piccolo_api/crud/exceptions.py index e0a7ddb3..42ccb9ba 100644 --- a/piccolo_api/crud/exceptions.py +++ b/piccolo_api/crud/exceptions.py @@ -1,6 +1,6 @@ import functools import logging -import typing as t +from collections.abc import Callable, Coroutine from sqlite3 import IntegrityError from starlette.responses import JSONResponse @@ -38,7 +38,7 @@ class MalformedQuery(Exception): pass -def db_exception_handler(func: t.Callable[..., t.Coroutine]): +def db_exception_handler(func: Callable[..., Coroutine]): """ A decorator which wraps an endpoint, and converts database exceptions into HTTP responses. diff --git a/piccolo_api/crud/hooks.py b/piccolo_api/crud/hooks.py index e071f66a..40d30b5f 100644 --- a/piccolo_api/crud/hooks.py +++ b/piccolo_api/crud/hooks.py @@ -1,6 +1,7 @@ import inspect -import typing as t +from collections.abc import Callable from enum import Enum +from typing import Any from piccolo.table import Table from starlette.requests import Request @@ -17,20 +18,20 @@ class HookType(Enum): class Hook: - def __init__(self, hook_type: HookType, callable: t.Callable) -> None: + def __init__(self, hook_type: HookType, callable: Callable) -> None: self.hook_type = hook_type self.callable = callable async def execute_post_hooks( - hooks: t.Dict[HookType, t.List[Hook]], + hooks: dict[HookType, list[Hook]], hook_type: HookType, row: Table, request: Request, ): for hook in hooks.get(hook_type, []): signature = inspect.signature(hook.callable) - kwargs: t.Dict[str, t.Any] = dict(row=row) + kwargs: dict[str, Any] = dict(row=row) # Include request in hook call arguments if possible if {i for i in signature.parameters.keys()}.intersection( {"kwargs", "request"} @@ -44,12 +45,12 @@ async def execute_post_hooks( async def execute_patch_hooks( - hooks: t.Dict[HookType, t.List[Hook]], + hooks: dict[HookType, list[Hook]], hook_type: HookType, - row_id: t.Any, - values: t.Dict[t.Any, t.Any], + row_id: Any, + values: dict[Any, Any], request: Request, -) -> t.Dict[t.Any, t.Any]: +) -> dict[Any, Any]: for hook in hooks.get(hook_type, []): signature = inspect.signature(hook.callable) kwargs = dict(row_id=row_id, values=values) @@ -66,9 +67,9 @@ async def execute_patch_hooks( async def execute_delete_hooks( - hooks: t.Dict[HookType, t.List[Hook]], + hooks: dict[HookType, list[Hook]], hook_type: HookType, - row_id: t.Any, + row_id: Any, request: Request, ): for hook in hooks.get(hook_type, []): diff --git a/piccolo_api/crud/validators.py b/piccolo_api/crud/validators.py index a6c1cdc5..85c6ef5c 100644 --- a/piccolo_api/crud/validators.py +++ b/piccolo_api/crud/validators.py @@ -2,19 +2,18 @@ import functools import inspect -import typing as t +from collections.abc import Callable, Coroutine +from typing import TYPE_CHECKING, Any, Union from piccolo.utils.sync import run_sync from starlette.exceptions import HTTPException from starlette.requests import Request -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from .endpoints import PiccoloCRUD -ValidatorFunction = t.Callable[ - ["PiccoloCRUD", Request], t.Union[t.Coroutine, None] -] +ValidatorFunction = Callable[["PiccoloCRUD", Request], Union[Coroutine, None]] class Validators: @@ -48,20 +47,20 @@ async def validator_2(piccolo_crud: PiccoloCRUD, request: Request): def __init__( self, - every: t.List[ValidatorFunction] = [], - get_single: t.List[ValidatorFunction] = [], - put_single: t.List[ValidatorFunction] = [], - patch_single: t.List[ValidatorFunction] = [], - delete_single: t.List[ValidatorFunction] = [], - post_single: t.List[ValidatorFunction] = [], - get_all: t.List[ValidatorFunction] = [], - delete_all: t.List[ValidatorFunction] = [], - get_references: t.List[ValidatorFunction] = [], - get_ids: t.List[ValidatorFunction] = [], - get_new: t.List[ValidatorFunction] = [], - get_schema: t.List[ValidatorFunction] = [], - get_count: t.List[ValidatorFunction] = [], - extra_context: t.Dict[str, t.Any] = {}, + every: list[ValidatorFunction] = [], + get_single: list[ValidatorFunction] = [], + put_single: list[ValidatorFunction] = [], + patch_single: list[ValidatorFunction] = [], + delete_single: list[ValidatorFunction] = [], + post_single: list[ValidatorFunction] = [], + get_all: list[ValidatorFunction] = [], + delete_all: list[ValidatorFunction] = [], + get_references: list[ValidatorFunction] = [], + get_ids: list[ValidatorFunction] = [], + get_new: list[ValidatorFunction] = [], + get_schema: list[ValidatorFunction] = [], + get_count: list[ValidatorFunction] = [], + extra_context: dict[str, Any] = {}, ): self.every = every self.get_single = get_single diff --git a/piccolo_api/csp/middleware.py b/piccolo_api/csp/middleware.py index d4cb537a..42e66963 100644 --- a/piccolo_api/csp/middleware.py +++ b/piccolo_api/csp/middleware.py @@ -1,16 +1,16 @@ from __future__ import annotations -import typing as t from dataclasses import dataclass from functools import wraps +from typing import TYPE_CHECKING, Optional -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from starlette.types import ASGIApp, Message, Receive, Scope, Send @dataclass class CSPConfig: - report_uri: t.Optional[bytes] = None + report_uri: Optional[bytes] = None default_src: str = "self" diff --git a/piccolo_api/csrf/middleware.py b/piccolo_api/csrf/middleware.py index 2ac62eed..9b582dbb 100644 --- a/piccolo_api/csrf/middleware.py +++ b/piccolo_api/csrf/middleware.py @@ -1,6 +1,5 @@ from __future__ import annotations -import typing as t import uuid from collections.abc import Sequence @@ -42,7 +41,7 @@ def get_new_token() -> str: def __init__( self, app: ASGIApp, - allowed_hosts: t.Sequence[str] = [], + allowed_hosts: Sequence[str] = [], cookie_name: str = DEFAULT_COOKIE_NAME, header_name: str = DEFAULT_HEADER_NAME, max_age: int = ONE_YEAR, diff --git a/piccolo_api/encryption/providers.py b/piccolo_api/encryption/providers.py index 1b0cc7f4..ef5d2ee5 100644 --- a/piccolo_api/encryption/providers.py +++ b/piccolo_api/encryption/providers.py @@ -1,10 +1,10 @@ from __future__ import annotations import logging -import typing as t from abc import ABCMeta, abstractmethod +from typing import TYPE_CHECKING -if t.TYPE_CHECKING: +if TYPE_CHECKING: import nacl from cryptography.fernet import Fernet @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) -def get_fernet_class() -> t.Type[Fernet]: # type: ignore +def get_fernet_class() -> type[Fernet]: # type: ignore try: from cryptography.fernet import Fernet except ImportError as e: diff --git a/piccolo_api/fastapi/endpoints.py b/piccolo_api/fastapi/endpoints.py index 810d5fea..2d8e4205 100644 --- a/piccolo_api/fastapi/endpoints.py +++ b/piccolo_api/fastapi/endpoints.py @@ -5,11 +5,12 @@ from __future__ import annotations import datetime -import typing as t from collections import defaultdict +from collections.abc import Callable from decimal import Decimal from enum import Enum from inspect import Parameter, Signature, isclass +from typing import Any, Optional, Union from fastapi import APIRouter, FastAPI, Request, status from fastapi.params import Query @@ -19,7 +20,7 @@ from piccolo_api.crud.endpoints import PiccoloCRUD from piccolo_api.utils.types import get_type -ANNOTATIONS: t.DefaultDict = defaultdict(dict) +ANNOTATIONS: defaultdict = defaultdict(dict) class HTTPMethod(str, Enum): @@ -34,14 +35,14 @@ class FastAPIKwargs: def __init__( self, - all_routes: t.Dict[str, t.Any] = {}, - get: t.Dict[str, t.Any] = {}, - delete: t.Dict[str, t.Any] = {}, - post: t.Dict[str, t.Any] = {}, - put: t.Dict[str, t.Any] = {}, - patch: t.Dict[str, t.Any] = {}, - get_single: t.Dict[str, t.Any] = {}, - delete_single: t.Dict[str, t.Any] = {}, + all_routes: dict[str, Any] = {}, + get: dict[str, Any] = {}, + delete: dict[str, Any] = {}, + post: dict[str, Any] = {}, + put: dict[str, Any] = {}, + patch: dict[str, Any] = {}, + get_single: dict[str, Any] = {}, + delete_single: dict[str, Any] = {}, ): self.all_routes = all_routes self.get = get @@ -52,7 +53,7 @@ def __init__( self.get_single = get_single self.delete_single = delete_single - def get_kwargs(self, endpoint_name: str) -> t.Dict[str, t.Any]: + def get_kwargs(self, endpoint_name: str) -> dict[str, Any]: """ Merges the arguments for all routes with arguments specific to the given route. @@ -74,7 +75,7 @@ class ReferenceModel(BaseModel): class ReferencesModel(BaseModel): - references: t.List[ReferenceModel] + references: list[ReferenceModel] class FastAPIWrapper: @@ -102,9 +103,9 @@ class FastAPIWrapper: def __init__( self, root_url: str, - fastapi_app: t.Union[FastAPI, APIRouter], + fastapi_app: Union[FastAPI, APIRouter], piccolo_crud: PiccoloCRUD, - fastapi_kwargs: t.Optional[FastAPIKwargs] = None, + fastapi_kwargs: Optional[FastAPIKwargs] = None, ): fastapi_kwargs = fastapi_kwargs or FastAPIKwargs() @@ -157,8 +158,8 @@ async def get(request: Request, **kwargs): async def ids( request: Request, - search: t.Optional[str] = None, - limit: t.Optional[int] = None, + search: Optional[str] = None, + limit: Optional[int] = None, ): """ Returns a mapping of row IDs to a readable representation. @@ -169,7 +170,7 @@ async def ids( path=self.join_urls(root_url, "/ids/"), endpoint=ids, methods=["GET"], - response_model=t.Dict[str, str], + response_model=dict[str, str], **fastapi_kwargs.get_kwargs("get"), ) @@ -187,7 +188,7 @@ async def new(request: Request): path=self.join_urls(root_url, "/new/"), endpoint=new, methods=["GET"], - response_model=t.Dict[str, str], + response_model=dict[str, str], **fastapi_kwargs.get_kwargs("get"), ) @@ -225,7 +226,7 @@ async def schema(request: Request): path=self.join_urls(root_url, "/schema/"), endpoint=schema, methods=["GET"], - response_model=t.Dict[str, t.Any], + response_model=dict[str, Any], **fastapi_kwargs.get_kwargs("get"), ) @@ -397,8 +398,8 @@ def join_urls(url_1: str, url_2: str) -> str: @staticmethod def modify_signature( - endpoint: t.Callable, - model: t.Type[PydanticBaseModel], + endpoint: Callable, + model: type[PydanticBaseModel], http_method: HTTPMethod, allow_pagination: bool = False, allow_ordering: bool = False, diff --git a/piccolo_api/jwt_auth/endpoints.py b/piccolo_api/jwt_auth/endpoints.py index 99240e14..7e40727d 100644 --- a/piccolo_api/jwt_auth/endpoints.py +++ b/piccolo_api/jwt_auth/endpoints.py @@ -1,6 +1,5 @@ from __future__ import annotations -import typing as t from abc import abstractmethod from datetime import datetime, timedelta, timezone @@ -15,7 +14,7 @@ class JWTLoginBase(HTTPEndpoint): @property @abstractmethod - def _auth_table(self) -> t.Type[BaseUser]: + def _auth_table(self) -> type[BaseUser]: raise NotImplementedError @property @@ -49,9 +48,9 @@ async def post(self, request: Request) -> JSONResponse: def jwt_login( secret: str, - auth_table: t.Type[BaseUser] = BaseUser, + auth_table: type[BaseUser] = BaseUser, expiry: timedelta = timedelta(days=1), -) -> t.Type[JWTLoginBase]: +) -> type[JWTLoginBase]: """ Create an endpoint for generating JWT tokens. diff --git a/piccolo_api/jwt_auth/middleware.py b/piccolo_api/jwt_auth/middleware.py index 26ce9194..1f3c4fb6 100644 --- a/piccolo_api/jwt_auth/middleware.py +++ b/piccolo_api/jwt_auth/middleware.py @@ -1,7 +1,7 @@ from __future__ import annotations import enum -import typing as t +from typing import Any, Optional import jwt from piccolo.apps.user.tables import BaseUser @@ -30,14 +30,14 @@ class StaticJWTBlacklist(JWTBlacklist): rejects a token if it's in the given list. """ - def __init__(self, blacklist: t.List[str]): + def __init__(self, blacklist: list[str]): self.blacklist = blacklist async def in_blacklist(self, token: str) -> bool: return token in self.blacklist -def extend_scope(scope: t.Dict, extra: t.Dict) -> t.Dict: +def extend_scope(scope: dict, extra: dict) -> dict: """ We copy the scope and extend it with `extra`. It's best to copy the scope rather than manipulate it directly. @@ -71,7 +71,7 @@ def __init__( self, asgi: ASGIApp, secret: str, - auth_table: t.Type[BaseUser] = BaseUser, + auth_table: type[BaseUser] = BaseUser, blacklist: JWTBlacklist = JWTBlacklist(), allow_unauthenticated: bool = False, ) -> None: @@ -97,7 +97,7 @@ def __init__( self.blacklist = blacklist self.allow_unauthenticated = allow_unauthenticated - def get_token(self, headers: dict) -> t.Optional[str]: + def get_token(self, headers: dict) -> Optional[str]: """ Try and extract the JWT token from the request headers. """ @@ -109,9 +109,7 @@ def get_token(self, headers: dict) -> t.Optional[str]: return None return auth_str.split(" ")[1] - async def get_user( - self, token_dict: t.Dict[str, t.Any] - ) -> t.Optional[BaseUser]: + async def get_user(self, token_dict: dict[str, Any]) -> Optional[BaseUser]: """ Extract the user_id from the token, and return a matching user. """ diff --git a/piccolo_api/media/base.py b/piccolo_api/media/base.py index 696e7496..86cdfe96 100644 --- a/piccolo_api/media/base.py +++ b/piccolo_api/media/base.py @@ -5,8 +5,9 @@ import itertools import pathlib import string -import typing as t import uuid +from collections.abc import Sequence +from typing import IO, Optional, Union from piccolo.apps.user.tables import BaseUser from piccolo.columns.column_types import Array, Text, Varchar @@ -84,9 +85,9 @@ class MediaStorage(metaclass=abc.ABCMeta): def __init__( self, - column: t.Union[Text, Varchar, Array], - allowed_extensions: t.Optional[t.Sequence[str]] = ALLOWED_EXTENSIONS, - allowed_characters: t.Optional[t.Sequence[str]] = ALLOWED_CHARACTERS, + column: Union[Text, Varchar, Array], + allowed_extensions: Optional[Sequence[str]] = ALLOWED_EXTENSIONS, + allowed_characters: Optional[Sequence[str]] = ALLOWED_CHARACTERS, ): if not ( isinstance(column, ALLOWED_COLUMN_TYPES) @@ -152,7 +153,7 @@ def validate_file_name(self, file_name: str): raise ValueError("The file has no extension.") def generate_file_key( - self, file_name: str, user: t.Optional[BaseUser] = None + self, file_name: str, user: Optional[BaseUser] = None ) -> str: """ Generates a unique file ID. If you have your own strategy for naming @@ -193,7 +194,7 @@ def generate_file_key( @abc.abstractmethod async def store_file( - self, file_name: str, file: t.IO, user: t.Optional[BaseUser] = None + self, file_name: str, file: IO, user: Optional[BaseUser] = None ) -> str: """ Stores the file in whichever storage you're using, and returns a key @@ -211,7 +212,7 @@ async def store_file( @abc.abstractmethod async def generate_file_url( - self, file_key: str, root_url: str, user: t.Optional[BaseUser] = None + self, file_key: str, root_url: str, user: Optional[BaseUser] = None ): """ This retrieves an absolute URL for the file. It might be a signed URL, @@ -229,7 +230,7 @@ async def generate_file_url( raise NotImplementedError # pragma: no cover @abc.abstractmethod - async def get_file(self, file_key: str) -> t.Optional[t.IO]: + async def get_file(self, file_key: str) -> Optional[IO]: """ Returns the file object matching the ``file_key``. """ @@ -243,11 +244,11 @@ async def delete_file(self, file_key: str): raise NotImplementedError # pragma: no cover @abc.abstractmethod - async def bulk_delete_files(self, file_keys: t.List[str]): + async def bulk_delete_files(self, file_keys: list[str]): raise NotImplementedError # pragma: no cover @abc.abstractmethod - async def get_file_keys(self) -> t.List[str]: + async def get_file_keys(self) -> list[str]: """ Returns the file key for each file we have stored. """ @@ -255,7 +256,7 @@ async def get_file_keys(self) -> t.List[str]: ########################################################################### - async def get_file_keys_from_db(self) -> t.List[str]: + async def get_file_keys_from_db(self) -> list[str]: """ Returns the file key for each file we have in the database. """ @@ -266,7 +267,7 @@ async def get_file_keys_from_db(self) -> t.List[str]: else: return response - async def get_unused_file_keys(self) -> t.List[str]: + async def get_unused_file_keys(self) -> list[str]: """ Compares the file keys we have stored, vs what's in the database. """ diff --git a/piccolo_api/media/local.py b/piccolo_api/media/local.py index a1638883..df6daa9e 100644 --- a/piccolo_api/media/local.py +++ b/piccolo_api/media/local.py @@ -6,8 +6,9 @@ import os import pathlib import shutil -import typing as t +from collections.abc import Sequence from concurrent.futures import ThreadPoolExecutor +from typing import IO, TYPE_CHECKING, Optional, Union from piccolo.apps.user.tables import BaseUser from piccolo.columns.column_types import Array, Text, Varchar @@ -15,7 +16,7 @@ from .base import ALLOWED_CHARACTERS, ALLOWED_EXTENSIONS, MediaStorage -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from concurrent.futures._base import Executor @@ -25,12 +26,12 @@ class LocalMediaStorage(MediaStorage): def __init__( self, - column: t.Union[Text, Varchar, Array], + column: Union[Text, Varchar, Array], media_path: str, - executor: t.Optional[Executor] = None, - allowed_extensions: t.Optional[t.Sequence[str]] = ALLOWED_EXTENSIONS, - allowed_characters: t.Optional[t.Sequence[str]] = ALLOWED_CHARACTERS, - file_permissions: t.Optional[int] = 0o600, + executor: Optional[Executor] = None, + allowed_extensions: Optional[Sequence[str]] = ALLOWED_EXTENSIONS, + allowed_characters: Optional[Sequence[str]] = ALLOWED_CHARACTERS, + file_permissions: Optional[int] = 0o600, ): """ Stores media files on a local path. This is good for simple @@ -71,7 +72,7 @@ def __init__( ) async def store_file( - self, file_name: str, file: t.IO, user: t.Optional[BaseUser] = None + self, file_name: str, file: IO, user: Optional[BaseUser] = None ) -> str: # If the file_name includes the entire path (e.g. /foo/bar.jpg) - we # just want bar.jpg. @@ -102,7 +103,7 @@ def save(): return file_key def store_file_sync( - self, file_name: str, file: t.IO, user: t.Optional[BaseUser] = None + self, file_name: str, file: IO, user: Optional[BaseUser] = None ) -> str: """ A sync wrapper around :meth:`store_file`. @@ -112,7 +113,7 @@ def store_file_sync( ) async def generate_file_url( - self, file_key: str, root_url: str, user: t.Optional[BaseUser] = None + self, file_key: str, root_url: str, user: Optional[BaseUser] = None ) -> str: """ This retrieves an absolute URL for the file. @@ -120,7 +121,7 @@ async def generate_file_url( return "/".join((root_url.rstrip("/"), file_key)) def generate_file_url_sync( - self, file_key: str, root_url: str, user: t.Optional[BaseUser] = None + self, file_key: str, root_url: str, user: Optional[BaseUser] = None ) -> str: """ A sync wrapper around :meth:`generate_file_url`. @@ -133,7 +134,7 @@ def generate_file_url_sync( ########################################################################### - async def get_file(self, file_key: str) -> t.Optional[t.IO]: + async def get_file(self, file_key: str) -> Optional[IO]: """ Returns the file object matching the ``file_key``. """ @@ -141,7 +142,7 @@ async def get_file(self, file_key: str) -> t.Optional[t.IO]: func = functools.partial(self.get_file_sync, file_key=file_key) return await loop.run_in_executor(self.executor, func) - def get_file_sync(self, file_key: str) -> t.Optional[t.IO]: + def get_file_sync(self, file_key: str) -> Optional[IO]: """ A sync wrapper around :meth:`get_file`. """ @@ -163,12 +164,12 @@ def delete_file_sync(self, file_key: str): path = os.path.join(self.media_path, file_key) os.unlink(path) - async def bulk_delete_files(self, file_keys: t.List[str]): + async def bulk_delete_files(self, file_keys: list[str]): media_path = self.media_path for file_key in file_keys: os.unlink(os.path.join(media_path, file_key)) - async def get_file_keys(self) -> t.List[str]: + async def get_file_keys(self) -> list[str]: """ Returns the file key for each file we have stored. """ diff --git a/piccolo_api/media/s3.py b/piccolo_api/media/s3.py index 1dbc5457..add14e2b 100644 --- a/piccolo_api/media/s3.py +++ b/piccolo_api/media/s3.py @@ -4,8 +4,9 @@ import functools import pathlib import sys -import typing as t +from collections.abc import Callable, Sequence from concurrent.futures import ThreadPoolExecutor +from typing import IO, TYPE_CHECKING, Any, Optional, Union from piccolo.apps.user.tables import BaseUser from piccolo.columns.column_types import Array, Text, Varchar @@ -13,23 +14,23 @@ from .base import ALLOWED_CHARACTERS, ALLOWED_EXTENSIONS, MediaStorage from .content_type import CONTENT_TYPE -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from concurrent.futures._base import Executor class S3MediaStorage(MediaStorage): def __init__( self, - column: t.Union[Text, Varchar, Array], + column: Union[Text, Varchar, Array], bucket_name: str, - folder_name: t.Optional[str] = None, - connection_kwargs: t.Optional[t.Dict[str, t.Any]] = None, + folder_name: Optional[str] = None, + connection_kwargs: Optional[dict[str, Any]] = None, sign_urls: bool = True, signed_url_expiry: int = 3600, - upload_metadata: t.Optional[t.Dict[str, t.Any]] = None, - executor: t.Optional[Executor] = None, - allowed_extensions: t.Optional[t.Sequence[str]] = ALLOWED_EXTENSIONS, - allowed_characters: t.Optional[t.Sequence[str]] = ALLOWED_CHARACTERS, + upload_metadata: Optional[dict[str, Any]] = None, + executor: Optional[Executor] = None, + allowed_extensions: Optional[Sequence[str]] = ALLOWED_EXTENSIONS, + allowed_characters: Optional[Sequence[str]] = ALLOWED_CHARACTERS, ): """ Stores media files in S3 compatible storage. This is a good option when @@ -153,7 +154,7 @@ def get_client(self, config=None): # pragma: no cover return client async def store_file( - self, file_name: str, file: t.IO, user: t.Optional[BaseUser] = None + self, file_name: str, file: IO, user: Optional[BaseUser] = None ) -> str: loop = asyncio.get_running_loop() @@ -173,7 +174,7 @@ def _prepend_folder_name(self, file_key: str) -> str: return file_key def store_file_sync( - self, file_name: str, file: t.IO, user: t.Optional[BaseUser] = None + self, file_name: str, file: IO, user: Optional[BaseUser] = None ) -> str: """ A sync wrapper around :meth:`store_file`. @@ -181,7 +182,7 @@ def store_file_sync( file_key = self.generate_file_key(file_name=file_name, user=user) extension = file_key.rsplit(".", 1)[-1] client = self.get_client() - upload_metadata: t.Dict[str, t.Any] = self.upload_metadata + upload_metadata: dict[str, Any] = self.upload_metadata if extension in CONTENT_TYPE: upload_metadata["ContentType"] = CONTENT_TYPE[extension] @@ -196,14 +197,14 @@ def store_file_sync( return file_key async def generate_file_url( - self, file_key: str, root_url: str, user: t.Optional[BaseUser] = None + self, file_key: str, root_url: str, user: Optional[BaseUser] = None ) -> str: """ This retrieves an absolute URL for the file. """ loop = asyncio.get_running_loop() - blocking_function: t.Callable = functools.partial( + blocking_function: Callable = functools.partial( self.generate_file_url_sync, file_key=file_key, root_url=root_url, @@ -213,7 +214,7 @@ async def generate_file_url( return await loop.run_in_executor(self.executor, blocking_function) def generate_file_url_sync( - self, file_key: str, root_url: str, user: t.Optional[BaseUser] = None + self, file_key: str, root_url: str, user: Optional[BaseUser] = None ) -> str: """ A sync wrapper around :meth:`generate_file_url`. @@ -239,7 +240,7 @@ def generate_file_url_sync( ########################################################################### - async def get_file(self, file_key: str) -> t.Optional[t.IO]: + async def get_file(self, file_key: str) -> Optional[IO]: """ Returns the file object matching the ``file_key``. """ @@ -249,7 +250,7 @@ async def get_file(self, file_key: str) -> t.Optional[t.IO]: return await loop.run_in_executor(self.executor, func) - def get_file_sync(self, file_key: str) -> t.Optional[t.IO]: + def get_file_sync(self, file_key: str) -> Optional[IO]: """ Returns the file object matching the ``file_key``. """ @@ -283,7 +284,7 @@ def delete_file_sync(self, file_key: str): Key=self._prepend_folder_name(file_key), ) - async def bulk_delete_files(self, file_keys: t.List[str]): + async def bulk_delete_files(self, file_keys: list[str]): loop = asyncio.get_running_loop() func = functools.partial( self.bulk_delete_files_sync, @@ -291,7 +292,7 @@ async def bulk_delete_files(self, file_keys: t.List[str]): ) await loop.run_in_executor(self.executor, func) - def bulk_delete_files_sync(self, file_keys: t.List[str]): + def bulk_delete_files_sync(self, file_keys: list[str]): s3_client = self.get_client() batch_size = 100 @@ -321,7 +322,7 @@ def bulk_delete_files_sync(self, file_keys: t.List[str]): iteration += 1 - def get_file_keys_sync(self) -> t.List[str]: + def get_file_keys_sync(self) -> list[str]: """ Returns the file key for each file we have stored. """ @@ -331,7 +332,7 @@ def get_file_keys_sync(self) -> t.List[str]: start_after = None while True: - extra_kwargs: t.Dict[str, t.Any] = {} + extra_kwargs: dict[str, Any] = {} if start_after: extra_kwargs["StartAfter"] = start_after @@ -361,7 +362,7 @@ def get_file_keys_sync(self) -> t.List[str]: else: return keys - async def get_file_keys(self) -> t.List[str]: + async def get_file_keys(self) -> list[str]: """ Returns the file key for each file we have stored. """ diff --git a/piccolo_api/mfa/authenticator/provider.py b/piccolo_api/mfa/authenticator/provider.py index f87e5677..870a0f56 100644 --- a/piccolo_api/mfa/authenticator/provider.py +++ b/piccolo_api/mfa/authenticator/provider.py @@ -1,5 +1,5 @@ import os -import typing as t +from typing import Optional from jinja2 import Environment, FileSystemLoader from piccolo.apps.user.tables import BaseUser @@ -23,10 +23,10 @@ def __init__( self, encryption_provider: EncryptionProvider, recovery_code_count: int = 8, - secret_table: t.Type[AuthenticatorSecret] = AuthenticatorSecret, + secret_table: type[AuthenticatorSecret] = AuthenticatorSecret, issuer_name: str = "Piccolo-MFA", - register_template_path: t.Optional[str] = None, - styles: t.Optional[Styles] = None, + register_template_path: Optional[str] = None, + styles: Optional[Styles] = None, valid_window: int = 0, ): """ diff --git a/piccolo_api/mfa/authenticator/tables.py b/piccolo_api/mfa/authenticator/tables.py index 466ac2a6..f9ee2f2a 100644 --- a/piccolo_api/mfa/authenticator/tables.py +++ b/piccolo_api/mfa/authenticator/tables.py @@ -2,7 +2,7 @@ import datetime import logging -import typing as t +from typing import TYPE_CHECKING from piccolo.apps.user.tables import BaseUser from piccolo.columns import Array, Integer, Serial, Text, Timestamptz @@ -11,7 +11,7 @@ from piccolo_api.encryption.providers import EncryptionProvider from piccolo_api.mfa.recovery_codes import generate_recovery_code -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover import pyotp @@ -73,7 +73,7 @@ async def create_new( user_id: int, encryption_provider: EncryptionProvider, recovery_code_count: int = 8, - ) -> t.Tuple[AuthenticatorSecret, t.List[str]]: + ) -> tuple[AuthenticatorSecret, list[str]]: """ Returns the new ``AuthenticatorSecret`` and the unhashed recovery codes. This is the only time the unhashed recovery codes will be diff --git a/piccolo_api/mfa/authenticator/utils.py b/piccolo_api/mfa/authenticator/utils.py index b6c67521..1e78c92e 100644 --- a/piccolo_api/mfa/authenticator/utils.py +++ b/piccolo_api/mfa/authenticator/utils.py @@ -1,10 +1,10 @@ from __future__ import annotations -import typing as t from base64 import b64encode from io import BytesIO +from typing import TYPE_CHECKING -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover import qrcode diff --git a/piccolo_api/mfa/endpoints.py b/piccolo_api/mfa/endpoints.py index fc0bf3eb..b08014e7 100644 --- a/piccolo_api/mfa/endpoints.py +++ b/piccolo_api/mfa/endpoints.py @@ -1,7 +1,7 @@ import os -import typing as t from abc import ABCMeta, abstractmethod from json import JSONDecodeError +from typing import Any, Optional from jinja2 import Environment, FileSystemLoader from piccolo.apps.user.tables import BaseUser @@ -33,7 +33,7 @@ def _provider(self) -> MFAProvider: @property @abstractmethod - def _auth_table(self) -> t.Type[BaseUser]: + def _auth_table(self) -> type[BaseUser]: raise NotImplementedError @property @@ -44,7 +44,7 @@ def _styles(self) -> Styles: def _render_register_template( self, request: Request, - extra_context: t.Optional[t.Dict] = None, + extra_context: Optional[dict] = None, status_code: int = 200, ): template = environment.get_template("mfa_setup.html") @@ -85,7 +85,7 @@ async def post(self, request: Request): # Some middleware (for example CSRF) has already awaited the request # body, and adds it to the request. - body: t.Any = request.scope.get("form") + body: Any = request.scope.get("form") if not body: try: @@ -152,9 +152,9 @@ async def post(self, request: Request): def mfa_setup( provider: MFAProvider, - auth_table: t.Type[BaseUser] = BaseUser, - styles: t.Optional[Styles] = None, -) -> t.Type[HTTPEndpoint]: + auth_table: type[BaseUser] = BaseUser, + styles: Optional[Styles] = None, +) -> type[HTTPEndpoint]: """ This endpoint needs to be protected ``SessionAuthMiddleware``, ensuring that only logged in users can access it. diff --git a/piccolo_api/mfa/recovery_codes.py b/piccolo_api/mfa/recovery_codes.py index 2d456b57..55e60deb 100644 --- a/piccolo_api/mfa/recovery_codes.py +++ b/piccolo_api/mfa/recovery_codes.py @@ -1,12 +1,12 @@ import math import secrets import string -import typing as t +from collections.abc import Sequence DEFAULT_CHARACTERS = string.ascii_lowercase + string.digits -def _get_random_string(length: int, characters: t.Sequence[str]) -> str: +def _get_random_string(length: int, characters: Sequence[str]) -> str: """ :param length: How long to make the string. @@ -19,7 +19,7 @@ def _get_random_string(length: int, characters: t.Sequence[str]) -> str: def generate_recovery_code( length: int = 12, - characters: t.Sequence[str] = DEFAULT_CHARACTERS, + characters: Sequence[str] = DEFAULT_CHARACTERS, separator: str = "-", ): """ diff --git a/piccolo_api/openapi/endpoints.py b/piccolo_api/openapi/endpoints.py index 19543aeb..9d277594 100644 --- a/piccolo_api/openapi/endpoints.py +++ b/piccolo_api/openapi/endpoints.py @@ -1,5 +1,5 @@ import os -import typing as t +from typing import Optional import jinja2 from fastapi.openapi.docs import get_swagger_ui_oauth2_redirect_html @@ -24,8 +24,8 @@ def swagger_ui( schema_url: str = "/openapi.json", swagger_ui_title: str = "Piccolo Swagger UI", - csrf_cookie_name: t.Optional[str] = DEFAULT_COOKIE_NAME, - csrf_header_name: t.Optional[str] = DEFAULT_HEADER_NAME, + csrf_cookie_name: Optional[str] = DEFAULT_COOKIE_NAME, + csrf_header_name: Optional[str] = DEFAULT_HEADER_NAME, swagger_ui_version: str = "5", ): """ diff --git a/piccolo_api/rate_limiting/middleware.py b/piccolo_api/rate_limiting/middleware.py index be4fce40..336118f1 100644 --- a/piccolo_api/rate_limiting/middleware.py +++ b/piccolo_api/rate_limiting/middleware.py @@ -1,15 +1,15 @@ from __future__ import annotations -import typing as t from abc import ABCMeta, abstractmethod from collections import defaultdict from time import time +from typing import TYPE_CHECKING, Optional from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import Response from starlette.types import ASGIApp -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from starlette.middleware.base import Request, RequestResponseEndpoint @@ -52,7 +52,7 @@ def __init__( self, timespan: int, limit: int = 1000, - block_duration: t.Optional[int] = None, + block_duration: Optional[int] = None, ): """ :param timespan: @@ -72,7 +72,7 @@ def __init__( self.last_reset = time() self.limit = limit - self.blocked: t.Dict[str, float] = {} + self.blocked: dict[str, float] = {} self.block_duration = block_duration def _handle_blocked(self): @@ -83,7 +83,7 @@ def is_already_blocked(self, identifier: str) -> bool: Check whether the identifier is already blocked from previous requests. Remove the identifier if the block has expired. """ - blocked_at: t.Optional[float] = self.blocked.get(identifier, None) + blocked_at: Optional[float] = self.blocked.get(identifier, None) if blocked_at: duration = self.block_duration if (time() - blocked_at < duration) if duration else True: @@ -139,7 +139,7 @@ class RateLimitingMiddleware(BaseHTTPMiddleware): def __init__( self, app: ASGIApp, - provider: t.Optional[RateLimitProvider] = None, + provider: Optional[RateLimitProvider] = None, ): """ :param app: diff --git a/piccolo_api/register/endpoints.py b/piccolo_api/register/endpoints.py index 7be7cc2a..62797234 100644 --- a/piccolo_api/register/endpoints.py +++ b/piccolo_api/register/endpoints.py @@ -2,9 +2,9 @@ import os import re -import typing as t from abc import ABCMeta, abstractmethod from json import JSONDecodeError +from typing import TYPE_CHECKING, Any, Optional, Union from jinja2 import Environment, FileSystemLoader from piccolo.apps.user.tables import BaseUser @@ -20,7 +20,7 @@ from piccolo_api.shared.auth.styles import Styles -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from jinja2 import Template from starlette.responses import Response @@ -38,12 +38,12 @@ class RegisterEndpoint(HTTPEndpoint, metaclass=ABCMeta): @property @abstractmethod - def _auth_table(self) -> t.Type[BaseUser]: + def _auth_table(self) -> type[BaseUser]: raise NotImplementedError @property @abstractmethod - def _redirect_to(self) -> t.Union[str, URL]: + def _redirect_to(self) -> Union[str, URL]: """ Where to redirect to after login is successful. """ @@ -56,12 +56,12 @@ def _register_template(self) -> Template: @property @abstractmethod - def _user_defaults(self) -> t.Optional[t.Dict[str, t.Any]]: + def _user_defaults(self) -> Optional[dict[str, Any]]: raise NotImplementedError @property @abstractmethod - def _captcha(self) -> t.Optional[Captcha]: + def _captcha(self) -> Optional[Captcha]: raise NotImplementedError @property @@ -75,7 +75,7 @@ def _read_only(self) -> bool: raise NotImplementedError def render_template( - self, request: Request, template_context: t.Dict[str, t.Any] = {} + self, request: Request, template_context: dict[str, Any] = {} ) -> HTMLResponse: # If CSRF middleware is present, we have to include a form field with # the CSRF token. It only works if CSRFMiddleware has @@ -106,7 +106,7 @@ async def post(self, request: Request) -> Response: # Some middleware (for example CSRF) has already awaited the request # body, and adds it to the request. - body: t.Any = request.scope.get("form") + body: Any = request.scope.get("form") if not body: try: @@ -211,14 +211,14 @@ async def post(self, request: Request) -> Response: def register( - auth_table: t.Type[BaseUser] = BaseUser, - redirect_to: t.Union[str, URL] = "/login/", - template_path: t.Optional[str] = None, - user_defaults: t.Optional[t.Dict[str, t.Any]] = None, - captcha: t.Optional[Captcha] = None, - styles: t.Optional[Styles] = None, + auth_table: type[BaseUser] = BaseUser, + redirect_to: Union[str, URL] = "/login/", + template_path: Optional[str] = None, + user_defaults: Optional[dict[str, Any]] = None, + captcha: Optional[Captcha] = None, + styles: Optional[Styles] = None, read_only: bool = False, -) -> t.Type[RegisterEndpoint]: +) -> type[RegisterEndpoint]: """ An endpoint for register user. diff --git a/piccolo_api/session_auth/endpoints.py b/piccolo_api/session_auth/endpoints.py index ba8003d1..100385cd 100644 --- a/piccolo_api/session_auth/endpoints.py +++ b/piccolo_api/session_auth/endpoints.py @@ -1,11 +1,12 @@ from __future__ import annotations import os -import typing as t import warnings from abc import ABCMeta, abstractmethod +from collections.abc import Sequence from datetime import datetime, timedelta from json import JSONDecodeError +from typing import TYPE_CHECKING, Any, Literal, Optional, cast from jinja2 import Environment, FileSystemLoader from piccolo.apps.user.tables import BaseUser @@ -24,7 +25,7 @@ from piccolo_api.shared.auth.hooks import LoginHooks from piccolo_api.shared.auth.styles import Styles -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from jinja2 import Template from starlette.responses import Response @@ -42,7 +43,7 @@ class SessionLogoutEndpoint(HTTPEndpoint, metaclass=ABCMeta): @property @abstractmethod - def _session_table(self) -> t.Type[SessionsBase]: + def _session_table(self) -> type[SessionsBase]: raise NotImplementedError @property @@ -52,7 +53,7 @@ def _cookie_name(self) -> str: @property @abstractmethod - def _redirect_to(self) -> t.Optional[str]: + def _redirect_to(self) -> Optional[str]: raise NotImplementedError @property @@ -62,11 +63,11 @@ def _logout_template(self) -> Template: @property @abstractmethod - def _styles(self) -> t.Optional[Styles]: + def _styles(self) -> Optional[Styles]: raise NotImplementedError def _render_template( - self, request: Request, template_context: t.Dict[str, t.Any] = {} + self, request: Request, template_context: dict[str, Any] = {} ) -> HTMLResponse: # If CSRF middleware is present, we have to include a form field with # the CSRF token. It only works if CSRFMiddleware has @@ -111,12 +112,12 @@ async def post(self, request: Request) -> Response: class SessionLoginEndpoint(HTTPEndpoint, metaclass=ABCMeta): @property @abstractmethod - def _auth_table(self) -> t.Type[BaseUser]: + def _auth_table(self) -> type[BaseUser]: raise NotImplementedError @property @abstractmethod - def _session_table(self) -> t.Type[SessionsBase]: + def _session_table(self) -> type[SessionsBase]: raise NotImplementedError @property @@ -136,7 +137,7 @@ def _cookie_name(self) -> str: @property @abstractmethod - def _redirect_to(self) -> t.Optional[str]: + def _redirect_to(self) -> Optional[str]: """ Where to redirect to after login is successful. """ @@ -157,28 +158,28 @@ def _login_template(self) -> Template: @property @abstractmethod - def _hooks(self) -> t.Optional[LoginHooks]: + def _hooks(self) -> Optional[LoginHooks]: raise NotImplementedError @property @abstractmethod - def _captcha(self) -> t.Optional[Captcha]: + def _captcha(self) -> Optional[Captcha]: raise NotImplementedError @property @abstractmethod - def _styles(self) -> t.Optional[Styles]: + def _styles(self) -> Optional[Styles]: raise NotImplementedError @property @abstractmethod - def _mfa_providers(self) -> t.Optional[t.Sequence[MFAProvider]]: + def _mfa_providers(self) -> Optional[Sequence[MFAProvider]]: raise NotImplementedError def _render_template( self, request: Request, - template_context: t.Dict[str, t.Any] = {}, + template_context: dict[str, Any] = {}, status_code=200, ) -> HTMLResponse: # If CSRF middleware is present, we have to include a form field with @@ -201,7 +202,7 @@ def _render_template( ) def _get_error_response( - self, request, error: str, response_format: t.Literal["html", "plain"] + self, request, error: str, response_format: Literal["html", "plain"] ) -> Response: if response_format == "html": return self._render_template( @@ -221,7 +222,7 @@ async def get(self, request: Request) -> HTMLResponse: async def post(self, request: Request) -> Response: # Some middleware (for example CSRF) has already awaited the request # body, and adds it to the request. - body: t.Any = request.scope.get("form") + body: Any = request.scope.get("form") if not body: try: @@ -293,7 +294,7 @@ async def post(self, request: Request) -> Response: mfa_code = body.get("mfa_code") if mfa_code is None: - has_sent_code: t.List[bool] = [] + has_sent_code: list[bool] = [] for mfa_provider in enrolled_mfa_providers: # Send the code (only used with things like email # and SMS MFA). @@ -440,7 +441,7 @@ async def post(self, request: Request) -> Response: ) warnings.warn(message) - cookie_value = t.cast(str, session.token) + cookie_value = cast(str, session.token) response.set_cookie( key=self._cookie_name, @@ -454,19 +455,19 @@ async def post(self, request: Request) -> Response: def session_login( - auth_table: t.Type[BaseUser] = BaseUser, - session_table: t.Type[SessionsBase] = SessionsBase, + auth_table: type[BaseUser] = BaseUser, + session_table: type[SessionsBase] = SessionsBase, session_expiry: timedelta = timedelta(hours=1), max_session_expiry: timedelta = timedelta(days=7), - redirect_to: t.Optional[str] = "/", + redirect_to: Optional[str] = "/", production: bool = False, cookie_name: str = "id", - template_path: t.Optional[str] = None, - hooks: t.Optional[LoginHooks] = None, - captcha: t.Optional[Captcha] = None, - styles: t.Optional[Styles] = None, - mfa_providers: t.Optional[t.Sequence[MFAProvider]] = None, -) -> t.Type[SessionLoginEndpoint]: + template_path: Optional[str] = None, + hooks: Optional[LoginHooks] = None, + captcha: Optional[Captcha] = None, + styles: Optional[Styles] = None, + mfa_providers: Optional[Sequence[MFAProvider]] = None, +) -> type[SessionLoginEndpoint]: """ An endpoint for creating a user session. @@ -538,12 +539,12 @@ class _SessionLoginEndpoint(SessionLoginEndpoint): def session_logout( - session_table: t.Type[SessionsBase] = SessionsBase, + session_table: type[SessionsBase] = SessionsBase, cookie_name: str = "id", - redirect_to: t.Optional[str] = None, - template_path: t.Optional[str] = None, - styles: t.Optional[Styles] = None, -) -> t.Type[SessionLogoutEndpoint]: + redirect_to: Optional[str] = None, + template_path: Optional[str] = None, + styles: Optional[Styles] = None, +) -> type[SessionLogoutEndpoint]: """ An endpoint for clearing a user session. diff --git a/piccolo_api/session_auth/middleware.py b/piccolo_api/session_auth/middleware.py index 43d682d4..6d039306 100644 --- a/piccolo_api/session_auth/middleware.py +++ b/piccolo_api/session_auth/middleware.py @@ -1,7 +1,8 @@ from __future__ import annotations -import typing as t +from collections.abc import Sequence from datetime import timedelta +from typing import Optional from piccolo.apps.user.tables import BaseUser as PiccoloBaseUser from starlette.authentication import ( @@ -24,15 +25,15 @@ class SessionsAuthBackend(AuthenticationBackend): def __init__( self, - auth_table: t.Type[PiccoloBaseUser] = PiccoloBaseUser, - session_table: t.Type[SessionsBase] = SessionsBase, + auth_table: type[PiccoloBaseUser] = PiccoloBaseUser, + session_table: type[SessionsBase] = SessionsBase, cookie_name: str = "id", admin_only: bool = True, superuser_only: bool = False, active_only: bool = True, - increase_expiry: t.Optional[timedelta] = None, + increase_expiry: Optional[timedelta] = None, allow_unauthenticated: bool = False, - excluded_paths: t.Optional[t.Sequence[str]] = None, + excluded_paths: Optional[Sequence[str]] = None, ): """ :param auth_table: @@ -80,7 +81,7 @@ def __init__( @check_excluded_paths async def authenticate( self, conn: HTTPConnection - ) -> t.Optional[t.Tuple[AuthCredentials, BaseUser]]: + ) -> Optional[tuple[AuthCredentials, BaseUser]]: token = conn.cookies.get(self.cookie_name, None) if not token: if self.allow_unauthenticated: diff --git a/piccolo_api/session_auth/tables.py b/piccolo_api/session_auth/tables.py index bc1341c6..fd861f64 100644 --- a/piccolo_api/session_auth/tables.py +++ b/piccolo_api/session_auth/tables.py @@ -1,8 +1,8 @@ from __future__ import annotations import secrets -import typing as t from datetime import datetime, timedelta +from typing import Optional, cast from piccolo.columns import Integer, Serial, Timestamp, Varchar from piccolo.columns.defaults.timestamp import TimestampOffset @@ -39,8 +39,8 @@ class SessionsBase(Table, tablename="sessions"): async def create_session( cls, user_id: int, - expiry_date: t.Optional[datetime] = None, - max_expiry_date: t.Optional[datetime] = None, + expiry_date: Optional[datetime] = None, + max_expiry_date: Optional[datetime] = None, ) -> SessionsBase: """ Creates a session in the database. @@ -62,7 +62,7 @@ async def create_session( @classmethod def create_session_sync( - cls, user_id: int, expiry_date: t.Optional[datetime] = None + cls, user_id: int, expiry_date: Optional[datetime] = None ) -> SessionsBase: """ A sync equivalent of :meth:`create_session`. @@ -71,8 +71,8 @@ def create_session_sync( @classmethod async def get_user_id( - cls, token: str, increase_expiry: t.Optional[timedelta] = None - ) -> t.Optional[int]: + cls, token: str, increase_expiry: Optional[timedelta] = None + ) -> Optional[int]: """ Returns the ``user_id`` if the given token is valid, otherwise ``None``. @@ -91,19 +91,19 @@ async def get_user_id( now = datetime.now() if (session.expiry_date > now) and (session.max_expiry_date > now): if increase_expiry and ( - t.cast(datetime, session.expiry_date) - now < increase_expiry + cast(datetime, session.expiry_date) - now < increase_expiry ): session.expiry_date = ( - t.cast(datetime, session.expiry_date) + increase_expiry + cast(datetime, session.expiry_date) + increase_expiry ) await session.save().run() - return t.cast(t.Optional[int], session.user_id) + return cast(Optional[int], session.user_id) else: return None @classmethod - def get_user_id_sync(cls, token: str) -> t.Optional[int]: + def get_user_id_sync(cls, token: str) -> Optional[int]: """ A sync wrapper around :meth:`get_user_id`. """ diff --git a/piccolo_api/shared/auth/captcha.py b/piccolo_api/shared/auth/captcha.py index 85afeec4..2c5eb60e 100644 --- a/piccolo_api/shared/auth/captcha.py +++ b/piccolo_api/shared/auth/captcha.py @@ -1,13 +1,14 @@ import inspect -import typing as t +from collections.abc import Awaitable, Callable from dataclasses import dataclass +from typing import Optional, Union import httpx -Response = t.Optional[str] -Validator = t.Union[ - t.Callable[[str], Response], - t.Callable[[str], t.Awaitable[Response]], +Response = Optional[str] +Validator = Union[ + Callable[[str], Response], + Callable[[str], Awaitable[Response]], ] @@ -34,7 +35,7 @@ class directly if doing something custom. token_field: str validator: Validator - async def validate(self, token: str) -> t.Optional[str]: + async def validate(self, token: str) -> Optional[str]: if self.validator: if inspect.iscoroutinefunction(self.validator): return await self.validator(token) # type: ignore @@ -72,7 +73,7 @@ def hcaptcha(site_key: str, secret_key: str) -> Captcha: """ - async def validator(token: str) -> t.Optional[str]: + async def validator(token: str) -> Optional[str]: if not token: return "Unable to find CAPTCHA token." @@ -85,7 +86,7 @@ async def validator(token: str) -> t.Optional[str]: }, ) data = response.json() - if not data.get("success", None) is True: + if data.get("success", None) is not True: return "CAPTCHA failed." return None @@ -121,7 +122,7 @@ def recaptcha_v2(site_key: str, secret_key: str) -> Captcha: """ - async def validator(token: str) -> t.Optional[str]: + async def validator(token: str) -> Optional[str]: if not token: return "Unable to find CAPTCHA token." @@ -134,7 +135,7 @@ async def validator(token: str) -> t.Optional[str]: }, ) data = response.json() - if not data.get("success", None) is True: + if data.get("success", None) is not True: return "CAPTCHA failed." return None diff --git a/piccolo_api/shared/auth/excluded_paths.py b/piccolo_api/shared/auth/excluded_paths.py index cc86565b..78004f6d 100644 --- a/piccolo_api/shared/auth/excluded_paths.py +++ b/piccolo_api/shared/auth/excluded_paths.py @@ -1,7 +1,7 @@ from __future__ import annotations import functools -import typing as t +from collections.abc import Callable from starlette.authentication import AuthCredentials, AuthenticationBackend from starlette.requests import HTTPConnection @@ -9,7 +9,7 @@ from piccolo_api.shared.auth import UnauthenticatedUser -def check_excluded_paths(authenticate_func: t.Callable): +def check_excluded_paths(authenticate_func: Callable): @functools.wraps(authenticate_func) async def authenticate(self: AuthenticationBackend, conn: HTTPConnection): diff --git a/piccolo_api/shared/auth/hooks.py b/piccolo_api/shared/auth/hooks.py index 094e670e..fd74da03 100644 --- a/piccolo_api/shared/auth/hooks.py +++ b/piccolo_api/shared/auth/hooks.py @@ -2,19 +2,20 @@ import dataclasses import inspect -import typing as t +from collections.abc import Awaitable, Callable +from typing import Optional, Union, cast -PreLoginHook = t.Union[ - t.Callable[[str], t.Optional[str]], - t.Callable[[str], t.Awaitable[t.Optional[str]]], +PreLoginHook = Union[ + Callable[[str], Optional[str]], + Callable[[str], Awaitable[Optional[str]]], ] -LoginSuccessHook = t.Union[ - t.Callable[[str, int], t.Optional[str]], - t.Callable[[str, int], t.Awaitable[t.Optional[str]]], +LoginSuccessHook = Union[ + Callable[[str, int], Optional[str]], + Callable[[str, int], Awaitable[Optional[str]]], ] -LoginFailureHook = t.Union[ - t.Callable[[str], t.Optional[str]], - t.Callable[[str], t.Awaitable[t.Optional[str]]], +LoginFailureHook = Union[ + Callable[[str], Optional[str]], + Callable[[str], Awaitable[Optional[str]]], ] @@ -85,16 +86,16 @@ async def log_failure(username: str, **kwargs): """ # noqa: E501 - pre_login: t.Optional[t.List[PreLoginHook]] = None - login_success: t.Optional[t.List[LoginSuccessHook]] = None - login_failure: t.Optional[t.List[LoginFailureHook]] = None + pre_login: Optional[list[PreLoginHook]] = None + login_success: Optional[list[LoginSuccessHook]] = None + login_failure: Optional[list[LoginFailureHook]] = None - async def run_pre_login(self, username: str) -> t.Optional[str]: + async def run_pre_login(self, username: str) -> Optional[str]: if self.pre_login: for hook in self.pre_login: response = hook(username) if inspect.isawaitable(response): - response = t.cast(t.Awaitable, response) + response = cast(Awaitable, response) response = await response if isinstance(response, str): @@ -104,12 +105,12 @@ async def run_pre_login(self, username: str) -> t.Optional[str]: async def run_login_success( self, username: str, user_id: int - ) -> t.Optional[str]: + ) -> Optional[str]: if self.login_success: for hook in self.login_success: response = hook(username, user_id) if inspect.isawaitable(response): - response = t.cast(t.Awaitable, response) + response = cast(Awaitable, response) response = await response if isinstance(response, str): @@ -117,12 +118,12 @@ async def run_login_success( return None - async def run_login_failure(self, username: str) -> t.Optional[str]: + async def run_login_failure(self, username: str) -> Optional[str]: if self.login_failure: for hook in self.login_failure: response = hook(username) if inspect.isawaitable(response): - response = t.cast(t.Awaitable, response) + response = cast(Awaitable, response) response = await response if isinstance(response, str): diff --git a/piccolo_api/shared/auth/junction.py b/piccolo_api/shared/auth/junction.py index 7e9aa769..cb5c0d06 100644 --- a/piccolo_api/shared/auth/junction.py +++ b/piccolo_api/shared/auth/junction.py @@ -1,4 +1,5 @@ -import typing as t +from collections.abc import Sequence +from typing import Optional from starlette.authentication import ( AuthCredentials, @@ -15,12 +16,12 @@ class AuthenticationBackendJunction(AuthenticationBackend): the same endpoint - if any of them pass, then auth is successful. """ - def __init__(self, backends: t.Sequence[AuthenticationBackend]): + def __init__(self, backends: Sequence[AuthenticationBackend]): self.backends = backends async def authenticate( self, conn: HTTPConnection - ) -> t.Optional[t.Tuple[AuthCredentials, BaseUser]]: + ) -> Optional[tuple[AuthCredentials, BaseUser]]: for backend in self.backends: try: response = await backend.authenticate(conn=conn) diff --git a/piccolo_api/shared/auth/user.py b/piccolo_api/shared/auth/user.py index fd46d25b..63317cc2 100644 --- a/piccolo_api/shared/auth/user.py +++ b/piccolo_api/shared/auth/user.py @@ -1,11 +1,11 @@ from __future__ import annotations -import typing as t +from typing import TYPE_CHECKING, cast from piccolo.apps.user.tables import BaseUser as PiccoloBaseUser from starlette.authentication import BaseUser -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from piccolo.table import Table @@ -19,18 +19,18 @@ def __init__(self, user: PiccoloBaseUser): # contructor, but we can just infer them from the user instance. @property - def auth_table(self) -> t.Type[Table]: + def auth_table(self) -> type[Table]: return self.user.__class__ @property def user_id(self) -> int: - return t.cast( + return cast( int, getattr(self.user, self.user._meta.primary_key._meta.name) ) @property def username(self) -> str: - return t.cast(str, self.user.username) + return cast(str, self.user.username) ########################################################################### # Required properties. diff --git a/piccolo_api/token_auth/endpoints.py b/piccolo_api/token_auth/endpoints.py index 34d74f7a..5b8491b4 100644 --- a/piccolo_api/token_auth/endpoints.py +++ b/piccolo_api/token_auth/endpoints.py @@ -1,7 +1,7 @@ from __future__ import annotations -import typing as t from abc import ABCMeta, abstractmethod +from typing import Optional from piccolo.apps.user.tables import BaseUser from starlette.endpoints import HTTPEndpoint @@ -18,7 +18,7 @@ class TokenProvider(metaclass=ABCMeta): """ @abstractmethod - async def get_token(self, username: str, password: str) -> t.Optional[str]: + async def get_token(self, username: str, password: str) -> Optional[str]: pass @@ -27,7 +27,7 @@ class PiccoloTokenProvider(TokenProvider): Retrieves a token from a Piccolo table. """ - async def get_token(self, username: str, password: str) -> t.Optional[str]: + async def get_token(self, username: str, password: str) -> Optional[str]: user = await BaseUser.login(username=username, password=password) if user: @@ -73,7 +73,7 @@ async def post(self, request: Request) -> Response: def token_login( provider: TokenProvider = PiccoloTokenProvider(), -) -> t.Type[TokenAuthLoginEndpoint]: +) -> type[TokenAuthLoginEndpoint]: """ Create an endpoint for logging using tokens. diff --git a/piccolo_api/token_auth/middleware.py b/piccolo_api/token_auth/middleware.py index b9dedd29..bd5c8128 100644 --- a/piccolo_api/token_auth/middleware.py +++ b/piccolo_api/token_auth/middleware.py @@ -1,7 +1,8 @@ from __future__ import annotations -import typing as t from abc import ABCMeta, abstractmethod +from collections.abc import Sequence +from typing import Optional from piccolo.apps.user.tables import BaseUser as BaseUserTable from starlette.authentication import ( @@ -35,7 +36,7 @@ class SecretTokenAuthProvider(TokenAuthProvider): microservices, where the client is trusted. """ - def __init__(self, tokens: t.Sequence[str]): + def __init__(self, tokens: Sequence[str]): self.tokens = tokens async def get_user(self, token: str) -> SimpleUser: @@ -53,8 +54,8 @@ class PiccoloTokenAuthProvider(TokenAuthProvider): def __init__( self, - auth_table: t.Type[BaseUserTable] = BaseUserTable, - token_table: t.Type[TokenAuth] = TokenAuth, + auth_table: type[BaseUserTable] = BaseUserTable, + token_table: type[TokenAuth] = TokenAuth, ): self.auth_table = auth_table self.token_table = token_table @@ -85,7 +86,7 @@ class TokenAuthBackend(AuthenticationBackend): def __init__( self, token_auth_provider: TokenAuthProvider = DEFAULT_PROVIDER, - excluded_paths: t.Optional[t.Sequence[str]] = None, + excluded_paths: Optional[Sequence[str]] = None, ): """ :param token_auth_provider: @@ -110,7 +111,7 @@ def extract_token(self, header: str) -> str: @check_excluded_paths async def authenticate( self, conn: HTTPConnection - ) -> t.Optional[t.Tuple[AuthCredentials, BaseUser]]: + ) -> Optional[tuple[AuthCredentials, BaseUser]]: auth_header = conn.headers.get("Authorization", None) if not auth_header: diff --git a/piccolo_api/token_auth/tables.py b/piccolo_api/token_auth/tables.py index c2d6ccaa..5a1dc9a1 100644 --- a/piccolo_api/token_auth/tables.py +++ b/piccolo_api/token_auth/tables.py @@ -1,7 +1,7 @@ from __future__ import annotations -import typing as t import uuid +from typing import Optional, cast from piccolo.apps.user.tables import BaseUser from piccolo.columns.column_types import ForeignKey, Serial, Varchar @@ -46,22 +46,22 @@ async def create_token( token_auth = cls(user=user_id) await token_auth.save().run() - return t.cast(str, token_auth.token) + return cast(str, token_auth.token) @classmethod def create_token_sync(cls, user_id: int) -> str: return run_sync(cls.create_token(user_id)) @classmethod - async def authenticate(cls, token: str) -> t.Optional[t.Dict]: + async def authenticate(cls, token: str) -> Optional[dict]: return await cls.select(cls.user).where(cls.token == token).first() @classmethod - def authenticate_sync(cls, token: str) -> t.Optional[t.Dict]: + def authenticate_sync(cls, token: str) -> Optional[dict]: return run_sync(cls.authenticate(token)) @classmethod - async def get_user_id(cls, token: str) -> t.Optional[int]: + async def get_user_id(cls, token: str) -> Optional[int]: """ Returns the user_id if the given token is valid, otherwise None. """ diff --git a/piccolo_api/utils/types.py b/piccolo_api/utils/types.py index 63d96f18..6f374c7a 100644 --- a/piccolo_api/utils/types.py +++ b/piccolo_api/utils/types.py @@ -4,7 +4,7 @@ from __future__ import annotations -import typing as t +from typing import Union, get_args, get_origin try: # Python 3.10 and above @@ -15,7 +15,7 @@ class UnionType: # type: ignore ... -def get_type(type_: t.Type) -> t.Type: +def get_type(type_: type) -> type: """ Extract the inner type from an optional if necessary, otherwise return the type as is. @@ -35,12 +35,12 @@ def get_type(type_: t.Type) -> t.Type: list[str] """ - origin = t.get_origin(type_) + origin = get_origin(type_) - # Note: even if `t.Optional` is passed in, the origin is still a - # `t.Union` or `UnionType` depending on the Python version. - if any(origin is i for i in (t.Union, UnionType)): - union_args = t.get_args(type_) + # Note: even if `Optional` is passed in, the origin is still a + # `Union` or `UnionType` depending on the Python version. + if any(origin is i for i in (Union, UnionType)): + union_args = get_args(type_) NoneType = type(None) diff --git a/setup.py b/setup.py index 9758993f..8aa7e28b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,6 @@ import itertools import os -import typing as t from setuptools import find_packages, setup @@ -29,7 +28,7 @@ ] -def parse_requirement(req_path: str) -> t.List[str]: +def parse_requirement(req_path: str) -> list[str]: """ Parses a requirement file - returning a list of contents. Example:: @@ -42,7 +41,7 @@ def parse_requirement(req_path: str) -> t.List[str]: return [i.strip() for i in contents.strip().split("\n")] -def extras_require() -> t.Dict[str, t.List[str]]: +def extras_require() -> dict[str, list[str]]: """ Parse requirements in requirements/extras directory """ diff --git a/tests/crud/test_validators.py b/tests/crud/test_validators.py index 67548ffd..99a8fc6c 100644 --- a/tests/crud/test_validators.py +++ b/tests/crud/test_validators.py @@ -1,4 +1,4 @@ -import typing as t +from collections.abc import Callable from dataclasses import dataclass from unittest import TestCase @@ -23,7 +23,7 @@ def get_readable(cls): @dataclass class Scenario: - validators: t.List[t.Callable] + validators: list[Callable] status_code: int content: bytes diff --git a/tests/utils/test_types.py b/tests/utils/test_types.py index b5a705d8..34d83a7e 100644 --- a/tests/utils/test_types.py +++ b/tests/utils/test_types.py @@ -1,5 +1,5 @@ import sys -import typing as t +from typing import Optional, Union from unittest import TestCase import pytest @@ -14,12 +14,12 @@ def test_get_type(self): If we pass in an optional type, it should return the non-optional type. """ # Should return the underlying type, as they're all optional: - self.assertIs(get_type(t.Optional[str]), str) - self.assertIs(get_type(t.Optional[t.List[str]]), t.List[str]) - self.assertIs(get_type(t.Union[str, None]), str) + self.assertIs(get_type(Optional[str]), str) + self.assertEqual(get_type(Optional[list[str]]), list[str]) + self.assertIs(get_type(Union[str, None]), str) # Should be returned as is, because it's not optional: - self.assertIs(get_type(t.List[str]), t.List[str]) + self.assertEqual(get_type(list[str]), list[str]) @pytest.mark.skipif( sys.version_info < (3, 10), reason="Union syntax not available"