From 6f0734af874f7086d2f5c465c316d90ebcf02ab5 Mon Sep 17 00:00:00 2001 From: zeripath Date: Fri, 13 Nov 2020 12:00:30 +0000 Subject: [PATCH] Create functional option for ctx.SetCookie (#208) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: ᴜɴᴋɴᴡᴏɴ Co-authored-by: 6543 <6543@obermui.de> --- context.go | 61 ++++++++++++++++++++++++++----------- context_test.go | 14 ++++++++- cookie/helper.go | 78 ++++++++++++++++++++++++++++++++++++++++++++++++ tree_test.go | 8 ++--- 4 files changed, 138 insertions(+), 23 deletions(-) create mode 100644 cookie/helper.go diff --git a/context.go b/context.go index be94d35..05d09f3 100644 --- a/context.go +++ b/context.go @@ -68,6 +68,7 @@ type Request struct { *http.Request } +// Body returns a RequestBody for the request func (r *Request) Body() *RequestBody { return &RequestBody{r.Request.Body} } @@ -75,6 +76,7 @@ func (r *Request) Body() *RequestBody { // ContextInvoker is an inject.FastInvoker wrapper of func(ctx *Context). type ContextInvoker func(ctx *Context) +// Invoke implements inject.FastInvoker which simplifies calls of `func(ctx *Context)` function. func (invoke ContextInvoker) Invoke(params []interface{}) ([]reflect.Value, error) { invoke(params[0].(*Context)) return nil, nil @@ -97,41 +99,43 @@ type Context struct { Data map[string]interface{} } -func (c *Context) handler() Handler { - if c.index < len(c.handlers) { - return c.handlers[c.index] +func (ctx *Context) handler() Handler { + if ctx.index < len(ctx.handlers) { + return ctx.handlers[ctx.index] } - if c.index == len(c.handlers) { - return c.action + if ctx.index == len(ctx.handlers) { + return ctx.action } panic("invalid index for context handler") } -func (c *Context) Next() { - c.index += 1 - c.run() +// Next runs the next handler in the context chain +func (ctx *Context) Next() { + ctx.index++ + ctx.run() } -func (c *Context) Written() bool { - return c.Resp.Written() +// Written returns whether the context response has been written to +func (ctx *Context) Written() bool { + return ctx.Resp.Written() } -func (c *Context) run() { - for c.index <= len(c.handlers) { - vals, err := c.Invoke(c.handler()) +func (ctx *Context) run() { + for ctx.index <= len(ctx.handlers) { + vals, err := ctx.Invoke(ctx.handler()) if err != nil { panic(err) } - c.index += 1 + ctx.index++ // if the handler returned something, write it to the http response if len(vals) > 0 { - ev := c.GetVal(reflect.TypeOf(ReturnHandler(nil))) + ev := ctx.GetVal(reflect.TypeOf(ReturnHandler(nil))) handleReturn := ev.Interface().(ReturnHandler) - handleReturn(c, vals) + handleReturn(ctx, vals) } - if c.Written() { + if ctx.Written() { return } } @@ -172,6 +176,7 @@ func (ctx *Context) HTMLSet(status int, setName, tplName string, data ...interfa ctx.renderHTML(status, setName, tplName, data...) } +// Redirect sends a redirect response func (ctx *Context) Redirect(location string, status ...int) { code := http.StatusFound if len(status) == 1 { @@ -181,7 +186,7 @@ func (ctx *Context) Redirect(location string, status ...int) { http.Redirect(ctx.Resp, ctx.Req.Request, location, code) } -// Maximum amount of memory to use when parsing a multipart form. +// MaxMemory is the maximum amount of memory to use when parsing a multipart form. // Set this to whatever value you prefer; default is 10 MB. var MaxMemory = int64(1024 * 1024 * 10) @@ -341,6 +346,8 @@ func (ctx *Context) SetCookie(name string, value string, others ...interface{}) cookie.MaxAge = int(v) case int32: cookie.MaxAge = int(v) + case func(*http.Cookie): + v(&cookie) } } @@ -348,12 +355,16 @@ func (ctx *Context) SetCookie(name string, value string, others ...interface{}) if len(others) > 1 { if v, ok := others[1].(string); ok && len(v) > 0 { cookie.Path = v + } else if v, ok := others[1].(func(*http.Cookie)); ok { + v(&cookie) } } if len(others) > 2 { if v, ok := others[2].(string); ok && len(v) > 0 { cookie.Domain = v + } else if v, ok := others[1].(func(*http.Cookie)); ok { + v(&cookie) } } @@ -361,6 +372,8 @@ func (ctx *Context) SetCookie(name string, value string, others ...interface{}) switch v := others[3].(type) { case bool: cookie.Secure = v + case func(*http.Cookie): + v(&cookie) default: if others[3] != nil { cookie.Secure = true @@ -371,6 +384,8 @@ func (ctx *Context) SetCookie(name string, value string, others ...interface{}) if len(others) > 4 { if v, ok := others[4].(bool); ok && v { cookie.HttpOnly = true + } else if v, ok := others[1].(func(*http.Cookie)); ok { + v(&cookie) } } @@ -378,6 +393,16 @@ func (ctx *Context) SetCookie(name string, value string, others ...interface{}) if v, ok := others[5].(time.Time); ok { cookie.Expires = v cookie.RawExpires = v.Format(time.UnixDate) + } else if v, ok := others[1].(func(*http.Cookie)); ok { + v(&cookie) + } + } + + if len(others) > 6 { + for _, other := range others[6:] { + if v, ok := other.(func(*http.Cookie)); ok { + v(&cookie) + } } } diff --git a/context_test.go b/context_test.go index de2ffaa..44bdd92 100644 --- a/context_test.go +++ b/context_test.go @@ -27,6 +27,7 @@ import ( "time" "github.com/unknwon/com" + "gopkg.in/macaron.v1/cookie" . "github.com/smartystreets/goconvey/convey" ) @@ -209,7 +210,18 @@ func Test_Context(t *testing.T) { So(err, ShouldBeNil) ctx.SetCookie("user", "Unknwon", 1, "/", "localhost", true, true, t) ctx.SetCookie("user", "Unknwon", int32(1), "/", "localhost", 1) - ctx.SetCookie("user", "Unknwon", int64(1)) + called := false + ctx.SetCookie("user", "Unknwon", int64(1), func(c *http.Cookie) { + called = true + }) + So(called, ShouldBeTrue) + ctx.SetCookie("user", "Unknown", + cookie.Secure(true), + cookie.HttpOnly(true), + cookie.Path("/"), + cookie.MaxAge(1), + cookie.Domain("localhost"), + ) }) resp := httptest.NewRecorder() diff --git a/cookie/helper.go b/cookie/helper.go new file mode 100644 index 0000000..c5f8eb4 --- /dev/null +++ b/cookie/helper.go @@ -0,0 +1,78 @@ +// Copyright 2020 The Macaron Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +// Package cookie contains helper functions for setting cookie values. +package cookie + +import ( + "net/http" + "time" +) + +// MaxAge sets the maximum age for a provided cookie +func MaxAge(maxAge int) func(*http.Cookie) { + return func(c *http.Cookie) { + c.MaxAge = maxAge + } +} + +// Path sets the path for a provided cookie +func Path(path string) func(*http.Cookie) { + return func(c *http.Cookie) { + c.Path = path + } +} + +// Domain sets the domain for a provided cookie +func Domain(domain string) func(*http.Cookie) { + return func(c *http.Cookie) { + c.Domain = domain + } +} + +// Secure sets the secure setting for a provided cookie +func Secure(secure bool) func(*http.Cookie) { + return func(c *http.Cookie) { + c.Secure = secure + } +} + +// HttpOnly sets the HttpOnly setting for a provided cookie +func HttpOnly(httpOnly bool) func(*http.Cookie) { + return func(c *http.Cookie) { + c.HttpOnly = httpOnly + } +} + +// HTTPOnly sets the HttpOnly setting for a provided cookie +func HTTPOnly(httpOnly bool) func(*http.Cookie) { + return func(c *http.Cookie) { + c.HttpOnly = httpOnly + } +} + +// Expires sets the expires and rawexpires for a provided cookie +func Expires(expires time.Time) func(*http.Cookie) { + return func(c *http.Cookie) { + c.Expires = expires + c.RawExpires = expires.Format(time.UnixDate) + } +} + +// SameSite sets the SameSite for a provided cookie +func SameSite(sameSite http.SameSite) func(*http.Cookie) { + return func(c *http.Cookie) { + c.SameSite = sameSite + } +} diff --git a/tree_test.go b/tree_test.go index d82c043..3921b74 100644 --- a/tree_test.go +++ b/tree_test.go @@ -34,8 +34,8 @@ func Test_getWildcards(t *testing.T) { ":id([0-9]+)_:name": result{"([0-9]+)_(.+)", ":id :name"}, "article_:id_:page.html": result{"article_(.+)_(.+).html", ":id :page"}, "article_:id:int_:page:string.html": result{"article_([0-9]+)_([\\w]+).html", ":id :page"}, - "*": result{"*", ""}, - "*.*": result{"*.*", ""}, + "*": result{"*", ""}, + "*.*": result{"*.*", ""}, } Convey("Get wildcards", t, func() { for key, result := range cases { @@ -56,8 +56,8 @@ func Test_getRawPattern(t *testing.T) { "article_:id_:page.html": "article_:id_:page.html", "article_:id:int_:page:string.html": "article_:id_:page.html", "article_:id([0-9]+)_:page([\\w]+).html": "article_:id_:page.html", - "*": "*", - "*.*": "*.*", + "*": "*", + "*.*": "*.*", } Convey("Get raw pattern", t, func() { for k, v := range cases {