Skip to content

Commit

Permalink
use timers instead of sleeps when probing
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Sep 27, 2021
1 parent bed14e1 commit f91e31c
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 17 deletions.
47 changes: 30 additions & 17 deletions server.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package zeroconf

import (
"errors"
"fmt"
"log"
"math/rand"
Expand Down Expand Up @@ -189,14 +188,10 @@ func (s *Server) start() {
s.refCount.Add(1)
go s.recv6(s.ipv6conn)
}
s.refCount.Add(1)
go s.probe()
}

// Shutdown closes all udp connections and unregisters the service
func (s *Server) Shutdown() {
s.shutdown()
}

// SetText updates and announces the TXT records
func (s *Server) SetText(text []string) {
s.service.Text = text
Expand All @@ -208,15 +203,17 @@ func (s *Server) TTL(ttl uint32) {
s.ttl = ttl
}

// Shutdown server will close currently open connections & channel
func (s *Server) shutdown() error {
// Shutdown closes all udp connections and unregisters the service
func (s *Server) Shutdown() {
s.shutdownLock.Lock()
defer s.shutdownLock.Unlock()
if s.isShutdown {
return errors.New("server is already shutdown")
return
}

err := s.unregister()
if err := s.unregister(); err != nil {
log.Printf("failed to unregister: %s", err)
}

close(s.shouldShutdown)

Expand All @@ -230,8 +227,6 @@ func (s *Server) shutdown() error {
// Wait for connection and routines to be closed
s.refCount.Wait()
s.isShutdown = true

return err
}

// recv4 is a long running routine to receive packets from an interface
Expand Down Expand Up @@ -526,6 +521,8 @@ func (s *Server) serviceTypeName(resp *dns.Msg, ttl uint32) {
// Perform probing & announcement
//TODO: implement a proper probing & conflict resolution
func (s *Server) probe() {
defer s.refCount.Done()

q := new(dns.Msg)
q.SetQuestion(s.service.ServiceInstanceName(), dns.TypePTR)
q.RecursionDesired = false
Expand Down Expand Up @@ -555,12 +552,23 @@ func (s *Server) probe() {

// Wait for a random duration uniformly distributed between 0 and 250 ms
// before sending the first probe packet.
time.Sleep(time.Duration(rand.Intn(250)) * time.Millisecond)
timer := time.NewTimer(time.Duration(rand.Intn(250)) * time.Millisecond)
defer timer.Stop()
select {
case <-timer.C:
case <-s.shouldShutdown:
return
}
for i := 0; i < 3; i++ {
if err := s.multicastResponse(q, 0); err != nil {
log.Println("[ERR] zeroconf: failed to send probe:", err.Error())
}
time.Sleep(250 * time.Millisecond)
timer.Reset(250 * time.Millisecond)
select {
case <-timer.C:
case <-s.shouldShutdown:
return
}
}

// From RFC6762
Expand All @@ -569,7 +577,7 @@ func (s *Server) probe() {
// packet loss, a responder MAY send up to eight unsolicited responses,
// provided that the interval between unsolicited responses increases by
// at least a factor of two with every response sent.
timeout := 1 * time.Second
timeout := time.Second
for i := 0; i < multicastRepetitions; i++ {
for _, intf := range s.ifaces {
resp := new(dns.Msg)
Expand All @@ -583,7 +591,12 @@ func (s *Server) probe() {
log.Println("[ERR] zeroconf: failed to send announcement:", err.Error())
}
}
time.Sleep(timeout)
timer.Reset(timeout)
select {
case <-timer.C:
case <-s.shouldShutdown:
return
}
timeout *= 2
}
}
Expand Down Expand Up @@ -715,7 +728,7 @@ func (s *Server) unicastResponse(resp *dns.Msg, ifIndex int, from net.Addr) erro
}
}

// multicastResponse us used to send a multicast response packet
// multicastResponse is used to send a multicast response packet
func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int) error {
buf, err := msg.Pack()
if err != nil {
Expand Down
18 changes: 18 additions & 0 deletions service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,24 @@ func startMDNS(t *testing.T, port int, name, service, domain string) {
log.Printf("Published service: %s, type: %s, domain: %s", name, service, domain)
}

func TestQuickShutdown(t *testing.T) {
server, err := Register(mdnsName, mdnsService, mdnsDomain, mdnsPort, []string{"txtv=0", "lo=1", "la=2"}, nil)
if err != nil {
t.Fatal(err)
}

done := make(chan struct{})
go func() {
defer close(done)
server.Shutdown()
}()
select {
case <-done:
case <-time.After(500 * time.Millisecond):
t.Fatal("shutdown took longer than 500ms")
}
}

func TestBasic(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
Expand Down

0 comments on commit f91e31c

Please sign in to comment.