diff --git a/lib/build/customFramework.js b/lib/build/customFramework.js index f4794a458..9326c42ef 100644 --- a/lib/build/customFramework.js +++ b/lib/build/customFramework.js @@ -20,6 +20,7 @@ const cookieAndHeaders_1 = require("./recipe/session/cookieAndHeaders"); const jwt_1 = require("./recipe/session/jwt"); const jose_1 = require("jose"); const supertokens_1 = __importDefault(require("./supertokens")); +const accessToken_1 = require("./recipe/session/accessToken"); function createPreParsedRequest(request, getCookieFn = getCookieFromRequest) { /** * This helper function can take any `Request` type of object @@ -75,11 +76,6 @@ exports.getQueryFromRequest = getQueryFromRequest; function getAccessToken(request) { return getCookieFromRequest(request)["sAccessToken"]; } -async function verifyToken(token, jwks) { - // Verify the JWT using the remote JWK set and return the payload - const { payload } = await jose_1.jwtVerify(token, jwks); - return payload; -} function getHandleCall(res, stMiddleware) { return async function handleCall(req) { const baseResponse = new custom_1.CollectingResponse(); @@ -190,8 +186,9 @@ async function getSessionForSSR(request, jwks) { } try { if (accessToken) { - const decoded = await verifyToken(accessToken, jwksToUse); - return { accessTokenPayload: decoded, hasToken, error: undefined }; + const tokenInfo = jwt_1.parseJWTWithoutSignatureVerification(accessToken); + const decoded = await accessToken_1.getInfoFromAccessToken(tokenInfo, jwksToUse, false); + return { accessTokenPayload: decoded.userData, hasToken, error: undefined }; } return { accessTokenPayload: undefined, hasToken, error: undefined }; } catch (error) { diff --git a/lib/ts/customFramework.ts b/lib/ts/customFramework.ts index 6049d5f49..544685fb1 100644 --- a/lib/ts/customFramework.ts +++ b/lib/ts/customFramework.ts @@ -11,9 +11,10 @@ import SessionRecipe from "./recipe/session/recipe"; import { availableTokenTransferMethods } from "./recipe/session/constants"; import { getToken } from "./recipe/session/cookieAndHeaders"; import { parseJWTWithoutSignatureVerification } from "./recipe/session/jwt"; -import { jwtVerify, JWTPayload, createRemoteJWKSet } from "jose"; +import { JWTPayload, createRemoteJWKSet } from "jose"; import SuperTokens from "./supertokens"; import { HTTPMethod } from "./types"; +import { getInfoFromAccessToken } from "./recipe/session/accessToken"; export type GetCookieFn = (req: T) => Record; @@ -84,12 +85,6 @@ function getAccessToken(request: Request): string | undefined { return getCookieFromRequest(request)["sAccessToken"]; } -async function verifyToken(token: string, jwks: any): Promise { - // Verify the JWT using the remote JWK set and return the payload - const { payload } = await jwtVerify(token, jwks); - return payload; -} - export function getHandleCall(res: typeof Response, stMiddleware: any) { return async function handleCall(req: T) { const baseResponse = new CollectingResponse(); @@ -224,8 +219,9 @@ export async function getSessionForSSR( try { if (accessToken) { - const decoded = await verifyToken(accessToken, jwksToUse); - return { accessTokenPayload: decoded, hasToken, error: undefined }; + const tokenInfo = parseJWTWithoutSignatureVerification(accessToken); + const decoded = await getInfoFromAccessToken(tokenInfo, jwksToUse, false); + return { accessTokenPayload: decoded.userData, hasToken, error: undefined }; } return { accessTokenPayload: undefined, hasToken, error: undefined }; } catch (error) { diff --git a/test/customFramework.test.js b/test/customFramework.test.js index c010b0da4..aee2a5ab2 100644 --- a/test/customFramework.test.js +++ b/test/customFramework.test.js @@ -10,19 +10,56 @@ let SuperTokens = require("../lib/build/").default; const Session = require("../lib/build/recipe/session"); const EmailPassword = require("../lib/build/recipe/emailpassword"); const { PreParsedRequest } = require("../lib/build/framework/custom"); -const { printPath, setupST, startST, killAllST, cleanST, delay } = require("./utils"); -const { generateKeyPair, SignJWT } = require("jose"); +const { printPath, setupST, startST, killAllST, cleanST } = require("./utils"); +const { generateKeyPair, SignJWT, exportJWK, importJWK, decodeJwt } = require("jose"); // Helper function to create a JWKS async function createJWKS() { - const { privateKey } = await generateKeyPair("RS256"); - return privateKey; + // Generate an RSA key pair + const { privateKey, publicKey } = await generateKeyPair("RS256"); + + // Export the public key to JWK format + const jwk = await exportJWK(publicKey); + + // Construct the JWKS + const jwks = { + keys: [ + { + ...jwk, + alg: "RS256", + use: "sig", + kid: "test-key-id", + }, + ], + }; + + return { privateKey, jwks }; +} + +async function createJWTVerifyGetKey(jwks) { + // Find the JWK in the set based on `kid` + const jwk = jwks.keys.find((k) => k.kid === "test-key-id"); + + if (!jwk) { + throw new Error("Key with the specified kid not found in JWKS"); + } + + // Import the JWK as a CryptoKey suitable for RS256 verification + return await importJWK(jwk, "RS256"); } // Function to sign a JWT -async function signJWT(privateKey, payload, expiresIn = "2h") { +async function signJWT(privateKey, jwks, payload, expiresIn = "2h") { + // Find the corresponding public key in the JWKS to get the `kid` and `alg` + const publicJWK = jwks.keys.find((k) => k.kid === "test-key-id"); + + if (!publicJWK) { + throw new Error("Key with the specified kid not found in JWKS"); + } + + // Sign the JWT using the private key return new SignJWT(payload) - .setProtectedHeader({ alg: "RS256", kid: "test-key-id" }) + .setProtectedHeader({ alg: publicJWK.alg, kid: publicJWK.kid, version: "5", typ: "JWT" }) .setIssuedAt() .setExpirationTime(expiresIn) .sign(privateKey); @@ -84,8 +121,8 @@ describe(`createPreParsedRequest ${printPath("[test/customFramework.test.js]")}` describe(`handleAuthAPIRequest ${printPath("[test/customFramework.test.js]")}`, () => { let connectionURI; - let accessToken; - let privateKey; + let accessToken, accessTokenPayload; + let privateKey, jwks; before(async function () { process.env.user = undefined; @@ -122,7 +159,9 @@ describe(`handleAuthAPIRequest ${printPath("[test/customFramework.test.js]")}`, ], }); - privateKey = await createJWKS(); + const { privateKey: privateKeyGenerated, jwks: jwksGenerated } = await createJWKS(); + privateKey = privateKeyGenerated; + jwks = jwksGenerated; }); after(async function () { @@ -202,6 +241,7 @@ describe(`handleAuthAPIRequest ${printPath("[test/customFramework.test.js]")}`, ); accessToken = response.headers.get("st-access-token"); + accessTokenPayload = decodeJwt(accessToken); assert.ok(accessToken, "st-access-token header should be set"); assert.ok(response.headers.get("st-refresh-token"), "st-refresh-token header should be set"); @@ -309,12 +349,14 @@ describe(`handleAuthAPIRequest ${printPath("[test/customFramework.test.js]")}`, assert.strictEqual(await response.text(), "Not found", "Should return Not found"); }); - it("getSessionForSSR should return session for valid token", async () => { - // Create a valid JWT payload - const payload = { userId: "123", email: "john.doe@example.com" }; + // NOTE: For all the JWT related testing, we are using a different key because + // the default way of getting the key is by hitting the `/jwt/jwks` endpoint + // but that endpoint doesn't return anything for testing and thus we are testing + // with a custom key. + it("getSessionForSSR should return session for valid token", async () => { // Sign the JWT - const validToken = await signJWT(privateKey, payload); + const validToken = await signJWT(privateKey, jwks, accessTokenPayload); // Create a mock request containing the valid token as a cookie const mockRequest = new Request("https://example.com", { @@ -322,14 +364,13 @@ describe(`handleAuthAPIRequest ${printPath("[test/customFramework.test.js]")}`, }); // Call the getSessionForSSR function - const result = await getSessionForSSR(mockRequest, privateKey); + const result = await getSessionForSSR(mockRequest, await createJWTVerifyGetKey(jwks)); // Assertions assert.strictEqual(result.hasToken, true, "hasToken should be true for a valid token"); assert.ok(result.accessTokenPayload, "accessTokenPayload should be present for a valid token"); assert.strictEqual(result.error, undefined, "error should be undefined for a valid token"); - assert.strictEqual(result.accessTokenPayload.userId, "123", "User ID in payload should match"); - assert.strictEqual(result.accessTokenPayload.email, "john.doe@example.com", "Email in payload should match"); + assert.strictEqual(result.accessTokenPayload.sub, accessTokenPayload.sub, "User ID in payload should match"); }); it("should return undefined accessTokenPayload and hasToken as false when no token is present", async () => { @@ -350,11 +391,8 @@ describe(`handleAuthAPIRequest ${printPath("[test/customFramework.test.js]")}`, }); it("should handle an expired token gracefully", async () => { - // Create a payload for the token - const payload = { userId: "123", email: "john.doe@example.com" }; - // Sign the JWT with an expiration time in the past (e.g., 1 second ago) - const expiredToken = await signJWT(privateKey, payload, Math.floor(Date.now() / 1000) - 1); + const expiredToken = await signJWT(privateKey, jwks, accessTokenPayload, Math.floor(Date.now() / 1000) - 1); // Create a mock request containing the expired token as a cookie const mockRequest = new Request("https://example.com", { @@ -371,7 +409,11 @@ describe(`handleAuthAPIRequest ${printPath("[test/customFramework.test.js]")}`, undefined, "accessTokenPayload should be undefined for an expired token" ); - assert.strictEqual(result.error, undefined, "error should be undefined for an expired token"); + assert.strictEqual( + result.error.type, + "TRY_REFRESH_TOKEN", + "error should be TRY_REFRESH_TOKEN for an expired token" + ); }); it("should return an error for an invalid token", async () => {