Skip to content

Commit

Permalink
fix: Improving typing
Browse files Browse the repository at this point in the history
Implemented inverse comparisons in Django
Adding tests for inverse comparisons
  • Loading branch information
constantinius committed Nov 25, 2021
1 parent 9117412 commit 6c3584b
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 60 deletions.
116 changes: 61 additions & 55 deletions pygeofilter/backends/django/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from operator import and_, or_, add, sub, mul, truediv
from datetime import datetime, timedelta
from functools import reduce
from typing import Dict, List, Optional, Union

from django.db.models import Q, F, Value
from django.db.models.expressions import Expression
Expand All @@ -37,39 +38,23 @@
from django.contrib.gis.geos import Polygon
from django.contrib.gis.measure import D

ARITHMETIC_TYPES = (Expression, F, Value, int, float)
ArithmeticType = Union[Expression, F, Value, int, float]

# ------------------------------------------------------------------------------
# Filters
# ------------------------------------------------------------------------------


def combine(sub_filters, combinator="AND"):
def combine(sub_filters: List[Q], combinator: str = "AND") -> Q:
""" Combine filters using a logical combinator
:param sub_filters: the filters to combine
:param combinator: a string: "AND" / "OR"
:type sub_filters: list[django.db.models.Q]
:return: the combined filter
:rtype: :class:`django.db.models.Q`
"""
for sub_filter in sub_filters:
assert isinstance(sub_filter, Q)

assert combinator in ("AND", "OR")
op = and_ if combinator == "AND" else or_
return reduce(lambda acc, q: op(acc, q) if acc else q, sub_filters)


def negate(sub_filter):
def negate(sub_filter: Q) -> Q:
""" Negate a filter, opposing its meaning.
:param sub_filter: the filter to negate
:type sub_filter: :class:`django.db.models.Q`
:return: the negated filter
:rtype: :class:`django.db.models.Q`
"""
assert isinstance(sub_filter, Q)
return ~sub_filter


Expand All @@ -82,8 +67,16 @@ def negate(sub_filter):
"=": "exact"
}

INVERT_COMP = {
"lt": "gt",
"lte": "gte",
"gt": "lt",
"gte": "lte",
}


def compare(lhs, rhs, op, mapping_choices=None):
def compare(lhs: Union[F, Value], rhs: Union[F, Value], op: str,
mapping_choices: Optional[Dict[str, str]] = None) -> Q:
""" Compare a filter with an expression using a comparison operation
:param lhs: the field to compare
Expand All @@ -99,11 +92,19 @@ def compare(lhs, rhs, op, mapping_choices=None):
:return: a comparison expression object
:rtype: :class:`django.db.models.Q`
"""
assert isinstance(lhs, F)
# assert isinstance(rhs, Q) # TODO!!
assert op in OP_TO_COMP
comp = OP_TO_COMP[op]

# if the left hand side is not a field reference, the comparison
# can be be inverted to try if the right hand side is a field
# reference.
if not isinstance(lhs, F):
lhs, rhs = rhs, lhs
comp = INVERT_COMP.get(comp, comp)

# if neither lhs and rhs are fields, we have to fail here
if not isinstance(lhs, F):
raise ValueError(f'Unable to compare non-field {lhs}')

field_name = lhs.name

if mapping_choices and field_name in mapping_choices:
Expand All @@ -121,7 +122,8 @@ def compare(lhs, rhs, op, mapping_choices=None):
return ~Q(**{field_name: rhs})


def between(lhs, low, high, not_=False):
def between(lhs: F, low: Union[F, Value], high: Union[F, Value],
not_: bool = False) -> Q:
""" Create a filter to match elements that have a value within a certain
range.
Expand All @@ -137,15 +139,12 @@ def between(lhs, low, high, not_=False):
:return: a comparison expression object
:rtype: :class:`django.db.models.Q`
"""
assert isinstance(lhs, F)
# assert isinstance(low, BaseExpression)
# assert isinstance(high, BaseExpression) # TODO

q = Q(**{"%s__range" % lhs.name: (low, high)})
return ~q if not_ else q


