Skip to content

Commit

Permalink
feat: cast params
Browse files Browse the repository at this point in the history
  • Loading branch information
wey-gu committed May 27, 2024
1 parent e876666 commit 909e610
Showing 1 changed file with 77 additions and 3 deletions.
80 changes: 77 additions & 3 deletions nebula3/gclient/net/Connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
#
# This source code is licensed under Apache 2.0 License.


import datetime
import time
import ssl

from typing import Any

from nebula3.fbthrift.transport import (
TSocket,
TSSLSocket,
Expand All @@ -18,7 +20,7 @@
from nebula3.fbthrift.transport.TTransport import TTransportException
from nebula3.fbthrift.protocol import THeaderProtocol, TBinaryProtocol

from nebula3.common.ttypes import ErrorCode
from nebula3.common.ttypes import ErrorCode, Value, NList, Date, Time, DateTime
from nebula3.graph import GraphService
from nebula3.graph.ttypes import VerifyClientVersionReq
from nebula3.logger import logger
Expand Down Expand Up @@ -190,6 +192,77 @@ def execute(self, session_id, stmt):
:return: ExecutionResponse
"""
return self.execute_parameter(session_id, stmt, None)

@staticmethod
def _cast_value(value: Any) -> Value:
"""
Cast the value to nebula Value type
ref: https://github.com/vesoft-inc/nebula/blob/master/src/common/datatypes/Value.cpp
:param value: the value to be casted
:return: the casted value
"""
if isinstance(value, Value):
return value
casted_value = Value()
if isinstance(value, bool):
casted_value.set_bVal(value)
elif isinstance(value, int):
casted_value.set_iVal(value)
elif isinstance(value, str):
casted_value.set_sVal(value)
elif isinstance(value, float):
casted_value.set_fVal(value)
elif isinstance(value, datetime.date):
date_value = Date(
year=value.year,
month=value.month,
day=value.day
)
casted_value.set_dVal(date_value)
elif isinstance(value, datetime.time):
time_value = Time(
hour=value.hour,
minute=value.minute,
sec=value.second,
microsec=value.microsecond
)
casted_value.set_tVal(time_value)
elif isinstance(value, datetime.datetime):
datetime_value = DateTime(
year=value.year,
month=value.month,
day=value.day,
hour=value.hour,
minute=value.minute,
sec=value.second,
microsec=value.microsecond
)
casted_value.set_dtVal(datetime_value)
# TODO: add support for GeoSpatial
else:
raise TypeError(f"Unsupported type: {type(value)}")
return casted_value

@staticmethod
def _build_byte_param(params: dict) -> dict:
byte_params = {}
for k, v in params.items():
if isinstance(v, Value):
byte_params[k] = v
elif type(v).startswith("nebula3.common.ttypes"):
byte_params[k] = v
elif isinstance(v, list):
byte_list = []
for item in v:
byte_list.append(Connection._cast_value(item))
nlist = NList(values=byte_list)
byte_params[k] = nlist
elif isinstance(v, dict):
# TODO: add support for NMap
raise TypeError("Unsupported type: dict")
else:
byte_params[k] = Connection._cast_value(v)
return byte_params

def execute_parameter(self, session_id, stmt, params):
"""execute interface with session_id and ngql
Expand All @@ -198,8 +271,9 @@ def execute_parameter(self, session_id, stmt, params):
:param params: parameter map
:return: ExecutionResponse
"""
byte_params = Connection._build_byte_param(params)
try:
resp = self._connection.executeWithParameter(session_id, stmt, params)
resp = self._connection.executeWithParameter(session_id, stmt, byte_params)
return resp
except Exception as te:
if isinstance(te, TTransportException):
Expand Down

0 comments on commit 909e610

Please sign in to comment.