-
Notifications
You must be signed in to change notification settings - Fork 0
/
connection.py
122 lines (102 loc) · 4.51 KB
/
connection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from typing import Iterable
import threading
from time import sleep
import requests
class Connection:
SERVER_LIST_FETCH_FREQUENCY = .1
def __init__(self, uris: Iterable[str]):
uris_n = []
self._primary_server, self._servers = None, None
for uri in uris:
uris_n.append(uri[:-1] if uri[-1] == '/' else uri)
self._servers_lock = threading.Lock()
self._update_servers(uris_n)
self._round_robin_index = 0
self._stop_worker = False
self._worker_thread = threading.Thread(target=self._worker_routine)
self._worker_thread.start()
@property
def _secondary_servers(self):
return list(filter(lambda x: x != self._primary_server, self._servers))
def _update_servers(self, uris: Iterable[str]):
uris = list(uris)
invalid_uris = []
for uri in uris:
try:
res = requests.get(uri + '/managers')
except requests.exceptions.ConnectionError:
invalid_uris.append(uri)
continue
if res.ok:
self._servers_lock.acquire()
managers = res.json()['managers']
self._servers = []
for server in managers:
if server[1]:
self._primary_server = server[0]
else:
self._servers.append(server[0])
self._servers_lock.release()
else:
invalid_uris.append(uri)
for uri in invalid_uris:
uris.remove(uri)
if self._primary_server is None:
raise Exception('No primary server found')
def _send_to_primary(self, action, path: str, data: dict = None, params: dict = None):
try:
res = action(self._primary_server + path, data=data, params=params)
return res
except requests.exceptions.ConnectionError:
self._servers_lock.acquire()
self._servers.remove(self._primary_server)
self._servers_lock.release()
self._update_servers(self._servers)
return self._send_to_primary(action, path, data=data, params=params)
def _send_to_secondary(self, action, path: str, data: dict = None, params: dict = None):
self._servers_lock.acquire()
if len(self._secondary_servers) == 0:
try:
res = action(self._primary_server + path, data=data, params=params)
return res
except requests.exceptions.ConnectionError:
self._servers_lock.release()
raise Exception('No server found')
secondary_server = self._secondary_servers[self._round_robin_index]
self._servers_lock.release()
try:
res = action(secondary_server + path, data=data, params=params)
self._servers_lock.acquire()
self._round_robin_index = (self._round_robin_index + 1) % len(self._secondary_servers)
self._servers_lock.release()
return res
except requests.exceptions.ConnectionError:
self._servers_lock.acquire()
self._servers.remove(secondary_server)
self._round_robin_index = self._round_robin_index % len(self._secondary_servers)
self._servers_lock.release()
return self._send_to_secondary(action, path, data=data, params=params)
def get(self, path: str, data: dict = None, params: dict = None):
return self._send_to_secondary(requests.get, path, data=data, params=params)
def post_readonly(self, path: str, data: dict = None, params: dict = None):
return self._send_to_secondary(requests.post, path, data=data, params=params)
def post(self, path: str, data: dict = None, params: dict = None):
return self._send_to_primary(requests.post, path, data=data, params=params)
def _worker_routine(self):
while not self._stop_worker:
try:
self._update_servers(self._servers)
finally:
sleep(1 / self.SERVER_LIST_FETCH_FREQUENCY)
def __exit__(self, exc_type, exc_val, exc_tb):
self._stop_worker = True
print('Closing connection...',)
if hasattr(self, '_worker_thread'):
self._worker_thread.join()
def __enter__(self):
return self
# def __del__(self):
# print('Closing connection in ', threading.current_thread())
# self._stop_worker = True
# if hasattr(self, '_worker_thread'):
# self._worker_thread.join()