Skip to content

Commit

Permalink
Use internal functions for jwt parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
deepjyoti30-st committed Oct 4, 2024
1 parent 81017fd commit 1e75b00
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 37 deletions.
11 changes: 4 additions & 7 deletions lib/build/customFramework.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 5 additions & 9 deletions lib/ts/customFramework.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ import SessionRecipe from "./recipe/session/recipe";
import { availableTokenTransferMethods } from "./recipe/session/constants";
import { getToken } from "./recipe/session/cookieAndHeaders";
import { parseJWTWithoutSignatureVerification } from "./recipe/session/jwt";
import { jwtVerify, JWTPayload, createRemoteJWKSet } from "jose";
import { JWTPayload, createRemoteJWKSet } from "jose";
import SuperTokens from "./supertokens";
import { HTTPMethod } from "./types";
import { getInfoFromAccessToken } from "./recipe/session/accessToken";

export type GetCookieFn<T extends ParsableRequest = Request> = (req: T) => Record<string, string>;

Expand Down Expand Up @@ -84,12 +85,6 @@ function getAccessToken(request: Request): string | undefined {
return getCookieFromRequest(request)["sAccessToken"];
}

async function verifyToken(token: string, jwks: any): Promise<JWTPayload> {
// Verify the JWT using the remote JWK set and return the payload
const { payload } = await jwtVerify(token, jwks);
return payload;
}

