diff --git a/piecestore/impl/piecestore.go b/piecestore/impl/piecestore.go index 60a7b904..f82cf84a 100644 --- a/piecestore/impl/piecestore.go +++ b/piecestore/impl/piecestore.go @@ -8,12 +8,14 @@ import ( "github.com/ipfs/go-datastore" "github.com/ipfs/go-datastore/namespace" logging "github.com/ipfs/go-log/v2" + "golang.org/x/xerrors" versioning "github.com/filecoin-project/go-ds-versioning/pkg" versioned "github.com/filecoin-project/go-ds-versioning/pkg/statestore" "github.com/filecoin-project/go-fil-markets/piecestore" "github.com/filecoin-project/go-fil-markets/piecestore/migrations" + "github.com/filecoin-project/go-fil-markets/retrievalmarket" "github.com/filecoin-project/go-fil-markets/shared" ) @@ -144,6 +146,9 @@ func (ps *pieceStore) ListCidInfoKeys() ([]cid.Cid, error) { func (ps *pieceStore) GetPieceInfo(pieceCID cid.Cid) (piecestore.PieceInfo, error) { var out piecestore.PieceInfo if err := ps.pieces.Get(pieceCID).Get(&out); err != nil { + if xerrors.Is(err, datastore.ErrNotFound) { + return piecestore.PieceInfo{}, xerrors.Errorf("piece with CID %s: %w", pieceCID, retrievalmarket.ErrNotFound) + } return piecestore.PieceInfo{}, err } return out, nil @@ -153,6 +158,9 @@ func (ps *pieceStore) GetPieceInfo(pieceCID cid.Cid) (piecestore.PieceInfo, erro func (ps *pieceStore) GetCIDInfo(payloadCID cid.Cid) (piecestore.CIDInfo, error) { var out piecestore.CIDInfo if err := ps.cidInfos.Get(payloadCID).Get(&out); err != nil { + if xerrors.Is(err, datastore.ErrNotFound) { + return piecestore.CIDInfo{}, xerrors.Errorf("payload CID %s: %w", payloadCID, retrievalmarket.ErrNotFound) + } return piecestore.CIDInfo{}, err } return out, nil diff --git a/piecestore/impl/piecestore_test.go b/piecestore/impl/piecestore_test.go index c9e06c6d..d7c5eb98 100644 --- a/piecestore/impl/piecestore_test.go +++ b/piecestore/impl/piecestore_test.go @@ -11,6 +11,7 @@ import ( "github.com/ipfs/go-datastore/namespace" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/xerrors" "github.com/filecoin-project/go-state-types/abi" "github.com/filecoin-project/go-statestore" @@ -18,12 +19,14 @@ import ( "github.com/filecoin-project/go-fil-markets/piecestore" piecestoreimpl "github.com/filecoin-project/go-fil-markets/piecestore/impl" "github.com/filecoin-project/go-fil-markets/piecestore/migrations" + "github.com/filecoin-project/go-fil-markets/retrievalmarket" "github.com/filecoin-project/go-fil-markets/shared_testutil" ) func TestStorePieceInfo(t *testing.T) { ctx := context.Background() pieceCid := shared_testutil.GenerateCids(1)[0] + pieceCid2 := shared_testutil.GenerateCids(1)[0] initializePieceStore := func(t *testing.T, ctx context.Context) piecestore.PieceStore { ps, err := piecestoreimpl.NewPieceStore(datastore.NewMapDatastore()) require.NoError(t, err) @@ -51,6 +54,11 @@ func TestStorePieceInfo(t *testing.T) { assert.NoError(t, err) assert.Len(t, pi.Deals, 1) assert.Equal(t, pi.Deals[0], dealInfo) + + // Verify that getting a piece with a non-existent CID returns ErrNotFound + pi, err = ps.GetPieceInfo(pieceCid2) + assert.Error(t, err) + assert.True(t, xerrors.Is(err, retrievalmarket.ErrNotFound)) }) t.Run("adding same deal twice does not dup", func(t *testing.T) { @@ -86,7 +94,7 @@ func TestStoreCIDInfo(t *testing.T) { pieceCids := shared_testutil.GenerateCids(2) pieceCid1 := pieceCids[0] pieceCid2 := pieceCids[1] - testCIDs := shared_testutil.GenerateCids(3) + testCIDs := shared_testutil.GenerateCids(4) blockLocations := make([]piecestore.BlockLocation, 0, 3) for i := 0; i < 3; i++ { blockLocations = append(blockLocations, piecestore.BlockLocation{ @@ -129,6 +137,11 @@ func TestStoreCIDInfo(t *testing.T) { assert.NoError(t, err) assert.Len(t, ci.PieceBlockLocations, 1) assert.Equal(t, ci.PieceBlockLocations[0], piecestore.PieceBlockLocation{BlockLocation: blockLocations[2], PieceCID: pieceCid1}) + + // Verify that getting CID info with a non-existent CID returns ErrNotFound + ci, err = ps.GetCIDInfo(testCIDs[3]) + assert.Error(t, err) + assert.True(t, xerrors.Is(err, retrievalmarket.ErrNotFound)) }) t.Run("overlapping adds", func(t *testing.T) {