diff --git a/CHANGELOG.md b/CHANGELOG.md index 4fcca57..a1cf15f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## 1.6.1 /2025-02-03 +* RuntimeCache updates by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/260 +* fix memory leak by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/261 +* Avoid Race Condition on SQLite Table Creation by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/263 + +**Full Changelog**: https://github.com/opentensor/async-substrate-interface/compare/v1.6.0...v1.6.1 + ## 1.6.0 /2025-01-27 * Fix typo by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/258 * Improve Disk Caching by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/227 diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index 91c4ca0..177e0e2 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -11,7 +11,6 @@ import socket import ssl import warnings -from contextlib import suppress from unittest.mock import AsyncMock from hashlib import blake2b from typing import ( @@ -40,7 +39,6 @@ from websockets.asyncio.client import connect, ClientConnection from websockets.exceptions import ( ConnectionClosed, - WebSocketException, ) from websockets.protocol import State @@ -708,6 +706,10 @@ async def _cancel(self): logger.debug("Cancelling send/recv tasks") if self._send_recv_task is not None: self._send_recv_task.cancel() + try: + await self._send_recv_task + except asyncio.CancelledError: + pass except asyncio.CancelledError: pass except Exception as e: @@ -777,16 +779,31 @@ async def _handler(self, ws: ClientConnection) -> Union[None, Exception]: logger.debug("WS handler attached") recv_task = asyncio.create_task(self._start_receiving(ws)) send_task = asyncio.create_task(self._start_sending(ws)) - done, pending = await asyncio.wait( - [recv_task, send_task], - return_when=asyncio.FIRST_COMPLETED, - ) + try: + done, pending = await asyncio.wait( + [recv_task, send_task], + return_when=asyncio.FIRST_COMPLETED, + ) + except asyncio.CancelledError: + # Handler was cancelled, clean up child tasks + for task in [recv_task, send_task]: + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + raise loop = asyncio.get_running_loop() should_reconnect = False is_retry = False for task in pending: task.cancel() + try: + await task + except asyncio.CancelledError: + pass for task in done: task_res = task.result() @@ -887,6 +904,14 @@ async def _exit_with_timer(self): async def shutdown(self): logger.debug("Shutdown requested") + # Cancel the exit timer task if it exists + if self._exit_task is not None: + self._exit_task.cancel() + try: + await self._exit_task + except asyncio.CancelledError: + pass + self._exit_task = None try: await asyncio.wait_for(self._cancel(), timeout=10.0) except asyncio.TimeoutError: @@ -990,8 +1015,9 @@ async def _start_sending(self, ws) -> Exception: ) if to_send is not None: to_send_ = json.loads(to_send) - self._received[to_send_["id"]].set_exception(e) - self._received[to_send_["id"]].cancel() + if to_send_["id"] in self._received: + self._received[to_send_["id"]].set_exception(e) + self._received[to_send_["id"]].cancel() else: for i in self._received.keys(): self._received[i].set_exception(e) @@ -1975,7 +2001,6 @@ async def result_handler( if subscription_result is not None: reached = True - logger.info("REACHED!") # Handler returned end result: unsubscribe from further updates async with self.ws as ws: await ws.unsubscribe( diff --git a/async_substrate_interface/sync_substrate.py b/async_substrate_interface/sync_substrate.py index 8ddd90b..5b6db72 100644 --- a/async_substrate_interface/sync_substrate.py +++ b/async_substrate_interface/sync_substrate.py @@ -3398,5 +3398,12 @@ def close(self): self.ws.close() except AttributeError: pass + # Clear lru_cache on instance methods to allow garbage collection + self.get_runtime_for_version.cache_clear() + self.get_parent_block_hash.cache_clear() + self.get_block_runtime_info.cache_clear() + self.get_block_runtime_version_for.cache_clear() + self.supports_rpc_method.cache_clear() + self.get_block_hash.cache_clear() encode_scale = SubstrateMixin._encode_scale diff --git a/async_substrate_interface/types.py b/async_substrate_interface/types.py index 8878497..842e260 100644 --- a/async_substrate_interface/types.py +++ b/async_substrate_interface/types.py @@ -1,11 +1,13 @@ +import bisect import logging +import os from abc import ABC from collections import defaultdict, deque from collections.abc import Iterable from contextlib import suppress from dataclasses import dataclass from datetime import datetime -from typing import Optional, Union, Any +from typing import Optional, Union, Any, Sequence import scalecodec.types from bt_decode import PortableRegistry, encode as encode_by_type_string @@ -17,9 +19,11 @@ from .const import SS58_FORMAT from .utils import json -from .utils.cache import AsyncSqliteDB +from .utils.cache import AsyncSqliteDB, LRUCache logger = logging.getLogger("async_substrate_interface") +SUBSTRATE_RUNTIME_CACHE_SIZE = int(os.getenv("SUBSTRATE_RUNTIME_CACHE_SIZE", "16")) +SUBSTRATE_CACHE_METHOD_SIZE = int(os.getenv("SUBSTRATE_CACHE_METHOD_SIZE", "512")) class RuntimeCache: @@ -41,11 +45,45 @@ class RuntimeCache: versions: dict[int, "Runtime"] last_used: Optional["Runtime"] - def __init__(self): - self.blocks = {} - self.block_hashes = {} - self.versions = {} - self.last_used = None + def __init__(self, known_versions: Optional[Sequence[tuple[int, int]]] = None): + # {block: block_hash, ...} + self.blocks: LRUCache = LRUCache(max_size=SUBSTRATE_CACHE_METHOD_SIZE) + # {block_hash: specVersion, ...} + self.block_hashes: LRUCache = LRUCache(max_size=SUBSTRATE_CACHE_METHOD_SIZE) + # {specVersion: Runtime, ...} + self.versions: LRUCache = LRUCache(max_size=SUBSTRATE_RUNTIME_CACHE_SIZE) + # [(block, specVersion), ...] + self.known_versions: list[tuple[int, int]] = [] + # [block, ...] for binary search (excludes last item) + self._known_version_blocks: list[int] = [] + if known_versions: + self.add_known_versions(known_versions) + self.last_used: Optional["Runtime"] = None + + def add_known_versions(self, known_versions: Sequence[tuple[int, int]]): + """ + Known versions are a map of {block: specVersion} for when runtimes change. + + E.g. + [ + (561, 102), + (1075, 103), + ..., + (7257645, 367) + ] + + This mapping is generally user-created or pulled from an external API, such as + https://api.tao.app/docs#/chain/get_runtime_versions_api_beta_chain_runtime_version_get + + By preloading the known versions, there can be significantly fewer chain calls to determine version. + + Note that because the last runtime in the supplied known versions will be ignored, as otherwise we would + have to assume that the final known version never changes. + """ + known_versions = list(sorted(known_versions, key=lambda v: v[0])) + self.known_versions = known_versions + # Cache block numbers (excluding last) for O(log n) binary search lookups + self._known_version_blocks = [v[0] for v in known_versions[:-1]] def add_item( self, @@ -59,11 +97,11 @@ def add_item( """ self.last_used = runtime if block is not None and block_hash is not None: - self.blocks[block] = block_hash + self.blocks.set(block, block_hash) if block_hash is not None and runtime_version is not None: - self.block_hashes[block_hash] = runtime_version + self.block_hashes.set(block_hash, runtime_version) if runtime_version is not None: - self.versions[runtime_version] = runtime + self.versions.set(runtime_version, runtime) def retrieve( self, @@ -75,26 +113,35 @@ def retrieve( Retrieves a Runtime object from the cache, using the key of its block number, block hash, or runtime version. Retrieval happens in this order. If no Runtime is found mapped to any of your supplied keys, returns `None`. """ + # No reason to do this lookup if the runtime version is already supplied in this call + if block is not None and runtime_version is None and self._known_version_blocks: + # _known_version_blocks excludes the last item (see note in `add_known_versions`) + idx = bisect.bisect_right(self._known_version_blocks, block) - 1 + if idx >= 0: + runtime_version = self.known_versions[idx][1] + runtime = None if block is not None: if block_hash is not None: - self.blocks[block] = block_hash + self.blocks.set(block, block_hash) if runtime_version is not None: - self.block_hashes[block_hash] = runtime_version - with suppress(KeyError): - runtime = self.versions[self.block_hashes[self.blocks[block]]] + self.block_hashes.set(block_hash, runtime_version) + with suppress(AttributeError): + runtime = self.versions.get( + self.block_hashes.get(self.blocks.get(block)) + ) self.last_used = runtime return runtime if block_hash is not None: if runtime_version is not None: - self.block_hashes[block_hash] = runtime_version - with suppress(KeyError): - runtime = self.versions[self.block_hashes[block_hash]] + self.block_hashes.set(block_hash, runtime_version) + with suppress(AttributeError): + runtime = self.versions.get(self.block_hashes.get(block_hash)) self.last_used = runtime return runtime if runtime_version is not None: - with suppress(KeyError): - runtime = self.versions[runtime_version] + runtime = self.versions.get(runtime_version) + if runtime is not None: self.last_used = runtime return runtime return runtime @@ -110,16 +157,21 @@ async def load_from_disk(self, chain_endpoint: str): logger.debug("No runtime mappings in disk cache") else: logger.debug("Found runtime mappings in disk cache") - self.blocks = block_mapping - self.block_hashes = block_hash_mapping - self.versions = { - x: Runtime.deserialize(y) for x, y in runtime_version_mapping.items() - } + self.blocks.cache = block_mapping + self.block_hashes.cache = block_hash_mapping + for x, y in runtime_version_mapping.items(): + self.versions.cache[x] = Runtime.deserialize(y) async def dump_to_disk(self, chain_endpoint: str): db = AsyncSqliteDB(chain_endpoint=chain_endpoint) + blocks = self.blocks.cache + block_hashes = self.block_hashes.cache + versions = self.versions.cache await db.dump_runtime_cache( - chain_endpoint, self.blocks, self.block_hashes, self.versions + chain=chain_endpoint, + block_mapping=blocks, + block_hash_mapping=block_hashes, + version_mapping=versions, ) diff --git a/async_substrate_interface/utils/cache.py b/async_substrate_interface/utils/cache.py index 24c609c..431a430 100644 --- a/async_substrate_interface/utils/cache.py +++ b/async_substrate_interface/utils/cache.py @@ -1,5 +1,6 @@ import asyncio import inspect +import weakref from collections import OrderedDict import functools import logging @@ -60,6 +61,7 @@ async def _create_if_not_exists(self, chain: str, table_name: str): ); """ ) + await self._db.commit() await self._db.execute( f""" CREATE TRIGGER IF NOT EXISTS prune_rows_trigger_{table_name} AFTER INSERT ON {table_name} @@ -81,8 +83,8 @@ async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any] if not self._db: _ensure_dir() self._db = await aiosqlite.connect(CACHE_LOCATION) - table_name = _get_table_name(func) - local_chain = await self._create_if_not_exists(chain, table_name) + table_name = _get_table_name(func) + local_chain = await self._create_if_not_exists(chain, table_name) key = pickle.dumps((args, kwargs or None)) try: cursor: aiosqlite.Cursor = await self._db.execute( @@ -111,9 +113,9 @@ async def load_runtime_cache(self, chain: str) -> tuple[dict, dict, dict]: if not self._db: _ensure_dir() self._db = await aiosqlite.connect(CACHE_LOCATION) - block_mapping = {} - block_hash_mapping = {} - version_mapping = {} + block_mapping = OrderedDict() + block_hash_mapping = OrderedDict() + version_mapping = OrderedDict() tables = { "RuntimeCache_blocks": block_mapping, "RuntimeCache_block_hashes": block_hash_mapping, @@ -419,6 +421,26 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any: self._inflight.pop(key, None) +class _WeakMethod: + """ + Weak reference to a bound method that allows the instance to be garbage collected. + Preserves the method's signature for introspection. + """ + + def __init__(self, method): + self._func = method.__func__ + self._instance_ref = weakref.ref(method.__self__) + # Store the bound method's signature (without 'self') for inspect.signature() to find. + # We capture this once at creation time to avoid holding references to the bound method. + self.__signature__ = inspect.signature(method) + + def __call__(self, *args, **kwargs): + instance = self._instance_ref() + if instance is None: + raise ReferenceError("Instance has been garbage collected") + return self._func(instance, *args, **kwargs) + + class _CachedFetcherMethod: """ Helper class for using CachedFetcher with method caches (rather than functions) @@ -428,18 +450,21 @@ def __init__(self, method, max_size: int, cache_key_index: int): self.method = method self.max_size = max_size self.cache_key_index = cache_key_index - self._instances = {} + # Use WeakKeyDictionary to avoid preventing garbage collection of instances + self._instances: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() def __get__(self, instance, owner): if instance is None: return self - # Cache per-instance + # Cache per-instance (weak references allow GC when instance is no longer used) if instance not in self._instances: bound_method = self.method.__get__(instance, owner) + # Use weak reference wrapper to avoid preventing GC of instance + weak_method = _WeakMethod(bound_method) self._instances[instance] = CachedFetcher( max_size=self.max_size, - method=bound_method, + method=weak_method, cache_key_index=self.cache_key_index, ) return self._instances[instance] diff --git a/pyproject.toml b/pyproject.toml index ee9db3b..5eb5015 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "async-substrate-interface" -version = "1.6.0" +version = "1.6.1" description = "Asyncio library for interacting with substrate. Mostly API-compatible with py-substrate-interface" readme = "README.md" license = { file = "LICENSE" } @@ -37,6 +37,8 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Programming Language :: Python :: 3 :: Only", ] diff --git a/tests/helpers/settings.py b/tests/helpers/settings.py index 0e9e1da..ff6b3f7 100644 --- a/tests/helpers/settings.py +++ b/tests/helpers/settings.py @@ -33,6 +33,6 @@ environ.get("SUBSTRATE_AURA_NODE_URL") or "wss://acala-rpc-1.aca-api.network" ) -ARCHIVE_ENTRYPOINT = "wss://archive.chain.opentensor.ai:443" +ARCHIVE_ENTRYPOINT = "wss://archive.sub.latent.to" LATENT_LITE_ENTRYPOINT = "wss://lite.sub.latent.to:443" diff --git a/tests/integration_tests/test_disk_cache.py b/tests/integration_tests/test_disk_cache.py index cdebcc6..063eca1 100644 --- a/tests/integration_tests/test_disk_cache.py +++ b/tests/integration_tests/test_disk_cache.py @@ -5,13 +5,15 @@ AsyncSubstrateInterface, ) from async_substrate_interface.sync_substrate import SubstrateInterface +from tests.helpers.settings import LATENT_LITE_ENTRYPOINT @pytest.mark.asyncio async def test_disk_cache(): print("Testing test_disk_cache") - entrypoint = "wss://entrypoint-finney.opentensor.ai:443" - async with DiskCachedAsyncSubstrateInterface(entrypoint) as disk_cached_substrate: + async with DiskCachedAsyncSubstrateInterface( + LATENT_LITE_ENTRYPOINT, ss58_format=42, chain_name="Bittensor" + ) as disk_cached_substrate: current_block = await disk_cached_substrate.get_block_number(None) block_hash = await disk_cached_substrate.get_block_hash(current_block) parent_block_hash = await disk_cached_substrate.get_parent_block_hash( @@ -42,7 +44,9 @@ async def test_disk_cache(): assert block_runtime_info == block_runtime_info_from_cache assert block_runtime_version_for == block_runtime_version_from_cache # Verify data integrity with non-disk cached Async Substrate Interface - async with AsyncSubstrateInterface(entrypoint) as non_cache_substrate: + async with AsyncSubstrateInterface( + LATENT_LITE_ENTRYPOINT, ss58_format=42, chain_name="Bittensor" + ) as non_cache_substrate: block_hash_non_cache = await non_cache_substrate.get_block_hash(current_block) parent_block_hash_non_cache = await non_cache_substrate.get_parent_block_hash( block_hash_non_cache @@ -60,7 +64,9 @@ async def test_disk_cache(): assert block_runtime_info == block_runtime_info_non_cache assert block_runtime_version_for == block_runtime_version_for_non_cache # Verify data integrity with sync Substrate Interface - with SubstrateInterface(entrypoint) as sync_substrate: + with SubstrateInterface( + LATENT_LITE_ENTRYPOINT, ss58_format=42, chain_name="Bittensor" + ) as sync_substrate: block_hash_sync = sync_substrate.get_block_hash(current_block) parent_block_hash_sync = sync_substrate.get_parent_block_hash( block_hash_non_cache @@ -76,7 +82,9 @@ async def test_disk_cache(): assert block_runtime_info == block_runtime_info_sync assert block_runtime_version_for == block_runtime_version_for_sync # Verify data is pulling from disk cache - async with DiskCachedAsyncSubstrateInterface(entrypoint) as disk_cached_substrate: + async with DiskCachedAsyncSubstrateInterface( + LATENT_LITE_ENTRYPOINT, ss58_format=42, chain_name="Bittensor" + ) as disk_cached_substrate: start = time.monotonic() new_block_hash = await disk_cached_substrate.get_block_hash(current_block) new_time = time.monotonic() diff --git a/tests/unit_tests/asyncio_/test_substrate_interface.py b/tests/unit_tests/asyncio_/test_substrate_interface.py index 1253e6c..721804b 100644 --- a/tests/unit_tests/asyncio_/test_substrate_interface.py +++ b/tests/unit_tests/asyncio_/test_substrate_interface.py @@ -1,13 +1,17 @@ import asyncio +import tracemalloc from unittest.mock import AsyncMock, MagicMock, ANY import pytest from websockets.exceptions import InvalidURI from websockets.protocol import State -from async_substrate_interface.async_substrate import AsyncSubstrateInterface +from async_substrate_interface.async_substrate import ( + AsyncSubstrateInterface, + get_async_substrate_interface, +) from async_substrate_interface.types import ScaleObj -from tests.helpers.settings import ARCHIVE_ENTRYPOINT +from tests.helpers.settings import ARCHIVE_ENTRYPOINT, LATENT_LITE_ENTRYPOINT @pytest.mark.asyncio @@ -139,3 +143,35 @@ async def test_runtime_switching(): assert one is not None assert two is not None print("test_runtime_switching succeeded") + + +@pytest.mark.asyncio +async def test_memory_leak(): + import gc + + # Stop any existing tracemalloc and start fresh + tracemalloc.stop() + tracemalloc.start() + two_mb = 2 * 1024 * 1024 + + # Warmup: populate caches before taking baseline + for _ in range(2): + subtensor = await get_async_substrate_interface(LATENT_LITE_ENTRYPOINT) + await subtensor.close() + + baseline_snapshot = tracemalloc.take_snapshot() + + for i in range(5): + subtensor = await get_async_substrate_interface(LATENT_LITE_ENTRYPOINT) + await subtensor.close() + gc.collect() + + snapshot = tracemalloc.take_snapshot() + stats = snapshot.compare_to(baseline_snapshot, "lineno") + total_diff = sum(stat.size_diff for stat in stats) + current, peak = tracemalloc.get_traced_memory() + # Allow cumulative growth up to 2MB per iteration from baseline + assert total_diff < two_mb * (i + 1), ( + f"Loop {i}: diff={total_diff / 1024:.2f} KiB, current={current / 1024:.2f} KiB, " + f"peak={peak / 1024:.2f} KiB" + ) diff --git a/tests/unit_tests/sync/test_substrate_interface.py b/tests/unit_tests/sync/test_substrate_interface.py index 68f51b4..54a5b7d 100644 --- a/tests/unit_tests/sync/test_substrate_interface.py +++ b/tests/unit_tests/sync/test_substrate_interface.py @@ -1,9 +1,10 @@ +import tracemalloc from unittest.mock import MagicMock from async_substrate_interface.sync_substrate import SubstrateInterface from async_substrate_interface.types import ScaleObj -from tests.helpers.settings import ARCHIVE_ENTRYPOINT +from tests.helpers.settings import ARCHIVE_ENTRYPOINT, LATENT_LITE_ENTRYPOINT def test_runtime_call(monkeypatch): @@ -90,3 +91,34 @@ def test_runtime_switching(): assert substrate.get_extrinsics(block_number=block) is not None assert substrate.get_extrinsics(block_number=block - 21) is not None print("test_runtime_switching succeeded") + + +def test_memory_leak(): + import gc + + # Stop any existing tracemalloc and start fresh + tracemalloc.stop() + tracemalloc.start() + two_mb = 2 * 1024 * 1024 + + # Warmup: populate caches before taking baseline + for _ in range(2): + subtensor = SubstrateInterface(LATENT_LITE_ENTRYPOINT) + subtensor.close() + + baseline_snapshot = tracemalloc.take_snapshot() + + for i in range(5): + subtensor = SubstrateInterface(LATENT_LITE_ENTRYPOINT) + subtensor.close() + gc.collect() + + snapshot = tracemalloc.take_snapshot() + stats = snapshot.compare_to(baseline_snapshot, "lineno") + total_diff = sum(stat.size_diff for stat in stats) + current, peak = tracemalloc.get_traced_memory() + # Allow cumulative growth up to 2MB per iteration from baseline + assert total_diff < two_mb * (i + 1), ( + f"Loop {i}: diff={total_diff / 1024:.2f} KiB, current={current / 1024:.2f} KiB, " + f"peak={peak / 1024:.2f} KiB" + ) diff --git a/tests/unit_tests/test_types.py b/tests/unit_tests/test_types.py index f2e13b4..928d809 100644 --- a/tests/unit_tests/test_types.py +++ b/tests/unit_tests/test_types.py @@ -111,15 +111,15 @@ async def test_runtime_cache_from_disk(): substrate.initialized = True # runtime cache should be completely empty - assert substrate.runtime_cache.block_hashes == {} - assert substrate.runtime_cache.blocks == {} - assert substrate.runtime_cache.versions == {} + assert len(substrate.runtime_cache.block_hashes.cache) == 0 + assert len(substrate.runtime_cache.blocks.cache) == 0 + assert len(substrate.runtime_cache.versions.cache) == 0 await substrate.initialize() # after initialization, runtime cache should still be completely empty - assert substrate.runtime_cache.block_hashes == {} - assert substrate.runtime_cache.blocks == {} - assert substrate.runtime_cache.versions == {} + assert len(substrate.runtime_cache.block_hashes.cache) == 0 + assert len(substrate.runtime_cache.blocks.cache) == 0 + assert len(substrate.runtime_cache.versions.cache) == 0 await substrate.close() # ensure we have created the SQLite DB during initialize() @@ -136,7 +136,7 @@ async def test_runtime_cache_from_disk(): substrate.initialized = True await substrate.initialize() - assert substrate.runtime_cache.blocks == {fake_block: fake_hash} + assert substrate.runtime_cache.blocks.cache == {fake_block: fake_hash} # add an item to the cache substrate.runtime_cache.add_item( runtime=None, block_hash=new_fake_hash, block=new_fake_block