diff --git a/windows/service.go b/windows/service.go index f8deca839..c964b6848 100644 --- a/windows/service.go +++ b/windows/service.go @@ -141,6 +141,12 @@ const ( SERVICE_DYNAMIC_INFORMATION_LEVEL_START_REASON = 1 ) +type ENUM_SERVICE_STATUS struct { + ServiceName *uint16 + DisplayName *uint16 + ServiceStatus SERVICE_STATUS +} + type SERVICE_STATUS struct { ServiceType uint32 CurrentState uint32 @@ -245,3 +251,4 @@ type QUERY_SERVICE_LOCK_STATUS struct { //sys UnsubscribeServiceChangeNotifications(subscription uintptr) = sechost.UnsubscribeServiceChangeNotifications? //sys RegisterServiceCtrlHandlerEx(serviceName *uint16, handlerProc uintptr, context uintptr) (handle Handle, err error) = advapi32.RegisterServiceCtrlHandlerExW //sys QueryServiceDynamicInformation(service Handle, infoLevel uint32, dynamicInfo unsafe.Pointer) (err error) = advapi32.QueryServiceDynamicInformation? +//sys EnumDependentServices(service Handle, activityState uint32, services *ENUM_SERVICE_STATUS, buffSize uint32, bytesNeeded *uint32, servicesReturned *uint32) (err error) = advapi32.EnumDependentServicesW diff --git a/windows/svc/mgr/mgr_test.go b/windows/svc/mgr/mgr_test.go index 330cca4a1..bc763f060 100644 --- a/windows/svc/mgr/mgr_test.go +++ b/windows/svc/mgr/mgr_test.go @@ -17,6 +17,7 @@ import ( "testing" "time" + "golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc/mgr" ) @@ -294,3 +295,45 @@ func TestMyService(t *testing.T) { remove(t, s) } + +func TestListDependentServices(t *testing.T) { + m, err := mgr.Connect() + if err != nil { + if errno, ok := err.(syscall.Errno); ok && errno == syscall.ERROR_ACCESS_DENIED { + t.Skip("Skipping test: we don't have rights to manage services.") + } + t.Fatalf("SCM connection failed: %s", err) + } + defer m.Disconnect() + + baseServiceName := "testservice1" + dependentServiceName := "testservice2" + install(t, m, baseServiceName, "", mgr.Config{}) + baseService, err := m.OpenService(baseServiceName) + if err != nil { + t.Fatalf("OpenService failed: %v", err) + } + defer remove(t, baseService) + install(t, m, dependentServiceName, "", mgr.Config{Dependencies: []string{baseServiceName}}) + dependentService, err := m.OpenService(dependentServiceName) + if err != nil { + t.Fatalf("OpenService failed: %v", err) + } + defer remove(t, dependentService) + + // test that both the base service and dependent service list the correct dependencies + dependentServices, err := baseService.ListDependentServices(svc.AnyActivity) + if err != nil { + t.Fatalf("baseService.ListDependentServices failed: %v", err) + } + if len(dependentServices) != 1 || dependentServices[0] != dependentServiceName { + t.Errorf("Found %v, instead of expected contents %s", dependentServices, dependentServiceName) + } + dependentServices, err = dependentService.ListDependentServices(svc.AnyActivity) + if err != nil { + t.Fatalf("dependentService.ListDependentServices failed: %v", err) + } + if len(dependentServices) != 0 { + t.Errorf("Found %v, where no service should be listed", dependentService) + } +} diff --git a/windows/svc/mgr/service.go b/windows/svc/mgr/service.go index 0623fc0b0..90f5d95d5 100644 --- a/windows/svc/mgr/service.go +++ b/windows/svc/mgr/service.go @@ -15,8 +15,6 @@ import ( "golang.org/x/sys/windows/svc" ) -// TODO(brainman): Use EnumDependentServices to enumerate dependent services. - // Service is used to access Windows service. type Service struct { Name string @@ -76,3 +74,44 @@ func (s *Service) Query() (svc.Status, error) { ServiceSpecificExitCode: t.ServiceSpecificExitCode, }, nil } + +// ListDependentServices returns the names of the services dependent on service s, which match the given status. +func (s *Service) ListDependentServices(status svc.ActivityStatus) ([]string, error) { + var bytesNeeded, returnedServiceCount uint32 + var services []windows.ENUM_SERVICE_STATUS + for { + var servicesPtr *windows.ENUM_SERVICE_STATUS + if len(services) > 0 { + servicesPtr = &services[0] + } + allocatedBytes := uint32(len(services)) * uint32(unsafe.Sizeof(windows.ENUM_SERVICE_STATUS{})) + err := windows.EnumDependentServices(s.Handle, uint32(status), servicesPtr, allocatedBytes, &bytesNeeded, + &returnedServiceCount) + if err == nil { + break + } + if err != syscall.ERROR_MORE_DATA { + return nil, err + } + if bytesNeeded <= allocatedBytes { + return nil, err + } + // ERROR_MORE_DATA indicates the provided buffer was too small, run the call again after resizing the buffer + requiredSliceLen := bytesNeeded / uint32(unsafe.Sizeof(windows.ENUM_SERVICE_STATUS{})) + if bytesNeeded%uint32(unsafe.Sizeof(windows.ENUM_SERVICE_STATUS{})) != 0 { + requiredSliceLen += 1 + } + services = make([]windows.ENUM_SERVICE_STATUS, requiredSliceLen) + } + if returnedServiceCount == 0 { + return nil, nil + } + + // The slice mutated by EnumDependentServices may have a length greater than returnedServiceCount, any elements + // past that should be ignored. + var dependents []string + for i := 0; i < int(returnedServiceCount); i++ { + dependents = append(dependents, windows.UTF16PtrToString(services[i].ServiceName)) + } + return dependents, nil +} diff --git a/windows/svc/service.go b/windows/svc/service.go index 806baa055..2b4a7bc6c 100644 --- a/windows/svc/service.go +++ b/windows/svc/service.go @@ -68,6 +68,15 @@ const ( AcceptPreShutdown = Accepted(windows.SERVICE_ACCEPT_PRESHUTDOWN) ) +// ActivityStatus allows for services to be selected based on active and inactive categories of service state. +type ActivityStatus uint32 + +const ( + Active = ActivityStatus(windows.SERVICE_ACTIVE) + Inactive = ActivityStatus(windows.SERVICE_INACTIVE) + AnyActivity = ActivityStatus(windows.SERVICE_STATE_ALL) +) + // Status combines State and Accepted commands to fully describe running service. type Status struct { State State diff --git a/windows/zsyscall_windows.go b/windows/zsyscall_windows.go index 6d2a26853..a81ea2c70 100644 --- a/windows/zsyscall_windows.go +++ b/windows/zsyscall_windows.go @@ -86,6 +86,7 @@ var ( procDeleteService = modadvapi32.NewProc("DeleteService") procDeregisterEventSource = modadvapi32.NewProc("DeregisterEventSource") procDuplicateTokenEx = modadvapi32.NewProc("DuplicateTokenEx") + procEnumDependentServicesW = modadvapi32.NewProc("EnumDependentServicesW") procEnumServicesStatusExW = modadvapi32.NewProc("EnumServicesStatusExW") procEqualSid = modadvapi32.NewProc("EqualSid") procFreeSid = modadvapi32.NewProc("FreeSid") @@ -734,6 +735,14 @@ func DuplicateTokenEx(existingToken Token, desiredAccess uint32, tokenAttributes return } +func EnumDependentServices(service Handle, activityState uint32, services *ENUM_SERVICE_STATUS, buffSize uint32, bytesNeeded *uint32, servicesReturned *uint32) (err error) { + r1, _, e1 := syscall.Syscall6(procEnumDependentServicesW.Addr(), 6, uintptr(service), uintptr(activityState), uintptr(unsafe.Pointer(services)), uintptr(buffSize), uintptr(unsafe.Pointer(bytesNeeded)), uintptr(unsafe.Pointer(servicesReturned))) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + func EnumServicesStatusEx(mgr Handle, infoLevel uint32, serviceType uint32, serviceState uint32, services *byte, bufSize uint32, bytesNeeded *uint32, servicesReturned *uint32, resumeHandle *uint32, groupName *uint16) (err error) { r1, _, e1 := syscall.Syscall12(procEnumServicesStatusExW.Addr(), 10, uintptr(mgr), uintptr(infoLevel), uintptr(serviceType), uintptr(serviceState), uintptr(unsafe.Pointer(services)), uintptr(bufSize), uintptr(unsafe.Pointer(bytesNeeded)), uintptr(unsafe.Pointer(servicesReturned)), uintptr(unsafe.Pointer(resumeHandle)), uintptr(unsafe.Pointer(groupName)), 0, 0) if r1 == 0 {