Skip to content

Commit

Permalink
Cancel Cobra parent context on interrupt
Browse files Browse the repository at this point in the history
  • Loading branch information
phillebaba committed May 31, 2024
1 parent ff83e19 commit b1dde93
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 75 deletions.
8 changes: 7 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
package main

import (
"context"
"embed"
"os/signal"
"syscall"

"github.com/defenseunicorns/zarf/src/cmd"
"github.com/defenseunicorns/zarf/src/config"
Expand All @@ -19,7 +22,10 @@ var cosignPublicKey string
var zarfSchema embed.FS

func main() {
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()

config.CosignPublicKey = cosignPublicKey
lint.ZarfSchema = zarfSchema
cmd.Execute()
cmd.Execute(ctx)
}
2 changes: 0 additions & 2 deletions src/cmd/common/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ var LogLevelCLI string

// SetupCLI sets up the CLI logging, interrupt functions, and more
func SetupCLI() {
ExitOnInterrupt()

match := map[string]message.LogLevel{
"warn": message.WarnLevel,
"info": message.InfoLevel,
Expand Down
19 changes: 0 additions & 19 deletions src/cmd/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,11 @@ package common

import (
"context"
"os"
"os/signal"
"syscall"

"github.com/defenseunicorns/zarf/src/config/lang"
"github.com/defenseunicorns/zarf/src/pkg/cluster"
"github.com/defenseunicorns/zarf/src/pkg/message"
)

// SuppressGlobalInterrupt suppresses the global error on an interrupt
var SuppressGlobalInterrupt = false

// SetBaseDirectory sets the base directory. This is a directory with a zarf.yaml.
func SetBaseDirectory(args []string) string {
if len(args) > 0 {
Expand All @@ -26,18 +19,6 @@ func SetBaseDirectory(args []string) string {
return "."
}

// ExitOnInterrupt catches an interrupt and exits with fatal error
func ExitOnInterrupt() {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
if !SuppressGlobalInterrupt {
message.Fatal(lang.ErrInterrupt, lang.ErrInterrupt.Error())
}
}()
}

// NewClusterOrDie creates a new Cluster instance and waits for the cluster to be ready or throws a fatal error.
func NewClusterOrDie(ctx context.Context) *cluster.Cluster {
timeoutCtx, cancel := context.WithTimeout(ctx, cluster.DefaultTimeout)
Expand Down
11 changes: 2 additions & 9 deletions src/cmd/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ package cmd
import (
"fmt"
"os"
"os/signal"
"syscall"

"github.com/defenseunicorns/zarf/src/cmd/common"
"github.com/defenseunicorns/zarf/src/config/lang"
Expand Down Expand Up @@ -72,17 +70,12 @@ var (
}
}

// Keep this open until an interrupt signal is received.
interruptChan := make(chan os.Signal, 1)
signal.Notify(interruptChan, os.Interrupt, syscall.SIGTERM)
common.SuppressGlobalInterrupt = true

// Wait for the interrupt signal or an error.
select {
case <-ctx.Done():
spinner.Successf(lang.CmdConnectTunnelClosed, url)
case err = <-tunnel.ErrChan():
spinner.Fatalf(err, lang.CmdConnectErrService, err.Error())
case <-interruptChan:
spinner.Successf(lang.CmdConnectTunnelClosed, url)
}
os.Exit(0)
},
Expand Down
8 changes: 4 additions & 4 deletions src/cmd/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,17 @@ var agentCmd = &cobra.Command{
Use: "agent",
Short: lang.CmdInternalAgentShort,
Long: lang.CmdInternalAgentLong,
Run: func(_ *cobra.Command, _ []string) {
agent.StartWebhook()
RunE: func(cmd *cobra.Command, _ []string) error {
return agent.StartWebhook(cmd.Context())
},
}

var httpProxyCmd = &cobra.Command{
Use: "http-proxy",
Short: lang.CmdInternalProxyShort,
Long: lang.CmdInternalProxyLong,
Run: func(_ *cobra.Command, _ []string) {
agent.StartHTTPProxy()
RunE: func(cmd *cobra.Command, _ []string) error {
return agent.StartHTTPProxy(cmd.Context())
},
}

Expand Down
20 changes: 8 additions & 12 deletions src/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ import (
"os"
"strings"

"github.com/spf13/cobra"

"github.com/defenseunicorns/zarf/src/cmd/common"
"github.com/defenseunicorns/zarf/src/cmd/tools"
"github.com/defenseunicorns/zarf/src/config"
"github.com/defenseunicorns/zarf/src/config/lang"
"github.com/defenseunicorns/zarf/src/pkg/layout"
"github.com/defenseunicorns/zarf/src/pkg/message"
"github.com/defenseunicorns/zarf/src/types"
"github.com/spf13/cobra"
)

var (
Expand All @@ -32,21 +33,16 @@ var rootCmd = &cobra.Command{
if common.CheckVendorOnlyFromPath(cmd) {
return
}

// Don't log the help command
if cmd.Parent() == nil {
config.SkipLogFile = true
}

// Set the global context for the root command and all child commands
ctx := context.Background()
cmd.SetContext(ctx)

common.SetupCLI()
},
Short: lang.RootCmdShort,
Long: lang.RootCmdLong,
Args: cobra.MaximumNArgs(1),
Short: lang.RootCmdShort,
Long: lang.RootCmdLong,
Args: cobra.MaximumNArgs(1),
SilenceUsage: true,
Run: func(cmd *cobra.Command, args []string) {
zarfLogo := message.GetLogo()
_, _ = fmt.Fprintln(os.Stderr, zarfLogo)
Expand All @@ -64,8 +60,8 @@ var rootCmd = &cobra.Command{
}

// Execute is the entrypoint for the CLI.
func Execute() {
cobra.CheckErr(rootCmd.Execute())
func Execute(ctx context.Context) {
rootCmd.ExecuteContext(ctx)
}

func init() {
Expand Down
4 changes: 0 additions & 4 deletions src/cmd/tools/crane.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"strings"

"github.com/AlecAivazis/survey/v2"
"github.com/defenseunicorns/zarf/src/cmd/common"
"github.com/defenseunicorns/zarf/src/config"
"github.com/defenseunicorns/zarf/src/config/lang"
"github.com/defenseunicorns/zarf/src/internal/packager/images"
Expand Down Expand Up @@ -39,9 +38,6 @@ func init() {
Aliases: []string{"r", "crane"},
Short: lang.CmdToolsRegistryShort,
PersistentPreRun: func(cmd *cobra.Command, _ []string) {

common.ExitOnInterrupt()

// The crane options loading here comes from the rootCmd of crane
craneOptions = append(craneOptions, crane.WithContext(cmd.Context()))
// TODO(jonjohnsonjr): crane.Verbose option?
Expand Down
54 changes: 30 additions & 24 deletions src/internal/agent/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ package agent

import (
"context"
"errors"
"net/http"
"os"
"os/signal"
"syscall"
"time"

"golang.org/x/sync/errgroup"

"github.com/defenseunicorns/zarf/src/config/lang"
agentHttp "github.com/defenseunicorns/zarf/src/internal/agent/http"
Expand All @@ -27,35 +28,40 @@ const (
)

// StartWebhook launches the Zarf agent mutating webhook in the cluster.
func StartWebhook() {
func StartWebhook(ctx context.Context) error {
message.Debug("agent.StartWebhook()")

startServer(agentHttp.NewAdmissionServer(httpPort))
return startServer(ctx, agentHttp.NewAdmissionServer(httpPort))
}

// StartHTTPProxy launches the zarf agent proxy in the cluster.
func StartHTTPProxy() {
func StartHTTPProxy(ctx context.Context) error {
message.Debug("agent.StartHttpProxy()")

startServer(agentHttp.NewProxyServer(httpPort))
return startServer(ctx, agentHttp.NewProxyServer(httpPort))
}

func startServer(server *http.Server) {
go func() {
if err := server.ListenAndServeTLS(tlsCert, tlsKey); err != nil && err != http.ErrServerClosed {
message.Fatal(err, lang.AgentErrStart)
func startServer(ctx context.Context, srv *http.Server) error {
g, gCtx := errgroup.WithContext(ctx)
g.Go(func() error {
err := srv.ListenAndServeTLS(tlsCert, tlsKey)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
}()

return nil
})
g.Go(func() error {
<-gCtx.Done()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
err := srv.Shutdown(ctx)
if err != nil {
return err
}
return nil
})
message.Infof(lang.AgentInfoPort, httpPort)

// listen shutdown signal
signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM)
<-signalChan

message.Infof(lang.AgentInfoShutdown)
if err := server.Shutdown(context.Background()); err != nil {
message.Fatal(err, lang.AgentErrShutdown)
err := g.Wait()
if err != nil {
return err
}
return nil
}

0 comments on commit b1dde93

Please sign in to comment.