Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reorganize tests and fix nullable propagation and some other bugs #433

Merged
merged 1 commit into from
Mar 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions geoalchemy2/admin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def after_parent_attach(column, table):
# with a selectable as table, so we want to skip this case.
return

if not getattr(column.type, "nullable", True):
column.nullable = column.type.nullable
elif hasattr(column.type, "nullable"):
column.type.nullable = column.nullable

if not getattr(column.type, "spatial_index", False) and getattr(
column.type, "use_N_D_index", False
):
Expand Down
14 changes: 9 additions & 5 deletions geoalchemy2/admin/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,17 @@ def reflect_geometry_column(inspector, table, column_info):
)
if table.schema is not None:
geometry_type_query += """ and table_schema = '{}'""".format(table.schema)
geometry_type, srid, nullable = inspector.bind.execute(text(geometry_type_query)).one()
geometry_type, srid, nullable_str = inspector.bind.execute(text(geometry_type_query)).one()
is_nullable = str(nullable_str).lower() == "yes"

if geometry_type not in _POSSIBLE_TYPES:
return

# Check if the column has spatial index
has_index_query = """SELECT DISTINCT
INDEX_TYPE
FROM INFORMATION_SCHEMA.STATISTICS
WHERE TABLE_NAME = '{}' and COLUMN_NAME = '{}'""".format(
INDEX_TYPE
FROM INFORMATION_SCHEMA.STATISTICS
WHERE TABLE_NAME = '{}' and COLUMN_NAME = '{}'""".format(
table.name, column_name
)
if table.schema is not None:
Expand All @@ -60,7 +61,7 @@ def reflect_geometry_column(inspector, table, column_info):
geometry_type=geometry_type.upper(),
srid=srid,
spatial_index=spatial_index,
nullable=str(nullable).lower() == "yes",
nullable=is_nullable,
_spatial_index_reflected=True,
)

Expand Down Expand Up @@ -168,6 +169,9 @@ def _compile_GeomFromText_MySql(element, compiler, **kw):

def _compile_GeomFromWKB_MySql(element, compiler, **kw):
element.identifier = "ST_GeomFromWKB"
wkb_data = list(element.clauses)[0].value
if isinstance(wkb_data, memoryview):
list(element.clauses)[0].value = wkb_data.tobytes()
compiled = compiler.process(element.clauses, **kw)
srid = element.type.srid

Expand Down
1 change: 0 additions & 1 deletion geoalchemy2/admin/dialects/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def _setup_dummy_type(table, gis_cols):
# Add dummy columns with GEOMETRY type
col._actual_type = col.type
col.type = _DummyGeometry()
col.nullable = col._actual_type.nullable
table.columns = table.info["_saved_columns"]


Expand Down
46 changes: 46 additions & 0 deletions geoalchemy2/alembic_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from alembic.operations import ops
from sqlalchemy import Column
from sqlalchemy import text
from sqlalchemy.dialects.mysql.base import MySQLDialect
from sqlalchemy.dialects.sqlite.base import SQLiteDialect
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.schema import DropTable
Expand Down Expand Up @@ -88,6 +89,51 @@ def spatial_behavior(self, connection, table_name, schema=None, **kw):
_monkey_patch_get_indexes_for_sqlite()


def _monkey_patch_get_indexes_for_mysql():
"""Monkey patch SQLAlchemy to fix spatial index reflection."""
normal_behavior = MySQLDialect.get_indexes

def spatial_behavior(self, connection, table_name, schema=None, **kw):
indexes = self._get_indexes_normal_behavior(connection, table_name, schema=None, **kw)

# Get spatial indexes
has_index_query = """SELECT DISTINCT
COLUMN_NAME
FROM INFORMATION_SCHEMA.STATISTICS
WHERE TABLE_NAME = '{}' AND INDEX_TYPE = 'SPATIAL'""".format(
table_name
)
if schema is not None:
has_index_query += """ AND TABLE_SCHEMA = '{}'""".format(schema)
spatial_indexes = connection.execute(text(has_index_query)).fetchall()

if spatial_indexes:
reflected_names = set([i["name"] for i in indexes])
for idx in spatial_indexes:
idx_col = idx[0]
idx_name = _spatial_idx_name(table_name, idx_col)
if idx_name in reflected_names:
continue
indexes.append(
{
"name": idx_name,
"column_names": [idx_col],
"unique": 0,
"dialect_options": {"_column_flag": True},
}
)
reflected_names.add(idx_name)

return indexes

spatial_behavior.__doc__ = normal_behavior.__doc__
MySQLDialect.get_indexes = spatial_behavior
MySQLDialect._get_indexes_normal_behavior = normal_behavior


_monkey_patch_get_indexes_for_mysql()


def render_item(obj_type, obj, autogen_context):
"""Add proper imports for spatial types."""
if obj_type == "type" and isinstance(obj, (Geometry, Geography, Raster)):
Expand Down
10 changes: 5 additions & 5 deletions geoalchemy2/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,12 @@ def check_ctor_args(geometry_type, srid, dimension, management, use_typmod, null

@compiles(_GISType, "mysql")
def get_col_spec(self, *args, **kwargs):
if not self.geometry_type:
return self.name
if self.geometry_type is not None:
spec = "%s" % self.geometry_type
else:
spec = "GEOMETRY"

spec = "%s" % self.geometry_type

if not self.nullable:
if not self.nullable or self.spatial_index:
spec += " NOT NULL"
if self.srid > 0:
spec += " SRID %d" % self.srid
Expand Down
12 changes: 6 additions & 6 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ def __call__(self, test_obj):
return test_obj


def get_postgis_version(bind):
def get_postgis_major_version(bind):
try:
return bind.execute(func.postgis_lib_version()).scalar()
return version.parse(bind.execute(func.postgis_lib_version()).scalar()).major
except OperationalError:
return "0"
return version.parse("0").major


def get_postgres_major_version(bind):
Expand All @@ -44,17 +44,17 @@ def get_postgres_major_version(bind):


def skip_postgis1(bind):
if get_postgis_version(bind).startswith("1."):
if get_postgis_major_version(bind) == 1:
pytest.skip("requires PostGIS != 1")


def skip_postgis2(bind):
if get_postgis_version(bind).startswith("2."):
if get_postgis_major_version(bind) == 2:
pytest.skip("requires PostGIS != 2")


def skip_postgis3(bind):
if get_postgis_version(bind).startswith("3."):
if get_postgis_major_version(bind) == 3:
pytest.skip("requires PostGIS != 3")


Expand Down
49 changes: 31 additions & 18 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
import pytest
from sqlalchemy import MetaData
from sqlalchemy import create_engine
from sqlalchemy.dialects.mysql.base import MySQLDialect
from sqlalchemy.dialects.sqlite.base import SQLiteDialect
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import sessionmaker

from geoalchemy2.alembic_helpers import _monkey_patch_get_indexes_for_mysql
from geoalchemy2.alembic_helpers import _monkey_patch_get_indexes_for_sqlite

from . import copy_and_connect_sqlite_db
from . import get_postgis_version
from . import get_postgis_major_version
from . import get_postgres_major_version
from .schema_fixtures import * # noqa

Expand Down Expand Up @@ -63,7 +65,7 @@ def pytest_generate_tests(metafunc):
dialects = metafunc.cls.tested_dialects

if dialects is None:
dialects = ["postgresql", "sqlite-spatialite3", "sqlite-spatialite4"]
dialects = ["mysql", "postgresql", "sqlite-spatialite3", "sqlite-spatialite4"]

if "sqlite" in dialects:
dialects = [i for i in dialects if i != "sqlite"] + sqlite_dialects
Expand Down Expand Up @@ -144,17 +146,24 @@ def engine(tmpdir, db_url, _engine_echo):


@pytest.fixture
def session(engine):
session = sessionmaker(bind=engine)()
yield session
session.rollback()
def dialect_name(engine):
return engine.dialect.name


@pytest.fixture
def conn(session):
def conn(engine):
"""Provide a connection to test database."""
conn = session.connection()
yield conn
with engine.connect() as connection:
trans = connection.begin()
yield connection
trans.rollback()


@pytest.fixture
def session(engine, conn):
Session = sessionmaker(bind=conn)
with Session(bind=conn) as session:
yield session


@pytest.fixture
Expand All @@ -177,7 +186,7 @@ def base(metadata):

@pytest.fixture
def postgis_version(conn):
return get_postgis_version(conn)
return get_postgis_major_version(conn)


@pytest.fixture
Expand All @@ -186,25 +195,29 @@ def postgres_major_version(conn):


@pytest.fixture(autouse=True)
def reset_sqlite_monkeypatch():
def reset_alembic_monkeypatch():
"""Disable Alembic monkeypatching by default."""
try:
normal_behavior = SQLiteDialect._get_indexes_normal_behavior
SQLiteDialect.get_indexes = normal_behavior
SQLiteDialect._get_indexes_normal_behavior = normal_behavior
normal_behavior_sqlite = SQLiteDialect._get_indexes_normal_behavior
SQLiteDialect.get_indexes = normal_behavior_sqlite
SQLiteDialect._get_indexes_normal_behavior = normal_behavior_sqlite

normal_behavior_mysql = MySQLDialect._get_indexes_normal_behavior
MySQLDialect.get_indexes = normal_behavior_mysql
MySQLDialect._get_indexes_normal_behavior = normal_behavior_mysql
except AttributeError:
pass


@pytest.fixture(autouse=True)
def use_sqlite_monkeypatch():
@pytest.fixture()
def use_alembic_monkeypatch():
"""Enable Alembic monkeypatching ."""
_monkey_patch_get_indexes_for_sqlite()
_monkey_patch_get_indexes_for_mysql()


@pytest.fixture
def setup_tables(session, metadata):
conn = session.connection()
def setup_tables(conn, metadata):
metadata.drop_all(conn, checkfirst=True)
metadata.create_all(conn)
yield
2 changes: 2 additions & 0 deletions tests/gallery/test_length_at_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# Tests imports
from tests import select
from tests import test_only_with_dialects

metadata = MetaData()

Expand All @@ -32,6 +33,7 @@


class TestLengthAtInsert:
@test_only_with_dialects("postgresql", "sqlite")
def test_query(self, conn):
metadata.drop_all(conn, checkfirst=True)
metadata.create_all(conn)
Expand Down
2 changes: 2 additions & 0 deletions tests/gallery/test_specific_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
# Tests imports
from tests import format_wkt
from tests import select
from tests import test_only_with_dialects

metadata = MetaData()
Base = declarative_base(metadata=metadata)
Expand Down Expand Up @@ -87,6 +88,7 @@ def _compile_buffer_sqlite(element, compiler, **kw):
compiles(functions.ST_Buffer, "sqlite")(_compile_buffer_sqlite)


@test_only_with_dialects("postgresql", "sqlite")
def test_specific_compilation(conn):
# Build a query with a sided buffer
query = select(
Expand Down
6 changes: 3 additions & 3 deletions tests/schema_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@pytest.fixture
def Lake(base, postgis_version, schema):
with_management = postgis_version.startswith("1.")
with_management = postgis_version == 1

class Lake(base):
__tablename__ = "lake"
Expand Down Expand Up @@ -49,7 +49,7 @@ def __init__(self, geog):

@pytest.fixture
def Summit(base, postgis_version, schema):
with_use_typemod = postgis_version.startswith("1.")
with_use_typemod = postgis_version == 1

class Summit(base):
__tablename__ = "summit"
Expand All @@ -73,7 +73,7 @@ def __init__(self, geom):
@pytest.fixture
def Ocean(base, postgis_version):
# The raster type is only available on PostGIS 2.0 and above
if postgis_version.startswith("1."):
if postgis_version == 1:
pytest.skip("The raster type is only available on PostGIS 2.0 and above")

class Ocean(base):
Expand Down
7 changes: 4 additions & 3 deletions tests/test_alembic_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def filter_tables(name, type_, parent_names):


class TestAutogenerate:
def test_no_diff(self, conn, Lake, setup_tables, use_sqlite_monkeypatch):
def test_no_diff(self, conn, Lake, setup_tables, use_alembic_monkeypatch, dialect_name):
"""Check that the autogeneration detects spatial types properly."""
metadata = MetaData()

Expand All @@ -39,6 +39,7 @@ def test_no_diff(self, conn, Lake, setup_tables, use_sqlite_monkeypatch):
geometry_type="LINESTRING",
srid=4326,
management=Lake.__table__.c.geom.type.management,
nullable=dialect_name != "mysql",
),
),
schema=Lake.__table__.schema,
Expand All @@ -56,7 +57,7 @@ def test_no_diff(self, conn, Lake, setup_tables, use_sqlite_monkeypatch):

assert diff == []

def test_diff(self, conn, Lake, setup_tables, use_sqlite_monkeypatch):
def test_diff(self, conn, Lake, setup_tables, use_alembic_monkeypatch):
"""Check that the autogeneration detects spatial types properly."""
metadata = MetaData()

Expand Down Expand Up @@ -253,7 +254,7 @@ class = StreamHandler

@test_only_with_dialects("postgresql", "sqlite-spatialite4")
def test_migration_revision(
conn, metadata, alembic_config, alembic_env_path, test_script_path, use_sqlite_monkeypatch
conn, metadata, alembic_config, alembic_env_path, test_script_path, use_alembic_monkeypatch
):
initial_rev = command.revision(
alembic_config,
Expand Down
Loading