From b72ecf4aa66d2d1ac8ed80d8f2cde10aa594a72e Mon Sep 17 00:00:00 2001 From: Lukas Vogel Date: Tue, 17 Dec 2019 08:17:20 +0100 Subject: [PATCH] TrustStore: unify interface (#3532) Introduce TRCID and use ChainID and TRCID consistently in trust package. Use value types for simple structs instead of pointers. Fixes #3525 --- go/lib/infra/common.go | 3 + go/lib/infra/modules/trust/v2/db.go | 8 +- go/lib/infra/modules/trust/v2/handlers.go | 26 ++- .../infra/modules/trust/v2/handlers_test.go | 30 ++- go/lib/infra/modules/trust/v2/inserter.go | 11 +- .../infra/modules/trust/v2/inserter_test.go | 7 +- go/lib/infra/modules/trust/v2/inspector.go | 8 +- .../infra/modules/trust/v2/inspector_test.go | 12 +- go/lib/infra/modules/trust/v2/mock_v2/v2.go | 60 ++--- go/lib/infra/modules/trust/v2/provider.go | 102 +++++---- .../infra/modules/trust/v2/provider_test.go | 209 ++++++++++++------ go/lib/infra/modules/trust/v2/resolver.go | 20 +- .../infra/modules/trust/v2/resolver_test.go | 36 ++- go/lib/infra/modules/trust/v2/router.go | 2 +- go/lib/infra/modules/trust/v2/router_test.go | 12 +- go/lib/infra/modules/trust/v2/rpc.go | 2 +- go/lib/infra/modules/trust/v2/signer.go | 6 +- go/lib/infra/modules/trust/v2/signer_test.go | 49 ++-- .../modules/trust/v2/trustdbsqlite/db.go | 32 ++- .../trust/v2/trustdbtest/trustdbtest.go | 31 +-- go/lib/infra/modules/trust/v2/verifier.go | 2 +- .../infra/modules/trust/v2/verifier_test.go | 2 +- 22 files changed, 397 insertions(+), 273 deletions(-) diff --git a/go/lib/infra/common.go b/go/lib/infra/common.go index e8a318e555..9273f69f60 100644 --- a/go/lib/infra/common.go +++ b/go/lib/infra/common.go @@ -508,6 +508,9 @@ type TrustStoreOpts struct { // requests, if they are not available locally. If it is not set, the // trust store does its own server resolution. Server net.Addr + // Client indicates the peer who sent this request to the trust store, if + // applicable. + Client net.Addr // LocalOnly indicates that the store should only check locally. LocalOnly bool } diff --git a/go/lib/infra/modules/trust/v2/db.go b/go/lib/infra/modules/trust/v2/db.go index 8a6e429f4d..ded9544ea4 100644 --- a/go/lib/infra/modules/trust/v2/db.go +++ b/go/lib/infra/modules/trust/v2/db.go @@ -80,13 +80,13 @@ type TRCRead interface { // database with differing contents. TRCExists(ctx context.Context, d decoded.TRC) (bool, error) // GetTRC returns the TRC. If it is not found, ErrNotFound is returned. - GetTRC(ctx context.Context, isd addr.ISD, version scrypto.Version) (*trc.TRC, error) + GetTRC(ctx context.Context, id TRCID) (*trc.TRC, error) // GetRawTRC returns the raw signed TRC bytes. If it is not found, // ErrNotFound is returned. - GetRawTRC(ctx context.Context, isd addr.ISD, version scrypto.Version) ([]byte, error) + GetRawTRC(ctx context.Context, id TRCID) ([]byte, error) // GetTRCInfo returns the infos for the requested TRC. If it is not found, // ErrNotFound is returned. - GetTRCInfo(ctx context.Context, isd addr.ISD, version scrypto.Version) (TRCInfo, error) + GetTRCInfo(ctx context.Context, id TRCID) (TRCInfo, error) // GetIssuingKeyInfo returns the infos of the requested AS. If it is not // found, ErrNotFound is returned. GetIssuingKeyInfo(ctx context.Context, ia addr.IA, version scrypto.Version) (KeyInfo, error) @@ -105,7 +105,7 @@ type TRCWrite interface { type ChainRead interface { // GetRawChain returns the raw signed certificate chain bytes. If it is not // found, ErrNotFound is returned. - GetRawChain(ctx context.Context, ia addr.IA, version scrypto.Version) ([]byte, error) + GetRawChain(ctx context.Context, id ChainID) ([]byte, error) // ChainExists returns whether the certificate chain is found in the // database and the content matches. ErrContentMismatch is returned if any // of the two certificates exist in the database with differing contents. diff --git a/go/lib/infra/modules/trust/v2/handlers.go b/go/lib/infra/modules/trust/v2/handlers.go index 330d2d1972..16021465f4 100644 --- a/go/lib/infra/modules/trust/v2/handlers.go +++ b/go/lib/infra/modules/trust/v2/handlers.go @@ -21,14 +21,12 @@ import ( "golang.org/x/xerrors" - "github.com/scionproto/scion/go/lib/addr" "github.com/scionproto/scion/go/lib/common" "github.com/scionproto/scion/go/lib/ctrl/cert_mgmt" "github.com/scionproto/scion/go/lib/infra" "github.com/scionproto/scion/go/lib/infra/messenger" "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/internal/decoded" "github.com/scionproto/scion/go/lib/log" - "github.com/scionproto/scion/go/lib/scrypto" "github.com/scionproto/scion/go/lib/scrypto/trc/v2" "github.com/scionproto/scion/go/proto" ) @@ -66,8 +64,13 @@ func (h *chainReqHandler) Handle() *infra.HandlerResult { return infra.MetricsErrInternal } sendAck := messenger.SendAckHelper(ctx, rw) - raw, err := h.provider.GetRawChain(ctx, chainReq.IA(), chainReq.Version, - infra.ChainOpts{AllowInactive: true}, h.request.Peer) + raw, err := h.provider.GetRawChain(ctx, + ChainID{IA: chainReq.IA(), Version: chainReq.Version}, + infra.ChainOpts{ + TrustStoreOpts: infra.TrustStoreOpts{Client: h.request.Peer}, + AllowInactive: true, + }, + ) if err != nil { logger.Error("[TrustStore:chainReqHandler] Unable to retrieve chain", "err", err) sendAck(proto.Ack_ErrCode_reject, AckNotFound) @@ -114,8 +117,13 @@ func (h *trcReqHandler) Handle() *infra.HandlerResult { return infra.MetricsErrInternal } sendAck := messenger.SendAckHelper(ctx, rw) - raw, err := h.provider.GetRawTRC(ctx, trcReq.ISD, trcReq.Version, - infra.TRCOpts{AllowInactive: true}, h.request.Peer) + raw, err := h.provider.GetRawTRC(ctx, + TRCID{ISD: trcReq.ISD, Version: trcReq.Version}, + infra.TRCOpts{ + TrustStoreOpts: infra.TrustStoreOpts{Client: h.request.Peer}, + AllowInactive: true, + }, + ) if err != nil { logger.Error("[TrustStore:trcReqHandler] Unable to retrieve TRC", "err", err) sendAck(proto.Ack_ErrCode_reject, AckNotFound) @@ -254,15 +262,15 @@ func (h *trcPushHandler) Handle() *infra.HandlerResult { } func newTRCGetter(provider CryptoProvider, peer net.Addr) func(context.Context, - addr.ISD, scrypto.Version) (*trc.TRC, error) { + TRCID) (*trc.TRC, error) { - return func(ctx context.Context, isd addr.ISD, version scrypto.Version) (*trc.TRC, error) { + return func(ctx context.Context, id TRCID) (*trc.TRC, error) { opts := infra.TRCOpts{ TrustStoreOpts: infra.TrustStoreOpts{ Server: peer, }, AllowInactive: true, } - return provider.GetTRC(ctx, isd, version, opts) + return provider.GetTRC(ctx, id, opts) } } diff --git a/go/lib/infra/modules/trust/v2/handlers_test.go b/go/lib/infra/modules/trust/v2/handlers_test.go index ee5169eda0..99fee4e751 100644 --- a/go/lib/infra/modules/trust/v2/handlers_test.go +++ b/go/lib/infra/modules/trust/v2/handlers_test.go @@ -84,8 +84,9 @@ func TestChainReqHandler(t *testing.T) { AllowInactive: true, } p := mock_v2.NewMockCryptoProvider(ctrl) - p.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.LatestVer, - opts, nil).Return(nil, trust.ErrNotFound) + p.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.LatestVer}, + opts).Return(nil, trust.ErrNotFound) return p }, ExpectedResult: infra.MetricsErrTrustStore(trust.ErrNotFound), @@ -109,8 +110,9 @@ func TestChainReqHandler(t *testing.T) { AllowInactive: true, } p := mock_v2.NewMockCryptoProvider(ctrl) - p.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.LatestVer, - opts, nil).Return([]byte("test"), nil) + p.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.LatestVer}, + opts).Return([]byte("test"), nil) return p }, ExpectedResult: infra.MetricsErrMsger(infra.ErrTransport), @@ -134,8 +136,9 @@ func TestChainReqHandler(t *testing.T) { AllowInactive: true, } p := mock_v2.NewMockCryptoProvider(ctrl) - p.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.LatestVer, - opts, nil).Return([]byte("test"), nil) + p.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.LatestVer}, + opts).Return([]byte("test"), nil) return p }, ExpectedResult: infra.MetricsResultOk, @@ -208,8 +211,9 @@ func TestTRCReqHandler(t *testing.T) { AllowInactive: true, } p := mock_v2.NewMockCryptoProvider(ctrl) - p.EXPECT().GetRawTRC(gomock.Any(), addr.ISD(1), scrypto.LatestVer, - opts, nil).Return(nil, trust.ErrNotFound) + p.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: addr.ISD(1), Version: scrypto.LatestVer}, + opts).Return(nil, trust.ErrNotFound) return p }, ExpectedResult: infra.MetricsErrTrustStore(trust.ErrNotFound), @@ -233,8 +237,9 @@ func TestTRCReqHandler(t *testing.T) { AllowInactive: true, } p := mock_v2.NewMockCryptoProvider(ctrl) - p.EXPECT().GetRawTRC(gomock.Any(), addr.ISD(1), scrypto.LatestVer, - opts, nil).Return([]byte("test"), nil) + p.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: addr.ISD(1), Version: scrypto.LatestVer}, + opts).Return([]byte("test"), nil) return p }, ExpectedResult: infra.MetricsErrMsger(infra.ErrTransport), @@ -258,8 +263,9 @@ func TestTRCReqHandler(t *testing.T) { AllowInactive: true, } p := mock_v2.NewMockCryptoProvider(ctrl) - p.EXPECT().GetRawTRC(gomock.Any(), addr.ISD(1), scrypto.LatestVer, - opts, nil).Return([]byte("test"), nil) + p.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: addr.ISD(1), Version: scrypto.LatestVer}, + opts).Return([]byte("test"), nil) return p }, ExpectedResult: infra.MetricsResultOk, diff --git a/go/lib/infra/modules/trust/v2/inserter.go b/go/lib/infra/modules/trust/v2/inserter.go index ca460298e0..ad197aac1b 100644 --- a/go/lib/infra/modules/trust/v2/inserter.go +++ b/go/lib/infra/modules/trust/v2/inserter.go @@ -17,9 +17,7 @@ package trust import ( "context" - "github.com/scionproto/scion/go/lib/addr" "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/internal/decoded" - "github.com/scionproto/scion/go/lib/scrypto" "github.com/scionproto/scion/go/lib/scrypto/cert/v2" "github.com/scionproto/scion/go/lib/scrypto/trc/v2" "github.com/scionproto/scion/go/lib/serrors" @@ -47,7 +45,7 @@ type Inserter interface { // TRCProviderFunc provides TRCs. It is used to configure the TRC retrieval // method of the inserter. -type TRCProviderFunc func(context.Context, addr.ISD, scrypto.Version) (*trc.TRC, error) +type TRCProviderFunc func(context.Context, TRCID) (*trc.TRC, error) // inserter is used to verify and insert trust material into the database. type inserter struct { @@ -168,7 +166,7 @@ func (ins *baseInserter) shouldInsertTRC(ctx context.Context, decTRC decoded.TRC } return false, serrors.WithCtx(ErrBaseNotSupported, "trc", decTRC) } - prev, err := trcProvider(ctx, decTRC.TRC.ISD, decTRC.TRC.Version-1) + prev, err := trcProvider(ctx, TRCID{ISD: decTRC.TRC.ISD, Version: decTRC.TRC.Version - 1}) if err != nil { return false, serrors.WrapStr("unable to get previous TRC", err, "isd", decTRC.TRC.ISD, "version", decTRC.TRC.Version-1) @@ -212,7 +210,10 @@ func (ins *baseInserter) shouldInsertChain(ctx context.Context, chain decoded.Ch if err := ins.validateChain(chain); err != nil { return false, serrors.WrapStr("error validating the certificate chain", err) } - t, err := trcProvider(ctx, chain.Issuer.Subject.I, chain.Issuer.Issuer.TRCVersion) + t, err := trcProvider(ctx, TRCID{ + ISD: chain.Issuer.Subject.I, + Version: chain.Issuer.Issuer.TRCVersion, + }) if err != nil { return false, serrors.WrapStr("unable to get issuing TRC", err, "isd", chain.Issuer.Subject.I, "version", chain.Issuer.Issuer.TRCVersion) diff --git a/go/lib/infra/modules/trust/v2/inserter_test.go b/go/lib/infra/modules/trust/v2/inserter_test.go index 8d27cd1f3b..03277be76c 100644 --- a/go/lib/infra/modules/trust/v2/inserter_test.go +++ b/go/lib/infra/modules/trust/v2/inserter_test.go @@ -24,7 +24,6 @@ import ( "github.com/scionproto/scion/go/lib/infra/modules/trust/v2" "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/internal/decoded" "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/mock_v2" - "github.com/scionproto/scion/go/lib/scrypto" "github.com/scionproto/scion/go/lib/scrypto/trc/v2" "github.com/scionproto/scion/go/lib/serrors" "github.com/scionproto/scion/go/lib/snet" @@ -137,7 +136,7 @@ func TestInserterInsertChain(t *testing.T) { false, nil, ) }, - TRCProvider: func(context.Context, addr.ISD, scrypto.Version) (*trc.TRC, error) { + TRCProvider: func(context.Context, trust.TRCID) (*trc.TRC, error) { return nil, notFound }, ExpectedErr: notFound, @@ -201,7 +200,7 @@ func TestInserterInsertChain(t *testing.T) { ins := trust.NewInserter(db, false) decTRC := loadTRC(t, trc1v1) - p := func(ctx context.Context, isd addr.ISD, ver scrypto.Version) (*trc.TRC, error) { + p := func(_ context.Context, _ trust.TRCID) (*trc.TRC, error) { return decTRC.TRC, nil } if test.TRCProvider != nil { @@ -292,7 +291,7 @@ func TestFwdInserterInsertChain(t *testing.T) { ins := trust.NewFwdInserter(m.DB, m.RPC) decTRC := loadTRC(t, trc1v1) - p := func(ctx context.Context, isd addr.ISD, ver scrypto.Version) (*trc.TRC, error) { + p := func(_ context.Context, _ trust.TRCID) (*trc.TRC, error) { return decTRC.TRC, nil } if test.TRCProvider != nil { diff --git a/go/lib/infra/modules/trust/v2/inspector.go b/go/lib/infra/modules/trust/v2/inspector.go index 785152da5c..abe9af7b23 100644 --- a/go/lib/infra/modules/trust/v2/inspector.go +++ b/go/lib/infra/modules/trust/v2/inspector.go @@ -44,7 +44,9 @@ func (i *inspector) ByAttributes(ctx context.Context, isd addr.ISD, opts infra.ASInspectorOpts) ([]addr.IA, error) { trcOpts := infra.TRCOpts{TrustStoreOpts: opts.TrustStoreOpts} - t, err := i.provider.GetTRC(ctx, isd, scrypto.Version(scrypto.LatestVer), trcOpts) + t, err := i.provider.GetTRC(ctx, TRCID{ + ISD: isd, Version: scrypto.Version(scrypto.LatestVer)}, + trcOpts) if err != nil { return nil, serrors.WrapStr("unable to get latest TRC", err, "isd", isd) } @@ -63,7 +65,9 @@ func (i *inspector) HasAttributes(ctx context.Context, ia addr.IA, opts infra.ASInspectorOpts) (bool, error) { trcOpts := infra.TRCOpts{TrustStoreOpts: opts.TrustStoreOpts} - trc, err := i.provider.GetTRC(ctx, ia.I, scrypto.Version(scrypto.LatestVer), trcOpts) + trc, err := i.provider.GetTRC(ctx, TRCID{ + ISD: ia.I, Version: scrypto.Version(scrypto.LatestVer)}, + trcOpts) if err != nil { return false, serrors.WrapStr("unable to get latest TRC", err, "isd", ia.I) } diff --git a/go/lib/infra/modules/trust/v2/inspector_test.go b/go/lib/infra/modules/trust/v2/inspector_test.go index 1c784c650b..7ff7b5e228 100644 --- a/go/lib/infra/modules/trust/v2/inspector_test.go +++ b/go/lib/infra/modules/trust/v2/inspector_test.go @@ -60,8 +60,8 @@ func TestInspectorByAttributes(t *testing.T) { }, "error": { Expect: func(provider *mock_v2.MockCryptoProvider, _ *trc.TRC) { - provider.EXPECT().GetTRC(gomock.Any(), trc1v1.ISD, - scrypto.LatestVer, gomock.Any()).Return(nil, trust.ErrNotFound) + provider.EXPECT().GetTRC(gomock.Any(), trust.TRCID{ISD: trc1v1.ISD, + Version: scrypto.LatestVer}, gomock.Any()).Return(nil, trust.ErrNotFound) }, ExpectedErr: trust.ErrNotFound, }, @@ -140,8 +140,8 @@ func TestInspectorHasAttributes(t *testing.T) { "error": { IA: ia110, Expect: func(provider *mock_v2.MockCryptoProvider, _ *trc.TRC) { - provider.EXPECT().GetTRC(gomock.Any(), ia110.I, - scrypto.LatestVer, gomock.Any()).Return(nil, trust.ErrNotFound) + provider.EXPECT().GetTRC(gomock.Any(), trust.TRCID{ISD: ia110.I, + Version: scrypto.LatestVer}, gomock.Any()).Return(nil, trust.ErrNotFound) }, ExpectedErr: trust.ErrNotFound, }, @@ -182,6 +182,6 @@ func defaultExpect(provider *mock_v2.MockCryptoProvider, trcObj *trc.TRC) { entry = trcObj.PrimaryASes[ia120.A] entry.Attributes = []trc.Attribute{trc.Issuing} trcObj.PrimaryASes[ia120.A] = entry - provider.EXPECT().GetTRC(gomock.Any(), addr.ISD(1), - scrypto.LatestVer, gomock.Any()).Return(trcObj, nil) + provider.EXPECT().GetTRC(gomock.Any(), trust.TRCID{ISD: addr.ISD(1), + Version: scrypto.LatestVer}, gomock.Any()).Return(trcObj, nil) } diff --git a/go/lib/infra/modules/trust/v2/mock_v2/v2.go b/go/lib/infra/modules/trust/v2/mock_v2/v2.go index d0b20dc3bd..c7d97d1422 100644 --- a/go/lib/infra/modules/trust/v2/mock_v2/v2.go +++ b/go/lib/infra/modules/trust/v2/mock_v2/v2.go @@ -43,10 +43,10 @@ func (m *MockCryptoProvider) EXPECT() *MockCryptoProviderMockRecorder { } // GetASKey mocks base method -func (m *MockCryptoProvider) GetASKey(arg0 context.Context, arg1 v2.ChainID, arg2 *infra.ChainOpts) (*scrypto.KeyMeta, error) { +func (m *MockCryptoProvider) GetASKey(arg0 context.Context, arg1 v2.ChainID, arg2 infra.ChainOpts) (scrypto.KeyMeta, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetASKey", arg0, arg1, arg2) - ret0, _ := ret[0].(*scrypto.KeyMeta) + ret0, _ := ret[0].(scrypto.KeyMeta) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -58,48 +58,48 @@ func (mr *MockCryptoProviderMockRecorder) GetASKey(arg0, arg1, arg2 interface{}) } // GetRawChain mocks base method -func (m *MockCryptoProvider) GetRawChain(arg0 context.Context, arg1 addr.IA, arg2 scrypto.Version, arg3 infra.ChainOpts, arg4 net.Addr) ([]byte, error) { +func (m *MockCryptoProvider) GetRawChain(arg0 context.Context, arg1 v2.ChainID, arg2 infra.ChainOpts) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetRawChain", arg0, arg1, arg2, arg3, arg4) + ret := m.ctrl.Call(m, "GetRawChain", arg0, arg1, arg2) ret0, _ := ret[0].([]byte) ret1, _ := ret[1].(error) return ret0, ret1 } // GetRawChain indicates an expected call of GetRawChain -func (mr *MockCryptoProviderMockRecorder) GetRawChain(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { +func (mr *MockCryptoProviderMockRecorder) GetRawChain(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRawChain", reflect.TypeOf((*MockCryptoProvider)(nil).GetRawChain), arg0, arg1, arg2, arg3, arg4) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRawChain", reflect.TypeOf((*MockCryptoProvider)(nil).GetRawChain), arg0, arg1, arg2) } // GetRawTRC mocks base method -func (m *MockCryptoProvider) GetRawTRC(arg0 context.Context, arg1 addr.ISD, arg2 scrypto.Version, arg3 infra.TRCOpts, arg4 net.Addr) ([]byte, error) { +func (m *MockCryptoProvider) GetRawTRC(arg0 context.Context, arg1 v2.TRCID, arg2 infra.TRCOpts) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetRawTRC", arg0, arg1, arg2, arg3, arg4) + ret := m.ctrl.Call(m, "GetRawTRC", arg0, arg1, arg2) ret0, _ := ret[0].([]byte) ret1, _ := ret[1].(error) return ret0, ret1 } // GetRawTRC indicates an expected call of GetRawTRC -func (mr *MockCryptoProviderMockRecorder) GetRawTRC(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { +func (mr *MockCryptoProviderMockRecorder) GetRawTRC(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRawTRC", reflect.TypeOf((*MockCryptoProvider)(nil).GetRawTRC), arg0, arg1, arg2, arg3, arg4) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRawTRC", reflect.TypeOf((*MockCryptoProvider)(nil).GetRawTRC), arg0, arg1, arg2) } // GetTRC mocks base method -func (m *MockCryptoProvider) GetTRC(arg0 context.Context, arg1 addr.ISD, arg2 scrypto.Version, arg3 infra.TRCOpts) (*v20.TRC, error) { +func (m *MockCryptoProvider) GetTRC(arg0 context.Context, arg1 v2.TRCID, arg2 infra.TRCOpts) (*v20.TRC, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTRC", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "GetTRC", arg0, arg1, arg2) ret0, _ := ret[0].(*v20.TRC) ret1, _ := ret[1].(error) return ret0, ret1 } // GetTRC indicates an expected call of GetTRC -func (mr *MockCryptoProviderMockRecorder) GetTRC(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockCryptoProviderMockRecorder) GetTRC(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTRC", reflect.TypeOf((*MockCryptoProvider)(nil).GetTRC), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTRC", reflect.TypeOf((*MockCryptoProvider)(nil).GetTRC), arg0, arg1, arg2) } // MockDB is a mock of DB interface @@ -185,63 +185,63 @@ func (mr *MockDBMockRecorder) GetIssuingKeyInfo(arg0, arg1, arg2 interface{}) *g } // GetRawChain mocks base method -func (m *MockDB) GetRawChain(arg0 context.Context, arg1 addr.IA, arg2 scrypto.Version) ([]byte, error) { +func (m *MockDB) GetRawChain(arg0 context.Context, arg1 v2.ChainID) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetRawChain", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetRawChain", arg0, arg1) ret0, _ := ret[0].([]byte) ret1, _ := ret[1].(error) return ret0, ret1 } // GetRawChain indicates an expected call of GetRawChain -func (mr *MockDBMockRecorder) GetRawChain(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockDBMockRecorder) GetRawChain(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRawChain", reflect.TypeOf((*MockDB)(nil).GetRawChain), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRawChain", reflect.TypeOf((*MockDB)(nil).GetRawChain), arg0, arg1) } // GetRawTRC mocks base method -func (m *MockDB) GetRawTRC(arg0 context.Context, arg1 addr.ISD, arg2 scrypto.Version) ([]byte, error) { +func (m *MockDB) GetRawTRC(arg0 context.Context, arg1 v2.TRCID) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetRawTRC", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetRawTRC", arg0, arg1) ret0, _ := ret[0].([]byte) ret1, _ := ret[1].(error) return ret0, ret1 } // GetRawTRC indicates an expected call of GetRawTRC -func (mr *MockDBMockRecorder) GetRawTRC(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockDBMockRecorder) GetRawTRC(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRawTRC", reflect.TypeOf((*MockDB)(nil).GetRawTRC), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRawTRC", reflect.TypeOf((*MockDB)(nil).GetRawTRC), arg0, arg1) } // GetTRC mocks base method -func (m *MockDB) GetTRC(arg0 context.Context, arg1 addr.ISD, arg2 scrypto.Version) (*v20.TRC, error) { +func (m *MockDB) GetTRC(arg0 context.Context, arg1 v2.TRCID) (*v20.TRC, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTRC", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetTRC", arg0, arg1) ret0, _ := ret[0].(*v20.TRC) ret1, _ := ret[1].(error) return ret0, ret1 } // GetTRC indicates an expected call of GetTRC -func (mr *MockDBMockRecorder) GetTRC(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockDBMockRecorder) GetTRC(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTRC", reflect.TypeOf((*MockDB)(nil).GetTRC), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTRC", reflect.TypeOf((*MockDB)(nil).GetTRC), arg0, arg1) } // GetTRCInfo mocks base method -func (m *MockDB) GetTRCInfo(arg0 context.Context, arg1 addr.ISD, arg2 scrypto.Version) (v2.TRCInfo, error) { +func (m *MockDB) GetTRCInfo(arg0 context.Context, arg1 v2.TRCID) (v2.TRCInfo, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTRCInfo", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetTRCInfo", arg0, arg1) ret0, _ := ret[0].(v2.TRCInfo) ret1, _ := ret[1].(error) return ret0, ret1 } // GetTRCInfo indicates an expected call of GetTRCInfo -func (mr *MockDBMockRecorder) GetTRCInfo(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockDBMockRecorder) GetTRCInfo(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTRCInfo", reflect.TypeOf((*MockDB)(nil).GetTRCInfo), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTRCInfo", reflect.TypeOf((*MockDB)(nil).GetTRCInfo), arg0, arg1) } // InsertChain mocks base method diff --git a/go/lib/infra/modules/trust/v2/provider.go b/go/lib/infra/modules/trust/v2/provider.go index 7f2782f836..6506eac0c1 100644 --- a/go/lib/infra/modules/trust/v2/provider.go +++ b/go/lib/infra/modules/trust/v2/provider.go @@ -40,23 +40,27 @@ type CryptoProvider interface { // server is queried over the network if the TRC is not available locally. // Otherwise, the default server is queried. How the default server is // determined differs between implementations. - GetTRC(ctx context.Context, isd addr.ISD, version scrypto.Version, - opts infra.TRCOpts) (*trc.TRC, error) + GetTRC(context.Context, TRCID, infra.TRCOpts) (*trc.TRC, error) // GetRawTRC behaves the same as GetTRC, except returning the raw signed TRC. - GetRawTRC(ctx context.Context, isd addr.ISD, version scrypto.Version, - opts infra.TRCOpts, client net.Addr) ([]byte, error) + GetRawTRC(context.Context, TRCID, infra.TRCOpts) ([]byte, error) // GetRawChain asks the trust store to return a valid and active certificate // chain, unless inactive chains are specifically allowed. The optionally // configured server is queried over the network if the certificate chain is // not available locally. Otherwise, the default server is queried. How the // default server is determined differs between implementations. - GetRawChain(ctx context.Context, ia addr.IA, version scrypto.Version, - opts infra.ChainOpts, client net.Addr) ([]byte, error) + GetRawChain(context.Context, ChainID, infra.ChainOpts) ([]byte, error) //GetASKey returns from trust store the public key required to verify signature //originated from an AS. - GetASKey(context.Context, ChainID, *infra.ChainOpts) (*scrypto.KeyMeta, error) + GetASKey(context.Context, ChainID, infra.ChainOpts) (scrypto.KeyMeta, error) } +// TRCID identifies a TRC. +type TRCID struct { + ISD addr.ISD + Version scrypto.Version +} + +// ChainID identifies a chain. type ChainID struct { IA addr.IA Version scrypto.Version @@ -69,31 +73,31 @@ type cryptoProvider struct { router Router } -func (p *cryptoProvider) GetTRC(ctx context.Context, isd addr.ISD, version scrypto.Version, +func (p *cryptoProvider) GetTRC(ctx context.Context, id TRCID, opts infra.TRCOpts) (*trc.TRC, error) { - t, _, err := p.getCheckedTRC(ctx, isd, version, opts, nil) + t, _, err := p.getCheckedTRC(ctx, id, opts) return t, err } -func (p *cryptoProvider) GetRawTRC(ctx context.Context, isd addr.ISD, version scrypto.Version, - opts infra.TRCOpts, client net.Addr) ([]byte, error) { +func (p *cryptoProvider) GetRawTRC(ctx context.Context, id TRCID, + opts infra.TRCOpts) ([]byte, error) { - _, raw, err := p.getCheckedTRC(ctx, isd, version, opts, client) + _, raw, err := p.getCheckedTRC(ctx, id, opts) return raw, err } -func (p *cryptoProvider) getCheckedTRC(ctx context.Context, isd addr.ISD, version scrypto.Version, - opts infra.TRCOpts, client net.Addr) (*trc.TRC, []byte, error) { +func (p *cryptoProvider) getCheckedTRC(ctx context.Context, id TRCID, + opts infra.TRCOpts) (*trc.TRC, []byte, error) { - decTRC, err := p.getTRC(ctx, isd, version, opts, nil) + decTRC, err := p.getTRC(ctx, id, opts, nil) if err != nil { return nil, nil, serrors.WrapStr("unable to get requested TRC", err) } if opts.AllowInactive { return decTRC.TRC, decTRC.Raw, nil } - info, err := p.db.GetTRCInfo(ctx, isd, scrypto.LatestVer) + info, err := p.db.GetTRCInfo(ctx, TRCID{ISD: id.ISD, Version: scrypto.LatestVer}) if err != nil { return nil, nil, serrors.WrapStr("unable to get latest TRC info", err) } @@ -105,14 +109,14 @@ func (p *cryptoProvider) getCheckedTRC(ctx context.Context, isd addr.ISD, versio return nil, nil, serrors.WrapStr("grace period has passed", ErrInactive, "end", info.Validity.NotBefore.Add(info.GracePeriod), "latest", info.Version) case !decTRC.TRC.Validity.Contains(time.Now()): - if !version.IsLatest() || opts.LocalOnly { + if !id.Version.IsLatest() || opts.LocalOnly { return nil, nil, serrors.WrapStr("requested TRC expired", ErrInactive, "validity", decTRC.TRC.Validity) } // There might exist a more recent TRC that is not available locally // yet. Fetch it if the latest version was requested and recursion // is allowed. - fetched, err := p.fetchTRC(ctx, isd, scrypto.LatestVer, opts, client) + fetched, err := p.fetchTRC(ctx, TRCID{ISD: id.ISD, Version: scrypto.LatestVer}, opts) if err != nil { return nil, nil, serrors.WrapStr("unable to fetch latest TRC from network", err) } @@ -135,10 +139,10 @@ func (p *cryptoProvider) getCheckedTRC(ctx context.Context, isd addr.ISD, versio // whether this function is allowed to create new network requests. Parameter // client contains the node that caused the function to be called, or nil if the // function was called due to a local feature. -func (p *cryptoProvider) getTRC(ctx context.Context, isd addr.ISD, version scrypto.Version, +func (p *cryptoProvider) getTRC(ctx context.Context, id TRCID, opts infra.TRCOpts, client net.Addr) (decoded.TRC, error) { - raw, err := p.db.GetRawTRC(ctx, isd, version) + raw, err := p.db.GetRawTRC(ctx, id) switch { case err == nil: return decoded.DecodeTRC(raw) @@ -147,26 +151,26 @@ func (p *cryptoProvider) getTRC(ctx context.Context, isd addr.ISD, version scryp case opts.LocalOnly: return decoded.TRC{}, serrors.WrapStr("localOnly requested", err) default: - return p.fetchTRC(ctx, isd, version, opts, client) + return p.fetchTRC(ctx, id, opts) } } // fetchTRC fetches a TRC via a network request, if allowed. -func (p *cryptoProvider) fetchTRC(ctx context.Context, isd addr.ISD, version scrypto.Version, - opts infra.TRCOpts, client net.Addr) (decoded.TRC, error) { +func (p *cryptoProvider) fetchTRC(ctx context.Context, id TRCID, + opts infra.TRCOpts) (decoded.TRC, error) { server := opts.Server - if err := p.recurser.AllowRecursion(client); err != nil { + if err := p.recurser.AllowRecursion(opts.Client); err != nil { return decoded.TRC{}, err } req := TRCReq{ - ISD: isd, - Version: version, + ISD: id.ISD, + Version: id.Version, } // Choose remote server, if not set. if server == nil { var err error - if server, err = p.router.ChooseServer(ctx, isd); err != nil { + if server, err = p.router.ChooseServer(ctx, id.ISD); err != nil { return decoded.TRC{}, serrors.WrapStr("unable to route TRC request", err) } } @@ -177,45 +181,45 @@ func (p *cryptoProvider) fetchTRC(ctx context.Context, isd addr.ISD, version scr return decTRC, nil } -func (p *cryptoProvider) GetRawChain(ctx context.Context, ia addr.IA, version scrypto.Version, - opts infra.ChainOpts, client net.Addr) ([]byte, error) { +func (p *cryptoProvider) GetRawChain(ctx context.Context, id ChainID, + opts infra.ChainOpts) ([]byte, error) { - chain, err := p.getChain(ctx, ia, version, opts, client) + chain, err := p.getChain(ctx, id, opts) if err != nil { return nil, serrors.WrapStr("unable to get requested certificate chain", err) } if opts.AllowInactive { return chain.Raw, nil } - err = p.issuerActive(ctx, chain, opts.TrustStoreOpts, client) + err = p.issuerActive(ctx, chain, opts.TrustStoreOpts) switch { case err == nil: return chain.Raw, nil case !xerrors.Is(err, ErrInactive): return nil, err - case !version.IsLatest(): + case !id.Version.IsLatest(): return nil, err case opts.LocalOnly: return nil, err default: // In case the latest certificate chain is requested, there might be a more // recent and active one that is not locally available yet. - fetched, err := p.fetchChain(ctx, ia, scrypto.LatestVer, opts, client) + fetched, err := p.fetchChain(ctx, id, opts) if err != nil { return nil, serrors.WrapStr("unable to fetch latest certificate chain from network", err) } - if err := p.issuerActive(ctx, fetched, opts.TrustStoreOpts, client); err != nil { + if err := p.issuerActive(ctx, fetched, opts.TrustStoreOpts); err != nil { return nil, serrors.WrapStr("latest certificate chain from network not active", err) } return fetched.Raw, nil } } -func (p *cryptoProvider) getChain(ctx context.Context, ia addr.IA, version scrypto.Version, - opts infra.ChainOpts, client net.Addr) (decoded.Chain, error) { +func (p *cryptoProvider) getChain(ctx context.Context, id ChainID, + opts infra.ChainOpts) (decoded.Chain, error) { - raw, err := p.db.GetRawChain(ctx, ia, version) + raw, err := p.db.GetRawChain(ctx, id) switch { case err == nil: return decoded.DecodeChain(raw) @@ -224,12 +228,12 @@ func (p *cryptoProvider) getChain(ctx context.Context, ia addr.IA, version scryp case opts.LocalOnly: return decoded.Chain{}, serrors.WrapStr("localOnly requested", err) default: - return p.fetchChain(ctx, ia, version, opts, client) + return p.fetchChain(ctx, id, opts) } } func (p *cryptoProvider) issuerActive(ctx context.Context, chain decoded.Chain, - opts infra.TrustStoreOpts, client net.Addr) error { + opts infra.TrustStoreOpts) error { if !chain.AS.Validity.Contains(time.Now()) { return serrors.WrapStr("AS certificate outside of validity period", ErrInactive, @@ -237,7 +241,9 @@ func (p *cryptoProvider) issuerActive(ctx context.Context, chain decoded.Chain, } // Ensure that an active TRC is available locally. trcOpts := infra.TRCOpts{TrustStoreOpts: opts} - _, _, err := p.getCheckedTRC(ctx, chain.Issuer.Subject.I, scrypto.LatestVer, trcOpts, client) + _, _, err := p.getCheckedTRC(ctx, TRCID{ + ISD: chain.Issuer.Subject.I, Version: scrypto.LatestVer}, + trcOpts) if err != nil { return serrors.WrapStr("unable to preload latest TRC", err) } @@ -271,21 +277,21 @@ func (p *cryptoProvider) issuerActive(ctx context.Context, chain decoded.Chain, return nil } -func (p *cryptoProvider) fetchChain(ctx context.Context, ia addr.IA, version scrypto.Version, - opts infra.ChainOpts, client net.Addr) (decoded.Chain, error) { +func (p *cryptoProvider) fetchChain(ctx context.Context, id ChainID, + opts infra.ChainOpts) (decoded.Chain, error) { server := opts.Server - if err := p.recurser.AllowRecursion(client); err != nil { + if err := p.recurser.AllowRecursion(opts.Client); err != nil { return decoded.Chain{}, err } req := ChainReq{ - IA: ia, - Version: version, + IA: id.IA, + Version: id.Version, } // Choose remote server, if not set. if server == nil { var err error - if server, err = p.router.ChooseServer(ctx, ia.I); err != nil { + if server, err = p.router.ChooseServer(ctx, id.IA.I); err != nil { return decoded.Chain{}, serrors.WrapStr("unable to route TRC request", err) } } @@ -298,10 +304,10 @@ func (p *cryptoProvider) fetchChain(ctx context.Context, ia addr.IA, version scr } func (p *cryptoProvider) GetASKey(ctx context.Context, - id ChainID, opts *infra.ChainOpts) (*scrypto.KeyMeta, error) { + id ChainID, opts infra.ChainOpts) (scrypto.KeyMeta, error) { // TODO(karampok): implement. - return nil, serrors.New("not implemented") + return scrypto.KeyMeta{}, serrors.New("not implemented") } func graceExpired(info TRCInfo) bool { diff --git a/go/lib/infra/modules/trust/v2/provider_test.go b/go/lib/infra/modules/trust/v2/provider_test.go index 31cf518ccf..d18e4eb5ec 100644 --- a/go/lib/infra/modules/trust/v2/provider_test.go +++ b/go/lib/infra/modules/trust/v2/provider_test.go @@ -52,7 +52,8 @@ func TestCryptoProviderGetTRC(t *testing.T) { }{ "TRC in database, allow inactive": { Expect: func(m *mocks, dec *decoded.TRC) { - m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + m.DB.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: dec.TRC.Version}).Return( dec.Raw, nil, ) }, @@ -60,10 +61,12 @@ func TestCryptoProviderGetTRC(t *testing.T) { }, "TRC in database, is newest": { Expect: func(m *mocks, dec *decoded.TRC) { - m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + m.DB.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: dec.TRC.Version}).Return( dec.Raw, nil, ) - m.DB.EXPECT().GetTRCInfo(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + m.DB.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: scrypto.LatestVer}).Return( trust.TRCInfo{Version: dec.TRC.Version}, nil, ) }, @@ -75,10 +78,12 @@ func TestCryptoProviderGetTRC(t *testing.T) { GracePeriod: time.Hour, Validity: scrypto.Validity{NotBefore: util.UnixTime{Time: time.Now()}}, } - m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + m.DB.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: dec.TRC.Version}).Return( dec.Raw, nil, ) - m.DB.EXPECT().GetTRCInfo(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + m.DB.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: scrypto.LatestVer}).Return( info, nil, ) }, @@ -86,7 +91,8 @@ func TestCryptoProviderGetTRC(t *testing.T) { "not found, resolve success": { Expect: func(m *mocks, dec *decoded.TRC) { ip := &net.IPAddr{IP: []byte{127, 0, 0, 1}} - m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + m.DB.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: dec.TRC.Version}).Return( nil, trust.ErrNotFound, ) m.Recurser.EXPECT().AllowRecursion(gomock.Any()).Return(nil) @@ -108,10 +114,12 @@ func TestCryptoProviderGetTRC(t *testing.T) { dec.TRC.Validity.NotAfter.Time = time.Now() dec.Signed.EncodedTRC, _ = trc.Encode(dec.TRC) dec.Raw, _ = json.Marshal(dec.Signed) - m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + m.DB.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: dec.TRC.Version}).Return( dec.Raw, nil, ) - m.DB.EXPECT().GetTRCInfo(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + m.DB.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: scrypto.LatestVer}).Return( trust.TRCInfo{Version: dec.TRC.Version}, nil, ) }, @@ -119,10 +127,12 @@ func TestCryptoProviderGetTRC(t *testing.T) { }, "TRC in database, invalidated by newer": { Expect: func(m *mocks, dec *decoded.TRC) { - m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + m.DB.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: dec.TRC.Version}).Return( dec.Raw, nil, ) - m.DB.EXPECT().GetTRCInfo(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + m.DB.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: scrypto.LatestVer}).Return( trust.TRCInfo{Version: dec.TRC.Version + 2}, nil, ) }, @@ -137,10 +147,12 @@ func TestCryptoProviderGetTRC(t *testing.T) { NotBefore: util.UnixTime{Time: time.Now().Add(-2 * time.Second)}, }, } - m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + m.DB.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: dec.TRC.Version}).Return( dec.Raw, nil, ) - m.DB.EXPECT().GetTRCInfo(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + m.DB.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: scrypto.LatestVer}).Return( info, nil, ) }, @@ -148,7 +160,8 @@ func TestCryptoProviderGetTRC(t *testing.T) { }, "DB error": { Expect: func(m *mocks, dec *decoded.TRC) { - m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + m.DB.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: dec.TRC.Version}).Return( nil, internal, ) }, @@ -156,10 +169,12 @@ func TestCryptoProviderGetTRC(t *testing.T) { }, "Fail getting TRC info": { Expect: func(m *mocks, dec *decoded.TRC) { - m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + m.DB.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: dec.TRC.Version}).Return( dec.Raw, nil, ) - m.DB.EXPECT().GetTRCInfo(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + m.DB.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: scrypto.LatestVer}).Return( trust.TRCInfo{}, internal, ) }, @@ -167,7 +182,8 @@ func TestCryptoProviderGetTRC(t *testing.T) { }, "not found, local only": { Expect: func(m *mocks, dec *decoded.TRC) { - m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + m.DB.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: dec.TRC.Version}).Return( nil, trust.ErrNotFound, ) }, @@ -176,7 +192,8 @@ func TestCryptoProviderGetTRC(t *testing.T) { }, "not found, recursion not allowed": { Expect: func(m *mocks, dec *decoded.TRC) { - m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + m.DB.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: dec.TRC.Version}).Return( nil, trust.ErrNotFound, ) m.Recurser.EXPECT().AllowRecursion(gomock.Any()).Return(internal) @@ -185,7 +202,8 @@ func TestCryptoProviderGetTRC(t *testing.T) { }, "not found, router error": { Expect: func(m *mocks, dec *decoded.TRC) { - m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + m.DB.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: dec.TRC.Version}).Return( nil, trust.ErrNotFound, ) m.Recurser.EXPECT().AllowRecursion(gomock.Any()).Return(nil) @@ -196,7 +214,8 @@ func TestCryptoProviderGetTRC(t *testing.T) { "not found, resolve error": { Expect: func(m *mocks, dec *decoded.TRC) { ip := &net.IPAddr{IP: []byte{127, 0, 0, 1}} - m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + m.DB.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: dec.TRC.Version}).Return( nil, trust.ErrNotFound, ) m.Recurser.EXPECT().AllowRecursion(gomock.Any()).Return(nil) @@ -212,7 +231,8 @@ func TestCryptoProviderGetTRC(t *testing.T) { "not found, server set": { Expect: func(m *mocks, dec *decoded.TRC) { ip := &net.IPAddr{IP: []byte{127, 0, 0, 1}} - m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, dec.TRC.Version).Return( + m.DB.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: dec.TRC.Version}).Return( nil, trust.ErrNotFound, ) m.Recurser.EXPECT().AllowRecursion(gomock.Any()).Return(nil) @@ -241,7 +261,8 @@ func TestCryptoProviderGetTRC(t *testing.T) { decoded := loadTRC(t, trc1v1) test.Expect(&m, &decoded) provider := trust.NewCryptoProvider(m.DB, m.Recurser, m.Resolver, m.Router) - ptrc, err := provider.GetTRC(nil, trc1v1.ISD, trc1v1.Version, test.Opts) + ptrc, err := provider.GetTRC(nil, + trust.TRCID{ISD: trc1v1.ISD, Version: trc1v1.Version}, test.Opts) if test.ExpectedErr != nil { require.Error(t, err) assert.Truef(t, xerrors.Is(err, test.ExpectedErr), @@ -269,7 +290,8 @@ func TestCryptoProviderGetTRCLatest(t *testing.T) { }{ "TRC in database, allow inactive": { Expect: func(m *mocks, dec *decoded.TRC) decoded.TRC { - m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + m.DB.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: scrypto.LatestVer}).Return( dec.Raw, nil, ) return *dec @@ -278,7 +300,8 @@ func TestCryptoProviderGetTRCLatest(t *testing.T) { }, "not found, resolve success": { Expect: func(m *mocks, dec *decoded.TRC) decoded.TRC { - m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + m.DB.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: scrypto.LatestVer}).Return( nil, trust.ErrNotFound, ) m.Recurser.EXPECT().AllowRecursion(gomock.Any()).Return(nil) @@ -300,10 +323,12 @@ func TestCryptoProviderGetTRCLatest(t *testing.T) { dec.TRC.Validity.NotAfter.Time = time.Now() dec.Signed.EncodedTRC, _ = trc.Encode(dec.TRC) dec.Raw, _ = json.Marshal(dec.Signed) - m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + m.DB.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: scrypto.LatestVer}).Return( dec.Raw, nil, ) - m.DB.EXPECT().GetTRCInfo(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + m.DB.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: scrypto.LatestVer}).Return( trust.TRCInfo{Version: dec.TRC.Version}, nil, ) m.Recurser.EXPECT().AllowRecursion(gomock.Any()).Return(internal) @@ -316,10 +341,12 @@ func TestCryptoProviderGetTRCLatest(t *testing.T) { dec.TRC.Validity.NotAfter.Time = time.Now() dec.Signed.EncodedTRC, _ = trc.Encode(dec.TRC) dec.Raw, _ = json.Marshal(dec.Signed) - m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + m.DB.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: scrypto.LatestVer}).Return( dec.Raw, nil, ) - m.DB.EXPECT().GetTRCInfo(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + m.DB.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: scrypto.LatestVer}).Return( trust.TRCInfo{Version: dec.TRC.Version}, nil, ) m.Recurser.EXPECT().AllowRecursion(gomock.Any()).Return(nil) @@ -339,10 +366,12 @@ func TestCryptoProviderGetTRCLatest(t *testing.T) { dec.TRC.Validity.NotAfter.Time = time.Now() dec.Signed.EncodedTRC, _ = trc.Encode(dec.TRC) dec.Raw, _ = json.Marshal(dec.Signed) - m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + m.DB.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: scrypto.LatestVer}).Return( dec.Raw, nil, ) - m.DB.EXPECT().GetTRCInfo(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + m.DB.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: scrypto.LatestVer}).Return( trust.TRCInfo{Version: dec.TRC.Version}, nil, ) m.Recurser.EXPECT().AllowRecursion(gomock.Any()).Return(nil) @@ -366,10 +395,12 @@ func TestCryptoProviderGetTRCLatest(t *testing.T) { dec.TRC.Validity.NotAfter.Time = time.Now() dec.Signed.EncodedTRC, _ = trc.Encode(dec.TRC) dec.Raw, _ = json.Marshal(dec.Signed) - m.DB.EXPECT().GetRawTRC(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + m.DB.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: scrypto.LatestVer}).Return( dec.Raw, nil, ) - m.DB.EXPECT().GetTRCInfo(gomock.Any(), dec.TRC.ISD, scrypto.LatestVer).Return( + m.DB.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: dec.TRC.ISD, Version: scrypto.LatestVer}).Return( trust.TRCInfo{Version: dec.TRC.Version}, nil, ) m.Recurser.EXPECT().AllowRecursion(gomock.Any()).Return(nil) @@ -404,7 +435,8 @@ func TestCryptoProviderGetTRCLatest(t *testing.T) { decoded := loadTRC(t, trc1v1) expected := test.Expect(&m, &decoded) provider := trust.NewCryptoProvider(m.DB, m.Recurser, m.Resolver, m.Router) - trcObj, err := provider.GetTRC(nil, trc1v1.ISD, scrypto.LatestVer, test.Opts) + trcObj, err := provider.GetTRC(nil, + trust.TRCID{ISD: trc1v1.ISD, Version: scrypto.LatestVer}, test.Opts) assert.Equal(t, expected.TRC, trcObj) if test.ExpectedErr != nil { require.Error(t, err) @@ -444,7 +476,8 @@ func TestCryptoProviderGetRawChain(t *testing.T) { "chain in database, allow inactive": { DB: func(t *testing.T, ctrl *gomock.Controller) trust.DB { db := mock_v2.NewMockDB(ctrl) - db.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.Version(1)).Return( + db.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.Version(1)}).Return( loadChain(t, chain110v1).Raw, nil, ) return db @@ -465,7 +498,8 @@ func TestCryptoProviderGetRawChain(t *testing.T) { "not found, resolve success": { DB: func(t *testing.T, ctrl *gomock.Controller) trust.DB { db := mock_v2.NewMockDB(ctrl) - db.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.Version(1)).Return( + db.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.Version(1)}).Return( nil, trust.ErrNotFound, ) return db @@ -498,15 +532,18 @@ func TestCryptoProviderGetRawChain(t *testing.T) { "latest TRC with same key version": { DB: func(t *testing.T, ctrl *gomock.Controller) trust.DB { db := mock_v2.NewMockDB(ctrl) - db.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.Version(1)).Return( + db.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.Version(1)}).Return( loadChain(t, chain110v1).Raw, nil, ) dec := loadTRC(t, trc1v1) - db.EXPECT().GetRawTRC(gomock.Any(), ia110.I, scrypto.LatestVer).Return( + db.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: ia110.I, Version: scrypto.LatestVer}).Return( dec.Raw, nil, ) info := trust.TRCInfo{Validity: *dec.TRC.Validity, Version: 1} - db.EXPECT().GetTRCInfo(gomock.Any(), ia110.I, scrypto.LatestVer).Return(info, nil) + db.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: ia110.I, Version: scrypto.LatestVer}).Return(info, nil) db.EXPECT().GetIssuingKeyInfo(gomock.Any(), ia110, scrypto.Version(1)).Return( trust.KeyInfo{ TRC: trust.TRCInfo{ @@ -545,15 +582,18 @@ func TestCryptoProviderGetRawChain(t *testing.T) { "expired latest chain, fetch active latest": { DB: func(t *testing.T, ctrl *gomock.Controller) trust.DB { db := mock_v2.NewMockDB(ctrl) - db.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.LatestVer).Return( + db.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.LatestVer}).Return( expired(t, chain110v1).Raw, nil, ) dec := loadTRC(t, trc1v1) - db.EXPECT().GetRawTRC(gomock.Any(), ia110.I, scrypto.LatestVer).Return( + db.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: ia110.I, Version: scrypto.LatestVer}).Return( dec.Raw, nil, ) info := trust.TRCInfo{Validity: *dec.TRC.Validity, Version: 1} - db.EXPECT().GetTRCInfo(gomock.Any(), ia110.I, scrypto.LatestVer).Return(info, nil) + db.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: ia110.I, Version: scrypto.LatestVer}).Return(info, nil) db.EXPECT().GetIssuingKeyInfo(gomock.Any(), ia110, scrypto.Version(1)).Return( trust.KeyInfo{ TRC: trust.TRCInfo{ @@ -604,15 +644,18 @@ func TestCryptoProviderGetRawChain(t *testing.T) { "grace TRC with same key version": { DB: func(t *testing.T, ctrl *gomock.Controller) trust.DB { db := mock_v2.NewMockDB(ctrl) - db.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.Version(1)).Return( + db.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.Version(1)}).Return( loadChain(t, chain110v1).Raw, nil, ) dec := loadTRC(t, trc1v1) - db.EXPECT().GetRawTRC(gomock.Any(), ia110.I, scrypto.LatestVer).Return( + db.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: ia110.I, Version: scrypto.LatestVer}).Return( dec.Raw, nil, ) info := trust.TRCInfo{Validity: *dec.TRC.Validity, Version: 1} - db.EXPECT().GetTRCInfo(gomock.Any(), ia110.I, scrypto.LatestVer).Return(info, nil) + db.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: ia110.I, Version: scrypto.LatestVer}).Return(info, nil) db.EXPECT().GetIssuingKeyInfo(gomock.Any(), ia110, scrypto.Version(1)).Return( trust.KeyInfo{ TRC: trust.TRCInfo{ @@ -665,15 +708,18 @@ func TestCryptoProviderGetRawChain(t *testing.T) { "latest TRC with different key version": { DB: func(t *testing.T, ctrl *gomock.Controller) trust.DB { db := mock_v2.NewMockDB(ctrl) - db.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.Version(1)).Return( + db.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.Version(1)}).Return( loadChain(t, chain110v1).Raw, nil, ) dec := loadTRC(t, trc1v1) - db.EXPECT().GetRawTRC(gomock.Any(), ia110.I, scrypto.LatestVer).Return( + db.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: ia110.I, Version: scrypto.LatestVer}).Return( dec.Raw, nil, ) info := trust.TRCInfo{Validity: *dec.TRC.Validity, Version: 1} - db.EXPECT().GetTRCInfo(gomock.Any(), ia110.I, scrypto.LatestVer).Return(info, nil) + db.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: ia110.I, Version: scrypto.LatestVer}).Return(info, nil) db.EXPECT().GetIssuingKeyInfo(gomock.Any(), ia110, scrypto.Version(1)).Return( trust.KeyInfo{ TRC: trust.TRCInfo{ @@ -717,15 +763,18 @@ func TestCryptoProviderGetRawChain(t *testing.T) { "grace TRC with different key version": { DB: func(t *testing.T, ctrl *gomock.Controller) trust.DB { db := mock_v2.NewMockDB(ctrl) - db.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.Version(1)).Return( + db.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.Version(1)}).Return( loadChain(t, chain110v1).Raw, nil, ) dec := loadTRC(t, trc1v1) - db.EXPECT().GetRawTRC(gomock.Any(), ia110.I, scrypto.LatestVer).Return( + db.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: ia110.I, Version: scrypto.LatestVer}).Return( dec.Raw, nil, ) info := trust.TRCInfo{Validity: *dec.TRC.Validity, Version: 1} - db.EXPECT().GetTRCInfo(gomock.Any(), ia110.I, scrypto.LatestVer).Return(info, nil) + db.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: ia110.I, Version: scrypto.LatestVer}).Return(info, nil) db.EXPECT().GetIssuingKeyInfo(gomock.Any(), ia110, scrypto.Version(1)).Return( trust.KeyInfo{ TRC: trust.TRCInfo{ @@ -779,7 +828,8 @@ func TestCryptoProviderGetRawChain(t *testing.T) { "expired latest chain, fetch inactive": { DB: func(t *testing.T, ctrl *gomock.Controller) trust.DB { db := mock_v2.NewMockDB(ctrl) - db.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.LatestVer).Return( + db.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.LatestVer}).Return( expired(t, chain110v1).Raw, nil, ) return db @@ -813,7 +863,8 @@ func TestCryptoProviderGetRawChain(t *testing.T) { "expired latest chain, fetch fails": { DB: func(t *testing.T, ctrl *gomock.Controller) trust.DB { db := mock_v2.NewMockDB(ctrl) - db.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.LatestVer).Return( + db.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.LatestVer}).Return( expired(t, chain110v1).Raw, nil, ) return db @@ -847,10 +898,12 @@ func TestCryptoProviderGetRawChain(t *testing.T) { "failing to fetch TRC": { DB: func(t *testing.T, ctrl *gomock.Controller) trust.DB { db := mock_v2.NewMockDB(ctrl) - db.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.Version(1)).Return( + db.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.Version(1)}).Return( loadChain(t, chain110v1).Raw, nil, ) - db.EXPECT().GetRawTRC(gomock.Any(), ia110.I, scrypto.LatestVer).Return( + db.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: ia110.I, Version: scrypto.LatestVer}).Return( nil, internal, ) return db @@ -872,15 +925,18 @@ func TestCryptoProviderGetRawChain(t *testing.T) { "failing to get key info for issuing TRC": { DB: func(t *testing.T, ctrl *gomock.Controller) trust.DB { db := mock_v2.NewMockDB(ctrl) - db.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.Version(1)).Return( + db.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.Version(1)}).Return( loadChain(t, chain110v1).Raw, nil, ) dec := loadTRC(t, trc1v1) - db.EXPECT().GetRawTRC(gomock.Any(), ia110.I, scrypto.LatestVer).Return( + db.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: ia110.I, Version: scrypto.LatestVer}).Return( dec.Raw, nil, ) info := trust.TRCInfo{Validity: *dec.TRC.Validity, Version: 1} - db.EXPECT().GetTRCInfo(gomock.Any(), ia110.I, scrypto.LatestVer).Return(info, nil) + db.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: ia110.I, Version: scrypto.LatestVer}).Return(info, nil) db.EXPECT().GetIssuingKeyInfo(gomock.Any(), ia110, scrypto.Version(1)).Return( trust.KeyInfo{}, internal, ) @@ -903,15 +959,18 @@ func TestCryptoProviderGetRawChain(t *testing.T) { "failing to get key info for latest TRC": { DB: func(t *testing.T, ctrl *gomock.Controller) trust.DB { db := mock_v2.NewMockDB(ctrl) - db.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.Version(1)).Return( + db.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.Version(1)}).Return( loadChain(t, chain110v1).Raw, nil, ) dec := loadTRC(t, trc1v1) - db.EXPECT().GetRawTRC(gomock.Any(), ia110.I, scrypto.LatestVer).Return( + db.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: ia110.I, Version: scrypto.LatestVer}).Return( dec.Raw, nil, ) info := trust.TRCInfo{Validity: *dec.TRC.Validity, Version: 1} - db.EXPECT().GetTRCInfo(gomock.Any(), ia110.I, scrypto.LatestVer).Return(info, nil) + db.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: ia110.I, Version: scrypto.LatestVer}).Return(info, nil) db.EXPECT().GetIssuingKeyInfo(gomock.Any(), ia110, scrypto.Version(1)).Return( trust.KeyInfo{ TRC: trust.TRCInfo{ @@ -945,15 +1004,18 @@ func TestCryptoProviderGetRawChain(t *testing.T) { "failing to get key info for grace TRC": { DB: func(t *testing.T, ctrl *gomock.Controller) trust.DB { db := mock_v2.NewMockDB(ctrl) - db.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.Version(1)).Return( + db.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.Version(1)}).Return( loadChain(t, chain110v1).Raw, nil, ) dec := loadTRC(t, trc1v1) - db.EXPECT().GetRawTRC(gomock.Any(), ia110.I, scrypto.LatestVer).Return( + db.EXPECT().GetRawTRC(gomock.Any(), + trust.TRCID{ISD: ia110.I, Version: scrypto.LatestVer}).Return( dec.Raw, nil, ) info := trust.TRCInfo{Validity: *dec.TRC.Validity, Version: 1} - db.EXPECT().GetTRCInfo(gomock.Any(), ia110.I, scrypto.LatestVer).Return(info, nil) + db.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: ia110.I, Version: scrypto.LatestVer}).Return(info, nil) db.EXPECT().GetIssuingKeyInfo(gomock.Any(), ia110, scrypto.Version(1)).Return( trust.KeyInfo{ TRC: trust.TRCInfo{ @@ -1000,7 +1062,8 @@ func TestCryptoProviderGetRawChain(t *testing.T) { "database error": { DB: func(t *testing.T, ctrl *gomock.Controller) trust.DB { db := mock_v2.NewMockDB(ctrl) - db.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.Version(1)).Return( + db.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.Version(1)}).Return( nil, internal, ) return db @@ -1022,7 +1085,8 @@ func TestCryptoProviderGetRawChain(t *testing.T) { "not found, local only": { DB: func(t *testing.T, ctrl *gomock.Controller) trust.DB { db := mock_v2.NewMockDB(ctrl) - db.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.Version(1)).Return( + db.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.Version(1)}).Return( nil, trust.ErrNotFound, ) return db @@ -1044,7 +1108,8 @@ func TestCryptoProviderGetRawChain(t *testing.T) { "not found, recursion not allowed": { DB: func(t *testing.T, ctrl *gomock.Controller) trust.DB { db := mock_v2.NewMockDB(ctrl) - db.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.Version(1)).Return( + db.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.Version(1)}).Return( nil, trust.ErrNotFound, ) return db @@ -1068,7 +1133,8 @@ func TestCryptoProviderGetRawChain(t *testing.T) { "not found, router error": { DB: func(t *testing.T, ctrl *gomock.Controller) trust.DB { db := mock_v2.NewMockDB(ctrl) - db.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.Version(1)).Return( + db.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.Version(1)}).Return( nil, trust.ErrNotFound, ) return db @@ -1094,7 +1160,8 @@ func TestCryptoProviderGetRawChain(t *testing.T) { "not found, resolve error": { DB: func(t *testing.T, ctrl *gomock.Controller) trust.DB { db := mock_v2.NewMockDB(ctrl) - db.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.Version(1)).Return( + db.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.Version(1)}).Return( nil, trust.ErrNotFound, ) return db @@ -1128,7 +1195,8 @@ func TestCryptoProviderGetRawChain(t *testing.T) { "not found, server set": { DB: func(t *testing.T, ctrl *gomock.Controller) trust.DB { db := mock_v2.NewMockDB(ctrl) - db.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.Version(1)).Return( + db.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.Version(1)}).Return( nil, trust.ErrNotFound, ) return db @@ -1172,8 +1240,9 @@ func TestCryptoProviderGetRawChain(t *testing.T) { test.Recurser(t, mctrl), test.Resolver(t, mctrl), test.Router(t, mctrl)) - raw, err := p.GetRawChain(nil, test.ChainDesc.IA, test.ChainDesc.Version, - test.Opts, nil) + raw, err := p.GetRawChain(nil, + trust.ChainID{IA: test.ChainDesc.IA, Version: test.ChainDesc.Version}, + test.Opts) xtest.AssertErrorsIs(t, err, test.ExpectedErr) assert.Equal(t, test.ExpectedRaw, raw) }) diff --git a/go/lib/infra/modules/trust/v2/resolver.go b/go/lib/infra/modules/trust/v2/resolver.go index 1009260aea..932c426bfe 100644 --- a/go/lib/infra/modules/trust/v2/resolver.go +++ b/go/lib/infra/modules/trust/v2/resolver.go @@ -20,7 +20,6 @@ import ( "golang.org/x/xerrors" - "github.com/scionproto/scion/go/lib/addr" "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/internal/decoded" "github.com/scionproto/scion/go/lib/log" "github.com/scionproto/scion/go/lib/scrypto" @@ -62,7 +61,7 @@ func (r *resolver) TRC(ctx context.Context, req TRCReq, server net.Addr) (decode } req = req.withVersion(latest) } - prev, err := r.db.GetTRC(ctx, req.ISD, scrypto.LatestVer) + prev, err := r.db.GetTRC(ctx, TRCID{ISD: req.ISD, Version: scrypto.LatestVer}) if err != nil && !xerrors.Is(err, ErrNotFound) { return decoded.TRC{}, serrors.WrapStr("error fetching latest locally available TRC", err) } @@ -205,11 +204,11 @@ func (w *prevWrap) SetTRC(prev *trc.TRC) { w.prev = prev } -func (w *prevWrap) TRC(_ context.Context, isd addr.ISD, version scrypto.Version) (*trc.TRC, error) { - if isd != w.prev.ISD || version != w.prev.Version { +func (w *prevWrap) TRC(_ context.Context, id TRCID) (*trc.TRC, error) { + if id.ISD != w.prev.ISD || id.Version != w.prev.Version { return nil, serrors.New("previous wrapper can only provide a single TRC", - "expected_isd", w.prev.ISD, "expted_version", w.prev.Version, - "actual_isd", isd, "actual_version", version) + "expected_isd", w.prev.ISD, "expected_version", w.prev.Version, + "actual_isd", id.ISD, "actual_version", id.Version) } return w.prev, nil } @@ -221,10 +220,9 @@ type resolveWrap struct { server net.Addr } -func (w resolveWrap) TRC(ctx context.Context, isd addr.ISD, - version scrypto.Version) (*trc.TRC, error) { +func (w resolveWrap) TRC(ctx context.Context, id TRCID) (*trc.TRC, error) { - t, err := w.resolver.db.GetTRC(ctx, isd, version) + t, err := w.resolver.db.GetTRC(ctx, id) switch { case err == nil: return t, nil @@ -232,8 +230,8 @@ func (w resolveWrap) TRC(ctx context.Context, isd addr.ISD, return nil, serrors.WrapStr("error querying DB for TRC", err) } req := TRCReq{ - ISD: isd, - Version: version, + ISD: id.ISD, + Version: id.Version, } decoded, err := w.resolver.TRC(ctx, req, w.server) if err != nil { diff --git a/go/lib/infra/modules/trust/v2/resolver_test.go b/go/lib/infra/modules/trust/v2/resolver_test.go index 81ad07e148..cd8f97daef 100644 --- a/go/lib/infra/modules/trust/v2/resolver_test.go +++ b/go/lib/infra/modules/trust/v2/resolver_test.go @@ -22,7 +22,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/scionproto/scion/go/lib/addr" trust "github.com/scionproto/scion/go/lib/infra/modules/trust/v2" "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/internal/decoded" "github.com/scionproto/scion/go/lib/infra/modules/trust/v2/mock_v2" @@ -45,7 +44,8 @@ func TestResolverTRC(t *testing.T) { }{ "Fetch missing links successfully": { Expect: func(t *testing.T, m mocks) decoded.TRC { - m.DB.EXPECT().GetTRC(gomock.Any(), addr.ISD(1), scrypto.LatestVer).Return( + m.DB.EXPECT().GetTRC(gomock.Any(), + trust.TRCID{ISD: 1, Version: scrypto.LatestVer}).Return( loadTRC(t, trc1v1).TRC, nil, ) req := trust.TRCReq{ISD: 1, Version: scrypto.LatestVer} @@ -56,7 +56,7 @@ func TestResolverTRC(t *testing.T) { m.RPC.EXPECT().GetTRC(gomock.Any(), req, nil).Return(dec.Raw, nil) m.Inserter.EXPECT().InsertTRC(gomock.Any(), dec, gomock.Any()).DoAndReturn( func(_ interface{}, decTRC decoded.TRC, p trust.TRCProviderFunc) error { - prev, err := p(nil, 1, req.Version-1) + prev, err := p(nil, trust.TRCID{ISD: 1, Version: req.Version - 1}) require.NoError(t, err) assert.Equal(t, req.Version-1, prev.Version) assert.Equal(t, dec, decTRC) @@ -70,7 +70,8 @@ func TestResolverTRC(t *testing.T) { }, "DB error": { Expect: func(t *testing.T, m mocks) decoded.TRC { - m.DB.EXPECT().GetTRC(gomock.Any(), addr.ISD(1), scrypto.LatestVer).Return( + m.DB.EXPECT().GetTRC(gomock.Any(), + trust.TRCID{ISD: 1, Version: scrypto.LatestVer}).Return( nil, internal, ) return decoded.TRC{} @@ -80,7 +81,8 @@ func TestResolverTRC(t *testing.T) { }, "Superseded": { Expect: func(t *testing.T, m mocks) decoded.TRC { - m.DB.EXPECT().GetTRC(gomock.Any(), addr.ISD(1), scrypto.LatestVer).Return( + m.DB.EXPECT().GetTRC(gomock.Any(), + trust.TRCID{ISD: 1, Version: scrypto.LatestVer}).Return( loadTRC(t, trc1v3).TRC, nil, ) return decoded.TRC{} @@ -137,13 +139,16 @@ func TestResolverChain(t *testing.T) { decTRC := loadTRC(t, trc1v1) m.Inserter.EXPECT().InsertChain(gomock.Any(), dec, gomock.Any()).DoAndReturn( func(ctx context.Context, dec decoded.Chain, p trust.TRCProviderFunc) error { - trc, err := p(ctx, dec.Issuer.Subject.I, dec.Issuer.Issuer.TRCVersion) + trc, err := p(ctx, trust.TRCID{ + ISD: dec.Issuer.Subject.I, + Version: dec.Issuer.Issuer.TRCVersion, + }) require.NoError(t, err) assert.Equal(t, decTRC.TRC, trc) return err }, ) - m.DB.EXPECT().GetTRC(gomock.Any(), addr.ISD(1), scrypto.Version(1)).Return( + m.DB.EXPECT().GetTRC(gomock.Any(), trust.TRCID{ISD: 1, Version: 1}).Return( decTRC.TRC, nil, ) return dec @@ -203,12 +208,15 @@ func TestResolverChain(t *testing.T) { ) m.Inserter.EXPECT().InsertChain(gomock.Any(), dec, gomock.Any()).DoAndReturn( func(ctx context.Context, dec decoded.Chain, p trust.TRCProviderFunc) error { - _, err := p(ctx, dec.Issuer.Subject.I, dec.Issuer.Issuer.TRCVersion) + _, err := p(ctx, trust.TRCID{ + ISD: dec.Issuer.Subject.I, + Version: dec.Issuer.Issuer.TRCVersion, + }) require.Error(t, err) return err }, ) - m.DB.EXPECT().GetTRC(gomock.Any(), addr.ISD(1), scrypto.Version(1)).Return( + m.DB.EXPECT().GetTRC(gomock.Any(), trust.TRCID{ISD: 1, Version: 1}).Return( nil, internal, ) return decoded.Chain{} @@ -225,15 +233,19 @@ func TestResolverChain(t *testing.T) { ) m.Inserter.EXPECT().InsertChain(gomock.Any(), dec, gomock.Any()).DoAndReturn( func(ctx context.Context, dec decoded.Chain, p trust.TRCProviderFunc) error { - _, err := p(ctx, dec.Issuer.Subject.I, dec.Issuer.Issuer.TRCVersion) + _, err := p(ctx, trust.TRCID{ + ISD: dec.Issuer.Subject.I, + Version: dec.Issuer.Issuer.TRCVersion, + }) xtest.AssertErrorsIs(t, err, internal) return err }, ) - m.DB.EXPECT().GetTRC(gomock.Any(), addr.ISD(1), scrypto.Version(1)).Return( + m.DB.EXPECT().GetTRC(gomock.Any(), trust.TRCID{ISD: 1, Version: 1}).Return( nil, trust.ErrNotFound, ) - m.DB.EXPECT().GetTRC(gomock.Any(), addr.ISD(1), scrypto.LatestVer).Return( + m.DB.EXPECT().GetTRC(gomock.Any(), + trust.TRCID{ISD: 1, Version: scrypto.LatestVer}).Return( nil, internal, ) return decoded.Chain{} diff --git a/go/lib/infra/modules/trust/v2/router.go b/go/lib/infra/modules/trust/v2/router.go index ac67d73411..32d4b88d6a 100644 --- a/go/lib/infra/modules/trust/v2/router.go +++ b/go/lib/infra/modules/trust/v2/router.go @@ -78,7 +78,7 @@ func (r *csRouter) dstISD(ctx context.Context, destination addr.ISD) (addr.ISD, if destination == r.isd { return r.isd, nil } - info, err := r.db.GetTRCInfo(ctx, destination, scrypto.Version(scrypto.LatestVer)) + info, err := r.db.GetTRCInfo(ctx, TRCID{ISD: destination, Version: scrypto.LatestVer}) if err != nil { if xerrors.Is(err, ErrNotFound) { return r.isd, nil diff --git a/go/lib/infra/modules/trust/v2/router_test.go b/go/lib/infra/modules/trust/v2/router_test.go index 7160ca1935..9cdf9533ea 100644 --- a/go/lib/infra/modules/trust/v2/router_test.go +++ b/go/lib/infra/modules/trust/v2/router_test.go @@ -79,7 +79,8 @@ func TestCSRouterChooseServer(t *testing.T) { ISD: 2, Expect: func(db *mock_v2.MockDB, r *mock_snet.MockRouter, p *mock_snet.MockPath) { future := util.UnixTime{Time: time.Now().Add(time.Hour)} - db.EXPECT().GetTRCInfo(gomock.Any(), addr.ISD(2), scrypto.Version(0)).Return( + db.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: addr.ISD(2), Version: scrypto.LatestVer}).Return( trust.TRCInfo{Validity: scrypto.Validity{NotAfter: future}}, nil, ) p.EXPECT().Path().AnyTimes().Return(&spath.Path{Raw: []byte("remote ISD path")}) @@ -91,7 +92,8 @@ func TestCSRouterChooseServer(t *testing.T) { "Remote ISD, TRC not found": { ISD: 2, Expect: func(db *mock_v2.MockDB, r *mock_snet.MockRouter, p *mock_snet.MockPath) { - db.EXPECT().GetTRCInfo(gomock.Any(), addr.ISD(2), scrypto.Version(0)).Return( + db.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: addr.ISD(2), Version: scrypto.LatestVer}).Return( trust.TRCInfo{}, trust.ErrNotFound, ) p.EXPECT().Path().AnyTimes().Return(&spath.Path{Raw: []byte("isd local path")}) @@ -104,7 +106,8 @@ func TestCSRouterChooseServer(t *testing.T) { ISD: 2, Expect: func(db *mock_v2.MockDB, r *mock_snet.MockRouter, p *mock_snet.MockPath) { passed := util.UnixTime{Time: time.Now().Add(-time.Second)} - db.EXPECT().GetTRCInfo(gomock.Any(), addr.ISD(2), scrypto.Version(0)).Return( + db.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: addr.ISD(2), Version: scrypto.LatestVer}).Return( trust.TRCInfo{Validity: scrypto.Validity{NotAfter: passed}}, nil, ) p.EXPECT().Path().AnyTimes().Return(&spath.Path{Raw: []byte("isd local path")}) @@ -116,7 +119,8 @@ func TestCSRouterChooseServer(t *testing.T) { "Remote ISD, DB error": { ISD: 2, Expect: func(db *mock_v2.MockDB, r *mock_snet.MockRouter, p *mock_snet.MockPath) { - db.EXPECT().GetTRCInfo(gomock.Any(), addr.ISD(2), scrypto.Version(0)).Return( + db.EXPECT().GetTRCInfo(gomock.Any(), + trust.TRCID{ISD: addr.ISD(2), Version: scrypto.LatestVer}).Return( trust.TRCInfo{}, common.NewBasicError("DB error", nil), ) }, diff --git a/go/lib/infra/modules/trust/v2/rpc.go b/go/lib/infra/modules/trust/v2/rpc.go index 5bc138c328..0edf78910c 100644 --- a/go/lib/infra/modules/trust/v2/rpc.go +++ b/go/lib/infra/modules/trust/v2/rpc.go @@ -28,7 +28,7 @@ import ( // RPC abstracts the RPC calls over the messenger. type RPC interface { GetTRC(context.Context, TRCReq, net.Addr) ([]byte, error) - GetCertChain(ctx context.Context, msg ChainReq, a net.Addr) ([]byte, error) + GetCertChain(context.Context, ChainReq, net.Addr) ([]byte, error) SendTRC(context.Context, []byte, net.Addr) error SendCertChain(context.Context, []byte, net.Addr) error } diff --git a/go/lib/infra/modules/trust/v2/signer.go b/go/lib/infra/modules/trust/v2/signer.go index b786d56cbf..a4acb754c3 100644 --- a/go/lib/infra/modules/trust/v2/signer.go +++ b/go/lib/infra/modules/trust/v2/signer.go @@ -123,7 +123,8 @@ type SignerGen struct { // Signer returns the active signer. func (g *SignerGen) Signer(ctx context.Context) (*Signer, error) { - raw, err := g.Provider.GetRawChain(ctx, g.IA, scrypto.LatestVer, infra.ChainOpts{}, nil) + raw, err := g.Provider.GetRawChain(ctx, ChainID{IA: g.IA, Version: scrypto.LatestVer}, + infra.ChainOpts{}) if err != nil { return nil, serrors.WrapStr("error fetching latest chain", err) } @@ -144,7 +145,8 @@ func (g *SignerGen) Signer(ctx context.Context) (*Signer, error) { return nil, serrors.WrapStr("public key does not match", err, "chain_version", dec.AS.Version) } - trc, err := g.Provider.GetTRC(ctx, g.IA.I, scrypto.LatestVer, infra.TRCOpts{}) + trc, err := g.Provider.GetTRC(ctx, TRCID{ISD: g.IA.I, Version: scrypto.LatestVer}, + infra.TRCOpts{}) if err != nil { return nil, serrors.WrapStr("unable to get latest TRC", err) } diff --git a/go/lib/infra/modules/trust/v2/signer_test.go b/go/lib/infra/modules/trust/v2/signer_test.go index e99c8af493..bf2831fcac 100644 --- a/go/lib/infra/modules/trust/v2/signer_test.go +++ b/go/lib/infra/modules/trust/v2/signer_test.go @@ -196,8 +196,9 @@ func TestSignerGenSigner(t *testing.T) { "chain lookup fails": { Provider: func(t *testing.T, ctrl *gomock.Controller) trust.CryptoProvider { p := mock_v2.NewMockCryptoProvider(ctrl) - p.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.LatestVer, - infra.ChainOpts{}, nil).Return(nil, internal) + p.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.LatestVer}, + infra.ChainOpts{}).Return(nil, internal) return p }, KeyRing: func(t *testing.T, ctrl *gomock.Controller) trust.KeyRing { @@ -208,8 +209,9 @@ func TestSignerGenSigner(t *testing.T) { "garbage chain": { Provider: func(t *testing.T, ctrl *gomock.Controller) trust.CryptoProvider { p := mock_v2.NewMockCryptoProvider(ctrl) - p.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.LatestVer, - infra.ChainOpts{}, nil).Return([]byte("garbage"), nil) + p.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.LatestVer}, + infra.ChainOpts{}).Return([]byte("garbage"), nil) return p }, KeyRing: func(t *testing.T, ctrl *gomock.Controller) trust.KeyRing { @@ -220,8 +222,9 @@ func TestSignerGenSigner(t *testing.T) { "key not found": { Provider: func(t *testing.T, ctrl *gomock.Controller) trust.CryptoProvider { p := mock_v2.NewMockCryptoProvider(ctrl) - p.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.LatestVer, - infra.ChainOpts{}, nil).Return(loadChain(t, chain110v1).Raw, nil) + p.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.LatestVer}, + infra.ChainOpts{}).Return(loadChain(t, chain110v1).Raw, nil) return p }, KeyRing: func(t *testing.T, ctrl *gomock.Controller) trust.KeyRing { @@ -234,8 +237,9 @@ func TestSignerGenSigner(t *testing.T) { "garbage private key": { Provider: func(t *testing.T, ctrl *gomock.Controller) trust.CryptoProvider { p := mock_v2.NewMockCryptoProvider(ctrl) - p.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.LatestVer, - infra.ChainOpts{}, nil).Return(loadChain(t, chain110v1).Raw, nil) + p.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.LatestVer}, + infra.ChainOpts{}).Return(loadChain(t, chain110v1).Raw, nil) return p }, KeyRing: func(t *testing.T, ctrl *gomock.Controller) trust.KeyRing { @@ -251,8 +255,9 @@ func TestSignerGenSigner(t *testing.T) { "differing key": { Provider: func(t *testing.T, ctrl *gomock.Controller) trust.CryptoProvider { p := mock_v2.NewMockCryptoProvider(ctrl) - p.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.LatestVer, - infra.ChainOpts{}, nil).Return(loadChain(t, chain110v1).Raw, nil) + p.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.LatestVer}, + infra.ChainOpts{}).Return(loadChain(t, chain110v1).Raw, nil) return p }, KeyRing: func(t *testing.T, ctrl *gomock.Controller) trust.KeyRing { @@ -268,9 +273,11 @@ func TestSignerGenSigner(t *testing.T) { "getting TRC fails": { Provider: func(t *testing.T, ctrl *gomock.Controller) trust.CryptoProvider { p := mock_v2.NewMockCryptoProvider(ctrl) - p.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.LatestVer, - infra.ChainOpts{}, nil).Return(loadChain(t, chain110v1).Raw, nil) - p.EXPECT().GetTRC(gomock.Any(), ia110.I, scrypto.LatestVer, gomock.Any()).Return( + p.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.LatestVer}, + infra.ChainOpts{}).Return(loadChain(t, chain110v1).Raw, nil) + p.EXPECT().GetTRC(gomock.Any(), + trust.TRCID{ia110.I, scrypto.LatestVer}, infra.TRCOpts{}).Return( nil, internal, ) return p @@ -286,9 +293,11 @@ func TestSignerGenSigner(t *testing.T) { "invalid IA": { Provider: func(t *testing.T, ctrl *gomock.Controller) trust.CryptoProvider { p := mock_v2.NewMockCryptoProvider(ctrl) - p.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.LatestVer, - infra.ChainOpts{}, nil).Return(loadChain(t, chain110v1).Raw, nil) - p.EXPECT().GetTRC(gomock.Any(), ia110.I, scrypto.LatestVer, gomock.Any()).Return( + p.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.LatestVer}, + infra.ChainOpts{}).Return(loadChain(t, chain110v1).Raw, nil) + p.EXPECT().GetTRC(gomock.Any(), + trust.TRCID{ISD: ia110.I, Version: scrypto.LatestVer}, infra.TRCOpts{}).Return( loadTRC(t, trc1v1).TRC, nil, ) return p @@ -306,9 +315,11 @@ func TestSignerGenSigner(t *testing.T) { "valid": { Provider: func(t *testing.T, ctrl *gomock.Controller) trust.CryptoProvider { p := mock_v2.NewMockCryptoProvider(ctrl) - p.EXPECT().GetRawChain(gomock.Any(), ia110, scrypto.LatestVer, - infra.ChainOpts{}, nil).Return(loadChain(t, chain110v1).Raw, nil) - p.EXPECT().GetTRC(gomock.Any(), ia110.I, scrypto.LatestVer, gomock.Any()).Return( + p.EXPECT().GetRawChain(gomock.Any(), + trust.ChainID{IA: ia110, Version: scrypto.LatestVer}, + infra.ChainOpts{}).Return(loadChain(t, chain110v1).Raw, nil) + p.EXPECT().GetTRC(gomock.Any(), + trust.TRCID{ISD: ia110.I, Version: scrypto.LatestVer}, infra.TRCOpts{}).Return( loadTRC(t, trc1v1).TRC, nil, ) return p diff --git a/go/lib/infra/modules/trust/v2/trustdbsqlite/db.go b/go/lib/infra/modules/trust/v2/trustdbsqlite/db.go index 594f3848f0..ae700d13a4 100644 --- a/go/lib/infra/modules/trust/v2/trustdbsqlite/db.go +++ b/go/lib/infra/modules/trust/v2/trustdbsqlite/db.go @@ -120,18 +120,17 @@ func (e *executor) TRCExists(ctx context.Context, d decoded.TRC) (bool, error) { return trcExists(ctx, e.db, d) } -func (e *executor) GetTRC(ctx context.Context, isd addr.ISD, - version scrypto.Version) (*trc.TRC, error) { +func (e *executor) GetTRC(ctx context.Context, id trust.TRCID) (*trc.TRC, error) { e.RLock() defer e.RUnlock() var pld []byte query := `SELECT pld FROM trcs WHERE isd_id=? AND version=?` - if version.IsLatest() { + if id.Version.IsLatest() { query = `SELECT pld FROM (SELECT pld, max(version) FROM trcs WHERE isd_id=?) WHERE pld IS NOT NULL` } - err := e.db.QueryRowContext(ctx, query, isd, version).Scan(&pld) + err := e.db.QueryRowContext(ctx, query, id.ISD, id.Version).Scan(&pld) switch { case err == sql.ErrNoRows: return nil, trust.ErrNotFound @@ -141,38 +140,36 @@ func (e *executor) GetTRC(ctx context.Context, isd addr.ISD, return trc.Encoded(pld).Decode() } -func (e *executor) GetRawTRC(ctx context.Context, isd addr.ISD, - version scrypto.Version) ([]byte, error) { +func (e *executor) GetRawTRC(ctx context.Context, id trust.TRCID) ([]byte, error) { e.RLock() defer e.RUnlock() query := `SELECT raw FROM trcs WHERE isd_id=? AND version=?` - if version.IsLatest() { + if id.Version.IsLatest() { query = `SELECT raw FROM (SELECT raw, max(version) FROM trcs WHERE isd_id=?) WHERE raw IS NOT NULL` } var raw []byte - err := e.db.QueryRowContext(ctx, query, isd, version).Scan(&raw) + err := e.db.QueryRowContext(ctx, query, id.ISD, id.Version).Scan(&raw) if err == sql.ErrNoRows { return nil, trust.ErrNotFound } return raw, err } -func (e *executor) GetTRCInfo(ctx context.Context, isd addr.ISD, - version scrypto.Version) (trust.TRCInfo, error) { - +func (e *executor) GetTRCInfo(ctx context.Context, id trust.TRCID) (trust.TRCInfo, error) { e.RLock() defer e.RUnlock() query := `SELECT version, not_before, not_after, grace_period from trcs WHERE isd_id=? AND version=?` - if version.IsLatest() { + if id.Version.IsLatest() { query = `SELECT max(version), not_before, not_after, grace_period from trcs WHERE isd_id=?` } var ver scrypto.Version var grace int var notBefore, notAfter uint32 - err := e.db.QueryRowContext(ctx, query, isd, version).Scan(&ver, ¬Before, ¬After, &grace) + err := e.db.QueryRowContext(ctx, query, id.ISD, id.Version). + Scan(&ver, ¬Before, ¬After, &grace) switch { case err == sql.ErrNoRows: return trust.TRCInfo{}, trust.ErrNotFound @@ -196,7 +193,7 @@ func (e *executor) GetIssuingKeyInfo(ctx context.Context, ia addr.IA, // we chose the simple way to implement this, if this ever becomes a // performance bottleneck we can still add a separate table which contains // this information. - t, err := e.GetTRC(ctx, ia.I, version) + t, err := e.GetTRC(ctx, trust.TRCID{ISD: ia.I, Version: version}) if err != nil { return trust.KeyInfo{}, err } @@ -250,18 +247,17 @@ func (e *executor) InsertTRC(ctx context.Context, d decoded.TRC) (bool, error) { return inserted, nil } -func (e *executor) GetRawChain(ctx context.Context, ia addr.IA, - version scrypto.Version) ([]byte, error) { +func (e *executor) GetRawChain(ctx context.Context, id trust.ChainID) ([]byte, error) { e.RLock() defer e.RUnlock() query := `SELECT raw FROM chains WHERE isd_id=? AND as_id=? AND version=?` - if version.IsLatest() { + if id.Version.IsLatest() { query = `SELECT raw FROM (SELECT raw, max(version) FROM chains WHERE isd_id=? AND as_id=?) WHERE raw IS NOT NULL` } var raw []byte - err := e.db.QueryRowContext(ctx, query, ia.I, ia.A, version).Scan(&raw) + err := e.db.QueryRowContext(ctx, query, id.IA.I, id.IA.A, id.Version).Scan(&raw) if err == sql.ErrNoRows { return nil, trust.ErrNotFound } diff --git a/go/lib/infra/modules/trust/v2/trustdbtest/trustdbtest.go b/go/lib/infra/modules/trust/v2/trustdbtest/trustdbtest.go index 31d91056c9..b414e78d96 100644 --- a/go/lib/infra/modules/trust/v2/trustdbtest/trustdbtest.go +++ b/go/lib/infra/modules/trust/v2/trustdbtest/trustdbtest.go @@ -139,28 +139,28 @@ func testTRC(t *testing.T, db trust.ReadWrite, cfg Config) { }) t.Run("GetTRC", func(t *testing.T) { // Fetch existing TRC. - fetched, err := db.GetTRC(ctx, v1.TRC.ISD, v1.TRC.Version) + fetched, err := db.GetTRC(ctx, trust.TRCID{ISD: v1.TRC.ISD, Version: v1.TRC.Version}) assert.NoError(t, err) assert.Equal(t, v1.TRC, fetched) // Fetch max of existing TRC. - max, err := db.GetTRC(ctx, v1.TRC.ISD, scrypto.LatestVer) + max, err := db.GetTRC(ctx, trust.TRCID{ISD: v1.TRC.ISD, Version: scrypto.LatestVer}) assert.NoError(t, err) assert.Equal(t, v7.TRC, max) // Fetch inexistent TRC. - _, err = db.GetTRC(ctx, 42, scrypto.LatestVer) + _, err = db.GetTRC(ctx, trust.TRCID{ISD: 42, Version: scrypto.LatestVer}) xtest.AssertErrorsIs(t, err, trust.ErrNotFound) }) t.Run("GetRawTRC", func(t *testing.T) { // Fetch existing TRC. - fetched, err := db.GetRawTRC(ctx, v1.TRC.ISD, v1.TRC.Version) + fetched, err := db.GetRawTRC(ctx, trust.TRCID{ISD: v1.TRC.ISD, Version: v1.TRC.Version}) assert.NoError(t, err) assert.Equal(t, v1.Raw, fetched) // Fetch max of existing TRC. - max, err := db.GetRawTRC(ctx, v1.TRC.ISD, scrypto.LatestVer) + max, err := db.GetRawTRC(ctx, trust.TRCID{ISD: v1.TRC.ISD, Version: scrypto.LatestVer}) assert.NoError(t, err) assert.Equal(t, v7.Raw, max) // Fetch inexistent TRC. - _, err = db.GetRawTRC(ctx, 42, scrypto.LatestVer) + _, err = db.GetRawTRC(ctx, trust.TRCID{ISD: 42, Version: scrypto.LatestVer}) xtest.AssertErrorsIs(t, err, trust.ErrNotFound) }) t.Run("GetTRCInfo", func(t *testing.T) { @@ -170,7 +170,7 @@ func testTRC(t *testing.T, db trust.ReadWrite, cfg Config) { GracePeriod: v1.TRC.GracePeriod.Duration, Version: v1.TRC.Version, } - fetched, err := db.GetTRCInfo(ctx, v1.TRC.ISD, v1.TRC.Version) + fetched, err := db.GetTRCInfo(ctx, trust.TRCID{ISD: v1.TRC.ISD, Version: v1.TRC.Version}) assert.NoError(t, err) assert.Equal(t, info, fetched) // Fetch max of existing TRC. @@ -179,11 +179,11 @@ func testTRC(t *testing.T, db trust.ReadWrite, cfg Config) { GracePeriod: v7.TRC.GracePeriod.Duration, Version: v7.TRC.Version, } - max, err := db.GetTRCInfo(ctx, v1.TRC.ISD, scrypto.LatestVer) + max, err := db.GetTRCInfo(ctx, trust.TRCID{ISD: v1.TRC.ISD, Version: scrypto.LatestVer}) assert.NoError(t, err) assert.Equal(t, info, max) // Fetch inexistent TRC. - _, err = db.GetTRCInfo(ctx, 42, scrypto.LatestVer) + _, err = db.GetTRCInfo(ctx, trust.TRCID{ISD: 42, Version: scrypto.LatestVer}) xtest.AssertErrorsIs(t, err, trust.ErrNotFound) }) t.Run("GetIssuingKeyInfo", func(t *testing.T) { @@ -297,15 +297,20 @@ func testChain(t *testing.T, db trust.ReadWrite, cfg Config) { }) t.Run("GetRawChain", func(t *testing.T) { // Check existing certificate chain. - fetched, err := db.GetRawChain(ctx, v1.AS.Subject, v1.AS.Version) + fetched, err := db.GetRawChain(ctx, trust.ChainID{ + IA: v1.AS.Subject, Version: v1.AS.Version}) assert.NoError(t, err) assert.Equal(t, v1.Raw, fetched) // Check max of existing certificate chain. - max, err := db.GetRawChain(ctx, v1.AS.Subject, scrypto.LatestVer) + max, err := db.GetRawChain(ctx, trust.ChainID{ + IA: v1.AS.Subject, Version: scrypto.LatestVer}) assert.NoError(t, err) assert.Equal(t, v7.Raw, max) // Check inexistent certificate chain. - _, err = db.GetRawChain(ctx, xtest.MustParseIA("42-ff00:0:142"), scrypto.LatestVer) + _, err = db.GetRawChain(ctx, trust.ChainID{ + IA: xtest.MustParseIA("42-ff00:0:142"), + Version: scrypto.LatestVer, + }) xtest.AssertErrorsIs(t, err, trust.ErrNotFound) }) t.Run("InsertChain", func(t *testing.T) { @@ -352,7 +357,7 @@ func testRollback(t *testing.T, db trust.DB, cfg Config) { err = tx.Rollback() assert.NoError(t, err) // Check that TRC is not in database after rollback. - _, err = db.GetTRCInfo(ctx, 1, scrypto.LatestVer) + _, err = db.GetTRCInfo(ctx, trust.TRCID{ISD: 1, Version: scrypto.LatestVer}) xtest.AssertErrorsIs(t, err, trust.ErrNotFound) } diff --git a/go/lib/infra/modules/trust/v2/verifier.go b/go/lib/infra/modules/trust/v2/verifier.go index f11fc64492..f058c4eefd 100644 --- a/go/lib/infra/modules/trust/v2/verifier.go +++ b/go/lib/infra/modules/trust/v2/verifier.go @@ -106,7 +106,7 @@ func (v *verifier) Verify(ctx context.Context, msg []byte, sign *proto.SignS) er } id := ChainID{IA: src.IA, Version: src.ChainVer} - opts := &infra.ChainOpts{ + opts := infra.ChainOpts{ TrustStoreOpts: infra.TrustStoreOpts{Server: v.Server}, } diff --git a/go/lib/infra/modules/trust/v2/verifier_test.go b/go/lib/infra/modules/trust/v2/verifier_test.go index fddfe56991..42ba300810 100644 --- a/go/lib/infra/modules/trust/v2/verifier_test.go +++ b/go/lib/infra/modules/trust/v2/verifier_test.go @@ -103,7 +103,7 @@ func TestVerify(t *testing.T) { defer ctrl.Finish() p := mock_v2.NewMockCryptoProvider(ctrl) p.EXPECT().GetASKey(gomock.Any(), gomock.Any(), - gomock.Any()).Return(&scrypto.KeyMeta{Key: public, Algorithm: scrypto.Ed25519}, nil) + gomock.Any()).Return(scrypto.KeyMeta{Key: public, Algorithm: scrypto.Ed25519}, nil) v := &trust.Verifier{ Store: p,