Skip to content

Commit cb0e52b

Browse files
authored
Fix memory leak issue on thread local in sharedLock #3633 (#3640)
1 parent 1ef3c73 commit cb0e52b

File tree

2 files changed

+186
-5
lines changed

2 files changed

+186
-5
lines changed

src/main/java/io/lettuce/core/protocol/SharedLock.java

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package io.lettuce.core.protocol;
22

3+
import java.util.WeakHashMap;
34
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
45
import java.util.concurrent.locks.Lock;
56
import java.util.concurrent.locks.ReentrantLock;
@@ -16,6 +17,15 @@
1617
* Exclusive locking is reentrant. An exclusive lock owner is permitted to acquire and release shared locks. Shared/exclusive
1718
* lock requests by other threads than the thread which holds the exclusive lock, are forced to wait until the exclusive lock is
1819
* released.
20+
* <p>
21+
* <b>Memory Management:</b> This implementation uses a static {@link ThreadLocal} containing a {@link WeakHashMap} to track
22+
* per-thread writer counts across all {@code SharedLock} instances. This design:
23+
* <ul>
24+
* <li>Creates only ONE ThreadLocal entry per thread (not per SharedLock instance)</li>
25+
* <li>Uses WeakHashMap so entries are automatically removed when SharedLock instances are garbage collected</li>
26+
* <li>Explicitly removes entries when writer count reaches zero for immediate cleanup</li>
27+
* <li>Eliminates the memory leak that occurred with per-instance ThreadLocal in connection pooling scenarios</li>
28+
* </ul>
1929
*
2030
* @author Mark Paluch
2131
*/
@@ -26,7 +36,8 @@ class SharedLock {
2636

2737
private final Lock lock = new ReentrantLock();
2838

29-
private final ThreadLocal<Integer> threadWriters = ThreadLocal.withInitial(() -> 0);
39+
private static final ThreadLocal<WeakHashMap<SharedLock, Integer>> THREAD_WRITERS = ThreadLocal
40+
.withInitial(WeakHashMap::new);
3041

3142
private volatile long writers = 0;
3243

@@ -47,7 +58,8 @@ void incrementWriters() {
4758

4859
if (WRITERS.get(this) >= 0) {
4960
WRITERS.incrementAndGet(this);
50-
threadWriters.set(threadWriters.get() + 1);
61+
WeakHashMap<SharedLock, Integer> map = THREAD_WRITERS.get();
62+
map.merge(this, 1, Integer::sum);
5163
return;
5264
}
5365
}
@@ -66,7 +78,8 @@ void decrementWriters() {
6678
}
6779

6880
WRITERS.decrementAndGet(this);
69-
threadWriters.set(threadWriters.get() - 1);
81+
WeakHashMap<SharedLock, Integer> map = THREAD_WRITERS.get();
82+
map.computeIfPresent(this, (lock, count) -> count <= 1 ? null : count - 1);
7083
}
7184

