Skip to content

Commit

Permalink
Add enable port option to %%gremlin and %%oc
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelnchin committed Oct 18, 2024
1 parent db7d7d3 commit d0e0552
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
22 changes: 17 additions & 5 deletions src/graph_notebook/magics/graph_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,9 @@ def gremlin(self, line, cell, local_ns: dict = None):
help='Enable debug mode.')
parser.add_argument('--profile-misc-args', type=str, default='{}',
help='Additional profile options, passed in as a map.')
parser.add_argument('--use-port', action='store_true', default=False,
help='Includes the port in the URI for applicable Neptune HTTP requests where it is '
'excluded by default.')
parser.add_argument('-sp', '--stop-physics', action='store_true', default=False,
help="Disable visualization physics after the initial simulation stabilizes.")
parser.add_argument('-sd', '--simulation-duration', type=int, default=1500,
Expand Down Expand Up @@ -1208,7 +1211,8 @@ def gremlin(self, line, cell, local_ns: dict = None):
if self.client.is_analytics_domain() and query_params:
explain_args['parameters'] = query_params
res = self.client.gremlin_explain(cell,
args=explain_args)
args=explain_args,
use_port=args.use_port)
res.raise_for_status()
except Exception as e:
if self.client.is_analytics_domain():
Expand Down Expand Up @@ -1251,7 +1255,9 @@ def gremlin(self, line, cell, local_ns: dict = None):
print('--profile-misc-args received invalid input, please check that you are passing in a valid '
'string representation of a map, ex. "{\'profile.x\':\'true\'}"')
try:
res = self.client.gremlin_profile(query=cell, args=profile_args)
res = self.client.gremlin_profile(query=cell,
args=profile_args,
use_port=args.use_port)
res.raise_for_status()
except Exception as e:
if self.client.is_analytics_domain():
Expand Down Expand Up @@ -1302,7 +1308,8 @@ def gremlin(self, line, cell, local_ns: dict = None):
passed_params = query_params if self.client.is_analytics_domain() else None
query_res_http = self.client.gremlin_http_query(cell,
headers=headers,
query_params=passed_params)
query_params=passed_params,
use_port=args.use_port)
query_res_http.raise_for_status()
try:
query_res_http_json = query_res_http.json()
Expand Down Expand Up @@ -3550,6 +3557,9 @@ def handle_opencypher_query(self, line, cell, local_ns):
parser.add_argument('-qp', '--query-parameters', type=str, default='',
help='Parameter definitions to apply to the query. This option can accept a local variable '
'name, or a string representation of the map.')
parser.add_argument('--use-port', action='store_true', default=False,
help='Includes the port in the URI for applicable Neptune HTTP requests where it is '
'excluded by default.')
parser.add_argument('-g', '--group-by', type=str, default='~labels',
help='Property used to group nodes (e.g. code, ~id) default is ~labels')
parser.add_argument('-gd', '--group-by-depth', action='store_true', default=False,
Expand Down Expand Up @@ -3638,7 +3648,8 @@ def handle_opencypher_query(self, line, cell, local_ns):
explain=args.explain_type,
query_params=query_params,
plan_cache=args.plan_cache,
query_timeout=args.query_timeout)
query_timeout=args.query_timeout,
use_port=args.use_port)
query_time = time.time() * 1000 - query_start
res_replace_chars = res.content.replace(b'$', b'\$')
explain = res_replace_chars.decode("utf-8")
Expand All @@ -3660,7 +3671,8 @@ def handle_opencypher_query(self, line, cell, local_ns):
oc_http = self.client.opencypher_http(cell,
query_params=query_params,
plan_cache=args.plan_cache,
query_timeout=args.query_timeout)
query_timeout=args.query_timeout,
use_port=args.use_port)
query_time = time.time() * 1000 - query_start
if oc_http.status_code == 400 and not self.client.is_analytics_domain() and args.plan_cache != "auto":
try:
Expand Down
23 changes: 15 additions & 8 deletions src/graph_notebook/neptune/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,14 +466,15 @@ def gremlin_query(self, query, transport_args=None, bindings=None):
c.close()
raise e

def gremlin_http_query(self, query, headers=None, query_params: dict = None) -> requests.Response:
def gremlin_http_query(self, query, headers=None, query_params: dict = None,
use_port: bool = False) -> requests.Response:
if headers is None:
headers = {}

data = {}
use_proxy = True if self.proxy_host != '' else False
if self.is_analytics_domain():
uri = f'{self.get_uri(use_websocket=False, use_proxy=use_proxy, include_port=False)}/queries'
uri = f'{self.get_uri(use_websocket=False, use_proxy=use_proxy, include_port=use_port)}/queries'
data['query'] = query
data['language'] = 'gremlin'
headers['content-type'] = 'application/json'
Expand All @@ -498,17 +499,20 @@ def gremlin_cancel(self, query_id: str):
raise ValueError('query_id must be a non-empty string')
return self._query_status('gremlin', query_id=query_id, cancelQuery=True)

def gremlin_explain(self, query: str, args={}) -> requests.Response:
return self._gremlin_query_plan(query=query, plan_type='explain', args=args)
def gremlin_explain(self, query: str, use_port: bool = False, args={}) -> requests.Response:
return self._gremlin_query_plan(query=query, plan_type='explain', args=args, use_port=use_port)

def gremlin_profile(self, query: str, args={}) -> requests.Response:
return self._gremlin_query_plan(query=query, plan_type='profile', args=args)
def gremlin_profile(self, query: str, use_port: bool = False, args={}) -> requests.Response:
return self._gremlin_query_plan(query=query, plan_type='profile', args=args, use_port=use_port)

def _gremlin_query_plan(self, query: str, plan_type: str, args: dict, ) -> requests.Response:
def _gremlin_query_plan(self, query: str, plan_type: str, args: dict,
use_port: bool = False) -> requests.Response:
data = {}
headers = {}
url = f'{self._http_protocol}://{self.host}'
if self.is_analytics_domain():
if use_port:
url += f':{self.port}'
url += '/queries'
data['query'] = query
data['language'] = 'gremlin'
Expand Down Expand Up @@ -537,7 +541,8 @@ def _gremlin_query_plan(self, query: str, plan_type: str, args: dict, ) -> reque
def opencypher_http(self, query: str, headers: dict = None, explain: str = None,
query_params: dict = None,
plan_cache: str = None,
query_timeout: int = None) -> requests.Response:
query_timeout: int = None,
use_port: bool = False) -> requests.Response:
if headers is None:
headers = {}

Expand All @@ -546,6 +551,8 @@ def opencypher_http(self, query: str, headers: dict = None, explain: str = None,
if self.is_neptune_domain():
data = {}
if self.is_analytics_domain():
if use_port:
url += f':{self.port}'
url += f'/queries'
data['language'] = 'opencypher'
else:
Expand Down

0 comments on commit d0e0552

Please sign in to comment.