Skip to content

Commit

Permalink
Add support for custom md5 remote path. (#49)
Browse files Browse the repository at this point in the history
We use the gitlab package registry to store ansible playbook releases.
Gitlab has permalinks to the latest artifacts or to release assets.
These links to artifacts or release assets are accessed using the gitlab
API and my not be the actual path to the tarball or md5 checksum file.

A custom remote path for the link to the md5 hash would be very nice!
This PR will add a `http-checksum-url` flag. If the `http-checksum-url`
flag has not been set, it will use the `http-url` with `.md5` added to
the path.
  • Loading branch information
dmeulen committed Jul 5, 2024
1 parent 5d7be01 commit a6da503
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 29 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Config file should be in: `/etc/ansible-puller/config.json`, `$HOME/.ansible-pul
| `http-user` | `""` | Username for HTTP Basic Auth |
| `http-pass` | `""` | Password for HTTP basic Auth |
| `http-url` | `""` | HTTP Url to find the Ansible tarball. Required if s3-arn is not set |
| `http-checksum-url` | `""` | HTTP Url to find the Ansible tarball md5 hash. Defaults to http-url + `.md5`. |
| `log-dir` | `"/var/log/ansible-puller"` | Log directory (must exist) |
| `ansible-dir` | `""` | Path in the pulled tarball to cd into before ansible commands - usually ansible.cfg dir |
| `ansible-playbook` | `"site.yml"` | The playbook that will be run - relative to ansible-dir |
Expand Down Expand Up @@ -104,8 +105,9 @@ remote.

By design, ansible_puller will look at the remote path `<resource_path>.md5` to discover the live
MD5 checksum. If, for example, your resource is located at `https://example.com/some/file.tgz` then
ansible_puller will look for the MD5 hash at `https://example.com/some/file.tgz.md5`. The following
conditions will lead to a (re-)download of the ansible tarball:
ansible_puller will look for the MD5 hash at `https://example.com/some/file.tgz.md5`. A custom remote
path can be specified with the `http-checksum-url` option.
The following conditions will lead to a (re-)download of the ansible tarball:
- There is no current ansible tarball at the specified local path
- The current hash of the local ansible tarball not match the remote checksum
- The remote checksum does not exist
Expand Down
9 changes: 4 additions & 5 deletions http_downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,13 @@ func (downloader httpDownloader) Download(remotePath, outputPath string) error {
return nil
}

func (downloader httpDownloader) RemoteChecksum(remotePath string) (string, error) {
hashRemotePath := fmt.Sprintf("%s.md5", remotePath)
func (downloader httpDownloader) RemoteChecksum(checksumURL string) (string, error) {

timeout := time.Duration(2 * time.Second)
client := http.Client{
Timeout: timeout,
}
req, err := http.NewRequest("GET", hashRemotePath, nil)
req, err := http.NewRequest("GET", checksumURL, nil)
if err != nil {
return "", errors.Wrap(err, "failed to create request")
}
Expand All @@ -82,15 +81,15 @@ func (downloader httpDownloader) RemoteChecksum(remotePath string) (string, erro
}
// Ignore the checksum if it's not found, as assumed by the caller of this function.
if resp.StatusCode == http.StatusNotFound {
logrus.Debugf("MD5 sum not found at: %s", hashRemotePath)
logrus.Debugf("MD5 sum not found at: %s", checksumURL)
return "", nil
}
// A non-2xx status code does not cause an error, so we handle it here. https://pkg.go.dev/net/http#Client.Do
if resp.StatusCode >= 400 {
return "", fmt.Errorf("bad status code: %v", resp.StatusCode)
}

logrus.Debugf("Found MD5 sum at: %s", hashRemotePath)
logrus.Debugf("Found MD5 sum at: %s", checksumURL)
defer resp.Body.Close()
remoteChecksum, err := ioutil.ReadAll(resp.Body)
if err != nil {
Expand Down
56 changes: 45 additions & 11 deletions http_downloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ var (
testFilenameHash = "testfile.txt.md5"
testMD5 = "7b20fda6af27c1b59ebdd8c09a93e770"

testEmptyChecksumUrl = ""
testChecksumUrlPath = "custom.txt.md5"

testHashlessFilename = "nohash.txt"
testHashlessText = []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.")

Expand Down Expand Up @@ -65,6 +68,8 @@ func (s *HttpDownloaderTestSuite) SetupTest() {
rw.Write(testText)
case "/" + testFilenameHash:
rw.Write([]byte(testMD5))
case "/" + testChecksumUrlPath:
rw.Write([]byte(testMD5))
case "/" + testHashlessFilename:
rw.Write(testHashlessText)
default:
Expand Down Expand Up @@ -98,7 +103,7 @@ func (s *HttpDownloaderTestSuite) TestIdempotentDownloadWhenNoFileExists() {
username: "",
password: "",
}
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testFilename, testFilename)
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testFilename, testEmptyChecksumUrl, testFilename)
assert.Nil(s.T(), err)

text, err := ioutil.ReadFile(testFilename)
Expand All @@ -111,7 +116,7 @@ func (s *HttpDownloaderTestSuite) TestIdempotentDownloadWhenCurrentFileExists()
username: "",
password: "",
}
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testFilename, testFilename)
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testFilename, testEmptyChecksumUrl, testFilename)
assert.Nil(s.T(), err)

text, err := ioutil.ReadFile(testFilename)
Expand All @@ -125,7 +130,7 @@ func (s *HttpDownloaderTestSuite) TestIdempotentDownloadWhenCurrentFileExists()
modtime := finfo.ModTime()

// Idempotent Download
err = idempotentFileDownload(downloader, s.testServer.URL+"/"+testFilename, testFilename)
err = idempotentFileDownload(downloader, s.testServer.URL+"/"+testFilename, testEmptyChecksumUrl, testFilename)
assert.Nil(s.T(), err)

newFinfo, err := os.Stat(testFilename)
Expand All @@ -140,7 +145,7 @@ func (s *HttpDownloaderTestSuite) TestIdempotentDownloadWhenOldFileExists() {
username: "",
password: "",
}
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testHashlessFilename, testFilename)
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testHashlessFilename, testEmptyChecksumUrl, testFilename)
assert.Nil(s.T(), err)

