diff --git a/README.md b/README.md index e1c6dd8..3d81325 100644 --- a/README.md +++ b/README.md @@ -11,10 +11,12 @@ Exported variables and functions implemented till now : ```go var Headers map[string]string // Set headers as a map of key-value pairs, an alternative to calling Header() individually var Cookies map[string]string // Set cookies as a map of key-value pairs, an alternative to calling Cookie() individually -func Get(string) (string,error){} // Takes the url as an argument, returns HTML string -func GetWithClient(string, *http.Client){} // Takes the url and a custom HTTP client as arguments, returns HTML string -func Header(string, string){} // Takes key,value pair to set as headers for the HTTP request made in Get() -func Cookie(string, string){} // Takes key, value pair to set as cookies to be sent with the HTTP request in Get() +func Get(string) (string,error) {} // Takes the url as an argument, returns HTML string +func GetWithClient(string, *http.Client) {} // Takes the url and a custom HTTP client as arguments, returns HTML string +func Post(string, string, interface{}) (string, error) {} // Takes the url, bodyType, and payload as an argument, returns HTML string +func PostForm(string, url.Values) {} // Takes the url and body. bodyType is set to "application/x-www-form-urlencoded" +func Header(string, string) {} // Takes key,value pair to set as headers for the HTTP request made in Get() +func Cookie(string, string) {} // Takes key, value pair to set as cookies to be sent with the HTTP request in Get() func HTMLParse(string) Root {} // Takes the HTML string as an argument, returns a pointer to the DOM constructed func Find([]string) Root {} // Element tag,(attribute key-value pair) as argument, pointer to first occurence returned func FindAll([]string) []Root {} // Same as Find(), but pointers to all occurrences returned diff --git a/soup.go b/soup.go index 4b5a8d0..d75a622 100644 --- a/soup.go +++ b/soup.go @@ -6,9 +6,14 @@ package soup import ( "bytes" + "encoding/json" "fmt" + "io" "io/ioutil" "net/http" + "net/http/httputil" + "net/url" + netURL "net/url" "regexp" "strings" @@ -35,6 +40,10 @@ const ( ErrCreatingGetRequest // ErrInGetRequest will be returned when there was an error during the get request ErrInGetRequest + // ErrCreatingPostRequest will be returned when the post request couldn't be created + ErrCreatingPostRequest + // ErrMarshallingPostRequest will be returned when the body of a post request couldn't be serialized + ErrMarshallingPostRequest // ErrReadingResponse will be returned if there was an error reading the response to our get request ErrReadingResponse ) @@ -99,10 +108,34 @@ func GetWithClient(url string, client *http.Client) (string, error) { req, err := http.NewRequest("GET", url, nil) if err != nil { if debug { - panic("Couldn't perform GET request to " + url) + panic("Couldn't create GET request to " + url) } return "", newError(ErrCreatingGetRequest, "error creating get request to "+url) } + + setHeadersAndCookies(req) + + // Perform request + resp, err := client.Do(req) + if err != nil { + if debug { + panic("Couldn't perform GET request to " + url) + } + return "", newError(ErrInGetRequest, "couldn't perform GET request to "+url) + } + defer resp.Body.Close() + bytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + if debug { + panic("Unable to read the response body") + } + return "", newError(ErrReadingResponse, "unable to read the response body") + } + return string(bytes), nil +} + +// setHeadersAndCookies helps build a request +func setHeadersAndCookies(req *http.Request) { // Set headers for hName, hValue := range Headers { req.Header.Set(hName, hValue) @@ -114,13 +147,66 @@ func GetWithClient(url string, client *http.Client) (string, error) { Value: cValue, }) } +} + +// getBodyReader serializes the body for a network request. See the test file for examples +func getBodyReader(rawBody interface{}) (io.Reader, error) { + var bodyReader io.Reader + + if rawBody != nil { + switch body := rawBody.(type) { + case map[string]string: + jsonBody, err := json.Marshal(body) + if err != nil { + if debug { + panic("Unable to read the response body") + } + return nil, newError(ErrMarshallingPostRequest, "couldn't serialize map of strings to JSON.") + } + bodyReader = bytes.NewBuffer(jsonBody) + case netURL.Values: + bodyReader = strings.NewReader(body.Encode()) + case []byte: //expects JSON format + bodyReader = bytes.NewBuffer(body) + case string: //expects JSON format + bodyReader = strings.NewReader(body) + default: + return nil, newError(ErrMarshallingPostRequest, fmt.Sprintf("Cannot handle body type %T", rawBody)) + } + } + + return bodyReader, nil +} + +// PostWithClient returns the HTML returned by the url using a provided HTTP client +// The type of the body must conform to one of the types listed in func getBodyReader() +func PostWithClient(url string, bodyType string, body interface{}, client *http.Client) (string, error) { + bodyReader, err := getBodyReader(body) + if err != nil { + return "todo:", err + } + + req, err := http.NewRequest("POST", url, bodyReader) + Header("Content-Type", bodyType) + setHeadersAndCookies(req) + + if debug { + // Save a copy of this request for debugging. + requestDump, err := httputil.DumpRequest(req, true) + if err != nil { + fmt.Println(err) + } + fmt.Println(string(requestDump)) + } + // Perform request resp, err := client.Do(req) + if err != nil { if debug { - panic("Couldn't perform GET request to " + url) + panic("Couldn't perform POST request to " + url) } - return "", newError(ErrInGetRequest, "couldn't perform GET request to "+url) + return "", newError(ErrCreatingPostRequest, "couldn't perform POST request to"+url) } defer resp.Body.Close() bytes, err := ioutil.ReadAll(resp.Body) @@ -133,7 +219,17 @@ func GetWithClient(url string, client *http.Client) (string, error) { return string(bytes), nil } -// Get returns the HTML returned by the url in string using the default HTTP client +// Post returns the HTML returned by the url as a string using the default HTTP client +func Post(url string, bodyType string, body interface{}) (string, error) { + return PostWithClient(url, bodyType, body, defaultClient) +} + +// PostForm is a convenience method for POST requests that +func PostForm(url string, data url.Values) (string, error) { + return PostWithClient(url, "application/x-www-form-urlencoded", data, defaultClient) +} + +// Get returns the HTML returned by the url as a string using the default HTTP client func Get(url string) (string, error) { return GetWithClient(url, defaultClient) } diff --git a/soup_test.go b/soup_test.go index 64a7b8d..d2010d4 100644 --- a/soup_test.go +++ b/soup_test.go @@ -1,6 +1,11 @@ package soup import ( + "bytes" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" "strconv" "strings" "testing" @@ -177,7 +182,113 @@ func TestFindReturnsInspectableError(t *testing.T) { assert.Equal(t, ErrElementNotFound, r.Error.(Error).Type) } +// Similar test: https://github.com/hashicorp/go-retryablehttp/blob/master/client_test.go#L616 +func TestClient_Post(t *testing.T) { + // Mock server which always responds 200. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Fatalf("bad method: %s", r.Method) + } + if r.RequestURI != "/foo/bar" { + t.Fatalf("bad uri: %s", r.RequestURI) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Fatalf("bad content-type: %s", ct) + } + + // Check the payload + body, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Fatalf("err: %s", err) + } + expected := []byte(`{"hello":"world"}`) + if !bytes.Equal(body, expected) { + t.Fatalf("bad: %v", string(body)) + } + + w.WriteHeader(200) + })) + defer ts.Close() + + // Make the request with JSON payload + _, err := Post( + ts.URL+"/foo/bar", "application/json", `{"hello":"world"}`) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Make the request with byte payload + _, err = Post( + ts.URL+"/foo/bar", "application/json", []byte(`{"hello":"world"}`)) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Make the request with string map payload + _, err = Post( + ts.URL+"/foo/bar", + "application/json", + map[string]string{ + "hello": "world", + }) + if err != nil { + t.Fatalf("err: %v", err) + } +} + +// Similar test: https://github.com/hashicorp/go-retryablehttp/blob/add-circleci/client_test.go#L631 +func TestClient_PostForm(t *testing.T) { + // Mock server which always responds 200. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Fatalf("bad method: %s", r.Method) + } + if r.RequestURI != "/foo/bar" { + t.Fatalf("bad uri: %s", r.RequestURI) + } + if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" { + t.Fatalf("bad content-type: %s", ct) + } + + // Check the payload + body, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Fatalf("err: %s", err) + } + expected := []byte(`hello=world`) + if !bytes.Equal(body, expected) { + t.Fatalf("bad: %v", string(body)) + } + + w.WriteHeader(200) + })) + defer ts.Close() + + // Create the form data. + form1, err := url.ParseQuery("hello=world") + if err != nil { + t.Fatalf("err: %v", err) + } + + form2 := url.Values{ + "hello": []string{"world"}, + } + + // Make the request. + _, err = PostForm(ts.URL+"/foo/bar", form1) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Make the request. + _, err = PostForm(ts.URL+"/foo/bar", form2) + if err != nil { + t.Fatalf("err: %v", err) + } +} + func TestHTML(t *testing.T) { li := doc.Find("ul").Find("li") assert.Equal(t, "
  • To a JSP page right?
  • ", li.HTML()) + }