|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from copy import copy |
4 | | - |
5 | | -from .hashed_data import HashedIterable |
6 | | -from .utils import All |
7 | | - |
8 | 3 | """ |
9 | 4 | Cache utilities. |
10 | 5 |
|
11 | 6 | This module provides caching datastructures and utilities. |
12 | | -It also exposes a runtime switch to enable/disable caching. |
13 | 7 | """ |
14 | | -import contextvars |
15 | | -from collections import UserDict |
16 | | -from dataclasses import dataclass, field, InitVar |
17 | | -from typing_extensions import Dict, Any, Iterable, Hashable, Optional, Tuple |
18 | | - |
19 | | - |
20 | | -# Runtime switch to enable/disable caching paths |
21 | | -_caching_enabled = contextvars.ContextVar("caching_enabled", default=True) |
22 | | - |
23 | | - |
24 | | -def enable_caching() -> None: |
25 | | - """ |
26 | | - Enable the caching fast-paths for query evaluation. |
27 | | -
|
28 | | - :return: None |
29 | | - :rtype: None |
30 | | - """ |
31 | | - _caching_enabled.set(True) |
32 | | - |
33 | | - |
34 | | -def disable_caching() -> None: |
35 | | - """ |
36 | | - Disable the caching fast-paths for query evaluation. |
37 | | -
|
38 | | - :return: None |
39 | | - :rtype: None |
40 | | - """ |
41 | | - _caching_enabled.set(False) |
42 | | - |
43 | | - |
44 | | -def is_caching_enabled() -> bool: |
45 | | - """ |
46 | | - Check whether caching is currently enabled. |
47 | | -
|
48 | | - :return: True if caching is enabled; False otherwise. |
49 | | - :rtype: bool |
50 | | - """ |
51 | | - return _caching_enabled.get() |
| 8 | +from dataclasses import dataclass, field |
| 9 | +from typing_extensions import Dict |
52 | 10 |
|
53 | 11 |
|
54 | 12 | @dataclass |
@@ -124,183 +82,3 @@ def clear(self): |
124 | 82 | self.all_seen = False |
125 | 83 | self.constraints.clear() |
126 | 84 | self.exact.clear() |
127 | | - |
128 | | - |
129 | | -class CacheDict(UserDict): ... |
130 | | - |
131 | | - |
132 | | -@dataclass |
133 | | -class IndexedCache: |
134 | | - """ |
135 | | - A hierarchical cache keyed by a fixed sequence of indices. |
136 | | -
|
137 | | - It supports insertion of outputs under partial assignments and retrieval with |
138 | | - wildcard handling using the ALL sentinel. |
139 | | -
|
140 | | - :ivar keys: Ordered list of integer keys to index the cache. |
141 | | - :ivar seen_set: Helper to track assignments already checked. |
142 | | - :ivar cache: Nested mapping structure storing cached results. |
143 | | - :ivar flat_cache: Flattened mapping structure storing cached results. |
144 | | - """ |
145 | | - |
146 | | - _keys: InitVar[Tuple[Hashable]] = () |
147 | | - seen_set: SeenSet = field(default_factory=SeenSet, init=False) |
148 | | - cache: CacheDict = field(default_factory=CacheDict, init=False) |
149 | | - flat_cache: HashedIterable = field(default_factory=HashedIterable, init=False) |
150 | | - |
151 | | - def __post_init__(self, _keys: Tuple[Hashable] = ()): |
152 | | - self.keys = _keys |
153 | | - |
154 | | - @property |
155 | | - def keys(self) -> Tuple[Hashable]: |
156 | | - return self.seen_set.keys |
157 | | - |
158 | | - @keys.setter |
159 | | - def keys(self, keys: Iterable[Hashable]): |
160 | | - self.seen_set.keys = tuple(sorted(keys)) |
161 | | - # Reset structures |
162 | | - self.cache.clear() |
163 | | - self.seen_set.clear() |
164 | | - |
165 | | - def insert(self, assignment: Dict, output: Any, index: bool = True) -> None: |
166 | | - """ |
167 | | - Insert an output under the given partial assignment. |
168 | | -
|
169 | | - Missing keys are filled with the ALL sentinel. |
170 | | -
|
171 | | - :param assignment: Mapping from key index to concrete value. |
172 | | - :type assignment: Dict |
173 | | - :param output: Cached value to store at the leaf. |
174 | | - :type output: Any |
175 | | - :param index: If True, insert into cache tree; otherwise, store directly in a flat cache. |
176 | | - :type index: bool |
177 | | - :return: None |
178 | | - :rtype: None |
179 | | - """ |
180 | | - # Make a shallow copy only for seen_set tracking to avoid mutating caller's dict |
181 | | - if not index or not assignment: |
182 | | - self.flat_cache.add(output) |
183 | | - return |
184 | | - |
185 | | - seen_assignment = dict(assignment) |
186 | | - self.seen_set.add(seen_assignment) |
187 | | - |
188 | | - cache = self.cache |
189 | | - keys = self.keys |
190 | | - last_idx = len(keys) - 1 |
191 | | - |
192 | | - # Use local lookups and setdefault to reduce overhead |
193 | | - for idx, k in enumerate(keys): |
194 | | - v = assignment.get(k, All) |
195 | | - if idx < last_idx: |
196 | | - next_cache = cache.get(v) |
197 | | - if next_cache is None: |
198 | | - next_cache = CacheDict() |
199 | | - cache[v] = next_cache |
200 | | - cache = next_cache |
201 | | - else: |
202 | | - cache[v] = output |
203 | | - |
204 | | - def check(self, assignment: Dict) -> bool: |
205 | | - """ |
206 | | - Check if seen entries cover an assignment (dict). |
207 | | -
|
208 | | - :param assignment: The assignment to check. |
209 | | - """ |
210 | | - return self.seen_set.check(assignment) |
211 | | - |
212 | | - def __getitem__(self, key: Any): |
213 | | - return self.flat_cache[key] |
214 | | - |
215 | | - def retrieve( |
216 | | - self, |
217 | | - assignment: Optional[Dict] = None, |
218 | | - cache=None, |
219 | | - key_idx=0, |
220 | | - result: Dict = None, |
221 | | - from_index: bool = True, |
222 | | - ) -> Iterable: |
223 | | - """ |
224 | | - Retrieve leaf results matching a (possibly partial) assignment. |
225 | | -
|
226 | | - This yields tuples of (resolved_assignment, value) when a leaf is found. |
227 | | -
|
228 | | - :param assignment: Partial mapping from key index to values. |
229 | | - :param cache: Internal recursion parameter; the current cache node. |
230 | | - :param key_idx: Internal recursion parameter; current key index position. |
231 | | - :param result: Internal accumulator for building a full assignment. |
232 | | - :param from_index: If True, retrieve from cache tree; otherwise, retrieve directly from flat cache. |
233 | | - :return: Generator of (assignment, value) pairs. |
234 | | - :rtype: Iterable |
235 | | - """ |
236 | | - if not from_index: |
237 | | - for v in self.flat_cache: |
238 | | - yield {}, v |
239 | | - return |
240 | | - |
241 | | - # Initialize result only once; avoid repeated copying where possible |
242 | | - if result is None: |
243 | | - result = copy(assignment) |
244 | | - if cache is None: |
245 | | - cache = self.cache |
246 | | - # Fast return on empty cache node |
247 | | - if isinstance(cache, CacheDict) and not cache: |
248 | | - return |
249 | | - keys = self.keys |
250 | | - n_keys = len(keys) |
251 | | - key = keys[key_idx] |
252 | | - |
253 | | - # Follow the concrete chain as far as it exists without exceptions |
254 | | - while key in assignment: |
255 | | - next_cache = cache.get(assignment[key]) |
256 | | - if next_cache is None: |
257 | | - # Try wildcard branch at this level |
258 | | - wildcard = cache.get(All) |
259 | | - if wildcard is not None: |
260 | | - yield from self._yield_result(assignment, wildcard, key_idx, result) |
261 | | - return |
262 | | - cache = next_cache |
263 | | - if key_idx + 1 < n_keys: |
264 | | - key_idx += 1 |
265 | | - key = keys[key_idx] |
266 | | - else: |
267 | | - break |
268 | | - |
269 | | - if key not in assignment: |
270 | | - # Prefer wildcard branch if available |
271 | | - wildcard = cache.get(All) |
272 | | - if wildcard is not None: |
273 | | - yield from self._yield_result(assignment, wildcard, key_idx, result) |
274 | | - else: |
275 | | - # Explore all branches at this level, copying only the minimal delta |
276 | | - for cache_key, cache_val in cache.items(): |
277 | | - local_result = copy(result) |
278 | | - local_result[key] = cache_key |
279 | | - yield from self._yield_result( |
280 | | - assignment, cache_val, key_idx, local_result |
281 | | - ) |
282 | | - else: |
283 | | - # Reached the leaf (value or next dict) specifically specified by assignment |
284 | | - yield result, cache |
285 | | - |
286 | | - def clear(self): |
287 | | - self.cache.clear() |
288 | | - self.seen_set.clear() |
289 | | - self.flat_cache.clear() |
290 | | - |
291 | | - def _yield_result( |
292 | | - self, assignment: Dict, cache_val: Any, key_idx: int, result: Dict[int, Any] |
293 | | - ): |
294 | | - """ |
295 | | - Internal helper to descend into cache and yield concrete results. |
296 | | -
|
297 | | - :param assignment: Original partial assignment. |
298 | | - :param cache_val: Current cache node or value. |
299 | | - :param key_idx: Current key index. |
300 | | - :param result: Accumulated assignment. |
301 | | - :return: Yields (assignment, value) when reaching leaves. |
302 | | - """ |
303 | | - if isinstance(cache_val, CacheDict): |
304 | | - yield from self.retrieve(assignment, cache_val, key_idx + 1, result) |
305 | | - else: |
306 | | - yield result, cache_val |
0 commit comments