Skip to content

Commit

Permalink
feat: primitive params (#351)
Browse files Browse the repository at this point in the history
feat: primitive params

---------

Co-authored-by: Wey Gu <weyl.gu@gmail.com>
  • Loading branch information
BeautyyuYanli and wey-gu authored Jul 15, 2024
1 parent 5d562fb commit f1b75ce
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 38 deletions.
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,11 @@ params = {
"ids": ["player100", "player101"], # second query
}

result = client.execute_py_params(
"RETURN abs($p1)+3 AS col1, (toBoolean($p2) AND false) AS col2, toLower($p3)+1 AS col3",
resp = client.execute_py(
"RETURN abs($p1)+3 AS col1, (toBoolean($p2) and false) AS col2, toLower($p3)+1 AS col3",
params,
)

result = client.execute_py_params(
resp = client.execute_py(
"MATCH (v) WHERE id(v) in $ids RETURN id(v) AS vertex_id",
params,
)
Expand Down
7 changes: 4 additions & 3 deletions example/Params.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import time
from typing import Any, Dict, List

from nebula3.gclient.net import ConnectionPool
from nebula3.Config import Config
from nebula3.common import ttypes
from nebula3.data.ResultSet import ResultSet

# define a config
config = Config()
Expand Down Expand Up @@ -51,12 +53,11 @@
"p4": ["Bob", "Lily"],
}

resp = client.execute_py_params(
resp = client.execute_py(
"RETURN abs($p1)+3 AS col1, (toBoolean($p2) and false) AS col2, toLower($p3)+1 AS col3",
params_premitive,
)

resp = client.execute_py_params(
resp = client.execute_py(
"MATCH (v) WHERE id(v) in $p4 RETURN id(v) AS vertex_id",
params_premitive,
)
1 change: 1 addition & 0 deletions nebula3/gclient/net/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from nebula3.gclient.net.Session import Session
from nebula3.gclient.net.Connection import Connection
from nebula3.gclient.net.ConnectionPool import ConnectionPool
from nebula3.gclient.net.base import BaseExecutor, ExecuteError
33 changes: 29 additions & 4 deletions nebula3/gclient/net/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,21 @@
from nebula3.common.ttypes import ErrorCode, Value, NList, Date, Time, DateTime


class ExecuteError(Exception):
def __init__(self, stmt: str, param: Any, code: ErrorCode, msg: str):
self.stmt = stmt
self.param = param
self.code = code
self.msg = msg

def __str__(self):
return (
f"ExecuteError. err_code: {self.code}, err_msg: {self.msg}.\n"
+ f"Statement: \n{self.stmt}\n"
+ f"Parameter: \n{self.param}"
)


class BaseExecutor:
@abstractmethod
def execute_parameter(
Expand All @@ -24,11 +39,21 @@ def execute(self, stmt: str) -> ResultSet:
def execute_json(self, stmt: str) -> bytes:
return self.execute_json_with_parameter(stmt, None)

def execute_py_params(
self, stmt: str, params: Optional[Dict[str, Any]]
) -> ResultSet:
def execute_py(
self,
stmt: str,
params: Optional[Dict[str, Any]] = None,
):
"""**Recommended** Execute a statement with parameters in Python type instead of thrift type."""
return self.execute_parameter(stmt, _build_byte_param(params))
if params is None:
result = self.execute_parameter(stmt, None)
else:
result = self.execute_parameter(stmt, _build_byte_param(params))

if not result.is_succeeded():
raise ExecuteError(stmt, params, result.error_code(), result.error_msg())

return result


def _build_byte_param(params: dict) -> dict:
Expand Down
59 changes: 32 additions & 27 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import time
import json

from nebula3.gclient.net import ConnectionPool
from nebula3.gclient.net import ConnectionPool, ExecuteError
from nebula3.Config import Config
from nebula3.common import *
from unittest import TestCase
Expand Down Expand Up @@ -94,18 +94,6 @@ def test_parameter(self):
assert False == resp.row_values(0)[1].as_bool()
assert "bob1" == resp.row_values(0)[2].as_string()

# same test with premitive params
resp = client.execute_py_params(
"RETURN abs($p1)+3 AS col1, (toBoolean($p2) and false) AS col2, toLower($p3)+1 AS col3",
self.params_premitive,
)
assert resp.is_succeeded(), resp.error_msg()
assert 1 == resp.row_size()
names = ["col1", "col2", "col3"]
assert names == resp.keys()
assert 6 == resp.row_values(0)[0].as_int()
assert False == resp.row_values(0)[1].as_bool()
assert "bob1" == resp.row_values(0)[2].as_string()
# test cypher parameter
resp = client.execute_parameter(
f"""MATCH (v:person)--() WHERE v.person.age>abs($p1)+3
Expand All @@ -126,21 +114,11 @@ def test_parameter(self):
self.params,
)
assert not resp.is_succeeded()
resp = client.execute_py_params(
'$p1=go from "Bob" over like yield like._dst;',
self.params_premitive,
)
assert not resp.is_succeeded()
resp = client.execute_parameter(
"go from $p3 over like yield like._dst;",
self.params,
)
assert not resp.is_succeeded()
resp = client.execute_py_params(
"go from $p3 over like yield like._dst;",
self.params_premitive,
)
assert not resp.is_succeeded()
resp = client.execute_parameter(
"fetch prop on person $p3 yield vertex as v",
self.params,
Expand All @@ -162,12 +140,39 @@ def test_parameter(self):
)
assert not resp.is_succeeded()

resp = client.execute_py_params(
# same test with premitive params
resp = client.execute_py(
"RETURN abs($p1)+3 AS col1, (toBoolean($p2) and false) AS col2, toLower($p3)+1 AS col3",
self.params_premitive,
).as_primitive()
assert 1 == len(resp)
assert ["col1", "col2", "col3"] == list(resp[0].keys())
assert resp[0]["col1"] == 6
assert resp[0]["col2"] == False
assert resp[0]["col3"] == "bob1"
try:
resp = client.execute_py(
'$p1=go from "Bob" over like yield like._dst;',
self.params_premitive,
)
except ExecuteError:
pass
else:
raise AssertionError("should raise exception")
try:
resp = client.execute_py(
"go from $p3 over like yield like._dst;",
self.params_premitive,
)
except ExecuteError:
pass
else:
raise AssertionError("should raise exception")
resp = client.execute_py(
"MATCH (v) WHERE id(v) in $p4 RETURN id(v) AS vertex_id",
self.params_premitive,
)
assert resp.is_succeeded(), resp.error_msg()
assert 2 == resp.row_size()
).as_primitive()
assert 2 == len(resp)

def tearDown(self) -> None:
client = self.pool.get_session("root", "nebula")
Expand Down

0 comments on commit f1b75ce

Please sign in to comment.