diff --git a/github.go b/github.go index c5f5f23..7c7ae15 100644 --- a/github.go +++ b/github.go @@ -1,6 +1,7 @@ package resource import ( + "context" "crypto/tls" "errors" "fmt" @@ -11,8 +12,6 @@ import ( "golang.org/x/oauth2" - "context" - "github.com/google/go-github/github" ) @@ -37,10 +36,11 @@ type GitHub interface { } type GitHubClient struct { - client *github.Client + client *github.Client - owner string - repository string + owner string + repository string + accessToken string } func NewGitHubClient(source Source) (*GitHubClient, error) { @@ -92,9 +92,10 @@ func NewGitHubClient(source Source) (*GitHubClient, error) { } return &GitHubClient{ - client: client, - owner: owner, - repository: source.Repository, + client: client, + owner: owner, + repository: source.Repository, + accessToken: source.AccessToken, }, nil } @@ -232,12 +233,23 @@ func (g *GitHubClient) DownloadReleaseAsset(asset github.ReleaseAsset) (io.ReadC return bodyReader, err } - resp, err := http.Get(redirectURL) + req, err := g.client.NewRequest("GET", redirectURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "application/octet-stream") + if g.accessToken != "" && req.URL.Host == g.client.BaseURL.Host { + req.Header.Set("Authorization", "Bearer " + g.accessToken) + } + + httpClient := &http.Client{} + resp, err := httpClient.Do(req) if err != nil { return nil, err } if resp.StatusCode < 200 || resp.StatusCode > 299 { + resp.Body.Close() return nil, fmt.Errorf("redirect URL %q responded with bad status code: %d", redirectURL, resp.StatusCode) } diff --git a/github_test.go b/github_test.go index 0e6b7c9..b6f6669 100644 --- a/github_test.go +++ b/github_test.go @@ -368,20 +368,32 @@ var _ = Describe("GitHub Client", func() { BeforeEach(func() { source.Owner = owner source.Repository = repo + source.AccessToken = "abc123" assetID = 42 asset = github.ReleaseAsset{ID: &assetID} assetPath = fmt.Sprintf("/repos/%s/%s/releases/assets/%d", owner, repo, assetID) }) - var appendGetHandler = func(server *ghttp.Server, path string, statusCode int, body string, headers ...http.Header) { - server.AppendHandlers( - ghttp.CombineHandlers( - ghttp.VerifyRequest("GET", path), + var appendGetHandler = func(server *ghttp.Server, path string, statusCode int, body string, usesAuth bool, headers ...http.Header) { + var authHeaderValue []string + if usesAuth { + authHeaderValue = []string{"Bearer abc123"} + } + server.AppendHandlers(ghttp.CombineHandlers( + ghttp.VerifyRequest("GET", fmt.Sprintf("%s", path)), ghttp.RespondWith(statusCode, body, headers...), - ), - ) + ghttp.VerifyHeaderKV("Accept", "application/octet-stream"), + ghttp.VerifyHeaderKV("Authorization", authHeaderValue...), + )) } + var locationHeader = func(url string) http.Header { + header := make(http.Header) + header.Add("Location", url) + return header + } + + Context("when the asset can be downloaded directly", func() { Context("when the asset is downloaded successfully", func() { const ( @@ -389,7 +401,7 @@ var _ = Describe("GitHub Client", func() { ) BeforeEach(func() { - appendGetHandler(server, assetPath, 200, fileContents) + appendGetHandler(server, assetPath, 200, fileContents, true) }) It("returns the correct body", func() { @@ -405,7 +417,7 @@ var _ = Describe("GitHub Client", func() { Context("when there is an error downloading the asset", func() { BeforeEach(func() { - appendGetHandler(server, assetPath, 401, "authorized personnel only") + appendGetHandler(server, assetPath, 401, "authorized personnel only", true) }) It("returns an error", func() { @@ -416,18 +428,10 @@ var _ = Describe("GitHub Client", func() { }) Context("when the asset is behind a redirect", func() { - const ( - redirectPath = "/the/redirect/path" - ) - - var locationHeader = func(url string) http.Header { - header := make(http.Header) - header.Add("Location", url) - return header - } + const redirectPath = "/the/redirect/path" BeforeEach(func() { - appendGetHandler(server, assetPath, 307, "", locationHeader(redirectPath)) + appendGetHandler(server, assetPath, 307, "", true, locationHeader(redirectPath)) }) Context("when the redirect succeeds", func() { @@ -436,7 +440,7 @@ var _ = Describe("GitHub Client", func() { ) BeforeEach(func() { - appendGetHandler(server, redirectPath, 200, redirectFileContents) + appendGetHandler(server, redirectPath, 200, redirectFileContents, true) }) It("returns the body from the redirect request", func() { @@ -457,8 +461,8 @@ var _ = Describe("GitHub Client", func() { ) BeforeEach(func() { - appendGetHandler(server, redirectPath, 307, "", locationHeader("/somewhere-else")) - appendGetHandler(server, "/somewhere-else", 200, redirectFileContents) + appendGetHandler(server, redirectPath, 307, "", true, locationHeader("/somewhere-else")) + appendGetHandler(server, "/somewhere-else", 200, redirectFileContents, true) }) It("returns the body from the final redirect request", func() { @@ -472,9 +476,34 @@ var _ = Describe("GitHub Client", func() { }) }) + Context("when there is another redirect to an external server", func() { + const ( + redirectFileContents = "some-random-contents-from-redirect" + ) + + var externalServer *ghttp.Server + + BeforeEach(func() { + externalServer = ghttp.NewServer() + + appendGetHandler(server, redirectPath, 307, "", true, locationHeader(externalServer.URL() + "/somewhere-else")) + appendGetHandler(externalServer, "/somewhere-else", 200, redirectFileContents, false) + }) + + It("downloads the file without the Authorization header", func() { + readCloser, err := client.DownloadReleaseAsset(asset) + Expect(err).NotTo(HaveOccurred()) + defer readCloser.Close() + + body, err := ioutil.ReadAll(readCloser) + Expect(err).NotTo(HaveOccurred()) + Expect(string(body)).To(Equal(redirectFileContents)) + }) + }) + Context("when the redirect request response is a 400", func() { BeforeEach(func() { - appendGetHandler(server, redirectPath, 400, "oops") + appendGetHandler(server, redirectPath, 400, "oops", true) }) It("returns an error", func() { @@ -485,7 +514,7 @@ var _ = Describe("GitHub Client", func() { Context("when the redirect request response is a 401", func() { BeforeEach(func() { - appendGetHandler(server, redirectPath, 401, "authorized personnel only") + appendGetHandler(server, redirectPath, 401, "authorized personnel only", true) }) It("returns an error", func() { @@ -497,7 +526,7 @@ var _ = Describe("GitHub Client", func() { Context("when the redirect request response is a 403", func() { BeforeEach(func() { - appendGetHandler(server, redirectPath, 403, "authorized personnel only") + appendGetHandler(server, redirectPath, 403, "authorized personnel only", true) }) It("returns an error", func() { @@ -508,7 +537,7 @@ var _ = Describe("GitHub Client", func() { Context("when the redirect request response is a 404", func() { BeforeEach(func() { - appendGetHandler(server, redirectPath, 404, "I don't know her") + appendGetHandler(server, redirectPath, 404, "I don't know her", true) }) It("returns an error", func() { @@ -519,7 +548,7 @@ var _ = Describe("GitHub Client", func() { Context("when the redirect request response is a 500", func() { BeforeEach(func() { - appendGetHandler(server, redirectPath, 500, "boom") + appendGetHandler(server, redirectPath, 500, "boom", true) }) It("returns an error", func() { @@ -528,5 +557,30 @@ var _ = Describe("GitHub Client", func() { }) }) }) + + Context("when the asset is behind a redirect on an external server", func() { + const ( + redirectFileContents = "some-random-contents-from-redirect" + ) + + var externalServer *ghttp.Server + + BeforeEach(func() { + externalServer = ghttp.NewServer() + + appendGetHandler(server, assetPath, 307, "", true, locationHeader(externalServer.URL() + "/somewhere-else")) + appendGetHandler(externalServer, "/somewhere-else", 200, redirectFileContents, false) + }) + + It("downloads the file without the Authorization header", func() { + readCloser, err := client.DownloadReleaseAsset(asset) + Expect(err).NotTo(HaveOccurred()) + defer readCloser.Close() + + body, err := ioutil.ReadAll(readCloser) + Expect(err).NotTo(HaveOccurred()) + Expect(string(body)).To(Equal(redirectFileContents)) + }) + }) }) })