Skip to content

Commit

Permalink
Merge pull request #4 from vokomarov/experimental
Browse files Browse the repository at this point in the history
Refactor scanner to use context instead of a sets of channels
  • Loading branch information
vokomarov committed Mar 9, 2020
2 parents 1e98cf7 + ce2813b commit f167532
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 214 deletions.
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ install:
- curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.23.6

script:
- go test ./... -v -covermode=count -coverprofile=coverage.out
- go mod download
- sudo $(which go) test ./... -v -covermode=count -coverprofile=coverage.out
- cat coverage.out | grep -v "main.go" | grep -v "hosts.go" | grep -v "ports.go" > cover.out
- goveralls -coverprofile=cover.out -service=travis-ci
- golangci-lint run
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
.PHONY: test

test:
go test -v ./...
sudo go test ./... -v -covermode=count
44 changes: 13 additions & 31 deletions hosts.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"fmt"
"os"
"os/signal"
Expand All @@ -21,45 +22,26 @@ func (c *hostsCommand) Execute(_ []string) error {
signal.Notify(quit, os.Interrupt)

scanner := host.NewScanner()
_, stopFunc := scanner.Ctx(context.Background())

go scanner.Scan()

timeout := time.NewTicker(time.Duration(c.Timeout) * time.Second)
go func() {
timeout := time.NewTicker(time.Duration(c.Timeout) * time.Second)

func() {
for {
fmt.Printf("found %d hosts...\r", len(scanner.Hosts))

time.Sleep(1 * time.Second)

select {
case <-timeout.C:
scanner.Stop()
return
case <-quit:
scanner.Stop()
return
case <-scanner.Done:

return
default:
}
select {
case <-timeout.C:
stopFunc()
return
case <-quit:
stopFunc()
return
}
}()

// Clear line
fmt.Printf("%c[2K\r", 27)

if scanner.Error != nil {
fmt.Printf("\n\r")
return scanner.Error
}

fmt.Printf("\nFound %d hosts: \n", len(scanner.Hosts))

for _, h := range scanner.Hosts {
for h := range scanner.Hosts() {
fmt.Printf(" [IP: %s] \t [MAC: %s] \n", h.IP, h.MAC)
}

return nil
return scanner.Error
}
155 changes: 71 additions & 84 deletions scanner/host/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package host

import (
"bytes"
"context"
"encoding/binary"
"fmt"
"net"
Expand All @@ -16,79 +17,64 @@ import (
// Scanner provide container for control local network scanning
// process and checking results
type Scanner struct {
mu sync.RWMutex
unique map[string]bool
Hosts []*Host
started bool
stopped bool
stop chan struct{}
Done chan struct{}
Error error
mu sync.RWMutex
ctx context.Context
cancelFunc context.CancelFunc
found chan *Host
unique map[string]bool
Error error
}

// NewScanner will initialise new instance of Scanner
func NewScanner() *Scanner {
ctx, cancelFunc := context.WithCancel(context.Background())

return &Scanner{
mu: sync.RWMutex{},
started: false,
stopped: false,
stop: make(chan struct{}),
unique: make(map[string]bool),
Hosts: make([]*Host, 0),
Done: make(chan struct{}),
mu: sync.RWMutex{},
ctx: ctx,
cancelFunc: cancelFunc,
found: make(chan *Host),
unique: make(map[string]bool),
}
}

// Stop perform manually stopping of scan process with blocking
// until stopping is not finished in case of scanning already started
// Safe to call before or after scanning started or stopped
func (s *Scanner) Stop() {
if s.started {
s.stop <- struct{}{}
<-s.Done
}

if !s.stopped {
close(s.stop)
// Ctx wrap given context and return new with cancel func
func (s *Scanner) Ctx(ctx context.Context) (context.Context, context.CancelFunc) {
s.ctx, s.cancelFunc = context.WithCancel(ctx)

if !s.started {
close(s.Done)
}
}

s.stopped = true
return s.ctx, s.cancelFunc
}

func (s *Scanner) finish(err error) {
if err != nil {
s.Error = err
}

if s.started && !s.stopped {
s.Done <- struct{}{}
close(s.Done)
}
// Hosts will return a read only channel to receive found Host
func (s *Scanner) Hosts() <-chan *Host {
return s.found
}

func (s *Scanner) hasHost(host *Host) bool {
s.mu.RLock()
defer s.mu.RUnlock()
func (s *Scanner) fail(err error) {
s.mu.Lock()
defer s.mu.Unlock()

if _, ok := s.unique[host.ID()]; ok {
return true
}
s.Error = err

return false
if s.ctx.Err() == nil {
s.cancelFunc()
}
}

func (s *Scanner) addHost(host *Host) *Scanner {
func (s *Scanner) foundHost(host *Host) bool {
s.mu.Lock()
defer s.mu.Unlock()

s.Hosts = append(s.Hosts, host)
s.unique[host.ID()] = true
if s.ctx.Err() != nil {
return false
}

if _, ok := s.unique[host.ID()]; !ok {
s.unique[host.ID()] = true
s.found <- host
}

return s
return true
}

// Scan will detect system interfaces and go over each one to detect
Expand All @@ -98,7 +84,7 @@ func (s *Scanner) addHost(host *Host) *Scanner {
func (s *Scanner) Scan() {
interfaces, err := net.Interfaces()
if err != nil {
s.finish(err)
s.fail(err)
return
}

Expand All @@ -111,16 +97,15 @@ func (s *Scanner) Scan() {
defer wg.Done()

if err := s.scanInterface(&iface); err != nil {
s.finish(fmt.Errorf("interface [%v] error: %w", iface.Name, err))
s.fail(fmt.Errorf("interface [%v] error: %w", iface.Name, err))
return
}
}(interfaces[i])
}

// Wait for all interfaces' scans to complete. They'll try to run
// forever, but will stop on an error, so if we get past this Wait
// it means all attempts to write have failed.
wg.Wait()

close(s.found)
}

// Scans an individual interface's local network for machines using ARP requests/replies.
Expand All @@ -131,19 +116,22 @@ func (s *Scanner) scanInterface(iface *net.Interface) error {
// We just look for IPv4 addresses, so try to find if the interface has one.
var addr *net.IPNet

if addresses, err := iface.Addrs(); err != nil {
addresses, err := iface.Addrs()
if err != nil {
return err
} else {
for _, a := range addresses {
if IPNet, ok := a.(*net.IPNet); ok {
if IPv4 := IPNet.IP.To4(); IPv4 != nil {
addr = &net.IPNet{
IP: IPv4,
Mask: IPNet.Mask[len(IPNet.Mask)-4:],
}

break
}
}

for _, a := range addresses {
if IPNet, ok := a.(*net.IPNet); ok {
IPv4 := IPNet.IP.To4()

if IPv4 == nil {
continue
}

addr = &net.IPNet{
IP: IPv4,
Mask: IPNet.Mask[len(IPNet.Mask)-4:],
}
}
}
Expand Down Expand Up @@ -176,29 +164,30 @@ func (s *Scanner) scanInterface(iface *net.Interface) error {
// We don't know exactly how long it'll take for packets to be
// sent back to us, but 10 seconds should be more than enough
// time ;)
time.Sleep(10 * time.Second)

timeout := time.NewTicker(10 * time.Second)

select {
case <-timeout.C:
continue
case <-s.ctx.Done():
return nil
}
}
}

// Watches a handle for incoming ARP responses we might care about.
// Push new Host once any correct response received
// Work until 'stop' is closed.
func (s *Scanner) listenARP(handle *pcap.Handle, iface *net.Interface) {
s.mu.Lock()
if !s.started {
s.started = true
}
s.mu.Unlock()

src := gopacket.NewPacketSource(handle, layers.LayerTypeEthernet)
in := src.Packets()

for {
var packet gopacket.Packet

select {
case <-s.stop:
s.finish(nil)
case <-s.ctx.Done():
return
case packet = <-in:
arpLayer := packet.Layer(layers.LayerTypeARP)
Expand All @@ -217,13 +206,11 @@ func (s *Scanner) listenARP(handle *pcap.Handle, iface *net.Interface) {
// Note: we might get some packets here that aren't responses to ones we've sent,
// if for example someone else sends US an ARP request. Doesn't much matter, though...
// all information is good information :)
host := Host{
if !s.foundHost(&Host{
IP: fmt.Sprintf("%v", net.IP(arp.SourceProtAddress)),
MAC: fmt.Sprintf("%v", net.HardwareAddr(arp.SourceHwAddress)),
}

if !s.hasHost(&host) {
s.addHost(&host)
}) {
return
}
}
}
Expand Down Expand Up @@ -277,7 +264,7 @@ func writeARP(handle *pcap.Handle, iface *net.Interface, addr *net.IPNet) error

// ips is a simple and not very good method for getting all IPv4 addresses from a
// net.IPNet. It returns all IPs it can over the channel it sends back, closing
// the channel when done.
// the channel when fail.
func ips(n *net.IPNet) (out []net.IP) {
num := binary.BigEndian.Uint32([]byte(n.IP))
mask := binary.BigEndian.Uint32([]byte(n.Mask))
Expand Down
Loading

0 comments on commit f167532

Please sign in to comment.