From f93cf5a798f113207890778261690c7bc067c3b0 Mon Sep 17 00:00:00 2001 From: Adrien Berchet Date: Sun, 3 Nov 2024 12:42:45 +0100 Subject: [PATCH] Improve MariaDB support (#524) * Improve MariaDB support * Fix pre-commit config * Fix CI and minor improvements --------- Co-authored-by: Tom Cook Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/test_and_publish.yml | 82 +++++++++++++++++-------- .pre-commit-config.yaml | 2 - doc/admin.rst | 4 +- doc/index.rst | 2 +- geoalchemy2/admin/dialects/mysql.py | 80 ++++++++++++++++++++++-- geoalchemy2/types/__init__.py | 8 +-- geoalchemy2/types/dialects/__init__.py | 1 + geoalchemy2/types/dialects/mariadb.py | 47 ++++++++++++++ pyproject.toml | 3 + requirements.txt | 2 +- test_container/Dockerfile | 4 +- test_container/Dockerfile_mariadb | 19 ++++++ test_container/build_mariadb.sh | 7 +++ test_container/helpers/init_mariadb.sh | 22 +++++++ test_container/helpers/init_mysql.sh | 3 +- test_container/run_mariadb.sh | 12 ++++ tests/conftest.py | 84 ++++++++++++++++++-------- tests/gallery/test_orm_mapped_v2.py | 2 +- tests/gallery/test_summarystatsagg.py | 2 +- tests/schema_fixtures.py | 2 +- tests/test_alembic_migrations.py | 2 +- tests/test_functional.py | 38 ++++++------ tests/test_functional_mysql.py | 27 ++++++--- tests/test_functional_postgresql.py | 2 +- tests/test_pickle.py | 2 +- tox.ini | 4 ++ 26 files changed, 362 insertions(+), 101 deletions(-) create mode 100644 geoalchemy2/types/dialects/mariadb.py create mode 100644 test_container/Dockerfile_mariadb create mode 100644 test_container/build_mariadb.sh create mode 100755 test_container/helpers/init_mariadb.sh create mode 100755 test_container/run_mariadb.sh diff --git a/.github/workflows/test_and_publish.yml b/.github/workflows/test_and_publish.yml index 10b1be79..c2b40ec4 100644 --- a/.github/workflows/test_and_publish.yml +++ b/.github/workflows/test_and_publish.yml @@ -54,28 +54,53 @@ jobs: --health-interval 10s --health-timeout 5s --health-retries 5 + mysql: + image: mysql:latest + ports: + - 3307:3306 + env: + MYSQL_USER: gis + MYSQL_PASSWORD: gis + MYSQL_DATABASE: gis + MYSQL_ROOT_PASSWORD: gis + # Set health checks to wait until MySQL has started + options: >- + --health-cmd="mysqladmin ping" + --health-interval=10s + --health-timeout=5s + --health-retries=3 + mariadb: + image: mariadb:latest + ports: + - 3308:3306 + env: + MARIADB_USER: gis + MARIADB_PASSWORD: gis + MARIADB_DATABASE: gis + MARIADB_ROOT_PASSWORD: gis + # Set health checks to wait until MariaDB has started + options: >- + --health-cmd="healthcheck.sh --connect --innodb_initialized" + --health-interval=10s + --health-timeout=5s + --health-retries=3 + steps: # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it - uses: actions/checkout@v4 - # Setup Conda for Python and Pypy - - name: Install Conda environment with Micromamba - uses: mamba-org/setup-micromamba@v1 + # Setup Python + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 with: - environment-name: test_${{ matrix.python-version.flag }} - cache-environment: true - create-args: >- - ${{ matrix.python-version.pkg_name }} - libgdal - libspatialite==5.0.1 - pyproj - condarc: | - channels: - - conda-forge - - defaults - channel_priority: strict + python-version: ${{ matrix.python-version.flag }} + + # Install MariaDB + - name: Install MariaDB and SpatiaLite + run: | + sudo apt-get install -y mariadb-server mariadb-client libsqlite3-mod-spatialite libgdal-dev gdal-bin rasterio # Config PostgreSQL - name: Configure PostgreSQL @@ -95,13 +120,17 @@ jobs: # Drop PostGIS Topology extension to "gis" database psql -h localhost -p 5432 -U gis -d gis -c 'DROP EXTENSION IF EXISTS postgis_topology;' - # Setup MySQL - - name: Set up MySQL + # Check MySQL + - name: Check MySQL + run: | + mysql --user=gis --password=gis --host=127.0.0.1 -P 3307 -e "SELECT VERSION();" + mysql --user=root --password=gis --host=127.0.0.1 -P 3307 -e "GRANT ALL PRIVILEGES ON *.* TO 'gis'@'%' WITH GRANT OPTION;" + + # Check MariaDB + - name: Check MariaDB run: | - sudo systemctl start mysql - sudo mysql --user=root --password=root --host=127.0.0.1 -e "CREATE USER 'gis'@'%' IDENTIFIED BY 'gis';" - sudo mysql --user=root --password=root --host=127.0.0.1 -e "GRANT ALL PRIVILEGES ON *.* TO 'gis'@'%' WITH GRANT OPTION;" - mysql --user=gis --password=gis -e "CREATE DATABASE gis;" + mysql --user=gis --password=gis --host=127.0.0.1 -P 3308 -e "SELECT VERSION();" + mysql --user=root --password=gis --host=127.0.0.1 -P 3308 -e "GRANT ALL PRIVILEGES ON *.* TO 'gis'@'%' WITH GRANT OPTION;" # Check python version - name: Display Python version @@ -114,16 +143,19 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip setuptools - pip install tox-gh-actions + python -m pip install tox-gh-actions # Run the test suite - name: Run the tests env: - SPATIALITE_LIBRARY_PATH: /home/runner/micromamba/envs/test_${{ matrix.python-version.flag }}/lib/mod_spatialite.so - PROJ_LIB: /home/runner/micromamba/envs/test_${{ matrix.python-version.flag }}/share/proj + SPATIALITE_LIBRARY_PATH: /usr/lib/x86_64-linux-gnu/mod_spatialite.so COVERAGE_FILE: .coverage - PYTEST_MYSQL_DB_URL: mysql://gis:gis@127.0.0.1/gis + PYTEST_MYSQL_DB_URL: mysql://gis:gis@127.0.0.1:3307/gis + PYTEST_MARIADB_DB_URL: mariadb://gis:gis@127.0.0.1:3308/gis run: | + if [[ ${{ matrix.python-version.flag }} == 'pypy3.8' ]]; then + export PYTEST_ADDOPTS='--ignore=tests/gallery/test_insert_raster.py' + fi; # Run the unit test suite with SQLAlchemy=1.4.* and then with the latest version of SQLAlchemy tox -vv diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6e2ed3c8..ab22c9fb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,3 @@ -default_language_version: - python: python3.8 repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.6.0 diff --git a/doc/admin.rst b/doc/admin.rst index 76575a3d..4c0a6fc1 100644 --- a/doc/admin.rst +++ b/doc/admin.rst @@ -25,8 +25,8 @@ PostgreSQL-specific objects :private-members: :show-inheritance: -MySQL-specific objects ---------------------------- +MySQL/MariadDB-specific objects +------------------------------- .. automodule:: geoalchemy2.admin.dialects.mysql :members: diff --git a/doc/index.rst b/doc/index.rst index 4f090d74..4b2711d3 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -17,7 +17,7 @@ GeoAlchemy 2 also supports the following dialects: * `GeoPackage `_ Note that using GeoAlchemy 2 with these dialects may require some specific configuration on the -application side. +application side. It also may not be optimal for performance. GeoAlchemy 2 aims to be simpler than its predecessor, `GeoAlchemy `_. Simpler to use, and simpler diff --git a/geoalchemy2/admin/dialects/mysql.py b/geoalchemy2/admin/dialects/mysql.py index 92aad792..5b7ef6c5 100644 --- a/geoalchemy2/admin/dialects/mysql.py +++ b/geoalchemy2/admin/dialects/mysql.py @@ -8,6 +8,9 @@ from geoalchemy2.admin.dialects.common import _check_spatial_type from geoalchemy2.admin.dialects.common import _spatial_idx_name from geoalchemy2.admin.dialects.common import setup_create_drop +from geoalchemy2.elements import WKBElement +from geoalchemy2.elements import WKTElement +from geoalchemy2.shape import to_shape from geoalchemy2.types import Geography from geoalchemy2.types import Geometry @@ -31,11 +34,16 @@ def reflect_geometry_column(inspector, table, column_info): column_name = column_info.get("name") schema = table.schema or inspector.default_schema_name + if inspector.dialect.name == "mariadb": + select_srid = "-1, " + else: + select_srid = "SRS_ID, " + # Check geometry type, SRID and if the column is nullable - geometry_type_query = """SELECT DATA_TYPE, SRS_ID, IS_NULLABLE + geometry_type_query = """SELECT DATA_TYPE, {}IS_NULLABLE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = '{}' and COLUMN_NAME = '{}'""".format( - table.name, column_name + select_srid, table.name, column_name ) if schema is not None: geometry_type_query += """ and table_schema = '{}'""".format(schema) @@ -176,25 +184,85 @@ def _compile_GeomFromWKB_MySql(element, compiler, **kw): return "{}({})".format(element.identifier, compiled) +def _compile_GeomFromText_MariaDB(element, compiler, **kw): + element.identifier = "ST_GeomFromText" + compiled = compiler.process(element.clauses, **kw) + try: + clauses = list(element.clauses) + data_element = WKTElement(clauses[0].value) + srid = max(0, data_element.srid) + if srid <= 0: + srid = max(0, element.type.srid) + if len(clauses) > 1 and srid > 0: + clauses[1].value = srid + except Exception: + srid = max(0, element.type.srid) + + if srid > 0: + res = "{}({}, {})".format(element.identifier, compiled, srid) + else: + res = "{}({})".format(element.identifier, compiled) + return res + + +def _compile_GeomFromWKB_MariaDB(element, compiler, **kw): + element.identifier = "ST_GeomFromText" + + try: + clauses = list(element.clauses) + data_element = WKBElement(clauses[0].value) + srid = max(0, data_element.srid) + if srid <= 0: + srid = max(0, element.type.srid) + clauses[0].value = to_shape(data_element).wkt.encode("utf-8") + if len(clauses) > 1 and srid > 0: + clauses[1].value = srid + except Exception: + srid = max(0, element.type.srid) + compiled = compiler.process(element.clauses, **kw) + + if srid > 0: + res = "{}({}, {})".format(element.identifier, compiled, srid) + else: + res = "{}({})".format(element.identifier, compiled) + return res + + @compiles(functions.ST_GeomFromText, "mysql") # type: ignore -@compiles(functions.ST_GeomFromText, "mariadb") # type: ignore def _MySQL_ST_GeomFromText(element, compiler, **kw): return _compile_GeomFromText_MySql(element, compiler, **kw) @compiles(functions.ST_GeomFromEWKT, "mysql") # type: ignore -@compiles(functions.ST_GeomFromEWKT, "mariadb") # type: ignore def _MySQL_ST_GeomFromEWKT(element, compiler, **kw): return _compile_GeomFromText_MySql(element, compiler, **kw) +@compiles(functions.ST_GeomFromText, "mariadb") # type: ignore +def _MariaDB_ST_GeomFromText(element, compiler, **kw): + return _compile_GeomFromText_MariaDB(element, compiler, **kw) + + +@compiles(functions.ST_GeomFromEWKT, "mariadb") # type: ignore +def _MariaDB_ST_GeomFromEWKT(element, compiler, **kw): + return _compile_GeomFromText_MariaDB(element, compiler, **kw) + + @compiles(functions.ST_GeomFromWKB, "mysql") # type: ignore -@compiles(functions.ST_GeomFromWKB, "mariadb") # type: ignore def _MySQL_ST_GeomFromWKB(element, compiler, **kw): return _compile_GeomFromWKB_MySql(element, compiler, **kw) @compiles(functions.ST_GeomFromEWKB, "mysql") # type: ignore -@compiles(functions.ST_GeomFromEWKB, "mariadb") # type: ignore def _MySQL_ST_GeomFromEWKB(element, compiler, **kw): return _compile_GeomFromWKB_MySql(element, compiler, **kw) + + +@compiles(functions.ST_GeomFromWKB, "mariadb") # type: ignore +def _MariaDB_ST_GeomFromWKB(element, compiler, **kw): + return _compile_GeomFromWKB_MariaDB(element, compiler, **kw) + + +@compiles(functions.ST_GeomFromEWKB, "mariadb") # type: ignore +def _MariaDB_ST_GeomFromEWKB(element, compiler, **kw): + return _compile_GeomFromWKB_MariaDB(element, compiler, **kw) diff --git a/geoalchemy2/types/__init__.py b/geoalchemy2/types/__init__.py index 35eb4fe2..76966986 100644 --- a/geoalchemy2/types/__init__.py +++ b/geoalchemy2/types/__init__.py @@ -40,7 +40,7 @@ def select_dialect(dialect_name): known_dialects = { "geopackage": dialects.geopackage, "mysql": dialects.mysql, - "mariadb": dialects.mysql, + "mariadb": dialects.mariadb, "postgresql": dialects.postgresql, "sqlite": dialects.sqlite, } @@ -198,9 +198,9 @@ def check_ctor_args(geometry_type, srid, dimension, use_typmod, nullable): return geometry_type, srid -@compiles(_GISType, "mariadb") @compiles(_GISType, "mysql") -def get_col_spec(self, *args, **kwargs): +@compiles(_GISType, "mariadb") +def get_col_spec_mysql(self, compiler, *args, **kwargs): if self.geometry_type is not None: spec = "%s" % self.geometry_type else: @@ -208,7 +208,7 @@ def get_col_spec(self, *args, **kwargs): if not self.nullable or self.spatial_index: spec += " NOT NULL" - if self.srid > 0: + if self.srid > 0 and compiler.dialect.name != "mariadb": spec += " SRID %d" % self.srid return spec diff --git a/geoalchemy2/types/dialects/__init__.py b/geoalchemy2/types/dialects/__init__.py index c343b2ab..cc135210 100644 --- a/geoalchemy2/types/dialects/__init__.py +++ b/geoalchemy2/types/dialects/__init__.py @@ -2,6 +2,7 @@ from geoalchemy2.types.dialects import common # noqa from geoalchemy2.types.dialects import geopackage # noqa +from geoalchemy2.types.dialects import mariadb # noqa from geoalchemy2.types.dialects import mysql # noqa from geoalchemy2.types.dialects import postgresql # noqa from geoalchemy2.types.dialects import sqlite # noqa diff --git a/geoalchemy2/types/dialects/mariadb.py b/geoalchemy2/types/dialects/mariadb.py new file mode 100644 index 00000000..32363efe --- /dev/null +++ b/geoalchemy2/types/dialects/mariadb.py @@ -0,0 +1,47 @@ +"""This module defines specific functions for MySQL dialect.""" + +from geoalchemy2.elements import WKBElement +from geoalchemy2.elements import WKTElement +from geoalchemy2.elements import _SpatialElement +from geoalchemy2.exc import ArgumentError +from geoalchemy2.shape import to_shape + + +def bind_processor_process(spatial_type, bindvalue): + if isinstance(bindvalue, str): + wkt_match = WKTElement._REMOVE_SRID.match(bindvalue) + srid = wkt_match.group(2) + try: + if srid is not None: + srid = int(srid) + except (ValueError, TypeError): # pragma: no cover + raise ArgumentError( + f"The SRID ({srid}) of the supplied value can not be casted to integer" + ) + + if srid is not None and srid != spatial_type.srid: + raise ArgumentError( + f"The SRID ({srid}) of the supplied value is different " + f"from the one of the column ({spatial_type.srid})" + ) + return wkt_match.group(3) + + if ( + isinstance(bindvalue, _SpatialElement) + and bindvalue.srid != -1 + and bindvalue.srid != spatial_type.srid + ): + raise ArgumentError( + f"The SRID ({bindvalue.srid}) of the supplied value is different " + f"from the one of the column ({spatial_type.srid})" + ) + + if isinstance(bindvalue, WKTElement): + bindvalue = bindvalue.as_wkt() + if bindvalue.srid <= 0: + bindvalue.srid = spatial_type.srid + return bindvalue + elif isinstance(bindvalue, WKBElement): + # With MariaDB we use Shapely to convert the WKBElement to an EWKT string + return to_shape(bindvalue).wkt + return bindvalue diff --git a/pyproject.toml b/pyproject.toml index 8655959a..7989b5bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,9 @@ requires = [ "wheel", "setuptools_scm[toml]>=3.4", ] +build-backend = "setuptools.build_meta" + +[tool.setuptools_scm] # BLACK [tool.black] diff --git a/requirements.txt b/requirements.txt index 3895f610..ade78067 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,4 @@ pytest pytest-cov pytest-html pytest-mypy -rasterio +rasterio;implementation_name!='pypy' diff --git a/test_container/Dockerfile b/test_container/Dockerfile index 329235fe..4563a04b 100644 --- a/test_container/Dockerfile +++ b/test_container/Dockerfile @@ -4,8 +4,8 @@ COPY ./helpers/install_requirements.sh / RUN /install_requirements.sh COPY ./helpers/init_postgres.sh / -env PGDATA="/var/lib/postgresql/data" -env POSTGRES_PATH="/usr/lib/postgresql/16" +ENV PGDATA="/var/lib/postgresql/data" +ENV POSTGRES_PATH="/usr/lib/postgresql/16" RUN su postgres -c /init_postgres.sh ENV SPATIALITE_LIBRARY_PATH="/usr/lib/x86_64-linux-gnu/mod_spatialite.so" diff --git a/test_container/Dockerfile_mariadb b/test_container/Dockerfile_mariadb new file mode 100644 index 00000000..a248706f --- /dev/null +++ b/test_container/Dockerfile_mariadb @@ -0,0 +1,19 @@ +FROM ubuntu:22.04 + +COPY ./helpers/install_requirements.sh / +RUN /install_requirements.sh + +RUN apt-get update; apt-get install -y mariadb-server mariadb-client; rm -rf /var/lib/apt/lists/*; + +COPY ./helpers/init_postgres.sh / +ENV PGDATA="/var/lib/postgresql/data" +ENV POSTGRES_PATH="/usr/lib/postgresql/16" +RUN su postgres -c /init_postgres.sh + +ENV SPATIALITE_LIBRARY_PATH="/usr/lib/x86_64-linux-gnu/mod_spatialite.so" + +COPY ./helpers/init_mariadb.sh / +RUN /init_mariadb.sh + +COPY ./helpers/entrypoint.sh / +ENTRYPOINT ["/entrypoint.sh"] diff --git a/test_container/build_mariadb.sh b/test_container/build_mariadb.sh new file mode 100644 index 00000000..9f0fc616 --- /dev/null +++ b/test_container/build_mariadb.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash +set -e + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +cd "${SCRIPT_DIR}" +docker build -f Dockerfile_mariadb -t geoalchemy2-mariadb . diff --git a/test_container/helpers/init_mariadb.sh b/test_container/helpers/init_mariadb.sh new file mode 100755 index 00000000..e723d931 --- /dev/null +++ b/test_container/helpers/init_mariadb.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +set -e + +if [ $(whoami) != "root" ]; then + echo "must run as the root user" + exit 1 +fi + +echo "Starting mysql server" +/etc/init.d/mariadb start + +echo "Waiting for mysql to start" +while ! mysqladmin ping -h 127.0.0.1 --silent; do + sleep 0.2 +done + +echo "Create the 'gis' role" +mariadb -e "CREATE USER 'gis'@'%' IDENTIFIED BY 'gis';" +mariadb -e "GRANT ALL PRIVILEGES ON *.* TO 'gis'@'%' WITH GRANT OPTION;" + +echo "Create the 'gis' database" +mariadb -u gis --password=gis -e "CREATE DATABASE gis;" diff --git a/test_container/helpers/init_mysql.sh b/test_container/helpers/init_mysql.sh index a6b7c86f..452cb081 100755 --- a/test_container/helpers/init_mysql.sh +++ b/test_container/helpers/init_mysql.sh @@ -6,9 +6,10 @@ if [ $(whoami) != "root" ]; then exit 1 fi +echo "Starting mysql server" /etc/init.d/mysql start -echo "waiting for mysql to start" +echo "Waiting for mysql to start" while ! mysqladmin ping -h 127.0.0.1 --silent; do sleep 0.2 done diff --git a/test_container/run_mariadb.sh b/test_container/run_mariadb.sh new file mode 100755 index 00000000..6d08331d --- /dev/null +++ b/test_container/run_mariadb.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +mkdir -p "${SCRIPT_DIR}/output" + +docker run --rm -it \ + --mount type=bind,source="${ROOT}",target=/geoalchemy2_read_only,ro \ + --mount type=bind,source="${SCRIPT_DIR}/output",target=/output \ + geoalchemy2-mariadb diff --git a/tests/conftest.py b/tests/conftest.py index e4a4813e..5d31e415 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,8 +4,10 @@ import pytest from sqlalchemy import MetaData from sqlalchemy import create_engine +from sqlalchemy import text from sqlalchemy.dialects.mysql.base import MySQLDialect from sqlalchemy.dialects.sqlite.base import SQLiteDialect +from sqlalchemy.exc import InvalidRequestError from sqlalchemy.orm import declarative_base from sqlalchemy.orm import sessionmaker @@ -42,7 +44,12 @@ def pytest_addoption(parser): parser.addoption( "--mysql_dburl", action="store", - help="MySQL DB URL used for tests with MySQL (`mysql://user:password@host:port/dbname`).", + help="MySQL DB URL used for tests (`mysql://user:password@host:port/dbname`).", + ) + parser.addoption( + "--mariadb_dburl", + action="store", + help="MariaDB DB URL used for tests (`mariadb://user:password@host:port/dbname`).", ) parser.addoption( "--engine-echo", @@ -62,7 +69,7 @@ def pytest_generate_tests(metafunc): elif metafunc.module.__name__ == "tests.test_functional_sqlite": dialects = sqlite_dialects elif metafunc.module.__name__ == "tests.test_functional_mysql": - dialects = ["mysql"] + dialects = ["mysql", "mariadb"] elif metafunc.module.__name__ == "tests.test_functional_geopackage": dialects = ["geopackage"] @@ -72,7 +79,7 @@ def pytest_generate_tests(metafunc): dialects = metafunc.cls.tested_dialects if dialects is None: - dialects = ["mysql", "postgresql"] + sqlite_dialects + dialects = ["mysql", "mariadb", "postgresql"] + sqlite_dialects if "sqlite" in dialects: # Order dialects @@ -99,6 +106,15 @@ def db_url_mysql(request, tmpdir_factory): ) +@pytest.fixture(scope="session") +def db_url_mariadb(request, tmpdir_factory): + return ( + request.config.getoption("--mariadb_dburl") + or os.getenv("PYTEST_MARIADB_DB_URL") + or "mariadb://gis:gis@localhost/gis" + ) + + @pytest.fixture(scope="session") def db_url_sqlite_spatialite3(request, tmpdir_factory): return ( @@ -134,16 +150,19 @@ def db_url( db_url_sqlite_spatialite4, db_url_geopackage, db_url_mysql, + db_url_mariadb, ): if request.param == "postgresql": return db_url_postgresql if request.param == "mysql": return db_url_mysql - elif request.param == "sqlite-spatialite3": + if request.param == "mariadb": + return db_url_mariadb + if request.param == "sqlite-spatialite3": return db_url_sqlite_spatialite3 - elif request.param == "sqlite-spatialite4": + if request.param == "sqlite-spatialite4": return db_url_sqlite_spatialite4 - elif request.param == "geopackage": + if request.param == "geopackage": return db_url_geopackage return None @@ -157,24 +176,41 @@ def _engine_echo(request): @pytest.fixture def engine(tmpdir, db_url, _engine_echo): """Provide an engine to test database.""" - if db_url.startswith("sqlite:///"): - # Copy the input SQLite DB to a temporary file and return an engine to it - input_url = str(db_url)[10:] - output_file = "test_spatial_db.sqlite" - current_engine = copy_and_connect_sqlite_db( - input_url, tmpdir / output_file, _engine_echo, "sqlite" - ) - elif db_url.startswith("gpkg:///"): - # Copy the input SQLite DB to a temporary file and return an engine to it - input_url = str(db_url)[8:] - output_file = "test_spatial_db.gpkg" - current_engine = copy_and_connect_sqlite_db( - input_url, tmpdir / output_file, _engine_echo, "gpkg" - ) - else: - # For other dialects the engine is directly returned - current_engine = create_engine(db_url, echo=_engine_echo) - current_engine.update_execution_options(search_path=["gis", "public"]) + try: + if db_url.startswith("sqlite:///"): + # Copy the input SQLite DB to a temporary file and return an engine to it + input_url = str(db_url)[10:] + output_file = "test_spatial_db.sqlite" + current_engine = copy_and_connect_sqlite_db( + input_url, tmpdir / output_file, _engine_echo, "sqlite" + ) + elif db_url.startswith("gpkg:///"): + # Copy the input GeoPackage to a temporary file and return an engine to it + input_url = str(db_url)[8:] + output_file = "test_spatial_db.gpkg" + current_engine = copy_and_connect_sqlite_db( + input_url, tmpdir / output_file, _engine_echo, "gpkg" + ) + else: + # For other dialects the engine is directly returned + current_engine = create_engine(db_url, echo=_engine_echo) + current_engine.update_execution_options(search_path=["gis", "public"]) + except Exception: + pytest.skip(reason=f"Could not create engine for this URL: {db_url}") + + # Disambiguate MySQL and MariaDB + if current_engine.dialect.name in ["mysql", "mariadb"]: + try: + with current_engine.begin() as connection: + mysql_type = ( + "MariaDB" + if "mariadb" in connection.execute(text("SELECT VERSION();")).scalar().lower() + else "MySQL" + ) + if current_engine.dialect.name != mysql_type.lower(): + pytest.skip(reason=f"Can not execute {mysql_type} queries on {db_url}") + except InvalidRequestError: + pytest.skip(reason=f"Can not execute MariaDB queries on {db_url}") yield current_engine current_engine.dispose() diff --git a/tests/gallery/test_orm_mapped_v2.py b/tests/gallery/test_orm_mapped_v2.py index c1bd492e..870744d2 100644 --- a/tests/gallery/test_orm_mapped_v2.py +++ b/tests/gallery/test_orm_mapped_v2.py @@ -8,7 +8,7 @@ """ import pytest -from pkg_resources import parse_version +from packaging.version import parse as parse_version from sqlalchemy import __version__ as SA_VERSION try: diff --git a/tests/gallery/test_summarystatsagg.py b/tests/gallery/test_summarystatsagg.py index a229708a..30c2d024 100644 --- a/tests/gallery/test_summarystatsagg.py +++ b/tests/gallery/test_summarystatsagg.py @@ -7,7 +7,7 @@ """ import pytest -from pkg_resources import parse_version +from packaging.version import parse as parse_version from sqlalchemy import Column from sqlalchemy import Float from sqlalchemy import Integer diff --git a/tests/schema_fixtures.py b/tests/schema_fixtures.py index bcd3d549..1e00bf8c 100644 --- a/tests/schema_fixtures.py +++ b/tests/schema_fixtures.py @@ -189,7 +189,7 @@ class Lake(base): geom_no_idx = Column( Geometry(geometry_type="LINESTRING", srid=4326, spatial_index=False) ) - if dialect_name != "mysql": + if dialect_name not in ["mysql", "mariadb"]: geom_z = Column(Geometry(geometry_type="LINESTRINGZ", srid=4326, dimension=3)) geom_m = Column(Geometry(geometry_type="LINESTRINGM", srid=4326, dimension=3)) geom_zm = Column(Geometry(geometry_type="LINESTRINGZM", srid=4326, dimension=4)) diff --git a/tests/test_alembic_migrations.py b/tests/test_alembic_migrations.py index 025ed69f..0d5fa09e 100644 --- a/tests/test_alembic_migrations.py +++ b/tests/test_alembic_migrations.py @@ -39,7 +39,7 @@ def test_no_diff(self, conn, Lake, setup_tables, use_alembic_monkeypatch, dialec Geometry( geometry_type="LINESTRING", srid=4326, - nullable=dialect_name != "mysql", + nullable=dialect_name not in ["mysql", "mariadb"], ), ), schema=Lake.__table__.schema, diff --git a/tests/test_functional.py b/tests/test_functional.py index 4702cbdb..19b23b8d 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -11,7 +11,7 @@ compat.register() del compat -from pkg_resources import parse_version +from packaging.version import parse as parse_version from shapely.geometry import LineString from shapely.geometry import Point from sqlalchemy import CheckConstraint @@ -249,6 +249,7 @@ def test_insert(self, conn, Lake, setup_tables): row = rows[0] assert isinstance(row[1], WKBElement) + wkt = conn.execute(from_shape(LineString([[0, 0], [3, 3]]), srid=4326).ST_AsText()).scalar() wkt = conn.execute(row[1].ST_AsText()).scalar() assert format_wkt(wkt) == "LINESTRING(0 0,1 1)" srid = conn.execute(row[1].ST_SRID()).scalar() @@ -355,7 +356,7 @@ def test_insert_all_geom_types( else: has_m = False - if ndims > 2 and dialect_name == "mysql": + if ndims > 2 and dialect_name in ["mysql", "mariadb"]: # Explicitly skip MySQL dialect to show that it can only work with 2D geometries pytest.xfail(reason="MySQL only supports 2D geometry types") @@ -372,12 +373,6 @@ class GeomTypeTable(base): inserted_wkt = f"{geom_type}{wkt}" - # Use the DB to generate the corresponding raw WKB - raw_wkb = conn.execute( - text("SELECT ST_AsBinary(ST_GeomFromText('{}', 4326))".format(inserted_wkt)) - ).scalar() - - wkb_elem = WKBElement(raw_wkb, srid=4326) inserted_elements = [ {"geom": inserted_wkt}, {"geom": f"SRID=4326;{inserted_wkt}"}, @@ -385,6 +380,13 @@ class GeomTypeTable(base): {"geom": WKTElement(f"SRID=4326;{inserted_wkt}")}, ] if dialect_name not in ["postgresql", "sqlite"] or not has_m: + # Use the DB to generate the corresponding raw WKB + raw_wkb = conn.execute( + text("SELECT ST_AsBinary(ST_GeomFromText('{}', 4326))".format(inserted_wkt)) + ).scalar() + + wkb_elem = WKBElement(raw_wkb, srid=4326) + # Currently Shapely does not support geometry types with M dimension inserted_elements.append({"geom": wkb_elem}) inserted_elements.append({"geom": wkb_elem.as_ewkb()}) @@ -466,10 +468,8 @@ def test_insert_negative_coords(self, conn, Poi, setup_tables, dialect_name): assert format_wkt(wkt) == "POINT(-1 1)" srid = conn.execute(row[1].ST_SRID()).scalar() assert srid == 4326 - if dialect_name == "mysql": - assert row[1] == from_shape(Point(-1, 1), srid=4326) - else: - assert row[1] == from_shape(Point(-1, 1), srid=4326, extended=True) + extended = dialect_name not in ["mysql", "mariadb"] + assert row[1] == from_shape(Point(-1, 1), srid=4326, extended=extended) class TestSelectBindParam: @@ -516,7 +516,7 @@ def test_select_bindparam_WKBElement_extented(self, conn, Lake, setup_one_lake, rows = results.fetchall() geom = rows[0][1] assert isinstance(geom, WKBElement) - if dialect_name == "mysql": + if dialect_name in ["mysql", "mariadb"]: assert geom.extended is False else: assert geom.extended is True @@ -564,7 +564,7 @@ def test_WKTElement(self, session, Lake, setup_tables, dialect_name): session.flush() session.expire(lake) assert isinstance(lake.geom, WKBElement) - if dialect_name == "mysql": + if dialect_name in ["mysql", "mariadb"]: # Not extended case assert str(lake.geom) == ( "0102000000020000000000000000000000000000000000000000000" @@ -587,7 +587,7 @@ def test_WKBElement(self, session, Lake, setup_tables, dialect_name): session.flush() session.expire(lake) assert isinstance(lake.geom, WKBElement) - if dialect_name == "mysql": + if dialect_name in ["mysql", "mariadb"]: # Not extended case assert str(lake.geom) == ( "0102000000020000000000000000000000000000000000000000000" @@ -666,7 +666,7 @@ def test_WKTElement(self, session, Lake, setup_tables, dialect_name): srid = session.execute(lake.geom.ST_SRID()).scalar() assert srid == 4326 - if dialect_name != "mysql": + if dialect_name not in ["mysql", "mariadb"]: # Set geometry to None lake.geom = None @@ -708,7 +708,7 @@ def test_WKBElement(self, session, Lake, setup_tables, dialect_name): srid = session.execute(lake.geom.ST_SRID()).scalar() assert srid == 4326 - if dialect_name != "mysql": + if dialect_name not in ["mysql", "mariadb"]: # Set geometry to None lake.geom = None @@ -759,7 +759,7 @@ def test_other_type_fail(self, session, Lake, setup_tables, dialect_name): session.flush() session.refresh(lake) assert lake.geom is None - elif dialect_name == "mysql": + elif dialect_name in ["mysql", "mariadb"]: with pytest.raises(OperationalError): session.flush() else: @@ -925,7 +925,7 @@ class TestShapely: def test_to_shape(self, session, Lake, setup_tables, dialect_name): if dialect_name in ["sqlite", "geopackage"]: data_type = str - elif dialect_name == "mysql": + elif dialect_name in ["mysql", "mariadb"]: data_type = bytes else: data_type = memoryview diff --git a/tests/test_functional_mysql.py b/tests/test_functional_mysql.py index ccdc1c3e..42d8184c 100644 --- a/tests/test_functional_mysql.py +++ b/tests/test_functional_mysql.py @@ -1,7 +1,7 @@ from json import loads import pytest -from pkg_resources import parse_version +from packaging.version import parse as parse_version from shapely.geometry import LineString from sqlalchemy import MetaData from sqlalchemy import Table @@ -9,6 +9,7 @@ from sqlalchemy import bindparam from sqlalchemy import create_engine from sqlalchemy import text +from sqlalchemy.engine import URL from sqlalchemy.exc import OperationalError from sqlalchemy.exc import StatementError from sqlalchemy.sql import func @@ -42,7 +43,7 @@ def test_create_drop_tables( class TestInsertionCore: @pytest.mark.parametrize("use_executemany", [True, False]) - def test_insert(self, conn, Lake, setup_tables, use_executemany): + def test_insert_mysql(self, conn, Lake, setup_tables, use_executemany): # Issue several inserts using DBAPI's executemany() method or single inserts. This tests # the Geometry type's bind_processor and bind_expression functions. elements = [ @@ -211,6 +212,7 @@ def test_ST_GeometryType(self, session, Lake, setup_one_lake): assert isinstance(r4, Lake) assert r4.id == lake_id + @test_only_with_dialects("mysql") def test_ST_Transform(self, session, Lake, setup_one_lake): lake_id = setup_one_lake @@ -277,6 +279,7 @@ def test_ST_GeoJSON_feature(self, session, Lake, setup_tables): parse_version(SA_VERSION) < parse_version("1.3.4"), reason="Case-insensitivity is only available for sqlalchemy>=1.3.4", ) + @test_only_with_dialects("mysql") def test_comparator_case_insensitivity(self, session, Lake, setup_one_lake): lake_id = setup_one_lake @@ -327,14 +330,22 @@ def test_insert(self, conn, Lake, setup_tables): class TestReflection: @pytest.fixture - def create_temp_db(self, request, conn, reflection_tables_metadata): + def create_temp_db(self, _engine_echo, engine, conn, reflection_tables_metadata): """Temporary database, that is dropped on fixture teardown. Used to make sure reflection methods always uses the correct schema. """ temp_db_name = "geoalchemy_test_reflection" + temp_db_url = URL.create( + engine.url.drivername, + engine.url.username, + engine.url.password, + engine.url.host, + engine.url.port, + temp_db_name, + ) engine = create_engine( - f"mysql://gis:gis@localhost/{temp_db_name}", - echo=request.config.getoption("--engine-echo"), + temp_db_url, + echo=_engine_echo, ) conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {temp_db_name};")) with engine.connect() as connection: @@ -349,19 +360,19 @@ def setup_reflection_tables(self, reflection_tables_metadata, conn): reflection_tables_metadata.drop_all(conn, checkfirst=True) reflection_tables_metadata.create_all(conn) - def test_reflection_mysql(self, conn, setup_reflection_tables, create_temp_db): + def test_reflection_mysql(self, conn, setup_reflection_tables, create_temp_db, dialect_name): t = Table("lake", MetaData(), autoload_with=conn) type_ = t.c.geom.type assert isinstance(type_, Geometry) assert type_.geometry_type == "LINESTRING" - assert type_.srid == 4326 + assert type_.srid == 4326 if dialect_name == "mysql" else -1 assert type_.dimension == 2 type_ = t.c.geom_no_idx.type assert isinstance(type_, Geometry) assert type_.geometry_type == "LINESTRING" - assert type_.srid == 4326 + assert type_.srid == 4326 if dialect_name == "mysql" else -1 assert type_.dimension == 2 # Drop the table diff --git a/tests/test_functional_postgresql.py b/tests/test_functional_postgresql.py index 62d584c4..00d796b4 100644 --- a/tests/test_functional_postgresql.py +++ b/tests/test_functional_postgresql.py @@ -11,7 +11,7 @@ compat.register() del compat -from pkg_resources import parse_version +from packaging.version import parse as parse_version from shapely.geometry import Point from sqlalchemy import Column from sqlalchemy import Integer diff --git a/tests/test_pickle.py b/tests/test_pickle.py index fe916970..4e67e49b 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -51,7 +51,7 @@ def test_pickle_unpickle(self, session, setup_one_lake, dialect_name): unpickled = pickle.loads(pickled) assert unpickled.geom.srid == 4326 assert str(unpickled.geom) == data_desc - if dialect_name == "mysql": + if dialect_name in ["mysql", "mariadb"]: assert unpickled.geom.extended is False else: assert unpickled.geom.extended is True diff --git a/tox.ini b/tox.ini index 0c355e73..95320aab 100644 --- a/tox.ini +++ b/tox.ini @@ -16,7 +16,9 @@ python = [testenv] passenv= PROJ_LIB + PYTEST_ADDOPTS PYTEST_POSTGRESQL_DB_URL + PYTEST_MARIADB_DB_URL PYTEST_MYSQL_DB_URL PYTEST_SPATIALITE3_DB_URL PYTEST_SPATIALITE4_DB_URL @@ -50,6 +52,8 @@ commands= --html reports/pytest-{envname}.html \ --junit-xml=reports/pytest-{envname}.xml \ --self-contained-html \ + --durations 10 \ + --durations-min=2.0 \ --mypy \ {posargs}