diff --git a/aequilibrae/project/about.py b/aequilibrae/project/about.py index f7ffd1471..64c7280e9 100644 --- a/aequilibrae/project/about.py +++ b/aequilibrae/project/about.py @@ -1,7 +1,10 @@ -from os.path import join, dirname, realpath import string import uuid +from os.path import join, dirname, realpath + from aequilibrae.project.project_creation import run_queries_from_sql_file +from aequilibrae.utils.db_utils import commit_and_close +from aequilibrae.utils.spatialite_utils import connect_spatialite class About: @@ -27,29 +30,28 @@ class About: def __init__(self, project): self.__characteristics = [] self.__original = {} - self.__conn = project.conn + self.__path_to_file = project.path_to_file self.logger = project.logger - if self.__has_about(): - self.__load() + + with commit_and_close(connect_spatialite(self.__path_to_file)) as conn: + if self.__has_about(conn): + self.__load(conn) def create(self): """Creates the 'about' table for project files that did not previously contain it""" - if not self.__has_about(): - qry_file = join(dirname(realpath(__file__)), "database_specification", "tables", "about.sql") - run_queries_from_sql_file(self.__conn, qry_file) + with commit_and_close(connect_spatialite(self.__path_to_file)) as conn: + if not self.__has_about(conn): + qry_file = join(dirname(realpath(__file__)), "database_specification", "tables", "about.sql") + run_queries_from_sql_file(conn, self.logger, qry_file) - cursor = self.__conn.cursor() - - sql = "SELECT count(*) as num_records from about;" - if self.__conn.execute(sql).fetchone()[0] == 0: - cursor.execute(f"UPDATE 'about' set infovalue='{uuid.uuid4().hex}' where infoname='project_ID'") - cursor.execute("UPDATE 'about' set infovalue='right' where infoname='driving_side'") - self.__conn.commit() - - self.__load() - else: - self.logger.warning("About table already exists. Nothing was done") + sql = "SELECT count(*) as num_records from about;" + if conn.execute(sql).fetchone()[0] == 0: + conn.execute(f"UPDATE 'about' set infovalue='{uuid.uuid4().hex}' where infoname='project_ID'") + conn.execute("UPDATE 'about' set infovalue='right' where infoname='driving_side'") + self.__load(conn) + else: + self.logger.warning("About table already exists. Nothing was done") def list_fields(self) -> list: """Returns a list of all characteristics the about table holds""" @@ -78,10 +80,8 @@ def add_info_field(self, info_field: str) -> None: if has_forbidden: raise ValueError(f"{info_field} is not valid as a metadata field. Should be a lower case ascii letter or _") - sql = "INSERT INTO 'about' (infoname) VALUES(?)" - curr = self.__conn.cursor() - curr.execute(sql, [info_field]) - self.__conn.commit() + with commit_and_close(connect_spatialite(self.__path_to_file)) as conn: + conn.execute("INSERT INTO 'about' (infoname) VALUES(?)", [info_field]) self.__characteristics.append(info_field) self.__original[info_field] = None @@ -96,25 +96,21 @@ def write_back(self): >>> p.about.description = 'This is the example project. Do not use for forecast' >>> p.about.write_back() """ - curr = self.__conn.cursor() - for k in self.__characteristics: - v = self.__dict__[k] - if v != self.__original[k]: - curr.execute("UPDATE 'about' set infovalue = ? where infoname=?", [v, k]) - self.logger.info(f"Updated {k} on About_Table to {v}") - self.__conn.commit() - - def __has_about(self): - curr = self.__conn.cursor() - curr.execute("SELECT name FROM sqlite_master WHERE type='table';") - return any(["about" in x[0] for x in curr.fetchall()]) - - def __load(self): + with commit_and_close(connect_spatialite(self.__path_to_file)) as conn: + for k in self.__characteristics: + v = self.__dict__[k] + if v != self.__original[k]: + conn.execute("UPDATE 'about' set infovalue = ? where infoname=?", [v, k]) + self.logger.info(f"Updated {k} on About_Table to {v}") + + def __has_about(self, conn): + sql = "SELECT name FROM sqlite_master WHERE type='table';" + return any(["about" in x[0] for x in conn.execute(sql).fetchall()]) + + def __load(self, conn): self.__characteristics = [] - curr = self.__conn.cursor() - curr.execute("select infoname, infovalue from 'about'") - - for x in curr.fetchall(): + sql = "select infoname, infovalue from 'about'" + for x in conn.execute(sql).fetchall(): self.__characteristics.append(x[0]) self.__dict__[x[0]] = x[1] self.__original[x[0]] = x[1] diff --git a/aequilibrae/project/basic_table.py b/aequilibrae/project/basic_table.py index 26c4e3aad..9845c5154 100644 --- a/aequilibrae/project/basic_table.py +++ b/aequilibrae/project/basic_table.py @@ -2,6 +2,7 @@ from shapely.geometry import Polygon from aequilibrae.project.field_editor import FieldEditor +from aequilibrae.utils.db_utils import commit_and_close class BasicTable: @@ -12,8 +13,6 @@ class BasicTable: def __init__(self, project): self.project = project self.__table_type__ = "" - self.conn = project.connect() - self._curr = self.conn.cursor() def extent(self) -> Polygon: """Queries the extent of thelayer included in the model @@ -21,20 +20,15 @@ def extent(self) -> Polygon: Returns: *model extent* (:obj:`Polygon`): Shapely polygon with the bounding box of the layer. """ - self.__curr.execute(f'Select ST_asBinary(GetLayerExtent("{self.__table_type__}"))') - poly = shapely.wkb.loads(self.__curr.fetchone()[0]) - return poly + with commit_and_close(self.project.connect()) as conn: + data = conn.execute(f'Select ST_asBinary(GetLayerExtent("{self.__table_type__}"))').fetchone()[0] + return shapely.wkb.loads(data) @property def fields(self) -> FieldEditor: """Returns a FieldEditor class instance to edit the zones table fields and their metadata""" return FieldEditor(self.project, self.__table_type__) - def refresh_connection(self): - """Opens a new database connection to avoid thread conflict""" - self.conn = self.project.connect() - self.__curr = self.conn.cursor() - def __copy__(self): raise Exception(f"{self.__table_type__} object cannot be copied") diff --git a/aequilibrae/project/data/matrices.py b/aequilibrae/project/data/matrices.py index efd01b2f8..e2ec90e79 100644 --- a/aequilibrae/project/data/matrices.py +++ b/aequilibrae/project/data/matrices.py @@ -1,9 +1,12 @@ import os from os.path import isfile, join + import pandas as pd -from aequilibrae.project.table_loader import TableLoader + from aequilibrae.matrix import AequilibraeMatrix from aequilibrae.project.data.matrix_record import MatrixRecord +from aequilibrae.project.table_loader import TableLoader +from aequilibrae.utils.db_utils import commit_and_close class Matrices: @@ -17,10 +20,9 @@ def __init__(self, project): self.fldr = os.path.join(project.project_base_path, "matrices") - self.conn = project.connect() - self.curr = self.conn.cursor() tl = TableLoader() - matrices_list = tl.load_table(self.curr, "matrices") + with commit_and_close(self.project.connect()) as conn: + matrices_list = tl.load_table(conn, "matrices") self.__fields = [x for x in tl.fields] if matrices_list: self.__properties = list(matrices_list[0].keys()) @@ -41,16 +43,16 @@ def reload(self): def clear_database(self) -> None: """Removes records from the matrices database that do not exist in disk""" - self.curr.execute("Select name, file_name from matrices;") + with commit_and_close(self.project.connect()) as conn: + mats = conn.execute("Select name, file_name from matrices;").fetchall() - remove = [nm for nm, file in self.curr.fetchall() if not isfile(join(self.fldr, file))] + remove = [nm for nm, file in mats if not isfile(join(self.fldr, file))] - if remove: - self.logger.warning(f'Matrix records not found in disk cleaned from database: {",".join(remove)}') + if remove: + self.logger.warning(f'Matrix records not found in disk cleaned from database: {",".join(remove)}') - remove = [[x] for x in remove] - self.curr.executemany("DELETE from matrices where name=?;", remove) - self.conn.commit() + remove = [[x] for x in remove] + conn.executemany("DELETE from matrices where name=?;", remove) def update_database(self) -> None: """Adds records to the matrices database for matrix files found on disk""" @@ -96,9 +98,10 @@ def check_if_exists(file_name): else: return "file missing" - df = pd.read_sql_query("Select * from matrices;", self.conn) - df = df.assign(status="") - df.status = df.file_name.apply(check_if_exists) + with commit_and_close(self.project.connect()) as conn: + df = pd.read_sql_query("Select * from matrices;", conn) + df = df.assign(status="") + df.status = df.file_name.apply(check_if_exists) return df diff --git a/aequilibrae/project/data/matrix_record.py b/aequilibrae/project/data/matrix_record.py index 71a64c76c..dec285bef 100644 --- a/aequilibrae/project/data/matrix_record.py +++ b/aequilibrae/project/data/matrix_record.py @@ -1,48 +1,47 @@ from os import unlink from os.path import isfile, join -from aequilibrae.project.network.safe_class import SafeClass + from aequilibrae.matrix.aequilibrae_matrix import AequilibraeMatrix +from aequilibrae.project.network.safe_class import SafeClass +from aequilibrae.utils.db_utils import commit_and_close class MatrixRecord(SafeClass): def __init__(self, data_set: dict, project): super().__init__(data_set, project) - self._exists = True - self.fldr = join(project.project_base_path, "matrices") + self._exists: bool + self.fldr: str + self.__dict__["_exists"] = True + self.__dict__["fldr"] = join(project.project_base_path, "matrices") def save(self): """Saves matrix record to the project database""" - conn = self.connect_db() - curr = conn.cursor() + with commit_and_close(self.connect_db()) as conn: + sql = "select count(*) from matrices where name=?" - curr.execute("select count(*) from matrices where name=?", [self.name]) - if curr.fetchone()[0] == 0: - data = [str(self.name), str(self.file_name), int(self.cores)] - curr.execute("Insert into matrices (name, file_name, cores) values(?,?,?)", data) + if conn.execute(sql, [self.name]).fetchone()[0] == 0: + data = [str(self.name), str(self.file_name), int(self.cores)] + conn.execute("Insert into matrices (name, file_name, cores) values(?,?,?)", data) - for key, value in self.__dict__.items(): - if key != "name" and key in self.__original__: - v_old = self.__original__.get(key, None) - if value != v_old and value: - self.__original__[key] = value - curr.execute(f"update matrices set '{key}'=? where name=?", [value, self.name]) - conn.commit() - conn.close() + for key, value in self.__dict__.items(): + if key != "name" and key in self.__original__: + v_old = self.__original__.get(key, None) + if value != v_old and value: + self.__original__[key] = value + conn.execute(f"update matrices set '{key}'=? where name=?", [value, self.name]) def delete(self): """Deletes this matrix record and the underlying data from disk""" - conn = self.connect_db() - curr = conn.cursor() - curr.execute("DELETE FROM matrices where name=?", [self.name]) - conn.commit() + with commit_and_close(self.connect_db()) as conn: + conn.execute("DELETE FROM matrices where name=?", [self.name]) + if isfile(join(self.fldr, self.file_name)): try: unlink(join(self.fldr, self.file_name)) except Exception as e: self._logger.error(f"Could not remove matrix from disk: {e.args}") - conn.close() - self._exists = False + self.__dict__["_exists"] = False def update_cores(self): """Updates this matrix record with the matrix core count in disk""" @@ -59,20 +58,14 @@ def get_data(self) -> AequilibraeMatrix: return mat def __setattr__(self, instance, value) -> None: - if instance == "name": - value = str(value).lower() - conn = self.connect_db() - curr = conn.cursor() - curr.execute("Select count(*) from matrices where LOWER(name)=?", [value]) - if sum(curr.fetchone()) > 0: - raise ValueError("Another matrix with this name already exists") - conn.close() - elif instance == "file_name": - conn = self.connect_db() - curr = conn.cursor() - curr.execute("Select count(*) from matrices where LOWER(file_name)=?", [str(value).lower()]) - if sum(curr.fetchone()) > 0: - raise ValueError("There is another matrix record for this file") + with commit_and_close(self.connect_db()) as conn: + sql = f"Select count(*) from matrices where LOWER({instance})=?" + qry_value = sum(conn.execute(sql, [str(value).lower()]).fetchone()) + if qry_value > 0: + if instance == "name": + raise ValueError("Another matrix with this name already exists") + elif instance == "file_name": + raise ValueError("There is another matrix record for this file") self.__dict__[instance] = value if instance in ["file_name", "cores"]: diff --git a/aequilibrae/project/data_loader.py b/aequilibrae/project/data_loader.py index cfe7bb21e..bd13d023f 100644 --- a/aequilibrae/project/data_loader.py +++ b/aequilibrae/project/data_loader.py @@ -1,28 +1,32 @@ -from sqlite3 import Connection -import shapely.wkb +from os import PathLike + import pandas as pd +import shapely.wkb + +from aequilibrae.utils.db_utils import commit_and_close +from aequilibrae.utils.spatialite_utils import connect_spatialite class DataLoader: - def __init__(self, conn: Connection, table_name: str): - self.conn = conn - self.curr = conn.cursor() + def __init__(self, path_to_file: PathLike, table_name: str): + self.__pth_file = path_to_file self.table_name = table_name def load_table(self) -> pd.DataFrame: - fields, _, geo_field = self.__find_table_fields() - fields = [f'"{x}"' for x in fields] - if geo_field is not None: - fields.append('ST_AsBinary("geometry") geometry') - keys = ",".join(fields) - df = pd.read_sql_query(f"select {keys} from '{self.table_name}'", self.conn) + with commit_and_close(connect_spatialite(self.__pth_file)) as conn: + fields, _, geo_field = self.__find_table_fields() + fields = [f'"{x}"' for x in fields] + if geo_field is not None: + fields.append('ST_AsBinary("geometry") geometry') + keys = ",".join(fields) + df = pd.read_sql_query(f"select {keys} from '{self.table_name}'", conn) df.geometry = df.geometry.apply(shapely.wkb.loads) return df def __find_table_fields(self): + with commit_and_close(connect_spatialite(self.__pth_file)) as conn: + structure = conn.execute(f"pragma table_info({self.table_name})").fetchall() geotypes = ["LINESTRING", "POINT", "POLYGON", "MULTIPOLYGON"] - self.curr.execute(f"pragma table_info({self.table_name})") - structure = self.curr.fetchall() fields = [x[1].lower() for x in structure] geotype = geo_field = None for x in structure: diff --git a/aequilibrae/project/field_editor.py b/aequilibrae/project/field_editor.py index 12bd77b4c..7fb1b0ac7 100644 --- a/aequilibrae/project/field_editor.py +++ b/aequilibrae/project/field_editor.py @@ -2,6 +2,8 @@ import string from typing import List +from aequilibrae.utils.db_utils import commit_and_close + ALLOWED_CHARACTERS = string.ascii_letters + "_0123456789" @@ -129,18 +131,13 @@ def __adds_to_attribute_table(self, attribute_name, attribute_value): self.__run_query_commit(qry, vals) def __run_query_fetch_all(self, qry: str): - conn = self.project.connect() - curr = conn.cursor() - curr.execute(qry) - dt = curr.fetchall() - conn.close() + with commit_and_close(self.project.connect()) as conn: + dt = conn.execute(qry).fetchall() return dt def __run_query_commit(self, qry: str, values=None) -> None: - conn = self.project.connect() - if values is None: - conn.execute(qry) - else: - conn.execute(qry, values) - conn.commit() - conn.close() + with commit_and_close(self.project.connect()) as conn: + if values is None: + conn.execute(qry) + else: + conn.execute(qry, values) diff --git a/aequilibrae/project/network/connector_creation.py b/aequilibrae/project/network/connector_creation.py index e995497d8..1b3b2fd67 100644 --- a/aequilibrae/project/network/connector_creation.py +++ b/aequilibrae/project/network/connector_creation.py @@ -1,114 +1,112 @@ from math import pi, sqrt +from sqlite3 import Connection +from typing import Optional + import numpy as np from scipy.cluster.vq import kmeans2, whiten from scipy.spatial.distance import cdist import shapely.wkb from shapely.geometry import LineString +from aequilibrae.utils.db_utils import commit_and_close INFINITE_CAPACITY = 99999 -def connector_creation(geo, zone_id: int, srid: int, mode_id: str, network, link_types="", connectors=1): +def connector_creation( + geo, zone_id: int, srid: int, mode_id: str, network, link_types="", connectors=1, conn_: Optional[Connection] = None +): if len(mode_id) > 1: raise Exception("We can only add centroid connectors for one mode at a time") - conn = network.project.connect() - curr = conn.cursor() - logger = network.project.logger - curr.execute("select count(*) from nodes where node_id=?", [zone_id]) - if curr.fetchone() is None: - logger.warning("This centroid does not exist. Please create it first") - return - - proj_nodes = network.nodes - node = proj_nodes.get(zone_id) - curr.execute("select count(*) from links where a_node=? and instr(modes,?) > 0", [zone_id, mode_id]) - if curr.fetchone()[0] > 0: - logger.warning("Mode is already connected") - return - - if len(link_types) > 0: - lt = f"*[{link_types}]*" - else: - curr.execute("Select link_type_id from link_types") - lt = "".join([x[0] for x in curr.fetchall()]) - lt = f"*[{lt}]*" - - sql = """select node_id, ST_asBinary(geometry), modes, link_types from nodes where ST_Within(geometry, GeomFromWKB(?, ?)) and - (nodes.rowid in (select rowid from SpatialIndex where f_table_name = 'nodes' and - search_frame = GeomFromWKB(?, ?))) - and link_types glob ? and instr(modes, ?)>0""" - - # We expand the area by its average radius until it is 20 times - # beginning with a strict search within the zone - buffer = 0 - increase = sqrt(geo.area / pi) - dt = [] - while dt == [] and buffer <= increase * 10: - wkb = geo.buffer(buffer).wkb - curr.execute(sql, [wkb, srid, wkb, srid, lt, mode_id]) - dt = curr.fetchall() - buffer += increase - - if buffer > increase: - msg = f"Could not find node inside zone {zone_id}. Search area was expanded until we found a suitable node" - logger.warning(msg) - if dt == []: - logger.warning( - f"FAILED! Could not find suitable nodes to connect within 5 times the diameter of zone {zone_id}." - ) - return - - coords = [] - nodes = [] - for node_id, wkb, modes, link_types in dt: - geo = shapely.wkb.loads(wkb) - coords.append([geo.x, geo.y]) - nodes.append(node_id) - - num_connectors = connectors - if len(nodes) == 0: - raise Exception("We could not find any candidate nodes that satisfied your criteria") - elif len(nodes) < connectors: - logger.warning( - f"We have fewer possible nodes than required connectors for zone {zone_id}. Will connect all of them." - ) - num_connectors = len(nodes) - - if num_connectors == len(coords): - all_nodes = nodes - else: - features = np.array(coords) - whitened = whiten(features) - centroids, allocation = kmeans2(whitened, num_connectors) - - all_nodes = set() - for i in range(num_connectors): - nds = [x for x, y in zip(nodes, list(allocation)) if y == i] - centr = centroids[i] - positions = [x for x, y in zip(whitened, allocation) if y == i] - if positions: - dist = cdist(np.array([centr]), np.array(positions)).flatten() - node_to_connect = nds[dist.argmin()] - all_nodes.add(node_to_connect) - - nds = list(all_nodes) - data = [zone_id] + nds - curr.execute(f'select b_node from links where a_node=? and b_node in ({",".join(["?"] * len(nds))})', data) - - data = [x[0] for x in curr.fetchall()] - - if data: - qry = ",".join(["?"] * len(data)) - dt = [mode_id, zone_id] + data - curr.execute(f"Update links set modes=modes || ? where a_node=? and b_node in ({qry})", dt) - nds = [x for x in nds if x not in data] - logger.warning(f"Mode {mode_id} added to {len(data)} existing centroid connectors for zone {zone_id}") - conn.commit() - - curr.close() - links = network.links + with conn_ or commit_and_close(network.project.connect()) as conn: + logger = network.project.logger + if conn.execute("select count(*) from nodes where node_id=?", [zone_id]).fetchone() is None: + logger.warning("This centroid does not exist. Please create it first") + return + + proj_nodes = network.nodes + node = proj_nodes.get(zone_id) + sql = "select count(*) from links where a_node=? and instr(modes,?) > 0" + if conn.execute(sql, [zone_id, mode_id]).fetchone()[0] > 0: + logger.warning("Mode is already connected") + return + + if len(link_types) > 0: + lt = f"*[{link_types}]*" + else: + lt = "".join([x[0] for x in conn.execute("Select link_type_id from link_types").fetchall()]) + lt = f"*[{lt}]*" + + sql = """select node_id, ST_asBinary(geometry), modes, link_types from nodes where ST_Within(geometry, GeomFromWKB(?, ?)) and + (nodes.rowid in (select rowid from SpatialIndex where f_table_name = 'nodes' and + search_frame = GeomFromWKB(?, ?))) + and link_types glob ? and instr(modes, ?)>0""" + + # We expand the area by its average radius until it is 20 times + # beginning with a strict search within the zone + buffer = 0 + increase = sqrt(geo.area / pi) + dt = [] + while dt == [] and buffer <= increase * 10: + wkb = geo.buffer(buffer).wkb + dt = conn.execute(sql, [wkb, srid, wkb, srid, lt, mode_id]).fetchall() + buffer += increase + + if buffer > increase: + msg = f"Could not find node inside zone {zone_id}. Search area was expanded until we found a suitable node" + logger.warning(msg) + if dt == []: + logger.warning( + f"FAILED! Could not find suitable nodes to connect within 5 times the diameter of zone {zone_id}." + ) + return + + coords = [] + nodes = [] + for node_id, wkb, modes, link_types in dt: + geo = shapely.wkb.loads(wkb) + coords.append([geo.x, geo.y]) + nodes.append(node_id) + + num_connectors = connectors + if len(nodes) == 0: + raise Exception("We could not find any candidate nodes that satisfied your criteria") + elif len(nodes) < connectors: + logger.warning( + f"We have fewer possible nodes than required connectors for zone {zone_id}. Will connect all of them." + ) + num_connectors = len(nodes) + + if num_connectors == len(coords): + all_nodes = nodes + else: + features = np.array(coords) + whitened = whiten(features) + centroids, allocation = kmeans2(whitened, num_connectors) + + all_nodes = set() + for i in range(num_connectors): + nds = [x for x, y in zip(nodes, list(allocation)) if y == i] + centr = centroids[i] + positions = [x for x, y in zip(whitened, allocation) if y == i] + if positions: + dist = cdist(np.array([centr]), np.array(positions)).flatten() + node_to_connect = nds[dist.argmin()] + all_nodes.add(node_to_connect) + + nds = list(all_nodes) + data = [zone_id] + nds + sql = f'select b_node from links where a_node=? and b_node in ({",".join(["?"] * len(nds))})' + data = [x[0] for x in conn.execute(sql, data).fetchall()] + + if data: + qry = ",".join(["?"] * len(data)) + dt = [mode_id, zone_id] + data + conn.execute(f"Update links set modes=modes || ? where a_node=? and b_node in ({qry})", dt) + nds = [x for x in nds if x not in data] + logger.warning(f"Mode {mode_id} added to {len(data)} existing centroid connectors for zone {zone_id}") + links = network.links for node_to_connect in nds: link = links.new() node_to = proj_nodes.get(node_to_connect) @@ -122,5 +120,3 @@ def connector_creation(geo, zone_id: int, srid: int, mode_id: str, network, link link.save() if nds: logger.warning(f"{len(nds)} new centroid connectors for mode {mode_id} added for centroid {zone_id}") - - conn.commit() diff --git a/aequilibrae/project/network/gmns_builder.py b/aequilibrae/project/network/gmns_builder.py index de8d8d1a8..5f97acec6 100644 --- a/aequilibrae/project/network/gmns_builder.py +++ b/aequilibrae/project/network/gmns_builder.py @@ -11,6 +11,8 @@ from aequilibrae import logger from aequilibrae.parameters import Parameters +from aequilibrae.utils.db_utils import commit_and_close +from aequilibrae.utils.spatialite_utils import connect_spatialite class GMNSBuilder(WorkerThread): @@ -23,7 +25,7 @@ def __init__( self.nodes = net.nodes self.link_types = net.link_types self.modes = net.modes - self.conn = net.conn + self.__pth_file = net.project.path_to_file self.link_df = pd.read_csv(link_path).fillna("") self.node_df = pd.read_csv(node_path).fillna("") @@ -378,9 +380,10 @@ def save_modes_to_aeq(self): mode_ids_list = [x for x in modes_list] saved_modes = list(self.modes.all_modes()) - modes_df = pd.DataFrame( - self.conn.execute("select mode_name, mode_id from modes").fetchall(), columns=["name", "id"] - ) + with commit_and_close(connect_spatialite(self.__pth_file)) as conn: + modes_df = pd.DataFrame( + conn.execute("select mode_name, mode_id from modes").fetchall(), columns=["name", "id"] + ) for mode in list(dict.fromkeys(modes_list)): if mode in groups_dict.keys(): modes_gathered = [m.replace(" ", "") for m in groups_dict[mode].split(sep=",")] @@ -494,16 +497,15 @@ def save_to_database(self, links_fields, nodes_fields): ) n_params_list = aeq_nodes_df.to_records(index=False) - self.conn.executemany(n_query, n_params_list) - self.conn.commit() + with commit_and_close(connect_spatialite(self.__pth_file)) as conn: + conn.executemany(n_query, n_params_list) - l_query = "insert into links(" + ", ".join(list(links_fields.keys())) + ")" - l_query += ( - " values(" - + ", ".join(["GeomFromTEXT(?,4326)" if x == "geometry" else "?" for x in list(links_fields.keys())]) - + ")" - ) - l_params_list = aeq_links_df.to_records(index=False) + l_query = "insert into links(" + ", ".join(list(links_fields.keys())) + ")" + l_query += ( + " values(" + + ", ".join(["GeomFromTEXT(?,4326)" if x == "geometry" else "?" for x in list(links_fields.keys())]) + + ")" + ) + l_params_list = aeq_links_df.to_records(index=False) - self.conn.executemany(l_query, l_params_list) - self.conn.commit() + conn.executemany(l_query, l_params_list) diff --git a/aequilibrae/project/network/gmns_exporter.py b/aequilibrae/project/network/gmns_exporter.py index 85251fc8b..712cff0d3 100644 --- a/aequilibrae/project/network/gmns_exporter.py +++ b/aequilibrae/project/network/gmns_exporter.py @@ -3,6 +3,7 @@ from ...utils import WorkerThread from aequilibrae.parameters import Parameters +from aequilibrae.utils.db_utils import commit_and_close class GMNSExporter(WorkerThread): @@ -12,14 +13,14 @@ def __init__(self, net, path) -> None: self.links_df = net.links.data self.nodes_df = net.nodes.data self.source = net.source - self.conn = net.conn self.output_path = path self.gmns_parameters = self.p.parameters["network"]["gmns"] self.gmns_links = self.gmns_parameters["link"] self.gmns_nodes = self.gmns_parameters["node"] - cur = self.conn.execute("select mode_name, mode_id, description, pce, vot, ppv from modes").fetchall() + with commit_and_close(net.project.connect()) as conn: + cur = conn.execute("select mode_name, mode_id, description, pce, vot, ppv from modes").fetchall() self.modes_df = pd.DataFrame(cur, columns=["mode_name", "mode_id", "description", "pce", "vot", "ppv"]) def doWork(self): diff --git a/aequilibrae/project/network/link.py b/aequilibrae/project/network/link.py index 1037c21eb..01834e6e7 100644 --- a/aequilibrae/project/network/link.py +++ b/aequilibrae/project/network/link.py @@ -1,6 +1,8 @@ from typing import Union -from .safe_class import SafeClass + from aequilibrae.project.network.mode import Mode +from aequilibrae.utils.db_utils import commit_and_close +from .safe_class import SafeClass class Link(SafeClass): @@ -53,27 +55,19 @@ def __init__(self, dataset, project): def delete(self): """Deletes link from database""" - conn = self.connect_db() - curr = conn.cursor() - curr.execute(f'DELETE FROM links where link_id="{self.link_id}"') - conn.commit() + with commit_and_close(self.connect_db()) as conn: + conn.execute(f'DELETE FROM links where link_id="{self.link_id}"') self.__stil_exists = False def save(self): """Saves link to database""" - conn = self.connect_db() - curr = conn.cursor() - if self.__new: - data, sql = self._save_new_with_geometry() - else: - data, sql = self.__save_existing_link() + data, sql = self._save_new_with_geometry() if self.__new else self.__save_existing_link() if data: - curr.execute(sql, data) + with commit_and_close(self.connect_db()) as conn: + conn.execute(sql, data) - conn.commit() - conn.close() self.__new = False for key in self.__original__.keys(): diff --git a/aequilibrae/project/network/link_type.py b/aequilibrae/project/network/link_type.py index eb163b6ca..857ac326b 100644 --- a/aequilibrae/project/network/link_type.py +++ b/aequilibrae/project/network/link_type.py @@ -1,5 +1,6 @@ import string from .safe_class import SafeClass +from aequilibrae.utils.db_utils import commit_and_close class LinkType(SafeClass): @@ -8,10 +9,8 @@ class LinkType(SafeClass): __alowed_characters = string.ascii_letters + "_" def delete(self): - conn = self.connect_db() - curr = conn.cursor() - curr.execute(f'DELETE FROM link_types where link_type_id="{self.link_type_id}"') - conn.commit() + with commit_and_close(self.connect_db()) as conn: + conn.execute(f'DELETE FROM link_types where link_type_id="{self.link_type_id}"') del self def save(self): diff --git a/aequilibrae/project/network/link_types.py b/aequilibrae/project/network/link_types.py index 2a636a8be..9dfe11a90 100644 --- a/aequilibrae/project/network/link_types.py +++ b/aequilibrae/project/network/link_types.py @@ -2,6 +2,8 @@ from aequilibrae.project.network.link_type import LinkType from aequilibrae.project.field_editor import FieldEditor from aequilibrae.project.table_loader import TableLoader +from aequilibrae.utils.db_utils import commit_and_close +from aequilibrae.utils.spatialite_utils import connect_spatialite class LinkTypes: @@ -59,11 +61,10 @@ def __init__(self, net): self.__items = {} self.project = net.project self.logger = net.project.logger - self.conn = net.conn # type: Connection - self.curr = net.conn.cursor() tl = TableLoader() - link_types_list = tl.load_table(self.curr, "link_types") + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + link_types_list = tl.load_table(conn, "link_types") existing_list = [lt["link_type_id"] for lt in link_types_list] self.__fields = [x for x in tl.fields] @@ -92,7 +93,6 @@ def delete(self, link_type_id: str) -> None: lt = self.__items[link_type_id] # type: LinkType lt.delete() del self.__items[link_type_id] - self.conn.commit() except IntegrityError as e: self.logger.error(f"Failed to remove link_type {link_type_id}. {e.args}") raise e diff --git a/aequilibrae/project/network/links.py b/aequilibrae/project/network/links.py index 1af7e4f29..3cd532367 100644 --- a/aequilibrae/project/network/links.py +++ b/aequilibrae/project/network/links.py @@ -7,6 +7,8 @@ from aequilibrae.project.data_loader import DataLoader from aequilibrae.project.network.link import Link from aequilibrae.project.table_loader import TableLoader +from aequilibrae.utils.db_utils import commit_and_close +from aequilibrae.utils.spatialite_utils import connect_spatialite class Links(BasicTable): @@ -115,9 +117,8 @@ def delete(self, link_id: int) -> None: link = self.__items.pop(link_id) # type: Link link.delete() else: - self._curr.execute("Delete from Links where link_id=?", [link_id]) - d = self._curr.rowcount - self.conn.commit() + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + d = conn.execute("Delete from Links where link_id=?", [link_id]).rowcount if d: self.project.logger.warning(f"Link {link_id} was successfully removed from the project database") else: @@ -125,10 +126,10 @@ def delete(self, link_id: int) -> None: def refresh_fields(self) -> None: """After adding a field one needs to refresh all the fields recognized by the software""" - self._curr.execute("select coalesce(max(link_id),0) from Links") - self.__max_id = self._curr.fetchone()[0] tl = TableLoader() - tl.load_structure(self._curr, "links") + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + self.__max_id = conn.execute("select coalesce(max(link_id),0) from Links").fetchone()[0] + tl.load_structure(conn, "links") self.sql = tl.sql self.__fields = deepcopy(tl.fields) @@ -139,7 +140,7 @@ def data(self) -> pd.DataFrame: :Returns: **table** (:obj:`DataFrame`): Pandas dataframe with all the links, complete with Geometry """ - dl = DataLoader(self.conn, "links") + dl = DataLoader(self.project.path_to_file, "links") return dl.load_table() def refresh(self): @@ -159,8 +160,8 @@ def __existence_error(self, link_id): raise ValueError(f"Link {link_id} does not exist in the model") def __link_data(self, link_id: int) -> dict: - self._curr.execute(f"{self.sql} where link_id=?", [link_id]) - data = self._curr.fetchone() + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + data = conn.execute(f"{self.sql} where link_id=?", [link_id]).fetchone() if data: return {key: val for key, val in zip(self.__fields, data)} raise ValueError("Link_id does not exist on the network") diff --git a/aequilibrae/project/network/mode.py b/aequilibrae/project/network/mode.py index fe95c8c40..6da2dee9e 100644 --- a/aequilibrae/project/network/mode.py +++ b/aequilibrae/project/network/mode.py @@ -1,5 +1,8 @@ import string +from aequilibrae.utils.db_utils import commit_and_close +from aequilibrae.utils.spatialite_utils import connect_spatialite + class Mode: """A mode object represents a single record in the *modes* table""" @@ -13,17 +16,14 @@ def __init__(self, mode_id: str, project) -> None: if len(mode_id) != 1 or mode_id not in string.ascii_letters: raise ValueError("Mode IDs must be a single ascii character") - conn = self.project.connect() - curr = conn.cursor() - curr.execute("pragma table_info(modes)") - table_struct = curr.fetchall() - self.__fields = [x[1] for x in table_struct] - self.__original__ = {} + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + table_struct = conn.execute("pragma table_info(modes)").fetchall() + self.__fields = [x[1] for x in table_struct] + self.__original__ = {} + # data for the mode + dt = conn.execute(f"select * from 'modes' where mode_id='{mode_id}'").fetchone() - # data for the mode - curr.execute(f"select * from 'modes' where mode_id='{mode_id}'") - dt = curr.fetchone() if dt is None: # if the mode was not found, we return a new one for k in self.__fields: @@ -35,7 +35,6 @@ def __init__(self, mode_id: str, project) -> None: for k, v in zip(self.__fields, dt): self.__dict__[k] = v self.__original__[k] = v - conn.close() def __setattr__(self, instance, value) -> None: if instance == "mode_name" and value is None: @@ -54,21 +53,15 @@ def save(self): if letter not in self.__alowed_characters: raise ValueError('mode_name can only contain letters and "_"') - conn = self.project.connect() - curr = conn.cursor() - - curr.execute(f'select count(*) from modes where mode_id="{self.mode_id}"') - if curr.fetchone()[0] == 0: - raise ValueError("Mode does not exist in the model. You need to explicitly add it") + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + if conn.execute(f'select count(*) from modes where mode_id="{self.mode_id}"').fetchone()[0] == 0: + raise ValueError("Mode does not exist in the model. You need to explicitly add it") - curr.execute("pragma table_info(modes)") - table_struct = [x[1] for x in curr.fetchall()] + table_struct = [x[1] for x in conn.execute("pragma table_info(modes)").fetchall()] - for key, value in self.__dict__.items(): - if key in table_struct and key != "mode_id": - v_old = self.__original__.get(key, None) - if value != v_old and value is not None: - self.__original__[key] = value - curr.execute(f"update 'modes' set '{key}'=? where mode_id='{self.mode_id}'", [value]) - conn.commit() - conn.close() + for key, value in self.__dict__.items(): + if key in table_struct and key != "mode_id": + v_old = self.__original__.get(key, None) + if value != v_old and value is not None: + self.__original__[key] = value + conn.execute(f"update 'modes' set '{key}'=? where mode_id='{self.mode_id}'", [value]) diff --git a/aequilibrae/project/network/modes.py b/aequilibrae/project/network/modes.py index d5d98402d..7fd12f77d 100644 --- a/aequilibrae/project/network/modes.py +++ b/aequilibrae/project/network/modes.py @@ -1,6 +1,9 @@ -from sqlite3 import IntegrityError, Connection -from aequilibrae.project.network.mode import Mode +from sqlite3 import IntegrityError + from aequilibrae.project.field_editor import FieldEditor +from aequilibrae.project.network.mode import Mode +from aequilibrae.utils.db_utils import commit_and_close +from aequilibrae.utils.spatialite_utils import connect_spatialite class Modes: @@ -55,32 +58,33 @@ def __init__(self, net): self.__items = {} self.project = net.project self.logger = net.logger - self.conn = net.conn # type: Connection - self.curr = net.conn.cursor() - self.__update_list_of_modes() + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + self.__update_list_of_modes(conn) def add(self, mode: Mode) -> None: """We add a mode to the project""" - self.__update_list_of_modes() + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + self.__update_list_of_modes(conn) if mode.mode_id in self.__all_modes: raise ValueError("Mode already exists in the model") - self.curr.execute("insert into 'modes'(mode_id, mode_name) Values(?,?)", [mode.mode_id, mode.mode_name]) - self.conn.commit() - self.logger.info(f"mode {mode.mode_name}({mode.mode_id}) was added to the project") - mode.save() - self.__update_list_of_modes() + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + conn.execute("insert into 'modes'(mode_id, mode_name) Values(?,?)", [mode.mode_id, mode.mode_name]) + self.logger.info(f"mode {mode.mode_name}({mode.mode_id}) was added to the project") + conn.commit() + mode.save() + self.__update_list_of_modes(conn) def delete(self, mode_id: str) -> None: """Removes the mode with *mode_id* from the project""" - try: - self.curr.execute(f'delete from modes where mode_id="{mode_id}"') - self.conn.commit() - except IntegrityError as e: - self.logger.error(f"Failed to remove mode {mode_id}. {e.args}") - raise e - self.logger.warning(f"Mode {mode_id} was successfully removed from the database") - self.__update_list_of_modes() + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + try: + conn.execute(f'delete from modes where mode_id="{mode_id}"') + except IntegrityError as e: + self.logger.error(f"Failed to remove mode {mode_id}. {e.args}") + raise e + self.logger.warning(f"Mode {mode_id} was successfully removed from the database") + self.__update_list_of_modes(conn) @property def fields(self) -> FieldEditor: @@ -89,23 +93,25 @@ def fields(self) -> FieldEditor: def get(self, mode_id: str) -> Mode: """Get a mode from the network by its *mode_id*""" - self.__update_list_of_modes() + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + self.__update_list_of_modes(conn) if mode_id not in self.__all_modes: raise ValueError(f"Mode {mode_id} does not exist in the model") return Mode(mode_id, self.project) def get_by_name(self, mode: str) -> Mode: """Get a mode from the network by its *mode_name*""" - self.__update_list_of_modes() - self.curr.execute(f"select mode_id from 'modes' where mode_name='{mode}'") - found = self.curr.fetchone() + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + self.__update_list_of_modes(conn) + found = conn.execute(f"select mode_id from 'modes' where mode_name='{mode}'").fetchone() if len(found) == 0: raise ValueError(f"Mode {mode} does not exist in the model") return Mode(found[0], self.project) def all_modes(self) -> dict: """Returns a dictionary with all mode objects available in the model. mode_id as key""" - self.__update_list_of_modes() + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + self.__update_list_of_modes(conn) return {x: Mode(x, self.project) for x in self.__all_modes} def new(self, mode_id: str) -> Mode: @@ -115,9 +121,8 @@ def new(self, mode_id: str) -> Mode: return Mode(mode_id, self.project) - def __update_list_of_modes(self) -> None: - self.curr.execute("select mode_id from 'modes'") - self.__all_modes = [x[0] for x in self.curr.fetchall()] + def __update_list_of_modes(self, conn) -> None: + self.__all_modes = [x[0] for x in conn.execute("select mode_id from 'modes'").fetchall()] def __copy__(self): raise Exception("Modes object cannot be copied") @@ -127,8 +132,3 @@ def __deepcopy__(self, memodict=None): def __del__(self): self.__items.clear() - - def __has_mode(self): - curr = self.conn.cursor() - curr.execute("SELECT name FROM sqlite_master WHERE type='table';") - return any(["modes" in x[0] for x in curr.fetchall()]) diff --git a/aequilibrae/project/network/network.py b/aequilibrae/project/network/network.py index c205a80b3..367a4c42c 100644 --- a/aequilibrae/project/network/network.py +++ b/aequilibrae/project/network/network.py @@ -24,6 +24,8 @@ from aequilibrae.project.network.osm_utils.place_getter import placegetter from aequilibrae.project.project_creation import req_link_flds, req_node_flds, protected_fields from aequilibrae.utils import WorkerThread +from aequilibrae.utils.db_utils import commit_and_close +from aequilibrae.utils.spatialite_utils import connect_spatialite spec = iutil.find_spec("PyQt5") pyqt = spec is not None @@ -48,7 +50,6 @@ def __init__(self, project) -> None: from aequilibrae.paths import Graph WorkerThread.__init__(self, None) - self.conn = project.conn # type: sqlc self.source = project.source # type: sqlc self.graphs = {} # type: Dict[Graph] self.project = project @@ -65,11 +66,11 @@ def skimmable_fields(self): :Returns: :obj:`list`: List of all fields that can be skimmed """ - curr = self.conn.cursor() - curr.execute("PRAGMA table_info(links);") - field_names = curr.fetchall() - ignore_fields = ["ogc_fid", "geometry"] + self.req_link_flds + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + field_names = conn.execute("PRAGMA table_info(links);").fetchall() + + ignore_fields = ["ogc_fid", "geometry"] + self.req_link_flds skimmable = [ "INT", "INTEGER", @@ -115,9 +116,10 @@ def list_modes(self): :Returns: :obj:`list`: List of all modes """ - curr = self.conn.cursor() - curr.execute("""select mode_id from modes""") - return [x[0] for x in curr.fetchall()] + + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + all_modes = [x[0] for x in conn.execute("""select mode_id from modes""").fetchall()] + return all_modes def create_from_osm( self, @@ -173,10 +175,9 @@ def create_from_osm( if self.count_links() > 0: raise FileExistsError("You can only import an OSM network into a brand new model file") - curr = self.conn.cursor() - curr.execute("""ALTER TABLE links ADD COLUMN osm_id integer""") - curr.execute("""ALTER TABLE nodes ADD COLUMN osm_id integer""") - self.conn.commit() + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + conn.execute("""ALTER TABLE links ADD COLUMN osm_id integer""") + conn.execute("""ALTER TABLE nodes ADD COLUMN osm_id integer""") if isinstance(modes, (tuple, list)): modes = list(modes) @@ -313,34 +314,31 @@ def build_graphs(self, fields: list = None, modes: list = None) -> None: """ from aequilibrae.paths import Graph - curr = self.conn.cursor() - - if fields is None: - curr.execute("PRAGMA table_info(links);") - field_names = curr.fetchall() + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + if fields is None: + field_names = conn.execute("PRAGMA table_info(links);").fetchall() - ignore_fields = ["ogc_fid", "geometry"] - all_fields = [f[1] for f in field_names if f[1] not in ignore_fields] - else: - fields.extend(["link_id", "a_node", "b_node", "direction", "modes"]) - all_fields = list(set(fields)) + ignore_fields = ["ogc_fid", "geometry"] + all_fields = [f[1] for f in field_names if f[1] not in ignore_fields] + else: + fields.extend(["link_id", "a_node", "b_node", "direction", "modes"]) + all_fields = list(set(fields)) - if modes is None: - modes = curr.execute("select mode_id from modes;").fetchall() - modes = [m[0] for m in modes] - elif isinstance(modes, str): - modes = [modes] + if modes is None: + modes = conn.execute("select mode_id from modes;").fetchall() + modes = [m[0] for m in modes] + elif isinstance(modes, str): + modes = [modes] - sql = f"select {','.join(all_fields)} from links" + sql = f"select {','.join(all_fields)} from links" - df = pd.read_sql(sql, self.conn).fillna(value=np.nan) - valid_fields = list(df.select_dtypes(np.number).columns) + ["modes"] - curr.execute("select node_id from nodes where is_centroid=1 order by node_id;") - centroids = np.array([i[0] for i in curr.fetchall()], np.uint32) - centroids = centroids if centroids.shape[0] else None + df = pd.read_sql(sql, conn).fillna(value=np.nan) + valid_fields = list(df.select_dtypes(np.number).columns) + ["modes"] + sql = "select node_id from nodes where is_centroid=1 order by node_id;" + centroids = np.array([i[0] for i in conn.execute(sql).fetchall()], np.uint32) + centroids = centroids if centroids.shape[0] else None lonlat = self.nodes.lonlat.set_index("node_id") - data = df[valid_fields] for m in modes: net = pd.DataFrame(data, copy=True) @@ -402,9 +400,8 @@ def extent(self): :Returns: **model extent** (:obj:`Polygon`): Shapely polygon with the bounding box of the model network. """ - curr = self.conn.cursor() - curr.execute('Select ST_asBinary(GetLayerExtent("Links"))') - poly = shapely.wkb.loads(curr.fetchone()[0]) + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + poly = shapely.wkb.loads(conn.execute('Select ST_asBinary(GetLayerExtent("Links"))').fetchone()[0]) return poly def convex_hull(self) -> Polygon: @@ -413,15 +410,12 @@ def convex_hull(self) -> Polygon: :Returns: **model coverage** (:obj:`Polygon`): Shapely (Multi)polygon of the model network. """ - curr = self.conn.cursor() - curr.execute('Select ST_asBinary("geometry") from Links where ST_Length("geometry") > 0;') - links = [shapely.wkb.loads(x[0]) for x in curr.fetchall()] + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + sql = 'Select ST_asBinary("geometry") from Links where ST_Length("geometry") > 0;' + links = [shapely.wkb.loads(x[0]) for x in conn.execute(sql).fetchall()] return unary_union(links).convex_hull - def refresh_connection(self): - """Opens a new database connection to avoid thread conflict""" - self.conn = self.project.connect() - def __count_items(self, field: str, table: str, condition: str) -> int: - c = self.conn.execute(f"select count({field}) from {table} where {condition};").fetchone()[0] + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + c = conn.execute(f"select count({field}) from {table} where {condition};").fetchone()[0] return c diff --git a/aequilibrae/project/network/node.py b/aequilibrae/project/network/node.py index dce36168f..14c86a82d 100644 --- a/aequilibrae/project/network/node.py +++ b/aequilibrae/project/network/node.py @@ -1,3 +1,6 @@ +from sqlite3 import Connection +from typing import Optional + from shapely.geometry import Polygon from .safe_class import SafeClass from .connector_creation import connector_creation @@ -111,7 +114,7 @@ def __save_existing_node(self): sql = f"Update Nodes set {txts}" return data, sql - def connect_mode(self, area: Polygon, mode_id: str, link_types="", connectors=1): + def connect_mode(self, area: Polygon, mode_id: str, link_types="", connectors=1, conn: Optional[Connection] = None): """Adds centroid connectors for the desired mode to the network file Centroid connectors are created by connecting the zone centroid to one or more nodes selected from @@ -147,7 +150,8 @@ def connect_mode(self, area: Polygon, mode_id: str, link_types="", connectors=1) mode_id, link_types=link_types, connectors=connectors, - network=self._project.network, + network=self.project.network, + conn_=conn, ) def __setattr__(self, instance, value) -> None: diff --git a/aequilibrae/project/network/nodes.py b/aequilibrae/project/network/nodes.py index 91c3d106d..d09d515ec 100644 --- a/aequilibrae/project/network/nodes.py +++ b/aequilibrae/project/network/nodes.py @@ -6,6 +6,8 @@ from aequilibrae.project.data_loader import DataLoader from aequilibrae.project.network.node import Node from aequilibrae.project.table_loader import TableLoader +from aequilibrae.utils.db_utils import commit_and_close +from aequilibrae.utils.spatialite_utils import connect_spatialite class Nodes(BasicTable): @@ -61,8 +63,8 @@ def get(self, node_id: int) -> Node: else: self.__items[node.node_id] = self.__items.pop(node_id) - self._curr.execute(f"{self.sql} where node_id=?", [node_id]) - data = self._curr.fetchone() + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + data = conn.execute(f"{self.sql} where node_id=?", [node_id]).fetchone() if data: data = {key: val for key, val in zip(self.__fields, data)} node = Node(data, self.project) @@ -74,7 +76,8 @@ def get(self, node_id: int) -> Node: def refresh_fields(self) -> None: """After adding a field one needs to refresh all the fields recognized by the software""" tl = TableLoader() - tl.load_structure(self._curr, "nodes") + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + tl.load_structure(conn, "nodes") self.sql = tl.sql self.__fields = deepcopy(tl.fields) @@ -91,8 +94,9 @@ def new_centroid(self, node_id: int) -> Node: **node_id** (:obj:`int`): Id of the centroid to be created """ - self._curr.execute("select count(*) from nodes where node_id=?", [node_id]) - if self._curr.fetchone()[0] > 0: + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + ct = conn.execute("select count(*) from nodes where node_id=?", [node_id]).fetchone()[0] + if ct > 0: raise Exception("Node_id already exists. Failed to create it") data = {key: None for key in self.__fields} @@ -113,7 +117,7 @@ def data(self) -> pd.DataFrame: :Returns: **table** (:obj:`DataFrame`): Pandas DataFrame with all the nodes, complete with Geometry """ - dl = DataLoader(self.conn, "nodes") + dl = DataLoader(self.project.path_to_file, "nodes") return dl.load_table() @property @@ -123,7 +127,9 @@ def lonlat(self) -> pd.DataFrame: :Returns: **table** (:obj:`DataFrame`): Pandas DataFrame with all the nodes, with geometry as lon/lat """ - return pd.read_sql("SELECT node_id, ST_X(geometry) AS lon, ST_Y(geometry) AS lat FROM nodes", self.conn) + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + df = pd.read_sql("SELECT node_id, ST_X(geometry) AS lon, ST_Y(geometry) AS lat FROM nodes", conn) + return df def __del__(self): self.__items.clear() diff --git a/aequilibrae/project/network/osm_builder.py b/aequilibrae/project/network/osm_builder.py index 6d33a5bb3..a3f0a44b3 100644 --- a/aequilibrae/project/network/osm_builder.py +++ b/aequilibrae/project/network/osm_builder.py @@ -14,6 +14,8 @@ from .haversine import haversine from ...utils import WorkerThread +from aequilibrae.utils.db_utils import commit_and_close + spec = iutil.find_spec("PyQt5") pyqt = spec is not None if pyqt: @@ -35,7 +37,6 @@ def __init__(self, osm_items: List, path: str, node_start=10000, project=None) - self.logger = self.project.logger self.osm_items = osm_items self.path = path - self.conn = None self.node_start = node_start self.__link_types = None # type: LinkTypes self.report = [] @@ -52,12 +53,11 @@ def __emit_all(self, *args): self.building.emit(*args) def doWork(self): - self.conn = connect_spatialite(self.path) - self.curr = self.conn.cursor() - self.__worksetup() - node_count = self.data_structures() - self.importing_links(node_count) - self.__emit_all(["finished_threaded_procedure", 0]) + with commit_and_close(connect_spatialite(self.path)) as conn: + self.__worksetup() + node_count = self.data_structures() + self.importing_links(node_count, conn) + self.__emit_all(["finished_threaded_procedure", 0]) def data_structures(self): self.logger.info("Separating nodes and links") @@ -117,21 +117,21 @@ def data_structures(self): return node_count - def importing_links(self, node_count): + def importing_links(self, node_count, conn): node_ids = {} vars = {} vars["link_id"] = 1 table = "links" fields = self.get_link_fields() - self.__update_table_structure() + self.__update_table_structure(conn) field_names = ",".join(fields) self.logger.info("Adding network nodes") self.__emit_all(["text", "Adding network nodes"]) sql = "insert into nodes(node_id, is_centroid, osm_id, geometry) Values(?, 0, ?, MakePoint(?,?, 4326))" - self.conn.executemany(sql, self.node_df) - self.conn.commit() + conn.executemany(sql, self.node_df) + conn.commit() del self.node_df self.logger.info("Adding network links") @@ -140,7 +140,7 @@ def importing_links(self, node_count): self.__emit_all(["maxValue", L]) counter = 0 - mode_codes, not_found_tags = self.modes_per_link_type() + mode_codes, not_found_tags = self.modes_per_link_type(conn) owf, twf = self.field_osm_source() all_attrs = [] all_osm_ids = list(self.links.keys()) @@ -199,15 +199,12 @@ def importing_links(self, node_count): self.logger.info("Adding network links") self.__emit_all(["text", "Adding network links"]) try: - self.curr.executemany(sql, all_attrs) + conn.executemany(sql, all_attrs) except Exception as e: self.logger.error("error when inserting link {}. Error {}".format(all_attrs[0], e.args)) self.logger.error(sql) raise e - self.conn.commit() - self.curr.close() - def __worksetup(self): self.__link_types = self.project.network.link_types lts = self.__link_types.all_types() @@ -215,16 +212,14 @@ def __worksetup(self): self.__model_link_types.append(lt.link_type) self.__model_link_type_ids.append(lt_id) - def __update_table_structure(self): - curr = self.conn.cursor() - curr.execute("pragma table_info(Links)") - structure = curr.fetchall() + def __update_table_structure(self, conn): + structure = conn.execute("pragma table_info(Links)").fetchall() has_fields = [x[1].lower() for x in structure] fields = [field.lower() for field in self.get_link_fields()] + ["osm_id"] for field in [f for f in fields if f not in has_fields]: ltype = self.get_link_field_type(field).upper() - curr.execute(f"Alter table Links add column {field} {ltype}") - self.conn.commit() + conn.execute(f"Alter table Links add column {field} {ltype}") + conn.commit() def __build_link_data(self, vars, intersections, i, linknodes, node_ids, fields): ii = intersections[i] @@ -358,13 +353,11 @@ def field_osm_source(): } return owf, twf - def modes_per_link_type(self): + def modes_per_link_type(self, conn): p = Parameters() modes = p.parameters["network"]["osm"]["modes"] - cursor = self.conn.cursor() - cursor.execute("SELECT mode_name, mode_id from modes") - mode_codes = cursor.fetchall() + mode_codes = conn.execute("SELECT mode_name, mode_id from modes").fetchall() mode_codes = {p[0]: p[1] for p in mode_codes} type_list = {} diff --git a/aequilibrae/project/network/safe_class.py b/aequilibrae/project/network/safe_class.py index f005484c1..1c7b4218f 100644 --- a/aequilibrae/project/network/safe_class.py +++ b/aequilibrae/project/network/safe_class.py @@ -5,11 +5,11 @@ class SafeClass: _srid = 4326 def __init__(self, data_set: dict, project) -> None: - self.__original__ = {} - self._project = project - self._logger = project.logger - self._table = "" - self.__srid__ = 4326 + self.__dict__["__original__"] = {} + self.__dict__["project"] = project + self.__dict__["_logger"] = project.logger + self.__dict__["_table"] = "" + self.__dict__["__srid__"] = 4326 for k, v in data_set.items(): if k == "geometry" and v is not None: v = shapely.wkb.loads(v) @@ -37,4 +37,4 @@ def _save_new_with_geometry(self): return data, sql def connect_db(self): - return self._project.connect() + return self.project.connect() diff --git a/aequilibrae/project/project.py b/aequilibrae/project/project.py index bb464c2e9..c2d627f72 100644 --- a/aequilibrae/project/project.py +++ b/aequilibrae/project/project.py @@ -202,8 +202,7 @@ def __create_empty_network(self): p.write_back() # Create actual tables - cursor = self.conn.cursor() - cursor.execute("PRAGMA foreign_keys = ON;") + self.conn.execute("PRAGMA foreign_keys = ON;") self.conn.commit() initialize_tables(self, "network") diff --git a/aequilibrae/project/table_loader.py b/aequilibrae/project/table_loader.py index 01a4a79c5..37db69e0e 100644 --- a/aequilibrae/project/table_loader.py +++ b/aequilibrae/project/table_loader.py @@ -1,4 +1,4 @@ -from sqlite3 import Cursor +from sqlite3 import Connection from typing import List @@ -7,16 +7,15 @@ def __init__(self): self.fields = [] self.sql = "" - def load_table(self, curr: Cursor, table_name: str) -> List[dict]: - self.__get_table_struct(curr, table_name) - curr.execute(self.sql) - return [dict(zip(self.fields, row)) for row in curr.fetchall()] + def load_table(self, conn: Connection, table_name: str) -> List[dict]: + self.__get_table_struct(conn, table_name) + return [dict(zip(self.fields, row)) for row in conn.execute(self.sql).fetchall()] - def load_structure(self, curr: Cursor, table_name: str) -> None: - self.__get_table_struct(curr, table_name) + def load_structure(self, conn: Connection, table_name: str) -> None: + self.__get_table_struct(conn, table_name) - def __get_table_struct(self, curr: Cursor, table_name: str) -> None: - curr.execute(f"pragma table_info({table_name})") - self.fields = [x[1].lower() for x in curr.fetchall() if x[1].lower() != "ogc_fid"] + def __get_table_struct(self, conn: Connection, table_name: str) -> None: + dt = conn.execute(f"pragma table_info({table_name})").fetchall() + self.fields = [x[1].lower() for x in dt if x[1].lower() != "ogc_fid"] keys = [f'"{fld}"' if fld != "geometry" else 'ST_AsBinary("geometry")' for fld in self.fields] self.sql = f'select {",".join(keys)} from "{table_name}"' diff --git a/aequilibrae/project/zone.py b/aequilibrae/project/zone.py index b16d55cb8..4bb0a5d1d 100644 --- a/aequilibrae/project/zone.py +++ b/aequilibrae/project/zone.py @@ -1,8 +1,12 @@ import random from sqlite3 import Connection +from typing import Optional + from shapely.geometry import Point, MultiPolygon -from .network.safe_class import SafeClass + +from aequilibrae.utils.db_utils import commit_and_close from .network.connector_creation import connector_creation +from .network.safe_class import SafeClass class Zone(SafeClass): @@ -13,18 +17,14 @@ def __init__(self, dataset: dict, zoning): self.zone_id = -1 super().__init__(dataset, zoning.project) self.__zoning = zoning - self.conn = zoning.conn # type: Connection self.__new = dataset["geometry"] is None self.__network_links = zoning.network.links self.__network_nodes = zoning.network.nodes def delete(self): """Removes the zone from the database""" - conn = self._project.connect() - curr = conn.cursor() - curr.execute(f'DELETE FROM zones where zone_id="{self.zone_id}"') - conn.commit() - conn.close() + with commit_and_close(self.connect_db()) as conn: + conn.execute(f'DELETE FROM zones where zone_id="{self.zone_id}"') self.__zoning._remove_zone(self.zone_id) del self @@ -34,26 +34,21 @@ def save(self): if self.zone_id != self.__original__["zone_id"]: raise ValueError("One cannot change the zone_id") - conn = self._project.connect() - curr = conn.cursor() - - curr.execute(f'select count(*) from zones where zone_id="{self.zone_id}"') - if curr.fetchone()[0] == 0: - data = [self.zone_id, self.geometry.wkb] - curr.execute("Insert into zones (zone_id, geometry) values(?, ST_Multi(GeomFromWKB(?, 4326)))", data) - - for key, value in self.__dict__.items(): - if key != "zone_id" and key in self.__original__: - v_old = self.__original__.get(key, None) - if value != v_old and value is not None: - self.__original__[key] = value - if key == "geometry": - sql = "update 'zones' set geometry=ST_Multi(GeomFromWKB(?, 4326)) where zone_id=?" - curr.execute(sql, [value.wkb, self.zone_id]) - else: - curr.execute(f"update 'zones' set '{key}'=? where zone_id=?", [value, self.zone_id]) - conn.commit() - conn.close() + with commit_and_close(self.connect_db()) as conn: + if conn.execute(f'select count(*) from zones where zone_id="{self.zone_id}"').fetchone()[0] == 0: + data = [self.zone_id, self.geometry.wkb] + conn.execute("Insert into zones (zone_id, geometry) values(?, ST_Multi(GeomFromWKB(?, 4326)))", data) + + for key, value in self.__dict__.items(): + if key != "zone_id" and key in self.__original__: + v_old = self.__original__.get(key, None) + if value != v_old and value is not None: + self.__original__[key] = value + if key == "geometry": + sql = "update 'zones' set geometry=ST_Multi(GeomFromWKB(?, 4326)) where zone_id=?" + conn.execute(sql, [value.wkb, self.zone_id]) + else: + conn.execute(f"update 'zones' set '{key}'=? where zone_id=?", [value, self.zone_id]) def add_centroid(self, point: Point, robust=True) -> None: """Adds a centroid to the network file @@ -68,36 +63,33 @@ def add_centroid(self, point: Point, robust=True) -> None: # This is VERY small in real-world terms (between zero and 11cm) shift = 0.000001 - curr = self.conn.cursor() - - curr.execute("select count(*) from nodes where node_id=?", [self.zone_id]) - if curr.fetchone()[0] > 0: - self._project.logger.warning("Centroid already exists. Failed to create it") - return + with commit_and_close(self.connect_db()) as conn: + if conn.execute("select count(*) from nodes where node_id=?", [self.zone_id]).fetchone()[0] > 0: + self.project.logger.warning("Centroid already exists. Failed to create it") + return - sql = "INSERT into nodes (node_id, is_centroid, geometry) VALUES(?,1,GeomFromWKB(?, ?));" + sql = "INSERT into nodes (node_id, is_centroid, geometry) VALUES(?,1,GeomFromWKB(?, ?));" - if point is None: - point = self.geometry.centroid + if point is None: + point = self.geometry.centroid - if robust: - check_sql = """SELECT count(*) FROM nodes - WHERE nodes.geometry = GeomFromWKB(?, 4326) AND - nodes.ROWID IN ( - SELECT ROWID FROM SpatialIndex WHERE f_table_name = 'nodes' AND - search_frame = GeomFromWKB(?, 4326)) - """ + if robust: + check_sql = """SELECT count(*) FROM nodes + WHERE nodes.geometry = GeomFromWKB(?, 4326) AND + nodes.ROWID IN ( + SELECT ROWID FROM SpatialIndex WHERE f_table_name = 'nodes' AND + search_frame = GeomFromWKB(?, 4326)) + """ - test_list = self.conn.execute(check_sql, [point.wkb, point.wkb]).fetchone() - while sum(test_list): - test_list = self.conn.execute(check_sql, [point.wkb, point.wkb]).fetchone() - point = Point(point.x + random.random() * shift, point.y + random.random() * shift) + test_list = conn.execute(check_sql, [point.wkb, point.wkb]).fetchone() + while sum(test_list): + test_list = conn.execute(check_sql, [point.wkb, point.wkb]).fetchone() + point = Point(point.x + random.random() * shift, point.y + random.random() * shift) - data = [self.zone_id, point.wkb, self.__srid__] - self.conn.execute(sql, data) - self.conn.commit() + data = [self.zone_id, point.wkb, self.__srid__] + conn.execute(sql, data) - def connect_mode(self, mode_id: str, link_types="", connectors=1) -> None: + def connect_mode(self, mode_id: str, link_types="", connectors=1, conn: Optional[Connection] = None) -> None: """Adds centroid connectors for the desired mode to the network file Centroid connectors are created by connecting the zone centroid to one or more nodes selected from @@ -127,7 +119,8 @@ def connect_mode(self, mode_id: str, link_types="", connectors=1) -> None: mode_id=mode_id, link_types=link_types, connectors=connectors, - network=self._project.network, + network=self.project.network, + conn_=conn, ) def disconnect_mode(self, mode_id: str) -> None: @@ -137,17 +130,17 @@ def disconnect_mode(self, mode_id: str) -> None: **mode_id** (:obj:`str`): Mode ID we are trying to disconnect from this zone """ - curr = self.conn.cursor() - data = [self.zone_id, mode_id] - curr.execute("Delete from links where a_node=? and modes=?", data) - row_count = curr.rowcount + with commit_and_close(self.connect_db()) as conn: + data = [self.zone_id, mode_id] + row_count = conn.execute("Delete from links where a_node=? and modes=?", data).rowcount - data = [mode_id, self.zone_id, mode_id] - curr.execute('Update links set modes = replace(modes, ?, "") where a_node=? and instr(modes,?) > 0', data) - row_count += curr.rowcount + data = [mode_id, self.zone_id, mode_id] + sql = 'Update links set modes = replace(modes, ?, "") where a_node=? and instr(modes,?) > 0' + row_count += conn.execute(sql, data).rowcount - if row_count: - self._project.logger.warning(f"Deleted {row_count} connectors for mode {mode_id} for zone {self.zone_id}") - else: - self._project.warning("No centroid connectors for this mode") - self.conn.commit() + if row_count: + self.project.logger.warning( + f"Deleted {row_count} connectors for mode {mode_id} for zone {self.zone_id}" + ) + else: + self.project.warning("No centroid connectors for this mode") diff --git a/aequilibrae/project/zoning.py b/aequilibrae/project/zoning.py index 5f519dedd..630cb2077 100644 --- a/aequilibrae/project/zoning.py +++ b/aequilibrae/project/zoning.py @@ -9,8 +9,10 @@ from aequilibrae.project.basic_table import BasicTable from aequilibrae.project.project_creation import run_queries_from_sql_file from aequilibrae.project.table_loader import TableLoader -from aequilibrae.utils.geo_index import GeoIndex from aequilibrae.project.zone import Zone +from aequilibrae.utils.db_utils import commit_and_close +from aequilibrae.utils.geo_index import GeoIndex +from aequilibrae.utils.spatialite_utils import connect_spatialite class Zoning(BasicTable): @@ -71,7 +73,8 @@ def create_zoning_layer(self): if not self.__has_zoning(): qry_file = join(realpath(__file__), "database_specification", "tables", "zones.sql") - run_queries_from_sql_file(self.conn, qry_file) + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + run_queries_from_sql_file(conn, self.project.logger, qry_file) self.__load() else: self.project.warning("zones table already exists. Nothing was done", Warning) @@ -82,8 +85,9 @@ def coverage(self) -> Polygon: :Returns: **model coverage** (:obj:`Polygon`): Shapely (Multi)polygon of the zoning system. """ - self._curr.execute('Select ST_asBinary("geometry") from zones;') - polygons = [shapely.wkb.loads(x[0]) for x in self._curr.fetchall()] + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + dt = conn.execute('Select ST_asBinary("geometry") from zones;').fetchall() + polygons = [shapely.wkb.loads(x[0]) for x in dt] return unary_union(polygons) def get(self, zone_id: str) -> Zone: @@ -128,13 +132,14 @@ def refresh_geo_index(self): self.__geo_index.insert(feature_id=zone_id, geometry=zone.geometry) def __has_zoning(self): - curr = self.conn.cursor() - curr.execute("SELECT name FROM sqlite_master WHERE type='table';") - return any(["zone" in x[0].lower() for x in curr.fetchall()]) + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + dt = conn.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall() + return any(["zone" in x[0].lower() for x in dt]) def __load(self): tl = TableLoader() - zones_list = tl.load_table(self._curr, "zones") + with commit_and_close(connect_spatialite(self.project.path_to_file)) as conn: + zones_list = tl.load_table(conn, "zones") self.__fields = deepcopy(tl.fields) existing_list = [zn["zone_id"] for zn in zones_list] diff --git a/aequilibrae/transit/route_system.py b/aequilibrae/transit/route_system.py index 662cc09c5..47cde528e 100644 --- a/aequilibrae/transit/route_system.py +++ b/aequilibrae/transit/route_system.py @@ -1,25 +1,22 @@ import os -import sqlite3 import zipfile from os.path import join import pandas as pd from pyproj import Transformer -# from aequilibrae.tools.geo import Geo +from aequilibrae.project.database_connection import database_connection from aequilibrae.transit.functions.get_srid import get_srid -from aequilibrae.log import logger from aequilibrae.transit.gtfs_writer import write_routes, write_agencies, write_fares from aequilibrae.transit.gtfs_writer import write_stops, write_trips, write_stop_times, write_shapes from aequilibrae.transit.route_system_reader import read_agencies, read_patterns from aequilibrae.transit.route_system_reader import read_stop_times, read_stops, read_trips, read_routes -from aequilibrae.project.database_connection import database_connection +from aequilibrae.utils.db_utils import commit_and_close class RouteSystem: def __init__(self, database_path): self.__database_path = database_path - self.__conn: sqlite3.Connection = None self.agencies = [] self.stops = [] @@ -35,60 +32,45 @@ def __init__(self, database_path): self.transformer = Transformer.from_crs(f"epsg:{get_srid()}", "epsg:4326", always_xy=True) def load_route_system(self): - self._read_agencies() - self._read_stops() - self._read_routes() - self._read_patterns() - self._read_trips() - self._read_stop_times() - - def _read_agencies(self): - self.agencies = read_agencies(self.conn) + with commit_and_close(database_connection(join(self.__database_path, "public_transport.sqlite"))) as conn: + self._read_agencies(conn) + self._read_stops(conn) + self._read_routes(conn) + self._read_patterns(conn) + self._read_trips(conn) + self._read_stop_times(conn) - def _read_stops(self): - self.stops = read_stops(self.conn, self.transformer) + def _read_agencies(self, conn): + self.agencies = read_agencies(conn) - def _read_routes(self): - self.routes = read_routes(self.conn) + def _read_stops(self, conn): + self.stops = read_stops(conn, self.transformer) - def _read_patterns(self): - self.patterns = self.patterns or read_patterns(self.conn, self.transformer) + def _read_routes(self, conn): + self.routes = read_routes(conn) - def _read_trips(self): - self.trips = self.trips or read_trips(self.conn) + def _read_patterns(self, conn): + self.patterns = self.patterns or read_patterns(conn, self.transformer) - def _read_stop_times(self): - self.stop_times = read_stop_times(self.conn) + def _read_trips(self, conn): + self.trips = self.trips or read_trips(conn) - @property - def conn(self) -> sqlite3.Connection: - self.__conn = self.__conn or database_connection(join(self.__database_path, "public_transport.sqlite")) - return self.__conn + def _read_stop_times(self, conn): + self.stop_times = read_stop_times(conn) def write_GTFS(self, path_to_folder: str): """ """ - # timezone = self._timezone() - - write_agencies(self.agencies, path_to_folder) - write_stops(self.stops, path_to_folder) - write_routes(self.routes, path_to_folder) - write_shapes(self.patterns, path_to_folder) - - write_trips(self.trips, path_to_folder, self.conn) - write_stop_times(self.stop_times, path_to_folder) - write_fares(path_to_folder, self.conn) - self._zip_feed(path_to_folder) - - # def _timezone(self, allow_error=True): - # geotool = Geo() - # geotool.conn = self.conn - # try: - # return geotool.get_timezone() - # except Exception as e: - # logger.error("Could not retrieve the correct time zone for GTFS exporter. Using Chicago instead") - # if not allow_error: - # raise e - # return "America/Chicago" + + with commit_and_close(database_connection(join(self.__database_path, "public_transport.sqlite"))) as conn: + write_agencies(self.agencies, path_to_folder) + write_stops(self.stops, path_to_folder) + write_routes(self.routes, path_to_folder) + write_shapes(self.patterns, path_to_folder) + + write_trips(self.trips, path_to_folder, conn) + write_stop_times(self.stop_times, path_to_folder) + write_fares(path_to_folder, conn) + self._zip_feed(path_to_folder) def _zip_feed(self, path_to_folder: str): filename = join(path_to_folder, "polaris_gtfs.zip") diff --git a/aequilibrae/utils/db_utils.py b/aequilibrae/utils/db_utils.py index d1bee1fb8..23534138f 100644 --- a/aequilibrae/utils/db_utils.py +++ b/aequilibrae/utils/db_utils.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from os import PathLike from pathlib import Path -from sqlite3 import Connection, Cursor, connect +from sqlite3 import Connection, connect from typing import Union import pandas as pd @@ -19,12 +19,6 @@ def safe_connect(filepath: PathLike, missing_ok=False): raise FileNotFoundError(f"Attempting to open non-existant SQLite database: {filepath}") -def normalise_conn(curr: Union[Cursor, Connection, PathLike]): - if isinstance(curr, Cursor) or isinstance(curr, Connection): - return curr - return safe_connect(curr) - - class commit_and_close: """A context manager for sqlite connections which closes and commits.""" diff --git a/aequilibrae/utils/spatialite_utils.py b/aequilibrae/utils/spatialite_utils.py index 1a9ae5da8..160b27297 100644 --- a/aequilibrae/utils/spatialite_utils.py +++ b/aequilibrae/utils/spatialite_utils.py @@ -99,12 +99,11 @@ def _download_and_extract_spatialite(directory: os.PathLike) -> None: def spatialize_db(conn, logger=None): - logger = logging.getLogger("aequilibrae") + logger = logger or logging.getLogger("aequilibrae") logger.info("Adding Spatialite infrastructure to the database") - curr = conn.cursor() if not inside_qgis and not is_spatialite(conn): try: - curr.execute("SELECT InitSpatialMetaData();") + conn.execute("SELECT InitSpatialMetaData();") conn.commit() except Exception as e: logger.error("Problem with spatialite", e.args)