-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathauth.go
151 lines (121 loc) · 4.03 KB
/
auth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
package main
import (
"database/sql"
"log"
"net/http"
"net/url"
"strings"
"time"
"crypto/rand"
"github.com/google/uuid"
"github.com/joelramilison/timespent/internal/database"
"golang.org/x/crypto/bcrypt"
"errors"
)
const (
sessionDurationString = "168h"
)
type authedHandler func(w http.ResponseWriter, req *http.Request, user database.User)
func (cfg *apiConfig) middlewareAuth(handler authedHandler) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
toLogin := func() {
// we want to redirect to the /login page.
// the methodology depends on whether it's an HTMX request
// or a normal one (say, typing it into the browser)
hxReqHeader := req.Header.Get("HX-Request")
if hxReqHeader == "" {
http.Redirect(w, req, "/login", http.StatusFound)
} else {
w.Header().Add("HX-Redirect", "/login")
w.WriteHeader(302)
w.Write([]byte{})
}
}
sessionID, userID, err := extractFromCookie(req)
if err != nil {
toLogin()
return
}
user, err := cfg.DB.GetUser(req.Context(), userID)
if err != nil {
log.Printf("couldn't find user with ID %v in database, error: %v", userID.String(), err)
}
if user.SessionExpiresAt.Time.Before(time.Now()) {
// there exists no active session, so abort even the for loop
toLogin()
return
}
err = bcrypt.CompareHashAndPassword(user.SessionIDHash, []byte(sessionID))
if err != nil {
// sessionID doesn't match
toLogin()
return
}
// at this point, the sessionID matches
handler(w, req, user)
}
}
// Creates sessionID and sets the cookie
func createSession(w http.ResponseWriter, userID uuid.UUID) ([]byte, sql.NullTime, error) {
// create sessionID
sessionID := make([]byte, 32)
_, err := rand.Read(sessionID)
if err != nil {
log.Printf("Failed to create a random session ID: %v", err)
return []byte{}, sql.NullTime{}, err
}
// hash sessionID
hashedSessionID, err := bcrypt.GenerateFromPassword(sessionID, bcrypt.DefaultCost)
if err != nil {
log.Printf("couldn't hash session ID, error: %v", err)
return []byte{}, sql.NullTime{}, err
}
expireDuration, err := time.ParseDuration(sessionDurationString)
if err != nil {
log.Printf("couldn't parse session time duration from string, error: %v", err)
return []byte{}, sql.NullTime{}, err
}
sessionExpiresAt := time.Now().Add(expireDuration)
// escape sessionID to make it compatible with cookies
escapedSessionID := url.QueryEscape(string(sessionID))
cookieString := userID.String() + ":" + escapedSessionID
sessionCookie := http.Cookie{
Name: "session_id", Value: cookieString, Expires: sessionExpiresAt, Secure: false,
HttpOnly: true,
}
http.SetCookie(w, &sessionCookie)
return hashedSessionID, sql.NullTime{Time: sessionExpiresAt, Valid: true}, nil
}
func extractFromCookie(req *http.Request) (string, uuid.UUID, error) {
rawCookies := req.Header.Get("Cookie")
cookies, err := http.ParseCookie(rawCookies)
if err != nil {
return "", uuid.UUID{}, errors.New("couldn't parse http cookies")
}
var cookieString string
for _, cookie := range cookies {
if cookie.Name == "session_id" {
cookieString = cookie.Value
break
}
}
if cookieString == "" {
return "", uuid.UUID{}, errors.New("no session cookie found")
}
userIDAndSession := strings.SplitN(cookieString, ":", 2)
if len(userIDAndSession) != 2 {
log.Printf("Found session ID that couldn't be separated using separator ':': %v", cookieString)
return "", uuid.UUID{}, errors.New("cookie string couldn't be parsed")
}
userID, err := uuid.Parse(userIDAndSession[0])
if err != nil {
log.Printf("couldn't parse UUID %v while extracting session cookie", userIDAndSession[0])
return "", uuid.UUID{}, errors.New("UUID couldn't be parsed")
}
sessionID, err := url.QueryUnescape(userIDAndSession[1])
if err != nil {
log.Printf("couldn't unescape sessionID %v while extracting session cookie", userIDAndSession[1])
return "", uuid.UUID{}, errors.New("couldn't process sessionID")
}
return sessionID, userID, nil
}