diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5f2669c0b..b910c20e0 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -496,7 +496,8 @@ def _profiles_without_explicit_lbps(self): def distance(self, host): distances = set(p.load_balancing_policy.distance(host) for p in self.profiles.values()) - return HostDistance.LOCAL if HostDistance.LOCAL in distances else \ + return HostDistance.LOCAL_RACK if HostDistance.LOCAL_RACK in distances else \ + HostDistance.LOCAL if HostDistance.LOCAL in distances else \ HostDistance.REMOTE if HostDistance.REMOTE in distances else \ HostDistance.IGNORED @@ -613,7 +614,7 @@ class Cluster(object): Defaults to loopback interface. - Note: When using :class:`.DCAwareLoadBalancingPolicy` with no explicit + Note: When using :class:`.DCAwareRoundRobinPolicy` with no explicit local_dc set (as is the default), the DC is chosen from an arbitrary host in contact_points. In this case, contact_points should contain only nodes from a single, local DC. @@ -1373,21 +1374,25 @@ def __init__(self, self._user_types = defaultdict(dict) self._min_requests_per_connection = { + HostDistance.LOCAL_RACK: DEFAULT_MIN_REQUESTS, HostDistance.LOCAL: DEFAULT_MIN_REQUESTS, HostDistance.REMOTE: DEFAULT_MIN_REQUESTS } self._max_requests_per_connection = { + HostDistance.LOCAL_RACK: DEFAULT_MAX_REQUESTS, HostDistance.LOCAL: DEFAULT_MAX_REQUESTS, HostDistance.REMOTE: DEFAULT_MAX_REQUESTS } self._core_connections_per_host = { + HostDistance.LOCAL_RACK: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST, HostDistance.LOCAL: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST, HostDistance.REMOTE: DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST } self._max_connections_per_host = { + HostDistance.LOCAL_RACK: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST, HostDistance.LOCAL: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST, HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST } diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 9ef24b981..b5a622ae5 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -3440,7 +3440,7 @@ def group_keys_by_replica(session, keyspace, table, keys): all_replicas = cluster.metadata.get_replicas(keyspace, routing_key) # First check if there are local replicas valid_replicas = [host for host in all_replicas if - host.is_up and distance(host) == HostDistance.LOCAL] + host.is_up and (distance(host) == HostDistance.LOCAL or distance(host) == HostDistance.LOCAL_RACK)] if not valid_replicas: valid_replicas = [host for host in all_replicas if host.is_up] diff --git a/cassandra/policies.py b/cassandra/policies.py index 691287745..3387107c1 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -46,7 +46,18 @@ class HostDistance(object): connections opened to it. """ - LOCAL = 0 + LOCAL_RACK = 0 + """ + Nodes with ``LOCAL_RACK`` distance will be preferred for operations + under some load balancing policies (such as :class:`.RackAwareRoundRobinPolicy`) + and will have a greater number of connections opened against + them by default. + + This distance is typically used for nodes within the same + datacenter and the same rack as the client. + """ + + LOCAL = 1 """ Nodes with ``LOCAL`` distance will be preferred for operations under some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`) @@ -57,12 +68,12 @@ class HostDistance(object): datacenter as the client. """ - REMOTE = 1 + REMOTE = 2 """ Nodes with ``REMOTE`` distance will be treated as a last resort - by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`) - and will have a smaller number of connections opened against - them by default. + by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy` + and :class:`.RackAwareRoundRobinPolicy`)and will have a smaller number of + connections opened against them by default. This distance is typically used for nodes outside of the datacenter that the client is running in. @@ -316,6 +327,118 @@ def on_add(self, host): def on_remove(self, host): self.on_down(host) +class RackAwareRoundRobinPolicy(LoadBalancingPolicy): + """ + Similar to :class:`.DCAwareRoundRobinPolicy`, but prefers hosts + in the local rack, before hosts in the local datacenter but a + different rack, before hosts in all other datercentres + """ + + local_dc = None + local_rack = None + used_hosts_per_remote_dc = 0 + + def __init__(self, local_dc, local_rack, used_hosts_per_remote_dc=0): + """ + The `local_dc` and `local_rack` parameters should be the name of the + datacenter and rack (such as is reported by ``nodetool ring``) that + should be considered local. + + `used_hosts_per_remote_dc` controls how many nodes in + each remote datacenter will have connections opened + against them. In other words, `used_hosts_per_remote_dc` hosts + will be considered :attr:`~.HostDistance.REMOTE` and the + rest will be considered :attr:`~.HostDistance.IGNORED`. + By default, all remote hosts are ignored. + """ + self.local_rack = local_rack + self.local_dc = local_dc + self.used_hosts_per_remote_dc = used_hosts_per_remote_dc + self._live_hosts = {} + self._dc_live_hosts = {} + self._position = 0 + self._endpoints = [] + LoadBalancingPolicy.__init__(self) + + def _rack(self, host): + return host.rack or self.local_rack + + def _dc(self, host): + return host.datacenter or self.local_dc + + def populate(self, cluster, hosts): + for (dc, rack), dc_hosts in groupby(hosts, lambda host: (self._dc(host), self._rack(host))): + self._live_hosts[(dc, rack)] = dc_hosts + for dc, dc_hosts in groupby(hosts, lambda host: self._dc(host)): + self._dc_live_hosts[dc] = dc_hosts + + # as in other policies choose random position for better distributing queries across hosts + self._position = randint(0, len(hosts) - 1) if hosts else 0 + + def distance(self, host): + rack = self._rack(host) + dc = self._dc(host) + if rack == self.local_rack and dc == self.local_dc: + return HostDistance.LOCAL_RACK + + if dc == self.local_dc: + return HostDistance.LOCAL + + if not self.used_hosts_per_remote_dc: + return HostDistance.IGNORED + else: + dc_hosts = self._dc_live_hosts.get(dc, ()) + if not dc_hosts: + return HostDistance.IGNORED + + if host in dc_hosts[:self.used_hosts_per_remote_dc]: + + return HostDistance.REMOTE + else: + return HostDistance.IGNORED + + def make_query_plan(self, working_keyspace=None, query=None): + pos = self._position + self._position += 1 + + local_rack_live = self._live_hosts.get((self.local_dc, self.local_rack), ()) + pos = (pos % len(local_rack_live)) if local_rack_live else 0 + # Slice the cyclic iterator to start from pos and include the next len(local_live) elements + # This ensures we get exactly one full cycle starting from pos + for host in islice(cycle(local_rack_live), pos, pos + len(local_rack_live)): + yield host + + local_live = [host for host in self._dc_live_hosts.get(self.local_dc, ()) if host.rack != self.local_rack] + pos = (pos % len(local_live)) if local_live else 0 + for host in islice(cycle(local_live), pos, pos + len(local_live)): + yield host + + # the dict can change, so get candidate DCs iterating over keys of a copy + other_dcs = [dc for dc in self._dc_live_hosts.copy().keys() if dc != self.local_dc] + for dc in other_dcs: + remote_live = self._dc_live_hosts.get(dc, ()) + for host in remote_live[:self.used_hosts_per_remote_dc]: + yield host + + def on_up(self, host): + dc = self._dc(host) + rack = self._rack(host) + with self._hosts_lock: + self._live_hosts[(dc, rack)].append(host) + self._dc_live_hosts[dc].append(host) + + def on_down(self, host): + dc = self._dc(host) + rack = self._rack(host) + with self._hosts_lock: + self._live_hosts[(dc, rack)].remove(host) + self._dc_live_hosts[dc].remove(host) + + def on_add(self, host): + self.on_up(host) + + def on_remove(self, host): + self.on_down(host) class TokenAwarePolicy(LoadBalancingPolicy): """ @@ -396,7 +519,7 @@ def make_query_plan(self, working_keyspace=None, query=None): shuffle(replicas) for replica in replicas: if replica.is_up and \ - child.distance(replica) == HostDistance.LOCAL: + (child.distance(replica) == HostDistance.LOCAL or child.distance(replica) == HostDistance.LOCAL_RACK): yield replica for host in child.make_query_plan(keyspace, query): diff --git a/docs/api/cassandra/policies.rst b/docs/api/cassandra/policies.rst index 387b19ed9..ea3b19d79 100644 --- a/docs/api/cassandra/policies.rst +++ b/docs/api/cassandra/policies.rst @@ -18,6 +18,9 @@ Load Balancing .. autoclass:: DCAwareRoundRobinPolicy :members: +.. autoclass:: RackAwareRoundRobinPolicy + :members: + .. autoclass:: WhiteListRoundRobinPolicy :members: diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index db9eae632..96daae151 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -17,6 +17,7 @@ from itertools import islice, cycle from mock import Mock, patch, call from random import randint +import pytest import six from six.moves._thread import LockType import sys @@ -26,7 +27,7 @@ from cassandra import ConsistencyLevel from cassandra.cluster import Cluster, ControlConnection from cassandra.metadata import Metadata -from cassandra.policies import (RoundRobinPolicy, WhiteListRoundRobinPolicy, DCAwareRoundRobinPolicy, +from cassandra.policies import (RackAwareRoundRobinPolicy, RoundRobinPolicy, WhiteListRoundRobinPolicy, DCAwareRoundRobinPolicy, TokenAwarePolicy, SimpleConvictionPolicy, HostDistance, ExponentialReconnectionPolicy, RetryPolicy, WriteType, @@ -180,61 +181,93 @@ def test_no_live_nodes(self): qplan = list(policy.make_query_plan()) self.assertEqual(qplan, []) +@pytest.mark.parametrize("policy_specialization, constructor_args", [(DCAwareRoundRobinPolicy, ("dc1", )), (RackAwareRoundRobinPolicy, ("dc1", "rack1"))]) +class RackOrDCAwareRoundRobinPolicyTest(): -class DCAwareRoundRobinPolicyTest(unittest.TestCase): - - def test_no_remote(self): + def test_no_remote(self, policy_specialization, constructor_args): hosts = [] - for i in range(4): + for i in range(2): h = Host(DefaultEndPoint(i), SimpleConvictionPolicy) + h.set_location_info("dc1", "rack2") + hosts.append(h) + for i in range(2): + h = Host(DefaultEndPoint(i + 2), SimpleConvictionPolicy) h.set_location_info("dc1", "rack1") hosts.append(h) - policy = DCAwareRoundRobinPolicy("dc1") + policy = policy_specialization(*constructor_args) policy.populate(None, hosts) qplan = list(policy.make_query_plan()) self.assertEqual(sorted(qplan), sorted(hosts)) - def test_with_remotes(self): - hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(4)] + def test_with_remotes(self, policy_specialization, constructor_args): + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(6)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") - for h in hosts[2:]: + for h in hosts[2:4]: + h.set_location_info("dc1", "rack2") + for h in hosts[4:]: h.set_location_info("dc2", "rack1") - local_hosts = set(h for h in hosts if h.datacenter == "dc1") + local_rack_hosts = set(h for h in hosts if h.datacenter == "dc1" and h.rack == "rack1") + local_hosts = set(h for h in hosts if h.datacenter == "dc1" and h.rack != "rack1") remote_hosts = set(h for h in hosts if h.datacenter != "dc1") # allow all of the remote hosts to be used - policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=2) + policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=2) policy.populate(Mock(), hosts) qplan = list(policy.make_query_plan()) - self.assertEqual(set(qplan[:2]), local_hosts) - self.assertEqual(set(qplan[2:]), remote_hosts) + if isinstance(policy_specialization, DCAwareRoundRobinPolicy): + self.assertEqual(set(qplan[:4]), local_rack_hosts + local_hosts) + elif isinstance(policy_specialization, RackAwareRoundRobinPolicy): + self.assertEqual(set(qplan[:2]), local_rack_hosts) + self.assertEqual(set(qplan[2:4]), local_hosts) + self.assertEqual(set(qplan[4:]), remote_hosts) # allow only one of the remote hosts to be used - policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1) + policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=1) policy.populate(Mock(), hosts) qplan = list(policy.make_query_plan()) - self.assertEqual(set(qplan[:2]), local_hosts) + if isinstance(policy_specialization, DCAwareRoundRobinPolicy): + self.assertEqual(set(qplan[:4]), local_rack_hosts + local_hosts) + elif isinstance(policy_specialization, RackAwareRoundRobinPolicy): + self.assertEqual(set(qplan[:2]), local_rack_hosts) + self.assertEqual(set(qplan[2:4]), local_hosts) - used_remotes = set(qplan[2:]) + used_remotes = set(qplan[4:]) self.assertEqual(1, len(used_remotes)) - self.assertIn(qplan[2], remote_hosts) + self.assertIn(qplan[4], remote_hosts) # allow no remote hosts to be used - policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=0) + policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=0) policy.populate(Mock(), hosts) qplan = list(policy.make_query_plan()) - self.assertEqual(2, len(qplan)) - self.assertEqual(local_hosts, set(qplan)) - def test_get_distance(self): - policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=0) + self.assertEqual(4, len(qplan)) + if isinstance(policy_specialization, DCAwareRoundRobinPolicy): + self.assertEqual(set(qplan), local_rack_hosts + local_hosts) + elif isinstance(policy_specialization, RackAwareRoundRobinPolicy): + self.assertEqual(set(qplan[:2]), local_rack_hosts) + self.assertEqual(set(qplan[2:4]), local_hosts) + + def test_get_distance(self, policy_specialization, constructor_args): + policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=0) + + # same dc, same rack host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) host.set_location_info("dc1", "rack1") policy.populate(Mock(), [host]) + if isinstance(policy_specialization, DCAwareRoundRobinPolicy): + self.assertEqual(policy.distance(host), HostDistance.LOCAL) + elif isinstance(policy_specialization, RackAwareRoundRobinPolicy): + self.assertEqual(policy.distance(host), HostDistance.LOCAL_RACK) + + # same dc different rack + host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) + host.set_location_info("dc1", "rack2") + policy.populate(Mock(), [host]) + self.assertEqual(policy.distance(host), HostDistance.LOCAL) # used_hosts_per_remote_dc is set to 0, so ignore it @@ -258,30 +291,34 @@ def test_get_distance(self): distances = set([policy.distance(remote_host), policy.distance(second_remote_host)]) self.assertEqual(distances, set([HostDistance.REMOTE, HostDistance.IGNORED])) - def test_status_updates(self): - hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(4)] + def test_status_updates(self, policy_specialization, constructor_args): + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(5)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") - for h in hosts[2:]: + for h in hosts[2:4]: + h.set_location_info("dc1", "rack2") + for h in hosts[4:]: h.set_location_info("dc2", "rack1") - policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1) + policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=1) policy.populate(Mock(), hosts) policy.on_down(hosts[0]) policy.on_remove(hosts[2]) - new_local_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy) + new_local_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy) new_local_host.set_location_info("dc1", "rack1") policy.on_up(new_local_host) - new_remote_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy) + new_remote_host = Host(DefaultEndPoint(6), SimpleConvictionPolicy) new_remote_host.set_location_info("dc9000", "rack1") policy.on_add(new_remote_host) - # we now have two local hosts and two remote hosts in separate dcs + # we now have three local hosts and two remote hosts in separate dcs qplan = list(policy.make_query_plan()) - self.assertEqual(set(qplan[:2]), set([hosts[1], new_local_host])) - self.assertEqual(set(qplan[2:]), set([hosts[3], new_remote_host])) + print("QPLAN: ", qplan) + + self.assertEqual(set(qplan[:3]), set([hosts[1], new_local_host, hosts[3]])) + self.assertEqual(set(qplan[3:]), set([hosts[4], new_remote_host])) # since we have hosts in dc9000, the distance shouldn't be IGNORED self.assertEqual(policy.distance(new_remote_host), HostDistance.REMOTE) @@ -289,21 +326,22 @@ def test_status_updates(self): policy.on_down(new_local_host) policy.on_down(hosts[1]) qplan = list(policy.make_query_plan()) - self.assertEqual(set(qplan), set([hosts[3], new_remote_host])) + self.assertEqual(set(qplan), set([hosts[3], hosts[4], new_remote_host])) policy.on_down(new_remote_host) policy.on_down(hosts[3]) + policy.on_down(hosts[4]) qplan = list(policy.make_query_plan()) self.assertEqual(qplan, []) - def test_modification_during_generation(self): + def test_modification_during_generation(self, policy_specialization, constructor_args): hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(4)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") for h in hosts[2:]: h.set_location_info("dc2", "rack1") - policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=3) + policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=3) policy.populate(Mock(), hosts) # The general concept here is to change thee internal state of the @@ -449,7 +487,7 @@ def test_modification_during_generation(self): # the last DC has two self.assertEqual(len(list(plan)), 0 + 2) - def test_no_live_nodes(self): + def test_no_live_nodes(self, policy_specialization, constructor_args): """ Ensure query plan for a downed cluster will execute without errors """ @@ -460,7 +498,7 @@ def test_no_live_nodes(self): h.set_location_info("dc1", "rack1") hosts.append(h) - policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1) + policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=1) policy.populate(Mock(), hosts) for host in hosts: @@ -469,12 +507,12 @@ def test_no_live_nodes(self): qplan = list(policy.make_query_plan()) self.assertEqual(qplan, []) - def test_no_nodes(self): + def test_no_nodes(self, policy_specialization, constructor_args): """ Ensure query plan for an empty cluster will execute without errors """ - policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1) + policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=1) policy.populate(None, []) qplan = list(policy.make_query_plan()) @@ -520,7 +558,6 @@ def test_default_dc(self): policy.on_add(host_remote) self.assertFalse(policy.local_dc) - class TokenAwarePolicyTest(unittest.TestCase): def test_wrap_round_robin(self):