From d829c6be5254a45689d8bcdb52b28b8a5ed3b5b2 Mon Sep 17 00:00:00 2001 From: Ariana Barzinpour Date: Mon, 8 Apr 2024 07:04:28 +0000 Subject: [PATCH] add optional undefined_as_null param in to_filter --- pygeofilter/backends/sqlalchemy/evaluate.py | 17 ++++++++-------- pygeofilter/backends/sqlalchemy/filters.py | 16 ++++++++++----- tests/backends/sqlalchemy/test_evaluate.py | 22 +++++++++++++++++++-- 3 files changed, 39 insertions(+), 16 deletions(-) diff --git a/pygeofilter/backends/sqlalchemy/evaluate.py b/pygeofilter/backends/sqlalchemy/evaluate.py index 334b82a..c073a67 100644 --- a/pygeofilter/backends/sqlalchemy/evaluate.py +++ b/pygeofilter/backends/sqlalchemy/evaluate.py @@ -8,8 +8,9 @@ class SQLAlchemyFilterEvaluator(Evaluator): - def __init__(self, field_mapping): + def __init__(self, field_mapping, undefined_as_null): self.field_mapping = field_mapping + self.undefined_as_null = undefined_as_null @handle(ast.Not) def not_(self, node, sub): @@ -105,7 +106,7 @@ def bbox(self, node, lhs): @handle(ast.Attribute) def attribute(self, node): - return filters.attribute(node.name, self.field_mapping) + return filters.attribute(node.name, self.field_mapping, self.undefined_as_null) @handle(ast.Arithmetic, subclasses=True) def arithmetic(self, node, lhs, rhs): @@ -133,15 +134,13 @@ def envelope(self, node): return filters.parse_bbox([node.x1, node.y1, node.x2, node.y2]) -def to_filter(ast, field_mapping=None): - """Helper function to translate ECQL AST to Django Query expressions. +def to_filter(ast, field_mapping={}, undefined_as_null=None): + """Helper function to translate ECQL AST to SQLAlchemy Query expressions. :param ast: the abstract syntax tree - :param field_mapping: a dict mapping from the filter name to the Django + :param field_mapping: a dict mapping from the filter name to the SQLAlchemy field lookup. - :param mapping_choices: a dict mapping field lookups to choices. :type ast: :class:`Node` - :returns: a Django query object - :rtype: :class:`django.db.models.Q` + :returns: a SQLAlchemy query object """ - return SQLAlchemyFilterEvaluator(field_mapping).evaluate(ast) + return SQLAlchemyFilterEvaluator(field_mapping, undefined_as_null).evaluate(ast) diff --git a/pygeofilter/backends/sqlalchemy/filters.py b/pygeofilter/backends/sqlalchemy/filters.py index 5d5f32f..e59f8fe 100644 --- a/pygeofilter/backends/sqlalchemy/filters.py +++ b/pygeofilter/backends/sqlalchemy/filters.py @@ -4,7 +4,7 @@ from typing import Callable, Dict from pygeoif import shape -from sqlalchemy import and_, func, not_, or_ +from sqlalchemy import and_, func, not_, or_, null def parse_bbox(box, srid: int = None): minx, miny, maxx, maxy = box @@ -256,15 +256,21 @@ def bbox(lhs, minx, miny, maxx, maxy, crs=4326): return lhs.ST_Intersects(parse_bbox([minx, miny, maxx, maxy], crs)) -def attribute(name, field_mapping=None): +def attribute(name, field_mapping={}, undefined_as_null: bool = None): """Create an attribute lookup expression using a field mapping dictionary. :param name: the field filter name :param field_mapping: the dictionary to use as a lookup. + :param undefined_as_null: how to handle a name not present in field_mapping + (None (default) - leave as-is; True - treat as null; False - throw error) """ - field = field_mapping.get(name, name) - - return field + if undefined_as_null is None: + return field_mapping.get(name, name) + if undefined_as_null: + # return null object if name is not found in field_mapping + return field_mapping.get(name, null()) + # undefined_as_null is False, so raise KeyError if name not found + return field_mapping[name] def literal(value): diff --git a/tests/backends/sqlalchemy/test_evaluate.py b/tests/backends/sqlalchemy/test_evaluate.py index b12e019..12b8591 100644 --- a/tests/backends/sqlalchemy/test_evaluate.py +++ b/tests/backends/sqlalchemy/test_evaluate.py @@ -152,9 +152,9 @@ def db_session(setup_database, connection): transaction.rollback() -def evaluate(session, cql_expr, expected_ids): +def evaluate(session, cql_expr, expected_ids, filter_option=None): ast = parse(cql_expr) - filters = to_filter(ast, FIELD_MAPPING) + filters = to_filter(ast, FIELD_MAPPING, filter_option) q = session.query(Record).join(RecordMeta).filter(filters) results = [row.identifier for row in q] @@ -415,3 +415,21 @@ def test_arith_field_plus_mul_1(db_session): def test_arith_field_plus_mul_2(db_session): evaluate(db_session, "intMetaAttribute = 5 + intAttribute * 1.5", ("A",)) + + +# handling undefined/invalid attributes + + +def test_undef_comp(db_session): + # treat undefined/invalid attribute as null + evaluate(db_session, "missingAttribute > 10", (), True) + + +def test_undef_isnull(db_session): + evaluate(db_session, "missingAttribute IS NULL", ("A", "B"), True) + + +def test_undef_comp_error(db_session): + # error if undefined/invalid attribute + with pytest.raises(KeyError): + evaluate(db_session, "missingAttribute > 10", (), False)