Skip to content

Commit

Permalink
Merge pull request #27 from richpsharp/bugfix/TASKGRAPH-26-corrupt-table
Browse files Browse the repository at this point in the history
Bugfix/taskgraph 26 corrupt table
  • Loading branch information
davemfish authored Jun 4, 2020
2 parents 516e9ce + 8bfdeb2 commit 9296d7f
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 33 deletions.
8 changes: 7 additions & 1 deletion HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@
TaskGraph Release History
=========================

Unreleased Changes
Unreleased changes
------------------
* Fixed issue that would cause an infinite loop if a ``TaskGraph`` object were
created with a database from an incompatible previous version. Behavior now
is to log the issue, delete the old database, and create a new compatible
one.
* Fixed issue that would cause some rare infinite loops if ``TaskGraph`` were
to fail due to some kinds of task exceptions.
* Adding open source BSD-3-Clause license.

0.9.0 (2020-03-05)
Expand Down
168 changes: 139 additions & 29 deletions taskgraph/Task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Task graph framework."""
from pkg_resources import get_distribution
import collections
import hashlib
import inspect
Expand All @@ -19,6 +20,9 @@

import retrying

__version__ = get_distribution('taskgraph').version


_VALID_PATH_TYPES = (str, pathlib.Path)
_TASKGRAPH_DATABASE_FILENAME = 'taskgraph_data.db'

Expand Down Expand Up @@ -82,7 +86,7 @@ def _initialize_logging_to_queue(logging_queue):
``multiprocessing.Pool`` to establish logging from a Pool worker to the
main python process via a multiprocessing Queue.
Parameters:
Args:
logging_queue (multiprocessing.Queue): The queue to use for passing
log records back to the main process.
Expand All @@ -105,6 +109,106 @@ def _initialize_logging_to_queue(logging_queue):
root_logger.addHandler(handler)


def _create_taskgraph_table_schema(taskgraph_database_path):
"""Create database exists and/or ensures it is compatible and recreate.
Args:
taskgraph_database_path (str): path to an existing database or desired
location of a new database.
Returns:
None.
"""
sql_create_projects_table_script = (
"""
CREATE TABLE taskgraph_data (
task_reexecution_hash TEXT NOT NULL,
target_path_stats BLOB NOT NULL,
result BLOB NOT NULL,
PRIMARY KEY (task_reexecution_hash)
);
CREATE TABLE global_variables (
key TEXT NOT NULL,
value BLOB,
PRIMARY KEY (key)
);
""")

table_valid = True
expected_table_column_name_map = {
'taskgraph_data': [
'task_reexecution_hash', 'target_path_stats', 'result'],
'global_variables': ['key', 'value']}
if os.path.exists(taskgraph_database_path):
try:
# check that the tables exist and the column names are as expected
for expected_table_name in expected_table_column_name_map:
table_result = _execute_sqlite(
'''
SELECT name
FROM sqlite_master
WHERE type='table' AND name=?
''', taskgraph_database_path,
argument_list=[expected_table_name],
mode='read_only', execute='execute', fetch='all')
if not table_result:
raise ValueError(f'missing table {expected_table_name}')

# this query returns a list of results of the form
# [(0, 'task_reexecution_hash', 'TEXT', 1, None, 1), ... ]
# we'll just check that the header names are the same, no
# need to be super aggressive, also need to construct the
# PRAGMA string directly since it doesn't take arguments
table_info_result = _execute_sqlite(
f'PRAGMA table_info({expected_table_name})',
taskgraph_database_path, mode='read_only',
execute='execute', fetch='all')

expected_column_names = expected_table_column_name_map[
expected_table_name]
header_count = 0
for header_line in table_info_result:
column_name = header_line[1]
if column_name not in expected_column_names:
raise ValueError(
f'expected {column_name} in table '
f'{expected_table_name} but not found')
header_count += 1
if header_count < len(expected_column_names):
raise ValueError(
f'found only {header_count} of an expected '
f'{len(expected_column_names)} columns in table '
f'{expected_table_name}')
if not table_info_result:
raise ValueError(f'missing table {expected_table_name}')
except Exception:
# catch all "Exception"s because anything that goes wrong while
# checking the database should be considered a bad database and we
# should make a new one.
LOGGER.exception(
f'{taskgraph_database_path} exists, but is incompatible '
'somehow. Deleting and making a new one.')
os.remove(taskgraph_database_path)
table_valid = False
else:
# table does not exist
table_valid = False

if not table_valid:
# create the base table
_execute_sqlite(
sql_create_projects_table_script, taskgraph_database_path,
mode='modify', execute='script')
# set the database version
_execute_sqlite(
'''
INSERT OR REPLACE INTO global_variables
VALUES ("version", ?)
''', taskgraph_database_path, mode='modify',
argument_list=(__version__,))


class TaskGraph(object):
"""Encapsulates the worker and tasks states for parallel processing."""

Expand All @@ -116,7 +220,7 @@ def __init__(
Creates an object for building task graphs, executing them,
parallelizing independent work notes, and avoiding repeated calls.
Parameters:
Args:
taskgraph_cache_dir_path (string): path to a directory that
either contains a taskgraph cache from a previous instance or
will create a new one if none exists.
Expand Down Expand Up @@ -204,21 +308,22 @@ def __init__(
self._task_database_path = os.path.join(
self._taskgraph_cache_dir_path, _TASKGRAPH_DATABASE_FILENAME)

sql_create_projects_table = (
"""
CREATE TABLE IF NOT EXISTS taskgraph_data (
task_reexecution_hash TEXT NOT NULL,
target_path_stats BLOB NOT NULL,
result BLOB NOT NULL,
PRIMARY KEY (task_reexecution_hash)
);
CREATE UNIQUE INDEX IF NOT EXISTS task_reexecution_hash_index
ON taskgraph_data (task_reexecution_hash);
""")

_execute_sqlite(
sql_create_projects_table, self._task_database_path, mode='modify',
execute='script')
# create new table if needed
_create_taskgraph_table_schema(self._task_database_path)

# check the version of the database and warn if a problem
local_version = _execute_sqlite(
'''
SELECT value
FROM global_variables
WHERE key=?
''', self._task_database_path, mode='read_only',
fetch='one', argument_list=['version'])[0]
if local_version != __version__:
LOGGER.warn(
f'the database located at {self._task_database_path} was '
f'created with TaskGraph version {local_version} but the '
f'current version is {__version__}')

# no need to set up schedulers if n_workers is single threaded
self._n_workers = n_workers
Expand Down Expand Up @@ -432,7 +537,7 @@ def add_task(
transient_run=False):
"""Add a task to the task graph.
Parameters:
Args:
func (callable): target function
args (list): argument list for `func`
kwargs (dict): keyword arguments for `func`
Expand Down Expand Up @@ -640,7 +745,7 @@ def _handle_logs_from_processes(self, queue_):
def _execution_monitor(self, monitor_wait_event):
"""Log state of taskgraph every `self._reporting_interval` seconds.
Parameters:
Args:
monitor_wait_event (threading.Event): used to sleep the monitor
for `self._reporting_interval` seconds, or to wake up to
terminate for shutdown.
Expand Down Expand Up @@ -684,7 +789,7 @@ def _execution_monitor(self, monitor_wait_event):
def join(self, timeout=None):
"""Join all threads in the graph.
Parameters:
Args:
timeout (float): if not none will attempt to join subtasks with
this value. If a subtask times out, the whole function will
timeout.
Expand Down Expand Up @@ -780,7 +885,7 @@ def __init__(
task_database_path):
"""Make a Task.
Parameters:
Args:
task_name (int): unique task id from the task graph.
func (function): a function that takes the argument list
`args`
Expand Down Expand Up @@ -1240,7 +1345,7 @@ def get(self, timeout=None):
determined by a call to `.join()`. Otherwise will wait up to `timeout`
seconds before raising a `RuntimeError` if exceeded.
Parameters:
Args:
timeout (float): if not None this parameter is a floating point
number specifying a timeout for the operation in seconds.
Expand All @@ -1261,7 +1366,7 @@ def _get_file_stats(
ignore_directories):
"""Return fingerprints of any filepaths in `base_value`.
Parameters:
Args:
base_value: any python value. Any file paths in `base_value`
should be "os.path.norm"ed before this function is called.
contains filepaths in any nested structure.
Expand Down Expand Up @@ -1327,7 +1432,7 @@ def _filter_non_files(
base_value, keep_list, ignore_list, keep_directories):
"""Remove any values that are files not in ignore list or directories.
Parameters:
Args:
base_value: any python value. Any file paths in `base_value`
should be "os.path.norm"ed before this function is called.
contains filepaths in any nested structure.
Expand Down Expand Up @@ -1383,7 +1488,7 @@ def _scrub_task_args(base_value, target_path_list):
This function can be called before the Task dependencies are satisfied
since it doesn't inspect any file stats on disk.
Parameters:
Args:
base_value: any python value
target_path_list (list): a list of strings that if found in
`base_value` should be replaced with 'in_target_path' so
Expand Down Expand Up @@ -1433,7 +1538,7 @@ def _scrub_task_args(base_value, target_path_list):
def _hash_file(file_path, hash_algorithm, buf_size=2**20):
"""Return a hex digest of `file_path`.
Parameters:
Args:
file_path (string): path to file to hash.
hash_algorithm (string): a hash function id that exists in
hashlib.algorithms_available or 'sizetimestamp'. If function id
Expand Down Expand Up @@ -1478,13 +1583,15 @@ def _normalize_path(path):
return os.path.normcase(abs_path)


