Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support http2 and custom headers #322

Merged
merged 9 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .github/workflows/run_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install nebulagraph-python from source and test dependencies
run: |
python setup.py install
python -m pip install --upgrade pip
# remove pyproject.toml to avoid pdm install
rm pyproject.toml
pip install .
pip install pip-tools pytest
- name: Test with pytest
run: |
Expand All @@ -39,7 +42,7 @@ jobs:
strategy:
max-parallel: 2
matrix:
python-version: [3.7, 3.8, 3.9, '3.10', 3.11, 3.12]
python-version: [3.7, 3.8, 3.9, '3.10', 3.11]

steps:
- name: Maximize runner space
Expand All @@ -59,6 +62,7 @@ jobs:

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .
pdm install -G:dev
pdm install -G:test
Expand Down
12 changes: 8 additions & 4 deletions nebula3/Config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@
class Config(object):
# the min connection always in pool
min_connection_pool_size = 0

# the max connection in pool
max_connection_pool_size = 10

# connection or execute timeout, unit ms, 0 means no timeout
timeout = 0

# 0 means will never close the idle connection, unit ms,
idle_time = 0

# the interval to check idle time connection, unit second, -1 means no check
interval_check = -1
# use http2 or not
use_http2 = False
# headers for http2, dict type
http_headers = None


class SSL_config(object):
Expand Down Expand Up @@ -89,3 +89,7 @@ class SessionPoolConfig(object):
max_size = 30
min_size = 1
interval_check = -1
# use http2 or not
use_http2 = False
# headers for http2, dict type
http_headers = None
85 changes: 85 additions & 0 deletions nebula3/fbthrift/transport/THttp2Client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2024 vesoft inc. All rights reserved.
#
# This source code is licensed under Apache 2.0 License.
#
from nebula3.fbthrift.transport.TTransport import *
import httpx

default_timeout = 60

class THttp2Client(TTransportBase):
def __init__(self, url,
timeout=None,
verify=None,
certfile=None,
keyfile=None,
password=None,
http_headers=None,
):
self.__wbuf = StringIO()
self.__rbuf = StringIO()
self.__http = None
if timeout is not None and timeout > 0 :
self.timeout = timeout
if timeout is None:
self.timeout = default_timeout

self.url = url
if verify is None:
self.verify = False
else:
self.verify = verify
if certfile is not None :
self.cert = (certfile, keyfile, password)
else:
self.cert = None
self.response = None
self.http_headers = http_headers

def isOpen(self):
return self.__http is not None and self.__http.is_closed is False

def open(self):
if self.cert is None:
self.__http = httpx.Client(http1=False,http2=True, verify=False, timeout=self.timeout)
else:
self.__http = httpx.Client(http1=False,http2=True, verify=self.verify, cert=self.cert, timeout=self.timeout)

def close(self):
self.__http.close()
self.__http = None

def read(self, sz):
return self.__rbuf.read(sz)


def write(self, buf):
self.__wbuf.write(buf)

def flush(self):
if self.isOpen():
self.close()
self.open()

# Pull data out of buffer
data = self.__wbuf.getvalue()
self.__wbuf = StringIO()

# HTTP2 request
header = {
'Content-Type': 'application/x-thrift',
'Content-Length': str(len(data)),
'User-Agent': 'Python/THttpClient',
}
if self.http_headers is not None and isinstance(self.http_headers, dict):
header.update(self.http_headers)
try:
self.response= self.__http.post(self.url, headers=header, data=data)
except Exception as e:
raise TTransportException(TTransportException.UNKNOWN, str(e))
# Get reply to flush the request
self.code = self.response.status_code
self.headers = self.response.headers
self.__rbuf = StringIO()
self.__rbuf.write(self.response.read())
self.__rbuf.seek(0)
76 changes: 63 additions & 13 deletions nebula3/gclient/net/Connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@


import time
import ssl

from nebula3.fbthrift.transport import (
TSocket,
TSSLSocket,
TTransport,
THeaderTransport,
THttp2Client,
)
from nebula3.fbthrift.transport.TTransport import TTransportException
from nebula3.fbthrift.protocol import THeaderProtocol
from nebula3.fbthrift.protocol import THeaderProtocol, TBinaryProtocol

from nebula3.common.ttypes import ErrorCode
from nebula3.graph import GraphService
Expand Down Expand Up @@ -42,30 +44,54 @@ def __init__(self):
self._port = None
self._timeout = 0
self._ssl_conf = None
self.use_http2 = False
self.http_headers = None

def open(self, ip, port, timeout):
def open(self, ip, port, timeout, use_http2=False, http_headers=None):
"""open the connection

:param ip: the server ip
:param port: the server port
:param timeout: the timeout for connect and execute
:param use_http2: use http2 or not
:param http_headers: http headers
:return: void
"""
self.open_SSL(ip, port, timeout, None)
self.open_SSL(ip, port, timeout, None, use_http2, http_headers)

