Skip to content

Commit

Permalink
Add kwarg labels to search funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
devdupont committed Jul 1, 2024
1 parent 0639796 commit e9509b5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 46 deletions.
22 changes: 11 additions & 11 deletions avwx_api/api/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
Functional API endpoints separate from static views
"""
"""Functional API endpoints separate from static views."""

# stdlib
import json
Expand All @@ -17,7 +15,7 @@
# module
import avwx
from avwx_api_core.views import AuthView, Token, make_token_check
from avwx_api import app, handle, structs, validate
from avwx_api import app, structs, validate
from avwx_api.handle.base import ManagerHandler, ReportHandler


Expand All @@ -28,7 +26,7 @@


def parse_params(func):
"""Collects and parses endpoint parameters"""
"""Collect and parses endpoint parameters."""

@wraps(func)
async def wrapper(self, **kwargs):
Expand All @@ -50,9 +48,9 @@ class Base(AuthView):

validator: validate.Schema
struct: structs.Params
report_type: str = None
handler: handle.base.ReportHandler = None
handlers: dict[str, handle.base.ReportHandler] = None
report_type: str | None = None
handler: ReportHandler | None = None
handlers: dict[str, ReportHandler] = None

# Name of parameter used for report location
loc_param: str = "station"
Expand Down Expand Up @@ -98,7 +96,7 @@ class Report(Base):
@crossdomain(origin="*", headers=HEADERS)
@parse_params
@token_check
async def get(self, params: structs.Params, token: Optional[Token]) -> Response:
async def get(self, params: structs.Report, token: Optional[Token]) -> Response:
"""GET handler returning reports"""
config = structs.ParseConfig.from_params(params, token)
await app.station.from_params(params, params.report_type)
Expand All @@ -123,7 +121,9 @@ class Parse(Base):
async def post(self, token: Optional[Token], **kwargs) -> Response:
"""POST handler to parse given reports"""
data = await request.data
params = self.validate_params(report=data.decode() or None, **kwargs)
params: structs.ReportGiven = self.validate_params(
report=data.decode() or None, **kwargs
)
if isinstance(params, dict):
return self.make_response(params, code=400)
config = structs.ParseConfig.from_params(params, token)
Expand Down Expand Up @@ -168,7 +168,7 @@ def split_distances(data: list[avwx.Station | dict]) -> tuple[list, dict]:
@crossdomain(origin="*", headers=HEADERS)
@parse_params
@token_check
async def get(self, params: structs.Params, token: Optional[Token]) -> Response:
async def get(self, params: structs.Report, token: Optional[Token]) -> Response:
"""GET handler returning multiple reports"""
locations, distances = self.split_distances(self.get_locations(params))
config = structs.ParseConfig.from_params(params, token)
Expand Down
67 changes: 32 additions & 35 deletions avwx_api/api/search.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
"""
Search API endpoints
"""

# pylint: disable=arguments-differ,too-many-ancestors
"""Search API endpoints."""

# stdlib
from typing import Any, Optional
Expand All @@ -12,14 +8,13 @@
from quart_openapi.cors import crossdomain

# module
import avwx
import avwx.station
from avwx_api_core.token import Token
import avwx_api.handle.current as handle
from avwx_api import app, structs, validate
from avwx_api.api.base import Base, HEADERS, MultiReport, parse_params, token_check
from avwx_api.station_manager import station_data_for


SEARCH_HANDLERS = {
"metar": handle.MetarHandler,
"taf": handle.TafHandler,
Expand All @@ -32,7 +27,7 @@
def check_count_limit(
count: int, token: Optional[Token], plans: tuple[str]
) -> Optional[dict]:
"""Returns an error payload if the count is greater than the user is allowed"""
"""Return an error payload if the count is greater than the user is allowed."""
if count <= COUNT_MAX or token is None:
return None
if token.is_developer or token.valid_type(plans):
Expand All @@ -45,36 +40,38 @@ def check_count_limit(


def arg_matching(target: Any, args: tuple[Any]) -> Any:
"""Returns the first arg matching the target type"""
"""Return the first arg matching the target type."""
return next((arg for arg in args if isinstance(arg, target)), None)


@app.route("/api/station/near/<coord>")
class Near(Base):
"""Returns stations near a coordinate pair"""
"""Return stations near a coordinate pair."""

validator = validate.coord_search
struct = structs.CoordSearch
loc_param = "coord"
example = "stations_near"

def validate_token_parameters(self, token: Token, *args) -> Optional[dict]:
"""Returns an error payload if parameter validation doesn't match plan level"""
params = arg_matching(self.struct, args) # pylint: disable=no-member
"""Return an error payload if parameter validation doesn't match plan level."""
params: structs.StationSearch = arg_matching(self.struct, args)
return check_count_limit(params.n, token, PAID_PLANS)

