Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v16] Workload Identity: JWT SVID OIDC compatability (#47079) #47317

Merged
merged 2 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion lib/auth/machineid/machineidv1/workload_identity_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import (
"github.com/gravitational/teleport/lib/tlsca"
usagereporter "github.com/gravitational/teleport/lib/usagereporter/teleport"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/oidc"
)

const (
Expand Down Expand Up @@ -73,6 +74,7 @@ type WorkloadIdentityServiceConfig struct {
type WorkloadIdentityCacher interface {
GetCertAuthority(ctx context.Context, id types.CertAuthID, loadKeys bool) (types.CertAuthority, error)
GetClusterName(opts ...services.MarshalOption) (types.ClusterName, error)
GetProxies() ([]types.Server, error)
}

// KeyStorer is an interface that provides methods to retrieve keys and
Expand Down Expand Up @@ -377,6 +379,7 @@ func (wis *WorkloadIdentityService) signJWTSVID(
ctx context.Context,
authCtx *authz.Context,
clusterName string,
issuer string,
key *jwt.Key,
req *pb.JWTSVIDRequest,
) (res *pb.JWTSVIDResponse, err error) {
Expand Down Expand Up @@ -457,6 +460,7 @@ func (wis *WorkloadIdentityService) signJWTSVID(
SPIFFEID: spiffeID,
TTL: ttl,
JTI: jti,
Issuer: issuer,
})
if err != nil {
return nil, trace.Wrap(err, "signing jwt")
Expand Down Expand Up @@ -507,10 +511,17 @@ func (wis *WorkloadIdentityService) SignJWTSVIDs(
return nil, trace.Wrap(err, "getting JWT key")
}

// Determine the public address of the proxy for inclusion in the JWT as
// the issuer for purposes of OIDC compatibility.
issuer, err := oidc.IssuerForCluster(ctx, wis.cache, "/workload-identity")
if err != nil {
return nil, trace.Wrap(err, "determining issuer")
}

res := &pb.SignJWTSVIDsResponse{}
for i, svidReq := range req.Svids {
svidRes, err := wis.signJWTSVID(
ctx, authCtx, clusterName.GetClusterName(), jwtKey, svidReq,
ctx, authCtx, clusterName.GetClusterName(), issuer, jwtKey, svidReq,
)
if err != nil {
return nil, trace.Wrap(err, "signing svid %d", i)
Expand Down
11 changes: 11 additions & 0 deletions lib/auth/machineid/machineidv1/workload_identity_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,16 @@ func TestWorkloadIdentityService_SignJWTSVIDs(t *testing.T) {

kid := libjwt.KeyID(jwtSigner.Public().(*rsa.PublicKey))

// Upsert a fake proxy to ensure we have a public address to use for the
// issuer.
proxy, err := types.NewServer("proxy", types.KindProxy, types.ServerSpecV2{
PublicAddrs: []string{"teleport.example.com"},
})
require.NoError(t, err)
err = srv.Auth().UpsertProxy(ctx, proxy)
require.NoError(t, err)
wantIssuer := "https://teleport.example.com/workload-identity"

tests := []struct {
name string
user string
Expand Down Expand Up @@ -336,6 +346,7 @@ func TestWorkloadIdentityService_SignJWTSVIDs(t *testing.T) {
require.Equal(t, wantSPIFFEID, claims.Subject)
require.Equal(t, svid.Jti, claims.ID)
require.Equal(t, "example.com", claims.Audience[0])
require.Equal(t, wantIssuer, claims.Issuer)
require.WithinDuration(t, time.Now().Add(30*time.Minute), claims.Expiry.Time(), 5*time.Second)
require.WithinDuration(t, time.Now(), claims.IssuedAt.Time(), 5*time.Second)
},
Expand Down
2 changes: 1 addition & 1 deletion lib/integrations/awsoidc/idp_thumbprint.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import (
// Returns the thumbprint of the top intermediate CA that signed the TLS cert used to serve HTTPS requests.
// In case of a self signed certificate, then it returns the thumbprint of the TLS cert itself.
func ThumbprintIdP(ctx context.Context, publicAddress string) (string, error) {
issuer, err := oidc.IssuerFromPublicAddress(publicAddress)
issuer, err := oidc.IssuerFromPublicAddress(publicAddress, "")
if err != nil {
return "", trace.Wrap(err)
}
Expand Down
2 changes: 1 addition & 1 deletion lib/integrations/awsoidc/token_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func GenerateAWSOIDCToken(ctx context.Context, cacheClt Cache, keyStoreManager K
}

if issuer == "" {
issuer, err = oidc.IssuerForCluster(ctx, cacheClt)
issuer, err = oidc.IssuerForCluster(ctx, cacheClt, "")
if err != nil {
return "", trace.Wrap(err)
}
Expand Down
2 changes: 1 addition & 1 deletion lib/integrations/azureoidc/token_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func GenerateEntraOIDCToken(ctx context.Context, cache Cache, manager KeyStoreMa
return "", trace.Wrap(err)
}

issuer, err := oidc.IssuerForCluster(ctx, cache)
issuer, err := oidc.IssuerForCluster(ctx, cache, "")
if err != nil {
return "", trace.Wrap(err)
}
Expand Down
8 changes: 8 additions & 0 deletions lib/jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ type SignParamsJWTSVID struct {
Audiences []string
// TTL is the time to live for the token.
TTL time.Duration
// Issuer is the value that should be included in the `iss` claim of the
// created token.
Issuer string
}

// SignJWTSVID signs a JWT SVID token.
Expand All @@ -283,6 +286,11 @@ func (k *Key) SignJWTSVID(p SignParamsJWTSVID) (string, error) {
// > noted that JWT-SVID validators are not required to track jti
// > uniqueness.
ID: p.JTI,
// The SPIFFE specification makes no comment on the inclusion of `iss`,
// however, we provide this value so that the issued token can be a
// valid OIDC ID token and used with non-SPIFFE aware systems that do
// understand OIDC.
Issuer: p.Issuer,
}

// > 2.2. Key ID:
Expand Down
18 changes: 14 additions & 4 deletions lib/utils/oidc/issuer.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ type ProxiesGetter interface {
}

// IssuerForCluster returns the issuer URL using the Cluster state.
func IssuerForCluster(ctx context.Context, clt ProxiesGetter) (string, error) {
// Path is an optional element to append to the issuer to distinguish a
// separate CA within the same cluster.
func IssuerForCluster(ctx context.Context, clt ProxiesGetter, path string) (string, error) {
proxies, err := clt.GetProxies()
if err != nil {
return "", trace.Wrap(err)
Expand All @@ -44,18 +46,22 @@ func IssuerForCluster(ctx context.Context, clt ProxiesGetter) (string, error) {
for _, p := range proxies {
proxyPublicAddress := p.GetPublicAddr()
if proxyPublicAddress != "" {
return IssuerFromPublicAddress(proxyPublicAddress)
return IssuerFromPublicAddress(proxyPublicAddress, path)
}
}

return "", trace.BadParameter("failed to get Proxy Public Address")
}

// IssuerFromPublicAddress is the address for the AWS OIDC Provider.
// IssuerFromPublicAddress is the address for an OIDC Provider.
//
// It must match exactly what was introduced in AWS IAM console when adding the Identity Provider.
// PublicProxyAddr from `teleport.yaml/proxy` does not come with the desired format: it misses the protocol and has a port
// This method adds the `https` protocol and removes the port if it is the default one for https (443)
func IssuerFromPublicAddress(addr string) (string, error) {
//
// Path is an optional element to append to the issuer to distinguish a
// separate CA within the same cluster.
func IssuerFromPublicAddress(addr string, path string) (string, error) {
// Add protocol if not present.
if !strings.HasPrefix(addr, "https://") && !strings.HasPrefix(addr, "http://") {
addr = "https://" + addr
Expand All @@ -70,5 +76,9 @@ func IssuerFromPublicAddress(addr string) (string, error) {
// Cut off redundant :443
result.Host = result.Hostname()
}

if path != "" {
result.Path = path
}
return result.String(), nil
}
34 changes: 32 additions & 2 deletions lib/utils/oidc/issuer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,31 @@ func TestIssuerFromPublicAddress(t *testing.T) {
for _, tt := range []struct {
name string
addr string
path string
expected string
}{
{
name: "valid host:port",
addr: "127.0.0.1.nip.io:3080",
expected: "https://127.0.0.1.nip.io:3080",
},
{
name: "valid host:port with path",
addr: "127.0.0.1.nip.io:3080",
path: "/workload-identity",
expected: "https://127.0.0.1.nip.io:3080/workload-identity",
},
{
name: "valid ip:port",
addr: "127.0.0.1:3080",
expected: "https://127.0.0.1:3080",
},
{
name: "valid ip:port with path",
addr: "127.0.0.1:3080",
path: "/workload-identity",
expected: "https://127.0.0.1:3080/workload-identity",
},
{
name: "removes 443 port",
addr: "https://teleport-local.example.com:443",
Expand All @@ -54,9 +67,15 @@ func TestIssuerFromPublicAddress(t *testing.T) {
addr: "localhost",
expected: "https://localhost",
},
{
name: "only host with path",
addr: "localhost",
path: "/workload-identity",
expected: "https://localhost/workload-identity",
},
} {
t.Run(tt.name, func(t *testing.T) {
got, err := IssuerFromPublicAddress(tt.addr)
got, err := IssuerFromPublicAddress(tt.addr, tt.path)
require.NoError(t, err)
require.Equal(t, tt.expected, got)
})
Expand All @@ -79,6 +98,7 @@ func TestIssuerForCluster(t *testing.T) {
ctx := context.Background()
for _, tt := range []struct {
name string
path string
mockProxies []types.Server
mockErr error
checkErr require.ErrorAssertionFunc
Expand All @@ -93,6 +113,16 @@ func TestIssuerForCluster(t *testing.T) {
},
expectedIssuer: "https://127.0.0.1.nip.io",
},
{
name: "valid with subpath",
path: "/workload-identity",
mockProxies: []types.Server{
&types.ServerV2{Spec: types.ServerSpecV2{
PublicAddrs: []string{"127.0.0.1.nip.io"},
}},
},
expectedIssuer: "https://127.0.0.1.nip.io/workload-identity",
},
{
name: "only the second server has a valid public address",
mockProxies: []types.Server{
Expand Down Expand Up @@ -121,7 +151,7 @@ func TestIssuerForCluster(t *testing.T) {
proxies: tt.mockProxies,
returnErr: tt.mockErr,
}
issuer, err := IssuerForCluster(ctx, clt)
issuer, err := IssuerForCluster(ctx, clt, tt.path)
if tt.checkErr != nil {
tt.checkErr(t, err)
}
Expand Down
5 changes: 3 additions & 2 deletions lib/utils/oidc/openidconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
package oidc

// OpenIDConfiguration is the default OpenID Configuration used by Teleport.
// Based on https://openid.net/specs/openid-connect-discovery-1_0.html
type OpenIDConfiguration struct {
Issuer string `json:"issuer"`
JWKSURI string `json:"jwks_uri"`
Claims []string `json:"claims"`
IdTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"`
ResponseTypesSupported []string `json:"response_types_supported"`
ScopesSupported []string `json:"scopes_supported"`
SubjectTypesSupported []string `json:"subject_types_supported"`
ScopesSupported []string `json:"scopes_supported,omitempty"`
SubjectTypesSupported []string `json:"subject_types_supported,omitempty"`
}

// OpenIDConfigurationForIssuer returns the OpenID Configuration for
Expand Down
4 changes: 3 additions & 1 deletion lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,8 @@ func (h *Handler) bindDefaultEndpoints() {

// SPIFFE Federation Trust Bundle
h.GET("/webapi/spiffe/bundle.json", h.WithLimiter(h.getSPIFFEBundle))
h.GET("/workload-identity/jwt-jwks.json", h.WithLimiter(h.getSPIFFEJWKS))
h.GET("/workload-identity/.well-known/openid-configuration", h.WithLimiter(h.getSPIFFEOIDCDiscoveryDocument))

// DiscoveryConfig CRUD
h.GET("/webapi/sites/:site/discoveryconfig", h.WithClusterAuth(h.discoveryconfigList))
Expand Down Expand Up @@ -1887,7 +1889,7 @@ func (h *Handler) getUIConfig(ctx context.Context) webclient.UIConfig {

// jwks returns all public keys used to sign JWT tokens for this cluster.
func (h *Handler) wellKnownJWKS(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
return h.jwks(r.Context(), types.JWTSigner)
return h.jwks(r.Context(), types.JWTSigner, true)
}

func (h *Handler) motd(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
Expand Down
4 changes: 2 additions & 2 deletions lib/web/integrations_awsoidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,7 @@ func (h *Handler) awsOIDCConfigureIdP(w http.ResponseWriter, r *http.Request, p

switch {
case s3Bucket == "" && s3Prefix == "":
proxyAddr, err := oidc.IssuerFromPublicAddress(h.cfg.PublicProxyAddr)
proxyAddr, err := oidc.IssuerFromPublicAddress(h.cfg.PublicProxyAddr, "")
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -1144,7 +1144,7 @@ func (h *Handler) awsOIDCConfigureIdP(w http.ResponseWriter, r *http.Request, p
}
s3URI := url.URL{Scheme: "s3", Host: s3Bucket, Path: s3Prefix}

jwksContents, err := h.jwks(r.Context(), types.OIDCIdPCA)
jwksContents, err := h.jwks(r.Context(), types.OIDCIdPCA, true)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
2 changes: 1 addition & 1 deletion lib/web/integrations_azureoidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (h *Handler) azureOIDCConfigure(w http.ResponseWriter, r *http.Request, p h
ctx := r.Context()
queryParams := r.URL.Query()

oidcIssuer, err := oidc.IssuerFromPublicAddress(h.cfg.PublicProxyAddr)
oidcIssuer, err := oidc.IssuerFromPublicAddress(h.cfg.PublicProxyAddr, "")
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
12 changes: 7 additions & 5 deletions lib/web/oidcidp.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ const (

// openidConfiguration returns the openid-configuration for setting up the AWS OIDC Integration
func (h *Handler) openidConfiguration(_ http.ResponseWriter, _ *http.Request, _ httprouter.Params) (interface{}, error) {
issuer, err := oidc.IssuerFromPublicAddress(h.cfg.PublicProxyAddr)
issuer, err := oidc.IssuerFromPublicAddress(h.cfg.PublicProxyAddr, "")
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -48,10 +48,10 @@ func (h *Handler) openidConfiguration(_ http.ResponseWriter, _ *http.Request, _

// jwksOIDC returns all public keys used to sign JWT tokens for this cluster.
func (h *Handler) jwksOIDC(_ http.ResponseWriter, r *http.Request, _ httprouter.Params) (interface{}, error) {
return h.jwks(r.Context(), types.OIDCIdPCA)
return h.jwks(r.Context(), types.OIDCIdPCA, true)
}

func (h *Handler) jwks(ctx context.Context, caType types.CertAuthType) (*JWKSResponse, error) {
func (h *Handler) jwks(ctx context.Context, caType types.CertAuthType, includeBlankKeyID bool) (*JWKSResponse, error) {
clusterName, err := h.GetProxyClient().GetDomainName(ctx)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -82,8 +82,10 @@ func (h *Handler) jwks(ctx context.Context, caType types.CertAuthType) (*JWKSRes

// Return an additional copy of the same JWK
// with KeyID set to the empty string for compatibility.
jwk.KeyID = ""
resp.Keys = append(resp.Keys, jwk)
if includeBlankKeyID {
jwk.KeyID = ""
resp.Keys = append(resp.Keys, jwk)
}
}
return &resp, nil
}
Expand Down
Loading
Loading