From 26d7859c49fcc6c69f0afe5ffd8b47f7f850d179 Mon Sep 17 00:00:00 2001 From: Joanna Wang Date: Mon, 6 Apr 2026 22:32:33 -0700 Subject: [PATCH] feat: support periodic preload --- config.go | 94 +++++++++++++++++++++++++++++++++----------------- config_test.go | 92 +++++++++++++++++++++++++++++++++--------------- hot.go | 77 +++++++++++++++++++++++++++++++++++++++-- hot_test.go | 56 +++++++++++++++++++++++------- 4 files changed, 245 insertions(+), 74 deletions(-) diff --git a/config.go b/config.go index bbfa334..c43f630 100644 --- a/config.go +++ b/config.go @@ -1,7 +1,7 @@ package hot import ( - "errors" + "context" "time" "github.com/samber/hot/pkg/base" @@ -67,7 +67,8 @@ type HotCacheConfig[K comparable, V any] struct { cacheName string collectors []metrics.Collector - warmUpFn func() (map[K]V, []K, error) + preloadFn func() (map[K]V, []K, error) + preloadPeriod time.Duration loaderFns LoaderChain[K, V] revalidationLoaderFns LoaderChain[K, V] revalidationErrorPolicy revalidationErrorPolicy @@ -140,37 +141,45 @@ func (cfg HotCacheConfig[K, V]) WithSharding(nbr uint64, fn sharded.Hasher[K]) H return cfg } -// WithWarmUp preloads the cache with data from the provided function. -// This is useful for initializing the cache with frequently accessed data. -func (cfg HotCacheConfig[K, V]) WithWarmUp(fn func() (map[K]V, []K, error)) HotCacheConfig[K, V] { - cfg.warmUpFn = fn +// WithPreload loads the cache with data from the provided loader. +// This is useful for eagerly set the cache with frequently access data. +func (cfg HotCacheConfig[K, V]) WithPreload(preloader Preloader[K, V]) HotCacheConfig[K, V] { + if preloader.Timeout > 0 { + cfg.preloadFn = func() (map[K]V, []K, error) { + done := make(chan struct{}, 1) + + var result map[K]V + var missing []K + var err error + + go func() { + result, missing, err = preloader.Fn() + done <- struct{}{} + close(done) + }() + + select { + case <-time.After(preloader.Timeout): + return nil, nil, context.DeadlineExceeded + case <-done: + return result, missing, err + } + } + } else { + cfg.preloadFn = preloader.Fn + } + cfg.preloadPeriod = preloader.Period return cfg } -// WithWarmUpWithTimeout preloads the cache with data from the provided function with a timeout. -// This is useful when the inner callback does not have its own timeout strategy. +// Deprecated: Use [HotCacheConfig.WithPreload] instead. +func (cfg HotCacheConfig[K, V]) WithWarmUp(fn func() (map[K]V, []K, error)) HotCacheConfig[K, V] { + return cfg.WithPreload(Preloader[K, V]{Fn: fn}) +} + +// Deprecated: Use [HotCacheConfig.WithPreload] instead. func (cfg HotCacheConfig[K, V]) WithWarmUpWithTimeout(timeout time.Duration, fn func() (map[K]V, []K, error)) HotCacheConfig[K, V] { - cfg.warmUpFn = func() (map[K]V, []K, error) { - done := make(chan struct{}, 1) - - var result map[K]V - var missing []K - var err error - - go func() { - result, missing, err = fn() - done <- struct{}{} - close(done) - }() - - select { - case <-time.After(timeout): - return nil, nil, errors.New("WarmUp timeout") - case <-done: - return result, missing, err - } - } - return cfg + return cfg.WithPreload(Preloader[K, V]{Fn: fn, Timeout: timeout}) } // WithoutLocking disables mutex for the cache and improves internal performance. @@ -267,9 +276,14 @@ func (cfg HotCacheConfig[K, V]) Build() *HotCache[K, V] { cfg.collectors, ) - if cfg.warmUpFn != nil { + if cfg.preloadFn != nil { + hot.preloadFn = cfg.preloadFn // @TODO: Check error? - hot.WarmUp(cfg.warmUpFn) //nolint:errcheck + hot.Preload(cfg.preloadFn) //nolint:errcheck + + if cfg.preloadPeriod > 0 { + hot.StartPeriodicPreload(cfg.preloadPeriod) + } } if cfg.janitorEnabled { @@ -299,3 +313,21 @@ func (cfg *HotCacheConfig[K, V]) buildPrometheusCollector(mode base.CacheMode) f return collector } } + +// Preloader holds the configuration for preloading the cache. +type Preloader[K comparable, V any] struct { + // Fn is the function that returns the data to preload into the cache. + // It returns a map of key-value pairs to load, a list of missing keys, and + // an error if any. + Fn func() (map[K]V, []K, error) + + // Timeout is an optional duration to limit how long the preload function + // can run. + // If zero, no timeout is applied. + Timeout time.Duration + + // Period is an optional duration for periodic cache refresh. + // If set, the perload function will be called repeatedly at this interval + // to refresh the cache. + Period time.Duration +} diff --git a/config_test.go b/config_test.go index cb34a4e..fb85159 100644 --- a/config_test.go +++ b/config_test.go @@ -35,7 +35,7 @@ func TestHotCacheConfig(t *testing.T) { cacheAlgo: LRU, cacheCapacity: 42, missingSharedCache: false, missingCacheAlgo: "", missingCacheCapacity: 0, ttl: 0, stale: 0, jitterLambda: 0, jitterUpperBound: 0, shards: 0, shardingFn: nil, lockingDisabled: false, janitorEnabled: false, prometheusMetricsEnabled: false, cacheName: "", - warmUpFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError, + preloadFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError, onEviction: nil, copyOnRead: nil, copyOnWrite: nil, }, opts) @@ -44,7 +44,7 @@ func TestHotCacheConfig(t *testing.T) { cacheAlgo: LRU, cacheCapacity: 42, missingSharedCache: true, missingCacheAlgo: "", missingCacheCapacity: 0, ttl: 0, stale: 0, jitterLambda: 0, jitterUpperBound: 0, shards: 0, shardingFn: nil, lockingDisabled: false, janitorEnabled: false, prometheusMetricsEnabled: false, cacheName: "", - warmUpFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError, + preloadFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError, onEviction: nil, copyOnRead: nil, copyOnWrite: nil, }, opts) @@ -53,7 +53,7 @@ func TestHotCacheConfig(t *testing.T) { cacheAlgo: LRU, cacheCapacity: 42, missingSharedCache: false, missingCacheAlgo: LFU, missingCacheCapacity: 21, ttl: 0, stale: 0, jitterLambda: 0, jitterUpperBound: 0, shards: 0, shardingFn: nil, lockingDisabled: false, janitorEnabled: false, prometheusMetricsEnabled: false, cacheName: "", - warmUpFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError, + preloadFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError, onEviction: nil, copyOnRead: nil, copyOnWrite: nil, }, opts) @@ -65,7 +65,7 @@ func TestHotCacheConfig(t *testing.T) { cacheAlgo: LRU, cacheCapacity: 42, missingSharedCache: false, missingCacheAlgo: LFU, missingCacheCapacity: 21, ttl: 42 * time.Second, stale: 0, jitterLambda: 0, jitterUpperBound: 0, shards: 0, shardingFn: nil, lockingDisabled: false, janitorEnabled: false, prometheusMetricsEnabled: false, cacheName: "", - warmUpFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError, + preloadFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError, onEviction: nil, copyOnRead: nil, copyOnWrite: nil, }, opts) @@ -84,7 +84,7 @@ func TestHotCacheConfig(t *testing.T) { cacheAlgo: LRU, cacheCapacity: 42, missingSharedCache: false, missingCacheAlgo: LFU, missingCacheCapacity: 21, ttl: 42 * time.Second, stale: 0, jitterLambda: 2, jitterUpperBound: time.Second, shards: 0, shardingFn: nil, lockingDisabled: false, janitorEnabled: false, prometheusMetricsEnabled: false, cacheName: "", - warmUpFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError, + preloadFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError, onEviction: nil, copyOnRead: nil, copyOnWrite: nil, }, opts) @@ -96,7 +96,7 @@ func TestHotCacheConfig(t *testing.T) { cacheAlgo: LRU, cacheCapacity: 42, missingSharedCache: false, missingCacheAlgo: LFU, missingCacheCapacity: 21, ttl: 42 * time.Second, stale: 0, jitterLambda: 2, jitterUpperBound: time.Second, shards: 0, shardingFn: nil, lockingDisabled: true, janitorEnabled: false, prometheusMetricsEnabled: false, cacheName: "", - warmUpFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError, + preloadFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError, onEviction: nil, copyOnRead: nil, copyOnWrite: nil, }, opts) @@ -105,7 +105,7 @@ func TestHotCacheConfig(t *testing.T) { cacheAlgo: LRU, cacheCapacity: 42, missingSharedCache: false, missingCacheAlgo: LFU, missingCacheCapacity: 21, ttl: 42 * time.Second, stale: 0, jitterLambda: 2, jitterUpperBound: time.Second, shards: 0, shardingFn: nil, lockingDisabled: true, janitorEnabled: true, prometheusMetricsEnabled: false, cacheName: "", - warmUpFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError, + preloadFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError, onEviction: nil, copyOnRead: nil, copyOnWrite: nil, }, opts) @@ -159,28 +159,64 @@ func TestWithWarmUpWithTimeout(t *testing.T) { is := assert.New(t) t.Parallel() + // Minimal test for backward compatibility - deprecated API opts := NewHotCache[string, int](LRU, 42) - - // Test successful warmup - warmUpFn := func() (map[string]int, []string, error) { - return map[string]int{"key1": 1, "key2": 2}, []string{"missing1"}, nil + preloadFn := func() (map[string]int, []string, error) { + return map[string]int{"key1": 1}, []string{}, nil } - opts = opts.WithWarmUpWithTimeout(100*time.Millisecond, warmUpFn) - is.NotNil(opts.warmUpFn) + opts = opts.WithWarmUpWithTimeout(100*time.Millisecond, preloadFn) + is.NotNil(opts.preloadFn) +} - // Test timeout - slowWarmUpFn := func() (map[string]int, []string, error) { - time.Sleep(200 * time.Millisecond) - return map[string]int{"key1": 1}, []string{}, nil +func TestWithPreload(t *testing.T) { + is := assert.New(t) + t.Parallel() + + counter := 0 + preloadFn := func() (map[string]int, []string, error) { + counter++ + if counter > 5 { + time.Sleep(200 * time.Millisecond) + } + return map[string]int{"key1": counter}, []string{}, nil } - opts = opts.WithWarmUpWithTimeout(50*time.Millisecond, slowWarmUpFn) - result, missing, err := opts.warmUpFn() - is.Nil(result) - is.Nil(missing) - is.Error(err) - is.Contains(err.Error(), "WarmUp timeout") + cache := NewHotCache[string, int](LRU, 42). + WithTTL(500 * time.Millisecond).WithPreload(Preloader[string, int]{ + Fn: preloadFn, + Timeout: 100 * time.Millisecond, + Period: 50 * time.Millisecond, + }).Build() + + // Verify initial preload + v, ok, err := cache.Get("key1") + is.True(ok) + is.NoError(err) + is.Equal(1, v) + + // Wait for periodic refresh + time.Sleep(200 * time.Millisecond) + + // Verify cache was refreshed + v, ok, err = cache.Get("key1") + is.True(ok) + is.NoError(err) + is.Greater(v, 1) + + // Stop periodic preload + cache.StopPeriodicPreload() + vAfterStop, ok, err := cache.Get("key1") + is.True(ok) + is.NoError(err) + + time.Sleep(150 * time.Millisecond) + + // Verify cache was NOT refreshed after stop + vFinal, ok, err := cache.Get("key1") + is.True(ok) + is.NoError(err) + is.Equal(vAfterStop, vFinal, "Cache should not be refreshed after StopPeriodicPreload") } func TestWithEvictionCallback(t *testing.T) { @@ -272,23 +308,23 @@ func TestBuildWithJanitorAndLockingConflict(t *testing.T) { }) } -func TestBuildWithWarmUp(t *testing.T) { +func TestBuildWithPreload(t *testing.T) { is := assert.New(t) t.Parallel() - warmUpFn := func() (map[string]int, []string, error) { + preloader := Preloader[string, int]{Fn: func() (map[string]int, []string, error) { return map[string]int{"key1": 1, "key2": 2}, []string{"missing1"}, nil - } + }} is.Panics(func() { _ = NewHotCache[string, int](LRU, 42). - WithWarmUp(warmUpFn). + WithPreload(preloader). Build() }) cache := NewHotCache[string, int](LRU, 42). WithMissingCache(LFU, 21). - WithWarmUp(warmUpFn).Build() + WithPreload(preloader).Build() is.NotNil(cache) } diff --git a/hot.go b/hot.go index 7a84647..596b5f2 100644 --- a/hot.go +++ b/hot.go @@ -70,6 +70,14 @@ type HotCache[K comparable, V any] struct { stopJanitor chan struct{} janitorDone chan struct{} + // preloadMutex protects the periodic preload state + preloadMutex sync.RWMutex + preloadTicker *time.Ticker + preloadStopOnce *sync.Once + stopPreload chan struct{} + preloadDone chan struct{} + preloadFn func() (map[K]V, []K, error) + cache base.InMemoryCache[K, *item[V]] missingSharedCache bool missingCache base.InMemoryCache[K, *item[V]] @@ -508,10 +516,10 @@ func (c *HotCache[K, V]) Len() int { return c.cache.Len() } -// WarmUp preloads the cache with data from the provided loader function. +// Preload preloads the cache with data from the provided loader function. // This is useful for initializing the cache with frequently accessed data. // The loader function should return a map of key-value pairs and a slice of missing keys. -func (c *HotCache[K, V]) WarmUp(loader func() (map[K]V, []K, error)) error { +func (c *HotCache[K, V]) Preload(loader func() (map[K]V, []K, error)) error { if loader == nil { return nil } @@ -535,6 +543,71 @@ func (c *HotCache[K, V]) WarmUp(loader func() (map[K]V, []K, error)) error { return nil } +// Deprecated: Use [HotCache.Preload] instead. +func (c *HotCache[K, V]) WarmUp(loader func() (map[K]V, []K, error)) error { + return c.Preload(loader) +} + +// StartPeriodicPreload starts a background goroutine that periodically preloads +// the cache. +// This method is safe to call multiple times but only the first call will start +// the periodic preload. +func (c *HotCache[K, V]) StartPeriodicPreload(period time.Duration) { + c.preloadMutex.Lock() + defer c.preloadMutex.Unlock() + + if c.preloadTicker != nil { + return + } + + if c.preloadFn == nil { + return + } + + c.preloadTicker = time.NewTicker(period) + c.preloadStopOnce = &sync.Once{} + c.stopPreload = make(chan struct{}) + c.preloadDone = make(chan struct{}) + + // Start the periodic preload goroutine + go func() { + defer func() { + c.preloadMutex.Lock() + c.preloadTicker = nil + c.preloadMutex.Unlock() + close(c.preloadDone) + }() + + for { + select { + case <-c.stopPreload: + return + case <-c.preloadTicker.C: + c.Preload(c.preloadFn) //nolint:errcheck + } + } + }() +} + +// StopPeriodicPreload stops the background periodic preload goroutine and +// cleans up resources. +// This method is safe to call multiple times and will wait for the periodic +// preload to fully stop. +func (c *HotCache[K, V]) StopPeriodicPreload() { + c.preloadMutex.RLock() + if c.preloadTicker == nil { + c.preloadMutex.RUnlock() + return + } + c.preloadMutex.RUnlock() + + c.preloadStopOnce.Do(func() { + close(c.stopPreload) + c.preloadTicker.Stop() + <-c.preloadDone + }) +} + // Janitor starts a background goroutine that periodically removes expired items from the cache. // The janitor runs until StopJanitor() is called or the cache is garbage collected. // This method is safe to call multiple times, but only the first call will start the janitor. diff --git a/hot_test.go b/hot_test.go index 1357106..a23ef60 100644 --- a/hot_test.go +++ b/hot_test.go @@ -30,7 +30,16 @@ func TestNewHotCache(t *testing.T) { // ttl, stale, jitter cache = newHotCache(safeLru, false, nil, 42_000, 21_000, 2, time.Second, nil, nil, DropOnError, nil, nil, nil, nil) - is.Equal(&HotCache[int, int]{sync.RWMutex{}, nil, nil, nil, nil, safeLru, false, nil, 42_000, 21_000, 2, time.Second, nil, nil, DropOnError, nil, nil, nil, singleflightx.Group[int, int]{}, nil}, cache) + is.Equal(&HotCache[int, int]{ + janitorMutex: sync.RWMutex{}, ticker: nil, stopOnce: nil, stopJanitor: nil, janitorDone: nil, + preloadMutex: sync.RWMutex{}, preloadTicker: nil, preloadStopOnce: nil, stopPreload: nil, preloadDone: nil, preloadFn: nil, + cache: safeLru, missingSharedCache: false, missingCache: nil, + ttlNano: 42_000, staleNano: 21_000, jitterLambda: 2, jitterUpperBound: time.Second, + loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError, + onEviction: nil, copyOnRead: nil, copyOnWrite: nil, + group: singleflightx.Group[int, int]{}, prometheusCollectors: nil}, + cache, + ) // @TODO: test locks // @TODO: more tests @@ -1224,18 +1233,22 @@ func TestHotCache_Len(t *testing.T) { is.Equal(4, cache.Len()) } -func TestHotCache_WarmUp(t *testing.T) { +func TestHotCache_Preload(t *testing.T) { is := assert.New(t) t.Parallel() + preloader := Preloader[string, int]{ + Fn: func() (map[string]int, []string, error) { + return map[string]int{"a": 1}, []string{"b"}, nil + }, + } + is.Panics(func() { _ = NewHotCache[string, int](LRU, 10). WithCopyOnWrite(func(nb int) int { return nb * 2 }). - WithWarmUp(func() (map[string]int, []string, error) { - return map[string]int{"a": 1}, []string{"b"}, nil - }). + WithPreload(preloader). Build() }) @@ -1244,8 +1257,10 @@ func TestHotCache_WarmUp(t *testing.T) { WithCopyOnWrite(func(nb int) int { return nb * 2 }). - WithWarmUp(func() (map[string]int, []string, error) { - return map[string]int{"a": 1}, []string{}, nil + WithPreload(Preloader[string, int]{ + Fn: func() (map[string]int, []string, error) { + return map[string]int{"a": 1}, []string{}, nil + }, }). Build() time.Sleep(5 * time.Millisecond) @@ -1259,9 +1274,7 @@ func TestHotCache_WarmUp(t *testing.T) { WithCopyOnWrite(func(nb int) int { return nb * 2 }). - WithWarmUp(func() (map[string]int, []string, error) { - return map[string]int{"a": 1}, []string{"b"}, nil - }). + WithPreload(preloader). WithMissingSharedCache(). Build() time.Sleep(5 * time.Millisecond) @@ -1279,9 +1292,7 @@ func TestHotCache_WarmUp(t *testing.T) { WithCopyOnWrite(func(nb int) int { return nb * 2 }). - WithWarmUp(func() (map[string]int, []string, error) { - return map[string]int{"a": 1}, []string{"b"}, nil - }). + WithPreload(preloader). WithMissingCache(LRU, 10). Build() time.Sleep(5 * time.Millisecond) @@ -1297,6 +1308,25 @@ func TestHotCache_WarmUp(t *testing.T) { time.Sleep(10 * time.Millisecond) // purge revalidation goroutine } +func TestHotCache_WarmUp(t *testing.T) { + is := assert.New(t) + t.Parallel() + + // Minimal test for backward compatibility - deprecated API + cache := NewHotCache[string, int](LRU, 10). + WithWarmUp(func() (map[string]int, []string, error) { + return map[string]int{"a": 1}, []string{}, nil + }). + Build() + time.Sleep(5 * time.Millisecond) + v, ok, err := cache.Get("a") + is.True(ok) + is.NoError(err) + is.Equal(1, v) + + time.Sleep(10 * time.Millisecond) // purge revalidation goroutine +} + func TestHotCache_Janitor(t *testing.T) { is := assert.New(t) t.Parallel()