7285
/**
@@ -126,7 +139,8 @@ private void lockWritersExclusive() {
126139
for (;;) {
127140

128141
// allow reentrant exclusive lock by comparing writers count and threadWriters count
129-
if (WRITERS.compareAndSet(this, threadWriters.get(), -1)) {
142+
int threadWriterCount = getThreadWriterCount();
143+
if (WRITERS.compareAndSet(this, threadWriterCount, -1)) {
130144
exclusiveLockOwner = Thread.currentThread();
131145
return;
132146
}
@@ -142,8 +156,10 @@ private void lockWritersExclusive() {
142156
private void unlockWritersExclusive() {
143157

144158
if (exclusiveLockOwner == Thread.currentThread()) {
159+
int threadWriterCount = getThreadWriterCount();
160+
145161
// check exclusive look not reentrant first
146-
if (WRITERS.compareAndSet(this, -1, threadWriters.get())) {
162+
if (WRITERS.compareAndSet(this, -1, threadWriterCount)) {
147163
exclusiveLockOwner = null;
148164
return;
149165
}
@@ -152,4 +168,8 @@ private void unlockWritersExclusive() {
152168
}
153169
}
154170

171+
private int getThreadWriterCount() {
172+
return THREAD_WRITERS.get().getOrDefault(this, 0);
173+
}
174+
155175
}

src/test/java/io/lettuce/core/protocol/SharedLockUnitTests.java

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
import org.junit.jupiter.api.Tag;
55
import org.junit.jupiter.api.Test;
66

7+
import java.lang.reflect.Field;
8+
import java.util.WeakHashMap;
79
import java.util.concurrent.CountDownLatch;
810
import java.util.concurrent.TimeUnit;
911

1012
import static io.lettuce.TestTags.UNIT_TEST;
13+
import java.util.concurrent.atomic.AtomicInteger;
1114

1215
@Tag(UNIT_TEST)
1316
public class SharedLockUnitTests {
@@ -58,4 +61,162 @@ public void safety_on_reentrant_lock_exclusive_on_writers() throws InterruptedEx
5861
Assertions.assertTrue(await);
5962
}
6063

64+
@Test
65+
public void writerCountsAreIndependentPerSharedLockInstance() {
66+
final SharedLock lock1 = new SharedLock();
67+
final SharedLock lock2 = new SharedLock();
68+
69+
// Increment writers on lock1
70+
lock1.incrementWriters();
71+
72+
// lock2 should still be able to get exclusive lock (its writer count is 0)
73+
String result = lock2.doExclusive(() -> "exclusive-on-lock2");
74+
Assertions.assertEquals("exclusive-on-lock2", result);
75+
76+
// Cleanup
77+
lock1.decrementWriters();
78+
79+
// Now lock1 should also work
80+
result = lock1.doExclusive(() -> "exclusive-on-lock1");
81+
Assertions.assertEquals("exclusive-on-lock1", result);
82+
}
83+
84+
@Test
85+
@SuppressWarnings("unchecked")
86+
public void entryShouldBeRemovedWhenWriterCountReachesZero() throws Exception {
87+
final SharedLock sharedLock = new SharedLock();
88+
89+
// Access the static THREAD_WRITERS field
90+
Field threadWritersField = SharedLock.class.getDeclaredField("THREAD_WRITERS");
91+
threadWritersField.setAccessible(true);
92+
ThreadLocal<WeakHashMap<SharedLock, Integer>> threadLocal = (ThreadLocal<WeakHashMap<SharedLock, Integer>>) threadWritersField
93+
.get(null);
94+
95+
WeakHashMap<SharedLock, Integer> map = threadLocal.get();
96+
97+
// Initially, the map should not contain this SharedLock
98+
Assertions.assertFalse(map.containsKey(sharedLock), "Map should not contain SharedLock initially");
99+
100+
// Increment should add an entry
101+
sharedLock.incrementWriters();
102+
Assertions.assertTrue(map.containsKey(sharedLock), "Map should contain SharedLock after increment");
103+
Assertions.assertEquals(1, map.get(sharedLock), "Writer count should be 1");
104+
105+
// Decrement to zero should remove the entry
106+
sharedLock.decrementWriters();
107+
Assertions.assertFalse(map.containsKey(sharedLock), "Map entry should be removed when count reaches zero");
108+
}
109+
110+
@Test
111+
@SuppressWarnings("unchecked")
112+
public void nestedWriterCountsShouldWorkCorrectly() throws Exception {
113+
final SharedLock sharedLock = new SharedLock();
114+
115+
// Access the static THREAD_WRITERS field
116+
Field threadWritersField = SharedLock.class.getDeclaredField("THREAD_WRITERS");
117+
threadWritersField.setAccessible(true);
118+
ThreadLocal<WeakHashMap<SharedLock, Integer>> threadLocal = (ThreadLocal<WeakHashMap<SharedLock, Integer>>) threadWritersField
119+
.get(null);
120+
121+
WeakHashMap<SharedLock, Integer> map = threadLocal.get();
122+
123+
// Nested increments
124+
sharedLock.incrementWriters();
125+
Assertions.assertEquals(1, map.get(sharedLock));
126+
127+
sharedLock.incrementWriters();
128+
Assertions.assertEquals(2, map.get(sharedLock));
129+
130+
sharedLock.incrementWriters();
131+
Assertions.assertEquals(3, map.get(sharedLock));
132+
133+
// Decrements - entry should NOT be removed until count reaches 0
134+
sharedLock.decrementWriters();
135+
Assertions.assertEquals(2, map.get(sharedLock));
136+
137+
sharedLock.decrementWriters();
138+
Assertions.assertEquals(1, map.get(sharedLock));
139+
140+
// Final decrement to zero - entry should be removed
141+
sharedLock.decrementWriters();
142+
Assertions.assertFalse(map.containsKey(sharedLock), "Entry should be removed when count reaches zero");
143+
}
144+
145+
@Test
146+
public void multipleThreadsShouldNotInterfere() throws InterruptedException {
147+
final SharedLock sharedLock = new SharedLock();
148+
final int threadCount = 100;
149+
final CountDownLatch startLatch = new CountDownLatch(1);
150+
final CountDownLatch completionLatch = new CountDownLatch(threadCount);
151+
final AtomicInteger errorCount = new AtomicInteger(0);
152+
153+
// Create multiple threads that will use the SharedLock
154+
for (int i = 0; i < threadCount; i++) {
155+
new Thread(() -> {
156+
try {
157+
startLatch.await();
158+
159+
// Each thread increments and decrements multiple times
160+
for (int j = 0; j < 10; j++) {
161+
sharedLock.incrementWriters();
162+
sharedLock.decrementWriters();
163+
}
164+
} catch (Exception e) {
165+
errorCount.incrementAndGet();
166+
} finally {
167+
completionLatch.countDown();
168+
}
169+
}).start();
170+
}
171+
172+
// Start all threads at once
173+
startLatch.countDown();
174+
175+
// Wait for all threads to complete
176+
boolean completed = completionLatch.await(10, TimeUnit.SECONDS);
177+
Assertions.assertTrue(completed, "All threads should complete within timeout");
178+
Assertions.assertEquals(0, errorCount.get(), "No errors should occur during concurrent operations");
179+
180+
// After all threads complete, the SharedLock should be in a clean state
181+
String result = sharedLock.doExclusive(() -> "success");
182+
Assertions.assertEquals("success", result);
183+
}
184+
185+
@Test
186+
@SuppressWarnings("unchecked")
187+
public void singleThreadLocalEntryPerThread() throws Exception {
188+
// Access the static THREAD_WRITERS field
189+
Field threadWritersField = SharedLock.class.getDeclaredField("THREAD_WRITERS");
190+
threadWritersField.setAccessible(true);
191+
ThreadLocal<WeakHashMap<SharedLock, Integer>> threadLocal = (ThreadLocal<WeakHashMap<SharedLock, Integer>>) threadWritersField
192+
.get(null);
193+
194+
// Get the map for this thread
195+
WeakHashMap<SharedLock, Integer> map1 = threadLocal.get();
196+
WeakHashMap<SharedLock, Integer> map2 = threadLocal.get();
197+
198+
// Should be the SAME map instance
199+
Assertions.assertSame(map1, map2, "ThreadLocal should return the same map instance");
200+
201+
// Create multiple SharedLocks
202+
SharedLock lock1 = new SharedLock();
203+
SharedLock lock2 = new SharedLock();
204+
SharedLock lock3 = new SharedLock();
205+
206+
lock1.incrementWriters();
207+
lock2.incrementWriters();
208+
lock3.incrementWriters();
209+
210+
// All should be in the SAME map
211+
WeakHashMap<SharedLock, Integer> map = threadLocal.get();
212+
Assertions.assertTrue(map.containsKey(lock1));
213+
Assertions.assertTrue(map.containsKey(lock2));
214+
Assertions.assertTrue(map.containsKey(lock3));
215+
216+
// Cleanup
217+
lock1.decrementWriters();
218+
lock2.decrementWriters();
219+
lock3.decrementWriters();
220+
}
221+
61222
}

0 commit comments

Comments
 (0)