diff --git a/fsutil/filesystem.go b/fsutil/filesystem.go index 7d65e2a..292eaca 100644 --- a/fsutil/filesystem.go +++ b/fsutil/filesystem.go @@ -4,13 +4,14 @@ package fsutil import ( "archive/tar" "compress/gzip" + "fmt" "io" + "io/fs" "os" "path/filepath" "strings" "github.com/kolide/kit/env" - "github.com/pkg/errors" ) const ( @@ -93,13 +94,13 @@ func CopyFile(src, dest string) error { func UntarBundle(destination string, source string) error { f, err := os.Open(source) if err != nil { - return errors.Wrap(err, "open download source") + return fmt.Errorf("opening source: %w", err) } defer f.Close() gzr, err := gzip.NewReader(f) if err != nil { - return errors.Wrapf(err, "create gzip reader from %s", source) + return fmt.Errorf("creating gzip reader from %s: %w", source, err) } defer gzr.Close() @@ -110,40 +111,93 @@ func UntarBundle(destination string, source string) error { break } if err != nil { - return errors.Wrap(err, "reading tar file") + return fmt.Errorf("reading tar file: %w", err) } if err := sanitizeExtractPath(filepath.Dir(destination), header.Name); err != nil { - return errors.Wrap(err, "checking filename") + return fmt.Errorf("checking filename: %w", err) } - path := filepath.Join(filepath.Dir(destination), header.Name) + destPath := filepath.Join(filepath.Dir(destination), header.Name) info := header.FileInfo() if info.IsDir() { - if err = os.MkdirAll(path, info.Mode()); err != nil { - return errors.Wrapf(err, "creating directory for tar file: %s", path) + if err = os.MkdirAll(destPath, info.Mode()); err != nil { + return fmt.Errorf("creating directory %s for tar file: %w", destPath, err) } continue } - file, err := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, info.Mode()) + if err := writeBundleFile(destPath, info.Mode(), tr); err != nil { + return fmt.Errorf("writing file: %w", err) + } + } + return nil +} + +// UntarBundleWithRequiredFilePermission performs the same operation as UntarBundle, +// but enforces `requiredFilePerm` for all files in the bundle. +func UntarBundleWithRequiredFilePermission(destination string, source string, requiredFilePerm fs.FileMode) error { + f, err := os.Open(source) + if err != nil { + return fmt.Errorf("opening source: %w", err) + } + defer f.Close() + + gzr, err := gzip.NewReader(f) + if err != nil { + return fmt.Errorf("creating gzip reader from %s: %w", source, err) + } + defer gzr.Close() + + tr := tar.NewReader(gzr) + for { + header, err := tr.Next() + if err == io.EOF { + break + } if err != nil { - return errors.Wrapf(err, "open file %s", path) + return fmt.Errorf("reading tar file: %w", err) + } + + if err := sanitizeExtractPath(filepath.Dir(destination), header.Name); err != nil { + return fmt.Errorf("checking filename: %w", err) } - defer file.Close() - if _, err := io.Copy(file, tr); err != nil { - return errors.Wrapf(err, "copy tar %s to destination %s", header.FileInfo().Name(), path) + + destPath := filepath.Join(filepath.Dir(destination), header.Name) + info := header.FileInfo() + if info.IsDir() { + if err = os.MkdirAll(destPath, info.Mode()); err != nil { + return fmt.Errorf("creating directory %s for tar file: %w", destPath, err) + } + continue + } + + if err := writeBundleFile(destPath, requiredFilePerm, tr); err != nil { + return fmt.Errorf("writing file: %w", err) } } return nil } +func writeBundleFile(destPath string, perm fs.FileMode, srcReader io.Reader) error { + file, err := os.OpenFile(destPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, perm) + if err != nil { + return fmt.Errorf("opening %s: %w", destPath, err) + } + defer file.Close() + if _, err := io.Copy(file, srcReader); err != nil { + return fmt.Errorf("copying to %s: %w", destPath, err) + } + + return nil +} + // sanitizeExtractPath checks that the supplied extraction path is nor // vulnerable to zip slip attacks. See https://snyk.io/research/zip-slip-vulnerability func sanitizeExtractPath(filePath string, destination string) error { destpath := filepath.Join(destination, filePath) if !strings.HasPrefix(destpath, filepath.Clean(destination)+string(os.PathSeparator)) { - return errors.Errorf("%s: illegal file path", filePath) + return fmt.Errorf("%s: illegal file path", filePath) } return nil } diff --git a/fsutil/filesystem_test.go b/fsutil/filesystem_test.go index f420523..e420472 100644 --- a/fsutil/filesystem_test.go +++ b/fsutil/filesystem_test.go @@ -1,11 +1,145 @@ package fsutil import ( + "archive/tar" + "compress/gzip" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "strings" "testing" "github.com/stretchr/testify/require" ) +func TestUntarBundle(t *testing.T) { + t.Parallel() + + // Create tarball contents + originalDir := t.TempDir() + topLevelFile := filepath.Join(originalDir, "testfile.txt") + var topLevelFileMode fs.FileMode = 0655 + require.NoError(t, os.WriteFile(topLevelFile, []byte("test1"), topLevelFileMode)) + internalDir := filepath.Join(originalDir, "some", "path", "to") + var nestedFileMode fs.FileMode = 0755 + require.NoError(t, os.MkdirAll(internalDir, nestedFileMode)) + nestedFile := filepath.Join(internalDir, "anotherfile.txt") + require.NoError(t, os.WriteFile(nestedFile, []byte("test2"), nestedFileMode)) + + // Create test tarball + tarballDir := t.TempDir() + tarballFile := filepath.Join(tarballDir, "test.gz") + createTar(t, tarballFile, originalDir) + + // Confirm we can untar the tarball successfully + newDir := t.TempDir() + require.NoError(t, UntarBundle(filepath.Join(newDir, "anything"), tarballFile)) + + // Confirm the tarball has the contents we expect + newTopLevelFile := filepath.Join(newDir, filepath.Base(topLevelFile)) + require.FileExists(t, newTopLevelFile) + newNestedFile := filepath.Join(newDir, "some", "path", "to", filepath.Base(nestedFile)) + require.FileExists(t, newNestedFile) + + // Confirm each file retained its original permissions + topLevelFileInfo, err := os.Stat(newTopLevelFile) + require.NoError(t, err) + require.Equal(t, topLevelFileMode, topLevelFileInfo.Mode()) + nestedFileInfo, err := os.Stat(newNestedFile) + require.NoError(t, err) + require.Equal(t, nestedFileMode, nestedFileInfo.Mode()) +} + +func TestUntarBundleWithRequiredFilePermission(t *testing.T) { + t.Parallel() + + // Create tarball contents + originalDir := t.TempDir() + topLevelFile := filepath.Join(originalDir, "testfile.txt") + require.NoError(t, os.WriteFile(topLevelFile, []byte("test1"), 0655)) + internalDir := filepath.Join(originalDir, "some", "path", "to") + require.NoError(t, os.MkdirAll(internalDir, 0744)) + nestedFile := filepath.Join(internalDir, "anotherfile.txt") + require.NoError(t, os.WriteFile(nestedFile, []byte("test2"), 0744)) + + // Create test tarball + tarballDir := t.TempDir() + tarballFile := filepath.Join(tarballDir, "test.gz") + createTar(t, tarballFile, originalDir) + + // Confirm we can untar the tarball successfully + newDir := t.TempDir() + var requiredFileMode fs.FileMode = 0755 + require.NoError(t, UntarBundleWithRequiredFilePermission(filepath.Join(newDir, "anything"), tarballFile, requiredFileMode)) + + // Confirm the tarball has the contents we expect + newTopLevelFile := filepath.Join(newDir, filepath.Base(topLevelFile)) + require.FileExists(t, newTopLevelFile) + newNestedFile := filepath.Join(newDir, "some", "path", "to", filepath.Base(nestedFile)) + require.FileExists(t, newNestedFile) + + // Require that both files have the required permission 0755 + topLevelFileInfo, err := os.Stat(newTopLevelFile) + require.NoError(t, err) + require.Equal(t, requiredFileMode, topLevelFileInfo.Mode()) + nestedFileInfo, err := os.Stat(newNestedFile) + require.NoError(t, err) + require.Equal(t, requiredFileMode, nestedFileInfo.Mode()) +} + +// createTar is a helper to create a test tar +func createTar(t *testing.T, createLocation string, sourceDir string) { + tarballFile, err := os.Create(createLocation) + require.NoError(t, err) + defer tarballFile.Close() + + gzw := gzip.NewWriter(tarballFile) + defer gzw.Close() + + tw := tar.NewWriter(gzw) + defer tw.Close() + + require.NoError(t, filepath.Walk(sourceDir, func(path string, info fs.FileInfo, err error) error { + if err != nil { + return err + } + + srcInfo, err := os.Lstat(path) + if os.IsNotExist(err) { + return fmt.Errorf("error adding %s to tarball: %w", path, err) + } + + hdr, err := tar.FileInfoHeader(srcInfo, path) + if err != nil { + return fmt.Errorf("error creating tar header: %w", err) + } + hdr.Name = strings.TrimPrefix(path, sourceDir+"/") + + if err := tw.WriteHeader(hdr); err != nil { + return fmt.Errorf("error writing tar header: %w", err) + } + + if !srcInfo.Mode().IsRegular() { + // Don't open/copy over directories + return nil + } + + srcFile, err := os.Open(path) + if err != nil { + return fmt.Errorf("error opening file to add to tarball: %w", err) + } + defer srcFile.Close() + + if _, err := io.Copy(tw, srcFile); err != nil { + return fmt.Errorf("error copying file %s to tarball: %w", path, err) + } + + return nil + })) +} + func TestSanitizeExtractPath(t *testing.T) { t.Parallel()