Skip to content

Commit

Permalink
Adds 'generic' InstitutionTypeDto and /v1/institutions/types/{type} e…
Browse files Browse the repository at this point in the history
…ndpoint

Adds endpoints for getting address-states and regulators
Adds institution repo get_ functions for the above
Let commented out code in for specific Dtos in case the general approach isn't desired
  • Loading branch information
jcadam14 committed Dec 28, 2023
1 parent 176ad6e commit 20771a1
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 18 deletions.
10 changes: 6 additions & 4 deletions src/entities/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
"SBLInstitutionTypeDao",
"AddressStateDao",
"FederalRegulatorDto",
"HMDAInstitutionTypeDto",
"SBLInstitutionTypeDto",
#"HMDAInstitutionTypeDto",
"InstitutionTypeDto",
#"SBLInstitutionTypeDto",
"AddressStateDto",
]

Expand All @@ -41,7 +42,8 @@
UserProfile,
AuthenticatedUser,
FederalRegulatorDto,
HMDAInstitutionTypeDto,
SBLInstitutionTypeDto,
#HMDAInstitutionTypeDto,
InstitutionTypeDto,
#SBLInstitutionTypeDto,
AddressStateDto,
)
2 changes: 1 addition & 1 deletion src/entities/models/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class FederalRegulatorDao(AuditMixin, Base):
__tablename__ = "federal_regulator"
id: Mapped[str] = mapped_column(String(4), index=True, primary_key=True, unique=True)
name: Mapped[str] = mapped_column(unique=True, nullable=False)


class HMDAInstitutionTypeDao(AuditMixin, Base):
__tablename__ = "hmda_institution_type"
Expand Down
31 changes: 20 additions & 11 deletions src/entities/models/dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,35 @@ class Config:
from_attributes = True


class HMDAInstitutionTypeBase(BaseModel):
class InstitutionTypeDto(BaseModel ):
id: str


class HMDAInstitutionTypeDto(HMDAInstitutionTypeBase):
name: str

class Config:
from_attributes = True

# Let this in here just in case the 'generic' InstitutionTypeDto approach isn't desired
#
# class HMDAInstitutionTypeBase(BaseModel):
# id: str

class SBLInstitutionTypeBase(BaseModel):
id: str

# class HMDAInstitutionTypeDto(HMDAInstitutionTypeBase):
# name: str
#
# class Config:
# from_attributes = True

class SBLInstitutionTypeDto(SBLInstitutionTypeBase):
name: str

class Config:
from_attributes = True
# class SBLInstitutionTypeBase(BaseModel):
# id: str


# class SBLInstitutionTypeDto(SBLInstitutionTypeBase):
# name: str
#
# class Config:
# from_attributes = True


class AddressStateBase(BaseModel):
Expand Down
37 changes: 37 additions & 0 deletions src/entities/repos/institutions_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
FinancialInstitutionDomainDao,
FinancialInstitutionDto,
FinancialInsitutionDomainCreate,
HMDAInstitutionTypeDao,
SBLInstitutionTypeDao,
DeniedDomainDao,
AddressStateDao,
FederalRegulatorDao
)


Expand Down Expand Up @@ -43,7 +47,40 @@ async def get_institution(session: AsyncSession, lei: str) -> FinancialInstituti
.filter(FinancialInstitutionDao.lei == lei)
)
return await session.scalar(stmt)


async def get_sbl_types(session: AsyncSession) -> SBLInstitutionTypeDao:
async with session.begin():
stmt = (
select(SBLInstitutionTypeDao)
)
res = await session.scalars(stmt)
return res.all()


async def get_hmda_types(session: AsyncSession) -> HMDAInstitutionTypeDao:
async with session.begin():
stmt = (
select(HMDAInstitutionTypeDao)
)
res = await session.scalars(stmt)
return res.all()

async def get_address_states(session: AsyncSession) -> AddressStateDao:
async with session.begin():
stmt = (
select(AddressStateDao)
)
res = await session.scalars(stmt)
return res.all()

async def get_federal_regulators(session: AsyncSession) -> FederalRegulatorDao:
async with session.begin():
stmt = (
select(FederalRegulatorDao)
)
res = await session.scalars(stmt)
return res.all()

