Skip to content

Commit e966d06

Browse files
committed
Use mapped_column to improve sqlalchemy typing.
1 parent b3cb6c1 commit e966d06

File tree

2 files changed

+229
-199
lines changed

2 files changed

+229
-199
lines changed

api/src/membership/models.py

Lines changed: 97 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from datetime import datetime
12
from logging import getLogger
2-
from typing import Any, Optional
3+
from typing import Any, List, Literal, Optional
34

45
import phonenumbers as phonenumbers
56
from basic_types.enums import PriceLevel
@@ -18,13 +19,23 @@
1819
func,
1920
select,
2021
)
21-
from sqlalchemy.orm import column_property, configure_mappers, declarative_base, relationship, validates
22+
from sqlalchemy.orm import (
23+
DeclarativeBase,
24+
Mapped,
25+
MappedAsDataclass,
26+
column_property,
27+
configure_mappers,
28+
mapped_column,
29+
relationship,
30+
validates,
31+
)
2232

23-
Base = declarative_base()
2433

34+
class Base(DeclarativeBase):
35+
pass
2536

26-
logger = getLogger("makeradmin")
2737

38+
logger = getLogger("makeradmin")
2839

2940
member_group = Table(
3041
"membership_members_groups",
@@ -37,41 +48,43 @@
3748
class Member(Base):
3849
__tablename__ = "membership_members"
3950

40-
member_id = Column(Integer, primary_key=True, nullable=False, autoincrement=True)
41-
email = Column(String(255), unique=True, nullable=False)
42-
password = Column(String(60))
43-
firstname = Column(String(255), nullable=False)
44-
lastname = Column(String(255))
45-
civicregno = Column(String(25))
46-
company = Column(String(255))
47-
orgno = Column(String(12))
48-
address_street = Column(String(255))
49-
address_extra = Column(String(255))
50-
address_zipcode = Column(Integer)
51-
address_city = Column(String(255))
52-
address_country = Column(String(2))
53-
phone = Column(String(255))
54-
created_at = Column(DateTime, server_default=func.now())
55-
updated_at = Column(DateTime, server_default=func.now())
56-
deleted_at = Column(DateTime)
51+
member_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False, autoincrement=True)
52+
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
53+
password: Mapped[Optional[str]]
54+
firstname: Mapped[str]
55+
lastname: Mapped[Optional[str]]
56+
civicregno: Mapped[Optional[str]]
57+
company: Mapped[Optional[str]]
58+
orgno: Mapped[Optional[str]]
59+
address_street: Mapped[Optional[str]]
60+
address_extra: Mapped[Optional[str]]
61+
address_zipcode: Mapped[Optional[int]]
62+
address_city: Mapped[Optional[str]]
63+
address_country: Mapped[Optional[str]]
64+
phone: Mapped[Optional[str]]
65+
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
66+
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
67+
deleted_at: Mapped[Optional[datetime]]
5768

5869
# True during the registration flow as the payment is being processed
59-
pending_activation = Column(Boolean, nullable=False)
70+
pending_activation: Mapped[bool]
6071

61-
member_number = Column(Integer, unique=True)
62-
labaccess_agreement_at = Column(DateTime)
63-
pin_code = Column(String(30))
64-
stripe_customer_id = Column(String(64))
65-
stripe_membership_subscription_id = Column(String(64))
66-
stripe_labaccess_subscription_id = Column(String(64))
67-
price_level = Column(Enum(*[x.value for x in PriceLevel]), nullable=False)
68-
price_level_motivation = Column(String)
72+
member_number: Mapped[int] = mapped_column(Integer, unique=True, nullable=False)
73+
labaccess_agreement_at: Mapped[Optional[datetime]]
74+
pin_code: Mapped[Optional[str]]
75+
stripe_customer_id: Mapped[Optional[str]]
76+
stripe_membership_subscription_id: Mapped[Optional[str]]
77+
stripe_labaccess_subscription_id: Mapped[Optional[str]]
78+
price_level: Mapped[str] = mapped_column(Enum(*[x.value for x in PriceLevel]), nullable=False)
79+
price_level_motivation: Mapped[Optional[str]]
6980

7081
@validates("phone")
7182
def validate_phone(self, key: Any, value: Optional[str]) -> Optional[str]:
7283
return normalise_phone_number(value)
7384

74-
groups = relationship("Group", secondary=member_group, back_populates="members", cascade_backrefs=False)
85+
groups: Mapped[List["Group"]] = relationship(
86+
"Group", secondary=member_group, back_populates="members", cascade_backrefs=False
87+
)
7588

