Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Authentication Support for Java & Go SDKs #971

Merged
merged 14 commits into from
Sep 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,25 @@
import java.util.Map;
import java.util.concurrent.Executor;

/*
* Google auth provider's callCredentials Implementation for serving.
* Used by CoreSpecService to connect to core.
/**
* GoogleAuthCredentials provides a Google OIDC ID token for making authenticated gRPC calls. Uses
* <a href="https://cloud.google.com/docs/authentication/getting-started">Google Application
* Default</a> credentials to obtain the OIDC token used for authentication. The given token will be
* passed as authorization bearer token when making calls.
*/
public class GoogleAuthCredentials extends CallCredentials {
private final IdTokenCredentials credentials;
private static final String BEARER_TYPE = "Bearer";
private static final Metadata.Key<String> AUTHORIZATION_METADATA_KEY =
Metadata.Key.of("Authorization", ASCII_STRING_MARSHALLER);

/**
* Construct a new GoogleAuthCredentials with given options.
*
* @param options a map of options, Required unless specified: audience - Optional, Sets the
* target audience of the token obtained.
*/
public GoogleAuthCredentials(Map<String, String> options) throws IOException {

String targetAudience = options.getOrDefault("audience", "https://localhost");
ServiceAccountCredentials serviceCreds =
(ServiceAccountCredentials)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package feast.core.auth.infra;
package feast.common.auth.credentials;

import io.grpc.CallCredentials;
import io.grpc.Metadata;
import java.util.concurrent.Executor;

/**
* JWTCallCredentials provides/attaches a static JWT token for making authenticated gRPC calls. The
* given token will be passed as authorization bearer token when making calls.
*/
public final class JwtCallCredentials extends CallCredentials {

private String jwt;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@
import okhttp3.Response;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;

/*
* Oauth Credentials Implementation for serving.
* Used by CoreSpecService to connect to core.
/**
* OAuthCredentials uses a OAuth OIDC ID token making authenticated gRPC calls. Makes an OAuth
* request to obtain the OIDC token used for authentication. The given token will be passed as
* authorization bearer token when making calls.
*/
public class OAuthCredentials extends CallCredentials {

Expand All @@ -58,6 +59,15 @@ public class OAuthCredentials extends CallCredentials {
private Instant tokenExpiryTime;
private NimbusJwtDecoder jwtDecoder;

/**
* Constructs a new OAuthCredentials with given options.
*
* @param options a map of options, Required unless specified: grant_type - OAuth grant type.
* Should be set as client_credentials audience - Sets the target audience of the token
* obtained. client_id - Client id to use in the OAuth request. client_secret - Client securet
* to use in the OAuth request. jwtEndpointURI - HTTPS URL used to retrieve a JWK that can be
* used to decode the credential.
*/
public OAuthCredentials(Map<String, String> options) {
this.httpClient = new OkHttpClient();
if (!(options.containsKey(GRANT_TYPE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,15 @@ public void onMessage(ReqT message) {
private String getIdentity(Authentication authentication) {
// use subject claim as identity if set in security authorization properties
if (securityProperties != null) {
Map<String, String> options = securityProperties.getAuthorization().getOptions();
Map<String, String> options = securityProperties.getAuthentication().getOptions();
if (options.containsKey(AuthenticationProperties.SUBJECT_CLAIM)) {
return AuthUtils.getSubjectFromAuth(
authentication, options.get(AuthenticationProperties.SUBJECT_CLAIM));
try {
return AuthUtils.getSubjectFromAuth(
authentication, options.get(AuthenticationProperties.SUBJECT_CLAIM));
} catch (IllegalStateException e) {
// could not extract claim, revert to authenticated name.
return authentication.getName();
}
}
}
return authentication.getName();
Expand Down
1 change: 1 addition & 0 deletions core/src/test/java/feast/core/auth/infra/JwtHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.nimbusds.jose.jwk.gen.RSAKeyGenerator;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import feast.common.auth.credentials.JwtCallCredentials;
import io.grpc.*;
import java.security.interfaces.RSAPublicKey;
import java.time.Instant;
Expand Down
136 changes: 136 additions & 0 deletions sdk/go/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package feast

import (
"bytes"
"context"
"encoding/json"
"fmt"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/idtoken"
"io/ioutil"
"net/http"
"net/url"
)

// Credential provides OIDC ID tokens used when authenticating with Feast.
// Implements credentials.PerRPCCredentials
type Credential struct {
tokenSrc oauth2.TokenSource
}

// GetRequestMetadata attaches OIDC token as metadata, refreshing tokens if required.
// This should be called by the GRPC to authenticate each request.
func (provider *Credential) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
token, err := provider.tokenSrc.Token()
if err != nil {
return map[string]string{}, nil
}
return map[string]string{
"Authorization": "Bearer: " + token.AccessToken,
}, nil
}

// Disable requirement of transport security to allow user to configure it explictly instead.
func (provider *Credential) RequireTransportSecurity() bool {
return false
}

// Create a Static Authentication Provider that provides a static token
func NewStaticCredential(token string) *Credential {
return &Credential{tokenSrc: oauth2.StaticTokenSource(
&oauth2.Token{
AccessToken: token,
}),
}
}

func newGoogleCredential(
audience string,
findDefaultCredentials func(ctx context.Context, scopes ...string) (*google.Credentials, error),
makeTokenSource func(ctx context.Context, audience string, opts ...idtoken.ClientOption) (oauth2.TokenSource, error)) (*Credential, error) {
// Refresh a Google Id token
// Attempt to id token from Google Application Default Credentials
ctx := context.Background()
creds, err := findDefaultCredentials(ctx, "openid", "email")
if err != nil {
return nil, err
}
tokenSrc, err := makeTokenSource(ctx, audience, idtoken.WithCredentialsJSON(creds.JSON))
if err != nil {
return nil, err
}
return &Credential{tokenSrc: tokenSrc}, nil
}

// Creates a new Google Credential which obtains credentials from Application Default Credentials
func NewGoogleCredential(audience string) (*Credential, error) {
return newGoogleCredential(audience, google.FindDefaultCredentials, idtoken.NewTokenSource)
}

// Creates a new OAuth credential witch obtains credentials by making a client credentials request to an OAuth endpoint.
// clientId, clientSecret - Client credentials used to authenticate the client when obtaining credentials.
// endpointURL - target URL of the OAuth endpoint to make the OAuth request to.
func NewOAuthCredential(audience string, clientId string, clientSecret string, endpointURL *url.URL) *Credential {
tokenSrc := &oauthTokenSource{
clientId: clientId,
clientSecret: clientSecret,
endpointURL: endpointURL,
audience: audience,
}
return &Credential{tokenSrc: tokenSrc}
}

// Defines a Token Source that obtains tokens via making a OAuth client credentials request.
type oauthTokenSource struct {
clientId string
clientSecret string
endpointURL *url.URL
audience string
token *oauth2.Token
}

// Defines a Oauth cleint credentials request.
type oauthClientCredientialsRequest struct {
GrantType string `json:"grant_type"`
ClientId string `json:"client_id"`
ClientSecret string `json:"client_secret"`
Audience string `json:"audience"`
}

// Obtain or Refresh token from OAuth Token Source.
func (tokenSrc *oauthTokenSource) Token() (*oauth2.Token, error) {
if tokenSrc.token == nil || !tokenSrc.token.Valid() {
// Refresh Oauth Id token by making Oauth client credentials request
req := &oauthClientCredientialsRequest{
GrantType: "client_credentials",
ClientId: tokenSrc.clientId,
ClientSecret: tokenSrc.clientSecret,
Audience: tokenSrc.audience,
}

reqBytes, err := json.Marshal(req)
if err != nil {
return nil, err
}
resp, err := http.Post(tokenSrc.endpointURL.String(),
"application/json", bytes.NewBuffer(reqBytes))
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("OAuth Endpoint returned unexpected status: %s", resp.Status)
}
respBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
tokenSrc.token = &oauth2.Token{}
err = json.Unmarshal(respBytes, tokenSrc.token)
if err != nil {
return nil, err
}
}

return tokenSrc.token, nil
}
142 changes: 142 additions & 0 deletions sdk/go/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package feast

import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/idtoken"
)

// Returns a mocked google credential.
func mockGoogleCredential(token string, targetAudience string) (*Credential, error) {
// mock find default credentials implementation.
findDefaultCredentials := func(ctx context.Context, scopes ...string) (*google.Credentials, error) {
if len(scopes) != 2 && scopes[0] != "openid" && scopes[1] != "email" {
return nil, fmt.Errorf("Got bad scopes. Expected 'openid', 'email'")
}

return &google.Credentials{
ProjectID: "project_id",
JSON: []byte("mock key json"),
}, nil
}

// mock id token source implementation.
makeTokenSource := func(ctx context.Context, audience string, opts ...idtoken.ClientOption) (oauth2.TokenSource, error) {
// unable to check opts as ClientOption refrences internal type.
if targetAudience != audience {
return nil, fmt.Errorf("Audience does not match up with target audience")
}

return oauth2.StaticTokenSource(&oauth2.Token{
AccessToken: "google token",
}), nil
}

return newGoogleCredential(targetAudience, findDefaultCredentials, makeTokenSource)
}

// Create a mocked OAuth credential with a backing mocked OAuth server.
func mockOAuthCredential(token string, audience string) (*httptest.Server, *Credential) {
clientId := "id"
clientSecret := "secret"
path := "/oauth"

// Create a mock OAuth server to test Oauth provider.
handlers := http.NewServeMux()
handlers.HandleFunc(path, func(resp http.ResponseWriter, req *http.Request) {
reqBytes, err := ioutil.ReadAll(req.Body)
if err != nil {
resp.WriteHeader(http.StatusBadRequest)
}

oauthReq := oauthClientCredientialsRequest{}
err = json.Unmarshal(reqBytes, &oauthReq)
if err != nil {
resp.WriteHeader(http.StatusBadRequest)
}

if oauthReq.GrantType != "client_credentials" ||
oauthReq.ClientId != clientId ||
oauthReq.ClientSecret != clientSecret ||
oauthReq.Audience != audience {
resp.WriteHeader(http.StatusUnauthorized)
}

_, err = resp.Write([]byte(fmt.Sprintf("{\"access_token\": \"%s\"}", token)))
if err != nil {
resp.WriteHeader(http.StatusInternalServerError)
}
})

srv := httptest.NewServer(handlers)
endpointURL, _ := url.Parse(srv.URL + path)
return srv, NewOAuthCredential(audience, clientId, clientSecret, endpointURL)
}

func TestCredentials(t *testing.T) {
audience := "localhost"
srv, oauthCred := mockOAuthCredential("oauth token", audience)
defer srv.Close()
googleCred, err := mockGoogleCredential("google token", audience)
if err != nil {
t.Errorf("Unexpected error creating mock google credential: %v", err)
}

tt := []struct {
name string
credential *Credential
want string
wantErr bool
err error
}{
{
name: "Valid Static Credential get authentication metadata.",
credential: NewStaticCredential("static token"),
want: "static token",
wantErr: false,
err: nil,
},
{
name: "Valid Google Credential get authentication metadata.",
credential: googleCred,
want: "google token",
wantErr: false,
err: nil,
},
{
name: "Valid OAuth Credential get authentication metadata.",
credential: oauthCred,
want: "oauth token",
wantErr: false,
err: nil,
},
}

for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
ctx := context.Background()
meta, err := tc.credential.GetRequestMetadata(ctx, "feast.serving")
if err != nil {
t.Error(err)
}
authKey := "Authorization"
if _, ok := meta[authKey]; !ok {
t.Errorf("Expected authentication metadata with key: '%s'", authKey)
}

expectedVal := "Bearer: " + tc.want
if meta[authKey] != expectedVal {
t.Errorf("Expected authentication metadata with value: '%s' Got instead: '%s'", expectedVal, meta[authKey])
}
})
}
}
Loading