Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Datadir override #118

Merged
merged 9 commits into from
Mar 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:
pull_request:
workflow_dispatch:
schedule:
# every other sunday at noon.
# twice per month
- cron: "0 10 1,15 * *"

jobs:
Expand Down
30 changes: 17 additions & 13 deletions nereid/nereid/api/api_v1/async_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

from celery import Task
from celery.exceptions import TimeoutError
from celery.result import AsyncResult
from fastapi import APIRouter
from fastapi import Request

from nereid.core.config import settings

Expand All @@ -21,33 +21,37 @@ def wait_a_sec_and_see_if_we_can_return_some_data(


def run_task(
request: Request,
task: Task,
router: APIRouter,
get_route: str,
get_route: str = "get_task",
force_foreground: Optional[bool] = False,
timeout: float = 0.2,
) -> Dict[str, Any]:

if force_foreground or settings.FORCE_FOREGROUND: # pragma: no cover
response = dict(data=task(), task_id="foreground", result_route="foreground")
task_ret: Union[bytes, str] = task()
response = {
"data": task_ret,
"task_id": "foreground",
"result_route": "foreground",
}

else:
response = standard_json_response(task.apply_async(), router, get_route)
response = standard_json_response(
request, task.apply_async(), get_route=get_route, timeout=timeout
)

return response


def standard_json_response(
request: Request,
task: AsyncResult,
router: APIRouter,
get_route: str,
get_route: str = "get_task",
timeout: float = 0.2,
api_version: str = settings.API_LATEST,
) -> Dict[str, Any]:
router_path = router.url_path_for(get_route, task_id=task.id)

result_route = f"{api_version}{router_path}"

_ = wait_a_sec_and_see_if_we_can_return_some_data(task, timeout=timeout)
result_route = str(request.url_for(get_route, task_id=task.id))

response = dict(task_id=task.task_id, status=task.status, result_route=result_route)

Expand Down
14 changes: 8 additions & 6 deletions nereid/nereid/api/api_v1/endpoints_async/land_surface_loading.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict

from fastapi import APIRouter, Body, Depends
from fastapi import APIRouter, Body, Depends, Request
from fastapi.responses import ORJSONResponse

import nereid.bg_worker as bg
Expand All @@ -21,6 +21,7 @@
response_class=ORJSONResponse,
)
async def calculate_loading(
request: Request,
land_surfaces: LandSurfaces = Body(...),
details: bool = False,
context: dict = Depends(get_valid_context),
Expand All @@ -32,9 +33,7 @@ async def calculate_loading(
land_surfaces=land_surfaces_req, details=details, context=context
)

return run_task(
task=task, router=router, get_route="get_land_surface_loading_result"
)
return run_task(request, task, "get_land_surface_loading_result")


@router.get(
Expand All @@ -43,6 +42,9 @@ async def calculate_loading(
response_model=LandSurfaceResponse,
response_class=ORJSONResponse,
)
async def get_land_surface_loading_result(task_id: str) -> Dict[str, Any]:
async def get_land_surface_loading_result(
request: Request,
task_id: str,
) -> Dict[str, Any]:
task = bg.land_surface_loading.AsyncResult(task_id, app=router)
return standard_json_response(task, router, "get_land_surface_loading_result")
return standard_json_response(request, task, "get_land_surface_loading_result")
25 changes: 15 additions & 10 deletions nereid/nereid/api/api_v1/endpoints_async/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
response_class=ORJSONResponse,
)
async def validate_network(
graph: network_models.Graph = Body(..., examples=network_models.GraphExamples)
request: Request,
graph: network_models.Graph = Body(..., examples=network_models.GraphExamples),
) -> Dict[str, Any]:

task = bg.validate_network.s(graph=graph.dict(by_alias=True))
return run_task(task=task, router=router, get_route="get_validate_network_result")
return run_task(request, task, "get_validate_network_result")


@router.get(
Expand All @@ -36,10 +37,10 @@ async def validate_network(
response_model=network_models.NetworkValidationResponse,
response_class=ORJSONResponse,
)
async def get_validate_network_result(task_id: str) -> Dict[str, Any]:
async def get_validate_network_result(request: Request, task_id: str) -> Dict[str, Any]:

task = bg.validate_network.AsyncResult(task_id, app=router)
return standard_json_response(task, router, "get_validate_network_result")
return standard_json_response(request, task, "get_validate_network_result")


@router.post(
Expand All @@ -49,12 +50,13 @@ async def get_validate_network_result(task_id: str) -> Dict[str, Any]:
response_class=ORJSONResponse,
)
async def subgraph_network(
request: Request,
subgraph_req: network_models.SubgraphRequest = Body(...),
) -> Dict[str, Any]:

task = bg.network_subgraphs.s(**subgraph_req.dict(by_alias=True))

return run_task(task=task, router=router, get_route="get_subgraph_network_result")
return run_task(request, task, "get_subgraph_network_result")


@router.get(
Expand All @@ -63,10 +65,10 @@ async def subgraph_network(
response_model=network_models.SubgraphResponse,
response_class=ORJSONResponse,
)
async def get_subgraph_network_result(task_id: str) -> Dict[str, Any]:
async def get_subgraph_network_result(request: Request, task_id: str) -> Dict[str, Any]:

task = bg.network_subgraphs.AsyncResult(task_id, app=router)
return standard_json_response(task, router, "get_subgraph_network_result")
return standard_json_response(request, task, "get_subgraph_network_result")


@router.get(
Expand Down Expand Up @@ -125,6 +127,7 @@ async def get_subgraph_network_as_img(
response_class=ORJSONResponse,
)
async def network_solution_sequence(
request: Request,
graph: network_models.Graph = Body(..., examples=network_models.GraphExamples),
min_branch_size: int = Query(4),
) -> Dict[str, Any]:
Expand All @@ -133,7 +136,7 @@ async def network_solution_sequence(
graph=graph.dict(by_alias=True), min_branch_size=min_branch_size
)

return run_task(task=task, router=router, get_route="get_network_solution_sequence")
return run_task(request, task, "get_network_solution_sequence")


@router.get(
Expand All @@ -142,10 +145,12 @@ async def network_solution_sequence(
response_model=network_models.SolutionSequenceResponse,
response_class=ORJSONResponse,
)
async def get_network_solution_sequence(task_id: str) -> Dict[str, Any]:
async def get_network_solution_sequence(
request: Request, task_id: str
) -> Dict[str, Any]:

task = bg.solution_sequence.AsyncResult(task_id, app=router)
return standard_json_response(task, router, "get_network_solution_sequence")
return standard_json_response(request, task, "get_network_solution_sequence")


@router.get(
Expand Down
16 changes: 6 additions & 10 deletions nereid/nereid/api/api_v1/endpoints_async/reference_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

from fastapi import APIRouter, Depends, HTTPException
from fastapi.requests import Request
Expand Down Expand Up @@ -45,10 +45,11 @@ async def get_reference_data_json(
state, region = context["state"], context["region"]

if filepath.is_file():
filedata: Union[Dict[str, Any], str] = ""
loader: Callable[[Union[Path, str]], Union[Dict[str, Any], str]] = load_file
if "json" in filepath.suffix.lower():
filedata: Dict[str, Any] = load_json(filepath)
else:
filedata: str = load_file(filepath) # type: ignore
loader = load_json
filedata = loader(filepath)

else:
detail = f"state '{state}', region '{region}', or filename '{filename}' not found. {filepath}"
Expand All @@ -62,12 +63,7 @@ async def get_reference_data_json(
return response


@router.get(
"/reference_data/nomograph",
tags=["reference_data"],
# response_model=ReferenceDataResponse,
# response_class=ORJSONResponse,
)
@router.get("/reference_data/nomograph", tags=["reference_data"])
async def get_nomograph(
request: Request,
context: dict = Depends(get_valid_context),
Expand Down
6 changes: 3 additions & 3 deletions nereid/nereid/api/api_v1/endpoints_async/tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict

from fastapi import APIRouter
from fastapi import APIRouter, Request
from fastapi.responses import ORJSONResponse

from nereid.api.api_v1.async_utils import standard_json_response
Expand All @@ -11,6 +11,6 @@


@router.get("/{task_id}", response_model=JSONAPIResponse)
async def get_task(task_id: str) -> Dict[str, Any]:
async def get_task(request: Request, task_id: str) -> Dict[str, Any]:
task = celery_app.AsyncResult(task_id)
return standard_json_response(task, router=router, get_route="get_task")
return standard_json_response(request, task)
15 changes: 8 additions & 7 deletions nereid/nereid/api/api_v1/endpoints_async/treatment_facilities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, Tuple, Union

from fastapi import APIRouter, Body, Depends
from fastapi import APIRouter, Body, Depends, Request
from fastapi.responses import ORJSONResponse

import nereid.bg_worker as bg
Expand Down Expand Up @@ -40,9 +40,10 @@ def validate_facility_request(
response_class=ORJSONResponse,
)
async def initialize_treatment_facility_parameters(
request: Request,
tmnt_facility_req: Tuple[TreatmentFacilities, Dict[str, Any]] = Depends(
validate_facility_request
)
),
) -> Dict[str, Any]:

treatment_facilities, context = tmnt_facility_req
Expand All @@ -52,9 +53,7 @@ async def initialize_treatment_facility_parameters(
pre_validated=True,
context=context,
)
return run_task(
task=task, router=router, get_route="get_treatment_facility_parameters"
)
return run_task(request, task, "get_treatment_facility_parameters")


@router.get(
Expand All @@ -63,6 +62,8 @@ async def initialize_treatment_facility_parameters(
response_model=TreatmentFacilitiesResponse,
response_class=ORJSONResponse,
)
async def get_treatment_facility_parameters(task_id: str) -> Dict[str, Any]:
async def get_treatment_facility_parameters(
request: Request, task_id: str
) -> Dict[str, Any]:
task = bg.initialize_treatment_facilities.AsyncResult(task_id, app=router)
return standard_json_response(task, router, "get_treatment_facility_parameters")
return standard_json_response(request, task, "get_treatment_facility_parameters")
9 changes: 5 additions & 4 deletions nereid/nereid/api/api_v1/endpoints_async/watershed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, Tuple

from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Request
from fastapi.responses import ORJSONResponse

import nereid.bg_worker as bg
Expand Down Expand Up @@ -39,6 +39,7 @@ def validate_watershed_request(
response_class=ORJSONResponse,
)
async def post_solve_watershed(
request: Request,
watershed_pkg: Tuple[Dict[str, Any], Dict[str, Any]] = Depends(
validate_watershed_request
),
Expand All @@ -48,7 +49,7 @@ async def post_solve_watershed(
task = bg.solve_watershed.s(
watershed=watershed, treatment_pre_validated=True, context=context
)
return run_task(task=task, router=router, get_route="get_watershed_result")
return run_task(request, task, "get_watershed_result")


@router.get(
Expand All @@ -57,6 +58,6 @@ async def post_solve_watershed(
response_model=WatershedResponse,
response_class=ORJSONResponse,
)
async def get_watershed_result(task_id: str) -> Dict[str, Any]:
async def get_watershed_result(request: Request, task_id: str) -> Dict[str, Any]:
task = bg.solve_watershed.AsyncResult(task_id, app=router)
return standard_json_response(task, router, "get_watershed_result")
return standard_json_response(request, task, "get_watershed_result")
26 changes: 11 additions & 15 deletions nereid/nereid/api/api_v1/endpoints_sync/reference_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

from fastapi import APIRouter, Depends, HTTPException
from fastapi.requests import Request
Expand All @@ -17,8 +17,7 @@

@router.get("/reference_data_file", tags=["reference_data"])
async def get_reference_data_file(
context: dict = Depends(get_valid_context),
filename: str = "",
context: dict = Depends(get_valid_context), filename: str = ""
) -> FileResponse:

filepath = Path(context.get("data_path", "")) / filename
Expand All @@ -39,21 +38,23 @@ async def get_reference_data_file(
response_class=ORJSONResponse,
)
async def get_reference_data_json(
context: dict = Depends(get_valid_context),
filename: str = "",
context: dict = Depends(get_valid_context), filename: str = ""
) -> Dict[str, Any]:

filepath = Path(context.get("data_path", "")) / filename
state, region = context["state"], context["region"]

if filepath.is_file():
filedata: Union[Dict[str, Any], str] = ""
loader: Callable[[Union[Path, str]], Union[Dict[str, Any], str]] = load_file
if "json" in filepath.suffix.lower():
filedata: Dict[str, Any] = load_json(filepath)
else:
filedata: str = load_file(filepath) # type: ignore
loader = load_json
filedata = loader(filepath)

else:
detail = f"state '{state}', region '{region}', or filename '{filename}' not found. {filepath}"
detail = (
f"state '{state}', region '{region}', or filename '{filename}' not found."
)
raise HTTPException(status_code=400, detail=detail)

response = dict(
Expand All @@ -64,12 +65,7 @@ async def get_reference_data_json(
return response


@router.get(
"/reference_data/nomograph",
tags=["reference_data"],
# response_model=ReferenceDataResponse,
# response_class=ORJSONResponse,
)
@router.get("/reference_data/nomograph", tags=["reference_data"])
async def get_nomograph(
request: Request,
context: dict = Depends(get_valid_context),
Expand Down
Loading