diff --git a/CHANGELOG.md b/CHANGELOG.md index 90e094baa..17cfb26ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,28 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [unreleased] +- Add TOTP recipe + +### Database changes: + +- Add new tables for TOTP recipe: + - `totp_users` that stores the users that have enabled TOTP + - `totp_user_devices` that stores devices (each device has its own secret) for each user + - `totp_used_codes` that stores used codes for each user. This is to implement rate limiting and prevent replay + attacks. + - `user_last_active` that stores the last active time for each user. + +### New APIs: + +- `GET /users/count/active` to fetch the number of active users after the given timestamp. +- `POST /recipe/totp/device` to create a new device as well as the user if it doesn't exist. +- `POST /recipe/totp/device/verify` to verify a device. This is to ensure that the user has access to the device. +- `POST /recipe/totp/verify` to verify a code and continue the login flow. +- `PUT /recipe/totp/device` to update the name of a device. Name is just a string that the user can set to identify the + device. +- `GET /recipe/totp/device/list` to get all devices for a user. +- `POST /recipe/totp/device/remove` to remove a device. If the user has no more devices, the user is also removed. + ## [4.4.2] - 2023-03-16 - Adds null check in email normalisation to fix: https://github.com/supertokens/supertokens-node/issues/514 diff --git a/build.gradle b/build.gradle index d6d374e11..bd5bbbc95 100644 --- a/build.gradle +++ b/build.gradle @@ -65,6 +65,12 @@ dependencies { // https://mvnrepository.com/artifact/com.lambdaworks/scrypt implementation group: 'com.lambdaworks', name: 'scrypt', version: '1.4.0' + // https://mvnrepository.com/artifact/com.eatthepath/java-otp + implementation group: 'com.eatthepath', name: 'java-otp', version: '0.4.0' + + // https://mvnrepository.com/artifact/commons-codec/commons-codec + implementation group: 'commons-codec', name: 'commons-codec', version: '1.15' + compileOnly project(":supertokens-plugin-interface") testImplementation project(":supertokens-plugin-interface") @@ -159,4 +165,3 @@ tasks.withType(Test) { } } } - \ No newline at end of file diff --git a/config.yaml b/config.yaml index 3b6d02904..c8eb96a7c 100644 --- a/config.yaml +++ b/config.yaml @@ -54,6 +54,11 @@ core_config_version: 0 # (OPTIONAL | Default: 900000) long value. Time in milliseconds for how long a passwordless code is valid for. # passwordless_code_lifetime: +# (OPTIONAL | Default: 5) integer value. The maximum number of invalid TOTP attempts that will trigger rate limiting. +# totp_max_attempts: + +# (OPTIONAL | Default: 900) integer value. The time in seconds for which the user will be rate limited once totp_max_attempts is crossed. +# totp_rate_limit_cooldown_sec: # (OPTIONAL | Default: installation directory/logs/info.log) string value. Give the path to a file (on your local # system) in which the SuperTokens service can write INFO logs to. Set it to "null" if you want it to log to @@ -120,4 +125,4 @@ core_config_version: 0 # (OPTIONAL | Default: null). Regex for denying requests from IP addresses that match with the value. Comment this # value to deny no IP address. -# ip_deny_regex: \ No newline at end of file +# ip_deny_regex: diff --git a/coreDriverInterfaceSupported.json b/coreDriverInterfaceSupported.json index a2126e64c..01bb39ccb 100644 --- a/coreDriverInterfaceSupported.json +++ b/coreDriverInterfaceSupported.json @@ -12,6 +12,7 @@ "2.15", "2.16", "2.17", - "2.18" + "2.18", + "2.19" ] } \ No newline at end of file diff --git a/devConfig.yaml b/devConfig.yaml index 5b15e7c8f..82e0e695a 100644 --- a/devConfig.yaml +++ b/devConfig.yaml @@ -54,6 +54,11 @@ core_config_version: 0 # (OPTIONAL | Default: 900000) long value. Time in milliseconds for how long a passwordless code is valid for. # passwordless_code_lifetime: +# (OPTIONAL | Default: 5) integer value. The maximum number of invalid TOTP attempts that will trigger rate limiting. +# totp_max_attempts: + +# (OPTIONAL | Default: 900) integer value. The time in seconds for which the user will be rate limited once totp_max_attempts is crossed. +# totp_rate_limit_cooldown_sec: # (OPTIONAL | Default: installation directory/logs/info.log) string value. Give the path to a file (on your local # system) in which the SuperTokens service can write INFO logs to. Set it to "null" if you want it to log to @@ -120,4 +125,4 @@ disable_telemetry: true # (OPTIONAL | Default: null). Regex for denying requests from IP addresses that match with the value. Comment this # value to deny no IP address. -# ip_deny_regex: \ No newline at end of file +# ip_deny_regex: diff --git a/ee/src/main/java/io/supertokens/ee/EEFeatureFlag.java b/ee/src/main/java/io/supertokens/ee/EEFeatureFlag.java index f87766863..3daa3caa9 100644 --- a/ee/src/main/java/io/supertokens/ee/EEFeatureFlag.java +++ b/ee/src/main/java/io/supertokens/ee/EEFeatureFlag.java @@ -6,10 +6,8 @@ import com.auth0.jwt.exceptions.JWTVerificationException; import com.auth0.jwt.interfaces.DecodedJWT; import com.auth0.jwt.interfaces.RSAKeyProvider; -import com.google.gson.JsonArray; -import com.google.gson.JsonObject; -import com.google.gson.JsonParser; -import com.google.gson.JsonPrimitive; +import com.google.gson.*; +import io.supertokens.ActiveUsers; import io.supertokens.Main; import io.supertokens.ProcessState; import io.supertokens.cronjobs.Cronjobs; @@ -21,6 +19,7 @@ import io.supertokens.httpRequest.HttpRequest; import io.supertokens.httpRequest.HttpResponseException; import io.supertokens.output.Logging; +import io.supertokens.pluginInterface.ActiveUsersStorage; import io.supertokens.pluginInterface.KeyValueInfo; import io.supertokens.pluginInterface.Storage; import io.supertokens.pluginInterface.exceptions.StorageQueryException; @@ -144,15 +143,50 @@ public Boolean getIsLicenseKeyPresent() { @Override public JsonObject getPaidFeatureStats() throws StorageQueryException { - JsonObject result = new JsonObject(); + JsonObject usageStats = new JsonObject(); EE_FEATURES[] features = getEnabledEEFeaturesFromDbOrCache(); - if (Arrays.stream(features).anyMatch(t -> t == EE_FEATURES.DASHBOARD_LOGIN)) { - JsonObject stats = new JsonObject(); - int userCount = StorageLayer.getDashboardStorage(main).getAllDashboardUsers().length; - stats.addProperty("user_count", userCount); - result.add(EE_FEATURES.DASHBOARD_LOGIN.toString(), stats); + ActiveUsersStorage activeUsersStorage = StorageLayer.getActiveUsersStorage(main); + + for (EE_FEATURES feature : features) { + if (feature == EE_FEATURES.DASHBOARD_LOGIN) { + JsonObject stats = new JsonObject(); + int userCount = StorageLayer.getDashboardStorage(main).getAllDashboardUsers().length; + stats.addProperty("user_count", userCount); + usageStats.add(EE_FEATURES.DASHBOARD_LOGIN.toString(), stats); + } + if (feature == EE_FEATURES.TOTP) { + JsonObject totpStats = new JsonObject(); + JsonArray totpMauArr = new JsonArray(); + + for (int i = 0; i < 30; i++) { + long now = System.currentTimeMillis(); + long today = now - (now % (24 * 60 * 60 * 1000L)); + long timestamp = today - (i * 24 * 60 * 60 * 1000L); + + int totpMau = activeUsersStorage.countUsersEnabledTotpAndActiveSince(timestamp); + totpMauArr.add(new JsonPrimitive(totpMau)); + } + + totpStats.add("maus", totpMauArr); + + int totpTotalUsers = activeUsersStorage.countUsersEnabledTotp(); + totpStats.addProperty("total_users", totpTotalUsers); + usageStats.add(EE_FEATURES.TOTP.toString(), totpStats); + } + } + + JsonArray mauArr = new JsonArray(); + for (int i = 0; i < 30; i++) { + long now = System.currentTimeMillis(); + long today = now - (now % (24 * 60 * 60 * 1000L)); + long timestamp = today - (i * 24 * 60 * 60 * 1000L); + + int mau = activeUsersStorage.countUsersActiveSince(timestamp); + mauArr.add(new JsonPrimitive(mau)); } - return result; + + usageStats.add("maus", mauArr); + return usageStats; } private EE_FEATURES[] verifyLicenseKey(String licenseKey) diff --git a/implementationDependencies.json b/implementationDependencies.json index 82e762a01..3e5b6dad3 100644 --- a/implementationDependencies.json +++ b/implementationDependencies.json @@ -100,6 +100,16 @@ "jar": "https://repo1.maven.org/maven2/com/lambdaworks/scrypt/1.4.0/scrypt-1.4.0.jar", "name": "Scrypt 1.4.0", "src": "https://repo1.maven.org/maven2/com/lambdaworks/scrypt/1.4.0/scrypt-1.4.0-sources.jar" + }, + { + "jar": "https://repo1.maven.org/maven2/com/eatthepath/java-otp/0.4.0/java-otp-0.4.0.jar", + "name": "Java OTP 0.4.0", + "src": "https://repo1.maven.org/maven2/com/eatthepath/java-otp/0.4.0/java-otp-0.4.0-sources.jar" + }, + { + "jar": "https://repo1.maven.org/maven2/commons-codec/commons-codec/1.15/commons-codec-1.15.jar", + "name": "Commons Codec 1.15", + "src": "https://repo1.maven.org/maven2/commons-codec/commons-codec/1.15/commons-codec-1.15-sources.jar" } ] } \ No newline at end of file diff --git a/src/main/java/io/supertokens/ActiveUsers.java b/src/main/java/io/supertokens/ActiveUsers.java new file mode 100644 index 000000000..6503389f9 --- /dev/null +++ b/src/main/java/io/supertokens/ActiveUsers.java @@ -0,0 +1,18 @@ +package io.supertokens; + +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.storageLayer.StorageLayer; + +public class ActiveUsers { + + public static void updateLastActive(Main main, String userId) { + try { + StorageLayer.getActiveUsersStorage(main).updateLastActive(userId); + } catch (StorageQueryException ignored) { + } + } + + public static int countUsersActiveSince(Main main, long time) throws StorageQueryException { + return StorageLayer.getActiveUsersStorage(main).countUsersActiveSince(time); + } +} diff --git a/src/main/java/io/supertokens/Main.java b/src/main/java/io/supertokens/Main.java index deed1b270..f2e8ee32b 100644 --- a/src/main/java/io/supertokens/Main.java +++ b/src/main/java/io/supertokens/Main.java @@ -26,6 +26,7 @@ import io.supertokens.cronjobs.deleteExpiredPasswordResetTokens.DeleteExpiredPasswordResetTokens; import io.supertokens.cronjobs.deleteExpiredPasswordlessDevices.DeleteExpiredPasswordlessDevices; import io.supertokens.cronjobs.deleteExpiredSessions.DeleteExpiredSessions; +import io.supertokens.cronjobs.deleteExpiredTotpTokens.DeleteExpiredTotpTokens; import io.supertokens.cronjobs.telemetry.Telemetry; import io.supertokens.emailpassword.PasswordHashing; import io.supertokens.exceptions.QuitProgramException; @@ -205,6 +206,9 @@ private void init() throws IOException { // removes passwordless devices with only expired codes Cronjobs.addCronjob(this, DeleteExpiredPasswordlessDevices.getInstance(this)); + // removes expired TOTP used tokens + Cronjobs.addCronjob(this, DeleteExpiredTotpTokens.getInstance(this)); + // removes expired dashboard session Cronjobs.addCronjob(this, DeleteExpiredDashboardSessions.getInstance(this)); diff --git a/src/main/java/io/supertokens/authRecipe/AuthRecipe.java b/src/main/java/io/supertokens/authRecipe/AuthRecipe.java index 98af4fe7e..8f4734501 100644 --- a/src/main/java/io/supertokens/authRecipe/AuthRecipe.java +++ b/src/main/java/io/supertokens/authRecipe/AuthRecipe.java @@ -20,6 +20,8 @@ import io.supertokens.pluginInterface.RECIPE_ID; import io.supertokens.pluginInterface.authRecipe.AuthRecipeUserInfo; import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; import io.supertokens.pluginInterface.useridmapping.UserIdMapping; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.useridmapping.UserIdType; @@ -59,22 +61,30 @@ public static UserPaginationContainer getUsers(Main main, Integer limit, String return new UserPaginationContainer(resultUsers, nextPaginationToken); } - public static void deleteUser(Main main, String userId) throws StorageQueryException { - // We clean up the user last so that if anything before that throws an error, then that will throw a 500 to the - // developer. In this case, they expect that the user has not been deleted (which will be true). This is as - // opposed to deleting the user first, in which case if something later throws an error, then the user has + public static void deleteUser(Main main, String userId) + throws StorageQueryException, StorageTransactionLogicException { + // We clean up the user last so that if anything before that throws an error, + // then that will throw a 500 to the + // developer. In this case, they expect that the user has not been deleted + // (which will be true). This is as + // opposed to deleting the user first, in which case if something later throws + // an error, then the user has // actually been deleted already (which is not expected by the dev) - // For things created after the intial cleanup and before finishing the operation: + // For things created after the intial cleanup and before finishing the + // operation: // - session: the session will expire anyway - // - email verification: email verification tokens can be created for any userId anyway + // - email verification: email verification tokens can be created for any userId + // anyway - // If userId mapping exists then delete entries with superTokensUserId from auth related tables and + // If userId mapping exists then delete entries with superTokensUserId from auth + // related tables and // externalUserid from non-auth tables UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(main, userId, UserIdType.ANY); if (userIdMapping != null) { - // We check if the mapped externalId is another SuperTokens UserId, this could come up when migrating + // We check if the mapped externalId is another SuperTokens UserId, this could + // come up when migrating // recipes. // in reference to // https://docs.google.com/spreadsheets/d/17hYV32B0aDCeLnSxbZhfRN2Y9b0LC2xUF44vV88RNAA/edit?usp=sharing @@ -97,12 +107,20 @@ public static void deleteUser(Main main, String userId) throws StorageQueryExcep } - private static void deleteNonAuthRecipeUser(Main main, String userId) throws StorageQueryException { + private static void deleteNonAuthRecipeUser(Main main, String userId) + throws StorageQueryException, StorageTransactionLogicException { // non auth recipe deletion StorageLayer.getUserMetadataStorage(main).deleteUserMetadata(userId); StorageLayer.getSessionStorage(main).deleteSessionsOfUser(userId); StorageLayer.getEmailVerificationStorage(main).deleteEmailVerificationUserInfo(userId); StorageLayer.getUserRolesStorage(main).deleteAllRolesForUser(userId); + + TOTPSQLStorage storage = StorageLayer.getTOTPStorage(main); + storage.startTransaction(con -> { + storage.removeUser_Transaction(con, userId); + storage.commitTransaction(con); + return null; + }); } private static void deleteAuthRecipeUser(Main main, String userId) throws StorageQueryException { diff --git a/src/main/java/io/supertokens/config/CoreConfig.java b/src/main/java/io/supertokens/config/CoreConfig.java index 516d771ba..f19f64b86 100644 --- a/src/main/java/io/supertokens/config/CoreConfig.java +++ b/src/main/java/io/supertokens/config/CoreConfig.java @@ -56,6 +56,12 @@ public class CoreConfig { @JsonProperty private long passwordless_code_lifetime = 900000; // in MS + @JsonProperty + private int totp_max_attempts = 5; + + @JsonProperty + private int totp_rate_limit_cooldown_sec = 900; // in seconds (Default 15 mins) + private final String logDefault = "asdkfahbdfk3kjHS"; @JsonProperty private String info_log_path = logDefault; @@ -106,10 +112,13 @@ public class CoreConfig { private int bcrypt_log_rounds = 11; // TODO: add https in later version -// # (OPTIONAL) boolean value (true or false). Set to true if you want to enable https requests to SuperTokens. -// # If you are not running SuperTokens within a closed network along with your API process, for -// # example if you are using multiple cloud vendors, then it is recommended to set this to true. -// # webserver_https_enabled: + // # (OPTIONAL) boolean value (true or false). Set to true if you want to enable + // https requests to SuperTokens. + // # If you are not running SuperTokens within a closed network along with your + // API process, for + // # example if you are using multiple cloud vendors, then it is recommended to + // set this to true. + // # webserver_https_enabled: @JsonProperty private boolean webserver_https_enabled = false; @@ -191,9 +200,11 @@ public enum PASSWORD_HASHING_ALG { } public int getArgon2HashingPoolSize() { - // the reason we do Math.max below is that if the password hashing algo is bcrypt, + // the reason we do Math.max below is that if the password hashing algo is + // bcrypt, // then we don't check the argon2 hashing pool size config at all. In this case, - // if the user gives a <= 0 number, it crashes the core (since it creates a blockedqueue in PaswordHashing + // if the user gives a <= 0 number, it crashes the core (since it creates a + // blockedqueue in PaswordHashing // .java with length <= 0). So we do a Math.max return Math.max(1, argon2_hashing_pool_size); } @@ -266,6 +277,15 @@ public long getPasswordlessCodeLifetime() { return passwordless_code_lifetime; } + public int getTotpMaxAttempts() { + return totp_max_attempts; + } + + /** TOTP rate limit cooldown time (in seconds) */ + public int getTotpRateLimitCooldownTimeSec() { + return totp_rate_limit_cooldown_sec; + } + public boolean isTelemetryDisabled() { return disable_telemetry; } @@ -384,6 +404,14 @@ void validateAndInitialise(Main main) throws IOException { throw new QuitProgramException("'passwordless_max_code_input_attempts' must be > 0"); } + if (totp_max_attempts <= 0) { + throw new QuitProgramException("'totp_max_attempts' must be > 0"); + } + + if (totp_rate_limit_cooldown_sec <= 0) { + throw new QuitProgramException("'totp_rate_limit_cooldown_sec' must be > 0"); + } + if (max_server_pool_size <= 0) { throw new QuitProgramException("'max_server_pool_size' must be >= 1. The config file can be found here: " + getConfigFileLocation(main)); @@ -475,4 +503,4 @@ void validateAndInitialise(Main main) throws IOException { } } -} \ No newline at end of file +} diff --git a/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java b/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java new file mode 100644 index 000000000..94730f045 --- /dev/null +++ b/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java @@ -0,0 +1,69 @@ +package io.supertokens.cronjobs.deleteExpiredTotpTokens; + +import io.supertokens.Main; +import io.supertokens.ResourceDistributor; +import io.supertokens.config.Config; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; +import io.supertokens.cronjobs.CronTask; +import io.supertokens.cronjobs.CronTaskTest; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.output.Logging; + +public class DeleteExpiredTotpTokens extends CronTask { + + public static final String RESOURCE_KEY = "io.supertokens.cronjobs.deleteExpiredTotpTokens.DeleteExpiredTotpTokens"; + + private DeleteExpiredTotpTokens(Main main) { + super("DeleteExpiredTotpTokens", main); + } + + public static DeleteExpiredTotpTokens getInstance(Main main) { + ResourceDistributor.SingletonResource instance = main.getResourceDistributor().getResource(RESOURCE_KEY); + if (instance == null) { + instance = main.getResourceDistributor().setResource(RESOURCE_KEY, new DeleteExpiredTotpTokens(main)); + } + return (DeleteExpiredTotpTokens) instance; + } + + @Override + protected void doTask() throws Exception { + if (StorageLayer.getStorage(this.main).getType() != STORAGE_TYPE.SQL) { + return; + } + + TOTPSQLStorage storage = StorageLayer.getTOTPStorage(this.main); + + long rateLimitResetInMs = Config.getConfig(this.main).getTotpRateLimitCooldownTimeSec() * 1000; + long expiredBefore = System.currentTimeMillis() - rateLimitResetInMs; + + // We will only remove expired codes that have been expired for longer + // than rate limiting duration. This ensures that this DB query + // doesn't delete totp codes that keep the rate limiting active for + // the expected cooldown duration. + int deletedCount = storage.removeExpiredCodes(expiredBefore); + Logging.debug(this.main, "Cron DeleteExpiredTotpTokens deleted " + deletedCount + " expired TOTP codes"); + } + + @Override + public int getIntervalTimeSeconds() { + if (Main.isTesting) { + Integer interval = CronTaskTest.getInstance(main).getIntervalInSeconds(RESOURCE_KEY); + if (interval != null) { + return interval; + } + } + + return 3600; // every hour + } + + @Override + public int getInitialWaitTimeSeconds() { + if (!Main.isTesting) { + return getIntervalTimeSeconds(); + } else { + return 0; + } + } + +} diff --git a/src/main/java/io/supertokens/featureflag/EE_FEATURES.java b/src/main/java/io/supertokens/featureflag/EE_FEATURES.java index 05e554dc6..820f2f73a 100644 --- a/src/main/java/io/supertokens/featureflag/EE_FEATURES.java +++ b/src/main/java/io/supertokens/featureflag/EE_FEATURES.java @@ -17,7 +17,8 @@ package io.supertokens.featureflag; public enum EE_FEATURES { - ACCOUNT_LINKING("account_linking"), MULTI_TENANCY("multi_tenancy"), TEST("test"), DASHBOARD_LOGIN("dashboard_login"); + ACCOUNT_LINKING("account_linking"), MULTI_TENANCY("multi_tenancy"), TEST("test"), DASHBOARD_LOGIN("dashboard_login"), + TOTP("totp"); private final String name; diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index aacd4e5c5..b24541294 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -22,8 +22,10 @@ import io.supertokens.ResourceDistributor; import io.supertokens.emailverification.EmailVerification; import io.supertokens.emailverification.exception.EmailAlreadyVerifiedException; +import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; import io.supertokens.inmemorydb.config.Config; import io.supertokens.inmemorydb.queries.*; +import io.supertokens.pluginInterface.ActiveUsersStorage; import io.supertokens.pluginInterface.KeyValueInfo; import io.supertokens.pluginInterface.LOG_LEVEL; import io.supertokens.pluginInterface.RECIPE_ID; @@ -61,8 +63,16 @@ import io.supertokens.pluginInterface.sqlStorage.TransactionConnection; import io.supertokens.pluginInterface.thirdparty.exception.DuplicateThirdPartyUserException; import io.supertokens.pluginInterface.thirdparty.sqlStorage.ThirdPartySQLStorage; +import io.supertokens.pluginInterface.totp.TOTPStorage; import io.supertokens.pluginInterface.useridmapping.UserIdMapping; import io.supertokens.pluginInterface.useridmapping.UserIdMappingStorage; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.totp.TOTPUsedCode; +import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; +import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; +import io.supertokens.pluginInterface.totp.exception.UsedCodeAlreadyExistsException; +import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; import io.supertokens.pluginInterface.useridmapping.exception.UnknownSuperTokensUserIdException; import io.supertokens.pluginInterface.useridmapping.exception.UserIdMappingAlreadyExistsException; import io.supertokens.pluginInterface.usermetadata.UserMetadataStorage; @@ -72,6 +82,7 @@ import io.supertokens.pluginInterface.userroles.exception.UnknownRoleException; import io.supertokens.pluginInterface.userroles.sqlStorage.UserRolesSQLStorage; import io.supertokens.session.Session; +import io.supertokens.totp.Totp; import io.supertokens.usermetadata.UserMetadata; import io.supertokens.userroles.UserRoles; import org.jetbrains.annotations.NotNull; @@ -92,7 +103,7 @@ public class Start implements SessionSQLStorage, EmailPasswordSQLStorage, EmailVerificationSQLStorage, ThirdPartySQLStorage, JWTRecipeSQLStorage, PasswordlessSQLStorage, UserMetadataSQLStorage, UserRolesSQLStorage, UserIdMappingStorage, - DashboardSQLStorage { + DashboardSQLStorage, TOTPSQLStorage, ActiveUsersStorage { private static final Object appenderLock = new Object(); private static final String APP_ID_KEY_NAME = "app_id"; @@ -301,8 +312,8 @@ public void close() { @Override public void createNewSession(String sessionHandle, String userId, String refreshTokenHash2, - JsonObject userDataInDatabase, long expiry, JsonObject userDataInJWT, - long createdAtTime) + JsonObject userDataInDatabase, long expiry, JsonObject userDataInJWT, + long createdAtTime) throws StorageQueryException { try { SessionQueries.createNewSession(this, sessionHandle, userId, refreshTokenHash2, userDataInDatabase, expiry, @@ -415,8 +426,8 @@ public long getUsersCount(RECIPE_ID[] includeRecipeIds) throws StorageQueryExcep @Override public AuthRecipeUserInfo[] getUsers(@NotNull Integer limit, @NotNull String timeJoinedOrder, - @Nullable RECIPE_ID[] includeRecipeIds, @Nullable String userId, - @Nullable Long timeJoined) + @Nullable RECIPE_ID[] includeRecipeIds, @Nullable String userId, + @Nullable Long timeJoined) throws StorageQueryException { try { return GeneralQueries.getUsers(this, limit, timeJoinedOrder, includeRecipeIds, userId, timeJoined); @@ -434,6 +445,42 @@ public boolean doesUserIdExist(String userId) throws StorageQueryException { } } + @Override + public void updateLastActive(String userId) throws StorageQueryException { + try { + ActiveUsersQueries.updateUserLastActive(this, userId); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public int countUsersActiveSince(long time) throws StorageQueryException { + try { + return ActiveUsersQueries.countUsersActiveSince(this, time); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public int countUsersEnabledTotp() throws StorageQueryException { + try { + return ActiveUsersQueries.countUsersEnabledTotp(this); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public int countUsersEnabledTotpAndActiveSince(long time) throws StorageQueryException { + try { + return ActiveUsersQueries.countUsersEnabledTotpAndActiveSince(this, time); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + @Override public SessionInfo getSessionInfo_Transaction(TransactionConnection con, String sessionHandle) throws StorageQueryException { @@ -447,7 +494,7 @@ public SessionInfo getSessionInfo_Transaction(TransactionConnection con, String @Override public void updateSessionInfo_Transaction(TransactionConnection con, String sessionHandle, String refreshTokenHash2, - long expiry) throws StorageQueryException { + long expiry) throws StorageQueryException { Connection sqlCon = (Connection) con.getConnection(); try { SessionQueries.updateSessionInfo_Transaction(this, sqlCon, sessionHandle, refreshTokenHash2, expiry); @@ -492,8 +539,8 @@ public void signUp(UserInfo userInfo) .equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (UNIQUE constraint failed: " + Config.getConfig(this).getEmailPasswordUsersTable() + ".user_id)") || e.getMessage() - .equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (UNIQUE constraint failed: " - + Config.getConfig(this).getUsersTable() + ".user_id)")) { + .equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (UNIQUE constraint failed: " + + Config.getConfig(this).getUsersTable() + ".user_id)")) { throw new DuplicateUserIdException(); } throw new StorageQueryException(e); @@ -567,7 +614,7 @@ public PasswordResetTokenInfo[] getAllPasswordResetTokenInfoForUser(String userI @Override public PasswordResetTokenInfo[] getAllPasswordResetTokenInfoForUser_Transaction(TransactionConnection con, - String userId) + String userId) throws StorageQueryException { Connection sqlCon = (Connection) con.getConnection(); try { @@ -630,7 +677,7 @@ public UserInfo getUserInfoUsingId_Transaction(TransactionConnection con, String @Override @Deprecated public UserInfo[] getUsers(@NotNull String userId, @NotNull Long timeJoined, @NotNull Integer limit, - @NotNull String timeJoinedOrder) throws StorageQueryException { + @NotNull String timeJoinedOrder) throws StorageQueryException { try { return EmailPasswordQueries.getUsersInfo(this, userId, timeJoined, limit, timeJoinedOrder); } catch (SQLException e) { @@ -669,7 +716,7 @@ public void deleteExpiredPasswordResetTokens() throws StorageQueryException { @Override public EmailVerificationTokenInfo[] getAllEmailVerificationTokenInfoForUser_Transaction(TransactionConnection con, - String userId, String email) + String userId, String email) throws StorageQueryException { Connection sqlCon = (Connection) con.getConnection(); try { @@ -700,7 +747,7 @@ public void addEmailVerificationToken(EmailVerificationTokenInfo emailVerificati @Override public void deleteAllEmailVerificationTokensForUser_Transaction(TransactionConnection con, String userId, - String email) throws StorageQueryException { + String email) throws StorageQueryException { Connection sqlCon = (Connection) con.getConnection(); try { EmailVerificationQueries.deleteAllEmailVerificationTokensForUser_Transaction(this, sqlCon, userId, email); @@ -711,7 +758,7 @@ public void deleteAllEmailVerificationTokensForUser_Transaction(TransactionConne @Override public void updateIsEmailVerified_Transaction(TransactionConnection con, String userId, String email, - boolean isEmailVerified) throws StorageQueryException { + boolean isEmailVerified) throws StorageQueryException { Connection sqlCon = (Connection) con.getConnection(); try { EmailVerificationQueries.updateUsersIsEmailVerified_Transaction(this, sqlCon, userId, email, @@ -793,8 +840,8 @@ public boolean isEmailVerified(String userId, String email) throws StorageQueryE @Override public io.supertokens.pluginInterface.thirdparty.UserInfo getUserInfoUsingId_Transaction(TransactionConnection con, - String thirdPartyId, - String thirdPartyUserId) + String thirdPartyId, + String thirdPartyUserId) throws StorageQueryException { Connection sqlCon = (Connection) con.getConnection(); try { @@ -806,7 +853,7 @@ public io.supertokens.pluginInterface.thirdparty.UserInfo getUserInfoUsingId_Tra @Override public void updateUserEmail_Transaction(TransactionConnection con, String thirdPartyId, String thirdPartyUserId, - String newEmail) throws StorageQueryException { + String newEmail) throws StorageQueryException { Connection sqlCon = (Connection) con.getConnection(); try { ThirdPartyQueries.updateUserEmail_Transaction(this, sqlCon, thirdPartyId, thirdPartyUserId, newEmail); @@ -832,8 +879,8 @@ public void signUp(io.supertokens.pluginInterface.thirdparty.UserInfo userInfo) .equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (UNIQUE constraint failed: " + Config.getConfig(this).getThirdPartyUsersTable() + ".user_id)") || e.getMessage() - .equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (UNIQUE constraint failed: " - + Config.getConfig(this).getUsersTable() + ".user_id)")) { + .equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (UNIQUE constraint failed: " + + Config.getConfig(this).getUsersTable() + ".user_id)")) { throw new io.supertokens.pluginInterface.thirdparty.exception.DuplicateUserIdException(); } throw new StorageQueryException(e); @@ -851,7 +898,7 @@ public void deleteThirdPartyUser(String userId) throws StorageQueryException { @Override public io.supertokens.pluginInterface.thirdparty.UserInfo getThirdPartyUserInfoUsingId(String thirdPartyId, - String thirdPartyUserId) + String thirdPartyUserId) throws StorageQueryException { try { return ThirdPartyQueries.getThirdPartyUserInfoUsingId(this, thirdPartyId, thirdPartyUserId); @@ -873,9 +920,9 @@ public io.supertokens.pluginInterface.thirdparty.UserInfo getThirdPartyUserInfoU @Override @Deprecated public io.supertokens.pluginInterface.thirdparty.UserInfo[] getThirdPartyUsers(@NotNull String userId, - @NotNull Long timeJoined, - @NotNull Integer limit, - @NotNull String timeJoinedOrder) + @NotNull Long timeJoined, + @NotNull Integer limit, + @NotNull String timeJoinedOrder) throws StorageQueryException { try { return ThirdPartyQueries.getThirdPartyUsers(this, userId, timeJoined, limit, timeJoinedOrder); @@ -887,7 +934,7 @@ public io.supertokens.pluginInterface.thirdparty.UserInfo[] getThirdPartyUsers(@ @Override @Deprecated public io.supertokens.pluginInterface.thirdparty.UserInfo[] getThirdPartyUsers(@NotNull Integer limit, - @NotNull String timeJoinedOrder) + @NotNull String timeJoinedOrder) throws StorageQueryException { try { return ThirdPartyQueries.getThirdPartyUsers(this, limit, timeJoinedOrder); @@ -1040,7 +1087,7 @@ public io.supertokens.pluginInterface.passwordless.UserInfo getUserByPhoneNumber @Override public void createDeviceWithCode(@Nullable String email, @Nullable String phoneNumber, String linkCodeSalt, - PasswordlessCode code) + PasswordlessCode code) throws StorageQueryException, DuplicateDeviceIdHashException, DuplicateCodeIdException, DuplicateLinkCodeHashException { if (email == null && phoneNumber == null) { @@ -1191,8 +1238,8 @@ public void createUser(io.supertokens.pluginInterface.passwordless.UserInfo user .equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (UNIQUE constraint failed: " + Config.getConfig(this).getPasswordlessUsersTable() + ".user_id)") || message - .equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (UNIQUE constraint failed: " - + Config.getConfig(this).getUsersTable() + ".user_id)")) { + .equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (UNIQUE constraint failed: " + + Config.getConfig(this).getUsersTable() + ".user_id)")) { throw new DuplicateUserIdException(); } @@ -1410,7 +1457,7 @@ public boolean createNewRoleOrDoNothingIfExists_Transaction(TransactionConnectio @Override public void addPermissionToRoleOrDoNothingIfExists_Transaction(TransactionConnection con, String role, - String permission) + String permission) throws StorageQueryException, UnknownRoleException { Connection sqlCon = (Connection) con.getConnection(); @@ -1460,7 +1507,7 @@ public boolean doesRoleExist_Transaction(TransactionConnection con, String role) @Override public void createUserIdMapping(String superTokensUserId, String externalUserId, - @Nullable String externalUserIdInfo) + @Nullable String externalUserIdInfo) throws StorageQueryException, UnknownSuperTokensUserIdException, UserIdMappingAlreadyExistsException { try { UserIdMappingQueries.createUserIdMapping(this, superTokensUserId, externalUserId, externalUserIdInfo); @@ -1530,7 +1577,7 @@ public UserIdMapping[] getUserIdMapping(String userId) throws StorageQueryExcept @Override public boolean updateOrDeleteExternalUserIdInfo(String userId, boolean isSuperTokensUserId, - @Nullable String externalUserIdInfo) throws StorageQueryException { + @Nullable String externalUserIdInfo) throws StorageQueryException { try { if (isSuperTokensUserId) { return UserIdMappingQueries.updateOrDeleteExternalUserIdInfoWithSuperTokensUserId(this, userId, @@ -1576,6 +1623,14 @@ public boolean isUserIdBeingUsedInNonAuthRecipe(String className, String userId) } } else if (className.equals(JWTRecipeStorage.class.getName())) { return false; + } else if (className.equals(TOTPStorage.class.getName())) { + try{ + TOTPDevice[] devices = TOTPQueries.getDevices(this, userId); + return devices.length > 0; + } + catch (SQLException e){ + throw new StorageQueryException(e); + } } else { throw new IllegalStateException("ClassName: " + className + " is not part of NonAuthRecipeStorage"); } @@ -1617,7 +1672,15 @@ public void addInfoToNonAuthRecipesBasedOnUserId(String className, String userId } catch (StorageTransactionLogicException e) { throw new StorageQueryException(e); } - } else if (className.equals(JWTRecipeStorage.class.getName())) { + } else if (className.equals(TOTPStorage.class.getName())) { + try { + Totp.registerDevice(this.main, userId, "testDevice", 0, 30); + } + catch (DeviceAlreadyExistsException | NoSuchAlgorithmException | FeatureNotEnabledException e) { + throw new StorageQueryException(e); + } + } + else if (className.equals(JWTRecipeStorage.class.getName())) { /* Since JWT recipe tables do not store userId we do not add any data to them */ } else { throw new IllegalStateException("ClassName: " + className + " is not part of NonAuthRecipeStorage"); @@ -1666,7 +1729,6 @@ public boolean deleteDashboardUserWithUserId(String userId) throws StorageQueryE } } - @Override public DashboardUser getDashboardUserByEmail(String email) throws StorageQueryException { try { return DashboardQueries.getDashboardUserByEmail(this, email); @@ -1675,9 +1737,8 @@ public DashboardUser getDashboardUserByEmail(String email) throws StorageQueryEx } } - @Override public void updateDashboardUsersEmailWithUserId_Transaction(TransactionConnection con, String userId, - String newEmail) + String newEmail) throws StorageQueryException, io.supertokens.pluginInterface.dashboard.exceptions.DuplicateEmailException, UserIdNotFoundException { Connection sqlCon = (Connection) con.getConnection(); @@ -1691,13 +1752,11 @@ public void updateDashboardUsersEmailWithUserId_Transaction(TransactionConnectio + Config.getConfig(this).getDashboardUsersTable() + ".email)")) { throw new io.supertokens.pluginInterface.dashboard.exceptions.DuplicateEmailException(); } - throw new StorageQueryException(e); } } - @Override public void updateDashboardUsersPasswordWithUserId_Transaction(TransactionConnection con, String userId, - String newPassword) + String newPassword) throws StorageQueryException, UserIdNotFoundException { Connection sqlCon = (Connection) con.getConnection(); try { @@ -1710,7 +1769,6 @@ public void updateDashboardUsersPasswordWithUserId_Transaction(TransactionConnec } } - @Override public DashboardSessionInfo[] getAllSessionsForUserId(String userId) throws StorageQueryException { try { return DashboardQueries.getAllSessionsForUserId(this, userId); @@ -1726,7 +1784,6 @@ public boolean revokeSessionWithSessionId(String sessionId) throws StorageQueryE } catch (SQLException e) { throw new StorageQueryException(e); } - } @Override @@ -1750,7 +1807,6 @@ public void createNewDashboardUserSession(String userId, String sessionId, long } throw new StorageQueryException(e); } - } @Override @@ -1762,7 +1818,6 @@ public void revokeExpiredSessions() throws StorageQueryException { } } - @Override public DashboardUser getDashboardUserByUserId(String userId) throws StorageQueryException { try { return DashboardQueries.getDashboardUserByUserId(this, userId); @@ -1770,4 +1825,145 @@ public DashboardUser getDashboardUserByUserId(String userId) throws StorageQuery throw new StorageQueryException(e); } } + + // TOTP recipe: + + @Override + public void createDevice(TOTPDevice device) throws StorageQueryException, DeviceAlreadyExistsException { + try { + TOTPQueries.createDevice(this, device); + } catch (StorageTransactionLogicException e) { + String message = e.actualException.getMessage(); + if (message.equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (UNIQUE constraint failed: " + + Config.getConfig(this).getTotpUserDevicesTable() + ".user_id, " + + Config.getConfig(this).getTotpUserDevicesTable() + ".device_name" + ")")) { + throw new DeviceAlreadyExistsException(); + } + + throw new StorageQueryException(e.actualException); + } + } + + @Override + public void markDeviceAsVerified(String userId, String deviceName) + throws StorageQueryException, UnknownDeviceException { + try { + int matchedCount = TOTPQueries.markDeviceAsVerified(this, userId, deviceName); + if (matchedCount == 0) { + // Note matchedCount != updatedCount + throw new UnknownDeviceException(); + } + return; // Device was marked as verified + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public int deleteDevice_Transaction(TransactionConnection con, String userId, String deviceName) + throws StorageQueryException { + Connection sqlCon = (Connection) con.getConnection(); + try { + return TOTPQueries.deleteDevice_Transaction(this, sqlCon, userId, deviceName); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public void removeUser_Transaction(TransactionConnection con, String userId) + throws StorageQueryException { + Connection sqlCon = (Connection) con.getConnection(); + try { + TOTPQueries.removeUser_Transaction(this, sqlCon, userId); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public void updateDeviceName(String userId, String oldDeviceName, String newDeviceName) + throws StorageQueryException, DeviceAlreadyExistsException, + UnknownDeviceException { + try { + int updatedCount = TOTPQueries.updateDeviceName(this, userId, oldDeviceName, newDeviceName); + if (updatedCount == 0) { + throw new UnknownDeviceException(); + } + } catch (SQLException e) { + if (e.getMessage().equals( + "[SQLITE_CONSTRAINT] Abort due to constraint violation (UNIQUE constraint failed: " + + Config.getConfig(this).getTotpUserDevicesTable() + ".user_id, " + + Config.getConfig(this).getTotpUserDevicesTable() + ".device_name" + ")")) { + throw new DeviceAlreadyExistsException(); + } + } + } + + @Override + public TOTPDevice[] getDevices(String userId) + throws StorageQueryException { + try { + return TOTPQueries.getDevices(this, userId); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public TOTPDevice[] getDevices_Transaction(TransactionConnection con, String userId) + throws StorageQueryException { + Connection sqlCon = (Connection) con.getConnection(); + try { + return TOTPQueries.getDevices_Transaction(this, sqlCon, userId); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public void insertUsedCode_Transaction(TransactionConnection con, TOTPUsedCode usedCodeObj) + throws StorageQueryException, TotpNotEnabledException, UsedCodeAlreadyExistsException { + Connection sqlCon = (Connection) con.getConnection(); + try { + TOTPQueries.insertUsedCode_Transaction(this, sqlCon, usedCodeObj); + } catch (SQLException e) { + String message = e.getMessage(); + // No user/device exists for the given usedCodeObj.userId + + if (message + .equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (FOREIGN KEY constraint failed)")) { + throw new TotpNotEnabledException(); + } + // Failed due to primary key on (userId, created_time) + if (message.equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (UNIQUE constraint failed: " + + Config.getConfig(this).getTotpUsedCodesTable() + ".user_id, " + + Config.getConfig(this).getTotpUsedCodesTable() + ".created_time_ms" + ")")) { + throw new UsedCodeAlreadyExistsException(); + } + + throw new StorageQueryException(e); + } + } + + @Override + public TOTPUsedCode[] getAllUsedCodesDescOrder_Transaction(TransactionConnection con, String userId) + throws StorageQueryException { + Connection sqlCon = (Connection) con.getConnection(); + try { + return TOTPQueries.getAllUsedCodesDescOrder_Transaction(this, sqlCon, userId); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public int removeExpiredCodes(long expiredBefore) + throws StorageQueryException { + try { + return TOTPQueries.removeExpiredCodes(this, expiredBefore); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } } diff --git a/src/main/java/io/supertokens/inmemorydb/config/SQLiteConfig.java b/src/main/java/io/supertokens/inmemorydb/config/SQLiteConfig.java index 75e8acd64..3ddaf0ac7 100644 --- a/src/main/java/io/supertokens/inmemorydb/config/SQLiteConfig.java +++ b/src/main/java/io/supertokens/inmemorydb/config/SQLiteConfig.java @@ -26,6 +26,10 @@ public String getUsersTable() { return "all_auth_recipe_users"; } + public String getUserLastActiveTable() { + return "user_last_active"; + } + public String getAccessTokenSigningKeysTable() { return "session_access_token_signing_keys"; } @@ -90,11 +94,23 @@ public String getUserIdMappingTable() { return "userid_mapping"; } - public String getDashboardUsersTable(){ + public String getTotpUsersTable() { + return "totp_users"; + } + + public String getTotpUserDevicesTable() { + return "totp_user_devices"; + } + + public String getTotpUsedCodesTable() { + return "totp_used_codes"; + } + + public String getDashboardUsersTable() { return "dashboard_users"; } - public String getDashboardSessionsTable(){ + public String getDashboardSessionsTable() { return "dashboard_user_sessions"; } -} \ No newline at end of file +} diff --git a/src/main/java/io/supertokens/inmemorydb/queries/ActiveUsersQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/ActiveUsersQueries.java new file mode 100644 index 000000000..ec1672f27 --- /dev/null +++ b/src/main/java/io/supertokens/inmemorydb/queries/ActiveUsersQueries.java @@ -0,0 +1,68 @@ +package io.supertokens.inmemorydb.queries; + +import java.sql.SQLException; + +import io.supertokens.inmemorydb.config.Config; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.inmemorydb.Start; + +import static io.supertokens.inmemorydb.QueryExecutorTemplate.execute; +import static io.supertokens.inmemorydb.QueryExecutorTemplate.update; + +public class ActiveUsersQueries { + static String getQueryToCreateUserLastActiveTable(Start start) { + return "CREATE TABLE IF NOT EXISTS " + Config.getConfig(start).getUserLastActiveTable() + " (" + + "user_id VARCHAR(128)," + + "last_active_time BIGINT UNSIGNED," + "PRIMARY KEY(user_id)" + " );"; + } + + public static int countUsersActiveSince(Start start, long sinceTime) throws SQLException, StorageQueryException { + String QUERY = "SELECT COUNT(*) as total FROM " + Config.getConfig(start).getUserLastActiveTable() + + " WHERE last_active_time >= ?"; + + return execute(start, QUERY, pst -> pst.setLong(1, sinceTime), result -> { + if (result.next()) { + return result.getInt("total"); + } + return 0; + }); + } + + public static int countUsersEnabledTotp(Start start) throws SQLException, StorageQueryException { + String QUERY = "SELECT COUNT(*) as total FROM " + Config.getConfig(start).getTotpUsersTable(); + + return execute(start, QUERY, null, result -> { + if (result.next()) { + return result.getInt("total"); + } + return 0; + }); + } + + public static int countUsersEnabledTotpAndActiveSince(Start start, long sinceTime) throws SQLException, StorageQueryException { + String QUERY = "SELECT COUNT(*) as total FROM " + Config.getConfig(start).getTotpUsersTable() + " AS totp_users " + + "INNER JOIN " + Config.getConfig(start).getUserLastActiveTable() + " AS user_last_active " + + "ON totp_users.user_id = user_last_active.user_id " + + "WHERE user_last_active.last_active_time >= ?"; + + return execute(start, QUERY, pst -> pst.setLong(1, sinceTime), result -> { + if (result.next()) { + return result.getInt("total"); + } + return 0; + }); + } + + public static int updateUserLastActive(Start start, String userId) throws SQLException, StorageQueryException { + String QUERY = "INSERT INTO " + Config.getConfig(start).getUserLastActiveTable() + + "(user_id, last_active_time) VALUES(?, ?) ON CONFLICT(user_id) DO UPDATE SET last_active_time = ?"; + + long now = System.currentTimeMillis(); + return update(start, QUERY, pst -> { + pst.setString(1, userId); + pst.setLong(2, now); + pst.setLong(3, now); + }); + } + +} diff --git a/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java index 21863dae7..d1c30556e 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java @@ -94,6 +94,11 @@ public static void createTablesIfNotExists(Start start, Main main) throws SQLExc update(start, getQueryToCreateUserPaginationIndex(start), NO_OP_SETTER); } + if (!doesTableExists(start, Config.getConfig(start).getUserLastActiveTable())) { + getInstance(main).addState(CREATING_NEW_TABLE, null); + update(start, ActiveUsersQueries.getQueryToCreateUserLastActiveTable(start), NO_OP_SETTER); + } + if (!doesTableExists(start, Config.getConfig(start).getAccessTokenSigningKeysTable())) { getInstance(main).addState(CREATING_NEW_TABLE, null); update(start, getQueryToCreateAccessTokenSigningKeysTable(start), NO_OP_SETTER); @@ -186,6 +191,23 @@ public static void createTablesIfNotExists(Start start, Main main) throws SQLExc update(start, UserIdMappingQueries.getQueryToCreateUserIdMappingTable(start), NO_OP_SETTER); } + if (!doesTableExists(start, Config.getConfig(start).getTotpUsersTable())) { + getInstance(main).addState(CREATING_NEW_TABLE, null); + update(start, TOTPQueries.getQueryToCreateUsersTable(start), NO_OP_SETTER); + } + + if (!doesTableExists(start, Config.getConfig(start).getTotpUserDevicesTable())) { + getInstance(main).addState(CREATING_NEW_TABLE, null); + update(start, TOTPQueries.getQueryToCreateUserDevicesTable(start), NO_OP_SETTER); + } + + if (!doesTableExists(start, Config.getConfig(start).getTotpUsedCodesTable())) { + getInstance(main).addState(CREATING_NEW_TABLE, null); + update(start, TOTPQueries.getQueryToCreateUsedCodesTable(start), NO_OP_SETTER); + // index: + update(start, TOTPQueries.getQueryToCreateUsedCodesExpiryTimeIndex(start), NO_OP_SETTER); + } + if (!doesTableExists(start, Config.getConfig(start).getDashboardUsersTable())) { getInstance(main).addState(CREATING_NEW_TABLE, null); update(start, DashboardQueries.getQueryToCreateDashboardUsersTable(start), NO_OP_SETTER); @@ -306,8 +328,8 @@ public static long getUsersCount(Start start, RECIPE_ID[] includeRecipeIds) } public static AuthRecipeUserInfo[] getUsers(Start start, @NotNull Integer limit, @NotNull String timeJoinedOrder, - @Nullable RECIPE_ID[] includeRecipeIds, @Nullable String userId, - @Nullable Long timeJoined) + @Nullable RECIPE_ID[] includeRecipeIds, @Nullable String userId, + @Nullable Long timeJoined) throws SQLException, StorageQueryException { // This list will be used to keep track of the result's order from the db @@ -416,7 +438,7 @@ public static boolean doesUserIdExist(Start start, String userId) throws SQLExce } private static List getUserInfoForRecipeIdFromUserIds(Start start, RECIPE_ID recipeId, - List userIds) + List userIds) throws StorageQueryException, SQLException { if (recipeId == RECIPE_ID.EMAIL_PASSWORD) { return EmailPasswordQueries.getUsersInfoUsingIdList(start, userIds); diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java new file mode 100644 index 000000000..76d2e17e3 --- /dev/null +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -0,0 +1,262 @@ +package io.supertokens.inmemorydb.queries; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import io.supertokens.inmemorydb.Start; +import io.supertokens.inmemorydb.config.Config; +import io.supertokens.inmemorydb.ConnectionWithLocks; +import io.supertokens.pluginInterface.RowMapper; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.totp.TOTPUsedCode; + +import static io.supertokens.inmemorydb.QueryExecutorTemplate.execute; +import static io.supertokens.inmemorydb.QueryExecutorTemplate.update; + +public class TOTPQueries { + public static String getQueryToCreateUsersTable(Start start) { + return "CREATE TABLE IF NOT EXISTS " + Config.getConfig(start).getTotpUsersTable() + " (" + + "user_id VARCHAR(128) NOT NULL," + + "PRIMARY KEY (user_id))"; + } + + public static String getQueryToCreateUserDevicesTable(Start start) { + return "CREATE TABLE IF NOT EXISTS " + Config.getConfig(start).getTotpUserDevicesTable() + " (" + + "user_id VARCHAR(128) NOT NULL," + "device_name VARCHAR(256) NOT NULL," + + "secret_key VARCHAR(256) NOT NULL," + + "period INTEGER NOT NULL," + "skew INTEGER NOT NULL," + "verified BOOLEAN NOT NULL," + + "PRIMARY KEY (user_id, device_name)," + + "FOREIGN KEY (user_id) REFERENCES " + Config.getConfig(start).getTotpUsersTable() + + "(user_id) ON DELETE CASCADE);"; + } + + public static String getQueryToCreateUsedCodesTable(Start start) { + return "CREATE TABLE IF NOT EXISTS " + Config.getConfig(start).getTotpUsedCodesTable() + " (" + + "user_id VARCHAR(128) NOT NULL, " + // SQLite doesn't follow VARCHAR length by default + // But we can add a check constraint to make sure the length is <= 8 + + "code VARCHAR(8) NOT NULL CHECK(LENGTH(code) <= 8)," + + "is_valid BOOLEAN NOT NULL," + + "expiry_time_ms BIGINT UNSIGNED NOT NULL," + + "created_time_ms BIGINT UNSIGNED NOT NULL," + + "PRIMARY KEY (user_id, created_time_ms)," + + "FOREIGN KEY (user_id) REFERENCES " + Config.getConfig(start).getTotpUsersTable() + + "(user_id) ON DELETE CASCADE);"; + } + + public static String getQueryToCreateUsedCodesExpiryTimeIndex(Start start) { + return "CREATE INDEX IF NOT EXISTS totp_used_codes_expiry_time_ms_index ON " + + Config.getConfig(start).getTotpUsedCodesTable() + " (expiry_time_ms)"; + } + + private static int insertUser_Transaction(Start start, Connection con, String userId) + throws SQLException, StorageQueryException { + // Create user if not exists: + String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUsersTable() + + " (user_id) VALUES (?) ON CONFLICT DO NOTHING"; + + return update(con, QUERY, pst -> pst.setString(1, userId)); + } + + private static int insertDevice_Transaction(Start start, Connection con, TOTPDevice device) + throws SQLException, StorageQueryException { + String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUserDevicesTable() + + " (user_id, device_name, secret_key, period, skew, verified) VALUES (?, ?, ?, ?, ?, ?)"; + + return update(con, QUERY, pst -> { + pst.setString(1, device.userId); + pst.setString(2, device.deviceName); + pst.setString(3, device.secretKey); + pst.setInt(4, device.period); + pst.setInt(5, device.skew); + pst.setBoolean(6, device.verified); + }); + } + + public static void createDevice(Start start, TOTPDevice device) + throws StorageQueryException, StorageTransactionLogicException { + start.startTransaction(con -> { + Connection sqlCon = (Connection) con.getConnection(); + + try { + insertUser_Transaction(start, sqlCon, device.userId); + insertDevice_Transaction(start, sqlCon, device); + sqlCon.commit(); + } catch (SQLException e) { + throw new StorageTransactionLogicException(e); + } + + return null; + }); + return; + } + + public static int markDeviceAsVerified(Start start, String userId, String deviceName) + throws StorageQueryException, SQLException { + String QUERY = "UPDATE " + Config.getConfig(start).getTotpUserDevicesTable() + + " SET verified = true WHERE user_id = ? AND device_name = ?"; + return update(start, QUERY, pst -> { + pst.setString(1, userId); + pst.setString(2, deviceName); + }); + } + + public static int deleteDevice_Transaction(Start start, Connection con, String userId, String deviceName) + throws SQLException, StorageQueryException { + String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUserDevicesTable() + + " WHERE user_id = ? AND device_name = ?;"; + + return update(con, QUERY, pst -> { + pst.setString(1, userId); + pst.setString(2, deviceName); + }); + } + + public static int removeUser_Transaction(Start start, Connection con, String userId) + throws SQLException, StorageQueryException { + String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUsersTable() + + " WHERE user_id = ?;"; + int removedUsersCount = update(con, QUERY, pst -> pst.setString(1, userId)); + + return removedUsersCount; + } + + public static int updateDeviceName(Start start, String userId, String oldDeviceName, String newDeviceName) + throws StorageQueryException, SQLException { + String QUERY = "UPDATE " + Config.getConfig(start).getTotpUserDevicesTable() + + " SET device_name = ? WHERE user_id = ? AND device_name = ?;"; + + return update(start, QUERY, pst -> { + pst.setString(1, newDeviceName); + pst.setString(2, userId); + pst.setString(3, oldDeviceName); + }); + } + + public static TOTPDevice[] getDevices(Start start, String userId) + throws StorageQueryException, SQLException { + String QUERY = "SELECT * FROM " + Config.getConfig(start).getTotpUserDevicesTable() + + " WHERE user_id = ?;"; + + return execute(start, QUERY, pst -> pst.setString(1, userId), result -> { + List devices = new ArrayList<>(); + while (result.next()) { + devices.add(TOTPDeviceRowMapper.getInstance().map(result)); + } + + return devices.toArray(TOTPDevice[]::new); + }); + } + + public static TOTPDevice[] getDevices_Transaction(Start start, Connection con, String userId) + throws StorageQueryException, SQLException { + ((ConnectionWithLocks) con).lock(userId + Config.getConfig(start).getTotpUserDevicesTable()); + String QUERY = "SELECT * FROM " + Config.getConfig(start).getTotpUserDevicesTable() + + " WHERE user_id = ?;"; + + return execute(con, QUERY, pst -> pst.setString(1, userId), result -> { + List devices = new ArrayList<>(); + while (result.next()) { + devices.add(TOTPDeviceRowMapper.getInstance().map(result)); + } + + return devices.toArray(TOTPDevice[]::new); + }); + + } + + public static int insertUsedCode_Transaction(Start start, Connection con, TOTPUsedCode code) + throws SQLException, StorageQueryException { + String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUsedCodesTable() + + " (user_id, code, is_valid, expiry_time_ms, created_time_ms) VALUES (?, ?, ?, ?, ?);"; + + return update(con, QUERY, pst -> { + pst.setString(1, code.userId); + pst.setString(2, code.code); + pst.setBoolean(3, code.isValid); + pst.setLong(4, code.expiryTime); + pst.setLong(5, code.createdTime); + }); + } + + /** + * Query to get all used codes (expired/non-expired) for a user in descending + * order of creation time. + */ + public static TOTPUsedCode[] getAllUsedCodesDescOrder_Transaction(Start start, Connection con, + String userId) + throws SQLException, StorageQueryException { + // Take a lock based on the user id: + ((ConnectionWithLocks) con).lock(userId + Config.getConfig(start).getTotpUsedCodesTable()); + + String QUERY = "SELECT * FROM " + + Config.getConfig(start).getTotpUsedCodesTable() + + " WHERE user_id = ? ORDER BY created_time_ms DESC"; + return execute(con, QUERY, pst -> { + pst.setString(1, userId); + }, result -> { + List codes = new ArrayList<>(); + while (result.next()) { + codes.add(TOTPUsedCodeRowMapper.getInstance().map(result)); + } + + return codes.toArray(TOTPUsedCode[]::new); + }); + } + + public static int removeExpiredCodes(Start start, long expiredBefore) + throws StorageQueryException, SQLException { + String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUsedCodesTable() + + " WHERE expiry_time_ms < ?;"; + + return update(start, QUERY, pst -> pst.setLong(1, expiredBefore)); + } + + private static class TOTPDeviceRowMapper implements RowMapper { + private static final TOTPDeviceRowMapper INSTANCE = new TOTPDeviceRowMapper(); + + private TOTPDeviceRowMapper() { + } + + private static TOTPDeviceRowMapper getInstance() { + return INSTANCE; + } + + @Override + public TOTPDevice map(ResultSet result) throws SQLException { + return new TOTPDevice( + result.getString("user_id"), + result.getString("device_name"), + result.getString("secret_key"), + result.getInt("period"), + result.getInt("skew"), + result.getBoolean("verified")); + } + } + + private static class TOTPUsedCodeRowMapper implements RowMapper { + private static final TOTPUsedCodeRowMapper INSTANCE = new TOTPUsedCodeRowMapper(); + + private TOTPUsedCodeRowMapper() { + } + + private static TOTPUsedCodeRowMapper getInstance() { + return INSTANCE; + } + + @Override + public TOTPUsedCode map(ResultSet result) throws SQLException { + return new TOTPUsedCode( + result.getString("user_id"), + result.getString("code"), + result.getBoolean("is_valid"), + result.getLong("expiry_time_ms"), + result.getLong("created_time_ms")); + } + } +} diff --git a/src/main/java/io/supertokens/storageLayer/StorageLayer.java b/src/main/java/io/supertokens/storageLayer/StorageLayer.java index d96435fa1..7a35d1531 100644 --- a/src/main/java/io/supertokens/storageLayer/StorageLayer.java +++ b/src/main/java/io/supertokens/storageLayer/StorageLayer.java @@ -23,6 +23,7 @@ import io.supertokens.exceptions.QuitProgramException; import io.supertokens.inmemorydb.Start; import io.supertokens.output.Logging; +import io.supertokens.pluginInterface.ActiveUsersStorage; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.pluginInterface.Storage; import io.supertokens.pluginInterface.authRecipe.AuthRecipeStorage; @@ -34,6 +35,7 @@ import io.supertokens.pluginInterface.passwordless.sqlStorage.PasswordlessSQLStorage; import io.supertokens.pluginInterface.session.SessionStorage; import io.supertokens.pluginInterface.thirdparty.sqlStorage.ThirdPartySQLStorage; +import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; import io.supertokens.pluginInterface.useridmapping.UserIdMappingStorage; import io.supertokens.pluginInterface.usermetadata.sqlStorage.UserMetadataSQLStorage; import io.supertokens.pluginInterface.userroles.sqlStorage.UserRolesSQLStorage; @@ -174,6 +176,14 @@ public static AuthRecipeStorage getAuthRecipeStorage(Main main) { return (AuthRecipeStorage) getInstance(main).storage; } + public static ActiveUsersStorage getActiveUsersStorage(Main main) { + if (getInstance(main) == null) { + throw new QuitProgramException("please call init() before calling getStorageLayer"); + } + + return (ActiveUsersStorage) getInstance(main).storage; + } + public static SessionStorage getSessionStorage(Main main) { if (getInstance(main) == null) { throw new QuitProgramException("please call init() before calling getStorageLayer"); @@ -266,6 +276,17 @@ public static UserIdMappingStorage getUserIdMappingStorage(Main main) { return (UserIdMappingStorage) getInstance(main).storage; } + public static TOTPSQLStorage getTOTPStorage(Main main) { + if (getInstance(main) == null) { + throw new QuitProgramException("please call init() before calling getStorageLayer"); + } + if (getInstance(main).storage.getType() != STORAGE_TYPE.SQL) { + // we only support SQL for now + throw new UnsupportedOperationException(""); + } + return (TOTPSQLStorage) getInstance(main).storage; + } + public static DashboardSQLStorage getDashboardStorage(Main main) { if (getInstance(main) == null) { throw new QuitProgramException("please call init() before calling getStorageLayer"); diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java new file mode 100644 index 000000000..9e5d43ace --- /dev/null +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -0,0 +1,402 @@ +package io.supertokens.totp; + +import java.security.InvalidKeyException; +import java.security.Key; +import java.security.NoSuchAlgorithmException; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; + +import javax.crypto.KeyGenerator; +import javax.crypto.spec.SecretKeySpec; + +import io.supertokens.Main; +import io.supertokens.config.Config; + +import com.eatthepath.otp.TimeBasedOneTimePasswordGenerator; + +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlag; +import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.totp.TOTPStorage; +import io.supertokens.pluginInterface.totp.TOTPUsedCode; +import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; +import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; +import io.supertokens.pluginInterface.totp.exception.UsedCodeAlreadyExistsException; +import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.totp.exceptions.InvalidTotpException; +import io.supertokens.totp.exceptions.LimitReachedException; +import org.apache.commons.codec.binary.Base32; +import org.jetbrains.annotations.TestOnly; + +import static io.supertokens.featureflag.FeatureFlag.getInstance; + +public class Totp { + private static String generateSecret() throws NoSuchAlgorithmException { + // Reference: https://github.com/jchambers/java-otp + final String TOTP_ALGORITHM = "HmacSHA1"; + + final KeyGenerator keyGenerator = KeyGenerator.getInstance(TOTP_ALGORITHM); + keyGenerator.init(160); // 160 bits = 20 bytes + + return new Base32().encodeToString(keyGenerator.generateKey().getEncoded()); + } + + private static boolean checkCode(TOTPDevice device, String code) { + final TimeBasedOneTimePasswordGenerator totp = new TimeBasedOneTimePasswordGenerator( + Duration.ofSeconds(device.period), 6); + + byte[] keyBytes = new Base32().decode(device.secretKey); + Key key = new SecretKeySpec(keyBytes, "HmacSHA1"); + + final int period = device.period; + final int skew = device.skew; + + // Check if code is valid for any of the time periods in the skew: + for (int i = -skew; i <= skew; i++) { + try { + if (totp.generateOneTimePasswordString(key, Instant.now().plusSeconds(i * period)).equals(code)) { + return true; + } + } catch (InvalidKeyException e) { + // This should never happen because we are always using a valid secretKey. + return false; + } + } + + return false; + } + + private static boolean isTotpEnabled(Main main) throws StorageQueryException { + EE_FEATURES[] features = FeatureFlag.getInstance(main).getEnabledFeatures(); + for (EE_FEATURES f : features) { + if (f == EE_FEATURES.TOTP) { + return true; + } + } + return false; + } + + + public static TOTPDevice registerDevice(Main main, String userId, String deviceName, int skew, int period) + throws StorageQueryException, DeviceAlreadyExistsException, NoSuchAlgorithmException, + FeatureNotEnabledException { + + if (!isTotpEnabled(main)){ + throw new FeatureNotEnabledException( + "TOTP feature is not enabled. Please subscribe to a SuperTokens core license key to enable this feature."); + } + + TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); + + String secret = generateSecret(); + TOTPDevice device = new TOTPDevice(userId, deviceName, secret, period, skew, false); + totpStorage.createDevice(device); + + return device; + } + + private static void checkAndStoreCode(Main main, TOTPStorage totpStorage, String userId, TOTPDevice[] devices, + String code) + throws InvalidTotpException, TotpNotEnabledException, + LimitReachedException, StorageQueryException, StorageTransactionLogicException { + // Note that the TOTP cron runs every 1 hour, so all the expired tokens can stay + // in the db for at max 1 hour after expiry. + + // If we filter expired codes in rate limiting logic, then because of + // differences in expiry time of different codes, we might end up with a + // situation where: + // Case 1: users might get released from the rate limiting too early because of + // some invalid codes in the checking window were expired. + // Case 2: users might face random rate limiting because if some valid codes + // expire and if it leads to N contagious invalid + // codes, then the user will be rate limited for no reason. + + // For examaple, assume 0 means expired; 1 means non-expired: + // Also, assume that totp_max_attempts is 3, totp_rate_limit_cooldown_time is + // 15 minutes, and totp_invalid_code_expiry is 5 minutes. + + // Example for Case 1: + // User is rate limited and the used codes are like this: [1, 1, 0, 0, 0]. Now + // if 1st zero (invalid code) expires in 5 mins and we filter + // out expired codes, we'll end up [1, 1, 0, 0]. This doesn't contain 3 + // contiguous invalid codes, so the user will be released from rate limiting in + // 5 minutes instead of 15 minutes. + + // Example for Case 2: + // User has used codes like this: [0, 1, 0, 0]. + // The 1st one (valid code) will expire in merely 1.5 minutes (assuming skew = 2 + // and period = 30s). So now if we filter out expired codes, we'll see + // [0, 0, 0] and this contains 3 contagious invalid codes, so now the user will + // be rate limited for no reason. + + // That's why we need to fetch all the codes (expired + non-expired). + // TOTPUsedCode[] usedCodes = + + TOTPSQLStorage totpSQLStorage = (TOTPSQLStorage) totpStorage; + + while (true) { + try { + totpSQLStorage.startTransaction(con -> { + TOTPUsedCode[] usedCodes = totpSQLStorage.getAllUsedCodesDescOrder_Transaction(con, + userId); + + // N represents # of invalid attempts that will trigger rate limiting: + int N = Config.getConfig(main).getTotpMaxAttempts(); // (Default 5) + // Count # of contiguous invalids in latest N attempts (stop at first valid): + long invalidOutOfN = Arrays.stream(usedCodes).limit(N).takeWhile(usedCode -> !usedCode.isValid) + .count(); + int rateLimitResetTimeInMs = Config.getConfig(main).getTotpRateLimitCooldownTimeSec() * 1000; // (Default + // 15 mins) + + // Check if the user has been rate limited: + if (invalidOutOfN == N) { + // All of the latest N attempts were invalid: + long latestInvalidCodeCreatedTime = usedCodes[0].createdTime; + long now = System.currentTimeMillis(); + + if (now - latestInvalidCodeCreatedTime < rateLimitResetTimeInMs) { + // Less than rateLimitResetTimeInMs (default = 15 mins) time has elasped since + // the last invalid code: + long timeLeftMs = (rateLimitResetTimeInMs - (now - latestInvalidCodeCreatedTime)); + throw new StorageTransactionLogicException(new LimitReachedException(timeLeftMs)); + + // If we insert the used code here, then it will further delay the user from + // being able to login. So not inserting it here. + } + } + + // Check if the code is valid for any device: + boolean isValid = false; + TOTPDevice matchingDevice = null; + for (TOTPDevice device : devices) { + // Check if the code is valid for this device: + if (checkCode(device, code)) { + isValid = true; + matchingDevice = device; + break; + } + } + + // Check if the code has been previously used by the user and it was valid (and + // is still valid). If so, this could be a replay attack. So reject it. + if (isValid) { + for (TOTPUsedCode usedCode : usedCodes) { + // One edge case is that if the user has 2 devices, and they are used back to + // back (within 90 seconds) such that the code of the first device was + // regenerated by the second device, then it won't allow the second device's + // code to be used until it is expired. + // But this would be rare so we can ignore it for now. + if (usedCode.code.equals(code) && usedCode.isValid + && usedCode.expiryTime > System.currentTimeMillis()) { + isValid = false; + // We found a matching device but the code + // will be considered invalid here. + } + } + } + + // Insert the code into the list of used codes: + + // If device is found, calculate used code expiry time for that device (based on + // its period and skew). Otherwise, use the max used code expiry time of all the + // devices. + int maxUsedCodeExpiry = Arrays.stream(devices) + .mapToInt(device -> device.period * (2 * device.skew + 1)) + .max() + .orElse(0); + int expireInSec = (matchingDevice != null) ? matchingDevice.period * (2 * matchingDevice.skew + 1) + : maxUsedCodeExpiry; + + long now = System.currentTimeMillis(); + TOTPUsedCode newCode = new TOTPUsedCode(userId, code, isValid, now + 1000 * expireInSec, now); + try { + totpSQLStorage.insertUsedCode_Transaction(con, newCode); + totpSQLStorage.commitTransaction(con); + } catch (UsedCodeAlreadyExistsException | TotpNotEnabledException e) { + throw new StorageTransactionLogicException(e); + } + + if (!isValid) { + // transaction has been committed, so we can directly throw the exception: + throw new StorageTransactionLogicException(new InvalidTotpException()); + } + + return null; + }); + return; // exit the while loop + } catch (StorageTransactionLogicException e) { + // throwing errors will also help exit the while loop: + if (e.actualException instanceof LimitReachedException) { + throw (LimitReachedException) e.actualException; + } else if (e.actualException instanceof InvalidTotpException) { + throw (InvalidTotpException) e.actualException; + } else if (e.actualException instanceof TotpNotEnabledException) { + throw (TotpNotEnabledException) e.actualException; + } else if (e.actualException instanceof UsedCodeAlreadyExistsException) { + // retry the transaction after a small delay: + int delayInMs = (int) (Math.random() * 10 + 1); + try { + Thread.sleep(delayInMs); + continue; + } catch (InterruptedException err) { + // ignore the error and retry + continue; + } + } else { + throw e; + } + } + } + } + + public static boolean verifyDevice(Main main, String userId, String deviceName, String code) + throws TotpNotEnabledException, UnknownDeviceException, InvalidTotpException, + LimitReachedException, StorageQueryException, StorageTransactionLogicException { + // Here boolean return value tells whether the device has been + // newly verified (true) OR it was already verified (false) + + TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); + TOTPDevice matchingDevice = null; + + // Here one race condition is that the same device + // is to be verified twice in parallel. In that case, + // both the API calls will return true, but that's okay. + + // Check if the user has any devices: + TOTPDevice[] devices = totpStorage.getDevices(userId); + if (devices.length == 0) { + throw new TotpNotEnabledException(); + } + + // Check if the requested device exists: + for (TOTPDevice device : devices) { + if (device.deviceName.equals(deviceName)) { + matchingDevice = device; + if (device.verified) { + return false; // Was already verified + } + break; + } + } + + // No device found: + if (matchingDevice == null) { + throw new UnknownDeviceException(); + } + + // At this point, even if device is suddenly deleted/renamed by another API + // call. We will still check the code against the new set of devices and store + // it in the used codes table. However, the device will not be marked as + // verified in the devices table (because it was deleted/renamed). So the user + // gets a UnknownDevceException. + // This behaviour is okay so we can ignore it. + checkAndStoreCode(main, totpStorage, userId, new TOTPDevice[] { matchingDevice }, code); + // Will reach here only if the code is valid: + totpStorage.markDeviceAsVerified(userId, deviceName); + return true; // Newly verified + } + + public static void verifyCode(Main main, String userId, String code, boolean allowUnverifiedDevices) + throws TotpNotEnabledException, InvalidTotpException, LimitReachedException, + StorageQueryException, StorageTransactionLogicException, FeatureNotEnabledException { + + if (!isTotpEnabled(main)){ + throw new FeatureNotEnabledException( + "TOTP feature is not enabled. Please subscribe to a SuperTokens core license key to enable this feature."); + } + + TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); + + // Check if the user has any devices: + TOTPDevice[] devices = totpStorage.getDevices(userId); + if (devices.length == 0) { + throw new TotpNotEnabledException(); + } + + // Filter out unverified devices: + if (!allowUnverifiedDevices) { + devices = Arrays.stream(devices).filter(device -> device.verified).toArray(TOTPDevice[]::new); + } + + // At this point, even if some of the devices are suddenly deleted/renamed by + // another API call. We will still check the code against the updated set of + // devices and store it in the used codes table. This behaviour is okay so we + // can ignore it. + checkAndStoreCode(main, totpStorage, userId, devices, code); + } + + /** Delete device and also delete the user if deleting the last device */ + public static void removeDevice(Main main, String userId, String deviceName) + throws StorageQueryException, UnknownDeviceException, TotpNotEnabledException, + StorageTransactionLogicException { + TOTPSQLStorage storage = StorageLayer.getTOTPStorage(main); + + try { + storage.startTransaction(con -> { + int deletedCount = storage.deleteDevice_Transaction(con, userId, deviceName); + if (deletedCount == 0) { + throw new StorageTransactionLogicException(new UnknownDeviceException()); + } + + // Some device(s) were deleted. Check if user has any other device left: + // This also takes a lock on the user devices. + TOTPDevice[] devices = storage.getDevices_Transaction(con, userId); + if (devices.length == 0) { + // no device left. delete user + storage.removeUser_Transaction(con, userId); + } + + storage.commitTransaction(con); + return null; + }); + return; + } catch (StorageTransactionLogicException e) { + if (e.actualException instanceof UnknownDeviceException) { + // Check if any device exists for the user: + TOTPDevice[] devices = storage.getDevices(userId); + if (devices.length == 0) { + throw new TotpNotEnabledException(); + } + + throw (UnknownDeviceException) e.actualException; + } + + throw e; + } + } + + public static void updateDeviceName(Main main, String userId, String oldDeviceName, String newDeviceName) + throws StorageQueryException, DeviceAlreadyExistsException, UnknownDeviceException, + TotpNotEnabledException { + TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); + try { + totpStorage.updateDeviceName(userId, oldDeviceName, newDeviceName); + } catch (UnknownDeviceException e) { + // Check if any device exists for the user: + TOTPDevice[] devices = totpStorage.getDevices(userId); + if (devices.length == 0) { + throw new TotpNotEnabledException(); + } else { + throw e; + } + } + } + + public static TOTPDevice[] getDevices(Main main, String userId) + throws StorageQueryException, TotpNotEnabledException { + TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); + + TOTPDevice[] devices = totpStorage.getDevices(userId); + if (devices.length == 0) { + throw new TotpNotEnabledException(); + } + return devices; + } + +} diff --git a/src/main/java/io/supertokens/totp/exceptions/InvalidTotpException.java b/src/main/java/io/supertokens/totp/exceptions/InvalidTotpException.java new file mode 100644 index 000000000..9dce2f51d --- /dev/null +++ b/src/main/java/io/supertokens/totp/exceptions/InvalidTotpException.java @@ -0,0 +1,5 @@ +package io.supertokens.totp.exceptions; + +public class InvalidTotpException extends Exception { + +} diff --git a/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java b/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java new file mode 100644 index 000000000..b7b1c8078 --- /dev/null +++ b/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java @@ -0,0 +1,11 @@ +package io.supertokens.totp.exceptions; + +public class LimitReachedException extends Exception { + + public long retryAfterMs; + + public LimitReachedException(long retryAfterMs) { + super("Retry in " + retryAfterMs + " ms"); + this.retryAfterMs = retryAfterMs; + } +} diff --git a/src/main/java/io/supertokens/useridmapping/UserIdMapping.java b/src/main/java/io/supertokens/useridmapping/UserIdMapping.java index 61faddad4..b8a12e476 100644 --- a/src/main/java/io/supertokens/useridmapping/UserIdMapping.java +++ b/src/main/java/io/supertokens/useridmapping/UserIdMapping.java @@ -22,6 +22,7 @@ import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.jwt.JWTRecipeStorage; import io.supertokens.pluginInterface.session.SessionStorage; +import io.supertokens.pluginInterface.totp.TOTPStorage; import io.supertokens.pluginInterface.useridmapping.UserIdMappingStorage; import io.supertokens.pluginInterface.useridmapping.exception.UnknownSuperTokensUserIdException; import io.supertokens.pluginInterface.useridmapping.exception.UserIdMappingAlreadyExistsException; @@ -38,7 +39,8 @@ public class UserIdMapping { public static void createUserIdMapping(Main main, String superTokensUserId, String externalUserId, - String externalUserIdInfo, boolean force) throws UnknownSuperTokensUserIdException, + String externalUserIdInfo, boolean force) + throws UnknownSuperTokensUserIdException, UserIdMappingAlreadyExistsException, StorageQueryException, ServletException { // if a userIdMapping is created with force, then we skip the following checks if (!force) { @@ -64,7 +66,8 @@ public static void createUserIdMapping(Main main, String superTokensUserId, Stri } public static io.supertokens.pluginInterface.useridmapping.UserIdMapping getUserIdMapping(Main main, String userId, - UserIdType userIdType) throws StorageQueryException { + UserIdType userIdType) + throws StorageQueryException { UserIdMappingStorage storage = StorageLayer.getUserIdMappingStorage(main); if (userIdType == UserIdType.SUPERTOKENS) { @@ -139,7 +142,8 @@ public static boolean deleteUserIdMapping(Main main, String userId, UserIdType u } public static boolean updateOrDeleteExternalUserIdInfo(Main main, String userId, UserIdType userIdType, - @Nullable String externalUserIdInfo) throws StorageQueryException { + @Nullable String externalUserIdInfo) + throws StorageQueryException { UserIdMappingStorage storage = StorageLayer.getUserIdMappingStorage(main); if (userIdType == UserIdType.SUPERTOKENS) { @@ -192,6 +196,13 @@ private static void assertThatUserIdIsNotBeingUsedInNonAuthRecipes(Main main, St new WebserverAPI.BadRequestException("UserId is already in use in EmailVerification recipe")); } } + { + if (StorageLayer.getStorage(main).isUserIdBeingUsedInNonAuthRecipe(TOTPStorage.class.getName(), + userId)) { + throw new ServletException( + new WebserverAPI.BadRequestException("UserId is already in use in TOTP recipe")); + } + } { if (StorageLayer.getStorage(main).isUserIdBeingUsedInNonAuthRecipe(JWTRecipeStorage.class.getName(), userId)) { diff --git a/src/main/java/io/supertokens/webserver/InputParser.java b/src/main/java/io/supertokens/webserver/InputParser.java index 23be846bc..98a1db6d8 100644 --- a/src/main/java/io/supertokens/webserver/InputParser.java +++ b/src/main/java/io/supertokens/webserver/InputParser.java @@ -86,6 +86,21 @@ public static Integer getIntQueryParamOrThrowError(HttpServletRequest request, S } } + public static Long getLongQueryParamOrThrowError(HttpServletRequest request, String fieldName, boolean nullable) + throws ServletException { + String key = getQueryParamOrThrowError(request, fieldName, nullable); + if (key == null && nullable) { + return null; + } + try { + assert key != null; + return Long.parseLong(key); + } catch (Exception e) { + throw new ServletException(new WebserverAPI.BadRequestException( + "Field name '" + fieldName + "' must be a long in the GET request")); + } + } + public static JsonObject parseJsonObjectOrThrowError(JsonObject element, String fieldName, boolean nullable) throws ServletException { try { @@ -197,4 +212,28 @@ public static Long parseLongOrThrowError(JsonObject element, String fieldName, b } } + + + public static Integer parseIntOrThrowError(JsonObject element, String fieldName, boolean nullable) + throws ServletException { + try { + if (nullable && element.get(fieldName) == null) { + return null; + + } + String stringified = element.toString(); + if (!stringified.contains("\"")) { + throw new Exception(); + + } + return element.get(fieldName).getAsInt(); + + } catch (Exception e) { + throw new ServletException( + new WebserverAPI.BadRequestException("Field name '" + fieldName + "' is invalid in JSON input")); + + } + + } + } diff --git a/src/main/java/io/supertokens/webserver/Webserver.java b/src/main/java/io/supertokens/webserver/Webserver.java index 68bfebcff..216996de2 100644 --- a/src/main/java/io/supertokens/webserver/Webserver.java +++ b/src/main/java/io/supertokens/webserver/Webserver.java @@ -45,6 +45,11 @@ import io.supertokens.webserver.api.session.*; import io.supertokens.webserver.api.thirdparty.GetUsersByEmailAPI; import io.supertokens.webserver.api.thirdparty.SignInUpAPI; +import io.supertokens.webserver.api.totp.CreateOrUpdateTotpDeviceAPI; +import io.supertokens.webserver.api.totp.GetTotpDevicesAPI; +import io.supertokens.webserver.api.totp.RemoveTotpDeviceAPI; +import io.supertokens.webserver.api.totp.VerifyTotpAPI; +import io.supertokens.webserver.api.totp.VerifyTotpDeviceAPI; import io.supertokens.webserver.api.useridmapping.RemoveUserIdMappingAPI; import io.supertokens.webserver.api.useridmapping.UpdateExternalUserIdInfoAPI; import io.supertokens.webserver.api.useridmapping.UserIdMappingAPI; @@ -233,6 +238,7 @@ private void setupRoutes() throws Exception { addAPI(new ConsumeCodeAPI(main)); addAPI(new TelemetryAPI(main)); addAPI(new UsersCountAPI(main)); + addAPI(new ActiveUsersCountAPI(main)); addAPI(new UsersAPI(main)); addAPI(new DeleteUserAPI(main)); addAPI(new RevokeAllTokensForUserAPI(main)); @@ -253,6 +259,11 @@ private void setupRoutes() throws Exception { addAPI(new GetRolesAPI(main)); addAPI(new UserIdMappingAPI(main)); addAPI(new RemoveUserIdMappingAPI(main)); + addAPI(new CreateOrUpdateTotpDeviceAPI(main)); + addAPI(new VerifyTotpDeviceAPI(main)); + addAPI(new VerifyTotpAPI(main)); + addAPI(new RemoveTotpDeviceAPI(main)); + addAPI(new GetTotpDevicesAPI(main)); addAPI(new UpdateExternalUserIdInfoAPI(main)); addAPI(new ImportUserWithPasswordHashAPI(main)); addAPI(new LicenseKeyAPI(main)); diff --git a/src/main/java/io/supertokens/webserver/WebserverAPI.java b/src/main/java/io/supertokens/webserver/WebserverAPI.java index 6c4a025e8..ece9509e8 100644 --- a/src/main/java/io/supertokens/webserver/WebserverAPI.java +++ b/src/main/java/io/supertokens/webserver/WebserverAPI.java @@ -51,10 +51,11 @@ public abstract class WebserverAPI extends HttpServlet { supportedVersions.add("2.16"); supportedVersions.add("2.17"); supportedVersions.add("2.18"); + supportedVersions.add("2.19"); } public static String getLatestCDIVersion() { - return "2.18"; + return "2.19"; } public WebserverAPI(Main main, String rid) { diff --git a/src/main/java/io/supertokens/webserver/api/core/ActiveUsersCountAPI.java b/src/main/java/io/supertokens/webserver/api/core/ActiveUsersCountAPI.java new file mode 100644 index 000000000..73efb131a --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/core/ActiveUsersCountAPI.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.webserver.api.core; + +import com.google.gson.JsonObject; +import io.supertokens.ActiveUsers; +import io.supertokens.Main; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.webserver.InputParser; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import java.io.IOException; + +public class ActiveUsersCountAPI extends WebserverAPI { + private static final long serialVersionUID = -2225750492558064634L; + + public ActiveUsersCountAPI(Main main) { + super(main, ""); + } + + @Override + public String getPath() { + return "/users/count/active"; + } + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + Long sinceTimestamp = InputParser.getLongQueryParamOrThrowError(req, "since", false); + + if (sinceTimestamp < 0) { + throw new ServletException(new BadRequestException("'since' query parameter must be >= 0")); + } + + try { + int count = ActiveUsers.countUsersActiveSince(main, sinceTimestamp); + JsonObject result = new JsonObject(); + result.addProperty("status", "OK"); + result.addProperty("count", count); + super.sendJsonResponse(200, result, resp); + } catch (StorageQueryException e) { + throw new ServletException(e); + } + } +} + diff --git a/src/main/java/io/supertokens/webserver/api/core/DeleteUserAPI.java b/src/main/java/io/supertokens/webserver/api/core/DeleteUserAPI.java index bbb9ea916..26c1dc0d5 100644 --- a/src/main/java/io/supertokens/webserver/api/core/DeleteUserAPI.java +++ b/src/main/java/io/supertokens/webserver/api/core/DeleteUserAPI.java @@ -28,6 +28,7 @@ import io.supertokens.authRecipe.AuthRecipe; import io.supertokens.pluginInterface.RECIPE_ID; import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.webserver.InputParser; import io.supertokens.webserver.WebserverAPI; @@ -53,7 +54,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I JsonObject result = new JsonObject(); result.addProperty("status", "OK"); super.sendJsonResponse(200, result, resp); - } catch (StorageQueryException e) { + } catch (StorageQueryException | StorageTransactionLogicException e) { throw new ServletException(e); } } diff --git a/src/main/java/io/supertokens/webserver/api/emailpassword/SignInAPI.java b/src/main/java/io/supertokens/webserver/api/emailpassword/SignInAPI.java index 7d09ab04f..3b598601c 100644 --- a/src/main/java/io/supertokens/webserver/api/emailpassword/SignInAPI.java +++ b/src/main/java/io/supertokens/webserver/api/emailpassword/SignInAPI.java @@ -19,6 +19,8 @@ import com.google.gson.Gson; import com.google.gson.JsonObject; import com.google.gson.JsonParser; + +import io.supertokens.ActiveUsers; import io.supertokens.Main; import io.supertokens.emailpassword.EmailPassword; import io.supertokens.emailpassword.exceptions.WrongCredentialsException; @@ -65,6 +67,8 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I try { UserInfo user = EmailPassword.signIn(super.main, normalisedEmail, password); + ActiveUsers.updateLastActive(main, user.id); // use the internal user id + // if a userIdMapping exists, pass the externalUserId to the response UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, user.id, UserIdType.ANY); diff --git a/src/main/java/io/supertokens/webserver/api/emailpassword/SignUpAPI.java b/src/main/java/io/supertokens/webserver/api/emailpassword/SignUpAPI.java index 45310c081..6c4ed02b9 100644 --- a/src/main/java/io/supertokens/webserver/api/emailpassword/SignUpAPI.java +++ b/src/main/java/io/supertokens/webserver/api/emailpassword/SignUpAPI.java @@ -19,6 +19,8 @@ import com.google.gson.Gson; import com.google.gson.JsonObject; import com.google.gson.JsonParser; + +import io.supertokens.ActiveUsers; import io.supertokens.Main; import io.supertokens.emailpassword.EmailPassword; import io.supertokens.output.Logging; @@ -67,6 +69,8 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I try { UserInfo user = EmailPassword.signUp(super.main, normalisedEmail, password); + ActiveUsers.updateLastActive(main, user.id); + JsonObject result = new JsonObject(); result.addProperty("status", "OK"); JsonObject userJson = new JsonParser().parse(new Gson().toJson(user)).getAsJsonObject(); diff --git a/src/main/java/io/supertokens/webserver/api/passwordless/ConsumeCodeAPI.java b/src/main/java/io/supertokens/webserver/api/passwordless/ConsumeCodeAPI.java index fb508e230..7c37c36a9 100644 --- a/src/main/java/io/supertokens/webserver/api/passwordless/ConsumeCodeAPI.java +++ b/src/main/java/io/supertokens/webserver/api/passwordless/ConsumeCodeAPI.java @@ -19,6 +19,8 @@ import com.google.gson.Gson; import com.google.gson.JsonObject; import com.google.gson.JsonParser; + +import io.supertokens.ActiveUsers; import io.supertokens.Main; import io.supertokens.passwordless.Passwordless; import io.supertokens.passwordless.Passwordless.ConsumeCodeResponse; @@ -81,6 +83,8 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I ConsumeCodeResponse consumeCodeResponse = Passwordless.consumeCode(main, deviceId, deviceIdHash, userInputCode, linkCode); + ActiveUsers.updateLastActive(main, consumeCodeResponse.user.id); + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(main, consumeCodeResponse.user.id, UserIdType.ANY); if (userIdMapping != null) { diff --git a/src/main/java/io/supertokens/webserver/api/session/RefreshSessionAPI.java b/src/main/java/io/supertokens/webserver/api/session/RefreshSessionAPI.java index 60308d563..20cf5001f 100644 --- a/src/main/java/io/supertokens/webserver/api/session/RefreshSessionAPI.java +++ b/src/main/java/io/supertokens/webserver/api/session/RefreshSessionAPI.java @@ -16,25 +16,28 @@ package io.supertokens.webserver.api.session; -import com.google.gson.Gson; import com.google.gson.JsonObject; -import com.google.gson.JsonParser; +import io.supertokens.ActiveUsers; import io.supertokens.Main; import io.supertokens.exceptions.TokenTheftDetectedException; import io.supertokens.exceptions.UnauthorisedException; import io.supertokens.output.Logging; import io.supertokens.pluginInterface.RECIPE_ID; +import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.useridmapping.UserIdMapping; import io.supertokens.session.Session; import io.supertokens.session.info.SessionInformationHolder; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.useridmapping.UserIdType; import io.supertokens.utils.Utils; import io.supertokens.webserver.InputParser; import io.supertokens.webserver.WebserverAPI; - import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; + import java.io.IOException; public class RefreshSessionAPI extends WebserverAPI { @@ -61,6 +64,22 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I try { SessionInformationHolder sessionInfo = Session.refreshSession(main, refreshToken, antiCsrfToken, enableAntiCsrf); + + if (StorageLayer.getStorage(main).getType() == STORAGE_TYPE.SQL) { + try { + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping( + super.main, + sessionInfo.session.userId, UserIdType.ANY); + if (userIdMapping != null) { + ActiveUsers.updateLastActive(main, userIdMapping.superTokensUserId); + } else { + ActiveUsers.updateLastActive(main, sessionInfo.session.userId); + } + } catch (StorageQueryException ignored) { + } + } + + JsonObject result = sessionInfo.toJsonObject(); result.addProperty("status", "OK"); super.sendJsonResponse(200, result, resp); diff --git a/src/main/java/io/supertokens/webserver/api/session/SessionAPI.java b/src/main/java/io/supertokens/webserver/api/session/SessionAPI.java index 8ef9a0441..6a4273ec8 100644 --- a/src/main/java/io/supertokens/webserver/api/session/SessionAPI.java +++ b/src/main/java/io/supertokens/webserver/api/session/SessionAPI.java @@ -19,28 +19,32 @@ import com.google.gson.Gson; import com.google.gson.JsonArray; import com.google.gson.JsonObject; -import com.google.gson.JsonParser; +import io.supertokens.ActiveUsers; import io.supertokens.Main; import io.supertokens.exceptions.UnauthorisedException; import io.supertokens.output.Logging; import io.supertokens.pluginInterface.RECIPE_ID; +import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.session.SessionInfo; +import io.supertokens.pluginInterface.useridmapping.UserIdMapping; import io.supertokens.session.Session; import io.supertokens.session.accessToken.AccessTokenSigningKey; import io.supertokens.session.accessToken.AccessTokenSigningKey.KeyInfo; import io.supertokens.session.info.SessionInformationHolder; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.useridmapping.UserIdType; import io.supertokens.utils.Utils; import io.supertokens.webserver.InputParser; import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; import javax.crypto.BadPaddingException; import javax.crypto.IllegalBlockSizeException; import javax.crypto.NoSuchPaddingException; -import jakarta.servlet.ServletException; -import jakarta.servlet.http.HttpServletRequest; -import jakarta.servlet.http.HttpServletResponse; import java.io.IOException; import java.security.InvalidAlgorithmParameterException; import java.security.InvalidKeyException; @@ -77,6 +81,20 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I SessionInformationHolder sessionInfo = Session.createNewSession(main, userId, userDataInJWT, userDataInDatabase, enableAntiCsrf); + if (StorageLayer.getStorage(main).getType() == STORAGE_TYPE.SQL) { + try { + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping( + super.main, + sessionInfo.session.userId, UserIdType.ANY); + if (userIdMapping != null) { + ActiveUsers.updateLastActive(main, userIdMapping.superTokensUserId); + } else { + ActiveUsers.updateLastActive(main, sessionInfo.session.userId); + } + } catch (StorageQueryException ignored) { + } + } + JsonObject result = sessionInfo.toJsonObject(); result.addProperty("status", "OK"); diff --git a/src/main/java/io/supertokens/webserver/api/session/SessionRemoveAPI.java b/src/main/java/io/supertokens/webserver/api/session/SessionRemoveAPI.java index 213f3e943..875034cc0 100644 --- a/src/main/java/io/supertokens/webserver/api/session/SessionRemoveAPI.java +++ b/src/main/java/io/supertokens/webserver/api/session/SessionRemoveAPI.java @@ -19,16 +19,21 @@ import com.google.gson.JsonArray; import com.google.gson.JsonObject; import com.google.gson.JsonPrimitive; +import io.supertokens.ActiveUsers; import io.supertokens.Main; import io.supertokens.pluginInterface.RECIPE_ID; +import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.useridmapping.UserIdMapping; import io.supertokens.session.Session; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.useridmapping.UserIdType; import io.supertokens.webserver.InputParser; import io.supertokens.webserver.WebserverAPI; - import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; + import java.io.IOException; public class SessionRemoveAPI extends WebserverAPI { @@ -74,6 +79,21 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I if (userId != null) { try { String[] sessionHandlesRevoked = Session.revokeAllSessionsForUser(main, userId); + + if (StorageLayer.getStorage(main).getType() == STORAGE_TYPE.SQL) { + try { + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping( + super.main, + userId, UserIdType.ANY); + if (userIdMapping != null) { + ActiveUsers.updateLastActive(main, userIdMapping.superTokensUserId); + } else { + ActiveUsers.updateLastActive(main, userId); + } + } catch (StorageQueryException ignored) { + } + } + JsonObject result = new JsonObject(); result.addProperty("status", "OK"); JsonArray sessionHandlesRevokedJSON = new JsonArray(); diff --git a/src/main/java/io/supertokens/webserver/api/thirdparty/SignInUpAPI.java b/src/main/java/io/supertokens/webserver/api/thirdparty/SignInUpAPI.java index 7f768b26b..36bae966b 100644 --- a/src/main/java/io/supertokens/webserver/api/thirdparty/SignInUpAPI.java +++ b/src/main/java/io/supertokens/webserver/api/thirdparty/SignInUpAPI.java @@ -19,6 +19,8 @@ import com.google.gson.Gson; import com.google.gson.JsonObject; import com.google.gson.JsonParser; + +import io.supertokens.ActiveUsers; import io.supertokens.Main; import io.supertokens.pluginInterface.RECIPE_ID; import io.supertokens.pluginInterface.exceptions.StorageQueryException; @@ -70,6 +72,8 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I ThirdParty.SignInUpResponse response = ThirdParty.signInUp2_7(super.main, thirdPartyId, thirdPartyUserId, email, isEmailVerified); + ActiveUsers.updateLastActive(main, response.user.id); + JsonObject result = new JsonObject(); result.addProperty("status", "OK"); result.addProperty("createdNewUser", response.createdNewUser); @@ -100,6 +104,8 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I ThirdParty.SignInUpResponse response = ThirdParty.signInUp(super.main, thirdPartyId, thirdPartyUserId, email); + ActiveUsers.updateLastActive(main, response.user.id); + // io.supertokens.pluginInterface.useridmapping.UserIdMapping userIdMapping = UserIdMapping .getUserIdMapping(main, response.user.id, UserIdType.ANY); diff --git a/src/main/java/io/supertokens/webserver/api/totp/CreateOrUpdateTotpDeviceAPI.java b/src/main/java/io/supertokens/webserver/api/totp/CreateOrUpdateTotpDeviceAPI.java new file mode 100644 index 000000000..fea6e0725 --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/totp/CreateOrUpdateTotpDeviceAPI.java @@ -0,0 +1,130 @@ +package io.supertokens.webserver.api.totp; + +import java.io.IOException; +import java.security.NoSuchAlgorithmException; + +import com.google.gson.JsonObject; +import io.supertokens.Main; +import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; +import io.supertokens.pluginInterface.RECIPE_ID; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; +import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; +import io.supertokens.pluginInterface.useridmapping.UserIdMapping; +import io.supertokens.totp.Totp; +import io.supertokens.useridmapping.UserIdType; +import io.supertokens.webserver.InputParser; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +public class CreateOrUpdateTotpDeviceAPI extends WebserverAPI { + private static final long serialVersionUID = -4641988458637882374L; + + public CreateOrUpdateTotpDeviceAPI(Main main) { + super(main, RECIPE_ID.TOTP.toString()); + } + + @Override + public String getPath() { + return "/recipe/totp/device"; + } + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + JsonObject input = InputParser.parseJsonObjectOrThrowError(req); + + String userId = InputParser.parseStringOrThrowError(input, "userId", false); + String deviceName = InputParser.parseStringOrThrowError(input, "deviceName", false); + Integer skew = InputParser.parseIntOrThrowError(input, "skew", false); + Integer period = InputParser.parseIntOrThrowError(input, "period", false); + + // Note: Not allowing the user to change the hashing algo and totp + // length (6-8) at the moment because it's rare to change them + + if (userId.isEmpty()) { + throw new ServletException(new BadRequestException("userId cannot be empty")); + } + if (deviceName.isEmpty()) { + throw new ServletException(new BadRequestException("deviceName cannot be empty")); + } + if (skew < 0) { + throw new ServletException(new BadRequestException("skew must be >= 0")); + } + if (period <= 0) { + throw new ServletException(new BadRequestException("period must be > 0")); + } + + JsonObject result = new JsonObject(); + + try { + // This step is required only because user_last_active table stores supertokens internal user id. + // While sending the usage stats we do a join, so totp tables also must use internal user id. + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, userId, UserIdType.ANY); + if (userIdMapping != null) { + userId = userIdMapping.superTokensUserId; + } + + TOTPDevice device = Totp.registerDevice(main, userId, deviceName, skew, period); + + result.addProperty("status", "OK"); + result.addProperty("secret", device.secretKey); + super.sendJsonResponse(200, result, resp); + } catch (DeviceAlreadyExistsException e) { + result.addProperty("status", "DEVICE_ALREADY_EXISTS_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (StorageQueryException | NoSuchAlgorithmException | FeatureNotEnabledException e) { + throw new ServletException(e); + } + } + + @Override + protected void doPut(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + JsonObject input = InputParser.parseJsonObjectOrThrowError(req); + + String userId = InputParser.parseStringOrThrowError(input, "userId", false); + String existingDeviceName = InputParser.parseStringOrThrowError(input, "existingDeviceName", false); + String newDeviceName = InputParser.parseStringOrThrowError(input, "newDeviceName", false); + + if (userId.isEmpty()) { + throw new ServletException(new BadRequestException("userId cannot be empty")); + } + if (existingDeviceName.isEmpty()) { + throw new ServletException(new BadRequestException("existingDeviceName cannot be empty")); + } + if (newDeviceName.isEmpty()) { + throw new ServletException(new BadRequestException("newDeviceName cannot be empty")); + } + + JsonObject result = new JsonObject(); + + try { + // This step is required only because user_last_active table stores supertokens internal user id. + // While sending the usage stats we do a join, so totp tables also must use internal user id. + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, userId, UserIdType.ANY); + if (userIdMapping != null) { + userId = userIdMapping.superTokensUserId; + } + + Totp.updateDeviceName(main, userId, existingDeviceName, newDeviceName); + + result.addProperty("status", "OK"); + super.sendJsonResponse(200, result, resp); + } catch (TotpNotEnabledException e) { + result.addProperty("status", "TOTP_NOT_ENABLED_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (UnknownDeviceException e) { + result.addProperty("status", "UNKNOWN_DEVICE_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (DeviceAlreadyExistsException e) { + result.addProperty("status", "DEVICE_ALREADY_EXISTS_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (StorageQueryException e) { + throw new ServletException(e); + } + } + +} diff --git a/src/main/java/io/supertokens/webserver/api/totp/GetTotpDevicesAPI.java b/src/main/java/io/supertokens/webserver/api/totp/GetTotpDevicesAPI.java new file mode 100644 index 000000000..9bda4ed27 --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/totp/GetTotpDevicesAPI.java @@ -0,0 +1,75 @@ +package io.supertokens.webserver.api.totp; + +import java.io.IOException; + +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; + +import io.supertokens.Main; +import io.supertokens.pluginInterface.RECIPE_ID; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.useridmapping.UserIdMapping; +import io.supertokens.totp.Totp; +import io.supertokens.useridmapping.UserIdType; +import io.supertokens.webserver.InputParser; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +public class GetTotpDevicesAPI extends WebserverAPI { + private static final long serialVersionUID = -4641988458637882374L; + + public GetTotpDevicesAPI(Main main) { + super(main, RECIPE_ID.TOTP.toString()); + } + + @Override + public String getPath() { + return "/recipe/totp/device/list"; + } + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + String userId = InputParser.getQueryParamOrThrowError(req, "userId", false); + + if (userId.isEmpty()) { + throw new ServletException(new BadRequestException("userId cannot be empty")); + } + + JsonObject result = new JsonObject(); + + try { + // This step is required only because user_last_active table stores supertokens internal user id. + // While sending the usage stats we do a join, so totp tables also must use internal user id. + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, userId, UserIdType.ANY); + if (userIdMapping != null) { + userId = userIdMapping.superTokensUserId; + } + + TOTPDevice[] devices = Totp.getDevices(main, userId); + JsonArray devicesArray = new JsonArray(); + + for (TOTPDevice d : devices) { + JsonObject item = new JsonObject(); + item.addProperty("name", d.deviceName); + item.addProperty("period", d.period); + item.addProperty("skew", d.skew); + item.addProperty("verified", d.verified); + + devicesArray.add(item); + } + + result.addProperty("status", "OK"); + result.add("devices", devicesArray); + super.sendJsonResponse(200, result, resp); + } catch (TotpNotEnabledException e) { + result.addProperty("status", "TOTP_NOT_ENABLED_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (StorageQueryException e) { + throw new ServletException(e); + } + } +} diff --git a/src/main/java/io/supertokens/webserver/api/totp/RemoveTotpDeviceAPI.java b/src/main/java/io/supertokens/webserver/api/totp/RemoveTotpDeviceAPI.java new file mode 100644 index 000000000..634e6fe20 --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/totp/RemoveTotpDeviceAPI.java @@ -0,0 +1,74 @@ +package io.supertokens.webserver.api.totp; + +import java.io.IOException; + +import com.google.gson.JsonObject; + +import io.supertokens.Main; +import io.supertokens.pluginInterface.RECIPE_ID; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; +import io.supertokens.pluginInterface.useridmapping.UserIdMapping; +import io.supertokens.totp.Totp; +import io.supertokens.useridmapping.UserIdType; +import io.supertokens.webserver.InputParser; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +public class RemoveTotpDeviceAPI extends WebserverAPI { + private static final long serialVersionUID = -4641988458637882374L; + + public RemoveTotpDeviceAPI(Main main) { + super(main, RECIPE_ID.TOTP.toString()); + } + + @Override + public String getPath() { + return "/recipe/totp/device/remove"; + } + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + JsonObject input = InputParser.parseJsonObjectOrThrowError(req); + + String userId = InputParser.parseStringOrThrowError(input, "userId", false); + String deviceName = InputParser.parseStringOrThrowError(input, "deviceName", false); + + if (userId.isEmpty()) { + throw new ServletException(new BadRequestException("userId cannot be empty")); + } + if (deviceName.isEmpty()) { + throw new ServletException(new BadRequestException("deviceName cannot be empty")); + } + + JsonObject result = new JsonObject(); + + try { + // This step is required only because user_last_active table stores supertokens internal user id. + // While sending the usage stats we do a join, so totp tables also must use internal user id. + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, userId, UserIdType.ANY); + if (userIdMapping != null) { + userId = userIdMapping.superTokensUserId; + } + + Totp.removeDevice(main, userId, deviceName); + + result.addProperty("status", "OK"); + result.addProperty("didDeviceExist", true); + super.sendJsonResponse(200, result, resp); + } catch (TotpNotEnabledException e) { + result.addProperty("status", "TOTP_NOT_ENABLED_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (UnknownDeviceException e) { + result.addProperty("status", "OK"); + result.addProperty("didDeviceExist", false); + super.sendJsonResponse(200, result, resp); + } catch (StorageQueryException | StorageTransactionLogicException e) { + throw new ServletException(e); + } + } +} diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java new file mode 100644 index 000000000..d7c684c5b --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java @@ -0,0 +1,80 @@ +package io.supertokens.webserver.api.totp; + +import java.io.IOException; + +import com.google.gson.JsonObject; + +import io.supertokens.Main; +import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; +import io.supertokens.pluginInterface.RECIPE_ID; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.useridmapping.UserIdMapping; +import io.supertokens.totp.Totp; +import io.supertokens.totp.exceptions.InvalidTotpException; +import io.supertokens.totp.exceptions.LimitReachedException; +import io.supertokens.useridmapping.UserIdType; +import io.supertokens.webserver.InputParser; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +public class VerifyTotpAPI extends WebserverAPI { + private static final long serialVersionUID = -4641988458637882374L; + + public VerifyTotpAPI(Main main) { + super(main, RECIPE_ID.TOTP.toString()); + } + + @Override + public String getPath() { + return "/recipe/totp/verify"; + } + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + JsonObject input = InputParser.parseJsonObjectOrThrowError(req); + + String userId = InputParser.parseStringOrThrowError(input, "userId", false); + String totp = InputParser.parseStringOrThrowError(input, "totp", false); + Boolean allowUnverifiedDevices = InputParser.parseBooleanOrThrowError(input, "allowUnverifiedDevices", false); + + if (userId.isEmpty()) { + throw new ServletException(new BadRequestException("userId cannot be empty")); + } + if (totp.length() != 6) { + throw new ServletException(new BadRequestException("totp must be 6 characters long")); + } + // Already checked that allowUnverifiedDevices is not null. + + JsonObject result = new JsonObject(); + + try { + // This step is required only because user_last_active table stores supertokens internal user id. + // While sending the usage stats we do a join, so totp tables also must use internal user id. + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, userId, UserIdType.ANY); + if (userIdMapping != null) { + userId = userIdMapping.superTokensUserId; + } + + Totp.verifyCode(main, userId, totp, allowUnverifiedDevices); + + result.addProperty("status", "OK"); + super.sendJsonResponse(200, result, resp); + } catch (TotpNotEnabledException e) { + result.addProperty("status", "TOTP_NOT_ENABLED_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (InvalidTotpException e) { + result.addProperty("status", "INVALID_TOTP_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (LimitReachedException e) { + result.addProperty("status", "LIMIT_REACHED_ERROR"); + result.addProperty("retryAfterMs", e.retryAfterMs); + super.sendJsonResponse(200, result, resp); + } catch (StorageQueryException | StorageTransactionLogicException | FeatureNotEnabledException e) { + throw new ServletException(e); + } + } +} diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java new file mode 100644 index 000000000..f1dd1ba0a --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java @@ -0,0 +1,86 @@ +package io.supertokens.webserver.api.totp; + +import java.io.IOException; + +import com.google.gson.JsonObject; + +import io.supertokens.Main; +import io.supertokens.pluginInterface.RECIPE_ID; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; +import io.supertokens.pluginInterface.useridmapping.UserIdMapping; +import io.supertokens.totp.Totp; +import io.supertokens.totp.exceptions.InvalidTotpException; +import io.supertokens.totp.exceptions.LimitReachedException; +import io.supertokens.useridmapping.UserIdType; +import io.supertokens.webserver.InputParser; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +public class VerifyTotpDeviceAPI extends WebserverAPI { + private static final long serialVersionUID = -4641988458637882374L; + + public VerifyTotpDeviceAPI(Main main) { + super(main, RECIPE_ID.TOTP.toString()); + } + + @Override + public String getPath() { + return "/recipe/totp/device/verify"; + } + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + JsonObject input = InputParser.parseJsonObjectOrThrowError(req); + + String userId = InputParser.parseStringOrThrowError(input, "userId", false); + String deviceName = InputParser.parseStringOrThrowError(input, "deviceName", false); + String totp = InputParser.parseStringOrThrowError(input, "totp", false); + + if (userId.isEmpty()) { + throw new ServletException(new BadRequestException("userId cannot be empty")); + } + if (deviceName.isEmpty()) { + throw new ServletException(new BadRequestException("deviceName cannot be empty")); + } + if (totp.length() != 6) { + throw new ServletException(new BadRequestException("totp must be 6 characters long")); + } + + JsonObject result = new JsonObject(); + + try { + // This step is required only because user_last_active table stores supertokens internal user id. + // While sending the usage stats we do a join, so totp tables also must use internal user id. + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, userId, UserIdType.ANY); + if (userIdMapping != null) { + userId = userIdMapping.superTokensUserId; + } + + boolean isNewlyVerified = Totp.verifyDevice(main, userId, deviceName, totp); + + result.addProperty("status", "OK"); + result.addProperty("wasAlreadyVerified", !isNewlyVerified); + super.sendJsonResponse(200, result, resp); + } catch (TotpNotEnabledException e) { + result.addProperty("status", "TOTP_NOT_ENABLED_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (UnknownDeviceException e) { + result.addProperty("status", "UNKNOWN_DEVICE_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (InvalidTotpException e) { + result.addProperty("status", "INVALID_TOTP_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (LimitReachedException e) { + result.addProperty("status", "LIMIT_REACHED_ERROR"); + result.addProperty("retryAfterMs", e.retryAfterMs); + super.sendJsonResponse(200, result, resp); + } catch (StorageQueryException | StorageTransactionLogicException e) { + throw new ServletException(e); + } + } +} diff --git a/src/test/java/io/supertokens/test/ActiveUsersTest.java b/src/test/java/io/supertokens/test/ActiveUsersTest.java new file mode 100644 index 000000000..d0ff48460 --- /dev/null +++ b/src/test/java/io/supertokens/test/ActiveUsersTest.java @@ -0,0 +1,213 @@ +package io.supertokens.test; + +import com.google.gson.JsonObject; +import io.supertokens.ActiveUsers; +import io.supertokens.Main; +import io.supertokens.ProcessState; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.httpRequest.HttpResponseException; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import java.util.HashMap; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; + +public class ActiveUsersTest { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + @Test + public void updateAndCountUserLastActiveTest() throws Exception { + String[] args = {"../"}; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + Main main = process.getProcess(); + long now = System.currentTimeMillis(); + + assert ActiveUsers.countUsersActiveSince(main, now) == 0; + + ActiveUsers.updateLastActive(main, "user1"); + ActiveUsers.updateLastActive(main, "user2"); + + assert ActiveUsers.countUsersActiveSince(main, now) == 2; + + long now2 = System.currentTimeMillis(); + + ActiveUsers.updateLastActive(main, "user1"); + + assert ActiveUsers.countUsersActiveSince(main, now2) == 1; // only user1 is counted + assert ActiveUsers.countUsersActiveSince(main, now) == 2; // user1 and user2 are counted + } + + @Test + public void activeUserCountAPITest() throws Exception { + String[] args = {"../"}; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + Main main = process.getProcess(); + long now = System.currentTimeMillis(); + + HashMap params = new HashMap<>(); + + HttpResponseException e = + assertThrows( + HttpResponseException.class, + () -> { + HttpRequestForTesting.sendGETRequest( + process.getProcess(), + "", + "http://localhost:3567/users/count/active", + params, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + ""); + } + ); + + assert e.statusCode == 400; + assert e.getMessage().contains("Field name 'since' is missing in GET request"); + + params.put("since", "not a number"); + e = + assertThrows( + HttpResponseException.class, + () -> { + HttpRequestForTesting.sendGETRequest( + process.getProcess(), + "", + "http://localhost:3567/users/count/active", + params, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + ""); + } + ); + + assert e.statusCode == 400; + assert e.getMessage().contains("Field name 'since' must be a long in the GET request"); + + params.put("since", "-1"); + e = + assertThrows( + HttpResponseException.class, + () -> { + HttpRequestForTesting.sendGETRequest( + process.getProcess(), + "", + "http://localhost:3567/users/count/active", + params, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + ""); + } + ); + + assert e.statusCode == 400; + assert e.getMessage().contains("'since' query parameter must be >= 0"); + + + params.put("since", Long.toString(now)); + + JsonObject res = HttpRequestForTesting.sendGETRequest( + process.getProcess(), + "", + "http://localhost:3567/users/count/active", + params, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + ""); + + assert res.get("status").getAsString().equals("OK"); + assert res.get("count").getAsInt() == 0; + + ActiveUsers.updateLastActive(main, "user1"); + ActiveUsers.updateLastActive(main, "user2"); + + res = HttpRequestForTesting.sendGETRequest( + process.getProcess(), + "", + "http://localhost:3567/users/count/active", + params, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + ""); + + assert res.get("status").getAsString().equals("OK"); + assert res.get("count").getAsInt() == 2; + + long now2 = System.currentTimeMillis(); + + ActiveUsers.updateLastActive(main, "user1"); + + params.put("since", Long.toString(now2)); + res = HttpRequestForTesting.sendGETRequest( + process.getProcess(), + "", + "http://localhost:3567/users/count/active", + params, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + ""); + + assert res.get("status").getAsString().equals("OK"); + assert res.get("count").getAsInt() == 1; + + params.put("since", Long.toString(now)); + res = HttpRequestForTesting.sendGETRequest( + process.getProcess(), + "", + "http://localhost:3567/users/count/active", + params, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + ""); + + assert res.get("status").getAsString().equals("OK"); + assert res.get("count").getAsInt() == 2; + } + +} diff --git a/src/test/java/io/supertokens/test/ConfigTest2_6.java b/src/test/java/io/supertokens/test/ConfigTest2_6.java index 6f9f04d74..14db92a86 100644 --- a/src/test/java/io/supertokens/test/ConfigTest2_6.java +++ b/src/test/java/io/supertokens/test/ConfigTest2_6.java @@ -135,6 +135,36 @@ public void testThatInvalidConfigThrowRightError() throws Exception { } + @Test + public void testInvalidTotpConfigThrowsExpectedError() throws Exception { + String[] args = { "../" }; + + Utils.setValueInConfig("totp_max_attempts", "0"); + + TestingProcess process = TestingProcessManager.start(args); + + ProcessState.EventAndException e = process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.INIT_FAILURE); + assertNotNull(e); + assertEquals(e.exception.getMessage(), + "'totp_max_attempts' must be > 0"); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(PROCESS_STATE.STOPPED)); + + Utils.reset(); + + Utils.setValueInConfig("totp_rate_limit_cooldown_sec", "0"); + process = TestingProcessManager.start(args); + + e = process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.INIT_FAILURE); + assertNotNull(e); + assertEquals(e.exception.getMessage(), + "'totp_rate_limit_cooldown_sec' must be > 0"); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(PROCESS_STATE.STOPPED)); + } + private String getConfigFileLocation(Main main) { return new File(CLIOptions.get(main).getConfigFilePath() == null ? CLIOptions.get(main).getInstallationPath() + "config.yaml" @@ -220,6 +250,9 @@ private static void checkConfigValues(CoreConfig config, TestingProcess process, assertFalse("Config access token blacklisting did not match default", config.getAccessTokenBlacklisting()); assertEquals("Config refresh token validity did not match default", config.getRefreshTokenValidity(), 60 * 2400 * 60 * (long) 1000); + assertEquals(5, config.getTotpMaxAttempts()); // 5 + assertEquals(900, config.getTotpRateLimitCooldownTimeSec()); // 15 minutes + assertEquals("Config info log path did not match default", config.getInfoLogPath(process.getProcess()), CLIOptions.get(process.getProcess()).getInstallationPath() + "logs/info.log"); assertEquals("Config error log path did not match default", config.getErrorLogPath(process.getProcess()), diff --git a/src/test/java/io/supertokens/test/FeatureFlagTest.java b/src/test/java/io/supertokens/test/FeatureFlagTest.java index bce7b57b6..7dc0d45ab 100644 --- a/src/test/java/io/supertokens/test/FeatureFlagTest.java +++ b/src/test/java/io/supertokens/test/FeatureFlagTest.java @@ -16,6 +16,7 @@ package io.supertokens.test; +import com.google.gson.JsonArray; import com.google.gson.JsonObject; import io.supertokens.ProcessState; import io.supertokens.featureflag.FeatureFlag; @@ -62,7 +63,11 @@ public void noLicenseKeyShouldHaveEmptyFeatureFlag() throws InterruptedException } catch (NoLicenseKeyFoundException ignored) { } - Assert.assertEquals(FeatureFlag.getInstance(process.getProcess()).getPaidFeatureStats().entrySet().size(), 0); + JsonObject stats = FeatureFlag.getInstance(process.getProcess()).getPaidFeatureStats(); + Assert.assertEquals(stats.entrySet().size(), 1); + Assert.assertEquals(stats.get("maus").getAsJsonArray().size(), 30); + Assert.assertEquals(stats.get("maus").getAsJsonArray().get(0).getAsInt(), 0); + Assert.assertEquals(stats.get("maus").getAsJsonArray().get(29).getAsInt(), 0); process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); @@ -111,4 +116,101 @@ public void testThatCallingGetFeatureFlagAPIReturnsEmptyArray() throws Exception process.kill(); Assert.assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } + + private final String OPAQUE_KEY_WITH_TOTP_FEATURE = "pXhNK=nYiEsb6gJEOYP2kIR6M0kn4XLvNqcwT1XbX8xHtm44K-lQfGCbaeN0Ieeza39fxkXr=tiiUU=DXxDH40Y=4FLT4CE-rG1ETjkXxO4yucLpJvw3uSegPayoISGL"; + + @Test + public void testThatCallingGetFeatureFlagAPIReturnsTotpStats() throws Exception { + String[] args = {"../"}; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + Assert.assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlag.getInstance(process.main).setLicenseKeyAndSyncFeatures(OPAQUE_KEY_WITH_TOTP_FEATURE); + + // Get the stats without any users/activity + { + JsonObject response = HttpRequestForTesting.sendGETRequest(process.getProcess(), "", + "http://localhost:3567/ee/featureflag", + null, 1000, 1000, null, WebserverAPI.getLatestCDIVersion(), ""); + Assert.assertEquals("OK", response.get("status").getAsString()); + + JsonArray features = response.get("features").getAsJsonArray(); + JsonObject usageStats = response.get("usageStats").getAsJsonObject(); + JsonArray maus = usageStats.get("maus").getAsJsonArray(); + + assert features.size() == 1; + assert features.get(0).getAsString().equals("totp"); + assert maus.size() == 30; + assert maus.get(0).getAsInt() == 0; + assert maus.get(29).getAsInt() == 0; + + JsonObject totpStats = usageStats.get("totp").getAsJsonObject(); + JsonArray totpMaus = totpStats.get("maus").getAsJsonArray(); + int totalTotpUsers = totpStats.get("total_users").getAsInt(); + + assert totpMaus.size() == 30; + assert totpMaus.get(0).getAsInt() == 0; + assert totpMaus.get(29).getAsInt() == 0; + + assert totalTotpUsers == 0; + } + + // First register 2 users for emailpassword recipe. + // This also marks them as active. + JsonObject signUpResponse = Utils.signUpRequest_2_5(process, "random@gmail.com", "validPass123"); + assert signUpResponse.get("status").getAsString().equals("OK"); + + JsonObject signUpResponse2 = Utils.signUpRequest_2_5(process, "random2@gmail.com", "validPass123"); + assert signUpResponse2.get("status").getAsString().equals("OK"); + + // Now enable TOTP for the first user by registering a device. + JsonObject body = new JsonObject(); + body.addProperty("userId", signUpResponse.get("user").getAsJsonObject().get("id").getAsString()); + body.addProperty("deviceName", "d1"); + body.addProperty("skew", 0); + body.addProperty("period", 30); + JsonObject res = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res.get("status").getAsString().equals("OK"); + + // Now check the stats again: + { + JsonObject response = HttpRequestForTesting.sendGETRequest(process.getProcess(), "", + "http://localhost:3567/ee/featureflag", + null, 1000, 1000, null, WebserverAPI.getLatestCDIVersion(), ""); + Assert.assertEquals("OK", response.get("status").getAsString()); + + JsonArray features = response.get("features").getAsJsonArray(); + JsonObject usageStats = response.get("usageStats").getAsJsonObject(); + JsonArray maus = usageStats.get("maus").getAsJsonArray(); + + assert features.size() == 1; + assert features.get(0).getAsString().equals("totp"); + assert maus.size() == 30; + assert maus.get(0).getAsInt() == 2; // 2 users have signed up + assert maus.get(29).getAsInt() == 2; + + JsonObject totpStats = usageStats.get("totp").getAsJsonObject(); + JsonArray totpMaus = totpStats.get("maus").getAsJsonArray(); + int totalTotpUsers = totpStats.get("total_users").getAsInt(); + + assert totpMaus.size() == 30; + assert totpMaus.get(0).getAsInt() == 1; // only 1 user has TOTP enabled + assert totpMaus.get(29).getAsInt() == 1; + + assert totalTotpUsers == 1; + } + + process.kill(); + Assert.assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } } diff --git a/src/test/java/io/supertokens/test/StorageLayerTest.java b/src/test/java/io/supertokens/test/StorageLayerTest.java new file mode 100644 index 000000000..f90a79850 --- /dev/null +++ b/src/test/java/io/supertokens/test/StorageLayerTest.java @@ -0,0 +1,98 @@ +package io.supertokens.test; + +import io.supertokens.ProcessState; +import io.supertokens.inmemorydb.Start; +import io.supertokens.inmemorydb.config.Config; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.totp.TOTPUsedCode; +import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.totp.exception.UsedCodeAlreadyExistsException; +import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; +import io.supertokens.storageLayer.StorageLayer; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import static org.junit.Assert.assertNotNull; + +public class StorageLayerTest { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + // TOTP recipe: + + public static void insertUsedCodeUtil(TOTPSQLStorage storage, TOTPUsedCode usedCode) throws Exception { + try { + storage.startTransaction(con -> { + try { + storage.insertUsedCode_Transaction(con, usedCode); + storage.commitTransaction(con); + return null; + } catch (TotpNotEnabledException | UsedCodeAlreadyExistsException e) { + throw new StorageTransactionLogicException(e); + } + }); + } catch (StorageTransactionLogicException e) { + Exception actual = e.actualException; + if (actual instanceof TotpNotEnabledException || actual instanceof UsedCodeAlreadyExistsException) { + throw actual; + } else { + throw e; + } + } + } + + @Test + public void totpCodeLengthTest() throws Exception { + String[] args = {"../"}; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + process.getProcess().setForceInMemoryDB(); // this test is only for SQLite + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + TOTPSQLStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); + long now = System.currentTimeMillis(); + long nextDay = now + 1000 * 60 * 60 * 24; // 1 day from now + + Start start = (Start) StorageLayer.getStorage(process.getProcess()); + + TOTPDevice d1 = new TOTPDevice("user", "d1", "secret", 30, 1, false); + storage.createDevice(d1); + + // Try code with length > 8 + try { + TOTPUsedCode code = new TOTPUsedCode("user", "123456789", true, nextDay, now); + insertUsedCodeUtil(storage, code); + assert (false); + } catch (StorageQueryException e) { + // This error will be different in Postgres and MySQL + // We added (CHECK (LENGTH(code) <= 8)) to the table definition in SQLite + String totpUsedCodeTable = Config.getConfig(start).getTotpUsedCodesTable(); + assert e.getMessage().contains("CHECK constraint failed: " + totpUsedCodeTable); + } + + // Try code with length < 8 + TOTPUsedCode code = new TOTPUsedCode("user", "12345678", true, nextDay, now); + insertUsedCodeUtil(storage, code); + } + +} diff --git a/src/test/java/io/supertokens/test/dashboard/DashboardTest.java b/src/test/java/io/supertokens/test/dashboard/DashboardTest.java index d1dceafcc..a394fe7c9 100644 --- a/src/test/java/io/supertokens/test/dashboard/DashboardTest.java +++ b/src/test/java/io/supertokens/test/dashboard/DashboardTest.java @@ -285,7 +285,12 @@ public void testDashboardUsageStats() throws Exception { assertEquals(3, response.entrySet().size()); assertEquals("OK", response.get("status").getAsString()); assertEquals(0, response.get("features").getAsJsonArray().size()); - assertEquals(0, response.get("usageStats").getAsJsonObject().entrySet().size()); + JsonObject usageStats = response.get("usageStats").getAsJsonObject(); + JsonArray mauArr = usageStats.get("maus").getAsJsonArray(); + assertEquals(1, usageStats.entrySet().size()); + assertEquals(30, mauArr.size()); + assertEquals(0, mauArr.get(0).getAsInt()); + assertEquals(0, mauArr.get(29).getAsInt()); } // create a dashboard user @@ -298,7 +303,12 @@ public void testDashboardUsageStats() throws Exception { assertEquals(3, response.entrySet().size()); assertEquals("OK", response.get("status").getAsString()); assertEquals(0, response.get("features").getAsJsonArray().size()); - assertEquals(0, response.get("usageStats").getAsJsonObject().entrySet().size()); + JsonObject usageStats = response.get("usageStats").getAsJsonObject(); + JsonArray mauArr = usageStats.get("maus").getAsJsonArray(); + assertEquals(1, usageStats.entrySet().size()); + assertEquals(30, mauArr.size()); + assertEquals(0, mauArr.get(0).getAsInt()); + assertEquals(0, mauArr.get(29).getAsInt()); } // enable the dashboard feature @@ -315,9 +325,9 @@ public void testDashboardUsageStats() throws Exception { assertEquals(1, featuresArray.size()); assertEquals(EE_FEATURES.DASHBOARD_LOGIN.toString(), featuresArray.get(0).getAsString()); JsonObject usageStats = response.get("usageStats").getAsJsonObject(); - assertEquals(1, - usageStats.entrySet().size()); JsonObject dashboardLoginObject = usageStats.get("dashboard_login").getAsJsonObject(); + assertEquals(2, usageStats.entrySet().size()); + assertEquals(30, usageStats.get("maus").getAsJsonArray().size()); assertEquals(1, dashboardLoginObject.entrySet().size()); assertEquals(1, dashboardLoginObject.get("user_count").getAsInt()); } @@ -338,9 +348,9 @@ public void testDashboardUsageStats() throws Exception { assertEquals(1, featuresArray.size()); assertEquals(EE_FEATURES.DASHBOARD_LOGIN.toString(), featuresArray.get(0).getAsString()); JsonObject usageStats = response.get("usageStats").getAsJsonObject(); - assertEquals(1, - usageStats.entrySet().size()); JsonObject dashboardLoginObject = usageStats.get("dashboard_login").getAsJsonObject(); + assertEquals(2, usageStats.entrySet().size()); + assertEquals(30, usageStats.get("maus").getAsJsonArray().size()); assertEquals(1, dashboardLoginObject.entrySet().size()); assertEquals(4, dashboardLoginObject.get("user_count").getAsInt()); } diff --git a/src/test/java/io/supertokens/test/emailpassword/api/SignInAPITest2_7.java b/src/test/java/io/supertokens/test/emailpassword/api/SignInAPITest2_7.java index 1691c2900..8fdfcbdea 100644 --- a/src/test/java/io/supertokens/test/emailpassword/api/SignInAPITest2_7.java +++ b/src/test/java/io/supertokens/test/emailpassword/api/SignInAPITest2_7.java @@ -17,6 +17,8 @@ package io.supertokens.test.emailpassword.api; import com.google.gson.JsonObject; + +import io.supertokens.ActiveUsers; import io.supertokens.ProcessState; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.storageLayer.StorageLayer; @@ -68,6 +70,8 @@ public void testBadInput() throws Exception { return; } + long startTs = System.currentTimeMillis(); + { try { HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", @@ -107,6 +111,9 @@ public void testBadInput() throws Exception { } } + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 0); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } @@ -133,6 +140,8 @@ public void testGoodInput() throws Exception { responseBody.addProperty("email", "random@gmail.com"); responseBody.addProperty("password", "validPass123"); + long beforeSignIn = System.currentTimeMillis(); + JsonObject signInResponse = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", "http://localhost:3567/recipe/signin", responseBody, 1000, 1000, null, Utils.getCdiVersion2_7ForTests(), "emailpassword"); @@ -147,11 +156,15 @@ public void testGoodInput() throws Exception { signInResponse.get("user").getAsJsonObject().get("timeJoined").getAsLong(); assertEquals(signInResponse.get("user").getAsJsonObject().entrySet().size(), 3); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), beforeSignIn); + assert (activeUsers == 1); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } - // Test that sign in with unnormalised email like Test@gmail.com should also work + // Test that sign in with unnormalised email like Test@gmail.com should also + // work @Test public void testThatUnnormalisedEmailShouldAlsoWork() throws Exception { String[] args = { "../" }; @@ -190,7 +203,8 @@ public void testThatUnnormalisedEmailShouldAlsoWork() throws Exception { assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } - // Test that giving an empty password, empty email, invalid email, random email or wrong password throws a wrong + // Test that giving an empty password, empty email, invalid email, random email + // or wrong password throws a wrong // * credentials error @Test public void testInputsToSignInAPI() throws Exception { diff --git a/src/test/java/io/supertokens/test/emailpassword/api/SignUpAPITest2_7.java b/src/test/java/io/supertokens/test/emailpassword/api/SignUpAPITest2_7.java index 5e921a2d3..bd4b97285 100644 --- a/src/test/java/io/supertokens/test/emailpassword/api/SignUpAPITest2_7.java +++ b/src/test/java/io/supertokens/test/emailpassword/api/SignUpAPITest2_7.java @@ -17,6 +17,8 @@ package io.supertokens.test.emailpassword.api; import com.google.gson.JsonObject; + +import io.supertokens.ActiveUsers; import io.supertokens.ProcessState; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.pluginInterface.emailpassword.UserInfo; @@ -68,6 +70,8 @@ public void testBadInput() throws Exception { return; } + long beforeTestTs = System.currentTimeMillis(); + { try { HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", @@ -107,6 +111,9 @@ public void testBadInput() throws Exception { } } + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), beforeTestTs); + assert (activeUsers == 0); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } @@ -123,6 +130,8 @@ public void testGoodInput() throws Exception { return; } + long beforeSignUpTs = System.currentTimeMillis(); + JsonObject signUpResponse = Utils.signUpRequest_2_5(process, "random@gmail.com", "validPass123"); assertEquals(signUpResponse.get("status").getAsString(), "OK"); assertEquals(signUpResponse.entrySet().size(), 2); @@ -131,6 +140,9 @@ public void testGoodInput() throws Exception { assertEquals(signUpUser.get("email").getAsString(), "random@gmail.com"); assertNotNull(signUpUser.get("id")); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), beforeSignUpTs); + assert (activeUsers == 1); + UserInfo user = StorageLayer.getEmailPasswordStorage(process.getProcess()) .getUserInfoUsingEmail("random@gmail.com"); assertEquals(user.email, signUpUser.get("email").getAsString()); @@ -160,7 +172,8 @@ public void testGoodInput() throws Exception { // Test the normalise email function // Test that only the normalised email is saved in the db - // Failure condition: If the email retrieved from the data is not normalised the test will fail + // Failure condition: If the email retrieved from the data is not normalised the + // test will fail @Test public void testTheNormaliseEmailFunction() throws Exception { String[] args = { "../" }; diff --git a/src/test/java/io/supertokens/test/passwordless/PasswordlessConsumeCodeTest.java b/src/test/java/io/supertokens/test/passwordless/PasswordlessConsumeCodeTest.java index bd716a73d..e83aaa81e 100644 --- a/src/test/java/io/supertokens/test/passwordless/PasswordlessConsumeCodeTest.java +++ b/src/test/java/io/supertokens/test/passwordless/PasswordlessConsumeCodeTest.java @@ -234,7 +234,8 @@ public void testConsumeLinkCodeWithExistingUser() throws Exception { } /** - * Check device clean up when user input code is generated via email & phone number + * Check device clean up when user input code is generated via email & phone + * number * * @throws Exception */ @@ -732,7 +733,8 @@ public void testConsumeWrongLinkCodeExceedingMaxAttempts() throws Exception { } /** - * user input code with too many failedAttempts (changed maxCodeInputAttempts configuration between consumes) + * user input code with too many failedAttempts (changed maxCodeInputAttempts + * configuration between consumes) * TODO: review -> do we need to create code again post restart ? * * @throws Exception diff --git a/src/test/java/io/supertokens/test/passwordless/api/PasswordlessConsumeCodeAPITest2_11.java b/src/test/java/io/supertokens/test/passwordless/api/PasswordlessConsumeCodeAPITest2_11.java index 9cfe9eaa6..cf8f36950 100644 --- a/src/test/java/io/supertokens/test/passwordless/api/PasswordlessConsumeCodeAPITest2_11.java +++ b/src/test/java/io/supertokens/test/passwordless/api/PasswordlessConsumeCodeAPITest2_11.java @@ -27,6 +27,7 @@ import org.junit.Test; import org.junit.rules.TestRule; +import io.supertokens.ActiveUsers; import io.supertokens.ProcessState; import io.supertokens.passwordless.Passwordless; import io.supertokens.passwordless.Passwordless.CreateCodeResponse; @@ -62,6 +63,8 @@ public void testBadInput() throws Exception { return; } + long startTs = System.currentTimeMillis(); + String email = "test@example.com"; CreateCodeResponse createResp = Passwordless.createCode(process.getProcess(), email, null, null, null); { @@ -276,6 +279,9 @@ public void testBadInput() throws Exception { error.getMessage()); } + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 0); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } @@ -291,6 +297,8 @@ public void testLinkCode() throws Exception { return; } + long startTs = System.currentTimeMillis(); + String email = "test@example.com"; CreateCodeResponse createResp = Passwordless.createCode(process.getProcess(), email, null, null, null); @@ -304,6 +312,9 @@ public void testLinkCode() throws Exception { checkResponse(response, true, email, null); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 1); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } @@ -321,6 +332,8 @@ public void testExpiredLinkCode() throws Exception { return; } + long startTs = System.currentTimeMillis(); + String email = "test@example.com"; CreateCodeResponse createResp = Passwordless.createCode(process.getProcess(), email, null, null, null); Thread.sleep(150); @@ -334,6 +347,9 @@ public void testExpiredLinkCode() throws Exception { assertEquals("RESTART_FLOW_ERROR", response.get("status").getAsString()); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 0); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } @@ -349,6 +365,8 @@ public void testUserInputCode() throws Exception { return; } + long startTs = System.currentTimeMillis(); + String email = "test@example.com"; CreateCodeResponse createResp = Passwordless.createCode(process.getProcess(), email, null, null, null); @@ -363,6 +381,9 @@ public void testUserInputCode() throws Exception { checkResponse(response, true, email, null); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 1); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } @@ -379,6 +400,8 @@ public void testExpiredUserInputCode() throws Exception { return; } + long startTs = System.currentTimeMillis(); + String email = "test@example.com"; CreateCodeResponse createResp = Passwordless.createCode(process.getProcess(), email, null, null, null); Thread.sleep(150); @@ -394,6 +417,9 @@ public void testExpiredUserInputCode() throws Exception { assertEquals("EXPIRED_USER_INPUT_CODE_ERROR", response.get("status").getAsString()); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 0); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } diff --git a/src/test/java/io/supertokens/test/session/api/RefreshSessionAPITest2_7.java b/src/test/java/io/supertokens/test/session/api/RefreshSessionAPITest2_7.java index 6fd495fd2..72ab594e5 100644 --- a/src/test/java/io/supertokens/test/session/api/RefreshSessionAPITest2_7.java +++ b/src/test/java/io/supertokens/test/session/api/RefreshSessionAPITest2_7.java @@ -18,6 +18,8 @@ import com.google.gson.JsonNull; import com.google.gson.JsonObject; + +import io.supertokens.ActiveUsers; import io.supertokens.ProcessState; import io.supertokens.test.TestingProcessManager; import io.supertokens.test.Utils; @@ -35,7 +37,7 @@ import static org.junit.Assert.fail; public class RefreshSessionAPITest2_7 { - @Rule + @Rule public TestRule watchman = Utils.getOnFailure(); @AfterClass @@ -187,6 +189,8 @@ public void badInputErrorTest() throws Exception { TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + long start1 = System.currentTimeMillis(); + try { HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", "http://localhost:3567/recipe/session/refresh", null, 1000, 1000, null, @@ -195,12 +199,18 @@ public void badInputErrorTest() throws Exception { } catch (io.supertokens.test.httpRequest.HttpResponseException e) { assertEquals("Http error. Status Code: 400. Message: Invalid Json Input", e.getMessage()); } + + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), start1); + assert (activeUsers == 0); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); process = TestingProcessManager.start(args); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + long start2 = System.currentTimeMillis(); + try { JsonObject jsonBody = new JsonObject(); jsonBody.addProperty("random", "random"); @@ -214,6 +224,13 @@ public void badInputErrorTest() throws Exception { } + activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), start2); + assert (activeUsers == 0); + + long start3 = System.currentTimeMillis(); + + long start3Inner = 0; // to be set after session is created + try { String userId = "userId"; JsonObject userDataInJWT = new JsonObject(); @@ -236,6 +253,8 @@ public void badInputErrorTest() throws Exception { sessionRefreshBody.addProperty("refreshToken", sessionInfo.get("refreshToken").getAsJsonObject().get("token").getAsString()); + start3Inner = System.currentTimeMillis(); + HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", "http://localhost:3567/recipe/session/refresh", sessionRefreshBody, 1000, 1000, null, Utils.getCdiVersion2_7ForTests(), "session"); @@ -246,6 +265,16 @@ public void badInputErrorTest() throws Exception { } + activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), start3); + assert (activeUsers == 1); + + int activeUsersAfterSessionCreate = ActiveUsers.countUsersActiveSince(process.getProcess(), start3Inner); + assert (activeUsersAfterSessionCreate == 0); + + long start4 = System.currentTimeMillis(); + + long start4Inner = 0; // to be set after session is created + try { String userId = "userId"; JsonObject userDataInJWT = new JsonObject(); @@ -269,6 +298,8 @@ public void badInputErrorTest() throws Exception { sessionInfo.get("refreshToken").getAsJsonObject().get("token").getAsString()); sessionRefreshBody.addProperty("enableAntiCsrf", "false"); + start4Inner = System.currentTimeMillis(); + HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", "http://localhost:3567/recipe/session/refresh", sessionRefreshBody, 1000, 1000, null, Utils.getCdiVersion2_7ForTests(), "session"); @@ -278,6 +309,12 @@ public void badInputErrorTest() throws Exception { "Http error. Status Code: 400. Message: Field name 'enableAntiCsrf' is invalid in JSON input"); } + + activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), start4); + assert (activeUsers == 1); + + activeUsersAfterSessionCreate = ActiveUsers.countUsersActiveSince(process.getProcess(), start4Inner); + assert (activeUsersAfterSessionCreate == 0); } @Test @@ -375,6 +412,8 @@ public void successOutputWithValidRefreshTokenTest() throws Exception { request.add("userDataInDatabase", userDataInDatabase); request.addProperty("enableAntiCsrf", false); + long startTs = System.currentTimeMillis(); + JsonObject sessionInfo = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", "http://localhost:3567/recipe/session", request, 1000, 1000, null, Utils.getCdiVersion2_7ForTests(), "session"); @@ -386,14 +425,19 @@ public void successOutputWithValidRefreshTokenTest() throws Exception { sessionInfo.get("refreshToken").getAsJsonObject().get("token").getAsString()); sessionRefreshBody.addProperty("enableAntiCsrf", false); + long afterSessionCreateTs = System.currentTimeMillis(); + JsonObject sessionRefreshResponse = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", "http://localhost:3567/recipe/session/refresh", sessionRefreshBody, 1000, 1000, null, Utils.getCdiVersion2_7ForTests(), "session"); checkRefreshSessionResponse(sessionRefreshResponse, process, userId, userDataInJWT, false); - process.kill(); - assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 1); + + int activeUsersAfterSessionCreate = ActiveUsers.countUsersActiveSince(process.getProcess(), afterSessionCreateTs); + assert (activeUsersAfterSessionCreate == 1); } @Test diff --git a/src/test/java/io/supertokens/test/session/api/SessionAPITest2_7.java b/src/test/java/io/supertokens/test/session/api/SessionAPITest2_7.java index 3af20714b..7813238c0 100644 --- a/src/test/java/io/supertokens/test/session/api/SessionAPITest2_7.java +++ b/src/test/java/io/supertokens/test/session/api/SessionAPITest2_7.java @@ -17,6 +17,8 @@ package io.supertokens.test.session.api; import com.google.gson.JsonObject; + +import io.supertokens.ActiveUsers; import io.supertokens.ProcessState; import io.supertokens.test.TestingProcessManager; import io.supertokens.test.Utils; @@ -54,6 +56,8 @@ public void successOutputCheckWithAntiCsrfWithCookieDomain() throws Exception { TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + long startTs = System.currentTimeMillis(); + String userId = "userId"; JsonObject userDataInJWT = new JsonObject(); userDataInJWT.addProperty("key", "value"); @@ -73,6 +77,9 @@ public void successOutputCheckWithAntiCsrfWithCookieDomain() throws Exception { checkSessionResponse(response, process, userId, userDataInJWT); assertTrue(response.has("antiCsrfToken")); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 1); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } @@ -150,6 +157,8 @@ public void badInputTest() throws Exception { JsonObject userDataInDatabase = new JsonObject(); userDataInDatabase.addProperty("key", "value"); + long startTs = System.currentTimeMillis(); + try { JsonObject request = new JsonObject(); request.add("userDataInJWT", userDataInJWT); @@ -219,6 +228,9 @@ public void badInputTest() throws Exception { + "input"); } + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 0); + JsonObject request = new JsonObject(); request.addProperty("userId", userId); request.add("userDataInJWT", userDataInJWT); @@ -227,6 +239,9 @@ public void badInputTest() throws Exception { HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", "http://localhost:3567/recipe/session", request, 1000, 1000, null, Utils.getCdiVersion2_7ForTests(), "session"); + activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 1); + request = new JsonObject(); request.addProperty("userId", userId); request.add("userDataInJWT", userDataInJWT); @@ -235,6 +250,9 @@ public void badInputTest() throws Exception { HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", "http://localhost:3567/recipe/session", request, 1000, 1000, null, Utils.getCdiVersion2_7ForTests(), "session"); + activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 1); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } diff --git a/src/test/java/io/supertokens/test/session/api/SessionAPITest2_9.java b/src/test/java/io/supertokens/test/session/api/SessionAPITest2_9.java index 6ebc5835e..8f1bfe467 100644 --- a/src/test/java/io/supertokens/test/session/api/SessionAPITest2_9.java +++ b/src/test/java/io/supertokens/test/session/api/SessionAPITest2_9.java @@ -19,6 +19,8 @@ import com.google.gson.JsonArray; import com.google.gson.JsonNull; import com.google.gson.JsonObject; + +import io.supertokens.ActiveUsers; import io.supertokens.ProcessState; import io.supertokens.test.TestingProcessManager; import io.supertokens.test.Utils; @@ -56,6 +58,8 @@ public void successOutputCheckWithAntiCsrfWithCookieDomain() throws Exception { TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + long startTs = System.currentTimeMillis(); + String userId = "userId"; JsonObject userDataInJWT = new JsonObject(); userDataInJWT.addProperty("key", "value"); @@ -75,6 +79,9 @@ public void successOutputCheckWithAntiCsrfWithCookieDomain() throws Exception { checkSessionResponse(response, process, userId, userDataInJWT); assertTrue(response.has("antiCsrfToken")); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 1); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } @@ -177,6 +184,8 @@ public void badInputTest() throws Exception { TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + long startTs = System.currentTimeMillis(); + String userId = "userId"; JsonObject userDataInJWT = new JsonObject(); userDataInJWT.addProperty("key", "value"); @@ -251,6 +260,9 @@ public void badInputTest() throws Exception { "Http error. Status Code: 400. Message: Field name 'userDataInDatabase' is invalid in JSON " + "input"); } + + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 0); JsonObject request = new JsonObject(); request.addProperty("userId", userId); @@ -267,6 +279,9 @@ public void badInputTest() throws Exception { request.addProperty("enableAntiCsrf", false); HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", "http://localhost:3567/recipe/session", request, 1000, 1000, null, Utils.getCdiVersion2_9ForTests(), "session"); + + activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 1); process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); diff --git a/src/test/java/io/supertokens/test/session/api/SessionRemoveAPITest2_7.java b/src/test/java/io/supertokens/test/session/api/SessionRemoveAPITest2_7.java index f534ce7e9..39c463a0f 100644 --- a/src/test/java/io/supertokens/test/session/api/SessionRemoveAPITest2_7.java +++ b/src/test/java/io/supertokens/test/session/api/SessionRemoveAPITest2_7.java @@ -19,6 +19,8 @@ import com.google.gson.JsonArray; import com.google.gson.JsonObject; import com.google.gson.JsonParser; + +import io.supertokens.ActiveUsers; import io.supertokens.ProcessState; import io.supertokens.test.TestingProcessManager; import io.supertokens.test.Utils; @@ -98,6 +100,8 @@ public void testRemovingMultipleSessionsGivesCorrectOutput() throws Exception { // remove s2 and s4 and make sure they are returned + long checkpoint1 = System.currentTimeMillis(); + String sessionRemoveBodyString = "{" + " sessionHandles : [ " + s2Info.get("session").getAsJsonObject().get("handle").getAsString() + " ," + s4Info.get("session").getAsJsonObject().get("handle").getAsString() + " ] " + "}"; @@ -107,6 +111,9 @@ public void testRemovingMultipleSessionsGivesCorrectOutput() throws Exception { Utils.getCdiVersion2_7ForTests(), "session"); JsonArray revokedSessions = sessionRemovedResponse.getAsJsonArray("sessionHandlesRevoked"); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), checkpoint1); + assert (activeUsers == 0); // user ID is not set so not counted as active (we don't have userId) + for (int i = 0; i < revokedSessions.size(); i++) { assertTrue(sessionRemoveBody.getAsJsonArray("sessionHandles").contains(revokedSessions.get(i))); } @@ -120,6 +127,8 @@ public void testRemovingMultipleSessionsGivesCorrectOutput() throws Exception { + s4Info.get("session").getAsJsonObject().get("handle").getAsString() + " ] " + "}"; sessionRemoveBody = new JsonParser().parse(sessionRemoveBodyString).getAsJsonObject(); + long checkpoint2 = System.currentTimeMillis(); + sessionRemovedResponse = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", "http://localhost:3567/recipe/session/remove", sessionRemoveBody, 1000, 1000, null, Utils.getCdiVersion2_7ForTests(), "session"); @@ -131,6 +140,9 @@ public void testRemovingMultipleSessionsGivesCorrectOutput() throws Exception { assertEquals(revokedSessions.size(), 2); + activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), checkpoint2); + assert (activeUsers == 0); // user ID is not set so not counted as active (we don't have userId) + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); @@ -221,6 +233,8 @@ public void testRemovingSessionByUserId() throws Exception { "session"); assertEquals(session2Info.get("status").getAsString(), "OK"); + long checkpoint1 = System.currentTimeMillis(); + // remove session using user id JsonObject removeSessionBody = new JsonObject(); removeSessionBody.addProperty("userId", userId); @@ -237,6 +251,9 @@ public void testRemovingSessionByUserId() throws Exception { assertTrue(sessionRemovedResponse.getAsJsonArray("sessionHandlesRevoked") .contains(session2Info.get("session").getAsJsonObject().get("handle"))); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), checkpoint1); + assert (activeUsers == 1); // user ID is set + // check that the number of sessions for user is 0 Map userParams = new HashMap<>(); userParams.put("userId", userId); diff --git a/src/test/java/io/supertokens/test/thirdparty/api/ThirdPartySignInUpAPITest2_7.java b/src/test/java/io/supertokens/test/thirdparty/api/ThirdPartySignInUpAPITest2_7.java index 8409e2809..a9a01195c 100644 --- a/src/test/java/io/supertokens/test/thirdparty/api/ThirdPartySignInUpAPITest2_7.java +++ b/src/test/java/io/supertokens/test/thirdparty/api/ThirdPartySignInUpAPITest2_7.java @@ -17,6 +17,8 @@ package io.supertokens.test.thirdparty.api; import com.google.gson.JsonObject; + +import io.supertokens.ActiveUsers; import io.supertokens.ProcessState; import io.supertokens.emailverification.EmailVerification; import io.supertokens.pluginInterface.STORAGE_TYPE; @@ -69,6 +71,8 @@ public void testGoodInput() throws Exception { return; } + long startTs = System.currentTimeMillis(); + JsonObject response = Utils.signInUpRequest_2_7(process, "test@example.com", true, "testThirdPartyId", "testThirdPartyUserId"); checkSignInUpResponse(response, "testThirdPartyId", "testThirdPartyUserId", "test@example.com", true); @@ -78,6 +82,10 @@ public void testGoodInput() throws Exception { assertTrue(EmailVerification.isEmailVerified(process.getProcess(), user.get("id").getAsString(), user.get("email").getAsString())); } + + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 1); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } @@ -125,6 +133,9 @@ public void testBadInput() throws Exception { if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { return; } + + long startTs = System.currentTimeMillis(); + { try { HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", @@ -174,6 +185,10 @@ public void testBadInput() throws Exception { + "in " + "JSON input")); } } + + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 0); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } diff --git a/src/test/java/io/supertokens/test/thirdparty/api/ThirdPartySignInUpAPITest2_8.java b/src/test/java/io/supertokens/test/thirdparty/api/ThirdPartySignInUpAPITest2_8.java index 59dcbec82..25d697557 100644 --- a/src/test/java/io/supertokens/test/thirdparty/api/ThirdPartySignInUpAPITest2_8.java +++ b/src/test/java/io/supertokens/test/thirdparty/api/ThirdPartySignInUpAPITest2_8.java @@ -17,6 +17,8 @@ package io.supertokens.test.thirdparty.api; import com.google.gson.JsonObject; + +import io.supertokens.ActiveUsers; import io.supertokens.ProcessState; import io.supertokens.emailverification.EmailVerification; import io.supertokens.pluginInterface.STORAGE_TYPE; @@ -69,6 +71,8 @@ public void testGoodInput() throws Exception { return; } + long startTs = System.currentTimeMillis(); + JsonObject response = Utils.signInUpRequest_2_8(process, "test@examplE.com", "testThirdPartyId", "testThirdPartyUserId"); checkSignInUpResponse(response, "testThirdPartyId", "testThirdPartyUserId", "test@example.com", true); @@ -78,6 +82,10 @@ public void testGoodInput() throws Exception { assertFalse(EmailVerification.isEmailVerified(process.getProcess(), user.get("id").getAsString(), user.get("email").getAsString())); } + + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 1); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } @@ -125,6 +133,9 @@ public void testBadInput() throws Exception { if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { return; } + + long startTs = System.currentTimeMillis(); + { try { HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", @@ -174,6 +185,10 @@ public void testBadInput() throws Exception { + "in " + "JSON input")); } } + + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 0); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } diff --git a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java new file mode 100644 index 000000000..6841f9982 --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java @@ -0,0 +1,524 @@ +/* + * Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.test.totp; + +import com.eatthepath.otp.TimeBasedOneTimePasswordGenerator; +import io.supertokens.Main; +import io.supertokens.ProcessState; +import io.supertokens.config.Config; +import io.supertokens.cronjobs.deleteExpiredTotpTokens.DeleteExpiredTotpTokens; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlagTestContent; +import io.supertokens.featureflag.exceptions.InvalidLicenseKeyException; +import io.supertokens.httpRequest.HttpResponseException; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.totp.TOTPStorage; +import io.supertokens.pluginInterface.totp.TOTPUsedCode; +import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; +import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; +import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.totp.Totp; +import io.supertokens.totp.exceptions.InvalidTotpException; +import io.supertokens.totp.exceptions.LimitReachedException; +import org.apache.commons.codec.binary.Base32; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import javax.crypto.spec.SecretKeySpec; +import java.io.IOException; +import java.security.InvalidKeyException; +import java.security.Key; +import java.time.Duration; +import java.time.Instant; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; + +// TODO: Add test for UsedCodeAlreadyExistsException once we implement time mocking + +public class TOTPRecipeTest { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + public class TestSetupResult { + public TOTPStorage storage; + public TestingProcessManager.TestingProcess process; + + public TestSetupResult(TOTPStorage storage, TestingProcessManager.TestingProcess process) { + this.storage = storage; + this.process = process; + } + } + + public TestSetupResult defaultInit() + throws InterruptedException, IOException, StorageQueryException, InvalidLicenseKeyException, + HttpResponseException { + String[] args = {"../"}; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return null; + } + TOTPStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); + + FeatureFlagTestContent.getInstance(process.main) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{EE_FEATURES.TOTP}); + + return new TestSetupResult(storage, process); + } + + public static String generateTotpCode(Main main, TOTPDevice device) + throws InvalidKeyException, StorageQueryException { + return generateTotpCode(main, device, 0); + } + + /** + * Generates TOTP code similar to apps like Google Authenticator and Authy + */ + private static String generateTotpCode(Main main, TOTPDevice device, int step) + throws InvalidKeyException, StorageQueryException { + final TimeBasedOneTimePasswordGenerator totp = new TimeBasedOneTimePasswordGenerator( + Duration.ofSeconds(device.period)); + + byte[] keyBytes = new Base32().decode(device.secretKey); + Key key = new SecretKeySpec(keyBytes, "HmacSHA1"); + + return totp.generateOneTimePasswordString(key, Instant.now().plusSeconds(step * device.period)); + } + + private static TOTPUsedCode[] getAllUsedCodesUtil(TOTPStorage storage, String userId) + throws StorageQueryException, StorageTransactionLogicException { + assert storage instanceof TOTPSQLStorage; + TOTPSQLStorage sqlStorage = (TOTPSQLStorage) storage; + + return (TOTPUsedCode[]) sqlStorage.startTransaction(con -> { + TOTPUsedCode[] usedCodes = sqlStorage.getAllUsedCodesDescOrder_Transaction(con, userId); + sqlStorage.commitTransaction(con); + return usedCodes; + }); + } + + @Test + public void createDeviceTest() throws Exception { + TestSetupResult result = defaultInit(); + if (result == null) { + return; + } + Main main = result.process.getProcess(); + + // Create device + TOTPDevice device = Totp.registerDevice(main, "user", "device1", 1, 30); + assert device.secretKey != ""; + + // Create same device again (should fail) + assertThrows(DeviceAlreadyExistsException.class, + () -> Totp.registerDevice(main, "user", "device1", 1, 30)); + } + + @Test + public void createDeviceAndVerifyCodeTest() throws Exception { + TestSetupResult result = defaultInit(); + if (result == null) { + return; + } + Main main = result.process.getProcess(); + + // Create device + TOTPDevice device = Totp.registerDevice(main, "user", "device1", 1, 1); + + // Try login with non-existent user: + assertThrows(TotpNotEnabledException.class, + () -> Totp.verifyCode(main, "non-existent-user", "any-code", true)); + + // {Code: [INVALID, VALID]} * {Devices: [VERIFIED_ONLY, ALL]} + + // Invalid code & allowUnverifiedDevice = true: + assertThrows(InvalidTotpException.class, + () -> Totp.verifyCode(main, "user", "invalid", true)); + + // Invalid code & allowUnverifiedDevice = false: + assertThrows(InvalidTotpException.class, + () -> Totp.verifyCode(main, "user", "invalid", false)); + + // Valid code & allowUnverifiedDevice = false: + assertThrows( + InvalidTotpException.class, + () -> Totp.verifyCode(main, "user", generateTotpCode(main, device), false)); + + // Valid code & allowUnverifiedDevice = true (Success): + String validCode = generateTotpCode(main, device); + Totp.verifyCode(main, "user", validCode, true); + + // Now try again with same code: + assertThrows( + InvalidTotpException.class, + () -> Totp.verifyCode(main, "user", validCode, true)); + + // Sleep for 1s so that code changes. + Thread.sleep(1000); + + // Use a new valid code: + String newValidCode = generateTotpCode(main, device); + Totp.verifyCode(main, "user", newValidCode, true); + + // Reuse the same code and use it again (should fail): + assertThrows(InvalidTotpException.class, + () -> Totp.verifyCode(main, "user", newValidCode, true)); + + // Use a code from next period: + String nextValidCode = generateTotpCode(main, device, 1); + Totp.verifyCode(main, "user", nextValidCode, true); + + // Use previous period code (should fail coz validCode has been used): + String previousCode = generateTotpCode(main, device, -1); + assert previousCode.equals(validCode); + assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", previousCode, true)); + + // Create device with skew = 0, check that it only works with the current code + TOTPDevice device2 = Totp.registerDevice(main, "user", "device2", 0, 1); + assert device2.secretKey != device.secretKey; + + String nextValidCode2 = generateTotpCode(main, device2, 1); + assertThrows(InvalidTotpException.class, + () -> Totp.verifyCode(main, "user", nextValidCode2, true)); + + String previousValidCode2 = generateTotpCode(main, device2, -1); + assertThrows(InvalidTotpException.class, + () -> Totp.verifyCode(main, "user", previousValidCode2, true)); + + String currentValidCode2 = generateTotpCode(main, device2); + Totp.verifyCode(main, "user", currentValidCode2, true); + + // Submit invalid code and check that it's expiry time is correct + // created - expiryTime = max of ((2 * skew + 1) * period) for all devices + assertThrows(InvalidTotpException.class, + () -> Totp.verifyCode(main, "user", "invalid", true)); + + TOTPUsedCode[] usedCodes = getAllUsedCodesUtil(result.storage, "user"); + TOTPUsedCode latestCode = usedCodes[0]; + assert latestCode.isValid == false; + assert latestCode.expiryTime - latestCode.createdTime == 3000; // it should be 3s because of device1 + } + + /* + * Triggers rate limiting and checks that it works. + * It returns the number of attempts that were made before rate limiting was + * triggered. + */ + public int triggerAndCheckRateLimit(Main main, TOTPDevice device) throws Exception { + int N = Config.getConfig(main).getTotpMaxAttempts(); + + // First N attempts should fail with invalid code: + // This is to trigger rate limiting + for (int i = 0; i < N; i++) { + String code = "ic-" + i; + assertThrows( + InvalidTotpException.class, + () -> Totp.verifyCode(main, "user", code, true)); + } + + // Any kind of attempt after this should fail with rate limiting error. + // This should happen until rate limiting cooldown happens: + assertThrows( + LimitReachedException.class, + () -> Totp.verifyCode(main, "user", "icN+1", true)); + assertThrows( + LimitReachedException.class, + () -> Totp.verifyCode(main, "user", generateTotpCode(main, device), true)); + assertThrows( + LimitReachedException.class, + () -> Totp.verifyCode(main, "user", "icN+2", true)); + + return N; + } + + @Test + public void rateLimitCooldownTest() throws Exception { + String[] args = {"../"}; + + // set rate limiting cooldown time to 1s + Utils.setValueInConfig("totp_rate_limit_cooldown_sec", "1"); + // set max attempts to 3 + Utils.setValueInConfig("totp_max_attempts", "3"); + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + FeatureFlagTestContent.getInstance(process.main) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{EE_FEATURES.TOTP}); + + Main main = process.getProcess(); + + // Create device + TOTPDevice device = Totp.registerDevice(main, "user", "deviceName", 1, 1); + + // Trigger rate limiting and fix it with a correct code after some time: + int attemptsRequired = triggerAndCheckRateLimit(main, device); + assert attemptsRequired == 3; + // Wait for 1 second (Should cool down rate limiting): + Thread.sleep(1000); + // But again try with invalid code: + assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "invalid0", true)); + // This triggered rate limiting again. So even valid codes will fail for + // another cooldown period: + assertThrows(LimitReachedException.class, + () -> Totp.verifyCode(main, "user", generateTotpCode(main, device), true)); + // Wait for 1 second (Should cool down rate limiting): + Thread.sleep(1000); + // Now try with valid code: + Totp.verifyCode(main, "user", generateTotpCode(main, device), true); + // Now invalid code shouldn't trigger rate limiting. Unless you do it N times: + assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "invaldd", true)); + } + + @Test + public void cronRemovesCodesDuringRateLimitTest() throws Exception { + // This test is flaky because of time. + TestSetupResult result = defaultInit(); + if (result == null) { + return; + } + Main main = result.process.getProcess(); + + // Create device + TOTPDevice device = Totp.registerDevice(main, "user", "deviceName", 0, 1); + + // Trigger rate limiting and fix it with cronjob (manually run cronjob): + int attemptsRequired = triggerAndCheckRateLimit(main, device); + assert attemptsRequired == 5; + // Wait for 1 second so that all the codes expire: + Thread.sleep(1500); + // Manually run cronjob to delete all the codes after their + // expiry time + rate limiting period is over: + DeleteExpiredTotpTokens.getInstance(main).run(); + + // This removal shouldn't affect rate limiting. User must remain rate limited. + assertThrows(LimitReachedException.class, + () -> Totp.verifyCode(main, "user", generateTotpCode(main, device), true)); + assertThrows(LimitReachedException.class, + () -> Totp.verifyCode(main, "user", "yet-ic", true)); + } + + @Test + public void createAndVerifyDeviceTest() throws Exception { + TestSetupResult result = defaultInit(); + if (result == null) { + return; + } + Main main = result.process.getProcess(); + + // Create device + TOTPDevice device = Totp.registerDevice(main, "user", "deviceName", 1, 30); + + // Try verify non-existent user: + assertThrows(TotpNotEnabledException.class, + () -> Totp.verifyDevice(main, "non-existent-user", "deviceName", "XXXX")); + + // Try verify non-existent device + assertThrows(UnknownDeviceException.class, + () -> Totp.verifyDevice(main, "user", "non-existent-device", "XXXX")); + + // Verify device with wrong code + assertThrows(InvalidTotpException.class, () -> Totp.verifyDevice(main, "user", "deviceName", "ic0")); + + // Verify device with correct code + String validCode = generateTotpCode(main, device); + boolean justVerfied = Totp.verifyDevice(main, "user", "deviceName", validCode); + assert justVerfied; + + // Verify again with same correct code: + justVerfied = Totp.verifyDevice(main, "user", "deviceName", validCode); + assert !justVerfied; + + // Verify again with new correct code: + String newValidCode = generateTotpCode(main, device); + justVerfied = Totp.verifyDevice(main, "user", "deviceName", newValidCode); + assert !justVerfied; + + // Verify again with wrong code: + justVerfied = Totp.verifyDevice(main, "user", "deviceName", "ic1"); + assert !justVerfied; + + result.process.kill(); + assertNotNull(result.process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void removeDeviceTest() throws Exception { + // Flaky test. + TestSetupResult result = defaultInit(); + if (result == null) { + return; + } + Main main = result.process.getProcess(); + TOTPStorage storage = result.storage; + + // Create devices + TOTPDevice device1 = Totp.registerDevice(main, "user", "device1", 1, 30); + TOTPDevice device2 = Totp.registerDevice(main, "user", "device2", 1, 30); + + TOTPDevice[] devices = Totp.getDevices(main, "user"); + assert (devices.length == 2); + + // Try to delete device for non-existent user: + assertThrows(TotpNotEnabledException.class, () -> Totp.removeDevice(main, "non-existent-user", "device1")); + + // Try to delete non-existent device: + assertThrows(UnknownDeviceException.class, () -> Totp.removeDevice(main, "user", "non-existent-device")); + + // Delete one of the devices + { + assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "ic0", true)); + Totp.verifyCode(main, "user", generateTotpCode(main, device1), true); + Totp.verifyCode(main, "user", generateTotpCode(main, device2), true); + + // Delete device1 + Totp.removeDevice(main, "user", "device1"); + + devices = Totp.getDevices(main, "user"); + assert (devices.length == 1); + + // 1 device still remain so all codes should still be still there: + TOTPUsedCode[] usedCodes = getAllUsedCodesUtil(storage, "user"); + assert (usedCodes.length == 3); + } + + // Deleting the last device of a user should delete all related codes: + // Delete the 2nd (and the last) device + { + + // Create another user to test that other users aren't affected: + TOTPDevice otherUserDevice = Totp.registerDevice(main, "other-user", "device", 1, 30); + Totp.verifyCode(main, "other-user", generateTotpCode(main, otherUserDevice), true); + assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "other-user", "ic1", true)); + + // Delete device2 + Totp.removeDevice(main, "user", "device2"); + + // TOTP has ben disabled for the user: + assertThrows(TotpNotEnabledException.class, () -> Totp.getDevices(main, "user")); + + // No device left so all codes of the user should be deleted: + TOTPUsedCode[] usedCodes = getAllUsedCodesUtil(storage, "user"); + assert (usedCodes.length == 0); + + // But for other users things should still be there: + TOTPDevice[] otherUserDevices = Totp.getDevices(main, "other-user"); + assert (otherUserDevices.length == 1); + + usedCodes = getAllUsedCodesUtil(storage, "other-user"); + assert (usedCodes.length == 2); + } + } + + @Test + public void updateDeviceNameTest() throws Exception { + TestSetupResult result = defaultInit(); + if (result == null) { + return; + } + Main main = result.process.getProcess(); + + Totp.registerDevice(main, "user", "device1", 1, 30); + Totp.registerDevice(main, "user", "device2", 1, 30); + + // Try update non-existent user: + assertThrows(TotpNotEnabledException.class, + () -> Totp.updateDeviceName(main, "non-existent-user", "device1", "new-device-name")); + + // Try update non-existent device: + assertThrows(UnknownDeviceException.class, + () -> Totp.updateDeviceName(main, "user", "non-existent-device", "new-device-name")); + + // Update device name (should work) + Totp.updateDeviceName(main, "user", "device1", "new-device-name"); + + // Verify that the device name has been updated: + TOTPDevice[] devices = Totp.getDevices(main, "user"); + assert (devices.length == 2); + assert (devices[0].deviceName.equals("device2")); + assert (devices[1].deviceName.equals("new-device-name")); + + // Verify that TOTP verification still works: + Totp.verifyDevice(main, "user", devices[0].deviceName, generateTotpCode(main, devices[0])); + Totp.verifyDevice(main, "user", devices[0].deviceName, generateTotpCode(main, devices[1])); + + // Try update device name to an already existing device name: + assertThrows(DeviceAlreadyExistsException.class, + () -> Totp.updateDeviceName(main, "user", "device2", "new-device-name")); + } + + @Test + public void getDevicesTest() throws Exception { + TestSetupResult result = defaultInit(); + if (result == null) { + return; + } + Main main = result.process.getProcess(); + + // Try get devices for non-existent user: + assertThrows(TotpNotEnabledException.class, () -> Totp.getDevices(main, "non-existent-user")); + + TOTPDevice device1 = Totp.registerDevice(main, "user", "device1", 2, 30); + TOTPDevice device2 = Totp.registerDevice(main, "user", "device2", 1, 10); + + TOTPDevice[] devices = Totp.getDevices(main, "user"); + assert (devices.length == 2); + assert devices[0].equals(device1); + assert devices[1].equals(device2); + } + + @Test + public void deleteExpiredTokensCronIntervalTest() throws Exception { + TestSetupResult result = defaultInit(); + if (result == null) { + return; + } + Main main = result.process.getProcess(); + + // Ensure that delete expired tokens cron runs every hour: + assert DeleteExpiredTotpTokens.getInstance(main).getIntervalTimeSeconds() == 60 * 60; + } + +} diff --git a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java new file mode 100644 index 000000000..a9d3c0493 --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java @@ -0,0 +1,510 @@ +package io.supertokens.test.totp; + +import io.supertokens.ProcessState; +import io.supertokens.cronjobs.deleteExpiredTotpTokens.DeleteExpiredTotpTokens; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlagTestContent; +import io.supertokens.featureflag.exceptions.InvalidLicenseKeyException; +import io.supertokens.httpRequest.HttpResponseException; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.totp.TOTPStorage; +import io.supertokens.pluginInterface.totp.TOTPUsedCode; +import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; +import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; +import io.supertokens.pluginInterface.totp.exception.UsedCodeAlreadyExistsException; +import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import java.io.IOException; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; + +public class TOTPStorageTest { + + public class TestSetupResult { + public TOTPSQLStorage storage; + public TestingProcessManager.TestingProcess process; + + public TestSetupResult(TOTPSQLStorage storage, TestingProcessManager.TestingProcess process) { + this.storage = storage; + this.process = process; + } + } + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + public TestSetupResult initSteps() + throws InterruptedException, StorageQueryException, InvalidLicenseKeyException, HttpResponseException, + IOException { + String[] args = {"../"}; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return null; + } + TOTPSQLStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); + + FeatureFlagTestContent.getInstance(process.main) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{EE_FEATURES.TOTP}); + + return new TestSetupResult(storage, process); + } + + private static TOTPUsedCode[] getAllUsedCodesUtil(TOTPStorage storage, String userId) + throws StorageQueryException, StorageTransactionLogicException { + assert storage instanceof TOTPSQLStorage; + TOTPSQLStorage sqlStorage = (TOTPSQLStorage) storage; + + return (TOTPUsedCode[]) sqlStorage.startTransaction(con -> { + TOTPUsedCode[] usedCodes = sqlStorage.getAllUsedCodesDescOrder_Transaction(con, userId); + sqlStorage.commitTransaction(con); + return usedCodes; + }); + } + + public static void insertUsedCodesUtil(TOTPSQLStorage storage, TOTPUsedCode[] usedCodes) + throws StorageQueryException, StorageTransactionLogicException, TotpNotEnabledException, + UsedCodeAlreadyExistsException { + try { + storage.startTransaction(con -> { + try { + for (TOTPUsedCode usedCode : usedCodes) { + storage.insertUsedCode_Transaction(con, usedCode); + } + } catch (TotpNotEnabledException | UsedCodeAlreadyExistsException e) { + throw new StorageTransactionLogicException(e); + } + storage.commitTransaction(con); + + return null; + }); + } catch (StorageTransactionLogicException e) { + Exception actual = e.actualException; + if (actual instanceof TotpNotEnabledException) { + throw (TotpNotEnabledException) actual; + } else if (actual instanceof UsedCodeAlreadyExistsException) { + throw (UsedCodeAlreadyExistsException) actual; + } + throw e; + } + } + + @Test + public void createDeviceTests() throws Exception { + TestSetupResult result = initSteps(); + if (result == null) { + return; + } + TOTPSQLStorage storage = result.storage; + + TOTPDevice device1 = new TOTPDevice("user", "d1", "secret", 30, 1, false); + TOTPDevice device2 = new TOTPDevice("user", "d2", "secret", 30, 1, true); + TOTPDevice device2Duplicate = new TOTPDevice("user", "d2", "new-secret", 30, 1, false); + + storage.createDevice(device1); + + TOTPDevice[] storedDevices = storage.getDevices("user"); + assert (storedDevices.length == 1); + assert storedDevices[0].equals(device1); + + storage.createDevice(device2); + storedDevices = storage.getDevices("user"); + + assert (storedDevices.length == 2); + assert storedDevices[0].equals(device1); + assert storedDevices[1].equals(device2); + + assertThrows(DeviceAlreadyExistsException.class, () -> storage.createDevice(device2Duplicate)); + + result.process.kill(); + assertNotNull(result.process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void verifyDeviceTests() throws Exception { + TestSetupResult result = initSteps(); + if (result == null) { + return; + } + TOTPSQLStorage storage = result.storage; + + TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); + storage.createDevice(device); + + TOTPDevice[] storedDevices = storage.getDevices("user"); + assert (storedDevices.length == 1); + assert (!storedDevices[0].verified); + + // Verify the device: + storage.markDeviceAsVerified("user", "device"); + + storedDevices = storage.getDevices("user"); + assert (storedDevices.length == 1); + assert (storedDevices[0].verified); + + // Try to verify the device again: + storage.markDeviceAsVerified("user", "device"); + + // Try to verify a device that doesn't exist: + assertThrows(UnknownDeviceException.class, () -> storage.markDeviceAsVerified("user", "non-existent-device")); + } + + @Test + public void getDevicesCount_TransactionTests() throws Exception { + TestSetupResult result = initSteps(); + if (result == null) { + return; + } + TOTPSQLStorage storage = result.storage; + + // Try to get the count for a user that doesn't exist (Should pass because + // this is DB level txn that doesn't throw TotpNotEnabledException): + int devicesCount = storage.startTransaction(con -> { + TOTPDevice[] devices = storage.getDevices_Transaction(con, "non-existent-user"); + storage.commitTransaction(con); + return devices.length; + }); + assert devicesCount == 0; + + TOTPDevice device1 = new TOTPDevice("user", "device1", "sk1", 30, 1, false); + TOTPDevice device2 = new TOTPDevice("user", "device2", "sk2", 30, 1, false); + + storage.createDevice(device1); + storage.createDevice(device2); + + devicesCount = storage.startTransaction(con -> { + TOTPDevice[] devices = storage.getDevices_Transaction(con, "user"); + storage.commitTransaction(con); + return devices.length; + }); + assert devicesCount == 2; + } + + @Test + public void removeUser_TransactionTests() throws Exception { + TestSetupResult result = initSteps(); + if (result == null) { + return; + } + TOTPSQLStorage storage = result.storage; + + // Try to remove a user that doesn't exist (Should pass because + // this is DB level txn that doesn't throw TotpNotEnabledException): + storage.startTransaction(con -> { + storage.removeUser_Transaction(con, "non-existent-user"); + storage.commitTransaction(con); + return null; + }); + + TOTPDevice device1 = new TOTPDevice("user", "device1", "sk1", 30, 1, false); + TOTPDevice device2 = new TOTPDevice("user", "device2", "sk2", 30, 1, false); + + storage.createDevice(device1); + storage.createDevice(device2); + + long now = System.currentTimeMillis(); + long expiryAfter10mins = now + 10 * 60 * 1000; + + TOTPUsedCode usedCode1 = new TOTPUsedCode("user", "code1", true, expiryAfter10mins, now); + TOTPUsedCode usedCode2 = new TOTPUsedCode("user", "code2", false, expiryAfter10mins, now + 1); + + insertUsedCodesUtil(storage, new TOTPUsedCode[]{usedCode1, usedCode2}); + + TOTPDevice[] storedDevices = storage.getDevices("user"); + assert (storedDevices.length == 2); + + TOTPUsedCode[] storedUsedCodes = getAllUsedCodesUtil(storage, "user"); + assert (storedUsedCodes.length == 2); + + storage.startTransaction(con -> { + storage.removeUser_Transaction(con, "user"); + storage.commitTransaction(con); + return null; + }); + + storedDevices = storage.getDevices("user"); + assert (storedDevices.length == 0); + + storedUsedCodes = getAllUsedCodesUtil(storage, "user"); + assert (storedUsedCodes.length == 0); + } + + @Test + public void deleteDevice_TransactionTests() throws Exception { + TestSetupResult result = initSteps(); + if (result == null) { + return; + } + TOTPSQLStorage storage = result.storage; + + TOTPDevice device1 = new TOTPDevice("user", "device1", "sk1", 30, 1, false); + TOTPDevice device2 = new TOTPDevice("user", "device2", "sk2", 30, 1, false); + + storage.createDevice(device1); + storage.createDevice(device2); + + TOTPDevice[] storedDevices = storage.getDevices("user"); + assert (storedDevices.length == 2); + + // Try to delete a device for a user that doesn't exist (Should pass because + // this is DB level txn that doesn't throw TotpNotEnabledException): + storage.startTransaction(con -> { + storage.deleteDevice_Transaction(con, "non-existent-user", "device1"); + storage.commitTransaction(con); + return null; + }); + + // Try to delete a device that doesn't exist: + try { + storage.startTransaction(con -> { + storage.deleteDevice_Transaction(con, "user", "non-existent-device"); + storage.commitTransaction(con); + return null; + }); + } catch (Exception e) { + assert (e instanceof UnknownDeviceException) ? true : false; + } + + // Successfully delete device1: + storage.startTransaction(con -> { + storage.deleteDevice_Transaction(con, "user", "device1"); + storage.commitTransaction(con); + return null; + }); + + storedDevices = storage.getDevices("user"); + assert (storedDevices.length == 1); // device2 should still be there + } + + @Test + public void updateDeviceNameTests() throws Exception { + TestSetupResult result = initSteps(); + if (result == null) { + return; + } + TOTPSQLStorage storage = result.storage; + + TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); + storage.createDevice(device); + + TOTPDevice[] storedDevices = storage.getDevices("user"); + assert (storedDevices.length == 1); + assert (storedDevices[0].deviceName.equals("device")); + + // Try to update a device that doesn't exist: + assertThrows(UnknownDeviceException.class, + () -> storage.updateDeviceName("user", "non-existent-device", "new-device-name")); + + // Update the device name: + storage.updateDeviceName("user", "device", "updated-device-name"); + + storedDevices = storage.getDevices("user"); + assert (storedDevices.length == 1); + assert (storedDevices[0].deviceName.equals("updated-device-name")); + + // Try to create a new device and rename it to the same name as an existing + // device: + TOTPDevice newDevice = new TOTPDevice("user", "new-device", "secretKey", 30, 1, false); + storage.createDevice(newDevice); + + assertThrows(DeviceAlreadyExistsException.class, + () -> storage.updateDeviceName("user", "new-device", "updated-device-name")); + + // Try to rename the device the same name (Should work at database level): + storage.updateDeviceName("user", "updated-device-name", "updated-device-name"); + } + + @Test + public void getDevicesTest() throws Exception { + TestSetupResult result = initSteps(); + if (result == null) { + return; + } + TOTPSQLStorage storage = result.storage; + + TOTPDevice device1 = new TOTPDevice("user", "d1", "secretKey", 30, 1, false); + TOTPDevice device2 = new TOTPDevice("user", "d2", "secretKey", 30, 1, false); + + storage.createDevice(device1); + storage.createDevice(device2); + + TOTPDevice[] storedDevices = storage.getDevices("user"); + + assert (storedDevices.length == 2); + assert (storedDevices[0].deviceName.equals("d1")); + assert (storedDevices[1].deviceName.equals("d2")); + + storedDevices = storage.getDevices("non-existent-user"); + assert (storedDevices.length == 0); + } + + @Test + public void insertUsedCodeTest() throws Exception { + TestSetupResult result = initSteps(); + if (result == null) { + return; + } + TOTPSQLStorage storage = result.storage; + long nextDay = System.currentTimeMillis() + 1000 * 60 * 60 * 24; // 1 day from now + long now = System.currentTimeMillis(); + + // Insert a long lasting valid code and check that it's returned when queried: + { + TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); + TOTPUsedCode code = new TOTPUsedCode("user", "1234", true, nextDay, now); + + storage.createDevice(device); + insertUsedCodesUtil(storage, new TOTPUsedCode[]{code}); + TOTPUsedCode[] usedCodes = getAllUsedCodesUtil(storage, "user"); + + assert (usedCodes.length == 1); + assert usedCodes[0].equals(code); + } + + // Try to insert a code with same user and created time. It should fail: + { + TOTPUsedCode codeWithRepeatedCreatedTime = new TOTPUsedCode("user", "any-code", true, nextDay, now); + assertThrows(UsedCodeAlreadyExistsException.class, + () -> insertUsedCodesUtil(storage, new TOTPUsedCode[]{codeWithRepeatedCreatedTime})); + } + + // Try to insert code when user doesn't have any device (i.e. TOTP not enabled) + { + assertThrows(TotpNotEnabledException.class, + () -> insertUsedCodesUtil(storage, new TOTPUsedCode[]{ + new TOTPUsedCode("new-user-without-totp", "1234", true, nextDay, + System.currentTimeMillis()) + })); + } + + // Try to insert code after user has atleast one device (i.e. TOTP enabled) + { + TOTPDevice newDevice = new TOTPDevice("user", "new-device", "secretKey", 30, 1, false); + storage.createDevice(newDevice); + insertUsedCodesUtil( + storage, + new TOTPUsedCode[]{ + new TOTPUsedCode("user", "1234", true, nextDay, System.currentTimeMillis()) + }); + } + + // Try to insert code when user doesn't exist: + assertThrows(TotpNotEnabledException.class, + () -> insertUsedCodesUtil(storage, new TOTPUsedCode[]{ + new TOTPUsedCode("non-existent-user", "1234", true, nextDay, + System.currentTimeMillis()) + })); + } + + @Test + public void getAllUsedCodesTest() throws Exception { + TestSetupResult result = initSteps(); + if (result == null) { + return; + } + TOTPSQLStorage storage = result.storage; + + TOTPUsedCode[] usedCodes = getAllUsedCodesUtil(storage, "non-existent-user"); + assert (usedCodes.length == 0); + + long now = System.currentTimeMillis(); + long nextDay = now + 1000 * 60 * 60 * 24; // 1 day from now + long prevDay = now - 1000 * 60 * 60 * 24; // 1 day ago + + TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); + TOTPUsedCode validCode1 = new TOTPUsedCode("user", "valid1", true, nextDay, now + 1); + TOTPUsedCode invalidCode = new TOTPUsedCode("user", "invalid", false, nextDay, now + 2); + TOTPUsedCode expiredCode = new TOTPUsedCode("user", "expired", true, prevDay, now + 3); + TOTPUsedCode expiredInvalidCode = new TOTPUsedCode("user", "ex-in", false, prevDay, now + 4); + TOTPUsedCode validCode2 = new TOTPUsedCode("user", "valid2", true, nextDay, now + 5); + TOTPUsedCode validCode3 = new TOTPUsedCode("user", "valid3", true, nextDay, now + 6); + + storage.createDevice(device); + insertUsedCodesUtil(storage, new TOTPUsedCode[]{ + validCode1, invalidCode, + expiredCode, expiredInvalidCode, + validCode2, validCode3 + }); + + // Try to create a code with same user and created time. It should fail: + assertThrows(UsedCodeAlreadyExistsException.class, + () -> insertUsedCodesUtil(storage, new TOTPUsedCode[]{ + new TOTPUsedCode("user", "any-code", true, nextDay, now + 1) + })); + + usedCodes = getAllUsedCodesUtil(storage, "user"); + assert (usedCodes.length == 6); + + DeleteExpiredTotpTokens.getInstance(result.process.getProcess()).run(); + + usedCodes = getAllUsedCodesUtil(storage, "user"); + assert (usedCodes.length == 4); // expired codes shouldn't be returned + assert (usedCodes[0].equals(validCode3)); // order is DESC by created time (now + X) + assert (usedCodes[1].equals(validCode2)); + assert (usedCodes[2].equals(invalidCode)); + assert (usedCodes[3].equals(validCode1)); + } + + @Test + public void removeExpiredCodesTest() throws Exception { + TestSetupResult result = initSteps(); + if (result == null) { + return; + } + TOTPSQLStorage storage = result.storage; + + long now = System.currentTimeMillis(); + long nextDay = System.currentTimeMillis() + 1000 * 60 * 60 * 24; // 1 day from now + long halfSecond = System.currentTimeMillis() + 500; // 500ms from now + + TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); + TOTPUsedCode validCodeToLive = new TOTPUsedCode("user", "valid", true, nextDay, now); + TOTPUsedCode invalidCodeToLive = new TOTPUsedCode("user", "invalid", false, nextDay, now + 1); + TOTPUsedCode validCodeToExpire = new TOTPUsedCode("user", "valid", true, halfSecond, now + 2); + TOTPUsedCode invalidCodeToExpire = new TOTPUsedCode("user", "invalid", false, halfSecond, now + 3); + + storage.createDevice(device); + insertUsedCodesUtil(storage, new TOTPUsedCode[]{ + validCodeToLive, invalidCodeToLive, + validCodeToExpire, invalidCodeToExpire + }); + + TOTPUsedCode[] usedCodes = getAllUsedCodesUtil(storage, "user"); + assert (usedCodes.length == 4); + + // After 500ms seconds pass: + Thread.sleep(500); + + storage.removeExpiredCodes(System.currentTimeMillis()); + + usedCodes = getAllUsedCodesUtil(storage, "user"); + assert (usedCodes.length == 2); + assert (usedCodes[0].equals(invalidCodeToLive)); + assert (usedCodes[1].equals(validCodeToLive)); + } +} diff --git a/src/test/java/io/supertokens/test/totp/TotpLicenseTest.java b/src/test/java/io/supertokens/test/totp/TotpLicenseTest.java new file mode 100644 index 000000000..c0f0745c5 --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/TotpLicenseTest.java @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.test.totp; + +import com.google.gson.JsonObject; +import io.supertokens.Main; +import io.supertokens.ProcessState; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlagTestContent; +import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.totp.TOTPStorage; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.httpRequest.HttpResponseException; +import io.supertokens.totp.Totp; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import static io.supertokens.test.totp.TOTPRecipeTest.generateTotpCode; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; + +public class TotpLicenseTest { + public final static String OPAQUE_KEY_WITH_TOTP_FEATURE = "pXhNK=nYiEsb6gJEOYP2kIR6M0kn4XLvNqcwT1XbX8xHtm44K" + + "-lQfGCbaeN0Ieeza39fxkXr=tiiUU=DXxDH40Y=4FLT4CE-rG1ETjkXxO4yucLpJvw3uSegPayoISGL"; + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + public class TestSetupResult { + public TOTPStorage storage; + public TestingProcessManager.TestingProcess process; + + public TestSetupResult(TOTPStorage storage, TestingProcessManager.TestingProcess process) { + this.storage = storage; + this.process = process; + } + } + + public TestSetupResult defaultInit() throws InterruptedException { + String[] args = {"../"}; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return null; + } + TOTPStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); + + return new TestSetupResult(storage, process); + } + + @Test + public void testTotpWithoutLicense() throws Exception { + TestSetupResult result = defaultInit(); + if (result == null) { + return; + } + Main main = result.process.getProcess(); + + // Create device + assertThrows(FeatureNotEnabledException.class, () -> { + Totp.registerDevice(main, "user", "device1", 1, 30); + }); + // Verify code + assertThrows(FeatureNotEnabledException.class, () -> { + Totp.verifyCode(main, "user", "device1", true); + }); + + // Try to create device via API: + JsonObject body = new JsonObject(); + body.addProperty("userId", "user-id"); + body.addProperty("deviceName", "d1"); + body.addProperty("skew", 0); + body.addProperty("period", 30); + + + HttpResponseException e = assertThrows( + HttpResponseException.class, + () -> { + HttpRequestForTesting.sendJsonPOSTRequest( + result.process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + } + ); + assert e.statusCode == 402; + assert e.getMessage().contains("TOTP feature is not enabled"); + + + // Try to verify code via API: + JsonObject body2 = new JsonObject(); + body2.addProperty("userId", "user-id"); + body2.addProperty("totp", "123456"); + body2.addProperty("allowUnverifiedDevices", true); + + + HttpResponseException e2 = assertThrows( + HttpResponseException.class, + () -> { + HttpRequestForTesting.sendJsonPOSTRequest( + result.process.getProcess(), + "", + "http://localhost:3567/recipe/totp/verify", + body2, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + } + ); + assert e2.statusCode == 402; + assert e2.getMessage().contains("TOTP feature is not enabled"); + } + + + @Test + public void testTotpWithLicense() throws Exception { + TestSetupResult result = defaultInit(); + if (result == null) { + return; + } + FeatureFlagTestContent.getInstance(result.process.main) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{EE_FEATURES.TOTP}); + + Main main = result.process.getProcess(); + + // Create device + TOTPDevice device = Totp.registerDevice(main, "user", "device1", 1, 30); + // Verify code + String code = generateTotpCode(main, device); + Totp.verifyCode(main, "user", code, true); + } + + +} diff --git a/src/test/java/io/supertokens/test/totp/api/CreateTotpDeviceAPITest.java b/src/test/java/io/supertokens/test/totp/api/CreateTotpDeviceAPITest.java new file mode 100644 index 000000000..92e844840 --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/api/CreateTotpDeviceAPITest.java @@ -0,0 +1,153 @@ +package io.supertokens.test.totp.api; + +import com.google.gson.JsonObject; +import io.supertokens.ProcessState; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlag; +import io.supertokens.featureflag.FeatureFlagTestContent; +import io.supertokens.test.httpRequest.HttpResponseException; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.totp.TotpLicenseTest; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import static org.junit.Assert.*; + +public class CreateTotpDeviceAPITest { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + private Exception createDeviceRequest(TestingProcessManager.TestingProcess process, JsonObject body) { + return assertThrows( + io.supertokens.test.httpRequest.HttpResponseException.class, + () -> HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp")); + } + + private void checkFieldMissingErrorResponse(Exception ex, String fieldName) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains( + "Http error. Status Code: 400. Message: Field name '" + fieldName + "' is invalid in JSON input")); + } + + private void checkResponseErrorContains(Exception ex, String msg) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains(msg)); + } + + + @Test + public void testApi() throws Exception { + String[] args = { "../" }; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlag.getInstance(process.main).setLicenseKeyAndSyncFeatures(TotpLicenseTest.OPAQUE_KEY_WITH_TOTP_FEATURE); + FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + JsonObject body = new JsonObject(); + + // Missing userId/deviceName/skew/period + { + Exception e = createDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "userId"); + + body.addProperty("userId", ""); + e = createDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "deviceName"); + + body.addProperty("deviceName", ""); + e = createDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "skew"); + + body.addProperty("skew", -1); + e = createDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "period"); + } + + // Invalid userId/deviceName/skew/period + { + body.addProperty("period", 0); + Exception e = createDeviceRequest(process, body); + checkResponseErrorContains(e, "userId cannot be empty"); // Note that this is not a field missing error + + body.addProperty("userId", "user-id"); + e = createDeviceRequest(process, body); + checkResponseErrorContains(e, "deviceName cannot be empty"); + + body.addProperty("deviceName", "d1"); + e = createDeviceRequest(process, body); + checkResponseErrorContains(e, "skew must be >= 0"); + + body.addProperty("skew", 0); + e = createDeviceRequest(process, body); + checkResponseErrorContains(e, "period must be > 0"); + + body.addProperty("period", 30); + + // should pass now: + JsonObject res = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res.get("status").getAsString().equals("OK"); + + // try again with same device: + JsonObject res2 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res2.get("status").getAsString().equals("DEVICE_ALREADY_EXISTS_ERROR"); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } +} diff --git a/src/test/java/io/supertokens/test/totp/api/GetTotpDevicesAPITest.java b/src/test/java/io/supertokens/test/totp/api/GetTotpDevicesAPITest.java new file mode 100644 index 000000000..c21f8a88d --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/api/GetTotpDevicesAPITest.java @@ -0,0 +1,163 @@ +package io.supertokens.test.totp.api; + +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import io.supertokens.ProcessState; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlag; +import io.supertokens.featureflag.FeatureFlagTestContent; +import io.supertokens.test.httpRequest.HttpResponseException; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.totp.TotpLicenseTest; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import static org.junit.Assert.*; + +import java.util.HashMap; + +public class GetTotpDevicesAPITest { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + private Exception getDevicesRequestException(TestingProcessManager.TestingProcess process, + HashMap params) { + + return assertThrows( + io.supertokens.test.httpRequest.HttpResponseException.class, + () -> HttpRequestForTesting.sendGETRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/list", + params, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp")); + } + + private void checkFieldMissingErrorResponse(Exception ex, String fieldName) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains( + "Http error. Status Code: 400. Message: Field name '" + fieldName + "' is missing in GET request")); + } + + private void checkResponseErrorContains(Exception ex, String msg) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains(msg)); + } + + @Test + public void testApi() throws Exception { + String[] args = { "../" }; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + // Setup to create a device (which also creates a user) + { + JsonObject body = new JsonObject(); + body.addProperty("userId", "user-id"); + body.addProperty("deviceName", "device-name"); + body.addProperty("skew", 0); + body.addProperty("period", 30); + JsonObject res = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res.get("status").getAsString().equals("OK"); + } + + HashMap params = new HashMap<>(); + + // Missing userId + { + Exception e = getDevicesRequestException(process, params); + checkFieldMissingErrorResponse(e, "userId"); + } + + // Invalid userId + { + params.put("userId", ""); + Exception e = getDevicesRequestException(process, params); + checkResponseErrorContains(e, "userId cannot be empty"); // Note that this is not a field missing error + + params.put("userId", "user-id"); + + // should pass now: + JsonObject res = HttpRequestForTesting.sendGETRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/list", + params, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res.get("status").getAsString().equals("OK"); + + JsonArray devicesArr = res.get("devices").getAsJsonArray(); + JsonObject deviceJson = devicesArr.get(0).getAsJsonObject(); + + assert devicesArr.size() == 1; + assert deviceJson.get("name").getAsString().equals("device-name"); + assert deviceJson.get("period").getAsInt() == 30; + assert deviceJson.get("skew").getAsInt() == 0; + assert deviceJson.get("verified").getAsBoolean() == false; + + // try for non-existent user: + params.put("userId", "non-existent-user-id"); + JsonObject res2 = HttpRequestForTesting.sendGETRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/list", + params, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res2.get("status").getAsString().equals("TOTP_NOT_ENABLED_ERROR"); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + +} diff --git a/src/test/java/io/supertokens/test/totp/api/RemoveTotpDeviceAPITest.java b/src/test/java/io/supertokens/test/totp/api/RemoveTotpDeviceAPITest.java new file mode 100644 index 000000000..5ef0d1b4c --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/api/RemoveTotpDeviceAPITest.java @@ -0,0 +1,188 @@ +package io.supertokens.test.totp.api; + +import com.google.gson.JsonObject; +import io.supertokens.ProcessState; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlag; +import io.supertokens.featureflag.FeatureFlagTestContent; +import io.supertokens.test.httpRequest.HttpResponseException; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.totp.TotpLicenseTest; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import static org.junit.Assert.*; + +public class RemoveTotpDeviceAPITest { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + private Exception removeDeviceRequest(TestingProcessManager.TestingProcess process, JsonObject body) { + return assertThrows( + io.supertokens.test.httpRequest.HttpResponseException.class, + () -> HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/remove", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp")); + } + + private void checkFieldMissingErrorResponse(Exception ex, String fieldName) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains( + "Http error. Status Code: 400. Message: Field name '" + fieldName + "' is invalid in JSON input")); + } + + private void checkResponseErrorContains(Exception ex, String msg) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains(msg)); + } + + @Test + public void testApi() throws Exception { + String[] args = { "../" }; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + + // Setup user and devices: + JsonObject createDeviceReq = new JsonObject(); + createDeviceReq.addProperty("userId", "user-id"); + createDeviceReq.addProperty("deviceName", "d1"); + createDeviceReq.addProperty("period", 30); + createDeviceReq.addProperty("skew", 0); + + JsonObject createDeviceRes = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + createDeviceReq, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assertEquals(createDeviceRes.get("status").getAsString(), "OK"); + + // create another device d2: + createDeviceReq.addProperty("deviceName", "d2"); + JsonObject createDeviceRes2 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + createDeviceReq, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assertEquals(createDeviceRes2.get("status").getAsString(), "OK"); + + // Start the actual tests for remove device API: + + JsonObject body = new JsonObject(); + + // Missing userId/deviceName + { + Exception e = removeDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "userId"); + + body.addProperty("userId", ""); + e = removeDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "deviceName"); + + } + + // Invalid userId/deviceName + { + body.addProperty("deviceName", ""); + Exception e = removeDeviceRequest(process, body); + checkResponseErrorContains(e, "userId cannot be empty"); // Note that this is not a field missing error + + body.addProperty("userId", "user-id"); + e = removeDeviceRequest(process, body); + checkResponseErrorContains(e, "deviceName cannot be empty"); + + body.addProperty("deviceName", "d1"); + + // should pass now: + JsonObject res = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/remove", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res.get("status").getAsString().equals("OK"); + assert res.get("didDeviceExist").getAsBoolean() == true; + + // try again with same device (still pass but didDeviceExist should be false) + JsonObject res2 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/remove", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res2.get("status").getAsString().equals("OK"); + assert res2.get("didDeviceExist").getAsBoolean() == false; + + // try deleting device for a non-existent user + body.addProperty("userId", "non-existent-user"); + JsonObject res3 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/remove", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res3.get("status").getAsString().equals("TOTP_NOT_ENABLED_ERROR"); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + +} diff --git a/src/test/java/io/supertokens/test/totp/api/TotpUserIdMappingTest.java b/src/test/java/io/supertokens/test/totp/api/TotpUserIdMappingTest.java new file mode 100644 index 000000000..d8dd0d974 --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/api/TotpUserIdMappingTest.java @@ -0,0 +1,181 @@ +package io.supertokens.test.totp.api; + +import com.google.gson.JsonObject; +import io.supertokens.ProcessState; +import io.supertokens.emailpassword.EmailPassword; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlag; +import io.supertokens.featureflag.FeatureFlagTestContent; +import io.supertokens.pluginInterface.emailpassword.UserInfo; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.totp.TOTPRecipeTest; +import io.supertokens.test.totp.TotpLicenseTest; +import io.supertokens.useridmapping.UserIdMapping; + +import static org.junit.Assert.assertNotNull; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +public class TotpUserIdMappingTest { + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + @Test + public void testExternalUserIdTranslation() throws Exception { + String[] args = { "../" }; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + + JsonObject body = new JsonObject(); + + UserInfo user = EmailPassword.signUp(process.main, "test@example.com", "testPass123"); + String superTokensUserId = user.id; + String externalUserId = "external-user-id"; + + // Create user id mapping first: + UserIdMapping.createUserIdMapping(process.main, superTokensUserId, externalUserId, null, false); + + body.addProperty("userId", externalUserId); + body.addProperty("deviceName", "d1"); + body.addProperty("skew", 0); + body.addProperty("period", 30); + + // Register 1st device + JsonObject res1 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res1.get("status").getAsString().equals("OK"); + String d1Secret = res1.get("secret").getAsString(); + TOTPDevice device1 = new TOTPDevice(externalUserId, "deviceName", d1Secret, 30, 0, false); + + body.addProperty("deviceName", "d2"); + + JsonObject res2 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res2.get("status").getAsString().equals("OK"); + String d2Secret = res2.get("secret").getAsString(); + TOTPDevice device2 = new TOTPDevice(externalUserId, "deviceName", d2Secret, 30, 0, false); + + // Verify d1 but not d2: + JsonObject verifyD1Input = new JsonObject(); + verifyD1Input.addProperty("userId", externalUserId); + String d1Totp = TOTPRecipeTest.generateTotpCode(process.getProcess(), device1); + verifyD1Input.addProperty("deviceName", "d1"); + verifyD1Input.addProperty("totp", d1Totp ); + + JsonObject verifyD1Res = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/verify", + verifyD1Input, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + + assert verifyD1Res.get("status").getAsString().equals("OK"); + assert verifyD1Res.get("wasAlreadyVerified").getAsBoolean() == false; + + // use d2 to login in totp: + JsonObject loginInput = new JsonObject(); + loginInput.addProperty("userId", externalUserId); + String d2Totp = TOTPRecipeTest.generateTotpCode(process.getProcess(), device2); + loginInput.addProperty("totp", d2Totp); // use code from d2 which is unverified + loginInput.addProperty("allowUnverifiedDevices", true); + + JsonObject loginRes = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/verify", + loginInput, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + + assert loginRes.get("status").getAsString().equals("OK"); + + // Change the name of d1 to d3: + JsonObject updateDeviceNameInput = new JsonObject(); + updateDeviceNameInput.addProperty("userId", externalUserId); + updateDeviceNameInput.addProperty("existingDeviceName", "d1"); + updateDeviceNameInput.addProperty("newDeviceName", "d3"); + + JsonObject updateDeviceNameRes = HttpRequestForTesting.sendJsonPUTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + updateDeviceNameInput, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + + assert updateDeviceNameRes.get("status").getAsString().equals("OK"); + + // Delete d3: + JsonObject deleteDeviceInput = new JsonObject(); + deleteDeviceInput.addProperty("userId", externalUserId); + deleteDeviceInput.addProperty("deviceName", "d3"); + + JsonObject deleteDeviceRes = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/remove", + deleteDeviceInput, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + + + assert deleteDeviceRes.get("status").getAsString().equals("OK"); + assert deleteDeviceRes.get("didDeviceExist").getAsBoolean() == true; + + } +} diff --git a/src/test/java/io/supertokens/test/totp/api/UpdateTotpDeviceAPITest.java b/src/test/java/io/supertokens/test/totp/api/UpdateTotpDeviceAPITest.java new file mode 100644 index 000000000..f021d2804 --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/api/UpdateTotpDeviceAPITest.java @@ -0,0 +1,209 @@ +package io.supertokens.test.totp.api; + +import com.google.gson.JsonObject; +import io.supertokens.ProcessState; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlag; +import io.supertokens.featureflag.FeatureFlagTestContent; +import io.supertokens.test.httpRequest.HttpResponseException; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.totp.TotpLicenseTest; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import static org.junit.Assert.*; + +public class UpdateTotpDeviceAPITest { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + private Exception updateDeviceRequest(TestingProcessManager.TestingProcess process, JsonObject body) { + return assertThrows( + io.supertokens.test.httpRequest.HttpResponseException.class, + () -> HttpRequestForTesting.sendJsonPUTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp")); + } + + private void checkFieldMissingErrorResponse(Exception ex, String fieldName) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains( + "Http error. Status Code: 400. Message: Field name '" + fieldName + "' is invalid in JSON input")); + } + + private void checkResponseErrorContains(Exception ex, String msg) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains(msg)); + } + + @Test + public void testApi() throws Exception { + String[] args = { "../" }; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + + // Setup user and devices: + JsonObject createDeviceReq = new JsonObject(); + createDeviceReq.addProperty("userId", "user-id"); + createDeviceReq.addProperty("deviceName", "d1"); + createDeviceReq.addProperty("period", 30); + createDeviceReq.addProperty("skew", 0); + + JsonObject createDeviceRes = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + createDeviceReq, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assertEquals(createDeviceRes.get("status").getAsString(), "OK"); + + // create another device d2: + createDeviceReq.addProperty("deviceName", "d2"); + JsonObject createDeviceRes2 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + createDeviceReq, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assertEquals(createDeviceRes2.get("status").getAsString(), "OK"); + + // Start the actual tests for update device API: + + JsonObject body = new JsonObject(); + + // Missing userId/deviceName/skew/period + { + Exception e = updateDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "userId"); + + body.addProperty("userId", ""); + e = updateDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "existingDeviceName"); + + body.addProperty("existingDeviceName", ""); + e = updateDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "newDeviceName"); + + } + + // Invalid userId/deviceName/skew/period + { + body.addProperty("newDeviceName", ""); + Exception e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "userId cannot be empty"); // Note that this is not a field missing error + + body.addProperty("userId", "user-id"); + e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "existingDeviceName cannot be empty"); + + body.addProperty("existingDeviceName", "d1"); + e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "newDeviceName cannot be empty"); + + body.addProperty("newDeviceName", "d1-new"); + + // should pass now: + JsonObject res = HttpRequestForTesting.sendJsonPUTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res.get("status").getAsString().equals("OK"); + + // try again with same device (has been renamed so should fail) + JsonObject res2 = HttpRequestForTesting.sendJsonPUTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res2.get("status").getAsString().equals("UNKNOWN_DEVICE_ERROR"); + + // try renaming to a device that already exists + body.addProperty("existingDeviceName", "d1-new"); + body.addProperty("newDeviceName", "d2"); + JsonObject res3 = HttpRequestForTesting.sendJsonPUTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res3.get("status").getAsString().equals("DEVICE_ALREADY_EXISTS_ERROR"); + + // try renaming to a device that already exists for a non-existent user + body.addProperty("userId", "non-existent-user"); + JsonObject res4 = HttpRequestForTesting.sendJsonPUTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res4.get("status").getAsString().equals("TOTP_NOT_ENABLED_ERROR"); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + +} diff --git a/src/test/java/io/supertokens/test/totp/api/VerifyTotpAPITest.java b/src/test/java/io/supertokens/test/totp/api/VerifyTotpAPITest.java new file mode 100644 index 000000000..5158f8046 --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/api/VerifyTotpAPITest.java @@ -0,0 +1,228 @@ +package io.supertokens.test.totp.api; + +import com.google.gson.JsonObject; + +import io.supertokens.ProcessState; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlag; +import io.supertokens.featureflag.FeatureFlagTestContent; +import io.supertokens.test.httpRequest.HttpResponseException; +import io.supertokens.test.totp.TOTPRecipeTest; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.totp.TotpLicenseTest; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import static org.junit.Assert.*; + +public class VerifyTotpAPITest { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + private Exception updateDeviceRequest(TestingProcessManager.TestingProcess process, JsonObject body) { + return assertThrows( + io.supertokens.test.httpRequest.HttpResponseException.class, + () -> HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp")); + } + + private void checkFieldMissingErrorResponse(Exception ex, String fieldName) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains( + "Http error. Status Code: 400. Message: Field name '" + fieldName + "' is invalid in JSON input")); + } + + private void checkResponseErrorContains(Exception ex, String msg) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains(msg)); + } + + @Test + public void testApi() throws Exception { + String[] args = { "../" }; + + // Trigger rate limiting on 1 wrong attempts: + Utils.setValueInConfig("totp_max_attempts", "1"); + // Set cooldown to 1 second: + Utils.setValueInConfig("totp_rate_limit_cooldown_sec", "1"); + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + + // Setup user and devices: + JsonObject createDeviceReq = new JsonObject(); + createDeviceReq.addProperty("userId", "user-id"); + createDeviceReq.addProperty("deviceName", "deviceName"); + createDeviceReq.addProperty("period", 30); + createDeviceReq.addProperty("skew", 0); + + JsonObject createDeviceRes = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + createDeviceReq, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assertEquals(createDeviceRes.get("status").getAsString(), "OK"); + String secretKey = createDeviceRes.get("secret").getAsString(); + + TOTPDevice device = new TOTPDevice("user-id", "deviceName", secretKey, 30, 0, false); + + // Start the actual tests for update device API: + + JsonObject body = new JsonObject(); + + // Missing userId/deviceName/skew/period + { + Exception e = updateDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "userId"); + + body.addProperty("userId", ""); + e = updateDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "totp"); + + body.addProperty("totp", ""); + e = updateDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "allowUnverifiedDevices"); + } + + // Invalid userId/deviceName/skew/period + { + body.addProperty("allowUnverifiedDevices", true); + Exception e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "userId cannot be empty"); // Note that this is not a field missing error + + body.addProperty("userId", device.userId); + e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "totp must be 6 characters long"); + + // test totp of length 5: + body.addProperty("totp", "12345"); + e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "totp must be 6 characters long"); + + // test totp of length 8: + body.addProperty("totp", "12345678"); + e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "totp must be 6 characters long"); + + // but let's pass invalid code first + body.addProperty("totp", "123456"); + JsonObject res0 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res0.get("status").getAsString().equals("INVALID_TOTP_ERROR"); + + // Check that rate limiting is triggered for the user: + JsonObject res3 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res3.get("status").getAsString().equals("LIMIT_REACHED_ERROR"); + assert res3.get("retryAfterMs") != null; + + // wait for cooldown to end (1s) + Thread.sleep(1000); + + // should pass now on valid code + String validTotp = TOTPRecipeTest.generateTotpCode(process.getProcess(), device); + body.addProperty("totp", validTotp); + JsonObject res = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res.get("status").getAsString().equals("OK"); + + // try to reuse the same code (replay attack) + body.addProperty("totp", "mycode"); + JsonObject res2 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res2.get("status").getAsString().equals("INVALID_TOTP_ERROR"); + + // try verifying device for a non-existent user + body.addProperty("userId", "non-existent-user"); + JsonObject res5 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res5.get("status").getAsString().equals("TOTP_NOT_ENABLED_ERROR"); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + +} diff --git a/src/test/java/io/supertokens/test/totp/api/VerifyTotpDeviceAPITest.java b/src/test/java/io/supertokens/test/totp/api/VerifyTotpDeviceAPITest.java new file mode 100644 index 000000000..53703b50b --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/api/VerifyTotpDeviceAPITest.java @@ -0,0 +1,248 @@ +package io.supertokens.test.totp.api; + +import com.google.gson.JsonObject; + +import io.supertokens.ProcessState; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlag; +import io.supertokens.featureflag.FeatureFlagTestContent; +import io.supertokens.test.httpRequest.HttpResponseException; +import io.supertokens.test.totp.TOTPRecipeTest; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.totp.TotpLicenseTest; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import static org.junit.Assert.*; + +public class VerifyTotpDeviceAPITest { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + private Exception updateDeviceRequest(TestingProcessManager.TestingProcess process, JsonObject body) { + return assertThrows( + io.supertokens.test.httpRequest.HttpResponseException.class, + () -> HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp")); + } + + private void checkFieldMissingErrorResponse(Exception ex, String fieldName) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains( + "Http error. Status Code: 400. Message: Field name '" + fieldName + "' is invalid in JSON input")); + } + + private void checkResponseErrorContains(Exception ex, String msg) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains(msg)); + } + + @Test + public void testApi() throws Exception { + String[] args = { "../" }; + + // Trigger rate limiting on 1 wrong attempts: + Utils.setValueInConfig("totp_max_attempts", "1"); + // Set cooldown to 1 second: + Utils.setValueInConfig("totp_rate_limit_cooldown_sec", "1"); + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + + // Setup user and devices: + JsonObject createDeviceReq = new JsonObject(); + createDeviceReq.addProperty("userId", "user-id"); + createDeviceReq.addProperty("deviceName", "deviceName"); + createDeviceReq.addProperty("period", 30); + createDeviceReq.addProperty("skew", 0); + + JsonObject createDeviceRes = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + createDeviceReq, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assertEquals(createDeviceRes.get("status").getAsString(), "OK"); + String secretKey = createDeviceRes.get("secret").getAsString(); + + TOTPDevice device = new TOTPDevice("user-id", "deviceName", secretKey, 30, 0, false); + + // Start the actual tests for update device API: + + JsonObject body = new JsonObject(); + + // Missing userId/deviceName/skew/period + { + Exception e = updateDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "userId"); + + body.addProperty("userId", ""); + e = updateDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "deviceName"); + + body.addProperty("deviceName", ""); + e = updateDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "totp"); + } + + // Invalid userId/deviceName/skew/period + { + body.addProperty("totp", ""); + Exception e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "userId cannot be empty"); // Note that this is not a field missing error + + body.addProperty("userId", device.userId); + e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "deviceName cannot be empty"); + + body.addProperty("deviceName", device.deviceName); + e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "totp must be 6 characters long"); + + // test totp of length 5: + body.addProperty("totp", "12345"); + e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "totp must be 6 characters long"); + + // test totp of length 8: + body.addProperty("totp", "12345678"); + e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "totp must be 6 characters long"); + + // but let's pass invalid code first + body.addProperty("totp", "123456"); + JsonObject res0 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res0.get("status").getAsString().equals("INVALID_TOTP_ERROR"); + + // Check that rate limiting is triggered for the user: + JsonObject res3 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res3.get("status").getAsString().equals("LIMIT_REACHED_ERROR"); + assert res3.get("retryAfterMs") != null; + + // wait for cooldown to end (1s) + Thread.sleep(1000); + + // should pass now on valid code + String validTotp = TOTPRecipeTest.generateTotpCode(process.getProcess(), device); + body.addProperty("totp", validTotp); + JsonObject res = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res.get("status").getAsString().equals("OK"); + assert res.get("wasAlreadyVerified").getAsBoolean() == false; + + // try again to verify the user with any code (valid/invalid) + body.addProperty("totp", "mycode"); + JsonObject res2 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res2.get("status").getAsString().equals("OK"); + assert res2.get("wasAlreadyVerified").getAsBoolean() == true; + + // try again with unknown device + body.addProperty("deviceName", "non-existent-device"); + JsonObject res4 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res4.get("status").getAsString().equals("UNKNOWN_DEVICE_ERROR"); + + // try verifying device for a non-existent user + body.addProperty("userId", "non-existent-user"); + JsonObject res5 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res5.get("status").getAsString().equals("TOTP_NOT_ENABLED_ERROR"); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + +} diff --git a/src/test/java/io/supertokens/test/userIdMapping/UserIdMappingTest.java b/src/test/java/io/supertokens/test/userIdMapping/UserIdMappingTest.java index cbaff9585..3b0484510 100644 --- a/src/test/java/io/supertokens/test/userIdMapping/UserIdMappingTest.java +++ b/src/test/java/io/supertokens/test/userIdMapping/UserIdMappingTest.java @@ -20,6 +20,8 @@ import io.supertokens.ProcessState; import io.supertokens.authRecipe.AuthRecipe; import io.supertokens.emailpassword.EmailPassword; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlagTestContent; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.pluginInterface.emailpassword.UserInfo; import io.supertokens.pluginInterface.nonAuthRecipe.NonAuthRecipeStorage; @@ -783,6 +785,9 @@ public void checkThatCreateUserIdMappingHasAllNonAuthRecipeChecks() throws Excep if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { return; } + + FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + // this list contains the package names for recipes which dont use UserIdMapping ArrayList nonAuthRecipesWhichDontNeedUserIdMapping = new ArrayList<>( List.of("io.supertokens.pluginInterface.jwt.JWTRecipeStorage"));