diff --git a/src/PlatformEngines.jl b/src/PlatformEngines.jl index 387b8869cb..3a3de4d6d3 100644 --- a/src/PlatformEngines.jl +++ b/src/PlatformEngines.jl @@ -4,6 +4,7 @@ module PlatformEngines using SHA, Logging +import ...Pkg: TOML, pkg_server, depots1 export probe_platform_engines!, parse_7z_list, parse_tar_list, verify, download_verify, unpack, package, download_verify_unpack, @@ -588,12 +589,109 @@ function parse_tar_list(output::AbstractString) return Sys.iswindows() ? replace.(lines, ['/' => '\\']) : lines end +is_secure_url(url::AbstractString) = + occursin(r"^(https://|\w+://(127\.0\.0\.1|localhost)(:\d+)?($|/))"i, url) + +function get_auth_header(url::AbstractString; verbose::Bool = false) + server = pkg_server() + server === nothing && return + startswith(url, server) || return + # find and parse auth file + m = match(r"^(\w+)://([^\\/]+)$", server) + if m === nothing + @warn "malformed Pkg server value" server=server + return + end + proto, host = m.captures + auth_file = joinpath(depots1(), "servers", host, "auth.toml") + isfile(auth_file) || return + # TODO: check for insecure auth file permissions + if !is_secure_url(url) + @warn "refusing to send auth info over insecure connection" url=url + return + end + # parse the auth file + auth_info = try + TOML.parsefile(auth_file) + catch err + @error "malformed auth file" file=auth_file err=err + return + end + # check for an auth token + if !haskey(auth_info, "access_token") + @warn "auth file without access_token field" file=auth_file + return + end + auth_header = "Authorization: Bearer $(auth_info["access_token"])" + # handle token expiration and refresh + expires_at = Inf + if haskey(auth_info, "expires_at") + expires_at = min(expires_at, auth_info["expires_at"]::Integer) + end + if haskey(auth_info, "expires_in") + expires_at = min(expires_at, mtime(auth_file) + auth_info["expires_in"]::Integer) + end + # if token is good until ten minutes from now, use it + time_now = time() + if expires_at ≥ time_now + 10*60 # ten minutes + return auth_header + end + if !haskey(auth_info, "refresh_url") || !haskey(auth_info, "refresh_token") + if expires_at ≤ time_now + @warn "expired auth without refresh keys" file=auth_file + end + # try it anyway since we can't refresh + return auth_header + end + refresh_url = auth_info["refresh_url"] + if !is_secure_url(refresh_url) + @warn "ignoring insecure auth refresh URL" url=refresh_url + return auth_header + end + verbose && @info "Refreshing expired auth token..." file=auth_file + tmp = tempname() + refresh_auth = "Authorization: Bearer $(auth_info["refresh_token"])" + try download(refresh_url, tmp, auth_header=refresh_auth, verbose=verbose) + catch err + @warn "token refresh failure" file=auth_file url=refresh_url err=err + rm(tmp, force=true) + return + end + auth_info = try TOML.parsefile(tmp) + catch err + @warn "discarding malformed auth file" url=refresh_url err=err + rm(tmp, force=true) + return auth_header + end + if !haskey(auth_info, "access_token") + if haskey(auth_info, "refresh_token") + auth_info["refresh_token"] = "*"^64 + end + @warn "discarding auth file without access token" auth=auth_info + rm(tmp, force=true) + return auth_header + end + if haskey(auth_info, "expires_in") + expires_in = auth_info["expires_in"] + if expires_in isa Number + expires_at = floor(Int64, time_now + expires_in) + # overwrite expires_at (avoids clock skew issues) + auth_info["expires_at"] = expires_at + end + end + open(tmp, write=true) do io + TOML.print(io, auth_info, sorted=true) + end + mv(tmp, auth_file, force=true) + return "Authorization: Bearer $(auth_info["access_token"])" +end + """ download( url::AbstractString, dest::AbstractString; - headers::Vector{Pair{String}} = Pair{String}[], verbose::Bool = false, + auth_header::Union{AbstractString, Nothing} = nothing, ) Download file located at `url`, store it at `dest`, continuing if `dest` @@ -602,11 +700,17 @@ already exists and the server and download engine support it. function download( url::AbstractString, dest::AbstractString; - headers::Vector{Pair{String}} = Pair{String}[], verbose::Bool = false, + auth_header::Union{AbstractString, Nothing} = nothing, ) - hdrs = String["$key: $val" for (key, val) in headers] - download_cmd = gen_download_cmd(url, dest, hdrs...) + if auth_header === nothing + auth_header = get_auth_header(url, verbose=verbose) + end + if auth_header === nothing + download_cmd = gen_download_cmd(url, dest) + else + download_cmd = gen_download_cmd(url, dest, auth_header) + end if verbose @info("Downloading $(url) to $(dest)...") end @@ -625,7 +729,6 @@ end url::AbstractString, hash::Union{AbstractString, Nothing}, dest::AbstractString; - headers::Vector{Pair{String}} = Pair{String}[], verbose::Bool = false, force::Bool = false, quiet_download::Bool = true, @@ -652,7 +755,6 @@ function download_verify( url::AbstractString, hash::Union{AbstractString, Nothing}, dest::AbstractString; - headers::Vector{Pair{String}} = Pair{String}[], verbose::Bool = false, force::Bool = false, quiet_download::Bool = true, @@ -834,7 +936,6 @@ end url::AbstractString, hash::Union{AbstractString, Nothing}, dest::AbstractString; - headers::Vector{Pair{String}} = Pair{String}[], tarball_path = nothing, ignore_existence::Bool = false, force::Bool = false, @@ -868,7 +969,6 @@ function download_verify_unpack( url::AbstractString, hash::Union{AbstractString, Nothing}, dest::AbstractString; - headers::Vector{Pair{String}} = Pair{String}[], tarball_path = nothing, ignore_existence::Bool = false, force::Bool = false,