Skip to content

Commit

Permalink
fix(datatypes): proper handling of srid in geospatial datatypes (#9519)
Browse files Browse the repository at this point in the history
Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com>
  • Loading branch information
ncclementi and cpcloud authored Jul 5, 2024
1 parent 6a748c4 commit a3ceb59
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 50 deletions.
27 changes: 27 additions & 0 deletions ibis/backends/postgres/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import annotations

import os
import string
from urllib.parse import quote_plus

import hypothesis as h
Expand All @@ -30,6 +31,7 @@
import ibis.expr.datatypes as dt
import ibis.expr.types as ir
from ibis.backends.tests.errors import PsycoPg2OperationalError
from ibis.util import gen_name

pytest.importorskip("psycopg2")

Expand Down Expand Up @@ -390,3 +392,28 @@ def test_password_with_bracket():
match=f'password authentication failed for user "{IBIS_POSTGRES_USER}"',
):
ibis.connect(url)


def test_create_geospatial_table_with_srid(con):
name = gen_name("geospatial")
column_names = string.ascii_lowercase
column_types = [
"Point",
"LineString",
"Polygon",
"MultiLineString",
"MultiPoint",
"MultiPolygon",
]
schema_string = ", ".join(
f"{column} geometry({dtype}, 4326)"
for column, dtype in zip(column_names, column_types)
)
con.raw_sql(f"CREATE TEMP TABLE {name} ({schema_string})")
schema = con.get_schema(name)
assert schema == ibis.schema(
{
column: getattr(dt, dtype)(srid=4326)
for column, dtype in zip(column_names, column_types)
}
)
53 changes: 41 additions & 12 deletions ibis/backends/sql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
typecode.ENUM16: dt.String,
typecode.FLOAT: dt.Float32,
typecode.FIXEDSTRING: dt.String,
typecode.GEOMETRY: partial(dt.GeoSpatial, geotype="geometry"),
typecode.GEOGRAPHY: partial(dt.GeoSpatial, geotype="geography"),
typecode.HSTORE: partial(dt.Map, dt.string, dt.string),
typecode.INET: dt.INET,
typecode.INT128: partial(dt.Decimal, 38, 0),
Expand Down Expand Up @@ -298,15 +296,27 @@ def _from_sqlglot_DECIMAL(

@classmethod
def _from_sqlglot_GEOMETRY(
cls, arg: sge.DataTypeParam | None = None
cls, arg: sge.DataTypeParam | None = None, srid: sge.DataTypeParam | None = None
) -> sge.DataType:
if arg is not None:
return _geotypes[str(arg).upper()](nullable=cls.default_nullable)
return dt.GeoSpatial(geotype="geometry", nullable=cls.default_nullable)
typeclass = _geotypes[arg.this.this]
else:
typeclass = dt.GeoSpatial
if srid is not None:
srid = int(srid.this.this)
return typeclass(geotype="geometry", nullable=cls.default_nullable, srid=srid)

@classmethod
def _from_sqlglot_GEOGRAPHY(cls) -> sge.DataType:
return dt.GeoSpatial(geotype="geography", nullable=cls.default_nullable)
def _from_sqlglot_GEOGRAPHY(
cls, arg: sge.DataTypeParam | None = None, srid: sge.DataTypeParam | None = None
) -> sge.DataType:
if arg is not None:
typeclass = _geotypes[arg.this.this]
else:
typeclass = dt.GeoSpatial
if srid is not None:
srid = int(srid.this.this)
return typeclass(geotype="geography", nullable=cls.default_nullable, srid=srid)

@classmethod
def _from_ibis_Interval(cls, dtype: dt.Interval) -> sge.DataType:
Expand Down Expand Up @@ -374,13 +384,30 @@ def _from_ibis_Timestamp(cls, dtype: dt.Timestamp) -> sge.DataType:

@classmethod
def _from_ibis_GeoSpatial(cls, dtype: dt.GeoSpatial):
if (geotype := dtype.geotype) is not None:
return sge.DataType(this=getattr(typecode, geotype.upper()))
return sge.DataType(this=typecode.GEOMETRY)
expressions = [None]

if (srid := dtype.srid) is not None:
expressions.append(sge.DataTypeParam(this=sge.convert(srid)))

this = getattr(typecode, dtype.geotype.upper())

return sge.DataType(this=this, expressions=expressions)

@classmethod
def _from_ibis_SpecificGeometry(cls, dtype: dt.GeoSpatial):
expressions = [
sge.DataTypeParam(this=sge.Var(this=dtype.__class__.__name__.upper()))
]

if (srid := dtype.srid) is not None:
expressions.append(sge.DataTypeParam(this=sge.convert(srid)))

this = getattr(typecode, dtype.geotype.upper())
return sge.DataType(this=this, expressions=expressions)

_from_ibis_Point = _from_ibis_LineString = _from_ibis_Polygon = (
_from_ibis_MultiLineString
) = _from_ibis_MultiPoint = _from_ibis_MultiPolygon = _from_ibis_GeoSpatial
) = _from_ibis_MultiPoint = _from_ibis_MultiPolygon = _from_ibis_SpecificGeometry


