diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index 9901619..a3f9107 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -10,6 +10,7 @@ import os import socket import ssl +import time import warnings from contextlib import suppress from unittest.mock import AsyncMock @@ -39,6 +40,7 @@ from websockets.asyncio.client import connect, ClientConnection from websockets.exceptions import ( ConnectionClosed, + InvalidURI, ) from websockets.protocol import State @@ -87,6 +89,7 @@ # env vars dictating the cache size of the cached methods SUBSTRATE_CACHE_METHOD_SIZE = int(os.getenv("SUBSTRATE_CACHE_METHOD_SIZE", "512")) SUBSTRATE_RUNTIME_CACHE_SIZE = int(os.getenv("SUBSTRATE_RUNTIME_CACHE_SIZE", "16")) +SSL_SESSION_TTL = int(os.getenv("SUBSTRATE_SSL_SESSION_TTL", "300")) class AsyncExtrinsicReceipt: @@ -577,6 +580,55 @@ def __getitem__(self, item): return self.records[item] +class _SessionResumingSSLContext(ssl.SSLContext): + """ + An SSL context that saves the last TLS session and attempts to resume it on + reconnection, as long as it is still within its TTL. + + Session resumption avoids a full TLS handshake on reconnect, reducing + latency. The effective TTL is the minimum of ``session_ttl`` and the + server-advertised session timeout. + """ + + def __new__(cls, protocol: int = ssl.PROTOCOL_TLS_CLIENT, **_kwargs): + return ssl.SSLContext.__new__(cls, protocol) + + def __init__( + self, + protocol: int = ssl.PROTOCOL_TLS_CLIENT, + *, + session_ttl: int = SSL_SESSION_TTL, + ): + self._saved_session: Optional[ssl.SSLSession] = None + self._session_established_at: Optional[float] = None + self._session_ttl = session_ttl + + def save_session(self, session: ssl.SSLSession) -> None: + self._saved_session = session + self._session_established_at = time.monotonic() + + def _session_is_valid(self) -> bool: + if self._saved_session is None or self._session_established_at is None: + return False + elapsed = time.monotonic() - self._session_established_at + effective_ttl = min(self._session_ttl, self._saved_session.timeout) + return elapsed < effective_ttl + + def wrap_bio( + self, incoming, outgoing, server_side=False, server_hostname=None, session=None + ): + if not server_side and session is None and self._session_is_valid(): + session = self._saved_session + logger.debug("Attempting TLS session resumption") + return super().wrap_bio( + incoming, + outgoing, + server_side=server_side, + server_hostname=server_hostname, + session=session, + ) + + class Websocket: def __init__( self, @@ -588,6 +640,8 @@ def __init__( _log_raw_websockets: bool = False, retry_timeout: float = 60.0, max_retries: int = 5, + ssl_context: Optional[_SessionResumingSSLContext] = None, + dns_ttl: int = 300, ): """ Websocket manager object. Allows for the use of a single websocket connection by multiple @@ -604,6 +658,10 @@ def __init__( _log_raw_websockets: Whether to log raw websockets in the "raw_websocket" logger retry_timeout: Timeout in seconds to retry websocket connection max_retries: Maximum number of retries following a timeout + ssl_context: Optional session-resuming SSL context for wss:// connections. + When provided, the context's saved TLS session is reused on reconnection + to avoid a full handshake. + dns_ttl: Seconds to cache DNS results. Set to 0 to disable caching. """ # TODO allow setting max concurrent connections and rpc subscriptions per connection self.ws_url = ws_url @@ -627,6 +685,11 @@ def __init__( self._last_activity = asyncio.Event() self._last_activity.set() self._waiting_for_response = 0 + self._ssl_context = ssl_context + if ssl_context is not None and ws_url.startswith("wss://"): + self._options["ssl"] = ssl_context + self._dns_ttl = dns_ttl + self._dns_cache: Optional[tuple[list, float]] = None @property def state(self): @@ -736,6 +799,37 @@ async def _cancel(self): f"{e} encountered while trying to close websocket connection." ) + async def _resolve_host(self) -> tuple: + """ + Resolve the websocket hostname to a (family, type, proto, canonname, sockaddr) tuple, + using a cached result if it is still within ``dns_ttl`` seconds. + + Invalidate the cache by setting ``_dns_cache = None`` before calling. + """ + from urllib.parse import urlparse + + parsed = urlparse(self.ws_url) + if parsed.scheme not in ("ws", "wss"): + raise InvalidURI(self.ws_url, f"Invalid URI scheme: {parsed.scheme!r}") + host = parsed.hostname + port = parsed.port or (443 if parsed.scheme == "wss" else 80) + + now = time.monotonic() + if self._dns_cache is not None and self._dns_ttl > 0: + infos, resolved_at = self._dns_cache + if now - resolved_at < self._dns_ttl: + logger.debug(f"DNS cache hit for {host} (age={now - resolved_at:.0f}s)") + return infos[0] + + logger.debug(f"Resolving DNS for {host}:{port}") + loop = asyncio.get_running_loop() + infos = await loop.getaddrinfo( + host, port, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM + ) + self._dns_cache = (infos, now) + logger.debug(f"DNS resolved {host} -> {infos[0][4][0]}") + return infos[0] + async def connect(self, force=False): if not force: async with self._lock: @@ -771,8 +865,20 @@ async def _connect_internal(self, force): pass logger.debug("Attempting connection") try: + family, type_, proto, _, sockaddr = await self._resolve_host() + tcp_sock = socket.socket(family, type_, proto) + tcp_sock.setblocking(False) + loop = asyncio.get_running_loop() + try: + await asyncio.wait_for( + loop.sock_connect(tcp_sock, sockaddr), timeout=10.0 + ) + except Exception: + tcp_sock.close() + self._dns_cache = None # invalidate on TCP failure + raise connection = await asyncio.wait_for( - connect(self.ws_url, **self._options), timeout=10.0 + connect(self.ws_url, sock=tcp_sock, **self._options), timeout=10.0 ) except socket.gaierror: logger.debug(f"Hostname not known (this is just for testing") @@ -780,6 +886,18 @@ async def _connect_internal(self, force): return await self.connect(force=force) logger.debug("Connection established") self.ws = connection + if self._ssl_context is not None: + try: + ssl_obj = connection.transport.get_extra_info("ssl_object") + if ssl_obj is not None and ssl_obj.session is not None: + self._ssl_context.save_session(ssl_obj.session) + logger.debug( + f"Saved TLS session " + f"(reused={ssl_obj.session_reused}, " + f"timeout={ssl_obj.session.timeout}s)" + ) + except Exception as e: + logger.debug(f"Could not save TLS session: {e}") if self._send_recv_task is None or self._send_recv_task.done(): self._send_recv_task = asyncio.get_running_loop().create_task( self._handler(self.ws) @@ -1146,6 +1264,8 @@ def __init__( _log_raw_websockets: bool = False, ws_shutdown_timer: Optional[float] = 5.0, decode_ss58: bool = False, + _ssl_context: Optional[_SessionResumingSSLContext] = None, + dns_ttl: int = 300, ): """ The asyncio-compatible version of the subtensor interface commands we use in bittensor. It is important to @@ -1166,6 +1286,10 @@ def __init__( _log_raw_websockets: whether to log raw websocket requests during RPC requests ws_shutdown_timer: how long after the last connection your websocket should close decode_ss58: Whether to decode AccountIds to SS58 or leave them in raw bytes tuples. + _ssl_context: optional session-resuming SSL context; used internally by + DiskCachedAsyncSubstrateInterface to enable TLS session reuse. + dns_ttl: seconds to cache DNS results for the websocket URL (default 300). Set to 0 + to disable caching. """ super().__init__( @@ -1192,6 +1316,8 @@ def __init__( shutdown_timer=ws_shutdown_timer, retry_timeout=self.retry_timeout, max_retries=max_retries, + ssl_context=_ssl_context, + dns_ttl=dns_ttl, ) else: self.ws = AsyncMock(spec=Websocket) @@ -4400,22 +4526,52 @@ class DiskCachedAsyncSubstrateInterface(AsyncSubstrateInterface): Loads the cache from the disk at startup, where it is kept in-memory, and dumps to the disk when the connection is closed. + + For `wss://` endpoints, a persistent `_SessionResumingSSLContext` is created so + that TLS sessions are reused across reconnections. The effective session TTL is the minimum + of `ssl_session_ttl` (default `SSL_SESSION_TTL`) and the server-advertised timeout. """ + def __init__( + self, + url: str, + *args, + ssl_session_ttl: int = SSL_SESSION_TTL, + **kwargs, + ): + ssl_context: Optional[_SessionResumingSSLContext] = None + if url.startswith("wss://") and not kwargs.get("_mock", False): + ssl_context = _SessionResumingSSLContext(session_ttl=ssl_session_ttl) + ssl_context.set_default_verify_paths() + super().__init__(url, *args, _ssl_context=ssl_context, **kwargs) + async def initialize(self) -> None: + db = AsyncSqliteDB(self.url) + cached = await db.load_dns_cache(self.url) + if cached is not None: + addrinfos, saved_at_unix = cached + age = time.time() - saved_at_unix + # Reconstruct a monotonic timestamp so _resolve_host's TTL check works correctly + self.ws._dns_cache = (addrinfos, time.monotonic() - age) + logger.debug(f"Loaded DNS cache from disk (age={age:.0f}s)") await self.runtime_cache.load_from_disk(self.url) await self._initialize() async def close(self): """ - Closes the substrate connection and the websocket connection, dumps the runtime cache to disk + Closes the substrate connection and the websocket connection, dumps the runtime and DNS + caches to disk. """ + db = AsyncSqliteDB(self.url) + dns_cache = getattr(self.ws, "_dns_cache", None) + if dns_cache is not None: + addrinfos, _ = dns_cache + await db.save_dns_cache(self.url, addrinfos) try: await self.runtime_cache.dump_to_disk(self.url) await self.ws.shutdown() except AttributeError: pass - db = AsyncSqliteDB(self.url) await db.close() @async_sql_lru_cache(maxsize=SUBSTRATE_CACHE_METHOD_SIZE) diff --git a/async_substrate_interface/utils/cache.py b/async_substrate_interface/utils/cache.py index 2fecd0e..bfac941 100644 --- a/async_substrate_interface/utils/cache.py +++ b/async_substrate_interface/utils/cache.py @@ -1,5 +1,6 @@ import asyncio import inspect +import time import weakref from collections import OrderedDict import functools @@ -115,6 +116,61 @@ async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any] await self._db.commit() return result + async def _ensure_dns_table(self): + await self._db.execute( + """CREATE TABLE IF NOT EXISTS dns_cache ( + url TEXT PRIMARY KEY, + addrinfos BLOB, + saved_at REAL + )""" + ) + await self._db.commit() + + async def load_dns_cache(self, url: str) -> Optional[tuple[list, float]]: + """ + Load a previously saved DNS result for ``url``. + + Returns ``(addrinfos, saved_at_unix)`` where ``saved_at_unix`` is the Unix + timestamp at which the result was saved, or ``None`` if nothing is cached. + Skips localhost URLs. + """ + if _check_if_local(url): + return None + async with self._lock: + if not self._db: + _ensure_dir() + self._db = await aiosqlite.connect(CACHE_LOCATION) + await self._ensure_dns_table() + try: + cursor = await self._db.execute( + "SELECT addrinfos, saved_at FROM dns_cache WHERE url=?", (url,) + ) + row = await cursor.fetchone() + await cursor.close() + if row is not None: + return pickle.loads(row[0]), row[1] + except (pickle.PickleError, sqlite3.Error) as e: + logger.debug(f"DNS cache load error: {e}") + return None + + async def save_dns_cache(self, url: str, addrinfos: list) -> None: + """Persist DNS results for ``url`` to disk. Skips localhost URLs.""" + if _check_if_local(url): + return + async with self._lock: + if not self._db: + _ensure_dir() + self._db = await aiosqlite.connect(CACHE_LOCATION) + await self._ensure_dns_table() + try: + await self._db.execute( + "INSERT OR REPLACE INTO dns_cache (url, addrinfos, saved_at) VALUES (?,?,?)", + (url, pickle.dumps(addrinfos), time.time()), + ) + await self._db.commit() + except (pickle.PickleError, sqlite3.Error) as e: + logger.debug(f"DNS cache save error: {e}") + async def load_runtime_cache( self, chain: str ) -> tuple[OrderedDict[int, str], OrderedDict[str, int], OrderedDict[int, dict]]: