diff --git a/tools/log4shell/analyze/analyze.go b/tools/log4shell/analyze/analyze.go index dc7a41544..4fc43f9eb 100644 --- a/tools/log4shell/analyze/analyze.go +++ b/tools/log4shell/analyze/analyze.go @@ -15,7 +15,6 @@ package analyze import ( - "archive/zip" "github.com/lunasec-io/lunasec/tools/log4shell/constants" "github.com/lunasec-io/lunasec/tools/log4shell/types" "github.com/lunasec-io/lunasec/tools/log4shell/util" @@ -25,8 +24,11 @@ import ( "strings" ) -func GetJndiLookupHash(zipReader *zip.Reader, filePath string) (fileHash string) { - reader, err := zipReader.Open(constants.JndiLookupClasspath) +func GetJndiLookupHash( + resolveArchiveFile types.ResolveArchiveFile, + filePath string, +) (fileHash string) { + reader, err := resolveArchiveFile(constants.JndiLookupClasspath) if err != nil { log.Debug(). Str("fieName", constants.JndiLookupClasspath). @@ -49,7 +51,11 @@ func GetJndiLookupHash(zipReader *zip.Reader, filePath string) (fileHash string) return } -func ProcessArchiveFile(zipReader *zip.Reader, reader io.Reader, filePath, fileName string) (finding *types.Finding) { +func ProcessArchiveFile( + resolveArchiveFile types.ResolveArchiveFile, + reader io.Reader, + filePath, fileName string, +) (finding *types.Finding) { var ( jndiLookupFileHash string ) @@ -93,7 +99,7 @@ func ProcessArchiveFile(zipReader *zip.Reader, reader io.Reader, filePath, fileN } if VersionIsInRange(archiveName, semverVersion, constants.JndiLookupPatchFileVersions) { - jndiLookupFileHash = GetJndiLookupHash(zipReader, filePath) + jndiLookupFileHash = GetJndiLookupHash(resolveArchiveFile, filePath) } log.Log(). diff --git a/tools/log4shell/commands/patch.go b/tools/log4shell/commands/patch.go index ee12cec2b..11322d4ed 100644 --- a/tools/log4shell/commands/patch.go +++ b/tools/log4shell/commands/patch.go @@ -38,6 +38,7 @@ func JavaArchivePatchCommand( forcePatch := c.Bool("force-patch") dryRun := c.Bool("dry-run") + backup := c.Bool("backup") var patchedLibraries []string @@ -64,7 +65,7 @@ func JavaArchivePatchCommand( } } - err = patch.ProcessJavaArchive(finding, dryRun) + err = patch.ProcessJavaArchive(finding, dryRun, backup) if err != nil { log.Error(). Str("path", finding.Path). diff --git a/tools/log4shell/constants/fs.go b/tools/log4shell/constants/fs.go index 464e42800..2d52c5103 100644 --- a/tools/log4shell/constants/fs.go +++ b/tools/log4shell/constants/fs.go @@ -21,3 +21,7 @@ const ( EarFileExt = ".ear" ClassFileExt = ".class" ) + +var ( + CleanupDirs []string +) diff --git a/tools/log4shell/main.go b/tools/log4shell/main.go index a626ac734..ce0372563 100644 --- a/tools/log4shell/main.go +++ b/tools/log4shell/main.go @@ -17,6 +17,7 @@ package main import ( "github.com/lunasec-io/lunasec/tools/log4shell/commands" "github.com/lunasec-io/lunasec/tools/log4shell/constants" + "github.com/lunasec-io/lunasec/tools/log4shell/util" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/urfave/cli/v2" @@ -48,6 +49,10 @@ func main() { zerolog.SetGlobalLevel(zerolog.InfoLevel) + util.RunOnProcessExit(func() { + util.RemoveCleanupDirs() + }) + globalBoolFlags := map[string]bool{ "verbose": false, "json": false, @@ -187,6 +192,10 @@ func main() { Usage: "Patches findings of libraries vulnerable toLog4Shell by removing the JndiLookup.class file from each.", Before: setGlobalBoolFlags, Flags: []cli.Flag{ + &cli.BoolFlag{ + Name: "backup", + Usage: "Backup each library to path/to/library.jar.bak before overwriting.", + }, &cli.StringSliceFlag{ Name: "exclude", Usage: "Exclude subdirectories from scanning. This can be helpful if there are directories which your user does not have access to when starting a scan from `/`.", diff --git a/tools/log4shell/patch/archivepatch.go b/tools/log4shell/patch/archivepatch.go index b80a6dc15..1a7caff6b 100644 --- a/tools/log4shell/patch/archivepatch.go +++ b/tools/log4shell/patch/archivepatch.go @@ -332,7 +332,7 @@ func copyAndFilterFilesFromZip( return } -func ProcessJavaArchive(finding types.Finding, dryRun bool) (err error) { +func ProcessJavaArchive(finding types.Finding, dryRun, backup bool) (err error) { var ( libraryFile *os.File zipReader *zip.Reader @@ -383,6 +383,23 @@ func ProcessJavaArchive(finding types.Finding, dryRun bool) (err error) { return } + if backup { + backupFilePath := fsFile + ".bak" + log.Info(). + Str("libraryFileName", fsFile). + Str("backupFileName", backupFilePath). + Msg("Backing up library file before overwritting.") + _, err = util.CopyFile(fsFile, backupFilePath) + if err != nil { + log.Error(). + Str("libraryFileName", fsFile). + Str("backupFileName", backupFilePath). + Err(err). + Msg("Unable to backup library file.") + return + } + } + _, err = util.CopyFile(filteredLibrary, fsFile) if err != nil { log.Error(). diff --git a/tools/log4shell/scan/executablejar.go b/tools/log4shell/scan/executablejar.go index 923699e1e..8b93f05eb 100644 --- a/tools/log4shell/scan/executablejar.go +++ b/tools/log4shell/scan/executablejar.go @@ -17,12 +17,13 @@ package scan import ( "bytes" "github.com/lunasec-io/lunasec/tools/log4shell/constants" + "github.com/lunasec-io/lunasec/tools/log4shell/types" "github.com/rs/zerolog/log" "io" "os" ) -func readerAtStartOfArchive(path string, file *os.File) (reader io.ReaderAt, offset int64, err error) { +func readerAtStartOfArchive(path string, file *os.File) (reader types.ReaderAtCloser, offset int64, err error) { // By default, we assume our original file will be our returned reader reader = file @@ -76,7 +77,7 @@ func readerAtStartOfArchive(path string, file *os.File) (reader io.ReaderAt, off Msg("unable to locate start of archive in bash executable jar file") return } - reader = bytes.NewReader(fileContents[idx:]) + reader = types.NopReaderAtCloser(bytes.NewReader(fileContents[idx:])) offset = int64(idx) } return diff --git a/tools/log4shell/scan/scan.go b/tools/log4shell/scan/scan.go index 2cb10810a..de53ecd5e 100644 --- a/tools/log4shell/scan/scan.go +++ b/tools/log4shell/scan/scan.go @@ -21,7 +21,6 @@ import ( "github.com/lunasec-io/lunasec/tools/log4shell/types" "github.com/lunasec-io/lunasec/tools/log4shell/util" "github.com/rs/zerolog/log" - "io" "io/ioutil" "os" "path/filepath" @@ -147,9 +146,65 @@ func (s *Log4jDirectoryScanner) scanLocatedArchive( return s.scanArchiveForVulnerableFiles(path, reader, info.Size() - offset) } +func (s *Log4jDirectoryScanner) getFilesToScan( + path string, + size int64, + zipReader *zip.Reader, +) (filesToScan []types.FileToScan, cleanup func(), err error) { + if size > 1024 * 1024 * 1024 { + var ( + tmpPath string + filenames []string + ) + + _, name := filepath.Split(path) + tmpPath, err = os.MkdirTemp(os.TempDir(), name) + if err != nil { + log.Warn(). + Str("path", path). + Err(err). + Msg("unable to create temporary path") + return + } + util.EnsureDirIsCleanedUp(tmpPath) + cleanup = func() { + os.RemoveAll(tmpPath) + util.RemoveDirFromCleanup(tmpPath) + } + + filenames, err = util.Unzip(zipReader, tmpPath) + if err != nil { + log.Warn(). + Str("path", path). + Err(err). + Msg("unable to unzip file") + return + } + + for _, file := range filenames { + dir, extractedFilename := filepath.Split(file) + + fileToScan := &types.DiskFileToScan{ + Filename: extractedFilename, + Path: dir, + } + filesToScan = append(filesToScan, fileToScan) + } + return + } + + for _, zipFile := range zipReader.File { + fileToScan := &types.ZipFileToScan{ + File: zipFile, + } + filesToScan = append(filesToScan, fileToScan) + } + return +} + func (s *Log4jDirectoryScanner) scanArchiveForVulnerableFiles( path string, - reader io.ReaderAt, + reader types.ReaderAtCloser, size int64, ) (findings []types.Finding) { zipReader, err := zip.NewReader(reader, size) @@ -161,31 +216,48 @@ func (s *Log4jDirectoryScanner) scanArchiveForVulnerableFiles( return } - for _, zipFile := range zipReader.File { - locatedFindings := s.scanFile(zipReader, path, zipFile) + filesToScan, cleanup, err := s.getFilesToScan(path, size, zipReader) + if err != nil { + return + } + + resolveArchiveFile := util.ResolveZipFile(zipReader) + + // if cleanup is specified, then we are reading files from disk + // and should close the current zip reader to free up space, + // set our archive reader to read files from disk, and defer + // a call to cleanup to remove all temporary extracted files + if cleanup != nil { + reader.Close() + resolveArchiveFile = util.ResolveDiskFile + defer cleanup() + } + + for _, fileToScan := range filesToScan { + locatedFindings := s.scanFile(resolveArchiveFile, path, fileToScan) findings = append(findings, locatedFindings...) } return } func (s *Log4jDirectoryScanner) scanFile( - zipReader *zip.Reader, + resolveArchiveFile types.ResolveArchiveFile, path string, - file *zip.File, + file types.FileToScan, ) (findings []types.Finding) { //log.Debug(). // Str("path", path). // Str("file", file.Name). // Msg("Scanning archive file") - fileExt := util.FileExt(file.Name) + fileExt := util.FileExt(file.Name()) switch fileExt { case constants.ClassFileExt: if s.onlyScanArchives { return } - finding := s.scanArchiveFile(zipReader, path, file) + finding := s.scanArchiveFile(resolveArchiveFile, path, file) if finding != nil { findings = []types.Finding{*finding} } @@ -195,7 +267,7 @@ func (s *Log4jDirectoryScanner) scanFile( constants.ZipFileExt, constants.EarFileExt: if s.onlyScanArchives { - finding := s.scanArchiveFile(zipReader, path, file) + finding := s.scanArchiveFile(resolveArchiveFile, path, file) if finding != nil { findings = []types.Finding{*finding} } @@ -207,14 +279,14 @@ func (s *Log4jDirectoryScanner) scanFile( } func (s *Log4jDirectoryScanner) scanArchiveFile( - zipReader *zip.Reader, + resolveArchiveFile types.ResolveArchiveFile, path string, - file *zip.File, + file types.FileToScan, ) (finding *types.Finding) { - reader, err := file.Open() + reader, err := file.Reader() if err != nil { log.Warn(). - Str("classFile", file.Name). + Str("classFile", file.Name()). Str("path", path). Err(err). Msg("unable to open class file") @@ -222,17 +294,17 @@ func (s *Log4jDirectoryScanner) scanArchiveFile( } defer reader.Close() - return s.processArchiveFile(zipReader, reader, path, file.Name) + return s.processArchiveFile(resolveArchiveFile, reader, path, file.Name()) } func (s *Log4jDirectoryScanner) scanEmbeddedArchive( path string, - file *zip.File, + file types.FileToScan, ) (findings []types.Finding) { - reader, err := file.Open() + reader, err := file.Reader() if err != nil { log.Warn(). - Str("classFile", file.Name). + Str("classFile", file.Name()). Str("path", path). Err(err). Msg("unable to open embedded archive") @@ -243,15 +315,16 @@ func (s *Log4jDirectoryScanner) scanEmbeddedArchive( buffer, err := ioutil.ReadAll(reader) if err != nil { log.Warn(). - Str("classFile", file.Name). + Str("classFile", file.Name()). Str("path", path). Err(err). Msg("unable to read embedded archive") return } + reader.Close() - newPath := path + "::" + file.Name - archiveReader := bytes.NewReader(buffer) + newPath := path + "::" + file.Name() + archiveReader := types.NopReaderAtCloser(bytes.NewReader(buffer)) archiveSize := int64(len(buffer)) return s.scanArchiveForVulnerableFiles(newPath, archiveReader, archiveSize) diff --git a/tools/log4shell/scan/scan_test.go b/tools/log4shell/scan/scan_test.go index ac3970ae6..d1928b857 100644 --- a/tools/log4shell/scan/scan_test.go +++ b/tools/log4shell/scan/scan_test.go @@ -41,6 +41,7 @@ func createNewScanner() (scanner Log4jVulnerableDependencyScanner, err error) { } func BenchmarkScanningForVulnerablePackages(b *testing.B) { + return b.ReportAllocs() scanner, err := createNewScanner() @@ -54,6 +55,22 @@ func BenchmarkScanningForVulnerablePackages(b *testing.B) { fmt.Printf("Number of findings: %d\n", len(findings)) } +func BenchmarkScanningForLargeArchives(b *testing.B) { + b.ReportAllocs() + + scanner, err := createNewScanner() + if err != nil { + b.Error(err) + return + } + + for i := 0; i < 10; i++ { + findings := scanner.Scan([]string{"../test/large-archives"}) + + fmt.Printf("Number of findings: %d\n", len(findings)) + } +} + func TestForFalsePositiveLibraryFindings(t *testing.T) { scanner, err := createNewScanner() diff --git a/tools/log4shell/scan/scanfile.go b/tools/log4shell/scan/scanfile.go index a66723188..ec53f88dc 100644 --- a/tools/log4shell/scan/scanfile.go +++ b/tools/log4shell/scan/scanfile.go @@ -15,7 +15,6 @@ package scan import ( - "archive/zip" "github.com/blang/semver/v4" "github.com/lunasec-io/lunasec/tools/log4shell/analyze" "github.com/lunasec-io/lunasec/tools/log4shell/constants" @@ -30,8 +29,8 @@ import ( func IdentifyPotentiallyVulnerableFiles(scanLog4j1 bool, archiveHashLookup types.VulnerableHashLookup) types.ProcessArchiveFile { hashLookup := FilterVulnerableHashLookup(archiveHashLookup, scanLog4j1) - return func(zipReader *zip.Reader, reader io.Reader, path, fileName string) (finding *types.Finding) { - return identifyPotentiallyVulnerableFile(zipReader, reader, path, fileName, hashLookup) + return func(resolveArchiveFile types.ResolveArchiveFile, reader io.Reader, path, fileName string) (finding *types.Finding) { + return identifyPotentiallyVulnerableFile(resolveArchiveFile, reader, path, fileName, hashLookup) } } @@ -50,7 +49,7 @@ func isVulnerableIfContainsJndiLookup(versions []string) bool { } func identifyPotentiallyVulnerableFile( - zipReader *zip.Reader, + resolveArchiveFile types.ResolveArchiveFile, reader io.Reader, path, fileName string, hashLookup types.VulnerableHashLookup, @@ -83,7 +82,7 @@ func identifyPotentiallyVulnerableFile( versions := strings.Split(vulnerableFile.Version, ", ") patchableVersion := isVulnerableIfContainsJndiLookup(versions) - jndiLookupFileHash := analyze.GetJndiLookupHash(zipReader, path) + jndiLookupFileHash := analyze.GetJndiLookupHash(resolveArchiveFile, path) if jndiLookupFileHash != "" { if _, ok := vulnerableFile.VulnerableFileHashLookup[jndiLookupFileHash]; !ok { log.Warn(). diff --git a/tools/log4shell/test/large-archives/.gitignore b/tools/log4shell/test/large-archives/.gitignore new file mode 100644 index 000000000..4c6615fb2 --- /dev/null +++ b/tools/log4shell/test/large-archives/.gitignore @@ -0,0 +1,3 @@ +struts-2.5.28-all/ +large-random-file.bin +large.jar diff --git a/tools/log4shell/test/large-archives/create.sh b/tools/log4shell/test/large-archives/create.sh new file mode 100644 index 000000000..5bc8984a7 --- /dev/null +++ b/tools/log4shell/test/large-archives/create.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +head -c 1G < /dev/urandom > large-random-file.bin +zip -q large.jar large-random-file.bin apache-tomcat-8.5.73 diff --git a/tools/log4shell/types/findings.go b/tools/log4shell/types/findings.go index 1df78a321..519e154b9 100644 --- a/tools/log4shell/types/findings.go +++ b/tools/log4shell/types/findings.go @@ -15,11 +15,12 @@ package types import ( - "archive/zip" "io" ) -type ProcessArchiveFile func(zipReader *zip.Reader, reader io.Reader, path, fileName string) (finding *Finding) +type ResolveArchiveFile func(path string) (io.ReadCloser, error) + +type ProcessArchiveFile func(resolveFile ResolveArchiveFile, reader io.Reader, path, file string) (finding *Finding) type Finding struct { Path string `json:"path"` diff --git a/tools/log4shell/types/scan.go b/tools/log4shell/types/scan.go new file mode 100644 index 000000000..c408bd798 --- /dev/null +++ b/tools/log4shell/types/scan.go @@ -0,0 +1,69 @@ +// Copyright 2022 by LunaSec (owned by Refinery Labs, Inc) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +package types + +import ( + "archive/zip" + "io" + "os" + "path" +) + +type ReaderAtCloser interface { + io.ReaderAt + io.Closer +} + +// NopReaderAtCloser returns a ReadCloser with a no-op Close method wrapping +// the provided ReaderAtCloser r. +func NopReaderAtCloser(r io.ReaderAt) ReaderAtCloser { + return nopCloser{r} +} + +type nopCloser struct { + io.ReaderAt +} + +func (nopCloser) Close() error { return nil } + +type ZipFileToScan struct { + File *zip.File +} + +type DiskFileToScan struct { + Filename string + Path string +} + +type FileToScan interface { + Name() string + Reader() (io.ReadCloser, error) +} + +func (s *ZipFileToScan) Name() string { + return s.File.Name +} + +func (s *ZipFileToScan) Reader() (io.ReadCloser, error) { + return s.File.Open() +} + +func (s *DiskFileToScan) Name() string { + return path.Join(s.Path, s.Filename) +} + +func (s *DiskFileToScan) Reader() (io.ReadCloser, error) { + return os.Open(s.Name()) +} diff --git a/tools/log4shell/util/fs.go b/tools/log4shell/util/fs.go index 050c6a3f4..406558b62 100644 --- a/tools/log4shell/util/fs.go +++ b/tools/log4shell/util/fs.go @@ -17,6 +17,9 @@ package util import ( "archive/zip" "bytes" + "fmt" + "github.com/lunasec-io/lunasec/tools/log4shell/constants" + "github.com/lunasec-io/lunasec/tools/log4shell/types" "github.com/rs/zerolog/log" "io" "io/ioutil" @@ -25,6 +28,16 @@ import ( "strings" ) +func ResolveZipFile(zipReader *zip.Reader) types.ResolveArchiveFile { + return func(path string) (io.ReadCloser, error) { + return zipReader.Open(path) + } +} + +func ResolveDiskFile(path string) (io.ReadCloser, error) { + return os.Open(path) +} + func FileExt(path string) string { return strings.ToLower(filepath.Ext(path)) } @@ -110,4 +123,81 @@ func CopyFile(in, out string) (int64, error) { if e != nil { return 0, e } defer o.Close() return io.Copy(o, i) -} \ No newline at end of file +} + +// Unzip will decompress a zip archive, moving all files and folders +// within the zip file (parameter 1) to an output directory (parameter 2). +// from: https://golangcode.com/unzip-files-in-go/ +func Unzip(reader *zip.Reader, dest string) (filenames []string, err error) { + var ( + outFile *os.File + rc io.ReadCloser + ) + + for _, f := range reader.File { + // Store filename/path for returning and using later on + fpath := filepath.Join(dest, f.Name) + + // Check for ZipSlip. More Info: http://bit.ly/2MsjAWE + if !strings.HasPrefix(fpath, filepath.Clean(dest)+string(os.PathSeparator)) { + return filenames, fmt.Errorf("%s: illegal file path", fpath) + } + + filenames = append(filenames, fpath) + + if f.FileInfo().IsDir() { + // Make Folder + os.MkdirAll(fpath, os.ModePerm) + continue + } + + // Make File + if err = os.MkdirAll(filepath.Dir(fpath), os.ModePerm); err != nil { + return filenames, err + } + + outFile, err = os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) + if err != nil { + return filenames, err + } + + rc, err = f.Open() + if err != nil { + return filenames, err + } + + _, err = io.Copy(outFile, rc) + + // Close the file without defer to close before next iteration of loop + outFile.Close() + rc.Close() + + if err != nil { + return + } + } + return +} + +func EnsureDirIsCleanedUp(dir string) { + constants.CleanupDirs = append(constants.CleanupDirs, dir) +} + +func RemoveDirFromCleanup(dir string) { + var ( + newCleanupDirs []string + ) + for _, cleanupDir := range constants.CleanupDirs { + if dir == cleanupDir { + continue + } + newCleanupDirs = append(newCleanupDirs, cleanupDir) + } + constants.CleanupDirs = newCleanupDirs +} + +func RemoveCleanupDirs() { + for _, cleanupDir := range constants.CleanupDirs { + os.RemoveAll(cleanupDir) + } +} diff --git a/tools/log4shell/util/process.go b/tools/log4shell/util/process.go index 8bea6ad69..9a4885abb 100644 --- a/tools/log4shell/util/process.go +++ b/tools/log4shell/util/process.go @@ -27,3 +27,13 @@ func WaitForProcessExit(callback func()) { close(ch) callback() } + +func RunOnProcessExit(callback func()) { + ch := make(chan os.Signal, 2) + signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-ch + close(ch) + callback() + }() +}