forked from justinas/nosurf
-
Notifications
You must be signed in to change notification settings - Fork 0
/
handler.go
119 lines (99 loc) · 2.79 KB
/
handler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package csrf
import (
"context"
"encoding/base64"
"net/http"
)
const (
cookieName = "CSRF"
)
// CSRFHandler is a struct
type CSRFHandler struct {
successHandler http.Handler
failureHandler http.Handler
baseCookie http.Cookie
}
func defaultFailureHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, http.StatusText(400), 400)
}
// New CSRFHandler is generated
func New(handler http.Handler) *CSRFHandler {
baseCookie := http.Cookie{
MaxAge: 31536000,
Secure: true,
}
baseCookie.MaxAge = 31536000
csrf := &CSRFHandler{
successHandler: handler,
failureHandler: http.HandlerFunc(defaultFailureHandler),
baseCookie: baseCookie,
}
return csrf
}
func (h *CSRFHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var token string
r = r.WithContext(context.WithValue(r.Context(), nosurfKey, &token))
var realToken []byte
tokenCookie, err := r.Cookie(cookieName)
if err == nil {
realToken, err = base64.StdEncoding.DecodeString(tokenCookie.Value)
if err != nil {
realToken = nil
}
}
if len(realToken) != tokenLength {
h.RegenerateToken(w, r)
} else {
ctxSetToken(r, realToken)
}
w.Header().Add("vary", "cookie")
if r.Method == "GET" || r.Method == "HEAD" /*|| strings.HasPrefix(r.URL.Path, "/a")*/ {
h.successHandler.ServeHTTP(w, r)
return
}
// For MITM attacks
// if r.URL.Scheme == "https" {
// referer, err := url.Parse(r.Header.Get("Referer"))
// if err != nil || referer.String() == "" || referer.Scheme != r.URL.Scheme || referer.Host != r.URL.Host {
// h.failureHandler.ServeHTTP(w, r)
// return
// }
// }
sentToken, err := base64.StdEncoding.DecodeString(r.Header.Get(cookieName))
if err != nil {
sentToken = nil
}
if !verifyToken(realToken, sentToken) {
h.failureHandler.ServeHTTP(w, r)
return
}
h.successHandler.ServeHTTP(w, r)
}
// RegenerateToken as the name suggests
func (h *CSRFHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) string {
token := generateToken()
h.setTokenCookie(w, r, token)
return Token(r)
}
func (h *CSRFHandler) setTokenCookie(w http.ResponseWriter, r *http.Request, token []byte) {
ctxSetToken(r, token)
cookie := h.baseCookie
cookie.Name = cookieName
cookie.Value = base64.StdEncoding.EncodeToString(token)
cookie.HttpOnly = true
http.SetCookie(w, &cookie)
}
// SetFailureHandler for custom 400.
// func (h *CSRFHandler) SetFailureHandler(handler http.Handler) {
// h.failureHandler = handler
// }
// SetBaseCookie to add to.
// func (h *CSRFHandler) SetBaseCookie(cookie http.Cookie) {
// h.baseCookie = cookie
// }
// func (h CSRFHandler) getcookieName() string {
// if h.baseCookie.Name != "" {
// return h.baseCookie.Name
// }
// return cookieName
// }