@crossdomain(origin="*", headers=HEADERS)
@parse_params
@token_check
async def get(self, params: structs.Params, token: Optional[Token]) -> Response:
"""Returns stations near a coordinate pair"""
async def get(
self, params: structs.CoordSearch, token: Optional[Token]
) -> Response:
"""Return stations near a coordinate pair."""
stations = avwx.station.nearest(
params.coord.lat,
params.coord.lon,
params.n,
params.airport,
params.reporting,
params.maxdist,
is_airport=params.airport,
sends_reports=params.reporting,
max_coord_distance=params.maxdist,
)
if isinstance(stations, dict):
stations = [stations]
Expand All @@ -85,32 +82,32 @@ async def get(self, params: structs.Params, token: Optional[Token]) -> Response:

@app.route("/api/search/station")
class TextSearch(Base):
"""Returns stations from a text-based search"""
"""Return stations from a text-based search."""

validator = validate.text_search
struct = structs.TextSearch
example = "station_search"

def validate_token_parameters(self, token: Token, *args) -> Optional[dict]:
"""Returns an error payload if parameter validation doesn't match plan level"""
params = arg_matching(self.struct, args) # pylint: disable=no-member
"""Return an error payload if parameter validation doesn't match plan level."""
params: structs.StationSearch = arg_matching(self.struct, args)
return check_count_limit(params.n, token, PAID_PLANS)

@crossdomain(origin="*", headers=HEADERS)
@parse_params
@token_check
async def get(self, params: structs.Params, token: Optional[Token]) -> Response:
"""Returns stations from a text-based search"""
async def get(self, params: structs.TextSearch, token: Optional[Token]) -> Response:
"""Return stations from a text-based search."""
stations = avwx.station.search(
params.text, params.n, params.airport, params.reporting
params.text, params.n, is_airport=params.airport, sends_reports=params.reporting
)
stations = [await station_data_for(s, token=token) for s in stations]
return self.make_response(stations, params)


@app.route("/api/<report_type>/near/<coord>")
class ReportCoordSearch(MultiReport):
"""Returns reports nearest to a coordinate"""
"""Return reports nearest to a coordinate."""

validator = validate.report_coord_search
struct = structs.ReportCoordSearch
Expand All @@ -123,18 +120,18 @@ class ReportCoordSearch(MultiReport):
log_postfix = "coord"

def validate_token_parameters(self, token: Token, *args) -> Optional[dict]:
"""Returns an error payload if parameter validation doesn't match plan level"""
params = arg_matching(self.struct, args) # pylint: disable=no-member
"""Return an error payload if parameter validation doesn't match plan level."""
params = arg_matching(self.struct, args)
return check_count_limit(params.n, token, ("enterprise",))

def get_locations(self, params: structs.Params) -> list[dict]:
def get_locations(self, params: structs.CoordSearch) -> list[dict]:
stations = avwx.station.nearest(
params.coord.lat,
params.coord.lon,
params.n,
params.airport,
params.reporting,
params.maxdist,
is_airport=params.airport,
sends_reports=params.reporting,
max_coord_distance=params.maxdist,
)
if isinstance(stations, dict):
stations = [stations]
Expand All @@ -143,7 +140,7 @@ def get_locations(self, params: structs.Params) -> list[dict]:

@app.route("/api/search/<report_type>")
class ReportTextSearch(MultiReport):
"""Returns reports from a text-based search"""
"""Return reports from a text-based search."""

validator = validate.report_text_search
struct = structs.ReportTextSearch
Expand All @@ -155,11 +152,11 @@ class ReportTextSearch(MultiReport):
log_postfix = "search"

def validate_token_parameters(self, token: Token, *args) -> Optional[dict]:
"""Returns an error payload if parameter validation doesn't match plan level"""
params = arg_matching(self.struct, args) # pylint: disable=no-member
"""Return an error payload if parameter validation doesn't match plan level."""
params: structs.StationSearch = arg_matching(self.struct, args)
return check_count_limit(params.n, token, ("enterprise",))

def get_locations(self, params: structs.Params) -> list[dict]:
def get_locations(self, params: structs.TextSearch) -> list[dict]:
return avwx.station.search(
params.text, params.n, params.airport, params.reporting
params.text, params.n, is_airport=params.airport, sends_reports=params.reporting,
)

0 comments on commit e9509b5

Please sign in to comment.