Skip to content

Commit

Permalink
fix: driver object pickle error (#1944)
Browse files Browse the repository at this point in the history
* fix: driver object pickle error

Signed-off-by: Allison Suarez Miranda <asuarezmiranda@lyft.com>

* always used conf w fallback on neo4j extractor

Signed-off-by: Allison Suarez Miranda <asuarezmiranda@lyft.com>
  • Loading branch information
allisonsuarez authored Jul 28, 2022
1 parent e97b74d commit 879a002
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 93 deletions.
63 changes: 30 additions & 33 deletions databuilder/databuilder/extractor/neo4j_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class Neo4jExtractor(Extractor):
"""NEO4J_ENCRYPTED is a boolean indicating whether to use SSL/TLS when connecting."""
NEO4J_VALIDATE_SSL = 'neo4j_validate_ssl'
"""NEO4J_VALIDATE_SSL is a boolean indicating whether to validate the server's SSL/TLS cert against system CAs."""
NEO4J_DRIVER = 'neo4j_driver'

DEFAULT_CONFIG = ConfigFactory.from_dict({
NEO4J_MAX_CONN_LIFE_TIME_SEC: 50,
Expand All @@ -48,41 +47,39 @@ def init(self, conf: ConfigTree) -> None:
:param conf:
"""
self.conf = conf.with_fallback(Neo4jExtractor.DEFAULT_CONFIG)
self.graph_url = conf.get_string(Neo4jExtractor.GRAPH_URL_CONFIG_KEY)
self.cypher_query = conf.get_string(Neo4jExtractor.CYPHER_QUERY_CONFIG_KEY)
self.graph_url = self.conf.get_string(Neo4jExtractor.GRAPH_URL_CONFIG_KEY)
self.cypher_query = self.conf.get_string(Neo4jExtractor.CYPHER_QUERY_CONFIG_KEY)
self.db_name = self.conf.get_string(Neo4jExtractor.NEO4J_DATABASE_NAME)
driver = conf.get(Neo4jExtractor.NEO4J_DRIVER, None)
if driver:
self.driver = driver
else:
uri = conf.get_string(Neo4jExtractor.GRAPH_URL_CONFIG_KEY)
driver_args = {
'uri': uri,
'max_connection_lifetime': self.conf.get_int(Neo4jExtractor.NEO4J_MAX_CONN_LIFE_TIME_SEC),
'auth': (conf.get_string(Neo4jExtractor.NEO4J_AUTH_USER),
conf.get_string(Neo4jExtractor.NEO4J_AUTH_PW)),
}

# if URI scheme not secure set `trust`` and `encrypted` to default values
# https://neo4j.com/docs/api/python-driver/current/api.html#uri
_, security_type, _ = parse_neo4j_uri(uri=uri)
if security_type not in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE]:
default_security_conf = {'trust': neo4j.TRUST_ALL_CERTIFICATES, 'encrypted': True}
driver_args.update(default_security_conf)

# if NEO4J_VALIDATE_SSL or NEO4J_ENCRYPTED are set in config pass them to the driver
validate_ssl_conf = conf.get(Neo4jExtractor.NEO4J_VALIDATE_SSL, None)
encrypted_conf = conf.get(Neo4jExtractor.NEO4J_ENCRYPTED, None)
if validate_ssl_conf is not None:
driver_args['trust'] = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if validate_ssl_conf \
else neo4j.TRUST_ALL_CERTIFICATES
if encrypted_conf is not None:
driver_args['encrypted'] = encrypted_conf

self.driver = GraphDatabase.driver(**driver_args)

uri = self.conf.get_string(Neo4jExtractor.GRAPH_URL_CONFIG_KEY)
driver_args = {
'uri': uri,
'max_connection_lifetime': self.conf.get_int(Neo4jExtractor.NEO4J_MAX_CONN_LIFE_TIME_SEC),
'auth': (self.conf.get_string(Neo4jExtractor.NEO4J_AUTH_USER),
self.conf.get_string(Neo4jExtractor.NEO4J_AUTH_PW)),
}

# if URI scheme not secure set `trust`` and `encrypted` to default values
# https://neo4j.com/docs/api/python-driver/current/api.html#uri
_, security_type, _ = parse_neo4j_uri(uri=uri)
if security_type not in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE]:
default_security_conf = {'trust': neo4j.TRUST_ALL_CERTIFICATES, 'encrypted': True}
driver_args.update(default_security_conf)

# if NEO4J_VALIDATE_SSL or NEO4J_ENCRYPTED are set in config pass them to the driver
validate_ssl_conf = self.conf.get(Neo4jExtractor.NEO4J_VALIDATE_SSL, None)
encrypted_conf = self.conf.get(Neo4jExtractor.NEO4J_ENCRYPTED, None)
if validate_ssl_conf is not None:
driver_args['trust'] = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if validate_ssl_conf \
else neo4j.TRUST_ALL_CERTIFICATES
if encrypted_conf is not None:
driver_args['encrypted'] = encrypted_conf

self.driver = GraphDatabase.driver(**driver_args)

self._extract_iter: Union[None, Iterator] = None

model_class = conf.get(Neo4jExtractor.MODEL_CLASS_CONFIG_KEY, None)
model_class = self.conf.get(Neo4jExtractor.MODEL_CLASS_CONFIG_KEY, None)
if model_class:
module_name, class_name = model_class.rsplit(".", 1)
mod = importlib.import_module(module_name)
Expand Down
54 changes: 24 additions & 30 deletions databuilder/databuilder/publisher/neo4j_csv_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@
# in Neo4j (v4.0+), we can create and use more than one active database at the same time
NEO4J_DATABASE_NAME = 'neo4j_database'

NEO4J_DRIVER = 'neo4j_driver'

# NEO4J_ENCRYPTED is a boolean indicating whether to use SSL/TLS when connecting
NEO4J_ENCRYPTED = 'neo4j_encrypted'
# NEO4J_VALIDATE_SSL is a boolean indicating whether to validate the server's SSL/TLS
Expand Down Expand Up @@ -154,34 +152,30 @@ def init(self, conf: ConfigTree) -> None:
self._relation_files = self._list_files(conf, RELATION_FILES_DIR)
self._relation_files_iter = iter(self._relation_files)

driver = conf.get(NEO4J_DRIVER, None)
if driver:
self._driver = driver
else:
uri = conf.get_string(NEO4J_END_POINT_KEY)
driver_args = {
'uri': uri,
'max_connection_lifetime': conf.get_int(NEO4J_MAX_CONN_LIFE_TIME_SEC),
'auth': (conf.get_string(NEO4J_USER), conf.get_string(NEO4J_PASSWORD)),
}

# if URI scheme not secure set `trust`` and `encrypted` to default values
# https://neo4j.com/docs/api/python-driver/current/api.html#uri
_, security_type, _ = parse_neo4j_uri(uri=uri)
if security_type not in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE]:
default_security_conf = {'trust': neo4j.TRUST_ALL_CERTIFICATES, 'encrypted': True}
driver_args.update(default_security_conf)

# if NEO4J_VALIDATE_SSL or NEO4J_ENCRYPTED are set in config pass them to the driver
validate_ssl_conf = conf.get(NEO4J_VALIDATE_SSL, None)
encrypted_conf = conf.get(NEO4J_ENCRYPTED, None)
if validate_ssl_conf is not None:
driver_args['trust'] = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if validate_ssl_conf \
else neo4j.TRUST_ALL_CERTIFICATES
if encrypted_conf is not None:
driver_args['encrypted'] = encrypted_conf

self._driver = GraphDatabase.driver(**driver_args)
uri = conf.get_string(NEO4J_END_POINT_KEY)
driver_args = {
'uri': uri,
'max_connection_lifetime': conf.get_int(NEO4J_MAX_CONN_LIFE_TIME_SEC),
'auth': (conf.get_string(NEO4J_USER), conf.get_string(NEO4J_PASSWORD)),
}

# if URI scheme not secure set `trust`` and `encrypted` to default values
# https://neo4j.com/docs/api/python-driver/current/api.html#uri
_, security_type, _ = parse_neo4j_uri(uri=uri)
if security_type not in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE]:
default_security_conf = {'trust': neo4j.TRUST_ALL_CERTIFICATES, 'encrypted': True}
driver_args.update(default_security_conf)

# if NEO4J_VALIDATE_SSL or NEO4J_ENCRYPTED are set in config pass them to the driver
validate_ssl_conf = conf.get(NEO4J_VALIDATE_SSL, None)
encrypted_conf = conf.get(NEO4J_ENCRYPTED, None)
if validate_ssl_conf is not None:
driver_args['trust'] = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if validate_ssl_conf \
else neo4j.TRUST_ALL_CERTIFICATES
if encrypted_conf is not None:
driver_args['encrypted'] = encrypted_conf

self._driver = GraphDatabase.driver(**driver_args)

self._db_name = conf.get_string(NEO4J_DATABASE_NAME)
self._session = self._driver.session(database=self._db_name)
Expand Down
53 changes: 24 additions & 29 deletions databuilder/databuilder/task/neo4j_staleness_removal_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
NEO4J_PASSWORD = 'neo4j_password'
# in Neo4j (v4.0+), we can create and use more than one active database at the same time
NEO4J_DATABASE_NAME = 'neo4j_database'
NEO4J_DRIVER = 'neo4j_driver'
NEO4J_ENCRYPTED = 'neo4j_encrypted'
"""NEO4J_ENCRYPTED is a boolean indicating whether to use SSL/TLS when connecting."""
NEO4J_VALIDATE_SSL = 'neo4j_validate_ssl'
Expand Down Expand Up @@ -131,34 +130,30 @@ def init(self, conf: ConfigTree) -> None:
else:
self.marker = conf.get_string(JOB_PUBLISH_TAG)

driver = conf.get(NEO4J_DRIVER, None)
if driver:
self._driver = driver
else:
uri = conf.get_string(NEO4J_END_POINT_KEY)
driver_args = {
'uri': uri,
'max_connection_lifetime': conf.get_int(NEO4J_MAX_CONN_LIFE_TIME_SEC),
'auth': (conf.get_string(NEO4J_USER), conf.get_string(NEO4J_PASSWORD)),
}

# if URI scheme not secure set `trust`` and `encrypted` to default values
# https://neo4j.com/docs/api/python-driver/current/api.html#uri
_, security_type, _ = parse_neo4j_uri(uri=uri)
if security_type not in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE]:
default_security_conf = {'trust': neo4j.TRUST_ALL_CERTIFICATES, 'encrypted': True}
driver_args.update(default_security_conf)

# if NEO4J_VALIDATE_SSL or NEO4J_ENCRYPTED are set in config pass them to the driver
validate_ssl_conf = conf.get(NEO4J_VALIDATE_SSL, None)
encrypted_conf = conf.get(NEO4J_ENCRYPTED, None)
if validate_ssl_conf is not None:
driver_args['trust'] = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if validate_ssl_conf \
else neo4j.TRUST_ALL_CERTIFICATES
if encrypted_conf is not None:
driver_args['encrypted'] = encrypted_conf

self._driver = GraphDatabase.driver(**driver_args)
uri = conf.get_string(NEO4J_END_POINT_KEY)
driver_args = {
'uri': uri,
'max_connection_lifetime': conf.get_int(NEO4J_MAX_CONN_LIFE_TIME_SEC),
'auth': (conf.get_string(NEO4J_USER), conf.get_string(NEO4J_PASSWORD)),
}

# if URI scheme not secure set `trust`` and `encrypted` to default values
# https://neo4j.com/docs/api/python-driver/current/api.html#uri
_, security_type, _ = parse_neo4j_uri(uri=uri)
if security_type not in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE]:
default_security_conf = {'trust': neo4j.TRUST_ALL_CERTIFICATES, 'encrypted': True}
driver_args.update(default_security_conf)

# if NEO4J_VALIDATE_SSL or NEO4J_ENCRYPTED are set in config pass them to the driver
validate_ssl_conf = conf.get(NEO4J_VALIDATE_SSL, None)
encrypted_conf = conf.get(NEO4J_ENCRYPTED, None)
if validate_ssl_conf is not None:
driver_args['trust'] = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if validate_ssl_conf \
else neo4j.TRUST_ALL_CERTIFICATES
if encrypted_conf is not None:
driver_args['encrypted'] = encrypted_conf

self._driver = GraphDatabase.driver(**driver_args)

self.db_name = conf.get(NEO4J_DATABASE_NAME)

Expand Down
2 changes: 1 addition & 1 deletion databuilder/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from setuptools import find_packages, setup

__version__ = '7.1.0'
__version__ = '7.1.1'

requirements_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'requirements.txt')
Expand Down

0 comments on commit 879a002

Please sign in to comment.