From 45be35e9d25fa549ca1c6337adc635ca5f348224 Mon Sep 17 00:00:00 2001 From: Adrian Shum Date: Thu, 14 Mar 2024 17:43:23 +0800 Subject: [PATCH] test cases --- config/config_test.go | 2 ++ imagor_test.go | 25 +++++++++++++++++++++ loader/httploader/httploader_test.go | 33 +++++++++++++++++++++++++++- 3 files changed, 59 insertions(+), 1 deletion(-) diff --git a/config/config_test.go b/config/config_test.go index e56c3c86b..221ddea14 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -60,6 +60,7 @@ func TestBasic(t *testing.T) { "-imagor-cache-header-ttl", "169h", "-imagor-cache-header-swr", "167h", "-http-loader-insecure-skip-verify-transport", + "-http-loader-override-response-headers", "cache-control,content-type", "-http-loader-base-url", "https://www.example.com/foo.org", }) app := srv.App.(*imagor.Imagor) @@ -85,6 +86,7 @@ func TestBasic(t *testing.T) { httpLoader := app.Loaders[0].(*httploader.HTTPLoader) assert.True(t, httpLoader.Transport.(*http.Transport).TLSClientConfig.InsecureSkipVerify) assert.Equal(t, "https://www.example.com/foo.org", httpLoader.BaseURL.String()) + assert.Equal(t, []string{"cache-control", "content-type"}, httpLoader.OverrideResponseHeaders) } func TestVersion(t *testing.T) { diff --git a/imagor_test.go b/imagor_test.go index 6df6fe10b..668556a3c 100644 --- a/imagor_test.go +++ b/imagor_test.go @@ -331,6 +331,31 @@ func TestWithRaw(t *testing.T) { assert.Equal(t, "bar", w.Header().Get("Content-Type")) } +func TestWithOverrideHeader(t *testing.T) { + app := New( + WithDebug(true), + WithUnsafe(true), + WithLogger(zap.NewExample()), + WithLoaders(loaderFunc(func(r *http.Request, image string) (*Blob, error) { + blob := NewBlobFromBytes([]byte("foo")) + blob.SetContentType("bar") + blob.Header = make(http.Header) + blob.Header.Set("Content-Type", "tada") + blob.Header.Set("Foo", "bar") + blob.Header.Set("asdf", "fghj") + return blob, nil + })), + ) + w := httptest.NewRecorder() + app.ServeHTTP(w, httptest.NewRequest( + http.MethodGet, "https://example.com/unsafe/filters:fill(red):raw()/gopher.png", nil)) + assert.Equal(t, 200, w.Code) + assert.Equal(t, "foo", w.Body.String()) + assert.Equal(t, "script-src 'none'", w.Header().Get("Content-Security-Policy")) + assert.Equal(t, "tada", w.Header().Get("Content-Type")) + assert.Equal(t, "fghj", w.Header().Get("ASDF")) +} + func TestNewBlobFromPathNotFound(t *testing.T) { loader := loaderFunc(func(r *http.Request, image string) (*Blob, error) { return NewBlobFromFile("./non-exists-path"), nil diff --git a/loader/httploader/httploader_test.go b/loader/httploader/httploader_test.go index bc5758f36..debcd1bca 100644 --- a/loader/httploader/httploader_test.go +++ b/loader/httploader/httploader_test.go @@ -23,7 +23,7 @@ func (t testTransport) RoundTrip(r *http.Request) (w *http.Response, err error) w = &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(res)), - Header: map[string][]string{}, + Header: make(http.Header), } w.Header.Set("Content-Type", "image/jpeg") return @@ -48,6 +48,7 @@ type test struct { name string target string result string + header map[string]string err string } @@ -80,6 +81,11 @@ func doTests(t *testing.T, loader imagor.Loader, tests []test) { } assert.Equal(t, tt.err, msg) } + if tt.header != nil { + for key, val := range tt.header { + assert.Equal(t, val, b.Header.Get(key)) + } + } }) } } @@ -492,6 +498,31 @@ func TestWithForwardHeadersOverrideUserAgent(t *testing.T) { }) } +func TestWithOverrideResponseHeader(t *testing.T) { + doTests(t, New( + WithTransport(roundTripFunc(func(r *http.Request) (w *http.Response, err error) { + res := &http.Response{ + StatusCode: http.StatusOK, + Header: map[string][]string{}, + Body: io.NopCloser(strings.NewReader("ok")), + } + res.Header.Set("Content-Type", "image/jpeg") + res.Header.Set("Foo", "Bar") + return res, nil + })), + WithOverrideResponseHeaders("foo"), + ), []test{ + { + name: "user agent", + target: "https://foo.bar/baz", + result: "ok", + header: map[string]string{ + "Foo": "Bar", + }, + }, + }) +} + func TestWithForwardClientHeaders(t *testing.T) { doTests(t, New( WithTransport(roundTripFunc(func(r *http.Request) (w *http.Response, err error) {