Skip to content

Commit

Permalink
Add RackAwareRoundRobinPolicy for host selection
Browse files Browse the repository at this point in the history
  • Loading branch information
sylwiaszunejko committed Jul 9, 2024
1 parent c9b24b7 commit bc88863
Show file tree
Hide file tree
Showing 6 changed files with 313 additions and 48 deletions.
9 changes: 7 additions & 2 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,8 @@ def _profiles_without_explicit_lbps(self):

def distance(self, host):
distances = set(p.load_balancing_policy.distance(host) for p in self.profiles.values())
return HostDistance.LOCAL if HostDistance.LOCAL in distances else \
return HostDistance.LOCAL_RACK if HostDistance.LOCAL_RACK in distances else \
HostDistance.LOCAL if HostDistance.LOCAL in distances else \
HostDistance.REMOTE if HostDistance.REMOTE in distances else \
HostDistance.IGNORED

Expand Down Expand Up @@ -613,7 +614,7 @@ class Cluster(object):
Defaults to loopback interface.
Note: When using :class:`.DCAwareLoadBalancingPolicy` with no explicit
Note: When using :class:`.DCAwareRoundRobinPolicy` with no explicit
local_dc set (as is the default), the DC is chosen from an arbitrary
host in contact_points. In this case, contact_points should contain
only nodes from a single, local DC.
Expand Down Expand Up @@ -1373,21 +1374,25 @@ def __init__(self,
self._user_types = defaultdict(dict)

self._min_requests_per_connection = {
HostDistance.LOCAL_RACK: DEFAULT_MIN_REQUESTS,
HostDistance.LOCAL: DEFAULT_MIN_REQUESTS,
HostDistance.REMOTE: DEFAULT_MIN_REQUESTS
}

self._max_requests_per_connection = {
HostDistance.LOCAL_RACK: DEFAULT_MAX_REQUESTS,
HostDistance.LOCAL: DEFAULT_MAX_REQUESTS,
HostDistance.REMOTE: DEFAULT_MAX_REQUESTS
}

self._core_connections_per_host = {
HostDistance.LOCAL_RACK: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST,
HostDistance.LOCAL: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST,
HostDistance.REMOTE: DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST
}

self._max_connections_per_host = {
HostDistance.LOCAL_RACK: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST,
HostDistance.LOCAL: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST,
HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST
}
Expand Down
2 changes: 1 addition & 1 deletion cassandra/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3440,7 +3440,7 @@ def group_keys_by_replica(session, keyspace, table, keys):
all_replicas = cluster.metadata.get_replicas(keyspace, routing_key)
# First check if there are local replicas
valid_replicas = [host for host in all_replicas if
host.is_up and distance(host) == HostDistance.LOCAL]
host.is_up and (distance(host) == HostDistance.LOCAL or distance(host) == HostDistance.LOCAL_RACK)]
if not valid_replicas:
valid_replicas = [host for host in all_replicas if host.is_up]

