Skip to content

Commit

Permalink
[v2.6.1 patch] Refactor connection pool and add more tests (#251)
Browse files Browse the repository at this point in the history
* Add more tests

* Tune tests and refactor

Lower concurrence for ci so the github action can pass
  • Loading branch information
Aiee authored Dec 13, 2022
1 parent d4d305d commit ef9e8d5
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 72 deletions.
100 changes: 61 additions & 39 deletions nebula2/gclient/net/ConnectionPool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
from collections import deque
from threading import RLock, Timer

from nebula2.Exception import (
NotValidConnectionException,
InValidHostname
)
from nebula2.Exception import NotValidConnectionException, InValidHostname

from nebula2.gclient.net.Session import Session
from nebula2.gclient.net.Connection import Connection
Expand Down Expand Up @@ -65,7 +62,7 @@ def init(self, addresses, configs, ssl_conf=None):
self._addresses.append(ip_port)
self._addresses_status[ip_port] = self.S_BAD
self._connections[ip_port] = deque()

self._ssl_configs = ssl_conf
self.update_servers_status()

# detect the services
Expand All @@ -74,25 +71,19 @@ def init(self, addresses, configs, ssl_conf=None):
# init min connections
ok_num = self.get_ok_servers_num()
if ok_num < len(self._addresses):
raise RuntimeError('The services status exception: {}'.format(
self._get_services_status()))

conns_per_address = int(
self._configs.min_connection_pool_size / ok_num)

if ssl_conf is None:
for addr in self._addresses:
for i in range(0, conns_per_address):
connection = Connection()
connection.open(addr[0], addr[1], self._configs.timeout)
self._connections[addr].append(connection)
else:
for addr in self._addresses:
for i in range(0, conns_per_address):
connection = Connection()
connection.open_SSL(
addr[0], addr[1], self._configs.timeout, self._ssl_configs)
self._connections[addr].append(connection)
raise RuntimeError(
'The services status exception: {}'.format(self._get_services_status())
)

conns_per_address = int(self._configs.min_connection_pool_size / ok_num)

for addr in self._addresses:
for i in range(0, conns_per_address):
connection = Connection()
connection.open_SSL(
addr[0], addr[1], self._configs.timeout, self._ssl_configs
)
self._connections[addr].append(connection)
return True

def get_session(self, user_name, password, retry_connect=True):
Expand Down Expand Up @@ -151,25 +142,45 @@ def get_connection(self):
if ok_num == 0:
logging.error('No available server')
return None
max_con_per_address = int(self._configs.max_connection_pool_size / ok_num)
max_con_per_address = int(
self._configs.max_connection_pool_size / ok_num
)
try_count = 0
while try_count <= len(self._addresses):
self._pos = (self._pos + 1) % len(self._addresses)
addr = self._addresses[self._pos]
if self._addresses_status[addr] == self.S_OK:
invalid_connections = list()

# iterate all connections to find an available connection
for connection in self._connections[addr]:
if not connection.is_used:
# ping to check the connection is valid
if connection.ping():
connection.is_used = True
logging.info('Get connection to {}'.format(addr))
return connection
# remove unusable connection
self._connections[addr].remove(connection)
else:
invalid_connections.append(connection)

# remove invalid connections
for connection in invalid_connections:
self._connections[addr].remove(connection)

# check if the server is still alive
if not self.ping(addr):
self._addresses_status[addr] = self.S_BAD
continue

# create new connection if the number of connections is less than max_con_per_address
if len(self._connections[addr]) < max_con_per_address:
connection = Connection()
connection.open_SSL(
addr[0], addr[1], self._configs.timeout, self._ssl_configs)
addr[0],
addr[1],
self._configs.timeout,
self._ssl_configs,
)
connection.is_used = True
self._connections[addr].append(connection)
logging.info('Get connection to {}'.format(addr))
Expand All @@ -179,6 +190,8 @@ def get_connection(self):
if not connection.is_used:
self._connections[addr].remove(connection)
try_count = try_count + 1

