77
88from __future__ import annotations
99
10+ import asyncio
1011import time
11- from collections .abc import Callable
12+ from collections .abc import Callable , Coroutine
1213from typing import Any
1314
1415from loguru import logger
@@ -67,6 +68,13 @@ def get(self, key: Any) -> Any | None:
6768 -------
6869 Any | None
6970 The cached value, or None if not found or expired.
71+
72+ Notes
73+ -----
74+ This method is safe for concurrent async access. In Python's async model
75+ (single-threaded event loop), dict operations are atomic. The check-then-act
76+ pattern here has no await points, so no other coroutine can run between
77+ the check and the access, making it race-condition safe.
7078 """
7179 if key not in self ._cache :
7280 return None
@@ -181,17 +189,25 @@ class GuildConfigCacheManager:
181189 when config is updated from multiple sources.
182190 """
183191
184- __slots__ = ("_cache" ,)
192+ __slots__ = ("_cache" , "_locks" )
185193 _instance : GuildConfigCacheManager | None = None
186194 _cache : TTLCache
195+ _locks : dict [int , asyncio .Lock ]
187196
188197 def __new__ (cls ) -> GuildConfigCacheManager :
189198 """Create or return the singleton instance."""
190199 if cls ._instance is None :
191200 cls ._instance = super ().__new__ (cls )
192201 cls ._instance ._cache = TTLCache (ttl = 300.0 , max_size = 1000 )
202+ cls ._instance ._locks = {} # Per-guild locks for cache updates
193203 return cls ._instance
194204
205+ def _get_lock (self , guild_id : int ) -> asyncio .Lock :
206+ """Get or create a lock for a specific guild."""
207+ if guild_id not in self ._locks :
208+ self ._locks [guild_id ] = asyncio .Lock ()
209+ return self ._locks [guild_id ]
210+
195211 def get (self , guild_id : int ) -> dict [str , int | None ] | None :
196212 """
197213 Get cached guild config for a guild.
@@ -221,6 +237,9 @@ def set(
221237 """
222238 Cache guild config for a guild.
223239
240+ This method is safe for concurrent access. Multiple coroutines can
241+ call this with partial updates without losing data.
242+
224243 Parameters
225244 ----------
226245 guild_id : int
@@ -233,6 +252,12 @@ def set(
233252 The jail role ID. Omit to skip updating this field.
234253 jail_channel_id : int | None, optional
235254 The jail channel ID. Omit to skip updating this field.
255+
256+ Notes
257+ -----
258+ This method is synchronous and has no await points, making it atomic
259+ in Python's async model. However, when called from async code that
260+ has await points between cache check and set, use async_set() instead.
236261 """
237262 cache_key = f"guild_config_{ guild_id } "
238263 # Get existing cache or create new dict
@@ -252,6 +277,53 @@ def set(
252277
253278 self ._cache .set (cache_key , updated )
254279
280+ async def async_set (
281+ self ,
282+ guild_id : int ,
283+ audit_log_id : int | None = _MISSING ,
284+ mod_log_id : int | None = _MISSING ,
285+ jail_role_id : int | None = _MISSING ,
286+ jail_channel_id : int | None = _MISSING ,
287+ ) -> None :
288+ """
289+ Cache guild config for a guild with async locking.
290+
291+ Use this method when called from async code that has await points
292+ between cache check and set, to prevent race conditions with concurrent
293+ partial updates.
294+
295+ Parameters
296+ ----------
297+ guild_id : int
298+ The guild ID.
299+ audit_log_id : int | None, optional
300+ The audit log channel ID. Omit to skip updating this field.
301+ mod_log_id : int | None, optional
302+ The mod log channel ID. Omit to skip updating this field.
303+ jail_role_id : int | None, optional
304+ The jail role ID. Omit to skip updating this field.
305+ jail_channel_id : int | None, optional
306+ The jail channel ID. Omit to skip updating this field.
307+ """
308+ lock = self ._get_lock (guild_id )
309+ async with lock :
310+ # Re-check cache after acquiring lock (another coroutine may have updated it)
311+ cache_key = f"guild_config_{ guild_id } "
312+ existing : dict [str , int | None ] = self ._cache .get (cache_key ) or {}
313+ updated : dict [str , int | None ] = dict (existing )
314+
315+ # Only update fields that were explicitly provided (not _MISSING)
316+ if audit_log_id is not _MISSING :
317+ updated ["audit_log_id" ] = audit_log_id
318+ if mod_log_id is not _MISSING :
319+ updated ["mod_log_id" ] = mod_log_id
320+ if jail_role_id is not _MISSING :
321+ updated ["jail_role_id" ] = jail_role_id
322+ if jail_channel_id is not _MISSING :
323+ updated ["jail_channel_id" ] = jail_channel_id
324+
325+ self ._cache .set (cache_key , updated )
326+
255327 def invalidate (self , guild_id : int ) -> None :
256328 """
257329 Invalidate cached guild config for a guild.
@@ -279,18 +351,83 @@ class JailStatusCache:
279351 tuple to reduce database queries for frequently checked jail status.
280352 """
281353
282- __slots__ = ("_cache" ,)
354+ __slots__ = ("_cache" , "_locks" , "_locks_lock" )
283355 _instance : JailStatusCache | None = None
284356 _cache : TTLCache
357+ _locks : dict [tuple [int , int ], asyncio .Lock ]
358+ _locks_lock : asyncio .Lock
285359
286360 def __new__ (cls ) -> JailStatusCache :
287361 """Create or return the singleton instance."""
288362 if cls ._instance is None :
289363 cls ._instance = super ().__new__ (cls )
290364 # 60 second TTL - jail status changes infrequently
291365 cls ._instance ._cache = TTLCache (ttl = 60.0 , max_size = 5000 )
366+ cls ._instance ._locks = {} # Per (guild_id, user_id) locks for cache updates
367+ cls ._instance ._locks_lock = (
368+ asyncio .Lock ()
369+ ) # Lock for managing the locks dict
292370 return cls ._instance
293371
372+ def _get_lock_key (self , guild_id : int , user_id : int ) -> tuple [int , int ]:
373+ """Generate lock key for guild_id and user_id."""
374+ return (guild_id , user_id )
375+
376+ async def _get_lock (self , guild_id : int , user_id : int ) -> asyncio .Lock :
377+ """Get or create a lock for a specific (guild_id, user_id) pair."""
378+ lock_key = self ._get_lock_key (guild_id , user_id )
379+ async with self ._locks_lock :
380+ if lock_key not in self ._locks :
381+ self ._locks [lock_key ] = asyncio .Lock ()
382+ return self ._locks [lock_key ]
383+
384+ async def get_or_fetch (
385+ self ,
386+ guild_id : int ,
387+ user_id : int ,
388+ fetch_func : Callable [[], Coroutine [Any , Any , bool ]],
389+ ) -> bool :
390+ """
391+ Get cached value or fetch and cache with async locking.
392+
393+ Prevents cache stampede when multiple coroutines miss the cache
394+ simultaneously. Only one coroutine will fetch, others will wait
395+ and use the cached result.
396+
397+ Parameters
398+ ----------
399+ guild_id : int
400+ The guild ID.
401+ user_id : int
402+ The user ID.
403+ fetch_func : Callable[[], Coroutine[Any, Any, bool]]
404+ Async function to fetch the value if not cached.
405+
406+ Returns
407+ -------
408+ bool
409+ The cached or fetched jail status.
410+ """
411+ # Check cache first (fast path)
412+ cached_status = self .get (guild_id , user_id )
413+ if cached_status is not None :
414+ return cached_status
415+
416+ # Cache miss - acquire lock to prevent stampede
417+ lock = await self ._get_lock (guild_id , user_id )
418+ async with lock :
419+ # Re-check cache after acquiring lock
420+ cached_status = self .get (guild_id , user_id )
421+ if cached_status is not None :
422+ return cached_status
423+
424+ # Still a cache miss - fetch from database
425+ is_jailed = await fetch_func ()
426+
427+ # Cache the result (atomic operation, no await points)
428+ self .set (guild_id , user_id , is_jailed )
429+ return is_jailed
430+
294431 def _get_key (self , guild_id : int , user_id : int ) -> str :
295432 """Generate cache key for guild_id and user_id."""
296433 return f"jail_status_{ guild_id } _{ user_id } "
@@ -318,6 +455,9 @@ def set(self, guild_id: int, user_id: int, is_jailed: bool) -> None:
318455 """
319456 Cache jail status for a user.
320457
458+ This method is safe for concurrent access. Multiple coroutines can
459+ call this without losing data.
460+
321461 Parameters
322462 ----------
323463 guild_id : int
@@ -326,10 +466,43 @@ def set(self, guild_id: int, user_id: int, is_jailed: bool) -> None:
326466 The user ID.
327467 is_jailed : bool
328468 Whether the user is jailed.
469+
470+ Notes
471+ -----
472+ This method is synchronous and has no await points, making it atomic
473+ in Python's async model. However, when called from async code that
474+ has await points between cache check and set, use async_set() instead.
329475 """
330476 cache_key = self ._get_key (guild_id , user_id )
331477 self ._cache .set (cache_key , is_jailed )
332478
479+ async def async_set (self , guild_id : int , user_id : int , is_jailed : bool ) -> None :
480+ """
481+ Cache jail status for a user with async locking.
482+
483+ Use this method when called from async code that has await points
484+ between cache check and set, to prevent cache stampede when multiple
485+ coroutines miss the cache simultaneously.
486+
487+ Parameters
488+ ----------
489+ guild_id : int
490+ The guild ID.
491+ user_id : int
492+ The user ID.
493+ is_jailed : bool
494+ Whether the user is jailed.
495+ """
496+ lock = await self ._get_lock (guild_id , user_id )
497+ async with lock :
498+ # Re-check cache after acquiring lock (another coroutine may have updated it)
499+ cache_key = self ._get_key (guild_id , user_id )
500+ cached = self ._cache .get (cache_key )
501+ if cached is not None :
502+ # Another coroutine already cached the value, no need to set again
503+ return
504+ self ._cache .set (cache_key , is_jailed )
505+
333506 def invalidate (self , guild_id : int , user_id : int ) -> None :
334507 """
335508 Invalidate cached jail status for a user.
0 commit comments