_, err = ioutil.ReadFile(testFilename)
Expand All @@ -156,7 +161,7 @@ func (s *HttpDownloaderTestSuite) TestIdempotentDownloadWhenOldFileExists() {
time.Sleep(1 * time.Second)

// Idempotent Download
err = idempotentFileDownload(downloader, s.testServer.URL+"/"+testFilename, testFilename)
err = idempotentFileDownload(downloader, s.testServer.URL+"/"+testFilename, testEmptyChecksumUrl, testFilename)
assert.Nil(s.T(), err)

newFinfo, err := os.Stat(testFilename)
Expand All @@ -166,12 +171,41 @@ func (s *HttpDownloaderTestSuite) TestIdempotentDownloadWhenOldFileExists() {
assert.NotEqual(s.T(), modtime, newModtime, "modification time should change")
}

func (s *HttpDownloaderTestSuite) TestIdempotentDownloadWhenCurrentFileExistsUsingChecksumUrl() {
downloader := httpDownloader{
username: "",
password: "",
}
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testFilename, s.testServer.URL+"/"+testChecksumUrlPath, testFilename)
assert.Nil(s.T(), err)

text, err := ioutil.ReadFile(testFilename)
assert.Nil(s.T(), err)
assert.Equal(s.T(), text, testText, "file should download correctly")

// Get original file info
finfo, err := os.Stat(testFilename)
assert.Nil(s.T(), err)

modtime := finfo.ModTime()

// Idempotent Download
err = idempotentFileDownload(downloader, s.testServer.URL+"/"+testFilename, testEmptyChecksumUrl, testFilename)
assert.Nil(s.T(), err)

newFinfo, err := os.Stat(testFilename)
newModtime := newFinfo.ModTime()

// Make sure the file didn't change
assert.Equal(s.T(), modtime, newModtime, "modification time should not change")
}