@retrying.retry(wait_exponential_multiplier=1000, wait_exponential_max=5000)
@retrying.retry(
wait_exponential_multiplier=100, wait_exponential_max=3200,
stop_max_attempt_number=5)
def _execute_sqlite(
sqlite_command, database_path, argument_list=None,
mode='read_only', execute='execute', fetch=None):
"""Execute SQLite command and attempt retries on a failure.
Parameters:
Args:
sqlite_command (str): a well formatted SQLite command.
database_path (str): path to the SQLite database to operate on.
argument_list (list): `execute == 'execute` then this list is passed to
Expand Down Expand Up @@ -1515,7 +1622,10 @@ def _execute_sqlite(
raise ValueError('Unknown mode: %s' % mode)

if execute == 'execute':
cursor = connection.execute(sqlite_command, argument_list)
if argument_list is None:
cursor = connection.execute(sqlite_command)
else:
cursor = connection.execute(sqlite_command, argument_list)
elif execute == 'script':
cursor = connection.executescript(sqlite_command)
else:
Expand Down
5 changes: 2 additions & 3 deletions taskgraph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""TaskGraph init module."""
from pkg_resources import get_distribution

from .Task import TaskGraph
from .Task import Task
from .Task import _TASKGRAPH_DATABASE_FILENAME
from .Task import __version__

