Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 63 additions & 31 deletions config.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package hot

import (
"errors"
"context"
"time"

"github.com/samber/hot/pkg/base"
Expand Down Expand Up @@ -67,7 +67,8 @@ type HotCacheConfig[K comparable, V any] struct {
cacheName string
collectors []metrics.Collector

warmUpFn func() (map[K]V, []K, error)
preloadFn func() (map[K]V, []K, error)
preloadPeriod time.Duration
loaderFns LoaderChain[K, V]
revalidationLoaderFns LoaderChain[K, V]
revalidationErrorPolicy revalidationErrorPolicy
Expand Down Expand Up @@ -140,37 +141,45 @@ func (cfg HotCacheConfig[K, V]) WithSharding(nbr uint64, fn sharded.Hasher[K]) H
return cfg
}

// WithWarmUp preloads the cache with data from the provided function.
// This is useful for initializing the cache with frequently accessed data.
func (cfg HotCacheConfig[K, V]) WithWarmUp(fn func() (map[K]V, []K, error)) HotCacheConfig[K, V] {
cfg.warmUpFn = fn
// WithPreload loads the cache with data from the provided loader.
// This is useful for eagerly set the cache with frequently access data.
func (cfg HotCacheConfig[K, V]) WithPreload(preloader Preloader[K, V]) HotCacheConfig[K, V] {
if preloader.Timeout > 0 {
cfg.preloadFn = func() (map[K]V, []K, error) {
done := make(chan struct{}, 1)

var result map[K]V
var missing []K
var err error

go func() {
result, missing, err = preloader.Fn()
done <- struct{}{}
close(done)
}()

select {
case <-time.After(preloader.Timeout):
return nil, nil, context.DeadlineExceeded
case <-done:
return result, missing, err
}
}
} else {
cfg.preloadFn = preloader.Fn
}
cfg.preloadPeriod = preloader.Period
return cfg
}

// WithWarmUpWithTimeout preloads the cache with data from the provided function with a timeout.
// This is useful when the inner callback does not have its own timeout strategy.
// Deprecated: Use [HotCacheConfig.WithPreload] instead.
func (cfg HotCacheConfig[K, V]) WithWarmUp(fn func() (map[K]V, []K, error)) HotCacheConfig[K, V] {
return cfg.WithPreload(Preloader[K, V]{Fn: fn})
}

// Deprecated: Use [HotCacheConfig.WithPreload] instead.
func (cfg HotCacheConfig[K, V]) WithWarmUpWithTimeout(timeout time.Duration, fn func() (map[K]V, []K, error)) HotCacheConfig[K, V] {
cfg.warmUpFn = func() (map[K]V, []K, error) {
done := make(chan struct{}, 1)

var result map[K]V
var missing []K
var err error

go func() {
result, missing, err = fn()
done <- struct{}{}
close(done)
}()

select {
case <-time.After(timeout):
return nil, nil, errors.New("WarmUp timeout")
case <-done:
return result, missing, err
}
}
return cfg
return cfg.WithPreload(Preloader[K, V]{Fn: fn, Timeout: timeout})
}

// WithoutLocking disables mutex for the cache and improves internal performance.
Expand Down Expand Up @@ -267,9 +276,14 @@ func (cfg HotCacheConfig[K, V]) Build() *HotCache[K, V] {
cfg.collectors,
)

if cfg.warmUpFn != nil {
if cfg.preloadFn != nil {
hot.preloadFn = cfg.preloadFn
// @TODO: Check error?
hot.WarmUp(cfg.warmUpFn) //nolint:errcheck
hot.Preload(cfg.preloadFn) //nolint:errcheck

if cfg.preloadPeriod > 0 {
hot.StartPeriodicPreload(cfg.preloadPeriod)
}
}

if cfg.janitorEnabled {
Expand Down Expand Up @@ -299,3 +313,21 @@ func (cfg *HotCacheConfig[K, V]) buildPrometheusCollector(mode base.CacheMode) f
return collector
}
}

// Preloader holds the configuration for preloading the cache.
type Preloader[K comparable, V any] struct {
// Fn is the function that returns the data to preload into the cache.
// It returns a map of key-value pairs to load, a list of missing keys, and
// an error if any.
Fn func() (map[K]V, []K, error)

// Timeout is an optional duration to limit how long the preload function
// can run.
// If zero, no timeout is applied.
Timeout time.Duration

// Period is an optional duration for periodic cache refresh.
// If set, the perload function will be called repeatedly at this interval
// to refresh the cache.
Period time.Duration
}
92 changes: 64 additions & 28 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestHotCacheConfig(t *testing.T) {
cacheAlgo: LRU, cacheCapacity: 42, missingSharedCache: false, missingCacheAlgo: "", missingCacheCapacity: 0,
ttl: 0, stale: 0, jitterLambda: 0, jitterUpperBound: 0, shards: 0, shardingFn: nil,
lockingDisabled: false, janitorEnabled: false, prometheusMetricsEnabled: false, cacheName: "",
warmUpFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError,
preloadFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError,
onEviction: nil, copyOnRead: nil, copyOnWrite: nil,
}, opts)

Expand All @@ -44,7 +44,7 @@ func TestHotCacheConfig(t *testing.T) {
cacheAlgo: LRU, cacheCapacity: 42, missingSharedCache: true, missingCacheAlgo: "", missingCacheCapacity: 0,
ttl: 0, stale: 0, jitterLambda: 0, jitterUpperBound: 0, shards: 0, shardingFn: nil,
lockingDisabled: false, janitorEnabled: false, prometheusMetricsEnabled: false, cacheName: "",
warmUpFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError,
preloadFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError,
onEviction: nil, copyOnRead: nil, copyOnWrite: nil,
}, opts)

