Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 159 additions & 3 deletions async_substrate_interface/async_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import os
import socket
import ssl
import time
import warnings
from contextlib import suppress
from unittest.mock import AsyncMock
Expand Down Expand Up @@ -39,6 +40,7 @@
from websockets.asyncio.client import connect, ClientConnection
from websockets.exceptions import (
ConnectionClosed,
InvalidURI,
)
from websockets.protocol import State

Expand Down Expand Up @@ -87,6 +89,7 @@
# env vars dictating the cache size of the cached methods
SUBSTRATE_CACHE_METHOD_SIZE = int(os.getenv("SUBSTRATE_CACHE_METHOD_SIZE", "512"))
SUBSTRATE_RUNTIME_CACHE_SIZE = int(os.getenv("SUBSTRATE_RUNTIME_CACHE_SIZE", "16"))
SSL_SESSION_TTL = int(os.getenv("SUBSTRATE_SSL_SESSION_TTL", "300"))


class AsyncExtrinsicReceipt:
Expand Down Expand Up @@ -577,6 +580,55 @@ def __getitem__(self, item):
return self.records[item]


class _SessionResumingSSLContext(ssl.SSLContext):
"""
An SSL context that saves the last TLS session and attempts to resume it on
reconnection, as long as it is still within its TTL.

Session resumption avoids a full TLS handshake on reconnect, reducing
latency. The effective TTL is the minimum of ``session_ttl`` and the
server-advertised session timeout.
"""

def __new__(cls, protocol: int = ssl.PROTOCOL_TLS_CLIENT, **_kwargs):
return ssl.SSLContext.__new__(cls, protocol)

def __init__(
self,
protocol: int = ssl.PROTOCOL_TLS_CLIENT,
*,
session_ttl: int = SSL_SESSION_TTL,
):
self._saved_session: Optional[ssl.SSLSession] = None
self._session_established_at: Optional[float] = None
self._session_ttl = session_ttl

def save_session(self, session: ssl.SSLSession) -> None:
self._saved_session = session
self._session_established_at = time.monotonic()

def _session_is_valid(self) -> bool:
if self._saved_session is None or self._session_established_at is None:
return False
elapsed = time.monotonic() - self._session_established_at
effective_ttl = min(self._session_ttl, self._saved_session.timeout)
return elapsed < effective_ttl

def wrap_bio(
self, incoming, outgoing, server_side=False, server_hostname=None, session=None
):
if not server_side and session is None and self._session_is_valid():
session = self._saved_session
logger.debug("Attempting TLS session resumption")
return super().wrap_bio(
incoming,
outgoing,
server_side=server_side,
server_hostname=server_hostname,
session=session,
)


