Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/src/piccolo/query_types/insert.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,21 @@ You can also compose it as follows:
).add(
Band(name="Gophers")
)

-------------------------------------------------------------------------------

on_conflict
-----------

You can use the ``on_conflict`` clause in an insert query.
Piccolo has ``DO_NOTHING`` and ``DO_UPDATE`` clauses:

.. code-block:: python

from piccolo.query.mixins import OnConflict

await Band.insert(
Band(id=1, name="Darts"),
Band(id=2, name="Gophers"),
on_conflict=OnConflict.do_nothing
)
67 changes: 60 additions & 7 deletions piccolo/query/methods/insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

from piccolo.custom_types import TableInstance
from piccolo.query.base import Query
from piccolo.query.mixins import AddDelegate, ReturningDelegate
from piccolo.query.mixins import (
AddDelegate,
OnConflict,
OnConflictDelegate,
ReturningDelegate,
)
from piccolo.querystring import QueryString

if t.TYPE_CHECKING: # pragma: no cover
Expand All @@ -15,14 +20,24 @@
class Insert(
t.Generic[TableInstance], Query[TableInstance, t.List[t.Dict[str, t.Any]]]
):
__slots__ = ("add_delegate", "returning_delegate")
__slots__ = (
"add_delegate",
"returning_delegate",
"on_conflict_delegate",
)

def __init__(
self, table: t.Type[TableInstance], *instances: TableInstance, **kwargs
self,
table: t.Type[TableInstance],
on_conflict: t.Optional[OnConflict] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would drop consider dropping this, and just using the on_conflict method instead.

The reason being, we can add extra arguments to the on_conflict method in the future, but we don't want to add too many to __init__.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should change it, but feel free to do as you think is best. With this method, we can also easily solve the problem with duplicate entries in M2M. I'm sorry if I didn't understand well what you wanted to say and feel free to change this if you think on_conflict as method would be better way.

*instances: TableInstance,
**kwargs,
):
super().__init__(table, **kwargs)
self.add_delegate = AddDelegate()
self.returning_delegate = ReturningDelegate()
self.on_conflict_delegate = OnConflictDelegate()
self.on_conflict(on_conflict) # type: ignore
self.add(*instances)

###########################################################################
Expand All @@ -36,6 +51,10 @@ def returning(self: Self, *columns: Column) -> Self:
self.returning_delegate.returning(columns)
return self

def on_conflict(self: Self, conflict: OnConflict) -> Self:
Copy link
Member

@dantownsend dantownsend Apr 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should allow the user to pass in specific columns?

await Band.insert(Band(name="Pythonistas")).on_conflict(Band.name, do_nothing=True)

If not specified, then we default to all columns.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you look at the tests, even now the user can update one or more columns on do_update conflict. For do_nothing conflict I don't see point of allowing specifying columns because on do_nothing conflict columns don't change anyway.

self.on_conflict_delegate.on_conflict(conflict)
return self

###########################################################################

def _raw_response_callback(self, results):
Expand All @@ -55,21 +74,55 @@ def _raw_response_callback(self, results):

@property
def default_querystrings(self) -> t.Sequence[QueryString]:
base = f'INSERT INTO "{self.table._meta.tablename}"'
engine_type = self.engine_type
if (
engine_type == "sqlite"
and self.table._meta.db.get_version_sync() < 3.24
): # pragma: no cover
if self.on_conflict_delegate._on_conflict == OnConflict.do_nothing:
base = f'INSERT OR IGNORE INTO "{self.table._meta.tablename}"'
elif (
self.on_conflict_delegate._on_conflict == OnConflict.do_update
):
base = f'INSERT OR REPLACE INTO "{self.table._meta.tablename}"'
else:
raise ValueError("Invalid on conflict value")
else:
base = f'INSERT INTO "{self.table._meta.tablename}"'
columns = ",".join(
f'"{i._meta.db_column_name}"' for i in self.table._meta.columns
)
values = ",".join("{}" for _ in self.add_delegate._add)
query = f"{base} ({columns}) VALUES {values}"
if self.on_conflict_delegate._on_conflict is not None:
if self.on_conflict_delegate._on_conflict == OnConflict.do_nothing:
query = f"""
{base} ({columns}) VALUES {values} ON CONFLICT
{self.on_conflict_delegate._on_conflict.value}
"""
elif (
self.on_conflict_delegate._on_conflict == OnConflict.do_update
):
excluded_updated_columns = ", ".join(
f"{i._meta.db_column_name}=EXCLUDED.{i._meta.db_column_name}" # noqa: E501
for i in self.table._meta.columns
)
query = f"""
{base} ({columns}) VALUES {values} ON CONFLICT
({self.table._meta.primary_key._meta.name})
{self.on_conflict_delegate._on_conflict.value}
SET {excluded_updated_columns}
"""
else:
raise ValueError("Invalid on conflict value")
else:
query = f"{base} ({columns}) VALUES {values}"
querystring = QueryString(
query,
*[i.querystring for i in self.add_delegate._add],
query_type="insert",
table=self.table,
)

