Skip to content

Commit

Permalink
Merge pull request #9 from madflojo/tlsconfig-helpers
Browse files Browse the repository at this point in the history
Add methods to generate and configure TLS
  • Loading branch information
madflojo authored May 18, 2024
2 parents 7f09672 + 5be6d63 commit 34930f4
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 37 deletions.
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
tests:
go test --race -v -covermode=atomic -coverprofile=coverage.out ./...

benchmarks:
go test -bench=. -benchmem ./...
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ func TestFunc(t *testing.T) {
// Create a client with the self-signed CA
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: ca.CertPool(),
},
TLSClientConfig: certs.ConfigureTLSConfig(ca.GenerateTLSConfig()),
},
}

Expand Down
33 changes: 26 additions & 7 deletions testcerts.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ Stop saving test certificates in your code repos. Start generating them in your
}
}
For more complex tests, you can also use this package to create a Certificate Authority and a key pair signed by that Certificate Authority for any test domain you want.
For more complex tests, you can also use this package to create a Certificate Authority and a key pair signed by
that Certificate Authority for any test domain you want.
func TestFunc(t *testing.T) {
// Generate Certificate Authority
Expand Down Expand Up @@ -48,9 +49,7 @@ For more complex tests, you can also use this package to create a Certificate Au
// Create a client with the self-signed CA
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: ca.CertPool(),
},
TLSClientConfig: certs.ConfigureTLSConfig(ca.GenerateTLSConfig()),
},
}
Expand All @@ -66,6 +65,7 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
Expand Down Expand Up @@ -163,7 +163,7 @@ func (ca *CertificateAuthority) NewKeyPair(domains ...string) (*KeyPair, error)
return kp, nil
}

// CertPool returns a Certificate Pool of the CertificateAuthority Certificate
// CertPool returns a Certificate Pool of the CertificateAuthority Certificate.
func (ca *CertificateAuthority) CertPool() *x509.CertPool {
return ca.certPool
}
Expand Down Expand Up @@ -232,6 +232,13 @@ func (ca *CertificateAuthority) ToTempFile(dir string) (cfh *os.File, kfh *os.Fi
return cfh, kfh, nil
}

// GenerateTLSConfig returns a tls.Config with the CertificateAuthority as the RootCA.
func (ca *CertificateAuthority) GenerateTLSConfig() *tls.Config {
return &tls.Config{
RootCAs: ca.CertPool(),
}
}

// PrivateKey returns the private key of the KeyPair.
func (kp *KeyPair) PrivateKey() []byte {
return pem.EncodeToMemory(kp.privateKey)
Expand Down Expand Up @@ -296,6 +303,18 @@ func (kp *KeyPair) ToTempFile(dir string) (cfh *os.File, kfh *os.File, err error
return cfh, kfh, nil
}

// ConfigureTLSConfig will configure the tls.Config with the KeyPair certificate and private key.
// The returned tls.Config can be used for a server or client.
func (kp *KeyPair) ConfigureTLSConfig(tlsConfig *tls.Config) *tls.Config {
tlsConfig.Certificates = []tls.Certificate{
{
Certificate: [][]byte{kp.PublicKey()},
PrivateKey: kp.PrivateKey(),
},
}
return tlsConfig
}

// GenerateCerts generates a x509 certificate and key.
// It returns the certificate and key as byte slices, and any error that occurred.
//
Expand Down Expand Up @@ -381,7 +400,7 @@ func genKeyPair(ca *x509.Certificate, caKey *ecdsa.PrivateKey, cert *x509.Certif
return certToPemBlock(signedCert), key, nil
}

// keyToPemBlock converts the key to a private pem.Block
// keyToPemBlock converts the key to a private pem.Block.
func keyToPemBlock(key *ecdsa.PrivateKey) (*pem.Block, error) {
// Convert key into pem.Block
kb, err := x509.MarshalPKCS8PrivateKey(key)
Expand All @@ -392,7 +411,7 @@ func keyToPemBlock(key *ecdsa.PrivateKey) (*pem.Block, error) {
return k, nil
}

// certToPemBlock converts the certificate to a public pem.Block
// certToPemBlock converts the certificate to a public pem.Block.
func certToPemBlock(cert []byte) *pem.Block {
return &pem.Block{Type: "CERTIFICATE", Bytes: cert}
}
131 changes: 104 additions & 27 deletions testcerts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,8 @@ func TestCertsUsage(t *testing.T) {
t.Errorf("Unexpected success with invalid tempfile directory")
}
})

})
}

}