Expand Down
135 changes: 129 additions & 6 deletions cassandra/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,18 @@ class HostDistance(object):
connections opened to it.
"""

LOCAL = 0
LOCAL_RACK = 0
"""
Nodes with ``LOCAL_RACK`` distance will be preferred for operations
under some load balancing policies (such as :class:`.RackAwareRoundRobinPolicy`)
and will have a greater number of connections opened against
them by default.
This distance is typically used for nodes within the same
datacenter and the same rack as the client.
"""

LOCAL = 1
"""
Nodes with ``LOCAL`` distance will be preferred for operations
under some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`)
Expand All @@ -57,12 +68,12 @@ class HostDistance(object):
datacenter as the client.
"""

REMOTE = 1
REMOTE = 2
"""
Nodes with ``REMOTE`` distance will be treated as a last resort
by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`)
and will have a smaller number of connections opened against
them by default.
by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`
and :class:`.RackAwareRoundRobinPolicy`)and will have a smaller number of
connections opened against them by default.
This distance is typically used for nodes outside of the
datacenter that the client is running in.
Expand Down Expand Up @@ -316,6 +327,118 @@ def on_add(self, host):
def on_remove(self, host):
self.on_down(host)

class RackAwareRoundRobinPolicy(LoadBalancingPolicy):
"""
Similar to :class:`.DCAwareRoundRobinPolicy`, but prefers hosts
in the local rack, before hosts in the local datacenter but a
different rack, before hosts in all other datercentres
"""

local_dc = None
local_rack = None
used_hosts_per_remote_dc = 0

def __init__(self, local_dc, local_rack, used_hosts_per_remote_dc=0):
"""
The `local_dc` and `local_rack` parameters should be the name of the
datacenter and rack (such as is reported by ``nodetool ring``) that
should be considered local.
`used_hosts_per_remote_dc` controls how many nodes in
each remote datacenter will have connections opened
against them. In other words, `used_hosts_per_remote_dc` hosts
will be considered :attr:`~.HostDistance.REMOTE` and the
rest will be considered :attr:`~.HostDistance.IGNORED`.
By default, all remote hosts are ignored.
"""
self.local_rack = local_rack
self.local_dc = local_dc
self.used_hosts_per_remote_dc = used_hosts_per_remote_dc
self._live_hosts = {}
self._dc_live_hosts = {}
self._position = 0
self._endpoints = []
LoadBalancingPolicy.__init__(self)

def _rack(self, host):
return host.rack or self.local_rack

def _dc(self, host):
return host.datacenter or self.local_dc

def populate(self, cluster, hosts):
for (dc, rack), dc_hosts in groupby(hosts, lambda host: (self._dc(host), self._rack(host))):
self._live_hosts[(dc, rack)] = list(dc_hosts)
for dc, dc_hosts in groupby(hosts, lambda host: self._dc(host)):
self._dc_live_hosts[dc] = list(dc_hosts)

# as in other policies choose random position for better distributing queries across hosts
self._position = randint(0, len(hosts) - 1) if hosts else 0

def distance(self, host):
rack = self._rack(host)
dc = self._dc(host)
if rack == self.local_rack and dc == self.local_dc:
return HostDistance.LOCAL_RACK

if dc == self.local_dc:
return HostDistance.LOCAL

if not self.used_hosts_per_remote_dc:
return HostDistance.IGNORED
else:
dc_hosts = self._dc_live_hosts.get(dc, ())
if not dc_hosts:
return HostDistance.IGNORED

if host in dc_hosts[:self.used_hosts_per_remote_dc]:

return HostDistance.REMOTE
else:
return HostDistance.IGNORED

def make_query_plan(self, working_keyspace=None, query=None):
pos = self._position
self._position += 1

local_rack_live = self._live_hosts.get((self.local_dc, self.local_rack), ())
pos = (pos % len(local_rack_live)) if local_rack_live else 0
# Slice the cyclic iterator to start from pos and include the next len(local_live) elements
# This ensures we get exactly one full cycle starting from pos
for host in islice(cycle(local_rack_live), pos, pos + len(local_rack_live)):
yield host

local_live = [host for host in self._dc_live_hosts.get(self.local_dc, ()) if host.rack != self.local_rack]
pos = (pos % len(local_live)) if local_live else 0
for host in islice(cycle(local_live), pos, pos + len(local_live)):
yield host

# the dict can change, so get candidate DCs iterating over keys of a copy
other_dcs = [dc for dc in self._dc_live_hosts.copy().keys() if dc != self.local_dc]
for dc in other_dcs:
remote_live = self._dc_live_hosts.get(dc, ())
for host in remote_live[:self.used_hosts_per_remote_dc]:
yield host

def on_up(self, host):
dc = self._dc(host)
rack = self._rack(host)
with self._hosts_lock:
self._live_hosts[(dc, rack)].append(host)
self._dc_live_hosts[dc].append(host)

def on_down(self, host):
dc = self._dc(host)
rack = self._rack(host)
with self._hosts_lock:
self._live_hosts[(dc, rack)].remove(host)
self._dc_live_hosts[dc].remove(host)

def on_add(self, host):
self.on_up(host)

def on_remove(self, host):
self.on_down(host)

class TokenAwarePolicy(LoadBalancingPolicy):
"""
Expand Down Expand Up @@ -396,7 +519,7 @@ def make_query_plan(self, working_keyspace=None, query=None):
shuffle(replicas)
for replica in replicas:
if replica.is_up and \
child.distance(replica) == HostDistance.LOCAL:
(child.distance(replica) == HostDistance.LOCAL or child.distance(replica) == HostDistance.LOCAL_RACK):
yield replica

for host in child.make_query_plan(keyspace, query):
Expand Down
3 changes: 3 additions & 0 deletions docs/api/cassandra/policies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ Load Balancing
.. autoclass:: DCAwareRoundRobinPolicy
:members:

.. autoclass:: RackAwareRoundRobinPolicy
:members:

.. autoclass:: WhiteListRoundRobinPolicy
:members:

Expand Down
87 changes: 87 additions & 0 deletions tests/integration/standard/test_rack_aware_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import logging
import unittest

from cassandra.cluster import Cluster
from cassandra.policies import ConstantReconnectionPolicy, RackAwareRoundRobinPolicy

from tests.integration import PROTOCOL_VERSION, get_cluster, use_multidc

LOGGER = logging.getLogger(__name__)

def setup_module():
use_multidc({'DC1': {'RC1': 2, 'RC2': 2}, 'DC2': {'RC1': 2}})

class RackAwareRoundRobinPolicyTests(unittest.TestCase):
@classmethod
def setup_class(cls):
cls.cluster = Cluster(contact_points=[node.address() for node in get_cluster().nodelist()], protocol_version=PROTOCOL_VERSION,
load_balancing_policy=RackAwareRoundRobinPolicy("DC1", "RC1", used_hosts_per_remote_dc=0),
reconnection_policy=ConstantReconnectionPolicy(1))
cls.session = cls.cluster.connect()
cls.create_ks_and_cf(cls)
cls.create_data(cls.session)

@classmethod
def teardown_class(cls):
cls.cluster.shutdown()

def create_ks_and_cf(self):
self.session.execute(
"""
DROP KEYSPACE IF EXISTS test1
"""
)
self.session.execute(
"""
CREATE KEYSPACE test1
WITH replication = {
'class': 'NetworkTopologyStrategy',
'replication_factor': 1
}
""")

self.session.execute(
"""
CREATE TABLE test1.table1 (pk int, ck int, v int, PRIMARY KEY (pk, ck));
""")

@staticmethod
def create_data(session):
prepared = session.prepare(
"""
INSERT INTO test1.table1 (pk, ck, v) VALUES (?, ?, ?)
""")

for i in range(50):
bound = prepared.bind((i, i%5, i%2))
session.execute(bound)

def get_hosts_from_tracing(self, results):
traces = results.get_query_trace()
events = traces.events
host_set = set()
for event in events:
LOGGER.info("TRACE EVENT: %s %s %s", event.source, event.thread_name, event.description)
host_set.add(event.source)

trace_id = results.response_future.get_query_trace_ids()[0]
traces = self.session.execute("SELECT * FROM system_traces.events WHERE session_id = %s", (trace_id,))
events = [event for event in traces]
host_set = set()
for event in events:
LOGGER.info("TRACE EVENT: %s %s", event.source, event.activity)
host_set.add(event.source)

return host_set

def test_rack_aware(self):
prepared = self.session.prepare(
"""
SELECT pk, ck, v FROM test1.table1 WHERE pk = ?
""")
for i in range(20):
bound = prepared.bind([(i)])
results = self.session.execute(bound, trace=True)
self.assertEqual(results, [(i, i%5, i%2)])
queried_hosts = self.get_hosts_from_tracing(results)
self.assertTrue(queried_hosts.issubset(set(["127.0.0.1", "127.0.0.2", "127.0.0.3", "127.0.0.4"])))
Loading

0 comments on commit bc88863

Please sign in to comment.