From b41ace040794d456d4b48a86179a81ddcd57cd16 Mon Sep 17 00:00:00 2001 From: Dirk McCormick Date: Tue, 28 Jul 2020 19:16:47 -0400 Subject: [PATCH] WIP: fix payment channel locking --- api/api_full.go | 3 +- api/apistruct/struct.go | 11 +- api/test/paych.go | 12 +- cli/paych.go | 2 +- gen/main.go | 1 + node/builder.go | 2 + node/impl/paych/paych.go | 29 +- paychmgr/accessorcache.go | 67 ++++ paychmgr/cbor_gen.go | 419 ++++++++++++++++++- paychmgr/channellock.go | 33 ++ paychmgr/manager.go | 278 +++++++++++++ paychmgr/msglistener.go | 58 +++ paychmgr/msglistener_test.go | 96 +++++ paychmgr/paych.go | 296 ++++++++------ paychmgr/paych_test.go | 34 +- paychmgr/paychget_test.go | 756 +++++++++++++++++++++++++++++++++++ paychmgr/settle_test.go | 72 ++++ paychmgr/simple.go | 479 +++++++++++++++++++--- paychmgr/state.go | 98 +---- paychmgr/store.go | 376 ++++++++++++----- paychmgr/store_test.go | 12 +- 21 files changed, 2729 insertions(+), 405 deletions(-) create mode 100644 paychmgr/accessorcache.go create mode 100644 paychmgr/channellock.go create mode 100644 paychmgr/manager.go create mode 100644 paychmgr/msglistener.go create mode 100644 paychmgr/msglistener_test.go create mode 100644 paychmgr/paychget_test.go create mode 100644 paychmgr/settle_test.go diff --git a/api/api_full.go b/api/api_full.go index 4dc39581fc..830369fe45 100644 --- a/api/api_full.go +++ b/api/api_full.go @@ -376,7 +376,8 @@ type FullNode interface { // MethodGroup: Paych // The Paych methods are for interacting with and managing payment channels - PaychGet(ctx context.Context, from, to address.Address, ensureFunds types.BigInt) (*ChannelInfo, error) + PaychGet(ctx context.Context, from, to address.Address, amt types.BigInt) (*ChannelInfo, error) + PaychGetWaitReady(context.Context, cid.Cid) (address.Address, error) PaychList(context.Context) ([]address.Address, error) PaychStatus(context.Context, address.Address) (*PaychStatus, error) PaychSettle(context.Context, address.Address) (cid.Cid, error) diff --git a/api/apistruct/struct.go b/api/apistruct/struct.go index 3e22856485..925b827c1f 100644 --- a/api/apistruct/struct.go +++ b/api/apistruct/struct.go @@ -183,7 +183,8 @@ type FullNodeStruct struct { MarketEnsureAvailable func(context.Context, address.Address, address.Address, types.BigInt) (cid.Cid, error) `perm:"sign"` - PaychGet func(ctx context.Context, from, to address.Address, ensureFunds types.BigInt) (*api.ChannelInfo, error) `perm:"sign"` + PaychGet func(ctx context.Context, from, to address.Address, amt types.BigInt) (*api.ChannelInfo, error) `perm:"sign"` + PaychGetWaitReady func(context.Context, cid.Cid) (address.Address, error) `perm:"sign"` // TODO: is perm:"sign" correct? PaychList func(context.Context) ([]address.Address, error) `perm:"read"` PaychStatus func(context.Context, address.Address) (*api.PaychStatus, error) `perm:"read"` PaychSettle func(context.Context, address.Address) (cid.Cid, error) `perm:"sign"` @@ -801,8 +802,12 @@ func (c *FullNodeStruct) MarketEnsureAvailable(ctx context.Context, addr, wallet return c.Internal.MarketEnsureAvailable(ctx, addr, wallet, amt) } -func (c *FullNodeStruct) PaychGet(ctx context.Context, from, to address.Address, ensureFunds types.BigInt) (*api.ChannelInfo, error) { - return c.Internal.PaychGet(ctx, from, to, ensureFunds) +func (c *FullNodeStruct) PaychGet(ctx context.Context, from, to address.Address, amt types.BigInt) (*api.ChannelInfo, error) { + return c.Internal.PaychGet(ctx, from, to, amt) +} + +func (c *FullNodeStruct) PaychGetWaitReady(ctx context.Context, mcid cid.Cid) (address.Address, error) { + return c.Internal.PaychGetWaitReady(ctx, mcid) } func (c *FullNodeStruct) PaychList(ctx context.Context) ([]address.Address, error) { diff --git a/api/test/paych.go b/api/test/paych.go index 1684413a90..8904c7de0e 100644 --- a/api/test/paych.go +++ b/api/test/paych.go @@ -1,18 +1,17 @@ package test import ( - "bytes" "context" "fmt" - "github.com/filecoin-project/specs-actors/actors/builtin" "os" "sync/atomic" "testing" "time" + "github.com/filecoin-project/specs-actors/actors/builtin" + "github.com/filecoin-project/specs-actors/actors/abi" "github.com/filecoin-project/specs-actors/actors/abi/big" - initactor "github.com/filecoin-project/specs-actors/actors/builtin/init" "github.com/filecoin-project/specs-actors/actors/builtin/paych" "github.com/ipfs/go-cid" @@ -77,13 +76,10 @@ func TestPaymentChannels(t *testing.T, b APIBuilder, blocktime time.Duration) { t.Fatal(err) } - res := waitForMessage(ctx, t, paymentCreator, channelInfo.ChannelMessage, time.Second, "channel create") - var params initactor.ExecReturn - err = params.UnmarshalCBOR(bytes.NewReader(res.Receipt.Return)) + channel, err := paymentCreator.PaychGetWaitReady(ctx, channelInfo.ChannelMessage) if err != nil { t.Fatal(err) } - channel := params.RobustAddress // allocate three lanes var lanes []uint64 @@ -129,7 +125,7 @@ func TestPaymentChannels(t *testing.T, b APIBuilder, blocktime time.Duration) { t.Fatal(err) } - res = waitForMessage(ctx, t, paymentCreator, settleMsgCid, time.Second*10, "settle") + res := waitForMessage(ctx, t, paymentCreator, settleMsgCid, time.Second*10, "settle") if res.Receipt.ExitCode != 0 { t.Fatal("Unable to settle payment channel") } diff --git a/cli/paych.go b/cli/paych.go index 969a36df6b..05dc1f319c 100644 --- a/cli/paych.go +++ b/cli/paych.go @@ -28,7 +28,7 @@ var paychCmd = &cli.Command{ var paychGetCmd = &cli.Command{ Name: "get", - Usage: "Create a new payment channel or get existing one", + Usage: "Create a new payment channel or get existing one and add amount to it", ArgsUsage: "[fromAddress toAddress amount]", Action: func(cctx *cli.Context) error { if cctx.Args().Len() != 3 { diff --git a/gen/main.go b/gen/main.go index 01cd756f78..1467a8943d 100644 --- a/gen/main.go +++ b/gen/main.go @@ -35,6 +35,7 @@ func main() { err = gen.WriteMapEncodersToFile("./paychmgr/cbor_gen.go", "paychmgr", paychmgr.VoucherInfo{}, paychmgr.ChannelInfo{}, + paychmgr.MsgInfo{}, ) if err != nil { fmt.Println(err) diff --git a/node/builder.go b/node/builder.go index 171c4a96e3..2561d0ef9b 100644 --- a/node/builder.go +++ b/node/builder.go @@ -108,6 +108,7 @@ const ( HandleIncomingMessagesKey RegisterClientValidatorKey + HandlePaymentChannelManagerKey // miner GetParamsKey @@ -274,6 +275,7 @@ func Online() Option { Override(new(*paychmgr.Store), paychmgr.NewStore), Override(new(*paychmgr.Manager), paychmgr.NewManager), + Override(HandlePaymentChannelManagerKey, paychmgr.HandleManager), Override(new(*market.FundMgr), market.NewFundMgr), Override(SettlePaymentChannelsKey, settler.SettlePaymentChannels), ), diff --git a/node/impl/paych/paych.go b/node/impl/paych/paych.go index c9f2f215da..8e28979f55 100644 --- a/node/impl/paych/paych.go +++ b/node/impl/paych/paych.go @@ -28,8 +28,8 @@ type PaychAPI struct { PaychMgr *paychmgr.Manager } -func (a *PaychAPI) PaychGet(ctx context.Context, from, to address.Address, ensureFunds types.BigInt) (*api.ChannelInfo, error) { - ch, mcid, err := a.PaychMgr.GetPaych(ctx, from, to, ensureFunds) +func (a *PaychAPI) PaychGet(ctx context.Context, from, to address.Address, amt types.BigInt) (*api.ChannelInfo, error) { + ch, mcid, err := a.PaychMgr.GetPaych(ctx, from, to, amt) if err != nil { return nil, err } @@ -40,6 +40,10 @@ func (a *PaychAPI) PaychGet(ctx context.Context, from, to address.Address, ensur }, nil } +func (a *PaychAPI) PaychGetWaitReady(ctx context.Context, mcid cid.Cid) (address.Address, error) { + return a.PaychMgr.GetPaychWaitReady(ctx, mcid) +} + func (a *PaychAPI) PaychAllocateLane(ctx context.Context, ch address.Address) (uint64, error) { return a.PaychMgr.AllocateLane(ch) } @@ -66,7 +70,7 @@ func (a *PaychAPI) PaychNewPayment(ctx context.Context, from, to address.Address ChannelAddr: ch.Channel, Amount: v.Amount, - Lane: uint64(lane), + Lane: lane, Extra: v.Extra, TimeLockMin: v.TimeLockMin, @@ -108,24 +112,7 @@ func (a *PaychAPI) PaychStatus(ctx context.Context, pch address.Address) (*api.P } func (a *PaychAPI) PaychSettle(ctx context.Context, addr address.Address) (cid.Cid, error) { - - ci, err := a.PaychMgr.GetChannelInfo(addr) - if err != nil { - return cid.Undef, err - } - - msg := &types.Message{ - To: addr, - From: ci.Control, - Value: types.NewInt(0), - Method: builtin.MethodsPaych.Settle, - } - smgs, err := a.MpoolPushMessage(ctx, msg) - - if err != nil { - return cid.Undef, err - } - return smgs.Cid(), nil + return a.PaychMgr.Settle(ctx, addr) } func (a *PaychAPI) PaychCollect(ctx context.Context, addr address.Address) (cid.Cid, error) { diff --git a/paychmgr/accessorcache.go b/paychmgr/accessorcache.go new file mode 100644 index 0000000000..11f9d3bc96 --- /dev/null +++ b/paychmgr/accessorcache.go @@ -0,0 +1,67 @@ +package paychmgr + +import "github.com/filecoin-project/go-address" + +// accessorByFromTo gets a channel accessor for a given from / to pair. +// The channel accessor facilitates locking a channel so that operations +// must be performed sequentially on a channel (but can be performed at +// the same time on different channels). +func (pm *Manager) accessorByFromTo(from address.Address, to address.Address) (*channelAccessor, error) { + key := pm.accessorCacheKey(from, to) + + // First take a read lock and check the cache + pm.lk.RLock() + ca, ok := pm.channels[key] + pm.lk.RUnlock() + if ok { + return ca, nil + } + + // Not in cache, so take a write lock + pm.lk.Lock() + defer pm.lk.Unlock() + + // Need to check cache again in case it was updated between releasing read + // lock and taking write lock + ca, ok = pm.channels[key] + if !ok { + // Not in cache, so create a new one and store in cache + ca = pm.addAccessorToCache(from, to) + } + + return ca, nil +} + +// accessorByAddress gets a channel accessor for a given channel address. +// The channel accessor facilitates locking a channel so that operations +// must be performed sequentially on a channel (but can be performed at +// the same time on different channels). +func (pm *Manager) accessorByAddress(ch address.Address) (*channelAccessor, error) { + // Get the channel from / to + pm.lk.RLock() + channelInfo, err := pm.store.ByAddress(ch) + pm.lk.RUnlock() + if err != nil { + return nil, err + } + + // TODO: cache by channel address so we can get by address instead of using from / to + return pm.accessorByFromTo(channelInfo.Control, channelInfo.Target) +} + +// accessorCacheKey returns the cache key use to reference a channel accessor +func (pm *Manager) accessorCacheKey(from address.Address, to address.Address) string { + return from.String() + "->" + to.String() +} + +// addAccessorToCache adds a channel accessor to a cache. Note that channelInfo +// may be nil if the channel hasn't been created yet, but we still want to +// reference the same channel accessor for a given from/to, so that all +// attempts to access a channel use the same lock (the lock on the accessor) +func (pm *Manager) addAccessorToCache(from address.Address, to address.Address) *channelAccessor { + key := pm.accessorCacheKey(from, to) + ca := newChannelAccessor(pm) + // TODO: Use LRU + pm.channels[key] = ca + return ca +} diff --git a/paychmgr/cbor_gen.go b/paychmgr/cbor_gen.go index 8876f6c8ad..57666fe2d1 100644 --- a/paychmgr/cbor_gen.go +++ b/paychmgr/cbor_gen.go @@ -6,6 +6,7 @@ import ( "fmt" "io" + "github.com/filecoin-project/go-address" "github.com/filecoin-project/specs-actors/actors/builtin/paych" cbg "github.com/whyrusleeping/cbor-gen" xerrors "golang.org/x/xerrors" @@ -156,12 +157,35 @@ func (t *ChannelInfo) MarshalCBOR(w io.Writer) error { _, err := w.Write(cbg.CborNull) return err } - if _, err := w.Write([]byte{166}); err != nil { + if _, err := w.Write([]byte{172}); err != nil { return err } scratch := make([]byte, 9) + // t.ChannelID (string) (string) + if len("ChannelID") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"ChannelID\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("ChannelID"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("ChannelID")); err != nil { + return err + } + + if len(t.ChannelID) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.ChannelID was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.ChannelID))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.ChannelID)); err != nil { + return err + } + // t.Channel (address.Address) (struct) if len("Channel") > cbg.MaxLength { return xerrors.Errorf("Value in field \"Channel\" was too long") @@ -267,6 +291,97 @@ func (t *ChannelInfo) MarshalCBOR(w io.Writer) error { return err } + // t.Amount (big.Int) (struct) + if len("Amount") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Amount\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Amount"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Amount")); err != nil { + return err + } + + if err := t.Amount.MarshalCBOR(w); err != nil { + return err + } + + // t.PendingAmount (big.Int) (struct) + if len("PendingAmount") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"PendingAmount\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("PendingAmount"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("PendingAmount")); err != nil { + return err + } + + if err := t.PendingAmount.MarshalCBOR(w); err != nil { + return err + } + + // t.CreateMsg (cid.Cid) (struct) + if len("CreateMsg") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"CreateMsg\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("CreateMsg"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("CreateMsg")); err != nil { + return err + } + + if t.CreateMsg == nil { + if _, err := w.Write(cbg.CborNull); err != nil { + return err + } + } else { + if err := cbg.WriteCidBuf(scratch, w, *t.CreateMsg); err != nil { + return xerrors.Errorf("failed to write cid field t.CreateMsg: %w", err) + } + } + + // t.AddFundsMsg (cid.Cid) (struct) + if len("AddFundsMsg") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"AddFundsMsg\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("AddFundsMsg"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("AddFundsMsg")); err != nil { + return err + } + + if t.AddFundsMsg == nil { + if _, err := w.Write(cbg.CborNull); err != nil { + return err + } + } else { + if err := cbg.WriteCidBuf(scratch, w, *t.AddFundsMsg); err != nil { + return xerrors.Errorf("failed to write cid field t.AddFundsMsg: %w", err) + } + } + + // t.Settling (bool) (bool) + if len("Settling") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Settling\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Settling"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Settling")); err != nil { + return err + } + + if err := cbg.WriteBool(w, t.Settling); err != nil { + return err + } return nil } @@ -303,13 +418,36 @@ func (t *ChannelInfo) UnmarshalCBOR(r io.Reader) error { } switch name { - // t.Channel (address.Address) (struct) + // t.ChannelID (string) (string) + case "ChannelID": + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.ChannelID = string(sval) + } + // t.Channel (address.Address) (struct) case "Channel": { - if err := t.Channel.UnmarshalCBOR(br); err != nil { - return xerrors.Errorf("unmarshaling t.Channel: %w", err) + pb, err := br.PeekByte() + if err != nil { + return err + } + if pb == cbg.CborNull[0] { + var nbuf [1]byte + if _, err := br.Read(nbuf[:]); err != nil { + return err + } + } else { + t.Channel = new(address.Address) + if err := t.Channel.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.Channel pointer: %w", err) + } } } @@ -393,6 +531,279 @@ func (t *ChannelInfo) UnmarshalCBOR(r io.Reader) error { t.NextLane = uint64(extra) } + // t.Amount (big.Int) (struct) + case "Amount": + + { + + if err := t.Amount.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.Amount: %w", err) + } + + } + // t.PendingAmount (big.Int) (struct) + case "PendingAmount": + + { + + if err := t.PendingAmount.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.PendingAmount: %w", err) + } + + } + // t.CreateMsg (cid.Cid) (struct) + case "CreateMsg": + + { + + pb, err := br.PeekByte() + if err != nil { + return err + } + if pb == cbg.CborNull[0] { + var nbuf [1]byte + if _, err := br.Read(nbuf[:]); err != nil { + return err + } + } else { + + c, err := cbg.ReadCid(br) + if err != nil { + return xerrors.Errorf("failed to read cid field t.CreateMsg: %w", err) + } + + t.CreateMsg = &c + } + + } + // t.AddFundsMsg (cid.Cid) (struct) + case "AddFundsMsg": + + { + + pb, err := br.PeekByte() + if err != nil { + return err + } + if pb == cbg.CborNull[0] { + var nbuf [1]byte + if _, err := br.Read(nbuf[:]); err != nil { + return err + } + } else { + + c, err := cbg.ReadCid(br) + if err != nil { + return xerrors.Errorf("failed to read cid field t.AddFundsMsg: %w", err) + } + + t.AddFundsMsg = &c + } + + } + // t.Settling (bool) (bool) + case "Settling": + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajOther { + return fmt.Errorf("booleans must be major type 7") + } + switch extra { + case 20: + t.Settling = false + case 21: + t.Settling = true + default: + return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) + } + + default: + return fmt.Errorf("unknown struct field %d: '%s'", i, name) + } + } + + return nil +} +func (t *MsgInfo) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + if _, err := w.Write([]byte{164}); err != nil { + return err + } + + scratch := make([]byte, 9) + + // t.ChannelID (string) (string) + if len("ChannelID") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"ChannelID\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("ChannelID"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("ChannelID")); err != nil { + return err + } + + if len(t.ChannelID) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.ChannelID was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.ChannelID))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.ChannelID)); err != nil { + return err + } + + // t.MsgCid (cid.Cid) (struct) + if len("MsgCid") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"MsgCid\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("MsgCid"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("MsgCid")); err != nil { + return err + } + + if err := cbg.WriteCidBuf(scratch, w, t.MsgCid); err != nil { + return xerrors.Errorf("failed to write cid field t.MsgCid: %w", err) + } + + // t.Received (bool) (bool) + if len("Received") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Received\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Received"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Received")); err != nil { + return err + } + + if err := cbg.WriteBool(w, t.Received); err != nil { + return err + } + + // t.Err (string) (string) + if len("Err") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Err\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Err"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Err")); err != nil { + return err + } + + if len(t.Err) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.Err was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.Err))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.Err)); err != nil { + return err + } + return nil +} + +func (t *MsgInfo) UnmarshalCBOR(r io.Reader) error { + *t = MsgInfo{} + + br := cbg.GetPeeker(r) + scratch := make([]byte, 8) + + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajMap { + return fmt.Errorf("cbor input should be of type map") + } + + if extra > cbg.MaxLength { + return fmt.Errorf("MsgInfo: map struct too large (%d)", extra) + } + + var name string + n := extra + + for i := uint64(0); i < n; i++ { + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + name = string(sval) + } + + switch name { + // t.ChannelID (string) (string) + case "ChannelID": + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.ChannelID = string(sval) + } + // t.MsgCid (cid.Cid) (struct) + case "MsgCid": + + { + + c, err := cbg.ReadCid(br) + if err != nil { + return xerrors.Errorf("failed to read cid field t.MsgCid: %w", err) + } + + t.MsgCid = c + + } + // t.Received (bool) (bool) + case "Received": + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajOther { + return fmt.Errorf("booleans must be major type 7") + } + switch extra { + case 20: + t.Received = false + case 21: + t.Received = true + default: + return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) + } + // t.Err (string) (string) + case "Err": + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.Err = string(sval) + } default: return fmt.Errorf("unknown struct field %d: '%s'", i, name) diff --git a/paychmgr/channellock.go b/paychmgr/channellock.go new file mode 100644 index 0000000000..0dc785ec0f --- /dev/null +++ b/paychmgr/channellock.go @@ -0,0 +1,33 @@ +package paychmgr + +import "sync" + +type rwlock interface { + RLock() + RUnlock() +} + +// channelLock manages locking for a specific channel. +// Some operations update the state of a single channel, and need to block +// other operations only on the same channel's state. +// Some operations update state that affects all channels, and need to block +// any operation against any channel. +type channelLock struct { + globalLock rwlock + chanLock sync.Mutex +} + +func (l *channelLock) Lock() { + // Wait for other operations by this channel to finish. + // Exclusive per-channel (no other ops by this channel allowed). + l.chanLock.Lock() + // Wait for operations affecting all channels to finish. + // Allows ops by other channels in parallel, but blocks all operations + // if global lock is taken exclusively (eg when adding a channel) + l.globalLock.RLock() +} + +func (l *channelLock) Unlock() { + l.globalLock.RUnlock() + l.chanLock.Unlock() +} diff --git a/paychmgr/manager.go b/paychmgr/manager.go new file mode 100644 index 0000000000..bdc9f516fc --- /dev/null +++ b/paychmgr/manager.go @@ -0,0 +1,278 @@ +package paychmgr + +import ( + "context" + "sync" + + "github.com/ipfs/go-datastore" + + "golang.org/x/sync/errgroup" + + xerrors "golang.org/x/xerrors" + + "github.com/filecoin-project/lotus/api" + + "github.com/filecoin-project/specs-actors/actors/builtin/paych" + + "github.com/ipfs/go-cid" + logging "github.com/ipfs/go-log/v2" + "go.uber.org/fx" + + "github.com/filecoin-project/go-address" + + "github.com/filecoin-project/lotus/chain/stmgr" + "github.com/filecoin-project/lotus/chain/types" + "github.com/filecoin-project/lotus/node/impl/full" +) + +var log = logging.Logger("paych") + +type ManagerApi struct { + fx.In + + full.MpoolAPI + full.WalletAPI + full.StateAPI +} + +type StateManagerApi interface { + LoadActorState(ctx context.Context, a address.Address, out interface{}, ts *types.TipSet) (*types.Actor, error) + Call(ctx context.Context, msg *types.Message, ts *types.TipSet) (*api.InvocResult, error) +} + +type Manager struct { + // The Manager context is used to terminate wait operations on shutdown + ctx context.Context + shutdown context.CancelFunc + + store *Store + sm StateManagerApi + sa *stateAccessor + pchapi paychApi + + lk sync.RWMutex + channels map[string]*channelAccessor + + mpool full.MpoolAPI + wallet full.WalletAPI + state full.StateAPI +} + +type paychAPIImpl struct { + full.MpoolAPI + full.StateAPI +} + +func NewManager(sm *stmgr.StateManager, pchstore *Store, api ManagerApi) *Manager { + return &Manager{ + store: pchstore, + sm: sm, + sa: &stateAccessor{sm: sm}, + channels: make(map[string]*channelAccessor), + // TODO: Is this the correct way to do this or can I do something different + // with dependency injection? + pchapi: &paychAPIImpl{api.MpoolAPI, api.StateAPI}, + + mpool: api.MpoolAPI, + wallet: api.WalletAPI, + state: api.StateAPI, + } +} + +// newManager is used by the tests to supply mocks +func newManager(sm StateManagerApi, pchstore *Store, pchapi paychApi) (*Manager, error) { + pm := &Manager{ + store: pchstore, + sm: sm, + sa: &stateAccessor{sm: sm}, + channels: make(map[string]*channelAccessor), + pchapi: pchapi, + } + return pm, pm.Start(context.Background()) +} + +// HandleManager is called by dependency injection to set up hooks +func HandleManager(lc fx.Lifecycle, pm *Manager) { + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + return pm.Start(ctx) + }, + OnStop: func(context.Context) error { + return pm.Stop() + }, + }) +} + +// Start checks the datastore to see if there are any channels that have +// outstanding add funds messages, and if so, waits on the messages. +// Outstanding messages can occur if an add funds message was sent +// and then lotus was shut down or crashed before the result was +// received. +func (pm *Manager) Start(ctx context.Context) error { + pm.ctx, pm.shutdown = context.WithCancel(ctx) + + cis, err := pm.store.WithPendingAddFunds() + if err != nil { + return err + } + + group := errgroup.Group{} + for _, chanInfo := range cis { + ci := chanInfo + if ci.CreateMsg != nil { + group.Go(func() error { + ca, err := pm.accessorByFromTo(ci.Control, ci.Target) + if err != nil { + return xerrors.Errorf("error initializing payment channel manager %s -> %s: %s", ci.Control, ci.Target, err) + } + go ca.waitForPaychCreateMsg(ci.Control, ci.Target, *ci.CreateMsg, nil) + return nil + }) + } else if ci.AddFundsMsg != nil { + group.Go(func() error { + ca, err := pm.accessorByAddress(*ci.Channel) + if err != nil { + return xerrors.Errorf("error initializing payment channel manager %s: %s", ci.Channel, err) + } + go ca.waitForAddFundsMsg(ci.Control, ci.Target, *ci.AddFundsMsg, nil) + return nil + }) + } + } + + return group.Wait() +} + +// Stop shuts down any processes used by the manager +func (pm *Manager) Stop() error { + pm.shutdown() + return nil +} + +func (pm *Manager) TrackOutboundChannel(ctx context.Context, ch address.Address) error { + return pm.trackChannel(ctx, ch, DirOutbound) +} + +func (pm *Manager) TrackInboundChannel(ctx context.Context, ch address.Address) error { + return pm.trackChannel(ctx, ch, DirInbound) +} + +func (pm *Manager) trackChannel(ctx context.Context, ch address.Address, dir uint64) error { + pm.lk.Lock() + defer pm.lk.Unlock() + + ci, err := pm.sa.loadStateChannelInfo(ctx, ch, dir) + if err != nil { + return err + } + + return pm.store.TrackChannel(ci) +} + +func (pm *Manager) GetPaych(ctx context.Context, from, to address.Address, amt types.BigInt) (address.Address, cid.Cid, error) { + chanAccessor, err := pm.accessorByFromTo(from, to) + if err != nil { + return address.Undef, cid.Undef, err + } + + return chanAccessor.getPaych(ctx, from, to, amt) +} + +// GetPaychWaitReady waits until the create channel / add funds message with the +// given message CID arrives. +// The returned channel address can safely be used against the Manager methods. +func (pm *Manager) GetPaychWaitReady(ctx context.Context, mcid cid.Cid) (address.Address, error) { + // Find the channel associated with the message CID + ci, err := pm.store.ByMessageCid(mcid) + if err != nil { + if err == datastore.ErrNotFound { + return address.Undef, xerrors.Errorf("Could not find wait msg cid %s", mcid) + } + return address.Undef, err + } + + chanAccessor, err := pm.accessorByFromTo(ci.Control, ci.Target) + if err != nil { + return address.Undef, err + } + + return chanAccessor.getPaychWaitReady(ctx, mcid) +} + +func (pm *Manager) ListChannels() ([]address.Address, error) { + // Need to take an exclusive lock here so that channel operations can't run + // in parallel (see channelLock) + pm.lk.Lock() + defer pm.lk.Unlock() + + return pm.store.ListChannels() +} + +func (pm *Manager) GetChannelInfo(addr address.Address) (*ChannelInfo, error) { + ca, err := pm.accessorByAddress(addr) + if err != nil { + return nil, err + } + return ca.getChannelInfo(addr) +} + +// CheckVoucherValid checks if the given voucher is valid (is or could become spendable at some point) +func (pm *Manager) CheckVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) error { + ca, err := pm.accessorByAddress(ch) + if err != nil { + return err + } + + _, err = ca.checkVoucherValid(ctx, ch, sv) + return err +} + +// CheckVoucherSpendable checks if the given voucher is currently spendable +func (pm *Manager) CheckVoucherSpendable(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, secret []byte, proof []byte) (bool, error) { + ca, err := pm.accessorByAddress(ch) + if err != nil { + return false, err + } + + return ca.checkVoucherSpendable(ctx, ch, sv, secret, proof) +} + +func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, proof []byte, minDelta types.BigInt) (types.BigInt, error) { + ca, err := pm.accessorByAddress(ch) + if err != nil { + return types.NewInt(0), err + } + return ca.addVoucher(ctx, ch, sv, proof, minDelta) +} + +func (pm *Manager) AllocateLane(ch address.Address) (uint64, error) { + ca, err := pm.accessorByAddress(ch) + if err != nil { + return 0, err + } + return ca.allocateLane(ch) +} + +func (pm *Manager) ListVouchers(ctx context.Context, ch address.Address) ([]*VoucherInfo, error) { + ca, err := pm.accessorByAddress(ch) + if err != nil { + return nil, err + } + return ca.listVouchers(ctx, ch) +} + +func (pm *Manager) NextNonceForLane(ctx context.Context, ch address.Address, lane uint64) (uint64, error) { + ca, err := pm.accessorByAddress(ch) + if err != nil { + return 0, err + } + return ca.nextNonceForLane(ctx, ch, lane) +} + +func (pm *Manager) Settle(ctx context.Context, addr address.Address) (cid.Cid, error) { + ca, err := pm.accessorByAddress(addr) + if err != nil { + return cid.Undef, err + } + return ca.settle(ctx, addr) +} diff --git a/paychmgr/msglistener.go b/paychmgr/msglistener.go new file mode 100644 index 0000000000..6b4cd8346a --- /dev/null +++ b/paychmgr/msglistener.go @@ -0,0 +1,58 @@ +package paychmgr + +import ( + "sync" + + "github.com/google/uuid" + "github.com/ipfs/go-cid" +) + +type msgListener struct { + id string + cb func(c cid.Cid, err error) +} + +type msgListeners struct { + lk sync.Mutex + listeners []*msgListener +} + +func (ml *msgListeners) onMsg(mcid cid.Cid, cb func(error)) string { + ml.lk.Lock() + defer ml.lk.Unlock() + + l := &msgListener{ + id: uuid.New().String(), + cb: func(c cid.Cid, err error) { + if mcid.Equals(c) { + cb(err) + } + }, + } + ml.listeners = append(ml.listeners, l) + return l.id +} + +func (ml *msgListeners) fireMsgComplete(mcid cid.Cid, err error) { + ml.lk.Lock() + defer ml.lk.Unlock() + + for _, l := range ml.listeners { + l.cb(mcid, err) + } +} + +func (ml *msgListeners) unsubscribe(sub string) { + for i, l := range ml.listeners { + if l.id == sub { + ml.removeListener(i) + return + } + } +} + +func (ml *msgListeners) removeListener(i int) { + copy(ml.listeners[i:], ml.listeners[i+1:]) + ml.listeners[len(ml.listeners)-1] = nil + ml.listeners = ml.listeners[:len(ml.listeners)-1] +} diff --git a/paychmgr/msglistener_test.go b/paychmgr/msglistener_test.go new file mode 100644 index 0000000000..fd457a518d --- /dev/null +++ b/paychmgr/msglistener_test.go @@ -0,0 +1,96 @@ +package paychmgr + +import ( + "testing" + + "github.com/ipfs/go-cid" + + "github.com/stretchr/testify/require" + + "golang.org/x/xerrors" +) + +func testCids() []cid.Cid { + c1, _ := cid.Decode("QmdmGQmRgRjazArukTbsXuuxmSHsMCcRYPAZoGhd6e3MuS") + c2, _ := cid.Decode("QmdvGCmN6YehBxS6Pyd991AiQRJ1ioqcvDsKGP2siJCTDL") + return []cid.Cid{c1, c2} +} + +func TestMsgListener(t *testing.T) { + var ml msgListeners + + done := false + experr := xerrors.Errorf("some err") + cids := testCids() + ml.onMsg(cids[0], func(err error) { + require.Equal(t, experr, err) + done = true + }) + + ml.fireMsgComplete(cids[0], experr) + + if !done { + t.Fatal("failed to fire event") + } +} + +func TestMsgListenerNilErr(t *testing.T) { + var ml msgListeners + + done := false + cids := testCids() + ml.onMsg(cids[0], func(err error) { + require.Nil(t, err) + done = true + }) + + ml.fireMsgComplete(cids[0], nil) + + if !done { + t.Fatal("failed to fire event") + } +} + +func TestMsgListenerUnsub(t *testing.T) { + var ml msgListeners + + done := false + experr := xerrors.Errorf("some err") + cids := testCids() + id1 := ml.onMsg(cids[0], func(err error) { + t.Fatal("should not call unsubscribed listener") + }) + ml.onMsg(cids[0], func(err error) { + require.Equal(t, experr, err) + done = true + }) + + ml.unsubscribe(id1) + ml.fireMsgComplete(cids[0], experr) + + if !done { + t.Fatal("failed to fire event") + } +} + +func TestMsgListenerMulti(t *testing.T) { + var ml msgListeners + + count := 0 + cids := testCids() + ml.onMsg(cids[0], func(err error) { + count++ + }) + ml.onMsg(cids[0], func(err error) { + count++ + }) + ml.onMsg(cids[1], func(err error) { + count++ + }) + + ml.fireMsgComplete(cids[0], nil) + require.Equal(t, 2, count) + + ml.fireMsgComplete(cids[1], nil) + require.Equal(t, 3, count) +} diff --git a/paychmgr/paych.go b/paychmgr/paych.go index 85db664cdf..f1d5199f6b 100644 --- a/paychmgr/paych.go +++ b/paychmgr/paych.go @@ -5,118 +5,75 @@ import ( "context" "fmt" - "github.com/filecoin-project/specs-actors/actors/abi/big" - - "github.com/filecoin-project/lotus/api" - - cborutil "github.com/filecoin-project/go-cbor-util" - "github.com/filecoin-project/specs-actors/actors/builtin" - "github.com/filecoin-project/specs-actors/actors/builtin/account" - "github.com/filecoin-project/specs-actors/actors/builtin/paych" - "golang.org/x/xerrors" - - logging "github.com/ipfs/go-log/v2" - "go.uber.org/fx" + "github.com/ipfs/go-cid" "github.com/filecoin-project/go-address" - + cborutil "github.com/filecoin-project/go-cbor-util" "github.com/filecoin-project/lotus/chain/actors" - "github.com/filecoin-project/lotus/chain/stmgr" "github.com/filecoin-project/lotus/chain/types" "github.com/filecoin-project/lotus/lib/sigs" - "github.com/filecoin-project/lotus/node/impl/full" + "github.com/filecoin-project/specs-actors/actors/abi/big" + "github.com/filecoin-project/specs-actors/actors/builtin" + "github.com/filecoin-project/specs-actors/actors/builtin/account" + "github.com/filecoin-project/specs-actors/actors/builtin/paych" + xerrors "golang.org/x/xerrors" ) -var log = logging.Logger("paych") - -type ManagerApi struct { - fx.In - - full.MpoolAPI - full.WalletAPI - full.StateAPI +// channelAccessor is used to simplify locking when accessing a channel +type channelAccessor struct { + // waitCtx is used by processes that wait for things to be confirmed + // on chain + waitCtx context.Context + sm StateManagerApi + sa *stateAccessor + api paychApi + store *Store + lk *channelLock + fundsReqQueue []*fundsReq + msgListeners msgListeners } -type StateManagerApi interface { - LoadActorState(ctx context.Context, a address.Address, out interface{}, ts *types.TipSet) (*types.Actor, error) - Call(ctx context.Context, msg *types.Message, ts *types.TipSet) (*api.InvocResult, error) -} - -type Manager struct { - store *Store - sm StateManagerApi - - mpool full.MpoolAPI - wallet full.WalletAPI - state full.StateAPI -} - -func NewManager(sm *stmgr.StateManager, pchstore *Store, api ManagerApi) *Manager { - return &Manager{ - store: pchstore, - sm: sm, - - mpool: api.MpoolAPI, - wallet: api.WalletAPI, - state: api.StateAPI, +func newChannelAccessor(pm *Manager) *channelAccessor { + return &channelAccessor{ + lk: &channelLock{globalLock: &pm.lk}, + sm: pm.sm, + sa: &stateAccessor{sm: pm.sm}, + api: pm.pchapi, + store: pm.store, + waitCtx: pm.ctx, } } -// Used by the tests to supply mocks -func newManager(sm StateManagerApi, pchstore *Store) *Manager { - return &Manager{ - store: pchstore, - sm: sm, - } -} - -func (pm *Manager) TrackOutboundChannel(ctx context.Context, ch address.Address) error { - return pm.trackChannel(ctx, ch, DirOutbound) -} +func (ca *channelAccessor) getChannelInfo(addr address.Address) (*ChannelInfo, error) { + ca.lk.Lock() + defer ca.lk.Unlock() -func (pm *Manager) TrackInboundChannel(ctx context.Context, ch address.Address) error { - return pm.trackChannel(ctx, ch, DirInbound) + return ca.store.ByAddress(addr) } -func (pm *Manager) trackChannel(ctx context.Context, ch address.Address, dir uint64) error { - ci, err := pm.loadStateChannelInfo(ctx, ch, dir) - if err != nil { - return err - } +func (ca *channelAccessor) checkVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) (map[uint64]*paych.LaneState, error) { + ca.lk.Lock() + defer ca.lk.Unlock() - return pm.store.TrackChannel(ci) + return ca.checkVoucherValidUnlocked(ctx, ch, sv) } -func (pm *Manager) ListChannels() ([]address.Address, error) { - return pm.store.ListChannels() -} - -func (pm *Manager) GetChannelInfo(addr address.Address) (*ChannelInfo, error) { - return pm.store.getChannelInfo(addr) -} - -// CheckVoucherValid checks if the given voucher is valid (is or could become spendable at some point) -func (pm *Manager) CheckVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) error { - _, err := pm.checkVoucherValid(ctx, ch, sv) - return err -} - -func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) (map[uint64]*paych.LaneState, error) { +func (ca *channelAccessor) checkVoucherValidUnlocked(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) (map[uint64]*paych.LaneState, error) { if sv.ChannelAddr != ch { return nil, xerrors.Errorf("voucher ChannelAddr doesn't match channel address, got %s, expected %s", sv.ChannelAddr, ch) } - act, pchState, err := pm.loadPaychState(ctx, ch) + act, pchState, err := ca.sa.loadPaychState(ctx, ch) if err != nil { return nil, err } - var account account.State - _, err = pm.sm.LoadActorState(ctx, pchState.From, &account, nil) + var actState account.State + _, err = ca.sm.LoadActorState(ctx, pchState.From, &actState, nil) if err != nil { return nil, err } - from := account.Address + from := actState.Address // verify signature vb, err := sv.SigningBytes() @@ -132,7 +89,7 @@ func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv } // Check the voucher against the highest known voucher nonce / value - laneStates, err := pm.laneState(pchState, ch) + laneStates, err := ca.laneState(pchState, ch) if err != nil { return nil, err } @@ -164,7 +121,7 @@ func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv // lane 2: 2 // - // total: 7 - totalRedeemed, err := pm.totalRedeemedWithVoucher(laneStates, sv) + totalRedeemed, err := ca.totalRedeemedWithVoucher(laneStates, sv) if err != nil { return nil, err } @@ -183,15 +140,17 @@ func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv return laneStates, nil } -// CheckVoucherSpendable checks if the given voucher is currently spendable -func (pm *Manager) CheckVoucherSpendable(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, secret []byte, proof []byte) (bool, error) { - recipient, err := pm.getPaychRecipient(ctx, ch) +func (ca *channelAccessor) checkVoucherSpendable(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, secret []byte, proof []byte) (bool, error) { + ca.lk.Lock() + defer ca.lk.Unlock() + + recipient, err := ca.getPaychRecipient(ctx, ch) if err != nil { return false, err } if sv.Extra != nil && proof == nil { - known, err := pm.ListVouchers(ctx, ch) + known, err := ca.store.VouchersForPaych(ch) if err != nil { return false, err } @@ -221,7 +180,7 @@ func (pm *Manager) CheckVoucherSpendable(ctx context.Context, ch address.Address return false, err } - ret, err := pm.sm.Call(ctx, &types.Message{ + ret, err := ca.sm.Call(ctx, &types.Message{ From: recipient, To: ch, Method: builtin.MethodsPaych.UpdateChannelState, @@ -238,22 +197,22 @@ func (pm *Manager) CheckVoucherSpendable(ctx context.Context, ch address.Address return true, nil } -func (pm *Manager) getPaychRecipient(ctx context.Context, ch address.Address) (address.Address, error) { +func (ca *channelAccessor) getPaychRecipient(ctx context.Context, ch address.Address) (address.Address, error) { var state paych.State - if _, err := pm.sm.LoadActorState(ctx, ch, &state, nil); err != nil { + if _, err := ca.sm.LoadActorState(ctx, ch, &state, nil); err != nil { return address.Address{}, err } return state.To, nil } -func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, proof []byte, minDelta types.BigInt) (types.BigInt, error) { - pm.store.lk.Lock() - defer pm.store.lk.Unlock() +func (ca *channelAccessor) addVoucher(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, proof []byte, minDelta types.BigInt) (types.BigInt, error) { + ca.lk.Lock() + defer ca.lk.Unlock() - ci, err := pm.store.getChannelInfo(ch) + ci, err := ca.store.ByAddress(ch) if err != nil { - return types.NewInt(0), err + return types.BigInt{}, err } // Check if the voucher has already been added @@ -275,7 +234,7 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych Proof: proof, } - return types.NewInt(0), pm.store.putChannelInfo(ci) + return types.NewInt(0), ca.store.putChannelInfo(ci) } // Otherwise just ignore the duplicate voucher @@ -284,7 +243,7 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych } // Check voucher validity - laneStates, err := pm.checkVoucherValid(ctx, ch, sv) + laneStates, err := ca.checkVoucherValidUnlocked(ctx, ch, sv) if err != nil { return types.NewInt(0), err } @@ -311,35 +270,32 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych ci.NextLane = sv.Lane + 1 } - return delta, pm.store.putChannelInfo(ci) + return delta, ca.store.putChannelInfo(ci) } -func (pm *Manager) AllocateLane(ch address.Address) (uint64, error) { +func (ca *channelAccessor) allocateLane(ch address.Address) (uint64, error) { + ca.lk.Lock() + defer ca.lk.Unlock() + // TODO: should this take into account lane state? - return pm.store.AllocateLane(ch) + return ca.store.AllocateLane(ch) } -func (pm *Manager) ListVouchers(ctx context.Context, ch address.Address) ([]*VoucherInfo, error) { +func (ca *channelAccessor) listVouchers(ctx context.Context, ch address.Address) ([]*VoucherInfo, error) { + ca.lk.Lock() + defer ca.lk.Unlock() + // TODO: just having a passthrough method like this feels odd. Seems like // there should be some filtering we're doing here - return pm.store.VouchersForPaych(ch) + return ca.store.VouchersForPaych(ch) } -func (pm *Manager) OutboundChanTo(from, to address.Address) (address.Address, error) { - pm.store.lk.Lock() - defer pm.store.lk.Unlock() +func (ca *channelAccessor) nextNonceForLane(ctx context.Context, ch address.Address, lane uint64) (uint64, error) { + ca.lk.Lock() + defer ca.lk.Unlock() - return pm.store.findChan(func(ci *ChannelInfo) bool { - if ci.Direction != DirOutbound { - return false - } - return ci.Control == from && ci.Target == to - }) -} - -func (pm *Manager) NextNonceForLane(ctx context.Context, ch address.Address, lane uint64) (uint64, error) { // TODO: should this take into account lane state? - vouchers, err := pm.store.VouchersForPaych(ch) + vouchers, err := ca.store.VouchersForPaych(ch) if err != nil { return 0, err } @@ -355,3 +311,109 @@ func (pm *Manager) NextNonceForLane(ctx context.Context, ch address.Address, lan return maxnonce + 1, nil } + +// laneState gets the LaneStates from chain, then applies all vouchers in +// the data store over the chain state +func (ca *channelAccessor) laneState(state *paych.State, ch address.Address) (map[uint64]*paych.LaneState, error) { + // TODO: we probably want to call UpdateChannelState with all vouchers to be fully correct + // (but technically dont't need to) + laneStates := make(map[uint64]*paych.LaneState, len(state.LaneStates)) + + // Get the lane state from the chain + for _, laneState := range state.LaneStates { + laneStates[laneState.ID] = laneState + } + + // Apply locally stored vouchers + vouchers, err := ca.store.VouchersForPaych(ch) + if err != nil && err != ErrChannelNotTracked { + return nil, err + } + + for _, v := range vouchers { + for range v.Voucher.Merges { + return nil, xerrors.Errorf("paych merges not handled yet") + } + + // If there's a voucher for a lane that isn't in chain state just + // create it + ls, ok := laneStates[v.Voucher.Lane] + if !ok { + ls = &paych.LaneState{ + ID: v.Voucher.Lane, + Redeemed: types.NewInt(0), + Nonce: 0, + } + laneStates[v.Voucher.Lane] = ls + } + + if v.Voucher.Nonce < ls.Nonce { + continue + } + + ls.Nonce = v.Voucher.Nonce + ls.Redeemed = v.Voucher.Amount + } + + return laneStates, nil +} + +// Get the total redeemed amount across all lanes, after applying the voucher +func (ca *channelAccessor) totalRedeemedWithVoucher(laneStates map[uint64]*paych.LaneState, sv *paych.SignedVoucher) (big.Int, error) { + // TODO: merges + if len(sv.Merges) != 0 { + return big.Int{}, xerrors.Errorf("dont currently support paych lane merges") + } + + total := big.NewInt(0) + for _, ls := range laneStates { + total = big.Add(total, ls.Redeemed) + } + + lane, ok := laneStates[sv.Lane] + if ok { + // If the voucher is for an existing lane, and the voucher nonce + // and is higher than the lane nonce + if sv.Nonce > lane.Nonce { + // Add the delta between the redeemed amount and the voucher + // amount to the total + delta := big.Sub(sv.Amount, lane.Redeemed) + total = big.Add(total, delta) + } + } else { + // If the voucher is *not* for an existing lane, just add its + // value (implicitly a new lane will be created for the voucher) + total = big.Add(total, sv.Amount) + } + + return total, nil +} + +func (ca *channelAccessor) settle(ctx context.Context, ch address.Address) (cid.Cid, error) { + ca.lk.Lock() + defer ca.lk.Unlock() + + ci, err := ca.store.ByAddress(ch) + if err != nil { + return cid.Undef, err + } + + msg := &types.Message{ + To: ch, + From: ci.Control, + Value: types.NewInt(0), + Method: builtin.MethodsPaych.Settle, + } + smgs, err := ca.api.MpoolPushMessage(ctx, msg) + if err != nil { + return cid.Undef, err + } + + ci.Settling = true + err = ca.store.putChannelInfo(ci) + if err != nil { + log.Errorf("Error marking channel as settled: %s", err) + } + + return smgs.Cid(), err +} diff --git a/paychmgr/paych_test.go b/paychmgr/paych_test.go index 2cbea5cb54..9c28fdcb8b 100644 --- a/paychmgr/paych_test.go +++ b/paychmgr/paych_test.go @@ -104,13 +104,15 @@ func TestPaychOutbound(t *testing.T) { LaneStates: []*paych.LaneState{}, }) - mgr := newManager(sm, store) - err := mgr.TrackOutboundChannel(ctx, ch) + mgr, err := newManager(sm, store, nil) + require.NoError(t, err) + + err = mgr.TrackOutboundChannel(ctx, ch) require.NoError(t, err) ci, err := mgr.GetChannelInfo(ch) require.NoError(t, err) - require.Equal(t, ci.Channel, ch) + require.Equal(t, *ci.Channel, ch) require.Equal(t, ci.Control, from) require.Equal(t, ci.Target, to) require.EqualValues(t, ci.Direction, DirOutbound) @@ -140,13 +142,15 @@ func TestPaychInbound(t *testing.T) { LaneStates: []*paych.LaneState{}, }) - mgr := newManager(sm, store) - err := mgr.TrackInboundChannel(ctx, ch) + mgr, err := newManager(sm, store, nil) + require.NoError(t, err) + + err = mgr.TrackInboundChannel(ctx, ch) require.NoError(t, err) ci, err := mgr.GetChannelInfo(ch) require.NoError(t, err) - require.Equal(t, ci.Channel, ch) + require.Equal(t, *ci.Channel, ch) require.Equal(t, ci.Control, to) require.Equal(t, ci.Target, from) require.EqualValues(t, ci.Direction, DirInbound) @@ -321,8 +325,10 @@ func TestCheckVoucherValid(t *testing.T) { LaneStates: tcase.laneStates, }) - mgr := newManager(sm, store) - err := mgr.TrackInboundChannel(ctx, ch) + mgr, err := newManager(sm, store, nil) + require.NoError(t, err) + + err = mgr.TrackInboundChannel(ctx, ch) require.NoError(t, err) sv := testCreateVoucher(t, ch, tcase.voucherLane, tcase.voucherNonce, tcase.voucherAmount, tcase.key) @@ -382,8 +388,10 @@ func TestCheckVoucherValidCountingAllLanes(t *testing.T) { LaneStates: laneStates, }) - mgr := newManager(sm, store) - err := mgr.TrackInboundChannel(ctx, ch) + mgr, err := newManager(sm, store, nil) + require.NoError(t, err) + + err = mgr.TrackInboundChannel(ctx, ch) require.NoError(t, err) // @@ -690,8 +698,10 @@ func testSetupMgrWithChannel(ctx context.Context, t *testing.T) (*Manager, addre }) store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) - mgr := newManager(sm, store) - err := mgr.TrackInboundChannel(ctx, ch) + mgr, err := newManager(sm, store, nil) + require.NoError(t, err) + + err = mgr.TrackInboundChannel(ctx, ch) require.NoError(t, err) return mgr, ch, fromKeyPrivate } diff --git a/paychmgr/paychget_test.go b/paychmgr/paychget_test.go new file mode 100644 index 0000000000..bdc19d9ac3 --- /dev/null +++ b/paychmgr/paychget_test.go @@ -0,0 +1,756 @@ +package paychmgr + +import ( + "context" + "sync" + "testing" + "time" + + cborrpc "github.com/filecoin-project/go-cbor-util" + + init_ "github.com/filecoin-project/specs-actors/actors/builtin/init" + + "github.com/filecoin-project/specs-actors/actors/builtin" + + "github.com/filecoin-project/lotus/api" + "github.com/filecoin-project/lotus/chain/types" + + "github.com/filecoin-project/go-address" + + "github.com/filecoin-project/specs-actors/actors/abi/big" + tutils "github.com/filecoin-project/specs-actors/support/testing" + "github.com/ipfs/go-cid" + ds "github.com/ipfs/go-datastore" + ds_sync "github.com/ipfs/go-datastore/sync" + + "github.com/stretchr/testify/require" +) + +type waitingCall struct { + response chan types.MessageReceipt +} + +type mockPaychAPI struct { + lk sync.Mutex + messages map[cid.Cid]*types.SignedMessage + waitingCalls []*waitingCall +} + +func newMockPaychAPI() *mockPaychAPI { + return &mockPaychAPI{ + messages: make(map[cid.Cid]*types.SignedMessage), + } +} + +func (pchapi *mockPaychAPI) StateWaitMsg(ctx context.Context, msg cid.Cid, confidence uint64) (*api.MsgLookup, error) { + response := make(chan types.MessageReceipt) + + pchapi.lk.Lock() + pchapi.waitingCalls = append(pchapi.waitingCalls, &waitingCall{response: response}) + pchapi.lk.Unlock() + + receipt := <-response + + return &api.MsgLookup{Receipt: receipt}, nil +} + +func (pchapi *mockPaychAPI) MpoolPushMessage(ctx context.Context, msg *types.Message) (*types.SignedMessage, error) { + pchapi.lk.Lock() + defer pchapi.lk.Unlock() + + smsg := &types.SignedMessage{Message: *msg} + pchapi.messages[smsg.Cid()] = smsg + return smsg, nil +} + +func (pchapi *mockPaychAPI) pushedMessages(c cid.Cid) *types.SignedMessage { + pchapi.lk.Lock() + defer pchapi.lk.Unlock() + + return pchapi.messages[c] +} + +func (pchapi *mockPaychAPI) pushedMessageCount() int { + pchapi.lk.Lock() + defer pchapi.lk.Unlock() + + return len(pchapi.messages) +} + +func (pchapi *mockPaychAPI) finishWaitingCalls(receipt types.MessageReceipt) { + pchapi.lk.Lock() + defer pchapi.lk.Unlock() + + for _, call := range pchapi.waitingCalls { + call.response <- receipt + } + pchapi.waitingCalls = nil +} + +func (pchapi *mockPaychAPI) close() { + pchapi.finishWaitingCalls(types.MessageReceipt{ + ExitCode: 0, + Return: []byte{}, + }) +} + +func testChannelResponse(t *testing.T, ch address.Address) types.MessageReceipt { + createChannelRet := init_.ExecReturn{ + IDAddress: ch, + RobustAddress: ch, + } + createChannelRetBytes, err := cborrpc.Dump(&createChannelRet) + require.NoError(t, err) + createChannelResponse := types.MessageReceipt{ + ExitCode: 0, + Return: createChannelRetBytes, + } + return createChannelResponse +} + +// TestPaychGetCreateChannelMsg tests that GetPaych sends a message to create +// a new channel with the correct funds +func TestPaychGetCreateChannelMsg(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + amt := big.NewInt(10) + ch, mcid, err := mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + require.Equal(t, address.Undef, ch) + + pushedMsg := pchapi.pushedMessages(mcid) + require.Equal(t, from, pushedMsg.Message.From) + require.Equal(t, builtin.InitActorAddr, pushedMsg.Message.To) + require.Equal(t, amt, pushedMsg.Message.Value) +} + +// TestPaychGetCreateChannelThenAddFunds tests creating a channel and then +// adding funds to it +func TestPaychGetCreateChannelThenAddFunds(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + ch := tutils.NewIDAddr(t, 100) + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel with value 10 + amt := big.NewInt(10) + _, createMsgCid, err := mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + + // Should have no channels yet (message sent but channel not created) + cis, err := mgr.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 0) + + // 1. Set up create channel response (sent in response to WaitForMsg()) + response := testChannelResponse(t, ch) + + done := make(chan struct{}) + go func() { + defer close(done) + + // 2. Request add funds - should block until create channel has completed + amt2 := big.NewInt(5) + ch2, addFundsMsgCid, err := mgr.GetPaych(ctx, from, to, amt2) + + // 4. This GetPaych should return after create channel from first + // GetPaych completes + require.NoError(t, err) + + // Expect the channel to be the same + require.Equal(t, ch, ch2) + // Expect add funds message CID to be different to create message cid + require.NotEqual(t, createMsgCid, addFundsMsgCid) + + // Should have one channel, whose address is the channel that was created + cis, err := mgr.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + // Amount should be amount sent to first GetPaych (to create + // channel). + // PendingAmount should be amount sent in second GetPaych + // (second GetPaych triggered add funds, which has not yet been confirmed) + ci, err := mgr.GetChannelInfo(ch) + require.NoError(t, err) + require.EqualValues(t, 10, ci.Amount.Int64()) + require.EqualValues(t, 5, ci.PendingAmount.Int64()) + require.Nil(t, ci.CreateMsg) + + // Trigger add funds confirmation + pchapi.finishWaitingCalls(types.MessageReceipt{ExitCode: 0}) + + time.Sleep(time.Millisecond * 10) + + // Should still have one channel + cis, err = mgr.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + // Channel amount should include last amount sent to GetPaych + ci, err = mgr.GetChannelInfo(ch) + require.NoError(t, err) + require.EqualValues(t, 15, ci.Amount.Int64()) + require.EqualValues(t, 0, ci.PendingAmount.Int64()) + require.Nil(t, ci.AddFundsMsg) + }() + + // Give the go routine above a moment to run + time.Sleep(time.Millisecond * 10) + + // 3. Send create channel response + pchapi.finishWaitingCalls(response) + + <-done +} + +// TestPaychGetCreateChannelWithErrorThenCreateAgain tests that if an +// operation is queued up behind a create channel operation, and the create +// channel fails, then the waiting operation can succeed. +func TestPaychGetCreateChannelWithErrorThenCreateAgain(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel + amt := big.NewInt(10) + _, _, err = mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + + // 1. Set up create channel response (sent in response to WaitForMsg()) + // This response indicates an error. + errResponse := types.MessageReceipt{ + ExitCode: 1, // error + Return: []byte{}, + } + + done := make(chan struct{}) + go func() { + defer close(done) + + // 2. Should block until create channel has completed. + // Because first channel create fails, this request + // should be for channel create. + amt2 := big.NewInt(5) + ch2, _, err := mgr.GetPaych(ctx, from, to, amt2) + require.NoError(t, err) + require.Equal(t, address.Undef, ch2) + + time.Sleep(time.Millisecond * 10) + + // 4. Send a success response + ch := tutils.NewIDAddr(t, 100) + successResponse := testChannelResponse(t, ch) + pchapi.finishWaitingCalls(successResponse) + + time.Sleep(time.Millisecond * 10) + + // Should have one channel, whose address is the channel that was created + cis, err := mgr.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + ci, err := mgr.GetChannelInfo(ch) + require.NoError(t, err) + require.Equal(t, amt2, ci.Amount) + }() + + // Give the go routine above a moment to run + time.Sleep(time.Millisecond * 10) + + // 3. Send error response to first channel create + pchapi.finishWaitingCalls(errResponse) + + <-done +} + +// TestPaychGetRecoverAfterError tests that after a create channel fails, the +// next attempt to create channel can succeed. +func TestPaychGetRecoverAfterError(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + ch := tutils.NewIDAddr(t, 100) + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel + amt := big.NewInt(10) + _, _, err = mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + // Send error create channel response + pchapi.finishWaitingCalls(types.MessageReceipt{ + ExitCode: 1, // error + Return: []byte{}, + }) + + time.Sleep(time.Millisecond * 10) + + // Send create message for a channel again + amt2 := big.NewInt(7) + _, _, err = mgr.GetPaych(ctx, from, to, amt2) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + // Send success create channel response + response := testChannelResponse(t, ch) + pchapi.finishWaitingCalls(response) + + time.Sleep(time.Millisecond * 10) + + // Should have one channel, whose address is the channel that was created + cis, err := mgr.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + ci, err := mgr.GetChannelInfo(ch) + require.NoError(t, err) + require.Equal(t, amt2, ci.Amount) + require.EqualValues(t, 0, ci.PendingAmount.Int64()) + require.Nil(t, ci.CreateMsg) +} + +// TestPaychGetRecoverAfterAddFundsError tests that after an add funds fails, the +// next attempt to add funds can succeed. +func TestPaychGetRecoverAfterAddFundsError(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + ch := tutils.NewIDAddr(t, 100) + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel + amt := big.NewInt(10) + _, _, err = mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + // Send success create channel response + response := testChannelResponse(t, ch) + pchapi.finishWaitingCalls(response) + + time.Sleep(time.Millisecond * 10) + + // Send add funds message for channel + amt2 := big.NewInt(5) + _, _, err = mgr.GetPaych(ctx, from, to, amt2) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + // Send error add funds response + pchapi.finishWaitingCalls(types.MessageReceipt{ + ExitCode: 1, // error + Return: []byte{}, + }) + + time.Sleep(time.Millisecond * 10) + + // Should have one channel, whose address is the channel that was created + cis, err := mgr.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + ci, err := mgr.GetChannelInfo(ch) + require.NoError(t, err) + require.Equal(t, amt, ci.Amount) + require.EqualValues(t, 0, ci.PendingAmount.Int64()) + require.Nil(t, ci.CreateMsg) + require.Nil(t, ci.AddFundsMsg) + + // Send add funds message for channel again + amt3 := big.NewInt(2) + _, _, err = mgr.GetPaych(ctx, from, to, amt3) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + // Send success add funds response + pchapi.finishWaitingCalls(types.MessageReceipt{ + ExitCode: 0, + Return: []byte{}, + }) + + time.Sleep(time.Millisecond * 10) + + // Should have one channel, whose address is the channel that was created + cis, err = mgr.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + // Amount should include amount for successful add funds msg + ci, err = mgr.GetChannelInfo(ch) + require.NoError(t, err) + require.Equal(t, amt.Int64()+amt3.Int64(), ci.Amount.Int64()) + require.EqualValues(t, 0, ci.PendingAmount.Int64()) + require.Nil(t, ci.CreateMsg) + require.Nil(t, ci.AddFundsMsg) +} + +// TestPaychGetRestartAfterCreateChannelMsg tests that if the system stops +// right after the create channel message is sent, the channel will be +// created when the system restarts. +func TestPaychGetRestartAfterCreateChannelMsg(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + ch := tutils.NewIDAddr(t, 100) + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel with value 10 + amt := big.NewInt(10) + _, createMsgCid, err := mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + + // Simulate shutting down lotus + pchapi.close() + + // Create a new manager with the same datastore + sm2 := newMockStateManager() + pchapi2 := newMockPaychAPI() + defer pchapi2.close() + + mgr2, err := newManager(sm2, store, pchapi2) + require.NoError(t, err) + + // Should have no channels yet (message sent but channel not created) + cis, err := mgr2.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 0) + + // 1. Set up create channel response (sent in response to WaitForMsg()) + response := testChannelResponse(t, ch) + + done := make(chan struct{}) + go func() { + defer close(done) + + // 2. Request add funds - should block until create channel has completed + amt2 := big.NewInt(5) + ch2, addFundsMsgCid, err := mgr2.GetPaych(ctx, from, to, amt2) + + // 4. This GetPaych should return after create channel from first + // GetPaych completes + require.NoError(t, err) + + // Expect the channel to have been created + require.Equal(t, ch, ch2) + // Expect add funds message CID to be different to create message cid + require.NotEqual(t, createMsgCid, addFundsMsgCid) + + // Should have one channel, whose address is the channel that was created + cis, err := mgr2.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + // Amount should be amount sent to first GetPaych (to create + // channel). + // PendingAmount should be amount sent in second GetPaych + // (second GetPaych triggered add funds, which has not yet been confirmed) + ci, err := mgr2.GetChannelInfo(ch) + require.NoError(t, err) + require.EqualValues(t, 10, ci.Amount.Int64()) + require.EqualValues(t, 5, ci.PendingAmount.Int64()) + require.Nil(t, ci.CreateMsg) + }() + + // Give the go routine above a moment to run + time.Sleep(time.Millisecond * 10) + + // 3. Send create channel response + pchapi2.finishWaitingCalls(response) + + <-done +} + +// TestPaychGetRestartAfterAddFundsMsg tests that if the system stops +// right after the add funds message is sent, the add funds will be +// processed when the system restarts. +func TestPaychGetRestartAfterAddFundsMsg(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + ch := tutils.NewIDAddr(t, 100) + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel + amt := big.NewInt(10) + _, _, err = mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + // Send success create channel response + response := testChannelResponse(t, ch) + pchapi.finishWaitingCalls(response) + + time.Sleep(time.Millisecond * 10) + + // Send add funds message for channel + amt2 := big.NewInt(5) + _, _, err = mgr.GetPaych(ctx, from, to, amt2) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + // Simulate shutting down lotus + pchapi.close() + + // Create a new manager with the same datastore + sm2 := newMockStateManager() + pchapi2 := newMockPaychAPI() + defer pchapi2.close() + + time.Sleep(time.Millisecond * 10) + + mgr2, err := newManager(sm2, store, pchapi2) + require.NoError(t, err) + + // Send success add funds response + pchapi2.finishWaitingCalls(types.MessageReceipt{ + ExitCode: 0, + Return: []byte{}, + }) + + time.Sleep(time.Millisecond * 10) + + // Should have one channel, whose address is the channel that was created + cis, err := mgr2.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + // Amount should include amount for successful add funds msg + ci, err := mgr2.GetChannelInfo(ch) + require.NoError(t, err) + require.Equal(t, amt.Int64()+amt2.Int64(), ci.Amount.Int64()) + require.EqualValues(t, 0, ci.PendingAmount.Int64()) + require.Nil(t, ci.CreateMsg) + require.Nil(t, ci.AddFundsMsg) +} + +// TestPaychGetWait tests that GetPaychWaitReady correctly waits for the +// channel to be created or funds to be added +func TestPaychGetWait(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // 1. Get + amt := big.NewInt(10) + _, mcid, err := mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + + done := make(chan address.Address) + go func() { + // 2. Wait till ready + ch, err := mgr.GetPaychWaitReady(ctx, mcid) + require.NoError(t, err) + + done <- ch + }() + + time.Sleep(time.Millisecond * 10) + + // 3. Send response + expch := tutils.NewIDAddr(t, 100) + response := testChannelResponse(t, expch) + pchapi.finishWaitingCalls(response) + + time.Sleep(time.Millisecond * 10) + + ch := <-done + require.Equal(t, expch, ch) + + // 4. Wait again - message has already been received so should + // return immediately + ch, err = mgr.GetPaychWaitReady(ctx, mcid) + require.NoError(t, err) + require.Equal(t, expch, ch) + + // Request add funds + amt2 := big.NewInt(15) + _, addFundsMsgCid, err := mgr.GetPaych(ctx, from, to, amt2) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + go func() { + // 5. Wait for add funds + ch, err := mgr.GetPaychWaitReady(ctx, addFundsMsgCid) + require.NoError(t, err) + require.Equal(t, expch, ch) + + done <- ch + }() + + time.Sleep(time.Millisecond * 10) + + // 6. Send add funds response + addFundsResponse := types.MessageReceipt{ + ExitCode: 0, + Return: []byte{}, + } + pchapi.finishWaitingCalls(addFundsResponse) + + <-done +} + +// TestPaychGetWaitErr tests that GetPaychWaitReady correctly handles errors +func TestPaychGetWaitErr(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // 1. Create channel + amt := big.NewInt(10) + _, mcid, err := mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + + done := make(chan address.Address) + go func() { + defer close(done) + + // 2. Wait for channel to be created + _, err := mgr.GetPaychWaitReady(ctx, mcid) + + // 4. Channel creation should have failed + require.NotNil(t, err) + + // 5. Call wait again with the same message CID + _, err = mgr.GetPaychWaitReady(ctx, mcid) + + // 6. Should return immediately with the same error + require.NotNil(t, err) + }() + + // Give the wait a moment to start before sending response + time.Sleep(time.Millisecond * 10) + + // 3. Send error response to create channel + response := types.MessageReceipt{ + ExitCode: 1, // error + Return: []byte{}, + } + pchapi.finishWaitingCalls(response) + + <-done +} + +// TestPaychGetWaitCtx tests that GetPaychWaitReady returns early if the context +// is cancelled +func TestPaychGetWaitCtx(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + amt := big.NewInt(10) + _, mcid, err := mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + + // When the context is cancelled, should unblock wait + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = mgr.GetPaychWaitReady(ctx, mcid) + require.Error(t, ctx.Err(), err) +} diff --git a/paychmgr/settle_test.go b/paychmgr/settle_test.go new file mode 100644 index 0000000000..a60351d4d8 --- /dev/null +++ b/paychmgr/settle_test.go @@ -0,0 +1,72 @@ +package paychmgr + +import ( + "context" + "testing" + "time" + + "github.com/ipfs/go-cid" + + "github.com/filecoin-project/specs-actors/actors/abi/big" + tutils "github.com/filecoin-project/specs-actors/support/testing" + ds "github.com/ipfs/go-datastore" + ds_sync "github.com/ipfs/go-datastore/sync" + "github.com/stretchr/testify/require" +) + +func TestPaychSettle(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + expch := tutils.NewIDAddr(t, 100) + expch2 := tutils.NewIDAddr(t, 101) + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + amt := big.NewInt(10) + _, mcid, err := mgr.GetPaych(ctx, from, to, amt) + require.NoError(t, err) + + time.Sleep(10 * time.Millisecond) + + // Send channel create response + response := testChannelResponse(t, expch) + pchapi.finishWaitingCalls(response) + + // Get the channel address + ch, err := mgr.GetPaychWaitReady(ctx, mcid) + require.NoError(t, err) + require.Equal(t, expch, ch) + + // Settle the channel + _, err = mgr.Settle(ctx, ch) + require.NoError(t, err) + + // Send another request for funds to the same from/to + // (should create a new channel because the previous channel + // is settling) + amt2 := big.NewInt(5) + _, mcid2, err := mgr.GetPaych(ctx, from, to, amt2) + require.NoError(t, err) + require.NotEqual(t, cid.Undef, mcid2) + + time.Sleep(10 * time.Millisecond) + + // Send new channel create response + response2 := testChannelResponse(t, expch2) + pchapi.finishWaitingCalls(response2) + + time.Sleep(10 * time.Millisecond) + + // Make sure the new channel is different from the old channel + ch2, err := mgr.GetPaychWaitReady(ctx, mcid2) + require.NoError(t, err) + require.NotEqual(t, ch, ch2) +} diff --git a/paychmgr/simple.go b/paychmgr/simple.go index 0d0075d626..0113f70612 100644 --- a/paychmgr/simple.go +++ b/paychmgr/simple.go @@ -3,6 +3,11 @@ package paychmgr import ( "bytes" "context" + "fmt" + + "github.com/filecoin-project/lotus/api" + + "github.com/filecoin-project/specs-actors/actors/abi/big" "github.com/filecoin-project/specs-actors/actors/builtin" init_ "github.com/filecoin-project/specs-actors/actors/builtin/init" @@ -17,7 +22,211 @@ import ( "github.com/filecoin-project/lotus/chain/types" ) -func (pm *Manager) createPaych(ctx context.Context, from, to address.Address, amt types.BigInt) (cid.Cid, error) { +type paychApi interface { + StateWaitMsg(ctx context.Context, msg cid.Cid, confidence uint64) (*api.MsgLookup, error) + MpoolPushMessage(ctx context.Context, msg *types.Message) (*types.SignedMessage, error) +} + +// paychFundsRes is the response to a create channel or add funds request +type paychFundsRes struct { + channel address.Address + mcid cid.Cid + err error +} + +type onCompleteFn func(*paychFundsRes) + +// fundsReq is a request to create a channel or add funds to a channel +type fundsReq struct { + ctx context.Context + from address.Address + to address.Address + amt types.BigInt + onComplete onCompleteFn +} + +// getPaych ensures that a channel exists between the from and to addresses, +// and adds the given amount of funds. +// If the channel does not exist a create channel message is sent and the +// message CID is returned. +// If the channel does exist an add funds message is sent and both the channel +// address and message CID are returned. +// If there is an in progress operation (create channel / add funds), getPaych +// blocks until the previous operation completes, then returns both the channel +// address and the CID of the new add funds message. +// If an operation returns an error, subsequent waiting operations will still +// be attempted. +func (ca *channelAccessor) getPaych(ctx context.Context, from, to address.Address, amt types.BigInt) (address.Address, cid.Cid, error) { + // Add the request to add funds to a queue and wait for the result + promise := ca.enqueue(&fundsReq{ctx: ctx, from: from, to: to, amt: amt}) + select { + case res := <-promise: + return res.channel, res.mcid, res.err + case <-ctx.Done(): + return address.Undef, cid.Undef, ctx.Err() + } +} + +// Queue up an add funds operation +func (ca *channelAccessor) enqueue(task *fundsReq) chan *paychFundsRes { + promise := make(chan *paychFundsRes) + task.onComplete = func(res *paychFundsRes) { + select { + case <-task.ctx.Done(): + case promise <- res: + } + } + + ca.lk.Lock() + defer ca.lk.Unlock() + + ca.fundsReqQueue = append(ca.fundsReqQueue, task) + go ca.processNextQueueItem() + + return promise +} + +// Run the operation at the head of the queue +func (ca *channelAccessor) processNextQueueItem() { + ca.lk.Lock() + defer ca.lk.Unlock() + + if len(ca.fundsReqQueue) == 0 { + return + } + + head := ca.fundsReqQueue[0] + res := ca.processTask(head.ctx, head.from, head.to, head.amt, head.onComplete) + + // If the task is waiting on an external event (eg something to appear on + // chain) it will return nil + if res == nil { + // Stop processing the fundsReqQueue and wait. When the event occurs it will + // call processNextQueueItem() again + return + } + + // The task has finished processing so clean it up + ca.fundsReqQueue[0] = nil // allow GC of element + ca.fundsReqQueue = ca.fundsReqQueue[1:] + + // Call the task callback with its results + head.onComplete(res) + + // Process the next task + if len(ca.fundsReqQueue) > 0 { + go ca.processNextQueueItem() + } +} + +// msgWaitComplete is called when the message for a previous task is confirmed +// or there is an error. +func (ca *channelAccessor) msgWaitComplete(mcid cid.Cid, err error, cb onCompleteFn) { + ca.lk.Lock() + defer ca.lk.Unlock() + + // Save the message result to the store + dserr := ca.store.SaveMessageResult(mcid, err) + if dserr != nil { + log.Errorf("saving message result: %s", dserr) + } + + // Call the onComplete callback + ca.callOnComplete(mcid, err, cb) + + // Inform listeners that the message has completed + ca.msgListeners.fireMsgComplete(mcid, err) + + // The queue may have been waiting for msg completion to proceed, so + // process the next queue item + if len(ca.fundsReqQueue) > 0 { + go ca.processNextQueueItem() + } +} + +// callOnComplete calls the onComplete callback for a task +func (ca *channelAccessor) callOnComplete(mcid cid.Cid, err error, cb onCompleteFn) { + if cb == nil { + return + } + + if err != nil { + go cb(&paychFundsRes{err: err}) + return + } + + // Get the channel address + ci, storeErr := ca.store.ByMessageCid(mcid) + if storeErr != nil { + log.Errorf("getting channel by message cid: %s", err) + go cb(&paychFundsRes{err: storeErr}) + return + } + + if ci.Channel == nil { + panic("channel address is nil when calling onComplete callback") + } + + go cb(&paychFundsRes{channel: *ci.Channel, mcid: mcid, err: err}) +} + +// processTask checks the state of the channel and takes appropriate action +// (see description of getPaych). +// Note that processTask may be called repeatedly in the same state, and should +// return nil if there is no state change to be made (eg when waiting for a +// message to be confirmed on chain) +func (ca *channelAccessor) processTask( + ctx context.Context, + from address.Address, + to address.Address, + amt types.BigInt, + onComplete onCompleteFn, +) *paychFundsRes { + // Get the payment channel for the from/to addresses. + // Note: It's ok if we get ErrChannelNotTracked. It just means we need to + // create a channel. + channelInfo, err := ca.store.OutboundActiveByFromTo(from, to) + if err != nil && err != ErrChannelNotTracked { + return &paychFundsRes{err: err} + } + + // If a channel has not yet been created, create one. + // Note that if the previous attempt to create the channel failed because of a VM error + // (eg not enough gas), both channelInfo.Channel and channelInfo.CreateMsg will be nil. + if channelInfo == nil || channelInfo.Channel == nil && channelInfo.CreateMsg == nil { + mcid, err := ca.createPaych(ctx, from, to, amt, onComplete) + if err != nil { + return &paychFundsRes{err: err} + } + + return &paychFundsRes{mcid: mcid} + } + + // If the create channel message has been sent but the channel hasn't + // been created on chain yet + if channelInfo.CreateMsg != nil { + // Wait for the channel to be created before trying again + return nil + } + + // If an add funds message was sent to the chain but hasn't been confirmed + // on chain yet + if channelInfo.AddFundsMsg != nil { + // Wait for the add funds message to be confirmed before trying again + return nil + } + + // We need to add more funds, so send an add funds message to + // cover the amount for this request + mcid, err := ca.addFunds(ctx, from, to, amt, onComplete) + if err != nil { + return &paychFundsRes{err: err} + } + return &paychFundsRes{channel: *channelInfo.Channel, mcid: *mcid} +} + +// createPaych sends a message to create the channel and returns the message cid +func (ca *channelAccessor) createPaych(ctx context.Context, from, to address.Address, amt types.BigInt, cb onCompleteFn) (cid.Cid, error) { params, aerr := actors.SerializeParams(&paych.ConstructorParams{From: from, To: to}) if aerr != nil { return cid.Undef, aerr @@ -41,106 +250,268 @@ func (pm *Manager) createPaych(ctx context.Context, from, to address.Address, am GasPrice: types.NewInt(0), } - smsg, err := pm.mpool.MpoolPushMessage(ctx, msg) + smsg, err := ca.api.MpoolPushMessage(ctx, msg) if err != nil { return cid.Undef, xerrors.Errorf("initializing paych actor: %w", err) } mcid := smsg.Cid() - go pm.waitForPaychCreateMsg(ctx, mcid) + + // Create a new channel in the store + if _, err := ca.store.createChannel(from, to, mcid, amt); err != nil { + log.Errorf("creating channel: %s", err) + return cid.Undef, err + } + + // Wait for the channel to be created on chain + go ca.waitForPaychCreateMsg(from, to, mcid, cb) + return mcid, nil } -// WaitForPaychCreateMsg waits for mcid to appear on chain and returns the robust address of the +// waitForPaychCreateMsg waits for mcid to appear on chain and stores the robust address of the // created payment channel -// TODO: wait outside the store lock! -// (tricky because we need to setup channel tracking before we know its address) -func (pm *Manager) waitForPaychCreateMsg(ctx context.Context, mcid cid.Cid) { - defer pm.store.lk.Unlock() - mwait, err := pm.state.StateWaitMsg(ctx, mcid, build.MessageConfidence) +func (ca *channelAccessor) waitForPaychCreateMsg(from address.Address, to address.Address, mcid cid.Cid, cb onCompleteFn) { + err := ca.waitPaychCreateMsg(from, to, mcid) + ca.msgWaitComplete(mcid, err, cb) +} + +func (ca *channelAccessor) waitPaychCreateMsg(from address.Address, to address.Address, mcid cid.Cid) error { + mwait, err := ca.api.StateWaitMsg(ca.waitCtx, mcid, build.MessageConfidence) if err != nil { log.Errorf("wait msg: %w", err) - return + return err } if mwait.Receipt.ExitCode != 0 { - log.Errorf("payment channel creation failed (exit code %d)", mwait.Receipt.ExitCode) - return + err := xerrors.Errorf("payment channel creation failed (exit code %d)", mwait.Receipt.ExitCode) + log.Error(err) + + ca.lk.Lock() + defer ca.lk.Unlock() + + ca.mutateChannelInfo(from, to, func(channelInfo *ChannelInfo) { + channelInfo.PendingAmount = big.NewInt(0) + channelInfo.CreateMsg = nil + }) + + return err } var decodedReturn init_.ExecReturn err = decodedReturn.UnmarshalCBOR(bytes.NewReader(mwait.Receipt.Return)) if err != nil { log.Error(err) - return + return err } - paychaddr := decodedReturn.RobustAddress - ci, err := pm.loadStateChannelInfo(ctx, paychaddr, DirOutbound) - if err != nil { - log.Errorf("loading channel info: %w", err) - return - } + ca.lk.Lock() + defer ca.lk.Unlock() - if err := pm.store.trackChannel(ci); err != nil { - log.Errorf("tracking channel: %w", err) - } + // Store robust address of channel + ca.mutateChannelInfo(from, to, func(channelInfo *ChannelInfo) { + channelInfo.Channel = &decodedReturn.RobustAddress + channelInfo.Amount = channelInfo.PendingAmount + channelInfo.PendingAmount = big.NewInt(0) + channelInfo.CreateMsg = nil + }) + + return nil } -func (pm *Manager) addFunds(ctx context.Context, ch address.Address, from address.Address, amt types.BigInt) (cid.Cid, error) { +// addFunds sends a message to add funds to the channel and returns the message cid +func (ca *channelAccessor) addFunds(ctx context.Context, from address.Address, to address.Address, amt types.BigInt, cb onCompleteFn) (*cid.Cid, error) { + channelInfo, err := ca.store.OutboundActiveByFromTo(from, to) + if err != nil { + return nil, err + } + msg := &types.Message{ - To: ch, - From: from, + To: *channelInfo.Channel, + From: channelInfo.Control, Value: amt, Method: 0, GasLimit: 0, GasPrice: types.NewInt(0), } - smsg, err := pm.mpool.MpoolPushMessage(ctx, msg) + smsg, err := ca.api.MpoolPushMessage(ctx, msg) if err != nil { - return cid.Undef, err + return nil, err } mcid := smsg.Cid() - go pm.waitForAddFundsMsg(ctx, mcid) - return mcid, nil + + // Store the add funds message CID on the channel + ca.mutateChannelInfo(from, to, func(ci *ChannelInfo) { + ci.PendingAmount = amt + ci.AddFundsMsg = &mcid + }) + + // Store a reference from the message CID to the channel, so that we can + // look up the channel from the message CID + err = ca.store.SaveNewMessage(channelInfo.ChannelID, mcid) + if err != nil { + log.Errorf("saving add funds message CID %s: %s", mcid, err) + } + + go ca.waitForAddFundsMsg(from, to, mcid, cb) + + return &mcid, nil +} + +// waitForAddFundsMsg waits for mcid to appear on chain and returns error, if any +func (ca *channelAccessor) waitForAddFundsMsg(from address.Address, to address.Address, mcid cid.Cid, cb onCompleteFn) { + err := ca.waitAddFundsMsg(from, to, mcid) + ca.msgWaitComplete(mcid, err, cb) } -// WaitForAddFundsMsg waits for mcid to appear on chain and returns error, if any -// TODO: wait outside the store lock! -// (tricky because we need to setup channel tracking before we know it's address) -func (pm *Manager) waitForAddFundsMsg(ctx context.Context, mcid cid.Cid) { - defer pm.store.lk.Unlock() - mwait, err := pm.state.StateWaitMsg(ctx, mcid, build.MessageConfidence) +func (ca *channelAccessor) waitAddFundsMsg(from address.Address, to address.Address, mcid cid.Cid) error { + mwait, err := ca.api.StateWaitMsg(ca.waitCtx, mcid, build.MessageConfidence) if err != nil { log.Error(err) + return err } if mwait.Receipt.ExitCode != 0 { - log.Errorf("voucher channel creation failed: adding funds (exit code %d)", mwait.Receipt.ExitCode) + err := xerrors.Errorf("voucher channel creation failed: adding funds (exit code %d)", mwait.Receipt.ExitCode) + log.Error(err) + + ca.lk.Lock() + defer ca.lk.Unlock() + + ca.mutateChannelInfo(from, to, func(channelInfo *ChannelInfo) { + channelInfo.PendingAmount = big.NewInt(0) + channelInfo.AddFundsMsg = nil + }) + + return err } -} -func (pm *Manager) GetPaych(ctx context.Context, from, to address.Address, ensureFree types.BigInt) (address.Address, cid.Cid, error) { - pm.store.lk.Lock() // unlock only on err; wait funcs will defer unlock - var mcid cid.Cid - ch, err := pm.store.findChan(func(ci *ChannelInfo) bool { - if ci.Direction != DirOutbound { - return false - } - return ci.Control == from && ci.Target == to + ca.lk.Lock() + defer ca.lk.Unlock() + + // Store updated amount + ca.mutateChannelInfo(from, to, func(channelInfo *ChannelInfo) { + channelInfo.Amount = types.BigAdd(channelInfo.Amount, channelInfo.PendingAmount) + channelInfo.PendingAmount = big.NewInt(0) + channelInfo.AddFundsMsg = nil }) + + return nil +} + +// Change the state of the channel in the store +func (ca *channelAccessor) mutateChannelInfo(from address.Address, to address.Address, mutate func(*ChannelInfo)) { + channelInfo, err := ca.store.OutboundActiveByFromTo(from, to) + + // If there's an error reading or writing to the store just log an error. + // For now we're assuming it's unlikely to happen in practice. + // Later we may want to implement a transactional approach, whereby + // we record to the store that we're going to send a message, send + // the message, and then record that the message was sent. if err != nil { - pm.store.lk.Unlock() - return address.Undef, cid.Undef, xerrors.Errorf("findChan: %w", err) + log.Errorf("Error reading channel info from store: %s", err) } - if ch != address.Undef { - // TODO: Track available funds - mcid, err = pm.addFunds(ctx, ch, from, ensureFree) - } else { - mcid, err = pm.createPaych(ctx, from, to, ensureFree) + + mutate(channelInfo) + + err = ca.store.putChannelInfo(channelInfo) + if err != nil { + log.Errorf("Error writing channel info to store: %s", err) } +} + +// getPaychWaitReady waits for a the response to the message with the given cid +func (ca *channelAccessor) getPaychWaitReady(ctx context.Context, mcid cid.Cid) (address.Address, error) { + ca.lk.Lock() + + // First check if the message has completed + msgInfo, err := ca.store.GetMessage(mcid) if err != nil { - pm.store.lk.Unlock() + ca.lk.Unlock() + + return address.Undef, err + } + + // If the create channel / add funds message failed, return an error + if len(msgInfo.Err) > 0 { + ca.lk.Unlock() + + return address.Undef, xerrors.New(msgInfo.Err) + } + + // If the message has completed successfully + if msgInfo.Received { + ca.lk.Unlock() + + // Get the channel address + ci, err := ca.store.ByMessageCid(mcid) + if err != nil { + return address.Undef, err + } + + if ci.Channel == nil { + panic(fmt.Sprintf("create / add funds message %s succeeded but channelInfo.Channel is nil", mcid)) + } + return *ci.Channel, nil } - return ch, mcid, err + + // The message hasn't completed yet so wait for it to complete + promise := ca.msgPromise(ctx, mcid) + + // Unlock while waiting + ca.lk.Unlock() + + select { + case res := <-promise: + return res.channel, res.err + case <-ctx.Done(): + return address.Undef, ctx.Err() + } +} + +type onMsgRes struct { + channel address.Address + err error +} + +// msgPromise returns a channel that receives the result of the message with +// the given CID +func (ca *channelAccessor) msgPromise(ctx context.Context, mcid cid.Cid) chan onMsgRes { + promise := make(chan onMsgRes) + triggerUnsub := make(chan struct{}) + sub := ca.msgListeners.onMsg(mcid, func(err error) { + close(triggerUnsub) + + // Use a go-routine so as not to block the event handler loop + go func() { + res := onMsgRes{err: err} + if res.err == nil { + // Get the channel associated with the message cid + ci, err := ca.store.ByMessageCid(mcid) + if err != nil { + res.err = err + } else { + res.channel = *ci.Channel + } + } + + // Pass the result to the caller + select { + case promise <- res: + case <-ctx.Done(): + } + }() + }) + + // Unsubscribe when the message is received or the context is done + go func() { + select { + case <-ctx.Done(): + case <-triggerUnsub: + } + + ca.msgListeners.unsubscribe(sub) + }() + + return promise } diff --git a/paychmgr/state.go b/paychmgr/state.go index 7d06a35a4d..9ba2740e69 100644 --- a/paychmgr/state.go +++ b/paychmgr/state.go @@ -3,20 +3,21 @@ package paychmgr import ( "context" - "github.com/filecoin-project/specs-actors/actors/abi/big" - "github.com/filecoin-project/specs-actors/actors/builtin/account" "github.com/filecoin-project/go-address" "github.com/filecoin-project/specs-actors/actors/builtin/paych" - xerrors "golang.org/x/xerrors" "github.com/filecoin-project/lotus/chain/types" ) -func (pm *Manager) loadPaychState(ctx context.Context, ch address.Address) (*types.Actor, *paych.State, error) { +type stateAccessor struct { + sm StateManagerApi +} + +func (ca *stateAccessor) loadPaychState(ctx context.Context, ch address.Address) (*types.Actor, *paych.State, error) { var pcast paych.State - act, err := pm.sm.LoadActorState(ctx, ch, &pcast, nil) + act, err := ca.sm.LoadActorState(ctx, ch, &pcast, nil) if err != nil { return nil, nil, err } @@ -24,26 +25,26 @@ func (pm *Manager) loadPaychState(ctx context.Context, ch address.Address) (*typ return act, &pcast, nil } -func (pm *Manager) loadStateChannelInfo(ctx context.Context, ch address.Address, dir uint64) (*ChannelInfo, error) { - _, st, err := pm.loadPaychState(ctx, ch) +func (ca *stateAccessor) loadStateChannelInfo(ctx context.Context, ch address.Address, dir uint64) (*ChannelInfo, error) { + _, st, err := ca.loadPaychState(ctx, ch) if err != nil { return nil, err } var account account.State - _, err = pm.sm.LoadActorState(ctx, st.From, &account, nil) + _, err = ca.sm.LoadActorState(ctx, st.From, &account, nil) if err != nil { return nil, err } from := account.Address - _, err = pm.sm.LoadActorState(ctx, st.To, &account, nil) + _, err = ca.sm.LoadActorState(ctx, st.To, &account, nil) if err != nil { return nil, err } to := account.Address ci := &ChannelInfo{ - Channel: ch, + Channel: &ch, Direction: dir, NextLane: nextLaneFromState(st), } @@ -72,80 +73,3 @@ func nextLaneFromState(st *paych.State) uint64 { } return maxLane + 1 } - -// laneState gets the LaneStates from chain, then applies all vouchers in -// the data store over the chain state -func (pm *Manager) laneState(state *paych.State, ch address.Address) (map[uint64]*paych.LaneState, error) { - // TODO: we probably want to call UpdateChannelState with all vouchers to be fully correct - // (but technically dont't need to) - laneStates := make(map[uint64]*paych.LaneState, len(state.LaneStates)) - - // Get the lane state from the chain - for _, laneState := range state.LaneStates { - laneStates[laneState.ID] = laneState - } - - // Apply locally stored vouchers - vouchers, err := pm.store.VouchersForPaych(ch) - if err != nil && err != ErrChannelNotTracked { - return nil, err - } - - for _, v := range vouchers { - for range v.Voucher.Merges { - return nil, xerrors.Errorf("paych merges not handled yet") - } - - // If there's a voucher for a lane that isn't in chain state just - // create it - ls, ok := laneStates[v.Voucher.Lane] - if !ok { - ls = &paych.LaneState{ - ID: v.Voucher.Lane, - Redeemed: types.NewInt(0), - Nonce: 0, - } - laneStates[v.Voucher.Lane] = ls - } - - if v.Voucher.Nonce < ls.Nonce { - continue - } - - ls.Nonce = v.Voucher.Nonce - ls.Redeemed = v.Voucher.Amount - } - - return laneStates, nil -} - -// Get the total redeemed amount across all lanes, after applying the voucher -func (pm *Manager) totalRedeemedWithVoucher(laneStates map[uint64]*paych.LaneState, sv *paych.SignedVoucher) (big.Int, error) { - // TODO: merges - if len(sv.Merges) != 0 { - return big.Int{}, xerrors.Errorf("dont currently support paych lane merges") - } - - total := big.NewInt(0) - for _, ls := range laneStates { - total = big.Add(total, ls.Redeemed) - } - - lane, ok := laneStates[sv.Lane] - if ok { - // If the voucher is for an existing lane, and the voucher nonce - // and is higher than the lane nonce - if sv.Nonce > lane.Nonce { - // Add the delta between the redeemed amount and the voucher - // amount to the total - delta := big.Sub(sv.Amount, lane.Redeemed) - total = big.Add(total, delta) - } - } else { - // If the voucher is *not* for an existing lane, just add its - // value (implicitly a new lane will be created for the voucher) - total = big.Add(total, sv.Amount) - } - - return total, nil -} diff --git a/paychmgr/store.go b/paychmgr/store.go index 66a514feb2..bf7b3ee9bb 100644 --- a/paychmgr/store.go +++ b/paychmgr/store.go @@ -4,14 +4,16 @@ import ( "bytes" "errors" "fmt" - "strings" - "sync" + + "github.com/google/uuid" + + "github.com/filecoin-project/lotus/chain/types" "github.com/filecoin-project/specs-actors/actors/builtin/paych" + "github.com/ipfs/go-cid" "github.com/ipfs/go-datastore" "github.com/ipfs/go-datastore/namespace" dsq "github.com/ipfs/go-datastore/query" - "golang.org/x/xerrors" "github.com/filecoin-project/go-address" cborrpc "github.com/filecoin-project/go-cbor-util" @@ -22,8 +24,6 @@ import ( var ErrChannelNotTracked = errors.New("channel not tracked") type Store struct { - lk sync.Mutex // TODO: this can be split per paych - ds datastore.Batching } @@ -39,85 +39,107 @@ const ( DirOutbound = 2 ) +const ( + dsKeyChannelInfo = "ChannelInfo" + dsKeyMsgCid = "MsgCid" +) + type VoucherInfo struct { Voucher *paych.SignedVoucher Proof []byte } +// ChannelInfo keeps track of information about a channel type ChannelInfo struct { - Channel address.Address + // ChannelID is a uuid set at channel creation + ChannelID string + // Channel address - may be nil if the channel hasn't been created yet + Channel *address.Address + // Control is the address of the account that created the channel Control address.Address - Target address.Address - + // Target is the address of the account on the other end of the channel + Target address.Address + // Direction indicates if the channel is inbound (this node is the Target) + // or outbound (this node is the Control) Direction uint64 - Vouchers []*VoucherInfo - NextLane uint64 + // Vouchers is a list of all vouchers sent on the channel + Vouchers []*VoucherInfo + // NextLane is the number of the next lane that should be used when the + // client requests a new lane (eg to create a voucher for a new deal) + NextLane uint64 + // Amount added to the channel. + // Note: This amount is only used by GetPaych to keep track of how much + // has locally been added to the channel. It should reflect the channel's + // Balance on chain as long as all operations occur on the same datastore. + Amount types.BigInt + // PendingAmount is the amount that we're awaiting confirmation of + PendingAmount types.BigInt + // CreateMsg is the CID of a pending create message (while waiting for confirmation) + CreateMsg *cid.Cid + // AddFundsMsg is the CID of a pending add funds message (while waiting for confirmation) + AddFundsMsg *cid.Cid + // Settling indicates whether the channel has entered into the settling state + Settling bool } -func dskeyForChannel(addr address.Address) datastore.Key { - return datastore.NewKey(addr.String()) -} - -func (ps *Store) putChannelInfo(ci *ChannelInfo) error { - k := dskeyForChannel(ci.Channel) - - b, err := cborrpc.Dump(ci) - if err != nil { +// TrackChannel stores a channel, returning an error if the channel was already +// being tracked +func (ps *Store) TrackChannel(ci *ChannelInfo) error { + _, err := ps.ByAddress(*ci.Channel) + switch err { + default: return err + case nil: + return fmt.Errorf("already tracking channel: %s", ci.Channel) + case ErrChannelNotTracked: + return ps.putChannelInfo(ci) } - - return ps.ds.Put(k, b) } -func (ps *Store) getChannelInfo(addr address.Address) (*ChannelInfo, error) { - k := dskeyForChannel(addr) - - b, err := ps.ds.Get(k) - if err == datastore.ErrNotFound { - return nil, ErrChannelNotTracked - } +// ListChannels returns the addresses of all channels that have been created +func (ps *Store) ListChannels() ([]address.Address, error) { + cis, err := ps.findChans(func(ci *ChannelInfo) bool { + return ci.Channel != nil + }, 0) if err != nil { return nil, err } - var ci ChannelInfo - if err := ci.UnmarshalCBOR(bytes.NewReader(b)); err != nil { - return nil, err + addrs := make([]address.Address, 0, len(cis)) + for _, ci := range cis { + addrs = append(addrs, *ci.Channel) } - return &ci, nil + return addrs, nil } -func (ps *Store) TrackChannel(ch *ChannelInfo) error { - ps.lk.Lock() - defer ps.lk.Unlock() - - return ps.trackChannel(ch) -} +// findChan finds a single channel using the given filter. +// If there isn't a channel that matches the filter, returns ErrChannelNotTracked +func (ps *Store) findChan(filter func(ci *ChannelInfo) bool) (*ChannelInfo, error) { + cis, err := ps.findChans(filter, 1) + if err != nil { + return nil, err + } -func (ps *Store) trackChannel(ch *ChannelInfo) error { - _, err := ps.getChannelInfo(ch.Channel) - switch err { - default: - return err - case nil: - return fmt.Errorf("already tracking channel: %s", ch.Channel) - case ErrChannelNotTracked: - return ps.putChannelInfo(ch) + if len(cis) == 0 { + return nil, ErrChannelNotTracked } -} -func (ps *Store) ListChannels() ([]address.Address, error) { - ps.lk.Lock() - defer ps.lk.Unlock() + return &cis[0], err +} - res, err := ps.ds.Query(dsq.Query{KeysOnly: true}) +// findChans loops over all channels, only including those that pass the filter. +// max is the maximum number of channels to return. Set to zero to return unlimited channels. +func (ps *Store) findChans(filter func(*ChannelInfo) bool, max int) ([]ChannelInfo, error) { + res, err := ps.ds.Query(dsq.Query{Prefix: dsKeyChannelInfo}) if err != nil { return nil, err } defer res.Close() //nolint:errcheck - var out []address.Address + var stored ChannelInfo + var matches []ChannelInfo + for { res, ok := res.NextSync() if !ok { @@ -128,75 +150,245 @@ func (ps *Store) ListChannels() ([]address.Address, error) { return nil, err } - addr, err := address.NewFromString(strings.TrimPrefix(res.Key, "/")) + ci, err := unmarshallChannelInfo(&stored, res) if err != nil { - return nil, xerrors.Errorf("failed reading paych key (%q) from datastore: %w", res.Key, err) + return nil, err + } + + if !filter(ci) { + continue } - out = append(out, addr) + matches = append(matches, *ci) + + // If we've reached the maximum number of matches, return. + // Note that if max is zero we return an unlimited number of matches + // because len(matches) will always be at least 1. + if len(matches) == max { + return matches, nil + } } - return out, nil + return matches, nil } -func (ps *Store) findChan(filter func(*ChannelInfo) bool) (address.Address, error) { - res, err := ps.ds.Query(dsq.Query{}) +// AllocateLane allocates a new lane for the given channel +func (ps *Store) AllocateLane(ch address.Address) (uint64, error) { + ci, err := ps.ByAddress(ch) if err != nil { - return address.Undef, err + return 0, err } - defer res.Close() //nolint:errcheck - var ci ChannelInfo + out := ci.NextLane + ci.NextLane++ - for { - res, ok := res.NextSync() - if !ok { - break - } + return out, ps.putChannelInfo(ci) +} - if res.Error != nil { - return address.Undef, err - } +// VouchersForPaych gets the vouchers for the given channel +func (ps *Store) VouchersForPaych(ch address.Address) ([]*VoucherInfo, error) { + ci, err := ps.ByAddress(ch) + if err != nil { + return nil, err + } - if err := ci.UnmarshalCBOR(bytes.NewReader(res.Value)); err != nil { - return address.Undef, err - } + return ci.Vouchers, nil +} - if !filter(&ci) { - continue +// ByAddress gets the channel that matches the given address +func (ps *Store) ByAddress(addr address.Address) (*ChannelInfo, error) { + return ps.findChan(func(ci *ChannelInfo) bool { + return ci.Channel != nil && *ci.Channel == addr + }) +} + +// MsgInfo stores information about a create channel / add funds message +// that has been sent +type MsgInfo struct { + // ChannelID links the message to a channel + ChannelID string + // MsgCid is the CID of the message + MsgCid cid.Cid + // Received indicates whether a response has been received + Received bool + // Err is the error received in the response + Err string +} + +// The datastore key used to identify the message +func dskeyForMsg(mcid cid.Cid) datastore.Key { + return datastore.KeyWithNamespaces([]string{dsKeyMsgCid, mcid.String()}) +} + +// SaveNewMessage is called when a message is sent +func (ps *Store) SaveNewMessage(channelID string, mcid cid.Cid) error { + k := dskeyForMsg(mcid) + + b, err := cborrpc.Dump(&MsgInfo{ChannelID: channelID, MsgCid: mcid}) + if err != nil { + return err + } + + return ps.ds.Put(k, b) +} + +// SaveMessageResult is called when the result of a message is received +func (ps *Store) SaveMessageResult(mcid cid.Cid, msgErr error) error { + minfo, err := ps.GetMessage(mcid) + if err != nil { + return err + } + + k := dskeyForMsg(mcid) + minfo.Received = true + if msgErr != nil { + minfo.Err = msgErr.Error() + } + + b, err := cborrpc.Dump(minfo) + if err != nil { + return err + } + + return ps.ds.Put(k, b) +} + +// ByMessageCid gets the channel associated with a message +func (ps *Store) ByMessageCid(mcid cid.Cid) (*ChannelInfo, error) { + minfo, err := ps.GetMessage(mcid) + if err != nil { + return nil, err + } + + ci, err := ps.findChan(func(ci *ChannelInfo) bool { + return ci.ChannelID == minfo.ChannelID + }) + if err != nil { + return nil, err + } + + return ci, err +} + +// GetMessage gets the message info for a given message CID +func (ps *Store) GetMessage(mcid cid.Cid) (*MsgInfo, error) { + k := dskeyForMsg(mcid) + + val, err := ps.ds.Get(k) + if err != nil { + return nil, err + } + + var minfo MsgInfo + if err := minfo.UnmarshalCBOR(bytes.NewReader(val)); err != nil { + return nil, err + } + + return &minfo, nil +} + +// OutboundActiveByFromTo looks for outbound channels that have not been +// settled, with the given from / to addresses +func (ps *Store) OutboundActiveByFromTo(from address.Address, to address.Address) (*ChannelInfo, error) { + return ps.findChan(func(ci *ChannelInfo) bool { + if ci.Direction != DirOutbound { + return false + } + if ci.Settling { + return false } + return ci.Control == from && ci.Target == to + }) +} - addr, err := address.NewFromString(strings.TrimPrefix(res.Key, "/")) - if err != nil { - return address.Undef, xerrors.Errorf("failed reading paych key (%q) from datastore: %w", res.Key, err) +// WithPendingAddFunds is used on startup to find channels for which a +// create channel or add funds message has been sent, but lotus shut down +// before the response was received. +func (ps *Store) WithPendingAddFunds() ([]ChannelInfo, error) { + return ps.findChans(func(ci *ChannelInfo) bool { + if ci.Direction != DirOutbound { + return false } + return ci.CreateMsg != nil || ci.AddFundsMsg != nil + }, 0) +} + +// createChannel creates an outbound channel for the given from / to, ensuring +// it has a higher sequence number than any existing channel with the same from / to +func (ps *Store) createChannel(from address.Address, to address.Address, createMsgCid cid.Cid, amt types.BigInt) (*ChannelInfo, error) { + ci := &ChannelInfo{ + ChannelID: uuid.New().String(), + Direction: DirOutbound, + NextLane: 0, + Control: from, + Target: to, + CreateMsg: &createMsgCid, + PendingAmount: amt, + } + + // Save the new channel + err := ps.putChannelInfo(ci) + if err != nil { + return nil, err + } - return addr, nil + // Save a reference to the create message + err = ps.SaveNewMessage(ci.ChannelID, createMsgCid) + if err != nil { + return nil, err } - return address.Undef, nil + return ci, err } -func (ps *Store) AllocateLane(ch address.Address) (uint64, error) { - ps.lk.Lock() - defer ps.lk.Unlock() +// The datastore key used to identify the channel info +func dskeyForChannel(ci *ChannelInfo) datastore.Key { + chanKey := fmt.Sprintf("%s->%s", ci.Control.String(), ci.Target.String()) + return datastore.KeyWithNamespaces([]string{dsKeyChannelInfo, chanKey}) +} + +// putChannelInfo stores the channel info in the datastore +func (ps *Store) putChannelInfo(ci *ChannelInfo) error { + k := dskeyForChannel(ci) - ci, err := ps.getChannelInfo(ch) + b, err := marshallChannelInfo(ci) if err != nil { - return 0, err + return err } - out := ci.NextLane - ci.NextLane++ - - return out, ps.putChannelInfo(ci) + return ps.ds.Put(k, b) } -func (ps *Store) VouchersForPaych(ch address.Address) ([]*VoucherInfo, error) { - ci, err := ps.getChannelInfo(ch) +// TODO: This is a hack to get around not being able to CBOR marshall a nil +// address.Address. It's been fixed in address.Address but we need to wait +// for the change to propagate to specs-actors before we can remove this hack. +var emptyAddr address.Address + +func init() { + addr, err := address.NewActorAddress([]byte("empty")) if err != nil { + panic(err) + } + emptyAddr = addr +} + +func marshallChannelInfo(ci *ChannelInfo) ([]byte, error) { + // See note above about CBOR marshalling address.Address + if ci.Channel == nil { + ci.Channel = &emptyAddr + } + return cborrpc.Dump(ci) +} + +func unmarshallChannelInfo(stored *ChannelInfo, res dsq.Result) (*ChannelInfo, error) { + if err := stored.UnmarshalCBOR(bytes.NewReader(res.Value)); err != nil { return nil, err } - return ci.Vouchers, nil + // See note above about CBOR marshalling address.Address + if stored.Channel != nil && *stored.Channel == emptyAddr { + stored.Channel = nil + } + + return stored, nil } diff --git a/paychmgr/store_test.go b/paychmgr/store_test.go index 0942264646..65be6f1b19 100644 --- a/paychmgr/store_test.go +++ b/paychmgr/store_test.go @@ -17,8 +17,9 @@ func TestStore(t *testing.T) { require.NoError(t, err) require.Len(t, addrs, 0) + ch := tutils.NewIDAddr(t, 100) ci := &ChannelInfo{ - Channel: tutils.NewIDAddr(t, 100), + Channel: &ch, Control: tutils.NewIDAddr(t, 101), Target: tutils.NewIDAddr(t, 102), @@ -26,8 +27,9 @@ func TestStore(t *testing.T) { Vouchers: []*VoucherInfo{{Voucher: nil, Proof: []byte{}}}, } + ch2 := tutils.NewIDAddr(t, 200) ci2 := &ChannelInfo{ - Channel: tutils.NewIDAddr(t, 200), + Channel: &ch2, Control: tutils.NewIDAddr(t, 201), Target: tutils.NewIDAddr(t, 202), @@ -55,7 +57,7 @@ func TestStore(t *testing.T) { require.Contains(t, addrsStrings(addrs), "t0200") // Request vouchers for channel - vouchers, err := store.VouchersForPaych(ci.Channel) + vouchers, err := store.VouchersForPaych(*ci.Channel) require.NoError(t, err) require.Len(t, vouchers, 1) @@ -64,12 +66,12 @@ func TestStore(t *testing.T) { require.Equal(t, err, ErrChannelNotTracked) // Allocate lane for channel - lane, err := store.AllocateLane(ci.Channel) + lane, err := store.AllocateLane(*ci.Channel) require.NoError(t, err) require.Equal(t, lane, uint64(0)) // Allocate next lane for channel - lane, err = store.AllocateLane(ci.Channel) + lane, err = store.AllocateLane(*ci.Channel) require.NoError(t, err) require.Equal(t, lane, uint64(1))