diff --git a/docs/src/piccolo/query_types/insert.rst b/docs/src/piccolo/query_types/insert.rst index eda460de1..61d554a2c 100644 --- a/docs/src/piccolo/query_types/insert.rst +++ b/docs/src/piccolo/query_types/insert.rst @@ -33,3 +33,21 @@ You can also compose it as follows: ).add( Band(name="Gophers") ) + +------------------------------------------------------------------------------- + +on_conflict +----------- + +You can use the ``on_conflict`` clause in an insert query. +Piccolo has ``DO_NOTHING`` and ``DO_UPDATE`` clauses: + +.. code-block:: python + + from piccolo.query.mixins import OnConflict + + await Band.insert( + Band(id=1, name="Darts"), + Band(id=2, name="Gophers"), + on_conflict=OnConflict.do_nothing + ) diff --git a/piccolo/query/methods/insert.py b/piccolo/query/methods/insert.py index 9f31f445a..c8e5ff6a5 100644 --- a/piccolo/query/methods/insert.py +++ b/piccolo/query/methods/insert.py @@ -4,7 +4,12 @@ from piccolo.custom_types import TableInstance from piccolo.query.base import Query -from piccolo.query.mixins import AddDelegate, ReturningDelegate +from piccolo.query.mixins import ( + AddDelegate, + OnConflict, + OnConflictDelegate, + ReturningDelegate, +) from piccolo.querystring import QueryString if t.TYPE_CHECKING: # pragma: no cover @@ -15,14 +20,24 @@ class Insert( t.Generic[TableInstance], Query[TableInstance, t.List[t.Dict[str, t.Any]]] ): - __slots__ = ("add_delegate", "returning_delegate") + __slots__ = ( + "add_delegate", + "returning_delegate", + "on_conflict_delegate", + ) def __init__( - self, table: t.Type[TableInstance], *instances: TableInstance, **kwargs + self, + table: t.Type[TableInstance], + on_conflict: t.Optional[OnConflict] = None, + *instances: TableInstance, + **kwargs, ): super().__init__(table, **kwargs) self.add_delegate = AddDelegate() self.returning_delegate = ReturningDelegate() + self.on_conflict_delegate = OnConflictDelegate() + self.on_conflict(on_conflict) # type: ignore self.add(*instances) ########################################################################### @@ -36,6 +51,10 @@ def returning(self: Self, *columns: Column) -> Self: self.returning_delegate.returning(columns) return self + def on_conflict(self: Self, conflict: OnConflict) -> Self: + self.on_conflict_delegate.on_conflict(conflict) + return self + ########################################################################### def _raw_response_callback(self, results): @@ -55,12 +74,48 @@ def _raw_response_callback(self, results): @property def default_querystrings(self) -> t.Sequence[QueryString]: - base = f'INSERT INTO "{self.table._meta.tablename}"' + engine_type = self.engine_type + if ( + engine_type == "sqlite" + and self.table._meta.db.get_version_sync() < 3.24 + ): # pragma: no cover + if self.on_conflict_delegate._on_conflict == OnConflict.do_nothing: + base = f'INSERT OR IGNORE INTO "{self.table._meta.tablename}"' + elif ( + self.on_conflict_delegate._on_conflict == OnConflict.do_update + ): + base = f'INSERT OR REPLACE INTO "{self.table._meta.tablename}"' + else: + raise ValueError("Invalid on conflict value") + else: + base = f'INSERT INTO "{self.table._meta.tablename}"' columns = ",".join( f'"{i._meta.db_column_name}"' for i in self.table._meta.columns ) values = ",".join("{}" for _ in self.add_delegate._add) - query = f"{base} ({columns}) VALUES {values}" + if self.on_conflict_delegate._on_conflict is not None: + if self.on_conflict_delegate._on_conflict == OnConflict.do_nothing: + query = f""" + {base} ({columns}) VALUES {values} ON CONFLICT + {self.on_conflict_delegate._on_conflict.value} + """ + elif ( + self.on_conflict_delegate._on_conflict == OnConflict.do_update + ): + excluded_updated_columns = ", ".join( + f"{i._meta.db_column_name}=EXCLUDED.{i._meta.db_column_name}" # noqa: E501 + for i in self.table._meta.columns + ) + query = f""" + {base} ({columns}) VALUES {values} ON CONFLICT + ({self.table._meta.primary_key._meta.name}) + {self.on_conflict_delegate._on_conflict.value} + SET {excluded_updated_columns} + """ + else: + raise ValueError("Invalid on conflict value") + else: + query = f"{base} ({columns}) VALUES {values}" querystring = QueryString( query, *[i.querystring for i in self.add_delegate._add], @@ -68,8 +123,6 @@ def default_querystrings(self) -> t.Sequence[QueryString]: table=self.table, ) - engine_type = self.engine_type - if engine_type in ("postgres", "cockroach") or ( engine_type == "sqlite" and self.table._meta.db.get_version_sync() >= 3.35 diff --git a/piccolo/query/mixins.py b/piccolo/query/mixins.py index 397f84f99..9c96dbfb8 100644 --- a/piccolo/query/mixins.py +++ b/piccolo/query/mixins.py @@ -141,7 +141,6 @@ def __str__(self): @dataclass class Output: - as_json: bool = False as_list: bool = False as_objects: bool = False @@ -170,7 +169,6 @@ class Callback: @dataclass class WhereDelegate: - _where: t.Optional[Combinable] = None _where_columns: t.List[Column] = field(default_factory=list) @@ -205,7 +203,6 @@ def where(self, *where: Combinable): @dataclass class OrderByDelegate: - _order_by: OrderBy = field(default_factory=OrderBy) def get_order_by_columns(self) -> t.List[Column]: @@ -231,7 +228,6 @@ def order_by(self, *columns: t.Union[Column, OrderByRaw], ascending=True): @dataclass class LimitDelegate: - _limit: t.Optional[Limit] = None _first: bool = False @@ -258,7 +254,6 @@ def as_of(self, interval: str = "-1s"): @dataclass class DistinctDelegate: - _distinct: bool = False def distinct(self): @@ -275,7 +270,6 @@ def returning(self, columns: t.Sequence[Column]): @dataclass class CountDelegate: - _count: bool = False def count(self): @@ -284,7 +278,6 @@ def count(self): @dataclass class AddDelegate: - _add: t.List[Table] = field(default_factory=list) def add(self, *instances: Table, table_class: t.Type[Table]): @@ -548,3 +541,16 @@ class GroupByDelegate: def group_by(self, *columns: Column): self._group_by = GroupBy(columns=columns) + + +class OnConflict(str, Enum): + do_nothing = "DO NOTHING" + do_update = "DO UPDATE" + + +@dataclass +class OnConflictDelegate: + _on_conflict: t.Optional[OnConflict] = None + + def on_conflict(self, conflict: OnConflict): + self._on_conflict = conflict diff --git a/piccolo/table.py b/piccolo/table.py index 2982180bd..375d62b7e 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -46,6 +46,7 @@ from piccolo.query.methods.indexes import Indexes from piccolo.query.methods.objects import First from piccolo.query.methods.refresh import Refresh +from piccolo.query.mixins import OnConflict from piccolo.querystring import QueryString, Unquoted from piccolo.utils import _camel_to_snake from piccolo.utils.graphlib import TopologicalSorter @@ -906,7 +907,9 @@ def ref(cls, column_name: str) -> Column: @classmethod def insert( - cls: t.Type[TableInstance], *rows: TableInstance + cls: t.Type[TableInstance], + *rows: TableInstance, + on_conflict: t.Optional[OnConflict] = None, ) -> Insert[TableInstance]: """ Insert rows into the database. @@ -918,7 +921,9 @@ def insert( ) """ - query = Insert(table=cls).returning(cls._meta.primary_key) + query = Insert(table=cls, on_conflict=on_conflict).returning( + cls._meta.primary_key + ) if rows: query.add(*rows) return query diff --git a/tests/table/test_insert.py b/tests/table/test_insert.py index 474497f25..c48371edc 100644 --- a/tests/table/test_insert.py +++ b/tests/table/test_insert.py @@ -1,6 +1,12 @@ import pytest -from tests.base import DBTestCase, engine_version_lt, is_running_sqlite +from piccolo.query.mixins import OnConflict +from tests.base import ( + DBTestCase, + engine_version_lt, + engines_only, + is_running_sqlite, +) from tests.example_apps.music.tables import Band, Manager @@ -45,6 +51,275 @@ def test_insert_curly_braces(self): self.assertIn("{}", names) + @engines_only("postgres", "sqlite") + def test_insert_on_conflict_do_nothing(self): + """ + Check that the record has not changed because of the + `on_conflict` clause. + """ + self.insert_rows() + + Band.insert( + Band(id=1, name="Javas", popularity=100), + on_conflict=OnConflict.do_nothing, + ).run_sync() + + response = ( + Band.select(Band.name).where(Band.id == 1).first().run_sync() + ) + self.assertEqual(response["name"], "Pythonistas") + + @engines_only("postgres", "sqlite") + def test_insert_on_conflict_value_error(self): + """ + Check ValueError if user pass wrong value to + `on_conflict` clause. + """ + self.insert_rows() + + with self.assertRaises(ValueError): + Band.insert( + Band(id=1, name="Javas", popularity=100), + on_conflict="do_nothing", + ).run_sync() + + @engines_only("postgres", "sqlite") + def test_insert_on_conflict_do_update_single_column(self): + """ + Check that the record has changed because of the + `on_update` clause. + """ + self.insert_rows() + + Band.insert( + Band(id=1, name="Pythonstas-updated", manager=1, popularity=1000), + Band(id=2, name="Rustaceans-updated", manager=2, popularity=2000), + Band(id=3, name="CSharps-updated", manager=3, popularity=10), + on_conflict=OnConflict.do_update, + ).run_sync() + + response = Band.select().run_sync() + self.assertEqual( + response, + [ + { + "id": 1, + "name": "Pythonstas-updated", + "manager": 1, + "popularity": 1000, + }, + { + "id": 2, + "name": "Rustaceans-updated", + "manager": 2, + "popularity": 2000, + }, + { + "id": 3, + "name": "CSharps-updated", + "manager": 3, + "popularity": 10, + }, + ], + ) + + @engines_only("postgres", "sqlite") + def test_insert_on_conflict_do_update_multiple_columns(self): + """ + Check that the record has changed because of the + `on_update` clause. + """ + self.insert_rows() + + Band.insert( + Band(id=1, name="Pythonstas-updated", manager=3, popularity=200), + Band(id=2, name="Rustaceans-updated", manager=2, popularity=1000), + Band(id=3, name="CSharps-updated", manager=1, popularity=20), + on_conflict=OnConflict.do_update, + ).run_sync() + + response = Band.select().run_sync() + self.assertEqual( + response, + [ + { + "id": 1, + "name": "Pythonstas-updated", + "manager": 3, + "popularity": 200, + }, + { + "id": 2, + "name": "Rustaceans-updated", + "manager": 2, + "popularity": 1000, + }, + { + "id": 3, + "name": "CSharps-updated", + "manager": 1, + "popularity": 20, + }, + ], + ) + + @engines_only("cockroach") + def test_insert_on_conflict_do_nothing_cockroach(self): + """ + Check that the record has not changed because of the + `on_conflict` clause. + """ + self.insert_rows() + + results = Band.select().run_sync() + + Band.insert( + Band( + id=results[0]["id"], + name="Javas", + manager=results[0]["manager"], + popularity=100, + ), + on_conflict=OnConflict.do_nothing, + ).run_sync() + + response = ( + Band.select(Band.name) + .where(Band.id == results[0]["id"]) + .first() + .run_sync() + ) + self.assertEqual(response["name"], "Pythonistas") + + @engines_only("cockroach") + def test_insert_on_conflict_value_error_cockroach(self): + """ + Check ValueError if user pass wrong value to + `on_conflict` clause. + """ + self.insert_rows() + + with self.assertRaises(ValueError): + Band.insert( + Band(id=1, name="Javas", popularity=100), + on_conflict="do_nothing", + ).run_sync() + + @engines_only("cockroach") + def test_insert_on_conflict_do_update_single_column_cockroach(self): + """ + Check that the record has changed because of the + `on_update` clause. + """ + self.insert_rows() + + results = Band.select().run_sync() + + Band.insert( + Band( + id=results[0]["id"], + name="Pythonstas-updated", + manager=results[0]["manager"], + popularity=1000, + ), + Band( + id=results[1]["id"], + name="Rustaceans-updated", + manager=results[1]["manager"], + popularity=2000, + ), + Band( + id=results[2]["id"], + name="CSharps-updated", + manager=results[2]["manager"], + popularity=10, + ), + on_conflict=OnConflict.do_update, + ).run_sync() + + response = Band.select().run_sync() + + self.assertEqual( + response, + [ + { + "id": results[0]["id"], + "name": "Pythonstas-updated", + "manager": results[0]["manager"], + "popularity": 1000, + }, + { + "id": results[1]["id"], + "name": "Rustaceans-updated", + "manager": results[1]["manager"], + "popularity": 2000, + }, + { + "id": results[2]["id"], + "name": "CSharps-updated", + "manager": results[2]["manager"], + "popularity": 10, + }, + ], + ) + + @engines_only("cockroach") + def test_insert_on_conflict_do_update_multiple_columns_cockroach(self): + """ + Check that the record has changed because of the + `on_update` clause. + """ + self.insert_rows() + + results = Band.select().run_sync() + + Band.insert( + Band( + id=results[0]["id"], + name="Pythonstas-updated", + manager=results[2]["manager"], + popularity=200, + ), + Band( + id=results[1]["id"], + name="Rustaceans-updated", + manager=results[1]["manager"], + popularity=1000, + ), + Band( + id=results[2]["id"], + name="CSharps-updated", + manager=results[0]["manager"], + popularity=20, + ), + on_conflict=OnConflict.do_update, + ).run_sync() + + response = Band.select().run_sync() + self.assertEqual( + response, + [ + { + "id": results[0]["id"], + "name": "Pythonstas-updated", + "manager": results[2]["manager"], + "popularity": 200, + }, + { + "id": results[1]["id"], + "name": "Rustaceans-updated", + "manager": results[1]["manager"], + "popularity": 1000, + }, + { + "id": results[2]["id"], + "name": "CSharps-updated", + "manager": results[0]["manager"], + "popularity": 20, + }, + ], + ) + @pytest.mark.skipif( is_running_sqlite() and engine_version_lt(3.35), reason="SQLite version not supported",