class PostgresType(SqlglotType):
Expand Down Expand Up @@ -780,7 +807,9 @@ def _from_sqlglot_TIMESTAMPTZ(cls) -> dt.Timestamp:
return dt.Timestamp(timezone="UTC", nullable=cls.default_nullable)

@classmethod
def _from_sqlglot_GEOGRAPHY(cls) -> dt.GeoSpatial:
def _from_sqlglot_GEOGRAPHY(
cls, arg: sge.DataTypeParam | None = None, srid: sge.DataTypeParam | None = None
) -> dt.GeoSpatial:
return dt.GeoSpatial(
geotype="geography", srid=4326, nullable=cls.default_nullable
)
Expand Down
13 changes: 1 addition & 12 deletions ibis/backends/sql/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def assert_dtype_roundtrip(ibis_type, sqlglot_expected=None):
| its.array_dtypes(roundtripable_types, nullable=true)
| its.map_dtypes(roundtripable_types, roundtripable_types, nullable=true)
| its.struct_dtypes(roundtripable_types, nullable=true)
| its.geometry_dtypes(nullable=true)
| its.geography_dtypes(nullable=true)
| its.geospatial_dtypes(nullable=true)
| its.decimal_dtypes(nullable=true)
| its.interval_dtype(nullable=true)
)
Expand All @@ -59,16 +58,6 @@ def test_roundtripable_types(ibis_type):
assert_dtype_roundtrip(ibis_type)


@h.given(its.specific_geometry_dtypes(nullable=true))
def test_specific_geometry_types(ibis_type):
sqlglot_result = SqlglotType.from_ibis(ibis_type)
assert isinstance(sqlglot_result, sge.DataType)
assert sqlglot_result == sge.DataType(this=sge.DataType.Type.GEOMETRY)
assert SqlglotType.to_ibis(sqlglot_result) == dt.GeoSpatial(
geotype="geometry", nullable=ibis_type.nullable
)


def test_interval_without_unit():
with pytest.raises(com.IbisTypeError, match="precision is None"):
SqlglotType.from_string("INTERVAL")
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/datatypes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ class JSON(Variadic):
class GeoSpatial(DataType):
"""Geospatial values."""

geotype: Optional[Literal["geography", "geometry"]] = None
geotype: Literal["geography", "geometry"] = "geometry"
"""The specific geospatial type."""

srid: Optional[int] = None
Expand Down
4 changes: 2 additions & 2 deletions ibis/expr/types/geospatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ def centroid(self) -> PointValue:
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ GeoCentroid(geom) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ point
│ point:geometry
├──────────────────────────────────┤
│ <POINT (935996.821 191376.75)> │
│ <POINT (1031085.719 164018.754)> │
Expand Down Expand Up @@ -1261,7 +1261,7 @@ def envelope(self) -> ir.PolygonValue:
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ GeoEnvelope(geom) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ polygon
│ polygon:geometry
├──────────────────────────────────────────────────────────────────────────────┤
│ <POLYGON ((931553.491 183788.05, 941810.009 183788.05, 941810.009 │
│ 197256.211...> │
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/types/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def point(self, right: int | float | NumericValue) -> ir.PointValue:
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ GeoPoint(x_cent, y_cent) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ point
│ point:geometry
├──────────────────────────────────┤
│ <POINT (935996.821 191376.75)> │
│ <POINT (1031085.719 164018.754)> │
Expand Down
31 changes: 9 additions & 22 deletions ibis/tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,30 +178,17 @@ def struct_dtypes(
return dt.Struct(fields, nullable=draw(nullable))


def geometry_dtypes(nullable=_nullable):
return st.builds(dt.GeoSpatial, geotype=st.just("geometry"), nullable=nullable)


def geography_dtypes(nullable=_nullable):
return st.builds(dt.GeoSpatial, geotype=st.just("geography"), nullable=nullable)


def specific_geometry_dtypes(nullable=_nullable):
return st.one_of(
st.builds(dt.Point, nullable=nullable),
st.builds(dt.LineString, nullable=nullable),
st.builds(dt.Polygon, nullable=nullable),
st.builds(dt.MultiPoint, nullable=nullable),
st.builds(dt.MultiLineString, nullable=nullable),
st.builds(dt.MultiPolygon, nullable=nullable),
)


def geospatial_dtypes(nullable=_nullable):
geotype = st.one_of(st.just("geography"), st.just("geometry"))
srid = st.one_of(st.just(None), st.integers(min_value=0))
return st.one_of(
specific_geometry_dtypes(nullable=nullable),
geometry_dtypes(nullable=nullable),
geography_dtypes(nullable=nullable),
st.builds(dt.Point, geotype=geotype, nullable=nullable, srid=srid),
st.builds(dt.LineString, geotype=geotype, nullable=nullable, srid=srid),
st.builds(dt.Polygon, geotype=geotype, nullable=nullable, srid=srid),
st.builds(dt.MultiPoint, geotype=geotype, nullable=nullable, srid=srid),
st.builds(dt.MultiLineString, geotype=geotype, nullable=nullable, srid=srid),
st.builds(dt.MultiPolygon, geotype=geotype, nullable=nullable, srid=srid),
st.builds(dt.GeoSpatial, geotype=geotype, nullable=nullable, srid=srid),
)


Expand Down

0 comments on commit a3ceb59

Please sign in to comment.