export function getHandleCall<T = Request>(res: typeof Response, stMiddleware: any) {
return async function handleCall(req: T) {
const baseResponse = new CollectingResponse();
Expand Down Expand Up @@ -224,8 +219,9 @@ export async function getSessionForSSR(

try {
if (accessToken) {
const decoded = await verifyToken(accessToken, jwksToUse);
return { accessTokenPayload: decoded, hasToken, error: undefined };
const tokenInfo = parseJWTWithoutSignatureVerification(accessToken);
const decoded = await getInfoFromAccessToken(tokenInfo, jwksToUse, false);
return { accessTokenPayload: decoded.userData, hasToken, error: undefined };
}
return { accessTokenPayload: undefined, hasToken, error: undefined };
} catch (error) {
Expand Down
84 changes: 63 additions & 21 deletions test/customFramework.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,56 @@ let SuperTokens = require("../lib/build/").default;
const Session = require("../lib/build/recipe/session");
const EmailPassword = require("../lib/build/recipe/emailpassword");
const { PreParsedRequest } = require("../lib/build/framework/custom");
const { printPath, setupST, startST, killAllST, cleanST, delay } = require("./utils");
const { generateKeyPair, SignJWT } = require("jose");
const { printPath, setupST, startST, killAllST, cleanST } = require("./utils");
const { generateKeyPair, SignJWT, exportJWK, importJWK, decodeJwt } = require("jose");

// Helper function to create a JWKS
async function createJWKS() {
const { privateKey } = await generateKeyPair("RS256");
return privateKey;
// Generate an RSA key pair
const { privateKey, publicKey } = await generateKeyPair("RS256");

// Export the public key to JWK format
const jwk = await exportJWK(publicKey);

// Construct the JWKS
const jwks = {
keys: [
{
...jwk,
alg: "RS256",
use: "sig",
kid: "test-key-id",
},
],
};

return { privateKey, jwks };
}

async function createJWTVerifyGetKey(jwks) {
// Find the JWK in the set based on `kid`
const jwk = jwks.keys.find((k) => k.kid === "test-key-id");

if (!jwk) {
throw new Error("Key with the specified kid not found in JWKS");
}

// Import the JWK as a CryptoKey suitable for RS256 verification
return await importJWK(jwk, "RS256");
}

// Function to sign a JWT
async function signJWT(privateKey, payload, expiresIn = "2h") {
async function signJWT(privateKey, jwks, payload, expiresIn = "2h") {
// Find the corresponding public key in the JWKS to get the `kid` and `alg`
const publicJWK = jwks.keys.find((k) => k.kid === "test-key-id");

if (!publicJWK) {
throw new Error("Key with the specified kid not found in JWKS");
}

// Sign the JWT using the private key
return new SignJWT(payload)
.setProtectedHeader({ alg: "RS256", kid: "test-key-id" })
.setProtectedHeader({ alg: publicJWK.alg, kid: publicJWK.kid, version: "5", typ: "JWT" })
.setIssuedAt()
.setExpirationTime(expiresIn)
.sign(privateKey);
Expand Down Expand Up @@ -84,8 +121,8 @@ describe(`createPreParsedRequest ${printPath("[test/customFramework.test.js]")}`

describe(`handleAuthAPIRequest ${printPath("[test/customFramework.test.js]")}`, () => {
let connectionURI;
let accessToken;
let privateKey;
let accessToken, accessTokenPayload;
let privateKey, jwks;

before(async function () {
process.env.user = undefined;
Expand Down Expand Up @@ -122,7 +159,9 @@ describe(`handleAuthAPIRequest ${printPath("[test/customFramework.test.js]")}`,
],
});

privateKey = await createJWKS();
const { privateKey: privateKeyGenerated, jwks: jwksGenerated } = await createJWKS();
privateKey = privateKeyGenerated;
jwks = jwksGenerated;
});

after(async function () {
Expand Down Expand Up @@ -202,6 +241,7 @@ describe(`handleAuthAPIRequest ${printPath("[test/customFramework.test.js]")}`,
);

accessToken = response.headers.get("st-access-token");
accessTokenPayload = decodeJwt(accessToken);

assert.ok(accessToken, "st-access-token header should be set");
assert.ok(response.headers.get("st-refresh-token"), "st-refresh-token header should be set");
Expand Down Expand Up @@ -309,27 +349,28 @@ describe(`handleAuthAPIRequest ${printPath("[test/customFramework.test.js]")}`,
assert.strictEqual(await response.text(), "Not found", "Should return Not found");
});

it("getSessionForSSR should return session for valid token", async () => {
// Create a valid JWT payload
const payload = { userId: "123", email: "john.doe@example.com" };
// NOTE: For all the JWT related testing, we are using a different key because
// the default way of getting the key is by hitting the `/jwt/jwks` endpoint
// but that endpoint doesn't return anything for testing and thus we are testing
// with a custom key.

it("getSessionForSSR should return session for valid token", async () => {
// Sign the JWT
const validToken = await signJWT(privateKey, payload);
const validToken = await signJWT(privateKey, jwks, accessTokenPayload);

// Create a mock request containing the valid token as a cookie
const mockRequest = new Request("https://example.com", {
headers: { Cookie: `sAccessToken=${validToken}` },
});

// Call the getSessionForSSR function
const result = await getSessionForSSR(mockRequest, privateKey);
const result = await getSessionForSSR(mockRequest, await createJWTVerifyGetKey(jwks));

// Assertions
assert.strictEqual(result.hasToken, true, "hasToken should be true for a valid token");
assert.ok(result.accessTokenPayload, "accessTokenPayload should be present for a valid token");
assert.strictEqual(result.error, undefined, "error should be undefined for a valid token");
assert.strictEqual(result.accessTokenPayload.userId, "123", "User ID in payload should match");
assert.strictEqual(result.accessTokenPayload.email, "john.doe@example.com", "Email in payload should match");
assert.strictEqual(result.accessTokenPayload.sub, accessTokenPayload.sub, "User ID in payload should match");
});

it("should return undefined accessTokenPayload and hasToken as false when no token is present", async () => {
Expand All @@ -350,11 +391,8 @@ describe(`handleAuthAPIRequest ${printPath("[test/customFramework.test.js]")}`,
});

it("should handle an expired token gracefully", async () => {
// Create a payload for the token
const payload = { userId: "123", email: "john.doe@example.com" };

// Sign the JWT with an expiration time in the past (e.g., 1 second ago)
const expiredToken = await signJWT(privateKey, payload, Math.floor(Date.now() / 1000) - 1);
const expiredToken = await signJWT(privateKey, jwks, accessTokenPayload, Math.floor(Date.now() / 1000) - 1);

// Create a mock request containing the expired token as a cookie
const mockRequest = new Request("https://example.com", {
Expand All @@ -371,7 +409,11 @@ describe(`handleAuthAPIRequest ${printPath("[test/customFramework.test.js]")}`,
undefined,
"accessTokenPayload should be undefined for an expired token"
);
assert.strictEqual(result.error, undefined, "error should be undefined for an expired token");
assert.strictEqual(
result.error.type,
"TRY_REFRESH_TOKEN",
"error should be TRY_REFRESH_TOKEN for an expired token"
);
});

it("should return an error for an invalid token", async () => {
Expand Down

0 comments on commit 1e75b00

Please sign in to comment.