Skip to content

Commit

Permalink
Merge pull request #355 from SubstraFoundation/improve-process-db-con…
Browse files Browse the repository at this point in the history
…nection

Close db connection before launching Process under celery task.
  • Loading branch information
Kelvin-M authored Dec 2, 2020
2 parents 388e255 + cf42dae commit e6ee19e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 32 deletions.
42 changes: 16 additions & 26 deletions backend/substrapp/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
query_tuples, get_object_from_ledger)
from substrapp.ledger.exceptions import LedgerError, LedgerStatusError
from substrapp.tasks.utils import (compute_job, get_asset_content, get_and_put_asset_content,
list_files, do_not_raise, get_or_create_local_volume, remove_image,
authenticate_worker)
list_files, do_not_raise, get_or_create_local_volume, remove_image)

from substrapp.tasks.exception_handler import compute_error_code

Expand Down Expand Up @@ -163,7 +162,7 @@ def get_tuple_status(channel_name, tuple_type, key):
return metadata['status']


def get_and_put_model_content(channel_name, tuple_type, hash_key, tuple_, out_model, model_dst_path, auth=None):
def get_and_put_model_content(channel_name, tuple_type, hash_key, tuple_, out_model, model_dst_path):
"""Get out model content."""
owner = tuple_get_owner(tuple_type, tuple_)
return get_and_put_asset_content(
Expand All @@ -172,9 +171,7 @@ def get_and_put_model_content(channel_name, tuple_type, hash_key, tuple_, out_mo
owner,
out_model['checksum'],
content_dst_path=model_dst_path,
hash_key=hash_key,
auth=auth
)
hash_key=hash_key)


def get_and_put_local_model_content(hash_key, out_model, model_dst_path):
Expand All @@ -195,7 +192,7 @@ def get_and_put_local_model_content(hash_key, out_model, model_dst_path):


@timeit
def fetch_model(channel_name, parent_tuple_type, authorized_types, input_model, directory, auth=None):
def fetch_model(channel_name, parent_tuple_type, authorized_types, input_model, directory):

tuple_type, metadata = find_training_step_tuple_from_key(channel_name, input_model['traintuple_key'])

Expand All @@ -207,24 +204,18 @@ def fetch_model(channel_name, parent_tuple_type, authorized_types, input_model,

if tuple_type == TRAINTUPLE_TYPE:
get_and_put_model_content(
channel_name, tuple_type, input_model['traintuple_key'], metadata, metadata['out_model'], model_dst_path,
auth
)
channel_name, tuple_type, input_model['traintuple_key'], metadata, metadata['out_model'], model_dst_path)
elif tuple_type == AGGREGATETUPLE_TYPE:
get_and_put_model_content(
channel_name, tuple_type, input_model['traintuple_key'], metadata, metadata['out_model'], model_dst_path,
auth
)
channel_name, tuple_type, input_model['traintuple_key'], metadata, metadata['out_model'], model_dst_path)
elif tuple_type == COMPOSITE_TRAINTUPLE_TYPE:
get_and_put_model_content(
channel_name,
tuple_type,
input_model['traintuple_key'],
metadata,
metadata['out_trunk_model']['out_model'],
model_dst_path,
auth
)
model_dst_path)
else:
raise TasksError(f'Traintuple: invalid input model: type={tuple_type}')

Expand All @@ -234,12 +225,13 @@ def fetch_models(channel_name, tuple_type, authorized_types, input_models, direc
models = []
exceptions = []

# Close django connection to force each Process to create its own as
# django orm connection is not fork safe https://code.djangoproject.com/ticket/20562
from django import db
db.connections.close_all()

for input_model in input_models:
# Authentification should be retrieved before sending the function fetch model to a subProcess
# as, fetching it from django models give random results
# Don't know why
auth = get_model_auth(channel_name, input_model)
args = (channel_name, tuple_type, authorized_types, input_model, directory, auth)
args = (channel_name, tuple_type, authorized_types, input_model, directory)
proc = Process(target=fetch_model, args=args)
models.append((proc, args))
proc.start()
Expand All @@ -249,15 +241,13 @@ def fetch_models(channel_name, tuple_type, authorized_types, input_models, direc
if proc.exitcode != 0:
exceptions.append(Exception(f'fetch model failed for args {args}'))

# Close django old connections to avoid potential leak
db.close_old_connections()

if exceptions:
raise Exception(exceptions)


def get_model_auth(channel_name, input_model):
tuple_type, metadata = find_training_step_tuple_from_key(channel_name, input_model['traintuple_key'])
return authenticate_worker(tuple_get_owner(tuple_type, metadata))


def prepare_traintuple_input_models(channel_name, directory, tuple_):
"""Get traintuple input models content."""
input_models = tuple_.get('in_models')
Expand Down
13 changes: 7 additions & 6 deletions backend/substrapp/tasks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ def authenticate_worker(node_id):
except OutgoingNode.DoesNotExist:
raise NodeError(f'Unauthorized to call node_id: {node_id}')

if node_id != outgoing.node_id:
# Ensure the response is valid. This is a safety net for the case when the DB connection is shared
# across processes running in parallel.
raise NodeError(f'Wrong response: Request {node_id} - Get {outgoing.node_id}')

auth = HTTPBasicAuth(owner, outgoing.secret)

return auth
Expand All @@ -39,12 +44,8 @@ def get_asset_content(channel_name, url, node_id, content_checksum, salt=None):
return get_remote_file_content(channel_name, url, authenticate_worker(node_id), content_checksum, salt=salt)


def get_and_put_asset_content(channel_name, url, node_id, content_checksum, content_dst_path, hash_key, auth=None):

if auth is None:
auth = authenticate_worker(node_id)

return get_and_put_remote_file_content(channel_name, url, auth, content_checksum,
def get_and_put_asset_content(channel_name, url, node_id, content_checksum, content_dst_path, hash_key):
return get_and_put_remote_file_content(channel_name, url, authenticate_worker(node_id), content_checksum,
content_dst_path=content_dst_path, hash_key=hash_key)


Expand Down

0 comments on commit e6ee19e

Please sign in to comment.