Skip to content

Commit f67dff1

Browse files
committed
feat(cache): implement async locking for cache updates to prevent race conditions
- Enhanced the `GuildConfigCacheManager` and `JailStatusCache` classes with async locking mechanisms to ensure thread safety during concurrent updates. - Introduced `async_set` methods for both classes to handle cache updates safely when called from async code. - Added detailed documentation for new methods, emphasizing their atomicity and safety in concurrent environments. - Improved the overall cache management strategy to prevent stampedes and ensure data integrity across multiple coroutines.
1 parent 7c3d40f commit f67dff1

File tree

1 file changed

+176
-3
lines changed

1 file changed

+176
-3
lines changed

src/tux/shared/cache.py

Lines changed: 176 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
from __future__ import annotations
99

10+
import asyncio
1011
import time
11-
from collections.abc import Callable
12+
from collections.abc import Callable, Coroutine
1213
from typing import Any
1314

1415
from loguru import logger
@@ -67,6 +68,13 @@ def get(self, key: Any) -> Any | None:
6768
-------
6869
Any | None
6970
The cached value, or None if not found or expired.
71+
72+
Notes
73+
-----
74+
This method is safe for concurrent async access. In Python's async model
75+
(single-threaded event loop), dict operations are atomic. The check-then-act
76+
pattern here has no await points, so no other coroutine can run between
77+
the check and the access, making it race-condition safe.
7078
"""
7179
if key not in self._cache:
7280
return None
@@ -181,17 +189,25 @@ class GuildConfigCacheManager:
181189
when config is updated from multiple sources.
182190
"""
183191

184-
__slots__ = ("_cache",)
192+
__slots__ = ("_cache", "_locks")
185193
_instance: GuildConfigCacheManager | None = None
186194
_cache: TTLCache
195+
_locks: dict[int, asyncio.Lock]
187196

188197
def __new__(cls) -> GuildConfigCacheManager:
189198
"""Create or return the singleton instance."""
190199
if cls._instance is None:
191200
cls._instance = super().__new__(cls)
192201
cls._instance._cache = TTLCache(ttl=300.0, max_size=1000)
202+
cls._instance._locks = {} # Per-guild locks for cache updates
193203
return cls._instance
194204

