From 95edb1f78f41aecb28a5c504bc8b02b3fb32302b Mon Sep 17 00:00:00 2001 From: Osama Khalid Date: Mon, 22 Jan 2024 07:21:41 -0500 Subject: [PATCH] separate protocol and device --- device/device.go | 119 ++++++++++++++++++++++++++++++++++++++++++ exporter/exporter.go | 120 ++++++------------------------------------- logger/logger.go | 38 ++++++++++++++ main.go | 41 ++++++--------- model/device.go | 17 ++++++ protocol/aes.go | 22 ++++---- protocol/protocol.go | 12 ----- util/commit.go | 14 +++++ 8 files changed, 228 insertions(+), 155 deletions(-) create mode 100644 device/device.go create mode 100644 logger/logger.go create mode 100644 model/device.go create mode 100644 util/commit.go diff --git a/device/device.go b/device/device.go new file mode 100644 index 0000000..7eb0c4e --- /dev/null +++ b/device/device.go @@ -0,0 +1,119 @@ +package device + +import ( + "crypto/rsa" + "encoding/base64" + "fmt" + + "github.com/dehydr8/kasa-go/model" + "github.com/dehydr8/kasa-go/protocol" +) + +type Device struct { + config *model.DeviceConfig + transport protocol.Protocol +} + +type DeviceInfoResult struct { + DeviceId string `json:"device_id"` + DeviceOn bool `json:"device_on"` + Model string `json:"model"` + Type string `json:"type"` + Alias string `json:"nickname"` + Rssi int `json:"rssi"` + OnTime int `json:"on_time"` + SoftwareVersion string `json:"sw_ver"` + HardwareVersion string `json:"hw_ver"` + MAC string `json:"mac"` + Overheated bool `json:"overheated"` + PowerProtectionStatus string `json:"power_protection_status"` + OvercurrentStatus string `json:"overcurrent_status"` + SignalLevel int `json:"signal_level"` + SSID string `json:"ssid"` +} + +type EnergyUsageResult struct { + CurrentPower int `json:"current_power"` + MonthEnergy int `json:"month_energy"` + MonthRuntime int `json:"month_runtime"` + TodayEnergy int `json:"today_energy"` + TodayRuntime int `json:"today_runtime"` +} + +type EnergyUsageResponse struct { + protocol.AesProtoBaseResponse + Result EnergyUsageResult `json:"result"` +} + +type DeviceInfoResponse struct { + protocol.AesProtoBaseResponse + Result DeviceInfoResult `json:"result"` +} + +func NewDevice(key *rsa.PrivateKey, config *model.DeviceConfig) (*Device, error) { + transport, err := protocol.NewAesTransport(key, config) + if err != nil { + return nil, err + } + + return &Device{ + config: config, + transport: transport, + }, nil +} + +func (d *Device) Address() string { + return d.config.Address +} + +func (d *Device) GetEnergyUsage() (*EnergyUsageResult, error) { + var response EnergyUsageResponse + req := map[string]interface{}{ + "method": "get_energy_usage", + } + + err := d.transport.Send(&req, &response) + + if err != nil { + return nil, err + } + + if response.ErrorCode != 0 { + return nil, fmt.Errorf("error code: %d", response.ErrorCode) + } + + return &response.Result, nil +} + +func (d *Device) GetDeviceInfo() (*DeviceInfoResult, error) { + var response DeviceInfoResponse + req := map[string]interface{}{ + "method": "get_device_info", + } + + err := d.transport.Send(&req, &response) + + if err != nil { + return nil, err + } + + if response.ErrorCode != 0 { + return nil, fmt.Errorf("error code: %d", response.ErrorCode) + } + + // try decoding nickname + nickname, err := base64.StdEncoding.DecodeString(response.Result.Alias) + + if err == nil { + response.Result.Alias = string(nickname) + } + + // try decoding ssid + ssid, err := base64.StdEncoding.DecodeString(response.Result.SSID) + + if err == nil { + response.Result.SSID = string(ssid) + } + + return &response.Result, nil +} diff --git a/exporter/exporter.go b/exporter/exporter.go index a2a9e41..321e23c 100644 --- a/exporter/exporter.go +++ b/exporter/exporter.go @@ -1,67 +1,23 @@ package exporter import ( - "encoding/base64" - "fmt" - - "github.com/dehydr8/kasa-go/protocol" + "github.com/dehydr8/kasa-go/device" + "github.com/dehydr8/kasa-go/logger" "github.com/prometheus/client_golang/prometheus" - - "github.com/go-kit/log" - "github.com/go-kit/log/level" ) var _ prometheus.Collector = (*PlugExporter)(nil) -type DeviceInfoResult struct { - DeviceId string `json:"device_id"` - DeviceOn bool `json:"device_on"` - Model string `json:"model"` - Type string `json:"type"` - Alias string `json:"nickname"` - Rssi int `json:"rssi"` - OnTime int `json:"on_time"` - SoftwareVersion string `json:"sw_ver"` - HardwareVersion string `json:"hw_ver"` - MAC string `json:"mac"` - Overheated bool `json:"overheated"` - PowerProtectionStatus string `json:"power_protection_status"` - OvercurrentStatus string `json:"overcurrent_status"` - SignalLevel int `json:"signal_level"` - SSID string `json:"ssid"` -} - -type EnergyUsageResult struct { - CurrentPower int `json:"current_power"` - MonthEnergy int `json:"month_energy"` - MonthRuntime int `json:"month_runtime"` - TodayEnergy int `json:"today_energy"` - TodayRuntime int `json:"today_runtime"` -} - -type EnergyUsageResponse struct { - protocol.AesProtoBaseResponse - Result EnergyUsageResult `json:"result"` -} - -type DeviceInfoResponse struct { - protocol.AesProtoBaseResponse - Result DeviceInfoResult `json:"result"` -} - type PlugExporter struct { - target string - proto protocol.Protocol + device *device.Device metricsUp, metricsRssi, metricsPowerLoad *prometheus.Desc - - logger log.Logger } -func NewPlugExporter(host string, proto protocol.Protocol, logger log.Logger) (*PlugExporter, error) { - info, err := collectDeviceInfo(proto) +func NewPlugExporter(device *device.Device) (*PlugExporter, error) { + info, err := device.GetDeviceInfo() if err != nil { return nil, err @@ -75,8 +31,7 @@ func NewPlugExporter(host string, proto protocol.Protocol, logger log.Logger) (* } e := &PlugExporter{ - target: host, - proto: proto, + device: device, metricsPowerLoad: prometheus.NewDesc("kasa_power_load", "Current power in Milliwatts (mW)", nil, constLabels), @@ -88,78 +43,35 @@ func NewPlugExporter(host string, proto protocol.Protocol, logger log.Logger) (* metricsRssi: prometheus.NewDesc("kasa_rssi", "Wifi received signal strength indicator", nil, constLabels), - - logger: logger, } return e, nil } func (k *PlugExporter) Collect(ch chan<- prometheus.Metric) { + logger.Debug("msg", "collecting metrics", "target", k.device.Address()) - level.Debug(k.logger).Log("msg", "collecting metrics", "target", k.target) - - var energyUsageResponse EnergyUsageResponse - - if err := k.proto.Send(&map[string]interface{}{ - "method": "get_energy_usage", - }, &energyUsageResponse); err == nil && energyUsageResponse.ErrorCode == 0 { - ch <- prometheus.MustNewConstMetric(k.metricsPowerLoad, prometheus.GaugeValue, float64(energyUsageResponse.Result.CurrentPower)) + if energyUsage, err := k.device.GetEnergyUsage(); err == nil { + ch <- prometheus.MustNewConstMetric(k.metricsPowerLoad, prometheus.GaugeValue, float64(energyUsage.CurrentPower)) } else { - level.Debug(k.logger).Log("msg", "error getting energy usage", "err", err, "code", energyUsageResponse.ErrorCode) + logger.Warn("msg", "error getting energy usage", "err", err) } - var deviceInfoResponse DeviceInfoResponse - - if err := k.proto.Send(&map[string]interface{}{ - "method": "get_device_info", - }, &deviceInfoResponse); err == nil && deviceInfoResponse.ErrorCode == 0 { - - if deviceInfoResponse.Result.DeviceOn { + if deviceInfo, err := k.device.GetDeviceInfo(); err == nil { + if deviceInfo.DeviceOn { ch <- prometheus.MustNewConstMetric(k.metricsUp, prometheus.GaugeValue, 1) } else { ch <- prometheus.MustNewConstMetric(k.metricsUp, prometheus.GaugeValue, 0) } - ch <- prometheus.MustNewConstMetric(k.metricsRssi, prometheus.GaugeValue, float64(deviceInfoResponse.Result.Rssi)) + ch <- prometheus.MustNewConstMetric(k.metricsRssi, prometheus.GaugeValue, float64(deviceInfo.Rssi)) } else { - level.Debug(k.logger).Log("msg", "error getting device info", "err", err, "code", deviceInfoResponse.ErrorCode) + logger.Warn("msg", "error getting device info", "err", err) } } func (k *PlugExporter) Describe(ch chan<- *prometheus.Desc) { ch <- k.metricsPowerLoad -} - -func collectDeviceInfo(proto protocol.Protocol) (*DeviceInfoResult, error) { - var response DeviceInfoResponse - req := map[string]interface{}{ - "method": "get_device_info", - } - - err := proto.Send(&req, &response) - - if err != nil { - return nil, err - } - - if response.ErrorCode != 0 { - return nil, fmt.Errorf("error code: %d", response.ErrorCode) - } - - // try decoding nickname - nickname, err := base64.StdEncoding.DecodeString(response.Result.Alias) - - if err == nil { - response.Result.Alias = string(nickname) - } - - // try decoding ssid - ssid, err := base64.StdEncoding.DecodeString(response.Result.SSID) - - if err == nil { - response.Result.SSID = string(ssid) - } - - return &response.Result, nil + ch <- k.metricsUp + ch <- k.metricsRssi } diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 0000000..c350738 --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,38 @@ +package logger + +import ( + "os" + "time" + + "github.com/go-kit/log" + "github.com/go-kit/log/level" +) + +var ( + GLOBAL_LOGGER log.Logger +) + +func SetupLogging(lvl string) { + GLOBAL_LOGGER = log.NewLogfmtLogger(os.Stderr) + GLOBAL_LOGGER = level.NewFilter(GLOBAL_LOGGER, level.Allow(level.ParseDefault(lvl, level.InfoValue()))) + GLOBAL_LOGGER = log.With(GLOBAL_LOGGER, "ts", log.TimestampFormat( + func() time.Time { return time.Now() }, + time.DateTime, + ), "caller", log.Caller(4)) +} + +func Info(keyvals ...interface{}) error { + return level.Info(GLOBAL_LOGGER).Log(keyvals...) +} + +func Debug(keyvals ...interface{}) error { + return level.Debug(GLOBAL_LOGGER).Log(keyvals...) +} + +func Warn(keyvals ...interface{}) error { + return level.Warn(GLOBAL_LOGGER).Log(keyvals...) +} + +func Error(keyvals ...interface{}) error { + return level.Error(GLOBAL_LOGGER).Log(keyvals...) +} diff --git a/main.go b/main.go index ad93b8f..38e366d 100644 --- a/main.go +++ b/main.go @@ -6,13 +6,11 @@ import ( "flag" "fmt" "net/http" - "os" - "time" + "github.com/dehydr8/kasa-go/device" "github.com/dehydr8/kasa-go/exporter" - "github.com/dehydr8/kasa-go/protocol" - "github.com/go-kit/log" - "github.com/go-kit/log/level" + "github.com/dehydr8/kasa-go/logger" + "github.com/dehydr8/kasa-go/model" lru "github.com/hashicorp/golang-lru/v2" @@ -22,13 +20,11 @@ import ( type MetricsServer struct { key *rsa.PrivateKey - credentials *protocol.Credentials + credentials *model.Credentials registryCache *lru.Cache[string, *prometheus.Registry] - - logger log.Logger } -func NewMetricsServer(key *rsa.PrivateKey, credentials *protocol.Credentials, logger log.Logger) *MetricsServer { +func NewMetricsServer(key *rsa.PrivateKey, credentials *model.Credentials) *MetricsServer { cache, err := lru.New[string, *prometheus.Registry](10) if err != nil { @@ -39,7 +35,6 @@ func NewMetricsServer(key *rsa.PrivateKey, credentials *protocol.Credentials, lo key: key, credentials: credentials, registryCache: cache, - logger: logger, } } @@ -72,20 +67,14 @@ func main() { panic("username and password must be specified") } - var logger log.Logger - logger = log.NewLogfmtLogger(os.Stderr) - logger = level.NewFilter(logger, level.Allow(level.ParseDefault(*lvl, level.InfoValue()))) - logger = log.With(logger, "ts", log.TimestampFormat( - func() time.Time { return time.Now() }, - time.DateTime, - ), "caller", log.DefaultCaller) + logger.SetupLogging(*lvl) - credentials := protocol.Credentials{ + credentials := model.Credentials{ Username: *username, Password: *password, } - level.Debug(logger).Log("msg", "Generating RSA key") + logger.Debug("msg", "Generating RSA key") key, err := rsa.GenerateKey(rand.Reader, 1024) @@ -93,11 +82,11 @@ func main() { panic(err) } - server := NewMetricsServer(key, &credentials, logger) + server := NewMetricsServer(key, &credentials) http.HandleFunc("/scrape", server.ScrapeHandler) - level.Info(logger).Log("msg", "Starting metrics server", "address", fmt.Sprintf("%s:%d", *address, *port)) + logger.Info("msg", "Starting metrics server", "address", fmt.Sprintf("%s:%d", *address, *port)) http.ListenAndServe(fmt.Sprintf("%s:%d", *address, *port), nil) } @@ -111,18 +100,18 @@ func (s *MetricsServer) ScrapeHandler(w http.ResponseWriter, r *http.Request) { registry, err := s.getOrCreate(target, func() (*prometheus.Registry, error) { - level.Debug(s.logger).Log("msg", "Creating new registry for target", "target", target) + logger.Debug("msg", "Creating new registry for target", "target", target) - transport, err := protocol.NewAesTransport(s.key, &protocol.DeviceConfig{ + device, err := device.NewDevice(s.key, &model.DeviceConfig{ Address: target, Credentials: s.credentials, - }, s.logger) + }) if err != nil { return nil, err } - exporter, err := exporter.NewPlugExporter(target, transport, s.logger) + exporter, err := exporter.NewPlugExporter(device) if err != nil { return nil, err @@ -135,7 +124,7 @@ func (s *MetricsServer) ScrapeHandler(w http.ResponseWriter, r *http.Request) { }) if err != nil { - level.Error(s.logger).Log("msg", "Error creating registry", "err", err) + logger.Error("msg", "Error creating registry", "err", err) http.Error(w, err.Error(), 500) return } diff --git a/model/device.go b/model/device.go new file mode 100644 index 0000000..c18936c --- /dev/null +++ b/model/device.go @@ -0,0 +1,17 @@ +package model + +import "crypto/rsa" + +type Credentials struct { + Username string + Password string +} + +type DeviceConfig struct { + Address string + + Credentials *Credentials + CredentialsHash *string + + Key *rsa.PrivateKey +} diff --git a/protocol/aes.go b/protocol/aes.go index 24baee9..5fc8e0b 100644 --- a/protocol/aes.go +++ b/protocol/aes.go @@ -14,14 +14,14 @@ import ( "net/http" "time" - "github.com/go-kit/log" - "github.com/go-kit/log/level" + "github.com/dehydr8/kasa-go/logger" + "github.com/dehydr8/kasa-go/model" ) var _ Protocol = (*AesTransport)(nil) type AesTransport struct { - config *DeviceConfig + config *model.DeviceConfig loginVersion int key *rsa.PrivateKey @@ -35,8 +35,6 @@ type AesTransport struct { httpClient *http.Client commonHeaders map[string]string - - logger log.Logger } type AesProtoBaseRequest struct { @@ -98,7 +96,7 @@ type AesPassthroughResponse struct { Result AesPassthroughResponseResult `json:"result"` } -func NewAesTransport(key *rsa.PrivateKey, config *DeviceConfig, logger log.Logger) (*AesTransport, error) { +func NewAesTransport(key *rsa.PrivateKey, config *model.DeviceConfig) (*AesTransport, error) { return &AesTransport{ key: key, config: config, @@ -114,8 +112,6 @@ func NewAesTransport(key *rsa.PrivateKey, config *DeviceConfig, logger log.Logge }, cookies: make(map[string]string), sessionExpiry: time.Now(), - - logger: logger, }, nil } @@ -157,7 +153,7 @@ func (t *AesTransport) Close() error { func (t *AesTransport) handshake() error { - level.Debug(t.logger).Log("msg", "performing handshake", "target", t.config.Address) + logger.Debug("msg", "performing handshake", "target", t.config.Address) encoded, err := x509.MarshalPKIXPublicKey(&t.key.PublicKey) @@ -248,7 +244,7 @@ func (t *AesTransport) handshakeExpired() bool { func (t *AesTransport) login() error { - level.Debug(t.logger).Log("msg", "performing login", "target", t.config.Address) + logger.Debug("msg", "performing login", "target", t.config.Address) req := &AesLoginRequest{ AesProtoBaseRequest: AesProtoBaseRequest{ @@ -288,7 +284,7 @@ func (t *AesTransport) securePassthrough(request interface{}, response interface return err } - level.Debug(t.logger).Log("msg", "sending request", "request", string(marshalledRequest)) + logger.Debug("msg", "sending request", "request", string(marshalledRequest)) encrypted, err := t.session.Encrypt(marshalledRequest) @@ -348,7 +344,7 @@ func (t *AesTransport) securePassthrough(request interface{}, response interface return err } - level.Debug(t.logger).Log("msg", "received encrypted response", "encrypted", res.Result.Response) + logger.Debug("msg", "received encrypted response", "encrypted", res.Result.Response) if res.ErrorCode != 0 { return fmt.Errorf("passthrough failed with error code %d", res.ErrorCode) @@ -366,7 +362,7 @@ func (t *AesTransport) securePassthrough(request interface{}, response interface return err } - level.Debug(t.logger).Log("msg", "decrypted response", "response", string(decrypted)) + logger.Debug("msg", "decrypted response", "response", string(decrypted)) return json.Unmarshal(decrypted, response) } diff --git a/protocol/protocol.go b/protocol/protocol.go index b2592d2..6bf4795 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -1,17 +1,5 @@ package protocol -type Credentials struct { - Username string - Password string -} - -type DeviceConfig struct { - Address string - - Credentials *Credentials - CredentialsHash *string -} - type Protocol interface { Send(request, response interface{}) error Close() error diff --git a/util/commit.go b/util/commit.go new file mode 100644 index 0000000..4b2dbd9 --- /dev/null +++ b/util/commit.go @@ -0,0 +1,14 @@ +package util + +import "runtime/debug" + +var Commit = func() string { + if info, ok := debug.ReadBuildInfo(); ok { + for _, setting := range info.Settings { + if setting.Key == "vcs.revision" { + return setting.Value + } + } + } + return "unknown" +}()