diff --git a/auth/auth.go b/auth/auth.go index 8e6a87ff..6df52067 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -29,6 +29,8 @@ type Authentication interface { ReadUserWithSession(ctx echo.Context) error RenewAccessToken(ctx echo.Context) error VerifyEmail(ctx echo.Context) error + ResetPassword(ctx echo.Context) error + Invites(ctx echo.Context) error } // New is the constructor function returns an Authentication implementation diff --git a/auth/bcrypt.go b/auth/bcrypt.go index d25f0514..c6e6b430 100644 --- a/auth/bcrypt.go +++ b/auth/bcrypt.go @@ -1,6 +1,8 @@ package auth -import "golang.org/x/crypto/bcrypt" +import ( + "golang.org/x/crypto/bcrypt" +) const bcryptMinCost = 6 diff --git a/auth/invites.go b/auth/invites.go new file mode 100644 index 00000000..f7fb01bc --- /dev/null +++ b/auth/invites.go @@ -0,0 +1,38 @@ +package auth + +import ( + "encoding/json" + "net/http" + "strings" + + "github.com/labstack/echo/v4" +) + +type List struct { + Emails string +} + +func (a *auth) Invites(ctx echo.Context) error { + var list List + err := json.NewDecoder(ctx.Request().Body).Decode(&list) + if err != nil { + a.logger.Log(ctx, err) + return ctx.JSON(http.StatusBadRequest, echo.Map{ + "error": err.Error(), + "msg": "error decode body, expecting and array of emails", + }) + } + err = a.emailClient.WelcomeEmail(strings.Split(list.Emails, ",")) + if err != nil { + a.logger.Log(ctx, err) + return ctx.JSON(http.StatusInternalServerError, echo.Map{ + "error": err.Error(), + "msg": "err sending invites", + }) + } + + a.logger.Log(ctx, err) + return ctx.JSON(http.StatusAccepted, echo.Map{ + "msg": "success", + }) +} diff --git a/auth/jwt.go b/auth/jwt.go index 9457d251..3cf4e23c 100644 --- a/auth/jwt.go +++ b/auth/jwt.go @@ -6,7 +6,6 @@ import ( "github.com/containerish/OpenRegistry/types" "github.com/golang-jwt/jwt" - "github.com/google/uuid" "golang.org/x/oauth2" ) @@ -158,7 +157,7 @@ func (a *auth) createOAuthClaims(u types.User, token *oauth2.Token) PlatformClai StandardClaims: jwt.StandardClaims{ Audience: a.c.Endpoint(), ExpiresAt: time.Now().Add(time.Hour).Unix(), - Id: uuid.NewString(), + Id: u.Id, IssuedAt: time.Now().Unix(), Issuer: a.c.Endpoint(), NotBefore: time.Now().Unix(), diff --git a/auth/jwt_middleware.go b/auth/jwt_middleware.go index a14bddfa..eb8cb5fd 100644 --- a/auth/jwt_middleware.go +++ b/auth/jwt_middleware.go @@ -3,18 +3,29 @@ package auth import ( "fmt" "net/http" + "strings" "time" "github.com/containerish/OpenRegistry/types" + "github.com/fatih/color" "github.com/golang-jwt/jwt" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" ) +const ( + AccessCookieKey = "access" + QueryToken = "token" +) + // JWT basically uses the default JWT middleware by echo, but has a slightly different skipper func func (a *auth) JWT() echo.MiddlewareFunc { return middleware.JWTWithConfig(middleware.JWTConfig{ Skipper: func(ctx echo.Context) bool { + if strings.HasPrefix(ctx.Request().RequestURI, "/auth") { + return false + } + // if JWT_AUTH is not set, we don't need to perform JWT authentication jwtAuth, ok := ctx.Get(JWT_AUTH_KEY).(bool) if !ok { @@ -33,6 +44,7 @@ func (a *auth) JWT() echo.MiddlewareFunc { ErrorHandlerWithContext: func(err error, ctx echo.Context) error { // ErrorHandlerWithContext only logs the failing requtest ctx.Set(types.HandlerStartTime, time.Now()) + color.Red(ctx.QueryParam("token")) a.logger.Log(ctx, err) return ctx.JSON(http.StatusUnauthorized, echo.Map{ "error": err.Error(), @@ -45,6 +57,7 @@ func (a *auth) JWT() echo.MiddlewareFunc { SigningKeys: map[string]interface{}{}, SigningMethod: jwt.SigningMethodHS256.Name, Claims: &Claims{}, + TokenLookup: fmt.Sprintf("cookie:%s,header:%s", AccessCookieKey, echo.HeaderAuthorization), }) } @@ -73,7 +86,7 @@ func (a *auth) ACL() echo.MiddlewareFunc { username := ctx.Param("username") - user, err := a.pgStore.GetUserById(ctx.Request().Context(), claims.Id) + user, err := a.pgStore.GetUserById(ctx.Request().Context(), claims.Id, false) if err != nil { a.logger.Log(ctx, err) return ctx.NoContent(http.StatusUnauthorized) diff --git a/auth/renew.go b/auth/renew.go index 10b0f46a..11a8afe5 100644 --- a/auth/renew.go +++ b/auth/renew.go @@ -55,7 +55,7 @@ func (a *auth) RenewAccessToken(ctx echo.Context) error { } userId := claims.Id - user, err := a.pgStore.GetUserById(ctx.Request().Context(), userId) + user, err := a.pgStore.GetUserById(ctx.Request().Context(), userId, false) if err != nil { a.logger.Log(ctx, err) return ctx.JSON(http.StatusUnauthorized, echo.Map{ diff --git a/auth/reset_password.go b/auth/reset_password.go new file mode 100644 index 00000000..8c6f2efd --- /dev/null +++ b/auth/reset_password.go @@ -0,0 +1,125 @@ +package auth + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/containerish/OpenRegistry/services/email" + "github.com/containerish/OpenRegistry/types" + "github.com/golang-jwt/jwt" + "github.com/labstack/echo/v4" +) + +func (a *auth) ResetPassword(ctx echo.Context) error { + token, ok := ctx.Get("user").(*jwt.Token) + if !ok { + err := fmt.Errorf("JWT token can not be empty") + a.logger.Log(ctx, err) + return ctx.JSON(http.StatusUnauthorized, echo.Map{ + "error": err.Error(), + }) + } + + c, ok := token.Claims.(*Claims) + if !ok { + err := fmt.Errorf("invalid claims in JWT") + a.logger.Log(ctx, err) + return ctx.JSON(http.StatusInternalServerError, echo.Map{ + "error": err.Error(), + }) + } + + userId := c.Id + user, err := a.pgStore.GetUserById(ctx.Request().Context(), userId, true) + if err != nil { + a.logger.Log(ctx, err) + return ctx.JSON(http.StatusNotFound, echo.Map{ + "error": err.Error(), + }) + } + + var pwd *types.Password + + kind := ctx.QueryParam("kind") + + if kind == "forgot" { + if err = a.emailClient.SendEmail(user, token.Raw, email.ResetPasswordEmailKind); err != nil { + a.logger.Log(ctx, err) + return ctx.JSON(http.StatusInternalServerError, echo.Map{ + "error": err.Error(), + "msg": "error sending reset password link", + }) + } + a.logger.Log(ctx, err) + return ctx.JSON(http.StatusAccepted, echo.Map{ + "msg": "success", + }) + + } + + err = json.NewDecoder(ctx.Request().Body).Decode(&pwd) + if err != nil { + a.logger.Log(ctx, err) + return ctx.JSON(http.StatusBadRequest, echo.Map{ + "error": err.Error(), + "msg": "request body could not be decoded", + }) + } + + if kind == "forgot_password_callback" { + hashPassword, err := a.hashPassword(pwd.NewPassword) + if err != nil { + a.logger.Log(ctx, err) + return ctx.JSON(http.StatusInternalServerError, echo.Map{ + "error": err.Error(), + }) + } + + if err = a.pgStore.UpdateUserPWD(ctx.Request().Context(), userId, hashPassword); err != nil { + a.logger.Log(ctx, err) + return ctx.JSON(http.StatusInternalServerError, echo.Map{ + "error": err.Error(), + }) + } + + return ctx.NoContent(http.StatusOK) + } + + if !a.verifyPassword(user.Password, pwd.OldPassword) { + err = fmt.Errorf("passwords do not match") + a.logger.Log(ctx, err) + return ctx.JSON(http.StatusBadRequest, echo.Map{ + "error": err.Error(), + }) + } + + if pwd.OldPassword == pwd.NewPassword { + err = fmt.Errorf("new password can not be same as old password") + a.logger.Log(ctx, err) + return ctx.JSON(http.StatusBadRequest, echo.Map{ + "error": err.Error(), + }) + } + + newHashedPwd, err := a.hashPassword(pwd.NewPassword) + if err != nil { + a.logger.Log(ctx, err) + return ctx.JSON(http.StatusInternalServerError, echo.Map{ + "error": err.Error(), + }) + } + + if err = a.pgStore.UpdateUserPWD(ctx.Request().Context(), userId, newHashedPwd); err != nil { + a.logger.Log(ctx, err) + return ctx.JSON(http.StatusInternalServerError, echo.Map{ + "error": err.Error(), + "msg": "error updating user in db", + }) + } + + a.logger.Log(ctx, nil) + return ctx.JSON(http.StatusAccepted, echo.Map{ + "msg": "success", + }) +} diff --git a/auth/signin.go b/auth/signin.go index 052146a0..70abd07b 100644 --- a/auth/signin.go +++ b/auth/signin.go @@ -40,6 +40,7 @@ func (a *auth) SignIn(ctx echo.Context) error { a.logger.Log(ctx, err) return ctx.JSON(http.StatusBadRequest, echo.Map{ "error": err.Error(), + "msg": "error while get user", }) } diff --git a/auth/signup.go b/auth/signup.go index 42f34aff..97bf4fd1 100644 --- a/auth/signup.go +++ b/auth/signup.go @@ -10,6 +10,7 @@ import ( "unicode" "github.com/containerish/OpenRegistry/config" + "github.com/containerish/OpenRegistry/services/email" "github.com/containerish/OpenRegistry/types" "github.com/google/uuid" "github.com/labstack/echo/v4" @@ -79,15 +80,7 @@ func (a *auth) SignUp(ctx echo.Context) error { }) } - token, err := a.newWebLoginToken(newUser.Id, newUser.Username, "access") - if err != nil { - a.logger.Log(ctx, err) - return ctx.JSON(http.StatusInternalServerError, echo.Map{ - "error": err.Error(), - "code": "CREATE_NEW_ACCESS_TOKEN", - }) - } - + token := uuid.NewString() err = a.pgStore.AddVerifyEmail(ctx.Request().Context(), token, newUser.Id) if err != nil { ctx.Set(types.HttpEndpointErrorKey, err.Error()) @@ -97,7 +90,7 @@ func (a *auth) SignUp(ctx echo.Context) error { }) } - if err = a.emailClient.SendEmail(newUser, token); err != nil { + if err = a.emailClient.SendEmail(newUser, token, email.VerifyEmailKind); err != nil { ctx.Set(types.HttpEndpointErrorKey, err.Error()) return ctx.JSON(http.StatusInternalServerError, echo.Map{ "error": err.Error(), @@ -106,7 +99,7 @@ func (a *auth) SignUp(ctx echo.Context) error { a.logger.Log(ctx, err) return ctx.JSON(http.StatusCreated, echo.Map{ - "message": "user successfully created", + "message": "signup was successful, please check your email to activate your account", }) } diff --git a/auth/verify_email.go b/auth/verify_email.go index c92f7a54..fe9a6936 100644 --- a/auth/verify_email.go +++ b/auth/verify_email.go @@ -1,64 +1,39 @@ package auth import ( - "encoding/base64" - "fmt" "net/http" "time" "github.com/containerish/OpenRegistry/types" - "github.com/golang-jwt/jwt" + "github.com/google/uuid" "github.com/labstack/echo/v4" ) func (a *auth) VerifyEmail(ctx echo.Context) error { ctx.Set(types.HandlerStartTime, time.Now()) - token := ctx.QueryParam("token") if token == "" { return ctx.JSON(http.StatusBadRequest, echo.Map{ - "error": "EMPTY_TOKEN", + "error": "token can not be empty", }) } - jToken, err := base64.StdEncoding.DecodeString(token) - if err != nil { + if _, err := uuid.Parse(token); err != nil { return ctx.JSON(http.StatusBadRequest, echo.Map{ - "error": err.Error(), - "msg": "EMPTY_TOKEN", + "error": err.Error(), + "message": "ERR_PARSE_TOKEN", }) } - var t *jwt.Token - - _, err = jwt.Parse(string(jToken), func(jt *jwt.Token) (interface{}, error) { - if jt == nil { - return nil, fmt.Errorf("ERR_PARSE_JWT_TOKEN") - } - t = jt - return nil, nil - }) - - claims, ok := t.Claims.(jwt.MapClaims) - if !ok { - ctx.Set(types.HttpEndpointErrorKey, err.Error()) + userId, err := a.pgStore.GetVerifyEmail(ctx.Request().Context(), token) + if err != nil { return ctx.JSON(http.StatusBadRequest, echo.Map{ - "error": err.Error(), - "msg": "ERR_CONVERT_CLAIMS", - }) - } - - id := claims["ID"].(string) - tokenFromDb, err := a.pgStore.GetVerifyEmail(ctx.Request().Context(), id) - if tokenFromDb != string(jToken) { - ctx.Set(types.HttpEndpointErrorKey, err.Error()) - return ctx.JSON(http.StatusInternalServerError, echo.Map{ - "error": err.Error(), - "msg": "ERR_TOKEN_MISMATCH", + "error": err.Error(), + "message": "invalid token", }) } - user, err := a.pgStore.GetUserById(ctx.Request().Context(), id) + user, err := a.pgStore.GetUserById(ctx.Request().Context(), userId, false) if err != nil { ctx.Set(types.HttpEndpointErrorKey, err.Error()) return ctx.JSON(http.StatusInternalServerError, echo.Map{ @@ -69,7 +44,7 @@ func (a *auth) VerifyEmail(ctx echo.Context) error { user.IsActive = true - err = a.pgStore.UpdateUser(ctx.Request().Context(), id, user) + err = a.pgStore.UpdateUser(ctx.Request().Context(), userId, user) if err != nil { ctx.Set(types.HttpEndpointErrorKey, err.Error()) return ctx.JSON(http.StatusInternalServerError, echo.Map{ @@ -78,7 +53,7 @@ func (a *auth) VerifyEmail(ctx echo.Context) error { }) } - err = a.pgStore.DeleteVerifyEmail(ctx.Request().Context(), id) + err = a.pgStore.DeleteVerifyEmail(ctx.Request().Context(), token) if err != nil { ctx.Set(types.HttpEndpointErrorKey, err.Error()) return ctx.JSON(http.StatusInternalServerError, echo.Map{ @@ -88,6 +63,6 @@ func (a *auth) VerifyEmail(ctx echo.Context) error { } return ctx.JSON(http.StatusOK, echo.Map{ - "message": "success", + "message": "user profile activated successfully", }) } diff --git a/config/config.go b/config/config.go index 506cc17b..458ae642 100644 --- a/config/config.go +++ b/config/config.go @@ -71,11 +71,12 @@ type ( } Email struct { - ApiKey string `mapstructure:"api_key"` - SendAs string `mapstructure:"send_as"` - VerifyEmailTemplate string `mapstructure:"verify_template_id"` - WelcomeEmailTemplate string `mapstructure:"welcome_template_id"` - Enabled bool `mapstructure:"enabled"` + Enabled bool `mapstructure:"enabled"` + ApiKey string `mapstructure:"api_key"` + SendAs string `mapstructure:"send_as"` + VerifyEmailTemplateId string `mapstructure:"verify_template_id"` + ForgotPasswordTemplateId string `mapstructure:"forgot_password_template_id"` + WelcomeEmailTemplateId string `mapstructure:"welcome_template_id"` } ) diff --git a/router/helpers.go b/router/helpers.go index 0e253552..3b830f5a 100644 --- a/router/helpers.go +++ b/router/helpers.go @@ -12,7 +12,9 @@ import ( // RegisterAuthRoutes includes all the auth related endpoints func RegisterAuthRoutes(authRouter *echo.Group, authSvc auth.Authentication) { + //send-email/welcome authRouter.Add(http.MethodPost, "/signup", authSvc.SignUp) + authRouter.Add(http.MethodPost, "/send-email/welcome", authSvc.Invites) authRouter.Add(http.MethodGet, "/signup/verify", authSvc.VerifyEmail) authRouter.Add(http.MethodPost, "/signin", authSvc.SignIn) authRouter.Add(http.MethodPost, "/token", authSvc.SignIn) @@ -20,4 +22,6 @@ func RegisterAuthRoutes(authRouter *echo.Group, authSvc auth.Authentication) { authRouter.Add(http.MethodGet, "/sessions/me", authSvc.ReadUserWithSession) authRouter.Add(http.MethodDelete, "/sessions", authSvc.ExpireSessions) authRouter.Add(http.MethodGet, "/renew", authSvc.RenewAccessToken) + authRouter.Add(http.MethodPost, "/reset-password", authSvc.ResetPassword, authSvc.JWT()) + } diff --git a/services/email/createEmail.go b/services/email/createEmail.go index d919538c..9d09dd7f 100644 --- a/services/email/createEmail.go +++ b/services/email/createEmail.go @@ -1,19 +1,44 @@ package email -import "github.com/sendgrid/sendgrid-go/helpers/mail" +import ( + "fmt" -func (e *email) CreateEmail(mailReq *Mail) (*mail.SGMailV3, error) { + "github.com/containerish/OpenRegistry/types" + "github.com/sendgrid/sendgrid-go/helpers/mail" +) + +func (e *email) CreateEmail(u *types.User, kind EmailKind, token string) (*mail.SGMailV3, error) { + mailReq := &Mail{} m := mail.NewV3Mail() - email := mail.NewEmail(mailReq.Name, e.config.SendAs) - m.SetFrom(email) + mailReq.To = append(mailReq.To, u.Email) + mailReq.Data.Username = u.Username + + switch kind { + case VerifyEmailKind: + m.SetTemplateID(e.config.VerifyEmailTemplateId) + mailReq.Name = "OpenRegistry" + mailReq.Subject = "Verify Email" + mailReq.Data.Link = fmt.Sprintf("%s/auth/signup/verify?token=%s", e.backendEndpoint, token) + + case ResetPasswordEmailKind: + m.SetTemplateID(e.config.ForgotPasswordTemplateId) + mailReq.Name = "OpenRegistry" + mailReq.Subject = "Forgot Password" + mailReq.Data.Link = fmt.Sprintf("%s/auth/reset-password?token=%s", e.backendEndpoint, token) - m.SetTemplateID(e.config.VerifyEmailTemplate) + default: + return nil, fmt.Errorf("incorrect email kind") + } + email := mail.NewEmail(mailReq.Name, e.config.SendAs) + m.SetFrom(email) p := mail.NewPersonalization() + tos := []*mail.Email{ - mail.NewEmail(mailReq.To, mailReq.To), + mail.NewEmail(mailReq.To[0], mailReq.To[0]), } + p.AddTos(tos...) p.SetDynamicTemplateData("user", mailReq.Data.Username) diff --git a/services/email/email.go b/services/email/email.go index dad0f3f2..8b71fca9 100644 --- a/services/email/email.go +++ b/services/email/email.go @@ -21,17 +21,18 @@ type MailData struct { } type Mail struct { + Data MailData Name string - To string + To []string Subject string Body string Mtype MailType - Data MailData } type MailService interface { - CreateEmail(mailReq *Mail) (*mail.SGMailV3, error) - SendEmail(u *types.User, token string) error + CreateEmail(u *types.User, kind EmailKind, token string) (*mail.SGMailV3, error) + SendEmail(u *types.User, token string, kind EmailKind) error + WelcomeEmail(list []string) error } func New(config *config.Email, backendEndpoint string) MailService { diff --git a/services/email/sendEmail.go b/services/email/sendEmail.go index 3a3f26ad..23995af8 100644 --- a/services/email/sendEmail.go +++ b/services/email/sendEmail.go @@ -1,26 +1,23 @@ package email import ( - "encoding/base64" "fmt" "net/http" "github.com/containerish/OpenRegistry/types" ) -func (e *email) SendEmail(u *types.User, token string) error { - token = base64.StdEncoding.EncodeToString([]byte(token)) - mailReq := &Mail{ - Name: "Verify OpenRegistry Signup", - To: u.Email, - Subject: "OpenRegistry - Signup Verification Email", - Data: MailData{ - Username: u.Username, - Link: fmt.Sprintf("%s/auth/signup/verify?token=%s", e.backendEndpoint, token), - }, - } +const ( + //KindWelcomeEmail + WelcomeEmailKind EmailKind = iota + VerifyEmailKind + ResetPasswordEmailKind +) + +type EmailKind int8 - mailMsg, err := e.CreateEmail(mailReq) +func (e *email) SendEmail(u *types.User, token string, kind EmailKind) error { + mailMsg, err := e.CreateEmail(u, kind, token) if err != nil { return fmt.Errorf("ERR_CREATE_EMAIL: %w", err) } diff --git a/services/email/welcome_email.go b/services/email/welcome_email.go new file mode 100644 index 00000000..b6c68a4f --- /dev/null +++ b/services/email/welcome_email.go @@ -0,0 +1,36 @@ +package email + +import ( + "fmt" + "net/http" + + "github.com/sendgrid/sendgrid-go/helpers/mail" +) + +func (e *email) WelcomeEmail(list []string) error { + + mailReq := &Mail{} + m := mail.NewV3Mail() + + m.SetTemplateID(e.config.WelcomeEmailTemplateId) + mailReq.Name = "OpenRegistry" + mailReq.Subject = "Welcome to OpenRegistry" + mailReq.Data.Link = fmt.Sprintf("%s/send-email/welcome", e.backendEndpoint) + + email := mail.NewEmail(mailReq.Name, e.config.SendAs) + m.SetFrom(email) + p := mail.NewPersonalization() + + var tos []*mail.Email + for _, v := range list { + tos = append(tos, mail.NewEmail(v, v)) + } + p.AddTos(tos...) + m.AddPersonalizations(p) + + resp, err := e.client.Send(m) + if err != nil && resp.StatusCode != http.StatusAccepted { + return fmt.Errorf("ERR_SEND_EMAIL: %w", err) + } + return nil +} diff --git a/store/postgres/postgres.go b/store/postgres/postgres.go index c42efa00..531c3d04 100644 --- a/store/postgres/postgres.go +++ b/store/postgres/postgres.go @@ -23,14 +23,15 @@ type UserStore interface { AddOAuthUser(ctx context.Context, u *types.User) error UserExists(ctx context.Context, id string) bool GetUser(ctx context.Context, identifier string, wihtPassword bool) (*types.User, error) - GetUserById(ctx context.Context, userId string) (*types.User, error) + GetUserById(ctx context.Context, userId string, wihtPassword bool) (*types.User, error) GetUserWithSession(ctx context.Context, sessionId string) (*types.User, error) - AddSession(ctx context.Context, sessionId, refreshToken, owner string) error - DeleteSession(ctx context.Context, sessionId, userId string) error - DeleteAllSessions(ctx context.Context, userId string) error UpdateUser(ctx context.Context, identifier string, u *types.User) error + UpdateUserPWD(ctx context.Context, identifier string, newPassword string) error DeleteUser(ctx context.Context, identifier string) error IsActive(ctx context.Context, identifier string) bool + AddSession(ctx context.Context, sessionId, refreshToken, owner string) error + DeleteSession(ctx context.Context, sessionId, userId string) error + DeleteAllSessions(ctx context.Context, userId string) error AddVerifyEmail(ctx context.Context, userId, token string) error GetVerifyEmail(ctx context.Context, userId string) (string, error) DeleteVerifyEmail(ctx context.Context, userId string) error diff --git a/store/postgres/queries/users.go b/store/postgres/queries/users.go index e3f23304..8b3135ec 100644 --- a/store/postgres/queries/users.go +++ b/store/postgres/queries/users.go @@ -4,14 +4,17 @@ package queries var ( AddUser = `insert into users (id, is_active, username, name, email, password, created_at, updated_at) values ($1, $2, $3, $4, $5, $6, $7, $8);` - GetUser = `select id, is_active, username, email, created_at, updated_at from users where email=$1 or username=$1;` - GetUserWithPassword = `select id, is_active, username, email, password, created_at, updated_at from users where email=$1 or username=$1;` - GetUserById = `select id, is_active, username, email, created_at, updated_at from users where id=$1;` - GetUserWithSession = `select id, is_active, name, username, email, created_at, updated_at from users where id=(select owner from session where id=$1);` - UpdateUser = `update user set username = $1, email = $2, password = $3, updated_at = $4 where username = $5;` - DeleteUser = `delete from user where username = $1;` - GetAllEmails = `select email from users;` - AddOAuthUser = `insert into users (id, username, email, created_at, updated_at, + GetUser = `select id, is_active, username, email, created_at, updated_at from users where email=$1 or username=$1;` + GetUserWithPassword = `select id, is_active, username, email, password, created_at, updated_at from users where email=$1 or username=$1;` + GetUserById = `select id, is_active, username, email, created_at, updated_at from users where id=$1;` + GetUserByIdWithPassword = `select id, is_active, username, email, password, created_at, updated_at from users where id=$1;` + GetUserWithSession = `select id, is_active, name, username, email, created_at, updated_at from users where id=(select owner from session where id=$1);` + UpdateUser = `update users set is_active = $1, updated_at = $2 where id = $3;` + SetUserActive = `update users set is_active=true where id=$1` + DeleteUser = `delete from users where username = $1;` + UpdateUserPwd = `update users set password=$1 where id=$2;` + GetAllEmails = `select email from users;` + AddOAuthUser = `insert into users (id, username, email, created_at, updated_at, bio, type, gravatar_id, login, name, node_id, avatar_url, oauth_id, is_active, hireable) values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) on conflict (email) do update set username=$2, email=$3` ) diff --git a/store/postgres/queries/verify_email.go b/store/postgres/queries/verify_email.go index 11ad484a..7bd4e543 100644 --- a/store/postgres/queries/verify_email.go +++ b/store/postgres/queries/verify_email.go @@ -2,6 +2,6 @@ package queries const ( AddVerifyUser = `insert into verify_emails (token,user_id) values ($1,$2);` - GetVerifyUser = `select token from verify_emails where user_id=$1;` - DeleteVerifyUser = `delete from verify_emails where user_id=$1;` + GetVerifyUser = `select user_id from verify_emails where token=$1;` + DeleteVerifyUser = `delete from verify_emails where token=$1;` ) diff --git a/store/postgres/users.go b/store/postgres/users.go index 9a944043..ca2379c3 100644 --- a/store/postgres/users.go +++ b/store/postgres/users.go @@ -107,12 +107,30 @@ func (p *pg) GetUser(ctx context.Context, identifier string, withPassword bool) return &user, nil } -func (p *pg) GetUserById(ctx context.Context, userId string) (*types.User, error) { +func (p *pg) GetUserById(ctx context.Context, userId string, withPassword bool) (*types.User, error) { childCtx, cancel := context.WithTimeout(ctx, time.Millisecond*100) defer cancel() - row := p.conn.QueryRow(childCtx, queries.GetUserById, userId) + if withPassword { + row := p.conn.QueryRow(childCtx, queries.GetUserByIdWithPassword, userId) + + var user types.User + if err := row.Scan( + &user.Id, + &user.IsActive, + &user.Username, + &user.Email, + &user.Password, + &user.CreatedAt, + &user.UpdatedAt, + ); err != nil { + return nil, fmt.Errorf("ERR_GET_USER_BY_ID_PWD_HASH: %w", err) + } + return &user, nil + } + + row := p.conn.QueryRow(childCtx, queries.GetUserById, userId) var user types.User err := row.Scan( &user.Id, @@ -123,7 +141,7 @@ func (p *pg) GetUserById(ctx context.Context, userId string) (*types.User, error &user.UpdatedAt, ) if err != nil { - return nil, fmt.Errorf("ERR_SESSION_NOT_FOUND: %w", err) + return nil, fmt.Errorf("ERR_GET_USER_BY_ID: %w", err) } return &user, nil @@ -153,54 +171,32 @@ func (p *pg) GetUserWithSession(ctx context.Context, sessionId string) (*types.U // UpdateUser //update users set username = $1, email = $2, updated_at = $3 where username = $4 -func (p *pg) UpdateUser(ctx context.Context, identifier string, u *types.User) error { - if err := u.Validate(); err != nil { - return err - } +func (p *pg) UpdateUser(ctx context.Context, userId string, u *types.User) error { childCtx, cancel := context.WithTimeout(ctx, time.Millisecond*100) defer cancel() t := time.Now() - _, err := p.conn.Exec(childCtx, queries.UpdateUser, u.Username, u.Email, t, u.IsActive, identifier) + _, err := p.conn.Exec(childCtx, queries.UpdateUser, u.IsActive, t, userId) if err != nil { return fmt.Errorf("error updating user: %s", err) } return nil } -// DeleteUser - delete from user where username = $1; -func (p *pg) DeleteUser(ctx context.Context, identifier string) error { +func (p *pg) UpdateUserPWD(ctx context.Context, identifier string, newPassword string) error { + if newPassword == "" { + return fmt.Errorf("insufficient feilds for updating user") + } childCtx, cancel := context.WithTimeout(ctx, time.Millisecond*100) defer cancel() - _, err := p.conn.Exec(childCtx, queries.DeleteUser, identifier) + _, err := p.conn.Exec(childCtx, queries.UpdateUserPwd, newPassword, identifier) if err != nil { - return fmt.Errorf("error deleting user: %s", identifier) + return fmt.Errorf("error updating user: %s", err) } return nil } -//IsActive - if the user has logged in, isActive returns true -// this method is also useful for limiting access of malicious actors -func (p *pg) IsActive(ctx context.Context, identifier string) bool { - childCtx, cancel := context.WithTimeout(ctx, time.Millisecond*100) - defer cancel() - row := p.conn.QueryRow(childCtx, queries.GetUser, identifier) - return row != nil -} - -func (p *pg) UserExists(ctx context.Context, id string) bool { - childCtx, cancel := context.WithTimeout(ctx, time.Millisecond*100) - defer cancel() - - row, err := p.GetUserById(childCtx, id) - if err != nil || row == nil { - return false - } - - return true -} - func (p *pg) AddVerifyEmail(ctx context.Context, token, userId string) error { childCtx, cancel := context.WithTimeout(ctx, time.Millisecond*100) defer cancel() @@ -211,31 +207,65 @@ func (p *pg) AddVerifyEmail(ctx context.Context, token, userId string) error { return nil } -func (p *pg) GetVerifyEmail(ctx context.Context, userId string) (string, error) { +func (p *pg) GetVerifyEmail(ctx context.Context, token string) (string, error) { childCtx, cancel := context.WithTimeout(ctx, time.Millisecond*100) defer cancel() - row := p.conn.QueryRow(childCtx, queries.GetVerifyUser, userId) + row := p.conn.QueryRow(childCtx, queries.GetVerifyUser, token) if row == nil { return "", fmt.Errorf("could not find verify token for userId") } - var token string - err := row.Scan(&token) + var userId string + err := row.Scan(&userId) if err != nil { return "", fmt.Errorf("error scanning verify token: %w", err) } - return token, nil + return userId, nil } -func (p *pg) DeleteVerifyEmail(ctx context.Context, userId string) error { +func (p *pg) DeleteVerifyEmail(ctx context.Context, token string) error { childCtx, cancel := context.WithTimeout(ctx, time.Millisecond*100) defer cancel() - _, err := p.conn.Exec(childCtx, queries.DeleteVerifyUser, userId) + _, err := p.conn.Exec(childCtx, queries.DeleteVerifyUser, token) if err != nil { return fmt.Errorf("error deleting verify token: %w", err) } + + return nil +} + +// DeleteUser - delete from user where username = $1; +func (p *pg) DeleteUser(ctx context.Context, identifier string) error { + childCtx, cancel := context.WithTimeout(ctx, time.Millisecond*100) + defer cancel() + + _, err := p.conn.Exec(childCtx, queries.DeleteUser, identifier) + if err != nil { + return fmt.Errorf("error deleting user: %s", identifier) + } return nil } + +//IsActive - if the user has logged in, isActive returns true +// this method is also useful for limiting access of malicious actors +func (p *pg) IsActive(ctx context.Context, identifier string) bool { + childCtx, cancel := context.WithTimeout(ctx, time.Millisecond*100) + defer cancel() + row := p.conn.QueryRow(childCtx, queries.GetUser, identifier) + return row != nil +} + +func (p *pg) UserExists(ctx context.Context, id string) bool { + childCtx, cancel := context.WithTimeout(ctx, time.Millisecond*100) + defer cancel() + + row, err := p.GetUserById(childCtx, id, false) + if err != nil || row == nil { + return false + } + + return true +} diff --git a/types/types.go b/types/types.go index f601de35..c275646b 100644 --- a/types/types.go +++ b/types/types.go @@ -91,6 +91,11 @@ type ( Namespace string `json:"namespace"` Tags []*ConfigV2 `json:"tags"` } + + Password struct { + OldPassword string `json:"old_password"` + NewPassword string `json:"new_password"` + } ) func (md Metadata) GetManifestByRef(ref string) (*Config, error) {