From ab71d844a6a3a21cedee22629e40c8dc2aab7456 Mon Sep 17 00:00:00 2001 From: Farhan Ali Raza Date: Fri, 20 Mar 2026 15:13:04 +0500 Subject: [PATCH 1/5] Add in-memory state expiration to StateManagerMemory Extract reusable expiration logic into StateManagerExpiration base class that tracks token access times and purges expired states using a deadline-ordered heap. Integrate it into StateManagerMemory with a background asyncio task that automatically cleans up idle client states. --- reflex/istate/manager/_expiration.py | 250 ++++++++++++++++++ reflex/istate/manager/memory.py | 70 ++++- .../test_memory_state_manager_expiration.py | 109 ++++++++ tests/units/istate/manager/test_expiration.py | 203 ++++++++++++++ 4 files changed, 623 insertions(+), 9 deletions(-) create mode 100644 reflex/istate/manager/_expiration.py create mode 100644 tests/integration/test_memory_state_manager_expiration.py create mode 100644 tests/units/istate/manager/test_expiration.py diff --git a/reflex/istate/manager/_expiration.py b/reflex/istate/manager/_expiration.py new file mode 100644 index 00000000000..2e6534f9b8c --- /dev/null +++ b/reflex/istate/manager/_expiration.py @@ -0,0 +1,250 @@ +"""Internal helpers for in-memory state expiration.""" + +import asyncio +import contextlib +import dataclasses +import heapq +import time +from typing import ClassVar + +from reflex.state import BaseState + +from . import _default_token_expiration + + +@dataclasses.dataclass +class StateManagerExpiration: + """Internal base for managers with in-memory state expiration.""" + + _locked_expiration_poll_interval: ClassVar[float] = 0.1 + _recheck_expired_locks_on_unlock: ClassVar[bool] = False + + token_expiration: int = dataclasses.field(default_factory=_default_token_expiration) + + # The mapping of client ids to states. + states: dict[str, BaseState] = dataclasses.field(default_factory=dict) + + # The dict of mutexes for each client. + _states_locks: dict[str, asyncio.Lock] = dataclasses.field( + default_factory=dict, + init=False, + ) + + # The latest expiration deadline for each token. + _token_expires_at: dict[str, float] = dataclasses.field( + default_factory=dict, + init=False, + ) + + # Last time a token was touched. + _token_last_touched: dict[str, float] = dataclasses.field( + default_factory=dict, + init=False, + ) + + # Deadline-ordered token expiration heap. + _token_expiration_heap: list[tuple[float, str]] = dataclasses.field( + default_factory=list, + init=False, + repr=False, + ) + + # Tokens whose expiration is deferred until their state lock is released. + _pending_locked_expirations: set[str] = dataclasses.field( + default_factory=set, + init=False, + repr=False, + ) + + # Wake any background expiration worker when token activity changes. + _token_activity: asyncio.Event = dataclasses.field( + default_factory=asyncio.Event, + init=False, + repr=False, + ) + + _scheduled_expiration_deadline: float | None = dataclasses.field( + default=None, + init=False, + repr=False, + ) + + def _touch_token(self, token: str): + """Record access for a token. + + Args: + token: The token that was accessed. + """ + touched_at = time.time() + expires_at = touched_at + self.token_expiration + self._token_last_touched[token] = touched_at + self._token_expires_at[token] = expires_at + self._pending_locked_expirations.discard(token) + heapq.heappush(self._token_expiration_heap, (expires_at, token)) + self._maybe_compact_expiration_heap() + if ( + self._scheduled_expiration_deadline is None + or expires_at <= self._scheduled_expiration_deadline + ): + self._token_activity.set() + + def _maybe_compact_expiration_heap(self): + """Rebuild the heap when stale deadline entries accumulate.""" + if len(self._token_expiration_heap) <= (2 * len(self._token_expires_at)) + 1: + return + self._token_expiration_heap = [ + (expires_at, token) + for token, expires_at in self._token_expires_at.items() + if token not in self._pending_locked_expirations + ] + heapq.heapify(self._token_expiration_heap) + + def _next_expiration(self) -> tuple[float, str] | None: + """Get the next valid token expiration from the heap. + + Returns: + The next expiration deadline and token, or None if there are no + active deadlines to process. + """ + while self._token_expiration_heap: + expires_at, token = self._token_expiration_heap[0] + current_expiration = self._token_expires_at.get(token) + if ( + current_expiration != expires_at + or token in self._pending_locked_expirations + ): + heapq.heappop(self._token_expiration_heap) + continue + return expires_at, token + return None + + def _purge_token(self, token: str): + """Remove a token from all in-memory expiration bookkeeping. + + Args: + token: The token to purge. + """ + self._token_last_touched.pop(token, None) + self._token_expires_at.pop(token, None) + self.states.pop(token, None) + self._states_locks.pop(token, None) + self._pending_locked_expirations.discard(token) + + def _purge_expired_tokens( + self, + now: float | None = None, + ) -> list[str]: + """Purge expired in-memory state entries. + + If a token's state lock is currently held, defer cleanup until a later pass + to avoid replacing the state while it is being modified. + + Args: + now: The time to compare against. + + Returns: + The list of purged tokens. + """ + now = time.time() if now is None else now + expired_tokens = [] + while ( + next_expiration := self._next_expiration() + ) is not None and next_expiration[0] <= now: + _expires_at, token = heapq.heappop(self._token_expiration_heap) + if ( + state_lock := self._states_locks.get(token) + ) is not None and state_lock.locked(): + self._pending_locked_expirations.add(token) + continue + self._purge_token(token) + expired_tokens.append(token) + return expired_tokens + + def _next_expiration_in( + self, + now: float | None = None, + ) -> float | None: + """Get the delay until the next expiration check should run. + + Args: + now: The time to compare against. + + Returns: + The number of seconds until the next check, or None when there are no + tracked tokens. + """ + if (next_expiration := self._next_expiration()) is None: + if ( + self._pending_locked_expirations + and not self._recheck_expired_locks_on_unlock + ): + return self._locked_expiration_poll_interval + return None + + now = time.time() if now is None else now + next_delay = max(0.0, next_expiration[0] - now) + if ( + self._pending_locked_expirations + and not self._recheck_expired_locks_on_unlock + ): + return min(next_delay, self._locked_expiration_poll_interval) + return next_delay + + def _reset_token_activity_wait(self): + """Reset the token activity event before waiting.""" + self._token_activity.clear() + + def _prepare_expiration_wait( + self, + *, + now: float | None = None, + default_timeout: float | None = None, + ) -> float | None: + """Prepare the next wait window for an expiration worker. + + Args: + now: The current time. + default_timeout: A fallback timeout when there are no in-memory token + deadlines to wait on. + + Returns: + The timeout to use for the next wait. + """ + self._reset_token_activity_wait() + now = time.time() if now is None else now + timeout = self._next_expiration_in(now=now) + if timeout is None: + timeout = default_timeout + elif default_timeout is not None: + timeout = min(timeout, default_timeout) + self._scheduled_expiration_deadline = None if timeout is None else now + timeout + return timeout + + def _notify_token_unlocked(self, token: str): + """Requeue a deferred expiration check for a token after its lock is released. + + Args: + token: The unlocked token. + """ + if token not in self._pending_locked_expirations: + return + self._pending_locked_expirations.discard(token) + if (expires_at := self._token_expires_at.get(token)) is None: + return + heapq.heappush(self._token_expiration_heap, (expires_at, token)) + self._token_activity.set() + + async def _wait_for_token_activity(self, timeout: float | None): + """Wait for token activity or a timeout. + + Args: + timeout: The maximum time to wait. When None, waits indefinitely. + """ + try: + if timeout is None: + await self._token_activity.wait() + return + with contextlib.suppress(asyncio.TimeoutError): + await asyncio.wait_for(self._token_activity.wait(), timeout=timeout) + finally: + self._scheduled_expiration_deadline = None diff --git a/reflex/istate/manager/memory.py b/reflex/istate/manager/memory.py index fab4df56305..a3d3de2f46b 100644 --- a/reflex/istate/manager/memory.py +++ b/reflex/istate/manager/memory.py @@ -3,28 +3,64 @@ import asyncio import contextlib import dataclasses +import time from collections.abc import AsyncIterator +from typing import ClassVar from typing_extensions import Unpack, override from reflex.istate.manager import StateManager, StateModificationContext +from reflex.istate.manager._expiration import StateManagerExpiration from reflex.state import BaseState, _split_substate_key +from reflex.utils import console + +_EXPIRATION_ERROR_RETRY_SECONDS = 1.0 @dataclasses.dataclass -class StateManagerMemory(StateManager): +class StateManagerMemory(StateManagerExpiration, StateManager): """A state manager that stores states in memory.""" - # The mapping of client ids to states. - states: dict[str, BaseState] = dataclasses.field(default_factory=dict) + _recheck_expired_locks_on_unlock: ClassVar[bool] = True # The mutex ensures the dict of mutexes is updated exclusively _state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock()) - # The dict of mutexes for each client - _states_locks: dict[str, asyncio.Lock] = dataclasses.field( - default_factory=dict, init=False - ) + _expiration_task: asyncio.Task | None = None + + async def _expire_states_once(self): + """Perform one expiration pass and wait for the next check.""" + try: + now = time.time() + self._purge_expired_tokens(now=now) + await self._wait_for_token_activity( + self._prepare_expiration_wait(now=now), + ) + except asyncio.CancelledError: + raise + except Exception as err: + console.error(f"Error expiring in-memory states: {err!r}") + await asyncio.sleep(_EXPIRATION_ERROR_RETRY_SECONDS) + + async def _expire_states(self): + """Long running task that removes expired states from memory. + + Raises: + asyncio.CancelledError: When the task is cancelled. + """ + while True: + await self._expire_states_once() + + async def _schedule_expiration_task(self): + """Schedule the expiration task if it is not already running.""" + if self._expiration_task is None or self._expiration_task.done(): + async with self._state_manager_lock: + if self._expiration_task is None or self._expiration_task.done(): + self._expiration_task = asyncio.create_task( + self._expire_states(), + name="StateManagerMemory|ExpirationProcessor", + ) + await asyncio.sleep(0) @override async def get_state(self, token: str) -> BaseState: @@ -38,6 +74,8 @@ async def get_state(self, token: str) -> BaseState: """ # Memory state manager ignores the substate suffix and always returns the top-level state. token = _split_substate_key(token)[0] + self._touch_token(token) + await self._schedule_expiration_task() if token not in self.states: self.states[token] = self.state(_reflex_internal_init=True) return self.states[token] @@ -57,7 +95,9 @@ async def set_state( context: The state modification context. """ token = _split_substate_key(token)[0] + self._touch_token(token) self.states[token] = state + await self._schedule_expiration_task() @override @contextlib.asynccontextmanager @@ -80,5 +120,17 @@ async def modify_state( if token not in self._states_locks: self._states_locks[token] = asyncio.Lock() - async with self._states_locks[token]: - yield await self.get_state(token) + try: + async with self._states_locks[token]: + yield await self.get_state(token) + finally: + self._notify_token_unlocked(token) + + async def close(self): + """Cancel the in-memory expiration task.""" + async with self._state_manager_lock: + if self._expiration_task: + self._expiration_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._expiration_task + self._expiration_task = None diff --git a/tests/integration/test_memory_state_manager_expiration.py b/tests/integration/test_memory_state_manager_expiration.py new file mode 100644 index 00000000000..a3d3b4a19af --- /dev/null +++ b/tests/integration/test_memory_state_manager_expiration.py @@ -0,0 +1,109 @@ +"""Integration tests for in-memory state expiration.""" + +from collections.abc import Generator + +import pytest +from selenium.webdriver.common.by import By +from selenium.webdriver.remote.webdriver import WebDriver + +from reflex.istate.manager.memory import StateManagerMemory +from reflex.testing import AppHarness + + +def MemoryExpirationApp(): + """Reflex app that exposes state expiration through a simple counter UI.""" + import reflex as rx + + class State(rx.State): + counter: int = 0 + + @rx.event + def increment(self): + self.counter += 1 + + app = rx.App() + + @app.add_page + def index(): + return rx.vstack( + rx.input( + id="token", + value=State.router.session.client_token, + is_read_only=True, + ), + rx.text(State.counter, id="counter"), + rx.button("Increment", id="increment", on_click=State.increment), + ) + + +@pytest.fixture +def memory_expiration_app( + app_harness_env: type[AppHarness], + monkeypatch: pytest.MonkeyPatch, + tmp_path_factory: pytest.TempPathFactory, +) -> Generator[AppHarness, None, None]: + """Start a memory-backed app with a short expiration window. + + Yields: + A running app harness configured to use StateManagerMemory. + """ + monkeypatch.setenv("REFLEX_STATE_MANAGER_MODE", "memory") + monkeypatch.setenv("REFLEX_REDIS_TOKEN_EXPIRATION", "1") + + with app_harness_env.create( + root=tmp_path_factory.mktemp("memory_expiration_app"), + app_name=f"memory_expiration_{app_harness_env.__name__.lower()}", + app_source=MemoryExpirationApp, + ) as harness: + assert isinstance(harness.state_manager, StateManagerMemory) + yield harness + + +@pytest.fixture +def driver(memory_expiration_app: AppHarness) -> Generator[WebDriver, None, None]: + """Open the memory expiration app in a browser. + + Yields: + A webdriver instance pointed at the running app. + """ + assert memory_expiration_app.app_instance is not None, "app is not running" + driver = memory_expiration_app.frontend() + try: + yield driver + finally: + driver.quit() + + +def test_memory_state_manager_expires_state_end_to_end( + memory_expiration_app: AppHarness, + driver: WebDriver, +): + """An idle in-memory state should expire and reset on the next event.""" + app_instance = memory_expiration_app.app_instance + assert app_instance is not None + + token_input = AppHarness.poll_for_or_raise_timeout( + lambda: driver.find_element(By.ID, "token") + ) + token = memory_expiration_app.poll_for_value(token_input) + assert token is not None + + counter = driver.find_element(By.ID, "counter") + increment = driver.find_element(By.ID, "increment") + app_state_manager = app_instance.state_manager + assert isinstance(app_state_manager, StateManagerMemory) + + AppHarness.expect(lambda: counter.text == "0") + + increment.click() + AppHarness.expect(lambda: counter.text == "1") + + increment.click() + AppHarness.expect(lambda: counter.text == "2") + + AppHarness.expect(lambda: token in app_state_manager.states) + AppHarness.expect(lambda: token not in app_state_manager.states, timeout=5) + + increment.click() + AppHarness.expect(lambda: counter.text == "1") + assert token_input.get_attribute("value") == token diff --git a/tests/units/istate/manager/test_expiration.py b/tests/units/istate/manager/test_expiration.py new file mode 100644 index 00000000000..7f843161142 --- /dev/null +++ b/tests/units/istate/manager/test_expiration.py @@ -0,0 +1,203 @@ +"""Tests for state manager token expiration.""" + +import asyncio +import time +from collections.abc import AsyncGenerator, Callable + +import pytest +import pytest_asyncio + +from reflex.istate.manager.memory import StateManagerMemory +from reflex.state import BaseState, _substate_key + + +class ExpiringState(BaseState): + """A test state for expiration-specific manager tests.""" + + value: int = 0 + + +async def _poll_until( + predicate: Callable[[], bool], + *, + timeout: float = 3.0, + interval: float = 0.05, +): + """Poll until a predicate succeeds. + + Args: + predicate: The predicate to evaluate. + timeout: The maximum time to wait. + interval: The delay between attempts. + """ + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + if predicate(): + return + await asyncio.sleep(interval) + assert predicate() + + +@pytest_asyncio.fixture(loop_scope="function", scope="function") +async def state_manager_memory() -> AsyncGenerator[StateManagerMemory]: + """Create a memory state manager with a short expiration. + + Yields: + The memory state manager under test. + """ + state_manager = StateManagerMemory(state=ExpiringState, token_expiration=1) + yield state_manager + await state_manager.close() + + +@pytest.mark.asyncio +async def test_memory_state_manager_evicts_expired_state( + state_manager_memory: StateManagerMemory, + token: str, +): + """Expired states should be removed from the in-memory cache and locks.""" + state_token = _substate_key(token, ExpiringState) + + async with state_manager_memory.modify_state(state_token) as state: + state.value = 42 + + assert token in state_manager_memory.states + assert token in state_manager_memory._states_locks + assert token in state_manager_memory._token_last_touched + + await _poll_until( + lambda: ( + token not in state_manager_memory.states + and token not in state_manager_memory._states_locks + and token not in state_manager_memory._token_last_touched + ) + ) + + +@pytest.mark.asyncio +async def test_memory_state_manager_get_state_refreshes_expiration( + state_manager_memory: StateManagerMemory, + token: str, +): + """Accessing a state should extend its expiration window.""" + state_token = _substate_key(token, ExpiringState) + state = await state_manager_memory.get_state(state_token) + assert isinstance(state, ExpiringState) + state.value = 7 + first_touch = state_manager_memory._token_last_touched[token] + + await asyncio.sleep(0.6) + + same_state = await state_manager_memory.get_state(state_token) + assert same_state is state + assert state_manager_memory._token_last_touched[token] > first_touch + + await asyncio.sleep(0.6) + + assert token in state_manager_memory.states + assert state_manager_memory.states[token] is state + + +@pytest.mark.asyncio +async def test_memory_state_manager_set_state_refreshes_expiration( + state_manager_memory: StateManagerMemory, + token: str, +): + """Persisting a state should extend its expiration window.""" + state_token = _substate_key(token, ExpiringState) + state = await state_manager_memory.get_state(state_token) + assert isinstance(state, ExpiringState) + state.value = 17 + first_touch = state_manager_memory._token_last_touched[token] + + await asyncio.sleep(0.6) + + await state_manager_memory.set_state(state_token, state) + + assert state_manager_memory._token_last_touched[token] > first_touch + + await asyncio.sleep(0.6) + + assert token in state_manager_memory.states + assert state_manager_memory.states[token] is state + + +@pytest.mark.asyncio +async def test_memory_state_manager_multiple_touches_do_not_evict_early( + state_manager_memory: StateManagerMemory, + token: str, +): + """Repeated touches should honor the latest expiration deadline.""" + state_token = _substate_key(token, ExpiringState) + state = await state_manager_memory.get_state(state_token) + assert isinstance(state, ExpiringState) + + for _ in range(3): + await asyncio.sleep(0.35) + assert await state_manager_memory.get_state(state_token) is state + + # The first deadlines have passed, but the latest touch should still keep the + # token alive until its own expiration window ends. + await asyncio.sleep(0.2) + + assert token in state_manager_memory.states + + await _poll_until(lambda: token not in state_manager_memory.states) + + +@pytest.mark.asyncio +async def test_memory_state_manager_returns_fresh_state_after_eviction( + state_manager_memory: StateManagerMemory, + token: str, +): + """A token should get a fresh state after the previous one expires.""" + state_token = _substate_key(token, ExpiringState) + state = await state_manager_memory.get_state(state_token) + assert isinstance(state, ExpiringState) + state.value = 99 + + await _poll_until(lambda: token not in state_manager_memory.states) + + fresh_state = await state_manager_memory.get_state(state_token) + assert isinstance(fresh_state, ExpiringState) + assert fresh_state is not state + assert fresh_state.value == 0 + + +@pytest.mark.asyncio +async def test_memory_state_manager_close_cancels_expiration_task( + state_manager_memory: StateManagerMemory, + token: str, +): + """Closing the manager should cancel the expiration task cleanly.""" + await state_manager_memory.get_state(_substate_key(token, ExpiringState)) + + expiration_task = state_manager_memory._expiration_task + assert expiration_task is not None + assert not expiration_task.done() + + await state_manager_memory.close() + + assert state_manager_memory._expiration_task is None + assert expiration_task.done() + + await state_manager_memory.close() + + +@pytest.mark.asyncio +async def test_memory_state_manager_evicts_expired_locked_state_after_unlock( + state_manager_memory: StateManagerMemory, + token: str, +): + """An expired locked state should be evicted once its lock is released.""" + state_token = _substate_key(token, ExpiringState) + + async with state_manager_memory.modify_state(state_token) as state: + state.value = 5 + await _poll_until( + lambda: token in state_manager_memory._pending_locked_expirations, + timeout=2.0, + ) + assert token in state_manager_memory.states + + await _poll_until(lambda: token not in state_manager_memory.states) From c0b8742c00ba266e00a301d98b98440c99687853 Mon Sep 17 00:00:00 2001 From: Farhan Ali Raza Date: Fri, 20 Mar 2026 15:18:09 +0500 Subject: [PATCH 2/5] Simplify StateManagerMemory expiration internals Remove dead _token_last_touched dict, replace hand-rolled task scheduling with ensure_task, move heap compaction off the hot path, and fix touch ordering in get_state/set_state. --- reflex/istate/manager/_expiration.py | 30 +++------- reflex/istate/manager/memory.py | 59 +++++++------------ tests/units/istate/manager/test_expiration.py | 12 ++-- 3 files changed, 35 insertions(+), 66 deletions(-) diff --git a/reflex/istate/manager/_expiration.py b/reflex/istate/manager/_expiration.py index 2e6534f9b8c..04a5988c277 100644 --- a/reflex/istate/manager/_expiration.py +++ b/reflex/istate/manager/_expiration.py @@ -1,16 +1,19 @@ """Internal helpers for in-memory state expiration.""" +from __future__ import annotations + import asyncio import contextlib import dataclasses import heapq import time -from typing import ClassVar - -from reflex.state import BaseState +from typing import TYPE_CHECKING, ClassVar from . import _default_token_expiration +if TYPE_CHECKING: + from reflex.state import BaseState + @dataclasses.dataclass class StateManagerExpiration: @@ -36,12 +39,6 @@ class StateManagerExpiration: init=False, ) - # Last time a token was touched. - _token_last_touched: dict[str, float] = dataclasses.field( - default_factory=dict, - init=False, - ) - # Deadline-ordered token expiration heap. _token_expiration_heap: list[tuple[float, str]] = dataclasses.field( default_factory=list, @@ -75,13 +72,10 @@ def _touch_token(self, token: str): Args: token: The token that was accessed. """ - touched_at = time.time() - expires_at = touched_at + self.token_expiration - self._token_last_touched[token] = touched_at + expires_at = time.time() + self.token_expiration self._token_expires_at[token] = expires_at self._pending_locked_expirations.discard(token) heapq.heappush(self._token_expiration_heap, (expires_at, token)) - self._maybe_compact_expiration_heap() if ( self._scheduled_expiration_deadline is None or expires_at <= self._scheduled_expiration_deadline @@ -124,7 +118,6 @@ def _purge_token(self, token: str): Args: token: The token to purge. """ - self._token_last_touched.pop(token, None) self._token_expires_at.pop(token, None) self.states.pop(token, None) self._states_locks.pop(token, None) @@ -133,7 +126,7 @@ def _purge_token(self, token: str): def _purge_expired_tokens( self, now: float | None = None, - ) -> list[str]: + ): """Purge expired in-memory state entries. If a token's state lock is currently held, defer cleanup until a later pass @@ -141,12 +134,8 @@ def _purge_expired_tokens( Args: now: The time to compare against. - - Returns: - The list of purged tokens. """ now = time.time() if now is None else now - expired_tokens = [] while ( next_expiration := self._next_expiration() ) is not None and next_expiration[0] <= now: @@ -157,8 +146,7 @@ def _purge_expired_tokens( self._pending_locked_expirations.add(token) continue self._purge_token(token) - expired_tokens.append(token) - return expired_tokens + self._maybe_compact_expiration_heap() def _next_expiration_in( self, diff --git a/reflex/istate/manager/memory.py b/reflex/istate/manager/memory.py index a3d3de2f46b..d5be3d43211 100644 --- a/reflex/istate/manager/memory.py +++ b/reflex/istate/manager/memory.py @@ -12,9 +12,7 @@ from reflex.istate.manager import StateManager, StateModificationContext from reflex.istate.manager._expiration import StateManagerExpiration from reflex.state import BaseState, _split_substate_key -from reflex.utils import console - -_EXPIRATION_ERROR_RETRY_SECONDS = 1.0 +from reflex.utils.tasks import ensure_task @dataclasses.dataclass @@ -26,41 +24,24 @@ class StateManagerMemory(StateManagerExpiration, StateManager): # The mutex ensures the dict of mutexes is updated exclusively _state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock()) - _expiration_task: asyncio.Task | None = None + _expiration_task: asyncio.Task | None = dataclasses.field(default=None, init=False) async def _expire_states_once(self): """Perform one expiration pass and wait for the next check.""" - try: - now = time.time() - self._purge_expired_tokens(now=now) - await self._wait_for_token_activity( - self._prepare_expiration_wait(now=now), - ) - except asyncio.CancelledError: - raise - except Exception as err: - console.error(f"Error expiring in-memory states: {err!r}") - await asyncio.sleep(_EXPIRATION_ERROR_RETRY_SECONDS) - - async def _expire_states(self): - """Long running task that removes expired states from memory. - - Raises: - asyncio.CancelledError: When the task is cancelled. - """ - while True: - await self._expire_states_once() - - async def _schedule_expiration_task(self): - """Schedule the expiration task if it is not already running.""" - if self._expiration_task is None or self._expiration_task.done(): - async with self._state_manager_lock: - if self._expiration_task is None or self._expiration_task.done(): - self._expiration_task = asyncio.create_task( - self._expire_states(), - name="StateManagerMemory|ExpirationProcessor", - ) - await asyncio.sleep(0) + now = time.time() + self._purge_expired_tokens(now=now) + await self._wait_for_token_activity( + self._prepare_expiration_wait(now=now), + ) + + def _ensure_expiration_task(self): + """Ensure the expiration background task is running.""" + ensure_task( + self, + "_expiration_task", + self._expire_states_once, + suppress_exceptions=[Exception], + ) @override async def get_state(self, token: str) -> BaseState: @@ -74,10 +55,10 @@ async def get_state(self, token: str) -> BaseState: """ # Memory state manager ignores the substate suffix and always returns the top-level state. token = _split_substate_key(token)[0] - self._touch_token(token) - await self._schedule_expiration_task() if token not in self.states: self.states[token] = self.state(_reflex_internal_init=True) + self._touch_token(token) + self._ensure_expiration_task() return self.states[token] @override @@ -95,9 +76,9 @@ async def set_state( context: The state modification context. """ token = _split_substate_key(token)[0] - self._touch_token(token) self.states[token] = state - await self._schedule_expiration_task() + self._touch_token(token) + self._ensure_expiration_task() @override @contextlib.asynccontextmanager diff --git a/tests/units/istate/manager/test_expiration.py b/tests/units/istate/manager/test_expiration.py index 7f843161142..c123477eb42 100644 --- a/tests/units/istate/manager/test_expiration.py +++ b/tests/units/istate/manager/test_expiration.py @@ -63,13 +63,13 @@ async def test_memory_state_manager_evicts_expired_state( assert token in state_manager_memory.states assert token in state_manager_memory._states_locks - assert token in state_manager_memory._token_last_touched + assert token in state_manager_memory._token_expires_at await _poll_until( lambda: ( token not in state_manager_memory.states and token not in state_manager_memory._states_locks - and token not in state_manager_memory._token_last_touched + and token not in state_manager_memory._token_expires_at ) ) @@ -84,13 +84,13 @@ async def test_memory_state_manager_get_state_refreshes_expiration( state = await state_manager_memory.get_state(state_token) assert isinstance(state, ExpiringState) state.value = 7 - first_touch = state_manager_memory._token_last_touched[token] + first_expires_at = state_manager_memory._token_expires_at[token] await asyncio.sleep(0.6) same_state = await state_manager_memory.get_state(state_token) assert same_state is state - assert state_manager_memory._token_last_touched[token] > first_touch + assert state_manager_memory._token_expires_at[token] > first_expires_at await asyncio.sleep(0.6) @@ -108,13 +108,13 @@ async def test_memory_state_manager_set_state_refreshes_expiration( state = await state_manager_memory.get_state(state_token) assert isinstance(state, ExpiringState) state.value = 17 - first_touch = state_manager_memory._token_last_touched[token] + first_expires_at = state_manager_memory._token_expires_at[token] await asyncio.sleep(0.6) await state_manager_memory.set_state(state_token, state) - assert state_manager_memory._token_last_touched[token] > first_touch + assert state_manager_memory._token_expires_at[token] > first_expires_at await asyncio.sleep(0.6) From 2b0e0972a27df3e9f8da2abc40d38415feb6e83d Mon Sep 17 00:00:00 2001 From: Farhan Ali Raza Date: Fri, 20 Mar 2026 15:45:46 +0500 Subject: [PATCH 3/5] Fix StateManager.create to respect explicit memory mode when Redis URL is set Previously, StateManager.create always overrode the mode to REDIS when a Redis URL was detected, ignoring an explicitly configured memory mode. Now it only auto-promotes to REDIS when state_manager_mode was not explicitly set. Adds a test verifying the explicit mode is honored. --- reflex/istate/manager/__init__.py | 5 +++- reflex/istate/manager/_expiration.py | 5 ++-- .../test_memory_state_manager_expiration.py | 1 + tests/units/test_state.py | 28 +++++++++++++++++++ 4 files changed, 36 insertions(+), 3 deletions(-) diff --git a/reflex/istate/manager/__init__.py b/reflex/istate/manager/__init__.py index bf4477cfd08..970c0393370 100644 --- a/reflex/istate/manager/__init__.py +++ b/reflex/istate/manager/__init__.py @@ -46,7 +46,10 @@ def create(cls, state: type[BaseState]): InvalidStateManagerModeError: If the state manager mode is invalid. """ config = get_config() - if prerequisites.parse_redis_url() is not None: + if ( + "state_manager_mode" not in config._non_default_attributes + and prerequisites.parse_redis_url() is not None + ): config.state_manager_mode = constants.StateManagerMode.REDIS if config.state_manager_mode == constants.StateManagerMode.MEMORY: from reflex.istate.manager.memory import StateManagerMemory diff --git a/reflex/istate/manager/_expiration.py b/reflex/istate/manager/_expiration.py index 04a5988c277..080de251b28 100644 --- a/reflex/istate/manager/_expiration.py +++ b/reflex/istate/manager/_expiration.py @@ -19,7 +19,7 @@ class StateManagerExpiration: """Internal base for managers with in-memory state expiration.""" - _locked_expiration_poll_interval: ClassVar[float] = 0.1 + _locked_expiration_poll_interval: ClassVar[float] = 0.1 # 100 ms _recheck_expired_locks_on_unlock: ClassVar[bool] = False token_expiration: int = dataclasses.field(default_factory=_default_token_expiration) @@ -72,7 +72,8 @@ def _touch_token(self, token: str): Args: token: The token that was accessed. """ - expires_at = time.time() + self.token_expiration + touched_at = time.time() + expires_at = touched_at + self.token_expiration # seconds from last touch self._token_expires_at[token] = expires_at self._pending_locked_expirations.discard(token) heapq.heappush(self._token_expiration_heap, (expires_at, token)) diff --git a/tests/integration/test_memory_state_manager_expiration.py b/tests/integration/test_memory_state_manager_expiration.py index a3d3b4a19af..965b65a5520 100644 --- a/tests/integration/test_memory_state_manager_expiration.py +++ b/tests/integration/test_memory_state_manager_expiration.py @@ -48,6 +48,7 @@ def memory_expiration_app( A running app harness configured to use StateManagerMemory. """ monkeypatch.setenv("REFLEX_STATE_MANAGER_MODE", "memory") + # Memory expiration reuses the shared token_expiration config field. monkeypatch.setenv("REFLEX_REDIS_TOKEN_EXPIRATION", "1") with app_harness_env.create( diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 0775dd000e5..d02940d24bd 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -3596,6 +3596,34 @@ def test_redis_state_manager_config_knobs_invalid_lock_warning_threshold( del sys.modules[constants.Config.MODULE] +def test_state_manager_create_respects_explicit_memory_mode_with_redis_url( + tmp_path, monkeypatch: pytest.MonkeyPatch +): + proj_root = tmp_path / "project1" + proj_root.mkdir() + + config_string = """ +import reflex as rx +config = rx.Config( + app_name="project1", +) + """ + + (proj_root / "rxconfig.py").write_text(dedent(config_string)) + monkeypatch.setenv("REFLEX_STATE_MANAGER_MODE", "memory") + monkeypatch.setenv("REFLEX_REDIS_URL", "redis://localhost:6379") + + with chdir(proj_root): + reflex.config.get_config(reload=True) + monkeypatch.setattr(prerequisites, "get_redis", mock_redis) + from reflex.state import State + + state_manager = StateManager.create(state=State) + assert isinstance(state_manager, StateManagerMemory) + + del sys.modules[constants.Config.MODULE] + + def test_auto_setters_off(tmp_path): proj_root = tmp_path / "project1" proj_root.mkdir() From cda08bba3ef98c6a22317d9ce76edfc79d7dec40 Mon Sep 17 00:00:00 2001 From: Farhan Ali Raza Date: Sat, 21 Mar 2026 15:01:35 +0500 Subject: [PATCH 4/5] feat: add in-memory state expiration to StateManagerMemory Implement automatic eviction of idle client states in the memory state manager using a heap-based expiration system. States are touched on get/set and purged after token_expiration seconds of inactivity. Locked states defer eviction until their lock is released. Also fix StateManager.create to respect an explicit memory mode when a Redis URL is configured. --- reflex/istate/manager/_expiration.py | 239 ------------------ reflex/istate/manager/memory.py | 149 ++++++++--- tests/units/istate/manager/test_expiration.py | 50 ++-- 3 files changed, 137 insertions(+), 301 deletions(-) delete mode 100644 reflex/istate/manager/_expiration.py diff --git a/reflex/istate/manager/_expiration.py b/reflex/istate/manager/_expiration.py deleted file mode 100644 index 080de251b28..00000000000 --- a/reflex/istate/manager/_expiration.py +++ /dev/null @@ -1,239 +0,0 @@ -"""Internal helpers for in-memory state expiration.""" - -from __future__ import annotations - -import asyncio -import contextlib -import dataclasses -import heapq -import time -from typing import TYPE_CHECKING, ClassVar - -from . import _default_token_expiration - -if TYPE_CHECKING: - from reflex.state import BaseState - - -@dataclasses.dataclass -class StateManagerExpiration: - """Internal base for managers with in-memory state expiration.""" - - _locked_expiration_poll_interval: ClassVar[float] = 0.1 # 100 ms - _recheck_expired_locks_on_unlock: ClassVar[bool] = False - - token_expiration: int = dataclasses.field(default_factory=_default_token_expiration) - - # The mapping of client ids to states. - states: dict[str, BaseState] = dataclasses.field(default_factory=dict) - - # The dict of mutexes for each client. - _states_locks: dict[str, asyncio.Lock] = dataclasses.field( - default_factory=dict, - init=False, - ) - - # The latest expiration deadline for each token. - _token_expires_at: dict[str, float] = dataclasses.field( - default_factory=dict, - init=False, - ) - - # Deadline-ordered token expiration heap. - _token_expiration_heap: list[tuple[float, str]] = dataclasses.field( - default_factory=list, - init=False, - repr=False, - ) - - # Tokens whose expiration is deferred until their state lock is released. - _pending_locked_expirations: set[str] = dataclasses.field( - default_factory=set, - init=False, - repr=False, - ) - - # Wake any background expiration worker when token activity changes. - _token_activity: asyncio.Event = dataclasses.field( - default_factory=asyncio.Event, - init=False, - repr=False, - ) - - _scheduled_expiration_deadline: float | None = dataclasses.field( - default=None, - init=False, - repr=False, - ) - - def _touch_token(self, token: str): - """Record access for a token. - - Args: - token: The token that was accessed. - """ - touched_at = time.time() - expires_at = touched_at + self.token_expiration # seconds from last touch - self._token_expires_at[token] = expires_at - self._pending_locked_expirations.discard(token) - heapq.heappush(self._token_expiration_heap, (expires_at, token)) - if ( - self._scheduled_expiration_deadline is None - or expires_at <= self._scheduled_expiration_deadline - ): - self._token_activity.set() - - def _maybe_compact_expiration_heap(self): - """Rebuild the heap when stale deadline entries accumulate.""" - if len(self._token_expiration_heap) <= (2 * len(self._token_expires_at)) + 1: - return - self._token_expiration_heap = [ - (expires_at, token) - for token, expires_at in self._token_expires_at.items() - if token not in self._pending_locked_expirations - ] - heapq.heapify(self._token_expiration_heap) - - def _next_expiration(self) -> tuple[float, str] | None: - """Get the next valid token expiration from the heap. - - Returns: - The next expiration deadline and token, or None if there are no - active deadlines to process. - """ - while self._token_expiration_heap: - expires_at, token = self._token_expiration_heap[0] - current_expiration = self._token_expires_at.get(token) - if ( - current_expiration != expires_at - or token in self._pending_locked_expirations - ): - heapq.heappop(self._token_expiration_heap) - continue - return expires_at, token - return None - - def _purge_token(self, token: str): - """Remove a token from all in-memory expiration bookkeeping. - - Args: - token: The token to purge. - """ - self._token_expires_at.pop(token, None) - self.states.pop(token, None) - self._states_locks.pop(token, None) - self._pending_locked_expirations.discard(token) - - def _purge_expired_tokens( - self, - now: float | None = None, - ): - """Purge expired in-memory state entries. - - If a token's state lock is currently held, defer cleanup until a later pass - to avoid replacing the state while it is being modified. - - Args: - now: The time to compare against. - """ - now = time.time() if now is None else now - while ( - next_expiration := self._next_expiration() - ) is not None and next_expiration[0] <= now: - _expires_at, token = heapq.heappop(self._token_expiration_heap) - if ( - state_lock := self._states_locks.get(token) - ) is not None and state_lock.locked(): - self._pending_locked_expirations.add(token) - continue - self._purge_token(token) - self._maybe_compact_expiration_heap() - - def _next_expiration_in( - self, - now: float | None = None, - ) -> float | None: - """Get the delay until the next expiration check should run. - - Args: - now: The time to compare against. - - Returns: - The number of seconds until the next check, or None when there are no - tracked tokens. - """ - if (next_expiration := self._next_expiration()) is None: - if ( - self._pending_locked_expirations - and not self._recheck_expired_locks_on_unlock - ): - return self._locked_expiration_poll_interval - return None - - now = time.time() if now is None else now - next_delay = max(0.0, next_expiration[0] - now) - if ( - self._pending_locked_expirations - and not self._recheck_expired_locks_on_unlock - ): - return min(next_delay, self._locked_expiration_poll_interval) - return next_delay - - def _reset_token_activity_wait(self): - """Reset the token activity event before waiting.""" - self._token_activity.clear() - - def _prepare_expiration_wait( - self, - *, - now: float | None = None, - default_timeout: float | None = None, - ) -> float | None: - """Prepare the next wait window for an expiration worker. - - Args: - now: The current time. - default_timeout: A fallback timeout when there are no in-memory token - deadlines to wait on. - - Returns: - The timeout to use for the next wait. - """ - self._reset_token_activity_wait() - now = time.time() if now is None else now - timeout = self._next_expiration_in(now=now) - if timeout is None: - timeout = default_timeout - elif default_timeout is not None: - timeout = min(timeout, default_timeout) - self._scheduled_expiration_deadline = None if timeout is None else now + timeout - return timeout - - def _notify_token_unlocked(self, token: str): - """Requeue a deferred expiration check for a token after its lock is released. - - Args: - token: The unlocked token. - """ - if token not in self._pending_locked_expirations: - return - self._pending_locked_expirations.discard(token) - if (expires_at := self._token_expires_at.get(token)) is None: - return - heapq.heappush(self._token_expiration_heap, (expires_at, token)) - self._token_activity.set() - - async def _wait_for_token_activity(self, timeout: float | None): - """Wait for token activity or a timeout. - - Args: - timeout: The maximum time to wait. When None, waits indefinitely. - """ - try: - if timeout is None: - await self._token_activity.wait() - return - with contextlib.suppress(asyncio.TimeoutError): - await asyncio.wait_for(self._token_activity.wait(), timeout=timeout) - finally: - self._scheduled_expiration_deadline = None diff --git a/reflex/istate/manager/memory.py b/reflex/istate/manager/memory.py index d5be3d43211..1bb2971c925 100644 --- a/reflex/istate/manager/memory.py +++ b/reflex/istate/manager/memory.py @@ -5,43 +5,130 @@ import dataclasses import time from collections.abc import AsyncIterator -from typing import ClassVar from typing_extensions import Unpack, override -from reflex.istate.manager import StateManager, StateModificationContext -from reflex.istate.manager._expiration import StateManagerExpiration +from reflex.istate.manager import ( + StateManager, + StateModificationContext, + _default_token_expiration, +) from reflex.state import BaseState, _split_substate_key -from reflex.utils.tasks import ensure_task @dataclasses.dataclass -class StateManagerMemory(StateManagerExpiration, StateManager): +class StateManagerMemory(StateManager): """A state manager that stores states in memory.""" - _recheck_expired_locks_on_unlock: ClassVar[bool] = True + # The token expiration time (s). + token_expiration: int = dataclasses.field(default_factory=_default_token_expiration) + + # The mapping of client ids to states. + states: dict[str, BaseState] = dataclasses.field(default_factory=dict) # The mutex ensures the dict of mutexes is updated exclusively _state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock()) + # The dict of mutexes for each client + _states_locks: dict[str, asyncio.Lock] = dataclasses.field( + default_factory=dict, + init=False, + ) + + # The latest expiration deadline for each token. + _token_expires_at: dict[str, float] = dataclasses.field( + default_factory=dict, + init=False, + ) + _expiration_task: asyncio.Task | None = dataclasses.field(default=None, init=False) - async def _expire_states_once(self): - """Perform one expiration pass and wait for the next check.""" + def _get_or_create_state(self, token: str) -> BaseState: + """Get an existing state or create a fresh one for a token. + + Args: + token: The normalized client token. + + Returns: + The state for the token. + """ + state = self.states.get(token) + if state is None: + state = self.states[token] = self.state(_reflex_internal_init=True) + return state + + def _track_token(self, token: str): + """Track a token for fixed-time expiration.""" + if token not in self._token_expires_at: + self._token_expires_at[token] = time.time() + self.token_expiration + self._ensure_expiration_task() + + def _purge_token(self, token: str): + """Remove a token from in-memory state bookkeeping.""" + self._token_expires_at.pop(token, None) + self.states.pop(token, None) + self._states_locks.pop(token, None) + + def _purge_expired_tokens(self) -> float | None: + """Purge expired in-memory state entries and return the next deadline. + + Returns: + The next expiration deadline among unlocked tokens, if any. + """ now = time.time() - self._purge_expired_tokens(now=now) - await self._wait_for_token_activity( - self._prepare_expiration_wait(now=now), - ) + next_expires_at = None + token_expires_at = self._token_expires_at + state_locks = self._states_locks + + for token, expires_at in list(token_expires_at.items()): + if ( + state_lock := state_locks.get(token) + ) is not None and state_lock.locked(): + continue + if expires_at <= now: + self._purge_token(token) + continue + if next_expires_at is None or expires_at < next_expires_at: + next_expires_at = expires_at + + return next_expires_at + + async def _get_state_lock(self, token: str) -> asyncio.Lock: + """Get or create the lock for a token. + + Args: + token: The normalized client token. + + Returns: + The lock protecting the token's state. + """ + state_lock = self._states_locks.get(token) + if state_lock is None: + async with self._state_manager_lock: + state_lock = self._states_locks.get(token) + if state_lock is None: + state_lock = self._states_locks[token] = asyncio.Lock() + return state_lock + + async def _expire_states(self): + """Purge expired states until there are no unlocked deadlines left.""" + try: + while True: + if (next_expires_at := self._purge_expired_tokens()) is None: + return + await asyncio.sleep(max(0.0, next_expires_at - time.time())) + finally: + if self._expiration_task is asyncio.current_task(): + self._expiration_task = None def _ensure_expiration_task(self): """Ensure the expiration background task is running.""" - ensure_task( - self, - "_expiration_task", - self._expire_states_once, - suppress_exceptions=[Exception], - ) + if self._expiration_task is None or self._expiration_task.done(): + asyncio.get_running_loop() # Ensure we're in an event loop. + self._expiration_task = asyncio.create_task( + self._expire_states(), + name="StateManagerMemory|Expiration", + ) @override async def get_state(self, token: str) -> BaseState: @@ -55,11 +142,9 @@ async def get_state(self, token: str) -> BaseState: """ # Memory state manager ignores the substate suffix and always returns the top-level state. token = _split_substate_key(token)[0] - if token not in self.states: - self.states[token] = self.state(_reflex_internal_init=True) - self._touch_token(token) - self._ensure_expiration_task() - return self.states[token] + state = self._get_or_create_state(token) + self._track_token(token) + return state @override async def set_state( @@ -77,8 +162,7 @@ async def set_state( """ token = _split_substate_key(token)[0] self.states[token] = state - self._touch_token(token) - self._ensure_expiration_task() + self._track_token(token) @override @contextlib.asynccontextmanager @@ -96,16 +180,17 @@ async def modify_state( """ # Memory state manager ignores the substate suffix and always returns the top-level state. token = _split_substate_key(token)[0] - if token not in self._states_locks: - async with self._state_manager_lock: - if token not in self._states_locks: - self._states_locks[token] = asyncio.Lock() + state_lock = await self._get_state_lock(token) try: - async with self._states_locks[token]: - yield await self.get_state(token) + async with state_lock: + state = self._get_or_create_state(token) + self._track_token(token) + yield state finally: - self._notify_token_unlocked(token) + # Re-run expiration after the lock is released in case only locked + # tokens were being tracked when the worker last ran. + self._ensure_expiration_task() async def close(self): """Cancel the in-memory expiration task.""" diff --git a/tests/units/istate/manager/test_expiration.py b/tests/units/istate/manager/test_expiration.py index c123477eb42..ed0c66fb0e6 100644 --- a/tests/units/istate/manager/test_expiration.py +++ b/tests/units/istate/manager/test_expiration.py @@ -75,72 +75,62 @@ async def test_memory_state_manager_evicts_expired_state( @pytest.mark.asyncio -async def test_memory_state_manager_get_state_refreshes_expiration( +async def test_memory_state_manager_get_state_does_not_refresh_expiration( state_manager_memory: StateManagerMemory, token: str, ): - """Accessing a state should extend its expiration window.""" + """Accessing a state should not extend its expiration window.""" state_token = _substate_key(token, ExpiringState) state = await state_manager_memory.get_state(state_token) assert isinstance(state, ExpiringState) state.value = 7 - first_expires_at = state_manager_memory._token_expires_at[token] + expires_at = state_manager_memory._token_expires_at[token] await asyncio.sleep(0.6) same_state = await state_manager_memory.get_state(state_token) assert same_state is state - assert state_manager_memory._token_expires_at[token] > first_expires_at + assert state_manager_memory._token_expires_at[token] == expires_at - await asyncio.sleep(0.6) - - assert token in state_manager_memory.states - assert state_manager_memory.states[token] is state + await _poll_until(lambda: token not in state_manager_memory.states) @pytest.mark.asyncio -async def test_memory_state_manager_set_state_refreshes_expiration( +async def test_memory_state_manager_set_state_does_not_refresh_expiration( state_manager_memory: StateManagerMemory, token: str, ): - """Persisting a state should extend its expiration window.""" + """Persisting a state should not extend its expiration window.""" state_token = _substate_key(token, ExpiringState) state = await state_manager_memory.get_state(state_token) assert isinstance(state, ExpiringState) state.value = 17 - first_expires_at = state_manager_memory._token_expires_at[token] + expires_at = state_manager_memory._token_expires_at[token] await asyncio.sleep(0.6) await state_manager_memory.set_state(state_token, state) - assert state_manager_memory._token_expires_at[token] > first_expires_at + assert state_manager_memory._token_expires_at[token] == expires_at - await asyncio.sleep(0.6) - - assert token in state_manager_memory.states - assert state_manager_memory.states[token] is state + await _poll_until(lambda: token not in state_manager_memory.states) @pytest.mark.asyncio -async def test_memory_state_manager_multiple_touches_do_not_evict_early( +async def test_memory_state_manager_multiple_accesses_do_not_extend_expiration( state_manager_memory: StateManagerMemory, token: str, ): - """Repeated touches should honor the latest expiration deadline.""" + """Repeated accesses should still expire on the original deadline.""" state_token = _substate_key(token, ExpiringState) state = await state_manager_memory.get_state(state_token) assert isinstance(state, ExpiringState) + expires_at = state_manager_memory._token_expires_at[token] for _ in range(3): - await asyncio.sleep(0.35) + await asyncio.sleep(0.25) assert await state_manager_memory.get_state(state_token) is state - - # The first deadlines have passed, but the latest touch should still keep the - # token alive until its own expiration window ends. - await asyncio.sleep(0.2) - - assert token in state_manager_memory.states + assert state_manager_memory._token_expires_at[token] == expires_at await _poll_until(lambda: token not in state_manager_memory.states) @@ -189,15 +179,15 @@ async def test_memory_state_manager_evicts_expired_locked_state_after_unlock( state_manager_memory: StateManagerMemory, token: str, ): - """An expired locked state should be evicted once its lock is released.""" + """Releasing an expired locked state should evict it without refreshing TTL.""" state_token = _substate_key(token, ExpiringState) async with state_manager_memory.modify_state(state_token) as state: state.value = 5 - await _poll_until( - lambda: token in state_manager_memory._pending_locked_expirations, - timeout=2.0, - ) + expires_at = state_manager_memory._token_expires_at[token] + await asyncio.sleep(1.2) assert token in state_manager_memory.states + assert state_manager_memory._token_expires_at[token] == expires_at + await _poll_until(lambda: token not in state_manager_memory.states) From 98773128f1b86d52a08344864cf3b33e677c2efb Mon Sep 17 00:00:00 2001 From: Farhan Ali Raza Date: Tue, 24 Mar 2026 18:49:21 +0500 Subject: [PATCH 5/5] feat: Addded expire extension --- reflex/istate/manager/memory.py | 12 +++-- .../test_memory_state_manager_expiration.py | 47 +++++++++++++++++++ tests/units/istate/manager/test_expiration.py | 41 +++++++++++----- 3 files changed, 84 insertions(+), 16 deletions(-) diff --git a/reflex/istate/manager/memory.py b/reflex/istate/manager/memory.py index 1bb2971c925..ec898388ebd 100644 --- a/reflex/istate/manager/memory.py +++ b/reflex/istate/manager/memory.py @@ -58,9 +58,8 @@ def _get_or_create_state(self, token: str) -> BaseState: return state def _track_token(self, token: str): - """Track a token for fixed-time expiration.""" - if token not in self._token_expires_at: - self._token_expires_at[token] = time.time() + self.token_expiration + """Refresh the expiration deadline for an active token.""" + self._token_expires_at[token] = time.time() + self.token_expiration self._ensure_expiration_task() def _purge_token(self, token: str): @@ -186,7 +185,12 @@ async def modify_state( async with state_lock: state = self._get_or_create_state(token) self._track_token(token) - yield state + try: + yield state + finally: + # Treat modify_state like a read followed by a write so the + # expiration window starts after the state is no longer busy. + self._track_token(token) finally: # Re-run expiration after the lock is released in case only locked # tokens were being tracked when the worker last ran. diff --git a/tests/integration/test_memory_state_manager_expiration.py b/tests/integration/test_memory_state_manager_expiration.py index 965b65a5520..a6feda0f0ef 100644 --- a/tests/integration/test_memory_state_manager_expiration.py +++ b/tests/integration/test_memory_state_manager_expiration.py @@ -1,5 +1,6 @@ """Integration tests for in-memory state expiration.""" +import time from collections.abc import Generator import pytest @@ -108,3 +109,49 @@ def test_memory_state_manager_expires_state_end_to_end( increment.click() AppHarness.expect(lambda: counter.text == "1") assert token_input.get_attribute("value") == token + + +def test_memory_state_manager_delays_expiration_after_use_end_to_end( + memory_expiration_app: AppHarness, + driver: WebDriver, +): + """Using a token should start a fresh expiration window from the last use.""" + app_instance = memory_expiration_app.app_instance + assert app_instance is not None + + token_input = AppHarness.poll_for_or_raise_timeout( + lambda: driver.find_element(By.ID, "token") + ) + token = memory_expiration_app.poll_for_value(token_input) + assert token is not None + + counter = driver.find_element(By.ID, "counter") + increment = driver.find_element(By.ID, "increment") + app_state_manager = app_instance.state_manager + assert isinstance(app_state_manager, StateManagerMemory) + + AppHarness.expect(lambda: counter.text == "0") + + increment.click() + AppHarness.expect(lambda: counter.text == "1") + AppHarness.expect(lambda: token in app_state_manager.states) + + time.sleep(0.6) + increment.click() + AppHarness.expect(lambda: counter.text == "2") + AppHarness.expect(lambda: token in app_state_manager.states) + + time.sleep(0.6) + increment.click() + AppHarness.expect(lambda: counter.text == "3") + AppHarness.expect(lambda: token in app_state_manager.states) + + time.sleep(0.6) + assert token in app_state_manager.states + assert counter.text == "3" + + AppHarness.expect(lambda: token not in app_state_manager.states, timeout=5) + + increment.click() + AppHarness.expect(lambda: counter.text == "1") + assert token_input.get_attribute("value") == token diff --git a/tests/units/istate/manager/test_expiration.py b/tests/units/istate/manager/test_expiration.py index ed0c66fb0e6..ff5c76de458 100644 --- a/tests/units/istate/manager/test_expiration.py +++ b/tests/units/istate/manager/test_expiration.py @@ -75,11 +75,11 @@ async def test_memory_state_manager_evicts_expired_state( @pytest.mark.asyncio -async def test_memory_state_manager_get_state_does_not_refresh_expiration( +async def test_memory_state_manager_get_state_refreshes_expiration( state_manager_memory: StateManagerMemory, token: str, ): - """Accessing a state should not extend its expiration window.""" + """Accessing a state should extend its expiration window.""" state_token = _substate_key(token, ExpiringState) state = await state_manager_memory.get_state(state_token) assert isinstance(state, ExpiringState) @@ -90,17 +90,21 @@ async def test_memory_state_manager_get_state_does_not_refresh_expiration( same_state = await state_manager_memory.get_state(state_token) assert same_state is state - assert state_manager_memory._token_expires_at[token] == expires_at + assert state_manager_memory._token_expires_at[token] > expires_at + + await asyncio.sleep(0.6) + + assert token in state_manager_memory.states await _poll_until(lambda: token not in state_manager_memory.states) @pytest.mark.asyncio -async def test_memory_state_manager_set_state_does_not_refresh_expiration( +async def test_memory_state_manager_set_state_refreshes_expiration( state_manager_memory: StateManagerMemory, token: str, ): - """Persisting a state should not extend its expiration window.""" + """Persisting a state should extend its expiration window.""" state_token = _substate_key(token, ExpiringState) state = await state_manager_memory.get_state(state_token) assert isinstance(state, ExpiringState) @@ -111,17 +115,21 @@ async def test_memory_state_manager_set_state_does_not_refresh_expiration( await state_manager_memory.set_state(state_token, state) - assert state_manager_memory._token_expires_at[token] == expires_at + assert state_manager_memory._token_expires_at[token] > expires_at + + await asyncio.sleep(0.6) + + assert token in state_manager_memory.states await _poll_until(lambda: token not in state_manager_memory.states) @pytest.mark.asyncio -async def test_memory_state_manager_multiple_accesses_do_not_extend_expiration( +async def test_memory_state_manager_multiple_accesses_extend_expiration( state_manager_memory: StateManagerMemory, token: str, ): - """Repeated accesses should still expire on the original deadline.""" + """Repeated accesses should keep the state alive until it goes idle.""" state_token = _substate_key(token, ExpiringState) state = await state_manager_memory.get_state(state_token) assert isinstance(state, ExpiringState) @@ -130,7 +138,12 @@ async def test_memory_state_manager_multiple_accesses_do_not_extend_expiration( for _ in range(3): await asyncio.sleep(0.25) assert await state_manager_memory.get_state(state_token) is state - assert state_manager_memory._token_expires_at[token] == expires_at + assert state_manager_memory._token_expires_at[token] > expires_at + expires_at = state_manager_memory._token_expires_at[token] + + await asyncio.sleep(0.6) + + assert token in state_manager_memory.states await _poll_until(lambda: token not in state_manager_memory.states) @@ -175,11 +188,11 @@ async def test_memory_state_manager_close_cancels_expiration_task( @pytest.mark.asyncio -async def test_memory_state_manager_evicts_expired_locked_state_after_unlock( +async def test_memory_state_manager_refreshes_expiration_after_locked_access( state_manager_memory: StateManagerMemory, token: str, ): - """Releasing an expired locked state should evict it without refreshing TTL.""" + """Releasing a long-held state should start a fresh expiration window.""" state_token = _substate_key(token, ExpiringState) async with state_manager_memory.modify_state(state_token) as state: @@ -188,6 +201,10 @@ async def test_memory_state_manager_evicts_expired_locked_state_after_unlock( await asyncio.sleep(1.2) assert token in state_manager_memory.states - assert state_manager_memory._token_expires_at[token] == expires_at + assert state_manager_memory._token_expires_at[token] > expires_at + + await asyncio.sleep(0.6) + + assert token in state_manager_memory.states await _poll_until(lambda: token not in state_manager_memory.states)