Skip to content

Commit

Permalink
[v2.6.1 patch] Refactor connection pool and add more tests (vesoft-in…
Browse files Browse the repository at this point in the history
…c#251)

* Add more tests

* Tune tests and refactor

Lower concurrence for ci so the github action can pass
  • Loading branch information
Aiee committed Jan 11, 2023
1 parent 76103df commit 27d5ee5
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 41 deletions.
54 changes: 29 additions & 25 deletions nebula3/gclient/net/ConnectionPool.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def init(self, addresses, configs, ssl_conf=None):
self._addresses.append(ip_port)
self._addresses_status[ip_port] = self.S_BAD
self._connections[ip_port] = deque()

self._ssl_configs = ssl_conf
self.update_servers_status()

# detect the services
Expand All @@ -78,20 +78,13 @@ def init(self, addresses, configs, ssl_conf=None):

conns_per_address = int(self._configs.min_connection_pool_size / ok_num)

if self._ssl_configs is None:
for addr in self._addresses:
for i in range(0, conns_per_address):
connection = Connection()
connection.open(addr[0], addr[1], self._configs.timeout)
self._connections[addr].append(connection)
else:
for addr in self._addresses:
for i in range(0, conns_per_address):
connection = Connection()
connection.open_SSL(
addr[0], addr[1], self._configs.timeout, self._ssl_configs
)
self._connections[addr].append(connection)
for addr in self._addresses:
for i in range(0, conns_per_address):
connection = Connection()
connection.open_SSL(
addr[0], addr[1], self._configs.timeout, self._ssl_configs
)
self._connections[addr].append(connection)
return True

def get_session(self, user_name, password, retry_connect=True):
Expand Down Expand Up @@ -158,15 +151,29 @@ def get_connection(self):
self._pos = (self._pos + 1) % len(self._addresses)
addr = self._addresses[self._pos]
if self._addresses_status[addr] == self.S_OK:
invalid_connections = list()

# iterate all connections to find an available connection
for connection in self._connections[addr]:
if not connection.is_used:
# ping to check the connection is valid
if connection.ping():
connection.is_used = True
logger.info('Get connection to {}'.format(addr))
return connection
# remove unusable connection
self._connections[addr].remove(connection)
else:
invalid_connections.append(connection)

# remove invalid connections
for connection in invalid_connections:
self._connections[addr].remove(connection)

# check if the server is still alive
if not self.ping(addr):
self._addresses_status[addr] = self.S_BAD
continue

# create new connection if the number of connections is less than max_con_per_address
if len(self._connections[addr]) < max_con_per_address:
connection = Connection()
connection.open_SSL(
Expand All @@ -184,6 +191,8 @@ def get_connection(self):
if not connection.is_used:
self._connections[addr].remove(connection)
try_count = try_count + 1

logging.error('No available connection')
return None
except Exception as ex:
logger.error('Get connection failed: {}'.format(ex))
Expand All @@ -197,10 +206,7 @@ def ping(self, address):
"""
try:
conn = Connection()
if self._ssl_configs is None:
conn.open(address[0], address[1], 1000)
else:
conn.open_SSL(address[0], address[1], 1000, self._ssl_configs)
conn.open_SSL(address[0], address[1], 1000, self._ssl_configs)
conn.close()
return True
except Exception as ex:
Expand All @@ -218,9 +224,7 @@ def close(self):
for addr in self._connections.keys():
for connection in self._connections[addr]:
if connection.is_used:
logger.error(
'The connection using by someone, but now want to close it'
)
logger.warning('Closing a connection that is in use')
connection.close()
self._close = True

Expand Down Expand Up @@ -286,7 +290,7 @@ def _remove_idle_unusable_connection(self):
if not connection.is_used:
if not connection.ping():
logger.debug(
'Remove the not unusable connection to {}'.format(
'Remove the unusable connection to {}'.format(
connection.get_address()
)
)
Expand Down
121 changes: 105 additions & 16 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,15 @@ def test_multi_thread():
# Test multi thread
addresses = [('127.0.0.1', 9669), ('127.0.0.1', 9670)]
configs = Config()
configs.max_connection_pool_size = 4
thread_num = 50
configs.max_connection_pool_size = thread_num
pool = ConnectionPool()
assert pool.init(addresses, configs)

global success_flag
success_flag = True

def main_test():
def pool_multi_thread_test():
session = None
global success_flag
try:
Expand All @@ -187,7 +188,7 @@ def main_test():
return
space_name = 'space_' + threading.current_thread().getName()

session.execute('DROP SPACE %s' % space_name)
session.execute('DROP SPACE IF EXISTS %s' % space_name)
resp = session.execute(
'CREATE SPACE IF NOT EXISTS %s(vid_type=FIXED_STRING(8))' % space_name
)
Expand All @@ -207,20 +208,108 @@ def main_test():
if session is not None:
session.release()

thread1 = threading.Thread(target=main_test, name='thread1')
thread2 = threading.Thread(target=main_test, name='thread2')
thread3 = threading.Thread(target=main_test, name='thread3')
thread4 = threading.Thread(target=main_test, name='thread4')
threads = []
for num in range(0, thread_num):
thread = threading.Thread(
target=pool_multi_thread_test, name='test_pool_thread' + str(num)
)
thread.start()
threads.append(thread)

thread1.start()
thread2.start()
thread3.start()
thread4.start()

thread1.join()
thread2.join()
thread3.join()
thread4.join()
for t in threads:
t.join()
assert success_flag

pool.close()


def test_session_context_multi_thread():
# Test multi thread
addresses = [('127.0.0.1', 9669), ('127.0.0.1', 9670)]
configs = Config()
thread_num = 50
configs.max_connection_pool_size = thread_num
pool = ConnectionPool()
assert pool.init(addresses, configs)

global success_flag
success_flag = True

def pool_session_context_multi_thread_test():
session = None
global success_flag
try:
with pool.session_context('root', 'nebula') as session:
if session is None:
success_flag = False
return
space_name = 'space_' + threading.current_thread().getName()

session.execute('DROP SPACE IF EXISTS %s' % space_name)
resp = session.execute(
'CREATE SPACE IF NOT EXISTS %s(vid_type=FIXED_STRING(8))'
% space_name
)
if not resp.is_succeeded():
raise RuntimeError(
'CREATE SPACE failed: {}'.format(resp.error_msg())
)

time.sleep(3)
resp = session.execute('USE %s' % space_name)
if not resp.is_succeeded():
raise RuntimeError('USE SPACE failed:{}'.format(resp.error_msg()))

except Exception as x:
print(x)
success_flag = False
return

threads = []
for num in range(0, thread_num):
thread = threading.Thread(
target=pool_session_context_multi_thread_test,
name='test_session_context_thread' + str(num),
)
thread.start()
threads.append(thread)

for t in threads:
t.join()
assert success_flag

pool.close()


def test_remove_invalid_connection():
addresses = [('127.0.0.1', 9669), ('127.0.0.1', 9670), ('127.0.0.1', 9671)]
configs = Config()
configs.min_connection_pool_size = 30
configs.max_connection_pool_size = 45
pool = ConnectionPool()

try:
assert pool.init(addresses, configs)

# turn down one server('127.0.0.1', 9669) so the connection to it is invalid
os.system('docker stop tests_graphd0_1')
time.sleep(3)

# get connection from the pool, we should be able to still get 30 connections even though one server is down
for i in range(0, 30):
conn = pool.get_connection()
assert conn is not None

# total connection should still be 30
assert pool.connects() == 30

# the number of connections to the down server should be 0
assert len(pool._connections[addresses[0]]) == 0

# the number of connections to the 2nd('127.0.0.1', 9670) and 3rd server('127.0.0.1', 9671) should be 15
assert len(pool._connections[addresses[1]]) == 15
assert len(pool._connections[addresses[2]]) == 15

finally:
os.system('docker start tests_graphd0_1')
time.sleep(3)

0 comments on commit 27d5ee5

Please sign in to comment.