class Websocket:
def __init__(
self,
Expand All @@ -588,6 +640,8 @@ def __init__(
_log_raw_websockets: bool = False,
retry_timeout: float = 60.0,
max_retries: int = 5,
ssl_context: Optional[_SessionResumingSSLContext] = None,
dns_ttl: int = 300,
):
"""
Websocket manager object. Allows for the use of a single websocket connection by multiple
Expand All @@ -604,6 +658,10 @@ def __init__(
_log_raw_websockets: Whether to log raw websockets in the "raw_websocket" logger
retry_timeout: Timeout in seconds to retry websocket connection
max_retries: Maximum number of retries following a timeout
ssl_context: Optional session-resuming SSL context for wss:// connections.
When provided, the context's saved TLS session is reused on reconnection
to avoid a full handshake.
dns_ttl: Seconds to cache DNS results. Set to 0 to disable caching.
"""
# TODO allow setting max concurrent connections and rpc subscriptions per connection
self.ws_url = ws_url
Expand All @@ -627,6 +685,11 @@ def __init__(
self._last_activity = asyncio.Event()
self._last_activity.set()
self._waiting_for_response = 0
self._ssl_context = ssl_context
if ssl_context is not None and ws_url.startswith("wss://"):
self._options["ssl"] = ssl_context
self._dns_ttl = dns_ttl
self._dns_cache: Optional[tuple[list, float]] = None

@property
def state(self):
Expand Down Expand Up @@ -736,6 +799,37 @@ async def _cancel(self):
f"{e} encountered while trying to close websocket connection."
)

async def _resolve_host(self) -> tuple:
"""
Resolve the websocket hostname to a (family, type, proto, canonname, sockaddr) tuple,
using a cached result if it is still within ``dns_ttl`` seconds.

Invalidate the cache by setting ``_dns_cache = None`` before calling.
"""
from urllib.parse import urlparse

parsed = urlparse(self.ws_url)
if parsed.scheme not in ("ws", "wss"):
raise InvalidURI(self.ws_url, f"Invalid URI scheme: {parsed.scheme!r}")
host = parsed.hostname
port = parsed.port or (443 if parsed.scheme == "wss" else 80)

now = time.monotonic()
if self._dns_cache is not None and self._dns_ttl > 0:
infos, resolved_at = self._dns_cache
if now - resolved_at < self._dns_ttl:
logger.debug(f"DNS cache hit for {host} (age={now - resolved_at:.0f}s)")
return infos[0]

logger.debug(f"Resolving DNS for {host}:{port}")
loop = asyncio.get_running_loop()
infos = await loop.getaddrinfo(
host, port, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
)
self._dns_cache = (infos, now)
logger.debug(f"DNS resolved {host} -> {infos[0][4][0]}")
return infos[0]

async def connect(self, force=False):
if not force:
async with self._lock:
Expand Down Expand Up @@ -771,15 +865,39 @@ async def _connect_internal(self, force):
pass
logger.debug("Attempting connection")
try:
family, type_, proto, _, sockaddr = await self._resolve_host()
tcp_sock = socket.socket(family, type_, proto)
tcp_sock.setblocking(False)
loop = asyncio.get_running_loop()
try:
await asyncio.wait_for(
loop.sock_connect(tcp_sock, sockaddr), timeout=10.0
)
except Exception:
tcp_sock.close()
self._dns_cache = None # invalidate on TCP failure
raise
connection = await asyncio.wait_for(
connect(self.ws_url, **self._options), timeout=10.0
connect(self.ws_url, sock=tcp_sock, **self._options), timeout=10.0
)
except socket.gaierror:
logger.debug(f"Hostname not known (this is just for testing")
await asyncio.sleep(10)
return await self.connect(force=force)
logger.debug("Connection established")
self.ws = connection
if self._ssl_context is not None:
try:
ssl_obj = connection.transport.get_extra_info("ssl_object")
if ssl_obj is not None and ssl_obj.session is not None:
self._ssl_context.save_session(ssl_obj.session)
logger.debug(
f"Saved TLS session "
f"(reused={ssl_obj.session_reused}, "
f"timeout={ssl_obj.session.timeout}s)"
)
except Exception as e:
logger.debug(f"Could not save TLS session: {e}")
if self._send_recv_task is None or self._send_recv_task.done():
self._send_recv_task = asyncio.get_running_loop().create_task(
self._handler(self.ws)
Expand Down Expand Up @@ -1146,6 +1264,8 @@ def __init__(
_log_raw_websockets: bool = False,
ws_shutdown_timer: Optional[float] = 5.0,
decode_ss58: bool = False,
_ssl_context: Optional[_SessionResumingSSLContext] = None,
dns_ttl: int = 300,
):
"""
The asyncio-compatible version of the subtensor interface commands we use in bittensor. It is important to
Expand All @@ -1166,6 +1286,10 @@ def __init__(
_log_raw_websockets: whether to log raw websocket requests during RPC requests
ws_shutdown_timer: how long after the last connection your websocket should close
decode_ss58: Whether to decode AccountIds to SS58 or leave them in raw bytes tuples.
_ssl_context: optional session-resuming SSL context; used internally by
DiskCachedAsyncSubstrateInterface to enable TLS session reuse.
dns_ttl: seconds to cache DNS results for the websocket URL (default 300). Set to 0
to disable caching.

"""
super().__init__(
Expand All @@ -1192,6 +1316,8 @@ def __init__(
shutdown_timer=ws_shutdown_timer,
retry_timeout=self.retry_timeout,
max_retries=max_retries,
ssl_context=_ssl_context,
dns_ttl=dns_ttl,
)
else:
self.ws = AsyncMock(spec=Websocket)
Expand Down Expand Up @@ -4400,22 +4526,52 @@ class DiskCachedAsyncSubstrateInterface(AsyncSubstrateInterface):

Loads the cache from the disk at startup, where it is kept in-memory, and dumps to the disk
when the connection is closed.

For `wss://` endpoints, a persistent `_SessionResumingSSLContext` is created so
that TLS sessions are reused across reconnections. The effective session TTL is the minimum
of `ssl_session_ttl` (default `SSL_SESSION_TTL`) and the server-advertised timeout.
"""

def __init__(
self,
url: str,
*args,
ssl_session_ttl: int = SSL_SESSION_TTL,
**kwargs,
):
ssl_context: Optional[_SessionResumingSSLContext] = None
if url.startswith("wss://") and not kwargs.get("_mock", False):
ssl_context = _SessionResumingSSLContext(session_ttl=ssl_session_ttl)
ssl_context.set_default_verify_paths()
super().__init__(url, *args, _ssl_context=ssl_context, **kwargs)

async def initialize(self) -> None:
db = AsyncSqliteDB(self.url)
cached = await db.load_dns_cache(self.url)
if cached is not None:
addrinfos, saved_at_unix = cached
age = time.time() - saved_at_unix
# Reconstruct a monotonic timestamp so _resolve_host's TTL check works correctly
self.ws._dns_cache = (addrinfos, time.monotonic() - age)
logger.debug(f"Loaded DNS cache from disk (age={age:.0f}s)")
await self.runtime_cache.load_from_disk(self.url)
await self._initialize()

async def close(self):
"""
Closes the substrate connection and the websocket connection, dumps the runtime cache to disk
Closes the substrate connection and the websocket connection, dumps the runtime and DNS
caches to disk.
"""
db = AsyncSqliteDB(self.url)
dns_cache = getattr(self.ws, "_dns_cache", None)
if dns_cache is not None:
addrinfos, _ = dns_cache
await db.save_dns_cache(self.url, addrinfos)
try:
await self.runtime_cache.dump_to_disk(self.url)
await self.ws.shutdown()
except AttributeError:
pass
db = AsyncSqliteDB(self.url)
await db.close()

@async_sql_lru_cache(maxsize=SUBSTRATE_CACHE_METHOD_SIZE)
Expand Down
56 changes: 56 additions & 0 deletions async_substrate_interface/utils/cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import inspect
import time
import weakref
from collections import OrderedDict
import functools
Expand Down Expand Up @@ -115,6 +116,61 @@ async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any]
await self._db.commit()
return result

async def _ensure_dns_table(self):
await self._db.execute(
"""CREATE TABLE IF NOT EXISTS dns_cache (
url TEXT PRIMARY KEY,
addrinfos BLOB,
saved_at REAL
)"""
)
await self._db.commit()

async def load_dns_cache(self, url: str) -> Optional[tuple[list, float]]:
"""
Load a previously saved DNS result for ``url``.

Returns ``(addrinfos, saved_at_unix)`` where ``saved_at_unix`` is the Unix
timestamp at which the result was saved, or ``None`` if nothing is cached.
Skips localhost URLs.
"""
if _check_if_local(url):
return None
async with self._lock:
if not self._db:
_ensure_dir()
self._db = await aiosqlite.connect(CACHE_LOCATION)
await self._ensure_dns_table()
try:
cursor = await self._db.execute(
"SELECT addrinfos, saved_at FROM dns_cache WHERE url=?", (url,)
)
row = await cursor.fetchone()
await cursor.close()
if row is not None:
return pickle.loads(row[0]), row[1]
except (pickle.PickleError, sqlite3.Error) as e:
logger.debug(f"DNS cache load error: {e}")
return None

async def save_dns_cache(self, url: str, addrinfos: list) -> None:
"""Persist DNS results for ``url`` to disk. Skips localhost URLs."""
if _check_if_local(url):
return
async with self._lock:
if not self._db:
_ensure_dir()
self._db = await aiosqlite.connect(CACHE_LOCATION)
await self._ensure_dns_table()
try:
await self._db.execute(
"INSERT OR REPLACE INTO dns_cache (url, addrinfos, saved_at) VALUES (?,?,?)",
(url, pickle.dumps(addrinfos), time.time()),
)
await self._db.commit()
except (pickle.PickleError, sqlite3.Error) as e:
logger.debug(f"DNS cache save error: {e}")

async def load_runtime_cache(
self, chain: str
) -> tuple[OrderedDict[int, str], OrderedDict[str, int], OrderedDict[int, dict]]:
Expand Down
Loading