diff --git a/methods/delete.go b/methods/delete.go index 748b27e..aa4eb67 100644 --- a/methods/delete.go +++ b/methods/delete.go @@ -10,9 +10,9 @@ import ( //Deletebasic sends a basic DELETE request func Deletebasic(c *cli.Context) error { - url := c.Args().Get(0) - if url == "" { - fmt.Print("URL is needed") + url, err := checkURL(c.Args().Get(0)) + if err != nil { + fmt.Printf("%s\n", err.Error()) return nil } var jsonStr = []byte(c.String("body")) diff --git a/methods/get.go b/methods/get.go index cf85098..d3f97ca 100644 --- a/methods/get.go +++ b/methods/get.go @@ -12,9 +12,9 @@ import ( //Getbasic sends a simple GET request to the url with any potential parameters like Tokens or Basic Auth func Getbasic(c *cli.Context) { - var url = c.Args().Get(0) - if url == "" { - fmt.Print("URL is needed") + var url, err = checkURL(c.Args().Get(0)) + if err != nil { + fmt.Printf("%s\n", err.Error()) os.Exit(0) } req, err := http.NewRequest("GET", url, nil) diff --git a/methods/patch.go b/methods/patch.go index 0ef20c1..7394fc6 100644 --- a/methods/patch.go +++ b/methods/patch.go @@ -11,9 +11,9 @@ import ( //Patchbasic sends a basic PATCH request func Patchbasic(c *cli.Context) { - url := c.Args().Get(0) - if url == "" { - fmt.Print("URL is needed") + url, err := checkURL(c.Args().Get(0)) + if err != nil { + fmt.Printf("%s\n", err.Error()) os.Exit(0) } var jsonStr = []byte(c.String("body")) diff --git a/methods/post.go b/methods/post.go index 85cd843..74d369b 100644 --- a/methods/post.go +++ b/methods/post.go @@ -11,9 +11,9 @@ import ( //Postbasic sends a basic POST request func Postbasic(c *cli.Context) { - url := c.Args().Get(0) - if url == "" { - fmt.Print("URL is needed") + url, err := checkURL(c.Args().Get(0)) + if err != nil { + fmt.Printf("%s\n", err.Error()) os.Exit(0) } var jsonStr = []byte(c.String("body")) diff --git a/methods/put.go b/methods/put.go index 6794656..bf6e934 100644 --- a/methods/put.go +++ b/methods/put.go @@ -10,9 +10,9 @@ import ( //Putbasic sends a basic PUT request func Putbasic(c *cli.Context) error { - url := c.Args().Get(0) - if url == "" { - fmt.Print("URL is needed") + url, err := checkURL(c.Args().Get(0)) + if err != nil { + fmt.Printf("%s\n", err.Error()) return nil } var jsonStr = []byte(c.String("body")) diff --git a/methods/validation.go b/methods/validation.go new file mode 100644 index 0000000..022932d --- /dev/null +++ b/methods/validation.go @@ -0,0 +1,24 @@ +package methods + +import ( + "fmt" + "net/url" + "strings" +) + +func checkURL(urlStr string) (string, error) { + if urlStr == "" { + return "", fmt.Errorf("URL is needed") + } + + prefixCheck := strings.HasPrefix(urlStr, "http://") || strings.HasPrefix(urlStr, "https://") + if !prefixCheck { + return "", fmt.Errorf("URL missing protocol or contains invalid protocol") + } + + if _, err := url.Parse(urlStr); err != nil { + return "", fmt.Errorf("URL is invalid") + } + + return urlStr, nil +} diff --git a/methods/validation_test.go b/methods/validation_test.go new file mode 100644 index 0000000..fa1f981 --- /dev/null +++ b/methods/validation_test.go @@ -0,0 +1,50 @@ +package methods + +import "testing" + +var checkURLCases = []struct { + name string + url string + expectErr bool +}{ + { + name: "valid http://", + url: "http://example.com", + expectErr: false, + }, + { + name: "valid https://", + url: "https://example.com", + expectErr: false, + }, + { + name: "empty url", + url: "", + expectErr: true, + }, + { + name: "invalid protocol", + url: "htp://example.com", + expectErr: true, + }, + { + name: "disallowed protocol", + url: "irc://example.com", + expectErr: true, + }, +} + +func Test_checkURL(t *testing.T) { + for _, tt := range checkURLCases { + out, err := checkURL(tt.url) + if err != nil && !tt.expectErr { + t.Errorf("%s :: %s", tt.name, err.Error()) + } + if out != tt.url && !tt.expectErr { + t.Errorf("URL mangled. Got %s - expected %s", out, tt.url) + } + if out != "" && err != nil && tt.expectErr { + t.Errorf("Didn't fail when expected") + } + } +}