diff --git a/decoder/decoder.go b/decoder/decoder.go new file mode 100644 index 000000000000..d9303f36c361 --- /dev/null +++ b/decoder/decoder.go @@ -0,0 +1,138 @@ +package decoder + +import ( + "fmt" + + "github.com/elastic/libbeat/logp" + "github.com/elastic/packetbeat/protos" + "github.com/elastic/packetbeat/protos/tcp" + "github.com/elastic/packetbeat/protos/udp" + + "github.com/tsg/gopacket" + "github.com/tsg/gopacket/layers" +) + +type DecoderStruct struct { + Parser *gopacket.DecodingLayerParser + + sll layers.LinuxSLL + d1q layers.Dot1Q + lo layers.Loopback + eth layers.Ethernet + ip4 layers.IPv4 + ip6 layers.IPv6 + tcp layers.TCP + udp layers.UDP + payload gopacket.Payload + decoded []gopacket.LayerType + + tcpProc tcp.Processor + udpProc udp.Processor +} + +// Creates and returns a new DecoderStruct. +func NewDecoder(datalink layers.LinkType, tcp tcp.Processor, udp udp.Processor) (*DecoderStruct, error) { + d := DecoderStruct{tcpProc: tcp, udpProc: udp} + + logp.Debug("pcapread", "Layer type: %s", datalink.String()) + + switch datalink { + + case layers.LinkTypeLinuxSLL: + d.Parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeLinuxSLL, + &d.sll, &d.d1q, &d.ip4, &d.ip6, &d.tcp, &d.udp, &d.payload) + + case layers.LinkTypeEthernet: + d.Parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeEthernet, + &d.eth, &d.d1q, &d.ip4, &d.ip6, &d.tcp, &d.udp, &d.payload) + + case layers.LinkTypeNull: // loopback on OSx + d.Parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeLoopback, + &d.lo, &d.d1q, &d.ip4, &d.ip6, &d.tcp, &d.udp, &d.payload) + + default: + return nil, fmt.Errorf("Unsuported link type: %s", datalink.String()) + + } + + d.decoded = []gopacket.LayerType{} + + return &d, nil +} + +func (decoder *DecoderStruct) DecodePacketData(data []byte, ci *gopacket.CaptureInfo) { + + var err error + var packet protos.Packet + + err = decoder.Parser.DecodeLayers(data, &decoder.decoded) + if err != nil { + // Ignore UnsupportedLayerType errors that can occur while parsing + // UDP packets. + lastLayer := decoder.decoded[len(decoder.decoded)-1] + _, unsupported := err.(gopacket.UnsupportedLayerType) + if !(unsupported && lastLayer == layers.LayerTypeUDP) { + logp.Debug("pcapread", "Decoding error: %s", err) + return + } + } + + has_tcp := false + has_udp := false + + for _, layerType := range decoder.decoded { + switch layerType { + case layers.LayerTypeIPv4: + logp.Debug("ip", "IPv4 packet") + + packet.Tuple.Src_ip = decoder.ip4.SrcIP + packet.Tuple.Dst_ip = decoder.ip4.DstIP + packet.Tuple.Ip_length = 4 + + case layers.LayerTypeIPv6: + logp.Debug("ip", "IPv6 packet") + + packet.Tuple.Src_ip = decoder.ip6.SrcIP + packet.Tuple.Dst_ip = decoder.ip6.DstIP + packet.Tuple.Ip_length = 16 + + case layers.LayerTypeTCP: + logp.Debug("ip", "TCP packet") + + packet.Tuple.Src_port = uint16(decoder.tcp.SrcPort) + packet.Tuple.Dst_port = uint16(decoder.tcp.DstPort) + + has_tcp = true + + case layers.LayerTypeUDP: + logp.Debug("ip", "UDP packet") + + packet.Tuple.Src_port = uint16(decoder.udp.SrcPort) + packet.Tuple.Dst_port = uint16(decoder.udp.DstPort) + packet.Payload = decoder.udp.Payload + + has_udp = true + + case gopacket.LayerTypePayload: + packet.Payload = decoder.payload + } + } + + packet.Ts = ci.Timestamp + packet.Tuple.ComputeHashebles() + + if has_udp { + decoder.udpProc.Process(&packet) + } else if has_tcp { + if len(packet.Payload) == 0 && !decoder.tcp.FIN { + // We have no use for this atm. + logp.Debug("pcapread", "Ignore empty non-FIN packet") + return + } + + decoder.tcpProc.Process(&decoder.tcp, &packet) + } +} diff --git a/decoder/decoder_test.go b/decoder/decoder_test.go new file mode 100644 index 000000000000..80be2b7f196d --- /dev/null +++ b/decoder/decoder_test.go @@ -0,0 +1,166 @@ +package decoder + +import ( + "strings" + "testing" + + "github.com/elastic/packetbeat/protos" + + "github.com/stretchr/testify/assert" + "github.com/tsg/gopacket" + "github.com/tsg/gopacket/layers" +) + +type TestTcpProcessor struct { + tcphdr *layers.TCP + pkt *protos.Packet +} + +func (l *TestTcpProcessor) Process(tcphdr *layers.TCP, pkt *protos.Packet) { + l.tcphdr = tcphdr + l.pkt = pkt +} + +type TestUdpProcessor struct { + pkt *protos.Packet +} + +func (l *TestUdpProcessor) Process(pkt *protos.Packet) { + l.pkt = pkt +} + +// 172.16.16.164:1108 172.16.16.139:53 DNS 87 Standard query 0x0007 AXFR contoso.local +var ipv4TcpDns = []byte{ + 0x00, 0x0c, 0x29, 0xce, 0xd1, 0x9e, 0x00, 0x0c, 0x29, 0x7e, 0xec, 0xa4, 0x08, 0x00, 0x45, 0x00, + 0x00, 0x49, 0x46, 0x54, 0x40, 0x00, 0x80, 0x06, 0x3b, 0x0b, 0xac, 0x10, 0x10, 0xa4, 0xac, 0x10, + 0x10, 0x8b, 0x04, 0x54, 0x00, 0x35, 0x5d, 0x9f, 0x0c, 0x90, 0x1a, 0xef, 0x6f, 0x43, 0x50, 0x18, + 0xfa, 0xf0, 0xbc, 0x3d, 0x00, 0x00, 0x00, 0x07, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x07, 0x63, 0x6f, 0x6e, 0x74, 0x6f, 0x73, 0x6f, 0x05, 0x6c, 0x6f, 0x63, 0x61, 0x6c, + 0x00, 0x00, 0xfc, 0x00, 0x01, 0x4d, 0x53, +} + +// Test that DecodePacket decodes and IPv4/TCP packet and invokes the TCP processor. +func TestDecodePacketData_ipv4Tcp(t *testing.T) { + p := gopacket.NewPacket(ipv4TcpDns, layers.LinkTypeEthernet, gopacket.Default) + if p.ErrorLayer() != nil { + t.Error("Failed to decode packet:", p.ErrorLayer().Error()) + } + d, tcp, _ := newTestDecoder(t) + d.DecodePacketData(p.Data(), &p.Metadata().CaptureInfo) + + assert.NotNil(t, tcp.pkt, "TCP packet not received") + assert.Equal(t, "172.16.16.164", tcp.pkt.Tuple.Src_ip.String()) + assert.Equal(t, uint16(1108), tcp.pkt.Tuple.Src_port) + assert.Equal(t, "172.16.16.139", tcp.pkt.Tuple.Dst_ip.String()) + assert.Equal(t, uint16(53), tcp.pkt.Tuple.Dst_port) + assert.NotEqual(t, -1, strings.Index(string(p.Data()), string(tcp.pkt.Payload))) +} + +// 192.168.170.8:32795 192.168.170.20:53 DNS 74 Standard query 0x75c0 A www.netbsd.org +var ipv4UdpDns = []byte{ + 0x00, 0xc0, 0x9f, 0x32, 0x41, 0x8c, 0x00, 0xe0, 0x18, 0xb1, 0x0c, 0xad, 0x08, 0x00, 0x45, 0x00, + 0x00, 0x3c, 0x00, 0x00, 0x40, 0x00, 0x40, 0x11, 0x65, 0x43, 0xc0, 0xa8, 0xaa, 0x08, 0xc0, 0xa8, + 0xaa, 0x14, 0x80, 0x1b, 0x00, 0x35, 0x00, 0x28, 0xaf, 0x61, 0x75, 0xc0, 0x01, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x77, 0x77, 0x77, 0x06, 0x6e, 0x65, 0x74, 0x62, 0x73, + 0x64, 0x03, 0x6f, 0x72, 0x67, 0x00, 0x00, 0x01, 0x00, 0x01, +} + +// Test that DecodePacket decodes and IPv4/UDP packet and invokes the UDP processor. +func TestDecodePacketData_ipv4Udp(t *testing.T) { + p := gopacket.NewPacket(ipv4UdpDns, layers.LinkTypeEthernet, gopacket.Default) + if p.ErrorLayer() != nil { + t.Error("Failed to decode packet:", p.ErrorLayer().Error()) + } + d, _, udp := newTestDecoder(t) + d.DecodePacketData(p.Data(), &p.Metadata().CaptureInfo) + + assert.NotNil(t, udp.pkt, "UDP packet not received") + assert.Equal(t, "192.168.170.8", udp.pkt.Tuple.Src_ip.String()) + assert.Equal(t, uint16(32795), udp.pkt.Tuple.Src_port) + assert.Equal(t, "192.168.170.20", udp.pkt.Tuple.Dst_ip.String()) + assert.Equal(t, uint16(53), udp.pkt.Tuple.Dst_port) + assert.NotEqual(t, -1, strings.Index(string(p.Data()), string(udp.pkt.Payload))) +} + +// IP6 2001:6f8:102d::2d0:9ff:fee3:e8de.59201 > 2001:6f8:900:7c0::2.80 +var ipv6TcpHttpGet = []byte{ + 0x00, 0x11, 0x25, 0x82, 0x95, 0xb5, 0x00, 0xd0, 0x09, 0xe3, 0xe8, 0xde, 0x86, 0xdd, 0x60, 0x00, + 0x00, 0x00, 0x01, 0x04, 0x06, 0x40, 0x20, 0x01, 0x06, 0xf8, 0x10, 0x2d, 0x00, 0x00, 0x02, 0xd0, + 0x09, 0xff, 0xfe, 0xe3, 0xe8, 0xde, 0x20, 0x01, 0x06, 0xf8, 0x09, 0x00, 0x07, 0xc0, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xe7, 0x41, 0x00, 0x50, 0xab, 0xdc, 0xd6, 0x61, 0x01, 0x4a, + 0x73, 0x9f, 0x50, 0x18, 0x16, 0x80, 0xf4, 0x48, 0x00, 0x00, 0x47, 0x45, 0x54, 0x20, 0x2f, 0x20, + 0x48, 0x54, 0x54, 0x50, 0x2f, 0x31, 0x2e, 0x30, 0x0d, 0x0a, 0x48, 0x6f, 0x73, 0x74, 0x3a, 0x20, + 0x63, 0x6c, 0x2d, 0x31, 0x39, 0x38, 0x35, 0x2e, 0x68, 0x61, 0x6d, 0x2d, 0x30, 0x31, 0x2e, 0x64, + 0x65, 0x2e, 0x73, 0x69, 0x78, 0x78, 0x73, 0x2e, 0x6e, 0x65, 0x74, 0x0d, 0x0a, 0x41, 0x63, 0x63, + 0x65, 0x70, 0x74, 0x3a, 0x20, 0x74, 0x65, 0x78, 0x74, 0x2f, 0x68, 0x74, 0x6d, 0x6c, 0x2c, 0x20, + 0x74, 0x65, 0x78, 0x74, 0x2f, 0x70, 0x6c, 0x61, 0x69, 0x6e, 0x2c, 0x20, 0x74, 0x65, 0x78, 0x74, + 0x2f, 0x63, 0x73, 0x73, 0x2c, 0x20, 0x74, 0x65, 0x78, 0x74, 0x2f, 0x73, 0x67, 0x6d, 0x6c, 0x2c, + 0x20, 0x2a, 0x2f, 0x2a, 0x3b, 0x71, 0x3d, 0x30, 0x2e, 0x30, 0x31, 0x0d, 0x0a, 0x41, 0x63, 0x63, + 0x65, 0x70, 0x74, 0x2d, 0x45, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67, 0x3a, 0x20, 0x67, 0x7a, + 0x69, 0x70, 0x2c, 0x20, 0x62, 0x7a, 0x69, 0x70, 0x32, 0x0d, 0x0a, 0x41, 0x63, 0x63, 0x65, 0x70, + 0x74, 0x2d, 0x4c, 0x61, 0x6e, 0x67, 0x75, 0x61, 0x67, 0x65, 0x3a, 0x20, 0x65, 0x6e, 0x0d, 0x0a, + 0x55, 0x73, 0x65, 0x72, 0x2d, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x3a, 0x20, 0x4c, 0x79, 0x6e, 0x78, + 0x2f, 0x32, 0x2e, 0x38, 0x2e, 0x36, 0x72, 0x65, 0x6c, 0x2e, 0x32, 0x20, 0x6c, 0x69, 0x62, 0x77, + 0x77, 0x77, 0x2d, 0x46, 0x4d, 0x2f, 0x32, 0x2e, 0x31, 0x34, 0x20, 0x53, 0x53, 0x4c, 0x2d, 0x4d, + 0x4d, 0x2f, 0x31, 0x2e, 0x34, 0x2e, 0x31, 0x20, 0x4f, 0x70, 0x65, 0x6e, 0x53, 0x53, 0x4c, 0x2f, + 0x30, 0x2e, 0x39, 0x2e, 0x38, 0x62, 0x0d, 0x0a, 0x0d, 0x0a, +} + +// Test that DecodePacket decodes and IPv6/TCP packet and invokes the TCP processor. +func TestDecodePacketData_ipv6Tcp(t *testing.T) { + p := gopacket.NewPacket(ipv6TcpHttpGet, layers.LinkTypeEthernet, gopacket.Default) + if p.ErrorLayer() != nil { + t.Error("Failed to decode packet: ", p.ErrorLayer().Error()) + } + d, tcp, _ := newTestDecoder(t) + d.DecodePacketData(p.Data(), &p.Metadata().CaptureInfo) + + assert.NotNil(t, tcp.pkt, "TCP packet not received") + assert.Equal(t, "2001:6f8:102d:0:2d0:9ff:fee3:e8de", tcp.pkt.Tuple.Src_ip.String()) + assert.Equal(t, uint16(59201), tcp.pkt.Tuple.Src_port) + assert.Equal(t, "2001:6f8:900:7c0::2", tcp.pkt.Tuple.Dst_ip.String()) + assert.Equal(t, uint16(80), tcp.pkt.Tuple.Dst_port) + assert.NotEqual(t, -1, strings.Index(string(p.Data()), string(tcp.pkt.Payload))) +} + +// 3ffe:507:0:1:200:86ff:fe05:80da.2415 > 3ffe:501:4819::42.53 +var ipv6UdpDns = []byte{ + 0x00, 0x60, 0x97, 0x07, 0x69, 0xea, 0x00, 0x00, 0x86, 0x05, 0x80, 0xda, 0x86, 0xdd, 0x60, 0x00, + 0x00, 0x00, 0x00, 0x61, 0x11, 0x40, 0x3f, 0xfe, 0x05, 0x07, 0x00, 0x00, 0x00, 0x01, 0x02, 0x00, + 0x86, 0xff, 0xfe, 0x05, 0x80, 0xda, 0x3f, 0xfe, 0x05, 0x01, 0x48, 0x19, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x42, 0x09, 0x6f, 0x00, 0x35, 0x00, 0x61, 0xa3, 0x35, 0x5c, 0x78, + 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x61, 0x01, 0x65, 0x01, 0x39, + 0x01, 0x36, 0x01, 0x37, 0x01, 0x30, 0x01, 0x65, 0x01, 0x66, 0x01, 0x66, 0x01, 0x66, 0x01, 0x37, + 0x01, 0x39, 0x01, 0x30, 0x01, 0x36, 0x01, 0x32, 0x01, 0x30, 0x01, 0x31, 0x01, 0x30, 0x01, 0x30, + 0x01, 0x30, 0x01, 0x30, 0x01, 0x30, 0x01, 0x30, 0x01, 0x30, 0x01, 0x37, 0x01, 0x30, 0x01, 0x35, + 0x01, 0x30, 0x01, 0x65, 0x01, 0x66, 0x01, 0x66, 0x01, 0x33, 0x03, 0x69, 0x70, 0x36, 0x03, 0x69, + 0x6e, 0x74, 0x00, 0x00, 0x0c, 0x00, 0x01, +} + +// Test that DecodePacket decodes and IPv6/UDP packet and invokes the UDP processor. +func TestDecodePacketData_ipv6Udp(t *testing.T) { + p := gopacket.NewPacket(ipv6UdpDns, layers.LinkTypeEthernet, gopacket.Default) + if p.ErrorLayer() != nil { + t.Error("Failed to decode packet:", p.ErrorLayer().Error()) + } + d, _, udp := newTestDecoder(t) + d.DecodePacketData(p.Data(), &p.Metadata().CaptureInfo) + + assert.NotNil(t, udp.pkt, "UDP packet not received") + assert.Equal(t, "3ffe:507:0:1:200:86ff:fe05:80da", udp.pkt.Tuple.Src_ip.String()) + assert.Equal(t, uint16(2415), udp.pkt.Tuple.Src_port) + assert.Equal(t, "3ffe:501:4819::42", udp.pkt.Tuple.Dst_ip.String()) + assert.Equal(t, uint16(53), udp.pkt.Tuple.Dst_port) + assert.NotEqual(t, -1, strings.Index(string(p.Data()), string(udp.pkt.Payload))) +} + +// Creates a new TestDecoder that handles ethernet packets. +func newTestDecoder(t *testing.T) (*DecoderStruct, *TestTcpProcessor, *TestUdpProcessor) { + tcpLayer := &TestTcpProcessor{} + udpLayer := &TestUdpProcessor{} + d, err := NewDecoder(layers.LinkTypeEthernet, tcpLayer, udpLayer) + if err != nil { + t.Fatalf("Error creating decoder %v", err) + } + return d, tcpLayer, udpLayer +} diff --git a/main.go b/main.go index 603fca2d23e1..c610437b8b75 100644 --- a/main.go +++ b/main.go @@ -25,6 +25,7 @@ import ( "github.com/elastic/packetbeat/protos/redis" "github.com/elastic/packetbeat/protos/tcp" "github.com/elastic/packetbeat/protos/thrift" + "github.com/elastic/packetbeat/protos/udp" "github.com/elastic/packetbeat/sniffer" ) @@ -111,7 +112,13 @@ func main() { protos.Protos.Register(proto, plugin) } - if err = tcp.TcpInit(); err != nil { + tcpProc, err := tcp.NewTcp(&protos.Protos) + if err != nil { + logp.Critical(err.Error()) + os.Exit(1) + } + udpProc, err := udp.NewUdp(&protos.Protos) + if err != nil { logp.Critical(err.Error()) os.Exit(1) } @@ -134,7 +141,7 @@ func main() { } logp.Debug("main", "Initializing sniffer") - err = sniff.Init(false, afterInputsQueue) + err = sniff.Init(false, afterInputsQueue, tcpProc, udpProc) if err != nil { logp.Critical("Initializing sniffer failed: %v", err) os.Exit(1) @@ -184,7 +191,7 @@ func main() { if service.WithMemProfile() { // wait for all TCP streams to expire time.Sleep(tcp.TCP_STREAM_EXPIRY * 1.2) - tcp.PrintTcpMap() + tcpProc.PrintTcpMap() } service.Cleanup() } diff --git a/protos/protos.go b/protos/protos.go index cdbb653e8414..7e34d39e2554 100644 --- a/protos/protos.go +++ b/protos/protos.go @@ -1,6 +1,9 @@ package protos import ( + "fmt" + "sort" + "strings" "time" "github.com/elastic/libbeat/common" @@ -25,8 +28,12 @@ type ProtocolPlugin interface { // Called to return the configured ports GetPorts() []int +} + +type TcpProtocolPlugin interface { + ProtocolPlugin - // Called when payload data is available for parsing. + // Called when TCP payload data is available for parsing. Parse(pkt *Packet, tcptuple *common.TcpTuple, dir uint8, private ProtocolData) ProtocolData @@ -40,6 +47,13 @@ type ProtocolPlugin interface { private ProtocolData) (priv ProtocolData, drop bool) } +type UdpProtocolPlugin interface { + ProtocolPlugin + + // ParseUdp is invoked when UDP payload data is available for parsing. + ParseUdp(pkt *Packet) +} + // Protocol identifier. type Protocol uint16 @@ -72,32 +86,117 @@ func (p Protocol) String() string { return ProtocolNames[p] } +type Protocols interface { + BpfFilter(with_vlans bool) string + GetTcp(proto Protocol) TcpProtocolPlugin + GetUdp(proto Protocol) UdpProtocolPlugin + GetAll() map[Protocol]ProtocolPlugin + GetAllTcp() map[Protocol]TcpProtocolPlugin + GetAllUdp() map[Protocol]UdpProtocolPlugin + Register(proto Protocol, plugin ProtocolPlugin) +} + // list of protocol plugins -type Protocols struct { - protos map[Protocol]ProtocolPlugin +type ProtocolsStruct struct { + all map[Protocol]ProtocolPlugin + tcp map[Protocol]TcpProtocolPlugin + udp map[Protocol]UdpProtocolPlugin } // Singleton of Protocols type. -var Protos Protocols +var Protos ProtocolsStruct + +func (protocols ProtocolsStruct) GetTcp(proto Protocol) TcpProtocolPlugin { + plugin, exists := protocols.tcp[proto] + if !exists { + return nil + } + + return plugin +} -func (protocols Protocols) Get(proto Protocol) ProtocolPlugin { - ret, exists := protocols.protos[proto] +func (protocols ProtocolsStruct) GetUdp(proto Protocol) UdpProtocolPlugin { + plugin, exists := protocols.udp[proto] if !exists { return nil } - return ret + + return plugin +} + +func (protocols ProtocolsStruct) GetAll() map[Protocol]ProtocolPlugin { + return protocols.all } -func (protocols Protocols) GetAll() map[Protocol]ProtocolPlugin { - return protocols.protos +func (protocols ProtocolsStruct) GetAllTcp() map[Protocol]TcpProtocolPlugin { + return protocols.tcp } -func (protos Protocols) Register(proto Protocol, plugin ProtocolPlugin) { - protos.protos[proto] = plugin +func (protocols ProtocolsStruct) GetAllUdp() map[Protocol]UdpProtocolPlugin { + return protocols.udp +} + +// BpfFilter returns a Berkeley Packer Filter (BFP) expression that +// will match against packets for the registered protocols. If with_vlans is +// true the filter will match against both IEEE 802.1Q VLAN encapsulated +// and unencapsulated packets +func (protocols ProtocolsStruct) BpfFilter(with_vlans bool) string { + // Sort the protocol IDs so that the return value is consistent. + var protos []int + for proto := range protocols.all { + protos = append(protos, int(proto)) + } + sort.Ints(protos) + + var expressions []string + for _, key := range protos { + proto := Protocol(key) + plugin := protocols.all[proto] + for _, port := range plugin.GetPorts() { + has_tcp := false + has_udp := false + + if _, present := protocols.tcp[proto]; present { + has_tcp = true + } + if _, present := protocols.udp[proto]; present { + has_udp = true + } + + var expr string + if has_tcp && !has_udp { + expr = "tcp port %d" + } else if !has_tcp && has_udp { + expr = "udp port %d" + } else { + expr = "port %d" + } + + expressions = append(expressions, fmt.Sprintf(expr, port)) + } + } + + filter := strings.Join(expressions, " or ") + if with_vlans { + filter = fmt.Sprintf("%s or (vlan and (%s))", filter, filter) + } + return filter +} + +func (protos ProtocolsStruct) Register(proto Protocol, plugin ProtocolPlugin) { + protos.all[proto] = plugin + if tcp, ok := plugin.(TcpProtocolPlugin); ok { + protos.tcp[proto] = tcp + } + if udp, ok := plugin.(UdpProtocolPlugin); ok { + protos.udp[proto] = udp + } } func init() { logp.Debug("protos", "Initializing Protos") - Protos = Protocols{} - Protos.protos = make(map[Protocol]ProtocolPlugin) + Protos = ProtocolsStruct{} + Protos.all = make(map[Protocol]ProtocolPlugin) + Protos.tcp = make(map[Protocol]TcpProtocolPlugin) + Protos.udp = make(map[Protocol]UdpProtocolPlugin) } diff --git a/protos/protos_test.go b/protos/protos_test.go index d941b7a18fd5..f7352b8ffb1e 100644 --- a/protos/protos_test.go +++ b/protos/protos_test.go @@ -3,9 +3,83 @@ package protos import ( "testing" + "github.com/elastic/libbeat/common" + "github.com/stretchr/testify/assert" ) +type TestProtocol struct { + Ports []int +} + +type TcpProtocol TestProtocol + +func (proto *TcpProtocol) Init(test_mode bool, results chan common.MapStr) error { + return nil +} + +func (proto *TcpProtocol) GetPorts() []int { + return proto.Ports +} + +func (proto *TcpProtocol) Parse(pkt *Packet, tcptuple *common.TcpTuple, + dir uint8, private ProtocolData) ProtocolData { + return private +} + +func (proto *TcpProtocol) ReceivedFin(tcptuple *common.TcpTuple, dir uint8, + private ProtocolData) ProtocolData { + return private +} + +func (proto *TcpProtocol) GapInStream(tcptuple *common.TcpTuple, dir uint8, + nbytes int, private ProtocolData) (priv ProtocolData, drop bool) { + return private, true +} + +type UdpProtocol TestProtocol + +func (proto *UdpProtocol) Init(test_mode bool, results chan common.MapStr) error { + return nil +} + +func (proto *UdpProtocol) GetPorts() []int { + return proto.Ports +} + +func (proto *UdpProtocol) ParseUdp(pkt *Packet) { + return +} + +type TcpUdpProtocol TestProtocol + +func (proto *TcpUdpProtocol) Init(test_mode bool, results chan common.MapStr) error { + return nil +} + +func (proto *TcpUdpProtocol) GetPorts() []int { + return proto.Ports +} + +func (proto *TcpUdpProtocol) Parse(pkt *Packet, tcptuple *common.TcpTuple, + dir uint8, private ProtocolData) ProtocolData { + return private +} + +func (proto *TcpUdpProtocol) ReceivedFin(tcptuple *common.TcpTuple, dir uint8, + private ProtocolData) ProtocolData { + return private +} + +func (proto *TcpUdpProtocol) GapInStream(tcptuple *common.TcpTuple, dir uint8, + nbytes int, private ProtocolData) (priv ProtocolData, drop bool) { + return private, true +} + +func (proto *TcpUdpProtocol) ParseUdp(pkt *Packet) { + return +} + func TestProtocolNames(t *testing.T) { assert.Equal(t, "unknown", UnknownProtocol.String()) assert.Equal(t, "http", HttpProtocol.String()) @@ -17,3 +91,84 @@ func TestProtocolNames(t *testing.T) { assert.Equal(t, "impossible", Protocol(100).String()) } + +func newProtocols() Protocols { + p := ProtocolsStruct{} + p.all = make(map[Protocol]ProtocolPlugin) + p.tcp = make(map[Protocol]TcpProtocolPlugin) + p.udp = make(map[Protocol]UdpProtocolPlugin) + + tcp := &TcpProtocol{Ports: []int{80}} + udp := &UdpProtocol{Ports: []int{5060}} + tcpUdp := &TcpUdpProtocol{Ports: []int{53}} + + p.Register(1, tcp) + p.Register(2, udp) + p.Register(3, tcpUdp) + return p +} + +func TestBpfFilter_withoutVlan(t *testing.T) { + p := newProtocols() + filter := p.BpfFilter(false) + assert.Equal(t, "tcp port 80 or udp port 5060 or port 53", filter) +} + +func TestBpfFilter_withtVlan(t *testing.T) { + p := newProtocols() + filter := p.BpfFilter(true) + assert.Equal(t, "tcp port 80 or udp port 5060 or port 53 or (vlan and "+ + "(tcp port 80 or udp port 5060 or port 53))", filter) +} + +func TestGetAll(t *testing.T) { + p := newProtocols() + all := p.GetAll() + assert.NotNil(t, all[1]) + assert.NotNil(t, all[2]) + assert.NotNil(t, all[3]) +} + +func TestGetAllTcp(t *testing.T) { + p := newProtocols() + tcp := p.GetAllTcp() + assert.NotNil(t, tcp[1]) + assert.Nil(t, tcp[2]) + assert.NotNil(t, tcp[3]) +} + +func TestGetAllUdp(t *testing.T) { + p := newProtocols() + udp := p.GetAllUdp() + assert.Nil(t, udp[1]) + assert.NotNil(t, udp[2]) + assert.NotNil(t, udp[3]) +} + +func TestGetTcp(t *testing.T) { + p := newProtocols() + tcp := p.GetTcp(1) + assert.NotNil(t, tcp) + assert.Contains(t, tcp.GetPorts(), 80) + + tcp = p.GetTcp(2) + assert.Nil(t, tcp) + + tcp = p.GetTcp(3) + assert.NotNil(t, tcp) + assert.Contains(t, tcp.GetPorts(), 53) +} + +func TestGetUdp(t *testing.T) { + p := newProtocols() + udp := p.GetUdp(1) + assert.Nil(t, udp) + + udp = p.GetUdp(2) + assert.NotNil(t, udp) + assert.Contains(t, udp.GetPorts(), 5060) + + udp = p.GetUdp(3) + assert.NotNil(t, udp) + assert.Contains(t, udp.GetPorts(), 53) +} diff --git a/protos/tcp/tcp.go b/protos/tcp/tcp.go index 91cafd2892d5..253b943ef2e9 100644 --- a/protos/tcp/tcp.go +++ b/protos/tcp/tcp.go @@ -2,7 +2,6 @@ package tcp import ( "fmt" - "strings" "time" "github.com/elastic/libbeat/common" @@ -10,7 +9,6 @@ import ( "github.com/elastic/packetbeat/protos" - "github.com/tsg/gopacket" "github.com/tsg/gopacket/layers" ) @@ -23,25 +21,29 @@ const ( TcpDirectionOriginal = 1 ) -var __id uint32 = 0 - -func GetId() uint32 { - __id += 1 - return __id +type Tcp struct { + id uint32 + streamsMap map[common.HashableIpPortTuple]*TcpStream + portMap map[uint16]protos.Protocol + protocols protos.Protocols } -// Config +type Processor interface { + Process(tcphdr *layers.TCP, pkt *protos.Packet) +} -var tcpStreamsMap = make(map[common.HashableIpPortTuple]*TcpStream, TCP_STREAM_HASH_SIZE) -var tcpPortMap map[uint16]protos.Protocol +func (tcp *Tcp) getId() uint32 { + tcp.id += 1 + return tcp.id +} -func decideProtocol(tuple *common.IpPortTuple) protos.Protocol { - protocol, exists := tcpPortMap[tuple.Src_port] +func (tcp *Tcp) decideProtocol(tuple *common.IpPortTuple) protos.Protocol { + protocol, exists := tcp.portMap[tuple.Src_port] if exists { return protocol } - protocol, exists = tcpPortMap[tuple.Dst_port] + protocol, exists = tcp.portMap[tuple.Dst_port] if exists { return protocol } @@ -55,6 +57,7 @@ type TcpStream struct { timer *time.Timer protocol protos.Protocol tcptuple common.TcpTuple + tcp *Tcp lastSeq [2]uint32 @@ -70,7 +73,7 @@ func (stream *TcpStream) AddPacket(pkt *protos.Packet, tcphdr *layers.TCP, origi } stream.timer = time.AfterFunc(TCP_STREAM_EXPIRY, func() { stream.Expire() }) - mod := protos.Protos.Get(stream.protocol) + mod := stream.tcp.protocols.GetTcp(stream.protocol) if mod == nil { logp.Debug("tcp", "Ignoring protocol for which we have no module loaded: %s", stream.protocol) return @@ -86,7 +89,7 @@ func (stream *TcpStream) AddPacket(pkt *protos.Packet, tcphdr *layers.TCP, origi } func (stream *TcpStream) GapInStream(original_dir uint8, nbytes int) (drop bool) { - mod := protos.Protos.Get(stream.protocol) + mod := stream.tcp.protocols.GetTcp(stream.protocol) stream.Data, drop = mod.GapInStream(&stream.tcptuple, original_dir, nbytes, stream.Data) return drop } @@ -96,7 +99,7 @@ func (stream *TcpStream) Expire() { logp.Debug("mem", "Tcp stream expired") // de-register from dict - delete(tcpStreamsMap, stream.tuple.Hashable()) + delete(stream.tcp.streamsMap, stream.tuple.Hashable()) // nullify to help the GC stream.Data = nil @@ -110,19 +113,19 @@ func TcpSeqBeforeEq(seq1 uint32, seq2 uint32) bool { return int32(seq1-seq2) <= 0 } -func FollowTcp(tcphdr *layers.TCP, pkt *protos.Packet) { +func (tcp *Tcp) Process(tcphdr *layers.TCP, pkt *protos.Packet) { // This Recover should catch all exceptions in // protocol modules. defer logp.Recover("FollowTcp exception") - stream, exists := tcpStreamsMap[pkt.Tuple.Hashable()] + stream, exists := tcp.streamsMap[pkt.Tuple.Hashable()] var original_dir uint8 = TcpDirectionOriginal created := false if !exists { - stream, exists = tcpStreamsMap[pkt.Tuple.RevHashable()] + stream, exists = tcp.streamsMap[pkt.Tuple.RevHashable()] if !exists { - protocol := decideProtocol(&pkt.Tuple) + protocol := tcp.decideProtocol(&pkt.Tuple) if protocol == protos.UnknownProtocol { // don't follow return @@ -130,9 +133,9 @@ func FollowTcp(tcphdr *layers.TCP, pkt *protos.Packet) { logp.Debug("tcp", "Stream doesn't exists, creating new") // create - stream = &TcpStream{id: GetId(), tuple: &pkt.Tuple, protocol: protocol} + stream = &TcpStream{id: tcp.getId(), tuple: &pkt.Tuple, protocol: protocol, tcp: tcp} stream.tcptuple = common.TcpTupleFromIpPort(stream.tuple, stream.id) - tcpStreamsMap[pkt.Tuple.Hashable()] = stream + tcp.streamsMap[pkt.Tuple.Hashable()] = stream created = true } else { original_dir = TcpDirectionReverse @@ -171,21 +174,20 @@ func FollowTcp(tcphdr *layers.TCP, pkt *protos.Packet) { stream.AddPacket(pkt, tcphdr, original_dir) } -func PrintTcpMap() { +func (tcp *Tcp) PrintTcpMap() { fmt.Printf("Streams in memory:") - for _, stream := range tcpStreamsMap { + for _, stream := range tcp.streamsMap { fmt.Printf(" %d", stream.id) } fmt.Printf("\n") - fmt.Printf("Streams dict: %v", tcpStreamsMap) + fmt.Printf("Streams dict: %v", tcp.streamsMap) } -func buildPortsMap(plugins map[protos.Protocol]protos.ProtocolPlugin) (map[uint16]protos.Protocol, error) { +func buildPortsMap(plugins map[protos.Protocol]protos.TcpProtocolPlugin) (map[uint16]protos.Protocol, error) { var res = map[uint16]protos.Protocol{} for proto, protoPlugin := range plugins { - for _, port := range protoPlugin.GetPorts() { old_proto, exists := res[uint16(port)] if exists { @@ -202,147 +204,16 @@ func buildPortsMap(plugins map[protos.Protocol]protos.ProtocolPlugin) (map[uint1 return res, nil } -func portsToBpfFilter(ports []int, with_vlans bool) string { - res := []string{} - for _, port := range ports { - res = append(res, fmt.Sprintf("port %d", port)) - } - - filter := strings.Join(res, " or ") - if with_vlans { - filter = fmt.Sprintf("%s or (vlan and (%s))", filter, filter) - } - - return filter -} - -func BpfFilter(with_vlans bool) string { - - ports := []int{} - - for _, protoPlugin := range protos.Protos.GetAll() { - for _, port := range protoPlugin.GetPorts() { - ports = append(ports, port) - } - } - - return portsToBpfFilter(ports, with_vlans) - -} - -func TcpInit() error { - var err error - tcpPortMap, err = buildPortsMap(protos.Protos.GetAll()) +// Creates and returns a new Tcp. +func NewTcp(p protos.Protocols) (*Tcp, error) { + portMap, err := buildPortsMap(p.GetAllTcp()) if err != nil { - return err - } - - logp.Debug("tcp", "Port map: %v", tcpPortMap) - - return nil -} - -type DecoderStruct struct { - Parser *gopacket.DecodingLayerParser - - sll layers.LinuxSLL - d1q layers.Dot1Q - lo layers.Loopback - eth layers.Ethernet - ip4 layers.IPv4 - ip6 layers.IPv6 - tcp layers.TCP - payload gopacket.Payload - decoded []gopacket.LayerType -} - -func CreateDecoder(datalink layers.LinkType) (*DecoderStruct, error) { - var d DecoderStruct - - logp.Debug("pcapread", "Layer type: %s", datalink.String()) - - switch datalink { - - case layers.LinkTypeLinuxSLL: - d.Parser = gopacket.NewDecodingLayerParser( - layers.LayerTypeLinuxSLL, - &d.sll, &d.d1q, &d.ip4, &d.ip6, &d.tcp, &d.payload) - - case layers.LinkTypeEthernet: - d.Parser = gopacket.NewDecodingLayerParser( - layers.LayerTypeEthernet, - &d.eth, &d.d1q, &d.ip4, &d.ip6, &d.tcp, &d.payload) - - case layers.LinkTypeNull: // loopback on OSx - d.Parser = gopacket.NewDecodingLayerParser( - layers.LayerTypeLoopback, - &d.lo, &d.d1q, &d.ip4, &d.ip6, &d.tcp, &d.payload) - - default: - return nil, fmt.Errorf("Unsuported link type: %s", datalink.String()) - - } - - d.decoded = []gopacket.LayerType{} - - return &d, nil -} - -func (decoder *DecoderStruct) DecodePacketData(data []byte, ci *gopacket.CaptureInfo) { - - var err error - var packet protos.Packet - - err = decoder.Parser.DecodeLayers(data, &decoder.decoded) - if err != nil { - logp.Debug("pcapread", "Decoding error: %s", err) - return - } - - has_tcp := false - - for _, layerType := range decoder.decoded { - switch layerType { - case layers.LayerTypeIPv4: - logp.Debug("ip", "IPv4 packet") - - packet.Tuple.Src_ip = decoder.ip4.SrcIP - packet.Tuple.Dst_ip = decoder.ip4.DstIP - packet.Tuple.Ip_length = 4 - - case layers.LayerTypeIPv6: - logp.Debug("ip", "IPv6 packet") - - packet.Tuple.Src_ip = decoder.ip6.SrcIP - packet.Tuple.Dst_ip = decoder.ip6.DstIP - packet.Tuple.Ip_length = 16 - - case layers.LayerTypeTCP: - logp.Debug("ip", "TCP packet") - - packet.Tuple.Src_port = uint16(decoder.tcp.SrcPort) - packet.Tuple.Dst_port = uint16(decoder.tcp.DstPort) - - has_tcp = true - - case gopacket.LayerTypePayload: - packet.Payload = decoder.payload - } - } - - if !has_tcp { - logp.Debug("pcapread", "No TCP header found in message") - return - } - - if len(packet.Payload) == 0 && !decoder.tcp.FIN { - // We have no use for this atm. - logp.Debug("pcapread", "Ignore empty non-FIN packet") - return + return nil, err } - packet.Ts = ci.Timestamp + tcp := &Tcp{protocols: p, portMap: portMap} + tcp.streamsMap = make(map[common.HashableIpPortTuple]*TcpStream, TCP_STREAM_HASH_SIZE) + logp.Debug("tcp", "Port map: %v", portMap) - packet.Tuple.ComputeHashebles() - FollowTcp(&decoder.tcp, &packet) + return tcp, nil } diff --git a/protos/tcp/tcp_test.go b/protos/tcp/tcp_test.go index 23e7e99f3869..18e8c9e8b6a4 100644 --- a/protos/tcp/tcp_test.go +++ b/protos/tcp/tcp_test.go @@ -39,13 +39,13 @@ func (proto *TestProtocol) GapInStream(tcptuple *common.TcpTuple, dir uint8, func Test_configToPortsMap(t *testing.T) { type configTest struct { - Input map[protos.Protocol]protos.ProtocolPlugin + Input map[protos.Protocol]protos.TcpProtocolPlugin Output map[uint16]protos.Protocol } config_tests := []configTest{ configTest{ - Input: map[protos.Protocol]protos.ProtocolPlugin{ + Input: map[protos.Protocol]protos.TcpProtocolPlugin{ protos.HttpProtocol: &TestProtocol{Ports: []int{80, 8080}}, }, Output: map[uint16]protos.Protocol{ @@ -54,7 +54,7 @@ func Test_configToPortsMap(t *testing.T) { }, }, configTest{ - Input: map[protos.Protocol]protos.ProtocolPlugin{ + Input: map[protos.Protocol]protos.TcpProtocolPlugin{ protos.HttpProtocol: &TestProtocol{Ports: []int{80, 8080}}, protos.MysqlProtocol: &TestProtocol{Ports: []int{3306}}, protos.RedisProtocol: &TestProtocol{Ports: []int{6379, 6380}}, @@ -70,7 +70,7 @@ func Test_configToPortsMap(t *testing.T) { // should ignore duplicate ports in the same protocol configTest{ - Input: map[protos.Protocol]protos.ProtocolPlugin{ + Input: map[protos.Protocol]protos.TcpProtocolPlugin{ protos.HttpProtocol: &TestProtocol{Ports: []int{80, 8080, 8080}}, protos.MysqlProtocol: &TestProtocol{Ports: []int{3306}}, }, @@ -92,14 +92,14 @@ func Test_configToPortsMap(t *testing.T) { func Test_configToPortsMap_negative(t *testing.T) { type errTest struct { - Input map[protos.Protocol]protos.ProtocolPlugin + Input map[protos.Protocol]protos.TcpProtocolPlugin Err string } tests := []errTest{ errTest{ // should raise error on duplicate port - Input: map[protos.Protocol]protos.ProtocolPlugin{ + Input: map[protos.Protocol]protos.TcpProtocolPlugin{ protos.HttpProtocol: &TestProtocol{Ports: []int{80, 8080}}, protos.MysqlProtocol: &TestProtocol{Ports: []int{3306}}, protos.RedisProtocol: &TestProtocol{Ports: []int{6379, 6380, 3306}}, @@ -114,36 +114,3 @@ func Test_configToPortsMap_negative(t *testing.T) { assert.Contains(t, err.Error(), test.Err) } } - -func Test_portsToBpfFilter(t *testing.T) { - type io struct { - Ports []int - WithVlans bool - Output string - } - - tests := []io{ - io{ - Ports: []int{2, 3, 4}, - Output: "port 2 or port 3 or port 4", - }, - io{ - Ports: []int{80, 8080}, - Output: "port 80 or port 8080", - }, - io{ - Ports: []int{2, 3, 4}, - WithVlans: true, - Output: "port 2 or port 3 or port 4 or (vlan and (port 2 or port 3 or port 4))", - }, - io{ - Ports: []int{80, 8080}, - WithVlans: true, - Output: "port 80 or port 8080 or (vlan and (port 80 or port 8080))", - }, - } - - for _, test := range tests { - assert.Equal(t, test.Output, portsToBpfFilter(test.Ports, test.WithVlans)) - } -} diff --git a/protos/udp/udp.go b/protos/udp/udp.go new file mode 100644 index 000000000000..d21c4e10a6e8 --- /dev/null +++ b/protos/udp/udp.go @@ -0,0 +1,96 @@ +package udp + +import ( + "fmt" + + "github.com/elastic/libbeat/common" + "github.com/elastic/libbeat/logp" + + "github.com/elastic/packetbeat/protos" +) + +type Udp struct { + protocols protos.Protocols + portMap map[uint16]protos.Protocol +} + +type Processor interface { + Process(pkt *protos.Packet) +} + +// decideProtocol determines the protocol based on the source and destination +// ports. If the protocol cannot be determined then protos.UnknownProtocol +// is returned. +func (udp *Udp) decideProtocol(tuple *common.IpPortTuple) protos.Protocol { + protocol, exists := udp.portMap[tuple.Src_port] + if exists { + return protocol + } + + protocol, exists = udp.portMap[tuple.Dst_port] + if exists { + return protocol + } + + return protos.UnknownProtocol +} + +// Process handles UDP packets that have been received. It attempts to +// determine the protocol type and then invokes the associated +// UdpProtocolPlugin's ParseUdp method. If the protocol cannot be determined +// or the payload is empty then the method is a noop. +func (udp *Udp) Process(pkt *protos.Packet) { + protocol := udp.decideProtocol(&pkt.Tuple) + if protocol == protos.UnknownProtocol { + logp.Debug("udp", "unknown protocol") + return + } + + plugin := udp.protocols.GetUdp(protocol) + if plugin == nil { + logp.Debug("udp", "Ignoring protocol for which we have no module loaded: %s", protocol) + return + } + + if len(pkt.Payload) > 0 { + logp.Debug("udp", "Parsing packet from %v of length %d.", + pkt.Tuple.String(), len(pkt.Payload)) + plugin.ParseUdp(pkt) + } +} + +// buildPortsMap creates a mapping of port numbers to protocol identifiers. If +// any two UdpProtocolPlugins operate on the same port number then an error +// will be returned. +func buildPortsMap(plugins map[protos.Protocol]protos.UdpProtocolPlugin) (map[uint16]protos.Protocol, error) { + var res = map[uint16]protos.Protocol{} + + for proto, protoPlugin := range plugins { + for _, port := range protoPlugin.GetPorts() { + old_proto, exists := res[uint16(port)] + if exists { + if old_proto == proto { + continue + } + return nil, fmt.Errorf("Duplicate port (%d) exists in %s and %s protocols", + port, old_proto, proto) + } + res[uint16(port)] = proto + } + } + + return res, nil +} + +// NewUdp creates and returns a new Udp. +func NewUdp(p protos.Protocols) (*Udp, error) { + portMap, err := buildPortsMap(p.GetAllUdp()) + if err != nil { + return nil, err + } + + udp := &Udp{protocols: p, portMap: portMap} + logp.Debug("udp", "Port map: %v", portMap) + + return udp, nil +} diff --git a/protos/udp/udp_test.go b/protos/udp/udp_test.go new file mode 100644 index 000000000000..4a2019bc58fd --- /dev/null +++ b/protos/udp/udp_test.go @@ -0,0 +1,231 @@ +package udp + +import ( + "net" + "testing" + "time" + + "github.com/elastic/libbeat/common" + "github.com/elastic/libbeat/logp" + "github.com/elastic/packetbeat/protos" + + "github.com/stretchr/testify/assert" +) + +// Protocol ID and port number used by TestProtocol in various tests. +const ( + PROTO = protos.Protocol(1) + PORT = 1234 +) + +type TestProtocols struct { + udp map[protos.Protocol]protos.UdpProtocolPlugin +} + +func (p TestProtocols) BpfFilter(with_vlans bool) string { + return "mock bpf filter" +} + +func (p TestProtocols) GetTcp(proto protos.Protocol) protos.TcpProtocolPlugin { + return nil +} + +func (p TestProtocols) GetUdp(proto protos.Protocol) protos.UdpProtocolPlugin { + return p.udp[proto] +} + +func (p TestProtocols) GetAll() map[protos.Protocol]protos.ProtocolPlugin { + return nil +} + +func (p TestProtocols) GetAllTcp() map[protos.Protocol]protos.TcpProtocolPlugin { + return nil +} + +func (p TestProtocols) GetAllUdp() map[protos.Protocol]protos.UdpProtocolPlugin { + return p.udp +} + +func (p TestProtocols) Register(proto protos.Protocol, plugin protos.ProtocolPlugin) { + return +} + +type TestProtocol struct { + Ports []int // Ports that the protocol operates on. + pkt *protos.Packet // UDP packet that the plugin was called to process. +} + +func (proto *TestProtocol) Init(test_mode bool, results chan common.MapStr) error { + return nil +} + +func (proto *TestProtocol) GetPorts() []int { + return proto.Ports +} + +func (proto *TestProtocol) ParseUdp(pkt *protos.Packet) { + proto.pkt = pkt +} + +type TestStruct struct { + protocols *TestProtocols + udp *Udp + plugin *TestProtocol +} + +// Helper method for creating mocks and the Udp instance under test. +func testSetup(t *testing.T) *TestStruct { + if testing.Verbose() { + logp.LogInit(logp.LOG_DEBUG, "", false, true, []string{"udp"}) + } + + protocols := &TestProtocols{} + protocols.udp = make(map[protos.Protocol]protos.UdpProtocolPlugin) + plugin := &TestProtocol{Ports: []int{PORT}} + protocols.udp[PROTO] = plugin + + udp, err := NewUdp(protocols) + if err != nil { + t.Error("Error creating UDP handler: ", err) + } + + return &TestStruct{protocols: protocols, udp: udp, plugin: plugin} +} + +func Test_buildPortsMap(t *testing.T) { + + type configTest struct { + Input map[protos.Protocol]protos.UdpProtocolPlugin + Output map[uint16]protos.Protocol + } + + // The protocols named here are not necessarily UDP. They are just used + // for testing purposes. + config_tests := []configTest{ + configTest{ + Input: map[protos.Protocol]protos.UdpProtocolPlugin{ + protos.HttpProtocol: &TestProtocol{Ports: []int{80, 8080}}, + }, + Output: map[uint16]protos.Protocol{ + 80: protos.HttpProtocol, + 8080: protos.HttpProtocol, + }, + }, + configTest{ + Input: map[protos.Protocol]protos.UdpProtocolPlugin{ + protos.HttpProtocol: &TestProtocol{Ports: []int{80, 8080}}, + protos.MysqlProtocol: &TestProtocol{Ports: []int{3306}}, + protos.RedisProtocol: &TestProtocol{Ports: []int{6379, 6380}}, + }, + Output: map[uint16]protos.Protocol{ + 80: protos.HttpProtocol, + 8080: protos.HttpProtocol, + 3306: protos.MysqlProtocol, + 6379: protos.RedisProtocol, + 6380: protos.RedisProtocol, + }, + }, + + // should ignore duplicate ports in the same protocol + configTest{ + Input: map[protos.Protocol]protos.UdpProtocolPlugin{ + protos.HttpProtocol: &TestProtocol{Ports: []int{80, 8080, 8080}}, + protos.MysqlProtocol: &TestProtocol{Ports: []int{3306}}, + }, + Output: map[uint16]protos.Protocol{ + 80: protos.HttpProtocol, + 8080: protos.HttpProtocol, + 3306: protos.MysqlProtocol, + }, + }, + } + + for _, test := range config_tests { + output, err := buildPortsMap(test.Input) + assert.Nil(t, err) + assert.Equal(t, test.Output, output) + } +} + +// Verify that buildPortsMap returns an error when two plugins are registered +// for the same port number. +func Test_buildPortsMap_portOverlapError(t *testing.T) { + type errTest struct { + Input map[protos.Protocol]protos.UdpProtocolPlugin + Err string + } + + // The protocols named here are not necessarily UDP. They are just used + // for testing purposes. + tests := []errTest{ + errTest{ + // Should raise error on duplicate port + Input: map[protos.Protocol]protos.UdpProtocolPlugin{ + protos.HttpProtocol: &TestProtocol{Ports: []int{80, 8080}}, + protos.MysqlProtocol: &TestProtocol{Ports: []int{3306}}, + protos.RedisProtocol: &TestProtocol{Ports: []int{6379, 6380, 3306}}, + }, + Err: "Duplicate port (3306) exists", + }, + } + + for _, test := range tests { + _, err := buildPortsMap(test.Input) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), test.Err) + } +} + +// Verify that decideProtocol returns the protocol assocated with the +// packet's source port. +func Test_decideProtocol_bySrcPort(t *testing.T) { + test := testSetup(t) + tuple := common.NewIpPortTuple(4, + net.ParseIP("192.168.0.1"), PORT, + net.ParseIP("10.0.0.1"), 34898) + assert.Equal(t, PROTO, test.udp.decideProtocol(&tuple)) +} + +// Verify that decideProtocol returns the protocol assocated with the +// packet's destination port. +func Test_decideProtocol_byDstPort(t *testing.T) { + test := testSetup(t) + tuple := common.NewIpPortTuple(4, + net.ParseIP("10.0.0.1"), 34898, + net.ParseIP("192.168.0.1"), PORT) + assert.Equal(t, PROTO, test.udp.decideProtocol(&tuple)) +} + +// Verify that decideProtocol returns UnknownProtocol when given packet for +// which it does not have a plugin. +func TestProcess_unknownProtocol(t *testing.T) { + test := testSetup(t) + tuple := common.NewIpPortTuple(4, + net.ParseIP("10.0.0.1"), 34898, + net.ParseIP("192.168.0.1"), PORT+1) + assert.Equal(t, protos.UnknownProtocol, test.udp.decideProtocol(&tuple)) +} + +// Verify that Process ignores empty packets. +func TestProcess_emptyPayload(t *testing.T) { + test := testSetup(t) + tuple := common.NewIpPortTuple(4, + net.ParseIP("192.168.0.1"), PORT, + net.ParseIP("10.0.0.1"), 34898) + emptyPkt := &protos.Packet{Ts: time.Now(), Tuple: tuple, Payload: []byte{}} + test.udp.Process(emptyPkt) + assert.Nil(t, test.plugin.pkt) +} + +// Verify that Process finds the plugin associated with the packet and invokes +// ProcessUdp on it. +func TestProcess_nonEmptyPayload(t *testing.T) { + test := testSetup(t) + tuple := common.NewIpPortTuple(4, + net.ParseIP("192.168.0.1"), PORT, + net.ParseIP("10.0.0.1"), 34898) + payload := []byte{1} + pkt := &protos.Packet{Ts: time.Now(), Tuple: tuple, Payload: payload} + test.udp.Process(pkt) + assert.Equal(t, pkt, test.plugin.pkt) +} diff --git a/sniffer/sniffer.go b/sniffer/sniffer.go index 6f93054685f2..90ecb080b715 100644 --- a/sniffer/sniffer.go +++ b/sniffer/sniffer.go @@ -11,7 +11,10 @@ import ( "github.com/elastic/libbeat/logp" "github.com/elastic/packetbeat/config" + "github.com/elastic/packetbeat/decoder" + "github.com/elastic/packetbeat/protos" "github.com/elastic/packetbeat/protos/tcp" + "github.com/elastic/packetbeat/protos/udp" "github.com/tsg/gopacket" "github.com/tsg/gopacket/layers" @@ -26,7 +29,7 @@ type SnifferSetup struct { isAlive bool dumper *pcap.Dumper - Decoder *tcp.DecoderStruct + Decoder *decoder.DecoderStruct DataSource gopacket.PacketDataSource } @@ -205,11 +208,13 @@ func (sniffer *SnifferSetup) Datalink() layers.LinkType { return layers.LinkTypeEthernet } -func (sniffer *SnifferSetup) Init(test_mode bool, events chan common.MapStr) error { +func (sniffer *SnifferSetup) Init(test_mode bool, events chan common.MapStr, + tcp tcp.Processor, udp udp.Processor) error { if config.ConfigSingleton.Interfaces.Bpf_filter == "" { with_vlans := config.ConfigSingleton.Interfaces.With_vlans - config.ConfigSingleton.Interfaces.Bpf_filter = tcp.BpfFilter(with_vlans) + config.ConfigSingleton.Interfaces.Bpf_filter = protos.Protos.BpfFilter(with_vlans) } + logp.Debug("sniffer", "BPF filter: %s", config.ConfigSingleton.Interfaces.Bpf_filter) var err error if !test_mode { @@ -219,7 +224,7 @@ func (sniffer *SnifferSetup) Init(test_mode bool, events chan common.MapStr) err } } - sniffer.Decoder, err = tcp.CreateDecoder(sniffer.Datalink()) + sniffer.Decoder, err = decoder.NewDecoder(sniffer.Datalink(), tcp, udp) if err != nil { return fmt.Errorf("Error creating decoder: %v", err) }