diff --git a/google/default.go b/google/default.go index 4b55b3f5a..df958359a 100644 --- a/google/default.go +++ b/google/default.go @@ -42,6 +42,17 @@ type Credentials struct { // running on Google Cloud Platform. JSON []byte + // UniverseDomainProvider returns the default service domain for a given + // Cloud universe. Optional. + // + // On GCE, UniverseDomainProvider should return the universe domain value + // from Google Compute Engine (GCE)'s metadata server. See also [The attached service + // account](https://cloud.google.com/docs/authentication/application-default-credentials#attached-sa). + // If the GCE metadata server returns a 404 error, the default universe + // domain value should be returned. If the GCE metadata server returns an + // error other than 404, the error should be returned. + UniverseDomainProvider func() (string, error) + udMu sync.Mutex // guards universeDomain // universeDomain is the default service domain for a given Cloud universe. universeDomain string @@ -64,54 +75,32 @@ func (c *Credentials) UniverseDomain() string { } // GetUniverseDomain returns the default service domain for a given Cloud -// universe. +// universe. If present, UniverseDomainProvider will be invoked and its return +// value will be cached. // // The default value is "googleapis.com". -// -// It obtains the universe domain from the attached service account on GCE when -// authenticating via the GCE metadata server. See also [The attached service -// account](https://cloud.google.com/docs/authentication/application-default-credentials#attached-sa). -// If the GCE metadata server returns a 404 error, the default value is -// returned. If the GCE metadata server returns an error other than 404, the -// error is returned. func (c *Credentials) GetUniverseDomain() (string, error) { c.udMu.Lock() defer c.udMu.Unlock() - if c.universeDomain == "" && metadata.OnGCE() { - // If we're on Google Compute Engine, an App Engine standard second - // generation runtime, or App Engine flexible, use the metadata server. - err := c.computeUniverseDomain() + if c.universeDomain == "" && c.UniverseDomainProvider != nil { + // On Google Compute Engine, an App Engine standard second generation + // runtime, or App Engine flexible, use an externally provided function + // to request the universe domain from the metadata server. + ud, err := c.UniverseDomainProvider() if err != nil { return "", err } + c.universeDomain = ud } - // If not on Google Compute Engine, or in case of any non-error path in - // computeUniverseDomain that did not set universeDomain, set the default - // universe domain. + // If no UniverseDomainProvider (meaning not on Google Compute Engine), or + // in case of any (non-error) empty return value from + // UniverseDomainProvider, set the default universe domain. if c.universeDomain == "" { c.universeDomain = defaultUniverseDomain } return c.universeDomain, nil } -// computeUniverseDomain fetches the default service domain for a given Cloud -// universe from Google Compute Engine (GCE)'s metadata server. It's only valid -// to use this method if your program is running on a GCE instance. -func (c *Credentials) computeUniverseDomain() error { - var err error - c.universeDomain, err = metadata.Get("universe/universe_domain") - if err != nil { - if _, ok := err.(metadata.NotDefinedError); ok { - // http.StatusNotFound (404) - c.universeDomain = defaultUniverseDomain - return nil - } else { - return err - } - } - return nil -} - // DefaultCredentials is the old name of Credentials. // // Deprecated: use Credentials instead. @@ -226,10 +215,23 @@ func FindDefaultCredentialsWithParams(ctx context.Context, params CredentialsPar // or App Engine flexible, use the metadata server. if metadata.OnGCE() { id, _ := metadata.ProjectID() + universeDomainProvider := func() (string, error) { + universeDomain, err := metadata.Get("universe/universe_domain") + if err != nil { + if _, ok := err.(metadata.NotDefinedError); ok { + // http.StatusNotFound (404) + return defaultUniverseDomain, nil + } else { + return "", err + } + } + return universeDomain, nil + } return &Credentials{ - ProjectID: id, - TokenSource: computeTokenSource("", params.EarlyTokenRefresh, params.Scopes...), - universeDomain: params.UniverseDomain, + ProjectID: id, + TokenSource: computeTokenSource("", params.EarlyTokenRefresh, params.Scopes...), + UniverseDomainProvider: universeDomainProvider, + universeDomain: params.UniverseDomain, }, nil } diff --git a/google/default_test.go b/google/default_test.go index 7352ffcce..c8465e94f 100644 --- a/google/default_test.go +++ b/google/default_test.go @@ -10,6 +10,8 @@ import ( "net/http/httptest" "strings" "testing" + + "cloud.google.com/go/compute/metadata" ) var saJSONJWT = []byte(`{ @@ -255,9 +257,14 @@ func TestCredentialsFromJSONWithParams_User_UniverseDomain_Params_UniverseDomain func TestComputeUniverseDomain(t *testing.T) { universeDomainPath := "/computeMetadata/v1/universe/universe_domain" universeDomainResponseBody := "example.com" + var requests int s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests++ if r.URL.Path != universeDomainPath { - t.Errorf("got %s, want %s", r.URL.Path, universeDomainPath) + t.Errorf("bad path, got %s, want %s", r.URL.Path, universeDomainPath) + } + if requests > 1 { + t.Errorf("too many requests, got %d, want 1", requests) } w.Write([]byte(universeDomainResponseBody)) })) @@ -268,11 +275,19 @@ func TestComputeUniverseDomain(t *testing.T) { params := CredentialsParams{ Scopes: []string{scope}, } + universeDomainProvider := func() (string, error) { + universeDomain, err := metadata.Get("universe/universe_domain") + if err != nil { + return "", err + } + return universeDomain, nil + } // Copied from FindDefaultCredentialsWithParams, metadata.OnGCE() = true block creds := &Credentials{ - ProjectID: "fake_project", - TokenSource: computeTokenSource("", params.EarlyTokenRefresh, params.Scopes...), - universeDomain: params.UniverseDomain, // empty + ProjectID: "fake_project", + TokenSource: computeTokenSource("", params.EarlyTokenRefresh, params.Scopes...), + UniverseDomainProvider: universeDomainProvider, + universeDomain: params.UniverseDomain, // empty } c := make(chan bool) go func() { @@ -285,7 +300,7 @@ func TestComputeUniverseDomain(t *testing.T) { } c <- true }() - got, err := creds.GetUniverseDomain() // Second conflicting access. + got, err := creds.GetUniverseDomain() // Second conflicting (and potentially uncached) access. <-c if err != nil { t.Error(err)