Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

removes use of necessary sqlite3 connection cursors and cached connections #478

Merged
merged 15 commits into from
Dec 14, 2023
76 changes: 36 additions & 40 deletions aequilibrae/project/about.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from os.path import join, dirname, realpath
import string
import uuid
from os.path import join, dirname, realpath

from aequilibrae.project.project_creation import run_queries_from_sql_file
from aequilibrae.utils.db_utils import commit_and_close
from aequilibrae.utils.spatialite_utils import connect_spatialite


class About:
Expand All @@ -27,29 +30,28 @@ class About:
def __init__(self, project):
self.__characteristics = []
self.__original = {}
self.__conn = project.conn
self.__path_to_file = project.path_to_file
self.logger = project.logger
if self.__has_about():
self.__load()

with commit_and_close(connect_spatialite(self.__path_to_file)) as conn:
if self.__has_about(conn):
self.__load(conn)

def create(self):
"""Creates the 'about' table for project files that did not previously contain it"""

if not self.__has_about():
qry_file = join(dirname(realpath(__file__)), "database_specification", "tables", "about.sql")
run_queries_from_sql_file(self.__conn, qry_file)
with commit_and_close(connect_spatialite(self.__path_to_file)) as conn:
if not self.__has_about(conn):
qry_file = join(dirname(realpath(__file__)), "database_specification", "tables", "about.sql")
run_queries_from_sql_file(conn, self.logger, qry_file)

cursor = self.__conn.cursor()

sql = "SELECT count(*) as num_records from about;"
if self.__conn.execute(sql).fetchone()[0] == 0:
cursor.execute(f"UPDATE 'about' set infovalue='{uuid.uuid4().hex}' where infoname='project_ID'")
cursor.execute("UPDATE 'about' set infovalue='right' where infoname='driving_side'")
self.__conn.commit()

self.__load()
else:
self.logger.warning("About table already exists. Nothing was done")
sql = "SELECT count(*) as num_records from about;"
if conn.execute(sql).fetchone()[0] == 0:
conn.execute(f"UPDATE 'about' set infovalue='{uuid.uuid4().hex}' where infoname='project_ID'")
conn.execute("UPDATE 'about' set infovalue='right' where infoname='driving_side'")
self.__load(conn)
else:
self.logger.warning("About table already exists. Nothing was done")

def list_fields(self) -> list:
"""Returns a list of all characteristics the about table holds"""
Expand Down Expand Up @@ -78,10 +80,8 @@ def add_info_field(self, info_field: str) -> None:
if has_forbidden:
raise ValueError(f"{info_field} is not valid as a metadata field. Should be a lower case ascii letter or _")

sql = "INSERT INTO 'about' (infoname) VALUES(?)"
curr = self.__conn.cursor()
curr.execute(sql, [info_field])
self.__conn.commit()
with commit_and_close(connect_spatialite(self.__path_to_file)) as conn:
conn.execute("INSERT INTO 'about' (infoname) VALUES(?)", [info_field])
self.__characteristics.append(info_field)
self.__original[info_field] = None

Expand All @@ -96,25 +96,21 @@ def write_back(self):
>>> p.about.description = 'This is the example project. Do not use for forecast'
>>> p.about.write_back()
"""
curr = self.__conn.cursor()
for k in self.__characteristics:
v = self.__dict__[k]
if v != self.__original[k]:
curr.execute("UPDATE 'about' set infovalue = ? where infoname=?", [v, k])
self.logger.info(f"Updated {k} on About_Table to {v}")
self.__conn.commit()

def __has_about(self):
curr = self.__conn.cursor()
curr.execute("SELECT name FROM sqlite_master WHERE type='table';")
return any(["about" in x[0] for x in curr.fetchall()])

def __load(self):
with commit_and_close(connect_spatialite(self.__path_to_file)) as conn:
for k in self.__characteristics:
v = self.__dict__[k]
if v != self.__original[k]:
conn.execute("UPDATE 'about' set infovalue = ? where infoname=?", [v, k])
self.logger.info(f"Updated {k} on About_Table to {v}")

def __has_about(self, conn):
sql = "SELECT name FROM sqlite_master WHERE type='table';"
return any(["about" in x[0] for x in conn.execute(sql).fetchall()])

def __load(self, conn):
self.__characteristics = []
curr = self.__conn.cursor()
curr.execute("select infoname, infovalue from 'about'")

for x in curr.fetchall():
sql = "select infoname, infovalue from 'about'"
for x in conn.execute(sql).fetchall():
self.__characteristics.append(x[0])
self.__dict__[x[0]] = x[1]
self.__original[x[0]] = x[1]
14 changes: 4 additions & 10 deletions aequilibrae/project/basic_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from shapely.geometry import Polygon

from aequilibrae.project.field_editor import FieldEditor
from aequilibrae.utils.db_utils import commit_and_close


class BasicTable:
Expand All @@ -12,29 +13,22 @@ class BasicTable:
def __init__(self, project):
self.project = project
self.__table_type__ = ""
self.conn = project.connect()
self._curr = self.conn.cursor()

def extent(self) -> Polygon:
"""Queries the extent of thelayer included in the model

