From 9997b1c563517aeedc1e05be9574d865e9842ebd Mon Sep 17 00:00:00 2001 From: Michael Chin Date: Thu, 19 Sep 2024 21:56:38 -0700 Subject: [PATCH] Enforce Gremlin protocol and serializer based on database type (#697) * Set allowed and default Gremlin protocol and serializer dynamically * Add unit test suite * update changelog --- ChangeLog.md | 1 + .../configuration/generate_config.py | 131 +++++--- .../configuration/get_config.py | 18 +- src/graph_notebook/magics/graph_magic.py | 15 +- src/graph_notebook/neptune/client.py | 41 ++- test/unit/configuration/test_configuration.py | 284 +++++++++++++++++- .../test_configuration_from_main.py | 48 ++- 7 files changed, 457 insertions(+), 81 deletions(-) diff --git a/ChangeLog.md b/ChangeLog.md index edcf5a4e..f4c7493b 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -5,6 +5,7 @@ Starting with v1.31.6, this file will contain a record of major features and upd ## Upcoming - Updated Gremlin config `message_serializer` to accept all TinkerPop serializers ([Link to PR](https://github.com/aws/graph-notebook/pull/685)) +- Implemented service-based dynamic allowlists and defaults for Gremlin serializer and protocol combinations ([Link to PR](https://github.com/aws/graph-notebook/pull/697)) - Added `%get_import_task` line magic ([Link to PR](https://github.com/aws/graph-notebook/pull/668)) - Added `--export-to` JSON file option to `%%graph_notebook_config` ([Link to PR](https://github.com/aws/graph-notebook/pull/684)) - Deprecated Python 3.8 support ([Link to PR](https://github.com/aws/graph-notebook/pull/683)) diff --git a/src/graph_notebook/configuration/generate_config.py b/src/graph_notebook/configuration/generate_config.py index 83f58038..44270328 100644 --- a/src/graph_notebook/configuration/generate_config.py +++ b/src/graph_notebook/configuration/generate_config.py @@ -14,9 +14,12 @@ HTTP_PROTOCOL_FORMATS, WS_PROTOCOL_FORMATS, DEFAULT_NEO4J_USERNAME, DEFAULT_NEO4J_PASSWORD, DEFAULT_NEO4J_DATABASE, NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, false_str_variants, - GRAPHBINARYV1, GREMLIN_SERIALIZERS_HTTP, + GRAPHBINARYV1, GREMLIN_SERIALIZERS_HTTP, GREMLIN_SERIALIZERS_WS, + GREMLIN_SERIALIZERS_ALL, NEPTUNE_GREMLIN_SERIALIZERS_HTTP, + DEFAULT_GREMLIN_WS_SERIALIZER, DEFAULT_GREMLIN_HTTP_SERIALIZER, NEPTUNE_DB_SERVICE_NAME, NEPTUNE_ANALYTICS_SERVICE_NAME, - normalize_service_name) + normalize_service_name, normalize_protocol_name, + normalize_serializer_class_name) DEFAULT_CONFIG_LOCATION = os.path.expanduser('~/graph_notebook_config.json') @@ -57,7 +60,8 @@ class GremlinSection(object): """ def __init__(self, traversal_source: str = '', username: str = '', password: str = '', - message_serializer: str = '', connection_protocol: str = '', include_protocol: bool = False): + message_serializer: str = '', connection_protocol: str = '', + include_protocol: bool = False, neptune_service: str = ''): """ :param traversal_source: used to specify the traversal source for a Gremlin traversal, in the case that we are connected to an endpoint that can access multiple graphs. @@ -71,57 +75,78 @@ def __init__(self, traversal_source: str = '', username: str = '', password: str if traversal_source == '': traversal_source = DEFAULT_GREMLIN_TRAVERSAL_SOURCE - serializer_lower = message_serializer.lower() - # TODO: Update with untyped serializers once supported in GremlinPython - # Accept TinkerPop serializer class name - # https://github.com/apache/tinkerpop/blob/fd040c94a66516e473811fe29eaeaf4081cf104c/docs/src/reference/gremlin-applications.asciidoc#graphson - # https://github.com/apache/tinkerpop/blob/fd040c94a66516e473811fe29eaeaf4081cf104c/docs/src/reference/gremlin-applications.asciidoc#graphbinary - if serializer_lower == '': - message_serializer = DEFAULT_GREMLIN_SERIALIZER - elif 'graphson' in serializer_lower: - message_serializer = 'GraphSON' - if 'untyped' in serializer_lower: - message_serializer += 'Untyped' - if 'v1' in serializer_lower: - if 'untyped' in serializer_lower: - message_serializer += 'MessageSerializerV1' - else: - message_serializer += 'MessageSerializerGremlinV1' - elif 'v2' in serializer_lower: - message_serializer += 'MessageSerializerV2' + invalid_serializer_input = False + if message_serializer != '': + message_serializer, invalid_serializer_input = normalize_serializer_class_name(message_serializer) + + if include_protocol: + # Neptune endpoint + invalid_protocol_input = False + if connection_protocol != '': + connection_protocol, invalid_protocol_input = normalize_protocol_name(connection_protocol) + + if neptune_service == NEPTUNE_ANALYTICS_SERVICE_NAME: + if connection_protocol != DEFAULT_HTTP_PROTOCOL: + if invalid_protocol_input: + print(f"Invalid connection protocol specified, you must use {DEFAULT_HTTP_PROTOCOL}. ") + elif connection_protocol == DEFAULT_WS_PROTOCOL: + print(f"Enforcing HTTP protocol.") + connection_protocol = DEFAULT_HTTP_PROTOCOL + # temporary restriction until GraphSON-typed and GraphBinary results are supported + if message_serializer not in NEPTUNE_GREMLIN_SERIALIZERS_HTTP: + if message_serializer not in GREMLIN_SERIALIZERS_ALL: + if invalid_serializer_input: + print(f"Invalid serializer specified, defaulting to {DEFAULT_GREMLIN_HTTP_SERIALIZER}. " + f"Valid serializers: {NEPTUNE_GREMLIN_SERIALIZERS_HTTP}") + else: + print(f"{message_serializer} is not currently supported for HTTP connections, " + f"defaulting to {DEFAULT_GREMLIN_HTTP_SERIALIZER}. " + f"Please use one of: {NEPTUNE_GREMLIN_SERIALIZERS_HTTP}") + message_serializer = DEFAULT_GREMLIN_HTTP_SERIALIZER else: - message_serializer += 'MessageSerializerV3' - elif 'graphbinary' in serializer_lower: - message_serializer = GRAPHBINARYV1 + if connection_protocol not in [DEFAULT_WS_PROTOCOL, DEFAULT_HTTP_PROTOCOL]: + if invalid_protocol_input: + print(f"Invalid connection protocol specified, defaulting to {DEFAULT_WS_PROTOCOL}. " + f"Valid protocols: [websockets, http].") + connection_protocol = DEFAULT_WS_PROTOCOL + + if connection_protocol == DEFAULT_HTTP_PROTOCOL: + # temporary restriction until GraphSON-typed and GraphBinary results are supported + if message_serializer not in NEPTUNE_GREMLIN_SERIALIZERS_HTTP: + if message_serializer not in GREMLIN_SERIALIZERS_ALL: + if invalid_serializer_input: + print(f"Invalid serializer specified, defaulting to {DEFAULT_GREMLIN_HTTP_SERIALIZER}. " + f"Valid serializers: {NEPTUNE_GREMLIN_SERIALIZERS_HTTP}") + else: + print(f"{message_serializer} is not currently supported for HTTP connections, " + f"defaulting to {DEFAULT_GREMLIN_HTTP_SERIALIZER}. " + f"Please use one of: {NEPTUNE_GREMLIN_SERIALIZERS_HTTP}") + message_serializer = DEFAULT_GREMLIN_HTTP_SERIALIZER + else: + if message_serializer not in GREMLIN_SERIALIZERS_WS: + if invalid_serializer_input: + print(f"Invalid serializer specified, defaulting to {DEFAULT_GREMLIN_WS_SERIALIZER}. " + f"Valid serializers: {GREMLIN_SERIALIZERS_WS}") + elif message_serializer != '': + print(f"{message_serializer} is not currently supported by Gremlin Python driver, " + f"defaulting to {DEFAULT_GREMLIN_WS_SERIALIZER}. " + f"Valid serializers: {GREMLIN_SERIALIZERS_WS}") + message_serializer = DEFAULT_GREMLIN_WS_SERIALIZER + + self.connection_protocol = connection_protocol else: - print(f'Invalid Gremlin serializer specified, defaulting to graphsonv3. ' - f'Valid serializers: {GREMLIN_SERIALIZERS_HTTP}.') - message_serializer = DEFAULT_GREMLIN_SERIALIZER + # Non-Neptune database - check and set valid WebSockets serializer if invalid/empty + if message_serializer not in GREMLIN_SERIALIZERS_WS: + message_serializer = DEFAULT_GREMLIN_WS_SERIALIZER + if invalid_serializer_input: + print(f'Invalid Gremlin serializer specified, defaulting to {DEFAULT_GREMLIN_WS_SERIALIZER}. ' + f'Valid serializers: {GREMLIN_SERIALIZERS_WS}.') self.traversal_source = traversal_source self.username = username self.password = password self.message_serializer = message_serializer - if include_protocol: - protocol_lower = connection_protocol.lower() - if message_serializer in GREMLIN_SERIALIZERS_HTTP: - connection_protocol = DEFAULT_HTTP_PROTOCOL - if protocol_lower != '' and protocol_lower not in HTTP_PROTOCOL_FORMATS: - print(f"Enforcing HTTP protocol usage for serializer: {message_serializer}.") - else: - if protocol_lower == '': - connection_protocol = DEFAULT_GREMLIN_PROTOCOL - elif protocol_lower in HTTP_PROTOCOL_FORMATS: - connection_protocol = DEFAULT_HTTP_PROTOCOL - elif protocol_lower in WS_PROTOCOL_FORMATS: - connection_protocol = DEFAULT_WS_PROTOCOL - else: - print(f"Invalid connection protocol specified, defaulting to {DEFAULT_GREMLIN_PROTOCOL}. " - f"Valid protocols: [websockets, http].") - connection_protocol = DEFAULT_GREMLIN_PROTOCOL - self.connection_protocol = connection_protocol - def to_dict(self): return self.__dict__ @@ -178,8 +203,8 @@ def __init__(self, host: str, port: int, self.auth_mode = auth_mode self.load_from_s3_arn = load_from_s3_arn self.aws_region = aws_region - default_protocol = DEFAULT_HTTP_PROTOCOL if self._proxy_host != '' else DEFAULT_GREMLIN_PROTOCOL if gremlin_section is not None: + default_protocol = DEFAULT_HTTP_PROTOCOL if self._proxy_host != '' else '' if hasattr(gremlin_section, "connection_protocol"): if self._proxy_host != '' and gremlin_section.connection_protocol != DEFAULT_HTTP_PROTOCOL: print("Enforcing HTTP connection protocol for proxy connections.") @@ -189,9 +214,12 @@ def __init__(self, host: str, port: int, else: final_protocol = default_protocol self.gremlin = GremlinSection(message_serializer=gremlin_section.message_serializer, - connection_protocol=final_protocol, include_protocol=True) + connection_protocol=final_protocol, + include_protocol=True, + neptune_service=self.neptune_service) else: - self.gremlin = GremlinSection(connection_protocol=default_protocol, include_protocol=True) + self.gremlin = GremlinSection(include_protocol=True, + neptune_service=self.neptune_service) self.neo4j = Neo4JSection() else: self.is_neptune_config = False @@ -331,11 +359,14 @@ def generate_default_config(): auth_mode_arg = args.auth_mode if args.auth_mode != '' else AuthModeEnum.DEFAULT.value protocol_arg = args.gremlin_connection_protocol include_protocol = False + gremlin_service = '' if is_allowed_neptune_host(args.host, args.neptune_hosts): include_protocol = True + gremlin_service = args.neptune_service if not protocol_arg: protocol_arg = DEFAULT_HTTP_PROTOCOL \ if args.neptune_service == NEPTUNE_ANALYTICS_SERVICE_NAME else DEFAULT_WS_PROTOCOL + config = generate_config(args.host, int(args.port), AuthModeEnum(auth_mode_arg), args.ssl, args.ssl_verify, @@ -344,7 +375,7 @@ def generate_default_config(): SparqlSection(args.sparql_path, ''), GremlinSection(args.gremlin_traversal_source, args.gremlin_username, args.gremlin_password, args.gremlin_serializer, - protocol_arg, include_protocol), + protocol_arg, include_protocol, gremlin_service), Neo4JSection(args.neo4j_username, args.neo4j_password, args.neo4j_auth, args.neo4j_database), args.neptune_hosts) diff --git a/src/graph_notebook/configuration/get_config.py b/src/graph_notebook/configuration/get_config.py index 7bcbc142..dd4e4908 100644 --- a/src/graph_notebook/configuration/get_config.py +++ b/src/graph_notebook/configuration/get_config.py @@ -9,7 +9,9 @@ SparqlSection, GremlinSection, Neo4JSection from graph_notebook.neptune.client import NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, false_str_variants, \ DEFAULT_NEO4J_USERNAME, DEFAULT_NEO4J_PASSWORD, DEFAULT_NEO4J_DATABASE, \ - NEPTUNE_DB_SERVICE_NAME, DEFAULT_WS_PROTOCOL, DEFAULT_HTTP_PROTOCOL + NEPTUNE_DB_SERVICE_NAME, DEFAULT_WS_PROTOCOL, DEFAULT_HTTP_PROTOCOL, \ + DEFAULT_GREMLIN_HTTP_SERIALIZER, DEFAULT_GREMLIN_WS_SERIALIZER, \ + normalize_service_name neptune_params = ['neptune_service', 'auth_mode', 'load_from_s3_arn', 'aws_region'] neptune_gremlin_params = ['connection_protocol'] @@ -30,18 +32,24 @@ def get_config_from_dict(data: dict, neptune_hosts: list = NEPTUNE_CONFIG_HOST_I is_neptune_host = is_allowed_neptune_host(hostname=data["host"], host_allowlist=neptune_hosts) if is_neptune_host: - neptune_service = data['neptune_service'] if 'neptune_service' in data else NEPTUNE_DB_SERVICE_NAME + if 'neptune_service' in data: + neptune_service = normalize_service_name(data['neptune_service']) + else: + neptune_service = NEPTUNE_DB_SERVICE_NAME if 'gremlin' in data: - data['gremlin']['include_protocol'] = True if 'connection_protocol' not in data['gremlin']: data['gremlin']['connection_protocol'] = DEFAULT_WS_PROTOCOL \ if neptune_service == NEPTUNE_DB_SERVICE_NAME else DEFAULT_HTTP_PROTOCOL - gremlin_section = GremlinSection(**data['gremlin']) + gremlin_section = GremlinSection(**data['gremlin'], + include_protocol=True, + neptune_service=neptune_service) if gremlin_section.to_dict()['traversal_source'] != 'g': print('Ignoring custom traversal source, Amazon Neptune does not support this functionality.\n') else: protocol = DEFAULT_WS_PROTOCOL if neptune_service == NEPTUNE_DB_SERVICE_NAME else DEFAULT_HTTP_PROTOCOL - gremlin_section = GremlinSection(include_protocol=True, connection_protocol=protocol) + gremlin_section = GremlinSection(include_protocol=True, + connection_protocol=protocol, + neptune_service=neptune_service) if neo4j_section.to_dict()['username'] != DEFAULT_NEO4J_USERNAME \ or neo4j_section.to_dict()['password'] != DEFAULT_NEO4J_PASSWORD: print('Ignoring Neo4J custom authentication, Amazon Neptune does not support this functionality.\n') diff --git a/src/graph_notebook/magics/graph_magic.py b/src/graph_notebook/magics/graph_magic.py index 1ca6b78e..dcc5ce01 100644 --- a/src/graph_notebook/magics/graph_magic.py +++ b/src/graph_notebook/magics/graph_magic.py @@ -54,7 +54,7 @@ SPARQL_EXPLAIN_MODES, OPENCYPHER_EXPLAIN_MODES, GREMLIN_EXPLAIN_MODES, \ OPENCYPHER_PLAN_CACHE_MODES, OPENCYPHER_DEFAULT_TIMEOUT, OPENCYPHER_STATUS_STATE_MODES, \ normalize_service_name, NEPTUNE_DB_SERVICE_NAME, NEPTUNE_ANALYTICS_SERVICE_NAME, GRAPH_PG_INFO_METRICS, \ - DEFAULT_GREMLIN_PROTOCOL, GREMLIN_PROTOCOL_FORMATS, DEFAULT_HTTP_PROTOCOL, DEFAULT_WS_PROTOCOL, \ + GREMLIN_PROTOCOL_FORMATS, DEFAULT_HTTP_PROTOCOL, DEFAULT_WS_PROTOCOL, \ GREMLIN_SERIALIZERS_WS, GREMLIN_SERIALIZERS_CLASS_TO_MIME_MAP, normalize_protocol_name, generate_snapshot_name) from graph_notebook.network import SPARQLNetwork from graph_notebook.network.gremlin.GremlinNetwork import parse_pattern_list_str, GremlinNetwork @@ -1250,11 +1250,18 @@ def gremlin(self, line, cell, local_ns: dict = None): query_start = time.time() * 1000 # time.time() returns time in seconds w/high precision; x1000 to get in ms if self.client.is_neptune_domain(): if args.connection_protocol != '': - connection_protocol = normalize_protocol_name(args.connection_protocol) + connection_protocol, bad_protocol_input = normalize_protocol_name(args.connection_protocol) + if bad_protocol_input: + if self.client.is_analytics_domain(): + connection_protocol = DEFAULT_HTTP_PROTOCOL + else: + connection_protocol = DEFAULT_WS_PROTOCOL + print(f"Connection protocol input is invalid for Neptune, " + f"defaulting to {connection_protocol}.") if connection_protocol == DEFAULT_WS_PROTOCOL and \ self.graph_notebook_config.gremlin.message_serializer not in GREMLIN_SERIALIZERS_WS: - print("Unsupported serializer for GremlinPython client, " - "compatible serializers are: {GREMLIN_SERIALIZERS_WS}") + print(f"Serializer is unsupported for GremlinPython client, " + f"compatible serializers are: {GREMLIN_SERIALIZERS_WS}") print("Defaulting to HTTP protocol.") connection_protocol = DEFAULT_HTTP_PROTOCOL else: diff --git a/src/graph_notebook/neptune/client.py b/src/graph_notebook/neptune/client.py index e113694f..c82d5c60 100644 --- a/src/graph_notebook/neptune/client.py +++ b/src/graph_notebook/neptune/client.py @@ -25,6 +25,7 @@ from neo4j.exceptions import AuthError from base64 import b64encode import nest_asyncio +from networkx import is_valid_directed_joint_degree from graph_notebook.neptune.bolt_auth_token import NeptuneBoltAuthToken @@ -139,7 +140,10 @@ GREMLIN_SERIALIZERS_WS = [GRAPHSONV2, GRAPHSONV3, GRAPHBINARYV1] GREMLIN_SERIALIZERS_HTTP = [GRAPHSONV1, GRAPHSONV1_UNTYPED, GRAPHSONV2_UNTYPED, GRAPHSONV3_UNTYPED] GREMLIN_SERIALIZERS_ALL = GREMLIN_SERIALIZERS_WS + GREMLIN_SERIALIZERS_HTTP -DEFAULT_GREMLIN_SERIALIZER = GRAPHSONV1_UNTYPED +NEPTUNE_GREMLIN_SERIALIZERS_HTTP = [GRAPHSONV1_UNTYPED, GRAPHSONV2_UNTYPED, GRAPHSONV3_UNTYPED] +DEFAULT_GREMLIN_WS_SERIALIZER = GRAPHSONV3 +DEFAULT_GREMLIN_HTTP_SERIALIZER = GRAPHSONV3_UNTYPED +DEFAULT_GREMLIN_SERIALIZER = GRAPHSONV3_UNTYPED DEFAULT_WS_PROTOCOL = "websockets" DEFAULT_HTTP_PROTOCOL = "http" @@ -188,13 +192,40 @@ def get_gremlin_serializer_mime(serializer_str: str): def normalize_protocol_name(protocol: str): + protocol = protocol.lower() + is_bad_protocol = False if protocol in WS_PROTOCOL_FORMATS: - return DEFAULT_WS_PROTOCOL + protocol = DEFAULT_WS_PROTOCOL elif protocol in HTTP_PROTOCOL_FORMATS: - return DEFAULT_HTTP_PROTOCOL + protocol = DEFAULT_HTTP_PROTOCOL else: - print(f"Provided connection protocol is invalid for Neptune, defaulting to {DEFAULT_GREMLIN_PROTOCOL}.") - return DEFAULT_GREMLIN_PROTOCOL + protocol = '' + is_bad_protocol = True + return protocol, is_bad_protocol + + +def normalize_serializer_class_name(serializer: str): + serializer_lower = serializer.lower() + is_bad_serializer = False + if 'graphson' in serializer_lower: + message_serializer = 'GraphSON' + if 'untyped' in serializer_lower: + message_serializer += 'Untyped' + if 'v1' in serializer_lower: + if 'untyped' in serializer_lower: + message_serializer += 'MessageSerializerV1' + else: + message_serializer += 'MessageSerializerGremlinV1' + elif 'v2' in serializer_lower: + message_serializer += 'MessageSerializerV2' + else: + message_serializer += 'MessageSerializerV3' + elif 'graphbinary' in serializer_lower: + message_serializer = GRAPHBINARYV1 + else: + message_serializer = '' + is_bad_serializer = True + return message_serializer, is_bad_serializer def normalize_service_name(neptune_service: str): diff --git a/test/unit/configuration/test_configuration.py b/test/unit/configuration/test_configuration.py index cf95a37d..4687d303 100644 --- a/test/unit/configuration/test_configuration.py +++ b/test/unit/configuration/test_configuration.py @@ -10,7 +10,7 @@ from graph_notebook.configuration.generate_config import Configuration, DEFAULT_AUTH_MODE, AuthModeEnum, \ generate_config, generate_default_config, GremlinSection from graph_notebook.neptune.client import NEPTUNE_DB_SERVICE_NAME, NEPTUNE_ANALYTICS_SERVICE_NAME, \ - DEFAULT_GREMLIN_PROTOCOL, DEFAULT_HTTP_PROTOCOL, NEPTUNE_CONFIG_HOST_IDENTIFIERS + DEFAULT_WS_PROTOCOL, DEFAULT_HTTP_PROTOCOL, NEPTUNE_CONFIG_HOST_IDENTIFIERS class TestGenerateConfiguration(unittest.TestCase): @@ -50,8 +50,8 @@ def test_generate_default_config(self): self.assertEqual('g', config.gremlin.traversal_source) self.assertEqual('', config.gremlin.username) self.assertEqual('', config.gremlin.password) - self.assertEqual(DEFAULT_GREMLIN_PROTOCOL, config.gremlin.connection_protocol) - self.assertEqual('GraphSONUntypedMessageSerializerV1', config.gremlin.message_serializer) + self.assertEqual(DEFAULT_WS_PROTOCOL, config.gremlin.connection_protocol) + self.assertEqual('GraphSONMessageSerializerV3', config.gremlin.message_serializer) self.assertEqual('neo4j', config.neo4j.username) self.assertEqual('password', config.neo4j.password) self.assertEqual(True, config.neo4j.auth) @@ -170,7 +170,7 @@ def test_get_configuration_generic_required_input(self): 'traversal_source': 'g', 'username': '', 'password': '', - 'message_serializer': 'GraphSONUntypedMessageSerializerV1' + 'message_serializer': 'GraphSONMessageSerializerV3' }, 'neo4j': { 'username': 'neo4j', @@ -267,8 +267,8 @@ def test_get_configuration_neptune_required_input(self): 'traversal_source': 'g', 'username': '', 'password': '', - 'message_serializer': 'GraphSONUntypedMessageSerializerV1', - 'connection_protocol': 'http' + 'message_serializer': 'GraphSONMessageSerializerV3', + 'connection_protocol': 'websockets' }, 'neo4j': { 'username': 'neo4j', @@ -472,21 +472,43 @@ def test_configuration_gremlinsection_generic_default(self): self.assertEqual(config.gremlin.traversal_source, 'g') self.assertEqual(config.gremlin.username, '') self.assertEqual(config.gremlin.password, '') - self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV1') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONMessageSerializerV3') + self.assertFalse(hasattr(config.gremlin, "connection_protocol")) + + def test_configuration_gremlinsection_generic_override_protocol(self): + config = Configuration('localhost', + self.port, + gremlin_section=GremlinSection(connection_protocol='http'), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONMessageSerializerV3') self.assertFalse(hasattr(config.gremlin, "connection_protocol")) - def test_configuration_gremlinsection_generic_override(self): + def test_configuration_gremlinsection_generic_override_serializer_invalid(self): + config = Configuration('localhost', + self.port, + gremlin_section=GremlinSection(message_serializer='not_a_serializer'), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONMessageSerializerV3') + self.assertFalse(hasattr(config.gremlin, "connection_protocol")) + + def test_configuration_gremlinsection_generic_override_serializer_http_only(self): config = Configuration('localhost', self.port, gremlin_section=GremlinSection(traversal_source='t', username='foo', password='bar', - message_serializer='graphbinary'), + message_serializer='GraphSONUntypedMessageSerializerV1'), ) self.assertEqual(config.gremlin.traversal_source, 't') self.assertEqual(config.gremlin.username, 'foo') self.assertEqual(config.gremlin.password, 'bar') - self.assertEqual(config.gremlin.message_serializer, 'GraphBinaryMessageSerializerV1') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONMessageSerializerV3') self.assertFalse(hasattr(config.gremlin, "connection_protocol")) def test_configuration_gremlinsection_neptune_default(self): @@ -494,10 +516,10 @@ def test_configuration_gremlinsection_neptune_default(self): self.assertEqual(config.gremlin.traversal_source, 'g') self.assertEqual(config.gremlin.username, '') self.assertEqual(config.gremlin.password, '') - self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV1') - self.assertEqual(config.gremlin.connection_protocol, DEFAULT_GREMLIN_PROTOCOL) + self.assertEqual(config.gremlin.message_serializer, 'GraphSONMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_WS_PROTOCOL) - def test_configuration_gremlinsection_neptune_override(self): + def test_configuration_gremlinsection_neptune_override_all(self): config = Configuration(self.neptune_host_reg, self.port, gremlin_section=GremlinSection(traversal_source='t', @@ -510,14 +532,248 @@ def test_configuration_gremlinsection_neptune_override(self): self.assertEqual(config.gremlin.traversal_source, 'g') self.assertEqual(config.gremlin.username, '') self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + + def test_configuration_gremlinsection_neptune_default_db(self): + config = Configuration(self.neptune_host_reg, self.port, neptune_service=NEPTUNE_DB_SERVICE_NAME) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_WS_PROTOCOL) + + def test_configuration_gremlinsection_neptune_db_override_protocol(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_DB_SERVICE_NAME, + gremlin_section=GremlinSection(connection_protocol='http', + include_protocol=True, + neptune_service=NEPTUNE_DB_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + + def test_configuration_gremlinsection_neptune_db_override_protocol_invalid(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_DB_SERVICE_NAME, + gremlin_section=GremlinSection(connection_protocol='not_a_protocol', + include_protocol=True, + neptune_service=NEPTUNE_DB_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_WS_PROTOCOL) + + def test_configuration_gremlinsection_neptune_db_override_serializer(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_DB_SERVICE_NAME, + gremlin_section=GremlinSection(message_serializer='graphbinary', + include_protocol=True, + neptune_service=NEPTUNE_DB_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphBinaryMessageSerializerV1') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_WS_PROTOCOL) + + def test_configuration_gremlinsection_neptune_db_override_serializer_invalid(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_DB_SERVICE_NAME, + gremlin_section=GremlinSection(message_serializer='not_a_serializer', + include_protocol=True, + neptune_service=NEPTUNE_DB_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_WS_PROTOCOL) + + def test_configuration_gremlinsection_neptune_db_override_http_protocol_and_serializer(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_DB_SERVICE_NAME, + gremlin_section=GremlinSection(connection_protocol='http', + message_serializer='graphsonv1untyped', + include_protocol=True, + neptune_service=NEPTUNE_DB_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV1') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + + def test_configuration_gremlinsection_neptune_db_override_http_protocol_and_serializer_invalid(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_DB_SERVICE_NAME, + gremlin_section=GremlinSection(connection_protocol='http', + message_serializer='not_a_serializer', + include_protocol=True, + neptune_service=NEPTUNE_DB_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + + def test_configuration_gremlinsection_neptune_db_override_http_protocol_and_serializer_not_graphson_untyped(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_DB_SERVICE_NAME, + gremlin_section=GremlinSection(connection_protocol='http', + message_serializer='graphbinaryv1', + include_protocol=True, + neptune_service=NEPTUNE_DB_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + + def test_configuration_gremlinsection_neptune_db_override_ws_protocol_and_serializer(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_DB_SERVICE_NAME, + gremlin_section=GremlinSection(connection_protocol='ws', + message_serializer='graphbinaryv1', + include_protocol=True, + neptune_service=NEPTUNE_DB_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphBinaryMessageSerializerV1') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_WS_PROTOCOL) + + def test_configuration_gremlinsection_neptune_db_override_ws_protocol_and_serializer_invalid(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_DB_SERVICE_NAME, + gremlin_section=GremlinSection(connection_protocol='ws', + message_serializer='graphbinaryv1', + include_protocol=True, + neptune_service=NEPTUNE_DB_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') self.assertEqual(config.gremlin.message_serializer, 'GraphBinaryMessageSerializerV1') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_WS_PROTOCOL) + + def test_configuration_gremlinsection_neptune_db_override_ws_protocol_and_serializer_http_only(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_DB_SERVICE_NAME, + gremlin_section=GremlinSection(connection_protocol='ws', + message_serializer='graphsonv3untyped', + include_protocol=True, + neptune_service=NEPTUNE_DB_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_WS_PROTOCOL) + + def test_configuration_gremlinsection_neptune_default_analytics(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + + def test_configuration_gremlinsection_neptune_analytics_override_ws_protocol(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME, + gremlin_section=GremlinSection(connection_protocol='ws', + include_protocol=True, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + + def test_configuration_gremlinsection_neptune_analytics_override_serializer(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME, + gremlin_section=GremlinSection(message_serializer='graphsonv1untyped', + include_protocol=True, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV1') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + + def test_configuration_gremlinsection_neptune_analytics_override_serializer_invalid(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME, + gremlin_section=GremlinSection(message_serializer='not_a_serializer', + include_protocol=True, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + + def test_configuration_gremlinsection_neptune_analytics_override_serializer_not_graphson_untyped(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME, + gremlin_section=GremlinSection(message_serializer='graphbinaryv1', + include_protocol=True, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + + def test_configuration_gremlinsection_neptune_analytics_override_http_protocol(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME, + gremlin_section=GremlinSection(connection_protocol='http', + include_protocol=True, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV3') self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) def test_configuration_gremlinsection_protocol_neptune_default_with_proxy(self): config = Configuration(self.neptune_host_reg, self.port, proxy_host='test_proxy') - self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_WS_PROTOCOL) def test_configuration_gremlinsection_protocol_neptune_override_with_proxy(self): config = Configuration(self.neptune_host_reg, diff --git a/test/unit/configuration/test_configuration_from_main.py b/test/unit/configuration/test_configuration_from_main.py index 0d625517..73dd2694 100644 --- a/test/unit/configuration/test_configuration_from_main.py +++ b/test/unit/configuration/test_configuration_from_main.py @@ -9,7 +9,7 @@ from graph_notebook.configuration.generate_config import AuthModeEnum, Configuration, GremlinSection from graph_notebook.configuration.get_config import get_config from graph_notebook.neptune.client import (NEPTUNE_DB_SERVICE_NAME, NEPTUNE_ANALYTICS_SERVICE_NAME, - DEFAULT_HTTP_PROTOCOL) + DEFAULT_HTTP_PROTOCOL, DEFAULT_WS_PROTOCOL) class TestGenerateConfigurationMain(unittest.TestCase): @@ -135,7 +135,7 @@ def test_generate_configuration_main_gremlin_protocol_no_service(self): self.assertEqual(0, result) config = get_config(self.test_file_path) config_dict = config.to_dict() - self.assertEqual(DEFAULT_HTTP_PROTOCOL, config_dict['gremlin']['connection_protocol']) + self.assertEqual(DEFAULT_WS_PROTOCOL, config_dict['gremlin']['connection_protocol']) def test_generate_configuration_main_gremlin_protocol_db(self): result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config ' @@ -149,7 +149,7 @@ def test_generate_configuration_main_gremlin_protocol_db(self): self.assertEqual(0, result) config = get_config(self.test_file_path) config_dict = config.to_dict() - self.assertEqual(DEFAULT_HTTP_PROTOCOL, config_dict['gremlin']['connection_protocol']) + self.assertEqual(DEFAULT_WS_PROTOCOL, config_dict['gremlin']['connection_protocol']) def test_generate_configuration_main_gremlin_protocol_analytics(self): result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config ' @@ -165,6 +165,48 @@ def test_generate_configuration_main_gremlin_protocol_analytics(self): config_dict = config.to_dict() self.assertEqual(DEFAULT_HTTP_PROTOCOL, config_dict['gremlin']['connection_protocol']) + def test_generate_configuration_main_gremlin_serializer_no_service(self): + result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config ' + f'--host "{self.neptune_host_reg}" ' + f'--port "{self.port}" ' + f'--neptune_service "" ' + f'--auth_mode "" ' + f'--ssl "" ' + f'--load_from_s3_arn "" ' + f'--config_destination="{self.test_file_path}" ') + self.assertEqual(0, result) + config = get_config(self.test_file_path) + config_dict = config.to_dict() + self.assertEqual('GraphSONMessageSerializerV3', config_dict['gremlin']['message_serializer']) + + def test_generate_configuration_main_gremlin_serializer_db(self): + result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config ' + f'--host "{self.neptune_host_reg}" ' + f'--port "{self.port}" ' + f'--neptune_service "{NEPTUNE_DB_SERVICE_NAME}" ' + f'--auth_mode "" ' + f'--ssl "" ' + f'--load_from_s3_arn "" ' + f'--config_destination="{self.test_file_path}" ') + self.assertEqual(0, result) + config = get_config(self.test_file_path) + config_dict = config.to_dict() + self.assertEqual('GraphSONMessageSerializerV3', config_dict['gremlin']['message_serializer']) + + def test_generate_configuration_main_gremlin_serializer_analytics(self): + result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config ' + f'--host "{self.neptune_host_reg}" ' + f'--port "{self.port}" ' + f'--neptune_service "{NEPTUNE_ANALYTICS_SERVICE_NAME}" ' + f'--auth_mode "" ' + f'--ssl "" ' + f'--load_from_s3_arn "" ' + f'--config_destination="{self.test_file_path}" ') + self.assertEqual(0, result) + config = get_config(self.test_file_path) + config_dict = config.to_dict() + self.assertEqual('GraphSONUntypedMessageSerializerV3', config_dict['gremlin']['message_serializer']) + def test_generate_configuration_main_empty_args_custom(self): expected_config = Configuration(self.neptune_host_custom, self.port, neptune_hosts=self.custom_hosts_list) result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config '