diff --git a/reflex/istate/manager/__init__.py b/reflex/istate/manager/__init__.py index bf4477cfd08..970c0393370 100644 --- a/reflex/istate/manager/__init__.py +++ b/reflex/istate/manager/__init__.py @@ -46,7 +46,10 @@ def create(cls, state: type[BaseState]): InvalidStateManagerModeError: If the state manager mode is invalid. """ config = get_config() - if prerequisites.parse_redis_url() is not None: + if ( + "state_manager_mode" not in config._non_default_attributes + and prerequisites.parse_redis_url() is not None + ): config.state_manager_mode = constants.StateManagerMode.REDIS if config.state_manager_mode == constants.StateManagerMode.MEMORY: from reflex.istate.manager.memory import StateManagerMemory diff --git a/reflex/istate/manager/memory.py b/reflex/istate/manager/memory.py index fab4df56305..ec898388ebd 100644 --- a/reflex/istate/manager/memory.py +++ b/reflex/istate/manager/memory.py @@ -3,11 +3,16 @@ import asyncio import contextlib import dataclasses +import time from collections.abc import AsyncIterator from typing_extensions import Unpack, override -from reflex.istate.manager import StateManager, StateModificationContext +from reflex.istate.manager import ( + StateManager, + StateModificationContext, + _default_token_expiration, +) from reflex.state import BaseState, _split_substate_key @@ -15,6 +20,9 @@ class StateManagerMemory(StateManager): """A state manager that stores states in memory.""" + # The token expiration time (s). + token_expiration: int = dataclasses.field(default_factory=_default_token_expiration) + # The mapping of client ids to states. states: dict[str, BaseState] = dataclasses.field(default_factory=dict) @@ -23,9 +31,104 @@ class StateManagerMemory(StateManager): # The dict of mutexes for each client _states_locks: dict[str, asyncio.Lock] = dataclasses.field( - default_factory=dict, init=False + default_factory=dict, + init=False, + ) + + # The latest expiration deadline for each token. + _token_expires_at: dict[str, float] = dataclasses.field( + default_factory=dict, + init=False, ) + _expiration_task: asyncio.Task | None = dataclasses.field(default=None, init=False) + + def _get_or_create_state(self, token: str) -> BaseState: + """Get an existing state or create a fresh one for a token. + + Args: + token: The normalized client token. + + Returns: + The state for the token. + """ + state = self.states.get(token) + if state is None: + state = self.states[token] = self.state(_reflex_internal_init=True) + return state + + def _track_token(self, token: str): + """Refresh the expiration deadline for an active token.""" + self._token_expires_at[token] = time.time() + self.token_expiration + self._ensure_expiration_task() + + def _purge_token(self, token: str): + """Remove a token from in-memory state bookkeeping.""" + self._token_expires_at.pop(token, None) + self.states.pop(token, None) + self._states_locks.pop(token, None) + + def _purge_expired_tokens(self) -> float | None: + """Purge expired in-memory state entries and return the next deadline. + + Returns: + The next expiration deadline among unlocked tokens, if any. + """ + now = time.time() + next_expires_at = None + token_expires_at = self._token_expires_at + state_locks = self._states_locks + + for token, expires_at in list(token_expires_at.items()): + if ( + state_lock := state_locks.get(token) + ) is not None and state_lock.locked(): + continue + if expires_at <= now: + self._purge_token(token) + continue + if next_expires_at is None or expires_at < next_expires_at: + next_expires_at = expires_at + + return next_expires_at + + async def _get_state_lock(self, token: str) -> asyncio.Lock: + """Get or create the lock for a token. + + Args: + token: The normalized client token. + + Returns: + The lock protecting the token's state. + """ + state_lock = self._states_locks.get(token) + if state_lock is None: + async with self._state_manager_lock: + state_lock = self._states_locks.get(token) + if state_lock is None: + state_lock = self._states_locks[token] = asyncio.Lock() + return state_lock + + async def _expire_states(self): + """Purge expired states until there are no unlocked deadlines left.""" + try: + while True: + if (next_expires_at := self._purge_expired_tokens()) is None: + return + await asyncio.sleep(max(0.0, next_expires_at - time.time())) + finally: + if self._expiration_task is asyncio.current_task(): + self._expiration_task = None + + def _ensure_expiration_task(self): + """Ensure the expiration background task is running.""" + if self._expiration_task is None or self._expiration_task.done(): + asyncio.get_running_loop() # Ensure we're in an event loop. + self._expiration_task = asyncio.create_task( + self._expire_states(), + name="StateManagerMemory|Expiration", + ) + @override async def get_state(self, token: str) -> BaseState: """Get the state for a token. @@ -38,9 +141,9 @@ async def get_state(self, token: str) -> BaseState: """ # Memory state manager ignores the substate suffix and always returns the top-level state. token = _split_substate_key(token)[0] - if token not in self.states: - self.states[token] = self.state(_reflex_internal_init=True) - return self.states[token] + state = self._get_or_create_state(token) + self._track_token(token) + return state @override async def set_state( @@ -58,6 +161,7 @@ async def set_state( """ token = _split_substate_key(token)[0] self.states[token] = state + self._track_token(token) @override @contextlib.asynccontextmanager @@ -75,10 +179,28 @@ async def modify_state( """ # Memory state manager ignores the substate suffix and always returns the top-level state. token = _split_substate_key(token)[0] - if token not in self._states_locks: - async with self._state_manager_lock: - if token not in self._states_locks: - self._states_locks[token] = asyncio.Lock() - - async with self._states_locks[token]: - yield await self.get_state(token) + state_lock = await self._get_state_lock(token) + + try: + async with state_lock: + state = self._get_or_create_state(token) + self._track_token(token) + try: + yield state + finally: + # Treat modify_state like a read followed by a write so the + # expiration window starts after the state is no longer busy. + self._track_token(token) + finally: + # Re-run expiration after the lock is released in case only locked + # tokens were being tracked when the worker last ran. + self._ensure_expiration_task() + + async def close(self): + """Cancel the in-memory expiration task.""" + async with self._state_manager_lock: + if self._expiration_task: + self._expiration_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._expiration_task + self._expiration_task = None diff --git a/tests/integration/test_memory_state_manager_expiration.py b/tests/integration/test_memory_state_manager_expiration.py new file mode 100644 index 00000000000..a6feda0f0ef --- /dev/null +++ b/tests/integration/test_memory_state_manager_expiration.py @@ -0,0 +1,157 @@ +"""Integration tests for in-memory state expiration.""" + +import time +from collections.abc import Generator + +import pytest +from selenium.webdriver.common.by import By +from selenium.webdriver.remote.webdriver import WebDriver + +from reflex.istate.manager.memory import StateManagerMemory +from reflex.testing import AppHarness + + +def MemoryExpirationApp(): + """Reflex app that exposes state expiration through a simple counter UI.""" + import reflex as rx + + class State(rx.State): + counter: int = 0 + + @rx.event + def increment(self): + self.counter += 1 + + app = rx.App() + + @app.add_page + def index(): + return rx.vstack( + rx.input( + id="token", + value=State.router.session.client_token, + is_read_only=True, + ), + rx.text(State.counter, id="counter"), + rx.button("Increment", id="increment", on_click=State.increment), + ) + + +@pytest.fixture +def memory_expiration_app( + app_harness_env: type[AppHarness], + monkeypatch: pytest.MonkeyPatch, + tmp_path_factory: pytest.TempPathFactory, +) -> Generator[AppHarness, None, None]: + """Start a memory-backed app with a short expiration window. + + Yields: + A running app harness configured to use StateManagerMemory. + """ + monkeypatch.setenv("REFLEX_STATE_MANAGER_MODE", "memory") + # Memory expiration reuses the shared token_expiration config field. + monkeypatch.setenv("REFLEX_REDIS_TOKEN_EXPIRATION", "1") + + with app_harness_env.create( + root=tmp_path_factory.mktemp("memory_expiration_app"), + app_name=f"memory_expiration_{app_harness_env.__name__.lower()}", + app_source=MemoryExpirationApp, + ) as harness: + assert isinstance(harness.state_manager, StateManagerMemory) + yield harness + + +@pytest.fixture +def driver(memory_expiration_app: AppHarness) -> Generator[WebDriver, None, None]: + """Open the memory expiration app in a browser. + + Yields: + A webdriver instance pointed at the running app. + """ + assert memory_expiration_app.app_instance is not None, "app is not running" + driver = memory_expiration_app.frontend() + try: + yield driver + finally: + driver.quit() + + +def test_memory_state_manager_expires_state_end_to_end( + memory_expiration_app: AppHarness, + driver: WebDriver, +): + """An idle in-memory state should expire and reset on the next event.""" + app_instance = memory_expiration_app.app_instance + assert app_instance is not None + + token_input = AppHarness.poll_for_or_raise_timeout( + lambda: driver.find_element(By.ID, "token") + ) + token = memory_expiration_app.poll_for_value(token_input) + assert token is not None + + counter = driver.find_element(By.ID, "counter") + increment = driver.find_element(By.ID, "increment") + app_state_manager = app_instance.state_manager + assert isinstance(app_state_manager, StateManagerMemory) + + AppHarness.expect(lambda: counter.text == "0") + + increment.click() + AppHarness.expect(lambda: counter.text == "1") + + increment.click() + AppHarness.expect(lambda: counter.text == "2") + + AppHarness.expect(lambda: token in app_state_manager.states) + AppHarness.expect(lambda: token not in app_state_manager.states, timeout=5) + + increment.click() + AppHarness.expect(lambda: counter.text == "1") + assert token_input.get_attribute("value") == token + + +def test_memory_state_manager_delays_expiration_after_use_end_to_end( + memory_expiration_app: AppHarness, + driver: WebDriver, +): + """Using a token should start a fresh expiration window from the last use.""" + app_instance = memory_expiration_app.app_instance + assert app_instance is not None + + token_input = AppHarness.poll_for_or_raise_timeout( + lambda: driver.find_element(By.ID, "token") + ) + token = memory_expiration_app.poll_for_value(token_input) + assert token is not None + + counter = driver.find_element(By.ID, "counter") + increment = driver.find_element(By.ID, "increment") + app_state_manager = app_instance.state_manager + assert isinstance(app_state_manager, StateManagerMemory) + + AppHarness.expect(lambda: counter.text == "0") + + increment.click() + AppHarness.expect(lambda: counter.text == "1") + AppHarness.expect(lambda: token in app_state_manager.states) + + time.sleep(0.6) + increment.click() + AppHarness.expect(lambda: counter.text == "2") + AppHarness.expect(lambda: token in app_state_manager.states) + + time.sleep(0.6) + increment.click() + AppHarness.expect(lambda: counter.text == "3") + AppHarness.expect(lambda: token in app_state_manager.states) + + time.sleep(0.6) + assert token in app_state_manager.states + assert counter.text == "3" + + AppHarness.expect(lambda: token not in app_state_manager.states, timeout=5) + + increment.click() + AppHarness.expect(lambda: counter.text == "1") + assert token_input.get_attribute("value") == token diff --git a/tests/units/istate/manager/test_expiration.py b/tests/units/istate/manager/test_expiration.py new file mode 100644 index 00000000000..ff5c76de458 --- /dev/null +++ b/tests/units/istate/manager/test_expiration.py @@ -0,0 +1,210 @@ +"""Tests for state manager token expiration.""" + +import asyncio +import time +from collections.abc import AsyncGenerator, Callable + +import pytest +import pytest_asyncio + +from reflex.istate.manager.memory import StateManagerMemory +from reflex.state import BaseState, _substate_key + + +class ExpiringState(BaseState): + """A test state for expiration-specific manager tests.""" + + value: int = 0 + + +async def _poll_until( + predicate: Callable[[], bool], + *, + timeout: float = 3.0, + interval: float = 0.05, +): + """Poll until a predicate succeeds. + + Args: + predicate: The predicate to evaluate. + timeout: The maximum time to wait. + interval: The delay between attempts. + """ + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + if predicate(): + return + await asyncio.sleep(interval) + assert predicate() + + +@pytest_asyncio.fixture(loop_scope="function", scope="function") +async def state_manager_memory() -> AsyncGenerator[StateManagerMemory]: + """Create a memory state manager with a short expiration. + + Yields: + The memory state manager under test. + """ + state_manager = StateManagerMemory(state=ExpiringState, token_expiration=1) + yield state_manager + await state_manager.close() + + +@pytest.mark.asyncio +async def test_memory_state_manager_evicts_expired_state( + state_manager_memory: StateManagerMemory, + token: str, +): + """Expired states should be removed from the in-memory cache and locks.""" + state_token = _substate_key(token, ExpiringState) + + async with state_manager_memory.modify_state(state_token) as state: + state.value = 42 + + assert token in state_manager_memory.states + assert token in state_manager_memory._states_locks + assert token in state_manager_memory._token_expires_at + + await _poll_until( + lambda: ( + token not in state_manager_memory.states + and token not in state_manager_memory._states_locks + and token not in state_manager_memory._token_expires_at + ) + ) + + +@pytest.mark.asyncio +async def test_memory_state_manager_get_state_refreshes_expiration( + state_manager_memory: StateManagerMemory, + token: str, +): + """Accessing a state should extend its expiration window.""" + state_token = _substate_key(token, ExpiringState) + state = await state_manager_memory.get_state(state_token) + assert isinstance(state, ExpiringState) + state.value = 7 + expires_at = state_manager_memory._token_expires_at[token] + + await asyncio.sleep(0.6) + + same_state = await state_manager_memory.get_state(state_token) + assert same_state is state + assert state_manager_memory._token_expires_at[token] > expires_at + + await asyncio.sleep(0.6) + + assert token in state_manager_memory.states + + await _poll_until(lambda: token not in state_manager_memory.states) + + +@pytest.mark.asyncio +async def test_memory_state_manager_set_state_refreshes_expiration( + state_manager_memory: StateManagerMemory, + token: str, +): + """Persisting a state should extend its expiration window.""" + state_token = _substate_key(token, ExpiringState) + state = await state_manager_memory.get_state(state_token) + assert isinstance(state, ExpiringState) + state.value = 17 + expires_at = state_manager_memory._token_expires_at[token] + + await asyncio.sleep(0.6) + + await state_manager_memory.set_state(state_token, state) + + assert state_manager_memory._token_expires_at[token] > expires_at + + await asyncio.sleep(0.6) + + assert token in state_manager_memory.states + + await _poll_until(lambda: token not in state_manager_memory.states) + + +@pytest.mark.asyncio +async def test_memory_state_manager_multiple_accesses_extend_expiration( + state_manager_memory: StateManagerMemory, + token: str, +): + """Repeated accesses should keep the state alive until it goes idle.""" + state_token = _substate_key(token, ExpiringState) + state = await state_manager_memory.get_state(state_token) + assert isinstance(state, ExpiringState) + expires_at = state_manager_memory._token_expires_at[token] + + for _ in range(3): + await asyncio.sleep(0.25) + assert await state_manager_memory.get_state(state_token) is state + assert state_manager_memory._token_expires_at[token] > expires_at + expires_at = state_manager_memory._token_expires_at[token] + + await asyncio.sleep(0.6) + + assert token in state_manager_memory.states + + await _poll_until(lambda: token not in state_manager_memory.states) + + +@pytest.mark.asyncio +async def test_memory_state_manager_returns_fresh_state_after_eviction( + state_manager_memory: StateManagerMemory, + token: str, +): + """A token should get a fresh state after the previous one expires.""" + state_token = _substate_key(token, ExpiringState) + state = await state_manager_memory.get_state(state_token) + assert isinstance(state, ExpiringState) + state.value = 99 + + await _poll_until(lambda: token not in state_manager_memory.states) + + fresh_state = await state_manager_memory.get_state(state_token) + assert isinstance(fresh_state, ExpiringState) + assert fresh_state is not state + assert fresh_state.value == 0 + + +@pytest.mark.asyncio +async def test_memory_state_manager_close_cancels_expiration_task( + state_manager_memory: StateManagerMemory, + token: str, +): + """Closing the manager should cancel the expiration task cleanly.""" + await state_manager_memory.get_state(_substate_key(token, ExpiringState)) + + expiration_task = state_manager_memory._expiration_task + assert expiration_task is not None + assert not expiration_task.done() + + await state_manager_memory.close() + + assert state_manager_memory._expiration_task is None + assert expiration_task.done() + + await state_manager_memory.close() + + +@pytest.mark.asyncio +async def test_memory_state_manager_refreshes_expiration_after_locked_access( + state_manager_memory: StateManagerMemory, + token: str, +): + """Releasing a long-held state should start a fresh expiration window.""" + state_token = _substate_key(token, ExpiringState) + + async with state_manager_memory.modify_state(state_token) as state: + state.value = 5 + expires_at = state_manager_memory._token_expires_at[token] + await asyncio.sleep(1.2) + assert token in state_manager_memory.states + + assert state_manager_memory._token_expires_at[token] > expires_at + + await asyncio.sleep(0.6) + + assert token in state_manager_memory.states + + await _poll_until(lambda: token not in state_manager_memory.states) diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 0775dd000e5..d02940d24bd 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -3596,6 +3596,34 @@ def test_redis_state_manager_config_knobs_invalid_lock_warning_threshold( del sys.modules[constants.Config.MODULE] +def test_state_manager_create_respects_explicit_memory_mode_with_redis_url( + tmp_path, monkeypatch: pytest.MonkeyPatch +): + proj_root = tmp_path / "project1" + proj_root.mkdir() + + config_string = """ +import reflex as rx +config = rx.Config( + app_name="project1", +) + """ + + (proj_root / "rxconfig.py").write_text(dedent(config_string)) + monkeypatch.setenv("REFLEX_STATE_MANAGER_MODE", "memory") + monkeypatch.setenv("REFLEX_REDIS_URL", "redis://localhost:6379") + + with chdir(proj_root): + reflex.config.get_config(reload=True) + monkeypatch.setattr(prerequisites, "get_redis", mock_redis) + from reflex.state import State + + state_manager = StateManager.create(state=State) + assert isinstance(state_manager, StateManagerMemory) + + del sys.modules[constants.Config.MODULE] + + def test_auto_setters_off(tmp_path): proj_root = tmp_path / "project1" proj_root.mkdir()