Skip to content

Commit

Permalink
feat: Refactor Updater's file package (#611)
Browse files Browse the repository at this point in the history
* break CopyFile into separate functions

* break overwrite flag into two functions

* fix comment for CopyFileOverwrite

* small tweaks

* tests for file package

* fix linux build

* remove todo

* explain why we continue even on error.

* empty commit for testing
  • Loading branch information
BinaryFissionGames committed Aug 4, 2022
1 parent f837886 commit fb95c4d
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 91 deletions.
2 changes: 1 addition & 1 deletion updater/internal/action/file_action.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (c CopyFileAction) Rollback() error {

// join the relative path to the backup directory to get the location of the backup path
backupFilePath := filepath.Join(c.backupDir, c.FromPathRel)
if err := file.CopyFile(c.logger.Named("copy-file"), backupFilePath, c.ToPath, true, true); err != nil {
if err := file.CopyFileRollback(c.logger.Named("copy-file"), backupFilePath, c.ToPath); err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}

Expand Down
121 changes: 76 additions & 45 deletions updater/internal/file/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package file

import (
"errors"
"fmt"
"io"
"io/fs"
Expand All @@ -24,9 +25,78 @@ import (
"go.uber.org/zap"
)

// CopyFile copies the file from pathIn to pathOut.
// If the file does not exist, it is created. If the file does exist, it is truncated before writing.
func CopyFile(logger *zap.Logger, pathIn, pathOut string, overwrite bool, useInFilePermBackup bool) error {
// CopyFileOverwrite copies the file from pathIn to pathOut.
// The output file is created if it does not exist.
// If the output file does exist, it is removed, then written from the input file, preserving the output file's mode.
func CopyFileOverwrite(logger *zap.Logger, pathIn, pathOut string) error {
fileMode := fs.FileMode(0600)
pathOutClean := filepath.Clean(pathOut)

// Try to save existing file's permissions
outFileInfo, _ := os.Stat(pathOutClean)
if outFileInfo != nil {
fileMode = outFileInfo.Mode()
}

pathInClean := filepath.Clean(pathIn)
// If the input file cannot be opened for some reason, do NOT delete the file
if _, err := os.Stat(pathInClean); err != nil {
return fmt.Errorf("failed to stat input file: %w", err)
}

// Remove old file to prevent issues with mac
if err := os.Remove(pathOutClean); err != nil {
logger.Debug("Failed to remove output file", zap.Error(err))
}

return copyFileInternal(logger, pathIn, pathOut, os.O_CREATE|os.O_WRONLY, fileMode)
}

// CopyFileNoOverwrite copies the file from pathIn to pathOut, preserving the input file's mode.
// If the output file already exists, this function returns an error.
func CopyFileNoOverwrite(logger *zap.Logger, pathIn, pathOut string) error {
pathInClean := filepath.Clean(pathIn)

// Use the new file's permissions and fail if there's an issue (want to fail for backup)
inFileInfo, err := os.Stat(pathInClean)
if err != nil {
return fmt.Errorf("failed to retrieve fileinfo for input file: %w", err)
}

// the os.O_EXCL flag will make OpenFile error if the file already exists
return copyFileInternal(logger, pathIn, pathOut, os.O_EXCL|os.O_CREATE|os.O_WRONLY, inFileInfo.Mode())
}

// CopyFileRollback copies the file to the file from pathIn to pathOut, preserving the input file's mode if possible
// Used to perform a rollback
func CopyFileRollback(logger *zap.Logger, pathIn, pathOut string) error {
// Default to 0600 if we can't determine the input file's mode
fileMode := fs.FileMode(0600)
pathInClean := filepath.Clean(pathIn)
// Use the backup file's permissions as a backup and don't fail on error (best chance for rollback)
inFileInfo, err := os.Stat(pathInClean)
switch {
case errors.Is(err, os.ErrNotExist):
return fmt.Errorf("input file does not exist: %w", err)
case err != nil:
// Even though we failed to stat, we'll continue in this case to give the best chance
// of rolling back successfully.
logger.Error("failed to retrieve fileinfo for input file", zap.Error(err))
default:
fileMode = inFileInfo.Mode()
}

pathOutClean := filepath.Clean(pathOut)
// Remove old file to prevent issues with mac
if err = os.Remove(pathOutClean); err != nil {
logger.Debug("Failed to remove output file", zap.Error(err))
}

return copyFileInternal(logger, pathIn, pathOut, os.O_CREATE|os.O_WRONLY, fileMode)
}

// copyFileInternal copies the file at pathIn to pathOut, using the provided flags and file mode
func copyFileInternal(logger *zap.Logger, pathIn, pathOut string, outFlags int, outMode fs.FileMode) error {
pathInClean := filepath.Clean(pathIn)

// Open the input file for reading.
Expand All @@ -37,66 +107,27 @@ func CopyFile(logger *zap.Logger, pathIn, pathOut string, overwrite bool, useInF
defer func() {
err := inFile.Close()
if err != nil {
logger.Info("Failed to close input file", zap.Error(err))
logger.Error("Failed to close input file", zap.Error(err))
}
}()

pathOutClean := filepath.Clean(pathOut)
fileMode := fs.FileMode(0600)
flags := os.O_CREATE | os.O_WRONLY
if overwrite {
// If we are OK to overwrite, we will truncate the file on open
flags |= os.O_TRUNC

// Try to save old file's permissions
outFileInfo, _ := os.Stat(pathOutClean)
if outFileInfo != nil {
fileMode = outFileInfo.Mode()
} else if useInFilePermBackup {
// Use the new file's permissions as a backup and don't fail on error (best chance for rollback)
inFileInfo, err := inFile.Stat()
switch {
case err != nil:
logger.Error("failed to retrieve fileinfo for input file", zap.Error(err))
case inFileInfo != nil:
fileMode = inFileInfo.Mode()
}
}

// Remove old file to prevent issues with mac
if err = os.Remove(pathOutClean); err != nil {
logger.Debug("Failed to remove output file", zap.Error(err))
}
} else {
// This flag will make OpenFile error if the file already exists
flags |= os.O_EXCL

// Use the new file's permissions and fail if there's an issue (want to fail for backup)
inFileInfo, err := inFile.Stat()
if err != nil {
return fmt.Errorf("failed to retrive fileinfo for input file: %w", err)
}

fileMode = inFileInfo.Mode()
}

// Open the output file, creating it if it does not exist and truncating it.
//#nosec G304 -- out file is cleaned; this is a general purpose copy function
outFile, err := os.OpenFile(pathOutClean, flags, fileMode)
outFile, err := os.OpenFile(pathOutClean, outFlags, outMode)
if err != nil {
return fmt.Errorf("failed to open output file: %w", err)
}
defer func() {
err := outFile.Close()
if err != nil {
logger.Info("Failed to close output file", zap.Error(err))
logger.Error("Failed to close output file", zap.Error(err))
}
}()

// Copy the input file to the output file.
if _, err := io.Copy(outFile, inFile); err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}

return nil
}
104 changes: 65 additions & 39 deletions updater/internal/file/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ import (
"go.uber.org/zap/zaptest"
)

func TestCopyFile(t *testing.T) {
func TestCopyFileOverwrite(t *testing.T) {
t.Run("Copies file when output does not exist", func(t *testing.T) {
tmpDir := t.TempDir()

inFile := filepath.Join("testdata", "test.txt")
outFile := filepath.Join(tmpDir, "test.txt")

err := CopyFile(zaptest.NewLogger(t), inFile, outFile, true, false)
err := CopyFileOverwrite(zaptest.NewLogger(t), inFile, outFile)
require.NoError(t, err)
require.FileExists(t, outFile)

Expand All @@ -52,34 +52,6 @@ func TestCopyFile(t *testing.T) {
}
})

t.Run("Copies file when output does not exist and uses new permissions when argument set", func(t *testing.T) {
tmpDir := t.TempDir()

inFile := filepath.Join("testdata", "test.txt")
outFile := filepath.Join(tmpDir, "test.txt")

err := CopyFile(zaptest.NewLogger(t), inFile, outFile, true, true)
require.NoError(t, err)
require.FileExists(t, outFile)

contentsIn, err := os.ReadFile(inFile)
require.NoError(t, err)

contentsOut, err := os.ReadFile(outFile)
require.NoError(t, err)

require.Equal(t, contentsIn, contentsOut)

fio, err := os.Stat(outFile)
require.NoError(t, err)
fii, err := os.Stat(outFile)
require.NoError(t, err)
// file mode on windows acts unlike unix, we'll only check for this on linux/darwin
if runtime.GOOS != "windows" {
require.Equal(t, fii.Mode(), fio.Mode())
}
})

t.Run("Copies file when output already exists", func(t *testing.T) {
tmpDir := t.TempDir()

Expand All @@ -95,7 +67,7 @@ func TestCopyFile(t *testing.T) {
fioOrig, err := os.Stat(outFile)
require.NoError(t, err)

err = CopyFile(zaptest.NewLogger(t), inFile, outFile, true, false)
err = CopyFileOverwrite(zaptest.NewLogger(t), inFile, outFile)
require.NoError(t, err)
require.FileExists(t, outFile)

Expand All @@ -117,8 +89,8 @@ func TestCopyFile(t *testing.T) {
inFile := filepath.Join("testdata", "does-not-exist.txt")
outFile := filepath.Join(tmpDir, "test.txt")

err := CopyFile(zaptest.NewLogger(t), inFile, outFile, true, false)
require.ErrorContains(t, err, "failed to open input file")
err := CopyFileOverwrite(zaptest.NewLogger(t), inFile, outFile)
require.ErrorContains(t, err, "failed to stat input file")
require.NoFileExists(t, outFile)
})

Expand All @@ -131,16 +103,46 @@ func TestCopyFile(t *testing.T) {
err := os.WriteFile(outFile, []byte("This is a file that already exists"), 0600)
require.NoError(t, err)

err = CopyFile(zaptest.NewLogger(t), inFile, outFile, true, false)
require.ErrorContains(t, err, "failed to open input file")
err = CopyFileOverwrite(zaptest.NewLogger(t), inFile, outFile)
require.ErrorContains(t, err, "failed to stat input file")
require.FileExists(t, outFile)

contentsOut, err := os.ReadFile(outFile)
require.NoError(t, err)
require.Equal(t, []byte("This is a file that already exists"), contentsOut)
})
}

t.Run("Fails to overwrite the output file if 'overwrite' false", func(t *testing.T) {
func TestCopyFileRollback(t *testing.T) {
t.Run("Copies file when output does not exist", func(t *testing.T) {
tmpDir := t.TempDir()

inFile := filepath.Join("testdata", "test.txt")
outFile := filepath.Join(tmpDir, "test.txt")

err := CopyFileNoOverwrite(zaptest.NewLogger(t), inFile, outFile)
require.NoError(t, err)
require.FileExists(t, outFile)

contentsIn, err := os.ReadFile(inFile)
require.NoError(t, err)

contentsOut, err := os.ReadFile(outFile)
require.NoError(t, err)

require.Equal(t, contentsIn, contentsOut)

fio, err := os.Stat(outFile)
require.NoError(t, err)
fii, err := os.Stat(outFile)
require.NoError(t, err)
// file mode on windows acts unlike unix, we'll only check for this on linux/darwin
if runtime.GOOS != "windows" {
require.Equal(t, fii.Mode(), fio.Mode())
}
})

t.Run("Fails to overwrite the output file", func(t *testing.T) {
tmpDir := t.TempDir()

inFile := filepath.Join("testdata", "test.txt")
Expand All @@ -149,7 +151,7 @@ func TestCopyFile(t *testing.T) {
err := os.WriteFile(outFile, []byte("This is a file that already exists"), 0640)
require.NoError(t, err)

err = CopyFile(zaptest.NewLogger(t), inFile, outFile, false, false)
err = CopyFileNoOverwrite(zaptest.NewLogger(t), inFile, outFile)
require.ErrorContains(t, err, "failed to open output file")
require.FileExists(t, outFile)

Expand All @@ -165,13 +167,26 @@ func TestCopyFile(t *testing.T) {
}
})

t.Run("Copies file when output does not exist when 'overwrite' is false", func(t *testing.T) {
t.Run("Fails when input file does not exist", func(t *testing.T) {
tmpDir := t.TempDir()

inFile := filepath.Join("testdata", "does-not-exist.txt")
outFile := filepath.Join(tmpDir, "test.txt")

err := CopyFileNoOverwrite(zaptest.NewLogger(t), inFile, outFile)
require.ErrorContains(t, err, "failed to retrieve fileinfo for input file")
require.NoFileExists(t, outFile)
})
}

func TestCopyFileNoOverwrite(t *testing.T) {
t.Run("Copies file when output does not exist and uses inFile's permissions", func(t *testing.T) {
tmpDir := t.TempDir()

inFile := filepath.Join("testdata", "test.txt")
outFile := filepath.Join(tmpDir, "test.txt")

err := CopyFile(zaptest.NewLogger(t), inFile, outFile, false, false)
err := CopyFileRollback(zaptest.NewLogger(t), inFile, outFile)
require.NoError(t, err)
require.FileExists(t, outFile)

Expand All @@ -192,4 +207,15 @@ func TestCopyFile(t *testing.T) {
require.Equal(t, fii.Mode(), fio.Mode())
}
})

t.Run("Fails when input file does not exist", func(t *testing.T) {
tmpDir := t.TempDir()

inFile := filepath.Join("testdata", "does-not-exist.txt")
outFile := filepath.Join(tmpDir, "test.txt")

err := CopyFileRollback(zaptest.NewLogger(t), inFile, outFile)
require.ErrorContains(t, err, "input file does not exist")
require.NoFileExists(t, outFile)
})
}
4 changes: 2 additions & 2 deletions updater/internal/install/install.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func installFiles(logger *zap.Logger, inputPath, installDir, backupDir string, r
// and we will want to roll that back if that is the case.
rb.AppendAction(cfa)

if err := file.CopyFile(logger.Named("copy-file"), inPath, outPath, true, false); err != nil {
if err := file.CopyFileOverwrite(logger.Named("copy-file"), inPath, outPath); err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}

Expand Down Expand Up @@ -195,7 +195,7 @@ func installFile(logger *zap.Logger, inPath, installDirPath, backupDirPath strin
// and we will want to roll that back if that is the case.
rb.AppendAction(cfa)

if err := file.CopyFile(logger.Named("copy-file"), inPath, outPath, true, false); err != nil {
if err := file.CopyFileOverwrite(logger.Named("copy-file"), inPath, outPath); err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}

Expand Down
4 changes: 2 additions & 2 deletions updater/internal/rollback/rollback.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func backupFiles(logger *zap.Logger, installDir, outputPath string) error {
}

// Fail if copying the input file to the output file would fail
if err := file.CopyFile(logger.Named("copy-file"), inPath, outPath, false, false); err != nil {
if err := file.CopyFileNoOverwrite(logger.Named("copy-file"), inPath, outPath); err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}

Expand Down Expand Up @@ -187,7 +187,7 @@ func backupFile(logger *zap.Logger, inPath, outputDirPath string) error {
}

// Fail if copying the input file to the output file would fail
if err := file.CopyFile(logger.Named("copy-file"), inPath, outPath, false, false); err != nil {
if err := file.CopyFileNoOverwrite(logger.Named("copy-file"), inPath, outPath); err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}

Expand Down
2 changes: 1 addition & 1 deletion updater/internal/service/service_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (d darwinService) Update() error {
}

func (d darwinService) Backup() error {
if err := file.CopyFile(d.logger.Named("copy-file"), d.installedServiceFilePath, path.BackupServiceFile(d.installDir), false, false); err != nil {
if err := file.CopyFileNoOverwrite(d.logger.Named("copy-file"), d.installedServiceFilePath, path.BackupServiceFile(d.installDir)); err != nil {
return fmt.Errorf("failed to copy service file: %w", err)
}

Expand Down
2 changes: 1 addition & 1 deletion updater/internal/service/service_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func (l linuxService) Update() error {
}

func (l linuxService) Backup() error {
if err := file.CopyFile(l.logger.Named("copy-file"), l.installedServiceFilePath, path.BackupServiceFile(l.installDir), false, false); err != nil {
if err := file.CopyFileNoOverwrite(l.logger.Named("copy-file"), l.installedServiceFilePath, path.BackupServiceFile(l.installDir)); err != nil {
return fmt.Errorf("failed to copy service file: %w", err)
}

Expand Down

0 comments on commit fb95c4d

Please sign in to comment.