From 81ffc6230e9e1490e5b68931582cf198ef1cb6a8 Mon Sep 17 00:00:00 2001 From: Andrew Olsen Date: Wed, 3 Mar 2021 14:41:04 +1300 Subject: [PATCH] Address SQL Server review comments --- sno/base_dataset.py | 4 + sno/checkout.py | 2 +- sno/clone.py | 2 +- sno/geometry.py | 8 + sno/init.py | 2 +- sno/repo.py | 4 +- sno/sqlalchemy.py | 32 ++- sno/working_copy/base.py | 123 +++++++----- sno/working_copy/db_server.py | 103 ++++++++++ sno/working_copy/gpkg.py | 172 +++++++++-------- sno/working_copy/postgis.py | 161 ++++------------ sno/working_copy/sqlserver.py | 267 ++++++++++---------------- sno/working_copy/sqlserver_adapter.py | 39 +--- sno/working_copy/table_defs.py | 4 +- tests/conftest.py | 6 +- tests/test_working_copy_sqlserver.py | 2 +- 16 files changed, 466 insertions(+), 465 deletions(-) create mode 100644 sno/working_copy/db_server.py diff --git a/sno/base_dataset.py b/sno/base_dataset.py index 56325600f..0a25f0592 100644 --- a/sno/base_dataset.py +++ b/sno/base_dataset.py @@ -9,6 +9,10 @@ class BaseDataset(ImportSource): """ Common interface for all datasets - mainly Dataset2, but there is also Dataset0 and Dataset1 used by `sno upgrade`. + + A Dataset instance is immutable since it is a view of a particular git tree. + To get a new version of a dataset, commit the desired changes, + then instantiate a new Dataset instance that references the new git tree. """ # Constants that subclasses should generally define. diff --git a/sno/checkout.py b/sno/checkout.py index cf3f80ace..587a49011 100644 --- a/sno/checkout.py +++ b/sno/checkout.py @@ -378,7 +378,7 @@ def create_workingcopy(ctx, discard_changes, wc_path): wc_path = WorkingCopy.default_path(repo.workdir_path) if wc_path != old_wc_path: - WorkingCopy.check_valid_creation_path(repo.workdir_path, wc_path) + WorkingCopy.check_valid_creation_path(wc_path, repo.workdir_path) # Finished sanity checks - start work: if old_wc and wc_path != old_wc_path: diff --git a/sno/clone.py b/sno/clone.py index c598746ad..63574566a 100644 --- a/sno/clone.py +++ b/sno/clone.py @@ -106,7 +106,7 @@ def clone( if repo_path.exists() and any(repo_path.iterdir()): raise InvalidOperation(f'"{repo_path}" isn\'t empty', param_hint="directory") - WorkingCopy.check_valid_creation_path(repo_path, wc_path) + WorkingCopy.check_valid_creation_path(wc_path, repo_path) if not repo_path.exists(): repo_path.mkdir(parents=True) diff --git a/sno/geometry.py b/sno/geometry.py index ca42f0d75..8873c8ddd 100644 --- a/sno/geometry.py +++ b/sno/geometry.py @@ -59,6 +59,11 @@ def with_crs_id(self, crs_id): @property def crs_id(self): + """ + Returns the CRS ID as it is embedded in the GPKG header - before the WKB. + Note that datasets V2 zeroes this field before committing, + so will return zero when called on Geometry where it has been zeroed. + """ wkb_offset, is_le, crs_id = parse_gpkg_geom(self) return crs_id @@ -296,6 +301,7 @@ def gpkg_geom_to_ogr(gpkg_geom, parse_crs=False): def wkt_to_gpkg_geom(wkt, **kwargs): + """Given a well-known-text string, returns a GPKG Geometry object.""" if wkt is None: return None @@ -304,6 +310,7 @@ def wkt_to_gpkg_geom(wkt, **kwargs): def wkb_to_gpkg_geom(wkb, **kwargs): + """Given a well-known-binary bytestring, returns a GPKG Geometry object.""" if wkb is None: return None @@ -312,6 +319,7 @@ def wkb_to_gpkg_geom(wkb, **kwargs): def hex_wkb_to_gpkg_geom(hex_wkb, **kwargs): + """Given a hex-encoded well-known-binary bytestring, returns a GPKG Geometry object.""" if hex_wkb is None: return None diff --git a/sno/init.py b/sno/init.py index 4f7dcc5de..a9d88b4bc 100644 --- a/sno/init.py +++ b/sno/init.py @@ -369,7 +369,7 @@ def init( if repo_path.exists() and any(repo_path.iterdir()): raise InvalidOperation(f'"{repo_path}" isn\'t empty', param_hint="directory") - WorkingCopy.check_valid_creation_path(repo_path, wc_path) + WorkingCopy.check_valid_creation_path(wc_path, repo_path) if not repo_path.exists(): repo_path.mkdir(parents=True) diff --git a/sno/repo.py b/sno/repo.py index 7fb303fa6..1e5420d1a 100644 --- a/sno/repo.py +++ b/sno/repo.py @@ -177,7 +177,7 @@ def init_repository( repo_root_path = repo_root_path.resolve() cls._ensure_exists_and_empty(repo_root_path) if not bare: - WorkingCopy.check_valid_creation_path(repo_root_path, wc_path) + WorkingCopy.check_valid_creation_path(wc_path, repo_root_path) extra_args = [] if initial_branch is not None: @@ -224,7 +224,7 @@ def clone_repository( repo_root_path = repo_root_path.resolve() cls._ensure_exists_and_empty(repo_root_path) if not bare: - WorkingCopy.check_valid_creation_path(repo_root_path, wc_path) + WorkingCopy.check_valid_creation_path(wc_path, repo_root_path) if bare: sno_repo = cls._create_with_git_command( diff --git a/sno/sqlalchemy.py b/sno/sqlalchemy.py index c3208f462..959bedadb 100644 --- a/sno/sqlalchemy.py +++ b/sno/sqlalchemy.py @@ -1,6 +1,6 @@ import re import socket -from urllib.parse import urlsplit, urlunsplit +from urllib.parse import urlsplit, urlunsplit, urlencode, parse_qs import sqlalchemy @@ -11,6 +11,7 @@ from sno import spatialite_path from sno.geometry import Geometry +from sno.exceptions import NotFound def gpkg_engine(path): @@ -110,7 +111,6 @@ def _on_connect(psycopg2_conn, connection_record): CANONICAL_SQL_SERVER_SCHEME = "mssql" INTERNAL_SQL_SERVER_SCHEME = "mssql+pyodbc" -SQL_SERVER_DRIVER_LIB = "ODBC+Driver+17+for+SQL+Server" def sqlserver_engine(msurl): @@ -122,7 +122,7 @@ def sqlserver_engine(msurl): url_netloc = re.sub(r"\blocalhost\b", _replace_with_localhost, url.netloc) url_query = _append_to_query( - url.query, {"driver": SQL_SERVER_DRIVER_LIB, "Application+Name": "sno"} + url.query, {"driver": get_sqlserver_driver(), "Application Name": "sno"} ) msurl = urlunsplit( @@ -133,18 +133,30 @@ def sqlserver_engine(msurl): return engine +def get_sqlserver_driver(): + import pyodbc + + drivers = [ + d for d in pyodbc.drivers() if re.search("SQL Server", d, flags=re.IGNORECASE) + ] + if not drivers: + drivers = pyodbc.drivers() + if not drivers: + raise NotFound("SQL Server driver was not found") + return sorted(drivers)[-1] # Latest driver + + def _replace_with_localhost(*args, **kwargs): return socket.gethostbyname("localhost") -def _append_query_to_url(uri, query_dict): +def _append_query_to_url(uri, new_query_dict): url = urlsplit(uri) - url_query = _append_to_query(url.query, query_dict) + url_query = _append_to_query(url.query, new_query_dict) return urlunsplit([url.scheme, url.netloc, url.path, url_query, ""]) -def _append_to_query(url_query, query_dict): - for key, value in query_dict.items(): - if key not in url_query: - url_query = "&".join(filter(None, [url_query, f"{key}={value}"])) - return url_query +def _append_to_query(existing_query, new_query_dict): + query_dict = parse_qs(existing_query) + # ignore new keys if they're already set in the querystring + return urlencode({**new_query_dict, **query_dict}) diff --git a/sno/working_copy/base.py b/sno/working_copy/base.py index 4dffdcac2..7332e3015 100644 --- a/sno/working_copy/base.py +++ b/sno/working_copy/base.py @@ -1,14 +1,13 @@ +from enum import Enum, auto import contextlib import functools import itertools import logging import time -from enum import Enum, auto - import click import pygit2 -import sqlalchemy +import sqlalchemy as sa from sno.base_dataset import BaseDataset from sno.diff_structs import RepoDiff, DatasetDiff, DeltaDiff, Delta @@ -48,7 +47,8 @@ def from_path(cls, path, allow_invalid=False): f"Unrecognised working copy type: {path}\n" "Try one of:\n" " PATH.gpkg\n" - " postgresql://[HOST]/DBNAME/SCHEMA\n" + " postgresql://[HOST]/DBNAME/DBSCHEMA\n" + " mssql://[HOST]/DBNAME/DBSCHEMA" ) @property @@ -103,12 +103,9 @@ class WorkingCopy: SNO_WORKINGCOPY_PATH = "sno.workingcopy.path" @property - @functools.lru_cache(maxsize=1) - def DB_SCHEMA(self): - """Escaped, dialect-specific name of the database-schema owned by this working copy (if any).""" - if self.db_schema is None: - raise RuntimeError("No schema to escape.") - return self.preparer.format_schema(self.db_schema) + def WORKING_COPY_TYPE_NAME(self): + """Human readable name of this type of working copy, eg "PostGIS".""" + raise NotImplementedError() @property @functools.lru_cache(maxsize=1) @@ -147,9 +144,7 @@ def table_identifier(self, dataset_or_table): else dataset_or_table ) sqlalchemy_table = ( - sqlalchemy.table(table, schema=self.db_schema) - if isinstance(table, str) - else table + sa.table(table, schema=self.db_schema) if isinstance(table, str) else table ) return self.preparer.format_table(sqlalchemy_table) @@ -157,7 +152,7 @@ def _table_def_for_column_schema(self, col, dataset): # This is just used for selects/inserts/updates - we don't need the full type information. # We only need type information if some automatic conversion needs to happen on read or write. # TODO: return the full type information so we can also use it for CREATE TABLE. - return sqlalchemy.column(col.name) + return sa.column(col.name) @functools.lru_cache() def _table_def_for_dataset(self, dataset, schema=None): @@ -166,7 +161,7 @@ def _table_def_for_dataset(self, dataset, schema=None): Specifies table and column names only - the minimum required such that an insert() or update() will work. """ schema = schema or dataset.schema - return sqlalchemy.table( + return sa.table( dataset.table_name, *[self._table_def_for_column_schema(c, dataset) for c in schema], schema=self.db_schema, @@ -261,7 +256,16 @@ def write_config(cls, repo, path=None, bare=False): repo_cfg[path_key] = str(path) @classmethod - def check_valid_creation_path(cls, workdir_path, wc_path): + def subclass_from_path(cls, wc_path): + wct = WorkingCopyType.from_path(wc_path) + if wct.class_ is cls: + raise RuntimeError( + f"No subclass found - don't call subclass_from_path on concrete implementation {cls}." + ) + return wct.class_ + + @classmethod + def check_valid_creation_path(cls, wc_path, workdir_path=None): """ Given a user-supplied string describing where to put the working copy, ensures it is a valid location, and nothing already exists there that prevents us from creating it. Raises InvalidOperation if it is not. @@ -269,20 +273,16 @@ def check_valid_creation_path(cls, workdir_path, wc_path): """ if not wc_path: wc_path = cls.default_path(workdir_path) - WorkingCopyType.from_path(wc_path).class_.check_valid_creation_path( - workdir_path, wc_path - ) + cls.subclass_from_path(wc_path).check_valid_creation_path(wc_path, workdir_path) @classmethod - def check_valid_path(cls, workdir_path, wc_path): + def check_valid_path(cls, wc_path, workdir_path=None): """ Given a user-supplied string describing where to put the working copy, ensures it is a valid location, and nothing already exists there that prevents us from creating it. Raises InvalidOperation if it is not. Doesn't check if we have permissions to create a working copy there. """ - WorkingCopyType.from_path(wc_path).class_.check_valid_path( - workdir_path, wc_path - ) + cls.subclass_from_path(wc_path).check_valid_path(wc_path, workdir_path) def check_valid_state(self): if self.is_created(): @@ -304,19 +304,39 @@ def default_path(cls, workdir_path): return f"{stem}.gpkg" @classmethod - def normalise_path(cls, repo, path): + def normalise_path(cls, repo, wc_path): """If the path is in a non-standard form, normalise it to the equivalent standard form.""" - return WorkingCopyType.from_path(path).class_.normalise_path(repo, path) + return cls.subclass_from_path(wc_path).normalise_path(repo, wc_path) @contextlib.contextmanager def session(self, bulk=0): """ - Context manager for DB sessions, yields a session object inside a transaction - Calling again yields the _same_ session, the transaction/etc only happen in the outer one. + Context manager for GeoPackage DB sessions, yields a connection object inside a transaction - @bulk controls bulk-loading operating mode - subject to change. See ./gpkg.py + Calling again yields the _same_ session, the transaction/etc only happen in the outer one. """ - raise NotImplementedError() + L = logging.getLogger(f"{self.__class__.__qualname__}.session") + + if hasattr(self, "_session"): + # Inner call - reuse existing session. + L.debug("session: existing...") + yield self._session + L.debug("session: existing/done") + return + + L.debug("session: new...") + self._session = self.sessionmaker() + try: + # TODO - use tidier syntax for opening transactions from sqlalchemy. + yield self._session + self._session.commit() + except Exception: + self._session.rollback() + raise + finally: + self._session.close() + del self._session + L.debug("session: new/done") def is_created(self): """ @@ -569,10 +589,10 @@ def _execute_dirty_rows_query( pk_column = table.columns[schema.pk_columns[0].name] tracking_col_type = sno_track.c.pk.type - base_query = sqlalchemy.select(columns=cols_to_select).select_from( + base_query = sa.select(columns=cols_to_select).select_from( sno_track.outerjoin( table, - sno_track.c.pk == sqlalchemy.cast(pk_column, tracking_col_type), + sno_track.c.pk == sa.cast(pk_column, tracking_col_type), ) ) @@ -581,7 +601,7 @@ def _execute_dirty_rows_query( else: pks = list(feature_filter) query = base_query.where( - sqlalchemy.and_( + sa.and_( sno_track.c.table_name == dataset.table_name, sno_track.c.pk.in_(pks), ) @@ -611,10 +631,10 @@ def reset_tracking_table(self, reset_filter=UNFILTERED): continue pks = list(dataset_filter.get("feature", [])) - t = sqlalchemy.text( + t = sa.text( f"DELETE FROM {self.SNO_TRACK} WHERE table_name=:table_name AND pk IN :pks;" ) - t = t.bindparams(sqlalchemy.bindparam("pks", expanding=True)) + t = t.bindparams(sa.bindparam("pks", expanding=True)) sess.execute(t, {"table_name": table_name, "pks": pks}) def can_find_renames(self, meta_diff): @@ -701,8 +721,7 @@ def write_full(self, commit, *datasets, **kwargs): self._write_meta(sess, dataset) if dataset.has_geometry: - # This should be called while the table is still empty. - self._create_spatial_index(sess, dataset) + self._create_spatial_index_pre(sess, dataset) L.info("Creating features...") sql = self._insert_into_dataset(dataset) @@ -739,6 +758,9 @@ def write_full(self, commit, *datasets, **kwargs): "Overall rate: %d features/s", (feat_progress / (t1 - t0 or 0.001)) ) + if dataset.has_geometry: + self._create_spatial_index_post(sess, dataset) + self._create_triggers(sess, dataset) self._update_last_write_time(sess, dataset, commit) @@ -754,18 +776,29 @@ def _write_meta(self, sess, dataset): """Write any non-feature data relating to dataset - title, description, CRS, etc.""" raise NotImplementedError() - def _create_spatial_index(self, sess, dataset): + def _create_spatial_index_pre(self, sess, dataset): """ Creates a spatial index for the table for the given dataset. - The spatial index is configured so that it is automatically updated when the table is modified. - It is not guaranteed that the spatial index will take into account features that are already present - in the table when this function is called - therefore, this should be called while the table is still empty. + This function comes in a pair - _pre is called before features are written, and _post is called afterwards. + Once both are called, the index must contain all the features currently in the table, and, be + configured such that any further writes cause the index to be updated automatically. """ - raise NotImplementedError() + + # Note that the simplest implementation is to add a trigger here so that any further writes update + # the index. Then _create_spatial_index_post needn't be implemented. + pass + + def _create_spatial_index_post(self, sess, dataset): + """Like _create_spatial_index_pre, but runs AFTER the bulk of features have been written.""" + + # Being able to create the index after the bulk of features have been written could be useful for two reasons: + # 1. It might be more efficient to write the features first, then index afterwards. + # 2. Certain working copies are not able to create an index without first knowing a rough bounding box. + pass def _drop_spatial_index(self, sess, dataset): - """Inverse of _create_spatial_index - deletes the spatial index.""" - raise NotImplementedError() + """Inverse of _create_spatial_index_* - deletes the spatial index.""" + pass def _update_last_write_time(self, sess, dataset, commit=None): """Hook for updating the last-modified timestamp stored for a particular dataset, if there is one.""" @@ -795,9 +828,7 @@ def _delete_features(self, sess, dataset, pk_list): pk_column = self.preparer.quote(dataset.primary_key) sql = f"""DELETE FROM {self.table_identifier(dataset)} WHERE {pk_column} IN :pks;""" - stmt = sqlalchemy.text(sql).bindparams( - sqlalchemy.bindparam("pks", expanding=True) - ) + stmt = sa.text(sql).bindparams(sa.bindparam("pks", expanding=True)) feat_count = 0 CHUNK_SIZE = 100 for pks in self._chunk(pk_list, CHUNK_SIZE): diff --git a/sno/working_copy/db_server.py b/sno/working_copy/db_server.py new file mode 100644 index 000000000..b5d73017d --- /dev/null +++ b/sno/working_copy/db_server.py @@ -0,0 +1,103 @@ +import functools +import re +from urllib.parse import urlsplit, urlunsplit + +import click + +from .base import WorkingCopy +from sno.exceptions import InvalidOperation + + +class DatabaseServer_WorkingCopy(WorkingCopy): + """Functionality common to working copies that connect to a database server.""" + + @property + def URI_SCHEME(self): + """The URI scheme to connect to this type of database, eg "postgresql".""" + raise NotImplementedError() + + @classmethod + def check_valid_creation_path(cls, wc_path, workdir_path=None): + cls.check_valid_path(wc_path, workdir_path) + + working_copy = cls(None, wc_path) + if working_copy.has_data(): + db_schema = working_copy.db_schema + container_text = f"schema '{db_schema}'" if db_schema else "working copy" + raise InvalidOperation( + f"Error creating {cls.WORKING_COPY_TYPE_NAME} working copy at {wc_path} - " + f"non-empty {container_text} already exists" + ) + + @classmethod + def check_valid_path(cls, wc_path, workdir_path=None): + cls.check_valid_db_uri(wc_path, workdir_path) + + @classmethod + def normalise_path(cls, repo, wc_path): + return wc_path + + @classmethod + def check_valid_db_uri(cls, db_uri, workdir_path=None): + """ + For working copies that connect to a database - checks the given URI is in the required form: + >>> URI_SCHEME::[HOST]/DBNAME/DBSCHEMA + """ + url = urlsplit(db_uri) + + if url.scheme != cls.URI_SCHEME: + raise click.UsageError( + f"Invalid {cls.WORKING_COPY_TYPE_NAME} URI - " + f"Expecting URI in form: {cls.URI_SCHEME}://[HOST]/DBNAME/DBSCHEMA" + ) + + url_path = url.path + path_parts = url_path[1:].split("/", 3) if url_path else [] + + suggestion_message = "" + if len(path_parts) == 1 and workdir_path is not None: + suggested_path = f"/{path_parts[0]}/{cls.default_db_schema(workdir_path)}" + suggested_uri = urlunsplit( + [url.scheme, url.netloc, suggested_path, url.query, ""] + ) + suggestion_message = f"\nFor example: {suggested_uri}" + + if len(path_parts) != 2: + raise click.UsageError( + f"Invalid {cls.WORKING_COPY_TYPE_NAME} URI - URI requires both database name and database schema:\n" + f"Expecting URI in form: {cls.URI_SCHEME}://[HOST]/DBNAME/DBSCHEMA" + + suggestion_message + ) + + @classmethod + def _separate_db_schema(cls, db_uri): + """ + Removes the DBSCHEMA part off the end of a uri in the form URI_SCHEME::[HOST]/DBNAME/DBSCHEMA - + and returns the URI and the DBSCHEMA separately. + Useful since generally, URI_SCHEME::[HOST]/DBNAME is what is needed to connect to the database, + and then DBSCHEMA must be specified in each query. + """ + url = urlsplit(db_uri) + url_path = url.path + path_parts = url_path[1:].split("/", 3) if url_path else [] + assert len(path_parts) == 2 + url_path = "/" + path_parts[0] + db_schema = path_parts[1] + return urlunsplit([url.scheme, url.netloc, url_path, url.query, ""]), db_schema + + @classmethod + def default_db_schema(cls, workdir_path): + """Returns a suitable default database schema - named after the folder this Sno repo is in.""" + stem = workdir_path.stem + schema = re.sub("[^a-z0-9]+", "_", stem.lower()) + "_sno" + if schema[0].isdigit(): + schema = "_" + schema + return schema + + @property + @functools.lru_cache(maxsize=1) + def DB_SCHEMA(self): + """Escaped, dialect-specific name of the database-schema owned by this working copy (if any).""" + if self.db_schema is None: + raise RuntimeError("No schema to escape.") + return self.preparer.format_schema(self.db_schema) diff --git a/sno/working_copy/gpkg.py b/sno/working_copy/gpkg.py index 9bcbb00dd..2a0846cd8 100644 --- a/sno/working_copy/gpkg.py +++ b/sno/working_copy/gpkg.py @@ -7,7 +7,7 @@ import click from osgeo import gdal -import sqlalchemy +import sqlalchemy as sa from sqlalchemy.orm import sessionmaker from sqlalchemy.sql.compiler import IdentifierPreparer from sqlalchemy.types import UserDefinedType @@ -26,6 +26,15 @@ class WorkingCopy_GPKG(WorkingCopy): + """ + GPKG working copy implementation. + + Requirements: + 1. Can read and write to the filesystem at the specified path. + """ + + WORKING_COPY_TYPE_NAME = "GPKG" + def __init__(self, repo, path): self.repo = repo self.path = path @@ -37,34 +46,34 @@ def __init__(self, repo, path): self.sno_tables = GpkgSnoTables @classmethod - def check_valid_creation_path(cls, workdir_path, path): - cls.check_valid_path(workdir_path, path) + def check_valid_creation_path(cls, wc_path, workdir_path=None): + cls.check_valid_path(wc_path, workdir_path) - gpkg_path = (workdir_path / path).resolve() + gpkg_path = (workdir_path / wc_path).resolve() if gpkg_path.exists(): desc = "path" if gpkg_path.is_dir() else "GPKG file" raise InvalidOperation( - f"Error creating GPKG working copy at {path} - {desc} already exists" + f"Error creating GPKG working copy at {wc_path} - {desc} already exists" ) @classmethod - def check_valid_path(cls, workdir_path, path): - if not str(path).endswith(".gpkg"): - suggested_path = f"{os.path.splitext(str(path))[0]}.gpkg" + def check_valid_path(cls, wc_path, workdir_path=None): + if not str(wc_path).endswith(".gpkg"): + suggested_path = f"{os.path.splitext(str(wc_path))[0]}.gpkg" raise click.UsageError( f"Invalid GPKG path - expected .gpkg suffix, eg {suggested_path}" ) @classmethod - def normalise_path(cls, repo, path): + def normalise_path(cls, repo, wc_path): """Rewrites a relative path (relative to the current directory) as relative to the repo.workdir_path.""" - path = Path(path) - if not path.is_absolute(): + wc_path = Path(wc_path) + if not wc_path.is_absolute(): try: - return str(path.resolve().relative_to(repo.workdir_path.resolve())) + return str(wc_path.resolve().relative_to(repo.workdir_path.resolve())) except ValueError: pass - return str(path) + return str(wc_path) @property def full_path(self): @@ -82,11 +91,11 @@ def _insert_or_replace_into_dataset(self, dataset): def _table_def_for_column_schema(self, col, dataset): if col.data_type == "geometry": - # This user-defined Geography type adapts WKB to SQL Server's native geography type. - return sqlalchemy.column(col.name, GeometryType) + # This user-defined GeometryType normalises GPKG geometry to the Sno V2 GPKG geometry. + return sa.column(col.name, GeometryType) else: # Don't need to specify type information for other columns at present, since we just pass through the values. - return sqlalchemy.column(col.name) + return sa.column(col.name) def _insert_or_replace_state_table_tree(self, sess, tree_id): r = sess.execute( @@ -119,37 +128,36 @@ def session(self, bulk=0): # - do something consistent and safe from then on. if hasattr(self, "_session"): - # inner - reuse + # Inner call - reuse existing session. L.debug(f"session(bulk={bulk}): existing...") yield self._session L.debug(f"session(bulk={bulk}): existing/done") + return - else: - L.debug(f"session(bulk={bulk}): new...") + # Outer call - create new session: + L.debug(f"session(bulk={bulk}): new...") + self._session = self.sessionmaker() - try: - self._session = self.sessionmaker() - - if bulk: - self._session.execute("PRAGMA synchronous = OFF;") - self._session.execute( - "PRAGMA cache_size = -1048576;" - ) # -KiB => 1GiB - if bulk >= 2: - self._session.execute("PRAGMA journal_mode = MEMORY;") - self._session.execute("PRAGMA locking_mode = EXCLUSIVE;") - - # TODO - use tidier syntax for opening transactions from sqlalchemy. - self._session.execute("BEGIN TRANSACTION;") - yield self._session - self._session.commit() - except Exception: - self._session.rollback() - raise - finally: - self._session.close() - del self._session - L.debug(f"session(bulk={bulk}): new/done") + try: + if bulk: + self._session.execute("PRAGMA synchronous = OFF;") + self._session.execute("PRAGMA cache_size = -1048576;") # -KiB => 1GiB + if bulk >= 2: + self._session.execute("PRAGMA journal_mode = MEMORY;") + self._session.execute("PRAGMA locking_mode = EXCLUSIVE;") + + # TODO - use tidier syntax for opening transactions from sqlalchemy. + self._session.execute("BEGIN TRANSACTION;") + yield self._session + self._session.commit() + + except Exception: + self._session.rollback() + raise + finally: + self._session.close() + del self._session + L.debug(f"session(bulk={bulk}): new/done") def delete(self, keep_db_schema_if_possible=False): """Delete the working copy files.""" @@ -428,12 +436,15 @@ def _delete_meta_metadata(self, sess, table_name): """DELETE FROM gpkg_metadata WHERE id IN :ids;""", ) for sql in sqls: - stmt = sqlalchemy.text(sql).bindparams( - sqlalchemy.bindparam("ids", expanding=True) - ) + stmt = sa.text(sql).bindparams(sa.bindparam("ids", expanding=True)) sess.execute(stmt, {"ids": ids}) - def _create_spatial_index(self, sess, dataset): + def _create_spatial_index_pre(self, sess, dataset): + # Implementing only _create_spatial_index_pre: + # gpkgAddSpatialIndex has to be called before writing any features, + # since it only adds on-write triggers to update the index - it doesn't + # add any pre-existing features to the index. + L = logging.getLogger(f"{self.__class__.__qualname__}._create_spatial_index") geom_col = dataset.geom_column_name @@ -480,56 +491,42 @@ def _create_triggers(self, sess, dataset): table_identifier = self.table_identifier(dataset) pk_column = self.quote(dataset.primary_key) - # SQLite doesn't let you do param substitutions in CREATE TRIGGER: - escaped_table_name = dataset.table_name.replace("'", "''") - - sess.execute( + insert_trigger = sa.text( f""" CREATE TRIGGER {self._quoted_trigger_name(dataset, 'ins')} AFTER INSERT ON {table_identifier} BEGIN - INSERT OR REPLACE INTO {self.SNO_TRACK} - (table_name, pk) - VALUES ('{escaped_table_name}', NEW.{pk_column}); + INSERT OR REPLACE INTO {self.SNO_TRACK} (table_name, pk) + VALUES (:table_name, NEW.{pk_column}); END; """ ) - sess.execute( + update_trigger = sa.text( f""" CREATE TRIGGER {self._quoted_trigger_name(dataset, 'upd')} AFTER UPDATE ON {table_identifier} BEGIN - INSERT OR REPLACE INTO {self.SNO_TRACK} - (table_name, pk) - VALUES - ('{escaped_table_name}', NEW.{pk_column}), - ('{escaped_table_name}', OLD.{pk_column}); + INSERT OR REPLACE INTO {self.SNO_TRACK} (table_name, pk) + VALUES (:table_name, NEW.{pk_column}), (:table_name, OLD.{pk_column}); END; """ ) - sess.execute( + delete_trigger = sa.text( f""" CREATE TRIGGER {self._quoted_trigger_name(dataset, 'del')} AFTER DELETE ON {table_identifier} BEGIN - INSERT OR REPLACE INTO {self.SNO_TRACK} - (table_name, pk) - VALUES - ('{escaped_table_name}', OLD.{pk_column}); + INSERT OR REPLACE INTO {self.SNO_TRACK} (table_name, pk) + VALUES (:table_name, OLD.{pk_column}); END; """ ) - - def _db_geom_to_gpkg_geom(self, g): - # Its possible in GPKG to put arbitrary values in columns, regardless of type. - # We don't try to convert them here - we let the commit validation step report this as an error. - if not isinstance(g, bytes): - return g - # We normalise geometries to avoid spurious diffs - diffs where nothing - # of any consequence has changed (eg, only endianness has changed). - # This includes setting the SRID to zero for each geometry so that we don't store a separate SRID per geometry, - # but only one per column at most. - return normalise_gpkg_geom(g) + for trigger in (insert_trigger, update_trigger, delete_trigger): + # Placeholders not allowed in CREATE TRIGGER - have to use literal_binds. + # See https://docs.sqlalchemy.org/en/13/faq/sqlexpressions.html#faq-sql-expression-string + trigger.bindparams(table_name=dataset.table_name).compile( + sess.connection(), compile_kwargs={"literal_binds": True} + ).execute() def _is_meta_update_supported(self, dataset_version, meta_diff): """ @@ -623,6 +620,21 @@ def _apply_meta_metadata_dataset_json(self, sess, dataset, src_value, dest_value def _update_last_write_time(self, sess, dataset, commit=None): self._update_gpkg_contents(sess, dataset, commit) + def _get_geom_extent(self, sess, dataset, default=None): + """Returns the envelope around the entire dataset as (min_x, min_y, max_x, max_y).""" + # FIXME: Why doesn't Extent(geom) work here as an aggregate? + geom_col = dataset.geom_column_name + r = sess.execute( + f""" + WITH _E AS ( + SELECT Extent({self.quote(geom_col)}) AS extent FROM {self.table_identifier(dataset)} + ) + SELECT ST_MinX(extent), ST_MinY(extent), ST_MaxX(extent), ST_MaxY(extent) FROM _E; + """ + ) + result = r.fetchone() + return default if result == (None, None, None, None) else result + def _update_gpkg_contents(self, sess, dataset, commit=None): """ Update the metadata for the given table in gpkg_contents to have the new bounding-box / last-updated timestamp. @@ -634,17 +646,11 @@ def _update_gpkg_contents(self, sess, dataset, commit=None): # GPKG Spec Req. 15: gpkg_change_time = change_time.strftime("%Y-%m-%dT%H:%M:%S.%fZ") - table_identifer = self.table_identifier(dataset) geom_col = dataset.geom_column_name if geom_col is not None: - # FIXME: Why doesn't Extent(geom) work here as an aggregate? - r = sess.execute( - f""" - WITH _E AS (SELECT extent({self.quote(geom_col)}) AS extent FROM {table_identifer}) - SELECT ST_MinX(extent), ST_MinY(extent), ST_MaxX(extent), ST_MaxY(extent) FROM _E - """ + min_x, min_y, max_x, max_y = self._get_geom_extent( + sess, dataset, default=(None, None, None, None) ) - min_x, min_y, max_x, max_y = r.fetchone() rc = sess.execute( """ UPDATE gpkg_contents diff --git a/sno/working_copy/postgis.py b/sno/working_copy/postgis.py index 611bb63ee..4d2fd2700 100644 --- a/sno/working_copy/postgis.py +++ b/sno/working_copy/postgis.py @@ -1,37 +1,39 @@ import contextlib import logging -import re import time -from urllib.parse import urlsplit, urlunsplit +from urllib.parse import urlsplit -import click + +from sqlalchemy import Index from sqlalchemy.dialects.postgresql import insert as postgresql_insert from sqlalchemy.sql.compiler import IdentifierPreparer from sqlalchemy.orm import sessionmaker -from .base import WorkingCopy from . import postgis_adapter +from .db_server import DatabaseServer_WorkingCopy from .table_defs import PostgisSnoTables from sno import crs_util -from sno.exceptions import InvalidOperation from sno.schema import Schema from sno.sqlalchemy import postgis_engine -""" -* database needs to exist -* database needs to have postgis enabled -* database user needs to be able to: - 1. create 'sno' schema & tables - 2. create & alter tables in the default (or specified) schema - 3. create triggers -""" +class WorkingCopy_Postgis(DatabaseServer_WorkingCopy): + """ + PosttGIS working copy implementation. -L = logging.getLogger("sno.working_copy.postgis") + Requirements: + 1. The database needs to exist + 2. If the dataset has geometry, then PostGIS (https://postgis.net/) v2.4 or newer needs + to be installed into the database and available in the database user's search path + 3. The database user needs to be able to: + - Create the specified schema (unless it already exists). + - Create, delete and alter tables and triggers in the specified schema. + """ + WORKING_COPY_TYPE_NAME = "PostGIS" + URI_SCHEME = "postgresql" -class WorkingCopy_Postgis(WorkingCopy): def __init__(self, repo, uri): """ uri: connection string of the form postgresql://[user[:password]@][netloc][:port][/dbname/schema][?param1=value1&...] @@ -42,75 +44,18 @@ def __init__(self, repo, uri): self.uri = uri self.path = uri - url = urlsplit(uri) - - if url.scheme != "postgresql": - raise ValueError("Expecting postgresql://") + self.check_valid_db_uri(uri) + self.db_uri, self.db_schema = self._separate_db_schema(uri) - url_path = url.path - path_parts = url_path[1:].split("/", 3) if url_path else [] - if len(path_parts) != 2: - raise ValueError("Expecting postgresql://[HOST]/DBNAME/SCHEMA") - url_path = f"/{path_parts[0]}" - self.db_schema = path_parts[1] - - # Rebuild DB URL suitable for postgres - self.dburl = urlunsplit([url.scheme, url.netloc, url_path, url.query, ""]) - self.engine = postgis_engine(self.dburl) + self.engine = postgis_engine(self.db_uri) self.sessionmaker = sessionmaker(bind=self.engine) self.preparer = IdentifierPreparer(self.engine.dialect) self.sno_tables = PostgisSnoTables(self.db_schema) @classmethod - def check_valid_creation_path(cls, workdir_path, path): - cls.check_valid_path(workdir_path, path) - postgis_wc = cls(None, path) - - # Less strict on Postgis - we are okay with the schema being already created, so long as its empty. - if postgis_wc.has_data(): - raise InvalidOperation( - f"Error creating Postgis working copy at {path} - non-empty schema already exists" - ) - - @classmethod - def check_valid_path(cls, workdir_path, path): - url = urlsplit(path) - - if url.scheme != "postgresql": - raise click.UsageError( - "Invalid postgres URI - Expecting URI in form: postgresql://[HOST]/DBNAME/SCHEMA" - ) - - url_path = url.path - path_parts = url_path[1:].split("/", 3) if url_path else [] - - suggestion_message = "" - if len(path_parts) == 1 and workdir_path is not None: - suggested_path = f"/{path_parts[0]}/{cls.default_schema(workdir_path)}" - suggested_uri = urlunsplit( - [url.scheme, url.netloc, suggested_path, url.query, ""] - ) - suggestion_message = f"\nFor example: {suggested_uri}" - - if len(path_parts) != 2: - raise click.UsageError( - "Invalid postgres URI - postgis working copy requires both dbname and schema:\n" - "Expecting URI in form: postgresql://[HOST]/DBNAME/SCHEMA" - + suggestion_message - ) - - @classmethod - def normalise_path(cls, repo, path): - return path - - @classmethod - def default_schema(cls, workdir_path): - stem = workdir_path.stem - schema = re.sub("[^a-z0-9]+", "_", stem.lower()) + "_sno" - if schema[0].isdigit(): - schema = "_" + schema - return schema + def check_valid_path(cls, wc_path, workdir_path=None): + cls.check_valid_db_uri(wc_path, workdir_path) def __str__(self): p = urlsplit(self.uri) @@ -124,44 +69,11 @@ def __str__(self): p._replace(netloc=nl) return p.geturl() - @contextlib.contextmanager - def session(self, bulk=0): - """ - Context manager for GeoPackage DB sessions, yields a connection object inside a transaction - - Calling again yields the _same_ connection, the transaction/etc only happen in the outer one. - """ - L = logging.getLogger(f"{self.__class__.__qualname__}.session") - - if hasattr(self, "_session"): - # inner - reuse - L.debug("session: existing...") - yield self._session - L.debug("session: existing/done") - - else: - L.debug("session: new...") - - try: - self._session = self.sessionmaker() - - # TODO - use tidier syntax for opening transactions from sqlalchemy. - self._session.execute("BEGIN TRANSACTION;") - yield self._session - self._session.commit() - except Exception: - self._session.rollback() - raise - finally: - self._session.close() - del self._session - L.debug("session: new/done") - def is_created(self): """ - Returns true if the postgres schema referred to by this working copy exists and + Returns true if the DB schema referred to by this working copy exists and contains at least one table. If it exists but is empty, it is treated as uncreated. - This is so the postgres schema can be created ahead of time before a repo is created + This is so the DB schema can be created ahead of time before a repo is created or configured, without it triggering code that checks for corrupted working copies. Note that it might not be initialised as a working copy - see self.is_initialised. """ @@ -177,7 +89,7 @@ def is_created(self): def is_initialised(self): """ - Returns true if the postgis working copy is initialised - + Returns true if the PostGIS working copy is initialised - the schema exists and has the necessary sno tables, _sno_state and _sno_track. """ with self.session() as sess: @@ -192,7 +104,7 @@ def is_initialised(self): def has_data(self): """ - Returns true if the postgis working copy seems to have user-created content already. + Returns true if the PostGIS working copy seems to have user-created content already. """ with self.session() as sess: count = sess.scalar( @@ -283,6 +195,7 @@ def _create_table_for_dataset(self, sess, dataset): ) def _insert_or_replace_into_dataset(self, dataset): + # See https://docs.sqlalchemy.org/en/14/dialects/postgresql.html#insert-on-conflict-upsert pk_col_names = [c.name for c in dataset.schema.pk_columns] stmt = postgresql_insert(self._table_def_for_dataset(dataset)) update_dict = {c.name: c for c in stmt.excluded if c.name not in pk_col_names} @@ -326,21 +239,25 @@ def delete_meta(self, dataset): """Delete any metadata that is only needed by this dataset.""" pass # There is no metadata except for the spatial_ref_sys table. - def _create_spatial_index(self, sess, dataset): + def _create_spatial_index_post(self, sess, dataset): + # Only implemented as _create_spatial_index_post: + # It is more efficient to write the features first, then index them all in bulk. L = logging.getLogger(f"{self.__class__.__qualname__}._create_spatial_index") geom_col = dataset.geom_column_name + index_name = f"{dataset.table_name}_idx_{geom_col}" + table = self._table_def_for_dataset(dataset) - # Create the PostGIS Spatial Index L.debug("Creating spatial index for %s.%s", dataset.table_name, geom_col) t0 = time.monotonic() - index_name = f"{dataset.table_name}_idx_{geom_col}" - sess.execute( - f""" - CREATE INDEX {self.quote(index_name)} - ON {self.table_identifier(dataset)} USING GIST ({self.quote(geom_col)}); - """ + + spatial_index = Index( + index_name, table.columns[geom_col], postgres_using="GIST" ) + spatial_index.table = table + spatial_index.create(sess.connection()) + sess.execute(f"""ANALYSE {self.table_identifier(dataset)};""") + L.info("Created spatial index in %ss", time.monotonic() - t0) def _drop_spatial_index(self, sess, dataset): diff --git a/sno/working_copy/sqlserver.py b/sno/working_copy/sqlserver.py index ee4d8bc33..6156a4c5b 100644 --- a/sno/working_copy/sqlserver.py +++ b/sno/working_copy/sqlserver.py @@ -1,11 +1,9 @@ import contextlib import logging -import re import time -from urllib.parse import urlsplit, urlunsplit +from urllib.parse import urlsplit -import click -import sqlalchemy +import sqlalchemy as sa from sqlalchemy import literal_column from sqlalchemy.ext.compiler import compiles from sqlalchemy.orm import sessionmaker @@ -15,28 +13,28 @@ from sqlalchemy.sql.compiler import IdentifierPreparer from sqlalchemy.types import UserDefinedType -from .base import WorkingCopy from . import sqlserver_adapter +from .db_server import DatabaseServer_WorkingCopy from .table_defs import SqlServerSnoTables from sno import crs_util from sno.geometry import Geometry -from sno.exceptions import InvalidOperation from sno.sqlalchemy import sqlserver_engine -""" -* database needs to exist -* database needs to have postgis enabled -* database user needs to be able to: - 1. create 'sno' schema & tables - 2. create & alter tables in the default (or specified) schema - 3. create triggers -""" +class WorkingCopy_SqlServer(DatabaseServer_WorkingCopy): + """ + SQL Server working copy implementation. -L = logging.getLogger("sno.working_copy.postgis") + Requirements: + 1. The database needs to exist + 2. The database user needs to be able to: + - Create the specified schema (unless it already exists). + - Create, delete and alter tables and triggers in the specified schema. + """ + WORKING_COPY_TYPE_NAME = "SQL Server" + URI_SCHEME = "mssql" -class WorkingCopy_SqlServer(WorkingCopy): def __init__(self, repo, uri): """ uri: connection string of the form mssql://[user[:password]@][netloc][:port][/dbname/schema][?param1=value1&...] @@ -47,78 +45,15 @@ def __init__(self, repo, uri): self.uri = uri self.path = uri - url = urlsplit(uri) - - if url.scheme != "mssql": - raise ValueError("Expecting mssql://") - - url_path = url.path - path_parts = url_path[1:].split("/", 3) if url_path else [] - if len(path_parts) != 2: - raise ValueError("Expecting mssql://[HOST]/DBNAME/SCHEMA") - url_path = f"/{path_parts[0]}" - self.db_schema = path_parts[1] + self.check_valid_db_uri(uri) + self.db_uri, self.db_schema = self._separate_db_schema(uri) - # Rebuild DB URL suitable for sqlserver_engine. - self.dburl = urlunsplit([url.scheme, url.netloc, url_path, url.query, ""]) - self.engine = sqlserver_engine(self.dburl) + self.engine = sqlserver_engine(self.db_uri) self.sessionmaker = sessionmaker(bind=self.engine) self.preparer = IdentifierPreparer(self.engine.dialect) self.sno_tables = SqlServerSnoTables(self.db_schema) - @classmethod - def check_valid_creation_path(cls, workdir_path, path): - # TODO - promote to superclass - cls.check_valid_path(workdir_path, path) - sqlserver_wc = cls(None, path) - - # We are okay with the schema being already created, so long as its empty. - if sqlserver_wc.has_data(): - raise InvalidOperation( - f"Error creating SQL Server working copy at {path} - non-empty schema already exists" - ) - - @classmethod - def check_valid_path(cls, workdir_path, path): - url = urlsplit(path) - - if url.scheme != "mssql": - raise click.UsageError( - "Invalid postgres URI - Expecting URI in form: mssql://[HOST]/DBNAME/SCHEMA" - ) - - url_path = url.path - path_parts = url_path[1:].split("/", 3) if url_path else [] - - suggestion_message = "" - if len(path_parts) == 1 and workdir_path is not None: - suggested_path = f"/{path_parts[0]}/{cls.default_schema(workdir_path)}" - suggested_uri = urlunsplit( - [url.scheme, url.netloc, suggested_path, url.query, ""] - ) - suggestion_message = f"\nFor example: {suggested_uri}" - - if len(path_parts) != 2: - raise click.UsageError( - "Invalid mssql URI - SQL Server working copy requires both dbname and schema:\n" - "Expecting URI in form: mssql://[HOST]/DBNAME/SCHEMA" - + suggestion_message - ) - - @classmethod - def normalise_path(cls, repo, path): - return path - - @classmethod - def default_schema(cls, workdir_path): - # TODO - promote to superclass - stem = workdir_path.stem - schema = re.sub("[^a-z0-9]+", "_", stem.lower()) + "_sno" - if schema[0].isdigit(): - schema = "_" + schema - return schema - def __str__(self): p = urlsplit(self.uri) if p.password is not None: @@ -131,38 +66,6 @@ def __str__(self): p._replace(netloc=nl) return p.geturl() - @contextlib.contextmanager - def session(self, bulk=0): - """ - Context manager for GeoPackage DB sessions, yields a connection object inside a transaction - - Calling again yields the _same_ connection, the transaction/etc only happen in the outer one. - """ - L = logging.getLogger(f"{self.__class__.__qualname__}.session") - - if hasattr(self, "_session"): - # inner - reuse - L.debug("session: existing...") - yield self._session - L.debug("session: existing/done") - - else: - L.debug("session: new...") - - try: - self._session = self.sessionmaker() - - # TODO - use tidier syntax for opening transactions from sqlalchemy. - yield self._session - self._session.commit() - except Exception: - self._session.rollback() - raise - finally: - self._session.close() - del self._session - L.debug("session: new/done") - def is_created(self): """ Returns true if the db schema referred to by this working copy exists and @@ -180,7 +83,7 @@ def is_created(self): def is_initialised(self): """ - Returns true if the postgis working copy is initialised - + Returns true if the SQL server working copy is initialised - the schema exists and has the necessary sno tables, _sno_state and _sno_track. """ with self.session() as sess: @@ -196,7 +99,7 @@ def is_initialised(self): def has_data(self): """ - Returns true if the postgis working copy seems to have user-created content already. + Returns true if the SQL server working copy seems to have user-created content already. """ with self.session() as sess: count = sess.scalar( @@ -245,15 +148,15 @@ def _create_table_for_dataset(self, sess, dataset): def _table_def_for_column_schema(self, col, dataset): if col.data_type == "geometry": - # This user-defined Geography type adapts WKB to SQL Server's native geography type. crs_name = col.extra_type_info.get("geometryCRS", None) crs_id = crs_util.get_identifier_int_from_dataset(dataset, crs_name) or 0 - return sqlalchemy.column(col.name, GeographyType(crs_id)) + # This user-defined GeometryType adapts Sno's GPKG geometry to SQL Server's native geometry type. + return sa.column(col.name, GeometryType(crs_id)) elif col.data_type in ("date", "time", "timestamp"): - return sqlalchemy.column(col.name, BaseDateOrTimeType) + return sa.column(col.name, BaseDateOrTimeType) else: # Don't need to specify type information for other columns at present, since we just pass through the values. - return sqlalchemy.column(col.name) + return sa.column(col.name) def _insert_or_replace_into_dataset(self, dataset): pk_col_names = [c.name for c in dataset.schema.pk_columns] @@ -282,33 +185,85 @@ def _insert_or_replace_state_table_tree(self, sess, tree_id): return r.rowcount def _write_meta(self, sess, dataset): - """Write the title (as a comment) and the CRS. Other metadata is not stored in a PostGIS WC.""" + """Write the title. Other metadata is not stored in a SQL Server WC.""" self._write_meta_title(sess, dataset) def _write_meta_title(self, sess, dataset): """Write the dataset title as a comment on the table.""" - # TODO - probably need to use sp_addextendedproperty @name=N'MS_Description' + # TODO - dataset title is not stored anywhere in SQL server working copy right now. + # We can probably store it using function sp_addextendedproperty to add property 'MS_Description' pass def delete_meta(self, dataset): """Delete any metadata that is only needed by this dataset.""" - pass # There is no metadata except for the spatial_ref_sys table. + # There is no metadata stored anywhere except the table itself. + pass + + def _get_geom_extent(self, sess, dataset, default=None): + """Returns the envelope around the entire dataset as (min_x, min_y, max_x, max_y).""" + geom_col = dataset.geom_column_name + r = sess.execute( + f""" + WITH _E AS ( + SELECT geometry::EnvelopeAggregate({self.quote(geom_col)}) AS envelope + FROM {self.table_identifier(dataset)} + ) + SELECT + envelope.STPointN(1).STX AS min_x, + envelope.STPointN(1).STY AS min_y, + envelope.STPointN(3).STX AS max_x, + envelope.STPointN(3).STY AS max_y + FROM _E; + """ + ) + result = r.fetchone() + return default if result == (None, None, None, None) else result + + def _grow_rectangle(self, rectangle, scale_factor): + # scale_factor = 1 -> no change, >1 -> grow, <1 -> shrink. + min_x, min_y, max_x, max_y = rectangle + centre_x, centre_y = (min_x + max_x) / 2, (min_y + max_y) / 2 + min_x = (min_x - centre_x) * scale_factor + centre_x + min_y = (min_y - centre_y) * scale_factor + centre_y + max_x = (max_x - centre_x) * scale_factor + centre_x + max_y = (max_y - centre_y) * scale_factor + centre_y + return min_x, min_y, max_x, max_y + + def _create_spatial_index_post(self, sess, dataset): + # Only implementing _create_spatial_index_post: + # We need to know the rough extent of the data to create an index in that area, + # so we create the spatial index once the bulk of the features have been written. - def _create_spatial_index(self, sess, dataset): L = logging.getLogger(f"{self.__class__.__qualname__}._create_spatial_index") + extent = self._get_geom_extent(sess, dataset) + if not extent: + # Can't create a spatial index if we don't know the rough bounding box we need to index. + return + + # Add 20% room to grow. + GROW_FACTOR = 1.2 + min_x, min_y, max_x, max_y = self._grow_rectangle(extent, GROW_FACTOR) + geom_col = dataset.geom_column_name + index_name = f"{dataset.table_name}_idx_{geom_col}" - # Create the SQL Server Spatial Index L.debug("Creating spatial index for %s.%s", dataset.table_name, geom_col) t0 = time.monotonic() - index_name = f"{dataset.table_name}_idx_{geom_col}" - sess.execute( + + create_index = sa.text( f""" CREATE SPATIAL INDEX {self.quote(index_name)} - ON {self.table_identifier(dataset)} ({self.quote(geom_col)}); + ON {self.table_identifier(dataset)} ({self.quote(geom_col)}) + WITH (BOUNDING_BOX = (:min_x, :min_y, :max_x, :max_y)) """ - ) + ).bindparams(min_x=min_x, min_y=min_y, max_x=max_x, max_y=max_y) + # Placeholders not allowed in CREATE SPATIAL INDEX - have to use literal_binds. + # See https://docs.sqlalchemy.org/en/13/faq/sqlexpressions.html#faq-sql-expression-string + create_index.compile( + sess.connection(), compile_kwargs={"literal_binds": True} + ).execute() + L.info("Created spatial index in %ss", time.monotonic() - t0) def _drop_spatial_index(self, sess, dataset): @@ -321,24 +276,26 @@ def _quoted_trigger_name(self, dataset): def _create_triggers(self, sess, dataset): pk_name = dataset.primary_key - escaped_table_name = dataset.table_name.replace("'", "''") - - sess.execute( + create_trigger = sa.text( f""" CREATE TRIGGER {self._quoted_trigger_name(dataset)} ON {self.table_identifier(dataset)} AFTER INSERT, UPDATE, DELETE AS BEGIN MERGE {self.SNO_TRACK} TRA USING - (SELECT '{escaped_table_name}', {self.quote(pk_name)} FROM inserted - UNION SELECT '{escaped_table_name}', {self.quote(pk_name)} FROM deleted) + (SELECT :table_name, {self.quote(pk_name)} FROM inserted + UNION SELECT :table_name, {self.quote(pk_name)} FROM deleted) AS SRC (table_name, pk) ON SRC.table_name = TRA.table_name AND SRC.pk = TRA.pk WHEN NOT MATCHED THEN INSERT (table_name, pk) VALUES (SRC.table_name, SRC.pk); END; - """, - {"table_name": dataset.table_name}, - ) + """ + ).bindparams(table_name=dataset.table_name) + # Placeholders not allowed in CREATE TRIGGER - have to use literal_binds. + # See https://docs.sqlalchemy.org/en/13/faq/sqlexpressions.html#faq-sql-expression-string + create_trigger.compile( + sess.connection(), compile_kwargs={"literal_binds": True} + ).execute() @contextlib.contextmanager def _suspend_triggers(self, sess, dataset): @@ -352,24 +309,6 @@ def _suspend_triggers(self, sess, dataset): def meta_items(self, dataset): with self.session() as sess: - table_info_sql = """ - SELECT - C.column_name, C.ordinal_position, C.data_type, C.udt_name, - C.character_maximum_length, C.numeric_precision, C.numeric_scale, - KCU.ordinal_position AS pk_ordinal_position, - upper(postgis_typmod_type(A.atttypmod)) AS geometry_type, - postgis_typmod_srid(A.atttypmod) AS geometry_srid - FROM information_schema.columns C - LEFT OUTER JOIN information_schema.key_column_usage KCU - ON (KCU.table_schema = C.table_schema) - AND (KCU.table_name = C.table_name) - AND (KCU.column_name = C.column_name) - LEFT OUTER JOIN pg_attribute A - ON (A.attname = C.column_name) - AND (A.attrelid = (C.table_schema || '.' || C.table_name)::regclass::oid) - WHERE C.table_schema=:table_schema AND C.table_name=:table_name - ORDER BY C.ordinal_position; - """ table_info_sql = """ SELECT C.column_name, C.ordinal_position, C.data_type, @@ -405,10 +344,6 @@ def try_align_schema_col(cls, old_col_dict, new_col_dict): old_type = old_col_dict["dataType"] new_type = new_col_dict["dataType"] - # Some types have to be approximated as other types in SQL Server. - if sqlserver_adapter.APPROXIMATED_TYPES.get(old_type) == new_type: - new_col_dict["dataType"] = new_type = old_type - # Geometry type loses its extra type info when roundtripped through SQL Server. if new_type == "geometry": new_col_dict["geometryType"] = old_col_dict.get("geometryType") @@ -419,7 +354,7 @@ def try_align_schema_col(cls, old_col_dict, new_col_dict): def _remove_hidden_meta_diffs(self, dataset, ds_meta_items, wc_meta_items): super()._remove_hidden_meta_diffs(dataset, ds_meta_items, wc_meta_items) - # Nowhere to put these in postgis WC + # Nowhere to put these in SQL Server WC for key in self._UNSUPPORTED_META_ITEMS: if key in ds_meta_items: del ds_meta_items[key] @@ -447,15 +382,13 @@ class InstanceFunction(Function): >>> function(element) """ - pass - @compiles(InstanceFunction) def compile_instance_function(element, compiler, **kw): return "(%s).%s()" % (element.clauses, element.name) -class GeographyType(UserDefinedType): +class GeometryType(UserDefinedType): """UserDefinedType so that V2 geometry is adapted to MS binary format.""" def __init__(self, crs_id): @@ -468,7 +401,7 @@ def bind_processor(self, dialect): def bind_expression(self, bindvalue): # 2. Writing - SQL layer - wrap in call to STGeomFromWKB to convert WKB to MS binary. return Function( - quoted_name("geography::STGeomFromWKB", False), + quoted_name("geometry::STGeomFromWKB", False), bindvalue, self.crs_id, type_=self, @@ -484,10 +417,14 @@ def result_processor(self, dialect, coltype): class BaseDateOrTimeType(UserDefinedType): - """UserDefinedType so we read dates, times, and datetimes as text.""" + """ + UserDefinedType so we read dates, times, and datetimes as text. + They are stored as date / time / datetime in SQL Server, but read back out as text. + """ def column_expression(self, col): # When reading, convert dates and times to strings using style 127: ISO8601 with time zone Z. + # https://docs.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql return Function( "CONVERT", literal_column("NVARCHAR"), diff --git a/sno/working_copy/sqlserver_adapter.py b/sno/working_copy/sqlserver_adapter.py index de64fe778..86f10b329 100644 --- a/sno/working_copy/sqlserver_adapter.py +++ b/sno/working_copy/sqlserver_adapter.py @@ -1,4 +1,3 @@ -from sno import crs_util from sno.schema import Schema, ColumnSchema from sqlalchemy.sql.compiler import IdentifierPreparer @@ -17,7 +16,7 @@ def quote(ident): "blob": "varbinary", "date": "date", "float": {0: "real", 32: "real", 64: "float"}, - "geometry": "geography", + "geometry": "geometry", "integer": { 0: "int", 8: "tinyint", @@ -25,7 +24,7 @@ def quote(ident): 32: "int", 64: "bigint", }, - "interval": "nvarchar", + "interval": "text", "numeric": "numeric", "text": "nvarchar", "time": "time", @@ -40,6 +39,8 @@ def quote(ident): "bigint": ("integer", 64), "real": ("float", 32), "float": ("float", 64), + "binary": "blob", + "char": "text", "date": "date", "datetime": "timestamp", "datetime2": "timestamp", @@ -47,9 +48,10 @@ def quote(ident): "decimal": "numeric", "geography": "geometry", "geometry": "geometry", - "interval": "interval", + "nchar": "text", "numeric": "numeric", "nvarchar": "text", + "ntext": "text", "text": "text", "time": "time", "varchar": "text", @@ -91,13 +93,6 @@ def v2_type_to_ms_type(column_schema, v2_obj): return ms_type_info.get(extra_type_info.get("size", 0)) ms_type = ms_type_info - if ms_type == "geometry": - geometry_type = extra_type_info.get("geometryType") - crs_name = extra_type_info.get("geometryCRS") - crs_id = None - if crs_name is not None: - crs_id = crs_util.get_identifier_int_from_dataset(v2_obj, crs_name) - return _v2_geometry_type_to_ms_type(geometry_type, crs_id) if ms_type in ("varchar", "nvarchar", "varbinary"): length = extra_type_info.get("length", None) @@ -116,33 +111,19 @@ def v2_type_to_ms_type(column_schema, v2_obj): return ms_type -def _v2_geometry_type_to_ms_type(geometry_type, crs_id): - if geometry_type is not None: - geometry_type = geometry_type.replace(" ", "") - - if geometry_type is not None and crs_id is not None: - return f"geometry({geometry_type},{crs_id})" - elif geometry_type is not None: - return f"geometry({geometry_type})" - else: - return "geometry" - - def sqlserver_to_v2_schema(ms_table_info, id_salt): - """Generate a V2 schema from the given postgis metadata tables.""" + """Generate a V2 schema from the given SQL server metadata.""" return Schema([_sqlserver_to_column_schema(col, id_salt) for col in ms_table_info]) def _sqlserver_to_column_schema(ms_col_info, id_salt): """ - Given the postgis column info for a particular column, and some extra context in - case it is a geometry column, converts it to a ColumnSchema. The extra context will - only be used if the given ms_col_info is the geometry column. + Given the MS column info for a particular column, converts it to a ColumnSchema. + Parameters: ms_col_info - info about a single column from ms_table_info. - ms_spatial_ref_sys - rows of the "spatial_ref_sys" table that are referenced by this dataset. id_salt - the UUIDs of the generated ColumnSchema are deterministic and depend on - the name and type of the column, and on this salt. + the name and type of the column, and on this salt. """ name = ms_col_info["column_name"] pk_index = ms_col_info["pk_ordinal_position"] diff --git a/sno/working_copy/table_defs.py b/sno/working_copy/table_defs.py index 5f8fbd2dc..ef500911c 100644 --- a/sno/working_copy/table_defs.py +++ b/sno/working_copy/table_defs.py @@ -104,9 +104,11 @@ def create_all(self, session): class SqlServerSnoTables(TableSet): """ - Tables for sno-specific metadata - PostGIS variant. + Tables for sno-specific metadata - SQL Server variant. Table names have a user-defined schema, and so unlike other table sets, we need to construct an instance with the appropriate schema. + Primary keys have to be NVARCHAR of a fixed maximum length - + if the total maximum length is too long, SQL Server cannot generate an index. """ def __init__(self, schema=None): diff --git a/tests/conftest.py b/tests/conftest.py index 014dbda6c..47cdf6deb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -930,8 +930,8 @@ def ctx(create=False): def sqlserver_db(): """ Using docker, you can run a SQL Server test - such as those in test_working_copy_sqlserver - as follows: - docker run -it --rm -d -p 11433:1433 -e ACCEPT_EULA=Y -e 'SA_PASSWORD=Sql(server)' mcr.microsoft.com/mssql/server - SNO_SQLSERVER_URL='mssql://sa:Sql(server)@127.0.0.1:11433/master' spytest -k sqlserver --pdb -vvs + docker run -it --rm -d -p 11433:1433 -e ACCEPT_EULA=Y -e 'SA_PASSWORD=PassWord1' mcr.microsoft.com/mssql/server + SNO_SQLSERVER_URL='mssql://sa:PassWord1@127.0.0.1:11433/master' pytest -k sqlserver --pdb -vvs """ if "SNO_SQLSERVER_URL" not in os.environ: raise pytest.skip( @@ -939,7 +939,7 @@ def sqlserver_db(): ) engine = sqlserver_engine(os.environ["SNO_SQLSERVER_URL"]) with engine.connect() as conn: - # test connection and postgis support + # Test connection try: conn.execute("SELECT @@version;") except sqlalchemy.exc.DBAPIError: diff --git a/tests/test_working_copy_sqlserver.py b/tests/test_working_copy_sqlserver.py index 2853f2263..aa5e5f6a0 100644 --- a/tests/test_working_copy_sqlserver.py +++ b/tests/test_working_copy_sqlserver.py @@ -308,6 +308,6 @@ def test_types_roundtrip(data_archive, cli_runner, new_sqlserver_db_schema): r = cli_runner.invoke(["checkout"]) # If type-approximation roundtrip code isn't working, - # we would get spurious diffs on types that PostGIS doesn't support. + # we would get spurious diffs on types that SQL server doesn't support. r = cli_runner.invoke(["diff", "--exit-code"]) assert r.exit_code == 0, r.stdout