Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prepare internals for exposing context.Context in exported API #266

Merged
merged 2 commits into from
Feb 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions ability.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package pagerduty

import "context"

// ListAbilityResponse is the response when calling the ListAbility API endpoint.
type ListAbilityResponse struct {
Abilities []string `json:"abilities"`
}

// ListAbilities lists all abilities on your account.
func (c *Client) ListAbilities() (*ListAbilityResponse, error) {
resp, err := c.get("/abilities")
resp, err := c.get(context.TODO(), "/abilities")
if err != nil {
return nil, err
}
Expand All @@ -17,6 +19,6 @@ func (c *Client) ListAbilities() (*ListAbilityResponse, error) {

// TestAbility Check if your account has the given ability.
func (c *Client) TestAbility(ability string) error {
_, err := c.get("/abilities/" + ability)
_, err := c.get(context.TODO(), "/abilities/"+ability)
return err
}
13 changes: 7 additions & 6 deletions addon.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pagerduty

import (
"context"
"fmt"
"net/http"

Expand Down Expand Up @@ -35,7 +36,7 @@ func (c *Client) ListAddons(o ListAddonOptions) (*ListAddonResponse, error) {
if err != nil {
return nil, err
}
resp, err := c.get("/addons?" + v.Encode())
resp, err := c.get(context.TODO(), "/addons?"+v.Encode())
if err != nil {
return nil, err
}
Expand All @@ -47,8 +48,8 @@ func (c *Client) ListAddons(o ListAddonOptions) (*ListAddonResponse, error) {
func (c *Client) InstallAddon(a Addon) (*Addon, error) {
data := make(map[string]Addon)
data["addon"] = a
resp, err := c.post("/addons", data, nil)
defer resp.Body.Close()
resp, err := c.post(context.TODO(), "/addons", data, nil)
defer resp.Body.Close() // TODO(theckman): validate that this is safe
if err != nil {
return nil, err
}
Expand All @@ -60,13 +61,13 @@ func (c *Client) InstallAddon(a Addon) (*Addon, error) {

// DeleteAddon deletes an add-on from your account.
func (c *Client) DeleteAddon(id string) error {
_, err := c.delete("/addons/" + id)
_, err := c.delete(context.TODO(), "/addons/"+id)
return err
}

// GetAddon gets details about an existing add-on.
func (c *Client) GetAddon(id string) (*Addon, error) {
resp, err := c.get("/addons/" + id)
resp, err := c.get(context.TODO(), "/addons/"+id)
if err != nil {
return nil, err
}
Expand All @@ -77,7 +78,7 @@ func (c *Client) GetAddon(id string) (*Addon, error) {
func (c *Client) UpdateAddon(id string, a Addon) (*Addon, error) {
v := make(map[string]Addon)
v["addon"] = a
resp, err := c.put("/addons/"+id, v, nil)
resp, err := c.put(context.TODO(), "/addons/"+id, v, nil)
if err != nil {
return nil, err
}
Expand Down
11 changes: 6 additions & 5 deletions business_service.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pagerduty

import (
"context"
"fmt"
"net/http"

Expand Down Expand Up @@ -75,7 +76,7 @@ func (c *Client) ListBusinessServices(o ListBusinessServiceOptions) (*ListBusine
}

// Make call to get all pages associated with the base endpoint.
if err := c.pagedGet("/business_services"+queryParms.Encode(), responseHandler); err != nil {
if err := c.pagedGet(context.TODO(), "/business_services"+queryParms.Encode(), responseHandler); err != nil {
return nil, err
}
businessServiceResponse.BusinessServices = businessServices
Expand All @@ -87,19 +88,19 @@ func (c *Client) ListBusinessServices(o ListBusinessServiceOptions) (*ListBusine
func (c *Client) CreateBusinessService(b *BusinessService) (*BusinessService, *http.Response, error) {
data := make(map[string]*BusinessService)
data["business_service"] = b
resp, err := c.post("/business_services", data, nil)
resp, err := c.post(context.TODO(), "/business_services", data, nil)
return getBusinessServiceFromResponse(c, resp, err)
}

// GetBusinessService gets details about a business service.
func (c *Client) GetBusinessService(ID string) (*BusinessService, *http.Response, error) {
resp, err := c.get("/business_services/" + ID)
resp, err := c.get(context.TODO(), "/business_services/"+ID)
return getBusinessServiceFromResponse(c, resp, err)
}

// DeleteBusinessService deletes a business_service.
func (c *Client) DeleteBusinessService(ID string) error {
_, err := c.delete("/business_services/" + ID)
_, err := c.delete(context.TODO(), "/business_services/"+ID)
return err
}

Expand All @@ -109,7 +110,7 @@ func (c *Client) UpdateBusinessService(b *BusinessService) (*BusinessService, *h
id := b.ID
b.ID = ""
v["business_service"] = b
resp, err := c.put("/business_services/"+id, v, nil)
resp, err := c.put(context.TODO(), "/business_services/"+id, v, nil)
return getBusinessServiceFromResponse(c, resp, err)
}

Expand Down
5 changes: 4 additions & 1 deletion change_events.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package pagerduty

import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
)

const changeEventPath = "/v2/change/enqueue"
Expand Down Expand Up @@ -55,8 +57,9 @@ func (c *Client) CreateChangeEvent(e ChangeEvent) (*ChangeEventResponse, error)
}

resp, err := c.doWithEndpoint(
context.TODO(),
c.v2EventsAPIEndpoint,
"POST",
http.MethodPost,
changeEventPath,
false,
bytes.NewBuffer(data),
Expand Down
35 changes: 20 additions & 15 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pagerduty

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -220,36 +221,40 @@ func WithOAuth() ClientOptions {
}
}

func (c *Client) delete(path string) (*http.Response, error) {
return c.do("DELETE", path, nil, nil)
func (c *Client) delete(ctx context.Context, path string) (*http.Response, error) {
return c.do(ctx, http.MethodDelete, path, nil, nil)
}

func (c *Client) put(path string, payload interface{}, headers *map[string]string) (*http.Response, error) {
func (c *Client) put(ctx context.Context, path string, payload interface{}, headers *map[string]string) (*http.Response, error) {
if payload != nil {
data, err := json.Marshal(payload)
if err != nil {
return nil, err
}
return c.do("PUT", path, bytes.NewBuffer(data), headers)
return c.do(ctx, http.MethodPut, path, bytes.NewBuffer(data), headers)
}
return c.do("PUT", path, nil, headers)
return c.do(ctx, http.MethodPut, path, nil, headers)
}

func (c *Client) post(path string, payload interface{}, headers *map[string]string) (*http.Response, error) {
func (c *Client) post(ctx context.Context, path string, payload interface{}, headers *map[string]string) (*http.Response, error) {
data, err := json.Marshal(payload)
if err != nil {
return nil, err
}
return c.do("POST", path, bytes.NewBuffer(data), headers)
return c.do(ctx, http.MethodPost, path, bytes.NewBuffer(data), headers)
}

func (c *Client) get(path string) (*http.Response, error) {
return c.do("GET", path, nil, nil)
func (c *Client) get(ctx context.Context, path string) (*http.Response, error) {
return c.do(ctx, http.MethodGet, path, nil, nil)
}

// needed where pagerduty use a different endpoint for certain actions (eg: v2 events)
func (c *Client) doWithEndpoint(endpoint, method, path string, authRequired bool, body io.Reader, headers *map[string]string) (*http.Response, error) {
req, _ := http.NewRequest(method, endpoint+path, body)
func (c *Client) doWithEndpoint(ctx context.Context, endpoint, method, path string, authRequired bool, body io.Reader, headers *map[string]string) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, method, endpoint+path, body)
if err != nil {
return nil, fmt.Errorf("failed to build request: %w", err)
}

req.Header.Set("Accept", "application/vnd.pagerduty+json;version=2")
if headers != nil {
for k, v := range *headers {
Expand All @@ -273,8 +278,8 @@ func (c *Client) doWithEndpoint(endpoint, method, path string, authRequired bool
return c.checkResponse(resp, err)
}

func (c *Client) do(method, path string, body io.Reader, headers *map[string]string) (*http.Response, error) {
return c.doWithEndpoint(c.apiEndpoint, method, path, true, body, headers)
func (c *Client) do(ctx context.Context, method, path string, body io.Reader, headers *map[string]string) (*http.Response, error) {
return c.doWithEndpoint(ctx, c.apiEndpoint, method, path, true, body, headers)
}

func (c *Client) decodeJSON(resp *http.Response, payload interface{}) error {
Expand Down Expand Up @@ -330,7 +335,7 @@ func (c *Client) getErrorFromResponse(resp *http.Response) APIError {
// a specific slice. The responseHandler is responsible for closing the response.
type responseHandler func(response *http.Response) (APIListObject, error)

func (c *Client) pagedGet(basePath string, handler responseHandler) error {
func (c *Client) pagedGet(ctx context.Context, basePath string, handler responseHandler) error {
// Indicates whether there are still additional pages associated with request.
var stillMore bool

Expand All @@ -339,7 +344,7 @@ func (c *Client) pagedGet(basePath string, handler responseHandler) error {

// While there are more pages, keep adjusting the offset to get all results.
for stillMore, nextOffset = true, 0; stillMore; {
response, err := c.do("GET", fmt.Sprintf("%s?offset=%d", basePath, nextOffset), nil, nil)
response, err := c.do(ctx, http.MethodGet, fmt.Sprintf("%s?offset=%d", basePath, nextOffset), nil, nil)
if err != nil {
return err
}
Expand Down
21 changes: 11 additions & 10 deletions escalation_policy.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pagerduty

import (
"context"
"fmt"
"net/http"

Expand Down Expand Up @@ -62,7 +63,7 @@ func (c *Client) ListEscalationPolicies(o ListEscalationPoliciesOptions) (*ListE
if err != nil {
return nil, err
}
resp, err := c.get(escPath + "?" + v.Encode())
resp, err := c.get(context.TODO(), escPath+"?"+v.Encode())
if err != nil {
return nil, err
}
Expand All @@ -74,13 +75,13 @@ func (c *Client) ListEscalationPolicies(o ListEscalationPoliciesOptions) (*ListE
func (c *Client) CreateEscalationPolicy(e EscalationPolicy) (*EscalationPolicy, error) {
data := make(map[string]EscalationPolicy)
data["escalation_policy"] = e
resp, err := c.post(escPath, data, nil)
resp, err := c.post(context.TODO(), escPath, data, nil)
return getEscalationPolicyFromResponse(c, resp, err)
}

// DeleteEscalationPolicy deletes an existing escalation policy and rules.
func (c *Client) DeleteEscalationPolicy(id string) error {
_, err := c.delete(escPath + "/" + id)
_, err := c.delete(context.TODO(), escPath+"/"+id)
return err
}

Expand All @@ -95,15 +96,15 @@ func (c *Client) GetEscalationPolicy(id string, o *GetEscalationPolicyOptions) (
if err != nil {
return nil, err
}
resp, err := c.get(escPath + "/" + id + "?" + v.Encode())
resp, err := c.get(context.TODO(), escPath+"/"+id+"?"+v.Encode())
return getEscalationPolicyFromResponse(c, resp, err)
}

// UpdateEscalationPolicy updates an existing escalation policy and its rules.
func (c *Client) UpdateEscalationPolicy(id string, e *EscalationPolicy) (*EscalationPolicy, error) {
data := make(map[string]EscalationPolicy)
data["escalation_policy"] = *e
resp, err := c.put(escPath+"/"+id, data, nil)
resp, err := c.put(context.TODO(), escPath+"/"+id, data, nil)
return getEscalationPolicyFromResponse(c, resp, err)
}

Expand All @@ -112,7 +113,7 @@ func (c *Client) UpdateEscalationPolicy(id string, e *EscalationPolicy) (*Escala
func (c *Client) CreateEscalationRule(escID string, e EscalationRule) (*EscalationRule, error) {
data := make(map[string]EscalationRule)
data["escalation_rule"] = e
resp, err := c.post(escPath+"/"+escID+"/escalation_rules", data, nil)
resp, err := c.post(context.TODO(), escPath+"/"+escID+"/escalation_rules", data, nil)
return getEscalationRuleFromResponse(c, resp, err)
}

Expand All @@ -122,27 +123,27 @@ func (c *Client) GetEscalationRule(escID string, id string, o *GetEscalationRule
if err != nil {
return nil, err
}
resp, err := c.get(escPath + "/" + escID + "/escalation_rules/" + id + "?" + v.Encode())
resp, err := c.get(context.TODO(), escPath+"/"+escID+"/escalation_rules/"+id+"?"+v.Encode())
return getEscalationRuleFromResponse(c, resp, err)
}

// DeleteEscalationRule deletes an existing escalation rule.
func (c *Client) DeleteEscalationRule(escID string, id string) error {
_, err := c.delete(escPath + "/" + escID + "/escalation_rules/" + id)
_, err := c.delete(context.TODO(), escPath+"/"+escID+"/escalation_rules/"+id)
return err
}

// UpdateEscalationRule updates an existing escalation rule.
func (c *Client) UpdateEscalationRule(escID string, id string, e *EscalationRule) (*EscalationRule, error) {
data := make(map[string]EscalationRule)
data["escalation_rule"] = *e
resp, err := c.put(escPath+"/"+escID+"/escalation_rules/"+id, data, nil)
resp, err := c.put(context.TODO(), escPath+"/"+escID+"/escalation_rules/"+id, data, nil)
return getEscalationRuleFromResponse(c, resp, err)
}

// ListEscalationRules lists all of the escalation rules for an existing escalation policy.
func (c *Client) ListEscalationRules(escID string) (*ListEscalationRulesResponse, error) {
resp, err := c.get(escPath + "/" + escID + "/escalation_rules")
resp, err := c.get(context.TODO(), escPath+"/"+escID+"/escalation_rules")
if err != nil {
return nil, err
}
Expand Down
13 changes: 11 additions & 2 deletions event_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pagerduty

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -48,9 +49,16 @@ func ManageEvent(e V2Event) (*V2EventResponse, error) {
if err != nil {
return nil, err
}
req, _ := http.NewRequest("POST", v2eventEndPoint, bytes.NewBuffer(data))

req, err := http.NewRequestWithContext(context.TODO(), http.MethodPost, v2eventEndPoint, bytes.NewBuffer(data))
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}

req.Header.Set("User-Agent", "go-pagerduty/"+Version)
req.Header.Set("Content-Type", "application/json")

// TODO(theckman): switch to a package-local default client
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
Expand Down Expand Up @@ -78,7 +86,8 @@ func (c *Client) ManageEvent(e *V2Event) (*V2EventResponse, error) {
if err != nil {
return nil, err
}
resp, err := c.doWithEndpoint(c.v2EventsAPIEndpoint, "POST", "/v2/enqueue", false, bytes.NewBuffer(data), &headers)

resp, err := c.doWithEndpoint(context.TODO(), c.v2EventsAPIEndpoint, http.MethodPost, "/v2/enqueue", false, bytes.NewBuffer(data), &headers)
if err != nil {
return nil, err
}
Expand Down
Loading