diff --git a/src/graph_notebook/configuration/generate_config.py b/src/graph_notebook/configuration/generate_config.py index 698b6c04..44270328 100644 --- a/src/graph_notebook/configuration/generate_config.py +++ b/src/graph_notebook/configuration/generate_config.py @@ -359,11 +359,14 @@ def generate_default_config(): auth_mode_arg = args.auth_mode if args.auth_mode != '' else AuthModeEnum.DEFAULT.value protocol_arg = args.gremlin_connection_protocol include_protocol = False + gremlin_service = '' if is_allowed_neptune_host(args.host, args.neptune_hosts): include_protocol = True + gremlin_service = args.neptune_service if not protocol_arg: protocol_arg = DEFAULT_HTTP_PROTOCOL \ if args.neptune_service == NEPTUNE_ANALYTICS_SERVICE_NAME else DEFAULT_WS_PROTOCOL + config = generate_config(args.host, int(args.port), AuthModeEnum(auth_mode_arg), args.ssl, args.ssl_verify, @@ -372,7 +375,7 @@ def generate_default_config(): SparqlSection(args.sparql_path, ''), GremlinSection(args.gremlin_traversal_source, args.gremlin_username, args.gremlin_password, args.gremlin_serializer, - protocol_arg, include_protocol), + protocol_arg, include_protocol, gremlin_service), Neo4JSection(args.neo4j_username, args.neo4j_password, args.neo4j_auth, args.neo4j_database), args.neptune_hosts) diff --git a/test/unit/configuration/test_configuration.py b/test/unit/configuration/test_configuration.py index 12420c16..4687d303 100644 --- a/test/unit/configuration/test_configuration.py +++ b/test/unit/configuration/test_configuration.py @@ -475,18 +475,40 @@ def test_configuration_gremlinsection_generic_default(self): self.assertEqual(config.gremlin.message_serializer, 'GraphSONMessageSerializerV3') self.assertFalse(hasattr(config.gremlin, "connection_protocol")) - def test_configuration_gremlinsection_generic_override(self): + def test_configuration_gremlinsection_generic_override_protocol(self): + config = Configuration('localhost', + self.port, + gremlin_section=GremlinSection(connection_protocol='http'), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONMessageSerializerV3') + self.assertFalse(hasattr(config.gremlin, "connection_protocol")) + + def test_configuration_gremlinsection_generic_override_serializer_invalid(self): + config = Configuration('localhost', + self.port, + gremlin_section=GremlinSection(message_serializer='not_a_serializer'), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONMessageSerializerV3') + self.assertFalse(hasattr(config.gremlin, "connection_protocol")) + + def test_configuration_gremlinsection_generic_override_serializer_http_only(self): config = Configuration('localhost', self.port, gremlin_section=GremlinSection(traversal_source='t', username='foo', password='bar', - message_serializer='graphbinary'), + message_serializer='GraphSONUntypedMessageSerializerV1'), ) self.assertEqual(config.gremlin.traversal_source, 't') self.assertEqual(config.gremlin.username, 'foo') self.assertEqual(config.gremlin.password, 'bar') - self.assertEqual(config.gremlin.message_serializer, 'GraphBinaryMessageSerializerV1') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONMessageSerializerV3') self.assertFalse(hasattr(config.gremlin, "connection_protocol")) def test_configuration_gremlinsection_neptune_default(self): @@ -497,7 +519,7 @@ def test_configuration_gremlinsection_neptune_default(self): self.assertEqual(config.gremlin.message_serializer, 'GraphSONMessageSerializerV3') self.assertEqual(config.gremlin.connection_protocol, DEFAULT_WS_PROTOCOL) - def test_configuration_gremlinsection_neptune_override(self): + def test_configuration_gremlinsection_neptune_override_all(self): config = Configuration(self.neptune_host_reg, self.port, gremlin_section=GremlinSection(traversal_source='t', @@ -513,6 +535,240 @@ def test_configuration_gremlinsection_neptune_override(self): self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV3') self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + def test_configuration_gremlinsection_neptune_default_db(self): + config = Configuration(self.neptune_host_reg, self.port, neptune_service=NEPTUNE_DB_SERVICE_NAME) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_WS_PROTOCOL) + + def test_configuration_gremlinsection_neptune_db_override_protocol(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_DB_SERVICE_NAME, + gremlin_section=GremlinSection(connection_protocol='http', + include_protocol=True, + neptune_service=NEPTUNE_DB_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + + def test_configuration_gremlinsection_neptune_db_override_protocol_invalid(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_DB_SERVICE_NAME, + gremlin_section=GremlinSection(connection_protocol='not_a_protocol', + include_protocol=True, + neptune_service=NEPTUNE_DB_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_WS_PROTOCOL) + + def test_configuration_gremlinsection_neptune_db_override_serializer(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_DB_SERVICE_NAME, + gremlin_section=GremlinSection(message_serializer='graphbinary', + include_protocol=True, + neptune_service=NEPTUNE_DB_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphBinaryMessageSerializerV1') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_WS_PROTOCOL) + + def test_configuration_gremlinsection_neptune_db_override_serializer_invalid(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_DB_SERVICE_NAME, + gremlin_section=GremlinSection(message_serializer='not_a_serializer', + include_protocol=True, + neptune_service=NEPTUNE_DB_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_WS_PROTOCOL) + + def test_configuration_gremlinsection_neptune_db_override_http_protocol_and_serializer(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_DB_SERVICE_NAME, + gremlin_section=GremlinSection(connection_protocol='http', + message_serializer='graphsonv1untyped', + include_protocol=True, + neptune_service=NEPTUNE_DB_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV1') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + + def test_configuration_gremlinsection_neptune_db_override_http_protocol_and_serializer_invalid(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_DB_SERVICE_NAME, + gremlin_section=GremlinSection(connection_protocol='http', + message_serializer='not_a_serializer', + include_protocol=True, + neptune_service=NEPTUNE_DB_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + + def test_configuration_gremlinsection_neptune_db_override_http_protocol_and_serializer_not_graphson_untyped(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_DB_SERVICE_NAME, + gremlin_section=GremlinSection(connection_protocol='http', + message_serializer='graphbinaryv1', + include_protocol=True, + neptune_service=NEPTUNE_DB_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + + def test_configuration_gremlinsection_neptune_db_override_ws_protocol_and_serializer(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_DB_SERVICE_NAME, + gremlin_section=GremlinSection(connection_protocol='ws', + message_serializer='graphbinaryv1', + include_protocol=True, + neptune_service=NEPTUNE_DB_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphBinaryMessageSerializerV1') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_WS_PROTOCOL) + + def test_configuration_gremlinsection_neptune_db_override_ws_protocol_and_serializer_invalid(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_DB_SERVICE_NAME, + gremlin_section=GremlinSection(connection_protocol='ws', + message_serializer='graphbinaryv1', + include_protocol=True, + neptune_service=NEPTUNE_DB_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphBinaryMessageSerializerV1') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_WS_PROTOCOL) + + def test_configuration_gremlinsection_neptune_db_override_ws_protocol_and_serializer_http_only(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_DB_SERVICE_NAME, + gremlin_section=GremlinSection(connection_protocol='ws', + message_serializer='graphsonv3untyped', + include_protocol=True, + neptune_service=NEPTUNE_DB_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_WS_PROTOCOL) + + def test_configuration_gremlinsection_neptune_default_analytics(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + + def test_configuration_gremlinsection_neptune_analytics_override_ws_protocol(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME, + gremlin_section=GremlinSection(connection_protocol='ws', + include_protocol=True, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + + def test_configuration_gremlinsection_neptune_analytics_override_serializer(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME, + gremlin_section=GremlinSection(message_serializer='graphsonv1untyped', + include_protocol=True, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV1') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + + def test_configuration_gremlinsection_neptune_analytics_override_serializer_invalid(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME, + gremlin_section=GremlinSection(message_serializer='not_a_serializer', + include_protocol=True, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + + def test_configuration_gremlinsection_neptune_analytics_override_serializer_not_graphson_untyped(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME, + gremlin_section=GremlinSection(message_serializer='graphbinaryv1', + include_protocol=True, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + + def test_configuration_gremlinsection_neptune_analytics_override_http_protocol(self): + config = Configuration(self.neptune_host_reg, + self.port, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME, + gremlin_section=GremlinSection(connection_protocol='http', + include_protocol=True, + neptune_service=NEPTUNE_ANALYTICS_SERVICE_NAME), + ) + self.assertEqual(config.gremlin.traversal_source, 'g') + self.assertEqual(config.gremlin.username, '') + self.assertEqual(config.gremlin.password, '') + self.assertEqual(config.gremlin.message_serializer, 'GraphSONUntypedMessageSerializerV3') + self.assertEqual(config.gremlin.connection_protocol, DEFAULT_HTTP_PROTOCOL) + def test_configuration_gremlinsection_protocol_neptune_default_with_proxy(self): config = Configuration(self.neptune_host_reg, self.port, diff --git a/test/unit/configuration/test_configuration_from_main.py b/test/unit/configuration/test_configuration_from_main.py index b76dc4af..73dd2694 100644 --- a/test/unit/configuration/test_configuration_from_main.py +++ b/test/unit/configuration/test_configuration_from_main.py @@ -165,6 +165,48 @@ def test_generate_configuration_main_gremlin_protocol_analytics(self): config_dict = config.to_dict() self.assertEqual(DEFAULT_HTTP_PROTOCOL, config_dict['gremlin']['connection_protocol']) + def test_generate_configuration_main_gremlin_serializer_no_service(self): + result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config ' + f'--host "{self.neptune_host_reg}" ' + f'--port "{self.port}" ' + f'--neptune_service "" ' + f'--auth_mode "" ' + f'--ssl "" ' + f'--load_from_s3_arn "" ' + f'--config_destination="{self.test_file_path}" ') + self.assertEqual(0, result) + config = get_config(self.test_file_path) + config_dict = config.to_dict() + self.assertEqual('GraphSONMessageSerializerV3', config_dict['gremlin']['message_serializer']) + + def test_generate_configuration_main_gremlin_serializer_db(self): + result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config ' + f'--host "{self.neptune_host_reg}" ' + f'--port "{self.port}" ' + f'--neptune_service "{NEPTUNE_DB_SERVICE_NAME}" ' + f'--auth_mode "" ' + f'--ssl "" ' + f'--load_from_s3_arn "" ' + f'--config_destination="{self.test_file_path}" ') + self.assertEqual(0, result) + config = get_config(self.test_file_path) + config_dict = config.to_dict() + self.assertEqual('GraphSONMessageSerializerV3', config_dict['gremlin']['message_serializer']) + + def test_generate_configuration_main_gremlin_serializer_analytics(self): + result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config ' + f'--host "{self.neptune_host_reg}" ' + f'--port "{self.port}" ' + f'--neptune_service "{NEPTUNE_ANALYTICS_SERVICE_NAME}" ' + f'--auth_mode "" ' + f'--ssl "" ' + f'--load_from_s3_arn "" ' + f'--config_destination="{self.test_file_path}" ') + self.assertEqual(0, result) + config = get_config(self.test_file_path) + config_dict = config.to_dict() + self.assertEqual('GraphSONUntypedMessageSerializerV3', config_dict['gremlin']['message_serializer']) + def test_generate_configuration_main_empty_args_custom(self): expected_config = Configuration(self.neptune_host_custom, self.port, neptune_hosts=self.custom_hosts_list) result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config '