From c01817d756dfebf635bbe2e96308df44b524d285 Mon Sep 17 00:00:00 2001 From: Chris Read Date: Thu, 17 Aug 2023 00:52:53 +1000 Subject: [PATCH] Updated pydanja usage, corrected SQLAlchemy session commits and re-formatted scalar invocations --- pdm.lock | 6 ++-- src/traffcap/dto/inbound_request.py | 6 ++-- src/traffcap/dto/outbound_response.py | 6 ++-- src/traffcap/dto/rule.py | 6 ++-- src/traffcap/dto/user.py | 6 ++-- .../inbound_request_repository.py | 26 +++++++-------- .../outbound_response_repository.py | 3 +- src/traffcap/repositories/rule_repository.py | 33 +++++++++++-------- src/traffcap/repositories/user_repository.py | 12 ++++--- 9 files changed, 60 insertions(+), 44 deletions(-) diff --git a/pdm.lock b/pdm.lock index 63a466c..b39aa8c 100644 --- a/pdm.lock +++ b/pdm.lock @@ -517,15 +517,15 @@ files = [ [[package]] name = "pydanja" -version = "0.1.12" +version = "0.1.13" requires_python = ">=3.8" summary = "JSON:API Support for Pydantic" dependencies = [ "pydantic>=2.1.1", ] files = [ - {file = "pydanja-0.1.12-py3-none-any.whl", hash = "sha256:efd30a722ced417094c7ab08a9a03f5f623e67ae904d166de8320a1e2941edea"}, - {file = "pydanja-0.1.12.tar.gz", hash = "sha256:6b6175f45550676f65d394705cf0f6cc2d7c34c0b9b50233fee6254195bf47dd"}, + {file = "pydanja-0.1.13-py3-none-any.whl", hash = "sha256:d461717226c3e1f1c2a05fea44b606b611dc645aec838b6d5524de000436ebad"}, + {file = "pydanja-0.1.13.tar.gz", hash = "sha256:9430d0d4b00f1e29c257de96d5a9c595d57d79e808f9390bf3a7af01e1e523f3"}, ] [[package]] diff --git a/src/traffcap/dto/inbound_request.py b/src/traffcap/dto/inbound_request.py index 2ee5894..0484967 100644 --- a/src/traffcap/dto/inbound_request.py +++ b/src/traffcap/dto/inbound_request.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, Field from fastapi import Request import json from fastapi.encoders import jsonable_encoder @@ -47,7 +47,9 @@ class InboundRequestCreate(InboundRequestBase): class InboundRequest(InboundRequestBase): - id: int + id: int = Field(json_schema_extra={ + "resource_id": True + }) class Config: from_attributes = True diff --git a/src/traffcap/dto/outbound_response.py b/src/traffcap/dto/outbound_response.py index 584e931..b669c07 100644 --- a/src/traffcap/dto/outbound_response.py +++ b/src/traffcap/dto/outbound_response.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, Field class OutboundResponseBase(BaseModel): @@ -12,7 +12,9 @@ class OutboundResponseCreate(OutboundResponseBase): class OutboundResponse(OutboundResponseBase): - id: int + id: int = Field(json_schema_extra={ + "resource_id": True + }) class Config: from_attributes = True diff --git a/src/traffcap/dto/rule.py b/src/traffcap/dto/rule.py index 8884081..42784e9 100644 --- a/src/traffcap/dto/rule.py +++ b/src/traffcap/dto/rule.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, Field class RuleBase(BaseModel): @@ -10,7 +10,9 @@ class RuleCreate(RuleBase): class Rule(RuleBase): - id: int + id: int = Field(json_schema_extra={ + "resource_id": True + }) class Config: from_attributes = True diff --git a/src/traffcap/dto/user.py b/src/traffcap/dto/user.py index f58ea90..048bcd8 100644 --- a/src/traffcap/dto/user.py +++ b/src/traffcap/dto/user.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, Field class UserBase(BaseModel): @@ -11,7 +11,9 @@ class UserCreate(UserBase): class User(UserBase): - id: int + id: int = Field(json_schema_extra={ + "resource_id": True + }) class Config: from_attributes = True diff --git a/src/traffcap/repositories/inbound_request_repository.py b/src/traffcap/repositories/inbound_request_repository.py index d1db4f3..1e9108f 100644 --- a/src/traffcap/repositories/inbound_request_repository.py +++ b/src/traffcap/repositories/inbound_request_repository.py @@ -17,25 +17,23 @@ async def store_request( Store the components of the request """ async with cls.session() as session: - async with session.begin(): - new_inbound_request = await InboundRequestCreate.from_request( - endpoint_code, - request - ) - session.add(InboundRequestModel( - endpoint_code=endpoint_code, - method=new_inbound_request.method, - headers=new_inbound_request.headers, - query_params=new_inbound_request.query_params, - body=new_inbound_request.body - )) + new_inbound_request = await InboundRequestCreate.from_request( + endpoint_code, + request + ) + await session.add(InboundRequestModel(**new_inbound_request.model_dump())) + await session.commit() @classmethod async def get_all_inbound_requests(cls) -> List[InboundRequest]: requests = [] async with cls.session() as session: - results = await session.scalars(select(InboundRequestModel).order_by(InboundRequestModel.id.desc())) + results = await session.scalars( + select(InboundRequestModel).order_by(InboundRequestModel.id.desc()) + ) for request in results.all(): - requests.append(InboundRequest.model_validate(request, from_attributes=True)) + requests.append( + InboundRequest.model_validate(request, from_attributes=True) + ) return requests diff --git a/src/traffcap/repositories/outbound_response_repository.py b/src/traffcap/repositories/outbound_response_repository.py index 8a195f0..7609a97 100644 --- a/src/traffcap/repositories/outbound_response_repository.py +++ b/src/traffcap/repositories/outbound_response_repository.py @@ -17,12 +17,11 @@ async def get_by_rule_and_content_type( """ responses = [] async with cls.session() as session: - stmnt = ( + results = await session.scalars( select(OutboundResponseModel) .where(OutboundResponseModel.rule_id == rule.id) .where(OutboundResponseModel.content_type == content_type) ) - results = await session.scalars(stmnt) for response in results.all(): responses.append(Rule.model_validate(response)) diff --git a/src/traffcap/repositories/rule_repository.py b/src/traffcap/repositories/rule_repository.py index 8aadc8e..3fa3e86 100644 --- a/src/traffcap/repositories/rule_repository.py +++ b/src/traffcap/repositories/rule_repository.py @@ -10,8 +10,9 @@ class RuleRepository(Repository): @classmethod async def get_rule_by_id(cls, rule_id: int) -> Optional[Rule]: async with cls.session() as session: - rule = await session.get(RuleModel, rule_id) - return Rule.model_validate(rule) + return Rule.model_validate( + await session.get(RuleModel, rule_id) + ) return None @@ -19,7 +20,9 @@ async def get_rule_by_id(cls, rule_id: int) -> Optional[Rule]: async def get_all_rules(cls) -> List[Rule]: rules = [] async with cls.session() as session: - results = await session.scalars(select(RuleModel)) + results = await session.scalars( + select(RuleModel) + ) for rule in results.all(): rules.append(Rule.model_validate(rule)) @@ -28,21 +31,23 @@ async def get_all_rules(cls) -> List[Rule]: @classmethod async def create_rule(cls, rule: str = ".*") -> Optional[Rule]: async with cls.session() as session: - async with session.begin(): - new_rule = RuleModel(rule=rule) - session.add(new_rule) + new_rule = RuleModel(rule=rule) + await session.add(new_rule) + await session.commit() - return_rule = await cls.get_rule_by_id(new_rule.id) - return Rule.model_validate(return_rule) + return Rule.model_validate( + await cls.get_rule_by_id(new_rule.id) + ) return None @classmethod async def delete_rule_by_id(cls, rule_id: int) -> None: async with cls.session() as session: - async with session.begin(): - rule = await session.get(RuleModel, rule_id) - await session.delete(rule) + await session.delete( + await session.get(RuleModel, rule_id) + ) + await session.commit() @classmethod async def find_matching_rules(cls, rule: str) -> List[Rule]: @@ -51,8 +56,10 @@ async def find_matching_rules(cls, rule: str) -> List[Rule]: """ rules = [] async with cls.session() as session: - stmnt = select(RuleModel).where(RuleModel.rule == rule) - results = await session.scalars(stmnt) + results = await session.scalars( + select(RuleModel) + .where(RuleModel.rule == rule) + ) for rule_item in results.all(): rules.append(Rule.model_validate(rule_item)) diff --git a/src/traffcap/repositories/user_repository.py b/src/traffcap/repositories/user_repository.py index 50754ed..29b3fb7 100644 --- a/src/traffcap/repositories/user_repository.py +++ b/src/traffcap/repositories/user_repository.py @@ -11,12 +11,13 @@ class UserRepository(Repository): async def add_a_test_user(cls) -> Optional[User]: user = None async with cls.session() as session: - async with session.begin(): - user = UserModel( + await session.add( + UserModel( email="centurix@gmail.com", fullname="Chris Read" ) - session.add(user) + ) + await session.commit() return User.model_validate(user) @@ -32,7 +33,10 @@ async def get_user_by_id(cls, user_id: int) -> Optional[User]: async def get_all_users(cls) -> List[User]: users = [] async with cls.session() as session: - results = await session.scalars(select(UserModel)) + results = await session.scalars( + select(UserModel) + ) for user in results.all(): users.append(User.model_validate(user)) + return users