diff --git a/cmd/launcher/desktop.go b/cmd/launcher/desktop.go index a397fadb0..ea13e515e 100644 --- a/cmd/launcher/desktop.go +++ b/cmd/launcher/desktop.go @@ -130,7 +130,9 @@ func runDesktop(_ *multislogger.MultiSlogger, args []string) error { }, func(error) {}) shutdownChan := make(chan struct{}) - server, err := userserver.New(slogger, *flUserServerAuthToken, *flUserServerSocketPath, shutdownChan, notifier) + showDesktopChan := make(chan struct{}) + + server, err := userserver.New(slogger, *flUserServerAuthToken, *flUserServerSocketPath, shutdownChan, showDesktopChan, notifier) if err != nil { return err } @@ -182,9 +184,10 @@ func runDesktop(_ *multislogger.MultiSlogger, args []string) error { } }() + // block until a send on showDesktopChan + <-showDesktopChan // blocks until shutdown called m.Init() - return nil } diff --git a/cmd/launcher/launcher.go b/cmd/launcher/launcher.go index 9d39daa44..dea7592c3 100644 --- a/cmd/launcher/launcher.go +++ b/cmd/launcher/launcher.go @@ -504,6 +504,7 @@ func runLauncher(ctx context.Context, cancel func(), multiSlogger, systemMultiSl ls, err := localserver.New( ctx, k, + runner, ) if err != nil { diff --git a/ee/desktop/runner/runner.go b/ee/desktop/runner/runner.go index db02aaf98..5c560f995 100644 --- a/ee/desktop/runner/runner.go +++ b/ee/desktop/runner/runner.go @@ -29,6 +29,7 @@ import ( "github.com/kolide/launcher/ee/desktop/user/client" "github.com/kolide/launcher/ee/desktop/user/menu" "github.com/kolide/launcher/ee/desktop/user/notify" + "github.com/kolide/launcher/ee/presencedetection" "github.com/kolide/launcher/ee/ui/assets" "github.com/kolide/launcher/pkg/backoff" "github.com/kolide/launcher/pkg/rungroup" @@ -118,9 +119,6 @@ type DesktopUsersProcessesRunner struct { // usersFilesRoot is the launcher root dir with will be the parent dir // for kolide desktop files on a per user basis usersFilesRoot string - // processSpawningEnabled controls whether or not desktop user processes are automatically spawned - // This effectively represents whether or not the launcher desktop GUI is enabled or not - processSpawningEnabled bool // knapsack is the almighty sack of knaps knapsack types.Knapsack // runnerServer is a local server that desktop processes call to monitor parent @@ -155,17 +153,16 @@ func (pr processRecord) String() string { // New creates and returns a new DesktopUsersProcessesRunner runner and initializes all required fields func New(k types.Knapsack, messenger runnerserver.Messenger, opts ...desktopUsersProcessesRunnerOption) (*DesktopUsersProcessesRunner, error) { runner := &DesktopUsersProcessesRunner{ - interrupt: make(chan struct{}), - uidProcs: make(map[string]processRecord), - updateInterval: k.DesktopUpdateInterval(), - menuRefreshInterval: k.DesktopMenuRefreshInterval(), - procsWg: &sync.WaitGroup{}, - interruptTimeout: time.Second * 5, - hostname: k.KolideServerURL(), - usersFilesRoot: agent.TempPath("kolide-desktop"), - processSpawningEnabled: k.DesktopEnabled(), - knapsack: k, - cachedMenuData: newMenuItemCache(), + interrupt: make(chan struct{}), + uidProcs: make(map[string]processRecord), + updateInterval: k.DesktopUpdateInterval(), + menuRefreshInterval: k.DesktopMenuRefreshInterval(), + procsWg: &sync.WaitGroup{}, + interruptTimeout: time.Second * 5, + hostname: k.KolideServerURL(), + usersFilesRoot: agent.TempPath("kolide-desktop"), + knapsack: k, + cachedMenuData: newMenuItemCache(), } runner.slogger = k.Slogger().With("component", "desktop_runner") @@ -286,6 +283,29 @@ func (r *DesktopUsersProcessesRunner) Interrupt(_ error) { ) } +func (r *DesktopUsersProcessesRunner) DetectPresence(reason string, interval time.Duration) (time.Duration, error) { + if r.uidProcs == nil || len(r.uidProcs) == 0 { + return presencedetection.DetectionFailedDurationValue, errors.New("no desktop processes running") + } + + var lastErr error + var lastDurationSinceLastDetection time.Duration + + for _, proc := range r.uidProcs { + client := client.New(r.userServerAuthToken, proc.socketPath) + lastDurationSinceLastDetection, err := client.DetectPresence(reason, interval) + + if err != nil { + lastErr = err + continue + } + + return lastDurationSinceLastDetection, nil + } + + return lastDurationSinceLastDetection, fmt.Errorf("no desktop processes detected presence, last error: %w", lastErr) +} + // killDesktopProcesses kills any existing desktop processes func (r *DesktopUsersProcessesRunner) killDesktopProcesses(ctx context.Context) { wgDone := make(chan struct{}) @@ -452,12 +472,35 @@ func (r *DesktopUsersProcessesRunner) Update(data io.Reader) error { } func (r *DesktopUsersProcessesRunner) FlagsChanged(flagKeys ...keys.FlagKey) { - if slices.Contains(flagKeys, keys.DesktopEnabled) { - r.processSpawningEnabled = r.knapsack.DesktopEnabled() - r.slogger.Log(context.TODO(), slog.LevelDebug, - "runner processSpawningEnabled set by control server", - "process_spawning_enabled", r.processSpawningEnabled, - ) + if !slices.Contains(flagKeys, keys.DesktopEnabled) { + return + } + + r.slogger.Log(context.TODO(), slog.LevelDebug, + "desktop enabled set by control server", + "desktop_enabled", r.knapsack.DesktopEnabled(), + ) + + if !r.knapsack.DesktopEnabled() { + // there is no way to "hide" the menu, so we will just kill any existing processes + // they will respawn in "silent" mode + r.killDesktopProcesses(context.TODO()) + return + } + + // DesktopEnabled() == true + // Tell any running desktop user processes that they should show the menu + for uid, proc := range r.uidProcs { + client := client.New(r.userServerAuthToken, proc.socketPath) + if err := client.ShowDesktop(); err != nil { + r.slogger.Log(context.TODO(), slog.LevelError, + "sending refresh command to user desktop process", + "uid", uid, + "pid", proc.Process.Pid, + "path", proc.path, + "err", err, + ) + } } } @@ -483,6 +526,10 @@ func (r *DesktopUsersProcessesRunner) writeSharedFile(path string, data []byte) // refreshMenu updates the menu file and tells desktop processes to refresh their menus func (r *DesktopUsersProcessesRunner) refreshMenu() { + if !r.knapsack.DesktopEnabled() { + return + } + if err := r.generateMenuFile(); err != nil { if r.knapsack.DebugServerData() { r.slogger.Log(context.TODO(), slog.LevelError, @@ -503,7 +550,6 @@ func (r *DesktopUsersProcessesRunner) refreshMenu() { for uid, proc := range r.uidProcs { client := client.New(r.userServerAuthToken, proc.socketPath) if err := client.Refresh(); err != nil { - r.slogger.Log(context.TODO(), slog.LevelError, "sending refresh command to user desktop process", "uid", uid, @@ -601,12 +647,6 @@ func (r *DesktopUsersProcessesRunner) runConsoleUserDesktop() error { return nil } - if !r.processSpawningEnabled { - // Desktop is disabled, kill any existing desktop user processes - r.killDesktopProcesses(context.Background()) - return nil - } - executablePath, err := r.determineExecutablePath() if err != nil { return fmt.Errorf("determining executable path: %w", err) @@ -669,13 +709,23 @@ func (r *DesktopUsersProcessesRunner) spawnForUser(ctx context.Context, uid stri r.waitOnProcessAsync(uid, cmd.Process) client := client.New(r.userServerAuthToken, socketPath) - if err := backoff.WaitFor(client.Ping, 10*time.Second, 1*time.Second); err != nil { + + pingFunc := client.Ping + + // if the desktop is enabled, we want to show the desktop + // just perform this instead of ping to verify the desktop is running + // and show it right away + if r.knapsack.DesktopEnabled() { + pingFunc = client.ShowDesktop + } + + if err := backoff.WaitFor(pingFunc, 10*time.Second, 1*time.Second); err != nil { // unregister proc from desktop server so server will not respond to its requests r.runnerServer.DeRegisterClient(uid) if err := cmd.Process.Kill(); err != nil { r.slogger.Log(ctx, slog.LevelError, - "killing user desktop process after startup ping failed", + "killing user desktop process after startup ping / show desktop failed", "uid", uid, "pid", cmd.Process.Pid, "path", cmd.Path, diff --git a/ee/desktop/runner/runner_test.go b/ee/desktop/runner/runner_test.go index 211a9ace5..9b26444b1 100644 --- a/ee/desktop/runner/runner_test.go +++ b/ee/desktop/runner/runner_test.go @@ -153,7 +153,7 @@ func TestDesktopUserProcessRunner_Execute(t *testing.T) { }() // let it run a few intervals - time.Sleep(r.updateInterval * 3) + time.Sleep(r.updateInterval * 6) r.Interrupt(nil) user, err := user.Current() diff --git a/ee/desktop/user/client/client.go b/ee/desktop/user/client/client.go index 8fe1641b6..da90291dd 100644 --- a/ee/desktop/user/client/client.go +++ b/ee/desktop/user/client/client.go @@ -4,11 +4,15 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "net/http" + "net/url" "time" "github.com/kolide/launcher/ee/desktop/user/notify" + "github.com/kolide/launcher/ee/desktop/user/server" + "github.com/kolide/launcher/ee/presencedetection" ) type transport struct { @@ -55,6 +59,42 @@ func (c *client) Refresh() error { return c.get("refresh") } +func (c *client) ShowDesktop() error { + return c.get("show") +} + +func (c *client) DetectPresence(reason string, interval time.Duration) (time.Duration, error) { + encodedReason := url.QueryEscape(reason) + encodedInterval := url.QueryEscape(interval.String()) + + // default time out of 30s is set in New() + resp, requestErr := c.base.Get(fmt.Sprintf("http://unix/detect_presence?reason=%s&interval=%s", encodedReason, encodedInterval)) + if requestErr != nil { + return presencedetection.DetectionFailedDurationValue, fmt.Errorf("getting presence: %w", requestErr) + } + + var response server.DetectPresenceResponse + if resp.Body != nil { + defer resp.Body.Close() + + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + return presencedetection.DetectionFailedDurationValue, fmt.Errorf("decoding response: %w", err) + } + } + + var detectionErr error + if response.Error != "" { + detectionErr = errors.New(response.Error) + } + + durationSinceLastDetection, parseErr := time.ParseDuration(response.DurationSinceLastDetection) + if parseErr != nil { + return presencedetection.DetectionFailedDurationValue, fmt.Errorf("parsing time since last detection: %w", parseErr) + } + + return durationSinceLastDetection, detectionErr +} + func (c *client) Notify(n notify.Notification) error { notificationToSend := notify.Notification{ Title: n.Title, diff --git a/ee/desktop/user/client/client_test.go b/ee/desktop/user/client/client_test.go index 97884c9d1..f000a2e2c 100644 --- a/ee/desktop/user/client/client_test.go +++ b/ee/desktop/user/client/client_test.go @@ -42,7 +42,7 @@ func TestClient_GetAndShutdown(t *testing.T) { socketPath := testSocketPath(t) shutdownChan := make(chan struct{}) - server, err := server.New(multislogger.NewNopLogger(), validAuthToken, socketPath, shutdownChan, nil) + server, err := server.New(multislogger.NewNopLogger(), validAuthToken, socketPath, shutdownChan, make(chan<- struct{}), nil) require.NoError(t, err) go func() { diff --git a/ee/desktop/user/server/server.go b/ee/desktop/user/server/server.go index 87bd16636..c15164a3a 100644 --- a/ee/desktop/user/server/server.go +++ b/ee/desktop/user/server/server.go @@ -13,9 +13,11 @@ import ( "os" "runtime" "strings" + "sync" "time" "github.com/kolide/launcher/ee/desktop/user/notify" + "github.com/kolide/launcher/ee/presencedetection" "github.com/kolide/launcher/pkg/backoff" ) @@ -26,23 +28,33 @@ type notificationSender interface { // UserServer provides IPC for the root desktop runner to communicate with the user desktop processes. // It allows the runner process to send notficaitons and commands to the desktop processes. type UserServer struct { - slogger *slog.Logger - server *http.Server - listener net.Listener - shutdownChan chan<- struct{} - authToken string - socketPath string - notifier notificationSender - refreshListeners []func() + slogger *slog.Logger + server *http.Server + listener net.Listener + shutdownChan chan<- struct{} + authToken string + socketPath string + notifier notificationSender + refreshListeners []func() + presenceDetector presencedetection.PresenceDetector + showDesktopOnceFunc func() } -func New(slogger *slog.Logger, authToken string, socketPath string, shutdownChan chan<- struct{}, notifier notificationSender) (*UserServer, error) { +func New(slogger *slog.Logger, + authToken string, + socketPath string, + shutdownChan chan<- struct{}, + showDesktopChan chan<- struct{}, + notifier notificationSender) (*UserServer, error) { userServer := &UserServer{ shutdownChan: shutdownChan, authToken: authToken, slogger: slogger.With("component", "desktop_server"), socketPath: socketPath, notifier: notifier, + showDesktopOnceFunc: sync.OnceFunc(func() { + showDesktopChan <- struct{}{} + }), } authedMux := http.NewServeMux() @@ -50,6 +62,8 @@ func New(slogger *slog.Logger, authToken string, socketPath string, shutdownChan authedMux.HandleFunc("/ping", userServer.pingHandler) authedMux.HandleFunc("/notification", userServer.notificationHandler) authedMux.HandleFunc("/refresh", userServer.refreshHandler) + authedMux.HandleFunc("/show", userServer.showDesktop) + authedMux.HandleFunc("/detect_presence", userServer.detectPresence) userServer.server = &http.Server{ Handler: userServer.authMiddleware(authedMux), @@ -152,6 +166,70 @@ func (s *UserServer) notificationHandler(w http.ResponseWriter, req *http.Reques w.WriteHeader(http.StatusOK) } +func (s *UserServer) showDesktop(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + s.showDesktopOnceFunc() +} + +type DetectPresenceResponse struct { + DurationSinceLastDetection string `json:"duration_since_last_detection,omitempty"` + Error string `json:"error,omitempty"` +} + +func (s *UserServer) detectPresence(w http.ResponseWriter, req *http.Request) { + // get reason url param from req + reason := req.URL.Query().Get("reason") + + if reason == "" { + http.Error(w, "reason is required", http.StatusBadRequest) + return + } + + // get intervalString from url param + intervalString := req.URL.Query().Get("interval") + if intervalString == "" { + http.Error(w, "interval is required", http.StatusBadRequest) + return + } + + interval, err := time.ParseDuration(intervalString) + if err != nil { + http.Error(w, "interval is not a valid duration", http.StatusBadRequest) + return + } + + // detect presence + durationSinceLastDetection, err := s.presenceDetector.DetectPresence(reason, interval) + response := DetectPresenceResponse{ + DurationSinceLastDetection: durationSinceLastDetection.String(), + } + + if err != nil { + response.Error = err.Error() + + s.slogger.Log(req.Context(), slog.LevelDebug, + "detecting presence", + "reason", reason, + "interval", interval, + "err", err, + ) + } + + // convert response to json + responseBytes, err := json.Marshal(response) + if err != nil { + http.Error(w, "could not marshal response", http.StatusInternalServerError) + return + } + + // write response + w.Header().Set("Content-Type", "application/json") + w.Write(responseBytes) + w.WriteHeader(http.StatusOK) +} + func (s *UserServer) refreshHandler(w http.ResponseWriter, req *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) diff --git a/ee/desktop/user/server/server_test.go b/ee/desktop/user/server/server_test.go index 964c4639a..6ec884560 100644 --- a/ee/desktop/user/server/server_test.go +++ b/ee/desktop/user/server/server_test.go @@ -121,7 +121,7 @@ func testServer(t *testing.T, authHeader, socketPath string, logBytes *bytes.Buf Level: slog.LevelDebug, })) - server, err := New(slogger, authHeader, socketPath, shutdownChan, nil) + server, err := New(slogger, authHeader, socketPath, shutdownChan, make(chan<- struct{}), nil) require.NoError(t, err) return server, shutdownChan } diff --git a/ee/localserver/krypto-ec-middleware.go b/ee/localserver/krypto-ec-middleware.go index c7f28a0d6..1d7b72aab 100644 --- a/ee/localserver/krypto-ec-middleware.go +++ b/ee/localserver/krypto-ec-middleware.go @@ -25,15 +25,19 @@ import ( ) const ( - timestampValidityRange = 150 - kolideKryptoEccHeader20230130Value = "2023-01-30" - kolideKryptoHeaderKey = "X-Kolide-Krypto" - kolideSessionIdHeaderKey = "X-Kolide-Session" + timestampValidityRange = 150 + kolideKryptoEccHeader20230130Value = "2023-01-30" + kolideKryptoHeaderKey = "X-Kolide-Krypto" + kolideSessionIdHeaderKey = "X-Kolide-Session" + kolidePresenceDetectionInterval = "X-Kolide-Presence-Detection-Interval" + kolidePresenceDetectionReason = "X-Kolide-Presence-Detection-Reason" + kolideDurationSinceLastPresenceDetection = "X-Kolide-Duration-Since-Last-Presence-Detection" ) type v2CmdRequestType struct { Path string Body []byte + Headers map[string][]string CallbackUrl string CallbackHeaders map[string][]string AllowedOrigins []string @@ -285,6 +289,12 @@ func (e *kryptoEcMiddleware) Wrap(next http.Handler) http.Handler { }, } + for h, vals := range cmdReq.Headers { + for _, v := range vals { + newReq.Header.Add(h, v) + } + } + newReq.Header.Set("Origin", r.Header.Get("Origin")) newReq.Header.Set("Referer", r.Header.Get("Referer")) @@ -306,15 +316,44 @@ func (e *kryptoEcMiddleware) Wrap(next http.Handler) http.Handler { bhr := &bufferedHttpResponse{} next.ServeHTTP(bhr, newReq) + // add headers to the response map + // this assumes that the response to `bhr` was a json encoded blob. + var responseMap map[string]interface{} + bhrBytes := bhr.Bytes() + if err := json.Unmarshal(bhrBytes, &responseMap); err != nil { + traces.SetError(span, err) + e.slogger.Log(r.Context(), slog.LevelError, + "unable to unmarshal response", + "err", err, + ) + responseMap = map[string]any{ + "headers": bhr.Header(), + + // the request body was not in json format, just pass it through as "body" + "body": string(bhrBytes), + } + } else { + responseMap["headers"] = bhr.Header() + } + + responseBytes, err := json.Marshal(responseMap) + if err != nil { + traces.SetError(span, err) + e.slogger.Log(r.Context(), slog.LevelError, + "unable to marshal response", + "err", err, + ) + } + var response []byte // it's possible the keys will be noop keys, then they will error or give nil when crypto.Signer funcs are called // krypto library has a nil check for the object but not the funcs, so if are getting nil from the funcs, just // pass nil to krypto // hardware signing is not implemented for darwin if runtime.GOOS != "darwin" && e.hardwareSigner != nil && e.hardwareSigner.Public() != nil { - response, err = challengeBox.Respond(e.localDbSigner, e.hardwareSigner, bhr.Bytes()) + response, err = challengeBox.Respond(e.localDbSigner, e.hardwareSigner, responseBytes) } else { - response, err = challengeBox.Respond(e.localDbSigner, nil, bhr.Bytes()) + response, err = challengeBox.Respond(e.localDbSigner, nil, responseBytes) } if err != nil { diff --git a/ee/localserver/krypto-ec-middleware_test.go b/ee/localserver/krypto-ec-middleware_test.go index fa135795a..ace5ef654 100644 --- a/ee/localserver/krypto-ec-middleware_test.go +++ b/ee/localserver/krypto-ec-middleware_test.go @@ -7,12 +7,14 @@ import ( "crypto/rand" "encoding/base64" "encoding/json" + "fmt" "io" "log/slog" "math/big" "net/http" "net/http/httptest" "net/url" + "runtime" "strings" "testing" "time" @@ -21,8 +23,11 @@ import ( "github.com/kolide/krypto/pkg/challenge" "github.com/kolide/krypto/pkg/echelper" "github.com/kolide/launcher/ee/agent/keys" + "github.com/kolide/launcher/ee/localserver/mocks" + "github.com/kolide/launcher/pkg/log/multislogger" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -36,6 +41,10 @@ func TestKryptoEcMiddleware(t *testing.T) { challengeData := []byte(ulid.New()) koldieSessionId := ulid.New() + cmdRequestHeaders := map[string][]string{ + kolidePresenceDetectionInterval: {"0s"}, + } + cmdReqCallBackHeaders := map[string][]string{ kolideSessionIdHeaderKey: {koldieSessionId}, } @@ -44,6 +53,7 @@ func TestKryptoEcMiddleware(t *testing.T) { cmdReq := mustMarshal(t, v2CmdRequestType{ Path: "whatevs", Body: cmdReqBody, + Headers: cmdRequestHeaders, CallbackHeaders: cmdReqCallBackHeaders, }) @@ -130,12 +140,18 @@ func TestKryptoEcMiddleware(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - responseData := tt.responseData - // generate the response we want the handler to return - if responseData == nil { - responseData = []byte(ulid.New()) + responseMap := make(map[string]any) + const testMsgKey = "body" + + responseValue := string(tt.responseData) + if responseValue == "" { + responseValue = ulid.New() } + responseMap[testMsgKey] = responseValue + + responseDataRaw := mustMarshal(t, responseMap) + testHandler := tt.handler // this handler is what will respond to the request made by the kryptoEcMiddleware.Wrap handler @@ -143,12 +159,18 @@ func TestKryptoEcMiddleware(t *testing.T) { // this should match the responseData in the opened response if testHandler == nil { testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + // make sure all the request headers are present + for k, v := range cmdRequestHeaders { + require.Equal(t, v[0], r.Header.Get(k)) + } + reqBodyRaw, err := io.ReadAll(r.Body) require.NoError(t, err) defer r.Body.Close() require.Equal(t, cmdReqBody, reqBodyRaw) - w.Write(responseData) + w.Write(responseDataRaw) }) } @@ -169,8 +191,15 @@ func TestKryptoEcMiddleware(t *testing.T) { kryptoEcMiddleware := newKryptoEcMiddleware(slogger, tt.localDbKey, tt.hardwareKey, counterpartyKey.PublicKey) require.NoError(t, err) + mockPresenceDetector := mocks.NewPresenceDetector(t) + mockPresenceDetector.On("DetectPresence", mock.AnythingOfType("string"), mock.AnythingOfType("Duration")).Return(0*time.Second, nil).Maybe() + localServer := &localServer{ + presenceDetector: mockPresenceDetector, + slogger: multislogger.NewNopLogger(), + } + // give our middleware with the test handler to the determiner - h := kryptoEcMiddleware.Wrap(testHandler) + h := kryptoEcMiddleware.Wrap(localServer.presenceDetectionHandler(testHandler)) rr := httptest.NewRecorder() h.ServeHTTP(rr, req) @@ -201,8 +230,20 @@ func TestKryptoEcMiddleware(t *testing.T) { opened, err := responseUnmarshalled.Open(*privateEncryptionKey) require.NoError(t, err) require.Equal(t, challengeData, opened.ChallengeData) - require.Equal(t, responseData, opened.ResponseData) + + opendResponseValue, err := extractJsonProperty[string](opened.ResponseData, testMsgKey) + require.NoError(t, err) + require.Equal(t, responseValue, opendResponseValue) + require.WithinDuration(t, time.Now(), time.Unix(opened.Timestamp, 0), time.Second*5) + + responseHeaders, err := extractJsonProperty[map[string][]string](opened.ResponseData, "headers") + require.NoError(t, err) + + // check that the presence detection interval is present + if runtime.GOOS == "darwin" { + require.Equal(t, (0 * time.Second).String(), responseHeaders[kolideDurationSinceLastPresenceDetection][0]) + } }) } }) @@ -357,7 +398,11 @@ func Test_AllowedOrigin(t *testing.T) { opened, err := responseUnmarshalled.Open(*privateEncryptionKey) require.NoError(t, err) require.Equal(t, challengeData, opened.ChallengeData) - require.Equal(t, responseData, opened.ResponseData) + + openedResponseValue, err := extractJsonProperty[string](opened.ResponseData, "body") + require.NoError(t, err) + + require.Equal(t, responseData, []byte(openedResponseValue)) require.WithinDuration(t, time.Now(), time.Unix(opened.Timestamp, 0), time.Second*5) }) @@ -422,3 +467,28 @@ func mustMarshal(t *testing.T, v interface{}) []byte { require.NoError(t, err) return b } + +func extractJsonProperty[T any](jsonData []byte, property string) (T, error) { + var result map[string]json.RawMessage + + // Unmarshal the JSON data into a map with json.RawMessage + err := json.Unmarshal(jsonData, &result) + if err != nil { + return *new(T), err + } + + // Retrieve the field from the map + value, ok := result[property] + if !ok { + return *new(T), fmt.Errorf("property %s not found", property) + } + + // Unmarshal the value into the type T + var extractedValue T + err = json.Unmarshal(value, &extractedValue) + if err != nil { + return *new(T), err + } + + return extractedValue, nil +} diff --git a/ee/localserver/mocks/presenceDetector.go b/ee/localserver/mocks/presenceDetector.go new file mode 100644 index 000000000..0af0518f7 --- /dev/null +++ b/ee/localserver/mocks/presenceDetector.go @@ -0,0 +1,56 @@ +// Code generated by mockery v2.44.1. DO NOT EDIT. + +package mocks + +import ( + time "time" + + mock "github.com/stretchr/testify/mock" +) + +// PresenceDetector is an autogenerated mock type for the presenceDetector type +type PresenceDetector struct { + mock.Mock +} + +// DetectPresence provides a mock function with given fields: reason, interval +func (_m *PresenceDetector) DetectPresence(reason string, interval time.Duration) (time.Duration, error) { + ret := _m.Called(reason, interval) + + if len(ret) == 0 { + panic("no return value specified for DetectPresence") + } + + var r0 time.Duration + var r1 error + if rf, ok := ret.Get(0).(func(string, time.Duration) (time.Duration, error)); ok { + return rf(reason, interval) + } + if rf, ok := ret.Get(0).(func(string, time.Duration) time.Duration); ok { + r0 = rf(reason, interval) + } else { + r0 = ret.Get(0).(time.Duration) + } + + if rf, ok := ret.Get(1).(func(string, time.Duration) error); ok { + r1 = rf(reason, interval) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewPresenceDetector creates a new instance of PresenceDetector. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewPresenceDetector(t interface { + mock.TestingT + Cleanup(func()) +}) *PresenceDetector { + mock := &PresenceDetector{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/ee/localserver/presence-detection-middleware_test.go b/ee/localserver/presence-detection-middleware_test.go new file mode 100644 index 000000000..2ff95031c --- /dev/null +++ b/ee/localserver/presence-detection-middleware_test.go @@ -0,0 +1,117 @@ +package localserver + +import ( + "net/http" + "net/http/httptest" + "runtime" + "testing" + "time" + + "github.com/kolide/launcher/ee/localserver/mocks" + "github.com/kolide/launcher/ee/presencedetection" + "github.com/kolide/launcher/pkg/log/multislogger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestPresenceDetectionHandler(t *testing.T) { + t.Parallel() + + if runtime.GOOS != "darwin" { + t.Skip("test only runs on darwin until implemented for other OSes") + } + + tests := []struct { + name string + expectDetectPresenceCall bool + intervalHeader, reasonHeader string + durationSinceLastDetection time.Duration + presenceDetectionError error + shouldHavePresenceDetectionDurationResponseHeader bool + expectedStatusCode int + }{ + { + name: "no presence detection headers", + expectedStatusCode: http.StatusOK, + shouldHavePresenceDetectionDurationResponseHeader: false, + }, + { + name: "invalid presence detection interval", + intervalHeader: "invalid-interval", + expectedStatusCode: http.StatusBadRequest, + }, + { + name: "valid presence detection, detection fails", + expectDetectPresenceCall: true, + intervalHeader: "10s", + reasonHeader: "test reason", + durationSinceLastDetection: presencedetection.DetectionFailedDurationValue, + expectedStatusCode: http.StatusOK, + shouldHavePresenceDetectionDurationResponseHeader: true, + }, + { + name: "valid presence detection, detection succeeds", + expectDetectPresenceCall: true, + intervalHeader: "10s", + reasonHeader: "test reason", + durationSinceLastDetection: 0, + expectedStatusCode: http.StatusOK, + shouldHavePresenceDetectionDurationResponseHeader: true, + }, + { + name: "presence detection error", + expectDetectPresenceCall: true, + intervalHeader: "10s", + reasonHeader: "test reason", + durationSinceLastDetection: presencedetection.DetectionFailedDurationValue, + presenceDetectionError: assert.AnError, + expectedStatusCode: http.StatusOK, + shouldHavePresenceDetectionDurationResponseHeader: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + mockPresenceDetector := mocks.NewPresenceDetector(t) + + if tt.expectDetectPresenceCall { + mockPresenceDetector.On("DetectPresence", mock.AnythingOfType("string"), mock.AnythingOfType("Duration")).Return(tt.durationSinceLastDetection, tt.presenceDetectionError) + } + + server := &localServer{ + presenceDetector: mockPresenceDetector, + slogger: multislogger.NewNopLogger(), + } + + // Create a test handler for the middleware to call + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Wrap the test handler in the middleware + handlerToTest := server.presenceDetectionHandler(nextHandler) + + // Create a request with the specified headers + req := httptest.NewRequest("GET", "/", nil) + if tt.intervalHeader != "" { + req.Header.Add("X-Kolide-Presence-Detection-Interval", tt.intervalHeader) + } + + if tt.reasonHeader != "" { + req.Header.Add("X-Kolide-Presence-Detection-Reason", tt.reasonHeader) + } + + rr := httptest.NewRecorder() + handlerToTest.ServeHTTP(rr, req) + + if tt.shouldHavePresenceDetectionDurationResponseHeader { + require.NotEmpty(t, rr.Header().Get(kolideDurationSinceLastPresenceDetection)) + } + require.Equal(t, tt.expectedStatusCode, rr.Code) + }) + } +} diff --git a/ee/localserver/request-id_test.go b/ee/localserver/request-id_test.go index f917571ca..8873bb34c 100644 --- a/ee/localserver/request-id_test.go +++ b/ee/localserver/request-id_test.go @@ -61,7 +61,7 @@ func Test_localServer_requestIdHandler(t *testing.T) { func testServer(t *testing.T, k types.Knapsack) *localServer { require.NoError(t, osquery.SetupLauncherKeys(k.ConfigStore())) - server, err := New(context.TODO(), k) + server, err := New(context.TODO(), k, nil) require.NoError(t, err) return server } diff --git a/ee/localserver/server.go b/ee/localserver/server.go index 06f2b8fda..af8bd519f 100644 --- a/ee/localserver/server.go +++ b/ee/localserver/server.go @@ -12,6 +12,7 @@ import ( "log/slog" "net" "net/http" + "runtime" "strings" "time" @@ -56,6 +57,8 @@ type localServer struct { serverKey *rsa.PublicKey serverEcKey *ecdsa.PublicKey + + presenceDetector presenceDetector } const ( @@ -63,7 +66,11 @@ const ( defaultRateBurst = 10 ) -func New(ctx context.Context, k types.Knapsack) (*localServer, error) { +type presenceDetector interface { + DetectPresence(reason string, interval time.Duration) (time.Duration, error) +} + +func New(ctx context.Context, k types.Knapsack, presenceDetector presenceDetector) (*localServer, error) { _, span := traces.StartSpan(ctx) defer span.End() @@ -74,6 +81,7 @@ func New(ctx context.Context, k types.Knapsack) (*localServer, error) { kolideServer: k.KolideServerURL(), myLocalDbSigner: agent.LocalDbKeys(), myLocalHardwareSigner: agent.HardwareKeys(), + presenceDetector: presenceDetector, } // TODO: As there may be things that adjust the keys during runtime, we need to persist that across @@ -103,13 +111,13 @@ func New(ctx context.Context, k types.Knapsack) (*localServer, error) { mux := http.NewServeMux() mux.HandleFunc("/", http.NotFound) - mux.Handle("/v0/cmd", ecKryptoMiddleware.Wrap(ecAuthedMux)) + mux.Handle("/v0/cmd", ecKryptoMiddleware.Wrap(ls.presenceDetectionHandler(ecAuthedMux))) // /v1/cmd was added after fixing a bug where local server would panic when an endpoint was not found // after making it through the kryptoEcMiddleware // by using v1, k2 can call endpoints without fear of panicing local server // /v0/cmd left for transition period - mux.Handle("/v1/cmd", ecKryptoMiddleware.Wrap(ecAuthedMux)) + mux.Handle("/v1/cmd", ecKryptoMiddleware.Wrap(ls.presenceDetectionHandler(ecAuthedMux))) // uncomment to test without going through middleware // for example: @@ -119,11 +127,19 @@ func New(ctx context.Context, k types.Knapsack) (*localServer, error) { // mux.Handle("/scheduledquery", ls.requestScheduledQueryHandler()) // curl localhost:40978/acceleratecontrol --data '{"interval":"250ms", "duration":"1s"}' // mux.Handle("/acceleratecontrol", ls.requestAccelerateControlHandler()) + // curl localhost:40978/id + // mux.Handle("/id", ls.requestIdHandler()) srv := &http.Server{ - Handler: otelhttp.NewHandler(ls.requestLoggingHandler(ls.preflightCorsHandler(ls.rateLimitHandler(mux))), "localserver", otelhttp.WithSpanNameFormatter(func(operation string, r *http.Request) string { - return r.URL.Path - })), + Handler: otelhttp.NewHandler( + ls.requestLoggingHandler( + ls.preflightCorsHandler( + ls.rateLimitHandler( + mux, + ), + )), "localserver", otelhttp.WithSpanNameFormatter(func(operation string, r *http.Request) string { + return r.URL.Path + })), ReadTimeout: 500 * time.Millisecond, ReadHeaderTimeout: 50 * time.Millisecond, // WriteTimeout very high due to retry logic in the scheduledquery endpoint @@ -393,3 +409,58 @@ func (ls *localServer) rateLimitHandler(next http.Handler) http.Handler { next.ServeHTTP(w, r) }) } + +func (ls *localServer) presenceDetectionHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + // presence detection is only supported on macos currently + if runtime.GOOS != "darwin" { + next.ServeHTTP(w, r) + return + } + + // can test this by adding an unauthed endpoint to the mux and running, for example: + // curl -i -H "X-Kolide-Presence-Detection-Interval: 10s" -H "X-Kolide-Presence-Detection-Reason: my reason" localhost:12519/id + detectionIntervalStr := r.Header.Get(kolidePresenceDetectionInterval) + + // no presence detection requested + if detectionIntervalStr == "" { + next.ServeHTTP(w, r) + return + } + + detectionIntervalDuration, err := time.ParseDuration(detectionIntervalStr) + if err != nil { + // this is the only time this should returna non-200 status code + // asked for presence detection, but the interval is invalid + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + // set a default reason, on macos the popup will look like "Kolide is trying to authenticate." + reason := "authenticate" + reasonHeader := r.Header.Get(kolidePresenceDetectionReason) + if reasonHeader != "" { + reason = reasonHeader + } + + durationSinceLastDetection, err := ls.presenceDetector.DetectPresence(reason, detectionIntervalDuration) + + if err != nil { + ls.slogger.Log(r.Context(), slog.LevelInfo, + "presence_detection", + "reason", reason, + "interval", detectionIntervalDuration, + "duration_since_last_detection", durationSinceLastDetection, + "err", err, + ) + } + + // if there was an error, we still want to return a 200 status code + // and send the request through + // allow the server to decide what to do based on last detection duration + + w.Header().Add(kolideDurationSinceLastPresenceDetection, durationSinceLastDetection.String()) + next.ServeHTTP(w, r) + }) +} diff --git a/ee/localserver/server_test.go b/ee/localserver/server_test.go index 83dfd5cdf..7e4732c06 100644 --- a/ee/localserver/server_test.go +++ b/ee/localserver/server_test.go @@ -36,7 +36,7 @@ func TestInterrupt_Multiple(t *testing.T) { recalculateInterval = 100 * time.Millisecond // Create the localserver - ls, err := New(context.TODO(), k) + ls, err := New(context.TODO(), k, nil) require.NoError(t, err) // Set the querier diff --git a/ee/presencedetection/auth.h b/ee/presencedetection/auth.h new file mode 100644 index 000000000..6d2538ced --- /dev/null +++ b/ee/presencedetection/auth.h @@ -0,0 +1,18 @@ +//go:build darwin +// +build darwin + +// auth.h +#ifndef AUTH_H +#define AUTH_H + +#include + +struct AuthResult { + bool success; // true for success, false for failure + char* error_msg; // Error message if any + int error_code; // Error code if any +}; + +struct AuthResult Authenticate(char const* reason); + +#endif diff --git a/ee/presencedetection/auth.m b/ee/presencedetection/auth.m new file mode 100644 index 000000000..ccfdc4ee5 --- /dev/null +++ b/ee/presencedetection/auth.m @@ -0,0 +1,68 @@ +//go:build darwin +// +build darwin + +// auth.m +#import +#include "auth.h" + +struct AuthResult Authenticate(char const* reason) { + struct AuthResult authResult; + LAContext *myContext = [[LAContext alloc] init]; + NSError *authError = nil; + dispatch_semaphore_t sema = dispatch_semaphore_create(0); + NSString *nsReason = [NSString stringWithUTF8String:reason]; + __block bool success = false; + __block NSString *errorMessage = nil; + __block int errorCode = 0; + + // Use LAPolicyDeviceOwnerAuthentication to allow biometrics and password fallback + if ([myContext canEvaluatePolicy:LAPolicyDeviceOwnerAuthentication error:&authError]) { + [myContext evaluatePolicy:LAPolicyDeviceOwnerAuthentication + localizedReason:nsReason + reply:^(BOOL policySuccess, NSError *error) { + if (policySuccess) { + success = true; // Authentication successful + } else { + success = false; + errorCode = (int)[error code]; + errorMessage = [error localizedDescription]; + if (error.code == LAErrorUserFallback || error.code == LAErrorAuthenticationFailed) { + // Prompting for password + [myContext evaluatePolicy:LAPolicyDeviceOwnerAuthentication + localizedReason:nsReason + reply:^(BOOL pwdSuccess, NSError *error) { + if (pwdSuccess) { + success = true; + } else { + success = false; + errorCode = (int)[error code]; + errorMessage = [error localizedDescription]; + } + dispatch_semaphore_signal(sema); + }]; + } else { + errorCode = (int)[error code]; + errorMessage = [error localizedDescription]; + } + } + dispatch_semaphore_signal(sema); + }]; + } else { + success = false; // Cannot evaluate policy + errorCode = (int)[authError code]; + errorMessage = [authError localizedDescription]; + } + + dispatch_semaphore_wait(sema, DISPATCH_TIME_FOREVER); + dispatch_release(sema); + + authResult.success = success; + authResult.error_code = errorCode; + if (errorMessage != nil) { + authResult.error_msg = strdup([errorMessage UTF8String]); // Copy error message to C string + } else { + authResult.error_msg = NULL; + } + + return authResult; +} diff --git a/ee/presencedetection/mocks/detectorIface.go b/ee/presencedetection/mocks/detectorIface.go new file mode 100644 index 000000000..19bd705d4 --- /dev/null +++ b/ee/presencedetection/mocks/detectorIface.go @@ -0,0 +1,52 @@ +// Code generated by mockery v2.44.1. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// DetectorIface is an autogenerated mock type for the detectorIface type +type DetectorIface struct { + mock.Mock +} + +// Detect provides a mock function with given fields: reason +func (_m *DetectorIface) Detect(reason string) (bool, error) { + ret := _m.Called(reason) + + if len(ret) == 0 { + panic("no return value specified for Detect") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(string) (bool, error)); ok { + return rf(reason) + } + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(reason) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(reason) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewDetectorIface creates a new instance of DetectorIface. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewDetectorIface(t interface { + mock.TestingT + Cleanup(func()) +}) *DetectorIface { + mock := &DetectorIface{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/ee/presencedetection/presencedetection.go b/ee/presencedetection/presencedetection.go new file mode 100644 index 000000000..7a77d1042 --- /dev/null +++ b/ee/presencedetection/presencedetection.go @@ -0,0 +1,58 @@ +package presencedetection + +import ( + "fmt" + "sync" + "time" +) + +const DetectionFailedDurationValue = -1 * time.Second + +type PresenceDetector struct { + lastDetection time.Time + mutext sync.Mutex + // detector is an interface to allow for mocking in tests + detector detectorIface +} + +// just exists for testing purposes +type detectorIface interface { + Detect(reason string) (bool, error) +} + +type detector struct{} + +func (d *detector) Detect(reason string) (bool, error) { + return Detect(reason) +} + +// DetectPresence checks if the user is present by detecting the presence of a user. +// It returns the duration since the last detection. +func (pd *PresenceDetector) DetectPresence(reason string, detectionInterval time.Duration) (time.Duration, error) { + pd.mutext.Lock() + defer pd.mutext.Unlock() + + if pd.detector == nil { + pd.detector = &detector{} + } + + // Check if the last detection was within the detection interval + if (pd.lastDetection != time.Time{}) && time.Since(pd.lastDetection) < detectionInterval { + return time.Since(pd.lastDetection), nil + } + + success, err := pd.detector.Detect(reason) + if err != nil { + // if we got an error, we behave as if there have been no successful detections in the past + return DetectionFailedDurationValue, fmt.Errorf("detecting presence: %w", err) + } + + if success { + pd.lastDetection = time.Now().UTC() + return 0, nil + } + + // if we got here it means we failed without an error + // this "should" never happen, but here for completeness + return DetectionFailedDurationValue, fmt.Errorf("detection failed without OS error") +} diff --git a/ee/presencedetection/presencedetection_darwin.go b/ee/presencedetection/presencedetection_darwin.go new file mode 100644 index 000000000..40f76d823 --- /dev/null +++ b/ee/presencedetection/presencedetection_darwin.go @@ -0,0 +1,36 @@ +//go:build darwin +// +build darwin + +package presencedetection + +/* +#cgo CFLAGS: -x objective-c -fmodules -fblocks +#cgo LDFLAGS: -framework CoreFoundation -framework LocalAuthentication -framework Foundation +#include +#include "auth.h" +*/ +import "C" +import ( + "fmt" + "unsafe" +) + +func Detect(reason string) (bool, error) { + reasonStr := C.CString(reason) + defer C.free(unsafe.Pointer(reasonStr)) + + result := C.Authenticate(reasonStr) + + // Convert C error message to Go string + if result.error_msg != nil { + defer C.free(unsafe.Pointer(result.error_msg)) + } + errorMessage := C.GoString(result.error_msg) + + // Return success or failure, with an error if applicable + if result.success { + return true, nil + } + + return false, fmt.Errorf("authentication failed: %d %s", int(result.error_code), errorMessage) +} diff --git a/ee/presencedetection/presencedetection_darwin_test.go b/ee/presencedetection/presencedetection_darwin_test.go new file mode 100644 index 000000000..ec528ce02 --- /dev/null +++ b/ee/presencedetection/presencedetection_darwin_test.go @@ -0,0 +1,48 @@ +package presencedetection + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const testPresenceEnvVar = "LAUNCHER_TEST_PRESENCE" + +// Since there is no way to test user presence in a CI / automated fashion, +// these test are expected to be run manually via cmd line when needed. + +// To test this run +// +// LAUNCHER_TEST_PRESENCE=true go test ./ee/presencedetection/ -run Test_detectSuccess +// +// then successfully auth with the pop up +func Test_detectSuccess(t *testing.T) { + t.Parallel() + + if os.Getenv(testPresenceEnvVar) == "" { + t.Skip("Skipping Test_detectSuccess") + } + + success, err := Detect("IS TRYING TO TEST SUCCESS, PLEASE AUTHENTICATE") + require.NoError(t, err, "should not get an error on successful detect") + assert.True(t, success, "should be successful") +} + +// To test this run +// +// LAUNCHER_TEST_PRESENCE=true go test ./ee/presencedetection/ -run Test_detectCancel +// +// then cancel the biometric auth that pops up +func Test_detectCancel(t *testing.T) { + t.Parallel() + + if os.Getenv(testPresenceEnvVar) == "" { + t.Skip("Skipping test_biometricDetectCancel") + } + + success, err := Detect("IS TRYING TO TEST CANCEL, PLEASE PRESS CANCEL") + require.Error(t, err, "should get an error on failed detect") + assert.False(t, success, "should not be successful") +} diff --git a/ee/presencedetection/presencedetection_other.go b/ee/presencedetection/presencedetection_other.go new file mode 100644 index 000000000..bcb742bf0 --- /dev/null +++ b/ee/presencedetection/presencedetection_other.go @@ -0,0 +1,11 @@ +//go:build !darwin +// +build !darwin + +package presencedetection + +import "errors" + +func Detect(reason string) (bool, error) { + // Implement detection logic for non-Darwin platforms + return false, errors.New("detection not implemented for this platform") +} diff --git a/ee/presencedetection/presencedetection_test.go b/ee/presencedetection/presencedetection_test.go new file mode 100644 index 000000000..e8dbb603d --- /dev/null +++ b/ee/presencedetection/presencedetection_test.go @@ -0,0 +1,108 @@ +package presencedetection + +import ( + "errors" + "math" + "testing" + "time" + + "github.com/kolide/launcher/ee/presencedetection/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestPresenceDetector_DetectPresence(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + interval time.Duration + detector func(t *testing.T) detectorIface + initialLastDetectionUTC time.Time + expectError bool + }{ + { + name: "first detection success", + interval: 0, + detector: func(t *testing.T) detectorIface { + d := mocks.NewDetectorIface(t) + d.On("Detect", mock.AnythingOfType("string")).Return(true, nil) + return d + }, + }, + { + name: "detection outside interval", + interval: time.Minute, + detector: func(t *testing.T) detectorIface { + d := mocks.NewDetectorIface(t) + d.On("Detect", mock.AnythingOfType("string")).Return(true, nil) + return d + }, + initialLastDetectionUTC: time.Now().UTC().Add(-time.Minute), + }, + { + name: "detection within interval", + interval: time.Minute, + detector: func(t *testing.T) detectorIface { + // should not be called, will get error if it is + return mocks.NewDetectorIface(t) + }, + initialLastDetectionUTC: time.Now().UTC(), + }, + { + name: "error first detection", + interval: 0, + detector: func(t *testing.T) detectorIface { + d := mocks.NewDetectorIface(t) + d.On("Detect", mock.AnythingOfType("string")).Return(true, errors.New("error")) + return d + }, + expectError: true, + }, + { + name: "error after first detection", + interval: 0, + detector: func(t *testing.T) detectorIface { + d := mocks.NewDetectorIface(t) + d.On("Detect", mock.AnythingOfType("string")).Return(true, errors.New("error")) + return d + }, + initialLastDetectionUTC: time.Now().UTC(), + expectError: true, + }, + { + // this should never happen, but it is here for completeness + name: "detection failed without OS error", + interval: 0, + detector: func(t *testing.T) detectorIface { + d := mocks.NewDetectorIface(t) + d.On("Detect", mock.AnythingOfType("string")).Return(false, nil) + return d + }, + initialLastDetectionUTC: time.Now().UTC(), + expectError: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + pd := &PresenceDetector{ + detector: tt.detector(t), + lastDetection: tt.initialLastDetectionUTC, + } + + timeSinceLastDetection, err := pd.DetectPresence("this is a test", tt.interval) + + if tt.expectError { + assert.Error(t, err) + return + } + + absDelta := math.Abs(timeSinceLastDetection.Seconds() - tt.interval.Seconds()) + assert.LessOrEqual(t, absDelta, tt.interval.Seconds()) + }) + } +}