Skip to content

Commit

Permalink
Add support for SQL Server working copy
Browse files Browse the repository at this point in the history
  • Loading branch information
olsen232 committed Mar 4, 2021
1 parent afa7eeb commit d0c07d0
Show file tree
Hide file tree
Showing 14 changed files with 1,364 additions and 114 deletions.
2 changes: 1 addition & 1 deletion sno/checkout.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

def reset_wc_if_needed(repo, target_tree_or_commit, *, discard_changes=False):
"""Resets the working copy to the target if it does not already match, or if discard_changes is True."""
working_copy = WorkingCopy.get(repo, allow_uncreated=True)
working_copy = WorkingCopy.get(repo, allow_uncreated=True, allow_invalid_state=True)
if working_copy is None:
click.echo(
"(Bare sno repository - to create a working copy, use `sno create-workingcopy`)"
Expand Down
15 changes: 13 additions & 2 deletions sno/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def with_crs_id(self, crs_id):
crs_id_bytes = struct.pack("<i", crs_id)
return Geometry.of(self[:4] + crs_id_bytes + self[8:])

@property
def crs_id(self):
wkb_offset, is_le, crs_id = parse_gpkg_geom(self)
return crs_id

@classmethod
def from_wkt(cls, wkt):
return wkt_to_gpkg_geom(wkt)
Expand Down Expand Up @@ -290,12 +295,18 @@ def gpkg_geom_to_ogr(gpkg_geom, parse_crs=False):
return geom


def wkt_to_gpkg_geom(wkb, **kwargs):
ogr_geom = ogr.CreateGeometryFromWkt(wkb)
def wkt_to_gpkg_geom(wkt, **kwargs):
if wkt is None:
return None

ogr_geom = ogr.CreateGeometryFromWkt(wkt)
return ogr_to_gpkg_geom(ogr_geom, **kwargs)


def wkb_to_gpkg_geom(wkb, **kwargs):
if wkb is None:
return None

ogr_geom = ogr.CreateGeometryFromWkb(wkb)
return ogr_to_gpkg_geom(ogr_geom, **kwargs)

Expand Down
49 changes: 49 additions & 0 deletions sno/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import re
import socket
from urllib.parse import urlsplit, urlunsplit


import sqlalchemy
from pysqlite3 import dbapi2 as sqlite
import psycopg2
Expand Down Expand Up @@ -95,7 +100,51 @@ def _on_connect(psycopg2_conn, connection_record):
geometry_type = new_type((r[0],), "GEOMETRY", _adapt_geometry_from_pg)
register_type(geometry_type, psycopg2_conn)

pgurl = _append_query_to_url(pgurl, {"fallback_application_name": "sno"})

engine = sqlalchemy.create_engine(pgurl, module=psycopg2)
sqlalchemy.event.listen(engine, "connect", _on_connect)

return engine


CANONICAL_SQL_SERVER_SCHEME = "mssql"
INTERNAL_SQL_SERVER_SCHEME = "mssql+pyodbc"
SQL_SERVER_DRIVER_LIB = "ODBC+Driver+17+for+SQL+Server"


def sqlserver_engine(msurl):
url = urlsplit(msurl)
if url.scheme != CANONICAL_SQL_SERVER_SCHEME:
raise ValueError("Expecting mssql://")

# SQL server driver is fussy - doesn't like localhost, prefers 127.0.0.1
url_netloc = re.sub(r"\blocalhost\b", _replace_with_localhost, url.netloc)

url_query = _append_to_query(
url.query, {"driver": SQL_SERVER_DRIVER_LIB, "Application+Name": "sno"}
)

msurl = urlunsplit(
[INTERNAL_SQL_SERVER_SCHEME, url_netloc, url.path, url_query, ""]
)

engine = sqlalchemy.create_engine(msurl)
return engine


def _replace_with_localhost(*args, **kwargs):
return socket.gethostbyname("localhost")


def _append_query_to_url(uri, query_dict):
url = urlsplit(uri)
url_query = _append_to_query(url.query, query_dict)
return urlunsplit([url.scheme, url.netloc, url.path, url_query, ""])


def _append_to_query(url_query, query_dict):
for key, value in query_dict.items():
if key not in url_query:
url_query = "&".join(filter(None, [url_query, f"{key}={value}"]))
return url_query
116 changes: 58 additions & 58 deletions sno/working_copy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import click
import pygit2
import sqlalchemy
from sqlalchemy.dialects.postgresql import insert as postgresql_insert


from sno.base_dataset import BaseDataset
from sno.diff_structs import RepoDiff, DatasetDiff, DeltaDiff, Delta
Expand Down Expand Up @@ -39,6 +37,8 @@ def from_path(cls, path, allow_invalid=False):
path = str(path)
if path.startswith("postgresql:"):
return WorkingCopyType.POSTGIS
elif path.startswith("mssql:"):
return WorkingCopyType.SQL_SERVER
elif path.lower().endswith(".gpkg"):
return WorkingCopyType.GPKG
elif allow_invalid:
Expand All @@ -61,6 +61,10 @@ def class_(self):
from .postgis import WorkingCopy_Postgis

return WorkingCopy_Postgis
elif self is WorkingCopyType.SQL_SERVER:
from .sqlserver import WorkingCopy_SqlServer

return WorkingCopy_SqlServer
raise RuntimeError("Invalid WorkingCopyType")


Expand Down Expand Up @@ -149,32 +153,36 @@ def table_identifier(self, dataset_or_table):
)
return self.preparer.format_table(sqlalchemy_table)

