Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions dash/backends/_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
) from _err

from dash.fingerprint import check_fingerprint
from dash import _validate
from dash import _validate, get_app
from dash.exceptions import PreventUpdate
from .base_server import BaseDashServer, RequestAdapter, ResponseAdapter
from ._utils import format_traceback_html
import traceback

if TYPE_CHECKING: # pragma: no cover - typing only
from dash import Dash
Expand Down Expand Up @@ -122,8 +123,12 @@ async def _initialize_dev_tools(self) -> None:
self.dash_app.enable_dev_tools(**config, first_run=False)
self._dev_tools_initialized = True

def _setup_timing(self, request: Request) -> None:
async def _setup_timing(self, request: Request) -> None:
"""Set up timing information for the request."""
try:
request.state.json_body = await request.json() if request.headers.get("content-type", "").startswith("application/json") else None
except:
request.state.json_body = None
if self.enable_timing:
request.state.timing_information = {
"__dash_server": {"dur": time.time(), "desc": None}
Expand Down Expand Up @@ -179,6 +184,12 @@ async def _handle_error(
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
# Handle lifespan events (startup/shutdown)
if scope["type"] == "lifespan":
try:
dash_app = get_app()
dash_app.backend._setup_catchall()
except:
print("Error during catch-all setup:")
print(traceback.format_exc())
await self._initialize_dev_tools()
await self.app(scope, receive, send)
return
Expand All @@ -193,7 +204,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
token = set_current_request(request)

try:
self._setup_timing(request)
await self._setup_timing(request)
await self._run_before_hooks()

await self.app(scope, receive, send)
Expand Down Expand Up @@ -275,11 +286,24 @@ async def index(_request: Request):
dash_app._add_url("", index, methods=["GET"])

def setup_catchall(self, dash_app: Dash):
async def catchall(_request: Request):
return Response(content=dash_app.index(), media_type="text/html")
'''This is needed to ensure that all routes are handled by FastAPI
and passed through the middleware, which is necessary for features like authentication
and timing to work correctly on all routes. FastAPI will match this catch-all route
for any path that isn't matched by a more specific route, allowing the middleware to
process the request and then return the appropriate response (e.g., 404 if no Dash route matches).'''

# pylint: disable=protected-access
dash_app._add_url("{path:path}", catchall, methods=["GET"])

def _setup_catchall(self):
try:
print("Setting up catch-all route for unmatched paths")
dash_app = get_app()
async def catchall(_request: Request):
return Response(content=dash_app.index(), media_type="text/html")

# pylint: disable=protected-access
self.add_url_rule("{path:path}", catchall, methods=["GET"])
except:
print(traceback.format_exc())

def add_url_rule(
self,
Expand All @@ -289,6 +313,7 @@ def add_url_rule(
methods: list[str] | None = None,
include_in_schema: bool = False,
):
print(f"Adding URL rule: {rule} -> {view_func} (endpoint: {endpoint}, methods: {methods})")
if rule == "":
rule = "/"
if isinstance(view_func, str):
Expand Down Expand Up @@ -481,7 +506,7 @@ def add_redirect_rule(self, app, fullname, path):
def serve_callback(self, dash_app: Dash):
async def _dispatch(request: Request):
# pylint: disable=protected-access
body = await request.json()
body = self.request_adapter().get_json()
cb_ctx = dash_app._initialize_context(
body
) # pylint: disable=protected-access
Expand Down Expand Up @@ -641,5 +666,13 @@ def origin(self):
def path(self):
return self._request.url.path

async def _get_json(self, request: Request=None):
req = self._request
if not hasattr(req.state, "json_body"):
req.state.json_body = await request.json()
return req.state.json_body

def get_json(self):
return asyncio.run(self._request.json())
if not hasattr(self, "_request") or self._request is None:
self._request = get_current_request()
return self._request.state.json_body
Loading