Skip to content

Commit

Permalink
Handle repository redirects and skip authorization header for LFS
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Dec 18, 2023
1 parent b947dd2 commit b4b8f53
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 33 deletions.
92 changes: 59 additions & 33 deletions lib/bumblebee/huggingface/hub.ex
Original file line number Diff line number Diff line change
Expand Up @@ -67,34 +67,39 @@ defmodule Bumblebee.HuggingFace.Hub do

metadata_path = Path.join(dir, metadata_filename(url))

if offline do
case load_json(metadata_path) do
{:ok, %{"etag" => etag}} ->
entry_path = Path.join(dir, entry_filename(url, etag))
{:ok, entry_path}

_ ->
{:error,
"could not find file in local cache and outgoing traffic is disabled, url: #{url}"}
end
else
head_result =
if etag = opts[:etag] do
{:ok, etag, url}
else
head_download(url, headers)
end

with {:ok, etag, download_url} <- head_result do
entry_path = Path.join(dir, entry_filename(url, etag))

cond do
offline ->
case load_json(metadata_path) do
{:ok, %{"etag" => ^etag}} ->
{:ok, %{"etag" => etag}} ->
entry_path = Path.join(dir, entry_filename(url, etag))
{:ok, entry_path}

_ ->
case HTTP.download(download_url, entry_path, headers: headers)
|> finish_request(download_url) do
{:error,
"could not find file in local cache and outgoing traffic is disabled, url: #{url}"}
end

entry_path = opts[:etag] && cached_path_for_etag(dir, url, opts[:etag]) ->
{:ok, entry_path}

true ->
with {:ok, etag, download_url, redirect?} <- head_download(url, headers) do
if entry_path = cached_path_for_etag(dir, url, etag) do
{:ok, entry_path}
else
entry_path = Path.join(dir, entry_filename(url, etag))

headers =
if redirect? do
List.keydelete(headers, "Authorization", 0)
else
headers
end

download_url
|> HTTP.download(entry_path, headers: headers)
|> finish_request(download_url)
|> case do
:ok ->
:ok = store_json(metadata_path, %{"etag" => etag, "url" => url})
{:ok, entry_path}
Expand All @@ -104,24 +109,45 @@ defmodule Bumblebee.HuggingFace.Hub do
File.rm_rf!(entry_path)
error
end
end
end
end
end
end

defp cached_path_for_etag(dir, url, etag) do
metadata_path = Path.join(dir, metadata_filename(url))

case load_json(metadata_path) do
{:ok, %{"etag" => ^etag}} ->
Path.join(dir, entry_filename(url, etag))

_ ->
nil
end
end

defp head_download(url, headers) do
with {:ok, response} <-
HTTP.request(:head, url, follow_redirects: false, headers: headers)
|> finish_request(url),
{:ok, etag} <- fetch_etag(response) do
download_url =
if response.status in 300..399 do
HTTP.get_header(response, "location")
|> finish_request(url) do
if response.status in 300..399 do
location = HTTP.get_header(response, "location")

# Follow relative redirects
if URI.parse(location).host == nil do
url =
url
|> URI.parse()
|> Map.replace!(:path, location)
|> URI.to_string()

head_download(url, headers)
else
url
with {:ok, etag} <- fetch_etag(response), do: {:ok, etag, location, true}
end

{:ok, etag, download_url}
else
with {:ok, etag} <- fetch_etag(response), do: {:ok, etag, url, false}
end
end
end

Expand Down
78 changes: 78 additions & 0 deletions test/bumblebee/huggingface/hub_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,28 @@ defmodule Bumblebee.HuggingFace.HubTest do
assert File.read!(path) == <<0, 1>>
end

@tag :tmp_dir
test "follows relative redirect before checking etag", %{bypass: bypass, tmp_dir: tmp_dir} do
Bypass.expect_once(bypass, "HEAD", "/repo/file.json", fn conn ->
conn
|> Plug.Conn.put_resp_header("location", "/repo-renamed/file.json")
|> Plug.Conn.resp(307, "")
end)

Bypass.expect_once(bypass, "HEAD", "/repo-renamed/file.json", fn conn ->
serve_with_etag(conn, ~s/"hash"/, "{}")
end)

Bypass.expect_once(bypass, "GET", "/repo-renamed/file.json", fn conn ->
serve_with_etag(conn, ~s/"hash"/, "{}")
end)

url = url(bypass.port) <> "/repo/file.json"

assert {:ok, path} = Hub.cached_download(url, cache_dir: tmp_dir, offline: false)
assert File.read!(path) == "{}"
end

@tag :tmp_dir
test "returns an error on missing etag header", %{bypass: bypass, tmp_dir: tmp_dir} do
Bypass.expect_once(bypass, "HEAD", "/file.json", fn conn ->
Expand Down Expand Up @@ -183,6 +205,62 @@ defmodule Bumblebee.HuggingFace.HubTest do
"could not find file in local cache and outgoing traffic is disabled, url: " <> _} =
Hub.cached_download(url, cache_dir: tmp_dir, offline: true)
end

@tag :tmp_dir
test "includes authorization header when :auth_token is given",
%{bypass: bypass, tmp_dir: tmp_dir} do
Bypass.expect_once(bypass, "HEAD", "/file.json", fn conn ->
assert {"authorization", "Bearer token"} in conn.req_headers

serve_with_etag(conn, ~s/"hash"/, "")
end)

Bypass.expect_once(bypass, "GET", "/file.json", fn conn ->
assert {"authorization", "Bearer token"} in conn.req_headers

serve_with_etag(conn, ~s/"hash"/, "{}")
end)

url = url(bypass.port) <> "/file.json"

assert {:ok, _path} =
Hub.cached_download(url, auth_token: "token", cache_dir: tmp_dir, offline: false)
end

@tag :tmp_dir
test "skips authentication header for redirected requests",
%{bypass: bypass, tmp_dir: tmp_dir} do
# Context: HuggingFace Hub returns redirects for files stored
# in LFS and the redirect location already has S3 signature in
# URL params. If the location points to S3 directly, which is
# the case within HF spaces, passing the Authorization header
# leads to failure, presumably because it takes precedence over
# the params signature

Bypass.expect_once(bypass, "HEAD", "/file.bin", fn conn ->
assert {"authorization", "Bearer token"} in conn.req_headers

url = Plug.Conn.request_url(conn)

conn
|> Plug.Conn.put_resp_header("x-linked-etag", ~s/"hash"/)
|> Plug.Conn.put_resp_header("location", url <> "/storage")
|> Plug.Conn.resp(302, "")
end)

Bypass.expect_once(bypass, "GET", "/file.bin/storage", fn conn ->
assert {"authorization", "Bearer token"} not in conn.req_headers

conn
|> Plug.Conn.put_resp_header("etag", ~s/"hash"/)
|> Plug.Conn.resp(200, <<0, 1>>)
end)

url = url(bypass.port) <> "/file.bin"

assert {:ok, _path} =
Hub.cached_download(url, auth_token: "token", cache_dir: tmp_dir, offline: false)
end
end

defp url(port), do: "http://localhost:#{port}"
Expand Down

0 comments on commit b4b8f53

Please sign in to comment.