Skip to content

Commit

Permalink
TrustStore: Resolver handles certificate chains
Browse files Browse the repository at this point in the history
This PR adds the capability to resolve certificate chains to the
resolver.

fixes scionproto#3469
  • Loading branch information
oncilla authored and lukedirtwalker committed Dec 11, 2019
1 parent 6da631a commit f8d2095
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 2 deletions.
1 change: 1 addition & 0 deletions go/lib/infra/modules/trust/v2/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ var (
// Chains
var (
chain110v1 = ChainDesc{IA: ia110, Version: 1}
chain120v1 = ChainDesc{IA: ia120, Version: 1}
)

func TestMain(m *testing.M) {
Expand Down
64 changes: 62 additions & 2 deletions go/lib/infra/modules/trust/v2/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/scionproto/scion/go/lib/infra/modules/trust/v2/internal/decoded"
"github.com/scionproto/scion/go/lib/log"
"github.com/scionproto/scion/go/lib/scrypto"
"github.com/scionproto/scion/go/lib/scrypto/cert/v2"
"github.com/scionproto/scion/go/lib/scrypto/trc/v2"
"github.com/scionproto/scion/go/lib/serrors"
)
Expand Down Expand Up @@ -157,8 +158,37 @@ func (r *resolver) trcCheck(req TRCReq, t *trc.TRC) error {
func (r *resolver) Chain(ctx context.Context, req ChainReq,
server net.Addr) (decoded.Chain, error) {

// TODO(roosd): implement
return decoded.Chain{}, serrors.New("not implemented")
msg, err := r.rpc.GetCertChain(ctx, req, server)
if err != nil {
return decoded.Chain{}, serrors.WrapStr("error requesting certificate chain", err)
}
dec, err := decoded.DecodeChain(msg)
if err != nil {
return decoded.Chain{}, serrors.WrapStr("error parsing certificate chain", err)
}
if err := r.chainCheck(req, dec.AS); err != nil {
return decoded.Chain{}, serrors.Wrap(ErrInvalidResponse, err)
}
w := resolveWrap{
resolver: r,
server: server,
cacheOnly: req.CacheOnly,
}
if err := r.inserter.InsertChain(ctx, dec, w.TRC); err != nil {
return decoded.Chain{}, serrors.WrapStr("unable to insert certificate chain", err,
"chain", dec)
}
return dec, nil
}

func (r *resolver) chainCheck(req ChainReq, as *cert.AS) error {
switch {
case !req.IA.Equal(as.Subject):
return serrors.New("wrong subject", "expected", req.IA, "actual", as.Subject)
case !req.Version.IsLatest() && req.Version != as.Version:
return serrors.New("wrong version", "expected", req.Version, "actual", as.Version)
}
return nil
}

type resOrErr struct {
Expand All @@ -184,3 +214,33 @@ func (w *prevWrap) TRC(_ context.Context, isd addr.ISD, version scrypto.Version)
}
return w.prev, nil
}

// resolverWrap provides TRCs that are backed by the resolver. If a TRC is
// missing in the DB, network requests are allowed.
type resolveWrap struct {
resolver *resolver
server net.Addr
cacheOnly bool
}

func (w resolveWrap) TRC(ctx context.Context, isd addr.ISD,
version scrypto.Version) (*trc.TRC, error) {

t, err := w.resolver.db.GetTRC(ctx, isd, version)
switch {
case err == nil:
return t, nil
case !xerrors.Is(err, ErrNotFound):
return nil, serrors.WrapStr("error querying DB for TRC", err)
}
req := TRCReq{
ISD: isd,
Version: version,
CacheOnly: w.cacheOnly,
}
decoded, err := w.resolver.TRC(ctx, req, w.server)
if err != nil {
return nil, serrors.WrapStr("unable to fetch TRC from network", err)
}
return decoded.TRC, nil
}
152 changes: 152 additions & 0 deletions go/lib/infra/modules/trust/v2/resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"

"github.com/scionproto/scion/go/lib/addr"
trust "github.com/scionproto/scion/go/lib/infra/modules/trust/v2"
Expand Down Expand Up @@ -114,3 +115,154 @@ func TestResolverTRC(t *testing.T) {
})
}
}

