Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for repacking and merging tls.Config structs #196

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 61 additions & 31 deletions transport/tlscommon/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,28 +159,37 @@ func (m TLSVerificationMode) MarshalText() ([]byte, error) {
return nil, fmt.Errorf("could not marshal '%+v' to text", m)
}

// Unpack unpacks the string into constants.
// Unpack unpacks the input into a TLSVerificationMode.
func (m *TLSVerificationMode) Unpack(in interface{}) error {
if in == nil {
*m = VerifyFull
return nil
}

s, ok := in.(string)
if !ok {
return fmt.Errorf("verification mode must be an identifier")
}
if s == "" {
*m = VerifyFull
return nil
switch o := in.(type) {
case string:
if o == "" {
*m = VerifyFull
return nil
}

mode, found := tlsVerificationModes[o]
if !found {
return fmt.Errorf("unknown verification mode '%v'", o)
}
*m = mode
case uint64:
*m = TLSVerificationMode(o)
default:
return fmt.Errorf("verification mode is an unknown type: %T", o)
}
return nil
}

mode, found := tlsVerificationModes[s]
if !found {
return fmt.Errorf("unknown verification mode '%v'", s)
func (m *TLSVerificationMode) Validate() error {
if *m > VerifyStrict {
return fmt.Errorf("unsupported verification mode: %v", m)
}

*m = mode
return nil
}

Expand Down Expand Up @@ -214,13 +223,20 @@ func (m *TLSClientAuth) Unpack(s string) error {

type CipherSuite uint16

func (cs *CipherSuite) Unpack(s string) error {
suite, found := tlsCipherSuites[s]
if !found {
return fmt.Errorf("invalid tls cipher suite '%v'", s)
func (cs *CipherSuite) Unpack(i interface{}) error {
switch o := i.(type) {
case string:
suite, found := tlsCipherSuites[o]
if !found {
return fmt.Errorf("invalid tls cipher suite '%v'", o)
}

*cs = suite
case uint64:
*cs = CipherSuite(o)
default:
return fmt.Errorf("cipher suite is an unknown type: %T", o)
}

*cs = suite
return nil
}

Expand All @@ -233,13 +249,20 @@ func (cs CipherSuite) String() string {

type tlsCurveType tls.CurveID

func (ct *tlsCurveType) Unpack(s string) error {
t, found := tlsCurveTypes[s]
if !found {
return fmt.Errorf("invalid tls curve type '%v'", s)
func (ct *tlsCurveType) Unpack(i interface{}) error {
switch o := i.(type) {
case string:
t, found := tlsCurveTypes[o]
if !found {
return fmt.Errorf("invalid tls curve type '%v'", o)
}

*ct = t
case uint64:
*ct = tlsCurveType(o)
default:
return fmt.Errorf("tls curve type is an unsupported input type: %T", o)
}

*ct = t
return nil
}

Expand All @@ -252,13 +275,20 @@ func (r TLSRenegotiationSupport) String() string {
return "<" + unknownType + ">"
}

func (r *TLSRenegotiationSupport) Unpack(s string) error {
t, found := tlsRenegotiationSupportTypes[s]
if !found {
return fmt.Errorf("invalid tls renegotiation type '%v'", s)
func (r *TLSRenegotiationSupport) Unpack(i interface{}) error {
switch o := i.(type) {
case string:
t, found := tlsRenegotiationSupportTypes[o]
if !found {
return fmt.Errorf("invalid tls renegotiation type '%v'", o)
}

*r = t
case uint64:
*r = TLSRenegotiationSupport(o)
default:
return fmt.Errorf("tls renegotation support is an unknown type: %T", o)
}

*r = t
return nil
}

Expand Down
31 changes: 31 additions & 0 deletions transport/tlscommon/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"testing"

"github.com/elastic/elastic-agent-libs/config"
"github.com/elastic/go-ucfg"
"github.com/stretchr/testify/assert"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -69,6 +70,36 @@ func TestLoadWithEmptyVerificationMode(t *testing.T) {
assert.Equal(t, cfg.VerificationMode, VerifyFull)
}

func TestRepackConfig(t *testing.T) {
cfg, err := load(`
enabled: true
verification_mode: certificate
supported_protocols: [TLSv1.1, TLSv1.2]
cipher_suites:
- RSA-AES-256-CBC-SHA
certificate_authorities:
- /path/to/ca.crt
certificate: /path/to/cert.crt
key: /path/to/key.crt
curve_types:
- P-521
renegotiation: freely
ca_sha256:
- example
ca_trusted_fingerprint: fingerprint
`)

assert.NoError(t, err)
assert.Equal(t, cfg.VerificationMode, VerifyCertificate)

tmp, err := ucfg.NewFrom(cfg)
assert.NoError(t, err)

err = tmp.Unpack(cfg)
assert.NoError(t, err)
assert.Equal(t, cfg.VerificationMode, VerifyCertificate)
}

func TestTLSClientAuthUnpack(t *testing.T) {
tests := []struct {
val string
Expand Down
23 changes: 18 additions & 5 deletions transport/tlscommon/versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,25 @@ func (v TLSVersion) Details() *TLSVersionDetails {
}

// Unpack transforms the string into a constant.
func (v *TLSVersion) Unpack(s string) error {
version, found := tlsProtocolVersions[s]
if !found {
return fmt.Errorf("invalid tls version '%v'", s)
func (v *TLSVersion) Unpack(i interface{}) error {
switch o := i.(type) {
case string:
version, found := tlsProtocolVersions[o]
if !found {
return fmt.Errorf("invalid tls version '%v'", o)
}
*v = version
case uint64:
*v = TLSVersion(o)
default:
return fmt.Errorf("tls version is an unknown type: %T", o)
}
return nil
}

*v = version
func (v *TLSVersion) Validate() error {
if *v < TLSVersionMin || *v > TLSVersionMax {
return fmt.Errorf("unsupported tls version: %v", v)
}
return nil
}
Loading