Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Refactor Updater's file package #611

Merged
merged 9 commits into from
Aug 4, 2022
2 changes: 1 addition & 1 deletion updater/internal/action/file_action.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (c CopyFileAction) Rollback() error {

// join the relative path to the backup directory to get the location of the backup path
backupFilePath := filepath.Join(c.backupDir, c.FromPathRel)
if err := file.CopyFile(c.logger.Named("copy-file"), backupFilePath, c.ToPath, true, true); err != nil {
if err := file.CopyFileRollback(c.logger.Named("copy-file"), backupFilePath, c.ToPath); err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}

Expand Down
121 changes: 76 additions & 45 deletions updater/internal/file/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package file

import (
"errors"
"fmt"
"io"
"io/fs"
Expand All @@ -24,9 +25,78 @@ import (
"go.uber.org/zap"
)

// CopyFile copies the file from pathIn to pathOut.
// If the file does not exist, it is created. If the file does exist, it is truncated before writing.
func CopyFile(logger *zap.Logger, pathIn, pathOut string, overwrite bool, useInFilePermBackup bool) error {
// CopyFileOverwrite copies the file from pathIn to pathOut.
// The output file is created if it does not exist.
// If the output file does exist, it is removed, then written from the input file, preserving the output file's mode.
func CopyFileOverwrite(logger *zap.Logger, pathIn, pathOut string) error {
fileMode := fs.FileMode(0600)
pathOutClean := filepath.Clean(pathOut)

// Try to save existing file's permissions
outFileInfo, _ := os.Stat(pathOutClean)
if outFileInfo != nil {
fileMode = outFileInfo.Mode()
}

pathInClean := filepath.Clean(pathIn)
// If the input file cannot be opened for some reason, do NOT delete the file
if _, err := os.Stat(pathInClean); err != nil {
return fmt.Errorf("failed to stat input file: %w", err)
}

// Remove old file to prevent issues with mac
if err := os.Remove(pathOutClean); err != nil {
logger.Debug("Failed to remove output file", zap.Error(err))
}

return copyFileInternal(logger, pathIn, pathOut, os.O_CREATE|os.O_WRONLY, fileMode)
}

// CopyFileNoOverwrite copies the file from pathIn to pathOut, preserving the input file's mode.
// If the output file already exists, this function returns an error.
func CopyFileNoOverwrite(logger *zap.Logger, pathIn, pathOut string) error {
pathInClean := filepath.Clean(pathIn)

// Use the new file's permissions and fail if there's an issue (want to fail for backup)
inFileInfo, err := os.Stat(pathInClean)
if err != nil {
return fmt.Errorf("failed to retrieve fileinfo for input file: %w", err)
}

// the os.O_EXCL flag will make OpenFile error if the file already exists
return copyFileInternal(logger, pathIn, pathOut, os.O_EXCL|os.O_CREATE|os.O_WRONLY, inFileInfo.Mode())
}

// CopyFileRollback copies the file to the file from pathIn to pathOut, preserving the input file's mode if possible
// Used to perform a rollback
func CopyFileRollback(logger *zap.Logger, pathIn, pathOut string) error {
// Default to 0600 if we can't determine the input file's mode
fileMode := fs.FileMode(0600)
pathInClean := filepath.Clean(pathIn)
// Use the backup file's permissions as a backup and don't fail on error (best chance for rollback)
inFileInfo, err := os.Stat(pathInClean)
switch {
case errors.Is(err, os.ErrNotExist):
return fmt.Errorf("input file does not exist: %w", err)
case err != nil:
// Even though we failed to stat, we'll continue in this case to give the best chance
// of rolling back successfully.
logger.Error("failed to retrieve fileinfo for input file", zap.Error(err))
default:
fileMode = inFileInfo.Mode()
}

pathOutClean := filepath.Clean(pathOut)
// Remove old file to prevent issues with mac
if err = os.Remove(pathOutClean); err != nil {
logger.Debug("Failed to remove output file", zap.Error(err))
}

return copyFileInternal(logger, pathIn, pathOut, os.O_CREATE|os.O_WRONLY, fileMode)
}

// copyFileInternal copies the file at pathIn to pathOut, using the provided flags and file mode
func copyFileInternal(logger *zap.Logger, pathIn, pathOut string, outFlags int, outMode fs.FileMode) error {
pathInClean := filepath.Clean(pathIn)

// Open the input file for reading.
Expand All @@ -37,66 +107,27 @@ func CopyFile(logger *zap.Logger, pathIn, pathOut string, overwrite bool, useInF
defer func() {
err := inFile.Close()
if err != nil {
logger.Info("Failed to close input file", zap.Error(err))
logger.Error("Failed to close input file", zap.Error(err))
}
}()

pathOutClean := filepath.Clean(pathOut)
fileMode := fs.FileMode(0600)
flags := os.O_CREATE | os.O_WRONLY
if overwrite {
// If we are OK to overwrite, we will truncate the file on open
flags |= os.O_TRUNC

// Try to save old file's permissions
outFileInfo, _ := os.Stat(pathOutClean)
if outFileInfo != nil {
fileMode = outFileInfo.Mode()
} else if useInFilePermBackup {
// Use the new file's permissions as a backup and don't fail on error (best chance for rollback)
inFileInfo, err := inFile.Stat()
switch {
case err != nil:
logger.Error("failed to retrieve fileinfo for input file", zap.Error(err))
case inFileInfo != nil:
fileMode = inFileInfo.Mode()
}
}

// Remove old file to prevent issues with mac
if err = os.Remove(pathOutClean); err != nil {
logger.Debug("Failed to remove output file", zap.Error(err))
}
} else {
// This flag will make OpenFile error if the file already exists
flags |= os.O_EXCL

// Use the new file's permissions and fail if there's an issue (want to fail for backup)
inFileInfo, err := inFile.Stat()
if err != nil {
return fmt.Errorf("failed to retrive fileinfo for input file: %w", err)
}

fileMode = inFileInfo.Mode()
}

// Open the output file, creating it if it does not exist and truncating it.
//#nosec G304 -- out file is cleaned; this is a general purpose copy function
outFile, err := os.OpenFile(pathOutClean, flags, fileMode)
outFile, err := os.OpenFile(pathOutClean, outFlags, outMode)
if err != nil {
return fmt.Errorf("failed to open output file: %w", err)
}
defer func() {
err := outFile.Close()
if err != nil {
logger.Info("Failed to close output file", zap.Error(err))
logger.Error("Failed to close output file", zap.Error(err))
}
}()

// Copy the input file to the output file.
if _, err := io.Copy(outFile, inFile); err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}

return nil
}
104 changes: 65 additions & 39 deletions updater/internal/file/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ import (
"go.uber.org/zap/zaptest"
)

func TestCopyFile(t *testing.T) {
func TestCopyFileOverwrite(t *testing.T) {
t.Run("Copies file when output does not exist", func(t *testing.T) {
tmpDir := t.TempDir()

inFile := filepath.Join("testdata", "test.txt")
outFile := filepath.Join(tmpDir, "test.txt")

err := CopyFile(zaptest.NewLogger(t), inFile, outFile, true, false)
err := CopyFileOverwrite(zaptest.NewLogger(t), inFile, outFile)
require.NoError(t, err)
require.FileExists(t, outFile)

Expand All @@ -52,34 +52,6 @@ func TestCopyFile(t *testing.T) {
}
})

t.Run("Copies file when output does not exist and uses new permissions when argument set", func(t *testing.T) {
tmpDir := t.TempDir()

inFile := filepath.Join("testdata", "test.txt")
outFile := filepath.Join(tmpDir, "test.txt")

err := CopyFile(zaptest.NewLogger(t), inFile, outFile, true, true)
require.NoError(t, err)
require.FileExists(t, outFile)

contentsIn, err := os.ReadFile(inFile)
require.NoError(t, err)

contentsOut, err := os.ReadFile(outFile)
require.NoError(t, err)

require.Equal(t, contentsIn, contentsOut)

fio, err := os.Stat(outFile)
require.NoError(t, err)
fii, err := os.Stat(outFile)
require.NoError(t, err)
// file mode on windows acts unlike unix, we'll only check for this on linux/darwin
if runtime.GOOS != "windows" {
require.Equal(t, fii.Mode(), fio.Mode())
}
})

t.Run("Copies file when output already exists", func(t *testing.T) {
tmpDir := t.TempDir()

Expand All @@ -95,7 +67,7 @@ func TestCopyFile(t *testing.T) {
fioOrig, err := os.Stat(outFile)
require.NoError(t, err)

err = CopyFile(zaptest.NewLogger(t), inFile, outFile, true, false)
err = CopyFileOverwrite(zaptest.NewLogger(t), inFile, outFile)
require.NoError(t, err)
require.FileExists(t, outFile)

Expand All @@ -117,8 +89,8 @@ func TestCopyFile(t *testing.T) {
inFile := filepath.Join("testdata", "does-not-exist.txt")
outFile := filepath.Join(tmpDir, "test.txt")

err := CopyFile(zaptest.NewLogger(t), inFile, outFile, true, false)
require.ErrorContains(t, err, "failed to open input file")
err := CopyFileOverwrite(zaptest.NewLogger(t), inFile, outFile)
require.ErrorContains(t, err, "failed to stat input file")
require.NoFileExists(t, outFile)
})

Expand All @@ -131,16 +103,46 @@ func TestCopyFile(t *testing.T) {
err := os.WriteFile(outFile, []byte("This is a file that already exists"), 0600)
require.NoError(t, err)

err = CopyFile(zaptest.NewLogger(t), inFile, outFile, true, false)
require.ErrorContains(t, err, "failed to open input file")
err = CopyFileOverwrite(zaptest.NewLogger(t), inFile, outFile)
require.ErrorContains(t, err, "failed to stat input file")
require.FileExists(t, outFile)

contentsOut, err := os.ReadFile(outFile)
require.NoError(t, err)
require.Equal(t, []byte("This is a file that already exists"), contentsOut)
})
}

t.Run("Fails to overwrite the output file if 'overwrite' false", func(t *testing.T) {
func TestCopyFileRollback(t *testing.T) {
t.Run("Copies file when output does not exist", func(t *testing.T) {
tmpDir := t.TempDir()

inFile := filepath.Join("testdata", "test.txt")
outFile := filepath.Join(tmpDir, "test.txt")

err := CopyFileNoOverwrite(zaptest.NewLogger(t), inFile, outFile)
require.NoError(t, err)
require.FileExists(t, outFile)

contentsIn, err := os.ReadFile(inFile)
require.NoError(t, err)

contentsOut, err := os.ReadFile(outFile)
require.NoError(t, err)

require.Equal(t, contentsIn, contentsOut)

fio, err := os.Stat(outFile)
require.NoError(t, err)
fii, err := os.Stat(outFile)
require.NoError(t, err)
// file mode on windows acts unlike unix, we'll only check for this on linux/darwin
if runtime.GOOS != "windows" {
require.Equal(t, fii.Mode(), fio.Mode())
}
})

t.Run("Fails to overwrite the output file", func(t *testing.T) {
tmpDir := t.TempDir()

inFile := filepath.Join("testdata", "test.txt")
Expand All @@ -149,7 +151,7 @@ func TestCopyFile(t *testing.T) {
err := os.WriteFile(outFile, []byte("This is a file that already exists"), 0640)
require.NoError(t, err)

err = CopyFile(zaptest.NewLogger(t), inFile, outFile, false, false)
err = CopyFileNoOverwrite(zaptest.NewLogger(t), inFile, outFile)
require.ErrorContains(t, err, "failed to open output file")
require.FileExists(t, outFile)

Expand All @@ -165,13 +167,26 @@ func TestCopyFile(t *testing.T) {
}
})

t.Run("Copies file when output does not exist when 'overwrite' is false", func(t *testing.T) {
t.Run("Fails when input file does not exist", func(t *testing.T) {
tmpDir := t.TempDir()

inFile := filepath.Join("testdata", "does-not-exist.txt")
outFile := filepath.Join(tmpDir, "test.txt")

err := CopyFileNoOverwrite(zaptest.NewLogger(t), inFile, outFile)
require.ErrorContains(t, err, "failed to retrieve fileinfo for input file")
require.NoFileExists(t, outFile)
})
}

func TestCopyFileNoOverwrite(t *testing.T) {
t.Run("Copies file when output does not exist and uses inFile's permissions", func(t *testing.T) {
tmpDir := t.TempDir()

inFile := filepath.Join("testdata", "test.txt")
outFile := filepath.Join(tmpDir, "test.txt")

err := CopyFile(zaptest.NewLogger(t), inFile, outFile, false, false)
err := CopyFileRollback(zaptest.NewLogger(t), inFile, outFile)
require.NoError(t, err)
require.FileExists(t, outFile)

Expand All @@ -192,4 +207,15 @@ func TestCopyFile(t *testing.T) {
require.Equal(t, fii.Mode(), fio.Mode())
}
})

t.Run("Fails when input file does not exist", func(t *testing.T) {
tmpDir := t.TempDir()

inFile := filepath.Join("testdata", "does-not-exist.txt")
outFile := filepath.Join(tmpDir, "test.txt")

err := CopyFileRollback(zaptest.NewLogger(t), inFile, outFile)
require.ErrorContains(t, err, "input file does not exist")
require.NoFileExists(t, outFile)
})
}
4 changes: 2 additions & 2 deletions updater/internal/install/install.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func installFiles(logger *zap.Logger, inputPath, installDir, backupDir string, r
// and we will want to roll that back if that is the case.
rb.AppendAction(cfa)

if err := file.CopyFile(logger.Named("copy-file"), inPath, outPath, true, false); err != nil {
if err := file.CopyFileOverwrite(logger.Named("copy-file"), inPath, outPath); err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}

Expand Down Expand Up @@ -195,7 +195,7 @@ func installFile(logger *zap.Logger, inPath, installDirPath, backupDirPath strin
// and we will want to roll that back if that is the case.
rb.AppendAction(cfa)

if err := file.CopyFile(logger.Named("copy-file"), inPath, outPath, true, false); err != nil {
if err := file.CopyFileOverwrite(logger.Named("copy-file"), inPath, outPath); err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}

Expand Down
4 changes: 2 additions & 2 deletions updater/internal/rollback/rollback.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func backupFiles(logger *zap.Logger, installDir, outputPath string) error {
}

// Fail if copying the input file to the output file would fail
if err := file.CopyFile(logger.Named("copy-file"), inPath, outPath, false, false); err != nil {
if err := file.CopyFileNoOverwrite(logger.Named("copy-file"), inPath, outPath); err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}

Expand Down Expand Up @@ -187,7 +187,7 @@ func backupFile(logger *zap.Logger, inPath, outputDirPath string) error {
}

// Fail if copying the input file to the output file would fail
if err := file.CopyFile(logger.Named("copy-file"), inPath, outPath, false, false); err != nil {
if err := file.CopyFileNoOverwrite(logger.Named("copy-file"), inPath, outPath); err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}

Expand Down
2 changes: 1 addition & 1 deletion updater/internal/service/service_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (d darwinService) Update() error {
}

func (d darwinService) Backup() error {
if err := file.CopyFile(d.logger.Named("copy-file"), d.installedServiceFilePath, path.BackupServiceFile(d.installDir), false, false); err != nil {
if err := file.CopyFileNoOverwrite(d.logger.Named("copy-file"), d.installedServiceFilePath, path.BackupServiceFile(d.installDir)); err != nil {
return fmt.Errorf("failed to copy service file: %w", err)
}

Expand Down
2 changes: 1 addition & 1 deletion updater/internal/service/service_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func (l linuxService) Update() error {
}

func (l linuxService) Backup() error {
if err := file.CopyFile(l.logger.Named("copy-file"), l.installedServiceFilePath, path.BackupServiceFile(l.installDir), false, false); err != nil {
if err := file.CopyFileNoOverwrite(l.logger.Named("copy-file"), l.installedServiceFilePath, path.BackupServiceFile(l.installDir)); err != nil {
return fmt.Errorf("failed to copy service file: %w", err)
}

Expand Down