From d0c07d063efaa27335c73ac10e23bfecbdc576c1 Mon Sep 17 00:00:00 2001 From: Andrew Olsen Date: Wed, 17 Feb 2021 09:06:30 +1300 Subject: [PATCH 1/2] Add support for SQL Server working copy --- sno/checkout.py | 2 +- sno/geometry.py | 15 +- sno/sqlalchemy.py | 49 +++ sno/working_copy/base.py | 116 +++--- sno/working_copy/gpkg.py | 42 +- sno/working_copy/gpkg_adapter.py | 2 +- sno/working_copy/postgis.py | 27 +- sno/working_copy/postgis_adapter.py | 2 +- sno/working_copy/sqlserver.py | 553 ++++++++++++++++++++++++++ sno/working_copy/sqlserver_adapter.py | 179 +++++++++ sno/working_copy/table_defs.py | 32 ++ tests/conftest.py | 140 +++++-- tests/test_working_copy_postgis.py | 6 +- tests/test_working_copy_sqlserver.py | 313 +++++++++++++++ 14 files changed, 1364 insertions(+), 114 deletions(-) create mode 100644 sno/working_copy/sqlserver.py create mode 100644 sno/working_copy/sqlserver_adapter.py create mode 100644 tests/test_working_copy_sqlserver.py diff --git a/sno/checkout.py b/sno/checkout.py index 2031fc6c2..cf3f80ace 100644 --- a/sno/checkout.py +++ b/sno/checkout.py @@ -20,7 +20,7 @@ def reset_wc_if_needed(repo, target_tree_or_commit, *, discard_changes=False): """Resets the working copy to the target if it does not already match, or if discard_changes is True.""" - working_copy = WorkingCopy.get(repo, allow_uncreated=True) + working_copy = WorkingCopy.get(repo, allow_uncreated=True, allow_invalid_state=True) if working_copy is None: click.echo( "(Bare sno repository - to create a working copy, use `sno create-workingcopy`)" diff --git a/sno/geometry.py b/sno/geometry.py index 6caf3d860..ca42f0d75 100644 --- a/sno/geometry.py +++ b/sno/geometry.py @@ -57,6 +57,11 @@ def with_crs_id(self, crs_id): crs_id_bytes = struct.pack(" 0 + + def is_initialised(self): + """ + Returns true if the postgis working copy is initialised - + the schema exists and has the necessary sno tables, _sno_state and _sno_track. + """ + with self.session() as sess: + count = sess.scalar( + f""" + SELECT COUNT(*) FROM sys.tables + WHERE schema_id = SCHEMA_ID(:schema_name) + AND name IN ('{self.SNO_STATE_NAME}', '{self.SNO_TRACK_NAME}'); + """, + {"schema_name": self.db_schema}, + ) + return count == 2 + + def has_data(self): + """ + Returns true if the postgis working copy seems to have user-created content already. + """ + with self.session() as sess: + count = sess.scalar( + f""" + SELECT COUNT(*) FROM sys.tables + WHERE schema_id = SCHEMA_ID(:schema_name) + AND name NOT IN ('{self.SNO_STATE_NAME}', '{self.SNO_TRACK_NAME}'); + """, + {"schema_name": self.db_schema}, + ) + return count > 0 + + def create_and_initialise(self): + with self.session() as sess: + if not self.is_created(): + sess.execute(f"CREATE SCHEMA {self.DB_SCHEMA};") + + with self.session() as sess: + self.sno_tables.create_all(sess) + + def delete(self, keep_db_schema_if_possible=False): + """Delete all tables in the schema.""" + with self.session() as sess: + # Drop tables + r = sess.execute( + "SELECT name FROM sys.tables WHERE schema_id=SCHEMA_ID(:schema);", + {"schema": self.db_schema}, + ) + table_identifiers = ", ".join((self.table_identifier(row[0]) for row in r)) + if table_identifiers: + sess.execute(f"DROP TABLE IF EXISTS {table_identifiers};") + + # Drop schema, unless keep_db_schema_if_possible=True + if not keep_db_schema_if_possible: + sess.execute( + f"DROP SCHEMA IF EXISTS {self.DB_SCHEMA};", + ) + + def _create_table_for_dataset(self, sess, dataset): + table_spec = sqlserver_adapter.v2_schema_to_sqlserver_spec( + dataset.schema, dataset + ) + sess.execute( + f"""CREATE TABLE {self.table_identifier(dataset)} ({table_spec});""" + ) + + def _table_def_for_column_schema(self, col, dataset): + if col.data_type == "geometry": + # This user-defined Geography type adapts WKB to SQL Server's native geography type. + crs_name = col.extra_type_info.get("geometryCRS", None) + crs_id = crs_util.get_identifier_int_from_dataset(dataset, crs_name) or 0 + return sqlalchemy.column(col.name, GeographyType(crs_id)) + elif col.data_type in ("date", "time", "timestamp"): + return sqlalchemy.column(col.name, BaseDateOrTimeType) + else: + # Don't need to specify type information for other columns at present, since we just pass through the values. + return sqlalchemy.column(col.name) + + def _insert_or_replace_into_dataset(self, dataset): + pk_col_names = [c.name for c in dataset.schema.pk_columns] + non_pk_col_names = [ + c.name for c in dataset.schema.columns if c.pk_index is None + ] + return sqlserver_upsert( + self._table_def_for_dataset(dataset), + index_elements=pk_col_names, + set_=non_pk_col_names, + ) + + def _insert_or_replace_state_table_tree(self, sess, tree_id): + r = sess.execute( + f""" + MERGE {self.SNO_STATE} STA + USING (VALUES ('*', 'tree', :value)) AS SRC("table_name", "key", "value") + ON SRC."table_name" = STA."table_name" AND SRC."key" = STA."key" + WHEN MATCHED THEN + UPDATE SET "value" = SRC."value" + WHEN NOT MATCHED THEN + INSERT ("table_name", "key", "value") VALUES (SRC."table_name", SRC."key", SRC."value"); + """, + {"value": tree_id}, + ) + return r.rowcount + + def _write_meta(self, sess, dataset): + """Write the title (as a comment) and the CRS. Other metadata is not stored in a PostGIS WC.""" + self._write_meta_title(sess, dataset) + + def _write_meta_title(self, sess, dataset): + """Write the dataset title as a comment on the table.""" + # TODO - probably need to use sp_addextendedproperty @name=N'MS_Description' + pass + + def delete_meta(self, dataset): + """Delete any metadata that is only needed by this dataset.""" + pass # There is no metadata except for the spatial_ref_sys table. + + def _create_spatial_index(self, sess, dataset): + L = logging.getLogger(f"{self.__class__.__qualname__}._create_spatial_index") + + geom_col = dataset.geom_column_name + + # Create the SQL Server Spatial Index + L.debug("Creating spatial index for %s.%s", dataset.table_name, geom_col) + t0 = time.monotonic() + index_name = f"{dataset.table_name}_idx_{geom_col}" + sess.execute( + f""" + CREATE SPATIAL INDEX {self.quote(index_name)} + ON {self.table_identifier(dataset)} ({self.quote(geom_col)}); + """ + ) + L.info("Created spatial index in %ss", time.monotonic() - t0) + + def _drop_spatial_index(self, sess, dataset): + # SQL server deletes the spatial index automatically when the table is deleted. + pass + + def _quoted_trigger_name(self, dataset): + trigger_name = f"{dataset.table_name}_sno_track" + return f"{self.DB_SCHEMA}.{self.quote(trigger_name)}" + + def _create_triggers(self, sess, dataset): + pk_name = dataset.primary_key + escaped_table_name = dataset.table_name.replace("'", "''") + + sess.execute( + f""" + CREATE TRIGGER {self._quoted_trigger_name(dataset)} ON {self.table_identifier(dataset)} + AFTER INSERT, UPDATE, DELETE AS + BEGIN + MERGE {self.SNO_TRACK} TRA + USING + (SELECT '{escaped_table_name}', {self.quote(pk_name)} FROM inserted + UNION SELECT '{escaped_table_name}', {self.quote(pk_name)} FROM deleted) + AS SRC (table_name, pk) + ON SRC.table_name = TRA.table_name AND SRC.pk = TRA.pk + WHEN NOT MATCHED THEN INSERT (table_name, pk) VALUES (SRC.table_name, SRC.pk); + END; + """, + {"table_name": dataset.table_name}, + ) + + @contextlib.contextmanager + def _suspend_triggers(self, sess, dataset): + sess.execute( + f"""DISABLE TRIGGER {self._quoted_trigger_name(dataset)} ON {self.table_identifier(dataset)};""" + ) + yield + sess.execute( + f"""ENABLE TRIGGER {self._quoted_trigger_name(dataset)} ON {self.table_identifier(dataset)};""" + ) + + def meta_items(self, dataset): + with self.session() as sess: + table_info_sql = """ + SELECT + C.column_name, C.ordinal_position, C.data_type, C.udt_name, + C.character_maximum_length, C.numeric_precision, C.numeric_scale, + KCU.ordinal_position AS pk_ordinal_position, + upper(postgis_typmod_type(A.atttypmod)) AS geometry_type, + postgis_typmod_srid(A.atttypmod) AS geometry_srid + FROM information_schema.columns C + LEFT OUTER JOIN information_schema.key_column_usage KCU + ON (KCU.table_schema = C.table_schema) + AND (KCU.table_name = C.table_name) + AND (KCU.column_name = C.column_name) + LEFT OUTER JOIN pg_attribute A + ON (A.attname = C.column_name) + AND (A.attrelid = (C.table_schema || '.' || C.table_name)::regclass::oid) + WHERE C.table_schema=:table_schema AND C.table_name=:table_name + ORDER BY C.ordinal_position; + """ + table_info_sql = """ + SELECT + C.column_name, C.ordinal_position, C.data_type, + C.character_maximum_length, C.numeric_precision, C.numeric_scale, + KCU.ordinal_position AS pk_ordinal_position + FROM information_schema.columns C + LEFT OUTER JOIN information_schema.key_column_usage KCU + ON (KCU.table_schema = C.table_schema) + AND (KCU.table_name = C.table_name) + AND (KCU.column_name = C.column_name) + WHERE C.table_schema=:table_schema AND C.table_name=:table_name + ORDER BY C.ordinal_position; + """ + r = sess.execute( + table_info_sql, + {"table_schema": self.db_schema, "table_name": dataset.table_name}, + ) + ms_table_info = list(r) + + id_salt = f"{self.db_schema} {dataset.table_name} {self.get_db_tree()}" + schema = sqlserver_adapter.sqlserver_to_v2_schema(ms_table_info, id_salt) + yield "schema.json", schema.to_column_dicts() + + _UNSUPPORTED_META_ITEMS = ( + "title", + "description", + "metadata/dataset.json", + "metadata.xml", + ) + + @classmethod + def try_align_schema_col(cls, old_col_dict, new_col_dict): + old_type = old_col_dict["dataType"] + new_type = new_col_dict["dataType"] + + # Some types have to be approximated as other types in SQL Server. + if sqlserver_adapter.APPROXIMATED_TYPES.get(old_type) == new_type: + new_col_dict["dataType"] = new_type = old_type + + # Geometry type loses its extra type info when roundtripped through SQL Server. + if new_type == "geometry": + new_col_dict["geometryType"] = old_col_dict.get("geometryType") + new_col_dict["geometryCRS"] = old_col_dict.get("geometryCRS") + + return new_type == old_type + + def _remove_hidden_meta_diffs(self, dataset, ds_meta_items, wc_meta_items): + super()._remove_hidden_meta_diffs(dataset, ds_meta_items, wc_meta_items) + + # Nowhere to put these in postgis WC + for key in self._UNSUPPORTED_META_ITEMS: + if key in ds_meta_items: + del ds_meta_items[key] + + # Diffing CRS is not yet supported. + for key in list(ds_meta_items.keys()): + if key.startswith("crs/"): + del ds_meta_items[key] + + def _is_meta_update_supported(self, dataset_version, meta_diff): + """ + Returns True if the given meta-diff is supported *without* dropping and rewriting the table. + (Any meta change is supported if we drop and rewrite the table, but of course it is less efficient). + meta_diff - DeltaDiff object containing the meta changes. + """ + # For now, just always drop and rewrite. + return not meta_diff + + +class InstanceFunction(Function): + """ + An instance function that compiles like this when applied to an element: + >>> element.function() + Unlike a normal sqlalchemy function which would compile as follows: + >>> function(element) + """ + + pass + + +@compiles(InstanceFunction) +def compile_instance_function(element, compiler, **kw): + return "(%s).%s()" % (element.clauses, element.name) + + +class GeographyType(UserDefinedType): + """UserDefinedType so that V2 geometry is adapted to MS binary format.""" + + def __init__(self, crs_id): + self.crs_id = crs_id + + def bind_processor(self, dialect): + # 1. Writing - Python layer - convert sno geometry to WKB + return lambda geom: geom.to_wkb() + + def bind_expression(self, bindvalue): + # 2. Writing - SQL layer - wrap in call to STGeomFromWKB to convert WKB to MS binary. + return Function( + quoted_name("geography::STGeomFromWKB", False), + bindvalue, + self.crs_id, + type_=self, + ) + + def column_expression(self, col): + # 3. Reading - SQL layer - append with call to .STAsBinary() to convert MS binary to WKB. + return InstanceFunction("STAsBinary", col, type_=self) + + def result_processor(self, dialect, coltype): + # 4. Reading - Python layer - convert WKB to sno geometry. + return lambda wkb: Geometry.from_wkb(wkb) + + +class BaseDateOrTimeType(UserDefinedType): + """UserDefinedType so we read dates, times, and datetimes as text.""" + + def column_expression(self, col): + # When reading, convert dates and times to strings using style 127: ISO8601 with time zone Z. + return Function( + "CONVERT", + literal_column("NVARCHAR"), + col, + literal_column("127"), + type_=self, + ) + + +def sqlserver_upsert(*args, **kwargs): + return Upsert(*args, **kwargs) + + +class Upsert(ValuesBase): + """A SQL server custom upsert command that compiles to a merge statement.""" + + def __init__( + self, + table, + values=None, + prefixes=None, + index_elements=None, + set_=None, + **dialect_kw, + ): + ValuesBase.__init__(self, table, values, prefixes) + self._validate_dialect_kwargs(dialect_kw) + self.index_elements = index_elements + self.set_ = set_ + self.select = self.select_names = None + self._returning = None + + +@compiles(Upsert) +def compile_upsert(upsert_stmt, compiler, **kw): + preparer = compiler.preparer + + def list_cols(col_names, prefix=""): + return ", ".join([prefix + c for c in col_names]) + + crud_params = crud._setup_crud_params(compiler, upsert_stmt, crud.ISINSERT, **kw) + crud_values = ", ".join([c[1] for c in crud_params]) + + table = preparer.format_table(upsert_stmt.table) + all_columns = [preparer.quote(c[0].name) for c in crud_params] + index_elements = [preparer.quote(c) for c in upsert_stmt.index_elements] + set_ = [preparer.quote(c) for c in upsert_stmt.set_] + + result = f"MERGE {table} TARGET" + result += f" USING (VALUES ({crud_values})) AS SOURCE ({list_cols(all_columns)})" + + result += " ON " + result += " AND ".join([f"SOURCE.{c} = TARGET.{c}" for c in index_elements]) + + result += " WHEN MATCHED THEN UPDATE SET " + result += ", ".join([f"{c} = SOURCE.{c}" for c in set_]) + + result += " WHEN NOT MATCHED THEN INSERT " + result += ( + f"({list_cols(all_columns)}) VALUES ({list_cols(all_columns, 'SOURCE.')});" + ) + + return result diff --git a/sno/working_copy/sqlserver_adapter.py b/sno/working_copy/sqlserver_adapter.py new file mode 100644 index 000000000..de64fe778 --- /dev/null +++ b/sno/working_copy/sqlserver_adapter.py @@ -0,0 +1,179 @@ +from sno import crs_util +from sno.schema import Schema, ColumnSchema + +from sqlalchemy.sql.compiler import IdentifierPreparer +from sqlalchemy.dialects.mssql.base import MSDialect + + +_PREPARER = IdentifierPreparer(MSDialect()) + + +def quote(ident): + return _PREPARER.quote(ident) + + +V2_TYPE_TO_MS_TYPE = { + "boolean": "bit", + "blob": "varbinary", + "date": "date", + "float": {0: "real", 32: "real", 64: "float"}, + "geometry": "geography", + "integer": { + 0: "int", + 8: "tinyint", + 16: "smallint", + 32: "int", + 64: "bigint", + }, + "interval": "nvarchar", + "numeric": "numeric", + "text": "nvarchar", + "time": "time", + "timestamp": "datetimeoffset", +} + +MS_TYPE_TO_V2_TYPE = { + "bit": "boolean", + "tinyint": ("integer", 8), + "smallint": ("integer", 16), + "int": ("integer", 32), + "bigint": ("integer", 64), + "real": ("float", 32), + "float": ("float", 64), + "date": "date", + "datetime": "timestamp", + "datetime2": "timestamp", + "datetimeoffset": "timestamp", + "decimal": "numeric", + "geography": "geometry", + "geometry": "geometry", + "interval": "interval", + "numeric": "numeric", + "nvarchar": "text", + "text": "text", + "time": "time", + "varchar": "text", + "varbinary": "blob", +} + +# Types that can't be roundtripped perfectly in SQL Server, and what they end up as. +APPROXIMATED_TYPES = {"interval": "text"} +# Note that although this means that all other V2 types above can be roundtripped, it +# doesn't mean that extra type info is always preserved. Specifically, extra +# geometry type info - the geometry type and CRS - is not roundtripped. + + +def v2_schema_to_sqlserver_spec(schema, v2_obj): + """ + Generate the SQL CREATE TABLE spec from a V2 object eg: + 'fid INTEGER, geom GEOMETRY(POINT,2136), desc VARCHAR(128), PRIMARY KEY(fid)' + """ + result = [f"{quote(col.name)} {v2_type_to_ms_type(col, v2_obj)}" for col in schema] + + if schema.pk_columns: + pk_col_names = ", ".join((quote(col.name) for col in schema.pk_columns)) + result.append(f"PRIMARY KEY({pk_col_names})") + + return ", ".join(result) + + +def v2_type_to_ms_type(column_schema, v2_obj): + """Convert a v2 schema type to a SQL server type.""" + + v2_type = column_schema.data_type + extra_type_info = column_schema.extra_type_info + + ms_type_info = V2_TYPE_TO_MS_TYPE.get(v2_type) + if ms_type_info is None: + raise ValueError(f"Unrecognised data type: {v2_type}") + + if isinstance(ms_type_info, dict): + return ms_type_info.get(extra_type_info.get("size", 0)) + + ms_type = ms_type_info + if ms_type == "geometry": + geometry_type = extra_type_info.get("geometryType") + crs_name = extra_type_info.get("geometryCRS") + crs_id = None + if crs_name is not None: + crs_id = crs_util.get_identifier_int_from_dataset(v2_obj, crs_name) + return _v2_geometry_type_to_ms_type(geometry_type, crs_id) + + if ms_type in ("varchar", "nvarchar", "varbinary"): + length = extra_type_info.get("length", None) + return f"{ms_type}({length})" if length is not None else f"{ms_type}(max)" + + if ms_type == "numeric": + precision = extra_type_info.get("precision", None) + scale = extra_type_info.get("scale", None) + if precision is not None and scale is not None: + return f"numeric({precision},{scale})" + elif precision is not None: + return f"numeric({precision})" + else: + return "numeric" + + return ms_type + + +def _v2_geometry_type_to_ms_type(geometry_type, crs_id): + if geometry_type is not None: + geometry_type = geometry_type.replace(" ", "") + + if geometry_type is not None and crs_id is not None: + return f"geometry({geometry_type},{crs_id})" + elif geometry_type is not None: + return f"geometry({geometry_type})" + else: + return "geometry" + + +def sqlserver_to_v2_schema(ms_table_info, id_salt): + """Generate a V2 schema from the given postgis metadata tables.""" + return Schema([_sqlserver_to_column_schema(col, id_salt) for col in ms_table_info]) + + +def _sqlserver_to_column_schema(ms_col_info, id_salt): + """ + Given the postgis column info for a particular column, and some extra context in + case it is a geometry column, converts it to a ColumnSchema. The extra context will + only be used if the given ms_col_info is the geometry column. + Parameters: + ms_col_info - info about a single column from ms_table_info. + ms_spatial_ref_sys - rows of the "spatial_ref_sys" table that are referenced by this dataset. + id_salt - the UUIDs of the generated ColumnSchema are deterministic and depend on + the name and type of the column, and on this salt. + """ + name = ms_col_info["column_name"] + pk_index = ms_col_info["pk_ordinal_position"] + if pk_index is not None: + pk_index -= 1 + data_type, extra_type_info = _ms_type_to_v2_type(ms_col_info) + + col_id = ColumnSchema.deterministic_id(name, data_type, id_salt) + return ColumnSchema(col_id, name, data_type, pk_index, **extra_type_info) + + +def _ms_type_to_v2_type(ms_col_info): + v2_type_info = MS_TYPE_TO_V2_TYPE.get(ms_col_info["data_type"]) + + if isinstance(v2_type_info, tuple): + v2_type = v2_type_info[0] + extra_type_info = {"size": v2_type_info[1]} + else: + v2_type = v2_type_info + extra_type_info = {} + + if v2_type == "geometry": + return v2_type, extra_type_info + + if v2_type == "text": + length = ms_col_info["character_maximum_length"] or None + if length is not None: + extra_type_info["length"] = length + + if v2_type == "numeric": + extra_type_info["precision"] = ms_col_info["numeric_precision"] or None + extra_type_info["scale"] = ms_col_info["numeric_scale"] or None + + return v2_type, extra_type_info diff --git a/sno/working_copy/table_defs.py b/sno/working_copy/table_defs.py index 6488985d8..5f8fbd2dc 100644 --- a/sno/working_copy/table_defs.py +++ b/sno/working_copy/table_defs.py @@ -9,6 +9,8 @@ UniqueConstraint, ) +from sqlalchemy.types import NVARCHAR + class TinyInt(Integer): __visit_name__ = "TINYINT" @@ -100,6 +102,36 @@ def create_all(self, session): return self._SQLALCHEMY_METADATA.create_all(session.connection()) +class SqlServerSnoTables(TableSet): + """ + Tables for sno-specific metadata - PostGIS variant. + Table names have a user-defined schema, and so unlike other table sets, + we need to construct an instance with the appropriate schema. + """ + + def __init__(self, schema=None): + self._SQLALCHEMY_METADATA = MetaData() + + self.sno_state = Table( + "_sno_state", + self._SQLALCHEMY_METADATA, + Column("table_name", NVARCHAR(400), nullable=False, primary_key=True), + Column("key", NVARCHAR(400), nullable=False, primary_key=True), + Column("value", Text, nullable=False), + schema=schema, + ) + self.sno_track = Table( + "_sno_track", + self._SQLALCHEMY_METADATA, + Column("table_name", NVARCHAR(400), nullable=False, primary_key=True), + Column("pk", NVARCHAR(400), nullable=True, primary_key=True), + schema=schema, + ) + + def create_all(self, session): + return self._SQLALCHEMY_METADATA.create_all(session.connection()) + + class GpkgTables(TableSet): """GPKG spec tables - see http://www.geopackage.org/spec/#table_definition_sql""" diff --git a/tests/conftest.py b/tests/conftest.py index d5684481a..014dbda6c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,7 +24,7 @@ from sno.geometry import Geometry from sno.repo import SnoRepo -from sno.sqlalchemy import gpkg_engine, postgis_engine +from sno.sqlalchemy import gpkg_engine, postgis_engine, sqlserver_engine from sno.working_copy import WorkingCopy @@ -739,24 +739,29 @@ def func(conn, pk, update_str, layer=None, commit=True): return func -def _insert_command(table_name, col_names, schema=None): +def _insert_command(table_name, col_names): return sqlalchemy.table( - table_name, - *[sqlalchemy.column(c) for c in col_names], - schema=schema, + table_name, *[sqlalchemy.column(c) for c in col_names] ).insert() -def _edit_points(conn, schema=None): +def _edit_points(conn, dataset=None, working_copy=None): H = pytest.helpers.helpers() - layer = f'"{schema}"."{H.POINTS.LAYER}"' if schema else H.POINTS.LAYER - r = conn.execute( - _insert_command(H.POINTS.LAYER, H.POINTS.RECORD.keys(), schema=schema), - H.POINTS.RECORD, - ) - assert r.rowcount == 1 + + if working_copy is None: + layer = H.POINTS.LAYER + insert_cmd = _insert_command(H.POINTS.LAYER, H.POINTS.RECORD.keys()) + else: + layer = f"{working_copy.DB_SCHEMA}.{H.POINTS.LAYER}" + insert_cmd = working_copy._insert_into_dataset(dataset) + + # Note - different DB backends support and interpret rowcount differently. + # Sometimes rowcount is not supported for inserts, so it just returns -1. + # Rowcount can be 1 or 2 if 1 row has changed its PK + r = conn.execute(insert_cmd, H.POINTS.RECORD) + assert r.rowcount in (1, -1) r = conn.execute(f"UPDATE {layer} SET fid=9998 WHERE fid=1;") - assert r.rowcount == 1 + assert r.rowcount in (1, 2) r = conn.execute(f"UPDATE {layer} SET name='test' WHERE fid=2;") assert r.rowcount == 1 r = conn.execute(f"DELETE FROM {layer} WHERE fid IN (3,30,31,32,33);") @@ -770,16 +775,20 @@ def edit_points(): return _edit_points -def _edit_polygons(conn, schema=None): +def _edit_polygons(conn, dataset=None, working_copy=None): H = pytest.helpers.helpers() - layer = f'"{schema}"."{H.POLYGONS.LAYER}"' if schema else H.POLYGONS.LAYER - r = conn.execute( - _insert_command(H.POLYGONS.LAYER, H.POLYGONS.RECORD.keys(), schema=schema), - H.POLYGONS.RECORD, - ) - assert r.rowcount == 1 + if working_copy is None: + layer = H.POLYGONS.LAYER + insert_cmd = _insert_command(H.POLYGONS.LAYER, H.POLYGONS.RECORD.keys()) + else: + layer = f"{working_copy.DB_SCHEMA}.{H.POLYGONS.LAYER}" + insert_cmd = working_copy._insert_into_dataset(dataset) + + # See note on rowcount at _edit_points + r = conn.execute(insert_cmd, H.POLYGONS.RECORD) + assert r.rowcount in (1, -1) r = conn.execute(f"UPDATE {layer} SET id=9998 WHERE id=1424927;") - assert r.rowcount == 1 + assert r.rowcount in (1, 2) r = conn.execute(f"UPDATE {layer} SET survey_reference='test' WHERE id=1443053;") assert r.rowcount == 1 r = conn.execute( @@ -795,16 +804,21 @@ def edit_polygons(): return _edit_polygons -def _edit_table(conn, schema=None): +def _edit_table(conn, dataset=None, working_copy=None): H = pytest.helpers.helpers() - layer = f'"{schema}"."{H.TABLE.LAYER}"' if schema else H.TABLE.LAYER - r = conn.execute( - _insert_command(H.TABLE.LAYER, H.TABLE.RECORD.keys(), schema=schema), - H.TABLE.RECORD, - ) - assert r.rowcount == 1 + + if working_copy is None: + layer = H.TABLE.LAYER + insert_cmd = _insert_command(H.TABLE.LAYER, H.TABLE.RECORD.keys()) + else: + layer = f"{working_copy.DB_SCHEMA}.{H.TABLE.LAYER}" + insert_cmd = working_copy._insert_into_dataset(dataset) + + r = conn.execute(insert_cmd, H.TABLE.RECORD) + # rowcount is not actually supported for inserts, but works in certain DB types - otherwise is -1. + assert r.rowcount in (1, -1) r = conn.execute(f"""UPDATE {layer} SET "OBJECTID"=9998 WHERE "OBJECTID"=1;""") - assert r.rowcount == 1 + assert r.rowcount in (1, 2) r = conn.execute(f"""UPDATE {layer} SET "NAME"='test' WHERE "OBJECTID"=2;""") assert r.rowcount == 1 r = conn.execute(f"""DELETE FROM {layer} WHERE "OBJECTID" IN (3,30,31,32,33);""") @@ -869,14 +883,12 @@ def disable_editor(): @pytest.fixture() def postgis_db(): """ - Using docker, you can run a PostGres test - such as test_postgis_import - as follows: + Using docker, you can run a PostGIS test - such as test_postgis_import - as follows: docker run -it --rm -d -p 15432:5432 -e POSTGRES_HOST_AUTH_METHOD=trust kartoza/postgis SNO_POSTGRES_URL='postgresql://docker:docker@localhost:15432/gis' pytest -k postgis --pdb -vvs """ if "SNO_POSTGRES_URL" not in os.environ: - raise pytest.skip( - "Requires postgres - read docstring at sno.test_structure.postgis_db" - ) + raise pytest.skip("Requires PostGIS - read docstring at conftest.postgis_db") engine = postgis_engine(os.environ["SNO_POSTGRES_URL"]) with engine.connect() as conn: # test connection and postgis support @@ -912,3 +924,65 @@ def ctx(create=False): conn.execute(f"""DROP SCHEMA IF EXISTS "{schema}" CASCADE;""") return ctx + + +@pytest.fixture() +def sqlserver_db(): + """ + Using docker, you can run a SQL Server test - such as those in test_working_copy_sqlserver - as follows: + docker run -it --rm -d -p 11433:1433 -e ACCEPT_EULA=Y -e 'SA_PASSWORD=Sql(server)' mcr.microsoft.com/mssql/server + SNO_SQLSERVER_URL='mssql://sa:Sql(server)@127.0.0.1:11433/master' spytest -k sqlserver --pdb -vvs + """ + if "SNO_SQLSERVER_URL" not in os.environ: + raise pytest.skip( + "Requires SQL Server - read docstring at conftest.sqlserver_db" + ) + engine = sqlserver_engine(os.environ["SNO_SQLSERVER_URL"]) + with engine.connect() as conn: + # test connection and postgis support + try: + conn.execute("SELECT @@version;") + except sqlalchemy.exc.DBAPIError: + raise pytest.skip("Requires SQL Server") + yield engine + + +@pytest.fixture() +def new_sqlserver_db_schema(request, sqlserver_db): + @contextlib.contextmanager + def ctx(create=False): + sha = hashlib.sha1(request.node.nodeid.encode("utf8")).hexdigest()[:20] + schema = f"sno_test_{sha}" + with sqlserver_db.connect() as conn: + # Start by deleting in case it is left over from last test-run... + _sqlserver_drop_schema_cascade(conn, schema) + # Actually create only if create=True, otherwise the test will create it + if create: + conn.execute(f"""CREATE SCHEMA "{schema}";""") + try: + url = urlsplit(os.environ["SNO_SQLSERVER_URL"]) + url_path = url.path.rstrip("/") + "/" + schema + new_schema_url = urlunsplit( + [url.scheme, url.netloc, url_path, url.query, ""] + ) + yield new_schema_url, schema + finally: + # Clean up - delete it again if it exists. + with sqlserver_db.connect() as conn: + _sqlserver_drop_schema_cascade(conn, schema) + + return ctx + + +def _sqlserver_drop_schema_cascade(conn, db_schema): + r = conn.execute( + sqlalchemy.text( + "SELECT name FROM sys.tables WHERE schema_id=SCHEMA_ID(:schema);" + ), + {"schema": db_schema}, + ) + table_identifiers = ", ".join([f"{db_schema}.{row[0]}" for row in r]) + if table_identifiers: + conn.execute(f"DROP TABLE IF EXISTS {table_identifiers};") + + conn.execute(f"DROP SCHEMA IF EXISTS {db_schema};") diff --git a/tests/test_working_copy_postgis.py b/tests/test_working_copy_postgis.py index cfbaae24a..0b09efcd8 100644 --- a/tests/test_working_copy_postgis.py +++ b/tests/test_working_copy_postgis.py @@ -151,11 +151,11 @@ def test_commit_edits( with wc.session() as sess: if archive == "points": - edit_points(sess, postgres_schema) + edit_points(sess, repo.datasets()[H.POINTS.LAYER], wc) elif archive == "polygons": - edit_polygons(sess, postgres_schema) + edit_polygons(sess, repo.datasets()[H.POLYGONS.LAYER], wc) elif archive == "table": - edit_table(sess, postgres_schema) + edit_table(sess, repo.datasets()[H.TABLE.LAYER], wc) r = cli_runner.invoke(["status"]) assert r.exit_code == 0, r.stderr diff --git a/tests/test_working_copy_sqlserver.py b/tests/test_working_copy_sqlserver.py new file mode 100644 index 000000000..2853f2263 --- /dev/null +++ b/tests/test_working_copy_sqlserver.py @@ -0,0 +1,313 @@ +import pytest + +import pygit2 + +from sno.repo import SnoRepo +from sno.working_copy import sqlserver_adapter +from test_working_copy import compute_approximated_types + + +H = pytest.helpers.helpers() + + +@pytest.mark.parametrize( + "existing_schema", + [ + pytest.param(True, id="existing-schema"), + pytest.param(False, id="brand-new-schema"), + ], +) +@pytest.mark.parametrize( + "archive,table,commit_sha", + [ + pytest.param("points", H.POINTS.LAYER, H.POINTS.HEAD_SHA, id="points"), + pytest.param("polygons", H.POLYGONS.LAYER, H.POLYGONS.HEAD_SHA, id="polygons"), + pytest.param("table", H.TABLE.LAYER, H.TABLE.HEAD_SHA, id="table"), + ], +) +def test_checkout_workingcopy( + archive, + table, + commit_sha, + existing_schema, + data_archive, + cli_runner, + new_sqlserver_db_schema, +): + """ Checkout a working copy """ + with data_archive(archive) as repo_path: + repo = SnoRepo(repo_path) + H.clear_working_copy() + + with new_sqlserver_db_schema(create=existing_schema) as ( + sqlserver_url, + sqlserver_schema, + ): + r = cli_runner.invoke(["create-workingcopy", sqlserver_url]) + assert r.exit_code == 0, r.stderr + assert ( + r.stdout.splitlines()[-1] + == f"Creating working copy at {sqlserver_url} ..." + ) + + r = cli_runner.invoke(["status"]) + assert r.exit_code == 0, r.stderr + assert r.stdout.splitlines() == [ + "On branch main", + "", + "Nothing to commit, working copy clean", + ] + + wc = repo.working_copy + assert wc.is_created() + + head_tree_id = repo.head_tree.hex + assert wc.assert_db_tree_match(head_tree_id) + + +@pytest.mark.parametrize( + "existing_schema", + [ + pytest.param(True, id="existing-schema"), + pytest.param(False, id="brand-new-schema"), + ], +) +def test_init_import( + existing_schema, + new_sqlserver_db_schema, + data_archive, + tmp_path, + cli_runner, +): + """ Import the GeoPackage (eg. `kx-foo-layer.gpkg`) into a Sno repository. """ + repo_path = tmp_path / "data.sno" + repo_path.mkdir() + + with data_archive("gpkg-points") as data: + with new_sqlserver_db_schema(create=existing_schema) as ( + sqlserver_url, + sqlserver_schema, + ): + r = cli_runner.invoke( + [ + "init", + "--import", + f"gpkg:{data / 'nz-pa-points-topo-150k.gpkg'}", + str(repo_path), + f"--workingcopy-path={sqlserver_url}", + ] + ) + assert r.exit_code == 0, r.stderr + assert (repo_path / ".sno" / "HEAD").exists() + + repo = SnoRepo(repo_path) + wc = repo.working_copy + + assert wc.is_created() + assert wc.is_initialised() + assert wc.has_data() + + assert wc.path == sqlserver_url + + +@pytest.mark.parametrize( + "archive,table,commit_sha", + [ + pytest.param("points", H.POINTS.LAYER, H.POINTS.HEAD_SHA, id="points"), + pytest.param("polygons", H.POLYGONS.LAYER, H.POLYGONS.HEAD_SHA, id="polygons"), + pytest.param("table", H.TABLE.LAYER, H.TABLE.HEAD_SHA, id="table"), + ], +) +def test_commit_edits( + archive, + table, + commit_sha, + data_archive, + cli_runner, + new_sqlserver_db_schema, + edit_points, + edit_polygons, + edit_table, +): + """ Checkout a working copy and make some edits """ + with data_archive(archive) as repo_path: + repo = SnoRepo(repo_path) + H.clear_working_copy() + + with new_sqlserver_db_schema() as (sqlserver_url, sqlserver_schema): + r = cli_runner.invoke(["create-workingcopy", sqlserver_url]) + assert r.exit_code == 0, r.stderr + + r = cli_runner.invoke(["status"]) + assert r.exit_code == 0, r.stderr + assert r.stdout.splitlines() == [ + "On branch main", + "", + "Nothing to commit, working copy clean", + ] + + wc = repo.working_copy + assert wc.is_created() + + with wc.session() as sess: + if archive == "points": + edit_points(sess, repo.datasets()[H.POINTS.LAYER], wc) + elif archive == "polygons": + edit_polygons(sess, repo.datasets()[H.POLYGONS.LAYER], wc) + elif archive == "table": + edit_table(sess, repo.datasets()[H.TABLE.LAYER], wc) + + r = cli_runner.invoke(["status"]) + assert r.exit_code == 0, r.stderr + assert r.stdout.splitlines() == [ + "On branch main", + "", + "Changes in working copy:", + ' (use "sno commit" to commit)', + ' (use "sno reset" to discard changes)', + "", + f" {table}:", + " feature:", + " 1 inserts", + " 2 updates", + " 5 deletes", + ] + orig_head = repo.head.peel(pygit2.Commit).hex + + r = cli_runner.invoke(["commit", "-m", "test_commit"]) + assert r.exit_code == 0, r.stderr + + r = cli_runner.invoke(["status"]) + assert r.exit_code == 0, r.stderr + assert r.stdout.splitlines() == [ + "On branch main", + "", + "Nothing to commit, working copy clean", + ] + + new_head = repo.head.peel(pygit2.Commit).hex + assert new_head != orig_head + + r = cli_runner.invoke(["checkout", "HEAD^"]) + + assert repo.head.peel(pygit2.Commit).hex == orig_head + + +def test_edit_schema(data_archive, cli_runner, new_sqlserver_db_schema): + with data_archive("polygons") as repo_path: + repo = SnoRepo(repo_path) + H.clear_working_copy() + + with new_sqlserver_db_schema() as (sqlserver_url, sqlserver_schema): + r = cli_runner.invoke(["create-workingcopy", sqlserver_url]) + assert r.exit_code == 0, r.stderr + + wc = repo.working_copy + assert wc.is_created() + + r = cli_runner.invoke(["diff", "--output-format=quiet"]) + assert r.exit_code == 0, r.stderr + + with wc.session() as sess: + sess.execute( + f"""ALTER TABLE "{sqlserver_schema}"."{H.POLYGONS.LAYER}" ADD colour NVARCHAR(32);""" + ) + sess.execute( + f"""ALTER TABLE "{sqlserver_schema}"."{H.POLYGONS.LAYER}" DROP COLUMN survey_reference;""" + ) + + r = cli_runner.invoke(["diff"]) + assert r.exit_code == 0, r.stderr + diff = r.stdout.splitlines() + + # New column "colour" has an ID is deterministically generated from the commit hash, + # but we don't care exactly what it is. + try: + colour_id_line = diff[-6] + except KeyError: + colour_id_line = "" + + assert diff[-46:] == [ + "--- nz_waca_adjustments:meta:schema.json", + "+++ nz_waca_adjustments:meta:schema.json", + " [", + " {", + ' "id": "79d3c4ca-3abd-0a30-2045-45169357113c",', + ' "name": "id",', + ' "dataType": "integer",', + ' "primaryKeyIndex": 0,', + ' "size": 64', + " },", + " {", + ' "id": "c1d4dea1-c0ad-0255-7857-b5695e3ba2e9",', + ' "name": "geom",', + ' "dataType": "geometry",', + ' "geometryType": "MULTIPOLYGON",', + ' "geometryCRS": "EPSG:4167"', + " },", + " {", + ' "id": "d3d4b64b-d48e-4069-4bb5-dfa943d91e6b",', + ' "name": "date_adjusted",', + ' "dataType": "timestamp"', + " },", + "- {", + '- "id": "dff34196-229d-f0b5-7fd4-b14ecf835b2c",', + '- "name": "survey_reference",', + '- "dataType": "text",', + '- "length": 50', + "- },", + " {", + ' "id": "13dc4918-974e-978f-05ce-3b4321077c50",', + ' "name": "adjusted_nodes",', + ' "dataType": "integer",', + ' "size": 32', + " },", + "+ {", + colour_id_line, + '+ "name": "colour",', + '+ "dataType": "text",', + '+ "length": 32', + "+ },", + " ]", + ] + + orig_head = repo.head.peel(pygit2.Commit).hex + + r = cli_runner.invoke(["commit", "-m", "test_commit"]) + assert r.exit_code == 0, r.stderr + + r = cli_runner.invoke(["status"]) + assert r.exit_code == 0, r.stderr + assert r.stdout.splitlines() == [ + "On branch main", + "", + "Nothing to commit, working copy clean", + ] + + new_head = repo.head.peel(pygit2.Commit).hex + assert new_head != orig_head + + r = cli_runner.invoke(["checkout", "HEAD^"]) + + assert repo.head.peel(pygit2.Commit).hex == orig_head + + +def test_approximated_types(): + assert sqlserver_adapter.APPROXIMATED_TYPES == compute_approximated_types( + sqlserver_adapter.V2_TYPE_TO_MS_TYPE, sqlserver_adapter.MS_TYPE_TO_V2_TYPE + ) + + +def test_types_roundtrip(data_archive, cli_runner, new_sqlserver_db_schema): + with data_archive("types") as repo_path: + repo = SnoRepo(repo_path) + H.clear_working_copy() + + with new_sqlserver_db_schema() as (sqlserver_url, sqlserver_schema): + repo.config["sno.workingcopy.path"] = sqlserver_url + r = cli_runner.invoke(["checkout"]) + + # If type-approximation roundtrip code isn't working, + # we would get spurious diffs on types that PostGIS doesn't support. + r = cli_runner.invoke(["diff", "--exit-code"]) + assert r.exit_code == 0, r.stdout From 1c39cada9affa734b36a0717b40cbfc50d8bc0ba Mon Sep 17 00:00:00 2001 From: Andrew Olsen Date: Wed, 3 Mar 2021 14:41:04 +1300 Subject: [PATCH 2/2] Address SQL Server review comments --- sno/base_dataset.py | 4 + sno/checkout.py | 2 +- sno/clone.py | 2 +- sno/geometry.py | 8 + sno/init.py | 2 +- sno/repo.py | 4 +- sno/sqlalchemy.py | 32 ++- sno/working_copy/base.py | 123 +++++++----- sno/working_copy/db_server.py | 103 ++++++++++ sno/working_copy/gpkg.py | 172 +++++++++-------- sno/working_copy/postgis.py | 161 ++++------------ sno/working_copy/sqlserver.py | 267 ++++++++++---------------- sno/working_copy/sqlserver_adapter.py | 39 +--- sno/working_copy/table_defs.py | 4 +- tests/conftest.py | 6 +- tests/test_working_copy_sqlserver.py | 2 +- 16 files changed, 466 insertions(+), 465 deletions(-) create mode 100644 sno/working_copy/db_server.py diff --git a/sno/base_dataset.py b/sno/base_dataset.py index 56325600f..0a25f0592 100644 --- a/sno/base_dataset.py +++ b/sno/base_dataset.py @@ -9,6 +9,10 @@ class BaseDataset(ImportSource): """ Common interface for all datasets - mainly Dataset2, but there is also Dataset0 and Dataset1 used by `sno upgrade`. + + A Dataset instance is immutable since it is a view of a particular git tree. + To get a new version of a dataset, commit the desired changes, + then instantiate a new Dataset instance that references the new git tree. """ # Constants that subclasses should generally define. diff --git a/sno/checkout.py b/sno/checkout.py index cf3f80ace..587a49011 100644 --- a/sno/checkout.py +++ b/sno/checkout.py @@ -378,7 +378,7 @@ def create_workingcopy(ctx, discard_changes, wc_path): wc_path = WorkingCopy.default_path(repo.workdir_path) if wc_path != old_wc_path: - WorkingCopy.check_valid_creation_path(repo.workdir_path, wc_path) + WorkingCopy.check_valid_creation_path(wc_path, repo.workdir_path) # Finished sanity checks - start work: if old_wc and wc_path != old_wc_path: diff --git a/sno/clone.py b/sno/clone.py index c598746ad..63574566a 100644 --- a/sno/clone.py +++ b/sno/clone.py @@ -106,7 +106,7 @@ def clone( if repo_path.exists() and any(repo_path.iterdir()): raise InvalidOperation(f'"{repo_path}" isn\'t empty', param_hint="directory") - WorkingCopy.check_valid_creation_path(repo_path, wc_path) + WorkingCopy.check_valid_creation_path(wc_path, repo_path) if not repo_path.exists(): repo_path.mkdir(parents=True) diff --git a/sno/geometry.py b/sno/geometry.py index ca42f0d75..8873c8ddd 100644 --- a/sno/geometry.py +++ b/sno/geometry.py @@ -59,6 +59,11 @@ def with_crs_id(self, crs_id): @property def crs_id(self): + """ + Returns the CRS ID as it is embedded in the GPKG header - before the WKB. + Note that datasets V2 zeroes this field before committing, + so will return zero when called on Geometry where it has been zeroed. + """ wkb_offset, is_le, crs_id = parse_gpkg_geom(self) return crs_id @@ -296,6 +301,7 @@ def gpkg_geom_to_ogr(gpkg_geom, parse_crs=False): def wkt_to_gpkg_geom(wkt, **kwargs): + """Given a well-known-text string, returns a GPKG Geometry object.""" if wkt is None: return None @@ -304,6 +310,7 @@ def wkt_to_gpkg_geom(wkt, **kwargs): def wkb_to_gpkg_geom(wkb, **kwargs): + """Given a well-known-binary bytestring, returns a GPKG Geometry object.""" if wkb is None: return None @@ -312,6 +319,7 @@ def wkb_to_gpkg_geom(wkb, **kwargs): def hex_wkb_to_gpkg_geom(hex_wkb, **kwargs): + """Given a hex-encoded well-known-binary bytestring, returns a GPKG Geometry object.""" if hex_wkb is None: return None diff --git a/sno/init.py b/sno/init.py index 4f7dcc5de..a9d88b4bc 100644 --- a/sno/init.py +++ b/sno/init.py @@ -369,7 +369,7 @@ def init( if repo_path.exists() and any(repo_path.iterdir()): raise InvalidOperation(f'"{repo_path}" isn\'t empty', param_hint="directory") - WorkingCopy.check_valid_creation_path(repo_path, wc_path) + WorkingCopy.check_valid_creation_path(wc_path, repo_path) if not repo_path.exists(): repo_path.mkdir(parents=True) diff --git a/sno/repo.py b/sno/repo.py index 7fb303fa6..1e5420d1a 100644 --- a/sno/repo.py +++ b/sno/repo.py @@ -177,7 +177,7 @@ def init_repository( repo_root_path = repo_root_path.resolve() cls._ensure_exists_and_empty(repo_root_path) if not bare: - WorkingCopy.check_valid_creation_path(repo_root_path, wc_path) + WorkingCopy.check_valid_creation_path(wc_path, repo_root_path) extra_args = [] if initial_branch is not None: @@ -224,7 +224,7 @@ def clone_repository( repo_root_path = repo_root_path.resolve() cls._ensure_exists_and_empty(repo_root_path) if not bare: - WorkingCopy.check_valid_creation_path(repo_root_path, wc_path) + WorkingCopy.check_valid_creation_path(wc_path, repo_root_path) if bare: sno_repo = cls._create_with_git_command( diff --git a/sno/sqlalchemy.py b/sno/sqlalchemy.py index c3208f462..959bedadb 100644 --- a/sno/sqlalchemy.py +++ b/sno/sqlalchemy.py @@ -1,6 +1,6 @@ import re import socket -from urllib.parse import urlsplit, urlunsplit +from urllib.parse import urlsplit, urlunsplit, urlencode, parse_qs import sqlalchemy @@ -11,6 +11,7 @@ from sno import spatialite_path from sno.geometry import Geometry +from sno.exceptions import NotFound def gpkg_engine(path): @@ -110,7 +111,6 @@ def _on_connect(psycopg2_conn, connection_record): CANONICAL_SQL_SERVER_SCHEME = "mssql" INTERNAL_SQL_SERVER_SCHEME = "mssql+pyodbc" -SQL_SERVER_DRIVER_LIB = "ODBC+Driver+17+for+SQL+Server" def sqlserver_engine(msurl): @@ -122,7 +122,7 @@ def sqlserver_engine(msurl): url_netloc = re.sub(r"\blocalhost\b", _replace_with_localhost, url.netloc) url_query = _append_to_query( - url.query, {"driver": SQL_SERVER_DRIVER_LIB, "Application+Name": "sno"} + url.query, {"driver": get_sqlserver_driver(), "Application Name": "sno"} ) msurl = urlunsplit( @@ -133,18 +133,30 @@ def sqlserver_engine(msurl): return engine +def get_sqlserver_driver(): + import pyodbc + + drivers = [ + d for d in pyodbc.drivers() if re.search("SQL Server", d, flags=re.IGNORECASE) + ] + if not drivers: + drivers = pyodbc.drivers() + if not drivers: + raise NotFound("SQL Server driver was not found") + return sorted(drivers)[-1] # Latest driver + + def _replace_with_localhost(*args, **kwargs): return socket.gethostbyname("localhost") -def _append_query_to_url(uri, query_dict): +def _append_query_to_url(uri, new_query_dict): url = urlsplit(uri) - url_query = _append_to_query(url.query, query_dict) + url_query = _append_to_query(url.query, new_query_dict) return urlunsplit([url.scheme, url.netloc, url.path, url_query, ""]) -def _append_to_query(url_query, query_dict): - for key, value in query_dict.items(): - if key not in url_query: - url_query = "&".join(filter(None, [url_query, f"{key}={value}"])) - return url_query +def _append_to_query(existing_query, new_query_dict): + query_dict = parse_qs(existing_query) + # ignore new keys if they're already set in the querystring + return urlencode({**new_query_dict, **query_dict}) diff --git a/sno/working_copy/base.py b/sno/working_copy/base.py index 4dffdcac2..7332e3015 100644 --- a/sno/working_copy/base.py +++ b/sno/working_copy/base.py @@ -1,14 +1,13 @@ +from enum import Enum, auto import contextlib import functools import itertools import logging import time -from enum import Enum, auto - import click import pygit2 -import sqlalchemy +import sqlalchemy as sa from sno.base_dataset import BaseDataset from sno.diff_structs import RepoDiff, DatasetDiff, DeltaDiff, Delta @@ -48,7 +47,8 @@ def from_path(cls, path, allow_invalid=False): f"Unrecognised working copy type: {path}\n" "Try one of:\n" " PATH.gpkg\n" - " postgresql://[HOST]/DBNAME/SCHEMA\n" + " postgresql://[HOST]/DBNAME/DBSCHEMA\n" + " mssql://[HOST]/DBNAME/DBSCHEMA" ) @property @@ -103,12 +103,9 @@ class WorkingCopy: SNO_WORKINGCOPY_PATH = "sno.workingcopy.path" @property - @functools.lru_cache(maxsize=1) - def DB_SCHEMA(self): - """Escaped, dialect-specific name of the database-schema owned by this working copy (if any).""" - if self.db_schema is None: - raise RuntimeError("No schema to escape.") - return self.preparer.format_schema(self.db_schema) + def WORKING_COPY_TYPE_NAME(self): + """Human readable name of this type of working copy, eg "PostGIS".""" + raise NotImplementedError() @property @functools.lru_cache(maxsize=1) @@ -147,9 +144,7 @@ def table_identifier(self, dataset_or_table): else dataset_or_table ) sqlalchemy_table = ( - sqlalchemy.table(table, schema=self.db_schema) - if isinstance(table, str) - else table + sa.table(table, schema=self.db_schema) if isinstance(table, str) else table ) return self.preparer.format_table(sqlalchemy_table) @@ -157,7 +152,7 @@ def _table_def_for_column_schema(self, col, dataset): # This is just used for selects/inserts/updates - we don't need the full type information. # We only need type information if some automatic conversion needs to happen on read or write. # TODO: return the full type information so we can also use it for CREATE TABLE. - return sqlalchemy.column(col.name) + return sa.column(col.name) @functools.lru_cache() def _table_def_for_dataset(self, dataset, schema=None): @@ -166,7 +161,7 @@ def _table_def_for_dataset(self, dataset, schema=None): Specifies table and column names only - the minimum required such that an insert() or update() will work. """ schema = schema or dataset.schema - return sqlalchemy.table( + return sa.table( dataset.table_name, *[self._table_def_for_column_schema(c, dataset) for c in schema], schema=self.db_schema, @@ -261,7 +256,16 @@ def write_config(cls, repo, path=None, bare=False): repo_cfg[path_key] = str(path) @classmethod - def check_valid_creation_path(cls, workdir_path, wc_path): + def subclass_from_path(cls, wc_path): + wct = WorkingCopyType.from_path(wc_path) + if wct.class_ is cls: + raise RuntimeError( + f"No subclass found - don't call subclass_from_path on concrete implementation {cls}." + ) + return wct.class_ + + @classmethod + def check_valid_creation_path(cls, wc_path, workdir_path=None): """ Given a user-supplied string describing where to put the working copy, ensures it is a valid location, and nothing already exists there that prevents us from creating it. Raises InvalidOperation if it is not. @@ -269,20 +273,16 @@ def check_valid_creation_path(cls, workdir_path, wc_path): """ if not wc_path: wc_path = cls.default_path(workdir_path) - WorkingCopyType.from_path(wc_path).class_.check_valid_creation_path( - workdir_path, wc_path - ) + cls.subclass_from_path(wc_path).check_valid_creation_path(wc_path, workdir_path) @classmethod - def check_valid_path(cls, workdir_path, wc_path): + def check_valid_path(cls, wc_path, workdir_path=None): """ Given a user-supplied string describing where to put the working copy, ensures it is a valid location, and nothing already exists there that prevents us from creating it. Raises InvalidOperation if it is not. Doesn't check if we have permissions to create a working copy there. """ - WorkingCopyType.from_path(wc_path).class_.check_valid_path( - workdir_path, wc_path - ) + cls.subclass_from_path(wc_path).check_valid_path(wc_path, workdir_path) def check_valid_state(self): if self.is_created(): @@ -304,19 +304,39 @@ def default_path(cls, workdir_path): return f"{stem}.gpkg" @classmethod - def normalise_path(cls, repo, path): + def normalise_path(cls, repo, wc_path): """If the path is in a non-standard form, normalise it to the equivalent standard form.""" - return WorkingCopyType.from_path(path).class_.normalise_path(repo, path) + return cls.subclass_from_path(wc_path).normalise_path(repo, wc_path) @contextlib.contextmanager def session(self, bulk=0): """ - Context manager for DB sessions, yields a session object inside a transaction - Calling again yields the _same_ session, the transaction/etc only happen in the outer one. + Context manager for GeoPackage DB sessions, yields a connection object inside a transaction - @bulk controls bulk-loading operating mode - subject to change. See ./gpkg.py + Calling again yields the _same_ session, the transaction/etc only happen in the outer one. """ - raise NotImplementedError() + L = logging.getLogger(f"{self.__class__.__qualname__}.session") + + if hasattr(self, "_session"): + # Inner call - reuse existing session. + L.debug("session: existing...") + yield self._session + L.debug("session: existing/done") + return + + L.debug("session: new...") + self._session = self.sessionmaker() + try: + # TODO - use tidier syntax for opening transactions from sqlalchemy. + yield self._session + self._session.commit() + except Exception: + self._session.rollback() + raise + finally: + self._session.close() + del self._session + L.debug("session: new/done") def is_created(self): """ @@ -569,10 +589,10 @@ def _execute_dirty_rows_query( pk_column = table.columns[schema.pk_columns[0].name] tracking_col_type = sno_track.c.pk.type - base_query = sqlalchemy.select(columns=cols_to_select).select_from( + base_query = sa.select(columns=cols_to_select).select_from( sno_track.outerjoin( table, - sno_track.c.pk == sqlalchemy.cast(pk_column, tracking_col_type), + sno_track.c.pk == sa.cast(pk_column, tracking_col_type), ) ) @@ -581,7 +601,7 @@ def _execute_dirty_rows_query( else: pks = list(feature_filter) query = base_query.where( - sqlalchemy.and_( + sa.and_( sno_track.c.table_name == dataset.table_name, sno_track.c.pk.in_(pks), ) @@ -611,10 +631,10 @@ def reset_tracking_table(self, reset_filter=UNFILTERED): continue pks = list(dataset_filter.get("feature", [])) - t = sqlalchemy.text( + t = sa.text( f"DELETE FROM {self.SNO_TRACK} WHERE table_name=:table_name AND pk IN :pks;" ) - t = t.bindparams(sqlalchemy.bindparam("pks", expanding=True)) + t = t.bindparams(sa.bindparam("pks", expanding=True)) sess.execute(t, {"table_name": table_name, "pks": pks}) def can_find_renames(self, meta_diff): @@ -701,8 +721,7 @@ def write_full(self, commit, *datasets, **kwargs): self._write_meta(sess, dataset) if dataset.has_geometry: - # This should be called while the table is still empty. - self._create_spatial_index(sess, dataset) + self._create_spatial_index_pre(sess, dataset) L.info("Creating features...") sql = self._insert_into_dataset(dataset) @@ -739,6 +758,9 @@ def write_full(self, commit, *datasets, **kwargs): "Overall rate: %d features/s", (feat_progress / (t1 - t0 or 0.001)) ) + if dataset.has_geometry: + self._create_spatial_index_post(sess, dataset) + self._create_triggers(sess, dataset) self._update_last_write_time(sess, dataset, commit) @@ -754,18 +776,29 @@ def _write_meta(self, sess, dataset): """Write any non-feature data relating to dataset - title, description, CRS, etc.""" raise NotImplementedError() - def _create_spatial_index(self, sess, dataset): + def _create_spatial_index_pre(self, sess, dataset): """ Creates a spatial index for the table for the given dataset. - The spatial index is configured so that it is automatically updated when the table is modified. - It is not guaranteed that the spatial index will take into account features that are already present - in the table when this function is called - therefore, this should be called while the table is still empty. + This function comes in a pair - _pre is called before features are written, and _post is called afterwards. + Once both are called, the index must contain all the features currently in the table, and, be + configured such that any further writes cause the index to be updated automatically. """ - raise NotImplementedError() + + # Note that the simplest implementation is to add a trigger here so that any further writes update + # the index. Then _create_spatial_index_post needn't be implemented. + pass + + def _create_spatial_index_post(self, sess, dataset): + """Like _create_spatial_index_pre, but runs AFTER the bulk of features have been written.""" + + # Being able to create the index after the bulk of features have been written could be useful for two reasons: + # 1. It might be more efficient to write the features first, then index afterwards. + # 2. Certain working copies are not able to create an index without first knowing a rough bounding box. + pass def _drop_spatial_index(self, sess, dataset): - """Inverse of _create_spatial_index - deletes the spatial index.""" - raise NotImplementedError() + """Inverse of _create_spatial_index_* - deletes the spatial index.""" + pass def _update_last_write_time(self, sess, dataset, commit=None): """Hook for updating the last-modified timestamp stored for a particular dataset, if there is one.""" @@ -795,9 +828,7 @@ def _delete_features(self, sess, dataset, pk_list): pk_column = self.preparer.quote(dataset.primary_key) sql = f"""DELETE FROM {self.table_identifier(dataset)} WHERE {pk_column} IN :pks;""" - stmt = sqlalchemy.text(sql).bindparams( - sqlalchemy.bindparam("pks", expanding=True) - ) + stmt = sa.text(sql).bindparams(sa.bindparam("pks", expanding=True)) feat_count = 0 CHUNK_SIZE = 100 for pks in self._chunk(pk_list, CHUNK_SIZE): diff --git a/sno/working_copy/db_server.py b/sno/working_copy/db_server.py new file mode 100644 index 000000000..b5d73017d --- /dev/null +++ b/sno/working_copy/db_server.py @@ -0,0 +1,103 @@ +import functools +import re +from urllib.parse import urlsplit, urlunsplit + +import click + +from .base import WorkingCopy +from sno.exceptions import InvalidOperation + + +class DatabaseServer_WorkingCopy(WorkingCopy): + """Functionality common to working copies that connect to a database server.""" + + @property + def URI_SCHEME(self): + """The URI scheme to connect to this type of database, eg "postgresql".""" + raise NotImplementedError() + + @classmethod + def check_valid_creation_path(cls, wc_path, workdir_path=None): + cls.check_valid_path(wc_path, workdir_path) + + working_copy = cls(None, wc_path) + if working_copy.has_data(): + db_schema = working_copy.db_schema + container_text = f"schema '{db_schema}'" if db_schema else "working copy" + raise InvalidOperation( + f"Error creating {cls.WORKING_COPY_TYPE_NAME} working copy at {wc_path} - " + f"non-empty {container_text} already exists" + ) + + @classmethod + def check_valid_path(cls, wc_path, workdir_path=None): + cls.check_valid_db_uri(wc_path, workdir_path) + + @classmethod + def normalise_path(cls, repo, wc_path): + return wc_path + + @classmethod + def check_valid_db_uri(cls, db_uri, workdir_path=None): + """ + For working copies that connect to a database - checks the given URI is in the required form: + >>> URI_SCHEME::[HOST]/DBNAME/DBSCHEMA + """ + url = urlsplit(db_uri) + + if url.scheme != cls.URI_SCHEME: + raise click.UsageError( + f"Invalid {cls.WORKING_COPY_TYPE_NAME} URI - " + f"Expecting URI in form: {cls.URI_SCHEME}://[HOST]/DBNAME/DBSCHEMA" + ) + + url_path = url.path + path_parts = url_path[1:].split("/", 3) if url_path else [] + + suggestion_message = "" + if len(path_parts) == 1 and workdir_path is not None: + suggested_path = f"/{path_parts[0]}/{cls.default_db_schema(workdir_path)}" + suggested_uri = urlunsplit( + [url.scheme, url.netloc, suggested_path, url.query, ""] + ) + suggestion_message = f"\nFor example: {suggested_uri}" + + if len(path_parts) != 2: + raise click.UsageError( + f"Invalid {cls.WORKING_COPY_TYPE_NAME} URI - URI requires both database name and database schema:\n" + f"Expecting URI in form: {cls.URI_SCHEME}://[HOST]/DBNAME/DBSCHEMA" + + suggestion_message + ) + + @classmethod + def _separate_db_schema(cls, db_uri): + """ + Removes the DBSCHEMA part off the end of a uri in the form URI_SCHEME::[HOST]/DBNAME/DBSCHEMA - + and returns the URI and the DBSCHEMA separately. + Useful since generally, URI_SCHEME::[HOST]/DBNAME is what is needed to connect to the database, + and then DBSCHEMA must be specified in each query. + """ + url = urlsplit(db_uri) + url_path = url.path + path_parts = url_path[1:].split("/", 3) if url_path else [] + assert len(path_parts) == 2 + url_path = "/" + path_parts[0] + db_schema = path_parts[1] + return urlunsplit([url.scheme, url.netloc, url_path, url.query, ""]), db_schema + + @classmethod + def default_db_schema(cls, workdir_path): + """Returns a suitable default database schema - named after the folder this Sno repo is in.""" + stem = workdir_path.stem + schema = re.sub("[^a-z0-9]+", "_", stem.lower()) + "_sno" + if schema[0].isdigit(): + schema = "_" + schema + return schema + + @property + @functools.lru_cache(maxsize=1) + def DB_SCHEMA(self): + """Escaped, dialect-specific name of the database-schema owned by this working copy (if any).""" + if self.db_schema is None: + raise RuntimeError("No schema to escape.") + return self.preparer.format_schema(self.db_schema) diff --git a/sno/working_copy/gpkg.py b/sno/working_copy/gpkg.py index 9bcbb00dd..2a0846cd8 100644 --- a/sno/working_copy/gpkg.py +++ b/sno/working_copy/gpkg.py @@ -7,7 +7,7 @@ import click from osgeo import gdal -import sqlalchemy +import sqlalchemy as sa from sqlalchemy.orm import sessionmaker from sqlalchemy.sql.compiler import IdentifierPreparer from sqlalchemy.types import UserDefinedType @@ -26,6 +26,15 @@ class WorkingCopy_GPKG(WorkingCopy): + """ + GPKG working copy implementation. + + Requirements: + 1. Can read and write to the filesystem at the specified path. + """ + + WORKING_COPY_TYPE_NAME = "GPKG" + def __init__(self, repo, path): self.repo = repo self.path = path @@ -37,34 +46,34 @@ def __init__(self, repo, path): self.sno_tables = GpkgSnoTables @classmethod - def check_valid_creation_path(cls, workdir_path, path): - cls.check_valid_path(workdir_path, path) + def check_valid_creation_path(cls, wc_path, workdir_path=None): + cls.check_valid_path(wc_path, workdir_path) - gpkg_path = (workdir_path / path).resolve() + gpkg_path = (workdir_path / wc_path).resolve() if gpkg_path.exists(): desc = "path" if gpkg_path.is_dir() else "GPKG file" raise InvalidOperation( - f"Error creating GPKG working copy at {path} - {desc} already exists" + f"Error creating GPKG working copy at {wc_path} - {desc} already exists" ) @classmethod - def check_valid_path(cls, workdir_path, path): - if not str(path).endswith(".gpkg"): - suggested_path = f"{os.path.splitext(str(path))[0]}.gpkg" + def check_valid_path(cls, wc_path, workdir_path=None): + if not str(wc_path).endswith(".gpkg"): + suggested_path = f"{os.path.splitext(str(wc_path))[0]}.gpkg" raise click.UsageError( f"Invalid GPKG path - expected .gpkg suffix, eg {suggested_path}" ) @classmethod - def normalise_path(cls, repo, path): + def normalise_path(cls, repo, wc_path): """Rewrites a relative path (relative to the current directory) as relative to the repo.workdir_path.""" - path = Path(path) - if not path.is_absolute(): + wc_path = Path(wc_path) + if not wc_path.is_absolute(): try: - return str(path.resolve().relative_to(repo.workdir_path.resolve())) + return str(wc_path.resolve().relative_to(repo.workdir_path.resolve())) except ValueError: pass - return str(path) + return str(wc_path) @property def full_path(self): @@ -82,11 +91,11 @@ def _insert_or_replace_into_dataset(self, dataset): def _table_def_for_column_schema(self, col, dataset): if col.data_type == "geometry": - # This user-defined Geography type adapts WKB to SQL Server's native geography type. - return sqlalchemy.column(col.name, GeometryType) + # This user-defined GeometryType normalises GPKG geometry to the Sno V2 GPKG geometry. + return sa.column(col.name, GeometryType) else: # Don't need to specify type information for other columns at present, since we just pass through the values. - return sqlalchemy.column(col.name) + return sa.column(col.name) def _insert_or_replace_state_table_tree(self, sess, tree_id): r = sess.execute( @@ -119,37 +128,36 @@ def session(self, bulk=0): # - do something consistent and safe from then on. if hasattr(self, "_session"): - # inner - reuse + # Inner call - reuse existing session. L.debug(f"session(bulk={bulk}): existing...") yield self._session L.debug(f"session(bulk={bulk}): existing/done") + return - else: - L.debug(f"session(bulk={bulk}): new...") + # Outer call - create new session: + L.debug(f"session(bulk={bulk}): new...") + self._session = self.sessionmaker() - try: - self._session = self.sessionmaker() - - if bulk: - self._session.execute("PRAGMA synchronous = OFF;") - self._session.execute( - "PRAGMA cache_size = -1048576;" - ) # -KiB => 1GiB - if bulk >= 2: - self._session.execute("PRAGMA journal_mode = MEMORY;") - self._session.execute("PRAGMA locking_mode = EXCLUSIVE;") - - # TODO - use tidier syntax for opening transactions from sqlalchemy. - self._session.execute("BEGIN TRANSACTION;") - yield self._session - self._session.commit() - except Exception: - self._session.rollback() - raise - finally: - self._session.close() - del self._session - L.debug(f"session(bulk={bulk}): new/done") + try: + if bulk: + self._session.execute("PRAGMA synchronous = OFF;") + self._session.execute("PRAGMA cache_size = -1048576;") # -KiB => 1GiB + if bulk >= 2: + self._session.execute("PRAGMA journal_mode = MEMORY;") + self._session.execute("PRAGMA locking_mode = EXCLUSIVE;") + + # TODO - use tidier syntax for opening transactions from sqlalchemy. + self._session.execute("BEGIN TRANSACTION;") + yield self._session + self._session.commit() + + except Exception: + self._session.rollback() + raise + finally: + self._session.close() + del self._session + L.debug(f"session(bulk={bulk}): new/done") def delete(self, keep_db_schema_if_possible=False): """Delete the working copy files.""" @@ -428,12 +436,15 @@ def _delete_meta_metadata(self, sess, table_name): """DELETE FROM gpkg_metadata WHERE id IN :ids;""", ) for sql in sqls: - stmt = sqlalchemy.text(sql).bindparams( - sqlalchemy.bindparam("ids", expanding=True) - ) + stmt = sa.text(sql).bindparams(sa.bindparam("ids", expanding=True)) sess.execute(stmt, {"ids": ids}) - def _create_spatial_index(self, sess, dataset): + def _create_spatial_index_pre(self, sess, dataset): + # Implementing only _create_spatial_index_pre: + # gpkgAddSpatialIndex has to be called before writing any features, + # since it only adds on-write triggers to update the index - it doesn't + # add any pre-existing features to the index. + L = logging.getLogger(f"{self.__class__.__qualname__}._create_spatial_index") geom_col = dataset.geom_column_name @@ -480,56 +491,42 @@ def _create_triggers(self, sess, dataset): table_identifier = self.table_identifier(dataset) pk_column = self.quote(dataset.primary_key) - # SQLite doesn't let you do param substitutions in CREATE TRIGGER: - escaped_table_name = dataset.table_name.replace("'", "''") - - sess.execute( + insert_trigger = sa.text( f""" CREATE TRIGGER {self._quoted_trigger_name(dataset, 'ins')} AFTER INSERT ON {table_identifier} BEGIN - INSERT OR REPLACE INTO {self.SNO_TRACK} - (table_name, pk) - VALUES ('{escaped_table_name}', NEW.{pk_column}); + INSERT OR REPLACE INTO {self.SNO_TRACK} (table_name, pk) + VALUES (:table_name, NEW.{pk_column}); END; """ ) - sess.execute( + update_trigger = sa.text( f""" CREATE TRIGGER {self._quoted_trigger_name(dataset, 'upd')} AFTER UPDATE ON {table_identifier} BEGIN - INSERT OR REPLACE INTO {self.SNO_TRACK} - (table_name, pk) - VALUES - ('{escaped_table_name}', NEW.{pk_column}), - ('{escaped_table_name}', OLD.{pk_column}); + INSERT OR REPLACE INTO {self.SNO_TRACK} (table_name, pk) + VALUES (:table_name, NEW.{pk_column}), (:table_name, OLD.{pk_column}); END; """ ) - sess.execute( + delete_trigger = sa.text( f""" CREATE TRIGGER {self._quoted_trigger_name(dataset, 'del')} AFTER DELETE ON {table_identifier} BEGIN - INSERT OR REPLACE INTO {self.SNO_TRACK} - (table_name, pk) - VALUES - ('{escaped_table_name}', OLD.{pk_column}); + INSERT OR REPLACE INTO {self.SNO_TRACK} (table_name, pk) + VALUES (:table_name, OLD.{pk_column}); END; """ ) - - def _db_geom_to_gpkg_geom(self, g): - # Its possible in GPKG to put arbitrary values in columns, regardless of type. - # We don't try to convert them here - we let the commit validation step report this as an error. - if not isinstance(g, bytes): - return g - # We normalise geometries to avoid spurious diffs - diffs where nothing - # of any consequence has changed (eg, only endianness has changed). - # This includes setting the SRID to zero for each geometry so that we don't store a separate SRID per geometry, - # but only one per column at most. - return normalise_gpkg_geom(g) + for trigger in (insert_trigger, update_trigger, delete_trigger): + # Placeholders not allowed in CREATE TRIGGER - have to use literal_binds. + # See https://docs.sqlalchemy.org/en/13/faq/sqlexpressions.html#faq-sql-expression-string + trigger.bindparams(table_name=dataset.table_name).compile( + sess.connection(), compile_kwargs={"literal_binds": True} + ).execute() def _is_meta_update_supported(self, dataset_version, meta_diff): """ @@ -623,6 +620,21 @@ def _apply_meta_metadata_dataset_json(self, sess, dataset, src_value, dest_value def _update_last_write_time(self, sess, dataset, commit=None): self._update_gpkg_contents(sess, dataset, commit) + def _get_geom_extent(self, sess, dataset, default=None): + """Returns the envelope around the entire dataset as (min_x, min_y, max_x, max_y).""" + # FIXME: Why doesn't Extent(geom) work here as an aggregate? + geom_col = dataset.geom_column_name + r = sess.execute( + f""" + WITH _E AS ( + SELECT Extent({self.quote(geom_col)}) AS extent FROM {self.table_identifier(dataset)} + ) + SELECT ST_MinX(extent), ST_MinY(extent), ST_MaxX(extent), ST_MaxY(extent) FROM _E; + """ + ) + result = r.fetchone() + return default if result == (None, None, None, None) else result + def _update_gpkg_contents(self, sess, dataset, commit=None): """ Update the metadata for the given table in gpkg_contents to have the new bounding-box / last-updated timestamp. @@ -634,17 +646,11 @@ def _update_gpkg_contents(self, sess, dataset, commit=None): # GPKG Spec Req. 15: gpkg_change_time = change_time.strftime("%Y-%m-%dT%H:%M:%S.%fZ") - table_identifer = self.table_identifier(dataset) geom_col = dataset.geom_column_name if geom_col is not None: - # FIXME: Why doesn't Extent(geom) work here as an aggregate? - r = sess.execute( - f""" - WITH _E AS (SELECT extent({self.quote(geom_col)}) AS extent FROM {table_identifer}) - SELECT ST_MinX(extent), ST_MinY(extent), ST_MaxX(extent), ST_MaxY(extent) FROM _E - """ + min_x, min_y, max_x, max_y = self._get_geom_extent( + sess, dataset, default=(None, None, None, None) ) - min_x, min_y, max_x, max_y = r.fetchone() rc = sess.execute( """ UPDATE gpkg_contents diff --git a/sno/working_copy/postgis.py b/sno/working_copy/postgis.py index 611bb63ee..975045fdf 100644 --- a/sno/working_copy/postgis.py +++ b/sno/working_copy/postgis.py @@ -1,37 +1,39 @@ import contextlib import logging -import re import time -from urllib.parse import urlsplit, urlunsplit +from urllib.parse import urlsplit -import click + +from sqlalchemy import Index from sqlalchemy.dialects.postgresql import insert as postgresql_insert from sqlalchemy.sql.compiler import IdentifierPreparer from sqlalchemy.orm import sessionmaker -from .base import WorkingCopy from . import postgis_adapter +from .db_server import DatabaseServer_WorkingCopy from .table_defs import PostgisSnoTables from sno import crs_util -from sno.exceptions import InvalidOperation from sno.schema import Schema from sno.sqlalchemy import postgis_engine -""" -* database needs to exist -* database needs to have postgis enabled -* database user needs to be able to: - 1. create 'sno' schema & tables - 2. create & alter tables in the default (or specified) schema - 3. create triggers -""" +class WorkingCopy_Postgis(DatabaseServer_WorkingCopy): + """ + PosttGIS working copy implementation. -L = logging.getLogger("sno.working_copy.postgis") + Requirements: + 1. The database needs to exist + 2. If the dataset has geometry, then PostGIS (https://postgis.net/) v2.4 or newer needs + to be installed into the database and available in the database user's search path + 3. The database user needs to be able to: + - Create the specified schema (unless it already exists). + - Create, delete and alter tables and triggers in the specified schema. + """ + WORKING_COPY_TYPE_NAME = "PostGIS" + URI_SCHEME = "postgresql" -class WorkingCopy_Postgis(WorkingCopy): def __init__(self, repo, uri): """ uri: connection string of the form postgresql://[user[:password]@][netloc][:port][/dbname/schema][?param1=value1&...] @@ -42,75 +44,18 @@ def __init__(self, repo, uri): self.uri = uri self.path = uri - url = urlsplit(uri) - - if url.scheme != "postgresql": - raise ValueError("Expecting postgresql://") + self.check_valid_db_uri(uri) + self.db_uri, self.db_schema = self._separate_db_schema(uri) - url_path = url.path - path_parts = url_path[1:].split("/", 3) if url_path else [] - if len(path_parts) != 2: - raise ValueError("Expecting postgresql://[HOST]/DBNAME/SCHEMA") - url_path = f"/{path_parts[0]}" - self.db_schema = path_parts[1] - - # Rebuild DB URL suitable for postgres - self.dburl = urlunsplit([url.scheme, url.netloc, url_path, url.query, ""]) - self.engine = postgis_engine(self.dburl) + self.engine = postgis_engine(self.db_uri) self.sessionmaker = sessionmaker(bind=self.engine) self.preparer = IdentifierPreparer(self.engine.dialect) self.sno_tables = PostgisSnoTables(self.db_schema) @classmethod - def check_valid_creation_path(cls, workdir_path, path): - cls.check_valid_path(workdir_path, path) - postgis_wc = cls(None, path) - - # Less strict on Postgis - we are okay with the schema being already created, so long as its empty. - if postgis_wc.has_data(): - raise InvalidOperation( - f"Error creating Postgis working copy at {path} - non-empty schema already exists" - ) - - @classmethod - def check_valid_path(cls, workdir_path, path): - url = urlsplit(path) - - if url.scheme != "postgresql": - raise click.UsageError( - "Invalid postgres URI - Expecting URI in form: postgresql://[HOST]/DBNAME/SCHEMA" - ) - - url_path = url.path - path_parts = url_path[1:].split("/", 3) if url_path else [] - - suggestion_message = "" - if len(path_parts) == 1 and workdir_path is not None: - suggested_path = f"/{path_parts[0]}/{cls.default_schema(workdir_path)}" - suggested_uri = urlunsplit( - [url.scheme, url.netloc, suggested_path, url.query, ""] - ) - suggestion_message = f"\nFor example: {suggested_uri}" - - if len(path_parts) != 2: - raise click.UsageError( - "Invalid postgres URI - postgis working copy requires both dbname and schema:\n" - "Expecting URI in form: postgresql://[HOST]/DBNAME/SCHEMA" - + suggestion_message - ) - - @classmethod - def normalise_path(cls, repo, path): - return path - - @classmethod - def default_schema(cls, workdir_path): - stem = workdir_path.stem - schema = re.sub("[^a-z0-9]+", "_", stem.lower()) + "_sno" - if schema[0].isdigit(): - schema = "_" + schema - return schema + def check_valid_path(cls, wc_path, workdir_path=None): + cls.check_valid_db_uri(wc_path, workdir_path) def __str__(self): p = urlsplit(self.uri) @@ -124,44 +69,11 @@ def __str__(self): p._replace(netloc=nl) return p.geturl() - @contextlib.contextmanager - def session(self, bulk=0): - """ - Context manager for GeoPackage DB sessions, yields a connection object inside a transaction - - Calling again yields the _same_ connection, the transaction/etc only happen in the outer one. - """ - L = logging.getLogger(f"{self.__class__.__qualname__}.session") - - if hasattr(self, "_session"): - # inner - reuse - L.debug("session: existing...") - yield self._session - L.debug("session: existing/done") - - else: - L.debug("session: new...") - - try: - self._session = self.sessionmaker() - - # TODO - use tidier syntax for opening transactions from sqlalchemy. - self._session.execute("BEGIN TRANSACTION;") - yield self._session - self._session.commit() - except Exception: - self._session.rollback() - raise - finally: - self._session.close() - del self._session - L.debug("session: new/done") - def is_created(self): """ - Returns true if the postgres schema referred to by this working copy exists and + Returns true if the DB schema referred to by this working copy exists and contains at least one table. If it exists but is empty, it is treated as uncreated. - This is so the postgres schema can be created ahead of time before a repo is created + This is so the DB schema can be created ahead of time before a repo is created or configured, without it triggering code that checks for corrupted working copies. Note that it might not be initialised as a working copy - see self.is_initialised. """ @@ -177,7 +89,7 @@ def is_created(self): def is_initialised(self): """ - Returns true if the postgis working copy is initialised - + Returns true if the PostGIS working copy is initialised - the schema exists and has the necessary sno tables, _sno_state and _sno_track. """ with self.session() as sess: @@ -192,7 +104,7 @@ def is_initialised(self): def has_data(self): """ - Returns true if the postgis working copy seems to have user-created content already. + Returns true if the PostGIS working copy seems to have user-created content already. """ with self.session() as sess: count = sess.scalar( @@ -283,6 +195,7 @@ def _create_table_for_dataset(self, sess, dataset): ) def _insert_or_replace_into_dataset(self, dataset): + # See https://docs.sqlalchemy.org/en/14/dialects/postgresql.html#insert-on-conflict-upsert pk_col_names = [c.name for c in dataset.schema.pk_columns] stmt = postgresql_insert(self._table_def_for_dataset(dataset)) update_dict = {c.name: c for c in stmt.excluded if c.name not in pk_col_names} @@ -326,21 +239,25 @@ def delete_meta(self, dataset): """Delete any metadata that is only needed by this dataset.""" pass # There is no metadata except for the spatial_ref_sys table. - def _create_spatial_index(self, sess, dataset): + def _create_spatial_index_post(self, sess, dataset): + # Only implemented as _create_spatial_index_post: + # It is more efficient to write the features first, then index them all in bulk. L = logging.getLogger(f"{self.__class__.__qualname__}._create_spatial_index") geom_col = dataset.geom_column_name + index_name = f"{dataset.table_name}_idx_{geom_col}" + table = self._table_def_for_dataset(dataset) - # Create the PostGIS Spatial Index L.debug("Creating spatial index for %s.%s", dataset.table_name, geom_col) t0 = time.monotonic() - index_name = f"{dataset.table_name}_idx_{geom_col}" - sess.execute( - f""" - CREATE INDEX {self.quote(index_name)} - ON {self.table_identifier(dataset)} USING GIST ({self.quote(geom_col)}); - """ + + spatial_index = Index( + index_name, table.columns[geom_col], postgres_using="GIST" ) + spatial_index.table = table + spatial_index.create(sess.connection()) + sess.execute(f"""ANALYZE {self.table_identifier(dataset)};""") + L.info("Created spatial index in %ss", time.monotonic() - t0) def _drop_spatial_index(self, sess, dataset): diff --git a/sno/working_copy/sqlserver.py b/sno/working_copy/sqlserver.py index ee4d8bc33..6156a4c5b 100644 --- a/sno/working_copy/sqlserver.py +++ b/sno/working_copy/sqlserver.py @@ -1,11 +1,9 @@ import contextlib import logging -import re import time -from urllib.parse import urlsplit, urlunsplit +from urllib.parse import urlsplit -import click -import sqlalchemy +import sqlalchemy as sa from sqlalchemy import literal_column from sqlalchemy.ext.compiler import compiles from sqlalchemy.orm import sessionmaker @@ -15,28 +13,28 @@ from sqlalchemy.sql.compiler import IdentifierPreparer from sqlalchemy.types import UserDefinedType -from .base import WorkingCopy from . import sqlserver_adapter +from .db_server import DatabaseServer_WorkingCopy from .table_defs import SqlServerSnoTables from sno import crs_util from sno.geometry import Geometry -from sno.exceptions import InvalidOperation from sno.sqlalchemy import sqlserver_engine -""" -* database needs to exist -* database needs to have postgis enabled -* database user needs to be able to: - 1. create 'sno' schema & tables - 2. create & alter tables in the default (or specified) schema - 3. create triggers -""" +class WorkingCopy_SqlServer(DatabaseServer_WorkingCopy): + """ + SQL Server working copy implementation. -L = logging.getLogger("sno.working_copy.postgis") + Requirements: + 1. The database needs to exist + 2. The database user needs to be able to: + - Create the specified schema (unless it already exists). + - Create, delete and alter tables and triggers in the specified schema. + """ + WORKING_COPY_TYPE_NAME = "SQL Server" + URI_SCHEME = "mssql" -class WorkingCopy_SqlServer(WorkingCopy): def __init__(self, repo, uri): """ uri: connection string of the form mssql://[user[:password]@][netloc][:port][/dbname/schema][?param1=value1&...] @@ -47,78 +45,15 @@ def __init__(self, repo, uri): self.uri = uri self.path = uri - url = urlsplit(uri) - - if url.scheme != "mssql": - raise ValueError("Expecting mssql://") - - url_path = url.path - path_parts = url_path[1:].split("/", 3) if url_path else [] - if len(path_parts) != 2: - raise ValueError("Expecting mssql://[HOST]/DBNAME/SCHEMA") - url_path = f"/{path_parts[0]}" - self.db_schema = path_parts[1] + self.check_valid_db_uri(uri) + self.db_uri, self.db_schema = self._separate_db_schema(uri) - # Rebuild DB URL suitable for sqlserver_engine. - self.dburl = urlunsplit([url.scheme, url.netloc, url_path, url.query, ""]) - self.engine = sqlserver_engine(self.dburl) + self.engine = sqlserver_engine(self.db_uri) self.sessionmaker = sessionmaker(bind=self.engine) self.preparer = IdentifierPreparer(self.engine.dialect) self.sno_tables = SqlServerSnoTables(self.db_schema) - @classmethod - def check_valid_creation_path(cls, workdir_path, path): - # TODO - promote to superclass - cls.check_valid_path(workdir_path, path) - sqlserver_wc = cls(None, path) - - # We are okay with the schema being already created, so long as its empty. - if sqlserver_wc.has_data(): - raise InvalidOperation( - f"Error creating SQL Server working copy at {path} - non-empty schema already exists" - ) - - @classmethod - def check_valid_path(cls, workdir_path, path): - url = urlsplit(path) - - if url.scheme != "mssql": - raise click.UsageError( - "Invalid postgres URI - Expecting URI in form: mssql://[HOST]/DBNAME/SCHEMA" - ) - - url_path = url.path - path_parts = url_path[1:].split("/", 3) if url_path else [] - - suggestion_message = "" - if len(path_parts) == 1 and workdir_path is not None: - suggested_path = f"/{path_parts[0]}/{cls.default_schema(workdir_path)}" - suggested_uri = urlunsplit( - [url.scheme, url.netloc, suggested_path, url.query, ""] - ) - suggestion_message = f"\nFor example: {suggested_uri}" - - if len(path_parts) != 2: - raise click.UsageError( - "Invalid mssql URI - SQL Server working copy requires both dbname and schema:\n" - "Expecting URI in form: mssql://[HOST]/DBNAME/SCHEMA" - + suggestion_message - ) - - @classmethod - def normalise_path(cls, repo, path): - return path - - @classmethod - def default_schema(cls, workdir_path): - # TODO - promote to superclass - stem = workdir_path.stem - schema = re.sub("[^a-z0-9]+", "_", stem.lower()) + "_sno" - if schema[0].isdigit(): - schema = "_" + schema - return schema - def __str__(self): p = urlsplit(self.uri) if p.password is not None: @@ -131,38 +66,6 @@ def __str__(self): p._replace(netloc=nl) return p.geturl() - @contextlib.contextmanager - def session(self, bulk=0): - """ - Context manager for GeoPackage DB sessions, yields a connection object inside a transaction - - Calling again yields the _same_ connection, the transaction/etc only happen in the outer one. - """ - L = logging.getLogger(f"{self.__class__.__qualname__}.session") - - if hasattr(self, "_session"): - # inner - reuse - L.debug("session: existing...") - yield self._session - L.debug("session: existing/done") - - else: - L.debug("session: new...") - - try: - self._session = self.sessionmaker() - - # TODO - use tidier syntax for opening transactions from sqlalchemy. - yield self._session - self._session.commit() - except Exception: - self._session.rollback() - raise - finally: - self._session.close() - del self._session - L.debug("session: new/done") - def is_created(self): """ Returns true if the db schema referred to by this working copy exists and @@ -180,7 +83,7 @@ def is_created(self): def is_initialised(self): """ - Returns true if the postgis working copy is initialised - + Returns true if the SQL server working copy is initialised - the schema exists and has the necessary sno tables, _sno_state and _sno_track. """ with self.session() as sess: @@ -196,7 +99,7 @@ def is_initialised(self): def has_data(self): """ - Returns true if the postgis working copy seems to have user-created content already. + Returns true if the SQL server working copy seems to have user-created content already. """ with self.session() as sess: count = sess.scalar( @@ -245,15 +148,15 @@ def _create_table_for_dataset(self, sess, dataset): def _table_def_for_column_schema(self, col, dataset): if col.data_type == "geometry": - # This user-defined Geography type adapts WKB to SQL Server's native geography type. crs_name = col.extra_type_info.get("geometryCRS", None) crs_id = crs_util.get_identifier_int_from_dataset(dataset, crs_name) or 0 - return sqlalchemy.column(col.name, GeographyType(crs_id)) + # This user-defined GeometryType adapts Sno's GPKG geometry to SQL Server's native geometry type. + return sa.column(col.name, GeometryType(crs_id)) elif col.data_type in ("date", "time", "timestamp"): - return sqlalchemy.column(col.name, BaseDateOrTimeType) + return sa.column(col.name, BaseDateOrTimeType) else: # Don't need to specify type information for other columns at present, since we just pass through the values. - return sqlalchemy.column(col.name) + return sa.column(col.name) def _insert_or_replace_into_dataset(self, dataset): pk_col_names = [c.name for c in dataset.schema.pk_columns] @@ -282,33 +185,85 @@ def _insert_or_replace_state_table_tree(self, sess, tree_id): return r.rowcount def _write_meta(self, sess, dataset): - """Write the title (as a comment) and the CRS. Other metadata is not stored in a PostGIS WC.""" + """Write the title. Other metadata is not stored in a SQL Server WC.""" self._write_meta_title(sess, dataset) def _write_meta_title(self, sess, dataset): """Write the dataset title as a comment on the table.""" - # TODO - probably need to use sp_addextendedproperty @name=N'MS_Description' + # TODO - dataset title is not stored anywhere in SQL server working copy right now. + # We can probably store it using function sp_addextendedproperty to add property 'MS_Description' pass def delete_meta(self, dataset): """Delete any metadata that is only needed by this dataset.""" - pass # There is no metadata except for the spatial_ref_sys table. + # There is no metadata stored anywhere except the table itself. + pass + + def _get_geom_extent(self, sess, dataset, default=None): + """Returns the envelope around the entire dataset as (min_x, min_y, max_x, max_y).""" + geom_col = dataset.geom_column_name + r = sess.execute( + f""" + WITH _E AS ( + SELECT geometry::EnvelopeAggregate({self.quote(geom_col)}) AS envelope + FROM {self.table_identifier(dataset)} + ) + SELECT + envelope.STPointN(1).STX AS min_x, + envelope.STPointN(1).STY AS min_y, + envelope.STPointN(3).STX AS max_x, + envelope.STPointN(3).STY AS max_y + FROM _E; + """ + ) + result = r.fetchone() + return default if result == (None, None, None, None) else result + + def _grow_rectangle(self, rectangle, scale_factor): + # scale_factor = 1 -> no change, >1 -> grow, <1 -> shrink. + min_x, min_y, max_x, max_y = rectangle + centre_x, centre_y = (min_x + max_x) / 2, (min_y + max_y) / 2 + min_x = (min_x - centre_x) * scale_factor + centre_x + min_y = (min_y - centre_y) * scale_factor + centre_y + max_x = (max_x - centre_x) * scale_factor + centre_x + max_y = (max_y - centre_y) * scale_factor + centre_y + return min_x, min_y, max_x, max_y + + def _create_spatial_index_post(self, sess, dataset): + # Only implementing _create_spatial_index_post: + # We need to know the rough extent of the data to create an index in that area, + # so we create the spatial index once the bulk of the features have been written. - def _create_spatial_index(self, sess, dataset): L = logging.getLogger(f"{self.__class__.__qualname__}._create_spatial_index") + extent = self._get_geom_extent(sess, dataset) + if not extent: + # Can't create a spatial index if we don't know the rough bounding box we need to index. + return + + # Add 20% room to grow. + GROW_FACTOR = 1.2 + min_x, min_y, max_x, max_y = self._grow_rectangle(extent, GROW_FACTOR) + geom_col = dataset.geom_column_name + index_name = f"{dataset.table_name}_idx_{geom_col}" - # Create the SQL Server Spatial Index L.debug("Creating spatial index for %s.%s", dataset.table_name, geom_col) t0 = time.monotonic() - index_name = f"{dataset.table_name}_idx_{geom_col}" - sess.execute( + + create_index = sa.text( f""" CREATE SPATIAL INDEX {self.quote(index_name)} - ON {self.table_identifier(dataset)} ({self.quote(geom_col)}); + ON {self.table_identifier(dataset)} ({self.quote(geom_col)}) + WITH (BOUNDING_BOX = (:min_x, :min_y, :max_x, :max_y)) """ - ) + ).bindparams(min_x=min_x, min_y=min_y, max_x=max_x, max_y=max_y) + # Placeholders not allowed in CREATE SPATIAL INDEX - have to use literal_binds. + # See https://docs.sqlalchemy.org/en/13/faq/sqlexpressions.html#faq-sql-expression-string + create_index.compile( + sess.connection(), compile_kwargs={"literal_binds": True} + ).execute() + L.info("Created spatial index in %ss", time.monotonic() - t0) def _drop_spatial_index(self, sess, dataset): @@ -321,24 +276,26 @@ def _quoted_trigger_name(self, dataset): def _create_triggers(self, sess, dataset): pk_name = dataset.primary_key - escaped_table_name = dataset.table_name.replace("'", "''") - - sess.execute( + create_trigger = sa.text( f""" CREATE TRIGGER {self._quoted_trigger_name(dataset)} ON {self.table_identifier(dataset)} AFTER INSERT, UPDATE, DELETE AS BEGIN MERGE {self.SNO_TRACK} TRA USING - (SELECT '{escaped_table_name}', {self.quote(pk_name)} FROM inserted - UNION SELECT '{escaped_table_name}', {self.quote(pk_name)} FROM deleted) + (SELECT :table_name, {self.quote(pk_name)} FROM inserted + UNION SELECT :table_name, {self.quote(pk_name)} FROM deleted) AS SRC (table_name, pk) ON SRC.table_name = TRA.table_name AND SRC.pk = TRA.pk WHEN NOT MATCHED THEN INSERT (table_name, pk) VALUES (SRC.table_name, SRC.pk); END; - """, - {"table_name": dataset.table_name}, - ) + """ + ).bindparams(table_name=dataset.table_name) + # Placeholders not allowed in CREATE TRIGGER - have to use literal_binds. + # See https://docs.sqlalchemy.org/en/13/faq/sqlexpressions.html#faq-sql-expression-string + create_trigger.compile( + sess.connection(), compile_kwargs={"literal_binds": True} + ).execute() @contextlib.contextmanager def _suspend_triggers(self, sess, dataset): @@ -352,24 +309,6 @@ def _suspend_triggers(self, sess, dataset): def meta_items(self, dataset): with self.session() as sess: - table_info_sql = """ - SELECT - C.column_name, C.ordinal_position, C.data_type, C.udt_name, - C.character_maximum_length, C.numeric_precision, C.numeric_scale, - KCU.ordinal_position AS pk_ordinal_position, - upper(postgis_typmod_type(A.atttypmod)) AS geometry_type, - postgis_typmod_srid(A.atttypmod) AS geometry_srid - FROM information_schema.columns C - LEFT OUTER JOIN information_schema.key_column_usage KCU - ON (KCU.table_schema = C.table_schema) - AND (KCU.table_name = C.table_name) - AND (KCU.column_name = C.column_name) - LEFT OUTER JOIN pg_attribute A - ON (A.attname = C.column_name) - AND (A.attrelid = (C.table_schema || '.' || C.table_name)::regclass::oid) - WHERE C.table_schema=:table_schema AND C.table_name=:table_name - ORDER BY C.ordinal_position; - """ table_info_sql = """ SELECT C.column_name, C.ordinal_position, C.data_type, @@ -405,10 +344,6 @@ def try_align_schema_col(cls, old_col_dict, new_col_dict): old_type = old_col_dict["dataType"] new_type = new_col_dict["dataType"] - # Some types have to be approximated as other types in SQL Server. - if sqlserver_adapter.APPROXIMATED_TYPES.get(old_type) == new_type: - new_col_dict["dataType"] = new_type = old_type - # Geometry type loses its extra type info when roundtripped through SQL Server. if new_type == "geometry": new_col_dict["geometryType"] = old_col_dict.get("geometryType") @@ -419,7 +354,7 @@ def try_align_schema_col(cls, old_col_dict, new_col_dict): def _remove_hidden_meta_diffs(self, dataset, ds_meta_items, wc_meta_items): super()._remove_hidden_meta_diffs(dataset, ds_meta_items, wc_meta_items) - # Nowhere to put these in postgis WC + # Nowhere to put these in SQL Server WC for key in self._UNSUPPORTED_META_ITEMS: if key in ds_meta_items: del ds_meta_items[key] @@ -447,15 +382,13 @@ class InstanceFunction(Function): >>> function(element) """ - pass - @compiles(InstanceFunction) def compile_instance_function(element, compiler, **kw): return "(%s).%s()" % (element.clauses, element.name) -class GeographyType(UserDefinedType): +class GeometryType(UserDefinedType): """UserDefinedType so that V2 geometry is adapted to MS binary format.""" def __init__(self, crs_id): @@ -468,7 +401,7 @@ def bind_processor(self, dialect): def bind_expression(self, bindvalue): # 2. Writing - SQL layer - wrap in call to STGeomFromWKB to convert WKB to MS binary. return Function( - quoted_name("geography::STGeomFromWKB", False), + quoted_name("geometry::STGeomFromWKB", False), bindvalue, self.crs_id, type_=self, @@ -484,10 +417,14 @@ def result_processor(self, dialect, coltype): class BaseDateOrTimeType(UserDefinedType): - """UserDefinedType so we read dates, times, and datetimes as text.""" + """ + UserDefinedType so we read dates, times, and datetimes as text. + They are stored as date / time / datetime in SQL Server, but read back out as text. + """ def column_expression(self, col): # When reading, convert dates and times to strings using style 127: ISO8601 with time zone Z. + # https://docs.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql return Function( "CONVERT", literal_column("NVARCHAR"), diff --git a/sno/working_copy/sqlserver_adapter.py b/sno/working_copy/sqlserver_adapter.py index de64fe778..86f10b329 100644 --- a/sno/working_copy/sqlserver_adapter.py +++ b/sno/working_copy/sqlserver_adapter.py @@ -1,4 +1,3 @@ -from sno import crs_util from sno.schema import Schema, ColumnSchema from sqlalchemy.sql.compiler import IdentifierPreparer @@ -17,7 +16,7 @@ def quote(ident): "blob": "varbinary", "date": "date", "float": {0: "real", 32: "real", 64: "float"}, - "geometry": "geography", + "geometry": "geometry", "integer": { 0: "int", 8: "tinyint", @@ -25,7 +24,7 @@ def quote(ident): 32: "int", 64: "bigint", }, - "interval": "nvarchar", + "interval": "text", "numeric": "numeric", "text": "nvarchar", "time": "time", @@ -40,6 +39,8 @@ def quote(ident): "bigint": ("integer", 64), "real": ("float", 32), "float": ("float", 64), + "binary": "blob", + "char": "text", "date": "date", "datetime": "timestamp", "datetime2": "timestamp", @@ -47,9 +48,10 @@ def quote(ident): "decimal": "numeric", "geography": "geometry", "geometry": "geometry", - "interval": "interval", + "nchar": "text", "numeric": "numeric", "nvarchar": "text", + "ntext": "text", "text": "text", "time": "time", "varchar": "text", @@ -91,13 +93,6 @@ def v2_type_to_ms_type(column_schema, v2_obj): return ms_type_info.get(extra_type_info.get("size", 0)) ms_type = ms_type_info - if ms_type == "geometry": - geometry_type = extra_type_info.get("geometryType") - crs_name = extra_type_info.get("geometryCRS") - crs_id = None - if crs_name is not None: - crs_id = crs_util.get_identifier_int_from_dataset(v2_obj, crs_name) - return _v2_geometry_type_to_ms_type(geometry_type, crs_id) if ms_type in ("varchar", "nvarchar", "varbinary"): length = extra_type_info.get("length", None) @@ -116,33 +111,19 @@ def v2_type_to_ms_type(column_schema, v2_obj): return ms_type -def _v2_geometry_type_to_ms_type(geometry_type, crs_id): - if geometry_type is not None: - geometry_type = geometry_type.replace(" ", "") - - if geometry_type is not None and crs_id is not None: - return f"geometry({geometry_type},{crs_id})" - elif geometry_type is not None: - return f"geometry({geometry_type})" - else: - return "geometry" - - def sqlserver_to_v2_schema(ms_table_info, id_salt): - """Generate a V2 schema from the given postgis metadata tables.""" + """Generate a V2 schema from the given SQL server metadata.""" return Schema([_sqlserver_to_column_schema(col, id_salt) for col in ms_table_info]) def _sqlserver_to_column_schema(ms_col_info, id_salt): """ - Given the postgis column info for a particular column, and some extra context in - case it is a geometry column, converts it to a ColumnSchema. The extra context will - only be used if the given ms_col_info is the geometry column. + Given the MS column info for a particular column, converts it to a ColumnSchema. + Parameters: ms_col_info - info about a single column from ms_table_info. - ms_spatial_ref_sys - rows of the "spatial_ref_sys" table that are referenced by this dataset. id_salt - the UUIDs of the generated ColumnSchema are deterministic and depend on - the name and type of the column, and on this salt. + the name and type of the column, and on this salt. """ name = ms_col_info["column_name"] pk_index = ms_col_info["pk_ordinal_position"] diff --git a/sno/working_copy/table_defs.py b/sno/working_copy/table_defs.py index 5f8fbd2dc..ef500911c 100644 --- a/sno/working_copy/table_defs.py +++ b/sno/working_copy/table_defs.py @@ -104,9 +104,11 @@ def create_all(self, session): class SqlServerSnoTables(TableSet): """ - Tables for sno-specific metadata - PostGIS variant. + Tables for sno-specific metadata - SQL Server variant. Table names have a user-defined schema, and so unlike other table sets, we need to construct an instance with the appropriate schema. + Primary keys have to be NVARCHAR of a fixed maximum length - + if the total maximum length is too long, SQL Server cannot generate an index. """ def __init__(self, schema=None): diff --git a/tests/conftest.py b/tests/conftest.py index 014dbda6c..47cdf6deb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -930,8 +930,8 @@ def ctx(create=False): def sqlserver_db(): """ Using docker, you can run a SQL Server test - such as those in test_working_copy_sqlserver - as follows: - docker run -it --rm -d -p 11433:1433 -e ACCEPT_EULA=Y -e 'SA_PASSWORD=Sql(server)' mcr.microsoft.com/mssql/server - SNO_SQLSERVER_URL='mssql://sa:Sql(server)@127.0.0.1:11433/master' spytest -k sqlserver --pdb -vvs + docker run -it --rm -d -p 11433:1433 -e ACCEPT_EULA=Y -e 'SA_PASSWORD=PassWord1' mcr.microsoft.com/mssql/server + SNO_SQLSERVER_URL='mssql://sa:PassWord1@127.0.0.1:11433/master' pytest -k sqlserver --pdb -vvs """ if "SNO_SQLSERVER_URL" not in os.environ: raise pytest.skip( @@ -939,7 +939,7 @@ def sqlserver_db(): ) engine = sqlserver_engine(os.environ["SNO_SQLSERVER_URL"]) with engine.connect() as conn: - # test connection and postgis support + # Test connection try: conn.execute("SELECT @@version;") except sqlalchemy.exc.DBAPIError: diff --git a/tests/test_working_copy_sqlserver.py b/tests/test_working_copy_sqlserver.py index 2853f2263..aa5e5f6a0 100644 --- a/tests/test_working_copy_sqlserver.py +++ b/tests/test_working_copy_sqlserver.py @@ -308,6 +308,6 @@ def test_types_roundtrip(data_archive, cli_runner, new_sqlserver_db_schema): r = cli_runner.invoke(["checkout"]) # If type-approximation roundtrip code isn't working, - # we would get spurious diffs on types that PostGIS doesn't support. + # we would get spurious diffs on types that SQL server doesn't support. r = cli_runner.invoke(["diff", "--exit-code"]) assert r.exit_code == 0, r.stdout