diff --git a/docs/src/piccolo/schema/index.rst b/docs/src/piccolo/schema/index.rst index ec9b887e6..bdda8bc4c 100644 --- a/docs/src/piccolo/schema/index.rst +++ b/docs/src/piccolo/schema/index.rst @@ -9,5 +9,6 @@ The schema is how you define your database tables, columns and relationships. ./defining ./column_types ./m2m + ./reverse_lookup ./one_to_one ./advanced diff --git a/docs/src/piccolo/schema/reverse_lookup.rst b/docs/src/piccolo/schema/reverse_lookup.rst new file mode 100644 index 000000000..d1df333f0 --- /dev/null +++ b/docs/src/piccolo/schema/reverse_lookup.rst @@ -0,0 +1,137 @@ +.. currentmodule:: piccolo.columns.reverse_lookup + +############## +Reverse Lookup +############## + +For example, we might have our ``Manager`` table, and we want to +get all the bands associated with the same manager. +For this we can use reverse foreign key lookup. + +We create it in Piccolo like this: + +.. code-block:: python + + from piccolo.columns.column_types import ( + ForeignKey, + LazyTableReference, + Varchar + ) + from piccolo.columns.reverse_lookup import ReverseLookup + from piccolo.table import Table + + + class Manager(Table): + name = Varchar() + bands = ReverseLookup( + LazyTableReference("Band", module_path=__name__), + reverse_fk="manager", + ) + + + class Band(Table): + name = Varchar() + manager = ForeignKey(Manager) + +------------------------------------------------------------------------------- + +Select queries +============== + +If we want to select each manager, along with a list of associated band names, +we can do this: + +.. code-block:: python + + >>> await Manager.select(Manager.name, Manager.bands(Band.name, as_list=True)) + [ + {'name': 'John', 'bands': ['C-Sharps']}, + {'name': 'Guido', 'bands': ['Pythonistas', 'Rustaceans']}, + ] + +You can request whichever column you like from the reverse lookup: + +.. code-block:: python + + >>> await Manager.select(Manager.name, Manager.bands(Band.id, as_list=True)) + [ + {'name': 'John', 'bands': [3]}, + {'name': 'Guido', 'bands': [1, 2]}, + ] + +You can also request multiple columns from the reverse lookup: + +.. code-block:: python + + >>> await Manager.select(Manager.name, Manager.bands(Band.id, Band.name)) + [ + { + 'name': 'John', + 'bands': [ + {'id': 3, 'name': 'C-Sharps'}, + ] + }, + { + 'name': 'Guido', + 'bands': [ + {'id': 1, 'name': 'Pythonistas'}, + {'id': 2, 'name': 'Rustaceans'}, + ] + } + ] + +If you omit the columns argument, then all of the columns are returned. + +.. code-block:: python + + >>> await Manager.select(Manager.name, Manager.bands()) + [ + { + 'name': 'John', + 'bands': [ + {'id': 3, 'name': 'C-Sharps'}, + ] + }, + { + 'name': 'Guido', + 'bands': [ + {'id': 1, 'name': 'Pythonistas'}, + {'id': 2, 'name': 'Rustaceans'}, + ] + } + ] + +The default order of reverse lookup results is ascending, but if you +specify ``descending=True``, you can get the results in descending order. + +.. code-block:: python + + >>> await Manager.select(Manager.name, Manager.bands(descending=True)) + [ + { + 'name': 'John', + 'bands': [ + {'id': 3, 'name': 'C-Sharps'}, + ] + }, + { + 'name': 'Guido', + 'bands': [ + {'id': 2, 'name': 'Rustaceans'}, + {'id': 1, 'name': 'Pythonistas'}, + ] + } + ] + +Object queries +============== + +We can also use object queries to ``ReverseLookup``. + +get_reverse_lookup +------------------ + +.. currentmodule:: piccolo.table + +.. automethod:: Table.get_reverse_lookup + :noindex: diff --git a/piccolo/columns/reverse_lookup.py b/piccolo/columns/reverse_lookup.py new file mode 100644 index 000000000..64772396d --- /dev/null +++ b/piccolo/columns/reverse_lookup.py @@ -0,0 +1,286 @@ +from __future__ import annotations + +import inspect +import typing as t +from dataclasses import dataclass + +from piccolo.columns.base import QueryString, Selectable +from piccolo.columns.column_types import ( + JSON, + JSONB, + Column, + LazyTableReference, +) +from piccolo.utils.sync import run_sync + +if t.TYPE_CHECKING: # pragma: no cover + from piccolo.table import Table + + +class ReverseLookupSelect(Selectable): + """ + This is a subquery used within a select to fetch reverse lookup data. + """ + + def __init__( + self, + *columns: Column, + reverse_lookup: ReverseLookup, + as_list: bool = False, + load_json: bool = False, + descending: bool = False, + ): + """ + :param columns: + Which columns to include from the related table. + :param as_list: + If a single column is provided, and ``as_list`` is ``True`` a + flattened list will be returned, rather than a list of objects. + :param load_json: + If ``True``, any JSON strings are loaded as Python objects. + :param descending: + If ``True'', reverse lookup results sorted in descending order, + otherwise in default ascending order. + + """ + self.as_list = as_list + self.columns = columns + self.reverse_lookup = reverse_lookup + self.load_json = load_json + self.descending = descending + + safe_types = [int, str] + + # If the columns can be serialised / deserialise as JSON, then we + # can fetch the data all in one go. + self.serialisation_safe = all( + (column.__class__.value_type in safe_types) + and (type(column) not in (JSON, JSONB)) + for column in columns + ) + + def get_select_string( + self, engine_type: str, with_alias=True + ) -> QueryString: + reverse_lookup_name = self.reverse_lookup._meta.name + + table1 = self.reverse_lookup._meta.table + table1_pk = table1._meta.primary_key._meta.name + table1_name = table1._meta.tablename + + table2 = self.reverse_lookup._meta.resolved_reverse_joining_table + table2_name = table2._meta.tablename + table2_pk = table2._meta.primary_key._meta.name + table2_fk = self.reverse_lookup._meta.reverse_fk + + reverse_select = f""" + "{table2_name}" + WHERE "{table2_name}"."{table2_fk}" + = "{table1_name}"."{table1_pk}" + """ + + if engine_type in ("postgres", "cockroach"): + if self.as_list: + column_name = self.columns[0]._meta.db_column_name + return QueryString( + f""" + ARRAY( + SELECT + "{table2_name}"."{column_name}" + FROM {reverse_select} + ) AS "{reverse_lookup_name}" + """ + ) + elif not self.serialisation_safe: + column_name = table2_pk + return QueryString( + f""" + ARRAY( + SELECT + "{table2_name}"."{column_name}" + FROM {reverse_select} + ) AS "{reverse_lookup_name}" + """ + ) + else: + if len(self.columns) > 0: + column_names = ", ".join( + f'"{table2_name}"."{column._meta.db_column_name}"' # noqa: E501 + for column in self.columns + ) + else: + column_names = ", ".join( + f'"{table2_name}"."{column._meta.db_column_name}"' # noqa: E501 + for column in table2._meta.columns + ) + return QueryString( + f""" + ( + SELECT JSON_AGG("{table2_name}s") + FROM ( + SELECT {column_names} FROM {reverse_select} + ) AS "{table2_name}s" + ) AS "{reverse_lookup_name}" + """ + ) + elif engine_type == "sqlite": + if len(self.columns) > 1 or not self.serialisation_safe: + column_name = table2_pk + else: + try: + column_name = self.columns[0]._meta.db_column_name + except IndexError: + column_name = table2_pk + + return QueryString( + f""" + ( + SELECT group_concat( + "{table2_name}"."{column_name}" + ) + FROM {reverse_select} + ) + AS "{reverse_lookup_name} [M2M]" + """ + ) + else: + raise ValueError(f"{engine_type} is an unrecognised engine type") + + +@dataclass +class ReverseLookupMeta: + reverse_joining_table: t.Union[t.Type[Table], LazyTableReference] + reverse_fk: str + + # Set by the Table Metaclass: + _name: t.Optional[str] = None + _table: t.Optional[t.Type[Table]] = None + + @property + def name(self) -> str: + if not self._name: + raise ValueError( + "`_name` isn't defined - the Table Metaclass should set it." + ) + return self._name + + @property + def table(self) -> t.Type[Table]: + if not self._table: + raise ValueError( + "`_table` isn't defined - the Table Metaclass should set it." + ) + return self._table + + @property + def resolved_reverse_joining_table(self) -> t.Type[Table]: + """ + Evaluates the ``reverse_joining_table`` attribute if it's a + ``LazyTableReference``, raising a ``ValueError`` if it fails, + otherwise returns a ``Table`` subclass. + """ + from piccolo.table import Table + + if isinstance(self.reverse_joining_table, LazyTableReference): + return self.reverse_joining_table.resolve() + elif inspect.isclass(self.reverse_joining_table) and issubclass( + self.reverse_joining_table, Table + ): + return self.reverse_joining_table + else: + raise ValueError( + "The reverse_joining_table attribute is neither a Table" + " subclass or a LazyTableReference instance." + ) + + +@dataclass +class ReverseLookupGetRelated: + row: Table + reverse_lookup: ReverseLookup + + async def run(self): + primary_table = self.reverse_lookup._meta._table + reverse_lookup_table = ( + self.reverse_lookup._meta.resolved_reverse_joining_table + ) + + for fk_column in reverse_lookup_table._meta.foreign_key_columns: + ids = ( + await primary_table.select( + primary_table._meta.primary_key.join_on( + fk_column + ).all_columns()[0] + ) + .where(primary_table._meta.primary_key == self.row) + .output(as_list=True) + ) + + results = ( + await reverse_lookup_table.objects().where( + reverse_lookup_table._meta.primary_key.is_in(ids) + ) + if len(ids) > 0 + else [] + ) + + return results + + def run_sync(self): + return run_sync(self.run()) + + def __await__(self): + return self.run().__await__() + + +class ReverseLookup: + def __init__( + self, + reverse_joining_table: t.Union[t.Type[Table], LazyTableReference], + reverse_fk: str, + ): + """ + :param reverse_joining_table: + A ``Table`` for reverse lookup. + :param reverse_fk: + The ForeignKey to be used for the reverse lookup. + """ + self._meta = ReverseLookupMeta( + reverse_joining_table=reverse_joining_table, + reverse_fk=reverse_fk, + ) + + def __call__( + self, + *columns: Column, + as_list: bool = False, + load_json: bool = False, + descending: bool = False, + ) -> ReverseLookupSelect: + """ + :param columns: + Which columns to include from the related table. If none are + specified, then all of the columns are returned. + :param as_list: + If a single column is provided, and ``as_list`` is ``True`` a + flattened list will be returned, rather than a list of objects. + :param load_json: + If ``True``, any JSON strings are loaded as Python objects. + :param descending: + If ``True'', reverse lookup results sorted in descending order, + otherwise in default ascending order. + + """ + + if as_list and len(columns) != 1: + raise ValueError( + "`as_list` is only valid with a single column argument" + ) + + return ReverseLookupSelect( + *columns, + reverse_lookup=self, + as_list=as_list, + load_json=load_json, + descending=descending, + ) diff --git a/piccolo/query/methods/select.py b/piccolo/query/methods/select.py index 4ba3a2977..ef868707a 100644 --- a/piccolo/query/methods/select.py +++ b/piccolo/query/methods/select.py @@ -17,6 +17,7 @@ from piccolo.columns.column_types import JSON, JSONB from piccolo.columns.m2m import M2MSelect from piccolo.columns.readable import Readable +from piccolo.columns.reverse_lookup import ReverseLookupSelect from piccolo.custom_types import TableInstance from piccolo.engine.base import BaseBatch from piccolo.query.base import Query @@ -250,33 +251,33 @@ def lock_rows( ) return self - async def _splice_m2m_rows( + async def _splice_related_rows( self, response: list[dict[str, Any]], secondary_table: type[Table], secondary_table_pk: Column, - m2m_name: str, - m2m_select: M2MSelect, + related_name: str, + related_select: Union[M2MSelect, ReverseLookupSelect], as_list: bool = False, ): row_ids = list( - set(itertools.chain(*[row[m2m_name] for row in response])) + set(itertools.chain(*[row[related_name] for row in response])) ) extra_rows = ( ( await secondary_table.select( - *m2m_select.columns, + *related_select.columns, secondary_table_pk.as_alias("mapping_key"), ) .where(secondary_table_pk.is_in(row_ids)) - .output(load_json=m2m_select.load_json) + .output(load_json=related_select.load_json) .run() ) if row_ids else [] ) if as_list: - column_name = m2m_select.columns[0]._meta.name + column_name = related_select.columns[0]._meta.name extra_rows_map = { row["mapping_key"]: row[column_name] for row in extra_rows } @@ -290,15 +291,22 @@ async def _splice_m2m_rows( for row in extra_rows } for row in response: - row[m2m_name] = [extra_rows_map.get(i) for i in row[m2m_name]] + row[related_name] = [ + extra_rows_map.get(i) for i in row[related_name] + ] return response - async def response_handler(self, response): + async def response_handler(self, response: list[dict[str, Any]]): m2m_selects = [ i for i in self.columns_delegate.selected_columns if isinstance(i, M2MSelect) ] + reverse_lookup_selects = [ + i + for i in self.columns_delegate.selected_columns + if isinstance(i, ReverseLookupSelect) + ] for m2m_select in m2m_selects: m2m_name = m2m_select.m2m._meta.name secondary_table = m2m_select.m2m._meta.secondary_table @@ -334,7 +342,7 @@ async def response_handler(self, response): if m2m_select.serialisation_safe: pass else: - response = await self._splice_m2m_rows( + response = await self._splice_related_rows( response, secondary_table, secondary_table_pk, @@ -353,7 +361,7 @@ async def response_handler(self, response): {column_name: i} for i in row[m2m_name] ] else: - response = await self._splice_m2m_rows( + response = await self._splice_related_rows( response, secondary_table, secondary_table_pk, @@ -383,7 +391,7 @@ async def response_handler(self, response): # If the data can't be safely serialised as JSON, we get # back an array of primary key values, and need to # splice in the correct values using Python. - response = await self._splice_m2m_rows( + response = await self._splice_related_rows( response, secondary_table, secondary_table_pk, @@ -391,6 +399,168 @@ async def response_handler(self, response): m2m_select, ) + for reverse_lookup_select in reverse_lookup_selects: + reverse_lookup = reverse_lookup_select.reverse_lookup + reverse_table = ( + reverse_lookup._meta.resolved_reverse_joining_table # noqa: E501 + ) + reverse_lookup_name = reverse_lookup._meta.name + + if self.engine_type == "sqlite": + # With ReverseLookup queries in SQLite, we always get + # the value back as a list of strings, so we need to + # do some type conversion. + value_type = ( + reverse_lookup_select.columns[0].__class__.value_type + if reverse_lookup_select.as_list + and reverse_lookup_select.serialisation_safe + else reverse_table._meta.primary_key.value_type + ) + try: + for row in response: + data = row[reverse_lookup_name] + row[reverse_lookup_name] = ( + [value_type(i) for i in row[reverse_lookup_name]] + if data + else [] + ) + except ValueError: + colored_warning( + "Unable to do type conversion for the " + f"{reverse_lookup_name} relation" + ) + + # If the user requested a single column, we just return that + # from the database. Otherwise we request the primary key + # value, so we can fetch the rest of the data in a subsequent + # SQL query - see below. + if reverse_lookup_select.as_list: + if reverse_lookup_select.serialisation_safe: + pass + else: + response = await self._splice_related_rows( + response, + reverse_table, + reverse_table._meta.primary_key, + reverse_lookup_name, + reverse_lookup_select, + as_list=True, + ) + else: + if ( + len(reverse_lookup_select.columns) == 1 + and reverse_lookup_select.serialisation_safe + ): + column_name = reverse_lookup_select.columns[ + 0 + ]._meta.name + for row in response: + if row[reverse_lookup_name] is None: + row[reverse_lookup_name] = [] + row[reverse_lookup_name] = [ + {column_name: i} + for i in row[reverse_lookup_name] + ] + elif ( + len(reverse_lookup_select.columns) == 0 + and reverse_lookup_select.serialisation_safe + ): + # if user request all columns + row_ids = list( + set( + itertools.chain( + *[ + row[reverse_lookup_name] + for row in response + ] + ) + ) + ) + extra_rows = ( + ( + await reverse_table.select( + *reverse_table._meta.columns, + reverse_table._meta.primary_key.as_alias( + "mapping_key" + ), + ) + .where( + reverse_table._meta.primary_key.is_in( + row_ids + ) + ) + .output( + load_json=reverse_lookup_select.load_json + ) + .run() + ) + if row_ids + else [] + ) + extra_rows_map = { + row["mapping_key"]: { + key: value + for key, value in row.items() + if key != "mapping_key" + } + for row in extra_rows + } + for row in response: + row[reverse_lookup_name] = [ + extra_rows_map.get(i) + for i in row[reverse_lookup_name] + ] + else: + response = await self._splice_related_rows( + response, + reverse_table, + reverse_table._meta.primary_key, + reverse_lookup_name, + reverse_lookup_select, + as_list=False, + ) + if self.engine_type in ("postgres", "cockroach"): + if reverse_lookup_select.as_list: + # We get the data back as an array, and can just return it + # unless it's JSON. + if ( + type(reverse_lookup_select.columns[0]) in (JSON, JSONB) + and reverse_lookup_select.load_json + ): + for row in response: + data = row[str(reverse_lookup_select.columns[0])] + row[str(reverse_lookup_select.columns[0])] = [ + load_json(i) for i in data + ] + + elif reverse_lookup_select.serialisation_safe: + # If the columns requested can be safely serialised, they + # are returned as a JSON string, so we need to deserialise + # it. + for row in response: + data = row[reverse_lookup_name] + row[reverse_lookup_name] = ( + load_json(data) if data else [] + ) + else: + # If the data can't be safely serialised as JSON, we get + # back an array of primary key values, and need to + # splice in the correct values using Python. + response = await self._splice_related_rows( + response, + reverse_table, + reverse_table._meta.primary_key, + reverse_lookup_name, + reverse_lookup_select, + as_list=False, + ) + + if reverse_lookup_select.descending: + for row in response: + row[reverse_lookup_name] = list( + reversed(row[reverse_lookup_name]) + ) + ####################################################################### # If no columns were specified, it's a select *, so we know that diff --git a/piccolo/table.py b/piccolo/table.py index bdfda2cdd..7188cc23a 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -29,6 +29,10 @@ ) from piccolo.columns.readable import Readable from piccolo.columns.reference import LAZY_COLUMN_REFERENCES +from piccolo.columns.reverse_lookup import ( + ReverseLookup, + ReverseLookupGetRelated, +) from piccolo.custom_types import TableInstance from piccolo.engine import Engine, engine_finder from piccolo.query import ( @@ -88,7 +92,9 @@ class TableMeta: tags: list[str] = field(default_factory=list) help_text: Optional[str] = None _db: Optional[Engine] = None - m2m_relationships: list[M2M] = field(default_factory=list) + m2m_relationships: list[Union[M2M, ReverseLookup]] = field( + default_factory=list + ) schema: Optional[str] = None # Records reverse foreign key relationships - i.e. when the current table @@ -278,7 +284,7 @@ def __init_subclass__( email_columns: list[Email] = [] auto_update_columns: list[Column] = [] primary_key: Optional[Column] = None - m2m_relationships: list[M2M] = [] + m2m_relationships: list[Union[M2M, ReverseLookup]] = [] attribute_names = itertools.chain( *[i.__dict__.keys() for i in reversed(cls.__mro__)] @@ -326,7 +332,7 @@ def __init_subclass__( if column._meta.auto_update is not ...: auto_update_columns.append(column) - if isinstance(attribute, M2M): + if isinstance(attribute, (M2M, ReverseLookup)): attribute._meta._name = attribute_name attribute._meta._table = cls m2m_relationships.append(attribute) @@ -731,6 +737,21 @@ def remove_m2m(self, *rows: Table, m2m: M2M) -> M2MRemoveRelated: m2m=m2m, ) + def get_reverse_lookup( + self, reverse_lookup: ReverseLookup + ) -> ReverseLookupGetRelated: + """ + Get all matching rows via the reverse lookup. + + .. code-block:: python + + >>> band = await Band.objects().get(Band.name == "Pythonistas") + >>> await band.get_reverse_lookup(Band.genres) + [, ] + + """ + return ReverseLookupGetRelated(row=self, reverse_lookup=reverse_lookup) + def to_dict(self, *columns: Column) -> dict[str, Any]: """ A convenience method which returns a dictionary, mapping column names @@ -1478,7 +1499,7 @@ def _table_str( for m2m_relationship in cls._meta.m2m_relationships: joining_table_name = ( - m2m_relationship._meta.resolved_joining_table.__name__ + m2m_relationship._meta.resolved_joining_table.__name__ # type: ignore # noqa: E501 ) columns.append( f"{m2m_relationship._meta.name} = M2M({joining_table_name})" diff --git a/tests/columns/test_reverse_lookup.py b/tests/columns/test_reverse_lookup.py new file mode 100644 index 000000000..bd9abd545 --- /dev/null +++ b/tests/columns/test_reverse_lookup.py @@ -0,0 +1,383 @@ +from unittest import TestCase + +from piccolo.columns.column_types import ( + UUID, + ForeignKey, + LazyTableReference, + Varchar, +) +from piccolo.columns.reverse_lookup import ReverseLookup +from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync +from tests.base import engine_is, engines_skip + + +class Manager(Table): + name = Varchar() + bands = ReverseLookup( + LazyTableReference( + "Band", + module_path=__name__, + ), + reverse_fk="manager", + ) + + +class Band(Table): + name = Varchar() + manager = ForeignKey(Manager) + + +SIMPLE_SCHEMA = [Manager, Band] + + +class TestReverseLookup(TestCase): + def setUp(self): + create_db_tables_sync(*SIMPLE_SCHEMA, if_not_exists=True) + + if engine_is("cockroach"): + managers = ( + Manager.insert( + Manager(name="Guido"), + Manager(name="Mark"), + Manager(name="John"), + ) + .returning(Manager.id) + .run_sync() + ) + + Band.insert( + Band(name="Pythonistas", manager=managers[0]["id"]), + Band(name="Rustaceans", manager=managers[0]["id"]), + Band(name="C-Sharps", manager=managers[1]["id"]), + ).returning(Band.id).run_sync() + + else: + Manager.insert( + Manager(name="Guido"), + Manager(name="Mark"), + Manager(name="John"), + ).run_sync() + + Band.insert( + Band(name="Pythonistas", manager=1), + Band(name="Rustaceans", manager=1), + Band(name="C-Sharps", manager=2), + ).run_sync() + + def tearDown(self): + drop_db_tables_sync(*SIMPLE_SCHEMA) + + def test_select_name(self): + response = Manager.select( + Manager.name, Manager.bands(Band.name, as_list=True) + ).run_sync() + + self.assertEqual( + response, + [ + {"name": "Guido", "bands": ["Pythonistas", "Rustaceans"]}, + {"name": "Mark", "bands": ["C-Sharps"]}, + {"name": "John", "bands": []}, + ], + ) + + def test_select_multiple(self): + response = Manager.select( + Manager.name, Manager.bands(Band.id, Band.name) + ).run_sync() + + if engine_is("cockroach"): + self.assertEqual(len(response[0]["bands"]), 2) + self.assertEqual(len(response[1]["bands"]), 1) + self.assertEqual(len(response[2]["bands"]), 0) + else: + self.assertEqual( + response, + [ + { + "name": "Guido", + "bands": [ + {"id": 1, "name": "Pythonistas"}, + {"id": 2, "name": "Rustaceans"}, + ], + }, + {"name": "Mark", "bands": [{"id": 3, "name": "C-Sharps"}]}, + { + "name": "John", + "bands": [], + }, + ], + ) + + def test_select_multiple_all_columns(self): + response = Manager.select(Manager.name, Manager.bands()).run_sync() + + if engine_is("cockroach"): + self.assertEqual(len(response[0]["bands"]), 2) + self.assertEqual(len(response[1]["bands"]), 1) + self.assertEqual(len(response[2]["bands"]), 0) + else: + self.assertEqual( + response, + [ + { + "name": "Guido", + "bands": [ + {"id": 1, "name": "Pythonistas", "manager": 1}, + {"id": 2, "name": "Rustaceans", "manager": 1}, + ], + }, + { + "name": "Mark", + "bands": [{"id": 3, "name": "C-Sharps", "manager": 2}], + }, + { + "name": "John", + "bands": [], + }, + ], + ) + + def test_select_id(self): + response = Manager.select( + Manager.name, Manager.bands(Band.id, as_list=True) + ).run_sync() + + if engine_is("cockroach"): + self.assertEqual(len(response[0]["bands"]), 2) + self.assertEqual(len(response[1]["bands"]), 1) + self.assertEqual(len(response[2]["bands"]), 0) + else: + self.assertEqual( + response, + [ + {"name": "Guido", "bands": [1, 2]}, + {"name": "Mark", "bands": [3]}, + {"name": "John", "bands": []}, + ], + ) + + def test_select_multiple_all_columns_descending(self): + response = Manager.select( + Manager.name, Manager.bands(descending=True) + ).run_sync() + + if engine_is("cockroach"): + self.assertEqual(len(response[0]["bands"]), 2) + self.assertEqual(len(response[1]["bands"]), 1) + self.assertEqual(len(response[2]["bands"]), 0) + else: + self.assertEqual( + response, + [ + { + "name": "Guido", + "bands": [ + {"id": 2, "name": "Rustaceans", "manager": 1}, + {"id": 1, "name": "Pythonistas", "manager": 1}, + ], + }, + { + "name": "Mark", + "bands": [{"id": 3, "name": "C-Sharps", "manager": 2}], + }, + { + "name": "John", + "bands": [], + }, + ], + ) + + def test_select_id_descending(self): + response = Manager.select( + Manager.name, Manager.bands(Band.id, as_list=True, descending=True) + ).run_sync() + + if engine_is("cockroach"): + self.assertEqual(len(response[0]["bands"]), 2) + self.assertEqual(len(response[1]["bands"]), 1) + self.assertEqual(len(response[2]["bands"]), 0) + else: + self.assertEqual( + response, + [ + {"name": "Guido", "bands": [2, 1]}, + {"name": "Mark", "bands": [3]}, + {"name": "John", "bands": []}, + ], + ) + + def test_select_multiple_as_list_error(self): + with self.assertRaises(ValueError): + Manager.select( + Manager.name, + Manager.bands(Band.id, Band.name, as_list=True), + ).run_sync() + + def test_objects_query(self): + manager = Manager.objects().get(Manager.name == "Guido").run_sync() + bands = manager.get_reverse_lookup(Manager.bands).run_sync() + response = { + "name": manager.name, + "bands": [i.to_dict() for i in bands], + } + + if engine_is("cockroach"): + self.assertEqual(len(response["bands"]), 2) + else: + self.assertEqual( + response, + { + "name": "Guido", + "bands": [ + {"id": 1, "name": "Pythonistas", "manager": 1}, + {"id": 2, "name": "Rustaceans", "manager": 1}, + ], + }, + ) + + +############################################################################### + +# A schema using custom primary keys + + +class Customer(Table): + uuid = UUID(primary_key=True) + name = Varchar() + concerts = ReverseLookup( + LazyTableReference( + "Concert", + module_path=__name__, + ), + reverse_fk="customer", + ) + + +class Concert(Table): + uuid = UUID(primary_key=True) + name = Varchar() + customer = ForeignKey(Customer) + + +CUSTOM_PK_SCHEMA = [Customer, Concert] + + +class TestReverseLookupCustomPrimaryKey(TestCase): + """ + Make sure the ReverseLookupCustom functionality works correctly + when the tables have custom primary key columns. + """ + + def setUp(self): + create_db_tables_sync(*CUSTOM_PK_SCHEMA, if_not_exists=True) + + Customer.objects().create(name="Bob").run_sync() + Customer.objects().create(name="Sally").run_sync() + Customer.objects().create(name="Fred").run_sync() + + bob_pk = ( + Customer.select(Customer.uuid) + .where(Customer.name == "Bob") + .first() + .run_sync() + ) + sally_pk = ( + Customer.select(Customer.uuid) + .where(Customer.name == "Sally") + .first() + .run_sync() + ) + + Concert.objects().create( + name="Rockfest", customer=bob_pk["uuid"] + ).run_sync() + Concert.objects().create( + name="Folkfest", customer=bob_pk["uuid"] + ).run_sync() + Concert.objects().create( + name="Classicfest", customer=sally_pk["uuid"] + ).run_sync() + + def tearDown(self): + drop_db_tables_sync(*CUSTOM_PK_SCHEMA) + + @engines_skip("cockroach") + def test_select_custom_primary_key(self): + response = Customer.select( + Customer.name, + Customer.concerts(Concert.name, as_list=True), + ).run_sync() + + self.assertListEqual( + response, + [ + {"name": "Bob", "concerts": ["Rockfest", "Folkfest"]}, + {"name": "Sally", "concerts": ["Classicfest"]}, + {"name": "Fred", "concerts": []}, + ], + ) + + response = Customer.select( + Customer.name, Customer.concerts(Concert.name) + ).run_sync() + + self.assertEqual( + response, + [ + { + "name": "Bob", + "concerts": [ + {"name": "Rockfest"}, + {"name": "Folkfest"}, + ], + }, + {"name": "Sally", "concerts": [{"name": "Classicfest"}]}, + {"name": "Fred", "concerts": []}, + ], + ) + + response = Customer.select( + Customer.name, Customer.concerts(Concert.name, descending=True) + ).run_sync() + + self.assertEqual( + response, + [ + { + "name": "Bob", + "concerts": [ + {"name": "Folkfest"}, + {"name": "Rockfest"}, + ], + }, + {"name": "Sally", "concerts": [{"name": "Classicfest"}]}, + {"name": "Fred", "concerts": []}, + ], + ) + + def test_objects_custom_primary_key(self): + customer_bob = ( + Customer.objects().get(Customer.name == "Bob").run_sync() + ) + concerts_bob = customer_bob.get_reverse_lookup( + Customer.concerts + ).run_sync() + + customer_sally = ( + Customer.objects().get(Customer.name == "Sally").run_sync() + ) + concerts_sally = customer_sally.get_reverse_lookup( + Customer.concerts + ).run_sync() + + customer_fred = ( + Customer.objects().get(Customer.name == "Fred").run_sync() + ) + concerts_fred = customer_fred.get_reverse_lookup( + Customer.concerts + ).run_sync() + + self.assertEqual(len(concerts_bob), 2) + self.assertEqual(len(concerts_sally), 1) + self.assertEqual(len(concerts_fred), 0)