diff --git a/svc.go b/svc.go index 24c4b32..8307896 100644 --- a/svc.go +++ b/svc.go @@ -18,12 +18,16 @@ package svc import ( "context" + "errors" + "os" "os/signal" ) // Create variable signal.Notify function so we can mock it in tests var signalNotify = signal.Notify +var ErrStop = errors.New("stopping service") + // Service interface contains Start and Stop methods which are called // when the service is started and stopped. The Init method is called // before the service is started, and after it's determined if the program @@ -41,11 +45,18 @@ type Service interface { // Start is called after Init. This method must be non-blocking. Start() error - // Stop is called in response to syscall.SIGINT, syscall.SIGTERM, or when a + // Stop is called when Handle() returns ErrStop or when a // Windows Service is stopped. Stop() error } +// Handler is an optional interface a Service can implement. +// When implemented, Handle() is called when a signal is received. +// Returning ErrStop from this method will result in Service.Stop() being called. +type Handler interface { + Handle(os.Signal) error +} + // Context interface contains an optional Context function which a Service can implement. // When implemented the context.Done() channel will be used in addition to signal handling // to exit a process. diff --git a/svc_common.go b/svc_common.go index 5daebc6..61de7a4 100644 --- a/svc_common.go +++ b/svc_common.go @@ -36,11 +36,23 @@ func Run(service Service, sig ...os.Signal) error { ctx = context.Background() } - select { - case <-signalChan: - case <-ctx.Done(): + for { + select { + case s := <-signalChan: + if h, ok := service.(Handler); ok { + if err := h.Handle(s); err == ErrStop { + goto stop + } + } else { + // this maintains backwards compatibility for Services that do not implement Handle() + goto stop + } + case <-ctx.Done(): + goto stop + } } +stop: return service.Stop() }