Skip to content

Commit

Permalink
Fix/tanlang/close node manually (#286)
Browse files Browse the repository at this point in the history
* fix: close node manually
  • Loading branch information
LinZexiao authored Nov 22, 2022
1 parent 09dc818 commit 12bbd5f
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 39 deletions.
80 changes: 45 additions & 35 deletions publisher/publisher.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"runtime"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -60,19 +59,26 @@ type RpcPublisher struct {
nodeProvider repo.INodeProvider
enableMultiNode bool

nodeThreads map[types.UUID]*nodeThread
lk sync.RWMutex
nodeThreads map[types.UUID]struct {
nodeThread *nodeThread
close func()
}
lk sync.Mutex
}

func NewRpcPublisher(ctx context.Context, nodeClient v1.FullNode, nodeProvider repo.INodeProvider, enableMultiNode bool) *RpcPublisher {
nThread := newNodeThread(ctx, nodeClient)
nThread := newNodeThread(ctx, "mainNode", nodeClient)
return &RpcPublisher{
ctx: ctx,
mainNodeThread: nThread,
nodeProvider: nodeProvider,
enableMultiNode: enableMultiNode,
nodeThreads: make(map[types.UUID]*nodeThread),
lk: sync.RWMutex{},
nodeThreads: make(map[types.UUID]struct {
nodeThread *nodeThread
close func()
}),

lk: sync.Mutex{},
}
}

Expand All @@ -88,53 +94,57 @@ func (p *RpcPublisher) PublishMessages(ctx context.Context, msgs []*types.Signed
return fmt.Errorf("list node fail %w", err)
}

newThreadMap := make(map[types.UUID]*nodeThread, len(nodeList))
needUpdate := false

p.lk.RLock()
oriLen := len(p.nodeThreads)
for _, node := range nodeList {
thread, ok := p.nodeThreads[node.ID]
if ok {
newThreadMap[node.ID] = thread
}
}
p.lk.RUnlock()
p.lk.Lock()
defer p.lk.Unlock()

nodesRemain := make(map[types.UUID]struct{})
for _, node := range nodeList {
thr, ok := newThreadMap[node.ID]
threadStruct, ok := p.nodeThreads[node.ID]
nodesRemain[node.ID] = struct{}{}
if !ok {
needUpdate = true
cli, closer, err := v1.DialFullNodeRPC(ctx, node.URL, node.Token, nil)
thrCtx, cancel := context.WithCancel(p.ctx) // nolint ignore lostcancel
cli, closer, err := v1.DialFullNodeRPC(thrCtx, node.URL, node.Token, nil)
if err != nil {
log.Warnf("connect node(%s) %v", node.Name, err)
log.Warnf("connect node(%s) fail %v", node.Name, err)
continue
}
runtime.SetFinalizer(cli, func(c *v1.FullNodeStruct) {
closer()
})
thr = newNodeThread(p.ctx, cli)
newThreadMap[node.ID] = thr

nodeName := node.Name
threadStruct = struct {
nodeThread *nodeThread
close func()
}{
nodeThread: newNodeThread(thrCtx, nodeName, cli),
close: func() {
cancel()
closer()
log.Debugf("close node thread %s", nodeName)
},
}
p.nodeThreads[node.ID] = threadStruct
}
thr.HandleMsg(msgs)
threadStruct.nodeThread.HandleMsg(msgs)
}

if needUpdate || len(newThreadMap) != oriLen {
p.lk.Lock()
p.nodeThreads = newThreadMap
p.lk.Unlock()
for id, threadStruct := range p.nodeThreads {
if _, ok := nodesRemain[id]; !ok {
threadStruct.close()
delete(p.nodeThreads, id)
}
}

return nil
return nil // nolint ignore lostcancel
}

type nodeThread struct {
name string
nodeClient v1.FullNode
msgChan chan []*types.SignedMessage
}

func newNodeThread(ctx context.Context, nodeClient v1.FullNode) *nodeThread {
func newNodeThread(ctx context.Context, name string, nodeClient v1.FullNode) *nodeThread {
t := &nodeThread{
name: name,
nodeClient: nodeClient,
msgChan: make(chan []*types.SignedMessage, 30),
}
Expand All @@ -152,7 +162,7 @@ func (n *nodeThread) run(ctx context.Context) {
if _, err := n.nodeClient.MpoolBatchPush(ctx, msgs); err != nil {
//skip error
if !strings.Contains(err.Error(), errMinimumNonce.Error()) && !strings.Contains(err.Error(), errAlreadyInMpool.Error()) {
log.Errorf("push message to node failed %v", err)
log.Errorf("push message to node %s failed %v", n.name, err)
} else {
log.Debugf("push message to node failed %v", err)
}
Expand Down
12 changes: 8 additions & 4 deletions publisher/publisher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,21 @@ func TestMultiNodePublishMessage(t *testing.T) {
rpcPublisher := NewRpcPublisher(ctx, mainNode, nodeProvider, true)

t.Run("publish message to multi node", func(t *testing.T) {
nodeProvider.EXPECT().ListNode().Return(nodes[:2], nil).Times(1)
for _, srv := range servers[:2] {
nodeProvider.EXPECT().ListNode().Return(nodes[:3], nil).Times(1)
for _, srv := range servers[:3] {
srv.FullNode.EXPECT().MpoolBatchPush(gomock.Any(), msgs).Return(nil, nil).Times(1)
}
err := rpcPublisher.PublishMessages(ctx, msgs)
assert.NoError(t, err)
runtime.Gosched()
})

// wait for messager consume
time.Sleep(1 * time.Second)

t.Run("publish message to multi node after delete node", func(t *testing.T) {
nodeProvider.EXPECT().ListNode().Return(nodes[:1], nil).Times(1)
for _, srv := range servers[:1] {
nodeProvider.EXPECT().ListNode().Return(nodes[1:2], nil).Times(1)
for _, srv := range servers[1:2] {
srv.FullNode.EXPECT().MpoolBatchPush(gomock.Any(), msgs).Return(nil, nil).Times(1)
}
err := rpcPublisher.PublishMessages(ctx, msgs)
Expand All @@ -92,6 +95,7 @@ func TestMultiNodePublishMessage(t *testing.T) {
assert.NoError(t, err)
runtime.Gosched()
})

// wait goroutine
time.Sleep(1 * time.Second)
}
Expand Down

0 comments on commit 12bbd5f

Please sign in to comment.