Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom md5 remote path. #49

Merged
merged 3 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Config file should be in: `/etc/ansible-puller/config.json`, `$HOME/.ansible-pul
| `http-user` | `""` | Username for HTTP Basic Auth |
| `http-pass` | `""` | Password for HTTP basic Auth |
| `http-url` | `""` | HTTP Url to find the Ansible tarball. Required if s3-arn is not set |
| `http-checksum-url` | `""` | HTTP Url to find the Ansible tarball md5 hash. Defaults to http-url + `.md5`. |
| `log-dir` | `"/var/log/ansible-puller"` | Log directory (must exist) |
| `ansible-dir` | `""` | Path in the pulled tarball to cd into before ansible commands - usually ansible.cfg dir |
| `ansible-playbook` | `"site.yml"` | The playbook that will be run - relative to ansible-dir |
Expand Down Expand Up @@ -104,8 +105,9 @@ remote.

By design, ansible_puller will look at the remote path `<resource_path>.md5` to discover the live
MD5 checksum. If, for example, your resource is located at `https://example.com/some/file.tgz` then
ansible_puller will look for the MD5 hash at `https://example.com/some/file.tgz.md5`. The following
conditions will lead to a (re-)download of the ansible tarball:
ansible_puller will look for the MD5 hash at `https://example.com/some/file.tgz.md5`. A custom remote
path can be specified with the `http-checksum-url` option.
The following conditions will lead to a (re-)download of the ansible tarball:
- There is no current ansible tarball at the specified local path
- The current hash of the local ansible tarball not match the remote checksum
- The remote checksum does not exist
Expand Down
9 changes: 4 additions & 5 deletions http_downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,13 @@ func (downloader httpDownloader) Download(remotePath, outputPath string) error {
return nil
}

func (downloader httpDownloader) RemoteChecksum(remotePath string) (string, error) {
hashRemotePath := fmt.Sprintf("%s.md5", remotePath)
func (downloader httpDownloader) RemoteChecksum(checksumURL string) (string, error) {

timeout := time.Duration(2 * time.Second)
client := http.Client{
Timeout: timeout,
}
req, err := http.NewRequest("GET", hashRemotePath, nil)
req, err := http.NewRequest("GET", checksumURL, nil)
if err != nil {
return "", errors.Wrap(err, "failed to create request")
}
Expand All @@ -82,15 +81,15 @@ func (downloader httpDownloader) RemoteChecksum(remotePath string) (string, erro
}
// Ignore the checksum if it's not found, as assumed by the caller of this function.
if resp.StatusCode == http.StatusNotFound {
logrus.Debugf("MD5 sum not found at: %s", hashRemotePath)
logrus.Debugf("MD5 sum not found at: %s", checksumURL)
return "", nil
}
// A non-2xx status code does not cause an error, so we handle it here. https://pkg.go.dev/net/http#Client.Do
if resp.StatusCode >= 400 {
return "", fmt.Errorf("bad status code: %v", resp.StatusCode)
}

logrus.Debugf("Found MD5 sum at: %s", hashRemotePath)
logrus.Debugf("Found MD5 sum at: %s", checksumURL)
defer resp.Body.Close()
remoteChecksum, err := ioutil.ReadAll(resp.Body)
if err != nil {
Expand Down
56 changes: 45 additions & 11 deletions http_downloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ var (
testFilenameHash = "testfile.txt.md5"
testMD5 = "7b20fda6af27c1b59ebdd8c09a93e770"

testEmptyChecksumUrl = ""
testChecksumUrlPath = "custom.txt.md5"

testHashlessFilename = "nohash.txt"
testHashlessText = []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.")

Expand Down Expand Up @@ -65,6 +68,8 @@ func (s *HttpDownloaderTestSuite) SetupTest() {
rw.Write(testText)
case "/" + testFilenameHash:
rw.Write([]byte(testMD5))
case "/" + testChecksumUrlPath:
rw.Write([]byte(testMD5))
case "/" + testHashlessFilename:
rw.Write(testHashlessText)
default:
Expand Down Expand Up @@ -98,7 +103,7 @@ func (s *HttpDownloaderTestSuite) TestIdempotentDownloadWhenNoFileExists() {
username: "",
password: "",
}
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testFilename, testFilename)
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testFilename, testEmptyChecksumUrl, testFilename)
assert.Nil(s.T(), err)

text, err := ioutil.ReadFile(testFilename)
Expand All @@ -111,7 +116,7 @@ func (s *HttpDownloaderTestSuite) TestIdempotentDownloadWhenCurrentFileExists()
username: "",
password: "",
}
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testFilename, testFilename)
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testFilename, testEmptyChecksumUrl, testFilename)
assert.Nil(s.T(), err)

text, err := ioutil.ReadFile(testFilename)
Expand All @@ -125,7 +130,7 @@ func (s *HttpDownloaderTestSuite) TestIdempotentDownloadWhenCurrentFileExists()
modtime := finfo.ModTime()

