Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(misconf): pass options to Rego scanner as is #7529

Merged
merged 1 commit into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions pkg/iac/rego/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (s *Scanner) loadEmbedded() error {
return nil
}

func (s *Scanner) LoadPolicies(enableEmbeddedLibraries, enableEmbeddedPolicies bool, srcFS fs.FS, paths []string, readers []io.Reader) error {
func (s *Scanner) LoadPolicies(srcFS fs.FS) error {

if s.policies == nil {
s.policies = make(map[string]*ast.Module)
Expand All @@ -90,28 +90,28 @@ func (s *Scanner) LoadPolicies(enableEmbeddedLibraries, enableEmbeddedPolicies b
return err
}

if enableEmbeddedPolicies {
if s.includeEmbeddedPolicies {
s.policies = lo.Assign(s.policies, s.embeddedChecks)
}

if enableEmbeddedLibraries {
if s.includeEmbeddedLibraries {
s.policies = lo.Assign(s.policies, s.embeddedLibs)
}

var err error
if len(paths) > 0 {
loaded, err := LoadPoliciesFromDirs(srcFS, paths...)
if len(s.policyDirs) > 0 {
loaded, err := LoadPoliciesFromDirs(srcFS, s.policyDirs...)
if err != nil {
return fmt.Errorf("failed to load rego checks from %s: %w", paths, err)
return fmt.Errorf("failed to load rego checks from %s: %w", s.policyDirs, err)
}
for name, policy := range loaded {
s.policies[name] = policy
}
s.logger.Debug("Checks from disk are loaded", log.Int("count", len(loaded)))
}

if len(readers) > 0 {
loaded, err := s.loadPoliciesFromReaders(readers)
if len(s.policyReaders) > 0 {
loaded, err := s.loadPoliciesFromReaders(s.policyReaders)
if err != nil {
return fmt.Errorf("failed to load rego checks from reader(s): %w", err)
}
Expand Down Expand Up @@ -143,7 +143,7 @@ func (s *Scanner) LoadPolicies(enableEmbeddedLibraries, enableEmbeddedPolicies b
}
s.store = store

return s.compilePolicies(srcFS, paths)
return s.compilePolicies(srcFS, s.policyDirs)
}

func (s *Scanner) fallbackChecks(compiler *ast.Compiler) {
Expand Down
49 changes: 31 additions & 18 deletions pkg/iac/rego/load_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"embed"
"fmt"
"io"
"log/slog"
"strings"
"testing"
Expand All @@ -16,7 +15,6 @@ import (

checks "github.com/aquasecurity/trivy-checks"
"github.com/aquasecurity/trivy/pkg/iac/rego"
"github.com/aquasecurity/trivy/pkg/iac/scanners/options"
"github.com/aquasecurity/trivy/pkg/iac/types"
"github.com/aquasecurity/trivy/pkg/log"
)
Expand All @@ -33,10 +31,11 @@ func Test_RegoScanning_WithSomeInvalidPolicies(t *testing.T) {
slog.SetDefault(log.New(log.NewHandler(&debugBuf, nil)))
scanner := rego.NewScanner(
types.SourceDockerfile,
options.ScannerWithRegoErrorLimits(0),
rego.WithRegoErrorLimits(0),
rego.WithPolicyDirs("."),
)

err := scanner.LoadPolicies(false, false, testEmbedFS, []string{"."}, nil)
err := scanner.LoadPolicies(testEmbedFS)
require.ErrorContains(t, err, `want (one of): ["Cmd" "EndLine" "Flags" "JSON" "Original" "Path" "Stage" "StartLine" "SubCmd" "Value"]`)
assert.Contains(t, debugBuf.String(), "Error(s) occurred while loading checks")
})
Expand All @@ -46,10 +45,11 @@ func Test_RegoScanning_WithSomeInvalidPolicies(t *testing.T) {
slog.SetDefault(log.New(log.NewHandler(&debugBuf, nil)))
scanner := rego.NewScanner(
types.SourceDockerfile,
options.ScannerWithRegoErrorLimits(1),
rego.WithRegoErrorLimits(1),
rego.WithPolicyDirs("."),
)

err := scanner.LoadPolicies(false, false, testEmbedFS, []string{"."}, nil)
err := scanner.LoadPolicies(testEmbedFS)
require.NoError(t, err)

assert.Contains(t, debugBuf.String(), "Error occurred while parsing\tfile_path=\"testdata/policies/invalid.rego\" err=\"testdata/policies/invalid.rego:7")
Expand All @@ -64,9 +64,13 @@ package mypackage
deny {
input.evil == "foo bar"
}`
scanner := rego.NewScanner(types.SourceJSON)
scanner := rego.NewScanner(
types.SourceJSON,
rego.WithPolicyDirs("."),
rego.WithPolicyReader(strings.NewReader(check)),
)

err := scanner.LoadPolicies(false, false, fstest.MapFS{}, []string{"."}, []io.Reader{strings.NewReader(check)})
err := scanner.LoadPolicies(fstest.MapFS{})
assert.ErrorContains(t, err, "could not find schema \"fooschema\"")
})

Expand All @@ -79,15 +83,19 @@ package mypackage
deny {
input.evil == "foo bar"
}`
scanner := rego.NewScanner(types.SourceJSON)
scanner := rego.NewScanner(
types.SourceJSON,
rego.WithPolicyDirs("."),
rego.WithPolicyReader(strings.NewReader(check)),
)

fsys := fstest.MapFS{
"schemas/fooschema.json": &fstest.MapFile{
Data: []byte("bad json"),
},
}

err := scanner.LoadPolicies(false, false, fsys, []string{"."}, []io.Reader{strings.NewReader(check)})
err := scanner.LoadPolicies(fsys)
assert.ErrorContains(t, err, "could not parse schema \"fooschema\"")
})

Expand All @@ -97,8 +105,12 @@ deny {
deny {
input.evil == "foo bar"
}`
scanner := rego.NewScanner(types.SourceJSON)
err := scanner.LoadPolicies(false, false, fstest.MapFS{}, []string{"."}, []io.Reader{strings.NewReader(check)})
scanner := rego.NewScanner(
types.SourceJSON,
rego.WithPolicyDirs("."),
rego.WithPolicyReader(strings.NewReader(check)),
)
err := scanner.LoadPolicies(fstest.MapFS{})
require.NoError(t, err)
})

Expand Down Expand Up @@ -184,8 +196,9 @@ deny {
t.Run(tt.name, func(t *testing.T) {
scanner := rego.NewScanner(
types.SourceDockerfile,
options.ScannerWithRegoErrorLimits(0),
options.ScannerWithEmbeddedPolicies(false),
rego.WithRegoErrorLimits(0),
rego.WithEmbeddedPolicies(false),
rego.WithPolicyDirs("."),
)

tt.files["schemas/fooschema.json"] = &fstest.MapFile{
Expand All @@ -200,9 +213,8 @@ deny {
}`),
}

fsys := fstest.MapFS(tt.files)
checks.EmbeddedPolicyFileSystem = embeddedChecksFS
err := scanner.LoadPolicies(false, false, fsys, []string{"."}, nil)
err := scanner.LoadPolicies(fstest.MapFS(tt.files))

if tt.expectedErr != "" {
assert.ErrorContains(t, err, tt.expectedErr)
Expand Down Expand Up @@ -244,8 +256,9 @@ deny {

scanner := rego.NewScanner(
types.SourceDockerfile,
options.ScannerWithEmbeddedPolicies(false),
rego.WithEmbeddedPolicies(false),
rego.WithPolicyDirs("."),
)
err := scanner.LoadPolicies(false, false, fsys, []string{"."}, nil)
err := scanner.LoadPolicies(fsys)
require.Error(t, err)
}
108 changes: 108 additions & 0 deletions pkg/iac/rego/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package rego

import (
"io"
"io/fs"

"github.com/aquasecurity/trivy/pkg/iac/scanners/options"
)

func WithPolicyReader(readers ...io.Reader) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.policyReaders = readers
}
}
}

func WithEmbeddedPolicies(include bool) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.includeEmbeddedPolicies = include
}
}
}

func WithEmbeddedLibraries(include bool) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.includeEmbeddedLibraries = include
}
}
}

// WithTrace specifies an io.Writer for trace logs (mainly rego tracing) - if not set, they are discarded
func WithTrace(w io.Writer) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.traceWriter = w
}
}
}

func WithPerResultTracing(enabled bool) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.tracePerResult = enabled
}
}
}

func WithPolicyDirs(paths ...string) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.policyDirs = paths
}
}
}

func WithDataDirs(paths ...string) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.dataDirs = paths
}
}
}

// WithPolicyNamespaces - namespaces which indicate rego policies containing enforced rules
func WithPolicyNamespaces(namespaces ...string) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
for _, namespace := range namespaces {
ss.ruleNamespaces[namespace] = struct{}{}
}
}
}
}

func WithPolicyFilesystem(fsys fs.FS) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.policyFS = fsys
}
}
}

func WithDataFilesystem(fsys fs.FS) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.dataFS = fsys
}
}
}

func WithRegoErrorLimits(limit int) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.regoErrorLimit = limit
}
}
}

func WithCustomSchemas(schemas map[string][]byte) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.customSchemas = schemas
}
}
}
Loading