Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 36 additions & 8 deletions proxy/internal/lb/mem.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"fmt"
"math/rand"
stdsync "sync"
"time"

"proxy/internal/env"
Expand All @@ -23,6 +24,8 @@ type MemoryLoadBalancer struct {
servers sync.Map[string, *SRSServer]
// The picked server to service client by specified stream URL, key is stream url.
picked sync.Map[string, *SRSServer]
// Mutex to protect the Pick operation when reselecting servers.
pickMutex stdsync.Mutex
// The HLS streaming, key is stream URL.
hlsStreamURL sync.Map[string, HLSPlayStream]
// The HLS streaming, key is SPBHID.
Expand Down Expand Up @@ -75,12 +78,36 @@ func (v *MemoryLoadBalancer) Update(ctx context.Context, server *SRSServer) erro
}

func (v *MemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) {
// Always proxy to the same server for the same stream URL.
if server, ok := v.picked.Load(streamURL); ok {
return server, nil
// First check (without lock): fast path for healthy servers.
// Try to load the previously picked server for this stream URL.
if pickedServer, ok := v.picked.Load(streamURL); ok {
// Check if the server still exists and is healthy by getting its latest state from servers map.
if actualServer, exists := v.servers.Load(pickedServer.ID()); exists {
if time.Since(actualServer.UpdatedAt) < ServerAliveDuration {
// Server is still healthy, return the latest server state.
// Most requests will return here without acquiring the lock.
return actualServer, nil
}
}
}

// Gather all servers that were alive within the last few seconds.
// Server is unhealthy or doesn't exist, need to pick a new one.
// Acquire lock to ensure only one goroutine picks a new server at a time.
v.pickMutex.Lock()
defer v.pickMutex.Unlock()

// Second check (with lock): another goroutine might have already updated the server.
if pickedServer, ok := v.picked.Load(streamURL); ok {
if actualServer, exists := v.servers.Load(pickedServer.ID()); exists {
if time.Since(actualServer.UpdatedAt) < ServerAliveDuration {
// Another goroutine has already picked a healthy server.
return actualServer, nil
}
}
}

// Now we're certain we need to pick a new server.
// Gather all servers that are alive within the last few seconds.
var servers []*SRSServer
v.servers.Range(func(key string, server *SRSServer) bool {
if time.Since(server.UpdatedAt) < ServerAliveDuration {
Expand All @@ -89,7 +116,7 @@ func (v *MemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSSe
return true
})

// If no servers available, use all possible servers.
// If no healthy servers available, use all possible servers as fallback.
if len(servers) == 0 {
v.servers.Range(func(key string, server *SRSServer) bool {
servers = append(servers, server)
Expand All @@ -104,9 +131,10 @@ func (v *MemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSSe

// Pick a server randomly from servers. Use global rand which is thread-safe since Go 1.20.
// For older Go versions, this is still safe as we're only reading from the servers slice.
server := servers[rand.Intn(len(servers))]
v.picked.Store(streamURL, server)
return server, nil
newServer := servers[rand.Intn(len(servers))]
v.picked.Store(streamURL, newServer)

return newServer, nil
}

func (v *MemoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (HLSPlayStream, error) {
Expand Down
286 changes: 286 additions & 0 deletions proxy/internal/lb/mem_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
// Copyright (c) 2025 Winlin
//
// SPDX-License-Identifier: MIT
package lb

import (
"context"
"sync"
"testing"
"time"
)

// mockEnvironment is a mock implementation of the Environment interface for testing.
type mockEnvironment struct{}

func (m *mockEnvironment) GoPprof() string { return "" }
func (m *mockEnvironment) GraceQuitTimeout() string { return "20s" }
func (m *mockEnvironment) ForceQuitTimeout() string { return "30s" }
func (m *mockEnvironment) HttpAPI() string { return "1985" }
func (m *mockEnvironment) HttpServer() string { return "8080" }
func (m *mockEnvironment) RtmpServer() string { return "1935" }
func (m *mockEnvironment) WebRTCServer() string { return "8000" }
func (m *mockEnvironment) SRTServer() string { return "10080" }
func (m *mockEnvironment) SystemAPI() string { return "12025" }
func (m *mockEnvironment) StaticFiles() string { return "" }
func (m *mockEnvironment) LoadBalancerType() string { return "memory" }
func (m *mockEnvironment) RedisHost() string { return "127.0.0.1" }
func (m *mockEnvironment) RedisPort() string { return "6379" }
func (m *mockEnvironment) RedisPassword() string { return "" }
func (m *mockEnvironment) RedisDB() string { return "0" }
func (m *mockEnvironment) DefaultBackendEnabled() string { return "off" }
func (m *mockEnvironment) DefaultBackendIP() string { return "127.0.0.1" }
func (m *mockEnvironment) DefaultBackendRTMP() string { return "1935" }
func (m *mockEnvironment) DefaultBackendHttp() string { return "8080" }
func (m *mockEnvironment) DefaultBackendAPI() string { return "1985" }
func (m *mockEnvironment) DefaultBackendRTC() string { return "8000" }
func (m *mockEnvironment) DefaultBackendSRT() string { return "10080" }

// TestPick_HealthyServer tests that Pick returns the same server when it's healthy.
func TestPick_HealthyServer(t *testing.T) {
ctx := context.Background()
lb := NewMemoryLoadBalancer(&mockEnvironment{}).(*MemoryLoadBalancer)

// Create and register a healthy server
server1 := &SRSServer{
IP: "192.168.1.1",
ServerID: "server1",
ServiceID: "service1",
PID: "1234",
UpdatedAt: time.Now(), // Fresh timestamp
}
err := lb.Update(ctx, server1)
if err != nil {
t.Fatalf("Failed to update server: %v", err)
}

streamURL := "rtmp://test/live/stream1"

// First pick
picked1, err := lb.Pick(ctx, streamURL)
if err != nil {
t.Fatalf("First pick failed: %v", err)
}
if picked1.ID() != server1.ID() {
t.Errorf("Expected server %v, got %v", server1.ID(), picked1.ID())
}

// Second pick should return the same server
picked2, err := lb.Pick(ctx, streamURL)
if err != nil {
t.Fatalf("Second pick failed: %v", err)
}
if picked2.ID() != server1.ID() {
t.Errorf("Expected same server %v, got %v", server1.ID(), picked2.ID())
}

// Verify both picks returned the same server
if picked1.ID() != picked2.ID() {
t.Errorf("Picks should return same server, got %v and %v", picked1.ID(), picked2.ID())
}
}

// TestPick_ExpiredServerSwitchesToNew tests that when a server expires, Pick switches to a new healthy server.
func TestPick_ExpiredServerSwitchesToNew(t *testing.T) {
ctx := context.Background()
lb := NewMemoryLoadBalancer(&mockEnvironment{}).(*MemoryLoadBalancer)

// Create an expired server (updated 400 seconds ago, beyond the 300s threshold)
oldServer := &SRSServer{
IP: "192.168.1.1",
ServerID: "server-old",
ServiceID: "service-old",
PID: "1111",
UpdatedAt: time.Now().Add(-400 * time.Second), // Expired
}
err := lb.Update(ctx, oldServer)
if err != nil {
t.Fatalf("Failed to update old server: %v", err)
}

// Create a healthy server
newServer := &SRSServer{
IP: "192.168.1.2",
ServerID: "server-new",
ServiceID: "service-new",
PID: "2222",
UpdatedAt: time.Now(), // Fresh timestamp
}
err = lb.Update(ctx, newServer)
if err != nil {
t.Fatalf("Failed to update new server: %v", err)
}

streamURL := "rtmp://test/live/stream1"

// First pick - should get the old server initially if it was picked before expiry
// Let's manually set picked to old server to simulate this scenario
lb.picked.Store(streamURL, oldServer)

// Now pick - should detect old server is expired and switch to new server
picked, err := lb.Pick(ctx, streamURL)
if err != nil {
t.Fatalf("Pick failed: %v", err)
}

// Should have switched to the new healthy server
if picked.ID() != newServer.ID() {
t.Errorf("Expected to switch to new server %v, but got %v", newServer.ID(), picked.ID())
}

// Verify the server is healthy
if time.Since(picked.UpdatedAt) >= ServerAliveDuration {
t.Errorf("Picked server should be healthy, but UpdatedAt is %v", picked.UpdatedAt)
}
}

// TestPick_ConcurrentAccess tests thread safety when multiple goroutines call Pick simultaneously.
func TestPick_ConcurrentAccess(t *testing.T) {
ctx := context.Background()
lb := NewMemoryLoadBalancer(&mockEnvironment{}).(*MemoryLoadBalancer)

// Create multiple healthy servers
for i := 1; i <= 3; i++ {
server := &SRSServer{
IP: "192.168.1." + string(rune('0'+i)),
ServerID: "server" + string(rune('0'+i)),
ServiceID: "service" + string(rune('0'+i)),
PID: string(rune('0' + i)),
UpdatedAt: time.Now(),
}
err := lb.Update(ctx, server)
if err != nil {
t.Fatalf("Failed to update server%d: %v", i, err)
}
}

streamURL := "rtmp://test/live/concurrent-stream"
numGoroutines := 100
var wg sync.WaitGroup
results := make(chan string, numGoroutines)

// Launch multiple goroutines to call Pick concurrently
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
picked, err := lb.Pick(ctx, streamURL)
if err != nil {
t.Errorf("Pick failed in goroutine: %v", err)
return
}
results <- picked.ID()
}()
}

wg.Wait()
close(results)

// Collect all results
serverIDs := make(map[string]int)
for id := range results {
serverIDs[id]++
}

// All goroutines should have picked the same server
if len(serverIDs) != 1 {
t.Errorf("Expected all goroutines to pick the same server, but got %d different servers: %v",
len(serverIDs), serverIDs)
}
}

// TestPick_ConcurrentExpiration tests thread safety when server expires during concurrent access.
func TestPick_ConcurrentExpiration(t *testing.T) {
ctx := context.Background()
lb := NewMemoryLoadBalancer(&mockEnvironment{}).(*MemoryLoadBalancer)

// Create an expired server
oldServer := &SRSServer{
IP: "192.168.1.1",
ServerID: "server-old",
ServiceID: "service-old",
PID: "1111",
UpdatedAt: time.Now().Add(-400 * time.Second), // Expired
}
err := lb.Update(ctx, oldServer)
if err != nil {
t.Fatalf("Failed to update old server: %v", err)
}

// Create healthy servers
for i := 1; i <= 3; i++ {
server := &SRSServer{
IP: "192.168.1." + string(rune('1'+i)),
ServerID: "server-new" + string(rune('0'+i)),
ServiceID: "service-new" + string(rune('0'+i)),
PID: string(rune('0' + i)),
UpdatedAt: time.Now(),
}
err := lb.Update(ctx, server)
if err != nil {
t.Fatalf("Failed to update server%d: %v", i, err)
}
}

streamURL := "rtmp://test/live/expiration-stream"
// Manually set picked to old expired server
lb.picked.Store(streamURL, oldServer)

numGoroutines := 100
var wg sync.WaitGroup
results := make(chan string, numGoroutines)

// Launch multiple goroutines that will all detect expiration simultaneously
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
picked, err := lb.Pick(ctx, streamURL)
if err != nil {
t.Errorf("Pick failed in goroutine: %v", err)
return
}
// Should not get the expired server
if picked.ID() == oldServer.ID() {
t.Errorf("Got expired server %v", picked.ID())
return
}
results <- picked.ID()
}()
}

wg.Wait()
close(results)

// Collect all results
serverIDs := make(map[string]int)
for id := range results {
serverIDs[id]++
}

// All goroutines should have picked the same new healthy server
// (thanks to double-checked locking)
if len(serverIDs) != 1 {
t.Errorf("Expected all goroutines to pick the same new server, but got %d different servers: %v",
len(serverIDs), serverIDs)
}

// Verify none picked the old expired server
if _, hasOld := serverIDs[oldServer.ID()]; hasOld {
t.Errorf("Some goroutines picked the expired server, this should not happen")
}
}

// TestPick_NoServersAvailable tests error handling when no servers are available.
func TestPick_NoServersAvailable(t *testing.T) {
ctx := context.Background()
lb := NewMemoryLoadBalancer(&mockEnvironment{}).(*MemoryLoadBalancer)

streamURL := "rtmp://test/live/no-server-stream"

// Try to pick when no servers are registered
_, err := lb.Pick(ctx, streamURL)
if err == nil {
t.Error("Expected error when no servers available, got nil")
}
}
Loading