Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion reflex/istate/manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def create(cls, state: type[BaseState]):
InvalidStateManagerModeError: If the state manager mode is invalid.
"""
config = get_config()
if prerequisites.parse_redis_url() is not None:
if (
"state_manager_mode" not in config._non_default_attributes
and prerequisites.parse_redis_url() is not None
):
config.state_manager_mode = constants.StateManagerMode.REDIS
if config.state_manager_mode == constants.StateManagerMode.MEMORY:
from reflex.istate.manager.memory import StateManagerMemory
Expand Down
146 changes: 134 additions & 12 deletions reflex/istate/manager/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,26 @@
import asyncio
import contextlib
import dataclasses
import time
from collections.abc import AsyncIterator

from typing_extensions import Unpack, override

from reflex.istate.manager import StateManager, StateModificationContext
from reflex.istate.manager import (
StateManager,
StateModificationContext,
_default_token_expiration,
)
from reflex.state import BaseState, _split_substate_key


@dataclasses.dataclass
class StateManagerMemory(StateManager):
"""A state manager that stores states in memory."""

# The token expiration time (s).
token_expiration: int = dataclasses.field(default_factory=_default_token_expiration)

# The mapping of client ids to states.
states: dict[str, BaseState] = dataclasses.field(default_factory=dict)

Expand All @@ -23,9 +31,104 @@ class StateManagerMemory(StateManager):

# The dict of mutexes for each client
_states_locks: dict[str, asyncio.Lock] = dataclasses.field(
default_factory=dict, init=False
default_factory=dict,
init=False,
)

# The latest expiration deadline for each token.
_token_expires_at: dict[str, float] = dataclasses.field(
default_factory=dict,
init=False,
)

_expiration_task: asyncio.Task | None = dataclasses.field(default=None, init=False)

def _get_or_create_state(self, token: str) -> BaseState:
"""Get an existing state or create a fresh one for a token.

Args:
token: The normalized client token.

Returns:
The state for the token.
"""
state = self.states.get(token)
if state is None:
state = self.states[token] = self.state(_reflex_internal_init=True)
return state

def _track_token(self, token: str):
"""Refresh the expiration deadline for an active token."""
self._token_expires_at[token] = time.time() + self.token_expiration
self._ensure_expiration_task()

def _purge_token(self, token: str):
"""Remove a token from in-memory state bookkeeping."""
self._token_expires_at.pop(token, None)
self.states.pop(token, None)
self._states_locks.pop(token, None)

def _purge_expired_tokens(self) -> float | None:
"""Purge expired in-memory state entries and return the next deadline.

Returns:
The next expiration deadline among unlocked tokens, if any.
"""
now = time.time()
next_expires_at = None
token_expires_at = self._token_expires_at
state_locks = self._states_locks

for token, expires_at in list(token_expires_at.items()):
if (
state_lock := state_locks.get(token)
) is not None and state_lock.locked():
continue
if expires_at <= now:
self._purge_token(token)
continue
if next_expires_at is None or expires_at < next_expires_at:
next_expires_at = expires_at

return next_expires_at

async def _get_state_lock(self, token: str) -> asyncio.Lock:
"""Get or create the lock for a token.

Args:
token: The normalized client token.

Returns:
The lock protecting the token's state.
"""
state_lock = self._states_locks.get(token)
if state_lock is None:
async with self._state_manager_lock:
state_lock = self._states_locks.get(token)
if state_lock is None:
state_lock = self._states_locks[token] = asyncio.Lock()
return state_lock

async def _expire_states(self):
"""Purge expired states until there are no unlocked deadlines left."""
try:
while True:
if (next_expires_at := self._purge_expired_tokens()) is None:
return
await asyncio.sleep(max(0.0, next_expires_at - time.time()))
finally:
if self._expiration_task is asyncio.current_task():
self._expiration_task = None

def _ensure_expiration_task(self):
"""Ensure the expiration background task is running."""
if self._expiration_task is None or self._expiration_task.done():
asyncio.get_running_loop() # Ensure we're in an event loop.
self._expiration_task = asyncio.create_task(
self._expire_states(),
name="StateManagerMemory|Expiration",
)

@override
async def get_state(self, token: str) -> BaseState:
"""Get the state for a token.
Expand All @@ -38,9 +141,9 @@ async def get_state(self, token: str) -> BaseState:
"""
# Memory state manager ignores the substate suffix and always returns the top-level state.
token = _split_substate_key(token)[0]
if token not in self.states:
self.states[token] = self.state(_reflex_internal_init=True)
return self.states[token]
state = self._get_or_create_state(token)
self._track_token(token)
return state

@override
async def set_state(
Expand All @@ -58,6 +161,7 @@ async def set_state(
"""
token = _split_substate_key(token)[0]
self.states[token] = state
self._track_token(token)

