Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

runtime-local auth support #4843

Merged
merged 5 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions admin/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ type DB interface {
UpdateDeviceAuthCode(ctx context.Context, id, userID string, state DeviceAuthCodeState) error
DeleteExpiredDeviceAuthCodes(ctx context.Context, retention time.Duration) error

FindAuthorizationCode(ctx context.Context, code string) (*AuthorizationCode, error)
InsertAuthorizationCode(ctx context.Context, code, userID, clientID, redirectURI, codeChallenge, codeChallengeMethod string, expiration time.Time) (*AuthorizationCode, error)
DeleteAuthorizationCode(ctx context.Context, code string) error
DeleteExpiredAuthorizationCodes(ctx context.Context, retention time.Duration) error

FindOrganizationRole(ctx context.Context, name string) (*OrganizationRole, error)
FindProjectRole(ctx context.Context, name string) (*ProjectRole, error)
ResolveOrganizationRolesForUser(ctx context.Context, userID, orgID string) ([]*OrganizationRole, error)
Expand Down Expand Up @@ -512,9 +517,10 @@ type AuthClient struct {

// Hard-coded auth client IDs (created in the migrations).
const (
AuthClientIDRillWeb = "12345678-0000-0000-0000-000000000001"
AuthClientIDRillCLI = "12345678-0000-0000-0000-000000000002"
AuthClientIDRillSupport = "12345678-0000-0000-0000-000000000003"
AuthClientIDRillWeb = "12345678-0000-0000-0000-000000000001"
AuthClientIDRillCLI = "12345678-0000-0000-0000-000000000002"
AuthClientIDRillSupport = "12345678-0000-0000-0000-000000000003"
AuthClientIDRillWebLocal = "12345678-0000-0000-0000-000000000004"
)

// DeviceAuthCodeState is an enum representing the approval state of a DeviceAuthCode
Expand All @@ -540,6 +546,20 @@ type DeviceAuthCode struct {
UpdatedOn time.Time `db:"updated_on"`
}

// AuthorizationCode represents an authorization code used for OAuth2 PKCE auth flow.
type AuthorizationCode struct {
ID string `db:"id"`
Code string `db:"code"`
UserID string `db:"user_id"`
ClientID string `db:"client_id"`
RedirectURI string `db:"redirect_uri"`
CodeChallenge string `db:"code_challenge"`
CodeChallengeMethod string `db:"code_challenge_method"`
Expiration time.Time `db:"expires_on"`
CreatedOn time.Time `db:"created_on"`
UpdatedOn time.Time `db:"updated_on"`
}

// Constants for known role names (created in migrations).
const (
OrganizationRoleNameAdmin = "admin"
Expand Down
20 changes: 20 additions & 0 deletions admin/database/postgres/migrations/0028.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
-- Hard-coded first-party auth clients
INSERT INTO auth_clients (id, display_name)
VALUES ('12345678-0000-0000-0000-000000000004', 'Rill Localhost');

-- Table for storing authorization codes for PKCE auth flow
CREATE TABLE authorization_codes (
id UUID DEFAULT uuid_generate_v4() PRIMARY KEY,
code TEXT NOT NULL,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
client_id UUID NOT NULL REFERENCES auth_clients(id) ON DELETE CASCADE,
redirect_uri TEXT NOT NULL,
code_challenge TEXT NOT NULL,
code_challenge_method TEXT NOT NULL,
expires_on TIMESTAMP NOT NULL,
created_on TIMESTAMPTZ DEFAULT now() NOT NULL,
updated_on TIMESTAMPTZ DEFAULT now() NOT NULL
);

-- create index on code column
CREATE UNIQUE INDEX authorization_codes_code_idx ON authorization_codes(code);
30 changes: 30 additions & 0 deletions admin/database/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,36 @@ func (c *connection) DeleteExpiredDeviceAuthCodes(ctx context.Context, retention
return parseErr("device auth code", err)
}

func (c *connection) FindAuthorizationCode(ctx context.Context, code string) (*database.AuthorizationCode, error) {
authCode := &database.AuthorizationCode{}
err := c.getDB(ctx).QueryRowxContext(ctx, "SELECT * FROM authorization_codes WHERE code = $1", code).StructScan(authCode)
if err != nil {
return nil, parseErr("authorization code", err)
}
return authCode, nil
}

func (c *connection) InsertAuthorizationCode(ctx context.Context, code, userID, clientID, redirectURI, codeChallenge, codeChallengeMethod string, expiration time.Time) (*database.AuthorizationCode, error) {
res := &database.AuthorizationCode{}
err := c.getDB(ctx).QueryRowxContext(ctx,
`INSERT INTO authorization_codes (code, user_id, client_id, redirect_uri, code_challenge, code_challenge_method, expires_on)
VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING *`, code, userID, clientID, redirectURI, codeChallenge, codeChallengeMethod, expiration).StructScan(res)
if err != nil {
return nil, parseErr("authorization code", err)
}
return res, nil
}

func (c *connection) DeleteAuthorizationCode(ctx context.Context, code string) error {
res, err := c.getDB(ctx).ExecContext(ctx, "DELETE FROM authorization_codes WHERE code=$1", code)
return checkDeleteRow("authorization code", res, err)
}

func (c *connection) DeleteExpiredAuthorizationCodes(ctx context.Context, retention time.Duration) error {
_, err := c.getDB(ctx).ExecContext(ctx, "DELETE FROM authorization_codes WHERE expires_on + $1 < now()", retention)
return parseErr("authorization code", err)
}

func (c *connection) FindOrganizationRole(ctx context.Context, name string) (*database.OrganizationRole, error) {
role := &database.OrganizationRole{}
err := c.getDB(ctx).QueryRowxContext(ctx, "SELECT * FROM org_roles WHERE lower(name)=lower($1)", name).StructScan(role)
Expand Down
14 changes: 14 additions & 0 deletions admin/pkg/oauth/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package oauth

const (
FormMediaType = "application/x-www-form-urlencoded"
JSONMediaType = "application/json"
)

// TokenResponse contains the information returned after fetching an access token from the OAuth server.
type TokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in,string"`
TokenType string `json:"token_type"`
UserID string `json:"user_id"`
}
25 changes: 5 additions & 20 deletions admin/server/auth/device_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import (

"github.com/rilldata/rill/admin"
"github.com/rilldata/rill/admin/database"
"github.com/rilldata/rill/admin/pkg/oauth"
"github.com/rilldata/rill/admin/pkg/urlutil"
"github.com/rilldata/rill/cli/pkg/deviceauth"
)

const deviceCodeGrantType = "urn:ietf:params:oauth:grant-type:device_code"
Expand Down Expand Up @@ -175,23 +175,8 @@ func (a *Authenticator) handleUserCodeConfirmation(w http.ResponseWriter, r *htt
}
}

// getAccessToken verifies the device code and returns an access token if the request is approved
func (a *Authenticator) getAccessToken(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "expected a POST request", http.StatusBadRequest)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
internalServerError(w, fmt.Errorf("failed to read request body: %w", err))
return
}
bodyStr := string(body)
values, err := url.ParseQuery(bodyStr)
if err != nil {
internalServerError(w, fmt.Errorf("failed to parse query: %w", err))
return
}
// getAccessTokenForDeviceCode verifies the device code and returns an access token if the request is approved
func (a *Authenticator) getAccessTokenForDeviceCode(w http.ResponseWriter, r *http.Request, values url.Values) {
deviceCode := values.Get("device_code")
if deviceCode == "" {
http.Error(w, "device_code is required", http.StatusBadRequest)
Expand Down Expand Up @@ -253,10 +238,10 @@ func (a *Authenticator) getAccessToken(w http.ResponseWriter, r *http.Request) {
return
}

resp := deviceauth.OAuthTokenResponse{
resp := oauth.TokenResponse{
AccessToken: authToken.Token().String(),
TokenType: "Bearer",
ExpiresIn: time.UnixMilli(0).Unix(), // never expires
ExpiresIn: 0, // never expires
UserID: *authCode.UserID,
}
respBytes, err := json.Marshal(resp)
Expand Down
83 changes: 82 additions & 1 deletion admin/server/auth/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/base64"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
Expand Down Expand Up @@ -56,7 +57,8 @@ func (a *Authenticator) RegisterEndpoints(mux *http.ServeMux, limiter ratelimit.
observability.MuxHandle(inner, "/auth/logout", middleware.Check(checkLimit("/auth/logout"), http.HandlerFunc(a.authLogout)))
observability.MuxHandle(inner, "/auth/logout/callback", middleware.Check(checkLimit("/auth/logout/callback"), http.HandlerFunc(a.authLogoutCallback)))
observability.MuxHandle(inner, "/auth/oauth/device_authorization", middleware.Check(checkLimit("/auth/oauth/device_authorization"), http.HandlerFunc(a.handleDeviceCodeRequest)))
observability.MuxHandle(inner, "/auth/oauth/device", a.HTTPMiddleware(middleware.Check(checkLimit("/auth/oauth/device"), http.HandlerFunc(a.handleUserCodeConfirmation)))) // NOTE: Uses auth middleware
observability.MuxHandle(inner, "/auth/oauth/device", a.HTTPMiddleware(middleware.Check(checkLimit("/auth/oauth/device"), http.HandlerFunc(a.handleUserCodeConfirmation)))) // NOTE: Uses auth middleware
observability.MuxHandle(inner, "/auth/oauth/authorize", a.HTTPMiddleware(middleware.Check(checkLimit("/auth/oauth/authorize"), http.HandlerFunc(a.handleAuthorizeRequest)))) // NOTE: Uses auth middleware
observability.MuxHandle(inner, "/auth/oauth/token", middleware.Check(checkLimit("/auth/oauth/token"), http.HandlerFunc(a.getAccessToken)))
mux.Handle("/auth/", observability.Middleware("admin", a.logger, inner))
}
Expand Down Expand Up @@ -355,3 +357,82 @@ func (a *Authenticator) authLogoutCallback(w http.ResponseWriter, r *http.Reques
// Redirect to UI (usually)
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
}

// handleAuthorizeRequest handles the incoming OAuth2 Authorization request, if he user is not logged redirect to login, currently only PKCE based authorization code flow is supported
func (a *Authenticator) handleAuthorizeRequest(w http.ResponseWriter, r *http.Request) {
claims := GetClaims(r.Context())
if claims == nil {
internalServerError(w, fmt.Errorf("did not find any claims, %w", errors.New("server error")))
return
}
if claims.OwnerType() == OwnerTypeAnon {
// not logged in, redirect to login
// after login redirect back to same path so encode the current URL as a redirect parameter
encodedURL := url.QueryEscape(r.URL.String())
http.Redirect(w, r, "/auth/login?redirect="+encodedURL, http.StatusTemporaryRedirect)
}
if claims.OwnerType() != OwnerTypeUser {
http.Error(w, "only users can be authorized", http.StatusBadRequest)
return
}
userID := claims.OwnerID()

// Extract necessary details from the query parameters
clientID := r.URL.Query().Get("client_id")
redirectURI := r.URL.Query().Get("redirect_uri")
responseType := r.URL.Query().Get("response_type")

if clientID == "" || redirectURI == "" || responseType == "" {
http.Error(w, "Missing required parameters - client_id or redirect_uri or response_type", http.StatusBadRequest)
return
}

codeChallenge := r.URL.Query().Get("code_challenge")
codeChallengeMethod := r.URL.Query().Get("code_challenge_method")

if codeChallenge != "" {
if codeChallengeMethod == "" {
http.Error(w, "Missing code challenge method", http.StatusBadRequest)
return
}
if responseType != "code" {
http.Error(w, "Invalid response type", http.StatusBadRequest)
return
}
a.handlePKCE(w, r, clientID, userID, codeChallenge, codeChallengeMethod, redirectURI)
} else {
http.Error(w, "only PKCE based authorization code flow is supported", http.StatusBadRequest)
return
}
}

// getAccessToken depending on the grant_type either verifies the device code and returns an access token if the request is approved or exchanges the authorization code for an access token
func (a *Authenticator) getAccessToken(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "expected a POST request", http.StatusBadRequest)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
internalServerError(w, fmt.Errorf("failed to read request body: %w", err))
return
}
bodyStr := string(body)
values, err := url.ParseQuery(bodyStr)
if err != nil {
internalServerError(w, fmt.Errorf("failed to parse query: %w", err))
return
}

grantType := values.Get("grant_type")
if !(grantType == deviceCodeGrantType || grantType == authorizationCodeGrantType) {
http.Error(w, "invalid grant_type", http.StatusBadRequest)
return
}

if grantType == deviceCodeGrantType {
a.getAccessTokenForDeviceCode(w, r, values)
} else {
a.getAccessTokenForAuthorizationCode(w, r, values)
}
}
Loading
Loading