-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Certificates management (#1086)
* basics of mTLS working * allow mtls to be enabled or disabled * use db for cert backend
- Loading branch information
Showing
18 changed files
with
1,052 additions
and
269 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
package server | ||
|
||
import ( | ||
"crypto/tls" | ||
"crypto/x509" | ||
"fmt" | ||
"io/ioutil" | ||
"net/http" | ||
"net/http/httptest" | ||
"strings" | ||
"testing" | ||
"time" | ||
|
||
"github.com/infrahq/infra/internal/server/models" | ||
"github.com/infrahq/infra/pki" | ||
"github.com/infrahq/infra/uid" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestCertificateSigningWorks(t *testing.T) { | ||
db := setupDB(t) | ||
|
||
cp, err := pki.NewNativeCertificateProvider(db, pki.NativeCertificateProviderConfig{ | ||
FullKeyRotationDurationInDays: 2, | ||
}) | ||
require.NoError(t, err) | ||
|
||
err = cp.CreateCA() | ||
require.NoError(t, err) | ||
|
||
err = cp.RotateCA() | ||
require.NoError(t, err) | ||
|
||
user := &models.User{ | ||
Model: models.Model{ID: uid.New()}, | ||
Email: "joe@example.com", | ||
} | ||
|
||
keyPair, err := pki.MakeUserCert("User "+user.ID.String(), 24*time.Hour) | ||
require.NoError(t, err) | ||
|
||
// happens on the server, needs to be a request for this. | ||
signedCert, signedRaw, err := pki.SignUserCert(cp, keyPair.Cert, user) | ||
require.NoError(t, err) | ||
keyPair.SignedCert = signedCert | ||
keyPair.SignedCertPEM = signedRaw | ||
|
||
// create a test server and client to make sure the certs work. | ||
requireMutualTLSWorks(t, keyPair, cp) | ||
} | ||
|
||
func requireMutualTLSWorks(t *testing.T, clientKeypair *pki.KeyPair, cp pki.CertificateProvider) { | ||
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
fmt.Fprintf(w, "success!") | ||
})) | ||
|
||
serverTLSCerts, err := cp.TLSCertificates() | ||
require.NoError(t, err) | ||
|
||
caPool := x509.NewCertPool() | ||
|
||
for _, cert := range cp.ActiveCAs() { | ||
cert := cert | ||
caPool.AddCert(&cert) | ||
} | ||
|
||
server.TLS = &tls.Config{ | ||
Certificates: serverTLSCerts, | ||
ClientAuth: tls.RequireAndVerifyClientCert, | ||
ClientCAs: caPool, | ||
MinVersion: tls.VersionTLS12, | ||
} | ||
|
||
server.StartTLS() | ||
defer server.Close() | ||
|
||
// This will response with HTTP 200 OK and a body containing success!. We can now set up the client to trust the CA, and send a request to the server: | ||
|
||
clientTLSCert, err := clientKeypair.TLSCertificate() | ||
require.NoError(t, err) | ||
|
||
transport := &http.Transport{ | ||
TLSClientConfig: &tls.Config{ | ||
Certificates: []tls.Certificate{*clientTLSCert}, | ||
ClientCAs: caPool, | ||
RootCAs: caPool, | ||
MinVersion: tls.VersionTLS12, | ||
}, | ||
} | ||
http := http.Client{ | ||
Transport: transport, | ||
} | ||
|
||
resp, err := http.Get(server.URL) | ||
require.NoError(t, err) | ||
|
||
// If no errors occurred, we now have our success! response from the server, and can verify it: | ||
|
||
respBodyBytes, err := ioutil.ReadAll(resp.Body) | ||
require.NoError(t, err) | ||
|
||
body := strings.TrimSpace(string(respBodyBytes[:])) | ||
require.Equal(t, "success!", body) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
package data | ||
|
||
import ( | ||
"encoding/base64" | ||
"errors" | ||
"time" | ||
|
||
"github.com/infrahq/infra/internal" | ||
"github.com/infrahq/infra/internal/server/models" | ||
"gorm.io/gorm" | ||
) | ||
|
||
// TrustPublicKey trusts a public key (in base64 format) from a user or service | ||
// Callers must have received the key from a mTLS/e2ee (mutually encrypted), trusted source. | ||
func TrustPublicKey(db *gorm.DB, tc *models.TrustedCertificate) error { | ||
_, err := get[models.TrustedCertificate](db, ByPublicKey(tc.PublicKey)) | ||
if err != nil && !errors.Is(err, internal.ErrNotFound) { | ||
return err | ||
} | ||
|
||
if err == nil { | ||
// this one already exists | ||
return nil | ||
} | ||
|
||
return add(db, tc) | ||
} | ||
|
||
func ListTrustedClientCertificates(db *gorm.DB) ([]models.TrustedCertificate, error) { | ||
return list[models.TrustedCertificate](db) | ||
} | ||
|
||
func ListRootCertificates(db *gorm.DB) ([]models.RootCertificate, error) { | ||
return list[models.RootCertificate](db, OrderBy("id desc"), ByNotExpired(), Limit(2)) | ||
} | ||
|
||
func GetRootCertificate(db *gorm.DB, selectors ...SelectorFunc) (*models.RootCertificate, error) { | ||
return get[models.RootCertificate](db, selectors...) | ||
} | ||
|
||
func AddRootCertificate(db *gorm.DB, cert *models.RootCertificate) error { | ||
return add(db, cert) | ||
} | ||
|
||
func ByPublicKey(key []byte) SelectorFunc { | ||
return func(db *gorm.DB) *gorm.DB { | ||
k := base64.StdEncoding.EncodeToString(key) | ||
return db.Where("public_key = ?", k) | ||
} | ||
} | ||
|
||
func OrderBy(order string) SelectorFunc { | ||
return func(db *gorm.DB) *gorm.DB { | ||
return db.Order(order) | ||
} | ||
} | ||
|
||
func Limit(limit int) SelectorFunc { | ||
return func(db *gorm.DB) *gorm.DB { | ||
return db.Limit(limit) | ||
} | ||
} | ||
|
||
func ByNotExpired() SelectorFunc { | ||
return func(db *gorm.DB) *gorm.DB { | ||
return db.Where("expires_at is null or expires_at > ?", time.Now().UTC()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
package models | ||
|
||
import ( | ||
"time" | ||
) | ||
|
||
type TrustedCertificate struct { | ||
Model | ||
|
||
KeyAlgorithm string `validate:"required"` | ||
SigningAlgorithm string `validate:"required"` | ||
PublicKey Base64 `validate:"required"` | ||
CertPEM []byte `validate:"required"` // pem encoded | ||
Identity string `validate:"required"` | ||
ExpiresAt time.Time | ||
OneTimeUse bool | ||
} | ||
|
||
type RootCertificate struct { | ||
Model | ||
|
||
KeyAlgorithm string `validate:"required"` | ||
SigningAlgorithm string `validate:"required"` | ||
PublicKey Base64 `validate:"required"` | ||
PrivateKey EncryptedAtRest `validate:"required"` | ||
SignedCert EncryptedAtRest `validate:"required"` // contains private key? probably not pem encoded | ||
ExpiresAt time.Time `validate:"required"` | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
package models | ||
|
||
import ( | ||
"database/sql/driver" | ||
"encoding/base64" | ||
"fmt" | ||
) | ||
|
||
type Base64 []byte | ||
|
||
func (f Base64) Value() (driver.Value, error) { | ||
r := base64.StdEncoding.EncodeToString([]byte(f)) | ||
|
||
return r, nil | ||
} | ||
|
||
func (f *Base64) Scan(v interface{}) error { | ||
b, err := base64.StdEncoding.DecodeString(string(v.(string))) | ||
if err != nil { | ||
return fmt.Errorf("base64 decoding field: %w", err) | ||
} | ||
|
||
*f = Base64(b) | ||
|
||
return nil | ||
} | ||
|
||
func (f Base64) GormDataType() string { | ||
return "text" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.