diff --git a/django_cte/cte.py b/django_cte/cte.py index 11743df..7faad61 100644 --- a/django_cte/cte.py +++ b/django_cte/cte.py @@ -1,5 +1,6 @@ from copy import copy +import django from django.db.models import Manager, sql from django.db.models.expressions import Ref from django.db.models.query import Q, QuerySet, ValuesIterable @@ -45,21 +46,30 @@ class CTE: """ def __init__(self, queryset, name="cte", materialized=False): - self.query = None if queryset is None else queryset.query + self._set_queryset(queryset) self.name = name self.col = CTEColumns(self) self.materialized = materialized def __getstate__(self): - return (self.query, self.name, self.materialized) + return (self.query, self.name, self.materialized, self._iterable_class) def __setstate__(self, state): - self.query, self.name, self.materialized = state + if len(state) == 3: + # Keep compatibility with the previous serialization method + self.query, self.name, self.materialized = state + self._iterable_class = ValuesIterable + else: + self.query, self.name, self.materialized, self._iterable_class = state self.col = CTEColumns(self) def __repr__(self): return f"<{type(self).__name__} {self.name}>" + def _set_queryset(self, queryset): + self.query = None if queryset is None else queryset.query + self._iterable_class = getattr(queryset, "_iterable_class", ValuesIterable) + @classmethod def recursive(cls, make_cte_queryset, name="cte", materialized=False): """Recursive Common Table Expression @@ -73,7 +83,7 @@ def recursive(cls, make_cte_queryset, name="cte", materialized=False): :returns: The fully constructed recursive cte object. """ cte = cls(None, name, materialized) - cte.query = make_cte_queryset(cte).query + cte._set_queryset(make_cte_queryset(cte)) return cte def join(self, model_or_queryset, *filter_q, **filter_kw): @@ -124,24 +134,30 @@ def queryset(self): """ cte_query = self.query qs = cte_query.model._default_manager.get_queryset() + qs._iterable_class = self._iterable_class + qs._fields = () # Allow any field names to be used in further annotations query = jit_mixin(sql.Query(cte_query.model), CTEQuery) query.join(BaseTable(self.name, None)) query.default_cols = cte_query.default_cols query.deferred_loading = cte_query.deferred_loading - if cte_query.values_select: + + if django.VERSION < (5, 2) and cte_query.values_select: query.set_values(cte_query.values_select) - qs._iterable_class = ValuesIterable + if cte_query.annotations: for alias, value in cte_query.annotations.items(): col = CTEColumnRef(alias, self.name, value.output_field) query.add_annotation(col, alias) query.annotation_select_mask = cte_query.annotation_select_mask - for alias in getattr(cte_query, "selected", None) or (): - if alias not in cte_query.annotations: - output_field = cte_query.resolve_ref(alias).output_field - col = CTEColumnRef(alias, self.name, output_field) - query.add_annotation(col, alias) + + if selected := getattr(cte_query, "selected", None): + for alias in selected: + if alias not in cte_query.annotations: + output_field = cte_query.resolve_ref(alias).output_field + col = CTEColumnRef(alias, self.name, output_field) + query.add_annotation(col, alias) + query.selected = {alias: alias for alias in selected} qs.query = query return qs diff --git a/tests/test_cte.py b/tests/test_cte.py index 1c28d4c..f3831a9 100644 --- a/tests/test_cte.py +++ b/tests/test_cte.py @@ -1,4 +1,5 @@ import pytest +import django from django.db.models import IntegerField, TextField from django.db.models.aggregates import Count, Max, Min, Sum from django.db.models.expressions import ( @@ -711,3 +712,129 @@ def test_django52_ambiguous_column_names(self): ('venus', 22, "admin"), ('venus', 23, "admin"), ]) + + def test_django52_queryset_aggregates_klass_error(self): + cte = CTE( + Order.objects.annotate(user_name=F("user__name")) + .values("user_name") + .annotate(c=Count("user_name")) + .values("user_name", "c") + ) + qs = with_cte(cte, select=cte) + # Executing the query should not raise TypeError: 'NoneType' object is not subscriptable + self.assertEqual(list(qs), [{"user_name": "admin", "c": 22}]) + + def test_django52_annotate_model_field_name_after_queryset(self): + # Select the `id` field in one CTE + cte = CTE(Order.objects.values("id", "region", "user_id")) + # In the next query, when querying from the CTE we reassign the `id` field + # Previously, this would have thrown an exception + qs = ( + with_cte(cte, select=cte) + .annotate(id=F('user_id')) + .values_list('id', 'region') + .order_by('id', 'region') + .distinct() + ) + self.assertEqual(list(qs), [ + (1, 'earth'), + (1, 'mars'), + (1, 'mercury'), + (1, 'moon'), + (1, 'proxima centauri'), + (1, 'proxima centauri b'), + (1, 'sun'), + (1, 'venus'), + ]) + + @pytest.mark.skipif(django.VERSION < (5, 2), reason="Requires Django 5.2+") + def test_queryset_after_values_list(self): + cte = CTE(Order.objects.values_list("region", "amount").order_by("region", "amount")) + qs = with_cte(cte, select=cte) + self.assertEqual(list(qs), [ + ('earth', 30), + ('earth', 31), + ('earth', 32), + ('earth', 33), + ('mars', 40), + ('mars', 41), + ('mars', 42), + ('mercury', 10), + ('mercury', 11), + ('mercury', 12), + ('moon', 1), + ('moon', 2), + ('moon', 3), + ('proxima centauri', 2000), + ('proxima centauri b', 10), + ('proxima centauri b', 11), + ('proxima centauri b', 12), + ('sun', 1000), + ('venus', 20), + ('venus', 21), + ('venus', 22), + ('venus', 23), + ]) + + @pytest.mark.skipif(django.VERSION < (5, 2), reason="Requires Django 5.2+") + def test_queryset_after_values_list_flat(self): + cte = CTE( + Order.objects.values_list("region", flat=True) + .order_by("region") + .distinct() + ) + qs = with_cte(cte, select=cte) + self.assertEqual(list(qs), [ + 'earth', + 'mars', + 'mercury', + 'moon', + 'proxima centauri', + 'proxima centauri b', + 'sun', + 'venus' + ]) + + @pytest.mark.skipif(django.VERSION < (5, 2), reason="Requires Django 5.2+") + def test_queryset_values_list_order1(self): + cte = CTE( + Order.objects.values("region") + .annotate(c=Count("region")) + .values_list("c", "region") + .order_by("region") + ) + qs = with_cte(cte, select=cte) + # Ensure the column order of queried fields is the specified one: c, region + # Before the fix, the order would have been this one: region, c + self.assertEqual(list(qs), [ + (4, 'earth'), + (3, 'mars'), + (3, 'mercury'), + (3, 'moon'), + (1, 'proxima centauri'), + (3, 'proxima centauri b'), + (1, 'sun'), + (4, 'venus'), + ]) + + @pytest.mark.skipif(django.VERSION < (5, 2), reason="Requires Django 5.2+") + def test_queryset_values_list_order2(self): + cte = CTE( + Order.objects.values("region") + .annotate(r=F("region"), c=Count("region")) + .values_list("c", "r") + .order_by("r") + ) + qs = with_cte(cte, select=cte) + # Ensure the column order of queried fields is the specified one: c, r + # Before the fix, the order would have been this one: r, c + self.assertEqual(list(qs), [ + (4, 'earth'), + (3, 'mars'), + (3, 'mercury'), + (3, 'moon'), + (1, 'proxima centauri'), + (3, 'proxima centauri b'), + (1, 'sun'), + (4, 'venus'), + ])