@override
@contextlib.asynccontextmanager
Expand All @@ -75,10 +179,28 @@ async def modify_state(
"""
# Memory state manager ignores the substate suffix and always returns the top-level state.
token = _split_substate_key(token)[0]
if token not in self._states_locks:
async with self._state_manager_lock:
if token not in self._states_locks:
self._states_locks[token] = asyncio.Lock()

async with self._states_locks[token]:
yield await self.get_state(token)
state_lock = await self._get_state_lock(token)

try:
async with state_lock:
state = self._get_or_create_state(token)
self._track_token(token)
try:
yield state
finally:
# Treat modify_state like a read followed by a write so the
# expiration window starts after the state is no longer busy.
self._track_token(token)
finally:
# Re-run expiration after the lock is released in case only locked
# tokens were being tracked when the worker last ran.
self._ensure_expiration_task()

async def close(self):
"""Cancel the in-memory expiration task."""
async with self._state_manager_lock:
if self._expiration_task:
self._expiration_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._expiration_task
self._expiration_task = None
157 changes: 157 additions & 0 deletions tests/integration/test_memory_state_manager_expiration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""Integration tests for in-memory state expiration."""

import time
from collections.abc import Generator

import pytest
from selenium.webdriver.common.by import By
from selenium.webdriver.remote.webdriver import WebDriver

from reflex.istate.manager.memory import StateManagerMemory
from reflex.testing import AppHarness


def MemoryExpirationApp():
"""Reflex app that exposes state expiration through a simple counter UI."""
import reflex as rx

class State(rx.State):
counter: int = 0

@rx.event
def increment(self):
self.counter += 1

app = rx.App()

@app.add_page
def index():
return rx.vstack(
rx.input(
id="token",
value=State.router.session.client_token,
is_read_only=True,
),
rx.text(State.counter, id="counter"),
rx.button("Increment", id="increment", on_click=State.increment),
)


@pytest.fixture
def memory_expiration_app(
app_harness_env: type[AppHarness],
monkeypatch: pytest.MonkeyPatch,
tmp_path_factory: pytest.TempPathFactory,
) -> Generator[AppHarness, None, None]:
"""Start a memory-backed app with a short expiration window.

Yields:
A running app harness configured to use StateManagerMemory.
"""
monkeypatch.setenv("REFLEX_STATE_MANAGER_MODE", "memory")
# Memory expiration reuses the shared token_expiration config field.
monkeypatch.setenv("REFLEX_REDIS_TOKEN_EXPIRATION", "1")

with app_harness_env.create(
root=tmp_path_factory.mktemp("memory_expiration_app"),
app_name=f"memory_expiration_{app_harness_env.__name__.lower()}",
app_source=MemoryExpirationApp,
) as harness:
assert isinstance(harness.state_manager, StateManagerMemory)
yield harness


@pytest.fixture
def driver(memory_expiration_app: AppHarness) -> Generator[WebDriver, None, None]:
"""Open the memory expiration app in a browser.

Yields:
A webdriver instance pointed at the running app.
"""
assert memory_expiration_app.app_instance is not None, "app is not running"
driver = memory_expiration_app.frontend()
try:
yield driver
finally:
driver.quit()


def test_memory_state_manager_expires_state_end_to_end(
memory_expiration_app: AppHarness,
driver: WebDriver,
):
"""An idle in-memory state should expire and reset on the next event."""
app_instance = memory_expiration_app.app_instance
assert app_instance is not None

token_input = AppHarness.poll_for_or_raise_timeout(
lambda: driver.find_element(By.ID, "token")
)
token = memory_expiration_app.poll_for_value(token_input)
assert token is not None

counter = driver.find_element(By.ID, "counter")
increment = driver.find_element(By.ID, "increment")
app_state_manager = app_instance.state_manager
assert isinstance(app_state_manager, StateManagerMemory)

AppHarness.expect(lambda: counter.text == "0")

increment.click()
AppHarness.expect(lambda: counter.text == "1")

increment.click()
AppHarness.expect(lambda: counter.text == "2")

AppHarness.expect(lambda: token in app_state_manager.states)
AppHarness.expect(lambda: token not in app_state_manager.states, timeout=5)

increment.click()
AppHarness.expect(lambda: counter.text == "1")
assert token_input.get_attribute("value") == token


def test_memory_state_manager_delays_expiration_after_use_end_to_end(
memory_expiration_app: AppHarness,
driver: WebDriver,
):
"""Using a token should start a fresh expiration window from the last use."""
app_instance = memory_expiration_app.app_instance
assert app_instance is not None

token_input = AppHarness.poll_for_or_raise_timeout(
lambda: driver.find_element(By.ID, "token")
)
token = memory_expiration_app.poll_for_value(token_input)
assert token is not None

counter = driver.find_element(By.ID, "counter")
increment = driver.find_element(By.ID, "increment")
app_state_manager = app_instance.state_manager
assert isinstance(app_state_manager, StateManagerMemory)

AppHarness.expect(lambda: counter.text == "0")

increment.click()
AppHarness.expect(lambda: counter.text == "1")
AppHarness.expect(lambda: token in app_state_manager.states)

time.sleep(0.6)
increment.click()
AppHarness.expect(lambda: counter.text == "2")
AppHarness.expect(lambda: token in app_state_manager.states)

time.sleep(0.6)
increment.click()
AppHarness.expect(lambda: counter.text == "3")
AppHarness.expect(lambda: token in app_state_manager.states)

time.sleep(0.6)
assert token in app_state_manager.states
assert counter.text == "3"

AppHarness.expect(lambda: token not in app_state_manager.states, timeout=5)

increment.click()
AppHarness.expect(lambda: counter.text == "1")
assert token_input.get_attribute("value") == token
Loading
Loading