From 4c34cc720778ebb3444f3ca0cd761d975b292049 Mon Sep 17 00:00:00 2001 From: Marcin Tojek Date: Mon, 3 Feb 2020 09:11:17 +0100 Subject: [PATCH 1/7] Refactor mqtt input --- filebeat/input/mqtt/client.go | 192 ++++----------------------------- filebeat/input/mqtt/config.go | 57 +++------- filebeat/input/mqtt/input.go | 164 ++++++++++++++++++---------- filebeat/input/mqtt/logging.go | 79 ++++++++++++++ 4 files changed, 222 insertions(+), 270 deletions(-) create mode 100644 filebeat/input/mqtt/logging.go diff --git a/filebeat/input/mqtt/client.go b/filebeat/input/mqtt/client.go index 0079dbf78f4..73e8d2f5344 100644 --- a/filebeat/input/mqtt/client.go +++ b/filebeat/input/mqtt/client.go @@ -18,191 +18,37 @@ package mqtt import ( - "crypto/tls" - "crypto/x509" - "encoding/json" - "io/ioutil" - "strings" - "time" + libmqtt "github.com/eclipse/paho.mqtt.golang" - "gopkg.in/vmihailenco/msgpack.v2" - - "github.com/elastic/beats/libbeat/beat" - "github.com/elastic/beats/libbeat/common" - "github.com/elastic/beats/libbeat/logp" - - MQTT "github.com/eclipse/paho.mqtt.golang" + "github.com/elastic/beats/libbeat/outputs" ) -func (input *mqttInput) newTLSConfig() (*tls.Config, error) { - config := input.config - - // Import trusted certificates from CAfile.pem. - // Alternatively, manually add CA certificates to - // default openssl CA bundle. - certpool := x509.NewCertPool() - if config.CA != "" { - logp.Info("[MQTT] Set the CA") - pemCerts, err := ioutil.ReadFile(config.CA) - if err != nil { - return nil, err - } - certpool.AppendCertsFromPEM(pemCerts) - } +func createClientOptions(config mqttInputConfig, onConnectHandler func(client libmqtt.Client)) (*libmqtt.ClientOptions, error) { + clientOptions := libmqtt.NewClientOptions(). + SetClientID(config.ClientID). + SetUsername(config.Username). + SetPassword(config.Password). + SetConnectRetry(true). + SetOnConnectHandler(onConnectHandler) - tlsconfig := &tls.Config{ - // RootCAs = certs used to verify server cert. - RootCAs: certpool, - // ClientAuth = whether to request cert from server. - // Since the server is set up for SSL, this happens - // anyways. - ClientAuth: tls.NoClientCert, - // ClientCAs = certs used to validate client cert. - ClientCAs: nil, - // InsecureSkipVerify = verify that cert contents - // match server. IP matches what is in cert etc. - InsecureSkipVerify: true, + for _, host := range config.Hosts { + clientOptions.AddBroker(host) } - // Import client certificate/key pair - if config.ClientCert != "" && config.ClientKey != "" { - logp.Info("[MQTT] Set the Certs") - cert, err := tls.LoadX509KeyPair(config.ClientCert, config.ClientKey) + if config.TLS != nil { + tlsConfig, err := outputs.LoadTLSConfig(config.TLS) if err != nil { return nil, err } - - // Certificates = list of certs client sends to server. - tlsconfig.Certificates = []tls.Certificate{cert} - } - - // Create tls.Config with desired tls properties - return tlsconfig, nil -} - -// Prepare MQTT client -func (input *mqttInput) setupMqttClient() error { - c := input.config - - logp.Info("[MQTT] Connect to broker URL: %s", c.Host) - - mqttClientOpt := MQTT.NewClientOptions() - mqttClientOpt.SetClientID(c.ClientID) - mqttClientOpt.AddBroker(c.Host) - - mqttClientOpt.SetMaxReconnectInterval(1 * time.Second) - mqttClientOpt.SetConnectionLostHandler(input.connectionLostHandler) - mqttClientOpt.SetOnConnectHandler(input.subscribeOnConnect) - mqttClientOpt.SetAutoReconnect(true) - - if c.Username != "" { - logp.Info("[MQTT] Broker username: %s", c.Username) - mqttClientOpt.SetUsername(c.Username) - } - - if c.Password != "" { - mqttClientOpt.SetPassword(c.Password) - } - - if c.SSL == true { - logp.Info("[MQTT] Configure session to use SSL") - tlsconfig, err := input.newTLSConfig() - if err != nil { - return err - } - mqttClientOpt.SetTLSConfig(tlsconfig) - } - - input.client = MQTT.NewClient(mqttClientOpt) - return nil -} - -func (input *mqttInput) connect() error { - if token := input.client.Connect(); token.WaitTimeout(input.config.WaitClose) && token.Error() != nil { - logp.Err("MQTT Failed to connect") - return token.Error() - } - logp.Info("MQTT Client connected: %t", input.client.IsConnected()) - return nil -} - -func (input *mqttInput) subscribeOnConnect(client MQTT.Client) { - subscriptions := prepareSubscriptionsForTopics(input.config.Topics, input.config.QoS) - - // Mqtt client - Subscribe to every topic in the config file, and bind with message handler - if token := input.client.SubscribeMultiple(subscriptions, input.onMessage); token.WaitTimeout(input.config.WaitClose) && token.Error() != nil { - logp.Error(token.Error()) - } - logp.Info("MQTT Subscribed to configured topics") -} - -// Mqtt message handler -func (input *mqttInput) onMessage(client MQTT.Client, msg MQTT.Message) { - logp.Debug("MQTT", "MQTT message received: %s", string(msg.Payload())) - var beatEvent beat.Event - eventFields := make(common.MapStr) - - // default case - var mqtt = make(common.MapStr) - eventFields["message"] = string(msg.Payload()) - if input.config.DecodePayload { - mqtt["fields"] = decodeBytes(msg.Payload()) - } - - eventFields["is_system_topic"] = strings.HasPrefix(msg.Topic(), "$") - eventFields["topic"] = msg.Topic() - - mqtt["id"] = msg.MessageID() - mqtt["retained"] = msg.Retained() - eventFields["mqtt"] = mqtt - - // Finally sending the message to elasticsearch - beatEvent.Fields = eventFields - isSent := input.outlet.OnEvent(beatEvent) - - logp.Debug("MQTT", "Event sent: %t", isSent) -} - -// connectionLostHandler will try to reconnect when connection is lost -func (input *mqttInput) connectionLostHandler(client MQTT.Client, reason error) { - logp.Warn("[MQTT] Connection lost: %s", reason.Error()) - - //Rerun the input - input.Run() -} - -// decodeBytes will try to decode the bytes in the following order -// 1.) Check for msgpack format -// 2.) Check for json format -// 3.) If every check fails, it will -// return the the string representation -func decodeBytes(payload []byte) common.MapStr { - event := make(common.MapStr) - - // A msgpack payload must be a json-like object - err := msgpack.Unmarshal(payload, &event) - if err == nil { - logp.Debug("MQTT", "Payload decoded - msgpack") - return event + clientOptions.SetTLSConfig(tlsConfig.BuildModuleConfig("")) } - - err = json.Unmarshal(payload, &event) - if err == nil { - logp.Debug("MQTT", "Payload decoded - as json") - return event - } - - logp.Debug("MQTT", "decoded - as text") - return event + return clientOptions, nil } -// ParseTopics will parse the config file and return a map with topic:QoS -func prepareSubscriptionsForTopics(topics []string, qos int) map[string]byte { - subscriptions := make(map[string]byte) - for _, value := range topics { - // Finally, filling the subscriptions map - subscriptions[value] = byte(qos) - logp.Info("Subscribe to %v with QoS %v", value, qos) +func createClientSubscriptions(config mqttInputConfig) map[string]byte { + subscriptions := map[string]byte{} + for _, topic := range config.Topics { + subscriptions[topic] = byte(config.QoS) } return subscriptions } diff --git a/filebeat/input/mqtt/config.go b/filebeat/input/mqtt/config.go index 9be3ef6103e..03abd82e5d0 100644 --- a/filebeat/input/mqtt/config.go +++ b/filebeat/input/mqtt/config.go @@ -19,57 +19,34 @@ package mqtt import ( "errors" - "fmt" - "time" + + "github.com/elastic/beats/libbeat/common/transport/tlscommon" ) type mqttInputConfig struct { - Host string `config:"host"` - Topics []string `config:"topics"` - Username string `config:"user"` - Password string `config:"password"` - QoS int `config:"QoS"` - DecodePayload bool `config:"decode_payload"` - SSL bool `config:"ssl"` - CA string `config:"CA"` - ClientCert string `config:"clientCert"` - ClientKey string `config:"clientKey"` - ClientID string `config:"clientID"` - WaitClose time.Duration `config:"wait_close" validate:"min=0"` - ConnectBackoff time.Duration `config:"connect_backoff" validate:"min=0"` + Hosts []string `config:"hosts" validate:"required,min=1"` + Topics []string `config:"topics" validate:"nonzero,min=1"` + QoS int `config:"qos" validate:"nonzero,min=0,max=2"` + + ClientID string `config:"clientID" validate:"nonzero"` + Username string `config:"user"` + Password string `config:"password"` + + TLS *tlscommon.Config `config:"ssl"` } -// The default config for the mqtt input +// The default config for the mqtt input. func defaultConfig() mqttInputConfig { return mqttInputConfig{ - Host: "localhost", - Topics: []string{"#"}, - ClientID: "Filebeat", - Username: "", - Password: "", - DecodePayload: true, - QoS: 0, - SSL: false, - CA: "", - ClientCert: "", - ClientKey: "", - WaitClose: 5 * time.Second, - ConnectBackoff: 30 * time.Second, + ClientID: "filebeat", + Topics: []string{"#"}, } } // Validate validates the config. -func (c *mqttInputConfig) Validate() error { - if c.Host == "" { - return errors.New("no host configured") - } - - if c.Username != "" && c.Password == "" { - return fmt.Errorf("password must be set when username is configured") - } - - if len(c.ClientID) > 23 || len(c.ClientID) < 1 { - return fmt.Errorf("client id must be between 1 and 23 characters long") +func (mic *mqttInputConfig) Validate() error { + if len(mic.ClientID) < 1 || len(mic.ClientID) > 23 { + return errors.New("ClientID must be between 1 and 23 characters long") } return nil } diff --git a/filebeat/input/mqtt/input.go b/filebeat/input/mqtt/input.go index e8253d901ab..20b24709a08 100644 --- a/filebeat/input/mqtt/input.go +++ b/filebeat/input/mqtt/input.go @@ -18,7 +18,12 @@ package mqtt import ( + "strings" "sync" + "time" + + libmqtt "github.com/eclipse/paho.mqtt.golang" + "github.com/pkg/errors" "github.com/elastic/beats/filebeat/channel" "github.com/elastic/beats/filebeat/input" @@ -26,12 +31,25 @@ import ( "github.com/elastic/beats/libbeat/common" "github.com/elastic/beats/libbeat/common/backoff" "github.com/elastic/beats/libbeat/logp" +) - "github.com/pkg/errors" +const ( + disconnectTimeout = 3 * 1000 // 3000 ms = 3 sec - MQTT "github.com/eclipse/paho.mqtt.golang" + subscribeTimeout = 35 * time.Second // in client: subscribeWaitTimeout = 30s + subscribeRetryInterval = 1 * time.Second ) +// Input contains the input and its config +type mqttInput struct { + once sync.Once + + logger *logp.Logger + + client libmqtt.Client + inflightMessages *sync.WaitGroup +} + func init() { err := input.Register("mqtt", NewInput) if err != nil { @@ -39,24 +57,12 @@ func init() { } } -// Input contains the input and its config -type mqttInput struct { - config mqttInputConfig - context input.Context - outlet channel.Outleter - log *logp.Logger - mqttWaitGroup sync.WaitGroup - runOnce sync.Once - client MQTT.Client -} - -// NewInput creates a new mqtt input +// NewInput method creates a new mqtt input, func NewInput( cfg *common.Config, connector channel.Connector, inputContext input.Context, ) (input.Input, error) { - config := defaultConfig() if err := cfg.Unpack(&config); err != nil { return nil, errors.Wrap(err, "reading mqtt input config") @@ -66,65 +72,109 @@ func NewInput( Processing: beat.ProcessingConfig{ DynamicFields: inputContext.DynamicFields, }, - // ACKEvents: func(events []interface{}) { - // for _, event := range events { - // if meta, ok := event.(eventMeta); ok { - // meta.handler.ack(meta.message) - // } - // } - // }, - WaitClose: config.WaitClose, }) if err != nil { return nil, err } - input := &mqttInput{ - config: config, - context: inputContext, - outlet: out, - log: logp.NewLogger("mqtt input").With("host", config.Host), - } + logger := logp.NewLogger("mqtt input").With("hosts", config.Hosts) + setupLibraryLogging() - err = input.setupMqttClient() + inflightMessages := new(sync.WaitGroup) + clientSubscriptions := createClientSubscriptions(config) + onMessageHandler := createOnMessageHandler(logger, out, inflightMessages) + onConnectHandler := createOnConnectHandler(logger, &inputContext, onMessageHandler, clientSubscriptions) + clientOptions, err := createClientOptions(config, onConnectHandler) if err != nil { return nil, err } - return input, nil + return &mqttInput{ + client: libmqtt.NewClient(clientOptions), + inflightMessages: inflightMessages, + logger: logp.NewLogger("mqtt input").With("hosts", config.Hosts), + }, nil +} + +func createOnMessageHandler(logger *logp.Logger, outlet channel.Outleter, inflightMessages *sync.WaitGroup) func(client libmqtt.Client, message libmqtt.Message) { + return func(client libmqtt.Client, message libmqtt.Message) { + inflightMessages.Add(1) + + logger.Debugf("Received message on topic '%s', messageID: %d, size: %d", message.Topic(), + message.MessageID(), len(message.Payload())) + + mqttFields := common.MapStr{ + "duplicate": message.Duplicate(), + "message_id": message.MessageID(), + "qos": message.Qos(), + "retained": message.Retained(), + "topic": message.Topic(), + } + outlet.OnEvent(beat.Event{ + Timestamp: time.Now(), + Fields: common.MapStr{ + "message": string(message.Payload()), + "mqtt": mqttFields, + }, + }) + + inflightMessages.Done() + } } -// Run starts the input by scanning for incoming messages and errors. -func (input *mqttInput) Run() { - input.runOnce.Do(func() { - go func() { - - // If the consumer fails to connect, we use exponential backoff with - // jitter up to 8 * the initial backoff interval. - backoff := backoff.NewEqualJitterBackoff( - input.context.Done, - input.config.ConnectBackoff, - 8*input.config.ConnectBackoff) - - for !input.client.IsConnected() { - err := input.connect() - if err != nil { - logp.Error(err) - backoff.Wait() +func createOnConnectHandler(logger *logp.Logger, inputContext *input.Context, onMessageHandler func(client libmqtt.Client, message libmqtt.Message), clientSubscriptions map[string]byte) func(client libmqtt.Client) { + // The function subscribes the client to the specific topics (with retry backoff in case of failure). + return func(client libmqtt.Client) { + backoff := backoff.NewEqualJitterBackoff( + inputContext.Done, + subscribeRetryInterval, + 8*subscribeRetryInterval) + + var topics []string + for topic := range clientSubscriptions { + topics = append(topics, topic) + } + + var success bool + for !success { + logger.Debugf("Try subscribe to topics: %v", strings.Join(topics, ", ")) + + token := client.SubscribeMultiple(clientSubscriptions, onMessageHandler) + if !token.WaitTimeout(subscribeTimeout) { + if token.Error() != nil { + logger.Warnf("Subscribing to topics failed due to error: %v", token.Error()) + } + + if !backoff.Wait() { + backoff.Reset() + success = true } + } else { + backoff.Reset() + success = true } - //All the rest is working asynchronously within the MQTT client - }() + } + } +} + +// Run method starts the mqtt input and processing. +// The mqtt client starts in auto-connect mode (with connection retries and resuming topic subscriptions). +func (mi *mqttInput) Run() { + mi.once.Do(func() { + mi.logger.Debug("Run the input once.") + mi.client.Connect() }) } -// Stop disconnects the MQTT client -func (input *mqttInput) Stop() { - input.client.Disconnect(250) +// Stop method stops the input. +func (mi *mqttInput) Stop() { + mi.logger.Debug("Stop the input.") + mi.client.Disconnect(disconnectTimeout) + mi.Wait() } -// Wait should shut down the input and wait for it to complete -// The disconnect of the client will do this for us -func (input *mqttInput) Wait() { - input.Stop() +// Wait method waits until event processing is finished and stops the input. +func (mi *mqttInput) Wait() { + mi.logger.Debug("Wait for the input to finish processing.") + mi.inflightMessages.Wait() } diff --git a/filebeat/input/mqtt/logging.go b/filebeat/input/mqtt/logging.go new file mode 100644 index 00000000000..35510c20b34 --- /dev/null +++ b/filebeat/input/mqtt/logging.go @@ -0,0 +1,79 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package mqtt + +import ( + "sync" + + libmqtt "github.com/eclipse/paho.mqtt.golang" + "go.uber.org/zap" + + "github.com/elastic/beats/libbeat/logp" +) + +var setupLoggingOnce sync.Once + +type loggerWrapper struct { + log *logp.Logger +} + +type ( + debugLogger loggerWrapper + errorLogger loggerWrapper + warnLogger loggerWrapper +) + +var ( + _ libmqtt.Logger = new(debugLogger) + _ libmqtt.Logger = new(errorLogger) + _ libmqtt.Logger = new(warnLogger) +) + +func setupLibraryLogging() { + setupLoggingOnce.Do(func() { + logger := logp.NewLogger("libmqtt", zap.AddCallerSkip(1)) + libmqtt.CRITICAL = &errorLogger{log: logger} + libmqtt.DEBUG = &debugLogger{log: logger} + libmqtt.ERROR = &errorLogger{log: logger} + libmqtt.WARN = &warnLogger{log: logger} + }) +} + +func (l *debugLogger) Println(v ...interface{}) { + l.log.Debug(v...) +} + +func (l *debugLogger) Printf(format string, v ...interface{}) { + l.log.Debugf(format, v...) +} + +func (l *errorLogger) Println(v ...interface{}) { + l.log.Error(v...) +} + +func (l *errorLogger) Printf(format string, v ...interface{}) { + l.log.Errorf(format, v...) +} + +func (l *warnLogger) Println(v ...interface{}) { + l.log.Warn(v...) +} + +func (l *warnLogger) Printf(format string, v ...interface{}) { + l.log.Warnf(format, v...) +} From 796b1d4afd294fa625c4cd8f6683364ccef3d6c8 Mon Sep 17 00:00:00 2001 From: Marcin Tojek Date: Mon, 3 Feb 2020 11:17:24 +0100 Subject: [PATCH 2/7] Fix: comment --- filebeat/input/mqtt/input.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/filebeat/input/mqtt/input.go b/filebeat/input/mqtt/input.go index 20b24709a08..909e80c320f 100644 --- a/filebeat/input/mqtt/input.go +++ b/filebeat/input/mqtt/input.go @@ -173,7 +173,7 @@ func (mi *mqttInput) Stop() { mi.Wait() } -// Wait method waits until event processing is finished and stops the input. +// Wait method waits until event processing is finished. func (mi *mqttInput) Wait() { mi.logger.Debug("Wait for the input to finish processing.") mi.inflightMessages.Wait() From 96758449644f06968a429c45f4982235001042a0 Mon Sep 17 00:00:00 2001 From: Marcin Tojek Date: Mon, 3 Feb 2020 13:38:03 +0100 Subject: [PATCH 3/7] Add unit tests --- filebeat/input/mqtt/client_mocked.go | 192 +++++++++++++++++++++++++ filebeat/input/mqtt/input.go | 4 +- filebeat/input/mqtt/input_test.go | 205 +++++++++++++++++++++++++++ 3 files changed, 400 insertions(+), 1 deletion(-) create mode 100644 filebeat/input/mqtt/client_mocked.go create mode 100644 filebeat/input/mqtt/input_test.go diff --git a/filebeat/input/mqtt/client_mocked.go b/filebeat/input/mqtt/client_mocked.go new file mode 100644 index 00000000000..6afeaa0726b --- /dev/null +++ b/filebeat/input/mqtt/client_mocked.go @@ -0,0 +1,192 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package mqtt + +import ( + "time" + + libmqtt "github.com/eclipse/paho.mqtt.golang" + + "github.com/elastic/beats/filebeat/channel" + "github.com/elastic/beats/libbeat/beat" + "github.com/elastic/beats/libbeat/common" +) + +type mockedMessage struct { + duplicate bool + messageID uint16 + qos byte + retained bool + topic string + payload []byte + ack func() +} + +var _ libmqtt.Message = new(mockedMessage) + +func (m *mockedMessage) Duplicate() bool { + return m.duplicate +} + +func (m *mockedMessage) Qos() byte { + return m.qos +} + +func (m *mockedMessage) Retained() bool { + return m.retained +} + +func (m *mockedMessage) Topic() string { + return m.topic +} + +func (m *mockedMessage) MessageID() uint16 { + return m.messageID +} + +func (m *mockedMessage) Payload() []byte { + return m.payload +} + +func (m *mockedMessage) Ack() { + panic("implement me") +} + +type mockedToken struct{} + +var _ libmqtt.Token = new(mockedToken) + +func (m *mockedToken) Wait() bool { + panic("implement me") +} + +func (m *mockedToken) WaitTimeout(time.Duration) bool { + return true +} + +func (m *mockedToken) Error() error { + panic("implement me") +} + +type mockedClient struct { + connectCount int + disconnectCount int + subscribeMultipleCount int + + subscriptions []string + messages []mockedMessage + + onConnectHandler func(client libmqtt.Client) + onMessageHandler func(client libmqtt.Client, message libmqtt.Message) +} + +var _ libmqtt.Client = new(mockedClient) + +func (m *mockedClient) IsConnected() bool { + panic("implement me") +} + +func (m *mockedClient) IsConnectionOpen() bool { + panic("implement me") +} + +func (m *mockedClient) Connect() libmqtt.Token { + m.connectCount++ + + if m.onConnectHandler != nil { + m.onConnectHandler(m) + } + return nil +} + +func (m *mockedClient) Disconnect(quiesce uint) { + m.disconnectCount++ +} + +func (m *mockedClient) Publish(topic string, qos byte, retained bool, payload interface{}) libmqtt.Token { + panic("implement me") +} + +func (m *mockedClient) Subscribe(topic string, qos byte, callback libmqtt.MessageHandler) libmqtt.Token { + panic("implement me") +} + +func (m *mockedClient) SubscribeMultiple(filters map[string]byte, callback libmqtt.MessageHandler) libmqtt.Token { + m.subscribeMultipleCount++ + + for filter := range filters { + m.subscriptions = append(m.subscriptions, filter) + } + m.onMessageHandler = callback + + go func() { + for _, msg := range m.messages { + m.onMessageHandler(m, &msg) + } + }() + return new(mockedToken) +} + +func (m *mockedClient) Unsubscribe(topics ...string) libmqtt.Token { + panic("implement me") +} + +func (m *mockedClient) AddRoute(topic string, callback libmqtt.MessageHandler) { + panic("implement me") +} + +func (m *mockedClient) OptionsReader() libmqtt.ClientOptionsReader { + panic("implement me") +} + +type mockedConnector struct { + connectWithError error + outlet channel.Outleter +} + +var _ channel.Connector = new(mockedConnector) + +func (m *mockedConnector) Connect(*common.Config) (channel.Outleter, error) { + panic("implement me") +} + +func (m *mockedConnector) ConnectWith(*common.Config, beat.ClientConfig) (channel.Outleter, error) { + if m.connectWithError != nil { + return nil, m.connectWithError + } + return m.outlet, nil +} + +type mockedOutleter struct { + events chan<- beat.Event +} + +var _ channel.Outleter = new(mockedOutleter) + +func (m mockedOutleter) Close() error { + panic("implement me") +} + +func (m mockedOutleter) Done() <-chan struct{} { + panic("implement me") +} + +func (m mockedOutleter) OnEvent(event beat.Event) bool { + m.events <- event + return true +} diff --git a/filebeat/input/mqtt/input.go b/filebeat/input/mqtt/input.go index 909e80c320f..3c8e7e26a06 100644 --- a/filebeat/input/mqtt/input.go +++ b/filebeat/input/mqtt/input.go @@ -40,6 +40,8 @@ const ( subscribeRetryInterval = 1 * time.Second ) +var newMqttClient = libmqtt.NewClient + // Input contains the input and its config type mqttInput struct { once sync.Once @@ -90,7 +92,7 @@ func NewInput( } return &mqttInput{ - client: libmqtt.NewClient(clientOptions), + client: newMqttClient(clientOptions), inflightMessages: inflightMessages, logger: logp.NewLogger("mqtt input").With("hosts", config.Hosts), }, nil diff --git a/filebeat/input/mqtt/input_test.go b/filebeat/input/mqtt/input_test.go new file mode 100644 index 00000000000..cb6ee0526a9 --- /dev/null +++ b/filebeat/input/mqtt/input_test.go @@ -0,0 +1,205 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package mqtt + +import ( + "errors" + "sync" + "testing" + + libmqtt "github.com/eclipse/paho.mqtt.golang" + "github.com/stretchr/testify/require" + + finput "github.com/elastic/beats/filebeat/input" + "github.com/elastic/beats/libbeat/beat" + "github.com/elastic/beats/libbeat/common" + "github.com/elastic/beats/libbeat/logp" +) + +var ( + logger = logp.NewLogger("test") +) + +func TestNewInput_MissingConfigField(t *testing.T) { + config := common.MustNewConfigFrom(common.MapStr{ + "topics": "#", + }) + connector := new(mockedConnector) + var inputContext finput.Context + + input, err := NewInput(config, connector, inputContext) + + require.Error(t, err) + require.Nil(t, input) +} + +func TestNewInput_ConnectWithFailed(t *testing.T) { + connectWithError := errors.New("failure") + config := common.MustNewConfigFrom(common.MapStr{ + "hosts": "tcp://mocked:1234", + "topics": "#", + }) + connector := &mockedConnector{ + connectWithError: connectWithError, + } + var inputContext finput.Context + + input, err := NewInput(config, connector, inputContext) + + require.Equal(t, connectWithError, err) + require.Nil(t, input) +} + +func TestNewInput_Run(t *testing.T) { + config := common.MustNewConfigFrom(common.MapStr{ + "hosts": "tcp://mocked:1234", + "topics": []string{"first", "second"}, + "qos": 2, + }) + + events := make(chan beat.Event) + outlet := &mockedOutleter{ + events: events, + } + connector := &mockedConnector{ + outlet: outlet, + } + var inputContext finput.Context + + firstMessage := mockedMessage{ + duplicate: false, + messageID: 1, + qos: 2, + retained: false, + topic: "first", + payload: []byte("first-message"), + } + secondMessage := mockedMessage{ + duplicate: false, + messageID: 2, + qos: 2, + retained: false, + topic: "second", + payload: []byte("second-message"), + } + + var client *mockedClient + newMqttClient = func(o *libmqtt.ClientOptions) libmqtt.Client { + client = &mockedClient{ + onConnectHandler: o.OnConnect, + messages: []mockedMessage{firstMessage, secondMessage}, + } + return client + } + + input, err := NewInput(config, connector, inputContext) + require.NoError(t, err) + require.NotNil(t, input) + + input.Run() + + require.Equal(t, 1, client.connectCount) + require.Equal(t, 1, client.subscribeMultipleCount) + require.ElementsMatch(t, []string{"first", "second"}, client.subscriptions) + + for _, event := range []beat.Event{<-events, <-events} { + topic, err := event.GetValue("mqtt.topic") + require.NoError(t, err) + + if topic == "first" { + assertEventMatches(t, firstMessage, event) + } else { + assertEventMatches(t, secondMessage, event) + } + } +} + +func TestRun_Once(t *testing.T) { + client := new(mockedClient) + input := &mqttInput{ + client: client, + logger: logger, + } + + input.Run() + + require.Equal(t, 1, client.connectCount) +} + +func TestRun_Twice(t *testing.T) { + client := new(mockedClient) + input := &mqttInput{ + client: client, + logger: logger, + } + + input.Run() + input.Run() + + require.Equal(t, 1, client.connectCount) +} + +func TestStop(t *testing.T) { + inflightMessages := new(sync.WaitGroup) + client := new(mockedClient) + input := &mqttInput{ + client: client, + logger: logger, + inflightMessages: inflightMessages, + } + + input.Stop() + + require.Equal(t, 1, client.disconnectCount) +} + +func TestWait(t *testing.T) { + inflightMessages := new(sync.WaitGroup) + input := &mqttInput{ + logger: logger, + inflightMessages: inflightMessages, + } + + input.Wait() +} + +func assertEventMatches(t *testing.T, expected mockedMessage, got beat.Event) { + topic, err := got.GetValue("mqtt.topic") + require.NoError(t, err) + require.Equal(t, expected.topic, topic) + + duplicate, err := got.GetValue("mqtt.duplicate") + require.NoError(t, err) + require.Equal(t, expected.duplicate, duplicate) + + messageID, err := got.GetValue("mqtt.message_id") + require.NoError(t, err) + require.Equal(t, expected.messageID, messageID) + + qos, err := got.GetValue("mqtt.qos") + require.NoError(t, err) + require.Equal(t, expected.qos, qos) + + retained, err := got.GetValue("mqtt.retained") + require.NoError(t, err) + require.Equal(t, expected.retained, retained) + + message, err := got.GetValue("message") + require.NoError(t, err) + require.Equal(t, string(expected.payload), message) +} From 133bc96143103c663e19715ac8a9ce9f43e0fe40 Mon Sep 17 00:00:00 2001 From: Marcin Tojek Date: Mon, 3 Feb 2020 15:33:31 +0100 Subject: [PATCH 4/7] Test: input run --- filebeat/input/mqtt/input_test.go | 57 +++++++++++++++++++++++++++++-- 1 file changed, 54 insertions(+), 3 deletions(-) diff --git a/filebeat/input/mqtt/input_test.go b/filebeat/input/mqtt/input_test.go index cb6ee0526a9..d89fb31df97 100644 --- a/filebeat/input/mqtt/input_test.go +++ b/filebeat/input/mqtt/input_test.go @@ -72,9 +72,9 @@ func TestNewInput_Run(t *testing.T) { "qos": 2, }) - events := make(chan beat.Event) + eventsCh := make(chan beat.Event) outlet := &mockedOutleter{ - events: events, + events: eventsCh, } connector := &mockedConnector{ outlet: outlet, @@ -117,7 +117,7 @@ func TestNewInput_Run(t *testing.T) { require.Equal(t, 1, client.subscribeMultipleCount) require.ElementsMatch(t, []string{"first", "second"}, client.subscriptions) - for _, event := range []beat.Event{<-events, <-events} { + for _, event := range []beat.Event{<-eventsCh, <-eventsCh} { topic, err := event.GetValue("mqtt.topic") require.NoError(t, err) @@ -129,6 +129,57 @@ func TestNewInput_Run(t *testing.T) { } } +func TestNewInput_Run_Stop(t *testing.T) { + config := common.MustNewConfigFrom(common.MapStr{ + "hosts": "tcp://mocked:1234", + "topics": []string{"first", "second"}, + "qos": 2, + }) + + eventsCh := make(chan beat.Event) + outlet := &mockedOutleter{ + events: eventsCh, + } + connector := &mockedConnector{ + outlet: outlet, + } + var inputContext finput.Context + + const numMessages = 5 + var messages []mockedMessage + for i := 0; i < numMessages; i++ { + messages = append(messages, mockedMessage{ + duplicate: false, + messageID: 1, + qos: 2, + retained: false, + topic: "first", + payload: []byte("first-message"), + }) + } + + var client *mockedClient + newMqttClient = func(o *libmqtt.ClientOptions) libmqtt.Client { + client = &mockedClient{ + onConnectHandler: o.OnConnect, + messages: messages, + } + return client + } + + input, err := NewInput(config, connector, inputContext) + require.NoError(t, err) + require.NotNil(t, input) + + input.Run() + + go func() { + for range eventsCh {} + }() + + input.Stop() +} + func TestRun_Once(t *testing.T) { client := new(mockedClient) input := &mqttInput{ From 1c73e798e2d10214113bcf362bfbf9583d8b4036 Mon Sep 17 00:00:00 2001 From: Marcin Tojek Date: Mon, 3 Feb 2020 17:11:41 +0100 Subject: [PATCH 5/7] Fix Test: run and stop --- filebeat/input/mqtt/client_mocked.go | 17 +++++++++-------- filebeat/input/mqtt/input_test.go | 26 ++++++++++++++++++++------ 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/filebeat/input/mqtt/client_mocked.go b/filebeat/input/mqtt/client_mocked.go index 6afeaa0726b..47abb13a06a 100644 --- a/filebeat/input/mqtt/client_mocked.go +++ b/filebeat/input/mqtt/client_mocked.go @@ -134,11 +134,13 @@ func (m *mockedClient) SubscribeMultiple(filters map[string]byte, callback libmq } m.onMessageHandler = callback - go func() { - for _, msg := range m.messages { - m.onMessageHandler(m, &msg) - } - }() + for _, msg := range m.messages { + var thatMsg = msg + go func() { + m.onMessageHandler(m, &thatMsg) + }() + } + return new(mockedToken) } @@ -173,7 +175,7 @@ func (m *mockedConnector) ConnectWith(*common.Config, beat.ClientConfig) (channe } type mockedOutleter struct { - events chan<- beat.Event + onEventHandler func(event beat.Event) bool } var _ channel.Outleter = new(mockedOutleter) @@ -187,6 +189,5 @@ func (m mockedOutleter) Done() <-chan struct{} { } func (m mockedOutleter) OnEvent(event beat.Event) bool { - m.events <- event - return true + return m.onEventHandler(event) } diff --git a/filebeat/input/mqtt/input_test.go b/filebeat/input/mqtt/input_test.go index d89fb31df97..f1d8b082855 100644 --- a/filebeat/input/mqtt/input_test.go +++ b/filebeat/input/mqtt/input_test.go @@ -21,6 +21,7 @@ import ( "errors" "sync" "testing" + "time" libmqtt "github.com/eclipse/paho.mqtt.golang" "github.com/stretchr/testify/require" @@ -74,7 +75,10 @@ func TestNewInput_Run(t *testing.T) { eventsCh := make(chan beat.Event) outlet := &mockedOutleter{ - events: eventsCh, + onEventHandler: func(event beat.Event) bool { + eventsCh <- event + return true + }, } connector := &mockedConnector{ outlet: outlet, @@ -136,16 +140,23 @@ func TestNewInput_Run_Stop(t *testing.T) { "qos": 2, }) + const numMessages = 5 + + var eventProcessing sync.WaitGroup + eventProcessing.Add(numMessages) eventsCh := make(chan beat.Event) outlet := &mockedOutleter{ - events: eventsCh, + onEventHandler: func(event beat.Event) bool { + eventProcessing.Done() + eventsCh <- event + return true + }, } connector := &mockedConnector{ outlet: outlet, } var inputContext finput.Context - const numMessages = 5 var messages []mockedMessage for i := 0; i < numMessages; i++ { messages = append(messages, mockedMessage{ @@ -153,8 +164,8 @@ func TestNewInput_Run_Stop(t *testing.T) { messageID: 1, qos: 2, retained: false, - topic: "first", - payload: []byte("first-message"), + topic: "topic", + payload: []byte("a-message"), }) } @@ -172,9 +183,12 @@ func TestNewInput_Run_Stop(t *testing.T) { require.NotNil(t, input) input.Run() + eventProcessing.Wait() go func() { - for range eventsCh {} + time.Sleep(100 * time.Millisecond) // let input.Stop() be executed. + for range eventsCh { + } }() input.Stop() From 2e1955ca95346692236ab78b594def1817d840b1 Mon Sep 17 00:00:00 2001 From: Marcin Tojek Date: Tue, 4 Feb 2020 12:22:00 +0100 Subject: [PATCH 6/7] Test: backoff --- filebeat/input/mqtt/client_mocked.go | 35 +++++++++++-- filebeat/input/mqtt/input.go | 7 ++- filebeat/input/mqtt/input_test.go | 75 ++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 6 deletions(-) diff --git a/filebeat/input/mqtt/client_mocked.go b/filebeat/input/mqtt/client_mocked.go index 47abb13a06a..dadccdf9356 100644 --- a/filebeat/input/mqtt/client_mocked.go +++ b/filebeat/input/mqtt/client_mocked.go @@ -25,6 +25,7 @@ import ( "github.com/elastic/beats/filebeat/channel" "github.com/elastic/beats/libbeat/beat" "github.com/elastic/beats/libbeat/common" + "github.com/elastic/beats/libbeat/common/backoff" ) type mockedMessage struct { @@ -67,7 +68,28 @@ func (m *mockedMessage) Ack() { panic("implement me") } -type mockedToken struct{} +type mockedBackoff struct { + resetCount int + + waits []bool + waitIndex int +} + +var _ backoff.Backoff = new(mockedBackoff) + +func (m *mockedBackoff) Wait() bool { + wait := m.waits[m.waitIndex] + m.waitIndex++ + return wait +} + +func (m *mockedBackoff) Reset() { + m.resetCount++ +} + +type mockedToken struct { + timeout bool +} var _ libmqtt.Token = new(mockedToken) @@ -76,11 +98,11 @@ func (m *mockedToken) Wait() bool { } func (m *mockedToken) WaitTimeout(time.Duration) bool { - return true + return m.timeout } func (m *mockedToken) Error() error { - panic("implement me") + return nil } type mockedClient struct { @@ -91,6 +113,9 @@ type mockedClient struct { subscriptions []string messages []mockedMessage + tokens []libmqtt.Token + tokenIndex int + onConnectHandler func(client libmqtt.Client) onMessageHandler func(client libmqtt.Client, message libmqtt.Message) } @@ -141,7 +166,9 @@ func (m *mockedClient) SubscribeMultiple(filters map[string]byte, callback libmq }() } - return new(mockedToken) + token := m.tokens[m.tokenIndex] + m.tokenIndex++ + return token } func (m *mockedClient) Unsubscribe(topics ...string) libmqtt.Token { diff --git a/filebeat/input/mqtt/input.go b/filebeat/input/mqtt/input.go index 3c8e7e26a06..f1829eca4f0 100644 --- a/filebeat/input/mqtt/input.go +++ b/filebeat/input/mqtt/input.go @@ -40,7 +40,10 @@ const ( subscribeRetryInterval = 1 * time.Second ) -var newMqttClient = libmqtt.NewClient +var ( + newMqttClient = libmqtt.NewClient + newBackoff = backoff.NewEqualJitterBackoff +) // Input contains the input and its config type mqttInput struct { @@ -127,7 +130,7 @@ func createOnMessageHandler(logger *logp.Logger, outlet channel.Outleter, inflig func createOnConnectHandler(logger *logp.Logger, inputContext *input.Context, onMessageHandler func(client libmqtt.Client, message libmqtt.Message), clientSubscriptions map[string]byte) func(client libmqtt.Client) { // The function subscribes the client to the specific topics (with retry backoff in case of failure). return func(client libmqtt.Client) { - backoff := backoff.NewEqualJitterBackoff( + backoff := newBackoff( inputContext.Done, subscribeRetryInterval, 8*subscribeRetryInterval) diff --git a/filebeat/input/mqtt/input_test.go b/filebeat/input/mqtt/input_test.go index f1d8b082855..172c7badd1f 100644 --- a/filebeat/input/mqtt/input_test.go +++ b/filebeat/input/mqtt/input_test.go @@ -29,6 +29,7 @@ import ( finput "github.com/elastic/beats/filebeat/input" "github.com/elastic/beats/libbeat/beat" "github.com/elastic/beats/libbeat/common" + "github.com/elastic/beats/libbeat/common/backoff" "github.com/elastic/beats/libbeat/logp" ) @@ -107,6 +108,9 @@ func TestNewInput_Run(t *testing.T) { client = &mockedClient{ onConnectHandler: o.OnConnect, messages: []mockedMessage{firstMessage, secondMessage}, + tokens: []libmqtt.Token{&mockedToken{ + timeout: true, + }}, } return client } @@ -174,6 +178,9 @@ func TestNewInput_Run_Stop(t *testing.T) { client = &mockedClient{ onConnectHandler: o.OnConnect, messages: messages, + tokens: []libmqtt.Token{&mockedToken{ + timeout: true, + }}, } return client } @@ -243,6 +250,74 @@ func TestWait(t *testing.T) { input.Wait() } +func TestOnCreateHandler_SubscribeMultiple_Succeeded(t *testing.T) { + inputContext := new(finput.Context) + onMessageHandler := func(client libmqtt.Client, message libmqtt.Message) {} + var clientSubscriptions map[string]byte + handler := createOnConnectHandler(logger, inputContext, onMessageHandler, clientSubscriptions) + + newBackoff = func(done <-chan struct{}, init, max time.Duration) backoff.Backoff { + return backoff.NewEqualJitterBackoff(inputContext.Done, time.Nanosecond, 2*time.Nanosecond) + } + + client := &mockedClient{ + tokens: []libmqtt.Token{&mockedToken{ + timeout: true, + }}, + } + handler(client) + + require.Equal(t, 1, client.subscribeMultipleCount) +} + +func TestOnCreateHandler_SubscribeMultiple_BackoffSucceeded(t *testing.T) { + inputContext := new(finput.Context) + onMessageHandler := func(client libmqtt.Client, message libmqtt.Message) {} + var clientSubscriptions map[string]byte + handler := createOnConnectHandler(logger, inputContext, onMessageHandler, clientSubscriptions) + + newBackoff = func(done <-chan struct{}, init, max time.Duration) backoff.Backoff { + return backoff.NewEqualJitterBackoff(inputContext.Done, time.Nanosecond, 2*time.Nanosecond) + } + + client := &mockedClient{ + tokens: []libmqtt.Token{&mockedToken{ + timeout: false, + }, &mockedToken{ + timeout: true, + }}, + } + handler(client) + + require.Equal(t, 2, client.subscribeMultipleCount) +} + +func TestOnCreateHandler_SubscribeMultiple_BackoffSignalDone(t *testing.T) { + inputContext := new(finput.Context) + onMessageHandler := func(client libmqtt.Client, message libmqtt.Message) {} + var clientSubscriptions map[string]byte + handler := createOnConnectHandler(logger, inputContext, onMessageHandler, clientSubscriptions) + + mockedBackoff := &mockedBackoff{ + waits: []bool{true, false}, + } + newBackoff = func(done <-chan struct{}, init, max time.Duration) backoff.Backoff { + return mockedBackoff + } + + client := &mockedClient{ + tokens: []libmqtt.Token{&mockedToken{ + timeout: false, + }, &mockedToken{ + timeout: false, + }}, + } + handler(client) + + require.Equal(t, 2, client.subscribeMultipleCount) + require.Equal(t, 1, mockedBackoff.resetCount) +} + func assertEventMatches(t *testing.T, expected mockedMessage, got beat.Event) { topic, err := got.GetValue("mqtt.topic") require.NoError(t, err) From 1e547b88c82cf0377ea56e02d1f84eeb81ab1fb3 Mon Sep 17 00:00:00 2001 From: Marcin Tojek Date: Wed, 5 Feb 2020 21:40:27 +0100 Subject: [PATCH 7/7] Adjust code after review --- filebeat/input/mqtt/config.go | 2 +- filebeat/input/mqtt/input.go | 8 ++++---- filebeat/input/mqtt/input_test.go | 14 +++++++------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/filebeat/input/mqtt/config.go b/filebeat/input/mqtt/config.go index 03abd82e5d0..1f9c2817055 100644 --- a/filebeat/input/mqtt/config.go +++ b/filebeat/input/mqtt/config.go @@ -25,7 +25,7 @@ import ( type mqttInputConfig struct { Hosts []string `config:"hosts" validate:"required,min=1"` - Topics []string `config:"topics" validate:"nonzero,min=1"` + Topics []string `config:"topics" validate:"required,min=1"` QoS int `config:"qos" validate:"nonzero,min=0,max=2"` ClientID string `config:"clientID" validate:"nonzero"` diff --git a/filebeat/input/mqtt/input.go b/filebeat/input/mqtt/input.go index f1829eca4f0..a3e00338cb8 100644 --- a/filebeat/input/mqtt/input.go +++ b/filebeat/input/mqtt/input.go @@ -34,7 +34,7 @@ import ( ) const ( - disconnectTimeout = 3 * 1000 // 3000 ms = 3 sec + disconnectTimeout = 3 * time.Second subscribeTimeout = 35 * time.Second // in client: subscribeWaitTimeout = 30s subscribeRetryInterval = 1 * time.Second @@ -174,12 +174,12 @@ func (mi *mqttInput) Run() { // Stop method stops the input. func (mi *mqttInput) Stop() { mi.logger.Debug("Stop the input.") - mi.client.Disconnect(disconnectTimeout) - mi.Wait() + mi.client.Disconnect(uint(disconnectTimeout.Milliseconds())) } -// Wait method waits until event processing is finished. +// Wait method stops the input and waits until event processing is finished. func (mi *mqttInput) Wait() { + mi.Stop() mi.logger.Debug("Wait for the input to finish processing.") mi.inflightMessages.Wait() } diff --git a/filebeat/input/mqtt/input_test.go b/filebeat/input/mqtt/input_test.go index 172c7badd1f..9261e40b468 100644 --- a/filebeat/input/mqtt/input_test.go +++ b/filebeat/input/mqtt/input_test.go @@ -226,7 +226,7 @@ func TestRun_Twice(t *testing.T) { require.Equal(t, 1, client.connectCount) } -func TestStop(t *testing.T) { +func TestWait(t *testing.T) { inflightMessages := new(sync.WaitGroup) client := new(mockedClient) input := &mqttInput{ @@ -235,19 +235,19 @@ func TestStop(t *testing.T) { inflightMessages: inflightMessages, } - input.Stop() + input.Wait() require.Equal(t, 1, client.disconnectCount) } -func TestWait(t *testing.T) { - inflightMessages := new(sync.WaitGroup) +func TestStop(t *testing.T) { + client := new(mockedClient) input := &mqttInput{ - logger: logger, - inflightMessages: inflightMessages, + client: client, + logger: logger, } - input.Wait() + input.Stop() } func TestOnCreateHandler_SubscribeMultiple_Succeeded(t *testing.T) {