From 64792df422ee84af58e5b35aa3fda1b58c9c6c22 Mon Sep 17 00:00:00 2001 From: Kent Quirk Date: Thu, 29 Jun 2023 10:34:17 -0400 Subject: [PATCH] Fix queryauth bug; add tests --- route/middleware.go | 1 + route/middleware_test.go | 74 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 route/middleware_test.go diff --git a/route/middleware.go b/route/middleware.go index 55f3345350..3a9a67979b 100644 --- a/route/middleware.go +++ b/route/middleware.go @@ -23,6 +23,7 @@ func (r *Router) queryTokenChecker(next http.Handler) http.Handler { if requiredToken == "" { err := fmt.Errorf("/query endpoint is not authorized for use (specify QueryAuthToken in config)") r.handlerReturnWithError(w, ErrAuthNeeded, err) + return } token := req.Header.Get(types.QueryTokenHeader) diff --git a/route/middleware_test.go b/route/middleware_test.go new file mode 100644 index 0000000000..0806edec2b --- /dev/null +++ b/route/middleware_test.go @@ -0,0 +1,74 @@ +package route + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/honeycombio/refinery/config" + "github.com/honeycombio/refinery/logger" + "github.com/honeycombio/refinery/types" +) + +type dummyHandler struct{} + +func (d *dummyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("good")) +} + +func TestRouter_queryTokenChecker(t *testing.T) { + tests := []struct { + name string + authtoken string + reqtoken string + want int + mustcontain string + mustnotcontain string + }{ + {"both_empty", "", "", 400, "not authorized for use", "good"}, + {"auth_empty", "", "foo", 400, "not authorized for use", "good"}, + {"req_empty", "foo", "", 400, "not authorized for query", "good"}, + {"correct", "testtoken", "testtoken", 200, "good", "authorized"}, + {"incorrect", "testtoken", "wrongtoken", 400, "not authorized for query", "good"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + router := &Router{ + Logger: &logger.NullLogger{}, + Config: &config.MockConfig{QueryAuthToken: tt.authtoken}, + } // we're not using anything else on this router + + // Create a request to pass to our handler. We don't have any query parameters for now, so we'll + // pass 'nil' as the third parameter. + req, err := http.NewRequest("GET", "/query", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set(types.QueryTokenHeader, tt.reqtoken) + + // We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response. + rr := httptest.NewRecorder() + + handler := router.queryTokenChecker(&dummyHandler{}) + handler.ServeHTTP(rr, req) + + // Check the status code is what we expect. + if status := rr.Code; status != tt.want { + t.Errorf("handler returned wrong status code: got %v want %v", + status, tt.want) + } + + // Check the response body is what we expect. + if !strings.Contains(rr.Body.String(), tt.mustcontain) { + t.Errorf("handler returned unexpected body: got %v should have contained %v", + rr.Body.String(), tt.mustcontain) + } + // Check the response body is what we expect. + if strings.Contains(rr.Body.String(), tt.mustnotcontain) { + t.Errorf("handler returned unexpected body: got %v should NOT have contained %v", + rr.Body.String(), tt.mustnotcontain) + } + }) + } +}