diff --git a/auth/authenticator.go b/auth/authenticator.go index 5aef553..80cfc23 100644 --- a/auth/authenticator.go +++ b/auth/authenticator.go @@ -83,8 +83,8 @@ const ( // complete. // // NOTE: This is part of the Authenticator interface. -func (l *L402Authenticator) FreshChallengeHeader(r *http.Request, - serviceName string, servicePrice int64) (http.Header, error) { +func (l *L402Authenticator) FreshChallengeHeader(serviceName string, + servicePrice int64) (http.Header, error) { service := l402.Service{ Name: serviceName, diff --git a/auth/interface.go b/auth/interface.go index 7b51dcc..29ebfeb 100644 --- a/auth/interface.go +++ b/auth/interface.go @@ -27,7 +27,7 @@ type Authenticator interface { // FreshChallengeHeader returns a header containing a challenge for the // user to complete. - FreshChallengeHeader(*http.Request, string, int64) (http.Header, error) + FreshChallengeHeader(string, int64) (http.Header, error) } // Minter is an entity that is able to mint and verify L402s for a set of diff --git a/auth/mock_authenticator.go b/auth/mock_authenticator.go index 7f4c380..390db29 100644 --- a/auth/mock_authenticator.go +++ b/auth/mock_authenticator.go @@ -31,8 +31,8 @@ func (a MockAuthenticator) Accept(header *http.Header, _ string) bool { // FreshChallengeHeader returns a header containing a challenge for the user to // complete. -func (a MockAuthenticator) FreshChallengeHeader(r *http.Request, - _ string, _ int64) (http.Header, error) { +func (a MockAuthenticator) FreshChallengeHeader(string, int64) (http.Header, + error) { header := http.Header{ "Content-Type": []string{"application/grpc"}, diff --git a/proxy/proxy.go b/proxy/proxy.go index 1d2bf8f..c134123 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -399,7 +399,7 @@ func (p *Proxy) handlePaymentRequired(w http.ResponseWriter, r *http.Request, serviceName string, servicePrice int64) { header, err := p.authenticator.FreshChallengeHeader( - r, serviceName, servicePrice, + serviceName, servicePrice, ) if err != nil { log.Errorf("Error creating new challenge header: %v", err) diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index dbe326b..a3b8056 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -101,15 +101,19 @@ func TestProxyHTTP(t *testing.T) { for _, tc := range testCases { tc := tc - t.Run(tc.name, func(t *testing.T) { - runHTTPTest(t, tc) + t.Run(tc.name+" GET", func(t *testing.T) { + runHTTPTest(t, tc, "GET") + }) + + t.Run(tc.name+" POST", func(t *testing.T) { + runHTTPTest(t, tc, "POST") }) } } // TestProxyHTTP tests that the proxy can forward HTTP requests to a backend // service and handle L402 authentication correctly. -func runHTTPTest(t *testing.T, tc *testCase) { +func runHTTPTest(t *testing.T, tc *testCase, method string) { // Create a list of services to proxy between. services := []*proxy.Service{{ Address: testTargetServiceAddress, @@ -148,11 +152,25 @@ func runHTTPTest(t *testing.T, tc *testCase) { // Authorization header set. client := &http.Client{} url := fmt.Sprintf("http://%s/http/test", testProxyAddr) - resp, err := client.Get(url) + + req, err := http.NewRequest(method, url, nil) + require.NoError(t, err) + + if method == "POST" { + req.Header.Add("Content-Type", "application/json") + req.Body = io.NopCloser(strings.NewReader(`{}`)) + } + + resp, err := client.Do(req) require.NoError(t, err) require.Equal(t, "402 Payment Required", resp.Status) + bodyContent, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "payment required\n", string(bodyContent)) + require.EqualValues(t, len(bodyContent), resp.ContentLength) + authHeader := resp.Header.Get("Www-Authenticate") require.Regexp(t, "(LSAT|L402)", authHeader) _ = resp.Body.Close() @@ -161,7 +179,7 @@ func runHTTPTest(t *testing.T, tc *testCase) { // get the 402 response. if len(tc.authWhitelist) > 0 { url = fmt.Sprintf("http://%s/http/white", testProxyAddr) - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequest(method, url, nil) require.NoError(t, err) resp, err = client.Do(req) require.NoError(t, err) @@ -174,11 +192,12 @@ func runHTTPTest(t *testing.T, tc *testCase) { require.NoError(t, err) require.Equal(t, testHTTPResponseBody, string(bodyBytes)) + require.EqualValues(t, len(bodyBytes), resp.ContentLength) } // Make sure that if the Auth header is set, the client's request is // proxied to the backend service. - req, err := http.NewRequest("GET", url, nil) + req, err = http.NewRequest(method, url, nil) require.NoError(t, err) req.Header.Add("Authorization", "foobar") @@ -193,6 +212,7 @@ func runHTTPTest(t *testing.T, tc *testCase) { require.NoError(t, err) require.Equal(t, testHTTPResponseBody, string(bodyBytes)) + require.EqualValues(t, len(bodyBytes), resp.ContentLength) } // TestProxyHTTP tests that the proxy can forward gRPC requests to a backend @@ -313,9 +333,7 @@ func runGRPCTest(t *testing.T, tc *testCase) { // We expect the WWW-Authenticate header field to be set to an L402 // auth response. - expectedHeaderContent, _ := mockAuth.FreshChallengeHeader(&http.Request{ - Header: map[string][]string{}, - }, "", 0) + expectedHeaderContent, _ := mockAuth.FreshChallengeHeader("", 0) capturedHeader := captureMetadata.Get("WWW-Authenticate") require.Len(t, capturedHeader, 2) require.Equal(