Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v.db.join: handle existing columns properly #3765

Merged
merged 4 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions scripts/v.db.join/testsuite/test_v_db_join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
TEST: test_v_db_join.py

AUTHOR(S): Stefan Blumentrath

PURPOSE: Test for v.db.join

COPYRIGHT: (C) 2024 Stefan Blumentrath, and by the GRASS Development Team

This program is free software under the GNU General Public
License (>=v2). Read the file COPYING that comes with GRASS
for details.
"""

from grass.gunittest.case import TestCase
from grass.gunittest.main import test
from grass.gunittest.gmodules import SimpleModule


class TestVDbJoin(TestCase):
"""Test v.db.join script"""

@classmethod
def setUpClass(cls):
"""Copy vector."""
firestation_sql = """CREATE TABLE firestation_test_table (
CITY text,
some_number int,
some_text text,
some_double double precision,
some_float real
);
INSERT INTO firestation_test VALUES
('Cary', 1, 'short', 1.1233445366756784345,),
('Apex', 2, 'longer', -111.1220390953406936354,),
('Garner', 3, 'short', 4.20529509802443234245,),
('Relaigh', 4, 'even longer than before', 32.913873948295837592,);
"""
firestation_existing_sql = """CREATE TABLE firestation_test_table_update (
CITY text,
others int
);
INSERT INTO firestation_test_table_update VALUES
('Cary', 1),
('Apex', 2),
('Garner', 3),
('Relaigh', 4);
"""
cls.runModule("g.copy", vector=["firestations", "test_firestations"])
cls.runModule("db.execute", sql=firestation_sql)
cls.runModule("db.execute", sql=firestation_existing_sql)

@classmethod
def tearDownClass(cls):
"""Remove copied vector data and created tables"""
cls.runModule("g.remove", type="vector", name="test_firestations", flags="f")
cls.runModule("db.execute", sql="DROP TABLE firestation_test_table;")
cls.runModule("db.execute", sql="DROP TABLE firestation_test_table_update;")

def test_join_firestations_table(self):
"""Join firestations table with new different columns"""
module = SimpleModule(
"v.db.join",
map="test_firestations",
column="CITY",
other_table="firestation_test_table",
other_column="CITY",
)
self.assertModule(module)

def test_join_firestations_table_existing(self):
"""Join firestations table with only existing columns"""
module = SimpleModule(
"v.db.join",
map="test_firestations",
column="CITY",
other_table="firestation_test_table_update",
other_column="CITY",
)
self.assertModule(module)


if __name__ == "__main__":
test()
178 changes: 93 additions & 85 deletions scripts/v.db.join/v.db.join.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,80 +64,87 @@
# %end

import atexit
import os
import sys
import grass.script as grass

from pathlib import Path

import grass.script as gs
from grass.exceptions import CalledModuleError

rm_files = []


def cleanup():
for file in rm_files:
if os.path.isfile(file):
try:
os.remove(file)
except Exception as e:
grass.warning(
_("Unable to remove file {file}: {message}").format(
file=file, message=e
)
for file_path in rm_files:
try:
file_path.unlink(missing_ok=True)
except Exception as e:
gs.warning(
_("Unable to remove file {file}: {message}").format(
file=file_path, message=e
)
)


def main():
global rm_files
map = options["map"]
# Include mapset into the name, so we avoid multiple messages about
# found in more mapsets. The following generates an error message, while the code
# above does not. However, the above checks that the map exists, so we don't
# check it here.
vector_map = gs.find_file(options["map"], element="vector")["fullname"]
layer = options["layer"]
column = options["column"]
otable = options["other_table"]
ocolumn = options["other_column"]
scolumns = None
if options["subset_columns"]:
scolumns = options["subset_columns"].split(",")
else:
scolumns = None
ecolumns = None
if options["exclude_columns"]:
ecolumns = options["exclude_columns"].split(",")
else:
ecolumns = None

try:
f = grass.vector_layer_db(map, layer)
f = gs.vector_layer_db(vector_map, layer)
except CalledModuleError:
sys.exit(1)

# Include mapset into the name, so we avoid multiple messages about
# found in more mapsets. The following generates an error message, while the code
# above does not. However, the above checks that the map exists, so we don't
# check it here.
map = grass.find_file(map, element="vector")["fullname"]

maptable = f["table"]
database = f["database"]
driver = f["driver"]

if driver == "dbf":
grass.fatal(_("JOIN is not supported for tables stored in DBF format"))
gs.fatal(_("JOIN is not supported for tables stored in DBF format"))

if not maptable:
grass.fatal(
gs.fatal(
_("There is no table connected to this map. Unable to join any column.")
)

all_cols_tt = gs.vector_columns(vector_map, int(layer)).keys()
# This is used for testing presence (and potential name conflict) with
# the newly added columns, but the test needs to case-insensitive since it
# is SQL, so we lowercase the names here and in the test
# An alternative is quoting identifiers (as in e.g. #3634)
all_cols_tt = [name.lower() for name in all_cols_tt]

# check if column is in map table
if column not in grass.vector_columns(map, layer):
grass.fatal(
if column.lower() not in all_cols_tt:
gs.fatal(
_("Column <{column}> not found in table <{table}>").format(
column=column, table=maptable
)
)

# describe other table
all_cols_ot = grass.db_describe(otable, driver=driver, database=database)["cols"]
all_cols_ot = {
col_desc[0].lower(): col_desc[1:]
for col_desc in gs.db_describe(otable, driver=driver, database=database)["cols"]
}

# check if ocolumn is on other table
if ocolumn not in [ocol[0] for ocol in all_cols_ot]:
grass.fatal(
if ocolumn.lower() not in all_cols_ot:
gs.fatal(
_("Column <{column}> not found in table <{table}>").format(
column=ocolumn, table=otable
)
Expand All @@ -146,106 +153,107 @@ def main():
# determine columns subset from other table
if not scolumns:
# select all columns from other table
cols_to_add = all_cols_ot
cols_to_update = all_cols_ot
else:
cols_to_add = []
cols_to_update = {}
# check if scolumns exists in the other table
for scol in scolumns:
found = False
for col_ot in all_cols_ot:
if scol == col_ot[0]:
found = True
cols_to_add.append(col_ot)
break
if not found:
grass.warning(
if scol not in all_cols_ot:
gs.warning(
_("Column <{column}> not found in table <{table}>").format(
column=scol, table=otable
)
)
else:
cols_to_update[scol] = all_cols_ot[scol]

# skip the vector column which is used for join
if column in cols_to_update:
cols_to_update.pop(column)

# exclude columns from other table
if ecolumns:
cols_to_add = list(filter(lambda col: col[0] not in ecolumns, cols_to_add))

all_cols_tt = grass.vector_columns(map, int(layer)).keys()
# This is used for testing presence (and potential name conflict) with
# the newly added columns, but the test needs to case-insensitive since it
# is SQL, so we lowercase the names here and in the test.
all_cols_tt = [name.lower() for name in all_cols_tt]
for ecol in ecolumns:
if ecol not in all_cols_ot:
gs.warning(
_("Column <{column}> not found in table <{table}>").format(
column=ecol, table=otable
)
)
else:
cols_to_update.pop(ecol)

cols_to_add_final = []
for col in cols_to_add:
# skip the vector column which is used for join
colname = col[0]
if colname == column:
continue
cols_to_add = []
for col_name, col_desc in cols_to_update.items():
use_len = False
if len(col) > 2:
col_type = f"{col_desc[0]}"
# Sqlite 3 does not support the precision number any more
if len(col_desc) > 2 and driver != "sqlite":
use_len = True
# Sqlite 3 does not support the precision number any more
if driver == "sqlite":
use_len = False
# MySQL - expect format DOUBLE PRECISION(M,D), see #2792
elif driver == "mysql" and col[1] == "DOUBLE PRECISION":
if driver == "mysql" and col_desc[1] == "DOUBLE PRECISION":
use_len = False

if use_len:
coltype = "%s(%s)" % (col[1], col[2])
else:
coltype = "%s" % col[1]
col_type = f"{col_desc[0]}({col_desc[1]})"

colspec = "%s %s" % (colname, coltype)
col_spec = f"{col_name.lower()} {col_type}"

# add only the new column to the table
if colname.lower() not in all_cols_tt:
cols_to_add_final.append(colspec)

cols_added = [col.split(" ")[0] for col in cols_to_add_final]
cols_added_str = ",".join(cols_added)
try:
grass.run_command(
"v.db.addcolumn", map=map, columns=cols_to_add_final, layer=layer
)
except CalledModuleError:
grass.fatal(_("Error creating columns <{}>").format(cols_added_str))
if col_name.lower() not in all_cols_tt:
cols_to_add.append(col_spec)

if cols_to_add:
try:
gs.run_command(
"v.db.addcolumn",
map=vector_map,
columns=",".join(cols_to_add),
layer=layer,
)
except CalledModuleError:
gs.fatal(
_("Error creating columns <{}>").format(
", ".join([col.split(" ")[0] for col in cols_to_add])
)
)

update_str = "BEGIN TRANSACTION\n"
for col in cols_added:
for col in cols_to_update:
cur_up_str = (
f"UPDATE {maptable} SET {col} = (SELECT {col} FROM "
f"{otable} WHERE "
f"{otable}.{ocolumn}={maptable}.{column});\n"
)
update_str += cur_up_str
update_str += "END TRANSACTION"
grass.debug(update_str, 1)
grass.verbose(
gs.debug(update_str, 1)
gs.verbose(
_("Updating columns {columns} of vector map {map_name}...").format(
columns=cols_added_str, map_name=map
columns=", ".join(cols_to_update.keys()), map_name=vector_map
)
)
sql_file = grass.tempfile()
sql_file = Path(gs.tempfile())
rm_files.append(sql_file)
with open(sql_file, "w") as write_file:
write_file.write(update_str)
sql_file.write_text(update_str, encoding="UTF8")

try:
grass.run_command(
gs.run_command(
"db.execute",
input=sql_file,
input=str(sql_file),
database=database,
driver=driver,
)
except CalledModuleError:
grass.fatal(_("Error filling columns {}").format(cols_added_str))
gs.fatal(_("Error filling columns {}").format(cols_to_update))

# write cmd history
grass.vector_history(map)
gs.vector_history(vector_map)

return 0


if __name__ == "__main__":
options, flags = grass.parser()
options, flags = gs.parser()
atexit.register(cleanup)
sys.exit(main())
Loading