1414import random
1515
1616from collections import namedtuple
17- from itertools import islice , cycle , groupby , repeat
17+ from itertools import islice , cycle , groupby , repeat , chain
1818import logging
1919from random import randint , shuffle
2020from threading import Lock
@@ -157,6 +157,18 @@ def make_query_plan(self, working_keyspace=None, query=None):
157157 """
158158 raise NotImplementedError ()
159159
160+ def make_query_plan_with_exclusion (self , working_keyspace = None , query = None , excluded = ()):
161+ """
162+ Same as :meth:`make_query_plan`, but with an additional `excluded` parameter.
163+ `excluded` should be a container (set, list, etc.) of hosts to skip.
164+
165+ The default implementation simply delegates to `make_query_plan` and filters the result.
166+ Subclasses may override this for performance.
167+ """
168+ for host in self .make_query_plan (working_keyspace , query ):
169+ if host not in excluded :
170+ yield host
171+
160172 def check_supported (self ):
161173 """
162174 This will be called after the cluster Metadata has been initialized.
@@ -198,6 +210,20 @@ def make_query_plan(self, working_keyspace=None, query=None):
198210 else :
199211 return []
200212
213+ def make_query_plan_with_exclusion (self , working_keyspace = None , query = None , excluded = ()):
214+ pos = self ._position
215+ self ._position += 1
216+
217+ hosts = self ._live_hosts
218+ length = len (hosts )
219+ if length :
220+ pos %= length
221+ for host in islice (cycle (hosts ), pos , pos + length ):
222+ if host not in excluded :
223+ yield host
224+ else :
225+ return
226+
201227 def on_up (self , host ):
202228 with self ._hosts_lock :
203229 self ._live_hosts = self ._live_hosts .union ((host , ))
@@ -289,6 +315,24 @@ def make_query_plan(self, working_keyspace=None, query=None):
289315 for host in self ._remote_hosts :
290316 yield host
291317
318+ def make_query_plan_with_exclusion (self , working_keyspace = None , query = None , excluded = ()):
319+ # not thread-safe, but we don't care much about lost increments
320+ # for the purposes of load balancing
321+ pos = self ._position
322+ self ._position += 1
323+
324+ local_live = self ._dc_live_hosts .get (self .local_dc , ())
325+ pos = (pos % len (local_live )) if local_live else 0
326+ for host in islice (cycle (local_live ), pos , pos + len (local_live )):
327+ if excluded and host in excluded :
328+ continue
329+ yield host
330+
331+ for host in self ._remote_hosts :
332+ if excluded and host in excluded :
333+ continue
334+ yield host
335+
292336 def on_up (self , host ):
293337 # not worrying about threads because this will happen during
294338 # control connection startup/refresh
@@ -424,6 +468,33 @@ def make_query_plan(self, working_keyspace=None, query=None):
424468
425469 for host in self ._remote_hosts :
426470 yield host
471+
472+ def make_query_plan_with_exclusion (self , working_keyspace = None , query = None , excluded = ()):
473+ pos = self ._position
474+ self ._position += 1
475+
476+ local_rack_live = self ._live_hosts .get ((self .local_dc , self .local_rack ), ())
477+ length = len (local_rack_live )
478+ if length :
479+ p = pos % length
480+ for host in islice (cycle (local_rack_live ), p , p + length ):
481+ if excluded and host in excluded :
482+ continue
483+ yield host
484+
485+ local_non_rack = self ._non_local_rack_hosts
486+ length = len (local_non_rack )
487+ if length :
488+ p = pos % length
489+ for host in islice (cycle (local_non_rack ), p , p + length ):
490+ if excluded and host in excluded :
491+ continue
492+ yield host
493+
494+ for host in self ._remote_hosts :
495+ if excluded and host in excluded :
496+ continue
497+ yield host
427498
428499 def on_up (self , host ):
429500 dc = self ._dc (host )
@@ -495,16 +566,12 @@ class TokenAwarePolicy(LoadBalancingPolicy):
495566 policy's query plan will be used as is.
496567 """
497568
498- _child_policy = None
499- _cluster_metadata = None
500- shuffle_replicas = True
501- """
502- Yield local replicas in a random order.
503- """
569+ __slots__ = ('_child_policy' , '_cluster_metadata' , 'shuffle_replicas' )
504570
505571 def __init__ (self , child_policy , shuffle_replicas = True ):
506572 self ._child_policy = child_policy
507573 self .shuffle_replicas = shuffle_replicas
574+ self ._cluster_metadata = None
508575
509576 def populate (self , cluster , hosts ):
510577 self ._cluster_metadata = cluster .metadata
@@ -527,35 +594,69 @@ def make_query_plan(self, working_keyspace=None, query=None):
527594
528595 child = self ._child_policy
529596 if query is None or query .routing_key is None or keyspace is None :
530- for host in child .make_query_plan (keyspace , query ):
531- yield host
597+ yield from child .make_query_plan (keyspace , query )
532598 return
533599
600+ cluster_metadata = self ._cluster_metadata
601+ token_map = cluster_metadata .token_map
534602 replicas = []
535- tablet = self ._cluster_metadata ._tablets .get_tablet_for_key (
536- keyspace , query .table , self ._cluster_metadata .token_map .token_class .from_key (query .routing_key ))
537603
538- if tablet is not None :
539- replicas_mapped = set (map (lambda r : r [0 ], tablet .replicas ))
540- child_plan = child .make_query_plan (keyspace , query )
604+ if token_map :
605+ try :
606+ token = token_map .token_class .from_key (query .routing_key )
607+ tablet = cluster_metadata ._tablets .get_tablet_for_key (
608+ keyspace , query .table , token )
609+
610+ if tablet is not None :
611+ replicas_mapped = set (map (lambda r : r [0 ], tablet .replicas ))
612+ for host_id in replicas_mapped :
613+ host = cluster_metadata .get_host_by_host_id (host_id )
614+ if host :
615+ replicas .append (host )
616+ else :
617+ try :
618+ replicas = list (token_map .get_replicas (keyspace , token ))
619+ except Exception :
620+ replicas = cluster_metadata .get_replicas (keyspace , query .routing_key )
621+ except Exception :
622+ pass
541623
542- replicas = [host for host in child_plan if host .host_id in replicas_mapped ]
543- else :
544- replicas = self ._cluster_metadata .get_replicas (keyspace , query .routing_key )
545624
546625 if self .shuffle_replicas and not query .is_lwt ():
547626 shuffle (replicas )
548627
549- def yield_in_order (hosts ):
550- for distance in [HostDistance .LOCAL_RACK , HostDistance .LOCAL , HostDistance .REMOTE ]:
551- for replica in hosts :
552- if replica .is_up and child .distance (replica ) == distance :
553- yield replica
554-
555- # yield replicas: local_rack, local, remote
556- yield from yield_in_order (replicas )
557- # yield rest of the cluster: local_rack, local, remote
558- yield from yield_in_order ([host for host in child .make_query_plan (keyspace , query ) if host not in replicas ])
628+ local_rack = []
629+ local = []
630+ remote = []
631+
632+ child_distance = child .distance
633+
634+ for replica in replicas :
635+ if replica .is_up :
636+ d = child_distance (replica )
637+ if d == HostDistance .LOCAL_RACK :
638+ local_rack .append (replica )
639+ elif d == HostDistance .LOCAL :
640+ local .append (replica )
641+ elif d == HostDistance .REMOTE :
642+ remote .append (replica )
643+
644+ yielded_sequence = tuple (chain (local_rack , local , remote ))
645+
646+ if yielded_sequence :
647+ yield from yielded_sequence
648+
649+ yielded = set (yielded_sequence )
650+
651+ # yield rest of the cluster
652+ try :
653+ yield from child .make_query_plan_with_exclusion (keyspace , query , yielded )
654+ except (AttributeError , TypeError ):
655+ for host in child .make_query_plan (keyspace , query ):
656+ if host not in yielded :
657+ yield host
658+ else :
659+ yield from child .make_query_plan (keyspace , query )
559660
560661 def on_up (self , * args , ** kwargs ):
561662 return self ._child_policy .on_up (* args , ** kwargs )
0 commit comments