From 196b8f4ef55ad6f363fec90b2794252e71662716 Mon Sep 17 00:00:00 2001 From: ayyghost Date: Wed, 20 Dec 2023 15:20:15 +0000 Subject: [PATCH] more config validation, refactoring --- extract.go | 77 ++++++++++++++++++++++++ updater.go | 171 +++++++++++++++++++++-------------------------------- 2 files changed, 146 insertions(+), 102 deletions(-) create mode 100644 extract.go diff --git a/extract.go b/extract.go new file mode 100644 index 0000000..fbaf6ba --- /dev/null +++ b/extract.go @@ -0,0 +1,77 @@ +//go:debug tarinsecurepath=0 + +package main + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "fmt" + "io" + "os" + "path/filepath" + "strings" +) + +func extractTarGz(tarGzData []byte, destination string, stripComponents int) error { + buf := bytes.NewBuffer(tarGzData) + gzipReader, err := gzip.NewReader(buf) + if err != nil { + return err + } + defer gzipReader.Close() + tarReader := tar.NewReader(gzipReader) + + for { + header, err := tarReader.Next() + if err == io.EOF { + break + } + if err != nil { + return err + } + // Skip pax_global_header entries + if header.Name == "pax_global_header" { + continue + } + + // Calculate the target path by stripping components + target := header.Name + if stripComponents > 0 { + components := strings.SplitN(target, string(filepath.Separator), stripComponents+1) + if len(components) > stripComponents { + target = strings.Join(components[stripComponents:], string(filepath.Separator)) + } else { + target = "" + } + } + + // Get the full path for the file + target = filepath.Join(destination, target) + + switch header.Typeflag { + case tar.TypeDir: + // Create directory if it doesn't exist + if err := os.MkdirAll(target, os.ModePerm); err != nil { + return err + } + + case tar.TypeReg: + // Create file + file, err := os.Create(target) + if err != nil { + return err + } + defer file.Close() + + if _, err := io.Copy(file, tarReader); err != nil { + return err + } + + default: + return fmt.Errorf("unsupported file type: %v in %v", header.Typeflag, header.Name) + } + } + + return nil +} diff --git a/updater.go b/updater.go index bc2a290..734cd52 100644 --- a/updater.go +++ b/updater.go @@ -1,11 +1,6 @@ -//go:debug tarinsecurepath=0 - package main import ( - "archive/tar" - "bytes" - "compress/gzip" "context" "encoding/json" "flag" @@ -32,7 +27,6 @@ type Target struct { } type Config struct { - MetadataDir string `json:"metadata_dir"` DeployDir string `json:"deploy_dir"` Targets []*Target `json:"targets"` PublicSigningKey string `json:"public_signing_key"` @@ -55,19 +49,14 @@ func main() { log.Fatal(err) } - err = validateConfig(&config) - if err != nil { - log.Fatal(err) - } - - err = os.MkdirAll(config.MetadataDir, 0750) + err = config.Validate() if err != nil { log.Fatal(err) } githubAPIToken, ok := os.LookupEnv("GITHUB_API_TOKEN") - if !ok { - log.Fatal("GITHUB_API_TOKEN environment variable must be set") + if !ok || githubAPIToken == "" { + log.Fatal("GITHUB_API_TOKEN environment variable must be set and non-empty") } client := github.NewClient(nil).WithAuthToken(githubAPIToken) @@ -83,14 +72,14 @@ func main() { } releaseID := strconv.FormatInt(*release.ID, 10) - lastReleaseFile := filepath.Join(config.MetadataDir, fmt.Sprintf("%v_last_release_id", target.Name)) missingLastRelease := false - lastReleaseID, err := os.ReadFile(lastReleaseFile) + lastReleaseID, err := getLastReleaseID(config.DeployDir, target.Name) if err != nil { if os.IsNotExist(err) { missingLastRelease = true } else { - log.Fatal(err) + log.Printf("%s: error getting last release ID: %v", target.Name, err) + continue } } if !missingLastRelease && string(lastReleaseID) == releaseID { @@ -115,7 +104,7 @@ func main() { log.Printf("%s: skipping signature verification!", target.Name) } - err = deploy(config.DeployDir, target.Name, releaseID, tarGzBytes, lastReleaseFile, string(lastReleaseID)) + err = deployRelease(config.DeployDir, target.Name, releaseID, lastReleaseID, tarGzBytes) if err != nil { log.Printf("%s update failed: %v", target.Name, err) continue @@ -126,25 +115,67 @@ func main() { } } -func validateConfig(config *Config) error { - if config.MetadataDir == "" { - return fmt.Errorf("metadata directory must be set") - } - if config.DeployDir == "" { +func (c *Config) Validate() error { + if c.DeployDir == "" { return fmt.Errorf("deploy directory must be set") } - if config.UpdateInterval <= 0 { + if c.UpdateInterval <= 0 { return fmt.Errorf("update interval must be >0") } - if !config.UnsafeSkipSignatureVerification && config.PublicSigningKey == "" { + if !c.UnsafeSkipSignatureVerification && c.PublicSigningKey == "" { return fmt.Errorf("public signing key must be set if signature verification is enabled") } - if len(config.Targets) == 0 { + if len(c.Targets) == 0 { return fmt.Errorf("at least one target must be set") } + targetNames := make(map[string]bool) + for i, target := range c.Targets { + if target.Name == "" { + return fmt.Errorf("name for target %d must be set", i) + } + if target.Owner == "" { + return fmt.Errorf("owner for target %d must be set", i) + } + if target.Repo == "" { + return fmt.Errorf("repo for target %d must be set", i) + } + if targetNames[target.Name] { + return fmt.Errorf("target %d has duplicate name", i) + } + targetNames[target.Name] = true + } return nil } +func getReleaseDir(deployDir, targetName, releaseID string) string { + return filepath.Join(deployDir, targetName+"-"+releaseID) +} + +func getReleaseSymlink(deployDir, targetName string) string { + return filepath.Join(deployDir, targetName) +} + +func getLastReleaseID(deployDir, targetName string) (string, error) { + lastReleaseSymlink := getReleaseSymlink(deployDir, targetName) + fi, err := os.Lstat(lastReleaseSymlink) + if err != nil { + return "", err + } + if fi.Mode()&os.ModeSymlink == 0 { + return "", fmt.Errorf("%s is not a symlink", lastReleaseSymlink) + } + + lastReleaseDir, err := os.Readlink(lastReleaseSymlink) + if err != nil { + return "", err + } + split := strings.Split(filepath.Base(lastReleaseDir), "-") + if len(split) != 2 { + return "", fmt.Errorf("invalid last release directory name: %s", lastReleaseDir) + } + return split[1], nil +} + func downloadReleaseAssets(target *Target, release *github.RepositoryRelease) (tarGzBytes, sigBytes []byte, err error) { if len(release.Assets) < 2 { err = fmt.Errorf("release needs at least 2 assets (have %v)", len(release.Assets)) @@ -153,8 +184,8 @@ func downloadReleaseAssets(target *Target, release *github.RepositoryRelease) (t const tarGzRegexFmt = `^%s-[\w.]+\.tar\.gz$` const sigRegexFmt = `^%s-[\w.]+\.minisig$` - tarGzRegex := regexp.MustCompile(fmt.Sprintf(tarGzRegexFmt, target.Name)) - sigRegex := regexp.MustCompile(fmt.Sprintf(sigRegexFmt, target.Name)) + tarGzRegex := regexp.MustCompile(fmt.Sprintf(tarGzRegexFmt, target.Repo)) + sigRegex := regexp.MustCompile(fmt.Sprintf(sigRegexFmt, target.Repo)) if !(tarGzRegex.MatchString(*release.Assets[0].Name)) { err = fmt.Errorf("first asset doesn't have expected name (%v)", *release.Assets[0].Name) @@ -222,87 +253,23 @@ func verifySignature(publicSigningKey string, tarGzBytes, sigBytes []byte) (bool return pk.Verify(tarGzBytes, sig) } -func deploy(deployDir, targetName, releaseID string, tarGzBytes []byte, lastReleaseFile, lastReleaseID string) error { - extractDir := filepath.Join(deployDir, targetName) + "-" + releaseID - if err := os.Mkdir(extractDir, 0755); err != nil { +func deployRelease(deployDir, targetName, releaseID, lastReleaseID string, tarGzBytes []byte) error { + releaseDir := getReleaseDir(deployDir, targetName, releaseID) + if err := os.Mkdir(releaseDir, 0755); err != nil { return err } - if err := extractTarGz(tarGzBytes, extractDir, 1); err != nil { - return err - } - if err := os.Symlink(extractDir, extractDir+".tmp"); err != nil { - return err - } - if err := os.Rename(extractDir+".tmp", filepath.Join(deployDir, targetName)); err != nil { - return err - } - if err := os.WriteFile(lastReleaseFile, []byte(releaseID), 0640); err != nil { + if err := extractTarGz(tarGzBytes, releaseDir, 1); err != nil { return err } - // clean up old release dir - return os.RemoveAll(filepath.Join(deployDir, targetName) + "-" + lastReleaseID) -} - -func extractTarGz(tarGzData []byte, destination string, stripComponents int) error { - buf := bytes.NewBuffer(tarGzData) - gzipReader, err := gzip.NewReader(buf) - if err != nil { + releaseSymlink := getReleaseSymlink(deployDir, targetName) + if err := os.Symlink(releaseDir, releaseSymlink+".tmp"); err != nil { return err } - defer gzipReader.Close() - tarReader := tar.NewReader(gzipReader) - - for { - header, err := tarReader.Next() - if err == io.EOF { - break - } - if err != nil { - return err - } - // Skip pax_global_header entries - if header.Name == "pax_global_header" { - continue - } - - // Calculate the target path by stripping components - target := header.Name - if stripComponents > 0 { - components := strings.SplitN(target, string(filepath.Separator), stripComponents+1) - if len(components) > stripComponents { - target = strings.Join(components[stripComponents:], string(filepath.Separator)) - } else { - target = "" - } - } - - // Get the full path for the file - target = filepath.Join(destination, target) - - switch header.Typeflag { - case tar.TypeDir: - // Create directory if it doesn't exist - if err := os.MkdirAll(target, os.ModePerm); err != nil { - return err - } - - case tar.TypeReg: - // Create file - file, err := os.Create(target) - if err != nil { - return err - } - defer file.Close() - - if _, err := io.Copy(file, tarReader); err != nil { - return err - } - - default: - return fmt.Errorf("unsupported file type: %v in %v", header.Typeflag, header.Name) - } + if err := os.Rename(releaseSymlink+".tmp", releaseSymlink); err != nil { + return err } - return nil + // clean up last release dir + return os.RemoveAll(getReleaseDir(deployDir, targetName, lastReleaseID)) }