def like(lhs, pattern, nocase=False, not_=False, mapping_choices=None):
def like(lhs: F, pattern: str, nocase: bool = False, not_: bool = False,
mapping_choices: Optional[Dict[str, str]] = None) -> Q:
""" Create a filter to filter elements according to a string attribute using
wildcard expressions.
Expand All @@ -165,9 +164,6 @@ def like(lhs, pattern, nocase=False, not_=False, mapping_choices=None):
:return: a comparison expression object
:rtype: :class:`django.db.models.Q`
"""
assert isinstance(lhs, F)
assert isinstance(pattern, str)

parts = pattern.split("%")
length = len(parts)

Expand Down Expand Up @@ -231,7 +227,8 @@ def like(lhs, pattern, nocase=False, not_=False, mapping_choices=None):
return ~q if not_ else q


def contains(lhs, items, not_=False, mapping_choices=None):
def contains(lhs: F, items: List[Union[F, Value]], not_: bool = False,
mapping_choices: Optional[Dict[str, str]] = None) -> Q:
""" Create a filter to match elements attribute to be in a list of choices.
:param lhs: the field to compare
Expand All @@ -247,9 +244,6 @@ def contains(lhs, items, not_=False, mapping_choices=None):
:return: a comparison expression object
:rtype: :class:`django.db.models.Q`
"""
assert isinstance(lhs, F)
# for item in items:
# assert isinstance(item, BaseExpression)

if mapping_choices and lhs.name in mapping_choices:
def map_value(item):
Expand All @@ -269,7 +263,7 @@ def map_value(item):
return ~q if not_ else q


def null(lhs, not_=False):
def null(lhs: F, not_: bool = False) -> Q:
""" Create a filter to match elements whose attribute is (not) null
:param lhs: the field to compare
Expand All @@ -280,11 +274,10 @@ def null(lhs, not_=False):
:return: a comparison expression object
:rtype: :class:`django.db.models.Q`
"""
assert isinstance(lhs, F)
return Q(**{"%s__isnull" % lhs.name: not not_})


def temporal(lhs, time_or_period, op):
def temporal(lhs: F, time_or_period: Value, op: str) -> Q:
""" Create a temporal filter for the given temporal attribute.
:param lhs: the field to compare
Expand All @@ -300,8 +293,6 @@ def temporal(lhs, time_or_period, op):
:return: a comparison expression object
:rtype: :class:`django.db.models.Q`
"""
assert isinstance(lhs, F)
assert isinstance(time_or_period, Value)
assert op in (
"BEFORE", "BEFORE OR DURING", "DURING", "DURING OR AFTER", "AFTER"
)
Expand Down Expand Up @@ -331,8 +322,9 @@ def temporal(lhs, time_or_period, op):
return Q(**{"%s__lte" % lhs.name: high})


def time_interval(time_or_period, containment='overlaps',
begin_time_field='begin_time', end_time_field='end_time'):
def time_interval(time_or_period: Value, containment: str = 'overlaps',
begin_time_field: str = 'begin_time',
end_time_field: str = 'end_time') -> Q:
"""
"""

Expand Down Expand Up @@ -390,7 +382,15 @@ def time_interval(time_or_period, containment='overlaps',
}


def spatial(lhs, rhs, op, pattern=None, distance=None, units=None):
INVERT_SPATIAL_OP = {
"WITHIN": "CONTAINS",
"CONTAINS": "WITHIN",
}


def spatial(lhs: Union[F, Value], rhs: Union[F, Value], op: str,
pattern: Optional[str] = None, distance: Optional[float] = None,
units: Optional[str] = None) -> Q:
""" Create a spatial filter for the given spatial attribute.
:param lhs: the field to compare
Expand All @@ -412,8 +412,6 @@ def spatial(lhs, rhs, op, pattern=None, distance=None, units=None):
:return: a comparison expression object
:rtype: :class:`django.db.models.Q`
"""
assert isinstance(lhs, F)
# assert isinstance(rhs, BaseExpression) # TODO