async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto) -> FinancialInstitutionDao:
async with session.begin():
Expand Down
33 changes: 31 additions & 2 deletions src/routers/institutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from oauth2 import oauth2_admin
from util import Router
from dependencies import check_domain, parse_leis, get_email_domain
from typing import Annotated, List, Tuple
from typing import Annotated, List, Tuple, Literal
from entities.engine import get_session
from entities.repos import institutions_repo as repo
from entities.models import (
Expand All @@ -12,11 +12,17 @@
FinancialInsitutionDomainDto,
FinancialInsitutionDomainCreate,
FinanicialInstitutionAssociationDto,
# HMDAInstitutionTypeDto,
InstitutionTypeDto,
# SBLInstitutionTypeDto,
AuthenticatedUser,
AddressStateDto,
FederalRegulatorDto
)
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.authentication import requires

InstitutionType = Literal["sbl", "hmda"]

async def set_db(request: Request, session: Annotated[AsyncSession, Depends(get_session)]):
request.state.db_session = session
Expand Down Expand Up @@ -79,6 +85,29 @@ async def get_associated_institutions(request: Request):
for institution in associated_institutions
]

@router.get("/types/{type}", response_model=List[InstitutionTypeDto])
@requires("authenticated")
async def get_institution_types(
request: Request,
type: InstitutionType
):
if type == "sbl":
return await repo.get_sbl_types(request.state.db_session)
else:
return await repo.get_hmda_types(request.state.db_session)


@router.get("/address-states", response_model=List[AddressStateDto])
@requires("authenticated")
async def get_address_states(request: Request):
return await repo.get_address_states(request.state.db_session)


@router.get("/regulators", response_model=List[FederalRegulatorDto])
@requires("authenticated")
async def get_federal_regulators(request: Request):
return await repo.get_federal_regulators(request.state.db_session)


@router.get("/{lei}", response_model=FinancialInstitutionWithDomainsDto)
@requires("authenticated")
Expand All @@ -104,4 +133,4 @@ async def add_domains(

@router.get("/domains/allowed", response_model=bool)
async def is_domain_allowed(request: Request, domain: str):
return await repo.is_domain_allowed(request.state.db_session, domain)
return await repo.is_domain_allowed(request.state.db_session, domain)
32 changes: 32 additions & 0 deletions tests/api/routers/test_institutions_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,35 @@ def test_get_associated_institutions_with_no_institutions(
assert res.status_code == 200
get_institutions_mock.assert_called_once_with(ANY, [])
assert res.json() == []

def test_get_institution_types(
self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
mock = mocker.patch("entities.repos.institutions_repo.get_sbl_types")
mock.return_value = []
client = TestClient(app_fixture)
res = client.get("/v1/institutions/types/sbl")
assert res.status_code == 200

mock = mocker.patch("entities.repos.institutions_repo.get_hmda_types")
mock.return_value = []
res = client.get("/v1/institutions/types/hmda")
assert res.status_code == 200

res = client.get("/v1/institutions/types/blah")
assert res.status_code == 422

def test_get_address_states(
self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
mock = mocker.patch("entities.repos.institutions_repo.get_address_states")
mock.return_value = []
client = TestClient(app_fixture)
res = client.get("/v1/institutions/address-states")
assert res.status_code == 200

def test_get_federal_regulators(
self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
mock = mocker.patch("entities.repos.institutions_repo.get_federal_regulators")
mock.return_value = []
client = TestClient(app_fixture)
res = client.get("/v1/institutions/regulators")
assert res.status_code == 200
24 changes: 24 additions & 0 deletions tests/entities/repos/test_institutions_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,31 @@ async def setup(
transaction_session.add(fi_dao_456)
transaction_session.add(fi_dao_sub_456)
await transaction_session.commit()

async def test_get_sbl_types(self, query_session: AsyncSession):
expected_ids = {"SIT1", "SIT2", "SIT3"}
res = await repo.get_sbl_types(query_session)
assert len(res) == 3
assert set([r.id for r in res]) == expected_ids

async def test_get_hmda_types(self, query_session: AsyncSession):
expected_ids = {"HIT1", "HIT2", "HIT3"}
res = await repo.get_hmda_types(query_session)
assert len(res) == 3
assert set([r.id for r in res]) == expected_ids

async def test_get_address_states(self, query_session: AsyncSession):
expected_codes = {"CA", "GA", "FL"}
res = await repo.get_address_states(query_session)
assert len(res) == 3
assert set([r.code for r in res]) == expected_codes

async def test_get_federal_regulators(self, query_session: AsyncSession):
expected_ids = {"FRI1", "FRI2", "FRI3"}
res = await repo.get_federal_regulators(query_session)
assert len(res) == 3
assert set([r.id for r in res]) == expected_ids

async def test_get_institutions(self, query_session: AsyncSession):
res = await repo.get_institutions(query_session)
assert len(res) == 3
Expand Down

0 comments on commit 20771a1

Please sign in to comment.