engine_type = self.engine_type

if engine_type in ("postgres", "cockroach") or (
engine_type == "sqlite"
and self.table._meta.db.get_version_sync() >= 3.35
Expand Down
20 changes: 13 additions & 7 deletions piccolo/query/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def __str__(self):

@dataclass
class Output:

as_json: bool = False
as_list: bool = False
as_objects: bool = False
Expand Down Expand Up @@ -170,7 +169,6 @@ class Callback:

@dataclass
class WhereDelegate:

_where: t.Optional[Combinable] = None
_where_columns: t.List[Column] = field(default_factory=list)

Expand Down Expand Up @@ -205,7 +203,6 @@ def where(self, *where: Combinable):

@dataclass
class OrderByDelegate:

_order_by: OrderBy = field(default_factory=OrderBy)

def get_order_by_columns(self) -> t.List[Column]:
Expand All @@ -231,7 +228,6 @@ def order_by(self, *columns: t.Union[Column, OrderByRaw], ascending=True):

@dataclass
class LimitDelegate:

_limit: t.Optional[Limit] = None
_first: bool = False

Expand All @@ -258,7 +254,6 @@ def as_of(self, interval: str = "-1s"):

@dataclass
class DistinctDelegate:

_distinct: bool = False

def distinct(self):
Expand All @@ -275,7 +270,6 @@ def returning(self, columns: t.Sequence[Column]):

@dataclass
class CountDelegate:

_count: bool = False

def count(self):
Expand All @@ -284,7 +278,6 @@ def count(self):

@dataclass
class AddDelegate:

_add: t.List[Table] = field(default_factory=list)

def add(self, *instances: Table, table_class: t.Type[Table]):
Expand Down Expand Up @@ -548,3 +541,16 @@ class GroupByDelegate:

def group_by(self, *columns: Column):
self._group_by = GroupBy(columns=columns)


class OnConflict(str, Enum):
do_nothing = "DO NOTHING"
do_update = "DO UPDATE"


@dataclass
class OnConflictDelegate:
_on_conflict: t.Optional[OnConflict] = None

def on_conflict(self, conflict: OnConflict):
self._on_conflict = conflict
9 changes: 7 additions & 2 deletions piccolo/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from piccolo.query.methods.indexes import Indexes
from piccolo.query.methods.objects import First
from piccolo.query.methods.refresh import Refresh
from piccolo.query.mixins import OnConflict
from piccolo.querystring import QueryString, Unquoted
from piccolo.utils import _camel_to_snake
from piccolo.utils.graphlib import TopologicalSorter
Expand Down Expand Up @@ -906,7 +907,9 @@ def ref(cls, column_name: str) -> Column:

@classmethod
def insert(
cls: t.Type[TableInstance], *rows: TableInstance
cls: t.Type[TableInstance],
*rows: TableInstance,
on_conflict: t.Optional[OnConflict] = None,
) -> Insert[TableInstance]:
"""
Insert rows into the database.
Expand All @@ -918,7 +921,9 @@ def insert(
)

"""
query = Insert(table=cls).returning(cls._meta.primary_key)
query = Insert(table=cls, on_conflict=on_conflict).returning(
cls._meta.primary_key
)
if rows:
query.add(*rows)
return query
Expand Down
Loading