// Idempotent Download
err = idempotentFileDownload(downloader, s.testServer.URL+"/"+testFilename, testFilename)
err = idempotentFileDownload(downloader, s.testServer.URL+"/"+testFilename, testEmptyChecksumUrl, testFilename)
assert.Nil(s.T(), err)

newFinfo, err := os.Stat(testFilename)
Expand All @@ -140,7 +145,7 @@ func (s *HttpDownloaderTestSuite) TestIdempotentDownloadWhenOldFileExists() {
username: "",
password: "",
}
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testHashlessFilename, testFilename)
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testHashlessFilename, testEmptyChecksumUrl, testFilename)
assert.Nil(s.T(), err)

_, err = ioutil.ReadFile(testFilename)
Expand All @@ -156,7 +161,7 @@ func (s *HttpDownloaderTestSuite) TestIdempotentDownloadWhenOldFileExists() {
time.Sleep(1 * time.Second)

// Idempotent Download
err = idempotentFileDownload(downloader, s.testServer.URL+"/"+testFilename, testFilename)
err = idempotentFileDownload(downloader, s.testServer.URL+"/"+testFilename, testEmptyChecksumUrl, testFilename)
assert.Nil(s.T(), err)

newFinfo, err := os.Stat(testFilename)
Expand All @@ -166,12 +171,41 @@ func (s *HttpDownloaderTestSuite) TestIdempotentDownloadWhenOldFileExists() {
assert.NotEqual(s.T(), modtime, newModtime, "modification time should change")
}

func (s *HttpDownloaderTestSuite) TestIdempotentDownloadWhenCurrentFileExistsUsingChecksumUrl() {
downloader := httpDownloader{
username: "",
password: "",
}
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testFilename, s.testServer.URL+"/"+testChecksumUrlPath, testFilename)
assert.Nil(s.T(), err)

text, err := ioutil.ReadFile(testFilename)
assert.Nil(s.T(), err)
assert.Equal(s.T(), text, testText, "file should download correctly")

// Get original file info
finfo, err := os.Stat(testFilename)
assert.Nil(s.T(), err)

modtime := finfo.ModTime()

// Idempotent Download
err = idempotentFileDownload(downloader, s.testServer.URL+"/"+testFilename, testEmptyChecksumUrl, testFilename)
assert.Nil(s.T(), err)

newFinfo, err := os.Stat(testFilename)
newModtime := newFinfo.ModTime()

// Make sure the file didn't change
assert.Equal(s.T(), modtime, newModtime, "modification time should not change")
}

func (s *HttpDownloaderTestSuite) TestIdempotentDownloadNoRemoteHash() {
downloader := httpDownloader{
username: "",
password: "",
}
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testHashlessFilename, testHashlessFilename)
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testHashlessFilename, testEmptyChecksumUrl, testHashlessFilename)
assert.Nil(s.T(), err)

text, err := ioutil.ReadFile(testHashlessFilename)
Expand All @@ -188,7 +222,7 @@ func (s *HttpDownloaderTestSuite) TestIdempotentDownloadNoRemoteHash() {
time.Sleep(1 * time.Second)

// Idempotent Download
err = idempotentFileDownload(downloader, s.testServer.URL+"/"+testHashlessFilename, testHashlessFilename)
err = idempotentFileDownload(downloader, s.testServer.URL+"/"+testHashlessFilename, testEmptyChecksumUrl, testHashlessFilename)
assert.Nil(s.T(), err)

newFinfo, err := os.Stat(testHashlessFilename)
Expand All @@ -203,7 +237,7 @@ func (s *HttpDownloaderTestSuite) TestIdempotentDownloadBasicAuth() {
username: testBasicAuthUser,
password: testBasicAuthPass,
}
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testBasicAuthFilename, testBasicAuthFilename)
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testBasicAuthFilename, testEmptyChecksumUrl, testBasicAuthFilename)
assert.Nil(s.T(), err)

text, err := ioutil.ReadFile(testBasicAuthFilename)
Expand All @@ -216,18 +250,18 @@ func (s *HttpDownloaderTestSuite) TestIdempotentDownloadBasicAuthFailure() {
username: "nottherightuser",
password: "nottherightpass",
}
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testBasicAuthFilename, testBasicAuthFilename)
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testBasicAuthFilename, testEmptyChecksumUrl, testBasicAuthFilename)
assert.NotNil(s.T(), err)
}

func (s *HttpDownloaderTestSuite) TestIdempotentDownloadFailureFromInvalidURL() {
downloader := httpDownloader{}
err := idempotentFileDownload(downloader, "http://192.168.0.%31/invalid-url", testFilename)
err := idempotentFileDownload(downloader, "http://192.168.0.%31/invalid-url", testEmptyChecksumUrl, testFilename)
assert.NotNil(s.T(), err)
}

