Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ static = HashedStatic("static")
app.mount("/static", static)

# Wrap any ASGI app to rewrite static paths in HTML responses
app = StaticRewriteMiddleware(app, static=static)
StaticRewriteMiddleware(app, static=static)

# In templates, resolve cache-busted URLs:
static.url("styles.css") # /static/styles.a1b2c3d4.css
Expand Down
2 changes: 1 addition & 1 deletion docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ static = HashedStatic("static")
app.mount("/static", static)

# Wrap the app to rewrite static paths in HTML responses:
app = StaticRewriteMiddleware(app, static=static)
StaticRewriteMiddleware(app, static=static)
```

`HashedStatic` hashes every file in the directory at startup. When a browser requests the hashed filename, it gets an immutable cache header. When it requests the original filename, the file is served without aggressive caching.
Expand Down
74 changes: 37 additions & 37 deletions src/staticware/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
static = HashedStatic("static")

# Wrap any ASGI app to rewrite /static/styles.css -> /static/styles.a1b2c3d4.css
app = StaticRewriteMiddleware(your_app, static=static)
StaticRewriteMiddleware(your_app, static=static)

# In templates:
static.url("styles.css") # -> /static/styles.a1b2c3d4.css
Expand All @@ -19,8 +19,8 @@
import hashlib
import mimetypes
import re
from pathlib import Path
from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import Any

# ASGI protocol types — inlined so we depend on nothing.
Expand Down Expand Up @@ -102,7 +102,7 @@ def _hash_files(self) -> None:

self.file_map[relative] = hashed
self._reverse[hashed] = relative
self._etags[relative] = f'"{hash_val}"'.encode('latin-1')
self._etags[relative] = f'"{hash_val}"'.encode("latin-1")

def url(self, path: str) -> str:
"""Return the cache-busted URL for a static file path.
Expand Down Expand Up @@ -165,9 +165,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if hdr_name == b"if-none-match" and hdr_value == etag:
await _send_text(send, 304, b"")
return
await _send_file(
send, file_path, extra_headers=[(b"etag", etag)]
)
await _send_file(send, file_path, extra_headers=[(b"etag", etag)])
else:
await _send_file(send, file_path)
return
Expand All @@ -186,7 +184,7 @@ class StaticRewriteMiddleware:
HTML — no template function needed (though ``static.url()`` is there
if you want it).

app = StaticRewriteMiddleware(app, static=static)
StaticRewriteMiddleware(app, static=static)
"""

def __init__(self, app: ASGIApp, *, static: HashedStatic) -> None:
Expand Down Expand Up @@ -225,9 +223,7 @@ async def send_wrapper(message: dict[str, Any]) -> None:

if message["type"] == "http.response.body":
if response_start is None:
raise RuntimeError(
"http.response.body received before http.response.start"
)
raise RuntimeError("http.response.body received before http.response.start")
if not is_html:
await send(message)
return
Expand All @@ -246,13 +242,9 @@ async def send_wrapper(message: dict[str, Any]) -> None:
pass

if response_start is None:
raise RuntimeError(
"http.response.body received before http.response.start"
)
raise RuntimeError("http.response.body received before http.response.start")
new_headers = [
(k, str(len(full_body)).encode("latin-1"))
if k == b"content-length"
else (k, v)
(k, str(len(full_body)).encode("latin-1")) if k == b"content-length" else (k, v)
for k, v in response_start.get("headers", [])
]
response_start["headers"] = new_headers
Expand Down Expand Up @@ -283,28 +275,36 @@ async def _send_file(
if extra_headers:
headers.extend(extra_headers)

await send({
"type": "http.response.start",
"status": 200,
"headers": headers,
})
await send({
"type": "http.response.body",
"body": content,
})
await send(
{
"type": "http.response.start",
"status": 200,
"headers": headers,
}
)
await send(
{
"type": "http.response.body",
"body": content,
}
)


async def _send_text(send: Send, status: int, body: bytes) -> None:
"""Send a plain-text ASGI response."""
await send({
"type": "http.response.start",
"status": status,
"headers": [
(b"content-type", b"text/plain"),
(b"content-length", str(len(body)).encode("latin-1")),
],
})
await send({
"type": "http.response.body",
"body": body,
})
await send(
{
"type": "http.response.start",
"status": status,
"headers": [
(b"content-type", b"text/plain"),
(b"content-length", str(len(body)).encode("latin-1")),
],
}
)
await send(
{
"type": "http.response.body",
"body": body,
}
)
104 changes: 51 additions & 53 deletions tests/test_staticware.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from staticware import HashedStatic, StaticRewriteMiddleware


# ── Helpers ──────────────────────────────────────────────────────────────


Expand Down Expand Up @@ -61,7 +60,6 @@ def expected_hash(content: bytes, length: int = 8) -> str:
# ── HashedStatic: hashing and url() ──────────────────────────────────────



def test_file_map_contains_all_files(static: HashedStatic, static_dir: Path) -> None:
assert "styles.css" in static.file_map
assert "images/logo.png" in static.file_map
Expand Down Expand Up @@ -216,14 +214,16 @@ def make_html_app(html: str):
body = html.encode("utf-8")

async def app(scope: dict, receive: Any, send: Any) -> None:
await send({
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/html; charset=utf-8"),
(b"content-length", str(len(body)).encode("latin-1")),
],
})
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/html; charset=utf-8"),
(b"content-length", str(len(body)).encode("latin-1")),
],
}
)
await send({"type": "http.response.body", "body": body})

return app
Expand All @@ -233,14 +233,16 @@ def make_json_app(data: bytes):
"""Create a dummy ASGI app that returns JSON."""

async def app(scope: dict, receive: Any, send: Any) -> None:
await send({
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"application/json"),
(b"content-length", str(len(data)).encode("latin-1")),
],
})
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"application/json"),
(b"content-length", str(len(data)).encode("latin-1")),
],
}
)
await send({"type": "http.response.body", "body": data})

return app
Expand Down Expand Up @@ -318,10 +320,12 @@ async def test_rewrite_raises_runtime_error_on_body_before_start(

async def broken_app(scope: dict, receive: Any, send: Any) -> None:
# Skip http.response.start entirely — straight to body.
await send({
"type": "http.response.body",
"body": b"<html>oops</html>",
})
await send(
{
"type": "http.response.body",
"body": b"<html>oops</html>",
}
)

app = StaticRewriteMiddleware(broken_app, static=static)
with pytest.raises(RuntimeError):
Expand All @@ -335,14 +339,16 @@ async def test_rewrite_streaming_html_response(static: HashedStatic) -> None:

async def streaming_app(scope: dict, receive: Any, send: Any) -> None:
total = len(chunk1) + len(chunk2)
await send({
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/html; charset=utf-8"),
(b"content-length", str(total).encode("latin-1")),
],
})
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/html; charset=utf-8"),
(b"content-length", str(total).encode("latin-1")),
],
}
)
await send({"type": "http.response.body", "body": chunk1, "more_body": True})
await send({"type": "http.response.body", "body": chunk2, "more_body": False})

Expand Down Expand Up @@ -373,14 +379,16 @@ async def test_rewrite_non_utf8_html_passes_through(static: HashedStatic) -> Non
raw_body = b"<html>\x80\x81\x82 not valid utf-8</html>"

async def bad_encoding_app(scope: dict, receive: Any, send: Any) -> None:
await send({
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/html; charset=utf-8"),
(b"content-length", str(len(raw_body)).encode("latin-1")),
],
})
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/html; charset=utf-8"),
(b"content-length", str(len(raw_body)).encode("latin-1")),
],
}
)
await send({"type": "http.response.body", "body": raw_body})

app = StaticRewriteMiddleware(bad_encoding_app, static=static)
Expand All @@ -397,9 +405,7 @@ def make_mount_scope(path: str, *, root_path: str = "") -> dict[str, Any]:
return {"type": "http", "path": path, "root_path": root_path, "method": "GET"}


async def test_serve_with_root_path_scope(
static: HashedStatic, static_dir: Path
) -> None:
async def test_serve_with_root_path_scope(static: HashedStatic, static_dir: Path) -> None:
"""Starlette-style mount: root_path set, path still includes the prefix.

Starlette sets scope["root_path"] = "/static" and leaves
Expand All @@ -414,9 +420,7 @@ async def test_serve_with_root_path_scope(
assert resp.text == "body { color: red; }"


async def test_serve_with_stripped_path(
static: HashedStatic, static_dir: Path
) -> None:
async def test_serve_with_stripped_path(static: HashedStatic, static_dir: Path) -> None:
"""Litestar-style mount: framework strips the prefix from scope["path"].

The sub-app sees scope["root_path"] = "/static" and
Expand Down Expand Up @@ -466,18 +470,14 @@ async def test_serve_with_mismatched_mount_and_prefix(static_dir: Path) -> None:
# ── HashedStatic: ETag and conditional requests ──────────────────────


def make_scope_with_headers(
path: str, headers: list[tuple[bytes, bytes]] | None = None
) -> dict[str, Any]:
def make_scope_with_headers(path: str, headers: list[tuple[bytes, bytes]] | None = None) -> dict[str, Any]:
scope: dict[str, Any] = {"type": "http", "path": path, "method": "GET"}
if headers:
scope["headers"] = headers
return scope


async def test_etag_on_unhashed_response(
static: HashedStatic, static_dir: Path
) -> None:
async def test_etag_on_unhashed_response(static: HashedStatic, static_dir: Path) -> None:
"""Original filename response includes an ETag header with the content hash."""
resp = ResponseCollector()
await static(make_scope("/static/styles.css"), receive, resp)
Expand All @@ -489,9 +489,7 @@ async def test_etag_on_unhashed_response(
assert resp.headers[b"etag"] == f'"{h}"'.encode("latin-1")


async def test_conditional_request_returns_304(
static: HashedStatic, static_dir: Path
) -> None:
async def test_conditional_request_returns_304(static: HashedStatic, static_dir: Path) -> None:
"""If-None-Match with matching ETag returns 304 and empty body."""
css_content = (static_dir / "styles.css").read_bytes()
h = expected_hash(css_content)
Expand Down