diff --git a/internal/requests_test.go b/internal/requests_test.go index a37f19f01..3c1d92d42 100644 --- a/internal/requests_test.go +++ b/internal/requests_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "net/http" "net/http/httptest" "net/url" "strings" @@ -217,6 +218,44 @@ func TestRequests(t *testing.T) { }) } +func TestRequestJSONTrailingData(t *testing.T) { + a := require.New(t) + ctx := context.Background() + + testData := "bababoi" + srv, err := api.NewServer(testHTTPRequests{}, + api.WithErrorHandler(func(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) { + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, err.Error()) + }), + ) + a.NoError(err) + + s := httptest.NewServer(srv) + defer s.Close() + + reqBody := fmt.Sprintf(`{"name":%q}{"name":"trailing"}`, testData) + req, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + s.URL+"/allRequestBodies", + strings.NewReader(reqBody), + ) + a.NoError(err) + req.Header.Set("Content-Type", "application/json") + + resp, err := s.Client().Do(req) + a.NoError(err) + defer func() { + _ = resp.Body.Close() + }() + + a.Equal(http.StatusBadRequest, resp.StatusCode) + data, err := io.ReadAll(resp.Body) + a.NoError(err) + a.Contains(string(data), ": unexpected trailing data") +} + func TestServerURLOverride(t *testing.T) { a := require.New(t) ctx := context.Background() diff --git a/internal/responses_test.go b/internal/responses_test.go index 8bfda2478..b55954366 100644 --- a/internal/responses_test.go +++ b/internal/responses_test.go @@ -352,3 +352,22 @@ func TestResponsesPattern(t *testing.T) { } }) } + +func TestResponseJSONTrailingData(t *testing.T) { + a := require.New(t) + ctx := context.Background() + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprint(w, `{"ok": "yes"} +{"ok": "trailing"}`) + })) + defer s.Close() + + client, err := api.NewClient(s.URL, api.WithClient(s.Client())) + a.NoError(err) + + _, err = client.Combined(ctx, api.CombinedParams{Type: api.CombinedType200}) + a.ErrorContains(err, "unexpected trailing data") +}