Skip to content

Commit

Permalink
Merge pull request #41 from lstoll/lstoll-manage-credentials-not-users
Browse files Browse the repository at this point in the history
Credential addition via CLI
  • Loading branch information
lstoll authored Mar 9, 2024
2 parents d804e42 + d9d5d66 commit 63d37df
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 87 deletions.
50 changes: 26 additions & 24 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,19 @@ type User struct {
// FullName to refer to the user as.
FullName string `json:"fullName"`

// Activated if the user is valid to be used.
Activated bool `json:"activated"`

// EnrollmentKey used for enrolling users first token.
// EnrollmentKey used for enrolling tokens for a user. It is removed when
// the token is enrolled.
EnrollmentKey string `json:"enrollmentKey"`

// Credentials is the user's WebAuthn credentials keyed by a user-provided identifier.
Credentials map[string]webauthn.Credential `json:"credentials"`
Credentials map[string]WebauthnCredential `json:"credentials"`
}

// WebauthnCredential wraps the webauthn.Credential with some more metadata.
type WebauthnCredential struct {
webauthn.Credential
Name string `json:"name"`
AddedAt time.Time `json:"addedAt"`
}

func (u User) WebAuthnID() []byte {
Expand All @@ -100,7 +105,7 @@ func (u User) WebAuthnIcon() string {
func (u User) WebAuthnCredentials() []webauthn.Credential {
var ret []webauthn.Credential
for _, v := range u.Credentials {
ret = append(ret, v)
ret = append(ret, v.Credential)
}
return ret
}
Expand Down Expand Up @@ -178,17 +183,6 @@ func (db *DB) GetUserByID(userID string) (User, error) {
return v, nil
}

func (db *DB) GetActivatedUserByID(id string) (User, error) {
user, err := db.GetUserByID(id)
if err != nil {
return User{}, err
}
if !user.Activated {
return User{}, ErrUserNotActivated
}
return user, nil
}

func (db *DB) CreateUser(user User) (User, error) {
if user.ID != "" {
return User{}, fmt.Errorf("user ID already assigned")
Expand All @@ -198,7 +192,7 @@ func (db *DB) CreateUser(user User) (User, error) {
}
user.ID = uuid.NewString()
user.EnrollmentKey = uuid.NewString()
user.Credentials = make(map[string]webauthn.Credential)
user.Credentials = make(map[string]WebauthnCredential)
err := db.f.Write(func(db *schema) error {
if _, ok := db.Users[user.ID]; ok {
panic("generated UUID already in use")
Expand Down Expand Up @@ -252,7 +246,9 @@ func (db *DB) UpdateUserCredential(userID string, cred webauthn.Credential) erro
}
for k, v := range user.Credentials {
if bytes.Equal(cred.ID, v.ID) {
db.Users[userID].Credentials[k] = cred
c := db.Users[userID].Credentials[k]
c.Credential = cred
db.Users[userID].Credentials[k] = c
return nil
}
}
Expand All @@ -261,12 +257,15 @@ func (db *DB) UpdateUserCredential(userID string, cred webauthn.Credential) erro
return err
}

func (db *DB) CreateUserCredential(userID, name string, cred webauthn.Credential) error {
func (db *DB) CreateUserCredential(userID, name string, cred WebauthnCredential) error {
err := db.f.Write(func(db *schema) error {
if _, ok := db.Users[userID]; !ok {
return ErrUserNotFound
}
db.Users[userID].Credentials[name] = cred
u := db.Users[userID]
u.EnrollmentKey = ""
u.Credentials[name] = cred
db.Users[userID] = u
return nil
})
return err
Expand Down Expand Up @@ -357,12 +356,15 @@ func migrateSQLToJSON(sqldb *storage, jsondb *DB) error {
ID: user.ID,
Email: user.Email,
FullName: user.FullName,
Activated: user.Activated,
EnrollmentKey: user.EnrollmentKey,
Credentials: make(map[string]webauthn.Credential),
Credentials: make(map[string]WebauthnCredential),
}
for _, cred := range user.Credentials {
newUser.Credentials[cred.Name] = cred.Credential
newUser.Credentials[cred.Name] = WebauthnCredential{
Credential: cred.Credential,
Name: cred.Name,
AddedAt: time.Now(),
}
}
if err := jsondb.createMigratedUser(newUser); err != nil {
return fmt.Errorf("json.createMigratedUser: %w", err)
Expand Down
20 changes: 3 additions & 17 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,24 +82,14 @@ func TestUsers(t *testing.T) {
t.Fatalf("want GetUserByID to return user not found error, got: %v", err)
}

if _, err := db.GetActivatedUserByID(newUser.ID); err != ErrUserNotActivated {
t.Fatalf("want GetActivatedUserByID to return user not activated error, got: %v", err)
}

newUser.Activated = true

if err := db.UpdateUser(newUser); err != nil {
t.Fatalf("UpdateUser: %v", err)
}

user2, err := db.CreateUser(User{Email: "me@example.com"})
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
if err := db.UpdateUser(User{ID: user2.ID, Email: user.Email}); err != ErrUserEmailTaken {
t.Fatalf("UpdateUser did not reject non-unique email, got: %v", err)
}
err = db.CreateUserCredential(user2.ID, "1pass", webauthn.Credential{ID: []byte("ID")})
err = db.CreateUserCredential(user2.ID, "1pass", WebauthnCredential{Credential: webauthn.Credential{ID: []byte("ID")}})
if err != nil {
t.Fatalf("AddCredentialToUser: %v", err)
}
Expand All @@ -108,7 +98,7 @@ func TestUsers(t *testing.T) {
t.Fatal("UpdateUser: want missing user ID error")
}

err = db.CreateUserCredential(newUser.ID, "test", webauthn.Credential{ID: []byte("ID")})
err = db.CreateUserCredential(newUser.ID, "test", WebauthnCredential{Credential: webauthn.Credential{ID: []byte("ID")}})
if err != nil {
t.Fatalf("AddCredentialToUser: %v", err)
}
Expand All @@ -122,7 +112,7 @@ func TestUsers(t *testing.T) {
t.Errorf("want user to have 1 credentials, got %d", len(user.Credentials))
}

if err := db.UpdateUserCredential(user.ID, user.Credentials["test"]); err != nil {
if err := db.UpdateUserCredential(user.ID, user.Credentials["test"].Credential); err != nil {
t.Fatalf("UpdateUserCredential: %v", err)
}
if err := db.UpdateUserCredential(user.ID, webauthn.Credential{}); err != ErrCredentialNotFound {
Expand Down Expand Up @@ -196,10 +186,6 @@ func TestMigrateSQLToJSON(t *testing.T) {
t.Fatalf("migrateSQLToJSON: %v", err)
}

if _, err = jsondb.GetActivatedUserByID(user.ID); err != nil {
t.Fatalf("jsondb.GetActivatedUserByID: %v", err)
}

user2, err := jsondb.GetUserByID(user.ID)
if err != nil {
t.Fatalf("GetUserByID: %v", err)
Expand Down
8 changes: 1 addition & 7 deletions e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,7 @@ func TestE2E(t *testing.T) {
case <-doneC:
}

// we need to mark the user as active for their credentials to be
// usable.
if err := activateUser(db, user.ID); err != nil {
t.Fatal(err)
}

user, err = db.GetActivatedUserByID(user.ID)
user, err = db.GetUserByID(user.ID)
if err != nil {
t.Fatal(err)
}
Expand Down
59 changes: 31 additions & 28 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (

"github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn"
"github.com/google/uuid"
"github.com/lstoll/cookiesession"
"github.com/lstoll/oidc/core"
"github.com/lstoll/oidc/discovery"
Expand All @@ -36,8 +37,9 @@ func main() {
enroll := flag.Bool("enroll", false, "Enroll a user into the system.")
email := flag.String("email", "", "Email address for the user.")
fullname := flag.String("fullname", "", "Full name of the user.")
activate := flag.Bool("activate", false, "Activate an enrolled user.")
userID := flag.String("user-id", "", "ID of user to activate.")
addCredential := flag.Bool("add-credential", false, "Generate a new credential enrollment URL for a user")
userID := flag.String("user-id", "", "ID of user to add credential to.")
listCredential := flag.Bool("list-credentials", false, "List credentials for the user-id")

flag.Parse()

Expand Down Expand Up @@ -82,24 +84,45 @@ func main() {
}

user, err := db.CreateUser(User{
Email: *email,
FullName: *fullname,
Activated: false,
Email: *email,
FullName: *fullname,
})
if err != nil {
fatalf("create user: %v", err)
}
reloadDB(*addr)
fmt.Printf("Enroll at: %s\n", registrationURL(cfg.Issuer[0].URL, user))
return
} else if *activate {
} else if *addCredential {
if *userID == "" {
fatal("required flag missing: user-id")
}
if err := activateUser(db, *userID); err != nil {
fatalf("ativate user: %v", err)
user, err := db.GetUserByID(*userID)
if err != nil {
fatalf("get user %s: %w", userID, err)
}

user.EnrollmentKey = uuid.NewString()

if err := db.UpdateUser(user); err != nil {
fatalf("update user %s: %w", userID, err)
}

reloadDB(*addr)
fmt.Printf("Enroll at: %s\n", registrationURL(cfg.Issuer[0].URL, user))
return
} else if *listCredential {
if *userID == "" {
fatal("required flag missing: user-id")
}
user, err := db.GetUserByID(*userID)
if err != nil {
fatalf("get user %s: %w", userID, err)
}

for _, c := range user.Credentials {
fmt.Printf("credential: %s (added at %s)\n", c.Name, c.AddedAt)
}
return
}

Expand Down Expand Up @@ -293,26 +316,6 @@ func registrationURL(iss *url.URL, user User) *url.URL {
return u2
}

// activateUser marks the user as activated and deletes its enrollment key.
// Must be called after the user has completed the registration flow.
func activateUser(db *DB, userID string) error {
user, err := db.GetUserByID(userID)
if err != nil {
return fmt.Errorf("get user %s: %w", userID, err)
}

user.EnrollmentKey = ""
user.Activated = true

if err := db.UpdateUser(user); err != nil {
return fmt.Errorf("update user %s: %w", userID, err)
}

fmt.Println("Done.")

return nil
}

// reloadDB tells the server running on addr to reload its database from disk.
func reloadDB(addr string) {
resp, err := http.Get("http://" + addr + "/reloaddb")
Expand Down
4 changes: 2 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func (s *oidcServer) finishLogin(rw http.ResponseWriter, req *http.Request) {

userID := string(parsedResponse.Response.UserHandle) // user handle is the webauthn.User#ID we registered with

user, err := s.db.GetActivatedUserByID(userID)
user, err := s.db.GetUserByID(userID)
if err != nil {
s.httpErr(rw, err)
return
Expand Down Expand Up @@ -247,7 +247,7 @@ func (s *oidcServer) loggedIn(rw http.ResponseWriter, req *http.Request) {
return
}

user, err := s.db.GetActivatedUserByID(authdUser.UserID)
user, err := s.db.GetUserByID(authdUser.UserID)
if err != nil {
s.httpErr(rw, err)
return
Expand Down
14 changes: 5 additions & 9 deletions webauthn_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"html/template"
"log/slog"
"net/http"
"time"

"github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn"
Expand All @@ -26,18 +27,13 @@ type webauthnManager struct {
func (w *webauthnManager) AddHandlers(mux *http.ServeMux) {
mux.HandleFunc("/registration/begin", w.beginRegistration)
mux.HandleFunc("/registration/finish", w.finishRegistration)
mux.HandleFunc("/registration", w.registration)
mux.HandleFunc("GET /registration", w.registration)
}

// registration is a page used to add a new key. It should handle either a user
// in the session (from the logged in keys page), or a boostrap token and user
// id as query params for an inactive user.
func (w *webauthnManager) registration(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet {
http.Error(rw, "Invalid Method", http.StatusMethodNotAllowed)
return
}

// first, check the URL for a registration token and user id. If it exists,
// check if we have the user and if they are active/with a matching token,
// embed it in the page.
Expand All @@ -50,8 +46,8 @@ func (w *webauthnManager) registration(rw http.ResponseWriter, req *http.Request
w.httpErr(req.Context(), rw, fmt.Errorf("get user %s: %w", uid, err))
return
}
if user.Activated || subtle.ConstantTimeCompare([]byte(et), []byte(user.EnrollmentKey)) == 0 {
w.httpUnauth(rw, "invalid enrollment")
if user.EnrollmentKey == "" || subtle.ConstantTimeCompare([]byte(et), []byte(user.EnrollmentKey)) == 0 {
w.httpUnauth(rw, "either previous enrollment completed fine, or invalid enrollment")
return
}
sess := w.sessmgr.Get(req.Context())
Expand Down Expand Up @@ -145,7 +141,7 @@ func (w *webauthnManager) finishRegistration(rw http.ResponseWriter, req *http.R
return
}

if err := w.db.CreateUserCredential(user.ID, keyName, *credential); err != nil {
if err := w.db.CreateUserCredential(user.ID, keyName, WebauthnCredential{Credential: *credential, Name: keyName, AddedAt: time.Now()}); err != nil {
w.httpErr(req.Context(), rw, err)
return
}
Expand Down

0 comments on commit 63d37df

Please sign in to comment.