Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TrustStore: Resolver handles certificate chains #3495

Merged
merged 2 commits into from
Dec 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go/lib/infra/modules/trust/v2/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ var (
// Chains
var (
chain110v1 = ChainDesc{IA: ia110, Version: 1}
chain120v1 = ChainDesc{IA: ia120, Version: 1}
)

func TestMain(m *testing.M) {
Expand Down
64 changes: 62 additions & 2 deletions go/lib/infra/modules/trust/v2/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/scionproto/scion/go/lib/infra/modules/trust/v2/internal/decoded"
"github.com/scionproto/scion/go/lib/log"
"github.com/scionproto/scion/go/lib/scrypto"
"github.com/scionproto/scion/go/lib/scrypto/cert/v2"
"github.com/scionproto/scion/go/lib/scrypto/trc/v2"
"github.com/scionproto/scion/go/lib/serrors"
)
Expand Down Expand Up @@ -157,8 +158,37 @@ func (r *resolver) trcCheck(req TRCReq, t *trc.TRC) error {
func (r *resolver) Chain(ctx context.Context, req ChainReq,
server net.Addr) (decoded.Chain, error) {

// TODO(roosd): implement
return decoded.Chain{}, serrors.New("not implemented")
msg, err := r.rpc.GetCertChain(ctx, req, server)
if err != nil {
return decoded.Chain{}, serrors.WrapStr("error requesting certificate chain", err)
}
dec, err := decoded.DecodeChain(msg)
if err != nil {
return decoded.Chain{}, serrors.WrapStr("error parsing certificate chain", err)
}
if err := r.chainCheck(req, dec.AS); err != nil {
return decoded.Chain{}, serrors.Wrap(ErrInvalidResponse, err)
}
w := resolveWrap{
resolver: r,
server: server,
cacheOnly: req.CacheOnly,
}
if err := r.inserter.InsertChain(ctx, dec, w.TRC); err != nil {
return decoded.Chain{}, serrors.WrapStr("unable to insert certificate chain", err,
"chain", dec)
}
return dec, nil
}

func (r *resolver) chainCheck(req ChainReq, as *cert.AS) error {
switch {
case !req.IA.Equal(as.Subject):
return serrors.New("wrong subject", "expected", req.IA, "actual", as.Subject)
case !req.Version.IsLatest() && req.Version != as.Version:
return serrors.New("wrong version", "expected", req.Version, "actual", as.Version)
}
return nil
}

type resOrErr struct {
Expand All @@ -184,3 +214,33 @@ func (w *prevWrap) TRC(_ context.Context, isd addr.ISD, version scrypto.Version)
}
return w.prev, nil
}

// resolverWrap provides TRCs that are backed by the resolver. If a TRC is
// missing in the DB, network requests are allowed.
type resolveWrap struct {
resolver *resolver
server net.Addr
cacheOnly bool
}

func (w resolveWrap) TRC(ctx context.Context, isd addr.ISD,
version scrypto.Version) (*trc.TRC, error) {

t, err := w.resolver.db.GetTRC(ctx, isd, version)
switch {
case err == nil:
return t, nil
case !xerrors.Is(err, ErrNotFound):
return nil, serrors.WrapStr("error querying DB for TRC", err)
}
req := TRCReq{
ISD: isd,
Version: version,
CacheOnly: w.cacheOnly,
}
decoded, err := w.resolver.TRC(ctx, req, w.server)
if err != nil {
return nil, serrors.WrapStr("unable to fetch TRC from network", err)
}
return decoded.TRC, nil
}
145 changes: 145 additions & 0 deletions go/lib/infra/modules/trust/v2/resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,148 @@ func TestResolverTRC(t *testing.T) {
})
}
}