__all__ = ['TaskGraph', 'Task', '_TASKGRAPH_DATABASE_FILENAME']
__version__ = get_distribution(__name__).version
__all__ = ['__version__', 'TaskGraph', 'Task', '_TASKGRAPH_DATABASE_FILENAME']
58 changes: 58 additions & 0 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,6 +1377,64 @@ def test_return_value(self):
task_graph.join()
task_graph = None

def test_malformed_taskgraph_database(self):
"""TaskGraph: Test an empty task."""
db_schema_test_list = [
'''
CREATE TABLE taskgraph_data (
bad_name_1 TEXT NOT NULL,
bad_name_2 BLOB NOT NULL,
bad_name_3 BLOB NOT NULL);
''',
'''
CREATE TABLE taskgraph_data (
task_reexecution_hash TEXT NOT NULL,
target_path_stats BLOB NOT NULL);
''',
'''
CREATE TABLE bad_table_name (
task_reexecution_hash TEXT NOT NULL,
target_path_stats BLOB NOT NULL,
result BLOB NOT NULL,
PRIMARY KEY (task_reexecution_hash));
'''
]

for db_schema in db_schema_test_list:
database_path = os.path.join(
self.workspace_dir, taskgraph._TASKGRAPH_DATABASE_FILENAME)
if os.path.exists(database_path):
os.remove(database_path)
connection = sqlite3.connect(database_path)
cursor = connection.cursor()
cursor.executescript(db_schema)
cursor.close()
connection.commit()
connection.close()

task_graph = taskgraph.TaskGraph(self.workspace_dir, 0)
_ = task_graph.add_task()
task_graph.close()
task_graph.join()
del task_graph

expected_column_name_list = [
'task_reexecution_hash', 'target_path_stats', 'result']
connection = sqlite3.connect(database_path)
cursor = connection.cursor()
cursor.execute(f'PRAGMA table_info(taskgraph_data)')
result = list(cursor.fetchall())
cursor.close()
connection.commit()
connection.close()
for header_line in result:
column_name = header_line[1]
if column_name not in expected_column_name_list:
raise ValueError(
f'unexpected column name {column_name} in '
'taskgraph_data ')
self.assertEqual(len(result), len(expected_column_name_list))


def Fail(n_tries, result_path):
"""Create a function that fails after `n_tries`."""
Expand Down

0 comments on commit 9296d7f

Please sign in to comment.