diff --git a/gen/main.go b/gen/main.go index 01cd756f78..ec8faf35b2 100644 --- a/gen/main.go +++ b/gen/main.go @@ -34,7 +34,7 @@ func main() { err = gen.WriteMapEncodersToFile("./paychmgr/cbor_gen.go", "paychmgr", paychmgr.VoucherInfo{}, - paychmgr.ChannelInfo{}, + paychmgr.ChannelInfoStorable{}, ) if err != nil { fmt.Println(err) diff --git a/node/builder.go b/node/builder.go index 16be84d2cd..2d860cc461 100644 --- a/node/builder.go +++ b/node/builder.go @@ -107,6 +107,7 @@ const ( HandleIncomingMessagesKey RegisterClientValidatorKey + HandlePaymentChannelManagerKey // miner GetParamsKey @@ -272,6 +273,7 @@ func Online() Option { Override(new(*paychmgr.Store), paychmgr.NewStore), Override(new(*paychmgr.Manager), paychmgr.NewManager), + Override(HandlePaymentChannelManagerKey, paychmgr.HandleManager), Override(new(*market.FundMgr), market.NewFundMgr), Override(SettlePaymentChannelsKey, settler.SettlePaymentChannels), ), diff --git a/paychmgr/accessorcache.go b/paychmgr/accessorcache.go new file mode 100644 index 0000000000..11f9d3bc96 --- /dev/null +++ b/paychmgr/accessorcache.go @@ -0,0 +1,67 @@ +package paychmgr + +import "github.com/filecoin-project/go-address" + +// accessorByFromTo gets a channel accessor for a given from / to pair. +// The channel accessor facilitates locking a channel so that operations +// must be performed sequentially on a channel (but can be performed at +// the same time on different channels). +func (pm *Manager) accessorByFromTo(from address.Address, to address.Address) (*channelAccessor, error) { + key := pm.accessorCacheKey(from, to) + + // First take a read lock and check the cache + pm.lk.RLock() + ca, ok := pm.channels[key] + pm.lk.RUnlock() + if ok { + return ca, nil + } + + // Not in cache, so take a write lock + pm.lk.Lock() + defer pm.lk.Unlock() + + // Need to check cache again in case it was updated between releasing read + // lock and taking write lock + ca, ok = pm.channels[key] + if !ok { + // Not in cache, so create a new one and store in cache + ca = pm.addAccessorToCache(from, to) + } + + return ca, nil +} + +// accessorByAddress gets a channel accessor for a given channel address. +// The channel accessor facilitates locking a channel so that operations +// must be performed sequentially on a channel (but can be performed at +// the same time on different channels). +func (pm *Manager) accessorByAddress(ch address.Address) (*channelAccessor, error) { + // Get the channel from / to + pm.lk.RLock() + channelInfo, err := pm.store.ByAddress(ch) + pm.lk.RUnlock() + if err != nil { + return nil, err + } + + // TODO: cache by channel address so we can get by address instead of using from / to + return pm.accessorByFromTo(channelInfo.Control, channelInfo.Target) +} + +// accessorCacheKey returns the cache key use to reference a channel accessor +func (pm *Manager) accessorCacheKey(from address.Address, to address.Address) string { + return from.String() + "->" + to.String() +} + +// addAccessorToCache adds a channel accessor to a cache. Note that channelInfo +// may be nil if the channel hasn't been created yet, but we still want to +// reference the same channel accessor for a given from/to, so that all +// attempts to access a channel use the same lock (the lock on the accessor) +func (pm *Manager) addAccessorToCache(from address.Address, to address.Address) *channelAccessor { + key := pm.accessorCacheKey(from, to) + ca := newChannelAccessor(pm) + // TODO: Use LRU + pm.channels[key] = ca + return ca +} diff --git a/paychmgr/cbor_gen.go b/paychmgr/cbor_gen.go index 828b0f45c1..5e22efad48 100644 --- a/paychmgr/cbor_gen.go +++ b/paychmgr/cbor_gen.go @@ -147,18 +147,18 @@ func (t *VoucherInfo) UnmarshalCBOR(r io.Reader) error { return nil } -func (t *ChannelInfo) MarshalCBOR(w io.Writer) error { +func (t *ChannelInfoStorable) MarshalCBOR(w io.Writer) error { if t == nil { _, err := w.Write(cbg.CborNull) return err } - if _, err := w.Write([]byte{166}); err != nil { + if _, err := w.Write([]byte{171}); err != nil { return err } scratch := make([]byte, 9) - // t.Channel (address.Address) (struct) + // t.Channel (string) (string) if len("Channel") > cbg.MaxLength { return xerrors.Errorf("Value in field \"Channel\" was too long") } @@ -170,7 +170,14 @@ func (t *ChannelInfo) MarshalCBOR(w io.Writer) error { return err } - if err := t.Channel.MarshalCBOR(w); err != nil { + if len(t.Channel) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.Channel was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.Channel))); err != nil { + return err + } + if _, err := io.WriteString(w, t.Channel); err != nil { return err } @@ -206,6 +213,22 @@ func (t *ChannelInfo) MarshalCBOR(w io.Writer) error { return err } + // t.Sequence (uint64) (uint64) + if len("Sequence") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Sequence\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Sequence"))); err != nil { + return err + } + if _, err := io.WriteString(w, "Sequence"); err != nil { + return err + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.Sequence)); err != nil { + return err + } + // t.Direction (uint64) (uint64) if len("Direction") > cbg.MaxLength { return xerrors.Errorf("Value in field \"Direction\" was too long") @@ -263,11 +286,87 @@ func (t *ChannelInfo) MarshalCBOR(w io.Writer) error { return err } + // t.Amount (big.Int) (struct) + if len("Amount") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Amount\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Amount"))); err != nil { + return err + } + if _, err := io.WriteString(w, "Amount"); err != nil { + return err + } + + if err := t.Amount.MarshalCBOR(w); err != nil { + return err + } + + // t.PendingAmount (big.Int) (struct) + if len("PendingAmount") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"PendingAmount\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("PendingAmount"))); err != nil { + return err + } + if _, err := io.WriteString(w, "PendingAmount"); err != nil { + return err + } + + if err := t.PendingAmount.MarshalCBOR(w); err != nil { + return err + } + + // t.AddFundsMsg (cid.Cid) (struct) + if len("AddFundsMsg") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"AddFundsMsg\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("AddFundsMsg"))); err != nil { + return err + } + if _, err := io.WriteString(w, "AddFundsMsg"); err != nil { + return err + } + + if t.AddFundsMsg == nil { + if _, err := w.Write(cbg.CborNull); err != nil { + return err + } + } else { + if err := cbg.WriteCidBuf(scratch, w, *t.AddFundsMsg); err != nil { + return xerrors.Errorf("failed to write cid field t.AddFundsMsg: %w", err) + } + } + + // t.CreateMsg (cid.Cid) (struct) + if len("CreateMsg") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"CreateMsg\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("CreateMsg"))); err != nil { + return err + } + if _, err := io.WriteString(w, "CreateMsg"); err != nil { + return err + } + + if t.CreateMsg == nil { + if _, err := w.Write(cbg.CborNull); err != nil { + return err + } + } else { + if err := cbg.WriteCidBuf(scratch, w, *t.CreateMsg); err != nil { + return xerrors.Errorf("failed to write cid field t.CreateMsg: %w", err) + } + } + return nil } -func (t *ChannelInfo) UnmarshalCBOR(r io.Reader) error { - *t = ChannelInfo{} +func (t *ChannelInfoStorable) UnmarshalCBOR(r io.Reader) error { + *t = ChannelInfoStorable{} br := cbg.GetPeeker(r) scratch := make([]byte, 8) @@ -281,7 +380,7 @@ func (t *ChannelInfo) UnmarshalCBOR(r io.Reader) error { } if extra > cbg.MaxLength { - return fmt.Errorf("ChannelInfo: map struct too large (%d)", extra) + return fmt.Errorf("ChannelInfoStorable: map struct too large (%d)", extra) } var name string @@ -299,15 +398,16 @@ func (t *ChannelInfo) UnmarshalCBOR(r io.Reader) error { } switch name { - // t.Channel (address.Address) (struct) + // t.Channel (string) (string) case "Channel": { - - if err := t.Channel.UnmarshalCBOR(br); err != nil { - return xerrors.Errorf("unmarshaling t.Channel: %w", err) + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err } + t.Channel = string(sval) } // t.Control (address.Address) (struct) case "Control": @@ -328,6 +428,21 @@ func (t *ChannelInfo) UnmarshalCBOR(r io.Reader) error { return xerrors.Errorf("unmarshaling t.Target: %w", err) } + } + // t.Sequence (uint64) (uint64) + case "Sequence": + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.Sequence = uint64(extra) + } // t.Direction (uint64) (uint64) case "Direction": @@ -389,6 +504,76 @@ func (t *ChannelInfo) UnmarshalCBOR(r io.Reader) error { t.NextLane = uint64(extra) } + // t.Amount (big.Int) (struct) + case "Amount": + + { + + if err := t.Amount.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.Amount: %w", err) + } + + } + // t.PendingAmount (big.Int) (struct) + case "PendingAmount": + + { + + if err := t.PendingAmount.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.PendingAmount: %w", err) + } + + } + // t.AddFundsMsg (cid.Cid) (struct) + case "AddFundsMsg": + + { + + pb, err := br.PeekByte() + if err != nil { + return err + } + if pb == cbg.CborNull[0] { + var nbuf [1]byte + if _, err := br.Read(nbuf[:]); err != nil { + return err + } + } else { + + c, err := cbg.ReadCid(br) + if err != nil { + return xerrors.Errorf("failed to read cid field t.AddFundsMsg: %w", err) + } + + t.AddFundsMsg = &c + } + + } + // t.CreateMsg (cid.Cid) (struct) + case "CreateMsg": + + { + + pb, err := br.PeekByte() + if err != nil { + return err + } + if pb == cbg.CborNull[0] { + var nbuf [1]byte + if _, err := br.Read(nbuf[:]); err != nil { + return err + } + } else { + + c, err := cbg.ReadCid(br) + if err != nil { + return xerrors.Errorf("failed to read cid field t.CreateMsg: %w", err) + } + + t.CreateMsg = &c + } + + } default: return fmt.Errorf("unknown struct field %d: '%s'", i, name) diff --git a/paychmgr/channellock.go b/paychmgr/channellock.go new file mode 100644 index 0000000000..0dc785ec0f --- /dev/null +++ b/paychmgr/channellock.go @@ -0,0 +1,33 @@ +package paychmgr + +import "sync" + +type rwlock interface { + RLock() + RUnlock() +} + +// channelLock manages locking for a specific channel. +// Some operations update the state of a single channel, and need to block +// other operations only on the same channel's state. +// Some operations update state that affects all channels, and need to block +// any operation against any channel. +type channelLock struct { + globalLock rwlock + chanLock sync.Mutex +} + +func (l *channelLock) Lock() { + // Wait for other operations by this channel to finish. + // Exclusive per-channel (no other ops by this channel allowed). + l.chanLock.Lock() + // Wait for operations affecting all channels to finish. + // Allows ops by other channels in parallel, but blocks all operations + // if global lock is taken exclusively (eg when adding a channel) + l.globalLock.RLock() +} + +func (l *channelLock) Unlock() { + l.globalLock.RUnlock() + l.chanLock.Unlock() +} diff --git a/paychmgr/manager.go b/paychmgr/manager.go new file mode 100644 index 0000000000..2643b0e1d1 --- /dev/null +++ b/paychmgr/manager.go @@ -0,0 +1,272 @@ +package paychmgr + +import ( + "context" + "sync" + + "golang.org/x/sync/errgroup" + + "github.com/filecoin-project/go-statemachine/fsm" + xerrors "golang.org/x/xerrors" + + "github.com/filecoin-project/lotus/api" + + "github.com/filecoin-project/specs-actors/actors/builtin/paych" + + "github.com/ipfs/go-cid" + logging "github.com/ipfs/go-log/v2" + "go.uber.org/fx" + + "github.com/filecoin-project/go-address" + + "github.com/filecoin-project/lotus/chain/stmgr" + "github.com/filecoin-project/lotus/chain/types" + "github.com/filecoin-project/lotus/node/impl/full" +) + +var log = logging.Logger("paych") + +type ManagerApi struct { + fx.In + + full.MpoolAPI + full.WalletAPI + full.StateAPI +} + +type StateManagerApi interface { + LoadActorState(ctx context.Context, a address.Address, out interface{}, ts *types.TipSet) (*types.Actor, error) + Call(ctx context.Context, msg *types.Message, ts *types.TipSet) (*api.InvocResult, error) +} + +type Manager struct { + // The Manager context is used to terminate wait operations on shutdown + ctx context.Context + shutdown context.CancelFunc + + store *Store + sm StateManagerApi + sa *stateAccessor + statemachines fsm.Group + pchapi paychApi + + lk sync.RWMutex + channels map[string]*channelAccessor + + mpool full.MpoolAPI + wallet full.WalletAPI + state full.StateAPI +} + +type paychAPIImpl struct { + full.MpoolAPI + full.StateAPI +} + +func NewManager(sm *stmgr.StateManager, pchstore *Store, api ManagerApi) *Manager { + return &Manager{ + store: pchstore, + sm: sm, + sa: &stateAccessor{sm: sm}, + channels: make(map[string]*channelAccessor), + // TODO: Is this the correct way to do this or can I do something different + // with dependency injection? + pchapi: &paychAPIImpl{api.MpoolAPI, api.StateAPI}, + + mpool: api.MpoolAPI, + wallet: api.WalletAPI, + state: api.StateAPI, + } +} + +// newManager is used by the tests to supply mocks +func newManager(sm StateManagerApi, pchstore *Store, pchapi paychApi) (*Manager, error) { + pm := &Manager{ + store: pchstore, + sm: sm, + sa: &stateAccessor{sm: sm}, + channels: make(map[string]*channelAccessor), + pchapi: pchapi, + } + return pm, pm.Start(context.Background()) +} + +// HandleManager is called by dependency injection to set up hooks +func HandleManager(lc fx.Lifecycle, pm *Manager) { + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + return pm.Start(ctx) + }, + OnStop: func(context.Context) error { + return pm.Stop() + }, + }) +} + +// Start checks the datastore to see if there are any channels that have +// outstanding add funds messages, and if so, waits on the messages. +// Outstanding messages can occur if an add funds message was sent +// and then lotus was shut down or crashed before the result was +// received. +func (pm *Manager) Start(ctx context.Context) error { + pm.ctx, pm.shutdown = context.WithCancel(ctx) + + cis, err := pm.store.WithPendingAddFunds() + if err != nil { + return err + } + + group := errgroup.Group{} + for _, ci := range cis { + if ci.CreateMsg != nil { + group.Go(func() error { + ca, err := pm.accessorByFromTo(ci.Control, ci.Target) + if err != nil { + return xerrors.Errorf("error initializing payment channel manager %s -> %s: %s", ci.Control, ci.Target, err) + } + go func() { + err = ca.waitForPaychCreateMsg(ci.Control, ci.Target, *ci.CreateMsg) + ca.msgWaitComplete(err) + }() + return nil + }) + } else if ci.AddFundsMsg != nil { + group.Go(func() error { + ca, err := pm.accessorByAddress(*ci.Channel) + if err != nil { + return xerrors.Errorf("error initializing payment channel manager %s: %s", ci.Channel, err) + } + go func() { + err = ca.waitForAddFundsMsg(ci.Control, ci.Target, *ci.AddFundsMsg) + ca.msgWaitComplete(err) + }() + return nil + }) + } + } + + return group.Wait() +} + +// Stop shuts down any processes used by the manager +func (pm *Manager) Stop() error { + pm.shutdown() + return nil +} + +func (pm *Manager) TrackOutboundChannel(ctx context.Context, ch address.Address) error { + return pm.trackChannel(ctx, ch, DirOutbound) +} + +func (pm *Manager) TrackInboundChannel(ctx context.Context, ch address.Address) error { + return pm.trackChannel(ctx, ch, DirInbound) +} + +func (pm *Manager) trackChannel(ctx context.Context, ch address.Address, dir uint64) error { + pm.lk.Lock() + defer pm.lk.Unlock() + + ci, err := pm.sa.loadStateChannelInfo(ctx, ch, dir) + if err != nil { + return err + } + + return pm.store.TrackChannel(ci) +} + +func (pm *Manager) GetPaych(ctx context.Context, from, to address.Address, ensureFree types.BigInt) (address.Address, cid.Cid, error) { + chanAccessor, err := pm.accessorByFromTo(from, to) + if err != nil { + return address.Undef, cid.Undef, err + } + + return chanAccessor.getPaych(ctx, from, to, ensureFree) +} + +func (pm *Manager) ListChannels() ([]address.Address, error) { + // Need to take an exclusive lock here so that channel operations can't run + // in parallel (see channelLock) + pm.lk.Lock() + defer pm.lk.Unlock() + + return pm.store.ListChannels() +} + +func (pm *Manager) GetChannelInfo(addr address.Address) (*ChannelInfo, error) { + ca, err := pm.accessorByAddress(addr) + if err != nil { + return nil, err + } + return ca.getChannelInfo(addr) +} + +// CheckVoucherValid checks if the given voucher is valid (is or could become spendable at some point) +func (pm *Manager) CheckVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) error { + ca, err := pm.accessorByAddress(ch) + if err != nil { + return err + } + + _, err = ca.checkVoucherValid(ctx, ch, sv) + return err +} + +// CheckVoucherSpendable checks if the given voucher is currently spendable +func (pm *Manager) CheckVoucherSpendable(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, secret []byte, proof []byte) (bool, error) { + ca, err := pm.accessorByAddress(ch) + if err != nil { + return false, err + } + + return ca.checkVoucherSpendable(ctx, ch, sv, secret, proof) +} + +func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, proof []byte, minDelta types.BigInt) (types.BigInt, error) { + ca, err := pm.accessorByAddress(ch) + if err != nil { + return types.NewInt(0), err + } + return ca.addVoucher(ctx, ch, sv, proof, minDelta) +} + +func (pm *Manager) AllocateLane(ch address.Address) (uint64, error) { + ca, err := pm.accessorByAddress(ch) + if err != nil { + return 0, err + } + return ca.allocateLane(ch) +} + +func (pm *Manager) ListVouchers(ctx context.Context, ch address.Address) ([]*VoucherInfo, error) { + ca, err := pm.accessorByAddress(ch) + if err != nil { + return nil, err + } + return ca.listVouchers(ctx, ch) +} + +func (pm *Manager) OutboundChanTo(from, to address.Address) (address.Address, error) { + pm.lk.Lock() + defer pm.lk.Unlock() + + ci, err := pm.store.OutboundByFromTo(from, to) + if err != nil { + return address.Undef, err + } + + // Channel create message has been sent but channel still hasn't been + // created on chain yet + if ci.Channel == nil { + return address.Undef, ErrChannelNotTracked + } + + return *ci.Channel, nil +} + +func (pm *Manager) NextNonceForLane(ctx context.Context, ch address.Address, lane uint64) (uint64, error) { + ca, err := pm.accessorByAddress(ch) + if err != nil { + return 0, err + } + return ca.nextNonceForLane(ctx, ch, lane) +} diff --git a/paychmgr/paych.go b/paychmgr/paych.go index 85db664cdf..a35981c980 100644 --- a/paychmgr/paych.go +++ b/paychmgr/paych.go @@ -5,118 +5,69 @@ import ( "context" "fmt" - "github.com/filecoin-project/specs-actors/actors/abi/big" - - "github.com/filecoin-project/lotus/api" - - cborutil "github.com/filecoin-project/go-cbor-util" - "github.com/filecoin-project/specs-actors/actors/builtin" - "github.com/filecoin-project/specs-actors/actors/builtin/account" - "github.com/filecoin-project/specs-actors/actors/builtin/paych" - "golang.org/x/xerrors" - - logging "github.com/ipfs/go-log/v2" - "go.uber.org/fx" - "github.com/filecoin-project/go-address" - + cborutil "github.com/filecoin-project/go-cbor-util" "github.com/filecoin-project/lotus/chain/actors" - "github.com/filecoin-project/lotus/chain/stmgr" "github.com/filecoin-project/lotus/chain/types" "github.com/filecoin-project/lotus/lib/sigs" - "github.com/filecoin-project/lotus/node/impl/full" + "github.com/filecoin-project/specs-actors/actors/abi/big" + "github.com/filecoin-project/specs-actors/actors/builtin" + "github.com/filecoin-project/specs-actors/actors/builtin/account" + "github.com/filecoin-project/specs-actors/actors/builtin/paych" + xerrors "golang.org/x/xerrors" ) -var log = logging.Logger("paych") - -type ManagerApi struct { - fx.In - - full.MpoolAPI - full.WalletAPI - full.StateAPI -} - -type StateManagerApi interface { - LoadActorState(ctx context.Context, a address.Address, out interface{}, ts *types.TipSet) (*types.Actor, error) - Call(ctx context.Context, msg *types.Message, ts *types.TipSet) (*api.InvocResult, error) +type channelAccessor struct { + waitCtx context.Context + sm StateManagerApi + sa *stateAccessor + api paychApi + store *Store + lk *channelLock + ensureFundsReqQueue []*ensureFundsReq } -type Manager struct { - store *Store - sm StateManagerApi - - mpool full.MpoolAPI - wallet full.WalletAPI - state full.StateAPI -} - -func NewManager(sm *stmgr.StateManager, pchstore *Store, api ManagerApi) *Manager { - return &Manager{ - store: pchstore, - sm: sm, - - mpool: api.MpoolAPI, - wallet: api.WalletAPI, - state: api.StateAPI, +func newChannelAccessor(pm *Manager) *channelAccessor { + return &channelAccessor{ + lk: &channelLock{globalLock: &pm.lk}, + sm: pm.sm, + sa: &stateAccessor{sm: pm.sm}, + api: pm.pchapi, + store: pm.store, + waitCtx: pm.ctx, } } -// Used by the tests to supply mocks -func newManager(sm StateManagerApi, pchstore *Store) *Manager { - return &Manager{ - store: pchstore, - sm: sm, - } -} - -func (pm *Manager) TrackOutboundChannel(ctx context.Context, ch address.Address) error { - return pm.trackChannel(ctx, ch, DirOutbound) -} - -func (pm *Manager) TrackInboundChannel(ctx context.Context, ch address.Address) error { - return pm.trackChannel(ctx, ch, DirInbound) -} +func (ca *channelAccessor) getChannelInfo(addr address.Address) (*ChannelInfo, error) { + ca.lk.Lock() + defer ca.lk.Unlock() -func (pm *Manager) trackChannel(ctx context.Context, ch address.Address, dir uint64) error { - ci, err := pm.loadStateChannelInfo(ctx, ch, dir) - if err != nil { - return err - } - - return pm.store.TrackChannel(ci) + return ca.store.ByAddress(addr) } -func (pm *Manager) ListChannels() ([]address.Address, error) { - return pm.store.ListChannels() -} +func (ca *channelAccessor) checkVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) (map[uint64]*paych.LaneState, error) { + ca.lk.Lock() + defer ca.lk.Unlock() -func (pm *Manager) GetChannelInfo(addr address.Address) (*ChannelInfo, error) { - return pm.store.getChannelInfo(addr) + return ca.checkVoucherValidUnlocked(ctx, ch, sv) } -// CheckVoucherValid checks if the given voucher is valid (is or could become spendable at some point) -func (pm *Manager) CheckVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) error { - _, err := pm.checkVoucherValid(ctx, ch, sv) - return err -} - -func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) (map[uint64]*paych.LaneState, error) { +func (ca *channelAccessor) checkVoucherValidUnlocked(ctx context.Context, ch address.Address, sv *paych.SignedVoucher) (map[uint64]*paych.LaneState, error) { if sv.ChannelAddr != ch { return nil, xerrors.Errorf("voucher ChannelAddr doesn't match channel address, got %s, expected %s", sv.ChannelAddr, ch) } - act, pchState, err := pm.loadPaychState(ctx, ch) + act, pchState, err := ca.sa.loadPaychState(ctx, ch) if err != nil { return nil, err } - var account account.State - _, err = pm.sm.LoadActorState(ctx, pchState.From, &account, nil) + var actState account.State + _, err = ca.sm.LoadActorState(ctx, pchState.From, &actState, nil) if err != nil { return nil, err } - from := account.Address + from := actState.Address // verify signature vb, err := sv.SigningBytes() @@ -132,7 +83,7 @@ func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv } // Check the voucher against the highest known voucher nonce / value - laneStates, err := pm.laneState(pchState, ch) + laneStates, err := ca.laneState(pchState, ch) if err != nil { return nil, err } @@ -164,7 +115,7 @@ func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv // lane 2: 2 // - // total: 7 - totalRedeemed, err := pm.totalRedeemedWithVoucher(laneStates, sv) + totalRedeemed, err := ca.totalRedeemedWithVoucher(laneStates, sv) if err != nil { return nil, err } @@ -183,15 +134,17 @@ func (pm *Manager) checkVoucherValid(ctx context.Context, ch address.Address, sv return laneStates, nil } -// CheckVoucherSpendable checks if the given voucher is currently spendable -func (pm *Manager) CheckVoucherSpendable(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, secret []byte, proof []byte) (bool, error) { - recipient, err := pm.getPaychRecipient(ctx, ch) +func (ca *channelAccessor) checkVoucherSpendable(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, secret []byte, proof []byte) (bool, error) { + ca.lk.Lock() + defer ca.lk.Unlock() + + recipient, err := ca.getPaychRecipient(ctx, ch) if err != nil { return false, err } if sv.Extra != nil && proof == nil { - known, err := pm.ListVouchers(ctx, ch) + known, err := ca.store.VouchersForPaych(ch) if err != nil { return false, err } @@ -221,7 +174,7 @@ func (pm *Manager) CheckVoucherSpendable(ctx context.Context, ch address.Address return false, err } - ret, err := pm.sm.Call(ctx, &types.Message{ + ret, err := ca.sm.Call(ctx, &types.Message{ From: recipient, To: ch, Method: builtin.MethodsPaych.UpdateChannelState, @@ -238,22 +191,22 @@ func (pm *Manager) CheckVoucherSpendable(ctx context.Context, ch address.Address return true, nil } -func (pm *Manager) getPaychRecipient(ctx context.Context, ch address.Address) (address.Address, error) { +func (ca *channelAccessor) getPaychRecipient(ctx context.Context, ch address.Address) (address.Address, error) { var state paych.State - if _, err := pm.sm.LoadActorState(ctx, ch, &state, nil); err != nil { + if _, err := ca.sm.LoadActorState(ctx, ch, &state, nil); err != nil { return address.Address{}, err } return state.To, nil } -func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, proof []byte, minDelta types.BigInt) (types.BigInt, error) { - pm.store.lk.Lock() - defer pm.store.lk.Unlock() +func (ca *channelAccessor) addVoucher(ctx context.Context, ch address.Address, sv *paych.SignedVoucher, proof []byte, minDelta types.BigInt) (types.BigInt, error) { + ca.lk.Lock() + defer ca.lk.Unlock() - ci, err := pm.store.getChannelInfo(ch) + ci, err := ca.store.ByAddress(ch) if err != nil { - return types.NewInt(0), err + return types.BigInt{}, err } // Check if the voucher has already been added @@ -275,7 +228,7 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych Proof: proof, } - return types.NewInt(0), pm.store.putChannelInfo(ci) + return types.NewInt(0), ca.store.putChannelInfo(ci) } // Otherwise just ignore the duplicate voucher @@ -284,7 +237,7 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych } // Check voucher validity - laneStates, err := pm.checkVoucherValid(ctx, ch, sv) + laneStates, err := ca.checkVoucherValidUnlocked(ctx, ch, sv) if err != nil { return types.NewInt(0), err } @@ -311,35 +264,32 @@ func (pm *Manager) AddVoucher(ctx context.Context, ch address.Address, sv *paych ci.NextLane = sv.Lane + 1 } - return delta, pm.store.putChannelInfo(ci) + return delta, ca.store.putChannelInfo(ci) } -func (pm *Manager) AllocateLane(ch address.Address) (uint64, error) { +func (ca *channelAccessor) allocateLane(ch address.Address) (uint64, error) { + ca.lk.Lock() + defer ca.lk.Unlock() + // TODO: should this take into account lane state? - return pm.store.AllocateLane(ch) + return ca.store.AllocateLane(ch) } -func (pm *Manager) ListVouchers(ctx context.Context, ch address.Address) ([]*VoucherInfo, error) { +func (ca *channelAccessor) listVouchers(ctx context.Context, ch address.Address) ([]*VoucherInfo, error) { + ca.lk.Lock() + defer ca.lk.Unlock() + // TODO: just having a passthrough method like this feels odd. Seems like // there should be some filtering we're doing here - return pm.store.VouchersForPaych(ch) + return ca.store.VouchersForPaych(ch) } -func (pm *Manager) OutboundChanTo(from, to address.Address) (address.Address, error) { - pm.store.lk.Lock() - defer pm.store.lk.Unlock() +func (ca *channelAccessor) nextNonceForLane(ctx context.Context, ch address.Address, lane uint64) (uint64, error) { + ca.lk.Lock() + defer ca.lk.Unlock() - return pm.store.findChan(func(ci *ChannelInfo) bool { - if ci.Direction != DirOutbound { - return false - } - return ci.Control == from && ci.Target == to - }) -} - -func (pm *Manager) NextNonceForLane(ctx context.Context, ch address.Address, lane uint64) (uint64, error) { // TODO: should this take into account lane state? - vouchers, err := pm.store.VouchersForPaych(ch) + vouchers, err := ca.store.VouchersForPaych(ch) if err != nil { return 0, err } @@ -355,3 +305,80 @@ func (pm *Manager) NextNonceForLane(ctx context.Context, ch address.Address, lan return maxnonce + 1, nil } + +// laneState gets the LaneStates from chain, then applies all vouchers in +// the data store over the chain state +func (ca *channelAccessor) laneState(state *paych.State, ch address.Address) (map[uint64]*paych.LaneState, error) { + // TODO: we probably want to call UpdateChannelState with all vouchers to be fully correct + // (but technically dont't need to) + laneStates := make(map[uint64]*paych.LaneState, len(state.LaneStates)) + + // Get the lane state from the chain + for _, laneState := range state.LaneStates { + laneStates[laneState.ID] = laneState + } + + // Apply locally stored vouchers + vouchers, err := ca.store.VouchersForPaych(ch) + if err != nil && err != ErrChannelNotTracked { + return nil, err + } + + for _, v := range vouchers { + for range v.Voucher.Merges { + return nil, xerrors.Errorf("paych merges not handled yet") + } + + // If there's a voucher for a lane that isn't in chain state just + // create it + ls, ok := laneStates[v.Voucher.Lane] + if !ok { + ls = &paych.LaneState{ + ID: v.Voucher.Lane, + Redeemed: types.NewInt(0), + Nonce: 0, + } + laneStates[v.Voucher.Lane] = ls + } + + if v.Voucher.Nonce < ls.Nonce { + continue + } + + ls.Nonce = v.Voucher.Nonce + ls.Redeemed = v.Voucher.Amount + } + + return laneStates, nil +} + +// Get the total redeemed amount across all lanes, after applying the voucher +func (ca *channelAccessor) totalRedeemedWithVoucher(laneStates map[uint64]*paych.LaneState, sv *paych.SignedVoucher) (big.Int, error) { + // TODO: merges + if len(sv.Merges) != 0 { + return big.Int{}, xerrors.Errorf("dont currently support paych lane merges") + } + + total := big.NewInt(0) + for _, ls := range laneStates { + total = big.Add(total, ls.Redeemed) + } + + lane, ok := laneStates[sv.Lane] + if ok { + // If the voucher is for an existing lane, and the voucher nonce + // and is higher than the lane nonce + if sv.Nonce > lane.Nonce { + // Add the delta between the redeemed amount and the voucher + // amount to the total + delta := big.Sub(sv.Amount, lane.Redeemed) + total = big.Add(total, delta) + } + } else { + // If the voucher is *not* for an existing lane, just add its + // value (implicitly a new lane will be created for the voucher) + total = big.Add(total, sv.Amount) + } + + return total, nil +} diff --git a/paychmgr/paych_test.go b/paychmgr/paych_test.go index 2cbea5cb54..9c28fdcb8b 100644 --- a/paychmgr/paych_test.go +++ b/paychmgr/paych_test.go @@ -104,13 +104,15 @@ func TestPaychOutbound(t *testing.T) { LaneStates: []*paych.LaneState{}, }) - mgr := newManager(sm, store) - err := mgr.TrackOutboundChannel(ctx, ch) + mgr, err := newManager(sm, store, nil) + require.NoError(t, err) + + err = mgr.TrackOutboundChannel(ctx, ch) require.NoError(t, err) ci, err := mgr.GetChannelInfo(ch) require.NoError(t, err) - require.Equal(t, ci.Channel, ch) + require.Equal(t, *ci.Channel, ch) require.Equal(t, ci.Control, from) require.Equal(t, ci.Target, to) require.EqualValues(t, ci.Direction, DirOutbound) @@ -140,13 +142,15 @@ func TestPaychInbound(t *testing.T) { LaneStates: []*paych.LaneState{}, }) - mgr := newManager(sm, store) - err := mgr.TrackInboundChannel(ctx, ch) + mgr, err := newManager(sm, store, nil) + require.NoError(t, err) + + err = mgr.TrackInboundChannel(ctx, ch) require.NoError(t, err) ci, err := mgr.GetChannelInfo(ch) require.NoError(t, err) - require.Equal(t, ci.Channel, ch) + require.Equal(t, *ci.Channel, ch) require.Equal(t, ci.Control, to) require.Equal(t, ci.Target, from) require.EqualValues(t, ci.Direction, DirInbound) @@ -321,8 +325,10 @@ func TestCheckVoucherValid(t *testing.T) { LaneStates: tcase.laneStates, }) - mgr := newManager(sm, store) - err := mgr.TrackInboundChannel(ctx, ch) + mgr, err := newManager(sm, store, nil) + require.NoError(t, err) + + err = mgr.TrackInboundChannel(ctx, ch) require.NoError(t, err) sv := testCreateVoucher(t, ch, tcase.voucherLane, tcase.voucherNonce, tcase.voucherAmount, tcase.key) @@ -382,8 +388,10 @@ func TestCheckVoucherValidCountingAllLanes(t *testing.T) { LaneStates: laneStates, }) - mgr := newManager(sm, store) - err := mgr.TrackInboundChannel(ctx, ch) + mgr, err := newManager(sm, store, nil) + require.NoError(t, err) + + err = mgr.TrackInboundChannel(ctx, ch) require.NoError(t, err) // @@ -690,8 +698,10 @@ func testSetupMgrWithChannel(ctx context.Context, t *testing.T) (*Manager, addre }) store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) - mgr := newManager(sm, store) - err := mgr.TrackInboundChannel(ctx, ch) + mgr, err := newManager(sm, store, nil) + require.NoError(t, err) + + err = mgr.TrackInboundChannel(ctx, ch) require.NoError(t, err) return mgr, ch, fromKeyPrivate } diff --git a/paychmgr/paychget_test.go b/paychmgr/paychget_test.go new file mode 100644 index 0000000000..48e2e066e7 --- /dev/null +++ b/paychmgr/paychget_test.go @@ -0,0 +1,629 @@ +package paychmgr + +import ( + "context" + "sync" + "testing" + "time" + + cborrpc "github.com/filecoin-project/go-cbor-util" + + init_ "github.com/filecoin-project/specs-actors/actors/builtin/init" + + "github.com/filecoin-project/specs-actors/actors/builtin" + + "github.com/filecoin-project/lotus/api" + "github.com/filecoin-project/lotus/chain/types" + + "github.com/filecoin-project/go-address" + + "github.com/filecoin-project/specs-actors/actors/abi/big" + tutils "github.com/filecoin-project/specs-actors/support/testing" + "github.com/ipfs/go-cid" + ds "github.com/ipfs/go-datastore" + ds_sync "github.com/ipfs/go-datastore/sync" + + "github.com/stretchr/testify/require" +) + +type waitingCall struct { + response chan types.MessageReceipt +} + +type mockPaychAPI struct { + lk sync.Mutex + messages map[cid.Cid]*types.SignedMessage + waitingCalls []*waitingCall +} + +func newMockPaychAPI() *mockPaychAPI { + return &mockPaychAPI{ + messages: make(map[cid.Cid]*types.SignedMessage), + } +} + +func (pchapi *mockPaychAPI) StateWaitMsg(ctx context.Context, msg cid.Cid, confidence uint64) (*api.MsgLookup, error) { + response := make(chan types.MessageReceipt) + + pchapi.lk.Lock() + pchapi.waitingCalls = append(pchapi.waitingCalls, &waitingCall{response: response}) + pchapi.lk.Unlock() + + receipt := <-response + + return &api.MsgLookup{Receipt: receipt}, nil +} + +func (pchapi *mockPaychAPI) MpoolPushMessage(ctx context.Context, msg *types.Message) (*types.SignedMessage, error) { + pchapi.lk.Lock() + defer pchapi.lk.Unlock() + + smsg := &types.SignedMessage{Message: *msg} + pchapi.messages[smsg.Cid()] = smsg + return smsg, nil +} + +func (pchapi *mockPaychAPI) pushedMessages(c cid.Cid) *types.SignedMessage { + pchapi.lk.Lock() + defer pchapi.lk.Unlock() + + return pchapi.messages[c] +} + +func (pchapi *mockPaychAPI) pushedMessageCount() int { + pchapi.lk.Lock() + defer pchapi.lk.Unlock() + + return len(pchapi.messages) +} + +func (pchapi *mockPaychAPI) finishWaitingCalls(receipt types.MessageReceipt) { + pchapi.lk.Lock() + defer pchapi.lk.Unlock() + + for _, call := range pchapi.waitingCalls { + call.response <- receipt + } + pchapi.waitingCalls = nil +} + +func (pchapi *mockPaychAPI) close() { + pchapi.finishWaitingCalls(types.MessageReceipt{ + ExitCode: 0, + Return: []byte{}, + }) +} + +func TestPaychGetCreateChannelMsg(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + ensureFree := big.NewInt(10) + ch, mcid, err := mgr.GetPaych(ctx, from, to, ensureFree) + require.NoError(t, err) + require.Equal(t, address.Undef, ch) + + pushedMsg := pchapi.pushedMessages(mcid) + require.Equal(t, from, pushedMsg.Message.From) + require.Equal(t, builtin.InitActorAddr, pushedMsg.Message.To) + require.Equal(t, ensureFree, pushedMsg.Message.Value) +} + +func TestPaychGetAddFundsSameValue(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel with value 10 + ensureFree := big.NewInt(10) + _, mcid, err := mgr.GetPaych(ctx, from, to, ensureFree) + require.NoError(t, err) + + // Requesting the same or a lesser amount should just return the same message + // CID as for the create channel message + ensureFree2 := big.NewInt(10) + ch2, mcid2, err := mgr.GetPaych(ctx, from, to, ensureFree2) + require.NoError(t, err) + require.Equal(t, address.Undef, ch2) + require.Equal(t, mcid, mcid2) + + // Should not have sent a second message (because the amount was already + // covered by the create channel message) + msgCount := pchapi.pushedMessageCount() + require.Equal(t, 1, msgCount) +} + +func TestPaychGetCreateChannelThenAddFunds(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + ch := tutils.NewIDAddr(t, 100) + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel with value 10 + ensureFree := big.NewInt(10) + _, createMsgCid, err := mgr.GetPaych(ctx, from, to, ensureFree) + require.NoError(t, err) + + // Should have no channels yet (message sent but channel not created) + cis, err := mgr.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 0) + + // 1. Set up create channel response (sent in response to WaitForMsg()) + createChannelRet := init_.ExecReturn{ + IDAddress: ch, + RobustAddress: ch, + } + createChannelRetBytes, err := cborrpc.Dump(&createChannelRet) + require.NoError(t, err) + createChannelResponse := types.MessageReceipt{ + ExitCode: 0, + Return: createChannelRetBytes, + } + + done := make(chan struct{}) + go func() { + defer close(done) + + // 2. Request add funds - should block until create channel has completed + ensureFree2 := big.NewInt(15) + ch2, addFundsMsgCid, err := mgr.GetPaych(ctx, from, to, ensureFree2) + + // 4. This GetPaych should return after create channel from first + // GetPaych completes + require.NoError(t, err) + + // Expect the channel to have been created + require.Equal(t, ch, ch2) + // Expect add funds message CID to be different to create message cid + require.NotEqual(t, createMsgCid, addFundsMsgCid) + + // Should have one channel, whose address is the channel that was created + cis, err := mgr.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + // Amount should be amount sent to first GetPaych (to create + // channel). + // PendingAmount should be amount sent in second GetPaych + // (second GetPaych triggered add funds, which has not yet been confirmed) + ci, err := mgr.GetChannelInfo(ch) + require.NoError(t, err) + require.EqualValues(t, 10, ci.Amount.Int64()) + require.EqualValues(t, 15, ci.PendingAmount.Int64()) + require.Nil(t, ci.CreateMsg) + + // Trigger add funds confirmation + pchapi.finishWaitingCalls(types.MessageReceipt{ExitCode: 0}) + + time.Sleep(time.Millisecond * 10) + + // Should still have one channel + cis, err = mgr.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + // Channel amount should include last amount sent to GetPaych + ci, err = mgr.GetChannelInfo(ch) + require.NoError(t, err) + require.EqualValues(t, 15, ci.Amount.Int64()) + require.EqualValues(t, 0, ci.PendingAmount.Int64()) + require.Nil(t, ci.AddFundsMsg) + }() + + // Give the go routine above a moment to run + time.Sleep(time.Millisecond * 10) + + // 3. Send create channel response + pchapi.finishWaitingCalls(createChannelResponse) + + <-done +} + +func TestPaychGetCreateChannelWithErrorThenAddFunds(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel + ensureFree := big.NewInt(10) + _, _, err = mgr.GetPaych(ctx, from, to, ensureFree) + require.NoError(t, err) + + // 1. Set up create channel response (sent in response to WaitForMsg()) + // This response indicates an error. + createChannelResponse := types.MessageReceipt{ + ExitCode: 1, // error + Return: []byte{}, + } + + done := make(chan struct{}) + go func() { + defer close(done) + + // 2. Request add funds - should block until create channel has completed + ensureFree2 := big.NewInt(15) + _, _, err := mgr.GetPaych(ctx, from, to, ensureFree2) + + // 4. This GetPaych should complete after create channel from first + // GetPaych completes, and it should error out because the create + // channel was unsuccessful + require.Error(t, err) + }() + + // Give the go routine above a moment to run + time.Sleep(time.Millisecond * 10) + + // 3. Send create channel response + pchapi.finishWaitingCalls(createChannelResponse) + + <-done +} + +func TestPaychGetRecoverAfterError(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + ch := tutils.NewIDAddr(t, 100) + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel + ensureFree := big.NewInt(10) + _, _, err = mgr.GetPaych(ctx, from, to, ensureFree) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + // Send error create channel response + pchapi.finishWaitingCalls(types.MessageReceipt{ + ExitCode: 1, // error + Return: []byte{}, + }) + + time.Sleep(time.Millisecond * 10) + + // Send create message for a channel again + ensureFree2 := big.NewInt(7) + _, _, err = mgr.GetPaych(ctx, from, to, ensureFree2) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + // Send success create channel response + createChannelRet := init_.ExecReturn{ + IDAddress: ch, + RobustAddress: ch, + } + createChannelRetBytes, err := cborrpc.Dump(&createChannelRet) + require.NoError(t, err) + createChannelResponse := types.MessageReceipt{ + ExitCode: 0, + Return: createChannelRetBytes, + } + pchapi.finishWaitingCalls(createChannelResponse) + + time.Sleep(time.Millisecond * 10) + + // Should have one channel, whose address is the channel that was created + cis, err := mgr.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + ci, err := mgr.GetChannelInfo(ch) + require.NoError(t, err) + require.Equal(t, ensureFree2, ci.Amount) + require.EqualValues(t, 0, ci.PendingAmount.Int64()) + require.Nil(t, ci.CreateMsg) +} + +func TestPaychGetRecoverAfterAddFundsError(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + ch := tutils.NewIDAddr(t, 100) + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + defer pchapi.close() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel + ensureFree := big.NewInt(10) + _, _, err = mgr.GetPaych(ctx, from, to, ensureFree) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + // Send success create channel response + createChannelRet := init_.ExecReturn{ + IDAddress: ch, + RobustAddress: ch, + } + createChannelRetBytes, err := cborrpc.Dump(&createChannelRet) + require.NoError(t, err) + createChannelResponse := types.MessageReceipt{ + ExitCode: 0, + Return: createChannelRetBytes, + } + pchapi.finishWaitingCalls(createChannelResponse) + + time.Sleep(time.Millisecond * 10) + + // Send add funds message for channel + ensureFree2 := big.NewInt(15) + _, _, err = mgr.GetPaych(ctx, from, to, ensureFree2) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + // Send error add funds response + pchapi.finishWaitingCalls(types.MessageReceipt{ + ExitCode: 1, // error + Return: []byte{}, + }) + + time.Sleep(time.Millisecond * 10) + + // Should have one channel, whose address is the channel that was created + cis, err := mgr.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + ci, err := mgr.GetChannelInfo(ch) + require.NoError(t, err) + require.Equal(t, ensureFree, ci.Amount) + require.EqualValues(t, 0, ci.PendingAmount.Int64()) + require.Nil(t, ci.CreateMsg) + require.Nil(t, ci.AddFundsMsg) + + // Send add funds message for channel again + ensureFree3 := big.NewInt(12) + _, _, err = mgr.GetPaych(ctx, from, to, ensureFree3) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + // Send success add funds response + pchapi.finishWaitingCalls(types.MessageReceipt{ + ExitCode: 0, + Return: []byte{}, + }) + + time.Sleep(time.Millisecond * 10) + + // Should have one channel, whose address is the channel that was created + cis, err = mgr.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + // Amount should be equal to ensure free for successful add funds msg + ci, err = mgr.GetChannelInfo(ch) + require.NoError(t, err) + require.Equal(t, ensureFree3, ci.Amount) + require.EqualValues(t, 0, ci.PendingAmount.Int64()) + require.Nil(t, ci.CreateMsg) + require.Nil(t, ci.AddFundsMsg) +} + +func TestPaychGetRestartAfterCreateChannelMsg(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + ch := tutils.NewIDAddr(t, 100) + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel with value 10 + ensureFree := big.NewInt(10) + _, createMsgCid, err := mgr.GetPaych(ctx, from, to, ensureFree) + require.NoError(t, err) + + // Simulate shutting down lotus + pchapi.close() + + // Create a new manager with the same datastore + sm2 := newMockStateManager() + pchapi2 := newMockPaychAPI() + defer pchapi2.close() + + mgr2, err := newManager(sm2, store, pchapi2) + require.NoError(t, err) + + // Should have no channels yet (message sent but channel not created) + cis, err := mgr2.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 0) + + // 1. Set up create channel response (sent in response to WaitForMsg()) + createChannelRet := init_.ExecReturn{ + IDAddress: ch, + RobustAddress: ch, + } + createChannelRetBytes, err := cborrpc.Dump(&createChannelRet) + require.NoError(t, err) + createChannelResponse := types.MessageReceipt{ + ExitCode: 0, + Return: createChannelRetBytes, + } + + done := make(chan struct{}) + go func() { + defer close(done) + + // 2. Request add funds - should block until create channel has completed + ensureFree2 := big.NewInt(15) + ch2, addFundsMsgCid, err := mgr2.GetPaych(ctx, from, to, ensureFree2) + + // 4. This GetPaych should return after create channel from first + // GetPaych completes + require.NoError(t, err) + + // Expect the channel to have been created + require.Equal(t, ch, ch2) + // Expect add funds message CID to be different to create message cid + require.NotEqual(t, createMsgCid, addFundsMsgCid) + + // Should have one channel, whose address is the channel that was created + cis, err := mgr2.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + // Amount should be amount sent to first GetPaych (to create + // channel). + // PendingAmount should be amount sent in second GetPaych + // (second GetPaych triggered add funds, which has not yet been confirmed) + ci, err := mgr2.GetChannelInfo(ch) + require.NoError(t, err) + require.EqualValues(t, 10, ci.Amount.Int64()) + require.EqualValues(t, 15, ci.PendingAmount.Int64()) + require.Nil(t, ci.CreateMsg) + }() + + // Give the go routine above a moment to run + time.Sleep(time.Millisecond * 10) + + // 3. Send create channel response + pchapi2.finishWaitingCalls(createChannelResponse) + + <-done +} + +func TestPaychGetRestartAfterAddFundsMsg(t *testing.T) { + ctx := context.Background() + store := NewStore(ds_sync.MutexWrap(ds.NewMapDatastore())) + + ch := tutils.NewIDAddr(t, 100) + from := tutils.NewIDAddr(t, 101) + to := tutils.NewIDAddr(t, 102) + + sm := newMockStateManager() + pchapi := newMockPaychAPI() + + mgr, err := newManager(sm, store, pchapi) + require.NoError(t, err) + + // Send create message for a channel + ensureFree := big.NewInt(10) + _, _, err = mgr.GetPaych(ctx, from, to, ensureFree) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + // Send success create channel response + createChannelRet := init_.ExecReturn{ + IDAddress: ch, + RobustAddress: ch, + } + createChannelRetBytes, err := cborrpc.Dump(&createChannelRet) + require.NoError(t, err) + createChannelResponse := types.MessageReceipt{ + ExitCode: 0, + Return: createChannelRetBytes, + } + pchapi.finishWaitingCalls(createChannelResponse) + + time.Sleep(time.Millisecond * 10) + + // Send add funds message for channel + ensureFree2 := big.NewInt(15) + _, _, err = mgr.GetPaych(ctx, from, to, ensureFree2) + require.NoError(t, err) + + time.Sleep(time.Millisecond * 10) + + // Simulate shutting down lotus + pchapi.close() + + // Create a new manager with the same datastore + sm2 := newMockStateManager() + pchapi2 := newMockPaychAPI() + defer pchapi2.close() + + time.Sleep(time.Millisecond * 10) + + mgr2, err := newManager(sm2, store, pchapi2) + require.NoError(t, err) + + // Send success add funds response + pchapi2.finishWaitingCalls(types.MessageReceipt{ + ExitCode: 0, + Return: []byte{}, + }) + + time.Sleep(time.Millisecond * 10) + + // Should have one channel, whose address is the channel that was created + cis, err := mgr2.ListChannels() + require.NoError(t, err) + require.Len(t, cis, 1) + require.Equal(t, ch, cis[0]) + + // Amount should be equal to ensure free for successful add funds msg + ci, err := mgr2.GetChannelInfo(ch) + require.NoError(t, err) + require.Equal(t, ensureFree2, ci.Amount) + require.EqualValues(t, 0, ci.PendingAmount.Int64()) + require.Nil(t, ci.CreateMsg) + require.Nil(t, ci.AddFundsMsg) +} diff --git a/paychmgr/simple.go b/paychmgr/simple.go index 0d0075d626..c5367d06fd 100644 --- a/paychmgr/simple.go +++ b/paychmgr/simple.go @@ -4,6 +4,10 @@ import ( "bytes" "context" + "github.com/filecoin-project/lotus/api" + + "github.com/filecoin-project/specs-actors/actors/abi/big" + "github.com/filecoin-project/specs-actors/actors/builtin" init_ "github.com/filecoin-project/specs-actors/actors/builtin/init" "github.com/filecoin-project/specs-actors/actors/builtin/paych" @@ -17,7 +21,210 @@ import ( "github.com/filecoin-project/lotus/chain/types" ) -func (pm *Manager) createPaych(ctx context.Context, from, to address.Address, amt types.BigInt) (cid.Cid, error) { +// TODO: +// +// Handle settle event +// - Mark channel as settled (in store) +// - Any subsequent add funds should go to a new channel +// - Tests +// + +type paychApi interface { + StateWaitMsg(ctx context.Context, msg cid.Cid, confidence uint64) (*api.MsgLookup, error) + MpoolPushMessage(ctx context.Context, msg *types.Message) (*types.SignedMessage, error) +} + +// paychFundsRes is the response to a create channel or add funds request +type paychFundsRes struct { + channel address.Address + mcid cid.Cid + err error +} + +// ensureFundsReq is a request to create a channel or add funds to a channel +type ensureFundsReq struct { + ctx context.Context + from address.Address + to address.Address + ensureFree types.BigInt + onComplete func(res *paychFundsRes) +} + +// getPaych ensures that a channel exists between the from and to addresses +// with the given amount of funds. +// If the channel does not exist a create channel message is sent and the message CID is returned. +// If the channel does exist and the funds are sufficient, the channel address is returned. +// If the channel does exist and the funds are not sufficient, an add funds message is sent and +// the message CID is returned. +// If there is an in progress operation (create channel / add funds) and +// - the amount in the in-progress operation would cover the requested amount, +// the message CID of the operation is returned. +// - the amount in the in-progress operation will not cover the requested amount, +// getPaych blocks until the previous operation completes, then returns the +// CID of the new add funds message. +// If an operation returns an error, all subsequent waiting operations complete with the error. +func (ca *channelAccessor) getPaych(ctx context.Context, from, to address.Address, ensureFree types.BigInt) (address.Address, cid.Cid, error) { + // Add the request to ensure funds to a queue and wait for the result + promise := ca.enqueue(&ensureFundsReq{ctx: ctx, from: from, to: to, ensureFree: ensureFree}) + select { + case res := <-promise: + return res.channel, res.mcid, res.err + case <-ctx.Done(): + return address.Undef, cid.Undef, ctx.Err() + } +} + +// Queue up an ensure funds operation +func (ca *channelAccessor) enqueue(task *ensureFundsReq) chan *paychFundsRes { + promise := make(chan *paychFundsRes) + task.onComplete = func(res *paychFundsRes) { + select { + case <-task.ctx.Done(): + case promise <- res: + } + } + + ca.lk.Lock() + defer ca.lk.Unlock() + + ca.ensureFundsReqQueue = append(ca.ensureFundsReqQueue, task) + go ca.processNextQueueItem() + + return promise +} + +// Run the operation at the head of the queue +func (ca *channelAccessor) processNextQueueItem() { + ca.lk.Lock() + defer ca.lk.Unlock() + + if len(ca.ensureFundsReqQueue) == 0 { + return + } + + head := ca.ensureFundsReqQueue[0] + res := ca.processTask(head.ctx, head.from, head.to, head.ensureFree) + + // If the task is waiting on an external event (eg something to appear on + // chain) it will return nil + if res == nil { + // Stop processing the ensureFundsReqQueue and wait. When the event occurs it will + // call processNextQueueItem() again + return + } + + // If there was an error, invoke the callback for the task and all + // subsequent ensureFundsReqQueue tasks with an error + if res.err != nil && res.err != context.Canceled { + for _, task := range ca.ensureFundsReqQueue { + task.onComplete(&paychFundsRes{err: res.err}) + } + ca.ensureFundsReqQueue = nil + return + } + + // The task has finished processing so clean it up + ca.ensureFundsReqQueue[0] = nil // allow GC of element + ca.ensureFundsReqQueue = ca.ensureFundsReqQueue[1:] + + // Task completed so callback with its results + head.onComplete(res) + + // Process the next task + if len(ca.ensureFundsReqQueue) > 0 { + go ca.processNextQueueItem() + } +} + +// msgWaitComplete is called when the message for a previous task is confirmed +// or there is an error. In the case of an error, all subsequent tasks in the +// queue are completed with the error. +func (ca *channelAccessor) msgWaitComplete(err error) { + ca.lk.Lock() + defer ca.lk.Unlock() + + if len(ca.ensureFundsReqQueue) == 0 { + return + } + + // If there was an error, complete all subsequent ensureFundsReqQueue tasks with an error + if err != nil { + for _, task := range ca.ensureFundsReqQueue { + task.onComplete(&paychFundsRes{err: err}) + } + ca.ensureFundsReqQueue = nil + return + } + + go ca.processNextQueueItem() +} + +// processTask checks the state of the channel and takes appropriate action +// (see description of getPaych). +// Note that processTask may be called repeatedly in the same state, and should +// return nil if there is no state change to be made (eg when waiting for a +// message to be confirmed on chain) +func (ca *channelAccessor) processTask(ctx context.Context, from address.Address, to address.Address, ensureFree types.BigInt) *paychFundsRes { + // Note: It's ok if we get ErrChannelNotTracked. It just means we need to + // create a channel. + channelInfo, err := ca.store.OutboundByFromTo(from, to) + if err != nil && err != ErrChannelNotTracked { + return &paychFundsRes{err: err} + } + + // If a channel has not yet been created, create one. + // Note that if the previous attempt to create the channel failed because of a VM error + // (eg not enough gas), both channelInfo.Channel and channelInfo.CreateMsg will be nil. + if channelInfo == nil || channelInfo.Channel == nil && channelInfo.CreateMsg == nil { + mcid, err := ca.createPaych(ctx, from, to, ensureFree) + if err != nil { + return &paychFundsRes{err: err} + } + + return &paychFundsRes{mcid: mcid} + } + + // If the create channel message has been sent but the channel hasn't + // been created on chain yet + if channelInfo.CreateMsg != nil { + // If the amount in the channel will cover the requested amount, + // there's no need to add more funds so just return the channel + // create message CID + if channelInfo.PendingAmount.GreaterThanEqual(ensureFree) { + return &paychFundsRes{mcid: *channelInfo.CreateMsg} + } + + // Otherwise just wait for the channel to be created and try again + return nil + } + + // If the channel already has the requested amount, there's no + // need to add any more, just return the channel address + if channelInfo.Amount.GreaterThanEqual(ensureFree) { + return &paychFundsRes{channel: *channelInfo.Channel, mcid: *channelInfo.CreateMsg} + } + + // If an add funds message was sent to the chain + if channelInfo.AddFundsMsg != nil { + // If the amount in the pending add funds message covers the amount for + // this request, there's no need to add more, just return the message + // CID for the pending request + if channelInfo.PendingAmount.GreaterThanEqual(ensureFree) { + return &paychFundsRes{channel: *channelInfo.Channel, mcid: *channelInfo.AddFundsMsg} + } + + // Otherwise wait for the add funds message to complete and try again + return nil + } + + // We need to add more funds, so send an add funds message to + // cover the amount for this request + mcid, err := ca.addFunds(ctx, from, to, ensureFree) + return &paychFundsRes{channel: *channelInfo.Channel, mcid: *mcid, err: err} +} + +// createPaych sends a message to create the channel and returns the message cid +func (ca *channelAccessor) createPaych(ctx context.Context, from, to address.Address, amt types.BigInt) (cid.Cid, error) { params, aerr := actors.SerializeParams(&paych.ConstructorParams{From: from, To: to}) if aerr != nil { return cid.Undef, aerr @@ -41,106 +248,171 @@ func (pm *Manager) createPaych(ctx context.Context, from, to address.Address, am GasPrice: types.NewInt(0), } - smsg, err := pm.mpool.MpoolPushMessage(ctx, msg) + smsg, err := ca.api.MpoolPushMessage(ctx, msg) if err != nil { return cid.Undef, xerrors.Errorf("initializing paych actor: %w", err) } mcid := smsg.Cid() - go pm.waitForPaychCreateMsg(ctx, mcid) + + ci := &ChannelInfo{ + Direction: DirOutbound, + NextLane: 0, + Control: from, + Target: to, + CreateMsg: &mcid, + PendingAmount: amt, + } + + // Create a new channel in the store + if err := ca.store.putChannelInfo(ci); err != nil { + log.Errorf("tracking channel: %s", err) + return mcid, err + } + + go func() { + // Wait for the channel to be created on chain + err := ca.waitForPaychCreateMsg(from, to, mcid) + ca.msgWaitComplete(err) + }() return mcid, nil } -// WaitForPaychCreateMsg waits for mcid to appear on chain and returns the robust address of the +// waitForPaychCreateMsg waits for mcid to appear on chain and stores the robust address of the // created payment channel -// TODO: wait outside the store lock! -// (tricky because we need to setup channel tracking before we know its address) -func (pm *Manager) waitForPaychCreateMsg(ctx context.Context, mcid cid.Cid) { - defer pm.store.lk.Unlock() - mwait, err := pm.state.StateWaitMsg(ctx, mcid, build.MessageConfidence) +func (ca *channelAccessor) waitForPaychCreateMsg(from address.Address, to address.Address, mcid cid.Cid) error { + mwait, err := ca.api.StateWaitMsg(ca.waitCtx, mcid, build.MessageConfidence) if err != nil { log.Errorf("wait msg: %w", err) - return + return err } if mwait.Receipt.ExitCode != 0 { - log.Errorf("payment channel creation failed (exit code %d)", mwait.Receipt.ExitCode) - return + err := xerrors.Errorf("payment channel creation failed (exit code %d)", mwait.Receipt.ExitCode) + log.Error(err) + + ca.lk.Lock() + defer ca.lk.Unlock() + + ca.mutateChannelInfo(from, to, func(channelInfo *ChannelInfo) { + channelInfo.PendingAmount = big.NewInt(0) + channelInfo.CreateMsg = nil + }) + + return err } var decodedReturn init_.ExecReturn err = decodedReturn.UnmarshalCBOR(bytes.NewReader(mwait.Receipt.Return)) if err != nil { log.Error(err) - return + return err } - paychaddr := decodedReturn.RobustAddress - ci, err := pm.loadStateChannelInfo(ctx, paychaddr, DirOutbound) + ca.lk.Lock() + defer ca.lk.Unlock() + + // Store robust address of channel + ca.mutateChannelInfo(from, to, func(channelInfo *ChannelInfo) { + channelInfo.Channel = &decodedReturn.RobustAddress + channelInfo.Amount = channelInfo.PendingAmount + channelInfo.PendingAmount = big.NewInt(0) + channelInfo.CreateMsg = nil + }) + + return nil +} + +// addFunds sends a message to add funds to the channel and returns the message cid +func (ca *channelAccessor) addFunds(ctx context.Context, from address.Address, to address.Address, ensureFree types.BigInt) (*cid.Cid, error) { + channelInfo, err := ca.store.OutboundByFromTo(from, to) if err != nil { - log.Errorf("loading channel info: %w", err) - return + return nil, err } - if err := pm.store.trackChannel(ci); err != nil { - log.Errorf("tracking channel: %w", err) - } -} + amt := big.Sub(ensureFree, channelInfo.PendingAmount) -func (pm *Manager) addFunds(ctx context.Context, ch address.Address, from address.Address, amt types.BigInt) (cid.Cid, error) { msg := &types.Message{ - To: ch, - From: from, + To: *channelInfo.Channel, + From: channelInfo.Control, Value: amt, Method: 0, GasLimit: 0, GasPrice: types.NewInt(0), } - smsg, err := pm.mpool.MpoolPushMessage(ctx, msg) + smsg, err := ca.api.MpoolPushMessage(ctx, msg) if err != nil { - return cid.Undef, err + return nil, err } mcid := smsg.Cid() - go pm.waitForAddFundsMsg(ctx, mcid) - return mcid, nil + + ca.mutateChannelInfo(from, to, func(ci *ChannelInfo) { + ci.PendingAmount = ensureFree + ci.AddFundsMsg = &mcid + }) + + go func() { + // Wait for funds to be added on chain + err := ca.waitForAddFundsMsg(from, to, mcid) + ca.msgWaitComplete(err) + }() + + return &mcid, nil } -// WaitForAddFundsMsg waits for mcid to appear on chain and returns error, if any -// TODO: wait outside the store lock! -// (tricky because we need to setup channel tracking before we know it's address) -func (pm *Manager) waitForAddFundsMsg(ctx context.Context, mcid cid.Cid) { - defer pm.store.lk.Unlock() - mwait, err := pm.state.StateWaitMsg(ctx, mcid, build.MessageConfidence) +// waitForAddFundsMsg waits for mcid to appear on chain and returns error, if any +func (ca *channelAccessor) waitForAddFundsMsg(from address.Address, to address.Address, mcid cid.Cid) error { + mwait, err := ca.api.StateWaitMsg(ca.waitCtx, mcid, build.MessageConfidence) if err != nil { log.Error(err) + return err } if mwait.Receipt.ExitCode != 0 { - log.Errorf("voucher channel creation failed: adding funds (exit code %d)", mwait.Receipt.ExitCode) + err := xerrors.Errorf("voucher channel creation failed: adding funds (exit code %d)", mwait.Receipt.ExitCode) + log.Error(err) + + ca.lk.Lock() + defer ca.lk.Unlock() + + ca.mutateChannelInfo(from, to, func(channelInfo *ChannelInfo) { + channelInfo.PendingAmount = big.NewInt(0) + channelInfo.AddFundsMsg = nil + }) + + return err } -} -func (pm *Manager) GetPaych(ctx context.Context, from, to address.Address, ensureFree types.BigInt) (address.Address, cid.Cid, error) { - pm.store.lk.Lock() // unlock only on err; wait funcs will defer unlock - var mcid cid.Cid - ch, err := pm.store.findChan(func(ci *ChannelInfo) bool { - if ci.Direction != DirOutbound { - return false - } - return ci.Control == from && ci.Target == to + ca.lk.Lock() + defer ca.lk.Unlock() + + // Store updated amount + ca.mutateChannelInfo(from, to, func(channelInfo *ChannelInfo) { + channelInfo.Amount = channelInfo.PendingAmount + channelInfo.PendingAmount = big.NewInt(0) + channelInfo.AddFundsMsg = nil }) + + return nil +} + +// Change the state of the channel in the store +func (ca *channelAccessor) mutateChannelInfo(from address.Address, to address.Address, mutate func(*ChannelInfo)) { + channelInfo, err := ca.store.OutboundByFromTo(from, to) + + // If there's an error reading or writing to the store just log an error. + // For now we're assuming it's unlikely to happen in practice. + // Later we may want to implement a transactional approach, whereby + // we record to the store that we're going to send a message, send + // the message, and then record that the message was sent. if err != nil { - pm.store.lk.Unlock() - return address.Undef, cid.Undef, xerrors.Errorf("findChan: %w", err) - } - if ch != address.Undef { - // TODO: Track available funds - mcid, err = pm.addFunds(ctx, ch, from, ensureFree) - } else { - mcid, err = pm.createPaych(ctx, from, to, ensureFree) + log.Errorf("Error reading channel info from store: %s", err) } + + mutate(channelInfo) + + err = ca.store.putChannelInfo(channelInfo) if err != nil { - pm.store.lk.Unlock() + log.Errorf("Error writing channel info to store: %s", err) } - return ch, mcid, err } diff --git a/paychmgr/state.go b/paychmgr/state.go index 7d06a35a4d..9ba2740e69 100644 --- a/paychmgr/state.go +++ b/paychmgr/state.go @@ -3,20 +3,21 @@ package paychmgr import ( "context" - "github.com/filecoin-project/specs-actors/actors/abi/big" - "github.com/filecoin-project/specs-actors/actors/builtin/account" "github.com/filecoin-project/go-address" "github.com/filecoin-project/specs-actors/actors/builtin/paych" - xerrors "golang.org/x/xerrors" "github.com/filecoin-project/lotus/chain/types" ) -func (pm *Manager) loadPaychState(ctx context.Context, ch address.Address) (*types.Actor, *paych.State, error) { +type stateAccessor struct { + sm StateManagerApi +} + +func (ca *stateAccessor) loadPaychState(ctx context.Context, ch address.Address) (*types.Actor, *paych.State, error) { var pcast paych.State - act, err := pm.sm.LoadActorState(ctx, ch, &pcast, nil) + act, err := ca.sm.LoadActorState(ctx, ch, &pcast, nil) if err != nil { return nil, nil, err } @@ -24,26 +25,26 @@ func (pm *Manager) loadPaychState(ctx context.Context, ch address.Address) (*typ return act, &pcast, nil } -func (pm *Manager) loadStateChannelInfo(ctx context.Context, ch address.Address, dir uint64) (*ChannelInfo, error) { - _, st, err := pm.loadPaychState(ctx, ch) +func (ca *stateAccessor) loadStateChannelInfo(ctx context.Context, ch address.Address, dir uint64) (*ChannelInfo, error) { + _, st, err := ca.loadPaychState(ctx, ch) if err != nil { return nil, err } var account account.State - _, err = pm.sm.LoadActorState(ctx, st.From, &account, nil) + _, err = ca.sm.LoadActorState(ctx, st.From, &account, nil) if err != nil { return nil, err } from := account.Address - _, err = pm.sm.LoadActorState(ctx, st.To, &account, nil) + _, err = ca.sm.LoadActorState(ctx, st.To, &account, nil) if err != nil { return nil, err } to := account.Address ci := &ChannelInfo{ - Channel: ch, + Channel: &ch, Direction: dir, NextLane: nextLaneFromState(st), } @@ -72,80 +73,3 @@ func nextLaneFromState(st *paych.State) uint64 { } return maxLane + 1 } - -// laneState gets the LaneStates from chain, then applies all vouchers in -// the data store over the chain state -func (pm *Manager) laneState(state *paych.State, ch address.Address) (map[uint64]*paych.LaneState, error) { - // TODO: we probably want to call UpdateChannelState with all vouchers to be fully correct - // (but technically dont't need to) - laneStates := make(map[uint64]*paych.LaneState, len(state.LaneStates)) - - // Get the lane state from the chain - for _, laneState := range state.LaneStates { - laneStates[laneState.ID] = laneState - } - - // Apply locally stored vouchers - vouchers, err := pm.store.VouchersForPaych(ch) - if err != nil && err != ErrChannelNotTracked { - return nil, err - } - - for _, v := range vouchers { - for range v.Voucher.Merges { - return nil, xerrors.Errorf("paych merges not handled yet") - } - - // If there's a voucher for a lane that isn't in chain state just - // create it - ls, ok := laneStates[v.Voucher.Lane] - if !ok { - ls = &paych.LaneState{ - ID: v.Voucher.Lane, - Redeemed: types.NewInt(0), - Nonce: 0, - } - laneStates[v.Voucher.Lane] = ls - } - - if v.Voucher.Nonce < ls.Nonce { - continue - } - - ls.Nonce = v.Voucher.Nonce - ls.Redeemed = v.Voucher.Amount - } - - return laneStates, nil -} - -// Get the total redeemed amount across all lanes, after applying the voucher -func (pm *Manager) totalRedeemedWithVoucher(laneStates map[uint64]*paych.LaneState, sv *paych.SignedVoucher) (big.Int, error) { - // TODO: merges - if len(sv.Merges) != 0 { - return big.Int{}, xerrors.Errorf("dont currently support paych lane merges") - } - - total := big.NewInt(0) - for _, ls := range laneStates { - total = big.Add(total, ls.Redeemed) - } - - lane, ok := laneStates[sv.Lane] - if ok { - // If the voucher is for an existing lane, and the voucher nonce - // and is higher than the lane nonce - if sv.Nonce > lane.Nonce { - // Add the delta between the redeemed amount and the voucher - // amount to the total - delta := big.Sub(sv.Amount, lane.Redeemed) - total = big.Add(total, delta) - } - } else { - // If the voucher is *not* for an existing lane, just add its - // value (implicitly a new lane will be created for the voucher) - total = big.Add(total, sv.Amount) - } - - return total, nil -} diff --git a/paychmgr/store.go b/paychmgr/store.go index 66a514feb2..a352f3a17c 100644 --- a/paychmgr/store.go +++ b/paychmgr/store.go @@ -4,14 +4,14 @@ import ( "bytes" "errors" "fmt" - "strings" - "sync" + + "github.com/filecoin-project/lotus/chain/types" "github.com/filecoin-project/specs-actors/actors/builtin/paych" + "github.com/ipfs/go-cid" "github.com/ipfs/go-datastore" "github.com/ipfs/go-datastore/namespace" dsq "github.com/ipfs/go-datastore/query" - "golang.org/x/xerrors" "github.com/filecoin-project/go-address" cborrpc "github.com/filecoin-project/go-cbor-util" @@ -22,8 +22,6 @@ import ( var ErrChannelNotTracked = errors.New("channel not tracked") type Store struct { - lk sync.Mutex // TODO: this can be split per paych - ds datastore.Batching } @@ -44,109 +42,105 @@ type VoucherInfo struct { Proof []byte } +// ChannelInfo keeps track of information about a channel type ChannelInfo struct { - Channel address.Address + // Channel address - may be nil if the channel hasn't been created yet + Channel *address.Address + // Control is the address of the account that created the channel Control address.Address - Target address.Address - + // Target is the address of the account on the other end of the channel + Target address.Address + // Sequence distinguishes channels with the same Control / Target + Sequence uint64 + // Direction indicates if the channel is inbound (this node is the Target) + // or outbound (this node is the Control) Direction uint64 - Vouchers []*VoucherInfo - NextLane uint64 -} - -func dskeyForChannel(addr address.Address) datastore.Key { - return datastore.NewKey(addr.String()) -} - -func (ps *Store) putChannelInfo(ci *ChannelInfo) error { - k := dskeyForChannel(ci.Channel) - - b, err := cborrpc.Dump(ci) - if err != nil { - return err - } - - return ps.ds.Put(k, b) + // Vouchers is a list of all vouchers sent on the channel + Vouchers []*VoucherInfo + // NextLane is the number of the next lane that should be used when the + // client requests a new lane (eg to create a voucher for a new deal) + NextLane uint64 + // Amount added to the channel. + // Note: This amount is only used by GetPaych to keep track of how much + // has locally been added to the channel. It should reflect the channel's + // Balance on chain as long as all operations occur on the same datastore. + Amount types.BigInt + // Pending amount that we're awaiting confirmation of + PendingAmount types.BigInt + // CID of a pending create message (while waiting for confirmation) + CreateMsg *cid.Cid + // CID of a pending add funds message (while waiting for confirmation) + AddFundsMsg *cid.Cid } -func (ps *Store) getChannelInfo(addr address.Address) (*ChannelInfo, error) { - k := dskeyForChannel(addr) - - b, err := ps.ds.Get(k) - if err == datastore.ErrNotFound { - return nil, ErrChannelNotTracked - } - if err != nil { - return nil, err - } - - var ci ChannelInfo - if err := ci.UnmarshalCBOR(bytes.NewReader(b)); err != nil { - return nil, err - } - - return &ci, nil -} - -func (ps *Store) TrackChannel(ch *ChannelInfo) error { - ps.lk.Lock() - defer ps.lk.Unlock() - - return ps.trackChannel(ch) -} - -func (ps *Store) trackChannel(ch *ChannelInfo) error { - _, err := ps.getChannelInfo(ch.Channel) +// TrackChannel stores a channel, returning an error if the channel was already +// being tracked +func (ps *Store) TrackChannel(ci *ChannelInfo) error { + _, err := ps.ByAddress(*ci.Channel) switch err { default: return err case nil: - return fmt.Errorf("already tracking channel: %s", ch.Channel) + return fmt.Errorf("already tracking channel: %s", ci.Channel) case ErrChannelNotTracked: - return ps.putChannelInfo(ch) + return ps.putChannelInfo(ci) } } func (ps *Store) ListChannels() ([]address.Address, error) { - ps.lk.Lock() - defer ps.lk.Unlock() - - res, err := ps.ds.Query(dsq.Query{KeysOnly: true}) + cis, err := ps.findChans(func(ci *ChannelInfo) bool { + return ci.Channel != nil + }, 0) if err != nil { return nil, err } - defer res.Close() //nolint:errcheck - - var out []address.Address - for { - res, ok := res.NextSync() - if !ok { - break - } - - if res.Error != nil { - return nil, err - } - addr, err := address.NewFromString(strings.TrimPrefix(res.Key, "/")) - if err != nil { - return nil, xerrors.Errorf("failed reading paych key (%q) from datastore: %w", res.Key, err) - } - - out = append(out, addr) + addrs := make([]address.Address, 0, len(cis)) + for _, ci := range cis { + addrs = append(addrs, *ci.Channel) } - return out, nil + return addrs, nil + + //res, err := ps.ds.Query(dsq.Query{KeysOnly: true}) + //if err != nil { + // return nil, err + //} + //defer res.Close() //nolint:errcheck + // + //var out []address.Address + //for { + // res, ok := res.NextSync() + // if !ok { + // break + // } + // + // if res.Error != nil { + // return nil, err + // } + // + // addr, err := address.NewFromString(strings.TrimPrefix(res.Key, "/")) + // if err != nil { + // return nil, xerrors.Errorf("failed reading paych key (%q) from datastore: %w", res.Key, err) + // } + // + // out = append(out, addr) + //} + // + //return out, nil } -func (ps *Store) findChan(filter func(*ChannelInfo) bool) (address.Address, error) { +// findChans loops over all channels, only including those that pass the filter. +// max is the maximum number of channels to return. Set to zero to return unlimited channels. +func (ps *Store) findChans(filter func(*ChannelInfo) bool, max int) ([]ChannelInfo, error) { res, err := ps.ds.Query(dsq.Query{}) if err != nil { - return address.Undef, err + return nil, err } defer res.Close() //nolint:errcheck - var ci ChannelInfo + var stored ChannelInfoStorable + var matches []ChannelInfo for { res, ok := res.NextSync() @@ -155,33 +149,38 @@ func (ps *Store) findChan(filter func(*ChannelInfo) bool) (address.Address, erro } if res.Error != nil { - return address.Undef, err + return nil, err } - if err := ci.UnmarshalCBOR(bytes.NewReader(res.Value)); err != nil { - return address.Undef, err + ci, err := unmarshallChannelInfo(&stored, res) + if err != nil { + return nil, err } - if !filter(&ci) { + if !filter(ci) { continue } - addr, err := address.NewFromString(strings.TrimPrefix(res.Key, "/")) - if err != nil { - return address.Undef, xerrors.Errorf("failed reading paych key (%q) from datastore: %w", res.Key, err) - } + //addr, err := address.NewFromString(strings.TrimPrefix(res.Key, "/")) + //if err != nil { + // return nil, xerrors.Errorf("failed reading paych key (%q) from datastore: %w", res.Key, err) + //} + + matches = append(matches, *ci) - return addr, nil + // If we've reached the maximum number of matches, return. + // Note that if max is zero we return an unlimited number of matches + // because len(matches) will always be at least 1. + if len(matches) == max { + return matches, nil + } } - return address.Undef, nil + return matches, nil } func (ps *Store) AllocateLane(ch address.Address) (uint64, error) { - ps.lk.Lock() - defer ps.lk.Unlock() - - ci, err := ps.getChannelInfo(ch) + ci, err := ps.ByAddress(ch) if err != nil { return 0, err } @@ -193,10 +192,156 @@ func (ps *Store) AllocateLane(ch address.Address) (uint64, error) { } func (ps *Store) VouchersForPaych(ch address.Address) ([]*VoucherInfo, error) { - ci, err := ps.getChannelInfo(ch) + ci, err := ps.ByAddress(ch) if err != nil { return nil, err } return ci.Vouchers, nil } + +func (ps *Store) ByAddress(addr address.Address) (*ChannelInfo, error) { + // TODO: cache + cis, err := ps.findChans(func(ci *ChannelInfo) bool { + return ci.Channel != nil && *ci.Channel == addr + }, 1) + if err != nil { + return nil, err + } + + if len(cis) == 0 { + return nil, ErrChannelNotTracked + } + + return &cis[0], nil +} + +// OutboundByFromTo collects the outbound channel with the given from / to +// addresses and returns the one with the highest sequence number +func (ps *Store) OutboundByFromTo(from address.Address, to address.Address) (*ChannelInfo, error) { + cis, err := ps.findChans(func(ci *ChannelInfo) bool { + if ci.Direction != DirOutbound { + return false + } + return ci.Control == from && ci.Target == to + }, 0) + if err != nil { + return nil, err + } + + return highestSequence(cis) +} + +func highestSequence(cis []ChannelInfo) (*ChannelInfo, error) { + if len(cis) == 0 { + return nil, ErrChannelNotTracked + } + + highestIndex := 0 + highest := cis[0].Sequence + for i := 1; i < len(cis); i++ { + if cis[i].Sequence > highest { + highest = cis[i].Sequence + highestIndex = i + } + } + return &cis[highestIndex], nil +} + +// WithPendingAddFunds is used on startup to find channels for which a +// create channel or add funds message has been sent, but lotus shut down +// before the response was received. +func (ps *Store) WithPendingAddFunds() ([]ChannelInfo, error) { + return ps.findChans(func(ci *ChannelInfo) bool { + if ci.Direction != DirOutbound { + return false + } + return ci.CreateMsg != nil || ci.AddFundsMsg != nil + }, 0) +} + +// The datastore key used to identify the channel info +func dskeyForChannel(ci *ChannelInfo) datastore.Key { + return datastore.NewKey(fmt.Sprintf("%s->%s:%d", ci.Control.String(), ci.Target.String(), ci.Sequence)) +} + +func (ps *Store) putChannelInfo(ci *ChannelInfo) error { + // TODO: When a channel is settled, the next call to putChannelInfo should + // create a new channel with a higher Sequence number + k := dskeyForChannel(ci) + + b, err := marshallChannelInfo(ci) + if err != nil { + return err + } + + return ps.ds.Put(k, b) +} + +// ChannelInfoStorable is used to store information about a channel in the data store. +// TODO: Only need this because we can't currently marshall a nil address.Address for +// Channel so we can't marshall ChannelInfo directly +type ChannelInfoStorable struct { + Channel string + Control address.Address + Target address.Address + Sequence uint64 + Direction uint64 + Vouchers []*VoucherInfo + NextLane uint64 + Amount types.BigInt + PendingAmount types.BigInt + AddFundsMsg *cid.Cid + CreateMsg *cid.Cid +} + +func marshallChannelInfo(ci *ChannelInfo) ([]byte, error) { + ch := "" + if ci.Channel != nil { + ch = ci.Channel.String() + } + toStore := ChannelInfoStorable{ + Channel: ch, + Control: ci.Control, + Target: ci.Target, + Sequence: ci.Sequence, + Direction: ci.Direction, + Vouchers: ci.Vouchers, + NextLane: ci.NextLane, + Amount: ci.Amount, + PendingAmount: ci.PendingAmount, + CreateMsg: ci.CreateMsg, + AddFundsMsg: ci.AddFundsMsg, + } + + return cborrpc.Dump(&toStore) +} + +func unmarshallChannelInfo(stored *ChannelInfoStorable, res dsq.Result) (*ChannelInfo, error) { + if err := stored.UnmarshalCBOR(bytes.NewReader(res.Value)); err != nil { + return nil, err + } + + var ch *address.Address + if len(stored.Channel) > 0 { + addr, err := address.NewFromString(stored.Channel) + if err != nil { + return nil, err + } + ch = &addr + } + ci := ChannelInfo{ + Channel: ch, + Control: stored.Control, + Target: stored.Target, + Sequence: stored.Sequence, + Direction: stored.Direction, + Vouchers: stored.Vouchers, + NextLane: stored.NextLane, + Amount: stored.Amount, + PendingAmount: stored.PendingAmount, + CreateMsg: stored.CreateMsg, + AddFundsMsg: stored.AddFundsMsg, + } + return &ci, nil +} diff --git a/paychmgr/store_test.go b/paychmgr/store_test.go index 0942264646..65be6f1b19 100644 --- a/paychmgr/store_test.go +++ b/paychmgr/store_test.go @@ -17,8 +17,9 @@ func TestStore(t *testing.T) { require.NoError(t, err) require.Len(t, addrs, 0) + ch := tutils.NewIDAddr(t, 100) ci := &ChannelInfo{ - Channel: tutils.NewIDAddr(t, 100), + Channel: &ch, Control: tutils.NewIDAddr(t, 101), Target: tutils.NewIDAddr(t, 102), @@ -26,8 +27,9 @@ func TestStore(t *testing.T) { Vouchers: []*VoucherInfo{{Voucher: nil, Proof: []byte{}}}, } + ch2 := tutils.NewIDAddr(t, 200) ci2 := &ChannelInfo{ - Channel: tutils.NewIDAddr(t, 200), + Channel: &ch2, Control: tutils.NewIDAddr(t, 201), Target: tutils.NewIDAddr(t, 202), @@ -55,7 +57,7 @@ func TestStore(t *testing.T) { require.Contains(t, addrsStrings(addrs), "t0200") // Request vouchers for channel - vouchers, err := store.VouchersForPaych(ci.Channel) + vouchers, err := store.VouchersForPaych(*ci.Channel) require.NoError(t, err) require.Len(t, vouchers, 1) @@ -64,12 +66,12 @@ func TestStore(t *testing.T) { require.Equal(t, err, ErrChannelNotTracked) // Allocate lane for channel - lane, err := store.AllocateLane(ci.Channel) + lane, err := store.AllocateLane(*ci.Channel) require.NoError(t, err) require.Equal(t, lane, uint64(0)) // Allocate next lane for channel - lane, err = store.AllocateLane(ci.Channel) + lane, err = store.AllocateLane(*ci.Channel) require.NoError(t, err) require.Equal(t, lane, uint64(1))