Skip to content

Commit

Permalink
[remote/downloader] Migrate Downloader to take Credentials (#16732)
Browse files Browse the repository at this point in the history
Progress on #15856

Closes #16601.

PiperOrigin-RevId: 485837451
Change-Id: I785882d0ff3e171dcaee6aa6f628bca9232c730a
  • Loading branch information
Yannic authored Nov 10, 2022
1 parent f589512 commit 38c5019
Show file tree
Hide file tree
Showing 12 changed files with 111 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2022 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.google.devtools.build.lib.authandtls;

import com.google.auth.Credentials;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import java.net.URI;
import java.util.List;
import java.util.Map;

/** Implementation of {@link Credentials} which provides a static set of credentials. */
public final class StaticCredentials extends Credentials {
public static final StaticCredentials EMPTY = new StaticCredentials(ImmutableMap.of());

private final ImmutableMap<URI, Map<String, List<String>>> credentials;

public StaticCredentials(Map<URI, Map<String, List<String>>> credentials) {
Preconditions.checkNotNull(credentials);

this.credentials = ImmutableMap.copyOf(credentials);
}

@Override
public String getAuthenticationType() {
return "static";
}

@Override
public Map<String, List<String>> getRequestMetadata(URI uri) {
Preconditions.checkNotNull(uri);

return credentials.getOrDefault(uri, ImmutableMap.of());
}

@Override
public boolean hasRequestMetadata() {
return true;
}

@Override
public boolean hasRequestMetadataOnly() {
return true;
}

@Override
public void refresh() {
// Can't refresh static credentials.
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

package com.google.devtools.build.lib.bazel.repository.downloader;

import com.google.auth.Credentials;
import com.google.common.base.Optional;
import com.google.devtools.build.lib.events.ExtendedEventHandler;
import com.google.devtools.build.lib.vfs.Path;
import java.io.IOException;
import java.net.URI;
import java.net.URL;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -47,7 +47,7 @@ public void setDelegate(@Nullable Downloader delegate) {
@Override
public void download(
List<URL> urls,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
Optional<Checksum> checksum,
String canonicalId,
Path destination,
Expand All @@ -60,6 +60,6 @@ public void download(
downloader = delegate;
}
downloader.download(
urls, authHeaders, checksum, canonicalId, destination, eventHandler, clientEnv, type);
urls, credentials, checksum, canonicalId, destination, eventHandler, clientEnv, type);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.devtools.build.lib.authandtls.StaticCredentials;
import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache;
import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache.KeyType;
import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCacheHitEvent;
Expand Down Expand Up @@ -256,7 +257,7 @@ public Path download(
try {
downloader.download(
rewrittenUrls,
rewrittenAuthHeaders,
new StaticCredentials(rewrittenAuthHeaders),
checksum,
canonicalId,
destination,
Expand Down Expand Up @@ -337,7 +338,7 @@ public byte[] downloadAndReadOneUrl(
for (int attempt = 0; attempt <= retries; ++attempt) {
try {
return httpDownloader.downloadAndReadOneUrl(
rewrittenUrls.get(0), authHeaders, eventHandler, clientEnv);
rewrittenUrls.get(0), new StaticCredentials(authHeaders), eventHandler, clientEnv);
} catch (ContentLengthMismatchException e) {
if (attempt == retries) {
throw e;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

package com.google.devtools.build.lib.bazel.repository.downloader;

import com.google.auth.Credentials;
import com.google.common.base.Optional;
import com.google.devtools.build.lib.events.ExtendedEventHandler;
import com.google.devtools.build.lib.vfs.Path;
import java.io.IOException;
import java.net.URI;
import java.net.URL;
import java.util.List;
import java.util.Map;
Expand All @@ -33,7 +33,7 @@ public interface Downloader {
* caller is responsible for cleaning up outputs of failed downloads.
*
* @param urls list of mirror URLs with identical content
* @param authHeaders map of authentication headers per URL
* @param credentials credentials to use when connecting to URLs
* @param checksum valid checksum which is checked, or absent to disable
* @param output path to the destination file to write
* @param type extension, e.g. "tar.gz" to force on downloaded filename, or empty to not do this
Expand All @@ -42,7 +42,7 @@ public interface Downloader {
*/
void download(
List<URL> urls,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
Optional<Checksum> checksum,
String canonicalId,
Path output,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,21 @@

package com.google.devtools.build.lib.bazel.repository.downloader;

import com.google.auth.Credentials;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Function;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.devtools.build.lib.analysis.BlazeVersionInfo;
import com.google.devtools.build.lib.authandtls.StaticCredentials;
import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
import com.google.devtools.build.lib.events.Event;
import com.google.devtools.build.lib.events.EventHandler;
import java.io.IOException;
import java.io.InputStream;
import java.io.InterruptedIOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.net.URLConnection;
Expand Down Expand Up @@ -74,7 +75,7 @@ final class HttpConnectorMultiplexer {
}

public HttpStream connect(URL url, Optional<Checksum> checksum) throws IOException {
return connect(url, checksum, ImmutableMap.of(), Optional.absent());
return connect(url, checksum, StaticCredentials.EMPTY, Optional.absent());
}

/**
Expand All @@ -87,25 +88,22 @@ public HttpStream connect(URL url, Optional<Checksum> checksum) throws IOExcepti
*
* @param url the URL to conenct to. can be: file, http, or https
* @param checksum checksum lazily checked on entire payload, or empty to disable
* @param authHeaders the authentication headers
* @param credentials the credentials
* @param type extension, e.g. "tar.gz" to force on downloaded filename, or empty to not do this
* @return an {@link InputStream} of response payload
* @throws IOException if all mirrors are down and contains suppressed exception of each attempt
* @throws InterruptedIOException if current thread is being cast into oblivion
* @throws IllegalArgumentException if {@code urls} is empty or has an unsupported protocol
*/
public HttpStream connect(
URL url,
Optional<Checksum> checksum,
Map<URI, Map<String, List<String>>> authHeaders,
Optional<String> type)
URL url, Optional<Checksum> checksum, Credentials credentials, Optional<String> type)
throws IOException {
Preconditions.checkArgument(HttpUtils.isUrlSupportedByDownloader(url));
if (Thread.interrupted()) {
throw new InterruptedIOException();
}
Function<URL, ImmutableMap<String, List<String>>> headerFunction =
getHeaderFunction(REQUEST_HEADERS, authHeaders);
getHeaderFunction(REQUEST_HEADERS, credentials);
URLConnection connection = connector.connect(url, headerFunction);
return httpStreamFactory.create(
connection,
Expand All @@ -127,21 +125,20 @@ public HttpStream connect(

@VisibleForTesting
static Function<URL, ImmutableMap<String, List<String>>> getHeaderFunction(
Map<String, List<String>> baseHeaders,
Map<URI, Map<String, List<String>>> additionalHeaders) {
Map<String, List<String>> baseHeaders, Credentials credentials) {
Preconditions.checkNotNull(baseHeaders);
Preconditions.checkNotNull(credentials);

return url -> {
ImmutableMap<String, List<String>> headers = ImmutableMap.copyOf(baseHeaders);
Map<String, List<String>> headers = new HashMap<>(baseHeaders);
try {
if (additionalHeaders.containsKey(url.toURI())) {
Map<String, List<String>> newHeaders = new HashMap<>(headers);
newHeaders.putAll(additionalHeaders.get(url.toURI()));
headers = ImmutableMap.copyOf(newHeaders);
}
} catch (URISyntaxException e) {
// If we can't convert the URL to a URI (because it is syntactically malformed), still try
// to do the connection, not adding authentication information as we cannot look it up.
headers.putAll(credentials.getRequestMetadata(url.toURI()));
} catch (URISyntaxException | IOException e) {
// If we can't convert the URL to a URI (because it is syntactically malformed), or fetching
// credentials fails for any other reason, still try to do the connection, not adding
// authentication information as we cannot look it up.
}
return headers;
return ImmutableMap.copyOf(headers);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package com.google.devtools.build.lib.bazel.repository.downloader;

import com.google.auth.Credentials;
import com.google.common.base.Optional;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
Expand All @@ -31,7 +32,6 @@
import java.io.InterruptedIOException;
import java.io.OutputStream;
import java.net.SocketTimeoutException;
import java.net.URI;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -63,7 +63,7 @@ public void setTimeoutScaling(float timeoutScaling) {
@Override
public void download(
List<URL> urls,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
Optional<Checksum> checksum,
String canonicalId,
Path destination,
Expand All @@ -82,7 +82,7 @@ public void download(
for (URL url : urls) {
SEMAPHORE.acquire();

try (HttpStream payload = multiplexer.connect(url, checksum, authHeaders, type);
try (HttpStream payload = multiplexer.connect(url, checksum, credentials, type);
OutputStream out = destination.getOutputStream()) {
try {
ByteStreams.copy(payload, out);
Expand Down Expand Up @@ -132,7 +132,7 @@ public void download(
/** Downloads the contents of one URL and reads it into a byte array. */
public byte[] downloadAndReadOneUrl(
URL url,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
ExtendedEventHandler eventHandler,
Map<String, String> clientEnv)
throws IOException, InterruptedException {
Expand All @@ -141,7 +141,7 @@ public byte[] downloadAndReadOneUrl(
ByteArrayOutputStream out = new ByteArrayOutputStream();
SEMAPHORE.acquire();
try (HttpStream payload =
multiplexer.connect(url, Optional.absent(), authHeaders, Optional.absent())) {
multiplexer.connect(url, Optional.absent(), credentials, Optional.absent())) {
ByteStreams.copy(payload, out);
} catch (SocketTimeoutException e) {
// SocketTimeoutExceptions are InterruptedIOExceptions; however they do not signify
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ java_library(
"//src/main/java/com/google/devtools/build/lib/remote/options",
"//src/main/java/com/google/devtools/build/lib/remote/util",
"//src/main/java/com/google/devtools/build/lib/vfs",
"//third_party:auth",
"//third_party:guava",
"//third_party:jsr305",
"//third_party/grpc-java:grpc-jar",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import build.bazel.remote.asset.v1.Qualifier;
import build.bazel.remote.execution.v2.Digest;
import build.bazel.remote.execution.v2.RequestMetadata;
import com.google.auth.Credentials;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
import com.google.devtools.build.lib.bazel.repository.downloader.Checksum;
Expand All @@ -41,7 +42,6 @@
import io.grpc.StatusRuntimeException;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
import java.net.URL;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -110,7 +110,7 @@ public void close() {
@Override
public void download(
List<URL> urls,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
com.google.common.base.Optional<Checksum> checksum,
String canonicalId,
Path destination,
Expand Down Expand Up @@ -154,7 +154,7 @@ public void download(
eventHandler.handle(
Event.warn("Remote Cache: " + Utils.grpcAwareErrorMessage(e, verboseFailures)));
fallbackDownloader.download(
urls, authHeaders, checksum, canonicalId, destination, eventHandler, clientEnv, type);
urls, credentials, checksum, canonicalId, destination, eventHandler, clientEnv, type);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import com.google.common.base.Optional;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.devtools.build.lib.authandtls.StaticCredentials;
import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache.KeyType;
import com.google.devtools.build.lib.bazel.repository.downloader.RetryingInputStream.Reconnector;
import com.google.devtools.build.lib.events.EventHandler;
Expand Down Expand Up @@ -163,7 +164,8 @@ public void testHeaderComputationFunction() throws Exception {
ImmutableMap.of("Authentication", ImmutableList.of("Zm9vOmZvb3NlY3JldA==")));

Function<URL, ImmutableMap<String, List<String>>> headerFunction =
HttpConnectorMultiplexer.getHeaderFunction(baseHeaders, additionalHeaders);
HttpConnectorMultiplexer.getHeaderFunction(
baseHeaders, new StaticCredentials(additionalHeaders));

// Unreleated URL
assertThat(headerFunction.apply(new URL("http://example.org/some/path/file.txt")))
Expand Down Expand Up @@ -215,7 +217,8 @@ public void testHeaderComputationFunction() throws Exception {
ImmutableMap<String, List<String>> annonAuth =
ImmutableMap.of("Authentication", ImmutableList.of("YW5vbnltb3VzOmZvb0BleGFtcGxlLm9yZw=="));
Function<URL, ImmutableMap<String, List<String>>> combinedHeaders =
HttpConnectorMultiplexer.getHeaderFunction(annonAuth, additionalHeaders);
HttpConnectorMultiplexer.getHeaderFunction(
annonAuth, new StaticCredentials(additionalHeaders));
assertThat(combinedHeaders.apply(new URL("http://hosting.example.com/user/foo/file.txt")))
.containsExactly("Authentication", ImmutableList.of("Zm9vOmZvb3NlY3JldA=="));
assertThat(combinedHeaders.apply(new URL("http://unreleated.example.org/user/foo/file.txt")))
Expand Down
Loading

0 comments on commit 38c5019

Please sign in to comment.