7689
def __repr__(self) -> str:
7790
return f"Member(member_id={self.member_id}, member_number={self.member_number}, email={self.email})"
@@ -89,19 +102,19 @@ def __repr__(self) -> str:
89102
class Group(Base):
90103
__tablename__ = "membership_groups"
91104

92-
group_id = Column(Integer, primary_key=True, nullable=False, autoincrement=True)
93-
name = Column(String(255), nullable=False)
94-
title = Column(String(255), nullable=False)
95-
description = Column(Text)
96-
created_at = Column(DateTime, server_default=func.now())
97-
updated_at = Column(DateTime, server_default=func.now())
98-
deleted_at = Column(DateTime)
105+
group_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False, autoincrement=True)
106+
name: Mapped[str]
107+
title: Mapped[str]
108+
description: Mapped[Optional[str]]
109+
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
110+
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
111+
deleted_at: Mapped[Optional[datetime]]
99112

100-
members = relationship(
113+
members: Mapped[List[Member]] = relationship(
101114
"Member", secondary=member_group, lazy="dynamic", back_populates="groups", cascade_backrefs=False
102115
)
103116

104-
permissions = relationship(
117+
permissions: Mapped[List["Permission"]] = relationship(
105118
"Permission", secondary=group_permission, back_populates="groups", cascade_backrefs=False
106119
)
107120

@@ -121,27 +134,29 @@ def __repr__(self) -> str:
121134
class Permission(Base):
122135
__tablename__ = "membership_permissions"
123136

124-
permission_id = Column(Integer, primary_key=True, nullable=False, autoincrement=True)
125-
permission = Column(String(255), nullable=False, unique=True)
126-
created_at = Column(DateTime, server_default=func.now())
127-
updated_at = Column(DateTime, server_default=func.now())
128-
deleted_at = Column(DateTime)
137+
permission_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False, autoincrement=True)
138+
permission: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
139+
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
140+
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
141+
deleted_at: Mapped[Optional[datetime]]
129142

130-
groups = relationship("Group", secondary=group_permission, back_populates="permissions", cascade_backrefs=False)
143+
groups: Mapped[List[Group]] = relationship(
144+
"Group", secondary=group_permission, back_populates="permissions", cascade_backrefs=False
145+
)
131146

132147

133148
class Key(Base):
134149
__tablename__ = "membership_keys"
135150

136-
key_id = Column(Integer, primary_key=True, nullable=False, autoincrement=True)
137-
member_id = Column(Integer, ForeignKey("membership_members.member_id"), nullable=False)
138-
description = Column(Text)
139-
tagid = Column(String(255), nullable=False, unique=True)
140-
created_at = Column(DateTime, server_default=func.now())
141-
updated_at = Column(DateTime, server_default=func.now())
142-
deleted_at = Column(DateTime)
151+
key_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False, autoincrement=True)
152+
member_id: Mapped[int] = mapped_column(Integer, ForeignKey("membership_members.member_id"), nullable=False)
153+
description: Mapped[Optional[str]] = mapped_column(Text)
154+
tagid: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
155+
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
156+
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
157+
deleted_at: Mapped[Optional[datetime]]
143158

144-
member = relationship(Member, backref="keys", cascade_backrefs=False)
159+
member: Mapped[Member] = relationship(Member, backref="keys", cascade_backrefs=False)
145160

146161
def __repr__(self) -> str:
147162
return f"Key(key_id={self.key_id}, tagid={self.tagid})"
@@ -150,21 +165,22 @@ def __repr__(self) -> str:
150165
class Span(Base):
151166
__tablename__ = "membership_spans"
152167

153-
LABACCESS = "labaccess"
154-
MEMBERSHIP = "membership"
155-
SPECIAL_LABACESS = "special_labaccess"
168+
ACCESS_TYPE = Literal["labaccess", "membership", "special_labaccess"]
169+
LABACCESS: ACCESS_TYPE = "labaccess"
170+
MEMBERSHIP: ACCESS_TYPE = "membership"
171+
SPECIAL_LABACESS: ACCESS_TYPE = "special_labaccess"
156172