def open_SSL(self, ip, port, timeout, ssl_config=None):
def open_SSL(
self, ip, port, timeout, ssl_config=None, use_http2=False, http_headers=None
):
"""open the SSL connection

:param ip: the server ip
:param port: the server port
:param timeout: the timeout for connect and execute
:ssl_config: configs for SSL
:param use_http2: use http2 or not
:param http_headers: http headers
:return: void
"""
self._ip = ip
self._port = port
self._timeout = timeout
self._ssl_conf = ssl_config
self.use_http2 = use_http2
self.http_headers = http_headers
try:
if use_http2 is False:
protocol = self.__get_protocol(timeout, ssl_config)
else:
protocol = self.__get_protocal_http2(timeout, ssl_config, http_headers)
self._connection = GraphService.Client(protocol)
resp = self._connection.verifyClientVersion(VerifyClientVersionReq())
if resp.error_code != ErrorCode.SUCCEEDED:
self._connection._iprot.trans.close()
raise ClientServerIncompatibleException(resp.error_msg)
except Exception as e:
raise

def __get_protocol(self, timeout, ssl_config):
try:
if ssl_config is not None:
s = TSSLSocket.TSSLSocket(
Expand All @@ -89,14 +115,29 @@ def open_SSL(self, ip, port, timeout, ssl_config=None):
header_transport = THeaderTransport.THeaderTransport(buffered_transport)
protocol = THeaderProtocol.THeaderProtocol(header_transport)
header_transport.open()

self._connection = GraphService.Client(protocol)
resp = self._connection.verifyClientVersion(VerifyClientVersionReq())
if resp.error_code != ErrorCode.SUCCEEDED:
self._connection._iprot.trans.close()
raise ClientServerIncompatibleException(resp.error_msg)
except Exception:
except Exception as e:
raise
return protocol

def __get_protocal_http2(self, timeout, ssl_config, http_headers):
verify, certfile, keyfile, password = None, None, None, None
if ssl_config is not None:
# verify could be a boolean or ssl.SSLContext in httpx.
verify = ssl.create_default_context(cafile=ssl_config.ca_certs)
certfile = ssl_config.certfile
keyfile = ssl_config.keyfile
url = "https://" + self._ip + ":" + str(self._port)
else:
url = "http://" + self._ip + ":" + str(self._port)
try:
transport = THttp2Client.THttp2Client(
url, timeout, verify, certfile, keyfile, password, http_headers
)
transport.open()
protocol = TBinaryProtocol.TBinaryProtocol(transport)
except Exception as e:
raise
return protocol

def _reopen(self):
"""reopen the connection
Expand All @@ -105,9 +146,18 @@ def _reopen(self):
"""
self.close()
if self._ssl_conf is not None:
self.open_SSL(self._ip, self._port, self._timeout, self._ssl_conf)
self.open_SSL(
self._ip,
self._port,
self._timeout,
self._ssl_conf,
self.use_http2,
self.http_headers,
)
else:
self.open(self._ip, self._port, self._timeout)
self.open(
self._ip, self._port, self._timeout, self.use_http2, self.http_headers
)

def authenticate(self, user_name, password):
"""authenticate to graphd
Expand Down
20 changes: 18 additions & 2 deletions nebula3/gclient/net/ConnectionPool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from nebula3.gclient.net.Session import Session
from nebula3.gclient.net.Connection import Connection
from nebula3.Config import Config
from nebula3.logger import logger


Expand Down Expand Up @@ -51,6 +52,7 @@ def init(self, addresses, configs, ssl_conf=None):
if self._close:
logger.error('The pool has init or closed.')
raise RuntimeError('The pool has init or closed.')
assert isinstance(configs, Config)
self._configs = configs
self._ssl_configs = ssl_conf
for address in addresses:
Expand Down Expand Up @@ -82,7 +84,12 @@ def init(self, addresses, configs, ssl_conf=None):
for i in range(0, conns_per_address):
connection = Connection()
connection.open_SSL(
addr[0], addr[1], self._configs.timeout, self._ssl_configs
addr[0],
addr[1],
self._configs.timeout,
self._ssl_configs,
self._configs.use_http2,
self._configs.http_headers,
)
self._connections[addr].append(connection)
return True
Expand Down Expand Up @@ -181,6 +188,8 @@ def get_connection(self):
addr[1],
self._configs.timeout,
self._ssl_configs,
self._configs.use_http2,
self._configs.http_headers,
)
connection.is_used = True
self._connections[addr].append(connection)
Expand All @@ -206,7 +215,14 @@ def ping(self, address):
"""
try:
conn = Connection()
conn.open_SSL(address[0], address[1], 1000, self._ssl_configs)
conn.open_SSL(
address[0],
address[1],
1000,
self._ssl_configs,
self._configs.use_http2,
self._configs.http_headers,
)
conn.close()
return True
except Exception as ex:
Expand Down
25 changes: 22 additions & 3 deletions nebula3/gclient/net/SessionPool.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,22 @@ def ping(self, address):
try:
conn = Connection()
if self._ssl_configs is None:
conn.open(address[0], address[1], 1000)
conn.open(
address[0],
address[1],
1000,
self._configs.use_http2,
self._configs.http_headers,
)
else:
conn.open_SSL(address[0], address[1], 1000, self._ssl_configs)
conn.open_SSL(
address[0],
address[1],
1000,
self._ssl_configs,
self._configs.use_http2,
self._configs.http_headers,
)
conn.close()
return True
except Exception as ex:
Expand Down Expand Up @@ -381,7 +394,13 @@ def _new_session(self):
# connect to the valid service
connection = Connection()
try:
connection.open(addr[0], addr[1], self._configs.timeout)
connection.open(
addr[0],
addr[1],
self._configs.timeout,
self._configs.use_http2,
self._configs.http_headers,
)
auth_result = connection.authenticate(self._username, self._password)
session = Session(connection, auth_result, self, False)

Expand Down
Loading
Loading