diff --git a/ibis/backends/postgres/tests/test_client.py b/ibis/backends/postgres/tests/test_client.py index ee94731cecf5..bde3637cfac0 100644 --- a/ibis/backends/postgres/tests/test_client.py +++ b/ibis/backends/postgres/tests/test_client.py @@ -14,6 +14,7 @@ from __future__ import annotations import os +import string from urllib.parse import quote_plus import hypothesis as h @@ -30,6 +31,7 @@ import ibis.expr.datatypes as dt import ibis.expr.types as ir from ibis.backends.tests.errors import PsycoPg2OperationalError +from ibis.util import gen_name pytest.importorskip("psycopg2") @@ -390,3 +392,28 @@ def test_password_with_bracket(): match=f'password authentication failed for user "{IBIS_POSTGRES_USER}"', ): ibis.connect(url) + + +def test_create_geospatial_table_with_srid(con): + name = gen_name("geospatial") + column_names = string.ascii_lowercase + column_types = [ + "Point", + "LineString", + "Polygon", + "MultiLineString", + "MultiPoint", + "MultiPolygon", + ] + schema_string = ", ".join( + f"{column} geometry({dtype}, 4326)" + for column, dtype in zip(column_names, column_types) + ) + con.raw_sql(f"CREATE TEMP TABLE {name} ({schema_string})") + schema = con.get_schema(name) + assert schema == ibis.schema( + { + column: getattr(dt, dtype)(srid=4326) + for column, dtype in zip(column_names, column_types) + } + ) diff --git a/ibis/backends/sql/datatypes.py b/ibis/backends/sql/datatypes.py index 432642fa1060..b8bff67a1cec 100644 --- a/ibis/backends/sql/datatypes.py +++ b/ibis/backends/sql/datatypes.py @@ -27,8 +27,6 @@ typecode.ENUM16: dt.String, typecode.FLOAT: dt.Float32, typecode.FIXEDSTRING: dt.String, - typecode.GEOMETRY: partial(dt.GeoSpatial, geotype="geometry"), - typecode.GEOGRAPHY: partial(dt.GeoSpatial, geotype="geography"), typecode.HSTORE: partial(dt.Map, dt.string, dt.string), typecode.INET: dt.INET, typecode.INT128: partial(dt.Decimal, 38, 0), @@ -298,15 +296,27 @@ def _from_sqlglot_DECIMAL( @classmethod def _from_sqlglot_GEOMETRY( - cls, arg: sge.DataTypeParam | None = None + cls, arg: sge.DataTypeParam | None = None, srid: sge.DataTypeParam | None = None ) -> sge.DataType: if arg is not None: - return _geotypes[str(arg).upper()](nullable=cls.default_nullable) - return dt.GeoSpatial(geotype="geometry", nullable=cls.default_nullable) + typeclass = _geotypes[arg.this.this] + else: + typeclass = dt.GeoSpatial + if srid is not None: + srid = int(srid.this.this) + return typeclass(geotype="geometry", nullable=cls.default_nullable, srid=srid) @classmethod - def _from_sqlglot_GEOGRAPHY(cls) -> sge.DataType: - return dt.GeoSpatial(geotype="geography", nullable=cls.default_nullable) + def _from_sqlglot_GEOGRAPHY( + cls, arg: sge.DataTypeParam | None = None, srid: sge.DataTypeParam | None = None + ) -> sge.DataType: + if arg is not None: + typeclass = _geotypes[arg.this.this] + else: + typeclass = dt.GeoSpatial + if srid is not None: + srid = int(srid.this.this) + return typeclass(geotype="geography", nullable=cls.default_nullable, srid=srid) @classmethod def _from_ibis_Interval(cls, dtype: dt.Interval) -> sge.DataType: @@ -374,13 +384,30 @@ def _from_ibis_Timestamp(cls, dtype: dt.Timestamp) -> sge.DataType: @classmethod def _from_ibis_GeoSpatial(cls, dtype: dt.GeoSpatial): - if (geotype := dtype.geotype) is not None: - return sge.DataType(this=getattr(typecode, geotype.upper())) - return sge.DataType(this=typecode.GEOMETRY) + expressions = [None] + + if (srid := dtype.srid) is not None: + expressions.append(sge.DataTypeParam(this=sge.convert(srid))) + + this = getattr(typecode, dtype.geotype.upper()) + + return sge.DataType(this=this, expressions=expressions) + + @classmethod + def _from_ibis_SpecificGeometry(cls, dtype: dt.GeoSpatial): + expressions = [ + sge.DataTypeParam(this=sge.Var(this=dtype.__class__.__name__.upper())) + ] + + if (srid := dtype.srid) is not None: + expressions.append(sge.DataTypeParam(this=sge.convert(srid))) + + this = getattr(typecode, dtype.geotype.upper()) + return sge.DataType(this=this, expressions=expressions) _from_ibis_Point = _from_ibis_LineString = _from_ibis_Polygon = ( _from_ibis_MultiLineString - ) = _from_ibis_MultiPoint = _from_ibis_MultiPolygon = _from_ibis_GeoSpatial + ) = _from_ibis_MultiPoint = _from_ibis_MultiPolygon = _from_ibis_SpecificGeometry class PostgresType(SqlglotType): @@ -780,7 +807,9 @@ def _from_sqlglot_TIMESTAMPTZ(cls) -> dt.Timestamp: return dt.Timestamp(timezone="UTC", nullable=cls.default_nullable) @classmethod - def _from_sqlglot_GEOGRAPHY(cls) -> dt.GeoSpatial: + def _from_sqlglot_GEOGRAPHY( + cls, arg: sge.DataTypeParam | None = None, srid: sge.DataTypeParam | None = None + ) -> dt.GeoSpatial: return dt.GeoSpatial( geotype="geography", srid=4326, nullable=cls.default_nullable ) diff --git a/ibis/backends/sql/tests/test_datatypes.py b/ibis/backends/sql/tests/test_datatypes.py index 07c50dbeb4f2..b772cdddd4fd 100644 --- a/ibis/backends/sql/tests/test_datatypes.py +++ b/ibis/backends/sql/tests/test_datatypes.py @@ -41,8 +41,7 @@ def assert_dtype_roundtrip(ibis_type, sqlglot_expected=None): | its.array_dtypes(roundtripable_types, nullable=true) | its.map_dtypes(roundtripable_types, roundtripable_types, nullable=true) | its.struct_dtypes(roundtripable_types, nullable=true) - | its.geometry_dtypes(nullable=true) - | its.geography_dtypes(nullable=true) + | its.geospatial_dtypes(nullable=true) | its.decimal_dtypes(nullable=true) | its.interval_dtype(nullable=true) ) @@ -59,16 +58,6 @@ def test_roundtripable_types(ibis_type): assert_dtype_roundtrip(ibis_type) -@h.given(its.specific_geometry_dtypes(nullable=true)) -def test_specific_geometry_types(ibis_type): - sqlglot_result = SqlglotType.from_ibis(ibis_type) - assert isinstance(sqlglot_result, sge.DataType) - assert sqlglot_result == sge.DataType(this=sge.DataType.Type.GEOMETRY) - assert SqlglotType.to_ibis(sqlglot_result) == dt.GeoSpatial( - geotype="geometry", nullable=ibis_type.nullable - ) - - def test_interval_without_unit(): with pytest.raises(com.IbisTypeError, match="precision is None"): SqlglotType.from_string("INTERVAL") diff --git a/ibis/expr/datatypes/core.py b/ibis/expr/datatypes/core.py index 2a2cfa4eae78..ff0a42ba85c5 100644 --- a/ibis/expr/datatypes/core.py +++ b/ibis/expr/datatypes/core.py @@ -935,7 +935,7 @@ class JSON(Variadic): class GeoSpatial(DataType): """Geospatial values.""" - geotype: Optional[Literal["geography", "geometry"]] = None + geotype: Literal["geography", "geometry"] = "geometry" """The specific geospatial type.""" srid: Optional[int] = None diff --git a/ibis/expr/types/geospatial.py b/ibis/expr/types/geospatial.py index 5ae14ee2af7f..558dfb877a1f 100644 --- a/ibis/expr/types/geospatial.py +++ b/ibis/expr/types/geospatial.py @@ -1227,7 +1227,7 @@ def centroid(self) -> PointValue: ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ GeoCentroid(geom) ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ - │ point │ + │ point:geometry │ ├──────────────────────────────────┤ │ │ │ │ @@ -1261,7 +1261,7 @@ def envelope(self) -> ir.PolygonValue: ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ GeoEnvelope(geom) ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ - │ polygon │ + │ polygon:geometry │ ├──────────────────────────────────────────────────────────────────────────────┤ │ │ diff --git a/ibis/expr/types/numeric.py b/ibis/expr/types/numeric.py index 91609dcc2999..36da3740a944 100644 --- a/ibis/expr/types/numeric.py +++ b/ibis/expr/types/numeric.py @@ -742,7 +742,7 @@ def point(self, right: int | float | NumericValue) -> ir.PointValue: ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ GeoPoint(x_cent, y_cent) ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ - │ point │ + │ point:geometry │ ├──────────────────────────────────┤ │ │ │ │ diff --git a/ibis/tests/strategies.py b/ibis/tests/strategies.py index f417b3719a12..d31e1ca035ab 100644 --- a/ibis/tests/strategies.py +++ b/ibis/tests/strategies.py @@ -178,30 +178,17 @@ def struct_dtypes( return dt.Struct(fields, nullable=draw(nullable)) -def geometry_dtypes(nullable=_nullable): - return st.builds(dt.GeoSpatial, geotype=st.just("geometry"), nullable=nullable) - - -def geography_dtypes(nullable=_nullable): - return st.builds(dt.GeoSpatial, geotype=st.just("geography"), nullable=nullable) - - -def specific_geometry_dtypes(nullable=_nullable): - return st.one_of( - st.builds(dt.Point, nullable=nullable), - st.builds(dt.LineString, nullable=nullable), - st.builds(dt.Polygon, nullable=nullable), - st.builds(dt.MultiPoint, nullable=nullable), - st.builds(dt.MultiLineString, nullable=nullable), - st.builds(dt.MultiPolygon, nullable=nullable), - ) - - def geospatial_dtypes(nullable=_nullable): + geotype = st.one_of(st.just("geography"), st.just("geometry")) + srid = st.one_of(st.just(None), st.integers(min_value=0)) return st.one_of( - specific_geometry_dtypes(nullable=nullable), - geometry_dtypes(nullable=nullable), - geography_dtypes(nullable=nullable), + st.builds(dt.Point, geotype=geotype, nullable=nullable, srid=srid), + st.builds(dt.LineString, geotype=geotype, nullable=nullable, srid=srid), + st.builds(dt.Polygon, geotype=geotype, nullable=nullable, srid=srid), + st.builds(dt.MultiPoint, geotype=geotype, nullable=nullable, srid=srid), + st.builds(dt.MultiLineString, geotype=geotype, nullable=nullable, srid=srid), + st.builds(dt.MultiPolygon, geotype=geotype, nullable=nullable, srid=srid), + st.builds(dt.GeoSpatial, geotype=geotype, nullable=nullable, srid=srid), )