55
66from piccolo .apps .migrations .auto .operations import (
77 AddColumn ,
8+ AddConstraint ,
89 AlterColumn ,
910 DropColumn ,
11+ DropConstraint ,
1012)
1113from piccolo .apps .migrations .auto .serialisation import (
1214 deserialise_params ,
1315 serialise_params ,
1416)
1517from piccolo .columns .base import Column
18+ from piccolo .constraint import Constraint
1619from piccolo .table import Table , create_table_class
1720
1821
@@ -55,6 +58,8 @@ class TableDelta:
5558 add_columns : t .List [AddColumn ] = field (default_factory = list )
5659 drop_columns : t .List [DropColumn ] = field (default_factory = list )
5760 alter_columns : t .List [AlterColumn ] = field (default_factory = list )
61+ add_constraints : t .List [AddConstraint ] = field (default_factory = list )
62+ drop_constraints : t .List [DropConstraint ] = field (default_factory = list )
5863
5964 def __eq__ (self , value : TableDelta ) -> bool : # type: ignore
6065 """
@@ -85,6 +90,19 @@ def __eq__(self, value) -> bool:
8590 return False
8691
8792
93+ @dataclass
94+ class ConstraintComparison :
95+ constraint : Constraint
96+
97+ def __hash__ (self ) -> int :
98+ return self .constraint .__hash__ ()
99+
100+ def __eq__ (self , value ) -> bool :
101+ if isinstance (value , ConstraintComparison ):
102+ return self .constraint ._meta .name == value .constraint ._meta .name
103+ return False
104+
105+
88106@dataclass
89107class DiffableTable :
90108 """
@@ -96,6 +114,7 @@ class DiffableTable:
96114 tablename : str
97115 schema : t .Optional [str ] = None
98116 columns : t .List [Column ] = field (default_factory = list )
117+ constraints : t .List [Constraint ] = field (default_factory = list )
99118 previous_class_name : t .Optional [str ] = None
100119
101120 def __post_init__ (self ) -> None :
@@ -189,10 +208,54 @@ def __sub__(self, value: DiffableTable) -> TableDelta:
189208 )
190209 )
191210
211+ add_constraints = [
212+ AddConstraint (
213+ table_class_name = self .class_name ,
214+ constraint_name = i .constraint ._meta .name ,
215+ constraint_class_name = i .constraint .__class__ .__name__ ,
216+ constraint_class = i .constraint .__class__ ,
217+ params = i .constraint ._meta .params ,
218+ schema = self .schema ,
219+ )
220+ for i in sorted (
221+ {
222+ ConstraintComparison (constraint = constraint )
223+ for constraint in self .constraints
224+ }
225+ - {
226+ ConstraintComparison (constraint = constraint )
227+ for constraint in value .constraints
228+ },
229+ key = lambda x : x .constraint ._meta .name ,
230+ )
231+ ]
232+
233+ drop_constraints = [
234+ DropConstraint (
235+ table_class_name = self .class_name ,
236+ constraint_name = i .constraint ._meta .name ,
237+ tablename = value .tablename ,
238+ schema = self .schema ,
239+ )
240+ for i in sorted (
241+ {
242+ ConstraintComparison (constraint = constraint )
243+ for constraint in value .constraints
244+ }
245+ - {
246+ ConstraintComparison (constraint = constraint )
247+ for constraint in self .constraints
248+ },
249+ key = lambda x : x .constraint ._meta .name ,
250+ )
251+ ]
252+
192253 return TableDelta (
193254 add_columns = add_columns ,
194255 drop_columns = drop_columns ,
195256 alter_columns = alter_columns ,
257+ add_constraints = add_constraints ,
258+ drop_constraints = drop_constraints ,
196259 )
197260
198261 def __hash__ (self ) -> int :
@@ -218,10 +281,14 @@ def to_table_class(self) -> t.Type[Table]:
218281 """
219282 Converts the DiffableTable into a Table subclass.
220283 """
284+ class_members : t .Dict [str , t .Any ] = {}
285+ for column in self .columns :
286+ class_members [column ._meta .name ] = column
287+ for constraint in self .constraints :
288+ class_members [constraint ._meta .name ] = constraint
289+
221290 return create_table_class (
222291 class_name = self .class_name ,
223292 class_kwargs = {"tablename" : self .tablename , "schema" : self .schema },
224- class_members = {
225- column ._meta .name : column for column in self .columns
226- },
293+ class_members = class_members ,
227294 )
0 commit comments