From 06cf6d5d71532fb6b46c49aeb8a9214048a4f7b1 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Fri, 10 Feb 2023 12:43:05 +0530 Subject: [PATCH 01/42] feat: Implement TOTP inmemory classes --- .../java/io/supertokens/inmemorydb/Start.java | 94 ++++++++- .../inmemorydb/config/SQLiteConfig.java | 10 +- .../inmemorydb/queries/TOTPQueries.java | 188 ++++++++++++++++++ .../storageLayer/StorageLayer.java | 12 ++ src/main/java/io/supertokens/totp/Totp.java | 54 +++++ .../test/totp/TOTPDevicesTest.java | 96 +++++++++ 6 files changed, 451 insertions(+), 3 deletions(-) create mode 100644 src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java create mode 100644 src/main/java/io/supertokens/totp/Totp.java create mode 100644 src/test/java/io/supertokens/test/totp/TOTPDevicesTest.java diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index 65c69ff34..d1d950806 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -59,6 +59,12 @@ import io.supertokens.pluginInterface.thirdparty.sqlStorage.ThirdPartySQLStorage; import io.supertokens.pluginInterface.useridmapping.UserIdMapping; import io.supertokens.pluginInterface.useridmapping.UserIdMappingStorage; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.totp.TOTPUsedCode; +import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; +import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; +import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; import io.supertokens.pluginInterface.useridmapping.exception.UnknownSuperTokensUserIdException; import io.supertokens.pluginInterface.useridmapping.exception.UserIdMappingAlreadyExistsException; import io.supertokens.pluginInterface.usermetadata.UserMetadataStorage; @@ -87,7 +93,8 @@ public class Start implements SessionSQLStorage, EmailPasswordSQLStorage, EmailVerificationSQLStorage, ThirdPartySQLStorage, - JWTRecipeSQLStorage, PasswordlessSQLStorage, UserMetadataSQLStorage, UserRolesSQLStorage, UserIdMappingStorage { + JWTRecipeSQLStorage, PasswordlessSQLStorage, UserMetadataSQLStorage, UserRolesSQLStorage, UserIdMappingStorage, + TOTPSQLStorage { private static final Object appenderLock = new Object(); private static final String APP_ID_KEY_NAME = "app_id"; @@ -1455,7 +1462,8 @@ public void createUserIdMapping(String superTokensUserId, String externalUserId, @Nullable String externalUserIdInfo) throws StorageQueryException, UnknownSuperTokensUserIdException, UserIdMappingAlreadyExistsException { - // SQLite is not compiled with foreign key constraint, so we need an explicit check to see if superTokensUserId + // SQLite is not compiled with foreign key constraint, so we need an explicit + // check to see if superTokensUserId // is a valid // userId. if (!doesUserIdExist(superTokensUserId)) { @@ -1619,4 +1627,86 @@ public void addInfoToNonAuthRecipesBasedOnUserId(String className, String userId throw new IllegalStateException("ClassName: " + className + " is not part of NonAuthRecipeStorage"); } } + + // TOTP recipe: + + @Override + public void createDevice(TOTPDevice device) throws StorageQueryException { + try { + TOTPQueries.createDevice(this, device); + } catch (Exception e) { + throw new StorageQueryException(e); + } + } + + @Override + public void markDeviceAsVerified(String userId, String deviceName) + throws StorageQueryException, TotpNotEnabledException, UnknownDeviceException { + try { + TOTPQueries.markDeviceAsVerified(this, userId, deviceName); + } catch (Exception e) { + throw new StorageQueryException(e); + } + } + + @Override + public void deleteDevice(String userId, String deviceName) + throws StorageQueryException, TotpNotEnabledException, UnknownDeviceException { + try { + TOTPQueries.deleteDevice(this, userId, deviceName); + } catch (Exception e) { + throw new StorageQueryException(e); + } + } + + @Override + public void updateDeviceName(String userId, String oldDeviceName, String newDeviceName) + throws StorageQueryException, TotpNotEnabledException, DeviceAlreadyExistsException, + UnknownDeviceException { + try { + TOTPQueries.updateDeviceName(this, userId, oldDeviceName, newDeviceName); + } catch (Exception e) { + throw new StorageQueryException(e); + } + } + + @Override + public TOTPDevice[] getDevices(String userId) + throws StorageQueryException { + try { + return TOTPQueries.getDevices(this, userId); + } catch (Exception e) { + throw new StorageQueryException(e); + } + } + + @Override + public boolean insertUsedCode(TOTPUsedCode code) + throws StorageQueryException, TotpNotEnabledException { + try { + return TOTPQueries.insertUsedCode(this, code); + } catch (Exception e) { + throw new StorageQueryException(e); + } + } + + @Override + public TOTPUsedCode[] getUsedCodes(String userId) + throws StorageQueryException, TotpNotEnabledException { + try { + return TOTPQueries.getUsedCodes(this, userId); + } catch (Exception e) { + throw new StorageQueryException(e); + } + } + + @Override + public void removeExpiredCodes() + throws StorageQueryException { + try { + TOTPQueries.removeExpiredCodes(this); + } catch (Exception e) { + throw new StorageQueryException(e); + } + } } diff --git a/src/main/java/io/supertokens/inmemorydb/config/SQLiteConfig.java b/src/main/java/io/supertokens/inmemorydb/config/SQLiteConfig.java index eb7d61574..c2eb0b2dd 100644 --- a/src/main/java/io/supertokens/inmemorydb/config/SQLiteConfig.java +++ b/src/main/java/io/supertokens/inmemorydb/config/SQLiteConfig.java @@ -89,4 +89,12 @@ public String getUserRolesTable() { public String getUserIdMappingTable() { return "userid_mapping"; } -} \ No newline at end of file + + public String getTotpUserDevicesTable() { + return "totp_user_devices"; + } + + public String getTotpUsedCodesTable() { + return "totp_used_codes"; + } +} diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java new file mode 100644 index 000000000..0e36ea3bc --- /dev/null +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -0,0 +1,188 @@ +package io.supertokens.inmemorydb.queries; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import io.supertokens.inmemorydb.ConnectionWithLocks; +import io.supertokens.inmemorydb.PreparedStatementValueSetter; +import io.supertokens.inmemorydb.Start; +import io.supertokens.inmemorydb.config.Config; +import io.supertokens.pluginInterface.RowMapper; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.sqlStorage.SQLStorage.TransactionIsolationLevel; + +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.totp.TOTPUsedCode; + +import static io.supertokens.inmemorydb.QueryExecutorTemplate.execute; +import static io.supertokens.inmemorydb.QueryExecutorTemplate.update; + +public class TOTPQueries { + public static String getQueryToCreateUserDevicesTable(Start start) { + // Todo: verify if "DEFAULT FALSE" is correct + // TODO: Verify all queries using SQLite + return "CREATE TABLE IF NOT EXISTS" + Config.getConfig(start).getTotpUserDevicesTable() + " (" + + "user_id VARCHAR(128) NOT NULL," + "device_name VARCHAR(256) NOT NULL," + + "secret_key VARCHAR(256) NOT NULL," + + "period INTEGER NOT NULL," + "skew INTEGER NOT NULL," + "verified BOOLEAN NOT NULL," + + "PRIMARY KEY (user_id, device_name)"; + } + + public static String getQueryToCreateUsedCodesTable(Start start) { + return "CREATE TABLE IF NOT EXISTS" + Config.getConfig(start).getTotpUsedCodesTable() + " (" + + "user_id VARCHAR(128) NOT NULL," + "code VARCHAR(6) NOT NULL," + "is_valid_code BOOLEAN NOT NULL," + + "expiry_time BIGINT NOT NULL," + "PRIMARY KEY (user_id, code)," + + "FOREIGN KEY (user_id) REFERENCES" + Config.getConfig(start).getTotpUserDevicesTable() + + "(user_id) ON DELETE CASCADE"; + } + + public static String getQueryToCreateUsedCodesIndex(Start start) { + return "CREATE INDEX IF NOT EXISTS totp_used_codes_expiry_time_index ON " + + Config.getConfig(start).getTotpUsedCodesTable() + " (expiry_time)"; + } + + public static void createDevice(Start start, TOTPDevice device) + throws StorageQueryException, StorageTransactionLogicException, SQLException { + String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUserDevicesTable() + + " (deviceName, userId, secretKey, period, skew, verified) VALUES (?, ?, ?, ?, ?, ?)"; + + update(start, QUERY, pst -> { + pst.setString(1, device.deviceName); + pst.setString(2, device.userId); + pst.setString(3, device.secretKey); + pst.setInt(4, device.period); + pst.setInt(5, device.skew); + pst.setBoolean(6, device.verified); + }); + } + + public static void markDeviceAsVerified(Start start, String userId, String deviceName) + throws StorageTransactionLogicException, StorageQueryException, SQLException { + String QUERY = "UPDATE " + Config.getConfig(start).getTotpUserDevicesTable() + + " SET verified = true WHERE user_id = ? AND device_name = ?;"; // What if device is already + // verified? + update(start, QUERY, pst -> { + pst.setString(1, userId); + pst.setString(2, deviceName); + }); + } + + public static void deleteDevice(Start start, String userId, String deviceName) + throws StorageTransactionLogicException, StorageQueryException, SQLException { + String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUserDevicesTable() + + " WHERE user_id = ? AND device_name = ?;"; + + update(start, QUERY, pst -> { + pst.setString(1, userId); + pst.setString(2, deviceName); + }); + } + + public static void updateDeviceName(Start start, String userId, String oldDeviceName, String newDeviceName) + throws StorageTransactionLogicException, StorageQueryException, SQLException { + String QUERY = "UPDATE " + Config.getConfig(start).getTotpUserDevicesTable() + + " SET device_name = ? WHERE user_id = ? AND device_name = ?;"; + + update(start, QUERY, pst -> { + pst.setString(1, newDeviceName); + pst.setString(2, userId); + pst.setString(3, oldDeviceName); + }); + } + + public static TOTPDevice[] getDevices(Start start, String userId) + throws StorageTransactionLogicException, StorageQueryException, SQLException { + String QUERY = "SELECT * FROM " + Config.getConfig(start).getTotpUserDevicesTable() + + " WHERE user_id = ?;"; + + return execute(start, QUERY, pst -> pst.setString(1, userId), result -> { + List devices = new ArrayList<>(); + while (result.next()) { + devices.add(TOTPDeviceRowMapper.getInstance().map(result)); + } + + return devices.toArray(new TOTPDevice[0]); + }); + } + + private static class TOTPDeviceRowMapper implements RowMapper { + private static final TOTPDeviceRowMapper INSTANCE = new TOTPDeviceRowMapper(); + + private TOTPDeviceRowMapper() { + } + + private static TOTPDeviceRowMapper getInstance() { + return INSTANCE; + } + + @Override + public TOTPDevice map(ResultSet result) throws SQLException { + return new TOTPDevice( + result.getString("device_name"), + result.getString("user_id"), + result.getString("secret_key"), + result.getInt("period"), + result.getInt("skew"), + result.getBoolean("verified")); + } + } + + private static class TOTPUsedCodeRowMapper implements RowMapper { + private static final TOTPUsedCodeRowMapper INSTANCE = new TOTPUsedCodeRowMapper(); + + private TOTPUsedCodeRowMapper() { + } + + private static TOTPUsedCodeRowMapper getInstance() { + return INSTANCE; + } + + @Override + public TOTPUsedCode map(ResultSet result) throws SQLException { + return new TOTPUsedCode( + result.getString("user_id"), + result.getString("code"), + result.getBoolean("is_valid_code"), + result.getLong("expiry_time")); + } + } + + public static boolean insertUsedCode(Start start, TOTPUsedCode code) throws SQLException, StorageQueryException { + String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUsedCodesTable() + + " (user_id, code, is_valid_code, expiry_time) VALUES (?, ?, ?, ?);"; + + update(start, QUERY, pst -> { + pst.setString(1, code.userId); + pst.setString(2, code.code); + pst.setBoolean(3, code.isValidCode); + pst.setLong(4, code.expiryTime); + }); + return true; // FIXME: This is not correct. We should check if the code was inserted or not. + } + + public static TOTPUsedCode[] getUsedCodes(Start start, String userId) throws SQLException, StorageQueryException { + String QUERY = "SELECT * FROM " + + Config.getConfig(start).getTotpUsedCodesTable() + + " WHERE user_id = ?;"; + return execute(start, QUERY, pst -> pst.setString(1, userId), result -> { + List codes = new ArrayList<>(); + while (result.next()) { + codes.add(TOTPUsedCodeRowMapper.getInstance().map(result)); + } + + return codes.toArray(new TOTPUsedCode[0]); + }); + } + + public static void removeExpiredCodes(Start start) + throws StorageTransactionLogicException, StorageQueryException, SQLException { + String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUsedCodesTable() + + " WHERE expiry_time < ?;"; + + update(start, QUERY, pst -> pst.setLong(1, System.currentTimeMillis())); + } +} diff --git a/src/main/java/io/supertokens/storageLayer/StorageLayer.java b/src/main/java/io/supertokens/storageLayer/StorageLayer.java index d8509819e..847a99984 100644 --- a/src/main/java/io/supertokens/storageLayer/StorageLayer.java +++ b/src/main/java/io/supertokens/storageLayer/StorageLayer.java @@ -32,6 +32,7 @@ import io.supertokens.pluginInterface.passwordless.sqlStorage.PasswordlessSQLStorage; import io.supertokens.pluginInterface.session.SessionStorage; import io.supertokens.pluginInterface.thirdparty.sqlStorage.ThirdPartySQLStorage; +import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; import io.supertokens.pluginInterface.useridmapping.UserIdMappingStorage; import io.supertokens.pluginInterface.usermetadata.sqlStorage.UserMetadataSQLStorage; import io.supertokens.pluginInterface.userroles.sqlStorage.UserRolesSQLStorage; @@ -264,6 +265,17 @@ public static UserIdMappingStorage getUserIdMappingStorage(Main main) { return (UserIdMappingStorage) getInstance(main).storage; } + public static TOTPSQLStorage getTOTPStorage(Main main) { + if (getInstance(main) == null) { + throw new QuitProgramException("please call init() before calling getStorageLayer"); + } + if (getInstance(main).storage.getType() != STORAGE_TYPE.SQL) { + // we only support SQL for now + throw new UnsupportedOperationException(""); + } + return (TOTPSQLStorage) getInstance(main).storage; + } + public boolean isInMemDb() { return this.storage instanceof Start; } diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java new file mode 100644 index 000000000..dbfd5d0ce --- /dev/null +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -0,0 +1,54 @@ +package io.supertokens.totp; + +import java.io.IOException; + +import io.supertokens.Main; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; +import io.supertokens.storageLayer.StorageLayer; + +public class Totp { + + public static CreateDeviceResponse createDevice(Main main, String userId, String deviceName, int skew, int period) + throws IOException { + + TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); + String secret = GenerateDeviceSecret.generate(); + + if (userId == null || deviceName == null || secret == null) { + throw new IllegalArgumentException("userId, deviceName and secret cannot be null"); + } + + try { + TOTPDevice device = new TOTPDevice(userId, deviceName, secret, skew, period, false); + totpStorage.createDevice(device); + } catch (Exception e) { + throw new IOException(e); + } + + return new CreateDeviceResponse("deviceName", secret); + } + + private static class GenerateDeviceSecret { + // private final String secret; + + // private GenerateDeviceSecret(String secret) { + // this.secret = secret; + // } + + public static String generate() { + return "XXXX"; + } + } + + public static class CreateDeviceResponse { + public String deviceName; + public String secret; + + public CreateDeviceResponse(String deviceName, String secret) { + this.deviceName = deviceName; + this.secret = secret; + } + + } +} diff --git a/src/test/java/io/supertokens/test/totp/TOTPDevicesTest.java b/src/test/java/io/supertokens/test/totp/TOTPDevicesTest.java new file mode 100644 index 000000000..c81b3bfd4 --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/TOTPDevicesTest.java @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.test.totp; + +import static org.junit.Assert.assertNotNull; // Not sure about this + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import io.supertokens.test.Utils; +import io.supertokens.ProcessState; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; + +import io.supertokens.totp.Totp; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.totp.TOTPStorage; + +public class TOTPDevicesTest { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + @Test + public void createDeviceWithFullCode() throws Exception { + String[] args = { "../" }; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + TOTPStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); + + Totp.CreateDeviceResponse createDeviceResponse = Totp.createDevice(process.getProcess(), "userId", "deviceName", + 1, 30); + assertNotNull(createDeviceResponse); + createDeviceResponse.deviceName.equals("deviceName"); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void createDeviceWithDb() throws Exception { + String[] args = { "../" }; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + TOTPDevice newDevice = new TOTPDevice("deviceName", "userId", "secretKey", 30, 1, false); + + TOTPStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); + storage.createDevice(newDevice); + + TOTPDevice[] storedDevices = storage.getDevices("userId"); + assertNotNull(storedDevices); + assert (storedDevices.length == 1); + assert (storedDevices[0].deviceName.equals("deviceName")); + } + +} From 37aec0e3e4eba52fb9e9a95523855689b07c1755 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Fri, 10 Feb 2023 13:17:16 +0530 Subject: [PATCH 02/42] feat: Create tables and indexes for TOTP --- .../inmemorydb/queries/GeneralQueries.java | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java index 17045c6fb..5812d8cef 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java @@ -186,6 +186,19 @@ public static void createTablesIfNotExists(Start start, Main main) throws SQLExc update(start, UserIdMappingQueries.getQueryToCreateUserIdMappingTable(start), NO_OP_SETTER); } + + if (!doesTableExists(start, Config.getConfig(start).getTotpUserDevicesTable())) { + getInstance(main).addState(CREATING_NEW_TABLE, null); + update(start, TOTPQueries.getQueryToCreateUserDevicesTable(start), NO_OP_SETTER); + } + + if (!doesTableExists(start, Config.getConfig(start).getTotpUsedCodesTable())) { + getInstance(main).addState(CREATING_NEW_TABLE, null); + update(start, TOTPQueries.getQueryToCreateUsedCodesTable(start), NO_OP_SETTER); + // index: + update(start, TOTPQueries.getQueryToCreateUsedCodesIndex(start), NO_OP_SETTER); + } + } public static void setKeyValue_Transaction(Start start, Connection con, String key, KeyValueInfo info) From dcbaf29ac03d5f0ddb9f150134ceb0076a5c18a0 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Fri, 10 Feb 2023 13:25:46 +0530 Subject: [PATCH 03/42] refactor: Remove comments and unused code --- .../supertokens/inmemorydb/queries/TOTPQueries.java | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index 0e36ea3bc..8602ff408 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -1,19 +1,15 @@ package io.supertokens.inmemorydb.queries; -import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; import java.util.List; -import io.supertokens.inmemorydb.ConnectionWithLocks; -import io.supertokens.inmemorydb.PreparedStatementValueSetter; import io.supertokens.inmemorydb.Start; import io.supertokens.inmemorydb.config.Config; import io.supertokens.pluginInterface.RowMapper; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; -import io.supertokens.pluginInterface.sqlStorage.SQLStorage.TransactionIsolationLevel; import io.supertokens.pluginInterface.totp.TOTPDevice; import io.supertokens.pluginInterface.totp.TOTPUsedCode; @@ -23,8 +19,6 @@ public class TOTPQueries { public static String getQueryToCreateUserDevicesTable(Start start) { - // Todo: verify if "DEFAULT FALSE" is correct - // TODO: Verify all queries using SQLite return "CREATE TABLE IF NOT EXISTS" + Config.getConfig(start).getTotpUserDevicesTable() + " (" + "user_id VARCHAR(128) NOT NULL," + "device_name VARCHAR(256) NOT NULL," + "secret_key VARCHAR(256) NOT NULL," @@ -63,8 +57,7 @@ public static void createDevice(Start start, TOTPDevice device) public static void markDeviceAsVerified(Start start, String userId, String deviceName) throws StorageTransactionLogicException, StorageQueryException, SQLException { String QUERY = "UPDATE " + Config.getConfig(start).getTotpUserDevicesTable() - + " SET verified = true WHERE user_id = ? AND device_name = ?;"; // What if device is already - // verified? + + " SET verified = true WHERE user_id = ? AND device_name = ?;"; update(start, QUERY, pst -> { pst.setString(1, userId); pst.setString(2, deviceName); @@ -161,7 +154,7 @@ public static boolean insertUsedCode(Start start, TOTPUsedCode code) throws SQLE pst.setBoolean(3, code.isValidCode); pst.setLong(4, code.expiryTime); }); - return true; // FIXME: This is not correct. We should check if the code was inserted or not. + return true; // FIXME: Count the number of rows inserted OR Check if the code was inserted } public static TOTPUsedCode[] getUsedCodes(Start start, String userId) throws SQLException, StorageQueryException { From 81e37655f4ffc967835e43dbf75b1108a8c11b50 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Mon, 13 Feb 2023 18:12:01 +0530 Subject: [PATCH 04/42] feat: Throws expected exceptions from totp in memory implementation with tests --- .../java/io/supertokens/inmemorydb/Start.java | 62 ++-- .../inmemorydb/queries/TOTPQueries.java | 41 ++- src/main/java/io/supertokens/totp/Totp.java | 9 + .../test/totp/TOTPDevicesTest.java | 24 +- .../test/totp/TOTPStorageTest.java | 279 ++++++++++++++++++ 5 files changed, 355 insertions(+), 60 deletions(-) create mode 100644 src/test/java/io/supertokens/test/totp/TOTPStorageTest.java diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index d1d950806..706062415 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -1628,44 +1628,66 @@ public void addInfoToNonAuthRecipesBasedOnUserId(String className, String userId } } - // TOTP recipe: + // TOTP recipe: @Override - public void createDevice(TOTPDevice device) throws StorageQueryException { + public void createDevice(TOTPDevice device) throws StorageQueryException, DeviceAlreadyExistsException { try { TOTPQueries.createDevice(this, device); - } catch (Exception e) { + } catch (SQLException e) { + if (e.getMessage() + .equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (UNIQUE constraint failed: " + + Config.getConfig(this).getTotpUserDevicesTable() + ".user_id, " + + Config.getConfig(this).getTotpUserDevicesTable() + ".device_name" + ")")) { + throw new DeviceAlreadyExistsException(); + } + throw new StorageQueryException(e); } } @Override public void markDeviceAsVerified(String userId, String deviceName) - throws StorageQueryException, TotpNotEnabledException, UnknownDeviceException { + throws StorageQueryException, UnknownDeviceException { try { - TOTPQueries.markDeviceAsVerified(this, userId, deviceName); - } catch (Exception e) { + int updatedCount = TOTPQueries.markDeviceAsVerified(this, userId, deviceName); + if (updatedCount == 0) { + throw new UnknownDeviceException(); + } + } catch (SQLException e) { throw new StorageQueryException(e); } } @Override public void deleteDevice(String userId, String deviceName) - throws StorageQueryException, TotpNotEnabledException, UnknownDeviceException { + throws StorageQueryException, UnknownDeviceException { try { - TOTPQueries.deleteDevice(this, userId, deviceName); - } catch (Exception e) { + int deletedCount = TOTPQueries.deleteDevice(this, userId, deviceName); + if (deletedCount == 0) { + throw new UnknownDeviceException(); + } + } catch (SQLException e) { throw new StorageQueryException(e); } } @Override public void updateDeviceName(String userId, String oldDeviceName, String newDeviceName) - throws StorageQueryException, TotpNotEnabledException, DeviceAlreadyExistsException, + throws StorageQueryException, DeviceAlreadyExistsException, UnknownDeviceException { try { - TOTPQueries.updateDeviceName(this, userId, oldDeviceName, newDeviceName); - } catch (Exception e) { + int updatedCount = TOTPQueries.updateDeviceName(this, userId, oldDeviceName, newDeviceName); + if (updatedCount == 0) { + throw new UnknownDeviceException(); + } + } catch (SQLException e) { + if (e.getMessage() + .equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (UNIQUE constraint failed: " + + Config.getConfig(this).getTotpUserDevicesTable() + ".user_id, " + + Config.getConfig(this).getTotpUserDevicesTable() + ".device_name" + ")")) { + throw new DeviceAlreadyExistsException(); + } throw new StorageQueryException(e); } } @@ -1675,27 +1697,33 @@ public TOTPDevice[] getDevices(String userId) throws StorageQueryException { try { return TOTPQueries.getDevices(this, userId); - } catch (Exception e) { + } catch (SQLException e) { throw new StorageQueryException(e); } } @Override - public boolean insertUsedCode(TOTPUsedCode code) + public boolean insertUsedCode(TOTPUsedCode usedCodeObj) throws StorageQueryException, TotpNotEnabledException { try { - return TOTPQueries.insertUsedCode(this, code); + int insertCount = TOTPQueries.insertUsedCode(this, usedCodeObj); + return insertCount == 1; } catch (Exception e) { + // FIXME: Not working without `PRAGMA foreign_keys = ON;` but unable to setup it in tests. + if (e.getMessage() + .equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (FOREIGN KEY constraint failed)")) { + throw new TotpNotEnabledException(); + } throw new StorageQueryException(e); } } @Override public TOTPUsedCode[] getUsedCodes(String userId) - throws StorageQueryException, TotpNotEnabledException { + throws StorageQueryException { try { return TOTPQueries.getUsedCodes(this, userId); - } catch (Exception e) { + } catch (SQLException e) { throw new StorageQueryException(e); } } diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index 8602ff408..28875d9c2 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -19,19 +19,19 @@ public class TOTPQueries { public static String getQueryToCreateUserDevicesTable(Start start) { - return "CREATE TABLE IF NOT EXISTS" + Config.getConfig(start).getTotpUserDevicesTable() + " (" + return "CREATE TABLE IF NOT EXISTS " + Config.getConfig(start).getTotpUserDevicesTable() + " (" + "user_id VARCHAR(128) NOT NULL," + "device_name VARCHAR(256) NOT NULL," + "secret_key VARCHAR(256) NOT NULL," + "period INTEGER NOT NULL," + "skew INTEGER NOT NULL," + "verified BOOLEAN NOT NULL," - + "PRIMARY KEY (user_id, device_name)"; + + "PRIMARY KEY (user_id, device_name))"; } public static String getQueryToCreateUsedCodesTable(Start start) { - return "CREATE TABLE IF NOT EXISTS" + Config.getConfig(start).getTotpUsedCodesTable() + " (" + return "CREATE TABLE IF NOT EXISTS " + Config.getConfig(start).getTotpUsedCodesTable() + " (" + "user_id VARCHAR(128) NOT NULL," + "code VARCHAR(6) NOT NULL," + "is_valid_code BOOLEAN NOT NULL," - + "expiry_time BIGINT NOT NULL," + "PRIMARY KEY (user_id, code)," - + "FOREIGN KEY (user_id) REFERENCES" + Config.getConfig(start).getTotpUserDevicesTable() - + "(user_id) ON DELETE CASCADE"; + + "expiry_time BIGINT NOT NULL," + + "FOREIGN KEY (user_id) REFERENCES " + Config.getConfig(start).getTotpUserDevicesTable() + + "(user_id) ON DELETE CASCADE)"; } public static String getQueryToCreateUsedCodesIndex(Start start) { @@ -40,9 +40,9 @@ public static String getQueryToCreateUsedCodesIndex(Start start) { } public static void createDevice(Start start, TOTPDevice device) - throws StorageQueryException, StorageTransactionLogicException, SQLException { + throws StorageQueryException, SQLException { String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUserDevicesTable() - + " (deviceName, userId, secretKey, period, skew, verified) VALUES (?, ?, ?, ?, ?, ?)"; + + " (device_name, user_id, secret_key, period, skew, verified) VALUES (?, ?, ?, ?, ?, ?)"; update(start, QUERY, pst -> { pst.setString(1, device.deviceName); @@ -54,33 +54,33 @@ public static void createDevice(Start start, TOTPDevice device) }); } - public static void markDeviceAsVerified(Start start, String userId, String deviceName) - throws StorageTransactionLogicException, StorageQueryException, SQLException { + public static int markDeviceAsVerified(Start start, String userId, String deviceName) + throws StorageQueryException, SQLException { String QUERY = "UPDATE " + Config.getConfig(start).getTotpUserDevicesTable() + " SET verified = true WHERE user_id = ? AND device_name = ?;"; - update(start, QUERY, pst -> { + return update(start, QUERY, pst -> { pst.setString(1, userId); pst.setString(2, deviceName); }); } - public static void deleteDevice(Start start, String userId, String deviceName) - throws StorageTransactionLogicException, StorageQueryException, SQLException { + public static int deleteDevice(Start start, String userId, String deviceName) + throws StorageQueryException, SQLException { String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUserDevicesTable() + " WHERE user_id = ? AND device_name = ?;"; - update(start, QUERY, pst -> { + return update(start, QUERY, pst -> { pst.setString(1, userId); pst.setString(2, deviceName); }); } - public static void updateDeviceName(Start start, String userId, String oldDeviceName, String newDeviceName) - throws StorageTransactionLogicException, StorageQueryException, SQLException { + public static int updateDeviceName(Start start, String userId, String oldDeviceName, String newDeviceName) + throws StorageQueryException, SQLException { String QUERY = "UPDATE " + Config.getConfig(start).getTotpUserDevicesTable() + " SET device_name = ? WHERE user_id = ? AND device_name = ?;"; - update(start, QUERY, pst -> { + return update(start, QUERY, pst -> { pst.setString(1, newDeviceName); pst.setString(2, userId); pst.setString(3, oldDeviceName); @@ -88,7 +88,7 @@ public static void updateDeviceName(Start start, String userId, String oldDevice } public static TOTPDevice[] getDevices(Start start, String userId) - throws StorageTransactionLogicException, StorageQueryException, SQLException { + throws StorageQueryException, SQLException { String QUERY = "SELECT * FROM " + Config.getConfig(start).getTotpUserDevicesTable() + " WHERE user_id = ?;"; @@ -144,17 +144,16 @@ public TOTPUsedCode map(ResultSet result) throws SQLException { } } - public static boolean insertUsedCode(Start start, TOTPUsedCode code) throws SQLException, StorageQueryException { + public static int insertUsedCode(Start start, TOTPUsedCode code) throws SQLException, StorageQueryException { String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUsedCodesTable() + " (user_id, code, is_valid_code, expiry_time) VALUES (?, ?, ?, ?);"; - update(start, QUERY, pst -> { + return update(start, QUERY, pst -> { pst.setString(1, code.userId); pst.setString(2, code.code); pst.setBoolean(3, code.isValidCode); pst.setLong(4, code.expiryTime); }); - return true; // FIXME: Count the number of rows inserted OR Check if the code was inserted } public static TOTPUsedCode[] getUsedCodes(Start start, String userId) throws SQLException, StorageQueryException { diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index dbfd5d0ce..48cc65080 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -29,6 +29,15 @@ public static CreateDeviceResponse createDevice(Main main, String userId, String return new CreateDeviceResponse("deviceName", secret); } + public static void markDeviceAsVerified(Main main, String userId, String deviceName) throws IOException { + TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); + try { + totpStorage.markDeviceAsVerified(userId, deviceName); + } catch (Exception e) { + throw new IOException(e); + } + } + private static class GenerateDeviceSecret { // private final String secret; diff --git a/src/test/java/io/supertokens/test/totp/TOTPDevicesTest.java b/src/test/java/io/supertokens/test/totp/TOTPDevicesTest.java index c81b3bfd4..2ab28a0f0 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPDevicesTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPDevicesTest.java @@ -67,30 +67,10 @@ public void createDeviceWithFullCode() throws Exception { assertNotNull(createDeviceResponse); createDeviceResponse.deviceName.equals("deviceName"); + Totp.markDeviceAsVerified(process.getProcess(), "userId", "deviceName"); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } - @Test - public void createDeviceWithDb() throws Exception { - String[] args = { "../" }; - - TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); - assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); - - if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { - return; - } - - TOTPDevice newDevice = new TOTPDevice("deviceName", "userId", "secretKey", 30, 1, false); - - TOTPStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); - storage.createDevice(newDevice); - - TOTPDevice[] storedDevices = storage.getDevices("userId"); - assertNotNull(storedDevices); - assert (storedDevices.length == 1); - assert (storedDevices[0].deviceName.equals("deviceName")); - } - } diff --git a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java new file mode 100644 index 000000000..9f1d5ba3c --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java @@ -0,0 +1,279 @@ +package io.supertokens.test.totp; + +import static org.junit.Assert.assertNotNull; // Not sure about this + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import io.supertokens.test.Utils; +import io.supertokens.ProcessState; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; + +import io.supertokens.totp.Totp; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.totp.TOTPStorage; +import io.supertokens.pluginInterface.totp.TOTPUsedCode; +import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; + +public class TOTPStorageTest { + + public class TestSetupResult { + public TOTPStorage storage; + public TestingProcessManager.TestingProcess process; + + public TestSetupResult(TOTPStorage storage, TestingProcessManager.TestingProcess process) { + this.storage = storage; + this.process = process; + } + } + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + public TestSetupResult setup() throws InterruptedException { + String[] args = { "../" }; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + assert (false); + } + TOTPStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); + + return new TestSetupResult(storage, process); + } + + @Test + public void createDeviceTests() throws Exception { + TestSetupResult result = setup(); + TOTPStorage storage = result.storage; + + TOTPDevice device1 = new TOTPDevice("d1", "user", "secretKey", 30, 1, false); + TOTPDevice device2 = new TOTPDevice("d2", "user", "secretKey", 30, 1, false); + TOTPDevice device2Duplicate = new TOTPDevice("d2", "user", "secretKey", 30, 1, false); + + storage.createDevice(device1); + + TOTPDevice[] storedDevices = storage.getDevices("user"); + TOTPDevice storedDevice = storedDevices[0]; + assert (storedDevices.length == 1); + assert (storedDevice.deviceName.equals("d1")); + assert (storedDevice.userId.equals("user")); + assert (storedDevice.secretKey.equals("secretKey")); + assert (storedDevice.period == 30); + assert (storedDevice.skew == 1); + assert (storedDevice.verified == false); + + storage.createDevice(device2); + storedDevices = storage.getDevices("user"); + assert (storedDevices.length == 2); + + try { + storage.createDevice(device2Duplicate); + assert (false); + } catch (DeviceAlreadyExistsException e) { + assert (true); + } + + result.process.kill(); + assertNotNull(result.process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void verifyDeviceTests() throws Exception { + TestSetupResult result = setup(); + TOTPStorage storage = result.storage; + + TOTPDevice device = new TOTPDevice("device", "user", "secretKey", 30, 1, false); + storage.createDevice(device); + + TOTPDevice[] storedDevices = storage.getDevices("user"); + assert (storedDevices.length == 1); + assert (!storedDevices[0].verified); + + // Verify the device: + storage.markDeviceAsVerified("user", "device"); + + storedDevices = storage.getDevices("user"); + assert (storedDevices.length == 1); + assert (storedDevices[0].verified); + + // Try to verify a device that doesn't exist: + try { + storage.markDeviceAsVerified("user", "non-existent-device"); + assert (false); + } catch (Exception e) { + assert (true); + } + } + + @Test + public void deleteDeviceTests() throws Exception { + TestSetupResult result = setup(); + TOTPStorage storage = result.storage; + + TOTPDevice device = new TOTPDevice("device", "user", "secretKey", 30, 1, false); + storage.createDevice(device); + + TOTPDevice[] storedDevices = storage.getDevices("user"); + assert (storedDevices.length == 1); + + // Try to delete a device that doesn't exist: + try { + storage.deleteDevice("user", "non-existent-device"); + assert (false); + } catch (Exception e) { + assert (true); + } + + // Delete the device: + storage.deleteDevice("user", "device"); + + storedDevices = storage.getDevices("user"); + assert (storedDevices.length == 0); + } + + @Test + public void updateDeviceNametests() throws Exception { + TestSetupResult result = setup(); + TOTPStorage storage = result.storage; + + TOTPDevice device = new TOTPDevice("device", "user", "secretKey", 30, 1, false); + storage.createDevice(device); + + TOTPDevice[] storedDevices = storage.getDevices("user"); + assert (storedDevices.length == 1); + assert (storedDevices[0].deviceName.equals("device")); + + // Try to update a device that doesn't exist: + try { + storage.updateDeviceName("user", "non-existent-device", "new-device-name"); + assert (false); + } catch (Exception e) { + assert (true); + } + + // Update the device name: + storage.updateDeviceName("user", "device", "new-device-name"); + + storedDevices = storage.getDevices("user"); + assert (storedDevices.length == 1); + assert (storedDevices[0].deviceName.equals("new-device-name")); + } + + @Test + public void getDevicesTest() throws Exception { + TestSetupResult result = setup(); + TOTPStorage storage = result.storage; + + TOTPDevice device1 = new TOTPDevice("d1", "user", "secretKey", 30, 1, false); + TOTPDevice device2 = new TOTPDevice("d2", "user", "secretKey", 30, 1, false); + + storage.createDevice(device1); + storage.createDevice(device2); + + TOTPDevice[] storedDevices = storage.getDevices("user"); + + assert (storedDevices.length == 2); + assert (storedDevices[0].deviceName.equals("d1")); + assert (storedDevices[1].deviceName.equals("d2")); + } + + @Test + public void insertUsedCodeTest() throws Exception { + TestSetupResult result = setup(); + TOTPStorage storage = result.storage; + + TOTPDevice device = new TOTPDevice("device", "user", "secretKey", 30, 1, false); + TOTPUsedCode code = new TOTPUsedCode("user", "1234", true, 1); + + storage.createDevice(device); + boolean isInserted = storage.insertUsedCode(code); + TOTPUsedCode[] usedCodes = storage.getUsedCodes("user"); + + assert (isInserted); + assert (usedCodes.length == 1); + assert (usedCodes[0].userId.equals("user")); + assert (usedCodes[0].code.equals("1234")); + assert (usedCodes[0].isValidCode); + assert (usedCodes[0].expiryTime == 1); + + // FIXME: Next two features aren't working because foreign key constraint is not + // working in tests: + + // Deleting the device should delete the used codes: + storage.deleteDevice("user", "device"); + usedCodes = storage.getUsedCodes("user"); + assert (usedCodes.length == 0); + + // Try to insert code when device (userId) doesn't exist: + try { + // Need to run `PRAGMA foreign_keys = ON;` then only will throws exception. But + // unable to setup that also in tests. + storage.insertUsedCode(new TOTPUsedCode("non-existent-user", "1234", true, 1)); + assert (false); + } catch (Exception e) { + assert (true); + } + } + + @Test + public void getUsedCodesTest() throws Exception { + TestSetupResult result = setup(); + TOTPStorage storage = result.storage; + + TOTPDevice device = new TOTPDevice("device", "user", "secretKey", 30, 1, false); + TOTPUsedCode code1 = new TOTPUsedCode("user", "code1", true, 1); + TOTPUsedCode code2 = new TOTPUsedCode("user", "code2", false, 1); + + storage.createDevice(device); + storage.insertUsedCode(code1); + storage.insertUsedCode(code2); + + TOTPUsedCode[] usedCodes = storage.getUsedCodes("user"); + assert (usedCodes.length == 2); + assert (usedCodes[0].code.equals("code1")); + assert (usedCodes[0].isValidCode); + assert (usedCodes[1].code.equals("code2")); + assert (!usedCodes[1].isValidCode); + } + + @Test + public void removeExpiredCodesTest() throws Exception { + TestSetupResult result = setup(); + TOTPStorage storage = result.storage; + + TOTPDevice device = new TOTPDevice("device", "user", "secretKey", 30, 1, false); + TOTPUsedCode codeToDelete = new TOTPUsedCode("user", "codeToDelete", true, 1); + TOTPUsedCode codeToKeep = new TOTPUsedCode("user", "codeToKeep", false, 1); + + storage.createDevice(device); + storage.insertUsedCode(codeToDelete); + storage.insertUsedCode(codeToKeep); + + TOTPUsedCode[] usedCodes = storage.getUsedCodes("user"); + assert (usedCodes.length == 2); + + storage.removeExpiredCodes(); + + usedCodes = storage.getUsedCodes("user"); + assert (usedCodes.length == 1); + assert (usedCodes[0].code.equals("codeToKeep")); + } +} From 7251944fe45367944e79556e2f579baffbeb1ae0 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Thu, 16 Feb 2023 14:02:31 +0530 Subject: [PATCH 05/42] feat: Fix TOTP.java and inmemory implementation --- .../java/io/supertokens/inmemorydb/Start.java | 10 +- .../inmemorydb/queries/TOTPQueries.java | 70 +++--- src/main/java/io/supertokens/totp/Totp.java | 200 +++++++++++++++--- .../totp/exceptions/InvalidTotpException.java | 5 + .../exceptions/LimitReachedException.java | 5 + 5 files changed, 223 insertions(+), 67 deletions(-) create mode 100644 src/main/java/io/supertokens/totp/exceptions/InvalidTotpException.java create mode 100644 src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index 706062415..657d2fd5c 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -1706,14 +1706,14 @@ public TOTPDevice[] getDevices(String userId) public boolean insertUsedCode(TOTPUsedCode usedCodeObj) throws StorageQueryException, TotpNotEnabledException { try { + TOTPDevice[] devices = TOTPQueries.getDevices(this, usedCodeObj.userId); + if (devices.length == 0) { + throw new TotpNotEnabledException(); + } + int insertCount = TOTPQueries.insertUsedCode(this, usedCodeObj); return insertCount == 1; } catch (Exception e) { - // FIXME: Not working without `PRAGMA foreign_keys = ON;` but unable to setup it in tests. - if (e.getMessage() - .equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (FOREIGN KEY constraint failed)")) { - throw new TotpNotEnabledException(); - } throw new StorageQueryException(e); } } diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index 28875d9c2..8e5eb2ac2 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -98,10 +98,44 @@ public static TOTPDevice[] getDevices(Start start, String userId) devices.add(TOTPDeviceRowMapper.getInstance().map(result)); } - return devices.toArray(new TOTPDevice[0]); + return devices.toArray(TOTPDevice[]::new); }); } + public static int insertUsedCode(Start start, TOTPUsedCode code) throws SQLException, StorageQueryException { + String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUsedCodesTable() + + " (user_id, code, is_valid_code, expiry_time) VALUES (?, ?, ?, ?);"; + + return update(start, QUERY, pst -> { + pst.setString(1, code.userId); + pst.setString(2, code.code); + pst.setBoolean(3, code.isValidCode); + pst.setLong(4, code.expiryTime); + }); + } + + public static TOTPUsedCode[] getUsedCodes(Start start, String userId) throws SQLException, StorageQueryException { + String QUERY = "SELECT * FROM " + + Config.getConfig(start).getTotpUsedCodesTable() + + " WHERE user_id = ?;"; + return execute(start, QUERY, pst -> pst.setString(1, userId), result -> { + List codes = new ArrayList<>(); + while (result.next()) { + codes.add(TOTPUsedCodeRowMapper.getInstance().map(result)); + } + + return codes.toArray(TOTPUsedCode[]::new); + }); + } + + public static int removeExpiredCodes(Start start) + throws StorageTransactionLogicException, StorageQueryException, SQLException { + String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUsedCodesTable() + + " WHERE expiry_time < ?;"; + + return update(start, QUERY, pst -> pst.setLong(1, System.currentTimeMillis())); + } + private static class TOTPDeviceRowMapper implements RowMapper { private static final TOTPDeviceRowMapper INSTANCE = new TOTPDeviceRowMapper(); @@ -143,38 +177,4 @@ public TOTPUsedCode map(ResultSet result) throws SQLException { result.getLong("expiry_time")); } } - - public static int insertUsedCode(Start start, TOTPUsedCode code) throws SQLException, StorageQueryException { - String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUsedCodesTable() - + " (user_id, code, is_valid_code, expiry_time) VALUES (?, ?, ?, ?);"; - - return update(start, QUERY, pst -> { - pst.setString(1, code.userId); - pst.setString(2, code.code); - pst.setBoolean(3, code.isValidCode); - pst.setLong(4, code.expiryTime); - }); - } - - public static TOTPUsedCode[] getUsedCodes(Start start, String userId) throws SQLException, StorageQueryException { - String QUERY = "SELECT * FROM " + - Config.getConfig(start).getTotpUsedCodesTable() - + " WHERE user_id = ?;"; - return execute(start, QUERY, pst -> pst.setString(1, userId), result -> { - List codes = new ArrayList<>(); - while (result.next()) { - codes.add(TOTPUsedCodeRowMapper.getInstance().map(result)); - } - - return codes.toArray(new TOTPUsedCode[0]); - }); - } - - public static void removeExpiredCodes(Start start) - throws StorageTransactionLogicException, StorageQueryException, SQLException { - String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUsedCodesTable() - + " WHERE expiry_time < ?;"; - - update(start, QUERY, pst -> pst.setLong(1, System.currentTimeMillis())); - } } diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index 48cc65080..f8145b676 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -1,63 +1,209 @@ package io.supertokens.totp; -import java.io.IOException; - import io.supertokens.Main; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.totp.TOTPUsedCode; +import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; +import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.totp.exceptions.InvalidTotpException; +import io.supertokens.totp.exceptions.LimitReachedException; public class Totp { + public static String generateSecret() { + return "XXXX"; + } + + public static boolean checkCode(TOTPDevice device, String code) { + return true; + } + public static CreateDeviceResponse createDevice(Main main, String userId, String deviceName, int skew, int period) - throws IOException { + throws StorageQueryException, DeviceAlreadyExistsException { TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); - String secret = GenerateDeviceSecret.generate(); + String secret = generateSecret(); + + TOTPDevice device = new TOTPDevice(userId, deviceName, secret, skew, period, false); + totpStorage.createDevice(device); - if (userId == null || deviceName == null || secret == null) { - throw new IllegalArgumentException("userId, deviceName and secret cannot be null"); + // TODO: Should we just return the secret as a string? + return new CreateDeviceResponse(secret); + } + + public static VerifyDeviceResponse verifyDevice(Main main, String userId, String deviceName, String code) + throws StorageQueryException, TotpNotEnabledException, UnknownDeviceException, InvalidTotpException { + TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); + + TOTPDevice[] devices = totpStorage.getDevices(userId); + if (devices.length == 0) { + throw new TotpNotEnabledException(); } - try { - TOTPDevice device = new TOTPDevice(userId, deviceName, secret, skew, period, false); - totpStorage.createDevice(device); - } catch (Exception e) { - throw new IOException(e); + // Find the device: + TOTPDevice matchingDevice = null; + for (TOTPDevice device : devices) { + if (device.deviceName.equals(deviceName)) { + if (device.verified) { + // TODO: Should we just return a boolean here? + return new VerifyDeviceResponse(true); + } else { + matchingDevice = device; + break; + } + } + } + if (matchingDevice == null) { + throw new UnknownDeviceException(); + } + + // // Insert the code into the list of used codes: + TOTPUsedCode newCode = new TOTPUsedCode(userId, code, true, System.currentTimeMillis() + 1000 * 60 * 5); + totpStorage.insertUsedCode(newCode); + + // Check if the code is valid: + if (!checkCode(matchingDevice, code)) { + throw new InvalidTotpException(); + } + + // Check if the code is unused: + TOTPUsedCode[] usedCodes = totpStorage.getUsedCodes(userId); + for (TOTPUsedCode usedCode : usedCodes) { + if (usedCode.code.equals(code) && usedCode.isValidCode) { + throw new InvalidTotpException(); + } + } + + totpStorage.markDeviceAsVerified(userId, deviceName); + return new VerifyDeviceResponse(false); + } + + public static void verifyCode(Main main, String userId, String code, boolean allowUnverifiedDevices) + throws StorageQueryException, TotpNotEnabledException, InvalidTotpException, + LimitReachedException { + TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); + + // Check if the user has any devices: + TOTPDevice[] devices = totpStorage.getDevices(userId); + if (devices.length == 0) { + throw new TotpNotEnabledException(); + } + + // FIXME: Every unused_code should be linked to the device it was used for. + // Otherwise, if a user has multiple devices, and they use the same code for + // both, + // then the code will be considered as used for both devices. This could cause + // UX + // issues. + // If we do this, then it also means that we need to assign a device ID to each + // device (OR use + // (userId, deviceName) as the ID) + + // Check if the code has been successfully used by the user (for any device): + TOTPUsedCode[] usedCodes = totpStorage.getUsedCodes(userId); + for (TOTPUsedCode usedCode : usedCodes) { + if (usedCode.code.equals(code) && usedCode.isValidCode) { + throw new InvalidTotpException(); + } + } + + // Try different devices until we find one that works: + boolean isValid = false; + for (TOTPDevice device : devices) { + // Check if the code is valid for this device: + if (checkCode(device, code)) { + isValid = true; + break; + } + } + + // Insert the code into the list of used codes: + TOTPUsedCode newCode = new TOTPUsedCode(userId, code, isValid, System.currentTimeMillis() + 1000 * 60 * 5); + totpStorage.insertUsedCode(newCode); + + if (isValid) { + return; + } + + // Check if last 5 codes are all invalid: + int WINDOW_SIZE = 5; + int invalidCodes = 0; + for (int i = usedCodes.length - 1; i >= 0 && i >= usedCodes.length - WINDOW_SIZE; i--) { + if (!usedCodes[i].isValidCode) { + invalidCodes++; + } + } + if (invalidCodes == WINDOW_SIZE) { + throw new LimitReachedException(); } - return new CreateDeviceResponse("deviceName", secret); + // Code is invalid and the user has not exceeded the limit: + throw new InvalidTotpException(); } - public static void markDeviceAsVerified(Main main, String userId, String deviceName) throws IOException { + public static void deleteDevice(Main main, String userId, String deviceName) + throws StorageQueryException, UnknownDeviceException, TotpNotEnabledException { TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); + try { - totpStorage.markDeviceAsVerified(userId, deviceName); - } catch (Exception e) { - throw new IOException(e); + totpStorage.deleteDevice(userId, deviceName); + } catch (UnknownDeviceException e) { + // See if any device exists for the user: + TOTPDevice[] devices = totpStorage.getDevices(userId); + if (devices.length == 0) { + throw new TotpNotEnabledException(); + } else { + throw e; + } } } - private static class GenerateDeviceSecret { - // private final String secret; + public static void updateDeviceName(Main main, String userId, String oldDeviceName, String newDeviceName) + throws StorageQueryException, DeviceAlreadyExistsException, UnknownDeviceException, + TotpNotEnabledException { + TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); + try { + totpStorage.updateDeviceName(userId, oldDeviceName, newDeviceName); + } catch (UnknownDeviceException e) { + // See if any device exists for the user: + TOTPDevice[] devices = totpStorage.getDevices(userId); + if (devices.length == 0) { + throw new TotpNotEnabledException(); + } else { + throw e; + } + } + } - // private GenerateDeviceSecret(String secret) { - // this.secret = secret; - // } + public static TOTPDevice[] getDevices(Main main, String userId) + throws StorageQueryException, TotpNotEnabledException { + TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); - public static String generate() { - return "XXXX"; + TOTPDevice[] devices = totpStorage.getDevices(userId); + if (devices.length == 0) { + throw new TotpNotEnabledException(); } + return devices; } public static class CreateDeviceResponse { - public String deviceName; - public String secret; + public final String secret; - public CreateDeviceResponse(String deviceName, String secret) { - this.deviceName = deviceName; + public CreateDeviceResponse(String secret) { this.secret = secret; } + } + + public static class VerifyDeviceResponse { + public final boolean deviceWasAlreadyVerified; + public VerifyDeviceResponse(boolean deviceWasAlreadyVerified) { + this.deviceWasAlreadyVerified = deviceWasAlreadyVerified; + } } + } diff --git a/src/main/java/io/supertokens/totp/exceptions/InvalidTotpException.java b/src/main/java/io/supertokens/totp/exceptions/InvalidTotpException.java new file mode 100644 index 000000000..9dce2f51d --- /dev/null +++ b/src/main/java/io/supertokens/totp/exceptions/InvalidTotpException.java @@ -0,0 +1,5 @@ +package io.supertokens.totp.exceptions; + +public class InvalidTotpException extends Exception { + +} diff --git a/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java b/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java new file mode 100644 index 000000000..0da203afe --- /dev/null +++ b/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java @@ -0,0 +1,5 @@ +package io.supertokens.totp.exceptions; + +public class LimitReachedException extends Exception { + +} From d5551b64cbe10588f59e3d6332cc4aa58b9b2e6d Mon Sep 17 00:00:00 2001 From: KShivendu Date: Thu, 16 Feb 2023 18:38:31 +0530 Subject: [PATCH 06/42] feat: Improvemnts in TOTP in memory implementation --- .../java/io/supertokens/inmemorydb/Start.java | 7 +++ .../test/totp/TOTPStorageTest.java | 60 ++++++++----------- 2 files changed, 33 insertions(+), 34 deletions(-) diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index 657d2fd5c..61cab4764 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -1667,6 +1667,13 @@ public void deleteDevice(String userId, String deviceName) if (deletedCount == 0) { throw new UnknownDeviceException(); } + + // FIXME: Should delete all the codes associated with this device. + // But problem is that we store all the codes whether they are valid (i.e. matching device found) or not. + // So we can't just delete all the codes for the user. We need to delete only the codes that are associated + // with this device. But we don't store that information. + + // One way is only delete the valid codes. } catch (SQLException e) { throw new StorageQueryException(e); } diff --git a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java index 9f1d5ba3c..32e38ce5e 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java @@ -1,6 +1,7 @@ package io.supertokens.test.totp; import static org.junit.Assert.assertNotNull; // Not sure about this +import static org.junit.Assert.assertThrows; import org.junit.AfterClass; import org.junit.Before; @@ -19,6 +20,8 @@ import io.supertokens.pluginInterface.totp.TOTPStorage; import io.supertokens.pluginInterface.totp.TOTPUsedCode; import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; +import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; public class TOTPStorageTest { @@ -84,12 +87,7 @@ public void createDeviceTests() throws Exception { storedDevices = storage.getDevices("user"); assert (storedDevices.length == 2); - try { - storage.createDevice(device2Duplicate); - assert (false); - } catch (DeviceAlreadyExistsException e) { - assert (true); - } + assertThrows(DeviceAlreadyExistsException.class, () -> storage.createDevice(device2Duplicate)); result.process.kill(); assertNotNull(result.process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); @@ -115,12 +113,7 @@ public void verifyDeviceTests() throws Exception { assert (storedDevices[0].verified); // Try to verify a device that doesn't exist: - try { - storage.markDeviceAsVerified("user", "non-existent-device"); - assert (false); - } catch (Exception e) { - assert (true); - } + assertThrows(UnknownDeviceException.class, () -> storage.markDeviceAsVerified("user", "non-existent-device")); } @Test @@ -135,12 +128,7 @@ public void deleteDeviceTests() throws Exception { assert (storedDevices.length == 1); // Try to delete a device that doesn't exist: - try { - storage.deleteDevice("user", "non-existent-device"); - assert (false); - } catch (Exception e) { - assert (true); - } + assertThrows(UnknownDeviceException.class, () -> storage.deleteDevice("user", "non-existent-device")); // Delete the device: storage.deleteDevice("user", "device"); @@ -162,19 +150,26 @@ public void updateDeviceNametests() throws Exception { assert (storedDevices[0].deviceName.equals("device")); // Try to update a device that doesn't exist: - try { - storage.updateDeviceName("user", "non-existent-device", "new-device-name"); - assert (false); - } catch (Exception e) { - assert (true); - } + assertThrows(UnknownDeviceException.class, + () -> storage.updateDeviceName("user", "non-existent-device", "new-device-name")); // Update the device name: - storage.updateDeviceName("user", "device", "new-device-name"); + storage.updateDeviceName("user", "device", "updated-device-name"); storedDevices = storage.getDevices("user"); assert (storedDevices.length == 1); - assert (storedDevices[0].deviceName.equals("new-device-name")); + assert (storedDevices[0].deviceName.equals("updated-device-name")); + + // Try to create a new device and rename it to the same name as an existing + // device: + TOTPDevice newDevice = new TOTPDevice("new-device", "user", "secretKey", 30, 1, false); + storage.createDevice(newDevice); + + assertThrows(DeviceAlreadyExistsException.class, + () -> storage.updateDeviceName("user", "new-device", "updated-device-name")); + + // Try to rename the device the same name (Should work at database level): + storage.updateDeviceName("user", "updated-device-name", "updated-device-name"); } @Test @@ -222,15 +217,12 @@ public void insertUsedCodeTest() throws Exception { usedCodes = storage.getUsedCodes("user"); assert (usedCodes.length == 0); + // Need to run `PRAGMA foreign_keys = ON;` then only will throws exception. But + // unable to setup that also in tests. + // Try to insert code when device (userId) doesn't exist: - try { - // Need to run `PRAGMA foreign_keys = ON;` then only will throws exception. But - // unable to setup that also in tests. - storage.insertUsedCode(new TOTPUsedCode("non-existent-user", "1234", true, 1)); - assert (false); - } catch (Exception e) { - assert (true); - } + assertThrows(TotpNotEnabledException.class, + () -> storage.insertUsedCode(new TOTPUsedCode("non-existent-user", "1234", true, 1))); } @Test From 353093970dd7b7cf446b81da34623c940e16eca3 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Fri, 17 Feb 2023 11:45:29 +0530 Subject: [PATCH 07/42] feat: Improve tests and used code handling logic --- src/main/java/io/supertokens/inmemorydb/Start.java | 12 +++++++----- .../inmemorydb/queries/TOTPQueries.java | 2 +- .../io/supertokens/test/totp/TOTPStorageTest.java | 14 ++++---------- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index 61cab4764..68e1ce15f 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -1669,11 +1669,13 @@ public void deleteDevice(String userId, String deviceName) } // FIXME: Should delete all the codes associated with this device. - // But problem is that we store all the codes whether they are valid (i.e. matching device found) or not. - // So we can't just delete all the codes for the user. We need to delete only the codes that are associated + // But problem is that we store all the codes whether they are valid (i.e. + // matching device found) or not. + // So we can't just delete all the codes for the user. We need to delete only + // the codes that are associated // with this device. But we don't store that information. - // One way is only delete the valid codes. + // One way is only delete the valid codes? } catch (SQLException e) { throw new StorageQueryException(e); } @@ -1720,7 +1722,7 @@ public boolean insertUsedCode(TOTPUsedCode usedCodeObj) int insertCount = TOTPQueries.insertUsedCode(this, usedCodeObj); return insertCount == 1; - } catch (Exception e) { + } catch (SQLException e) { throw new StorageQueryException(e); } } @@ -1740,7 +1742,7 @@ public void removeExpiredCodes() throws StorageQueryException { try { TOTPQueries.removeExpiredCodes(this); - } catch (Exception e) { + } catch (SQLException e) { throw new StorageQueryException(e); } } diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index 8e5eb2ac2..7331d8aa5 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -129,7 +129,7 @@ public static TOTPUsedCode[] getUsedCodes(Start start, String userId) throws SQL } public static int removeExpiredCodes(Start start) - throws StorageTransactionLogicException, StorageQueryException, SQLException { + throws StorageQueryException, SQLException { String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUsedCodesTable() + " WHERE expiry_time < ?;"; diff --git a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java index 32e38ce5e..a92b6226e 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java @@ -209,20 +209,14 @@ public void insertUsedCodeTest() throws Exception { assert (usedCodes[0].isValidCode); assert (usedCodes[0].expiryTime == 1); - // FIXME: Next two features aren't working because foreign key constraint is not - // working in tests: + // Try to insert code when device (userId) doesn't exist: + assertThrows(TotpNotEnabledException.class, + () -> storage.insertUsedCode(new TOTPUsedCode("non-existent-user", "1234", true, 1))); - // Deleting the device should delete the used codes: + // FIXME: Deleting the device should delete the used codes storage.deleteDevice("user", "device"); usedCodes = storage.getUsedCodes("user"); assert (usedCodes.length == 0); - - // Need to run `PRAGMA foreign_keys = ON;` then only will throws exception. But - // unable to setup that also in tests. - - // Try to insert code when device (userId) doesn't exist: - assertThrows(TotpNotEnabledException.class, - () -> storage.insertUsedCode(new TOTPUsedCode("non-existent-user", "1234", true, 1))); } @Test From 0f10e8bbf580fbf0ddfa73ae6130097825a624e2 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Fri, 17 Feb 2023 11:58:59 +0530 Subject: [PATCH 08/42] feat: Improve TOTP inmemorydb queries --- .../inmemorydb/queries/TOTPQueries.java | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index 7331d8aa5..eda075faa 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -9,7 +9,6 @@ import io.supertokens.inmemorydb.config.Config; import io.supertokens.pluginInterface.RowMapper; import io.supertokens.pluginInterface.exceptions.StorageQueryException; -import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.totp.TOTPDevice; import io.supertokens.pluginInterface.totp.TOTPUsedCode; @@ -28,15 +27,15 @@ public static String getQueryToCreateUserDevicesTable(Start start) { public static String getQueryToCreateUsedCodesTable(Start start) { return "CREATE TABLE IF NOT EXISTS " + Config.getConfig(start).getTotpUsedCodesTable() + " (" - + "user_id VARCHAR(128) NOT NULL," + "code VARCHAR(6) NOT NULL," + "is_valid_code BOOLEAN NOT NULL," - + "expiry_time BIGINT NOT NULL," + + "user_id VARCHAR(128) NOT NULL," + "code CHAR(6) NOT NULL," + "is_valid_code BOOLEAN NOT NULL," + + "expiry_time_ms BIGINT UNSIGNED NOT NULL," + "FOREIGN KEY (user_id) REFERENCES " + Config.getConfig(start).getTotpUserDevicesTable() + "(user_id) ON DELETE CASCADE)"; } public static String getQueryToCreateUsedCodesIndex(Start start) { - return "CREATE INDEX IF NOT EXISTS totp_used_codes_expiry_time_index ON " - + Config.getConfig(start).getTotpUsedCodesTable() + " (expiry_time)"; + return "CREATE INDEX IF NOT EXISTS totp_used_codes_expiry_time_ms_index ON " + + Config.getConfig(start).getTotpUsedCodesTable() + " (expiry_time_ms)"; } public static void createDevice(Start start, TOTPDevice device) @@ -104,7 +103,7 @@ public static TOTPDevice[] getDevices(Start start, String userId) public static int insertUsedCode(Start start, TOTPUsedCode code) throws SQLException, StorageQueryException { String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUsedCodesTable() - + " (user_id, code, is_valid_code, expiry_time) VALUES (?, ?, ?, ?);"; + + " (user_id, code, is_valid_code, expiry_time_ms) VALUES (?, ?, ?, ?);"; return update(start, QUERY, pst -> { pst.setString(1, code.userId); @@ -131,7 +130,7 @@ public static TOTPUsedCode[] getUsedCodes(Start start, String userId) throws SQL public static int removeExpiredCodes(Start start) throws StorageQueryException, SQLException { String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUsedCodesTable() - + " WHERE expiry_time < ?;"; + + " WHERE expiry_time_ms < ?;"; return update(start, QUERY, pst -> pst.setLong(1, System.currentTimeMillis())); } @@ -174,7 +173,7 @@ public TOTPUsedCode map(ResultSet result) throws SQLException { result.getString("user_id"), result.getString("code"), result.getBoolean("is_valid_code"), - result.getLong("expiry_time")); + result.getLong("expiry_time_ms")); } } } From 074ddfc12a3c8cc4d8d14bb8fa5bfcfeadb4b4b9 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Mon, 20 Feb 2023 14:24:19 +0530 Subject: [PATCH 09/42] refactor: Use compound foreign key in totp_used_codes table and fix order of init params in TOTPDevice --- .../java/io/supertokens/inmemorydb/Start.java | 11 +- .../inmemorydb/queries/TOTPQueries.java | 32 +++-- src/main/java/io/supertokens/totp/Totp.java | 66 ++++------ .../test/totp/TOTPDevicesTest.java | 76 ----------- .../supertokens/test/totp/TOTPRecipeTest.java | 124 ++++++++++++++++++ .../test/totp/TOTPStorageTest.java | 26 ++-- 6 files changed, 192 insertions(+), 143 deletions(-) delete mode 100644 src/test/java/io/supertokens/test/totp/TOTPDevicesTest.java create mode 100644 src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index 68e1ce15f..6405db34e 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -1668,14 +1668,9 @@ public void deleteDevice(String userId, String deviceName) throw new UnknownDeviceException(); } - // FIXME: Should delete all the codes associated with this device. - // But problem is that we store all the codes whether they are valid (i.e. - // matching device found) or not. - // So we can't just delete all the codes for the user. We need to delete only - // the codes that are associated - // with this device. But we don't store that information. - - // One way is only delete the valid codes? + // Note: This step is only required for in-memory databases. + // They don't have cascading deletes, so we need to manually delete the codes + TOTPQueries.removeUsedCodesForUser(this, userId, deviceName); } catch (SQLException e) { throw new StorageQueryException(e); } diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index eda075faa..08b024a66 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -1,5 +1,6 @@ package io.supertokens.inmemorydb.queries; +import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; @@ -9,9 +10,10 @@ import io.supertokens.inmemorydb.config.Config; import io.supertokens.pluginInterface.RowMapper; import io.supertokens.pluginInterface.exceptions.StorageQueryException; - +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.totp.TOTPDevice; import io.supertokens.pluginInterface.totp.TOTPUsedCode; +import jakarta.annotation.Nullable; import static io.supertokens.inmemorydb.QueryExecutorTemplate.execute; import static io.supertokens.inmemorydb.QueryExecutorTemplate.update; @@ -27,10 +29,11 @@ public static String getQueryToCreateUserDevicesTable(Start start) { public static String getQueryToCreateUsedCodesTable(Start start) { return "CREATE TABLE IF NOT EXISTS " + Config.getConfig(start).getTotpUsedCodesTable() + " (" - + "user_id VARCHAR(128) NOT NULL," + "code CHAR(6) NOT NULL," + "is_valid_code BOOLEAN NOT NULL," + + "user_id VARCHAR(128) NOT NULL, " + "device_name VARCHAR(256), " + + "code CHAR(6) NOT NULL," + "is_valid_code BOOLEAN NOT NULL," + "expiry_time_ms BIGINT UNSIGNED NOT NULL," - + "FOREIGN KEY (user_id) REFERENCES " + Config.getConfig(start).getTotpUserDevicesTable() - + "(user_id) ON DELETE CASCADE)"; + + "FOREIGN KEY (user_id, device_name) REFERENCES " + Config.getConfig(start).getTotpUserDevicesTable() + + "(user_id, device_name) ON DELETE CASCADE);"; } public static String getQueryToCreateUsedCodesIndex(Start start) { @@ -41,11 +44,11 @@ public static String getQueryToCreateUsedCodesIndex(Start start) { public static void createDevice(Start start, TOTPDevice device) throws StorageQueryException, SQLException { String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUserDevicesTable() - + " (device_name, user_id, secret_key, period, skew, verified) VALUES (?, ?, ?, ?, ?, ?)"; + + " (user_id, device_name, secret_key, period, skew, verified) VALUES (?, ?, ?, ?, ?, ?)"; update(start, QUERY, pst -> { - pst.setString(1, device.deviceName); - pst.setString(2, device.userId); + pst.setString(1, device.userId); + pst.setString(2, device.deviceName); pst.setString(3, device.secretKey); pst.setInt(4, device.period); pst.setInt(5, device.skew); @@ -135,6 +138,19 @@ public static int removeExpiredCodes(Start start) return update(start, QUERY, pst -> pst.setLong(1, System.currentTimeMillis())); } + public static int removeUsedCodesForUser(Start start, String userId, String deviceName) + throws StorageQueryException, SQLException { + // Remove codes where userId matches the given userId + // ONLY required for inmemorydb, as it does not support foreign key constraints. + String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUsedCodesTable() + + " WHERE user_id = ? AND device_name = ?;"; + + return update(start, QUERY, pst -> { + pst.setString(1, userId); + pst.setString(2, deviceName); + }); + } + private static class TOTPDeviceRowMapper implements RowMapper { private static final TOTPDeviceRowMapper INSTANCE = new TOTPDeviceRowMapper(); @@ -148,8 +164,8 @@ private static TOTPDeviceRowMapper getInstance() { @Override public TOTPDevice map(ResultSet result) throws SQLException { return new TOTPDevice( - result.getString("device_name"), result.getString("user_id"), + result.getString("device_name"), result.getString("secret_key"), result.getInt("period"), result.getInt("skew"), diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index f8145b676..f06ccab81 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -19,24 +19,30 @@ public static String generateSecret() { } public static boolean checkCode(TOTPDevice device, String code) { - return true; + if (code.startsWith("XXXX")) { + return true; + } + return false; } - public static CreateDeviceResponse createDevice(Main main, String userId, String deviceName, int skew, int period) + public static String createDevice(Main main, String userId, String deviceName, int skew, int period) throws StorageQueryException, DeviceAlreadyExistsException { TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); - String secret = generateSecret(); - TOTPDevice device = new TOTPDevice(userId, deviceName, secret, skew, period, false); + String secret = generateSecret(); + TOTPDevice device = new TOTPDevice(userId, deviceName, secret, period, skew, false); totpStorage.createDevice(device); - // TODO: Should we just return the secret as a string? - return new CreateDeviceResponse(secret); + return secret; } - public static VerifyDeviceResponse verifyDevice(Main main, String userId, String deviceName, String code) + public static boolean verifyDevice(Main main, String userId, String deviceName, String code) throws StorageQueryException, TotpNotEnabledException, UnknownDeviceException, InvalidTotpException { + // Here boolean return value tells whether the device was already verified + + boolean deviceAlreadyVerified = false; + TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); TOTPDevice[] devices = totpStorage.getDevices(userId); @@ -48,20 +54,24 @@ public static VerifyDeviceResponse verifyDevice(Main main, String userId, String TOTPDevice matchingDevice = null; for (TOTPDevice device : devices) { if (device.deviceName.equals(deviceName)) { - if (device.verified) { - // TODO: Should we just return a boolean here? - return new VerifyDeviceResponse(true); - } else { - matchingDevice = device; - break; - } + deviceAlreadyVerified = device.verified; + matchingDevice = device; + break; } } if (matchingDevice == null) { throw new UnknownDeviceException(); } - // // Insert the code into the list of used codes: + // Check if the code is unused: + TOTPUsedCode[] usedCodes = totpStorage.getUsedCodes(userId); + for (TOTPUsedCode usedCode : usedCodes) { + if (usedCode.code.equals(code)) { + throw new InvalidTotpException(); + } + } + + // Insert the code into the list of used codes: TOTPUsedCode newCode = new TOTPUsedCode(userId, code, true, System.currentTimeMillis() + 1000 * 60 * 5); totpStorage.insertUsedCode(newCode); @@ -70,16 +80,8 @@ public static VerifyDeviceResponse verifyDevice(Main main, String userId, String throw new InvalidTotpException(); } - // Check if the code is unused: - TOTPUsedCode[] usedCodes = totpStorage.getUsedCodes(userId); - for (TOTPUsedCode usedCode : usedCodes) { - if (usedCode.code.equals(code) && usedCode.isValidCode) { - throw new InvalidTotpException(); - } - } - totpStorage.markDeviceAsVerified(userId, deviceName); - return new VerifyDeviceResponse(false); + return deviceAlreadyVerified; } public static void verifyCode(Main main, String userId, String code, boolean allowUnverifiedDevices) @@ -190,20 +192,4 @@ public static TOTPDevice[] getDevices(Main main, String userId) return devices; } - public static class CreateDeviceResponse { - public final String secret; - - public CreateDeviceResponse(String secret) { - this.secret = secret; - } - } - - public static class VerifyDeviceResponse { - public final boolean deviceWasAlreadyVerified; - - public VerifyDeviceResponse(boolean deviceWasAlreadyVerified) { - this.deviceWasAlreadyVerified = deviceWasAlreadyVerified; - } - } - } diff --git a/src/test/java/io/supertokens/test/totp/TOTPDevicesTest.java b/src/test/java/io/supertokens/test/totp/TOTPDevicesTest.java deleted file mode 100644 index 2ab28a0f0..000000000 --- a/src/test/java/io/supertokens/test/totp/TOTPDevicesTest.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. - * - * This software is licensed under the Apache License, Version 2.0 (the - * "License") as published by the Apache Software Foundation. - * - * You may not use this file except in compliance with the License. You may - * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - */ - -package io.supertokens.test.totp; - -import static org.junit.Assert.assertNotNull; // Not sure about this - -import org.junit.AfterClass; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TestRule; - -import io.supertokens.test.Utils; -import io.supertokens.ProcessState; -import io.supertokens.pluginInterface.STORAGE_TYPE; -import io.supertokens.storageLayer.StorageLayer; -import io.supertokens.test.TestingProcessManager; - -import io.supertokens.totp.Totp; -import io.supertokens.pluginInterface.totp.TOTPDevice; -import io.supertokens.pluginInterface.totp.TOTPStorage; - -public class TOTPDevicesTest { - - @Rule - public TestRule watchman = Utils.getOnFailure(); - - @AfterClass - public static void afterTesting() { - Utils.afterTesting(); - } - - @Before - public void beforeEach() { - Utils.reset(); - } - - @Test - public void createDeviceWithFullCode() throws Exception { - String[] args = { "../" }; - - TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); - assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); - - if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { - return; - } - - TOTPStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); - - Totp.CreateDeviceResponse createDeviceResponse = Totp.createDevice(process.getProcess(), "userId", "deviceName", - 1, 30); - assertNotNull(createDeviceResponse); - createDeviceResponse.deviceName.equals("deviceName"); - - Totp.markDeviceAsVerified(process.getProcess(), "userId", "deviceName"); - - process.kill(); - assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); - } - -} diff --git a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java new file mode 100644 index 000000000..91dd5309e --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.test.totp; + +import static org.junit.Assert.assertNotNull; // Not sure about this +import static org.junit.Assert.assertThrows; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import io.supertokens.test.Utils; +import io.supertokens.test.totp.TOTPStorageTest.TestSetupResult; +import io.supertokens.Main; +import io.supertokens.ProcessState; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; + +import io.supertokens.totp.Totp; +import io.supertokens.totp.exceptions.InvalidTotpException; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.totp.TOTPStorage; +import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; +import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; + +public class TOTPRecipeTest { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + public class TestSetupResult { + public TOTPStorage storage; + public TestingProcessManager.TestingProcess process; + + public TestSetupResult(TOTPStorage storage, TestingProcessManager.TestingProcess process) { + this.storage = storage; + this.process = process; + } + } + + public TestSetupResult setup() throws InterruptedException { + String[] args = { "../" }; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + assert (false); + } + TOTPStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); + + return new TestSetupResult(storage, process); + } + + @Test + public void createAndVerifyDevice() throws Exception { + TestSetupResult result = setup(); + Main main = result.process.getProcess(); + + // Create device + String secret = Totp.createDevice(main, "userId", "deviceName", 1, 30); + assert secret != ""; + + // Create same device again (should fail) + assertThrows(DeviceAlreadyExistsException.class, () -> Totp.createDevice(main, "userId", "deviceName", 1, 30)); + + // Try verify non-existent user: + assertThrows(TotpNotEnabledException.class, + () -> Totp.verifyDevice(main, "non-existent-user", "deviceName", "XXXX")); + + // Try verify non-existent device + assertThrows(UnknownDeviceException.class, + () -> Totp.verifyDevice(main, "userId", "non-existent-device", "XXXX")); + + // Verify device with wrong code + assertThrows(InvalidTotpException.class, () -> Totp.verifyDevice(main, "userId", "deviceName", "wrong-code")); + + // Verify device with correct code + boolean deviceAlreadyVerified = Totp.verifyDevice(main, "userId", "deviceName", "XXXX"); + assert !deviceAlreadyVerified; + + // Verify again with same correct code: + assertThrows(InvalidTotpException.class, () -> Totp.verifyDevice(main, "userId", "deviceName", "XXXX")); + + // Verify again with new correct code: + deviceAlreadyVerified = Totp.verifyDevice(main, "userId", "deviceName", "XXXX-new"); + assert deviceAlreadyVerified; + + // Verify again with wrong code + assertThrows(InvalidTotpException.class, () -> Totp.verifyDevice(main, "userId", "deviceName", "wrong-code")); + + result.process.kill(); + assertNotNull(result.process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + +} diff --git a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java index a92b6226e..461e2021e 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java @@ -67,9 +67,9 @@ public void createDeviceTests() throws Exception { TestSetupResult result = setup(); TOTPStorage storage = result.storage; - TOTPDevice device1 = new TOTPDevice("d1", "user", "secretKey", 30, 1, false); - TOTPDevice device2 = new TOTPDevice("d2", "user", "secretKey", 30, 1, false); - TOTPDevice device2Duplicate = new TOTPDevice("d2", "user", "secretKey", 30, 1, false); + TOTPDevice device1 = new TOTPDevice("user", "d1", "secretKey", 30, 1, false); + TOTPDevice device2 = new TOTPDevice("user", "d2", "secretKey", 30, 1, false); + TOTPDevice device2Duplicate = new TOTPDevice("user", "d2", "secretKey", 30, 1, false); storage.createDevice(device1); @@ -98,7 +98,7 @@ public void verifyDeviceTests() throws Exception { TestSetupResult result = setup(); TOTPStorage storage = result.storage; - TOTPDevice device = new TOTPDevice("device", "user", "secretKey", 30, 1, false); + TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); storage.createDevice(device); TOTPDevice[] storedDevices = storage.getDevices("user"); @@ -121,7 +121,7 @@ public void deleteDeviceTests() throws Exception { TestSetupResult result = setup(); TOTPStorage storage = result.storage; - TOTPDevice device = new TOTPDevice("device", "user", "secretKey", 30, 1, false); + TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); storage.createDevice(device); TOTPDevice[] storedDevices = storage.getDevices("user"); @@ -142,7 +142,7 @@ public void updateDeviceNametests() throws Exception { TestSetupResult result = setup(); TOTPStorage storage = result.storage; - TOTPDevice device = new TOTPDevice("device", "user", "secretKey", 30, 1, false); + TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); storage.createDevice(device); TOTPDevice[] storedDevices = storage.getDevices("user"); @@ -162,7 +162,7 @@ public void updateDeviceNametests() throws Exception { // Try to create a new device and rename it to the same name as an existing // device: - TOTPDevice newDevice = new TOTPDevice("new-device", "user", "secretKey", 30, 1, false); + TOTPDevice newDevice = new TOTPDevice("user", "new-device", "secretKey", 30, 1, false); storage.createDevice(newDevice); assertThrows(DeviceAlreadyExistsException.class, @@ -177,8 +177,8 @@ public void getDevicesTest() throws Exception { TestSetupResult result = setup(); TOTPStorage storage = result.storage; - TOTPDevice device1 = new TOTPDevice("d1", "user", "secretKey", 30, 1, false); - TOTPDevice device2 = new TOTPDevice("d2", "user", "secretKey", 30, 1, false); + TOTPDevice device1 = new TOTPDevice("user", "d1", "secretKey", 30, 1, false); + TOTPDevice device2 = new TOTPDevice("user", "d2", "secretKey", 30, 1, false); storage.createDevice(device1); storage.createDevice(device2); @@ -188,6 +188,9 @@ public void getDevicesTest() throws Exception { assert (storedDevices.length == 2); assert (storedDevices[0].deviceName.equals("d1")); assert (storedDevices[1].deviceName.equals("d2")); + + storedDevices = storage.getDevices("non-existent-user"); + assert (storedDevices.length == 0); } @Test @@ -195,7 +198,7 @@ public void insertUsedCodeTest() throws Exception { TestSetupResult result = setup(); TOTPStorage storage = result.storage; - TOTPDevice device = new TOTPDevice("device", "user", "secretKey", 30, 1, false); + TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); TOTPUsedCode code = new TOTPUsedCode("user", "1234", true, 1); storage.createDevice(device); @@ -213,7 +216,8 @@ public void insertUsedCodeTest() throws Exception { assertThrows(TotpNotEnabledException.class, () -> storage.insertUsedCode(new TOTPUsedCode("non-existent-user", "1234", true, 1))); - // FIXME: Deleting the device should delete the used codes + // FIXME: Deleting last device of the user should delete all used codes of the + // user storage.deleteDevice("user", "device"); usedCodes = storage.getUsedCodes("user"); assert (usedCodes.length == 0); From dc3b14311d08de61adeb09ac47cfa61f30f71501 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Mon, 20 Feb 2023 16:05:28 +0530 Subject: [PATCH 10/42] fix: Remove related used codes when a user device is being deleted --- .../java/io/supertokens/inmemorydb/Start.java | 2 +- .../inmemorydb/queries/TOTPQueries.java | 17 ++++---- src/main/java/io/supertokens/totp/Totp.java | 6 ++- .../test/totp/TOTPStorageTest.java | 41 ++++++++++++------- 4 files changed, 40 insertions(+), 26 deletions(-) diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index 6405db34e..d3855de3f 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -1670,7 +1670,7 @@ public void deleteDevice(String userId, String deviceName) // Note: This step is only required for in-memory databases. // They don't have cascading deletes, so we need to manually delete the codes - TOTPQueries.removeUsedCodesForUser(this, userId, deviceName); + TOTPQueries.removeUsedCodes(this, userId, deviceName); } catch (SQLException e) { throw new StorageQueryException(e); } diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index 08b024a66..53a4fa1eb 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -1,6 +1,5 @@ package io.supertokens.inmemorydb.queries; -import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; @@ -10,10 +9,8 @@ import io.supertokens.inmemorydb.config.Config; import io.supertokens.pluginInterface.RowMapper; import io.supertokens.pluginInterface.exceptions.StorageQueryException; -import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.totp.TOTPDevice; import io.supertokens.pluginInterface.totp.TOTPUsedCode; -import jakarta.annotation.Nullable; import static io.supertokens.inmemorydb.QueryExecutorTemplate.execute; import static io.supertokens.inmemorydb.QueryExecutorTemplate.update; @@ -31,7 +28,7 @@ public static String getQueryToCreateUsedCodesTable(Start start) { return "CREATE TABLE IF NOT EXISTS " + Config.getConfig(start).getTotpUsedCodesTable() + " (" + "user_id VARCHAR(128) NOT NULL, " + "device_name VARCHAR(256), " + "code CHAR(6) NOT NULL," + "is_valid_code BOOLEAN NOT NULL," - + "expiry_time_ms BIGINT UNSIGNED NOT NULL," + + "expiry_time_ms BIGINT UNSIGNED NOT NULL," // Note: UNSIGNED won't work in Postgres + "FOREIGN KEY (user_id, device_name) REFERENCES " + Config.getConfig(start).getTotpUserDevicesTable() + "(user_id, device_name) ON DELETE CASCADE);"; } @@ -106,13 +103,14 @@ public static TOTPDevice[] getDevices(Start start, String userId) public static int insertUsedCode(Start start, TOTPUsedCode code) throws SQLException, StorageQueryException { String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUsedCodesTable() - + " (user_id, code, is_valid_code, expiry_time_ms) VALUES (?, ?, ?, ?);"; + + " (user_id, device_name, code, is_valid_code, expiry_time_ms) VALUES (?, ?, ?, ?, ?);"; return update(start, QUERY, pst -> { pst.setString(1, code.userId); - pst.setString(2, code.code); - pst.setBoolean(3, code.isValidCode); - pst.setLong(4, code.expiryTime); + pst.setString(2, code.deviceName); + pst.setString(3, code.code); + pst.setBoolean(4, code.isValidCode); + pst.setLong(5, code.expiryTime); }); } @@ -138,7 +136,7 @@ public static int removeExpiredCodes(Start start) return update(start, QUERY, pst -> pst.setLong(1, System.currentTimeMillis())); } - public static int removeUsedCodesForUser(Start start, String userId, String deviceName) + public static int removeUsedCodes(Start start, String userId, String deviceName) throws StorageQueryException, SQLException { // Remove codes where userId matches the given userId // ONLY required for inmemorydb, as it does not support foreign key constraints. @@ -187,6 +185,7 @@ private static TOTPUsedCodeRowMapper getInstance() { public TOTPUsedCode map(ResultSet result) throws SQLException { return new TOTPUsedCode( result.getString("user_id"), + result.getString("device_name"), result.getString("code"), result.getBoolean("is_valid_code"), result.getLong("expiry_time_ms")); diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index f06ccab81..dad00625b 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -72,7 +72,7 @@ public static boolean verifyDevice(Main main, String userId, String deviceName, } // Insert the code into the list of used codes: - TOTPUsedCode newCode = new TOTPUsedCode(userId, code, true, System.currentTimeMillis() + 1000 * 60 * 5); + TOTPUsedCode newCode = new TOTPUsedCode(userId, matchingDevice.deviceName, code, true, System.currentTimeMillis() + 1000 * 60 * 5); totpStorage.insertUsedCode(newCode); // Check if the code is valid: @@ -115,16 +115,18 @@ public static void verifyCode(Main main, String userId, String code, boolean all // Try different devices until we find one that works: boolean isValid = false; + TOTPDevice matchingDevice = null; for (TOTPDevice device : devices) { // Check if the code is valid for this device: if (checkCode(device, code)) { isValid = true; + matchingDevice = device; break; } } // Insert the code into the list of used codes: - TOTPUsedCode newCode = new TOTPUsedCode(userId, code, isValid, System.currentTimeMillis() + 1000 * 60 * 5); + TOTPUsedCode newCode = new TOTPUsedCode(userId, matchingDevice.deviceName, code, isValid, System.currentTimeMillis() + 1000 * 60 * 5); totpStorage.insertUsedCode(newCode); if (isValid) { diff --git a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java index 461e2021e..ea595e680 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java @@ -199,7 +199,7 @@ public void insertUsedCodeTest() throws Exception { TOTPStorage storage = result.storage; TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); - TOTPUsedCode code = new TOTPUsedCode("user", "1234", true, 1); + TOTPUsedCode code = new TOTPUsedCode("user", "device", "1234", true, 1); storage.createDevice(device); boolean isInserted = storage.insertUsedCode(code); @@ -212,15 +212,28 @@ public void insertUsedCodeTest() throws Exception { assert (usedCodes[0].isValidCode); assert (usedCodes[0].expiryTime == 1); - // Try to insert code when device (userId) doesn't exist: + // Deleting a device of the user should delete all related valid codes (coz they + // have deviceName != null) + TOTPUsedCode invalidCode = new TOTPUsedCode("user", null, "invalid-code", false, 1); + storage.insertUsedCode(invalidCode); + // Delete the device and check if the only the valid code is deleted: + storage.deleteDevice("user", "device"); + TOTPUsedCode[] newUsedCodes = storage.getUsedCodes("user"); + assert (newUsedCodes.length == 1); + assert (newUsedCodes[0].code.equals("invalid-code")); + + // Try to insert code when device doesn't exist and user doesn't have any device (i.e. TOTP not enabled) assertThrows(TotpNotEnabledException.class, - () -> storage.insertUsedCode(new TOTPUsedCode("non-existent-user", "1234", true, 1))); + () -> storage.insertUsedCode(new TOTPUsedCode("user", "non-existent-device", "1234", true, 1))); - // FIXME: Deleting last device of the user should delete all used codes of the - // user - storage.deleteDevice("user", "device"); - usedCodes = storage.getUsedCodes("user"); - assert (usedCodes.length == 0); + // Try to insert code when device doesn't exist and user already has a device (i.e. TOTP enabled) + TOTPDevice newDevice = new TOTPDevice("user", "new-device", "secretKey", 30, 1, false); + storage.createDevice(newDevice); + storage.insertUsedCode(new TOTPUsedCode("user", "non-existent-device", "1234", true, 1)); + + // Try to insert code when user doesn't exist: + assertThrows(TotpNotEnabledException.class, + () -> storage.insertUsedCode(new TOTPUsedCode("non-existent-user", "device", "1234", true, 1))); } @Test @@ -228,9 +241,9 @@ public void getUsedCodesTest() throws Exception { TestSetupResult result = setup(); TOTPStorage storage = result.storage; - TOTPDevice device = new TOTPDevice("device", "user", "secretKey", 30, 1, false); - TOTPUsedCode code1 = new TOTPUsedCode("user", "code1", true, 1); - TOTPUsedCode code2 = new TOTPUsedCode("user", "code2", false, 1); + TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); + TOTPUsedCode code1 = new TOTPUsedCode("user", "device", "code1", true, 1); + TOTPUsedCode code2 = new TOTPUsedCode("user", null, "code2", false, 1); storage.createDevice(device); storage.insertUsedCode(code1); @@ -249,9 +262,9 @@ public void removeExpiredCodesTest() throws Exception { TestSetupResult result = setup(); TOTPStorage storage = result.storage; - TOTPDevice device = new TOTPDevice("device", "user", "secretKey", 30, 1, false); - TOTPUsedCode codeToDelete = new TOTPUsedCode("user", "codeToDelete", true, 1); - TOTPUsedCode codeToKeep = new TOTPUsedCode("user", "codeToKeep", false, 1); + TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); + TOTPUsedCode codeToDelete = new TOTPUsedCode("user", "device", "codeToDelete", true, System.currentTimeMillis() - 1000); + TOTPUsedCode codeToKeep = new TOTPUsedCode("user", null, "codeToKeep", false, System.currentTimeMillis() + 10000); storage.createDevice(device); storage.insertUsedCode(codeToDelete); From ed448128c0cdf212dbedebc07aeaf94f1d760cbb Mon Sep 17 00:00:00 2001 From: KShivendu Date: Mon, 20 Feb 2023 18:23:01 +0530 Subject: [PATCH 11/42] feat(totp): Add cron to remove expired codes and improve tests --- .../DeleteExpiredTotpTokens.java | 50 +++++++++ .../java/io/supertokens/inmemorydb/Start.java | 10 +- .../inmemorydb/queries/TOTPQueries.java | 9 +- src/main/java/io/supertokens/totp/Totp.java | 106 +++++++++--------- .../supertokens/test/totp/TOTPRecipeTest.java | 70 +++++++++++- 5 files changed, 185 insertions(+), 60 deletions(-) create mode 100644 src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java diff --git a/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java b/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java new file mode 100644 index 000000000..3faec9c99 --- /dev/null +++ b/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java @@ -0,0 +1,50 @@ +package io.supertokens.cronjobs.deleteExpiredTotpTokens; + +import io.supertokens.Main; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; +import io.supertokens.cronjobs.CronTask; +import io.supertokens.cronjobs.CronTaskTest; +import io.supertokens.storageLayer.StorageLayer; + +public class DeleteExpiredTotpTokens extends CronTask { + + public static final String RESOURCE_KEY = "io.supertokens.cronjobs.deleteExpiredTotpTokens.DeleteExpiredTotpTokens"; + + private DeleteExpiredTotpTokens(Main main) { + super("DeleteExpiredTotpTokens", main); + } + + @Override + protected void doTask() throws Exception { + if (StorageLayer.getStorage(this.main).getType() != STORAGE_TYPE.SQL) { + return; + } + + TOTPSQLStorage storage = StorageLayer.getTOTPStorage(this.main); + + storage.removeExpiredCodes(); + } + + @Override + public int getIntervalTimeSeconds() { + if (Main.isTesting) { + Integer interval = CronTaskTest.getInstance(main).getIntervalInSeconds(RESOURCE_KEY); + if (interval != null) { + return interval; + } + } + + return 3600; // every hour + } + + @Override + public int getInitialWaitTimeSeconds() { + if (!Main.isTesting) { + return getIntervalTimeSeconds(); + } else { + return 0; + } + } + +} diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index d3855de3f..f3b3940c4 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -1647,13 +1647,21 @@ public void createDevice(TOTPDevice device) throws StorageQueryException, Device } @Override - public void markDeviceAsVerified(String userId, String deviceName) + public boolean markDeviceAsVerified(String userId, String deviceName) throws StorageQueryException, UnknownDeviceException { try { int updatedCount = TOTPQueries.markDeviceAsVerified(this, userId, deviceName); if (updatedCount == 0) { + TOTPDevice[] devices = TOTPQueries.getDevices(this, userId); + for (TOTPDevice device : devices) { + if (device.deviceName.equals(deviceName) && device.verified) { + return true; // Device was already verified + } + } + // Device was not found: throw new UnknownDeviceException(); } + return false; // Device was marked as verified } catch (SQLException e) { throw new StorageQueryException(e); } diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index 53a4fa1eb..cebdbc71b 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -56,7 +56,7 @@ public static void createDevice(Start start, TOTPDevice device) public static int markDeviceAsVerified(Start start, String userId, String deviceName) throws StorageQueryException, SQLException { String QUERY = "UPDATE " + Config.getConfig(start).getTotpUserDevicesTable() - + " SET verified = true WHERE user_id = ? AND device_name = ?;"; + + " SET verified = true WHERE user_id = ? AND device_name = ? WHERE verified = false;"; return update(start, QUERY, pst -> { pst.setString(1, userId); pst.setString(2, deviceName); @@ -117,8 +117,11 @@ public static int insertUsedCode(Start start, TOTPUsedCode code) throws SQLExcep public static TOTPUsedCode[] getUsedCodes(Start start, String userId) throws SQLException, StorageQueryException { String QUERY = "SELECT * FROM " + Config.getConfig(start).getTotpUsedCodesTable() - + " WHERE user_id = ?;"; - return execute(start, QUERY, pst -> pst.setString(1, userId), result -> { + + " WHERE user_id = ? AND expiry_time_ms > ?;"; + return execute(start, QUERY, pst -> { + pst.setString(1, userId); + pst.setLong(2, System.currentTimeMillis()); + }, result -> { List codes = new ArrayList<>(); while (result.next()) { codes.add(TOTPUsedCodeRowMapper.getInstance().map(result)); diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index dad00625b..5369bd6d1 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -41,47 +41,19 @@ public static boolean verifyDevice(Main main, String userId, String deviceName, throws StorageQueryException, TotpNotEnabledException, UnknownDeviceException, InvalidTotpException { // Here boolean return value tells whether the device was already verified - boolean deviceAlreadyVerified = false; - TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); - TOTPDevice[] devices = totpStorage.getDevices(userId); - if (devices.length == 0) { - throw new TotpNotEnabledException(); - } - - // Find the device: - TOTPDevice matchingDevice = null; - for (TOTPDevice device : devices) { - if (device.deviceName.equals(deviceName)) { - deviceAlreadyVerified = device.verified; - matchingDevice = device; - break; - } - } - if (matchingDevice == null) { - throw new UnknownDeviceException(); - } - - // Check if the code is unused: - TOTPUsedCode[] usedCodes = totpStorage.getUsedCodes(userId); - for (TOTPUsedCode usedCode : usedCodes) { - if (usedCode.code.equals(code)) { - throw new InvalidTotpException(); - } - } + boolean deviceAlreadyVerified = totpStorage.markDeviceAsVerified(userId, deviceName); - // Insert the code into the list of used codes: - TOTPUsedCode newCode = new TOTPUsedCode(userId, matchingDevice.deviceName, code, true, System.currentTimeMillis() + 1000 * 60 * 5); - totpStorage.insertUsedCode(newCode); + if (deviceAlreadyVerified) + return true; - // Check if the code is valid: - if (!checkCode(matchingDevice, code)) { + try { + verifyCode(main, userId, code, true); + return false; + } catch (LimitReachedException e) { throw new InvalidTotpException(); } - - totpStorage.markDeviceAsVerified(userId, deviceName); - return deviceAlreadyVerified; } public static void verifyCode(Main main, String userId, String code, boolean allowUnverifiedDevices) @@ -95,22 +67,26 @@ public static void verifyCode(Main main, String userId, String code, boolean all throw new TotpNotEnabledException(); } - // FIXME: Every unused_code should be linked to the device it was used for. - // Otherwise, if a user has multiple devices, and they use the same code for - // both, - // then the code will be considered as used for both devices. This could cause - // UX - // issues. - // If we do this, then it also means that we need to assign a device ID to each - // device (OR use - // (userId, deviceName) as the ID) - - // Check if the code has been successfully used by the user (for any device): - TOTPUsedCode[] usedCodes = totpStorage.getUsedCodes(userId); - for (TOTPUsedCode usedCode : usedCodes) { - if (usedCode.code.equals(code) && usedCode.isValidCode) { - throw new InvalidTotpException(); + // If allowUnverifiedDevices is false, then remove all unverified devices from + // the list: + if (!allowUnverifiedDevices) { + int verifiedDeviceCount = 0; + for (TOTPDevice device : devices) { + if (device.verified) { + verifiedDeviceCount++; + } + } + + TOTPDevice[] verifiedDevices = new TOTPDevice[verifiedDeviceCount]; + int index = 0; + for (TOTPDevice device : devices) { + if (device.verified) { + verifiedDevices[index] = device; + index++; + } } + + devices = verifiedDevices; } // Try different devices until we find one that works: @@ -125,16 +101,40 @@ public static void verifyCode(Main main, String userId, String code, boolean all } } + // Check if the code has been successfully used by the user (for any of their + // devices): + TOTPUsedCode[] usedCodes = totpStorage.getUsedCodes(userId); + for (TOTPUsedCode usedCode : usedCodes) { + if (usedCode.code.equals(code)) { // FIXME: Only check for the same device for better UX? + throw new InvalidTotpException(); + } + } + // Insert the code into the list of used codes: - TOTPUsedCode newCode = new TOTPUsedCode(userId, matchingDevice.deviceName, code, isValid, System.currentTimeMillis() + 1000 * 60 * 5); + TOTPUsedCode newCode = null; + if (matchingDevice == null) { + // TODO: Verify that this doesn't pile up OR gets deleted very quickly: + int expireAfterSeconds = 60 * 5; // 5 minutes + newCode = new TOTPUsedCode(userId, null, code, isValid, + System.currentTimeMillis() + 1000 * expireAfterSeconds); + } else { + int expireAfterSeconds = matchingDevice.period * (2 * matchingDevice.period + 1); + newCode = new TOTPUsedCode(userId, matchingDevice.deviceName, code, isValid, + System.currentTimeMillis() + 1000 * expireAfterSeconds); + } totpStorage.insertUsedCode(newCode); if (isValid) { return; } - // Check if last 5 codes are all invalid: - int WINDOW_SIZE = 5; + // Now we know that the code is invalid. + + // Check if last N codes are all invalid: + // Note that usedCodes will get updated when: + // - A valid code is used: It will break the chain of invalid codes. + // - Cron job runs: deletes expired codes every 1 hour + int WINDOW_SIZE = 3; int invalidCodes = 0; for (int i = usedCodes.length - 1; i >= 0 && i >= usedCodes.length - WINDOW_SIZE; i--) { if (!usedCodes[i].isValidCode) { diff --git a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java index 91dd5309e..1f6cabf57 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java @@ -26,7 +26,6 @@ import org.junit.rules.TestRule; import io.supertokens.test.Utils; -import io.supertokens.test.totp.TOTPStorageTest.TestSetupResult; import io.supertokens.Main; import io.supertokens.ProcessState; import io.supertokens.pluginInterface.STORAGE_TYPE; @@ -35,7 +34,6 @@ import io.supertokens.totp.Totp; import io.supertokens.totp.exceptions.InvalidTotpException; -import io.supertokens.pluginInterface.totp.TOTPDevice; import io.supertokens.pluginInterface.totp.TOTPStorage; import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; @@ -81,7 +79,7 @@ public TestSetupResult setup() throws InterruptedException { } @Test - public void createAndVerifyDevice() throws Exception { + public void createDevice() throws Exception { TestSetupResult result = setup(); Main main = result.process.getProcess(); @@ -91,6 +89,72 @@ public void createAndVerifyDevice() throws Exception { // Create same device again (should fail) assertThrows(DeviceAlreadyExistsException.class, () -> Totp.createDevice(main, "userId", "deviceName", 1, 30)); + } + + public void triggerRateLimit(Main main) throws Exception { + for (int i = 0; i < 4; i++) { + assertThrows( + InvalidTotpException.class, + () -> Totp.verifyCode(main, "user", "wrong-code", true)); + } + + // 5th attempt should fail with rate limiting error: + assertThrows( + InvalidTotpException.class, + () -> Totp.verifyCode(main, "user", "XXXX-code", true)); + } + + @Test + public void createDeviceAndVerifyCode() throws Exception { + TestSetupResult result = setup(); + Main main = result.process.getProcess(); + + // Create device + String secret = Totp.createDevice(main, "userId", "deviceName", 1, 30); + + // Try login with non-existent user: + assertThrows(TotpNotEnabledException.class, + () -> Totp.verifyCode(main, "non-existent-user", "XXXX-code", true)); + + // Try login with invalid code: + assertThrows(InvalidTotpException.class, + () -> Totp.verifyCode(main, "user", "invalid-code", true)); + + // Try login with with unverified device: + assertThrows(InvalidTotpException.class, + () -> Totp.verifyCode(main, "user", "XXXX-code", false)); + + // Successfully login: + Totp.verifyCode(main, "user", "XXXX-code", true); + // Now try again with same code: + assertThrows( + InvalidTotpException.class, + () -> Totp.verifyCode(main, "user", "XXXX-code", true)); + + // Trigger rate limiting and fix it with a correct code: + { + triggerRateLimit(main); + // Using a correct code should fix the rate limiting: + Totp.verifyCode(main, "user", "XXXX-code", true); + } + + // Trigger rate limiting and fix it with cronjob (runs every 1 hour) + { + triggerRateLimit(main); + // Run cronjob: + // Totp.runCron(main); + Totp.verifyCode(main, "user", "XXXX-code", true); + } + } + + @Test + public void createAndVerifyDevice() throws Exception { + TestSetupResult result = setup(); + Main main = result.process.getProcess(); + + // Create device + // FIXME: Use secret to generate actual TOTP code + String secret = Totp.createDevice(main, "userId", "deviceName", 1, 30); // Try verify non-existent user: assertThrows(TotpNotEnabledException.class, From 957e016c72598a2ea866b92529b1c062f1f37ff0 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Tue, 21 Feb 2023 13:12:58 +0530 Subject: [PATCH 12/42] feat: Add java-otp as a dependency --- build.gradle | 4 +++- src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java | 2 +- src/test/java/io/supertokens/test/totp/TOTPStorageTest.java | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/build.gradle b/build.gradle index fe8e62519..dfb77d289 100644 --- a/build.gradle +++ b/build.gradle @@ -65,6 +65,9 @@ dependencies { // https://mvnrepository.com/artifact/com.lambdaworks/scrypt implementation group: 'com.lambdaworks', name: 'scrypt', version: '1.4.0' + // https://mvnrepository.com/artifact/com.eatthepath/java-otp + implementation group: 'com.eatthepath', name: 'java-otp', version: '0.4.0' + compileOnly project(":supertokens-plugin-interface") testImplementation project(":supertokens-plugin-interface") @@ -159,4 +162,3 @@ tasks.withType(Test) { } } } - \ No newline at end of file diff --git a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java index 1f6cabf57..e74a42722 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java @@ -16,7 +16,7 @@ package io.supertokens.test.totp; -import static org.junit.Assert.assertNotNull; // Not sure about this +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; import org.junit.AfterClass; diff --git a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java index ea595e680..b0989610e 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java @@ -1,6 +1,6 @@ package io.supertokens.test.totp; -import static org.junit.Assert.assertNotNull; // Not sure about this +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; import org.junit.AfterClass; From 22e64a552c7d1b9847a04b459fca5a0e115b36fa Mon Sep 17 00:00:00 2001 From: KShivendu Date: Tue, 21 Feb 2023 19:08:48 +0530 Subject: [PATCH 13/42] feat: Introduce totp_users to keep track of users who have more than one TOTP device --- .../java/io/supertokens/inmemorydb/Start.java | 24 ++- .../inmemorydb/config/SQLiteConfig.java | 4 + .../inmemorydb/queries/GeneralQueries.java | 4 + .../inmemorydb/queries/TOTPQueries.java | 126 ++++++++++++--- src/main/java/io/supertokens/totp/Totp.java | 9 +- .../test/totp/TOTPStorageTest.java | 148 +++++++++++------- 6 files changed, 220 insertions(+), 95 deletions(-) diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index f3b3940c4..e2875acd4 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -1633,16 +1633,16 @@ public void addInfoToNonAuthRecipesBasedOnUserId(String className, String userId @Override public void createDevice(TOTPDevice device) throws StorageQueryException, DeviceAlreadyExistsException { try { - TOTPQueries.createDevice(this, device); - } catch (SQLException e) { - if (e.getMessage() - .equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (UNIQUE constraint failed: " - + Config.getConfig(this).getTotpUserDevicesTable() + ".user_id, " - + Config.getConfig(this).getTotpUserDevicesTable() + ".device_name" + ")")) { + TOTPQueries.createDeviceAndUser(this, device); + } catch (StorageTransactionLogicException e) { + String message = e.actualException.getMessage(); + if (message.equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (UNIQUE constraint failed: " + + Config.getConfig(this).getTotpUserDevicesTable() + ".user_id, " + + Config.getConfig(this).getTotpUserDevicesTable() + ".device_name" + ")")) { throw new DeviceAlreadyExistsException(); } - throw new StorageQueryException(e); + throw new StorageQueryException(e.actualException); } } @@ -1675,12 +1675,8 @@ public void deleteDevice(String userId, String deviceName) if (deletedCount == 0) { throw new UnknownDeviceException(); } - - // Note: This step is only required for in-memory databases. - // They don't have cascading deletes, so we need to manually delete the codes - TOTPQueries.removeUsedCodes(this, userId, deviceName); - } catch (SQLException e) { - throw new StorageQueryException(e); + } catch (StorageTransactionLogicException e) { + throw new StorageQueryException(e.actualException); } } @@ -1731,7 +1727,7 @@ public boolean insertUsedCode(TOTPUsedCode usedCodeObj) } @Override - public TOTPUsedCode[] getUsedCodes(String userId) + public TOTPUsedCode[] getNonExpiredUsedCodes(String userId) throws StorageQueryException { try { return TOTPQueries.getUsedCodes(this, userId); diff --git a/src/main/java/io/supertokens/inmemorydb/config/SQLiteConfig.java b/src/main/java/io/supertokens/inmemorydb/config/SQLiteConfig.java index c2eb0b2dd..7db4cab02 100644 --- a/src/main/java/io/supertokens/inmemorydb/config/SQLiteConfig.java +++ b/src/main/java/io/supertokens/inmemorydb/config/SQLiteConfig.java @@ -90,6 +90,10 @@ public String getUserIdMappingTable() { return "userid_mapping"; } + public String getTotpUsersTable() { + return "totp_users"; + } + public String getTotpUserDevicesTable() { return "totp_user_devices"; } diff --git a/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java index 5812d8cef..f0343f912 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java @@ -186,6 +186,10 @@ public static void createTablesIfNotExists(Start start, Main main) throws SQLExc update(start, UserIdMappingQueries.getQueryToCreateUserIdMappingTable(start), NO_OP_SETTER); } + if (!doesTableExists(start, Config.getConfig(start).getTotpUsersTable())) { + getInstance(main).addState(CREATING_NEW_TABLE, null); + update(start, TOTPQueries.getQueryToCreateUsersTable(start), NO_OP_SETTER); + } if (!doesTableExists(start, Config.getConfig(start).getTotpUserDevicesTable())) { getInstance(main).addState(CREATING_NEW_TABLE, null); diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index cebdbc71b..b409f9ec8 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -1,5 +1,6 @@ package io.supertokens.inmemorydb.queries; +import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; @@ -9,6 +10,7 @@ import io.supertokens.inmemorydb.config.Config; import io.supertokens.pluginInterface.RowMapper; import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.totp.TOTPDevice; import io.supertokens.pluginInterface.totp.TOTPUsedCode; @@ -16,21 +18,29 @@ import static io.supertokens.inmemorydb.QueryExecutorTemplate.update; public class TOTPQueries { + public static String getQueryToCreateUsersTable(Start start) { + return "CREATE TABLE IF NOT EXISTS " + Config.getConfig(start).getTotpUsersTable() + " (" + + "user_id VARCHAR(128) NOT NULL," + + "PRIMARY KEY (user_id))"; + } + public static String getQueryToCreateUserDevicesTable(Start start) { return "CREATE TABLE IF NOT EXISTS " + Config.getConfig(start).getTotpUserDevicesTable() + " (" + "user_id VARCHAR(128) NOT NULL," + "device_name VARCHAR(256) NOT NULL," + "secret_key VARCHAR(256) NOT NULL," + "period INTEGER NOT NULL," + "skew INTEGER NOT NULL," + "verified BOOLEAN NOT NULL," - + "PRIMARY KEY (user_id, device_name))"; + + "PRIMARY KEY (user_id, device_name)" + + "FOREIGN KEY (user_id) REFERENCES " + Config.getConfig(start).getTotpUsersTable() + + "(user_id) ON DELETE CASCADE);"; } public static String getQueryToCreateUsedCodesTable(Start start) { return "CREATE TABLE IF NOT EXISTS " + Config.getConfig(start).getTotpUsedCodesTable() + " (" + "user_id VARCHAR(128) NOT NULL, " + "device_name VARCHAR(256), " + "code CHAR(6) NOT NULL," + "is_valid_code BOOLEAN NOT NULL," - + "expiry_time_ms BIGINT UNSIGNED NOT NULL," // Note: UNSIGNED won't work in Postgres - + "FOREIGN KEY (user_id, device_name) REFERENCES " + Config.getConfig(start).getTotpUserDevicesTable() - + "(user_id, device_name) ON DELETE CASCADE);"; + + "expiry_time_ms BIGINT UNSIGNED NOT NULL," + + "FOREIGN KEY (user_id) REFERENCES " + Config.getConfig(start).getTotpUsersTable() + + "(user_id) ON DELETE CASCADE);"; } public static String getQueryToCreateUsedCodesIndex(Start start) { @@ -38,12 +48,23 @@ public static String getQueryToCreateUsedCodesIndex(Start start) { + Config.getConfig(start).getTotpUsedCodesTable() + " (expiry_time_ms)"; } - public static void createDevice(Start start, TOTPDevice device) - throws StorageQueryException, SQLException { + public static int insertUser_Transaction(Start start, Connection con, String userId) + throws SQLException, StorageQueryException { + // Create user if not exists: + // TODO: Check if not using "CONFLICT DO NOTHING" will break the transaction + // It's not a problem anyways. but we should check for clarity + String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUsersTable() + + " (user_id) VALUES (?) ON CONFLICT DO NOTHING"; + + return update(con, QUERY, pst -> pst.setString(1, userId)); + } + + public static int insertDevice_Transaction(Start start, Connection con, TOTPDevice device) + throws SQLException, StorageQueryException { String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUserDevicesTable() + " (user_id, device_name, secret_key, period, skew, verified) VALUES (?, ?, ?, ?, ?, ?)"; - update(start, QUERY, pst -> { + return update(con, QUERY, pst -> { pst.setString(1, device.userId); pst.setString(2, device.deviceName); pst.setString(3, device.secretKey); @@ -53,27 +74,85 @@ public static void createDevice(Start start, TOTPDevice device) }); } + public static void createDeviceAndUser(Start start, TOTPDevice device) + throws StorageQueryException, StorageTransactionLogicException { + start.startTransaction(con -> { + Connection sqlCon = (Connection) con.getConnection(); + + try { + insertUser_Transaction(start, sqlCon, device.userId); + insertDevice_Transaction(start, sqlCon, device); + sqlCon.commit(); + } catch (SQLException e) { + throw new StorageTransactionLogicException(e); + } + + return null; + }); + } + public static int markDeviceAsVerified(Start start, String userId, String deviceName) throws StorageQueryException, SQLException { String QUERY = "UPDATE " + Config.getConfig(start).getTotpUserDevicesTable() - + " SET verified = true WHERE user_id = ? AND device_name = ? WHERE verified = false;"; + + " SET verified = true WHERE user_id = ? AND device_name = ? AND verified = false;"; return update(start, QUERY, pst -> { pst.setString(1, userId); pst.setString(2, deviceName); }); } - public static int deleteDevice(Start start, String userId, String deviceName) - throws StorageQueryException, SQLException { + public static int deleteDevice_Transaction(Start start, Connection con, String userId, String deviceName) + throws SQLException, StorageQueryException { String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUserDevicesTable() + " WHERE user_id = ? AND device_name = ?;"; - return update(start, QUERY, pst -> { + return update(con, QUERY, pst -> { pst.setString(1, userId); pst.setString(2, deviceName); }); } + public static int removeUser_Transaction(Start start, Connection con, String userId) + throws SQLException, StorageQueryException { + String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUsersTable() + + " WHERE user_id = ?;"; + int removedUsersCount = update(con, QUERY, pst -> pst.setString(1, userId)); + + // Delete all used codes for this user: + // Note: This step is required only for in-memory db. + // Other databases will automatically delete the used codes when the user is + // deleted because of foreign key constraints. + String QUERY2 = "DELETE FROM " + Config.getConfig(start).getTotpUsedCodesTable() + + " WHERE user_id = ?;"; + update(con, QUERY2, pst -> pst.setString(1, userId)); + + return removedUsersCount; + } + + public static int deleteDevice(Start start, String userId, String deviceName) + throws StorageQueryException, StorageTransactionLogicException { + return start.startTransaction(con -> { + Connection sqlCon = (Connection) con.getConnection(); + + try { + int deletedCount = deleteDevice_Transaction(start, sqlCon, userId, deviceName); + if (deletedCount > 0) { + // some device deleted. Check if user has any other device left: + int devicesCount = getDevicesCount_Transaction(start, sqlCon, userId); + if (devicesCount == 0) { + // no device left. delete user + removeUser_Transaction(start, sqlCon, userId); + } + } + + sqlCon.commit(); + return deletedCount; + } catch (SQLException e) { + throw new StorageTransactionLogicException(e); + } + }); + } + public static int updateDeviceName(Start start, String userId, String oldDeviceName, String newDeviceName) throws StorageQueryException, SQLException { String QUERY = "UPDATE " + Config.getConfig(start).getTotpUserDevicesTable() @@ -101,16 +180,25 @@ public static TOTPDevice[] getDevices(Start start, String userId) }); } + public static int getDevicesCount_Transaction(Start start, Connection con, String userId) + throws StorageQueryException, SQLException { + String QUERY = "SELECT COUNT(*) as count FROM " + Config.getConfig(start).getTotpUserDevicesTable() + + " WHERE user_id = ?;"; + + return execute(con, QUERY, pst -> pst.setString(1, userId), result -> { + return result.getInt("count"); + }); + } + public static int insertUsedCode(Start start, TOTPUsedCode code) throws SQLException, StorageQueryException { String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUsedCodesTable() - + " (user_id, device_name, code, is_valid_code, expiry_time_ms) VALUES (?, ?, ?, ?, ?);"; + + " (user_id, code, is_valid_code, expiry_time_ms) VALUES (?, ?, ?, ?);"; return update(start, QUERY, pst -> { pst.setString(1, code.userId); - pst.setString(2, code.deviceName); - pst.setString(3, code.code); - pst.setBoolean(4, code.isValidCode); - pst.setLong(5, code.expiryTime); + pst.setString(2, code.code); + pst.setBoolean(3, code.isValidCode); + pst.setLong(4, code.expiryTime); }); } @@ -139,16 +227,15 @@ public static int removeExpiredCodes(Start start) return update(start, QUERY, pst -> pst.setLong(1, System.currentTimeMillis())); } - public static int removeUsedCodes(Start start, String userId, String deviceName) + public static int removeUsedCodes(Start start, String userId) throws StorageQueryException, SQLException { // Remove codes where userId matches the given userId // ONLY required for inmemorydb, as it does not support foreign key constraints. String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUsedCodesTable() - + " WHERE user_id = ? AND device_name = ?;"; + + " WHERE user_id = ?;"; return update(start, QUERY, pst -> { pst.setString(1, userId); - pst.setString(2, deviceName); }); } @@ -188,7 +275,6 @@ private static TOTPUsedCodeRowMapper getInstance() { public TOTPUsedCode map(ResultSet result) throws SQLException { return new TOTPUsedCode( result.getString("user_id"), - result.getString("device_name"), result.getString("code"), result.getBoolean("is_valid_code"), result.getLong("expiry_time_ms")); diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index 5369bd6d1..7b2c1bd9c 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -103,7 +103,7 @@ public static void verifyCode(Main main, String userId, String code, boolean all // Check if the code has been successfully used by the user (for any of their // devices): - TOTPUsedCode[] usedCodes = totpStorage.getUsedCodes(userId); + TOTPUsedCode[] usedCodes = totpStorage.getNonExpiredUsedCodes(userId); for (TOTPUsedCode usedCode : usedCodes) { if (usedCode.code.equals(code)) { // FIXME: Only check for the same device for better UX? throw new InvalidTotpException(); @@ -115,11 +115,11 @@ public static void verifyCode(Main main, String userId, String code, boolean all if (matchingDevice == null) { // TODO: Verify that this doesn't pile up OR gets deleted very quickly: int expireAfterSeconds = 60 * 5; // 5 minutes - newCode = new TOTPUsedCode(userId, null, code, isValid, + newCode = new TOTPUsedCode(userId, code, isValid, System.currentTimeMillis() + 1000 * expireAfterSeconds); } else { int expireAfterSeconds = matchingDevice.period * (2 * matchingDevice.period + 1); - newCode = new TOTPUsedCode(userId, matchingDevice.deviceName, code, isValid, + newCode = new TOTPUsedCode(userId, code, isValid, System.currentTimeMillis() + 1000 * expireAfterSeconds); } totpStorage.insertUsedCode(newCode); @@ -160,9 +160,8 @@ public static void deleteDevice(Main main, String userId, String deviceName) TOTPDevice[] devices = totpStorage.getDevices(userId); if (devices.length == 0) { throw new TotpNotEnabledException(); - } else { - throw e; } + throw e; } } diff --git a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java index b0989610e..fb2a02af4 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java @@ -121,24 +121,45 @@ public void deleteDeviceTests() throws Exception { TestSetupResult result = setup(); TOTPStorage storage = result.storage; - TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); - storage.createDevice(device); + TOTPDevice device1 = new TOTPDevice("user", "device1", "sk1", 30, 1, false); + TOTPDevice device2 = new TOTPDevice("user", "device2", "sk2", 30, 1, false); + + storage.createDevice(device1); + storage.createDevice(device2); TOTPDevice[] storedDevices = storage.getDevices("user"); - assert (storedDevices.length == 1); + assert (storedDevices.length == 2); + + // Try to delete a device for a user that doesn't exist: + assertThrows(UnknownDeviceException.class, () -> storage.deleteDevice("non-existent-user", "device1")); // Try to delete a device that doesn't exist: assertThrows(UnknownDeviceException.class, () -> storage.deleteDevice("user", "non-existent-device")); - // Delete the device: - storage.deleteDevice("user", "device"); + // Successfully delete device1: + storage.deleteDevice("user", "device1"); storedDevices = storage.getDevices("user"); - assert (storedDevices.length == 0); + assert (storedDevices.length == 1); // device2 should still be there + + long nextDay = System.currentTimeMillis() + 1000 * 60 * 60 * 24; // 1 day from now + + // Deleting all devices of a user should delete all related codes: + { + TOTPUsedCode validCode = new TOTPUsedCode("user", "valid-code", true, nextDay); + TOTPUsedCode invalidCode = new TOTPUsedCode("user", "invalid-code", false, nextDay); + storage.insertUsedCode(validCode); + storage.insertUsedCode(invalidCode); + + storage.deleteDevice("user", "device2"); // delete device2 as well + + TOTPUsedCode[] newUsedCodes = storage.getNonExpiredUsedCodes("user"); + assert (newUsedCodes.length == 0); + } } @Test - public void updateDeviceNametests() throws Exception { + public void updateDeviceNameTests() throws Exception { TestSetupResult result = setup(); TOTPStorage storage = result.storage; @@ -197,64 +218,68 @@ public void getDevicesTest() throws Exception { public void insertUsedCodeTest() throws Exception { TestSetupResult result = setup(); TOTPStorage storage = result.storage; + long nextDay = System.currentTimeMillis() + 1000 * 60 * 60 * 24; // 1 day from now - TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); - TOTPUsedCode code = new TOTPUsedCode("user", "device", "1234", true, 1); + // Insert a long lasting valid code and check that it's returned when queried: + { + TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); + TOTPUsedCode code = new TOTPUsedCode("user", "1234", true, nextDay); - storage.createDevice(device); - boolean isInserted = storage.insertUsedCode(code); - TOTPUsedCode[] usedCodes = storage.getUsedCodes("user"); - - assert (isInserted); - assert (usedCodes.length == 1); - assert (usedCodes[0].userId.equals("user")); - assert (usedCodes[0].code.equals("1234")); - assert (usedCodes[0].isValidCode); - assert (usedCodes[0].expiryTime == 1); - - // Deleting a device of the user should delete all related valid codes (coz they - // have deviceName != null) - TOTPUsedCode invalidCode = new TOTPUsedCode("user", null, "invalid-code", false, 1); - storage.insertUsedCode(invalidCode); - // Delete the device and check if the only the valid code is deleted: - storage.deleteDevice("user", "device"); - TOTPUsedCode[] newUsedCodes = storage.getUsedCodes("user"); - assert (newUsedCodes.length == 1); - assert (newUsedCodes[0].code.equals("invalid-code")); + storage.createDevice(device); + boolean isInserted = storage.insertUsedCode(code); + TOTPUsedCode[] usedCodes = storage.getNonExpiredUsedCodes("user"); - // Try to insert code when device doesn't exist and user doesn't have any device (i.e. TOTP not enabled) - assertThrows(TotpNotEnabledException.class, - () -> storage.insertUsedCode(new TOTPUsedCode("user", "non-existent-device", "1234", true, 1))); + assert (isInserted); + assert (usedCodes.length == 1); + assert usedCodes[0].equals(code); + } - // Try to insert code when device doesn't exist and user already has a device (i.e. TOTP enabled) - TOTPDevice newDevice = new TOTPDevice("user", "new-device", "secretKey", 30, 1, false); - storage.createDevice(newDevice); - storage.insertUsedCode(new TOTPUsedCode("user", "non-existent-device", "1234", true, 1)); + // Try to insert code when user doesn't have any device (i.e. TOTP not enabled) + { + storage.deleteDevice("user", "device"); + assertThrows(TotpNotEnabledException.class, + () -> storage.insertUsedCode(new TOTPUsedCode("user", "1234", true, nextDay))); + } + + // Try to insert code after user has atleast one device (i.e. TOTP enabled) + { + TOTPDevice newDevice = new TOTPDevice("user", "new-device", "secretKey", 30, 1, false); + storage.createDevice(newDevice); + storage.insertUsedCode(new TOTPUsedCode("user", "1234", true, nextDay)); + } // Try to insert code when user doesn't exist: assertThrows(TotpNotEnabledException.class, - () -> storage.insertUsedCode(new TOTPUsedCode("non-existent-user", "device", "1234", true, 1))); + () -> storage.insertUsedCode(new TOTPUsedCode("non-existent-user", "1234", true, nextDay))); } @Test - public void getUsedCodesTest() throws Exception { + public void getNonExpiredUsedCodesTest() throws Exception { TestSetupResult result = setup(); TOTPStorage storage = result.storage; + TOTPUsedCode[] usedCodes = storage.getNonExpiredUsedCodes("non-existent-user"); + assert (usedCodes.length == 0); + + long nextDay = System.currentTimeMillis() + 1000 * 60 * 60 * 24; // 1 day from now + long prevDay = System.currentTimeMillis() - 1000 * 60 * 60 * 24; // 1 day ago + TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); - TOTPUsedCode code1 = new TOTPUsedCode("user", "device", "code1", true, 1); - TOTPUsedCode code2 = new TOTPUsedCode("user", null, "code2", false, 1); + TOTPUsedCode validCode = new TOTPUsedCode("user", "code1", true, nextDay); + TOTPUsedCode invalidCode = new TOTPUsedCode("user", "code2", false, nextDay); + TOTPUsedCode expiredCode = new TOTPUsedCode("user", "expired-code", true, prevDay); + TOTPUsedCode expiredInvalidCode = new TOTPUsedCode("user", "expired-invalid-code", false, prevDay); storage.createDevice(device); - storage.insertUsedCode(code1); - storage.insertUsedCode(code2); + storage.insertUsedCode(validCode); + storage.insertUsedCode(invalidCode); + storage.insertUsedCode(expiredCode); + storage.insertUsedCode(expiredInvalidCode); - TOTPUsedCode[] usedCodes = storage.getUsedCodes("user"); + usedCodes = storage.getNonExpiredUsedCodes("user"); assert (usedCodes.length == 2); - assert (usedCodes[0].code.equals("code1")); - assert (usedCodes[0].isValidCode); - assert (usedCodes[1].code.equals("code2")); - assert (!usedCodes[1].isValidCode); + assert (usedCodes[0].equals(validCode)); + assert (usedCodes[1].equals(invalidCode)); } @Test @@ -262,21 +287,32 @@ public void removeExpiredCodesTest() throws Exception { TestSetupResult result = setup(); TOTPStorage storage = result.storage; + long nextDay = System.currentTimeMillis() + 1000 * 60 * 60 * 24; // 1 day from now + long halfSecond = System.currentTimeMillis() + 500; // 500ms from now + TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); - TOTPUsedCode codeToDelete = new TOTPUsedCode("user", "device", "codeToDelete", true, System.currentTimeMillis() - 1000); - TOTPUsedCode codeToKeep = new TOTPUsedCode("user", null, "codeToKeep", false, System.currentTimeMillis() + 10000); + TOTPUsedCode validCodeToLive = new TOTPUsedCode("user", "valid-code", true, nextDay); + TOTPUsedCode invalidCodeToLive = new TOTPUsedCode("user", "invalid-code", false, nextDay); + TOTPUsedCode validCodeToExpire = new TOTPUsedCode("user", "valid-code", true, halfSecond); + TOTPUsedCode invalidCodeToExpire = new TOTPUsedCode("user", "invalid-code", false, halfSecond); storage.createDevice(device); - storage.insertUsedCode(codeToDelete); - storage.insertUsedCode(codeToKeep); + storage.insertUsedCode(validCodeToLive); + storage.insertUsedCode(invalidCodeToLive); + storage.insertUsedCode(validCodeToExpire); + storage.insertUsedCode(invalidCodeToExpire); - TOTPUsedCode[] usedCodes = storage.getUsedCodes("user"); - assert (usedCodes.length == 2); + TOTPUsedCode[] usedCodes = storage.getNonExpiredUsedCodes("user"); + assert (usedCodes.length == 4); + + // After 500ms seconds pass: + Thread.sleep(500); storage.removeExpiredCodes(); - usedCodes = storage.getUsedCodes("user"); - assert (usedCodes.length == 1); - assert (usedCodes[0].code.equals("codeToKeep")); + usedCodes = storage.getNonExpiredUsedCodes("user"); + assert (usedCodes.length == 2); + assert (usedCodes[0].equals(validCodeToLive)); + assert (usedCodes[1].equals(invalidCodeToLive)); } } From fe17057d045f234deac86ad77ac97a224aeb9ca6 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Tue, 21 Feb 2023 19:16:07 +0530 Subject: [PATCH 14/42] test: Use equals function for cleaner code --- .../io/supertokens/test/totp/TOTPStorageTest.java | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java index fb2a02af4..24a537b7f 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java @@ -74,18 +74,15 @@ public void createDeviceTests() throws Exception { storage.createDevice(device1); TOTPDevice[] storedDevices = storage.getDevices("user"); - TOTPDevice storedDevice = storedDevices[0]; assert (storedDevices.length == 1); - assert (storedDevice.deviceName.equals("d1")); - assert (storedDevice.userId.equals("user")); - assert (storedDevice.secretKey.equals("secretKey")); - assert (storedDevice.period == 30); - assert (storedDevice.skew == 1); - assert (storedDevice.verified == false); + assert storedDevices[0].equals(device1); storage.createDevice(device2); storedDevices = storage.getDevices("user"); + assert (storedDevices.length == 2); + assert storedDevices[0].equals(device1); + assert storedDevices[1].equals(device2); assertThrows(DeviceAlreadyExistsException.class, () -> storage.createDevice(device2Duplicate)); From 6106c1a361aff38ac3e048adb7bce599eb9ac3a1 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Wed, 22 Feb 2023 19:05:04 +0530 Subject: [PATCH 15/42] feat: Improve TOTP recipe - Add created_time_ms - Run cron to delete expired used tokens - Add feature to delete all TOTP data on user deletion - Simulate foreign key constraint in totp_used_codes for inmemorydb - Refactor and clean TOTP.java code - Use java-totp to generate secret key and verify code - Add and update tests --- src/main/java/io/supertokens/Main.java | 4 + .../io/supertokens/authRecipe/AuthRecipe.java | 1 + .../DeleteExpiredTotpTokens.java | 9 + .../java/io/supertokens/inmemorydb/Start.java | 34 +-- .../inmemorydb/queries/TOTPQueries.java | 106 +++++++-- src/main/java/io/supertokens/totp/Totp.java | 214 +++++++++++------- .../supertokens/test/totp/TOTPRecipeTest.java | 43 ++-- .../test/totp/TOTPStorageTest.java | 70 ++++-- 8 files changed, 332 insertions(+), 149 deletions(-) diff --git a/src/main/java/io/supertokens/Main.java b/src/main/java/io/supertokens/Main.java index 4fea72159..2a1575498 100644 --- a/src/main/java/io/supertokens/Main.java +++ b/src/main/java/io/supertokens/Main.java @@ -25,6 +25,7 @@ import io.supertokens.cronjobs.deleteExpiredPasswordResetTokens.DeleteExpiredPasswordResetTokens; import io.supertokens.cronjobs.deleteExpiredPasswordlessDevices.DeleteExpiredPasswordlessDevices; import io.supertokens.cronjobs.deleteExpiredSessions.DeleteExpiredSessions; +import io.supertokens.cronjobs.deleteExpiredTotpTokens.DeleteExpiredTotpTokens; import io.supertokens.cronjobs.telemetry.Telemetry; import io.supertokens.emailpassword.PasswordHashing; import io.supertokens.exceptions.QuitProgramException; @@ -204,6 +205,9 @@ private void init() throws IOException { // removes passwordless devices with only expired codes Cronjobs.addCronjob(this, DeleteExpiredPasswordlessDevices.getInstance(this)); + // removes expired TOTP used tokens + Cronjobs.addCronjob(this, DeleteExpiredTotpTokens.getInstance(this)); + // starts Telemetry cronjob if the user has not disabled it if (!Config.getConfig(this).isTelemetryDisabled()) { Cronjobs.addCronjob(this, Telemetry.getInstance(this)); diff --git a/src/main/java/io/supertokens/authRecipe/AuthRecipe.java b/src/main/java/io/supertokens/authRecipe/AuthRecipe.java index 98af4fe7e..3080a48b4 100644 --- a/src/main/java/io/supertokens/authRecipe/AuthRecipe.java +++ b/src/main/java/io/supertokens/authRecipe/AuthRecipe.java @@ -103,6 +103,7 @@ private static void deleteNonAuthRecipeUser(Main main, String userId) throws Sto StorageLayer.getSessionStorage(main).deleteSessionsOfUser(userId); StorageLayer.getEmailVerificationStorage(main).deleteEmailVerificationUserInfo(userId); StorageLayer.getUserRolesStorage(main).deleteAllRolesForUser(userId); + StorageLayer.getTOTPStorage(main).deleteAllDataForUser(userId); } private static void deleteAuthRecipeUser(Main main, String userId) throws StorageQueryException { diff --git a/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java b/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java index 3faec9c99..aa794fc3c 100644 --- a/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java +++ b/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java @@ -1,6 +1,7 @@ package io.supertokens.cronjobs.deleteExpiredTotpTokens; import io.supertokens.Main; +import io.supertokens.ResourceDistributor; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; import io.supertokens.cronjobs.CronTask; @@ -15,6 +16,14 @@ private DeleteExpiredTotpTokens(Main main) { super("DeleteExpiredTotpTokens", main); } + public static DeleteExpiredTotpTokens getInstance(Main main) { + ResourceDistributor.SingletonResource instance = main.getResourceDistributor().getResource(RESOURCE_KEY); + if (instance == null) { + instance = main.getResourceDistributor().setResource(RESOURCE_KEY, new DeleteExpiredTotpTokens(main)); + } + return (DeleteExpiredTotpTokens) instance; + } + @Override protected void doTask() throws Exception { if (StorageLayer.getStorage(this.main).getType() != STORAGE_TYPE.SQL) { diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index e2875acd4..d3f9da21b 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -1652,16 +1652,9 @@ public boolean markDeviceAsVerified(String userId, String deviceName) try { int updatedCount = TOTPQueries.markDeviceAsVerified(this, userId, deviceName); if (updatedCount == 0) { - TOTPDevice[] devices = TOTPQueries.getDevices(this, userId); - for (TOTPDevice device : devices) { - if (device.deviceName.equals(deviceName) && device.verified) { - return true; // Device was already verified - } - } - // Device was not found: throw new UnknownDeviceException(); } - return false; // Device was marked as verified + return true; // Device was marked as verified } catch (SQLException e) { throw new StorageQueryException(e); } @@ -1711,18 +1704,18 @@ public TOTPDevice[] getDevices(String userId) } @Override - public boolean insertUsedCode(TOTPUsedCode usedCodeObj) + public void insertUsedCode(TOTPUsedCode usedCodeObj) throws StorageQueryException, TotpNotEnabledException { try { - TOTPDevice[] devices = TOTPQueries.getDevices(this, usedCodeObj.userId); - if (devices.length == 0) { + TOTPQueries.insertUsedCode(this, usedCodeObj); + } catch (StorageTransactionLogicException e) { + String message = e.actualException.getMessage(); + if (message + .equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (FOREIGN KEY constraint failed)")) { + // No user/device exists for the given usedCodeObj.userId throw new TotpNotEnabledException(); } - - int insertCount = TOTPQueries.insertUsedCode(this, usedCodeObj); - return insertCount == 1; - } catch (SQLException e) { - throw new StorageQueryException(e); + throw new StorageQueryException(e.actualException); } } @@ -1745,4 +1738,13 @@ public void removeExpiredCodes() throw new StorageQueryException(e); } } + + @Override + public void deleteAllDataForUser(String userId) throws StorageQueryException { + try { + TOTPQueries.deleteAllDataForUser(this, userId); + } catch (StorageTransactionLogicException e) { + throw new StorageQueryException(e); + } + } } diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index b409f9ec8..06a294ea2 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -38,6 +38,7 @@ public static String getQueryToCreateUsedCodesTable(Start start) { return "CREATE TABLE IF NOT EXISTS " + Config.getConfig(start).getTotpUsedCodesTable() + " (" + "user_id VARCHAR(128) NOT NULL, " + "device_name VARCHAR(256), " + "code CHAR(6) NOT NULL," + "is_valid_code BOOLEAN NOT NULL," + + "created_time_ms BIGINT UNSIGNED NOT NULL," + "expiry_time_ms BIGINT UNSIGNED NOT NULL," + "FOREIGN KEY (user_id) REFERENCES " + Config.getConfig(start).getTotpUsersTable() + "(user_id) ON DELETE CASCADE);"; @@ -46,6 +47,7 @@ public static String getQueryToCreateUsedCodesTable(Start start) { public static String getQueryToCreateUsedCodesIndex(Start start) { return "CREATE INDEX IF NOT EXISTS totp_used_codes_expiry_time_ms_index ON " + Config.getConfig(start).getTotpUsedCodesTable() + " (expiry_time_ms)"; + // TODO: Create index on created_time_ms as well } public static int insertUser_Transaction(Start start, Connection con, String userId) @@ -94,7 +96,7 @@ public static void createDeviceAndUser(Start start, TOTPDevice device) public static int markDeviceAsVerified(Start start, String userId, String deviceName) throws StorageQueryException, SQLException { String QUERY = "UPDATE " + Config.getConfig(start).getTotpUserDevicesTable() - + " SET verified = true WHERE user_id = ? AND device_name = ? AND verified = false;"; + + " SET verified = true WHERE user_id = ? AND device_name = ?"; return update(start, QUERY, pst -> { pst.setString(1, userId); pst.setString(2, deviceName); @@ -137,7 +139,7 @@ public static int deleteDevice(Start start, String userId, String deviceName) try { int deletedCount = deleteDevice_Transaction(start, sqlCon, userId, deviceName); if (deletedCount > 0) { - // some device deleted. Check if user has any other device left: + // Some device was deleted. Check if user has any other device left: int devicesCount = getDevicesCount_Transaction(start, sqlCon, userId); if (devicesCount == 0) { // no device left. delete user @@ -190,22 +192,69 @@ public static int getDevicesCount_Transaction(Start start, Connection con, Strin }); } - public static int insertUsedCode(Start start, TOTPUsedCode code) throws SQLException, StorageQueryException { + private static int insertUsedCode_Transaction(Start start, Connection con, TOTPUsedCode code) + throws SQLException, StorageQueryException { String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUsedCodesTable() - + " (user_id, code, is_valid_code, expiry_time_ms) VALUES (?, ?, ?, ?);"; + + " (user_id, code, is_valid_code, expiry_time_ms, created_time_ms) VALUES (?, ?, ?, ?, ?);"; - return update(start, QUERY, pst -> { + return update(con, QUERY, pst -> { pst.setString(1, code.userId); pst.setString(2, code.code); pst.setBoolean(3, code.isValidCode); pst.setLong(4, code.expiryTime); + pst.setLong(5, code.createdTime); }); } + public static void insertUsedCode(Start start, TOTPUsedCode code) + throws StorageQueryException, StorageTransactionLogicException { + start.startTransaction(con -> { + Connection sqlCon = (Connection) con.getConnection(); + + try { + // Check if user exists or not (if no device exists, user does not exist) + // NOTE: This step is required only for in-memory db. + int devicesCount = getDevicesCount_Transaction(start, sqlCon, code.userId); + if (devicesCount == 0) { + // no device left. transaction cannot be completed. + // foreign key constraint will fail. + throw new SQLException( + "[SQLITE_CONSTRAINT] Abort due to constraint violation (FOREIGN KEY constraint failed)"); + } + + insertUsedCode_Transaction(start, sqlCon, code); + sqlCon.commit(); + } catch (SQLException e) { + throw new StorageTransactionLogicException(e); + } + + return null; + }); + + // String QUERY = "INSERT INTO " + + // Config.getConfig(start).getTotpUsedCodesTable() + // + " (user_id, code, is_valid_code, expiry_time_ms, created_time_ms) VALUES + // (?, ?, ?, ?, ?);"; + + // return update(start, QUERY, pst -> { + // pst.setString(1, code.userId); + // pst.setString(2, code.code); + // pst.setBoolean(3, code.isValidCode); + // pst.setLong(4, code.expiryTime); + // pst.setLong(5, code.createdTime); + // }); + } + public static TOTPUsedCode[] getUsedCodes(Start start, String userId) throws SQLException, StorageQueryException { String QUERY = "SELECT * FROM " + Config.getConfig(start).getTotpUsedCodesTable() - + " WHERE user_id = ? AND expiry_time_ms > ?;"; + + " WHERE user_id = ? AND expiry_time_ms > ? ORDER BY created_time_ms DESC;"; // FIXME: Should be based + // on creation_time + // because + // of different devices + // having different expiry + // times (bcoz of period + // and skew values) return execute(start, QUERY, pst -> { pst.setString(1, userId); pst.setLong(2, System.currentTimeMillis()); @@ -227,15 +276,45 @@ public static int removeExpiredCodes(Start start) return update(start, QUERY, pst -> pst.setLong(1, System.currentTimeMillis())); } - public static int removeUsedCodes(Start start, String userId) - throws StorageQueryException, SQLException { - // Remove codes where userId matches the given userId - // ONLY required for inmemorydb, as it does not support foreign key constraints. + public static int deleteAllDevices_Transaction(Start start, Connection con, String userId) + throws SQLException, StorageQueryException { + String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUserDevicesTable() + + " WHERE user_id = ?;"; + return update(con, QUERY, pst -> pst.setString(1, userId)); + } + + public static int deleteAllUsedCodes_Transaction(Start start, Connection con, String userId) + throws SQLException, StorageQueryException { String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUsedCodesTable() + " WHERE user_id = ?;"; + return update(con, QUERY, pst -> pst.setString(1, userId)); + } - return update(start, QUERY, pst -> { - pst.setString(1, userId); + public static int deleteUser_Transaction(Start start, Connection con, String userId) + throws SQLException, StorageQueryException { + String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUsersTable() + + " WHERE user_id = ?;"; + return update(con, QUERY, pst -> pst.setString(1, userId)); + } + + public static void deleteAllDataForUser(Start start, String userId) + throws StorageQueryException, StorageTransactionLogicException { + start.startTransaction(con -> { + Connection sqlCon = (Connection) con.getConnection(); + + try { + // NOTE: These two steps are required only for in-memory db. + // Since foreign key constraints are not supported in in-memory db. + deleteAllDevices_Transaction(start, sqlCon, userId); + deleteAllUsedCodes_Transaction(start, sqlCon, userId); + + deleteUser_Transaction(start, sqlCon, userId); + sqlCon.commit(); + } catch (SQLException e) { + throw new StorageTransactionLogicException(e); + } + + return null; }); } @@ -277,7 +356,8 @@ public TOTPUsedCode map(ResultSet result) throws SQLException { result.getString("user_id"), result.getString("code"), result.getBoolean("is_valid_code"), - result.getLong("expiry_time_ms")); + result.getLong("expiry_time_ms"), + result.getLong("created_time_ms")); } } } diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index 7b2c1bd9c..4cfdb5834 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -1,8 +1,24 @@ package io.supertokens.totp; +import java.security.InvalidKeyException; +import java.security.Key; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.time.Instant; +import java.util.Arrays; +import java.util.Base64; + +import javax.crypto.KeyGenerator; +import javax.crypto.Mac; +import javax.crypto.spec.SecretKeySpec; + +import java.util.Base64; + import io.supertokens.Main; +import com.eatthepath.otp.TimeBasedOneTimePasswordGenerator; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.totp.TOTPStorage; import io.supertokens.pluginInterface.totp.TOTPUsedCode; import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; @@ -14,52 +30,94 @@ public class Totp { - public static String generateSecret() { - return "XXXX"; + public static String generateSecret() throws Exception { + final String TOTP_ALGORITHM = "HmacSHA1"; + + try { + final KeyGenerator keyGenerator = KeyGenerator.getInstance(TOTP_ALGORITHM); + keyGenerator.init(160); // 160 bits = 20 bytes + + // FIXME: Should return base32 or base16 + // Return base64 string of the secret key: + return Base64.getEncoder().encodeToString(keyGenerator.generateKey().getEncoded()); + } catch (NoSuchAlgorithmException e) { + throw new Exception("TOTP algorithm not found"); + } } public static boolean checkCode(TOTPDevice device, String code) { - if (code.startsWith("XXXX")) { - return true; + final TimeBasedOneTimePasswordGenerator totp = new TimeBasedOneTimePasswordGenerator(); + + byte[] keyBytes = Base64.getDecoder().decode(device.secretKey); + Key key = new SecretKeySpec(keyBytes, "HmacSHA1"); + + final int period = device.period; + final int skew = device.skew; + + // Check if code is valid for any of the time periods in the skew: + for (int i = -skew; i <= skew; i++) { + try { + if (totp.generateOneTimePasswordString(key, Instant.now().plusSeconds(i * period)).equals(code)) { + return true; + } + } catch (InvalidKeyException e) { + return false; + } } + return false; } - public static String createDevice(Main main, String userId, String deviceName, int skew, int period) - throws StorageQueryException, DeviceAlreadyExistsException { + public static String registerDevice(Main main, String userId, String deviceName, int skew, int period) + throws StorageQueryException, DeviceAlreadyExistsException, Exception { TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); - String secret = generateSecret(); + String secret = generateSecret(); // TODO: should we handle Exception differently? TOTPDevice device = new TOTPDevice(userId, deviceName, secret, period, skew, false); totpStorage.createDevice(device); return secret; } - public static boolean verifyDevice(Main main, String userId, String deviceName, String code) - throws StorageQueryException, TotpNotEnabledException, UnknownDeviceException, InvalidTotpException { - // Here boolean return value tells whether the device was already verified + private static void checkAndStoreCode(TOTPStorage totpStorage, String userId, TOTPDevice[] devices, + TOTPUsedCode[] usedCodes, String code) + throws InvalidTotpException, StorageQueryException, TotpNotEnabledException { - TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); + // Check if the code is valid for any device: + boolean isValid = false; + TOTPDevice matchingDevice = null; + for (TOTPDevice device : devices) { + // Check if the code is valid for this device: + if (checkCode(device, code)) { + isValid = true; + matchingDevice = device; + break; + } + } - boolean deviceAlreadyVerified = totpStorage.markDeviceAsVerified(userId, deviceName); + // Check if the code has been successfully used by the user (for any device): + for (TOTPUsedCode usedCode : usedCodes) { + if (usedCode.code.equals(code)) { + throw new InvalidTotpException(); + } + } - if (deviceAlreadyVerified) - return true; + // Insert the code into the list of used codes: + long now = System.currentTimeMillis(); + int expireInSec = isValid ? matchingDevice.period * (2 * matchingDevice.skew + 1) : 60 * 5; - try { - verifyCode(main, userId, code, true); - return false; - } catch (LimitReachedException e) { - throw new InvalidTotpException(); - } + TOTPUsedCode newCode = new TOTPUsedCode(userId, code, isValid, now + 1000 * expireInSec, now); + totpStorage.insertUsedCode(newCode); } - public static void verifyCode(Main main, String userId, String code, boolean allowUnverifiedDevices) - throws StorageQueryException, TotpNotEnabledException, InvalidTotpException, - LimitReachedException { + public static boolean verifyDevice(Main main, String userId, String deviceName, String code) + throws StorageQueryException, TotpNotEnabledException, UnknownDeviceException, InvalidTotpException { + // Here boolean return value tells whether the device has been + // newly verified (true) OR it was already verified (false) + TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); + TOTPDevice matchingDevice = null; // Check if the user has any devices: TOTPDevice[] devices = totpStorage.getDevices(userId); @@ -67,86 +125,66 @@ public static void verifyCode(Main main, String userId, String code, boolean all throw new TotpNotEnabledException(); } - // If allowUnverifiedDevices is false, then remove all unverified devices from - // the list: - if (!allowUnverifiedDevices) { - int verifiedDeviceCount = 0; - for (TOTPDevice device : devices) { - if (device.verified) { - verifiedDeviceCount++; - } - } - - TOTPDevice[] verifiedDevices = new TOTPDevice[verifiedDeviceCount]; - int index = 0; - for (TOTPDevice device : devices) { + // Check if the requested device exists: + for (TOTPDevice device : devices) { + if (device.deviceName.equals(deviceName)) { + matchingDevice = device; if (device.verified) { - verifiedDevices[index] = device; - index++; + return false; } + break; } - - devices = verifiedDevices; } - // Try different devices until we find one that works: - boolean isValid = false; - TOTPDevice matchingDevice = null; - for (TOTPDevice device : devices) { - // Check if the code is valid for this device: - if (checkCode(device, code)) { - isValid = true; - matchingDevice = device; - break; - } + // No device found: + if (matchingDevice == null) { + throw new UnknownDeviceException(); } - // Check if the code has been successfully used by the user (for any of their - // devices): + // If the device is not verified, check if the code is valid and unused. + // If it is successful, mark the device as verified. TOTPUsedCode[] usedCodes = totpStorage.getNonExpiredUsedCodes(userId); - for (TOTPUsedCode usedCode : usedCodes) { - if (usedCode.code.equals(code)) { // FIXME: Only check for the same device for better UX? - throw new InvalidTotpException(); - } - } + checkAndStoreCode(totpStorage, userId, new TOTPDevice[] { matchingDevice }, usedCodes, code); + totpStorage.markDeviceAsVerified(userId, deviceName); + return true; + } - // Insert the code into the list of used codes: - TOTPUsedCode newCode = null; - if (matchingDevice == null) { - // TODO: Verify that this doesn't pile up OR gets deleted very quickly: - int expireAfterSeconds = 60 * 5; // 5 minutes - newCode = new TOTPUsedCode(userId, code, isValid, - System.currentTimeMillis() + 1000 * expireAfterSeconds); - } else { - int expireAfterSeconds = matchingDevice.period * (2 * matchingDevice.period + 1); - newCode = new TOTPUsedCode(userId, code, isValid, - System.currentTimeMillis() + 1000 * expireAfterSeconds); + public static void verifyCode(Main main, String userId, String code, boolean allowUnverifiedDevices) + throws StorageQueryException, TotpNotEnabledException, InvalidTotpException, + LimitReachedException { + TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); + + // Check if the user has any devices: + TOTPDevice[] devices = totpStorage.getDevices(userId); + if (devices.length == 0) { + throw new TotpNotEnabledException(); } - totpStorage.insertUsedCode(newCode); - if (isValid) { - return; + // Filter out unverified devices: + if (!allowUnverifiedDevices) { + devices = Arrays.stream(devices).filter(device -> device.verified).toArray(TOTPDevice[]::new); } - // Now we know that the code is invalid. - - // Check if last N codes are all invalid: - // Note that usedCodes will get updated when: - // - A valid code is used: It will break the chain of invalid codes. - // - Cron job runs: deletes expired codes every 1 hour - int WINDOW_SIZE = 3; - int invalidCodes = 0; - for (int i = usedCodes.length - 1; i >= 0 && i >= usedCodes.length - WINDOW_SIZE; i--) { - if (!usedCodes[i].isValidCode) { - invalidCodes++; + TOTPUsedCode[] usedCodes = totpStorage.getNonExpiredUsedCodes(userId); + + try { + checkAndStoreCode(totpStorage, userId, devices, usedCodes, code); + } catch (InvalidTotpException e) { + // Now we know that the code is invalid. + // Check if latest 3 codes are all invalid: + + // Note: usedCodes will get updated when + // - A valid code is used: It will break the chain of invalid codes. + // - Cron job runs: deletes expired codes every hour + + // All the latest 3 codes are invalid: + if (Arrays.stream(usedCodes).limit(3).allMatch(usedCode -> !usedCode.isValidCode)) { + throw new LimitReachedException(); } - } - if (invalidCodes == WINDOW_SIZE) { - throw new LimitReachedException(); - } - // Code is invalid and the user has not exceeded the limit: - throw new InvalidTotpException(); + // Code is invalid but the user has not exceeded the limit: + throw e; + } } public static void deleteDevice(Main main, String userId, String deviceName) diff --git a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java index e74a42722..e14c5dfd9 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java @@ -34,6 +34,7 @@ import io.supertokens.totp.Totp; import io.supertokens.totp.exceptions.InvalidTotpException; +import io.supertokens.totp.exceptions.LimitReachedException; import io.supertokens.pluginInterface.totp.TOTPStorage; import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; @@ -84,24 +85,32 @@ public void createDevice() throws Exception { Main main = result.process.getProcess(); // Create device - String secret = Totp.createDevice(main, "userId", "deviceName", 1, 30); + String secret = Totp.registerDevice(main, "user", "device1", 1, 30); assert secret != ""; // Create same device again (should fail) - assertThrows(DeviceAlreadyExistsException.class, () -> Totp.createDevice(main, "userId", "deviceName", 1, 30)); + assertThrows(DeviceAlreadyExistsException.class, + () -> Totp.registerDevice(main, "user", "device1", 1, 30)); } public void triggerRateLimit(Main main) throws Exception { - for (int i = 0; i < 4; i++) { - assertThrows( - InvalidTotpException.class, - () -> Totp.verifyCode(main, "user", "wrong-code", true)); - } - - // 5th attempt should fail with rate limiting error: + // First 2 attempts should fail with invalid code: + assertThrows( + InvalidTotpException.class, + () -> Totp.verifyCode(main, "user", "wrong-code-1", true)); assertThrows( InvalidTotpException.class, - () -> Totp.verifyCode(main, "user", "XXXX-code", true)); + () -> Totp.verifyCode(main, "user", "wrong-code-2", true)); + + // 3th attempt should fail with rate limiting error: + assertThrows( + LimitReachedException.class, + () -> Totp.verifyCode(main, "user", "wrong-code-3", true)); + } + + @Test + public void createDeviceAndVerifyCodeAgainstUnverifiedDevices() throws Exception { + } @Test @@ -110,7 +119,7 @@ public void createDeviceAndVerifyCode() throws Exception { Main main = result.process.getProcess(); // Create device - String secret = Totp.createDevice(main, "userId", "deviceName", 1, 30); + String secret = Totp.registerDevice(main, "user", "deviceName", 1, 30); // Try login with non-existent user: assertThrows(TotpNotEnabledException.class, @@ -124,12 +133,16 @@ public void createDeviceAndVerifyCode() throws Exception { assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "XXXX-code", false)); + // Try logging in with same again code but allowUnverifiedDevice = true: + // TODO: Think if this is correct + assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "XXXX-code", true)); + // Successfully login: - Totp.verifyCode(main, "user", "XXXX-code", true); + Totp.verifyCode(main, "user", "XXXX-code2", true); // Now try again with same code: assertThrows( InvalidTotpException.class, - () -> Totp.verifyCode(main, "user", "XXXX-code", true)); + () -> Totp.verifyCode(main, "user", "XXXX-code2", true)); // Trigger rate limiting and fix it with a correct code: { @@ -154,7 +167,7 @@ public void createAndVerifyDevice() throws Exception { // Create device // FIXME: Use secret to generate actual TOTP code - String secret = Totp.createDevice(main, "userId", "deviceName", 1, 30); + String secret = Totp.registerDevice(main, "userId", "deviceName", 1, 30); // Try verify non-existent user: assertThrows(TotpNotEnabledException.class, @@ -168,7 +181,7 @@ public void createAndVerifyDevice() throws Exception { assertThrows(InvalidTotpException.class, () -> Totp.verifyDevice(main, "userId", "deviceName", "wrong-code")); // Verify device with correct code - boolean deviceAlreadyVerified = Totp.verifyDevice(main, "userId", "deviceName", "XXXX"); + boolean deviceAlreadyVerified = Totp.verifyDevice(main, "userId", "deviceName", "XXXX-correct"); assert !deviceAlreadyVerified; // Verify again with same correct code: diff --git a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java index 24a537b7f..a1bd42f37 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java @@ -140,11 +140,11 @@ public void deleteDeviceTests() throws Exception { assert (storedDevices.length == 1); // device2 should still be there long nextDay = System.currentTimeMillis() + 1000 * 60 * 60 * 24; // 1 day from now - + long now = System.currentTimeMillis(); // Deleting all devices of a user should delete all related codes: { - TOTPUsedCode validCode = new TOTPUsedCode("user", "valid-code", true, nextDay); - TOTPUsedCode invalidCode = new TOTPUsedCode("user", "invalid-code", false, nextDay); + TOTPUsedCode validCode = new TOTPUsedCode("user", "valid-code", true, nextDay, now); + TOTPUsedCode invalidCode = new TOTPUsedCode("user", "invalid-code", false, nextDay, now); storage.insertUsedCode(validCode); storage.insertUsedCode(invalidCode); @@ -220,13 +220,12 @@ public void insertUsedCodeTest() throws Exception { // Insert a long lasting valid code and check that it's returned when queried: { TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); - TOTPUsedCode code = new TOTPUsedCode("user", "1234", true, nextDay); + TOTPUsedCode code = new TOTPUsedCode("user", "1234", true, nextDay, System.currentTimeMillis()); storage.createDevice(device); - boolean isInserted = storage.insertUsedCode(code); + storage.insertUsedCode(code); TOTPUsedCode[] usedCodes = storage.getNonExpiredUsedCodes("user"); - assert (isInserted); assert (usedCodes.length == 1); assert usedCodes[0].equals(code); } @@ -235,19 +234,21 @@ public void insertUsedCodeTest() throws Exception { { storage.deleteDevice("user", "device"); assertThrows(TotpNotEnabledException.class, - () -> storage.insertUsedCode(new TOTPUsedCode("user", "1234", true, nextDay))); + () -> storage.insertUsedCode( + new TOTPUsedCode("user", "1234", true, nextDay, System.currentTimeMillis()))); } // Try to insert code after user has atleast one device (i.e. TOTP enabled) { TOTPDevice newDevice = new TOTPDevice("user", "new-device", "secretKey", 30, 1, false); storage.createDevice(newDevice); - storage.insertUsedCode(new TOTPUsedCode("user", "1234", true, nextDay)); + storage.insertUsedCode(new TOTPUsedCode("user", "1234", true, nextDay, System.currentTimeMillis())); } // Try to insert code when user doesn't exist: assertThrows(TotpNotEnabledException.class, - () -> storage.insertUsedCode(new TOTPUsedCode("non-existent-user", "1234", true, nextDay))); + () -> storage.insertUsedCode( + new TOTPUsedCode("non-existent-user", "1234", true, nextDay, System.currentTimeMillis()))); } @Test @@ -260,12 +261,13 @@ public void getNonExpiredUsedCodesTest() throws Exception { long nextDay = System.currentTimeMillis() + 1000 * 60 * 60 * 24; // 1 day from now long prevDay = System.currentTimeMillis() - 1000 * 60 * 60 * 24; // 1 day ago + long now = System.currentTimeMillis(); TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); - TOTPUsedCode validCode = new TOTPUsedCode("user", "code1", true, nextDay); - TOTPUsedCode invalidCode = new TOTPUsedCode("user", "code2", false, nextDay); - TOTPUsedCode expiredCode = new TOTPUsedCode("user", "expired-code", true, prevDay); - TOTPUsedCode expiredInvalidCode = new TOTPUsedCode("user", "expired-invalid-code", false, prevDay); + TOTPUsedCode validCode = new TOTPUsedCode("user", "code1", true, nextDay, now); + TOTPUsedCode invalidCode = new TOTPUsedCode("user", "code2", false, nextDay, now); + TOTPUsedCode expiredCode = new TOTPUsedCode("user", "expired-code", true, prevDay, now); + TOTPUsedCode expiredInvalidCode = new TOTPUsedCode("user", "expired-invalid-code", false, prevDay, now); storage.createDevice(device); storage.insertUsedCode(validCode); @@ -284,14 +286,15 @@ public void removeExpiredCodesTest() throws Exception { TestSetupResult result = setup(); TOTPStorage storage = result.storage; + long now = System.currentTimeMillis(); long nextDay = System.currentTimeMillis() + 1000 * 60 * 60 * 24; // 1 day from now long halfSecond = System.currentTimeMillis() + 500; // 500ms from now TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); - TOTPUsedCode validCodeToLive = new TOTPUsedCode("user", "valid-code", true, nextDay); - TOTPUsedCode invalidCodeToLive = new TOTPUsedCode("user", "invalid-code", false, nextDay); - TOTPUsedCode validCodeToExpire = new TOTPUsedCode("user", "valid-code", true, halfSecond); - TOTPUsedCode invalidCodeToExpire = new TOTPUsedCode("user", "invalid-code", false, halfSecond); + TOTPUsedCode validCodeToLive = new TOTPUsedCode("user", "valid-code", true, nextDay, now); + TOTPUsedCode invalidCodeToLive = new TOTPUsedCode("user", "invalid-code", false, nextDay, now); + TOTPUsedCode validCodeToExpire = new TOTPUsedCode("user", "valid-code", true, halfSecond, now); + TOTPUsedCode invalidCodeToExpire = new TOTPUsedCode("user", "invalid-code", false, halfSecond, now); storage.createDevice(device); storage.insertUsedCode(validCodeToLive); @@ -312,4 +315,37 @@ public void removeExpiredCodesTest() throws Exception { assert (usedCodes[0].equals(validCodeToLive)); assert (usedCodes[1].equals(invalidCodeToLive)); } + + @Test + public void deleteAllDataForUserTest() throws Exception { + TestSetupResult result = setup(); + TOTPStorage storage = result.storage; + + long now = System.currentTimeMillis(); + long nextDay = now + 1000 * 60 * 60 * 24; // 1 day from now + + TOTPDevice device1 = new TOTPDevice("user", "d1", "secretKey", 30, 1, false); + TOTPDevice device2 = new TOTPDevice("user", "d2", "secretKey", 30, 1, false); + TOTPUsedCode validCode = new TOTPUsedCode("user", "d1-valid", true, nextDay, now); + TOTPUsedCode invalidCode = new TOTPUsedCode("user", "invalid-code", false, nextDay, now); + + storage.createDevice(device1); + storage.createDevice(device2); + storage.insertUsedCode(validCode); + storage.insertUsedCode(invalidCode); + + TOTPDevice[] storedDevices = storage.getDevices("user"); + TOTPUsedCode[] usedCodes = storage.getNonExpiredUsedCodes("user"); + + assert (storedDevices.length == 2); + assert (usedCodes.length == 2); + + storage.deleteAllDataForUser("user"); + + storedDevices = storage.getDevices("user"); + usedCodes = storage.getNonExpiredUsedCodes("user"); + + assert (storedDevices.length == 0); + assert (usedCodes.length == 0); + } } From c26ae127c215c4a1cd0300f0e5b262b651994f5f Mon Sep 17 00:00:00 2001 From: KShivendu Date: Thu, 23 Feb 2023 17:24:24 +0530 Subject: [PATCH 16/42] feat: Improve TOTP recipe - Add config for totp_rate_limit_window_size - Improve function names and return types - Use `is_valid` for totp_used_code - Expose function to generate TOTP code for tests to use --- .../io/supertokens/config/CoreConfig.java | 30 ++++++++++---- .../java/io/supertokens/inmemorydb/Start.java | 6 +-- .../inmemorydb/queries/TOTPQueries.java | 36 +++++------------ src/main/java/io/supertokens/totp/Totp.java | 40 +++++++++++++++---- .../supertokens/test/totp/TOTPRecipeTest.java | 1 + .../test/totp/TOTPStorageTest.java | 2 +- 6 files changed, 71 insertions(+), 44 deletions(-) diff --git a/src/main/java/io/supertokens/config/CoreConfig.java b/src/main/java/io/supertokens/config/CoreConfig.java index 516d771ba..83e97a8ce 100644 --- a/src/main/java/io/supertokens/config/CoreConfig.java +++ b/src/main/java/io/supertokens/config/CoreConfig.java @@ -56,6 +56,9 @@ public class CoreConfig { @JsonProperty private long passwordless_code_lifetime = 900000; // in MS + @JsonProperty + private int totp_rate_limit_window_size = 5; + private final String logDefault = "asdkfahbdfk3kjHS"; @JsonProperty private String info_log_path = logDefault; @@ -106,10 +109,13 @@ public class CoreConfig { private int bcrypt_log_rounds = 11; // TODO: add https in later version -// # (OPTIONAL) boolean value (true or false). Set to true if you want to enable https requests to SuperTokens. -// # If you are not running SuperTokens within a closed network along with your API process, for -// # example if you are using multiple cloud vendors, then it is recommended to set this to true. -// # webserver_https_enabled: + // # (OPTIONAL) boolean value (true or false). Set to true if you want to enable + // https requests to SuperTokens. + // # If you are not running SuperTokens within a closed network along with your + // API process, for + // # example if you are using multiple cloud vendors, then it is recommended to + // set this to true. + // # webserver_https_enabled: @JsonProperty private boolean webserver_https_enabled = false; @@ -191,9 +197,11 @@ public enum PASSWORD_HASHING_ALG { } public int getArgon2HashingPoolSize() { - // the reason we do Math.max below is that if the password hashing algo is bcrypt, + // the reason we do Math.max below is that if the password hashing algo is + // bcrypt, // then we don't check the argon2 hashing pool size config at all. In this case, - // if the user gives a <= 0 number, it crashes the core (since it creates a blockedqueue in PaswordHashing + // if the user gives a <= 0 number, it crashes the core (since it creates a + // blockedqueue in PaswordHashing // .java with length <= 0). So we do a Math.max return Math.max(1, argon2_hashing_pool_size); } @@ -266,6 +274,10 @@ public long getPasswordlessCodeLifetime() { return passwordless_code_lifetime; } + public int getTotpRateLimitWindowSize() { + return totp_rate_limit_window_size; + } + public boolean isTelemetryDisabled() { return disable_telemetry; } @@ -384,6 +396,10 @@ void validateAndInitialise(Main main) throws IOException { throw new QuitProgramException("'passwordless_max_code_input_attempts' must be > 0"); } + if (totp_rate_limit_window_size <= 0) { + throw new QuitProgramException("'totp_rate_limit_window_size' must be > 0"); + } + if (max_server_pool_size <= 0) { throw new QuitProgramException("'max_server_pool_size' must be >= 1. The config file can be found here: " + getConfigFileLocation(main)); @@ -475,4 +491,4 @@ void validateAndInitialise(Main main) throws IOException { } } -} \ No newline at end of file +} diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index d3f9da21b..9f35ab84c 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -1647,14 +1647,14 @@ public void createDevice(TOTPDevice device) throws StorageQueryException, Device } @Override - public boolean markDeviceAsVerified(String userId, String deviceName) + public void markDeviceAsVerified(String userId, String deviceName) throws StorageQueryException, UnknownDeviceException { try { int updatedCount = TOTPQueries.markDeviceAsVerified(this, userId, deviceName); if (updatedCount == 0) { throw new UnknownDeviceException(); } - return true; // Device was marked as verified + return; // Device was marked as verified } catch (SQLException e) { throw new StorageQueryException(e); } @@ -1723,7 +1723,7 @@ public void insertUsedCode(TOTPUsedCode usedCodeObj) public TOTPUsedCode[] getNonExpiredUsedCodes(String userId) throws StorageQueryException { try { - return TOTPQueries.getUsedCodes(this, userId); + return TOTPQueries.getNonExpiredUsedCodesDescOrder(this, userId); } catch (SQLException e) { throw new StorageQueryException(e); } diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index 06a294ea2..b3892492d 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -37,7 +37,7 @@ public static String getQueryToCreateUserDevicesTable(Start start) { public static String getQueryToCreateUsedCodesTable(Start start) { return "CREATE TABLE IF NOT EXISTS " + Config.getConfig(start).getTotpUsedCodesTable() + " (" + "user_id VARCHAR(128) NOT NULL, " + "device_name VARCHAR(256), " - + "code CHAR(6) NOT NULL," + "is_valid_code BOOLEAN NOT NULL," + + "code CHAR(6) NOT NULL," + "is_valid BOOLEAN NOT NULL," + "created_time_ms BIGINT UNSIGNED NOT NULL," + "expiry_time_ms BIGINT UNSIGNED NOT NULL," + "FOREIGN KEY (user_id) REFERENCES " + Config.getConfig(start).getTotpUsersTable() @@ -195,12 +195,12 @@ public static int getDevicesCount_Transaction(Start start, Connection con, Strin private static int insertUsedCode_Transaction(Start start, Connection con, TOTPUsedCode code) throws SQLException, StorageQueryException { String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUsedCodesTable() - + " (user_id, code, is_valid_code, expiry_time_ms, created_time_ms) VALUES (?, ?, ?, ?, ?);"; + + " (user_id, code, is_valid, expiry_time_ms, created_time_ms) VALUES (?, ?, ?, ?, ?);"; return update(con, QUERY, pst -> { pst.setString(1, code.userId); pst.setString(2, code.code); - pst.setBoolean(3, code.isValidCode); + pst.setBoolean(3, code.isValid); pst.setLong(4, code.expiryTime); pst.setLong(5, code.createdTime); }); @@ -230,31 +230,17 @@ public static void insertUsedCode(Start start, TOTPUsedCode code) return null; }); - - // String QUERY = "INSERT INTO " + - // Config.getConfig(start).getTotpUsedCodesTable() - // + " (user_id, code, is_valid_code, expiry_time_ms, created_time_ms) VALUES - // (?, ?, ?, ?, ?);"; - - // return update(start, QUERY, pst -> { - // pst.setString(1, code.userId); - // pst.setString(2, code.code); - // pst.setBoolean(3, code.isValidCode); - // pst.setLong(4, code.expiryTime); - // pst.setLong(5, code.createdTime); - // }); } - public static TOTPUsedCode[] getUsedCodes(Start start, String userId) throws SQLException, StorageQueryException { + /** + * Query to get all non expired used codes for a user in descending order of + * creation time. + */ + public static TOTPUsedCode[] getNonExpiredUsedCodesDescOrder(Start start, String userId) + throws SQLException, StorageQueryException { String QUERY = "SELECT * FROM " + Config.getConfig(start).getTotpUsedCodesTable() - + " WHERE user_id = ? AND expiry_time_ms > ? ORDER BY created_time_ms DESC;"; // FIXME: Should be based - // on creation_time - // because - // of different devices - // having different expiry - // times (bcoz of period - // and skew values) + + " WHERE user_id = ? AND expiry_time_ms > ? ORDER BY created_time_ms DESC;"; return execute(start, QUERY, pst -> { pst.setString(1, userId); pst.setLong(2, System.currentTimeMillis()); @@ -355,7 +341,7 @@ public TOTPUsedCode map(ResultSet result) throws SQLException { return new TOTPUsedCode( result.getString("user_id"), result.getString("code"), - result.getBoolean("is_valid_code"), + result.getBoolean("is_valid"), result.getLong("expiry_time_ms"), result.getLong("created_time_ms")); } diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index 4cfdb5834..36f82b364 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -12,9 +12,13 @@ import javax.crypto.Mac; import javax.crypto.spec.SecretKeySpec; +import org.jetbrains.annotations.TestOnly; + import java.util.Base64; import io.supertokens.Main; +import io.supertokens.config.Config; + import com.eatthepath.otp.TimeBasedOneTimePasswordGenerator; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.totp.TOTPDevice; @@ -29,8 +33,7 @@ import io.supertokens.totp.exceptions.LimitReachedException; public class Totp { - - public static String generateSecret() throws Exception { + private static String generateSecret() throws Exception { final String TOTP_ALGORITHM = "HmacSHA1"; try { @@ -45,7 +48,7 @@ public static String generateSecret() throws Exception { } } - public static boolean checkCode(TOTPDevice device, String code) { + private static boolean checkCode(TOTPDevice device, String code) { final TimeBasedOneTimePasswordGenerator totp = new TimeBasedOneTimePasswordGenerator(); byte[] keyBytes = Base64.getDecoder().decode(device.secretKey); @@ -57,6 +60,7 @@ public static boolean checkCode(TOTPDevice device, String code) { // Check if code is valid for any of the time periods in the skew: for (int i = -skew; i <= skew; i++) { try { + // FIXME: Where is this using % on the period? if (totp.generateOneTimePasswordString(key, Instant.now().plusSeconds(i * period)).equals(code)) { return true; } @@ -68,6 +72,19 @@ public static boolean checkCode(TOTPDevice device, String code) { return false; } + /** + * Replicates TOTP code is generated by apps like Google Authenticator and Authy + */ + @TestOnly + public static String generateTotpCodeForDevice(TOTPDevice device) throws InvalidKeyException { + final TimeBasedOneTimePasswordGenerator totp = new TimeBasedOneTimePasswordGenerator(); + + byte[] keyBytes = Base64.getDecoder().decode(device.secretKey); + Key key = new SecretKeySpec(keyBytes, "HmacSHA1"); + + return totp.generateOneTimePasswordString(key, Instant.now()); + } + public static String registerDevice(Main main, String userId, String deviceName, int skew, int period) throws StorageQueryException, DeviceAlreadyExistsException, Exception { @@ -98,6 +115,11 @@ private static void checkAndStoreCode(TOTPStorage totpStorage, String userId, TO // Check if the code has been successfully used by the user (for any device): for (TOTPUsedCode usedCode : usedCodes) { + // One edge case is that if the user has 2 devices, and they are used back to + // back (within 90 seconds) such that the code of the first device was + // regenerated by the second device, then it won't allow the second device's + // code to be used until it is expired. + // But this would be rare so we can ignore it for now. if (usedCode.code.equals(code)) { throw new InvalidTotpException(); } @@ -130,7 +152,7 @@ public static boolean verifyDevice(Main main, String userId, String deviceName, if (device.deviceName.equals(deviceName)) { matchingDevice = device; if (device.verified) { - return false; + return false; // Was already verified } break; } @@ -146,7 +168,8 @@ public static boolean verifyDevice(Main main, String userId, String deviceName, TOTPUsedCode[] usedCodes = totpStorage.getNonExpiredUsedCodes(userId); checkAndStoreCode(totpStorage, userId, new TOTPDevice[] { matchingDevice }, usedCodes, code); totpStorage.markDeviceAsVerified(userId, deviceName); - return true; + // Note: No rate limiting in device verification + return true; // Newly verified } public static void verifyCode(Main main, String userId, String code, boolean allowUnverifiedDevices) @@ -171,14 +194,15 @@ public static void verifyCode(Main main, String userId, String code, boolean all checkAndStoreCode(totpStorage, userId, devices, usedCodes, code); } catch (InvalidTotpException e) { // Now we know that the code is invalid. - // Check if latest 3 codes are all invalid: // Note: usedCodes will get updated when // - A valid code is used: It will break the chain of invalid codes. // - Cron job runs: deletes expired codes every hour - // All the latest 3 codes are invalid: - if (Arrays.stream(usedCodes).limit(3).allMatch(usedCode -> !usedCode.isValidCode)) { + // Check if latest N codes are all invalid: + int N = Config.getConfig(main).getTotpRateLimitWindowSize(); // default = 5 + long invalidCodeCount = Arrays.stream(usedCodes).limit(N).takeWhile(usedCode -> !usedCode.isValid).count(); + if (invalidCodeCount >= N) { throw new LimitReachedException(); } diff --git a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java index e14c5dfd9..fa9670fd5 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java @@ -155,6 +155,7 @@ public void createDeviceAndVerifyCode() throws Exception { { triggerRateLimit(main); // Run cronjob: + assert false; // Totp.runCron(main); Totp.verifyCode(main, "user", "XXXX-code", true); } diff --git a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java index a1bd42f37..d677594b5 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java @@ -276,7 +276,7 @@ public void getNonExpiredUsedCodesTest() throws Exception { storage.insertUsedCode(expiredInvalidCode); usedCodes = storage.getNonExpiredUsedCodes("user"); - assert (usedCodes.length == 2); + assert (usedCodes.length == 2); // expired codes shouldn't be returned assert (usedCodes[0].equals(validCode)); assert (usedCodes[1].equals(invalidCode)); } From 457f091942ec379f74bc78fc245c6b70c9e6f2e5 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Thu, 23 Feb 2023 17:41:15 +0530 Subject: [PATCH 17/42] refactor: Remove device_name from totp_used_codes table Every used code is only linked to the user now. No concept of code to device linking. So removed device_name. --- .../java/io/supertokens/inmemorydb/queries/TOTPQueries.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index b3892492d..3a283fb01 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -36,7 +36,7 @@ public static String getQueryToCreateUserDevicesTable(Start start) { public static String getQueryToCreateUsedCodesTable(Start start) { return "CREATE TABLE IF NOT EXISTS " + Config.getConfig(start).getTotpUsedCodesTable() + " (" - + "user_id VARCHAR(128) NOT NULL, " + "device_name VARCHAR(256), " + + "user_id VARCHAR(128) NOT NULL, " + "code CHAR(6) NOT NULL," + "is_valid BOOLEAN NOT NULL," + "created_time_ms BIGINT UNSIGNED NOT NULL," + "expiry_time_ms BIGINT UNSIGNED NOT NULL," From 6fbfebc00665d56eaa59449263ef24480fa3ee99 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Mon, 27 Feb 2023 19:16:32 +0530 Subject: [PATCH 18/42] feat: Improve TOTP recipe - Improve rate limiting and device removal logic - Add config for TOTP rate limting - Properly use transactions - Add/update tests --- config.yaml | 5 + devConfig.yaml | 7 +- implementationDependencies.json | 5 + .../io/supertokens/config/CoreConfig.java | 22 +- .../DeleteExpiredTotpTokens.java | 3 +- .../java/io/supertokens/inmemorydb/Start.java | 42 +++- .../inmemorydb/queries/GeneralQueries.java | 6 +- .../inmemorydb/queries/TOTPQueries.java | 84 ++++---- src/main/java/io/supertokens/totp/Totp.java | 166 +++++++++------ .../io/supertokens/test/ConfigTest2_6.java | 4 + .../supertokens/test/totp/TOTPRecipeTest.java | 198 +++++++++++++----- .../test/totp/TOTPStorageTest.java | 178 ++++++++++++---- 12 files changed, 497 insertions(+), 223 deletions(-) diff --git a/config.yaml b/config.yaml index 3b6d02904..489458b0c 100644 --- a/config.yaml +++ b/config.yaml @@ -54,6 +54,11 @@ core_config_version: 0 # (OPTIONAL | Default: 900000) long value. Time in milliseconds for how long a passwordless code is valid for. # passwordless_code_lifetime: +# (OPTIONAL | Default: 5) integer value. The maximum number of invalid TOTP attempts that will trigger rate limiting. +# totp_max_attempts: + +# (OPTIONAL | Default: 900) integer value. The time in seconds for which the user will be rate limited (once totp_max_attempts is crossed) +# totp_rate_limit_cooldown_time: # (OPTIONAL | Default: installation directory/logs/info.log) string value. Give the path to a file (on your local # system) in which the SuperTokens service can write INFO logs to. Set it to "null" if you want it to log to diff --git a/devConfig.yaml b/devConfig.yaml index 5b15e7c8f..15f8ffa12 100644 --- a/devConfig.yaml +++ b/devConfig.yaml @@ -54,6 +54,11 @@ core_config_version: 0 # (OPTIONAL | Default: 900000) long value. Time in milliseconds for how long a passwordless code is valid for. # passwordless_code_lifetime: +# (OPTIONAL | Default: 5) integer value. The maximum number of invalid TOTP attempts that will trigger rate limiting. +# totp_max_attempts: + +# (OPTIONAL | Default: 900) integer value. The time in seconds for which the user will be rate limited (once totp_max_attempts is crossed) +# totp_rate_limit_cooldown_time: # (OPTIONAL | Default: installation directory/logs/info.log) string value. Give the path to a file (on your local # system) in which the SuperTokens service can write INFO logs to. Set it to "null" if you want it to log to @@ -120,4 +125,4 @@ disable_telemetry: true # (OPTIONAL | Default: null). Regex for denying requests from IP addresses that match with the value. Comment this # value to deny no IP address. -# ip_deny_regex: \ No newline at end of file +# ip_deny_regex: diff --git a/implementationDependencies.json b/implementationDependencies.json index 82e762a01..a96c50bc5 100644 --- a/implementationDependencies.json +++ b/implementationDependencies.json @@ -100,6 +100,11 @@ "jar": "https://repo1.maven.org/maven2/com/lambdaworks/scrypt/1.4.0/scrypt-1.4.0.jar", "name": "Scrypt 1.4.0", "src": "https://repo1.maven.org/maven2/com/lambdaworks/scrypt/1.4.0/scrypt-1.4.0-sources.jar" + }, + { + "jar": "https://repo1.maven.org/maven2/com/eatthepath/java-otp/0.4.0/java-otp-0.4.0.jar", + "name": "Java OTP 0.4.0", + "src": "https://repo1.maven.org/maven2/com/eatthepath/java-otp/0.4.0/java-otp-0.4.0-sources.jar" } ] } \ No newline at end of file diff --git a/src/main/java/io/supertokens/config/CoreConfig.java b/src/main/java/io/supertokens/config/CoreConfig.java index 83e97a8ce..1b3e57df6 100644 --- a/src/main/java/io/supertokens/config/CoreConfig.java +++ b/src/main/java/io/supertokens/config/CoreConfig.java @@ -57,7 +57,10 @@ public class CoreConfig { private long passwordless_code_lifetime = 900000; // in MS @JsonProperty - private int totp_rate_limit_window_size = 5; + private int totp_max_attempts = 5; + + @JsonProperty + private int totp_rate_limit_cooldown_time = 900; // in seconds (Default 15 mins) private final String logDefault = "asdkfahbdfk3kjHS"; @JsonProperty @@ -274,8 +277,13 @@ public long getPasswordlessCodeLifetime() { return passwordless_code_lifetime; } - public int getTotpRateLimitWindowSize() { - return totp_rate_limit_window_size; + public int getTotpMaxAttempts() { + return totp_max_attempts; + } + + /** TOTP rate limit cooldown time (in seconds) */ + public int getTotpRateLimitCooldownTime() { + return totp_rate_limit_cooldown_time; } public boolean isTelemetryDisabled() { @@ -396,8 +404,12 @@ void validateAndInitialise(Main main) throws IOException { throw new QuitProgramException("'passwordless_max_code_input_attempts' must be > 0"); } - if (totp_rate_limit_window_size <= 0) { - throw new QuitProgramException("'totp_rate_limit_window_size' must be > 0"); + if (totp_max_attempts <= 0) { + throw new QuitProgramException("'totp_max_attempts' must be > 0"); + } + + if (totp_rate_limit_cooldown_time <= 0) { + throw new QuitProgramException("'totp_rate_limit_cooldown_time' must be > 0"); } if (max_server_pool_size <= 0) { diff --git a/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java b/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java index aa794fc3c..fa61f4db7 100644 --- a/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java +++ b/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java @@ -32,7 +32,8 @@ protected void doTask() throws Exception { TOTPSQLStorage storage = StorageLayer.getTOTPStorage(this.main); - storage.removeExpiredCodes(); + int N = 5; // FIXME:: This is not used anywhere + storage.removeExpiredCodes(N); } @Override diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index 9f35ab84c..ca91f51a3 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -1633,7 +1633,7 @@ public void addInfoToNonAuthRecipesBasedOnUserId(String className, String userId @Override public void createDevice(TOTPDevice device) throws StorageQueryException, DeviceAlreadyExistsException { try { - TOTPQueries.createDeviceAndUser(this, device); + TOTPQueries.createDevice(this, device); } catch (StorageTransactionLogicException e) { String message = e.actualException.getMessage(); if (message.equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (UNIQUE constraint failed: " @@ -1661,15 +1661,35 @@ public void markDeviceAsVerified(String userId, String deviceName) } @Override - public void deleteDevice(String userId, String deviceName) - throws StorageQueryException, UnknownDeviceException { + public int deleteDevice_Transaction(TransactionConnection con, String userId, String deviceName) + throws StorageQueryException { + Connection sqlCon = (Connection) con.getConnection(); try { - int deletedCount = TOTPQueries.deleteDevice(this, userId, deviceName); - if (deletedCount == 0) { - throw new UnknownDeviceException(); - } - } catch (StorageTransactionLogicException e) { - throw new StorageQueryException(e.actualException); + return TOTPQueries.deleteDevice_Transaction(this, sqlCon, userId, deviceName); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public int getDevicesCount_Transaction(TransactionConnection con, String userId) + throws StorageQueryException { + Connection sqlCon = (Connection) con.getConnection(); + try { + return TOTPQueries.getDevicesCount_Transaction(this, sqlCon, userId); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public void removeUser_Transaction(TransactionConnection con, String userId) + throws StorageQueryException { + Connection sqlCon = (Connection) con.getConnection(); + try { + TOTPQueries.removeUser_Transaction(this, sqlCon, userId); + } catch (SQLException e) { + throw new StorageQueryException(e); } } @@ -1730,10 +1750,10 @@ public TOTPUsedCode[] getNonExpiredUsedCodes(String userId) } @Override - public void removeExpiredCodes() + public void removeExpiredCodes(int N) throws StorageQueryException { try { - TOTPQueries.removeExpiredCodes(this); + TOTPQueries.removeExpiredCodes(this, N); } catch (SQLException e) { throw new StorageQueryException(e); } diff --git a/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java index f0343f912..d4a325a36 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java @@ -200,7 +200,8 @@ public static void createTablesIfNotExists(Start start, Main main) throws SQLExc getInstance(main).addState(CREATING_NEW_TABLE, null); update(start, TOTPQueries.getQueryToCreateUsedCodesTable(start), NO_OP_SETTER); // index: - update(start, TOTPQueries.getQueryToCreateUsedCodesIndex(start), NO_OP_SETTER); + update(start, TOTPQueries.getQueryToCreateUsedCodesExpiryTimeIndex(start), NO_OP_SETTER); + update(start, TOTPQueries.getQueryToCreateUsedCodesCreatedTimeIndex(start), NO_OP_SETTER); } } @@ -396,7 +397,8 @@ public static AuthRecipeUserInfo[] getUsers(Start start, @NotNull Integer limit, List users = getUserInfoForRecipeIdFromUserIds(start, recipeId, recipeIdToUserIdListMap.get(recipeId)); - // we fill in all the slots in finalResult based on their position in usersFromQuery + // we fill in all the slots in finalResult based on their position in + // usersFromQuery Map userIdToInfoMap = new HashMap<>(); for (AuthRecipeUserInfo user : users) { userIdToInfoMap.put(user.id, user); diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index 3a283fb01..21cf1a86d 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -44,13 +44,17 @@ public static String getQueryToCreateUsedCodesTable(Start start) { + "(user_id) ON DELETE CASCADE);"; } - public static String getQueryToCreateUsedCodesIndex(Start start) { + public static String getQueryToCreateUsedCodesExpiryTimeIndex(Start start) { return "CREATE INDEX IF NOT EXISTS totp_used_codes_expiry_time_ms_index ON " + Config.getConfig(start).getTotpUsedCodesTable() + " (expiry_time_ms)"; - // TODO: Create index on created_time_ms as well } - public static int insertUser_Transaction(Start start, Connection con, String userId) + public static String getQueryToCreateUsedCodesCreatedTimeIndex(Start start) { + return "CREATE INDEX IF NOT EXISTS totp_used_codes_created_time_ms_index ON " + + Config.getConfig(start).getTotpUsedCodesTable() + " (created_time_ms DESC)"; + } + + private static int insertUser_Transaction(Start start, Connection con, String userId) throws SQLException, StorageQueryException { // Create user if not exists: // TODO: Check if not using "CONFLICT DO NOTHING" will break the transaction @@ -61,7 +65,7 @@ public static int insertUser_Transaction(Start start, Connection con, String use return update(con, QUERY, pst -> pst.setString(1, userId)); } - public static int insertDevice_Transaction(Start start, Connection con, TOTPDevice device) + private static int insertDevice_Transaction(Start start, Connection con, TOTPDevice device) throws SQLException, StorageQueryException { String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUserDevicesTable() + " (user_id, device_name, secret_key, period, skew, verified) VALUES (?, ?, ?, ?, ?, ?)"; @@ -76,7 +80,7 @@ public static int insertDevice_Transaction(Start start, Connection con, TOTPDevi }); } - public static void createDeviceAndUser(Start start, TOTPDevice device) + public static void createDevice(Start start, TOTPDevice device) throws StorageQueryException, StorageTransactionLogicException { start.startTransaction(con -> { Connection sqlCon = (Connection) con.getConnection(); @@ -91,6 +95,7 @@ public static void createDeviceAndUser(Start start, TOTPDevice device) return null; }); + return; } public static int markDeviceAsVerified(Start start, String userId, String deviceName) @@ -120,40 +125,45 @@ public static int removeUser_Transaction(Start start, Connection con, String use + " WHERE user_id = ?;"; int removedUsersCount = update(con, QUERY, pst -> pst.setString(1, userId)); - // Delete all used codes for this user: - // Note: This step is required only for in-memory db. - // Other databases will automatically delete the used codes when the user is + // Delete all devices and used codes for this user: + // This step is required only for in-memory db. + // Other databases will automatically delete these when the user is // deleted because of foreign key constraints. - String QUERY2 = "DELETE FROM " + Config.getConfig(start).getTotpUsedCodesTable() + String QUERY2 = "DELETE FROM " + Config.getConfig(start).getTotpUserDevicesTable() + " WHERE user_id = ?;"; update(con, QUERY2, pst -> pst.setString(1, userId)); + String QUERY3 = "DELETE FROM " + Config.getConfig(start).getTotpUsedCodesTable() + + " WHERE user_id = ?;"; + update(con, QUERY3, pst -> pst.setString(1, userId)); + return removedUsersCount; } - public static int deleteDevice(Start start, String userId, String deviceName) - throws StorageQueryException, StorageTransactionLogicException { - return start.startTransaction(con -> { - Connection sqlCon = (Connection) con.getConnection(); - - try { - int deletedCount = deleteDevice_Transaction(start, sqlCon, userId, deviceName); - if (deletedCount > 0) { - // Some device was deleted. Check if user has any other device left: - int devicesCount = getDevicesCount_Transaction(start, sqlCon, userId); - if (devicesCount == 0) { - // no device left. delete user - removeUser_Transaction(start, sqlCon, userId); - } - } - - sqlCon.commit(); - return deletedCount; - } catch (SQLException e) { - throw new StorageTransactionLogicException(e); - } - }); - } + // public static int deleteDevice(Start start, String userId, String deviceName) + // throws StorageQueryException, StorageTransactionLogicException { + // return start.startTransaction(con -> { + // Connection sqlCon = (Connection) con.getConnection(); + + // try { + // int deletedCount = deleteDevice_Transaction(start, sqlCon, userId, + // deviceName); + // if (deletedCount > 0) { + // // Some device was deleted. Check if user has any other device left: + // int devicesCount = getDevicesCount_Transaction(start, sqlCon, userId); + // if (devicesCount == 0) { + // // no device left. delete user + // removeUser_Transaction(start, sqlCon, userId); + // } + // } + + // sqlCon.commit(); + // return deletedCount; + // } catch (SQLException e) { + // throw new StorageTransactionLogicException(e); + // } + // }); + // } public static int updateDeviceName(Start start, String userId, String oldDeviceName, String newDeviceName) throws StorageQueryException, SQLException { @@ -254,29 +264,31 @@ public static TOTPUsedCode[] getNonExpiredUsedCodesDescOrder(Start start, String }); } - public static int removeExpiredCodes(Start start) + public static int removeExpiredCodes(Start start, int offset) throws StorageQueryException, SQLException { + // TODO: Use offset with `order by created_at desc` to exclude latest N codes + // for each user String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUsedCodesTable() + " WHERE expiry_time_ms < ?;"; return update(start, QUERY, pst -> pst.setLong(1, System.currentTimeMillis())); } - public static int deleteAllDevices_Transaction(Start start, Connection con, String userId) + private static int deleteAllDevices_Transaction(Start start, Connection con, String userId) throws SQLException, StorageQueryException { String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUserDevicesTable() + " WHERE user_id = ?;"; return update(con, QUERY, pst -> pst.setString(1, userId)); } - public static int deleteAllUsedCodes_Transaction(Start start, Connection con, String userId) + private static int deleteAllUsedCodes_Transaction(Start start, Connection con, String userId) throws SQLException, StorageQueryException { String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUsedCodesTable() + " WHERE user_id = ?;"; return update(con, QUERY, pst -> pst.setString(1, userId)); } - public static int deleteUser_Transaction(Start start, Connection con, String userId) + private static int deleteUser_Transaction(Start start, Connection con, String userId) throws SQLException, StorageQueryException { String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUsersTable() + " WHERE user_id = ?;"; diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index 36f82b364..42e6ef6d8 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -14,13 +14,12 @@ import org.jetbrains.annotations.TestOnly; -import java.util.Base64; - import io.supertokens.Main; import io.supertokens.config.Config; import com.eatthepath.otp.TimeBasedOneTimePasswordGenerator; import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.totp.TOTPDevice; import io.supertokens.pluginInterface.totp.TOTPStorage; import io.supertokens.pluginInterface.totp.TOTPUsedCode; @@ -33,19 +32,18 @@ import io.supertokens.totp.exceptions.LimitReachedException; public class Totp { - private static String generateSecret() throws Exception { + private static String generateSecret() throws NoSuchAlgorithmException { + // TODO: We can actually allow the user to choose this algorithm. + // Changing it a would be rare but it can be a requirement for someone + // who's dealing with unconventional totp apps/devices. final String TOTP_ALGORITHM = "HmacSHA1"; - try { - final KeyGenerator keyGenerator = KeyGenerator.getInstance(TOTP_ALGORITHM); - keyGenerator.init(160); // 160 bits = 20 bytes - - // FIXME: Should return base32 or base16 - // Return base64 string of the secret key: - return Base64.getEncoder().encodeToString(keyGenerator.generateKey().getEncoded()); - } catch (NoSuchAlgorithmException e) { - throw new Exception("TOTP algorithm not found"); - } + final KeyGenerator keyGenerator = KeyGenerator.getInstance(TOTP_ALGORITHM); + keyGenerator.init(160); // 160 bits = 20 bytes + + // FIXME: Should return base32 or base16 + // Return base64 string of the secret key: + return Base64.getEncoder().encodeToString(keyGenerator.generateKey().getEncoded()); } private static boolean checkCode(TOTPDevice device, String code) { @@ -74,9 +72,16 @@ private static boolean checkCode(TOTPDevice device, String code) { /** * Replicates TOTP code is generated by apps like Google Authenticator and Authy + * + * @throws StorageQueryException */ @TestOnly - public static String generateTotpCodeForDevice(TOTPDevice device) throws InvalidKeyException { + public static String generateTotpCode(Main main, String userId, String deviceName) + throws InvalidKeyException, StorageQueryException { + TOTPDevice[] devices = StorageLayer.getTOTPStorage(main).getDevices(userId); + TOTPDevice device = Arrays.stream(devices).filter(d -> d.deviceName.equals(deviceName)).findFirst() + .orElse(null); + final TimeBasedOneTimePasswordGenerator totp = new TimeBasedOneTimePasswordGenerator(); byte[] keyBytes = Base64.getDecoder().decode(device.secretKey); @@ -86,20 +91,43 @@ public static String generateTotpCodeForDevice(TOTPDevice device) throws Invalid } public static String registerDevice(Main main, String userId, String deviceName, int skew, int period) - throws StorageQueryException, DeviceAlreadyExistsException, Exception { + throws StorageQueryException, DeviceAlreadyExistsException, NoSuchAlgorithmException { TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); - String secret = generateSecret(); // TODO: should we handle Exception differently? + // TODO: There should be a hard limit on number of devices per user + // 8 devices per user should be enough. Otherwise, it is a security risk. + + String secret = generateSecret(); TOTPDevice device = new TOTPDevice(userId, deviceName, secret, period, skew, false); totpStorage.createDevice(device); return secret; } - private static void checkAndStoreCode(TOTPStorage totpStorage, String userId, TOTPDevice[] devices, - TOTPUsedCode[] usedCodes, String code) - throws InvalidTotpException, StorageQueryException, TotpNotEnabledException { + private static void checkAndStoreCode(Main main, TOTPStorage totpStorage, String userId, TOTPDevice[] devices, + String code) + throws InvalidTotpException, StorageQueryException, TotpNotEnabledException, LimitReachedException { + TOTPUsedCode[] usedCodes = totpStorage.getNonExpiredUsedCodes(userId); + + // N represents # of invalid attempts that will trigger rate limiting: + int N = Config.getConfig(main).getTotpMaxAttempts(); // (Default 5) + // Count # of contiguous invalids in latest N attempts (stop at first valid): + long invalidOutOfN = Arrays.stream(usedCodes).limit(N).takeWhile(usedCode -> !usedCode.isValid).count(); + int rateLimitResetTimeInMs = Config.getConfig(main).getTotpRateLimitCooldownTime() * 1000; // (Default 15 mins) + + // Check if the user has been rate limited: + if (invalidOutOfN == N) { + // All of the latest N attempts were invalid: + long latestInvalidCodeCreatedTime = usedCodes[0].createdTime; + long now = System.currentTimeMillis(); + + if (now - latestInvalidCodeCreatedTime < rateLimitResetTimeInMs) { + // Less than rateLimitResetTimeInMs (default = 15 mins) time has elasped since + // the last invalid code: + throw new LimitReachedException(); + } + } // Check if the code is valid for any device: boolean isValid = false; @@ -113,28 +141,37 @@ private static void checkAndStoreCode(TOTPStorage totpStorage, String userId, TO } } - // Check if the code has been successfully used by the user (for any device): - for (TOTPUsedCode usedCode : usedCodes) { - // One edge case is that if the user has 2 devices, and they are used back to - // back (within 90 seconds) such that the code of the first device was - // regenerated by the second device, then it won't allow the second device's - // code to be used until it is expired. - // But this would be rare so we can ignore it for now. - if (usedCode.code.equals(code)) { - throw new InvalidTotpException(); + // Check if the code has been previously used by the user and it was valid (and + // is still valid). If so, this could be a replay attack. So reject it. + if (isValid) { + for (TOTPUsedCode usedCode : usedCodes) { + // One edge case is that if the user has 2 devices, and they are used back to + // back (within 90 seconds) such that the code of the first device was + // regenerated by the second device, then it won't allow the second device's + // code to be used until it is expired. + // But this would be rare so we can ignore it for now. + if (usedCode.code.equals(code) && usedCode.isValid) { + isValid = false; + matchingDevice = null; + } } } // Insert the code into the list of used codes: long now = System.currentTimeMillis(); - int expireInSec = isValid ? matchingDevice.period * (2 * matchingDevice.skew + 1) : 60 * 5; + int expireInSec = isValid ? matchingDevice.period * (2 * matchingDevice.skew + 1) : 60 * 30; // 30 mins TOTPUsedCode newCode = new TOTPUsedCode(userId, code, isValid, now + 1000 * expireInSec, now); totpStorage.insertUsedCode(newCode); + + if (!isValid) { + throw new InvalidTotpException(); + } } public static boolean verifyDevice(Main main, String userId, String deviceName, String code) - throws StorageQueryException, TotpNotEnabledException, UnknownDeviceException, InvalidTotpException { + throws StorageQueryException, TotpNotEnabledException, UnknownDeviceException, InvalidTotpException, + LimitReachedException { // Here boolean return value tells whether the device has been // newly verified (true) OR it was already verified (false) @@ -163,12 +200,9 @@ public static boolean verifyDevice(Main main, String userId, String deviceName, throw new UnknownDeviceException(); } - // If the device is not verified, check if the code is valid and unused. - // If it is successful, mark the device as verified. - TOTPUsedCode[] usedCodes = totpStorage.getNonExpiredUsedCodes(userId); - checkAndStoreCode(totpStorage, userId, new TOTPDevice[] { matchingDevice }, usedCodes, code); + checkAndStoreCode(main, totpStorage, userId, new TOTPDevice[] { matchingDevice }, code); + // Will reach here only if the code is valid: totpStorage.markDeviceAsVerified(userId, deviceName); - // Note: No rate limiting in device verification return true; // Newly verified } @@ -188,41 +222,41 @@ public static void verifyCode(Main main, String userId, String code, boolean all devices = Arrays.stream(devices).filter(device -> device.verified).toArray(TOTPDevice[]::new); } - TOTPUsedCode[] usedCodes = totpStorage.getNonExpiredUsedCodes(userId); - - try { - checkAndStoreCode(totpStorage, userId, devices, usedCodes, code); - } catch (InvalidTotpException e) { - // Now we know that the code is invalid. - - // Note: usedCodes will get updated when - // - A valid code is used: It will break the chain of invalid codes. - // - Cron job runs: deletes expired codes every hour - - // Check if latest N codes are all invalid: - int N = Config.getConfig(main).getTotpRateLimitWindowSize(); // default = 5 - long invalidCodeCount = Arrays.stream(usedCodes).limit(N).takeWhile(usedCode -> !usedCode.isValid).count(); - if (invalidCodeCount >= N) { - throw new LimitReachedException(); - } - - // Code is invalid but the user has not exceeded the limit: - throw e; - } + checkAndStoreCode(main, totpStorage, userId, devices, code); } - public static void deleteDevice(Main main, String userId, String deviceName) - throws StorageQueryException, UnknownDeviceException, TotpNotEnabledException { - TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); + /** Delete device and also delete the user if deleting the last device */ + public static void removeDevice(Main main, String userId, String deviceName) + throws StorageQueryException, UnknownDeviceException, TotpNotEnabledException, + StorageTransactionLogicException { + TOTPSQLStorage storage = StorageLayer.getTOTPStorage(main); try { - totpStorage.deleteDevice(userId, deviceName); - } catch (UnknownDeviceException e) { - // See if any device exists for the user: - TOTPDevice[] devices = totpStorage.getDevices(userId); - if (devices.length == 0) { - throw new TotpNotEnabledException(); + storage.startTransaction(con -> { + int deletedCount = storage.deleteDevice_Transaction(con, userId, deviceName); + if (deletedCount == 0) + throw new StorageTransactionLogicException(new UnknownDeviceException()); + + // Some device(s) were deleted. Check if user has any other device left: + int devicesCount = storage.getDevicesCount_Transaction(con, userId); + if (devicesCount == 0) { + // no device left. delete user + storage.removeUser_Transaction(con, userId); + } + + storage.commitTransaction(con); + return null; + }); + return; + } catch (StorageTransactionLogicException e) { + if (e.actualException instanceof UnknownDeviceException) { + // Check if any device exists for the user: + TOTPDevice[] devices = storage.getDevices(userId); + if (devices.length == 0) { + throw new TotpNotEnabledException(); + } } + throw e; } } @@ -234,7 +268,7 @@ public static void updateDeviceName(Main main, String userId, String oldDeviceNa try { totpStorage.updateDeviceName(userId, oldDeviceName, newDeviceName); } catch (UnknownDeviceException e) { - // See if any device exists for the user: + // Check if any device exists for the user: TOTPDevice[] devices = totpStorage.getDevices(userId); if (devices.length == 0) { throw new TotpNotEnabledException(); diff --git a/src/test/java/io/supertokens/test/ConfigTest2_6.java b/src/test/java/io/supertokens/test/ConfigTest2_6.java index 6f9f04d74..c40491762 100644 --- a/src/test/java/io/supertokens/test/ConfigTest2_6.java +++ b/src/test/java/io/supertokens/test/ConfigTest2_6.java @@ -220,6 +220,10 @@ private static void checkConfigValues(CoreConfig config, TestingProcess process, assertFalse("Config access token blacklisting did not match default", config.getAccessTokenBlacklisting()); assertEquals("Config refresh token validity did not match default", config.getRefreshTokenValidity(), 60 * 2400 * 60 * (long) 1000); + // TODO: Is this correct? + assertEquals(5, config.getTotpMaxAttempts()); // 5 + assertEquals(900, config.getTotpRateLimitCooldownTime()); // 15 minutes + assertEquals("Config info log path did not match default", config.getInfoLogPath(process.getProcess()), CLIOptions.get(process.getProcess()).getInstallationPath() + "logs/info.log"); assertEquals("Config error log path did not match default", config.getErrorLogPath(process.getProcess()), diff --git a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java index fa9670fd5..68098762d 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java @@ -19,6 +19,8 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; +import java.io.IOException; + import org.junit.AfterClass; import org.junit.Before; import org.junit.Rule; @@ -28,6 +30,8 @@ import io.supertokens.test.Utils; import io.supertokens.Main; import io.supertokens.ProcessState; +import io.supertokens.config.Config; +import io.supertokens.config.CoreConfig; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.test.TestingProcessManager; @@ -35,7 +39,9 @@ import io.supertokens.totp.Totp; import io.supertokens.totp.exceptions.InvalidTotpException; import io.supertokens.totp.exceptions.LimitReachedException; +import io.supertokens.pluginInterface.totp.TOTPDevice; import io.supertokens.pluginInterface.totp.TOTPStorage; +import io.supertokens.pluginInterface.totp.TOTPUsedCode; import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; @@ -65,7 +71,7 @@ public TestSetupResult(TOTPStorage storage, TestingProcessManager.TestingProcess } } - public TestSetupResult setup() throws InterruptedException { + public TestSetupResult setup() throws InterruptedException, IOException { String[] args = { "../" }; TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); @@ -93,72 +99,117 @@ public void createDevice() throws Exception { () -> Totp.registerDevice(main, "user", "device1", 1, 30)); } - public void triggerRateLimit(Main main) throws Exception { - // First 2 attempts should fail with invalid code: - assertThrows( - InvalidTotpException.class, - () -> Totp.verifyCode(main, "user", "wrong-code-1", true)); - assertThrows( - InvalidTotpException.class, - () -> Totp.verifyCode(main, "user", "wrong-code-2", true)); - - // 3th attempt should fail with rate limiting error: - assertThrows( - LimitReachedException.class, - () -> Totp.verifyCode(main, "user", "wrong-code-3", true)); - } - - @Test - public void createDeviceAndVerifyCodeAgainstUnverifiedDevices() throws Exception { - - } - @Test public void createDeviceAndVerifyCode() throws Exception { TestSetupResult result = setup(); Main main = result.process.getProcess(); // Create device - String secret = Totp.registerDevice(main, "user", "deviceName", 1, 30); + String secret = Totp.registerDevice(main, "user", "device", 1, 30); // Try login with non-existent user: assertThrows(TotpNotEnabledException.class, - () -> Totp.verifyCode(main, "non-existent-user", "XXXX-code", true)); + () -> Totp.verifyCode(main, "non-existent-user", "any-code", true)); - // Try login with invalid code: + // {Code: [INVALID, VALID]} * {Devices: [VERIFIED_ONLY, ALL]} + + // Invalid code & allowUnverifiedDevice = true: assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "invalid-code", true)); - // Try login with with unverified device: + // Invalid code & allowUnverifiedDevice = false: assertThrows(InvalidTotpException.class, - () -> Totp.verifyCode(main, "user", "XXXX-code", false)); + () -> Totp.verifyCode(main, "user", "invalid-code", false)); + + // Valid code & allowUnverifiedDevice = false: + assertThrows( + InvalidTotpException.class, + () -> Totp.verifyCode(main, "user", Totp.generateTotpCode(main, "user", "device"), false)); - // Try logging in with same again code but allowUnverifiedDevice = true: - // TODO: Think if this is correct - assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "XXXX-code", true)); + // Valid code & allowUnverifiedDevice = true (Success): + String validCode = Totp.generateTotpCode(main, "user", "device"); + Totp.verifyCode(main, "user", validCode, true); - // Successfully login: - Totp.verifyCode(main, "user", "XXXX-code2", true); // Now try again with same code: assertThrows( InvalidTotpException.class, - () -> Totp.verifyCode(main, "user", "XXXX-code2", true)); + () -> Totp.verifyCode(main, "user", validCode, true)); - // Trigger rate limiting and fix it with a correct code: - { - triggerRateLimit(main); - // Using a correct code should fix the rate limiting: - Totp.verifyCode(main, "user", "XXXX-code", true); + // Use a new valid code: + String newValidCode = Totp.generateTotpCode(main, "user", "device"); + Totp.verifyCode(main, "user", newValidCode, true); + } + + public void triggerAndCheckRateLimit(Main main, String userId, String deviceName) throws Exception { + int N = Config.getConfig(main).getTotpMaxAttempts(); + // First N attempts should fail with invalid code: + // This is to trigger rate limiting + for (int i = 0; i < N; i++) { + String code = "invalid-code-" + i; + assertThrows( + InvalidTotpException.class, + () -> Totp.verifyCode(main, "user", code, true)); } - // Trigger rate limiting and fix it with cronjob (runs every 1 hour) - { - triggerRateLimit(main); - // Run cronjob: - assert false; - // Totp.runCron(main); - Totp.verifyCode(main, "user", "XXXX-code", true); + // Any kind of attempt after this should fail with rate limiting error. + // This should happen until rate limiting cooldown happens: + assertThrows( + LimitReachedException.class, + () -> Totp.verifyCode(main, "user", "invalid-code-N+1", true)); + assertThrows( + LimitReachedException.class, + () -> Totp.verifyCode(main, "user", Totp.generateTotpCode(main, userId, deviceName), true)); + assertThrows( + LimitReachedException.class, + () -> Totp.verifyCode(main, "user", "invalid-code-N+2", true)); + } + + @Test + public void rateLimitCooldownTest() throws Exception { + String[] args = { "../" }; + + // set rate limiting cooldown time to 1s + Utils.setValueInConfig("totp_rate_limit_cooldown_time", "1"); + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + assert (false); } + TOTPStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); + + Main main = process.getProcess(); + + // Create device + String secret = Totp.registerDevice(main, "user", "deviceName", 1, 30); + + // Trigger rate limiting and fix it with a correct code after some time: + triggerAndCheckRateLimit(main, "user", "deviceName"); + // Wait for 1 second (Should cool down rate limiting): + Thread.sleep(1000); + // But again try with invalid code: + Totp.verifyCode(main, "user", "yet-another-invalid-code", true); + // Wait for 1 second (Should cool down rate limiting): + Thread.sleep(1000); + // Now try with valid code: + Totp.verifyCode(main, "user", Totp.generateTotpCode(main, "user", "deviceName"), true); + } + + @Test + public void removeExpiredCodesCronDuringRateLimitTest() throws Exception { + TestSetupResult result = setup(); + Main main = result.process.getProcess(); + + // Create device + String secret = Totp.registerDevice(main, "user", "deviceName", 1, 30); + + // Trigger rate limiting and fix it with cronjob (runs every 1 hour) + triggerAndCheckRateLimit(main, "user", "deviceName"); + // FIXME: Run cronjob at higher frequency: + assert false; + // Totp.runCron(main); + Totp.verifyCode(main, "user", "XXXX-code", true); } @Test @@ -167,8 +218,7 @@ public void createAndVerifyDevice() throws Exception { Main main = result.process.getProcess(); // Create device - // FIXME: Use secret to generate actual TOTP code - String secret = Totp.registerDevice(main, "userId", "deviceName", 1, 30); + String secret = Totp.registerDevice(main, "user", "deviceName", 1, 30); // Try verify non-existent user: assertThrows(TotpNotEnabledException.class, @@ -176,27 +226,65 @@ public void createAndVerifyDevice() throws Exception { // Try verify non-existent device assertThrows(UnknownDeviceException.class, - () -> Totp.verifyDevice(main, "userId", "non-existent-device", "XXXX")); + () -> Totp.verifyDevice(main, "user", "non-existent-device", "XXXX")); // Verify device with wrong code - assertThrows(InvalidTotpException.class, () -> Totp.verifyDevice(main, "userId", "deviceName", "wrong-code")); + assertThrows(InvalidTotpException.class, () -> Totp.verifyDevice(main, "user", "deviceName", "wrong-code")); // Verify device with correct code - boolean deviceAlreadyVerified = Totp.verifyDevice(main, "userId", "deviceName", "XXXX-correct"); - assert !deviceAlreadyVerified; + String validCode = Totp.generateTotpCode(main, "user", "deviceName"); + boolean justVerfied = Totp.verifyDevice(main, "user", "deviceName", validCode); + assert justVerfied; // Verify again with same correct code: - assertThrows(InvalidTotpException.class, () -> Totp.verifyDevice(main, "userId", "deviceName", "XXXX")); + justVerfied = Totp.verifyDevice(main, "user", "deviceName", validCode); + assert !justVerfied; // Verify again with new correct code: - deviceAlreadyVerified = Totp.verifyDevice(main, "userId", "deviceName", "XXXX-new"); - assert deviceAlreadyVerified; + String newValidCode = Totp.generateTotpCode(main, "user", "deviceName"); + justVerfied = Totp.verifyDevice(main, "user", "deviceName", newValidCode); + assert !justVerfied; - // Verify again with wrong code - assertThrows(InvalidTotpException.class, () -> Totp.verifyDevice(main, "userId", "deviceName", "wrong-code")); + // Verify again with wrong code: + justVerfied = Totp.verifyDevice(main, "user", "deviceName", "wrong-code"); + assert !justVerfied; result.process.kill(); assertNotNull(result.process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } + @Test + public void deleteDevice() throws Exception { + + // Deleting the last device of a user should delete all related codes: + TestSetupResult result = setup(); + Main main = result.process.getProcess(); + + // Create device + String secret = Totp.registerDevice(main, "user", "device", 1, 30); + + long nextDay = System.currentTimeMillis() + 1000 * 60 * 60 * 24; // 1 day + long now = System.currentTimeMillis(); + { + Totp.verifyCode(main, "user", "invalid-code", true); + Totp.verifyCode(main, "user", Totp.generateTotpCode(main, "user", "device"), true); + + // delete device2 as well + // storage.startTransaction(con -> { + // storage.deleteDevice_Transaction(con, "user", "device2"); + // storage.commitTransaction(con); + // return null; + // }); + + TOTPDevice[] devices = Totp.getDevices(main, "user"); + assert (devices.length == 0); + } + } + + @Test + public void deleteUser() throws Exception { + // Deleting a user should delete all related devices and codes: + + } + } diff --git a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java index d677594b5..e2a421441 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java @@ -22,14 +22,15 @@ import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; +import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; public class TOTPStorageTest { public class TestSetupResult { - public TOTPStorage storage; + public TOTPSQLStorage storage; public TestingProcessManager.TestingProcess process; - public TestSetupResult(TOTPStorage storage, TestingProcessManager.TestingProcess process) { + public TestSetupResult(TOTPSQLStorage storage, TestingProcessManager.TestingProcess process) { this.storage = storage; this.process = process; } @@ -57,7 +58,7 @@ public TestSetupResult setup() throws InterruptedException { if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { assert (false); } - TOTPStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); + TOTPSQLStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); return new TestSetupResult(storage, process); } @@ -65,11 +66,11 @@ public TestSetupResult setup() throws InterruptedException { @Test public void createDeviceTests() throws Exception { TestSetupResult result = setup(); - TOTPStorage storage = result.storage; + TOTPSQLStorage storage = result.storage; - TOTPDevice device1 = new TOTPDevice("user", "d1", "secretKey", 30, 1, false); - TOTPDevice device2 = new TOTPDevice("user", "d2", "secretKey", 30, 1, false); - TOTPDevice device2Duplicate = new TOTPDevice("user", "d2", "secretKey", 30, 1, false); + TOTPDevice device1 = new TOTPDevice("user", "d1", "secret", 30, 1, false); + TOTPDevice device2 = new TOTPDevice("user", "d2", "secret", 30, 1, true); + TOTPDevice device2Duplicate = new TOTPDevice("user", "d2", "new-secret", 30, 1, false); storage.createDevice(device1); @@ -93,7 +94,7 @@ public void createDeviceTests() throws Exception { @Test public void verifyDeviceTests() throws Exception { TestSetupResult result = setup(); - TOTPStorage storage = result.storage; + TOTPSQLStorage storage = result.storage; TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); storage.createDevice(device); @@ -113,10 +114,48 @@ public void verifyDeviceTests() throws Exception { assertThrows(UnknownDeviceException.class, () -> storage.markDeviceAsVerified("user", "non-existent-device")); } + // FIXME: Should write tests for other transaction functions as well. + + @Test + public void getDevicesCount_TransactionTests() throws Exception { + TestSetupResult result = setup(); + TOTPSQLStorage storage = result.storage; + + // Try to get the count for a user that doesn't exist (Should pass because + // this is DB level txn that doesn't throw TotpNotEnabledException): + int devicesCount = storage.startTransaction(con -> { + int value = storage.getDevicesCount_Transaction(con, "non-existent-user"); + storage.commitTransaction(con); + return value; + }); + assert devicesCount == 0; + + TOTPDevice device1 = new TOTPDevice("user", "device1", "sk1", 30, 1, false); + TOTPDevice device2 = new TOTPDevice("user", "device2", "sk2", 30, 1, false); + + storage.createDevice(device1); + storage.createDevice(device2); + + devicesCount = storage.startTransaction(con -> { + int value = storage.getDevicesCount_Transaction(con, "user"); + storage.commitTransaction(con); + return value; + }); + assert devicesCount == 2; + } + @Test - public void deleteDeviceTests() throws Exception { + public void removeUser_TransactionTests() throws Exception { TestSetupResult result = setup(); - TOTPStorage storage = result.storage; + TOTPSQLStorage storage = result.storage; + + // Try to remove a user that doesn't exist (Should pass because + // this is DB level txn that doesn't throw TotpNotEnabledException): + storage.startTransaction(con -> { + storage.removeUser_Transaction(con, "non-existent-user"); + storage.commitTransaction(con); + return null; + }); TOTPDevice device1 = new TOTPDevice("user", "device1", "sk1", 30, 1, false); TOTPDevice device2 = new TOTPDevice("user", "device2", "sk2", 30, 1, false); @@ -124,41 +163,82 @@ public void deleteDeviceTests() throws Exception { storage.createDevice(device1); storage.createDevice(device2); + long now = System.currentTimeMillis(); + long expiryAfter10mins = now + 10 * 60 * 1000; + + TOTPUsedCode usedCode1 = new TOTPUsedCode("user", "code1", true, expiryAfter10mins, now); + TOTPUsedCode usedCode2 = new TOTPUsedCode("user", "code2", false, expiryAfter10mins, now); + + storage.insertUsedCode(usedCode1); + storage.insertUsedCode(usedCode2); + TOTPDevice[] storedDevices = storage.getDevices("user"); assert (storedDevices.length == 2); - // Try to delete a device for a user that doesn't exist: - assertThrows(UnknownDeviceException.class, () -> storage.deleteDevice("non-existent-user", "device1")); + TOTPUsedCode[] storedUsedCodes = storage.getNonExpiredUsedCodes("user"); + assert (storedUsedCodes.length == 2); - // Try to delete a device that doesn't exist: - assertThrows(UnknownDeviceException.class, () -> storage.deleteDevice("user", "non-existent-device")); - - // Successfully delete device1: - storage.deleteDevice("user", "device1"); + storage.startTransaction(con -> { + storage.removeUser_Transaction(con, "user"); + storage.commitTransaction(con); + return null; + }); storedDevices = storage.getDevices("user"); - assert (storedDevices.length == 1); // device2 should still be there + assert (storedDevices.length == 0); - long nextDay = System.currentTimeMillis() + 1000 * 60 * 60 * 24; // 1 day from now - long now = System.currentTimeMillis(); - // Deleting all devices of a user should delete all related codes: - { - TOTPUsedCode validCode = new TOTPUsedCode("user", "valid-code", true, nextDay, now); - TOTPUsedCode invalidCode = new TOTPUsedCode("user", "invalid-code", false, nextDay, now); - storage.insertUsedCode(validCode); - storage.insertUsedCode(invalidCode); + storedUsedCodes = storage.getNonExpiredUsedCodes("user"); + assert (storedUsedCodes.length == 0); + } - storage.deleteDevice("user", "device2"); // delete device2 as well + @Test + public void deleteDevice_TransactionTests() throws Exception { + TestSetupResult result = setup(); + TOTPSQLStorage storage = result.storage; - TOTPUsedCode[] newUsedCodes = storage.getNonExpiredUsedCodes("user"); - assert (newUsedCodes.length == 0); + TOTPDevice device1 = new TOTPDevice("user", "device1", "sk1", 30, 1, false); + TOTPDevice device2 = new TOTPDevice("user", "device2", "sk2", 30, 1, false); + + storage.createDevice(device1); + storage.createDevice(device2); + + TOTPDevice[] storedDevices = storage.getDevices("user"); + assert (storedDevices.length == 2); + + // Try to delete a device for a user that doesn't exist (Should pass because + // this is DB level txn that doesn't throw TotpNotEnabledException): + storage.startTransaction(con -> { + storage.deleteDevice_Transaction(con, "non-existent-user", "device1"); + storage.commitTransaction(con); + return null; + }); + + // Try to delete a device that doesn't exist: + try { + storage.startTransaction(con -> { + storage.deleteDevice_Transaction(con, "user", "non-existent-device"); + storage.commitTransaction(con); + return null; + }); + } catch (Exception e) { + assert (e instanceof UnknownDeviceException) ? true : false; } + + // Successfully delete device1: + storage.startTransaction(con -> { + storage.deleteDevice_Transaction(con, "user", "device1"); + storage.commitTransaction(con); + return null; + }); + + storedDevices = storage.getDevices("user"); + assert (storedDevices.length == 1); // device2 should still be there } @Test public void updateDeviceNameTests() throws Exception { TestSetupResult result = setup(); - TOTPStorage storage = result.storage; + TOTPSQLStorage storage = result.storage; TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); storage.createDevice(device); @@ -193,7 +273,7 @@ public void updateDeviceNameTests() throws Exception { @Test public void getDevicesTest() throws Exception { TestSetupResult result = setup(); - TOTPStorage storage = result.storage; + TOTPSQLStorage storage = result.storage; TOTPDevice device1 = new TOTPDevice("user", "d1", "secretKey", 30, 1, false); TOTPDevice device2 = new TOTPDevice("user", "d2", "secretKey", 30, 1, false); @@ -214,7 +294,7 @@ public void getDevicesTest() throws Exception { @Test public void insertUsedCodeTest() throws Exception { TestSetupResult result = setup(); - TOTPStorage storage = result.storage; + TOTPSQLStorage storage = result.storage; long nextDay = System.currentTimeMillis() + 1000 * 60 * 60 * 24; // 1 day from now // Insert a long lasting valid code and check that it's returned when queried: @@ -232,10 +312,10 @@ public void insertUsedCodeTest() throws Exception { // Try to insert code when user doesn't have any device (i.e. TOTP not enabled) { - storage.deleteDevice("user", "device"); assertThrows(TotpNotEnabledException.class, () -> storage.insertUsedCode( - new TOTPUsedCode("user", "1234", true, nextDay, System.currentTimeMillis()))); + new TOTPUsedCode("new-user-without-totp", "1234", true, nextDay, + System.currentTimeMillis()))); } // Try to insert code after user has atleast one device (i.e. TOTP enabled) @@ -254,37 +334,43 @@ public void insertUsedCodeTest() throws Exception { @Test public void getNonExpiredUsedCodesTest() throws Exception { TestSetupResult result = setup(); - TOTPStorage storage = result.storage; + TOTPSQLStorage storage = result.storage; TOTPUsedCode[] usedCodes = storage.getNonExpiredUsedCodes("non-existent-user"); assert (usedCodes.length == 0); - long nextDay = System.currentTimeMillis() + 1000 * 60 * 60 * 24; // 1 day from now - long prevDay = System.currentTimeMillis() - 1000 * 60 * 60 * 24; // 1 day ago long now = System.currentTimeMillis(); + long nextDay = now + 1000 * 60 * 60 * 24; // 1 day from now + long prevDay = now - 1000 * 60 * 60 * 24; // 1 day ago TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); - TOTPUsedCode validCode = new TOTPUsedCode("user", "code1", true, nextDay, now); - TOTPUsedCode invalidCode = new TOTPUsedCode("user", "code2", false, nextDay, now); + TOTPUsedCode validCode1 = new TOTPUsedCode("user", "valid-code-1", true, nextDay, now); + TOTPUsedCode invalidCode = new TOTPUsedCode("user", "invalid-code", false, nextDay, now); TOTPUsedCode expiredCode = new TOTPUsedCode("user", "expired-code", true, prevDay, now); TOTPUsedCode expiredInvalidCode = new TOTPUsedCode("user", "expired-invalid-code", false, prevDay, now); + TOTPUsedCode validCode2 = new TOTPUsedCode("user", "valid-code-2", true, nextDay, now + 1); + TOTPUsedCode validCode3 = new TOTPUsedCode("user", "valid-code-3", true, nextDay, now + 2); storage.createDevice(device); - storage.insertUsedCode(validCode); + storage.insertUsedCode(validCode1); storage.insertUsedCode(invalidCode); storage.insertUsedCode(expiredCode); storage.insertUsedCode(expiredInvalidCode); + storage.insertUsedCode(validCode2); + storage.insertUsedCode(validCode3); usedCodes = storage.getNonExpiredUsedCodes("user"); - assert (usedCodes.length == 2); // expired codes shouldn't be returned - assert (usedCodes[0].equals(validCode)); - assert (usedCodes[1].equals(invalidCode)); + assert (usedCodes.length == 4); // expired codes shouldn't be returned + assert (usedCodes[0].equals(validCode3)); // order is DESC by created time (now + X) + assert (usedCodes[1].equals(validCode2)); + assert (usedCodes[2].equals(validCode1)); + assert (usedCodes[3].equals(invalidCode)); } @Test public void removeExpiredCodesTest() throws Exception { TestSetupResult result = setup(); - TOTPStorage storage = result.storage; + TOTPSQLStorage storage = result.storage; long now = System.currentTimeMillis(); long nextDay = System.currentTimeMillis() + 1000 * 60 * 60 * 24; // 1 day from now @@ -308,7 +394,7 @@ public void removeExpiredCodesTest() throws Exception { // After 500ms seconds pass: Thread.sleep(500); - storage.removeExpiredCodes(); + storage.removeExpiredCodes(5); // FIXME:::: usedCodes = storage.getNonExpiredUsedCodes("user"); assert (usedCodes.length == 2); @@ -319,7 +405,7 @@ public void removeExpiredCodesTest() throws Exception { @Test public void deleteAllDataForUserTest() throws Exception { TestSetupResult result = setup(); - TOTPStorage storage = result.storage; + TOTPSQLStorage storage = result.storage; long now = System.currentTimeMillis(); long nextDay = now + 1000 * 60 * 60 * 24; // 1 day from now From a83c6f638df5d40383a95a1f2f366e3430fb80cf Mon Sep 17 00:00:00 2001 From: KShivendu Date: Tue, 28 Feb 2023 16:59:48 +0530 Subject: [PATCH 19/42] feat: Improve TOTP recipe - Use device period in totp generation and validation - Add tests to cover most edge cases of Totp.java - Fix overriding of totp_rate_limit_cooldown_sec - Add tests for TOTP cron - Add comments for edge cases and readability - Introduce totp_invalid_code_expiry_sec config --- config.yaml | 9 +- devConfig.yaml | 7 +- .../io/supertokens/config/CoreConfig.java | 20 +- .../DeleteExpiredTotpTokens.java | 10 +- .../java/io/supertokens/inmemorydb/Start.java | 4 +- .../inmemorydb/queries/TOTPQueries.java | 4 +- src/main/java/io/supertokens/totp/Totp.java | 51 +++- .../io/supertokens/test/ConfigTest2_6.java | 1 + .../supertokens/test/totp/TOTPRecipeTest.java | 269 ++++++++++++++---- .../test/totp/TOTPStorageTest.java | 26 +- 10 files changed, 304 insertions(+), 97 deletions(-) diff --git a/config.yaml b/config.yaml index 489458b0c..c239836d9 100644 --- a/config.yaml +++ b/config.yaml @@ -57,8 +57,11 @@ core_config_version: 0 # (OPTIONAL | Default: 5) integer value. The maximum number of invalid TOTP attempts that will trigger rate limiting. # totp_max_attempts: -# (OPTIONAL | Default: 900) integer value. The time in seconds for which the user will be rate limited (once totp_max_attempts is crossed) -# totp_rate_limit_cooldown_time: +# (OPTIONAL | Default: 900) integer value. The time in seconds for which the user will be rate limited once totp_max_attempts is crossed. +# totp_rate_limit_cooldown_sec: + +# (OPTIONAL | Default: 1800) integer value. The time in seconds in which invalid TOTP codes will be considered expired. +# totp_invalid_code_expiry_sec: # (OPTIONAL | Default: installation directory/logs/info.log) string value. Give the path to a file (on your local # system) in which the SuperTokens service can write INFO logs to. Set it to "null" if you want it to log to @@ -125,4 +128,4 @@ core_config_version: 0 # (OPTIONAL | Default: null). Regex for denying requests from IP addresses that match with the value. Comment this # value to deny no IP address. -# ip_deny_regex: \ No newline at end of file +# ip_deny_regex: diff --git a/devConfig.yaml b/devConfig.yaml index 15f8ffa12..8969d5f5a 100644 --- a/devConfig.yaml +++ b/devConfig.yaml @@ -57,8 +57,11 @@ core_config_version: 0 # (OPTIONAL | Default: 5) integer value. The maximum number of invalid TOTP attempts that will trigger rate limiting. # totp_max_attempts: -# (OPTIONAL | Default: 900) integer value. The time in seconds for which the user will be rate limited (once totp_max_attempts is crossed) -# totp_rate_limit_cooldown_time: +# (OPTIONAL | Default: 900) integer value. The time in seconds for which the user will be rate limited once totp_max_attempts is crossed. +# totp_rate_limit_cooldown_sec: + +# (OPTIONAL | Default: 1800) integer value. The time in seconds in which invalid TOTP codes will be considered expired. +# totp_invalid_code_expiry_sec: # (OPTIONAL | Default: installation directory/logs/info.log) string value. Give the path to a file (on your local # system) in which the SuperTokens service can write INFO logs to. Set it to "null" if you want it to log to diff --git a/src/main/java/io/supertokens/config/CoreConfig.java b/src/main/java/io/supertokens/config/CoreConfig.java index 1b3e57df6..383e0caab 100644 --- a/src/main/java/io/supertokens/config/CoreConfig.java +++ b/src/main/java/io/supertokens/config/CoreConfig.java @@ -60,7 +60,10 @@ public class CoreConfig { private int totp_max_attempts = 5; @JsonProperty - private int totp_rate_limit_cooldown_time = 900; // in seconds (Default 15 mins) + private int totp_rate_limit_cooldown_sec = 900; // in seconds (Default 15 mins) + + @JsonProperty + private int totp_invalid_code_expiry_sec = 1800; // in seconds (Default 30 mins) private final String logDefault = "asdkfahbdfk3kjHS"; @JsonProperty @@ -283,7 +286,12 @@ public int getTotpMaxAttempts() { /** TOTP rate limit cooldown time (in seconds) */ public int getTotpRateLimitCooldownTime() { - return totp_rate_limit_cooldown_time; + return totp_rate_limit_cooldown_sec; + } + + /** TOTP invalid code expiry time (in seconds) */ + public int getTotpInvalidCodeExpiryTime() { + return totp_invalid_code_expiry_sec; } public boolean isTelemetryDisabled() { @@ -408,8 +416,12 @@ void validateAndInitialise(Main main) throws IOException { throw new QuitProgramException("'totp_max_attempts' must be > 0"); } - if (totp_rate_limit_cooldown_time <= 0) { - throw new QuitProgramException("'totp_rate_limit_cooldown_time' must be > 0"); + if (totp_rate_limit_cooldown_sec <= 0) { + throw new QuitProgramException("'totp_rate_limit_cooldown_sec' must be > 0"); + } + + if (totp_invalid_code_expiry_sec <= 0) { + throw new QuitProgramException("'totp_invalid_code_expiry_sec' must be > 0"); } if (max_server_pool_size <= 0) { diff --git a/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java b/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java index fa61f4db7..32d2d6084 100644 --- a/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java +++ b/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java @@ -1,5 +1,7 @@ package io.supertokens.cronjobs.deleteExpiredTotpTokens; +import org.jetbrains.annotations.TestOnly; + import io.supertokens.Main; import io.supertokens.ResourceDistributor; import io.supertokens.pluginInterface.STORAGE_TYPE; @@ -24,6 +26,11 @@ public static DeleteExpiredTotpTokens getInstance(Main main) { return (DeleteExpiredTotpTokens) instance; } + @TestOnly + public void doTaskForTest() throws Exception { + doTask(); + } + @Override protected void doTask() throws Exception { if (StorageLayer.getStorage(this.main).getType() != STORAGE_TYPE.SQL) { @@ -32,8 +39,7 @@ protected void doTask() throws Exception { TOTPSQLStorage storage = StorageLayer.getTOTPStorage(this.main); - int N = 5; // FIXME:: This is not used anywhere - storage.removeExpiredCodes(N); + storage.removeExpiredCodes(); } @Override diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index ca91f51a3..0e0b4c882 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -1750,10 +1750,10 @@ public TOTPUsedCode[] getNonExpiredUsedCodes(String userId) } @Override - public void removeExpiredCodes(int N) + public void removeExpiredCodes() throws StorageQueryException { try { - TOTPQueries.removeExpiredCodes(this, N); + TOTPQueries.removeExpiredCodes(this); } catch (SQLException e) { throw new StorageQueryException(e); } diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index 21cf1a86d..506291357 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -264,10 +264,8 @@ public static TOTPUsedCode[] getNonExpiredUsedCodesDescOrder(Start start, String }); } - public static int removeExpiredCodes(Start start, int offset) + public static int removeExpiredCodes(Start start) throws StorageQueryException, SQLException { - // TODO: Use offset with `order by created_at desc` to exclude latest N codes - // for each user String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUsedCodesTable() + " WHERE expiry_time_ms < ?;"; diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index 42e6ef6d8..b242e7cce 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -4,6 +4,7 @@ import java.security.Key; import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; +import java.time.Duration; import java.time.Instant; import java.util.Arrays; import java.util.Base64; @@ -30,6 +31,7 @@ import io.supertokens.storageLayer.StorageLayer; import io.supertokens.totp.exceptions.InvalidTotpException; import io.supertokens.totp.exceptions.LimitReachedException; +import jakarta.annotation.Nullable; public class Totp { private static String generateSecret() throws NoSuchAlgorithmException { @@ -47,7 +49,8 @@ private static String generateSecret() throws NoSuchAlgorithmException { } private static boolean checkCode(TOTPDevice device, String code) { - final TimeBasedOneTimePasswordGenerator totp = new TimeBasedOneTimePasswordGenerator(); + final TimeBasedOneTimePasswordGenerator totp = new TimeBasedOneTimePasswordGenerator( + Duration.ofSeconds(device.period)); byte[] keyBytes = Base64.getDecoder().decode(device.secretKey); Key key = new SecretKeySpec(keyBytes, "HmacSHA1"); @@ -58,7 +61,6 @@ private static boolean checkCode(TOTPDevice device, String code) { // Check if code is valid for any of the time periods in the skew: for (int i = -skew; i <= skew; i++) { try { - // FIXME: Where is this using % on the period? if (totp.generateOneTimePasswordString(key, Instant.now().plusSeconds(i * period)).equals(code)) { return true; } @@ -70,31 +72,40 @@ private static boolean checkCode(TOTPDevice device, String code) { return false; } + @TestOnly + public static String generateTotpCode(Main main, TOTPDevice device) + throws InvalidKeyException, StorageQueryException { + return generateTotpCode(main, device, 0); + } + /** * Replicates TOTP code is generated by apps like Google Authenticator and Authy - * + * * @throws StorageQueryException */ @TestOnly - public static String generateTotpCode(Main main, String userId, String deviceName) + public static String generateTotpCode(Main main, TOTPDevice device, int step) throws InvalidKeyException, StorageQueryException { - TOTPDevice[] devices = StorageLayer.getTOTPStorage(main).getDevices(userId); - TOTPDevice device = Arrays.stream(devices).filter(d -> d.deviceName.equals(deviceName)).findFirst() - .orElse(null); - - final TimeBasedOneTimePasswordGenerator totp = new TimeBasedOneTimePasswordGenerator(); + final TimeBasedOneTimePasswordGenerator totp = new TimeBasedOneTimePasswordGenerator( + Duration.ofSeconds(device.period)); byte[] keyBytes = Base64.getDecoder().decode(device.secretKey); Key key = new SecretKeySpec(keyBytes, "HmacSHA1"); - return totp.generateOneTimePasswordString(key, Instant.now()); + return totp.generateOneTimePasswordString(key, Instant.now().plusSeconds(step * device.period)); } - public static String registerDevice(Main main, String userId, String deviceName, int skew, int period) + public static TOTPDevice registerDevice(Main main, String userId, String deviceName, int skew, int period) throws StorageQueryException, DeviceAlreadyExistsException, NoSuchAlgorithmException { TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); + // Assert that period > 0 (not == 0 because it would lead to a divide by 0 + // error) + // Assert that period <= 60. Otherwise, it is a security risk. Actually, + // anything > 30 is bad. + // and skew >= 0 and skew <= 2. Otherwise, it is a security risk. + // TODO: There should be a hard limit on number of devices per user // 8 devices per user should be enough. Otherwise, it is a security risk. @@ -102,7 +113,7 @@ public static String registerDevice(Main main, String userId, String deviceName, TOTPDevice device = new TOTPDevice(userId, deviceName, secret, period, skew, false); totpStorage.createDevice(device); - return secret; + return device; } private static void checkAndStoreCode(Main main, TOTPStorage totpStorage, String userId, TOTPDevice[] devices, @@ -126,6 +137,19 @@ private static void checkAndStoreCode(Main main, TOTPStorage totpStorage, String // Less than rateLimitResetTimeInMs (default = 15 mins) time has elasped since // the last invalid code: throw new LimitReachedException(); + + // If we insert the used code here, then it will further delay the user from + // being able to login. So not inserting it here. + + // Note: One edge case here is: user is rate limited, and then the + // DeleteExpiredTotpTokens cron removes the latest invalid attempts + // (because they are expired), and then user will again be able to + // do extra login attempts (totp_max_attempts times). + // But the cron running during cooldown of a user is somewhat rare. + // And every period has enough entropy to make brute force attacks + // infeasible, the code changes every period, and rate limiting will + // kick in after totp_max_attempts number of attempts anyways. + // So this edge case is not a big deal. } } @@ -159,7 +183,8 @@ private static void checkAndStoreCode(Main main, TOTPStorage totpStorage, String // Insert the code into the list of used codes: long now = System.currentTimeMillis(); - int expireInSec = isValid ? matchingDevice.period * (2 * matchingDevice.skew + 1) : 60 * 30; // 30 mins + int invalidCodeExpirySec = Config.getConfig(main).getTotpInvalidCodeExpiryTime(); // (Default 30 mins) + int expireInSec = isValid ? matchingDevice.period * (2 * matchingDevice.skew + 1) : invalidCodeExpirySec; TOTPUsedCode newCode = new TOTPUsedCode(userId, code, isValid, now + 1000 * expireInSec, now); totpStorage.insertUsedCode(newCode); diff --git a/src/test/java/io/supertokens/test/ConfigTest2_6.java b/src/test/java/io/supertokens/test/ConfigTest2_6.java index c40491762..50faca16d 100644 --- a/src/test/java/io/supertokens/test/ConfigTest2_6.java +++ b/src/test/java/io/supertokens/test/ConfigTest2_6.java @@ -223,6 +223,7 @@ private static void checkConfigValues(CoreConfig config, TestingProcess process, // TODO: Is this correct? assertEquals(5, config.getTotpMaxAttempts()); // 5 assertEquals(900, config.getTotpRateLimitCooldownTime()); // 15 minutes + assertEquals(5, config.getTotpInvalidCodeExpiryTime()); // 30 minutes assertEquals("Config info log path did not match default", config.getInfoLogPath(process.getProcess()), CLIOptions.get(process.getProcess()).getInstallationPath() + "logs/info.log"); diff --git a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java index 68098762d..f23d2897d 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java @@ -32,6 +32,8 @@ import io.supertokens.ProcessState; import io.supertokens.config.Config; import io.supertokens.config.CoreConfig; +import io.supertokens.cronjobs.CronTaskTest; +import io.supertokens.cronjobs.deleteExpiredTotpTokens.DeleteExpiredTotpTokens; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.test.TestingProcessManager; @@ -71,7 +73,7 @@ public TestSetupResult(TOTPStorage storage, TestingProcessManager.TestingProcess } } - public TestSetupResult setup() throws InterruptedException, IOException { + public TestSetupResult defaultInit() throws InterruptedException, IOException { String[] args = { "../" }; TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); @@ -86,13 +88,13 @@ public TestSetupResult setup() throws InterruptedException, IOException { } @Test - public void createDevice() throws Exception { - TestSetupResult result = setup(); + public void createDeviceTest() throws Exception { + TestSetupResult result = defaultInit(); Main main = result.process.getProcess(); // Create device - String secret = Totp.registerDevice(main, "user", "device1", 1, 30); - assert secret != ""; + TOTPDevice device = Totp.registerDevice(main, "user", "device1", 1, 30); + assert device.secretKey != ""; // Create same device again (should fail) assertThrows(DeviceAlreadyExistsException.class, @@ -100,12 +102,12 @@ public void createDevice() throws Exception { } @Test - public void createDeviceAndVerifyCode() throws Exception { - TestSetupResult result = setup(); + public void createDeviceAndVerifyCodeTest() throws Exception { + TestSetupResult result = defaultInit(); Main main = result.process.getProcess(); // Create device - String secret = Totp.registerDevice(main, "user", "device", 1, 30); + TOTPDevice device = Totp.registerDevice(main, "user", "device", 1, 1); // Try login with non-existent user: assertThrows(TotpNotEnabledException.class, @@ -124,10 +126,10 @@ public void createDeviceAndVerifyCode() throws Exception { // Valid code & allowUnverifiedDevice = false: assertThrows( InvalidTotpException.class, - () -> Totp.verifyCode(main, "user", Totp.generateTotpCode(main, "user", "device"), false)); + () -> Totp.verifyCode(main, "user", Totp.generateTotpCode(main, device), false)); // Valid code & allowUnverifiedDevice = true (Success): - String validCode = Totp.generateTotpCode(main, "user", "device"); + String validCode = Totp.generateTotpCode(main, device); Totp.verifyCode(main, "user", validCode, true); // Now try again with same code: @@ -135,13 +137,36 @@ public void createDeviceAndVerifyCode() throws Exception { InvalidTotpException.class, () -> Totp.verifyCode(main, "user", validCode, true)); + // Sleep for 1s so that code changes. + Thread.sleep(1500); + // Use a new valid code: - String newValidCode = Totp.generateTotpCode(main, "user", "device"); + String newValidCode = Totp.generateTotpCode(main, device); Totp.verifyCode(main, "user", newValidCode, true); + + // Regenerate the same code and use it again (should fail): + String newValidCodeCopy = Totp.generateTotpCode(main, device); + assertThrows(InvalidTotpException.class, + () -> Totp.verifyCode(main, "user", newValidCodeCopy, true)); + + // Use a code from next period: + String nextValidCode = Totp.generateTotpCode(main, device, 1); + Totp.verifyCode(main, "user", nextValidCode, true); + + // Use previous period code (should fail coz validCode): // FIXME: This should + // // fail + // String previousCode = Totp.generateTotpCode(main, "user", "device", -1); + // Totp.verifyCode(main, "user", previousCode, true); + + // TODO: Add tests for next and previous codes as well. + // TODO: Add tests for different skew values (0 and 1) + // TODO: Add tests where we change totp_max_attempts + // TODO: Add tests where we change totp_invalid_code_expiry_sec } - public void triggerAndCheckRateLimit(Main main, String userId, String deviceName) throws Exception { + public void triggerAndCheckRateLimit(Main main, TOTPDevice device) throws Exception { int N = Config.getConfig(main).getTotpMaxAttempts(); + // First N attempts should fail with invalid code: // This is to trigger rate limiting for (int i = 0; i < N; i++) { @@ -158,7 +183,7 @@ public void triggerAndCheckRateLimit(Main main, String userId, String deviceName () -> Totp.verifyCode(main, "user", "invalid-code-N+1", true)); assertThrows( LimitReachedException.class, - () -> Totp.verifyCode(main, "user", Totp.generateTotpCode(main, userId, deviceName), true)); + () -> Totp.verifyCode(main, "user", Totp.generateTotpCode(main, device), true)); assertThrows( LimitReachedException.class, () -> Totp.verifyCode(main, "user", "invalid-code-N+2", true)); @@ -169,7 +194,7 @@ public void rateLimitCooldownTest() throws Exception { String[] args = { "../" }; // set rate limiting cooldown time to 1s - Utils.setValueInConfig("totp_rate_limit_cooldown_time", "1"); + Utils.setValueInConfig("totp_rate_limit_cooldown_sec", "1"); TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); @@ -177,48 +202,62 @@ public void rateLimitCooldownTest() throws Exception { if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { assert (false); } - TOTPStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); Main main = process.getProcess(); // Create device - String secret = Totp.registerDevice(main, "user", "deviceName", 1, 30); + TOTPDevice device = Totp.registerDevice(main, "user", "deviceName", 1, 1); // Trigger rate limiting and fix it with a correct code after some time: - triggerAndCheckRateLimit(main, "user", "deviceName"); + triggerAndCheckRateLimit(main, device); // Wait for 1 second (Should cool down rate limiting): Thread.sleep(1000); // But again try with invalid code: - Totp.verifyCode(main, "user", "yet-another-invalid-code", true); + assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "yet-another-invalid-code", true)); + // This triggered rate limiting again. So even valid codes will fail for + // another cooldown period: + assertThrows(LimitReachedException.class, + () -> Totp.verifyCode(main, "user", Totp.generateTotpCode(main, device), true)); // Wait for 1 second (Should cool down rate limiting): Thread.sleep(1000); // Now try with valid code: - Totp.verifyCode(main, "user", Totp.generateTotpCode(main, "user", "deviceName"), true); + Totp.verifyCode(main, "user", Totp.generateTotpCode(main, device), true); + // Now invalid code shouldn't trigger rate limiting. Unless you do it N times: + assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "some-invalid-code", true)); } @Test public void removeExpiredCodesCronDuringRateLimitTest() throws Exception { - TestSetupResult result = setup(); + TestSetupResult result = defaultInit(); Main main = result.process.getProcess(); // Create device - String secret = Totp.registerDevice(main, "user", "deviceName", 1, 30); - - // Trigger rate limiting and fix it with cronjob (runs every 1 hour) - triggerAndCheckRateLimit(main, "user", "deviceName"); - // FIXME: Run cronjob at higher frequency: - assert false; - // Totp.runCron(main); - Totp.verifyCode(main, "user", "XXXX-code", true); + TOTPDevice device = Totp.registerDevice(main, "user", "deviceName", 0, 1); + + // Trigger rate limiting and fix it with cronjob (manually run cronjob): + triggerAndCheckRateLimit(main, device); + // Wait for 1 second so that all the codes expire: + Thread.sleep(1000); + // FIXME: Can this be cleaner? + DeleteExpiredTotpTokens.getInstance(main).doTaskForTest(); + // Will completely reset the rate limiting. Allowing the user to do N attempts + // here N == totp_max_attempts from the config: + assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "again-wrong-code1", true)); + // This should have throws LimitReachedException but it won't because of cron: + assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "again-wrong-code2", true)); + + Totp.verifyCode(main, "user", Totp.generateTotpCode(main, device), true); + // We can do N attempts again: + triggerAndCheckRateLimit(main, device); } @Test - public void createAndVerifyDevice() throws Exception { - TestSetupResult result = setup(); + public void createAndVerifyDeviceTest() throws Exception { + TestSetupResult result = defaultInit(); Main main = result.process.getProcess(); // Create device - String secret = Totp.registerDevice(main, "user", "deviceName", 1, 30); + TOTPDevice device = Totp.registerDevice(main, "user", "deviceName", 1, 30); // Try verify non-existent user: assertThrows(TotpNotEnabledException.class, @@ -232,7 +271,7 @@ public void createAndVerifyDevice() throws Exception { assertThrows(InvalidTotpException.class, () -> Totp.verifyDevice(main, "user", "deviceName", "wrong-code")); // Verify device with correct code - String validCode = Totp.generateTotpCode(main, "user", "deviceName"); + String validCode = Totp.generateTotpCode(main, device); boolean justVerfied = Totp.verifyDevice(main, "user", "deviceName", validCode); assert justVerfied; @@ -241,7 +280,7 @@ public void createAndVerifyDevice() throws Exception { assert !justVerfied; // Verify again with new correct code: - String newValidCode = Totp.generateTotpCode(main, "user", "deviceName"); + String newValidCode = Totp.generateTotpCode(main, device); justVerfied = Totp.verifyDevice(main, "user", "deviceName", newValidCode); assert !justVerfied; @@ -254,37 +293,157 @@ public void createAndVerifyDevice() throws Exception { } @Test - public void deleteDevice() throws Exception { - - // Deleting the last device of a user should delete all related codes: - TestSetupResult result = setup(); + public void removeDeviceTest() throws Exception { + TestSetupResult result = defaultInit(); Main main = result.process.getProcess(); + TOTPStorage storage = result.storage; - // Create device - String secret = Totp.registerDevice(main, "user", "device", 1, 30); + // Create devices + TOTPDevice device1 = Totp.registerDevice(main, "user", "device1", 1, 30); + TOTPDevice device2 = Totp.registerDevice(main, "user", "device2", 1, 30); - long nextDay = System.currentTimeMillis() + 1000 * 60 * 60 * 24; // 1 day - long now = System.currentTimeMillis(); + TOTPDevice[] devices = Totp.getDevices(main, "user"); + assert (devices.length == 2); + + // Delete one of the devices { - Totp.verifyCode(main, "user", "invalid-code", true); - Totp.verifyCode(main, "user", Totp.generateTotpCode(main, "user", "device"), true); - - // delete device2 as well - // storage.startTransaction(con -> { - // storage.deleteDevice_Transaction(con, "user", "device2"); - // storage.commitTransaction(con); - // return null; - // }); - - TOTPDevice[] devices = Totp.getDevices(main, "user"); - assert (devices.length == 0); + assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "invalid-code", true)); + Totp.verifyCode(main, "user", Totp.generateTotpCode(main, device1), true); + Totp.verifyCode(main, "user", Totp.generateTotpCode(main, device2), true); + + // Delete device1 + Totp.removeDevice(main, "user", "device1"); + + devices = Totp.getDevices(main, "user"); + assert (devices.length == 1); + + // 1 device still remain so all codes should still be still there: + TOTPUsedCode[] usedCodes = storage.getNonExpiredUsedCodes("user"); + assert (usedCodes.length == 3); + } + + // Deleting the last device of a user should delete all related codes: + // Delete the 2nd (and the last) device + { + + // Create another user to test that other users aren't affected: + TOTPDevice otherUserDevice = Totp.registerDevice(main, "other-user", "device", 1, 30); + Totp.verifyCode(main, "other-user", Totp.generateTotpCode(main, otherUserDevice), true); + assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "other-user", "invalid-code", true)); + + // Delete device2 + Totp.removeDevice(main, "user", "device2"); + + // TOTP has ben disabled for the user: + assertThrows(TotpNotEnabledException.class, () -> Totp.getDevices(main, "user")); + + // No device left so all codes of the user should be deleted: + TOTPUsedCode[] usedCodes = storage.getNonExpiredUsedCodes("user"); + assert (usedCodes.length == 0); + + // But for other users things should still be there: + TOTPDevice[] otherUserDevices = Totp.getDevices(main, "other-user"); + assert (otherUserDevices.length == 1); + + usedCodes = storage.getNonExpiredUsedCodes("other-user"); + assert (usedCodes.length == 2); } } @Test - public void deleteUser() throws Exception { - // Deleting a user should delete all related devices and codes: + public void updateDeviceNameTest() throws Exception { + TestSetupResult result = defaultInit(); + Main main = result.process.getProcess(); + + Totp.registerDevice(main, "user", "device1", 1, 30); + Totp.registerDevice(main, "user", "device2", 1, 30); + + // Try update non-existent user: + assertThrows(TotpNotEnabledException.class, + () -> Totp.updateDeviceName(main, "non-existent-user", "device1", "new-device-name")); + + // Try update non-existent device: + assertThrows(UnknownDeviceException.class, + () -> Totp.updateDeviceName(main, "user", "non-existent-device", "new-device-name")); + // Update device name (should work) + Totp.updateDeviceName(main, "user", "device1", "new-device-name"); + + // Verify that the device name has been updated: + TOTPDevice[] devices = Totp.getDevices(main, "user"); + assert (devices.length == 2); + assert (devices[0].deviceName.equals("device2")); + assert (devices[1].deviceName.equals("new-device-name")); + + // Verify that TOTP verification still works: + Totp.verifyDevice(main, "user", devices[0].deviceName, Totp.generateTotpCode(main, devices[0])); + Totp.verifyDevice(main, "user", devices[0].deviceName, Totp.generateTotpCode(main, devices[1])); + + // Try update device name to an already existing device name: + assertThrows(DeviceAlreadyExistsException.class, + () -> Totp.updateDeviceName(main, "user", "device2", "new-device-name")); + } + + @Test + public void getDevicesTest() throws Exception { + TestSetupResult result = defaultInit(); + Main main = result.process.getProcess(); + + // Try get devices for non-existent user: + assertThrows(TotpNotEnabledException.class, () -> Totp.getDevices(main, "non-existent-user")); + + TOTPDevice device1 = Totp.registerDevice(main, "user", "device1", 2, 30); + TOTPDevice device2 = Totp.registerDevice(main, "user", "device2", 1, 10); + + TOTPDevice[] devices = Totp.getDevices(main, "user"); + assert (devices.length == 2); + assert devices[0].equals(device1); + assert devices[1].equals(device2); + } + + @Test + public void deleteExpiredTokensCronIntervalTest() throws Exception { + TestSetupResult result = defaultInit(); + Main main = result.process.getProcess(); + + // Ensure that delete expired tokens cron runs every hour: + assert DeleteExpiredTotpTokens.getInstance(main).getIntervalTimeSeconds() == 60 * 60; } + @Test + public void deleteExpiredTokensCronTest() throws Exception { + String[] args = { "../" }; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args, false); + + CronTaskTest.getInstance(process.getProcess()).setIntervalInSeconds(DeleteExpiredTotpTokens.RESOURCE_KEY, 1); + process.startProcess(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + Main main = process.getProcess(); + + // Create device + // Set period and skew to 0 to make sure that the codes are one time usable and + // expire in 1 second + TOTPDevice device = Totp.registerDevice(main, "user", "device", 0, 1); + + // Add codes: + Totp.verifyCode(main, "user", Totp.generateTotpCode(main, device), true); + assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", + "invalid-code", true)); + + TOTPStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); + + // Verify that the codes have been added: + TOTPUsedCode[] usedCodes = storage.getNonExpiredUsedCodes("user"); + assert (usedCodes.length == 2); + + // Wait for 1 second to make sure that the valid codes expire + // (and crons deletes the valid ones since they are expired) + Thread.sleep(1000); + + usedCodes = storage.getNonExpiredUsedCodes("user"); + assert (usedCodes.length == 1); + // Invalid code will still remain because their expiration time is 5 minutes + assert usedCodes[0].code.equals("invalid-code"); + } } diff --git a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java index e2a421441..c9ded66fa 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java @@ -49,7 +49,7 @@ public void beforeEach() { Utils.reset(); } - public TestSetupResult setup() throws InterruptedException { + public TestSetupResult initSteps() throws InterruptedException { String[] args = { "../" }; TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); @@ -65,7 +65,7 @@ public TestSetupResult setup() throws InterruptedException { @Test public void createDeviceTests() throws Exception { - TestSetupResult result = setup(); + TestSetupResult result = initSteps(); TOTPSQLStorage storage = result.storage; TOTPDevice device1 = new TOTPDevice("user", "d1", "secret", 30, 1, false); @@ -93,7 +93,7 @@ public void createDeviceTests() throws Exception { @Test public void verifyDeviceTests() throws Exception { - TestSetupResult result = setup(); + TestSetupResult result = initSteps(); TOTPSQLStorage storage = result.storage; TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); @@ -118,7 +118,7 @@ public void verifyDeviceTests() throws Exception { @Test public void getDevicesCount_TransactionTests() throws Exception { - TestSetupResult result = setup(); + TestSetupResult result = initSteps(); TOTPSQLStorage storage = result.storage; // Try to get the count for a user that doesn't exist (Should pass because @@ -146,7 +146,7 @@ public void getDevicesCount_TransactionTests() throws Exception { @Test public void removeUser_TransactionTests() throws Exception { - TestSetupResult result = setup(); + TestSetupResult result = initSteps(); TOTPSQLStorage storage = result.storage; // Try to remove a user that doesn't exist (Should pass because @@ -193,7 +193,7 @@ public void removeUser_TransactionTests() throws Exception { @Test public void deleteDevice_TransactionTests() throws Exception { - TestSetupResult result = setup(); + TestSetupResult result = initSteps(); TOTPSQLStorage storage = result.storage; TOTPDevice device1 = new TOTPDevice("user", "device1", "sk1", 30, 1, false); @@ -237,7 +237,7 @@ public void deleteDevice_TransactionTests() throws Exception { @Test public void updateDeviceNameTests() throws Exception { - TestSetupResult result = setup(); + TestSetupResult result = initSteps(); TOTPSQLStorage storage = result.storage; TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); @@ -272,7 +272,7 @@ public void updateDeviceNameTests() throws Exception { @Test public void getDevicesTest() throws Exception { - TestSetupResult result = setup(); + TestSetupResult result = initSteps(); TOTPSQLStorage storage = result.storage; TOTPDevice device1 = new TOTPDevice("user", "d1", "secretKey", 30, 1, false); @@ -293,7 +293,7 @@ public void getDevicesTest() throws Exception { @Test public void insertUsedCodeTest() throws Exception { - TestSetupResult result = setup(); + TestSetupResult result = initSteps(); TOTPSQLStorage storage = result.storage; long nextDay = System.currentTimeMillis() + 1000 * 60 * 60 * 24; // 1 day from now @@ -333,7 +333,7 @@ public void insertUsedCodeTest() throws Exception { @Test public void getNonExpiredUsedCodesTest() throws Exception { - TestSetupResult result = setup(); + TestSetupResult result = initSteps(); TOTPSQLStorage storage = result.storage; TOTPUsedCode[] usedCodes = storage.getNonExpiredUsedCodes("non-existent-user"); @@ -369,7 +369,7 @@ public void getNonExpiredUsedCodesTest() throws Exception { @Test public void removeExpiredCodesTest() throws Exception { - TestSetupResult result = setup(); + TestSetupResult result = initSteps(); TOTPSQLStorage storage = result.storage; long now = System.currentTimeMillis(); @@ -394,7 +394,7 @@ public void removeExpiredCodesTest() throws Exception { // After 500ms seconds pass: Thread.sleep(500); - storage.removeExpiredCodes(5); // FIXME:::: + storage.removeExpiredCodes(); usedCodes = storage.getNonExpiredUsedCodes("user"); assert (usedCodes.length == 2); @@ -404,7 +404,7 @@ public void removeExpiredCodesTest() throws Exception { @Test public void deleteAllDataForUserTest() throws Exception { - TestSetupResult result = setup(); + TestSetupResult result = initSteps(); TOTPSQLStorage storage = result.storage; long now = System.currentTimeMillis(); From 54ad75e450105a2226ab5d7322811e45ffd83bb0 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Tue, 28 Feb 2023 18:31:55 +0530 Subject: [PATCH 20/42] feat: Improve TOTP rate limiting - Query all codes instead of only expired ones - Remove redundant deleteAllDataForUser from TOTPQueries - Move TOTP code generation to tests - Add logging to DeleteExpriedTotoTokens cron --- .../io/supertokens/authRecipe/AuthRecipe.java | 2 +- .../DeleteExpiredTotpTokens.java | 9 +- .../java/io/supertokens/inmemorydb/Start.java | 23 ++-- .../inmemorydb/queries/TOTPQueries.java | 55 +-------- src/main/java/io/supertokens/totp/Totp.java | 46 +++----- .../supertokens/test/totp/TOTPRecipeTest.java | 105 ++++++++++++------ .../test/totp/TOTPStorageTest.java | 28 +++-- 7 files changed, 125 insertions(+), 143 deletions(-) diff --git a/src/main/java/io/supertokens/authRecipe/AuthRecipe.java b/src/main/java/io/supertokens/authRecipe/AuthRecipe.java index 3080a48b4..d3e3c66e6 100644 --- a/src/main/java/io/supertokens/authRecipe/AuthRecipe.java +++ b/src/main/java/io/supertokens/authRecipe/AuthRecipe.java @@ -103,7 +103,7 @@ private static void deleteNonAuthRecipeUser(Main main, String userId) throws Sto StorageLayer.getSessionStorage(main).deleteSessionsOfUser(userId); StorageLayer.getEmailVerificationStorage(main).deleteEmailVerificationUserInfo(userId); StorageLayer.getUserRolesStorage(main).deleteAllRolesForUser(userId); - StorageLayer.getTOTPStorage(main).deleteAllDataForUser(userId); + StorageLayer.getTOTPStorage(main).deleteAllTotpDataForUser(userId); } private static void deleteAuthRecipeUser(Main main, String userId) throws StorageQueryException { diff --git a/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java b/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java index 32d2d6084..339367dfc 100644 --- a/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java +++ b/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java @@ -9,6 +9,7 @@ import io.supertokens.cronjobs.CronTask; import io.supertokens.cronjobs.CronTaskTest; import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.output.Logging; public class DeleteExpiredTotpTokens extends CronTask { @@ -26,11 +27,6 @@ public static DeleteExpiredTotpTokens getInstance(Main main) { return (DeleteExpiredTotpTokens) instance; } - @TestOnly - public void doTaskForTest() throws Exception { - doTask(); - } - @Override protected void doTask() throws Exception { if (StorageLayer.getStorage(this.main).getType() != STORAGE_TYPE.SQL) { @@ -39,7 +35,8 @@ protected void doTask() throws Exception { TOTPSQLStorage storage = StorageLayer.getTOTPStorage(this.main); - storage.removeExpiredCodes(); + int deletedCount = storage.removeExpiredCodes(); + Logging.debug(this.main, "Cron DeleteExpiredTotpTokens deleted " + deletedCount + " expired TOTP codes"); } @Override diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index 0e0b4c882..df1bdbb86 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -1740,31 +1740,40 @@ public void insertUsedCode(TOTPUsedCode usedCodeObj) } @Override - public TOTPUsedCode[] getNonExpiredUsedCodes(String userId) + public TOTPUsedCode[] getAllUsedCodes(String userId) throws StorageQueryException { try { - return TOTPQueries.getNonExpiredUsedCodesDescOrder(this, userId); + return TOTPQueries.getAllUsedCodesDescOrder(this, userId); } catch (SQLException e) { throw new StorageQueryException(e); } } @Override - public void removeExpiredCodes() + public int removeExpiredCodes() throws StorageQueryException { try { - TOTPQueries.removeExpiredCodes(this); + return TOTPQueries.removeExpiredCodes(this); } catch (SQLException e) { throw new StorageQueryException(e); } } @Override - public void deleteAllDataForUser(String userId) throws StorageQueryException { + public void deleteAllTotpDataForUser(String userId) throws StorageQueryException { + // TODO: Logically this is corrrect. But is this the right way to do it? try { - TOTPQueries.deleteAllDataForUser(this, userId); + this.startTransaction(con -> { + Connection sqlCon = (Connection) con.getConnection(); + try { + TOTPQueries.removeUser_Transaction(this, sqlCon, userId); + } catch (SQLException e) { + throw new StorageTransactionLogicException(e); + } + return null; + }); } catch (StorageTransactionLogicException e) { - throw new StorageQueryException(e); + throw new StorageQueryException(e.actualException); } } } diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index 506291357..224a31ba6 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -140,31 +140,6 @@ public static int removeUser_Transaction(Start start, Connection con, String use return removedUsersCount; } - // public static int deleteDevice(Start start, String userId, String deviceName) - // throws StorageQueryException, StorageTransactionLogicException { - // return start.startTransaction(con -> { - // Connection sqlCon = (Connection) con.getConnection(); - - // try { - // int deletedCount = deleteDevice_Transaction(start, sqlCon, userId, - // deviceName); - // if (deletedCount > 0) { - // // Some device was deleted. Check if user has any other device left: - // int devicesCount = getDevicesCount_Transaction(start, sqlCon, userId); - // if (devicesCount == 0) { - // // no device left. delete user - // removeUser_Transaction(start, sqlCon, userId); - // } - // } - - // sqlCon.commit(); - // return deletedCount; - // } catch (SQLException e) { - // throw new StorageTransactionLogicException(e); - // } - // }); - // } - public static int updateDeviceName(Start start, String userId, String oldDeviceName, String newDeviceName) throws StorageQueryException, SQLException { String QUERY = "UPDATE " + Config.getConfig(start).getTotpUserDevicesTable() @@ -243,17 +218,16 @@ public static void insertUsedCode(Start start, TOTPUsedCode code) } /** - * Query to get all non expired used codes for a user in descending order of - * creation time. + * Query to get all used codes (expired/non-expired) for a user in descending + * order of creation time. */ - public static TOTPUsedCode[] getNonExpiredUsedCodesDescOrder(Start start, String userId) + public static TOTPUsedCode[] getAllUsedCodesDescOrder(Start start, String userId) throws SQLException, StorageQueryException { String QUERY = "SELECT * FROM " + Config.getConfig(start).getTotpUsedCodesTable() - + " WHERE user_id = ? AND expiry_time_ms > ? ORDER BY created_time_ms DESC;"; + + " WHERE user_id = ? ORDER BY created_time_ms DESC;"; return execute(start, QUERY, pst -> { pst.setString(1, userId); - pst.setLong(2, System.currentTimeMillis()); }, result -> { List codes = new ArrayList<>(); while (result.next()) { @@ -293,27 +267,6 @@ private static int deleteUser_Transaction(Start start, Connection con, String us return update(con, QUERY, pst -> pst.setString(1, userId)); } - public static void deleteAllDataForUser(Start start, String userId) - throws StorageQueryException, StorageTransactionLogicException { - start.startTransaction(con -> { - Connection sqlCon = (Connection) con.getConnection(); - - try { - // NOTE: These two steps are required only for in-memory db. - // Since foreign key constraints are not supported in in-memory db. - deleteAllDevices_Transaction(start, sqlCon, userId); - deleteAllUsedCodes_Transaction(start, sqlCon, userId); - - deleteUser_Transaction(start, sqlCon, userId); - sqlCon.commit(); - } catch (SQLException e) { - throw new StorageTransactionLogicException(e); - } - - return null; - }); - } - private static class TOTPDeviceRowMapper implements RowMapper { private static final TOTPDeviceRowMapper INSTANCE = new TOTPDeviceRowMapper(); diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index b242e7cce..a0acd2863 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -72,29 +72,6 @@ private static boolean checkCode(TOTPDevice device, String code) { return false; } - @TestOnly - public static String generateTotpCode(Main main, TOTPDevice device) - throws InvalidKeyException, StorageQueryException { - return generateTotpCode(main, device, 0); - } - - /** - * Replicates TOTP code is generated by apps like Google Authenticator and Authy - * - * @throws StorageQueryException - */ - @TestOnly - public static String generateTotpCode(Main main, TOTPDevice device, int step) - throws InvalidKeyException, StorageQueryException { - final TimeBasedOneTimePasswordGenerator totp = new TimeBasedOneTimePasswordGenerator( - Duration.ofSeconds(device.period)); - - byte[] keyBytes = Base64.getDecoder().decode(device.secretKey); - Key key = new SecretKeySpec(keyBytes, "HmacSHA1"); - - return totp.generateOneTimePasswordString(key, Instant.now().plusSeconds(step * device.period)); - } - public static TOTPDevice registerDevice(Main main, String userId, String deviceName, int skew, int period) throws StorageQueryException, DeviceAlreadyExistsException, NoSuchAlgorithmException { @@ -119,7 +96,14 @@ public static TOTPDevice registerDevice(Main main, String userId, String deviceN private static void checkAndStoreCode(Main main, TOTPStorage totpStorage, String userId, TOTPDevice[] devices, String code) throws InvalidTotpException, StorageQueryException, TotpNotEnabledException, LimitReachedException { - TOTPUsedCode[] usedCodes = totpStorage.getNonExpiredUsedCodes(userId); + // Note that here we are fetching all the codes (expired/non-expired). + // otherwise, because of differences in expiry time of different codes, we might + // end up with a situation where the will be released from the rate limiting too + // early because of some invalid codes in the checking window expired OR it can + // also lead to random rate limiting because if some valid codes blip out of the + // checking window and if it leads to N contagious invalid codes, then the user + // will be rate limited for no reason. + TOTPUsedCode[] usedCodes = totpStorage.getAllUsedCodes(userId); // N represents # of invalid attempts that will trigger rate limiting: int N = Config.getConfig(main).getTotpMaxAttempts(); // (Default 5) @@ -143,12 +127,11 @@ private static void checkAndStoreCode(Main main, TOTPStorage totpStorage, String // Note: One edge case here is: user is rate limited, and then the // DeleteExpiredTotpTokens cron removes the latest invalid attempts - // (because they are expired), and then user will again be able to - // do extra login attempts (totp_max_attempts times). - // But the cron running during cooldown of a user is somewhat rare. - // And every period has enough entropy to make brute force attacks - // infeasible, the code changes every period, and rate limiting will - // kick in after totp_max_attempts number of attempts anyways. + // (because they have expired), and then user will again be able to + // do extra login attempts (totp_max_attempts more times). + // But rate limiting will kick in after totp_max_attempts number + // disarming the brute force attack. + // Furthermore, the cron running during cooldown of a user is somewhat rare. // So this edge case is not a big deal. } } @@ -174,7 +157,8 @@ private static void checkAndStoreCode(Main main, TOTPStorage totpStorage, String // regenerated by the second device, then it won't allow the second device's // code to be used until it is expired. // But this would be rare so we can ignore it for now. - if (usedCode.code.equals(code) && usedCode.isValid) { + if (usedCode.code.equals(code) && usedCode.isValid + && usedCode.expiryTime > System.currentTimeMillis()) { isValid = false; matchingDevice = null; } diff --git a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java index f23d2897d..a20fcc1de 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java @@ -20,6 +20,13 @@ import static org.junit.Assert.assertThrows; import java.io.IOException; +import java.security.InvalidKeyException; +import java.security.Key; +import java.time.Duration; +import java.time.Instant; +import java.util.Base64; + +import javax.crypto.spec.SecretKeySpec; import org.junit.AfterClass; import org.junit.Before; @@ -27,6 +34,8 @@ import org.junit.Test; import org.junit.rules.TestRule; +import com.eatthepath.otp.TimeBasedOneTimePasswordGenerator; + import io.supertokens.test.Utils; import io.supertokens.Main; import io.supertokens.ProcessState; @@ -35,6 +44,7 @@ import io.supertokens.cronjobs.CronTaskTest; import io.supertokens.cronjobs.deleteExpiredTotpTokens.DeleteExpiredTotpTokens; import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.test.TestingProcessManager; @@ -87,6 +97,23 @@ public TestSetupResult defaultInit() throws InterruptedException, IOException { return new TestSetupResult(storage, process); } + private static String generateTotpCode(Main main, TOTPDevice device) + throws InvalidKeyException, StorageQueryException { + return generateTotpCode(main, device, 0); + } + + /** Generates TOTP code similar to apps like Google Authenticator and Authy */ + private static String generateTotpCode(Main main, TOTPDevice device, int step) + throws InvalidKeyException, StorageQueryException { + final TimeBasedOneTimePasswordGenerator totp = new TimeBasedOneTimePasswordGenerator( + Duration.ofSeconds(device.period)); + + byte[] keyBytes = Base64.getDecoder().decode(device.secretKey); + Key key = new SecretKeySpec(keyBytes, "HmacSHA1"); + + return totp.generateOneTimePasswordString(key, Instant.now().plusSeconds(step * device.period)); + } + @Test public void createDeviceTest() throws Exception { TestSetupResult result = defaultInit(); @@ -126,10 +153,10 @@ public void createDeviceAndVerifyCodeTest() throws Exception { // Valid code & allowUnverifiedDevice = false: assertThrows( InvalidTotpException.class, - () -> Totp.verifyCode(main, "user", Totp.generateTotpCode(main, device), false)); + () -> Totp.verifyCode(main, "user", generateTotpCode(main, device), false)); // Valid code & allowUnverifiedDevice = true (Success): - String validCode = Totp.generateTotpCode(main, device); + String validCode = generateTotpCode(main, device); Totp.verifyCode(main, "user", validCode, true); // Now try again with same code: @@ -141,27 +168,27 @@ public void createDeviceAndVerifyCodeTest() throws Exception { Thread.sleep(1500); // Use a new valid code: - String newValidCode = Totp.generateTotpCode(main, device); + String newValidCode = generateTotpCode(main, device); Totp.verifyCode(main, "user", newValidCode, true); // Regenerate the same code and use it again (should fail): - String newValidCodeCopy = Totp.generateTotpCode(main, device); + String newValidCodeCopy = generateTotpCode(main, device); assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", newValidCodeCopy, true)); // Use a code from next period: - String nextValidCode = Totp.generateTotpCode(main, device, 1); + String nextValidCode = generateTotpCode(main, device, 1); Totp.verifyCode(main, "user", nextValidCode, true); // Use previous period code (should fail coz validCode): // FIXME: This should // // fail - // String previousCode = Totp.generateTotpCode(main, "user", "device", -1); + // String previousCode = generateTotpCode(main, "user", "device", -1); // Totp.verifyCode(main, "user", previousCode, true); - // TODO: Add tests for next and previous codes as well. - // TODO: Add tests for different skew values (0 and 1) - // TODO: Add tests where we change totp_max_attempts - // TODO: Add tests where we change totp_invalid_code_expiry_sec + // TODO: Add isolated tests where we + // - we try next and previous codes as well (try different skew values) + // - change totp_max_attempts + // - change totp_invalid_code_expiry_sec } public void triggerAndCheckRateLimit(Main main, TOTPDevice device) throws Exception { @@ -183,7 +210,7 @@ public void triggerAndCheckRateLimit(Main main, TOTPDevice device) throws Except () -> Totp.verifyCode(main, "user", "invalid-code-N+1", true)); assertThrows( LimitReachedException.class, - () -> Totp.verifyCode(main, "user", Totp.generateTotpCode(main, device), true)); + () -> Totp.verifyCode(main, "user", generateTotpCode(main, device), true)); assertThrows( LimitReachedException.class, () -> Totp.verifyCode(main, "user", "invalid-code-N+2", true)); @@ -217,19 +244,25 @@ public void rateLimitCooldownTest() throws Exception { // This triggered rate limiting again. So even valid codes will fail for // another cooldown period: assertThrows(LimitReachedException.class, - () -> Totp.verifyCode(main, "user", Totp.generateTotpCode(main, device), true)); + () -> Totp.verifyCode(main, "user", generateTotpCode(main, device), true)); // Wait for 1 second (Should cool down rate limiting): Thread.sleep(1000); // Now try with valid code: - Totp.verifyCode(main, "user", Totp.generateTotpCode(main, device), true); + Totp.verifyCode(main, "user", generateTotpCode(main, device), true); // Now invalid code shouldn't trigger rate limiting. Unless you do it N times: assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "some-invalid-code", true)); } @Test - public void removeExpiredCodesCronDuringRateLimitTest() throws Exception { - TestSetupResult result = defaultInit(); - Main main = result.process.getProcess(); + public void cronRemovesAllCodesDuringRateLimitTest() throws Exception { + String[] args = { "../" }; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args, false); + + Utils.setValueInConfig("totp_invalid_code_expiry_sec", "1"); + process.startProcess(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + Main main = process.getProcess(); // Create device TOTPDevice device = Totp.registerDevice(main, "user", "deviceName", 0, 1); @@ -237,16 +270,16 @@ public void removeExpiredCodesCronDuringRateLimitTest() throws Exception { // Trigger rate limiting and fix it with cronjob (manually run cronjob): triggerAndCheckRateLimit(main, device); // Wait for 1 second so that all the codes expire: - Thread.sleep(1000); - // FIXME: Can this be cleaner? - DeleteExpiredTotpTokens.getInstance(main).doTaskForTest(); + Thread.sleep(2000); + // Manually run cronjob to delete all the codes: (Here all of them are expired) + DeleteExpiredTotpTokens.getInstance(main).run(); // Will completely reset the rate limiting. Allowing the user to do N attempts // here N == totp_max_attempts from the config: assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "again-wrong-code1", true)); // This should have throws LimitReachedException but it won't because of cron: assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "again-wrong-code2", true)); - Totp.verifyCode(main, "user", Totp.generateTotpCode(main, device), true); + Totp.verifyCode(main, "user", generateTotpCode(main, device), true); // We can do N attempts again: triggerAndCheckRateLimit(main, device); } @@ -271,7 +304,7 @@ public void createAndVerifyDeviceTest() throws Exception { assertThrows(InvalidTotpException.class, () -> Totp.verifyDevice(main, "user", "deviceName", "wrong-code")); // Verify device with correct code - String validCode = Totp.generateTotpCode(main, device); + String validCode = generateTotpCode(main, device); boolean justVerfied = Totp.verifyDevice(main, "user", "deviceName", validCode); assert justVerfied; @@ -280,7 +313,7 @@ public void createAndVerifyDeviceTest() throws Exception { assert !justVerfied; // Verify again with new correct code: - String newValidCode = Totp.generateTotpCode(main, device); + String newValidCode = generateTotpCode(main, device); justVerfied = Totp.verifyDevice(main, "user", "deviceName", newValidCode); assert !justVerfied; @@ -308,8 +341,8 @@ public void removeDeviceTest() throws Exception { // Delete one of the devices { assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "invalid-code", true)); - Totp.verifyCode(main, "user", Totp.generateTotpCode(main, device1), true); - Totp.verifyCode(main, "user", Totp.generateTotpCode(main, device2), true); + Totp.verifyCode(main, "user", generateTotpCode(main, device1), true); + Totp.verifyCode(main, "user", generateTotpCode(main, device2), true); // Delete device1 Totp.removeDevice(main, "user", "device1"); @@ -318,7 +351,7 @@ public void removeDeviceTest() throws Exception { assert (devices.length == 1); // 1 device still remain so all codes should still be still there: - TOTPUsedCode[] usedCodes = storage.getNonExpiredUsedCodes("user"); + TOTPUsedCode[] usedCodes = storage.getAllUsedCodes("user"); assert (usedCodes.length == 3); } @@ -328,7 +361,7 @@ public void removeDeviceTest() throws Exception { // Create another user to test that other users aren't affected: TOTPDevice otherUserDevice = Totp.registerDevice(main, "other-user", "device", 1, 30); - Totp.verifyCode(main, "other-user", Totp.generateTotpCode(main, otherUserDevice), true); + Totp.verifyCode(main, "other-user", generateTotpCode(main, otherUserDevice), true); assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "other-user", "invalid-code", true)); // Delete device2 @@ -338,14 +371,14 @@ public void removeDeviceTest() throws Exception { assertThrows(TotpNotEnabledException.class, () -> Totp.getDevices(main, "user")); // No device left so all codes of the user should be deleted: - TOTPUsedCode[] usedCodes = storage.getNonExpiredUsedCodes("user"); + TOTPUsedCode[] usedCodes = storage.getAllUsedCodes("user"); assert (usedCodes.length == 0); // But for other users things should still be there: TOTPDevice[] otherUserDevices = Totp.getDevices(main, "other-user"); assert (otherUserDevices.length == 1); - usedCodes = storage.getNonExpiredUsedCodes("other-user"); + usedCodes = storage.getAllUsedCodes("other-user"); assert (usedCodes.length == 2); } } @@ -376,8 +409,8 @@ public void updateDeviceNameTest() throws Exception { assert (devices[1].deviceName.equals("new-device-name")); // Verify that TOTP verification still works: - Totp.verifyDevice(main, "user", devices[0].deviceName, Totp.generateTotpCode(main, devices[0])); - Totp.verifyDevice(main, "user", devices[0].deviceName, Totp.generateTotpCode(main, devices[1])); + Totp.verifyDevice(main, "user", devices[0].deviceName, generateTotpCode(main, devices[0])); + Totp.verifyDevice(main, "user", devices[0].deviceName, generateTotpCode(main, devices[1])); // Try update device name to an already existing device name: assertThrows(DeviceAlreadyExistsException.class, @@ -427,23 +460,23 @@ public void deleteExpiredTokensCronTest() throws Exception { TOTPDevice device = Totp.registerDevice(main, "user", "device", 0, 1); // Add codes: - Totp.verifyCode(main, "user", Totp.generateTotpCode(main, device), true); + Totp.verifyCode(main, "user", generateTotpCode(main, device), true); assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "invalid-code", true)); TOTPStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); // Verify that the codes have been added: - TOTPUsedCode[] usedCodes = storage.getNonExpiredUsedCodes("user"); + TOTPUsedCode[] usedCodes = storage.getAllUsedCodes("user"); assert (usedCodes.length == 2); - // Wait for 1 second to make sure that the valid codes expire + // Wait for 2 second to make sure that the valid codes expire // (and crons deletes the valid ones since they are expired) - Thread.sleep(1000); + Thread.sleep(1000 * 2); - usedCodes = storage.getNonExpiredUsedCodes("user"); + usedCodes = storage.getAllUsedCodes("user"); assert (usedCodes.length == 1); - // Invalid code will still remain because their expiration time is 5 minutes + // Invalid code will still remain because their expiry time is 30 minutes assert usedCodes[0].code.equals("invalid-code"); } } diff --git a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java index c9ded66fa..72098962d 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java @@ -11,6 +11,7 @@ import io.supertokens.test.Utils; import io.supertokens.ProcessState; +import io.supertokens.cronjobs.deleteExpiredTotpTokens.DeleteExpiredTotpTokens; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.test.TestingProcessManager; @@ -175,7 +176,7 @@ public void removeUser_TransactionTests() throws Exception { TOTPDevice[] storedDevices = storage.getDevices("user"); assert (storedDevices.length == 2); - TOTPUsedCode[] storedUsedCodes = storage.getNonExpiredUsedCodes("user"); + TOTPUsedCode[] storedUsedCodes = storage.getAllUsedCodes("user"); assert (storedUsedCodes.length == 2); storage.startTransaction(con -> { @@ -187,7 +188,7 @@ public void removeUser_TransactionTests() throws Exception { storedDevices = storage.getDevices("user"); assert (storedDevices.length == 0); - storedUsedCodes = storage.getNonExpiredUsedCodes("user"); + storedUsedCodes = storage.getAllUsedCodes("user"); assert (storedUsedCodes.length == 0); } @@ -304,7 +305,7 @@ public void insertUsedCodeTest() throws Exception { storage.createDevice(device); storage.insertUsedCode(code); - TOTPUsedCode[] usedCodes = storage.getNonExpiredUsedCodes("user"); + TOTPUsedCode[] usedCodes = storage.getAllUsedCodes("user"); assert (usedCodes.length == 1); assert usedCodes[0].equals(code); @@ -332,11 +333,11 @@ public void insertUsedCodeTest() throws Exception { } @Test - public void getNonExpiredUsedCodesTest() throws Exception { + public void getAllUsedCodesTest() throws Exception { TestSetupResult result = initSteps(); TOTPSQLStorage storage = result.storage; - TOTPUsedCode[] usedCodes = storage.getNonExpiredUsedCodes("non-existent-user"); + TOTPUsedCode[] usedCodes = storage.getAllUsedCodes("non-existent-user"); assert (usedCodes.length == 0); long now = System.currentTimeMillis(); @@ -359,7 +360,12 @@ public void getNonExpiredUsedCodesTest() throws Exception { storage.insertUsedCode(validCode2); storage.insertUsedCode(validCode3); - usedCodes = storage.getNonExpiredUsedCodes("user"); + usedCodes = storage.getAllUsedCodes("user"); + assert (usedCodes.length == 6); + + DeleteExpiredTotpTokens.getInstance(result.process.getProcess()).run(); + + usedCodes = storage.getAllUsedCodes("user"); assert (usedCodes.length == 4); // expired codes shouldn't be returned assert (usedCodes[0].equals(validCode3)); // order is DESC by created time (now + X) assert (usedCodes[1].equals(validCode2)); @@ -388,7 +394,7 @@ public void removeExpiredCodesTest() throws Exception { storage.insertUsedCode(validCodeToExpire); storage.insertUsedCode(invalidCodeToExpire); - TOTPUsedCode[] usedCodes = storage.getNonExpiredUsedCodes("user"); + TOTPUsedCode[] usedCodes = storage.getAllUsedCodes("user"); assert (usedCodes.length == 4); // After 500ms seconds pass: @@ -396,7 +402,7 @@ public void removeExpiredCodesTest() throws Exception { storage.removeExpiredCodes(); - usedCodes = storage.getNonExpiredUsedCodes("user"); + usedCodes = storage.getAllUsedCodes("user"); assert (usedCodes.length == 2); assert (usedCodes[0].equals(validCodeToLive)); assert (usedCodes[1].equals(invalidCodeToLive)); @@ -421,15 +427,15 @@ public void deleteAllDataForUserTest() throws Exception { storage.insertUsedCode(invalidCode); TOTPDevice[] storedDevices = storage.getDevices("user"); - TOTPUsedCode[] usedCodes = storage.getNonExpiredUsedCodes("user"); + TOTPUsedCode[] usedCodes = storage.getAllUsedCodes("user"); assert (storedDevices.length == 2); assert (usedCodes.length == 2); - storage.deleteAllDataForUser("user"); + storage.deleteAllTotpDataForUser("user"); storedDevices = storage.getDevices("user"); - usedCodes = storage.getNonExpiredUsedCodes("user"); + usedCodes = storage.getAllUsedCodes("user"); assert (storedDevices.length == 0); assert (usedCodes.length == 0); From 4ac4760b74ed106d5a44d4bf43d4bc6a181da0f1 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Tue, 28 Feb 2023 19:14:14 +0530 Subject: [PATCH 21/42] refactor: Remove redundant method deleteAllTotpDataForUser --- .../io/supertokens/authRecipe/AuthRecipe.java | 12 ++++++- .../java/io/supertokens/inmemorydb/Start.java | 18 ---------- .../test/totp/TOTPStorageTest.java | 33 ------------------- 3 files changed, 11 insertions(+), 52 deletions(-) diff --git a/src/main/java/io/supertokens/authRecipe/AuthRecipe.java b/src/main/java/io/supertokens/authRecipe/AuthRecipe.java index d3e3c66e6..2f08c3061 100644 --- a/src/main/java/io/supertokens/authRecipe/AuthRecipe.java +++ b/src/main/java/io/supertokens/authRecipe/AuthRecipe.java @@ -20,6 +20,7 @@ import io.supertokens.pluginInterface.RECIPE_ID; import io.supertokens.pluginInterface.authRecipe.AuthRecipeUserInfo; import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.useridmapping.UserIdMapping; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.useridmapping.UserIdType; @@ -103,7 +104,16 @@ private static void deleteNonAuthRecipeUser(Main main, String userId) throws Sto StorageLayer.getSessionStorage(main).deleteSessionsOfUser(userId); StorageLayer.getEmailVerificationStorage(main).deleteEmailVerificationUserInfo(userId); StorageLayer.getUserRolesStorage(main).deleteAllRolesForUser(userId); - StorageLayer.getTOTPStorage(main).deleteAllTotpDataForUser(userId); + try { + StorageLayer.getTOTPStorage(main).startTransaction(con -> { + StorageLayer.getTOTPStorage(main).removeUser_Transaction(con, userId); + return null; + }); + } catch (StorageTransactionLogicException e) { + if (e.actualException instanceof StorageQueryException) { + throw (StorageQueryException) e.actualException; + } + } } private static void deleteAuthRecipeUser(Main main, String userId) throws StorageQueryException { diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index df1bdbb86..fa050aa96 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -1758,22 +1758,4 @@ public int removeExpiredCodes() throw new StorageQueryException(e); } } - - @Override - public void deleteAllTotpDataForUser(String userId) throws StorageQueryException { - // TODO: Logically this is corrrect. But is this the right way to do it? - try { - this.startTransaction(con -> { - Connection sqlCon = (Connection) con.getConnection(); - try { - TOTPQueries.removeUser_Transaction(this, sqlCon, userId); - } catch (SQLException e) { - throw new StorageTransactionLogicException(e); - } - return null; - }); - } catch (StorageTransactionLogicException e) { - throw new StorageQueryException(e.actualException); - } - } } diff --git a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java index 72098962d..7fd375e83 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java @@ -407,37 +407,4 @@ public void removeExpiredCodesTest() throws Exception { assert (usedCodes[0].equals(validCodeToLive)); assert (usedCodes[1].equals(invalidCodeToLive)); } - - @Test - public void deleteAllDataForUserTest() throws Exception { - TestSetupResult result = initSteps(); - TOTPSQLStorage storage = result.storage; - - long now = System.currentTimeMillis(); - long nextDay = now + 1000 * 60 * 60 * 24; // 1 day from now - - TOTPDevice device1 = new TOTPDevice("user", "d1", "secretKey", 30, 1, false); - TOTPDevice device2 = new TOTPDevice("user", "d2", "secretKey", 30, 1, false); - TOTPUsedCode validCode = new TOTPUsedCode("user", "d1-valid", true, nextDay, now); - TOTPUsedCode invalidCode = new TOTPUsedCode("user", "invalid-code", false, nextDay, now); - - storage.createDevice(device1); - storage.createDevice(device2); - storage.insertUsedCode(validCode); - storage.insertUsedCode(invalidCode); - - TOTPDevice[] storedDevices = storage.getDevices("user"); - TOTPUsedCode[] usedCodes = storage.getAllUsedCodes("user"); - - assert (storedDevices.length == 2); - assert (usedCodes.length == 2); - - storage.deleteAllTotpDataForUser("user"); - - storedDevices = storage.getDevices("user"); - usedCodes = storage.getAllUsedCodes("user"); - - assert (storedDevices.length == 0); - assert (usedCodes.length == 0); - } } From 207513169d90b249ff06f09f608afc1ae225a360 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Wed, 1 Mar 2023 14:45:16 +0530 Subject: [PATCH 22/42] feat: Add APIs for TOTP recipe - Add APIs with input validation and error handling - Refactor LimitReachedException to store Retry-After header --- src/main/java/io/supertokens/totp/Totp.java | 8 +- .../exceptions/LimitReachedException.java | 6 ++ .../io/supertokens/webserver/Webserver.java | 12 +++ .../api/totp/CreateTotpDeviceAPI.java | 93 +++++++++++++++++++ .../webserver/api/totp/GetTotpDevicesAPI.java | 72 ++++++++++++++ .../api/totp/RemoveTotpDeviceAPI.java | 72 ++++++++++++++ .../api/totp/UpdateTotpDeviceNameAPI.java | 80 ++++++++++++++++ .../webserver/api/totp/VerifyTotpAPI.java | 83 +++++++++++++++++ .../api/totp/VerifyTotpDeviceAPI.java | 87 +++++++++++++++++ 9 files changed, 506 insertions(+), 7 deletions(-) create mode 100644 src/main/java/io/supertokens/webserver/api/totp/CreateTotpDeviceAPI.java create mode 100644 src/main/java/io/supertokens/webserver/api/totp/GetTotpDevicesAPI.java create mode 100644 src/main/java/io/supertokens/webserver/api/totp/RemoveTotpDeviceAPI.java create mode 100644 src/main/java/io/supertokens/webserver/api/totp/UpdateTotpDeviceNameAPI.java create mode 100644 src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java create mode 100644 src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index a0acd2863..a9684f36c 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -77,12 +77,6 @@ public static TOTPDevice registerDevice(Main main, String userId, String deviceN TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); - // Assert that period > 0 (not == 0 because it would lead to a divide by 0 - // error) - // Assert that period <= 60. Otherwise, it is a security risk. Actually, - // anything > 30 is bad. - // and skew >= 0 and skew <= 2. Otherwise, it is a security risk. - // TODO: There should be a hard limit on number of devices per user // 8 devices per user should be enough. Otherwise, it is a security risk. @@ -120,7 +114,7 @@ private static void checkAndStoreCode(Main main, TOTPStorage totpStorage, String if (now - latestInvalidCodeCreatedTime < rateLimitResetTimeInMs) { // Less than rateLimitResetTimeInMs (default = 15 mins) time has elasped since // the last invalid code: - throw new LimitReachedException(); + throw new LimitReachedException(rateLimitResetTimeInMs / 1000); // If we insert the used code here, then it will further delay the user from // being able to login. So not inserting it here. diff --git a/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java b/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java index 0da203afe..1cf9772fe 100644 --- a/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java +++ b/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java @@ -2,4 +2,10 @@ public class LimitReachedException extends Exception { + public int retryInSeconds; + + public LimitReachedException(int retryInSeconds) { + super("Retry in " + retryInSeconds + " seconds"); + this.retryInSeconds = retryInSeconds; + } } diff --git a/src/main/java/io/supertokens/webserver/Webserver.java b/src/main/java/io/supertokens/webserver/Webserver.java index 4d6f9a4ba..c941eaa9b 100644 --- a/src/main/java/io/supertokens/webserver/Webserver.java +++ b/src/main/java/io/supertokens/webserver/Webserver.java @@ -39,6 +39,12 @@ import io.supertokens.webserver.api.session.*; import io.supertokens.webserver.api.thirdparty.GetUsersByEmailAPI; import io.supertokens.webserver.api.thirdparty.SignInUpAPI; +import io.supertokens.webserver.api.totp.CreateTotpDeviceAPI; +import io.supertokens.webserver.api.totp.GetTotpDevicesAPI; +import io.supertokens.webserver.api.totp.RemoveTotpDeviceAPI; +import io.supertokens.webserver.api.totp.UpdateTotpDeviceNameAPI; +import io.supertokens.webserver.api.totp.VerifyTotpAPI; +import io.supertokens.webserver.api.totp.VerifyTotpDeviceAPI; import io.supertokens.webserver.api.useridmapping.RemoveUserIdMappingAPI; import io.supertokens.webserver.api.useridmapping.UpdateExternalUserIdInfoAPI; import io.supertokens.webserver.api.useridmapping.UserIdMappingAPI; @@ -247,6 +253,12 @@ private void setupRoutes() throws Exception { addAPI(new GetRolesAPI(main)); addAPI(new UserIdMappingAPI(main)); addAPI(new RemoveUserIdMappingAPI(main)); + addAPI(new CreateTotpDeviceAPI(main)); + addAPI(new VerifyTotpDeviceAPI(main)); + addAPI(new VerifyTotpAPI(main)); + addAPI(new RemoveTotpDeviceAPI(main)); + addAPI(new UpdateTotpDeviceNameAPI(main)); + addAPI(new GetTotpDevicesAPI(main)); addAPI(new UpdateExternalUserIdInfoAPI(main)); addAPI(new ImportUserWithPasswordHashAPI(main)); addAPI(new LicenseKeyAPI(main)); diff --git a/src/main/java/io/supertokens/webserver/api/totp/CreateTotpDeviceAPI.java b/src/main/java/io/supertokens/webserver/api/totp/CreateTotpDeviceAPI.java new file mode 100644 index 000000000..249c076b0 --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/totp/CreateTotpDeviceAPI.java @@ -0,0 +1,93 @@ +package io.supertokens.webserver.api.totp; + +import java.io.IOException; +import java.security.NoSuchAlgorithmException; + +import com.google.gson.JsonObject; + +import io.supertokens.Main; +import io.supertokens.pluginInterface.RECIPE_ID; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; +import io.supertokens.totp.Totp; +import io.supertokens.webserver.InputParser; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +public class CreateTotpDeviceAPI extends WebserverAPI { + private static final long serialVersionUID = -4641988458637882374L; + + public CreateTotpDeviceAPI(Main main) { + super(main, RECIPE_ID.TOTP.toString()); + } + + @Override + public String getPath() { + return "/recipe/totp/device"; + } + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + JsonObject input = InputParser.parseJsonObjectOrThrowError(req); + + String userId = null; + String deviceName = null; + Integer skew = null; + Integer period = null; + + // TODO: Should we also allow the user to change the hashing algo and totp + // length (6-8) since we are already allowing them to change the period and skew + // which are advanced options anyways? + + if (input.has("userId")) { + userId = InputParser.parseStringOrThrowError(input, "userId", false); + } + if (input.has("deviceName")) { + deviceName = InputParser.parseStringOrThrowError(input, "deviceName", false); + } + if (input.has("skew")) { + // FIXME: No function to parse integer: + skew = InputParser.parseLongOrThrowError(input, "skew", false).intValue(); + } + if (input.has("period")) { + // FIXME: No function to parse integer: + period = InputParser.parseLongOrThrowError(input, "period", false).intValue(); + } + + if (userId.isEmpty()) { + throw new ServletException(new IllegalArgumentException("userId cannot be empty")); + } + if (deviceName.isEmpty()) { + throw new ServletException(new IllegalArgumentException("deviceName cannot be empty")); + } + if (skew < 0) { + throw new ServletException(new IllegalArgumentException("skew must be >= 0")); + } + if (period <= 0) { + throw new ServletException(new IllegalArgumentException("period must be > 0")); + } + + // Should we do these as well? + // - Assert period <= 60. Otherwise, it is a security risk. > 30 is also bad. + // - Assert skew <= 2. Otherwise, it is a security risk. + + JsonObject result = new JsonObject(); + + try { + TOTPDevice device = Totp.registerDevice(main, userId, deviceName, skew, period); + + result.addProperty("status", "OK"); + result.addProperty("secret", device.secretKey); + super.sendJsonResponse(200, result, resp); + } catch (DeviceAlreadyExistsException e) { + result.addProperty("status", "DEVICE_ALREADY_EXISTS_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (StorageQueryException | NoSuchAlgorithmException e) { + throw new ServletException(e); + } + } + +} diff --git a/src/main/java/io/supertokens/webserver/api/totp/GetTotpDevicesAPI.java b/src/main/java/io/supertokens/webserver/api/totp/GetTotpDevicesAPI.java new file mode 100644 index 000000000..5cc43e77c --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/totp/GetTotpDevicesAPI.java @@ -0,0 +1,72 @@ +package io.supertokens.webserver.api.totp; + +import java.io.IOException; + +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; + +import io.supertokens.Main; +import io.supertokens.pluginInterface.RECIPE_ID; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.totp.Totp; +import io.supertokens.webserver.InputParser; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +public class GetTotpDevicesAPI extends WebserverAPI { + private static final long serialVersionUID = -4641988458637882374L; + + public GetTotpDevicesAPI(Main main) { + super(main, RECIPE_ID.TOTP.toString()); + } + + @Override + public String getPath() { + return "/recipe/totp/device/list"; + } + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + JsonObject input = InputParser.parseJsonObjectOrThrowError(req); + + String userId = null; + + if (input.has("userId")) { + userId = InputParser.parseStringOrThrowError(input, "userId", false); + } + + if (userId.isEmpty()) { + throw new ServletException(new IllegalArgumentException("userId cannot be empty")); + } + + JsonObject result = new JsonObject(); + + try { + TOTPDevice[] devices = Totp.getDevices(main, userId); + JsonArray devicesArray = new JsonArray(); + + for (TOTPDevice d : devices) { + JsonObject item = new JsonObject(); + item.addProperty("name", d.deviceName); + item.addProperty("period", d.period); + item.addProperty("skew", d.skew); + item.addProperty("verified", d.verified); + + devicesArray.add(item); + } + + result.addProperty("status", "OK"); + result.add("devices", devicesArray); + super.sendJsonResponse(200, result, resp); + } catch (TotpNotEnabledException e) { + result.addProperty("status", "TOTP_NOT_ENABLED_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (StorageQueryException e) { + throw new ServletException(e); + } + } +} diff --git a/src/main/java/io/supertokens/webserver/api/totp/RemoveTotpDeviceAPI.java b/src/main/java/io/supertokens/webserver/api/totp/RemoveTotpDeviceAPI.java new file mode 100644 index 000000000..bb349c561 --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/totp/RemoveTotpDeviceAPI.java @@ -0,0 +1,72 @@ +package io.supertokens.webserver.api.totp; + +import java.io.IOException; + +import com.google.gson.JsonObject; + +import io.supertokens.Main; +import io.supertokens.pluginInterface.RECIPE_ID; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; +import io.supertokens.totp.Totp; +import io.supertokens.webserver.InputParser; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +public class RemoveTotpDeviceAPI extends WebserverAPI { + private static final long serialVersionUID = -4641988458637882374L; + + public RemoveTotpDeviceAPI(Main main) { + super(main, RECIPE_ID.TOTP.toString()); + } + + @Override + public String getPath() { + return "/recipe/totp/device/remove"; + } + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + JsonObject input = InputParser.parseJsonObjectOrThrowError(req); + + String userId = null; + String deviceName = null; + + if (input.has("userId")) { + userId = InputParser.parseStringOrThrowError(input, "userId", false); + } + if (input.has("deviceName")) { + deviceName = InputParser.parseStringOrThrowError(input, "deviceName", false); + } + + if (userId.isEmpty()) { + throw new ServletException(new IllegalArgumentException("userId cannot be empty")); + } + if (deviceName.isEmpty()) { + throw new ServletException(new IllegalArgumentException("deviceName cannot be empty")); + } + + JsonObject result = new JsonObject(); + + try { + Totp.removeDevice(main, userId, deviceName); + + result.addProperty("status", "OK"); + result.addProperty("didDeviceExist", true); + super.sendJsonResponse(200, result, resp); + } catch (TotpNotEnabledException e) { + result.addProperty("status", "TOTP_NOT_ENABLED_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (UnknownDeviceException e) { + result.addProperty("status", "OK"); + result.addProperty("didDeviceExist", false); + super.sendJsonResponse(200, result, resp); + } catch (StorageQueryException | StorageTransactionLogicException e) { + throw new ServletException(e); + } + } +} diff --git a/src/main/java/io/supertokens/webserver/api/totp/UpdateTotpDeviceNameAPI.java b/src/main/java/io/supertokens/webserver/api/totp/UpdateTotpDeviceNameAPI.java new file mode 100644 index 000000000..891fc2f7a --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/totp/UpdateTotpDeviceNameAPI.java @@ -0,0 +1,80 @@ +package io.supertokens.webserver.api.totp; + +import java.io.IOException; + +import com.google.gson.JsonObject; + +import io.supertokens.Main; +import io.supertokens.pluginInterface.RECIPE_ID; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; +import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; +import io.supertokens.totp.Totp; +import io.supertokens.webserver.InputParser; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +public class UpdateTotpDeviceNameAPI extends WebserverAPI { + private static final long serialVersionUID = -4641988458637882374L; + + public UpdateTotpDeviceNameAPI(Main main) { + super(main, RECIPE_ID.TOTP.toString()); + } + + @Override + public String getPath() { + return "/recipe/totp/device"; + } + + @Override + protected void doPut(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + JsonObject input = InputParser.parseJsonObjectOrThrowError(req); + + String userId = null; + String existingDeviceName = null; + String newDeviceName = null; + + if (input.has("userId")) { + userId = InputParser.parseStringOrThrowError(input, "userId", false); + } + if (input.has("existingDeviceName")) { + existingDeviceName = InputParser.parseStringOrThrowError(input, "existingDeviceName", false); + } + if (input.has("newDeviceName")) { + newDeviceName = InputParser.parseStringOrThrowError(input, "newDeviceName", false); + } + + if (userId.isEmpty()) { + throw new ServletException(new IllegalArgumentException("userId cannot be empty")); + } + if (existingDeviceName.isEmpty()) { + throw new ServletException(new IllegalArgumentException("existingDeviceName cannot be empty")); + } + if (newDeviceName.isEmpty()) { + throw new ServletException(new IllegalArgumentException("newDeviceName cannot be empty")); + } + + JsonObject result = new JsonObject(); + + try { + Totp.updateDeviceName(main, userId, existingDeviceName, newDeviceName); + + result.addProperty("status", "OK"); + super.sendJsonResponse(200, result, resp); + } catch (TotpNotEnabledException e) { + result.addProperty("status", "TOTP_NOT_ENABLED_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (UnknownDeviceException e) { + result.addProperty("status", "UNKNOWN_DEVICE_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (DeviceAlreadyExistsException e) { + result.addProperty("status", "DEVICE_ALREADY_EXISTS_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (StorageQueryException e) { + throw new ServletException(e); + } + } +} diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java new file mode 100644 index 000000000..570c023a9 --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java @@ -0,0 +1,83 @@ +package io.supertokens.webserver.api.totp; + +import java.io.IOException; + +import com.google.gson.JsonObject; + +import io.supertokens.Main; +import io.supertokens.pluginInterface.RECIPE_ID; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.totp.Totp; +import io.supertokens.totp.exceptions.InvalidTotpException; +import io.supertokens.totp.exceptions.LimitReachedException; +import io.supertokens.webserver.InputParser; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +public class VerifyTotpAPI extends WebserverAPI { + private static final long serialVersionUID = -4641988458637882374L; + + public VerifyTotpAPI(Main main) { + super(main, RECIPE_ID.TOTP.toString()); + } + + @Override + public String getPath() { + return "/recipe/totp/device/verify"; + } + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + JsonObject input = InputParser.parseJsonObjectOrThrowError(req); + + String userId = null; + String totp = null; + Boolean allowUnverifiedDevices = null; + + if (input.has("userId")) { + userId = InputParser.parseStringOrThrowError(input, "userId", false); + } + if (input.has("totp")) { + totp = InputParser.parseStringOrThrowError(input, "totp", false); + if (totp.length() != 6) { + throw new ServletException(new IllegalArgumentException("totp must be 6 characters long")); + } + } + if (input.has("allowUnverifiedDevices")) { + allowUnverifiedDevices = InputParser.parseBooleanOrThrowError(input, "allowUnverifiedDevices", false); + } + + if (userId.isEmpty()) { + throw new ServletException(new IllegalArgumentException("userId cannot be empty")); + } + if (totp.length() != 6) { + throw new ServletException(new IllegalArgumentException("totp must be 6 characters long")); + } + // Already checked that allowUnverifiedDevices is not null. + + JsonObject result = new JsonObject(); + + try { + Totp.verifyCode(main, userId, totp, allowUnverifiedDevices); + + result.addProperty("status", "OK"); + super.sendJsonResponse(200, result, resp); + } catch (TotpNotEnabledException e) { + result.addProperty("status", "TOTP_NOT_ENABLED_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (InvalidTotpException e) { + result.addProperty("status", "INVALID_TOTP_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (LimitReachedException e) { + result.addProperty("status", "LIMIT_REACHED_ERROR"); + // Also return a retryAfter value: + resp.addHeader("Retry-After", Integer.toString(e.retryInSeconds)); + super.sendJsonResponse(429, result, resp); // 429 Too Many Requests + } catch (StorageQueryException e) { + throw new ServletException(e); + } + } +} diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java new file mode 100644 index 000000000..4218f219a --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java @@ -0,0 +1,87 @@ +package io.supertokens.webserver.api.totp; + +import java.io.IOException; + +import com.google.gson.JsonObject; + +import io.supertokens.Main; +import io.supertokens.pluginInterface.RECIPE_ID; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; +import io.supertokens.totp.Totp; +import io.supertokens.totp.exceptions.InvalidTotpException; +import io.supertokens.totp.exceptions.LimitReachedException; +import io.supertokens.webserver.InputParser; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +public class VerifyTotpDeviceAPI extends WebserverAPI { + private static final long serialVersionUID = -4641988458637882374L; + + public VerifyTotpDeviceAPI(Main main) { + super(main, RECIPE_ID.TOTP.toString()); + } + + @Override + public String getPath() { + return "/recipe/totp/device/verify"; + } + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + JsonObject input = InputParser.parseJsonObjectOrThrowError(req); + + String userId = null; + String deviceName = null; + String totp = null; + + if (input.has("userId")) { + userId = InputParser.parseStringOrThrowError(input, "userId", false); + } + if (input.has("deviceName")) { + deviceName = InputParser.parseStringOrThrowError(input, "deviceName", false); + } + if (input.has("totp")) { + totp = InputParser.parseStringOrThrowError(input, "totp", false); + } + + if (userId.isEmpty()) { + throw new ServletException(new IllegalArgumentException("userId cannot be empty")); + } + if (deviceName.isEmpty()) { + throw new ServletException(new IllegalArgumentException("deviceName cannot be empty")); + } + if (totp.length() != 6) { + throw new ServletException(new IllegalArgumentException("totp must be 6 characters long")); + } + + JsonObject result = new JsonObject(); + + try { + boolean isNewlyVerified = Totp.verifyDevice(main, userId, deviceName, totp); + + result.addProperty("status", "OK"); + result.addProperty("wasAlreadyVerified", !isNewlyVerified); + super.sendJsonResponse(200, result, resp); + } catch (TotpNotEnabledException e) { + result.addProperty("status", "TOTP_NOT_ENABLED_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (UnknownDeviceException e) { + result.addProperty("status", "UNKNOWN_DEVICE_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (InvalidTotpException e) { + result.addProperty("status", "INVALID_TOTP_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (LimitReachedException e) { + result.addProperty("status", "LIMIT_REACHED_ERROR"); + // Also return a retryAfter value: + resp.addHeader("Retry-After", Integer.toString(e.retryInSeconds)); + super.sendJsonResponse(429, result, resp); // 429 (Too Many Requests) + } catch (StorageQueryException e) { + throw new ServletException(e); + } + } +} From 25ea1b5c35494642a39bd6fe8fce7e5870be00bc Mon Sep 17 00:00:00 2001 From: KShivendu Date: Thu, 2 Mar 2023 16:48:25 +0530 Subject: [PATCH 23/42] feat: Improve TOTP recipe - Remove unused code and improve var names - Use lock for getDevicesCount Txn - Clearly explain rate limiting logic - Add test for invalid totp core config - Merge create and update TOTP device API functions - Bubble up error from removeUserTxn on deleteUser --- .../io/supertokens/authRecipe/AuthRecipe.java | 19 +++---- .../DeleteExpiredTotpTokens.java | 2 - .../java/io/supertokens/inmemorydb/Start.java | 7 ++- .../inmemorydb/queries/TOTPQueries.java | 24 +------- src/main/java/io/supertokens/totp/Totp.java | 53 +++++++++++++----- .../io/supertokens/webserver/Webserver.java | 5 +- .../webserver/api/core/DeleteUserAPI.java | 3 +- ....java => CreateOrUpdateTotpDeviceAPI.java} | 55 ++++++++++++++++++- .../webserver/api/totp/VerifyTotpAPI.java | 2 +- .../io/supertokens/test/ConfigTest2_6.java | 46 +++++++++++++++- .../supertokens/test/totp/TOTPRecipeTest.java | 10 ++-- .../test/totp/TOTPStorageTest.java | 21 +++---- 12 files changed, 170 insertions(+), 77 deletions(-) rename src/main/java/io/supertokens/webserver/api/totp/{CreateTotpDeviceAPI.java => CreateOrUpdateTotpDeviceAPI.java} (59%) diff --git a/src/main/java/io/supertokens/authRecipe/AuthRecipe.java b/src/main/java/io/supertokens/authRecipe/AuthRecipe.java index 2f08c3061..a86ca123e 100644 --- a/src/main/java/io/supertokens/authRecipe/AuthRecipe.java +++ b/src/main/java/io/supertokens/authRecipe/AuthRecipe.java @@ -60,7 +60,7 @@ public static UserPaginationContainer getUsers(Main main, Integer limit, String return new UserPaginationContainer(resultUsers, nextPaginationToken); } - public static void deleteUser(Main main, String userId) throws StorageQueryException { + public static void deleteUser(Main main, String userId) throws StorageQueryException, StorageTransactionLogicException { // We clean up the user last so that if anything before that throws an error, then that will throw a 500 to the // developer. In this case, they expect that the user has not been deleted (which will be true). This is as // opposed to deleting the user first, in which case if something later throws an error, then the user has @@ -98,22 +98,17 @@ public static void deleteUser(Main main, String userId) throws StorageQueryExcep } - private static void deleteNonAuthRecipeUser(Main main, String userId) throws StorageQueryException { + private static void deleteNonAuthRecipeUser(Main main, String userId) throws StorageQueryException, StorageTransactionLogicException { // non auth recipe deletion StorageLayer.getUserMetadataStorage(main).deleteUserMetadata(userId); StorageLayer.getSessionStorage(main).deleteSessionsOfUser(userId); StorageLayer.getEmailVerificationStorage(main).deleteEmailVerificationUserInfo(userId); StorageLayer.getUserRolesStorage(main).deleteAllRolesForUser(userId); - try { - StorageLayer.getTOTPStorage(main).startTransaction(con -> { - StorageLayer.getTOTPStorage(main).removeUser_Transaction(con, userId); - return null; - }); - } catch (StorageTransactionLogicException e) { - if (e.actualException instanceof StorageQueryException) { - throw (StorageQueryException) e.actualException; - } - } + StorageLayer.getTOTPStorage(main).startTransaction(con -> { + StorageLayer.getTOTPStorage(main).removeUser_Transaction(con, userId); + StorageLayer.getTOTPStorage(main).commitTransaction(con); + return null; + }); } private static void deleteAuthRecipeUser(Main main, String userId) throws StorageQueryException { diff --git a/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java b/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java index 339367dfc..cfe02f81a 100644 --- a/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java +++ b/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java @@ -1,7 +1,5 @@ package io.supertokens.cronjobs.deleteExpiredTotpTokens; -import org.jetbrains.annotations.TestOnly; - import io.supertokens.Main; import io.supertokens.ResourceDistributor; import io.supertokens.pluginInterface.STORAGE_TYPE; diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index fa050aa96..426afb610 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -1650,8 +1650,9 @@ public void createDevice(TOTPDevice device) throws StorageQueryException, Device public void markDeviceAsVerified(String userId, String deviceName) throws StorageQueryException, UnknownDeviceException { try { - int updatedCount = TOTPQueries.markDeviceAsVerified(this, userId, deviceName); - if (updatedCount == 0) { + int matchedCount = TOTPQueries.markDeviceAsVerified(this, userId, deviceName); + if (matchedCount == 0) { + // Note matchedCount != updatedCount throw new UnknownDeviceException(); } return; // Device was marked as verified @@ -1740,7 +1741,7 @@ public void insertUsedCode(TOTPUsedCode usedCodeObj) } @Override - public TOTPUsedCode[] getAllUsedCodes(String userId) + public TOTPUsedCode[] getAllUsedCodesDescOrder(String userId) throws StorageQueryException { try { return TOTPQueries.getAllUsedCodesDescOrder(this, userId); diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index 224a31ba6..cda1e841d 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -8,6 +8,7 @@ import io.supertokens.inmemorydb.Start; import io.supertokens.inmemorydb.config.Config; +import io.supertokens.inmemorydb.ConnectionWithLocks; import io.supertokens.pluginInterface.RowMapper; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; @@ -169,6 +170,8 @@ public static TOTPDevice[] getDevices(Start start, String userId) public static int getDevicesCount_Transaction(Start start, Connection con, String userId) throws StorageQueryException, SQLException { + ((ConnectionWithLocks) con).lock(userId + Config.getConfig(start).getTotpUserDevicesTable()); + String QUERY = "SELECT COUNT(*) as count FROM " + Config.getConfig(start).getTotpUserDevicesTable() + " WHERE user_id = ?;"; @@ -246,27 +249,6 @@ public static int removeExpiredCodes(Start start) return update(start, QUERY, pst -> pst.setLong(1, System.currentTimeMillis())); } - private static int deleteAllDevices_Transaction(Start start, Connection con, String userId) - throws SQLException, StorageQueryException { - String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUserDevicesTable() - + " WHERE user_id = ?;"; - return update(con, QUERY, pst -> pst.setString(1, userId)); - } - - private static int deleteAllUsedCodes_Transaction(Start start, Connection con, String userId) - throws SQLException, StorageQueryException { - String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUsedCodesTable() - + " WHERE user_id = ?;"; - return update(con, QUERY, pst -> pst.setString(1, userId)); - } - - private static int deleteUser_Transaction(Start start, Connection con, String userId) - throws SQLException, StorageQueryException { - String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUsersTable() - + " WHERE user_id = ?;"; - return update(con, QUERY, pst -> pst.setString(1, userId)); - } - private static class TOTPDeviceRowMapper implements RowMapper { private static final TOTPDeviceRowMapper INSTANCE = new TOTPDeviceRowMapper(); diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index a9684f36c..65508eacf 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -31,20 +31,17 @@ import io.supertokens.storageLayer.StorageLayer; import io.supertokens.totp.exceptions.InvalidTotpException; import io.supertokens.totp.exceptions.LimitReachedException; -import jakarta.annotation.Nullable; public class Totp { private static String generateSecret() throws NoSuchAlgorithmException { - // TODO: We can actually allow the user to choose this algorithm. - // Changing it a would be rare but it can be a requirement for someone - // who's dealing with unconventional totp apps/devices. + // Reference: https://github.com/jchambers/java-otp final String TOTP_ALGORITHM = "HmacSHA1"; final KeyGenerator keyGenerator = KeyGenerator.getInstance(TOTP_ALGORITHM); keyGenerator.init(160); // 160 bits = 20 bytes // FIXME: Should return base32 or base16 - // Return base64 string of the secret key: + // Return base64 encoded string for the secret key: return Base64.getEncoder().encodeToString(keyGenerator.generateKey().getEncoded()); } @@ -61,10 +58,12 @@ private static boolean checkCode(TOTPDevice device, String code) { // Check if code is valid for any of the time periods in the skew: for (int i = -skew; i <= skew; i++) { try { + // TODO: Would there be any effect of timezones here? if (totp.generateOneTimePasswordString(key, Instant.now().plusSeconds(i * period)).equals(code)) { return true; } } catch (InvalidKeyException e) { + // This should never happen because we are always using a valid secretKey. return false; } } @@ -90,14 +89,38 @@ public static TOTPDevice registerDevice(Main main, String userId, String deviceN private static void checkAndStoreCode(Main main, TOTPStorage totpStorage, String userId, TOTPDevice[] devices, String code) throws InvalidTotpException, StorageQueryException, TotpNotEnabledException, LimitReachedException { - // Note that here we are fetching all the codes (expired/non-expired). - // otherwise, because of differences in expiry time of different codes, we might - // end up with a situation where the will be released from the rate limiting too - // early because of some invalid codes in the checking window expired OR it can - // also lead to random rate limiting because if some valid codes blip out of the - // checking window and if it leads to N contagious invalid codes, then the user - // will be rate limited for no reason. - TOTPUsedCode[] usedCodes = totpStorage.getAllUsedCodes(userId); + // Note that the TOTP cron runs every 1 hour, so all the expired tokens can stay + // in the db for at max 1 hour after expiry. + + // If we filter expired codes in rate limiting logic, then because of + // differences in expiry time of different codes, we might end up with a + // situation where: + // Case 1: users might get released from the rate limiting too early because of + // some invalid codes in the checking window were expired. + // Case 2: users might face random rate limiting because if some valid codes + // expire and if it leads to N contagious invalid + // codes, then the user will be rate limited for no reason. + + // For examaple, assume 0 means expired; 1 means non-expired: + // Also, assume that totp_max_attempts is 3, totp_rate_limit_cooldown_time is + // 15 minutes, and totp_invalid_code_expiry is 5 minutes. + + // Example for Case 1: + // User is rate limited and the used codes are like this: [1, 1, 0, 0, 0]. Now + // if 1st zero (invalid code) expires in 5 mins and we filter + // out expired codes, we'll end up [1, 1, 0, 0]. This doesn't contain 3 + // contiguous invalid codes, so the user will be released from rate limiting in + // 5 minutes instead of 15 minutes. + + // Example for Case 2: + // User has used codes like this: [0, 1, 0, 0]. + // The 1st one (valid code) will expire in merely 1.5 minutes (assuming skew = 2 + // and period = 30s). So now if we filter out expired codes, we'll see + // [0, 0, 0] and this contains 3 contagious invalid codes, so now the user will + // be rate limited for no reason. + + // That's why we need to fetch all the codes (expired + non-expired). + TOTPUsedCode[] usedCodes = totpStorage.getAllUsedCodesDescOrder(userId); // N represents # of invalid attempts that will trigger rate limiting: int N = Config.getConfig(main).getTotpMaxAttempts(); // (Default 5) @@ -119,14 +142,14 @@ private static void checkAndStoreCode(Main main, TOTPStorage totpStorage, String // If we insert the used code here, then it will further delay the user from // being able to login. So not inserting it here. - // Note: One edge case here is: user is rate limited, and then the + // Note: One edge case here is: User is rate limited, and then the // DeleteExpiredTotpTokens cron removes the latest invalid attempts // (because they have expired), and then user will again be able to // do extra login attempts (totp_max_attempts more times). // But rate limiting will kick in after totp_max_attempts number // disarming the brute force attack. // Furthermore, the cron running during cooldown of a user is somewhat rare. - // So this edge case is not a big deal. + // So this edge case is practically harmless. } } diff --git a/src/main/java/io/supertokens/webserver/Webserver.java b/src/main/java/io/supertokens/webserver/Webserver.java index c941eaa9b..c828673fc 100644 --- a/src/main/java/io/supertokens/webserver/Webserver.java +++ b/src/main/java/io/supertokens/webserver/Webserver.java @@ -39,7 +39,7 @@ import io.supertokens.webserver.api.session.*; import io.supertokens.webserver.api.thirdparty.GetUsersByEmailAPI; import io.supertokens.webserver.api.thirdparty.SignInUpAPI; -import io.supertokens.webserver.api.totp.CreateTotpDeviceAPI; +import io.supertokens.webserver.api.totp.CreateOrUpdateTotpDeviceAPI; import io.supertokens.webserver.api.totp.GetTotpDevicesAPI; import io.supertokens.webserver.api.totp.RemoveTotpDeviceAPI; import io.supertokens.webserver.api.totp.UpdateTotpDeviceNameAPI; @@ -253,11 +253,10 @@ private void setupRoutes() throws Exception { addAPI(new GetRolesAPI(main)); addAPI(new UserIdMappingAPI(main)); addAPI(new RemoveUserIdMappingAPI(main)); - addAPI(new CreateTotpDeviceAPI(main)); + addAPI(new CreateOrUpdateTotpDeviceAPI(main)); addAPI(new VerifyTotpDeviceAPI(main)); addAPI(new VerifyTotpAPI(main)); addAPI(new RemoveTotpDeviceAPI(main)); - addAPI(new UpdateTotpDeviceNameAPI(main)); addAPI(new GetTotpDevicesAPI(main)); addAPI(new UpdateExternalUserIdInfoAPI(main)); addAPI(new ImportUserWithPasswordHashAPI(main)); diff --git a/src/main/java/io/supertokens/webserver/api/core/DeleteUserAPI.java b/src/main/java/io/supertokens/webserver/api/core/DeleteUserAPI.java index bbb9ea916..26c1dc0d5 100644 --- a/src/main/java/io/supertokens/webserver/api/core/DeleteUserAPI.java +++ b/src/main/java/io/supertokens/webserver/api/core/DeleteUserAPI.java @@ -28,6 +28,7 @@ import io.supertokens.authRecipe.AuthRecipe; import io.supertokens.pluginInterface.RECIPE_ID; import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.webserver.InputParser; import io.supertokens.webserver.WebserverAPI; @@ -53,7 +54,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I JsonObject result = new JsonObject(); result.addProperty("status", "OK"); super.sendJsonResponse(200, result, resp); - } catch (StorageQueryException e) { + } catch (StorageQueryException | StorageTransactionLogicException e) { throw new ServletException(e); } } diff --git a/src/main/java/io/supertokens/webserver/api/totp/CreateTotpDeviceAPI.java b/src/main/java/io/supertokens/webserver/api/totp/CreateOrUpdateTotpDeviceAPI.java similarity index 59% rename from src/main/java/io/supertokens/webserver/api/totp/CreateTotpDeviceAPI.java rename to src/main/java/io/supertokens/webserver/api/totp/CreateOrUpdateTotpDeviceAPI.java index 249c076b0..5f51c02bd 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/CreateTotpDeviceAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/CreateOrUpdateTotpDeviceAPI.java @@ -10,6 +10,8 @@ import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.totp.TOTPDevice; import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; +import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; import io.supertokens.totp.Totp; import io.supertokens.webserver.InputParser; import io.supertokens.webserver.WebserverAPI; @@ -17,10 +19,10 @@ import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; -public class CreateTotpDeviceAPI extends WebserverAPI { +public class CreateOrUpdateTotpDeviceAPI extends WebserverAPI { private static final long serialVersionUID = -4641988458637882374L; - public CreateTotpDeviceAPI(Main main) { + public CreateOrUpdateTotpDeviceAPI(Main main) { super(main, RECIPE_ID.TOTP.toString()); } @@ -90,4 +92,53 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I } } + @Override + protected void doPut(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + JsonObject input = InputParser.parseJsonObjectOrThrowError(req); + + String userId = null; + String existingDeviceName = null; + String newDeviceName = null; + + if (input.has("userId")) { + userId = InputParser.parseStringOrThrowError(input, "userId", false); + } + if (input.has("existingDeviceName")) { + existingDeviceName = InputParser.parseStringOrThrowError(input, "existingDeviceName", false); + } + if (input.has("newDeviceName")) { + newDeviceName = InputParser.parseStringOrThrowError(input, "newDeviceName", false); + } + + if (userId.isEmpty()) { + throw new ServletException(new IllegalArgumentException("userId cannot be empty")); + } + if (existingDeviceName.isEmpty()) { + throw new ServletException(new IllegalArgumentException("existingDeviceName cannot be empty")); + } + if (newDeviceName.isEmpty()) { + throw new ServletException(new IllegalArgumentException("newDeviceName cannot be empty")); + } + + JsonObject result = new JsonObject(); + + try { + Totp.updateDeviceName(main, userId, existingDeviceName, newDeviceName); + + result.addProperty("status", "OK"); + super.sendJsonResponse(200, result, resp); + } catch (TotpNotEnabledException e) { + result.addProperty("status", "TOTP_NOT_ENABLED_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (UnknownDeviceException e) { + result.addProperty("status", "UNKNOWN_DEVICE_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (DeviceAlreadyExistsException e) { + result.addProperty("status", "DEVICE_ALREADY_EXISTS_ERROR"); + super.sendJsonResponse(200, result, resp); + } catch (StorageQueryException e) { + throw new ServletException(e); + } + } + } diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java index 570c023a9..e1880251e 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java @@ -26,7 +26,7 @@ public VerifyTotpAPI(Main main) { @Override public String getPath() { - return "/recipe/totp/device/verify"; + return "/recipe/totp/verify"; } @Override diff --git a/src/test/java/io/supertokens/test/ConfigTest2_6.java b/src/test/java/io/supertokens/test/ConfigTest2_6.java index 50faca16d..0b869f38a 100644 --- a/src/test/java/io/supertokens/test/ConfigTest2_6.java +++ b/src/test/java/io/supertokens/test/ConfigTest2_6.java @@ -135,6 +135,49 @@ public void testThatInvalidConfigThrowRightError() throws Exception { } + @Test + public void testInvalidTotpConfigThrowsExpectedError() throws Exception { + String[] args = { "../" }; + + Utils.setValueInConfig("totp_max_attempts", "0"); + + TestingProcess process = TestingProcessManager.start(args); + + ProcessState.EventAndException e = process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.INIT_FAILURE); + assertNotNull(e); + assertEquals(e.exception.getMessage(), + "'totp_max_attempts' must be > 0"); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(PROCESS_STATE.STOPPED)); + + Utils.reset(); + + Utils.setValueInConfig("totp_rate_limit_cooldown_sec", "0"); + process = TestingProcessManager.start(args); + + e = process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.INIT_FAILURE); + assertNotNull(e); + assertEquals(e.exception.getMessage(), + "'totp_rate_limit_cooldown_sec' must be > 0"); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(PROCESS_STATE.STOPPED)); + + Utils.reset(); + + Utils.setValueInConfig("totp_invalid_code_expiry_sec", "0"); + process = TestingProcessManager.start(args); + + e = process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.INIT_FAILURE); + assertNotNull(e); + assertEquals(e.exception.getMessage(), + "'totp_invalid_code_expiry_sec' must be > 0"); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(PROCESS_STATE.STOPPED)); + } + private String getConfigFileLocation(Main main) { return new File(CLIOptions.get(main).getConfigFilePath() == null ? CLIOptions.get(main).getInstallationPath() + "config.yaml" @@ -220,10 +263,9 @@ private static void checkConfigValues(CoreConfig config, TestingProcess process, assertFalse("Config access token blacklisting did not match default", config.getAccessTokenBlacklisting()); assertEquals("Config refresh token validity did not match default", config.getRefreshTokenValidity(), 60 * 2400 * 60 * (long) 1000); - // TODO: Is this correct? assertEquals(5, config.getTotpMaxAttempts()); // 5 assertEquals(900, config.getTotpRateLimitCooldownTime()); // 15 minutes - assertEquals(5, config.getTotpInvalidCodeExpiryTime()); // 30 minutes + assertEquals(1800, config.getTotpInvalidCodeExpiryTime()); // 30 minutes assertEquals("Config info log path did not match default", config.getInfoLogPath(process.getProcess()), CLIOptions.get(process.getProcess()).getInstallationPath() + "logs/info.log"); diff --git a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java index a20fcc1de..be0f4d0ac 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java @@ -351,7 +351,7 @@ public void removeDeviceTest() throws Exception { assert (devices.length == 1); // 1 device still remain so all codes should still be still there: - TOTPUsedCode[] usedCodes = storage.getAllUsedCodes("user"); + TOTPUsedCode[] usedCodes = storage.getAllUsedCodesDescOrder("user"); assert (usedCodes.length == 3); } @@ -371,14 +371,14 @@ public void removeDeviceTest() throws Exception { assertThrows(TotpNotEnabledException.class, () -> Totp.getDevices(main, "user")); // No device left so all codes of the user should be deleted: - TOTPUsedCode[] usedCodes = storage.getAllUsedCodes("user"); + TOTPUsedCode[] usedCodes = storage.getAllUsedCodesDescOrder("user"); assert (usedCodes.length == 0); // But for other users things should still be there: TOTPDevice[] otherUserDevices = Totp.getDevices(main, "other-user"); assert (otherUserDevices.length == 1); - usedCodes = storage.getAllUsedCodes("other-user"); + usedCodes = storage.getAllUsedCodesDescOrder("other-user"); assert (usedCodes.length == 2); } } @@ -467,14 +467,14 @@ public void deleteExpiredTokensCronTest() throws Exception { TOTPStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); // Verify that the codes have been added: - TOTPUsedCode[] usedCodes = storage.getAllUsedCodes("user"); + TOTPUsedCode[] usedCodes = storage.getAllUsedCodesDescOrder("user"); assert (usedCodes.length == 2); // Wait for 2 second to make sure that the valid codes expire // (and crons deletes the valid ones since they are expired) Thread.sleep(1000 * 2); - usedCodes = storage.getAllUsedCodes("user"); + usedCodes = storage.getAllUsedCodesDescOrder("user"); assert (usedCodes.length == 1); // Invalid code will still remain because their expiry time is 30 minutes assert usedCodes[0].code.equals("invalid-code"); diff --git a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java index 7fd375e83..0cec2634b 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java @@ -111,12 +111,13 @@ public void verifyDeviceTests() throws Exception { assert (storedDevices.length == 1); assert (storedDevices[0].verified); + // Try to verify the device again: + storage.markDeviceAsVerified("user", "device"); + // Try to verify a device that doesn't exist: assertThrows(UnknownDeviceException.class, () -> storage.markDeviceAsVerified("user", "non-existent-device")); } - // FIXME: Should write tests for other transaction functions as well. - @Test public void getDevicesCount_TransactionTests() throws Exception { TestSetupResult result = initSteps(); @@ -176,7 +177,7 @@ public void removeUser_TransactionTests() throws Exception { TOTPDevice[] storedDevices = storage.getDevices("user"); assert (storedDevices.length == 2); - TOTPUsedCode[] storedUsedCodes = storage.getAllUsedCodes("user"); + TOTPUsedCode[] storedUsedCodes = storage.getAllUsedCodesDescOrder("user"); assert (storedUsedCodes.length == 2); storage.startTransaction(con -> { @@ -188,7 +189,7 @@ public void removeUser_TransactionTests() throws Exception { storedDevices = storage.getDevices("user"); assert (storedDevices.length == 0); - storedUsedCodes = storage.getAllUsedCodes("user"); + storedUsedCodes = storage.getAllUsedCodesDescOrder("user"); assert (storedUsedCodes.length == 0); } @@ -305,7 +306,7 @@ public void insertUsedCodeTest() throws Exception { storage.createDevice(device); storage.insertUsedCode(code); - TOTPUsedCode[] usedCodes = storage.getAllUsedCodes("user"); + TOTPUsedCode[] usedCodes = storage.getAllUsedCodesDescOrder("user"); assert (usedCodes.length == 1); assert usedCodes[0].equals(code); @@ -337,7 +338,7 @@ public void getAllUsedCodesTest() throws Exception { TestSetupResult result = initSteps(); TOTPSQLStorage storage = result.storage; - TOTPUsedCode[] usedCodes = storage.getAllUsedCodes("non-existent-user"); + TOTPUsedCode[] usedCodes = storage.getAllUsedCodesDescOrder("non-existent-user"); assert (usedCodes.length == 0); long now = System.currentTimeMillis(); @@ -360,12 +361,12 @@ public void getAllUsedCodesTest() throws Exception { storage.insertUsedCode(validCode2); storage.insertUsedCode(validCode3); - usedCodes = storage.getAllUsedCodes("user"); + usedCodes = storage.getAllUsedCodesDescOrder("user"); assert (usedCodes.length == 6); DeleteExpiredTotpTokens.getInstance(result.process.getProcess()).run(); - usedCodes = storage.getAllUsedCodes("user"); + usedCodes = storage.getAllUsedCodesDescOrder("user"); assert (usedCodes.length == 4); // expired codes shouldn't be returned assert (usedCodes[0].equals(validCode3)); // order is DESC by created time (now + X) assert (usedCodes[1].equals(validCode2)); @@ -394,7 +395,7 @@ public void removeExpiredCodesTest() throws Exception { storage.insertUsedCode(validCodeToExpire); storage.insertUsedCode(invalidCodeToExpire); - TOTPUsedCode[] usedCodes = storage.getAllUsedCodes("user"); + TOTPUsedCode[] usedCodes = storage.getAllUsedCodesDescOrder("user"); assert (usedCodes.length == 4); // After 500ms seconds pass: @@ -402,7 +403,7 @@ public void removeExpiredCodesTest() throws Exception { storage.removeExpiredCodes(); - usedCodes = storage.getAllUsedCodes("user"); + usedCodes = storage.getAllUsedCodesDescOrder("user"); assert (usedCodes.length == 2); assert (usedCodes[0].equals(validCodeToLive)); assert (usedCodes[1].equals(invalidCodeToLive)); From 3aca9e4d2be92e14fd95125f063fb594368033a7 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Thu, 2 Mar 2023 17:38:49 +0530 Subject: [PATCH 24/42] refactor: Remove created_time index from totp_used_codes table --- .../io/supertokens/inmemorydb/queries/GeneralQueries.java | 1 - .../java/io/supertokens/inmemorydb/queries/TOTPQueries.java | 5 ----- 2 files changed, 6 deletions(-) diff --git a/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java index d4a325a36..134e07ead 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java @@ -201,7 +201,6 @@ public static void createTablesIfNotExists(Start start, Main main) throws SQLExc update(start, TOTPQueries.getQueryToCreateUsedCodesTable(start), NO_OP_SETTER); // index: update(start, TOTPQueries.getQueryToCreateUsedCodesExpiryTimeIndex(start), NO_OP_SETTER); - update(start, TOTPQueries.getQueryToCreateUsedCodesCreatedTimeIndex(start), NO_OP_SETTER); } } diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index cda1e841d..6ab5eeabc 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -50,11 +50,6 @@ public static String getQueryToCreateUsedCodesExpiryTimeIndex(Start start) { + Config.getConfig(start).getTotpUsedCodesTable() + " (expiry_time_ms)"; } - public static String getQueryToCreateUsedCodesCreatedTimeIndex(Start start) { - return "CREATE INDEX IF NOT EXISTS totp_used_codes_created_time_ms_index ON " - + Config.getConfig(start).getTotpUsedCodesTable() + " (created_time_ms DESC)"; - } - private static int insertUser_Transaction(Start start, Connection con, String userId) throws SQLException, StorageQueryException { // Create user if not exists: From cf16b6c93c37b47febde2afe2d6193634bee36f8 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Thu, 2 Mar 2023 19:27:04 +0530 Subject: [PATCH 25/42] refactor: Remove foreign key constraint emulation in TOTP We have now enabled foreign key constraints in inmemory db --- .../inmemorydb/queries/TOTPQueries.java | 22 ------------------- .../supertokens/test/totp/TOTPRecipeTest.java | 1 + 2 files changed, 1 insertion(+), 22 deletions(-) diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index 6ab5eeabc..91524fae8 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -121,18 +121,6 @@ public static int removeUser_Transaction(Start start, Connection con, String use + " WHERE user_id = ?;"; int removedUsersCount = update(con, QUERY, pst -> pst.setString(1, userId)); - // Delete all devices and used codes for this user: - // This step is required only for in-memory db. - // Other databases will automatically delete these when the user is - // deleted because of foreign key constraints. - String QUERY2 = "DELETE FROM " + Config.getConfig(start).getTotpUserDevicesTable() - + " WHERE user_id = ?;"; - update(con, QUERY2, pst -> pst.setString(1, userId)); - - String QUERY3 = "DELETE FROM " + Config.getConfig(start).getTotpUsedCodesTable() - + " WHERE user_id = ?;"; - update(con, QUERY3, pst -> pst.setString(1, userId)); - return removedUsersCount; } @@ -195,16 +183,6 @@ public static void insertUsedCode(Start start, TOTPUsedCode code) Connection sqlCon = (Connection) con.getConnection(); try { - // Check if user exists or not (if no device exists, user does not exist) - // NOTE: This step is required only for in-memory db. - int devicesCount = getDevicesCount_Transaction(start, sqlCon, code.userId); - if (devicesCount == 0) { - // no device left. transaction cannot be completed. - // foreign key constraint will fail. - throw new SQLException( - "[SQLITE_CONSTRAINT] Abort due to constraint violation (FOREIGN KEY constraint failed)"); - } - insertUsedCode_Transaction(start, sqlCon, code); sqlCon.commit(); } catch (SQLException e) { diff --git a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java index be0f4d0ac..fadf9a2d6 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java @@ -255,6 +255,7 @@ public void rateLimitCooldownTest() throws Exception { @Test public void cronRemovesAllCodesDuringRateLimitTest() throws Exception { + // FIXME: This test is flaky because of time being involved. String[] args = { "../" }; TestingProcessManager.TestingProcess process = TestingProcessManager.start(args, false); From 8c18d03074f7a8d934268ae9c7325dc6c2ffe314 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Tue, 7 Mar 2023 19:09:41 +0530 Subject: [PATCH 26/42] feat: Improve TOTP recipe - Set totp code column size to 8 in DB - Introduce UsedCodeAlreadyExistsException - Improve comments - Improve input parsing API layer - Only remove expired totp codes after cooldown - Add primary key to TOTP used codes table - Use max expiry instead of totp_invalid_code_expiry_time - Use BadRequestException in TOTP API layer --- .../io/supertokens/authRecipe/AuthRecipe.java | 36 +++++++---- .../io/supertokens/config/CoreConfig.java | 12 ---- .../DeleteExpiredTotpTokens.java | 10 ++- .../java/io/supertokens/inmemorydb/Start.java | 16 +++-- .../inmemorydb/queries/TOTPQueries.java | 8 ++- src/main/java/io/supertokens/totp/Totp.java | 38 +++++++++--- .../io/supertokens/webserver/InputParser.java | 24 +++++++ .../api/totp/CreateOrUpdateTotpDeviceAPI.java | 62 +++++-------------- .../webserver/api/totp/GetTotpDevicesAPI.java | 8 +-- .../api/totp/RemoveTotpDeviceAPI.java | 15 ++--- .../api/totp/UpdateTotpDeviceNameAPI.java | 22 ++----- .../webserver/api/totp/VerifyTotpAPI.java | 23 ++----- .../api/totp/VerifyTotpDeviceAPI.java | 22 ++----- 13 files changed, 142 insertions(+), 154 deletions(-) diff --git a/src/main/java/io/supertokens/authRecipe/AuthRecipe.java b/src/main/java/io/supertokens/authRecipe/AuthRecipe.java index a86ca123e..8f4734501 100644 --- a/src/main/java/io/supertokens/authRecipe/AuthRecipe.java +++ b/src/main/java/io/supertokens/authRecipe/AuthRecipe.java @@ -21,6 +21,7 @@ import io.supertokens.pluginInterface.authRecipe.AuthRecipeUserInfo; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; import io.supertokens.pluginInterface.useridmapping.UserIdMapping; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.useridmapping.UserIdType; @@ -60,22 +61,30 @@ public static UserPaginationContainer getUsers(Main main, Integer limit, String return new UserPaginationContainer(resultUsers, nextPaginationToken); } - public static void deleteUser(Main main, String userId) throws StorageQueryException, StorageTransactionLogicException { - // We clean up the user last so that if anything before that throws an error, then that will throw a 500 to the - // developer. In this case, they expect that the user has not been deleted (which will be true). This is as - // opposed to deleting the user first, in which case if something later throws an error, then the user has + public static void deleteUser(Main main, String userId) + throws StorageQueryException, StorageTransactionLogicException { + // We clean up the user last so that if anything before that throws an error, + // then that will throw a 500 to the + // developer. In this case, they expect that the user has not been deleted + // (which will be true). This is as + // opposed to deleting the user first, in which case if something later throws + // an error, then the user has // actually been deleted already (which is not expected by the dev) - // For things created after the intial cleanup and before finishing the operation: + // For things created after the intial cleanup and before finishing the + // operation: // - session: the session will expire anyway - // - email verification: email verification tokens can be created for any userId anyway + // - email verification: email verification tokens can be created for any userId + // anyway - // If userId mapping exists then delete entries with superTokensUserId from auth related tables and + // If userId mapping exists then delete entries with superTokensUserId from auth + // related tables and // externalUserid from non-auth tables UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(main, userId, UserIdType.ANY); if (userIdMapping != null) { - // We check if the mapped externalId is another SuperTokens UserId, this could come up when migrating + // We check if the mapped externalId is another SuperTokens UserId, this could + // come up when migrating // recipes. // in reference to // https://docs.google.com/spreadsheets/d/17hYV32B0aDCeLnSxbZhfRN2Y9b0LC2xUF44vV88RNAA/edit?usp=sharing @@ -98,15 +107,18 @@ public static void deleteUser(Main main, String userId) throws StorageQueryExcep } - private static void deleteNonAuthRecipeUser(Main main, String userId) throws StorageQueryException, StorageTransactionLogicException { + private static void deleteNonAuthRecipeUser(Main main, String userId) + throws StorageQueryException, StorageTransactionLogicException { // non auth recipe deletion StorageLayer.getUserMetadataStorage(main).deleteUserMetadata(userId); StorageLayer.getSessionStorage(main).deleteSessionsOfUser(userId); StorageLayer.getEmailVerificationStorage(main).deleteEmailVerificationUserInfo(userId); StorageLayer.getUserRolesStorage(main).deleteAllRolesForUser(userId); - StorageLayer.getTOTPStorage(main).startTransaction(con -> { - StorageLayer.getTOTPStorage(main).removeUser_Transaction(con, userId); - StorageLayer.getTOTPStorage(main).commitTransaction(con); + + TOTPSQLStorage storage = StorageLayer.getTOTPStorage(main); + storage.startTransaction(con -> { + storage.removeUser_Transaction(con, userId); + storage.commitTransaction(con); return null; }); } diff --git a/src/main/java/io/supertokens/config/CoreConfig.java b/src/main/java/io/supertokens/config/CoreConfig.java index 383e0caab..e59690107 100644 --- a/src/main/java/io/supertokens/config/CoreConfig.java +++ b/src/main/java/io/supertokens/config/CoreConfig.java @@ -62,9 +62,6 @@ public class CoreConfig { @JsonProperty private int totp_rate_limit_cooldown_sec = 900; // in seconds (Default 15 mins) - @JsonProperty - private int totp_invalid_code_expiry_sec = 1800; // in seconds (Default 30 mins) - private final String logDefault = "asdkfahbdfk3kjHS"; @JsonProperty private String info_log_path = logDefault; @@ -289,11 +286,6 @@ public int getTotpRateLimitCooldownTime() { return totp_rate_limit_cooldown_sec; } - /** TOTP invalid code expiry time (in seconds) */ - public int getTotpInvalidCodeExpiryTime() { - return totp_invalid_code_expiry_sec; - } - public boolean isTelemetryDisabled() { return disable_telemetry; } @@ -420,10 +412,6 @@ void validateAndInitialise(Main main) throws IOException { throw new QuitProgramException("'totp_rate_limit_cooldown_sec' must be > 0"); } - if (totp_invalid_code_expiry_sec <= 0) { - throw new QuitProgramException("'totp_invalid_code_expiry_sec' must be > 0"); - } - if (max_server_pool_size <= 0) { throw new QuitProgramException("'max_server_pool_size' must be >= 1. The config file can be found here: " + getConfigFileLocation(main)); diff --git a/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java b/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java index cfe02f81a..78d422ad6 100644 --- a/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java +++ b/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java @@ -2,6 +2,7 @@ import io.supertokens.Main; import io.supertokens.ResourceDistributor; +import io.supertokens.config.Config; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; import io.supertokens.cronjobs.CronTask; @@ -33,7 +34,14 @@ protected void doTask() throws Exception { TOTPSQLStorage storage = StorageLayer.getTOTPStorage(this.main); - int deletedCount = storage.removeExpiredCodes(); + long rateLimitResetInMs = Config.getConfig(this.main).getTotpRateLimitCooldownTime(); + long expiredBefore = System.currentTimeMillis() - rateLimitResetInMs; + + // We will only remove expired codes that have been expired for longer + // than rate limiting duration. This ensures that this DB query + // doesn't delete totp codes that keep the rate limiting active for + // the expected cooldown duration. + int deletedCount = storage.removeExpiredCodes(expiredBefore); Logging.debug(this.main, "Cron DeleteExpiredTotpTokens deleted " + deletedCount + " expired TOTP codes"); } diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index 318006d2f..34ebb8f74 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -68,6 +68,7 @@ import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; +import io.supertokens.pluginInterface.totp.exception.UsedCodeAlreadyExistsException; import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; import io.supertokens.pluginInterface.useridmapping.exception.UnknownSuperTokensUserIdException; import io.supertokens.pluginInterface.useridmapping.exception.UserIdMappingAlreadyExistsException; @@ -1866,16 +1867,23 @@ public TOTPDevice[] getDevices(String userId) @Override public void insertUsedCode(TOTPUsedCode usedCodeObj) - throws StorageQueryException, TotpNotEnabledException { + throws StorageQueryException, TotpNotEnabledException, UsedCodeAlreadyExistsException { try { TOTPQueries.insertUsedCode(this, usedCodeObj); } catch (StorageTransactionLogicException e) { String message = e.actualException.getMessage(); + // No user/device exists for the given usedCodeObj.userId if (message .equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (FOREIGN KEY constraint failed)")) { - // No user/device exists for the given usedCodeObj.userId throw new TotpNotEnabledException(); } + // Failed due to primary key on (userId, created_time) + if (message.equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (UNIQUE constraint failed: " + + Config.getConfig(this).getTotpUsedCodesTable() + ".user_id, " + + Config.getConfig(this).getTotpUsedCodesTable() + ".created_time" + ")")) { + throw new UsedCodeAlreadyExistsException(); + } + throw new StorageQueryException(e.actualException); } } @@ -1891,10 +1899,10 @@ public TOTPUsedCode[] getAllUsedCodesDescOrder(String userId) } @Override - public int removeExpiredCodes() + public int removeExpiredCodes(long expiredBefore) throws StorageQueryException { try { - return TOTPQueries.removeExpiredCodes(this); + return TOTPQueries.removeExpiredCodes(this, expiredBefore); } catch (SQLException e) { throw new StorageQueryException(e); } diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index 91524fae8..ffb6eaa1b 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -38,9 +38,10 @@ public static String getQueryToCreateUserDevicesTable(Start start) { public static String getQueryToCreateUsedCodesTable(Start start) { return "CREATE TABLE IF NOT EXISTS " + Config.getConfig(start).getTotpUsedCodesTable() + " (" + "user_id VARCHAR(128) NOT NULL, " - + "code CHAR(6) NOT NULL," + "is_valid BOOLEAN NOT NULL," + + "code VARCHAR(8) NOT NULL," + "is_valid BOOLEAN NOT NULL," + "created_time_ms BIGINT UNSIGNED NOT NULL," + "expiry_time_ms BIGINT UNSIGNED NOT NULL," + + "PRIMARY KEY (user_id, created_time_ms)" + "FOREIGN KEY (user_id) REFERENCES " + Config.getConfig(start).getTotpUsersTable() + "(user_id) ON DELETE CASCADE);"; } @@ -214,12 +215,12 @@ public static TOTPUsedCode[] getAllUsedCodesDescOrder(Start start, String userId }); } - public static int removeExpiredCodes(Start start) + public static int removeExpiredCodes(Start start, long expiredBefore) throws StorageQueryException, SQLException { String QUERY = "DELETE FROM " + Config.getConfig(start).getTotpUsedCodesTable() + " WHERE expiry_time_ms < ?;"; - return update(start, QUERY, pst -> pst.setLong(1, System.currentTimeMillis())); + return update(start, QUERY, pst -> pst.setLong(1, expiredBefore)); } private static class TOTPDeviceRowMapper implements RowMapper { @@ -262,6 +263,7 @@ public TOTPUsedCode map(ResultSet result) throws SQLException { result.getBoolean("is_valid"), result.getLong("expiry_time_ms"), result.getLong("created_time_ms")); + // FIXME: Put created time first, then expiry time. } } } diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index 65508eacf..4506f7366 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -27,6 +27,7 @@ import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; +import io.supertokens.pluginInterface.totp.exception.UsedCodeAlreadyExistsException; import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.totp.exceptions.InvalidTotpException; @@ -110,7 +111,7 @@ private static void checkAndStoreCode(Main main, TOTPStorage totpStorage, String // if 1st zero (invalid code) expires in 5 mins and we filter // out expired codes, we'll end up [1, 1, 0, 0]. This doesn't contain 3 // contiguous invalid codes, so the user will be released from rate limiting in - // 5 minutes instead of 15 minutes. + // 5 minutes instead of 15 minutes. // Example for Case 2: // User has used codes like this: [0, 1, 0, 0]. @@ -177,18 +178,32 @@ private static void checkAndStoreCode(Main main, TOTPStorage totpStorage, String if (usedCode.code.equals(code) && usedCode.isValid && usedCode.expiryTime > System.currentTimeMillis()) { isValid = false; - matchingDevice = null; + // We found a matching device but the code + // will be considered invalid here. } } } // Insert the code into the list of used codes: - long now = System.currentTimeMillis(); - int invalidCodeExpirySec = Config.getConfig(main).getTotpInvalidCodeExpiryTime(); // (Default 30 mins) - int expireInSec = isValid ? matchingDevice.period * (2 * matchingDevice.skew + 1) : invalidCodeExpirySec; - TOTPUsedCode newCode = new TOTPUsedCode(userId, code, isValid, now + 1000 * expireInSec, now); - totpStorage.insertUsedCode(newCode); + // If device is found, calculate used code expiry time for that device (based on + // its period and skew). Otherwise, use the max used code expiry time of all the + // devices. + int maxUsedCodeExpiry = Arrays.stream(devices).mapToInt(device -> device.period * (2 * device.skew + 1)).max() + .orElse(0); + int expireInSec = (matchingDevice != null) ? matchingDevice.period * (2 * matchingDevice.skew + 1) + : maxUsedCodeExpiry; + + while (true) { + long now = System.currentTimeMillis(); + TOTPUsedCode newCode = new TOTPUsedCode(userId, code, isValid, now + 1000 * expireInSec, now); + try { + totpStorage.insertUsedCode(newCode); + break; + } catch (UsedCodeAlreadyExistsException e) { + continue; // Try again + } + } if (!isValid) { throw new InvalidTotpException(); @@ -204,6 +219,10 @@ public static boolean verifyDevice(Main main, String userId, String deviceName, TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); TOTPDevice matchingDevice = null; + // Here one race condition is that the same device + // is to be verified twice in parallel. In that case, + // both the API calls will return true, but that's okay. + // Check if the user has any devices: TOTPDevice[] devices = totpStorage.getDevices(userId); if (devices.length == 0) { @@ -260,8 +279,9 @@ public static void removeDevice(Main main, String userId, String deviceName) try { storage.startTransaction(con -> { int deletedCount = storage.deleteDevice_Transaction(con, userId, deviceName); - if (deletedCount == 0) + if (deletedCount == 0) { throw new StorageTransactionLogicException(new UnknownDeviceException()); + } // Some device(s) were deleted. Check if user has any other device left: int devicesCount = storage.getDevicesCount_Transaction(con, userId); @@ -283,7 +303,7 @@ public static void removeDevice(Main main, String userId, String deviceName) } } - throw e; + throw new StorageQueryException(e.actualException); } } diff --git a/src/main/java/io/supertokens/webserver/InputParser.java b/src/main/java/io/supertokens/webserver/InputParser.java index 23be846bc..6581b4c9f 100644 --- a/src/main/java/io/supertokens/webserver/InputParser.java +++ b/src/main/java/io/supertokens/webserver/InputParser.java @@ -197,4 +197,28 @@ public static Long parseLongOrThrowError(JsonObject element, String fieldName, b } } + + + public static Integer parseIntOrThrowError(JsonObject element, String fieldName, boolean nullable) + throws ServletException { + try { + if (nullable && element.get(fieldName) == null) { + return null; + + } + String stringified = element.toString(); + if (!stringified.contains("\"")) { + throw new Exception(); + + } + return element.get(fieldName).getAsInt(); + + } catch (Exception e) { + throw new ServletException( + new WebserverAPI.BadRequestException("Field name '" + fieldName + "' is invalid in JSON input")); + + } + + } + } diff --git a/src/main/java/io/supertokens/webserver/api/totp/CreateOrUpdateTotpDeviceAPI.java b/src/main/java/io/supertokens/webserver/api/totp/CreateOrUpdateTotpDeviceAPI.java index 5f51c02bd..76b3e09cf 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/CreateOrUpdateTotpDeviceAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/CreateOrUpdateTotpDeviceAPI.java @@ -35,47 +35,27 @@ public String getPath() { protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { JsonObject input = InputParser.parseJsonObjectOrThrowError(req); - String userId = null; - String deviceName = null; - Integer skew = null; - Integer period = null; + String userId = InputParser.parseStringOrThrowError(input, "userId", false); + String deviceName = InputParser.parseStringOrThrowError(input, "deviceName", false); + Integer skew = InputParser.parseIntOrThrowError(input, "skew", false); + Integer period = InputParser.parseIntOrThrowError(input, "period", false); - // TODO: Should we also allow the user to change the hashing algo and totp - // length (6-8) since we are already allowing them to change the period and skew - // which are advanced options anyways? - - if (input.has("userId")) { - userId = InputParser.parseStringOrThrowError(input, "userId", false); - } - if (input.has("deviceName")) { - deviceName = InputParser.parseStringOrThrowError(input, "deviceName", false); - } - if (input.has("skew")) { - // FIXME: No function to parse integer: - skew = InputParser.parseLongOrThrowError(input, "skew", false).intValue(); - } - if (input.has("period")) { - // FIXME: No function to parse integer: - period = InputParser.parseLongOrThrowError(input, "period", false).intValue(); - } + // Note: Not allowing the user to change the hashing algo and totp + // length (6-8) at the moment because it's rare to change them if (userId.isEmpty()) { - throw new ServletException(new IllegalArgumentException("userId cannot be empty")); + throw new ServletException(new BadRequestException("userId cannot be empty")); } if (deviceName.isEmpty()) { - throw new ServletException(new IllegalArgumentException("deviceName cannot be empty")); + throw new ServletException(new BadRequestException("deviceName cannot be empty")); } if (skew < 0) { - throw new ServletException(new IllegalArgumentException("skew must be >= 0")); + throw new ServletException(new BadRequestException("skew must be >= 0")); } if (period <= 0) { - throw new ServletException(new IllegalArgumentException("period must be > 0")); + throw new ServletException(new BadRequestException("period must be > 0")); } - // Should we do these as well? - // - Assert period <= 60. Otherwise, it is a security risk. > 30 is also bad. - // - Assert skew <= 2. Otherwise, it is a security risk. - JsonObject result = new JsonObject(); try { @@ -96,28 +76,18 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I protected void doPut(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { JsonObject input = InputParser.parseJsonObjectOrThrowError(req); - String userId = null; - String existingDeviceName = null; - String newDeviceName = null; - - if (input.has("userId")) { - userId = InputParser.parseStringOrThrowError(input, "userId", false); - } - if (input.has("existingDeviceName")) { - existingDeviceName = InputParser.parseStringOrThrowError(input, "existingDeviceName", false); - } - if (input.has("newDeviceName")) { - newDeviceName = InputParser.parseStringOrThrowError(input, "newDeviceName", false); - } + String userId = InputParser.parseStringOrThrowError(input, "userId", false); + String existingDeviceName = InputParser.parseStringOrThrowError(input, "existingDeviceName", false); + String newDeviceName = InputParser.parseStringOrThrowError(input, "newDeviceName", false); if (userId.isEmpty()) { - throw new ServletException(new IllegalArgumentException("userId cannot be empty")); + throw new ServletException(new BadRequestException("userId cannot be empty")); } if (existingDeviceName.isEmpty()) { - throw new ServletException(new IllegalArgumentException("existingDeviceName cannot be empty")); + throw new ServletException(new BadRequestException("existingDeviceName cannot be empty")); } if (newDeviceName.isEmpty()) { - throw new ServletException(new IllegalArgumentException("newDeviceName cannot be empty")); + throw new ServletException(new BadRequestException("newDeviceName cannot be empty")); } JsonObject result = new JsonObject(); diff --git a/src/main/java/io/supertokens/webserver/api/totp/GetTotpDevicesAPI.java b/src/main/java/io/supertokens/webserver/api/totp/GetTotpDevicesAPI.java index 5cc43e77c..122e8be8e 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/GetTotpDevicesAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/GetTotpDevicesAPI.java @@ -33,14 +33,10 @@ public String getPath() { protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { JsonObject input = InputParser.parseJsonObjectOrThrowError(req); - String userId = null; - - if (input.has("userId")) { - userId = InputParser.parseStringOrThrowError(input, "userId", false); - } + String userId = InputParser.parseStringOrThrowError(input, "userId", false); if (userId.isEmpty()) { - throw new ServletException(new IllegalArgumentException("userId cannot be empty")); + throw new ServletException(new BadRequestException("userId cannot be empty")); } JsonObject result = new JsonObject(); diff --git a/src/main/java/io/supertokens/webserver/api/totp/RemoveTotpDeviceAPI.java b/src/main/java/io/supertokens/webserver/api/totp/RemoveTotpDeviceAPI.java index bb349c561..b2a079694 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/RemoveTotpDeviceAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/RemoveTotpDeviceAPI.java @@ -33,21 +33,14 @@ public String getPath() { protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { JsonObject input = InputParser.parseJsonObjectOrThrowError(req); - String userId = null; - String deviceName = null; - - if (input.has("userId")) { - userId = InputParser.parseStringOrThrowError(input, "userId", false); - } - if (input.has("deviceName")) { - deviceName = InputParser.parseStringOrThrowError(input, "deviceName", false); - } + String userId = InputParser.parseStringOrThrowError(input, "userId", false); + String deviceName = InputParser.parseStringOrThrowError(input, "deviceName", false); if (userId.isEmpty()) { - throw new ServletException(new IllegalArgumentException("userId cannot be empty")); + throw new ServletException(new BadRequestException("userId cannot be empty")); } if (deviceName.isEmpty()) { - throw new ServletException(new IllegalArgumentException("deviceName cannot be empty")); + throw new ServletException(new BadRequestException("deviceName cannot be empty")); } JsonObject result = new JsonObject(); diff --git a/src/main/java/io/supertokens/webserver/api/totp/UpdateTotpDeviceNameAPI.java b/src/main/java/io/supertokens/webserver/api/totp/UpdateTotpDeviceNameAPI.java index 891fc2f7a..eaa69a388 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/UpdateTotpDeviceNameAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/UpdateTotpDeviceNameAPI.java @@ -33,28 +33,18 @@ public String getPath() { protected void doPut(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { JsonObject input = InputParser.parseJsonObjectOrThrowError(req); - String userId = null; - String existingDeviceName = null; - String newDeviceName = null; - - if (input.has("userId")) { - userId = InputParser.parseStringOrThrowError(input, "userId", false); - } - if (input.has("existingDeviceName")) { - existingDeviceName = InputParser.parseStringOrThrowError(input, "existingDeviceName", false); - } - if (input.has("newDeviceName")) { - newDeviceName = InputParser.parseStringOrThrowError(input, "newDeviceName", false); - } + String userId = InputParser.parseStringOrThrowError(input, "userId", false); + String existingDeviceName = InputParser.parseStringOrThrowError(input, "existingDeviceName", false); + String newDeviceName = InputParser.parseStringOrThrowError(input, "newDeviceName", false); if (userId.isEmpty()) { - throw new ServletException(new IllegalArgumentException("userId cannot be empty")); + throw new ServletException(new BadRequestException("userId cannot be empty")); } if (existingDeviceName.isEmpty()) { - throw new ServletException(new IllegalArgumentException("existingDeviceName cannot be empty")); + throw new ServletException(new BadRequestException("existingDeviceName cannot be empty")); } if (newDeviceName.isEmpty()) { - throw new ServletException(new IllegalArgumentException("newDeviceName cannot be empty")); + throw new ServletException(new BadRequestException("newDeviceName cannot be empty")); } JsonObject result = new JsonObject(); diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java index e1880251e..e486baf0b 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java @@ -33,28 +33,15 @@ public String getPath() { protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { JsonObject input = InputParser.parseJsonObjectOrThrowError(req); - String userId = null; - String totp = null; - Boolean allowUnverifiedDevices = null; - - if (input.has("userId")) { - userId = InputParser.parseStringOrThrowError(input, "userId", false); - } - if (input.has("totp")) { - totp = InputParser.parseStringOrThrowError(input, "totp", false); - if (totp.length() != 6) { - throw new ServletException(new IllegalArgumentException("totp must be 6 characters long")); - } - } - if (input.has("allowUnverifiedDevices")) { - allowUnverifiedDevices = InputParser.parseBooleanOrThrowError(input, "allowUnverifiedDevices", false); - } + String userId = InputParser.parseStringOrThrowError(input, "userId", false); + String totp = InputParser.parseStringOrThrowError(input, "totp", false); + Boolean allowUnverifiedDevices = InputParser.parseBooleanOrThrowError(input, "allowUnverifiedDevices", false); if (userId.isEmpty()) { - throw new ServletException(new IllegalArgumentException("userId cannot be empty")); + throw new ServletException(new BadRequestException("userId cannot be empty")); } if (totp.length() != 6) { - throw new ServletException(new IllegalArgumentException("totp must be 6 characters long")); + throw new ServletException(new BadRequestException("totp must be 6 characters long")); } // Already checked that allowUnverifiedDevices is not null. diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java index 4218f219a..dca121d1b 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java @@ -34,28 +34,18 @@ public String getPath() { protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { JsonObject input = InputParser.parseJsonObjectOrThrowError(req); - String userId = null; - String deviceName = null; - String totp = null; - - if (input.has("userId")) { - userId = InputParser.parseStringOrThrowError(input, "userId", false); - } - if (input.has("deviceName")) { - deviceName = InputParser.parseStringOrThrowError(input, "deviceName", false); - } - if (input.has("totp")) { - totp = InputParser.parseStringOrThrowError(input, "totp", false); - } + String userId = InputParser.parseStringOrThrowError(input, "userId", false); + String deviceName = InputParser.parseStringOrThrowError(input, "deviceName", false); + String totp = InputParser.parseStringOrThrowError(input, "totp", false); if (userId.isEmpty()) { - throw new ServletException(new IllegalArgumentException("userId cannot be empty")); + throw new ServletException(new BadRequestException("userId cannot be empty")); } if (deviceName.isEmpty()) { - throw new ServletException(new IllegalArgumentException("deviceName cannot be empty")); + throw new ServletException(new BadRequestException("deviceName cannot be empty")); } if (totp.length() != 6) { - throw new ServletException(new IllegalArgumentException("totp must be 6 characters long")); + throw new ServletException(new BadRequestException("totp must be 6 characters long")); } JsonObject result = new JsonObject(); From 6d8a2b228b777c3dda09ba7512eb755737e61a9b Mon Sep 17 00:00:00 2001 From: KShivendu Date: Thu, 9 Mar 2023 16:30:57 +0530 Subject: [PATCH 27/42] feat: Improve TOTP recipe - Fix totp cron - Update tests - Use base32 - Use transaction for check and store code --- build.gradle | 3 + cli/bin/main/install-linux.sh | 13 ++ cli/bin/main/install-windows.bat | 12 ++ config.yaml | 3 - ...rtokens.featureflag.EEFeatureFlagInterface | 1 + implementationDependencies.json | 5 + .../io/supertokens/config/CoreConfig.java | 2 +- .../DeleteExpiredTotpTokens.java | 2 +- .../java/io/supertokens/inmemorydb/Start.java | 19 +- .../inmemorydb/queries/TOTPQueries.java | 30 +-- src/main/java/io/supertokens/totp/Totp.java | 201 ++++++++++-------- .../io/supertokens/test/ConfigTest2_6.java | 3 +- .../supertokens/test/totp/TOTPRecipeTest.java | 99 ++++----- .../test/totp/TOTPStorageTest.java | 142 +++++++++---- 14 files changed, 305 insertions(+), 230 deletions(-) create mode 100644 cli/bin/main/install-linux.sh create mode 100644 cli/bin/main/install-windows.bat create mode 100644 ee/bin/main/META-INF/services/io.supertokens.featureflag.EEFeatureFlagInterface diff --git a/build.gradle b/build.gradle index c8b912723..98ec6d997 100644 --- a/build.gradle +++ b/build.gradle @@ -68,6 +68,9 @@ dependencies { // https://mvnrepository.com/artifact/com.eatthepath/java-otp implementation group: 'com.eatthepath', name: 'java-otp', version: '0.4.0' + // https://mvnrepository.com/artifact/commons-codec/commons-codec + implementation group: 'commons-codec', name: 'commons-codec', version: '1.15' + compileOnly project(":supertokens-plugin-interface") testImplementation project(":supertokens-plugin-interface") diff --git a/cli/bin/main/install-linux.sh b/cli/bin/main/install-linux.sh new file mode 100644 index 000000000..d29c5b040 --- /dev/null +++ b/cli/bin/main/install-linux.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +ST_INSTALL_LOC=$ST_INSTALL_LOC + +if [ -f /proc/1/cgroup ] && grep docker /proc/1/cgroup -qa; then + trap 'kill -TERM $PID' TERM INT + "${ST_INSTALL_LOC}"jre/bin/java -classpath "${ST_INSTALL_LOC}cli/*" io.supertokens.cli.Main false "${ST_INSTALL_LOC}" $@ & + PID=$! + wait $PID + trap - TERM INT +else + "${ST_INSTALL_LOC}"jre/bin/java -classpath "${ST_INSTALL_LOC}cli/*" io.supertokens.cli.Main false "${ST_INSTALL_LOC}" $@ +fi diff --git a/cli/bin/main/install-windows.bat b/cli/bin/main/install-windows.bat new file mode 100644 index 000000000..af13fb98d --- /dev/null +++ b/cli/bin/main/install-windows.bat @@ -0,0 +1,12 @@ +@echo off +set st_install_loc=$ST_INSTALL_LOC +"%st_install_loc%jre\bin"\java -classpath "%st_install_loc%cli\*" io.supertokens.cli.Main false "%st_install_loc%\" %* +IF %errorlevel% NEQ 0 ( +echo exiting +goto:eof +) +IF "%1" == "uninstall" ( +rmdir /S /Q "%st_install_loc%" +del "%~f0" +) +:eof diff --git a/config.yaml b/config.yaml index c239836d9..c8eb96a7c 100644 --- a/config.yaml +++ b/config.yaml @@ -60,9 +60,6 @@ core_config_version: 0 # (OPTIONAL | Default: 900) integer value. The time in seconds for which the user will be rate limited once totp_max_attempts is crossed. # totp_rate_limit_cooldown_sec: -# (OPTIONAL | Default: 1800) integer value. The time in seconds in which invalid TOTP codes will be considered expired. -# totp_invalid_code_expiry_sec: - # (OPTIONAL | Default: installation directory/logs/info.log) string value. Give the path to a file (on your local # system) in which the SuperTokens service can write INFO logs to. Set it to "null" if you want it to log to # standard output instead. diff --git a/ee/bin/main/META-INF/services/io.supertokens.featureflag.EEFeatureFlagInterface b/ee/bin/main/META-INF/services/io.supertokens.featureflag.EEFeatureFlagInterface new file mode 100644 index 000000000..d940a8488 --- /dev/null +++ b/ee/bin/main/META-INF/services/io.supertokens.featureflag.EEFeatureFlagInterface @@ -0,0 +1 @@ +io.supertokens.ee.EEFeatureFlag \ No newline at end of file diff --git a/implementationDependencies.json b/implementationDependencies.json index a96c50bc5..3e5b6dad3 100644 --- a/implementationDependencies.json +++ b/implementationDependencies.json @@ -105,6 +105,11 @@ "jar": "https://repo1.maven.org/maven2/com/eatthepath/java-otp/0.4.0/java-otp-0.4.0.jar", "name": "Java OTP 0.4.0", "src": "https://repo1.maven.org/maven2/com/eatthepath/java-otp/0.4.0/java-otp-0.4.0-sources.jar" + }, + { + "jar": "https://repo1.maven.org/maven2/commons-codec/commons-codec/1.15/commons-codec-1.15.jar", + "name": "Commons Codec 1.15", + "src": "https://repo1.maven.org/maven2/commons-codec/commons-codec/1.15/commons-codec-1.15-sources.jar" } ] } \ No newline at end of file diff --git a/src/main/java/io/supertokens/config/CoreConfig.java b/src/main/java/io/supertokens/config/CoreConfig.java index e59690107..f19f64b86 100644 --- a/src/main/java/io/supertokens/config/CoreConfig.java +++ b/src/main/java/io/supertokens/config/CoreConfig.java @@ -282,7 +282,7 @@ public int getTotpMaxAttempts() { } /** TOTP rate limit cooldown time (in seconds) */ - public int getTotpRateLimitCooldownTime() { + public int getTotpRateLimitCooldownTimeSec() { return totp_rate_limit_cooldown_sec; } diff --git a/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java b/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java index 78d422ad6..94730f045 100644 --- a/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java +++ b/src/main/java/io/supertokens/cronjobs/deleteExpiredTotpTokens/DeleteExpiredTotpTokens.java @@ -34,7 +34,7 @@ protected void doTask() throws Exception { TOTPSQLStorage storage = StorageLayer.getTOTPStorage(this.main); - long rateLimitResetInMs = Config.getConfig(this.main).getTotpRateLimitCooldownTime(); + long rateLimitResetInMs = Config.getConfig(this.main).getTotpRateLimitCooldownTimeSec() * 1000; long expiredBefore = System.currentTimeMillis() - rateLimitResetInMs; // We will only remove expired codes that have been expired for longer diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index 34ebb8f74..98bc21b55 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -1866,13 +1866,15 @@ public TOTPDevice[] getDevices(String userId) } @Override - public void insertUsedCode(TOTPUsedCode usedCodeObj) + public void insertUsedCode_Transaction(TransactionConnection con, TOTPUsedCode usedCodeObj) throws StorageQueryException, TotpNotEnabledException, UsedCodeAlreadyExistsException { + Connection sqlCon = (Connection) con.getConnection(); try { - TOTPQueries.insertUsedCode(this, usedCodeObj); - } catch (StorageTransactionLogicException e) { - String message = e.actualException.getMessage(); + TOTPQueries.insertUsedCode_Transaction(this, sqlCon, usedCodeObj); + } catch (SQLException e) { + String message = e.getMessage(); // No user/device exists for the given usedCodeObj.userId + if (message .equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (FOREIGN KEY constraint failed)")) { throw new TotpNotEnabledException(); @@ -1880,19 +1882,20 @@ public void insertUsedCode(TOTPUsedCode usedCodeObj) // Failed due to primary key on (userId, created_time) if (message.equals("[SQLITE_CONSTRAINT] Abort due to constraint violation (UNIQUE constraint failed: " + Config.getConfig(this).getTotpUsedCodesTable() + ".user_id, " - + Config.getConfig(this).getTotpUsedCodesTable() + ".created_time" + ")")) { + + Config.getConfig(this).getTotpUsedCodesTable() + ".created_time_ms" + ")")) { throw new UsedCodeAlreadyExistsException(); } - throw new StorageQueryException(e.actualException); + throw new StorageQueryException(e); } } @Override - public TOTPUsedCode[] getAllUsedCodesDescOrder(String userId) + public TOTPUsedCode[] getAllUsedCodesDescOrderAndLockByUser_Transaction(TransactionConnection con, String userId) throws StorageQueryException { + Connection sqlCon = (Connection) con.getConnection(); try { - return TOTPQueries.getAllUsedCodesDescOrder(this, userId); + return TOTPQueries.getAllUsedCodesDescOrderAndLockByUser_Transaction(this, sqlCon, userId); } catch (SQLException e) { throw new StorageQueryException(e); } diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index ffb6eaa1b..227d02f68 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -54,8 +54,6 @@ public static String getQueryToCreateUsedCodesExpiryTimeIndex(Start start) { private static int insertUser_Transaction(Start start, Connection con, String userId) throws SQLException, StorageQueryException { // Create user if not exists: - // TODO: Check if not using "CONFLICT DO NOTHING" will break the transaction - // It's not a problem anyways. but we should check for clarity String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUsersTable() + " (user_id) VALUES (?) ON CONFLICT DO NOTHING"; @@ -164,7 +162,7 @@ public static int getDevicesCount_Transaction(Start start, Connection con, Strin }); } - private static int insertUsedCode_Transaction(Start start, Connection con, TOTPUsedCode code) + public static int insertUsedCode_Transaction(Start start, Connection con, TOTPUsedCode code) throws SQLException, StorageQueryException { String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUsedCodesTable() + " (user_id, code, is_valid, expiry_time_ms, created_time_ms) VALUES (?, ?, ?, ?, ?);"; @@ -178,32 +176,20 @@ private static int insertUsedCode_Transaction(Start start, Connection con, TOTPU }); } - public static void insertUsedCode(Start start, TOTPUsedCode code) - throws StorageQueryException, StorageTransactionLogicException { - start.startTransaction(con -> { - Connection sqlCon = (Connection) con.getConnection(); - - try { - insertUsedCode_Transaction(start, sqlCon, code); - sqlCon.commit(); - } catch (SQLException e) { - throw new StorageTransactionLogicException(e); - } - - return null; - }); - } - /** * Query to get all used codes (expired/non-expired) for a user in descending * order of creation time. */ - public static TOTPUsedCode[] getAllUsedCodesDescOrder(Start start, String userId) + public static TOTPUsedCode[] getAllUsedCodesDescOrderAndLockByUser_Transaction(Start start, Connection con, + String userId) throws SQLException, StorageQueryException { + // Take a lock based on the user id: + ((ConnectionWithLocks) con).lock(userId + Config.getConfig(start).getTotpUsedCodesTable()); + String QUERY = "SELECT * FROM " + Config.getConfig(start).getTotpUsedCodesTable() - + " WHERE user_id = ? ORDER BY created_time_ms DESC;"; - return execute(start, QUERY, pst -> { + + " WHERE user_id = ? ORDER BY created_time_ms DESC"; + return execute(con, QUERY, pst -> { pst.setString(1, userId); }, result -> { List codes = new ArrayList<>(); diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index 4506f7366..b86992f1f 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -3,18 +3,13 @@ import java.security.InvalidKeyException; import java.security.Key; import java.security.NoSuchAlgorithmException; -import java.security.SecureRandom; import java.time.Duration; import java.time.Instant; import java.util.Arrays; -import java.util.Base64; import javax.crypto.KeyGenerator; -import javax.crypto.Mac; import javax.crypto.spec.SecretKeySpec; -import org.jetbrains.annotations.TestOnly; - import io.supertokens.Main; import io.supertokens.config.Config; @@ -32,6 +27,7 @@ import io.supertokens.storageLayer.StorageLayer; import io.supertokens.totp.exceptions.InvalidTotpException; import io.supertokens.totp.exceptions.LimitReachedException; +import org.apache.commons.codec.binary.Base32; public class Totp { private static String generateSecret() throws NoSuchAlgorithmException { @@ -41,16 +37,14 @@ private static String generateSecret() throws NoSuchAlgorithmException { final KeyGenerator keyGenerator = KeyGenerator.getInstance(TOTP_ALGORITHM); keyGenerator.init(160); // 160 bits = 20 bytes - // FIXME: Should return base32 or base16 - // Return base64 encoded string for the secret key: - return Base64.getEncoder().encodeToString(keyGenerator.generateKey().getEncoded()); + return new Base32().encodeToString(keyGenerator.generateKey().getEncoded()); } private static boolean checkCode(TOTPDevice device, String code) { final TimeBasedOneTimePasswordGenerator totp = new TimeBasedOneTimePasswordGenerator( Duration.ofSeconds(device.period)); - byte[] keyBytes = Base64.getDecoder().decode(device.secretKey); + byte[] keyBytes = new Base32().decode(device.secretKey); Key key = new SecretKeySpec(keyBytes, "HmacSHA1"); final int period = device.period; @@ -121,93 +115,124 @@ private static void checkAndStoreCode(Main main, TOTPStorage totpStorage, String // be rate limited for no reason. // That's why we need to fetch all the codes (expired + non-expired). - TOTPUsedCode[] usedCodes = totpStorage.getAllUsedCodesDescOrder(userId); - - // N represents # of invalid attempts that will trigger rate limiting: - int N = Config.getConfig(main).getTotpMaxAttempts(); // (Default 5) - // Count # of contiguous invalids in latest N attempts (stop at first valid): - long invalidOutOfN = Arrays.stream(usedCodes).limit(N).takeWhile(usedCode -> !usedCode.isValid).count(); - int rateLimitResetTimeInMs = Config.getConfig(main).getTotpRateLimitCooldownTime() * 1000; // (Default 15 mins) - - // Check if the user has been rate limited: - if (invalidOutOfN == N) { - // All of the latest N attempts were invalid: - long latestInvalidCodeCreatedTime = usedCodes[0].createdTime; - long now = System.currentTimeMillis(); - - if (now - latestInvalidCodeCreatedTime < rateLimitResetTimeInMs) { - // Less than rateLimitResetTimeInMs (default = 15 mins) time has elasped since - // the last invalid code: - throw new LimitReachedException(rateLimitResetTimeInMs / 1000); - - // If we insert the used code here, then it will further delay the user from - // being able to login. So not inserting it here. - - // Note: One edge case here is: User is rate limited, and then the - // DeleteExpiredTotpTokens cron removes the latest invalid attempts - // (because they have expired), and then user will again be able to - // do extra login attempts (totp_max_attempts more times). - // But rate limiting will kick in after totp_max_attempts number - // disarming the brute force attack. - // Furthermore, the cron running during cooldown of a user is somewhat rare. - // So this edge case is practically harmless. - } - } + // TOTPUsedCode[] usedCodes = - // Check if the code is valid for any device: - boolean isValid = false; - TOTPDevice matchingDevice = null; - for (TOTPDevice device : devices) { - // Check if the code is valid for this device: - if (checkCode(device, code)) { - isValid = true; - matchingDevice = device; - break; - } - } + TOTPSQLStorage totpSQLStorage = (TOTPSQLStorage) totpStorage; - // Check if the code has been previously used by the user and it was valid (and - // is still valid). If so, this could be a replay attack. So reject it. - if (isValid) { - for (TOTPUsedCode usedCode : usedCodes) { - // One edge case is that if the user has 2 devices, and they are used back to - // back (within 90 seconds) such that the code of the first device was - // regenerated by the second device, then it won't allow the second device's - // code to be used until it is expired. - // But this would be rare so we can ignore it for now. - if (usedCode.code.equals(code) && usedCode.isValid - && usedCode.expiryTime > System.currentTimeMillis()) { - isValid = false; - // We found a matching device but the code - // will be considered invalid here. + try { + totpSQLStorage.startTransaction(con -> { + TOTPUsedCode[] usedCodes = totpSQLStorage.getAllUsedCodesDescOrderAndLockByUser_Transaction(con, + userId); + + // N represents # of invalid attempts that will trigger rate limiting: + int N = Config.getConfig(main).getTotpMaxAttempts(); // (Default 5) + // Count # of contiguous invalids in latest N attempts (stop at first valid): + long invalidOutOfN = Arrays.stream(usedCodes).limit(N).takeWhile(usedCode -> !usedCode.isValid).count(); + int rateLimitResetTimeInMs = Config.getConfig(main).getTotpRateLimitCooldownTimeSec() * 1000; // (Default + // 15 + // mins) + + // Check if the user has been rate limited: + if (invalidOutOfN == N) { + // All of the latest N attempts were invalid: + long latestInvalidCodeCreatedTime = usedCodes[0].createdTime; + long now = System.currentTimeMillis(); + + if (now - latestInvalidCodeCreatedTime < rateLimitResetTimeInMs) { + // Less than rateLimitResetTimeInMs (default = 15 mins) time has elasped since + // the last invalid code: + throw new StorageTransactionLogicException( + new LimitReachedException(rateLimitResetTimeInMs / 1000)); + + // If we insert the used code here, then it will further delay the user from + // being able to login. So not inserting it here. + + // Note: One edge case here is: User is rate limited, and then the + // DeleteExpiredTotpTokens cron removes the latest invalid attempts + // (because they have expired), and then user will again be able to + // do extra login attempts (totp_max_attempts more times). + // But rate limiting will kick in after totp_max_attempts number + // disarming the brute force attack. + // Furthermore, the cron running during cooldown of a user is somewhat rare. + // So this edge case is practically harmless. + } } - } - } - // Insert the code into the list of used codes: + // Check if the code is valid for any device: + boolean isValid = false; + TOTPDevice matchingDevice = null; + for (TOTPDevice device : devices) { + // Check if the code is valid for this device: + if (checkCode(device, code)) { + isValid = true; + matchingDevice = device; + break; + } + } - // If device is found, calculate used code expiry time for that device (based on - // its period and skew). Otherwise, use the max used code expiry time of all the - // devices. - int maxUsedCodeExpiry = Arrays.stream(devices).mapToInt(device -> device.period * (2 * device.skew + 1)).max() - .orElse(0); - int expireInSec = (matchingDevice != null) ? matchingDevice.period * (2 * matchingDevice.skew + 1) - : maxUsedCodeExpiry; + // Check if the code has been previously used by the user and it was valid (and + // is still valid). If so, this could be a replay attack. So reject it. + if (isValid) { + for (TOTPUsedCode usedCode : usedCodes) { + // One edge case is that if the user has 2 devices, and they are used back to + // back (within 90 seconds) such that the code of the first device was + // regenerated by the second device, then it won't allow the second device's + // code to be used until it is expired. + // But this would be rare so we can ignore it for now. + if (usedCode.code.equals(code) && usedCode.isValid + && usedCode.expiryTime > System.currentTimeMillis()) { + isValid = false; + // We found a matching device but the code + // will be considered invalid here. + } + } + } - while (true) { - long now = System.currentTimeMillis(); - TOTPUsedCode newCode = new TOTPUsedCode(userId, code, isValid, now + 1000 * expireInSec, now); - try { - totpStorage.insertUsedCode(newCode); - break; - } catch (UsedCodeAlreadyExistsException e) { - continue; // Try again + // Insert the code into the list of used codes: + + // If device is found, calculate used code expiry time for that device (based on + // its period and skew). Otherwise, use the max used code expiry time of all the + // devices. + int maxUsedCodeExpiry = Arrays.stream(devices).mapToInt(device -> device.period * (2 * device.skew + 1)) + .max() + .orElse(0); + int expireInSec = (matchingDevice != null) ? matchingDevice.period * (2 * matchingDevice.skew + 1) + : maxUsedCodeExpiry; + + while (true) { + long now = System.currentTimeMillis(); + TOTPUsedCode newCode = new TOTPUsedCode(userId, code, isValid, now + 1000 * expireInSec, now); + try { + totpSQLStorage.insertUsedCode_Transaction(con, newCode); + break; + } catch (UsedCodeAlreadyExistsException e) { + break; + } catch (TotpNotEnabledException e) { + throw new StorageTransactionLogicException(e); + } + } + + if (!isValid) { + totpSQLStorage.commitTransaction(con); + throw new StorageTransactionLogicException(new InvalidTotpException()); + } + + return null; + }); + } catch (StorageTransactionLogicException e) { + if (e.actualException instanceof LimitReachedException) { + throw (LimitReachedException) e.actualException; + } else if (e.actualException instanceof InvalidTotpException) { + throw (InvalidTotpException) e.actualException; + } else if (e.actualException instanceof TotpNotEnabledException) { + throw (TotpNotEnabledException) e.actualException; + } else { + throw new StorageQueryException(e.actualException); } } - if (!isValid) { - throw new InvalidTotpException(); - } + return; + } public static boolean verifyDevice(Main main, String userId, String deviceName, String code) @@ -301,6 +326,8 @@ public static void removeDevice(Main main, String userId, String deviceName) if (devices.length == 0) { throw new TotpNotEnabledException(); } + + throw (UnknownDeviceException) e.actualException; } throw new StorageQueryException(e.actualException); diff --git a/src/test/java/io/supertokens/test/ConfigTest2_6.java b/src/test/java/io/supertokens/test/ConfigTest2_6.java index 0b869f38a..e97f07b03 100644 --- a/src/test/java/io/supertokens/test/ConfigTest2_6.java +++ b/src/test/java/io/supertokens/test/ConfigTest2_6.java @@ -264,8 +264,7 @@ private static void checkConfigValues(CoreConfig config, TestingProcess process, assertEquals("Config refresh token validity did not match default", config.getRefreshTokenValidity(), 60 * 2400 * 60 * (long) 1000); assertEquals(5, config.getTotpMaxAttempts()); // 5 - assertEquals(900, config.getTotpRateLimitCooldownTime()); // 15 minutes - assertEquals(1800, config.getTotpInvalidCodeExpiryTime()); // 30 minutes + assertEquals(900, config.getTotpRateLimitCooldownTimeSec()); // 15 minutes assertEquals("Config info log path did not match default", config.getInfoLogPath(process.getProcess()), CLIOptions.get(process.getProcess()).getInstallationPath() + "logs/info.log"); diff --git a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java index fadf9a2d6..950cf2108 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java @@ -24,10 +24,10 @@ import java.security.Key; import java.time.Duration; import java.time.Instant; -import java.util.Base64; import javax.crypto.spec.SecretKeySpec; +import org.apache.commons.codec.binary.Base32; import org.junit.AfterClass; import org.junit.Before; import org.junit.Rule; @@ -40,11 +40,10 @@ import io.supertokens.Main; import io.supertokens.ProcessState; import io.supertokens.config.Config; -import io.supertokens.config.CoreConfig; -import io.supertokens.cronjobs.CronTaskTest; import io.supertokens.cronjobs.deleteExpiredTotpTokens.DeleteExpiredTotpTokens; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.test.TestingProcessManager; @@ -57,6 +56,7 @@ import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; +import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; public class TOTPRecipeTest { @@ -108,12 +108,24 @@ private static String generateTotpCode(Main main, TOTPDevice device, int step) final TimeBasedOneTimePasswordGenerator totp = new TimeBasedOneTimePasswordGenerator( Duration.ofSeconds(device.period)); - byte[] keyBytes = Base64.getDecoder().decode(device.secretKey); + byte[] keyBytes = new Base32().decode(device.secretKey); Key key = new SecretKeySpec(keyBytes, "HmacSHA1"); return totp.generateOneTimePasswordString(key, Instant.now().plusSeconds(step * device.period)); } + private static TOTPUsedCode[] getAllUsedCodesUtil(TOTPStorage storage, String userId) + throws StorageQueryException, StorageTransactionLogicException { + assert storage instanceof TOTPSQLStorage; + TOTPSQLStorage sqlStorage = (TOTPSQLStorage) storage; + + return (TOTPUsedCode[]) sqlStorage.startTransaction(con -> { + TOTPUsedCode[] usedCodes = sqlStorage.getAllUsedCodesDescOrderAndLockByUser_Transaction(con, userId); + sqlStorage.commitTransaction(con); + return usedCodes; + }); + } + @Test public void createDeviceTest() throws Exception { TestSetupResult result = defaultInit(); @@ -255,15 +267,9 @@ public void rateLimitCooldownTest() throws Exception { @Test public void cronRemovesAllCodesDuringRateLimitTest() throws Exception { - // FIXME: This test is flaky because of time being involved. - String[] args = { "../" }; - TestingProcessManager.TestingProcess process = TestingProcessManager.start(args, false); - - Utils.setValueInConfig("totp_invalid_code_expiry_sec", "1"); - process.startProcess(); - assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); - - Main main = process.getProcess(); + // This test is flaky because of time. + TestSetupResult result = defaultInit(); + Main main = result.process.getProcess(); // Create device TOTPDevice device = Totp.registerDevice(main, "user", "deviceName", 0, 1); @@ -271,18 +277,16 @@ public void cronRemovesAllCodesDuringRateLimitTest() throws Exception { // Trigger rate limiting and fix it with cronjob (manually run cronjob): triggerAndCheckRateLimit(main, device); // Wait for 1 second so that all the codes expire: - Thread.sleep(2000); - // Manually run cronjob to delete all the codes: (Here all of them are expired) + Thread.sleep(1500); + // Manually run cronjob to delete all the codes after their + // expiry time + rate limiting period is over: DeleteExpiredTotpTokens.getInstance(main).run(); - // Will completely reset the rate limiting. Allowing the user to do N attempts - // here N == totp_max_attempts from the config: - assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "again-wrong-code1", true)); - // This should have throws LimitReachedException but it won't because of cron: - assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "again-wrong-code2", true)); - Totp.verifyCode(main, "user", generateTotpCode(main, device), true); - // We can do N attempts again: - triggerAndCheckRateLimit(main, device); + // This removal shouldn't affect rate limiting. User must remain rate limited. + assertThrows(LimitReachedException.class, + () -> Totp.verifyCode(main, "user", generateTotpCode(main, device), true)); + assertThrows(LimitReachedException.class, + () -> Totp.verifyCode(main, "user", "again-wrong-code1", true)); } @Test @@ -328,6 +332,7 @@ public void createAndVerifyDeviceTest() throws Exception { @Test public void removeDeviceTest() throws Exception { + // Flaky test. TestSetupResult result = defaultInit(); Main main = result.process.getProcess(); TOTPStorage storage = result.storage; @@ -339,6 +344,12 @@ public void removeDeviceTest() throws Exception { TOTPDevice[] devices = Totp.getDevices(main, "user"); assert (devices.length == 2); + // Try to delete device for non-existent user: + assertThrows(TotpNotEnabledException.class, () -> Totp.removeDevice(main, "non-existent-user", "device1")); + + // Try to delete non-existent device: + assertThrows(UnknownDeviceException.class, () -> Totp.removeDevice(main, "user", "non-existent-device")); + // Delete one of the devices { assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "invalid-code", true)); @@ -352,7 +363,7 @@ public void removeDeviceTest() throws Exception { assert (devices.length == 1); // 1 device still remain so all codes should still be still there: - TOTPUsedCode[] usedCodes = storage.getAllUsedCodesDescOrder("user"); + TOTPUsedCode[] usedCodes = getAllUsedCodesUtil(storage, "user"); assert (usedCodes.length == 3); } @@ -372,14 +383,14 @@ public void removeDeviceTest() throws Exception { assertThrows(TotpNotEnabledException.class, () -> Totp.getDevices(main, "user")); // No device left so all codes of the user should be deleted: - TOTPUsedCode[] usedCodes = storage.getAllUsedCodesDescOrder("user"); + TOTPUsedCode[] usedCodes = getAllUsedCodesUtil(storage, "user"); assert (usedCodes.length == 0); // But for other users things should still be there: TOTPDevice[] otherUserDevices = Totp.getDevices(main, "other-user"); assert (otherUserDevices.length == 1); - usedCodes = storage.getAllUsedCodesDescOrder("other-user"); + usedCodes = getAllUsedCodesUtil(storage, "other-user"); assert (usedCodes.length == 2); } } @@ -444,40 +455,4 @@ public void deleteExpiredTokensCronIntervalTest() throws Exception { assert DeleteExpiredTotpTokens.getInstance(main).getIntervalTimeSeconds() == 60 * 60; } - @Test - public void deleteExpiredTokensCronTest() throws Exception { - String[] args = { "../" }; - TestingProcessManager.TestingProcess process = TestingProcessManager.start(args, false); - - CronTaskTest.getInstance(process.getProcess()).setIntervalInSeconds(DeleteExpiredTotpTokens.RESOURCE_KEY, 1); - process.startProcess(); - assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); - - Main main = process.getProcess(); - - // Create device - // Set period and skew to 0 to make sure that the codes are one time usable and - // expire in 1 second - TOTPDevice device = Totp.registerDevice(main, "user", "device", 0, 1); - - // Add codes: - Totp.verifyCode(main, "user", generateTotpCode(main, device), true); - assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", - "invalid-code", true)); - - TOTPStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); - - // Verify that the codes have been added: - TOTPUsedCode[] usedCodes = storage.getAllUsedCodesDescOrder("user"); - assert (usedCodes.length == 2); - - // Wait for 2 second to make sure that the valid codes expire - // (and crons deletes the valid ones since they are expired) - Thread.sleep(1000 * 2); - - usedCodes = storage.getAllUsedCodesDescOrder("user"); - assert (usedCodes.length == 1); - // Invalid code will still remain because their expiry time is 30 minutes - assert usedCodes[0].code.equals("invalid-code"); - } } diff --git a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java index 0cec2634b..6815efffa 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java @@ -13,16 +13,18 @@ import io.supertokens.ProcessState; import io.supertokens.cronjobs.deleteExpiredTotpTokens.DeleteExpiredTotpTokens; import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.test.TestingProcessManager; -import io.supertokens.totp.Totp; import io.supertokens.pluginInterface.totp.TOTPDevice; import io.supertokens.pluginInterface.totp.TOTPStorage; import io.supertokens.pluginInterface.totp.TOTPUsedCode; import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; +import io.supertokens.pluginInterface.totp.exception.UsedCodeAlreadyExistsException; import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; public class TOTPStorageTest { @@ -64,6 +66,45 @@ public TestSetupResult initSteps() throws InterruptedException { return new TestSetupResult(storage, process); } + private static TOTPUsedCode[] getAllUsedCodesUtil(TOTPStorage storage, String userId) + throws StorageQueryException, StorageTransactionLogicException { + assert storage instanceof TOTPSQLStorage; + TOTPSQLStorage sqlStorage = (TOTPSQLStorage) storage; + + return (TOTPUsedCode[]) sqlStorage.startTransaction(con -> { + TOTPUsedCode[] usedCodes = sqlStorage.getAllUsedCodesDescOrderAndLockByUser_Transaction(con, userId); + sqlStorage.commitTransaction(con); + return usedCodes; + }); + } + + private static void insertUsedCodesUtil(TOTPSQLStorage storage, TOTPUsedCode[] usedCodes) + throws StorageQueryException, StorageTransactionLogicException, TotpNotEnabledException, + UsedCodeAlreadyExistsException { + try { + storage.startTransaction(con -> { + try { + for (TOTPUsedCode usedCode : usedCodes) { + storage.insertUsedCode_Transaction(con, usedCode); + } + } catch (TotpNotEnabledException | UsedCodeAlreadyExistsException e) { + throw new StorageTransactionLogicException(e); + } + storage.commitTransaction(con); + + return null; + }); + } catch (StorageTransactionLogicException e) { + Exception actual = e.actualException; + if (actual instanceof TotpNotEnabledException) { + throw (TotpNotEnabledException) actual; + } else if (actual instanceof UsedCodeAlreadyExistsException) { + throw (UsedCodeAlreadyExistsException) actual; + } + throw new StorageQueryException(e); + } + } + @Test public void createDeviceTests() throws Exception { TestSetupResult result = initSteps(); @@ -169,15 +210,14 @@ public void removeUser_TransactionTests() throws Exception { long expiryAfter10mins = now + 10 * 60 * 1000; TOTPUsedCode usedCode1 = new TOTPUsedCode("user", "code1", true, expiryAfter10mins, now); - TOTPUsedCode usedCode2 = new TOTPUsedCode("user", "code2", false, expiryAfter10mins, now); + TOTPUsedCode usedCode2 = new TOTPUsedCode("user", "code2", false, expiryAfter10mins, now + 1); - storage.insertUsedCode(usedCode1); - storage.insertUsedCode(usedCode2); + insertUsedCodesUtil(storage, new TOTPUsedCode[] { usedCode1, usedCode2 }); TOTPDevice[] storedDevices = storage.getDevices("user"); assert (storedDevices.length == 2); - TOTPUsedCode[] storedUsedCodes = storage.getAllUsedCodesDescOrder("user"); + TOTPUsedCode[] storedUsedCodes = getAllUsedCodesUtil(storage, "user"); assert (storedUsedCodes.length == 2); storage.startTransaction(con -> { @@ -189,7 +229,7 @@ public void removeUser_TransactionTests() throws Exception { storedDevices = storage.getDevices("user"); assert (storedDevices.length == 0); - storedUsedCodes = storage.getAllUsedCodesDescOrder("user"); + storedUsedCodes = getAllUsedCodesUtil(storage, "user"); assert (storedUsedCodes.length == 0); } @@ -298,39 +338,54 @@ public void insertUsedCodeTest() throws Exception { TestSetupResult result = initSteps(); TOTPSQLStorage storage = result.storage; long nextDay = System.currentTimeMillis() + 1000 * 60 * 60 * 24; // 1 day from now + long now = System.currentTimeMillis(); // Insert a long lasting valid code and check that it's returned when queried: { TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); - TOTPUsedCode code = new TOTPUsedCode("user", "1234", true, nextDay, System.currentTimeMillis()); + TOTPUsedCode code = new TOTPUsedCode("user", "1234", true, nextDay, now); storage.createDevice(device); - storage.insertUsedCode(code); - TOTPUsedCode[] usedCodes = storage.getAllUsedCodesDescOrder("user"); + insertUsedCodesUtil(storage, new TOTPUsedCode[] { code }); + TOTPUsedCode[] usedCodes = getAllUsedCodesUtil(storage, "user"); assert (usedCodes.length == 1); assert usedCodes[0].equals(code); } + // Try to insert a code with same user and created time. It should fail: + { + TOTPUsedCode codeWithRepeatedCreatedTime = new TOTPUsedCode("user", "any-code", true, nextDay, now); + assertThrows(UsedCodeAlreadyExistsException.class, + () -> insertUsedCodesUtil(storage, new TOTPUsedCode[] { codeWithRepeatedCreatedTime })); + } + // Try to insert code when user doesn't have any device (i.e. TOTP not enabled) { assertThrows(TotpNotEnabledException.class, - () -> storage.insertUsedCode( + () -> insertUsedCodesUtil(storage, new TOTPUsedCode[] { new TOTPUsedCode("new-user-without-totp", "1234", true, nextDay, - System.currentTimeMillis()))); + System.currentTimeMillis()) + })); } // Try to insert code after user has atleast one device (i.e. TOTP enabled) { TOTPDevice newDevice = new TOTPDevice("user", "new-device", "secretKey", 30, 1, false); storage.createDevice(newDevice); - storage.insertUsedCode(new TOTPUsedCode("user", "1234", true, nextDay, System.currentTimeMillis())); + insertUsedCodesUtil( + storage, + new TOTPUsedCode[] { + new TOTPUsedCode("user", "1234", true, nextDay, System.currentTimeMillis()) + }); } // Try to insert code when user doesn't exist: assertThrows(TotpNotEnabledException.class, - () -> storage.insertUsedCode( - new TOTPUsedCode("non-existent-user", "1234", true, nextDay, System.currentTimeMillis()))); + () -> insertUsedCodesUtil(storage, new TOTPUsedCode[] { + new TOTPUsedCode("non-existent-user", "1234", true, nextDay, + System.currentTimeMillis()) + })); } @Test @@ -338,7 +393,7 @@ public void getAllUsedCodesTest() throws Exception { TestSetupResult result = initSteps(); TOTPSQLStorage storage = result.storage; - TOTPUsedCode[] usedCodes = storage.getAllUsedCodesDescOrder("non-existent-user"); + TOTPUsedCode[] usedCodes = getAllUsedCodesUtil(storage, "non-existent-user"); assert (usedCodes.length == 0); long now = System.currentTimeMillis(); @@ -346,32 +401,31 @@ public void getAllUsedCodesTest() throws Exception { long prevDay = now - 1000 * 60 * 60 * 24; // 1 day ago TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); - TOTPUsedCode validCode1 = new TOTPUsedCode("user", "valid-code-1", true, nextDay, now); - TOTPUsedCode invalidCode = new TOTPUsedCode("user", "invalid-code", false, nextDay, now); - TOTPUsedCode expiredCode = new TOTPUsedCode("user", "expired-code", true, prevDay, now); - TOTPUsedCode expiredInvalidCode = new TOTPUsedCode("user", "expired-invalid-code", false, prevDay, now); - TOTPUsedCode validCode2 = new TOTPUsedCode("user", "valid-code-2", true, nextDay, now + 1); - TOTPUsedCode validCode3 = new TOTPUsedCode("user", "valid-code-3", true, nextDay, now + 2); + TOTPUsedCode validCode1 = new TOTPUsedCode("user", "valid-code-1", true, nextDay, now + 1); + TOTPUsedCode invalidCode = new TOTPUsedCode("user", "invalid-code", false, nextDay, now + 2); + TOTPUsedCode expiredCode = new TOTPUsedCode("user", "expired-code", true, prevDay, now + 3); + TOTPUsedCode expiredInvalidCode = new TOTPUsedCode("user", "expired-invalid-code", false, prevDay, now + 4); + TOTPUsedCode validCode2 = new TOTPUsedCode("user", "valid-code-2", true, nextDay, now + 5); + TOTPUsedCode validCode3 = new TOTPUsedCode("user", "valid-code-3", true, nextDay, now + 6); storage.createDevice(device); - storage.insertUsedCode(validCode1); - storage.insertUsedCode(invalidCode); - storage.insertUsedCode(expiredCode); - storage.insertUsedCode(expiredInvalidCode); - storage.insertUsedCode(validCode2); - storage.insertUsedCode(validCode3); - - usedCodes = storage.getAllUsedCodesDescOrder("user"); + insertUsedCodesUtil(storage, new TOTPUsedCode[] { + validCode1, invalidCode, + expiredCode, expiredInvalidCode, + validCode2, validCode3 + }); + + usedCodes = getAllUsedCodesUtil(storage, "user"); assert (usedCodes.length == 6); DeleteExpiredTotpTokens.getInstance(result.process.getProcess()).run(); - usedCodes = storage.getAllUsedCodesDescOrder("user"); + usedCodes = getAllUsedCodesUtil(storage, "user"); assert (usedCodes.length == 4); // expired codes shouldn't be returned assert (usedCodes[0].equals(validCode3)); // order is DESC by created time (now + X) assert (usedCodes[1].equals(validCode2)); - assert (usedCodes[2].equals(validCode1)); - assert (usedCodes[3].equals(invalidCode)); + assert (usedCodes[2].equals(invalidCode)); + assert (usedCodes[3].equals(validCode1)); } @Test @@ -385,27 +439,27 @@ public void removeExpiredCodesTest() throws Exception { TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); TOTPUsedCode validCodeToLive = new TOTPUsedCode("user", "valid-code", true, nextDay, now); - TOTPUsedCode invalidCodeToLive = new TOTPUsedCode("user", "invalid-code", false, nextDay, now); - TOTPUsedCode validCodeToExpire = new TOTPUsedCode("user", "valid-code", true, halfSecond, now); - TOTPUsedCode invalidCodeToExpire = new TOTPUsedCode("user", "invalid-code", false, halfSecond, now); + TOTPUsedCode invalidCodeToLive = new TOTPUsedCode("user", "invalid-code", false, nextDay, now + 1); + TOTPUsedCode validCodeToExpire = new TOTPUsedCode("user", "valid-code", true, halfSecond, now + 2); + TOTPUsedCode invalidCodeToExpire = new TOTPUsedCode("user", "invalid-code", false, halfSecond, now + 3); storage.createDevice(device); - storage.insertUsedCode(validCodeToLive); - storage.insertUsedCode(invalidCodeToLive); - storage.insertUsedCode(validCodeToExpire); - storage.insertUsedCode(invalidCodeToExpire); + insertUsedCodesUtil(storage, new TOTPUsedCode[] { + validCodeToLive, invalidCodeToLive, + validCodeToExpire, invalidCodeToExpire + }); - TOTPUsedCode[] usedCodes = storage.getAllUsedCodesDescOrder("user"); + TOTPUsedCode[] usedCodes = getAllUsedCodesUtil(storage, "user"); assert (usedCodes.length == 4); // After 500ms seconds pass: Thread.sleep(500); - storage.removeExpiredCodes(); + storage.removeExpiredCodes(System.currentTimeMillis()); - usedCodes = storage.getAllUsedCodesDescOrder("user"); + usedCodes = getAllUsedCodesUtil(storage, "user"); assert (usedCodes.length == 2); - assert (usedCodes[0].equals(validCodeToLive)); - assert (usedCodes[1].equals(invalidCodeToLive)); + assert (usedCodes[0].equals(invalidCodeToLive)); + assert (usedCodes[1].equals(validCodeToLive)); } } From 39ff5b79e82133540cf5893f0f3e53dd523b498a Mon Sep 17 00:00:00 2001 From: KShivendu Date: Fri, 10 Mar 2023 17:29:46 +0530 Subject: [PATCH 28/42] feat: Improve TOTP implementation - Use locks while verifying and storing code - Improve var names - Retry checkAndStore code on failure due to primary key - Update tests --- devConfig.yaml | 3 - .../java/io/supertokens/inmemorydb/Start.java | 26 +-- .../inmemorydb/queries/TOTPQueries.java | 15 +- src/main/java/io/supertokens/totp/Totp.java | 205 +++++++++--------- .../io/supertokens/webserver/Webserver.java | 1 - .../api/totp/UpdateTotpDeviceNameAPI.java | 70 ------ .../webserver/api/totp/VerifyTotpAPI.java | 3 +- .../api/totp/VerifyTotpDeviceAPI.java | 3 +- .../supertokens/test/totp/TOTPRecipeTest.java | 52 +++-- .../test/totp/TOTPStorageTest.java | 10 +- 10 files changed, 166 insertions(+), 222 deletions(-) delete mode 100644 src/main/java/io/supertokens/webserver/api/totp/UpdateTotpDeviceNameAPI.java diff --git a/devConfig.yaml b/devConfig.yaml index 8969d5f5a..82e0e695a 100644 --- a/devConfig.yaml +++ b/devConfig.yaml @@ -60,9 +60,6 @@ core_config_version: 0 # (OPTIONAL | Default: 900) integer value. The time in seconds for which the user will be rate limited once totp_max_attempts is crossed. # totp_rate_limit_cooldown_sec: -# (OPTIONAL | Default: 1800) integer value. The time in seconds in which invalid TOTP codes will be considered expired. -# totp_invalid_code_expiry_sec: - # (OPTIONAL | Default: installation directory/logs/info.log) string value. Give the path to a file (on your local # system) in which the SuperTokens service can write INFO logs to. Set it to "null" if you want it to log to # standard output instead. diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index 98bc21b55..07c6a625f 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -1814,17 +1814,6 @@ public int deleteDevice_Transaction(TransactionConnection con, String userId, St } } - @Override - public int getDevicesCount_Transaction(TransactionConnection con, String userId) - throws StorageQueryException { - Connection sqlCon = (Connection) con.getConnection(); - try { - return TOTPQueries.getDevicesCount_Transaction(this, sqlCon, userId); - } catch (SQLException e) { - throw new StorageQueryException(e); - } - } - @Override public void removeUser_Transaction(TransactionConnection con, String userId) throws StorageQueryException { @@ -1865,6 +1854,17 @@ public TOTPDevice[] getDevices(String userId) } } + @Override + public TOTPDevice[] getDevices_Transaction(TransactionConnection con, String userId) + throws StorageQueryException { + Connection sqlCon = (Connection) con.getConnection(); + try { + return TOTPQueries.getDevices_Transaction(this, sqlCon, userId); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + @Override public void insertUsedCode_Transaction(TransactionConnection con, TOTPUsedCode usedCodeObj) throws StorageQueryException, TotpNotEnabledException, UsedCodeAlreadyExistsException { @@ -1891,11 +1891,11 @@ public void insertUsedCode_Transaction(TransactionConnection con, TOTPUsedCode u } @Override - public TOTPUsedCode[] getAllUsedCodesDescOrderAndLockByUser_Transaction(TransactionConnection con, String userId) + public TOTPUsedCode[] getAllUsedCodesDescOrder_Transaction(TransactionConnection con, String userId) throws StorageQueryException { Connection sqlCon = (Connection) con.getConnection(); try { - return TOTPQueries.getAllUsedCodesDescOrderAndLockByUser_Transaction(this, sqlCon, userId); + return TOTPQueries.getAllUsedCodesDescOrder_Transaction(this, sqlCon, userId); } catch (SQLException e) { throw new StorageQueryException(e); } diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index 227d02f68..dd5cf392c 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -150,16 +150,21 @@ public static TOTPDevice[] getDevices(Start start, String userId) }); } - public static int getDevicesCount_Transaction(Start start, Connection con, String userId) + public static TOTPDevice[] getDevices_Transaction(Start start, Connection con, String userId) throws StorageQueryException, SQLException { ((ConnectionWithLocks) con).lock(userId + Config.getConfig(start).getTotpUserDevicesTable()); - - String QUERY = "SELECT COUNT(*) as count FROM " + Config.getConfig(start).getTotpUserDevicesTable() + String QUERY = "SELECT * FROM " + Config.getConfig(start).getTotpUserDevicesTable() + " WHERE user_id = ?;"; return execute(con, QUERY, pst -> pst.setString(1, userId), result -> { - return result.getInt("count"); + List devices = new ArrayList<>(); + while (result.next()) { + devices.add(TOTPDeviceRowMapper.getInstance().map(result)); + } + + return devices.toArray(TOTPDevice[]::new); }); + } public static int insertUsedCode_Transaction(Start start, Connection con, TOTPUsedCode code) @@ -180,7 +185,7 @@ public static int insertUsedCode_Transaction(Start start, Connection con, TOTPUs * Query to get all used codes (expired/non-expired) for a user in descending * order of creation time. */ - public static TOTPUsedCode[] getAllUsedCodesDescOrderAndLockByUser_Transaction(Start start, Connection con, + public static TOTPUsedCode[] getAllUsedCodesDescOrder_Transaction(Start start, Connection con, String userId) throws SQLException, StorageQueryException { // Take a lock based on the user id: diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index b86992f1f..99041c790 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -53,7 +53,6 @@ private static boolean checkCode(TOTPDevice device, String code) { // Check if code is valid for any of the time periods in the skew: for (int i = -skew; i <= skew; i++) { try { - // TODO: Would there be any effect of timezones here? if (totp.generateOneTimePasswordString(key, Instant.now().plusSeconds(i * period)).equals(code)) { return true; } @@ -71,9 +70,6 @@ public static TOTPDevice registerDevice(Main main, String userId, String deviceN TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); - // TODO: There should be a hard limit on number of devices per user - // 8 devices per user should be enough. Otherwise, it is a security risk. - String secret = generateSecret(); TOTPDevice device = new TOTPDevice(userId, deviceName, secret, period, skew, false); totpStorage.createDevice(device); @@ -83,7 +79,8 @@ public static TOTPDevice registerDevice(Main main, String userId, String deviceN private static void checkAndStoreCode(Main main, TOTPStorage totpStorage, String userId, TOTPDevice[] devices, String code) - throws InvalidTotpException, StorageQueryException, TotpNotEnabledException, LimitReachedException { + throws InvalidTotpException, TotpNotEnabledException, + LimitReachedException, StorageQueryException, StorageTransactionLogicException, InterruptedException { // Note that the TOTP cron runs every 1 hour, so all the expired tokens can stay // in the db for at max 1 hour after expiry. @@ -119,125 +116,118 @@ private static void checkAndStoreCode(Main main, TOTPStorage totpStorage, String TOTPSQLStorage totpSQLStorage = (TOTPSQLStorage) totpStorage; - try { - totpSQLStorage.startTransaction(con -> { - TOTPUsedCode[] usedCodes = totpSQLStorage.getAllUsedCodesDescOrderAndLockByUser_Transaction(con, - userId); - - // N represents # of invalid attempts that will trigger rate limiting: - int N = Config.getConfig(main).getTotpMaxAttempts(); // (Default 5) - // Count # of contiguous invalids in latest N attempts (stop at first valid): - long invalidOutOfN = Arrays.stream(usedCodes).limit(N).takeWhile(usedCode -> !usedCode.isValid).count(); - int rateLimitResetTimeInMs = Config.getConfig(main).getTotpRateLimitCooldownTimeSec() * 1000; // (Default - // 15 - // mins) - - // Check if the user has been rate limited: - if (invalidOutOfN == N) { - // All of the latest N attempts were invalid: - long latestInvalidCodeCreatedTime = usedCodes[0].createdTime; - long now = System.currentTimeMillis(); - - if (now - latestInvalidCodeCreatedTime < rateLimitResetTimeInMs) { - // Less than rateLimitResetTimeInMs (default = 15 mins) time has elasped since - // the last invalid code: - throw new StorageTransactionLogicException( - new LimitReachedException(rateLimitResetTimeInMs / 1000)); - - // If we insert the used code here, then it will further delay the user from - // being able to login. So not inserting it here. - - // Note: One edge case here is: User is rate limited, and then the - // DeleteExpiredTotpTokens cron removes the latest invalid attempts - // (because they have expired), and then user will again be able to - // do extra login attempts (totp_max_attempts more times). - // But rate limiting will kick in after totp_max_attempts number - // disarming the brute force attack. - // Furthermore, the cron running during cooldown of a user is somewhat rare. - // So this edge case is practically harmless. + while (true) { + try { + totpSQLStorage.startTransaction(con -> { + TOTPUsedCode[] usedCodes = totpSQLStorage.getAllUsedCodesDescOrder_Transaction(con, + userId); + + // N represents # of invalid attempts that will trigger rate limiting: + int N = Config.getConfig(main).getTotpMaxAttempts(); // (Default 5) + // Count # of contiguous invalids in latest N attempts (stop at first valid): + long invalidOutOfN = Arrays.stream(usedCodes).limit(N).takeWhile(usedCode -> !usedCode.isValid) + .count(); + int rateLimitResetTimeInMs = Config.getConfig(main).getTotpRateLimitCooldownTimeSec() * 1000; // (Default + // 15 mins) + + // Check if the user has been rate limited: + if (invalidOutOfN == N) { + // All of the latest N attempts were invalid: + long latestInvalidCodeCreatedTime = usedCodes[0].createdTime; + long now = System.currentTimeMillis(); + + if (now - latestInvalidCodeCreatedTime < rateLimitResetTimeInMs) { + // Less than rateLimitResetTimeInMs (default = 15 mins) time has elasped since + // the last invalid code: + throw new StorageTransactionLogicException( + new LimitReachedException(rateLimitResetTimeInMs / 1000)); + + // If we insert the used code here, then it will further delay the user from + // being able to login. So not inserting it here. + } } - } - // Check if the code is valid for any device: - boolean isValid = false; - TOTPDevice matchingDevice = null; - for (TOTPDevice device : devices) { - // Check if the code is valid for this device: - if (checkCode(device, code)) { - isValid = true; - matchingDevice = device; - break; + // Check if the code is valid for any device: + boolean isValid = false; + TOTPDevice matchingDevice = null; + for (TOTPDevice device : devices) { + // Check if the code is valid for this device: + if (checkCode(device, code)) { + isValid = true; + matchingDevice = device; + break; + } } - } - // Check if the code has been previously used by the user and it was valid (and - // is still valid). If so, this could be a replay attack. So reject it. - if (isValid) { - for (TOTPUsedCode usedCode : usedCodes) { - // One edge case is that if the user has 2 devices, and they are used back to - // back (within 90 seconds) such that the code of the first device was - // regenerated by the second device, then it won't allow the second device's - // code to be used until it is expired. - // But this would be rare so we can ignore it for now. - if (usedCode.code.equals(code) && usedCode.isValid - && usedCode.expiryTime > System.currentTimeMillis()) { - isValid = false; - // We found a matching device but the code - // will be considered invalid here. + // Check if the code has been previously used by the user and it was valid (and + // is still valid). If so, this could be a replay attack. So reject it. + if (isValid) { + for (TOTPUsedCode usedCode : usedCodes) { + // One edge case is that if the user has 2 devices, and they are used back to + // back (within 90 seconds) such that the code of the first device was + // regenerated by the second device, then it won't allow the second device's + // code to be used until it is expired. + // But this would be rare so we can ignore it for now. + if (usedCode.code.equals(code) && usedCode.isValid + && usedCode.expiryTime > System.currentTimeMillis()) { + isValid = false; + // We found a matching device but the code + // will be considered invalid here. + } } } - } - // Insert the code into the list of used codes: + // Insert the code into the list of used codes: - // If device is found, calculate used code expiry time for that device (based on - // its period and skew). Otherwise, use the max used code expiry time of all the - // devices. - int maxUsedCodeExpiry = Arrays.stream(devices).mapToInt(device -> device.period * (2 * device.skew + 1)) - .max() - .orElse(0); - int expireInSec = (matchingDevice != null) ? matchingDevice.period * (2 * matchingDevice.skew + 1) - : maxUsedCodeExpiry; + // If device is found, calculate used code expiry time for that device (based on + // its period and skew). Otherwise, use the max used code expiry time of all the + // devices. + int maxUsedCodeExpiry = Arrays.stream(devices) + .mapToInt(device -> device.period * (2 * device.skew + 1)) + .max() + .orElse(0); + int expireInSec = (matchingDevice != null) ? matchingDevice.period * (2 * matchingDevice.skew + 1) + : maxUsedCodeExpiry; - while (true) { long now = System.currentTimeMillis(); TOTPUsedCode newCode = new TOTPUsedCode(userId, code, isValid, now + 1000 * expireInSec, now); try { totpSQLStorage.insertUsedCode_Transaction(con, newCode); - break; - } catch (UsedCodeAlreadyExistsException e) { - break; - } catch (TotpNotEnabledException e) { + totpSQLStorage.commitTransaction(con); + } catch (UsedCodeAlreadyExistsException | TotpNotEnabledException e) { throw new StorageTransactionLogicException(e); } - } - if (!isValid) { - totpSQLStorage.commitTransaction(con); - throw new StorageTransactionLogicException(new InvalidTotpException()); - } + if (!isValid) { + // transaction has been committed, so we can directly throw the exception: + throw new StorageTransactionLogicException(new InvalidTotpException()); + } - return null; - }); - } catch (StorageTransactionLogicException e) { - if (e.actualException instanceof LimitReachedException) { - throw (LimitReachedException) e.actualException; - } else if (e.actualException instanceof InvalidTotpException) { - throw (InvalidTotpException) e.actualException; - } else if (e.actualException instanceof TotpNotEnabledException) { - throw (TotpNotEnabledException) e.actualException; - } else { - throw new StorageQueryException(e.actualException); + return null; + }); + return; // exit the while loop + } catch (StorageTransactionLogicException e) { + // throwing errors will also help exit the while loop: + if (e.actualException instanceof LimitReachedException) { + throw (LimitReachedException) e.actualException; + } else if (e.actualException instanceof InvalidTotpException) { + throw (InvalidTotpException) e.actualException; + } else if (e.actualException instanceof TotpNotEnabledException) { + throw (TotpNotEnabledException) e.actualException; + } else if (e.actualException instanceof UsedCodeAlreadyExistsException) { + // retry the transaction after 3ms (not sec) + Thread.sleep(3); + continue; + } else { + throw e; + } } } - - return; - } public static boolean verifyDevice(Main main, String userId, String deviceName, String code) - throws StorageQueryException, TotpNotEnabledException, UnknownDeviceException, InvalidTotpException, - LimitReachedException { + throws TotpNotEnabledException, UnknownDeviceException, InvalidTotpException, + LimitReachedException, StorageQueryException, StorageTransactionLogicException, InterruptedException { // Here boolean return value tells whether the device has been // newly verified (true) OR it was already verified (false) @@ -277,8 +267,8 @@ public static boolean verifyDevice(Main main, String userId, String deviceName, } public static void verifyCode(Main main, String userId, String code, boolean allowUnverifiedDevices) - throws StorageQueryException, TotpNotEnabledException, InvalidTotpException, - LimitReachedException { + throws TotpNotEnabledException, InvalidTotpException, LimitReachedException, + StorageQueryException, StorageTransactionLogicException, InterruptedException { TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); // Check if the user has any devices: @@ -309,8 +299,9 @@ public static void removeDevice(Main main, String userId, String deviceName) } // Some device(s) were deleted. Check if user has any other device left: - int devicesCount = storage.getDevicesCount_Transaction(con, userId); - if (devicesCount == 0) { + // This also takes a lock on the user devices. + TOTPDevice[] devices = storage.getDevices_Transaction(con, userId); + if (devices.length == 0) { // no device left. delete user storage.removeUser_Transaction(con, userId); } @@ -330,7 +321,7 @@ public static void removeDevice(Main main, String userId, String deviceName) throw (UnknownDeviceException) e.actualException; } - throw new StorageQueryException(e.actualException); + throw e; } } diff --git a/src/main/java/io/supertokens/webserver/Webserver.java b/src/main/java/io/supertokens/webserver/Webserver.java index c243bfa94..a0a6cfaf9 100644 --- a/src/main/java/io/supertokens/webserver/Webserver.java +++ b/src/main/java/io/supertokens/webserver/Webserver.java @@ -48,7 +48,6 @@ import io.supertokens.webserver.api.totp.CreateOrUpdateTotpDeviceAPI; import io.supertokens.webserver.api.totp.GetTotpDevicesAPI; import io.supertokens.webserver.api.totp.RemoveTotpDeviceAPI; -import io.supertokens.webserver.api.totp.UpdateTotpDeviceNameAPI; import io.supertokens.webserver.api.totp.VerifyTotpAPI; import io.supertokens.webserver.api.totp.VerifyTotpDeviceAPI; import io.supertokens.webserver.api.useridmapping.RemoveUserIdMappingAPI; diff --git a/src/main/java/io/supertokens/webserver/api/totp/UpdateTotpDeviceNameAPI.java b/src/main/java/io/supertokens/webserver/api/totp/UpdateTotpDeviceNameAPI.java deleted file mode 100644 index eaa69a388..000000000 --- a/src/main/java/io/supertokens/webserver/api/totp/UpdateTotpDeviceNameAPI.java +++ /dev/null @@ -1,70 +0,0 @@ -package io.supertokens.webserver.api.totp; - -import java.io.IOException; - -import com.google.gson.JsonObject; - -import io.supertokens.Main; -import io.supertokens.pluginInterface.RECIPE_ID; -import io.supertokens.pluginInterface.exceptions.StorageQueryException; -import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; -import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; -import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; -import io.supertokens.totp.Totp; -import io.supertokens.webserver.InputParser; -import io.supertokens.webserver.WebserverAPI; -import jakarta.servlet.ServletException; -import jakarta.servlet.http.HttpServletRequest; -import jakarta.servlet.http.HttpServletResponse; - -public class UpdateTotpDeviceNameAPI extends WebserverAPI { - private static final long serialVersionUID = -4641988458637882374L; - - public UpdateTotpDeviceNameAPI(Main main) { - super(main, RECIPE_ID.TOTP.toString()); - } - - @Override - public String getPath() { - return "/recipe/totp/device"; - } - - @Override - protected void doPut(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { - JsonObject input = InputParser.parseJsonObjectOrThrowError(req); - - String userId = InputParser.parseStringOrThrowError(input, "userId", false); - String existingDeviceName = InputParser.parseStringOrThrowError(input, "existingDeviceName", false); - String newDeviceName = InputParser.parseStringOrThrowError(input, "newDeviceName", false); - - if (userId.isEmpty()) { - throw new ServletException(new BadRequestException("userId cannot be empty")); - } - if (existingDeviceName.isEmpty()) { - throw new ServletException(new BadRequestException("existingDeviceName cannot be empty")); - } - if (newDeviceName.isEmpty()) { - throw new ServletException(new BadRequestException("newDeviceName cannot be empty")); - } - - JsonObject result = new JsonObject(); - - try { - Totp.updateDeviceName(main, userId, existingDeviceName, newDeviceName); - - result.addProperty("status", "OK"); - super.sendJsonResponse(200, result, resp); - } catch (TotpNotEnabledException e) { - result.addProperty("status", "TOTP_NOT_ENABLED_ERROR"); - super.sendJsonResponse(200, result, resp); - } catch (UnknownDeviceException e) { - result.addProperty("status", "UNKNOWN_DEVICE_ERROR"); - super.sendJsonResponse(200, result, resp); - } catch (DeviceAlreadyExistsException e) { - result.addProperty("status", "DEVICE_ALREADY_EXISTS_ERROR"); - super.sendJsonResponse(200, result, resp); - } catch (StorageQueryException e) { - throw new ServletException(e); - } - } -} diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java index e486baf0b..31a08bda6 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java @@ -7,6 +7,7 @@ import io.supertokens.Main; import io.supertokens.pluginInterface.RECIPE_ID; import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; import io.supertokens.totp.Totp; import io.supertokens.totp.exceptions.InvalidTotpException; @@ -63,7 +64,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I // Also return a retryAfter value: resp.addHeader("Retry-After", Integer.toString(e.retryInSeconds)); super.sendJsonResponse(429, result, resp); // 429 Too Many Requests - } catch (StorageQueryException e) { + } catch (StorageQueryException | StorageTransactionLogicException | InterruptedException e) { throw new ServletException(e); } } diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java index dca121d1b..150f8ccf4 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java @@ -7,6 +7,7 @@ import io.supertokens.Main; import io.supertokens.pluginInterface.RECIPE_ID; import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; import io.supertokens.totp.Totp; @@ -70,7 +71,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I // Also return a retryAfter value: resp.addHeader("Retry-After", Integer.toString(e.retryInSeconds)); super.sendJsonResponse(429, result, resp); // 429 (Too Many Requests) - } catch (StorageQueryException e) { + } catch (StorageQueryException | StorageTransactionLogicException | InterruptedException e) { throw new ServletException(e); } } diff --git a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java index 950cf2108..4cbf6c30f 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java @@ -120,7 +120,7 @@ private static TOTPUsedCode[] getAllUsedCodesUtil(TOTPStorage storage, String us TOTPSQLStorage sqlStorage = (TOTPSQLStorage) storage; return (TOTPUsedCode[]) sqlStorage.startTransaction(con -> { - TOTPUsedCode[] usedCodes = sqlStorage.getAllUsedCodesDescOrderAndLockByUser_Transaction(con, userId); + TOTPUsedCode[] usedCodes = sqlStorage.getAllUsedCodesDescOrder_Transaction(con, userId); sqlStorage.commitTransaction(con); return usedCodes; }); @@ -177,33 +177,47 @@ public void createDeviceAndVerifyCodeTest() throws Exception { () -> Totp.verifyCode(main, "user", validCode, true)); // Sleep for 1s so that code changes. - Thread.sleep(1500); + Thread.sleep(1000); // Use a new valid code: String newValidCode = generateTotpCode(main, device); Totp.verifyCode(main, "user", newValidCode, true); - // Regenerate the same code and use it again (should fail): - String newValidCodeCopy = generateTotpCode(main, device); + // Reuse the same code and use it again (should fail): assertThrows(InvalidTotpException.class, - () -> Totp.verifyCode(main, "user", newValidCodeCopy, true)); + () -> Totp.verifyCode(main, "user", newValidCode, true)); // Use a code from next period: String nextValidCode = generateTotpCode(main, device, 1); Totp.verifyCode(main, "user", nextValidCode, true); - // Use previous period code (should fail coz validCode): // FIXME: This should - // // fail - // String previousCode = generateTotpCode(main, "user", "device", -1); - // Totp.verifyCode(main, "user", previousCode, true); + // Use previous period code (should fail coz validCode has been used): + String previousCode = generateTotpCode(main, device, -1); + assert previousCode.equals(validCode); + assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", previousCode, true)); + + // Create device with skew = 0, check that it only works with the current code + TOTPDevice device2 = Totp.registerDevice(main, "user", "device2", 0, 1); + assert device2.secretKey != device.secretKey; + + String nextValidCode2 = generateTotpCode(main, device2, 1); + assertThrows(InvalidTotpException.class, + () -> Totp.verifyCode(main, "user", nextValidCode2, true)); + + String previousValidCode2 = generateTotpCode(main, device2, -1); + assertThrows(InvalidTotpException.class, + () -> Totp.verifyCode(main, "user", previousValidCode2, true)); - // TODO: Add isolated tests where we - // - we try next and previous codes as well (try different skew values) - // - change totp_max_attempts - // - change totp_invalid_code_expiry_sec + String currentValidCode2 = generateTotpCode(main, device2); + Totp.verifyCode(main, "user", currentValidCode2, true); } - public void triggerAndCheckRateLimit(Main main, TOTPDevice device) throws Exception { + /* + * Triggers rate limiting and checks that it works. + * It returns the number of attempts that were made before rate limiting was + * triggered. + */ + public int triggerAndCheckRateLimit(Main main, TOTPDevice device) throws Exception { int N = Config.getConfig(main).getTotpMaxAttempts(); // First N attempts should fail with invalid code: @@ -226,6 +240,8 @@ public void triggerAndCheckRateLimit(Main main, TOTPDevice device) throws Except assertThrows( LimitReachedException.class, () -> Totp.verifyCode(main, "user", "invalid-code-N+2", true)); + + return N; } @Test @@ -234,6 +250,8 @@ public void rateLimitCooldownTest() throws Exception { // set rate limiting cooldown time to 1s Utils.setValueInConfig("totp_rate_limit_cooldown_sec", "1"); + // set max attempts to 3 + Utils.setValueInConfig("totp_max_attempts", "3"); TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); @@ -248,7 +266,8 @@ public void rateLimitCooldownTest() throws Exception { TOTPDevice device = Totp.registerDevice(main, "user", "deviceName", 1, 1); // Trigger rate limiting and fix it with a correct code after some time: - triggerAndCheckRateLimit(main, device); + int attemptsRequired = triggerAndCheckRateLimit(main, device); + assert attemptsRequired == 3; // Wait for 1 second (Should cool down rate limiting): Thread.sleep(1000); // But again try with invalid code: @@ -275,7 +294,8 @@ public void cronRemovesAllCodesDuringRateLimitTest() throws Exception { TOTPDevice device = Totp.registerDevice(main, "user", "deviceName", 0, 1); // Trigger rate limiting and fix it with cronjob (manually run cronjob): - triggerAndCheckRateLimit(main, device); + int attemptsRequired = triggerAndCheckRateLimit(main, device); + assert attemptsRequired == 5; // Wait for 1 second so that all the codes expire: Thread.sleep(1500); // Manually run cronjob to delete all the codes after their diff --git a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java index 6815efffa..b82864845 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java @@ -72,7 +72,7 @@ private static TOTPUsedCode[] getAllUsedCodesUtil(TOTPStorage storage, String us TOTPSQLStorage sqlStorage = (TOTPSQLStorage) storage; return (TOTPUsedCode[]) sqlStorage.startTransaction(con -> { - TOTPUsedCode[] usedCodes = sqlStorage.getAllUsedCodesDescOrderAndLockByUser_Transaction(con, userId); + TOTPUsedCode[] usedCodes = sqlStorage.getAllUsedCodesDescOrder_Transaction(con, userId); sqlStorage.commitTransaction(con); return usedCodes; }); @@ -167,9 +167,9 @@ public void getDevicesCount_TransactionTests() throws Exception { // Try to get the count for a user that doesn't exist (Should pass because // this is DB level txn that doesn't throw TotpNotEnabledException): int devicesCount = storage.startTransaction(con -> { - int value = storage.getDevicesCount_Transaction(con, "non-existent-user"); + TOTPDevice[] devices = storage.getDevices_Transaction(con, "non-existent-user"); storage.commitTransaction(con); - return value; + return devices.length; }); assert devicesCount == 0; @@ -180,9 +180,9 @@ public void getDevicesCount_TransactionTests() throws Exception { storage.createDevice(device2); devicesCount = storage.startTransaction(con -> { - int value = storage.getDevicesCount_Transaction(con, "user"); + TOTPDevice[] devices = storage.getDevices_Transaction(con, "user"); storage.commitTransaction(con); - return value; + return devices.length; }); assert devicesCount == 2; } From 3adc73d6ada6a824936465de1f5cc94db40c7fdb Mon Sep 17 00:00:00 2001 From: KShivendu Date: Fri, 10 Mar 2023 17:32:57 +0530 Subject: [PATCH 29/42] refactor: Adjust order of columns in totp_used_codes table --- .../inmemorydb/queries/TOTPQueries.java | 3 +- .../io/supertokens/test/totp/TOTPApiTest.java | 67 +++++++++++++++++++ 2 files changed, 68 insertions(+), 2 deletions(-) create mode 100644 src/test/java/io/supertokens/test/totp/TOTPApiTest.java diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index dd5cf392c..04222139d 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -39,8 +39,8 @@ public static String getQueryToCreateUsedCodesTable(Start start) { return "CREATE TABLE IF NOT EXISTS " + Config.getConfig(start).getTotpUsedCodesTable() + " (" + "user_id VARCHAR(128) NOT NULL, " + "code VARCHAR(8) NOT NULL," + "is_valid BOOLEAN NOT NULL," - + "created_time_ms BIGINT UNSIGNED NOT NULL," + "expiry_time_ms BIGINT UNSIGNED NOT NULL," + + "created_time_ms BIGINT UNSIGNED NOT NULL," + "PRIMARY KEY (user_id, created_time_ms)" + "FOREIGN KEY (user_id) REFERENCES " + Config.getConfig(start).getTotpUsersTable() + "(user_id) ON DELETE CASCADE);"; @@ -254,7 +254,6 @@ public TOTPUsedCode map(ResultSet result) throws SQLException { result.getBoolean("is_valid"), result.getLong("expiry_time_ms"), result.getLong("created_time_ms")); - // FIXME: Put created time first, then expiry time. } } } diff --git a/src/test/java/io/supertokens/test/totp/TOTPApiTest.java b/src/test/java/io/supertokens/test/totp/TOTPApiTest.java new file mode 100644 index 000000000..e2ab46039 --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/TOTPApiTest.java @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.test.totp; + +import com.google.gson.JsonObject; +import io.supertokens.ProcessState; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import static org.junit.Assert.assertNotNull; + +public class TOTPApiTest { + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + @Test + public void testBadInput() throws Exception { + String[] args = { "../" }; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + String userId = "userId"; + String deviceName = "XXX"; + + // + { + JsonObject code = new JsonObject(); + + } + } +} From 35b06a054f370cec433a430e7e3e04f526df4b87 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Tue, 14 Mar 2023 18:42:31 +0530 Subject: [PATCH 30/42] feat: Improve TOTP implementation - Add CHECK constraints in totp used codes table - Supress InterruptedException error and retry - Send 200 with retryAfter in body instead of 429 - Seperate test for CHECK constraint only for inmemorydb --- cli/bin/main/install-linux.sh | 13 --- cli/bin/main/install-windows.bat | 12 --- ...rtokens.featureflag.EEFeatureFlagInterface | 1 - .../inmemorydb/queries/TOTPQueries.java | 9 +- src/main/java/io/supertokens/totp/Totp.java | 30 ++++-- .../webserver/api/totp/VerifyTotpAPI.java | 7 +- .../api/totp/VerifyTotpDeviceAPI.java | 2 +- .../io/supertokens/test/StorageLayerTest.java | 100 ++++++++++++++++++ .../supertokens/test/totp/TOTPRecipeTest.java | 36 ++++--- .../test/totp/TOTPStorageTest.java | 30 +++--- 10 files changed, 174 insertions(+), 66 deletions(-) delete mode 100644 cli/bin/main/install-linux.sh delete mode 100644 cli/bin/main/install-windows.bat delete mode 100644 ee/bin/main/META-INF/services/io.supertokens.featureflag.EEFeatureFlagInterface create mode 100644 src/test/java/io/supertokens/test/StorageLayerTest.java diff --git a/cli/bin/main/install-linux.sh b/cli/bin/main/install-linux.sh deleted file mode 100644 index d29c5b040..000000000 --- a/cli/bin/main/install-linux.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash - -ST_INSTALL_LOC=$ST_INSTALL_LOC - -if [ -f /proc/1/cgroup ] && grep docker /proc/1/cgroup -qa; then - trap 'kill -TERM $PID' TERM INT - "${ST_INSTALL_LOC}"jre/bin/java -classpath "${ST_INSTALL_LOC}cli/*" io.supertokens.cli.Main false "${ST_INSTALL_LOC}" $@ & - PID=$! - wait $PID - trap - TERM INT -else - "${ST_INSTALL_LOC}"jre/bin/java -classpath "${ST_INSTALL_LOC}cli/*" io.supertokens.cli.Main false "${ST_INSTALL_LOC}" $@ -fi diff --git a/cli/bin/main/install-windows.bat b/cli/bin/main/install-windows.bat deleted file mode 100644 index af13fb98d..000000000 --- a/cli/bin/main/install-windows.bat +++ /dev/null @@ -1,12 +0,0 @@ -@echo off -set st_install_loc=$ST_INSTALL_LOC -"%st_install_loc%jre\bin"\java -classpath "%st_install_loc%cli\*" io.supertokens.cli.Main false "%st_install_loc%\" %* -IF %errorlevel% NEQ 0 ( -echo exiting -goto:eof -) -IF "%1" == "uninstall" ( -rmdir /S /Q "%st_install_loc%" -del "%~f0" -) -:eof diff --git a/ee/bin/main/META-INF/services/io.supertokens.featureflag.EEFeatureFlagInterface b/ee/bin/main/META-INF/services/io.supertokens.featureflag.EEFeatureFlagInterface deleted file mode 100644 index d940a8488..000000000 --- a/ee/bin/main/META-INF/services/io.supertokens.featureflag.EEFeatureFlagInterface +++ /dev/null @@ -1 +0,0 @@ -io.supertokens.ee.EEFeatureFlag \ No newline at end of file diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index 04222139d..fb4758cc8 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -30,7 +30,7 @@ public static String getQueryToCreateUserDevicesTable(Start start) { + "user_id VARCHAR(128) NOT NULL," + "device_name VARCHAR(256) NOT NULL," + "secret_key VARCHAR(256) NOT NULL," + "period INTEGER NOT NULL," + "skew INTEGER NOT NULL," + "verified BOOLEAN NOT NULL," - + "PRIMARY KEY (user_id, device_name)" + + "PRIMARY KEY (user_id, device_name)," + "FOREIGN KEY (user_id) REFERENCES " + Config.getConfig(start).getTotpUsersTable() + "(user_id) ON DELETE CASCADE);"; } @@ -38,10 +38,13 @@ public static String getQueryToCreateUserDevicesTable(Start start) { public static String getQueryToCreateUsedCodesTable(Start start) { return "CREATE TABLE IF NOT EXISTS " + Config.getConfig(start).getTotpUsedCodesTable() + " (" + "user_id VARCHAR(128) NOT NULL, " - + "code VARCHAR(8) NOT NULL," + "is_valid BOOLEAN NOT NULL," + // SQLite doesn't follow VARCHAR length by default + // But we can add a check constraint to make sure the length is <= 8 + + "code VARCHAR(8) NOT NULL CHECK(LENGTH(code) <= 8)," + + "is_valid BOOLEAN NOT NULL," + "expiry_time_ms BIGINT UNSIGNED NOT NULL," + "created_time_ms BIGINT UNSIGNED NOT NULL," - + "PRIMARY KEY (user_id, created_time_ms)" + + "PRIMARY KEY (user_id, created_time_ms)," // failing without comma in postgres. validate + "FOREIGN KEY (user_id) REFERENCES " + Config.getConfig(start).getTotpUsersTable() + "(user_id) ON DELETE CASCADE);"; } diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index 99041c790..f679cac2e 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -42,7 +42,7 @@ private static String generateSecret() throws NoSuchAlgorithmException { private static boolean checkCode(TOTPDevice device, String code) { final TimeBasedOneTimePasswordGenerator totp = new TimeBasedOneTimePasswordGenerator( - Duration.ofSeconds(device.period)); + Duration.ofSeconds(device.period), 6); byte[] keyBytes = new Base32().decode(device.secretKey); Key key = new SecretKeySpec(keyBytes, "HmacSHA1"); @@ -80,7 +80,7 @@ public static TOTPDevice registerDevice(Main main, String userId, String deviceN private static void checkAndStoreCode(Main main, TOTPStorage totpStorage, String userId, TOTPDevice[] devices, String code) throws InvalidTotpException, TotpNotEnabledException, - LimitReachedException, StorageQueryException, StorageTransactionLogicException, InterruptedException { + LimitReachedException, StorageQueryException, StorageTransactionLogicException { // Note that the TOTP cron runs every 1 hour, so all the expired tokens can stay // in the db for at max 1 hour after expiry. @@ -215,9 +215,15 @@ private static void checkAndStoreCode(Main main, TOTPStorage totpStorage, String } else if (e.actualException instanceof TotpNotEnabledException) { throw (TotpNotEnabledException) e.actualException; } else if (e.actualException instanceof UsedCodeAlreadyExistsException) { - // retry the transaction after 3ms (not sec) - Thread.sleep(3); - continue; + // retry the transaction after a small delay: + int delayInMs = (int) (Math.random() * 10 + 1); + try { + Thread.sleep(delayInMs); + continue; + } catch (InterruptedException err) { + // ignore the error and retry + continue; + } } else { throw e; } @@ -227,7 +233,7 @@ private static void checkAndStoreCode(Main main, TOTPStorage totpStorage, String public static boolean verifyDevice(Main main, String userId, String deviceName, String code) throws TotpNotEnabledException, UnknownDeviceException, InvalidTotpException, - LimitReachedException, StorageQueryException, StorageTransactionLogicException, InterruptedException { + LimitReachedException, StorageQueryException, StorageTransactionLogicException { // Here boolean return value tells whether the device has been // newly verified (true) OR it was already verified (false) @@ -260,6 +266,12 @@ public static boolean verifyDevice(Main main, String userId, String deviceName, throw new UnknownDeviceException(); } + // At this point, even if device is suddenly deleted/renamed by another API + // call. We will still check the code against the new set of devices and store + // it in the used codes table. However, the device will not be marked as + // verified in the devices table (because it was deleted/renamed). So the user + // gets a UnknownDevceException. + // This behaviour is okay so we can ignore it. checkAndStoreCode(main, totpStorage, userId, new TOTPDevice[] { matchingDevice }, code); // Will reach here only if the code is valid: totpStorage.markDeviceAsVerified(userId, deviceName); @@ -268,7 +280,7 @@ public static boolean verifyDevice(Main main, String userId, String deviceName, public static void verifyCode(Main main, String userId, String code, boolean allowUnverifiedDevices) throws TotpNotEnabledException, InvalidTotpException, LimitReachedException, - StorageQueryException, StorageTransactionLogicException, InterruptedException { + StorageQueryException, StorageTransactionLogicException { TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); // Check if the user has any devices: @@ -282,6 +294,10 @@ public static void verifyCode(Main main, String userId, String code, boolean all devices = Arrays.stream(devices).filter(device -> device.verified).toArray(TOTPDevice[]::new); } + // At this point, even if some of the devices are suddenly deleted/renamed by + // another API call. We will still check the code against the updated set of + // devices and store it in the used codes table. This behaviour is okay so we + // can ignore it. checkAndStoreCode(main, totpStorage, userId, devices, code); } diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java index 31a08bda6..4927af36e 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java @@ -61,10 +61,9 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I super.sendJsonResponse(200, result, resp); } catch (LimitReachedException e) { result.addProperty("status", "LIMIT_REACHED_ERROR"); - // Also return a retryAfter value: - resp.addHeader("Retry-After", Integer.toString(e.retryInSeconds)); - super.sendJsonResponse(429, result, resp); // 429 Too Many Requests - } catch (StorageQueryException | StorageTransactionLogicException | InterruptedException e) { + result.addProperty("retryAfter", e.retryInSeconds); + super.sendJsonResponse(200, result, resp); + } catch (StorageQueryException | StorageTransactionLogicException e) { throw new ServletException(e); } } diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java index 150f8ccf4..55f27b5bc 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java @@ -71,7 +71,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I // Also return a retryAfter value: resp.addHeader("Retry-After", Integer.toString(e.retryInSeconds)); super.sendJsonResponse(429, result, resp); // 429 (Too Many Requests) - } catch (StorageQueryException | StorageTransactionLogicException | InterruptedException e) { + } catch (StorageQueryException | StorageTransactionLogicException e) { throw new ServletException(e); } } diff --git a/src/test/java/io/supertokens/test/StorageLayerTest.java b/src/test/java/io/supertokens/test/StorageLayerTest.java new file mode 100644 index 000000000..f1f420334 --- /dev/null +++ b/src/test/java/io/supertokens/test/StorageLayerTest.java @@ -0,0 +1,100 @@ +package io.supertokens.test; + +import io.supertokens.ProcessState; +import io.supertokens.inmemorydb.Start; +import io.supertokens.inmemorydb.config.Config; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.totp.TOTPUsedCode; +import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.totp.exception.UsedCodeAlreadyExistsException; +import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; +import io.supertokens.storageLayer.StorageLayer; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import static org.junit.Assert.assertNotNull; + +public class StorageLayerTest { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + // TOTP recipe: + + public static void insertUsedCodeUtil(TOTPSQLStorage storage, TOTPUsedCode usedCode) throws Exception { + try { + storage.startTransaction(con -> { + try { + storage.insertUsedCode_Transaction(con, usedCode); + storage.commitTransaction(con); + return null; + } catch (TotpNotEnabledException | UsedCodeAlreadyExistsException e) { + throw new StorageTransactionLogicException(e); + } + }); + } catch (StorageTransactionLogicException e) { + Exception actual = e.actualException; + if (actual instanceof TotpNotEnabledException || actual instanceof UsedCodeAlreadyExistsException) { + throw actual; + } else { + throw e; + } + } + } + + @Test + public void totpCodeLengthTest() throws Exception { + String[] args = { "../" }; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + process.getProcess().setForceInMemoryDB(); // this test is only for SQLite + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + assert (false); + } + TOTPSQLStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); + long now = System.currentTimeMillis(); + long nextDay = now + 1000 * 60 * 60 * 24; // 1 day from now + + Start start = (Start) StorageLayer.getStorage(process.getProcess()); + + TOTPDevice d1 = new TOTPDevice("user", "d1", "secret", 30, 1, false); + storage.createDevice(d1); + + // Try code with length > 8 + try { + + TOTPUsedCode code = new TOTPUsedCode("user", "123456789", true, nextDay, now); + insertUsedCodeUtil(storage, code); + assert (false); + } catch (StorageQueryException e) { + // This error will be different in Postgres and MySQL + // We added (CHECK (LENGTH(code) <= 8)) to the table definition in SQLite + String totpUsedCodeTable = Config.getConfig(start).getTotpUsedCodesTable(); + assert e.getMessage().contains("CHECK constraint failed: " + totpUsedCodeTable); + } + + // Try code with length < 8 + TOTPUsedCode code = new TOTPUsedCode("user", "12345678", true, nextDay, now); + insertUsedCodeUtil(storage, code); + } + +} diff --git a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java index 4cbf6c30f..c998681dd 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java @@ -146,7 +146,7 @@ public void createDeviceAndVerifyCodeTest() throws Exception { Main main = result.process.getProcess(); // Create device - TOTPDevice device = Totp.registerDevice(main, "user", "device", 1, 1); + TOTPDevice device = Totp.registerDevice(main, "user", "device1", 1, 1); // Try login with non-existent user: assertThrows(TotpNotEnabledException.class, @@ -156,11 +156,11 @@ public void createDeviceAndVerifyCodeTest() throws Exception { // Invalid code & allowUnverifiedDevice = true: assertThrows(InvalidTotpException.class, - () -> Totp.verifyCode(main, "user", "invalid-code", true)); + () -> Totp.verifyCode(main, "user", "invalid", true)); // Invalid code & allowUnverifiedDevice = false: assertThrows(InvalidTotpException.class, - () -> Totp.verifyCode(main, "user", "invalid-code", false)); + () -> Totp.verifyCode(main, "user", "invalid", false)); // Valid code & allowUnverifiedDevice = false: assertThrows( @@ -210,6 +210,16 @@ public void createDeviceAndVerifyCodeTest() throws Exception { String currentValidCode2 = generateTotpCode(main, device2); Totp.verifyCode(main, "user", currentValidCode2, true); + + // Submit invalid code and check that it's expiry time is correct + // created - expiryTime = max of ((2 * skew + 1) * period) for all devices + assertThrows(InvalidTotpException.class, + () -> Totp.verifyCode(main, "user", "invalid", true)); + + TOTPUsedCode[] usedCodes = getAllUsedCodesUtil(result.storage, "user"); + TOTPUsedCode latestCode = usedCodes[0]; + assert latestCode.isValid == false; + assert latestCode.expiryTime - latestCode.createdTime == 3000; // it should be 3s because of device1 } /* @@ -223,7 +233,7 @@ public int triggerAndCheckRateLimit(Main main, TOTPDevice device) throws Excepti // First N attempts should fail with invalid code: // This is to trigger rate limiting for (int i = 0; i < N; i++) { - String code = "invalid-code-" + i; + String code = "ic-" + i; assertThrows( InvalidTotpException.class, () -> Totp.verifyCode(main, "user", code, true)); @@ -233,13 +243,13 @@ public int triggerAndCheckRateLimit(Main main, TOTPDevice device) throws Excepti // This should happen until rate limiting cooldown happens: assertThrows( LimitReachedException.class, - () -> Totp.verifyCode(main, "user", "invalid-code-N+1", true)); + () -> Totp.verifyCode(main, "user", "icN+1", true)); assertThrows( LimitReachedException.class, () -> Totp.verifyCode(main, "user", generateTotpCode(main, device), true)); assertThrows( LimitReachedException.class, - () -> Totp.verifyCode(main, "user", "invalid-code-N+2", true)); + () -> Totp.verifyCode(main, "user", "icN+2", true)); return N; } @@ -271,7 +281,7 @@ public void rateLimitCooldownTest() throws Exception { // Wait for 1 second (Should cool down rate limiting): Thread.sleep(1000); // But again try with invalid code: - assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "yet-another-invalid-code", true)); + assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "invalid0", true)); // This triggered rate limiting again. So even valid codes will fail for // another cooldown period: assertThrows(LimitReachedException.class, @@ -281,7 +291,7 @@ public void rateLimitCooldownTest() throws Exception { // Now try with valid code: Totp.verifyCode(main, "user", generateTotpCode(main, device), true); // Now invalid code shouldn't trigger rate limiting. Unless you do it N times: - assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "some-invalid-code", true)); + assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "invaldd", true)); } @Test @@ -306,7 +316,7 @@ public void cronRemovesAllCodesDuringRateLimitTest() throws Exception { assertThrows(LimitReachedException.class, () -> Totp.verifyCode(main, "user", generateTotpCode(main, device), true)); assertThrows(LimitReachedException.class, - () -> Totp.verifyCode(main, "user", "again-wrong-code1", true)); + () -> Totp.verifyCode(main, "user", "yet-ic", true)); } @Test @@ -326,7 +336,7 @@ public void createAndVerifyDeviceTest() throws Exception { () -> Totp.verifyDevice(main, "user", "non-existent-device", "XXXX")); // Verify device with wrong code - assertThrows(InvalidTotpException.class, () -> Totp.verifyDevice(main, "user", "deviceName", "wrong-code")); + assertThrows(InvalidTotpException.class, () -> Totp.verifyDevice(main, "user", "deviceName", "ic0")); // Verify device with correct code String validCode = generateTotpCode(main, device); @@ -343,7 +353,7 @@ public void createAndVerifyDeviceTest() throws Exception { assert !justVerfied; // Verify again with wrong code: - justVerfied = Totp.verifyDevice(main, "user", "deviceName", "wrong-code"); + justVerfied = Totp.verifyDevice(main, "user", "deviceName", "ic1"); assert !justVerfied; result.process.kill(); @@ -372,7 +382,7 @@ public void removeDeviceTest() throws Exception { // Delete one of the devices { - assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "invalid-code", true)); + assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "ic0", true)); Totp.verifyCode(main, "user", generateTotpCode(main, device1), true); Totp.verifyCode(main, "user", generateTotpCode(main, device2), true); @@ -394,7 +404,7 @@ public void removeDeviceTest() throws Exception { // Create another user to test that other users aren't affected: TOTPDevice otherUserDevice = Totp.registerDevice(main, "other-user", "device", 1, 30); Totp.verifyCode(main, "other-user", generateTotpCode(main, otherUserDevice), true); - assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "other-user", "invalid-code", true)); + assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "other-user", "ic1", true)); // Delete device2 Totp.removeDevice(main, "user", "device2"); diff --git a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java index b82864845..766ba5b88 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java @@ -78,7 +78,7 @@ private static TOTPUsedCode[] getAllUsedCodesUtil(TOTPStorage storage, String us }); } - private static void insertUsedCodesUtil(TOTPSQLStorage storage, TOTPUsedCode[] usedCodes) + public static void insertUsedCodesUtil(TOTPSQLStorage storage, TOTPUsedCode[] usedCodes) throws StorageQueryException, StorageTransactionLogicException, TotpNotEnabledException, UsedCodeAlreadyExistsException { try { @@ -101,7 +101,7 @@ private static void insertUsedCodesUtil(TOTPSQLStorage storage, TOTPUsedCode[] u } else if (actual instanceof UsedCodeAlreadyExistsException) { throw (UsedCodeAlreadyExistsException) actual; } - throw new StorageQueryException(e); + throw e; } } @@ -401,12 +401,12 @@ public void getAllUsedCodesTest() throws Exception { long prevDay = now - 1000 * 60 * 60 * 24; // 1 day ago TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); - TOTPUsedCode validCode1 = new TOTPUsedCode("user", "valid-code-1", true, nextDay, now + 1); - TOTPUsedCode invalidCode = new TOTPUsedCode("user", "invalid-code", false, nextDay, now + 2); - TOTPUsedCode expiredCode = new TOTPUsedCode("user", "expired-code", true, prevDay, now + 3); - TOTPUsedCode expiredInvalidCode = new TOTPUsedCode("user", "expired-invalid-code", false, prevDay, now + 4); - TOTPUsedCode validCode2 = new TOTPUsedCode("user", "valid-code-2", true, nextDay, now + 5); - TOTPUsedCode validCode3 = new TOTPUsedCode("user", "valid-code-3", true, nextDay, now + 6); + TOTPUsedCode validCode1 = new TOTPUsedCode("user", "valid1", true, nextDay, now + 1); + TOTPUsedCode invalidCode = new TOTPUsedCode("user", "invalid", false, nextDay, now + 2); + TOTPUsedCode expiredCode = new TOTPUsedCode("user", "expired", true, prevDay, now + 3); + TOTPUsedCode expiredInvalidCode = new TOTPUsedCode("user", "ex-in", false, prevDay, now + 4); + TOTPUsedCode validCode2 = new TOTPUsedCode("user", "valid2", true, nextDay, now + 5); + TOTPUsedCode validCode3 = new TOTPUsedCode("user", "valid3", true, nextDay, now + 6); storage.createDevice(device); insertUsedCodesUtil(storage, new TOTPUsedCode[] { @@ -415,6 +415,12 @@ public void getAllUsedCodesTest() throws Exception { validCode2, validCode3 }); + // Try to create a code with same user and created time. It should fail: + assertThrows(UsedCodeAlreadyExistsException.class, + () -> insertUsedCodesUtil(storage, new TOTPUsedCode[] { + new TOTPUsedCode("user", "any-code", true, nextDay, now + 1) + })); + usedCodes = getAllUsedCodesUtil(storage, "user"); assert (usedCodes.length == 6); @@ -438,10 +444,10 @@ public void removeExpiredCodesTest() throws Exception { long halfSecond = System.currentTimeMillis() + 500; // 500ms from now TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); - TOTPUsedCode validCodeToLive = new TOTPUsedCode("user", "valid-code", true, nextDay, now); - TOTPUsedCode invalidCodeToLive = new TOTPUsedCode("user", "invalid-code", false, nextDay, now + 1); - TOTPUsedCode validCodeToExpire = new TOTPUsedCode("user", "valid-code", true, halfSecond, now + 2); - TOTPUsedCode invalidCodeToExpire = new TOTPUsedCode("user", "invalid-code", false, halfSecond, now + 3); + TOTPUsedCode validCodeToLive = new TOTPUsedCode("user", "valid", true, nextDay, now); + TOTPUsedCode invalidCodeToLive = new TOTPUsedCode("user", "invalid", false, nextDay, now + 1); + TOTPUsedCode validCodeToExpire = new TOTPUsedCode("user", "valid", true, halfSecond, now + 2); + TOTPUsedCode invalidCodeToExpire = new TOTPUsedCode("user", "invalid", false, halfSecond, now + 3); storage.createDevice(device); insertUsedCodesUtil(storage, new TOTPUsedCode[] { From c33fb260e5047827cfb1325181f2a9fbb2d46c36 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Wed, 15 Mar 2023 18:42:08 +0530 Subject: [PATCH 31/42] test: Add API layer test for TOTP recipe - Fixed get devices, verify TOTP, and verify device API - Add tests for all the APIs covering all exceptions --- src/main/java/io/supertokens/totp/Totp.java | 4 +- .../webserver/api/totp/GetTotpDevicesAPI.java | 4 +- .../webserver/api/totp/VerifyTotpAPI.java | 2 +- .../api/totp/VerifyTotpDeviceAPI.java | 5 +- .../io/supertokens/test/StorageLayerTest.java | 1 - .../io/supertokens/test/totp/TOTPApiTest.java | 67 ----- .../supertokens/test/totp/TOTPRecipeTest.java | 2 +- .../totp/totp/CreateTotpDeviceAPITest.java | 146 +++++++++++ .../test/totp/totp/GetTotpDevicesAPITest.java | 158 ++++++++++++ .../totp/totp/RemoveTotpDeviceAPITest.java | 182 +++++++++++++ .../totp/totp/UpdateTotpDeviceAPITest.java | 203 +++++++++++++++ .../test/totp/totp/VerifyTotpAPITest.java | 222 ++++++++++++++++ .../totp/totp/VerifyTotpDeviceAPITest.java | 242 ++++++++++++++++++ 13 files changed, 1160 insertions(+), 78 deletions(-) delete mode 100644 src/test/java/io/supertokens/test/totp/TOTPApiTest.java create mode 100644 src/test/java/io/supertokens/test/totp/totp/CreateTotpDeviceAPITest.java create mode 100644 src/test/java/io/supertokens/test/totp/totp/GetTotpDevicesAPITest.java create mode 100644 src/test/java/io/supertokens/test/totp/totp/RemoveTotpDeviceAPITest.java create mode 100644 src/test/java/io/supertokens/test/totp/totp/UpdateTotpDeviceAPITest.java create mode 100644 src/test/java/io/supertokens/test/totp/totp/VerifyTotpAPITest.java create mode 100644 src/test/java/io/supertokens/test/totp/totp/VerifyTotpDeviceAPITest.java diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index f679cac2e..668226096 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -139,8 +139,8 @@ private static void checkAndStoreCode(Main main, TOTPStorage totpStorage, String if (now - latestInvalidCodeCreatedTime < rateLimitResetTimeInMs) { // Less than rateLimitResetTimeInMs (default = 15 mins) time has elasped since // the last invalid code: - throw new StorageTransactionLogicException( - new LimitReachedException(rateLimitResetTimeInMs / 1000)); + int timeLeftMs = (int) (rateLimitResetTimeInMs - (now - latestInvalidCodeCreatedTime)); + throw new StorageTransactionLogicException(new LimitReachedException(timeLeftMs / 1000)); // If we insert the used code here, then it will further delay the user from // being able to login. So not inserting it here. diff --git a/src/main/java/io/supertokens/webserver/api/totp/GetTotpDevicesAPI.java b/src/main/java/io/supertokens/webserver/api/totp/GetTotpDevicesAPI.java index 122e8be8e..e8a9d156d 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/GetTotpDevicesAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/GetTotpDevicesAPI.java @@ -31,9 +31,7 @@ public String getPath() { @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { - JsonObject input = InputParser.parseJsonObjectOrThrowError(req); - - String userId = InputParser.parseStringOrThrowError(input, "userId", false); + String userId = InputParser.getQueryParamOrThrowError(req, "userId", false); if (userId.isEmpty()) { throw new ServletException(new BadRequestException("userId cannot be empty")); diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java index 4927af36e..0f83bbd4e 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java @@ -61,7 +61,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I super.sendJsonResponse(200, result, resp); } catch (LimitReachedException e) { result.addProperty("status", "LIMIT_REACHED_ERROR"); - result.addProperty("retryAfter", e.retryInSeconds); + result.addProperty("retryAfterSec", e.retryInSeconds); super.sendJsonResponse(200, result, resp); } catch (StorageQueryException | StorageTransactionLogicException e) { throw new ServletException(e); diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java index 55f27b5bc..741729d40 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java @@ -68,9 +68,8 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I super.sendJsonResponse(200, result, resp); } catch (LimitReachedException e) { result.addProperty("status", "LIMIT_REACHED_ERROR"); - // Also return a retryAfter value: - resp.addHeader("Retry-After", Integer.toString(e.retryInSeconds)); - super.sendJsonResponse(429, result, resp); // 429 (Too Many Requests) + result.addProperty("retryAfterSec", e.retryInSeconds); + super.sendJsonResponse(200, result, resp); } catch (StorageQueryException | StorageTransactionLogicException e) { throw new ServletException(e); } diff --git a/src/test/java/io/supertokens/test/StorageLayerTest.java b/src/test/java/io/supertokens/test/StorageLayerTest.java index f1f420334..c9896d8e9 100644 --- a/src/test/java/io/supertokens/test/StorageLayerTest.java +++ b/src/test/java/io/supertokens/test/StorageLayerTest.java @@ -81,7 +81,6 @@ public void totpCodeLengthTest() throws Exception { // Try code with length > 8 try { - TOTPUsedCode code = new TOTPUsedCode("user", "123456789", true, nextDay, now); insertUsedCodeUtil(storage, code); assert (false); diff --git a/src/test/java/io/supertokens/test/totp/TOTPApiTest.java b/src/test/java/io/supertokens/test/totp/TOTPApiTest.java deleted file mode 100644 index e2ab46039..000000000 --- a/src/test/java/io/supertokens/test/totp/TOTPApiTest.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved. - * - * This software is licensed under the Apache License, Version 2.0 (the - * "License") as published by the Apache Software Foundation. - * - * You may not use this file except in compliance with the License. You may - * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - */ - -package io.supertokens.test.totp; - -import com.google.gson.JsonObject; -import io.supertokens.ProcessState; -import io.supertokens.pluginInterface.STORAGE_TYPE; -import io.supertokens.storageLayer.StorageLayer; -import io.supertokens.test.TestingProcessManager; -import io.supertokens.test.Utils; -import org.junit.AfterClass; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TestRule; - -import static org.junit.Assert.assertNotNull; - -public class TOTPApiTest { - @Rule - public TestRule watchman = Utils.getOnFailure(); - - @AfterClass - public static void afterTesting() { - Utils.afterTesting(); - } - - @Before - public void beforeEach() { - Utils.reset(); - } - - @Test - public void testBadInput() throws Exception { - String[] args = { "../" }; - - TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); - assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); - - if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { - return; - } - - String userId = "userId"; - String deviceName = "XXX"; - - // - { - JsonObject code = new JsonObject(); - - } - } -} diff --git a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java index c998681dd..ac1f22a43 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java @@ -97,7 +97,7 @@ public TestSetupResult defaultInit() throws InterruptedException, IOException { return new TestSetupResult(storage, process); } - private static String generateTotpCode(Main main, TOTPDevice device) + public static String generateTotpCode(Main main, TOTPDevice device) throws InvalidKeyException, StorageQueryException { return generateTotpCode(main, device, 0); } diff --git a/src/test/java/io/supertokens/test/totp/totp/CreateTotpDeviceAPITest.java b/src/test/java/io/supertokens/test/totp/totp/CreateTotpDeviceAPITest.java new file mode 100644 index 000000000..96ca02b3e --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/totp/CreateTotpDeviceAPITest.java @@ -0,0 +1,146 @@ +package io.supertokens.test.totp.totp; + +import com.google.gson.JsonObject; +import io.supertokens.ProcessState; +import io.supertokens.test.httpRequest.HttpResponseException; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import static org.junit.Assert.*; + +public class CreateTotpDeviceAPITest { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + private Exception createDeviceRequest(TestingProcessManager.TestingProcess process, JsonObject body) { + return assertThrows( + io.supertokens.test.httpRequest.HttpResponseException.class, + () -> HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp")); + } + + private void checkFieldMissingErrorResponse(Exception ex, String fieldName) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains( + "Http error. Status Code: 400. Message: Field name '" + fieldName + "' is invalid in JSON input")); + } + + private void checkResponseErrorContains(Exception ex, String msg) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains(msg)); + } + + @Test + public void testApi() throws Exception { + String[] args = { "../" }; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + JsonObject body = new JsonObject(); + + // Missing userId/deviceName/skew/period + { + Exception e = createDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "userId"); + + body.addProperty("userId", ""); + e = createDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "deviceName"); + + body.addProperty("deviceName", ""); + e = createDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "skew"); + + body.addProperty("skew", -1); + e = createDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "period"); + } + + // Invalid userId/deviceName/skew/period + { + body.addProperty("period", 0); + Exception e = createDeviceRequest(process, body); + checkResponseErrorContains(e, "userId cannot be empty"); // Note that this is not a field missing error + + body.addProperty("userId", "user-id"); + e = createDeviceRequest(process, body); + checkResponseErrorContains(e, "deviceName cannot be empty"); + + body.addProperty("deviceName", "d1"); + e = createDeviceRequest(process, body); + checkResponseErrorContains(e, "skew must be >= 0"); + + body.addProperty("skew", 0); + e = createDeviceRequest(process, body); + checkResponseErrorContains(e, "period must be > 0"); + + body.addProperty("period", 30); + + // should pass now: + JsonObject res = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res.get("status").getAsString().equals("OK"); + + // try again with same device: + JsonObject res2 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res2.get("status").getAsString().equals("DEVICE_ALREADY_EXISTS_ERROR"); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + +} diff --git a/src/test/java/io/supertokens/test/totp/totp/GetTotpDevicesAPITest.java b/src/test/java/io/supertokens/test/totp/totp/GetTotpDevicesAPITest.java new file mode 100644 index 000000000..48ce4eab9 --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/totp/GetTotpDevicesAPITest.java @@ -0,0 +1,158 @@ +package io.supertokens.test.totp.totp; + +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import io.supertokens.ProcessState; +import io.supertokens.test.httpRequest.HttpResponseException; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import static org.junit.Assert.*; + +import java.util.HashMap; + +public class GetTotpDevicesAPITest { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + private Exception getDevicesRequestException(TestingProcessManager.TestingProcess process, + HashMap params) { + + return assertThrows( + io.supertokens.test.httpRequest.HttpResponseException.class, + () -> HttpRequestForTesting.sendGETRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/list", + params, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp")); + } + + private void checkFieldMissingErrorResponse(Exception ex, String fieldName) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains( + "Http error. Status Code: 400. Message: Field name '" + fieldName + "' is missing in GET request")); + } + + private void checkResponseErrorContains(Exception ex, String msg) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains(msg)); + } + + @Test + public void testApi() throws Exception { + String[] args = { "../" }; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + // Setup to create a device (which also creates a user) + { + JsonObject body = new JsonObject(); + body.addProperty("userId", "user-id"); + body.addProperty("deviceName", "device-name"); + body.addProperty("skew", 0); + body.addProperty("period", 30); + JsonObject res = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res.get("status").getAsString().equals("OK"); + } + + HashMap params = new HashMap<>(); + + // Missing userId + { + Exception e = getDevicesRequestException(process, params); + checkFieldMissingErrorResponse(e, "userId"); + } + + // Invalid userId + { + params.put("userId", ""); + Exception e = getDevicesRequestException(process, params); + checkResponseErrorContains(e, "userId cannot be empty"); // Note that this is not a field missing error + + params.put("userId", "user-id"); + + // should pass now: + JsonObject res = HttpRequestForTesting.sendGETRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/list", + params, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res.get("status").getAsString().equals("OK"); + + JsonArray devicesArr = res.get("devices").getAsJsonArray(); + JsonObject deviceJson = devicesArr.get(0).getAsJsonObject(); + + assert devicesArr.size() == 1; + assert deviceJson.get("name").getAsString().equals("device-name"); + assert deviceJson.get("period").getAsInt() == 30; + assert deviceJson.get("skew").getAsInt() == 0; + assert deviceJson.get("verified").getAsBoolean() == false; + + // try for non-existent user: + params.put("userId", "non-existent-user-id"); + JsonObject res2 = HttpRequestForTesting.sendGETRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/list", + params, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res2.get("status").getAsString().equals("TOTP_NOT_ENABLED_ERROR"); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + +} diff --git a/src/test/java/io/supertokens/test/totp/totp/RemoveTotpDeviceAPITest.java b/src/test/java/io/supertokens/test/totp/totp/RemoveTotpDeviceAPITest.java new file mode 100644 index 000000000..621ba7f42 --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/totp/RemoveTotpDeviceAPITest.java @@ -0,0 +1,182 @@ +package io.supertokens.test.totp.totp; + +import com.google.gson.JsonObject; +import io.supertokens.ProcessState; +import io.supertokens.test.httpRequest.HttpResponseException; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import static org.junit.Assert.*; + +public class RemoveTotpDeviceAPITest { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + private Exception removeDeviceRequest(TestingProcessManager.TestingProcess process, JsonObject body) { + return assertThrows( + io.supertokens.test.httpRequest.HttpResponseException.class, + () -> HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/remove", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp")); + } + + private void checkFieldMissingErrorResponse(Exception ex, String fieldName) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains( + "Http error. Status Code: 400. Message: Field name '" + fieldName + "' is invalid in JSON input")); + } + + private void checkResponseErrorContains(Exception ex, String msg) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains(msg)); + } + + @Test + public void testApi() throws Exception { + String[] args = { "../" }; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + // Setup user and devices: + JsonObject createDeviceReq = new JsonObject(); + createDeviceReq.addProperty("userId", "user-id"); + createDeviceReq.addProperty("deviceName", "d1"); + createDeviceReq.addProperty("period", 30); + createDeviceReq.addProperty("skew", 0); + + JsonObject createDeviceRes = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + createDeviceReq, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assertEquals(createDeviceRes.get("status").getAsString(), "OK"); + + // create another device d2: + createDeviceReq.addProperty("deviceName", "d2"); + JsonObject createDeviceRes2 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + createDeviceReq, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assertEquals(createDeviceRes2.get("status").getAsString(), "OK"); + + // Start the actual tests for remove device API: + + JsonObject body = new JsonObject(); + + // Missing userId/deviceName + { + Exception e = removeDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "userId"); + + body.addProperty("userId", ""); + e = removeDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "deviceName"); + + } + + // Invalid userId/deviceName + { + body.addProperty("deviceName", ""); + Exception e = removeDeviceRequest(process, body); + checkResponseErrorContains(e, "userId cannot be empty"); // Note that this is not a field missing error + + body.addProperty("userId", "user-id"); + e = removeDeviceRequest(process, body); + checkResponseErrorContains(e, "deviceName cannot be empty"); + + body.addProperty("deviceName", "d1"); + + // should pass now: + JsonObject res = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/remove", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res.get("status").getAsString().equals("OK"); + assert res.get("didDeviceExist").getAsBoolean() == true; + + // try again with same device (still pass but didDeviceExist should be false) + JsonObject res2 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/remove", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res2.get("status").getAsString().equals("OK"); + assert res2.get("didDeviceExist").getAsBoolean() == false; + + // try deleting device for a non-existent user + body.addProperty("userId", "non-existent-user"); + JsonObject res3 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/remove", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res3.get("status").getAsString().equals("TOTP_NOT_ENABLED_ERROR"); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + +} diff --git a/src/test/java/io/supertokens/test/totp/totp/UpdateTotpDeviceAPITest.java b/src/test/java/io/supertokens/test/totp/totp/UpdateTotpDeviceAPITest.java new file mode 100644 index 000000000..0e54b4962 --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/totp/UpdateTotpDeviceAPITest.java @@ -0,0 +1,203 @@ +package io.supertokens.test.totp.totp; + +import com.google.gson.JsonObject; +import io.supertokens.ProcessState; +import io.supertokens.test.httpRequest.HttpResponseException; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import static org.junit.Assert.*; + +public class UpdateTotpDeviceAPITest { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + private Exception updateDeviceRequest(TestingProcessManager.TestingProcess process, JsonObject body) { + return assertThrows( + io.supertokens.test.httpRequest.HttpResponseException.class, + () -> HttpRequestForTesting.sendJsonPUTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp")); + } + + private void checkFieldMissingErrorResponse(Exception ex, String fieldName) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains( + "Http error. Status Code: 400. Message: Field name '" + fieldName + "' is invalid in JSON input")); + } + + private void checkResponseErrorContains(Exception ex, String msg) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains(msg)); + } + + @Test + public void testApi() throws Exception { + String[] args = { "../" }; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + // Setup user and devices: + JsonObject createDeviceReq = new JsonObject(); + createDeviceReq.addProperty("userId", "user-id"); + createDeviceReq.addProperty("deviceName", "d1"); + createDeviceReq.addProperty("period", 30); + createDeviceReq.addProperty("skew", 0); + + JsonObject createDeviceRes = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + createDeviceReq, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assertEquals(createDeviceRes.get("status").getAsString(), "OK"); + + // create another device d2: + createDeviceReq.addProperty("deviceName", "d2"); + JsonObject createDeviceRes2 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + createDeviceReq, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assertEquals(createDeviceRes2.get("status").getAsString(), "OK"); + + // Start the actual tests for update device API: + + JsonObject body = new JsonObject(); + + // Missing userId/deviceName/skew/period + { + Exception e = updateDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "userId"); + + body.addProperty("userId", ""); + e = updateDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "existingDeviceName"); + + body.addProperty("existingDeviceName", ""); + e = updateDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "newDeviceName"); + + } + + // Invalid userId/deviceName/skew/period + { + body.addProperty("newDeviceName", ""); + Exception e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "userId cannot be empty"); // Note that this is not a field missing error + + body.addProperty("userId", "user-id"); + e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "existingDeviceName cannot be empty"); + + body.addProperty("existingDeviceName", "d1"); + e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "newDeviceName cannot be empty"); + + body.addProperty("newDeviceName", "d1-new"); + + // should pass now: + JsonObject res = HttpRequestForTesting.sendJsonPUTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res.get("status").getAsString().equals("OK"); + + // try again with same device (has been renamed so should fail) + JsonObject res2 = HttpRequestForTesting.sendJsonPUTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res2.get("status").getAsString().equals("UNKNOWN_DEVICE_ERROR"); + + // try renaming to a device that already exists + body.addProperty("existingDeviceName", "d1-new"); + body.addProperty("newDeviceName", "d2"); + JsonObject res3 = HttpRequestForTesting.sendJsonPUTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res3.get("status").getAsString().equals("DEVICE_ALREADY_EXISTS_ERROR"); + + // try renaming to a device that already exists for a non-existent user + body.addProperty("userId", "non-existent-user"); + JsonObject res4 = HttpRequestForTesting.sendJsonPUTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res4.get("status").getAsString().equals("TOTP_NOT_ENABLED_ERROR"); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + +} diff --git a/src/test/java/io/supertokens/test/totp/totp/VerifyTotpAPITest.java b/src/test/java/io/supertokens/test/totp/totp/VerifyTotpAPITest.java new file mode 100644 index 000000000..d8bef9e27 --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/totp/VerifyTotpAPITest.java @@ -0,0 +1,222 @@ +package io.supertokens.test.totp.totp; + +import com.google.gson.JsonObject; + +import io.supertokens.ProcessState; +import io.supertokens.test.httpRequest.HttpResponseException; +import io.supertokens.test.totp.TOTPRecipeTest; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import static org.junit.Assert.*; + +public class VerifyTotpAPITest { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + private Exception updateDeviceRequest(TestingProcessManager.TestingProcess process, JsonObject body) { + return assertThrows( + io.supertokens.test.httpRequest.HttpResponseException.class, + () -> HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp")); + } + + private void checkFieldMissingErrorResponse(Exception ex, String fieldName) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains( + "Http error. Status Code: 400. Message: Field name '" + fieldName + "' is invalid in JSON input")); + } + + private void checkResponseErrorContains(Exception ex, String msg) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains(msg)); + } + + @Test + public void testApi() throws Exception { + String[] args = { "../" }; + + // Trigger rate limiting on 1 wrong attempts: + Utils.setValueInConfig("totp_max_attempts", "1"); + // Set cooldown to 1 second: + Utils.setValueInConfig("totp_rate_limit_cooldown_sec", "1"); + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + // Setup user and devices: + JsonObject createDeviceReq = new JsonObject(); + createDeviceReq.addProperty("userId", "user-id"); + createDeviceReq.addProperty("deviceName", "deviceName"); + createDeviceReq.addProperty("period", 30); + createDeviceReq.addProperty("skew", 0); + + JsonObject createDeviceRes = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + createDeviceReq, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assertEquals(createDeviceRes.get("status").getAsString(), "OK"); + String secretKey = createDeviceRes.get("secret").getAsString(); + + TOTPDevice device = new TOTPDevice("user-id", "deviceName", secretKey, 30, 0, false); + + // Start the actual tests for update device API: + + JsonObject body = new JsonObject(); + + // Missing userId/deviceName/skew/period + { + Exception e = updateDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "userId"); + + body.addProperty("userId", ""); + e = updateDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "totp"); + + body.addProperty("totp", ""); + e = updateDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "allowUnverifiedDevices"); + } + + // Invalid userId/deviceName/skew/period + { + body.addProperty("allowUnverifiedDevices", true); + Exception e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "userId cannot be empty"); // Note that this is not a field missing error + + body.addProperty("userId", device.userId); + e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "totp must be 6 characters long"); + + // test totp of length 5: + body.addProperty("totp", "12345"); + e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "totp must be 6 characters long"); + + // test totp of length 8: + body.addProperty("totp", "12345678"); + e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "totp must be 6 characters long"); + + // but let's pass invalid code first + body.addProperty("totp", "123456"); + JsonObject res0 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res0.get("status").getAsString().equals("INVALID_TOTP_ERROR"); + + // Check that rate limiting is triggered for the user: + JsonObject res3 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res3.get("status").getAsString().equals("LIMIT_REACHED_ERROR"); + assert res3.get("retryAfterSec") != null; + + // wait for cooldown to end (1s) + Thread.sleep(1000); + + // should pass now on valid code + String validTotp = TOTPRecipeTest.generateTotpCode(process.getProcess(), device); + body.addProperty("totp", validTotp); + JsonObject res = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res.get("status").getAsString().equals("OK"); + + // try to reuse the same code (replay attack) + body.addProperty("totp", "mycode"); + JsonObject res2 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res2.get("status").getAsString().equals("INVALID_TOTP_ERROR"); + + // try verifying device for a non-existent user + body.addProperty("userId", "non-existent-user"); + JsonObject res5 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res5.get("status").getAsString().equals("TOTP_NOT_ENABLED_ERROR"); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + +} diff --git a/src/test/java/io/supertokens/test/totp/totp/VerifyTotpDeviceAPITest.java b/src/test/java/io/supertokens/test/totp/totp/VerifyTotpDeviceAPITest.java new file mode 100644 index 000000000..978282d2e --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/totp/VerifyTotpDeviceAPITest.java @@ -0,0 +1,242 @@ +package io.supertokens.test.totp.totp; + +import com.google.gson.JsonObject; + +import io.supertokens.ProcessState; +import io.supertokens.test.httpRequest.HttpResponseException; +import io.supertokens.test.totp.TOTPRecipeTest; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import static org.junit.Assert.*; + +public class VerifyTotpDeviceAPITest { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + private Exception updateDeviceRequest(TestingProcessManager.TestingProcess process, JsonObject body) { + return assertThrows( + io.supertokens.test.httpRequest.HttpResponseException.class, + () -> HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp")); + } + + private void checkFieldMissingErrorResponse(Exception ex, String fieldName) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains( + "Http error. Status Code: 400. Message: Field name '" + fieldName + "' is invalid in JSON input")); + } + + private void checkResponseErrorContains(Exception ex, String msg) { + assert ex instanceof HttpResponseException; + HttpResponseException e = (HttpResponseException) ex; + assert e.statusCode == 400; + assertTrue(e.getMessage().contains(msg)); + } + + @Test + public void testApi() throws Exception { + String[] args = { "../" }; + + // Trigger rate limiting on 1 wrong attempts: + Utils.setValueInConfig("totp_max_attempts", "1"); + // Set cooldown to 1 second: + Utils.setValueInConfig("totp_rate_limit_cooldown_sec", "1"); + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + // Setup user and devices: + JsonObject createDeviceReq = new JsonObject(); + createDeviceReq.addProperty("userId", "user-id"); + createDeviceReq.addProperty("deviceName", "deviceName"); + createDeviceReq.addProperty("period", 30); + createDeviceReq.addProperty("skew", 0); + + JsonObject createDeviceRes = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + createDeviceReq, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assertEquals(createDeviceRes.get("status").getAsString(), "OK"); + String secretKey = createDeviceRes.get("secret").getAsString(); + + TOTPDevice device = new TOTPDevice("user-id", "deviceName", secretKey, 30, 0, false); + + // Start the actual tests for update device API: + + JsonObject body = new JsonObject(); + + // Missing userId/deviceName/skew/period + { + Exception e = updateDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "userId"); + + body.addProperty("userId", ""); + e = updateDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "deviceName"); + + body.addProperty("deviceName", ""); + e = updateDeviceRequest(process, body); + checkFieldMissingErrorResponse(e, "totp"); + } + + // Invalid userId/deviceName/skew/period + { + body.addProperty("totp", ""); + Exception e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "userId cannot be empty"); // Note that this is not a field missing error + + body.addProperty("userId", device.userId); + e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "deviceName cannot be empty"); + + body.addProperty("deviceName", device.deviceName); + e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "totp must be 6 characters long"); + + // test totp of length 5: + body.addProperty("totp", "12345"); + e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "totp must be 6 characters long"); + + // test totp of length 8: + body.addProperty("totp", "12345678"); + e = updateDeviceRequest(process, body); + checkResponseErrorContains(e, "totp must be 6 characters long"); + + // but let's pass invalid code first + body.addProperty("totp", "123456"); + JsonObject res0 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res0.get("status").getAsString().equals("INVALID_TOTP_ERROR"); + + // Check that rate limiting is triggered for the user: + JsonObject res3 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res3.get("status").getAsString().equals("LIMIT_REACHED_ERROR"); + assert res3.get("retryAfterSec") != null; + + // wait for cooldown to end (1s) + Thread.sleep(1000); + + // should pass now on valid code + String validTotp = TOTPRecipeTest.generateTotpCode(process.getProcess(), device); + body.addProperty("totp", validTotp); + JsonObject res = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res.get("status").getAsString().equals("OK"); + assert res.get("wasAlreadyVerified").getAsBoolean() == false; + + // try again to verify the user with any code (valid/invalid) + body.addProperty("totp", "mycode"); + JsonObject res2 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res2.get("status").getAsString().equals("OK"); + assert res2.get("wasAlreadyVerified").getAsBoolean() == true; + + // try again with unknown device + body.addProperty("deviceName", "non-existent-device"); + JsonObject res4 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res4.get("status").getAsString().equals("UNKNOWN_DEVICE_ERROR"); + + // try verifying device for a non-existent user + body.addProperty("userId", "non-existent-user"); + JsonObject res5 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/verify", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res5.get("status").getAsString().equals("TOTP_NOT_ENABLED_ERROR"); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + +} From 235335f2da8f6d1f73a12761dba05f3e4b8d71c6 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Thu, 16 Mar 2023 20:20:31 +0530 Subject: [PATCH 32/42] feat: Finish totp implementation --- CHANGELOG.md | 2 ++ .../java/io/supertokens/inmemorydb/queries/TOTPQueries.java | 2 +- src/main/java/io/supertokens/totp/Totp.java | 6 ++++-- .../supertokens/totp/exceptions/LimitReachedException.java | 6 +++--- .../io/supertokens/webserver/api/totp/VerifyTotpAPI.java | 2 +- .../supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java | 2 +- .../test/totp/{totp => api}/CreateTotpDeviceAPITest.java | 2 +- .../test/totp/{totp => api}/GetTotpDevicesAPITest.java | 3 +-- .../test/totp/{totp => api}/RemoveTotpDeviceAPITest.java | 2 +- .../test/totp/{totp => api}/UpdateTotpDeviceAPITest.java | 2 +- .../test/totp/{totp => api}/VerifyTotpAPITest.java | 2 +- .../test/totp/{totp => api}/VerifyTotpDeviceAPITest.java | 2 +- 12 files changed, 18 insertions(+), 15 deletions(-) rename src/test/java/io/supertokens/test/totp/{totp => api}/CreateTotpDeviceAPITest.java (99%) rename src/test/java/io/supertokens/test/totp/{totp => api}/GetTotpDevicesAPITest.java (98%) rename src/test/java/io/supertokens/test/totp/{totp => api}/RemoveTotpDeviceAPITest.java (99%) rename src/test/java/io/supertokens/test/totp/{totp => api}/UpdateTotpDeviceAPITest.java (99%) rename src/test/java/io/supertokens/test/totp/{totp => api}/VerifyTotpAPITest.java (99%) rename src/test/java/io/supertokens/test/totp/{totp => api}/VerifyTotpDeviceAPITest.java (99%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 096b79a40..96f9a4059 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [unreleased] +- Add TOTP recipe + ## [4.4.1] - 2023-03-09 - Normalises email in all APIs in which email was not being diff --git a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java index fb4758cc8..76d2e17e3 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/TOTPQueries.java @@ -44,7 +44,7 @@ public static String getQueryToCreateUsedCodesTable(Start start) { + "is_valid BOOLEAN NOT NULL," + "expiry_time_ms BIGINT UNSIGNED NOT NULL," + "created_time_ms BIGINT UNSIGNED NOT NULL," - + "PRIMARY KEY (user_id, created_time_ms)," // failing without comma in postgres. validate + + "PRIMARY KEY (user_id, created_time_ms)," + "FOREIGN KEY (user_id) REFERENCES " + Config.getConfig(start).getTotpUsersTable() + "(user_id) ON DELETE CASCADE);"; } diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index 668226096..115633add 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -29,6 +29,8 @@ import io.supertokens.totp.exceptions.LimitReachedException; import org.apache.commons.codec.binary.Base32; +// TODO: Add test for UsedCodeAlreadyExistsException once we implement time mocking + public class Totp { private static String generateSecret() throws NoSuchAlgorithmException { // Reference: https://github.com/jchambers/java-otp @@ -139,8 +141,8 @@ private static void checkAndStoreCode(Main main, TOTPStorage totpStorage, String if (now - latestInvalidCodeCreatedTime < rateLimitResetTimeInMs) { // Less than rateLimitResetTimeInMs (default = 15 mins) time has elasped since // the last invalid code: - int timeLeftMs = (int) (rateLimitResetTimeInMs - (now - latestInvalidCodeCreatedTime)); - throw new StorageTransactionLogicException(new LimitReachedException(timeLeftMs / 1000)); + long timeLeftMs = (rateLimitResetTimeInMs - (now - latestInvalidCodeCreatedTime)); + throw new StorageTransactionLogicException(new LimitReachedException(timeLeftMs)); // If we insert the used code here, then it will further delay the user from // being able to login. So not inserting it here. diff --git a/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java b/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java index 1cf9772fe..fd70e5522 100644 --- a/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java +++ b/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java @@ -2,10 +2,10 @@ public class LimitReachedException extends Exception { - public int retryInSeconds; + public long retryInMs; - public LimitReachedException(int retryInSeconds) { + public LimitReachedException(long retryInSeconds) { super("Retry in " + retryInSeconds + " seconds"); - this.retryInSeconds = retryInSeconds; + this.retryInMs = retryInSeconds; } } diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java index 0f83bbd4e..2b910745b 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java @@ -61,7 +61,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I super.sendJsonResponse(200, result, resp); } catch (LimitReachedException e) { result.addProperty("status", "LIMIT_REACHED_ERROR"); - result.addProperty("retryAfterSec", e.retryInSeconds); + result.addProperty("retryAfterMs", e.retryInMs); super.sendJsonResponse(200, result, resp); } catch (StorageQueryException | StorageTransactionLogicException e) { throw new ServletException(e); diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java index 741729d40..a31a0e326 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java @@ -68,7 +68,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I super.sendJsonResponse(200, result, resp); } catch (LimitReachedException e) { result.addProperty("status", "LIMIT_REACHED_ERROR"); - result.addProperty("retryAfterSec", e.retryInSeconds); + result.addProperty("retryAfterMs", e.retryInMs); super.sendJsonResponse(200, result, resp); } catch (StorageQueryException | StorageTransactionLogicException e) { throw new ServletException(e); diff --git a/src/test/java/io/supertokens/test/totp/totp/CreateTotpDeviceAPITest.java b/src/test/java/io/supertokens/test/totp/api/CreateTotpDeviceAPITest.java similarity index 99% rename from src/test/java/io/supertokens/test/totp/totp/CreateTotpDeviceAPITest.java rename to src/test/java/io/supertokens/test/totp/api/CreateTotpDeviceAPITest.java index 96ca02b3e..fd15bd114 100644 --- a/src/test/java/io/supertokens/test/totp/totp/CreateTotpDeviceAPITest.java +++ b/src/test/java/io/supertokens/test/totp/api/CreateTotpDeviceAPITest.java @@ -1,4 +1,4 @@ -package io.supertokens.test.totp.totp; +package io.supertokens.test.totp.api; import com.google.gson.JsonObject; import io.supertokens.ProcessState; diff --git a/src/test/java/io/supertokens/test/totp/totp/GetTotpDevicesAPITest.java b/src/test/java/io/supertokens/test/totp/api/GetTotpDevicesAPITest.java similarity index 98% rename from src/test/java/io/supertokens/test/totp/totp/GetTotpDevicesAPITest.java rename to src/test/java/io/supertokens/test/totp/api/GetTotpDevicesAPITest.java index 48ce4eab9..93d749c0f 100644 --- a/src/test/java/io/supertokens/test/totp/totp/GetTotpDevicesAPITest.java +++ b/src/test/java/io/supertokens/test/totp/api/GetTotpDevicesAPITest.java @@ -1,11 +1,10 @@ -package io.supertokens.test.totp.totp; +package io.supertokens.test.totp.api; import com.google.gson.JsonArray; import com.google.gson.JsonObject; import io.supertokens.ProcessState; import io.supertokens.test.httpRequest.HttpResponseException; import io.supertokens.pluginInterface.STORAGE_TYPE; -import io.supertokens.pluginInterface.totp.TOTPDevice; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.test.TestingProcessManager; import io.supertokens.test.Utils; diff --git a/src/test/java/io/supertokens/test/totp/totp/RemoveTotpDeviceAPITest.java b/src/test/java/io/supertokens/test/totp/api/RemoveTotpDeviceAPITest.java similarity index 99% rename from src/test/java/io/supertokens/test/totp/totp/RemoveTotpDeviceAPITest.java rename to src/test/java/io/supertokens/test/totp/api/RemoveTotpDeviceAPITest.java index 621ba7f42..abcc74451 100644 --- a/src/test/java/io/supertokens/test/totp/totp/RemoveTotpDeviceAPITest.java +++ b/src/test/java/io/supertokens/test/totp/api/RemoveTotpDeviceAPITest.java @@ -1,4 +1,4 @@ -package io.supertokens.test.totp.totp; +package io.supertokens.test.totp.api; import com.google.gson.JsonObject; import io.supertokens.ProcessState; diff --git a/src/test/java/io/supertokens/test/totp/totp/UpdateTotpDeviceAPITest.java b/src/test/java/io/supertokens/test/totp/api/UpdateTotpDeviceAPITest.java similarity index 99% rename from src/test/java/io/supertokens/test/totp/totp/UpdateTotpDeviceAPITest.java rename to src/test/java/io/supertokens/test/totp/api/UpdateTotpDeviceAPITest.java index 0e54b4962..551f92793 100644 --- a/src/test/java/io/supertokens/test/totp/totp/UpdateTotpDeviceAPITest.java +++ b/src/test/java/io/supertokens/test/totp/api/UpdateTotpDeviceAPITest.java @@ -1,4 +1,4 @@ -package io.supertokens.test.totp.totp; +package io.supertokens.test.totp.api; import com.google.gson.JsonObject; import io.supertokens.ProcessState; diff --git a/src/test/java/io/supertokens/test/totp/totp/VerifyTotpAPITest.java b/src/test/java/io/supertokens/test/totp/api/VerifyTotpAPITest.java similarity index 99% rename from src/test/java/io/supertokens/test/totp/totp/VerifyTotpAPITest.java rename to src/test/java/io/supertokens/test/totp/api/VerifyTotpAPITest.java index d8bef9e27..a6ee05aad 100644 --- a/src/test/java/io/supertokens/test/totp/totp/VerifyTotpAPITest.java +++ b/src/test/java/io/supertokens/test/totp/api/VerifyTotpAPITest.java @@ -1,4 +1,4 @@ -package io.supertokens.test.totp.totp; +package io.supertokens.test.totp.api; import com.google.gson.JsonObject; diff --git a/src/test/java/io/supertokens/test/totp/totp/VerifyTotpDeviceAPITest.java b/src/test/java/io/supertokens/test/totp/api/VerifyTotpDeviceAPITest.java similarity index 99% rename from src/test/java/io/supertokens/test/totp/totp/VerifyTotpDeviceAPITest.java rename to src/test/java/io/supertokens/test/totp/api/VerifyTotpDeviceAPITest.java index 978282d2e..55da8c14e 100644 --- a/src/test/java/io/supertokens/test/totp/totp/VerifyTotpDeviceAPITest.java +++ b/src/test/java/io/supertokens/test/totp/api/VerifyTotpDeviceAPITest.java @@ -1,4 +1,4 @@ -package io.supertokens.test.totp.totp; +package io.supertokens.test.totp.api; import com.google.gson.JsonObject; From 04e4fd795eda219b18a01f57f9f2ba6a95bde4e1 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Mon, 20 Mar 2023 17:15:28 +0530 Subject: [PATCH 33/42] refactor: Update TOTP recipe vars and comments --- src/main/java/io/supertokens/totp/Totp.java | 2 -- .../totp/exceptions/LimitReachedException.java | 8 ++++---- .../io/supertokens/webserver/api/totp/VerifyTotpAPI.java | 2 +- .../webserver/api/totp/VerifyTotpDeviceAPI.java | 2 +- .../java/io/supertokens/test/totp/TOTPRecipeTest.java | 4 +++- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index 115633add..b055f95c7 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -29,8 +29,6 @@ import io.supertokens.totp.exceptions.LimitReachedException; import org.apache.commons.codec.binary.Base32; -// TODO: Add test for UsedCodeAlreadyExistsException once we implement time mocking - public class Totp { private static String generateSecret() throws NoSuchAlgorithmException { // Reference: https://github.com/jchambers/java-otp diff --git a/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java b/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java index fd70e5522..b7b1c8078 100644 --- a/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java +++ b/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java @@ -2,10 +2,10 @@ public class LimitReachedException extends Exception { - public long retryInMs; + public long retryAfterMs; - public LimitReachedException(long retryInSeconds) { - super("Retry in " + retryInSeconds + " seconds"); - this.retryInMs = retryInSeconds; + public LimitReachedException(long retryAfterMs) { + super("Retry in " + retryAfterMs + " ms"); + this.retryAfterMs = retryAfterMs; } } diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java index 2b910745b..027eaca5d 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java @@ -61,7 +61,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I super.sendJsonResponse(200, result, resp); } catch (LimitReachedException e) { result.addProperty("status", "LIMIT_REACHED_ERROR"); - result.addProperty("retryAfterMs", e.retryInMs); + result.addProperty("retryAfterMs", e.retryAfterMs); super.sendJsonResponse(200, result, resp); } catch (StorageQueryException | StorageTransactionLogicException e) { throw new ServletException(e); diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java index a31a0e326..00f38d40f 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java @@ -68,7 +68,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I super.sendJsonResponse(200, result, resp); } catch (LimitReachedException e) { result.addProperty("status", "LIMIT_REACHED_ERROR"); - result.addProperty("retryAfterMs", e.retryInMs); + result.addProperty("retryAfterMs", e.retryAfterMs); super.sendJsonResponse(200, result, resp); } catch (StorageQueryException | StorageTransactionLogicException e) { throw new ServletException(e); diff --git a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java index ac1f22a43..a22dc4d67 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java @@ -58,6 +58,8 @@ import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; +// TODO: Add test for UsedCodeAlreadyExistsException once we implement time mocking + public class TOTPRecipeTest { @Rule @@ -295,7 +297,7 @@ public void rateLimitCooldownTest() throws Exception { } @Test - public void cronRemovesAllCodesDuringRateLimitTest() throws Exception { + public void cronRemovesCodesDuringRateLimitTest() throws Exception { // This test is flaky because of time. TestSetupResult result = defaultInit(); Main main = result.process.getProcess(); From 5ab015e73fc8928d7b27c38a407fc9648ad2781a Mon Sep 17 00:00:00 2001 From: KShivendu Date: Mon, 20 Mar 2023 17:34:08 +0530 Subject: [PATCH 34/42] chores: Mention API and DB changes for TOTP recipe in CHANGELOG --- CHANGELOG.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 96f9a4059..0b072c924 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,20 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - Add TOTP recipe +### Database changes: +- Add new tables for TOTP recipe: + - `totp_users` that stores the users that have enabled TOTP + - `totp_user_devices` that stores devices (each device has its own secret) for each user + - `totp_used_codes` that stores used codes for each user. This is to implement rate limiting and prevent replay attacks. + +### New APIs: +- `POST /recipe/totp/device` to create a new device as well as the user if it doesn't exist. +- `POST /recipe/totp/device/verify` to verify a device. This is to ensure that the user has access to the device. +- `POST /recipe/totp/verify` to verify a code and continue the login flow. +- `PUT /recipe/totp/device` to update the name of a device. Name is just a string that the user can set to identify the device. +- `GET /recipe/totp/device/list` to get all devices for a user. +- `POST /recipe/totp/device/remove` to remove a device. If the user has no more devices, the user is also removed. + ## [4.4.1] - 2023-03-09 - Normalises email in all APIs in which email was not being From 4e13470c8ff403293c976e5fb2b124abcf3e5134 Mon Sep 17 00:00:00 2001 From: Kumar Shivendu Date: Tue, 21 Mar 2023 14:01:33 +0530 Subject: [PATCH 35/42] feat: Add support for active users stats (#585) * feat: Add support for active users stats * feat: Monitor active users for all auth recipes and session recipe --- src/main/java/io/supertokens/ActiveUsers.java | 14 ++++++ .../java/io/supertokens/inmemorydb/Start.java | 21 ++++++++- .../inmemorydb/config/SQLiteConfig.java | 8 +++- .../queries/ActiveUsersQueries.java | 43 +++++++++++++++++++ .../inmemorydb/queries/GeneralQueries.java | 5 +++ .../storageLayer/StorageLayer.java | 9 ++++ src/main/java/io/supertokens/totp/Totp.java | 1 + .../api/emailpassword/SignInAPI.java | 4 ++ .../api/emailpassword/SignUpAPI.java | 4 ++ .../api/passwordless/ConsumeCodeAPI.java | 4 ++ .../api/session/RefreshSessionAPI.java | 13 ++++++ .../webserver/api/session/SessionAPI.java | 13 +++++- .../api/session/SessionRemoveAPI.java | 13 ++++++ .../webserver/api/thirdparty/SignInUpAPI.java | 6 +++ 14 files changed, 154 insertions(+), 4 deletions(-) create mode 100644 src/main/java/io/supertokens/ActiveUsers.java create mode 100644 src/main/java/io/supertokens/inmemorydb/queries/ActiveUsersQueries.java diff --git a/src/main/java/io/supertokens/ActiveUsers.java b/src/main/java/io/supertokens/ActiveUsers.java new file mode 100644 index 000000000..f2ad3627e --- /dev/null +++ b/src/main/java/io/supertokens/ActiveUsers.java @@ -0,0 +1,14 @@ +package io.supertokens; + +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.storageLayer.StorageLayer; + +public class ActiveUsers { + + public static void updateLastActive(Main main, String userId) { + try { + StorageLayer.getActiveUsersStorage(main).updateLastActive(userId); + } catch (StorageQueryException ignored) { + } + } +} diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index 07c6a625f..cbc47bc6d 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -24,6 +24,7 @@ import io.supertokens.emailverification.exception.EmailAlreadyVerifiedException; import io.supertokens.inmemorydb.config.Config; import io.supertokens.inmemorydb.queries.*; +import io.supertokens.pluginInterface.ActiveUsersStorage; import io.supertokens.pluginInterface.KeyValueInfo; import io.supertokens.pluginInterface.LOG_LEVEL; import io.supertokens.pluginInterface.RECIPE_ID; @@ -99,7 +100,7 @@ public class Start implements SessionSQLStorage, EmailPasswordSQLStorage, EmailVerificationSQLStorage, ThirdPartySQLStorage, JWTRecipeSQLStorage, PasswordlessSQLStorage, UserMetadataSQLStorage, UserRolesSQLStorage, UserIdMappingStorage, - DashboardSQLStorage, TOTPSQLStorage { + DashboardSQLStorage, TOTPSQLStorage, ActiveUsersStorage { private static final Object appenderLock = new Object(); private static final String APP_ID_KEY_NAME = "app_id"; @@ -441,6 +442,24 @@ public boolean doesUserIdExist(String userId) throws StorageQueryException { } } + @Override + public void updateLastActive(String userId) throws StorageQueryException { + try { + ActiveUsersQueries.updateUserLastActive(this, userId); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public int countUsersActiveSince(long time) throws StorageQueryException { + try { + return ActiveUsersQueries.countUsersActiveSince(this, time); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + @Override public SessionInfo getSessionInfo_Transaction(TransactionConnection con, String sessionHandle) throws StorageQueryException { diff --git a/src/main/java/io/supertokens/inmemorydb/config/SQLiteConfig.java b/src/main/java/io/supertokens/inmemorydb/config/SQLiteConfig.java index a903c74f8..3ddaf0ac7 100644 --- a/src/main/java/io/supertokens/inmemorydb/config/SQLiteConfig.java +++ b/src/main/java/io/supertokens/inmemorydb/config/SQLiteConfig.java @@ -26,6 +26,10 @@ public String getUsersTable() { return "all_auth_recipe_users"; } + public String getUserLastActiveTable() { + return "user_last_active"; + } + public String getAccessTokenSigningKeysTable() { return "session_access_token_signing_keys"; } @@ -102,11 +106,11 @@ public String getTotpUsedCodesTable() { return "totp_used_codes"; } - public String getDashboardUsersTable(){ + public String getDashboardUsersTable() { return "dashboard_users"; } - public String getDashboardSessionsTable(){ + public String getDashboardSessionsTable() { return "dashboard_user_sessions"; } } diff --git a/src/main/java/io/supertokens/inmemorydb/queries/ActiveUsersQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/ActiveUsersQueries.java new file mode 100644 index 000000000..43fc22f52 --- /dev/null +++ b/src/main/java/io/supertokens/inmemorydb/queries/ActiveUsersQueries.java @@ -0,0 +1,43 @@ +package io.supertokens.inmemorydb.queries; + +import java.sql.SQLException; + +import io.supertokens.inmemorydb.config.Config; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.inmemorydb.Start; + +import static io.supertokens.inmemorydb.QueryExecutorTemplate.execute; +import static io.supertokens.inmemorydb.QueryExecutorTemplate.update; + +public class ActiveUsersQueries { + static String getQueryToCreateUserLastActiveTable(Start start) { + return "CREATE TABLE IF NOT EXISTS " + Config.getConfig(start).getUserLastActiveTable() + " (" + + "user_id VARCHAR(128)," + + "last_active_time BIGINT UNSIGNED," + "PRIMARY KEY(user_id)" + " );"; + } + + public static int countUsersActiveSince(Start start, long sinceTime) throws SQLException, StorageQueryException { + String QUERY = "SELECT COUNT(*) as total FROM " + Config.getConfig(start).getUserLastActiveTable() + + " WHERE last_active_time >= ?"; + + return execute(start, QUERY, pst -> pst.setLong(1, sinceTime), result -> { + if (result.next()) { + return result.getInt("total"); + } + return 0; + }); + } + + public static int updateUserLastActive(Start start, String userId) throws SQLException, StorageQueryException { + String QUERY = "INSERT INTO " + Config.getConfig(start).getUserLastActiveTable() + + "(user_id, last_active_time) VALUES(?, ?) ON CONFLICT(user_id) DO UPDATE SET last_active_time = ?"; + + long now = System.currentTimeMillis(); + return update(start, QUERY, pst -> { + pst.setString(1, userId); + pst.setLong(2, now); + pst.setLong(3, now); + }); + } + +} diff --git a/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java index 4c98a9724..d1c30556e 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java @@ -94,6 +94,11 @@ public static void createTablesIfNotExists(Start start, Main main) throws SQLExc update(start, getQueryToCreateUserPaginationIndex(start), NO_OP_SETTER); } + if (!doesTableExists(start, Config.getConfig(start).getUserLastActiveTable())) { + getInstance(main).addState(CREATING_NEW_TABLE, null); + update(start, ActiveUsersQueries.getQueryToCreateUserLastActiveTable(start), NO_OP_SETTER); + } + if (!doesTableExists(start, Config.getConfig(start).getAccessTokenSigningKeysTable())) { getInstance(main).addState(CREATING_NEW_TABLE, null); update(start, getQueryToCreateAccessTokenSigningKeysTable(start), NO_OP_SETTER); diff --git a/src/main/java/io/supertokens/storageLayer/StorageLayer.java b/src/main/java/io/supertokens/storageLayer/StorageLayer.java index a9cb7313c..7a35d1531 100644 --- a/src/main/java/io/supertokens/storageLayer/StorageLayer.java +++ b/src/main/java/io/supertokens/storageLayer/StorageLayer.java @@ -23,6 +23,7 @@ import io.supertokens.exceptions.QuitProgramException; import io.supertokens.inmemorydb.Start; import io.supertokens.output.Logging; +import io.supertokens.pluginInterface.ActiveUsersStorage; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.pluginInterface.Storage; import io.supertokens.pluginInterface.authRecipe.AuthRecipeStorage; @@ -175,6 +176,14 @@ public static AuthRecipeStorage getAuthRecipeStorage(Main main) { return (AuthRecipeStorage) getInstance(main).storage; } + public static ActiveUsersStorage getActiveUsersStorage(Main main) { + if (getInstance(main) == null) { + throw new QuitProgramException("please call init() before calling getStorageLayer"); + } + + return (ActiveUsersStorage) getInstance(main).storage; + } + public static SessionStorage getSessionStorage(Main main) { if (getInstance(main) == null) { throw new QuitProgramException("please call init() before calling getStorageLayer"); diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index b055f95c7..4c5040d93 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -14,6 +14,7 @@ import io.supertokens.config.Config; import com.eatthepath.otp.TimeBasedOneTimePasswordGenerator; + import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.totp.TOTPDevice; diff --git a/src/main/java/io/supertokens/webserver/api/emailpassword/SignInAPI.java b/src/main/java/io/supertokens/webserver/api/emailpassword/SignInAPI.java index 7d09ab04f..3b598601c 100644 --- a/src/main/java/io/supertokens/webserver/api/emailpassword/SignInAPI.java +++ b/src/main/java/io/supertokens/webserver/api/emailpassword/SignInAPI.java @@ -19,6 +19,8 @@ import com.google.gson.Gson; import com.google.gson.JsonObject; import com.google.gson.JsonParser; + +import io.supertokens.ActiveUsers; import io.supertokens.Main; import io.supertokens.emailpassword.EmailPassword; import io.supertokens.emailpassword.exceptions.WrongCredentialsException; @@ -65,6 +67,8 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I try { UserInfo user = EmailPassword.signIn(super.main, normalisedEmail, password); + ActiveUsers.updateLastActive(main, user.id); // use the internal user id + // if a userIdMapping exists, pass the externalUserId to the response UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, user.id, UserIdType.ANY); diff --git a/src/main/java/io/supertokens/webserver/api/emailpassword/SignUpAPI.java b/src/main/java/io/supertokens/webserver/api/emailpassword/SignUpAPI.java index 45310c081..6c4ed02b9 100644 --- a/src/main/java/io/supertokens/webserver/api/emailpassword/SignUpAPI.java +++ b/src/main/java/io/supertokens/webserver/api/emailpassword/SignUpAPI.java @@ -19,6 +19,8 @@ import com.google.gson.Gson; import com.google.gson.JsonObject; import com.google.gson.JsonParser; + +import io.supertokens.ActiveUsers; import io.supertokens.Main; import io.supertokens.emailpassword.EmailPassword; import io.supertokens.output.Logging; @@ -67,6 +69,8 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I try { UserInfo user = EmailPassword.signUp(super.main, normalisedEmail, password); + ActiveUsers.updateLastActive(main, user.id); + JsonObject result = new JsonObject(); result.addProperty("status", "OK"); JsonObject userJson = new JsonParser().parse(new Gson().toJson(user)).getAsJsonObject(); diff --git a/src/main/java/io/supertokens/webserver/api/passwordless/ConsumeCodeAPI.java b/src/main/java/io/supertokens/webserver/api/passwordless/ConsumeCodeAPI.java index fb508e230..7c37c36a9 100644 --- a/src/main/java/io/supertokens/webserver/api/passwordless/ConsumeCodeAPI.java +++ b/src/main/java/io/supertokens/webserver/api/passwordless/ConsumeCodeAPI.java @@ -19,6 +19,8 @@ import com.google.gson.Gson; import com.google.gson.JsonObject; import com.google.gson.JsonParser; + +import io.supertokens.ActiveUsers; import io.supertokens.Main; import io.supertokens.passwordless.Passwordless; import io.supertokens.passwordless.Passwordless.ConsumeCodeResponse; @@ -81,6 +83,8 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I ConsumeCodeResponse consumeCodeResponse = Passwordless.consumeCode(main, deviceId, deviceIdHash, userInputCode, linkCode); + ActiveUsers.updateLastActive(main, consumeCodeResponse.user.id); + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(main, consumeCodeResponse.user.id, UserIdType.ANY); if (userIdMapping != null) { diff --git a/src/main/java/io/supertokens/webserver/api/session/RefreshSessionAPI.java b/src/main/java/io/supertokens/webserver/api/session/RefreshSessionAPI.java index 60308d563..9fcc4c7c6 100644 --- a/src/main/java/io/supertokens/webserver/api/session/RefreshSessionAPI.java +++ b/src/main/java/io/supertokens/webserver/api/session/RefreshSessionAPI.java @@ -19,6 +19,8 @@ import com.google.gson.Gson; import com.google.gson.JsonObject; import com.google.gson.JsonParser; + +import io.supertokens.ActiveUsers; import io.supertokens.Main; import io.supertokens.exceptions.TokenTheftDetectedException; import io.supertokens.exceptions.UnauthorisedException; @@ -26,6 +28,8 @@ import io.supertokens.pluginInterface.RECIPE_ID; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.useridmapping.UserIdMapping; +import io.supertokens.useridmapping.UserIdType; import io.supertokens.session.Session; import io.supertokens.session.info.SessionInformationHolder; import io.supertokens.utils.Utils; @@ -61,6 +65,15 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I try { SessionInformationHolder sessionInfo = Session.refreshSession(main, refreshToken, antiCsrfToken, enableAntiCsrf); + + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, + sessionInfo.session.userId, UserIdType.ANY); + if (userIdMapping != null) { + ActiveUsers.updateLastActive(main, userIdMapping.superTokensUserId); + } else { + ActiveUsers.updateLastActive(main, sessionInfo.session.userId); + } + JsonObject result = sessionInfo.toJsonObject(); result.addProperty("status", "OK"); super.sendJsonResponse(200, result, resp); diff --git a/src/main/java/io/supertokens/webserver/api/session/SessionAPI.java b/src/main/java/io/supertokens/webserver/api/session/SessionAPI.java index 8ef9a0441..0ea621cd9 100644 --- a/src/main/java/io/supertokens/webserver/api/session/SessionAPI.java +++ b/src/main/java/io/supertokens/webserver/api/session/SessionAPI.java @@ -19,7 +19,8 @@ import com.google.gson.Gson; import com.google.gson.JsonArray; import com.google.gson.JsonObject; -import com.google.gson.JsonParser; + +import io.supertokens.ActiveUsers; import io.supertokens.Main; import io.supertokens.exceptions.UnauthorisedException; import io.supertokens.output.Logging; @@ -27,6 +28,8 @@ import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.session.SessionInfo; +import io.supertokens.pluginInterface.useridmapping.UserIdMapping; +import io.supertokens.useridmapping.UserIdType; import io.supertokens.session.Session; import io.supertokens.session.accessToken.AccessTokenSigningKey; import io.supertokens.session.accessToken.AccessTokenSigningKey.KeyInfo; @@ -77,6 +80,14 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I SessionInformationHolder sessionInfo = Session.createNewSession(main, userId, userDataInJWT, userDataInDatabase, enableAntiCsrf); + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, + sessionInfo.session.userId, UserIdType.ANY); + if (userIdMapping != null) { + ActiveUsers.updateLastActive(main, userIdMapping.superTokensUserId); + } else { + ActiveUsers.updateLastActive(main, sessionInfo.session.userId); + } + JsonObject result = sessionInfo.toJsonObject(); result.addProperty("status", "OK"); diff --git a/src/main/java/io/supertokens/webserver/api/session/SessionRemoveAPI.java b/src/main/java/io/supertokens/webserver/api/session/SessionRemoveAPI.java index 213f3e943..70d9ac092 100644 --- a/src/main/java/io/supertokens/webserver/api/session/SessionRemoveAPI.java +++ b/src/main/java/io/supertokens/webserver/api/session/SessionRemoveAPI.java @@ -19,9 +19,13 @@ import com.google.gson.JsonArray; import com.google.gson.JsonObject; import com.google.gson.JsonPrimitive; + +import io.supertokens.ActiveUsers; import io.supertokens.Main; import io.supertokens.pluginInterface.RECIPE_ID; import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.useridmapping.UserIdMapping; +import io.supertokens.useridmapping.UserIdType; import io.supertokens.session.Session; import io.supertokens.webserver.InputParser; import io.supertokens.webserver.WebserverAPI; @@ -74,6 +78,15 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I if (userId != null) { try { String[] sessionHandlesRevoked = Session.revokeAllSessionsForUser(main, userId); + + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, + userId, UserIdType.ANY); + if (userIdMapping != null) { + ActiveUsers.updateLastActive(main, userIdMapping.superTokensUserId); + } else { + ActiveUsers.updateLastActive(main, userId); + } + JsonObject result = new JsonObject(); result.addProperty("status", "OK"); JsonArray sessionHandlesRevokedJSON = new JsonArray(); diff --git a/src/main/java/io/supertokens/webserver/api/thirdparty/SignInUpAPI.java b/src/main/java/io/supertokens/webserver/api/thirdparty/SignInUpAPI.java index 7f768b26b..36bae966b 100644 --- a/src/main/java/io/supertokens/webserver/api/thirdparty/SignInUpAPI.java +++ b/src/main/java/io/supertokens/webserver/api/thirdparty/SignInUpAPI.java @@ -19,6 +19,8 @@ import com.google.gson.Gson; import com.google.gson.JsonObject; import com.google.gson.JsonParser; + +import io.supertokens.ActiveUsers; import io.supertokens.Main; import io.supertokens.pluginInterface.RECIPE_ID; import io.supertokens.pluginInterface.exceptions.StorageQueryException; @@ -70,6 +72,8 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I ThirdParty.SignInUpResponse response = ThirdParty.signInUp2_7(super.main, thirdPartyId, thirdPartyUserId, email, isEmailVerified); + ActiveUsers.updateLastActive(main, response.user.id); + JsonObject result = new JsonObject(); result.addProperty("status", "OK"); result.addProperty("createdNewUser", response.createdNewUser); @@ -100,6 +104,8 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I ThirdParty.SignInUpResponse response = ThirdParty.signInUp(super.main, thirdPartyId, thirdPartyUserId, email); + ActiveUsers.updateLastActive(main, response.user.id); + // io.supertokens.pluginInterface.useridmapping.UserIdMapping userIdMapping = UserIdMapping .getUserIdMapping(main, response.user.id, UserIdType.ANY); From b02a42099877dcef78d06f987f27d5c48eda7beb Mon Sep 17 00:00:00 2001 From: Kumar Shivendu Date: Wed, 22 Mar 2023 12:23:28 +0530 Subject: [PATCH 36/42] test: Add tests for active users update across different API calls (#586) * test: Add tests for active users update across different API calls * fix: Suppress usermapping exceptions for active users monitoring (#587) --- src/main/java/io/supertokens/ActiveUsers.java | 4 ++ .../api/session/RefreshSessionAPI.java | 18 +++--- .../webserver/api/session/SessionAPI.java | 15 +++-- .../api/session/SessionRemoveAPI.java | 16 +++-- .../io/supertokens/test/ActiveUsersTest.java | 61 +++++++++++++++++++ .../emailpassword/api/SignInAPITest2_7.java | 18 +++++- .../emailpassword/api/SignUpAPITest2_7.java | 15 ++++- .../PasswordlessConsumeCodeTest.java | 6 +- .../PasswordlessConsumeCodeAPITest2_11.java | 26 ++++++++ .../session/api/RefreshSessionAPITest2_7.java | 50 ++++++++++++++- .../test/session/api/SessionAPITest2_7.java | 18 ++++++ .../test/session/api/SessionAPITest2_9.java | 15 +++++ .../session/api/SessionRemoveAPITest2_7.java | 17 ++++++ .../api/ThirdPartySignInUpAPITest2_7.java | 15 +++++ .../api/ThirdPartySignInUpAPITest2_8.java | 15 +++++ 15 files changed, 281 insertions(+), 28 deletions(-) create mode 100644 src/test/java/io/supertokens/test/ActiveUsersTest.java diff --git a/src/main/java/io/supertokens/ActiveUsers.java b/src/main/java/io/supertokens/ActiveUsers.java index f2ad3627e..6503389f9 100644 --- a/src/main/java/io/supertokens/ActiveUsers.java +++ b/src/main/java/io/supertokens/ActiveUsers.java @@ -11,4 +11,8 @@ public static void updateLastActive(Main main, String userId) { } catch (StorageQueryException ignored) { } } + + public static int countUsersActiveSince(Main main, long time) throws StorageQueryException { + return StorageLayer.getActiveUsersStorage(main).countUsersActiveSince(time); + } } diff --git a/src/main/java/io/supertokens/webserver/api/session/RefreshSessionAPI.java b/src/main/java/io/supertokens/webserver/api/session/RefreshSessionAPI.java index 9fcc4c7c6..c1220aa8a 100644 --- a/src/main/java/io/supertokens/webserver/api/session/RefreshSessionAPI.java +++ b/src/main/java/io/supertokens/webserver/api/session/RefreshSessionAPI.java @@ -16,9 +16,7 @@ package io.supertokens.webserver.api.session; -import com.google.gson.Gson; import com.google.gson.JsonObject; -import com.google.gson.JsonParser; import io.supertokens.ActiveUsers; import io.supertokens.Main; @@ -66,14 +64,18 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I SessionInformationHolder sessionInfo = Session.refreshSession(main, refreshToken, antiCsrfToken, enableAntiCsrf); - UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, - sessionInfo.session.userId, UserIdType.ANY); - if (userIdMapping != null) { - ActiveUsers.updateLastActive(main, userIdMapping.superTokensUserId); - } else { - ActiveUsers.updateLastActive(main, sessionInfo.session.userId); + try{ + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, + sessionInfo.session.userId, UserIdType.ANY); + if (userIdMapping != null) { + ActiveUsers.updateLastActive(main, userIdMapping.superTokensUserId); + } else { + ActiveUsers.updateLastActive(main, sessionInfo.session.userId); + } + } catch (StorageQueryException ignored){ } + JsonObject result = sessionInfo.toJsonObject(); result.addProperty("status", "OK"); super.sendJsonResponse(200, result, resp); diff --git a/src/main/java/io/supertokens/webserver/api/session/SessionAPI.java b/src/main/java/io/supertokens/webserver/api/session/SessionAPI.java index 0ea621cd9..fb297a8ff 100644 --- a/src/main/java/io/supertokens/webserver/api/session/SessionAPI.java +++ b/src/main/java/io/supertokens/webserver/api/session/SessionAPI.java @@ -80,12 +80,15 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I SessionInformationHolder sessionInfo = Session.createNewSession(main, userId, userDataInJWT, userDataInDatabase, enableAntiCsrf); - UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, - sessionInfo.session.userId, UserIdType.ANY); - if (userIdMapping != null) { - ActiveUsers.updateLastActive(main, userIdMapping.superTokensUserId); - } else { - ActiveUsers.updateLastActive(main, sessionInfo.session.userId); + try { + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, + sessionInfo.session.userId, UserIdType.ANY); + if (userIdMapping != null) { + ActiveUsers.updateLastActive(main, userIdMapping.superTokensUserId); + } else { + ActiveUsers.updateLastActive(main, sessionInfo.session.userId); + } + } catch (StorageQueryException ignored) { } JsonObject result = sessionInfo.toJsonObject(); diff --git a/src/main/java/io/supertokens/webserver/api/session/SessionRemoveAPI.java b/src/main/java/io/supertokens/webserver/api/session/SessionRemoveAPI.java index 70d9ac092..92a2d77a3 100644 --- a/src/main/java/io/supertokens/webserver/api/session/SessionRemoveAPI.java +++ b/src/main/java/io/supertokens/webserver/api/session/SessionRemoveAPI.java @@ -79,12 +79,16 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I try { String[] sessionHandlesRevoked = Session.revokeAllSessionsForUser(main, userId); - UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, - userId, UserIdType.ANY); - if (userIdMapping != null) { - ActiveUsers.updateLastActive(main, userIdMapping.superTokensUserId); - } else { - ActiveUsers.updateLastActive(main, userId); + try { + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping( + super.main, + userId, UserIdType.ANY); + if (userIdMapping != null) { + ActiveUsers.updateLastActive(main, userIdMapping.superTokensUserId); + } else { + ActiveUsers.updateLastActive(main, userId); + } + } catch (StorageQueryException ignored) { } JsonObject result = new JsonObject(); diff --git a/src/test/java/io/supertokens/test/ActiveUsersTest.java b/src/test/java/io/supertokens/test/ActiveUsersTest.java new file mode 100644 index 000000000..2b27708bd --- /dev/null +++ b/src/test/java/io/supertokens/test/ActiveUsersTest.java @@ -0,0 +1,61 @@ +package io.supertokens.test; + +import static org.junit.Assert.assertNotNull; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.ActiveUsers; +import io.supertokens.Main; +import io.supertokens.ProcessState; + +public class ActiveUsersTest { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + @Test + public void updateAndCountUserLastActiveTest() throws Exception { + String[] args = { "../" }; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + assert (false); + } + + Main main = process.getProcess(); + long now = System.currentTimeMillis(); + + assert ActiveUsers.countUsersActiveSince(main, now) == 0; + + ActiveUsers.updateLastActive(main, "user1"); + ActiveUsers.updateLastActive(main, "user2"); + + assert ActiveUsers.countUsersActiveSince(main, now) == 2; + + long now2 = System.currentTimeMillis(); + + ActiveUsers.updateLastActive(main, "user1"); + + assert ActiveUsers.countUsersActiveSince(main, now) == 2; // user1 just got updated + assert ActiveUsers.countUsersActiveSince(main, now2) == 1; // only user1 is counted + } + +} diff --git a/src/test/java/io/supertokens/test/emailpassword/api/SignInAPITest2_7.java b/src/test/java/io/supertokens/test/emailpassword/api/SignInAPITest2_7.java index 1691c2900..8fdfcbdea 100644 --- a/src/test/java/io/supertokens/test/emailpassword/api/SignInAPITest2_7.java +++ b/src/test/java/io/supertokens/test/emailpassword/api/SignInAPITest2_7.java @@ -17,6 +17,8 @@ package io.supertokens.test.emailpassword.api; import com.google.gson.JsonObject; + +import io.supertokens.ActiveUsers; import io.supertokens.ProcessState; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.storageLayer.StorageLayer; @@ -68,6 +70,8 @@ public void testBadInput() throws Exception { return; } + long startTs = System.currentTimeMillis(); + { try { HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", @@ -107,6 +111,9 @@ public void testBadInput() throws Exception { } } + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 0); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } @@ -133,6 +140,8 @@ public void testGoodInput() throws Exception { responseBody.addProperty("email", "random@gmail.com"); responseBody.addProperty("password", "validPass123"); + long beforeSignIn = System.currentTimeMillis(); + JsonObject signInResponse = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", "http://localhost:3567/recipe/signin", responseBody, 1000, 1000, null, Utils.getCdiVersion2_7ForTests(), "emailpassword"); @@ -147,11 +156,15 @@ public void testGoodInput() throws Exception { signInResponse.get("user").getAsJsonObject().get("timeJoined").getAsLong(); assertEquals(signInResponse.get("user").getAsJsonObject().entrySet().size(), 3); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), beforeSignIn); + assert (activeUsers == 1); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } - // Test that sign in with unnormalised email like Test@gmail.com should also work + // Test that sign in with unnormalised email like Test@gmail.com should also + // work @Test public void testThatUnnormalisedEmailShouldAlsoWork() throws Exception { String[] args = { "../" }; @@ -190,7 +203,8 @@ public void testThatUnnormalisedEmailShouldAlsoWork() throws Exception { assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } - // Test that giving an empty password, empty email, invalid email, random email or wrong password throws a wrong + // Test that giving an empty password, empty email, invalid email, random email + // or wrong password throws a wrong // * credentials error @Test public void testInputsToSignInAPI() throws Exception { diff --git a/src/test/java/io/supertokens/test/emailpassword/api/SignUpAPITest2_7.java b/src/test/java/io/supertokens/test/emailpassword/api/SignUpAPITest2_7.java index 5e921a2d3..d47ae0a79 100644 --- a/src/test/java/io/supertokens/test/emailpassword/api/SignUpAPITest2_7.java +++ b/src/test/java/io/supertokens/test/emailpassword/api/SignUpAPITest2_7.java @@ -17,6 +17,8 @@ package io.supertokens.test.emailpassword.api; import com.google.gson.JsonObject; + +import io.supertokens.ActiveUsers; import io.supertokens.ProcessState; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.pluginInterface.emailpassword.UserInfo; @@ -68,6 +70,8 @@ public void testBadInput() throws Exception { return; } + long beforeTestTs = System.currentTimeMillis(); + { try { HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", @@ -107,6 +111,9 @@ public void testBadInput() throws Exception { } } + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), beforeTestTs); + assert (activeUsers == 0); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } @@ -123,6 +130,8 @@ public void testGoodInput() throws Exception { return; } + long beforeSignUpTs = System.currentTimeMillis(); + JsonObject signUpResponse = Utils.signUpRequest_2_5(process, "random@gmail.com", "validPass123"); assertEquals(signUpResponse.get("status").getAsString(), "OK"); assertEquals(signUpResponse.entrySet().size(), 2); @@ -131,6 +140,9 @@ public void testGoodInput() throws Exception { assertEquals(signUpUser.get("email").getAsString(), "random@gmail.com"); assertNotNull(signUpUser.get("id")); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), beforeSignUpTs); + assert (activeUsers == 0); + UserInfo user = StorageLayer.getEmailPasswordStorage(process.getProcess()) .getUserInfoUsingEmail("random@gmail.com"); assertEquals(user.email, signUpUser.get("email").getAsString()); @@ -160,7 +172,8 @@ public void testGoodInput() throws Exception { // Test the normalise email function // Test that only the normalised email is saved in the db - // Failure condition: If the email retrieved from the data is not normalised the test will fail + // Failure condition: If the email retrieved from the data is not normalised the + // test will fail @Test public void testTheNormaliseEmailFunction() throws Exception { String[] args = { "../" }; diff --git a/src/test/java/io/supertokens/test/passwordless/PasswordlessConsumeCodeTest.java b/src/test/java/io/supertokens/test/passwordless/PasswordlessConsumeCodeTest.java index bd716a73d..e83aaa81e 100644 --- a/src/test/java/io/supertokens/test/passwordless/PasswordlessConsumeCodeTest.java +++ b/src/test/java/io/supertokens/test/passwordless/PasswordlessConsumeCodeTest.java @@ -234,7 +234,8 @@ public void testConsumeLinkCodeWithExistingUser() throws Exception { } /** - * Check device clean up when user input code is generated via email & phone number + * Check device clean up when user input code is generated via email & phone + * number * * @throws Exception */ @@ -732,7 +733,8 @@ public void testConsumeWrongLinkCodeExceedingMaxAttempts() throws Exception { } /** - * user input code with too many failedAttempts (changed maxCodeInputAttempts configuration between consumes) + * user input code with too many failedAttempts (changed maxCodeInputAttempts + * configuration between consumes) * TODO: review -> do we need to create code again post restart ? * * @throws Exception diff --git a/src/test/java/io/supertokens/test/passwordless/api/PasswordlessConsumeCodeAPITest2_11.java b/src/test/java/io/supertokens/test/passwordless/api/PasswordlessConsumeCodeAPITest2_11.java index 9cfe9eaa6..cf8f36950 100644 --- a/src/test/java/io/supertokens/test/passwordless/api/PasswordlessConsumeCodeAPITest2_11.java +++ b/src/test/java/io/supertokens/test/passwordless/api/PasswordlessConsumeCodeAPITest2_11.java @@ -27,6 +27,7 @@ import org.junit.Test; import org.junit.rules.TestRule; +import io.supertokens.ActiveUsers; import io.supertokens.ProcessState; import io.supertokens.passwordless.Passwordless; import io.supertokens.passwordless.Passwordless.CreateCodeResponse; @@ -62,6 +63,8 @@ public void testBadInput() throws Exception { return; } + long startTs = System.currentTimeMillis(); + String email = "test@example.com"; CreateCodeResponse createResp = Passwordless.createCode(process.getProcess(), email, null, null, null); { @@ -276,6 +279,9 @@ public void testBadInput() throws Exception { error.getMessage()); } + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 0); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } @@ -291,6 +297,8 @@ public void testLinkCode() throws Exception { return; } + long startTs = System.currentTimeMillis(); + String email = "test@example.com"; CreateCodeResponse createResp = Passwordless.createCode(process.getProcess(), email, null, null, null); @@ -304,6 +312,9 @@ public void testLinkCode() throws Exception { checkResponse(response, true, email, null); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 1); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } @@ -321,6 +332,8 @@ public void testExpiredLinkCode() throws Exception { return; } + long startTs = System.currentTimeMillis(); + String email = "test@example.com"; CreateCodeResponse createResp = Passwordless.createCode(process.getProcess(), email, null, null, null); Thread.sleep(150); @@ -334,6 +347,9 @@ public void testExpiredLinkCode() throws Exception { assertEquals("RESTART_FLOW_ERROR", response.get("status").getAsString()); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 0); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } @@ -349,6 +365,8 @@ public void testUserInputCode() throws Exception { return; } + long startTs = System.currentTimeMillis(); + String email = "test@example.com"; CreateCodeResponse createResp = Passwordless.createCode(process.getProcess(), email, null, null, null); @@ -363,6 +381,9 @@ public void testUserInputCode() throws Exception { checkResponse(response, true, email, null); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 1); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } @@ -379,6 +400,8 @@ public void testExpiredUserInputCode() throws Exception { return; } + long startTs = System.currentTimeMillis(); + String email = "test@example.com"; CreateCodeResponse createResp = Passwordless.createCode(process.getProcess(), email, null, null, null); Thread.sleep(150); @@ -394,6 +417,9 @@ public void testExpiredUserInputCode() throws Exception { assertEquals("EXPIRED_USER_INPUT_CODE_ERROR", response.get("status").getAsString()); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 0); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } diff --git a/src/test/java/io/supertokens/test/session/api/RefreshSessionAPITest2_7.java b/src/test/java/io/supertokens/test/session/api/RefreshSessionAPITest2_7.java index 6fd495fd2..72ab594e5 100644 --- a/src/test/java/io/supertokens/test/session/api/RefreshSessionAPITest2_7.java +++ b/src/test/java/io/supertokens/test/session/api/RefreshSessionAPITest2_7.java @@ -18,6 +18,8 @@ import com.google.gson.JsonNull; import com.google.gson.JsonObject; + +import io.supertokens.ActiveUsers; import io.supertokens.ProcessState; import io.supertokens.test.TestingProcessManager; import io.supertokens.test.Utils; @@ -35,7 +37,7 @@ import static org.junit.Assert.fail; public class RefreshSessionAPITest2_7 { - @Rule + @Rule public TestRule watchman = Utils.getOnFailure(); @AfterClass @@ -187,6 +189,8 @@ public void badInputErrorTest() throws Exception { TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + long start1 = System.currentTimeMillis(); + try { HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", "http://localhost:3567/recipe/session/refresh", null, 1000, 1000, null, @@ -195,12 +199,18 @@ public void badInputErrorTest() throws Exception { } catch (io.supertokens.test.httpRequest.HttpResponseException e) { assertEquals("Http error. Status Code: 400. Message: Invalid Json Input", e.getMessage()); } + + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), start1); + assert (activeUsers == 0); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); process = TestingProcessManager.start(args); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + long start2 = System.currentTimeMillis(); + try { JsonObject jsonBody = new JsonObject(); jsonBody.addProperty("random", "random"); @@ -214,6 +224,13 @@ public void badInputErrorTest() throws Exception { } + activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), start2); + assert (activeUsers == 0); + + long start3 = System.currentTimeMillis(); + + long start3Inner = 0; // to be set after session is created + try { String userId = "userId"; JsonObject userDataInJWT = new JsonObject(); @@ -236,6 +253,8 @@ public void badInputErrorTest() throws Exception { sessionRefreshBody.addProperty("refreshToken", sessionInfo.get("refreshToken").getAsJsonObject().get("token").getAsString()); + start3Inner = System.currentTimeMillis(); + HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", "http://localhost:3567/recipe/session/refresh", sessionRefreshBody, 1000, 1000, null, Utils.getCdiVersion2_7ForTests(), "session"); @@ -246,6 +265,16 @@ public void badInputErrorTest() throws Exception { } + activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), start3); + assert (activeUsers == 1); + + int activeUsersAfterSessionCreate = ActiveUsers.countUsersActiveSince(process.getProcess(), start3Inner); + assert (activeUsersAfterSessionCreate == 0); + + long start4 = System.currentTimeMillis(); + + long start4Inner = 0; // to be set after session is created + try { String userId = "userId"; JsonObject userDataInJWT = new JsonObject(); @@ -269,6 +298,8 @@ public void badInputErrorTest() throws Exception { sessionInfo.get("refreshToken").getAsJsonObject().get("token").getAsString()); sessionRefreshBody.addProperty("enableAntiCsrf", "false"); + start4Inner = System.currentTimeMillis(); + HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", "http://localhost:3567/recipe/session/refresh", sessionRefreshBody, 1000, 1000, null, Utils.getCdiVersion2_7ForTests(), "session"); @@ -278,6 +309,12 @@ public void badInputErrorTest() throws Exception { "Http error. Status Code: 400. Message: Field name 'enableAntiCsrf' is invalid in JSON input"); } + + activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), start4); + assert (activeUsers == 1); + + activeUsersAfterSessionCreate = ActiveUsers.countUsersActiveSince(process.getProcess(), start4Inner); + assert (activeUsersAfterSessionCreate == 0); } @Test @@ -375,6 +412,8 @@ public void successOutputWithValidRefreshTokenTest() throws Exception { request.add("userDataInDatabase", userDataInDatabase); request.addProperty("enableAntiCsrf", false); + long startTs = System.currentTimeMillis(); + JsonObject sessionInfo = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", "http://localhost:3567/recipe/session", request, 1000, 1000, null, Utils.getCdiVersion2_7ForTests(), "session"); @@ -386,14 +425,19 @@ public void successOutputWithValidRefreshTokenTest() throws Exception { sessionInfo.get("refreshToken").getAsJsonObject().get("token").getAsString()); sessionRefreshBody.addProperty("enableAntiCsrf", false); + long afterSessionCreateTs = System.currentTimeMillis(); + JsonObject sessionRefreshResponse = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", "http://localhost:3567/recipe/session/refresh", sessionRefreshBody, 1000, 1000, null, Utils.getCdiVersion2_7ForTests(), "session"); checkRefreshSessionResponse(sessionRefreshResponse, process, userId, userDataInJWT, false); - process.kill(); - assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 1); + + int activeUsersAfterSessionCreate = ActiveUsers.countUsersActiveSince(process.getProcess(), afterSessionCreateTs); + assert (activeUsersAfterSessionCreate == 1); } @Test diff --git a/src/test/java/io/supertokens/test/session/api/SessionAPITest2_7.java b/src/test/java/io/supertokens/test/session/api/SessionAPITest2_7.java index 3af20714b..7813238c0 100644 --- a/src/test/java/io/supertokens/test/session/api/SessionAPITest2_7.java +++ b/src/test/java/io/supertokens/test/session/api/SessionAPITest2_7.java @@ -17,6 +17,8 @@ package io.supertokens.test.session.api; import com.google.gson.JsonObject; + +import io.supertokens.ActiveUsers; import io.supertokens.ProcessState; import io.supertokens.test.TestingProcessManager; import io.supertokens.test.Utils; @@ -54,6 +56,8 @@ public void successOutputCheckWithAntiCsrfWithCookieDomain() throws Exception { TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + long startTs = System.currentTimeMillis(); + String userId = "userId"; JsonObject userDataInJWT = new JsonObject(); userDataInJWT.addProperty("key", "value"); @@ -73,6 +77,9 @@ public void successOutputCheckWithAntiCsrfWithCookieDomain() throws Exception { checkSessionResponse(response, process, userId, userDataInJWT); assertTrue(response.has("antiCsrfToken")); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 1); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } @@ -150,6 +157,8 @@ public void badInputTest() throws Exception { JsonObject userDataInDatabase = new JsonObject(); userDataInDatabase.addProperty("key", "value"); + long startTs = System.currentTimeMillis(); + try { JsonObject request = new JsonObject(); request.add("userDataInJWT", userDataInJWT); @@ -219,6 +228,9 @@ public void badInputTest() throws Exception { + "input"); } + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 0); + JsonObject request = new JsonObject(); request.addProperty("userId", userId); request.add("userDataInJWT", userDataInJWT); @@ -227,6 +239,9 @@ public void badInputTest() throws Exception { HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", "http://localhost:3567/recipe/session", request, 1000, 1000, null, Utils.getCdiVersion2_7ForTests(), "session"); + activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 1); + request = new JsonObject(); request.addProperty("userId", userId); request.add("userDataInJWT", userDataInJWT); @@ -235,6 +250,9 @@ public void badInputTest() throws Exception { HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", "http://localhost:3567/recipe/session", request, 1000, 1000, null, Utils.getCdiVersion2_7ForTests(), "session"); + activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 1); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } diff --git a/src/test/java/io/supertokens/test/session/api/SessionAPITest2_9.java b/src/test/java/io/supertokens/test/session/api/SessionAPITest2_9.java index 6ebc5835e..8f1bfe467 100644 --- a/src/test/java/io/supertokens/test/session/api/SessionAPITest2_9.java +++ b/src/test/java/io/supertokens/test/session/api/SessionAPITest2_9.java @@ -19,6 +19,8 @@ import com.google.gson.JsonArray; import com.google.gson.JsonNull; import com.google.gson.JsonObject; + +import io.supertokens.ActiveUsers; import io.supertokens.ProcessState; import io.supertokens.test.TestingProcessManager; import io.supertokens.test.Utils; @@ -56,6 +58,8 @@ public void successOutputCheckWithAntiCsrfWithCookieDomain() throws Exception { TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + long startTs = System.currentTimeMillis(); + String userId = "userId"; JsonObject userDataInJWT = new JsonObject(); userDataInJWT.addProperty("key", "value"); @@ -75,6 +79,9 @@ public void successOutputCheckWithAntiCsrfWithCookieDomain() throws Exception { checkSessionResponse(response, process, userId, userDataInJWT); assertTrue(response.has("antiCsrfToken")); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 1); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } @@ -177,6 +184,8 @@ public void badInputTest() throws Exception { TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + long startTs = System.currentTimeMillis(); + String userId = "userId"; JsonObject userDataInJWT = new JsonObject(); userDataInJWT.addProperty("key", "value"); @@ -251,6 +260,9 @@ public void badInputTest() throws Exception { "Http error. Status Code: 400. Message: Field name 'userDataInDatabase' is invalid in JSON " + "input"); } + + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 0); JsonObject request = new JsonObject(); request.addProperty("userId", userId); @@ -267,6 +279,9 @@ public void badInputTest() throws Exception { request.addProperty("enableAntiCsrf", false); HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", "http://localhost:3567/recipe/session", request, 1000, 1000, null, Utils.getCdiVersion2_9ForTests(), "session"); + + activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 1); process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); diff --git a/src/test/java/io/supertokens/test/session/api/SessionRemoveAPITest2_7.java b/src/test/java/io/supertokens/test/session/api/SessionRemoveAPITest2_7.java index f534ce7e9..39c463a0f 100644 --- a/src/test/java/io/supertokens/test/session/api/SessionRemoveAPITest2_7.java +++ b/src/test/java/io/supertokens/test/session/api/SessionRemoveAPITest2_7.java @@ -19,6 +19,8 @@ import com.google.gson.JsonArray; import com.google.gson.JsonObject; import com.google.gson.JsonParser; + +import io.supertokens.ActiveUsers; import io.supertokens.ProcessState; import io.supertokens.test.TestingProcessManager; import io.supertokens.test.Utils; @@ -98,6 +100,8 @@ public void testRemovingMultipleSessionsGivesCorrectOutput() throws Exception { // remove s2 and s4 and make sure they are returned + long checkpoint1 = System.currentTimeMillis(); + String sessionRemoveBodyString = "{" + " sessionHandles : [ " + s2Info.get("session").getAsJsonObject().get("handle").getAsString() + " ," + s4Info.get("session").getAsJsonObject().get("handle").getAsString() + " ] " + "}"; @@ -107,6 +111,9 @@ public void testRemovingMultipleSessionsGivesCorrectOutput() throws Exception { Utils.getCdiVersion2_7ForTests(), "session"); JsonArray revokedSessions = sessionRemovedResponse.getAsJsonArray("sessionHandlesRevoked"); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), checkpoint1); + assert (activeUsers == 0); // user ID is not set so not counted as active (we don't have userId) + for (int i = 0; i < revokedSessions.size(); i++) { assertTrue(sessionRemoveBody.getAsJsonArray("sessionHandles").contains(revokedSessions.get(i))); } @@ -120,6 +127,8 @@ public void testRemovingMultipleSessionsGivesCorrectOutput() throws Exception { + s4Info.get("session").getAsJsonObject().get("handle").getAsString() + " ] " + "}"; sessionRemoveBody = new JsonParser().parse(sessionRemoveBodyString).getAsJsonObject(); + long checkpoint2 = System.currentTimeMillis(); + sessionRemovedResponse = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", "http://localhost:3567/recipe/session/remove", sessionRemoveBody, 1000, 1000, null, Utils.getCdiVersion2_7ForTests(), "session"); @@ -131,6 +140,9 @@ public void testRemovingMultipleSessionsGivesCorrectOutput() throws Exception { assertEquals(revokedSessions.size(), 2); + activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), checkpoint2); + assert (activeUsers == 0); // user ID is not set so not counted as active (we don't have userId) + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); @@ -221,6 +233,8 @@ public void testRemovingSessionByUserId() throws Exception { "session"); assertEquals(session2Info.get("status").getAsString(), "OK"); + long checkpoint1 = System.currentTimeMillis(); + // remove session using user id JsonObject removeSessionBody = new JsonObject(); removeSessionBody.addProperty("userId", userId); @@ -237,6 +251,9 @@ public void testRemovingSessionByUserId() throws Exception { assertTrue(sessionRemovedResponse.getAsJsonArray("sessionHandlesRevoked") .contains(session2Info.get("session").getAsJsonObject().get("handle"))); + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), checkpoint1); + assert (activeUsers == 1); // user ID is set + // check that the number of sessions for user is 0 Map userParams = new HashMap<>(); userParams.put("userId", userId); diff --git a/src/test/java/io/supertokens/test/thirdparty/api/ThirdPartySignInUpAPITest2_7.java b/src/test/java/io/supertokens/test/thirdparty/api/ThirdPartySignInUpAPITest2_7.java index 8409e2809..a9a01195c 100644 --- a/src/test/java/io/supertokens/test/thirdparty/api/ThirdPartySignInUpAPITest2_7.java +++ b/src/test/java/io/supertokens/test/thirdparty/api/ThirdPartySignInUpAPITest2_7.java @@ -17,6 +17,8 @@ package io.supertokens.test.thirdparty.api; import com.google.gson.JsonObject; + +import io.supertokens.ActiveUsers; import io.supertokens.ProcessState; import io.supertokens.emailverification.EmailVerification; import io.supertokens.pluginInterface.STORAGE_TYPE; @@ -69,6 +71,8 @@ public void testGoodInput() throws Exception { return; } + long startTs = System.currentTimeMillis(); + JsonObject response = Utils.signInUpRequest_2_7(process, "test@example.com", true, "testThirdPartyId", "testThirdPartyUserId"); checkSignInUpResponse(response, "testThirdPartyId", "testThirdPartyUserId", "test@example.com", true); @@ -78,6 +82,10 @@ public void testGoodInput() throws Exception { assertTrue(EmailVerification.isEmailVerified(process.getProcess(), user.get("id").getAsString(), user.get("email").getAsString())); } + + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 1); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } @@ -125,6 +133,9 @@ public void testBadInput() throws Exception { if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { return; } + + long startTs = System.currentTimeMillis(); + { try { HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", @@ -174,6 +185,10 @@ public void testBadInput() throws Exception { + "in " + "JSON input")); } } + + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 0); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } diff --git a/src/test/java/io/supertokens/test/thirdparty/api/ThirdPartySignInUpAPITest2_8.java b/src/test/java/io/supertokens/test/thirdparty/api/ThirdPartySignInUpAPITest2_8.java index 59dcbec82..25d697557 100644 --- a/src/test/java/io/supertokens/test/thirdparty/api/ThirdPartySignInUpAPITest2_8.java +++ b/src/test/java/io/supertokens/test/thirdparty/api/ThirdPartySignInUpAPITest2_8.java @@ -17,6 +17,8 @@ package io.supertokens.test.thirdparty.api; import com.google.gson.JsonObject; + +import io.supertokens.ActiveUsers; import io.supertokens.ProcessState; import io.supertokens.emailverification.EmailVerification; import io.supertokens.pluginInterface.STORAGE_TYPE; @@ -69,6 +71,8 @@ public void testGoodInput() throws Exception { return; } + long startTs = System.currentTimeMillis(); + JsonObject response = Utils.signInUpRequest_2_8(process, "test@examplE.com", "testThirdPartyId", "testThirdPartyUserId"); checkSignInUpResponse(response, "testThirdPartyId", "testThirdPartyUserId", "test@example.com", true); @@ -78,6 +82,10 @@ public void testGoodInput() throws Exception { assertFalse(EmailVerification.isEmailVerified(process.getProcess(), user.get("id").getAsString(), user.get("email").getAsString())); } + + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 1); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } @@ -125,6 +133,9 @@ public void testBadInput() throws Exception { if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { return; } + + long startTs = System.currentTimeMillis(); + { try { HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", @@ -174,6 +185,10 @@ public void testBadInput() throws Exception { + "in " + "JSON input")); } } + + int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), startTs); + assert (activeUsers == 0); + process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } From 5b2740e1208a263f844fe72b19302b8c56d37f48 Mon Sep 17 00:00:00 2001 From: Kumar Shivendu Date: Thu, 23 Mar 2023 18:26:37 +0530 Subject: [PATCH 37/42] feat: Make TOTP a paid feature and report stats (#589) * feat: Make TOTP a paid feature and report stats * test: Add test for TOTP usage stats * refactor: MAU should be sent irrespective of TOTP * refactor: Use internal supertokens user id in TOTP APIs (#591) * refactor: Use internal supertokens user id in TOTP APIs * test: Add test for user id mapping in TOTP APIs --- .../java/io/supertokens/ee/EEFeatureFlag.java | 56 ++++-- .../supertokens/featureflag/EE_FEATURES.java | 3 +- .../java/io/supertokens/inmemorydb/Start.java | 18 ++ .../queries/ActiveUsersQueries.java | 25 +++ .../api/totp/CreateOrUpdateTotpDeviceAPI.java | 17 +- .../webserver/api/totp/GetTotpDevicesAPI.java | 9 + .../api/totp/RemoveTotpDeviceAPI.java | 9 + .../webserver/api/totp/VerifyTotpAPI.java | 9 + .../api/totp/VerifyTotpDeviceAPI.java | 9 + .../io/supertokens/test/FeatureFlagTest.java | 98 ++++++++++ .../emailpassword/api/SignUpAPITest2_7.java | 2 +- .../totp/api/CreateTotpDeviceAPITest.java | 1 - .../test/totp/api/TotpUserIdMappingTest.java | 176 ++++++++++++++++++ 13 files changed, 417 insertions(+), 15 deletions(-) create mode 100644 src/test/java/io/supertokens/test/totp/api/TotpUserIdMappingTest.java diff --git a/ee/src/main/java/io/supertokens/ee/EEFeatureFlag.java b/ee/src/main/java/io/supertokens/ee/EEFeatureFlag.java index f87766863..3daa3caa9 100644 --- a/ee/src/main/java/io/supertokens/ee/EEFeatureFlag.java +++ b/ee/src/main/java/io/supertokens/ee/EEFeatureFlag.java @@ -6,10 +6,8 @@ import com.auth0.jwt.exceptions.JWTVerificationException; import com.auth0.jwt.interfaces.DecodedJWT; import com.auth0.jwt.interfaces.RSAKeyProvider; -import com.google.gson.JsonArray; -import com.google.gson.JsonObject; -import com.google.gson.JsonParser; -import com.google.gson.JsonPrimitive; +import com.google.gson.*; +import io.supertokens.ActiveUsers; import io.supertokens.Main; import io.supertokens.ProcessState; import io.supertokens.cronjobs.Cronjobs; @@ -21,6 +19,7 @@ import io.supertokens.httpRequest.HttpRequest; import io.supertokens.httpRequest.HttpResponseException; import io.supertokens.output.Logging; +import io.supertokens.pluginInterface.ActiveUsersStorage; import io.supertokens.pluginInterface.KeyValueInfo; import io.supertokens.pluginInterface.Storage; import io.supertokens.pluginInterface.exceptions.StorageQueryException; @@ -144,15 +143,50 @@ public Boolean getIsLicenseKeyPresent() { @Override public JsonObject getPaidFeatureStats() throws StorageQueryException { - JsonObject result = new JsonObject(); + JsonObject usageStats = new JsonObject(); EE_FEATURES[] features = getEnabledEEFeaturesFromDbOrCache(); - if (Arrays.stream(features).anyMatch(t -> t == EE_FEATURES.DASHBOARD_LOGIN)) { - JsonObject stats = new JsonObject(); - int userCount = StorageLayer.getDashboardStorage(main).getAllDashboardUsers().length; - stats.addProperty("user_count", userCount); - result.add(EE_FEATURES.DASHBOARD_LOGIN.toString(), stats); + ActiveUsersStorage activeUsersStorage = StorageLayer.getActiveUsersStorage(main); + + for (EE_FEATURES feature : features) { + if (feature == EE_FEATURES.DASHBOARD_LOGIN) { + JsonObject stats = new JsonObject(); + int userCount = StorageLayer.getDashboardStorage(main).getAllDashboardUsers().length; + stats.addProperty("user_count", userCount); + usageStats.add(EE_FEATURES.DASHBOARD_LOGIN.toString(), stats); + } + if (feature == EE_FEATURES.TOTP) { + JsonObject totpStats = new JsonObject(); + JsonArray totpMauArr = new JsonArray(); + + for (int i = 0; i < 30; i++) { + long now = System.currentTimeMillis(); + long today = now - (now % (24 * 60 * 60 * 1000L)); + long timestamp = today - (i * 24 * 60 * 60 * 1000L); + + int totpMau = activeUsersStorage.countUsersEnabledTotpAndActiveSince(timestamp); + totpMauArr.add(new JsonPrimitive(totpMau)); + } + + totpStats.add("maus", totpMauArr); + + int totpTotalUsers = activeUsersStorage.countUsersEnabledTotp(); + totpStats.addProperty("total_users", totpTotalUsers); + usageStats.add(EE_FEATURES.TOTP.toString(), totpStats); + } + } + + JsonArray mauArr = new JsonArray(); + for (int i = 0; i < 30; i++) { + long now = System.currentTimeMillis(); + long today = now - (now % (24 * 60 * 60 * 1000L)); + long timestamp = today - (i * 24 * 60 * 60 * 1000L); + + int mau = activeUsersStorage.countUsersActiveSince(timestamp); + mauArr.add(new JsonPrimitive(mau)); } - return result; + + usageStats.add("maus", mauArr); + return usageStats; } private EE_FEATURES[] verifyLicenseKey(String licenseKey) diff --git a/src/main/java/io/supertokens/featureflag/EE_FEATURES.java b/src/main/java/io/supertokens/featureflag/EE_FEATURES.java index 05e554dc6..820f2f73a 100644 --- a/src/main/java/io/supertokens/featureflag/EE_FEATURES.java +++ b/src/main/java/io/supertokens/featureflag/EE_FEATURES.java @@ -17,7 +17,8 @@ package io.supertokens.featureflag; public enum EE_FEATURES { - ACCOUNT_LINKING("account_linking"), MULTI_TENANCY("multi_tenancy"), TEST("test"), DASHBOARD_LOGIN("dashboard_login"); + ACCOUNT_LINKING("account_linking"), MULTI_TENANCY("multi_tenancy"), TEST("test"), DASHBOARD_LOGIN("dashboard_login"), + TOTP("totp"); private final String name; diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index cbc47bc6d..2cf559eee 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -460,6 +460,24 @@ public int countUsersActiveSince(long time) throws StorageQueryException { } } + @Override + public int countUsersEnabledTotp() throws StorageQueryException { + try { + return ActiveUsersQueries.countUsersEnabledTotp(this); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public int countUsersEnabledTotpAndActiveSince(long time) throws StorageQueryException { + try { + return ActiveUsersQueries.countUsersEnabledTotpAndActiveSince(this, time); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + @Override public SessionInfo getSessionInfo_Transaction(TransactionConnection con, String sessionHandle) throws StorageQueryException { diff --git a/src/main/java/io/supertokens/inmemorydb/queries/ActiveUsersQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/ActiveUsersQueries.java index 43fc22f52..ec1672f27 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/ActiveUsersQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/ActiveUsersQueries.java @@ -28,6 +28,31 @@ public static int countUsersActiveSince(Start start, long sinceTime) throws SQLE }); } + public static int countUsersEnabledTotp(Start start) throws SQLException, StorageQueryException { + String QUERY = "SELECT COUNT(*) as total FROM " + Config.getConfig(start).getTotpUsersTable(); + + return execute(start, QUERY, null, result -> { + if (result.next()) { + return result.getInt("total"); + } + return 0; + }); + } + + public static int countUsersEnabledTotpAndActiveSince(Start start, long sinceTime) throws SQLException, StorageQueryException { + String QUERY = "SELECT COUNT(*) as total FROM " + Config.getConfig(start).getTotpUsersTable() + " AS totp_users " + + "INNER JOIN " + Config.getConfig(start).getUserLastActiveTable() + " AS user_last_active " + + "ON totp_users.user_id = user_last_active.user_id " + + "WHERE user_last_active.last_active_time >= ?"; + + return execute(start, QUERY, pst -> pst.setLong(1, sinceTime), result -> { + if (result.next()) { + return result.getInt("total"); + } + return 0; + }); + } + public static int updateUserLastActive(Start start, String userId) throws SQLException, StorageQueryException { String QUERY = "INSERT INTO " + Config.getConfig(start).getUserLastActiveTable() + "(user_id, last_active_time) VALUES(?, ?) ON CONFLICT(user_id) DO UPDATE SET last_active_time = ?"; diff --git a/src/main/java/io/supertokens/webserver/api/totp/CreateOrUpdateTotpDeviceAPI.java b/src/main/java/io/supertokens/webserver/api/totp/CreateOrUpdateTotpDeviceAPI.java index 76b3e09cf..bb04f7070 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/CreateOrUpdateTotpDeviceAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/CreateOrUpdateTotpDeviceAPI.java @@ -4,7 +4,6 @@ import java.security.NoSuchAlgorithmException; import com.google.gson.JsonObject; - import io.supertokens.Main; import io.supertokens.pluginInterface.RECIPE_ID; import io.supertokens.pluginInterface.exceptions.StorageQueryException; @@ -12,7 +11,9 @@ import io.supertokens.pluginInterface.totp.exception.DeviceAlreadyExistsException; import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; +import io.supertokens.pluginInterface.useridmapping.UserIdMapping; import io.supertokens.totp.Totp; +import io.supertokens.useridmapping.UserIdType; import io.supertokens.webserver.InputParser; import io.supertokens.webserver.WebserverAPI; import jakarta.servlet.ServletException; @@ -59,6 +60,13 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I JsonObject result = new JsonObject(); try { + // This step is required only because user_last_active table stores supertokens internal user id. + // While sending the usage stats we do a join, so totp tables also must use internal user id. + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, userId, UserIdType.ANY); + if (userIdMapping != null) { + userId = userIdMapping.superTokensUserId; + } + TOTPDevice device = Totp.registerDevice(main, userId, deviceName, skew, period); result.addProperty("status", "OK"); @@ -93,6 +101,13 @@ protected void doPut(HttpServletRequest req, HttpServletResponse resp) throws IO JsonObject result = new JsonObject(); try { + // This step is required only because user_last_active table stores supertokens internal user id. + // While sending the usage stats we do a join, so totp tables also must use internal user id. + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, userId, UserIdType.ANY); + if (userIdMapping != null) { + userId = userIdMapping.superTokensUserId; + } + Totp.updateDeviceName(main, userId, existingDeviceName, newDeviceName); result.addProperty("status", "OK"); diff --git a/src/main/java/io/supertokens/webserver/api/totp/GetTotpDevicesAPI.java b/src/main/java/io/supertokens/webserver/api/totp/GetTotpDevicesAPI.java index e8a9d156d..9bda4ed27 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/GetTotpDevicesAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/GetTotpDevicesAPI.java @@ -10,7 +10,9 @@ import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.totp.TOTPDevice; import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.useridmapping.UserIdMapping; import io.supertokens.totp.Totp; +import io.supertokens.useridmapping.UserIdType; import io.supertokens.webserver.InputParser; import io.supertokens.webserver.WebserverAPI; import jakarta.servlet.ServletException; @@ -40,6 +42,13 @@ protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IO JsonObject result = new JsonObject(); try { + // This step is required only because user_last_active table stores supertokens internal user id. + // While sending the usage stats we do a join, so totp tables also must use internal user id. + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, userId, UserIdType.ANY); + if (userIdMapping != null) { + userId = userIdMapping.superTokensUserId; + } + TOTPDevice[] devices = Totp.getDevices(main, userId); JsonArray devicesArray = new JsonArray(); diff --git a/src/main/java/io/supertokens/webserver/api/totp/RemoveTotpDeviceAPI.java b/src/main/java/io/supertokens/webserver/api/totp/RemoveTotpDeviceAPI.java index b2a079694..634e6fe20 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/RemoveTotpDeviceAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/RemoveTotpDeviceAPI.java @@ -10,7 +10,9 @@ import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; +import io.supertokens.pluginInterface.useridmapping.UserIdMapping; import io.supertokens.totp.Totp; +import io.supertokens.useridmapping.UserIdType; import io.supertokens.webserver.InputParser; import io.supertokens.webserver.WebserverAPI; import jakarta.servlet.ServletException; @@ -46,6 +48,13 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I JsonObject result = new JsonObject(); try { + // This step is required only because user_last_active table stores supertokens internal user id. + // While sending the usage stats we do a join, so totp tables also must use internal user id. + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, userId, UserIdType.ANY); + if (userIdMapping != null) { + userId = userIdMapping.superTokensUserId; + } + Totp.removeDevice(main, userId, deviceName); result.addProperty("status", "OK"); diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java index 027eaca5d..af9400863 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java @@ -9,9 +9,11 @@ import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.useridmapping.UserIdMapping; import io.supertokens.totp.Totp; import io.supertokens.totp.exceptions.InvalidTotpException; import io.supertokens.totp.exceptions.LimitReachedException; +import io.supertokens.useridmapping.UserIdType; import io.supertokens.webserver.InputParser; import io.supertokens.webserver.WebserverAPI; import jakarta.servlet.ServletException; @@ -49,6 +51,13 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I JsonObject result = new JsonObject(); try { + // This step is required only because user_last_active table stores supertokens internal user id. + // While sending the usage stats we do a join, so totp tables also must use internal user id. + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, userId, UserIdType.ANY); + if (userIdMapping != null) { + userId = userIdMapping.superTokensUserId; + } + Totp.verifyCode(main, userId, totp, allowUnverifiedDevices); result.addProperty("status", "OK"); diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java index 00f38d40f..f1dd1ba0a 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java @@ -10,9 +10,11 @@ import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; +import io.supertokens.pluginInterface.useridmapping.UserIdMapping; import io.supertokens.totp.Totp; import io.supertokens.totp.exceptions.InvalidTotpException; import io.supertokens.totp.exceptions.LimitReachedException; +import io.supertokens.useridmapping.UserIdType; import io.supertokens.webserver.InputParser; import io.supertokens.webserver.WebserverAPI; import jakarta.servlet.ServletException; @@ -52,6 +54,13 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I JsonObject result = new JsonObject(); try { + // This step is required only because user_last_active table stores supertokens internal user id. + // While sending the usage stats we do a join, so totp tables also must use internal user id. + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, userId, UserIdType.ANY); + if (userIdMapping != null) { + userId = userIdMapping.superTokensUserId; + } + boolean isNewlyVerified = Totp.verifyDevice(main, userId, deviceName, totp); result.addProperty("status", "OK"); diff --git a/src/test/java/io/supertokens/test/FeatureFlagTest.java b/src/test/java/io/supertokens/test/FeatureFlagTest.java index bce7b57b6..926595b44 100644 --- a/src/test/java/io/supertokens/test/FeatureFlagTest.java +++ b/src/test/java/io/supertokens/test/FeatureFlagTest.java @@ -16,6 +16,7 @@ package io.supertokens.test; +import com.google.gson.JsonArray; import com.google.gson.JsonObject; import io.supertokens.ProcessState; import io.supertokens.featureflag.FeatureFlag; @@ -111,4 +112,101 @@ public void testThatCallingGetFeatureFlagAPIReturnsEmptyArray() throws Exception process.kill(); Assert.assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } + + private final String OPAQUE_KEY_WITH_TOTP_FEATURE = "pXhNK=nYiEsb6gJEOYP2kIR6M0kn4XLvNqcwT1XbX8xHtm44K-lQfGCbaeN0Ieeza39fxkXr=tiiUU=DXxDH40Y=4FLT4CE-rG1ETjkXxO4yucLpJvw3uSegPayoISGL"; + + @Test + public void testThatCallingGetFeatureFlagAPIReturnsTotpStats() throws Exception { + String[] args = {"../"}; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + Assert.assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlag.getInstance(process.main).setLicenseKeyAndSyncFeatures(OPAQUE_KEY_WITH_TOTP_FEATURE); + + // Get the stats without any users/activity + { + JsonObject response = HttpRequestForTesting.sendGETRequest(process.getProcess(), "", + "http://localhost:3567/ee/featureflag", + null, 1000, 1000, null, WebserverAPI.getLatestCDIVersion(), ""); + Assert.assertEquals("OK", response.get("status").getAsString()); + + JsonArray features = response.get("features").getAsJsonArray(); + JsonObject usageStats = response.get("usageStats").getAsJsonObject(); + JsonArray maus = usageStats.get("maus").getAsJsonArray(); + + assert features.size() == 1; + assert features.get(0).getAsString().equals("totp"); + assert maus.size() == 30; + assert maus.get(0).getAsInt() == 0; + assert maus.get(29).getAsInt() == 0; + + JsonObject totpStats = usageStats.get("totp").getAsJsonObject(); + JsonArray totpMaus = totpStats.get("maus").getAsJsonArray(); + int totalTotpUsers = totpStats.get("total_users").getAsInt(); + + assert totpMaus.size() == 30; + assert totpMaus.get(0).getAsInt() == 0; + assert totpMaus.get(29).getAsInt() == 0; + + assert totalTotpUsers == 0; + } + + // First register 2 users for emailpassword recipe. + // This also marks them as active. + JsonObject signUpResponse = Utils.signUpRequest_2_5(process, "random@gmail.com", "validPass123"); + assert signUpResponse.get("status").getAsString().equals("OK"); + + JsonObject signUpResponse2 = Utils.signUpRequest_2_5(process, "random2@gmail.com", "validPass123"); + assert signUpResponse2.get("status").getAsString().equals("OK"); + + // Now enable TOTP for the first user by registering a device. + JsonObject body = new JsonObject(); + body.addProperty("userId", signUpResponse.get("user").getAsJsonObject().get("id").getAsString()); + body.addProperty("deviceName", "d1"); + body.addProperty("skew", 0); + body.addProperty("period", 30); + JsonObject res = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res.get("status").getAsString().equals("OK"); + + // Now check the stats again: + { + JsonObject response = HttpRequestForTesting.sendGETRequest(process.getProcess(), "", + "http://localhost:3567/ee/featureflag", + null, 1000, 1000, null, WebserverAPI.getLatestCDIVersion(), ""); + Assert.assertEquals("OK", response.get("status").getAsString()); + + JsonArray features = response.get("features").getAsJsonArray(); + JsonObject usageStats = response.get("usageStats").getAsJsonObject(); + JsonArray maus = usageStats.get("maus").getAsJsonArray(); + + assert features.size() == 1; + assert features.get(0).getAsString().equals("totp"); + assert maus.size() == 30; + assert maus.get(0).getAsInt() == 2; // 2 users have signed up + assert maus.get(29).getAsInt() == 2; + + JsonObject totpStats = usageStats.get("totp").getAsJsonObject(); + JsonArray totpMaus = totpStats.get("maus").getAsJsonArray(); + int totalTotpUsers = totpStats.get("total_users").getAsInt(); + + assert totpMaus.size() == 30; + assert totpMaus.get(0).getAsInt() == 1; // only 1 user has TOTP enabled + assert totpMaus.get(29).getAsInt() == 1; + + assert totalTotpUsers == 1; + } + + process.kill(); + Assert.assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } } diff --git a/src/test/java/io/supertokens/test/emailpassword/api/SignUpAPITest2_7.java b/src/test/java/io/supertokens/test/emailpassword/api/SignUpAPITest2_7.java index d47ae0a79..bd4b97285 100644 --- a/src/test/java/io/supertokens/test/emailpassword/api/SignUpAPITest2_7.java +++ b/src/test/java/io/supertokens/test/emailpassword/api/SignUpAPITest2_7.java @@ -141,7 +141,7 @@ public void testGoodInput() throws Exception { assertNotNull(signUpUser.get("id")); int activeUsers = ActiveUsers.countUsersActiveSince(process.getProcess(), beforeSignUpTs); - assert (activeUsers == 0); + assert (activeUsers == 1); UserInfo user = StorageLayer.getEmailPasswordStorage(process.getProcess()) .getUserInfoUsingEmail("random@gmail.com"); diff --git a/src/test/java/io/supertokens/test/totp/api/CreateTotpDeviceAPITest.java b/src/test/java/io/supertokens/test/totp/api/CreateTotpDeviceAPITest.java index fd15bd114..154d96bf7 100644 --- a/src/test/java/io/supertokens/test/totp/api/CreateTotpDeviceAPITest.java +++ b/src/test/java/io/supertokens/test/totp/api/CreateTotpDeviceAPITest.java @@ -142,5 +142,4 @@ public void testApi() throws Exception { process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); } - } diff --git a/src/test/java/io/supertokens/test/totp/api/TotpUserIdMappingTest.java b/src/test/java/io/supertokens/test/totp/api/TotpUserIdMappingTest.java new file mode 100644 index 000000000..3f9970dbb --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/api/TotpUserIdMappingTest.java @@ -0,0 +1,176 @@ +package io.supertokens.test.totp.api; + +import com.google.gson.JsonObject; +import io.supertokens.ProcessState; +import io.supertokens.emailpassword.EmailPassword; +import io.supertokens.pluginInterface.emailpassword.UserInfo; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.test.httpRequest.HttpResponseException; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.totp.TOTPRecipeTest; +import io.supertokens.useridmapping.UserIdMapping; + +import static org.junit.Assert.assertNotNull; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +public class TotpUserIdMappingTest { + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + @Test + public void testExternalUserIdTranslation() throws Exception { + String[] args = { "../" }; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + JsonObject body = new JsonObject(); + + UserInfo user = EmailPassword.signUp(process.main, "test@example.com", "testPass123"); + String superTokensUserId = user.id; + String externalUserId = "external-user-id"; + + // Create user id mapping first: + UserIdMapping.createUserIdMapping(process.main, superTokensUserId, externalUserId, null, false); + + body.addProperty("userId", externalUserId); + body.addProperty("deviceName", "d1"); + body.addProperty("skew", 0); + body.addProperty("period", 30); + + // Register 1st device + JsonObject res1 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res1.get("status").getAsString().equals("OK"); + String d1Secret = res1.get("secret").getAsString(); + TOTPDevice device1 = new TOTPDevice(externalUserId, "deviceName", d1Secret, 30, 0, false); + + body.addProperty("deviceName", "d2"); + + JsonObject res2 = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + assert res2.get("status").getAsString().equals("OK"); + String d2Secret = res2.get("secret").getAsString(); + TOTPDevice device2 = new TOTPDevice(externalUserId, "deviceName", d2Secret, 30, 0, false); + + // Verify d1 but not d2: + JsonObject verifyD1Input = new JsonObject(); + verifyD1Input.addProperty("userId", externalUserId); + String d1Totp = TOTPRecipeTest.generateTotpCode(process.getProcess(), device1); + verifyD1Input.addProperty("deviceName", "d1"); + verifyD1Input.addProperty("totp", d1Totp ); + + JsonObject verifyD1Res = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/verify", + verifyD1Input, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + + assert verifyD1Res.get("status").getAsString().equals("OK"); + assert verifyD1Res.get("wasAlreadyVerified").getAsBoolean() == false; + + // use d2 to login in totp: + JsonObject loginInput = new JsonObject(); + loginInput.addProperty("userId", externalUserId); + String d2Totp = TOTPRecipeTest.generateTotpCode(process.getProcess(), device2); + loginInput.addProperty("totp", d2Totp); // use code from d2 which is unverified + loginInput.addProperty("allowUnverifiedDevices", true); + + JsonObject loginRes = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/verify", + loginInput, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + + assert loginRes.get("status").getAsString().equals("OK"); + + // Change the name of d1 to d3: + JsonObject updateDeviceNameInput = new JsonObject(); + updateDeviceNameInput.addProperty("userId", externalUserId); + updateDeviceNameInput.addProperty("existingDeviceName", "d1"); + updateDeviceNameInput.addProperty("newDeviceName", "d3"); + + JsonObject updateDeviceNameRes = HttpRequestForTesting.sendJsonPUTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + updateDeviceNameInput, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + + assert updateDeviceNameRes.get("status").getAsString().equals("OK"); + + // Delete d3: + JsonObject deleteDeviceInput = new JsonObject(); + deleteDeviceInput.addProperty("userId", externalUserId); + deleteDeviceInput.addProperty("deviceName", "d3"); + + JsonObject deleteDeviceRes = HttpRequestForTesting.sendJsonPOSTRequest( + process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device/remove", + deleteDeviceInput, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + + + assert deleteDeviceRes.get("status").getAsString().equals("OK"); + assert deleteDeviceRes.get("didDeviceExist").getAsBoolean() == true; + + } +} From 9aa94961d2ad24fc5f7309fb88c768005fb07c6e Mon Sep 17 00:00:00 2001 From: Kumar Shivendu Date: Thu, 23 Mar 2023 21:25:36 +0530 Subject: [PATCH 38/42] feat: Check TOTP feature flag in TOTP recipe functions (#592) * feat: Check TOTP feature flag in TOTP recipe functions * feat: Test feature flag error is handled by APIs --- src/main/java/io/supertokens/totp/Totp.java | 33 +++- .../api/totp/CreateOrUpdateTotpDeviceAPI.java | 3 +- .../webserver/api/totp/VerifyTotpAPI.java | 3 +- .../supertokens/test/totp/TOTPRecipeTest.java | 11 +- .../test/totp/TOTPStorageTest.java | 13 +- .../test/totp/TotpLicenseTest.java | 168 ++++++++++++++++++ .../totp/api/CreateTotpDeviceAPITest.java | 8 + .../test/totp/api/GetTotpDevicesAPITest.java | 6 + .../totp/api/RemoveTotpDeviceAPITest.java | 6 + .../test/totp/api/TotpUserIdMappingTest.java | 7 +- .../totp/api/UpdateTotpDeviceAPITest.java | 6 + .../test/totp/api/VerifyTotpAPITest.java | 6 + .../totp/api/VerifyTotpDeviceAPITest.java | 6 + 13 files changed, 269 insertions(+), 7 deletions(-) create mode 100644 src/test/java/io/supertokens/test/totp/TotpLicenseTest.java diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index 4c5040d93..9e5d43ace 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -15,6 +15,9 @@ import com.eatthepath.otp.TimeBasedOneTimePasswordGenerator; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlag; +import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.totp.TOTPDevice; @@ -29,6 +32,9 @@ import io.supertokens.totp.exceptions.InvalidTotpException; import io.supertokens.totp.exceptions.LimitReachedException; import org.apache.commons.codec.binary.Base32; +import org.jetbrains.annotations.TestOnly; + +import static io.supertokens.featureflag.FeatureFlag.getInstance; public class Totp { private static String generateSecret() throws NoSuchAlgorithmException { @@ -66,8 +72,25 @@ private static boolean checkCode(TOTPDevice device, String code) { return false; } + private static boolean isTotpEnabled(Main main) throws StorageQueryException { + EE_FEATURES[] features = FeatureFlag.getInstance(main).getEnabledFeatures(); + for (EE_FEATURES f : features) { + if (f == EE_FEATURES.TOTP) { + return true; + } + } + return false; + } + + public static TOTPDevice registerDevice(Main main, String userId, String deviceName, int skew, int period) - throws StorageQueryException, DeviceAlreadyExistsException, NoSuchAlgorithmException { + throws StorageQueryException, DeviceAlreadyExistsException, NoSuchAlgorithmException, + FeatureNotEnabledException { + + if (!isTotpEnabled(main)){ + throw new FeatureNotEnabledException( + "TOTP feature is not enabled. Please subscribe to a SuperTokens core license key to enable this feature."); + } TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); @@ -281,7 +304,13 @@ public static boolean verifyDevice(Main main, String userId, String deviceName, public static void verifyCode(Main main, String userId, String code, boolean allowUnverifiedDevices) throws TotpNotEnabledException, InvalidTotpException, LimitReachedException, - StorageQueryException, StorageTransactionLogicException { + StorageQueryException, StorageTransactionLogicException, FeatureNotEnabledException { + + if (!isTotpEnabled(main)){ + throw new FeatureNotEnabledException( + "TOTP feature is not enabled. Please subscribe to a SuperTokens core license key to enable this feature."); + } + TOTPSQLStorage totpStorage = StorageLayer.getTOTPStorage(main); // Check if the user has any devices: diff --git a/src/main/java/io/supertokens/webserver/api/totp/CreateOrUpdateTotpDeviceAPI.java b/src/main/java/io/supertokens/webserver/api/totp/CreateOrUpdateTotpDeviceAPI.java index bb04f7070..fea6e0725 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/CreateOrUpdateTotpDeviceAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/CreateOrUpdateTotpDeviceAPI.java @@ -5,6 +5,7 @@ import com.google.gson.JsonObject; import io.supertokens.Main; +import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; import io.supertokens.pluginInterface.RECIPE_ID; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.totp.TOTPDevice; @@ -75,7 +76,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I } catch (DeviceAlreadyExistsException e) { result.addProperty("status", "DEVICE_ALREADY_EXISTS_ERROR"); super.sendJsonResponse(200, result, resp); - } catch (StorageQueryException | NoSuchAlgorithmException e) { + } catch (StorageQueryException | NoSuchAlgorithmException | FeatureNotEnabledException e) { throw new ServletException(e); } } diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java index af9400863..d7c684c5b 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java @@ -5,6 +5,7 @@ import com.google.gson.JsonObject; import io.supertokens.Main; +import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; import io.supertokens.pluginInterface.RECIPE_ID; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; @@ -72,7 +73,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I result.addProperty("status", "LIMIT_REACHED_ERROR"); result.addProperty("retryAfterMs", e.retryAfterMs); super.sendJsonResponse(200, result, resp); - } catch (StorageQueryException | StorageTransactionLogicException e) { + } catch (StorageQueryException | StorageTransactionLogicException | FeatureNotEnabledException e) { throw new ServletException(e); } } diff --git a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java index a22dc4d67..da54cc955 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java @@ -27,6 +27,11 @@ import javax.crypto.spec.SecretKeySpec; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlag; +import io.supertokens.featureflag.FeatureFlagTestContent; +import io.supertokens.featureflag.exceptions.InvalidLicenseKeyException; +import io.supertokens.httpRequest.HttpResponseException; import org.apache.commons.codec.binary.Base32; import org.junit.AfterClass; import org.junit.Before; @@ -85,7 +90,9 @@ public TestSetupResult(TOTPStorage storage, TestingProcessManager.TestingProcess } } - public TestSetupResult defaultInit() throws InterruptedException, IOException { + public TestSetupResult defaultInit() + throws InterruptedException, IOException, StorageQueryException, InvalidLicenseKeyException, + HttpResponseException { String[] args = { "../" }; TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); @@ -96,6 +103,8 @@ public TestSetupResult defaultInit() throws InterruptedException, IOException { } TOTPStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); + FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + return new TestSetupResult(storage, process); } diff --git a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java index 766ba5b88..1a0d3723b 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java @@ -3,6 +3,11 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlag; +import io.supertokens.featureflag.FeatureFlagTestContent; +import io.supertokens.featureflag.exceptions.InvalidLicenseKeyException; +import io.supertokens.httpRequest.HttpResponseException; import org.junit.AfterClass; import org.junit.Before; import org.junit.Rule; @@ -27,6 +32,8 @@ import io.supertokens.pluginInterface.totp.exception.UsedCodeAlreadyExistsException; import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; +import java.io.IOException; + public class TOTPStorageTest { public class TestSetupResult { @@ -52,7 +59,9 @@ public void beforeEach() { Utils.reset(); } - public TestSetupResult initSteps() throws InterruptedException { + public TestSetupResult initSteps() + throws InterruptedException, StorageQueryException, InvalidLicenseKeyException, HttpResponseException, + IOException { String[] args = { "../" }; TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); @@ -63,6 +72,8 @@ public TestSetupResult initSteps() throws InterruptedException { } TOTPSQLStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); + FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + return new TestSetupResult(storage, process); } diff --git a/src/test/java/io/supertokens/test/totp/TotpLicenseTest.java b/src/test/java/io/supertokens/test/totp/TotpLicenseTest.java new file mode 100644 index 000000000..92ebba597 --- /dev/null +++ b/src/test/java/io/supertokens/test/totp/TotpLicenseTest.java @@ -0,0 +1,168 @@ +/* + * Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.test.totp; + +import com.google.gson.JsonObject; +import io.supertokens.Main; +import io.supertokens.ProcessState; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlagTestContent; +import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; +import io.supertokens.test.httpRequest.HttpResponseException; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.totp.TOTPStorage; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.totp.Totp; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + + +import static io.supertokens.test.totp.TOTPRecipeTest.generateTotpCode; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; + +public class TotpLicenseTest { + public final static String OPAQUE_KEY_WITH_TOTP_FEATURE = "pXhNK=nYiEsb6gJEOYP2kIR6M0kn4XLvNqcwT1XbX8xHtm44K-lQfGCbaeN0Ieeza39fxkXr=tiiUU=DXxDH40Y=4FLT4CE-rG1ETjkXxO4yucLpJvw3uSegPayoISGL"; + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + public class TestSetupResult { + public TOTPStorage storage; + public TestingProcessManager.TestingProcess process; + + public TestSetupResult(TOTPStorage storage, TestingProcessManager.TestingProcess process) { + this.storage = storage; + this.process = process; + } + } + + public TestSetupResult defaultInit() throws InterruptedException { + String[] args = { "../" }; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + assert (false); + } + TOTPStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); + + return new TestSetupResult(storage, process); + } + + @Test + public void testTotpWithoutLicense() throws Exception { + TestSetupResult result = defaultInit(); + Main main = result.process.getProcess(); + + // Create device + assertThrows(FeatureNotEnabledException.class, () -> { + Totp.registerDevice(main, "user", "device1", 1, 30); + }); + // Verify code + assertThrows(FeatureNotEnabledException.class, () -> { + Totp.verifyCode(main, "user", "device1", true); + }); + + // Try to create device via API: + JsonObject body = new JsonObject(); + body.addProperty("userId", "user-id"); + body.addProperty("deviceName", "d1"); + body.addProperty("skew", 0); + body.addProperty("period", 30); + + + HttpResponseException e = assertThrows( + HttpResponseException.class, + () -> { + HttpRequestForTesting.sendJsonPOSTRequest( + result.process.getProcess(), + "", + "http://localhost:3567/recipe/totp/device", + body, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + } + ); + assert e.statusCode == 402; + assert e.getMessage().contains("TOTP feature is not enabled"); + + + // Try to verify code via API: + JsonObject body2 = new JsonObject(); + body2.addProperty("userId", "user-id"); + body2.addProperty("totp", "123456"); + body2.addProperty("allowUnverifiedDevices", true); + + + HttpResponseException e2 = assertThrows( + HttpResponseException.class, + () -> { + HttpRequestForTesting.sendJsonPOSTRequest( + result.process.getProcess(), + "", + "http://localhost:3567/recipe/totp/verify", + body2, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + "totp"); + } + ); + assert e2.statusCode == 402; + assert e2.getMessage().contains("TOTP feature is not enabled"); + } + + + @Test + public void testTotpWithLicense() throws Exception { + TestSetupResult result = defaultInit(); + FeatureFlagTestContent.getInstance(result.process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + + Main main = result.process.getProcess(); + + // Create device + TOTPDevice device = Totp.registerDevice(main, "user", "device1", 1, 30); + // Verify code + String code = generateTotpCode(main, device); + Totp.verifyCode(main, "user", code, true); + } + + +} diff --git a/src/test/java/io/supertokens/test/totp/api/CreateTotpDeviceAPITest.java b/src/test/java/io/supertokens/test/totp/api/CreateTotpDeviceAPITest.java index 154d96bf7..92e844840 100644 --- a/src/test/java/io/supertokens/test/totp/api/CreateTotpDeviceAPITest.java +++ b/src/test/java/io/supertokens/test/totp/api/CreateTotpDeviceAPITest.java @@ -2,12 +2,16 @@ import com.google.gson.JsonObject; import io.supertokens.ProcessState; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlag; +import io.supertokens.featureflag.FeatureFlagTestContent; import io.supertokens.test.httpRequest.HttpResponseException; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.test.TestingProcessManager; import io.supertokens.test.Utils; import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.totp.TotpLicenseTest; import org.junit.AfterClass; import org.junit.Before; import org.junit.Rule; @@ -61,6 +65,7 @@ private void checkResponseErrorContains(Exception ex, String msg) { assertTrue(e.getMessage().contains(msg)); } + @Test public void testApi() throws Exception { String[] args = { "../" }; @@ -68,6 +73,9 @@ public void testApi() throws Exception { TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + FeatureFlag.getInstance(process.main).setLicenseKeyAndSyncFeatures(TotpLicenseTest.OPAQUE_KEY_WITH_TOTP_FEATURE); + FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { return; } diff --git a/src/test/java/io/supertokens/test/totp/api/GetTotpDevicesAPITest.java b/src/test/java/io/supertokens/test/totp/api/GetTotpDevicesAPITest.java index 93d749c0f..c21f8a88d 100644 --- a/src/test/java/io/supertokens/test/totp/api/GetTotpDevicesAPITest.java +++ b/src/test/java/io/supertokens/test/totp/api/GetTotpDevicesAPITest.java @@ -3,12 +3,16 @@ import com.google.gson.JsonArray; import com.google.gson.JsonObject; import io.supertokens.ProcessState; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlag; +import io.supertokens.featureflag.FeatureFlagTestContent; import io.supertokens.test.httpRequest.HttpResponseException; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.test.TestingProcessManager; import io.supertokens.test.Utils; import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.totp.TotpLicenseTest; import org.junit.AfterClass; import org.junit.Before; import org.junit.Rule; @@ -73,6 +77,8 @@ public void testApi() throws Exception { TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { return; } diff --git a/src/test/java/io/supertokens/test/totp/api/RemoveTotpDeviceAPITest.java b/src/test/java/io/supertokens/test/totp/api/RemoveTotpDeviceAPITest.java index abcc74451..5ef0d1b4c 100644 --- a/src/test/java/io/supertokens/test/totp/api/RemoveTotpDeviceAPITest.java +++ b/src/test/java/io/supertokens/test/totp/api/RemoveTotpDeviceAPITest.java @@ -2,12 +2,16 @@ import com.google.gson.JsonObject; import io.supertokens.ProcessState; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlag; +import io.supertokens.featureflag.FeatureFlagTestContent; import io.supertokens.test.httpRequest.HttpResponseException; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.test.TestingProcessManager; import io.supertokens.test.Utils; import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.totp.TotpLicenseTest; import org.junit.AfterClass; import org.junit.Before; import org.junit.Rule; @@ -72,6 +76,8 @@ public void testApi() throws Exception { return; } + FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + // Setup user and devices: JsonObject createDeviceReq = new JsonObject(); createDeviceReq.addProperty("userId", "user-id"); diff --git a/src/test/java/io/supertokens/test/totp/api/TotpUserIdMappingTest.java b/src/test/java/io/supertokens/test/totp/api/TotpUserIdMappingTest.java index 3f9970dbb..d8dd0d974 100644 --- a/src/test/java/io/supertokens/test/totp/api/TotpUserIdMappingTest.java +++ b/src/test/java/io/supertokens/test/totp/api/TotpUserIdMappingTest.java @@ -3,15 +3,18 @@ import com.google.gson.JsonObject; import io.supertokens.ProcessState; import io.supertokens.emailpassword.EmailPassword; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlag; +import io.supertokens.featureflag.FeatureFlagTestContent; import io.supertokens.pluginInterface.emailpassword.UserInfo; import io.supertokens.pluginInterface.totp.TOTPDevice; -import io.supertokens.test.httpRequest.HttpResponseException; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.test.TestingProcessManager; import io.supertokens.test.Utils; import io.supertokens.test.httpRequest.HttpRequestForTesting; import io.supertokens.test.totp.TOTPRecipeTest; +import io.supertokens.test.totp.TotpLicenseTest; import io.supertokens.useridmapping.UserIdMapping; import static org.junit.Assert.assertNotNull; @@ -47,6 +50,8 @@ public void testExternalUserIdTranslation() throws Exception { return; } + FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + JsonObject body = new JsonObject(); UserInfo user = EmailPassword.signUp(process.main, "test@example.com", "testPass123"); diff --git a/src/test/java/io/supertokens/test/totp/api/UpdateTotpDeviceAPITest.java b/src/test/java/io/supertokens/test/totp/api/UpdateTotpDeviceAPITest.java index 551f92793..f021d2804 100644 --- a/src/test/java/io/supertokens/test/totp/api/UpdateTotpDeviceAPITest.java +++ b/src/test/java/io/supertokens/test/totp/api/UpdateTotpDeviceAPITest.java @@ -2,12 +2,16 @@ import com.google.gson.JsonObject; import io.supertokens.ProcessState; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlag; +import io.supertokens.featureflag.FeatureFlagTestContent; import io.supertokens.test.httpRequest.HttpResponseException; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.test.TestingProcessManager; import io.supertokens.test.Utils; import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.totp.TotpLicenseTest; import org.junit.AfterClass; import org.junit.Before; import org.junit.Rule; @@ -72,6 +76,8 @@ public void testApi() throws Exception { return; } + FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + // Setup user and devices: JsonObject createDeviceReq = new JsonObject(); createDeviceReq.addProperty("userId", "user-id"); diff --git a/src/test/java/io/supertokens/test/totp/api/VerifyTotpAPITest.java b/src/test/java/io/supertokens/test/totp/api/VerifyTotpAPITest.java index a6ee05aad..ecc0ec0cd 100644 --- a/src/test/java/io/supertokens/test/totp/api/VerifyTotpAPITest.java +++ b/src/test/java/io/supertokens/test/totp/api/VerifyTotpAPITest.java @@ -3,6 +3,9 @@ import com.google.gson.JsonObject; import io.supertokens.ProcessState; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlag; +import io.supertokens.featureflag.FeatureFlagTestContent; import io.supertokens.test.httpRequest.HttpResponseException; import io.supertokens.test.totp.TOTPRecipeTest; import io.supertokens.pluginInterface.STORAGE_TYPE; @@ -11,6 +14,7 @@ import io.supertokens.test.TestingProcessManager; import io.supertokens.test.Utils; import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.totp.TotpLicenseTest; import org.junit.AfterClass; import org.junit.Before; import org.junit.Rule; @@ -80,6 +84,8 @@ public void testApi() throws Exception { return; } + FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + // Setup user and devices: JsonObject createDeviceReq = new JsonObject(); createDeviceReq.addProperty("userId", "user-id"); diff --git a/src/test/java/io/supertokens/test/totp/api/VerifyTotpDeviceAPITest.java b/src/test/java/io/supertokens/test/totp/api/VerifyTotpDeviceAPITest.java index 55da8c14e..7a62b0fbc 100644 --- a/src/test/java/io/supertokens/test/totp/api/VerifyTotpDeviceAPITest.java +++ b/src/test/java/io/supertokens/test/totp/api/VerifyTotpDeviceAPITest.java @@ -3,6 +3,9 @@ import com.google.gson.JsonObject; import io.supertokens.ProcessState; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlag; +import io.supertokens.featureflag.FeatureFlagTestContent; import io.supertokens.test.httpRequest.HttpResponseException; import io.supertokens.test.totp.TOTPRecipeTest; import io.supertokens.pluginInterface.STORAGE_TYPE; @@ -11,6 +14,7 @@ import io.supertokens.test.TestingProcessManager; import io.supertokens.test.Utils; import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.totp.TotpLicenseTest; import org.junit.AfterClass; import org.junit.Before; import org.junit.Rule; @@ -80,6 +84,8 @@ public void testApi() throws Exception { return; } + FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + // Setup user and devices: JsonObject createDeviceReq = new JsonObject(); createDeviceReq.addProperty("userId", "user-id"); From d46fe69aaa0f5905e913755e5d844d00b62f1bc8 Mon Sep 17 00:00:00 2001 From: Joel Coutinho Date: Fri, 24 Mar 2023 16:22:06 +0530 Subject: [PATCH 39/42] updates CDI version info --- coreDriverInterfaceSupported.json | 3 ++- src/main/java/io/supertokens/webserver/WebserverAPI.java | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/coreDriverInterfaceSupported.json b/coreDriverInterfaceSupported.json index a2126e64c..01bb39ccb 100644 --- a/coreDriverInterfaceSupported.json +++ b/coreDriverInterfaceSupported.json @@ -12,6 +12,7 @@ "2.15", "2.16", "2.17", - "2.18" + "2.18", + "2.19" ] } \ No newline at end of file diff --git a/src/main/java/io/supertokens/webserver/WebserverAPI.java b/src/main/java/io/supertokens/webserver/WebserverAPI.java index 6c4a025e8..ece9509e8 100644 --- a/src/main/java/io/supertokens/webserver/WebserverAPI.java +++ b/src/main/java/io/supertokens/webserver/WebserverAPI.java @@ -51,10 +51,11 @@ public abstract class WebserverAPI extends HttpServlet { supportedVersions.add("2.16"); supportedVersions.add("2.17"); supportedVersions.add("2.18"); + supportedVersions.add("2.19"); } public static String getLatestCDIVersion() { - return "2.18"; + return "2.19"; } public WebserverAPI(Main main, String rid) { From 00ccbe6d390d3a82da78514ac0e2c00377950c15 Mon Sep 17 00:00:00 2001 From: Kumar Shivendu Date: Mon, 27 Mar 2023 20:34:50 +0530 Subject: [PATCH 40/42] test: Fix failing tests (#598) * test: Fix failing tests * fixes tests --------- Co-authored-by: rishabhpoddar --- .../java/io/supertokens/inmemorydb/Start.java | 21 +++++++++++++++++- .../useridmapping/UserIdMapping.java | 17 +++++++++++--- .../io/supertokens/test/ConfigTest2_6.java | 13 ----------- .../io/supertokens/test/FeatureFlagTest.java | 6 ++++- .../test/dashboard/DashboardTest.java | 22 ++++++++++++++----- .../supertokens/test/totp/TOTPRecipeTest.java | 2 ++ .../test/totp/api/VerifyTotpAPITest.java | 2 +- .../totp/api/VerifyTotpDeviceAPITest.java | 2 +- .../test/userIdMapping/UserIdMappingTest.java | 5 +++++ 9 files changed, 64 insertions(+), 26 deletions(-) diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index 2cf559eee..b24541294 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -22,6 +22,7 @@ import io.supertokens.ResourceDistributor; import io.supertokens.emailverification.EmailVerification; import io.supertokens.emailverification.exception.EmailAlreadyVerifiedException; +import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; import io.supertokens.inmemorydb.config.Config; import io.supertokens.inmemorydb.queries.*; import io.supertokens.pluginInterface.ActiveUsersStorage; @@ -62,6 +63,7 @@ import io.supertokens.pluginInterface.sqlStorage.TransactionConnection; import io.supertokens.pluginInterface.thirdparty.exception.DuplicateThirdPartyUserException; import io.supertokens.pluginInterface.thirdparty.sqlStorage.ThirdPartySQLStorage; +import io.supertokens.pluginInterface.totp.TOTPStorage; import io.supertokens.pluginInterface.useridmapping.UserIdMapping; import io.supertokens.pluginInterface.useridmapping.UserIdMappingStorage; import io.supertokens.pluginInterface.totp.TOTPDevice; @@ -80,6 +82,7 @@ import io.supertokens.pluginInterface.userroles.exception.UnknownRoleException; import io.supertokens.pluginInterface.userroles.sqlStorage.UserRolesSQLStorage; import io.supertokens.session.Session; +import io.supertokens.totp.Totp; import io.supertokens.usermetadata.UserMetadata; import io.supertokens.userroles.UserRoles; import org.jetbrains.annotations.NotNull; @@ -1620,6 +1623,14 @@ public boolean isUserIdBeingUsedInNonAuthRecipe(String className, String userId) } } else if (className.equals(JWTRecipeStorage.class.getName())) { return false; + } else if (className.equals(TOTPStorage.class.getName())) { + try{ + TOTPDevice[] devices = TOTPQueries.getDevices(this, userId); + return devices.length > 0; + } + catch (SQLException e){ + throw new StorageQueryException(e); + } } else { throw new IllegalStateException("ClassName: " + className + " is not part of NonAuthRecipeStorage"); } @@ -1661,7 +1672,15 @@ public void addInfoToNonAuthRecipesBasedOnUserId(String className, String userId } catch (StorageTransactionLogicException e) { throw new StorageQueryException(e); } - } else if (className.equals(JWTRecipeStorage.class.getName())) { + } else if (className.equals(TOTPStorage.class.getName())) { + try { + Totp.registerDevice(this.main, userId, "testDevice", 0, 30); + } + catch (DeviceAlreadyExistsException | NoSuchAlgorithmException | FeatureNotEnabledException e) { + throw new StorageQueryException(e); + } + } + else if (className.equals(JWTRecipeStorage.class.getName())) { /* Since JWT recipe tables do not store userId we do not add any data to them */ } else { throw new IllegalStateException("ClassName: " + className + " is not part of NonAuthRecipeStorage"); diff --git a/src/main/java/io/supertokens/useridmapping/UserIdMapping.java b/src/main/java/io/supertokens/useridmapping/UserIdMapping.java index 61faddad4..b8a12e476 100644 --- a/src/main/java/io/supertokens/useridmapping/UserIdMapping.java +++ b/src/main/java/io/supertokens/useridmapping/UserIdMapping.java @@ -22,6 +22,7 @@ import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.jwt.JWTRecipeStorage; import io.supertokens.pluginInterface.session.SessionStorage; +import io.supertokens.pluginInterface.totp.TOTPStorage; import io.supertokens.pluginInterface.useridmapping.UserIdMappingStorage; import io.supertokens.pluginInterface.useridmapping.exception.UnknownSuperTokensUserIdException; import io.supertokens.pluginInterface.useridmapping.exception.UserIdMappingAlreadyExistsException; @@ -38,7 +39,8 @@ public class UserIdMapping { public static void createUserIdMapping(Main main, String superTokensUserId, String externalUserId, - String externalUserIdInfo, boolean force) throws UnknownSuperTokensUserIdException, + String externalUserIdInfo, boolean force) + throws UnknownSuperTokensUserIdException, UserIdMappingAlreadyExistsException, StorageQueryException, ServletException { // if a userIdMapping is created with force, then we skip the following checks if (!force) { @@ -64,7 +66,8 @@ public static void createUserIdMapping(Main main, String superTokensUserId, Stri } public static io.supertokens.pluginInterface.useridmapping.UserIdMapping getUserIdMapping(Main main, String userId, - UserIdType userIdType) throws StorageQueryException { + UserIdType userIdType) + throws StorageQueryException { UserIdMappingStorage storage = StorageLayer.getUserIdMappingStorage(main); if (userIdType == UserIdType.SUPERTOKENS) { @@ -139,7 +142,8 @@ public static boolean deleteUserIdMapping(Main main, String userId, UserIdType u } public static boolean updateOrDeleteExternalUserIdInfo(Main main, String userId, UserIdType userIdType, - @Nullable String externalUserIdInfo) throws StorageQueryException { + @Nullable String externalUserIdInfo) + throws StorageQueryException { UserIdMappingStorage storage = StorageLayer.getUserIdMappingStorage(main); if (userIdType == UserIdType.SUPERTOKENS) { @@ -192,6 +196,13 @@ private static void assertThatUserIdIsNotBeingUsedInNonAuthRecipes(Main main, St new WebserverAPI.BadRequestException("UserId is already in use in EmailVerification recipe")); } } + { + if (StorageLayer.getStorage(main).isUserIdBeingUsedInNonAuthRecipe(TOTPStorage.class.getName(), + userId)) { + throw new ServletException( + new WebserverAPI.BadRequestException("UserId is already in use in TOTP recipe")); + } + } { if (StorageLayer.getStorage(main).isUserIdBeingUsedInNonAuthRecipe(JWTRecipeStorage.class.getName(), userId)) { diff --git a/src/test/java/io/supertokens/test/ConfigTest2_6.java b/src/test/java/io/supertokens/test/ConfigTest2_6.java index e97f07b03..14db92a86 100644 --- a/src/test/java/io/supertokens/test/ConfigTest2_6.java +++ b/src/test/java/io/supertokens/test/ConfigTest2_6.java @@ -163,19 +163,6 @@ public void testInvalidTotpConfigThrowsExpectedError() throws Exception { process.kill(); assertNotNull(process.checkOrWaitForEvent(PROCESS_STATE.STOPPED)); - - Utils.reset(); - - Utils.setValueInConfig("totp_invalid_code_expiry_sec", "0"); - process = TestingProcessManager.start(args); - - e = process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.INIT_FAILURE); - assertNotNull(e); - assertEquals(e.exception.getMessage(), - "'totp_invalid_code_expiry_sec' must be > 0"); - - process.kill(); - assertNotNull(process.checkOrWaitForEvent(PROCESS_STATE.STOPPED)); } private String getConfigFileLocation(Main main) { diff --git a/src/test/java/io/supertokens/test/FeatureFlagTest.java b/src/test/java/io/supertokens/test/FeatureFlagTest.java index 926595b44..7dc0d45ab 100644 --- a/src/test/java/io/supertokens/test/FeatureFlagTest.java +++ b/src/test/java/io/supertokens/test/FeatureFlagTest.java @@ -63,7 +63,11 @@ public void noLicenseKeyShouldHaveEmptyFeatureFlag() throws InterruptedException } catch (NoLicenseKeyFoundException ignored) { } - Assert.assertEquals(FeatureFlag.getInstance(process.getProcess()).getPaidFeatureStats().entrySet().size(), 0); + JsonObject stats = FeatureFlag.getInstance(process.getProcess()).getPaidFeatureStats(); + Assert.assertEquals(stats.entrySet().size(), 1); + Assert.assertEquals(stats.get("maus").getAsJsonArray().size(), 30); + Assert.assertEquals(stats.get("maus").getAsJsonArray().get(0).getAsInt(), 0); + Assert.assertEquals(stats.get("maus").getAsJsonArray().get(29).getAsInt(), 0); process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); diff --git a/src/test/java/io/supertokens/test/dashboard/DashboardTest.java b/src/test/java/io/supertokens/test/dashboard/DashboardTest.java index d1dceafcc..a394fe7c9 100644 --- a/src/test/java/io/supertokens/test/dashboard/DashboardTest.java +++ b/src/test/java/io/supertokens/test/dashboard/DashboardTest.java @@ -285,7 +285,12 @@ public void testDashboardUsageStats() throws Exception { assertEquals(3, response.entrySet().size()); assertEquals("OK", response.get("status").getAsString()); assertEquals(0, response.get("features").getAsJsonArray().size()); - assertEquals(0, response.get("usageStats").getAsJsonObject().entrySet().size()); + JsonObject usageStats = response.get("usageStats").getAsJsonObject(); + JsonArray mauArr = usageStats.get("maus").getAsJsonArray(); + assertEquals(1, usageStats.entrySet().size()); + assertEquals(30, mauArr.size()); + assertEquals(0, mauArr.get(0).getAsInt()); + assertEquals(0, mauArr.get(29).getAsInt()); } // create a dashboard user @@ -298,7 +303,12 @@ public void testDashboardUsageStats() throws Exception { assertEquals(3, response.entrySet().size()); assertEquals("OK", response.get("status").getAsString()); assertEquals(0, response.get("features").getAsJsonArray().size()); - assertEquals(0, response.get("usageStats").getAsJsonObject().entrySet().size()); + JsonObject usageStats = response.get("usageStats").getAsJsonObject(); + JsonArray mauArr = usageStats.get("maus").getAsJsonArray(); + assertEquals(1, usageStats.entrySet().size()); + assertEquals(30, mauArr.size()); + assertEquals(0, mauArr.get(0).getAsInt()); + assertEquals(0, mauArr.get(29).getAsInt()); } // enable the dashboard feature @@ -315,9 +325,9 @@ public void testDashboardUsageStats() throws Exception { assertEquals(1, featuresArray.size()); assertEquals(EE_FEATURES.DASHBOARD_LOGIN.toString(), featuresArray.get(0).getAsString()); JsonObject usageStats = response.get("usageStats").getAsJsonObject(); - assertEquals(1, - usageStats.entrySet().size()); JsonObject dashboardLoginObject = usageStats.get("dashboard_login").getAsJsonObject(); + assertEquals(2, usageStats.entrySet().size()); + assertEquals(30, usageStats.get("maus").getAsJsonArray().size()); assertEquals(1, dashboardLoginObject.entrySet().size()); assertEquals(1, dashboardLoginObject.get("user_count").getAsInt()); } @@ -338,9 +348,9 @@ public void testDashboardUsageStats() throws Exception { assertEquals(1, featuresArray.size()); assertEquals(EE_FEATURES.DASHBOARD_LOGIN.toString(), featuresArray.get(0).getAsString()); JsonObject usageStats = response.get("usageStats").getAsJsonObject(); - assertEquals(1, - usageStats.entrySet().size()); JsonObject dashboardLoginObject = usageStats.get("dashboard_login").getAsJsonObject(); + assertEquals(2, usageStats.entrySet().size()); + assertEquals(30, usageStats.get("maus").getAsJsonArray().size()); assertEquals(1, dashboardLoginObject.entrySet().size()); assertEquals(4, dashboardLoginObject.get("user_count").getAsInt()); } diff --git a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java index da54cc955..61551ebe7 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java @@ -281,6 +281,8 @@ public void rateLimitCooldownTest() throws Exception { assert (false); } + FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + Main main = process.getProcess(); // Create device diff --git a/src/test/java/io/supertokens/test/totp/api/VerifyTotpAPITest.java b/src/test/java/io/supertokens/test/totp/api/VerifyTotpAPITest.java index ecc0ec0cd..5158f8046 100644 --- a/src/test/java/io/supertokens/test/totp/api/VerifyTotpAPITest.java +++ b/src/test/java/io/supertokens/test/totp/api/VerifyTotpAPITest.java @@ -172,7 +172,7 @@ public void testApi() throws Exception { Utils.getCdiVersionLatestForTests(), "totp"); assert res3.get("status").getAsString().equals("LIMIT_REACHED_ERROR"); - assert res3.get("retryAfterSec") != null; + assert res3.get("retryAfterMs") != null; // wait for cooldown to end (1s) Thread.sleep(1000); diff --git a/src/test/java/io/supertokens/test/totp/api/VerifyTotpDeviceAPITest.java b/src/test/java/io/supertokens/test/totp/api/VerifyTotpDeviceAPITest.java index 7a62b0fbc..53703b50b 100644 --- a/src/test/java/io/supertokens/test/totp/api/VerifyTotpDeviceAPITest.java +++ b/src/test/java/io/supertokens/test/totp/api/VerifyTotpDeviceAPITest.java @@ -176,7 +176,7 @@ public void testApi() throws Exception { Utils.getCdiVersionLatestForTests(), "totp"); assert res3.get("status").getAsString().equals("LIMIT_REACHED_ERROR"); - assert res3.get("retryAfterSec") != null; + assert res3.get("retryAfterMs") != null; // wait for cooldown to end (1s) Thread.sleep(1000); diff --git a/src/test/java/io/supertokens/test/userIdMapping/UserIdMappingTest.java b/src/test/java/io/supertokens/test/userIdMapping/UserIdMappingTest.java index cbaff9585..3b0484510 100644 --- a/src/test/java/io/supertokens/test/userIdMapping/UserIdMappingTest.java +++ b/src/test/java/io/supertokens/test/userIdMapping/UserIdMappingTest.java @@ -20,6 +20,8 @@ import io.supertokens.ProcessState; import io.supertokens.authRecipe.AuthRecipe; import io.supertokens.emailpassword.EmailPassword; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlagTestContent; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.pluginInterface.emailpassword.UserInfo; import io.supertokens.pluginInterface.nonAuthRecipe.NonAuthRecipeStorage; @@ -783,6 +785,9 @@ public void checkThatCreateUserIdMappingHasAllNonAuthRecipeChecks() throws Excep if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { return; } + + FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + // this list contains the package names for recipes which dont use UserIdMapping ArrayList nonAuthRecipesWhichDontNeedUserIdMapping = new ArrayList<>( List.of("io.supertokens.pluginInterface.jwt.JWTRecipeStorage")); From f63a46211f4879f4ef2b8a873105645d235c3231 Mon Sep 17 00:00:00 2001 From: Kumar Shivendu Date: Mon, 27 Mar 2023 21:37:13 +0530 Subject: [PATCH 41/42] feat: Add new API and tests for counting active users (#596) * feat: Add new API and tests for counting active users * chores: Update CHANGELOG to mention new active user count API * test: Add bad input tests for active users count API * chores: Update the http method for active user count API in CHANGELOG --- CHANGELOG.md | 1 + .../io/supertokens/webserver/InputParser.java | 15 ++ .../io/supertokens/webserver/Webserver.java | 1 + .../api/core/ActiveUsersCountAPI.java | 62 +++++++ .../io/supertokens/test/ActiveUsersTest.java | 156 +++++++++++++++++- 5 files changed, 234 insertions(+), 1 deletion(-) create mode 100644 src/main/java/io/supertokens/webserver/api/core/ActiveUsersCountAPI.java diff --git a/CHANGELOG.md b/CHANGELOG.md index d14249006..2981274bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - `totp_used_codes` that stores used codes for each user. This is to implement rate limiting and prevent replay attacks. ### New APIs: +- `GET /users/count/active` to fetch the number of active users after the given timestamp. - `POST /recipe/totp/device` to create a new device as well as the user if it doesn't exist. - `POST /recipe/totp/device/verify` to verify a device. This is to ensure that the user has access to the device. - `POST /recipe/totp/verify` to verify a code and continue the login flow. diff --git a/src/main/java/io/supertokens/webserver/InputParser.java b/src/main/java/io/supertokens/webserver/InputParser.java index 6581b4c9f..98a1db6d8 100644 --- a/src/main/java/io/supertokens/webserver/InputParser.java +++ b/src/main/java/io/supertokens/webserver/InputParser.java @@ -86,6 +86,21 @@ public static Integer getIntQueryParamOrThrowError(HttpServletRequest request, S } } + public static Long getLongQueryParamOrThrowError(HttpServletRequest request, String fieldName, boolean nullable) + throws ServletException { + String key = getQueryParamOrThrowError(request, fieldName, nullable); + if (key == null && nullable) { + return null; + } + try { + assert key != null; + return Long.parseLong(key); + } catch (Exception e) { + throw new ServletException(new WebserverAPI.BadRequestException( + "Field name '" + fieldName + "' must be a long in the GET request")); + } + } + public static JsonObject parseJsonObjectOrThrowError(JsonObject element, String fieldName, boolean nullable) throws ServletException { try { diff --git a/src/main/java/io/supertokens/webserver/Webserver.java b/src/main/java/io/supertokens/webserver/Webserver.java index a0a6cfaf9..216996de2 100644 --- a/src/main/java/io/supertokens/webserver/Webserver.java +++ b/src/main/java/io/supertokens/webserver/Webserver.java @@ -238,6 +238,7 @@ private void setupRoutes() throws Exception { addAPI(new ConsumeCodeAPI(main)); addAPI(new TelemetryAPI(main)); addAPI(new UsersCountAPI(main)); + addAPI(new ActiveUsersCountAPI(main)); addAPI(new UsersAPI(main)); addAPI(new DeleteUserAPI(main)); addAPI(new RevokeAllTokensForUserAPI(main)); diff --git a/src/main/java/io/supertokens/webserver/api/core/ActiveUsersCountAPI.java b/src/main/java/io/supertokens/webserver/api/core/ActiveUsersCountAPI.java new file mode 100644 index 000000000..73efb131a --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/core/ActiveUsersCountAPI.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.webserver.api.core; + +import com.google.gson.JsonObject; +import io.supertokens.ActiveUsers; +import io.supertokens.Main; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.webserver.InputParser; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import java.io.IOException; + +public class ActiveUsersCountAPI extends WebserverAPI { + private static final long serialVersionUID = -2225750492558064634L; + + public ActiveUsersCountAPI(Main main) { + super(main, ""); + } + + @Override + public String getPath() { + return "/users/count/active"; + } + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + Long sinceTimestamp = InputParser.getLongQueryParamOrThrowError(req, "since", false); + + if (sinceTimestamp < 0) { + throw new ServletException(new BadRequestException("'since' query parameter must be >= 0")); + } + + try { + int count = ActiveUsers.countUsersActiveSince(main, sinceTimestamp); + JsonObject result = new JsonObject(); + result.addProperty("status", "OK"); + result.addProperty("count", count); + super.sendJsonResponse(200, result, resp); + } catch (StorageQueryException e) { + throw new ServletException(e); + } + } +} + diff --git a/src/test/java/io/supertokens/test/ActiveUsersTest.java b/src/test/java/io/supertokens/test/ActiveUsersTest.java index 2b27708bd..0a6856b04 100644 --- a/src/test/java/io/supertokens/test/ActiveUsersTest.java +++ b/src/test/java/io/supertokens/test/ActiveUsersTest.java @@ -1,7 +1,12 @@ package io.supertokens.test; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import com.fasterxml.jackson.databind.jsonFormatVisitors.JsonBooleanFormatVisitor; +import com.google.gson.JsonObject; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.httpRequest.HttpResponseException; import org.junit.AfterClass; import org.junit.Before; import org.junit.Rule; @@ -14,6 +19,8 @@ import io.supertokens.Main; import io.supertokens.ProcessState; +import java.util.HashMap; + public class ActiveUsersTest { @Rule @@ -54,8 +61,155 @@ public void updateAndCountUserLastActiveTest() throws Exception { ActiveUsers.updateLastActive(main, "user1"); - assert ActiveUsers.countUsersActiveSince(main, now) == 2; // user1 just got updated assert ActiveUsers.countUsersActiveSince(main, now2) == 1; // only user1 is counted + assert ActiveUsers.countUsersActiveSince(main, now) == 2; // user1 and user2 are counted + } + + @Test + public void activeUserCountAPITest() throws Exception { + String[] args = { "../" }; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + assert (false); + } + + Main main = process.getProcess(); + long now = System.currentTimeMillis(); + + HashMap params = new HashMap<>(); + + HttpResponseException e = + assertThrows( + HttpResponseException.class, + () -> { + HttpRequestForTesting.sendGETRequest( + process.getProcess(), + "", + "http://localhost:3567/users/count/active", + params, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + ""); + } + ); + + assert e.statusCode == 400; + assert e.getMessage().contains("Field name 'since' is missing in GET request"); + + params.put("since", "not a number"); + e = + assertThrows( + HttpResponseException.class, + () -> { + HttpRequestForTesting.sendGETRequest( + process.getProcess(), + "", + "http://localhost:3567/users/count/active", + params, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + ""); + } + ); + + assert e.statusCode == 400; + assert e.getMessage().contains("Field name 'since' must be a long in the GET request"); + + params.put("since", "-1"); + e = + assertThrows( + HttpResponseException.class, + () -> { + HttpRequestForTesting.sendGETRequest( + process.getProcess(), + "", + "http://localhost:3567/users/count/active", + params, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + ""); + } + ); + + assert e.statusCode == 400; + assert e.getMessage().contains("'since' query parameter must be >= 0"); + + + params.put("since", Long.toString(now)); + + JsonObject res = HttpRequestForTesting.sendGETRequest( + process.getProcess(), + "", + "http://localhost:3567/users/count/active", + params, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + ""); + + assert res.get("status").getAsString().equals("OK"); + assert res.get("count").getAsInt() == 0; + + ActiveUsers.updateLastActive(main, "user1"); + ActiveUsers.updateLastActive(main, "user2"); + + res = HttpRequestForTesting.sendGETRequest( + process.getProcess(), + "", + "http://localhost:3567/users/count/active", + params, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + ""); + + assert res.get("status").getAsString().equals("OK"); + assert res.get("count").getAsInt() == 2; + + long now2 = System.currentTimeMillis(); + + ActiveUsers.updateLastActive(main, "user1"); + + params.put("since", Long.toString(now2)); + res = HttpRequestForTesting.sendGETRequest( + process.getProcess(), + "", + "http://localhost:3567/users/count/active", + params, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + ""); + + assert res.get("status").getAsString().equals("OK"); + assert res.get("count").getAsInt() == 1; + + params.put("since", Long.toString(now)); + res = HttpRequestForTesting.sendGETRequest( + process.getProcess(), + "", + "http://localhost:3567/users/count/active", + params, + 1000, + 1000, + null, + Utils.getCdiVersionLatestForTests(), + ""); + + assert res.get("status").getAsString().equals("OK"); + assert res.get("count").getAsInt() == 2; } } From 90c9af1031f16f0db162f66432c884fbfc2d0fd9 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 27 Mar 2023 21:46:37 +0530 Subject: [PATCH 42/42] fixes tests and review comments --- CHANGELOG.md | 13 ++- .../api/session/RefreshSessionAPI.java | 26 ++--- .../webserver/api/session/SessionAPI.java | 30 +++--- .../api/session/SessionRemoveAPI.java | 27 ++--- .../io/supertokens/test/ActiveUsersTest.java | 32 +++--- .../io/supertokens/test/StorageLayerTest.java | 5 +- .../supertokens/test/totp/TOTPRecipeTest.java | 99 ++++++++++++------- .../test/totp/TOTPStorageTest.java | 84 ++++++++++------ .../test/totp/TotpLicenseTest.java | 19 ++-- 9 files changed, 203 insertions(+), 132 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2981274bf..17cfb26ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,17 +10,22 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - Add TOTP recipe ### Database changes: + - Add new tables for TOTP recipe: - - `totp_users` that stores the users that have enabled TOTP - - `totp_user_devices` that stores devices (each device has its own secret) for each user - - `totp_used_codes` that stores used codes for each user. This is to implement rate limiting and prevent replay attacks. + - `totp_users` that stores the users that have enabled TOTP + - `totp_user_devices` that stores devices (each device has its own secret) for each user + - `totp_used_codes` that stores used codes for each user. This is to implement rate limiting and prevent replay + attacks. + - `user_last_active` that stores the last active time for each user. ### New APIs: + - `GET /users/count/active` to fetch the number of active users after the given timestamp. - `POST /recipe/totp/device` to create a new device as well as the user if it doesn't exist. - `POST /recipe/totp/device/verify` to verify a device. This is to ensure that the user has access to the device. - `POST /recipe/totp/verify` to verify a code and continue the login flow. -- `PUT /recipe/totp/device` to update the name of a device. Name is just a string that the user can set to identify the device. +- `PUT /recipe/totp/device` to update the name of a device. Name is just a string that the user can set to identify the + device. - `GET /recipe/totp/device/list` to get all devices for a user. - `POST /recipe/totp/device/remove` to remove a device. If the user has no more devices, the user is also removed. diff --git a/src/main/java/io/supertokens/webserver/api/session/RefreshSessionAPI.java b/src/main/java/io/supertokens/webserver/api/session/RefreshSessionAPI.java index c1220aa8a..20cf5001f 100644 --- a/src/main/java/io/supertokens/webserver/api/session/RefreshSessionAPI.java +++ b/src/main/java/io/supertokens/webserver/api/session/RefreshSessionAPI.java @@ -17,26 +17,27 @@ package io.supertokens.webserver.api.session; import com.google.gson.JsonObject; - import io.supertokens.ActiveUsers; import io.supertokens.Main; import io.supertokens.exceptions.TokenTheftDetectedException; import io.supertokens.exceptions.UnauthorisedException; import io.supertokens.output.Logging; import io.supertokens.pluginInterface.RECIPE_ID; +import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.useridmapping.UserIdMapping; -import io.supertokens.useridmapping.UserIdType; import io.supertokens.session.Session; import io.supertokens.session.info.SessionInformationHolder; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.useridmapping.UserIdType; import io.supertokens.utils.Utils; import io.supertokens.webserver.InputParser; import io.supertokens.webserver.WebserverAPI; - import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; + import java.io.IOException; public class RefreshSessionAPI extends WebserverAPI { @@ -64,15 +65,18 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I SessionInformationHolder sessionInfo = Session.refreshSession(main, refreshToken, antiCsrfToken, enableAntiCsrf); - try{ - UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, - sessionInfo.session.userId, UserIdType.ANY); - if (userIdMapping != null) { - ActiveUsers.updateLastActive(main, userIdMapping.superTokensUserId); - } else { - ActiveUsers.updateLastActive(main, sessionInfo.session.userId); + if (StorageLayer.getStorage(main).getType() == STORAGE_TYPE.SQL) { + try { + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping( + super.main, + sessionInfo.session.userId, UserIdType.ANY); + if (userIdMapping != null) { + ActiveUsers.updateLastActive(main, userIdMapping.superTokensUserId); + } else { + ActiveUsers.updateLastActive(main, sessionInfo.session.userId); + } + } catch (StorageQueryException ignored) { } - } catch (StorageQueryException ignored){ } diff --git a/src/main/java/io/supertokens/webserver/api/session/SessionAPI.java b/src/main/java/io/supertokens/webserver/api/session/SessionAPI.java index fb297a8ff..6a4273ec8 100644 --- a/src/main/java/io/supertokens/webserver/api/session/SessionAPI.java +++ b/src/main/java/io/supertokens/webserver/api/session/SessionAPI.java @@ -19,31 +19,32 @@ import com.google.gson.Gson; import com.google.gson.JsonArray; import com.google.gson.JsonObject; - import io.supertokens.ActiveUsers; import io.supertokens.Main; import io.supertokens.exceptions.UnauthorisedException; import io.supertokens.output.Logging; import io.supertokens.pluginInterface.RECIPE_ID; +import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.session.SessionInfo; import io.supertokens.pluginInterface.useridmapping.UserIdMapping; -import io.supertokens.useridmapping.UserIdType; import io.supertokens.session.Session; import io.supertokens.session.accessToken.AccessTokenSigningKey; import io.supertokens.session.accessToken.AccessTokenSigningKey.KeyInfo; import io.supertokens.session.info.SessionInformationHolder; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.useridmapping.UserIdType; import io.supertokens.utils.Utils; import io.supertokens.webserver.InputParser; import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; import javax.crypto.BadPaddingException; import javax.crypto.IllegalBlockSizeException; import javax.crypto.NoSuchPaddingException; -import jakarta.servlet.ServletException; -import jakarta.servlet.http.HttpServletRequest; -import jakarta.servlet.http.HttpServletResponse; import java.io.IOException; import java.security.InvalidAlgorithmParameterException; import java.security.InvalidKeyException; @@ -80,15 +81,18 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I SessionInformationHolder sessionInfo = Session.createNewSession(main, userId, userDataInJWT, userDataInDatabase, enableAntiCsrf); - try { - UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping(super.main, - sessionInfo.session.userId, UserIdType.ANY); - if (userIdMapping != null) { - ActiveUsers.updateLastActive(main, userIdMapping.superTokensUserId); - } else { - ActiveUsers.updateLastActive(main, sessionInfo.session.userId); + if (StorageLayer.getStorage(main).getType() == STORAGE_TYPE.SQL) { + try { + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping( + super.main, + sessionInfo.session.userId, UserIdType.ANY); + if (userIdMapping != null) { + ActiveUsers.updateLastActive(main, userIdMapping.superTokensUserId); + } else { + ActiveUsers.updateLastActive(main, sessionInfo.session.userId); + } + } catch (StorageQueryException ignored) { } - } catch (StorageQueryException ignored) { } JsonObject result = sessionInfo.toJsonObject(); diff --git a/src/main/java/io/supertokens/webserver/api/session/SessionRemoveAPI.java b/src/main/java/io/supertokens/webserver/api/session/SessionRemoveAPI.java index 92a2d77a3..875034cc0 100644 --- a/src/main/java/io/supertokens/webserver/api/session/SessionRemoveAPI.java +++ b/src/main/java/io/supertokens/webserver/api/session/SessionRemoveAPI.java @@ -19,20 +19,21 @@ import com.google.gson.JsonArray; import com.google.gson.JsonObject; import com.google.gson.JsonPrimitive; - import io.supertokens.ActiveUsers; import io.supertokens.Main; import io.supertokens.pluginInterface.RECIPE_ID; +import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.useridmapping.UserIdMapping; -import io.supertokens.useridmapping.UserIdType; import io.supertokens.session.Session; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.useridmapping.UserIdType; import io.supertokens.webserver.InputParser; import io.supertokens.webserver.WebserverAPI; - import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; + import java.io.IOException; public class SessionRemoveAPI extends WebserverAPI { @@ -79,16 +80,18 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I try { String[] sessionHandlesRevoked = Session.revokeAllSessionsForUser(main, userId); - try { - UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping( - super.main, - userId, UserIdType.ANY); - if (userIdMapping != null) { - ActiveUsers.updateLastActive(main, userIdMapping.superTokensUserId); - } else { - ActiveUsers.updateLastActive(main, userId); + if (StorageLayer.getStorage(main).getType() == STORAGE_TYPE.SQL) { + try { + UserIdMapping userIdMapping = io.supertokens.useridmapping.UserIdMapping.getUserIdMapping( + super.main, + userId, UserIdType.ANY); + if (userIdMapping != null) { + ActiveUsers.updateLastActive(main, userIdMapping.superTokensUserId); + } else { + ActiveUsers.updateLastActive(main, userId); + } + } catch (StorageQueryException ignored) { } - } catch (StorageQueryException ignored) { } JsonObject result = new JsonObject(); diff --git a/src/test/java/io/supertokens/test/ActiveUsersTest.java b/src/test/java/io/supertokens/test/ActiveUsersTest.java index 0a6856b04..d0ff48460 100644 --- a/src/test/java/io/supertokens/test/ActiveUsersTest.java +++ b/src/test/java/io/supertokens/test/ActiveUsersTest.java @@ -1,10 +1,11 @@ package io.supertokens.test; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertThrows; - -import com.fasterxml.jackson.databind.jsonFormatVisitors.JsonBooleanFormatVisitor; import com.google.gson.JsonObject; +import io.supertokens.ActiveUsers; +import io.supertokens.Main; +import io.supertokens.ProcessState; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.storageLayer.StorageLayer; import io.supertokens.test.httpRequest.HttpRequestForTesting; import io.supertokens.test.httpRequest.HttpResponseException; import org.junit.AfterClass; @@ -13,14 +14,11 @@ import org.junit.Test; import org.junit.rules.TestRule; -import io.supertokens.pluginInterface.STORAGE_TYPE; -import io.supertokens.storageLayer.StorageLayer; -import io.supertokens.ActiveUsers; -import io.supertokens.Main; -import io.supertokens.ProcessState; - import java.util.HashMap; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; + public class ActiveUsersTest { @Rule @@ -38,13 +36,13 @@ public void beforeEach() { @Test public void updateAndCountUserLastActiveTest() throws Exception { - String[] args = { "../" }; + String[] args = {"../"}; TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { - assert (false); + return; } Main main = process.getProcess(); @@ -67,13 +65,13 @@ public void updateAndCountUserLastActiveTest() throws Exception { @Test public void activeUserCountAPITest() throws Exception { - String[] args = { "../" }; + String[] args = {"../"}; TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { - assert (false); + return; } Main main = process.getProcess(); @@ -81,7 +79,7 @@ public void activeUserCountAPITest() throws Exception { HashMap params = new HashMap<>(); - HttpResponseException e = + HttpResponseException e = assertThrows( HttpResponseException.class, () -> { @@ -102,7 +100,7 @@ public void activeUserCountAPITest() throws Exception { assert e.getMessage().contains("Field name 'since' is missing in GET request"); params.put("since", "not a number"); - e = + e = assertThrows( HttpResponseException.class, () -> { @@ -123,7 +121,7 @@ public void activeUserCountAPITest() throws Exception { assert e.getMessage().contains("Field name 'since' must be a long in the GET request"); params.put("since", "-1"); - e = + e = assertThrows( HttpResponseException.class, () -> { diff --git a/src/test/java/io/supertokens/test/StorageLayerTest.java b/src/test/java/io/supertokens/test/StorageLayerTest.java index c9896d8e9..f90a79850 100644 --- a/src/test/java/io/supertokens/test/StorageLayerTest.java +++ b/src/test/java/io/supertokens/test/StorageLayerTest.java @@ -12,7 +12,6 @@ import io.supertokens.pluginInterface.totp.exception.UsedCodeAlreadyExistsException; import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; import io.supertokens.storageLayer.StorageLayer; - import org.junit.AfterClass; import org.junit.Before; import org.junit.Rule; @@ -61,14 +60,14 @@ public static void insertUsedCodeUtil(TOTPSQLStorage storage, TOTPUsedCode usedC @Test public void totpCodeLengthTest() throws Exception { - String[] args = { "../" }; + String[] args = {"../"}; TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); process.getProcess().setForceInMemoryDB(); // this test is only for SQLite assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { - assert (false); + return; } TOTPSQLStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); long now = System.currentTimeMillis(); diff --git a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java index 61551ebe7..6841f9982 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java @@ -16,45 +16,18 @@ package io.supertokens.test.totp; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertThrows; - -import java.io.IOException; -import java.security.InvalidKeyException; -import java.security.Key; -import java.time.Duration; -import java.time.Instant; - -import javax.crypto.spec.SecretKeySpec; - -import io.supertokens.featureflag.EE_FEATURES; -import io.supertokens.featureflag.FeatureFlag; -import io.supertokens.featureflag.FeatureFlagTestContent; -import io.supertokens.featureflag.exceptions.InvalidLicenseKeyException; -import io.supertokens.httpRequest.HttpResponseException; -import org.apache.commons.codec.binary.Base32; -import org.junit.AfterClass; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TestRule; - import com.eatthepath.otp.TimeBasedOneTimePasswordGenerator; - -import io.supertokens.test.Utils; import io.supertokens.Main; import io.supertokens.ProcessState; import io.supertokens.config.Config; import io.supertokens.cronjobs.deleteExpiredTotpTokens.DeleteExpiredTotpTokens; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlagTestContent; +import io.supertokens.featureflag.exceptions.InvalidLicenseKeyException; +import io.supertokens.httpRequest.HttpResponseException; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; -import io.supertokens.storageLayer.StorageLayer; -import io.supertokens.test.TestingProcessManager; - -import io.supertokens.totp.Totp; -import io.supertokens.totp.exceptions.InvalidTotpException; -import io.supertokens.totp.exceptions.LimitReachedException; import io.supertokens.pluginInterface.totp.TOTPDevice; import io.supertokens.pluginInterface.totp.TOTPStorage; import io.supertokens.pluginInterface.totp.TOTPUsedCode; @@ -62,6 +35,28 @@ import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.totp.Totp; +import io.supertokens.totp.exceptions.InvalidTotpException; +import io.supertokens.totp.exceptions.LimitReachedException; +import org.apache.commons.codec.binary.Base32; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import javax.crypto.spec.SecretKeySpec; +import java.io.IOException; +import java.security.InvalidKeyException; +import java.security.Key; +import java.time.Duration; +import java.time.Instant; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; // TODO: Add test for UsedCodeAlreadyExistsException once we implement time mocking @@ -93,17 +88,18 @@ public TestSetupResult(TOTPStorage storage, TestingProcessManager.TestingProcess public TestSetupResult defaultInit() throws InterruptedException, IOException, StorageQueryException, InvalidLicenseKeyException, HttpResponseException { - String[] args = { "../" }; + String[] args = {"../"}; TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { - assert (false); + return null; } TOTPStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); - FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + FeatureFlagTestContent.getInstance(process.main) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{EE_FEATURES.TOTP}); return new TestSetupResult(storage, process); } @@ -113,7 +109,9 @@ public static String generateTotpCode(Main main, TOTPDevice device) return generateTotpCode(main, device, 0); } - /** Generates TOTP code similar to apps like Google Authenticator and Authy */ + /** + * Generates TOTP code similar to apps like Google Authenticator and Authy + */ private static String generateTotpCode(Main main, TOTPDevice device, int step) throws InvalidKeyException, StorageQueryException { final TimeBasedOneTimePasswordGenerator totp = new TimeBasedOneTimePasswordGenerator( @@ -140,6 +138,9 @@ private static TOTPUsedCode[] getAllUsedCodesUtil(TOTPStorage storage, String us @Test public void createDeviceTest() throws Exception { TestSetupResult result = defaultInit(); + if (result == null) { + return; + } Main main = result.process.getProcess(); // Create device @@ -154,6 +155,9 @@ public void createDeviceTest() throws Exception { @Test public void createDeviceAndVerifyCodeTest() throws Exception { TestSetupResult result = defaultInit(); + if (result == null) { + return; + } Main main = result.process.getProcess(); // Create device @@ -267,7 +271,7 @@ public int triggerAndCheckRateLimit(Main main, TOTPDevice device) throws Excepti @Test public void rateLimitCooldownTest() throws Exception { - String[] args = { "../" }; + String[] args = {"../"}; // set rate limiting cooldown time to 1s Utils.setValueInConfig("totp_rate_limit_cooldown_sec", "1"); @@ -278,10 +282,11 @@ public void rateLimitCooldownTest() throws Exception { assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { - assert (false); + return; } - FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + FeatureFlagTestContent.getInstance(process.main) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{EE_FEATURES.TOTP}); Main main = process.getProcess(); @@ -311,6 +316,9 @@ public void rateLimitCooldownTest() throws Exception { public void cronRemovesCodesDuringRateLimitTest() throws Exception { // This test is flaky because of time. TestSetupResult result = defaultInit(); + if (result == null) { + return; + } Main main = result.process.getProcess(); // Create device @@ -335,6 +343,9 @@ public void cronRemovesCodesDuringRateLimitTest() throws Exception { @Test public void createAndVerifyDeviceTest() throws Exception { TestSetupResult result = defaultInit(); + if (result == null) { + return; + } Main main = result.process.getProcess(); // Create device @@ -377,6 +388,9 @@ public void createAndVerifyDeviceTest() throws Exception { public void removeDeviceTest() throws Exception { // Flaky test. TestSetupResult result = defaultInit(); + if (result == null) { + return; + } Main main = result.process.getProcess(); TOTPStorage storage = result.storage; @@ -441,6 +455,9 @@ public void removeDeviceTest() throws Exception { @Test public void updateDeviceNameTest() throws Exception { TestSetupResult result = defaultInit(); + if (result == null) { + return; + } Main main = result.process.getProcess(); Totp.registerDevice(main, "user", "device1", 1, 30); @@ -475,6 +492,9 @@ public void updateDeviceNameTest() throws Exception { @Test public void getDevicesTest() throws Exception { TestSetupResult result = defaultInit(); + if (result == null) { + return; + } Main main = result.process.getProcess(); // Try get devices for non-existent user: @@ -492,6 +512,9 @@ public void getDevicesTest() throws Exception { @Test public void deleteExpiredTokensCronIntervalTest() throws Exception { TestSetupResult result = defaultInit(); + if (result == null) { + return; + } Main main = result.process.getProcess(); // Ensure that delete expired tokens cron runs every hour: diff --git a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java index 1a0d3723b..a9d3c0493 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPStorageTest.java @@ -1,28 +1,14 @@ package io.supertokens.test.totp; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertThrows; - +import io.supertokens.ProcessState; +import io.supertokens.cronjobs.deleteExpiredTotpTokens.DeleteExpiredTotpTokens; import io.supertokens.featureflag.EE_FEATURES; -import io.supertokens.featureflag.FeatureFlag; import io.supertokens.featureflag.FeatureFlagTestContent; import io.supertokens.featureflag.exceptions.InvalidLicenseKeyException; import io.supertokens.httpRequest.HttpResponseException; -import org.junit.AfterClass; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TestRule; - -import io.supertokens.test.Utils; -import io.supertokens.ProcessState; -import io.supertokens.cronjobs.deleteExpiredTotpTokens.DeleteExpiredTotpTokens; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; -import io.supertokens.storageLayer.StorageLayer; -import io.supertokens.test.TestingProcessManager; - import io.supertokens.pluginInterface.totp.TOTPDevice; import io.supertokens.pluginInterface.totp.TOTPStorage; import io.supertokens.pluginInterface.totp.TOTPUsedCode; @@ -31,9 +17,20 @@ import io.supertokens.pluginInterface.totp.exception.UnknownDeviceException; import io.supertokens.pluginInterface.totp.exception.UsedCodeAlreadyExistsException; import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; import java.io.IOException; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; + public class TOTPStorageTest { public class TestSetupResult { @@ -62,17 +59,18 @@ public void beforeEach() { public TestSetupResult initSteps() throws InterruptedException, StorageQueryException, InvalidLicenseKeyException, HttpResponseException, IOException { - String[] args = { "../" }; + String[] args = {"../"}; TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { - assert (false); + return null; } TOTPSQLStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); - FeatureFlagTestContent.getInstance(process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + FeatureFlagTestContent.getInstance(process.main) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{EE_FEATURES.TOTP}); return new TestSetupResult(storage, process); } @@ -119,6 +117,9 @@ public static void insertUsedCodesUtil(TOTPSQLStorage storage, TOTPUsedCode[] us @Test public void createDeviceTests() throws Exception { TestSetupResult result = initSteps(); + if (result == null) { + return; + } TOTPSQLStorage storage = result.storage; TOTPDevice device1 = new TOTPDevice("user", "d1", "secret", 30, 1, false); @@ -147,6 +148,9 @@ public void createDeviceTests() throws Exception { @Test public void verifyDeviceTests() throws Exception { TestSetupResult result = initSteps(); + if (result == null) { + return; + } TOTPSQLStorage storage = result.storage; TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); @@ -173,6 +177,9 @@ public void verifyDeviceTests() throws Exception { @Test public void getDevicesCount_TransactionTests() throws Exception { TestSetupResult result = initSteps(); + if (result == null) { + return; + } TOTPSQLStorage storage = result.storage; // Try to get the count for a user that doesn't exist (Should pass because @@ -201,6 +208,9 @@ public void getDevicesCount_TransactionTests() throws Exception { @Test public void removeUser_TransactionTests() throws Exception { TestSetupResult result = initSteps(); + if (result == null) { + return; + } TOTPSQLStorage storage = result.storage; // Try to remove a user that doesn't exist (Should pass because @@ -223,7 +233,7 @@ public void removeUser_TransactionTests() throws Exception { TOTPUsedCode usedCode1 = new TOTPUsedCode("user", "code1", true, expiryAfter10mins, now); TOTPUsedCode usedCode2 = new TOTPUsedCode("user", "code2", false, expiryAfter10mins, now + 1); - insertUsedCodesUtil(storage, new TOTPUsedCode[] { usedCode1, usedCode2 }); + insertUsedCodesUtil(storage, new TOTPUsedCode[]{usedCode1, usedCode2}); TOTPDevice[] storedDevices = storage.getDevices("user"); assert (storedDevices.length == 2); @@ -247,6 +257,9 @@ public void removeUser_TransactionTests() throws Exception { @Test public void deleteDevice_TransactionTests() throws Exception { TestSetupResult result = initSteps(); + if (result == null) { + return; + } TOTPSQLStorage storage = result.storage; TOTPDevice device1 = new TOTPDevice("user", "device1", "sk1", 30, 1, false); @@ -291,6 +304,9 @@ public void deleteDevice_TransactionTests() throws Exception { @Test public void updateDeviceNameTests() throws Exception { TestSetupResult result = initSteps(); + if (result == null) { + return; + } TOTPSQLStorage storage = result.storage; TOTPDevice device = new TOTPDevice("user", "device", "secretKey", 30, 1, false); @@ -326,6 +342,9 @@ public void updateDeviceNameTests() throws Exception { @Test public void getDevicesTest() throws Exception { TestSetupResult result = initSteps(); + if (result == null) { + return; + } TOTPSQLStorage storage = result.storage; TOTPDevice device1 = new TOTPDevice("user", "d1", "secretKey", 30, 1, false); @@ -347,6 +366,9 @@ public void getDevicesTest() throws Exception { @Test public void insertUsedCodeTest() throws Exception { TestSetupResult result = initSteps(); + if (result == null) { + return; + } TOTPSQLStorage storage = result.storage; long nextDay = System.currentTimeMillis() + 1000 * 60 * 60 * 24; // 1 day from now long now = System.currentTimeMillis(); @@ -357,7 +379,7 @@ public void insertUsedCodeTest() throws Exception { TOTPUsedCode code = new TOTPUsedCode("user", "1234", true, nextDay, now); storage.createDevice(device); - insertUsedCodesUtil(storage, new TOTPUsedCode[] { code }); + insertUsedCodesUtil(storage, new TOTPUsedCode[]{code}); TOTPUsedCode[] usedCodes = getAllUsedCodesUtil(storage, "user"); assert (usedCodes.length == 1); @@ -368,13 +390,13 @@ public void insertUsedCodeTest() throws Exception { { TOTPUsedCode codeWithRepeatedCreatedTime = new TOTPUsedCode("user", "any-code", true, nextDay, now); assertThrows(UsedCodeAlreadyExistsException.class, - () -> insertUsedCodesUtil(storage, new TOTPUsedCode[] { codeWithRepeatedCreatedTime })); + () -> insertUsedCodesUtil(storage, new TOTPUsedCode[]{codeWithRepeatedCreatedTime})); } // Try to insert code when user doesn't have any device (i.e. TOTP not enabled) { assertThrows(TotpNotEnabledException.class, - () -> insertUsedCodesUtil(storage, new TOTPUsedCode[] { + () -> insertUsedCodesUtil(storage, new TOTPUsedCode[]{ new TOTPUsedCode("new-user-without-totp", "1234", true, nextDay, System.currentTimeMillis()) })); @@ -386,14 +408,14 @@ public void insertUsedCodeTest() throws Exception { storage.createDevice(newDevice); insertUsedCodesUtil( storage, - new TOTPUsedCode[] { + new TOTPUsedCode[]{ new TOTPUsedCode("user", "1234", true, nextDay, System.currentTimeMillis()) }); } // Try to insert code when user doesn't exist: assertThrows(TotpNotEnabledException.class, - () -> insertUsedCodesUtil(storage, new TOTPUsedCode[] { + () -> insertUsedCodesUtil(storage, new TOTPUsedCode[]{ new TOTPUsedCode("non-existent-user", "1234", true, nextDay, System.currentTimeMillis()) })); @@ -402,6 +424,9 @@ public void insertUsedCodeTest() throws Exception { @Test public void getAllUsedCodesTest() throws Exception { TestSetupResult result = initSteps(); + if (result == null) { + return; + } TOTPSQLStorage storage = result.storage; TOTPUsedCode[] usedCodes = getAllUsedCodesUtil(storage, "non-existent-user"); @@ -420,7 +445,7 @@ public void getAllUsedCodesTest() throws Exception { TOTPUsedCode validCode3 = new TOTPUsedCode("user", "valid3", true, nextDay, now + 6); storage.createDevice(device); - insertUsedCodesUtil(storage, new TOTPUsedCode[] { + insertUsedCodesUtil(storage, new TOTPUsedCode[]{ validCode1, invalidCode, expiredCode, expiredInvalidCode, validCode2, validCode3 @@ -428,7 +453,7 @@ public void getAllUsedCodesTest() throws Exception { // Try to create a code with same user and created time. It should fail: assertThrows(UsedCodeAlreadyExistsException.class, - () -> insertUsedCodesUtil(storage, new TOTPUsedCode[] { + () -> insertUsedCodesUtil(storage, new TOTPUsedCode[]{ new TOTPUsedCode("user", "any-code", true, nextDay, now + 1) })); @@ -448,6 +473,9 @@ public void getAllUsedCodesTest() throws Exception { @Test public void removeExpiredCodesTest() throws Exception { TestSetupResult result = initSteps(); + if (result == null) { + return; + } TOTPSQLStorage storage = result.storage; long now = System.currentTimeMillis(); @@ -461,7 +489,7 @@ public void removeExpiredCodesTest() throws Exception { TOTPUsedCode invalidCodeToExpire = new TOTPUsedCode("user", "invalid", false, halfSecond, now + 3); storage.createDevice(device); - insertUsedCodesUtil(storage, new TOTPUsedCode[] { + insertUsedCodesUtil(storage, new TOTPUsedCode[]{ validCodeToLive, invalidCodeToLive, validCodeToExpire, invalidCodeToExpire }); diff --git a/src/test/java/io/supertokens/test/totp/TotpLicenseTest.java b/src/test/java/io/supertokens/test/totp/TotpLicenseTest.java index 92ebba597..c0f0745c5 100644 --- a/src/test/java/io/supertokens/test/totp/TotpLicenseTest.java +++ b/src/test/java/io/supertokens/test/totp/TotpLicenseTest.java @@ -22,7 +22,6 @@ import io.supertokens.featureflag.EE_FEATURES; import io.supertokens.featureflag.FeatureFlagTestContent; import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; -import io.supertokens.test.httpRequest.HttpResponseException; import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.pluginInterface.totp.TOTPDevice; import io.supertokens.pluginInterface.totp.TOTPStorage; @@ -30,6 +29,7 @@ import io.supertokens.test.TestingProcessManager; import io.supertokens.test.Utils; import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.httpRequest.HttpResponseException; import io.supertokens.totp.Totp; import org.junit.AfterClass; import org.junit.Before; @@ -37,13 +37,13 @@ import org.junit.Test; import org.junit.rules.TestRule; - import static io.supertokens.test.totp.TOTPRecipeTest.generateTotpCode; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; public class TotpLicenseTest { - public final static String OPAQUE_KEY_WITH_TOTP_FEATURE = "pXhNK=nYiEsb6gJEOYP2kIR6M0kn4XLvNqcwT1XbX8xHtm44K-lQfGCbaeN0Ieeza39fxkXr=tiiUU=DXxDH40Y=4FLT4CE-rG1ETjkXxO4yucLpJvw3uSegPayoISGL"; + public final static String OPAQUE_KEY_WITH_TOTP_FEATURE = "pXhNK=nYiEsb6gJEOYP2kIR6M0kn4XLvNqcwT1XbX8xHtm44K" + + "-lQfGCbaeN0Ieeza39fxkXr=tiiUU=DXxDH40Y=4FLT4CE-rG1ETjkXxO4yucLpJvw3uSegPayoISGL"; @Rule public TestRule watchman = Utils.getOnFailure(); @@ -69,13 +69,13 @@ public TestSetupResult(TOTPStorage storage, TestingProcessManager.TestingProcess } public TestSetupResult defaultInit() throws InterruptedException { - String[] args = { "../" }; + String[] args = {"../"}; TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { - assert (false); + return null; } TOTPStorage storage = StorageLayer.getTOTPStorage(process.getProcess()); @@ -85,6 +85,9 @@ public TestSetupResult defaultInit() throws InterruptedException { @Test public void testTotpWithoutLicense() throws Exception { TestSetupResult result = defaultInit(); + if (result == null) { + return; + } Main main = result.process.getProcess(); // Create device @@ -153,7 +156,11 @@ public void testTotpWithoutLicense() throws Exception { @Test public void testTotpWithLicense() throws Exception { TestSetupResult result = defaultInit(); - FeatureFlagTestContent.getInstance(result.process.main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[] { EE_FEATURES.TOTP }); + if (result == null) { + return; + } + FeatureFlagTestContent.getInstance(result.process.main) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{EE_FEATURES.TOTP}); Main main = result.process.getProcess();