diff --git a/docs/content/policy-reference.md b/docs/content/policy-reference.md index ffb82bc0b0..163d6c0571 100644 --- a/docs/content/policy-reference.md +++ b/docs/content/policy-reference.md @@ -543,7 +543,7 @@ The `request` object parameter may contain the following fields: | `tls_client_key_env_variable` | no | `string` | Environment variable containing a client key in PEM encoded format. | | `tls_client_cert_file` | no | `string` | Path to file containing a client certificate in PEM encoded format. | | `tls_client_key_file` | no | `string` | Path to file containing a key in PEM encoded format. | - +| `timeout` | no | `string` or `number` | Timeout for the HTTP request with a default of 5 seconds (`5s`). Numbers provided are in nanoseconds. Strings must be a valid duration string where a duration string is a possibly signed sequence of decimal numbers, each with optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". A zero timeout means no timeout.| To trigger the use of HTTPs the user must provide one of the following combinations: diff --git a/topdown/http.go b/topdown/http.go index 695bd5d100..c6bbb81d22 100644 --- a/topdown/http.go +++ b/topdown/http.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "io/ioutil" + "net/url" "strconv" "github.com/open-policy-agent/opa/internal/version" @@ -24,7 +25,9 @@ import ( "github.com/open-policy-agent/opa/topdown/builtins" ) -const defaultHTTPRequestTimeout = time.Second * 5 +const defaultHTTPRequestTimeoutEnv = "HTTP_SEND_TIMEOUT" + +var defaultHTTPRequestTimeout = time.Second * 5 var allowedKeyNames = [...]string{ "method", @@ -41,13 +44,12 @@ var allowedKeyNames = [...]string{ "tls_client_key_env_variable", "tls_client_cert_file", "tls_client_key_file", + "timeout", } var allowedKeys = ast.NewSet() var requiredKeys = ast.NewSet(ast.StringTerm("method"), ast.StringTerm("url")) -var client *http.Client - func builtinHTTPSend(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error { req, err := validateHTTPRequestOperand(args[0], 1) @@ -57,7 +59,7 @@ func builtinHTTPSend(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) resp, err := executeHTTPRequest(bctx, req) if err != nil { - return handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err) + return handleHTTPSendErr(bctx, err) } return iter(ast.NewTerm(resp)) @@ -65,23 +67,32 @@ func builtinHTTPSend(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) func init() { createAllowedKeys() - createHTTPClient() + initDefaults() RegisterBuiltinFunc(ast.HTTPSend.Name, builtinHTTPSend) } -func createHTTPClient() { - timeout := defaultHTTPRequestTimeout - timeoutDuration := os.Getenv("HTTP_SEND_TIMEOUT") - if timeoutDuration != "" { - timeout, _ = time.ParseDuration(timeoutDuration) +func handleHTTPSendErr(bctx BuiltinContext, err error) error { + // Return HTTP client timeout errors in a generic error message to avoid confusion about what happened. + // Do not do this if the builtin context was cancelled and is what caused the request to stop. + if urlErr, ok := err.(*url.Error); ok && urlErr.Timeout() && bctx.Context.Err() == nil { + err = fmt.Errorf("%s %s: request timed out", urlErr.Op, urlErr.URL) } + return handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err) +} - // create a http client with redirects disabled - client = &http.Client{ - Timeout: timeout, - CheckRedirect: func(*http.Request, []*http.Request) error { - return http.ErrUseLastResponse - }, +func initDefaults() { + timeoutDuration := os.Getenv(defaultHTTPRequestTimeoutEnv) + if timeoutDuration != "" { + var err error + defaultHTTPRequestTimeout, err = time.ParseDuration(timeoutDuration) + if err != nil { + // If it is set to something not valid don't let the process continue in a state + // that will almost definitely give unexpected results by having it set at 0 + // which means no timeout.. + // This environment variable isn't considered part of the public API. + // TODO(patrick-east): Remove the environment variable + panic(fmt.Sprintf("invalid value for HTTP_SEND_TIMEOUT: %s", err)) + } } } @@ -139,6 +150,8 @@ func executeHTTPRequest(bctx BuiltinContext, obj ast.Object) (ast.Value, error) var tlsConfig tls.Config var clientCerts []tls.Certificate var customHeaders map[string]interface{} + var timeout = defaultHTTPRequestTimeout + for _, val := range obj.Keys() { key, err := ast.JSON(val.Value) if err != nil { @@ -149,7 +162,7 @@ func executeHTTPRequest(bctx BuiltinContext, obj ast.Object) (ast.Value, error) switch key { case "method": method = obj.Get(val).String() - method = strings.Trim(method, "\"") + method = strings.ToUpper(strings.Trim(method, "\"")) case "url": url = obj.Get(val).String() url = strings.Trim(url, "\"") @@ -218,11 +231,20 @@ func executeHTTPRequest(bctx BuiltinContext, obj ast.Object) (ast.Value, error) if !ok { return nil, fmt.Errorf("invalid type for headers key") } + case "timeout": + timeout, err = parseTimeout(obj.Get(val).Value) + if err != nil { + return nil, err + } default: return nil, fmt.Errorf("invalid parameter %q", key) } } + client := &http.Client{ + Timeout: timeout, + } + if tlsClientCertFile != "" && tlsClientKeyFile != "" { clientCertFromFile, err := tls.LoadX509KeyPair(tlsClientCertFile, tlsClientKeyFile) if err != nil { @@ -253,8 +275,10 @@ func executeHTTPRequest(bctx BuiltinContext, obj ast.Object) (ast.Value, error) } // check if redirects are enabled - if enableRedirect { - client.CheckRedirect = nil + if !enableRedirect { + client.CheckRedirect = func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + } } if rawBody != nil { @@ -269,11 +293,13 @@ func executeHTTPRequest(bctx BuiltinContext, obj ast.Object) (ast.Value, error) return cachedResponse, nil } - // create the http request - req, err := http.NewRequest(strings.ToUpper(method), url, body) + // create the http request, use the builtin context's context to ensure + // the request is cancelled if evaluation is cancelled. + req, err := http.NewRequest(method, url, body) if err != nil { return nil, err } + req = req.WithContext(bctx.Context) // Add custom headers passed from CLI @@ -358,3 +384,31 @@ func createAllowedKeys() { allowedKeys.Add(ast.StringTerm(element)) } } + +func parseTimeout(timeoutVal ast.Value) (time.Duration, error) { + var timeout time.Duration + switch t := timeoutVal.(type) { + case ast.Number: + timeoutInt, ok := t.Int64() + if !ok { + return timeout, fmt.Errorf("invalid timeout number value %v, must be int64", timeoutVal) + } + return time.Duration(timeoutInt), nil + case ast.String: + // Support strings without a unit, treat them the same as just a number value (ns) + var err error + timeoutInt, err := strconv.ParseInt(string(t), 10, 64) + if err == nil { + return time.Duration(timeoutInt), nil + } + + // Try parsing it as a duration (requires a supported units suffix) + timeout, err = time.ParseDuration(string(t)) + if err != nil { + return timeout, fmt.Errorf("invalid timeout value %v: %s", timeoutVal, err) + } + return timeout, nil + default: + return timeout, builtins.NewOperandErr(1, "'timeout' must be one of {string, number} but got %s", ast.TypeName(t)) + } +} diff --git a/topdown/http_test.go b/topdown/http_test.go index ae99eb0de6..1074d63936 100644 --- a/topdown/http_test.go +++ b/topdown/http_test.go @@ -6,18 +6,24 @@ package topdown import ( "bytes" + "context" "crypto/tls" "crypto/x509" "encoding/json" + "errors" "fmt" "io/ioutil" "net/http" "net/http/httptest" "os" + "strconv" "strings" + "sync" "testing" + "time" "github.com/open-policy-agent/opa/internal/version" + "github.com/open-policy-agent/opa/topdown/builtins" "github.com/open-policy-agent/opa/ast" ) @@ -375,6 +381,201 @@ func TestInvalidKeyError(t *testing.T) { } } +func TestHTTPSendTimeout(t *testing.T) { + + // Each test can tweak the response delay, default is 0 with no delay + var responseDelay time.Duration + + tsMtx := sync.Mutex{} + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tsMtx.Lock() + defer tsMtx.Unlock() + time.Sleep(responseDelay) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`hello`)) + })) + defer ts.Close() + + tests := []struct { + note string + rule string + input string + defaultTimeout time.Duration + evalTimeout time.Duration + serverDelay time.Duration + expected interface{} + }{ + { + note: "no timeout", + rule: `p = x { http.send({"method": "get", "url": "%URL%" }, x) }`, + expected: `{"body": null, "raw_body": "hello", "status": "200 OK", "status_code": 200}`, + }, + { + note: "default timeout", + rule: `p = x { http.send({"method": "get", "url": "%URL%" }, x) }`, + evalTimeout: 1 * time.Minute, + serverDelay: 500 * time.Millisecond, + defaultTimeout: 1 * time.Microsecond, + expected: &Error{Code: BuiltinErr, Message: "http.send: Get %URL%: request timed out"}, + }, + { + note: "eval timeout", + rule: `p = x { http.send({"method": "get", "url": "%URL%" }, x) }`, + evalTimeout: 1 * time.Microsecond, + serverDelay: 500 * time.Millisecond, + defaultTimeout: 1 * time.Minute, + expected: &Error{Code: BuiltinErr, Message: "http.send: Get %URL%: context deadline exceeded"}, + }, + { + note: "param timeout less than default", + rule: `p = x { http.send({"method": "get", "url": "%URL%", "timeout": "1ms"}, x) }`, + evalTimeout: 1 * time.Minute, + serverDelay: 500 * time.Millisecond, + defaultTimeout: 1 * time.Minute, + expected: &Error{Code: BuiltinErr, Message: "http.send: Get %URL%: request timed out"}, + }, + { + note: "param timeout greater than default", + rule: `p = x { http.send({"method": "get", "url": "%URL%", "timeout": "1ms"}, x) }`, + evalTimeout: 1 * time.Minute, + serverDelay: 500 * time.Millisecond, + defaultTimeout: 1 * time.Microsecond, + expected: &Error{Code: BuiltinErr, Message: "http.send: Get %URL%: request timed out"}, + }, + { + note: "eval timeout less than param", + rule: `p = x { http.send({"method": "get", "url": "%URL%", "timeout": "1m" }, x) }`, + evalTimeout: 1 * time.Millisecond, + serverDelay: 100 * time.Millisecond, + defaultTimeout: 1 * time.Minute, + expected: &Error{Code: BuiltinErr, Message: "http.send: Get %URL%: context deadline exceeded"}, + }, + } + + for _, tc := range tests { + responseDelay = tc.serverDelay + + ctx := context.Background() + var cancel context.CancelFunc + if tc.evalTimeout > 0 { + ctx, cancel = context.WithTimeout(ctx, tc.evalTimeout) + } + + // TODO(patrick-east): Remove this along with the environment variable so that the "default" can't change + originalDefaultTimeout := defaultHTTPRequestTimeout + if tc.defaultTimeout > 0 { + defaultHTTPRequestTimeout = tc.defaultTimeout + } + + rule := strings.ReplaceAll(tc.rule, "%URL%", ts.URL) + if e, ok := tc.expected.(*Error); ok { + e.Message = strings.ReplaceAll(e.Message, "%URL%", ts.URL) + } + + runTopDownTestCaseWithContext(ctx, t, map[string]interface{}{}, tc.note, []string{rule}, nil, tc.input, tc.expected) + + // Put back the default (may not have changed) + defaultHTTPRequestTimeout = originalDefaultTimeout + if cancel != nil { + cancel() + } + } +} + +func TestParseTimeout(t *testing.T) { + tests := []struct { + note string + raw ast.Value + expected interface{} + }{ + { + note: "zero string", + raw: ast.String("0"), + expected: time.Duration(0), + }, + { + note: "zero number", + raw: ast.Number(strconv.FormatInt(0, 10)), + expected: time.Duration(0), + }, + { + note: "number", + raw: ast.Number(strconv.FormatInt(1234, 10)), + expected: time.Duration(1234), + }, + { + note: "number with invalid float", + raw: ast.Number("1.234"), + expected: errors.New("invalid timeout number value"), + }, + { + note: "string no units", + raw: ast.String("1000"), + expected: time.Duration(1000), + }, + { + note: "string with units", + raw: ast.String("10ms"), + expected: time.Duration(10000000), + }, + { + note: "string with complex units", + raw: ast.String("1s10ms5us"), + expected: time.Second + (10 * time.Millisecond) + (5 * time.Microsecond), + }, + { + note: "string with invalid duration format", + raw: ast.String("1xyz 2"), + expected: errors.New("invalid timeout value"), + }, + { + note: "string with float", + raw: ast.String("1.234"), + expected: errors.New("invalid timeout value"), + }, + { + note: "invalid value type object", + raw: ast.NewObject(), + expected: builtins.NewOperandErr(1, "'timeout' must be one of {string, number} but got object"), + }, + { + note: "invalid value type set", + raw: ast.NewSet(), + expected: builtins.NewOperandErr(1, "'timeout' must be one of {string, number} but got set"), + }, + { + note: "invalid value type array", + raw: &ast.Array{}, + expected: builtins.NewOperandErr(1, "'timeout' must be one of {string, number} but got array"), + }, + { + note: "invalid value type boolean", + raw: ast.Boolean(true), + expected: builtins.NewOperandErr(1, "'timeout' must be one of {string, number} but got boolean"), + }, + { + note: "invalid value type null", + raw: ast.Null{}, + expected: builtins.NewOperandErr(1, "'timeout' must be one of {string, number} but got null"), + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + actual, err := parseTimeout(tc.raw) + switch e := tc.expected.(type) { + case error: + assertError(t, tc.expected, err) + case time.Duration: + if e != actual { + t.Fatalf("Expected %d but got %d", e, actual) + } + } + }) + } +} + // TestHTTPRedirectDisable tests redirects are not enabled by default func TestHTTPRedirectDisable(t *testing.T) { diff --git a/topdown/topdown_test.go b/topdown/topdown_test.go index 4a6a314e58..a69d32ad4f 100644 --- a/topdown/topdown_test.go +++ b/topdown/topdown_test.go @@ -2868,10 +2868,14 @@ func loadSmallTestData() map[string]interface{} { } func runTopDownTestCase(t *testing.T, data map[string]interface{}, note string, rules []string, expected interface{}) { - runTopDownTestCaseWithModules(t, data, note, rules, nil, "", expected) + runTopDownTestCaseWithContext(context.Background(), t, data, note, rules, nil, "", expected) } func runTopDownTestCaseWithModules(t *testing.T, data map[string]interface{}, note string, rules []string, modules []string, input string, expected interface{}) { + runTopDownTestCaseWithContext(context.Background(), t, data, note, rules, modules, input, expected) +} + +func runTopDownTestCaseWithContext(ctx context.Context, t *testing.T, data map[string]interface{}, note string, rules []string, modules []string, input string, expected interface{}) { imports := []string{} for k := range data { imports = append(imports, "data."+k) @@ -2889,10 +2893,14 @@ func runTopDownTestCaseWithModules(t *testing.T, data map[string]interface{}, no store := inmem.NewFromObject(data) - assertTopDownWithPath(t, compiler, store, note, []string{"p"}, input, expected) + assertTopDownWithPathAndContext(ctx, t, compiler, store, note, []string{"p"}, input, expected) } func assertTopDownWithPath(t *testing.T, compiler *ast.Compiler, store storage.Store, note string, path []string, input string, expected interface{}) { + assertTopDownWithPathAndContext(context.Background(), t, compiler, store, note, path, input, expected) +} + +func assertTopDownWithPathAndContext(ctx context.Context, t *testing.T, compiler *ast.Compiler, store storage.Store, note string, path []string, input string, expected interface{}) { var inputTerm *ast.Term @@ -2900,7 +2908,6 @@ func assertTopDownWithPath(t *testing.T, compiler *ast.Compiler, store storage.S inputTerm = ast.MustParseTerm(input) } - ctx := context.Background() txn := storage.NewTransactionOrDie(ctx, store) defer store.Abort(ctx, txn)