diff --git a/lib/bumblebee/huggingface/hub.ex b/lib/bumblebee/huggingface/hub.ex index bd2ce305..03b166a3 100644 --- a/lib/bumblebee/huggingface/hub.ex +++ b/lib/bumblebee/huggingface/hub.ex @@ -67,34 +67,39 @@ defmodule Bumblebee.HuggingFace.Hub do metadata_path = Path.join(dir, metadata_filename(url)) - if offline do - case load_json(metadata_path) do - {:ok, %{"etag" => etag}} -> - entry_path = Path.join(dir, entry_filename(url, etag)) - {:ok, entry_path} - - _ -> - {:error, - "could not find file in local cache and outgoing traffic is disabled, url: #{url}"} - end - else - head_result = - if etag = opts[:etag] do - {:ok, etag, url} - else - head_download(url, headers) - end - - with {:ok, etag, download_url} <- head_result do - entry_path = Path.join(dir, entry_filename(url, etag)) - + cond do + offline -> case load_json(metadata_path) do - {:ok, %{"etag" => ^etag}} -> + {:ok, %{"etag" => etag}} -> + entry_path = Path.join(dir, entry_filename(url, etag)) {:ok, entry_path} _ -> - case HTTP.download(download_url, entry_path, headers: headers) - |> finish_request(download_url) do + {:error, + "could not find file in local cache and outgoing traffic is disabled, url: #{url}"} + end + + entry_path = opts[:etag] && cached_path_for_etag(dir, url, opts[:etag]) -> + {:ok, entry_path} + + true -> + with {:ok, etag, download_url, redirect?} <- head_download(url, headers) do + if entry_path = cached_path_for_etag(dir, url, etag) do + {:ok, entry_path} + else + entry_path = Path.join(dir, entry_filename(url, etag)) + + headers = + if redirect? do + List.keydelete(headers, "Authorization", 0) + else + headers + end + + download_url + |> HTTP.download(entry_path, headers: headers) + |> finish_request(download_url) + |> case do :ok -> :ok = store_json(metadata_path, %{"etag" => etag, "url" => url}) {:ok, entry_path} @@ -104,24 +109,45 @@ defmodule Bumblebee.HuggingFace.Hub do File.rm_rf!(entry_path) error end + end end - end + end + end + + defp cached_path_for_etag(dir, url, etag) do + metadata_path = Path.join(dir, metadata_filename(url)) + + case load_json(metadata_path) do + {:ok, %{"etag" => ^etag}} -> + Path.join(dir, entry_filename(url, etag)) + + _ -> + nil end end defp head_download(url, headers) do with {:ok, response} <- HTTP.request(:head, url, follow_redirects: false, headers: headers) - |> finish_request(url), - {:ok, etag} <- fetch_etag(response) do - download_url = - if response.status in 300..399 do - HTTP.get_header(response, "location") + |> finish_request(url) do + if response.status in 300..399 do + location = HTTP.get_header(response, "location") + + # Follow relative redirects + if URI.parse(location).host == nil do + url = + url + |> URI.parse() + |> Map.replace!(:path, location) + |> URI.to_string() + + head_download(url, headers) else - url + with {:ok, etag} <- fetch_etag(response), do: {:ok, etag, location, true} end - - {:ok, etag, download_url} + else + with {:ok, etag} <- fetch_etag(response), do: {:ok, etag, url, false} + end end end diff --git a/test/bumblebee/huggingface/hub_test.exs b/test/bumblebee/huggingface/hub_test.exs index 0a6f2ffc..14f8a6fa 100644 --- a/test/bumblebee/huggingface/hub_test.exs +++ b/test/bumblebee/huggingface/hub_test.exs @@ -115,6 +115,28 @@ defmodule Bumblebee.HuggingFace.HubTest do assert File.read!(path) == <<0, 1>> end + @tag :tmp_dir + test "follows relative redirect before checking etag", %{bypass: bypass, tmp_dir: tmp_dir} do + Bypass.expect_once(bypass, "HEAD", "/repo/file.json", fn conn -> + conn + |> Plug.Conn.put_resp_header("location", "/repo-renamed/file.json") + |> Plug.Conn.resp(307, "") + end) + + Bypass.expect_once(bypass, "HEAD", "/repo-renamed/file.json", fn conn -> + serve_with_etag(conn, ~s/"hash"/, "{}") + end) + + Bypass.expect_once(bypass, "GET", "/repo-renamed/file.json", fn conn -> + serve_with_etag(conn, ~s/"hash"/, "{}") + end) + + url = url(bypass.port) <> "/repo/file.json" + + assert {:ok, path} = Hub.cached_download(url, cache_dir: tmp_dir, offline: false) + assert File.read!(path) == "{}" + end + @tag :tmp_dir test "returns an error on missing etag header", %{bypass: bypass, tmp_dir: tmp_dir} do Bypass.expect_once(bypass, "HEAD", "/file.json", fn conn -> @@ -183,6 +205,62 @@ defmodule Bumblebee.HuggingFace.HubTest do "could not find file in local cache and outgoing traffic is disabled, url: " <> _} = Hub.cached_download(url, cache_dir: tmp_dir, offline: true) end + + @tag :tmp_dir + test "includes authorization header when :auth_token is given", + %{bypass: bypass, tmp_dir: tmp_dir} do + Bypass.expect_once(bypass, "HEAD", "/file.json", fn conn -> + assert {"authorization", "Bearer token"} in conn.req_headers + + serve_with_etag(conn, ~s/"hash"/, "") + end) + + Bypass.expect_once(bypass, "GET", "/file.json", fn conn -> + assert {"authorization", "Bearer token"} in conn.req_headers + + serve_with_etag(conn, ~s/"hash"/, "{}") + end) + + url = url(bypass.port) <> "/file.json" + + assert {:ok, _path} = + Hub.cached_download(url, auth_token: "token", cache_dir: tmp_dir, offline: false) + end + + @tag :tmp_dir + test "skips authentication header for redirected requests", + %{bypass: bypass, tmp_dir: tmp_dir} do + # Context: HuggingFace Hub returns redirects for files stored + # in LFS and the redirect location already has S3 signature in + # URL params. If the location points to S3 directly, which is + # the case within HF spaces, passing the Authorization header + # leads to failure, presumably because it takes precedence over + # the params signature + + Bypass.expect_once(bypass, "HEAD", "/file.bin", fn conn -> + assert {"authorization", "Bearer token"} in conn.req_headers + + url = Plug.Conn.request_url(conn) + + conn + |> Plug.Conn.put_resp_header("x-linked-etag", ~s/"hash"/) + |> Plug.Conn.put_resp_header("location", url <> "/storage") + |> Plug.Conn.resp(302, "") + end) + + Bypass.expect_once(bypass, "GET", "/file.bin/storage", fn conn -> + assert {"authorization", "Bearer token"} not in conn.req_headers + + conn + |> Plug.Conn.put_resp_header("etag", ~s/"hash"/) + |> Plug.Conn.resp(200, <<0, 1>>) + end) + + url = url(bypass.port) <> "/file.bin" + + assert {:ok, _path} = + Hub.cached_download(url, auth_token: "token", cache_dir: tmp_dir, offline: false) + end end defp url(port), do: "http://localhost:#{port}"