diff --git a/README.md b/README.md index 0831b4a..d4e6d4c 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,8 @@ eu3.sec-tunnel.com,77.111.244.22,443 | cafile | String | use custom CA certificate bundle file | | certchain-workaround | Boolean | add bundled cross-signed intermediate cert to certchain to make it check out on old systems (default true) | | country | String | desired proxy location (default "EU") | +| init-retries | Number | number of attempts for initialization steps, zero for unlimited retry | +| init-retry-interval | Duration | delay between initialization retries (default 5s) | | list-countries | - | list available countries and exit | | list-proxies | - | output proxy list and exit | | proxy | String | sets base proxy to use for all dial-outs. Format: `://[login:password@]host[:port]` Examples: `http://user:password@192.168.1.1:3128`, `socks5://10.0.0.1:1080` | diff --git a/main.go b/main.go index 9e319c9..2c70c8d 100644 --- a/main.go +++ b/main.go @@ -98,6 +98,8 @@ type CLIArgs struct { bootstrapDNS *CSVArg refresh time.Duration refreshRetry time.Duration + initRetries int + initRetryInterval time.Duration certChainWorkaround bool caFile string } @@ -140,6 +142,8 @@ func parse_args() *CLIArgs { "Examples: https://1.1.1.1/dns-query,quic://dns.adguard.com") flag.DurationVar(&args.refresh, "refresh", 4*time.Hour, "login refresh interval") flag.DurationVar(&args.refreshRetry, "refresh-retry", 5*time.Second, "login refresh retry interval") + flag.IntVar(&args.initRetries, "init-retries", 0, "number of attempts for initialization steps, zero for unlimited retry") + flag.DurationVar(&args.initRetryInterval, "init-retry-interval", 5 * time.Second, "delay between initialization retries") flag.BoolVar(&args.certChainWorkaround, "certchain-workaround", true, "add bundled cross-signed intermediate cert to certchain to make it check out on old systems") flag.StringVar(&args.caFile, "cafile", "", "use custom CA certificate bundle file") @@ -237,39 +241,48 @@ func run() int { TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, }) - seclient.Settings.ClientType = args.apiClientType - seclient.Settings.ClientVersion = args.apiClientVersion - seclient.Settings.UserAgent = args.apiUserAgent if err != nil { mainLogger.Critical("Unable to construct SEClient: %v", err) return 8 } + seclient.Settings.ClientType = args.apiClientType + seclient.Settings.ClientVersion = args.apiClientVersion + seclient.Settings.UserAgent = args.apiUserAgent - ctx, cl := context.WithTimeout(context.Background(), args.timeout) - err = seclient.AnonRegister(ctx) + try := retryPolicy(args.initRetries, args.initRetryInterval, mainLogger) + + err = try("anonymous registration", func() error { + ctx, cl := context.WithTimeout(context.Background(), args.timeout) + defer cl() + return seclient.AnonRegister(ctx) + }) if err != nil { - mainLogger.Critical("Unable to perform anonymous registration: %v", err) return 9 } - cl() - ctx, cl = context.WithTimeout(context.Background(), args.timeout) - err = seclient.RegisterDevice(ctx) + err = try("device registration", func() error { + ctx, cl := context.WithTimeout(context.Background(), args.timeout) + defer cl() + return seclient.RegisterDevice(ctx) + }) if err != nil { - mainLogger.Critical("Unable to perform device registration: %v", err) return 10 } - cl() if args.listCountries { - return printCountries(mainLogger, args.timeout, seclient) + return printCountries(try, mainLogger, args.timeout, seclient) } - ctx, cl = context.WithTimeout(context.Background(), args.timeout) - // TODO: learn about requested_geo value format - ips, err := seclient.Discover(ctx, fmt.Sprintf("\"%s\",,", args.country)) + var ips []se.SEIPEntry + err = try("discover", func() error { + ctx, cl := context.WithTimeout(context.Background(), args.timeout) + defer cl() + // TODO: learn about requested_geo value format + res, err := seclient.Discover(ctx, fmt.Sprintf("\"%s\",,", args.country)) + ips = res + return err + }) if err != nil { - mainLogger.Critical("Endpoint discovery failed: %v", err) return 12 } @@ -340,12 +353,16 @@ func run() int { return 0 } -func printCountries(logger *clog.CondLogger, timeout time.Duration, seclient *se.SEClient) int { - ctx, cl := context.WithTimeout(context.Background(), timeout) - defer cl() - list, err := seclient.GeoList(ctx) +func printCountries(try func(string, func() error) error, logger *clog.CondLogger, timeout time.Duration, seclient *se.SEClient) int { + var list []se.SEGeoEntry + err := try("geolist", func() error { + ctx, cl := context.WithTimeout(context.Background(), timeout) + defer cl() + l, err := seclient.GeoList(ctx) + list = l + return err + }) if err != nil { - logger.Critical("GeoList error: %v", err) return 11 } @@ -382,3 +399,24 @@ func printProxies(ips []se.SEIPEntry, seclient *se.SEClient) int { func main() { os.Exit(run()) } + +func retryPolicy(retries int, retryInterval time.Duration, logger *clog.CondLogger) func(string, func() error) error { + return func(name string, f func() error) error { + var err error + for i:=1; retries <= 0 || i<=retries; i++ { + if i > 1 { + logger.Warning("Retrying action %q in %v...", name, retryInterval) + time.Sleep(retryInterval) + } + logger.Info("Attempting action %q, attempt #%d...", name, i) + err = f() + if err == nil { + logger.Info("Action %q succeeded on attempt #%d", name, i) + return nil + } + logger.Warning("Action %q failed: %v", name, err) + } + logger.Critical("All attempts for action %q have failed. Last error: %v", name, err) + return err + } +}