func (s *HttpDownloaderTestSuite) TestIdempotentDownloadNoRemoteHash() {
downloader := httpDownloader{
username: "",
password: "",
}
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testHashlessFilename, testHashlessFilename)
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testHashlessFilename, testEmptyChecksumUrl, testHashlessFilename)
assert.Nil(s.T(), err)

text, err := ioutil.ReadFile(testHashlessFilename)
Expand All @@ -188,7 +222,7 @@ func (s *HttpDownloaderTestSuite) TestIdempotentDownloadNoRemoteHash() {
time.Sleep(1 * time.Second)

// Idempotent Download
err = idempotentFileDownload(downloader, s.testServer.URL+"/"+testHashlessFilename, testHashlessFilename)
err = idempotentFileDownload(downloader, s.testServer.URL+"/"+testHashlessFilename, testEmptyChecksumUrl, testHashlessFilename)
assert.Nil(s.T(), err)

newFinfo, err := os.Stat(testHashlessFilename)
Expand All @@ -203,7 +237,7 @@ func (s *HttpDownloaderTestSuite) TestIdempotentDownloadBasicAuth() {
username: testBasicAuthUser,
password: testBasicAuthPass,
}
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testBasicAuthFilename, testBasicAuthFilename)
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testBasicAuthFilename, testEmptyChecksumUrl, testBasicAuthFilename)
assert.Nil(s.T(), err)

text, err := ioutil.ReadFile(testBasicAuthFilename)
Expand All @@ -216,18 +250,18 @@ func (s *HttpDownloaderTestSuite) TestIdempotentDownloadBasicAuthFailure() {
username: "nottherightuser",
password: "nottherightpass",
}
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testBasicAuthFilename, testBasicAuthFilename)
err := idempotentFileDownload(downloader, s.testServer.URL+"/"+testBasicAuthFilename, testEmptyChecksumUrl, testBasicAuthFilename)
assert.NotNil(s.T(), err)
}

func (s *HttpDownloaderTestSuite) TestIdempotentDownloadFailureFromInvalidURL() {
downloader := httpDownloader{}
err := idempotentFileDownload(downloader, "http://192.168.0.%31/invalid-url", testFilename)
err := idempotentFileDownload(downloader, "http://192.168.0.%31/invalid-url", testEmptyChecksumUrl, testFilename)
assert.NotNil(s.T(), err)
}

func (s *HttpDownloaderTestSuite) TestIdempotentDownloadFailureFromUnresponsiveServer() {
downloader := httpDownloader{}
err := idempotentFileDownload(downloader, "http://0.0.0.0/unresponsive/"+testFilename, testFilename)
err := idempotentFileDownload(downloader, "http://0.0.0.0/unresponsive/"+testFilename, testEmptyChecksumUrl, testFilename)
assert.NotNil(s.T(), err)
}
11 changes: 8 additions & 3 deletions idempotent_download.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"crypto/md5"
"encoding/hex"
"fmt"
"io"
"os"