Expand All @@ -53,7 +53,7 @@ func TestHotCacheConfig(t *testing.T) {
cacheAlgo: LRU, cacheCapacity: 42, missingSharedCache: false, missingCacheAlgo: LFU, missingCacheCapacity: 21,
ttl: 0, stale: 0, jitterLambda: 0, jitterUpperBound: 0, shards: 0, shardingFn: nil,
lockingDisabled: false, janitorEnabled: false, prometheusMetricsEnabled: false, cacheName: "",
warmUpFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError,
preloadFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError,
onEviction: nil, copyOnRead: nil, copyOnWrite: nil,
}, opts)

Expand All @@ -65,7 +65,7 @@ func TestHotCacheConfig(t *testing.T) {
cacheAlgo: LRU, cacheCapacity: 42, missingSharedCache: false, missingCacheAlgo: LFU, missingCacheCapacity: 21,
ttl: 42 * time.Second, stale: 0, jitterLambda: 0, jitterUpperBound: 0, shards: 0, shardingFn: nil,
lockingDisabled: false, janitorEnabled: false, prometheusMetricsEnabled: false, cacheName: "",
warmUpFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError,
preloadFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError,
onEviction: nil, copyOnRead: nil, copyOnWrite: nil,
}, opts)

Expand All @@ -84,7 +84,7 @@ func TestHotCacheConfig(t *testing.T) {
cacheAlgo: LRU, cacheCapacity: 42, missingSharedCache: false, missingCacheAlgo: LFU, missingCacheCapacity: 21,
ttl: 42 * time.Second, stale: 0, jitterLambda: 2, jitterUpperBound: time.Second, shards: 0, shardingFn: nil,
lockingDisabled: false, janitorEnabled: false, prometheusMetricsEnabled: false, cacheName: "",
warmUpFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError,
preloadFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError,
onEviction: nil, copyOnRead: nil, copyOnWrite: nil,
}, opts)

Expand All @@ -96,7 +96,7 @@ func TestHotCacheConfig(t *testing.T) {
cacheAlgo: LRU, cacheCapacity: 42, missingSharedCache: false, missingCacheAlgo: LFU, missingCacheCapacity: 21,
ttl: 42 * time.Second, stale: 0, jitterLambda: 2, jitterUpperBound: time.Second, shards: 0, shardingFn: nil,
lockingDisabled: true, janitorEnabled: false, prometheusMetricsEnabled: false, cacheName: "",
warmUpFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError,
preloadFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError,
onEviction: nil, copyOnRead: nil, copyOnWrite: nil,
}, opts)

