Skip to content

Commit

Permalink
Add locking around database operations
Browse files Browse the repository at this point in the history
Add separate method for reading user entries without authentication data
Add helpers for update/delete operations that perform some degree of input validation
Add unit test coverage for changes
  • Loading branch information
NeonDaniel committed Oct 25, 2024
1 parent 9816572 commit 499c635
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 49 deletions.
86 changes: 46 additions & 40 deletions neon_users_service/databases/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from os import makedirs
from os.path import expanduser, dirname
from sqlite3 import connect, Cursor
from typing import Optional
from threading import Lock
from typing import Optional, List

from neon_users_service.databases import UserDatabase
from neon_users_service.exceptions import UserNotFoundError, UserExistsError
from neon_users_service.exceptions import UserNotFoundError, UserExistsError, DatabaseError
from neon_users_service.models import User, AccessRoles


Expand All @@ -15,6 +16,7 @@ def __init__(self, db_path: Optional[str] = None):
db_path = expanduser(db_path or "~/.local/share/neon/user-db.sqlite")
makedirs(dirname(db_path), exist_ok=True)
self.connection = connect(db_path)
self._db_lock = Lock()
self.connection.execute(
'''CREATE TABLE IF NOT EXISTS users
(user_id text,
Expand All @@ -27,46 +29,48 @@ def __init__(self, db_path: Optional[str] = None):
def create_user(self, user: User) -> User:
if self._check_user_exists(user):
raise UserExistsError(user)

self.connection.execute(
f'''INSERT INTO users VALUES
('{user.user_id}',
'{user.created_timestamp}',
'{user.username}',
'{user.model_dump_json()}')'''
)
self.connection.commit()
with self._db_lock:
self.connection.execute(
f'''INSERT INTO users VALUES
('{user.user_id}',
'{user.created_timestamp}',
'{user.username}',
'{user.model_dump_json()}')'''
)
self.connection.commit()
return user

@staticmethod
def _parse_lookup_results(user_spec: str, cursor: Cursor):
rows = cursor.fetchall()
cursor.close()

def _parse_lookup_results(user_spec: str, rows: List[tuple]) -> str:
if len(rows) > 1:
# TODO: Custom exception
raise RuntimeError(f"User with spec '{user_spec}' has duplicate entries!")
raise DatabaseError(f"User with spec '{user_spec}' has duplicate entries!")
elif len(rows) == 0:
raise UserNotFoundError(user_spec)
return rows[0][0]

def read_user_by_id(self, user_id: str) -> User:
cursor = self.connection.cursor()
cursor.execute(
f'''SELECT user_object FROM users WHERE
user_id = '{user_id}'
'''
)
return User(**json.loads(self._parse_lookup_results(user_id, cursor)))
with self._db_lock:
cursor = self.connection.cursor()
cursor.execute(
f'''SELECT user_object FROM users WHERE
user_id = '{user_id}'
'''
)
rows = cursor.fetchall()
cursor.close()
return User(**json.loads(self._parse_lookup_results(user_id, rows)))

def read_user_by_username(self, username: str) -> User:
cursor = self.connection.cursor()
cursor.execute(
f'''SELECT user_object FROM users WHERE
username = '{username}'
'''
)
return User(**json.loads(self._parse_lookup_results(username, cursor)))
with self._db_lock:
cursor = self.connection.cursor()
cursor.execute(
f'''SELECT user_object FROM users WHERE
username = '{username}'
'''
)
rows = cursor.fetchall()
cursor.close()
return User(**json.loads(self._parse_lookup_results(username, rows)))

def update_user(self, user: User) -> User:
# Lookup user to ensure they exist in the database
Expand All @@ -77,20 +81,22 @@ def update_user(self, user: User) -> User:
f"'{user.username}' already exists")
except UserNotFoundError:
pass
self.connection.execute(
f'''UPDATE users SET username = '{user.username}',
user_object = '{user.model_dump_json()}'
WHERE user_id = '{user.user_id}'
'''
)
self.connection.commit()
with self._db_lock:
self.connection.execute(
f'''UPDATE users SET username = '{user.username}',
user_object = '{user.model_dump_json()}'
WHERE user_id = '{user.user_id}'
'''
)
self.connection.commit()
return self.read_user_by_id(user.user_id)

def delete_user(self, user_id: str) -> User:
# Lookup user to ensure they exist in the database
user_to_delete = self.read_user_by_id(user_id)
self.connection.execute(f"DELETE FROM users WHERE user_id = '{user_id}'")
self.connection.commit()
with self._db_lock:
self.connection.execute(f"DELETE FROM users WHERE user_id = '{user_id}'")
self.connection.commit()
return user_to_delete

def shutdown(self):
Expand Down
6 changes: 6 additions & 0 deletions neon_users_service/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,10 @@ class ConfigurationError(KeyError):
class AuthenticationError(ValueError):
"""
Raised when authentication fails for an existing valid user.
"""


class DatabaseError(RuntimeError):
"""
Raised when a database-related error occurs.
"""
4 changes: 2 additions & 2 deletions neon_users_service/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,14 @@ class TokenConfig(BaseModel):

class User(BaseModel):
username: str
password_hash: str
password_hash: Optional[str]
user_id: str = Field(default_factory=lambda: str(uuid4()))
created_timestamp: int = Field(default_factory=lambda: round(time()))
neon: NeonUserConfig = NeonUserConfig()
klat: KlatConfig = KlatConfig()
llm: BrainForgeConfig = BrainForgeConfig()
permissions: PermissionsConfig = PermissionsConfig()
tokens: List[TokenConfig] = []
tokens: Optional[List[TokenConfig]] = []

def __eq__(self, other):
return self.model_dump() == other.model_dump()
44 changes: 42 additions & 2 deletions neon_users_service/service.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import hashlib
import re

from copy import copy
from typing import Optional
from ovos_config import Configuration
from neon_users_service.databases import UserDatabase
from neon_users_service.exceptions import ConfigurationError, AuthenticationError
from neon_users_service.exceptions import ConfigurationError, AuthenticationError, UserNotFoundError
from neon_users_service.models import User


Expand Down Expand Up @@ -51,7 +52,20 @@ def create_user(self, user: User) -> User:
user.password_hash = self._ensure_hashed(user.password_hash)
return self.database.create_user(user)

def authenticate_user(self, username: str, password: str) -> User:
def read_unauthenticated_user(self, user_spec: str) -> User:
"""
Helper to get a user from the database with sensitive data removed.
This is what most lookups should return; `authenticate_user` can be
used to get an un-redacted User object.
@param user_spec: username or user_id to retrieve
@returns: Redacted User object with sensitive information removed
"""
user = self.database.read_user(user_spec)
user.password_hash = None
user.tokens = []
return user

def read_authenticated_user(self, username: str, password: str) -> User:
"""
Helper to get a user from the database, only if the requested username
and password match a database entry.
Expand All @@ -67,6 +81,32 @@ def authenticate_user(self, username: str, password: str) -> User:
raise AuthenticationError(f"Invalid password for {username}")
return user

def update_user(self, user: User) -> User:
"""
Helper to update a user. If the supplied user's password is not defined,
an `AuthenticationError` will be raised.
@param user: The updated user object to update in the database
@retruns: User object as it exists in the database, after updating
"""
if not user.password_hash:
raise ValueError("Supplied user password is empty")
if not isinstance(user.tokens, list):
raise ValueError("Supplied tokens configuration is not a list")
# This will raise a `UserNotFound` exception if the user doesn't exist
return self.database.update_user(user)

def delete_user(self, user: User) -> User:
"""
Helper to remove a user from the database. If the supplied user does not
match any database entry, a `UserNotFoundError` will be raised.
@param user: The user object to remove from the database
@returns: User object removed from the database
"""
db_user = self.database.read_user_by_id(user.user_id)
if db_user != user:
raise UserNotFoundError(user)
return self.database.delete_user(user.user_id)

def shutdown(self):
"""
Shutdown the service.
Expand Down
59 changes: 54 additions & 5 deletions tests/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,69 @@ def test_create_user(self):
self.assertNotEqual(user_2, input_user_2)
service.shutdown()

def test_authenticate_user(self):
def test_read_authenticated_user(self):
service = NeonUsersService(self.test_config)
string_password = "super secret password"
hashed_password = hashlib.sha256(string_password.encode()).hexdigest()
user_1 = service.create_user(User(username="user",
password_hash=hashed_password))
auth_1 = service.authenticate_user("user", string_password)
auth_1 = service.read_authenticated_user("user", string_password)
self.assertEqual(auth_1, user_1)
auth_2 = service.authenticate_user("user", hashed_password)
auth_2 = service.read_authenticated_user("user", hashed_password)
self.assertEqual(auth_2, user_1)

with self.assertRaises(AuthenticationError):
service.authenticate_user("user", "bad password")
service.read_authenticated_user("user", "bad password")

with self.assertRaises(UserNotFoundError):
service.authenticate_user("user_1", hashed_password)
service.read_authenticated_user("user_1", hashed_password)
service.shutdown()

def test_read_unauthenticated_user(self):
service = NeonUsersService(self.test_config)
user_1 = service.create_user(User(username="user",
password_hash="test"))
read_user = service.read_unauthenticated_user("user")
self.assertEqual(read_user, service.read_unauthenticated_user(user_1.user_id))
self.assertIsNone(read_user.password_hash)
self.assertEqual(read_user.tokens, [])
read_user.password_hash = user_1.password_hash
read_user.tokens = user_1.tokens
self.assertEqual(user_1, read_user)

with self.assertRaises(UserNotFoundError):
service.read_unauthenticated_user("not_a_user")
service.shutdown()

def test_update_user(self):
service = NeonUsersService(self.test_config)
user_1 = service.create_user(User(username="user",
password_hash="test"))

# Valid update
user_1.username = "new_username"
updated_user = service.update_user(user_1)
self.assertEqual(updated_user, user_1)

# Invalid password values
updated_user.password_hash = None
with self.assertRaises(ValueError):
service.update_user(updated_user)
updated_user.password_hash = ""
with self.assertRaises(ValueError):
service.update_user(updated_user)

# Valid password values
updated_user.password_hash = user_1.password_hash
updated_user = service.update_user(updated_user)
self.assertEqual(updated_user.password_hash, user_1.password_hash)

# Invalid token values
updated_user.tokens = None
with self.assertRaises(ValueError):
service.update_user(updated_user)

service.shutdown()

def test_delete_user(self):
pass

0 comments on commit 499c635

Please sign in to comment.