diff --git a/presto-elasticsearch/src/main/java/io/prestosql/elasticsearch/ElasticsearchClient.java b/presto-elasticsearch/src/main/java/io/prestosql/elasticsearch/ElasticsearchClient.java index 676e1823c828..b1863d059403 100644 --- a/presto-elasticsearch/src/main/java/io/prestosql/elasticsearch/ElasticsearchClient.java +++ b/presto-elasticsearch/src/main/java/io/prestosql/elasticsearch/ElasticsearchClient.java @@ -38,6 +38,7 @@ import org.elasticsearch.client.transport.TransportClient; import org.elasticsearch.cluster.metadata.MappingMetaData; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.common.collect.ImmutableOpenMap; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; @@ -50,6 +51,7 @@ import java.io.UncheckedIOException; import java.net.InetAddress; import java.util.ArrayList; +import java.util.Arrays; import java.util.Comparator; import java.util.HashMap; import java.util.Iterator; @@ -58,6 +60,7 @@ import java.util.Optional; import java.util.TreeMap; import java.util.concurrent.ExecutorService; +import java.util.stream.Collectors; import static com.floragunn.searchguard.ssl.util.SSLConfigConstants.SEARCHGUARD_SSL_TRANSPORT_ENFORCE_HOSTNAME_VERIFICATION; import static com.floragunn.searchguard.ssl.util.SSLConfigConstants.SEARCHGUARD_SSL_TRANSPORT_KEYSTORE_FILEPATH; @@ -193,10 +196,27 @@ public List getSearchShards(String index) .actionGet(requestTimeout.toMillis())); ImmutableList.Builder shards = ImmutableList.builder(); + DiscoveryNode[] nodes = result.getNodes(); + Map nodeById = Arrays.stream(nodes) + .collect(Collectors.toMap(DiscoveryNode::getId, node -> node)); + for (ClusterSearchShardsGroup group : result.getGroups()) { - int nodeIndex = group.getShardId().getId() % nodes.length; - shards.add(new Shard(group.getShardId().getId(), nodes[nodeIndex].getHostName(), nodes[nodeIndex].getAddress().getPort())); + Optional routing = Arrays.stream(group.getShards()) + .filter(ShardRouting::assignedToNode) + .sorted(this::shardPreference) + .findFirst(); + + DiscoveryNode node; + if (!routing.isPresent()) { + // pick an arbitrary node + node = nodes[group.getShardId().getId() % nodes.length]; + } + else { + node = nodeById.get(routing.get().currentNodeId()); + } + + shards.add(new Shard(group.getShardId().getId(), node.getHostName(), node.getAddress().getPort())); } return shards.build(); @@ -206,6 +226,15 @@ public List getSearchShards(String index) } } + private int shardPreference(ShardRouting left, ShardRouting right) + { + // Favor non-primary shards + if (left.primary() == right.primary()) { + return 0; + } + return left.primary() ? 1 : -1; + } + private List buildMetadata(List columns) { List result = new ArrayList<>();