Expand All @@ -105,7 +105,7 @@ func TestHotCacheConfig(t *testing.T) {
cacheAlgo: LRU, cacheCapacity: 42, missingSharedCache: false, missingCacheAlgo: LFU, missingCacheCapacity: 21,
ttl: 42 * time.Second, stale: 0, jitterLambda: 2, jitterUpperBound: time.Second, shards: 0, shardingFn: nil,
lockingDisabled: true, janitorEnabled: true, prometheusMetricsEnabled: false, cacheName: "",
warmUpFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError,
preloadFn: nil, loaderFns: nil, revalidationLoaderFns: nil, revalidationErrorPolicy: DropOnError,
onEviction: nil, copyOnRead: nil, copyOnWrite: nil,
}, opts)

Expand Down Expand Up @@ -159,28 +159,64 @@ func TestWithWarmUpWithTimeout(t *testing.T) {
is := assert.New(t)
t.Parallel()

// Minimal test for backward compatibility - deprecated API
opts := NewHotCache[string, int](LRU, 42)

// Test successful warmup
warmUpFn := func() (map[string]int, []string, error) {
return map[string]int{"key1": 1, "key2": 2}, []string{"missing1"}, nil
preloadFn := func() (map[string]int, []string, error) {
return map[string]int{"key1": 1}, []string{}, nil
}

opts = opts.WithWarmUpWithTimeout(100*time.Millisecond, warmUpFn)
is.NotNil(opts.warmUpFn)
opts = opts.WithWarmUpWithTimeout(100*time.Millisecond, preloadFn)
is.NotNil(opts.preloadFn)
}

// Test timeout
slowWarmUpFn := func() (map[string]int, []string, error) {
time.Sleep(200 * time.Millisecond)
return map[string]int{"key1": 1}, []string{}, nil
func TestWithPreload(t *testing.T) {
is := assert.New(t)
t.Parallel()

counter := 0
preloadFn := func() (map[string]int, []string, error) {
counter++
if counter > 5 {
time.Sleep(200 * time.Millisecond)
}
return map[string]int{"key1": counter}, []string{}, nil
}

opts = opts.WithWarmUpWithTimeout(50*time.Millisecond, slowWarmUpFn)
result, missing, err := opts.warmUpFn()
is.Nil(result)
is.Nil(missing)
is.Error(err)
is.Contains(err.Error(), "WarmUp timeout")
cache := NewHotCache[string, int](LRU, 42).
WithTTL(500 * time.Millisecond).WithPreload(Preloader[string, int]{
Fn: preloadFn,
Timeout: 100 * time.Millisecond,
Period: 50 * time.Millisecond,
}).Build()

// Verify initial preload
v, ok, err := cache.Get("key1")
is.True(ok)
is.NoError(err)
is.Equal(1, v)

// Wait for periodic refresh
time.Sleep(200 * time.Millisecond)

// Verify cache was refreshed
v, ok, err = cache.Get("key1")
is.True(ok)
is.NoError(err)
is.Greater(v, 1)

// Stop periodic preload
cache.StopPeriodicPreload()
vAfterStop, ok, err := cache.Get("key1")
is.True(ok)
is.NoError(err)

time.Sleep(150 * time.Millisecond)

// Verify cache was NOT refreshed after stop
vFinal, ok, err := cache.Get("key1")
is.True(ok)
is.NoError(err)
is.Equal(vAfterStop, vFinal, "Cache should not be refreshed after StopPeriodicPreload")
}

func TestWithEvictionCallback(t *testing.T) {
Expand Down Expand Up @@ -272,23 +308,23 @@ func TestBuildWithJanitorAndLockingConflict(t *testing.T) {
})
}

func TestBuildWithWarmUp(t *testing.T) {
func TestBuildWithPreload(t *testing.T) {
is := assert.New(t)
t.Parallel()

warmUpFn := func() (map[string]int, []string, error) {
preloader := Preloader[string, int]{Fn: func() (map[string]int, []string, error) {
return map[string]int{"key1": 1, "key2": 2}, []string{"missing1"}, nil
}
}}

is.Panics(func() {
_ = NewHotCache[string, int](LRU, 42).
WithWarmUp(warmUpFn).
WithPreload(preloader).
Build()
})

cache := NewHotCache[string, int](LRU, 42).
WithMissingCache(LFU, 21).
WithWarmUp(warmUpFn).Build()
WithPreload(preloader).Build()
is.NotNil(cache)
}

Expand Down
Loading