diff --git a/probe/endpoint/conntrack.go b/probe/endpoint/conntrack.go index e853855bcc..520426a351 100644 --- a/probe/endpoint/conntrack.go +++ b/probe/endpoint/conntrack.go @@ -4,6 +4,7 @@ import ( "bufio" "encoding/xml" "fmt" + "io" "log" "os" "strings" @@ -59,21 +60,28 @@ type Flow struct { Original, Reply, Independent *Meta `xml:"-"` } +type conntrack struct { + XMLName xml.Name `xml:"conntrack"` + Flows []Flow `xml:"flow"` +} + // Conntracker uses the conntrack command to track network connections type Conntracker struct { sync.Mutex cmd exec.Cmd activeFlows map[int64]Flow // active flows in state != TIME_WAIT bufferedFlows []Flow // flows coming out of activeFlows spend 1 walk cycle here + existingConns bool } // NewConntracker creates and starts a new Conntracter -func NewConntracker(args ...string) (*Conntracker, error) { +func NewConntracker(existingConns bool, args ...string) (*Conntracker, error) { if !ConntrackModulePresent() { return nil, fmt.Errorf("No conntrack module") } result := &Conntracker{ - activeFlows: map[int64]Flow{}, + activeFlows: map[int64]Flow{}, + existingConns: existingConns, } go result.run(args...) return result, nil @@ -105,6 +113,19 @@ var ConntrackModulePresent = func() bool { // NB this is not re-entrant! func (c *Conntracker) run(args ...string) { + if c.existingConns { + // Fork another conntrack, just to capture existing connections + // for which we don't get events + existingFlows, err := c.existingConnections(args...) + if err != nil { + log.Printf("conntrack existingConnections error: %v", err) + return + } + for _, flow := range existingFlows { + c.handleFlow(flow, true) + } + } + args = append([]string{"-E", "-o", "xml", "-p", "tcp"}, args...) cmd := exec.Command("conntrack", args...) stdout, err := cmd.StdoutPipe() @@ -143,15 +164,43 @@ func (c *Conntracker) run(args ...string) { return } + defer log.Printf("contrack exiting") + // Now loop on the output stream decoder := xml.NewDecoder(reader) for { var f Flow if err := decoder.Decode(&f); err != nil { log.Printf("conntrack error: %v", err) + return + } + c.handleFlow(f, false) + } +} + +func (c *Conntracker) existingConnections(args ...string) ([]Flow, error) { + args = append([]string{"-L", "-o", "xml", "-p", "tcp"}, args...) + cmd := exec.Command("conntrack", args...) + stdout, err := cmd.StdoutPipe() + if err != nil { + return []Flow{}, err + } + if err := cmd.Start(); err != nil { + return []Flow{}, err + } + defer func() { + if err := cmd.Wait(); err != nil { + log.Printf("conntrack existingConnections exit error: %v", err) + } + }() + var result conntrack + if err := xml.NewDecoder(stdout).Decode(&result); err != nil { + if err == io.EOF { + return []Flow{}, err } - c.handleFlow(f) + return []Flow{}, err } + return result.Flows, nil } // Stop stop stop @@ -167,7 +216,7 @@ func (c *Conntracker) Stop() { } } -func (c *Conntracker) handleFlow(f Flow) { +func (c *Conntracker) handleFlow(f Flow, forceAdd bool) { // A flow consists of 3 'metas' - the 'original' 4 tuple (as seen by this // host) and the 'reply' 4 tuple, which is what it has been rewritten to. // This code finds those metas, which are identified by a Direction @@ -194,15 +243,15 @@ func (c *Conntracker) handleFlow(f Flow) { c.Lock() defer c.Unlock() - switch f.Type { - case New, Update: + switch { + case forceAdd || f.Type == New || f.Type == Update: if f.Independent.State != TimeWait { c.activeFlows[f.Independent.ID] = f } else if _, ok := c.activeFlows[f.Independent.ID]; ok { delete(c.activeFlows, f.Independent.ID) c.bufferedFlows = append(c.bufferedFlows, f) } - case Destroy: + case f.Type == Destroy: if _, ok := c.activeFlows[f.Independent.ID]; ok { delete(c.activeFlows, f.Independent.ID) c.bufferedFlows = append(c.bufferedFlows, f) diff --git a/probe/endpoint/conntrack_test.go b/probe/endpoint/conntrack_test.go index daf96f59d4..6b5b03f094 100644 --- a/probe/endpoint/conntrack_test.go +++ b/probe/endpoint/conntrack_test.go @@ -76,7 +76,7 @@ func TestConntracker(t *testing.T) { return testExec.NewMockCmd(reader) } - conntracker, err := NewConntracker() + conntracker, err := NewConntracker(false) if err != nil { t.Fatal(err) } diff --git a/probe/endpoint/nat.go b/probe/endpoint/nat.go index 17992297cf..d3531f4388 100644 --- a/probe/endpoint/nat.go +++ b/probe/endpoint/nat.go @@ -21,7 +21,7 @@ type natmapper struct { } func newNATMapper() (*natmapper, error) { - ct, err := NewConntracker("--any-nat") + ct, err := NewConntracker(true, "--any-nat") if err != nil { return nil, err } @@ -53,14 +53,21 @@ func toMapping(f Flow) *endpointMapping { // report, based on the NAT table as returns by natTable. func (n *natmapper) applyNAT(rpt report.Report, scope string) { n.WalkFlows(func(f Flow) { - mapping := toMapping(f) - realEndpointID := report.MakeEndpointNodeID(scope, mapping.originalIP, strconv.Itoa(mapping.originalPort)) - copyEndpointID := report.MakeEndpointNodeID(scope, mapping.rewrittenIP, strconv.Itoa(mapping.rewrittenPort)) - node, ok := rpt.Endpoint.Nodes[realEndpointID] + var ( + mapping = toMapping(f) + realEndpointID = report.MakeEndpointNodeID(scope, mapping.originalIP, strconv.Itoa(mapping.originalPort)) + copyEndpointPort = strconv.Itoa(mapping.rewrittenPort) + copyEndpointID = report.MakeEndpointNodeID(scope, mapping.rewrittenIP, copyEndpointPort) + node, ok = rpt.Endpoint.Nodes[realEndpointID] + ) if !ok { return } - rpt.Endpoint.Nodes[copyEndpointID] = node.Copy() + node = node.Copy() + node.Metadata[Addr] = mapping.rewrittenIP + node.Metadata[Port] = copyEndpointPort + node.Metadata["copy_of"] = realEndpointID + rpt.Endpoint.AddNode(copyEndpointID, node) }) } diff --git a/probe/endpoint/reporter.go b/probe/endpoint/reporter.go index 6df5eb229b..0f58b41cb4 100644 --- a/probe/endpoint/reporter.go +++ b/probe/endpoint/reporter.go @@ -55,7 +55,7 @@ func NewReporter(hostID, hostName string, includeProcesses bool, useConntrack bo err error ) if conntrackModulePresent && useConntrack { - conntracker, err = NewConntracker() + conntracker, err = NewConntracker(true) if err != nil { log.Printf("Failed to start conntracker: %v", err) } @@ -93,6 +93,7 @@ func (r *Reporter) Report() (report.Report, error) { SpyDuration.WithLabelValues().Observe(float64(time.Since(begin))) }(time.Now()) + hostNodeID := report.MakeHostNodeID(r.hostID) rpt := report.MakeReport() conns, err := procspy.Connections(r.includeProcesses) if err != nil { @@ -109,7 +110,8 @@ func (r *Reporter) Report() (report.Report, error) { extraNodeInfo := report.MakeNode() if conn.Proc.PID > 0 { extraNodeInfo = extraNodeInfo.WithMetadata(report.Metadata{ - process.PID: strconv.FormatUint(uint64(conn.Proc.PID), 10), + process.PID: strconv.FormatUint(uint64(conn.Proc.PID), 10), + report.HostNodeID: hostNodeID, }) } r.addConnection(&rpt, localAddr, remoteAddr, localPort, remotePort, &extraNodeInfo, nil) @@ -138,10 +140,7 @@ func (r *Reporter) Report() (report.Report, error) { } func (r *Reporter) addConnection(rpt *report.Report, localAddr, remoteAddr string, localPort, remotePort uint16, extraLocalNode, extraRemoteNode *report.Node) { - var ( - localIsClient = int(localPort) > int(remotePort) - hostNodeID = report.MakeHostNodeID(r.hostID) - ) + localIsClient := int(localPort) > int(remotePort) // Update address topology { @@ -149,9 +148,8 @@ func (r *Reporter) addConnection(rpt *report.Report, localAddr, remoteAddr strin localAddressNodeID = report.MakeAddressNodeID(r.hostID, localAddr) remoteAddressNodeID = report.MakeAddressNodeID(r.hostID, remoteAddr) localNode = report.MakeNodeWith(map[string]string{ - "name": r.hostName, - Addr: localAddr, - report.HostNodeID: hostNodeID, + "name": r.hostName, + Addr: localAddr, }) remoteNode = report.MakeNodeWith(map[string]string{ Addr: remoteAddr, @@ -178,6 +176,12 @@ func (r *Reporter) addConnection(rpt *report.Report, localAddr, remoteAddr strin }) } + if extraLocalNode != nil { + localNode = localNode.Merge(*extraLocalNode) + } + if extraRemoteNode != nil { + remoteNode = remoteNode.Merge(*extraRemoteNode) + } rpt.Address = rpt.Address.AddNode(localAddressNodeID, localNode) rpt.Address = rpt.Address.AddNode(remoteAddressNodeID, remoteNode) } @@ -189,9 +193,8 @@ func (r *Reporter) addConnection(rpt *report.Report, localAddr, remoteAddr strin remoteEndpointNodeID = report.MakeEndpointNodeID(r.hostID, remoteAddr, strconv.Itoa(int(remotePort))) localNode = report.MakeNodeWith(map[string]string{ - Addr: localAddr, - Port: strconv.Itoa(int(localPort)), - report.HostNodeID: hostNodeID, + Addr: localAddr, + Port: strconv.Itoa(int(localPort)), }) remoteNode = report.MakeNodeWith(map[string]string{ Addr: remoteAddr, diff --git a/probe/process/reporter.go b/probe/process/reporter.go index 3d46de07d9..f7b22ffef8 100644 --- a/probe/process/reporter.go +++ b/probe/process/reporter.go @@ -45,7 +45,7 @@ func (r *Reporter) processTopology() (report.Topology, error) { err := r.walker.Walk(func(p Process) { pidstr := strconv.Itoa(p.PID) nodeID := report.MakeProcessNodeID(r.scope, pidstr) - t.Nodes[nodeID] = report.MakeNode() + node := report.MakeNode() for _, tuple := range []struct{ key, value string }{ {PID, pidstr}, {Comm, p.Comm}, @@ -53,12 +53,13 @@ func (r *Reporter) processTopology() (report.Topology, error) { {Threads, strconv.Itoa(p.Threads)}, } { if tuple.value != "" { - t.Nodes[nodeID].Metadata[tuple.key] = tuple.value + node.Metadata[tuple.key] = tuple.value } } if p.PPID > 0 { - t.Nodes[nodeID].Metadata[PPID] = strconv.Itoa(p.PPID) + node.Metadata[PPID] = strconv.Itoa(p.PPID) } + t.AddNode(nodeID, node) }) return t, err diff --git a/render/mapping.go b/render/mapping.go index 40c9242c35..626bd6d219 100644 --- a/render/mapping.go +++ b/render/mapping.go @@ -180,9 +180,17 @@ func MapAddressIdentity(m RenderableNode, local report.Networks) RenderableNodes return RenderableNodes{} } + // Conntracked connections don't have a host id unless + // they were merged with a procspied connection. Filter + // out those that weren't. + _, hasHostID := m.Metadata[report.HostNodeID] + _, conntracked := m.Metadata[endpoint.Conntracked] + if !hasHostID && conntracked { + return RenderableNodes{} + } + // Nodes without a hostid are treated as psuedo nodes - _, ok = m.Metadata[report.HostNodeID] - if !ok { + if !hasHostID { // If the addr is not in a network local to this report, we emit an // internet node if !local.Contains(net.ParseIP(addr)) { diff --git a/report/topology.go b/report/topology.go index 4ac349c720..0b4e75b007 100644 --- a/report/topology.go +++ b/report/topology.go @@ -66,9 +66,10 @@ func (n Nodes) Copy() Nodes { func (n Nodes) Merge(other Nodes) Nodes { cp := n.Copy() for k, v := range other { - if _, ok := cp[k]; !ok { // don't overwrite - cp[k] = v.Copy() + if n, ok := cp[k]; ok { // don't overwrite + v = v.Merge(n) } + cp[k] = v } return cp }