diff --git a/internal/app/agent/readiness_checker/readiness_checker.go b/internal/app/agent/readiness_checker/readiness_checker.go index ac66e791..b6976f9b 100644 --- a/internal/app/agent/readiness_checker/readiness_checker.go +++ b/internal/app/agent/readiness_checker/readiness_checker.go @@ -1,5 +1,7 @@ +//nolint:staticcheck package readiness_checker +//TODO: io/ioutil is deprecated, replace to fs.FS import ( "context" "errors" @@ -154,7 +156,8 @@ func serializeReport(r *protoagent.AgentReadinessReport) ([]byte, error) { } func openLogFile(path string) (*os.File, error) { - f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o644) + // #nosec G304 + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o600) if err != nil { return nil, fmt.Errorf("failed to open the file \"%s\": %w", path, err) } @@ -166,6 +169,7 @@ func ReadReport(ctx context.Context, logDir string) (*protoagent.AgentReadinessR if err != nil { return nil, err } + // #nosec G304 fp, err := os.Open(logFilePath) if err != nil { if errors.Is(err, os.ErrNotExist) { diff --git a/internal/app/agent/upgrader/starter/files/starter.go b/internal/app/agent/upgrader/starter/files/starter.go index b9fd4db8..bb1bf7df 100644 --- a/internal/app/agent/upgrader/starter/files/starter.go +++ b/internal/app/agent/upgrader/starter/files/starter.go @@ -1,5 +1,7 @@ +//nolint:staticcheck package files +//TODO: io/ioutil is deprecated, replace to fs.FS and delete "nolint:staticcheck" import ( "fmt" "io/ioutil" diff --git a/internal/app/agent/upgrader/upgrader.go b/internal/app/agent/upgrader/upgrader.go index a110e35f..acd0a748 100644 --- a/internal/app/agent/upgrader/upgrader.go +++ b/internal/app/agent/upgrader/upgrader.go @@ -189,6 +189,7 @@ func (u *upgrader) replaceAgentExecutableWithUpgrader(_ context.Context) error { func (u *upgrader) startAgentProcess(_ context.Context) (int, error) { args := getAgentProcessFlags() + // #nosec G204 cmd := exec.Command(u.conf.AgentExecutablePath, args...) if err := cmd.Start(); err != nil { return -1, fmt.Errorf("failed to start the new agent process: %w", err) diff --git a/internal/app/agent/utils/lock_file.go b/internal/app/agent/utils/lock_file.go index c1f69e71..997d43f9 100644 --- a/internal/app/agent/utils/lock_file.go +++ b/internal/app/agent/utils/lock_file.go @@ -1,5 +1,7 @@ +//nolint:staticcheck package utils +//TODO: io/ioutil is deprecated, replace to fs.FS and delete "nolint:staticcheck" import ( "context" "errors" @@ -13,7 +15,8 @@ import ( ) func CreateLockFile(name string) error { - f, err := os.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0o666) + // #nosec G304 + f, err := os.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0o600) if err != nil { return fmt.Errorf("failed to create a lock file: %w", err) } @@ -52,6 +55,7 @@ func WatchLockFile(ctx context.Context, name string) error { } func getPID(name string) (int, error) { + // #nosec G304 pidBytes, err := ioutil.ReadFile(name) if err != nil { if errors.Is(err, os.ErrNotExist) { diff --git a/internal/app/agent/utils/utils.go b/internal/app/agent/utils/utils.go index 2ada933d..39c9a758 100644 --- a/internal/app/agent/utils/utils.go +++ b/internal/app/agent/utils/utils.go @@ -154,6 +154,7 @@ func checkFileExists(path string) error { } func copyContents(src string, dst string, mode os.FileMode) error { + // #nosec G304 srcF, err := os.Open(src) if err != nil { return fmt.Errorf("failed to open the source file %s: %w", src, err) @@ -168,6 +169,7 @@ func copyContents(src string, dst string, mode os.FileMode) error { return fmt.Errorf("failed to get the source file %s mode: %w", src, err) } } + // #nosec G304 dstF, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_TRUNC, fMode) if err != nil { return fmt.Errorf("failed to open the destination file %s: %w", dst, err) @@ -217,7 +219,7 @@ func checkDstDir(dst string) (func(e error) error, error) { if err != nil { return nil, fmt.Errorf("failed to get the list of dirs to create: %w", err) } - if err = os.MkdirAll(dstDir, 0o755); err != nil { + if err = os.MkdirAll(dstDir, 0o750); err != nil { return nil, fmt.Errorf("failed to create the destination file directory %s: %w", dstDir, err) } return func(e error) error { diff --git a/internal/app/api/server/private/options.go b/internal/app/api/server/private/options.go index daa17875..0f3c6806 100644 --- a/internal/app/api/server/private/options.go +++ b/internal/app/api/server/private/options.go @@ -127,8 +127,10 @@ func getBaseQueryForItem(db *gorm.DB, option string) *gorm.SqlExpr { SubQuery() } +//nolint:typecheck func validOptions(c *gin.Context, value interface{}) bool { var vlist []models.IValid + //TODO: Check pointers in switch switch tvalue := value.(type) { case *[]models.OptionsActions: for _, v := range *tvalue { diff --git a/internal/app/api/server/proto/vm/abh_calculator.go b/internal/app/api/server/proto/vm/abh_calculator.go index 46073d67..a747fdc9 100644 --- a/internal/app/api/server/proto/vm/abh_calculator.go +++ b/internal/app/api/server/proto/vm/abh_calculator.go @@ -9,6 +9,8 @@ import ( "sync" "soldr/internal/hardening/luavm/vm" + + "github.com/sirupsen/logrus" ) type abhCalculator struct { @@ -33,11 +35,17 @@ func calculateABH() (vm.ABH, error) { if err != nil { return nil, fmt.Errorf("failed to get the current executable path: %w", err) } + // #nosec G304 f, err := os.Open(execFile) if err != nil { return nil, fmt.Errorf("failed to open the executable file %s: %w", execFile, err) } - defer f.Close() + defer func(f *os.File) { + err = f.Close() + if err != nil { + logrus.Errorf("failed to close file: %s", err) + } + }(f) h := sha256.New() if _, err := io.Copy(h, f); err != nil { return nil, fmt.Errorf("failed to get the hash of the executable file: %w", err) diff --git a/internal/app/api/server/proto/vm/vm.go b/internal/app/api/server/proto/vm/vm.go index f50b61be..4084b946 100644 --- a/internal/app/api/server/proto/vm/vm.go +++ b/internal/app/api/server/proto/vm/vm.go @@ -48,7 +48,11 @@ func NewVM( func (v *VM) ProcessConnectionChallengeRequest(ctx context.Context, req []byte) ([]byte, error) { var connChallengeReq protoagent.ConnectionChallengeRequest - if err := protoagent.UnpackProtoMessage(&connChallengeReq, req, protoagent.Message_CONNECTION_CHALLENGE_REQUEST); err != nil { + if err := protoagent.UnpackProtoMessage( + &connChallengeReq, + req, + protoagent.Message_CONNECTION_CHALLENGE_REQUEST, + ); err != nil { return nil, fmt.Errorf("failed to unpack the connection challenge request: %w", err) } agentID, err := getAgentID(ctx) diff --git a/internal/app/api/utils/meter/gorm.go b/internal/app/api/utils/meter/gorm.go index a570faab..3bfa0998 100644 --- a/internal/app/api/utils/meter/gorm.go +++ b/internal/app/api/utils/meter/gorm.go @@ -93,31 +93,31 @@ func ApplyGorm(db *gorm.DB) { recordStartTime(scope) }) db.Callback().Create().After("gorm:create").Register("metric:create-after", func(scope *gorm.Scope) { - labels := append(labels, makeQueryLabel(scope)) - recordHistogram(scope, createHist, labels...) + l := append(labels, makeQueryLabel(scope)) + recordHistogram(scope, createHist, l...) }) db.Callback().Update().Before("gorm:update").Register("metric:update-before", func(scope *gorm.Scope) { recordStartTime(scope) }) db.Callback().Update().After("gorm:update").Register("metric:update-after", func(scope *gorm.Scope) { - labels := append(labels, makeQueryLabel(scope)) - recordHistogram(scope, updateHist, labels...) + l := append(labels, makeQueryLabel(scope)) + recordHistogram(scope, updateHist, l...) }) db.Callback().Query().Before("gorm:query").Register("metric:query-before", func(scope *gorm.Scope) { recordStartTime(scope) }) db.Callback().Query().After("gorm:query").Register("metric:query-after", func(scope *gorm.Scope) { - labels := append(labels, makeQueryLabel(scope)) - recordHistogram(scope, queryHist, labels...) + l := append(labels, makeQueryLabel(scope)) + recordHistogram(scope, queryHist, l...) }) db.Callback().Delete().Before("gorm:delete").Register("metric:delete-before", func(scope *gorm.Scope) { recordStartTime(scope) }) db.Callback().Delete().After("gorm:delete").Register("metric:delete-after", func(scope *gorm.Scope) { - labels := append(labels, makeQueryLabel(scope)) - recordHistogram(scope, deleteHist, labels...) + l := append(labels, makeQueryLabel(scope)) + recordHistogram(scope, deleteHist, l...) }) } diff --git a/internal/app/app.go b/internal/app/app.go index 9825583e..0de4b58b 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -2,6 +2,7 @@ package app import ( "github.com/oklog/run" + "github.com/sirupsen/logrus" ) type AppGroup struct { @@ -13,7 +14,10 @@ func NewAppGroup() *AppGroup { } func (a *AppGroup) Run() { - a.runGroup.Run() + err := a.runGroup.Run() + if err != nil { + logrus.Errorf("failed to run appGroup: %s", err) + } } // Add (function) to the application group. Each actor must be pre-emptable by an diff --git a/internal/app/server/certs/provider/static/static.go b/internal/app/server/certs/provider/static/static.go index 55156c92..aac4146d 100644 --- a/internal/app/server/certs/provider/static/static.go +++ b/internal/app/server/certs/provider/static/static.go @@ -1,5 +1,7 @@ +//nolint:staticcheck package static +//TODO: io/ioutil is deprecated, replace to fs.FS and delete "nolint:staticcheck" import ( "crypto/ed25519" "crypto/rand" @@ -80,8 +82,8 @@ func getVXCAs(certsDir string) (map[string]*x509Cert, error) { if filepath.Ext(fi.Name()) != ".cert" { continue } - fPath := filepath.Join(vxcasPath, fi.Name()) - data, err := os.ReadFile(fPath) + // #nosec G304 + data, err := os.ReadFile(filepath.Join(vxcasPath, fi.Name())) if err != nil { return nil, err } @@ -154,8 +156,8 @@ func getCertKeyFiles(filesInfo []fs.FileInfo, baseDir string) (certKeyFiles, err if ext != certExt && ext != keyExt { continue } - fPath := filepath.Join(baseDir, fi.Name()) - data, err := os.ReadFile(fPath) + // #nosec G304 + data, err := os.ReadFile(filepath.Join(baseDir, fi.Name())) if err != nil { return nil, err } diff --git a/internal/app/server/config/config.go b/internal/app/server/config/config.go index e102465d..e10dadbb 100644 --- a/internal/app/server/config/config.go +++ b/internal/app/server/config/config.go @@ -1,5 +1,7 @@ +//nolint:staticcheck package config +//TODO: io/ioutil is deprecated, replace to fs.FS and delete "nolint:staticcheck" import ( "encoding/json" "errors" @@ -133,6 +135,7 @@ var defaultConfig = &Config{ func parseConfigFile(path string) (*Config, error) { var cfg Config + // #nosec G304 cfgData, err := ioutil.ReadFile(path) if err != nil { return nil, err diff --git a/internal/controller/loader.go b/internal/controller/loader.go index c224264a..339276f4 100644 --- a/internal/controller/loader.go +++ b/internal/controller/loader.go @@ -24,9 +24,62 @@ const ( // List of SQL queries string const ( - sLoadModulesSQL string = "SELECT m.`id`, IFNULL(g.`hash`, '') AS `group_id`, IFNULL(p.`hash`, '') AS `policy_id`, m.`info`, m.`last_update`, m.`last_module_update`, m.`state`, m.`template` FROM `modules` m LEFT JOIN (SELECT * FROM `policies` UNION SELECT 0, '', '{}', NOW(), NOW(), NULL) p ON m.`policy_id` = p.`id` AND p.deleted_at IS NULL LEFT JOIN `groups_to_policies` gp ON gp.`policy_id` = p.`id` LEFT JOIN (SELECT * FROM `groups` UNION SELECT 0, '', '{}', NOW(), NOW(), NULL) g ON gp.`group_id` = g.`id` AND g.deleted_at IS NULL WHERE m.`status` = 'joined' AND NOT (ISNULL(g.`hash`) AND p.`hash` NOT LIKE '') AND m.deleted_at IS NULL" - sGetModuleFieldSQL string = "SELECT `%s` FROM `modules` WHERE `id` = ? LIMIT 1" - sSetModuleFieldSQL string = "UPDATE `modules` SET `%s` = ? WHERE `id` = ?" + sLoadModulesSQL string = ` + SELECT + m.id + , IFNULL(g.hash, '') AS group_id + , IFNULL(p.hash, '') AS policy_id + , m.info + , m.last_update + , m.last_module_update + , m.state + , m.template + FROM modules m + LEFT JOIN ( + SELECT * + FROM policies + UNION + SELECT + 0 + , '' + , '{}' + , NOW() + , NOW() + , NULL + ) p ON + m.policy_id = p.id + AND p.deleted_at IS NULL + LEFT JOIN groups_to_policies gp + ON gp.policy_id = p.id + LEFT JOIN ( + SELECT * + FROM groups + UNION + SELECT + 0 + , '' + , '{}' + , NOW() + , NOW() + , NULL + ) g ON + gp.group_id = g.id + AND g.deleted_at IS NULL + WHERE m.status = 'joined' + AND NOT (ISNULL(g.hash) + AND p.hash NOT LIKE '') + AND m.deleted_at IS NULL` + + sGetModuleFieldSQL string = ` + SELECT %s + FROM modules + WHERE id = ? + LIMIT 1` + + sSetModuleFieldSQL string = ` + UPDATE modules + SET %s = ? + WHERE id = ?` ) // tFilesLoaderType is type for loading module @@ -115,7 +168,9 @@ func (cl *configLoaderDB) load() ([]*loader.ModuleConfig, error) { return nil, fmt.Errorf("failed to parse the module config: %w", err) } } else { - return nil, fmt.Errorf("failed to load the module config: returned rows do not contain the field '%s'", moduleInfoField) + return nil, fmt.Errorf( + "failed to load the module config: returned rows do not contain the field '%s'", moduleInfoField, + ) } if groupID, ok := m["group_id"]; ok { mc.GroupID = groupID diff --git a/internal/errors/errors.go b/internal/errors/errors.go index 90eb3360..00e59c1f 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -3,7 +3,8 @@ package errors import "errors" var ( - ErrConnectionInitializationRequired = errors.New("failed to connect to the server: a connection initialization required") - ErrUnexpectedUnpackType = errors.New("unexpected agent message type") - ErrRecordNotFound = errors.New("record not found") + ErrConnectionInitializationRequired = errors.New("failed to connect to the server: " + + "a connection initialization required") + ErrUnexpectedUnpackType = errors.New("unexpected agent message type") + ErrRecordNotFound = errors.New("record not found") ) diff --git a/internal/hardening/luavm/store/simple/simple.go b/internal/hardening/luavm/store/simple/simple.go index f0bf274f..38a0c8f5 100644 --- a/internal/hardening/luavm/store/simple/simple.go +++ b/internal/hardening/luavm/store/simple/simple.go @@ -1,5 +1,7 @@ +//nolint:staticcheck package simple +//TODO: io/ioutil is deprecated, replace to fs.FS and delete "nolint:staticcheck" import ( "encoding/json" "errors" diff --git a/internal/hardening/luavm/vm/abh_calculator.go b/internal/hardening/luavm/vm/abh_calculator.go index d9cefcee..7cb6a2b1 100644 --- a/internal/hardening/luavm/vm/abh_calculator.go +++ b/internal/hardening/luavm/vm/abh_calculator.go @@ -6,6 +6,8 @@ import ( "io" "os" "path/filepath" + + "github.com/sirupsen/logrus" ) type ABH []byte @@ -25,11 +27,17 @@ func (c *abhCalculator) GetABH() (ABH, error) { if err != nil { return nil, fmt.Errorf("failed to get the current executable path: %w", err) } + // #nosec G304 f, err := os.Open(execFile) if err != nil { return nil, fmt.Errorf("failed to open the executable file %s: %w", execFile, err) } - defer f.Close() + defer func(f *os.File) { + err = f.Close() + if err != nil { + logrus.Errorf("failed close file: %s", err) + } + }(f) h := sha256.New() if _, err := io.Copy(h, f); err != nil { return nil, fmt.Errorf("failed to get the hash of the executable file: %w", err) diff --git a/internal/hardening/luavm/vm/securestore.go b/internal/hardening/luavm/vm/securestore.go index 21207c9f..e51b10f2 100644 --- a/internal/hardening/luavm/vm/securestore.go +++ b/internal/hardening/luavm/vm/securestore.go @@ -1,5 +1,7 @@ +//nolint:gosec package vm +//TODO: replace mean cryptographic primitive and delete "nolint:gosec" import ( "crypto/rc4" "crypto/sha256" diff --git a/internal/hardening/luavm/vm/tls_configurer.go b/internal/hardening/luavm/vm/tls_configurer.go index 2a81478a..574f386f 100644 --- a/internal/hardening/luavm/vm/tls_configurer.go +++ b/internal/hardening/luavm/vm/tls_configurer.go @@ -26,7 +26,11 @@ type SimpleTLSConfigurer struct { scaStore scaStore } -func NewSimpleTLSConfigurer(certsProvider certs.CertProvider, ltacGetter LTACGetter, scaStore scaStore) *SimpleTLSConfigurer { +func NewSimpleTLSConfigurer( + certsProvider certs.CertProvider, + ltacGetter LTACGetter, + scaStore scaStore, +) *SimpleTLSConfigurer { return &SimpleTLSConfigurer{ certsProvider: certsProvider, ltacGetter: ltacGetter, @@ -43,6 +47,9 @@ func (c *SimpleTLSConfigurer) GetTLSConfigForInitConnection() (*tls.Config, erro if err != nil { return nil, fmt.Errorf("failed to get the VXCA pool: %w", err) } + //TODO: TLS version is too low + + // #nosec G402 tlsConfig := &tls.Config{ Certificates: []tls.Certificate{*iac}, RootCAs: vxcaPool, @@ -61,6 +68,9 @@ func (c *SimpleTLSConfigurer) GetTLSConfigForConnection() (*tls.Config, error) { if err != nil { return nil, fmt.Errorf("failed to get the LTAC certificate: %w", err) } + //TODO: TLS version is too low + + // #nosec G402 return &tls.Config{ Certificates: []tls.Certificate{*ltacCert}, RootCAs: vxcaPool, @@ -100,8 +110,14 @@ func (c *SimpleTLSConfigurer) getLTACCertificate() (*tls.Certificate, error) { return &cert, nil } -func (c *SimpleTLSConfigurer) initConnectionVerifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { - c.scaStore.PopSCA() +func (c *SimpleTLSConfigurer) initConnectionVerifyPeerCertificate( + rawCerts [][]byte, + verifiedChains [][]*x509.Certificate, +) error { + _, err := c.scaStore.PopSCA() + if err != nil { + return fmt.Errorf("popSCA is failed: %w", err) + } if len(rawCerts) != 2 { return fmt.Errorf("expected to get two raw certs, actually got %d", len(rawCerts)) } @@ -111,7 +127,7 @@ func (c *SimpleTLSConfigurer) initConnectionVerifyPeerCertificate(rawCerts [][]b if len(verifiedChains[0]) != 3 { return fmt.Errorf("expected to get three certificates in the verified chain, actually got %d", len(verifiedChains[0])) } - if err := c.scaStore.PushSCA(rawCerts[1]); err != nil { + if err = c.scaStore.PushSCA(rawCerts[1]); err != nil { return fmt.Errorf("failed to save the passed SCA certificate: %w", err) } return nil diff --git a/internal/hardening/luavm/vm/vm.go b/internal/hardening/luavm/vm/vm.go index 33537f2f..60153161 100644 --- a/internal/hardening/luavm/vm/vm.go +++ b/internal/hardening/luavm/vm/vm.go @@ -21,6 +21,8 @@ import ( "soldr/internal/protoagent" utilsErrors "soldr/internal/utils/errors" vxprotoTunnel "soldr/internal/vxproto/tunnel" + + "github.com/sirupsen/logrus" ) const certServerName = "example" @@ -164,10 +166,16 @@ func (v *vm) popSCA() ([]byte, error) { return sca, nil } -func (v *vm) PrepareInitConnectionRequest(info *InitConnectionAgentInfo, agentInfo *protoagent.Information) (msg []byte, err error) { +func (v *vm) PrepareInitConnectionRequest( + info *InitConnectionAgentInfo, + agentInfo *protoagent.Information, +) (msg []byte, err error) { defer func() { if err != nil { - v.popSCA() + _, err = v.popSCA() + if err != nil { + logrus.Errorf("popSCA is failed: %s", err) + } } }() csr, ltacKey, err := v.generateLTACRequest() @@ -191,7 +199,10 @@ func (v *vm) PrepareInitConnectionRequest(info *InitConnectionAgentInfo, agentIn func (v *vm) ProcessInitConnectionResponse(respData []byte) (err error) { defer func() { if err != nil { - v.popSCA() + _, err = v.popSCA() + if err != nil { + logrus.Errorf("popSCA is failed: %s", err) + } } }() var initConnResp protoagent.InitConnectionResponse @@ -271,7 +282,11 @@ func (vm *vm) ProcessConnectionRequest(req []byte, packEncryptor vxprotoTunnel.P } if err = vm.checkSBH(connReq.Sbh); err != nil { if resetErr := vm.Reset(); resetErr != nil { - return nil, fmt.Errorf("failed to reset the VM store (%v), while processing the SBH verification error: %w", resetErr, err) + return nil, fmt.Errorf( + "failed to reset the VM store (%v), while processing the SBH verification error: %w", + resetErr, + err, + ) } return nil, err } @@ -348,7 +363,12 @@ func checkKeyPairValidity(pub interface{}, priv []byte) error { return nil } -func prepareInitConnReqProtoMsg(csr []byte, abh []byte, info *InitConnectionAgentInfo, agentInfo *protoagent.Information) ([]byte, error) { +func prepareInitConnReqProtoMsg( + csr []byte, + abh []byte, + info *InitConnectionAgentInfo, + agentInfo *protoagent.Information, +) ([]byte, error) { os := runtime.GOOS arch := runtime.GOARCH req := &protoagent.InitConnectionRequest{ @@ -378,6 +398,9 @@ func (v *vm) GetTLSConfigForInitConnection() (*tls.Config, error) { if err != nil { return nil, fmt.Errorf("failed to get the VXCA pool: %w", err) } + //TODO: TLS version is too low + + // #nosec G402 tlsConfig := &tls.Config{ Certificates: []tls.Certificate{*iac}, RootCAs: vxcaPool, @@ -389,7 +412,10 @@ func (v *vm) GetTLSConfigForInitConnection() (*tls.Config, error) { func (v *vm) initConnectionVerifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { // TODO(SSH): change - v.popSCA() + _, err := v.popSCA() + if err != nil { + return fmt.Errorf("failed to pop SCA certificate: %w", err) + } if len(rawCerts) != 2 { return fmt.Errorf("expected to get two raw certs, actually got %d", len(rawCerts)) } @@ -399,7 +425,7 @@ func (v *vm) initConnectionVerifyPeerCertificate(rawCerts [][]byte, verifiedChai if len(verifiedChains[0]) != 3 { return fmt.Errorf("expected to get three certificates in the verified chain, actually got %d", len(verifiedChains[0])) } - if err := v.pushSCA(rawCerts[1]); err != nil { + if err = v.pushSCA(rawCerts[1]); err != nil { return fmt.Errorf("failed to save the passed SCA certificate: %w", err) } return nil @@ -414,6 +440,9 @@ func (v *vm) GetTLSConfigForConnection() (*tls.Config, error) { if err != nil { return nil, fmt.Errorf("failed to get the LTAC certificate: %w", err) } + //TODO: TLS version is too low + + // #nosec G402 return &tls.Config{ Certificates: []tls.Certificate{*ltacCert}, RootCAs: vxcaPool, @@ -423,7 +452,11 @@ func (v *vm) GetTLSConfigForConnection() (*tls.Config, error) { func (v *vm) ProcessConnectionChallengeRequest(ctx context.Context, req []byte) ([]byte, error) { var connChallengeReq protoagent.ConnectionChallengeRequest - if err := protoagent.UnpackProtoMessage(&connChallengeReq, req, protoagent.Message_CONNECTION_CHALLENGE_REQUEST); err != nil { + if err := protoagent.UnpackProtoMessage( + &connChallengeReq, + req, + protoagent.Message_CONNECTION_CHALLENGE_REQUEST, + ); err != nil { return nil, fmt.Errorf("failed to unpack the connection challenge request: %w", err) } ct, err := v.prepareChallengeResponseCT(&connChallengeReq) diff --git a/internal/hardening/validator/init.go b/internal/hardening/validator/init.go index 9dcaf888..6767ae51 100644 --- a/internal/hardening/validator/init.go +++ b/internal/hardening/validator/init.go @@ -10,7 +10,11 @@ import ( "soldr/internal/vxproto" ) -func (v *Validator) OnInitConnect(ctx context.Context, socket vxproto.SyncWS, agentInfo *protoagent.Information) (err error) { +func (v *Validator) OnInitConnect( + ctx context.Context, + socket vxproto.SyncWS, + agentInfo *protoagent.Information, +) (err error) { defer func() { if err != nil { v.vm.ResetInitConnection() diff --git a/internal/hardening/validator/validator.go b/internal/hardening/validator/validator.go index b94a3b34..d938ade2 100644 --- a/internal/hardening/validator/validator.go +++ b/internal/hardening/validator/validator.go @@ -8,10 +8,10 @@ import ( ) type Validator struct { - vm vm.VM - agentID string - version string - connProtocolVersion string + vm vm.VM + agentID string + version string + //connProtocolVersion string - unused field } func NewValidator(agentID string, version string, luaVM vm.VM) *Validator { diff --git a/internal/lua/module.go b/internal/lua/module.go index 7c5a19b8..12430c2c 100644 --- a/internal/lua/module.go +++ b/internal/lua/module.go @@ -84,7 +84,10 @@ func (m *Module) Start() { m.wgRun.Add(1) go func() { defer m.wgRun.Done() - m.recvPacket() + err := m.recvPacket() + if err != nil { + m.logger.WithContext(m.state.ctx).Errorf("failed to recive packet: %s", err) + } }() m.logger.WithContext(m.state.ctx).Info("the module was started") defer func(m *Module) { @@ -118,7 +121,10 @@ func (m *Module) Stop() { defer m.logger.WithContext(m.state.ctx).Info("the module stopping has done") m.closed = true - m.controlMsgCb(m.state.ctx, "quit", "") + _, err := m.controlMsgCb(m.state.ctx, "quit", "") + if err != nil { + m.logger.WithContext(m.state.ctx).Error(err) + } close(m.quit) close(m.notifier) @@ -1449,25 +1455,28 @@ func NewModule(args map[string][]string, state *State, socket vxproto.IModuleSoc state.L.SetGlobal("__args") // TODO: change it to native load function - state.L.DoString(` - io.stdout:setvbuf('no') - - function __api.async(f, ...) - local glue = require("glue") - __api.unsafe.unlock() - t = glue.pack(f(...)) - __api.unsafe.lock() - return glue.unpack(t) - end - - function __api.sync(f, ...) - local glue = require("glue") - __api.unsafe.lock() - t = glue.pack(f(...)) - __api.unsafe.unlock() - return glue.unpack(t) - end - `) + err := state.L.DoString(` + io.stdout:setvbuf('no') + + function __api.async(f, ...) + local glue = require("glue") + __api.unsafe.unlock() + t = glue.pack(f(...)) + __api.unsafe.lock() + return glue.unpack(t) + end + + function __api.sync(f, ...) + local glue = require("glue") + __api.unsafe.lock() + t = glue.pack(f(...)) + __api.unsafe.unlock() + return glue.unpack(t) + end + `) + if err != nil { + logrus.Errorf("DoString is failed: %s", err) + } m.logger.WithContext(state.ctx).Info("the module was created") return m, nil diff --git a/internal/lua/state.go b/internal/lua/state.go index e97661c0..88a3c48b 100644 --- a/internal/lua/state.go +++ b/internal/lua/state.go @@ -1,5 +1,7 @@ +//nolint:staticcheck package lua +//TODO: io/ioutil is deprecated, replace to fs.FS and delete "nolint:staticcheck" import ( "context" "fmt" @@ -58,7 +60,7 @@ func NewState(files map[string][]byte) (*State, error) { pathToPID := filepath.Join(tmpdir, "lock.pid") pid := strconv.Itoa(os.Getpid()) - if err = ioutil.WriteFile(pathToPID, []byte(pid), 0640); err != nil { + if err = ioutil.WriteFile(pathToPID, []byte(pid), 0600); err != nil { return nil, err } @@ -91,7 +93,10 @@ func NewState(files map[string][]byte) (*State, error) { for name, data := range nfiles { lfiles[name] = string(data) } - s.getRegisterFFILoader()(s.L, s.tmpdir) + err = s.getRegisterFFILoader()(s.L, s.tmpdir) + if err != nil { + return nil, err + } luar.GoToLua(s.L, tmpdir) s.L.SetGlobal("__tmpdir") luar.GoToLua(s.L, lfiles) @@ -151,7 +156,7 @@ func writeFile(ctx context.Context, path string, contents []byte, logger *logrus Error("failed to create a directory") return fmt.Errorf("failed to create a directory: %w", err) } - if err := ioutil.WriteFile(path, contents, 0640); err != nil { + if err := ioutil.WriteFile(path, contents, 0600); err != nil { logger.WithContext(ctx).WithError(err).WithField("name", path). Error("failed to write a file") return fmt.Errorf("failed to write a file: %w", err) @@ -164,7 +169,7 @@ func (s *State) getModuleLoader() func(*lua.State) int { return func(L *lua.State) int { moduleName := L.CheckString(1) - logger := logger.WithFields(logrus.Fields{ + l := logger.WithFields(logrus.Fields{ "component": "lua_module_loader", "require": moduleName, }) @@ -173,7 +178,7 @@ func (s *State) getModuleLoader() func(*lua.State) int { err := luar.LuaToGo(L, -1, &files) L.Pop(1) if err != nil { - logger.WithError(err).Error("failed to put the module files into the lua state") + l.WithError(err).Error("failed to put the module files into the lua state") } var moduleNames []string @@ -184,7 +189,7 @@ func (s *State) getModuleLoader() func(*lua.State) int { moduleData, ok := files[filePath] if ok && L.LoadBuffer([]byte(moduleData), len(moduleData), moduleName) != 0 { err = fmt.Errorf(L.ToString(-1)) - logger.WithError(err).Error("failed to put the module data into the lua state") + l.WithError(err).Error("failed to put the module data into the lua state") L.Pop(1) break } @@ -348,51 +353,57 @@ func (s *State) RegisterLogger(level logrus.Level, fields logrus.Fields) error { s.L.SetField(-2, "_debug") s.L.SetGlobal("__log") - s.L.DoString(` - function __log.error(...) - if __log.level >= __log.level_error then - __log._error(...) + err := s.L.DoString(` + function __log.error(...) + if __log.level >= __log.level_error then + __log._error(...) + end end - end - function __log.warn(...) - if __log.level >= __log.level_warn then - __log._warn(...) + function __log.warn(...) + if __log.level >= __log.level_warn then + __log._warn(...) + end end - end - function __log.info(...) - if __log.level >= __log.level_info then - __log._info(...) + function __log.info(...) + if __log.level >= __log.level_info then + __log._info(...) + end end - end - function __log.debug(...) - if __log.level >= __log.level_debug then - __log._debug(...) + function __log.debug(...) + if __log.level >= __log.level_debug then + __log._debug(...) + end end - end - `) + `) + if err != nil { + logrus.Errorf("DoString is failed: %s", err) + } - s.L.DoString(` - function __log.errorf(fmt, ...) - if __log.level >= __log.level_error then - __log._error(string.format(fmt, ...)) + err = s.L.DoString(` + function __log.errorf(fmt, ...) + if __log.level >= __log.level_error then + __log._error(string.format(fmt, ...)) + end end - end - function __log.warnf(fmt, ...) - if __log.level >= __log.level_warn then - __log._warn(string.format(fmt, ...)) + function __log.warnf(fmt, ...) + if __log.level >= __log.level_warn then + __log._warn(string.format(fmt, ...)) + end end - end - function __log.infof(fmt, ...) - if __log.level >= __log.level_info then - __log._info(string.format(fmt, ...)) + function __log.infof(fmt, ...) + if __log.level >= __log.level_info then + __log._info(string.format(fmt, ...)) + end end - end - function __log.debugf(fmt, ...) - if __log.level >= __log.level_debug then - __log._debug(string.format(fmt, ...)) + function __log.debugf(fmt, ...) + if __log.level >= __log.level_debug then + __log._debug(string.format(fmt, ...)) + end end - end - `) + `) + if err != nil { + logrus.Errorf("DoString is failed: %s", err) + } return nil } diff --git a/internal/observability/collector.go b/internal/observability/collector.go index 4ea88aa0..85471bc7 100644 --- a/internal/observability/collector.go +++ b/internal/observability/collector.go @@ -45,11 +45,33 @@ func startProcessMetricCollect(meter metric.Meter, attrs []attribute.KeyValue) e } if _, err := proc.MemoryInfo(); err == nil { - meter.NewInt64GaugeObserver("process_resident_memory_bytes", collectRssMem) - meter.NewInt64GaugeObserver("process_virtual_memory_bytes", collectVirtMem) + _, err = meter.NewInt64GaugeObserver("process_resident_memory_bytes", collectRssMem) + if err != nil { + logrus.Errorf( + "meter.NewInt64GaugeObserver is failed: %s: %s", + "process_resident_memory_bytes", + err, + ) + } + + _, err = meter.NewInt64GaugeObserver("process_virtual_memory_bytes", collectVirtMem) + if err != nil { + logrus.Errorf( + "meter.NewInt64GaugeObserver is failed: %s: %s", + "process_virtual_memory_bytes", + err, + ) + } } if _, err := proc.Percent(time.Duration(0)); err == nil { - meter.NewFloat64GaugeObserver("process_cpu_usage_percent", collectCpuPercent) + _, err = meter.NewFloat64GaugeObserver("process_cpu_usage_percent", collectCpuPercent) + if err != nil { + logrus.Errorf( + "meter.NewInt64GaugeObserver is failed: %s: %s", + "process_cpu_usage_percent", + err, + ) + } } return nil @@ -75,35 +97,97 @@ func startGoRuntimeMetricCollect(meter metric.Meter, attrs []attribute.KeyValue) return &procRuntimeMemStat } - meter.NewInt64GaugeObserver("go_cgo_calls", func(ctx context.Context, m metric.Int64ObserverResult) { - m.Observe(runtime.NumCgoCall(), attrs...) - }) - meter.NewInt64GaugeObserver("go_goroutines", func(ctx context.Context, m metric.Int64ObserverResult) { - m.Observe(int64(runtime.NumGoroutine()), attrs...) - }) - meter.NewInt64GaugeObserver("go_heap_objects_bytes", func(ctx context.Context, m metric.Int64ObserverResult) { - m.Observe(int64(getMemStats().HeapInuse), attrs...) - }) - meter.NewInt64GaugeObserver("go_heap_objects_counter", func(ctx context.Context, m metric.Int64ObserverResult) { - m.Observe(int64(getMemStats().HeapObjects), attrs...) - }) - meter.NewInt64GaugeObserver("go_stack_inuse_bytes", func(ctx context.Context, m metric.Int64ObserverResult) { - m.Observe(int64(getMemStats().StackInuse), attrs...) - }) - meter.NewInt64GaugeObserver("go_stack_sys_bytes", func(ctx context.Context, m metric.Int64ObserverResult) { - m.Observe(int64(getMemStats().StackSys), attrs...) - }) - meter.NewInt64GaugeObserver("go_total_allocs_bytes", func(ctx context.Context, m metric.Int64ObserverResult) { - m.Observe(int64(getMemStats().TotalAlloc), attrs...) - }) - meter.NewInt64GaugeObserver("go_heap_allocs_bytes", func(ctx context.Context, m metric.Int64ObserverResult) { - m.Observe(int64(getMemStats().HeapAlloc), attrs...) - }) - meter.NewInt64GaugeObserver("go_pause_gc_total_nanosec", func(ctx context.Context, m metric.Int64ObserverResult) { - m.Observe(int64(getMemStats().PauseTotalNs), attrs...) - }) + _, err := meter.NewInt64GaugeObserver( + "go_cgo_calls", + func(ctx context.Context, m metric.Int64ObserverResult) { + m.Observe(runtime.NumCgoCall(), attrs...) + }, + ) + if err != nil { + logrus.WithError(err).Errorf("meter.NewInt64GaugeObserver is failed: %s", "go_cgo_calls") + } + + _, err = meter.NewInt64GaugeObserver( + "go_goroutines", + func(ctx context.Context, m metric.Int64ObserverResult) { + m.Observe(int64(runtime.NumGoroutine()), attrs...) + }, + ) + if err != nil { + logrus.WithError(err).Errorf("meter.NewInt64GaugeObserver is failed: %s", "go_goroutines") + } - return nil + _, err = meter.NewInt64GaugeObserver( + "go_heap_objects_bytes", + func(ctx context.Context, m metric.Int64ObserverResult) { + m.Observe(int64(getMemStats().HeapInuse), attrs...) + }, + ) + if err != nil { + logrus.WithError(err).Errorf("meter.NewInt64GaugeObserver is failed: %s", "go_heap_objects_types") + } + + _, err = meter.NewInt64GaugeObserver( + "go_heap_objects_counter", + func(ctx context.Context, m metric.Int64ObserverResult) { + m.Observe(int64(getMemStats().HeapObjects), attrs...) + }, + ) + if err != nil { + logrus.WithError(err).Errorf("meter.NewInt64GaugeObserver is failed: %s", "go_heap_objects_counter") + } + + _, err = meter.NewInt64GaugeObserver( + "go_stack_inuse_bytes", + func(ctx context.Context, m metric.Int64ObserverResult) { + m.Observe(int64(getMemStats().StackInuse), attrs...) + }, + ) + if err != nil { + logrus.WithError(err).Errorf("meter.NewInt64GaugeObserver is failed: %s", "go_stack_inuse_bytes") + } + + _, err = meter.NewInt64GaugeObserver( + "go_stack_sys_bytes", + func(ctx context.Context, m metric.Int64ObserverResult) { + m.Observe(int64(getMemStats().StackSys), attrs...) + }, + ) + if err != nil { + logrus.WithError(err).Errorf("meter.NewInt64GaugeObserver is failed: %s", "go_stack_sys_bytes") + } + + _, err = meter.NewInt64GaugeObserver( + "go_total_allocs_bytes", + func(ctx context.Context, m metric.Int64ObserverResult) { + m.Observe(int64(getMemStats().TotalAlloc), attrs...) + }, + ) + if err != nil { + logrus.WithError(err).Errorf("meter.NewInt64GaugeObserver is failed: %s", "go_total_allocs_bytes") + } + + _, err = meter.NewInt64GaugeObserver( + "go_heap_allocs_bytes", + func(ctx context.Context, m metric.Int64ObserverResult) { + m.Observe(int64(getMemStats().HeapAlloc), attrs...) + }, + ) + if err != nil { + logrus.WithError(err).Errorf("meter.NewInt64GaugeObserver is failed: %s", "go_heap_allocs_bytes") + } + + _, err = meter.NewInt64GaugeObserver( + "go_pause_gc_total_nanosec", + func(ctx context.Context, m metric.Int64ObserverResult) { + m.Observe(int64(getMemStats().PauseTotalNs), attrs...) + }, + ) + if err != nil { + logrus.WithError(err).Errorf("meter.NewInt64GaugeObserver is failed: %s", "go_pause_gc_total_nanosec") + } + + return err } func startDumperMetricCollect(stats IDumper, meter metric.Meter, attrs []attribute.KeyValue) error { @@ -136,11 +220,14 @@ func startDumperMetricCollect(stats IDumper, meter metric.Meter, attrs []attribu for key := range lastStats { metricName := key - meter.NewFloat64CounterObserver(metricName, func(ctx context.Context, m metric.Float64ObserverResult) { + _, err = meter.NewFloat64CounterObserver(metricName, func(ctx context.Context, m metric.Float64ObserverResult) { if value, ok := getProtoStats()[metricName]; ok { m.Observe(value, attrs...) } }) + if err != nil { + logrus.Errorf("meter.NewFloat64CounterObserver is failed: %s", err) + } } return nil diff --git a/internal/observability/obs.go b/internal/observability/obs.go index 670d9eac..0b6dfc5d 100644 --- a/internal/observability/obs.go +++ b/internal/observability/obs.go @@ -148,7 +148,10 @@ func InitObserver( if mprovider != nil { metricglobal.SetMeterProvider(mprovider) obs.meter = mprovider.Meter(tname) - mprovider.Start(ctx) + err := mprovider.Start(ctx) + if err != nil { + logrus.Errorf("failed to start mprovider: %s", err) + } } Observer = obs @@ -276,15 +279,27 @@ func (obs *observer) Flush(ctx context.Context) { if obs.mprovider != nil { // TODO: it's dirty hack because otel sdk can't use Collect method when ticker was running if obs.mprovider.IsRunning() { - obs.mprovider.Stop(ctx) - obs.mprovider.Start(ctx) + err := obs.mprovider.Stop(ctx) + if err != nil { + logrus.Errorf("failed to stop mprovider: %s", err) + } + err = obs.mprovider.Start(ctx) + if err != nil { + logrus.Errorf("failed to start mprovider: %s", err) + } } else { - obs.mprovider.Collect(ctx) + err := obs.mprovider.Collect(ctx) + if err != nil { + logrus.Errorf("failed to collect mprovider: %s", err) + } } obs.flushMClient(ctx) } if obs.tprovider != nil { - obs.tprovider.ForceFlush(ctx) + err := obs.tprovider.ForceFlush(ctx) + if err != nil { + logrus.Errorf("failed to force flush tprovider: %s", err) + } obs.flushTClient(ctx) } } @@ -305,15 +320,27 @@ func (obs *observer) Close() { ctx := context.Background() if obs.mprovider != nil { if obs.mprovider.IsRunning() { - obs.mprovider.Stop(ctx) + err := obs.mprovider.Stop(ctx) + if err != nil { + logrus.Errorf("failt to stop mprovider: %s", err) + } + } + err := obs.mprovider.Collect(ctx) + if err != nil { + logrus.Errorf("failed to collect mprovider: %s", err) } - obs.mprovider.Collect(ctx) obs.flushMClient(ctx) } if obs.tprovider != nil { - obs.tprovider.ForceFlush(ctx) + err := obs.tprovider.ForceFlush(ctx) + if err != nil { + logrus.Errorf("failed to force flush tprovider: %s", err) + } obs.flushTClient(ctx) - obs.tprovider.Shutdown(ctx) + err = obs.tprovider.Shutdown(ctx) + if err != nil { + logrus.Errorf("failed to shutdown tprovider: %s", err) + } } obs.cancelCtx() } @@ -490,7 +517,10 @@ func (obs *observer) NewSpanWithParent(ctx context.Context, kind oteltrace.SpanK ) tid, err = oteltrace.TraceIDFromHex(traceID) if err != nil { - rand.Read(tid[:]) + _, e := rand.Read(tid[:]) + if e != nil { + logrus.Errorf("rand.Read is failed: %s", e) + } } sid, err = oteltrace.SpanIDFromHex(pspanID) if err != nil { diff --git a/internal/protoagent/utils.go b/internal/protoagent/utils.go index 0ec1cf3d..d28f64ae 100644 --- a/internal/protoagent/utils.go +++ b/internal/protoagent/utils.go @@ -54,7 +54,12 @@ func UnpackProtoMessage(dst proto.Message, msg []byte, msgType Message_Type) err return err } if actualMsgType != msgType { - return fmt.Errorf("%w: expected agent message type: %d, got: %d", errors.ErrUnexpectedUnpackType, msgType, actualMsgType) + return fmt.Errorf( + "%w: expected agent message type: %d, got: %d", + errors.ErrUnexpectedUnpackType, + msgType, + actualMsgType, + ) } if err := UnpackProtoMessagePayload(dst, payload); err != nil { return err diff --git a/internal/storage/fs.go b/internal/storage/fs.go index a80e3845..d90f345c 100755 --- a/internal/storage/fs.go +++ b/internal/storage/fs.go @@ -1,11 +1,15 @@ +//nolint:staticcheck package storage +//TODO: io/ioutil is deprecated, replace to fs.FS import ( "io" "io/ioutil" "os" "path/filepath" "strings" + + "github.com/sirupsen/logrus" ) // FS is main class for FS API @@ -14,7 +18,7 @@ type FS struct { } // NewFS is function that construct FS driver with IStorage -func NewFS() (IStorage, error) { +func NewFS() (*FS, error) { return &FS{ sLimits: sLimits{ defPerm: 0644, @@ -112,6 +116,7 @@ func (fs *FS) ReadFile(path string) ([]byte, error) { } else if info.Size() > fs.maxFileSize { return nil, ErrLimitExceeded } + // #nosec G304 if data, err := ioutil.ReadFile(path); err == nil { return data, nil } @@ -178,6 +183,7 @@ func (fs *FS) CreateDir(path string) error { // CreateFile is function for create new file if not exists func (fs *FS) CreateFile(path string) error { if !fs.IsExist(path) { + // #nosec G304 file, err := os.Create(path) if err != nil { return ErrCreateFailed @@ -193,6 +199,7 @@ func (fs *FS) CreateFile(path string) error { // WriteFile is function that write (override) data to a file func (fs *FS) WriteFile(path string, data []byte) error { + // #nosec G304 file, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, fs.defPerm) if err != nil { return ErrOpenFailed @@ -212,6 +219,7 @@ func (fs *FS) WriteFile(path string, data []byte) error { // AppendFile is function that append data to an exist file func (fs *FS) AppendFile(path string, data []byte) error { + // #nosec G304 file, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, fs.defPerm) if err != nil { return ErrOpenFailed @@ -294,11 +302,18 @@ func (fs *FS) CopyFile(src, dst string) error { } } + // #nosec G304 in, err := os.Open(src) if err != nil { return ErrOpenFailed } - defer in.Close() + defer func(in *os.File) { + e := in.Close() + if e != nil { + logrus.Errorf("failed to close file: %s", e) + } + }(in) + // #nosec G304 out, err := os.Create(dst) if err != nil { return ErrCreateFailed diff --git a/internal/storage/s3.go b/internal/storage/s3.go index aeee3b28..ff84a229 100755 --- a/internal/storage/s3.go +++ b/internal/storage/s3.go @@ -112,6 +112,7 @@ func (s *S3) ListDir(path string) (map[string]os.FileInfo, error) { Prefix: path, Recursive: true, }) + //nolint:exportloopref for object := range objectCh { if object.Err != nil { return nil, ErrListFailed @@ -123,6 +124,7 @@ func (s *S3) ListDir(path string) (map[string]os.FileInfo, error) { dirName, fileName := filepath.Split(shortPath) spl := strings.Split(dirName, "/") if dirName == "/" { + //TODO: Check use pointer in loop. Perhaps it is worth passing a pointer to the channel tree[shortPath] = &S3FileInfo{ isDir: false, path: fileName, diff --git a/internal/storage/structs.go b/internal/storage/structs.go index 8fedd1e5..e2e00e04 100755 --- a/internal/storage/structs.go +++ b/internal/storage/structs.go @@ -43,7 +43,7 @@ type IStorage interface { RemoveDir(path string) error RemoveFile(path string) error Remove(path string) error - Rename(old, new string) error + Rename(old, n string) error CopyFile(src, dst string) error IFileReader ILimits diff --git a/internal/system/agent_info.go b/internal/system/agent_info.go index e67d5e80..1df9878e 100644 --- a/internal/system/agent_info.go +++ b/internal/system/agent_info.go @@ -1,5 +1,7 @@ +//nolint:gosec package system +//TODO: replace mean cryptographic primitive and delete "nolint:gosec" import ( "context" "crypto/md5" diff --git a/internal/system/utils_linux.go b/internal/system/utils_linux.go index 449e71f8..3b31ccd4 100644 --- a/internal/system/utils_linux.go +++ b/internal/system/utils_linux.go @@ -1,8 +1,10 @@ //go:build linux // +build linux +//nolint:staticcheck package system +//TODO: io/ioutil is deprecated, replace to fs.FS and delete "nolint:staticcheck" import ( "io/ioutil" "strings" diff --git a/internal/system/utils_unix.go b/internal/system/utils_unix.go index cea8bbeb..54a7e42c 100644 --- a/internal/system/utils_unix.go +++ b/internal/system/utils_unix.go @@ -1,8 +1,10 @@ //go:build !windows // +build !windows +//nolint:staticcheck package system +//TODO: io/ioutil is deprecated, replace to fs.FS and delete "nolint:staticcheck" import ( "fmt" "io/ioutil" diff --git a/internal/utils/utils.go b/internal/utils/utils.go index ff23c77c..3382efe5 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -1,14 +1,18 @@ +//nolint:staticcheck package utils +//TODO: io/ioutil is deprecated, replace to fs.FS and delete "nolint:staticcheck" import ( "io/ioutil" "os" "path/filepath" "strconv" "strings" + + "github.com/sirupsen/logrus" ) -// GetRef is function for returning referense of string +// GetRef is function for returning reference of string func GetRef(str string) *string { return &str } @@ -22,22 +26,29 @@ func RemoveUnusedTempDir() { for _, f := range files { if f.IsDir() && strings.HasPrefix(f.Name(), "vxlua-") { - pathToPID := filepath.Join(os.TempDir(), f.Name(), "lock.pid") - fdata, err := ioutil.ReadFile(pathToPID) + fdata, err := ioutil.ReadFile(filepath.Join(os.TempDir(), f.Name(), "lock.pid")) + path := filepath.Join(os.TempDir(), f.Name()) if err != nil { - os.RemoveAll(filepath.Join(os.TempDir(), f.Name())) + removeAll(path) continue } pid, err := strconv.Atoi(string(fdata)) if err != nil { - os.RemoveAll(filepath.Join(os.TempDir(), f.Name())) + removeAll(path) continue } proc, _ := os.FindProcess(pid) if proc == nil || err != nil { - os.RemoveAll(filepath.Join(os.TempDir(), f.Name())) + removeAll(path) continue } } } } + +func removeAll(path string) { + e := os.RemoveAll(path) + if e != nil { + logrus.Errorf("failed to remove all: %s", e) + } +} diff --git a/internal/vxproto/agent.go b/internal/vxproto/agent.go index 47329071..9820d316 100644 --- a/internal/vxproto/agent.go +++ b/internal/vxproto/agent.go @@ -72,7 +72,7 @@ func (at AgentType) String() string { return str } - return "unknown" + return typeUnknown } // MarshalJSON using for convert from AgentType to JSON diff --git a/internal/vxproto/module.go b/internal/vxproto/module.go index ccbb4b90..c22bb616 100755 --- a/internal/vxproto/module.go +++ b/internal/vxproto/module.go @@ -1,5 +1,8 @@ +//nolint:gosec,staticcheck package vxproto +//TODO: replace mean cryptographic primitive and delete "nolint:gosec" +//TODO: io/ioutil is deprecated, replace to fs.FS and delete "nolint:staticcheck" import ( "context" "crypto/md5" @@ -241,10 +244,20 @@ func (ms *moduleSocket) sendFileStream(ctx context.Context, dst string, file *Fi return fmt.Errorf("invalid left index %d: cannot be negative", left) } if left > right { - return fmt.Errorf("invalid left (%d) and right (%d) indices: the left index cannot be greater than the right index", left, right) + return fmt.Errorf( + "invalid left (%d) and right (%d) indices: "+ + "the left index cannot be greater than the right index", + left, + right, + ) } if right > dataLen { - return fmt.Errorf("invalid right index (%d) for the given data length (%d): the right index cannot be greater than the data length", right, dataLen) + return fmt.Errorf( + "invalid right index (%d) for the given data length (%d): "+ + "the right index cannot be greater than the data length", + right, + dataLen, + ) } return nil } @@ -355,7 +368,10 @@ func (ms *moduleSocket) parseFileStream(ctx context.Context, packet *Packet) (bo } tempDir := filepath.Join(os.TempDir(), "vx-store") - os.Mkdir(tempDir, 0700) + err := os.Mkdir(tempDir, 0700) + if err != nil { + return false, fmt.Errorf("failed to make directory: %s", err) + } tempFile := filepath.Join(tempDir, uniqArr[0]) tempFileFlags := os.O_WRONLY | os.O_CREATE diff --git a/internal/vxproto/packet_checker.go b/internal/vxproto/packet_checker.go index 55f20e07..4389c870 100644 --- a/internal/vxproto/packet_checker.go +++ b/internal/vxproto/packet_checker.go @@ -10,7 +10,9 @@ import ( "soldr/internal/protoagent" ) -func NewConnectionPolicyManagerIterator(versionsConfig ServerAPIVersionsConfig) (*ConnectionPolicyManagerIterator, error) { +func NewConnectionPolicyManagerIterator( + versionsConfig ServerAPIVersionsConfig, +) (*ConnectionPolicyManagerIterator, error) { if versionsConfig == nil { return nil, fmt.Errorf("a nil configuration object passed") } @@ -126,7 +128,7 @@ func (p EndpointConnectionPolicy) String() string { case EndpointConnectionPolicyUpgrade: return "upgrade" default: - return "unknown" + return typeUnknown } } @@ -135,7 +137,10 @@ type staticConnectionPolicyManager struct { policyType EndpointConnectionPolicy } -func newStaticConnectionPolicyManager(policy ConnectionPolicy, policyType EndpointConnectionPolicy) *staticConnectionPolicyManager { +func newStaticConnectionPolicyManager( + policy ConnectionPolicy, + policyType EndpointConnectionPolicy, +) *staticConnectionPolicyManager { return &staticConnectionPolicyManager{ policy: policy, policyType: policyType, @@ -309,7 +314,8 @@ func newUpgradePacketChecker() *upgradePacketChecker { return fmt.Errorf("only upgrade files can be sent via this endpoint") default: return fmt.Errorf( - "sending packet of type %d, but only data packets (of type %d) or file packets (of type %d) are allowed for this connection", + "sending packet of type %d, "+ + "but only data packets (of type %d) or file packets (of type %d) are allowed for this connection", p.PType, PTData, PTFile, diff --git a/internal/vxproto/proto.go b/internal/vxproto/proto.go index 9bda1239..2620f6a7 100644 --- a/internal/vxproto/proto.go +++ b/internal/vxproto/proto.go @@ -1,5 +1,7 @@ +//nolint:gosec package vxproto +//TODO: replace mean cryptographic primitive and delete "nolint:gosec" import ( "bytes" "context" @@ -286,7 +288,10 @@ func (vxp *vxProto) runRecvPQueue() func() { for _, p := range pq { if time.Since(time.Unix(p.TS, 0)).Seconds() < DeferredPacketTTLSeconds { // Here using public method because there need save proc time to recv other packets - vxp.recvPacket(ctx, p) + err := vxp.recvPacket(ctx, p) + if err != nil { + logrus.Errorf("failed to recive packet: %s", err) + } } } }) @@ -301,7 +306,10 @@ func (vxp *vxProto) runSendPQueue() func() { for _, p := range pq { if time.Since(time.Unix(p.TS, 0)).Seconds() < DeferredPacketTTLSeconds { // Here using public method because there need save proc time to send other packets - vxp.sendPacket(ctx, p) + err := vxp.sendPacket(ctx, p) + if err != nil { + logrus.Errorf("failed to send packet: %s", err) + } } } }) diff --git a/internal/vxproto/proto_init.go b/internal/vxproto/proto_init.go index c6260bd5..c9cbc46f 100644 --- a/internal/vxproto/proto_init.go +++ b/internal/vxproto/proto_init.go @@ -13,7 +13,12 @@ import ( "soldr/internal/system" ) -func (vxp *vxProto) InitConnection(ctx context.Context, connValidator AgentConnectionValidator, config *ClientInitConfig, logger *logrus.Entry) error { +func (vxp *vxProto) InitConnection( + ctx context.Context, + connValidator AgentConnectionValidator, + config *ClientInitConfig, + logger *logrus.Entry, +) error { vxp.isClosedMux.RLock() defer vxp.isClosedMux.RUnlock() if vxp.isClosed { @@ -130,7 +135,13 @@ func NewSyncWS(ctx context.Context, logger *logrus.Entry, ws *websocket.Conn, co return s, nil } -func configurePing(done <-chan struct{}, logger *logrus.Entry, ws IWSConnection, readTimeout time.Duration, pingFrequency time.Duration) { +func configurePing( + done <-chan struct{}, + logger *logrus.Entry, + ws IWSConnection, + readTimeout time.Duration, + pingFrequency time.Duration, +) { if pingFrequency == 0 { handleWSPing(ws, logger, readTimeout) return @@ -138,7 +149,13 @@ func configurePing(done <-chan struct{}, logger *logrus.Entry, ws IWSConnection, configureWSPing(done, logger, ws, readTimeout, pingFrequency) } -func configureWSPing(done <-chan struct{}, logger *logrus.Entry, ws IWSConnection, readTimeout time.Duration, pingFrequency time.Duration) { +func configureWSPing( + done <-chan struct{}, + logger *logrus.Entry, + ws IWSConnection, + readTimeout time.Duration, + pingFrequency time.Duration, +) { go func() { ticker := time.NewTicker(pingFrequency) for { diff --git a/internal/vxproto/structs.go b/internal/vxproto/structs.go index 386c92c7..9d69dc62 100755 --- a/internal/vxproto/structs.go +++ b/internal/vxproto/structs.go @@ -16,8 +16,10 @@ import ( // ControlMessageType is type of message that used for modules communication type ControlMessageType int32 -// Enumerate control message types const ( + typeUnknown = "unknown" + + // Enumerate control message types AgentConnected ControlMessageType = 0 AgentDisconnected ControlMessageType = 1 StopModule ControlMessageType = 2 @@ -40,7 +42,7 @@ func (cmt ControlMessageType) String() string { return str } - return "unknown" + return typeUnknown } // MarshalJSON using for convert from ControlMessageType to JSON @@ -197,7 +199,7 @@ func (mt MsgType) String() string { return str } - return "unknown" + return typeUnknown } // MarshalJSON using for convert from MsgType to JSON @@ -307,7 +309,7 @@ func (pt PacketType) String() string { return str } - return "unknown" + return typeUnknown } // MarshalJSON using for convert from PacketType to JSON @@ -418,7 +420,7 @@ func (p *Packet) fromPB(packet *protocol.Packet) (*Packet, error) { p.PType = PTAction p.Payload = (&Action{}).fromPB(content) default: - return nil, fmt.Errorf("unknown packet type") + return nil, fmt.Errorf("%s packet type", typeUnknown) } return p, nil @@ -485,7 +487,7 @@ func (p *Packet) fromBytesJSON(data []byte) (*Packet, error) { } p.Payload = &act default: - return nil, fmt.Errorf("unknown packet type") + return nil, fmt.Errorf("%s packet type", typeUnknown) } var traceID string @@ -527,7 +529,7 @@ func (p *Packet) toPB() (*protocol.Packet, error) { case PTAction: content = p.Payload.(*Action).toPB() default: - return nil, fmt.Errorf("unknown packet type") + return nil, fmt.Errorf("%s packet type", typeUnknown) } spanCtx := obs.Observer.SpanContextFromContext(p.ctx) diff --git a/internal/vxproto/tunnel/rc4/encrypter.go b/internal/vxproto/tunnel/rc4/encrypter.go index ed0cb5b4..f104140f 100644 --- a/internal/vxproto/tunnel/rc4/encrypter.go +++ b/internal/vxproto/tunnel/rc4/encrypter.go @@ -1,5 +1,7 @@ +//nolint:gosec package rc4 +// TODO: replace mean cryptographic primitive and delete "nolint:gosec" import ( "crypto/rc4" "fmt" @@ -9,6 +11,7 @@ import ( compressor "soldr/internal/vxproto/tunnel/compressor/simple" ) +// GenerateKey TODO: Check use function func GenerateKey(rand func(buf []byte) error) ([]byte, error) { const keyLen = 48 buf := make([]byte, keyLen) @@ -61,11 +64,13 @@ func (e *Encrypter) Decrypt(data []byte) ([]byte, error) { return data, nil } +//nolint:gosec func (e *Encrypter) applyCipher(data []byte) ([]byte, error) { xoredData := make([]byte, len(data)) e.keyMux.RLock() defer e.keyMux.RUnlock() + // TODO: Replace weak cryptographic primitive cipher, err := rc4.NewCipher(e.key) if err != nil { return nil, fmt.Errorf("failed to initialize a new RC4 cipher: %w", err) diff --git a/internal/vxproto/ws.go b/internal/vxproto/ws.go index 694d3b11..e84f0371 100755 --- a/internal/vxproto/ws.go +++ b/internal/vxproto/ws.go @@ -1,5 +1,7 @@ +//nolint:staticcheck package vxproto +//TODO: io/ioutil is deprecated, replace to fs.FS and delete "nolint:staticcheck" import ( "bytes" "context" @@ -13,6 +15,7 @@ import ( "time" "github.com/gorilla/websocket" + "github.com/sirupsen/logrus" ) var PacketMarkerV1 = []byte{0xf0, 0xa0, 0x60, 0x30} @@ -208,8 +211,14 @@ func (ws *wsConnection) Write(ctx context.Context, data []byte) (err error) { } if !ws.original { prefixBuffer := new(bytes.Buffer) - binary.Write(prefixBuffer, binary.LittleEndian, PacketMarkerV1) - binary.Write(prefixBuffer, binary.LittleEndian, uint32(len(data))) + e := binary.Write(prefixBuffer, binary.LittleEndian, PacketMarkerV1) + if e != nil { + logrus.Errorf("binary.Write failed: %s", e) + } + e = binary.Write(prefixBuffer, binary.LittleEndian, uint32(len(data))) + if e != nil { + logrus.Errorf("binary.Write failed: %s", e) + } prefix := prefixBuffer.Bytes() if n, err := writer.Write(prefix); err != nil || n != len(prefix) { return fmt.Errorf( @@ -289,7 +298,10 @@ func (ws *wsConnection) Close(ctx context.Context) error { _ = ws.Conn.Close() }() - if err := ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil { + if err := ws.WriteMessage( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), + ); err != nil { if strings.HasPrefix(err.Error(), "websocket: close ") { // to avoid issue with hanging TLS/TCP connection _ = ws.Conn.Close() @@ -297,7 +309,11 @@ func (ws *wsConnection) Close(ctx context.Context) error { } err = fmt.Errorf("failed to write a request Close control message: %w", err) if closeErr := ws.Conn.Close(); closeErr != nil { - err = fmt.Errorf("failed to properly close the websocket connection (%v) after another error has occurred: %w", closeErr, err) + err = fmt.Errorf( + "failed to properly close the websocket connection (%v) after another error has occurred: %w", + closeErr, + err, + ) } return err } diff --git a/internal/vxproto/ws_client.go b/internal/vxproto/ws_client.go index a439d13f..c7be2bcb 100644 --- a/internal/vxproto/ws_client.go +++ b/internal/vxproto/ws_client.go @@ -19,6 +19,11 @@ import ( "soldr/internal/vxproto/tunnel" ) +const ( + configTypeBrowser = "browser" + configTypeExternal = "external" +) + type agentConn interface { connect(ctx context.Context) error } @@ -46,8 +51,8 @@ func (vxp *vxProto) openAgentSocket( tunnelEncrypter tunnel.PackEncryptor, infoGetter system.AgentInfoGetter, ) (socket *agentSocket, cleanup func(), err error) { - if config.Type == "browser" || config.Type == "external" { - return nil, nil, fmt.Errorf("connection initialization for the browser type is NYI") + if config.Type == configTypeBrowser || config.Type == configTypeExternal { + return nil, nil, fmt.Errorf("connection initialization for the %s type is NYI", configTypeBrowser) } dialer := websocket.Dialer{ TLSClientConfig: config.TLSConfig, @@ -147,7 +152,10 @@ func (a *agentSocket) connect(ctx context.Context) error { return fmt.Errorf("failed to start the pingee: %w", err) } defer func() { - a.pinger.Stop(ctx) + e := a.pinger.Stop(ctx) + if e != nil { + logrus.Errorf("failed to stop pinger: %s", e) + } }() if closeRootSpanCb, ok := ctx.Value(obs.VXProtoAgentConnect).(func()); ok { @@ -170,9 +178,13 @@ const ( defaultReadTimeout = defaultPingFrequency * 6 ) -func openAgentSocketToInitConnection(ctx context.Context, logger *logrus.Entry, config *ClientInitConfig) (SyncWS, error) { - if config.Type == "browser" || config.Type == "external" { - return nil, fmt.Errorf("connection initialization for the browser type is NYI") +func openAgentSocketToInitConnection( + ctx context.Context, + logger *logrus.Entry, + config *ClientInitConfig, +) (SyncWS, error) { + if config.Type == configTypeBrowser || config.Type == configTypeExternal { + return nil, fmt.Errorf("connection initialization for the %s type is NYI", configTypeBrowser) } dialer := websocket.Dialer{ TLSClientConfig: config.TLSConfig, diff --git a/internal/vxproto/ws_server.go b/internal/vxproto/ws_server.go index 2eed2b90..7283728b 100644 --- a/internal/vxproto/ws_server.go +++ b/internal/vxproto/ws_server.go @@ -109,7 +109,10 @@ func (vxp *vxProto) listenWS( vxp.mutex.Lock() defer vxp.mutex.Unlock() ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - server.Shutdown(ctx) + e := server.Shutdown(ctx) + if e != nil { + logrus.Errorf("failed to shutdown server: %s", e) + } cancel() for idx, closer := range vxp.closers { @@ -132,7 +135,11 @@ func (vxp *vxProto) listenWS( return server.ListenAndServeTLS("", "") } -func (vxp *vxProto) configureRouter(c ServerAPIVersionsConfig, validatorFactory ConnectionValidatorFactory, logger *logrus.Entry) (http.Handler, error) { +func (vxp *vxProto) configureRouter( + c ServerAPIVersionsConfig, + validatorFactory ConnectionValidatorFactory, + logger *logrus.Entry, +) (http.Handler, error) { policyManagerIterator, err := NewConnectionPolicyManagerIterator(c) if err != nil { return nil, fmt.Errorf("failed to initialize a connection policy manager factory: %w", err) @@ -143,8 +150,18 @@ func (vxp *vxProto) configureRouter(c ServerAPIVersionsConfig, validatorFactory if err != nil { return nil, fmt.Errorf("failed to get a connection policy manager: %w", err) } - if err := vxp.configureVersionHandlers(r, validatorFactory, managerWithVer.Manager, managerWithVer.Version, logger); err != nil { - return nil, fmt.Errorf("failed to configurate the version handlers for version \"%s\": %w", managerWithVer.Version, err) + if err := vxp.configureVersionHandlers( + r, + validatorFactory, + managerWithVer.Manager, + managerWithVer.Version, + logger, + ); err != nil { + return nil, fmt.Errorf( + "failed to configure the version handlers for version \"%s\": %w", + managerWithVer.Version, + err, + ) } } return r, nil @@ -299,7 +316,13 @@ type getConnectionPolicyMiddlewareResult struct { AgentConnectionInfo *AgentConnectionInfo } -func getConnectionPolicy(ctx context.Context, urlPath string, agentID string, idFetcher AgentIDFetcher, policyManager ConnectionPolicyManager) (*getConnectionPolicyMiddlewareResult, error) { +func getConnectionPolicy( + ctx context.Context, + urlPath string, + agentID string, + idFetcher AgentIDFetcher, + policyManager ConnectionPolicyManager, +) (*getConnectionPolicyMiddlewareResult, error) { agentType, err := getAgentTypeFromURL(urlPath) if err != nil { return nil, fmt.Errorf("failed to get the agent type from the URL path: %w", err) @@ -356,7 +379,11 @@ func extractAgentType(r *http.Request) (AgentType, error) { return agentType, nil } -func fetchAgentConnectionInfo(reqCtx context.Context, idFetcher AgentIDFetcher, id string) (*AgentConnectionInfo, error) { +func fetchAgentConnectionInfo( + reqCtx context.Context, + idFetcher AgentIDFetcher, + id string, +) (*AgentConnectionInfo, error) { connInfo, err := idFetcher.GetAgentConnectionInfo(reqCtx, &AgentInfoForIDFetcher{ ID: id, }) @@ -404,9 +431,15 @@ func handleAgentWS( }() // Run ping sender - socket.pinger.Start(r.Context(), socket.ping) + err = socket.pinger.Start(r.Context(), socket.ping) + if err != nil { + log.Errorf("failed to start pinger: %s", err) + } defer func() { - socket.pinger.Stop(r.Context()) + e := socket.pinger.Stop(r.Context()) + if e != nil { + log.Errorf("failed to stop pinger: %s", e) + } }() // Read messages before connection will be closed