Skip to content

Commit

Permalink
support http2 and custom headers (#322)
Browse files Browse the repository at this point in the history
* update

* fix pip install pytest error

* update

* update

* update

* update

* update

* update

* update
  • Loading branch information
HarrisChu authored Mar 11, 2024
1 parent dfe1534 commit bcc60e7
Show file tree
Hide file tree
Showing 24 changed files with 704 additions and 269 deletions.
8 changes: 6 additions & 2 deletions .github/workflows/run_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install nebulagraph-python from source and test dependencies
run: |
python setup.py install
python -m pip install --upgrade pip
# remove pyproject.toml to avoid pdm install
rm pyproject.toml
pip install .
pip install pip-tools pytest
- name: Test with pytest
run: |
Expand All @@ -39,7 +42,7 @@ jobs:
strategy:
max-parallel: 2
matrix:
python-version: [3.7, 3.8, 3.9, '3.10', 3.11, 3.12]
python-version: [3.7, 3.8, 3.9, '3.10', 3.11]

steps:
- name: Maximize runner space
Expand All @@ -59,6 +62,7 @@ jobs:

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .
pdm install -G:dev
pdm install -G:test
Expand Down
12 changes: 8 additions & 4 deletions nebula3/Config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@
class Config(object):
# the min connection always in pool
min_connection_pool_size = 0

# the max connection in pool
max_connection_pool_size = 10

# connection or execute timeout, unit ms, 0 means no timeout
timeout = 0

# 0 means will never close the idle connection, unit ms,
idle_time = 0

# the interval to check idle time connection, unit second, -1 means no check
interval_check = -1
# use http2 or not
use_http2 = False
# headers for http2, dict type
http_headers = None


class SSL_config(object):
Expand Down Expand Up @@ -89,3 +89,7 @@ class SessionPoolConfig(object):
max_size = 30
min_size = 1
interval_check = -1
# use http2 or not
use_http2 = False
# headers for http2, dict type
http_headers = None
85 changes: 85 additions & 0 deletions nebula3/fbthrift/transport/THttp2Client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2024 vesoft inc. All rights reserved.
#
# This source code is licensed under Apache 2.0 License.
#
from nebula3.fbthrift.transport.TTransport import *
import httpx

default_timeout = 60

class THttp2Client(TTransportBase):
def __init__(self, url,
timeout=None,
verify=None,
certfile=None,
keyfile=None,
password=None,
http_headers=None,
):
self.__wbuf = StringIO()
self.__rbuf = StringIO()
self.__http = None
if timeout is not None and timeout > 0 :
self.timeout = timeout
if timeout is None:
self.timeout = default_timeout

self.url = url
if verify is None:
self.verify = False
else:
self.verify = verify
if certfile is not None :
self.cert = (certfile, keyfile, password)
else:
self.cert = None
self.response = None
self.http_headers = http_headers

def isOpen(self):
return self.__http is not None and self.__http.is_closed is False

def open(self):
if self.cert is None:
self.__http = httpx.Client(http1=False,http2=True, verify=False, timeout=self.timeout)
else:
self.__http = httpx.Client(http1=False,http2=True, verify=self.verify, cert=self.cert, timeout=self.timeout)

def close(self):
self.__http.close()
self.__http = None

def read(self, sz):
return self.__rbuf.read(sz)


def write(self, buf):
self.__wbuf.write(buf)

def flush(self):
if self.isOpen():
self.close()
self.open()

# Pull data out of buffer
data = self.__wbuf.getvalue()
self.__wbuf = StringIO()

# HTTP2 request
header = {
'Content-Type': 'application/x-thrift',
'Content-Length': str(len(data)),
'User-Agent': 'Python/THttpClient',
}
if self.http_headers is not None and isinstance(self.http_headers, dict):
header.update(self.http_headers)
try:
self.response= self.__http.post(self.url, headers=header, data=data)
except Exception as e:
raise TTransportException(TTransportException.UNKNOWN, str(e))
# Get reply to flush the request
self.code = self.response.status_code
self.headers = self.response.headers
self.__rbuf = StringIO()
self.__rbuf.write(self.response.read())
self.__rbuf.seek(0)
76 changes: 63 additions & 13 deletions nebula3/gclient/net/Connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@


import time
import ssl

from nebula3.fbthrift.transport import (
TSocket,
TSSLSocket,
TTransport,
THeaderTransport,
THttp2Client,
)
from nebula3.fbthrift.transport.TTransport import TTransportException
from nebula3.fbthrift.protocol import THeaderProtocol
from nebula3.fbthrift.protocol import THeaderProtocol, TBinaryProtocol

from nebula3.common.ttypes import ErrorCode
from nebula3.graph import GraphService
Expand Down Expand Up @@ -42,30 +44,54 @@ def __init__(self):
self._port = None
self._timeout = 0
self._ssl_conf = None
self.use_http2 = False
self.http_headers = None

def open(self, ip, port, timeout):
def open(self, ip, port, timeout, use_http2=False, http_headers=None):
"""open the connection
:param ip: the server ip
:param port: the server port
:param timeout: the timeout for connect and execute
:param use_http2: use http2 or not
:param http_headers: http headers
:return: void
"""
self.open_SSL(ip, port, timeout, None)
self.open_SSL(ip, port, timeout, None, use_http2, http_headers)

def open_SSL(self, ip, port, timeout, ssl_config=None):
def open_SSL(
self, ip, port, timeout, ssl_config=None, use_http2=False, http_headers=None
):
"""open the SSL connection
:param ip: the server ip
:param port: the server port
:param timeout: the timeout for connect and execute
:ssl_config: configs for SSL
:param use_http2: use http2 or not
:param http_headers: http headers
:return: void
"""
self._ip = ip
self._port = port
self._timeout = timeout
self._ssl_conf = ssl_config
self.use_http2 = use_http2
self.http_headers = http_headers
try:
if use_http2 is False:
protocol = self.__get_protocol(timeout, ssl_config)
else:
protocol = self.__get_protocal_http2(timeout, ssl_config, http_headers)
self._connection = GraphService.Client(protocol)
resp = self._connection.verifyClientVersion(VerifyClientVersionReq())
if resp.error_code != ErrorCode.SUCCEEDED:
self._connection._iprot.trans.close()
raise ClientServerIncompatibleException(resp.error_msg)
except Exception as e:
raise

def __get_protocol(self, timeout, ssl_config):
try:
if ssl_config is not None:
s = TSSLSocket.TSSLSocket(
Expand All @@ -89,14 +115,29 @@ def open_SSL(self, ip, port, timeout, ssl_config=None):
header_transport = THeaderTransport.THeaderTransport(buffered_transport)
protocol = THeaderProtocol.THeaderProtocol(header_transport)
header_transport.open()

self._connection = GraphService.Client(protocol)
resp = self._connection.verifyClientVersion(VerifyClientVersionReq())
if resp.error_code != ErrorCode.SUCCEEDED:
self._connection._iprot.trans.close()
raise ClientServerIncompatibleException(resp.error_msg)
except Exception:
except Exception as e:
raise
return protocol

def __get_protocal_http2(self, timeout, ssl_config, http_headers):
verify, certfile, keyfile, password = None, None, None, None
if ssl_config is not None:
# verify could be a boolean or ssl.SSLContext in httpx.
verify = ssl.create_default_context(cafile=ssl_config.ca_certs)
certfile = ssl_config.certfile
keyfile = ssl_config.keyfile
url = "https://" + self._ip + ":" + str(self._port)
else:
url = "http://" + self._ip + ":" + str(self._port)
try:
transport = THttp2Client.THttp2Client(
url, timeout, verify, certfile, keyfile, password, http_headers
)
transport.open()
protocol = TBinaryProtocol.TBinaryProtocol(transport)
except Exception as e:
raise
return protocol

def _reopen(self):
"""reopen the connection
Expand All @@ -105,9 +146,18 @@ def _reopen(self):
"""
self.close()
if self._ssl_conf is not None:
self.open_SSL(self._ip, self._port, self._timeout, self._ssl_conf)
self.open_SSL(
self._ip,
self._port,
self._timeout,
self._ssl_conf,
self.use_http2,
self.http_headers,
)
else:
self.open(self._ip, self._port, self._timeout)
self.open(
self._ip, self._port, self._timeout, self.use_http2, self.http_headers
)

def authenticate(self, user_name, password):
"""authenticate to graphd
Expand Down
20 changes: 18 additions & 2 deletions nebula3/gclient/net/ConnectionPool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from nebula3.gclient.net.Session import Session
from nebula3.gclient.net.Connection import Connection
from nebula3.Config import Config
from nebula3.logger import logger


Expand Down Expand Up @@ -51,6 +52,7 @@ def init(self, addresses, configs, ssl_conf=None):
if self._close:
logger.error('The pool has init or closed.')
raise RuntimeError('The pool has init or closed.')
assert isinstance(configs, Config)
self._configs = configs
self._ssl_configs = ssl_conf
for address in addresses:
Expand Down Expand Up @@ -82,7 +84,12 @@ def init(self, addresses, configs, ssl_conf=None):
for i in range(0, conns_per_address):
connection = Connection()
connection.open_SSL(
addr[0], addr[1], self._configs.timeout, self._ssl_configs
addr[0],
addr[1],
self._configs.timeout,
self._ssl_configs,
self._configs.use_http2,
self._configs.http_headers,
)
self._connections[addr].append(connection)
return True
Expand Down Expand Up @@ -181,6 +188,8 @@ def get_connection(self):
addr[1],
self._configs.timeout,
self._ssl_configs,
self._configs.use_http2,
self._configs.http_headers,
)
connection.is_used = True
self._connections[addr].append(connection)
Expand All @@ -206,7 +215,14 @@ def ping(self, address):
"""
try:
conn = Connection()
conn.open_SSL(address[0], address[1], 1000, self._ssl_configs)
conn.open_SSL(
address[0],
address[1],
1000,
self._ssl_configs,
self._configs.use_http2,
self._configs.http_headers,
)
conn.close()
return True
except Exception as ex:
Expand Down
25 changes: 22 additions & 3 deletions nebula3/gclient/net/SessionPool.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,22 @@ def ping(self, address):
try:
conn = Connection()
if self._ssl_configs is None:
conn.open(address[0], address[1], 1000)
conn.open(
address[0],
address[1],
1000,
self._configs.use_http2,
self._configs.http_headers,
)
else:
conn.open_SSL(address[0], address[1], 1000, self._ssl_configs)
conn.open_SSL(
address[0],
address[1],
1000,
self._ssl_configs,
self._configs.use_http2,
self._configs.http_headers,
)
conn.close()
return True
except Exception as ex:
Expand Down Expand Up @@ -381,7 +394,13 @@ def _new_session(self):
# connect to the valid service
connection = Connection()
try:
connection.open(addr[0], addr[1], self._configs.timeout)
connection.open(
addr[0],
addr[1],
self._configs.timeout,
self._configs.use_http2,
self._configs.http_headers,
)
auth_result = connection.authenticate(self._username, self._password)
session = Session(connection, auth_result, self, False)

Expand Down
Loading

0 comments on commit bcc60e7

Please sign in to comment.