From debe88b858314e54b921feb2d28540c7d02f49b4 Mon Sep 17 00:00:00 2001 From: sinisaos Date: Fri, 31 Mar 2023 07:38:57 +0200 Subject: [PATCH 1/5] upserting api attempt --- piccolo/query/methods/insert.py | 47 ++++++++++++-- piccolo/query/mixins.py | 20 +++--- piccolo/table.py | 9 ++- tests/table/test_insert.py | 106 +++++++++++++++++++++++++++++++- 4 files changed, 167 insertions(+), 15 deletions(-) diff --git a/piccolo/query/methods/insert.py b/piccolo/query/methods/insert.py index 9f31f445a..0c38f5ff5 100644 --- a/piccolo/query/methods/insert.py +++ b/piccolo/query/methods/insert.py @@ -4,7 +4,12 @@ from piccolo.custom_types import TableInstance from piccolo.query.base import Query -from piccolo.query.mixins import AddDelegate, ReturningDelegate +from piccolo.query.mixins import ( + AddDelegate, + Conflict, + OnConflictDelegate, + ReturningDelegate, +) from piccolo.querystring import QueryString if t.TYPE_CHECKING: # pragma: no cover @@ -15,14 +20,25 @@ class Insert( t.Generic[TableInstance], Query[TableInstance, t.List[t.Dict[str, t.Any]]] ): - __slots__ = ("add_delegate", "returning_delegate") + __slots__ = ( + "add_delegate", + "returning_delegate", + "on_conflict_delegate", + "values_delegate", + ) def __init__( - self, table: t.Type[TableInstance], *instances: TableInstance, **kwargs + self, + table: t.Type[TableInstance], + on_conflict: t.Optional[Conflict] = None, + *instances: TableInstance, + **kwargs, ): super().__init__(table, **kwargs) self.add_delegate = AddDelegate() self.returning_delegate = ReturningDelegate() + self.on_conflict_delegate = OnConflictDelegate() + self.on_conflict(on_conflict) # type: ignore self.add(*instances) ########################################################################### @@ -36,6 +52,10 @@ def returning(self: Self, *columns: Column) -> Self: self.returning_delegate.returning(columns) return self + def on_conflict(self: Self, conflict: Conflict) -> Self: + self.on_conflict_delegate.on_conflict(conflict) + return self + ########################################################################### def _raw_response_callback(self, results): @@ -60,7 +80,25 @@ def default_querystrings(self) -> t.Sequence[QueryString]: f'"{i._meta.db_column_name}"' for i in self.table._meta.columns ) values = ",".join("{}" for _ in self.add_delegate._add) - query = f"{base} ({columns}) VALUES {values}" + if self.on_conflict_delegate._on_conflict is not None: + if self.on_conflict_delegate._on_conflict.value == "DO NOTHING": + query = f""" + {base} ({columns}) VALUES {values} ON CONFLICT + {self.on_conflict_delegate._on_conflict.value} + """ + else: + excluded_updated_columns = ", ".join( + f"{i._meta.db_column_name}=EXCLUDED.{i._meta.db_column_name}" # noqa: E501 + for i in self.table._meta.columns + ) + query = f""" + {base} ({columns}) VALUES {values} ON CONFLICT + ({self.table._meta.primary_key._meta.name}) + {self.on_conflict_delegate._on_conflict.value} + SET {excluded_updated_columns} + """ + else: + query = f"{base} ({columns}) VALUES {values}" querystring = QueryString( query, *[i.querystring for i in self.add_delegate._add], @@ -84,7 +122,6 @@ def default_querystrings(self) -> t.Sequence[QueryString]: table=self.table, ) ] - return [querystring] diff --git a/piccolo/query/mixins.py b/piccolo/query/mixins.py index 397f84f99..91ec166c8 100644 --- a/piccolo/query/mixins.py +++ b/piccolo/query/mixins.py @@ -141,7 +141,6 @@ def __str__(self): @dataclass class Output: - as_json: bool = False as_list: bool = False as_objects: bool = False @@ -170,7 +169,6 @@ class Callback: @dataclass class WhereDelegate: - _where: t.Optional[Combinable] = None _where_columns: t.List[Column] = field(default_factory=list) @@ -205,7 +203,6 @@ def where(self, *where: Combinable): @dataclass class OrderByDelegate: - _order_by: OrderBy = field(default_factory=OrderBy) def get_order_by_columns(self) -> t.List[Column]: @@ -231,7 +228,6 @@ def order_by(self, *columns: t.Union[Column, OrderByRaw], ascending=True): @dataclass class LimitDelegate: - _limit: t.Optional[Limit] = None _first: bool = False @@ -258,7 +254,6 @@ def as_of(self, interval: str = "-1s"): @dataclass class DistinctDelegate: - _distinct: bool = False def distinct(self): @@ -275,7 +270,6 @@ def returning(self, columns: t.Sequence[Column]): @dataclass class CountDelegate: - _count: bool = False def count(self): @@ -284,7 +278,6 @@ def count(self): @dataclass class AddDelegate: - _add: t.List[Table] = field(default_factory=list) def add(self, *instances: Table, table_class: t.Type[Table]): @@ -548,3 +541,16 @@ class GroupByDelegate: def group_by(self, *columns: Column): self._group_by = GroupBy(columns=columns) + + +class Conflict(Enum): + do_nothing = "DO NOTHING" + do_update = "DO UPDATE" + + +@dataclass +class OnConflictDelegate: + _on_conflict: t.Optional[Conflict] = None + + def on_conflict(self, conflict: Conflict): + self._on_conflict = conflict diff --git a/piccolo/table.py b/piccolo/table.py index 2982180bd..0034c6179 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -46,6 +46,7 @@ from piccolo.query.methods.indexes import Indexes from piccolo.query.methods.objects import First from piccolo.query.methods.refresh import Refresh +from piccolo.query.mixins import Conflict from piccolo.querystring import QueryString, Unquoted from piccolo.utils import _camel_to_snake from piccolo.utils.graphlib import TopologicalSorter @@ -906,7 +907,9 @@ def ref(cls, column_name: str) -> Column: @classmethod def insert( - cls: t.Type[TableInstance], *rows: TableInstance + cls: t.Type[TableInstance], + *rows: TableInstance, + on_conflict: t.Optional[Conflict] = None, ) -> Insert[TableInstance]: """ Insert rows into the database. @@ -918,7 +921,9 @@ def insert( ) """ - query = Insert(table=cls).returning(cls._meta.primary_key) + query = Insert(table=cls, on_conflict=on_conflict).returning( + cls._meta.primary_key + ) if rows: query.add(*rows) return query diff --git a/tests/table/test_insert.py b/tests/table/test_insert.py index 474497f25..da9a14b67 100644 --- a/tests/table/test_insert.py +++ b/tests/table/test_insert.py @@ -1,6 +1,12 @@ import pytest -from tests.base import DBTestCase, engine_version_lt, is_running_sqlite +from piccolo.query.mixins import Conflict +from tests.base import ( + DBTestCase, + engine_version_lt, + engines_only, + is_running_sqlite, +) from tests.example_apps.music.tables import Band, Manager @@ -45,6 +51,104 @@ def test_insert_curly_braces(self): self.assertIn("{}", names) + @engines_only("postgres", "sqlite") + def test_insert_on_conflict_do_nothing(self): + """ + Check that the record has not changed because of the + `on_conflict` clause. + """ + self.insert_rows() + + Band.insert( + Band(id=1, name="Javas", popularity=100), + on_conflict=Conflict.do_nothing, + ).run_sync() + + response = ( + Band.select(Band.name).where(Band.id == 1).first().run_sync() + ) + self.assertEqual(response["name"], "Pythonistas") + + @engines_only("postgres", "sqlite") + def test_insert_on_conflict_do_update_single_column(self): + """ + Check that the record has changed because of the + `on_update` clause. + """ + self.insert_rows() + + Band.insert( + Band(id=1, name="Pythonstas-updated", manager=1, popularity=1000), + Band(id=2, name="Rustaceans-updated", manager=2, popularity=2000), + Band(id=3, name="CSharps-updated", manager=3, popularity=10), + on_conflict=Conflict.do_update, + ).run_sync() + + response = Band.select().run_sync() + self.assertEqual( + response, + [ + { + "id": 1, + "name": "Pythonstas-updated", + "manager": 1, + "popularity": 1000, + }, + { + "id": 2, + "name": "Rustaceans-updated", + "manager": 2, + "popularity": 2000, + }, + { + "id": 3, + "name": "CSharps-updated", + "manager": 3, + "popularity": 10, + }, + ], + ) + + @engines_only("postgres", "sqlite") + def test_insert_on_conflict_do_update_multiple_columns(self): + """ + Check that the record has changed because of the + `on_update` clause. + """ + self.insert_rows() + + Band.insert( + Band(id=1, name="Pythonstas-updated", manager=3, popularity=200), + Band(id=2, name="Rustaceans-updated", manager=2, popularity=1000), + Band(id=3, name="CSharps-updated", manager=1, popularity=20), + on_conflict=Conflict.do_update, + ).run_sync() + + response = Band.select().run_sync() + self.assertEqual( + response, + [ + { + "id": 1, + "name": "Pythonstas-updated", + "manager": 3, + "popularity": 200, + }, + { + "id": 2, + "name": "Rustaceans-updated", + "manager": 2, + "popularity": 1000, + }, + { + "id": 3, + "name": "CSharps-updated", + "manager": 1, + "popularity": 20, + }, + ], + ) + @pytest.mark.skipif( is_running_sqlite() and engine_version_lt(3.35), reason="SQLite version not supported", From dcafcf1e4873073f398050d3838a8c4cf31480fc Mon Sep 17 00:00:00 2001 From: sinisaos Date: Fri, 31 Mar 2023 14:35:03 +0200 Subject: [PATCH 2/5] support Sqlite less than version 3.24.0 --- piccolo/query/methods/insert.py | 24 ++++++++++++++++-------- piccolo/query/mixins.py | 6 +++--- piccolo/table.py | 4 ++-- tests/table/test_insert.py | 8 ++++---- 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/piccolo/query/methods/insert.py b/piccolo/query/methods/insert.py index 0c38f5ff5..45064d6f7 100644 --- a/piccolo/query/methods/insert.py +++ b/piccolo/query/methods/insert.py @@ -6,7 +6,7 @@ from piccolo.query.base import Query from piccolo.query.mixins import ( AddDelegate, - Conflict, + OnConflict, OnConflictDelegate, ReturningDelegate, ) @@ -24,13 +24,12 @@ class Insert( "add_delegate", "returning_delegate", "on_conflict_delegate", - "values_delegate", ) def __init__( self, table: t.Type[TableInstance], - on_conflict: t.Optional[Conflict] = None, + on_conflict: t.Optional[OnConflict] = None, *instances: TableInstance, **kwargs, ): @@ -52,7 +51,7 @@ def returning(self: Self, *columns: Column) -> Self: self.returning_delegate.returning(columns) return self - def on_conflict(self: Self, conflict: Conflict) -> Self: + def on_conflict(self: Self, conflict: OnConflict) -> Self: self.on_conflict_delegate.on_conflict(conflict) return self @@ -75,13 +74,23 @@ def _raw_response_callback(self, results): @property def default_querystrings(self) -> t.Sequence[QueryString]: - base = f'INSERT INTO "{self.table._meta.tablename}"' + engine_type = self.engine_type + if ( + engine_type == "sqlite" + and self.table._meta.db.get_version_sync() < 3.24 + ): + if self.on_conflict_delegate._on_conflict == OnConflict.do_nothing: + base = f'INSERT OR IGNORE INTO "{self.table._meta.tablename}"' + else: + base = f'INSERT OR REPLACE INTO "{self.table._meta.tablename}"' + else: + base = f'INSERT INTO "{self.table._meta.tablename}"' columns = ",".join( f'"{i._meta.db_column_name}"' for i in self.table._meta.columns ) values = ",".join("{}" for _ in self.add_delegate._add) if self.on_conflict_delegate._on_conflict is not None: - if self.on_conflict_delegate._on_conflict.value == "DO NOTHING": + if self.on_conflict_delegate._on_conflict == OnConflict.do_nothing: query = f""" {base} ({columns}) VALUES {values} ON CONFLICT {self.on_conflict_delegate._on_conflict.value} @@ -106,8 +115,6 @@ def default_querystrings(self) -> t.Sequence[QueryString]: table=self.table, ) - engine_type = self.engine_type - if engine_type in ("postgres", "cockroach") or ( engine_type == "sqlite" and self.table._meta.db.get_version_sync() >= 3.35 @@ -122,6 +129,7 @@ def default_querystrings(self) -> t.Sequence[QueryString]: table=self.table, ) ] + return [querystring] diff --git a/piccolo/query/mixins.py b/piccolo/query/mixins.py index 91ec166c8..9c96dbfb8 100644 --- a/piccolo/query/mixins.py +++ b/piccolo/query/mixins.py @@ -543,14 +543,14 @@ def group_by(self, *columns: Column): self._group_by = GroupBy(columns=columns) -class Conflict(Enum): +class OnConflict(str, Enum): do_nothing = "DO NOTHING" do_update = "DO UPDATE" @dataclass class OnConflictDelegate: - _on_conflict: t.Optional[Conflict] = None + _on_conflict: t.Optional[OnConflict] = None - def on_conflict(self, conflict: Conflict): + def on_conflict(self, conflict: OnConflict): self._on_conflict = conflict diff --git a/piccolo/table.py b/piccolo/table.py index 0034c6179..375d62b7e 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -46,7 +46,7 @@ from piccolo.query.methods.indexes import Indexes from piccolo.query.methods.objects import First from piccolo.query.methods.refresh import Refresh -from piccolo.query.mixins import Conflict +from piccolo.query.mixins import OnConflict from piccolo.querystring import QueryString, Unquoted from piccolo.utils import _camel_to_snake from piccolo.utils.graphlib import TopologicalSorter @@ -909,7 +909,7 @@ def ref(cls, column_name: str) -> Column: def insert( cls: t.Type[TableInstance], *rows: TableInstance, - on_conflict: t.Optional[Conflict] = None, + on_conflict: t.Optional[OnConflict] = None, ) -> Insert[TableInstance]: """ Insert rows into the database. diff --git a/tests/table/test_insert.py b/tests/table/test_insert.py index da9a14b67..cd9eb0990 100644 --- a/tests/table/test_insert.py +++ b/tests/table/test_insert.py @@ -1,6 +1,6 @@ import pytest -from piccolo.query.mixins import Conflict +from piccolo.query.mixins import OnConflict from tests.base import ( DBTestCase, engine_version_lt, @@ -61,7 +61,7 @@ def test_insert_on_conflict_do_nothing(self): Band.insert( Band(id=1, name="Javas", popularity=100), - on_conflict=Conflict.do_nothing, + on_conflict=OnConflict.do_nothing, ).run_sync() response = ( @@ -81,7 +81,7 @@ def test_insert_on_conflict_do_update_single_column(self): Band(id=1, name="Pythonstas-updated", manager=1, popularity=1000), Band(id=2, name="Rustaceans-updated", manager=2, popularity=2000), Band(id=3, name="CSharps-updated", manager=3, popularity=10), - on_conflict=Conflict.do_update, + on_conflict=OnConflict.do_update, ).run_sync() response = Band.select().run_sync() @@ -121,7 +121,7 @@ def test_insert_on_conflict_do_update_multiple_columns(self): Band(id=1, name="Pythonstas-updated", manager=3, popularity=200), Band(id=2, name="Rustaceans-updated", manager=2, popularity=1000), Band(id=3, name="CSharps-updated", manager=1, popularity=20), - on_conflict=Conflict.do_update, + on_conflict=OnConflict.do_update, ).run_sync() response = Band.select().run_sync() From 3c1b74fc37879bce09159f066a4053837d66d908 Mon Sep 17 00:00:00 2001 From: sinisaos Date: Fri, 31 Mar 2023 17:22:06 +0200 Subject: [PATCH 3/5] CockroachDB tests --- piccolo/query/methods/insert.py | 2 +- tests/table/test_insert.py | 143 ++++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+), 1 deletion(-) diff --git a/piccolo/query/methods/insert.py b/piccolo/query/methods/insert.py index 45064d6f7..e2575482b 100644 --- a/piccolo/query/methods/insert.py +++ b/piccolo/query/methods/insert.py @@ -78,7 +78,7 @@ def default_querystrings(self) -> t.Sequence[QueryString]: if ( engine_type == "sqlite" and self.table._meta.db.get_version_sync() < 3.24 - ): + ): # pragma: no cover if self.on_conflict_delegate._on_conflict == OnConflict.do_nothing: base = f'INSERT OR IGNORE INTO "{self.table._meta.tablename}"' else: diff --git a/tests/table/test_insert.py b/tests/table/test_insert.py index cd9eb0990..2cc234978 100644 --- a/tests/table/test_insert.py +++ b/tests/table/test_insert.py @@ -149,6 +149,149 @@ def test_insert_on_conflict_do_update_multiple_columns(self): ], ) + @engines_only("cockroach") + def test_insert_on_conflict_do_nothing_cockroach(self): + """ + Check that the record has not changed because of the + `on_conflict` clause. + """ + self.insert_rows() + + results = Band.select().run_sync() + + Band.insert( + Band( + id=results[0]["id"], + name="Javas", + manager=results[0]["manager"], + popularity=100, + ), + on_conflict=OnConflict.do_nothing, + ).run_sync() + + response = ( + Band.select(Band.name) + .where(Band.id == results[0]["id"]) + .first() + .run_sync() + ) + self.assertEqual(response["name"], "Pythonistas") + + @engines_only("cockroach") + def test_insert_on_conflict_do_update_single_column_cockroach(self): + """ + Check that the record has changed because of the + `on_update` clause. + """ + self.insert_rows() + + results = Band.select().run_sync() + + Band.insert( + Band( + id=results[0]["id"], + name="Pythonstas-updated", + manager=results[0]["manager"], + popularity=1000, + ), + Band( + id=results[1]["id"], + name="Rustaceans-updated", + manager=results[1]["manager"], + popularity=2000, + ), + Band( + id=results[2]["id"], + name="CSharps-updated", + manager=results[2]["manager"], + popularity=10, + ), + on_conflict=OnConflict.do_update, + ).run_sync() + + response = Band.select().run_sync() + + self.assertEqual( + response, + [ + { + "id": results[0]["id"], + "name": "Pythonstas-updated", + "manager": results[0]["manager"], + "popularity": 1000, + }, + { + "id": results[1]["id"], + "name": "Rustaceans-updated", + "manager": results[1]["manager"], + "popularity": 2000, + }, + { + "id": results[2]["id"], + "name": "CSharps-updated", + "manager": results[2]["manager"], + "popularity": 10, + }, + ], + ) + + @engines_only("cockroach") + def test_insert_on_conflict_do_update_multiple_columns_cockroach(self): + """ + Check that the record has changed because of the + `on_update` clause. + """ + self.insert_rows() + + results = Band.select().run_sync() + + Band.insert( + Band( + id=results[0]["id"], + name="Pythonstas-updated", + manager=results[2]["manager"], + popularity=200, + ), + Band( + id=results[1]["id"], + name="Rustaceans-updated", + manager=results[1]["manager"], + popularity=1000, + ), + Band( + id=results[2]["id"], + name="CSharps-updated", + manager=results[0]["manager"], + popularity=20, + ), + on_conflict=OnConflict.do_update, + ).run_sync() + + response = Band.select().run_sync() + self.assertEqual( + response, + [ + { + "id": results[0]["id"], + "name": "Pythonstas-updated", + "manager": results[2]["manager"], + "popularity": 200, + }, + { + "id": results[1]["id"], + "name": "Rustaceans-updated", + "manager": results[1]["manager"], + "popularity": 1000, + }, + { + "id": results[2]["id"], + "name": "CSharps-updated", + "manager": results[0]["manager"], + "popularity": 20, + }, + ], + ) + @pytest.mark.skipif( is_running_sqlite() and engine_version_lt(3.35), reason="SQLite version not supported", From e6fd34b97dbdf9bdfbd5c2657a73c1ac676cf24a Mon Sep 17 00:00:00 2001 From: sinisaos Date: Fri, 31 Mar 2023 17:50:13 +0200 Subject: [PATCH 4/5] add docs --- docs/src/piccolo/query_types/insert.rst | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/src/piccolo/query_types/insert.rst b/docs/src/piccolo/query_types/insert.rst index eda460de1..61d554a2c 100644 --- a/docs/src/piccolo/query_types/insert.rst +++ b/docs/src/piccolo/query_types/insert.rst @@ -33,3 +33,21 @@ You can also compose it as follows: ).add( Band(name="Gophers") ) + +------------------------------------------------------------------------------- + +on_conflict +----------- + +You can use the ``on_conflict`` clause in an insert query. +Piccolo has ``DO_NOTHING`` and ``DO_UPDATE`` clauses: + +.. code-block:: python + + from piccolo.query.mixins import OnConflict + + await Band.insert( + Band(id=1, name="Darts"), + Band(id=2, name="Gophers"), + on_conflict=OnConflict.do_nothing + ) From 3b566eaa22ccb6d554b50544ecdead4c1fb340b8 Mon Sep 17 00:00:00 2001 From: sinisaos Date: Sun, 2 Apr 2023 10:04:17 +0200 Subject: [PATCH 5/5] add on_conflict check wrong value --- piccolo/query/methods/insert.py | 12 ++++++++++-- tests/table/test_insert.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/piccolo/query/methods/insert.py b/piccolo/query/methods/insert.py index e2575482b..c8e5ff6a5 100644 --- a/piccolo/query/methods/insert.py +++ b/piccolo/query/methods/insert.py @@ -81,8 +81,12 @@ def default_querystrings(self) -> t.Sequence[QueryString]: ): # pragma: no cover if self.on_conflict_delegate._on_conflict == OnConflict.do_nothing: base = f'INSERT OR IGNORE INTO "{self.table._meta.tablename}"' - else: + elif ( + self.on_conflict_delegate._on_conflict == OnConflict.do_update + ): base = f'INSERT OR REPLACE INTO "{self.table._meta.tablename}"' + else: + raise ValueError("Invalid on conflict value") else: base = f'INSERT INTO "{self.table._meta.tablename}"' columns = ",".join( @@ -95,7 +99,9 @@ def default_querystrings(self) -> t.Sequence[QueryString]: {base} ({columns}) VALUES {values} ON CONFLICT {self.on_conflict_delegate._on_conflict.value} """ - else: + elif ( + self.on_conflict_delegate._on_conflict == OnConflict.do_update + ): excluded_updated_columns = ", ".join( f"{i._meta.db_column_name}=EXCLUDED.{i._meta.db_column_name}" # noqa: E501 for i in self.table._meta.columns @@ -106,6 +112,8 @@ def default_querystrings(self) -> t.Sequence[QueryString]: {self.on_conflict_delegate._on_conflict.value} SET {excluded_updated_columns} """ + else: + raise ValueError("Invalid on conflict value") else: query = f"{base} ({columns}) VALUES {values}" querystring = QueryString( diff --git a/tests/table/test_insert.py b/tests/table/test_insert.py index 2cc234978..c48371edc 100644 --- a/tests/table/test_insert.py +++ b/tests/table/test_insert.py @@ -69,6 +69,20 @@ def test_insert_on_conflict_do_nothing(self): ) self.assertEqual(response["name"], "Pythonistas") + @engines_only("postgres", "sqlite") + def test_insert_on_conflict_value_error(self): + """ + Check ValueError if user pass wrong value to + `on_conflict` clause. + """ + self.insert_rows() + + with self.assertRaises(ValueError): + Band.insert( + Band(id=1, name="Javas", popularity=100), + on_conflict="do_nothing", + ).run_sync() + @engines_only("postgres", "sqlite") def test_insert_on_conflict_do_update_single_column(self): """ @@ -177,6 +191,20 @@ def test_insert_on_conflict_do_nothing_cockroach(self): ) self.assertEqual(response["name"], "Pythonistas") + @engines_only("cockroach") + def test_insert_on_conflict_value_error_cockroach(self): + """ + Check ValueError if user pass wrong value to + `on_conflict` clause. + """ + self.insert_rows() + + with self.assertRaises(ValueError): + Band.insert( + Band(id=1, name="Javas", popularity=100), + on_conflict="do_nothing", + ).run_sync() + @engines_only("cockroach") def test_insert_on_conflict_do_update_single_column_cockroach(self): """