From 54a9b45cc86bdcd48288ae164621adaefe56e501 Mon Sep 17 00:00:00 2001 From: Farhan Ali Raza Date: Thu, 19 Mar 2026 16:51:23 +0500 Subject: [PATCH 01/10] feat: add experimental memo decorator for JS-level component and function memoization Introduce rx.experimental.memo (rx._x.memo) that allows memoizing components and plain functions at the JavaScript level. Supports component memos with typed props (including children and rest props via RestProp), and function memos that emit raw JS. Updates the compiler pipeline to handle both memo kinds alongside existing CustomComponent memos, and refactors signature rendering to use DestructuredArg. --- pyi_hashes.json | 5 +- reflex/__init__.py | 2 +- reflex/app.py | 6 +- reflex/compiler/compiler.py | 40 +- reflex/compiler/templates.py | 13 +- reflex/compiler/utils.py | 108 ++- reflex/experimental/__init__.py | 2 + reflex/experimental/memo.py | 796 ++++++++++++++++++++ reflex/vars/__init__.py | 3 +- reflex/vars/object.py | 4 + tests/integration/test_experimental_memo.py | 150 ++++ tests/units/experimental/__init__.py | 0 tests/units/experimental/test_memo.py | 348 +++++++++ 13 files changed, 1463 insertions(+), 14 deletions(-) create mode 100644 reflex/experimental/memo.py create mode 100644 tests/integration/test_experimental_memo.py create mode 100644 tests/units/experimental/__init__.py create mode 100644 tests/units/experimental/test_memo.py diff --git a/pyi_hashes.json b/pyi_hashes.json index ec9bbff8850..ff2fcd02e92 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -1,5 +1,5 @@ { - "reflex/__init__.pyi": "0a3ae880e256b9fd3b960e12a2cb51a7", + "reflex/__init__.pyi": "224964f24351c79614ab9a4ae47560a0", "reflex/components/__init__.pyi": "ac05995852baa81062ba3d18fbc489fb", "reflex/components/base/__init__.pyi": "16e47bf19e0d62835a605baa3d039c5a", "reflex/components/base/app_wrap.pyi": "22e94feaa9fe675bcae51c412f5b67f1", @@ -118,5 +118,6 @@ "reflex/components/recharts/general.pyi": "d87ff9b85b2a204be01753690df4fb11", "reflex/components/recharts/polar.pyi": "b8b1a3e996e066facdf4f8c9eb363137", "reflex/components/recharts/recharts.pyi": "d5c9fc57a03b419748f0408c23319eee", - "reflex/components/sonner/toast.pyi": "3c27bad1aaeb5183eaa6a41e77e8d7f0" + "reflex/components/sonner/toast.pyi": "3c27bad1aaeb5183eaa6a41e77e8d7f0", + "reflex/experimental/memo.pyi": "a1c5c4682fc4dadbd82a0a5e8fd4bd32" } diff --git a/reflex/__init__.py b/reflex/__init__.py index 066df110f02..29fc5983971 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -342,7 +342,7 @@ "utils.imports": ["ImportDict", "ImportVar"], "utils.misc": ["run_in_thread"], "utils.serializers": ["serializer"], - "vars": ["Var", "field", "Field"], + "vars": ["Var", "field", "Field", "RestProp"], } _SUBMODULES: set[str] = { diff --git a/reflex/app.py b/reflex/app.py index 54682543a7d..cbb77cff253 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -83,6 +83,7 @@ get_hydrate_event, noop, ) +from reflex.experimental.memo import EXPERIMENTAL_MEMOS from reflex.istate.proxy import StateProxy from reflex.page import DECORATED_PAGES from reflex.route import ( @@ -1319,7 +1320,10 @@ def memoized_toast_provider(): memo_components_output, memo_components_result, memo_components_imports, - ) = compiler.compile_memo_components(dict.fromkeys(CUSTOM_COMPONENTS.values())) + ) = compiler.compile_memo_components( + dict.fromkeys(CUSTOM_COMPONENTS.values()), + tuple(EXPERIMENTAL_MEMOS.values()), + ) compile_results.append((memo_components_output, memo_components_result)) all_imports.update(memo_components_imports) progress.advance(task) diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 0c8ee62c1a1..4a9c55bff6e 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -22,6 +22,11 @@ from reflex.constants.compiler import PageNames, ResetStylesheet from reflex.constants.state import FIELD_MARKER from reflex.environment import environment +from reflex.experimental.memo import ( + ExperimentalMemoComponentDefinition, + ExperimentalMemoDefinition, + ExperimentalMemoFunctionDefinition, +) from reflex.state import BaseState from reflex.style import SYSTEM_COLOR_MODE from reflex.utils import console, path_ops @@ -339,20 +344,20 @@ def _compile_component(component: Component | StatefulComponent) -> str: def _compile_memo_components( components: Iterable[CustomComponent], + experimental_memos: Iterable[ExperimentalMemoDefinition] = (), ) -> tuple[str, dict[str, list[ImportVar]]]: """Compile the components. Args: components: The components to compile. + experimental_memos: The experimental memos to compile. Returns: The compiled components. """ - imports = { - "react": [ImportVar(tag="memo")], - f"$/{constants.Dirs.STATE_PATH}": [ImportVar(tag="isTrue")], - } + imports: dict[str, list[ImportVar]] = {} component_renders = [] + function_renders = [] # Compile each component. for component in components: @@ -360,7 +365,27 @@ def _compile_memo_components( component_renders.append(component_render) imports = utils.merge_imports(imports, component_imports) - _apply_common_imports(imports) + for memo in experimental_memos: + if isinstance(memo, ExperimentalMemoComponentDefinition): + memo_render, memo_imports = utils.compile_experimental_component_memo(memo) + component_renders.append(memo_render) + imports = utils.merge_imports(imports, memo_imports) + elif isinstance(memo, ExperimentalMemoFunctionDefinition): + memo_render, memo_imports = utils.compile_experimental_function_memo(memo) + function_renders.append(memo_render) + imports = utils.merge_imports(imports, memo_imports) + + if component_renders: + imports = utils.merge_imports( + { + "react": [ImportVar(tag="memo")], + f"$/{constants.Dirs.STATE_PATH}": [ImportVar(tag="isTrue")], + }, + imports, + ) + + if component_renders: + _apply_common_imports(imports) dynamic_imports = { comp_import: None @@ -380,6 +405,7 @@ def _compile_memo_components( templates.memo_components_template( imports=utils.compile_imports(imports), components=component_renders, + functions=function_renders, dynamic_imports=sorted(dynamic_imports), custom_codes=custom_codes, ), @@ -568,11 +594,13 @@ def compile_page(path: str, component: BaseComponent) -> tuple[str, str]: def compile_memo_components( components: Iterable[CustomComponent], + experimental_memos: Iterable[ExperimentalMemoDefinition] = (), ) -> tuple[str, str, dict[str, list[ImportVar]]]: """Compile the custom components. Args: components: The custom components to compile. + experimental_memos: The experimental memos to compile. Returns: The path and code of the compiled components. @@ -581,7 +609,7 @@ def compile_memo_components( output_path = utils.get_components_path() # Compile the components. - code, imports = _compile_memo_components(components) + code, imports = _compile_memo_components(components, experimental_memos) return output_path, code, imports diff --git a/reflex/compiler/templates.py b/reflex/compiler/templates.py index a8a7dbe4ec2..1bf6c8ce061 100644 --- a/reflex/compiler/templates.py +++ b/reflex/compiler/templates.py @@ -8,7 +8,6 @@ from reflex import constants from reflex.constants import Hooks -from reflex.constants.state import CAMEL_CASE_MEMO_MARKER from reflex.utils.format import format_state_name, json_dumps from reflex.vars.base import VarData @@ -661,6 +660,7 @@ def stateful_components_template(imports: list[_ImportDict], memoized_code: str) def memo_components_template( imports: list[_ImportDict], components: list[dict[str, Any]], + functions: list[dict[str, Any]], dynamic_imports: Iterable[str], custom_codes: Iterable[str], ) -> str: @@ -669,6 +669,7 @@ def memo_components_template( Args: imports: List of import statements. components: List of component definitions. + functions: List of function definitions. dynamic_imports: List of dynamic import statements. custom_codes: List of custom code snippets. @@ -682,7 +683,7 @@ def memo_components_template( components_code = "" for component in components: components_code += f""" -export const {component["name"]} = memo(({{ {",".join([f"{prop}:{prop}{CAMEL_CASE_MEMO_MARKER}" for prop in component.get("props", [])])} }}) => {{ +export const {component["name"]} = memo(({component["signature"]}) => {{ {_render_hooks(component.get("hooks", {}))} return( {_RenderUtils.render(component["render"])} @@ -690,6 +691,12 @@ def memo_components_template( }}); """ + functions_code = "" + for function in functions: + functions_code += ( + f"\nexport const {function['name']} = {function['function']};\n" + ) + return f""" {imports_str} @@ -697,6 +704,8 @@ def memo_components_template( {custom_code_str} +{functions_code} + {components_code}""" diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 49eef26b924..f555e2807ae 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -20,7 +20,11 @@ from reflex.components.el.elements.metadata import Head, Link, Meta, Title from reflex.components.el.elements.other import Html from reflex.components.el.elements.sectioning import Body -from reflex.constants.state import FIELD_MARKER +from reflex.constants.state import CAMEL_CASE_MEMO_MARKER, FIELD_MARKER +from reflex.experimental.memo import ( + ExperimentalMemoComponentDefinition, + ExperimentalMemoFunctionDefinition, +) from reflex.istate.storage import Cookie, LocalStorage, SessionStorage from reflex.state import BaseState, _resolve_delta from reflex.style import Style @@ -28,6 +32,7 @@ from reflex.utils.imports import ImportVar, ParsedImportDict from reflex.utils.prerequisites import get_web_dir from reflex.vars.base import Field, Var, VarData +from reflex.vars.function import DestructuredArg # To re-export this function. merge_imports = imports.merge_imports @@ -344,6 +349,9 @@ def compile_custom_component( { "name": component.tag, "props": props, + "signature": DestructuredArg( + fields=tuple(f"{prop}:{prop}{CAMEL_CASE_MEMO_MARKER}" for prop in props) + ).to_javascript(), "render": render.render(), "hooks": render._get_all_hooks(), "custom_code": render._get_all_custom_code(), @@ -353,6 +361,104 @@ def compile_custom_component( ) +def _apply_component_style_for_compile(component: Component) -> Component: + """Apply the app style to a compiled component tree. + + Args: + component: The component tree. + + Returns: + The styled component tree. + """ + try: + from reflex.utils.prerequisites import get_and_validate_app + + style = get_and_validate_app().app.style + except Exception: + style = {} + + component._add_style_recursive(style) + return component + + +def compile_experimental_component_memo( + definition: ExperimentalMemoComponentDefinition, +) -> tuple[dict, ParsedImportDict]: + """Compile an experimental memo component. + + Args: + definition: The component memo definition. + + Returns: + A tuple of the compiled component definition and its imports. + """ + render = _apply_component_style_for_compile(definition.component) + + imports: ParsedImportDict = { + lib: fields + for lib, fields in render._get_all_imports().items() + if lib != f"$/{constants.Dirs.COMPONENTS_PATH}" + } + + imports.setdefault("@emotion/react", []).append(ImportVar("jsx")) + + signature_fields = [ + f"{param.js_prop_name}:{param.placeholder_name}" + for param in definition.params + if not param.is_children and not param.is_rest + ] + + if any(param.is_children for param in definition.params): + signature_fields.insert(0, "children") + + rest_param = next((param for param in definition.params if param.is_rest), None) + + return ( + { + "kind": "component", + "name": definition.export_name, + "signature": DestructuredArg( + fields=tuple(signature_fields), + rest=rest_param.placeholder_name if rest_param is not None else None, + ).to_javascript(), + "render": render.render(), + "hooks": render._get_all_hooks(), + "custom_code": render._get_all_custom_code(), + "dynamic_imports": render._get_all_dynamic_imports(), + }, + imports, + ) + + +def compile_experimental_function_memo( + definition: ExperimentalMemoFunctionDefinition, +) -> tuple[dict, ParsedImportDict]: + """Compile an experimental memo function. + + Args: + definition: The function memo definition. + + Returns: + A tuple of the compiled function definition and its imports. + """ + imports: ParsedImportDict = {} + if var_data := definition.function._get_all_var_data(): + imports = { + lib: list(fields) + for lib, fields in dict(var_data.imports).items() + if lib != f"$/{constants.Dirs.COMPONENTS_PATH}" + } + + return ( + { + "kind": "function", + "name": definition.python_name, + "function": str(definition.function), + }, + imports, + ) + + def create_document_root( head_components: Sequence[Component] | None = None, html_lang: str | None = None, diff --git a/reflex/experimental/__init__.py b/reflex/experimental/__init__.py index 2734da19112..8ae5f3ee8a6 100644 --- a/reflex/experimental/__init__.py +++ b/reflex/experimental/__init__.py @@ -8,6 +8,7 @@ from . import hooks as hooks from .client_state import ClientStateVar as ClientStateVar +from .memo import memo as memo class ExperimentalNamespace(SimpleNamespace): @@ -58,4 +59,5 @@ def register_component_warning(component_name: str): client_state=ClientStateVar.create, hooks=hooks, code_block=code_block, + memo=memo, ) diff --git a/reflex/experimental/memo.py b/reflex/experimental/memo.py new file mode 100644 index 00000000000..17b4af8bb2f --- /dev/null +++ b/reflex/experimental/memo.py @@ -0,0 +1,796 @@ +"""Experimental memo support for vars and components.""" + +from __future__ import annotations + +import dataclasses +import inspect +from collections.abc import Callable +from functools import wraps +from typing import Any, get_args, get_origin, get_type_hints + +from reflex import constants +from reflex.components.base.bare import Bare +from reflex.components.base.fragment import Fragment +from reflex.components.component import Component +from reflex.components.dynamic import bundled_libraries +from reflex.constants.compiler import SpecialAttributes +from reflex.constants.state import CAMEL_CASE_MEMO_MARKER +from reflex.utils import format +from reflex.utils import types as type_utils +from reflex.utils.imports import ImportVar +from reflex.vars import VarData +from reflex.vars.base import LiteralVar, Var +from reflex.vars.function import ( + ArgsFunctionOperation, + DestructuredArg, + FunctionStringVar, + FunctionVar, + ReflexCallable, +) +from reflex.vars.object import RestProp + + +@dataclasses.dataclass(frozen=True, slots=True) +class MemoParam: + """Metadata about a memo parameter.""" + + name: str + annotation: Any + kind: inspect._ParameterKind + default: Any = inspect.Parameter.empty + js_prop_name: str | None = None + placeholder_name: str = "" + is_children: bool = False + is_rest: bool = False + + +@dataclasses.dataclass(frozen=True, slots=True) +class ExperimentalMemoDefinition: + """Base metadata for an experimental memo.""" + + fn: Callable[..., Any] + python_name: str + params: tuple[MemoParam, ...] + + +@dataclasses.dataclass(frozen=True, slots=True) +class ExperimentalMemoFunctionDefinition(ExperimentalMemoDefinition): + """A memo that compiles to a JavaScript function.""" + + function: ArgsFunctionOperation + imported_var: FunctionVar + + +@dataclasses.dataclass(frozen=True, slots=True) +class ExperimentalMemoComponentDefinition(ExperimentalMemoDefinition): + """A memo that compiles to a React component.""" + + export_name: str + component: Component + + +class ExperimentalMemoComponent(Component): + """A rendered instance of an experimental memo component.""" + + library = f"$/{constants.Dirs.COMPONENTS_PATH}" + + def _post_init(self, **kwargs): + """Initialize the experimental memo component. + + Args: + **kwargs: The kwargs to pass to the component. + """ + definition = kwargs.pop("memo_definition") + + explicit_props = { + param.name + for param in definition.params + if not param.is_children and not param.is_rest + } + component_fields = self.get_fields() + + declared_props = { + key: kwargs.pop(key) for key in list(kwargs) if key in explicit_props + } + + rest_props = {} + if _get_rest_param(definition.params) is not None: + rest_props = { + key: kwargs.pop(key) + for key in list(kwargs) + if key not in component_fields and not SpecialAttributes.is_special(key) + } + + super()._post_init(**kwargs) + + self.tag = definition.export_name + + props: dict[str, Any] = {} + for key, value in {**declared_props, **rest_props}.items(): + camel_cased_key = format.to_camel_case(key) + literal_value = LiteralVar.create(value) + props[camel_cased_key] = literal_value + setattr(self, camel_cased_key, literal_value) + + prop_names = dict.fromkeys(props) + object.__setattr__(self, "get_props", lambda: prop_names) + + +EXPERIMENTAL_MEMOS: dict[str, ExperimentalMemoDefinition] = {} + + +def _annotation_inner_type(annotation: Any) -> Any: + """Unwrap a Var-like annotation to its inner type. + + Args: + annotation: The annotation to unwrap. + + Returns: + The inner type for the annotation. + """ + if _is_rest_annotation(annotation): + return dict[str, Any] + + origin = get_origin(annotation) or annotation + if type_utils.safe_issubclass(origin, Var) and (args := get_args(annotation)): + return args[0] + return Any + + +def _is_rest_annotation(annotation: Any) -> bool: + """Check whether an annotation is a RestProp. + + Args: + annotation: The annotation to check. + + Returns: + Whether the annotation is a RestProp. + """ + origin = get_origin(annotation) or annotation + return isinstance(origin, type) and issubclass(origin, RestProp) + + +def _is_var_annotation(annotation: Any) -> bool: + """Check whether an annotation is a Var-like annotation. + + Args: + annotation: The annotation to check. + + Returns: + Whether the annotation is Var-like. + """ + origin = get_origin(annotation) or annotation + return isinstance(origin, type) and issubclass(origin, Var) + + +def _is_component_annotation(annotation: Any) -> bool: + """Check whether an annotation is component-like. + + Args: + annotation: The annotation to check. + + Returns: + Whether the annotation resolves to Component. + """ + origin = get_origin(annotation) or annotation + return isinstance(origin, type) and issubclass(origin, Component) + + +def _children_annotation_is_valid(annotation: Any) -> bool: + """Check whether an annotation is valid for children. + + Args: + annotation: The annotation to check. + + Returns: + Whether the annotation is valid for children. + """ + return _is_var_annotation(annotation) and type_utils.typehint_issubclass( + _annotation_inner_type(annotation), Component + ) + + +def _get_children_param(params: tuple[MemoParam, ...]) -> MemoParam | None: + return next((param for param in params if param.is_children), None) + + +def _get_rest_param(params: tuple[MemoParam, ...]) -> MemoParam | None: + return next((param for param in params if param.is_rest), None) + + +def _imported_function_var(name: str, return_type: Any) -> FunctionVar: + """Create the imported FunctionVar for an experimental memo. + + Args: + name: The exported function name. + return_type: The return type of the function. + + Returns: + The imported FunctionVar. + """ + return FunctionStringVar.create( + name, + _var_type=ReflexCallable[Any, return_type], + _var_data=VarData( + imports={f"$/{constants.Dirs.COMPONENTS_PATH}": [ImportVar(tag=name)]} + ), + ) + + +def _component_import_var(name: str) -> Var: + """Create the imported component var for an experimental memo component. + + Args: + name: The exported component name. + + Returns: + The component var. + """ + return Var( + name, + _var_type=type[Component], + _var_data=VarData( + imports={ + f"$/{constants.Dirs.COMPONENTS_PATH}": [ImportVar(tag=name)], + "@emotion/react": [ImportVar(tag="jsx")], + } + ), + ) + + +def _validate_var_return_expr(return_expr: Var, func_name: str) -> None: + """Validate that a var-returning memo can compile safely. + + Args: + return_expr: The return expression. + func_name: The function name for error messages. + + Raises: + TypeError: If the return expression depends on unsupported features. + """ + var_data = VarData.merge(return_expr._get_all_var_data()) + if var_data is None: + return + + if var_data.hooks: + msg = ( + f"Var-returning `@rx._x.memo` `{func_name}` cannot depend on hooks. " + "Use a component-returning `@rx._x.memo` instead." + ) + raise TypeError(msg) + + if var_data.components: + msg = ( + f"Var-returning `@rx._x.memo` `{func_name}` cannot depend on embedded " + "components, custom code, or dynamic imports. Use a component-returning " + "`@rx._x.memo` instead." + ) + raise TypeError(msg) + + for lib in dict(var_data.imports): + if not lib: + continue + if lib.startswith((".", "/", "$/", "http")): + continue + if format.format_library_name(lib) in bundled_libraries: + continue + msg = ( + f"Var-returning `@rx._x.memo` `{func_name}` cannot import `{lib}` because " + "it is not bundled. Use a component-returning `@rx._x.memo` instead." + ) + raise TypeError(msg) + + +def _rest_placeholder(name: str) -> RestProp: + """Create the placeholder RestProp. + + Args: + name: The JavaScript identifier. + + Returns: + The placeholder rest prop. + """ + return RestProp(_js_expr=name, _var_type=dict[str, Any]) + + +def _var_placeholder(name: str, annotation: Any) -> Var: + """Create a placeholder Var for a memo parameter. + + Args: + name: The JavaScript identifier. + annotation: The parameter annotation. + + Returns: + The placeholder Var. + """ + return Var(_js_expr=name, _var_type=_annotation_inner_type(annotation)).guess_type() + + +def _placeholder_for_param(param: MemoParam) -> Var: + """Create a placeholder var for a parameter. + + Args: + param: The parameter metadata. + + Returns: + The placeholder var. + """ + if param.is_rest: + return _rest_placeholder(param.placeholder_name) + return _var_placeholder(param.placeholder_name, param.annotation) + + +def _evaluate_memo_function( + fn: Callable[..., Any], + params: tuple[MemoParam, ...], +) -> Any: + """Evaluate a memo function with placeholder vars. + + Args: + fn: The function to evaluate. + params: The memo parameters. + + Returns: + The return value from the function. + """ + positional_args = [] + keyword_args = {} + + for param in params: + placeholder = _placeholder_for_param(param) + if param.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + positional_args.append(placeholder) + else: + keyword_args[param.name] = placeholder + + return fn(*positional_args, **keyword_args) + + +def _lift_rest_props(component: Component) -> Component: + """Convert RestProp children into special props. + + Args: + component: The component tree to rewrite. + + Returns: + The rewritten component tree. + """ + special_props = list(component.special_props) + rewritten_children = [] + + for child in component.children: + if isinstance(child, Bare) and isinstance(child.contents, RestProp): + special_props.append(child.contents) + continue + + if isinstance(child, Component): + child = _lift_rest_props(child) + + rewritten_children.append(child) + + component.children = rewritten_children + component.special_props = special_props + return component + + +def _analyze_params( + fn: Callable[..., Any], + *, + for_component: bool, +) -> tuple[MemoParam, ...]: + """Analyze and validate memo parameters. + + Args: + fn: The function to analyze. + for_component: Whether the memo returns a component. + + Returns: + The analyzed parameters. + + Raises: + TypeError: If the function signature is not supported. + """ + signature = inspect.signature(fn) + hints = get_type_hints(fn) + + params: list[MemoParam] = [] + rest_count = 0 + + for parameter in signature.parameters.values(): + if parameter.kind is inspect.Parameter.VAR_POSITIONAL: + msg = f"`@rx._x.memo` does not support `*args` in `{fn.__name__}`." + raise TypeError(msg) + if parameter.kind is inspect.Parameter.VAR_KEYWORD: + msg = f"`@rx._x.memo` does not support `**kwargs` in `{fn.__name__}`." + raise TypeError(msg) + if parameter.kind is inspect.Parameter.POSITIONAL_ONLY: + msg = ( + f"`@rx._x.memo` does not support positional-only parameters in " + f"`{fn.__name__}`." + ) + raise TypeError(msg) + + annotation = hints.get(parameter.name, parameter.annotation) + if annotation is inspect.Parameter.empty: + msg = ( + f"All parameters of `{fn.__name__}` must be annotated as `rx.Var[...]` " + f"or `rx.RestProp`. Missing annotation for `{parameter.name}`." + ) + raise TypeError(msg) + + is_rest = _is_rest_annotation(annotation) + is_children = parameter.name == "children" and _children_annotation_is_valid( + annotation + ) + + if parameter.name == "children" and not is_children: + msg = ( + f"`children` in `{fn.__name__}` must be annotated as " + "`rx.Var[rx.Component]`." + ) + raise TypeError(msg) + + if not is_rest and not _is_var_annotation(annotation): + msg = ( + f"All parameters of `{fn.__name__}` must be annotated as `rx.Var[...]` " + f"or `rx.RestProp`, got `{annotation}` for `{parameter.name}`." + ) + raise TypeError(msg) + + if is_rest: + rest_count += 1 + if rest_count > 1: + msg = ( + f"`@rx._x.memo` only supports one `rx.RestProp` in `{fn.__name__}`." + ) + raise TypeError(msg) + + js_prop_name = format.to_camel_case(parameter.name) + placeholder_name = ( + parameter.name + if is_children or is_rest or not for_component + else js_prop_name + CAMEL_CASE_MEMO_MARKER + ) + + params.append( + MemoParam( + name=parameter.name, + annotation=annotation, + kind=parameter.kind, + default=parameter.default, + js_prop_name=js_prop_name, + placeholder_name=placeholder_name, + is_children=is_children, + is_rest=is_rest, + ) + ) + + return tuple(params) + + +def _create_function_definition( + fn: Callable[..., Any], + return_annotation: Any, +) -> ExperimentalMemoFunctionDefinition: + """Create a definition for a var-returning memo. + + Args: + fn: The function to analyze. + return_annotation: The return annotation. + + Returns: + The function memo definition. + """ + params = _analyze_params(fn, for_component=False) + return_expr = Var.create(_evaluate_memo_function(fn, params)) + _validate_var_return_expr(return_expr, fn.__name__) + + children_param = _get_children_param(params) + rest_param = _get_rest_param(params) + if children_param is None and rest_param is None: + function = ArgsFunctionOperation.create( + args_names=tuple(param.placeholder_name for param in params), + return_expr=return_expr, + ) + else: + function = ArgsFunctionOperation.create( + args_names=( + DestructuredArg( + fields=tuple( + param.placeholder_name for param in params if not param.is_rest + ), + rest=( + rest_param.placeholder_name if rest_param is not None else None + ), + ), + ), + return_expr=return_expr, + ) + + return ExperimentalMemoFunctionDefinition( + fn=fn, + python_name=fn.__name__, + params=params, + function=function, + imported_var=_imported_function_var( + fn.__name__, _annotation_inner_type(return_annotation) + ), + ) + + +def _create_component_definition( + fn: Callable[..., Any], + return_annotation: Any, +) -> ExperimentalMemoComponentDefinition: + """Create a definition for a component-returning memo. + + Args: + fn: The function to analyze. + return_annotation: The return annotation. + + Returns: + The component memo definition. + + Raises: + TypeError: If the function does not return a component. + """ + params = _analyze_params(fn, for_component=True) + component = _evaluate_memo_function(fn, params) + if not isinstance(component, Component): + msg = ( + f"Component-returning `@rx._x.memo` `{fn.__name__}` must return an " + f"`rx.Component`, got `{type(component).__name__}`." + ) + raise TypeError(msg) + + return ExperimentalMemoComponentDefinition( + fn=fn, + python_name=fn.__name__, + params=params, + export_name=format.to_title_case(fn.__name__), + component=_lift_rest_props(component), + ) + + +def _bind_function_runtime_args( + definition: ExperimentalMemoFunctionDefinition, + *args: Any, + **kwargs: Any, +) -> tuple[Any, ...]: + """Bind runtime args for a var-returning memo. + + Args: + definition: The function memo definition. + *args: Positional arguments. + **kwargs: Keyword arguments. + + Returns: + The ordered arguments for the imported FunctionVar. + + Raises: + TypeError: If the provided arguments are invalid. + """ + children_param = _get_children_param(definition.params) + rest_param = _get_rest_param(definition.params) + if "children" in kwargs: + msg = f"`{definition.python_name}` only accepts children positionally." + raise TypeError(msg) + + if rest_param is not None and rest_param.name in kwargs: + msg = ( + f"`{definition.python_name}` captures rest props from extra keyword " + f"arguments. Do not pass `{rest_param.name}=` directly." + ) + raise TypeError(msg) + + if args and children_param is None: + msg = f"`{definition.python_name}` only accepts keyword props." + raise TypeError(msg) + + if any(not _is_component_child(child) for child in args): + msg = ( + f"`{definition.python_name}` only accepts positional children that are " + "`rx.Component` or `rx.Var[rx.Component]`." + ) + raise TypeError(msg) + + explicit_params = [ + param + for param in definition.params + if not param.is_rest and not param.is_children + ] + explicit_values = {} + remaining_props = kwargs.copy() + for param in explicit_params: + if param.name in remaining_props: + explicit_values[param.name] = remaining_props.pop(param.name) + elif param.default is not inspect.Parameter.empty: + explicit_values[param.name] = param.default + else: + msg = f"`{definition.python_name}` is missing required prop `{param.name}`." + raise TypeError(msg) + + if remaining_props and rest_param is None: + unexpected_prop = next(iter(remaining_props)) + msg = ( + f"`{definition.python_name}` does not accept prop `{unexpected_prop}`. " + "Only declared props may be passed when no `rx.RestProp` is present." + ) + raise TypeError(msg) + + if children_param is None and rest_param is None: + return tuple(explicit_values[param.name] for param in explicit_params) + + children_value: Any | None = None + if children_param is not None: + children_value = args[0] if len(args) == 1 else Fragment.create(*args) + + bound_props = {} + if children_param is not None: + bound_props[children_param.name] = children_value + bound_props.update(explicit_values) + bound_props.update(remaining_props) + return (bound_props,) + + +def _is_component_child(value: Any) -> bool: + """Check whether a value is valid as an experimental memo child. + + Args: + value: The value to check. + + Returns: + Whether the value is a component child. + """ + return isinstance(value, Component) or ( + isinstance(value, Var) + and type_utils.typehint_issubclass(value._var_type, Component) + ) + + +def _create_function_wrapper( + definition: ExperimentalMemoFunctionDefinition, +) -> Callable[..., Var]: + """Create the Python wrapper for a var-returning memo. + + Args: + definition: The function memo definition. + + Returns: + The wrapper callable. + """ + imported_var = definition.imported_var + + @wraps(definition.fn) + def wrapper(*args: Any, **kwargs: Any) -> Var: + return imported_var.call( + *_bind_function_runtime_args(definition, *args, **kwargs) + ) + + def call(*args: Any, **kwargs: Any) -> Var: + return imported_var.call( + *_bind_function_runtime_args(definition, *args, **kwargs) + ) + + def partial(*args: Any, **kwargs: Any) -> FunctionVar: + return imported_var.partial( + *_bind_function_runtime_args(definition, *args, **kwargs) + ) + + object.__setattr__(wrapper, "call", call) + object.__setattr__(wrapper, "partial", partial) + object.__setattr__(wrapper, "_as_var", lambda: imported_var) + return wrapper + + +def _create_component_wrapper( + definition: ExperimentalMemoComponentDefinition, +) -> Callable[..., ExperimentalMemoComponent]: + """Create the Python wrapper for a component-returning memo. + + Args: + definition: The component memo definition. + + Returns: + The wrapper callable. + """ + children_param = _get_children_param(definition.params) + rest_param = _get_rest_param(definition.params) + explicit_params = [ + param + for param in definition.params + if not param.is_children and not param.is_rest + ] + + @wraps(definition.fn) + def wrapper(*children: Any, **props: Any) -> ExperimentalMemoComponent: + if "children" in props: + msg = f"`{definition.python_name}` only accepts children positionally." + raise TypeError(msg) + if rest_param is not None and rest_param.name in props: + msg = ( + f"`{definition.python_name}` captures rest props from extra keyword " + f"arguments. Do not pass `{rest_param.name}=` directly." + ) + raise TypeError(msg) + if children and children_param is None: + msg = f"`{definition.python_name}` only accepts keyword props." + raise TypeError(msg) + if any(not _is_component_child(child) for child in children): + msg = ( + f"`{definition.python_name}` only accepts positional children that are " + "`rx.Component` or `rx.Var[rx.Component]`." + ) + raise TypeError(msg) + + explicit_values = {} + remaining_props = props.copy() + for param in explicit_params: + if param.name in remaining_props: + explicit_values[param.name] = remaining_props.pop(param.name) + elif param.default is not inspect.Parameter.empty: + explicit_values[param.name] = param.default + else: + msg = f"`{definition.python_name}` is missing required prop `{param.name}`." + raise TypeError(msg) + + if remaining_props and rest_param is None: + unexpected_prop = next(iter(remaining_props)) + msg = ( + f"`{definition.python_name}` does not accept prop `{unexpected_prop}`. " + "Only declared props may be passed when no `rx.RestProp` is present." + ) + raise TypeError(msg) + + return ExperimentalMemoComponent._create( + children=list(children), + memo_definition=definition, + **explicit_values, + **remaining_props, + ) + + object.__setattr__( + wrapper, "_as_var", lambda: _component_import_var(definition.export_name) + ) + return wrapper + + +def memo(fn: Callable[..., Any]) -> Callable[..., Any]: + """Create an experimental memo from a function. + + Args: + fn: The function to memoize. + + Returns: + The wrapped function or component factory. + + Raises: + TypeError: If the return type is not supported. + """ + hints = get_type_hints(fn) + return_annotation = hints.get("return", inspect.Signature.empty) + if return_annotation is inspect.Signature.empty: + msg = ( + f"`@rx._x.memo` requires a return annotation on `{fn.__name__}`. " + "Use `-> rx.Component` or `-> rx.Var[...]`." + ) + raise TypeError(msg) + + if _is_component_annotation(return_annotation): + definition = _create_component_definition(fn, return_annotation) + EXPERIMENTAL_MEMOS[definition.export_name] = definition + return _create_component_wrapper(definition) + + if _is_var_annotation(return_annotation): + definition = _create_function_definition(fn, return_annotation) + EXPERIMENTAL_MEMOS[definition.python_name] = definition + return _create_function_wrapper(definition) + + msg = ( + f"`@rx._x.memo` on `{fn.__name__}` must return `rx.Component` or `rx.Var[...]`, " + f"got `{return_annotation}`." + ) + raise TypeError(msg) diff --git a/reflex/vars/__init__.py b/reflex/vars/__init__.py index c81e9a9bff3..c1989e5733d 100644 --- a/reflex/vars/__init__.py +++ b/reflex/vars/__init__.py @@ -17,7 +17,7 @@ from .datetime import DateTimeVar from .function import FunctionStringVar, FunctionVar, VarOperationCall from .number import BooleanVar, LiteralBooleanVar, LiteralNumberVar, NumberVar -from .object import LiteralObjectVar, ObjectVar +from .object import LiteralObjectVar, ObjectVar, RestProp from .sequence import ( ArrayVar, ConcatVarOperation, @@ -46,6 +46,7 @@ "LiteralVar", "NumberVar", "ObjectVar", + "RestProp", "StringVar", "Var", "VarData", diff --git a/reflex/vars/object.py b/reflex/vars/object.py index 1de146db5d0..e13b715e3d6 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -362,6 +362,10 @@ def contains(self, key: Var | Any) -> BooleanVar: return object_has_own_property_operation(self, key) +class RestProp(ObjectVar[dict[str, Any]]): + """A special object var representing forwarded rest props.""" + + @dataclasses.dataclass( eq=False, frozen=True, diff --git a/tests/integration/test_experimental_memo.py b/tests/integration/test_experimental_memo.py new file mode 100644 index 00000000000..f63be1804c1 --- /dev/null +++ b/tests/integration/test_experimental_memo.py @@ -0,0 +1,150 @@ +"""Integration tests for rx._x.memo.""" + +from collections.abc import Generator + +import pytest +from selenium.webdriver.common.by import By + +import reflex.app as reflex_app +import reflex.state as reflex_state +from reflex import constants +from reflex.testing import AppHarness + + +def ExperimentalMemoApp(): + """Reflex app that exercises experimental memo functions and components.""" + import reflex as rx + + class FooComponent(rx.Fragment): + def add_custom_code(self) -> list[str]: + return [ + "const foo = 'bar'", + ] + + @rx._x.memo + def foo_component(label: rx.Var[str]) -> rx.Component: + return FooComponent.create(label, rx.Var("foo")) + + @rx._x.memo + def format_price(amount: rx.Var[int], currency: rx.Var[str]) -> rx.Var[str]: + return currency.to(str) + ": $" + amount.to(str) + + @rx._x.memo + def summary_card( + children: rx.Var[rx.Component], + rest: rx.RestProp, + *, + title: rx.Var[str], + value: rx.Var[str], + ) -> rx.Component: + return rx.box( + rx.heading(title, id="summary-title"), + rx.text(value, id="summary-value"), + children, + rest, + ) + + class ExperimentalMemoState(rx.State): + amount: int = 125 + currency: str = "USD" + title: str = "Current Price" + + @rx.event + def increment_amount(self): + self.amount += 5 + + def index() -> rx.Component: + formatted_price = format_price( + amount=ExperimentalMemoState.amount, + currency=ExperimentalMemoState.currency, + ) + return rx.vstack( + rx.vstack( + foo_component(label="foo"), + foo_component(label="bar"), + id="experimental-memo-custom-code", + ), + rx.text(formatted_price, id="formatted-price"), + rx.button( + "Increment", + id="increment-price", + on_click=ExperimentalMemoState.increment_amount, + ), + summary_card( + rx.text("Children are passed positionally.", id="summary-child"), + title=ExperimentalMemoState.title, + value=formatted_price, + id="summary-card", + class_name="forwarded-summary-card", + ), + ) + + app = rx.App() + app.add_page(index) + + +@pytest.fixture +def experimental_memo_app(tmp_path, monkeypatch) -> Generator[AppHarness, None, None]: + """Start ExperimentalMemoApp app at tmp_path via AppHarness. + + Args: + tmp_path: pytest tmp_path fixture. + monkeypatch: pytest monkeypatch fixture. + + Yields: + Running AppHarness instance. + """ + monkeypatch.setenv( + constants.PYTEST_CURRENT_TEST, + "tests/integration/test_experimental_memo.py::test_experimental_memo_app", + ) + monkeypatch.setattr(reflex_app, "is_testing_env", lambda: True) + monkeypatch.setattr(reflex_state, "is_testing_env", lambda: True) + with AppHarness.create( + root=tmp_path, + app_source=ExperimentalMemoApp, + ) as harness: + yield harness + + +def test_experimental_memo_app(experimental_memo_app: AppHarness): + """Render experimental memos and assert on their frontend behavior. + + Args: + experimental_memo_app: Harness for ExperimentalMemoApp. + """ + assert experimental_memo_app.app_instance is not None, "app is not running" + driver = experimental_memo_app.frontend() + + memo_custom_code_stack = AppHarness.poll_for_or_raise_timeout( + lambda: driver.find_element(By.ID, "experimental-memo-custom-code") + ) + assert ( + experimental_memo_app.poll_for_content(memo_custom_code_stack, exp_not_equal="") + == "foobarbarbar" + ) + assert memo_custom_code_stack.text == "foobarbarbar" + + formatted_price = driver.find_element(By.ID, "formatted-price") + assert ( + experimental_memo_app.poll_for_content(formatted_price, exp_not_equal="") + == "USD: $125" + ) + + summary_card = driver.find_element(By.ID, "summary-card") + assert "forwarded-summary-card" in (summary_card.get_attribute("class") or "") + assert driver.find_element(By.ID, "summary-title").text == "Current Price" + assert ( + driver.find_element(By.ID, "summary-child").text + == "Children are passed positionally." + ) + + summary_value = driver.find_element(By.ID, "summary-value") + assert ( + experimental_memo_app.poll_for_content(summary_value, exp_not_equal="") + == "USD: $125" + ) + + driver.find_element(By.ID, "increment-price").click() + assert experimental_memo_app.poll_for_content(formatted_price) == "USD: $130" + assert experimental_memo_app.poll_for_content(summary_value) == "USD: $130" diff --git a/tests/units/experimental/__init__.py b/tests/units/experimental/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/units/experimental/test_memo.py b/tests/units/experimental/test_memo.py new file mode 100644 index 00000000000..cd732a87ab4 --- /dev/null +++ b/tests/units/experimental/test_memo.py @@ -0,0 +1,348 @@ +"""Tests for experimental memo support.""" + +from __future__ import annotations + +from typing import Any + +import pytest + +import reflex as rx +from reflex.compiler import compiler +from reflex.compiler import utils as compiler_utils +from reflex.components.component import CUSTOM_COMPONENTS, Component +from reflex.experimental.memo import ( + EXPERIMENTAL_MEMOS, + ExperimentalMemoComponent, + ExperimentalMemoComponentDefinition, + ExperimentalMemoFunctionDefinition, +) +from reflex.utils.imports import ImportVar +from reflex.vars import VarData +from reflex.vars.base import Var +from reflex.vars.function import FunctionVar + + +@pytest.fixture(autouse=True) +def restore_memo_registries(): + """Restore the memo registries after each test.""" + custom_components = dict(CUSTOM_COMPONENTS) + experimental_memos = dict(EXPERIMENTAL_MEMOS) + + yield + + CUSTOM_COMPONENTS.clear() + CUSTOM_COMPONENTS.update(custom_components) + EXPERIMENTAL_MEMOS.clear() + EXPERIMENTAL_MEMOS.update(experimental_memos) + + +def test_var_returning_memo(): + """Var-returning memos should behave like imported function vars.""" + + @rx._x.memo + def format_price(amount: rx.Var[int], currency: rx.Var[str]) -> rx.Var[str]: + return currency.to(str) + ": $" + amount.to(str) + + price = Var(_js_expr="price", _var_type=int) + currency = Var(_js_expr="currency", _var_type=str) + + assert ( + str(format_price(amount=price, currency=currency)) + == "(format_price(price, currency))" + ) + assert isinstance(format_price._as_var(), FunctionVar) + + definition = EXPERIMENTAL_MEMOS["format_price"] + assert isinstance(definition, ExperimentalMemoFunctionDefinition) + assert ( + str(definition.function) == '((amount, currency) => ((currency+": $")+amount))' + ) + + with pytest.raises(TypeError, match="only accepts keyword props"): + format_price(price, currency) + + +def test_component_returning_memo_with_children_and_rest(): + """Component-returning memos should accept positional children and forwarded props.""" + + @rx._x.memo + def my_card( + children: rx.Var[rx.Component], + rest: rx.RestProp, + *, + title: rx.Var[str], + ) -> rx.Component: + return rx.box( + rx.heading(title), + children, + rest, + ) + + component = my_card( + rx.text("child 1"), + rx.text("child 2"), + title="Hello", + class_name="extra", + ) + + assert isinstance(component, ExperimentalMemoComponent) + assert len(component.children) == 2 + + rendered = component.render() + assert rendered["name"] == "MyCard" + assert 'title:"Hello"' in rendered["props"] + assert 'className:"extra"' in rendered["props"] + + definition = EXPERIMENTAL_MEMOS["MyCard"] + assert isinstance(definition, ExperimentalMemoComponentDefinition) + assert any(str(prop) == "rest" for prop in definition.component.special_props) + + _, code, _ = compiler.compile_memo_components( + (), tuple(EXPERIMENTAL_MEMOS.values()) + ) + assert "export const MyCard = memo(({children, title:title" in code + assert "...rest" in code + assert "jsx(RadixThemesBox,{...rest}" in code + + +def test_var_returning_memo_with_rest_props(): + """Var-returning memos should capture extra keyword args into RestProp.""" + + @rx._x.memo + def merge_styles( + base: rx.Var[dict[str, str]], + overrides: rx.RestProp, + ) -> rx.Var[Any]: + return base.to(dict).merge(overrides) + + base = Var(_js_expr="base", _var_type=dict[str, str]) + merged = merge_styles(base=base, color="red") + + assert "merge_styles" in str(merged) + assert '["base"] : base' in str(merged) + assert '["color"] : "red"' in str(merged) + + _, code, _ = compiler.compile_memo_components( + (), tuple(EXPERIMENTAL_MEMOS.values()) + ) + assert ( + "export const merge_styles = (({base, ...overrides}) => ({...base, ...overrides}));" + in code + ) + + with pytest.raises(TypeError, match="Do not pass `overrides=` directly"): + merge_styles(base=base, overrides={"color": "red"}) + + +def test_var_returning_memo_with_children_and_rest(): + """Var-returning memos should accept positional children plus keyword props.""" + + @rx._x.memo + def label_slot( + children: rx.Var[rx.Component], + rest: rx.RestProp, + *, + label: rx.Var[str], + ) -> rx.Var[str]: + return label + + rendered = label_slot( + rx.text("child"), + label="Hello", + class_name="slot", + ) + + assert "label_slot" in str(rendered) + assert '["children"]' in str(rendered) + assert '["class_name"] : "slot"' in str(rendered) + + _, code, _ = compiler.compile_memo_components( + (), tuple(EXPERIMENTAL_MEMOS.values()) + ) + assert "export const label_slot = (({children, label, ...rest}) => label);" in code + + +def test_memo_requires_var_annotations(): + """Experimental memos should require Var annotations on parameters.""" + with pytest.raises(TypeError, match="must be annotated"): + + @rx._x.memo + def bad_annotation(value: int) -> rx.Var[str]: + return rx.Var.create("x") + + with pytest.raises(TypeError, match="Missing annotation"): + + @rx._x.memo + def missing_annotation(value) -> rx.Var[str]: + return rx.Var.create("x") + + +def test_memo_rejects_invalid_children_annotation(): + """Component memos should validate the special children annotation.""" + with pytest.raises(TypeError, match="children"): + + @rx._x.memo + def bad_children(children: rx.Var[str]) -> rx.Component: + return rx.text(children) + + +def test_memo_rejects_multiple_rest_props(): + """Experimental memos should only allow a single RestProp.""" + with pytest.raises(TypeError, match="only supports one"): + + @rx._x.memo + def too_many_rest( + first: rx.RestProp, + second: rx.RestProp, + ) -> rx.Var[Any]: + return first + + +def test_memo_rejects_varargs(): + """Experimental memos should reject *args and **kwargs.""" + with pytest.raises(TypeError, match=r"\*args"): + + @rx._x.memo + def bad_args(*values: rx.Var[str]) -> rx.Var[str]: + return rx.Var.create("x") + + with pytest.raises(TypeError, match=r"\*\*kwargs"): + + @rx._x.memo + def bad_kwargs(**values: rx.Var[str]) -> rx.Var[str]: + return rx.Var.create("x") + + +def test_component_memo_rejects_invalid_positional_usage(): + """Component memos should only accept positional children.""" + + @rx._x.memo + def title_card(*, title: rx.Var[str]) -> rx.Component: + return rx.box(rx.heading(title)) + + with pytest.raises(TypeError, match="only accepts keyword props"): + title_card(rx.text("child")) + + @rx._x.memo + def child_card( + children: rx.Var[rx.Component], *, title: rx.Var[str] + ) -> rx.Component: + return rx.box(rx.heading(title), children) + + with pytest.raises(TypeError, match="only accepts positional children"): + child_card("not a component", title="Hello") + + +def test_var_memo_rejects_invalid_positional_usage(): + """Var memos should also reserve positional arguments for children only.""" + + @rx._x.memo + def format_price(amount: rx.Var[int], currency: rx.Var[str]) -> rx.Var[str]: + return currency.to(str) + ": $" + amount.to(str) + + price = Var(_js_expr="price", _var_type=int) + currency = Var(_js_expr="currency", _var_type=str) + + with pytest.raises(TypeError, match="only accepts keyword props"): + format_price(price, currency) + + @rx._x.memo + def child_label( + children: rx.Var[rx.Component], *, label: rx.Var[str] + ) -> rx.Var[str]: + return label + + with pytest.raises(TypeError, match="only accepts positional children"): + child_label("not a component", label="Hello") + + +def test_var_returning_memo_rejects_hooks(): + """Var-returning memos should reject hook-bearing expressions.""" + with pytest.raises(TypeError, match="cannot depend on hooks"): + + @rx._x.memo + def bad_hook(value: rx.Var[str]) -> rx.Var[str]: + return Var( + _js_expr="value", + _var_type=str, + _var_data=VarData(hooks={"const badHook = 1": None}), + ) + + +def test_var_returning_memo_rejects_non_bundled_imports(): + """Var-returning memos should reject non-bundled imports.""" + with pytest.raises(TypeError, match="not bundled"): + + @rx._x.memo + def bad_import(value: rx.Var[str]) -> rx.Var[str]: + return Var( + _js_expr="value", + _var_type=str, + _var_data=VarData(imports={"some-lib": [ImportVar(tag="x")]}), + ) + + +def test_compile_memo_components_includes_experimental_functions_and_components(): + """The shared memo output should include both experimental functions and components.""" + + @rx.memo + def old_wrapper(title: rx.Var[str]) -> rx.Component: + return rx.text(title) + + @rx._x.memo + def format_price(amount: rx.Var[int], currency: rx.Var[str]) -> rx.Var[str]: + return currency.to(str) + ": $" + amount.to(str) + + @rx._x.memo + def my_card(children: rx.Var[rx.Component], *, title: rx.Var[str]) -> rx.Component: + return rx.box(rx.heading(title), children) + + _, code, _ = compiler.compile_memo_components( + dict.fromkeys(CUSTOM_COMPONENTS.values()), + tuple(EXPERIMENTAL_MEMOS.values()), + ) + + assert "export const OldWrapper = memo(" in code + assert "export const format_price =" in code + assert "export const MyCard = memo(" in code + + +def test_experimental_component_memo_get_imports(): + """Experimental component memos should resolve imports during compilation.""" + + class Inner(Component): + tag = "Inner" + library = "inner" + + @rx._x.memo + def wrapper() -> rx.Component: + return Inner.create() + + experimental_component = wrapper() + + assert "inner" not in experimental_component._get_all_imports() + + definition = EXPERIMENTAL_MEMOS["Wrapper"] + assert isinstance(definition, ExperimentalMemoComponentDefinition) + _, imports = compiler_utils.compile_experimental_component_memo(definition) + assert "inner" in imports + + +def test_compile_memo_components_includes_experimental_custom_code(): + """Experimental component memos should include custom code in compiled output.""" + + class FooComponent(rx.Fragment): + def add_custom_code(self) -> list[str]: + return [ + "const foo = 'bar'", + ] + + @rx._x.memo + def foo_component(label: rx.Var[str]) -> rx.Component: + return FooComponent.create(label, rx.Var("foo")) + + _, code, _ = compiler.compile_memo_components( + (), tuple(EXPERIMENTAL_MEMOS.values()) + ) + + assert "const foo = 'bar'" in code From c078e8b24ef99834e5b8b656b60fb039cb554a0b Mon Sep 17 00:00:00 2001 From: Farhan Ali Raza Date: Thu, 19 Mar 2026 18:51:55 +0500 Subject: [PATCH 02/10] fix: prevent memo name collisions and compile-time mutation of stored components Add registry helpers that detect duplicate exported names across memo kinds and raise on collision. Deepcopy the component before applying styles during compilation so the stored definition stays clean. Simplify the function wrappers .call to alias the wrapper itself. --- reflex/compiler/compiler.py | 2 - reflex/compiler/utils.py | 3 +- reflex/experimental/memo.py | 70 ++++++++++++++++++++++++--- tests/units/experimental/test_memo.py | 64 ++++++++++++++++++++++++ 4 files changed, 128 insertions(+), 11 deletions(-) diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 4a9c55bff6e..b644bfc0ad0 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -383,8 +383,6 @@ def _compile_memo_components( }, imports, ) - - if component_renders: _apply_common_imports(imports) dynamic_imports = { diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index f555e2807ae..2e99413b068 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -4,6 +4,7 @@ import asyncio import concurrent.futures +import copy import operator import traceback from collections.abc import Mapping, Sequence @@ -392,7 +393,7 @@ def compile_experimental_component_memo( Returns: A tuple of the compiled component definition and its imports. """ - render = _apply_component_style_for_compile(definition.component) + render = _apply_component_style_for_compile(copy.deepcopy(definition.component)) imports: ParsedImportDict = { lib: fields diff --git a/reflex/experimental/memo.py b/reflex/experimental/memo.py index 17b4af8bb2f..7115e018d23 100644 --- a/reflex/experimental/memo.py +++ b/reflex/experimental/memo.py @@ -119,6 +119,65 @@ def _post_init(self, **kwargs): EXPERIMENTAL_MEMOS: dict[str, ExperimentalMemoDefinition] = {} +def _memo_registry_key(definition: ExperimentalMemoDefinition) -> str: + """Get the registry key for an experimental memo. + + Args: + definition: The memo definition. + + Returns: + The registry key for the memo. + """ + if isinstance(definition, ExperimentalMemoComponentDefinition): + return definition.export_name + return definition.python_name + + +def _is_memo_reregistration( + existing: ExperimentalMemoDefinition, + definition: ExperimentalMemoDefinition, +) -> bool: + """Check whether a memo definition replaces the same memo during reload. + + Args: + existing: The currently registered memo definition. + definition: The new memo definition being registered. + + Returns: + Whether the new definition should replace the existing one. + """ + return ( + type(existing) is type(definition) + and existing.python_name == definition.python_name + and existing.fn.__module__ == definition.fn.__module__ + and existing.fn.__qualname__ == definition.fn.__qualname__ + ) + + +def _register_memo_definition(definition: ExperimentalMemoDefinition) -> None: + """Register an experimental memo definition. + + Args: + definition: The memo definition to register. + + Raises: + ValueError: If another memo already compiles to the same exported name. + """ + key = _memo_registry_key(definition) + if (existing := EXPERIMENTAL_MEMOS.get(key)) is not None and ( + not _is_memo_reregistration(existing, definition) + ): + msg = ( + f"Experimental memo name collision for `{key}`: " + f"`{existing.fn.__module__}.{existing.python_name}` and " + f"`{definition.fn.__module__}.{definition.python_name}` both compile " + "to the same memo name." + ) + raise ValueError(msg) + + EXPERIMENTAL_MEMOS[key] = definition + + def _annotation_inner_type(annotation: Any) -> Any: """Unwrap a Var-like annotation to its inner type. @@ -670,17 +729,12 @@ def wrapper(*args: Any, **kwargs: Any) -> Var: *_bind_function_runtime_args(definition, *args, **kwargs) ) - def call(*args: Any, **kwargs: Any) -> Var: - return imported_var.call( - *_bind_function_runtime_args(definition, *args, **kwargs) - ) - def partial(*args: Any, **kwargs: Any) -> FunctionVar: return imported_var.partial( *_bind_function_runtime_args(definition, *args, **kwargs) ) - object.__setattr__(wrapper, "call", call) + object.__setattr__(wrapper, "call", wrapper) object.__setattr__(wrapper, "partial", partial) object.__setattr__(wrapper, "_as_var", lambda: imported_var) return wrapper @@ -781,12 +835,12 @@ def memo(fn: Callable[..., Any]) -> Callable[..., Any]: if _is_component_annotation(return_annotation): definition = _create_component_definition(fn, return_annotation) - EXPERIMENTAL_MEMOS[definition.export_name] = definition + _register_memo_definition(definition) return _create_component_wrapper(definition) if _is_var_annotation(return_annotation): definition = _create_function_definition(fn, return_annotation) - EXPERIMENTAL_MEMOS[definition.python_name] = definition + _register_memo_definition(definition) return _create_function_wrapper(definition) msg = ( diff --git a/tests/units/experimental/test_memo.py b/tests/units/experimental/test_memo.py index cd732a87ab4..11be6bac7cd 100644 --- a/tests/units/experimental/test_memo.py +++ b/tests/units/experimental/test_memo.py @@ -2,6 +2,7 @@ from __future__ import annotations +from types import SimpleNamespace from typing import Any import pytest @@ -16,6 +17,7 @@ ExperimentalMemoComponentDefinition, ExperimentalMemoFunctionDefinition, ) +from reflex.style import Style from reflex.utils.imports import ImportVar from reflex.vars import VarData from reflex.vars.base import Var @@ -50,6 +52,10 @@ def format_price(amount: rx.Var[int], currency: rx.Var[str]) -> rx.Var[str]: str(format_price(amount=price, currency=currency)) == "(format_price(price, currency))" ) + assert ( + str(format_price.call(amount=price, currency=currency)) + == "(format_price(price, currency))" + ) assert isinstance(format_price._as_var(), FunctionVar) definition = EXPERIMENTAL_MEMOS["format_price"] @@ -198,6 +204,36 @@ def too_many_rest( return first +def test_memo_rejects_component_and_function_name_collision(): + """Experimental memos should reject same exported name across kinds.""" + + @rx._x.memo + def foo_bar() -> rx.Component: + return rx.box() + + assert "FooBar" in EXPERIMENTAL_MEMOS + + with pytest.raises(ValueError, match=r"name collision.*FooBar"): + + @rx._x.memo + def FooBar() -> rx.Var[str]: + return rx.Var.create("x") + + +def test_memo_rejects_component_export_name_collision(): + """Experimental memos should reject duplicate component export names.""" + + @rx._x.memo + def foo_bar() -> rx.Component: + return rx.box() + + with pytest.raises(ValueError, match=r"name collision.*FooBar"): + + @rx._x.memo + def foo__bar() -> rx.Component: + return rx.box() + + def test_memo_rejects_varargs(): """Experimental memos should reject *args and **kwargs.""" with pytest.raises(TypeError, match=r"\*args"): @@ -328,6 +364,34 @@ def wrapper() -> rx.Component: assert "inner" in imports +def test_compile_experimental_component_memo_does_not_mutate_definition( + monkeypatch: pytest.MonkeyPatch, +): + """Experimental component memo compilation should not mutate stored components.""" + + @rx._x.memo + def wrapper() -> rx.Component: + return rx.box("hi") + + definition = EXPERIMENTAL_MEMOS["Wrapper"] + assert isinstance(definition, ExperimentalMemoComponentDefinition) + assert definition.component.style == Style() + + monkeypatch.setattr( + "reflex.utils.prerequisites.get_and_validate_app", + lambda: SimpleNamespace( + app=SimpleNamespace( + style={type(definition.component): Style({"color": "red"})} + ) + ), + ) + + render, _ = compiler_utils.compile_experimental_component_memo(definition) + + assert render["render"]["props"] == ['css:({ ["color"] : "red" })'] + assert definition.component.style == Style() + + def test_compile_memo_components_includes_experimental_custom_code(): """Experimental component memos should include custom code in compiled output.""" From f98b52f64b2408a2dd6c9dd5d9d329c6563cff00 Mon Sep 17 00:00:00 2001 From: Farhan Ali Raza Date: Fri, 20 Mar 2026 13:19:34 +0500 Subject: [PATCH 03/10] test: clear old memos when testing. --- reflex/testing.py | 18 ++++-- tests/units/test_testing.py | 109 ++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 5 deletions(-) diff --git a/reflex/testing.py b/reflex/testing.py index 4ab72602334..e82cf023104 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -34,9 +34,10 @@ import reflex.utils.format import reflex.utils.prerequisites import reflex.utils.processes -from reflex.components.component import CustomComponent +from reflex.components.component import CUSTOM_COMPONENTS, CustomComponent from reflex.config import get_config from reflex.environment import environment +from reflex.experimental.memo import EXPERIMENTAL_MEMOS from reflex.istate.manager.disk import StateManagerDisk from reflex.istate.manager.memory import StateManagerMemory from reflex.istate.manager.redis import StateManagerRedis @@ -243,6 +244,10 @@ def _initialize_app(self): # disable telemetry reporting for tests os.environ["REFLEX_TELEMETRY_ENABLED"] = "false" + # Reset global memo registries so previous AppHarness apps do not + # leak compiled component definitions into the next test app. + CUSTOM_COMPONENTS.clear() + EXPERIMENTAL_MEMOS.clear() CustomComponent.create().get_component.cache_clear() self.app_path.mkdir(parents=True, exist_ok=True) if self.app_source is not None: @@ -269,15 +274,18 @@ def _initialize_app(self): reflex.utils.prerequisites.initialize_frontend_dependencies() with chdir(self.app_path): # ensure config and app are reloaded when testing different app - reflex.config.get_config(reload=True) + config = reflex.config.get_config(reload=True) # Ensure the AppHarness test does not skip State assignment due to running via pytest os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None) os.environ[reflex.constants.APP_HARNESS_FLAG] = "true" - # Ensure we actually compile the app during first initialization. + # Ensure we compile generated apps, and reload pre-existing app modules + # that were already imported so they can re-register memo definitions. + should_reload_app = ( + self.app_source is not None or config.module in sys.modules + ) self.app_instance, self.app_module = ( reflex.utils.prerequisites.get_and_validate_app( - # Do not reload the module for pre-existing apps (only apps generated from source) - reload=self.app_source is not None + reload=should_reload_app ) ) self.app_asgi = self.app_instance() diff --git a/tests/units/test_testing.py b/tests/units/test_testing.py index 8c8f1461bf0..56b2255eeda 100644 --- a/tests/units/test_testing.py +++ b/tests/units/test_testing.py @@ -1,8 +1,18 @@ """Unit tests for the included testing tools.""" +import sys +from types import ModuleType, SimpleNamespace +from unittest import mock + import pytest +import reflex.config +import reflex.reflex as reflex_cli +import reflex.testing as reflex_testing +import reflex.utils.prerequisites +from reflex.components.component import CUSTOM_COMPONENTS from reflex.constants import IS_WINDOWS +from reflex.experimental.memo import EXPERIMENTAL_MEMOS from reflex.testing import AppHarness @@ -38,3 +48,102 @@ class State(rx.State): assert harness.frontend_process.poll() is None assert harness.frontend_process.poll() is not None + + +@pytest.fixture +def preserve_memo_registries(): + """Restore global memo registries after each test.""" + custom_components = dict(CUSTOM_COMPONENTS) + experimental_memos = dict(EXPERIMENTAL_MEMOS) + try: + yield + finally: + CUSTOM_COMPONENTS.clear() + CUSTOM_COMPONENTS.update(custom_components) + EXPERIMENTAL_MEMOS.clear() + EXPERIMENTAL_MEMOS.update(experimental_memos) + + +def test_app_harness_initialize_clears_memo_registries( + tmp_path, preserve_memo_registries, monkeypatch +): + """Ensure app initialization clears leaked memo registries. + + Args: + tmp_path: pytest tmp_path fixture + preserve_memo_registries: restores global memo registries after the test + monkeypatch: pytest monkeypatch fixture + """ + fake_config = SimpleNamespace(loglevel=None, module="memo_app.memo_app") + fake_app = mock.Mock(_state_manager=None) + get_and_validate_app = mock.Mock( + return_value=reflex.utils.prerequisites.AppInfo( + app=fake_app, + module=ModuleType(fake_config.module), + ) + ) + + monkeypatch.setattr(reflex_testing, "get_config", lambda: fake_config) + monkeypatch.setattr(reflex.config, "get_config", lambda reload=False: fake_config) + monkeypatch.setattr(reflex_cli, "_init", lambda **kwargs: None) + monkeypatch.setattr( + reflex.utils.prerequisites, + "get_and_validate_app", + get_and_validate_app, + ) + CUSTOM_COMPONENTS["FooComponent"] = mock.sentinel.component + EXPERIMENTAL_MEMOS["format_value"] = mock.sentinel.memo + + assert "FooComponent" in CUSTOM_COMPONENTS + assert "format_value" in EXPERIMENTAL_MEMOS + + harness = AppHarness.create( + root=tmp_path / "memo_app", + app_source="import reflex as rx\napp = rx.App()", + app_name="memo_app", + ) + harness.app_module_path.parent.mkdir(parents=True, exist_ok=True) + harness._initialize_app() + + assert "FooComponent" not in CUSTOM_COMPONENTS + assert "format_value" not in EXPERIMENTAL_MEMOS + get_and_validate_app.assert_called_once_with(reload=True) + + +def test_app_harness_initialize_reloads_existing_imported_app( + tmp_path, preserve_memo_registries, monkeypatch +): + """Ensure pre-existing imported apps are reloaded after memo registry reset. + + Args: + tmp_path: pytest tmp_path fixture + preserve_memo_registries: restores global memo registries after the test + monkeypatch: pytest monkeypatch fixture + """ + fake_config = SimpleNamespace(loglevel=None, module="plain_app.plain_app") + fake_app = mock.Mock(_state_manager=None) + get_and_validate_app = mock.Mock( + return_value=reflex.utils.prerequisites.AppInfo( + app=fake_app, + module=ModuleType(fake_config.module), + ) + ) + + monkeypatch.setattr(reflex_testing, "get_config", lambda: fake_config) + monkeypatch.setattr(reflex.config, "get_config", lambda reload=False: fake_config) + monkeypatch.setattr( + reflex.utils.prerequisites, + "initialize_frontend_dependencies", + lambda: None, + ) + monkeypatch.setattr( + reflex.utils.prerequisites, + "get_and_validate_app", + get_and_validate_app, + ) + monkeypatch.setitem(sys.modules, fake_config.module, ModuleType(fake_config.module)) + + harness = AppHarness.create(root=tmp_path / "plain_app") + harness._initialize_app() + + get_and_validate_app.assert_called_once_with(reload=True) From 522e7e5a9b0142d192a3c7c3ed4d3c51b002f78d Mon Sep 17 00:00:00 2001 From: Farhan Ali Raza Date: Fri, 20 Mar 2026 13:23:12 +0500 Subject: [PATCH 04/10] test: cleanup --- tests/units/conftest.py | 20 +++++++ tests/units/experimental/test_memo.py | 13 +---- tests/units/test_testing.py | 77 ++++++++++++--------------- 3 files changed, 56 insertions(+), 54 deletions(-) diff --git a/tests/units/conftest.py b/tests/units/conftest.py index 612d8beaf85..1399832d8be 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -8,7 +8,9 @@ import pytest from reflex.app import App +from reflex.components.component import CUSTOM_COMPONENTS from reflex.event import EventSpec +from reflex.experimental.memo import EXPERIMENTAL_MEMOS from reflex.model import ModelRegistry from reflex.testing import chdir from reflex.utils import prerequisites @@ -209,3 +211,21 @@ def model_registry() -> Generator[type[ModelRegistry], None, None]: """ yield ModelRegistry ModelRegistry._metadata = None + + +@pytest.fixture +def preserve_memo_registries(): + """Save and restore global memo registries around a test. + + Yields: + None + """ + custom_components = dict(CUSTOM_COMPONENTS) + experimental_memos = dict(EXPERIMENTAL_MEMOS) + try: + yield + finally: + CUSTOM_COMPONENTS.clear() + CUSTOM_COMPONENTS.update(custom_components) + EXPERIMENTAL_MEMOS.clear() + EXPERIMENTAL_MEMOS.update(experimental_memos) diff --git a/tests/units/experimental/test_memo.py b/tests/units/experimental/test_memo.py index 11be6bac7cd..2b6fc1fb485 100644 --- a/tests/units/experimental/test_memo.py +++ b/tests/units/experimental/test_memo.py @@ -25,17 +25,8 @@ @pytest.fixture(autouse=True) -def restore_memo_registries(): - """Restore the memo registries after each test.""" - custom_components = dict(CUSTOM_COMPONENTS) - experimental_memos = dict(EXPERIMENTAL_MEMOS) - - yield - - CUSTOM_COMPONENTS.clear() - CUSTOM_COMPONENTS.update(custom_components) - EXPERIMENTAL_MEMOS.clear() - EXPERIMENTAL_MEMOS.update(experimental_memos) +def _restore_memo_registries(preserve_memo_registries): + """Autouse wrapper around the shared preserve_memo_registries fixture.""" def test_var_returning_memo(): diff --git a/tests/units/test_testing.py b/tests/units/test_testing.py index 56b2255eeda..36c8f173b01 100644 --- a/tests/units/test_testing.py +++ b/tests/units/test_testing.py @@ -51,30 +51,16 @@ class State(rx.State): @pytest.fixture -def preserve_memo_registries(): - """Restore global memo registries after each test.""" - custom_components = dict(CUSTOM_COMPONENTS) - experimental_memos = dict(EXPERIMENTAL_MEMOS) - try: - yield - finally: - CUSTOM_COMPONENTS.clear() - CUSTOM_COMPONENTS.update(custom_components) - EXPERIMENTAL_MEMOS.clear() - EXPERIMENTAL_MEMOS.update(experimental_memos) - - -def test_app_harness_initialize_clears_memo_registries( - tmp_path, preserve_memo_registries, monkeypatch -): - """Ensure app initialization clears leaked memo registries. +def harness_mocks(monkeypatch): + """Common mocks for AppHarness initialization tests. Args: - tmp_path: pytest tmp_path fixture - preserve_memo_registries: restores global memo registries after the test monkeypatch: pytest monkeypatch fixture + + Returns: + Namespace with fake_config and get_and_validate_app mock. """ - fake_config = SimpleNamespace(loglevel=None, module="memo_app.memo_app") + fake_config = SimpleNamespace(loglevel=None, module="test_app.test_app") fake_app = mock.Mock(_state_manager=None) get_and_validate_app = mock.Mock( return_value=reflex.utils.prerequisites.AppInfo( @@ -85,18 +71,34 @@ def test_app_harness_initialize_clears_memo_registries( monkeypatch.setattr(reflex_testing, "get_config", lambda: fake_config) monkeypatch.setattr(reflex.config, "get_config", lambda reload=False: fake_config) - monkeypatch.setattr(reflex_cli, "_init", lambda **kwargs: None) monkeypatch.setattr( reflex.utils.prerequisites, "get_and_validate_app", get_and_validate_app, ) + + return SimpleNamespace( + config=fake_config, + get_and_validate_app=get_and_validate_app, + ) + + +def test_app_harness_initialize_clears_memo_registries( + tmp_path, preserve_memo_registries, harness_mocks, monkeypatch +): + """Ensure app initialization clears leaked memo registries. + + Args: + tmp_path: pytest tmp_path fixture + preserve_memo_registries: restores global memo registries after the test + harness_mocks: shared AppHarness mock setup + monkeypatch: pytest monkeypatch fixture + """ + monkeypatch.setattr(reflex_cli, "_init", lambda **kwargs: None) + CUSTOM_COMPONENTS["FooComponent"] = mock.sentinel.component EXPERIMENTAL_MEMOS["format_value"] = mock.sentinel.memo - assert "FooComponent" in CUSTOM_COMPONENTS - assert "format_value" in EXPERIMENTAL_MEMOS - harness = AppHarness.create( root=tmp_path / "memo_app", app_source="import reflex as rx\napp = rx.App()", @@ -107,43 +109,32 @@ def test_app_harness_initialize_clears_memo_registries( assert "FooComponent" not in CUSTOM_COMPONENTS assert "format_value" not in EXPERIMENTAL_MEMOS - get_and_validate_app.assert_called_once_with(reload=True) + harness_mocks.get_and_validate_app.assert_called_once_with(reload=True) def test_app_harness_initialize_reloads_existing_imported_app( - tmp_path, preserve_memo_registries, monkeypatch + tmp_path, preserve_memo_registries, harness_mocks, monkeypatch ): """Ensure pre-existing imported apps are reloaded after memo registry reset. Args: tmp_path: pytest tmp_path fixture preserve_memo_registries: restores global memo registries after the test + harness_mocks: shared AppHarness mock setup monkeypatch: pytest monkeypatch fixture """ - fake_config = SimpleNamespace(loglevel=None, module="plain_app.plain_app") - fake_app = mock.Mock(_state_manager=None) - get_and_validate_app = mock.Mock( - return_value=reflex.utils.prerequisites.AppInfo( - app=fake_app, - module=ModuleType(fake_config.module), - ) - ) - - monkeypatch.setattr(reflex_testing, "get_config", lambda: fake_config) - monkeypatch.setattr(reflex.config, "get_config", lambda reload=False: fake_config) monkeypatch.setattr( reflex.utils.prerequisites, "initialize_frontend_dependencies", lambda: None, ) - monkeypatch.setattr( - reflex.utils.prerequisites, - "get_and_validate_app", - get_and_validate_app, + monkeypatch.setitem( + sys.modules, + harness_mocks.config.module, + ModuleType(harness_mocks.config.module), ) - monkeypatch.setitem(sys.modules, fake_config.module, ModuleType(fake_config.module)) harness = AppHarness.create(root=tmp_path / "plain_app") harness._initialize_app() - get_and_validate_app.assert_called_once_with(reload=True) + harness_mocks.get_and_validate_app.assert_called_once_with(reload=True) From 15015418367b1109f736c17ed91647683305ed63 Mon Sep 17 00:00:00 2001 From: Farhan Ali Raza Date: Fri, 20 Mar 2026 13:24:59 +0500 Subject: [PATCH 05/10] pyi: update hashes --- pyi_hashes.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyi_hashes.json b/pyi_hashes.json index a81f3167afb..d5d699547c5 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -118,6 +118,6 @@ "reflex/components/recharts/general.pyi": "d87ff9b85b2a204be01753690df4fb11", "reflex/components/recharts/polar.pyi": "b8b1a3e996e066facdf4f8c9eb363137", "reflex/components/recharts/recharts.pyi": "d5c9fc57a03b419748f0408c23319eee", - "reflex/components/sonner/toast.pyi": "3c27bad1aaeb5183eaa6a41e77e8d7f0", + "reflex/components/sonner/toast.pyi": "dca44901640cda9d58c62ff8434faa3e", "reflex/experimental/memo.pyi": "a1c5c4682fc4dadbd82a0a5e8fd4bd32" } From 24f96983afece695a23b16e4ddb185fe0dc67d67 Mon Sep 17 00:00:00 2001 From: Farhan Ali Raza Date: Wed, 25 Mar 2026 13:53:43 +0500 Subject: [PATCH 06/10] fix: camelCase rest-prop keys in memo function bindings and clean up memo internals MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Convert remaining_props keys to camelCase in _bind_function_runtime_args so rest props (e.g. class_name → className) match the component memo behavior. Also make MemoParam kw_only, return a tuple from get_props instead of a dict, and remove unnecessary monkeypatch boilerplate from the integration test fixture. --- pyi_hashes.json | 2 +- reflex/experimental/memo.py | 21 ++++++++++++++++++--- tests/integration/test_experimental_memo.py | 12 +----------- tests/units/experimental/test_memo.py | 8 ++++++-- 4 files changed, 26 insertions(+), 17 deletions(-) diff --git a/pyi_hashes.json b/pyi_hashes.json index d5d699547c5..a1b362948e0 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -119,5 +119,5 @@ "reflex/components/recharts/polar.pyi": "b8b1a3e996e066facdf4f8c9eb363137", "reflex/components/recharts/recharts.pyi": "d5c9fc57a03b419748f0408c23319eee", "reflex/components/sonner/toast.pyi": "dca44901640cda9d58c62ff8434faa3e", - "reflex/experimental/memo.pyi": "a1c5c4682fc4dadbd82a0a5e8fd4bd32" + "reflex/experimental/memo.pyi": "87dc52b4ffa791d54c85816016445790" } diff --git a/reflex/experimental/memo.py b/reflex/experimental/memo.py index 7115e018d23..f315da7371f 100644 --- a/reflex/experimental/memo.py +++ b/reflex/experimental/memo.py @@ -30,7 +30,7 @@ from reflex.vars.object import RestProp -@dataclasses.dataclass(frozen=True, slots=True) +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) class MemoParam: """Metadata about a memo parameter.""" @@ -112,7 +112,7 @@ def _post_init(self, **kwargs): props[camel_cased_key] = literal_value setattr(self, camel_cased_key, literal_value) - prop_names = dict.fromkeys(props) + prop_names = tuple(props) object.__setattr__(self, "get_props", lambda: prop_names) @@ -634,6 +634,8 @@ def _bind_function_runtime_args( """ children_param = _get_children_param(definition.params) rest_param = _get_rest_param(definition.params) + + # Validate positional children usage and reserved keywords. if "children" in kwargs: msg = f"`{definition.python_name}` only accepts children positionally." raise TypeError(msg) @@ -656,6 +658,7 @@ def _bind_function_runtime_args( ) raise TypeError(msg) + # Bind declared props before collecting any rest props. explicit_params = [ param for param in definition.params @@ -672,6 +675,7 @@ def _bind_function_runtime_args( msg = f"`{definition.python_name}` is missing required prop `{param.name}`." raise TypeError(msg) + # Reject unknown props unless a rest prop is declared. if remaining_props and rest_param is None: unexpected_prop = next(iter(remaining_props)) msg = ( @@ -680,18 +684,25 @@ def _bind_function_runtime_args( ) raise TypeError(msg) + # Return ordered explicit args when no packed props object is needed. if children_param is None and rest_param is None: return tuple(explicit_values[param.name] for param in explicit_params) + # Build the props object passed to the imported FunctionVar. children_value: Any | None = None if children_param is not None: children_value = args[0] if len(args) == 1 else Fragment.create(*args) + # Convert rest-prop keys to camelCase to match component memo behavior. + camel_cased_remaining_props = { + format.to_camel_case(key): value for key, value in remaining_props.items() + } + bound_props = {} if children_param is not None: bound_props[children_param.name] = children_value bound_props.update(explicit_values) - bound_props.update(remaining_props) + bound_props.update(camel_cased_remaining_props) return (bound_props,) @@ -761,6 +772,7 @@ def _create_component_wrapper( @wraps(definition.fn) def wrapper(*children: Any, **props: Any) -> ExperimentalMemoComponent: + # Validate positional children usage and reserved keywords. if "children" in props: msg = f"`{definition.python_name}` only accepts children positionally." raise TypeError(msg) @@ -780,6 +792,7 @@ def wrapper(*children: Any, **props: Any) -> ExperimentalMemoComponent: ) raise TypeError(msg) + # Bind declared props before collecting any rest props. explicit_values = {} remaining_props = props.copy() for param in explicit_params: @@ -791,6 +804,7 @@ def wrapper(*children: Any, **props: Any) -> ExperimentalMemoComponent: msg = f"`{definition.python_name}` is missing required prop `{param.name}`." raise TypeError(msg) + # Reject unknown props unless a rest prop is declared. if remaining_props and rest_param is None: unexpected_prop = next(iter(remaining_props)) msg = ( @@ -799,6 +813,7 @@ def wrapper(*children: Any, **props: Any) -> ExperimentalMemoComponent: ) raise TypeError(msg) + # Build the component props passed into the memo wrapper. return ExperimentalMemoComponent._create( children=list(children), memo_definition=definition, diff --git a/tests/integration/test_experimental_memo.py b/tests/integration/test_experimental_memo.py index f63be1804c1..935f8c75ce6 100644 --- a/tests/integration/test_experimental_memo.py +++ b/tests/integration/test_experimental_memo.py @@ -5,9 +5,6 @@ import pytest from selenium.webdriver.common.by import By -import reflex.app as reflex_app -import reflex.state as reflex_state -from reflex import constants from reflex.testing import AppHarness @@ -84,22 +81,15 @@ def index() -> rx.Component: @pytest.fixture -def experimental_memo_app(tmp_path, monkeypatch) -> Generator[AppHarness, None, None]: +def experimental_memo_app(tmp_path) -> Generator[AppHarness, None, None]: """Start ExperimentalMemoApp app at tmp_path via AppHarness. Args: tmp_path: pytest tmp_path fixture. - monkeypatch: pytest monkeypatch fixture. Yields: Running AppHarness instance. """ - monkeypatch.setenv( - constants.PYTEST_CURRENT_TEST, - "tests/integration/test_experimental_memo.py::test_experimental_memo_app", - ) - monkeypatch.setattr(reflex_app, "is_testing_env", lambda: True) - monkeypatch.setattr(reflex_state, "is_testing_env", lambda: True) with AppHarness.create( root=tmp_path, app_source=ExperimentalMemoApp, diff --git a/tests/units/experimental/test_memo.py b/tests/units/experimental/test_memo.py index 2b6fc1fb485..006baae95bd 100644 --- a/tests/units/experimental/test_memo.py +++ b/tests/units/experimental/test_memo.py @@ -79,15 +79,18 @@ def my_card( rx.text("child 1"), rx.text("child 2"), title="Hello", + foo="extra", class_name="extra", ) assert isinstance(component, ExperimentalMemoComponent) assert len(component.children) == 2 + assert component.get_props() == ("title", "foo") rendered = component.render() assert rendered["name"] == "MyCard" assert 'title:"Hello"' in rendered["props"] + assert 'foo:"extra"' in rendered["props"] assert 'className:"extra"' in rendered["props"] definition = EXPERIMENTAL_MEMOS["MyCard"] @@ -113,11 +116,12 @@ def merge_styles( return base.to(dict).merge(overrides) base = Var(_js_expr="base", _var_type=dict[str, str]) - merged = merge_styles(base=base, color="red") + merged = merge_styles(base=base, color="red", class_name="primary") assert "merge_styles" in str(merged) assert '["base"] : base' in str(merged) assert '["color"] : "red"' in str(merged) + assert '["className"] : "primary"' in str(merged) _, code, _ = compiler.compile_memo_components( (), tuple(EXPERIMENTAL_MEMOS.values()) @@ -151,7 +155,7 @@ def label_slot( assert "label_slot" in str(rendered) assert '["children"]' in str(rendered) - assert '["class_name"] : "slot"' in str(rendered) + assert '["className"] : "slot"' in str(rendered) _, code, _ = compiler.compile_memo_components( (), tuple(EXPERIMENTAL_MEMOS.values()) From e0385921e67c16b1e037031f4ae19153f1ba86b8 Mon Sep 17 00:00:00 2001 From: Farhan Ali Raza Date: Wed, 25 Mar 2026 13:59:46 +0500 Subject: [PATCH 07/10] refactor: replace memo wrapper closures with proper callable classes Replace _create_function_wrapper and _create_component_wrapper closures with _ExperimentalMemoFunctionWrapper and _ExperimentalMemoComponentWrapper classes, eliminating object.__setattr__ hacks for call/partial/_as_var in favor of real methods. --- pyi_hashes.json | 2 +- reflex/experimental/memo.py | 166 ++++++++++++++++++++++++++---------- 2 files changed, 121 insertions(+), 47 deletions(-) diff --git a/pyi_hashes.json b/pyi_hashes.json index a1b362948e0..46d16eb3250 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -119,5 +119,5 @@ "reflex/components/recharts/polar.pyi": "b8b1a3e996e066facdf4f8c9eb363137", "reflex/components/recharts/recharts.pyi": "d5c9fc57a03b419748f0408c23319eee", "reflex/components/sonner/toast.pyi": "dca44901640cda9d58c62ff8434faa3e", - "reflex/experimental/memo.pyi": "87dc52b4ffa791d54c85816016445790" + "reflex/experimental/memo.pyi": "aab342a879269a0a7b67df46a7913cb7" } diff --git a/reflex/experimental/memo.py b/reflex/experimental/memo.py index f315da7371f..9139fcb2dc9 100644 --- a/reflex/experimental/memo.py +++ b/reflex/experimental/memo.py @@ -5,7 +5,7 @@ import dataclasses import inspect from collections.abc import Callable -from functools import wraps +from functools import update_wrapper from typing import Any, get_args, get_origin, get_type_hints from reflex import constants @@ -721,57 +721,100 @@ def _is_component_child(value: Any) -> bool: ) -def _create_function_wrapper( - definition: ExperimentalMemoFunctionDefinition, -) -> Callable[..., Var]: - """Create the Python wrapper for a var-returning memo. +class _ExperimentalMemoFunctionWrapper: + """Callable wrapper for a var-returning experimental memo.""" - Args: - definition: The function memo definition. + def __init__(self, definition: ExperimentalMemoFunctionDefinition): + """Initialize the wrapper. - Returns: - The wrapper callable. - """ - imported_var = definition.imported_var + Args: + definition: The function memo definition. + """ + self._definition = definition + self._imported_var = definition.imported_var + update_wrapper(self, definition.fn) - @wraps(definition.fn) - def wrapper(*args: Any, **kwargs: Any) -> Var: - return imported_var.call( - *_bind_function_runtime_args(definition, *args, **kwargs) + def __call__(self, *args: Any, **kwargs: Any) -> Var: + """Call the wrapped memo and return a var. + + Args: + *args: Positional children, if supported. + **kwargs: Explicit props and rest props. + + Returns: + The function call var. + """ + return self.call(*args, **kwargs) + + def call(self, *args: Any, **kwargs: Any) -> Var: + """Call the imported memo function. + + Args: + *args: Positional children, if supported. + **kwargs: Explicit props and rest props. + + Returns: + The function call var. + """ + return self._imported_var.call( + *_bind_function_runtime_args(self._definition, *args, **kwargs) ) - def partial(*args: Any, **kwargs: Any) -> FunctionVar: - return imported_var.partial( - *_bind_function_runtime_args(definition, *args, **kwargs) + def partial(self, *args: Any, **kwargs: Any) -> FunctionVar: + """Partially apply the imported memo function. + + Args: + *args: Positional children, if supported. + **kwargs: Explicit props and rest props. + + Returns: + The partially applied function var. + """ + return self._imported_var.partial( + *_bind_function_runtime_args(self._definition, *args, **kwargs) ) - object.__setattr__(wrapper, "call", wrapper) - object.__setattr__(wrapper, "partial", partial) - object.__setattr__(wrapper, "_as_var", lambda: imported_var) - return wrapper + def _as_var(self) -> FunctionVar: + """Expose the imported function var. + Returns: + The imported function var. + """ + return self._imported_var -def _create_component_wrapper( - definition: ExperimentalMemoComponentDefinition, -) -> Callable[..., ExperimentalMemoComponent]: - """Create the Python wrapper for a component-returning memo. - Args: - definition: The component memo definition. +class _ExperimentalMemoComponentWrapper: + """Callable wrapper for a component-returning experimental memo.""" - Returns: - The wrapper callable. - """ - children_param = _get_children_param(definition.params) - rest_param = _get_rest_param(definition.params) - explicit_params = [ - param - for param in definition.params - if not param.is_children and not param.is_rest - ] + def __init__(self, definition: ExperimentalMemoComponentDefinition): + """Initialize the wrapper. + + Args: + definition: The component memo definition. + """ + self._definition = definition + self._children_param = _get_children_param(definition.params) + self._rest_param = _get_rest_param(definition.params) + self._explicit_params = [ + param + for param in definition.params + if not param.is_children and not param.is_rest + ] + update_wrapper(self, definition.fn) + + def __call__(self, *children: Any, **props: Any) -> ExperimentalMemoComponent: + """Call the wrapped memo and return a component. + + Args: + *children: Positional children passed to the memo. + **props: Explicit props and rest props. + + Returns: + The rendered memo component. + """ + definition = self._definition + rest_param = self._rest_param - @wraps(definition.fn) - def wrapper(*children: Any, **props: Any) -> ExperimentalMemoComponent: # Validate positional children usage and reserved keywords. if "children" in props: msg = f"`{definition.python_name}` only accepts children positionally." @@ -782,7 +825,7 @@ def wrapper(*children: Any, **props: Any) -> ExperimentalMemoComponent: f"arguments. Do not pass `{rest_param.name}=` directly." ) raise TypeError(msg) - if children and children_param is None: + if children and self._children_param is None: msg = f"`{definition.python_name}` only accepts keyword props." raise TypeError(msg) if any(not _is_component_child(child) for child in children): @@ -795,7 +838,7 @@ def wrapper(*children: Any, **props: Any) -> ExperimentalMemoComponent: # Bind declared props before collecting any rest props. explicit_values = {} remaining_props = props.copy() - for param in explicit_params: + for param in self._explicit_params: if param.name in remaining_props: explicit_values[param.name] = remaining_props.pop(param.name) elif param.default is not inspect.Parameter.empty: @@ -821,10 +864,41 @@ def wrapper(*children: Any, **props: Any) -> ExperimentalMemoComponent: **remaining_props, ) - object.__setattr__( - wrapper, "_as_var", lambda: _component_import_var(definition.export_name) - ) - return wrapper + def _as_var(self) -> Var: + """Expose the imported component var. + + Returns: + The imported component var. + """ + return _component_import_var(self._definition.export_name) + + +def _create_function_wrapper( + definition: ExperimentalMemoFunctionDefinition, +) -> _ExperimentalMemoFunctionWrapper: + """Create the Python wrapper for a var-returning memo. + + Args: + definition: The function memo definition. + + Returns: + The wrapper callable. + """ + return _ExperimentalMemoFunctionWrapper(definition) + + +def _create_component_wrapper( + definition: ExperimentalMemoComponentDefinition, +) -> _ExperimentalMemoComponentWrapper: + """Create the Python wrapper for a component-returning memo. + + Args: + definition: The component memo definition. + + Returns: + The wrapper callable. + """ + return _ExperimentalMemoComponentWrapper(definition) def memo(fn: Callable[..., Any]) -> Callable[..., Any]: From 668b8ac8798b1222c0f94d0e292f5aee80147d8a Mon Sep 17 00:00:00 2001 From: Farhan Ali Raza Date: Wed, 25 Mar 2026 14:05:25 +0500 Subject: [PATCH 08/10] updated hashes --- pyi_hashes.json | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pyi_hashes.json b/pyi_hashes.json index 29f3f3ac23e..aae35c45ddb 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -113,11 +113,11 @@ "reflex/components/react_player/video.pyi": "998671c06103d797c554d9278eb3b2a0", "reflex/components/react_router/dom.pyi": "2198359c2d8f3d1856f4391978a4e2de", "reflex/components/recharts/__init__.pyi": "6ee7f1ca2c0912f389ba6f3251a74d99", - "reflex/components/recharts/cartesian.pyi": "d138261ab8259d5208c2f028b9f708bd", - "reflex/components/recharts/charts.pyi": "013036b9c00ad85a570efdb813c1bc40", - "reflex/components/recharts/general.pyi": "d87ff9b85b2a204be01753690df4fb11", - "reflex/components/recharts/polar.pyi": "b8b1a3e996e066facdf4f8c9eb363137", - "reflex/components/recharts/recharts.pyi": "d5c9fc57a03b419748f0408c23319eee", - "reflex/components/sonner/toast.pyi": "dca44901640cda9d58c62ff8434faa3e", - "reflex/experimental/memo.pyi": "aab342a879269a0a7b67df46a7913cb7" + "reflex/components/recharts/cartesian.pyi": "4dc01da3195f80b9408d84373b874c41", + "reflex/components/recharts/charts.pyi": "16cd435d77f06f0315b595b7e62cf44b", + "reflex/components/recharts/general.pyi": "9abf71810a5405fd45b13804c0a7fd1a", + "reflex/components/recharts/polar.pyi": "ea4743e8903365ba95bc4b653c47cc4a", + "reflex/components/recharts/recharts.pyi": "b3d93d085d51053bbb8f65326f34a299", + "reflex/components/sonner/toast.pyi": "636050fcc919f8ab0903c30dceaa18f1", + "reflex/experimental/memo.pyi": "50be0b7fe796412f60f3bd1ead3f829a" } From 3642e5fad82230001aa27a27831d35865b5a12e1 Mon Sep 17 00:00:00 2001 From: Farhan Ali Raza Date: Wed, 25 Mar 2026 14:17:54 +0500 Subject: [PATCH 09/10] fix: accept Var[Component] return from component-returning memos Extract _normalize_component_return to wrap Var[Component] values in Bare.create, allowing memos that return rx.cond or other component-typed vars to be registered as component memos. Add a cond overload for (Any, Var[Component], Var[Component]) -> Component. --- reflex/components/core/cond.py | 4 ++++ reflex/experimental/memo.py | 26 +++++++++++++++++++++++--- tests/units/experimental/test_memo.py | 24 ++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/reflex/components/core/cond.py b/reflex/components/core/cond.py index eefcc04ef9d..209e414a5af 100644 --- a/reflex/components/core/cond.py +++ b/reflex/components/core/cond.py @@ -98,6 +98,10 @@ def cond(condition: Any, c1: Component, c2: Any, /) -> Component: ... # pyright def cond(condition: Any, c1: Component, /) -> Component: ... +@overload +def cond(condition: Any, c1: Var[Component], c2: Var[Component], /) -> Component: ... # pyright: ignore [reportOverlappingOverload] + + @overload def cond(condition: Any, c1: Any, c2: Component, /) -> Component: ... # pyright: ignore [reportOverlappingOverload] diff --git a/reflex/experimental/memo.py b/reflex/experimental/memo.py index 9139fcb2dc9..58c3bd8ce95 100644 --- a/reflex/experimental/memo.py +++ b/reflex/experimental/memo.py @@ -408,6 +408,26 @@ def _evaluate_memo_function( return fn(*positional_args, **keyword_args) +def _normalize_component_return(value: Any) -> Component | None: + """Normalize a component-like memo return value into a Component. + + Args: + value: The value returned from the memo function. + + Returns: + The normalized component, or ``None`` if the value is not component-like. + """ + if isinstance(value, Component): + return value + + if isinstance(value, Var) and type_utils.typehint_issubclass( + value._var_type, Component + ): + return Bare.create(value) + + return None + + def _lift_rest_props(component: Component) -> Component: """Convert RestProp children into special props. @@ -597,11 +617,11 @@ def _create_component_definition( TypeError: If the function does not return a component. """ params = _analyze_params(fn, for_component=True) - component = _evaluate_memo_function(fn, params) - if not isinstance(component, Component): + component = _normalize_component_return(_evaluate_memo_function(fn, params)) + if component is None: msg = ( f"Component-returning `@rx._x.memo` `{fn.__name__}` must return an " - f"`rx.Component`, got `{type(component).__name__}`." + "`rx.Component` or `rx.Var[rx.Component]`." ) raise TypeError(msg) diff --git a/tests/units/experimental/test_memo.py b/tests/units/experimental/test_memo.py index 006baae95bd..a67aa2d9ad2 100644 --- a/tests/units/experimental/test_memo.py +++ b/tests/units/experimental/test_memo.py @@ -105,6 +105,30 @@ def my_card( assert "jsx(RadixThemesBox,{...rest}" in code +def test_component_returning_memo_accepts_component_var_result(): + """Component-returning memos should accept component-typed var results.""" + + @rx._x.memo + def conditional_slot( + show: rx.Var[bool], + first: rx.Var[rx.Component], + second: rx.Var[rx.Component], + ) -> rx.Component: + return rx.cond(show, first, second) + + definition = EXPERIMENTAL_MEMOS["ConditionalSlot"] + assert isinstance(definition, ExperimentalMemoComponentDefinition) + assert definition.component.render() == { + "contents": "(showRxMemo ? firstRxMemo : secondRxMemo)" + } + + _, code, _ = compiler.compile_memo_components( + (), tuple(EXPERIMENTAL_MEMOS.values()) + ) + assert "export const ConditionalSlot = memo(({show:showRxMemo" in code + assert "(showRxMemo ? firstRxMemo : secondRxMemo)" in code + + def test_var_returning_memo_with_rest_props(): """Var-returning memos should capture extra keyword args into RestProp.""" From 4e434f052f0c6cd2d34b322c0165c5547749350d Mon Sep 17 00:00:00 2001 From: Farhan Ali Raza Date: Wed, 25 Mar 2026 14:26:42 +0500 Subject: [PATCH 10/10] refactor: create per-memo component subclasses with tag set at class level Replace instance-level self.tag assignment with cached dynamically created ExperimentalMemoComponent subclasses via _get_experimental_memo_component_class, so the tag is a class-level attribute rather than set in _post_init. --- pyi_hashes.json | 2 +- reflex/experimental/memo.py | 28 +++++++++++++++++++++++---- tests/units/experimental/test_memo.py | 4 ++++ 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/pyi_hashes.json b/pyi_hashes.json index aae35c45ddb..5d5bd8fb037 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -119,5 +119,5 @@ "reflex/components/recharts/polar.pyi": "ea4743e8903365ba95bc4b653c47cc4a", "reflex/components/recharts/recharts.pyi": "b3d93d085d51053bbb8f65326f34a299", "reflex/components/sonner/toast.pyi": "636050fcc919f8ab0903c30dceaa18f1", - "reflex/experimental/memo.pyi": "50be0b7fe796412f60f3bd1ead3f829a" + "reflex/experimental/memo.pyi": "78b1968972194785f72eab32476bc61d" } diff --git a/reflex/experimental/memo.py b/reflex/experimental/memo.py index 58c3bd8ce95..ee1912d1757 100644 --- a/reflex/experimental/memo.py +++ b/reflex/experimental/memo.py @@ -5,7 +5,7 @@ import dataclasses import inspect from collections.abc import Callable -from functools import update_wrapper +from functools import cache, update_wrapper from typing import Any, get_args, get_origin, get_type_hints from reflex import constants @@ -103,8 +103,6 @@ def _post_init(self, **kwargs): super()._post_init(**kwargs) - self.tag = definition.export_name - props: dict[str, Any] = {} for key, value in {**declared_props, **rest_props}.items(): camel_cased_key = format.to_camel_case(key) @@ -116,6 +114,28 @@ def _post_init(self, **kwargs): object.__setattr__(self, "get_props", lambda: prop_names) +@cache +def _get_experimental_memo_component_class( + export_name: str, +) -> type[ExperimentalMemoComponent]: + """Get the component subclass for an experimental memo export. + + Args: + export_name: The exported React component name. + + Returns: + A cached component subclass with the tag set at class definition time. + """ + return type( + f"ExperimentalMemoComponent_{export_name}", + (ExperimentalMemoComponent,), + { + "__module__": __name__, + "tag": export_name, + }, + ) + + EXPERIMENTAL_MEMOS: dict[str, ExperimentalMemoDefinition] = {} @@ -877,7 +897,7 @@ def __call__(self, *children: Any, **props: Any) -> ExperimentalMemoComponent: raise TypeError(msg) # Build the component props passed into the memo wrapper. - return ExperimentalMemoComponent._create( + return _get_experimental_memo_component_class(definition.export_name)._create( children=list(children), memo_definition=definition, **explicit_values, diff --git a/tests/units/experimental/test_memo.py b/tests/units/experimental/test_memo.py index a67aa2d9ad2..7de8bc012e9 100644 --- a/tests/units/experimental/test_memo.py +++ b/tests/units/experimental/test_memo.py @@ -82,10 +82,14 @@ def my_card( foo="extra", class_name="extra", ) + component_again = my_card(title="World") assert isinstance(component, ExperimentalMemoComponent) assert len(component.children) == 2 assert component.get_props() == ("title", "foo") + assert type(component) is type(component_again) + assert type(component).tag == "MyCard" + assert type(component).get_fields()["tag"].default == "MyCard" rendered = component.render() assert rendered["name"] == "MyCard"