Skip to content

Commit

Permalink
Add unit tests for regowriter
Browse files Browse the repository at this point in the history
Rewrite TestRegoRewriter to test AddLib and AddEntryPoint separately
Add go113 error for regowriter error assertion
Minor refactoring for rewriteImportPath

Signed-off-by: Becky Huang <beckyhd@google.com>
  • Loading branch information
becky-hd committed Dec 9, 2021
1 parent 02ea899 commit 1868dfa
Show file tree
Hide file tree
Showing 4 changed files with 453 additions and 67 deletions.
9 changes: 9 additions & 0 deletions constraint/pkg/regorewriter/errors.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package regorewriter

import (
"errors"
"fmt"
"io"
"strings"
Expand Down Expand Up @@ -41,3 +42,11 @@ func (errs Errors) Format(s fmt.State, verb rune) {
_, _ = fmt.Fprintf(s, "%q", errs.Error())
}
}

var (
ErrInvalidModule = errors.New("invalid module")
ErrInvalidImport = errors.New("invalid import")
ErrInvalidLibs = errors.New("invalid lib prefix")
ErrDataReferences = errors.New("invalid Data References")
ErrReadingFile = errors.New("error reading file")
)
48 changes: 21 additions & 27 deletions constraint/pkg/regorewriter/regorewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ type RegoRewriter struct {

// New returns a new RegoRewriter
// args:
// it - the PackageTransformer that will be used for updating the path
// pt - the PackageTransformer that will be used for updating the path
// libs - a list of package prefixes that are allowed for library use
// externs - a list of packages that the rego is allowed to reference but not declared in any libs
func New(pt PackageTransformer, libs []string, externs []string) (*RegoRewriter, error) {
Expand All @@ -64,7 +64,7 @@ func New(pt PackageTransformer, libs []string, externs []string) (*RegoRewriter,
func (r *RegoRewriter) add(path, src string, slice *[]*Module) error {
m, err := ast.ParseModule(path, src)
if err != nil {
return err
return fmt.Errorf("%w: %v", ErrInvalidModule, err)
}

r.addModule(path, m, slice)
Expand Down Expand Up @@ -95,7 +95,7 @@ func (r *RegoRewriter) addTestDir(testDirPath string) error {
glog.V(vLog).Infof("Walking test dir %s", testDirPath)
walkFn := func(path string, info os.FileInfo, err error) error {
if err != nil {
return fmt.Errorf("walk error on path %s: %w", path, err)
return fmt.Errorf("%w: walk error on path %s: %v", ErrReadingFile, path, err)
}
if info.IsDir() {
return nil
Expand All @@ -107,7 +107,7 @@ func (r *RegoRewriter) addTestDir(testDirPath string) error {
glog.V(vLog).Infof("reading %s", path)
bytes, err := ioutil.ReadFile(path)
if err != nil {
return err
return fmt.Errorf("%w: %v", ErrReadingFile, err)
}

r.testData = append(r.testData, &TestData{FilePath: FilePath{path: path}, content: bytes})
Expand All @@ -126,7 +126,7 @@ func (r *RegoRewriter) addFileFromFs(path string, slice *[]*Module) error {

bytes, err := ioutil.ReadFile(path)
if err != nil {
return err
return fmt.Errorf("%w: %v", ErrReadingFile, err)
}
return r.add(path, string(bytes), slice)
}
Expand All @@ -141,13 +141,13 @@ func (r *RegoRewriter) addFileFromFs(path string, slice *[]*Module) error {
func (r *RegoRewriter) addPathFromFs(path string, slice *[]*Module) error {
fileStat, err := os.Stat(path)
if err != nil {
return err
return fmt.Errorf("%w: %v", ErrReadingFile, err)
}

if fileStat.IsDir() {
infos, err := ioutil.ReadDir(path)
if err != nil {
return err
return fmt.Errorf("%w: %v", ErrReadingFile, err)
}

// handle test dirs
Expand Down Expand Up @@ -223,7 +223,7 @@ func (r *RegoRewriter) checkLibPackages() error {
for _, mod := range r.libs {
path := mod.Module.Package.Path
if !r.allowedLibPackage(path) {
return fmt.Errorf("path %s not found in lib prefixes", path)
return fmt.Errorf("%w: path %s not found in lib prefixes", ErrInvalidLibs, path)
}
}
return nil
Expand Down Expand Up @@ -275,11 +275,11 @@ func (r *RegoRewriter) checkImport(i *ast.Import) error {

importRef, ok := i.Path.Value.(ast.Ref)
if !ok {
return fmt.Errorf("got reference of type %T, want %T", i.Path.Value, ast.Ref{})
return fmt.Errorf("%w: got reference of type %T, want %T", ErrInvalidImport, i.Path.Value, ast.Ref{})
}

if isSubRef(inputRefPrefix, importRef) {
return fmt.Errorf("bad import")
return fmt.Errorf("%w: bad import", ErrInvalidImport)
}

for _, libPrefix := range r.allowedLibPrefixes {
Expand All @@ -288,7 +288,7 @@ func (r *RegoRewriter) checkImport(i *ast.Import) error {
}
}

return fmt.Errorf("bad import")
return fmt.Errorf("%w: bad import", ErrInvalidImport)
}

// checkDataReferences checks that all data references are directed to allowed lib prefixes or
Expand All @@ -312,7 +312,7 @@ func (r *RegoRewriter) checkDataReferences() error {
})
}
if errs != nil {
return fmt.Errorf("check refs failed on module %s: %w", m.FilePath, errs)
return fmt.Errorf("%w: check refs failed on module %s: %v", ErrDataReferences, m.FilePath, errs)
}
return nil
})
Expand Down Expand Up @@ -356,17 +356,17 @@ func (r *RegoRewriter) rewriteDataRef(ref ast.Ref) ast.Ref {
}

// rewriteImportPath updates an import path to the new value.
func (r *RegoRewriter) rewriteImportPath(path *ast.Term) (*ast.Term, error) {
func (r *RegoRewriter) rewriteImportPath(path *ast.Term) error {
glog.V(vLogDetail).Infof("import: %s %#v", path, path)
for _, t := range path.Value.(ast.Ref) {
glog.V(vLogDetail).Infof(" term: %s %#v %#v", t, t, reflect.TypeOf(t.Value).String())
}
pathRef, ok := path.Value.(ast.Ref)
if !ok {
return nil, fmt.Errorf("got reference of type %T, want %T", path.Value, ast.Ref{})
return fmt.Errorf("got reference of type %T, want %T", path.Value, ast.Ref{})
}

if !r.refNeedsRewrite(pathRef) {
return path, nil
}

return ast.NewTerm(r.rewriteDataRef(pathRef)), nil
path.Value = r.rewriteDataRef(pathRef)
return nil
}

// Rewrite will check the input source and update the package paths and refs as appropriate.
Expand All @@ -383,13 +383,7 @@ func (r *RegoRewriter) Rewrite() (*Sources, error) {
// libs, entryPoints - update import and other refs
err := r.forAllModules(func(mod *Module) error {
for _, i := range mod.Module.Imports {
glog.V(vLogDetail).Infof("import: %s %#v", i.Path, i.Path)
for _, t := range i.Path.Value.(ast.Ref) {
glog.V(vLogDetail).Infof(" term: %s %#v %#v", t, t, reflect.TypeOf(t.Value).String())
}
var err error
i.Path, err = r.rewriteImportPath(i.Path)
if err != nil {
if err := r.rewriteImportPath(i.Path); err != nil {
return err
}
}
Expand Down
Loading

0 comments on commit 1868dfa

Please sign in to comment.