logging.error('No available connection')
return None
except Exception as ex:
logging.error('Get connection failed: {}'.format(ex))
Expand All @@ -192,14 +205,13 @@ def ping(self, address):
"""
try:
conn = Connection()
if self._ssl_configs is None:
conn.open(address[0], address[1], 1000)
else:
conn.open_SSL(address[0], address[1], 1000, self._ssl_configs)
conn.open_SSL(address[0], address[1], 1000, self._ssl_configs)
conn.close()
return True
except Exception as ex:
logging.warning('Connect {}:{} failed: {}'.format(address[0], address[1], ex))
logging.warning(
'Connect {}:{} failed: {}'.format(address[0], address[1], ex)
)
return False

def close(self):
Expand All @@ -211,7 +223,7 @@ def close(self):
for addr in self._connections.keys():
for connection in self._connections[addr]:
if connection.is_used:
logging.error('The connection using by someone, but now want to close it')
logging.warning('Closing a connection that is in use')
connection.close()
self._close = True

Expand Down Expand Up @@ -260,8 +272,7 @@ def _get_services_status(self):
return ', '.join(msg_list)

def update_servers_status(self):
"""update the servers' status
"""
"""update the servers' status"""
for address in self._addresses:
if self.ping(address):
self._addresses_status[address] = self.S_OK
Expand All @@ -277,11 +288,22 @@ def _remove_idle_unusable_connection(self):
for connection in list(conns):
if not connection.is_used:
if not connection.ping():
logging.debug('Remove the not unusable connection to {}'.format(connection.get_address()))
logging.debug(
'Remove the unusable connection to {}'.format(
connection.get_address()
)
)
conns.remove(connection)
continue
if self._configs.idle_time != 0 and connection.idle_time() > self._configs.idle_time:
logging.debug('Remove the idle connection to {}'.format(connection.get_address()))
if (
self._configs.idle_time != 0
and connection.idle_time() > self._configs.idle_time
):
logging.debug(
'Remove the idle connection to {}'.format(
connection.get_address()
)
)
conns.remove(connection)

def _period_detect(self):
Expand Down
18 changes: 9 additions & 9 deletions tests/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ services:
- --log_dir=/logs
- --v=0
- --minloglevel=0
- --heartbeat_interval_secs=2
- --heartbeat_interval_secs=1
# ssl
- --ca_path=${ca_path}
- --cert_path=${cert_path}
Expand Down Expand Up @@ -55,7 +55,7 @@ services:
- --log_dir=/logs
- --v=0
- --minloglevel=0
- --heartbeat_interval_secs=2
- --heartbeat_interval_secs=1
# ssl
- --ca_path=${ca_path}
- --cert_path=${cert_path}
Expand Down Expand Up @@ -96,7 +96,7 @@ services:
- --log_dir=/logs
- --v=0
- --minloglevel=0
- --heartbeat_interval_secs=2
- --heartbeat_interval_secs=1
# ssl
- --ca_path=${ca_path}
- --cert_path=${cert_path}
Expand Down Expand Up @@ -137,7 +137,7 @@ services:
- --log_dir=/logs
- --v=0
- --minloglevel=0
- --heartbeat_interval_secs=2
- --heartbeat_interval_secs=1
- --timezone_name=+08:00
# ssl
- --ca_path=${ca_path}
Expand Down Expand Up @@ -183,7 +183,7 @@ services:
- --log_dir=/logs
- --v=0
- --minloglevel=0
- --heartbeat_interval_secs=2
- --heartbeat_interval_secs=1
- --timezone_name=+08:00
# ssl
- --ca_path=${ca_path}
Expand Down Expand Up @@ -229,7 +229,7 @@ services:
- --log_dir=/logs
- --v=0
- --minloglevel=0
- --heartbeat_interval_secs=2
- --heartbeat_interval_secs=1
- --timezone_name=+08:00
# ssl
- --ca_path=${ca_path}
Expand Down Expand Up @@ -273,7 +273,7 @@ services:
- --log_dir=/logs
- --v=0
- --minloglevel=0
- --heartbeat_interval_secs=2
- --heartbeat_interval_secs=1
- --timezone_name=+08:00
# ssl
- --ca_path=${ca_path}
Expand Down Expand Up @@ -316,7 +316,7 @@ services:
- --log_dir=/logs
- --v=0
- --minloglevel=0
- --heartbeat_interval_secs=2
- --heartbeat_interval_secs=1
- --timezone_name=+08:00
# ssl
- --ca_path=${ca_path}
Expand Down Expand Up @@ -359,7 +359,7 @@ services:
- --log_dir=/logs
- --v=0
- --minloglevel=0
- --heartbeat_interval_secs=2
- --heartbeat_interval_secs=1
- --timezone_name=+08:00
# ssl
- --ca_path=${ca_path}
Expand Down
Loading

0 comments on commit ef9e8d5

Please sign in to comment.