Skip to content

Commit

Permalink
fix: pass cache control header for static files (#3131)
Browse files Browse the repository at this point in the history
* fix: pass cache control header for static files

* refactor: pass resolved headers to StaticFile
  • Loading branch information
guacs authored Feb 27, 2024
1 parent 19f4f04 commit 1ff7f1e
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 2 deletions.
7 changes: 6 additions & 1 deletion litestar/static_files/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
class StaticFiles:
"""ASGI App that handles file sending."""

__slots__ = ("is_html_mode", "directories", "adapter", "send_as_attachment")
__slots__ = ("is_html_mode", "directories", "adapter", "send_as_attachment", "headers")

def __init__(
self,
Expand All @@ -31,6 +31,7 @@ def __init__(
file_system: FileSystemProtocol,
send_as_attachment: bool = False,
resolve_symlinks: bool = True,
headers: dict[str, str] | None = None,
) -> None:
"""Initialize the Application.
Expand All @@ -41,11 +42,13 @@ def __init__(
send_as_attachment: Whether to send the file with a ``content-disposition`` header of
``attachment`` or ``inline``
resolve_symlinks: Resolve symlinks to the directories
headers: Headers that will be sent with every response.
"""
self.adapter = FileSystemAdapter(file_system)
self.directories = tuple(Path(p).resolve() if resolve_symlinks else Path(p) for p in directories)
self.is_html_mode = is_html_mode
self.send_as_attachment = send_as_attachment
self.headers = headers

async def get_fs_info(
self, directories: Sequence[PathType], file_path: PathType
Expand Down Expand Up @@ -111,6 +114,7 @@ async def handle(self, path: str, is_head_response: bool) -> ASGIFileResponse:
filename=filename,
content_disposition_type=content_disposition_type,
is_head_response=is_head_response,
headers=self.headers,
)

if self.is_html_mode:
Expand All @@ -129,6 +133,7 @@ async def handle(self, path: str, is_head_response: bool) -> ASGIFileResponse:
status_code=HTTP_404_NOT_FOUND,
content_disposition_type=content_disposition_type,
is_head_response=is_head_response,
headers=self.headers,
)

raise NotFoundException(
Expand Down
5 changes: 5 additions & 0 deletions litestar/static_files/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,17 @@ def create_static_files_router(
_validate_config(path=path, directories=directories, file_system=file_system)
path = normalize_path(path)

headers = None
if cache_control:
headers = {cache_control.HEADER_NAME: cache_control.to_header()}

static_files = StaticFiles(
is_html_mode=html_mode,
directories=directories,
file_system=file_system,
send_as_attachment=send_as_attachment,
resolve_symlinks=resolve_symlinks,
headers=headers,
)

@get("{file_path:path}", name=name)
Expand Down
25 changes: 24 additions & 1 deletion tests/unit/test_static_files/test_create_static_router.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from typing import Any
from pathlib import Path
from typing import Any, Optional

import pytest

from litestar import Litestar, Request, Response, Router
from litestar.connection import ASGIConnection
from litestar.datastructures import CacheControlHeader
from litestar.exceptions import ValidationException
from litestar.handlers import BaseRouteHandler
from litestar.static_files import create_static_files_router
from litestar.status_codes import HTTP_200_OK
from litestar.testing.helpers import create_test_client


def test_route_reverse() -> None:
Expand Down Expand Up @@ -71,3 +76,21 @@ class MyRouter(Router):

router = create_static_files_router("/", directories=["some"], router_class=MyRouter)
assert isinstance(router, MyRouter)


@pytest.mark.parametrize("cache_control", (None, CacheControlHeader(max_age=3600)))
def test_cache_control(tmp_path: Path, cache_control: Optional[CacheControlHeader]) -> None:
static_dir = tmp_path / "foo"
static_dir.mkdir()
static_dir.joinpath("test.txt").write_text("hello")

router = create_static_files_router("/static", [static_dir], name="static", cache_control=cache_control)

with create_test_client([router]) as client:
response = client.get("static/test.txt")

assert response.status_code == HTTP_200_OK
if cache_control is not None:
assert response.headers["cache-control"] == cache_control.to_header()
else:
assert "cache-control" not in response.headers

0 comments on commit 1ff7f1e

Please sign in to comment.