diff --git a/pkg/http/handler/timeout.go b/pkg/http/handler/timeout.go index f864a07fb268..bc8b6d308bcc 100644 --- a/pkg/http/handler/timeout.go +++ b/pkg/http/handler/timeout.go @@ -127,9 +127,8 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return case <-timeout.C(): timeoutDrained = true - if tw.tryTimeoutAndWriteError(h.body) { - return - } + tw.forceTimeoutAndWriteError(h.body) + return case now := <-idleTimeoutCh: timedOut, timeToNextTimeout := tw.tryIdleTimeoutAndWriteError(now, revIdleTimeout, h.body) if timedOut { @@ -214,22 +213,10 @@ func (tw *timeoutWriter) WriteHeader(code int) { tw.w.WriteHeader(code) } -// tryTimeoutAndWriteError writes an error to the responsewriter if -// nothing has been written to the writer before. Returns whether -// an error was written or not. -// -// If this writes an error, all subsequent calls to Write will -// result in http.ErrHandlerTimeout. -func (tw *timeoutWriter) tryTimeoutAndWriteError(msg string) bool { +func (tw *timeoutWriter) forceTimeoutAndWriteError(msg string) { tw.mu.Lock() defer tw.mu.Unlock() - - if tw.lastWriteTime.IsZero() { - tw.timeoutAndWriteError(msg) - return true - } - - return false + tw.timeoutAndWriteError(msg) } // tryResponseStartTimeoutAndWriteError writes an error to the responsewriter if diff --git a/pkg/http/handler/timeout_test.go b/pkg/http/handler/timeout_test.go index b388e94d5ddd..aa0cbb2b4cce 100644 --- a/pkg/http/handler/timeout_test.go +++ b/pkg/http/handler/timeout_test.go @@ -34,8 +34,6 @@ func TestTimeoutWriterAllowsForAdditionalWritesBeforeTimeout(t *testing.T) { clock := clock.RealClock{} handler := &timeoutWriter{w: recorder, clock: clock} handler.WriteHeader(http.StatusOK) - handler.tryTimeoutAndWriteError("error") - handler.tryResponseStartTimeoutAndWriteError("error") handler.tryIdleTimeoutAndWriteError(clock.Now(), 10*time.Second, "error") if _, err := io.WriteString(handler, "test"); err != nil { t.Fatalf("handler.Write() = %v, want no error", err) diff --git a/test/conformance/api/v1/revision_timeout_test.go b/test/conformance/api/v1/revision_timeout_test.go index 6c8097559515..221c7ed020c8 100644 --- a/test/conformance/api/v1/revision_timeout_test.go +++ b/test/conformance/api/v1/revision_timeout_test.go @@ -37,7 +37,7 @@ import ( // sendRequest send a request to "endpoint", returns error if unexpected response code, nil otherwise. func sendRequest(t *testing.T, clients *test.Clients, endpoint *url.URL, - initialSleep, sleep time.Duration, expectedResponseCode int) error { + initialSleep, sleep time.Duration, expectedResponseCode int, expectedBody string) error { client, err := pkgtest.NewSpoofingClient(context.Background(), clients.KubeClient, t.Logf, endpoint.Hostname(), test.ServingFlags.ResolvableDomain, test.AddRootCAtoTransport(context.Background(), t.Logf, clients, test.ServingFlags.HTTPS)) if err != nil { return fmt.Errorf("error creating Spoofing client: %w", err) @@ -68,7 +68,12 @@ func sendRequest(t *testing.T, clients *test.Clients, endpoint *url.URL, if expectedResponseCode != resp.StatusCode { return fmt.Errorf("response status code = %v, want = %v, response = %v", resp.StatusCode, expectedResponseCode, resp) } - + if expectedBody != "" { + gotBody := string(resp.Body) + if expectedBody != gotBody { + return fmt.Errorf("response body = %v, want = %v, response = %v", gotBody, expectedBody, resp) + } + } return nil } @@ -99,6 +104,7 @@ func TestRevisionTimeout(t *testing.T) { expectedStatus: http.StatusOK, sleep: 15 * time.Second, initialSleep: 0, + expectedBody: "activator request timeout", }} for _, tc := range testCases { @@ -136,7 +142,7 @@ func TestRevisionTimeout(t *testing.T) { t.Fatalf("Error probing %s: %v", serviceURL, err) } - if err := sendRequest(t, clients, serviceURL, tc.initialSleep, tc.sleep, tc.expectedStatus); err != nil { + if err := sendRequest(t, clients, serviceURL, tc.initialSleep, tc.sleep, tc.expectedStatus, tc.expectedBody); err != nil { t.Errorf("Failed request with initialSleep %v, sleep %v, with revision timeout %ds, expecting status %v: %v", tc.initialSleep, tc.sleep, tc.timeoutSeconds, tc.expectedStatus, err) }