Returns:
*model extent* (:obj:`Polygon`): Shapely polygon with the bounding box of the layer.
"""
self.__curr.execute(f'Select ST_asBinary(GetLayerExtent("{self.__table_type__}"))')
poly = shapely.wkb.loads(self.__curr.fetchone()[0])
return poly
with commit_and_close(self.project.connect()) as conn:
data = conn.execute(f'Select ST_asBinary(GetLayerExtent("{self.__table_type__}"))').fetchone()[0]
return shapely.wkb.loads(data)

@property
def fields(self) -> FieldEditor:
"""Returns a FieldEditor class instance to edit the zones table fields and their metadata"""
return FieldEditor(self.project, self.__table_type__)

def refresh_connection(self):
"""Opens a new database connection to avoid thread conflict"""
self.conn = self.project.connect()
self.__curr = self.conn.cursor()

def __copy__(self):
raise Exception(f"{self.__table_type__} object cannot be copied")

Expand Down
31 changes: 17 additions & 14 deletions aequilibrae/project/data/matrices.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
from os.path import isfile, join

import pandas as pd
from aequilibrae.project.table_loader import TableLoader

from aequilibrae.matrix import AequilibraeMatrix
from aequilibrae.project.data.matrix_record import MatrixRecord
from aequilibrae.project.table_loader import TableLoader
from aequilibrae.utils.db_utils import commit_and_close


class Matrices:
Expand All @@ -17,10 +20,9 @@ def __init__(self, project):

self.fldr = os.path.join(project.project_base_path, "matrices")

self.conn = project.connect()
self.curr = self.conn.cursor()
tl = TableLoader()
matrices_list = tl.load_table(self.curr, "matrices")
with commit_and_close(self.project.connect()) as conn:
matrices_list = tl.load_table(conn, "matrices")
self.__fields = [x for x in tl.fields]
if matrices_list:
self.__properties = list(matrices_list[0].keys())
Expand All @@ -41,16 +43,16 @@ def reload(self):
def clear_database(self) -> None:
"""Removes records from the matrices database that do not exist in disk"""

self.curr.execute("Select name, file_name from matrices;")
with commit_and_close(self.project.connect()) as conn:
mats = conn.execute("Select name, file_name from matrices;").fetchall()

remove = [nm for nm, file in self.curr.fetchall() if not isfile(join(self.fldr, file))]
remove = [nm for nm, file in mats if not isfile(join(self.fldr, file))]

if remove:
self.logger.warning(f'Matrix records not found in disk cleaned from database: {",".join(remove)}')
if remove:
self.logger.warning(f'Matrix records not found in disk cleaned from database: {",".join(remove)}')

remove = [[x] for x in remove]
self.curr.executemany("DELETE from matrices where name=?;", remove)
self.conn.commit()
remove = [[x] for x in remove]
conn.executemany("DELETE from matrices where name=?;", remove)

def update_database(self) -> None:
"""Adds records to the matrices database for matrix files found on disk"""
Expand Down Expand Up @@ -96,9 +98,10 @@ def check_if_exists(file_name):
else:
return "file missing"

df = pd.read_sql_query("Select * from matrices;", self.conn)
df = df.assign(status="")
df.status = df.file_name.apply(check_if_exists)
with commit_and_close(self.project.connect()) as conn:
df = pd.read_sql_query("Select * from matrices;", conn)
df = df.assign(status="")
df.status = df.file_name.apply(check_if_exists)

return df

Expand Down
67 changes: 30 additions & 37 deletions aequilibrae/project/data/matrix_record.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,47 @@
from os import unlink
from os.path import isfile, join
from aequilibrae.project.network.safe_class import SafeClass

from aequilibrae.matrix.aequilibrae_matrix import AequilibraeMatrix
from aequilibrae.project.network.safe_class import SafeClass
from aequilibrae.utils.db_utils import commit_and_close


class MatrixRecord(SafeClass):
def __init__(self, data_set: dict, project):
super().__init__(data_set, project)
self._exists = True
self.fldr = join(project.project_base_path, "matrices")
self._exists: bool
self.fldr: str
self.__dict__["_exists"] = True
self.__dict__["fldr"] = join(project.project_base_path, "matrices")

def save(self):
"""Saves matrix record to the project database"""
conn = self.connect_db()
curr = conn.cursor()
with commit_and_close(self.connect_db()) as conn:
sql = "select count(*) from matrices where name=?"

curr.execute("select count(*) from matrices where name=?", [self.name])
if curr.fetchone()[0] == 0:
data = [str(self.name), str(self.file_name), int(self.cores)]
curr.execute("Insert into matrices (name, file_name, cores) values(?,?,?)", data)
if conn.execute(sql, [self.name]).fetchone()[0] == 0:
data = [str(self.name), str(self.file_name), int(self.cores)]
conn.execute("Insert into matrices (name, file_name, cores) values(?,?,?)", data)

for key, value in self.__dict__.items():
if key != "name" and key in self.__original__:
v_old = self.__original__.get(key, None)
if value != v_old and value:
self.__original__[key] = value
curr.execute(f"update matrices set '{key}'=? where name=?", [value, self.name])
conn.commit()
conn.close()
for key, value in self.__dict__.items():
if key != "name" and key in self.__original__:
v_old = self.__original__.get(key, None)
if value != v_old and value:
self.__original__[key] = value
conn.execute(f"update matrices set '{key}'=? where name=?", [value, self.name])

def delete(self):
"""Deletes this matrix record and the underlying data from disk"""
conn = self.connect_db()
curr = conn.cursor()
curr.execute("DELETE FROM matrices where name=?", [self.name])
conn.commit()
with commit_and_close(self.connect_db()) as conn:
conn.execute("DELETE FROM matrices where name=?", [self.name])

if isfile(join(self.fldr, self.file_name)):
try:
unlink(join(self.fldr, self.file_name))
except Exception as e:
self._logger.error(f"Could not remove matrix from disk: {e.args}")

conn.close()
self._exists = False
self.__dict__["_exists"] = False

def update_cores(self):
"""Updates this matrix record with the matrix core count in disk"""
Expand All @@ -59,20 +58,14 @@ def get_data(self) -> AequilibraeMatrix:
return mat

def __setattr__(self, instance, value) -> None:
if instance == "name":
value = str(value).lower()
conn = self.connect_db()
curr = conn.cursor()
curr.execute("Select count(*) from matrices where LOWER(name)=?", [value])
if sum(curr.fetchone()) > 0:
raise ValueError("Another matrix with this name already exists")
conn.close()
elif instance == "file_name":
conn = self.connect_db()
curr = conn.cursor()
curr.execute("Select count(*) from matrices where LOWER(file_name)=?", [str(value).lower()])
if sum(curr.fetchone()) > 0:
raise ValueError("There is another matrix record for this file")
with commit_and_close(self.connect_db()) as conn:
sql = f"Select count(*) from matrices where LOWER({instance})=?"
qry_value = sum(conn.execute(sql, [str(value).lower()]).fetchone())
if qry_value > 0:
if instance == "name":
raise ValueError("Another matrix with this name already exists")
elif instance == "file_name":
raise ValueError("There is another matrix record for this file")

self.__dict__[instance] = value
if instance in ["file_name", "cores"]:
Expand Down
30 changes: 17 additions & 13 deletions aequilibrae/project/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
from sqlite3 import Connection
import shapely.wkb
from os import PathLike

import pandas as pd
import shapely.wkb

from aequilibrae.utils.db_utils import commit_and_close
from aequilibrae.utils.spatialite_utils import connect_spatialite


class DataLoader:
def __init__(self, conn: Connection, table_name: str):
self.conn = conn
self.curr = conn.cursor()
def __init__(self, path_to_file: PathLike, table_name: str):
self.__pth_file = path_to_file
self.table_name = table_name

def load_table(self) -> pd.DataFrame:
fields, _, geo_field = self.__find_table_fields()
fields = [f'"{x}"' for x in fields]
if geo_field is not None:
fields.append('ST_AsBinary("geometry") geometry')
keys = ",".join(fields)
df = pd.read_sql_query(f"select {keys} from '{self.table_name}'", self.conn)
with commit_and_close(connect_spatialite(self.__pth_file)) as conn:
fields, _, geo_field = self.__find_table_fields()
fields = [f'"{x}"' for x in fields]
if geo_field is not None:
fields.append('ST_AsBinary("geometry") geometry')
keys = ",".join(fields)
df = pd.read_sql_query(f"select {keys} from '{self.table_name}'", conn)
df.geometry = df.geometry.apply(shapely.wkb.loads)
return df

def __find_table_fields(self):
with commit_and_close(connect_spatialite(self.__pth_file)) as conn:
structure = conn.execute(f"pragma table_info({self.table_name})").fetchall()
geotypes = ["LINESTRING", "POINT", "POLYGON", "MULTIPOLYGON"]
self.curr.execute(f"pragma table_info({self.table_name})")
structure = self.curr.fetchall()
fields = [x[1].lower() for x in structure]
geotype = geo_field = None
for x in structure:
Expand Down
21 changes: 9 additions & 12 deletions aequilibrae/project/field_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import string
from typing import List

from aequilibrae.utils.db_utils import commit_and_close

ALLOWED_CHARACTERS = string.ascii_letters + "_0123456789"


Expand Down Expand Up @@ -129,18 +131,13 @@ def __adds_to_attribute_table(self, attribute_name, attribute_value):
self.__run_query_commit(qry, vals)

def __run_query_fetch_all(self, qry: str):
conn = self.project.connect()
curr = conn.cursor()
curr.execute(qry)
dt = curr.fetchall()
conn.close()
with commit_and_close(self.project.connect()) as conn:
dt = conn.execute(qry).fetchall()
return dt

def __run_query_commit(self, qry: str, values=None) -> None:
conn = self.project.connect()
if values is None:
conn.execute(qry)
else:
conn.execute(qry, values)
conn.commit()
conn.close()
with commit_and_close(self.project.connect()) as conn:
if values is None:
conn.execute(qry)
else:
conn.execute(qry, values)
Loading