Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

runtime-local auth support #4843

Merged
merged 5 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions admin/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ const (
AuthClientIDRillWeb = "12345678-0000-0000-0000-000000000001"
AuthClientIDRillCLI = "12345678-0000-0000-0000-000000000002"
AuthClientIDRillSupport = "12345678-0000-0000-0000-000000000003"
AuthClientIDRillLocal = "12345678-0000-0000-0000-000000000004"
pjain1 marked this conversation as resolved.
Show resolved Hide resolved
)

// DeviceAuthCodeState is an enum representing the approval state of a DeviceAuthCode
Expand Down
3 changes: 3 additions & 0 deletions admin/database/postgres/migrations/0028.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-- Hard-coded first-party auth clients
INSERT INTO auth_clients (id, display_name)
VALUES ('12345678-0000-0000-0000-000000000004', 'Rill Local');
23 changes: 4 additions & 19 deletions admin/server/auth/device_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/rilldata/rill/admin"
"github.com/rilldata/rill/admin/database"
"github.com/rilldata/rill/admin/pkg/urlutil"
"github.com/rilldata/rill/cli/pkg/deviceauth"
"github.com/rilldata/rill/cli/pkg/auth"
)

const deviceCodeGrantType = "urn:ietf:params:oauth:grant-type:device_code"
Expand Down Expand Up @@ -175,23 +175,8 @@ func (a *Authenticator) handleUserCodeConfirmation(w http.ResponseWriter, r *htt
}
}

