diff --git a/internal/publicip/api/api.go b/internal/publicip/api/api.go index e7774a6c3..c44f580eb 100644 --- a/internal/publicip/api/api.go +++ b/internal/publicip/api/api.go @@ -3,7 +3,11 @@ package api import ( "errors" "fmt" + "maps" "net/http" + "net/url" + "regexp" + "slices" "strings" ) @@ -16,6 +20,8 @@ const ( IP2Location Provider = "ip2location" ) +const echoipPrefix = "echoip#" + type NameToken struct { Name string Token string @@ -30,15 +36,19 @@ func New(nameTokenPairs []NameToken, client *http.Client) ( if err != nil { return nil, fmt.Errorf("parsing API name: %w", err) } - switch provider { - case Cloudflare: + switch { + case provider == Cloudflare: fetchers[i] = newCloudflare(client) - case IfConfigCo: - fetchers[i] = newIfConfigCo(client) - case IPInfo: + case provider == IfConfigCo: + const ifConfigCoURL = "https://ifconfig.co" + fetchers[i] = newEchoip(client, ifConfigCoURL) + case provider == IPInfo: fetchers[i] = newIPInfo(client, nameTokenPair.Token) - case IP2Location: + case provider == IP2Location: fetchers[i] = newIP2Location(client, nameTokenPair.Token) + case strings.HasPrefix(string(provider), echoipPrefix): + url := strings.TrimPrefix(string(provider), echoipPrefix) + fetchers[i] = newEchoip(client, url) default: panic("provider not valid: " + provider) } @@ -46,20 +56,88 @@ func New(nameTokenPairs []NameToken, client *http.Client) ( return fetchers, nil } +var regexEchoipURL = regexp.MustCompile(`^http(s|):\/\/.+$`) + var ErrProviderNotValid = errors.New("API name is not valid") func ParseProvider(s string) (provider Provider, err error) { - switch strings.ToLower(s) { - case "cloudflare": - return Cloudflare, nil - case string(IfConfigCo): - return IfConfigCo, nil - case "ipinfo": - return IPInfo, nil - case "ip2location": - return IP2Location, nil - default: - return "", fmt.Errorf(`%w: %q can only be "cloudflare", "ifconfigco", "ip2location" or "ipinfo"`, - ErrProviderNotValid, s) + possibleProviders := []Provider{ + Cloudflare, + IfConfigCo, + IP2Location, + IPInfo, + } + stringToProvider := make(map[string]Provider, len(possibleProviders)) + for _, provider := range possibleProviders { + stringToProvider[string(provider)] = provider + } + provider, ok := stringToProvider[strings.ToLower(s)] + if ok { + return provider, nil + } + + customPrefixToURLRegex := map[string]*regexp.Regexp{ + echoipPrefix: regexEchoipURL, + } + for prefix, urlRegex := range customPrefixToURLRegex { + match, err := checkCustomURL(s, prefix, urlRegex) + if !match { + continue + } else if err != nil { + return "", err + } + return Provider(s), nil } + + providerStrings := make([]string, 0, len(stringToProvider)+len(customPrefixToURLRegex)) + for _, providerString := range slices.Sorted(maps.Keys(stringToProvider)) { + providerStrings = append(providerStrings, `"`+providerString+`"`) + } + for _, prefix := range slices.Sorted(maps.Keys(customPrefixToURLRegex)) { + providerStrings = append(providerStrings, "a custom "+prefix+" url") + } + + return "", fmt.Errorf(`%w: %q can only be %s`, + ErrProviderNotValid, s, orStrings(providerStrings)) +} + +var ErrCustomURLNotValid = errors.New("custom URL is not valid") + +func checkCustomURL(s, prefix string, regex *regexp.Regexp) (match bool, err error) { + if !strings.HasPrefix(s, prefix) { + return false, nil + } + s = strings.TrimPrefix(s, prefix) + _, err = url.Parse(s) + if err != nil { + return true, fmt.Errorf("%s %w: %w", prefix, ErrCustomURLNotValid, err) + } + + if regex.MatchString(s) { + return true, nil + } + + return true, fmt.Errorf("%s %w: %q does not match regular expression: %s", + prefix, ErrCustomURLNotValid, s, regex) +} + +func orStrings(strings []string) (result string) { + return joinStrings(strings, "or") +} + +func joinStrings(strings []string, lastJoin string) (result string) { + if len(strings) == 0 { + return "" + } + + result = strings[0] + for i := 1; i < len(strings); i++ { + if i < len(strings)-1 { + result += ", " + strings[i] + } else { + result += " " + lastJoin + " " + strings[i] + } + } + + return result } diff --git a/internal/publicip/api/api_test.go b/internal/publicip/api/api_test.go new file mode 100644 index 000000000..5d4ddedca --- /dev/null +++ b/internal/publicip/api/api_test.go @@ -0,0 +1,68 @@ +package api + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_ParseProvider(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + s string + provider Provider + errWrapped error + errMessage string + }{ + "empty": { + errWrapped: ErrProviderNotValid, + errMessage: `API name is not valid: "" can only be ` + + `"cloudflare", "ifconfigco", "ip2location", "ipinfo" or a custom echoip# url`, + }, + "invalid": { + s: "xyz", + errWrapped: ErrProviderNotValid, + errMessage: `API name is not valid: "xyz" can only be ` + + `"cloudflare", "ifconfigco", "ip2location", "ipinfo" or a custom echoip# url`, + }, + "ipinfo": { + s: "ipinfo", + provider: IPInfo, + }, + "IpInfo": { + s: "IpInfo", + provider: IPInfo, + }, + "echoip_url_empty": { + s: "echoip#", + errWrapped: ErrCustomURLNotValid, + errMessage: `echoip# custom URL is not valid: "" ` + + `does not match regular expression: ^http(s|):\/\/.+$`, + }, + "echoip_url_invalid": { + s: "echoip#postgres://localhost:3451", + errWrapped: ErrCustomURLNotValid, + errMessage: `echoip# custom URL is not valid: "postgres://localhost:3451" ` + + `does not match regular expression: ^http(s|):\/\/.+$`, + }, + "echoip_url_valid": { + s: "echoip#http://localhost:3451", + provider: Provider("echoip#http://localhost:3451"), + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + + provider, err := ParseProvider(testCase.s) + + assert.Equal(t, testCase.provider, provider) + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} diff --git a/internal/publicip/api/ifconfigco.go b/internal/publicip/api/echoip.go similarity index 78% rename from internal/publicip/api/ifconfigco.go rename to internal/publicip/api/echoip.go index c0d8fdcbc..07351953e 100644 --- a/internal/publicip/api/ifconfigco.go +++ b/internal/publicip/api/echoip.go @@ -6,39 +6,45 @@ import ( "fmt" "net/http" "net/netip" + "strings" "github.com/qdm12/gluetun/internal/models" ) -type ifConfigCo struct { +type echoip struct { client *http.Client + url string } -func newIfConfigCo(client *http.Client) *ifConfigCo { - return &ifConfigCo{ +func newEchoip(client *http.Client, url string) *echoip { + return &echoip{ client: client, + url: url, } } -func (i *ifConfigCo) String() string { - return string(IfConfigCo) +func (e *echoip) String() string { + s := e.url + s = strings.TrimPrefix(s, "http://") + s = strings.TrimPrefix(s, "https://") + return s } -func (i *ifConfigCo) CanFetchAnyIP() bool { +func (e *echoip) CanFetchAnyIP() bool { return true } -func (i *ifConfigCo) Token() string { +func (e *echoip) Token() string { return "" } // FetchInfo obtains information on the ip address provided -// using the ifconfig.co/json API. If the ip is the zero value, +// using the echoip API at the url given. If the ip is the zero value, // the public IP address of the machine is used as the IP. -func (i *ifConfigCo) FetchInfo(ctx context.Context, ip netip.Addr) ( +func (e *echoip) FetchInfo(ctx context.Context, ip netip.Addr) ( result models.PublicIP, err error, ) { - url := "https://ifconfig.co/json" + url := e.url + "/json" if ip.IsValid() { url += "?ip=" + ip.String() } @@ -48,7 +54,7 @@ func (i *ifConfigCo) FetchInfo(ctx context.Context, ip netip.Addr) ( return result, err } - response, err := i.client.Do(request) + response, err := e.client.Do(request) if err != nil { return result, err }