From 6afb1eb1a8a11070c2868ed5abff9ff70283975b Mon Sep 17 00:00:00 2001 From: Su Ko Date: Wed, 17 Dec 2025 09:23:50 +0900 Subject: [PATCH] implement `MaxTopologyRefresh` --- .../ClusterTopologyRefreshOptions.java | 34 ++++++++++++++++++ .../core/cluster/RedisClusterClient.java | 21 ++++++++++- .../topology/ClusterTopologyRefresh.java | 14 ++++++++ .../DefaultClusterTopologyRefresh.java | 35 ++++++++++++++++++- ...lusterTopologyRefreshOptionsUnitTests.java | 6 ++++ .../ClusterTopologyRefreshUnitTests.java | 30 ++++++++++++++++ 6 files changed, 138 insertions(+), 2 deletions(-) diff --git a/src/main/java/io/lettuce/core/cluster/ClusterTopologyRefreshOptions.java b/src/main/java/io/lettuce/core/cluster/ClusterTopologyRefreshOptions.java index 13bae0d998..40263b0c9e 100644 --- a/src/main/java/io/lettuce/core/cluster/ClusterTopologyRefreshOptions.java +++ b/src/main/java/io/lettuce/core/cluster/ClusterTopologyRefreshOptions.java @@ -65,6 +65,8 @@ public class ClusterTopologyRefreshOptions { public static final int DEFAULT_REFRESH_TRIGGERS_RECONNECT_ATTEMPTS = 5; + public static final int DEFAULT_MAX_TOPOLOGY_REFRESH_SOURCES = Integer.MAX_VALUE; + private final Set adaptiveRefreshTriggers; private final Duration adaptiveRefreshTimeout; @@ -79,6 +81,8 @@ public class ClusterTopologyRefreshOptions { private final int refreshTriggersReconnectAttempts; + private final int maxTopologyRefreshSources; + protected ClusterTopologyRefreshOptions(Builder builder) { this.adaptiveRefreshTriggers = Collections.unmodifiableSet(new HashSet<>(builder.adaptiveRefreshTriggers)); @@ -88,6 +92,7 @@ protected ClusterTopologyRefreshOptions(Builder builder) { this.periodicRefreshEnabled = builder.periodicRefreshEnabled; this.refreshPeriod = builder.refreshPeriod; this.refreshTriggersReconnectAttempts = builder.refreshTriggersReconnectAttempts; + this.maxTopologyRefreshSources = builder.maxTopologyRefreshSources; } protected ClusterTopologyRefreshOptions(ClusterTopologyRefreshOptions original) { @@ -99,6 +104,7 @@ protected ClusterTopologyRefreshOptions(ClusterTopologyRefreshOptions original) this.periodicRefreshEnabled = original.periodicRefreshEnabled; this.refreshPeriod = original.refreshPeriod; this.refreshTriggersReconnectAttempts = original.refreshTriggersReconnectAttempts; + this.maxTopologyRefreshSources = original.maxTopologyRefreshSources; } /** @@ -157,6 +163,8 @@ public static class Builder { private int refreshTriggersReconnectAttempts = DEFAULT_REFRESH_TRIGGERS_RECONNECT_ATTEMPTS; + private int maxTopologyRefreshSources = DEFAULT_MAX_TOPOLOGY_REFRESH_SOURCES; + private Builder() { } @@ -304,6 +312,19 @@ public Builder dynamicRefreshSources(boolean dynamicRefreshSources) { return this; } + /** + * @param maxTopologyRefreshSources maximum number of nodes to query for topology refresh. Use + * {@link ClusterTopologyRefreshOptions#DEFAULT_MAX_TOPOLOGY_REFRESH_SOURCES} for no limit. + * @return {@code this} + */ + public Builder maxTopologyRefreshSources(int maxTopologyRefreshSources) { + + LettuceAssert.isTrue(maxTopologyRefreshSources > 0, "maxTopologyRefreshSources must be greater than 0"); + + this.maxTopologyRefreshSources = maxTopologyRefreshSources; + return this; + } + /** * Enables periodic cluster topology updates. The client starts updating the cluster topology in the intervals of * {@link Builder#refreshPeriod}. Defaults to {@code false}. See {@link #DEFAULT_PERIODIC_REFRESH_ENABLED}. @@ -459,6 +480,19 @@ public boolean useDynamicRefreshSources() { return dynamicRefreshSources; } + /** + * Return the maximum number of additionally queried (discovered) nodes used as sources during topology refresh when + * {@link #useDynamicRefreshSources()} ()} is true. + * + *

+ * A value of {@link #DEFAULT_MAX_TOPOLOGY_REFRESH_SOURCES} means no limit and will query all discovered nodes. + * + * @return the maximum number of additionally queried topology refresh sources + */ + public int getMaxTopologyRefreshSources() { + return maxTopologyRefreshSources; + } + /** * Flag, whether regular cluster topology updates are updated. The client starts updating the cluster topology in the * intervals of {@link #getRefreshPeriod()}. Defaults to {@code false}. diff --git a/src/main/java/io/lettuce/core/cluster/RedisClusterClient.java b/src/main/java/io/lettuce/core/cluster/RedisClusterClient.java index 8ac05116b7..845dae9ebf 100644 --- a/src/main/java/io/lettuce/core/cluster/RedisClusterClient.java +++ b/src/main/java/io/lettuce/core/cluster/RedisClusterClient.java @@ -19,6 +19,7 @@ */ package io.lettuce.core.cluster; +import io.lettuce.core.cluster.models.partitions.RedisClusterNode.NodeFlag; import java.io.Closeable; import java.net.SocketAddress; import java.net.URI; @@ -1087,7 +1088,8 @@ protected CompletableFuture loadPartitionsAsync() { private CompletionStage fetchPartitions(Iterable topologyRefreshSource) { CompletionStage> topology = refresh.loadViews(topologyRefreshSource, - getClusterClientOptions().getSocketOptions().getConnectTimeout(), useDynamicRefreshSources()); + getClusterClientOptions().getSocketOptions().getConnectTimeout(), useDynamicRefreshSources(), + getMaxTopologyRefreshSources()); return topology.thenApply(partitions -> { @@ -1267,6 +1269,23 @@ protected boolean useDynamicRefreshSources() { return topologyRefreshOptions.useDynamicRefreshSources(); } + /** + * Returns the maximum number of additionally queried (discovered) nodes used as topology refresh sources when + * {@link ClusterTopologyRefreshOptions#useDynamicRefreshSources() dynamic refresh sources} are enabled. + *

+ * Subclasses of {@link RedisClusterClient} may override this method. + * + * @return the maximum number of additionally queried topology refresh sources. + * @see ClusterTopologyRefreshOptions#getMaxTopologyRefreshSources() + * @see ClusterTopologyRefreshOptions#useDynamicRefreshSources() + */ + protected int getMaxTopologyRefreshSources() { + + ClusterTopologyRefreshOptions topologyRefreshOptions = getClusterClientOptions().getTopologyRefreshOptions(); + + return topologyRefreshOptions.getMaxTopologyRefreshSources(); + } + /** * Returns a {@link String} {@link RedisCodec codec}. * diff --git a/src/main/java/io/lettuce/core/cluster/topology/ClusterTopologyRefresh.java b/src/main/java/io/lettuce/core/cluster/topology/ClusterTopologyRefresh.java index bb2694d18c..9b5dc416bf 100644 --- a/src/main/java/io/lettuce/core/cluster/topology/ClusterTopologyRefresh.java +++ b/src/main/java/io/lettuce/core/cluster/topology/ClusterTopologyRefresh.java @@ -38,4 +38,18 @@ static ClusterTopologyRefresh create(NodeConnectionFactory nodeConnectionFactory */ CompletionStage> loadViews(Iterable seed, Duration connectTimeout, boolean discovery); + /** + * Load topology views from a collection of {@link RedisURI}s and return the view per {@link RedisURI}. Partitions contain + * an ordered list of {@link RedisClusterNode}s. The sort key is latency. Nodes with lower latency come first. + * + * @param seed collection of {@link RedisURI}s + * @param connectTimeout connect timeout + * @param discovery {@code true} to discover additional nodes + * @param maxTopologyRefreshSources maximum number of additionally queried (discovered) nodes. Use {@link Integer#MAX_VALUE} + * to query all discovered nodes. + * @return mapping between {@link RedisURI} and {@link Partitions} + */ + CompletionStage> loadViews(Iterable seed, Duration connectTimeout, boolean discovery, + int maxTopologyRefreshSources); + } diff --git a/src/main/java/io/lettuce/core/cluster/topology/DefaultClusterTopologyRefresh.java b/src/main/java/io/lettuce/core/cluster/topology/DefaultClusterTopologyRefresh.java index 2f4b04b4c0..922fff59ca 100644 --- a/src/main/java/io/lettuce/core/cluster/topology/DefaultClusterTopologyRefresh.java +++ b/src/main/java/io/lettuce/core/cluster/topology/DefaultClusterTopologyRefresh.java @@ -19,10 +19,13 @@ */ package io.lettuce.core.cluster.topology; +import io.lettuce.core.cluster.ClusterTopologyRefreshOptions; +import io.lettuce.core.cluster.topology.TopologyComparators.LatencyComparator; import java.io.IOException; import java.net.SocketAddress; import java.time.Duration; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -92,6 +95,23 @@ public DefaultClusterTopologyRefresh(NodeConnectionFactory nodeConnectionFactory @Override public CompletionStage> loadViews(Iterable seed, Duration connectTimeout, boolean discovery) { + return loadViews(seed, connectTimeout, discovery, Integer.MAX_VALUE); + } + + /** + * Load topology views from a collection of {@link RedisURI}s and return the view per {@link RedisURI}. Partitions contain + * an ordered list of {@link RedisClusterNode}s. The sort key is latency. Nodes with lower latency come first. + * + * @param seed collection of {@link RedisURI}s + * @param connectTimeout connect timeout + * @param discovery {@code true} to discover additional nodes + * @param maxTopologyRefreshSources maximum number of additionally queried (discovered) nodes. Use {@link Integer#MAX_VALUE} + * to query all discovered nodes. + * @return mapping between {@link RedisURI} and {@link Partitions} + */ + @Override + public CompletionStage> loadViews(Iterable seed, Duration connectTimeout, + boolean discovery, int maxTopologyRefreshSources) { if (!isEventLoopActive()) { return CompletableFuture.completedFuture(Collections.emptyMap()); @@ -115,7 +135,8 @@ public CompletionStage> loadViews(Iterable s if (discovery && isEventLoopActive()) { Set allKnownUris = views.getClusterNodes(); - Set discoveredNodes = difference(allKnownUris, toSet(seed)); + Set discoveredNodes = limit(difference(allKnownUris, toSet(seed)), + maxTopologyRefreshSources); if (discoveredNodes.isEmpty()) { return CompletableFuture.completedFuture(views); @@ -390,6 +411,18 @@ private static Set difference(Set allKnown, Set se return result; } + private static Set limit(Set uris, int maxTopologyRefreshSources) { + if (uris.size() <= maxTopologyRefreshSources) { + return uris; + } + + List uriList = new ArrayList<>(uris); + Collections.shuffle(uriList); + + return uriList.stream().limit(maxTopologyRefreshSources) + .collect(Collectors.toCollection(() -> new TreeSet<>(TopologyComparators.RedisURIComparator.INSTANCE))); + } + private static long getCommandTimeoutNs(Iterable redisURIs) { RedisURI redisURI = redisURIs.iterator().next(); diff --git a/src/test/java/io/lettuce/core/cluster/ClusterTopologyRefreshOptionsUnitTests.java b/src/test/java/io/lettuce/core/cluster/ClusterTopologyRefreshOptionsUnitTests.java index 4d980fea0c..8588b24acb 100644 --- a/src/test/java/io/lettuce/core/cluster/ClusterTopologyRefreshOptionsUnitTests.java +++ b/src/test/java/io/lettuce/core/cluster/ClusterTopologyRefreshOptionsUnitTests.java @@ -30,6 +30,7 @@ void testBuilder() { .adaptiveRefreshTriggersTimeout(15, TimeUnit.MILLISECONDS)// .closeStaleConnections(false)// .refreshTriggersReconnectAttempts(2)// + .maxTopologyRefreshSources(100)// .build(); assertThat(options.getRefreshPeriod()).isEqualTo(Duration.ofMinutes(10)); @@ -39,6 +40,7 @@ void testBuilder() { assertThat(options.getAdaptiveRefreshTimeout()).isEqualTo(Duration.ofMillis(15)); assertThat(options.getAdaptiveRefreshTriggers()).containsOnly(RefreshTrigger.MOVED_REDIRECT); assertThat(options.getRefreshTriggersReconnectAttempts()).isEqualTo(2); + assertThat(options.getMaxTopologyRefreshSources()).isEqualTo(100); } @Test @@ -52,6 +54,7 @@ void testCopy() { .adaptiveRefreshTriggersTimeout(15, TimeUnit.MILLISECONDS)// .closeStaleConnections(false)// .refreshTriggersReconnectAttempts(2)// + .maxTopologyRefreshSources(100)// .build(); ClusterTopologyRefreshOptions options = ClusterTopologyRefreshOptions.copyOf(master); @@ -63,6 +66,7 @@ void testCopy() { assertThat(options.getAdaptiveRefreshTimeout()).isEqualTo(Duration.ofMillis(15)); assertThat(options.getAdaptiveRefreshTriggers()).containsOnly(RefreshTrigger.MOVED_REDIRECT); assertThat(options.getRefreshTriggersReconnectAttempts()).isEqualTo(2); + assertThat(options.getMaxTopologyRefreshSources()).isEqualTo(100); } @Test @@ -82,6 +86,8 @@ void testDefault() { .isEqualTo(ClusterTopologyRefreshOptions.DEFAULT_ADAPTIVE_REFRESH_TRIGGERS); assertThat(options.getRefreshTriggersReconnectAttempts()) .isEqualTo(ClusterTopologyRefreshOptions.DEFAULT_REFRESH_TRIGGERS_RECONNECT_ATTEMPTS); + assertThat(options.getMaxTopologyRefreshSources()) + .isEqualTo(ClusterTopologyRefreshOptions.DEFAULT_MAX_TOPOLOGY_REFRESH_SOURCES); } @Test diff --git a/src/test/java/io/lettuce/core/cluster/topology/ClusterTopologyRefreshUnitTests.java b/src/test/java/io/lettuce/core/cluster/topology/ClusterTopologyRefreshUnitTests.java index 7c4a920584..ee3be32aaf 100644 --- a/src/test/java/io/lettuce/core/cluster/topology/ClusterTopologyRefreshUnitTests.java +++ b/src/test/java/io/lettuce/core/cluster/topology/ClusterTopologyRefreshUnitTests.java @@ -46,6 +46,7 @@ import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoSettings; @@ -602,6 +603,35 @@ void shouldPropagateCommandFailures() { } } + @Test + void shouldLimitDiscoveredNodesToOne() { + + List seed = Collections.singletonList(RedisURI.create("foobar", 7380)); + + when(nodeConnectionFactory.connectToNodeAsync(any(RedisCodec.class), + eq(InetSocketAddress.createUnresolved("foobar", 7380)))) + .thenReturn(completedFuture((StatefulRedisConnection) connection1)); + + when(nodeConnectionFactory.connectToNodeAsync(any(RedisCodec.class), + eq(InetSocketAddress.createUnresolved("127.0.0.1", 7380)))) + .thenReturn(completedFuture((StatefulRedisConnection) connection1)); + + when(nodeConnectionFactory.connectToNodeAsync(any(RedisCodec.class), + eq(InetSocketAddress.createUnresolved("127.0.0.1", 7381)))) + .thenReturn(completedFuture((StatefulRedisConnection) connection2)); + + sut.loadViews(seed, Duration.ofSeconds(1), true, 1).toCompletableFuture().join(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(InetSocketAddress.class); + + verify(nodeConnectionFactory, times(2)).connectToNodeAsync(any(RedisCodec.class), captor.capture()); + + List called = captor.getAllValues(); + assertThat(called).contains(InetSocketAddress.createUnresolved("foobar", 7380)); + + assertThat(called).anyMatch(a -> a.getHostString().equals("127.0.0.1") && (a.getPort() == 7380 || a.getPort() == 7381)); + } + Requests createClusterNodesRequests(int duration, String nodes) { RedisURI redisURI = RedisURI.create("redis://localhost:" + duration);