func TestResolverChain(t *testing.T) {
internal := serrors.New("internal")
type mocks struct {
DB *mock_v2.MockDB
Inserter *mock_v2.MockInserter
RPC *mock_v2.MockRPC
}
tests := map[string]struct {
Expect func(t *testing.T, m mocks) decoded.Chain
ChainReq trust.ChainReq
ExpectedErr error
}{
"valid, TRC in DB": {
Expect: func(t *testing.T, m mocks) decoded.Chain {
req := trust.ChainReq{IA: ia110, Version: scrypto.LatestVer}
dec := loadChain(t, chain110v1)
m.RPC.EXPECT().GetCertChain(gomock.Any(), req, nil).Return(
dec.Raw, nil,
)
decTRC := loadTRC(t, trc1v1)
m.Inserter.EXPECT().InsertChain(gomock.Any(), dec, gomock.Any()).DoAndReturn(
func(ctx context.Context, dec decoded.Chain, p trust.TRCProviderFunc) error {
trc, err := p(ctx, dec.Issuer.Subject.I, dec.Issuer.Issuer.TRCVersion)
require.NoError(t, err)
assert.Equal(t, decTRC.TRC, trc)
return err
},
)
m.DB.EXPECT().GetTRC(gomock.Any(), addr.ISD(1), scrypto.Version(1)).Return(
decTRC.TRC, nil,
)
return dec
},
ChainReq: trust.ChainReq{IA: ia110, Version: scrypto.LatestVer},
},
"RPC fail": {
Expect: func(t *testing.T, m mocks) decoded.Chain {
req := trust.ChainReq{IA: ia110, Version: scrypto.LatestVer}
m.RPC.EXPECT().GetCertChain(gomock.Any(), req, nil).Return(
nil, internal,
)
return decoded.Chain{}
},
ChainReq: trust.ChainReq{IA: ia110, Version: scrypto.LatestVer},
ExpectedErr: internal,
},
"garbage chain": {
Expect: func(t *testing.T, m mocks) decoded.Chain {
req := trust.ChainReq{IA: ia110, Version: scrypto.LatestVer}
m.RPC.EXPECT().GetCertChain(gomock.Any(), req, nil).Return(
[]byte("some_garbage"), nil,
)
return decoded.Chain{}
},
ChainReq: trust.ChainReq{IA: ia110, Version: scrypto.LatestVer},
ExpectedErr: decoded.ErrParse,
},
"wrong subject": {
Expect: func(t *testing.T, m mocks) decoded.Chain {
req := trust.ChainReq{IA: ia110, Version: scrypto.LatestVer}
m.RPC.EXPECT().GetCertChain(gomock.Any(), req, nil).Return(
loadChain(t, chain120v1).Raw, nil,
)
return decoded.Chain{}
},
ChainReq: trust.ChainReq{IA: ia110, Version: scrypto.LatestVer},
ExpectedErr: trust.ErrInvalidResponse,
},
"wrong version": {
Expect: func(t *testing.T, m mocks) decoded.Chain {
req := trust.ChainReq{IA: ia110, Version: 2}
m.RPC.EXPECT().GetCertChain(gomock.Any(), req, nil).Return(
loadChain(t, chain110v1).Raw, nil,
)
return decoded.Chain{}
},
ChainReq: trust.ChainReq{IA: ia110, Version: 2},
ExpectedErr: trust.ErrInvalidResponse,
},
"DB error": {
Expect: func(t *testing.T, m mocks) decoded.Chain {
req := trust.ChainReq{IA: ia110, Version: scrypto.LatestVer}
dec := loadChain(t, chain110v1)
m.RPC.EXPECT().GetCertChain(gomock.Any(), req, nil).Return(
dec.Raw, nil,
)
m.Inserter.EXPECT().InsertChain(gomock.Any(), dec, gomock.Any()).DoAndReturn(
func(ctx context.Context, dec decoded.Chain, p trust.TRCProviderFunc) error {
_, err := p(ctx, dec.Issuer.Subject.I, dec.Issuer.Issuer.TRCVersion)
require.Error(t, err)
return err
},
)
m.DB.EXPECT().GetTRC(gomock.Any(), addr.ISD(1), scrypto.Version(1)).Return(
nil, internal,
)
return decoded.Chain{}
},
ChainReq: trust.ChainReq{IA: ia110, Version: scrypto.LatestVer},
ExpectedErr: internal,
},
"TRC RPC error": {
Expect: func(t *testing.T, m mocks) decoded.Chain {
req := trust.ChainReq{IA: ia110, Version: scrypto.LatestVer}
dec := loadChain(t, chain110v1)
m.RPC.EXPECT().GetCertChain(gomock.Any(), req, nil).Return(
dec.Raw, nil,
)
m.Inserter.EXPECT().InsertChain(gomock.Any(), dec, gomock.Any()).DoAndReturn(
func(ctx context.Context, dec decoded.Chain, p trust.TRCProviderFunc) error {
_, err := p(ctx, dec.Issuer.Subject.I, dec.Issuer.Issuer.TRCVersion)
xtest.AssertErrorsIs(t, err, internal)
return err
},
)
m.DB.EXPECT().GetTRC(gomock.Any(), addr.ISD(1), scrypto.Version(1)).Return(
nil, trust.ErrNotFound,
)
m.DB.EXPECT().GetTRC(gomock.Any(), addr.ISD(1), scrypto.LatestVer).Return(
nil, internal,
)
return decoded.Chain{}
},
ChainReq: trust.ChainReq{IA: ia110, Version: scrypto.LatestVer},
ExpectedErr: internal,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
mctrl := gomock.NewController(t)
defer mctrl.Finish()
m := mocks{
DB: mock_v2.NewMockDB(mctrl),
Inserter: mock_v2.NewMockInserter(mctrl),
RPC: mock_v2.NewMockRPC(mctrl),
}
expected := test.Expect(t, m)
r := trust.NewResolver(m.DB, m.Inserter, m.RPC)
dec, err := r.Chain(context.Background(), test.ChainReq, nil)
xtest.AssertErrorsIs(t, err, test.ExpectedErr)
assert.Equal(t, expected, dec)
})
}
}