def _table_def_for_column_schema(self, col, dataset):
# This is just used for selects/inserts/updates - we don't need the full type information.
# We only need type information if some automatic conversion needs to happen on read or write.
# TODO: return the full type information so we can also use it for CREATE TABLE.
return sqlalchemy.column(col.name)

@functools.lru_cache()
def table_def_for_dataset(self, dataset):
def _table_def_for_dataset(self, dataset, schema=None):
"""
A minimal table definition for a dataset.
Specifies table and column names only - the minimum required such that an insert() or update() will work.
"""
schema = schema or dataset.schema
return sqlalchemy.table(
dataset.table_name,
*[sqlalchemy.column(c.name) for c in dataset.schema.columns],
*[self._table_def_for_column_schema(c, dataset) for c in schema],
schema=self.db_schema,
)

def insert_into_dataset(self, dataset):
def _insert_into_dataset(self, dataset):
"""Returns a SQL command for inserting features into the table for that dataset."""
return self.table_def_for_dataset(dataset).insert()
return self._table_def_for_dataset(dataset).insert()

def insert_or_replace_into_dataset(self, dataset):
def _insert_or_replace_into_dataset(self, dataset):
"""
Returns a SQL command for inserting/replacing features that may or may not already exist in the table
for that dataset.
"""
# Even though this uses the postgresql_insert sqlalchemy feature, it results in generic SQL.
pk_col_names = [c.name for c in dataset.schema.pk_columns]
stmt = postgresql_insert(self.table_def_for_dataset(dataset))
update_dict = {c.name: c for c in stmt.excluded if c.name not in pk_col_names}
return stmt.on_conflict_do_update(index_elements=pk_col_names, set_=update_dict)
# Its not possible to do this in a DB agnostic way.
raise NotImplementedError()

@classmethod
def get(cls, repo, *, allow_uncreated=False, allow_invalid_state=False):
Expand Down Expand Up @@ -342,17 +350,16 @@ def delete(self, keep_db_schema_if_possible=False):
"""
raise NotImplementedError()

def get_db_tree(self, table_name="*"):
def get_db_tree(self):
"""Returns the hex tree ID from the state table."""
with self.session() as sess:
return sess.scalar(
f"""SELECT value FROM {self.SNO_STATE} WHERE table_name=:table_name AND key='tree';""",
{"table_name": table_name},
f"""SELECT value FROM {self.SNO_STATE} WHERE "table_name"='*' AND "key"='tree';""",
)

def assert_db_tree_match(self, tree, *, table_name="*"):
def assert_db_tree_match(self, tree):
"""Raises a Mismatch if sno_state refers to a different tree and not the given tree."""
wc_tree_id = self.get_db_tree(table_name)
wc_tree_id = self.get_db_tree()
expected_tree_id = tree.id.hex if isinstance(tree, pygit2.Tree) else tree

if wc_tree_id != expected_tree_id:
Expand Down Expand Up @@ -499,18 +506,13 @@ def diff_db_to_tree_feature(
feature_diff = DeltaDiff()
insert_count = delete_count = 0

geom_col = dataset.geom_column_name

for row in r:
track_pk = row[0] # This is always a str
db_obj = {k: row[k] for k in row.keys() if k != ".__track_pk"}

if db_obj[pk_field] is None:
db_obj = None

if db_obj is not None and geom_col is not None:
db_obj[geom_col] = self._db_geom_to_gpkg_geom(db_obj[geom_col])

try:
repo_obj = dataset.get_feature(track_pk)
except KeyError:
Expand Down Expand Up @@ -560,29 +562,32 @@ def _execute_dirty_rows_query(
else:
schema = dataset.schema

table_identifier = self.table_identifier(dataset)
pk_column = self.quote(schema.pk_columns[0].name)
col_names = ",".join(
[f"{table_identifier}.{self.quote(col.name)}" for col in schema]
)
sno_track = self.sno_tables.sno_track
table = self._table_def_for_dataset(dataset, schema=schema)

sql = f"""
SELECT
{self.SNO_TRACK}.pk AS ".__track_pk",
{col_names}
FROM {self.SNO_TRACK} LEFT OUTER JOIN {table_identifier}
ON ({self.SNO_TRACK}.pk = CAST({table_identifier}.{pk_column} AS TEXT))
WHERE ({self.SNO_TRACK}.table_name = :table_name)
"""
cols_to_select = [sno_track.c.pk.label(".__track_pk"), *table.columns]
pk_column = table.columns[schema.pk_columns[0].name]
tracking_col_type = sno_track.c.pk.type

base_query = sqlalchemy.select(columns=cols_to_select).select_from(
sno_track.outerjoin(
table,
sno_track.c.pk == sqlalchemy.cast(pk_column, tracking_col_type),
)
)

if feature_filter is UNFILTERED:
return sess.execute(sql, {"table_name": dataset.table_name})
query = base_query.where(sno_track.c.table_name == dataset.table_name)
else:
pks = list(feature_filter)
sql += f" AND {self.SNO_TRACK}.pk IN :pks"
t = sqlalchemy.text(sql)
t = t.bindparams(sqlalchemy.bindparam("pks", expanding=True))
return sess.execute(t, {"table_name": dataset.table_name, "pks": pks})
query = base_query.where(
sqlalchemy.and_(
sno_track.c.table_name == dataset.table_name,
sno_track.c.pk.in_(pks),
)
)

return sess.execute(query)

def reset_tracking_table(self, reset_filter=UNFILTERED):
"""Delete the rows from the tracking table that match the given filter."""
Expand Down Expand Up @@ -612,10 +617,6 @@ def reset_tracking_table(self, reset_filter=UNFILTERED):
t = t.bindparams(sqlalchemy.bindparam("pks", expanding=True))
sess.execute(t, {"table_name": table_name, "pks": pks})

def _db_geom_to_gpkg_geom(self, g):
"""Convert a geometry as returned by the database to a sno geometry.Geometry object."""
raise NotImplementedError()

def can_find_renames(self, meta_diff):
"""Can we find a renamed (aka moved) feature? There's no point looking for renames if the schema has changed."""
if "schema.json" not in meta_diff:
Expand Down Expand Up @@ -665,18 +666,21 @@ def update_state_table_tree(self, tree):
tree_id = tree.id.hex if isinstance(tree, pygit2.Tree) else tree
L.info(f"Tree sha: {tree_id}")
with self.session() as sess:
changes = self._update_state_table_tree_impl(sess, tree_id)
changes = self._insert_or_replace_state_table_tree(sess, tree_id)
assert changes == 1, f"{self.SNO_STATE} update: expected 1Δ, got {changes}"

