diff --git a/src/main/java/io/supertokens/session/Session.java b/src/main/java/io/supertokens/session/Session.java index dcd0c6ae7..8ca76e2fb 100644 --- a/src/main/java/io/supertokens/session/Session.java +++ b/src/main/java/io/supertokens/session/Session.java @@ -25,10 +25,11 @@ import io.supertokens.exceptions.TryRefreshTokenException; import io.supertokens.exceptions.UnauthorisedException; import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.pluginInterface.Storage; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.multitenancy.AppIdentifier; -import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; +import io.supertokens.pluginInterface.multitenancy.TenantIdentifierWithStorage; import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; import io.supertokens.pluginInterface.session.noSqlStorage.SessionNoSQLStorage_1; import io.supertokens.pluginInterface.session.sqlStorage.SessionSQLStorage; @@ -58,7 +59,8 @@ public class Session { @TestOnly - public static SessionInformationHolder createNewSession(TenantIdentifier tenantIdentifier, Main main, + public static SessionInformationHolder createNewSession(TenantIdentifierWithStorage tenantIdentifierWithStorage, + Main main, @Nonnull String userId, @Nonnull JsonObject userDataInJWT, @Nonnull JsonObject userDataInDatabase) @@ -66,7 +68,7 @@ public static SessionInformationHolder createNewSession(TenantIdentifier tenantI InvalidKeySpecException, StorageTransactionLogicException, SignatureException, IllegalBlockSizeException, BadPaddingException, InvalidAlgorithmParameterException, NoSuchPaddingException { try { - return createNewSession(tenantIdentifier, main, userId, userDataInJWT, userDataInDatabase, + return createNewSession(tenantIdentifierWithStorage, main, userId, userDataInJWT, userDataInDatabase, false); } catch (TenantOrAppNotFoundException e) { throw new IllegalStateException(e); @@ -81,8 +83,10 @@ public static SessionInformationHolder createNewSession(Main main, throws NoSuchAlgorithmException, UnsupportedEncodingException, StorageQueryException, InvalidKeyException, InvalidKeySpecException, StorageTransactionLogicException, SignatureException, IllegalBlockSizeException, BadPaddingException, InvalidAlgorithmParameterException, NoSuchPaddingException { - return createNewSession(new TenantIdentifier(null, null, null), main, userId, userDataInJWT, - userDataInDatabase); + Storage storage = StorageLayer.getStorage(main); + return createNewSession( + new TenantIdentifierWithStorage(null, null, null, storage), main, + userId, userDataInJWT, userDataInDatabase); } @TestOnly @@ -94,14 +98,16 @@ public static SessionInformationHolder createNewSession(Main main, @Nonnull Stri InvalidKeySpecException, StorageTransactionLogicException, SignatureException, IllegalBlockSizeException, BadPaddingException, InvalidAlgorithmParameterException, NoSuchPaddingException { try { - return createNewSession(new TenantIdentifier(null, null, null), main, userId, userDataInJWT, - userDataInDatabase, enableAntiCsrf); + Storage storage = StorageLayer.getStorage(main); + return createNewSession( + new TenantIdentifierWithStorage(null, null, null, storage), main, + userId, userDataInJWT, userDataInDatabase, enableAntiCsrf); } catch (TenantOrAppNotFoundException e) { throw new IllegalStateException(e); } } - public static SessionInformationHolder createNewSession(TenantIdentifier tenantIdentifier, Main main, + public static SessionInformationHolder createNewSession(TenantIdentifierWithStorage tenantIdentifierWithStorage, Main main, @Nonnull String userId, @Nonnull JsonObject userDataInJWT, @Nonnull JsonObject userDataInDatabase, @@ -112,17 +118,17 @@ public static SessionInformationHolder createNewSession(TenantIdentifier tenantI TenantOrAppNotFoundException { String sessionHandle = UUID.randomUUID().toString(); String antiCsrfToken = enableAntiCsrf ? UUID.randomUUID().toString() : null; - final TokenInfo refreshToken = RefreshToken.createNewRefreshToken(tenantIdentifier, main, + final TokenInfo refreshToken = RefreshToken.createNewRefreshToken(tenantIdentifierWithStorage, main, sessionHandle, userId, null, antiCsrfToken); - TokenInfo accessToken = AccessToken.createNewAccessToken(tenantIdentifier, main, + TokenInfo accessToken = AccessToken.createNewAccessToken(tenantIdentifierWithStorage, main, sessionHandle, userId, Utils.hashSHA256(refreshToken.token), null, userDataInJWT, antiCsrfToken, System.currentTimeMillis(), null); - StorageLayer.getSessionStorage(tenantIdentifier, main).createNewSession(tenantIdentifier, sessionHandle, userId, + tenantIdentifierWithStorage.getSessionStorage().createNewSession(tenantIdentifierWithStorage, sessionHandle, userId, Utils.hashSHA256(Utils.hashSHA256(refreshToken.token)), userDataInDatabase, refreshToken.expiry, userDataInJWT, refreshToken.createdTime); // TODO: add lmrt to database @@ -172,16 +178,14 @@ public static SessionInformationHolder regenerateToken(AppIdentifier appIdentifi // We assume the token has already been verified at this point. It may be expired or JWT signing key may have // changed for it... AccessTokenInfo accessToken = AccessToken.getInfoFromAccessTokenWithoutVerifying(appIdentifier, token); - TenantIdentifier tenantIdentifier = accessToken.tenantIdentifier; - if (!tenantIdentifier.toAppIdentifier().equals(appIdentifier)) { - throw new UnauthorisedException("Access token is from an incorrect app"); - } + TenantIdentifierWithStorage tenantIdentifierWithStorage = accessToken.tenantIdentifier.withStorage( + StorageLayer.getStorage(accessToken.tenantIdentifier, main)); JsonObject newJWTUserPayload = userDataInJWT == null ? - getSession(tenantIdentifier, main, accessToken.sessionHandle).userDataInJWT + getSession(tenantIdentifierWithStorage, accessToken.sessionHandle).userDataInJWT : userDataInJWT; long lmrt = System.currentTimeMillis(); - updateSession(tenantIdentifier, main, accessToken.sessionHandle, null, newJWTUserPayload, lmrt); + updateSession(tenantIdentifierWithStorage, accessToken.sessionHandle, null, newJWTUserPayload, lmrt); // if the above succeeds but the below fails, it's OK since the client will get server error and will try // again. In this case, the JWT data will be updated again since the API will get the old JWT. In case there @@ -194,7 +198,7 @@ public static SessionInformationHolder regenerateToken(AppIdentifier appIdentifi null); } - TokenInfo newAccessToken = AccessToken.createNewAccessToken(tenantIdentifier, main, + TokenInfo newAccessToken = AccessToken.createNewAccessToken(tenantIdentifierWithStorage, main, accessToken.sessionHandle, accessToken.userId, accessToken.refreshTokenHash1, accessToken.parentRefreshTokenHash1, newJWTUserPayload, accessToken.antiCsrfToken, lmrt, accessToken.expiryTime); @@ -229,7 +233,8 @@ public static SessionInformationHolder getSession(AppIdentifier appIdentifier, M AccessTokenInfo accessToken = AccessToken.getInfoFromAccessToken(appIdentifier, main, token, doAntiCsrfCheck && enableAntiCsrf); - TenantIdentifier tenantIdentifier = accessToken.tenantIdentifier; + TenantIdentifierWithStorage tenantIdentifierWithStorage = accessToken.tenantIdentifier.withStorage( + StorageLayer.getStorage(accessToken.tenantIdentifier, main)); if (enableAntiCsrf && doAntiCsrfCheck && (antiCsrfToken == null || !antiCsrfToken.equals(accessToken.antiCsrfToken))) { @@ -237,9 +242,9 @@ public static SessionInformationHolder getSession(AppIdentifier appIdentifier, M } io.supertokens.pluginInterface.session.SessionInfo sessionInfoForBlacklisting = null; - if (Config.getConfig(tenantIdentifier, main).getAccessTokenBlacklisting()) { - sessionInfoForBlacklisting = StorageLayer.getSessionStorage(tenantIdentifier, main) - .getSession(tenantIdentifier, accessToken.sessionHandle); + if (Config.getConfig(tenantIdentifierWithStorage, main).getAccessTokenBlacklisting()) { + sessionInfoForBlacklisting = tenantIdentifierWithStorage.getSessionStorage() + .getSession(tenantIdentifierWithStorage, accessToken.sessionHandle); if (sessionInfoForBlacklisting == null) { throw new UnauthorisedException("Either the session has ended or has been blacklisted"); } @@ -257,15 +262,15 @@ public static SessionInformationHolder getSession(AppIdentifier appIdentifier, M ProcessState.getInstance(main).addState(ProcessState.PROCESS_STATE.GET_SESSION_NEW_TOKENS, null); - if (StorageLayer.getSessionStorage(tenantIdentifier, main).getType() == STORAGE_TYPE.SQL) { - SessionSQLStorage storage = (SessionSQLStorage) StorageLayer.getSessionStorage(tenantIdentifier, main); + if (tenantIdentifierWithStorage.getSessionStorage().getType() == STORAGE_TYPE.SQL) { + SessionSQLStorage storage = (SessionSQLStorage) tenantIdentifierWithStorage.getSessionStorage(); try { - CoreConfig config = Config.getConfig(tenantIdentifier, main); + CoreConfig config = Config.getConfig(tenantIdentifierWithStorage, main); return storage.startTransaction(con -> { try { io.supertokens.pluginInterface.session.SessionInfo sessionInfo = storage - .getSessionInfo_Transaction(tenantIdentifier, con, accessToken.sessionHandle); + .getSessionInfo_Transaction(tenantIdentifierWithStorage, con, accessToken.sessionHandle); if (sessionInfo == null) { storage.commitTransaction(con); @@ -278,7 +283,7 @@ public static SessionInformationHolder getSession(AppIdentifier appIdentifier, M || sessionInfo.refreshTokenHash2.equals(Utils.hashSHA256(accessToken.refreshTokenHash1)) || JWTPayloadNeedsUpdating) { if (promote) { - storage.updateSessionInfo_Transaction(tenantIdentifier, con, accessToken.sessionHandle, + storage.updateSessionInfo_Transaction(tenantIdentifierWithStorage, con, accessToken.sessionHandle, Utils.hashSHA256(accessToken.refreshTokenHash1), System.currentTimeMillis() + config.getRefreshTokenValidity()); @@ -287,14 +292,14 @@ public static SessionInformationHolder getSession(AppIdentifier appIdentifier, M TokenInfo newAccessToken; if (AccessToken.getAccessTokenVersion(accessToken) == AccessToken.VERSION.V1) { - newAccessToken = AccessToken.createNewAccessTokenV1(tenantIdentifier, + newAccessToken = AccessToken.createNewAccessTokenV1(tenantIdentifierWithStorage, main, accessToken.sessionHandle, accessToken.userId, accessToken.refreshTokenHash1, null, sessionInfo.userDataInJWT, accessToken.antiCsrfToken); } else { assert accessToken.lmrt != null; - newAccessToken = AccessToken.createNewAccessToken(tenantIdentifier, + newAccessToken = AccessToken.createNewAccessToken(tenantIdentifierWithStorage, main, accessToken.sessionHandle, accessToken.userId, accessToken.refreshTokenHash1, null, @@ -328,10 +333,9 @@ public static SessionInformationHolder getSession(AppIdentifier appIdentifier, M } throw e; } - } else if (StorageLayer.getSessionStorage(tenantIdentifier, main).getType() == + } else if (tenantIdentifierWithStorage.getSessionStorage().getType() == STORAGE_TYPE.NOSQL_1) { - SessionNoSQLStorage_1 storage = (SessionNoSQLStorage_1) StorageLayer.getSessionStorage(tenantIdentifier, - main); + SessionNoSQLStorage_1 storage = (SessionNoSQLStorage_1) tenantIdentifierWithStorage.getSessionStorage(); while (true) { try { @@ -349,7 +353,7 @@ public static SessionInformationHolder getSession(AppIdentifier appIdentifier, M if (promote) { boolean success = storage.updateSessionInfo_Transaction(accessToken.sessionHandle, Utils.hashSHA256(accessToken.refreshTokenHash1), - System.currentTimeMillis() + Config.getConfig(tenantIdentifier, main) + System.currentTimeMillis() + Config.getConfig(tenantIdentifierWithStorage, main) .getRefreshTokenValidity(), sessionInfo.lastUpdatedSign); if (!success) { @@ -359,14 +363,14 @@ public static SessionInformationHolder getSession(AppIdentifier appIdentifier, M TokenInfo newAccessToken; if (AccessToken.getAccessTokenVersion(accessToken) == AccessToken.VERSION.V1) { - newAccessToken = AccessToken.createNewAccessTokenV1(tenantIdentifier, + newAccessToken = AccessToken.createNewAccessTokenV1(tenantIdentifierWithStorage, main, accessToken.sessionHandle, accessToken.userId, accessToken.refreshTokenHash1, null, sessionInfo.userDataInJWT, accessToken.antiCsrfToken); } else { assert accessToken.lmrt != null; - newAccessToken = AccessToken.createNewAccessToken(tenantIdentifier, main, + newAccessToken = AccessToken.createNewAccessToken(tenantIdentifierWithStorage, main, accessToken.sessionHandle, accessToken.userId, accessToken.refreshTokenHash1, null, sessionInfo.userDataInJWT, accessToken.antiCsrfToken, accessToken.lmrt, null); @@ -423,11 +427,14 @@ public static SessionInformationHolder refreshSession(AppIdentifier appIdentifie } } - return refreshSessionHelper(refreshTokenInfo.tenantIdentifier, main, refreshToken, refreshTokenInfo, + return refreshSessionHelper( + refreshTokenInfo.tenantIdentifier.withStorage( + StorageLayer.getStorage(refreshTokenInfo.tenantIdentifier, main)), + main, refreshToken, refreshTokenInfo, enableAntiCsrf); } - private static SessionInformationHolder refreshSessionHelper(TenantIdentifier tenantIdentifier, Main main, + private static SessionInformationHolder refreshSessionHelper(TenantIdentifierWithStorage tenantIdentifierWithStorage, Main main, String refreshToken, RefreshToken.RefreshTokenInfo refreshTokenInfo, boolean enableAntiCsrf) @@ -439,15 +446,15 @@ private static SessionInformationHolder refreshSessionHelper(TenantIdentifier te ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// - if (StorageLayer.getSessionStorage(tenantIdentifier, main).getType() == STORAGE_TYPE.SQL) { - SessionSQLStorage storage = (SessionSQLStorage) StorageLayer.getSessionStorage(tenantIdentifier, main); + if (tenantIdentifierWithStorage.getSessionStorage().getType() == STORAGE_TYPE.SQL) { + SessionSQLStorage storage = (SessionSQLStorage) tenantIdentifierWithStorage.getSessionStorage(); try { - CoreConfig config = Config.getConfig(tenantIdentifier, main); + CoreConfig config = Config.getConfig(tenantIdentifierWithStorage, main); return storage.startTransaction(con -> { try { String sessionHandle = refreshTokenInfo.sessionHandle; io.supertokens.pluginInterface.session.SessionInfo sessionInfo = storage - .getSessionInfo_Transaction(tenantIdentifier, con, sessionHandle); + .getSessionInfo_Transaction(tenantIdentifierWithStorage, con, sessionHandle); if (sessionInfo == null || sessionInfo.expiry < System.currentTimeMillis()) { storage.commitTransaction(con); @@ -459,12 +466,12 @@ private static SessionInformationHolder refreshSessionHelper(TenantIdentifier te storage.commitTransaction(con); String antiCsrfToken = enableAntiCsrf ? UUID.randomUUID().toString() : null; final TokenInfo newRefreshToken = RefreshToken.createNewRefreshToken( - tenantIdentifier, main, + tenantIdentifierWithStorage, main, sessionHandle, sessionInfo.userId, Utils.hashSHA256(refreshToken), antiCsrfToken); TokenInfo newAccessToken = AccessToken.createNewAccessToken( - tenantIdentifier, + tenantIdentifierWithStorage, main, sessionHandle, sessionInfo.userId, Utils.hashSHA256(newRefreshToken.token), Utils.hashSHA256(refreshToken), sessionInfo.userDataInJWT, antiCsrfToken, @@ -484,13 +491,13 @@ private static SessionInformationHolder refreshSessionHelper(TenantIdentifier te || (refreshTokenInfo.parentRefreshTokenHash1 != null && Utils.hashSHA256(refreshTokenInfo.parentRefreshTokenHash1) .equals(sessionInfo.refreshTokenHash2))) { - storage.updateSessionInfo_Transaction(tenantIdentifier, con, sessionHandle, + storage.updateSessionInfo_Transaction(tenantIdentifierWithStorage, con, sessionHandle, Utils.hashSHA256(Utils.hashSHA256(refreshToken)), System.currentTimeMillis() + config.getRefreshTokenValidity()); storage.commitTransaction(con); - return refreshSessionHelper(tenantIdentifier, main, refreshToken, + return refreshSessionHelper(tenantIdentifierWithStorage, main, refreshToken, refreshTokenInfo, enableAntiCsrf); } @@ -522,10 +529,9 @@ private static SessionInformationHolder refreshSessionHelper(TenantIdentifier te ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// - } else if (StorageLayer.getSessionStorage(tenantIdentifier, main).getType() == + } else if (tenantIdentifierWithStorage.getSessionStorage().getType() == STORAGE_TYPE.NOSQL_1) { - SessionNoSQLStorage_1 storage = (SessionNoSQLStorage_1) StorageLayer.getSessionStorage(tenantIdentifier, - main); + SessionNoSQLStorage_1 storage = (SessionNoSQLStorage_1) tenantIdentifierWithStorage.getSessionStorage(); while (true) { try { String sessionHandle = refreshTokenInfo.sessionHandle; @@ -541,10 +547,10 @@ private static SessionInformationHolder refreshSessionHelper(TenantIdentifier te String antiCsrfToken = enableAntiCsrf ? UUID.randomUUID().toString() : null; final TokenInfo newRefreshToken = RefreshToken.createNewRefreshToken( - tenantIdentifier, main, + tenantIdentifierWithStorage, main, sessionHandle, sessionInfo.userId, Utils.hashSHA256(refreshToken), antiCsrfToken); - TokenInfo newAccessToken = AccessToken.createNewAccessToken(tenantIdentifier, + TokenInfo newAccessToken = AccessToken.createNewAccessToken(tenantIdentifierWithStorage, main, sessionHandle, sessionInfo.userId, Utils.hashSHA256(newRefreshToken.token), @@ -568,12 +574,12 @@ private static SessionInformationHolder refreshSessionHelper(TenantIdentifier te boolean success = storage.updateSessionInfo_Transaction(sessionHandle, Utils.hashSHA256(Utils.hashSHA256(refreshToken)), System.currentTimeMillis() + - Config.getConfig(tenantIdentifier, main).getRefreshTokenValidity(), + Config.getConfig(tenantIdentifierWithStorage, main).getRefreshTokenValidity(), sessionInfo.lastUpdatedSign); if (!success) { continue; } - return refreshSessionHelper(tenantIdentifier, main, refreshToken, refreshTokenInfo, + return refreshSessionHelper(tenantIdentifierWithStorage, main, refreshToken, refreshTokenInfo, enableAntiCsrf); } @@ -595,18 +601,17 @@ private static SessionInformationHolder refreshSessionHelper(TenantIdentifier te public static String[] revokeSessionUsingSessionHandles(Main main, String[] sessionHandles) throws StorageQueryException { - try { - return revokeSessionUsingSessionHandles(new TenantIdentifier(null, null, null), main, sessionHandles); - } catch (TenantOrAppNotFoundException e) { - throw new IllegalStateException(e); - } + Storage storage = StorageLayer.getStorage(main); + return revokeSessionUsingSessionHandles( + new TenantIdentifierWithStorage(null, null, null, storage), + sessionHandles); } - public static String[] revokeSessionUsingSessionHandles(TenantIdentifier tenantIdentifier, Main main, + public static String[] revokeSessionUsingSessionHandles(TenantIdentifierWithStorage tenantIdentifierWithStorage, String[] sessionHandles) - throws StorageQueryException, TenantOrAppNotFoundException { - int numberOfSessionsRevoked = StorageLayer.getSessionStorage(tenantIdentifier, main) - .deleteSession(tenantIdentifier, sessionHandles); + throws StorageQueryException { + int numberOfSessionsRevoked = tenantIdentifierWithStorage.getSessionStorage() + .deleteSession(tenantIdentifierWithStorage, sessionHandles); // most of the time we will enter the below if statement if (numberOfSessionsRevoked == sessionHandles.length) { @@ -620,7 +625,7 @@ public static String[] revokeSessionUsingSessionHandles(TenantIdentifier tenantI break; } - if (StorageLayer.getSessionStorage(tenantIdentifier, main).getSession(tenantIdentifier, sessionHandle) == + if (tenantIdentifierWithStorage.getSessionStorage().getSession(tenantIdentifierWithStorage, sessionHandle) == null) { result[indexIntoResult] = sessionHandle; indexIntoResult++; @@ -633,56 +638,48 @@ public static String[] revokeSessionUsingSessionHandles(TenantIdentifier tenantI @TestOnly public static String[] revokeAllSessionsForUser(Main main, String userId) throws StorageQueryException { - try { - return revokeAllSessionsForUser(new TenantIdentifier(null, null, null), main, userId); - } catch (TenantOrAppNotFoundException e) { - throw new IllegalStateException(e); - } + Storage storage = StorageLayer.getStorage(main); + return revokeAllSessionsForUser( + new TenantIdentifierWithStorage(null, null, null, storage), userId); } - public static String[] revokeAllSessionsForUser(TenantIdentifier tenantIdentifier, Main main, - String userId) throws StorageQueryException, - TenantOrAppNotFoundException { - String[] sessionHandles = getAllNonExpiredSessionHandlesForUser(tenantIdentifier, main, userId); - return revokeSessionUsingSessionHandles(tenantIdentifier, main, sessionHandles); + public static String[] revokeAllSessionsForUser(TenantIdentifierWithStorage tenantIdentifierWithStorage, + String userId) throws StorageQueryException { + String[] sessionHandles = getAllNonExpiredSessionHandlesForUser(tenantIdentifierWithStorage, userId); + return revokeSessionUsingSessionHandles(tenantIdentifierWithStorage, sessionHandles); } @TestOnly public static String[] getAllNonExpiredSessionHandlesForUser(Main main, String userId) throws StorageQueryException { - try { - return getAllNonExpiredSessionHandlesForUser(new TenantIdentifier(null, null, null), main, userId); - } catch (TenantOrAppNotFoundException e) { - throw new IllegalStateException(e); - } + Storage storage = StorageLayer.getStorage(main); + return getAllNonExpiredSessionHandlesForUser( + new TenantIdentifierWithStorage(null, null, null, storage), userId); } - public static String[] getAllNonExpiredSessionHandlesForUser(TenantIdentifier tenantIdentifier, Main main, - String userId) - throws StorageQueryException, TenantOrAppNotFoundException { - return StorageLayer.getSessionStorage(tenantIdentifier, main) - .getAllNonExpiredSessionHandlesForUser(tenantIdentifier, userId); + public static String[] getAllNonExpiredSessionHandlesForUser( + TenantIdentifierWithStorage tenantIdentifierWithStorage, String userId) + throws StorageQueryException { + return tenantIdentifierWithStorage.getSessionStorage() + .getAllNonExpiredSessionHandlesForUser(tenantIdentifierWithStorage, userId); } @TestOnly @Deprecated - public static JsonObject getSessionData(Main main, - String sessionHandle) + public static JsonObject getSessionData(Main main, String sessionHandle) throws StorageQueryException, UnauthorisedException { - try { - return getSessionData(new TenantIdentifier(null, null, null), main, sessionHandle); - } catch (TenantOrAppNotFoundException e) { - throw new IllegalStateException(e); - } + Storage storage = StorageLayer.getStorage(main); + return getSessionData( + new TenantIdentifierWithStorage(null, null, null, storage), + sessionHandle); } @Deprecated - public static JsonObject getSessionData(TenantIdentifier tenantIdentifier, Main main, + public static JsonObject getSessionData(TenantIdentifierWithStorage tenantIdentifierWithStorage, String sessionHandle) - throws StorageQueryException, UnauthorisedException, TenantOrAppNotFoundException { - io.supertokens.pluginInterface.session.SessionInfo session = StorageLayer.getSessionStorage(tenantIdentifier, - main) - .getSession(tenantIdentifier, sessionHandle); + throws StorageQueryException, UnauthorisedException { + io.supertokens.pluginInterface.session.SessionInfo session = tenantIdentifierWithStorage.getSessionStorage() + .getSession(tenantIdentifierWithStorage, sessionHandle); if (session == null || session.expiry <= System.currentTimeMillis()) { throw new UnauthorisedException("Session does not exist."); } @@ -693,19 +690,17 @@ public static JsonObject getSessionData(TenantIdentifier tenantIdentifier, Main @Deprecated public static JsonObject getJWTData(Main main, String sessionHandle) throws StorageQueryException, UnauthorisedException { - try { - return getJWTData(new TenantIdentifier(null, null, null), main, sessionHandle); - } catch (TenantOrAppNotFoundException e) { - throw new IllegalStateException(e); - } + Storage storage =StorageLayer.getStorage(main); + return getJWTData( + new TenantIdentifierWithStorage(null, null, null, storage), + sessionHandle); } @Deprecated - public static JsonObject getJWTData(TenantIdentifier tenantIdentifier, Main main, String sessionHandle) - throws StorageQueryException, UnauthorisedException, TenantOrAppNotFoundException { - io.supertokens.pluginInterface.session.SessionInfo session = StorageLayer.getSessionStorage(tenantIdentifier, - main) - .getSession(tenantIdentifier, sessionHandle); + public static JsonObject getJWTData(TenantIdentifierWithStorage tenantIdentifierWithStorage, String sessionHandle) + throws StorageQueryException, UnauthorisedException { + io.supertokens.pluginInterface.session.SessionInfo session = tenantIdentifierWithStorage.getSessionStorage() + .getSession(tenantIdentifierWithStorage, sessionHandle); if (session == null || session.expiry <= System.currentTimeMillis()) { throw new UnauthorisedException("Session does not exist."); } @@ -713,14 +708,12 @@ public static JsonObject getJWTData(TenantIdentifier tenantIdentifier, Main main } @TestOnly - public static io.supertokens.pluginInterface.session.SessionInfo getSession(Main main, - String sessionHandle) + public static io.supertokens.pluginInterface.session.SessionInfo getSession(Main main, String sessionHandle) throws StorageQueryException, UnauthorisedException { - try { - return getSession(new TenantIdentifier(null, null, null), main, sessionHandle); - } catch (TenantOrAppNotFoundException e) { - throw new IllegalStateException(e); - } + Storage storage = StorageLayer.getStorage(main); + return getSession( + new TenantIdentifierWithStorage(null, null, null, storage), + sessionHandle); } /** @@ -728,13 +721,11 @@ public static io.supertokens.pluginInterface.session.SessionInfo getSession(Main * Used by: * - /recipe/session GET */ - public static io.supertokens.pluginInterface.session.SessionInfo getSession(TenantIdentifier tenantIdentifier, - Main main, - String sessionHandle) - throws StorageQueryException, UnauthorisedException, TenantOrAppNotFoundException { - io.supertokens.pluginInterface.session.SessionInfo session = StorageLayer.getSessionStorage(tenantIdentifier, - main) - .getSession(tenantIdentifier, sessionHandle); + public static io.supertokens.pluginInterface.session.SessionInfo getSession( + TenantIdentifierWithStorage tenantIdentifierWithStorage, String sessionHandle) + throws StorageQueryException, UnauthorisedException { + io.supertokens.pluginInterface.session.SessionInfo session = tenantIdentifierWithStorage.getSessionStorage() + .getSession(tenantIdentifierWithStorage, sessionHandle); // If there is no session, or session is expired if (session == null || session.expiry <= System.currentTimeMillis()) { @@ -749,27 +740,24 @@ public static void updateSession(Main main, String sessionHandle, @Nullable JsonObject sessionData, @Nullable JsonObject jwtData, @Nullable Long lmrt) throws StorageQueryException, UnauthorisedException { - try { - updateSession(new TenantIdentifier(null, null, null), main, sessionHandle, sessionData, jwtData, lmrt); - } catch (TenantOrAppNotFoundException e) { - throw new IllegalStateException(e); - } + Storage storage = StorageLayer.getStorage(main); + updateSession(new TenantIdentifierWithStorage(null, null, null, storage), + sessionHandle, sessionData, jwtData, lmrt); } - public static void updateSession(TenantIdentifier tenantIdentifier, Main main, String sessionHandle, + public static void updateSession(TenantIdentifierWithStorage tenantIdentifierWithStorage, String sessionHandle, @Nullable JsonObject sessionData, @Nullable JsonObject jwtData, @Nullable Long lmrt) - throws StorageQueryException, UnauthorisedException, TenantOrAppNotFoundException { - io.supertokens.pluginInterface.session.SessionInfo session = StorageLayer.getSessionStorage(tenantIdentifier, - main) - .getSession(tenantIdentifier, sessionHandle); + throws StorageQueryException, UnauthorisedException { + io.supertokens.pluginInterface.session.SessionInfo session = tenantIdentifierWithStorage.getSessionStorage() + .getSession(tenantIdentifierWithStorage, sessionHandle); // If there is no session, or session is expired if (session == null || session.expiry <= System.currentTimeMillis()) { throw new UnauthorisedException("Session does not exist."); } - int numberOfRowsAffected = StorageLayer.getSessionStorage(tenantIdentifier, main) - .updateSession(tenantIdentifier, sessionHandle, sessionData, + int numberOfRowsAffected = tenantIdentifierWithStorage.getSessionStorage() + .updateSession(tenantIdentifierWithStorage, sessionHandle, sessionData, jwtData); // TODO: update lmrt as well if (numberOfRowsAffected != 1) { throw new UnauthorisedException("Session does not exist."); diff --git a/src/main/java/io/supertokens/session/accessToken/AccessToken.java b/src/main/java/io/supertokens/session/accessToken/AccessToken.java index 7e6c05dc9..d8c9581b4 100644 --- a/src/main/java/io/supertokens/session/accessToken/AccessToken.java +++ b/src/main/java/io/supertokens/session/accessToken/AccessToken.java @@ -128,6 +128,10 @@ private static AccessTokenInfo getInfoFromAccessToken(AppIdentifier appIdentifie throw new TryRefreshTokenException("Access token expired"); } + // There is no need to check if the appIdentifier (from request) is the same app in which the + // accessToken was created, because, each app has a different accessTokenSigningKey and + // when a cross app request is made, token decoding will fail and result in TRY_REFRESH_TOKEN + // Hence, we don't bother storing any info related to app in the accessTokenPayload. return new AccessTokenInfo(tokenInfo.sessionHandle, tokenInfo.userId, tokenInfo.refreshTokenHash1, tokenInfo.expiryTime, tokenInfo.parentRefreshTokenHash1, tokenInfo.userData, tokenInfo.antiCsrfToken, tokenInfo.timeCreated, tokenInfo.lmrt, @@ -155,14 +159,21 @@ public static AccessTokenInfo getInfoFromAccessToken(AppIdentifier appIdentifier } @TestOnly - public static AccessTokenInfo getInfoFromAccessTokenWithoutVerifying(@Nonnull String token) { - return getInfoFromAccessTokenWithoutVerifying(new AppIdentifier(null, null), token); + public static AccessTokenInfo getInfoFromAccessTokenWithoutVerifying(Main main, @Nonnull String token) { + try { + return getInfoFromAccessTokenWithoutVerifying( + new AppIdentifier(null, null), token); + } catch (TenantOrAppNotFoundException | NoSuchAlgorithmException e) { + throw new IllegalStateException(e); + } } public static AccessTokenInfo getInfoFromAccessTokenWithoutVerifying(AppIdentifier appIdentifier, - @Nonnull String token) { + @Nonnull String token) + throws TenantOrAppNotFoundException, NoSuchAlgorithmException { AccessTokenPayload tokenInfo = new Gson().fromJson(JWT.getPayloadWithoutVerifying(token).payload, AccessTokenPayload.class); + return new AccessTokenInfo(tokenInfo.sessionHandle, tokenInfo.userId, tokenInfo.refreshTokenHash1, tokenInfo.expiryTime, tokenInfo.parentRefreshTokenHash1, tokenInfo.userData, tokenInfo.antiCsrfToken, tokenInfo.timeCreated, tokenInfo.lmrt, diff --git a/src/main/java/io/supertokens/session/accessToken/AccessTokenSigningKey.java b/src/main/java/io/supertokens/session/accessToken/AccessTokenSigningKey.java index c51850d86..d822ea009 100644 --- a/src/main/java/io/supertokens/session/accessToken/AccessTokenSigningKey.java +++ b/src/main/java/io/supertokens/session/accessToken/AccessTokenSigningKey.java @@ -150,23 +150,35 @@ synchronized void removeKeyFromMemoryIfItHasNotChanged(List oldKeyInfo) public synchronized void transferLegacyKeyToNewTable() throws StorageQueryException, StorageTransactionLogicException, TenantOrAppNotFoundException { - Storage storage = StorageLayer.getSessionStorage(this.appIdentifier.getAsPublicTenantIdentifier(), main); + Storage storage = StorageLayer.getStorage(this.appIdentifier.getAsPublicTenantIdentifier(), main); if (storage.getType() == STORAGE_TYPE.SQL) { SessionSQLStorage sqlStorage = (SessionSQLStorage) storage; - // start transaction - sqlStorage.startTransaction(con -> { - KeyValueInfo legacyKey = sqlStorage.getLegacyAccessTokenSigningKey_Transaction( - appIdentifier, con); + try { + // start transaction + sqlStorage.startTransaction(con -> { + KeyValueInfo legacyKey = sqlStorage.getLegacyAccessTokenSigningKey_Transaction( + appIdentifier, con); - if (legacyKey != null) { - sqlStorage.addAccessTokenSigningKey_Transaction(appIdentifier, con, legacyKey); - sqlStorage.removeLegacyAccessTokenSigningKey_Transaction(appIdentifier, con); - sqlStorage.commitTransaction(con); + if (legacyKey != null) { + try { + sqlStorage.addAccessTokenSigningKey_Transaction(appIdentifier, con, legacyKey); + } catch (TenantOrAppNotFoundException e) { + throw new StorageTransactionLogicException(e); + } + sqlStorage.removeLegacyAccessTokenSigningKey_Transaction(appIdentifier, con); + sqlStorage.commitTransaction(con); + } + return legacyKey; + }); + } catch (StorageTransactionLogicException e) { + if (e.actualException instanceof TenantOrAppNotFoundException) { + throw (TenantOrAppNotFoundException) e.actualException; } - return legacyKey; - }); + throw e; + } + } else { SessionNoSQLStorage_1 noSQLStorage = (SessionNoSQLStorage_1) storage; KeyValueInfoWithLastUpdated legacyKey = noSQLStorage.getLegacyAccessTokenSigningKey_Transaction(); @@ -186,7 +198,7 @@ public synchronized void transferLegacyKeyToNewTable() public synchronized void cleanExpiredAccessTokenSigningKeys() throws StorageQueryException, TenantOrAppNotFoundException { - SessionStorage storage = StorageLayer.getSessionStorage(this.appIdentifier.getAsPublicTenantIdentifier(), main); + SessionStorage storage = (SessionStorage) StorageLayer.getStorage(this.appIdentifier.getAsPublicTenantIdentifier(), main); CoreConfig config = Config.getConfig(this.appIdentifier.getAsPublicTenantIdentifier(), main); if (config.getAccessTokenSigningKeyDynamic()) { @@ -232,7 +244,7 @@ public synchronized long getKeyExpiryTime() private List maybeGenerateNewKeyAndUpdateInDb() throws StorageQueryException, StorageTransactionLogicException, TenantOrAppNotFoundException { - Storage storage = StorageLayer.getSessionStorage(this.appIdentifier.getAsPublicTenantIdentifier(), main); + Storage storage = StorageLayer.getStorage(this.appIdentifier.getAsPublicTenantIdentifier(), main); CoreConfig config = Config.getConfig(this.appIdentifier.getAsPublicTenantIdentifier(), main); // Access token signing keys older than this are deleted (ms) @@ -250,42 +262,54 @@ private List maybeGenerateNewKeyAndUpdateInDb() if (storage.getType() == STORAGE_TYPE.SQL) { SessionSQLStorage sqlStorage = (SessionSQLStorage) storage; - // start transaction - validKeys = sqlStorage.startTransaction(con -> { - List validKeysFromSQL = new ArrayList(); - - // We have to generate a new key if we couldn't find one we can use for signing - boolean generateNewKey = true; - - KeyValueInfo[] keysFromStorage = sqlStorage.getAccessTokenSigningKeys_Transaction(appIdentifier, - con); - - for (KeyValueInfo key : keysFromStorage) { - if (keysCreatedAfterCanVerify <= key.createdAtTime) { - if (keysCreatedAfterCanSign <= key.createdAtTime) { - generateNewKey = false; + try { + // start transaction + validKeys = sqlStorage.startTransaction(con -> { + List validKeysFromSQL = new ArrayList(); + + // We have to generate a new key if we couldn't find one we can use for signing + boolean generateNewKey = true; + + KeyValueInfo[] keysFromStorage = sqlStorage.getAccessTokenSigningKeys_Transaction(appIdentifier, + con); + + for (KeyValueInfo key : keysFromStorage) { + if (keysCreatedAfterCanVerify <= key.createdAtTime) { + if (keysCreatedAfterCanSign <= key.createdAtTime) { + generateNewKey = false; + } + validKeysFromSQL.add(new KeyInfo(key.value, key.createdAtTime, signingKeyLifetime)); } - validKeysFromSQL.add(new KeyInfo(key.value, key.createdAtTime, signingKeyLifetime)); } - } - if (generateNewKey) { - String signingKey; - try { - Utils.PubPriKey rsaKeys = Utils.generateNewPubPriKey(); - signingKey = rsaKeys.toString(); - } catch (NoSuchAlgorithmException e) { - throw new StorageTransactionLogicException(e); + if (generateNewKey) { + String signingKey; + try { + Utils.PubPriKey rsaKeys = Utils.generateNewPubPriKey(); + signingKey = rsaKeys.toString(); + } catch (NoSuchAlgorithmException e) { + throw new StorageTransactionLogicException(e); + } + KeyInfo newKey = new KeyInfo(signingKey, System.currentTimeMillis(), signingKeyLifetime); + try { + sqlStorage.addAccessTokenSigningKey_Transaction(appIdentifier, con, + new KeyValueInfo(newKey.value, newKey.createdAtTime)); + } catch (TenantOrAppNotFoundException e) { + throw new StorageTransactionLogicException(e); + } + validKeysFromSQL.add(newKey); } - KeyInfo newKey = new KeyInfo(signingKey, System.currentTimeMillis(), signingKeyLifetime); - sqlStorage.addAccessTokenSigningKey_Transaction(appIdentifier, con, - new KeyValueInfo(newKey.value, newKey.createdAtTime)); - validKeysFromSQL.add(newKey); + + sqlStorage.commitTransaction(con); + return validKeysFromSQL; + }); + } catch (StorageTransactionLogicException e) { + if (e.actualException instanceof TenantOrAppNotFoundException) { + throw (TenantOrAppNotFoundException) e.actualException; } + throw e; + } - sqlStorage.commitTransaction(con); - return validKeysFromSQL; - }); } else if (storage.getType() == STORAGE_TYPE.NOSQL_1) { SessionNoSQLStorage_1 noSQLStorage = (SessionNoSQLStorage_1) storage; diff --git a/src/main/java/io/supertokens/session/refreshToken/RefreshToken.java b/src/main/java/io/supertokens/session/refreshToken/RefreshToken.java index a9e653bd9..03388ae4a 100644 --- a/src/main/java/io/supertokens/session/refreshToken/RefreshToken.java +++ b/src/main/java/io/supertokens/session/refreshToken/RefreshToken.java @@ -73,6 +73,7 @@ public static RefreshTokenInfo getInfoFromRefreshToken(AppIdentifier appIdentifi || !nonce.equals(tokenPayload.nonce)) { throw new UnauthorisedException("Invalid refresh token"); } + return new RefreshTokenInfo(tokenPayload.sessionHandle, tokenPayload.userId, tokenPayload.parentRefreshTokenHash1, null, tokenPayload.antiCsrfToken, tokenType, new TenantIdentifier(appIdentifier.getConnectionUriDomain(), appIdentifier.getAppId(), diff --git a/src/main/java/io/supertokens/session/refreshToken/RefreshTokenKey.java b/src/main/java/io/supertokens/session/refreshToken/RefreshTokenKey.java index ea0bac0a6..db11b314d 100644 --- a/src/main/java/io/supertokens/session/refreshToken/RefreshTokenKey.java +++ b/src/main/java/io/supertokens/session/refreshToken/RefreshTokenKey.java @@ -126,34 +126,46 @@ public String getKey() throws StorageQueryException, StorageTransactionLogicExce private String maybeGenerateNewKeyAndUpdateInDb() throws StorageQueryException, StorageTransactionLogicException, TenantOrAppNotFoundException { - SessionStorage storage = StorageLayer.getSessionStorage(this.appIdentifier.getAsPublicTenantIdentifier(), main); + SessionStorage storage = (SessionStorage) StorageLayer.getStorage(this.appIdentifier.getAsPublicTenantIdentifier(), main); if (storage.getType() == STORAGE_TYPE.SQL) { SessionSQLStorage sqlStorage = (SessionSQLStorage) storage; - // start transaction - return sqlStorage.startTransaction(con -> { - String key = null; - KeyValueInfo keyFromStorage = sqlStorage.getRefreshTokenSigningKey_Transaction(appIdentifier, con); - if (keyFromStorage != null) { - key = keyFromStorage.value; - } + try { + // start transaction + return sqlStorage.startTransaction(con -> { + String key = null; + KeyValueInfo keyFromStorage = sqlStorage.getRefreshTokenSigningKey_Transaction(appIdentifier, con); + if (keyFromStorage != null) { + key = keyFromStorage.value; + } - if (key == null) { - try { - key = Utils.generateNewSigningKey(); - } catch (NoSuchAlgorithmException | InvalidKeySpecException e) { - throw new StorageTransactionLogicException(e); + if (key == null) { + try { + key = Utils.generateNewSigningKey(); + } catch (NoSuchAlgorithmException | InvalidKeySpecException e) { + throw new StorageTransactionLogicException(e); + } + try { + sqlStorage.setRefreshTokenSigningKey_Transaction(appIdentifier, con, + new KeyValueInfo(key, System.currentTimeMillis())); + } catch (TenantOrAppNotFoundException e) { + throw new StorageTransactionLogicException(e); + } } - sqlStorage.setRefreshTokenSigningKey_Transaction(appIdentifier, con, - new KeyValueInfo(key, System.currentTimeMillis())); - } - sqlStorage.commitTransaction(con); - return key; + sqlStorage.commitTransaction(con); + return key; + + }); + } catch (StorageTransactionLogicException e) { + if (e.actualException instanceof TenantOrAppNotFoundException) { + throw (TenantOrAppNotFoundException) e.actualException; + } + throw e; + } - }); } else if (storage.getType() == STORAGE_TYPE.NOSQL_1) { SessionNoSQLStorage_1 noSQLStorage = (SessionNoSQLStorage_1) storage; diff --git a/src/main/java/io/supertokens/storageLayer/StorageLayer.java b/src/main/java/io/supertokens/storageLayer/StorageLayer.java index c487f1fc3..e1f300bc4 100644 --- a/src/main/java/io/supertokens/storageLayer/StorageLayer.java +++ b/src/main/java/io/supertokens/storageLayer/StorageLayer.java @@ -328,22 +328,6 @@ public static AuthRecipeStorage getAuthRecipeStorage(Main main) { } } - public static SessionStorage getSessionStorage(TenantIdentifier tenantIdentifier, Main main) - throws TenantOrAppNotFoundException { - // TODO remove this function - return (SessionStorage) getInstance(tenantIdentifier, main).storage; - } - - @TestOnly - public static SessionStorage getSessionStorage(Main main) { - // TODO remove this function - try { - return getSessionStorage(new TenantIdentifier(null, null, null), main); - } catch (TenantOrAppNotFoundException e) { - throw new IllegalStateException(e); - } - } - public static JWTRecipeStorage getJWTRecipeStorage(TenantIdentifier tenantIdentifier, Main main) throws TenantOrAppNotFoundException { // TODO remove this function diff --git a/src/main/java/io/supertokens/webserver/WebserverAPI.java b/src/main/java/io/supertokens/webserver/WebserverAPI.java index a2ce52c2e..35b330f46 100644 --- a/src/main/java/io/supertokens/webserver/WebserverAPI.java +++ b/src/main/java/io/supertokens/webserver/WebserverAPI.java @@ -329,7 +329,7 @@ protected void service(HttpServletRequest req, HttpServletResponse resp) throws sendTextResponse(401, "Invalid API key", resp); } else if (rootCause instanceof TenantOrAppNotFoundException) { sendTextResponse(400, - "AppId or tenantId not found => " + ((TenantOrAppNotFoundException) e).getMessage(), + "AppId or tenantId not found => " + ((TenantOrAppNotFoundException) rootCause).getMessage(), resp); } else if (rootCause instanceof BadPermissionException) { sendTextResponse(403, e.getMessage(), resp); diff --git a/src/main/java/io/supertokens/webserver/api/session/HandshakeAPI.java b/src/main/java/io/supertokens/webserver/api/session/HandshakeAPI.java index a734bae80..0a1cca304 100644 --- a/src/main/java/io/supertokens/webserver/api/session/HandshakeAPI.java +++ b/src/main/java/io/supertokens/webserver/api/session/HandshakeAPI.java @@ -23,6 +23,7 @@ import io.supertokens.pluginInterface.RECIPE_ID; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; import io.supertokens.session.accessToken.AccessTokenSigningKey; import io.supertokens.session.accessToken.AccessTokenSigningKey.KeyInfo; @@ -49,35 +50,34 @@ public String getPath() { @Override protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + // API is tenant specific try { JsonObject result = new JsonObject(); result.addProperty("status", "OK"); + TenantIdentifier tenantIdentifier = this.getTenantIdentifierWithStorageFromRequest(req); + result.addProperty("jwtSigningPublicKey", new Utils.PubPriKey( - AccessTokenSigningKey.getInstance(this.getTenantIdentifierWithStorageFromRequest(req).toAppIdentifier(), + AccessTokenSigningKey.getInstance(tenantIdentifier.toAppIdentifier(), main).getLatestIssuedKey().value).publicKey); result.addProperty("jwtSigningPublicKeyExpiryTime", - AccessTokenSigningKey.getInstance(this.getTenantIdentifierWithStorageFromRequest(req).toAppIdentifier(), main) + AccessTokenSigningKey.getInstance(tenantIdentifier.toAppIdentifier(), main) .getKeyExpiryTime()); if (!super.getVersionFromRequest(req).equals("2.7") && !super.getVersionFromRequest(req).equals("2.8")) { - List keys = AccessTokenSigningKey.getInstance(this.getTenantIdentifierWithStorageFromRequest(req).toAppIdentifier(), - main) + List keys = AccessTokenSigningKey.getInstance(tenantIdentifier.toAppIdentifier(), main) .getAllKeys(); JsonArray jwtSigningPublicKeyListJSON = Utils.keyListToJson(keys); result.add("jwtSigningPublicKeyList", jwtSigningPublicKeyListJSON); } result.addProperty("accessTokenBlacklistingEnabled", - Config.getConfig(this.getTenantIdentifierWithStorageFromRequest(req), main) - .getAccessTokenBlacklisting()); + Config.getConfig(tenantIdentifier, main).getAccessTokenBlacklisting()); result.addProperty("accessTokenValidity", - Config.getConfig(this.getTenantIdentifierWithStorageFromRequest(req), main) - .getAccessTokenValidity()); + Config.getConfig(tenantIdentifier, main).getAccessTokenValidity()); result.addProperty("refreshTokenValidity", - Config.getConfig(this.getTenantIdentifierWithStorageFromRequest(req), main) - .getRefreshTokenValidity()); + Config.getConfig(tenantIdentifier, main).getRefreshTokenValidity()); super.sendJsonResponse(200, result, resp); } catch (StorageQueryException | StorageTransactionLogicException | TenantOrAppNotFoundException e) { throw new ServletException(e); diff --git a/src/main/java/io/supertokens/webserver/api/session/JWTDataAPI.java b/src/main/java/io/supertokens/webserver/api/session/JWTDataAPI.java index 5e42d9155..27b69f983 100644 --- a/src/main/java/io/supertokens/webserver/api/session/JWTDataAPI.java +++ b/src/main/java/io/supertokens/webserver/api/session/JWTDataAPI.java @@ -48,6 +48,7 @@ public String getPath() { @Override protected void doPut(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + // API is tenant specific JsonObject input = InputParser.parseJsonObjectOrThrowError(req); String sessionHandle = InputParser.parseStringOrThrowError(input, "sessionHandle", false); @@ -57,7 +58,7 @@ protected void doPut(HttpServletRequest req, HttpServletResponse resp) throws IO assert userDataInJWT != null; try { - Session.updateSession(this.getTenantIdentifierWithStorageFromRequest(req), main, sessionHandle, null, + Session.updateSession(this.getTenantIdentifierWithStorageFromRequest(req), sessionHandle, null, userDataInJWT, null); JsonObject result = new JsonObject(); @@ -79,12 +80,12 @@ protected void doPut(HttpServletRequest req, HttpServletResponse resp) throws IO @Override @Deprecated protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + // API is tenant specific String sessionHandle = InputParser.getQueryParamOrThrowError(req, "sessionHandle", false); assert sessionHandle != null; try { - JsonElement jwtPayload = Session.getJWTData(this.getTenantIdentifierWithStorageFromRequest(req), main, - sessionHandle); + JsonElement jwtPayload = Session.getJWTData(this.getTenantIdentifierWithStorageFromRequest(req), sessionHandle); JsonObject result = new JsonObject(); diff --git a/src/main/java/io/supertokens/webserver/api/session/RefreshSessionAPI.java b/src/main/java/io/supertokens/webserver/api/session/RefreshSessionAPI.java index d8b0c718b..4011dc2a9 100644 --- a/src/main/java/io/supertokens/webserver/api/session/RefreshSessionAPI.java +++ b/src/main/java/io/supertokens/webserver/api/session/RefreshSessionAPI.java @@ -50,6 +50,7 @@ public String getPath() { @Override protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + // API is app specific, but session is updated based on tenantId obtained from the refreshToken JsonObject input = InputParser.parseJsonObjectOrThrowError(req); String refreshToken = InputParser.parseStringOrThrowError(input, "refreshToken", false); String antiCsrfToken = InputParser.parseStringOrThrowError(input, "antiCsrfToken", true); @@ -59,7 +60,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I try { SessionInformationHolder sessionInfo = Session.refreshSession( - this.getTenantIdentifierWithStorageFromRequest(req).toAppIdentifier(), main, + this.getAppIdentifierWithStorage(req), main, refreshToken, antiCsrfToken, enableAntiCsrf); JsonObject result = sessionInfo.toJsonObject(); diff --git a/src/main/java/io/supertokens/webserver/api/session/SessionAPI.java b/src/main/java/io/supertokens/webserver/api/session/SessionAPI.java index ee13ab7ed..fcecd56c8 100644 --- a/src/main/java/io/supertokens/webserver/api/session/SessionAPI.java +++ b/src/main/java/io/supertokens/webserver/api/session/SessionAPI.java @@ -63,6 +63,7 @@ public String getPath() { @Override protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + // API is tenant specific JsonObject input = InputParser.parseJsonObjectOrThrowError(req); String userId = InputParser.parseStringOrThrowError(input, "userId", false); assert userId != null; @@ -74,8 +75,8 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I assert userDataInDatabase != null; try { - SessionInformationHolder sessionInfo = Session.createNewSession(this.getTenantIdentifierWithStorageFromRequest(req), main, userId, - userDataInJWT, + SessionInformationHolder sessionInfo = Session.createNewSession( + this.getTenantIdentifierWithStorageFromRequest(req), main, userId, userDataInJWT, userDataInDatabase, enableAntiCsrf); JsonObject result = sessionInfo.toJsonObject(); @@ -106,11 +107,12 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + // API is tenant specific String sessionHandle = InputParser.getQueryParamOrThrowError(req, "sessionHandle", false); assert sessionHandle != null; try { - SessionInfo sessionInfo = Session.getSession(this.getTenantIdentifierWithStorageFromRequest(req), main, sessionHandle); + SessionInfo sessionInfo = Session.getSession(this.getTenantIdentifierWithStorageFromRequest(req), sessionHandle); JsonObject result = new Gson().toJsonTree(sessionInfo).getAsJsonObject(); result.add("userDataInJWT", Utils.toJsonTreeWithNulls(sessionInfo.userDataInJWT)); diff --git a/src/main/java/io/supertokens/webserver/api/session/SessionDataAPI.java b/src/main/java/io/supertokens/webserver/api/session/SessionDataAPI.java index dbcf3c0ab..c2f67900d 100644 --- a/src/main/java/io/supertokens/webserver/api/session/SessionDataAPI.java +++ b/src/main/java/io/supertokens/webserver/api/session/SessionDataAPI.java @@ -48,11 +48,12 @@ public String getPath() { @Override @Deprecated protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + // API is tenant specific String sessionHandle = InputParser.getQueryParamOrThrowError(req, "sessionHandle", false); assert sessionHandle != null; try { - JsonObject userDataInDatabase = Session.getSessionData(this.getTenantIdentifierWithStorageFromRequest(req), main, sessionHandle); + JsonObject userDataInDatabase = Session.getSessionData(this.getTenantIdentifierWithStorageFromRequest(req), sessionHandle); JsonObject result = new JsonObject(); result.addProperty("status", "OK"); @@ -79,7 +80,7 @@ protected void doPut(HttpServletRequest req, HttpServletResponse resp) throws IO assert userDataInDatabase != null; try { - Session.updateSession(this.getTenantIdentifierWithStorageFromRequest(req), main, sessionHandle, + Session.updateSession(this.getTenantIdentifierWithStorageFromRequest(req), sessionHandle, userDataInDatabase, null, null); JsonObject result = new JsonObject(); diff --git a/src/main/java/io/supertokens/webserver/api/session/SessionRegenerateAPI.java b/src/main/java/io/supertokens/webserver/api/session/SessionRegenerateAPI.java index b32b9556b..c6c5ee230 100644 --- a/src/main/java/io/supertokens/webserver/api/session/SessionRegenerateAPI.java +++ b/src/main/java/io/supertokens/webserver/api/session/SessionRegenerateAPI.java @@ -54,6 +54,7 @@ public String getPath() { @Override protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + // API is app specific, but the session is updated based on tenantId obtained from the accessToken JsonObject input = InputParser.parseJsonObjectOrThrowError(req); String accessToken = InputParser.parseStringOrThrowError(input, "accessToken", false); @@ -63,7 +64,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I try { SessionInformationHolder sessionInfo = Session.regenerateToken( - this.getTenantIdentifierWithStorageFromRequest(req).toAppIdentifier(), main, + this.getAppIdentifierWithStorage(req), main, accessToken, userDataInJWT); JsonObject result = sessionInfo.toJsonObject(); diff --git a/src/main/java/io/supertokens/webserver/api/session/SessionRemoveAPI.java b/src/main/java/io/supertokens/webserver/api/session/SessionRemoveAPI.java index 00851a6db..62d15c3c6 100644 --- a/src/main/java/io/supertokens/webserver/api/session/SessionRemoveAPI.java +++ b/src/main/java/io/supertokens/webserver/api/session/SessionRemoveAPI.java @@ -46,6 +46,7 @@ public String getPath() { @Override protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + // API is tenant specific JsonObject input = InputParser.parseJsonObjectOrThrowError(req); String userId = InputParser.parseStringOrThrowError(input, "userId", true); @@ -74,8 +75,8 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I if (userId != null) { try { - String[] sessionHandlesRevoked = Session.revokeAllSessionsForUser(this.getTenantIdentifierWithStorageFromRequest(req), main, - userId); + String[] sessionHandlesRevoked = Session.revokeAllSessionsForUser( + this.getTenantIdentifierWithStorageFromRequest(req), userId); JsonObject result = new JsonObject(); result.addProperty("status", "OK"); JsonArray sessionHandlesRevokedJSON = new JsonArray(); @@ -90,7 +91,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I } else { try { String[] sessionHandlesRevoked = Session.revokeSessionUsingSessionHandles( - this.getTenantIdentifierWithStorageFromRequest(req), main, sessionHandles); + this.getTenantIdentifierWithStorageFromRequest(req), sessionHandles); JsonObject result = new JsonObject(); result.addProperty("status", "OK"); JsonArray sessionHandlesRevokedJSON = new JsonArray(); diff --git a/src/main/java/io/supertokens/webserver/api/session/SessionUserAPI.java b/src/main/java/io/supertokens/webserver/api/session/SessionUserAPI.java index 893761f64..138c3b369 100644 --- a/src/main/java/io/supertokens/webserver/api/session/SessionUserAPI.java +++ b/src/main/java/io/supertokens/webserver/api/session/SessionUserAPI.java @@ -47,12 +47,13 @@ public String getPath() { @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + // API is tenant specific String userId = InputParser.getQueryParamOrThrowError(req, "userId", false); assert userId != null; try { - String[] sessionHandles = Session.getAllNonExpiredSessionHandlesForUser(this.getTenantIdentifierWithStorageFromRequest(req), main, - userId); + String[] sessionHandles = Session.getAllNonExpiredSessionHandlesForUser( + this.getTenantIdentifierWithStorageFromRequest(req), userId); JsonObject result = new JsonObject(); result.addProperty("status", "OK"); diff --git a/src/main/java/io/supertokens/webserver/api/session/VerifySessionAPI.java b/src/main/java/io/supertokens/webserver/api/session/VerifySessionAPI.java index 39851c8ef..97a4f7971 100644 --- a/src/main/java/io/supertokens/webserver/api/session/VerifySessionAPI.java +++ b/src/main/java/io/supertokens/webserver/api/session/VerifySessionAPI.java @@ -25,6 +25,7 @@ import io.supertokens.pluginInterface.RECIPE_ID; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.multitenancy.AppIdentifier; import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; import io.supertokens.session.Session; import io.supertokens.session.accessToken.AccessTokenSigningKey; @@ -38,6 +39,7 @@ import jakarta.servlet.http.HttpServletResponse; import java.io.IOException; +import java.security.NoSuchAlgorithmException; import java.util.List; public class VerifySessionAPI extends WebserverAPI { @@ -55,6 +57,7 @@ public String getPath() { @Override protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + // API is app specific, but the session is fetched based on tenantId obtained from the accessToken JsonObject input = InputParser.parseJsonObjectOrThrowError(req); String accessToken = InputParser.parseStringOrThrowError(input, "accessToken", false); assert accessToken != null; @@ -64,8 +67,17 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I Boolean enableAntiCsrf = InputParser.parseBooleanOrThrowError(input, "enableAntiCsrf", false); assert enableAntiCsrf != null; + AppIdentifier appIdentifier; try { - SessionInformationHolder sessionInfo = Session.getSession(this.getTenantIdentifierWithStorageFromRequest(req).toAppIdentifier(), + // We actually don't use the storage because tenantId is obtained from the accessToken, + // and appropriate storage is obtained later + appIdentifier = this.getAppIdentifierWithStorage(req); + } catch (TenantOrAppNotFoundException e) { + throw new ServletException(e); + } + + try { + SessionInformationHolder sessionInfo = Session.getSession(appIdentifier, main, accessToken, antiCsrfToken, enableAntiCsrf, doAntiCsrfCheck); @@ -75,15 +87,13 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I result.addProperty("jwtSigningPublicKey", new Utils.PubPriKey( - AccessTokenSigningKey.getInstance(this.getTenantIdentifierWithStorageFromRequest(req).toAppIdentifier(), main) + AccessTokenSigningKey.getInstance(appIdentifier, main) .getLatestIssuedKey().value).publicKey); result.addProperty("jwtSigningPublicKeyExpiryTime", - AccessTokenSigningKey.getInstance(this.getTenantIdentifierWithStorageFromRequest(req).toAppIdentifier(), main) - .getKeyExpiryTime()); + AccessTokenSigningKey.getInstance(appIdentifier, main).getKeyExpiryTime()); if (!super.getVersionFromRequest(req).equals("2.7") && !super.getVersionFromRequest(req).equals("2.8")) { - List keys = AccessTokenSigningKey.getInstance(this.getTenantIdentifierWithStorageFromRequest(req).toAppIdentifier(), - main) + List keys = AccessTokenSigningKey.getInstance(appIdentifier, main) .getAllKeys(); JsonArray jwtSigningPublicKeyListJSON = Utils.keyListToJson(keys); result.add("jwtSigningPublicKeyList", jwtSigningPublicKeyListJSON); @@ -105,17 +115,16 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I reply.addProperty("status", "TRY_REFRESH_TOKEN"); reply.addProperty("jwtSigningPublicKey", new Utils.PubPriKey( - AccessTokenSigningKey.getInstance(this.getTenantIdentifierWithStorageFromRequest(req).toAppIdentifier(), main) + AccessTokenSigningKey.getInstance(appIdentifier, main) .getLatestIssuedKey().value).publicKey); reply.addProperty("jwtSigningPublicKeyExpiryTime", - AccessTokenSigningKey.getInstance(this.getTenantIdentifierWithStorageFromRequest(req).toAppIdentifier(), main) + AccessTokenSigningKey.getInstance(appIdentifier, main) .getKeyExpiryTime()); if (!super.getVersionFromRequest(req).equals("2.7") && !super.getVersionFromRequest(req).equals("2.8")) { List keys = AccessTokenSigningKey.getInstance( - this.getTenantIdentifierWithStorageFromRequest(req).toAppIdentifier(), main) - .getAllKeys(); + appIdentifier, main).getAllKeys(); JsonArray jwtSigningPublicKeyListJSON = Utils.keyListToJson(keys); reply.add("jwtSigningPublicKeyList", jwtSigningPublicKeyListJSON); } diff --git a/src/test/java/io/supertokens/test/InMemoryDBStorageTest.java b/src/test/java/io/supertokens/test/InMemoryDBStorageTest.java index 689d3b97b..178d02b48 100644 --- a/src/test/java/io/supertokens/test/InMemoryDBStorageTest.java +++ b/src/test/java/io/supertokens/test/InMemoryDBStorageTest.java @@ -22,6 +22,7 @@ import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; +import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; import io.supertokens.pluginInterface.sqlStorage.SQLStorage; import io.supertokens.storageLayer.StorageLayer; import org.junit.AfterClass; @@ -72,8 +73,12 @@ public void transactionIsolationTesting() Storage storage = StorageLayer.getStorage(process.getProcess()); SQLStorage sqlStorage = (SQLStorage) storage; sqlStorage.startTransaction(con -> { - sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", - new KeyValueInfo("Value")); + try { + sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", + new KeyValueInfo("Value")); + } catch (TenantOrAppNotFoundException e) { + throw new IllegalStateException(e); + } sqlStorage.commitTransaction(con); return null; }); @@ -96,8 +101,12 @@ public void transactionIsolationTesting() syncObject.notifyAll(); } - sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", - new KeyValueInfo("Value2")); + try { + sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", + new KeyValueInfo("Value2")); + } catch (TenantOrAppNotFoundException e) { + throw new IllegalStateException(e); + } try { Thread.sleep(1500); @@ -185,8 +194,12 @@ public void transactionTest() throws InterruptedException, StorageQueryException Storage storage = StorageLayer.getStorage(process.getProcess()); SQLStorage sqlStorage = (SQLStorage) storage; String returnedValue = sqlStorage.startTransaction(con -> { - sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", - new KeyValueInfo("Value")); + try { + sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", + new KeyValueInfo("Value")); + } catch (TenantOrAppNotFoundException e) { + throw new IllegalStateException(e); + } sqlStorage.commitTransaction(con); return "returned value"; }); @@ -210,8 +223,12 @@ public void transactionDoNotCommitButStillCommitsTest() SQLStorage sqlStorage = (SQLStorage) storage; sqlStorage.startTransaction(con -> { - sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", - new KeyValueInfo("Value")); + try { + sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", + new KeyValueInfo("Value")); + } catch (TenantOrAppNotFoundException e) { + throw new IllegalStateException(e); + } return null; }); KeyValueInfo value = storage.getKeyValue(new TenantIdentifier(null, null, null), "Key"); @@ -235,8 +252,12 @@ public void transactionThrowCompileTimeErrorAndExpectRollbackTest() SQLStorage sqlStorage = (SQLStorage) storage; try { sqlStorage.startTransaction(con -> { - sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", - new KeyValueInfo("Value")); + try { + sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", + new KeyValueInfo("Value")); + } catch (TenantOrAppNotFoundException e) { + throw new IllegalStateException(e); + } throw new StorageTransactionLogicException(new Exception("error message")); }); } catch (StorageTransactionLogicException e) { @@ -264,8 +285,12 @@ public void transactionThrowRunTimeErrorAndExpectRollbackTest() SQLStorage sqlStorage = (SQLStorage) storage; try { sqlStorage.startTransaction(con -> { - sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", - new KeyValueInfo("Value")); + try { + sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", + new KeyValueInfo("Value")); + } catch (TenantOrAppNotFoundException e) { + throw new IllegalStateException(e); + } throw new RuntimeException("error message"); }); } catch (RuntimeException e) { diff --git a/src/test/java/io/supertokens/test/InMemoryDBTest.java b/src/test/java/io/supertokens/test/InMemoryDBTest.java index 175317075..28c1a5498 100644 --- a/src/test/java/io/supertokens/test/InMemoryDBTest.java +++ b/src/test/java/io/supertokens/test/InMemoryDBTest.java @@ -28,6 +28,7 @@ import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; +import io.supertokens.pluginInterface.session.SessionStorage; import io.supertokens.session.Session; import io.supertokens.session.info.SessionInformationHolder; import io.supertokens.storageLayer.StorageLayer; @@ -181,7 +182,7 @@ public void createAndForgetSession() assert sessionInfo.accessToken != null; assert sessionInfo.refreshToken != null; - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); process.kill(); @@ -194,7 +195,7 @@ public void createAndForgetSession() process.startProcess(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 0); process.kill(); @@ -225,7 +226,7 @@ public void createAndGetSession() throws InterruptedException, StorageQueryExcep assertEquals(sessionInfo.session.userId, userId); assertEquals(sessionInfo.session.userDataInJWT.toString(), userDataInJWT.toString()); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); assert sessionInfo.accessToken != null; assertNull(sessionInfo.antiCsrfToken); @@ -266,7 +267,7 @@ public void createAndGetSessionNoAntiCSRF() throws InterruptedException, Storage assertEquals(sessionInfo.session.userId, userId); assertEquals(sessionInfo.session.userDataInJWT.toString(), userDataInJWT.toString()); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); assert sessionInfo.accessToken != null; assertNull(sessionInfo.antiCsrfToken); @@ -344,7 +345,7 @@ public void createNewSessionAndAlterJWTPayload() throws InterruptedException, St assertEquals(sessionInfo.session.userId, userId); assertEquals(sessionInfo.session.userDataInJWT.toString(), userDataInJWT.toString()); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); assert sessionInfo.accessToken != null; assertNull(sessionInfo.antiCsrfToken); @@ -388,7 +389,7 @@ public void createAndGetSessionWithEmptyJWTPayload() throws InterruptedException assertEquals(sessionInfo.session.userId, userId); assertEquals(sessionInfo.session.userDataInJWT.toString(), userDataInJWT.toString()); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); assert sessionInfo.accessToken != null; assertNull(sessionInfo.antiCsrfToken); @@ -436,7 +437,7 @@ public void createAndGetSessionWithComplexJWTPayload() throws InterruptedExcepti assertEquals(sessionInfo.session.userId, userId); assertEquals(sessionInfo.session.userDataInJWT.toString(), userDataInJWT.toString()); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); assert sessionInfo.accessToken != null; assertNull(sessionInfo.antiCsrfToken); @@ -563,7 +564,7 @@ public void refreshSessionTestWithAntiCsrf() throws IOException, InterruptedExce assertEquals(refreshedSession.session.handle, sessionInfo.session.handle); assertEquals(refreshedSession.session.userId, sessionInfo.session.userId); assertEquals(refreshedSession.session.userDataInJWT.toString(), sessionInfo.session.userDataInJWT.toString()); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); SessionInformationHolder newSession = Session.getSession(process.getProcess(), @@ -590,7 +591,7 @@ public void refreshSessionTestWithAntiCsrf() throws IOException, InterruptedExce SessionInformationHolder refreshedSession2 = Session.refreshSession(process.getProcess(), refreshedSession.refreshToken.token, refreshedSession.antiCsrfToken, true); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); assert refreshedSession2.accessToken != null; @@ -656,7 +657,7 @@ public void refreshSessionTestWithNoAntiCsrf() throws IOException, InterruptedEx assertEquals(refreshedSession.session.handle, sessionInfo.session.handle); assertEquals(refreshedSession.session.userId, sessionInfo.session.userId); assertEquals(refreshedSession.session.userDataInJWT.toString(), sessionInfo.session.userDataInJWT.toString()); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); SessionInformationHolder newSession = Session.getSession(process.getProcess(), @@ -682,7 +683,7 @@ public void refreshSessionTestWithNoAntiCsrf() throws IOException, InterruptedEx SessionInformationHolder refreshedSession2 = Session.refreshSession(process.getProcess(), refreshedSession.refreshToken.token, refreshedSession.antiCsrfToken, false); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); assert refreshedSession2.accessToken != null; @@ -767,7 +768,7 @@ public void refreshTokenExpiresAfterShortTime() throws InterruptedException, IOE sessionInfo.antiCsrfToken, false); assert newRefreshedSession.refreshToken != null; - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); Session.getSession(main, sessionInfo.accessToken.token, sessionInfo.antiCsrfToken, false, true); @@ -781,7 +782,7 @@ public void refreshTokenExpiresAfterShortTime() throws InterruptedException, IOE } catch (UnauthorisedException ignored) { } - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); } @@ -791,7 +792,7 @@ public void refreshTokenExpiresAfterShortTime() throws InterruptedException, IOE userDataInDatabase, false); assert sessionInfo.refreshToken != null; assert sessionInfo.accessToken != null; - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 2); SessionInformationHolder newRefreshedSession = Session.refreshSession(main, sessionInfo.refreshToken.token, @@ -800,7 +801,7 @@ public void refreshTokenExpiresAfterShortTime() throws InterruptedException, IOE assert newRefreshedSession.accessToken != null; assertNotEquals(newRefreshedSession.accessToken.token, sessionInfo.accessToken.token); assertNotEquals(newRefreshedSession.refreshToken.token, sessionInfo.refreshToken.token); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 2); Thread.sleep(500); @@ -811,7 +812,7 @@ public void refreshTokenExpiresAfterShortTime() throws InterruptedException, IOE assert newRefreshedSession2.accessToken != null; assertNotEquals(newRefreshedSession.accessToken.token, newRefreshedSession2.accessToken.token); assertNotEquals(newRefreshedSession.refreshToken.token, newRefreshedSession2.refreshToken.token); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 2); Thread.sleep(500); @@ -822,7 +823,7 @@ public void refreshTokenExpiresAfterShortTime() throws InterruptedException, IOE assert newRefreshedSession3.accessToken != null; assertNotEquals(newRefreshedSession3.accessToken.token, newRefreshedSession2.accessToken.token); assertNotEquals(newRefreshedSession3.refreshToken.token, newRefreshedSession2.refreshToken.token); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 2); } diff --git a/src/test/java/io/supertokens/test/StorageTest.java b/src/test/java/io/supertokens/test/StorageTest.java index 8c4656dde..7e1ba80a3 100644 --- a/src/test/java/io/supertokens/test/StorageTest.java +++ b/src/test/java/io/supertokens/test/StorageTest.java @@ -26,6 +26,7 @@ import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; +import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; import io.supertokens.pluginInterface.noSqlStorage.NoSQLStorage_1; import io.supertokens.pluginInterface.sqlStorage.SQLStorage; import io.supertokens.storageLayer.StorageLayer; @@ -113,8 +114,12 @@ public void transactionIsolationWithoutAnInitialRowTesting() throws Exception { } if (info == null) { - sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, key, - new KeyValueInfo("Value1")); + try { + sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, key, + new KeyValueInfo("Value1")); + } catch (TenantOrAppNotFoundException e) { + throw new IllegalStateException(e); + } } else { endValueOfCon1.set(info.value); return null; @@ -147,8 +152,12 @@ public void transactionIsolationWithoutAnInitialRowTesting() throws Exception { } if (info == null) { - sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, key, - new KeyValueInfo("Value2")); + try { + sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, key, + new KeyValueInfo("Value2")); + } catch (TenantOrAppNotFoundException e) { + throw new IllegalStateException(e); + } } else { endValueOfCon2.set(info.value); return null; @@ -194,8 +203,12 @@ public void transactionIsolationWithAnInitialRowTesting() if (storage.getType() == STORAGE_TYPE.SQL) { SQLStorage sqlStorage = (SQLStorage) storage; sqlStorage.startTransaction(con -> { - sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", - new KeyValueInfo("Value")); + try { + sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", + new KeyValueInfo("Value")); + } catch (TenantOrAppNotFoundException e) { + throw new IllegalStateException(e); + } sqlStorage.commitTransaction(con); return null; }); @@ -213,8 +226,12 @@ public void transactionIsolationWithAnInitialRowTesting() new TenantIdentifier(null, null, null), con, "Key"); if (info.value.equals("Value")) { - sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", - new KeyValueInfo("Value1")); + try { + sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", + new KeyValueInfo("Value1")); + } catch (TenantOrAppNotFoundException e) { + throw new IllegalStateException(e); + } } else { endValueOfCon1.set(info.value); return null; @@ -236,8 +253,12 @@ public void transactionIsolationWithAnInitialRowTesting() new TenantIdentifier(null, null, null), con, "Key"); if (info.value.equals("Value")) { - sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", - new KeyValueInfo("Value2")); + try { + sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", + new KeyValueInfo("Value2")); + } catch (TenantOrAppNotFoundException e) { + throw new IllegalStateException(e); + } } else { endValueOfCon2.set(info.value); return null; @@ -287,8 +308,12 @@ public void transactionIsolationTesting() if (storage.getType() == STORAGE_TYPE.SQL) { SQLStorage sqlStorage = (SQLStorage) storage; sqlStorage.startTransaction(con -> { - sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", - new KeyValueInfo("Value")); + try { + sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", + new KeyValueInfo("Value")); + } catch (TenantOrAppNotFoundException e) { + throw new IllegalStateException(e); + } sqlStorage.commitTransaction(con); return null; }); @@ -311,8 +336,12 @@ public void transactionIsolationTesting() syncObject.notifyAll(); } - sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", - new KeyValueInfo("Value2")); + try { + sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", + new KeyValueInfo("Value2")); + } catch (TenantOrAppNotFoundException e) { + throw new IllegalStateException(e); + } try { Thread.sleep(1500); @@ -504,8 +533,12 @@ public void transactionTest() throws InterruptedException, StorageQueryException if (storage.getType() == STORAGE_TYPE.SQL) { SQLStorage sqlStorage = (SQLStorage) storage; String returnedValue = sqlStorage.startTransaction(con -> { - sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", - new KeyValueInfo("Value")); + try { + sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", + new KeyValueInfo("Value")); + } catch (TenantOrAppNotFoundException e) { + throw new IllegalStateException(e); + } sqlStorage.commitTransaction(con); return "returned value"; }); @@ -552,8 +585,12 @@ public void transactionDoNotCommitButStillCommitsTest() if (storage.getType() == STORAGE_TYPE.SQL) { SQLStorage sqlStorage = (SQLStorage) storage; sqlStorage.startTransaction(con -> { - sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", - new KeyValueInfo("Value")); + try { + sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", + new KeyValueInfo("Value")); + } catch (TenantOrAppNotFoundException e) { + throw new IllegalStateException(e); + } return null; }); KeyValueInfo value = storage.getKeyValue(new TenantIdentifier(null, null, null), "Key"); @@ -609,8 +646,12 @@ public void transactionThrowCompileTimeErrorAndExpectRollbackTest() SQLStorage sqlStorage = (SQLStorage) storage; try { sqlStorage.startTransaction(con -> { - sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", - new KeyValueInfo("Value")); + try { + sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", + new KeyValueInfo("Value")); + } catch (TenantOrAppNotFoundException e) { + throw new IllegalStateException(e); + } throw new StorageTransactionLogicException(new Exception("error message")); }); } catch (StorageTransactionLogicException e) { @@ -642,8 +683,12 @@ public void transactionThrowRunTimeErrorAndExpectRollbackTest() SQLStorage sqlStorage = (SQLStorage) storage; try { sqlStorage.startTransaction(con -> { - sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", - new KeyValueInfo("Value")); + try { + sqlStorage.setKeyValue_Transaction(new TenantIdentifier(null, null, null), con, "Key", + new KeyValueInfo("Value")); + } catch (TenantOrAppNotFoundException e) { + throw new IllegalStateException(e); + } throw new RuntimeException("error message"); }); } catch (RuntimeException e) { diff --git a/src/test/java/io/supertokens/test/multitenant/SigningKeysTest.java b/src/test/java/io/supertokens/test/multitenant/SigningKeysTest.java index 1fbf32ad6..6f240eaf3 100644 --- a/src/test/java/io/supertokens/test/multitenant/SigningKeysTest.java +++ b/src/test/java/io/supertokens/test/multitenant/SigningKeysTest.java @@ -22,6 +22,12 @@ import io.supertokens.config.Config; import io.supertokens.featureflag.EE_FEATURES; import io.supertokens.featureflag.FeatureFlagTestContent; +import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; +import io.supertokens.multitenancy.Multitenancy; +import io.supertokens.multitenancy.MultitenancyHelper; +import io.supertokens.multitenancy.exception.BadPermissionException; +import io.supertokens.multitenancy.exception.CannotModifyBaseConfigException; +import io.supertokens.multitenancy.exception.DeletionInProgressException; import io.supertokens.pluginInterface.exceptions.DbInitException; import io.supertokens.pluginInterface.exceptions.InvalidConfigException; import io.supertokens.pluginInterface.exceptions.StorageQueryException; @@ -32,6 +38,7 @@ import io.supertokens.storageLayer.StorageLayer; import io.supertokens.test.TestingProcessManager; import io.supertokens.test.Utils; +import io.supertokens.thirdparty.InvalidProviderConfigException; import org.junit.AfterClass; import org.junit.Before; import org.junit.Rule; @@ -83,7 +90,9 @@ public void normalConfigContinuesToWork() @Test public void keysAreGeneratedForAllUserPoolIds() throws InterruptedException, IOException, StorageQueryException, StorageTransactionLogicException, - InvalidConfigException, DbInitException, TenantOrAppNotFoundException { + InvalidConfigException, DbInitException, TenantOrAppNotFoundException, InvalidProviderConfigException, + DeletionInProgressException, FeatureNotEnabledException, CannotModifyBaseConfigException, + BadPermissionException { String[] args = {"../"}; TestingProcessManager.TestingProcess process = TestingProcessManager.start(args, false); @@ -103,9 +112,9 @@ public void keysAreGeneratedForAllUserPoolIds() new PasswordlessConfig(false), tenantConfig)}; - Config.loadAllTenantConfig(process.getProcess(), tenants); - - StorageLayer.loadAllTenantStorage(process.getProcess(), tenants); + for (TenantConfig config : tenants) { + Multitenancy.addNewOrUpdateAppOrTenant(process.getProcess(), new TenantIdentifier(null, null, null), config); + } List apps = new ArrayList<>(); for (TenantConfig t : tenants) { @@ -138,7 +147,9 @@ public void keysAreGeneratedForAllUserPoolIds() @Test public void signingKeyClassesAreThereForAllTenants() throws InterruptedException, IOException, InvalidConfigException, DbInitException, StorageQueryException, - StorageTransactionLogicException, TenantOrAppNotFoundException { + StorageTransactionLogicException, TenantOrAppNotFoundException, InvalidProviderConfigException, + DeletionInProgressException, FeatureNotEnabledException, CannotModifyBaseConfigException, + BadPermissionException { String[] args = {"../"}; TestingProcessManager.TestingProcess process = TestingProcessManager.start(args, false); @@ -166,9 +177,9 @@ public void signingKeyClassesAreThereForAllTenants() new PasswordlessConfig(false), tenantConfig2)}; - Config.loadAllTenantConfig(process.getProcess(), tenants); - - StorageLayer.loadAllTenantStorage(process.getProcess(), tenants); + for (TenantConfig config : tenants) { + Multitenancy.addNewOrUpdateAppOrTenant(process.getProcess(), new TenantIdentifier(null, null, null), config); + } List apps = new ArrayList<>(); for (TenantConfig t : tenants) { diff --git a/src/test/java/io/supertokens/test/session/AccessTokenSigningKeyTest.java b/src/test/java/io/supertokens/test/session/AccessTokenSigningKeyTest.java index 8e47689e9..fda2ffa4c 100644 --- a/src/test/java/io/supertokens/test/session/AccessTokenSigningKeyTest.java +++ b/src/test/java/io/supertokens/test/session/AccessTokenSigningKeyTest.java @@ -71,7 +71,7 @@ public void legacySigningKeysAreMigratedProperly() throws InterruptedException, io.supertokens.utils.Utils.PubPriKey rsaKeys = io.supertokens.utils.Utils.generateNewPubPriKey(); String signingKey = rsaKeys.toString(); KeyValueInfo newKey = new KeyValueInfo(signingKey, System.currentTimeMillis()); - SessionStorage sessionStorage = StorageLayer.getSessionStorage(process.getProcess()); + SessionStorage sessionStorage = (SessionStorage) StorageLayer.getStorage(process.getProcess()); sessionStorage.setKeyValue(new TenantIdentifier(null, null, null), "access_token_signing_key", newKey); AccessTokenSigningKey accessTokenSigningKeyInstance = AccessTokenSigningKey.getInstance(process.getProcess()); accessTokenSigningKeyInstance.transferLegacyKeyToNewTable(); @@ -103,7 +103,7 @@ public void getAllKeysReturnsOrdered() KeyValueInfo legacyKey = new KeyValueInfo(signingKey, System.currentTimeMillis() - 2000); // 2 seconds in the // past - SessionStorage sessionStorage = StorageLayer.getSessionStorage(process.getProcess()); + SessionStorage sessionStorage = (SessionStorage) StorageLayer.getStorage(process.getProcess()); sessionStorage.setKeyValue(new TenantIdentifier(null, null, null), "access_token_signing_key", legacyKey); AccessTokenSigningKey accessTokenSigningKeyInstance = AccessTokenSigningKey.getInstance(process.getProcess()); @@ -183,7 +183,7 @@ public void migratingStaticSigningKeys() String signingKey = rsaKeys.toString(); KeyValueInfo legacyKey = new KeyValueInfo(signingKey, System.currentTimeMillis() - 2629743830l); // 1 month old - SessionStorage sessionStorage = StorageLayer.getSessionStorage(process.getProcess()); + SessionStorage sessionStorage = (SessionStorage) StorageLayer.getStorage(process.getProcess()); sessionStorage.setKeyValue(new TenantIdentifier(null, null, null), "access_token_signing_key", legacyKey); AccessTokenSigningKey accessTokenSigningKeyInstance = AccessTokenSigningKey.getInstance(process.getProcess()); diff --git a/src/test/java/io/supertokens/test/session/AccessTokenTest.java b/src/test/java/io/supertokens/test/session/AccessTokenTest.java index cd83b31bd..426c5e370 100644 --- a/src/test/java/io/supertokens/test/session/AccessTokenTest.java +++ b/src/test/java/io/supertokens/test/session/AccessTokenTest.java @@ -19,6 +19,7 @@ import com.google.gson.JsonObject; import io.supertokens.ProcessState.EventAndException; import io.supertokens.ProcessState.PROCESS_STATE; +import io.supertokens.exceptions.UnauthorisedException; import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; import io.supertokens.exceptions.TryRefreshTokenException; import io.supertokens.pluginInterface.exceptions.StorageQueryException; @@ -95,7 +96,7 @@ public void testCreateSessionWithDataExpireGetAccessTokenAndCheckPayload() throw // get access token without verifying assert sessionInfo.accessToken != null; AccessToken.AccessTokenInfo accessTokenInfo = AccessToken - .getInfoFromAccessTokenWithoutVerifying(sessionInfo.accessToken.token); + .getInfoFromAccessTokenWithoutVerifying(process.getProcess(), sessionInfo.accessToken.token); // check payload is fine assertEquals(accessTokenInfo.userData, userDataInJWT); @@ -138,7 +139,7 @@ public void testSessionWithOldExpiryTimeForAccessToken() throws Exception { accessTokenInfo.lmrt, value); AccessTokenInfo customAccessToken = AccessToken - .getInfoFromAccessTokenWithoutVerifying(newAccessTokenInfo.token); + .getInfoFromAccessTokenWithoutVerifying(process.getProcess(), newAccessTokenInfo.token); assertEquals(customAccessToken.expiryTime, value); } @@ -176,7 +177,7 @@ public void testCreateAccessTokenVersion2AndCheck() throws Exception { @Test public void inputOutputTest() throws InterruptedException, InvalidKeyException, NoSuchAlgorithmException, StorageQueryException, StorageTransactionLogicException, TryRefreshTokenException, - UnsupportedEncodingException, InvalidKeySpecException, SignatureException { + UnsupportedEncodingException, InvalidKeySpecException, SignatureException, UnauthorisedException { String[] args = {"../"}; TestingProcess process = TestingProcessManager.start(args); EventAndException e = process.checkOrWaitForEvent(PROCESS_STATE.STARTED); @@ -205,7 +206,7 @@ public void inputOutputTest() throws InterruptedException, InvalidKeyException, @Test public void inputOutputTestv1() throws InterruptedException, InvalidKeyException, NoSuchAlgorithmException, StorageQueryException, StorageTransactionLogicException, TryRefreshTokenException, - UnsupportedEncodingException, InvalidKeySpecException, SignatureException { + UnsupportedEncodingException, InvalidKeySpecException, SignatureException, UnauthorisedException { String[] args = {"../"}; TestingProcess process = TestingProcessManager.start(args); EventAndException e = process.checkOrWaitForEvent(PROCESS_STATE.STARTED); @@ -249,7 +250,8 @@ public void signingKeyShortInterval() @Test public void signingKeyChangeDoesNotThrow() throws IOException, InterruptedException, InvalidKeyException, NoSuchAlgorithmException, - StorageQueryException, StorageTransactionLogicException, InvalidKeySpecException, SignatureException { + StorageQueryException, StorageTransactionLogicException, InvalidKeySpecException, SignatureException, + UnauthorisedException { Utils.setValueInConfig("access_token_signing_key_update_interval", "0.00027"); // 1 second String[] args = {"../"}; @@ -276,7 +278,8 @@ public void signingKeyChangeDoesNotThrow() @Test public void accessTokenShortLifetimeThrowsRefreshTokenError() throws IOException, InterruptedException, InvalidKeyException, NoSuchAlgorithmException, - StorageQueryException, StorageTransactionLogicException, InvalidKeySpecException, SignatureException { + StorageQueryException, StorageTransactionLogicException, InvalidKeySpecException, SignatureException, + UnauthorisedException { Utils.setValueInConfig("access_token_validity", "1"); // 1 second String[] args = {"../"}; @@ -305,7 +308,8 @@ public void accessTokenShortLifetimeThrowsRefreshTokenError() @Test public void verifyRandomAccessTokenFailure() - throws InterruptedException, StorageQueryException, StorageTransactionLogicException { + throws InterruptedException, StorageQueryException, StorageTransactionLogicException, UnauthorisedException, + NoSuchAlgorithmException { String[] args = {"../"}; TestingProcess process = TestingProcessManager.start(args); assertNotNull(process.checkOrWaitForEvent(PROCESS_STATE.STARTED)); diff --git a/src/test/java/io/supertokens/test/session/DeleteExpiredAccessTokenSigningKeysTest.java b/src/test/java/io/supertokens/test/session/DeleteExpiredAccessTokenSigningKeysTest.java index 08d5b1311..838ec4e16 100644 --- a/src/test/java/io/supertokens/test/session/DeleteExpiredAccessTokenSigningKeysTest.java +++ b/src/test/java/io/supertokens/test/session/DeleteExpiredAccessTokenSigningKeysTest.java @@ -23,6 +23,7 @@ import io.supertokens.pluginInterface.KeyValueInfo; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.pluginInterface.multitenancy.AppIdentifier; +import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; import io.supertokens.pluginInterface.session.SessionStorage; import io.supertokens.pluginInterface.session.sqlStorage.SessionSQLStorage; import io.supertokens.storageLayer.StorageLayer; @@ -62,7 +63,7 @@ public void jobCleansOldKeysTest() throws Exception { assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); - SessionStorage sessionStorage = StorageLayer.getSessionStorage(process.getProcess()); + SessionStorage sessionStorage = (SessionStorage) StorageLayer.getStorage(process.getProcess()); if (sessionStorage.getType() != STORAGE_TYPE.SQL) { return; @@ -72,21 +73,25 @@ public void jobCleansOldKeysTest() throws Exception { SessionSQLStorage sqlStorage = (SessionSQLStorage) sessionStorage; sqlStorage.startTransaction(con -> { - sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, - new KeyValueInfo("clean!", 100)); - sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, - new KeyValueInfo("clean!", - System.currentTimeMillis() - signingKeyUpdateInterval - 3 * accessTokenValidity)); - sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, - new KeyValueInfo("clean!", - System.currentTimeMillis() - signingKeyUpdateInterval - 2 * accessTokenValidity)); - sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, - new KeyValueInfo("keep!", - System.currentTimeMillis() - signingKeyUpdateInterval - 1 * accessTokenValidity)); - sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, - new KeyValueInfo("keep!", System.currentTimeMillis() - signingKeyUpdateInterval)); - sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, - new KeyValueInfo("keep!", System.currentTimeMillis())); + try { + sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, + new KeyValueInfo("clean!", 100)); + sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, + new KeyValueInfo("clean!", + System.currentTimeMillis() - signingKeyUpdateInterval - 3 * accessTokenValidity)); + sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, + new KeyValueInfo("clean!", + System.currentTimeMillis() - signingKeyUpdateInterval - 2 * accessTokenValidity)); + sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, + new KeyValueInfo("keep!", + System.currentTimeMillis() - signingKeyUpdateInterval - 1 * accessTokenValidity)); + sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, + new KeyValueInfo("keep!", System.currentTimeMillis() - signingKeyUpdateInterval)); + sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, + new KeyValueInfo("keep!", System.currentTimeMillis())); + } catch (TenantOrAppNotFoundException e) { + throw new IllegalStateException(e); + } return true; }); @@ -118,7 +123,7 @@ public void jobKeepsOldKeysIfNotDynamicTest() throws Exception { assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); - SessionStorage sessionStorage = StorageLayer.getSessionStorage(process.getProcess()); + SessionStorage sessionStorage = (SessionStorage) StorageLayer.getStorage(process.getProcess()); if (sessionStorage.getType() != STORAGE_TYPE.SQL) { return; @@ -128,21 +133,25 @@ public void jobKeepsOldKeysIfNotDynamicTest() throws Exception { SessionSQLStorage sqlStorage = (SessionSQLStorage) sessionStorage; sqlStorage.startTransaction(con -> { - sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, - new KeyValueInfo("clean!", 100)); - sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, - new KeyValueInfo("clean!", - System.currentTimeMillis() - signingKeyUpdateInterval - 3 * accessTokenValidity)); - sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, - new KeyValueInfo("clean!", - System.currentTimeMillis() - signingKeyUpdateInterval - 2 * accessTokenValidity)); - sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, - new KeyValueInfo("keep!", - System.currentTimeMillis() - signingKeyUpdateInterval - 1 * accessTokenValidity)); - sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, - new KeyValueInfo("keep!", System.currentTimeMillis() - signingKeyUpdateInterval)); - sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, - new KeyValueInfo("keep!", System.currentTimeMillis())); + try { + sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, + new KeyValueInfo("clean!", 100)); + sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, + new KeyValueInfo("clean!", + System.currentTimeMillis() - signingKeyUpdateInterval - 3 * accessTokenValidity)); + sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, + new KeyValueInfo("clean!", + System.currentTimeMillis() - signingKeyUpdateInterval - 2 * accessTokenValidity)); + sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, + new KeyValueInfo("keep!", + System.currentTimeMillis() - signingKeyUpdateInterval - 1 * accessTokenValidity)); + sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, + new KeyValueInfo("keep!", System.currentTimeMillis() - signingKeyUpdateInterval)); + sqlStorage.addAccessTokenSigningKey_Transaction(new AppIdentifier(null, null), con, + new KeyValueInfo("keep!", System.currentTimeMillis())); + } catch (TenantOrAppNotFoundException e) { + throw new IllegalStateException(e); + } return true; }); diff --git a/src/test/java/io/supertokens/test/session/RegenerateTokenTest.java b/src/test/java/io/supertokens/test/session/RegenerateTokenTest.java index 87c4cd915..705f70645 100644 --- a/src/test/java/io/supertokens/test/session/RegenerateTokenTest.java +++ b/src/test/java/io/supertokens/test/session/RegenerateTokenTest.java @@ -253,7 +253,7 @@ public void testSessionRegenerateWithTokenExpiryAndRefresh() throws Exception { assert getSessionResponse.accessToken != null; AccessToken.AccessTokenInfo accessTokenInfoAfter = AccessToken - .getInfoFromAccessTokenWithoutVerifying(getSessionResponse.accessToken.token); + .getInfoFromAccessTokenWithoutVerifying(process.getProcess(), getSessionResponse.accessToken.token); assertEquals(accessTokenInfoAfter.userData, newUserDataInJWT); assertNotEquals(accessTokenInfoAfter.expiryTime, accessTokenInfoBefore.expiryTime); // expiry time is different diff --git a/src/test/java/io/supertokens/test/session/SessionTest1.java b/src/test/java/io/supertokens/test/session/SessionTest1.java index ce3c33ffa..1ba683567 100644 --- a/src/test/java/io/supertokens/test/session/SessionTest1.java +++ b/src/test/java/io/supertokens/test/session/SessionTest1.java @@ -27,6 +27,7 @@ import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; +import io.supertokens.pluginInterface.session.SessionStorage; import io.supertokens.session.Session; import io.supertokens.session.info.SessionInformationHolder; import io.supertokens.storageLayer.StorageLayer; @@ -88,7 +89,7 @@ public void createAndGetSession() throws InterruptedException, StorageQueryExcep assertEquals(sessionInfo.session.userId, userId); assertEquals(sessionInfo.session.userDataInJWT.toString(), userDataInJWT.toString()); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); assert sessionInfo.accessToken != null; assertNull(sessionInfo.antiCsrfToken); @@ -127,7 +128,7 @@ public void createAndGetSessionNoAntiCSRF() throws InterruptedException, Storage assertEquals(sessionInfo.session.userId, userId); assertEquals(sessionInfo.session.userDataInJWT.toString(), userDataInJWT.toString()); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); assert sessionInfo.accessToken != null; assertNull(sessionInfo.antiCsrfToken); @@ -201,7 +202,7 @@ public void createNewSessionAndAlterJWTPayload() throws InterruptedException, St assertEquals(sessionInfo.session.userId, userId); assertEquals(sessionInfo.session.userDataInJWT.toString(), userDataInJWT.toString()); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); assert sessionInfo.accessToken != null; assertNull(sessionInfo.antiCsrfToken); @@ -243,7 +244,7 @@ public void createAndGetSessionWithEmptyJWTPayload() throws InterruptedException assertEquals(sessionInfo.session.userId, userId); assertEquals(sessionInfo.session.userDataInJWT.toString(), userDataInJWT.toString()); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); assert sessionInfo.accessToken != null; assertNull(sessionInfo.antiCsrfToken); @@ -289,7 +290,7 @@ public void createAndGetSessionWithComplexJWTPayload() throws InterruptedExcepti assertEquals(sessionInfo.session.userId, userId); assertEquals(sessionInfo.session.userDataInJWT.toString(), userDataInJWT.toString()); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); assert sessionInfo.accessToken != null; assertNull(sessionInfo.antiCsrfToken); @@ -412,7 +413,7 @@ public void refreshSessionTestWithAntiCsrf() throws IOException, InterruptedExce assertEquals(refreshedSession.session.handle, sessionInfo.session.handle); assertEquals(refreshedSession.session.userId, sessionInfo.session.userId); assertEquals(refreshedSession.session.userDataInJWT.toString(), sessionInfo.session.userDataInJWT.toString()); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); SessionInformationHolder newSession = Session.getSession(process.getProcess(), @@ -439,7 +440,7 @@ public void refreshSessionTestWithAntiCsrf() throws IOException, InterruptedExce SessionInformationHolder refreshedSession2 = Session.refreshSession(process.getProcess(), refreshedSession.refreshToken.token, refreshedSession.antiCsrfToken, true); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); assert refreshedSession2.accessToken != null; @@ -503,7 +504,7 @@ public void refreshSessionTestWithNoAntiCsrf() throws IOException, InterruptedEx assertEquals(refreshedSession.session.handle, sessionInfo.session.handle); assertEquals(refreshedSession.session.userId, sessionInfo.session.userId); assertEquals(refreshedSession.session.userDataInJWT.toString(), sessionInfo.session.userDataInJWT.toString()); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); SessionInformationHolder newSession = Session.getSession(process.getProcess(), @@ -529,7 +530,7 @@ public void refreshSessionTestWithNoAntiCsrf() throws IOException, InterruptedEx SessionInformationHolder refreshedSession2 = Session.refreshSession(process.getProcess(), refreshedSession.refreshToken.token, refreshedSession.antiCsrfToken, false); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); assert refreshedSession2.accessToken != null; @@ -610,7 +611,7 @@ public void refreshTokenExpiresAfterShortTime() throws InterruptedException, IOE sessionInfo.antiCsrfToken, false); assert newRefreshedSession.refreshToken != null; - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); Session.getSession(main, sessionInfo.accessToken.token, sessionInfo.antiCsrfToken, false, true); @@ -624,7 +625,7 @@ public void refreshTokenExpiresAfterShortTime() throws InterruptedException, IOE } catch (UnauthorisedException ignored) { } - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); } @@ -634,7 +635,7 @@ public void refreshTokenExpiresAfterShortTime() throws InterruptedException, IOE userDataInDatabase, false); assert sessionInfo.refreshToken != null; assert sessionInfo.accessToken != null; - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 2); SessionInformationHolder newRefreshedSession = Session.refreshSession(main, sessionInfo.refreshToken.token, @@ -643,7 +644,7 @@ public void refreshTokenExpiresAfterShortTime() throws InterruptedException, IOE assert newRefreshedSession.accessToken != null; assertNotEquals(newRefreshedSession.accessToken.token, sessionInfo.accessToken.token); assertNotEquals(newRefreshedSession.refreshToken.token, sessionInfo.refreshToken.token); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 2); Thread.sleep(500); @@ -654,7 +655,7 @@ public void refreshTokenExpiresAfterShortTime() throws InterruptedException, IOE assert newRefreshedSession2.accessToken != null; assertNotEquals(newRefreshedSession.accessToken.token, newRefreshedSession2.accessToken.token); assertNotEquals(newRefreshedSession.refreshToken.token, newRefreshedSession2.refreshToken.token); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 2); Thread.sleep(500); @@ -665,7 +666,7 @@ public void refreshTokenExpiresAfterShortTime() throws InterruptedException, IOE assert newRefreshedSession3.accessToken != null; assertNotEquals(newRefreshedSession3.accessToken.token, newRefreshedSession2.accessToken.token); assertNotEquals(newRefreshedSession3.refreshToken.token, newRefreshedSession2.refreshToken.token); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 2); } diff --git a/src/test/java/io/supertokens/test/session/SessionTest2.java b/src/test/java/io/supertokens/test/session/SessionTest2.java index cf97c8770..0a79333c3 100644 --- a/src/test/java/io/supertokens/test/session/SessionTest2.java +++ b/src/test/java/io/supertokens/test/session/SessionTest2.java @@ -26,6 +26,7 @@ import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; +import io.supertokens.pluginInterface.session.SessionStorage; import io.supertokens.session.Session; import io.supertokens.session.info.SessionInformationHolder; import io.supertokens.storageLayer.StorageLayer; @@ -218,11 +219,11 @@ public void revokeSessionWithoutBlacklisting() Session.createNewSession(process.getProcess(), userId, userDataInJWT, userDataInDatabase, false); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 2); Session.revokeSessionUsingSessionHandles(process.getProcess(), new String[]{sessionInfo.session.handle}); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); try { diff --git a/src/test/java/io/supertokens/test/session/SessionTest3.java b/src/test/java/io/supertokens/test/session/SessionTest3.java index fda14c29d..a1ede805c 100644 --- a/src/test/java/io/supertokens/test/session/SessionTest3.java +++ b/src/test/java/io/supertokens/test/session/SessionTest3.java @@ -26,6 +26,7 @@ import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; +import io.supertokens.pluginInterface.session.SessionStorage; import io.supertokens.session.Session; import io.supertokens.session.info.SessionInformationHolder; import io.supertokens.storageLayer.StorageLayer; @@ -92,11 +93,11 @@ public void revokeSessionWithBlacklistingRefreshSessionAndGetSessionThrows() Session.createNewSession(process.getProcess(), userId, userDataInJWT, userDataInDatabase, false); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 2); Session.revokeSessionUsingSessionHandles(process.getProcess(), new String[]{sessionInfo.session.handle}); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); try { @@ -204,12 +205,12 @@ public void revokeAllSessionsForUserWithoutBlacklisting() throws InterruptedExce Session.createNewSession(process.getProcess(), "userId2", userDataInJWT, userDataInDatabase, false); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 4); assertEquals(Session.revokeAllSessionsForUser(process.getProcess(), userId).length, 3); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); Session.getSession(process.getProcess(), sessionInfo.accessToken.token, sessionInfo.antiCsrfToken, false, true); @@ -260,13 +261,13 @@ public void removeExpiredSessions() assert sessionInfo3.refreshToken != null; assert sessionInfo3.accessToken != null; - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 3); Thread.sleep(2500); Session.createNewSession(process.getProcess(), userId, userDataInJWT, userDataInDatabase, false); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); process.kill(); diff --git a/src/test/java/io/supertokens/test/session/SessionTest4.java b/src/test/java/io/supertokens/test/session/SessionTest4.java index 0cb14cb12..e388640c7 100644 --- a/src/test/java/io/supertokens/test/session/SessionTest4.java +++ b/src/test/java/io/supertokens/test/session/SessionTest4.java @@ -25,6 +25,7 @@ import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; +import io.supertokens.pluginInterface.session.SessionStorage; import io.supertokens.session.Session; import io.supertokens.session.accessToken.AccessTokenSigningKey; import io.supertokens.session.info.SessionInformationHolder; @@ -107,7 +108,7 @@ public void checkForNumberOfDeletedSessions() } assertTrue(revokedAll); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 2); Session.createNewSession(process.getProcess(), userId, userDataInJWT, userDataInDatabase, false); @@ -116,12 +117,12 @@ public void checkForNumberOfDeletedSessions() assertEquals(Session.revokeAllSessionsForUser(process.getProcess(), userId).length, 4); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); assertEquals(Session.revokeAllSessionsForUser(process.getProcess(), "userId2").length, 1); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 0); assertEquals(Session.revokeSessionUsingSessionHandles(process.getProcess(), handles).length, 0); @@ -184,7 +185,7 @@ public void createVerifyRefreshVerifyRefresh() userDataInDatabase, false); assertEquals(sessionInfo.session.userId, userId); assertEquals(sessionInfo.session.userDataInJWT.toString(), userDataInJWT.toString()); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); assert sessionInfo.accessToken != null; assert sessionInfo.refreshToken != null; @@ -219,7 +220,7 @@ public void createVerifyRefreshVerifyRefresh() assert sessionInfo.refreshToken != null; assert sessionInfo.accessToken != null; - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); } @@ -248,7 +249,7 @@ public void verifyAccessTokenThatIsBelongsToGrandparentRefreshToken() userDataInDatabase, false); assertEquals(sessionInfo.session.userId, userId); assertEquals(sessionInfo.session.userDataInJWT.toString(), userDataInJWT.toString()); - assertEquals(StorageLayer.getSessionStorage(process.getProcess()) + assertEquals(((SessionStorage) StorageLayer.getStorage(process.getProcess())) .getNumberOfSessions(new TenantIdentifier(null, null, null)), 1); assert sessionInfo.accessToken != null; assert sessionInfo.refreshToken != null; diff --git a/src/test/java/io/supertokens/test/session/api/MultitenantAPITest.java b/src/test/java/io/supertokens/test/session/api/MultitenantAPITest.java new file mode 100644 index 000000000..1f037bcad --- /dev/null +++ b/src/test/java/io/supertokens/test/session/api/MultitenantAPITest.java @@ -0,0 +1,367 @@ +/* + * Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.test.session.api; + +import com.google.gson.JsonObject; +import io.supertokens.ProcessState; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlagTestContent; +import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; +import io.supertokens.multitenancy.Multitenancy; +import io.supertokens.multitenancy.exception.BadPermissionException; +import io.supertokens.multitenancy.exception.CannotModifyBaseConfigException; +import io.supertokens.multitenancy.exception.DeletionInProgressException; +import io.supertokens.pluginInterface.exceptions.InvalidConfigException; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.multitenancy.*; +import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.httpRequest.HttpResponseException; +import io.supertokens.thirdparty.InvalidProviderConfigException; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.util.HashMap; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +public class MultitenantAPITest { + TestingProcessManager.TestingProcess process; + TenantIdentifier t1, t2, t3; + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @After + public void afterEach() throws InterruptedException { + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Before + public void beforeEach() throws InterruptedException, InvalidProviderConfigException, DeletionInProgressException, + StorageQueryException, FeatureNotEnabledException, TenantOrAppNotFoundException, IOException, + InvalidConfigException, CannotModifyBaseConfigException, BadPermissionException { + Utils.reset(); + + String[] args = {"../"}; + + this.process = TestingProcessManager.start(args); + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{EE_FEATURES.MULTI_TENANCY}); + process.startProcess(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + createTenants(); + } + + + private void createTenants() + throws StorageQueryException, TenantOrAppNotFoundException, InvalidProviderConfigException, + DeletionInProgressException, FeatureNotEnabledException, IOException, InvalidConfigException, + CannotModifyBaseConfigException, BadPermissionException { + // User pool 1 - (null, a1, null) + // User pool 2 - (null, a1, t1), (null, a1, t2) + + { // tenant 1 + JsonObject config = new JsonObject(); + TenantIdentifier tenantIdentifier = new TenantIdentifier(null, "a1", null); + + StorageLayer.getStorage(new TenantIdentifier(null, null, null), process.getProcess()) + .modifyConfigToAddANewUserPoolForTesting(config, 1); + + Multitenancy.addNewOrUpdateAppOrTenant( + process.getProcess(), + new TenantIdentifier(null, null, null), + new TenantConfig( + tenantIdentifier, + new EmailPasswordConfig(false), + new ThirdPartyConfig(false, null), + new PasswordlessConfig(true), + config + ) + ); + } + + { // tenant 2 + JsonObject config = new JsonObject(); + TenantIdentifier tenantIdentifier = new TenantIdentifier(null, "a1", "t1"); + + StorageLayer.getStorage(new TenantIdentifier(null, null, null), process.getProcess()) + .modifyConfigToAddANewUserPoolForTesting(config, 2); + + Multitenancy.addNewOrUpdateAppOrTenant( + process.getProcess(), + new TenantIdentifier(null, "a1", null), + new TenantConfig( + tenantIdentifier, + new EmailPasswordConfig(false), + new ThirdPartyConfig(false, null), + new PasswordlessConfig(true), + config + ) + ); + } + + { // tenant 3 + JsonObject config = new JsonObject(); + TenantIdentifier tenantIdentifier = new TenantIdentifier(null, "a1", "t2"); + + StorageLayer.getStorage(new TenantIdentifier(null, null, null), process.getProcess()) + .modifyConfigToAddANewUserPoolForTesting(config, 2); + + Multitenancy.addNewOrUpdateAppOrTenant( + process.getProcess(), + new TenantIdentifier(null, "a1", null), + new TenantConfig( + tenantIdentifier, + new EmailPasswordConfig(false), + new ThirdPartyConfig(false, null), + new PasswordlessConfig(true), + config + ) + ); + } + + t1 = new TenantIdentifier(null, "a1", null); + t2 = new TenantIdentifier(null, "a1", "t1"); + t3 = new TenantIdentifier(null, "a1", "t2"); + } + + private JsonObject createSession(TenantIdentifier tenantIdentifier, String userId, JsonObject userDataInJWT, JsonObject userDataInDatabase) + throws HttpResponseException, IOException { + JsonObject request = new JsonObject(); + request.addProperty("userId", userId); + request.add("userDataInJWT", userDataInJWT); + request.add("userDataInDatabase", userDataInDatabase); + request.addProperty("enableAntiCsrf", false); + + JsonObject response = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + HttpRequestForTesting.getMultitenantUrl(tenantIdentifier, "/recipe/session"), request, + 1000, 1000, null, Utils.getCdiVersionLatestForTests(), + "session"); + + return response; + } + + private JsonObject getSession(TenantIdentifier tenantIdentifier, String sessionHandle) + throws HttpResponseException, IOException { + HashMap map = new HashMap<>(); + map.put("sessionHandle", sessionHandle); + JsonObject sessionResponse = HttpRequestForTesting.sendGETRequest(process.getProcess(), "", + HttpRequestForTesting.getMultitenantUrl(tenantIdentifier, "/recipe/session"), + map, 1000, 1000, null, Utils.getCdiVersionLatestForTests(), + "session"); + + assertEquals("OK", sessionResponse.getAsJsonPrimitive("status").getAsString()); + return sessionResponse; + } + + private void getSessionUnauthorised(TenantIdentifier tenantIdentifier, String sessionHandle) + throws HttpResponseException, IOException { + HashMap map = new HashMap<>(); + map.put("sessionHandle", sessionHandle); + JsonObject sessionResponse = HttpRequestForTesting.sendGETRequest(process.getProcess(), "", + HttpRequestForTesting.getMultitenantUrl(tenantIdentifier, "/recipe/session"), + map, 1000, 1000, null, Utils.getCdiVersionLatestForTests(), + "session"); + + assertEquals("UNAUTHORISED", sessionResponse.getAsJsonPrimitive("status").getAsString()); + } + + private void regenerateSession(TenantIdentifier tenantIdentifier, String accessToken, JsonObject newUserDataInJWT) + throws HttpResponseException, IOException { + JsonObject sessionRegenerateRequest = new JsonObject(); + sessionRegenerateRequest.addProperty("accessToken", accessToken); + sessionRegenerateRequest.add("userDataInJWT", newUserDataInJWT); + + JsonObject sessionRegenerateResponse = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + HttpRequestForTesting.getMultitenantUrl(tenantIdentifier, "/recipe/session/regenerate"), + sessionRegenerateRequest, 1000, 1000, null, + Utils.getCdiVersionLatestForTests(), "session"); + + assertEquals(sessionRegenerateResponse.get("status").getAsString(), "OK"); + } + + private JsonObject verifySession(TenantIdentifier tenantIdentifier, String accessToken) + throws HttpResponseException, IOException { + JsonObject request = new JsonObject(); + request.addProperty("accessToken", accessToken); + request.addProperty("doAntiCsrfCheck", true); + request.addProperty("enableAntiCsrf", false); + JsonObject response = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + HttpRequestForTesting.getMultitenantUrl(tenantIdentifier, "/recipe/session/verify"), request, + 1000, 1000, null, + Utils.getCdiVersionLatestForTests(), "session"); + return response; + } + + private JsonObject refreshSession(TenantIdentifier tenantIdentifier, String refreshToken) + throws HttpResponseException, IOException { + JsonObject sessionRefreshBody = new JsonObject(); + + sessionRefreshBody.addProperty("refreshToken", refreshToken); + sessionRefreshBody.addProperty("enableAntiCsrf", false); + + JsonObject response = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + HttpRequestForTesting.getMultitenantUrl(tenantIdentifier, "/recipe/session/refresh"), + sessionRefreshBody, 1000, 1000, null, + Utils.getCdiVersionLatestForTests(), "session"); + return response; + } + + @Test + public void testSessionCreatedIsAccessableFromTheSameTenantOnly() throws Exception { + JsonObject user1DataInJWT = new JsonObject(); + user1DataInJWT.addProperty("foo", "val1"); + JsonObject user1DataInDb = new JsonObject(); + user1DataInJWT.addProperty("bar", "val1"); + + JsonObject user2DataInJWT = new JsonObject(); + user1DataInJWT.addProperty("foo", "val2"); + JsonObject user2DataInDb = new JsonObject(); + user1DataInJWT.addProperty("bar", "val2"); + + JsonObject user3DataInJWT = new JsonObject(); + user1DataInJWT.addProperty("foo", "val3"); + JsonObject user3DataInDb = new JsonObject(); + user1DataInJWT.addProperty("bar", "val3"); + + JsonObject session1 = createSession(t1, "userid", user1DataInJWT, user1DataInDb).get("session").getAsJsonObject(); + JsonObject session2 = createSession(t2, "userid", user2DataInJWT, user2DataInDb).get("session").getAsJsonObject(); + JsonObject session3 = createSession(t3, "userid", user3DataInJWT, user3DataInDb).get("session").getAsJsonObject(); + + { + JsonObject getSession = getSession(t1, session1.get("handle").getAsString()); + assertEquals(session1.get("userId"), getSession.get("userId")); + assertEquals(session1.get("handle"), getSession.get("sessionHandle")); + assertEquals(user1DataInJWT, getSession.get("userDataInJWT")); + assertEquals(user1DataInDb, getSession.get("userDataInDatabase")); + } + + { + JsonObject getSession = getSession(t2, session2.get("handle").getAsString()); + assertEquals(session2.get("userId"), getSession.get("userId")); + assertEquals(session2.get("handle"), getSession.get("sessionHandle")); + assertEquals(user2DataInJWT, getSession.get("userDataInJWT")); + assertEquals(user2DataInDb, getSession.get("userDataInDatabase")); + } + + { + JsonObject getSession = getSession(t3, session3.get("handle").getAsString()); + assertEquals(session3.get("userId"), getSession.get("userId")); + assertEquals(session3.get("handle"), getSession.get("sessionHandle")); + assertEquals(user3DataInJWT, getSession.get("userDataInJWT")); + assertEquals(user3DataInDb, getSession.get("userDataInDatabase")); + } + } + + @Test + public void testSessionFromOneTenantCannotBeFetchedFromAnother() throws Exception { + TenantIdentifier[] tenants = new TenantIdentifier[]{t1, t2, t3}; + + for (TenantIdentifier tenant1 : tenants) { + for (TenantIdentifier tenant2 : tenants) { + if (tenant1.equals(tenant2)) { + continue; + } + JsonObject userDataInJWT = new JsonObject(); + userDataInJWT.addProperty("foo", "val1"); + JsonObject userDataInDb = new JsonObject(); + userDataInJWT.addProperty("bar", "val1"); + + JsonObject session = createSession(tenant1, "userid", userDataInJWT, userDataInDb).get("session").getAsJsonObject(); + getSessionUnauthorised(tenant2, session.get("handle").getAsString()); + } + } + } + + @Test + public void testRegenerateSessionWorksFromAnyTenantButUpdatesTheRightSession() throws Exception { + TenantIdentifier[] tenants = new TenantIdentifier[]{t1, t2, t3}; + + for (TenantIdentifier tenant1 : tenants) { + for (TenantIdentifier tenant2 : tenants) { + JsonObject userDataInJWT = new JsonObject(); + userDataInJWT.addProperty("foo", "val1"); + JsonObject userDataInDb = new JsonObject(); + userDataInJWT.addProperty("bar", "val1"); + + JsonObject session = createSession(tenant1, "userid", userDataInJWT, userDataInDb); + userDataInJWT.addProperty("hello", "world"); + + regenerateSession(tenant2, session.get("accessToken").getAsJsonObject().get("token").getAsString(), userDataInJWT); + + JsonObject getSession = getSession(tenant1, session.get("session").getAsJsonObject().get("handle").getAsString()); + assertEquals(userDataInJWT, getSession.get("userDataInJWT")); + } + } + } + + @Test + public void testVerifySessionWorksFromAnyTenantInTheApp() throws Exception { + TenantIdentifier[] tenants = new TenantIdentifier[]{t1, t2, t3}; + + for (TenantIdentifier tenant1 : tenants) { + for (TenantIdentifier tenant2 : tenants) { + JsonObject userDataInJWT = new JsonObject(); + userDataInJWT.addProperty("foo", "val1"); + JsonObject userDataInDb = new JsonObject(); + userDataInJWT.addProperty("bar", "val1"); + + JsonObject session = createSession(tenant1, "userid", userDataInJWT, userDataInDb); + userDataInJWT.addProperty("hello", "world"); + + JsonObject sessionResponse = verifySession(tenant2, session.get("accessToken").getAsJsonObject().get("token").getAsString()); + assertEquals(session.get("session"), sessionResponse.get("session")); + } + } + } + + @Test + public void testVerifySessionDoesNotWorkFromDifferentApp() throws Exception { + JsonObject userDataInJWT = new JsonObject(); + userDataInJWT.addProperty("foo", "val1"); + JsonObject userDataInDb = new JsonObject(); + userDataInJWT.addProperty("bar", "val1"); + + JsonObject session = createSession(t1, "userid", userDataInJWT, userDataInDb); + JsonObject sessionResponse = verifySession(new TenantIdentifier(null, null, null), session.get("accessToken").getAsJsonObject().get("token").getAsString()); + assertEquals("TRY_REFRESH_TOKEN", sessionResponse.get("status").getAsString()); + } + + @Test + public void testRefreshSessionDoesNotWorkFromDifferentApp() throws Exception { + JsonObject userDataInJWT = new JsonObject(); + userDataInJWT.addProperty("foo", "val1"); + JsonObject userDataInDb = new JsonObject(); + userDataInJWT.addProperty("bar", "val1"); + + JsonObject session = createSession(t1, "userid", userDataInJWT, userDataInDb); + JsonObject sessionResponse = refreshSession(new TenantIdentifier(null, null, null), session.get("refreshToken").getAsJsonObject().get("token").getAsString()); + assertEquals("UNAUTHORISED", sessionResponse.get("status").getAsString()); + } +}