Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement visor restart from hypervisor #80

Merged
merged 13 commits into from
Dec 27, 2019
22 changes: 21 additions & 1 deletion cmd/skywire-visor/commands/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/spf13/cobra"

"github.com/SkycoinProject/skywire-mainnet/internal/utclient"
"github.com/SkycoinProject/skywire-mainnet/pkg/restart"
"github.com/SkycoinProject/skywire-mainnet/pkg/util/pathutil"
"github.com/SkycoinProject/skywire-mainnet/pkg/visor"
)
Expand All @@ -39,13 +40,15 @@ type runCfg struct {
cfgFromStdin bool
profileMode string
port string
startDelay string
args []string

profileStop func()
logger *logging.Logger
masterLogger *logging.MasterLogger
conf visor.Config
node *visor.Node
restartCtx *restart.Context
}

var cfg *runCfg
Expand Down Expand Up @@ -73,6 +76,9 @@ func init() {
rootCmd.Flags().BoolVarP(&cfg.cfgFromStdin, "stdin", "i", false, "read config from STDIN")
rootCmd.Flags().StringVarP(&cfg.profileMode, "profile", "p", "none", "enable profiling with pprof. Mode: none or one of: [cpu, mem, mutex, block, trace, http]")
rootCmd.Flags().StringVarP(&cfg.port, "port", "", "6060", "port for http-mode of pprof")
rootCmd.Flags().StringVarP(&cfg.startDelay, "delay", "", "0ns", "delay before visor start")

cfg.restartCtx = restart.CaptureContext()
}

// Execute executes root CLI command.
Expand Down Expand Up @@ -148,7 +154,19 @@ func (cfg *runCfg) readConfig() *runCfg {
}

func (cfg *runCfg) runNode() *runCfg {
node, err := visor.NewNode(&cfg.conf, cfg.masterLogger)
startDelay, err := time.ParseDuration(cfg.startDelay)
if err != nil {
cfg.logger.Warnf("Using no visor start delay due to parsing failure: %v", err)
startDelay = time.Duration(0)
}

if startDelay != 0 {
cfg.logger.Infof("Visor start delay is %v, waiting...", startDelay)
}

time.Sleep(startDelay)

node, err := visor.NewNode(&cfg.conf, cfg.masterLogger, cfg.restartCtx)
if err != nil {
cfg.logger.Fatal("Failed to initialize node: ", err)
}
Expand Down Expand Up @@ -181,7 +199,9 @@ func (cfg *runCfg) runNode() *runCfg {
if cfg.conf.ShutdownTimeout == 0 {
cfg.conf.ShutdownTimeout = defaultShutdownTimeout
}

cfg.node = node

return cfg
}

Expand Down
13 changes: 13 additions & 0 deletions pkg/hypervisor/hypervisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ func (m *Node) ServeHTTP(w http.ResponseWriter, req *http.Request) {
r.Put("/nodes/{pk}/routes/{rid}", m.putRoute())
r.Delete("/nodes/{pk}/routes/{rid}", m.deleteRoute())
r.Get("/nodes/{pk}/loops", m.getLoops())
r.Get("/nodes/{pk}/restart", m.restart())
})
})
r.ServeHTTP(w, req)
Expand Down Expand Up @@ -569,6 +570,18 @@ func (m *Node) getLoops() http.HandlerFunc {
})
}

// NOTE: Reply comes with a delay, because of check if new executable is started successfully.
func (m *Node) restart() http.HandlerFunc {
return m.withCtx(m.nodeCtx, func(w http.ResponseWriter, r *http.Request, ctx *httpCtx) {
if err := ctx.RPC.Restart(); err != nil {
httputil.WriteJSON(w, r, http.StatusInternalServerError, err)
return
}

httputil.WriteJSON(w, r, http.StatusOK, true)
})
}

/*
<<< Helper functions >>>
*/
Expand Down
155 changes: 155 additions & 0 deletions pkg/restart/restart.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package restart

import (
"errors"
"log"
"os"
"os/exec"
"sync/atomic"
"time"

"github.com/sirupsen/logrus"
)

var (
// ErrAlreadyStarting is returned on starting attempt when starting is in progress.
ErrAlreadyStarting = errors.New("already starting")
)

const (
// DefaultCheckDelay is a default delay for checking if a new instance is started successfully.
DefaultCheckDelay = 1 * time.Second
extraWaitingTime = 1 * time.Second
delayArgName = "--delay"
)

// Context describes data required for restarting visor.
type Context struct {
log logrus.FieldLogger
cmd *exec.Cmd
checkDelay time.Duration
isStarting int32
appendDelay bool // disabled in tests
}

// CaptureContext captures data required for restarting visor.
// Data used by CaptureContext must not be modified before,
// therefore calling CaptureContext immediately after starting executable is recommended.
func CaptureContext() *Context {
cmd := exec.Command(os.Args[0], os.Args[1:]...) // nolint:gosec

cmd.Stdout = os.Stdout
cmd.Stdin = os.Stdin
cmd.Stderr = os.Stderr
cmd.Env = os.Environ()

return &Context{
cmd: cmd,
checkDelay: DefaultCheckDelay,
appendDelay: true,
}
}

// RegisterLogger registers a logger instead of standard one.
func (c *Context) RegisterLogger(logger logrus.FieldLogger) {
if c != nil {
c.log = logger
}
}

// SetCheckDelay sets a check delay instead of standard one.
func (c *Context) SetCheckDelay(delay time.Duration) {
if c != nil {
c.checkDelay = delay
}
}