func TestResolverChain(t *testing.T) {
internal := serrors.New("internal")
type mocks struct {
DB *mock_v2.MockDB
Inserter *mock_v2.MockInserter
RPC *mock_v2.MockRPC
}
tests := map[string]struct {
Expect func(t *testing.T, m mocks) decoded.Chain
ChainReq trust.ChainReq
ExpectedErr error
}{
"valid, TRC in DB": {
Expect: func(t *testing.T, m mocks) decoded.Chain {
req := trust.ChainReq{IA: ia110, Version: scrypto.LatestVer}
dec := loadChain(t, chain110v1)
m.RPC.EXPECT().GetCertChain(gomock.Any(), req, nil).Return(
dec.Raw, nil,
)
decTRC := loadTRC(t, trc1v1)
m.Inserter.EXPECT().InsertChain(gomock.Any(), dec, gomock.Any()).DoAndReturn(
func(ctx context.Context, dec decoded.Chain, p trust.TRCProviderFunc) error {
trc, err := p(ctx, dec.Issuer.Subject.I, dec.Issuer.Issuer.TRCVersion)
require.NoError(t, err)
assert.Equal(t, decTRC.TRC, trc)
return err
},
)
m.DB.EXPECT().GetTRC(gomock.Any(), addr.ISD(1), scrypto.Version(1)).Return(
decTRC.TRC, nil,
)
return dec
},
ChainReq: trust.ChainReq{IA: ia110, Version: scrypto.LatestVer},
},
"RPC fail": {
Expect: func(t *testing.T, m mocks) decoded.Chain {
req := trust.ChainReq{IA: ia110, Version: scrypto.LatestVer}
m.RPC.EXPECT().GetCertChain(gomock.Any(), req, nil).Return(
nil, internal,
)
return loadChain(t, chain110v1)
},
ChainReq: trust.ChainReq{IA: ia110, Version: scrypto.LatestVer},
ExpectedErr: internal,
},
"garbage chain": {
Expect: func(t *testing.T, m mocks) decoded.Chain {
req := trust.ChainReq{IA: ia110, Version: scrypto.LatestVer}
m.RPC.EXPECT().GetCertChain(gomock.Any(), req, nil).Return(
[]byte("some_garbage"), nil,
)
return loadChain(t, chain110v1)
},
ChainReq: trust.ChainReq{IA: ia110, Version: scrypto.LatestVer},
ExpectedErr: decoded.ErrParse,
},
"wrong subject": {
Expect: func(t *testing.T, m mocks) decoded.Chain {
req := trust.ChainReq{IA: ia110, Version: scrypto.LatestVer}
m.RPC.EXPECT().GetCertChain(gomock.Any(), req, nil).Return(
loadChain(t, chain120v1).Raw, nil,
)
return loadChain(t, chain110v1)
},
ChainReq: trust.ChainReq{IA: ia110, Version: scrypto.LatestVer},
ExpectedErr: trust.ErrInvalidResponse,
},
"wrong version": {
Expect: func(t *testing.T, m mocks) decoded.Chain {
req := trust.ChainReq{IA: ia110, Version: 2}
m.RPC.EXPECT().GetCertChain(gomock.Any(), req, nil).Return(
loadChain(t, chain110v1).Raw, nil,
)
return loadChain(t, chain110v1)
},
ChainReq: trust.ChainReq{IA: ia110, Version: 2},
ExpectedErr: trust.ErrInvalidResponse,
},
"DB error": {
Expect: func(t *testing.T, m mocks) decoded.Chain {
req := trust.ChainReq{IA: ia110, Version: scrypto.LatestVer}
dec := loadChain(t, chain110v1)
m.RPC.EXPECT().GetCertChain(gomock.Any(), req, nil).Return(
dec.Raw, nil,
)
m.Inserter.EXPECT().InsertChain(gomock.Any(), dec, gomock.Any()).DoAndReturn(
func(ctx context.Context, dec decoded.Chain, p trust.TRCProviderFunc) error {
_, err := p(ctx, dec.Issuer.Subject.I, dec.Issuer.Issuer.TRCVersion)
require.Error(t, err)
return err
},
)
m.DB.EXPECT().GetTRC(gomock.Any(), addr.ISD(1), scrypto.Version(1)).Return(
nil, internal,
)
return dec
},
ChainReq: trust.ChainReq{IA: ia110, Version: scrypto.LatestVer},
ExpectedErr: internal,
},
"TRC RPC error": {
Expect: func(t *testing.T, m mocks) decoded.Chain {
req := trust.ChainReq{IA: ia110, Version: scrypto.LatestVer}
dec := loadChain(t, chain110v1)
m.RPC.EXPECT().GetCertChain(gomock.Any(), req, nil).Return(
dec.Raw, nil,
)
m.Inserter.EXPECT().InsertChain(gomock.Any(), dec, gomock.Any()).DoAndReturn(
func(ctx context.Context, dec decoded.Chain, p trust.TRCProviderFunc) error {
_, err := p(ctx, dec.Issuer.Subject.I, dec.Issuer.Issuer.TRCVersion)
xtest.AssertErrorsIs(t, err, internal)
return err
},
)
m.DB.EXPECT().GetTRC(gomock.Any(), addr.ISD(1), scrypto.Version(1)).Return(
nil, trust.ErrNotFound,
)
m.DB.EXPECT().GetTRC(gomock.Any(), addr.ISD(1), scrypto.LatestVer).Return(
nil, internal,
)
return dec
},
ChainReq: trust.ChainReq{IA: ia110, Version: scrypto.LatestVer},
ExpectedErr: internal,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
mctrl := gomock.NewController(t)
defer mctrl.Finish()
m := mocks{
DB: mock_v2.NewMockDB(mctrl),
Inserter: mock_v2.NewMockInserter(mctrl),
RPC: mock_v2.NewMockRPC(mctrl),
}
expected := test.Expect(t, m)
r := trust.NewResolver(m.DB, m.Inserter, m.RPC)
dec, err := r.Chain(context.Background(), test.ChainReq, nil)
if test.ExpectedErr != nil {
require.Error(t, err)
assert.Truef(t, xerrors.Is(err, test.ExpectedErr),
"actual: %s\nexpected: %s", err, test.ExpectedErr)
} else {
require.NoError(t, err)
assert.Equal(t, expected, dec)
}
})
}
}

0 comments on commit f8d2095

Please sign in to comment.