Skip to content

Commit

Permalink
mfa: add WithMFA to session-related audit events (#5833)
Browse files Browse the repository at this point in the history
* mfa: put device UUID in MFAVerified cert extensions

Instead of just a bool, add the device UUID. This will be used in audit
events when a session is started using MFA-issued certs.

* Add WithMFA to session-related audit events

Also extract a common function for parsing SSH certs.
  • Loading branch information
Andrew Lytvynov authored Mar 4, 2021
1 parent 85c2e88 commit 6d8f778
Show file tree
Hide file tree
Showing 33 changed files with 529 additions and 506 deletions.
684 changes: 363 additions & 321 deletions api/types/events/events.pb.go

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions api/types/events/events.proto
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ message Metadata {
message SessionMetadata {
// SessionID is a unique UUID of the session.
string SessionID = 1 [ (gogoproto.jsontag) = "sid" ];
// WithMFA is a UUID of an MFA device used to start this session.
string WithMFA = 2 [ (gogoproto.jsontag) = "with_mfa,omitempty" ];
}

// UserMetadata is a common user event metadata
Expand Down
20 changes: 10 additions & 10 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,16 +478,18 @@ type certRequest struct {
// dbName is the optional database name which, if provided, will be used
// as a default database.
dbName string
// mfaVerified is set when this certRequest was created immediately after
// an MFA check.
mfaVerified bool
// mfaVerified is the UUID of an MFA device when this certRequest was
// created immediately after an MFA check.
mfaVerified string
// clientIP is an IP of the client requesting the certificate.
clientIP string
}

type certRequestOption func(*certRequest)

func certRequestMFAVerified(r *certRequest) { r.mfaVerified = true }
func certRequestMFAVerified(mfaID string) certRequestOption {
return func(r *certRequest) { r.mfaVerified = mfaID }
}
func certRequestClientIP(ip string) certRequestOption {
return func(r *certRequest) { r.clientIP = ip }
}
Expand Down Expand Up @@ -2026,21 +2028,19 @@ func (a *Server) mfaAuthChallenge(ctx context.Context, user string, u2fStorage u
return challenge, nil
}

func (a *Server) validateMFAAuthResponse(ctx context.Context, user string, resp *proto.MFAAuthenticateResponse, u2fStorage u2f.AuthenticationStorage) error {
var err error
func (a *Server) validateMFAAuthResponse(ctx context.Context, user string, resp *proto.MFAAuthenticateResponse, u2fStorage u2f.AuthenticationStorage) (*types.MFADevice, error) {
switch res := resp.Response.(type) {
case *proto.MFAAuthenticateResponse_TOTP:
_, err = a.checkOTP(user, res.TOTP.Code)
return a.checkOTP(user, res.TOTP.Code)
case *proto.MFAAuthenticateResponse_U2F:
_, err = a.checkU2F(ctx, user, u2f.AuthenticateChallengeResponse{
return a.checkU2F(ctx, user, u2f.AuthenticateChallengeResponse{
KeyHandle: res.U2F.KeyHandle,
ClientData: res.U2F.ClientData,
SignatureData: res.U2F.Signature,
}, u2fStorage)
default:
err = trace.BadParameter("unknown or missing MFAAuthenticateResponse type %T", resp.Response)
return nil, trace.BadParameter("unknown or missing MFAAuthenticateResponse type %T", resp.Response)
}
return trace.Wrap(err)
}

func (a *Server) checkU2F(ctx context.Context, user string, res u2f.AuthenticateChallengeResponse, u2fStorage u2f.AuthenticationStorage) (*types.MFADevice, error) {
Expand Down
7 changes: 3 additions & 4 deletions lib/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import (
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/services/local"
"github.com/gravitational/teleport/lib/services/suite"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"

Expand Down Expand Up @@ -225,9 +226,8 @@ func (s *AuthSuite) TestAuthenticateSSHUser(c *C) {
// Verify the public key and principals in SSH cert.
inSSHPub, _, _, _, err := ssh.ParseAuthorizedKey(pub)
c.Assert(err, IsNil)
gotSSHCertPub, _, _, _, err := ssh.ParseAuthorizedKey(resp.Cert)
gotSSHCert, err := sshutils.ParseCertificate(resp.Cert)
c.Assert(err, IsNil)
gotSSHCert := gotSSHCertPub.(*ssh.Certificate)
c.Assert(gotSSHCert.Key, DeepEquals, inSSHPub)
c.Assert(gotSSHCert.ValidPrincipals, DeepEquals, []string{user})
// Verify the public key and Subject in TLS cert.
Expand Down Expand Up @@ -567,9 +567,8 @@ func (s *AuthSuite) TestTokensCRUD(c *C) {
c.Assert(err, IsNil)

// along the way, make sure that additional principals work
key, _, _, _, err := ssh.ParseAuthorizedKey(keys.Cert)
hostCert, err := sshutils.ParseCertificate(keys.Cert)
c.Assert(err, IsNil)
hostCert := key.(*ssh.Certificate)
comment := Commentf("can't find example.com in %v", hostCert.ValidPrincipals)
c.Assert(utils.SliceContainsStr(hostCert.ValidPrincipals, "example.com"), Equals, true, comment)

Expand Down
2 changes: 1 addition & 1 deletion lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -2512,7 +2512,7 @@ func (a *ServerWithRoles) UpsertKubeService(ctx context.Context, s services.Serv
}

for _, kube := range s.GetKubernetesClusters() {
if err := a.context.Checker.CheckAccessToKubernetes(s.GetNamespace(), kube, a.context.Identity.GetIdentity().MFAVerified); err != nil {
if err := a.context.Checker.CheckAccessToKubernetes(s.GetNamespace(), kube, a.context.Identity.GetIdentity().MFAVerified != ""); err != nil {
return utils.OpaqueAccessDenied(err)
}
}
Expand Down
32 changes: 17 additions & 15 deletions lib/auth/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1363,7 +1363,7 @@ func addMFADeviceAuthChallenge(gctx *grpcContext, stream proto.AuthService_AddMF
}
// Only validate if there was a challenge.
if authChallenge.TOTP != nil || len(authChallenge.U2F) > 0 {
if err := auth.validateMFAAuthResponse(ctx, user, authResp, u2fStorage); err != nil {
if _, err := auth.validateMFAAuthResponse(ctx, user, authResp, u2fStorage); err != nil {
return trace.Wrap(err)
}
}
Expand Down Expand Up @@ -1581,7 +1581,7 @@ func deleteMFADeviceAuthChallenge(gctx *grpcContext, stream proto.AuthService_De
if authResp == nil {
return trace.BadParameter("expected MFAAuthenticateResponse, got %T", req)
}
if err := auth.validateMFAAuthResponse(ctx, user, authResp, u2fStorage); err != nil {
if _, err := auth.validateMFAAuthResponse(ctx, user, authResp, u2fStorage); err != nil {
return trace.Wrap(err)
}
return nil
Expand Down Expand Up @@ -1649,12 +1649,13 @@ func (g *GRPCServer) GenerateUserSingleUseCerts(stream proto.AuthService_Generat

// 2. send MFAChallenge
// 3. receive and validate MFAResponse
if err := userSingleUseCertsAuthChallenge(actx, stream); err != nil {
mfaDev, err := userSingleUseCertsAuthChallenge(actx, stream)
if err != nil {
return trail.ToGRPC(err)
}

// Generate the cert.
respCert, err := userSingleUseCertsGenerate(stream.Context(), actx, *initReq)
respCert, err := userSingleUseCertsGenerate(stream.Context(), actx, *initReq, mfaDev)
if err != nil {
return trail.ToGRPC(err)
}
Expand Down Expand Up @@ -1694,40 +1695,41 @@ func validateUserSingleUseCertRequest(req *proto.UserCertsRequest, clock clockwo
return nil
}

func userSingleUseCertsAuthChallenge(gctx *grpcContext, stream proto.AuthService_GenerateUserSingleUseCertsServer) error {
func userSingleUseCertsAuthChallenge(gctx *grpcContext, stream proto.AuthService_GenerateUserSingleUseCertsServer) (*types.MFADevice, error) {
ctx := stream.Context()
auth := gctx.authServer
user := gctx.User.GetName()
u2fStorage, err := u2f.InMemoryAuthenticationStorage(auth.Identity)
if err != nil {
return trace.Wrap(err)
return nil, trace.Wrap(err)
}

authChallenge, err := auth.mfaAuthChallenge(ctx, user, u2fStorage)
if err != nil {
return trace.Wrap(err)
return nil, trace.Wrap(err)
}
if err := stream.Send(&proto.UserSingleUseCertsResponse{
Response: &proto.UserSingleUseCertsResponse_MFAChallenge{MFAChallenge: authChallenge},
}); err != nil {
return trace.Wrap(err)
return nil, trace.Wrap(err)
}

req, err := stream.Recv()
if err != nil {
return trace.Wrap(err)
return nil, trace.Wrap(err)
}
authResp := req.GetMFAResponse()
if authResp == nil {
return trace.BadParameter("expected MFAAuthenticateResponse, got %T", req.Request)
return nil, trace.BadParameter("expected MFAAuthenticateResponse, got %T", req.Request)
}
if err := auth.validateMFAAuthResponse(ctx, user, authResp, u2fStorage); err != nil {
return trace.Wrap(err)
mfaDev, err := auth.validateMFAAuthResponse(ctx, user, authResp, u2fStorage)
if err != nil {
return nil, trace.Wrap(err)
}
return nil
return mfaDev, nil
}

func userSingleUseCertsGenerate(ctx context.Context, actx *grpcContext, req proto.UserCertsRequest) (*proto.SingleUseUserCert, error) {
func userSingleUseCertsGenerate(ctx context.Context, actx *grpcContext, req proto.UserCertsRequest, mfaDev *types.MFADevice) (*proto.SingleUseUserCert, error) {
// Get the client IP.
clientPeer, ok := peer.FromContext(ctx)
if !ok {
Expand All @@ -1739,7 +1741,7 @@ func userSingleUseCertsGenerate(ctx context.Context, actx *grpcContext, req prot
}

// Generate the cert.
certs, err := actx.generateUserCerts(ctx, req, certRequestMFAVerified, certRequestClientIP(clientIP))
certs, err := actx.generateUserCerts(ctx, req, certRequestMFAVerified(mfaDev.Id), certRequestClientIP(clientIP))
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
24 changes: 13 additions & 11 deletions lib/auth/grpcserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"github.com/pquerna/otp"
"github.com/pquerna/otp/totp"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/client/proto"
Expand All @@ -40,6 +39,7 @@ import (
"github.com/gravitational/teleport/lib/auth/mocku2f"
"github.com/gravitational/teleport/lib/auth/u2f"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/tlsca"
)

Expand Down Expand Up @@ -619,6 +619,12 @@ func TestGenerateUserSingleUseCert(t *testing.T) {
return wantDev
},
})
// Fetch MFA device ID.
devs, err := srv.Auth().GetMFADevices(ctx, user.GetName())
require.NoError(t, err)
require.Len(t, devs, 1)
u2fDevID := devs[0].Id

u2fChallengeHandler := func(t *testing.T, req *proto.MFAAuthenticateChallenge) *proto.MFAAuthenticateResponse {
require.Len(t, req.U2F, 1)
chal := req.U2F[0]
Expand Down Expand Up @@ -660,12 +666,10 @@ func TestGenerateUserSingleUseCert(t *testing.T) {
crt := c.GetSSH()
require.NotEmpty(t, crt)

key, _, _, _, err := ssh.ParseAuthorizedKey(crt)
cert, err := sshutils.ParseCertificate(crt)
require.NoError(t, err)
cert, ok := key.(*ssh.Certificate)
require.True(t, ok)

require.Contains(t, cert.Extensions, teleport.CertExtensionMFAVerified)
require.Equal(t, cert.Extensions[teleport.CertExtensionMFAVerified], u2fDevID)
require.True(t, net.ParseIP(cert.Extensions[teleport.CertExtensionClientIP]).IsLoopback())
require.Equal(t, cert.ValidBefore, uint64(clock.Now().Add(teleport.UserSingleUseCertTTL).Unix()))
},
Expand Down Expand Up @@ -694,7 +698,7 @@ func TestGenerateUserSingleUseCert(t *testing.T) {

identity, err := tlsca.FromSubject(cert.Subject, cert.NotAfter)
require.NoError(t, err)
require.True(t, identity.MFAVerified)
require.Equal(t, identity.MFAVerified, u2fDevID)
require.True(t, net.ParseIP(identity.ClientIP).IsLoopback())
require.Equal(t, identity.Usage, []string{teleport.UsageKubeOnly})
require.Equal(t, identity.KubernetesCluster, "kube-a")
Expand Down Expand Up @@ -726,7 +730,7 @@ func TestGenerateUserSingleUseCert(t *testing.T) {

identity, err := tlsca.FromSubject(cert.Subject, cert.NotAfter)
require.NoError(t, err)
require.True(t, identity.MFAVerified)
require.Equal(t, identity.MFAVerified, u2fDevID)
require.True(t, net.ParseIP(identity.ClientIP).IsLoopback())
require.Equal(t, identity.Usage, []string{teleport.UsageDatabaseOnly})
require.Equal(t, identity.RouteToDatabase.ServiceName, "db-a")
Expand Down Expand Up @@ -765,12 +769,10 @@ func TestGenerateUserSingleUseCert(t *testing.T) {
crt := c.GetSSH()
require.NotEmpty(t, crt)

key, _, _, _, err := ssh.ParseAuthorizedKey(crt)
cert, err := sshutils.ParseCertificate(crt)
require.NoError(t, err)
cert, ok := key.(*ssh.Certificate)
require.True(t, ok)

require.Contains(t, cert.Extensions, teleport.CertExtensionMFAVerified)
require.Equal(t, cert.Extensions[teleport.CertExtensionMFAVerified], u2fDevID)
require.True(t, net.ParseIP(cert.Extensions[teleport.CertExtensionClientIP]).IsLoopback())
require.Equal(t, cert.ValidBefore, uint64(clock.Now().Add(teleport.UserSingleUseCertTTL).Unix()))
},
Expand Down
7 changes: 1 addition & 6 deletions lib/auth/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -1016,16 +1016,11 @@ func ReadSSHIdentityFromKeyPair(keyBytes, certBytes []byte) (*Identity, error) {
return nil, trace.BadParameter("Cert: missing parameter")
}

pubKey, _, _, _, err := ssh.ParseAuthorizedKey(certBytes)
cert, err := sshutils.ParseCertificate(certBytes)
if err != nil {
return nil, trace.BadParameter("failed to parse server certificate: %v", err)
}

cert, ok := pubKey.(*ssh.Certificate)
if !ok {
return nil, trace.BadParameter("expected ssh.Certificate, got %v", pubKey)
}

signer, err := ssh.ParsePrivateKey(keyBytes)
if err != nil {
return nil, trace.BadParameter("failed to parse private key: %v", err)
Expand Down
4 changes: 1 addition & 3 deletions lib/auth/init_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,8 @@ func TestReadIdentity(t *testing.T) {
TTL: ttl,
})
require.NoError(t, err)
pk, _, _, _, err := ssh.ParseAuthorizedKey(bytes)
copy, err := sshutils.ParseCertificate(bytes)
require.NoError(t, err)
copy, ok := pk.(*ssh.Certificate)
require.True(t, ok)
require.Equal(t, uint64(expiryDate.Unix()), copy.ValidBefore)
}

Expand Down
4 changes: 2 additions & 2 deletions lib/auth/native/native.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,8 @@ func (k *Keygen) GenerateUserCertWithoutValidation(c services.UserCertParams) ([
if c.PermitPortForwarding {
cert.Permissions.Extensions[teleport.CertExtensionPermitPortForwarding] = ""
}
if c.MFAVerified {
cert.Permissions.Extensions[teleport.CertExtensionMFAVerified] = ""
if c.MFAVerified != "" {
cert.Permissions.Extensions[teleport.CertExtensionMFAVerified] = c.MFAVerified
}
if c.ClientIP != "" {
cert.Permissions.Extensions[teleport.CertExtensionClientIP] = c.ClientIP
Expand Down
15 changes: 4 additions & 11 deletions lib/auth/native/native_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@ import (
"testing"
"time"

"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/auth/test"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/utils"

"github.com/jonboulle/clockwork"
Expand Down Expand Up @@ -182,12 +181,9 @@ func (s *NativeSuite) TestBuildPrincipals(c *check.C) {
})
c.Assert(err, check.IsNil)

publicKey, _, _, _, err := ssh.ParseAuthorizedKey(hostCertificateBytes)
hostCertificate, err := sshutils.ParseCertificate(hostCertificateBytes)
c.Assert(err, check.IsNil)

hostCertificate, ok := publicKey.(*ssh.Certificate)
c.Assert(ok, check.Equals, true)

c.Assert(hostCertificate.ValidPrincipals, check.DeepEquals, tt.outValidPrincipals)
}
}
Expand Down Expand Up @@ -232,15 +228,12 @@ func (s *NativeSuite) TestUserCertCompatibility(c *check.C) {
})
c.Assert(err, check.IsNil, comment)

publicKey, _, _, _, err := ssh.ParseAuthorizedKey(userCertificateBytes)
userCertificate, err := sshutils.ParseCertificate(userCertificateBytes)
c.Assert(err, check.IsNil, comment)

userCertificate, ok := publicKey.(*ssh.Certificate)
c.Assert(ok, check.Equals, true, comment)
// Check that the signature algorithm is correct.
c.Assert(userCertificate.Signature.Format, check.Equals, defaults.CASignatureAlgorithm)
// check if we added the roles extension
_, ok = userCertificate.Extensions[teleport.CertExtensionTeleportRoles]
_, ok := userCertificate.Extensions[teleport.CertExtensionTeleportRoles]
c.Assert(ok, check.Equals, tt.outHasRoles, comment)
}
}
Loading

0 comments on commit 6d8f778

Please sign in to comment.