Expand Down Expand Up @@ -52,9 +53,13 @@ func validateMd5Sum(path, checksum string) error {
// Checks the md5sum of the file to see if the remote file should be downloaded
//
// The MD5 checking may be an Artifactory-specific setup because it will look for the hash at "${url}.md5"
// or will look for the hash in the path provided in http-checksum-url.
// If the MD5 is not found, this will download the file
func idempotentFileDownload(downloader downloader, remotePath, localPath string) error {
logrus.Debugf("Starting idempotent download of %s to %s", remotePath, localPath)
func idempotentFileDownload(downloader downloader, remotePath, checksumURL, localPath string) error {
if len(checksumURL) == 0 {
checksumURL = fmt.Sprintf("%s.md5", remotePath)
}
logrus.Debugf("Starting idempotent download of %s to %s, remote checksum: %s", remotePath, localPath, checksumURL)

currentChecksum, err := md5sum(localPath)
if os.IsNotExist(err) {
Expand All @@ -64,7 +69,7 @@ func idempotentFileDownload(downloader downloader, remotePath, localPath string)
return errors.Wrap(err, "failed to calc local md5sum")
}

remoteChecksum, err := downloader.RemoteChecksum(remotePath)
remoteChecksum, err := downloader.RemoteChecksum(checksumURL)
if err != nil {
return errors.Wrap(err, "failed to download md5sum")
}
Expand Down
6 changes: 4 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func init() {
pflag.String("http-pass", "", "HTTP password for pulling the remote file")

pflag.String("http-url", "", "Remote endpoint to retrieve the file from")
pflag.String("http-checksum-url", "", "Remote endpoint to retrieve the checksum from")
pflag.String("s3-arn", "", "Remote object ARN in S3 to retrieve")
pflag.String("s3-conn-region", "", "AWS service endpoint region for S3")

Expand Down Expand Up @@ -153,6 +154,7 @@ func ansibleEnable() {

func getAnsibleRepository(runDir string) error {
httpURL := viper.GetString("http-url")
checksumURL := viper.GetString("http-checksum-url")
s3Obj := viper.GetString("s3-arn")
s3ConnectionRegion := viper.GetString("s3-conn-region")
localCacheFile := fmt.Sprintf("/tmp/%s.tgz", appName)
Expand All @@ -167,13 +169,13 @@ func getAnsibleRepository(runDir string) error {
username: viper.GetString("http-user"),
password: viper.GetString("http-pass"),
}
err = idempotentFileDownload(downloader, remoteHttpURL, localCacheFile)
err = idempotentFileDownload(downloader, remoteHttpURL, checksumURL, localCacheFile)
} else if s3Obj != "" {
downloader, createError := createS3Downloader(s3ConnectionRegion)
if createError != nil {
return errors.Wrap(err, "unable to pull Ansible repo")
}
err = idempotentFileDownload(downloader, s3Obj, localCacheFile)
err = idempotentFileDownload(downloader, s3Obj, checksumURL, localCacheFile)
}
if err != nil {
return errors.Wrap(err, "unable to pull Ansible repo")
Expand Down
12 changes: 6 additions & 6 deletions s3_downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ import (
"os"
"regexp"

"io/ioutil"
"path/filepath"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/sirupsen/logrus"
"io/ioutil"
"path/filepath"
)

type s3Downloader struct {
Expand Down Expand Up @@ -105,8 +106,7 @@ func (downloader s3Downloader) Download(remotePath, outputPath string) (err erro
return
}

func (downloader s3Downloader) RemoteChecksum(remotePath string) (string, error) {
hashRemotePath := fmt.Sprintf("%s.md5", remotePath)
func (downloader s3Downloader) RemoteChecksum(checksumURL string) (string, error) {

dir, err := ioutil.TempDir("", "*")
if err != nil {
Expand All @@ -116,13 +116,13 @@ func (downloader s3Downloader) RemoteChecksum(remotePath string) (string, error)
defer os.RemoveAll(dir)
hashFile := filepath.Join(dir, "md5Hash")

err = downloader.Download(hashRemotePath, hashFile)
err = downloader.Download(checksumURL, hashFile)
if err != nil {
logrus.Infof("MD5 sum not reachable. %v", err)
return "", nil
}

logrus.Infof("Found MD5 sum at: %s", hashRemotePath)
logrus.Infof("Found MD5 sum at: %s", checksumURL)

content, err := ioutil.ReadFile(hashFile)
if err != nil {
Expand Down

0 comments on commit a6da503

Please sign in to comment.