Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Refresh Token #36

Draft
wants to merge 2 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions account/refresh.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package account

import (
"database/sql"

"github.com/bhuvansingla/iitk-coin/database"
)

func UpdateToken(token string, rollNo string) error {
_, err := database.DB.Exec(("UPDATE REFRESH_TOKEN SET token = $1 WHERE rollNo = $2"), token, rollNo)
return err
}

func DeleteToken(rollNo string) error {
return UpdateToken("", rollNo)
}

func InvalidateAllTokens() error {
_, err := database.DB.Exec(("UPDATE REFRESH_TOKEN SET token = $1"), "")
return err
}

func GetToken(rollNo string) (string, error) {
var token string
err := database.DB.QueryRow(("SELECT token FROM REFRESH_TOKEN WHERE rollNo = $1"), rollNo).Scan(&token)

if err == sql.ErrNoRows {
return "", nil
}
if err != nil {
return "", err
}
return token, nil
}
13 changes: 13 additions & 0 deletions account/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,27 @@ const (
func Create(rollNo string, hashedPasssword string, name string) error {

role := NormalUser

stmt, err := database.DB.Prepare("INSERT INTO ACCOUNT (rollNo, name, password, coins, role) VALUES ($1, $2, $3, $4, $5)")
if err != nil {
return errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
}

_, err = stmt.Exec(rollNo, name, hashedPasssword, 0, role)
if err != nil {
return errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
}

stmt, err = database.DB.Prepare("INSERT INTO REFRESH_TOKEN (rollNo, token) VALUES ($1, $2)")
if err != nil {
return errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
}

_, err = stmt.Exec(rollNo, "")
if err != nil {
return errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
}

return nil
}

Expand Down
37 changes: 37 additions & 0 deletions auth/accessToken.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package auth

import (
"net/http"
"time"

"github.com/bhuvansingla/iitk-coin/errors"
"github.com/spf13/viper"
)

func GenerateAccessToken(rollNo string) (string, error) {

expirationTime := time.Now().Add(time.Duration(viper.GetInt("JWT.ACCESS_TOKEN.EXPIRATION_TIME_IN_MIN")) * time.Minute)

return generateToken(rollNo, expirationTime)
}

func IsAuthorized(endpoint func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {

return func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(viper.GetString("JWT.ACCESS_TOKEN.NAME"))
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("bad token"))
return
}

err = isTokenValid(cookie)

if err == nil {
endpoint(w, r)
return
}

errors.WriteResponse(err, w)
}
}
105 changes: 44 additions & 61 deletions auth/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"
"time"

"github.com/bhuvansingla/iitk-coin/errors"
"github.com/golang-jwt/jwt/v4"
"github.com/spf13/viper"
)
Expand All @@ -16,10 +17,33 @@ type Claims struct {
jwt.RegisteredClaims
}

func GenerateToken(rollNo string) (string, error) {
func GetRollNoFromRequest(r *http.Request) (string, error) {
cookie, err := r.Cookie(viper.GetString("JWT.ACCESS_TOKEN.NAME"))
if err != nil {
return "", err
}
return GetRollNoFromTokenCookie(cookie)
}

expirationTime := time.Now().Add(time.Duration(viper.GetInt("JWT.EXPIRATION_TIME_IN_MIN")) * time.Minute)
func GetRollNoFromTokenCookie(cookie *http.Cookie) (string, error) {
token := cookie.Value
claims := &Claims{}
_, err := jwt.ParseWithClaims(token, claims, keyFunc)
if err != nil {
return "", err
}
return claims.RollNo, nil
}

func keyFunc(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("invalid signing method")
}
return privateKey, nil
}

func generateToken(rollNo string, expirationTime time.Time) (string, error) {

claims := &Claims{
RollNo: rollNo,
RegisteredClaims: jwt.RegisteredClaims{
Expand All @@ -37,66 +61,25 @@ func GenerateToken(rollNo string) (string, error) {
return tokenString, nil
}

func IsAuthorized(endpoint func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {

return func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(viper.GetString("JWT.COOKIE_NAME"))
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("bad token"))
return
}

token, err := jwt.Parse(cookie.Value, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("invalid signing method")
}
return privateKey, nil
})

if token.Valid {
endpoint(w, r)
return
} else if ve, ok := err.(*jwt.ValidationError); ok {
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("bad token"))
return
} else if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 {
// Token is either expired or not active yet
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("token expired"))
return
} else {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(http.StatusText(http.StatusInternalServerError)))
return
}
func isTokenValid(cookie *http.Cookie) error {

token, err := jwt.Parse(cookie.Value, keyFunc)

if token.Valid {
return nil
}

jwtError, ok := err.(*jwt.ValidationError)

if ok {
if jwtError.Errors&jwt.ValidationErrorMalformed != 0 {
return errors.NewHTTPError(err, http.StatusBadRequest, "validation malformed")
} else if jwtError.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 {
return errors.NewHTTPError(err, http.StatusUnauthorized, "token expired")
} else {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(http.StatusText(http.StatusInternalServerError)))
return
return errors.NewHTTPError(nil, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
}

}
}

