Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for number of scanners to be configurable #762

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions cmd/mal/mal.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ var (
outputFlag string
profileFlag bool
quantityIncreasesRiskFlag bool
scannersFlag int
statsFlag bool
thirdPartyFlag bool
verboseFlag bool
Expand Down Expand Up @@ -92,7 +93,7 @@ func showError(err error) {
fmt.Fprintf(os.Stderr, "%s %s\n", emoji, err.Error())
}

//nolint:cyclop // ignore complexity of 40
//nolint:cyclop,gocognit // ignore complexity of 40,100
func main() {
returnCode := ExitOK
defer func() { os.Exit(returnCode) }()
Expand Down Expand Up @@ -249,9 +250,14 @@ func main() {
concurrency = 1
}

maxScanners := scannersFlag
if maxScanners > concurrency {
maxScanners = concurrency
}

var pool *malcontent.ScannerPool
if mc.ScannerPool == nil {
pool, err = malcontent.NewScannerPool(yrs, concurrency)
pool, err = malcontent.NewScannerPool(yrs, maxScanners)
if err != nil {
returnCode = ExitInvalidRules
}
Expand All @@ -264,6 +270,7 @@ func main() {
IgnoreSelf: ignoreSelfFlag,
IgnoreTags: ignoreTags,
IncludeDataFiles: includeDataFiles,
MaxScanners: maxScanners,
MinFileRisk: minFileRisk,
MinRisk: minRisk,
OCI: ociFlag,
Expand Down Expand Up @@ -372,6 +379,12 @@ func main() {
Usage: "Increase file risk score based on behavior quantity",
Destination: &quantityIncreasesRiskFlag,
},
&cli.IntFlag{
Name: "scanners",
Value: runtime.NumCPU(),
Usage: "Number of scanners to create",
Destination: &scannersFlag,
},
&cli.BoolFlag{
Name: "stats",
Aliases: []string{"s"},
Expand Down
2 changes: 1 addition & 1 deletion pkg/action/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF

var pool *malcontent.ScannerPool
if c.ScannerPool == nil {
pool, err = malcontent.NewScannerPool(yrs, c.Concurrency)
pool, err = malcontent.NewScannerPool(yrs, c.MaxScanners)
if err != nil {
return nil, fmt.Errorf("failed to create scanner pool: %w", err)
}
Expand Down
77 changes: 23 additions & 54 deletions pkg/malcontent/malcontent.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"runtime"
"sync"
"sync/atomic"
"time"

yarax "github.com/VirusTotal/yara-x/go"
orderedmap "github.com/wk8/go-ordered-map/v2"
Expand All @@ -34,6 +33,7 @@ type Config struct {
IgnoreSelf bool
IgnoreTags []string
IncludeDataFiles bool
MaxScanners int
MinFileRisk int
MinRisk int
OCI bool
Expand Down Expand Up @@ -175,23 +175,17 @@ func NewScannerPool(rules *yarax.Rules, maxScanners int) (*ScannerPool, error) {
available: make(chan *yarax.Scanner, maxScanners),
maxScanners: int32(maxScanners),
scanners: make([]*yarax.Scanner, 0, maxScanners),
closed: atomic.Bool{},
}

pool.closed.Store(false)

// Create a subset of the maximum number of scanners to avoid contention
initialScanners := maxScanners/2 + 1
for i := 0; i < initialScanners; i++ {
scanner, err := pool.createScanner()
if err != nil {
pool.Cleanup()
return nil, fmt.Errorf("failed to create initial scanner: %w", err)
}
pool.scanners = append(pool.scanners, scanner)
pool.available <- scanner
atomic.AddInt32(&pool.currentCount, 1)
scanner := yarax.NewScanner(rules)
if scanner == nil {
return nil, fmt.Errorf("failed to create scanner")
}

pool.available <- scanner
atomic.AddInt32(&pool.currentCount, 1)

return pool, nil
}

Expand Down Expand Up @@ -236,39 +230,27 @@ func (p *ScannerPool) Get() (*yarax.Scanner, error) {
return nil, fmt.Errorf("scanner pool is closed")
}

// Retrieve an existing scanner
// If none are available, create up to the maximum number of scanners
select {
case scanner := <-p.available:
if scanner == nil {
return nil, fmt.Errorf("received nil scanner from pool")
}
return scanner, nil
case <-time.After(100 * time.Millisecond):
}

// Create a new scanner if we aren't already running the maximum number
p.mu.Lock()
current := atomic.LoadInt32(&p.currentCount)
if current < p.maxScanners {
scanner, err := p.createScanner()
if err != nil {
default:
p.mu.Lock()
if atomic.LoadInt32(&p.currentCount) < p.maxScanners {
scanner, err := p.createScanner()
if err != nil {
p.mu.Unlock()
return nil, fmt.Errorf("create scanner: %w", err)
}
p.scanners = append(p.scanners, scanner)
atomic.AddInt32(&p.currentCount, 1)
p.mu.Unlock()
return nil, fmt.Errorf("create scanner: %w", err)
return scanner, nil
}
p.scanners = append(p.scanners, scanner)
atomic.AddInt32(&p.currentCount, 1)
p.mu.Unlock()
return scanner, nil
}
p.mu.Unlock()

select {
case scanner := <-p.available:
if scanner == nil {
return nil, fmt.Errorf("received nil scanner from pool")
}
return scanner, nil
case <-time.After(10 * time.Second):
return nil, fmt.Errorf("timeout waiting for available scanner")
return <-p.available, nil
}
}

Expand All @@ -277,20 +259,7 @@ func (p *ScannerPool) Put(scanner *yarax.Scanner) {
if scanner == nil || p.closed.Load() {
return
}

select {
case p.available <- scanner:
default:
p.mu.Lock()
defer func() {
p.mu.Unlock()
if atomic.LoadInt32(&p.currentCount) > p.maxScanners/2 {
runtime.GC()
}
}()
scanner.Destroy()
atomic.AddInt32(&p.currentCount, -1)
}
p.available <- scanner
}

// Cleanup destroys all scanners in the pool.
Expand Down
3 changes: 2 additions & 1 deletion pkg/refresh/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ func actionRefresh(ctx context.Context) ([]TestData, error) {
c := &malcontent.Config{
Concurrency: runtime.NumCPU(),
IgnoreSelf: false,
MaxScanners: runtime.NumCPU(),
MinFileRisk: 0,
MinRisk: 0,
OCI: false,
Expand All @@ -81,7 +82,7 @@ func actionRefresh(ctx context.Context) ([]TestData, error) {

var pool *malcontent.ScannerPool
if c.ScannerPool == nil {
pool, err = malcontent.NewScannerPool(yrs, c.Concurrency)
pool, err = malcontent.NewScannerPool(yrs, c.MaxScanners)
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/refresh/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ func diffRefresh(ctx context.Context, rc Config) ([]TestData, error) {
Concurrency: runtime.NumCPU(),
FileRiskChange: td.riskChange,
FileRiskIncrease: td.riskIncrease,
MaxScanners: runtime.NumCPU(),
MinFileRisk: minFileRisk,
MinRisk: minRisk,
QuantityIncreasesRisk: true,
Expand All @@ -207,7 +208,7 @@ func diffRefresh(ctx context.Context, rc Config) ([]TestData, error) {

var pool *malcontent.ScannerPool
if c.ScannerPool == nil {
pool, err = malcontent.NewScannerPool(yrs, c.Concurrency)
pool, err = malcontent.NewScannerPool(yrs, c.MaxScanners)
if err != nil {
return nil, err
}
Expand Down
1 change: 1 addition & 0 deletions pkg/refresh/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ func newConfig(rc Config) *malcontent.Config {
return &malcontent.Config{
Concurrency: runtime.NumCPU(),
IgnoreTags: []string{"harmless"},
MaxScanners: runtime.NumCPU(),
MinFileRisk: 1,
MinRisk: 1,
QuantityIncreasesRisk: true,
Expand Down
Loading