Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions django_cte/cte.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from copy import copy

import django
from django.db.models import Manager, sql
from django.db.models.expressions import Ref
from django.db.models.query import Q, QuerySet, ValuesIterable
Expand Down Expand Up @@ -45,21 +46,30 @@ class CTE:
"""

def __init__(self, queryset, name="cte", materialized=False):
self.query = None if queryset is None else queryset.query
self._set_queryset(queryset)
self.name = name
self.col = CTEColumns(self)
self.materialized = materialized

def __getstate__(self):
return (self.query, self.name, self.materialized)
return (self.query, self.name, self.materialized, self._iterable_class)

def __setstate__(self, state):
self.query, self.name, self.materialized = state
if len(state) == 3:
# Keep compatibility with the previous serialization method
self.query, self.name, self.materialized = state
self._iterable_class = ValuesIterable
else:
self.query, self.name, self.materialized, self._iterable_class = state
self.col = CTEColumns(self)

def __repr__(self):
return f"<{type(self).__name__} {self.name}>"

def _set_queryset(self, queryset):
self.query = None if queryset is None else queryset.query
self._iterable_class = getattr(queryset, "_iterable_class", ValuesIterable)

@classmethod
def recursive(cls, make_cte_queryset, name="cte", materialized=False):
"""Recursive Common Table Expression
Expand All @@ -73,7 +83,7 @@ def recursive(cls, make_cte_queryset, name="cte", materialized=False):
:returns: The fully constructed recursive cte object.
"""
cte = cls(None, name, materialized)
cte.query = make_cte_queryset(cte).query
cte._set_queryset(make_cte_queryset(cte))
return cte

def join(self, model_or_queryset, *filter_q, **filter_kw):
Expand Down Expand Up @@ -124,24 +134,30 @@ def queryset(self):
"""
cte_query = self.query
qs = cte_query.model._default_manager.get_queryset()
qs._iterable_class = self._iterable_class
qs._fields = () # Allow any field names to be used in further annotations

query = jit_mixin(sql.Query(cte_query.model), CTEQuery)
query.join(BaseTable(self.name, None))
query.default_cols = cte_query.default_cols
query.deferred_loading = cte_query.deferred_loading
if cte_query.values_select:

if django.VERSION < (5, 2) and cte_query.values_select:
query.set_values(cte_query.values_select)
qs._iterable_class = ValuesIterable

if cte_query.annotations:
for alias, value in cte_query.annotations.items():
col = CTEColumnRef(alias, self.name, value.output_field)
query.add_annotation(col, alias)
query.annotation_select_mask = cte_query.annotation_select_mask
for alias in getattr(cte_query, "selected", None) or ():
if alias not in cte_query.annotations:
output_field = cte_query.resolve_ref(alias).output_field
col = CTEColumnRef(alias, self.name, output_field)
query.add_annotation(col, alias)

if selected := getattr(cte_query, "selected", None):
for alias in selected:
if alias not in cte_query.annotations:
output_field = cte_query.resolve_ref(alias).output_field
col = CTEColumnRef(alias, self.name, output_field)
query.add_annotation(col, alias)
query.selected = {alias: alias for alias in selected}

qs.query = query
return qs
Expand Down
127 changes: 127 additions & 0 deletions tests/test_cte.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import django
from django.db.models import IntegerField, TextField
from django.db.models.aggregates import Count, Max, Min, Sum
from django.db.models.expressions import (
Expand Down Expand Up @@ -711,3 +712,129 @@ def test_django52_ambiguous_column_names(self):
('venus', 22, "admin"),
('venus', 23, "admin"),
])

def test_django52_queryset_aggregates_klass_error(self):
cte = CTE(
Order.objects.annotate(user_name=F("user__name"))
.values("user_name")
.annotate(c=Count("user_name"))
.values("user_name", "c")
)
qs = with_cte(cte, select=cte)
# Executing the query should not raise TypeError: 'NoneType' object is not subscriptable
self.assertEqual(list(qs), [{"user_name": "admin", "c": 22}])

def test_django52_annotate_model_field_name_after_queryset(self):
# Select the `id` field in one CTE
cte = CTE(Order.objects.values("id", "region", "user_id"))
# In the next query, when querying from the CTE we reassign the `id` field
# Previously, this would have thrown an exception
qs = (
with_cte(cte, select=cte)
.annotate(id=F('user_id'))
.values_list('id', 'region')
.order_by('id', 'region')
.distinct()
)
self.assertEqual(list(qs), [
(1, 'earth'),
(1, 'mars'),
(1, 'mercury'),
(1, 'moon'),
(1, 'proxima centauri'),
(1, 'proxima centauri b'),
(1, 'sun'),
(1, 'venus'),
])

@pytest.mark.skipif(django.VERSION < (5, 2), reason="Requires Django 5.2+")
def test_queryset_after_values_list(self):
cte = CTE(Order.objects.values_list("region", "amount").order_by("region", "amount"))
qs = with_cte(cte, select=cte)
self.assertEqual(list(qs), [
('earth', 30),
('earth', 31),
('earth', 32),
('earth', 33),
('mars', 40),
('mars', 41),
('mars', 42),
('mercury', 10),
('mercury', 11),
('mercury', 12),
('moon', 1),
('moon', 2),
('moon', 3),
('proxima centauri', 2000),
('proxima centauri b', 10),
('proxima centauri b', 11),
('proxima centauri b', 12),
('sun', 1000),
('venus', 20),
('venus', 21),
('venus', 22),
('venus', 23),
])

@pytest.mark.skipif(django.VERSION < (5, 2), reason="Requires Django 5.2+")
def test_queryset_after_values_list_flat(self):
cte = CTE(
Order.objects.values_list("region", flat=True)
.order_by("region")
.distinct()
)
qs = with_cte(cte, select=cte)
self.assertEqual(list(qs), [
'earth',
'mars',
'mercury',
'moon',
'proxima centauri',
'proxima centauri b',
'sun',
'venus'
])

@pytest.mark.skipif(django.VERSION < (5, 2), reason="Requires Django 5.2+")
def test_queryset_values_list_order1(self):
cte = CTE(
Order.objects.values("region")
.annotate(c=Count("region"))
.values_list("c", "region")
.order_by("region")
)
qs = with_cte(cte, select=cte)
# Ensure the column order of queried fields is the specified one: c, region
# Before the fix, the order would have been this one: region, c
self.assertEqual(list(qs), [
(4, 'earth'),
(3, 'mars'),
(3, 'mercury'),
(3, 'moon'),
(1, 'proxima centauri'),
(3, 'proxima centauri b'),
(1, 'sun'),
(4, 'venus'),
])

@pytest.mark.skipif(django.VERSION < (5, 2), reason="Requires Django 5.2+")
def test_queryset_values_list_order2(self):
cte = CTE(
Order.objects.values("region")
.annotate(r=F("region"), c=Count("region"))
.values_list("c", "r")
.order_by("r")
)
qs = with_cte(cte, select=cte)
# Ensure the column order of queried fields is the specified one: c, r
# Before the fix, the order would have been this one: r, c
self.assertEqual(list(qs), [
(4, 'earth'),
(3, 'mars'),
(3, 'mercury'),
(3, 'moon'),
(1, 'proxima centauri'),
(3, 'proxima centauri b'),
(1, 'sun'),
(4, 'venus'),
])