205+
def _get_lock(self, guild_id: int) -> asyncio.Lock:
206+
"""Get or create a lock for a specific guild."""
207+
if guild_id not in self._locks:
208+
self._locks[guild_id] = asyncio.Lock()
209+
return self._locks[guild_id]
210+
195211
def get(self, guild_id: int) -> dict[str, int | None] | None:
196212
"""
197213
Get cached guild config for a guild.
@@ -221,6 +237,9 @@ def set(
221237
"""
222238
Cache guild config for a guild.
223239
240+
This method is safe for concurrent access. Multiple coroutines can
241+
call this with partial updates without losing data.
242+
224243
Parameters
225244
----------
226245
guild_id : int
@@ -233,6 +252,12 @@ def set(
233252
The jail role ID. Omit to skip updating this field.
234253
jail_channel_id : int | None, optional
235254
The jail channel ID. Omit to skip updating this field.
255+
256+
Notes
257+
-----
258+
This method is synchronous and has no await points, making it atomic
259+
in Python's async model. However, when called from async code that
260+
has await points between cache check and set, use async_set() instead.
236261
"""
237262
cache_key = f"guild_config_{guild_id}"
238263
# Get existing cache or create new dict
@@ -252,6 +277,53 @@ def set(
252277

253278
self._cache.set(cache_key, updated)
254279

280+
async def async_set(
281+
self,
282+
guild_id: int,
283+
audit_log_id: int | None = _MISSING,
284+
mod_log_id: int | None = _MISSING,
285+
jail_role_id: int | None = _MISSING,
286+
jail_channel_id: int | None = _MISSING,
287+
) -> None:
288+
"""
289+
Cache guild config for a guild with async locking.
290+
291+
Use this method when called from async code that has await points
292+
between cache check and set, to prevent race conditions with concurrent
293+
partial updates.
294+
295+
Parameters
296+
----------
297+
guild_id : int
298+
The guild ID.
299+
audit_log_id : int | None, optional
300+
The audit log channel ID. Omit to skip updating this field.
301+
mod_log_id : int | None, optional
302+
The mod log channel ID. Omit to skip updating this field.
303+
jail_role_id : int | None, optional
304+
The jail role ID. Omit to skip updating this field.
305+
jail_channel_id : int | None, optional
306+
The jail channel ID. Omit to skip updating this field.
307+
"""
308+
lock = self._get_lock(guild_id)
309+
async with lock:
310+
# Re-check cache after acquiring lock (another coroutine may have updated it)
311+
cache_key = f"guild_config_{guild_id}"
312+
existing: dict[str, int | None] = self._cache.get(cache_key) or {}
313+
updated: dict[str, int | None] = dict(existing)
314+
315+
# Only update fields that were explicitly provided (not _MISSING)
316+
if audit_log_id is not _MISSING:
317+
updated["audit_log_id"] = audit_log_id
318+
if mod_log_id is not _MISSING:
319+
updated["mod_log_id"] = mod_log_id
320+
if jail_role_id is not _MISSING:
321+
updated["jail_role_id"] = jail_role_id
322+
if jail_channel_id is not _MISSING:
323+
updated["jail_channel_id"] = jail_channel_id
324+
325+
self._cache.set(cache_key, updated)
326+
255327
def invalidate(self, guild_id: int) -> None:
256328
"""
257329
Invalidate cached guild config for a guild.
@@ -279,18 +351,83 @@ class JailStatusCache:
279351
tuple to reduce database queries for frequently checked jail status.
280352
"""
281353

282-
__slots__ = ("_cache",)
354+
__slots__ = ("_cache", "_locks", "_locks_lock")
283355
_instance: JailStatusCache | None = None
284356
_cache: TTLCache
357+
_locks: dict[tuple[int, int], asyncio.Lock]
358+
_locks_lock: asyncio.Lock
285359

286360
def __new__(cls) -> JailStatusCache:
287361
"""Create or return the singleton instance."""
288362
if cls._instance is None:
289363
cls._instance = super().__new__(cls)
290364
# 60 second TTL - jail status changes infrequently
291365
cls._instance._cache = TTLCache(ttl=60.0, max_size=5000)
366+
cls._instance._locks = {} # Per (guild_id, user_id) locks for cache updates
367+
cls._instance._locks_lock = (
368+
asyncio.Lock()
369+
) # Lock for managing the locks dict
292370
return cls._instance
293371

372+
def _get_lock_key(self, guild_id: int, user_id: int) -> tuple[int, int]:
373+
"""Generate lock key for guild_id and user_id."""
374+
return (guild_id, user_id)
375+
376+
async def _get_lock(self, guild_id: int, user_id: int) -> asyncio.Lock:
377+
"""Get or create a lock for a specific (guild_id, user_id) pair."""
378+
lock_key = self._get_lock_key(guild_id, user_id)
379+
async with self._locks_lock:
380+
if lock_key not in self._locks:
381+
self._locks[lock_key] = asyncio.Lock()
382+
return self._locks[lock_key]
383+
384+
async def get_or_fetch(
385+
self,
386+
guild_id: int,
387+
user_id: int,
388+
fetch_func: Callable[[], Coroutine[Any, Any, bool]],
389+
) -> bool:
390+
"""
391+
Get cached value or fetch and cache with async locking.
392+
393+
Prevents cache stampede when multiple coroutines miss the cache
394+
simultaneously. Only one coroutine will fetch, others will wait
395+
and use the cached result.
396+
397+
Parameters
398+
----------
399+
guild_id : int
400+
The guild ID.
401+
user_id : int
402+
The user ID.
403+
fetch_func : Callable[[], Coroutine[Any, Any, bool]]
404+
Async function to fetch the value if not cached.
405+
406+
Returns
407+
-------
408+
bool
409+
The cached or fetched jail status.
410+
"""
411+
# Check cache first (fast path)
412+
cached_status = self.get(guild_id, user_id)
413+
if cached_status is not None:
414+
return cached_status
415+
416+
# Cache miss - acquire lock to prevent stampede
417+
lock = await self._get_lock(guild_id, user_id)
418+
async with lock:
419+
# Re-check cache after acquiring lock
420+
cached_status = self.get(guild_id, user_id)
421+
if cached_status is not None:
422+
return cached_status
423+
424+
# Still a cache miss - fetch from database
425+
is_jailed = await fetch_func()
426+
427+
# Cache the result (atomic operation, no await points)
428+
self.set(guild_id, user_id, is_jailed)
429+
return is_jailed
430+
294431
def _get_key(self, guild_id: int, user_id: int) -> str:
295432
"""Generate cache key for guild_id and user_id."""
296433
return f"jail_status_{guild_id}_{user_id}"
@@ -318,6 +455,9 @@ def set(self, guild_id: int, user_id: int, is_jailed: bool) -> None:
318455
"""
319456
Cache jail status for a user.
320457
458+
This method is safe for concurrent access. Multiple coroutines can
459+
call this without losing data.
460+
321461
Parameters
322462
----------
323463
guild_id : int
@@ -326,10 +466,43 @@ def set(self, guild_id: int, user_id: int, is_jailed: bool) -> None:
326466
The user ID.
327467
is_jailed : bool
328468
Whether the user is jailed.
469+
470+
Notes
471+
-----
472+
This method is synchronous and has no await points, making it atomic
473+
in Python's async model. However, when called from async code that
474+
has await points between cache check and set, use async_set() instead.
329475
"""
330476
cache_key = self._get_key(guild_id, user_id)
331477
self._cache.set(cache_key, is_jailed)
332478

479+
async def async_set(self, guild_id: int, user_id: int, is_jailed: bool) -> None:
480+
"""
481+
Cache jail status for a user with async locking.
482+
483+
Use this method when called from async code that has await points
484+
between cache check and set, to prevent cache stampede when multiple
485+
coroutines miss the cache simultaneously.
486+
487+
Parameters
488+
----------
489+
guild_id : int
490+
The guild ID.
491+
user_id : int
492+
The user ID.
493+
is_jailed : bool
494+
Whether the user is jailed.
495+
"""
496+
lock = await self._get_lock(guild_id, user_id)
497+
async with lock:
498+
# Re-check cache after acquiring lock (another coroutine may have updated it)
499+
cache_key = self._get_key(guild_id, user_id)
500+
cached = self._cache.get(cache_key)
501+
if cached is not None:
502+
# Another coroutine already cached the value, no need to set again
503+
return
504+
self._cache.set(cache_key, is_jailed)
505+
333506
def invalidate(self, guild_id: int, user_id: int) -> None:
334507
"""
335508
Invalidate cached jail status for a user.

0 commit comments

Comments
 (0)