diff --git a/pyi_hashes.json b/pyi_hashes.json index 03d20e2342b..60bf946e01e 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -1,5 +1,5 @@ { - "reflex/__init__.pyi": "0a3ae880e256b9fd3b960e12a2cb51a7", + "reflex/__init__.pyi": "a0266c47111e9af7f340186013c7a31e", "reflex/components/__init__.pyi": "ac05995852baa81062ba3d18fbc489fb", "reflex/components/base/__init__.pyi": "16e47bf19e0d62835a605baa3d039c5a", "reflex/components/base/app_wrap.pyi": "22e94feaa9fe675bcae51c412f5b67f1", @@ -19,7 +19,7 @@ "reflex/components/core/helmet.pyi": "cb5ac1be02c6f82fcc78ba74651be593", "reflex/components/core/html.pyi": "4ebe946f3fc097fc2e31dddf7040ec1c", "reflex/components/core/sticky.pyi": "cb763b986a9b0654d1a3f33440dfcf60", - "reflex/components/core/upload.pyi": "c90782be1b63276b428bce3fd4ce0af2", + "reflex/components/core/upload.pyi": "ca9f7424f3b74b1b56f5c819e8654eeb", "reflex/components/core/window_events.pyi": "e7af4bf5341c4afaf60c4a534660f68f", "reflex/components/datadisplay/__init__.pyi": "52755871369acbfd3a96b46b9a11d32e", "reflex/components/datadisplay/code.pyi": "1d123d19ef08f085422f3023540e7bb1", diff --git a/reflex/.templates/web/utils/helpers/upload.js b/reflex/.templates/web/utils/helpers/upload.js index 6bbfc746ed6..6d3d146c6c5 100644 --- a/reflex/.templates/web/utils/helpers/upload.js +++ b/reflex/.templates/web/utils/helpers/upload.js @@ -27,10 +27,7 @@ export const uploadFiles = async ( getBackendURL, getToken, ) => { - // return if there's no file to upload - if (files === undefined || files.length === 0) { - return false; - } + files = files ?? []; const upload_ref_name = `__upload_controllers_${upload_id}`; diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 9e937ed62cd..56eba97d4fe 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -419,16 +419,6 @@ export const applyEvent = async (event, socket, navigate, params) => { export const applyRestEvent = async (event, socket, navigate, params) => { let eventSent = false; if (event.handler === "uploadFiles") { - if (event.payload.files === undefined || event.payload.files.length === 0) { - // Submit the event over the websocket to trigger the event handler. - return await applyEvent( - ReflexEvent(event.name, { files: [] }), - socket, - navigate, - params, - ); - } - // Start upload, but do not wait for it, which would block other events. uploadFiles( event.name, diff --git a/reflex/__init__.py b/reflex/__init__.py index 066df110f02..29841ffd86a 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -297,6 +297,10 @@ "config": ["Config", "DBConfig"], "constants": ["Env"], "constants.colors": ["Color"], + "_upload": [ + "UploadChunk", + "UploadChunkIterator", + ], "event": [ "event", "EventChain", @@ -320,6 +324,7 @@ "set_value", "stop_propagation", "upload_files", + "upload_files_chunk", "window_alert", ], "istate.storage": [ diff --git a/reflex/_upload.py b/reflex/_upload.py new file mode 100644 index 00000000000..1e7ca21537c --- /dev/null +++ b/reflex/_upload.py @@ -0,0 +1,719 @@ +"""Backend upload helpers and routes for Reflex apps.""" + +from __future__ import annotations + +import asyncio +import contextlib +import dataclasses +from collections import deque +from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable +from pathlib import Path +from typing import TYPE_CHECKING, Any, BinaryIO, cast + +from python_multipart.multipart import MultipartParser, parse_options_header +from starlette.datastructures import Headers +from starlette.datastructures import UploadFile as StarletteUploadFile +from starlette.exceptions import HTTPException +from starlette.formparsers import MultiPartException, _user_safe_decode +from starlette.requests import ClientDisconnect, Request +from starlette.responses import JSONResponse, Response, StreamingResponse +from typing_extensions import Self + +from reflex import constants +from reflex.utils import exceptions + +if TYPE_CHECKING: + from reflex.app import App + from reflex.event import EventHandler + from reflex.state import BaseState + from reflex.utils.types import Receive, Scope, Send + + +@dataclasses.dataclass(frozen=True) +class UploadFile(StarletteUploadFile): + """A file uploaded to the server. + + Args: + file: The standard Python file object (non-async). + filename: The original file name. + size: The size of the file in bytes. + headers: The headers of the request. + """ + + file: BinaryIO + + path: Path | None = dataclasses.field(default=None) + + size: int | None = dataclasses.field(default=None) + + headers: Headers = dataclasses.field(default_factory=Headers) + + @property + def filename(self) -> str | None: + """Get the name of the uploaded file. + + Returns: + The name of the uploaded file. + """ + return self.name + + @property + def name(self) -> str | None: + """Get the name of the uploaded file. + + Returns: + The name of the uploaded file. + """ + if self.path: + return self.path.name + return None + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class UploadChunk: + """A chunk of uploaded file data.""" + + filename: str + offset: int + content_type: str + data: bytes + + +class UploadChunkIterator(AsyncIterator[UploadChunk]): + """An async iterator over uploaded file chunks.""" + + __slots__ = ( + "_chunks", + "_closed", + "_condition", + "_consumer_task", + "_error", + "_maxsize", + ) + + def __init__(self, *, maxsize: int = 8): + """Initialize the iterator. + + Args: + maxsize: Maximum number of chunks to buffer before blocking producers. + """ + self._maxsize = maxsize + self._chunks: deque[UploadChunk] = deque() + self._condition = asyncio.Condition() + self._closed = False + self._error: Exception | None = None + self._consumer_task: asyncio.Task[Any] | None = None + + def __aiter__(self) -> Self: + """Return the iterator itself. + + Returns: + The upload chunk iterator. + """ + return self + + async def __anext__(self) -> UploadChunk: + """Yield the next available upload chunk. + + Returns: + The next upload chunk. + + Raises: + _error: Any error forwarded from the upload producer. + StopAsyncIteration: When all chunks have been consumed. + """ + async with self._condition: + while not self._chunks and not self._closed: + await self._condition.wait() + + if self._chunks: + chunk = self._chunks.popleft() + self._condition.notify_all() + return chunk + + if self._error is not None: + raise self._error + raise StopAsyncIteration + + def set_consumer_task(self, task: asyncio.Task[Any]) -> None: + """Track the task consuming this iterator. + + Args: + task: The background task consuming upload chunks. + """ + self._consumer_task = task + task.add_done_callback(self._wake_waiters) + + async def push(self, chunk: UploadChunk) -> None: + """Push a new chunk into the iterator. + + Args: + chunk: The chunk to push. + + Raises: + RuntimeError: If the iterator is already closed or the consumer exited early. + """ + async with self._condition: + while len(self._chunks) >= self._maxsize and not self._closed: + self._raise_if_consumer_finished() + await self._condition.wait() + + if self._closed: + msg = "Upload chunk iterator is closed." + raise RuntimeError(msg) + + self._raise_if_consumer_finished() + self._chunks.append(chunk) + self._condition.notify_all() + + async def finish(self) -> None: + """Mark the iterator as complete.""" + async with self._condition: + if self._closed: + return + self._closed = True + self._condition.notify_all() + + async def fail(self, error: Exception) -> None: + """Mark the iterator as failed. + + Args: + error: The error to raise from the iterator. + """ + async with self._condition: + if self._closed: + return + self._closed = True + self._error = error + self._condition.notify_all() + + def _raise_if_consumer_finished(self) -> None: + """Raise if the consumer task exited before draining the iterator. + + Raises: + RuntimeError: If the consumer task completed before draining the iterator. + """ + if self._consumer_task is None or not self._consumer_task.done(): + return + + try: + task_exc = self._consumer_task.exception() + except asyncio.CancelledError as err: + task_exc = err + + msg = "Upload handler returned before consuming all upload chunks." + if task_exc is not None: + raise RuntimeError(msg) from task_exc + raise RuntimeError(msg) + + def _wake_waiters(self, task: asyncio.Task[Any]) -> None: + """Wake any producers or consumers blocked on the iterator condition. + + Args: + task: The completed consumer task. + """ + task.get_loop().create_task(self._notify_waiters()) + + async def _notify_waiters(self) -> None: + """Notify tasks waiting on the iterator condition.""" + async with self._condition: + self._condition.notify_all() + + +@dataclasses.dataclass(kw_only=True, slots=True) +class _UploadChunkPart: + """Track the current multipart file part for upload streaming.""" + + content_disposition: bytes | None = None + field_name: str = "" + filename: str | None = None + content_type: str = "" + item_headers: list[tuple[bytes, bytes]] = dataclasses.field(default_factory=list) + offset: int = 0 + bytes_emitted: int = 0 + is_upload_chunk: bool = False + + +@dataclasses.dataclass(kw_only=True, slots=True) +class _UploadChunkMultipartParser: + """Streaming multipart parser for streamed upload files.""" + + headers: Headers + stream: AsyncGenerator[bytes, None] + chunk_iter: UploadChunkIterator + _charset: str = "" + _current_partial_header_name: bytes = b"" + _current_partial_header_value: bytes = b"" + _current_part: _UploadChunkPart = dataclasses.field( + default_factory=_UploadChunkPart + ) + _chunks_to_emit: deque[UploadChunk] = dataclasses.field(default_factory=deque) + _seen_upload_chunk: bool = False + _part_count: int = 0 + _emitted_chunk_count: int = 0 + _emitted_bytes: int = 0 + _stream_chunk_count: int = 0 + + def on_part_begin(self) -> None: + """Reset parser state for a new multipart part.""" + self._current_part = _UploadChunkPart() + + def on_part_data(self, data: bytes, start: int, end: int) -> None: + """Record streamed chunk data for the current part.""" + if ( + not self._current_part.is_upload_chunk + or self._current_part.filename is None + ): + return + + message_bytes = data[start:end] + self._chunks_to_emit.append( + UploadChunk( + filename=self._current_part.filename, + offset=self._current_part.offset + self._current_part.bytes_emitted, + content_type=self._current_part.content_type, + data=message_bytes, + ) + ) + self._current_part.bytes_emitted += len(message_bytes) + self._emitted_chunk_count += 1 + self._emitted_bytes += len(message_bytes) + + def on_part_end(self) -> None: + """Emit a zero-byte chunk for empty file parts.""" + if ( + self._current_part.is_upload_chunk + and self._current_part.filename is not None + and self._current_part.bytes_emitted == 0 + ): + self._chunks_to_emit.append( + UploadChunk( + filename=self._current_part.filename, + offset=self._current_part.offset, + content_type=self._current_part.content_type, + data=b"", + ) + ) + self._emitted_chunk_count += 1 + + def on_header_field(self, data: bytes, start: int, end: int) -> None: + """Accumulate multipart header field bytes.""" + self._current_partial_header_name += data[start:end] + + def on_header_value(self, data: bytes, start: int, end: int) -> None: + """Accumulate multipart header value bytes.""" + self._current_partial_header_value += data[start:end] + + def on_header_end(self) -> None: + """Store the completed multipart header.""" + field = self._current_partial_header_name.lower() + if field == b"content-disposition": + self._current_part.content_disposition = self._current_partial_header_value + self._current_part.item_headers.append(( + field, + self._current_partial_header_value, + )) + self._current_partial_header_name = b"" + self._current_partial_header_value = b"" + + def on_headers_finished(self) -> None: + """Parse upload metadata from multipart headers.""" + disposition, options = parse_options_header( + self._current_part.content_disposition + ) + if disposition != b"form-data": + msg = "Invalid upload chunk disposition." + raise MultiPartException(msg) + + try: + field_name = _user_safe_decode(options[b"name"], self._charset) + except KeyError as err: + msg = 'The Content-Disposition header field "name" must be provided.' + raise MultiPartException(msg) from err + + try: + filename = _user_safe_decode(options[b"filename"], self._charset) + except KeyError: + # Ignore non-file form fields entirely. + return + filename = Path(filename.lstrip("/")).name + + content_type = "" + for header_name, header_value in self._current_part.item_headers: + if header_name == b"content-type": + content_type = _user_safe_decode(header_value, self._charset) + break + + self._current_part.field_name = field_name + self._current_part.filename = filename + self._current_part.content_type = content_type + self._current_part.offset = 0 + self._current_part.bytes_emitted = 0 + self._current_part.is_upload_chunk = True + self._seen_upload_chunk = True + self._part_count += 1 + + def on_end(self) -> None: + """Finalize parser callbacks.""" + + async def _flush_emitted_chunks(self) -> None: + """Push parsed upload chunks into the handler iterator.""" + while self._chunks_to_emit: + await self.chunk_iter.push(self._chunks_to_emit.popleft()) + + async def parse(self) -> None: + """Parse the incoming request stream and push chunks to the iterator. + + Raises: + MultiPartException: If the request is not valid multipart upload data. + RuntimeError: If the upload handler exits before consuming all chunks. + """ + _, params = parse_options_header(self.headers["Content-Type"]) + charset = params.get(b"charset", "utf-8") + if isinstance(charset, bytes): + charset = charset.decode("latin-1") + self._charset = charset + + try: + boundary = params[b"boundary"] + except KeyError as err: + msg = "Missing boundary in multipart." + raise MultiPartException(msg) from err + + callbacks = { + "on_part_begin": self.on_part_begin, + "on_part_data": self.on_part_data, + "on_part_end": self.on_part_end, + "on_header_field": self.on_header_field, + "on_header_value": self.on_header_value, + "on_header_end": self.on_header_end, + "on_headers_finished": self.on_headers_finished, + "on_end": self.on_end, + } + parser = MultipartParser(boundary, cast(Any, callbacks)) + + async for chunk in self.stream: + self._stream_chunk_count += 1 + parser.write(chunk) + await self._flush_emitted_chunks() + + parser.finalize() + await self._flush_emitted_chunks() + + +class _UploadStreamingResponse(StreamingResponse): + """Streaming response that always releases upload form resources.""" + + _on_finish: Callable[[], Awaitable[None]] + + def __init__( + self, + *args: Any, + on_finish: Callable[[], Awaitable[None]], + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self._on_finish = on_finish + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + try: + await super().__call__(scope, receive, send) + finally: + await self._on_finish() + + +def _require_upload_headers(request: Request) -> tuple[str, str]: + """Extract the required upload headers from a request. + + Args: + request: The incoming request. + + Returns: + The client token and event handler name. + + Raises: + HTTPException: If the upload headers are missing. + """ + token = request.headers.get("reflex-client-token") + handler = request.headers.get("reflex-event-handler") + + if not token or not handler: + raise HTTPException( + status_code=400, + detail="Missing reflex-client-token or reflex-event-handler header.", + ) + + return token, handler + + +async def _get_upload_runtime_handler( + app: App, + token: str, + handler_name: str, +) -> tuple[BaseState, EventHandler]: + """Resolve the runtime state and event handler for an upload request. + + Args: + app: The Reflex app. + token: The client token. + handler_name: The fully qualified event handler name. + + Returns: + The root state instance and resolved event handler. + """ + from reflex.state import _substate_key + + substate_token = _substate_key(token, handler_name.rpartition(".")[0]) + state = await app.state_manager.get_state(substate_token) + _current_state, event_handler = state._get_event_handler(handler_name) + return state, event_handler + + +def _seed_upload_router_data(state: BaseState, token: str) -> None: + """Ensure upload-launched handlers have the client token in router state. + + Background upload handlers use ``StateProxy`` which derives its mutable-state + token from ``self.router.session.client_token``. Upload requests do not flow + through the normal websocket event pipeline, so we seed the token here. + + Args: + state: The root state instance. + token: The client token from the upload request. + """ + from reflex.state import RouterData + + router_data = dict(state.router_data) + if router_data.get(constants.RouteVar.CLIENT_TOKEN) == token: + return + + router_data[constants.RouteVar.CLIENT_TOKEN] = token + state.router_data = router_data + state.router = RouterData.from_router_data(router_data) + + +async def _upload_buffered_file( + request: Request, + app: App, + *, + token: str, + handler_name: str, + handler_upload_param: tuple[str, Any], +) -> Response: + """Handle buffered uploads on the standard upload endpoint. + + Returns: + A streaming response for the buffered upload. + """ + from reflex.event import Event + from reflex.utils.exceptions import UploadValueError + + try: + form_data = await request.form() + except ClientDisconnect: + return Response() + + form_data_closed = False + + async def _close_form_data() -> None: + """Close the parsed form data exactly once.""" + nonlocal form_data_closed + if form_data_closed: + return + form_data_closed = True + await form_data.close() + + def _create_upload_event() -> Event: + """Create an upload event using the live Starlette temp files. + + Returns: + The upload event backed by the parsed files. + """ + files = form_data.getlist("files") + file_uploads = [] + for file in files: + if not isinstance(file, StarletteUploadFile): + raise UploadValueError( + "Uploaded file is not an UploadFile." + str(file) + ) + file_uploads.append( + UploadFile( + file=file.file, + path=Path(file.filename.lstrip("/")) if file.filename else None, + size=file.size, + headers=file.headers, + ) + ) + + return Event( + token=token, + name=handler_name, + payload={handler_upload_param[0]: file_uploads}, + ) + + event: Event | None = None + try: + event = _create_upload_event() + finally: + if event is None: + await _close_form_data() + + if event is None: + msg = "Upload event was not created." + raise RuntimeError(msg) + + async def _ndjson_updates(): + """Process the upload event, generating ndjson updates. + + Yields: + Each state update as newline-delimited JSON. + """ + async with app.state_manager.modify_state_with_links( + event.substate_token, event=event + ) as state: + async for update in state._process(event): + update = await app._postprocess(state, event, update) + yield update.json() + "\n" + + return _UploadStreamingResponse( + _ndjson_updates(), + media_type="application/x-ndjson", + on_finish=_close_form_data, + ) + + +def _background_upload_accepted_response() -> StreamingResponse: + """Return a minimal ndjson response for background upload dispatch.""" + from reflex.state import StateUpdate + + def _accepted_updates(): + yield StateUpdate(final=True).json() + "\n" + + return StreamingResponse( + _accepted_updates(), + media_type="application/x-ndjson", + status_code=202, + ) + + +async def _upload_chunk_file( + request: Request, + app: App, + *, + token: str, + handler_name: str, + handler_upload_param: tuple[str, Any], + acknowledge_on_upload_endpoint: bool, +) -> Response: + """Handle a streaming upload request. + + Returns: + The streaming upload response. + """ + from reflex.event import Event + + chunk_iter = UploadChunkIterator(maxsize=8) + event = Event( + token=token, + name=handler_name, + payload={handler_upload_param[0]: chunk_iter}, + ) + + async with app.state_manager.modify_state_with_links( + event.substate_token, + event=event, + ) as state: + _seed_upload_router_data(state, token) + task = app._process_background(state, event) + + if task is None: + msg = f"@rx.event(background=True) is required for upload_files_chunk handler `{handler_name}`." + return JSONResponse({"detail": msg}, status_code=400) + + chunk_iter.set_consumer_task(task) + + parser = _UploadChunkMultipartParser( + headers=request.headers, + stream=request.stream(), + chunk_iter=chunk_iter, + ) + + try: + await parser.parse() + except ClientDisconnect: + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + return Response() + except (MultiPartException, RuntimeError, ValueError) as err: + await chunk_iter.fail(err) + return JSONResponse({"detail": str(err)}, status_code=400) + + try: + await chunk_iter.finish() + except RuntimeError as err: + return JSONResponse({"detail": str(err)}, status_code=400) + + if acknowledge_on_upload_endpoint: + return _background_upload_accepted_response() + return Response(status_code=202) + + +def upload(app: App): + """Upload files, dispatching to buffered or streaming handling. + + Args: + app: The app to upload the file for. + + Returns: + The upload function. + """ + + async def upload_file(request: Request): + """Upload a file. + + Args: + request: The Starlette request object. + + Returns: + The upload response. + + Raises: + UploadValueError: If the handler does not have a supported annotation. + UploadTypeError: If a non-streaming upload is wired to a background task. + HTTPException: when the request does not include token / handler headers. + """ + from reflex.event import ( + resolve_upload_chunk_handler_param, + resolve_upload_handler_param, + ) + + token, handler_name = _require_upload_headers(request) + _state, event_handler = await _get_upload_runtime_handler( + app, token, handler_name + ) + + if event_handler.is_background: + try: + handler_upload_param = resolve_upload_chunk_handler_param(event_handler) + except exceptions.UploadValueError: + pass + else: + return await _upload_chunk_file( + request, + app, + token=token, + handler_name=handler_name, + handler_upload_param=handler_upload_param, + acknowledge_on_upload_endpoint=True, + ) + + handler_upload_param = resolve_upload_handler_param(event_handler) + return await _upload_buffered_file( + request, + app, + token=token, + handler_name=handler_name, + handler_upload_param=handler_upload_param, + ) + + return upload_file diff --git a/reflex/app.py b/reflex/app.py index 39fd6478797..30f5d575783 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -18,7 +18,6 @@ from collections.abc import ( AsyncGenerator, AsyncIterator, - Awaitable, Callable, Coroutine, Mapping, @@ -29,22 +28,21 @@ from pathlib import Path from timeit import default_timer as timer from types import SimpleNamespace -from typing import TYPE_CHECKING, Any, BinaryIO, ParamSpec, get_args, get_type_hints +from typing import TYPE_CHECKING, Any, ParamSpec from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from socketio import ASGIApp as EngineIOApp from socketio import AsyncNamespace, AsyncServer from starlette.applications import Starlette -from starlette.datastructures import Headers -from starlette.datastructures import UploadFile as StarletteUploadFile -from starlette.exceptions import HTTPException from starlette.middleware import cors -from starlette.requests import ClientDisconnect, Request -from starlette.responses import JSONResponse, Response, StreamingResponse +from starlette.requests import Request +from starlette.responses import JSONResponse, Response from starlette.staticfiles import StaticFiles from typing_extensions import Unpack from reflex import constants +from reflex._upload import UploadFile as UploadFile +from reflex._upload import upload from reflex.admin import AdminDash from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin from reflex.compiler import compiler @@ -112,7 +110,6 @@ js_runtimes, path_ops, prerequisites, - types, ) from reflex.utils.exec import ( get_compile_context, @@ -247,46 +244,6 @@ def default_error_boundary(*children: Component, **props) -> Component: ) -@dataclasses.dataclass(frozen=True) -class UploadFile(StarletteUploadFile): - """A file uploaded to the server. - - Args: - file: The standard Python file object (non-async). - filename: The original file name. - size: The size of the file in bytes. - headers: The headers of the request. - """ - - file: BinaryIO - - path: Path | None = dataclasses.field(default=None) - - size: int | None = dataclasses.field(default=None) - - headers: Headers = dataclasses.field(default_factory=Headers) - - @property - def filename(self) -> str | None: - """Get the name of the uploaded file. - - Returns: - The name of the uploaded file. - """ - return self.name - - @property - def name(self) -> str | None: - """Get the name of the uploaded file. - - Returns: - The name of the uploaded file. - """ - if self.path: - return self.path.name - return None - - @dataclasses.dataclass( frozen=True, ) @@ -1896,174 +1853,6 @@ async def health(_request: Request) -> JSONResponse: return JSONResponse(content=health_status, status_code=status_code) -class _UploadStreamingResponse(StreamingResponse): - """Streaming response that always releases upload form resources.""" - - _on_finish: Callable[[], Awaitable[None]] - - def __init__( - self, - *args: Any, - on_finish: Callable[[], Awaitable[None]], - **kwargs: Any, - ) -> None: - super().__init__(*args, **kwargs) - self._on_finish = on_finish - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - try: - await super().__call__(scope, receive, send) - finally: - await self._on_finish() - - -def upload(app: App): - """Upload a file. - - Args: - app: The app to upload the file for. - - Returns: - The upload function. - """ - - async def upload_file(request: Request): - """Upload a file. - - Args: - request: The Starlette request object. - - Returns: - StreamingResponse yielding newline-delimited JSON of StateUpdate - emitted by the upload handler. - - Raises: - UploadValueError: if there are no args with supported annotation. - UploadTypeError: if a background task is used as the handler. - HTTPException: when the request does not include token / handler headers. - """ - from reflex.utils.exceptions import UploadTypeError, UploadValueError - - # Get the files from the request. - try: - form_data = await request.form() - except ClientDisconnect: - return Response() # user cancelled - - form_data_closed = False - - async def _close_form_data() -> None: - """Close the parsed form data exactly once.""" - nonlocal form_data_closed - if form_data_closed: - return - form_data_closed = True - await form_data.close() - - async def _create_upload_event() -> Event: - """Create an upload event using the live Starlette temp files. - - Returns: - The upload event backed by the original temp files. - """ - files = form_data.getlist("files") - if not files: - msg = "No files were uploaded." - raise UploadValueError(msg) - - token = request.headers.get("reflex-client-token") - handler = request.headers.get("reflex-event-handler") - - if not token or not handler: - raise HTTPException( - status_code=400, - detail="Missing reflex-client-token or reflex-event-handler header.", - ) - - # Get the state for the session. - substate_token = _substate_key(token, handler.rpartition(".")[0]) - state = await app.state_manager.get_state(substate_token) - - handler_upload_param = () - - _current_state, event_handler = state._get_event_handler(handler) - - if event_handler.is_background: - msg = f"@rx.event(background=True) is not supported for upload handler `{handler}`." - raise UploadTypeError(msg) - func = event_handler.fn - if isinstance(func, functools.partial): - func = func.func - for k, v in get_type_hints(func).items(): - if types.is_generic_alias(v) and types._issubclass( - get_args(v)[0], - UploadFile, - ): - handler_upload_param = (k, v) - break - - if not handler_upload_param: - msg = ( - f"`{handler}` handler should have a parameter annotated as " - "list[rx.UploadFile]" - ) - raise UploadValueError(msg) - - # Keep the parsed form data alive until the upload event finishes so - # the underlying Starlette temp files remain available to the handler. - file_uploads = [] - for file in files: - if not isinstance(file, StarletteUploadFile): - raise UploadValueError( - "Uploaded file is not an UploadFile." + str(file) - ) - file_uploads.append( - UploadFile( - file=file.file, - path=Path(file.filename.lstrip("/")) if file.filename else None, - size=file.size, - headers=file.headers, - ) - ) - - return Event( - token=token, - name=handler, - payload={handler_upload_param[0]: file_uploads}, - ) - - event: Event | None = None - try: - event = await _create_upload_event() - finally: - if event is None: - await _close_form_data() - - async def _ndjson_updates(): - """Process the upload event, generating ndjson updates. - - Yields: - Each state update as JSON followed by a new line. - """ - # Process the event. - async with app.state_manager.modify_state_with_links( - event.substate_token, event=event - ) as state: - async for update in state._process(event): - # Postprocess the event. - update = await app._postprocess(state, event, update) - yield update.json() + "\n" - - # Stream updates to client - return _UploadStreamingResponse( - _ndjson_updates(), - media_type="application/x-ndjson", - on_finish=_close_form_data, - ) - - return upload_file - - class EventNamespace(AsyncNamespace): """The event namespace.""" diff --git a/reflex/components/core/upload.py b/reflex/components/core/upload.py index b6fdd476001..a74da9e3d8b 100644 --- a/reflex/components/core/upload.py +++ b/reflex/components/core/upload.py @@ -6,7 +6,7 @@ from pathlib import Path from typing import Any, ClassVar -from reflex.app import UploadFile +from reflex._upload import UploadChunkIterator, UploadFile from reflex.components.base.fragment import Fragment from reflex.components.component import ( Component, @@ -172,6 +172,11 @@ def get_upload_url(file_path: str | Var[str]) -> Var[str]: _on_drop_spec = passthrough_event_spec(list[UploadFile]) +_on_drop_args_spec = ( + _on_drop_spec, + passthrough_event_spec(UploadChunkIterator), +) +_UPLOAD_FILES_CLIENT_HANDLER = "uploadFiles" def _default_drop_rejected(rejected_files: ArrayVar[list[dict[str, Any]]]) -> EventSpec: @@ -211,7 +216,8 @@ class UploadFilesProvider(Component): class GhostUpload(Fragment): """A ghost upload component.""" - on_drop: EventHandler[_on_drop_spec] = field(doc="Fired when files are dropped.") + # Fired when files are dropped. + on_drop: EventHandler[_on_drop_args_spec] on_drop_rejected: EventHandler[_on_drop_spec] = field( doc="Fired when dropped files do not meet the specified criteria." @@ -254,7 +260,8 @@ class Upload(MemoizationLeaf): # Marked True when any Upload component is created. is_used: ClassVar[bool] = False - on_drop: EventHandler[_on_drop_spec] = field(doc="Fired when files are dropped.") + # Fired when files are dropped. + on_drop: EventHandler[_on_drop_args_spec] on_drop_rejected: EventHandler[_on_drop_spec] = field( doc="Fired when dropped files do not meet the specified criteria." @@ -310,11 +317,12 @@ def create(cls, *children, **props) -> Component: if isinstance(event, EventHandler): event = event(upload_files(upload_id)) if isinstance(event, EventSpec): - # Call the lambda to get the event chain. - event = call_event_handler(event, _on_drop_spec) + if event.client_handler_name != _UPLOAD_FILES_CLIENT_HANDLER: + # Call the lambda to get the event chain. + event = call_event_handler(event, _on_drop_args_spec) elif isinstance(event, Callable): # Call the lambda to get the event chain. - event = call_event_fn(event, _on_drop_spec) + event = call_event_fn(event, _on_drop_args_spec) if isinstance(event, EventSpec): # Update the provided args for direct use with on_drop. event = event.with_args( diff --git a/reflex/event.py b/reflex/event.py index 959c57aa11e..a95885edac2 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -93,6 +93,95 @@ def substate_token(self) -> str: BACKGROUND_TASK_MARKER = "_reflex_background_task" EVENT_ACTIONS_MARKER = "_rx_event_actions" +UPLOAD_FILES_CLIENT_HANDLER = "uploadFiles" + + +def _handler_name(handler: "EventHandler") -> str: + """Get a stable fully qualified handler name for errors. + + Args: + handler: The handler to name. + + Returns: + The fully qualified handler name. + """ + if handler.state_full_name: + return f"{handler.state_full_name}.{handler.fn.__name__}" + return handler.fn.__qualname__ + + +def resolve_upload_handler_param(handler: "EventHandler") -> tuple[str, Any]: + """Validate and resolve the UploadFile list parameter for a handler. + + Args: + handler: The event handler to inspect. + + Returns: + The parameter name and annotation for the upload file argument. + + Raises: + UploadTypeError: If the handler is a background task. + UploadValueError: If the handler does not accept ``list[rx.UploadFile]``. + """ + from reflex._upload import UploadFile + from reflex.utils.exceptions import UploadTypeError, UploadValueError + + handler_name = _handler_name(handler) + if handler.is_background: + msg = ( + f"@rx.event(background=True) is not supported for upload handler " + f"`{handler_name}`." + ) + raise UploadTypeError(msg) + + func = handler.fn.func if isinstance(handler.fn, partial) else handler.fn + for name, annotation in get_type_hints(func).items(): + if name == "return" or get_origin(annotation) is not list: + continue + args = get_args(annotation) + if len(args) == 1 and typehint_issubclass(args[0], UploadFile): + return name, annotation + + msg = ( + f"`{handler_name}` handler should have a parameter annotated as " + "list[rx.UploadFile]" + ) + raise UploadValueError(msg) + + +def resolve_upload_chunk_handler_param(handler: "EventHandler") -> tuple[str, type]: + """Validate and resolve the UploadChunkIterator parameter for a handler. + + Args: + handler: The event handler to inspect. + + Returns: + The parameter name and annotation for the iterator argument. + + Raises: + UploadTypeError: If the handler is not a background task. + UploadValueError: If the handler does not accept an UploadChunkIterator. + """ + from reflex._upload import UploadChunkIterator + from reflex.utils.exceptions import UploadTypeError, UploadValueError + + handler_name = _handler_name(handler) + if not handler.is_background: + msg = f"@rx.event(background=True) is required for upload_files_chunk handler `{handler_name}`." + raise UploadTypeError(msg) + + func = handler.fn.func if isinstance(handler.fn, partial) else handler.fn + for name, annotation in get_type_hints(func).items(): + if name == "return": + continue + if annotation is UploadChunkIterator: + return name, annotation + + msg = ( + f"`{handler_name}` handler should have a parameter annotated as " + "rx.UploadChunkIterator" + ) + raise UploadValueError(msg) @dataclasses.dataclass( @@ -290,7 +379,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> "EventSpec": values = [] for arg in [*args, *kwargs.values()]: # Special case for file uploads. - if isinstance(arg, FileUpload): + if isinstance(arg, (FileUpload, UploadFilesChunk)): return arg.as_event_spec(handler=self) # Otherwise, convert to JSON. @@ -868,14 +957,22 @@ def on_upload_progress_args_spec(_prog: Var[dict[str, int | float | bool]]): """ return [_prog] - def as_event_spec(self, handler: EventHandler) -> EventSpec: - """Get the EventSpec for the file upload. + def _as_event_spec( + self, + handler: EventHandler, + *, + client_handler_name: str, + upload_param_name: str, + ) -> EventSpec: + """Create an upload EventSpec. Args: handler: The event handler. + client_handler_name: The client handler name. + upload_param_name: The upload argument name in the event handler. Returns: - The event spec for the handler. + The upload EventSpec. Raises: ValueError: If the on_upload_progress is not a valid event handler. @@ -886,14 +983,19 @@ def as_event_spec(self, handler: EventHandler) -> EventSpec: ) upload_id = self.upload_id if self.upload_id is not None else DEFAULT_UPLOAD_ID + upload_files_var = Var( + _js_expr="filesById", + _var_type=dict[str, Any], + _var_data=VarData.merge(upload_files_context_var_data), + ).to(ObjectVar)[LiteralVar.create(upload_id)] spec_args = [ ( Var(_js_expr="files"), - Var( - _js_expr="filesById", - _var_type=dict[str, Any], - _var_data=VarData.merge(upload_files_context_var_data), - ).to(ObjectVar)[LiteralVar.create(upload_id)], + upload_files_var, + ), + ( + Var(_js_expr="upload_param_name"), + LiteralVar.create(upload_param_name), ), ( Var(_js_expr="upload_id"), @@ -906,6 +1008,14 @@ def as_event_spec(self, handler: EventHandler) -> EventSpec: ), ), ] + if upload_param_name != "files": + spec_args.insert( + 1, + ( + Var(_js_expr=upload_param_name), + upload_files_var, + ), + ) if self.on_upload_progress is not None: on_upload_progress = self.on_upload_progress if isinstance(on_upload_progress, EventHandler): @@ -941,16 +1051,65 @@ def as_event_spec(self, handler: EventHandler) -> EventSpec: ) return EventSpec( handler=handler, - client_handler_name="uploadFiles", + client_handler_name=client_handler_name, args=tuple(spec_args), event_actions=handler.event_actions.copy(), ) + def as_event_spec(self, handler: EventHandler) -> EventSpec: + """Get the EventSpec for the file upload. + + Args: + handler: The event handler. + + Returns: + The event spec for the handler. + """ + from reflex.utils.exceptions import UploadValueError + + try: + upload_param_name, _annotation = resolve_upload_handler_param(handler) + except UploadValueError: + upload_param_name = "files" + return self._as_event_spec( + handler, + client_handler_name=UPLOAD_FILES_CLIENT_HANDLER, + upload_param_name=upload_param_name, + ) + # Alias for rx.upload_files upload_files = FileUpload +@dataclasses.dataclass( + init=True, + frozen=True, +) +class UploadFilesChunk(FileUpload): + """Class to represent a streaming file upload.""" + + def as_event_spec(self, handler: EventHandler) -> EventSpec: + """Get the EventSpec for the streaming file upload. + + Args: + handler: The event handler. + + Returns: + The event spec for the handler. + """ + upload_param_name, _annotation = resolve_upload_chunk_handler_param(handler) + return self._as_event_spec( + handler, + client_handler_name=UPLOAD_FILES_CLIENT_HANDLER, + upload_param_name=upload_param_name, + ) + + +# Alias for rx.upload_files_chunk +upload_files_chunk = UploadFilesChunk + + # Special server-side events. def server_side(name: str, sig: inspect.Signature, **kwargs) -> EventSpec: """A server-side event. @@ -2313,6 +2472,7 @@ class EventNamespace: # File Upload FileUpload = FileUpload + UploadFilesChunk = UploadFilesChunk # Type Aliases EventType = EventType @@ -2326,10 +2486,15 @@ class EventNamespace: _EVENT_FIELDS = _EVENT_FIELDS FORM_DATA = FORM_DATA upload_files = upload_files + upload_files_chunk = upload_files_chunk stop_propagation = stop_propagation prevent_default = prevent_default # Private/Internal Functions + resolve_upload_handler_param = staticmethod(resolve_upload_handler_param) + resolve_upload_chunk_handler_param = staticmethod( + resolve_upload_chunk_handler_param + ) _values_returned_from_event = staticmethod(_values_returned_from_event) _check_event_args_subclass_of_callback = staticmethod( _check_event_args_subclass_of_callback diff --git a/reflex/utils/pyi_generator.py b/reflex/utils/pyi_generator.py index 3dae7ff4c62..93a7bae6b18 100644 --- a/reflex/utils/pyi_generator.py +++ b/reflex/utils/pyi_generator.py @@ -536,13 +536,48 @@ def _extract_class_props_as_ast_nodes( return kwargs -def type_to_ast(typ: Any, cls: type) -> ast.expr: +def _get_visible_type_name( + typ: Any, type_hint_globals: Mapping[str, Any] | None +) -> str | None: + """Get a visible identifier for a type in the current module. + + Args: + typ: The type annotation to resolve. + type_hint_globals: The globals visible in the current module. + + Returns: + The visible identifier if one exists, otherwise None. + """ + if type_hint_globals is None: + return None + + type_name = getattr(typ, "__name__", None) + if ( + type_name is not None + and type_name in type_hint_globals + and type_hint_globals[type_name] is typ + ): + return type_name + + for name, value in type_hint_globals.items(): + if name.isidentifier() and value is typ: + return name + + return None + + +def type_to_ast( + typ: Any, + cls: type, + type_hint_globals: Mapping[str, Any] | None = None, +) -> ast.expr: """Converts any type annotation into its AST representation. Handles nested generic types, unions, etc. Args: typ: The type annotation to convert. cls: The class where the type annotation is used. + type_hint_globals: The globals visible where the annotation is used. Returns: The AST representation of the type annotation. @@ -573,6 +608,8 @@ def type_to_ast(typ: Any, cls: type) -> ast.expr: if all(a == b for a, b in zipped) and len(typ_parts) == len(cls_parts): return ast.Name(id=typ.__name__) + if visible_name := _get_visible_type_name(typ, type_hint_globals): + return ast.Name(id=visible_name) if ( typ.__module__ in DEFAULT_IMPORTS and typ.__name__ in DEFAULT_IMPORTS[typ.__module__] @@ -595,7 +632,7 @@ def type_to_ast(typ: Any, cls: type) -> ast.expr: return ast.Name(id=base_name) # Convert all type arguments recursively - arg_nodes = [type_to_ast(arg, cls) for arg in args] + arg_nodes = [type_to_ast(arg, cls, type_hint_globals) for arg in args] # Special case for single-argument types (like list[T] or Optional[T]) if len(arg_nodes) == 1: @@ -694,7 +731,10 @@ def figure_out_return_type(annotation: Any): ] # Convert each argument type to its AST representation - type_args = [type_to_ast(arg, cls=clz) for arg in arguments_without_var] + type_args = [ + type_to_ast(arg, cls=clz, type_hint_globals=type_hint_globals) + for arg in arguments_without_var + ] # Get all prefixes of the type arguments all_count_args_type = [ diff --git a/tests/integration/test_upload.py b/tests/integration/test_upload.py index ed4a5456cd1..9af85805a1f 100644 --- a/tests/integration/test_upload.py +++ b/tests/integration/test_upload.py @@ -6,11 +6,13 @@ import time from collections.abc import Generator from pathlib import Path +from typing import Any, cast from urllib.parse import urlsplit import pytest from selenium.webdriver.common.by import By +import reflex as rx from reflex.constants.event import Endpoint from reflex.testing import AppHarness, WebDriver @@ -27,9 +29,12 @@ class UploadState(rx.State): _file_data: dict[str, str] = {} event_order: rx.Field[list[str]] = rx.field([]) progress_dicts: rx.Field[list[dict]] = rx.field([]) + stream_progress_dicts: rx.Field[list[dict]] = rx.field([]) disabled: rx.Field[bool] = rx.field(False) large_data: rx.Field[str] = rx.field("") quaternary_names: rx.Field[list[str]] = rx.field([]) + stream_chunk_records: rx.Field[list[str]] = rx.field([]) + stream_completed_files: rx.Field[list[str]] = rx.field([]) @rx.event async def handle_upload(self, files: list[rx.UploadFile]): @@ -57,6 +62,11 @@ def chain_event(self): self.large_data = "" self.event_order.append("chain_event") + @rx.event + def stream_upload_progress(self, progress): + assert progress + self.stream_progress_dicts.append(progress) + @rx.event async def handle_upload_tertiary(self, files: list[rx.UploadFile]): for file in files: @@ -68,6 +78,35 @@ async def handle_upload_tertiary(self, files: list[rx.UploadFile]): async def handle_upload_quaternary(self, files: list[rx.UploadFile]): self.quaternary_names = [file.name for file in files if file.name] + @rx.event(background=True) + async def handle_upload_stream(self, chunk_iter: rx.UploadChunkIterator): + upload_dir = rx.get_upload_dir() / "streaming" + file_handles: dict[str, Any] = {} + + try: + async for chunk in chunk_iter: + path = upload_dir / chunk.filename + path.parent.mkdir(parents=True, exist_ok=True) + + fh = file_handles.get(chunk.filename) + if fh is None: + fh = path.open("r+b") if path.exists() else path.open("wb") + file_handles[chunk.filename] = fh + + fh.seek(chunk.offset) + fh.write(chunk.data) + + async with self: + self.stream_chunk_records.append( + f"{chunk.filename}:{chunk.offset}:{len(chunk.data)}" + ) + finally: + for fh in file_handles.values(): + fh.close() + + async with self: + self.stream_completed_files = sorted(file_handles) + @rx.event def do_download(self): return rx.download(rx.get_upload_url("test.txt")) @@ -188,6 +227,44 @@ def index(): UploadState.quaternary_names.to_string(), id="quaternary_files", ), + rx.heading("Streaming Upload"), + rx.upload.root( + rx.vstack( + rx.button("Select File"), + rx.text("Drag and drop files here or click to select files"), + ), + id="streaming", + ), + rx.button( + "Upload", + on_click=UploadState.handle_upload_stream( + rx.upload_files_chunk( # pyright: ignore [reportArgumentType] + upload_id="streaming", + on_upload_progress=UploadState.stream_upload_progress, + ) + ), + id="upload_button_streaming", + ), + rx.box( + rx.foreach( + rx.selected_files("streaming"), + lambda f: rx.text(f, as_="p"), + ), + id="selected_files_streaming", + ), + rx.button( + "Cancel", + on_click=rx.cancel_upload("streaming"), + id="cancel_button_streaming", + ), + rx.text( + UploadState.stream_chunk_records.to_string(), + id="stream_chunk_records", + ), + rx.text( + UploadState.stream_completed_files.to_string(), + id="stream_completed_files", + ), rx.text(UploadState.event_order.to_string(), id="event-order"), ) @@ -487,6 +564,140 @@ async def _progress_dicts(): target_file.unlink() +@pytest.mark.asyncio +async def test_upload_chunk_file(tmp_path, upload_file: AppHarness, driver: WebDriver): + """Submit a streaming upload and check that chunks are processed incrementally.""" + assert upload_file.app_instance is not None + token = poll_for_token(driver, upload_file) + state_name = upload_file.get_state_name("_upload_state") + state_full_name = upload_file.get_full_state_name(["_upload_state"]) + substate_token = f"{token}_{state_full_name}" + + upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[4] + upload_button = driver.find_element(By.ID, "upload_button_streaming") + selected_files = driver.find_element(By.ID, "selected_files_streaming") + chunk_records_display = driver.find_element(By.ID, "stream_chunk_records") + completed_files_display = driver.find_element(By.ID, "stream_completed_files") + + exp_files = { + "stream1.txt": "ABCD" * 262_144, + "stream2.txt": "WXYZ" * 262_144, + } + for exp_name, exp_contents in exp_files.items(): + target_file = tmp_path / exp_name + target_file.write_text(exp_contents) + upload_box.send_keys(str(target_file)) + + await asyncio.sleep(0.2) + + assert [Path(name).name for name in selected_files.text.split("\n")] == [ + Path(name).name for name in exp_files + ] + + upload_button.click() + + AppHarness.expect(lambda: "stream1.txt" in chunk_records_display.text) + + async def _stream_completed(): + state = await upload_file.get_state(substate_token) + return ( + len( + state.substates[state_name].stream_completed_files # pyright: ignore[reportAttributeAccessIssue] + ) + == 2 + ) + + await AppHarness._poll_for_async(_stream_completed) + + state = await upload_file.get_state(substate_token) + substate = cast(Any, state.substates[state_name]) + chunk_records = substate.stream_chunk_records + + assert len(chunk_records) > 2 + assert {Path(record.split(":")[0]).name for record in chunk_records} == { + "stream1.txt", + "stream2.txt", + } + assert substate.stream_completed_files == ["stream1.txt", "stream2.txt"] + + AppHarness.expect( + lambda: ( + "stream1.txt" in completed_files_display.text + and "stream2.txt" in completed_files_display.text + ) + ) + + for exp_name, exp_contents in exp_files.items(): + assert ( + rx.get_upload_dir() / "streaming" / exp_name + ).read_text() == exp_contents + + +@pytest.mark.asyncio +async def test_cancel_upload_chunk( + tmp_path, + upload_file: AppHarness, + driver: WebDriver, +): + """Submit a large streaming upload and cancel it.""" + assert upload_file.app_instance is not None + driver.execute_cdp_cmd("Network.enable", {}) + driver.execute_cdp_cmd( + "Network.emulateNetworkConditions", + { + "offline": False, + "downloadThroughput": 1024 * 1024 / 8, # 1 Mbps + "uploadThroughput": 1024 * 1024 / 8, # 1 Mbps + "latency": 200, # 200ms + }, + ) + token = poll_for_token(driver, upload_file) + state_name = upload_file.get_state_name("_upload_state") + state_full_name = upload_file.get_full_state_name(["_upload_state"]) + substate_token = f"{token}_{state_full_name}" + + upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[4] + upload_button = driver.find_element(By.ID, "upload_button_streaming") + cancel_button = driver.find_element(By.ID, "cancel_button_streaming") + + exp_name = "cancel_stream.txt" + target_file = tmp_path / exp_name + with target_file.open("wb") as f: + f.seek(2 * 1024 * 1024) + f.write(b"0") + + upload_box.send_keys(str(target_file)) + upload_button.click() + await asyncio.sleep(1) + cancel_button.click() + + await asyncio.sleep(12) + + async def _stream_progress_dicts(): + state = await upload_file.get_state(substate_token) + return ( + state.substates[state_name].stream_progress_dicts # pyright: ignore[reportAttributeAccessIssue] + ) + + assert await AppHarness._poll_for_async(_stream_progress_dicts) + + for progress in await _stream_progress_dicts(): + assert progress["progress"] != 1 + + state = await upload_file.get_state(substate_token) + substate = cast(Any, state.substates[state_name]) + assert substate.stream_completed_files == [] + assert substate.stream_chunk_records + + partial_path = rx.get_upload_dir() / "streaming" / exp_name + assert partial_path.exists() + assert partial_path.stat().st_size < target_file.stat().st_size + + target_file.unlink() + if partial_path.exists(): + partial_path.unlink() + + def test_upload_download_file( tmp_path, upload_file: AppHarness, diff --git a/tests/units/components/core/test_upload.py b/tests/units/components/core/test_upload.py index 3b03362d6e4..ddeed335c59 100644 --- a/tests/units/components/core/test_upload.py +++ b/tests/units/components/core/test_upload.py @@ -1,5 +1,8 @@ -from typing import Any +from typing import Any, cast +import pytest + +import reflex as rx from reflex import event from reflex.components.core.upload import ( StyledUpload, @@ -9,7 +12,7 @@ cancel_upload, get_upload_url, ) -from reflex.event import EventSpec +from reflex.event import EventChain, EventHandler, EventSpec from reflex.state import State from reflex.vars.base import LiteralVar, Var @@ -33,6 +36,31 @@ def not_drop_handler(self, not_files: Any): not_files: The files dropped. """ + @event + async def upload_alias_handler(self, uploads: list[rx.UploadFile]): + """Handle uploaded files with a non-default parameter name.""" + + +class StreamingUploadStateTest(State): + """Test state for streaming uploads.""" + + @event(background=True) + async def chunk_drop_handler(self, chunk_iter: rx.UploadChunkIterator): + """Handle streamed upload chunks.""" + + @event(background=True) + async def chunk_upload_alias_handler(self, stream: rx.UploadChunkIterator): + """Handle streamed upload chunks with a non-default parameter name.""" + + async def chunk_drop_handler_not_background( + self, chunk_iter: rx.UploadChunkIterator + ): + """Invalid handler used to validate background-task requirement.""" + + @event(background=True) + async def chunk_drop_handler_missing_annotation(self, chunk_iter): + """Invalid handler missing the UploadChunkIterator annotation.""" + def test_cancel_upload(): spec = cancel_upload("foo_id") @@ -48,6 +76,37 @@ def test__on_drop_spec(): assert isinstance(_on_drop_spec(LiteralVar.create([])), tuple) +def test_upload_files_chunk_requires_background(): + with pytest.raises(TypeError) as err: + event.resolve_upload_chunk_handler_param( + cast( + EventHandler, StreamingUploadStateTest.chunk_drop_handler_not_background + ) + ) + + assert ( + err.value.args[0] + == "@rx.event(background=True) is required for upload_files_chunk handler " + f"`{StreamingUploadStateTest.get_full_name()}.chunk_drop_handler_not_background`." + ) + + +def test_upload_files_chunk_requires_iterator_annotation(): + with pytest.raises(ValueError) as err: + event.resolve_upload_chunk_handler_param( + cast( + EventHandler, + StreamingUploadStateTest.chunk_drop_handler_missing_annotation, + ) + ) + + assert ( + err.value.args[0] + == f"`{StreamingUploadStateTest.get_full_name()}.chunk_drop_handler_missing_annotation` " + "handler should have a parameter annotated as rx.UploadChunkIterator" + ) + + def test_upload_create(): up_comp_1 = Upload.create() assert isinstance(up_comp_1, Upload) @@ -83,6 +142,53 @@ def test_upload_create(): assert isinstance(up_comp_4, Upload) assert up_comp_4.is_used + # reset is_used + Upload.is_used = False + + up_comp_5 = Upload.create( + id="foo_id", + on_drop=StreamingUploadStateTest.chunk_drop_handler( + rx.upload_files_chunk(upload_id="foo_id") # pyright: ignore[reportArgumentType] + ), + ) + assert isinstance(up_comp_5, Upload) + assert up_comp_5.is_used + + up_comp_6 = Upload.create( + id="foo_id", + on_drop=StreamingUploadStateTest.chunk_upload_alias_handler( + rx.upload_files_chunk(upload_id="foo_id") # pyright: ignore[reportArgumentType] + ), + ) + assert isinstance(up_comp_6, Upload) + assert up_comp_6.is_used + + +def test_upload_button_handlers_allow_custom_param_names(): + legacy_button = rx.button( + "Upload", + on_click=UploadStateTest.upload_alias_handler( + cast(Any, rx.upload_files(upload_id="foo_id")) + ), + ) + legacy_chain = cast(EventChain, legacy_button.event_triggers["on_click"]) + legacy_event = cast(EventSpec, legacy_chain.events[0]) + legacy_arg_names = [arg[0]._js_expr for arg in legacy_event.args] + assert legacy_event.client_handler_name == "uploadFiles" + assert legacy_arg_names[:3] == ["files", "uploads", "upload_param_name"] + + chunk_button = rx.button( + "Upload", + on_click=StreamingUploadStateTest.chunk_upload_alias_handler( + rx.upload_files_chunk(upload_id="foo_id") # pyright: ignore[reportArgumentType] + ), + ) + chunk_chain = cast(EventChain, chunk_button.event_triggers["on_click"]) + chunk_event = cast(EventSpec, chunk_chain.events[0]) + chunk_arg_names = [arg[0]._js_expr for arg in chunk_event.args] + assert chunk_event.client_handler_name == "uploadFiles" + assert chunk_arg_names[:3] == ["files", "stream", "upload_param_name"] + def test_styled_upload_create(): styled_up_comp_1 = StyledUpload.create() diff --git a/tests/units/states/upload.py b/tests/units/states/upload.py index 6c732796a73..1ec847f808a 100644 --- a/tests/units/states/upload.py +++ b/tests/units/states/upload.py @@ -1,6 +1,7 @@ """Test states for upload-related tests.""" from pathlib import Path +from typing import BinaryIO import reflex as rx from reflex.state import BaseState, State @@ -65,6 +66,10 @@ async def multi_handle_upload(self, files: list[rx.UploadFile]): self.img_list.append(file.name) yield + async def upload_alias_handler(self, uploads: list[rx.UploadFile]): + """Handle uploaded files with a non-default parameter name.""" + self.img_list = [f"count:{len(uploads)}"] + @rx.event(background=True) async def bg_upload(self, files: list[rx.UploadFile]): """Background task cannot be upload handler. @@ -78,6 +83,66 @@ class FileUploadState(_FileUploadMixin, State): """The base state for uploading a file.""" +class _ChunkUploadMixin(BaseState, mixin=True): + """Common fields and handlers for chunk upload tests.""" + + chunk_records: list[str] + completed_files: list[str] + _tmp_path: Path = Path() + + @rx.event(background=True) + async def chunk_handle_upload(self, chunk_iter: rx.UploadChunkIterator): + """Handle a chunked upload in the background.""" + file_handles: dict[str, BinaryIO] = {} + + try: + async for chunk in chunk_iter: + outfile = self._tmp_path / chunk.filename + outfile.parent.mkdir(parents=True, exist_ok=True) + + fh = file_handles.get(chunk.filename) + if fh is None: + fh = outfile.open("r+b") if outfile.exists() else outfile.open("wb") + file_handles[chunk.filename] = fh + + fh.seek(chunk.offset) + fh.write(chunk.data) + + async with self: + self.chunk_records.append( + f"{chunk.filename}:{chunk.offset}:{len(chunk.data)}:{chunk.content_type}" + ) + finally: + for fh in file_handles.values(): + fh.close() + + async with self: + self.completed_files = sorted(file_handles) + + async def chunk_handle_upload_not_background( + self, chunk_iter: rx.UploadChunkIterator + ): + """Invalid streaming upload handler used for compile-time validation tests.""" + + @rx.event(background=True) + async def chunk_handle_upload_missing_annotation(self, chunk_iter): + """Invalid streaming upload handler missing the iterator annotation.""" + + @rx.event(background=True) + async def chunk_handle_upload_alias(self, stream: rx.UploadChunkIterator): + """Handle streamed upload chunks with a non-default parameter name.""" + chunk_count = 0 + async for _chunk in stream: + chunk_count += 1 + + async with self: + self.completed_files = [f"chunks:{chunk_count}"] + + +class ChunkUploadState(_ChunkUploadMixin, State): + """The base state for streaming chunk uploads.""" + + class FileStateBase1(State): """The base state for a child FileUploadState.""" diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 25c71c0d17e..97dec39eb51 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -10,13 +10,14 @@ from contextlib import nullcontext as does_not_raise from importlib.util import find_spec from pathlib import Path +from types import SimpleNamespace from typing import TYPE_CHECKING, Any, ClassVar from unittest.mock import AsyncMock import pytest from pytest_mock import MockerFixture from starlette.applications import Starlette -from starlette.datastructures import FormData, UploadFile +from starlette.datastructures import FormData, Headers, UploadFile from starlette.responses import StreamingResponse import reflex as rx @@ -57,6 +58,7 @@ from .states import GenState from .states.upload import ( ChildFileUploadState, + ChunkUploadState, FileStateBase1, FileUploadState, GrandChildFileUploadState, @@ -1084,12 +1086,65 @@ async def send(message): # noqa: RUF029 await app.state_manager.close() +@pytest.mark.asyncio +async def test_upload_empty_buffered_request_dispatches_alias_handler( + token: str, + mocker: MockerFixture, +): + """Test that empty uploads still dispatch buffered alias handlers.""" + mocker.patch( + "reflex.state.State.class_subclasses", + {FileUploadState}, + ) + app = App() + app.event_namespace.emit = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] + + async with app.modify_state(_substate_key(token, FileUploadState)) as root_state: + substate = root_state.get_substate(FileUploadState.get_full_name().split(".")) + substate.img_list = [] + + request_mock = unittest.mock.Mock() + request_mock.headers = { + "reflex-client-token": token, + "reflex-event-handler": f"{FileUploadState.get_full_name()}.upload_alias_handler", + } + + async def form(): # noqa: RUF029 + return FormData() + + request_mock.form = form + + upload_fn = upload(app) + streaming_response = await upload_fn(request_mock) + assert isinstance(streaming_response, StreamingResponse) + + updates = [] + async for state_update in streaming_response.body_iterator: + updates.append(json.loads(str(state_update))) + + assert updates[-1]["final"] + + if environment.REFLEX_OPLOCK_ENABLED.get(): + await app.state_manager.close() + + state = await app.state_manager.get_state(_substate_key(token, FileUploadState)) + substate = ( + state + if isinstance(state, FileUploadState) + else state.get_substate(FileUploadState.get_full_name().split(".")) + ) + assert isinstance(substate, FileUploadState) + assert substate.img_list == ["count:0"] + + await app.state_manager.close() + + @pytest.mark.asyncio async def test_upload_file_closes_form_on_event_creation_cancellation( token: str, mocker: MockerFixture, ): - """Test that cancellation during upload event creation closes form data.""" + """Test that cancellation before form parsing leaves form data untouched.""" mocker.patch( "reflex.state.State.class_subclasses", {FileUploadState}, @@ -1122,8 +1177,8 @@ async def cancelled_get_state(*_args, **_kwargs): with pytest.raises(asyncio.CancelledError): await upload_fn(request_mock) - assert form_close.await_count == 1 - assert file1.file.closed + assert form_close.await_count == 0 + assert not file1.file.closed await app.state_manager.close() @@ -1271,6 +1326,293 @@ async def form(): # noqa: RUF029 await app.state_manager.close() +def _build_chunk_upload_multipart_body( + boundary: str, + parts: list[tuple[str, str, str, bytes]], +) -> bytes: + """Build a multipart upload body for chunk upload tests. + + Args: + boundary: The multipart boundary string. + parts: Tuples of field name, filename, content type, and payload. + + Returns: + The encoded multipart body bytes. + """ + body = bytearray() + for field_name, filename, content_type, data in parts: + body.extend(f"--{boundary}\r\n".encode()) + body.extend( + ( + f'Content-Disposition: form-data; name="{field_name}"; ' + f'filename="{filename}"\r\n' + ).encode() + ) + body.extend(f"Content-Type: {content_type}\r\n\r\n".encode()) + body.extend(data) + body.extend(b"\r\n") + body.extend(f"--{boundary}--\r\n".encode()) + return bytes(body) + + +def _make_chunk_upload_request( + token: str, + handler_name: str, + body: bytes, + *, + content_type: str, + stream_chunk_size: int = 17, +): + """Create a mocked request for the chunk upload endpoint. + + Returns: + A mocked Starlette request object. + """ + request_mock = unittest.mock.Mock() + request_mock.headers = Headers({ + "content-type": content_type, + "reflex-client-token": token, + "reflex-event-handler": handler_name, + }) + request_mock.query_params = {} + + async def stream(): + for index in range(0, len(body), stream_chunk_size): + yield body[index : index + stream_chunk_size] + yield b"" + await asyncio.sleep(0) + + request_mock.stream = stream + return request_mock + + +async def _drain_background_tasks(app: App): + """Wait for all background tasks associated with an app. + + Returns: + The gathered background task results. + """ + tasks = tuple(app._background_tasks) + results = await asyncio.gather(*tasks, return_exceptions=True) if tasks else [] + if environment.REFLEX_OPLOCK_ENABLED.get(): + # Redis oplocks can keep completed background-task writes in the local + # lease cache until the manager is closed. + await app.state_manager.close() + return results + + +@pytest.mark.asyncio +async def test_upload_dispatches_chunk_handlers_on_upload_endpoint( + tmp_path, + token: str, + mocker: MockerFixture, +): + """Test that the standard upload endpoint dispatches chunk handlers.""" + mocker.patch( + "reflex.state.State.class_subclasses", + {ChunkUploadState}, + ) + app = App() + mocker.patch( + "reflex.utils.prerequisites.get_and_validate_app", + return_value=SimpleNamespace(app=app), + ) + app.event_namespace.emit_update = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] + + async with app.modify_state(_substate_key(token, ChunkUploadState)) as root_state: + substate = root_state.get_substate(ChunkUploadState.get_full_name().split(".")) + substate._tmp_path = tmp_path + substate.chunk_records = [] + substate.completed_files = [] + + upload_fn = upload(app) + boundary = "chunk-upload-on-upload-endpoint-boundary" + response = await upload_fn( + _make_chunk_upload_request( + token, + f"{ChunkUploadState.get_full_name()}.chunk_handle_upload", + _build_chunk_upload_multipart_body( + boundary, + [ + ("files", "alpha.txt", "text/plain", b"abcde"), + ("files", "beta.txt", "text/plain", b"12345"), + ], + ), + content_type=f"multipart/form-data; boundary={boundary}", + stream_chunk_size=1, + ) + ) + + assert isinstance(response, StreamingResponse) + assert response.status_code == 202 + + updates = [] + async for state_update in response.body_iterator: + updates.append(json.loads(str(state_update))) + assert updates == [{"delta": {}, "events": [], "final": True}] + + task_results = await _drain_background_tasks(app) + assert all(result is None for result in task_results) + + state = await app.state_manager.get_state(_substate_key(token, ChunkUploadState)) + substate = ( + state + if isinstance(state, ChunkUploadState) + else state.get_substate(ChunkUploadState.get_full_name().split(".")) + ) + assert isinstance(substate, ChunkUploadState) + parsed_chunk_records = [ + (filename, int(offset), int(size), content_type) + for filename, offset, size, content_type in ( + record.rsplit(":", 3) for record in substate.chunk_records + ) + ] + assert len(parsed_chunk_records) >= 4 + assert {filename for filename, *_ in parsed_chunk_records} == { + "alpha.txt", + "beta.txt", + } + assert all( + content_type == "text/plain" for *_, content_type in parsed_chunk_records + ) + assert ( + sum( + size + for filename, _offset, size, _content_type in parsed_chunk_records + if filename == "alpha.txt" + ) + == 5 + ) + assert ( + sum( + size + for filename, _offset, size, _content_type in parsed_chunk_records + if filename == "beta.txt" + ) + == 5 + ) + assert parsed_chunk_records[0][0] == "alpha.txt" + assert parsed_chunk_records[-1][0] == "beta.txt" + assert substate.completed_files == ["alpha.txt", "beta.txt"] + assert (tmp_path / "alpha.txt").read_bytes() == b"abcde" + assert (tmp_path / "beta.txt").read_bytes() == b"12345" + assert app.event_namespace.emit_update.await_count >= 1 # pyright: ignore [reportOptionalMemberAccess] + assert not app._background_tasks + + await app.state_manager.close() + + +@pytest.mark.asyncio +async def test_upload_empty_chunk_request_dispatches_alias_handler( + token: str, + mocker: MockerFixture, +): + """Test that empty uploads still dispatch chunk alias handlers.""" + mocker.patch( + "reflex.state.State.class_subclasses", + {ChunkUploadState}, + ) + app = App() + mocker.patch( + "reflex.utils.prerequisites.get_and_validate_app", + return_value=SimpleNamespace(app=app), + ) + app.event_namespace.emit_update = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] + + async with app.modify_state(_substate_key(token, ChunkUploadState)) as root_state: + substate = root_state.get_substate(ChunkUploadState.get_full_name().split(".")) + substate.chunk_records = [] + substate.completed_files = [] + + upload_fn = upload(app) + boundary = "chunk-upload-empty-alias-boundary" + response = await upload_fn( + _make_chunk_upload_request( + token, + f"{ChunkUploadState.get_full_name()}.chunk_handle_upload_alias", + _build_chunk_upload_multipart_body(boundary, []), + content_type=f"multipart/form-data; boundary={boundary}", + ) + ) + + assert isinstance(response, StreamingResponse) + assert response.status_code == 202 + + updates = [] + async for state_update in response.body_iterator: + updates.append(json.loads(str(state_update))) + assert updates == [{"delta": {}, "events": [], "final": True}] + + task_results = await _drain_background_tasks(app) + assert all(result is None for result in task_results) + + state = await app.state_manager.get_state(_substate_key(token, ChunkUploadState)) + substate = ( + state + if isinstance(state, ChunkUploadState) + else state.get_substate(ChunkUploadState.get_full_name().split(".")) + ) + assert isinstance(substate, ChunkUploadState) + assert substate.chunk_records == [] + assert substate.completed_files == ["chunks:0"] + assert not app._background_tasks + + await app.state_manager.close() + + +@pytest.mark.asyncio +async def test_upload_chunk_invalid_offset_returns_400( + token: str, + mocker: MockerFixture, +): + """Test that malformed chunk metadata fails the standard upload request.""" + mocker.patch( + "reflex.state.State.class_subclasses", + {ChunkUploadState}, + ) + app = App() + mocker.patch( + "reflex.utils.prerequisites.get_and_validate_app", + return_value=SimpleNamespace(app=app), + ) + app.event_namespace.emit_update = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] + + async with app.modify_state(_substate_key(token, ChunkUploadState)) as root_state: + substate = root_state.get_substate(ChunkUploadState.get_full_name().split(".")) + substate.chunk_records = [] + substate.completed_files = [] + + upload_fn = upload(app) + response = await upload_fn( + _make_chunk_upload_request( + token, + f"{ChunkUploadState.get_full_name()}.chunk_handle_upload", + b"abc", + content_type="text/plain", + ) + ) + + assert response.status_code == 400 + assert json.loads(bytes(response.body).decode()) == { + "detail": "Missing boundary in multipart." + } + + await _drain_background_tasks(app) + + state = await app.state_manager.get_state(_substate_key(token, ChunkUploadState)) + substate = ( + state + if isinstance(state, ChunkUploadState) + else state.get_substate(ChunkUploadState.get_full_name().split(".")) + ) + assert isinstance(substate, ChunkUploadState) + assert substate.chunk_records == [] + assert substate.completed_files == [] + assert not app._background_tasks + + await app.state_manager.close() + + class DynamicState(BaseState): """State class for testing dynamic route var.