diff --git a/litestar/static_files/base.py b/litestar/static_files/base.py index 12f197f987..9827697933 100644 --- a/litestar/static_files/base.py +++ b/litestar/static_files/base.py @@ -22,7 +22,7 @@ class StaticFiles: """ASGI App that handles file sending.""" - __slots__ = ("is_html_mode", "directories", "adapter", "send_as_attachment") + __slots__ = ("is_html_mode", "directories", "adapter", "send_as_attachment", "headers") def __init__( self, @@ -31,6 +31,7 @@ def __init__( file_system: FileSystemProtocol, send_as_attachment: bool = False, resolve_symlinks: bool = True, + headers: dict[str, str] | None = None, ) -> None: """Initialize the Application. @@ -41,11 +42,13 @@ def __init__( send_as_attachment: Whether to send the file with a ``content-disposition`` header of ``attachment`` or ``inline`` resolve_symlinks: Resolve symlinks to the directories + headers: Headers that will be sent with every response. """ self.adapter = FileSystemAdapter(file_system) self.directories = tuple(Path(p).resolve() if resolve_symlinks else Path(p) for p in directories) self.is_html_mode = is_html_mode self.send_as_attachment = send_as_attachment + self.headers = headers async def get_fs_info( self, directories: Sequence[PathType], file_path: PathType @@ -111,6 +114,7 @@ async def handle(self, path: str, is_head_response: bool) -> ASGIFileResponse: filename=filename, content_disposition_type=content_disposition_type, is_head_response=is_head_response, + headers=self.headers, ) if self.is_html_mode: @@ -129,6 +133,7 @@ async def handle(self, path: str, is_head_response: bool) -> ASGIFileResponse: status_code=HTTP_404_NOT_FOUND, content_disposition_type=content_disposition_type, is_head_response=is_head_response, + headers=self.headers, ) raise NotFoundException( diff --git a/litestar/static_files/config.py b/litestar/static_files/config.py index cf34a9f2b3..22b6620aa4 100644 --- a/litestar/static_files/config.py +++ b/litestar/static_files/config.py @@ -162,12 +162,17 @@ def create_static_files_router( _validate_config(path=path, directories=directories, file_system=file_system) path = normalize_path(path) + headers = None + if cache_control: + headers = {cache_control.HEADER_NAME: cache_control.to_header()} + static_files = StaticFiles( is_html_mode=html_mode, directories=directories, file_system=file_system, send_as_attachment=send_as_attachment, resolve_symlinks=resolve_symlinks, + headers=headers, ) @get("{file_path:path}", name=name) diff --git a/tests/unit/test_static_files/test_create_static_router.py b/tests/unit/test_static_files/test_create_static_router.py index 08cc71653a..ffab6383ca 100644 --- a/tests/unit/test_static_files/test_create_static_router.py +++ b/tests/unit/test_static_files/test_create_static_router.py @@ -1,4 +1,7 @@ -from typing import Any +from pathlib import Path +from typing import Any, Optional + +import pytest from litestar import Litestar, Request, Response, Router from litestar.connection import ASGIConnection @@ -6,6 +9,8 @@ from litestar.exceptions import ValidationException from litestar.handlers import BaseRouteHandler from litestar.static_files import create_static_files_router +from litestar.status_codes import HTTP_200_OK +from litestar.testing.helpers import create_test_client def test_route_reverse() -> None: @@ -71,3 +76,21 @@ class MyRouter(Router): router = create_static_files_router("/", directories=["some"], router_class=MyRouter) assert isinstance(router, MyRouter) + + +@pytest.mark.parametrize("cache_control", (None, CacheControlHeader(max_age=3600))) +def test_cache_control(tmp_path: Path, cache_control: Optional[CacheControlHeader]) -> None: + static_dir = tmp_path / "foo" + static_dir.mkdir() + static_dir.joinpath("test.txt").write_text("hello") + + router = create_static_files_router("/static", [static_dir], name="static", cache_control=cache_control) + + with create_test_client([router]) as client: + response = client.get("static/test.txt") + + assert response.status_code == HTTP_200_OK + if cache_control is not None: + assert response.headers["cache-control"] == cache_control.to_header() + else: + assert "cache-control" not in response.headers