func TestGeneratingCerts(t *testing.T) {
Expand Down Expand Up @@ -338,10 +336,7 @@ func TestGenerateCertsToTempFile(t *testing.T) {
})
}

// testUsingCerts is called by the two tests below. Both test setting up certificates and subsequently
// configuring the transport of the http.Client to use the generated certificate.
// One uses the own public key as part of the pool (self signed cert) and one uses the cert from the CA.
func testUsingCerts(t *testing.T, rootCAs func(ca *CertificateAuthority, certs *KeyPair) *x509.CertPool) {
func TestFullFlow(t *testing.T) {
// Create a signed Certificate and Key for "localhost"
ca := NewCA()
certs, err := ca.NewKeyPair("localhost")
Expand Down Expand Up @@ -369,39 +364,121 @@ func testUsingCerts(t *testing.T, rootCAs func(ca *CertificateAuthority, certs *
}
}()

// Add handler
server.Handler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, err := w.Write([]byte("Hello, World!"))
if err != nil {
t.Errorf("Error writing response - %s", err)
}
})

// Wait for Listener to start
<-time.After(3 * time.Second)

// Setup HTTP Client with Cert Pool
certpool := rootCAs(ca, certs)
if certpool == nil {
t.Fatalf("Test configuration error: rootCAs arg function returned nil instead of a x509.CertPool")
t.Run("TestUsingCA", func(t *testing.T) {
// Setup HTTP Client with Cert Pool
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: certs.ConfigureTLSConfig(ca.GenerateTLSConfig()),
},
}

// Make an HTTPS request
rsp, err := client.Get("https://localhost:8443")
if err != nil {
t.Errorf("Client returned error - %s", err)
}

// Check the response
if rsp.StatusCode != http.StatusOK {
t.Errorf("Unexpected response code - %d", rsp.StatusCode)
}
})

t.Run("TestUsingSelfSigned", func(t *testing.T) {
// Create new CertPool
pool := x509.NewCertPool()
pool.AppendCertsFromPEM(certs.PublicKey())

// Setup HTTP Client with Cert Pool
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: pool,
},
},
}

// Make an HTTPS request
rsp, err := client.Get("https://localhost:8443")
if err != nil {
t.Errorf("Client returned error - %s", err)
}

// Check the response
if rsp.StatusCode != http.StatusOK {
t.Errorf("Unexpected response code - %d", rsp.StatusCode)
}
})
}

func ExampleNewCA() {
// Generate a new Certificate Authority
ca := NewCA()

// Create a new KeyPair with a list of domains
certs, err := ca.NewKeyPair("localhost")
if err != nil {
fmt.Printf("Error generating keypair - %s", err)
}

// Write the certificates to a file
cert, key, err := certs.ToTempFile("")
if err != nil {
fmt.Printf("Error writing certs to temp files - %s", err)
}

// Create an HTTP Server
server := &http.Server{
Addr: "0.0.0.0:8443",
}
defer server.Close()

go func() {
// Start HTTP Listener
err = server.ListenAndServeTLS(cert.Name(), key.Name())
if err != nil && err != http.ErrServerClosed {
fmt.Printf("Listener returned error - %s", err)
}
}()

// Add handler
server.Handler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, err := w.Write([]byte("Hello, World!"))
if err != nil {
fmt.Printf("Error writing response - %s", err)
}
})

// Wait for Listener to start
<-time.After(3 * time.Second)

// Setup HTTP Client with Cert Pool
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: certpool,
},
TLSClientConfig: certs.ConfigureTLSConfig(ca.GenerateTLSConfig()),
},
}

// Make an HTTPS request
_, err = client.Get("https://localhost:8443")
rsp, err := client.Get("https://localhost:8443")
if err != nil {
t.Errorf("Client returned error - %s", err)
fmt.Printf("Client returned error - %s", err)
}
}

func TestUsingSelfSignedCerts(t *testing.T) {
testUsingCerts(t, func(_ *CertificateAuthority, certs *KeyPair) *x509.CertPool {
pool := x509.NewCertPool()
pool.AppendCertsFromPEM(certs.PublicKey())
return pool
})
}
// Print the response
fmt.Println(rsp.Status)

func TestUsingCertsWithCA(t *testing.T) {
testUsingCerts(t, func(ca *CertificateAuthority, _ *KeyPair) *x509.CertPool {
return ca.certPool
})
// Output:
// 200 OK
}

0 comments on commit 34930f4

Please sign in to comment.