func (s *HttpDownloaderTestSuite) TestIdempotentDownloadFailureFromUnresponsiveServer() {
downloader := httpDownloader{}
err := idempotentFileDownload(downloader, "http://0.0.0.0/unresponsive/"+testFilename, testFilename)
err := idempotentFileDownload(downloader, "http://0.0.0.0/unresponsive/"+testFilename, testEmptyChecksumUrl, testFilename)
assert.NotNil(s.T(), err)
}
11 changes: 8 additions & 3 deletions idempotent_download.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"crypto/md5"
"encoding/hex"
"fmt"
"io"
"os"

Expand Down Expand Up @@ -52,9 +53,13 @@ func validateMd5Sum(path, checksum string) error {
// Checks the md5sum of the file to see if the remote file should be downloaded
//
// The MD5 checking may be an Artifactory-specific setup because it will look for the hash at "${url}.md5"
// or will look for the hash in the path provided in http-checksum-url.
// If the MD5 is not found, this will download the file
func idempotentFileDownload(downloader downloader, remotePath, localPath string) error {
logrus.Debugf("Starting idempotent download of %s to %s", remotePath, localPath)
func idempotentFileDownload(downloader downloader, remotePath, checksumURL, localPath string) error {
if len(checksumURL) == 0 {
checksumURL = fmt.Sprintf("%s.md5", remotePath)
}
logrus.Debugf("Starting idempotent download of %s to %s, remote checksum: %s", remotePath, localPath, checksumURL)

currentChecksum, err := md5sum(localPath)
if os.IsNotExist(err) {
Expand All @@ -64,7 +69,7 @@ func idempotentFileDownload(downloader downloader, remotePath, localPath string)
return errors.Wrap(err, "failed to calc local md5sum")
}

remoteChecksum, err := downloader.RemoteChecksum(remotePath)
remoteChecksum, err := downloader.RemoteChecksum(checksumURL)
if err != nil {
return errors.Wrap(err, "failed to download md5sum")
}
Expand Down
6 changes: 4 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func init() {
pflag.String("http-pass", "", "HTTP password for pulling the remote file")

pflag.String("http-url", "", "Remote endpoint to retrieve the file from")
pflag.String("http-checksum-url", "", "Remote endpoint to retrieve the checksum from")
pflag.String("s3-arn", "", "Remote object ARN in S3 to retrieve")
pflag.String("s3-conn-region", "", "AWS service endpoint region for S3")

Expand Down Expand Up @@ -153,6 +154,7 @@ func ansibleEnable() {

func getAnsibleRepository(runDir string) error {
httpURL := viper.GetString("http-url")
checksumURL := viper.GetString("http-checksum-url")
s3Obj := viper.GetString("s3-arn")
s3ConnectionRegion := viper.GetString("s3-conn-region")
localCacheFile := fmt.Sprintf("/tmp/%s.tgz", appName)
Expand All @@ -167,13 +169,13 @@ func getAnsibleRepository(runDir string) error {
username: viper.GetString("http-user"),
password: viper.GetString("http-pass"),
}
err = idempotentFileDownload(downloader, remoteHttpURL, localCacheFile)
err = idempotentFileDownload(downloader, remoteHttpURL, checksumURL, localCacheFile)
} else if s3Obj != "" {
downloader, createError := createS3Downloader(s3ConnectionRegion)
if createError != nil {
return errors.Wrap(err, "unable to pull Ansible repo")
}
err = idempotentFileDownload(downloader, s3Obj, localCacheFile)
err = idempotentFileDownload(downloader, s3Obj, checksumURL, localCacheFile)
}
if err != nil {
return errors.Wrap(err, "unable to pull Ansible repo")
Expand Down
12 changes: 6 additions & 6 deletions s3_downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ import (
"os"
"regexp"

"io/ioutil"
"path/filepath"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/sirupsen/logrus"
"io/ioutil"
"path/filepath"
)

type s3Downloader struct {
Expand Down Expand Up @@ -105,8 +106,7 @@ func (downloader s3Downloader) Download(remotePath, outputPath string) (err erro
return
}

func (downloader s3Downloader) RemoteChecksum(remotePath string) (string, error) {
hashRemotePath := fmt.Sprintf("%s.md5", remotePath)
func (downloader s3Downloader) RemoteChecksum(checksumURL string) (string, error) {

dir, err := ioutil.TempDir("", "*")
if err != nil {
Expand All @@ -116,13 +116,13 @@ func (downloader s3Downloader) RemoteChecksum(remotePath string) (string, error)
defer os.RemoveAll(dir)
hashFile := filepath.Join(dir, "md5Hash")

err = downloader.Download(hashRemotePath, hashFile)
err = downloader.Download(checksumURL, hashFile)
if err != nil {
logrus.Infof("MD5 sum not reachable. %v", err)
return "", nil
}

logrus.Infof("Found MD5 sum at: %s", hashRemotePath)
logrus.Infof("Found MD5 sum at: %s", checksumURL)

content, err := ioutil.ReadFile(hashFile)
if err != nil {
Expand Down
Loading