Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ public class ClusterTopologyRefreshOptions {

public static final int DEFAULT_REFRESH_TRIGGERS_RECONNECT_ATTEMPTS = 5;

public static final int DEFAULT_MAX_TOPOLOGY_REFRESH_SOURCES = Integer.MAX_VALUE;

private final Set<RefreshTrigger> adaptiveRefreshTriggers;

private final Duration adaptiveRefreshTimeout;
Expand All @@ -79,6 +81,8 @@ public class ClusterTopologyRefreshOptions {

private final int refreshTriggersReconnectAttempts;

private final int maxTopologyRefreshSources;

protected ClusterTopologyRefreshOptions(Builder builder) {

this.adaptiveRefreshTriggers = Collections.unmodifiableSet(new HashSet<>(builder.adaptiveRefreshTriggers));
Expand All @@ -88,6 +92,7 @@ protected ClusterTopologyRefreshOptions(Builder builder) {
this.periodicRefreshEnabled = builder.periodicRefreshEnabled;
this.refreshPeriod = builder.refreshPeriod;
this.refreshTriggersReconnectAttempts = builder.refreshTriggersReconnectAttempts;
this.maxTopologyRefreshSources = builder.maxTopologyRefreshSources;
}

protected ClusterTopologyRefreshOptions(ClusterTopologyRefreshOptions original) {
Expand All @@ -99,6 +104,7 @@ protected ClusterTopologyRefreshOptions(ClusterTopologyRefreshOptions original)
this.periodicRefreshEnabled = original.periodicRefreshEnabled;
this.refreshPeriod = original.refreshPeriod;
this.refreshTriggersReconnectAttempts = original.refreshTriggersReconnectAttempts;
this.maxTopologyRefreshSources = original.maxTopologyRefreshSources;
}

/**
Expand Down Expand Up @@ -157,6 +163,8 @@ public static class Builder {

private int refreshTriggersReconnectAttempts = DEFAULT_REFRESH_TRIGGERS_RECONNECT_ATTEMPTS;

private int maxTopologyRefreshSources = DEFAULT_MAX_TOPOLOGY_REFRESH_SOURCES;

private Builder() {
}

Expand Down Expand Up @@ -304,6 +312,19 @@ public Builder dynamicRefreshSources(boolean dynamicRefreshSources) {
return this;
}

/**
* @param maxTopologyRefreshSources maximum number of nodes to query for topology refresh. Use
* {@link ClusterTopologyRefreshOptions#DEFAULT_MAX_TOPOLOGY_REFRESH_SOURCES} for no limit.
* @return {@code this}
*/
public Builder maxTopologyRefreshSources(int maxTopologyRefreshSources) {

LettuceAssert.isTrue(maxTopologyRefreshSources > 0, "maxTopologyRefreshSources must be greater than 0");

this.maxTopologyRefreshSources = maxTopologyRefreshSources;
return this;
}

/**
* Enables periodic cluster topology updates. The client starts updating the cluster topology in the intervals of
* {@link Builder#refreshPeriod}. Defaults to {@code false}. See {@link #DEFAULT_PERIODIC_REFRESH_ENABLED}.
Expand Down Expand Up @@ -459,6 +480,19 @@ public boolean useDynamicRefreshSources() {
return dynamicRefreshSources;
}

/**
* Return the maximum number of additionally queried (discovered) nodes used as sources during topology refresh when
* {@link #useDynamicRefreshSources()} ()} is true.
*
* <p>
* A value of {@link #DEFAULT_MAX_TOPOLOGY_REFRESH_SOURCES} means no limit and will query all discovered nodes.
*
* @return the maximum number of additionally queried topology refresh sources
*/
public int getMaxTopologyRefreshSources() {
return maxTopologyRefreshSources;
}

/**
* Flag, whether regular cluster topology updates are updated. The client starts updating the cluster topology in the
* intervals of {@link #getRefreshPeriod()}. Defaults to {@code false}.
Expand Down
21 changes: 20 additions & 1 deletion src/main/java/io/lettuce/core/cluster/RedisClusterClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
*/
package io.lettuce.core.cluster;

import io.lettuce.core.cluster.models.partitions.RedisClusterNode.NodeFlag;
import java.io.Closeable;
import java.net.SocketAddress;
import java.net.URI;
Expand Down Expand Up @@ -1087,7 +1088,8 @@ protected CompletableFuture<Partitions> loadPartitionsAsync() {
private CompletionStage<Partitions> fetchPartitions(Iterable<RedisURI> topologyRefreshSource) {

CompletionStage<Map<RedisURI, Partitions>> topology = refresh.loadViews(topologyRefreshSource,
getClusterClientOptions().getSocketOptions().getConnectTimeout(), useDynamicRefreshSources());
getClusterClientOptions().getSocketOptions().getConnectTimeout(), useDynamicRefreshSources(),
getMaxTopologyRefreshSources());

return topology.thenApply(partitions -> {

Expand Down Expand Up @@ -1267,6 +1269,23 @@ protected boolean useDynamicRefreshSources() {
return topologyRefreshOptions.useDynamicRefreshSources();
}

/**
* Returns the maximum number of additionally queried (discovered) nodes used as topology refresh sources when
* {@link ClusterTopologyRefreshOptions#useDynamicRefreshSources() dynamic refresh sources} are enabled.
* <p>
* Subclasses of {@link RedisClusterClient} may override this method.
*
* @return the maximum number of additionally queried topology refresh sources.
* @see ClusterTopologyRefreshOptions#getMaxTopologyRefreshSources()
* @see ClusterTopologyRefreshOptions#useDynamicRefreshSources()
*/
protected int getMaxTopologyRefreshSources() {

ClusterTopologyRefreshOptions topologyRefreshOptions = getClusterClientOptions().getTopologyRefreshOptions();

return topologyRefreshOptions.getMaxTopologyRefreshSources();
}

/**
* Returns a {@link String} {@link RedisCodec codec}.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,18 @@ static ClusterTopologyRefresh create(NodeConnectionFactory nodeConnectionFactory
*/
CompletionStage<Map<RedisURI, Partitions>> loadViews(Iterable<RedisURI> seed, Duration connectTimeout, boolean discovery);

/**
* Load topology views from a collection of {@link RedisURI}s and return the view per {@link RedisURI}. Partitions contain
* an ordered list of {@link RedisClusterNode}s. The sort key is latency. Nodes with lower latency come first.
*
* @param seed collection of {@link RedisURI}s
* @param connectTimeout connect timeout
* @param discovery {@code true} to discover additional nodes
* @param maxTopologyRefreshSources maximum number of additionally queried (discovered) nodes. Use {@link Integer#MAX_VALUE}
* to query all discovered nodes.
* @return mapping between {@link RedisURI} and {@link Partitions}
*/
CompletionStage<Map<RedisURI, Partitions>> loadViews(Iterable<RedisURI> seed, Duration connectTimeout, boolean discovery,
int maxTopologyRefreshSources);

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
*/
package io.lettuce.core.cluster.topology;

import io.lettuce.core.cluster.ClusterTopologyRefreshOptions;
import io.lettuce.core.cluster.topology.TopologyComparators.LatencyComparator;
import java.io.IOException;
import java.net.SocketAddress;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
Expand Down Expand Up @@ -92,6 +95,23 @@ public DefaultClusterTopologyRefresh(NodeConnectionFactory nodeConnectionFactory
@Override
public CompletionStage<Map<RedisURI, Partitions>> loadViews(Iterable<RedisURI> seed, Duration connectTimeout,
boolean discovery) {
return loadViews(seed, connectTimeout, discovery, Integer.MAX_VALUE);
}

/**
* Load topology views from a collection of {@link RedisURI}s and return the view per {@link RedisURI}. Partitions contain
* an ordered list of {@link RedisClusterNode}s. The sort key is latency. Nodes with lower latency come first.
*
* @param seed collection of {@link RedisURI}s
* @param connectTimeout connect timeout
* @param discovery {@code true} to discover additional nodes
* @param maxTopologyRefreshSources maximum number of additionally queried (discovered) nodes. Use {@link Integer#MAX_VALUE}
* to query all discovered nodes.
* @return mapping between {@link RedisURI} and {@link Partitions}
*/
@Override
public CompletionStage<Map<RedisURI, Partitions>> loadViews(Iterable<RedisURI> seed, Duration connectTimeout,
boolean discovery, int maxTopologyRefreshSources) {

if (!isEventLoopActive()) {
return CompletableFuture.completedFuture(Collections.emptyMap());
Expand All @@ -115,7 +135,8 @@ public CompletionStage<Map<RedisURI, Partitions>> loadViews(Iterable<RedisURI> s
if (discovery && isEventLoopActive()) {

Set<RedisURI> allKnownUris = views.getClusterNodes();
Set<RedisURI> discoveredNodes = difference(allKnownUris, toSet(seed));
Set<RedisURI> discoveredNodes = limit(difference(allKnownUris, toSet(seed)),
maxTopologyRefreshSources);

if (discoveredNodes.isEmpty()) {
return CompletableFuture.completedFuture(views);
Expand Down Expand Up @@ -390,6 +411,18 @@ private static Set<RedisURI> difference(Set<RedisURI> allKnown, Set<RedisURI> se
return result;
}

private static Set<RedisURI> limit(Set<RedisURI> uris, int maxTopologyRefreshSources) {
if (uris.size() <= maxTopologyRefreshSources) {
return uris;
}

List<RedisURI> uriList = new ArrayList<>(uris);
Collections.shuffle(uriList);

return uriList.stream().limit(maxTopologyRefreshSources)
.collect(Collectors.toCollection(() -> new TreeSet<>(TopologyComparators.RedisURIComparator.INSTANCE)));
}

private static long getCommandTimeoutNs(Iterable<RedisURI> redisURIs) {

RedisURI redisURI = redisURIs.iterator().next();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ void testBuilder() {
.adaptiveRefreshTriggersTimeout(15, TimeUnit.MILLISECONDS)//
.closeStaleConnections(false)//
.refreshTriggersReconnectAttempts(2)//
.maxTopologyRefreshSources(100)//
.build();

assertThat(options.getRefreshPeriod()).isEqualTo(Duration.ofMinutes(10));
Expand All @@ -39,6 +40,7 @@ void testBuilder() {
assertThat(options.getAdaptiveRefreshTimeout()).isEqualTo(Duration.ofMillis(15));
assertThat(options.getAdaptiveRefreshTriggers()).containsOnly(RefreshTrigger.MOVED_REDIRECT);
assertThat(options.getRefreshTriggersReconnectAttempts()).isEqualTo(2);
assertThat(options.getMaxTopologyRefreshSources()).isEqualTo(100);
}

@Test
Expand All @@ -52,6 +54,7 @@ void testCopy() {
.adaptiveRefreshTriggersTimeout(15, TimeUnit.MILLISECONDS)//
.closeStaleConnections(false)//
.refreshTriggersReconnectAttempts(2)//
.maxTopologyRefreshSources(100)//
.build();

ClusterTopologyRefreshOptions options = ClusterTopologyRefreshOptions.copyOf(master);
Expand All @@ -63,6 +66,7 @@ void testCopy() {
assertThat(options.getAdaptiveRefreshTimeout()).isEqualTo(Duration.ofMillis(15));
assertThat(options.getAdaptiveRefreshTriggers()).containsOnly(RefreshTrigger.MOVED_REDIRECT);
assertThat(options.getRefreshTriggersReconnectAttempts()).isEqualTo(2);
assertThat(options.getMaxTopologyRefreshSources()).isEqualTo(100);
}

@Test
Expand All @@ -82,6 +86,8 @@ void testDefault() {
.isEqualTo(ClusterTopologyRefreshOptions.DEFAULT_ADAPTIVE_REFRESH_TRIGGERS);
assertThat(options.getRefreshTriggersReconnectAttempts())
.isEqualTo(ClusterTopologyRefreshOptions.DEFAULT_REFRESH_TRIGGERS_RECONNECT_ATTEMPTS);
assertThat(options.getMaxTopologyRefreshSources())
.isEqualTo(ClusterTopologyRefreshOptions.DEFAULT_MAX_TOPOLOGY_REFRESH_SOURCES);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
Expand Down Expand Up @@ -602,6 +603,35 @@ void shouldPropagateCommandFailures() {
}
}

@Test
void shouldLimitDiscoveredNodesToOne() {

List<RedisURI> seed = Collections.singletonList(RedisURI.create("foobar", 7380));

when(nodeConnectionFactory.connectToNodeAsync(any(RedisCodec.class),
eq(InetSocketAddress.createUnresolved("foobar", 7380))))
.thenReturn(completedFuture((StatefulRedisConnection) connection1));

when(nodeConnectionFactory.connectToNodeAsync(any(RedisCodec.class),
eq(InetSocketAddress.createUnresolved("127.0.0.1", 7380))))
.thenReturn(completedFuture((StatefulRedisConnection) connection1));

when(nodeConnectionFactory.connectToNodeAsync(any(RedisCodec.class),
eq(InetSocketAddress.createUnresolved("127.0.0.1", 7381))))
.thenReturn(completedFuture((StatefulRedisConnection) connection2));

sut.loadViews(seed, Duration.ofSeconds(1), true, 1).toCompletableFuture().join();

ArgumentCaptor<InetSocketAddress> captor = ArgumentCaptor.forClass(InetSocketAddress.class);

verify(nodeConnectionFactory, times(2)).connectToNodeAsync(any(RedisCodec.class), captor.capture());

List<InetSocketAddress> called = captor.getAllValues();
assertThat(called).contains(InetSocketAddress.createUnresolved("foobar", 7380));

assertThat(called).anyMatch(a -> a.getHostString().equals("127.0.0.1") && (a.getPort() == 7380 || a.getPort() == 7381));
}

Requests createClusterNodesRequests(int duration, String nodes) {

RedisURI redisURI = RedisURI.create("redis://localhost:" + duration);
Expand Down
Loading