Skip to content

Commit

Permalink
Add RackAwareRoundRobinPolicy for host selection
Browse files Browse the repository at this point in the history
  • Loading branch information
sylwiaszunejko committed Jun 28, 2024
1 parent c9b24b7 commit a45adb2
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 48 deletions.
9 changes: 7 additions & 2 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,8 @@ def _profiles_without_explicit_lbps(self):

def distance(self, host):
distances = set(p.load_balancing_policy.distance(host) for p in self.profiles.values())
return HostDistance.LOCAL if HostDistance.LOCAL in distances else \
return HostDistance.LOCAL_RACK if HostDistance.LOCAL_RACK in distances else \
HostDistance.LOCAL if HostDistance.LOCAL in distances else \
HostDistance.REMOTE if HostDistance.REMOTE in distances else \
HostDistance.IGNORED

Expand Down Expand Up @@ -613,7 +614,7 @@ class Cluster(object):
Defaults to loopback interface.
Note: When using :class:`.DCAwareLoadBalancingPolicy` with no explicit
Note: When using :class:`.DCAwareRoundRobinPolicy` with no explicit
local_dc set (as is the default), the DC is chosen from an arbitrary
host in contact_points. In this case, contact_points should contain
only nodes from a single, local DC.
Expand Down Expand Up @@ -1373,21 +1374,25 @@ def __init__(self,
self._user_types = defaultdict(dict)

self._min_requests_per_connection = {
HostDistance.LOCAL_RACK: DEFAULT_MIN_REQUESTS,
HostDistance.LOCAL: DEFAULT_MIN_REQUESTS,
HostDistance.REMOTE: DEFAULT_MIN_REQUESTS
}

self._max_requests_per_connection = {
HostDistance.LOCAL_RACK: DEFAULT_MAX_REQUESTS,
HostDistance.LOCAL: DEFAULT_MAX_REQUESTS,
HostDistance.REMOTE: DEFAULT_MAX_REQUESTS
}

self._core_connections_per_host = {
HostDistance.LOCAL_RACK: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST,
HostDistance.LOCAL: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST,
HostDistance.REMOTE: DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST
}

self._max_connections_per_host = {
HostDistance.LOCAL_RACK: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST,
HostDistance.LOCAL: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST,
HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST
}
Expand Down
2 changes: 1 addition & 1 deletion cassandra/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3440,7 +3440,7 @@ def group_keys_by_replica(session, keyspace, table, keys):
all_replicas = cluster.metadata.get_replicas(keyspace, routing_key)
# First check if there are local replicas
valid_replicas = [host for host in all_replicas if
host.is_up and distance(host) == HostDistance.LOCAL]
host.is_up and (distance(host) == HostDistance.LOCAL or distance(host) == HostDistance.LOCAL_RACK)]
if not valid_replicas:
valid_replicas = [host for host in all_replicas if host.is_up]