func GetRollNoFromRequest(r *http.Request) (string, error) {
cookie, err := r.Cookie(viper.GetString("JWT.COOKIE_NAME"))
if err != nil {
return "", err
}
return GetRollNoFromTokenCookie(cookie)
}

func GetRollNoFromTokenCookie(cookie *http.Cookie) (string, error) {
token := cookie.Value
claims := &Claims{}
_, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
return privateKey, nil
})
if err != nil {
return "", err
} else {
return errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
}
return claims.RollNo, nil
}
77 changes: 77 additions & 0 deletions auth/refresh.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package auth

import (
"net/http"
"time"

"github.com/bhuvansingla/iitk-coin/account"
"github.com/bhuvansingla/iitk-coin/errors"
"github.com/spf13/viper"
)

func GenerateRefreshToken(rollNo string) (string, error) {

expirationTime := time.Now().Add(time.Duration(viper.GetInt("JWT.REFRESH_TOKEN.EXPIRATION_TIME_IN_MIN")) * time.Minute)

refreshToken, err := generateToken(rollNo, expirationTime)
if err != nil {
return "", err
}

err = account.UpdateToken(refreshToken, rollNo)
if err != nil {
return "", err
}

return refreshToken, nil
}

func CheckRefreshTokenValidity(r *http.Request) (string, error) {

cookie, err := r.Cookie(viper.GetString("JWT.ACCESS_TOKEN.NAME"))
if err != nil {
return "", errors.NewHTTPError(err, http.StatusBadRequest, "bad access token")
}

err = isTokenValid(cookie)

if err == nil {
rollNo, err := GetRollNoFromTokenCookie(cookie)
if err != nil {
return "", errors.NewHTTPError(err, http.StatusBadRequest, "bad access token")
}
return rollNo, nil
}

clientError, ok := err.(errors.ClientError)
if !ok {
return "", errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
}

if status, _ := clientError.ResponseHeaders(); status!=http.StatusUnauthorized {
return "", err
}

cookie, err = r.Cookie(viper.GetString("JWT.REFRESH_TOKEN.NAME"))
if err != nil {
return "", errors.NewHTTPError(err, http.StatusBadRequest, "bad refresh token")
}

rollNo, err := GetRollNoFromTokenCookie(cookie)
if err != nil {
return "", errors.NewHTTPError(err, http.StatusBadRequest, "bad refresh token")
}

refreshToken, err := account.GetToken(rollNo)
if err != nil {
return "", errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
}

if refreshToken != cookie.Value {
return "", errors.NewHTTPError(err, http.StatusBadRequest, "bad refresh token")
}

err = isTokenValid(cookie)

return rollNo, err
}
8 changes: 6 additions & 2 deletions config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@ TAX:
INTRA_BATCH: 33

JWT:
ACCESS_TOKEN:
NAME: "access_token"
EXPIRATION_TIME_IN_MIN: 10
REFRESH_TOKEN:
NAME: "refresh_token"
EXPIRATION_TIME_IN_MIN: 50000
PRIVATE_KEY: "this-is-a-secret"
EXPIRATION_TIME_IN_MIN: 10
COOKIE_NAME: "token"

LOGGER:
LOG_LEVEL: 5 # Error: 2, Warn: 3, Info: 4, Debug: 5
Expand Down
12 changes: 11 additions & 1 deletion database/tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,16 @@ func createTables() (err error) {
log.Error(err.Error())
return
}
err = createRefreshTokenTable()
if err != nil {
log.Error(err.Error())
return
}
return
}

func createAccountTable() (err error) {
_, err = DB.Exec("create table if not exists ACCOUNT (rollNo text, name text, password text, coins int, role int)")
_, err = DB.Exec("create table if not exists ACCOUNT (rollNo text PRIMARY KEY NOT NULL, name text, password text, coins int, role int)")
return
}

Expand All @@ -58,3 +63,8 @@ func createRewardHistoryTable() (err error) {
_, err = DB.Exec("create table if not exists REWARD_HISTORY (id SERIAL PRIMARY KEY NOT NULL, rollNo text, coins int, time NUMERIC, remarks text)")
return
}

func createRefreshTokenTable() (err error) {
_, err = DB.Exec("create table if not exists REFRESH_TOKEN (rollNo text PRIMARY KEY NOT NULL, token text)")
return
}
Loading