Skip to content

Commit fef1bab

Browse files
author
Micah Denbraver
committed
use aioapns
1 parent 6b4d226 commit fef1bab

File tree

7 files changed

+438
-325
lines changed

7 files changed

+438
-325
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@ repos:
1212
rev: v3.15.2
1313
hooks:
1414
- id: pyupgrade
15+
args:
16+
- --keep-mock # for AsyncMock in 3.7

push_notifications/apns.py

Lines changed: 188 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -4,141 +4,207 @@
44
https://developer.apple.com/library/content/documentation/NetworkingInternet/Conceptual/RemoteNotificationsPG/APNSOverview.html
55
"""
66

7+
import asyncio
78
import contextlib
9+
import tempfile
810
import time
911

10-
from apns2 import client as apns2_client
11-
from apns2 import credentials as apns2_credentials
12-
from apns2 import errors as apns2_errors
13-
from apns2 import payload as apns2_payload
12+
import aioapns
13+
from aioapns.common import APNS_RESPONSE_CODE, PRIORITY_HIGH, PRIORITY_NORMAL
14+
from asgiref.sync import async_to_sync
1415

1516
from . import models
1617
from .conf import get_manager
17-
from .exceptions import APNSError, APNSUnsupportedPriority, APNSServerError
18+
from .exceptions import APNSError, APNSServerError, APNSUnsupportedPriority
19+
20+
21+
SUCCESS_RESULT = "Success"
22+
UNREGISTERED_RESULT = "Unregistered"
1823

1924

2025
@contextlib.contextmanager
21-
def _apns_create_socket(application_id=None):
22-
if not get_manager().has_auth_token_creds(application_id):
23-
cert = get_manager().get_apns_certificate(application_id)
24-
creds = apns2_credentials.CertificateCredentials(cert)
25-
else:
26-
keyPath, keyId, teamId = get_manager().get_apns_auth_creds(application_id)
27-
# No use getting a lifetime because this credential is
28-
# ephemeral, but if you're looking at this to see how to
29-
# create a credential, you could also pass the lifetime and
30-
# algorithm. Neither of those settings are exposed in the
31-
# settings API at the moment.
32-
creds = apns2_credentials.TokenCredentials(keyPath, keyId, teamId)
33-
client = apns2_client.APNsClient(
34-
creds,
35-
use_sandbox=get_manager().get_apns_use_sandbox(application_id),
36-
use_alternative_port=get_manager().get_apns_use_alternative_port(application_id)
37-
)
38-
client.connect()
39-
yield client
26+
def _apns_path_for_cert(cert):
27+
if cert is None:
28+
yield None
29+
with tempfile.NamedTemporaryFile("w") as cert_file:
30+
cert_file.write(cert)
31+
cert_file.flush()
32+
yield cert_file.name
33+
34+
35+
def _apns_create_client(application_id=None):
36+
cert = None
37+
key_path = None
38+
key_id = None
39+
team_id = None
40+
41+
if not get_manager().has_auth_token_creds(application_id):
42+
cert = get_manager().get_apns_certificate(application_id)
43+
else:
44+
key_path, key_id, team_id = get_manager().get_apns_auth_creds(application_id)
45+
# No use getting a lifetime because this credential is
46+
# ephemeral, but if you're looking at this to see how to
47+
# create a credential, you could also pass the lifetime and
48+
# algorithm. Neither of those settings are exposed in the
49+
# settings API at the moment.
50+
51+
with _apns_path_for_cert(cert) as cert_path:
52+
client = aioapns.APNs(
53+
client_cert=cert_path,
54+
key=key_path,
55+
key_id=key_id,
56+
team_id=team_id,
57+
use_sandbox=get_manager().get_apns_use_sandbox(application_id),
58+
)
59+
60+
return client
4061

4162

4263
def _apns_prepare(
43-
token, alert, application_id=None, badge=None, sound=None, category=None,
44-
content_available=False, action_loc_key=None, loc_key=None, loc_args=[],
45-
extra={}, mutable_content=False, thread_id=None, url_args=None):
46-
if action_loc_key or loc_key or loc_args:
47-
apns2_alert = apns2_payload.PayloadAlert(
48-
body=alert if alert else {}, body_localized_key=loc_key,
49-
body_localized_args=loc_args, action_localized_key=action_loc_key)
50-
else:
51-
apns2_alert = alert
52-
53-
if callable(badge):
54-
badge = badge(token)
55-
56-
return apns2_payload.Payload(
57-
alert=apns2_alert, badge=badge, sound=sound, category=category,
58-
url_args=url_args, custom=extra, thread_id=thread_id,
59-
content_available=content_available, mutable_content=mutable_content)
60-
61-
62-
def _apns_send(
63-
registration_id, alert, batch=False, application_id=None, **kwargs
64+
token,
65+
alert,
66+
application_id=None,
67+
badge=None,
68+
sound=None,
69+
category=None,
70+
content_available=False,
71+
action_loc_key=None,
72+
loc_key=None,
73+
loc_args=[],
74+
extra={},
75+
mutable_content=False,
76+
thread_id=None,
77+
url_args=None,
6478
):
65-
notification_kwargs = {}
66-
67-
# if expiration isn"t specified use 1 month from now
68-
notification_kwargs["expiration"] = kwargs.pop("expiration", None)
69-
if not notification_kwargs["expiration"]:
70-
notification_kwargs["expiration"] = int(time.time()) + 2592000
71-
72-
priority = kwargs.pop("priority", None)
73-
if priority:
74-
try:
75-
notification_kwargs["priority"] = apns2_client.NotificationPriority(str(priority))
76-
except ValueError:
77-
raise APNSUnsupportedPriority("Unsupported priority %d" % (priority))
78-
79-
notification_kwargs["collapse_id"] = kwargs.pop("collapse_id", None)
80-
81-
with _apns_create_socket(application_id=application_id) as client:
82-
if batch:
83-
data = [apns2_client.Notification(
84-
token=rid, payload=_apns_prepare(rid, alert, **kwargs)) for rid in registration_id]
85-
# returns a dictionary mapping each token to its result. That
86-
# result is either "Success" or the reason for the failure.
87-
return client.send_notification_batch(
88-
data, get_manager().get_apns_topic(application_id=application_id),
89-
**notification_kwargs
90-
)
91-
92-
data = _apns_prepare(registration_id, alert, **kwargs)
93-
client.send_notification(
94-
registration_id, data,
95-
get_manager().get_apns_topic(application_id=application_id),
96-
**notification_kwargs
97-
)
79+
if action_loc_key or loc_key or loc_args:
80+
alert_payload = {
81+
"body": alert if alert else {},
82+
"body_localized_key": loc_key,
83+
"body_localized_args": loc_args,
84+
"action_localized_key": action_loc_key,
85+
}
86+
else:
87+
alert_payload = alert
88+
89+
if callable(badge):
90+
badge = badge(token)
91+
92+
return {
93+
"alert": alert_payload,
94+
"badge": badge,
95+
"sound": sound,
96+
"category": category,
97+
"url_args": url_args,
98+
"custom": extra,
99+
"thread_id": thread_id,
100+
"content_available": content_available,
101+
"mutable_content": mutable_content,
102+
}
103+
104+
105+
@async_to_sync
106+
async def _apns_send(
107+
registration_ids,
108+
alert,
109+
application_id=None,
110+
*,
111+
priority=None,
112+
expiration=None,
113+
collapse_id=None,
114+
**kwargs,
115+
):
116+
"""Make calls to APNs for each device token (registration_id) provided.
117+
118+
Since the underlying library (aioapns) is asynchronous, we are
119+
taking advantage of that here and making the requests in parallel.
120+
"""
121+
client = _apns_create_client(application_id=application_id)
122+
123+
# if expiration isn't specified use 1 month from now
124+
# converting to ttl for underlying library
125+
if expiration:
126+
time_to_live = expiration - int(time.time())
127+
else:
128+
time_to_live = 2592000
129+
130+
if priority is not None:
131+
if str(priority) not in [PRIORITY_HIGH, PRIORITY_NORMAL]:
132+
raise APNSUnsupportedPriority(f"Unsupported priority {priority}")
133+
134+
# track which device token belongs to each coroutine.
135+
# this allows us to stitch the results back together later
136+
coro_registration_ids = {}
137+
for registration_id in set(registration_ids):
138+
coro = client.send_notification(
139+
aioapns.NotificationRequest(
140+
device_token=registration_id,
141+
message={"aps": _apns_prepare(registration_id, alert, **kwargs)},
142+
time_to_live=time_to_live,
143+
priority=priority,
144+
collapse_key=collapse_id,
145+
)
146+
)
147+
coro_registration_ids[asyncio.create_task(coro)] = registration_id
148+
149+
# run all of the tasks. this will resolve once all requests are complete
150+
done, _ = await asyncio.wait(coro_registration_ids.keys())
151+
152+
# recombine task results with their device tokens
153+
results = {}
154+
for coro in done:
155+
registration_id = coro_registration_ids[coro]
156+
result = await coro
157+
if result.is_successful:
158+
results[registration_id] = SUCCESS_RESULT
159+
else:
160+
results[registration_id] = result.description
161+
162+
return results
98163

99164

100165
def apns_send_message(registration_id, alert, application_id=None, **kwargs):
101-
"""
102-
Sends an APNS notification to a single registration_id.
103-
This will send the notification as form data.
104-
If sending multiple notifications, it is more efficient to use
105-
apns_send_bulk_message()
106-
107-
Note that if set alert should always be a string. If it is not set,
108-
it won"t be included in the notification. You will need to pass None
109-
to this for silent notifications.
110-
"""
111-
112-
try:
113-
_apns_send(
114-
registration_id, alert, application_id=application_id,
115-
**kwargs
116-
)
117-
except apns2_errors.APNsException as apns2_exception:
118-
if isinstance(apns2_exception, apns2_errors.Unregistered):
119-
device = models.APNSDevice.objects.get(registration_id=registration_id)
120-
device.active = False
121-
device.save()
122-
123-
raise APNSServerError(status=apns2_exception.__class__.__name__)
124-
125-
126-
def apns_send_bulk_message(
127-
registration_ids, alert, application_id=None, **kwargs
128-
):
129-
"""
130-
Sends an APNS notification to one or more registration_ids.
131-
The registration_ids argument needs to be a list.
132-
133-
Note that if set alert should always be a string. If it is not set,
134-
it won"t be included in the notification. You will need to pass None
135-
to this for silent notifications.
136-
"""
137-
138-
results = _apns_send(
139-
registration_ids, alert, batch=True, application_id=application_id,
140-
**kwargs
141-
)
142-
inactive_tokens = [token for token, result in results.items() if result == "Unregistered"]
143-
models.APNSDevice.objects.filter(registration_id__in=inactive_tokens).update(active=False)
144-
return results
166+
"""
167+
Sends an APNS notification to a single registration_id.
168+
This will send the notification as form data.
169+
If sending multiple notifications, it is more efficient to use
170+
apns_send_bulk_message()
171+
172+
Note that if set alert should always be a string. If it is not set,
173+
it won"t be included in the notification. You will need to pass None
174+
to this for silent notifications.
175+
"""
176+
177+
results = _apns_send(
178+
[registration_id], alert, application_id=application_id, **kwargs
179+
)
180+
result = results[registration_id]
181+
182+
if result == SUCCESS_RESULT:
183+
return
184+
if result == UNREGISTERED_RESULT:
185+
models.APNSDevice.objects.filter(registration_id=registration_id).update(
186+
active=False
187+
)
188+
raise APNSServerError(status=result)
189+
190+
191+
def apns_send_bulk_message(registration_ids, alert, application_id=None, **kwargs):
192+
"""
193+
Sends an APNS notification to one or more registration_ids.
194+
The registration_ids argument needs to be a list.
195+
196+
Note that if set alert should always be a string. If it is not set,
197+
it won"t be included in the notification. You will need to pass None
198+
to this for silent notifications.
199+
"""
200+
201+
results = _apns_send(
202+
registration_ids, alert, application_id=application_id, **kwargs
203+
)
204+
inactive_tokens = [
205+
token for token, result in results.items() if result == UNREGISTERED_RESULT
206+
]
207+
models.APNSDevice.objects.filter(registration_id__in=inactive_tokens).update(
208+
active=False
209+
)
210+
return results

push_notifications/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from django.db import models
22
from django.utils.translation import gettext_lazy as _
33

4+
from .apns import apns_send_bulk_message
45
from .fields import HexIntegerField
56
from .settings import PUSH_NOTIFICATIONS_SETTINGS as SETTINGS
67

setup.cfg

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ setup_requires =
3535

3636
[options.extras_require]
3737
APNS =
38-
apns2>=0.3.0
38+
aioapns
39+
asgiref>=2.0
3940
importlib-metadata;python_version < "3.8"
40-
Django>=2.2
4141

4242
WP = pywebpush>=1.3.0
4343

0 commit comments

Comments
 (0)