assert op in (
"INTERSECTS", "DISJOINT", "CONTAINS", "WITHIN", "TOUCHES", "CROSSES",
Expand All @@ -425,6 +423,17 @@ def spatial(lhs, rhs, op, pattern=None, distance=None, units=None):
assert distance
assert units

# if the left hand side is not a field reference, the comparison
# can be be inverted to try if the right hand side is a field
# reference.
if not isinstance(lhs, F):
lhs, rhs = rhs, lhs
op = INVERT_SPATIAL_OP.get(op, op)

# if neither lhs and rhs are fields, we have to fail here
if not isinstance(lhs, F):
raise ValueError(f'Unable to compare non-field {lhs}')

if op in (
"INTERSECTS", "DISJOINT", "CONTAINS", "WITHIN", "TOUCHES",
"CROSSES", "OVERLAPS", "EQUALS"):
Expand All @@ -439,7 +448,8 @@ def spatial(lhs, rhs, op, pattern=None, distance=None, units=None):
return Q(**{"%s__distance_gte" % lhs.name: (rhs, d, 'spheroid')})


def bbox(lhs, minx, miny, maxx, maxy, crs=None, bboverlaps=True):
def bbox(lhs: F, minx: float, miny: float, maxx, maxy: float,
crs: Optional[str] = None, bboverlaps: bool = True) -> Q:
""" Create a bounding box filter for the given spatial attribute.
:param lhs: the field to compare
Expand All @@ -457,7 +467,6 @@ def bbox(lhs, minx, miny, maxx, maxy, crs=None, bboverlaps=True):
:return: a comparison expression object
:rtype: :class:`django.db.models.Q`
"""
assert isinstance(lhs, F)
box = Polygon.from_bbox((minx, miny, maxx, maxy))

if crs:
Expand All @@ -469,7 +478,7 @@ def bbox(lhs, minx, miny, maxx, maxy, crs=None, bboverlaps=True):
return Q(**{"%s__intersects" % lhs.name: box})


def attribute(name, field_mapping=None):
def attribute(name: str, field_mapping: Optional[Dict[str, str]] = None) -> F:
""" Create an attribute lookup expression using a field mapping dictionary.
:param name: the field filter name
Expand All @@ -485,7 +494,7 @@ def attribute(name, field_mapping=None):
return F(field)


def literal(value):
def literal(value) -> Value:
return Value(value)


Expand All @@ -497,7 +506,8 @@ def literal(value):
}


def arithmetic(lhs, rhs, op):
def arithmetic(lhs: ArithmeticType, rhs: ArithmeticType,
op: str) -> ArithmeticType:
""" Create an arithmetic filter
:param lhs: left hand side of the arithmetic expression. either a
Expand All @@ -507,9 +517,5 @@ def arithmetic(lhs, rhs, op):
``"/"``
:rtype: :class:`django.db.models.F`
"""

assert isinstance(lhs, ARITHMETIC_TYPES), f'{lhs} is not a compatible type'
assert isinstance(rhs, ARITHMETIC_TYPES), f'{rhs} is not a compatible type'
assert op in OP_TO_FUNC
func = OP_TO_FUNC[op]
return func(lhs, rhs)
15 changes: 11 additions & 4 deletions pygeofilter/backends/native/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(self, function_map: Dict[str, Callable] = None,
self.attribute_map = attribute_map
self.use_getattr = use_getattr
self.allow_nested_attributes = allow_nested_attributes
self.locals = {}
self.locals: Dict[str, Any] = {}
self.local_count = 0

def _add_local(self, value: Any) -> str:
Expand Down Expand Up @@ -328,9 +328,15 @@ def to_interval(value: MaybeInterval) -> InternalInterval:
high = datetime.combine(high, time.max, timezone.utc)

if isinstance(low, timedelta):
low = high - low
if isinstance(high, (date, datetime)):
low = high - low
else:
raise ValueError(f'Cannot combine {low} with {high}')
elif isinstance(high, timedelta):
high = low + high
if isinstance(high, (date, datetime)):
high = low + high
else:
raise ValueError(f'Cannot combine {low} with {high}')

return (low, high)

Expand All @@ -356,7 +362,8 @@ def relate_intervals(lhs: InternalInterval,
"""
ll, lh = lhs
rl, rh = rhs
if None in (ll, lh, rl, rh):
if ll is None or lh is None or rl is None or rh is None:
# TODO: handle open ended intervals (None on either side)
return ast.TemporalComparisonOp.DISJOINT
elif lh < rl:
return ast.TemporalComparisonOp.BEFORE
Expand Down
1 change: 0 additions & 1 deletion tests/backends/django/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


def pytest_configure():
print('configuring')
settings.configure(
SECRET_KEY="secret",
INSTALLED_APPS=[
Expand Down
Loading

0 comments on commit 6c3584b

Please sign in to comment.