Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create functional option for ctx.SetCookie #208

Merged
merged 8 commits into from
Nov 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 43 additions & 18 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,15 @@ type Request struct {
*http.Request
}

// Body returns a RequestBody for the request
func (r *Request) Body() *RequestBody {
return &RequestBody{r.Request.Body}
}

// ContextInvoker is an inject.FastInvoker wrapper of func(ctx *Context).
type ContextInvoker func(ctx *Context)

// Invoke implements inject.FastInvoker which simplifies calls of `func(ctx *Context)` function.
func (invoke ContextInvoker) Invoke(params []interface{}) ([]reflect.Value, error) {
invoke(params[0].(*Context))
return nil, nil
Expand All @@ -97,41 +99,43 @@ type Context struct {
Data map[string]interface{}
}

func (c *Context) handler() Handler {
if c.index < len(c.handlers) {
return c.handlers[c.index]
func (ctx *Context) handler() Handler {
if ctx.index < len(ctx.handlers) {
return ctx.handlers[ctx.index]
}
if c.index == len(c.handlers) {
return c.action
if ctx.index == len(ctx.handlers) {
return ctx.action
}
panic("invalid index for context handler")
}

func (c *Context) Next() {
c.index += 1
c.run()
// Next runs the next handler in the context chain
func (ctx *Context) Next() {
ctx.index++
ctx.run()
}

func (c *Context) Written() bool {
return c.Resp.Written()
// Written returns whether the context response has been written to
func (ctx *Context) Written() bool {
return ctx.Resp.Written()
}

func (c *Context) run() {
for c.index <= len(c.handlers) {
vals, err := c.Invoke(c.handler())
func (ctx *Context) run() {
for ctx.index <= len(ctx.handlers) {
vals, err := ctx.Invoke(ctx.handler())
if err != nil {
panic(err)
}
c.index += 1
ctx.index++

// if the handler returned something, write it to the http response
if len(vals) > 0 {
ev := c.GetVal(reflect.TypeOf(ReturnHandler(nil)))
ev := ctx.GetVal(reflect.TypeOf(ReturnHandler(nil)))
handleReturn := ev.Interface().(ReturnHandler)
handleReturn(c, vals)
handleReturn(ctx, vals)
}

if c.Written() {
if ctx.Written() {
return
}
}
Expand Down Expand Up @@ -172,6 +176,7 @@ func (ctx *Context) HTMLSet(status int, setName, tplName string, data ...interfa
ctx.renderHTML(status, setName, tplName, data...)
}

// Redirect sends a redirect response
func (ctx *Context) Redirect(location string, status ...int) {
code := http.StatusFound
if len(status) == 1 {
Expand All @@ -181,7 +186,7 @@ func (ctx *Context) Redirect(location string, status ...int) {
http.Redirect(ctx.Resp, ctx.Req.Request, location, code)
}

// Maximum amount of memory to use when parsing a multipart form.
// MaxMemory is the maximum amount of memory to use when parsing a multipart form.
// Set this to whatever value you prefer; default is 10 MB.
var MaxMemory = int64(1024 * 1024 * 10)

Expand Down Expand Up @@ -341,26 +346,34 @@ func (ctx *Context) SetCookie(name string, value string, others ...interface{})
cookie.MaxAge = int(v)
case int32:
cookie.MaxAge = int(v)
case func(*http.Cookie):
v(&cookie)
}
}

cookie.Path = "/"
if len(others) > 1 {
if v, ok := others[1].(string); ok && len(v) > 0 {
cookie.Path = v
} else if v, ok := others[1].(func(*http.Cookie)); ok {
v(&cookie)
}
}

if len(others) > 2 {
if v, ok := others[2].(string); ok && len(v) > 0 {
cookie.Domain = v
} else if v, ok := others[1].(func(*http.Cookie)); ok {
v(&cookie)
}
}

if len(others) > 3 {
switch v := others[3].(type) {
case bool:
cookie.Secure = v
case func(*http.Cookie):
v(&cookie)
default:
if others[3] != nil {
cookie.Secure = true
Expand All @@ -371,13 +384,25 @@ func (ctx *Context) SetCookie(name string, value string, others ...interface{})
if len(others) > 4 {
if v, ok := others[4].(bool); ok && v {
cookie.HttpOnly = true
} else if v, ok := others[1].(func(*http.Cookie)); ok {
v(&cookie)
}
}

if len(others) > 5 {
if v, ok := others[5].(time.Time); ok {
cookie.Expires = v
cookie.RawExpires = v.Format(time.UnixDate)
} else if v, ok := others[1].(func(*http.Cookie)); ok {
v(&cookie)
}
}

if len(others) > 6 {
for _, other := range others[6:] {
if v, ok := other.(func(*http.Cookie)); ok {
v(&cookie)
}
}
}

Expand Down
14 changes: 13 additions & 1 deletion context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"time"

"github.com/unknwon/com"
"gopkg.in/macaron.v1/cookie"

. "github.com/smartystreets/goconvey/convey"
)
Expand Down Expand Up @@ -209,7 +210,18 @@ func Test_Context(t *testing.T) {
So(err, ShouldBeNil)
ctx.SetCookie("user", "Unknwon", 1, "/", "localhost", true, true, t)
ctx.SetCookie("user", "Unknwon", int32(1), "/", "localhost", 1)
ctx.SetCookie("user", "Unknwon", int64(1))
called := false
ctx.SetCookie("user", "Unknwon", int64(1), func(c *http.Cookie) {
called = true
})
So(called, ShouldBeTrue)
ctx.SetCookie("user", "Unknown",
cookie.Secure(true),
cookie.HttpOnly(true),
cookie.Path("/"),
cookie.MaxAge(1),
cookie.Domain("localhost"),
)
})

resp := httptest.NewRecorder()
Expand Down
78 changes: 78 additions & 0 deletions cookie/helper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright 2020 The Macaron Authors
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

// Package cookie contains helper functions for setting cookie values.
package cookie

import (
"net/http"
"time"
)

// MaxAge sets the maximum age for a provided cookie
func MaxAge(maxAge int) func(*http.Cookie) {
return func(c *http.Cookie) {
c.MaxAge = maxAge
}
}

// Path sets the path for a provided cookie
func Path(path string) func(*http.Cookie) {
return func(c *http.Cookie) {
c.Path = path
}
}

// Domain sets the domain for a provided cookie
func Domain(domain string) func(*http.Cookie) {
return func(c *http.Cookie) {
c.Domain = domain
}
}

// Secure sets the secure setting for a provided cookie
func Secure(secure bool) func(*http.Cookie) {
return func(c *http.Cookie) {
c.Secure = secure
}
}

// HttpOnly sets the HttpOnly setting for a provided cookie
func HttpOnly(httpOnly bool) func(*http.Cookie) {
return func(c *http.Cookie) {
c.HttpOnly = httpOnly
}
}

// HTTPOnly sets the HttpOnly setting for a provided cookie
func HTTPOnly(httpOnly bool) func(*http.Cookie) {
return func(c *http.Cookie) {
c.HttpOnly = httpOnly
}
}

// Expires sets the expires and rawexpires for a provided cookie
func Expires(expires time.Time) func(*http.Cookie) {
return func(c *http.Cookie) {
c.Expires = expires
c.RawExpires = expires.Format(time.UnixDate)
}
}

// SameSite sets the SameSite for a provided cookie
func SameSite(sameSite http.SameSite) func(*http.Cookie) {
return func(c *http.Cookie) {
c.SameSite = sameSite
}
}
8 changes: 4 additions & 4 deletions tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ func Test_getWildcards(t *testing.T) {
":id([0-9]+)_:name": result{"([0-9]+)_(.+)", ":id :name"},
"article_:id_:page.html": result{"article_(.+)_(.+).html", ":id :page"},
"article_:id:int_:page:string.html": result{"article_([0-9]+)_([\\w]+).html", ":id :page"},
"*": result{"*", ""},
"*.*": result{"*.*", ""},
"*": result{"*", ""},
"*.*": result{"*.*", ""},
}
Convey("Get wildcards", t, func() {
for key, result := range cases {
Expand All @@ -56,8 +56,8 @@ func Test_getRawPattern(t *testing.T) {
"article_:id_:page.html": "article_:id_:page.html",
"article_:id:int_:page:string.html": "article_:id_:page.html",
"article_:id([0-9]+)_:page([\\w]+).html": "article_:id_:page.html",
"*": "*",
"*.*": "*.*",
"*": "*",
"*.*": "*.*",
}
Convey("Get raw pattern", t, func() {
for k, v := range cases {
Expand Down