diff --git a/probe/endpoint/reporter.go b/probe/endpoint/reporter.go index c2fea6facf..049659b131 100644 --- a/probe/endpoint/reporter.go +++ b/probe/endpoint/reporter.go @@ -26,7 +26,7 @@ type Reporter struct { includeNAT bool conntracker *Conntracker natmapper *natmapper - revResolver *ReverseResolver + revResolver *reverseResolver } // SpyDuration is an exported prometheus metric @@ -66,7 +66,7 @@ func NewReporter(hostID, hostName string, includeProcesses bool, useConntrack bo } } - revRes := NewReverseResolver(rAddrCacheLen) + revRes := newReverseResolver() return &Reporter{ hostID: hostID, @@ -152,7 +152,7 @@ func (r *Reporter) addConnection(rpt *report.Report, localAddr, remoteAddr strin ) // in case we have a reverse resolution for the IP, we can use it for the name... - if revRemoteName, err := r.revResolver.Get(remoteAddr); err == nil { + if revRemoteName, err := r.revResolver.Get(remoteAddr, false); err == nil { remoteNode = remoteNode.AddMetadata(map[string]string{ "name": revRemoteName, }) @@ -191,7 +191,7 @@ func (r *Reporter) addConnection(rpt *report.Report, localAddr, remoteAddr strin ) // in case we have a reverse resolution for the IP, we can use it for the name... - if revRemoteName, err := r.revResolver.Get(remoteAddr); err == nil { + if revRemoteName, err := r.revResolver.Get(remoteAddr, false); err == nil { remoteNode = remoteNode.AddMetadata(map[string]string{ "name": revRemoteName, }) diff --git a/probe/endpoint/resolver.go b/probe/endpoint/resolver.go index 99ca8ff7ff..40b0a05a4c 100644 --- a/probe/endpoint/resolver.go +++ b/probe/endpoint/resolver.go @@ -4,6 +4,8 @@ import ( "net" "time" + "strings" + "github.com/bluele/gcache" ) @@ -13,60 +15,72 @@ const ( rAddrCacheExpiration = 30 * time.Minute ) +type revResFunc func(addr string) (names []string, err error) + +type revResRequest struct { + address string + done chan struct{} +} + // ReverseResolver is a caching, reverse resolver -type ReverseResolver struct { - addresses chan string +type reverseResolver struct { + addresses chan revResRequest cache gcache.Cache + resolver revResFunc } // NewReverseResolver starts a new reverse resolver that // performs reverse resolutions and caches the result. -func NewReverseResolver(cacheLen int) *ReverseResolver { - r := ReverseResolver{ - addresses: make(chan string, rAddrBacklog), - cache: gcache.New(cacheLen).LRU().Expiration(rAddrCacheExpiration).Build(), +func newReverseResolver() *reverseResolver { + r := reverseResolver{ + addresses: make(chan revResRequest, rAddrBacklog), + cache: gcache.New(rAddrCacheLen).LRU().Expiration(rAddrCacheExpiration).Build(), + resolver: net.LookupAddr, } - - go r.run() + go r.loop() return &r } -// Get the reverse resolution for an IP address if already in the cache, gcache.NotFoundKeyError otherwise -// Note: it returns one of the possible names that can be obtained for that IP -func (r *ReverseResolver) Get(address string) (string, error) { +// Get the reverse resolution for an IP address if already in the cache, +// a gcache.NotFoundKeyError error otherwise. +// Note: it returns one of the possible names that can be obtained for that IP. +func (r *reverseResolver) Get(address string, wait bool) (string, error) { val, err := r.cache.Get(address) if err == nil { return val.(string), nil } if err == gcache.NotFoundKeyError { + request := revResRequest{address: address, done: make(chan struct{})} // we trigger a asynchronous reverse resolution when not cached select { - case r.addresses <- address: + case r.addresses <- request: + if wait { + <-request.done + } default: } } return "", err } -func (r *ReverseResolver) run() { +func (r *reverseResolver) loop() { throttle := time.Tick(time.Second / 10) - for address := range r.addresses { - <-throttle // rate limit our DNS resolutions - _, err := r.cache.Get(address) // and check if the answer is already in the cache - if err == nil { - continue - } - names, err := net.LookupAddr(address) - if err != nil { + for request := range r.addresses { + <-throttle // rate limit our DNS resolutions + // and check if the answer is already in the cache + if _, err := r.cache.Get(request.address); err == nil { continue } - if len(names) > 0 { - r.cache.Set(address, names[0]) + names, err := r.resolver(request.address) + if err == nil && len(names) > 0 { + name := strings.TrimRight(names[0], ".") + r.cache.Set(request.address, name) } + close(request.done) } } // Stop the async reverse resolver -func (r *ReverseResolver) Stop() { +func (r *reverseResolver) Stop() { close(r.addresses) } diff --git a/probe/endpoint/resolver_internal_test.go b/probe/endpoint/resolver_internal_test.go new file mode 100644 index 0000000000..42407d6a00 --- /dev/null +++ b/probe/endpoint/resolver_internal_test.go @@ -0,0 +1,41 @@ +package endpoint + +import ( + "errors" + "testing" +) + +func TestReverseResolver(t *testing.T) { + tests := map[string]string{ + "8.8.8.8": "google-public-dns-a.google.com", + "8.8.4.4": "google-public-dns-b.google.com", + } + + revRes := newReverseResolver() + + // use a mocked resolver function + revRes.resolver = func(addr string) (names []string, err error) { + if name, ok := tests[addr]; ok { + return []string{name}, nil + } + return []string{}, errors.New("invalid IP") + } + + // first time: no names are returned for our reverse resolutions + for ip, _ := range tests { + if have, err := revRes.Get(ip, true); have != "" || err == nil { + t.Errorf("we didn't get an error, or the cache was not empty, when trying to resolve '%q'", ip) + } + } + + // so, if we check again these IPs, we should have the names now + for ip, want := range tests { + have, err := revRes.Get(ip, true) + if err != nil { + t.Errorf("%s: %v", ip, err) + } + if want != have { + t.Errorf("%s: want %q, have %q", ip, want, have) + } + } +} diff --git a/render/detailed_node.go b/render/detailed_node.go index fb1aaa3984..6e40e35321 100644 --- a/render/detailed_node.go +++ b/render/detailed_node.go @@ -207,7 +207,7 @@ func connectionDetailsRows(topology report.Topology, originID string) []Row { rows := []Row{} labeler := func(nodeID string, meta map[string]string) (string, bool) { if _, addr, port, ok := report.ParseEndpointNodeID(nodeID); ok { - if name, found := meta["name"]; found { + if name, ok := meta["name"]; ok { return fmt.Sprintf("%s:%s", name, port), true } return fmt.Sprintf("%s:%s", addr, port), true