// Start starts a new executable using Context.
func (c *Context) Start() error {
if !atomic.CompareAndSwapInt32(&c.isStarting, 0, 1) {
return ErrAlreadyStarting
}

defer atomic.StoreInt32(&c.isStarting, 0)

errCh := c.startExec()

ticker := time.NewTicker(c.checkDelay)
defer ticker.Stop()

select {
case err := <-errCh:
c.errorLogger()("Failed to start new instance: %v", err)
return err
case <-ticker.C:
c.infoLogger()("New instance started successfully, exiting from the old one")
return nil
}
}

func (c *Context) startExec() chan error {
errCh := make(chan error, 1)

go func() {
defer close(errCh)

c.adjustArgs()

c.infoLogger()("Starting new instance of executable (args: %q)", c.cmd.Args)

if err := c.cmd.Start(); err != nil {
errCh <- err
return
}

if err := c.cmd.Wait(); err != nil {
errCh <- err
return
}
}()

return errCh
}

func (c *Context) adjustArgs() {
args := c.cmd.Args

i := 0
l := len(args)

for i < l {
if args[i] == delayArgName && i < len(args)-1 {
args = append(args[:i], args[i+2:]...)
l -= 2
} else {
i++
}
}

if c.appendDelay {
delay := c.checkDelay + extraWaitingTime
args = append(args, delayArgName, delay.String())
}

c.cmd.Args = args
}

func (c *Context) infoLogger() func(string, ...interface{}) {
if c.log != nil {
return c.log.Infof
}

logger := log.New(os.Stdout, "[INFO] ", log.LstdFlags)

return logger.Printf
}

func (c *Context) errorLogger() func(string, ...interface{}) {
if c.log != nil {
return c.log.Errorf
}

logger := log.New(os.Stdout, "[ERROR] ", log.LstdFlags)

return logger.Printf
}
83 changes: 83 additions & 0 deletions pkg/restart/restart_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package restart

import (
"os"
"os/exec"
"testing"
"time"

"github.com/SkycoinProject/skycoin/src/util/logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestCaptureContext(t *testing.T) {
cc := CaptureContext()

require.Equal(t, DefaultCheckDelay, cc.checkDelay)
require.Equal(t, os.Args, cc.cmd.Args)
require.Equal(t, os.Stdout, cc.cmd.Stdout)
require.Equal(t, os.Stdin, cc.cmd.Stdin)
require.Equal(t, os.Stderr, cc.cmd.Stderr)
require.Equal(t, os.Environ(), cc.cmd.Env)
require.Nil(t, cc.log)
}

func TestContext_RegisterLogger(t *testing.T) {
cc := CaptureContext()
require.Nil(t, cc.log)

logger := logging.MustGetLogger("test")
cc.RegisterLogger(logger)
require.Equal(t, logger, cc.log)
}

func TestContext_Start(t *testing.T) {
cc := CaptureContext()
assert.NotZero(t, len(cc.cmd.Args))

t.Run("executable started", func(t *testing.T) {
cmd := "touch"
path := "/tmp/test_restart"
cc.cmd = exec.Command(cmd, path) // nolint:gosec
cc.appendDelay = false

assert.NoError(t, cc.Start())
assert.NoError(t, os.Remove(path))
})

t.Run("bad args", func(t *testing.T) {
cmd := "bad_command"
cc.cmd = exec.Command(cmd) // nolint:gosec

// TODO(nkryuchkov): Check if it works on Linux and Windows, if not then change the error text.
assert.EqualError(t, cc.Start(), `exec: "bad_command": executable file not found in $PATH`)
})

t.Run("already restarting", func(t *testing.T) {
cmd := "touch"
path := "/tmp/test_restart"
cc.cmd = exec.Command(cmd, path) // nolint:gosec
cc.appendDelay = false

ch := make(chan error, 1)
go func() {
ch <- cc.Start()
}()

assert.NoError(t, cc.Start())
assert.Equal(t, ErrAlreadyStarting, <-ch)

assert.NoError(t, os.Remove(path))
})
}

func TestContext_SetCheckDelay(t *testing.T) {
cc := CaptureContext()
require.Equal(t, DefaultCheckDelay, cc.checkDelay)

const oneSecond = 1 * time.Second

cc.SetCheckDelay(oneSecond)
require.Equal(t, oneSecond, cc.checkDelay)
}
2 changes: 2 additions & 0 deletions pkg/visor/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ type Config struct {
Interfaces InterfaceConfig `json:"interfaces"`

AppServerSockFile string `json:"app_server_sock_file"`

RestartCheckDelay string `json:"restart_check_delay"`
}

// MessagingConfig returns config for dmsg client.
Expand Down
28 changes: 28 additions & 0 deletions pkg/visor/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
"os"
"path/filepath"
"time"

Expand All @@ -29,6 +30,9 @@ var (

// ErrNotFound is returned when a requested resource is not found.
ErrNotFound = errors.New("not found")

// ErrMalformedRestartContext is returned when restart context is malformed.
ErrMalformedRestartContext = errors.New("restart context is malformed")
)

// RPC defines RPC methods for Node.
Expand Down Expand Up @@ -390,3 +394,27 @@ func (r *RPC) Loops(_ *struct{}, out *[]LoopInfo) error {
*out = loops
return nil
}

/*
<<< VISOR MANAGEMENT >>>
*/

const exitDelay = 100 * time.Millisecond

// Restart restarts visor.
func (r *RPC) Restart(_ *struct{}, _ *struct{}) (err error) {
defer func() {
if err == nil {
go func() {
time.Sleep(exitDelay)
os.Exit(0)
}()
}
}()

if r.node.restartCtx == nil {
nkryuchkov marked this conversation as resolved.
Show resolved Hide resolved
return ErrMalformedRestartContext
}

return r.node.restartCtx.Start()
}
Loading