diff --git a/pyi_hashes.json b/pyi_hashes.json index 152aa676fb0..d5d699547c5 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -1,5 +1,5 @@ { - "reflex/__init__.pyi": "0a3ae880e256b9fd3b960e12a2cb51a7", + "reflex/__init__.pyi": "224964f24351c79614ab9a4ae47560a0", "reflex/components/__init__.pyi": "ac05995852baa81062ba3d18fbc489fb", "reflex/components/base/__init__.pyi": "16e47bf19e0d62835a605baa3d039c5a", "reflex/components/base/app_wrap.pyi": "22e94feaa9fe675bcae51c412f5b67f1", @@ -118,5 +118,6 @@ "reflex/components/recharts/general.pyi": "d87ff9b85b2a204be01753690df4fb11", "reflex/components/recharts/polar.pyi": "b8b1a3e996e066facdf4f8c9eb363137", "reflex/components/recharts/recharts.pyi": "d5c9fc57a03b419748f0408c23319eee", - "reflex/components/sonner/toast.pyi": "dca44901640cda9d58c62ff8434faa3e" + "reflex/components/sonner/toast.pyi": "dca44901640cda9d58c62ff8434faa3e", + "reflex/experimental/memo.pyi": "a1c5c4682fc4dadbd82a0a5e8fd4bd32" } diff --git a/reflex/__init__.py b/reflex/__init__.py index 066df110f02..29fc5983971 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -342,7 +342,7 @@ "utils.imports": ["ImportDict", "ImportVar"], "utils.misc": ["run_in_thread"], "utils.serializers": ["serializer"], - "vars": ["Var", "field", "Field"], + "vars": ["Var", "field", "Field", "RestProp"], } _SUBMODULES: set[str] = { diff --git a/reflex/app.py b/reflex/app.py index 9f4ef9c6229..98896561aaf 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -84,6 +84,7 @@ get_hydrate_event, noop, ) +from reflex.experimental.memo import EXPERIMENTAL_MEMOS from reflex.istate.manager import StateModificationContext from reflex.istate.proxy import StateProxy from reflex.page import DECORATED_PAGES @@ -1321,7 +1322,10 @@ def memoized_toast_provider(): memo_components_output, memo_components_result, memo_components_imports, - ) = compiler.compile_memo_components(dict.fromkeys(CUSTOM_COMPONENTS.values())) + ) = compiler.compile_memo_components( + dict.fromkeys(CUSTOM_COMPONENTS.values()), + tuple(EXPERIMENTAL_MEMOS.values()), + ) compile_results.append((memo_components_output, memo_components_result)) all_imports.update(memo_components_imports) progress.advance(task) diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index a097b2d99f2..c5f18cad07c 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -22,6 +22,11 @@ from reflex.constants.compiler import PageNames, ResetStylesheet from reflex.constants.state import FIELD_MARKER from reflex.environment import environment +from reflex.experimental.memo import ( + ExperimentalMemoComponentDefinition, + ExperimentalMemoDefinition, + ExperimentalMemoFunctionDefinition, +) from reflex.state import BaseState from reflex.style import SYSTEM_COLOR_MODE from reflex.utils import console, path_ops @@ -339,20 +344,20 @@ def _compile_component(component: Component | StatefulComponent) -> str: def _compile_memo_components( components: Iterable[CustomComponent], + experimental_memos: Iterable[ExperimentalMemoDefinition] = (), ) -> tuple[str, dict[str, list[ImportVar]]]: """Compile the components. Args: components: The components to compile. + experimental_memos: The experimental memos to compile. Returns: The compiled components. """ - imports = { - "react": [ImportVar(tag="memo")], - f"$/{constants.Dirs.STATE_PATH}": [ImportVar(tag="isTrue")], - } + imports: dict[str, list[ImportVar]] = {} component_renders = [] + function_renders = [] # Compile each component. for component in components: @@ -360,7 +365,25 @@ def _compile_memo_components( component_renders.append(component_render) imports = utils.merge_imports(imports, component_imports) - _apply_common_imports(imports) + for memo in experimental_memos: + if isinstance(memo, ExperimentalMemoComponentDefinition): + memo_render, memo_imports = utils.compile_experimental_component_memo(memo) + component_renders.append(memo_render) + imports = utils.merge_imports(imports, memo_imports) + elif isinstance(memo, ExperimentalMemoFunctionDefinition): + memo_render, memo_imports = utils.compile_experimental_function_memo(memo) + function_renders.append(memo_render) + imports = utils.merge_imports(imports, memo_imports) + + if component_renders: + imports = utils.merge_imports( + { + "react": [ImportVar(tag="memo")], + f"$/{constants.Dirs.STATE_PATH}": [ImportVar(tag="isTrue")], + }, + imports, + ) + _apply_common_imports(imports) dynamic_imports = { comp_import: None @@ -380,6 +403,7 @@ def _compile_memo_components( templates.memo_components_template( imports=utils.compile_imports(imports), components=component_renders, + functions=function_renders, dynamic_imports=sorted(dynamic_imports), custom_codes=custom_codes, ), @@ -573,11 +597,13 @@ def compile_page(path: str, component: BaseComponent) -> tuple[str, str]: def compile_memo_components( components: Iterable[CustomComponent], + experimental_memos: Iterable[ExperimentalMemoDefinition] = (), ) -> tuple[str, str, dict[str, list[ImportVar]]]: """Compile the custom components. Args: components: The custom components to compile. + experimental_memos: The experimental memos to compile. Returns: The path and code of the compiled components. @@ -586,7 +612,7 @@ def compile_memo_components( output_path = utils.get_components_path() # Compile the components. - code, imports = _compile_memo_components(components) + code, imports = _compile_memo_components(components, experimental_memos) return output_path, code, imports diff --git a/reflex/compiler/templates.py b/reflex/compiler/templates.py index a8a7dbe4ec2..1bf6c8ce061 100644 --- a/reflex/compiler/templates.py +++ b/reflex/compiler/templates.py @@ -8,7 +8,6 @@ from reflex import constants from reflex.constants import Hooks -from reflex.constants.state import CAMEL_CASE_MEMO_MARKER from reflex.utils.format import format_state_name, json_dumps from reflex.vars.base import VarData @@ -661,6 +660,7 @@ def stateful_components_template(imports: list[_ImportDict], memoized_code: str) def memo_components_template( imports: list[_ImportDict], components: list[dict[str, Any]], + functions: list[dict[str, Any]], dynamic_imports: Iterable[str], custom_codes: Iterable[str], ) -> str: @@ -669,6 +669,7 @@ def memo_components_template( Args: imports: List of import statements. components: List of component definitions. + functions: List of function definitions. dynamic_imports: List of dynamic import statements. custom_codes: List of custom code snippets. @@ -682,7 +683,7 @@ def memo_components_template( components_code = "" for component in components: components_code += f""" -export const {component["name"]} = memo(({{ {",".join([f"{prop}:{prop}{CAMEL_CASE_MEMO_MARKER}" for prop in component.get("props", [])])} }}) => {{ +export const {component["name"]} = memo(({component["signature"]}) => {{ {_render_hooks(component.get("hooks", {}))} return( {_RenderUtils.render(component["render"])} @@ -690,6 +691,12 @@ def memo_components_template( }}); """ + functions_code = "" + for function in functions: + functions_code += ( + f"\nexport const {function['name']} = {function['function']};\n" + ) + return f""" {imports_str} @@ -697,6 +704,8 @@ def memo_components_template( {custom_code_str} +{functions_code} + {components_code}""" diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index f02edacbd71..cd3d2f6ce2b 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -4,6 +4,7 @@ import asyncio import concurrent.futures +import copy import operator import traceback from collections.abc import Mapping, Sequence @@ -20,7 +21,11 @@ from reflex.components.el.elements.metadata import Head, Link, Meta, Title from reflex.components.el.elements.other import Html from reflex.components.el.elements.sectioning import Body -from reflex.constants.state import FIELD_MARKER +from reflex.constants.state import CAMEL_CASE_MEMO_MARKER, FIELD_MARKER +from reflex.experimental.memo import ( + ExperimentalMemoComponentDefinition, + ExperimentalMemoFunctionDefinition, +) from reflex.istate.storage import Cookie, LocalStorage, SessionStorage from reflex.state import BaseState, _resolve_delta from reflex.style import Style @@ -28,6 +33,7 @@ from reflex.utils.imports import ImportVar, ParsedImportDict from reflex.utils.prerequisites import get_web_dir from reflex.vars.base import Field, Var, VarData +from reflex.vars.function import DestructuredArg # To re-export this function. merge_imports = imports.merge_imports @@ -344,6 +350,9 @@ def compile_custom_component( { "name": component.tag, "props": props, + "signature": DestructuredArg( + fields=tuple(f"{prop}:{prop}{CAMEL_CASE_MEMO_MARKER}" for prop in props) + ).to_javascript(), "render": render.render(), "hooks": render._get_all_hooks(), "custom_code": render._get_all_custom_code(), @@ -353,6 +362,104 @@ def compile_custom_component( ) +def _apply_component_style_for_compile(component: Component) -> Component: + """Apply the app style to a compiled component tree. + + Args: + component: The component tree. + + Returns: + The styled component tree. + """ + try: + from reflex.utils.prerequisites import get_and_validate_app + + style = get_and_validate_app().app.style + except Exception: + style = {} + + component._add_style_recursive(style) + return component + + +def compile_experimental_component_memo( + definition: ExperimentalMemoComponentDefinition, +) -> tuple[dict, ParsedImportDict]: + """Compile an experimental memo component. + + Args: + definition: The component memo definition. + + Returns: + A tuple of the compiled component definition and its imports. + """ + render = _apply_component_style_for_compile(copy.deepcopy(definition.component)) + + imports: ParsedImportDict = { + lib: fields + for lib, fields in render._get_all_imports().items() + if lib != f"$/{constants.Dirs.COMPONENTS_PATH}" + } + + imports.setdefault("@emotion/react", []).append(ImportVar("jsx")) + + signature_fields = [ + f"{param.js_prop_name}:{param.placeholder_name}" + for param in definition.params + if not param.is_children and not param.is_rest + ] + + if any(param.is_children for param in definition.params): + signature_fields.insert(0, "children") + + rest_param = next((param for param in definition.params if param.is_rest), None) + + return ( + { + "kind": "component", + "name": definition.export_name, + "signature": DestructuredArg( + fields=tuple(signature_fields), + rest=rest_param.placeholder_name if rest_param is not None else None, + ).to_javascript(), + "render": render.render(), + "hooks": render._get_all_hooks(), + "custom_code": render._get_all_custom_code(), + "dynamic_imports": render._get_all_dynamic_imports(), + }, + imports, + ) + + +def compile_experimental_function_memo( + definition: ExperimentalMemoFunctionDefinition, +) -> tuple[dict, ParsedImportDict]: + """Compile an experimental memo function. + + Args: + definition: The function memo definition. + + Returns: + A tuple of the compiled function definition and its imports. + """ + imports: ParsedImportDict = {} + if var_data := definition.function._get_all_var_data(): + imports = { + lib: list(fields) + for lib, fields in dict(var_data.imports).items() + if lib != f"$/{constants.Dirs.COMPONENTS_PATH}" + } + + return ( + { + "kind": "function", + "name": definition.python_name, + "function": str(definition.function), + }, + imports, + ) + + def create_document_root( head_components: Sequence[Component] | None = None, html_lang: str | None = None, diff --git a/reflex/experimental/__init__.py b/reflex/experimental/__init__.py index 2734da19112..8ae5f3ee8a6 100644 --- a/reflex/experimental/__init__.py +++ b/reflex/experimental/__init__.py @@ -8,6 +8,7 @@ from . import hooks as hooks from .client_state import ClientStateVar as ClientStateVar +from .memo import memo as memo class ExperimentalNamespace(SimpleNamespace): @@ -58,4 +59,5 @@ def register_component_warning(component_name: str): client_state=ClientStateVar.create, hooks=hooks, code_block=code_block, + memo=memo, ) diff --git a/reflex/experimental/memo.py b/reflex/experimental/memo.py new file mode 100644 index 00000000000..7115e018d23 --- /dev/null +++ b/reflex/experimental/memo.py @@ -0,0 +1,850 @@ +"""Experimental memo support for vars and components.""" + +from __future__ import annotations + +import dataclasses +import inspect +from collections.abc import Callable +from functools import wraps +from typing import Any, get_args, get_origin, get_type_hints + +from reflex import constants +from reflex.components.base.bare import Bare +from reflex.components.base.fragment import Fragment +from reflex.components.component import Component +from reflex.components.dynamic import bundled_libraries +from reflex.constants.compiler import SpecialAttributes +from reflex.constants.state import CAMEL_CASE_MEMO_MARKER +from reflex.utils import format +from reflex.utils import types as type_utils +from reflex.utils.imports import ImportVar +from reflex.vars import VarData +from reflex.vars.base import LiteralVar, Var +from reflex.vars.function import ( + ArgsFunctionOperation, + DestructuredArg, + FunctionStringVar, + FunctionVar, + ReflexCallable, +) +from reflex.vars.object import RestProp + + +@dataclasses.dataclass(frozen=True, slots=True) +class MemoParam: + """Metadata about a memo parameter.""" + + name: str + annotation: Any + kind: inspect._ParameterKind + default: Any = inspect.Parameter.empty + js_prop_name: str | None = None + placeholder_name: str = "" + is_children: bool = False + is_rest: bool = False + + +@dataclasses.dataclass(frozen=True, slots=True) +class ExperimentalMemoDefinition: + """Base metadata for an experimental memo.""" + + fn: Callable[..., Any] + python_name: str + params: tuple[MemoParam, ...] + + +@dataclasses.dataclass(frozen=True, slots=True) +class ExperimentalMemoFunctionDefinition(ExperimentalMemoDefinition): + """A memo that compiles to a JavaScript function.""" + + function: ArgsFunctionOperation + imported_var: FunctionVar + + +@dataclasses.dataclass(frozen=True, slots=True) +class ExperimentalMemoComponentDefinition(ExperimentalMemoDefinition): + """A memo that compiles to a React component.""" + + export_name: str + component: Component + + +class ExperimentalMemoComponent(Component): + """A rendered instance of an experimental memo component.""" + + library = f"$/{constants.Dirs.COMPONENTS_PATH}" + + def _post_init(self, **kwargs): + """Initialize the experimental memo component. + + Args: + **kwargs: The kwargs to pass to the component. + """ + definition = kwargs.pop("memo_definition") + + explicit_props = { + param.name + for param in definition.params + if not param.is_children and not param.is_rest + } + component_fields = self.get_fields() + + declared_props = { + key: kwargs.pop(key) for key in list(kwargs) if key in explicit_props + } + + rest_props = {} + if _get_rest_param(definition.params) is not None: + rest_props = { + key: kwargs.pop(key) + for key in list(kwargs) + if key not in component_fields and not SpecialAttributes.is_special(key) + } + + super()._post_init(**kwargs) + + self.tag = definition.export_name + + props: dict[str, Any] = {} + for key, value in {**declared_props, **rest_props}.items(): + camel_cased_key = format.to_camel_case(key) + literal_value = LiteralVar.create(value) + props[camel_cased_key] = literal_value + setattr(self, camel_cased_key, literal_value) + + prop_names = dict.fromkeys(props) + object.__setattr__(self, "get_props", lambda: prop_names) + + +EXPERIMENTAL_MEMOS: dict[str, ExperimentalMemoDefinition] = {} + + +def _memo_registry_key(definition: ExperimentalMemoDefinition) -> str: + """Get the registry key for an experimental memo. + + Args: + definition: The memo definition. + + Returns: + The registry key for the memo. + """ + if isinstance(definition, ExperimentalMemoComponentDefinition): + return definition.export_name + return definition.python_name + + +def _is_memo_reregistration( + existing: ExperimentalMemoDefinition, + definition: ExperimentalMemoDefinition, +) -> bool: + """Check whether a memo definition replaces the same memo during reload. + + Args: + existing: The currently registered memo definition. + definition: The new memo definition being registered. + + Returns: + Whether the new definition should replace the existing one. + """ + return ( + type(existing) is type(definition) + and existing.python_name == definition.python_name + and existing.fn.__module__ == definition.fn.__module__ + and existing.fn.__qualname__ == definition.fn.__qualname__ + ) + + +def _register_memo_definition(definition: ExperimentalMemoDefinition) -> None: + """Register an experimental memo definition. + + Args: + definition: The memo definition to register. + + Raises: + ValueError: If another memo already compiles to the same exported name. + """ + key = _memo_registry_key(definition) + if (existing := EXPERIMENTAL_MEMOS.get(key)) is not None and ( + not _is_memo_reregistration(existing, definition) + ): + msg = ( + f"Experimental memo name collision for `{key}`: " + f"`{existing.fn.__module__}.{existing.python_name}` and " + f"`{definition.fn.__module__}.{definition.python_name}` both compile " + "to the same memo name." + ) + raise ValueError(msg) + + EXPERIMENTAL_MEMOS[key] = definition + + +def _annotation_inner_type(annotation: Any) -> Any: + """Unwrap a Var-like annotation to its inner type. + + Args: + annotation: The annotation to unwrap. + + Returns: + The inner type for the annotation. + """ + if _is_rest_annotation(annotation): + return dict[str, Any] + + origin = get_origin(annotation) or annotation + if type_utils.safe_issubclass(origin, Var) and (args := get_args(annotation)): + return args[0] + return Any + + +def _is_rest_annotation(annotation: Any) -> bool: + """Check whether an annotation is a RestProp. + + Args: + annotation: The annotation to check. + + Returns: + Whether the annotation is a RestProp. + """ + origin = get_origin(annotation) or annotation + return isinstance(origin, type) and issubclass(origin, RestProp) + + +def _is_var_annotation(annotation: Any) -> bool: + """Check whether an annotation is a Var-like annotation. + + Args: + annotation: The annotation to check. + + Returns: + Whether the annotation is Var-like. + """ + origin = get_origin(annotation) or annotation + return isinstance(origin, type) and issubclass(origin, Var) + + +def _is_component_annotation(annotation: Any) -> bool: + """Check whether an annotation is component-like. + + Args: + annotation: The annotation to check. + + Returns: + Whether the annotation resolves to Component. + """ + origin = get_origin(annotation) or annotation + return isinstance(origin, type) and issubclass(origin, Component) + + +def _children_annotation_is_valid(annotation: Any) -> bool: + """Check whether an annotation is valid for children. + + Args: + annotation: The annotation to check. + + Returns: + Whether the annotation is valid for children. + """ + return _is_var_annotation(annotation) and type_utils.typehint_issubclass( + _annotation_inner_type(annotation), Component + ) + + +def _get_children_param(params: tuple[MemoParam, ...]) -> MemoParam | None: + return next((param for param in params if param.is_children), None) + + +def _get_rest_param(params: tuple[MemoParam, ...]) -> MemoParam | None: + return next((param for param in params if param.is_rest), None) + + +def _imported_function_var(name: str, return_type: Any) -> FunctionVar: + """Create the imported FunctionVar for an experimental memo. + + Args: + name: The exported function name. + return_type: The return type of the function. + + Returns: + The imported FunctionVar. + """ + return FunctionStringVar.create( + name, + _var_type=ReflexCallable[Any, return_type], + _var_data=VarData( + imports={f"$/{constants.Dirs.COMPONENTS_PATH}": [ImportVar(tag=name)]} + ), + ) + + +def _component_import_var(name: str) -> Var: + """Create the imported component var for an experimental memo component. + + Args: + name: The exported component name. + + Returns: + The component var. + """ + return Var( + name, + _var_type=type[Component], + _var_data=VarData( + imports={ + f"$/{constants.Dirs.COMPONENTS_PATH}": [ImportVar(tag=name)], + "@emotion/react": [ImportVar(tag="jsx")], + } + ), + ) + + +def _validate_var_return_expr(return_expr: Var, func_name: str) -> None: + """Validate that a var-returning memo can compile safely. + + Args: + return_expr: The return expression. + func_name: The function name for error messages. + + Raises: + TypeError: If the return expression depends on unsupported features. + """ + var_data = VarData.merge(return_expr._get_all_var_data()) + if var_data is None: + return + + if var_data.hooks: + msg = ( + f"Var-returning `@rx._x.memo` `{func_name}` cannot depend on hooks. " + "Use a component-returning `@rx._x.memo` instead." + ) + raise TypeError(msg) + + if var_data.components: + msg = ( + f"Var-returning `@rx._x.memo` `{func_name}` cannot depend on embedded " + "components, custom code, or dynamic imports. Use a component-returning " + "`@rx._x.memo` instead." + ) + raise TypeError(msg) + + for lib in dict(var_data.imports): + if not lib: + continue + if lib.startswith((".", "/", "$/", "http")): + continue + if format.format_library_name(lib) in bundled_libraries: + continue + msg = ( + f"Var-returning `@rx._x.memo` `{func_name}` cannot import `{lib}` because " + "it is not bundled. Use a component-returning `@rx._x.memo` instead." + ) + raise TypeError(msg) + + +def _rest_placeholder(name: str) -> RestProp: + """Create the placeholder RestProp. + + Args: + name: The JavaScript identifier. + + Returns: + The placeholder rest prop. + """ + return RestProp(_js_expr=name, _var_type=dict[str, Any]) + + +def _var_placeholder(name: str, annotation: Any) -> Var: + """Create a placeholder Var for a memo parameter. + + Args: + name: The JavaScript identifier. + annotation: The parameter annotation. + + Returns: + The placeholder Var. + """ + return Var(_js_expr=name, _var_type=_annotation_inner_type(annotation)).guess_type() + + +def _placeholder_for_param(param: MemoParam) -> Var: + """Create a placeholder var for a parameter. + + Args: + param: The parameter metadata. + + Returns: + The placeholder var. + """ + if param.is_rest: + return _rest_placeholder(param.placeholder_name) + return _var_placeholder(param.placeholder_name, param.annotation) + + +def _evaluate_memo_function( + fn: Callable[..., Any], + params: tuple[MemoParam, ...], +) -> Any: + """Evaluate a memo function with placeholder vars. + + Args: + fn: The function to evaluate. + params: The memo parameters. + + Returns: + The return value from the function. + """ + positional_args = [] + keyword_args = {} + + for param in params: + placeholder = _placeholder_for_param(param) + if param.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + positional_args.append(placeholder) + else: + keyword_args[param.name] = placeholder + + return fn(*positional_args, **keyword_args) + + +def _lift_rest_props(component: Component) -> Component: + """Convert RestProp children into special props. + + Args: + component: The component tree to rewrite. + + Returns: + The rewritten component tree. + """ + special_props = list(component.special_props) + rewritten_children = [] + + for child in component.children: + if isinstance(child, Bare) and isinstance(child.contents, RestProp): + special_props.append(child.contents) + continue + + if isinstance(child, Component): + child = _lift_rest_props(child) + + rewritten_children.append(child) + + component.children = rewritten_children + component.special_props = special_props + return component + + +def _analyze_params( + fn: Callable[..., Any], + *, + for_component: bool, +) -> tuple[MemoParam, ...]: + """Analyze and validate memo parameters. + + Args: + fn: The function to analyze. + for_component: Whether the memo returns a component. + + Returns: + The analyzed parameters. + + Raises: + TypeError: If the function signature is not supported. + """ + signature = inspect.signature(fn) + hints = get_type_hints(fn) + + params: list[MemoParam] = [] + rest_count = 0 + + for parameter in signature.parameters.values(): + if parameter.kind is inspect.Parameter.VAR_POSITIONAL: + msg = f"`@rx._x.memo` does not support `*args` in `{fn.__name__}`." + raise TypeError(msg) + if parameter.kind is inspect.Parameter.VAR_KEYWORD: + msg = f"`@rx._x.memo` does not support `**kwargs` in `{fn.__name__}`." + raise TypeError(msg) + if parameter.kind is inspect.Parameter.POSITIONAL_ONLY: + msg = ( + f"`@rx._x.memo` does not support positional-only parameters in " + f"`{fn.__name__}`." + ) + raise TypeError(msg) + + annotation = hints.get(parameter.name, parameter.annotation) + if annotation is inspect.Parameter.empty: + msg = ( + f"All parameters of `{fn.__name__}` must be annotated as `rx.Var[...]` " + f"or `rx.RestProp`. Missing annotation for `{parameter.name}`." + ) + raise TypeError(msg) + + is_rest = _is_rest_annotation(annotation) + is_children = parameter.name == "children" and _children_annotation_is_valid( + annotation + ) + + if parameter.name == "children" and not is_children: + msg = ( + f"`children` in `{fn.__name__}` must be annotated as " + "`rx.Var[rx.Component]`." + ) + raise TypeError(msg) + + if not is_rest and not _is_var_annotation(annotation): + msg = ( + f"All parameters of `{fn.__name__}` must be annotated as `rx.Var[...]` " + f"or `rx.RestProp`, got `{annotation}` for `{parameter.name}`." + ) + raise TypeError(msg) + + if is_rest: + rest_count += 1 + if rest_count > 1: + msg = ( + f"`@rx._x.memo` only supports one `rx.RestProp` in `{fn.__name__}`." + ) + raise TypeError(msg) + + js_prop_name = format.to_camel_case(parameter.name) + placeholder_name = ( + parameter.name + if is_children or is_rest or not for_component + else js_prop_name + CAMEL_CASE_MEMO_MARKER + ) + + params.append( + MemoParam( + name=parameter.name, + annotation=annotation, + kind=parameter.kind, + default=parameter.default, + js_prop_name=js_prop_name, + placeholder_name=placeholder_name, + is_children=is_children, + is_rest=is_rest, + ) + ) + + return tuple(params) + + +def _create_function_definition( + fn: Callable[..., Any], + return_annotation: Any, +) -> ExperimentalMemoFunctionDefinition: + """Create a definition for a var-returning memo. + + Args: + fn: The function to analyze. + return_annotation: The return annotation. + + Returns: + The function memo definition. + """ + params = _analyze_params(fn, for_component=False) + return_expr = Var.create(_evaluate_memo_function(fn, params)) + _validate_var_return_expr(return_expr, fn.__name__) + + children_param = _get_children_param(params) + rest_param = _get_rest_param(params) + if children_param is None and rest_param is None: + function = ArgsFunctionOperation.create( + args_names=tuple(param.placeholder_name for param in params), + return_expr=return_expr, + ) + else: + function = ArgsFunctionOperation.create( + args_names=( + DestructuredArg( + fields=tuple( + param.placeholder_name for param in params if not param.is_rest + ), + rest=( + rest_param.placeholder_name if rest_param is not None else None + ), + ), + ), + return_expr=return_expr, + ) + + return ExperimentalMemoFunctionDefinition( + fn=fn, + python_name=fn.__name__, + params=params, + function=function, + imported_var=_imported_function_var( + fn.__name__, _annotation_inner_type(return_annotation) + ), + ) + + +def _create_component_definition( + fn: Callable[..., Any], + return_annotation: Any, +) -> ExperimentalMemoComponentDefinition: + """Create a definition for a component-returning memo. + + Args: + fn: The function to analyze. + return_annotation: The return annotation. + + Returns: + The component memo definition. + + Raises: + TypeError: If the function does not return a component. + """ + params = _analyze_params(fn, for_component=True) + component = _evaluate_memo_function(fn, params) + if not isinstance(component, Component): + msg = ( + f"Component-returning `@rx._x.memo` `{fn.__name__}` must return an " + f"`rx.Component`, got `{type(component).__name__}`." + ) + raise TypeError(msg) + + return ExperimentalMemoComponentDefinition( + fn=fn, + python_name=fn.__name__, + params=params, + export_name=format.to_title_case(fn.__name__), + component=_lift_rest_props(component), + ) + + +def _bind_function_runtime_args( + definition: ExperimentalMemoFunctionDefinition, + *args: Any, + **kwargs: Any, +) -> tuple[Any, ...]: + """Bind runtime args for a var-returning memo. + + Args: + definition: The function memo definition. + *args: Positional arguments. + **kwargs: Keyword arguments. + + Returns: + The ordered arguments for the imported FunctionVar. + + Raises: + TypeError: If the provided arguments are invalid. + """ + children_param = _get_children_param(definition.params) + rest_param = _get_rest_param(definition.params) + if "children" in kwargs: + msg = f"`{definition.python_name}` only accepts children positionally." + raise TypeError(msg) + + if rest_param is not None and rest_param.name in kwargs: + msg = ( + f"`{definition.python_name}` captures rest props from extra keyword " + f"arguments. Do not pass `{rest_param.name}=` directly." + ) + raise TypeError(msg) + + if args and children_param is None: + msg = f"`{definition.python_name}` only accepts keyword props." + raise TypeError(msg) + + if any(not _is_component_child(child) for child in args): + msg = ( + f"`{definition.python_name}` only accepts positional children that are " + "`rx.Component` or `rx.Var[rx.Component]`." + ) + raise TypeError(msg) + + explicit_params = [ + param + for param in definition.params + if not param.is_rest and not param.is_children + ] + explicit_values = {} + remaining_props = kwargs.copy() + for param in explicit_params: + if param.name in remaining_props: + explicit_values[param.name] = remaining_props.pop(param.name) + elif param.default is not inspect.Parameter.empty: + explicit_values[param.name] = param.default + else: + msg = f"`{definition.python_name}` is missing required prop `{param.name}`." + raise TypeError(msg) + + if remaining_props and rest_param is None: + unexpected_prop = next(iter(remaining_props)) + msg = ( + f"`{definition.python_name}` does not accept prop `{unexpected_prop}`. " + "Only declared props may be passed when no `rx.RestProp` is present." + ) + raise TypeError(msg) + + if children_param is None and rest_param is None: + return tuple(explicit_values[param.name] for param in explicit_params) + + children_value: Any | None = None + if children_param is not None: + children_value = args[0] if len(args) == 1 else Fragment.create(*args) + + bound_props = {} + if children_param is not None: + bound_props[children_param.name] = children_value + bound_props.update(explicit_values) + bound_props.update(remaining_props) + return (bound_props,) + + +def _is_component_child(value: Any) -> bool: + """Check whether a value is valid as an experimental memo child. + + Args: + value: The value to check. + + Returns: + Whether the value is a component child. + """ + return isinstance(value, Component) or ( + isinstance(value, Var) + and type_utils.typehint_issubclass(value._var_type, Component) + ) + + +def _create_function_wrapper( + definition: ExperimentalMemoFunctionDefinition, +) -> Callable[..., Var]: + """Create the Python wrapper for a var-returning memo. + + Args: + definition: The function memo definition. + + Returns: + The wrapper callable. + """ + imported_var = definition.imported_var + + @wraps(definition.fn) + def wrapper(*args: Any, **kwargs: Any) -> Var: + return imported_var.call( + *_bind_function_runtime_args(definition, *args, **kwargs) + ) + + def partial(*args: Any, **kwargs: Any) -> FunctionVar: + return imported_var.partial( + *_bind_function_runtime_args(definition, *args, **kwargs) + ) + + object.__setattr__(wrapper, "call", wrapper) + object.__setattr__(wrapper, "partial", partial) + object.__setattr__(wrapper, "_as_var", lambda: imported_var) + return wrapper + + +def _create_component_wrapper( + definition: ExperimentalMemoComponentDefinition, +) -> Callable[..., ExperimentalMemoComponent]: + """Create the Python wrapper for a component-returning memo. + + Args: + definition: The component memo definition. + + Returns: + The wrapper callable. + """ + children_param = _get_children_param(definition.params) + rest_param = _get_rest_param(definition.params) + explicit_params = [ + param + for param in definition.params + if not param.is_children and not param.is_rest + ] + + @wraps(definition.fn) + def wrapper(*children: Any, **props: Any) -> ExperimentalMemoComponent: + if "children" in props: + msg = f"`{definition.python_name}` only accepts children positionally." + raise TypeError(msg) + if rest_param is not None and rest_param.name in props: + msg = ( + f"`{definition.python_name}` captures rest props from extra keyword " + f"arguments. Do not pass `{rest_param.name}=` directly." + ) + raise TypeError(msg) + if children and children_param is None: + msg = f"`{definition.python_name}` only accepts keyword props." + raise TypeError(msg) + if any(not _is_component_child(child) for child in children): + msg = ( + f"`{definition.python_name}` only accepts positional children that are " + "`rx.Component` or `rx.Var[rx.Component]`." + ) + raise TypeError(msg) + + explicit_values = {} + remaining_props = props.copy() + for param in explicit_params: + if param.name in remaining_props: + explicit_values[param.name] = remaining_props.pop(param.name) + elif param.default is not inspect.Parameter.empty: + explicit_values[param.name] = param.default + else: + msg = f"`{definition.python_name}` is missing required prop `{param.name}`." + raise TypeError(msg) + + if remaining_props and rest_param is None: + unexpected_prop = next(iter(remaining_props)) + msg = ( + f"`{definition.python_name}` does not accept prop `{unexpected_prop}`. " + "Only declared props may be passed when no `rx.RestProp` is present." + ) + raise TypeError(msg) + + return ExperimentalMemoComponent._create( + children=list(children), + memo_definition=definition, + **explicit_values, + **remaining_props, + ) + + object.__setattr__( + wrapper, "_as_var", lambda: _component_import_var(definition.export_name) + ) + return wrapper + + +def memo(fn: Callable[..., Any]) -> Callable[..., Any]: + """Create an experimental memo from a function. + + Args: + fn: The function to memoize. + + Returns: + The wrapped function or component factory. + + Raises: + TypeError: If the return type is not supported. + """ + hints = get_type_hints(fn) + return_annotation = hints.get("return", inspect.Signature.empty) + if return_annotation is inspect.Signature.empty: + msg = ( + f"`@rx._x.memo` requires a return annotation on `{fn.__name__}`. " + "Use `-> rx.Component` or `-> rx.Var[...]`." + ) + raise TypeError(msg) + + if _is_component_annotation(return_annotation): + definition = _create_component_definition(fn, return_annotation) + _register_memo_definition(definition) + return _create_component_wrapper(definition) + + if _is_var_annotation(return_annotation): + definition = _create_function_definition(fn, return_annotation) + _register_memo_definition(definition) + return _create_function_wrapper(definition) + + msg = ( + f"`@rx._x.memo` on `{fn.__name__}` must return `rx.Component` or `rx.Var[...]`, " + f"got `{return_annotation}`." + ) + raise TypeError(msg) diff --git a/reflex/testing.py b/reflex/testing.py index 5b5892ad33c..66e1f11b485 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -34,9 +34,10 @@ import reflex.utils.format import reflex.utils.prerequisites import reflex.utils.processes -from reflex.components.component import CustomComponent +from reflex.components.component import CUSTOM_COMPONENTS, CustomComponent from reflex.config import get_config from reflex.environment import environment +from reflex.experimental.memo import EXPERIMENTAL_MEMOS from reflex.istate.manager.disk import StateManagerDisk from reflex.istate.manager.memory import StateManagerMemory from reflex.istate.manager.redis import StateManagerRedis @@ -243,6 +244,10 @@ def _initialize_app(self): # disable telemetry reporting for tests os.environ["REFLEX_TELEMETRY_ENABLED"] = "false" + # Reset global memo registries so previous AppHarness apps do not + # leak compiled component definitions into the next test app. + CUSTOM_COMPONENTS.clear() + EXPERIMENTAL_MEMOS.clear() CustomComponent.create().get_component.cache_clear() self.app_path.mkdir(parents=True, exist_ok=True) if self.app_source is not None: @@ -269,15 +274,18 @@ def _initialize_app(self): reflex.utils.prerequisites.initialize_frontend_dependencies() with chdir(self.app_path): # ensure config and app are reloaded when testing different app - reflex.config.get_config(reload=True) + config = reflex.config.get_config(reload=True) # Ensure the AppHarness test does not skip State assignment due to running via pytest os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None) os.environ[reflex.constants.APP_HARNESS_FLAG] = "true" - # Ensure we actually compile the app during first initialization. + # Ensure we compile generated apps, and reload pre-existing app modules + # that were already imported so they can re-register memo definitions. + should_reload_app = ( + self.app_source is not None or config.module in sys.modules + ) self.app_instance, self.app_module = ( reflex.utils.prerequisites.get_and_validate_app( - # Do not reload the module for pre-existing apps (only apps generated from source) - reload=self.app_source is not None + reload=should_reload_app ) ) self.app_asgi = self.app_instance() diff --git a/reflex/vars/__init__.py b/reflex/vars/__init__.py index c81e9a9bff3..c1989e5733d 100644 --- a/reflex/vars/__init__.py +++ b/reflex/vars/__init__.py @@ -17,7 +17,7 @@ from .datetime import DateTimeVar from .function import FunctionStringVar, FunctionVar, VarOperationCall from .number import BooleanVar, LiteralBooleanVar, LiteralNumberVar, NumberVar -from .object import LiteralObjectVar, ObjectVar +from .object import LiteralObjectVar, ObjectVar, RestProp from .sequence import ( ArrayVar, ConcatVarOperation, @@ -46,6 +46,7 @@ "LiteralVar", "NumberVar", "ObjectVar", + "RestProp", "StringVar", "Var", "VarData", diff --git a/reflex/vars/object.py b/reflex/vars/object.py index 1ecb09394ba..0bcc195a0fc 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -362,6 +362,10 @@ def contains(self, key: Var | Any) -> BooleanVar: return object_has_own_property_operation(self, key) +class RestProp(ObjectVar[dict[str, Any]]): + """A special object var representing forwarded rest props.""" + + @dataclasses.dataclass( eq=False, frozen=True, diff --git a/tests/integration/test_experimental_memo.py b/tests/integration/test_experimental_memo.py new file mode 100644 index 00000000000..f63be1804c1 --- /dev/null +++ b/tests/integration/test_experimental_memo.py @@ -0,0 +1,150 @@ +"""Integration tests for rx._x.memo.""" + +from collections.abc import Generator + +import pytest +from selenium.webdriver.common.by import By + +import reflex.app as reflex_app +import reflex.state as reflex_state +from reflex import constants +from reflex.testing import AppHarness + + +def ExperimentalMemoApp(): + """Reflex app that exercises experimental memo functions and components.""" + import reflex as rx + + class FooComponent(rx.Fragment): + def add_custom_code(self) -> list[str]: + return [ + "const foo = 'bar'", + ] + + @rx._x.memo + def foo_component(label: rx.Var[str]) -> rx.Component: + return FooComponent.create(label, rx.Var("foo")) + + @rx._x.memo + def format_price(amount: rx.Var[int], currency: rx.Var[str]) -> rx.Var[str]: + return currency.to(str) + ": $" + amount.to(str) + + @rx._x.memo + def summary_card( + children: rx.Var[rx.Component], + rest: rx.RestProp, + *, + title: rx.Var[str], + value: rx.Var[str], + ) -> rx.Component: + return rx.box( + rx.heading(title, id="summary-title"), + rx.text(value, id="summary-value"), + children, + rest, + ) + + class ExperimentalMemoState(rx.State): + amount: int = 125 + currency: str = "USD" + title: str = "Current Price" + + @rx.event + def increment_amount(self): + self.amount += 5 + + def index() -> rx.Component: + formatted_price = format_price( + amount=ExperimentalMemoState.amount, + currency=ExperimentalMemoState.currency, + ) + return rx.vstack( + rx.vstack( + foo_component(label="foo"), + foo_component(label="bar"), + id="experimental-memo-custom-code", + ), + rx.text(formatted_price, id="formatted-price"), + rx.button( + "Increment", + id="increment-price", + on_click=ExperimentalMemoState.increment_amount, + ), + summary_card( + rx.text("Children are passed positionally.", id="summary-child"), + title=ExperimentalMemoState.title, + value=formatted_price, + id="summary-card", + class_name="forwarded-summary-card", + ), + ) + + app = rx.App() + app.add_page(index) + + +@pytest.fixture +def experimental_memo_app(tmp_path, monkeypatch) -> Generator[AppHarness, None, None]: + """Start ExperimentalMemoApp app at tmp_path via AppHarness. + + Args: + tmp_path: pytest tmp_path fixture. + monkeypatch: pytest monkeypatch fixture. + + Yields: + Running AppHarness instance. + """ + monkeypatch.setenv( + constants.PYTEST_CURRENT_TEST, + "tests/integration/test_experimental_memo.py::test_experimental_memo_app", + ) + monkeypatch.setattr(reflex_app, "is_testing_env", lambda: True) + monkeypatch.setattr(reflex_state, "is_testing_env", lambda: True) + with AppHarness.create( + root=tmp_path, + app_source=ExperimentalMemoApp, + ) as harness: + yield harness + + +def test_experimental_memo_app(experimental_memo_app: AppHarness): + """Render experimental memos and assert on their frontend behavior. + + Args: + experimental_memo_app: Harness for ExperimentalMemoApp. + """ + assert experimental_memo_app.app_instance is not None, "app is not running" + driver = experimental_memo_app.frontend() + + memo_custom_code_stack = AppHarness.poll_for_or_raise_timeout( + lambda: driver.find_element(By.ID, "experimental-memo-custom-code") + ) + assert ( + experimental_memo_app.poll_for_content(memo_custom_code_stack, exp_not_equal="") + == "foobarbarbar" + ) + assert memo_custom_code_stack.text == "foobarbarbar" + + formatted_price = driver.find_element(By.ID, "formatted-price") + assert ( + experimental_memo_app.poll_for_content(formatted_price, exp_not_equal="") + == "USD: $125" + ) + + summary_card = driver.find_element(By.ID, "summary-card") + assert "forwarded-summary-card" in (summary_card.get_attribute("class") or "") + assert driver.find_element(By.ID, "summary-title").text == "Current Price" + assert ( + driver.find_element(By.ID, "summary-child").text + == "Children are passed positionally." + ) + + summary_value = driver.find_element(By.ID, "summary-value") + assert ( + experimental_memo_app.poll_for_content(summary_value, exp_not_equal="") + == "USD: $125" + ) + + driver.find_element(By.ID, "increment-price").click() + assert experimental_memo_app.poll_for_content(formatted_price) == "USD: $130" + assert experimental_memo_app.poll_for_content(summary_value) == "USD: $130" diff --git a/tests/units/conftest.py b/tests/units/conftest.py index 612d8beaf85..1399832d8be 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -8,7 +8,9 @@ import pytest from reflex.app import App +from reflex.components.component import CUSTOM_COMPONENTS from reflex.event import EventSpec +from reflex.experimental.memo import EXPERIMENTAL_MEMOS from reflex.model import ModelRegistry from reflex.testing import chdir from reflex.utils import prerequisites @@ -209,3 +211,21 @@ def model_registry() -> Generator[type[ModelRegistry], None, None]: """ yield ModelRegistry ModelRegistry._metadata = None + + +@pytest.fixture +def preserve_memo_registries(): + """Save and restore global memo registries around a test. + + Yields: + None + """ + custom_components = dict(CUSTOM_COMPONENTS) + experimental_memos = dict(EXPERIMENTAL_MEMOS) + try: + yield + finally: + CUSTOM_COMPONENTS.clear() + CUSTOM_COMPONENTS.update(custom_components) + EXPERIMENTAL_MEMOS.clear() + EXPERIMENTAL_MEMOS.update(experimental_memos) diff --git a/tests/units/experimental/__init__.py b/tests/units/experimental/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/units/experimental/test_memo.py b/tests/units/experimental/test_memo.py new file mode 100644 index 00000000000..2b6fc1fb485 --- /dev/null +++ b/tests/units/experimental/test_memo.py @@ -0,0 +1,403 @@ +"""Tests for experimental memo support.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +import reflex as rx +from reflex.compiler import compiler +from reflex.compiler import utils as compiler_utils +from reflex.components.component import CUSTOM_COMPONENTS, Component +from reflex.experimental.memo import ( + EXPERIMENTAL_MEMOS, + ExperimentalMemoComponent, + ExperimentalMemoComponentDefinition, + ExperimentalMemoFunctionDefinition, +) +from reflex.style import Style +from reflex.utils.imports import ImportVar +from reflex.vars import VarData +from reflex.vars.base import Var +from reflex.vars.function import FunctionVar + + +@pytest.fixture(autouse=True) +def _restore_memo_registries(preserve_memo_registries): + """Autouse wrapper around the shared preserve_memo_registries fixture.""" + + +def test_var_returning_memo(): + """Var-returning memos should behave like imported function vars.""" + + @rx._x.memo + def format_price(amount: rx.Var[int], currency: rx.Var[str]) -> rx.Var[str]: + return currency.to(str) + ": $" + amount.to(str) + + price = Var(_js_expr="price", _var_type=int) + currency = Var(_js_expr="currency", _var_type=str) + + assert ( + str(format_price(amount=price, currency=currency)) + == "(format_price(price, currency))" + ) + assert ( + str(format_price.call(amount=price, currency=currency)) + == "(format_price(price, currency))" + ) + assert isinstance(format_price._as_var(), FunctionVar) + + definition = EXPERIMENTAL_MEMOS["format_price"] + assert isinstance(definition, ExperimentalMemoFunctionDefinition) + assert ( + str(definition.function) == '((amount, currency) => ((currency+": $")+amount))' + ) + + with pytest.raises(TypeError, match="only accepts keyword props"): + format_price(price, currency) + + +def test_component_returning_memo_with_children_and_rest(): + """Component-returning memos should accept positional children and forwarded props.""" + + @rx._x.memo + def my_card( + children: rx.Var[rx.Component], + rest: rx.RestProp, + *, + title: rx.Var[str], + ) -> rx.Component: + return rx.box( + rx.heading(title), + children, + rest, + ) + + component = my_card( + rx.text("child 1"), + rx.text("child 2"), + title="Hello", + class_name="extra", + ) + + assert isinstance(component, ExperimentalMemoComponent) + assert len(component.children) == 2 + + rendered = component.render() + assert rendered["name"] == "MyCard" + assert 'title:"Hello"' in rendered["props"] + assert 'className:"extra"' in rendered["props"] + + definition = EXPERIMENTAL_MEMOS["MyCard"] + assert isinstance(definition, ExperimentalMemoComponentDefinition) + assert any(str(prop) == "rest" for prop in definition.component.special_props) + + _, code, _ = compiler.compile_memo_components( + (), tuple(EXPERIMENTAL_MEMOS.values()) + ) + assert "export const MyCard = memo(({children, title:title" in code + assert "...rest" in code + assert "jsx(RadixThemesBox,{...rest}" in code + + +def test_var_returning_memo_with_rest_props(): + """Var-returning memos should capture extra keyword args into RestProp.""" + + @rx._x.memo + def merge_styles( + base: rx.Var[dict[str, str]], + overrides: rx.RestProp, + ) -> rx.Var[Any]: + return base.to(dict).merge(overrides) + + base = Var(_js_expr="base", _var_type=dict[str, str]) + merged = merge_styles(base=base, color="red") + + assert "merge_styles" in str(merged) + assert '["base"] : base' in str(merged) + assert '["color"] : "red"' in str(merged) + + _, code, _ = compiler.compile_memo_components( + (), tuple(EXPERIMENTAL_MEMOS.values()) + ) + assert ( + "export const merge_styles = (({base, ...overrides}) => ({...base, ...overrides}));" + in code + ) + + with pytest.raises(TypeError, match="Do not pass `overrides=` directly"): + merge_styles(base=base, overrides={"color": "red"}) + + +def test_var_returning_memo_with_children_and_rest(): + """Var-returning memos should accept positional children plus keyword props.""" + + @rx._x.memo + def label_slot( + children: rx.Var[rx.Component], + rest: rx.RestProp, + *, + label: rx.Var[str], + ) -> rx.Var[str]: + return label + + rendered = label_slot( + rx.text("child"), + label="Hello", + class_name="slot", + ) + + assert "label_slot" in str(rendered) + assert '["children"]' in str(rendered) + assert '["class_name"] : "slot"' in str(rendered) + + _, code, _ = compiler.compile_memo_components( + (), tuple(EXPERIMENTAL_MEMOS.values()) + ) + assert "export const label_slot = (({children, label, ...rest}) => label);" in code + + +def test_memo_requires_var_annotations(): + """Experimental memos should require Var annotations on parameters.""" + with pytest.raises(TypeError, match="must be annotated"): + + @rx._x.memo + def bad_annotation(value: int) -> rx.Var[str]: + return rx.Var.create("x") + + with pytest.raises(TypeError, match="Missing annotation"): + + @rx._x.memo + def missing_annotation(value) -> rx.Var[str]: + return rx.Var.create("x") + + +def test_memo_rejects_invalid_children_annotation(): + """Component memos should validate the special children annotation.""" + with pytest.raises(TypeError, match="children"): + + @rx._x.memo + def bad_children(children: rx.Var[str]) -> rx.Component: + return rx.text(children) + + +def test_memo_rejects_multiple_rest_props(): + """Experimental memos should only allow a single RestProp.""" + with pytest.raises(TypeError, match="only supports one"): + + @rx._x.memo + def too_many_rest( + first: rx.RestProp, + second: rx.RestProp, + ) -> rx.Var[Any]: + return first + + +def test_memo_rejects_component_and_function_name_collision(): + """Experimental memos should reject same exported name across kinds.""" + + @rx._x.memo + def foo_bar() -> rx.Component: + return rx.box() + + assert "FooBar" in EXPERIMENTAL_MEMOS + + with pytest.raises(ValueError, match=r"name collision.*FooBar"): + + @rx._x.memo + def FooBar() -> rx.Var[str]: + return rx.Var.create("x") + + +def test_memo_rejects_component_export_name_collision(): + """Experimental memos should reject duplicate component export names.""" + + @rx._x.memo + def foo_bar() -> rx.Component: + return rx.box() + + with pytest.raises(ValueError, match=r"name collision.*FooBar"): + + @rx._x.memo + def foo__bar() -> rx.Component: + return rx.box() + + +def test_memo_rejects_varargs(): + """Experimental memos should reject *args and **kwargs.""" + with pytest.raises(TypeError, match=r"\*args"): + + @rx._x.memo + def bad_args(*values: rx.Var[str]) -> rx.Var[str]: + return rx.Var.create("x") + + with pytest.raises(TypeError, match=r"\*\*kwargs"): + + @rx._x.memo + def bad_kwargs(**values: rx.Var[str]) -> rx.Var[str]: + return rx.Var.create("x") + + +def test_component_memo_rejects_invalid_positional_usage(): + """Component memos should only accept positional children.""" + + @rx._x.memo + def title_card(*, title: rx.Var[str]) -> rx.Component: + return rx.box(rx.heading(title)) + + with pytest.raises(TypeError, match="only accepts keyword props"): + title_card(rx.text("child")) + + @rx._x.memo + def child_card( + children: rx.Var[rx.Component], *, title: rx.Var[str] + ) -> rx.Component: + return rx.box(rx.heading(title), children) + + with pytest.raises(TypeError, match="only accepts positional children"): + child_card("not a component", title="Hello") + + +def test_var_memo_rejects_invalid_positional_usage(): + """Var memos should also reserve positional arguments for children only.""" + + @rx._x.memo + def format_price(amount: rx.Var[int], currency: rx.Var[str]) -> rx.Var[str]: + return currency.to(str) + ": $" + amount.to(str) + + price = Var(_js_expr="price", _var_type=int) + currency = Var(_js_expr="currency", _var_type=str) + + with pytest.raises(TypeError, match="only accepts keyword props"): + format_price(price, currency) + + @rx._x.memo + def child_label( + children: rx.Var[rx.Component], *, label: rx.Var[str] + ) -> rx.Var[str]: + return label + + with pytest.raises(TypeError, match="only accepts positional children"): + child_label("not a component", label="Hello") + + +def test_var_returning_memo_rejects_hooks(): + """Var-returning memos should reject hook-bearing expressions.""" + with pytest.raises(TypeError, match="cannot depend on hooks"): + + @rx._x.memo + def bad_hook(value: rx.Var[str]) -> rx.Var[str]: + return Var( + _js_expr="value", + _var_type=str, + _var_data=VarData(hooks={"const badHook = 1": None}), + ) + + +def test_var_returning_memo_rejects_non_bundled_imports(): + """Var-returning memos should reject non-bundled imports.""" + with pytest.raises(TypeError, match="not bundled"): + + @rx._x.memo + def bad_import(value: rx.Var[str]) -> rx.Var[str]: + return Var( + _js_expr="value", + _var_type=str, + _var_data=VarData(imports={"some-lib": [ImportVar(tag="x")]}), + ) + + +def test_compile_memo_components_includes_experimental_functions_and_components(): + """The shared memo output should include both experimental functions and components.""" + + @rx.memo + def old_wrapper(title: rx.Var[str]) -> rx.Component: + return rx.text(title) + + @rx._x.memo + def format_price(amount: rx.Var[int], currency: rx.Var[str]) -> rx.Var[str]: + return currency.to(str) + ": $" + amount.to(str) + + @rx._x.memo + def my_card(children: rx.Var[rx.Component], *, title: rx.Var[str]) -> rx.Component: + return rx.box(rx.heading(title), children) + + _, code, _ = compiler.compile_memo_components( + dict.fromkeys(CUSTOM_COMPONENTS.values()), + tuple(EXPERIMENTAL_MEMOS.values()), + ) + + assert "export const OldWrapper = memo(" in code + assert "export const format_price =" in code + assert "export const MyCard = memo(" in code + + +def test_experimental_component_memo_get_imports(): + """Experimental component memos should resolve imports during compilation.""" + + class Inner(Component): + tag = "Inner" + library = "inner" + + @rx._x.memo + def wrapper() -> rx.Component: + return Inner.create() + + experimental_component = wrapper() + + assert "inner" not in experimental_component._get_all_imports() + + definition = EXPERIMENTAL_MEMOS["Wrapper"] + assert isinstance(definition, ExperimentalMemoComponentDefinition) + _, imports = compiler_utils.compile_experimental_component_memo(definition) + assert "inner" in imports + + +def test_compile_experimental_component_memo_does_not_mutate_definition( + monkeypatch: pytest.MonkeyPatch, +): + """Experimental component memo compilation should not mutate stored components.""" + + @rx._x.memo + def wrapper() -> rx.Component: + return rx.box("hi") + + definition = EXPERIMENTAL_MEMOS["Wrapper"] + assert isinstance(definition, ExperimentalMemoComponentDefinition) + assert definition.component.style == Style() + + monkeypatch.setattr( + "reflex.utils.prerequisites.get_and_validate_app", + lambda: SimpleNamespace( + app=SimpleNamespace( + style={type(definition.component): Style({"color": "red"})} + ) + ), + ) + + render, _ = compiler_utils.compile_experimental_component_memo(definition) + + assert render["render"]["props"] == ['css:({ ["color"] : "red" })'] + assert definition.component.style == Style() + + +def test_compile_memo_components_includes_experimental_custom_code(): + """Experimental component memos should include custom code in compiled output.""" + + class FooComponent(rx.Fragment): + def add_custom_code(self) -> list[str]: + return [ + "const foo = 'bar'", + ] + + @rx._x.memo + def foo_component(label: rx.Var[str]) -> rx.Component: + return FooComponent.create(label, rx.Var("foo")) + + _, code, _ = compiler.compile_memo_components( + (), tuple(EXPERIMENTAL_MEMOS.values()) + ) + + assert "const foo = 'bar'" in code diff --git a/tests/units/test_testing.py b/tests/units/test_testing.py index 8c8f1461bf0..36c8f173b01 100644 --- a/tests/units/test_testing.py +++ b/tests/units/test_testing.py @@ -1,8 +1,18 @@ """Unit tests for the included testing tools.""" +import sys +from types import ModuleType, SimpleNamespace +from unittest import mock + import pytest +import reflex.config +import reflex.reflex as reflex_cli +import reflex.testing as reflex_testing +import reflex.utils.prerequisites +from reflex.components.component import CUSTOM_COMPONENTS from reflex.constants import IS_WINDOWS +from reflex.experimental.memo import EXPERIMENTAL_MEMOS from reflex.testing import AppHarness @@ -38,3 +48,93 @@ class State(rx.State): assert harness.frontend_process.poll() is None assert harness.frontend_process.poll() is not None + + +@pytest.fixture +def harness_mocks(monkeypatch): + """Common mocks for AppHarness initialization tests. + + Args: + monkeypatch: pytest monkeypatch fixture + + Returns: + Namespace with fake_config and get_and_validate_app mock. + """ + fake_config = SimpleNamespace(loglevel=None, module="test_app.test_app") + fake_app = mock.Mock(_state_manager=None) + get_and_validate_app = mock.Mock( + return_value=reflex.utils.prerequisites.AppInfo( + app=fake_app, + module=ModuleType(fake_config.module), + ) + ) + + monkeypatch.setattr(reflex_testing, "get_config", lambda: fake_config) + monkeypatch.setattr(reflex.config, "get_config", lambda reload=False: fake_config) + monkeypatch.setattr( + reflex.utils.prerequisites, + "get_and_validate_app", + get_and_validate_app, + ) + + return SimpleNamespace( + config=fake_config, + get_and_validate_app=get_and_validate_app, + ) + + +def test_app_harness_initialize_clears_memo_registries( + tmp_path, preserve_memo_registries, harness_mocks, monkeypatch +): + """Ensure app initialization clears leaked memo registries. + + Args: + tmp_path: pytest tmp_path fixture + preserve_memo_registries: restores global memo registries after the test + harness_mocks: shared AppHarness mock setup + monkeypatch: pytest monkeypatch fixture + """ + monkeypatch.setattr(reflex_cli, "_init", lambda **kwargs: None) + + CUSTOM_COMPONENTS["FooComponent"] = mock.sentinel.component + EXPERIMENTAL_MEMOS["format_value"] = mock.sentinel.memo + + harness = AppHarness.create( + root=tmp_path / "memo_app", + app_source="import reflex as rx\napp = rx.App()", + app_name="memo_app", + ) + harness.app_module_path.parent.mkdir(parents=True, exist_ok=True) + harness._initialize_app() + + assert "FooComponent" not in CUSTOM_COMPONENTS + assert "format_value" not in EXPERIMENTAL_MEMOS + harness_mocks.get_and_validate_app.assert_called_once_with(reload=True) + + +def test_app_harness_initialize_reloads_existing_imported_app( + tmp_path, preserve_memo_registries, harness_mocks, monkeypatch +): + """Ensure pre-existing imported apps are reloaded after memo registry reset. + + Args: + tmp_path: pytest tmp_path fixture + preserve_memo_registries: restores global memo registries after the test + harness_mocks: shared AppHarness mock setup + monkeypatch: pytest monkeypatch fixture + """ + monkeypatch.setattr( + reflex.utils.prerequisites, + "initialize_frontend_dependencies", + lambda: None, + ) + monkeypatch.setitem( + sys.modules, + harness_mocks.config.module, + ModuleType(harness_mocks.config.module), + ) + + harness = AppHarness.create(root=tmp_path / "plain_app") + harness._initialize_app() + + harness_mocks.get_and_validate_app.assert_called_once_with(reload=True)