diff --git a/libs/common/Collections/IndexedPriorityQueue.cs b/libs/common/Collections/IndexedPriorityQueue.cs new file mode 100644 index 00000000000..8816eb44d92 --- /dev/null +++ b/libs/common/Collections/IndexedPriorityQueue.cs @@ -0,0 +1,323 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.CompilerServices; + +namespace Garnet.common.Collections +{ + /// + /// In-place updatable min-heap. With methods to access priority in constant time. + /// + public class IndexedPriorityQueue + { + // element -> index in heap + private readonly Dictionary _index; + + private const int DefaultCapacity = 4; + + // binary heap + private (TElement element, TPriority priority)[] _heap = []; + private int _count; + + /// + /// Raw heap access to do iteration really fast + /// + public (TElement element, TPriority priority)[] RawHeap => _heap; + + /// + /// Number of elements in the priority queue. + /// + public int Count => _count; + + /// + /// Creates an IndexedPriorityQueue using the default equality comparer for elements. + /// + public IndexedPriorityQueue() : this(null) { } + + /// + /// Creates an IndexedPriorityQueue using the specified equality comparer for elements. + /// + public IndexedPriorityQueue(IEqualityComparer comparer) + { + _index = new Dictionary(comparer); + } + + /// + /// Determines whether the specified element exists in the priority queue. + /// + /// The element to look up. + /// if the element exists; otherwise, . + public bool Exists(TElement element) => _index.ContainsKey(element); + + /// + /// O(log N) - Enqueue or update the priority of a key + /// + /// + /// + public void EnqueueOrUpdate(TElement element, TPriority value) + { + if (_index.TryGetValue(element, out int idxInHeap)) + { + _index[element] = UpdateHeap(idxInHeap, value); + return; + } + + _index[element] = InsertIntoHeap(element, value); + } + + /// + /// O(log N) - Dequeue Key with Lowest Priority + /// + /// Element with lowest priority + public TElement Dequeue() + { + if (_count == 0) + throw new InvalidOperationException("The queue is empty."); + + return DequeueFromHeap(); + } + + /// + /// + /// O(1) - Try to peek at the element with the lowest priority + /// + /// The element with the lowest priority + /// The priority of the element + /// True if the queue is not empty, otherwise false + public bool TryPeek(out TElement key, out TPriority value) + { + if (_count == 0) + { + key = default!; + value = default!; + return false; + } + (key, value) = _heap[0]; + return true; + } + + /// + /// O(log N) - Change the priority of an element + /// + /// The element whose priority is to be changed + /// The new priority value + public void ChangePriority(TElement key, TPriority newValue) => _index[key] = UpdateHeap(_index[key], newValue); + + /// + /// O(1) - Get the priority of an element + /// + /// The element whose priority is to be retrieved + /// The priority of the element + public void GetPriority(TElement key, out TPriority value) => value = _heap[_index[key]].priority; + + + /// + /// O(1) - Try to get the priority of an element. Returns false if the element is not in the queue. + /// + /// + /// + /// + public bool TryGetPriority(TElement key, out TPriority value) + { + if (_index.TryGetValue(key, out int idxInHeap)) + { + value = _heap[idxInHeap].priority; + return true; + } + value = default!; + return false; + } + + + /// + /// O(log N) - Try to remove an element from the queue. Returns false if the element is not in the queue. + /// + /// + /// + public bool TryRemove(TElement key) + { + if (!_index.TryGetValue(key, out int idxInHeap)) + { + return false; + } + + _index.Remove(key); + _count--; + + if (idxInHeap != _count) + { + _heap[idxInHeap] = _heap[_count]; + _index[_heap[idxInHeap].element] = idxInHeap; + + // Try sifting down, if it doesn't move then try sifting up + if (SiftDown(idxInHeap) == idxInHeap) + { + SiftUp(idxInHeap); + } + } + + _heap[_count] = default; + + if (_heap.Length > DefaultCapacity && _count < _heap.Length / 2) + { + Shrink(); + } + return true; + } + + + // helper - methods + + private int InsertIntoHeap(TElement key, TPriority value) + { + if (_count == _heap.Length) + { + Grow(_count + 1); + } + + _heap[_count] = (key, value); + _index[key] = _count; + _count++; + return SiftUp(_count - 1); + } + + private TElement DequeueFromHeap() + { + TElement element = _heap[0].element; + _index.Remove(element); + _count--; + + if (_count > 0) + { + _heap[0] = _heap[_count]; + _heap[_count] = default; + _index[_heap[0].element] = 0; + SiftDown(0); + } + else + { + _heap[0] = default; + } + + if (_heap.Length > DefaultCapacity && _count < _heap.Length / 2) + { + Shrink(); + } + + return element; + } + + private int UpdateHeap(int idxInHeap, TPriority newValue) + { + TPriority oldValue = _heap[idxInHeap].priority; + TElement element = _heap[idxInHeap].element; + _heap[idxInHeap] = (element, newValue); + + int cmp = Comparer.Default.Compare(newValue, oldValue); + if (cmp < 0) + { + // new priority is smaller – sift up + return SiftUp(idxInHeap); + } + else if (cmp > 0) + { + // new priority is larger – sift down + return SiftDown(idxInHeap); + } + + return idxInHeap; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private int SiftUp(int currIdx) + { + var entry = _heap[currIdx]; + while (currIdx > 0) + { + int parentIdx = GetParentIndex(currIdx); + if (Comparer.Default.Compare(_heap[parentIdx].priority, entry.priority) <= 0) + break; + + _heap[currIdx] = _heap[parentIdx]; + _index[_heap[currIdx].element] = currIdx; + currIdx = parentIdx; + } + _heap[currIdx] = entry; + _index[entry.element] = currIdx; + return currIdx; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private int SiftDown(int currIdx) + { + var entry = _heap[currIdx]; + while (true) + { + int smallerChildIdx = GetLeftChildIndex(currIdx); + if (smallerChildIdx >= _count) + break; + + int rightChildIdx = smallerChildIdx + 1; + if (rightChildIdx < _count && Comparer.Default.Compare(_heap[rightChildIdx].priority, _heap[smallerChildIdx].priority) < 0) + smallerChildIdx = rightChildIdx; + + if (Comparer.Default.Compare(entry.priority, _heap[smallerChildIdx].priority) <= 0) + break; + + _heap[currIdx] = _heap[smallerChildIdx]; + _index[_heap[currIdx].element] = currIdx; + currIdx = smallerChildIdx; + } + _heap[currIdx] = entry; + _index[entry.element] = currIdx; + return currIdx; + } + + private int GetParentIndex(int i) => (i - 1) / 2; + + private int GetLeftChildIndex(int i) => (2 * i) + 1; + + /// + /// Grows the priority queue to match the specified min capacity. + /// + private void Grow(int minCapacity) + { + Debug.Assert(_heap.Length < minCapacity); + + const int GrowFactor = 2; + const int MinimumGrow = 4; + + int newcapacity = GrowFactor * _heap.Length; + + // Allow the queue to grow to maximum possible capacity (~2G elements) before encountering overflow. + // Note that this check works even when _heap.Length overflowed thanks to the (uint) cast + if ((uint)newcapacity > Array.MaxLength) newcapacity = Array.MaxLength; + + // Ensure minimum growth is respected. + newcapacity = Math.Max(newcapacity, _heap.Length + MinimumGrow); + + // If the computed capacity is still less than specified, set to the original argument. + // Capacities exceeding Array.MaxLength will be surfaced as OutOfMemoryException by Array.Resize. + if (newcapacity < minCapacity) newcapacity = minCapacity; + + Array.Resize(ref _heap, newcapacity); + } + + /// + /// Shrinks the backing array when more than half the space is unoccupied. + /// + private void Shrink() + { + int newCapacity = _heap.Length / 2; + newCapacity = Math.Max(newCapacity, DefaultCapacity); + + if (newCapacity < _heap.Length) + { + Array.Resize(ref _heap, newCapacity); + } + } + } +} \ No newline at end of file diff --git a/libs/server/Objects/SortedSet/SortedSetObject.cs b/libs/server/Objects/SortedSet/SortedSetObject.cs index ce20b6e34f5..dd83af0158f 100644 --- a/libs/server/Objects/SortedSet/SortedSetObject.cs +++ b/libs/server/Objects/SortedSet/SortedSetObject.cs @@ -9,6 +9,7 @@ using System.Linq; using System.Runtime.CompilerServices; using Garnet.common; +using Garnet.common.Collections; using Tsavorite.core; namespace Garnet.server @@ -141,8 +142,7 @@ public partial class SortedSetObject : GarnetObjectBase { private readonly SortedSet<(double Score, byte[] Element)> sortedSet; private readonly Dictionary sortedSetDict; - private Dictionary expirationTimes; - private PriorityQueue expirationQueue; + private IndexedPriorityQueue expirationQueue; // Byte #31 is used to denote if key has expiration (1) or not (0) private const int ExpirationBitMask = 1 << 31; @@ -192,8 +192,7 @@ public SortedSetObject(BinaryReader reader) if (expiration > 0) { InitializeExpirationStructures(); - expirationTimes.Add(item, expiration); - expirationQueue.Enqueue(item, expiration); + expirationQueue.EnqueueOrUpdate(item, expiration); UpdateExpirationSize(add: true); } } @@ -208,7 +207,6 @@ public SortedSetObject(SortedSetObject sortedSetObject) { sortedSet = sortedSetObject.sortedSet; sortedSetDict = sortedSetObject.sortedSetDict; - expirationTimes = sortedSetObject.expirationTimes; expirationQueue = sortedSetObject.expirationQueue; } @@ -252,7 +250,7 @@ public override void DoSerialize(BinaryWriter writer) writer.Write(count); foreach (var kvp in sortedSetDict) { - if (expirationTimes is not null && expirationTimes.TryGetValue(kvp.Key, out var expiration)) + if (expirationQueue is not null && expirationQueue.TryGetPriority(kvp.Key, out var expiration)) { writer.Write(kvp.Key.Length | ExpirationBitMask); writer.Write(kvp.Key); @@ -277,7 +275,7 @@ public override void DoSerialize(BinaryWriter writer) /// public void Add(byte[] item, double score) { - DeleteExpiredItems(); + DeleteExpiredItems(bound: 16); sortedSetDict.Add(item, score); _ = sortedSet.Add((score, item)); @@ -295,10 +293,10 @@ public bool Equals(SortedSetObject other) foreach (var key in sortedSetDict) { - if (IsExpired(key.Key) && IsExpired(key.Key)) + if (IsExpired(key.Key) && other.IsExpired(key.Key)) continue; - if (IsExpired(key.Key) || IsExpired(key.Key)) + if (IsExpired(key.Key) || other.IsExpired(key.Key)) return false; if (!other.sortedSetDict.TryGetValue(key.Key, out var otherValue) || key.Value != otherValue) @@ -509,7 +507,7 @@ public static Dictionary CopyDiff(SortedSetObject sortedSetObjec if (sortedSetObject2 == null) { - if (sortedSetObject1.expirationTimes is null) + if (!sortedSetObject1.HasExpirableItems()) { return new Dictionary(sortedSetObject1.sortedSetDict, ByteArrayComparer.Instance); } @@ -543,11 +541,14 @@ public static void InPlaceDiff(Dictionary dict1, SortedSetObject if (sortedSetObject2 != null) { + var keysToRemove = new List(); foreach (var item in dict1) { if (!sortedSetObject2.IsExpired(item.Key) && sortedSetObject2.sortedSetDict.ContainsKey(item.Key)) - _ = dict1.Remove(item.Key); + keysToRemove.Add(item.Key); } + foreach (var key in keysToRemove) + _ = dict1.Remove(key); } } @@ -573,14 +574,18 @@ public bool TryGetScore(byte[] key, out double value) /// The count of elements in the sorted set. public int Count() { - if (!HasExpirableItems()) + if (!HasExpirableItems() || (expirationQueue.TryPeek(out _, out var minExpiration) && minExpiration > DateTimeOffset.UtcNow.Ticks)) return sortedSetDict.Count; var expiredKeysCount = 0; - foreach (var item in expirationTimes) + var rawHeap = expirationQueue.RawHeap; + var heapCount = expirationQueue.Count; + for (int i = 0; i < heapCount; i++) { - if (IsExpired(item.Key)) + if (rawHeap[i].priority < DateTimeOffset.UtcNow.Ticks) + { expiredKeysCount++; + } } return sortedSetDict.Count - expiredKeysCount; } @@ -591,88 +596,72 @@ public int Count() /// The key to check for expiration. /// True if the key is expired; otherwise, false. [MethodImpl(MethodImplOptions.AggressiveInlining)] - public bool IsExpired(byte[] key) => expirationTimes is not null && expirationTimes.TryGetValue(key, out var expiration) && expiration < DateTimeOffset.UtcNow.Ticks; + public bool IsExpired(byte[] key) => expirationQueue is not null && expirationQueue.TryGetPriority(key, out var expiration) && expiration < DateTimeOffset.UtcNow.Ticks; /// /// Determines whether the sorted set has expirable items. /// /// True if the sorted set has expirable items; otherwise, false. [MethodImpl(MethodImplOptions.AggressiveInlining)] - public bool HasExpirableItems() => expirationTimes is not null; + public bool HasExpirableItems() => expirationQueue is not null && expirationQueue.Count > 0; #endregion private void InitializeExpirationStructures() { - if (expirationTimes is null) + if (expirationQueue is null) { - expirationTimes = new Dictionary(ByteArrayComparer.Instance); - expirationQueue = new PriorityQueue(); - HeapMemorySize += MemoryUtils.DictionaryOverhead + MemoryUtils.PriorityQueueOverhead; - // No DiskSize adjustment needed yet; wait until keys are added or removed + expirationQueue = new IndexedPriorityQueue(ByteArrayComparer.Instance); + HeapMemorySize += MemoryUtils.IndexedPriorityQueueOverhead; } } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void UpdateExpirationSize(bool add, bool includePQ = true) + private void UpdateExpirationSize(bool add) { - // Account for dictionary entry and priority queue entry - var memorySize = IntPtr.Size + sizeof(long) + MemoryUtils.DictionaryEntryOverhead; - if (includePQ) - memorySize += IntPtr.Size + sizeof(long) + MemoryUtils.PriorityQueueEntryOverhead; - if (add) - HeapMemorySize += memorySize; + HeapMemorySize += MemoryUtils.IndexedPriorityQueueEntryOverhead; else { - HeapMemorySize -= memorySize; - Debug.Assert(HeapMemorySize >= MemoryUtils.DictionaryOverhead); + HeapMemorySize -= MemoryUtils.IndexedPriorityQueueEntryOverhead; + Debug.Assert(HeapMemorySize >= MemoryUtils.DictionaryOverhead + MemoryUtils.SortedSetOverhead); } } [MethodImpl(MethodImplOptions.AggressiveInlining)] private void CleanupExpirationStructuresIfEmpty() { - if (expirationTimes.Count != 0) + if (expirationQueue.Count != 0) return; - HeapMemorySize -= (IntPtr.Size + sizeof(long) + MemoryUtils.PriorityQueueEntryOverhead) * expirationQueue.Count; - HeapMemorySize -= MemoryUtils.DictionaryOverhead + MemoryUtils.PriorityQueueOverhead; - expirationTimes = null; + HeapMemorySize -= MemoryUtils.IndexedPriorityQueueOverhead; expirationQueue = null; } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void DeleteExpiredItems() + private void DeleteExpiredItems(int bound = 0) { - if (expirationTimes is null) + if (expirationQueue is null) return; - DeleteExpiredItemsWorker(); + DeleteExpiredItemsWorker(bound); } - private void DeleteExpiredItemsWorker() + private void DeleteExpiredItemsWorker(int bound) { + int i = 0; while (expirationQueue.TryPeek(out var key, out var expiration) && expiration < DateTimeOffset.UtcNow.Ticks) { - if (expirationTimes.TryGetValue(key, out var actualExpiration) && actualExpiration == expiration) - { - _ = expirationTimes.Remove(key); - _ = expirationQueue.Dequeue(); - UpdateExpirationSize(add: false); - if (sortedSetDict.TryGetValue(key, out var value)) - { - _ = sortedSetDict.Remove(key); - _ = sortedSet.Remove((value, key)); - UpdateSize(key, add: false); - } - } - else - { - // The key was not in expirationTimes. It may have been Remove()d. - _ = expirationQueue.Dequeue(); + if (bound > 0 && i >= bound) + break; - // Adjust memory size for the priority queue entry removal. No DiskSize change needed as it was not in expirationTimes. - HeapMemorySize -= MemoryUtils.PriorityQueueEntryOverhead + IntPtr.Size + sizeof(long); + _ = expirationQueue.Dequeue(); + UpdateExpirationSize(add: false); + if (sortedSetDict.TryGetValue(key, out var value)) + { + _ = sortedSetDict.Remove(key); + _ = sortedSet.Remove((value, key)); + UpdateSize(key, add: false); } + i++; } CleanupExpirationStructuresIfEmpty(); @@ -687,13 +676,14 @@ private int SetExpiration(byte[] key, long expiration, ExpireOption expireOption { _ = sortedSetDict.Remove(key, out var value); _ = sortedSet.Remove((value, key)); + TryRemoveExpiration(key); UpdateSize(key, add: false); return (int)SortedSetExpireResult.KeyAlreadyExpired; } InitializeExpirationStructures(); - if (expirationTimes.TryGetValue(key, out var currentExpiration)) + if (expirationQueue.TryGetPriority(key, out var currentExpiration)) { if (expireOption.HasFlag(ExpireOption.NX) || (expireOption.HasFlag(ExpireOption.GT) && expiration <= currentExpiration) || @@ -702,20 +692,15 @@ private int SetExpiration(byte[] key, long expiration, ExpireOption expireOption return (int)SortedSetExpireResult.ExpireConditionNotMet; } - expirationTimes[key] = expiration; - expirationQueue.Enqueue(key, expiration); - - // LogMemorySize of dictionary entry already accounted for as the key already exists. - // DiskSize of expiration already accounted for as the key already exists in expirationTimes. - HeapMemorySize += IntPtr.Size + sizeof(long) + MemoryUtils.PriorityQueueEntryOverhead; + expirationQueue.EnqueueOrUpdate(key, expiration); + // In-place update — no size change needed } else { if ((expireOption & ExpireOption.XX) == ExpireOption.XX || (expireOption & ExpireOption.GT) == ExpireOption.GT) return (int)SortedSetExpireResult.ExpireConditionNotMet; - expirationTimes[key] = expiration; - expirationQueue.Enqueue(key, expiration); + expirationQueue.EnqueueOrUpdate(key, expiration); UpdateExpirationSize(add: true); } @@ -732,19 +717,17 @@ private int Persist(byte[] key) [MethodImpl(MethodImplOptions.AggressiveInlining)] private bool TryRemoveExpiration(byte[] key) { - if (expirationTimes is null) + if (expirationQueue is null) return false; return TryRemoveExpirationWorker(key); } private bool TryRemoveExpirationWorker(byte[] key) { - if (!expirationTimes.TryGetValue(key, out _)) + if (!expirationQueue.TryRemove(key)) return false; - _ = expirationTimes.Remove(key); - - UpdateExpirationSize(add: false, includePQ: false); + UpdateExpirationSize(add: false); CleanupExpirationStructuresIfEmpty(); return true; } @@ -753,7 +736,7 @@ private long GetExpiration(byte[] key) { if (!sortedSetDict.ContainsKey(key)) return -2; - if (expirationTimes is not null && expirationTimes.TryGetValue(key, out var expiration)) + if (expirationQueue is not null && expirationQueue.TryGetPriority(key, out var expiration)) return expiration; return -1; } diff --git a/libs/server/Objects/SortedSet/SortedSetObjectImpl.cs b/libs/server/Objects/SortedSet/SortedSetObjectImpl.cs index 51fbb325d4e..74055be0dac 100644 --- a/libs/server/Objects/SortedSet/SortedSetObjectImpl.cs +++ b/libs/server/Objects/SortedSet/SortedSetObjectImpl.cs @@ -89,7 +89,7 @@ bool GetOptions(ref ObjectInput input, ref int currTokenIdx, out SortedSetAddOpt private void SortedSetAdd(ref ObjectInput input, ref ObjectOutput output, byte respProtocolVersion) { - DeleteExpiredItems(); + DeleteExpiredItems(bound: 16); var addedOrChanged = 0; double incrResult = 0; @@ -210,7 +210,7 @@ private void SortedSetAdd(ref ObjectInput input, ref ObjectOutput output, byte r private void SortedSetRemove(ref ObjectInput input, ref ObjectOutput output) { - DeleteExpiredItems(); + DeleteExpiredItems(bound: 16); for (var i = 0; i < input.parseState.Count; i++) { @@ -316,7 +316,7 @@ private void SortedSetCount(ref ObjectInput input, ref ObjectOutput output, byte private void SortedSetIncrement(ref ObjectInput input, ref ObjectOutput output, byte respProtocolVersion) { - DeleteExpiredItems(); + DeleteExpiredItems(bound: 16); // It's useful to fix RESP2 in the internal API as that just reads back the output. if (input.arg2 > 0) @@ -499,7 +499,7 @@ private void SortedSetRange(ref ObjectInput input, ref ObjectOutput output, byte var n = maxIndex - minIndex + 1; var iterator = options.Reverse ? sortedSet.Reverse() : sortedSet; - if (expirationTimes is not null) + if (HasExpirableItems()) { iterator = iterator.Where(x => !IsExpired(x.Element)); } @@ -565,7 +565,7 @@ void WriteSortedSetResult(bool withScores, int count, byte respProtocolVersion, private void SortedSetRemoveRangeByRank(ref ObjectInput input, ref ObjectOutput output, byte respProtocolVersion) { - DeleteExpiredItems(); + DeleteExpiredItems(bound: 16); using var writer = new RespMemoryWriter(respProtocolVersion, ref output.SpanByteAndMemory); @@ -607,7 +607,7 @@ private void SortedSetRemoveRangeByRank(ref ObjectInput input, ref ObjectOutput private void SortedSetRemoveRangeByScore(ref ObjectInput input, ref ObjectOutput output, byte respProtocolVersion) { - DeleteExpiredItems(); + DeleteExpiredItems(bound: 16); using var writer = new RespMemoryWriter(respProtocolVersion, ref output.SpanByteAndMemory); @@ -690,7 +690,7 @@ private void SortedSetRemoveOrCountRangeByLex(ref ObjectInput input, ref ObjectO if (isRemove) { - DeleteExpiredItems(); + DeleteExpiredItems(bound: 16); } var rem = GetElementsInRangeByLex(minParamBytes, maxParamBytes, false, false, isRemove, out int errorCode); @@ -840,7 +840,7 @@ private void SortedSetPopMinOrMaxCount(ref ObjectInput input, ref ObjectOutput o private void SortedSetPersist(ref ObjectInput input, ref ObjectOutput output, byte respProtocolVersion) { - DeleteExpiredItems(); + DeleteExpiredItems(bound: 16); var numFields = input.parseState.Count; @@ -904,7 +904,7 @@ private void SortedSetTimeToLive(ref ObjectInput input, ref ObjectOutput output, private void SortedSetExpire(ref ObjectInput input, ref ObjectOutput output, byte respProtocolVersion) { - DeleteExpiredItems(); + DeleteExpiredItems(bound: 16); var expirationWithOption = new ExpirationWithOption(input.arg1, input.arg2); diff --git a/libs/storage/Tsavorite/cs/src/core/Utilities/MemoryUtils.cs b/libs/storage/Tsavorite/cs/src/core/Utilities/MemoryUtils.cs index 045a27d4fbc..02e86e56976 100644 --- a/libs/storage/Tsavorite/cs/src/core/Utilities/MemoryUtils.cs +++ b/libs/storage/Tsavorite/cs/src/core/Utilities/MemoryUtils.cs @@ -43,6 +43,18 @@ public static class MemoryUtils /// .Net object avg. overhead for holding a priority queue entry public const int PriorityQueueEntryOverhead = 48; + /// + /// .Net object overhead for IndexedPriorityQueue (Dictionary + array + count). + /// Dictionary(80) + array object header(24) + int(4) ≈ 108, rounded to 112. + /// + public const int IndexedPriorityQueueOverhead = 112; + + /// + /// .Net object avg. overhead per entry in IndexedPriorityQueue. + /// Dictionary entry(64) + heap array slot (ref 8 + long 8 = 16) = 80. + /// + public const int IndexedPriorityQueueEntryOverhead = 80; + /// This is but that is a static expression, not a constant public const int ArrayMaxLength = 0x7FFFFFC7; diff --git a/test/Garnet.test/IndexedPriorityQueueTests.cs b/test/Garnet.test/IndexedPriorityQueueTests.cs new file mode 100644 index 00000000000..16e1dd0d101 --- /dev/null +++ b/test/Garnet.test/IndexedPriorityQueueTests.cs @@ -0,0 +1,423 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Collections.Generic; +using Allure.NUnit; +using Garnet.common.Collections; +using NUnit.Framework; +using NUnit.Framework.Legacy; + +namespace Garnet.test +{ + [AllureNUnit] + [TestFixture] + public class IndexedPriorityQueueTests : AllureTestBase + { + #region Basic Operations + + [Test] + public void EmptyQueueHasCountZero() + { + var q = new IndexedPriorityQueue(); + ClassicAssert.AreEqual(0, q.Count); + } + + [Test] + public void EnqueueIncreasesCount() + { + var q = new IndexedPriorityQueue(); + q.EnqueueOrUpdate("a", 10); + ClassicAssert.AreEqual(1, q.Count); + q.EnqueueOrUpdate("b", 20); + ClassicAssert.AreEqual(2, q.Count); + } + + [Test] + public void DequeueReturnsMinPriority() + { + var q = new IndexedPriorityQueue(); + q.EnqueueOrUpdate("high", 100); + q.EnqueueOrUpdate("low", 1); + q.EnqueueOrUpdate("mid", 50); + + ClassicAssert.AreEqual("low", q.Dequeue()); + ClassicAssert.AreEqual("mid", q.Dequeue()); + ClassicAssert.AreEqual("high", q.Dequeue()); + } + + [Test] + public void DequeueOnEmptyThrows() + { + var q = new IndexedPriorityQueue(); + Assert.Throws(() => q.Dequeue()); + } + + [Test] + public void TryPeekReturnsFalseWhenEmpty() + { + var q = new IndexedPriorityQueue(); + ClassicAssert.IsFalse(q.TryPeek(out _, out _)); + } + + [Test] + public void TryPeekReturnsMinWithoutRemoving() + { + var q = new IndexedPriorityQueue(); + q.EnqueueOrUpdate("a", 5); + q.EnqueueOrUpdate("b", 3); + + ClassicAssert.IsTrue(q.TryPeek(out var key, out var priority)); + ClassicAssert.AreEqual("b", key); + ClassicAssert.AreEqual(3, priority); + ClassicAssert.AreEqual(2, q.Count); + } + + [Test] + public void ExistsReturnsTrueForEnqueuedElement() + { + var q = new IndexedPriorityQueue(); + q.EnqueueOrUpdate("a", 10); + + ClassicAssert.IsTrue(q.Exists("a")); + ClassicAssert.IsFalse(q.Exists("b")); + } + + #endregion + + #region In-Place Update + + [Test] + public void EnqueueOrUpdateUpdatesExistingElement() + { + var q = new IndexedPriorityQueue(); + q.EnqueueOrUpdate("a", 100); + q.EnqueueOrUpdate("b", 50); + + // Update "a" to lower priority — should become the new min + q.EnqueueOrUpdate("a", 1); + + ClassicAssert.AreEqual(2, q.Count, "Update should not add a new entry"); + ClassicAssert.IsTrue(q.TryPeek(out var key, out var priority)); + ClassicAssert.AreEqual("a", key); + ClassicAssert.AreEqual(1, priority); + } + + [Test] + public void ChangePriorityMovesElementDown() + { + var q = new IndexedPriorityQueue(); + q.EnqueueOrUpdate("a", 1); + q.EnqueueOrUpdate("b", 10); + q.EnqueueOrUpdate("c", 20); + + // Move "a" to highest priority value — "b" should become min + q.ChangePriority("a", 100); + + q.TryPeek(out var key, out _); + ClassicAssert.AreEqual("b", key); + } + + [Test] + public void ChangePriorityMovesElementUp() + { + var q = new IndexedPriorityQueue(); + q.EnqueueOrUpdate("a", 50); + q.EnqueueOrUpdate("b", 10); + q.EnqueueOrUpdate("c", 30); + + // Move "a" to lowest — should become min + q.ChangePriority("a", 1); + + q.TryPeek(out var key, out _); + ClassicAssert.AreEqual("a", key); + } + + [Test] + public void RepeatedUpdatesDoNotBloatCount() + { + var q = new IndexedPriorityQueue(); + q.EnqueueOrUpdate("a", 10); + + for (int i = 0; i < 100; i++) + { + q.EnqueueOrUpdate("a", i); + } + + ClassicAssert.AreEqual(1, q.Count, "Repeated updates should not increase count"); + } + + #endregion + + #region Priority Lookup + + [Test] + public void GetPriorityReturnsCurrentPriority() + { + var q = new IndexedPriorityQueue(); + q.EnqueueOrUpdate("a", 42); + + q.GetPriority("a", out var priority); + ClassicAssert.AreEqual(42, priority); + } + + [Test] + public void TryGetPriorityReturnsFalseForMissing() + { + var q = new IndexedPriorityQueue(); + ClassicAssert.IsFalse(q.TryGetPriority("missing", out _)); + } + + [Test] + public void TryGetPriorityReflectsUpdates() + { + var q = new IndexedPriorityQueue(); + q.EnqueueOrUpdate("a", 10); + q.EnqueueOrUpdate("a", 99); + + ClassicAssert.IsTrue(q.TryGetPriority("a", out var priority)); + ClassicAssert.AreEqual(99, priority); + } + + #endregion + + #region Removal + + [Test] + public void TryRemoveReturnsFalseForMissing() + { + var q = new IndexedPriorityQueue(); + ClassicAssert.IsFalse(q.TryRemove("missing")); + } + + [Test] + public void TryRemoveRemovesElement() + { + var q = new IndexedPriorityQueue(); + q.EnqueueOrUpdate("a", 10); + q.EnqueueOrUpdate("b", 20); + + ClassicAssert.IsTrue(q.TryRemove("a")); + ClassicAssert.AreEqual(1, q.Count); + ClassicAssert.IsFalse(q.Exists("a")); + ClassicAssert.AreEqual("b", q.Dequeue()); + } + + [Test] + public void TryRemoveMiddleElementMaintainsHeapOrder() + { + var q = new IndexedPriorityQueue(); + q.EnqueueOrUpdate("a", 1); + q.EnqueueOrUpdate("b", 5); + q.EnqueueOrUpdate("c", 3); + q.EnqueueOrUpdate("d", 10); + q.EnqueueOrUpdate("e", 7); + + q.TryRemove("c"); + + // Remaining should dequeue in order: a(1), b(5), e(7), d(10) + ClassicAssert.AreEqual("a", q.Dequeue()); + ClassicAssert.AreEqual("b", q.Dequeue()); + ClassicAssert.AreEqual("e", q.Dequeue()); + ClassicAssert.AreEqual("d", q.Dequeue()); + } + + [Test] + public void TryRemoveLastElement() + { + var q = new IndexedPriorityQueue(); + q.EnqueueOrUpdate("only", 1); + + ClassicAssert.IsTrue(q.TryRemove("only")); + ClassicAssert.AreEqual(0, q.Count); + ClassicAssert.IsFalse(q.TryPeek(out _, out _)); + } + + #endregion + + #region Custom Comparer (byte[] keys) + + [Test] + public void ByteArrayComparerMatchesByContent() + { + var comparer = new ByteArrayComparer(); + var q = new IndexedPriorityQueue(comparer); + + var key1 = new byte[] { 1, 2, 3 }; + var key1Copy = new byte[] { 1, 2, 3 }; + var key2 = new byte[] { 4, 5, 6 }; + + q.EnqueueOrUpdate(key1, 100); + q.EnqueueOrUpdate(key2, 200); + + // Update using a different byte[] with same content + q.EnqueueOrUpdate(key1Copy, 50); + + ClassicAssert.AreEqual(2, q.Count, "Should match by content, not reference"); + + q.TryPeek(out var minKey, out var minPriority); + ClassicAssert.AreEqual(50, minPriority); + ClassicAssert.IsTrue(comparer.Equals(key1, minKey)); + } + + [Test] + public void ByteArrayComparerTryGetPriorityByContent() + { + var comparer = new ByteArrayComparer(); + var q = new IndexedPriorityQueue(comparer); + + var key = new byte[] { 10, 20 }; + var keyCopy = new byte[] { 10, 20 }; + + q.EnqueueOrUpdate(key, 42); + + ClassicAssert.IsTrue(q.TryGetPriority(keyCopy, out var priority)); + ClassicAssert.AreEqual(42, priority); + } + + [Test] + public void ByteArrayComparerTryRemoveByContent() + { + var comparer = new ByteArrayComparer(); + var q = new IndexedPriorityQueue(comparer); + + var key = new byte[] { 7, 8, 9 }; + var keyCopy = new byte[] { 7, 8, 9 }; + + q.EnqueueOrUpdate(key, 100); + ClassicAssert.IsTrue(q.TryRemove(keyCopy)); + ClassicAssert.AreEqual(0, q.Count); + } + + [Test] + public void WithoutComparerByteArrayUsesReferenceEquality() + { + var q = new IndexedPriorityQueue(); + + var key1 = new byte[] { 1, 2, 3 }; + var key1Copy = new byte[] { 1, 2, 3 }; + + q.EnqueueOrUpdate(key1, 100); + q.EnqueueOrUpdate(key1Copy, 50); + + // Without comparer, these are different keys + ClassicAssert.AreEqual(2, q.Count, "Default comparer uses reference equality for byte[]"); + } + + #endregion + + #region Stress / Ordering + + [Test] + public void LargeInsertDequeueMaintainsOrder() + { + var q = new IndexedPriorityQueue(); + var rng = new Random(42); + var count = 1000; + + for (int i = 0; i < count; i++) + { + q.EnqueueOrUpdate(i, rng.Next(0, 100000)); + } + + ClassicAssert.AreEqual(count, q.Count); + + int prev = int.MinValue; + while (q.Count > 0) + { + q.TryPeek(out _, out var priority); + ClassicAssert.GreaterOrEqual(priority, prev, "Dequeue order should be non-decreasing"); + prev = priority; + q.Dequeue(); + } + } + + [Test] + public void InterleavedInsertUpdateRemoveDequeue() + { + var q = new IndexedPriorityQueue(); + + q.EnqueueOrUpdate("a", 50); + q.EnqueueOrUpdate("b", 30); + q.EnqueueOrUpdate("c", 70); + q.EnqueueOrUpdate("d", 10); + + // Update + q.EnqueueOrUpdate("c", 5); // c becomes min + q.TryPeek(out var min, out _); + ClassicAssert.AreEqual("c", min); + + // Remove min + q.TryRemove("c"); + q.TryPeek(out min, out _); + ClassicAssert.AreEqual("d", min); + + // Add new + q.EnqueueOrUpdate("e", 1); + q.TryPeek(out min, out _); + ClassicAssert.AreEqual("e", min); + + // Drain and verify order + ClassicAssert.AreEqual("e", q.Dequeue()); // 1 + ClassicAssert.AreEqual("d", q.Dequeue()); // 10 + ClassicAssert.AreEqual("b", q.Dequeue()); // 30 + ClassicAssert.AreEqual("a", q.Dequeue()); // 50 + ClassicAssert.AreEqual(0, q.Count); + } + + [Test] + public void GrowAndShrinkBehavior() + { + var q = new IndexedPriorityQueue(); + + // Grow + for (int i = 0; i < 100; i++) + q.EnqueueOrUpdate(i, i); + + ClassicAssert.AreEqual(100, q.Count); + + // Shrink by removing most + for (int i = 0; i < 90; i++) + q.Dequeue(); + + ClassicAssert.AreEqual(10, q.Count); + + // Remaining should still be ordered + int prev = int.MinValue; + while (q.Count > 0) + { + q.TryPeek(out _, out var p); + ClassicAssert.GreaterOrEqual(p, prev); + prev = p; + q.Dequeue(); + } + } + + #endregion + + /// + /// Simple byte[] equality comparer for tests. + /// + private class ByteArrayComparer : IEqualityComparer + { + public bool Equals(byte[] x, byte[] y) + { + if (x == null && y == null) return true; + if (x == null || y == null) return false; + if (x.Length != y.Length) return false; + for (int i = 0; i < x.Length; i++) + if (x[i] != y[i]) return false; + return true; + } + + public int GetHashCode(byte[] obj) + { + if (obj == null) return 0; + int hash = 17; + foreach (var b in obj) + hash = hash * 31 + b; + return hash; + } + } + } +} \ No newline at end of file