// getAccessToken verifies the device code and returns an access token if the request is approved
func (a *Authenticator) getAccessToken(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "expected a POST request", http.StatusBadRequest)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
internalServerError(w, fmt.Errorf("failed to read request body: %w", err))
return
}
bodyStr := string(body)
values, err := url.ParseQuery(bodyStr)
if err != nil {
internalServerError(w, fmt.Errorf("failed to parse query: %w", err))
return
}
// getAccessTokenForDeviceCode verifies the device code and returns an access token if the request is approved
func (a *Authenticator) getAccessTokenForDeviceCode(w http.ResponseWriter, r *http.Request, values url.Values) {
deviceCode := values.Get("device_code")
if deviceCode == "" {
http.Error(w, "device_code is required", http.StatusBadRequest)
Expand Down Expand Up @@ -253,7 +238,7 @@ func (a *Authenticator) getAccessToken(w http.ResponseWriter, r *http.Request) {
return
}

resp := deviceauth.OAuthTokenResponse{
resp := auth.OAuthTokenResponse{
AccessToken: authToken.Token().String(),
TokenType: "Bearer",
ExpiresIn: time.UnixMilli(0).Unix(), // never expires
Expand Down
85 changes: 84 additions & 1 deletion admin/server/auth/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/base64"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
Expand Down Expand Up @@ -56,7 +57,8 @@ func (a *Authenticator) RegisterEndpoints(mux *http.ServeMux, limiter ratelimit.
observability.MuxHandle(inner, "/auth/logout", middleware.Check(checkLimit("/auth/logout"), http.HandlerFunc(a.authLogout)))
observability.MuxHandle(inner, "/auth/logout/callback", middleware.Check(checkLimit("/auth/logout/callback"), http.HandlerFunc(a.authLogoutCallback)))
observability.MuxHandle(inner, "/auth/oauth/device_authorization", middleware.Check(checkLimit("/auth/oauth/device_authorization"), http.HandlerFunc(a.handleDeviceCodeRequest)))
observability.MuxHandle(inner, "/auth/oauth/device", a.HTTPMiddleware(middleware.Check(checkLimit("/auth/oauth/device"), http.HandlerFunc(a.handleUserCodeConfirmation)))) // NOTE: Uses auth middleware
observability.MuxHandle(inner, "/auth/oauth/device", a.HTTPMiddleware(middleware.Check(checkLimit("/auth/oauth/device"), http.HandlerFunc(a.handleUserCodeConfirmation)))) // NOTE: Uses auth middleware
observability.MuxHandle(inner, "/auth/oauth/authorize", a.HTTPMiddleware(middleware.Check(checkLimit("/auth/oauth/authorize"), http.HandlerFunc(a.handleAuthorizeRequest)))) // NOTE: Uses auth middleware
observability.MuxHandle(inner, "/auth/oauth/token", middleware.Check(checkLimit("/auth/oauth/token"), http.HandlerFunc(a.getAccessToken)))
mux.Handle("/auth/", observability.Middleware("admin", a.logger, inner))
}
Expand Down Expand Up @@ -355,3 +357,84 @@ func (a *Authenticator) authLogoutCallback(w http.ResponseWriter, r *http.Reques
// Redirect to UI (usually)
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
}

// / handleAuthorizeRequest handles the incoming OAuth2 Authorization request
// and generates an authorization code while associating the code challenge.
func (a *Authenticator) handleAuthorizeRequest(w http.ResponseWriter, r *http.Request) {
claims := GetClaims(r.Context())
if claims == nil {
internalServerError(w, fmt.Errorf("did not find any claims, %w", errors.New("server error")))
return
}
if claims.OwnerType() == OwnerTypeAnon {
// not logged in, redirect to login
// TODO how to choose between login and signup?
pjain1 marked this conversation as resolved.
Show resolved Hide resolved
// after login redirect back to same path so encode the current URL as a redirect parameter
encodedURL := url.QueryEscape(r.URL.String())
http.Redirect(w, r, "/auth/login?redirect="+encodedURL, http.StatusTemporaryRedirect)
}
if claims.OwnerType() != OwnerTypeUser {
http.Error(w, "only users can be authorized", http.StatusBadRequest)
return
}
userID := claims.OwnerID()

// Extract necessary details from the query parameters
clientID := r.URL.Query().Get("client_id")
redirectURI := r.URL.Query().Get("redirect_uri")
responseType := r.URL.Query().Get("response_type")

if clientID == "" || redirectURI == "" || responseType == "" {
http.Error(w, "Missing required parameters - client_id or redirect_uri or response_type", http.StatusBadRequest)
return
}

codeChallenge := r.URL.Query().Get("code_challenge")
codeChallengeMethod := r.URL.Query().Get("code_challenge_method")

if codeChallenge != "" {
if codeChallengeMethod == "" {
http.Error(w, "Missing code challenge method", http.StatusBadRequest)
return
}
if responseType != "code" {
http.Error(w, "Invalid response type", http.StatusBadRequest)
return
}
handlePKCE(w, r, clientID, userID, codeChallenge, codeChallengeMethod, redirectURI)
} else {
http.Error(w, "only PKCE based authorization code flow is supported", http.StatusBadRequest)
return
}
}

// getAccessToken verifies the device code and returns an access token if the request is approved
func (a *Authenticator) getAccessToken(w http.ResponseWriter, r *http.Request) {
pjain1 marked this conversation as resolved.
Show resolved Hide resolved
if r.Method != http.MethodPost {
http.Error(w, "expected a POST request", http.StatusBadRequest)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
internalServerError(w, fmt.Errorf("failed to read request body: %w", err))
return
}
bodyStr := string(body)
values, err := url.ParseQuery(bodyStr)
if err != nil {
internalServerError(w, fmt.Errorf("failed to parse query: %w", err))
return
}

grantType := values.Get("grant_type")
if !(grantType == deviceCodeGrantType || grantType == authorizationCodeGrantType) {
http.Error(w, "invalid grant_type", http.StatusBadRequest)
return
}

if grantType == deviceCodeGrantType {
a.getAccessTokenForDeviceCode(w, r, values)
} else {
a.getAccessTokenForAuthorizationCode(w, r, values)
}
}
178 changes: 178 additions & 0 deletions admin/server/auth/pkce.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package auth

import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"net/url"
"time"

"github.com/rilldata/rill/cli/pkg/auth"
pjain1 marked this conversation as resolved.
Show resolved Hide resolved
)

const authorizationCodeGrantType = "authorization_code"

// in-memory storage for authorization codes
// TODO use persistent storage to handle admin server restarts, but since codes expire in a minute, is it really necessary, user can just re-initiate login ?
var authCodeDB = make(map[string]AuthorizationCode)
pjain1 marked this conversation as resolved.
Show resolved Hide resolved

// AuthorizationCode represents the stored information for an authorization code
type AuthorizationCode struct {
ClientID string
RedirectURI string
UserID string
Expiration time.Time
CodeChallenge string
CodeChallengeMethod string
}

func handlePKCE(w http.ResponseWriter, r *http.Request, clientID, userID, codeChallenge, codeChallengeMethod, redirectURI string) {
// Generate a unique authorization code
code, err := generateRandomString(16) // 16 bytes, resulting in a 32-character hex string
if err != nil {
http.Error(w, "Failed to generate authorization code", http.StatusInternalServerError)
return
}

// Set the expiration date for the authorization code (e.g., a minute from now)
// Note from https://www.oauth.com/oauth2-servers/authorization/the-authorization-response/
// The authorization code must expire shortly after it is issued. The OAuth 2.0 spec recommends a maximum lifetime of 10 minutes, but in practice, most services set the expiration much shorter, around 30-60 seconds.
expiration := time.Now().Add(1 * time.Minute)

authCodeDB[code] = AuthorizationCode{
ClientID: clientID,
RedirectURI: redirectURI,
UserID: userID,
Expiration: expiration,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
}

// Build the redirection URI with the authorization code as per OAuth2 spec
redirectWithCode := fmt.Sprintf("%s?code=%s&state=%s", redirectURI, code, r.URL.Query().Get("state"))
pjain1 marked this conversation as resolved.
Show resolved Hide resolved

// Redirect the user agent to the redirect URI with the authorization code
http.Redirect(w, r, redirectWithCode, http.StatusFound)
}

// getAccessTokenForAuthorizationCode exchanges an authorization code for an access token
func (a *Authenticator) getAccessTokenForAuthorizationCode(w http.ResponseWriter, r *http.Request, values url.Values) {
// Extract the authorization code
code := values.Get("code")
if code == "" {
http.Error(w, "authorization code is required", http.StatusBadRequest)
return
}

// Extract the client ID
clientID := values.Get("client_id")
if clientID == "" {
http.Error(w, "client ID is required", http.StatusBadRequest)
return
}

// Extract the redirect URI
redirectURI := values.Get("redirect_uri")
if redirectURI == "" {
http.Error(w, "redirect URI is required", http.StatusBadRequest)
return
}

// Extract the code verifier
codeVerifier := values.Get("code_verifier")
if codeVerifier == "" {
http.Error(w, "code verifier is required", http.StatusBadRequest)
return
}

// Validate the authorization code
authCode, ok := authCodeDB[code]
if !ok {
http.Error(w, "invalid authorization code, please re-initiate login", http.StatusBadRequest)
return
}

userID := authCode.UserID
if userID == "" {
http.Error(w, "no user found for authorization code", http.StatusInternalServerError)
return
}

// remove the authorization code from the database to prevent reuse
delete(authCodeDB, code)

// Check if the client ID matches the stored client ID
if authCode.ClientID != clientID {
http.Error(w, "invalid client ID", http.StatusBadRequest)
return
}

// Check if the redirect URI matches the stored redirect URI
if authCode.RedirectURI != redirectURI {
http.Error(w, "invalid redirect URI", http.StatusBadRequest)
return
}

// Check if the authorization code has expired
if time.Now().After(authCode.Expiration) {
http.Error(w, "authorization code has expired", http.StatusBadRequest)
return
}

// Verify the code verifier against the stored code challenge
if !verifyCodeChallenge(codeVerifier, authCode.CodeChallenge, authCode.CodeChallengeMethod) {
http.Error(w, "invalid code verifier", http.StatusBadRequest)
return
}

// Issue an access token
authToken, err := a.admin.IssueUserAuthToken(r.Context(), userID, authCode.ClientID, "", nil, nil)
if err != nil {
internalServerError(w, fmt.Errorf("failed to issue access token, %w", err))
return
}
pjain1 marked this conversation as resolved.
Show resolved Hide resolved

resp := auth.OAuthTokenResponse{
AccessToken: authToken.Token().String(),
TokenType: "Bearer",
ExpiresIn: time.UnixMilli(0).Unix(), // never expires
pjain1 marked this conversation as resolved.
Show resolved Hide resolved
UserID: userID,
}
respBytes, err := json.Marshal(resp)
if err != nil {
internalServerError(w, fmt.Errorf("failed to marshal response, %w", err))
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(respBytes)
if err != nil {
internalServerError(w, fmt.Errorf("failed to write response, %w", err))
return
}
}

// verifyCodeChallenge validates the code verifier with the stored code challenge
func verifyCodeChallenge(verifier, challenge, method string) bool {
switch method {
case "S256":
s256 := sha256.Sum256([]byte(verifier))
computedChallenge := base64.RawURLEncoding.EncodeToString(s256[:])
return computedChallenge == challenge
default:
return false
}
}

// Generates a random string for use as the authorization code
func generateRandomString(n int) (string, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
14 changes: 14 additions & 0 deletions cli/pkg/auth/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package auth

const (
FormMediaType = "application/x-www-form-urlencoded"
JSONMediaType = "application/json"
)

// OAuthTokenResponse contains the information returned after fetching an access token from the OAuth server.
type OAuthTokenResponse struct {
pjain1 marked this conversation as resolved.
Show resolved Hide resolved
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in,string"`
TokenType string `json:"token_type"`
UserID string `json:"user_id"`
}
Loading
Loading