diff --git a/neon_users_service/databases/sqlite.py b/neon_users_service/databases/sqlite.py index 90ff37b..ce6d28c 100644 --- a/neon_users_service/databases/sqlite.py +++ b/neon_users_service/databases/sqlite.py @@ -3,10 +3,11 @@ from os import makedirs from os.path import expanduser, dirname from sqlite3 import connect, Cursor -from typing import Optional +from threading import Lock +from typing import Optional, List from neon_users_service.databases import UserDatabase -from neon_users_service.exceptions import UserNotFoundError, UserExistsError +from neon_users_service.exceptions import UserNotFoundError, UserExistsError, DatabaseError from neon_users_service.models import User, AccessRoles @@ -15,6 +16,7 @@ def __init__(self, db_path: Optional[str] = None): db_path = expanduser(db_path or "~/.local/share/neon/user-db.sqlite") makedirs(dirname(db_path), exist_ok=True) self.connection = connect(db_path) + self._db_lock = Lock() self.connection.execute( '''CREATE TABLE IF NOT EXISTS users (user_id text, @@ -27,46 +29,48 @@ def __init__(self, db_path: Optional[str] = None): def create_user(self, user: User) -> User: if self._check_user_exists(user): raise UserExistsError(user) - - self.connection.execute( - f'''INSERT INTO users VALUES - ('{user.user_id}', - '{user.created_timestamp}', - '{user.username}', - '{user.model_dump_json()}')''' - ) - self.connection.commit() + with self._db_lock: + self.connection.execute( + f'''INSERT INTO users VALUES + ('{user.user_id}', + '{user.created_timestamp}', + '{user.username}', + '{user.model_dump_json()}')''' + ) + self.connection.commit() return user @staticmethod - def _parse_lookup_results(user_spec: str, cursor: Cursor): - rows = cursor.fetchall() - cursor.close() - + def _parse_lookup_results(user_spec: str, rows: List[tuple]) -> str: if len(rows) > 1: - # TODO: Custom exception - raise RuntimeError(f"User with spec '{user_spec}' has duplicate entries!") + raise DatabaseError(f"User with spec '{user_spec}' has duplicate entries!") elif len(rows) == 0: raise UserNotFoundError(user_spec) return rows[0][0] def read_user_by_id(self, user_id: str) -> User: - cursor = self.connection.cursor() - cursor.execute( - f'''SELECT user_object FROM users WHERE - user_id = '{user_id}' - ''' - ) - return User(**json.loads(self._parse_lookup_results(user_id, cursor))) + with self._db_lock: + cursor = self.connection.cursor() + cursor.execute( + f'''SELECT user_object FROM users WHERE + user_id = '{user_id}' + ''' + ) + rows = cursor.fetchall() + cursor.close() + return User(**json.loads(self._parse_lookup_results(user_id, rows))) def read_user_by_username(self, username: str) -> User: - cursor = self.connection.cursor() - cursor.execute( - f'''SELECT user_object FROM users WHERE - username = '{username}' - ''' - ) - return User(**json.loads(self._parse_lookup_results(username, cursor))) + with self._db_lock: + cursor = self.connection.cursor() + cursor.execute( + f'''SELECT user_object FROM users WHERE + username = '{username}' + ''' + ) + rows = cursor.fetchall() + cursor.close() + return User(**json.loads(self._parse_lookup_results(username, rows))) def update_user(self, user: User) -> User: # Lookup user to ensure they exist in the database @@ -77,20 +81,22 @@ def update_user(self, user: User) -> User: f"'{user.username}' already exists") except UserNotFoundError: pass - self.connection.execute( - f'''UPDATE users SET username = '{user.username}', - user_object = '{user.model_dump_json()}' - WHERE user_id = '{user.user_id}' - ''' - ) - self.connection.commit() + with self._db_lock: + self.connection.execute( + f'''UPDATE users SET username = '{user.username}', + user_object = '{user.model_dump_json()}' + WHERE user_id = '{user.user_id}' + ''' + ) + self.connection.commit() return self.read_user_by_id(user.user_id) def delete_user(self, user_id: str) -> User: # Lookup user to ensure they exist in the database user_to_delete = self.read_user_by_id(user_id) - self.connection.execute(f"DELETE FROM users WHERE user_id = '{user_id}'") - self.connection.commit() + with self._db_lock: + self.connection.execute(f"DELETE FROM users WHERE user_id = '{user_id}'") + self.connection.commit() return user_to_delete def shutdown(self): diff --git a/neon_users_service/exceptions.py b/neon_users_service/exceptions.py index 6df8bee..38f82a1 100644 --- a/neon_users_service/exceptions.py +++ b/neon_users_service/exceptions.py @@ -20,4 +20,10 @@ class ConfigurationError(KeyError): class AuthenticationError(ValueError): """ Raised when authentication fails for an existing valid user. + """ + + +class DatabaseError(RuntimeError): + """ + Raised when a database-related error occurs. """ \ No newline at end of file diff --git a/neon_users_service/models.py b/neon_users_service/models.py index 38a0fa7..1cb203a 100644 --- a/neon_users_service/models.py +++ b/neon_users_service/models.py @@ -121,14 +121,14 @@ class TokenConfig(BaseModel): class User(BaseModel): username: str - password_hash: str + password_hash: Optional[str] user_id: str = Field(default_factory=lambda: str(uuid4())) created_timestamp: int = Field(default_factory=lambda: round(time())) neon: NeonUserConfig = NeonUserConfig() klat: KlatConfig = KlatConfig() llm: BrainForgeConfig = BrainForgeConfig() permissions: PermissionsConfig = PermissionsConfig() - tokens: List[TokenConfig] = [] + tokens: Optional[List[TokenConfig]] = [] def __eq__(self, other): return self.model_dump() == other.model_dump() diff --git a/neon_users_service/service.py b/neon_users_service/service.py index 3aad1d7..985a61e 100644 --- a/neon_users_service/service.py +++ b/neon_users_service/service.py @@ -1,10 +1,11 @@ import hashlib import re + from copy import copy from typing import Optional from ovos_config import Configuration from neon_users_service.databases import UserDatabase -from neon_users_service.exceptions import ConfigurationError, AuthenticationError +from neon_users_service.exceptions import ConfigurationError, AuthenticationError, UserNotFoundError from neon_users_service.models import User @@ -51,7 +52,20 @@ def create_user(self, user: User) -> User: user.password_hash = self._ensure_hashed(user.password_hash) return self.database.create_user(user) - def authenticate_user(self, username: str, password: str) -> User: + def read_unauthenticated_user(self, user_spec: str) -> User: + """ + Helper to get a user from the database with sensitive data removed. + This is what most lookups should return; `authenticate_user` can be + used to get an un-redacted User object. + @param user_spec: username or user_id to retrieve + @returns: Redacted User object with sensitive information removed + """ + user = self.database.read_user(user_spec) + user.password_hash = None + user.tokens = [] + return user + + def read_authenticated_user(self, username: str, password: str) -> User: """ Helper to get a user from the database, only if the requested username and password match a database entry. @@ -67,6 +81,32 @@ def authenticate_user(self, username: str, password: str) -> User: raise AuthenticationError(f"Invalid password for {username}") return user + def update_user(self, user: User) -> User: + """ + Helper to update a user. If the supplied user's password is not defined, + an `AuthenticationError` will be raised. + @param user: The updated user object to update in the database + @retruns: User object as it exists in the database, after updating + """ + if not user.password_hash: + raise ValueError("Supplied user password is empty") + if not isinstance(user.tokens, list): + raise ValueError("Supplied tokens configuration is not a list") + # This will raise a `UserNotFound` exception if the user doesn't exist + return self.database.update_user(user) + + def delete_user(self, user: User) -> User: + """ + Helper to remove a user from the database. If the supplied user does not + match any database entry, a `UserNotFoundError` will be raised. + @param user: The user object to remove from the database + @returns: User object removed from the database + """ + db_user = self.database.read_user_by_id(user.user_id) + if db_user != user: + raise UserNotFoundError(user) + return self.database.delete_user(user.user_id) + def shutdown(self): """ Shutdown the service. diff --git a/tests/test_service.py b/tests/test_service.py index 78b52c3..86b5234 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -50,20 +50,69 @@ def test_create_user(self): self.assertNotEqual(user_2, input_user_2) service.shutdown() - def test_authenticate_user(self): + def test_read_authenticated_user(self): service = NeonUsersService(self.test_config) string_password = "super secret password" hashed_password = hashlib.sha256(string_password.encode()).hexdigest() user_1 = service.create_user(User(username="user", password_hash=hashed_password)) - auth_1 = service.authenticate_user("user", string_password) + auth_1 = service.read_authenticated_user("user", string_password) self.assertEqual(auth_1, user_1) - auth_2 = service.authenticate_user("user", hashed_password) + auth_2 = service.read_authenticated_user("user", hashed_password) self.assertEqual(auth_2, user_1) with self.assertRaises(AuthenticationError): - service.authenticate_user("user", "bad password") + service.read_authenticated_user("user", "bad password") with self.assertRaises(UserNotFoundError): - service.authenticate_user("user_1", hashed_password) + service.read_authenticated_user("user_1", hashed_password) + service.shutdown() + + def test_read_unauthenticated_user(self): + service = NeonUsersService(self.test_config) + user_1 = service.create_user(User(username="user", + password_hash="test")) + read_user = service.read_unauthenticated_user("user") + self.assertEqual(read_user, service.read_unauthenticated_user(user_1.user_id)) + self.assertIsNone(read_user.password_hash) + self.assertEqual(read_user.tokens, []) + read_user.password_hash = user_1.password_hash + read_user.tokens = user_1.tokens + self.assertEqual(user_1, read_user) + + with self.assertRaises(UserNotFoundError): + service.read_unauthenticated_user("not_a_user") + service.shutdown() + + def test_update_user(self): + service = NeonUsersService(self.test_config) + user_1 = service.create_user(User(username="user", + password_hash="test")) + + # Valid update + user_1.username = "new_username" + updated_user = service.update_user(user_1) + self.assertEqual(updated_user, user_1) + + # Invalid password values + updated_user.password_hash = None + with self.assertRaises(ValueError): + service.update_user(updated_user) + updated_user.password_hash = "" + with self.assertRaises(ValueError): + service.update_user(updated_user) + + # Valid password values + updated_user.password_hash = user_1.password_hash + updated_user = service.update_user(updated_user) + self.assertEqual(updated_user.password_hash, user_1.password_hash) + + # Invalid token values + updated_user.tokens = None + with self.assertRaises(ValueError): + service.update_user(updated_user) + + service.shutdown() + def test_delete_user(self): + pass