diff --git a/.travis.yml b/.travis.yml index 8df6a2c..f338973 100644 --- a/.travis.yml +++ b/.travis.yml @@ -12,7 +12,7 @@ install: - curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.23.6 script: - - go test ./... -v -covermode=count -coverprofile=coverage.out + - sudo go test ./... -v -covermode=count -coverprofile=coverage.out - cat coverage.out | grep -v "main.go" | grep -v "hosts.go" | grep -v "ports.go" > cover.out - goveralls -coverprofile=cover.out -service=travis-ci - golangci-lint run diff --git a/Makefile b/Makefile index ae01790..4b5e265 100644 --- a/Makefile +++ b/Makefile @@ -2,4 +2,4 @@ .PHONY: test test: - go test -v ./... + sudo go test ./scanner/... -v -covermode=count diff --git a/hosts.go b/hosts.go index bb37293..48cad26 100644 --- a/hosts.go +++ b/hosts.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "os" "os/signal" @@ -21,45 +22,26 @@ func (c *hostsCommand) Execute(_ []string) error { signal.Notify(quit, os.Interrupt) scanner := host.NewScanner() + _, stopFunc := scanner.Ctx(context.Background()) go scanner.Scan() - timeout := time.NewTicker(time.Duration(c.Timeout) * time.Second) + go func() { + timeout := time.NewTicker(time.Duration(c.Timeout) * time.Second) - func() { - for { - fmt.Printf("found %d hosts...\r", len(scanner.Hosts)) - - time.Sleep(1 * time.Second) - - select { - case <-timeout.C: - scanner.Stop() - return - case <-quit: - scanner.Stop() - return - case <-scanner.Done: - - return - default: - } + select { + case <-timeout.C: + stopFunc() + return + case <-quit: + stopFunc() + return } }() - // Clear line - fmt.Printf("%c[2K\r", 27) - - if scanner.Error != nil { - fmt.Printf("\n\r") - return scanner.Error - } - - fmt.Printf("\nFound %d hosts: \n", len(scanner.Hosts)) - - for _, h := range scanner.Hosts { + for h := range scanner.Hosts() { fmt.Printf(" [IP: %s] \t [MAC: %s] \n", h.IP, h.MAC) } - return nil + return scanner.Error } diff --git a/scanner/host/scanner.go b/scanner/host/scanner.go index da4671c..bc0e5f4 100644 --- a/scanner/host/scanner.go +++ b/scanner/host/scanner.go @@ -2,6 +2,7 @@ package host import ( "bytes" + "context" "encoding/binary" "fmt" "net" @@ -16,79 +17,64 @@ import ( // Scanner provide container for control local network scanning // process and checking results type Scanner struct { - mu sync.RWMutex - unique map[string]bool - Hosts []*Host - started bool - stopped bool - stop chan struct{} - Done chan struct{} - Error error + mu sync.RWMutex + ctx context.Context + cancelFunc context.CancelFunc + found chan *Host + unique map[string]bool + Error error } // NewScanner will initialise new instance of Scanner func NewScanner() *Scanner { + ctx, cancelFunc := context.WithCancel(context.Background()) + return &Scanner{ - mu: sync.RWMutex{}, - started: false, - stopped: false, - stop: make(chan struct{}), - unique: make(map[string]bool), - Hosts: make([]*Host, 0), - Done: make(chan struct{}), + mu: sync.RWMutex{}, + ctx: ctx, + cancelFunc: cancelFunc, + found: make(chan *Host), + unique: make(map[string]bool), } } -// Stop perform manually stopping of scan process with blocking -// until stopping is not finished in case of scanning already started -// Safe to call before or after scanning started or stopped -func (s *Scanner) Stop() { - if s.started { - s.stop <- struct{}{} - <-s.Done - } - - if !s.stopped { - close(s.stop) +// Ctx wrap given context and return new with cancel func +func (s *Scanner) Ctx(ctx context.Context) (context.Context, context.CancelFunc) { + s.ctx, s.cancelFunc = context.WithCancel(ctx) - if !s.started { - close(s.Done) - } - } - - s.stopped = true + return s.ctx, s.cancelFunc } -func (s *Scanner) finish(err error) { - if err != nil { - s.Error = err - } - - if s.started && !s.stopped { - s.Done <- struct{}{} - close(s.Done) - } +// Hosts will return a read only channel to receive found Host +func (s *Scanner) Hosts() <-chan *Host { + return s.found } -func (s *Scanner) hasHost(host *Host) bool { - s.mu.RLock() - defer s.mu.RUnlock() +func (s *Scanner) fail(err error) { + s.mu.Lock() + defer s.mu.Unlock() - if _, ok := s.unique[host.ID()]; ok { - return true - } + s.Error = err - return false + if s.ctx.Err() == nil { + s.cancelFunc() + } } -func (s *Scanner) addHost(host *Host) *Scanner { +func (s *Scanner) foundHost(host *Host) bool { s.mu.Lock() defer s.mu.Unlock() - s.Hosts = append(s.Hosts, host) - s.unique[host.ID()] = true + if s.ctx.Err() != nil { + return false + } + + if _, ok := s.unique[host.ID()]; !ok { + s.unique[host.ID()] = true + s.found <- host + } - return s + return true } // Scan will detect system interfaces and go over each one to detect @@ -98,7 +84,7 @@ func (s *Scanner) addHost(host *Host) *Scanner { func (s *Scanner) Scan() { interfaces, err := net.Interfaces() if err != nil { - s.finish(err) + s.fail(err) return } @@ -111,16 +97,15 @@ func (s *Scanner) Scan() { defer wg.Done() if err := s.scanInterface(&iface); err != nil { - s.finish(fmt.Errorf("interface [%v] error: %w", iface.Name, err)) + s.fail(fmt.Errorf("interface [%v] error: %w", iface.Name, err)) return } }(interfaces[i]) } - // Wait for all interfaces' scans to complete. They'll try to run - // forever, but will stop on an error, so if we get past this Wait - // it means all attempts to write have failed. wg.Wait() + + close(s.found) } // Scans an individual interface's local network for machines using ARP requests/replies. @@ -131,19 +116,22 @@ func (s *Scanner) scanInterface(iface *net.Interface) error { // We just look for IPv4 addresses, so try to find if the interface has one. var addr *net.IPNet - if addresses, err := iface.Addrs(); err != nil { + addresses, err := iface.Addrs() + if err != nil { return err - } else { - for _, a := range addresses { - if IPNet, ok := a.(*net.IPNet); ok { - if IPv4 := IPNet.IP.To4(); IPv4 != nil { - addr = &net.IPNet{ - IP: IPv4, - Mask: IPNet.Mask[len(IPNet.Mask)-4:], - } - - break - } + } + + for _, a := range addresses { + if IPNet, ok := a.(*net.IPNet); ok { + IPv4 := IPNet.IP.To4() + + if IPv4 == nil { + continue + } + + addr = &net.IPNet{ + IP: IPv4, + Mask: IPNet.Mask[len(IPNet.Mask)-4:], } } } @@ -176,7 +164,15 @@ func (s *Scanner) scanInterface(iface *net.Interface) error { // We don't know exactly how long it'll take for packets to be // sent back to us, but 10 seconds should be more than enough // time ;) - time.Sleep(10 * time.Second) + + timeout := time.NewTicker(10 * time.Second) + + select { + case <-timeout.C: + continue + case <-s.ctx.Done(): + return nil + } } } @@ -184,12 +180,6 @@ func (s *Scanner) scanInterface(iface *net.Interface) error { // Push new Host once any correct response received // Work until 'stop' is closed. func (s *Scanner) listenARP(handle *pcap.Handle, iface *net.Interface) { - s.mu.Lock() - if !s.started { - s.started = true - } - s.mu.Unlock() - src := gopacket.NewPacketSource(handle, layers.LayerTypeEthernet) in := src.Packets() @@ -197,8 +187,7 @@ func (s *Scanner) listenARP(handle *pcap.Handle, iface *net.Interface) { var packet gopacket.Packet select { - case <-s.stop: - s.finish(nil) + case <-s.ctx.Done(): return case packet = <-in: arpLayer := packet.Layer(layers.LayerTypeARP) @@ -217,13 +206,11 @@ func (s *Scanner) listenARP(handle *pcap.Handle, iface *net.Interface) { // Note: we might get some packets here that aren't responses to ones we've sent, // if for example someone else sends US an ARP request. Doesn't much matter, though... // all information is good information :) - host := Host{ + if !s.foundHost(&Host{ IP: fmt.Sprintf("%v", net.IP(arp.SourceProtAddress)), MAC: fmt.Sprintf("%v", net.HardwareAddr(arp.SourceHwAddress)), - } - - if !s.hasHost(&host) { - s.addHost(&host) + }) { + return } } } @@ -277,7 +264,7 @@ func writeARP(handle *pcap.Handle, iface *net.Interface, addr *net.IPNet) error // ips is a simple and not very good method for getting all IPv4 addresses from a // net.IPNet. It returns all IPs it can over the channel it sends back, closing -// the channel when done. +// the channel when fail. func ips(n *net.IPNet) (out []net.IP) { num := binary.BigEndian.Uint32([]byte(n.IP)) mask := binary.BigEndian.Uint32([]byte(n.Mask)) diff --git a/scanner/host/scanner_test.go b/scanner/host/scanner_test.go index 6a1efd8..8dc625c 100644 --- a/scanner/host/scanner_test.go +++ b/scanner/host/scanner_test.go @@ -1,6 +1,9 @@ package host import ( + "context" + "fmt" + "strings" "testing" "time" ) @@ -16,169 +19,155 @@ func TestNewScanner(t *testing.T) { t.Errorf("Unique host registry is not initialised") } - if scanner.started == true { - t.Errorf("Scanner has wrong started flag") + if scanner.ctx == nil { + t.Errorf("Context is not initialised") } - if scanner.stopped == true { - t.Errorf("Scanner has wrong stopped flag") + if scanner.cancelFunc == nil { + t.Errorf("Cancel func is not initialised") } - if scanner.Hosts == nil { - t.Errorf("Host storage is not initialised") + if scanner.found == nil { + t.Errorf("Found channel is not initialised") } - if scanner.Done == nil { - t.Errorf("Done channel is not created") - } - - if scanner.stop == nil { - t.Errorf("Stop channel is not created") + if scanner.Error != nil { + t.Errorf("Error is not empty") } } -func TestScanner_StopEmpty(t *testing.T) { +func TestScanner_Ctx(t *testing.T) { scanner := NewScanner() if scanner == nil { t.Errorf("Scanner instance is empty") return } - scanner.Stop() - - select { - case <-scanner.stop: - default: - t.Errorf("stop channel must be closed once Stop method called") + ctx, cancelFunc := scanner.Ctx(context.Background()) + if ctx == nil { + t.Errorf("Wrapped context is empty") + return } -} -func TestScanner_StopStartedStopped(t *testing.T) { - scanner := NewScanner() - if scanner == nil { - t.Errorf("Scanner instance is empty") + if cancelFunc == nil { + t.Errorf("Cancel func is empty") return } - go scanner.Scan() - scanner.Stop() - - select { - case <-scanner.stop: - default: - t.Errorf("stop channel must be closed once Stop method called") + if scanner.ctx != ctx { + t.Errorf("Context is not equal") } - select { - case <-scanner.Done: - default: - t.Errorf("done channel must be closed once Stop method finished") + if scanner.ctx.Err() != nil { + t.Errorf("Context is already closed") } - scanner.Stop() + cancelFunc() - select { - case <-scanner.stop: - default: - t.Errorf("stop channel must be closed once Stop method called") - } - - select { - case <-scanner.Done: - default: - t.Errorf("done channel must be closed once Stop method finished") + if scanner.ctx.Err() == nil { + t.Errorf("Context is not closed") } } -func TestScanner_StopWorking(t *testing.T) { +func TestScanner_Hosts(t *testing.T) { scanner := NewScanner() if scanner == nil { t.Errorf("Scanner instance is empty") return } - // simulate fake scanner - go func(s *Scanner) { - s.started = true - <-s.stop - s.finish(nil) - }(scanner) + channel := scanner.Hosts() + if channel == nil { + t.Errorf("Result hosts channel is empty") + } - time.Sleep(1 * time.Millisecond) + host := Host{ + IP: "127.0.0.1", + MAC: "ff:ff:ff:ff:ff:ff", + } - scanner.Stop() + go func() { + for host := range channel { + if host == nil { + t.Errorf("Received nil host") + return + } - select { - case <-scanner.stop: - default: - t.Errorf("stop channel must be closed once Stop method called") - } + if host.IP != "127.0.0.1" { + t.Errorf("Received host IP is not same as input") + return + } - select { - case <-scanner.Done: - default: - t.Errorf("done channel must be closed once Stop method finished") - } + if host.MAC != "ff:ff:ff:ff:ff:ff" { + t.Errorf("Received host IP is not same as input") + return + } + } + }() + + scanner.foundHost(&host) + close(scanner.found) } -func TestScanner_AddHost(t *testing.T) { +func TestScanner_Scan(t *testing.T) { scanner := NewScanner() if scanner == nil { t.Errorf("Scanner instance is empty") return } - host := Host{ - IP: "127.0.0.1", - MAC: "ff:ff:ff:ff:ff:ff", - } + _, stopFunc := scanner.Ctx(context.Background()) - scanner.addHost(&host) - - if len(scanner.Hosts) != 1 { - t.Errorf("Host is not added") - } + go scanner.Scan() - if host.ID() != scanner.Hosts[0].ID() { - t.Errorf("Host is added but changed") - } + go func() { + <-time.NewTicker(2 * time.Second).C + stopFunc() + }() - if host.IP != scanner.Hosts[0].IP { - t.Errorf("Host is added but changed IP") - } + hosts := make([]*Host, 0) - if host.MAC != scanner.Hosts[0].MAC { - t.Errorf("Host is added but changed MAC") + for h := range scanner.Hosts() { + hosts = append(hosts, h) } - if len(scanner.unique) != 1 { - t.Errorf("Host is not registered to unique registry") + if len(scanner.unique) != len(hosts) { + t.Errorf("Some of found hosts is not registered") } - if _, ok := scanner.unique[host.ID()]; !ok { - t.Errorf("Host is not registered to unique registry") + if scanner.foundHost(&Host{}) { + t.Errorf("Registering host after stop scanning must be failed") } } -func TestScanner_HasHost(t *testing.T) { +func TestScanner_ScanFail(t *testing.T) { scanner := NewScanner() if scanner == nil { t.Errorf("Scanner instance is empty") return } - host := Host{ - IP: "127.0.0.1", - MAC: "ff:ff:ff:ff:ff:ff", - } + err := fmt.Errorf("test error") - if scanner.hasHost(&host) { - t.Errorf("Host is wrongly detected as already added") + go scanner.Scan() + + go func() { + <-time.NewTicker(2 * time.Second).C + scanner.fail(err) + }() + + for range scanner.Hosts() { } - scanner.addHost(&host) + if scanner.Error != nil { + if strings.Contains(scanner.Error.Error(), "permission") { + t.Skipf("run tests with sudo to allow scan interface") + return + } - if !scanner.hasHost(&host) { - t.Errorf("Host is not registered as already addded") + if scanner.Error != err { + t.Errorf("Error is not propagated") + } } + }