157-
span_id = Column(Integer, primary_key=True, nullable=False, autoincrement=True)
158-
member_id = Column(Integer, ForeignKey("membership_members.member_id"), nullable=False)
159-
startdate = Column(Date, nullable=False) # Start date, inclusive
160-
enddate = Column(Date, nullable=False) # End date, inclusive
161-
type = Column(Enum(LABACCESS, MEMBERSHIP, SPECIAL_LABACESS), nullable=False)
162-
creation_reason = Column(String(255), unique=True)
163-
created_at = Column(DateTime, server_default=func.now())
164-
deleted_at = Column(DateTime)
165-
deletion_reason = Column(String(255))
173+
span_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False, autoincrement=True)
174+
member_id: Mapped[int] = mapped_column(Integer, ForeignKey("membership_members.member_id"), nullable=False)
175+
startdate: Mapped[Date] = mapped_column(Date, nullable=False) # Start date, inclusive
176+
enddate: Mapped[Date] = mapped_column(Date, nullable=False) # End date, inclusive
177+
type: Mapped[ACCESS_TYPE] = mapped_column(Enum(LABACCESS, MEMBERSHIP, SPECIAL_LABACESS), nullable=False)
178+
creation_reason: Mapped[str] = mapped_column(String(255), unique=True)
179+
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
180+
deleted_at: Mapped[Optional[datetime]]
181+
deletion_reason: Mapped[Optional[str]]
166182

167-
member = relationship(Member, backref="spans", cascade_backrefs=False)
183+
member: Mapped[Member] = relationship(Member, backref="spans", cascade_backrefs=False)
168184

169185
def __repr__(self) -> str:
170186
return f"Span(span_id={self.span_id}, type={self.type}, enddate={self.enddate})"
@@ -173,24 +189,24 @@ def __repr__(self) -> str:
173189
class Box(Base):
174190
__tablename__ = "membership_box"
175191

176-
id = Column(Integer, primary_key=True, nullable=False, autoincrement=True)
192+
id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False, autoincrement=True)
177193

178-
member_id = Column(Integer, ForeignKey("membership_members.member_id"), nullable=False)
194+
member_id: Mapped[int] = mapped_column(Integer, ForeignKey("membership_members.member_id"), nullable=False)
179195

180196
# The id of the printed label on the box.
181-
box_label_id = Column(BigInteger, unique=True, nullable=False)
197+
box_label_id: Mapped[int] = mapped_column(BigInteger, unique=True, nullable=False)
182198

183199
# Scanning session to be able to make list of all scanned boxes during the session.
184-
session_token = Column(String(32), index=True, nullable=False)
200+
session_token: Mapped[str] = mapped_column(String(32), index=True, nullable=False)
185201

186202
# Box last checked at timestamp.
187-
last_check_at = Column(DateTime, nullable=True)
203+
last_check_at: Mapped[Optional[datetime]]
188204

189205
# Last time a nag mail was sent out for this box, note that for a member with several boxes this may not be the
190206
# last nag date for that member.
191-
last_nag_at = Column(DateTime, nullable=False)
207+
last_nag_at: Mapped[Optional[datetime]]
192208

193-
member = relationship(Member, backref="boxes", cascade_backrefs=False)
209+
member: Mapped[Member] = relationship(Member, backref="boxes", cascade_backrefs=False)
194210

195211
def __repr__(self) -> str:
196212
return (
@@ -202,20 +218,20 @@ def __repr__(self) -> str:
202218
class PhoneNumberChangeRequest(Base):
203219
__tablename__ = "change_phone_number_requests"
204220

205-
id = Column(Integer, primary_key=True, nullable=False, autoincrement=True)
206-
member_id = Column(Integer, ForeignKey("membership_members.member_id"), nullable=True)
207-
phone = Column(String(255), nullable=False)
221+
id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False, autoincrement=True)
222+
member_id: Mapped[int] = mapped_column(Integer, ForeignKey("membership_members.member_id"), nullable=True)
223+
phone: Mapped[str]
208224

209225
# Number used to compare if the reques is valid or not.
210-
validation_code = Column(Integer, nullable=False)
226+
validation_code: Mapped[int]
211227

212228
# If the request has been completed or not.
213-
completed = Column(Boolean, nullable=False)
229+
completed: Mapped[bool]
214230

215231
# When the request was made.
216-
timestamp = Column(DateTime, nullable=False)
232+
timestamp: Mapped[datetime]
217233

218-
member = relationship(Member, backref="change_phone_number_requests", cascade_backrefs=False)
234+
member: Mapped[Member] = relationship(Member, backref="change_phone_number_requests", cascade_backrefs=False)
219235

220236
@validates("phone")
221237
def validate_phone(self, key: Any, value: Optional[str]) -> Optional[str]:

0 commit comments

Comments
 (0)