Skip to content

Commit

Permalink
Update magics for new Neptune Analytics API (#560)
Browse files Browse the repository at this point in the history
* Support new Analytics API

* Fix %load, better %summary mode messaging

* update changelog
  • Loading branch information
michaelnchin authored Feb 1, 2024
1 parent bb96dd8 commit 5705154
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 45 deletions.
4 changes: 3 additions & 1 deletion ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ Starting with v1.31.6, this file will contain a record of major features and upd
## Upcoming
- New Neptune Analytics notebook - Vector Similarity Algorithms ([Link to PR](https://github.com/aws/graph-notebook/pull/555))
- Path: 02-Neptune-Analytics > 02-Graph-Algorithms > 06-Vector-Similarity-Algorithms
- Deprecated Python 3.7 support ([Link to PR](https://github.com/aws/graph-notebook/pull/551))
- Updated various Neptune magics for new Analytics API ([Link to PR](https://github.com/aws/graph-notebook/pull/560))
- Added `%graph_notebook_service` line magic ([Link to PR](https://github.com/aws/graph-notebook/pull/560))
- Added unit abbreviation support to `--max-content-length` ([Link to PR](https://github.com/aws/graph-notebook/pull/553))
- Deprecated Python 3.7 support ([Link to PR](https://github.com/aws/graph-notebook/pull/551))

## Release 4.0.2 (Dec 14, 2023)
- Fixed `neptune_ml_utils` imports in `03-Neptune-ML` samples ([Link to PR](https://github.com/aws/graph-notebook/pull/546))
Expand Down
141 changes: 105 additions & 36 deletions src/graph_notebook/magics/graph_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@
from graph_notebook.decorators.decorators import display_exceptions, magic_variables, neptune_db_only
from graph_notebook.magics.ml import neptune_ml_magic_handler, generate_neptune_ml_parser
from graph_notebook.magics.streams import StreamViewer
from graph_notebook.neptune.client import ClientBuilder, Client,PARALLELISM_OPTIONS, PARALLELISM_HIGH, \
from graph_notebook.neptune.client import ClientBuilder, Client, PARALLELISM_OPTIONS, PARALLELISM_HIGH, \
LOAD_JOB_MODES, MODE_AUTO, FINAL_LOAD_STATUSES, SPARQL_ACTION, FORMAT_CSV, FORMAT_OPENCYPHER, FORMAT_NTRIPLE, \
DB_LOAD_TYPES, ANALYTICS_LOAD_TYPES, VALID_BULK_FORMATS, VALID_INCREMENTAL_FORMATS, \
FORMAT_NQUADS, FORMAT_RDFXML, FORMAT_TURTLE, STREAM_RDF, STREAM_PG, STREAM_ENDPOINTS, \
NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, \
STATISTICS_LANGUAGE_INPUTS, STATISTICS_MODES, SUMMARY_MODES, \
SPARQL_EXPLAIN_MODES, OPENCYPHER_EXPLAIN_MODES, OPENCYPHER_PLAN_CACHE_MODES, OPENCYPHER_DEFAULT_TIMEOUT
STATISTICS_LANGUAGE_INPUTS, STATISTICS_LANGUAGE_INPUTS_SPARQL, STATISTICS_MODES, SUMMARY_MODES, \
SPARQL_EXPLAIN_MODES, OPENCYPHER_EXPLAIN_MODES, OPENCYPHER_PLAN_CACHE_MODES, OPENCYPHER_DEFAULT_TIMEOUT, \
OPENCYPHER_STATUS_STATE_MODES, normalize_service_name
from graph_notebook.network import SPARQLNetwork
from graph_notebook.network.gremlin.GremlinNetwork import parse_pattern_list_str, GremlinNetwork
from graph_notebook.visualization.rows_and_columns import sparql_get_rows_and_columns, opencypher_get_rows_and_columns
Expand Down Expand Up @@ -255,22 +256,31 @@ def get_load_ids(neptune_client):
return ids, res


def process_statistics_400(is_summary: bool, response):
def process_statistics_400(response, is_summary: bool = False, is_analytics: bool = False):
bad_request_res = json.loads(response.text)
res_code = bad_request_res['code']
if res_code == 'StatisticsNotAvailableException':
print("No statistics found. Please ensure that auto-generation of DFE statistics is enabled by running "
"'%statistics' and checking if 'autoCompute' if set to True. Alternately, you can manually "
"trigger statistics generation by running: '%statistics --mode refresh'.")
print("No statistics found. ", end="")
if not is_analytics:
print("Please ensure that auto-generation of DFE statistics is enabled by running '%statistics' and "
"checking if 'autoCompute' if set to True. Alternately, you can manually trigger statistics "
"generation by running: '%statistics --mode refresh'.")
return
elif res_code == "BadRequestException":
print("Unable to query the statistics endpoint. Please check that your Neptune instance is of size r5.large or "
"greater in order to have DFE statistics enabled.")
if is_summary and "Statistics is disabled" not in bad_request_res["detailedMessage"]:
print("\nPlease also note that the Graph Summary API is only available in Neptune engine version 1.2.1.0 "
"and later.")
else:
print("Query encountered 400 error, please see below.")
if is_analytics:
if bad_request_res["message"] == 'Bad route: /summary':
logger.debug("Encountered bad route exception for Analytics, retrying with legacy statistics endpoint.")
return 1
else:
print("Unable to query the statistics endpoint. Please check that your Neptune instance is of size "
"r5.large or greater in order to have DFE statistics enabled.")
if is_summary and "Statistics is disabled" not in bad_request_res["detailedMessage"]:
print("\nPlease also note that the Graph Summary API is only available in Neptune engine version "
"1.2.1.0 and later.")
return
print("Query encountered 400 error, please see below.")
print(f"\nFull response: {bad_request_res}")
return


def mcl_to_bytes(mcl):
Expand Down Expand Up @@ -445,6 +455,7 @@ def stream_viewer(self,line):
@line_magic
@needs_local_scope
@display_exceptions
@neptune_db_only
def statistics(self, line, local_ns: dict = None):
parser = argparse.ArgumentParser()
parser.add_argument('language', nargs='?', type=str.lower, default="propertygraph",
Expand Down Expand Up @@ -476,9 +487,9 @@ def statistics(self, line, local_ns: dict = None):
statistics_res = self.client.statistics(args.language, args.summary, mode)
if statistics_res.status_code == 400:
if args.summary:
process_statistics_400(True, statistics_res)
process_statistics_400(statistics_res)
else:
process_statistics_400(False, statistics_res)
process_statistics_400(statistics_res)
return
statistics_res.raise_for_status()
statistics_res_json = statistics_res.json()
Expand Down Expand Up @@ -508,10 +519,21 @@ def summary(self, line, local_ns: dict = None):
else:
mode = "basic"

summary_res = self.client.statistics(args.language, True, mode)
language_ep = args.language
if self.client.is_analytics_domain():
is_analytics = True
if language_ep in STATISTICS_LANGUAGE_INPUTS_SPARQL:
print("SPARQL is not supported for Neptune Analytics, defaulting to PropertyGraph.")
language_ep = 'propertygraph'
else:
is_analytics = False
summary_res = self.client.statistics(language_ep, True, mode, is_analytics)
if summary_res.status_code == 400:
process_statistics_400(True, summary_res)
return
retry_legacy = process_statistics_400(summary_res, is_summary=True, is_analytics=is_analytics)
if retry_legacy == 1:
summary_res = self.client.statistics(language_ep, True, mode, False)
else:
return
summary_res.raise_for_status()
summary_res_json = summary_res.json()
if not args.silent:
Expand All @@ -530,6 +552,16 @@ def graph_notebook_host(self, line):
self._generate_client_from_config(self.graph_notebook_config)
print(f'set host to {self.graph_notebook_config.host}')

@line_magic
def graph_notebook_service(self, line):
if line == '':
print(f'current service name: {self.graph_notebook_config.neptune_service}')
return

self.graph_notebook_config.neptune_service = normalize_service_name(line)
self._generate_client_from_config(self.graph_notebook_config)
print(f'set service name to {self.graph_notebook_config.neptune_service}')

@magic_variables
@cell_magic
@needs_local_scope
Expand Down Expand Up @@ -1177,6 +1209,7 @@ def opencypher_status(self, line='', local_ns: dict = None):
@line_magic
@needs_local_scope
@display_exceptions
@neptune_db_only
def status(self, line='', local_ns: dict = None):
logger.info(f'calling for status on endpoint {self.graph_notebook_config.host}')
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -1547,6 +1580,7 @@ def load(self, line='', local_ns: dict = None):
value=str(args.concurrency),
placeholder=1,
min=1,
max=2**16,
disabled=False,
layout=widgets.Layout(display=concurrency_hbox_visibility,
width=widget_width)
Expand All @@ -1556,6 +1590,7 @@ def load(self, line='', local_ns: dict = None):
value=args.periodic_commit,
placeholder=0,
min=0,
max=1000000,
disabled=False,
layout=widgets.Layout(display=periodic_commit_hbox_visibility,
width=widget_width)
Expand Down Expand Up @@ -1770,13 +1805,12 @@ def on_button_clicked(b):
source_format_validation_label = widgets.HTML('<p style="color:red;">Format cannot be blank.</p>')
source_format_hbox.children += (source_format_validation_label,)

if not arn.value.startswith('arn:aws') and source.value.startswith(
"s3://"): # only do this validation if we are using an s3 bucket.
validated = False
arn_validation_label = widgets.HTML('<p style="color:red;">Load ARN must start with "arn:aws"</p>')
arn_hbox.children += (arn_validation_label,)

if load_type == 'bulk':
if not arn.value.startswith('arn:aws') and source.value.startswith(
"s3://"): # only do this validation if we are using an s3 bucket.
validated = False
arn_validation_label = widgets.HTML('<p style="color:red;">Load ARN must start with "arn:aws"</p>')
arn_hbox.children += (arn_validation_label,)
dependencies_list = list(filter(None, dependencies.value.split('\n')))
if not len(dependencies_list) < 64:
validated = False
Expand Down Expand Up @@ -3105,9 +3139,15 @@ def handle_opencypher_status(self, line, local_ns):
parser.add_argument('-c', '--cancelQuery', action='store_true', default=False,
help='Tells the status command to cancel a query. This parameter does not take a value.')
parser.add_argument('-w', '--includeWaiting', action='store_true', default=False,
help='When set to true and other parameters are not present, causes status information '
'for waiting queries to be returned as well as for running queries. '
'This parameter does not take a value.')
help='Neptune DB only. When set to true and other parameters are not present, causes '
'status information for waiting queries to be returned as well as for running '
'queries. This parameter does not take a value.')
parser.add_argument('--state', type=str.upper, default='ALL',
help=f'Neptune Analytics only. Specifies what subset of query states to retrieve the '
f'status of. Default is ALL. Accepted values: ${OPENCYPHER_STATUS_STATE_MODES}')
parser.add_argument('-m', '--maxResults', type=int, default=200,
help=f'Neptune Analytics only. Sets an upper limit on the set of returned queries whose '
f'status matches --state. Default is 200.')
parser.add_argument('-s', '--silent-cancel', action='store_true', default=False,
help='If silent_cancel=true then the running query is cancelled and the HTTP response '
'code is 200. If silent_cancel is not present or silent_cancel=false, '
Expand All @@ -3116,21 +3156,50 @@ def handle_opencypher_status(self, line, local_ns):
parser.add_argument('--store-to', type=str, default='', help='store query result to this variable')
args = parser.parse_args(line.split())

using_analytics = self.client.is_analytics_domain()
if not args.cancelQuery:
if args.includeWaiting and not args.queryId:
res = self.client.opencypher_status(include_waiting=args.includeWaiting)
query_id = ''
include_waiting = None
state = ''
max_results = None
if args.includeWaiting and not args.queryId and not self.client.is_analytics_domain():
include_waiting = args.includeWaiting
elif args.state and not args.queryId and self.client.is_analytics_domain():
state = args.state
max_results = args.maxResults
else:
res = self.client.opencypher_status(query_id=args.queryId)
query_id = args.queryId
res = self.client.opencypher_status(query_id=query_id,
include_waiting=include_waiting,
state=state,
max_results=max_results,
use_analytics_endpoint=using_analytics)
if using_analytics and res.status_code == 400 and 'Bad route: /queries' in res.json()["message"]:
res = self.client.opencypher_status(query_id=query_id,
include_waiting=include_waiting,
state=state,
max_results=max_results,
use_analytics_endpoint=False)
res.raise_for_status()
else:
if args.queryId == '':
if not args.silent:
print(OPENCYPHER_CANCEL_HINT_MSG)
return
else:
res = self.client.opencypher_cancel(args.queryId, args.silent_cancel)
res = self.client.opencypher_cancel(args.queryId,
silent=args.silent_cancel,
use_analytics_endpoint=using_analytics)
if using_analytics and res.status_code == 400 and 'Bad route: /queries' in res.json()["message"]:
res = self.client.opencypher_cancel(args.queryId,
silent=args.silent_cancel,
use_analytics_endpoint=False)
res.raise_for_status()
js = res.json()
store_to_ns(args.store_to, js, local_ns)
if not args.silent:
print(json.dumps(js, indent=2))
if using_analytics and args.cancelQuery:
if not args.silent:
print(f'Submitted cancellation request for query ID: {args.queryId}')
else:
js = res.json()
store_to_ns(args.store_to, js, local_ns)
if not args.silent:
print(json.dumps(js, indent=2))
42 changes: 34 additions & 8 deletions src/graph_notebook/neptune/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,15 @@

STATISTICS_MODES = ["", "status", "disableAutoCompute", "enableAutoCompute", "refresh", "delete"]
SUMMARY_MODES = ["", "basic", "detailed"]
STATISTICS_LANGUAGE_INPUTS = ["propertygraph", "pg", "gremlin", "oc", "opencypher", "sparql", "rdf"]
STATISTICS_LANGUAGE_INPUTS_PG = ["propertygraph", "pg", "gremlin", "oc", "opencypher"]
STATISTICS_LANGUAGE_INPUTS_SPARQL = ["sparql", "rdf"]
STATISTICS_LANGUAGE_INPUTS = STATISTICS_LANGUAGE_INPUTS_PG + STATISTICS_LANGUAGE_INPUTS_SPARQL

SPARQL_EXPLAIN_MODES = ['dynamic', 'static', 'details']
OPENCYPHER_EXPLAIN_MODES = ['dynamic', 'static', 'details']
OPENCYPHER_PLAN_CACHE_MODES = ['auto', 'enabled', 'disabled']
OPENCYPHER_DEFAULT_TIMEOUT = 120000
OPENCYPHER_STATUS_STATE_MODES = ['ALL', 'RUNNING', 'WAITING', 'CANCELLING']


def is_allowed_neptune_host(hostname: str, host_allowlist: list):
Expand Down Expand Up @@ -405,7 +408,7 @@ def opencypher_http(self, query: str, headers: dict = None, explain: str = None,
if plan_cache:
data['planCache'] = plan_cache
if query_timeout:
headers['query_timeout_millis'] = str(query_timeout)
data['queryTimeoutMilliseconds'] = str(query_timeout)
else:
url += 'db/neo4j/tx/commit'
headers['content-type'] = 'application/json'
Expand Down Expand Up @@ -441,16 +444,20 @@ def opencyper_bolt(self, query: str, **kwargs):
driver.close()
return data

def opencypher_status(self, query_id: str = '', include_waiting: bool = False):
def opencypher_status(self, query_id: str = '', include_waiting: bool = False, state: str = '',
max_results: int = None, use_analytics_endpoint: bool = False):
if use_analytics_endpoint:
return self._analytics_query_status(query_id=query_id, state=state, max_results=max_results)
kwargs = {}
if include_waiting:
kwargs['includeWaiting'] = True
return self._query_status('openCypher', query_id=query_id, **kwargs)

def opencypher_cancel(self, query_id, silent: bool = False):
def opencypher_cancel(self, query_id, silent: bool = False, use_analytics_endpoint: bool = False):
if type(query_id) is not str or query_id == '':
raise ValueError('query_id must be a non-empty string')

if use_analytics_endpoint:
return self._analytics_query_status(query_id=query_id, cancel_query=True)
return self._query_status('openCypher', query_id=query_id, cancelQuery=True, silent=silent)

def get_opencypher_driver(self):
Expand Down Expand Up @@ -808,7 +815,25 @@ def _query_status(self, language: str, *, query_id: str = '', **kwargs) -> reque
res = self._http_session.send(req, verify=self.ssl_verify)
return res

def statistics(self, language: str, summary: bool = False, mode: str = '') -> requests.Response:
def _analytics_query_status(self, query_id: str = '', state: str = '', max_results: int = None,
cancel_query: bool = False) -> requests.Response:
url = f'{self._http_protocol}://{self.host}:{self.port}/queries'
if query_id != '':
url += f'/{query_id}'
elif state != '':
url += f'?state={state}&maxResults={max_results}'

method = 'DELETE' if cancel_query else 'GET'

headers = {
'Content-Type': 'application/x-www-form-urlencoded'
}
req = self._prepare_request(method, url, headers=headers)
res = self._http_session.send(req, verify=self.ssl_verify)
return res

def statistics(self, language: str, summary: bool = False, mode: str = '',
use_analytics_endpoint: bool = False) -> requests.Response:
headers = {
'Accept': 'application/json'
}
Expand All @@ -817,11 +842,12 @@ def statistics(self, language: str, summary: bool = False, mode: str = '') -> re
elif language == "sparql":
language = "rdf"

url = f'{self._http_protocol}://{self.host}:{self.port}/{language}/statistics'
base_url = f'{self._http_protocol}://{self.host}:{self.port}'
url = base_url + f'/{language}/statistics'
data = {'mode': mode}

if summary:
summary_url = url + '/summary'
summary_url = (base_url if use_analytics_endpoint else url) + '/summary'
if mode:
summary_mode_param = '?mode=' + mode
summary_url += summary_mode_param
Expand Down

0 comments on commit 5705154

Please sign in to comment.