Skip to content

Commit

Permalink
add host fields to sshca.Identity
Browse files Browse the repository at this point in the history
  • Loading branch information
fspmarshall committed Jan 24, 2025
1 parent 4ffaa89 commit 09148bb
Show file tree
Hide file tree
Showing 20 changed files with 331 additions and 244 deletions.
9 changes: 6 additions & 3 deletions integration/helpers/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ import (
"github.com/gravitational/teleport/lib/service"
"github.com/gravitational/teleport/lib/service/servicecfg"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/sshca"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -373,14 +374,16 @@ func NewInstance(t *testing.T, cfg InstanceConfig) *TeleInstance {
fatalIf(err)

keygen := keygen.New(context.TODO())
cert, err := keygen.GenerateHostCert(services.HostCertParams{
cert, err := keygen.GenerateHostCert(sshca.HostCertificateRequest{
CASigner: sshSigner,
PublicHostKey: cfg.Pub,
HostID: cfg.HostID,
NodeName: cfg.NodeName,
ClusterName: cfg.ClusterName,
Role: types.RoleAdmin,
TTL: 24 * time.Hour,
Identity: sshca.Identity{
ClusterName: cfg.ClusterName,
SystemRole: types.RoleAdmin,
},
})
fatalIf(err)
tlsCA, err := tlsca.FromKeys(tlsCACert, cfg.Priv)
Expand Down
45 changes: 22 additions & 23 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2130,28 +2130,30 @@ func (a *Server) GenerateHostCert(ctx context.Context, hostPublicKey []byte, hos
}

// create and sign!
return a.generateHostCert(ctx, services.HostCertParams{
return a.generateHostCert(ctx, sshca.HostCertificateRequest{
CASigner: caSigner,
PublicHostKey: hostPublicKey,
HostID: hostID,
NodeName: nodeName,
Principals: principals,
ClusterName: clusterName,
Role: role,
TTL: ttl,
Identity: sshca.Identity{
Principals: principals,
ClusterName: clusterName,
SystemRole: role,
},
})
}

func (a *Server) generateHostCert(
ctx context.Context, p services.HostCertParams,
ctx context.Context, req sshca.HostCertificateRequest,
) ([]byte, error) {
readOnlyAuthPref, err := a.GetReadOnlyAuthPreference(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

var locks []types.LockTarget
switch p.Role {
switch req.Identity.SystemRole {
case types.RoleNode:
// Node role is a special case because it was previously suported as a
// lock target that only locked the `ssh_service`. If the same Teleport server
Expand All @@ -2164,17 +2166,17 @@ func (a *Server) generateHostCert(
// and `Node` fields if the role is `Node` so that the previous behavior
// is preserved.
// This is a legacy behavior that we need to support for backwards compatibility.
locks = []types.LockTarget{{ServerID: p.HostID, Node: p.HostID}, {ServerID: HostFQDN(p.HostID, p.ClusterName), Node: HostFQDN(p.HostID, p.ClusterName)}}
locks = []types.LockTarget{{ServerID: req.HostID, Node: req.HostID}, {ServerID: HostFQDN(req.HostID, req.Identity.ClusterName), Node: HostFQDN(req.HostID, req.Identity.ClusterName)}}
default:
locks = []types.LockTarget{{ServerID: p.HostID}, {ServerID: HostFQDN(p.HostID, p.ClusterName)}}
locks = []types.LockTarget{{ServerID: req.HostID}, {ServerID: HostFQDN(req.HostID, req.Identity.ClusterName)}}
}
if lockErr := a.checkLockInForce(readOnlyAuthPref.GetLockingMode(),
locks,
); lockErr != nil {
return nil, trace.Wrap(lockErr)
}

return a.Authority.GenerateHostCert(p)
return a.Authority.GenerateHostCert(req)
}

// GetKeyStore returns the KeyStore used by the auth server
Expand Down Expand Up @@ -2226,7 +2228,7 @@ type certRequest struct {
traits wrappers.Traits
// activeRequests tracks privilege escalation requests applied
// during the construction of the certificate.
activeRequests services.RequestIDs
activeRequests []string
// appSessionID is the session ID of the application session.
appSessionID string
// appPublicAddr is the public address of the application.
Expand Down Expand Up @@ -3081,7 +3083,7 @@ func generateCert(ctx context.Context, a *Server, req certRequest, caType types.
defaultMode: readOnlyAuthPref.GetLockingMode(),
username: req.user.GetName(),
mfaVerified: req.mfaVerified,
activeAccessRequests: req.activeRequests.AccessRequests,
activeAccessRequests: req.activeRequests,
deviceID: req.deviceExtensions.DeviceID,
}); err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -3210,11 +3212,6 @@ func generateCert(ctx context.Context, a *Server, req certRequest, caType types.
// All users have access to this and join RBAC rules are checked after the connection is established.
allowedLogins = append(allowedLogins, teleport.SSHSessionJoinPrincipal)

requestedResourcesStr, err := types.ResourceIDsToString(req.checker.GetAllowedResourceIDs())
if err != nil {
return nil, trace.Wrap(err)
}

pinnedIP := ""
if caType == types.UserCA && (req.checker.PinSourceIP() || req.pinIP) {
if req.loginIP == "" {
Expand Down Expand Up @@ -3254,7 +3251,7 @@ func generateCert(ctx context.Context, a *Server, req certRequest, caType types.
Identity: sshca.Identity{
Username: req.user.GetName(),
Impersonator: req.impersonator,
AllowedLogins: allowedLogins,
Principals: allowedLogins,
Roles: req.checker.RoleNames(),
PermitPortForwarding: req.checker.CanPortForward(),
PermitAgentForwarding: req.checker.CanForwardAgents(),
Expand All @@ -3272,7 +3269,7 @@ func generateCert(ctx context.Context, a *Server, req certRequest, caType types.
BotName: req.botName,
BotInstanceID: req.botInstanceID,
CertificateExtensions: req.checker.CertificateExtensions(),
AllowedResourceIDs: requestedResourcesStr,
AllowedResourceIDs: req.checker.GetAllowedResourceIDs(),
ConnectionDiagnosticID: req.connectionDiagnosticID,
PrivateKeyPolicy: attestedKeyPolicy,
DeviceID: req.deviceExtensions.DeviceID,
Expand Down Expand Up @@ -3367,7 +3364,7 @@ func generateCert(ctx context.Context, a *Server, req certRequest, caType types.
AWSRoleARNs: roleARNs,
AzureIdentities: azureIdentities,
GCPServiceAccounts: gcpAccounts,
ActiveRequests: req.activeRequests.AccessRequests,
ActiveRequests: req.activeRequests,
DisallowReissue: req.disallowReissue,
Renewable: req.renewable,
Generation: req.generation,
Expand Down Expand Up @@ -4734,14 +4731,16 @@ func (a *Server) GenerateHostCerts(ctx context.Context, req *proto.HostCertsRequ
return nil, trace.Wrap(err)
}
// generate host SSH certificate
hostSSHCert, err := a.generateHostCert(ctx, services.HostCertParams{
hostSSHCert, err := a.generateHostCert(ctx, sshca.HostCertificateRequest{
CASigner: caSigner,
PublicHostKey: req.PublicSSHKey,
HostID: req.HostID,
NodeName: req.NodeName,
ClusterName: clusterName.GetClusterName(),
Role: req.Role,
Principals: req.AdditionalPrincipals,
Identity: sshca.Identity{
ClusterName: clusterName.GetClusterName(),
SystemRole: req.Role,
Principals: req.AdditionalPrincipals,
},
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down
2 changes: 1 addition & 1 deletion lib/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2642,7 +2642,7 @@ func TestGenerateUserCertWithLocks(t *testing.T) {
mfaVerified: mfaID,
sshPublicKey: sshPubKey,
tlsPublicKey: tlsPubKey,
activeRequests: services.RequestIDs{AccessRequests: []string{requestID}},
activeRequests: []string{requestID},
deviceExtensions: DeviceExtensions{
DeviceID: deviceID,
AssetTag: "assettag1",
Expand Down
8 changes: 3 additions & 5 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -3440,11 +3440,9 @@ func (a *ServerWithRoles) generateUserCerts(ctx context.Context, req proto.UserC
checker: checker,
// Copy IP from current identity to the generated certificate, if present,
// to avoid generateUserCerts() being used to drop IP pinning in the new certificates.
loginIP: a.context.Identity.GetIdentity().LoginIP,
traits: accessInfo.Traits,
activeRequests: services.RequestIDs{
AccessRequests: req.AccessRequests,
},
loginIP: a.context.Identity.GetIdentity().LoginIP,
traits: accessInfo.Traits,
activeRequests: req.AccessRequests,
connectionDiagnosticID: req.ConnectionDiagnosticID,
botName: getBotName(user),

Expand Down
41 changes: 26 additions & 15 deletions lib/auth/init_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ import (
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/services/suite"
"github.com/gravitational/teleport/lib/srv/db/common/databaseobjectimportrule"
"github.com/gravitational/teleport/lib/sshca"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/proxy"
Expand All @@ -77,14 +78,16 @@ func TestReadIdentity(t *testing.T) {
caSigner, err := ssh.ParsePrivateKey(priv)
require.NoError(t, err)

cert, err := a.GenerateHostCert(services.HostCertParams{
cert, err := a.GenerateHostCert(sshca.HostCertificateRequest{
CASigner: caSigner,
PublicHostKey: pub,
HostID: "id1",
NodeName: "node-name",
ClusterName: "example.com",
Role: types.RoleNode,
TTL: 0,
Identity: sshca.Identity{
ClusterName: "example.com",
SystemRole: types.RoleNode,
},
})
require.NoError(t, err)

Expand All @@ -98,14 +101,16 @@ func TestReadIdentity(t *testing.T) {
// test TTL by converting the generated cert to text -> back and making sure ExpireAfter is valid
ttl := 10 * time.Second
expiryDate := clock.Now().Add(ttl)
bytes, err := a.GenerateHostCert(services.HostCertParams{
bytes, err := a.GenerateHostCert(sshca.HostCertificateRequest{
CASigner: caSigner,
PublicHostKey: pub,
HostID: "id1",
NodeName: "node-name",
ClusterName: "example.com",
Role: types.RoleNode,
TTL: ttl,
Identity: sshca.Identity{
ClusterName: "example.com",
SystemRole: types.RoleNode,
},
})
require.NoError(t, err)
copy, err := apisshutils.ParseCertificate(bytes)
Expand All @@ -125,44 +130,50 @@ func TestBadIdentity(t *testing.T) {
require.IsType(t, trace.BadParameter(""), err)

// missing authority domain
cert, err := a.GenerateHostCert(services.HostCertParams{
cert, err := a.GenerateHostCert(sshca.HostCertificateRequest{
CASigner: caSigner,
PublicHostKey: pub,
HostID: "id2",
NodeName: "",
ClusterName: "",
Role: types.RoleNode,
TTL: 0,
Identity: sshca.Identity{
ClusterName: "",
SystemRole: types.RoleNode,
},
})
require.NoError(t, err)

_, err = state.ReadSSHIdentityFromKeyPair(priv, cert)
require.IsType(t, trace.BadParameter(""), err)

// missing host uuid
cert, err = a.GenerateHostCert(services.HostCertParams{
cert, err = a.GenerateHostCert(sshca.HostCertificateRequest{
CASigner: caSigner,
PublicHostKey: pub,
HostID: "example.com",
NodeName: "",
ClusterName: "",
Role: types.RoleNode,
TTL: 0,
Identity: sshca.Identity{
ClusterName: "",
SystemRole: types.RoleNode,
},
})
require.NoError(t, err)

_, err = state.ReadSSHIdentityFromKeyPair(priv, cert)
require.IsType(t, trace.BadParameter(""), err)

// unrecognized role
cert, err = a.GenerateHostCert(services.HostCertParams{
cert, err = a.GenerateHostCert(sshca.HostCertificateRequest{
CASigner: caSigner,
PublicHostKey: pub,
HostID: "example.com",
NodeName: "",
ClusterName: "id1",
Role: "bad role",
TTL: 0,
Identity: sshca.Identity{
ClusterName: "id1",
SystemRole: "bad role",
},
})
require.NoError(t, err)

Expand Down
Loading

0 comments on commit 09148bb

Please sign in to comment.