diff --git a/.changesets/10444.md b/.changesets/10444.md new file mode 100644 index 000000000000..eb3f6aee2016 --- /dev/null +++ b/.changesets/10444.md @@ -0,0 +1,4 @@ +- feat(server-auth): Part 1/3: dbAuth middleware support (web side changes) (#10444) by @dac09 +Adds ability to `createMiddlewareAuth` in dbAuth client which: +1. Updates the dbAuth web client to speak to middleware instead of graphql +2. Implements fetching current user from middleware diff --git a/packages/auth-providers/dbAuth/web/src/__tests__/dbAuth.middleware.test.ts b/packages/auth-providers/dbAuth/web/src/__tests__/dbAuth.middleware.test.ts new file mode 100644 index 000000000000..199e3d8d9aa0 --- /dev/null +++ b/packages/auth-providers/dbAuth/web/src/__tests__/dbAuth.middleware.test.ts @@ -0,0 +1,150 @@ +import { act, renderHook } from '@testing-library/react' + +import type { CustomProviderHooks, DbAuthClientArgs } from '../dbAuth' +import { createDbAuthClient, createMiddlewareAuth } from '../dbAuth' + +import { fetchMock } from './dbAuth.test' + +const defaultArgs = { + fetchConfig: { + credentials: 'include' as const, + }, +} + +export function getMwDbAuth( + args: DbAuthClientArgs & CustomProviderHooks = defaultArgs, +) { + // We have to create a special createDbAuthClient with middleware = true + const dbAuthClient = createDbAuthClient({ ...args, middleware: true }) + const { useAuth, AuthProvider } = createMiddlewareAuth(dbAuthClient, { + useCurrentUser: args.useCurrentUser, + useHasRole: args.useHasRole, + }) + const { result } = renderHook(() => useAuth(), { + wrapper: AuthProvider, + }) + + return result +} + +// These tests are on top of the other tests in dbAuth.test.ts +// They test the middleware specific things about the dbAuth client + +describe('dbAuth web ~ cookie/middleware auth', () => { + it('will create a middleware version of the auth client', async () => { + const { current: dbAuthInstance } = getMwDbAuth() + + // Middleware auth clients should not return tokens + expect(await dbAuthInstance.getToken()).toBeNull() + + let currentUser + await act(async () => { + currentUser = await dbAuthInstance.getCurrentUser() + }) + + expect(globalThis.fetch).toHaveBeenCalledWith( + // Doesn't speak to graphql! + '/middleware/dbauth/currentUser', + expect.objectContaining({ + credentials: 'include', + method: 'GET', // in mw auth, we use GET for currentUser + }), + ) + + expect(currentUser).toEqual({ + id: 'middleware-user-555', + username: 'user@middleware.auth', + }) + }) + + it('can still override getCurrentUser', async () => { + const mockedCustomCurrentUser = jest.fn() + const { current: dbAuthInstance } = getMwDbAuth({ + useCurrentUser: mockedCustomCurrentUser, + }) + await act(async () => { + await dbAuthInstance.getCurrentUser() + }) + + expect(mockedCustomCurrentUser).toHaveBeenCalled() + }) + + it('allows you to override the middleware endpoint', async () => { + const auth = getMwDbAuth({ + dbAuthUrl: '/hello/handsome', + }).current + + await act(async () => await auth.forgotPassword('username')) + + expect(fetchMock).toHaveBeenCalledWith( + '/hello/handsome', + expect.any(Object), + ) + }) + + it('calls login at the middleware endpoint', async () => { + const auth = getMwDbAuth().current + + await act( + async () => + await auth.logIn({ username: 'username', password: 'password' }), + ) + + expect(globalThis.fetch).toHaveBeenCalledWith( + '/middleware/dbauth', + expect.any(Object), + ) + }) + + it('calls middleware endpoint for logout', async () => { + const auth = getMwDbAuth().current + await act(async () => { + await auth.logOut() + }) + + expect(globalThis.fetch).toHaveBeenCalledWith('/middleware/dbauth', { + body: '{"method":"logout"}', + credentials: 'include', + method: 'POST', + }) + }) + + it('calls reset password at the correct endpoint', async () => { + const auth = getMwDbAuth().current + + await act( + async () => + await auth.resetPassword({ + resetToken: 'reset-token', + password: 'password', + }), + ) + + expect(globalThis.fetch).toHaveBeenCalledWith( + '/middleware/dbauth', + expect.objectContaining({ + body: '{"resetToken":"reset-token","password":"password","method":"resetPassword"}', + }), + ) + }) + + it('passes through fetchOptions to signup calls', async () => { + const auth = getMwDbAuth().current + + await act( + async () => + await auth.signUp({ + username: 'username', + password: 'password', + }), + ) + + expect(globalThis.fetch).toHaveBeenCalledWith( + '/middleware/dbauth', + expect.objectContaining({ + method: 'POST', + body: '{"username":"username","password":"password","method":"signup"}', + }), + ) + }) +}) diff --git a/packages/auth-providers/dbAuth/web/src/__tests__/dbAuth.test.ts b/packages/auth-providers/dbAuth/web/src/__tests__/dbAuth.test.ts index 9f21253de96c..35a7cbfa6793 100644 --- a/packages/auth-providers/dbAuth/web/src/__tests__/dbAuth.test.ts +++ b/packages/auth-providers/dbAuth/web/src/__tests__/dbAuth.test.ts @@ -2,7 +2,7 @@ import { renderHook, act } from '@testing-library/react' import type { CurrentUser } from '@redwoodjs/auth' -import type { DbAuthClientArgs } from '../dbAuth' +import type { CustomProviderHooks, DbAuthClientArgs } from '../dbAuth' import { createDbAuthClient, createAuth } from '../dbAuth' globalThis.RWJS_API_URL = '/.redwood/functions' @@ -20,7 +20,7 @@ interface User { let loggedInUser: User | undefined -const fetchMock = jest.fn() +export const fetchMock = jest.fn() fetchMock.mockImplementation(async (url, options) => { const body = options?.body ? JSON.parse(options.body) : {} @@ -63,7 +63,26 @@ fetchMock.mockImplementation(async (url, options) => { return { ok: true, text: () => '', - json: () => ({ data: { redwood: { currentUser: loggedInUser } } }), + json: () => ({ + data: { + redwood: { + currentUser: loggedInUser, + }, + }, + }), + } + } + + if (url.includes('middleware/dbauth/currentUser')) { + return { + ok: true, + text: () => '', + json: () => ({ + currentUser: { + id: 'middleware-user-555', + username: 'user@middleware.auth', + }, + }), } } @@ -79,16 +98,11 @@ beforeEach(() => { loggedInUser = undefined }) -const defaultArgs: DbAuthClientArgs & { - useCurrentUser?: () => Promise - useHasRole?: ( - currentUser: CurrentUser | null, - ) => (rolesToCheck: string | string[]) => boolean -} = { +const defaultArgs: DbAuthClientArgs & CustomProviderHooks = { fetchConfig: { credentials: 'include' }, } -function getDbAuth(args = defaultArgs) { +export function getDbAuth(args = defaultArgs) { const dbAuthClient = createDbAuthClient(args) const { useAuth, AuthProvider } = createAuth(dbAuthClient, { useHasRole: args.useHasRole, @@ -101,7 +115,7 @@ function getDbAuth(args = defaultArgs) { return result } -describe('dbAuth', () => { +describe('dbAuth web client', () => { it('sets a default credentials value if not included', async () => { const authRef = getDbAuth({ fetchConfig: {} }) @@ -113,7 +127,7 @@ describe('dbAuth', () => { await authRef.current.getToken() }) - expect(globalThis.fetch).toBeCalledWith( + expect(globalThis.fetch).toHaveBeenCalledWith( `${globalThis.RWJS_API_URL}/auth?method=getToken`, { credentials: 'same-origin', @@ -126,7 +140,7 @@ describe('dbAuth', () => { await act(async () => await auth.forgotPassword('username')) - expect(fetchMock).toBeCalledWith( + expect(fetchMock).toHaveBeenCalledWith( `${globalThis.RWJS_API_URL}/auth`, expect.objectContaining({ credentials: 'include', @@ -143,7 +157,7 @@ describe('dbAuth', () => { expect(fetchMock).toHaveBeenCalledTimes(1) - expect(fetchMock).toBeCalledWith( + expect(fetchMock).toHaveBeenCalledWith( `${globalThis.RWJS_API_URL}/auth?method=getToken`, { credentials: 'include', @@ -152,21 +166,21 @@ describe('dbAuth', () => { }) it('passes through fetchOptions to login calls', async () => { - const auth = (await getDbAuth()).current + const auth = getDbAuth().current await act( async () => await auth.logIn({ username: 'username', password: 'password' }), ) - expect(globalThis.fetch).toBeCalledWith( + expect(globalThis.fetch).toHaveBeenCalledWith( `${globalThis.RWJS_API_URL}/auth`, expect.objectContaining({ credentials: 'include', }), ) - expect(globalThis.fetch).toBeCalledWith( + expect(globalThis.fetch).toHaveBeenCalledWith( `${globalThis.RWJS_API_URL}/auth`, expect.objectContaining({ credentials: 'include', @@ -180,7 +194,7 @@ describe('dbAuth', () => { await auth.logOut() }) - expect(globalThis.fetch).toBeCalledWith( + expect(globalThis.fetch).toHaveBeenCalledWith( `${globalThis.RWJS_API_URL}/auth`, expect.objectContaining({ credentials: 'include', @@ -198,7 +212,7 @@ describe('dbAuth', () => { }), ) - expect(globalThis.fetch).toBeCalledWith( + expect(globalThis.fetch).toHaveBeenCalledWith( `${globalThis.RWJS_API_URL}/auth`, expect.objectContaining({ credentials: 'include', @@ -216,7 +230,7 @@ describe('dbAuth', () => { }), ) - expect(globalThis.fetch).toBeCalledWith( + expect(globalThis.fetch).toHaveBeenCalledWith( `${globalThis.RWJS_API_URL}/auth`, expect.objectContaining({ credentials: 'include', @@ -228,7 +242,7 @@ describe('dbAuth', () => { const auth = getDbAuth().current await act(async () => await auth.validateResetToken('token')) - expect(globalThis.fetch).toBeCalledWith( + expect(globalThis.fetch).toHaveBeenCalledWith( `${globalThis.RWJS_API_URL}/auth`, expect.objectContaining({ credentials: 'include', @@ -241,7 +255,7 @@ describe('dbAuth', () => { await act(async () => await auth.forgotPassword('username')) - expect(fetchMock).toBeCalledWith( + expect(fetchMock).toHaveBeenCalledWith( '/.redwood/functions/dbauth', expect.objectContaining({ credentials: 'same-origin', @@ -326,7 +340,7 @@ describe('dbAuth', () => { expect(authRef.current.hasRole('user')).toBeFalsy() await act(async () => { - authRef.current.logIn({ + await authRef.current.logIn({ username: 'auth-test', password: 'ThereIsNoSpoon', }) diff --git a/packages/auth-providers/dbAuth/web/src/dbAuth.ts b/packages/auth-providers/dbAuth/web/src/dbAuth.ts index 510f2930c1de..dcb84e414083 100644 --- a/packages/auth-providers/dbAuth/web/src/dbAuth.ts +++ b/packages/auth-providers/dbAuth/web/src/dbAuth.ts @@ -1,6 +1,7 @@ import type { CurrentUser } from '@redwoodjs/auth' import { createAuthentication } from '@redwoodjs/auth' +import { getCurrentUserFromMiddleware } from './getCurrentUserFromMiddleware' import type { WebAuthnClientType } from './webAuthn' export interface LoginAttributes { @@ -17,6 +18,27 @@ export type SignupAttributes = Record & LoginAttributes const TOKEN_CACHE_TIME = 5000 +export type CustomProviderHooks = { + useCurrentUser?: () => Promise + useHasRole?: ( + currentUser: CurrentUser | null, + ) => (rolesToCheck: string | string[]) => boolean +} + +export function createMiddlewareAuth( + dbAuthClient: ReturnType, + customProviderHooks?: CustomProviderHooks, +) { + return createAuthentication(dbAuthClient, { + // @MARK This is key! 👇 + // Override the default getCurrentUser to fetch it from middleware instead + ...customProviderHooks, + useCurrentUser: + customProviderHooks?.useCurrentUser ?? + (() => getCurrentUserFromMiddleware(dbAuthClient.getAuthUrl())), + }) +} + export function createAuth( dbAuthClient: ReturnType, customProviderHooks?: { @@ -35,12 +57,14 @@ export interface DbAuthClientArgs { fetchConfig?: { credentials?: 'include' | 'same-origin' } + middleware?: boolean } export function createDbAuthClient({ webAuthn, dbAuthUrl, fetchConfig, + middleware = false, }: DbAuthClientArgs = {}) { const credentials = fetchConfig?.credentials || 'same-origin' webAuthn?.setAuthApiUrl(dbAuthUrl) @@ -49,8 +73,12 @@ export function createDbAuthClient({ let lastTokenCheckAt = new Date('1970-01-01T00:00:00') let cachedToken: string | null - const getApiDbAuthUrl = () => { - return dbAuthUrl || `${RWJS_API_URL}/auth` + const getDbAuthUrl = () => { + if (dbAuthUrl) { + return dbAuthUrl + } + + return middleware ? `/middleware/dbauth` : `${RWJS_API_URL}/auth` } const resetAndFetch = async (...params: Parameters) => { @@ -69,7 +97,7 @@ export function createDbAuthClient({ } const forgotPassword = async (username: string) => { - const response = await resetAndFetch(getApiDbAuthUrl(), { + const response = await resetAndFetch(getDbAuthUrl(), { credentials, method: 'POST', headers: { 'Content-Type': 'application/json' }, @@ -80,6 +108,10 @@ export function createDbAuthClient({ } const getToken = async () => { + // Middleware auth providers doesn't need a token + if (middleware) { + return null + } // Return the existing fetch promise, so that parallel calls // to getToken only cause a single fetch if (getTokenPromise) { @@ -87,7 +119,7 @@ export function createDbAuthClient({ } if (isTokenCacheExpired()) { - getTokenPromise = fetch(`${getApiDbAuthUrl()}?method=getToken`, { + getTokenPromise = fetch(`${getDbAuthUrl()}?method=getToken`, { credentials, }) .then((response) => response.text()) @@ -110,7 +142,7 @@ export function createDbAuthClient({ } const login = async ({ username, password }: LoginAttributes) => { - const response = await resetAndFetch(getApiDbAuthUrl(), { + const response = await resetAndFetch(getDbAuthUrl(), { credentials, method: 'POST', headers: { 'Content-Type': 'application/json' }, @@ -121,7 +153,7 @@ export function createDbAuthClient({ } const logout = async () => { - await resetAndFetch(getApiDbAuthUrl(), { + await resetAndFetch(getDbAuthUrl(), { credentials, method: 'POST', body: JSON.stringify({ method: 'logout' }), @@ -131,7 +163,7 @@ export function createDbAuthClient({ } const resetPassword = async (attributes: ResetPasswordAttributes) => { - const response = await resetAndFetch(getApiDbAuthUrl(), { + const response = await resetAndFetch(getDbAuthUrl(), { credentials, method: 'POST', headers: { 'Content-Type': 'application/json' }, @@ -142,7 +174,7 @@ export function createDbAuthClient({ } const signup = async (attributes: SignupAttributes) => { - const response = await resetAndFetch(getApiDbAuthUrl(), { + const response = await resetAndFetch(getDbAuthUrl(), { credentials, method: 'POST', headers: { 'Content-Type': 'application/json' }, @@ -153,7 +185,7 @@ export function createDbAuthClient({ } const validateResetToken = async (resetToken: string | null) => { - const response = await resetAndFetch(getApiDbAuthUrl(), { + const response = await resetAndFetch(getDbAuthUrl(), { credentials, method: 'POST', headers: { 'Content-Type': 'application/json' }, @@ -163,6 +195,20 @@ export function createDbAuthClient({ return response.json() } + /* + Cookie+Middleware based auth providers cannot retrieve current user from localStorage, etc. + It either has to retrieve it from serverAuthState (e.g. on first render) + or has to retrieve it from the middleware, where the cookie gets validated first. + + getUserMetadata is used in reauthenticate. So when you login in, the currentUser get's fetched + from the server, so that it will redirect + */ + const getUserMetadata = async () => { + return middleware + ? getCurrentUserFromMiddleware(getDbAuthUrl()) + : getToken() + } + return { type: 'dbAuth', client: webAuthn, @@ -170,9 +216,14 @@ export function createDbAuthClient({ logout, signup, getToken, - getUserMetadata: getToken, + getUserMetadata, forgotPassword, resetPassword, validateResetToken, + // 👇 New methods for middleware auth + // so we can get the dbAuthUrl in getCurrentUserFromMiddleware + getAuthUrl: getDbAuthUrl, + // This is so that we can skip fetching getCurrentUser in reauthenticate + useMiddlewareAuth: middleware, } } diff --git a/packages/auth-providers/dbAuth/web/src/getCurrentUserFromMiddleware.ts b/packages/auth-providers/dbAuth/web/src/getCurrentUserFromMiddleware.ts new file mode 100644 index 000000000000..3e48dee58971 --- /dev/null +++ b/packages/auth-providers/dbAuth/web/src/getCurrentUserFromMiddleware.ts @@ -0,0 +1,28 @@ +/* + This call allows the middleware to validate the cookie and return the current user. + */ +export const getCurrentUserFromMiddleware = async < + TCurrentUser = Record, +>( + authUrl: string, +): Promise => { + const response = await globalThis.fetch(`${authUrl}/currentUser`, { + method: 'GET', + credentials: 'include', + headers: { + 'content-type': 'application/json', + }, + }) + + if (response.ok) { + const { currentUser } = await response.json() + if (!currentUser) { + throw new Error('No current user found') + } + return currentUser + } else { + throw new Error( + `Could not fetch current user: ${response.statusText} (${response.status})`, + ) + } +} diff --git a/packages/auth/src/AuthImplementation.ts b/packages/auth/src/AuthImplementation.ts index 0bb3642416e8..b5264d2f5caa 100644 --- a/packages/auth/src/AuthImplementation.ts +++ b/packages/auth/src/AuthImplementation.ts @@ -52,4 +52,9 @@ export interface AuthImplementation< * Set "loading" to true while the auth provider is reauthenticating. */ loadWhileReauthenticating?: boolean + + // 👇 @TODO: Naming! Middleware-auth only + useMiddlewareAuth?: boolean + // This is the endpoint on the middleware we are going to hit for POST requests + getAuthUrl?: () => string } diff --git a/packages/auth/src/AuthProvider/useReauthenticate.ts b/packages/auth/src/AuthProvider/useReauthenticate.ts index 93d2d84b1d27..4765711504db 100644 --- a/packages/auth/src/AuthProvider/useReauthenticate.ts +++ b/packages/auth/src/AuthProvider/useReauthenticate.ts @@ -52,10 +52,17 @@ export const useReauthenticate = ( client: authImplementation.client, }) } else { - // This call here is a local check against the auth provider's client. - // e.g. if the auth sdk has logged you out, it'll throw an error - await getToken() - const currentUser = await getCurrentUser() + // Prevent a double fetch of the current user if the auth provider is using middleware + let currentUser + if (authImplementation.useMiddlewareAuth) { + // userMetadata === currentUser in middleware-auth + currentUser = userMetadata + } else { + // This call here is a local check against the auth provider's client. + // e.g. if the auth sdk has logged you out, it'll throw an error + await getToken() + currentUser = await getCurrentUser() + } setAuthProviderState((oldState) => ({ ...oldState,