Expand Down
135 changes: 129 additions & 6 deletions cassandra/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,18 @@ class HostDistance(object):
connections opened to it.
"""

LOCAL = 0
LOCAL_RACK = 0
"""
Nodes with ``LOCAL_RACK`` distance will be preferred for operations
under some load balancing policies (such as :class:`.RackAwareRoundRobinPolicy`)
and will have a greater number of connections opened against
them by default.
This distance is typically used for nodes within the same
datacenter and the same rack as the client.
"""

LOCAL = 1
"""
Nodes with ``LOCAL`` distance will be preferred for operations
under some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`)
Expand All @@ -57,12 +68,12 @@ class HostDistance(object):
datacenter as the client.
"""

REMOTE = 1
REMOTE = 2
"""
Nodes with ``REMOTE`` distance will be treated as a last resort
by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`)
and will have a smaller number of connections opened against
them by default.
by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`
and :class:`.RackAwareRoundRobinPolicy`)and will have a smaller number of
connections opened against them by default.
This distance is typically used for nodes outside of the
datacenter that the client is running in.
Expand Down Expand Up @@ -316,6 +327,118 @@ def on_add(self, host):
def on_remove(self, host):
self.on_down(host)

class RackAwareRoundRobinPolicy(LoadBalancingPolicy):
"""
Similar to :class:`.DCAwareRoundRobinPolicy`, but prefers hosts
in the local rack, before hosts in the local datacenter but a
different rack, before hosts in all other datercentres
"""

local_dc = None
local_rack = None
used_hosts_per_remote_dc = 0

def __init__(self, local_dc, local_rack, used_hosts_per_remote_dc=0):
"""
The `local_dc` and `local_rack` parameters should be the name of the
datacenter and rack (such as is reported by ``nodetool ring``) that
should be considered local.
`used_hosts_per_remote_dc` controls how many nodes in
each remote datacenter will have connections opened
against them. In other words, `used_hosts_per_remote_dc` hosts
will be considered :attr:`~.HostDistance.REMOTE` and the
rest will be considered :attr:`~.HostDistance.IGNORED`.
By default, all remote hosts are ignored.
"""
self.local_rack = local_rack
self.local_dc = local_dc
self.used_hosts_per_remote_dc = used_hosts_per_remote_dc
self._live_hosts = {}
self._dc_live_hosts = {}
self._position = 0
self._endpoints = []
LoadBalancingPolicy.__init__(self)

def _rack(self, host):
return host.rack or self.local_rack

def _dc(self, host):
return host.datacenter or self.local_dc

def populate(self, cluster, hosts):
for (dc, rack), dc_hosts in groupby(hosts, lambda host: (self._dc(host), self._rack(host))):
self._live_hosts[(dc, rack)] = dc_hosts
for dc, dc_hosts in groupby(hosts, lambda host: self._dc(host)):
self._dc_live_hosts[dc] = dc_hosts

# as in other policies choose random position for better distributing queries across hosts
self._position = randint(0, len(hosts) - 1) if hosts else 0

def distance(self, host):
rack = self._rack(host)
dc = self._dc(host)
if rack == self.local_rack and dc == self.local_dc:
return HostDistance.LOCAL_RACK

if dc == self.local_dc:
return HostDistance.LOCAL

if not self.used_hosts_per_remote_dc:
return HostDistance.IGNORED
else:
dc_hosts = self._dc_live_hosts.get(dc, ())
if not dc_hosts:
return HostDistance.IGNORED

if host in dc_hosts[:self.used_hosts_per_remote_dc]:

return HostDistance.REMOTE
else:
return HostDistance.IGNORED

def make_query_plan(self, working_keyspace=None, query=None):
pos = self._position
self._position += 1

local_rack_live = self._live_hosts.get((self.local_dc, self.local_rack), ())
pos = (pos % len(local_rack_live)) if local_rack_live else 0
# Slice the cyclic iterator to start from pos and include the next len(local_live) elements
# This ensures we get exactly one full cycle starting from pos
for host in islice(cycle(local_rack_live), pos, pos + len(local_rack_live)):
yield host

local_live = [host for host in self._dc_live_hosts.get(self.local_dc, ()) if host.rack != self.local_rack]
pos = (pos % len(local_live)) if local_live else 0
for host in islice(cycle(local_live), pos, pos + len(local_live)):
yield host

# the dict can change, so get candidate DCs iterating over keys of a copy
other_dcs = [dc for dc in self._dc_live_hosts.copy().keys() if dc != self.local_dc]
for dc in other_dcs:
remote_live = self._dc_live_hosts.get(dc, ())
for host in remote_live[:self.used_hosts_per_remote_dc]:
yield host

def on_up(self, host):
dc = self._dc(host)
rack = self._rack(host)
with self._hosts_lock:
self._live_hosts[(dc, rack)].append(host)
self._dc_live_hosts[dc].append(host)

def on_down(self, host):
dc = self._dc(host)
rack = self._rack(host)
with self._hosts_lock:
self._live_hosts[(dc, rack)].remove(host)
self._dc_live_hosts[dc].remove(host)

def on_add(self, host):
self.on_up(host)

def on_remove(self, host):
self.on_down(host)

class TokenAwarePolicy(LoadBalancingPolicy):
"""
Expand Down Expand Up @@ -396,7 +519,7 @@ def make_query_plan(self, working_keyspace=None, query=None):
shuffle(replicas)
for replica in replicas:
if replica.is_up and \
child.distance(replica) == HostDistance.LOCAL:
(child.distance(replica) == HostDistance.LOCAL or child.distance(replica) == HostDistance.LOCAL_RACK):
yield replica

for host in child.make_query_plan(keyspace, query):
Expand Down
3 changes: 3 additions & 0 deletions docs/api/cassandra/policies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ Load Balancing
.. autoclass:: DCAwareRoundRobinPolicy
:members:

.. autoclass:: RackAwareRoundRobinPolicy
:members:

.. autoclass:: WhiteListRoundRobinPolicy
:members:

Expand Down
Loading

0 comments on commit a45adb2

Please sign in to comment.