diff --git a/cmd/launcher/launcher.go b/cmd/launcher/launcher.go index 139e56917..beef2cca2 100644 --- a/cmd/launcher/launcher.go +++ b/cmd/launcher/launcher.go @@ -178,20 +178,6 @@ func runSocket(args []string) error { return nil } -func runSubcommands() error { - var run func([]string) error - switch os.Args[1] { - case "socket": - run = runSocket - case "query": - run = runQuery - case "flare": - run = runFlare - } - err := run(os.Args[2:]) - return errors.Wrapf(err, "running subcommand %s", os.Args[1]) -} - // run the launcher daemon func runLauncher(ctx context.Context, cancel func(), opts *options, logger log.Logger) error { // determine the root directory, create one if it's not provided diff --git a/cmd/launcher/main.go b/cmd/launcher/main.go index 53f44b5f4..5dd3cdb75 100644 --- a/cmd/launcher/main.go +++ b/cmd/launcher/main.go @@ -1,5 +1,3 @@ -// +build !windows - package main import ( @@ -10,7 +8,6 @@ import ( "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" "github.com/kolide/kit/logutil" - "github.com/kolide/kit/version" "github.com/pkg/errors" ) @@ -27,24 +24,12 @@ func main() { } } - opts, err := parseOptions() + opts, err := parseOptions(os.Args[1:]) if err != nil { level.Info(logger).Log("err", err) os.Exit(1) } - // handle --version - if opts.printVersion { - version.PrintFull() - os.Exit(0) - } - - // handle --usage - if opts.developerUsage { - developerUsage() - os.Exit(0) - } - logger = logutil.NewServerLogger(opts.debug) ctx, cancel := context.WithCancel(context.Background()) @@ -56,7 +41,7 @@ func main() { } func isSubCommand() bool { - if len(os.Args) != 2 { + if len(os.Args) < 2 { return false } @@ -64,6 +49,8 @@ func isSubCommand() bool { "socket", "query", "flare", + "svc", + "svc-fg", } for _, sc := range subCommands { @@ -74,3 +61,21 @@ func isSubCommand() bool { return false } + +func runSubcommands() error { + var run func([]string) error + switch os.Args[1] { + case "socket": + run = runSocket + case "query": + run = runQuery + case "flare": + run = runFlare + case "svc": + run = runWindowsSvc + case "svc-fg": + run = runWindowsSvcForeground + } + err := run(os.Args[2:]) + return errors.Wrapf(err, "running subcommand %s", os.Args[1]) +} diff --git a/cmd/launcher/main_windows.go b/cmd/launcher/main_windows.go deleted file mode 100644 index fa35c62d3..000000000 --- a/cmd/launcher/main_windows.go +++ /dev/null @@ -1,268 +0,0 @@ -// +build windows - -package main - -import ( - "context" - "fmt" - "os" - "path/filepath" - "strings" - "time" - - "github.com/go-kit/kit/log" - "github.com/go-kit/kit/log/level" - "github.com/kolide/kit/logutil" - "github.com/kolide/kit/version" - "github.com/pkg/errors" - "golang.org/x/sys/windows/svc" - "golang.org/x/sys/windows/svc/debug" - "golang.org/x/sys/windows/svc/mgr" - - "github.com/kolide/launcher/pkg/log/eventlog" -) - -const serviceName = "launcher" - -func main() { - - var logger log.Logger - logger = logutil.NewCLILogger(true) //interactive - - isIntSess, err := svc.IsAnInteractiveSession() - if err != nil { - logutil.Fatal(logger, "err", errors.Wrap(err, "cannot determine if session is interactive")) - - } - - run := debug.Run - if !isIntSess { - w, err := eventlog.NewWriter(serviceName) - if err != nil { - logutil.Fatal(logger, "err", errors.Wrap(err, "create eventlog writer")) - } - defer w.Close() - logger = eventlog.New(w) - level.Debug(logger).Log("msg", "daemonized service start requested") - run = svc.Run - } - - if isSubCommand() { - switch strings.ToLower(os.Args[1]) { - case "install": - err = installService(serviceName, "Kolide Osquery Launcher") - case "remove": - err = removeService(serviceName) - case "start": - err = startService(serviceName) - case "stop": - err = controlService(serviceName, svc.Stop, svc.Stopped) - } - if err != nil { - logutil.Fatal(logger, "err", errors.Wrap(err, "run")) - } - return - } - - opts, err := parseOptions() - if err != nil { - level.Info(logger).Log("err", err) - os.Exit(1) - } - - // handle --version - if opts.printVersion { - version.PrintFull() - os.Exit(0) - } - - // handle --usage - if opts.developerUsage { - developerUsage() - os.Exit(0) - } - - err = run(serviceName, &winSvc{logger: logger, opts: opts}) - if err != nil { - logutil.Fatal(logger, "err", errors.Wrap(err, "run")) - } - -} - -type winSvc struct { - logger log.Logger - opts *options -} - -func (w *winSvc) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, errno uint32) { - const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown - changes <- svc.Status{State: svc.StartPending} - level.Debug(w.logger).Log("msg", "windows service starting") - changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted} - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - err := runLauncher(ctx, cancel, w.opts, w.logger) - if err != nil { - level.Info(w.logger).Log("err", err) - changes <- svc.Status{State: svc.Stopped, Accepts: cmdsAccepted} - os.Exit(1) - } - }() - - for { - select { - case c := <-r: - switch c.Cmd { - case svc.Interrogate: - changes <- c.CurrentStatus - // Testing deadlock from https://code.google.com/p/winsvc/issues/detail?id=4 - time.Sleep(100 * time.Millisecond) - changes <- c.CurrentStatus - case svc.Stop, svc.Shutdown: - changes <- svc.Status{State: svc.StopPending} - cancel() - time.Sleep(100 * time.Millisecond) - changes <- svc.Status{State: svc.Stopped, Accepts: cmdsAccepted} - return - default: - level.Info(w.logger).Log("err", "unexpected control request", "control_request", c) - } - } - } -} - -func exePath() (string, error) { - prog := os.Args[0] - p, err := filepath.Abs(prog) - if err != nil { - return "", err - } - fi, err := os.Stat(p) - if err == nil { - if !fi.Mode().IsDir() { - return p, nil - } - err = fmt.Errorf("%s is directory", p) - } - if filepath.Ext(p) == "" { - p += ".exe" - fi, err := os.Stat(p) - if err == nil { - if !fi.Mode().IsDir() { - return p, nil - } - err = fmt.Errorf("%s is directory", p) - } - } - return "", err -} - -func isSubCommand() bool { - if len(os.Args) > 2 { - return false - } - - subCommands := []string{ - "install", - "remove", - "start", - "stop", - } - - for _, sc := range subCommands { - if sc == os.Args[1] { - return true - } - } - - return false -} - -func installService(name, desc string) error { - exepath, err := exePath() - if err != nil { - return err - } - m, err := mgr.Connect() - if err != nil { - return err - } - defer m.Disconnect() - s, err := m.OpenService(name) - if err == nil { - s.Close() - return fmt.Errorf("service %s already exists", name) - } - cfg := mgr.Config{DisplayName: desc, StartType: mgr.StartAutomatic} - s, err = m.CreateService(name, exepath, cfg, "is", "auto-started") - if err != nil { - return err - } - defer s.Close() - return nil -} - -func removeService(name string) error { - m, err := mgr.Connect() - if err != nil { - return err - } - defer m.Disconnect() - s, err := m.OpenService(name) - if err != nil { - return fmt.Errorf("service %s is not installed", name) - } - defer s.Close() - err = s.Delete() - return err -} - -func startService(name string) error { - m, err := mgr.Connect() - if err != nil { - return err - } - defer m.Disconnect() - s, err := m.OpenService(name) - if err != nil { - return fmt.Errorf("could not access service: %v", err) - } - defer s.Close() - err = s.Start("is", "manual-started") - if err != nil { - return fmt.Errorf("could not start service: %v", err) - } - return nil -} - -func controlService(name string, c svc.Cmd, to svc.State) error { - m, err := mgr.Connect() - if err != nil { - return err - } - defer m.Disconnect() - s, err := m.OpenService(name) - if err != nil { - return fmt.Errorf("could not access service: %v", err) - } - defer s.Close() - status, err := s.Control(c) - if err != nil { - return fmt.Errorf("could not send control=%d: %v", c, err) - } - timeout := time.Now().Add(10 * time.Second) - for status.State != to { - if timeout.Before(time.Now()) { - return fmt.Errorf("timeout waiting for service to go to state=%d", to) - } - time.Sleep(300 * time.Millisecond) - status, err = s.Query() - if err != nil { - return fmt.Errorf("could not retrieve service status: %v", err) - } - } - return nil -} diff --git a/cmd/launcher/options.go b/cmd/launcher/options.go index a7e16c792..42b616cad 100644 --- a/cmd/launcher/options.go +++ b/cmd/launcher/options.go @@ -9,9 +9,9 @@ import ( "strings" "time" - "github.com/kolide/kit/env" "github.com/kolide/kit/version" "github.com/kolide/launcher/pkg/autoupdate" + "github.com/peterbourgon/ff" "github.com/pkg/errors" ) @@ -33,8 +33,6 @@ type options struct { getShellsInterval time.Duration autoupdate bool - printVersion bool - developerUsage bool debug bool disableControlTLS bool insecureTLS bool @@ -52,141 +50,60 @@ const ( // parseOptions parses the options that may be configured via command-line flags // and/or environment variables, determines order of precedence and returns a // typed struct of options for further application use -func parseOptions() (*options, error) { - var ( - // Primary options - flRootDirectory = flag.String( - "root_directory", - env.String("KOLIDE_LAUNCHER_ROOT_DIRECTORY", ""), - "The location of the local database, pidfiles, etc.", - ) - flKolideServerURL = flag.String( - "hostname", - env.String("KOLIDE_LAUNCHER_HOSTNAME", ""), - "The hostname of the gRPC server", - ) +func parseOptions(args []string) (*options, error) { - flControl = flag.Bool( - "control", - env.Bool("KOLIDE_CONTROL", false), - "Whether or not the control server is enabled (default: false)", - ) - flControlServerURL = flag.String( - "control_hostname", - env.String("KOLIDE_CONTROL_HOSTNAME", ""), - "The hostname of the control server", - ) - flGetShellsInterval = flag.Duration( - "control_get_shells_interval", - env.Duration("KOLIDE_CONTROL_GET_SHELLS_INTERVAL", 3*time.Second), - "The interval at which the get shells request will be made", - ) + flagset := flag.NewFlagSet("launcher", flag.ExitOnError) + flagset.Usage = func() { usage(flagset) } - flEnrollSecret = flag.String( - "enroll_secret", - env.String("KOLIDE_LAUNCHER_ENROLL_SECRET", ""), - "The enroll secret that is used in your environment", - ) - flEnrollSecretPath = flag.String( - "enroll_secret_path", - env.String("KOLIDE_LAUNCHER_ENROLL_SECRET_PATH", ""), - "Optionally, the path to your enrollment secret", - ) - flOsquerydPath = flag.String( - "osqueryd_path", - env.String("KOLIDE_LAUNCHER_OSQUERYD_PATH", ""), - "Path to the osqueryd binary to use (Default: find osqueryd in $PATH)", - ) - flCertPins = flag.String( - "cert_pins", - env.String("KOLIDE_LAUNCHER_CERT_PINS", ""), - "Comma separated, hex encoded SHA256 hashes of pinned subject public key info", - ) - flRootPEM = flag.String( - "root_pem", - env.String("KOLIDE_LAUNCHER_ROOT_PEM", ""), - "Path to PEM file including root certificates to verify against", - ) - flLoggingInterval = flag.Duration( - "logging_interval", - env.Duration("KOLIDE_LAUNCHER_LOGGING_INTERVAL", 60*time.Second), - "The interval at which logs should be flushed to the server", - ) + var ( + // Primary options + flCertPins = flagset.String("cert_pins", "", "Comma separated, hex encoded SHA256 hashes of pinned subject public key info") + flControl = flagset.Bool("control", false, "Whether or not the control server is enabled (default: false)") + flControlServerURL = flagset.String("control_hostname", "", "The hostname of the control server") + flEnrollSecret = flagset.String("enroll_secret", "", "The enroll secret that is used in your environment") + flEnrollSecretPath = flagset.String("enroll_secret_path", "", "Optionally, the path to your enrollment secret") + flGetShellsInterval = flagset.Duration("control_get_shells_interval", 3*time.Second, "The interval at which the 'get shells' request will be made") + flInitialRunner = flagset.Bool("with_initial_runner", false, "Run differential queries from config ahead of scheduled interval.") + flKolideServerURL = flagset.String("hostname", "", "The hostname of the gRPC server") + flLoggingInterval = flagset.Duration("logging_interval", 60*time.Second, "The interval at which logs should be flushed to the server") + flOsquerydPath = flagset.String("osqueryd_path", "", "Path to the osqueryd binary to use (Default: find osqueryd in $PATH)") + flRootDirectory = flagset.String("root_directory", "", "The location of the local database, pidfiles, etc.") + flRootPEM = flagset.String("root_pem", "", "Path to PEM file including root certificates to verify against") + flVersion = flagset.Bool("version", false, "Print Launcher version and exit") + _ = flagset.String("config", "", "config file to parse options from (optional)") // Autoupdate options - flAutoupdate = flag.Bool( - "autoupdate", - env.Bool("KOLIDE_LAUNCHER_AUTOUPDATE", false), - "Whether or not the osquery autoupdater is enabled (default: false)", - ) - flNotaryServerURL = flag.String( - "notary_url", - env.String("KOLIDE_LAUNCHER_NOTARY_SERVER_URL", autoupdate.DefaultNotary), - "The Notary update server (default: https://notary.kolide.co)", - ) - flMirrorURL = flag.String( - "mirror_url", - env.String("KOLIDE_LAUNCHER_MIRROR_SERVER_URL", autoupdate.DefaultMirror), - "The mirror server for autoupdates (default: https://dl.kolide.co)", - ) - flAutoupdateInterval = flag.Duration( - "autoupdate_interval", - duration("KOLIDE_LAUNCHER_AUTOUPDATE_INTERVAL", 1*time.Hour), - "The interval to check for updates (default: once every hour)", - ) - flUpdateChannel = flag.String( - "update_channel", - env.String("KOLIDE_LAUNCHER_UPDATE_CHANNEL", "stable"), - "The channel to pull updates from (options: stable, beta, nightly)", - ) + flAutoupdate = flagset.Bool("autoupdate", false, "Whether or not the osquery autoupdater is enabled (default: false)") + flNotaryServerURL = flagset.String("notary_url", autoupdate.DefaultNotary, "The Notary update server (default: https://notary.kolide.co)") + flMirrorURL = flagset.String("mirror_url", autoupdate.DefaultMirror, "The mirror server for autoupdates (default: https://dl.kolide.co)") + flAutoupdateInterval = flagset.Duration("autoupdate_interval", 1*time.Hour, "The interval to check for updates (default: once every hour)") + flUpdateChannel = flagset.String("update_channel", "stable", "The channel to pull updates from (options: stable, beta, nightly)") // Development options - flDebug = flag.Bool( - "debug", - env.Bool("KOLIDE_LAUNCHER_DEBUG", false), - "Whether or not debug logging is enabled (default: false)", - ) - flDisableControlTLS = flag.Bool( - "disable_control_tls", - env.Bool("KOLIDE_LAUNCHER_DISABLE_CONTROL_TLS", false), - "Disable TLS encryption for the control features", - ) - flInsecureTLS = flag.Bool( - "insecure", - env.Bool("KOLIDE_LAUNCHER_INSECURE", false), - "Do not verify TLS certs for outgoing connections (default: false)", - ) - flInsecureGRPC = flag.Bool( - "insecure_grpc", - env.Bool("KOLIDE_LAUNCHER_INSECURE_GRPC", false), - "Dial GRPC without a TLS config (default: false)", - ) - - // Version command: launcher --version - flVersion = flag.Bool( - "version", - env.Bool("KOLIDE_LAUNCHER_VERSION", false), - "Print Launcher version and exit", - ) - - // Developer usage - flDeveloperUsage = flag.Bool( - "dev_help", - env.Bool("KOLIDE_LAUNCHER_DEV_HELP", false), - "Print full Launcher help, including developer options", - ) + flDebug = flagset.Bool("debug", false, "Whether or not debug logging is enabled (default: false)") + flDeveloperUsage = flagset.Bool("dev_help", false, "Print full Launcher help, including developer options") + flDisableControlTLS = flagset.Bool("disable_control_tls", false, "Disable TLS encryption for the control features") + flInsecureGRPC = flagset.Bool("insecure_grpc", false, "Dial GRPC without a TLS config (default: false)") + flInsecureTLS = flagset.Bool("insecure", false, "Do not verify TLS certs for outgoing connections (default: false)") + ) - // Enable Initial Runner: launcher --with_initial_runner - flInitialRunner = flag.Bool( - "with_initial_runner", - env.Bool("KOLIDE_LAUNCHER_INITIAL_RUNNER", false), - "Run differential queries from config ahead of scheduled interval.", - ) + ff.Parse(flagset, args, + ff.WithConfigFileFlag("config"), + ff.WithConfigFileParser(ff.PlainParser), + ff.WithEnvVarPrefix("KOLIDE_LAUNCHER"), ) - flag.Usage = usage + // handle --version + if *flVersion { + version.PrintFull() + os.Exit(0) + } - flag.Parse() + // handle --usage + if *flDeveloperUsage { + developerUsage(flagset) + os.Exit(0) + } // if an osqueryd path was not set, it's likely that we want to use the bundled // osqueryd path, but if it cannot be found, we will fail back to using an @@ -237,8 +154,6 @@ func parseOptions() (*options, error) { loggingInterval: *flLoggingInterval, enableInitialRunner: *flInitialRunner, autoupdate: *flAutoupdate, - printVersion: *flVersion, - developerUsage: *flDeveloperUsage, debug: *flDebug, disableControlTLS: *flDisableControlTLS, insecureTLS: *flInsecureTLS, @@ -251,12 +166,12 @@ func parseOptions() (*options, error) { return opts, nil } -func shortUsage() { +func shortUsage(flagset *flag.FlagSet) { launcherFlags := map[string]string{} flagAggregator := func(f *flag.Flag) { launcherFlags[f.Name] = f.Usage } - flag.VisitAll(flagAggregator) + flagset.VisitAll(flagAggregator) printOpt := func(opt string) { fmt.Fprintf(os.Stderr, " --%s", opt) @@ -294,17 +209,17 @@ func shortUsage() { fmt.Fprintf(os.Stderr, "\n") } -func usage() { - shortUsage() +func usage(flagset *flag.FlagSet) { + shortUsage(flagset) usageFooter() } -func developerUsage() { +func developerUsage(flagset *flag.FlagSet) { launcherFlags := map[string]string{} flagAggregator := func(f *flag.Flag) { launcherFlags[f.Name] = f.Usage } - flag.VisitAll(flagAggregator) + flagset.VisitAll(flagAggregator) printOpt := func(opt string) { fmt.Fprintf(os.Stderr, " --%s", opt) @@ -314,7 +229,8 @@ func developerUsage() { fmt.Fprintf(os.Stderr, "%s\n", launcherFlags[opt]) } - shortUsage() + shortUsage(flagset) + fmt.Fprintf(os.Stderr, "\n") fmt.Fprintf(os.Stderr, "Development Options:\n") fmt.Fprintf(os.Stderr, "\n") @@ -341,19 +257,6 @@ func usageFooter() { fmt.Fprintf(os.Stderr, "\n") } -// TODO: move to kolide/kit and figure out error handling there. -func duration(key string, def time.Duration) time.Duration { - if env, ok := os.LookupEnv(key); ok { - t, err := time.ParseDuration(env) - if err != nil { - fmt.Println("env: parse duration flag: ", err) - os.Exit(1) - } - return t - } - return def -} - func parseCertPins(pins string) ([][]byte, error) { var certPins [][]byte if pins != "" { diff --git a/cmd/launcher/options_test.go b/cmd/launcher/options_test.go new file mode 100644 index 000000000..3e85dfc10 --- /dev/null +++ b/cmd/launcher/options_test.go @@ -0,0 +1,109 @@ +package main + +import ( + "fmt" + "io/ioutil" + "math/rand" + "os" + "strings" + "testing" + "time" + + "github.com/kolide/kit/stringutil" + "github.com/stretchr/testify/require" +) + +// TestOptionsFromFlags isn't parallel to ensure that we don't pollute the environment +func TestOptionsFromFlags(t *testing.T) { + os.Clearenv() + + testArgs, expectedOpts := getArgsAndResponse() + + testFlags := []string{} + for k, v := range testArgs { + testFlags = append(testFlags, k) + if v != "" { + testFlags = append(testFlags, v) + } + } + + opts, err := parseOptions(testFlags) + require.NoError(t, err) + require.Equal(t, expectedOpts, opts) +} + +func TestOptionsFromEnv(t *testing.T) { + os.Clearenv() + + testArgs, expectedOpts := getArgsAndResponse() + + for k, val := range testArgs { + if val == "" { + val = "true" + } + name := fmt.Sprintf("KOLIDE_LAUNCHER_%s", strings.ToUpper(strings.TrimLeft(k, "-"))) + require.NoError(t, os.Setenv(name, val)) + } + opts, err := parseOptions([]string{}) + require.NoError(t, err) + require.Equal(t, expectedOpts, opts) +} + +func TestOptionsFromFile(t *testing.T) { + os.Clearenv() + + testArgs, expectedOpts := getArgsAndResponse() + + flagFile, err := ioutil.TempFile("", "flag-file") + require.NoError(t, err) + defer os.Remove(flagFile.Name()) + + for k, val := range testArgs { + var err error + + _, err = flagFile.WriteString(strings.TrimLeft(k, "-")) + require.NoError(t, err) + + if val != "" { + _, err = flagFile.WriteString(fmt.Sprintf(" %s", val)) + require.NoError(t, err) + } + + _, err = flagFile.WriteString("\n") + require.NoError(t, err) + } + + require.NoError(t, flagFile.Close()) + + opts, err := parseOptions([]string{"-config", flagFile.Name()}) + require.NoError(t, err) + require.Equal(t, expectedOpts, opts) +} + +func getArgsAndResponse() (map[string]string, *options) { + randomHostname := fmt.Sprintf("%s.example.com", stringutil.RandomString(8)) + randomInt := rand.Intn(1024) + + // includes both `-` and `--` for variety. + args := map[string]string{ + "-control": "", // This is a bool, it's special cased in the test routines + "--hostname": randomHostname, + "-autoupdate_interval": "48h", + "-logging_interval": fmt.Sprintf("%ds", randomInt), + "-osqueryd_path": "/dev/null", + } + + opts := &options{ + control: true, + osquerydPath: "/dev/null", + kolideServerURL: randomHostname, + getShellsInterval: 3 * time.Second, + loggingInterval: time.Duration(randomInt) * time.Second, + autoupdateInterval: 48 * time.Hour, + notaryServerURL: "https://notary.kolide.co", + mirrorServerURL: "https://dl.kolide.co", + updateChannel: "stable", + } + + return args, opts +} diff --git a/cmd/launcher/svc.go b/cmd/launcher/svc.go new file mode 100644 index 000000000..70f06bf7b --- /dev/null +++ b/cmd/launcher/svc.go @@ -0,0 +1,20 @@ +// +build !windows + +package main + +import ( + "fmt" + "os" +) + +func runWindowsSvc(args []string) error { + fmt.Println("This isn't windows") + os.Exit(1) + return nil +} + +func runWindowsSvcForeground(args []string) error { + fmt.Println("This isn't windows") + os.Exit(1) + return nil +} diff --git a/cmd/launcher/svc_windows.go b/cmd/launcher/svc_windows.go new file mode 100644 index 000000000..97125ed3e --- /dev/null +++ b/cmd/launcher/svc_windows.go @@ -0,0 +1,103 @@ +// +build windows + +package main + +import ( + "context" + "os" + "time" + + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/log/level" + "github.com/kolide/kit/logutil" + "github.com/kolide/launcher/pkg/log/eventlog" + "github.com/pkg/errors" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/debug" +) + +// TODO This should be inherited from some setting +const serviceName = "launcher" + +// runWindowsSvc starts launcher as a windows service. This will +// probably not behave correctly if you start it from the command line. +func runWindowsSvc(args []string) error { + eventLogWriter, err := eventlog.NewWriter(serviceName) + if err != nil { + return errors.Wrap(err, "create eventlog writer") + } + defer eventLogWriter.Close() + + logger := eventlog.New(eventLogWriter) + level.Debug(logger).Log("msg", "service start requested") + + opts, err := parseOptions(os.Args[2:]) + if err != nil { + level.Info(logger).Log("err", err) + os.Exit(1) + } + + run := svc.Run + + return run(serviceName, &winSvc{logger: logger, opts: opts}) +} + +func runWindowsSvcForeground(args []string) error { + logger := logutil.NewCLILogger(true) //interactive + level.Debug(logger).Log("msg", "foreground service start requested (debug mode)") + + opts, err := parseOptions(os.Args[2:]) + if err != nil { + level.Info(logger).Log("err", err) + os.Exit(1) + } + + run := debug.Run + + return run(serviceName, &winSvc{logger: logger, opts: opts}) +} + +type winSvc struct { + logger log.Logger + opts *options +} + +func (w *winSvc) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, errno uint32) { + const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown + changes <- svc.Status{State: svc.StartPending} + level.Debug(w.logger).Log("msg", "windows service starting") + changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + err := runLauncher(ctx, cancel, w.opts, w.logger) + if err != nil { + level.Info(w.logger).Log("err", err) + changes <- svc.Status{State: svc.Stopped, Accepts: cmdsAccepted} + os.Exit(1) + } + }() + + for { + select { + case c := <-r: + switch c.Cmd { + case svc.Interrogate: + changes <- c.CurrentStatus + // Testing deadlock from https://code.google.com/p/winsvc/issues/detail?id=4 + time.Sleep(100 * time.Millisecond) + changes <- c.CurrentStatus + case svc.Stop, svc.Shutdown: + changes <- svc.Status{State: svc.StopPending} + cancel() + time.Sleep(100 * time.Millisecond) + changes <- svc.Status{State: svc.Stopped, Accepts: cmdsAccepted} + return + default: + level.Info(w.logger).Log("err", "unexpected control request", "control_request", c) + } + } + } +} diff --git a/go.mod b/go.mod index 0a7871506..97b1cc2d0 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/bugsnag/panicwrap v1.2.0 // indirect github.com/cenkalti/backoff v2.0.0+incompatible // indirect github.com/cloudflare/cfssl v0.0.0-20181102015659-ea4033a214e7 // indirect + github.com/davecgh/go-spew v1.1.1 github.com/denisenkom/go-mssqldb v0.0.0-20181014144952-4e0d7dc8888f // indirect github.com/docker/distribution v2.6.2+incompatible // indirect github.com/docker/go-connections v0.4.0 // indirect @@ -49,6 +50,7 @@ require ( github.com/oklog/run v1.0.0 github.com/onsi/ginkgo v1.7.0 // indirect github.com/onsi/gomega v1.4.3 // indirect + github.com/peterbourgon/ff v1.1.0 github.com/pkg/errors v0.8.1 github.com/prometheus/client_golang v0.9.2 // indirect github.com/sirupsen/logrus v1.2.0 // indirect diff --git a/go.sum b/go.sum index 7d7302e31..10268e954 100644 --- a/go.sum +++ b/go.sum @@ -151,6 +151,8 @@ github.com/opencensus-integrations/ocsql v0.1.1/go.mod h1:ozPYpNVBHZsX33jfoQPO5T github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= +github.com/peterbourgon/ff v1.1.0 h1:mz0/c1gtUwz/MsyJs3GwPjFSTFlD4UEwKEPw5Z+0cm8= +github.com/peterbourgon/ff v1.1.0/go.mod h1:P2CDTJbjip+iAWS2jnkp6uKBX2Bcvc/K4p5j47CTRv0= github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= diff --git a/pkg/log/eventlog/writer_windows.go b/pkg/log/eventlog/writer_windows.go index ed855fa32..ebbe425c0 100644 --- a/pkg/log/eventlog/writer_windows.go +++ b/pkg/log/eventlog/writer_windows.go @@ -53,5 +53,8 @@ func (w *Writer) Write(p []byte) (n int, err error) { } func isAlreadyExists(err error) bool { + if err == nil { + return false + } return strings.Contains(err.Error(), "registry key already exists") } diff --git a/pkg/packaging/packaging.go b/pkg/packaging/packaging.go index 3359e0777..9ca463eb3 100644 --- a/pkg/packaging/packaging.go +++ b/pkg/packaging/packaging.go @@ -106,7 +106,7 @@ func (p *PackageOptions) Build(ctx context.Context, packageWriter io.Writer, tar } if p.Control && p.ControlHostname != "" { - launcherEnv["KOLIDE_CONTROL_HOSTNAME"] = p.ControlHostname + launcherEnv["KOLIDE_LAUNCHER_CONTROL_HOSTNAME"] = p.ControlHostname } if p.Autoupdate && p.UpdateChannel != "" {