Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions internal/pool/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (d *testDialer) Dial(ctx context.Context, network, address string,
return dialer.DialContext(ctx, network, address)
}

func startLocalTCPServer(t *testing.T, handleConn func(net.Conn) error) (
func startLocalTCPServer(t *testing.T, handleConn func(net.Conn) error) ( //nolint:cyclop
dialer *testDialer, runError <-chan error,
) {
t.Helper()
Expand Down Expand Up @@ -57,12 +57,12 @@ func startLocalTCPServer(t *testing.T, handleConn func(net.Conn) error) (
runErrorCh <- fmt.Errorf("accepting connection: %w", err)
return
}
connsInFlightMutex.Lock()
connsInFlight[conn.RemoteAddr().String()] = conn
connsInFlightMutex.Unlock()
handleConnWg.Add(1)
handleConnWg.Go(func() {
defer handleConnWg.Done()
connsInFlightMutex.Lock()
connsInFlight[conn.RemoteAddr().String()] = conn
connsInFlightMutex.Unlock()
err := handleConn(conn)
if err != nil {
select {
Expand All @@ -74,17 +74,24 @@ func startLocalTCPServer(t *testing.T, handleConn func(net.Conn) error) (
}
}()

stop := func() {
t.Cleanup(func() {
_ = listener.Close()
// drain error channel in case test exited with fatal and did not read the runError
// channel return, and one or more goroutines are trying to write an error to runErrorCh
for range len(connsInFlight) {
select {
case <-runErrorCh:
default:
}
}
<-listenerDone
connsInFlightMutex.Lock()
for _, conn := range connsInFlight {
_ = conn.Close()
}
connsInFlightMutex.Unlock()
handleConnWg.Wait()
}
t.Cleanup(stop)
})

select {
case <-ready:
Expand Down
14 changes: 14 additions & 0 deletions internal/pool/renew.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,20 @@ func (p *Pool) renew(ctx context.Context, conn poolConn, network, reason string)
netConn, err := p.dialer.Dial(ctx, network, address)
p.mutex.Lock()
defer p.mutex.Unlock()
// The pool may have been reset (p.addrConns emptied) while we were
// dialing outside the lock. Without this guard, p.addrConns[conn.addrIndex]
// panics with 'index out of range [0] with length 0' and crashes
// the DNS server goroutine, taking gluetun and its network
// namespace down with it.
if conn.addrIndex < 0 || conn.addrIndex >= len(p.addrConns) {
if netConn != nil {
_ = netConn.Close()
}
if err == nil {
err = fmt.Errorf("pool addrIndex %d out of range; pool reset while dialing", conn.addrIndex)
}
return poolConn{}, err
}
addrConns := p.addrConns[conn.addrIndex]
p.ensureConnIDToIndex(&addrConns)
connIndex, found := addrConns.connIDToIndex[conn.id]
Expand Down
85 changes: 85 additions & 0 deletions internal/pool/stress_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package pool

import (
"context"
"fmt"
"math/rand/v2"
"testing"

noopmetrics "github.com/qdm12/dns/v2/internal/pool/metrics/noop"
"github.com/stretchr/testify/require"
)

func Test_Pool_stress(t *testing.T) {
if testing.Short() {
t.Skip("skipping stress test in short mode")
}

t.Parallel()

dialer, runErr := startLocalTCPServer(t, handleConnCopy)
pool := New(dialer, noopmetrics.New())

const workers = 16
const iterations = 100

resultCh := make(chan error)
start := make(chan struct{})
for i := range workers {
go runStressWorker(pool, i, iterations, resultCh, start)
}

close(start)
var errs []error
for range workers {
err := <-resultCh
if err != nil {
errs = append(errs, err)
}
}
require.Empty(t, errs)

pool.mutex.Lock()
currentConns := len(pool.addrConns[0].conns)
pool.mutex.Unlock()
require.LessOrEqual(t, currentConns, workers,
"pool retained too many live connections after stress run")

select {
case err := <-runErr:
require.NoError(t, err)
default:
}
}

func runStressWorker(pool *Pool, worker int, iterations int,
resultCh chan<- error, start <-chan struct{},
) {
<-start
for i := range iterations {
conn, err := pool.Get(context.Background(), "tcp")
if err != nil {
resultCh <- fmt.Errorf("worker %d: iteration %d: get: %w", worker, i, err)
return
}

switch rand.IntN(4) { //nolint:gosec
case 0:
pool.Put(conn)
case 1:
_ = conn.Close()
pool.PutDead(conn)
case 2:
renewedConn, err := pool.Renew(context.Background(), "tcp", conn)
if err != nil {
// Failed renew already marks the slot as dead in pool state.
continue
}
pool.Put(renewedConn)
default:
pool.Put(conn)
}
}
fmt.Println("worker", worker, "done")
resultCh <- nil
}