Skip to content
Merged
1 change: 1 addition & 0 deletions .github/workflows/e2e-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ on:
env:
CARGO_TERM_COLOR: always
VERBOSE: ${{ github.event.inputs.verbose }}
CACHE_LOCAL: "1"

# job to run tests in parallel
jobs:
Expand Down
25 changes: 23 additions & 2 deletions async_substrate_interface/async_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import socket
import ssl
import warnings
from contextlib import suppress
from unittest.mock import AsyncMock
from hashlib import blake2b
from typing import (
Expand Down Expand Up @@ -1211,6 +1212,8 @@ def __init__(
self.metadata_version_hex = "0x0f000000" # v15
self._initializing = False
self._mock = _mock
self.startup_runtime_task: Optional[asyncio.Task] = None
self.startup_block_hash: Optional[str] = None

async def __aenter__(self):
if not self._mock:
Expand All @@ -1230,8 +1233,12 @@ async def _initialize(self) -> None:
if not self._chain:
chain = await self.rpc_request("system_chain", [])
self._chain = chain.get("result")
runtime = await self.init_runtime()
self.startup_block_hash = block_hash = await self.get_chain_head()
self.startup_runtime_task = asyncio.create_task(
self.init_runtime(block_hash=block_hash, init=True)
)
if self.ss58_format is None:
runtime = await self.init_runtime(block_hash)
# Check and apply runtime constants
ss58_prefix_constant = await self.get_constant(
"System", "SS58Prefix", runtime=runtime
Expand Down Expand Up @@ -1438,7 +1445,10 @@ async def decode_scale(
return obj

async def init_runtime(
self, block_hash: Optional[str] = None, block_id: Optional[int] = None
self,
block_hash: Optional[str] = None,
block_id: Optional[int] = None,
init: bool = False,
) -> Runtime:
"""
This method is used by all other methods that deals with metadata and types defined in the type registry.
Expand All @@ -1455,6 +1465,13 @@ async def init_runtime(
Returns:
Runtime object
"""
if (
not init
and self.startup_runtime_task is not None
and block_hash == self.startup_block_hash
):
await self.startup_runtime_task
self.startup_runtime_task = None

if block_id and block_hash:
raise ValueError("Cannot provide block_hash and block_id at the same time")
Expand Down Expand Up @@ -4322,6 +4339,10 @@ async def close(self):
Closes the substrate connection, and the websocket connection.
"""
try:
if self.startup_runtime_task is not None:
self.startup_runtime_task.cancel()
with suppress(asyncio.CancelledError):
await self.startup_runtime_task
await self.ws.shutdown()
except AttributeError:
pass
Expand Down
33 changes: 21 additions & 12 deletions async_substrate_interface/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


USE_CACHE = True if os.getenv("NO_CACHE") != "1" else False
CACHE_LOCAL = os.getenv("CACHE_LOCAL") == "1"
CACHE_LOCATION = (
os.path.expanduser(
os.getenv("CACHE_LOCATION", "~/.cache/async-substrate-interface")
Expand All @@ -30,13 +31,15 @@ class AsyncSqliteDB:
_instances: dict[str, "AsyncSqliteDB"] = {}
_db: Optional[aiosqlite.Connection] = None
_lock: Optional[asyncio.Lock] = None
_created_tables: set

def __new__(cls, chain_endpoint: str):
try:
return cls._instances[chain_endpoint]
except KeyError:
instance = super().__new__(cls)
instance._lock = asyncio.Lock()
instance._created_tables = set()
cls._instances[chain_endpoint] = instance
return instance

Expand All @@ -45,8 +48,11 @@ async def close(self):
if self._db:
await self._db.close()
self._db = None
self._created_tables.clear()

async def _create_if_not_exists(self, chain: str, table_name: str):
if table_name in self._created_tables:
return _check_if_local(chain)
if not (local_chain := _check_if_local(chain)) or not USE_CACHE:
await self._db.execute(
f"""
Expand Down Expand Up @@ -76,6 +82,7 @@ async def _create_if_not_exists(self, chain: str, table_name: str):
"""
)
await self._db.commit()
self._created_tables.add(table_name)
return local_chain

async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any]:
Expand All @@ -86,18 +93,18 @@ async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any]
table_name = _get_table_name(func)
local_chain = await self._create_if_not_exists(chain, table_name)
key = pickle.dumps((args, kwargs or None))
try:
cursor: aiosqlite.Cursor = await self._db.execute(
f"SELECT value FROM {table_name} WHERE key=? AND chain=?",
(key, chain),
)
result = await cursor.fetchone()
await cursor.close()
if result is not None:
return pickle.loads(result[0])
except (pickle.PickleError, sqlite3.Error) as e:
logger.exception("Cache error", exc_info=e)
pass
if not local_chain or not USE_CACHE:
try:
cursor: aiosqlite.Cursor = await self._db.execute(
f"SELECT value FROM {table_name} WHERE key=? AND chain=?",
(key, chain),
)
result = await cursor.fetchone()
await cursor.close()
if result is not None:
return pickle.loads(result[0])
except (pickle.PickleError, sqlite3.Error) as e:
logger.exception("Cache error", exc_info=e)
result = await func(other_self, *args, **kwargs)
if not local_chain or not USE_CACHE:
# TODO use a task here
Expand Down Expand Up @@ -202,6 +209,8 @@ def _get_table_name(func):


def _check_if_local(chain: str) -> bool:
if CACHE_LOCAL:
return False
return any([x in chain for x in ["127.0.0.1", "localhost", "0.0.0.0"]])


Expand Down
61 changes: 30 additions & 31 deletions tests/integration_tests/test_disk_cache.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
"""
Thresholds:
DISK_CACHE_TIMEOUT – first access per method hits SQLite (aiosqlite thread-pool
overhead); must be << any real network call (~200 ms).
MEMORY_CACHE_TIMEOUT – repeat access with the same args hits the in-process LRU;
should be effectively instant.
"""

import pytest
import time
from async_substrate_interface.async_substrate import (
Expand All @@ -8,6 +16,10 @@
from tests.helpers.settings import LATENT_LITE_ENTRYPOINT


DISK_CACHE_TIMEOUT = 0.5
MEMORY_CACHE_TIMEOUT = 0.002


@pytest.mark.asyncio
async def test_disk_cache():
print("Testing test_disk_cache")
Expand Down Expand Up @@ -81,57 +93,44 @@ async def test_disk_cache():
assert parent_block_hash == parent_block_hash_sync
assert block_runtime_info == block_runtime_info_sync
assert block_runtime_version_for == block_runtime_version_for_sync
# Verify data is pulling from disk cache
# Verify data is pulling from disk cache.
async with DiskCachedAsyncSubstrateInterface(
LATENT_LITE_ENTRYPOINT, ss58_format=42, chain_name="Bittensor"
) as disk_cached_substrate:
start = time.monotonic()
new_block_hash = await disk_cached_substrate.get_block_hash(current_block)
new_time = time.monotonic()
assert new_time - start < 0.001
assert new_time - start < DISK_CACHE_TIMEOUT

start = time.monotonic()
new_parent_block_hash = await disk_cached_substrate.get_parent_block_hash(
block_hash
)
_ = await disk_cached_substrate.get_parent_block_hash(block_hash)
new_time = time.monotonic()
assert new_time - start < 0.001
assert new_time - start < DISK_CACHE_TIMEOUT
start = time.monotonic()
new_block_runtime_info = await disk_cached_substrate.get_block_runtime_info(
block_hash
)
_ = await disk_cached_substrate.get_block_runtime_info(block_hash)
new_time = time.monotonic()
assert new_time - start < 0.001
assert new_time - start < DISK_CACHE_TIMEOUT
start = time.monotonic()
new_block_runtime_version_for = (
await disk_cached_substrate.get_block_runtime_version_for(block_hash)
)
_ = await disk_cached_substrate.get_block_runtime_version_for(block_hash)
new_time = time.monotonic()
assert new_time - start < 0.001
assert new_time - start < DISK_CACHE_TIMEOUT
# Repeat calls with the same args must come from the in-process LRU cache.
start = time.monotonic()
new_block_hash_from_cache = await disk_cached_substrate.get_block_hash(
current_block
)
_ = await disk_cached_substrate.get_block_hash(current_block)
new_time = time.monotonic()
assert new_time - start < 0.001
assert new_time - start < MEMORY_CACHE_TIMEOUT
start = time.monotonic()
new_parent_block_hash_from_cache = (
await disk_cached_substrate.get_parent_block_hash(block_hash_from_cache)
)
_ = await disk_cached_substrate.get_parent_block_hash(block_hash_from_cache)
new_time = time.monotonic()
assert new_time - start < 0.001
assert new_time - start < MEMORY_CACHE_TIMEOUT
start = time.monotonic()
new_block_runtime_info_from_cache = (
await disk_cached_substrate.get_block_runtime_info(block_hash_from_cache)
)
_ = await disk_cached_substrate.get_block_runtime_info(block_hash_from_cache)
new_time = time.monotonic()
assert new_time - start < 0.001
assert new_time - start < MEMORY_CACHE_TIMEOUT
start = time.monotonic()
new_block_runtime_version_from_cache = (
await disk_cached_substrate.get_block_runtime_version_for(
block_hash_from_cache
)
_ = await disk_cached_substrate.get_block_runtime_version_for(
block_hash_from_cache
)
new_time = time.monotonic()
assert new_time - start < 0.001
assert new_time - start < MEMORY_CACHE_TIMEOUT
print("Disk Cache tests passed")
Loading