Skip to content

Commit

Permalink
Updated pydanja usage, corrected SQLAlchemy session commits and re-fo…
Browse files Browse the repository at this point in the history
…rmatted scalar invocations
  • Loading branch information
Centurix committed Aug 16, 2023
1 parent 7bb1b47 commit c01817d
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 44 deletions.
6 changes: 3 additions & 3 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions src/traffcap/dto/inbound_request.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel
from pydantic import BaseModel, Field
from fastapi import Request
import json
from fastapi.encoders import jsonable_encoder
Expand Down Expand Up @@ -47,7 +47,9 @@ class InboundRequestCreate(InboundRequestBase):


class InboundRequest(InboundRequestBase):
id: int
id: int = Field(json_schema_extra={
"resource_id": True
})

class Config:
from_attributes = True
Expand Down
6 changes: 4 additions & 2 deletions src/traffcap/dto/outbound_response.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel
from pydantic import BaseModel, Field


class OutboundResponseBase(BaseModel):
Expand All @@ -12,7 +12,9 @@ class OutboundResponseCreate(OutboundResponseBase):


class OutboundResponse(OutboundResponseBase):
id: int
id: int = Field(json_schema_extra={
"resource_id": True
})

class Config:
from_attributes = True
Expand Down
6 changes: 4 additions & 2 deletions src/traffcap/dto/rule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel
from pydantic import BaseModel, Field


class RuleBase(BaseModel):
Expand All @@ -10,7 +10,9 @@ class RuleCreate(RuleBase):


class Rule(RuleBase):
id: int
id: int = Field(json_schema_extra={
"resource_id": True
})

class Config:
from_attributes = True
Expand Down
6 changes: 4 additions & 2 deletions src/traffcap/dto/user.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel
from pydantic import BaseModel, Field


class UserBase(BaseModel):
Expand All @@ -11,7 +11,9 @@ class UserCreate(UserBase):


class User(UserBase):
id: int
id: int = Field(json_schema_extra={
"resource_id": True
})

class Config:
from_attributes = True
Expand Down
26 changes: 12 additions & 14 deletions src/traffcap/repositories/inbound_request_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,23 @@ async def store_request(
Store the components of the request
"""
async with cls.session() as session:
async with session.begin():
new_inbound_request = await InboundRequestCreate.from_request(
endpoint_code,
request
)
session.add(InboundRequestModel(
endpoint_code=endpoint_code,
method=new_inbound_request.method,
headers=new_inbound_request.headers,
query_params=new_inbound_request.query_params,
body=new_inbound_request.body
))
new_inbound_request = await InboundRequestCreate.from_request(
endpoint_code,
request
)
await session.add(InboundRequestModel(**new_inbound_request.model_dump()))
await session.commit()

@classmethod
async def get_all_inbound_requests(cls) -> List[InboundRequest]:
requests = []
async with cls.session() as session:
results = await session.scalars(select(InboundRequestModel).order_by(InboundRequestModel.id.desc()))
results = await session.scalars(
select(InboundRequestModel).order_by(InboundRequestModel.id.desc())
)
for request in results.all():
requests.append(InboundRequest.model_validate(request, from_attributes=True))
requests.append(
InboundRequest.model_validate(request, from_attributes=True)
)

return requests
3 changes: 1 addition & 2 deletions src/traffcap/repositories/outbound_response_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@ async def get_by_rule_and_content_type(
"""
responses = []
async with cls.session() as session:
stmnt = (
results = await session.scalars(
select(OutboundResponseModel)
.where(OutboundResponseModel.rule_id == rule.id)
.where(OutboundResponseModel.content_type == content_type)
)
results = await session.scalars(stmnt)
for response in results.all():
responses.append(Rule.model_validate(response))

Expand Down
33 changes: 20 additions & 13 deletions src/traffcap/repositories/rule_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,19 @@ class RuleRepository(Repository):
@classmethod
async def get_rule_by_id(cls, rule_id: int) -> Optional[Rule]:
async with cls.session() as session:
rule = await session.get(RuleModel, rule_id)
return Rule.model_validate(rule)
return Rule.model_validate(
await session.get(RuleModel, rule_id)
)

return None

@classmethod
async def get_all_rules(cls) -> List[Rule]:
rules = []
async with cls.session() as session:
results = await session.scalars(select(RuleModel))
results = await session.scalars(
select(RuleModel)
)
for rule in results.all():
rules.append(Rule.model_validate(rule))

Expand All @@ -28,21 +31,23 @@ async def get_all_rules(cls) -> List[Rule]:
@classmethod
async def create_rule(cls, rule: str = ".*") -> Optional[Rule]:
async with cls.session() as session:
async with session.begin():
new_rule = RuleModel(rule=rule)
session.add(new_rule)
new_rule = RuleModel(rule=rule)
await session.add(new_rule)
await session.commit()

return_rule = await cls.get_rule_by_id(new_rule.id)
return Rule.model_validate(return_rule)
return Rule.model_validate(
await cls.get_rule_by_id(new_rule.id)
)

return None

@classmethod
async def delete_rule_by_id(cls, rule_id: int) -> None:
async with cls.session() as session:
async with session.begin():
rule = await session.get(RuleModel, rule_id)
await session.delete(rule)
await session.delete(
await session.get(RuleModel, rule_id)
)
await session.commit()

@classmethod
async def find_matching_rules(cls, rule: str) -> List[Rule]:
Expand All @@ -51,8 +56,10 @@ async def find_matching_rules(cls, rule: str) -> List[Rule]:
"""
rules = []
async with cls.session() as session:
stmnt = select(RuleModel).where(RuleModel.rule == rule)
results = await session.scalars(stmnt)
results = await session.scalars(
select(RuleModel)
.where(RuleModel.rule == rule)
)
for rule_item in results.all():
rules.append(Rule.model_validate(rule_item))

Expand Down
12 changes: 8 additions & 4 deletions src/traffcap/repositories/user_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ class UserRepository(Repository):
async def add_a_test_user(cls) -> Optional[User]:
user = None
async with cls.session() as session:
async with session.begin():
user = UserModel(
await session.add(
UserModel(
email="centurix@gmail.com",
fullname="Chris Read"
)
session.add(user)
)
await session.commit()

return User.model_validate(user)

Expand All @@ -32,7 +33,10 @@ async def get_user_by_id(cls, user_id: int) -> Optional[User]:
async def get_all_users(cls) -> List[User]:
users = []
async with cls.session() as session:
results = await session.scalars(select(UserModel))
results = await session.scalars(
select(UserModel)
)
for user in results.all():
users.append(User.model_validate(user))

return users

0 comments on commit c01817d

Please sign in to comment.