From 06e9df0d09f232afc7bbc48cca420881e062213d Mon Sep 17 00:00:00 2001 From: Oncilla Date: Mon, 7 Oct 2019 11:25:59 +0200 Subject: [PATCH] TrustStore: Implement inserter (#3225) Adds: - Implement TRC verification and insertion logic. - The forwarding inserter registers the new trust material with the local certificate server before inserting into the database. It is supposed to be used by the beacon and path server. --- go/lib/infra/modules/trust/v2/BUILD.bazel | 5 + go/lib/infra/modules/trust/v2/export_test.go | 55 +++++- go/lib/infra/modules/trust/v2/inserter.go | 158 +++++++++++++++++- .../infra/modules/trust/v2/inserter_test.go | 100 +++++++++++ go/lib/infra/modules/trust/v2/router.go | 68 ++++++++ go/lib/infra/modules/trust/v2/router_test.go | 151 +++++++++++++++++ 6 files changed, 530 insertions(+), 7 deletions(-) create mode 100644 go/lib/infra/modules/trust/v2/inserter_test.go create mode 100644 go/lib/infra/modules/trust/v2/router_test.go diff --git a/go/lib/infra/modules/trust/v2/BUILD.bazel b/go/lib/infra/modules/trust/v2/BUILD.bazel index eb732c7f9a..12ed4bcc23 100644 --- a/go/lib/infra/modules/trust/v2/BUILD.bazel +++ b/go/lib/infra/modules/trust/v2/BUILD.bazel @@ -32,11 +32,13 @@ go_test( name = "go_default_test", srcs = [ "export_test.go", + "inserter_test.go", "inspector_test.go", "main_test.go", "provider_test.go", "recurser_test.go", "resolver_test.go", + "router_test.go", ], data = [ "//go/lib/infra/modules/trust/v2/testdata:crypto_tar", @@ -44,6 +46,7 @@ go_test( embed = [":go_default_library"], deps = [ "//go/lib/addr:go_default_library", + "//go/lib/common:go_default_library", "//go/lib/infra:go_default_library", "//go/lib/infra/modules/trust/v2/internal/decoded:go_default_library", "//go/lib/infra/modules/trust/v2/mock_v2:go_default_library", @@ -52,6 +55,8 @@ go_test( "//go/lib/scrypto/trc/v2:go_default_library", "//go/lib/serrors:go_default_library", "//go/lib/snet:go_default_library", + "//go/lib/snet/mock_snet:go_default_library", + "//go/lib/spath:go_default_library", "//go/lib/util:go_default_library", "//go/lib/xtest:go_default_library", "@com_github_golang_mock//gomock:go_default_library", diff --git a/go/lib/infra/modules/trust/v2/export_test.go b/go/lib/infra/modules/trust/v2/export_test.go index bc0631d306..5f730f5e40 100644 --- a/go/lib/infra/modules/trust/v2/export_test.go +++ b/go/lib/infra/modules/trust/v2/export_test.go @@ -14,13 +14,30 @@ package trust +import ( + "github.com/scionproto/scion/go/lib/addr" + "github.com/scionproto/scion/go/lib/snet" +) + var ( // NewCryptoProvider allows instantiating the private cryptoProvider for // black-box testing. NewCryptoProvider = newTestCryptoProvider - // newTestInspector allows instantiating the private inspector for - // black-box testing. + // NewCSRouter allows instantiating the private CS router for black-box + // testing. + NewCSRouter = newTestCSRouter + // NewFwdInserter allows instantiating the private forwarding + // inserter for black-box testing. + NewFwdInserter = newTestFwdInserter + // NewInserter allows instantiating the private inserter for black-box + // testing. + NewInserter = newTestInserter + // NewTestInspector allows instantiating the private inspector for black-box + // testing. NewTestInspector = newTestInspector + // NewLocalRouter allows instantiating the private resolver for black-box + // testing. + NewLocalRouter = newTestLocalRouter // NewResolver allows instantiating the private resolver for black-box // testing. NewResolver = newTestResolver @@ -39,6 +56,35 @@ func newTestCryptoProvider(db DBRead, recurser Recurser, resolver Resolver, rout } } +// newTestCSRouter returns a new router for testing. +func newTestCSRouter(isd addr.ISD, router snet.Router, db TRCRead) Router { + return &csRouter{ + isd: isd, + router: router, + db: db, + } +} + +// newTestFwdInserter returns a new forwarding inserter for testing. +func newTestFwdInserter(db ReadWrite, rpc RPC) Inserter { + return &fwdInserter{ + baseInserter: baseInserter{ + db: db, + }, + rpc: rpc, + } +} + +// newTestInserter returns a new inserter for testing. +func newTestInserter(db ReadWrite, unsafe bool) Inserter { + return &inserter{ + baseInserter: baseInserter{ + db: db, + unsafe: unsafe, + }, + } +} + // newTestInspector returns a new inspector for testing. func newTestInspector(provider CryptoProvider) Inspector { return &inspector{ @@ -46,6 +92,11 @@ func newTestInspector(provider CryptoProvider) Inspector { } } +// newTestLocalRouter returns a new router for testing. +func newTestLocalRouter(ia addr.IA) Router { + return &localRouter{ia: ia} +} + // newTestResolver returns a new resolver for testing. func newTestResolver(db DBRead, inserter Inserter, rpc RPC) Resolver { return &resolver{ diff --git a/go/lib/infra/modules/trust/v2/inserter.go b/go/lib/infra/modules/trust/v2/inserter.go index 3c2eb7a22d..6cbb3dc679 100644 --- a/go/lib/infra/modules/trust/v2/inserter.go +++ b/go/lib/infra/modules/trust/v2/inserter.go @@ -16,32 +16,180 @@ package trust import ( "context" - "errors" "github.com/scionproto/scion/go/lib/addr" "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/internal/decoded" "github.com/scionproto/scion/go/lib/scrypto" "github.com/scionproto/scion/go/lib/scrypto/trc/v2" + "github.com/scionproto/scion/go/lib/serrors" ) var ( + // ErrBaseNotSupported indicates base TRC insertion is not supported. + ErrBaseNotSupported = serrors.New("inserting base TRC not supported") // ErrValidation indicates a validation error. - ErrValidation = errors.New("validation error") + ErrValidation = serrors.New("validation error") // ErrVerification indicates a verification error. - ErrVerification = errors.New("verification error") + ErrVerification = serrors.New("verification error") ) // Inserter inserts and verifies trust material into the database. type Inserter interface { // InsertTRC verifies the signed TRC and inserts it into the database. // The previous TRC is queried through the provider function, when necessary. - InsertTRC(ctx context.Context, decoded decoded.TRC, trcProvider TRCProviderFunc) error + InsertTRC(ctx context.Context, decTRC decoded.TRC, trcProvider TRCProviderFunc) error // InsertChain verifies the signed certificate chain and inserts it into the // database. The issuing TRC is queried through the provider function, when // necessary. - InsertChain(ctx context.Context, decoded decoded.Chain, trcProvider TRCProviderFunc) error + InsertChain(ctx context.Context, decChain decoded.Chain, trcProvider TRCProviderFunc) error } // TRCProviderFunc provides TRCs. It is used to configure the TRC retrieval // method of the inserter. type TRCProviderFunc func(context.Context, addr.ISD, scrypto.Version) (*trc.TRC, error) + +// inserter is used to verify and insert trust material into the database. +type inserter struct { + baseInserter +} + +// InsertTRC verifies the signed TRC and inserts it into the database. +// The previous TRC is queried through the provider function, when necessary. +func (ins *inserter) InsertTRC(ctx context.Context, decTRC decoded.TRC, + trcProvider TRCProviderFunc) error { + + if insert, err := ins.shouldInsertTRC(ctx, decTRC, trcProvider); err != nil || !insert { + return err + } + if _, err := ins.db.InsertTRC(ctx, decTRC); err != nil { + return serrors.WrapStr("unable to insert TRC", err) + } + return nil +} + +// InsertChain verifies the signed certificate chain and inserts it into the +// database. The issuing TRC is queried through the provider function, when +// necessary. +func (ins *inserter) InsertChain(ctx context.Context, chain decoded.Chain, + trcProvider TRCProviderFunc) error { + + if insert, err := ins.shouldInsertChain(ctx, chain, trcProvider); err != nil || !insert { + return err + } + if _, _, err := ins.db.InsertChain(ctx, chain); err != nil { + return serrors.WrapStr("unable to insert chain", err) + } + return nil +} + +// fwdInserter is an inserter that always forwards the trust material to the +// certificate server before inserting it into the database. Forwarding must be +// successful, otherwise the material is not inserted into the database. +type fwdInserter struct { + baseInserter + router localRouter + rpc RPC +} + +// InsertTRC verifies the signed TRC and inserts it into the database. The +// previous TRC is queried through the provider function, when necessary. Before +// insertion, the TRC is forwarded to the certificate server. If the certificate +// server does not successfully handle the TRC, the insertion fails. +func (ins *fwdInserter) InsertTRC(ctx context.Context, decTRC decoded.TRC, + trcProvider TRCProviderFunc) error { + + if insert, err := ins.shouldInsertTRC(ctx, decTRC, trcProvider); err != nil || !insert { + return err + } + cs := ins.router.chooseServer() + if err := ins.rpc.SendTRC(ctx, decTRC.Raw, cs); err != nil { + return serrors.WrapStr("unable to push TRC to certificate server", err, "addr", cs) + } + if _, err := ins.db.InsertTRC(ctx, decTRC); err != nil { + return serrors.WrapStr("unable to insert TRC", err) + } + return nil +} + +// InsertChain verifies the signed certificate chain and inserts it into the +// database. The issuing TRC is queried through the provider function, when +// necessary. Before insertion, the certificate chain is forwarded to the +// certificate server. If the certificate server does not successfully handle +// the certificate chain, the insertion fails. +func (ins *fwdInserter) InsertChain(ctx context.Context, chain decoded.Chain, + trcProvider TRCProviderFunc) error { + + if insert, err := ins.shouldInsertChain(ctx, chain, trcProvider); err != nil || !insert { + return err + } + cs := ins.router.chooseServer() + if err := ins.rpc.SendCertChain(ctx, chain.Raw, cs); err != nil { + return serrors.WrapStr("unable to push chain to certificate server", err, + "addr", cs) + } + if _, _, err := ins.db.InsertChain(ctx, chain); err != nil { + return serrors.WrapStr("unable to insert chain", err) + } + return nil +} + +type baseInserter struct { + db ReadWrite + // unsafe allows inserts of base TRCs. This is used as a workaround until + // TAAC support is implemented. + unsafe bool +} + +func (ins *baseInserter) shouldInsertTRC(ctx context.Context, decTRC decoded.TRC, + trcProvider TRCProviderFunc) (bool, error) { + + found, err := ins.db.TRCExists(ctx, decTRC) + if err != nil || found { + return !found, err + } + if decTRC.TRC.Base() { + // XXX(roosd): remove when TAACs are supported. + if ins.unsafe { + if _, err := ins.db.InsertTRC(ctx, decTRC); err != nil { + return false, serrors.WrapStr("unable to insert base TRC", err) + } + return false, nil + } + return false, serrors.WithCtx(ErrBaseNotSupported, "trc", decTRC) + } + prev, err := trcProvider(ctx, decTRC.TRC.ISD, decTRC.TRC.Version-1) + if err != nil { + return false, serrors.WrapStr("unable to get previous TRC", err, + "isd", decTRC.TRC.ISD, "version", decTRC.TRC.Version-1) + } + if err := ins.checkUpdate(ctx, prev, decTRC); err != nil { + return false, serrors.WrapStr("error checking TRC update", err) + } + return true, nil +} + +func (ins *baseInserter) checkUpdate(ctx context.Context, prev *trc.TRC, next decoded.TRC) error { + validator := trc.UpdateValidator{ + Next: next.TRC, + Prev: prev, + } + if _, err := validator.Validate(); err != nil { + return serrors.Wrap(ErrValidation, err) + } + verifier := trc.UpdateVerifier{ + Next: next.TRC, + NextEncoded: next.Signed.EncodedTRC, + Prev: prev, + Signatures: next.Signed.Signatures, + } + if err := verifier.Verify(); err != nil { + return serrors.Wrap(ErrVerification, err) + } + return nil +} + +func (ins *baseInserter) shouldInsertChain(ctx context.Context, chain decoded.Chain, + trcProvider TRCProviderFunc) (bool, error) { + + return false, serrors.New("not implemented") +} diff --git a/go/lib/infra/modules/trust/v2/inserter_test.go b/go/lib/infra/modules/trust/v2/inserter_test.go new file mode 100644 index 0000000000..6f9665c6d5 --- /dev/null +++ b/go/lib/infra/modules/trust/v2/inserter_test.go @@ -0,0 +1,100 @@ +// Copyright 2019 Anapaya Systems +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package trust_test + +import ( + "context" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/scionproto/scion/go/lib/infra/modules/trust/v2" + "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/internal/decoded" + "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/mock_v2" +) + +func TestInserterInsertTRC(t *testing.T) { + tests := map[string]struct { + Expect func(*mock_v2.MockDB, decoded.TRC) + Unsafe bool + ExpectedErr error + }{ + "Exists with same contents": { + Expect: func(db *mock_v2.MockDB, decTRC decoded.TRC) { + db.EXPECT().TRCExists(gomock.Any(), decTRC).Return( + true, nil, + ) + }, + }, + "Exists with different contents": { + Expect: func(db *mock_v2.MockDB, decTRC decoded.TRC) { + db.EXPECT().TRCExists(gomock.Any(), decTRC).Return( + true, trust.ErrContentMismatch, + ) + }, + ExpectedErr: trust.ErrContentMismatch, + }, + "Base TRC and unsafe set": { + Expect: func(db *mock_v2.MockDB, decTRC decoded.TRC) { + db.EXPECT().TRCExists(gomock.Any(), decTRC).Return( + false, nil, + ) + db.EXPECT().InsertTRC(gomock.Any(), decTRC).Return(true, nil) + }, + Unsafe: true, + }, + "Base TRC and unsafe set, insert fail": { + Expect: func(db *mock_v2.MockDB, decTRC decoded.TRC) { + db.EXPECT().TRCExists(gomock.Any(), decTRC).Return( + false, nil, + ) + db.EXPECT().InsertTRC(gomock.Any(), decTRC).Return( + false, trust.ErrContentMismatch, + ) + }, + ExpectedErr: trust.ErrContentMismatch, + Unsafe: true, + }, + "Base TRC and unsafe not set": { + Expect: func(db *mock_v2.MockDB, decTRC decoded.TRC) { + db.EXPECT().TRCExists(gomock.Any(), decTRC).Return( + false, nil, + ) + }, + ExpectedErr: trust.ErrBaseNotSupported, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + mctrl := gomock.NewController(t) + defer mctrl.Finish() + + db := mock_v2.NewMockDB(mctrl) + decoded := loadTRC(t, trc1v1) + test.Expect(db, decoded) + ins := trust.NewInserter(db, test.Unsafe) + + err := ins.InsertTRC(context.Background(), decoded, nil) + if test.ExpectedErr != nil { + require.Truef(t, xerrors.Is(err, test.ExpectedErr), + "Expected: %s Actual: %s", test.ExpectedErr, err) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/go/lib/infra/modules/trust/v2/router.go b/go/lib/infra/modules/trust/v2/router.go index 36f9da669a..c0c874b8fb 100644 --- a/go/lib/infra/modules/trust/v2/router.go +++ b/go/lib/infra/modules/trust/v2/router.go @@ -17,8 +17,14 @@ package trust import ( "context" "net" + "time" + + "golang.org/x/xerrors" "github.com/scionproto/scion/go/lib/addr" + "github.com/scionproto/scion/go/lib/scrypto" + "github.com/scionproto/scion/go/lib/serrors" + "github.com/scionproto/scion/go/lib/snet" ) // Router builds the CS address for crypto material with the subject in a given ISD. @@ -27,3 +33,65 @@ type Router interface { // subject in the provided ISD. ChooseServer(ctx context.Context, subjectISD addr.ISD) (net.Addr, error) } + +type localRouter struct { + ia addr.IA +} + +// ChooseServer always routes to the local CS. +func (r *localRouter) ChooseServer(_ context.Context, _ addr.ISD) (net.Addr, error) { + return r.chooseServer(), nil +} + +func (r *localRouter) chooseServer() net.Addr { + return &snet.Addr{IA: r.ia, Host: addr.NewSVCUDPAppAddr(addr.SvcCS)} +} + +type csRouter struct { + isd addr.ISD + router snet.Router + db TRCRead +} + +// ChooseServer builds a CS address for crypto with the subject in a given ISD. +// * a local authoritative CS if subject is ISD-local. +// * a local authoritative CS if subject is in remote ISD, but no active TRC is available. +// * a remote authoritative CS otherwise. +func (r *csRouter) ChooseServer(ctx context.Context, subjectISD addr.ISD) (net.Addr, error) { + dstISD, err := r.dstISD(ctx, subjectISD) + if err != nil { + return nil, serrors.WrapStr("unable to determine dest ISD to query", err) + } + path, err := r.router.Route(ctx, addr.IA{I: dstISD}) + if err != nil { + return nil, serrors.WrapStr("unable to find path to any core AS", err, "isd", dstISD) + } + a := &snet.Addr{ + IA: path.Destination(), + Host: addr.NewSVCUDPAppAddr(addr.SvcCS), + Path: path.Path(), + NextHop: path.OverlayNextHop(), + } + return a, nil +} + +// dstISD selects the CS to ask for crypto material, using the following strategy: +// * a local authoritative CS if subject is ISD-local. +// * a local authoritative CS if subject is in remote ISD, but no active TRC is available. +// * a remote authoritative CS otherwise. +func (r *csRouter) dstISD(ctx context.Context, destination addr.ISD) (addr.ISD, error) { + if destination == r.isd { + return r.isd, nil + } + info, err := r.db.GetTRCInfo(ctx, destination, scrypto.Version(scrypto.LatestVer)) + if err != nil { + if xerrors.Is(err, ErrNotFound) { + return r.isd, nil + } + return 0, serrors.WrapStr("error querying DB for TRC", err) + } + if !info.Validity.Contains(time.Now()) { + return r.isd, nil + } + return destination, nil +} diff --git a/go/lib/infra/modules/trust/v2/router_test.go b/go/lib/infra/modules/trust/v2/router_test.go new file mode 100644 index 0000000000..74175d81f3 --- /dev/null +++ b/go/lib/infra/modules/trust/v2/router_test.go @@ -0,0 +1,151 @@ +// Copyright 2019 Anapaya Systems +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package trust_test + +import ( + "context" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/scionproto/scion/go/lib/addr" + "github.com/scionproto/scion/go/lib/common" + "github.com/scionproto/scion/go/lib/infra/modules/trust/v2" + "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/mock_v2" + "github.com/scionproto/scion/go/lib/scrypto" + "github.com/scionproto/scion/go/lib/snet" + "github.com/scionproto/scion/go/lib/snet/mock_snet" + "github.com/scionproto/scion/go/lib/spath" + "github.com/scionproto/scion/go/lib/util" +) + +func TestLocalRouterChooseServer(t *testing.T) { + tests := map[string]addr.ISD{ + "ISD local": 1, + "Remote ISD": 2, + } + for name, isd := range tests { + t.Run(name, func(t *testing.T) { + localCS := &snet.Addr{IA: ia122, Host: addr.NewSVCUDPAppAddr(addr.SvcCS)} + router := trust.NewLocalRouter(localCS.IA) + routed, err := router.ChooseServer(context.Background(), isd) + require.NoError(t, err) + assert.Equal(t, localCS, routed) + }) + } +} + +func TestCSRouterChooseServer(t *testing.T) { + tests := map[string]struct { + ISD addr.ISD + Expect func(*mock_v2.MockDB, *mock_snet.MockRouter, *mock_snet.MockPath) + ExpectedErr error + }{ + "ISD local": { + ISD: 1, + Expect: func(_ *mock_v2.MockDB, r *mock_snet.MockRouter, p *mock_snet.MockPath) { + p.EXPECT().Path().AnyTimes().Return(&spath.Path{Raw: []byte("isd local path")}) + p.EXPECT().Destination().AnyTimes().Return(ia110) + p.EXPECT().OverlayNextHop().AnyTimes().Return(nil) + r.EXPECT().Route(gomock.Any(), addr.IA{I: 1}).Return(p, nil) + }, + }, + "ISD local, Route error": { + ISD: 1, + Expect: func(_ *mock_v2.MockDB, r *mock_snet.MockRouter, p *mock_snet.MockPath) { + r.EXPECT().Route(gomock.Any(), addr.IA{I: 1}).Return( + nil, common.NewBasicError("unable to route", nil), + ) + }, + ExpectedErr: common.NewBasicError("unable to route", nil), + }, + "Remote ISD, Valid TRC": { + ISD: 2, + Expect: func(db *mock_v2.MockDB, r *mock_snet.MockRouter, p *mock_snet.MockPath) { + future := util.UnixTime{Time: time.Now().Add(time.Hour)} + db.EXPECT().GetTRCInfo(gomock.Any(), addr.ISD(2), scrypto.Version(0)).Return( + trust.TRCInfo{Validity: scrypto.Validity{NotAfter: future}}, nil, + ) + p.EXPECT().Path().AnyTimes().Return(&spath.Path{Raw: []byte("remote ISD path")}) + p.EXPECT().Destination().AnyTimes().Return(ia210) + p.EXPECT().OverlayNextHop().AnyTimes().Return(nil) + r.EXPECT().Route(gomock.Any(), addr.IA{I: 2}).Return(p, nil) + }, + }, + "Remote ISD, TRC not found": { + ISD: 2, + Expect: func(db *mock_v2.MockDB, r *mock_snet.MockRouter, p *mock_snet.MockPath) { + db.EXPECT().GetTRCInfo(gomock.Any(), addr.ISD(2), scrypto.Version(0)).Return( + trust.TRCInfo{}, trust.ErrNotFound, + ) + p.EXPECT().Path().AnyTimes().Return(&spath.Path{Raw: []byte("isd local path")}) + p.EXPECT().Destination().AnyTimes().Return(ia110) + p.EXPECT().OverlayNextHop().AnyTimes().Return(nil) + r.EXPECT().Route(gomock.Any(), addr.IA{I: 1}).Return(p, nil) + }, + }, + "Remote ISD, Expired TRC": { + ISD: 2, + Expect: func(db *mock_v2.MockDB, r *mock_snet.MockRouter, p *mock_snet.MockPath) { + passed := util.UnixTime{Time: time.Now().Add(-time.Second)} + db.EXPECT().GetTRCInfo(gomock.Any(), addr.ISD(2), scrypto.Version(0)).Return( + trust.TRCInfo{Validity: scrypto.Validity{NotAfter: passed}}, nil, + ) + p.EXPECT().Path().AnyTimes().Return(&spath.Path{Raw: []byte("isd local path")}) + p.EXPECT().Destination().AnyTimes().Return(ia110) + p.EXPECT().OverlayNextHop().AnyTimes().Return(nil) + r.EXPECT().Route(gomock.Any(), addr.IA{I: 1}).Return(p, nil) + }, + }, + "Remote ISD, DB error": { + ISD: 2, + Expect: func(db *mock_v2.MockDB, r *mock_snet.MockRouter, p *mock_snet.MockPath) { + db.EXPECT().GetTRCInfo(gomock.Any(), addr.ISD(2), scrypto.Version(0)).Return( + trust.TRCInfo{}, common.NewBasicError("DB error", nil), + ) + }, + ExpectedErr: common.NewBasicError("DB error", nil), + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + mctrl := gomock.NewController(t) + defer mctrl.Finish() + db := mock_v2.NewMockDB(mctrl) + r, p := mock_snet.NewMockRouter(mctrl), mock_snet.NewMockPath(mctrl) + test.Expect(db, r, p) + router := trust.NewCSRouter(1, r, db) + res, err := router.ChooseServer(context.Background(), test.ISD) + if test.ExpectedErr != nil { + require.Error(t, err) + assert.True(t, xerrors.Is(err, test.ExpectedErr), "Expected: %s Actual: %s", + test.ExpectedErr, err) + } else { + require.NoError(t, err) + expected := &snet.Addr{ + IA: p.Destination(), + Host: addr.NewSVCUDPAppAddr(addr.SvcCS), + Path: p.Path(), + NextHop: p.OverlayNextHop(), + } + assert.Equal(t, expected, res) + } + }) + } +}