Skip to content

Commit

Permalink
Enforce Gremlin protocol and serializer based on database type (#697)
Browse files Browse the repository at this point in the history
* Set allowed and default Gremlin protocol and serializer dynamically

* Add unit test suite

* update changelog
  • Loading branch information
michaelnchin authored Sep 20, 2024
1 parent 8fa1749 commit 9997b1c
Show file tree
Hide file tree
Showing 7 changed files with 457 additions and 81 deletions.
1 change: 1 addition & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Starting with v1.31.6, this file will contain a record of major features and upd
## Upcoming

- Updated Gremlin config `message_serializer` to accept all TinkerPop serializers ([Link to PR](https://github.com/aws/graph-notebook/pull/685))
- Implemented service-based dynamic allowlists and defaults for Gremlin serializer and protocol combinations ([Link to PR](https://github.com/aws/graph-notebook/pull/697))
- Added `%get_import_task` line magic ([Link to PR](https://github.com/aws/graph-notebook/pull/668))
- Added `--export-to` JSON file option to `%%graph_notebook_config` ([Link to PR](https://github.com/aws/graph-notebook/pull/684))
- Deprecated Python 3.8 support ([Link to PR](https://github.com/aws/graph-notebook/pull/683))
Expand Down
131 changes: 81 additions & 50 deletions src/graph_notebook/configuration/generate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
HTTP_PROTOCOL_FORMATS, WS_PROTOCOL_FORMATS,
DEFAULT_NEO4J_USERNAME, DEFAULT_NEO4J_PASSWORD, DEFAULT_NEO4J_DATABASE,
NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, false_str_variants,
GRAPHBINARYV1, GREMLIN_SERIALIZERS_HTTP,
GRAPHBINARYV1, GREMLIN_SERIALIZERS_HTTP, GREMLIN_SERIALIZERS_WS,
GREMLIN_SERIALIZERS_ALL, NEPTUNE_GREMLIN_SERIALIZERS_HTTP,
DEFAULT_GREMLIN_WS_SERIALIZER, DEFAULT_GREMLIN_HTTP_SERIALIZER,
NEPTUNE_DB_SERVICE_NAME, NEPTUNE_ANALYTICS_SERVICE_NAME,
normalize_service_name)
normalize_service_name, normalize_protocol_name,
normalize_serializer_class_name)

DEFAULT_CONFIG_LOCATION = os.path.expanduser('~/graph_notebook_config.json')

Expand Down Expand Up @@ -57,7 +60,8 @@ class GremlinSection(object):
"""

def __init__(self, traversal_source: str = '', username: str = '', password: str = '',
message_serializer: str = '', connection_protocol: str = '', include_protocol: bool = False):
message_serializer: str = '', connection_protocol: str = '',
include_protocol: bool = False, neptune_service: str = ''):
"""
:param traversal_source: used to specify the traversal source for a Gremlin traversal, in the case that we are
connected to an endpoint that can access multiple graphs.
Expand All @@ -71,57 +75,78 @@ def __init__(self, traversal_source: str = '', username: str = '', password: str
if traversal_source == '':
traversal_source = DEFAULT_GREMLIN_TRAVERSAL_SOURCE

serializer_lower = message_serializer.lower()
# TODO: Update with untyped serializers once supported in GremlinPython
# Accept TinkerPop serializer class name
# https://github.com/apache/tinkerpop/blob/fd040c94a66516e473811fe29eaeaf4081cf104c/docs/src/reference/gremlin-applications.asciidoc#graphson
# https://github.com/apache/tinkerpop/blob/fd040c94a66516e473811fe29eaeaf4081cf104c/docs/src/reference/gremlin-applications.asciidoc#graphbinary
if serializer_lower == '':
message_serializer = DEFAULT_GREMLIN_SERIALIZER
elif 'graphson' in serializer_lower:
message_serializer = 'GraphSON'
if 'untyped' in serializer_lower:
message_serializer += 'Untyped'
if 'v1' in serializer_lower:
if 'untyped' in serializer_lower:
message_serializer += 'MessageSerializerV1'
else:
message_serializer += 'MessageSerializerGremlinV1'
elif 'v2' in serializer_lower:
message_serializer += 'MessageSerializerV2'
invalid_serializer_input = False
if message_serializer != '':
message_serializer, invalid_serializer_input = normalize_serializer_class_name(message_serializer)

if include_protocol:
# Neptune endpoint
invalid_protocol_input = False
if connection_protocol != '':
connection_protocol, invalid_protocol_input = normalize_protocol_name(connection_protocol)

if neptune_service == NEPTUNE_ANALYTICS_SERVICE_NAME:
if connection_protocol != DEFAULT_HTTP_PROTOCOL:
if invalid_protocol_input:
print(f"Invalid connection protocol specified, you must use {DEFAULT_HTTP_PROTOCOL}. ")
elif connection_protocol == DEFAULT_WS_PROTOCOL:
print(f"Enforcing HTTP protocol.")
connection_protocol = DEFAULT_HTTP_PROTOCOL
# temporary restriction until GraphSON-typed and GraphBinary results are supported
if message_serializer not in NEPTUNE_GREMLIN_SERIALIZERS_HTTP:
if message_serializer not in GREMLIN_SERIALIZERS_ALL:
if invalid_serializer_input:
print(f"Invalid serializer specified, defaulting to {DEFAULT_GREMLIN_HTTP_SERIALIZER}. "
f"Valid serializers: {NEPTUNE_GREMLIN_SERIALIZERS_HTTP}")
else:
print(f"{message_serializer} is not currently supported for HTTP connections, "
f"defaulting to {DEFAULT_GREMLIN_HTTP_SERIALIZER}. "
f"Please use one of: {NEPTUNE_GREMLIN_SERIALIZERS_HTTP}")
message_serializer = DEFAULT_GREMLIN_HTTP_SERIALIZER
else:
message_serializer += 'MessageSerializerV3'
elif 'graphbinary' in serializer_lower:
message_serializer = GRAPHBINARYV1
if connection_protocol not in [DEFAULT_WS_PROTOCOL, DEFAULT_HTTP_PROTOCOL]:
if invalid_protocol_input:
print(f"Invalid connection protocol specified, defaulting to {DEFAULT_WS_PROTOCOL}. "
f"Valid protocols: [websockets, http].")
connection_protocol = DEFAULT_WS_PROTOCOL

if connection_protocol == DEFAULT_HTTP_PROTOCOL:
# temporary restriction until GraphSON-typed and GraphBinary results are supported
if message_serializer not in NEPTUNE_GREMLIN_SERIALIZERS_HTTP:
if message_serializer not in GREMLIN_SERIALIZERS_ALL:
if invalid_serializer_input:
print(f"Invalid serializer specified, defaulting to {DEFAULT_GREMLIN_HTTP_SERIALIZER}. "
f"Valid serializers: {NEPTUNE_GREMLIN_SERIALIZERS_HTTP}")
else:
print(f"{message_serializer} is not currently supported for HTTP connections, "
f"defaulting to {DEFAULT_GREMLIN_HTTP_SERIALIZER}. "
f"Please use one of: {NEPTUNE_GREMLIN_SERIALIZERS_HTTP}")
message_serializer = DEFAULT_GREMLIN_HTTP_SERIALIZER
else:
if message_serializer not in GREMLIN_SERIALIZERS_WS:
if invalid_serializer_input:
print(f"Invalid serializer specified, defaulting to {DEFAULT_GREMLIN_WS_SERIALIZER}. "
f"Valid serializers: {GREMLIN_SERIALIZERS_WS}")
elif message_serializer != '':
print(f"{message_serializer} is not currently supported by Gremlin Python driver, "
f"defaulting to {DEFAULT_GREMLIN_WS_SERIALIZER}. "
f"Valid serializers: {GREMLIN_SERIALIZERS_WS}")
message_serializer = DEFAULT_GREMLIN_WS_SERIALIZER

self.connection_protocol = connection_protocol
else:
print(f'Invalid Gremlin serializer specified, defaulting to graphsonv3. '
f'Valid serializers: {GREMLIN_SERIALIZERS_HTTP}.')
message_serializer = DEFAULT_GREMLIN_SERIALIZER
# Non-Neptune database - check and set valid WebSockets serializer if invalid/empty
if message_serializer not in GREMLIN_SERIALIZERS_WS:
message_serializer = DEFAULT_GREMLIN_WS_SERIALIZER
if invalid_serializer_input:
print(f'Invalid Gremlin serializer specified, defaulting to {DEFAULT_GREMLIN_WS_SERIALIZER}. '
f'Valid serializers: {GREMLIN_SERIALIZERS_WS}.')

self.traversal_source = traversal_source
self.username = username
self.password = password
self.message_serializer = message_serializer

if include_protocol:
protocol_lower = connection_protocol.lower()
if message_serializer in GREMLIN_SERIALIZERS_HTTP:
connection_protocol = DEFAULT_HTTP_PROTOCOL
if protocol_lower != '' and protocol_lower not in HTTP_PROTOCOL_FORMATS:
print(f"Enforcing HTTP protocol usage for serializer: {message_serializer}.")
else:
if protocol_lower == '':
connection_protocol = DEFAULT_GREMLIN_PROTOCOL
elif protocol_lower in HTTP_PROTOCOL_FORMATS:
connection_protocol = DEFAULT_HTTP_PROTOCOL
elif protocol_lower in WS_PROTOCOL_FORMATS:
connection_protocol = DEFAULT_WS_PROTOCOL
else:
print(f"Invalid connection protocol specified, defaulting to {DEFAULT_GREMLIN_PROTOCOL}. "
f"Valid protocols: [websockets, http].")
connection_protocol = DEFAULT_GREMLIN_PROTOCOL
self.connection_protocol = connection_protocol

def to_dict(self):
return self.__dict__

Expand Down Expand Up @@ -178,8 +203,8 @@ def __init__(self, host: str, port: int,
self.auth_mode = auth_mode
self.load_from_s3_arn = load_from_s3_arn
self.aws_region = aws_region
default_protocol = DEFAULT_HTTP_PROTOCOL if self._proxy_host != '' else DEFAULT_GREMLIN_PROTOCOL
if gremlin_section is not None:
default_protocol = DEFAULT_HTTP_PROTOCOL if self._proxy_host != '' else ''
if hasattr(gremlin_section, "connection_protocol"):
if self._proxy_host != '' and gremlin_section.connection_protocol != DEFAULT_HTTP_PROTOCOL:
print("Enforcing HTTP connection protocol for proxy connections.")
Expand All @@ -189,9 +214,12 @@ def __init__(self, host: str, port: int,
else:
final_protocol = default_protocol
self.gremlin = GremlinSection(message_serializer=gremlin_section.message_serializer,
connection_protocol=final_protocol, include_protocol=True)
connection_protocol=final_protocol,
include_protocol=True,
neptune_service=self.neptune_service)
else:
self.gremlin = GremlinSection(connection_protocol=default_protocol, include_protocol=True)
self.gremlin = GremlinSection(include_protocol=True,
neptune_service=self.neptune_service)
self.neo4j = Neo4JSection()
else:
self.is_neptune_config = False
Expand Down Expand Up @@ -331,11 +359,14 @@ def generate_default_config():
auth_mode_arg = args.auth_mode if args.auth_mode != '' else AuthModeEnum.DEFAULT.value
protocol_arg = args.gremlin_connection_protocol
include_protocol = False
gremlin_service = ''
if is_allowed_neptune_host(args.host, args.neptune_hosts):
include_protocol = True
gremlin_service = args.neptune_service
if not protocol_arg:
protocol_arg = DEFAULT_HTTP_PROTOCOL \
if args.neptune_service == NEPTUNE_ANALYTICS_SERVICE_NAME else DEFAULT_WS_PROTOCOL

config = generate_config(args.host, int(args.port),
AuthModeEnum(auth_mode_arg),
args.ssl, args.ssl_verify,
Expand All @@ -344,7 +375,7 @@ def generate_default_config():
SparqlSection(args.sparql_path, ''),
GremlinSection(args.gremlin_traversal_source, args.gremlin_username,
args.gremlin_password, args.gremlin_serializer,
protocol_arg, include_protocol),
protocol_arg, include_protocol, gremlin_service),
Neo4JSection(args.neo4j_username, args.neo4j_password,
args.neo4j_auth, args.neo4j_database),
args.neptune_hosts)
Expand Down
18 changes: 13 additions & 5 deletions src/graph_notebook/configuration/get_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
SparqlSection, GremlinSection, Neo4JSection
from graph_notebook.neptune.client import NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, false_str_variants, \
DEFAULT_NEO4J_USERNAME, DEFAULT_NEO4J_PASSWORD, DEFAULT_NEO4J_DATABASE, \
NEPTUNE_DB_SERVICE_NAME, DEFAULT_WS_PROTOCOL, DEFAULT_HTTP_PROTOCOL
NEPTUNE_DB_SERVICE_NAME, DEFAULT_WS_PROTOCOL, DEFAULT_HTTP_PROTOCOL, \
DEFAULT_GREMLIN_HTTP_SERIALIZER, DEFAULT_GREMLIN_WS_SERIALIZER, \
normalize_service_name

neptune_params = ['neptune_service', 'auth_mode', 'load_from_s3_arn', 'aws_region']
neptune_gremlin_params = ['connection_protocol']
Expand All @@ -30,18 +32,24 @@ def get_config_from_dict(data: dict, neptune_hosts: list = NEPTUNE_CONFIG_HOST_I
is_neptune_host = is_allowed_neptune_host(hostname=data["host"], host_allowlist=neptune_hosts)

if is_neptune_host:
neptune_service = data['neptune_service'] if 'neptune_service' in data else NEPTUNE_DB_SERVICE_NAME
if 'neptune_service' in data:
neptune_service = normalize_service_name(data['neptune_service'])
else:
neptune_service = NEPTUNE_DB_SERVICE_NAME
if 'gremlin' in data:
data['gremlin']['include_protocol'] = True
if 'connection_protocol' not in data['gremlin']:
data['gremlin']['connection_protocol'] = DEFAULT_WS_PROTOCOL \
if neptune_service == NEPTUNE_DB_SERVICE_NAME else DEFAULT_HTTP_PROTOCOL
gremlin_section = GremlinSection(**data['gremlin'])
gremlin_section = GremlinSection(**data['gremlin'],
include_protocol=True,
neptune_service=neptune_service)
if gremlin_section.to_dict()['traversal_source'] != 'g':
print('Ignoring custom traversal source, Amazon Neptune does not support this functionality.\n')
else:
protocol = DEFAULT_WS_PROTOCOL if neptune_service == NEPTUNE_DB_SERVICE_NAME else DEFAULT_HTTP_PROTOCOL
gremlin_section = GremlinSection(include_protocol=True, connection_protocol=protocol)
gremlin_section = GremlinSection(include_protocol=True,
connection_protocol=protocol,
neptune_service=neptune_service)
if neo4j_section.to_dict()['username'] != DEFAULT_NEO4J_USERNAME \
or neo4j_section.to_dict()['password'] != DEFAULT_NEO4J_PASSWORD:
print('Ignoring Neo4J custom authentication, Amazon Neptune does not support this functionality.\n')
Expand Down
15 changes: 11 additions & 4 deletions src/graph_notebook/magics/graph_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
SPARQL_EXPLAIN_MODES, OPENCYPHER_EXPLAIN_MODES, GREMLIN_EXPLAIN_MODES, \
OPENCYPHER_PLAN_CACHE_MODES, OPENCYPHER_DEFAULT_TIMEOUT, OPENCYPHER_STATUS_STATE_MODES, \
normalize_service_name, NEPTUNE_DB_SERVICE_NAME, NEPTUNE_ANALYTICS_SERVICE_NAME, GRAPH_PG_INFO_METRICS, \
DEFAULT_GREMLIN_PROTOCOL, GREMLIN_PROTOCOL_FORMATS, DEFAULT_HTTP_PROTOCOL, DEFAULT_WS_PROTOCOL, \
GREMLIN_PROTOCOL_FORMATS, DEFAULT_HTTP_PROTOCOL, DEFAULT_WS_PROTOCOL, \
GREMLIN_SERIALIZERS_WS, GREMLIN_SERIALIZERS_CLASS_TO_MIME_MAP, normalize_protocol_name, generate_snapshot_name)
from graph_notebook.network import SPARQLNetwork
from graph_notebook.network.gremlin.GremlinNetwork import parse_pattern_list_str, GremlinNetwork
Expand Down Expand Up @@ -1250,11 +1250,18 @@ def gremlin(self, line, cell, local_ns: dict = None):
query_start = time.time() * 1000 # time.time() returns time in seconds w/high precision; x1000 to get in ms
if self.client.is_neptune_domain():
if args.connection_protocol != '':
connection_protocol = normalize_protocol_name(args.connection_protocol)
connection_protocol, bad_protocol_input = normalize_protocol_name(args.connection_protocol)
if bad_protocol_input:
if self.client.is_analytics_domain():
connection_protocol = DEFAULT_HTTP_PROTOCOL
else:
connection_protocol = DEFAULT_WS_PROTOCOL
print(f"Connection protocol input is invalid for Neptune, "
f"defaulting to {connection_protocol}.")
if connection_protocol == DEFAULT_WS_PROTOCOL and \
self.graph_notebook_config.gremlin.message_serializer not in GREMLIN_SERIALIZERS_WS:
print("Unsupported serializer for GremlinPython client, "
"compatible serializers are: {GREMLIN_SERIALIZERS_WS}")
print(f"Serializer is unsupported for GremlinPython client, "
f"compatible serializers are: {GREMLIN_SERIALIZERS_WS}")
print("Defaulting to HTTP protocol.")
connection_protocol = DEFAULT_HTTP_PROTOCOL
else:
Expand Down
Loading

0 comments on commit 9997b1c

Please sign in to comment.