diff --git a/source/extensions/filters/http/oauth2/oauth_client.h b/source/extensions/filters/http/oauth2/oauth_client.h index ebc9a97e3ab5..e270c327236e 100644 --- a/source/extensions/filters/http/oauth2/oauth_client.h +++ b/source/extensions/filters/http/oauth2/oauth_client.h @@ -90,7 +90,13 @@ class OAuth2ClientImpl : public OAuth2Client, Logger::Loggableheaders().setReferenceMethod(Http::Headers::get().MethodValues.Post); - request->headers().setContentType(Http::Headers::get().ContentTypeValues.FormUrlEncoded); + request->headers().setReferenceContentType( + Http::Headers::get().ContentTypeValues.FormUrlEncoded); + // Use the Accept header to ensure the Access Token Response is returned as JSON. + // Some authorization servers return other encodings (e.g. FormUrlEncoded) in the absence of the + // Accept header. RFC 6749 Section 5.1 defines the media type to be JSON, so this is safe. + request->headers().setReference(Http::CustomHeaders::get().Accept, + Http::Headers::get().ContentTypeValues.Json); return request; } }; diff --git a/test/extensions/filters/http/oauth2/oauth_test.cc b/test/extensions/filters/http/oauth2/oauth_test.cc index c1de35521ab8..f0a72cdfd4b2 100644 --- a/test/extensions/filters/http/oauth2/oauth_test.cc +++ b/test/extensions/filters/http/oauth2/oauth_test.cc @@ -79,12 +79,20 @@ TEST_F(OAuth2ClientTest, RequestAccessTokenSuccess) { mock_response->body().add(json); EXPECT_CALL(cm_.thread_local_cluster_.async_client_, send_(_, _, _)) - .WillRepeatedly( - Invoke([&](Http::RequestMessagePtr&, Http::AsyncClient::Callbacks& cb, - const Http::AsyncClient::RequestOptions&) -> Http::AsyncClient::Request* { - callbacks_.push_back(&cb); - return &request_; - })); + .WillRepeatedly(Invoke([&](Http::RequestMessagePtr& message, Http::AsyncClient::Callbacks& cb, + const Http::AsyncClient::RequestOptions&) + -> Http::AsyncClient::Request* { + EXPECT_EQ(Http::Headers::get().MethodValues.Post, + message->headers().Method()->value().getStringView()); + EXPECT_EQ(Http::Headers::get().ContentTypeValues.FormUrlEncoded, + message->headers().ContentType()->value().getStringView()); + EXPECT_TRUE( + !message->headers().get(Http::CustomHeaders::get().Accept).empty() && + message->headers().get(Http::CustomHeaders::get().Accept)[0]->value().getStringView() == + Http::Headers::get().ContentTypeValues.Json); + callbacks_.push_back(&cb); + return &request_; + })); client_->setCallbacks(*mock_callbacks_); client_->asyncGetAccessToken("a", "b", "c", "d");