def _update_state_table_tree_impl(self, sess, tree_id):
def _insert_or_replace_state_table_tree(self, sess, tree_id):
"""
Write the given tree ID to the state table.
sess - sqlalchemy session.
tree_id - str, the hex SHA of the tree at HEAD.
"""
r = sess.execute(
f"UPDATE {self.SNO_STATE} SET value=:value WHERE table_name='*' AND key='tree';",
f"""
INSERT INTO {self.SNO_STATE} ("table_name", "key", "value") VALUES ('*', 'tree', :value)
ON CONFLICT ("table_name", "key") DO UPDATE SET value=EXCLUDED.value;
""",
{"value": tree_id},
)
return r.rowcount
Expand All @@ -701,7 +705,7 @@ def write_full(self, commit, *datasets, **kwargs):
self._create_spatial_index(sess, dataset)

L.info("Creating features...")
sql = self.insert_into_dataset(dataset)
sql = self._insert_into_dataset(dataset)
feat_progress = 0
t0 = time.monotonic()
t0p = t0
Expand Down Expand Up @@ -738,12 +742,8 @@ def write_full(self, commit, *datasets, **kwargs):
self._create_triggers(sess, dataset)
self._update_last_write_time(sess, dataset, commit)

sess.execute(
f"""
INSERT INTO {self.SNO_STATE} (table_name, key, value) VALUES ('*', 'tree', :value)
ON CONFLICT (table_name, key) DO UPDATE SET value=EXCLUDED.value;
""",
{"value": commit.peel(pygit2.Tree).hex},
self._insert_or_replace_state_table_tree(
sess, commit.peel(pygit2.Tree).id.hex
)

def _create_table_for_dataset(self, sess, dataset):
Expand Down Expand Up @@ -776,15 +776,15 @@ def _write_features(self, sess, dataset, pk_list, *, ignore_missing=False):
if not pk_list:
return 0

sql = self.insert_or_replace_into_dataset(dataset)
sql = self._insert_or_replace_into_dataset(dataset)
feat_count = 0
CHUNK_SIZE = 10000
for row_dicts in self._chunk(
dataset.get_features_with_crs_ids(pk_list, ignore_missing=ignore_missing),
CHUNK_SIZE,
):
r = sess.execute(sql, row_dicts)
feat_count += r.rowcount
sess.execute(sql, row_dicts)
feat_count += len(row_dicts)

return feat_count

Expand Down Expand Up @@ -953,7 +953,7 @@ def reset(

if not track_changes_as_dirty:
# update the tree id
self._update_state_table_tree_impl(sess, target_tree_id)
self._insert_or_replace_state_table_tree(sess, target_tree_id)

def _filter_by_paths(self, datasets, paths):
"""Filters the datasets so that only those matching the paths are returned."""
Expand Down
Loading

0 comments on commit d0c07d0

Please sign in to comment.