Skip to content

Commit 2a203e6

Browse files
refactor: proxy tests (#131)
* refactor: proxy tests * ci: fix linter
1 parent f35ee3b commit 2a203e6

File tree

2 files changed

+192
-152
lines changed

2 files changed

+192
-152
lines changed

proxy/proxy_framework_test.go

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
package proxy
2+
3+
import (
4+
"crypto/tls"
5+
"io"
6+
"log/slog"
7+
"net/http"
8+
"os"
9+
"os/user"
10+
"strconv"
11+
"testing"
12+
"time"
13+
14+
"github.com/coder/boundary/rulesengine"
15+
boundary_tls "github.com/coder/boundary/tls"
16+
"github.com/stretchr/testify/require"
17+
)
18+
19+
// ProxyTest is a high-level test framework for proxy tests
20+
type ProxyTest struct {
21+
t *testing.T
22+
server *Server
23+
client *http.Client
24+
port int
25+
useCertManager bool
26+
configDir string
27+
startupDelay time.Duration
28+
}
29+
30+
// ProxyTestOption is a function that configures ProxyTest
31+
type ProxyTestOption func(*ProxyTest)
32+
33+
// NewProxyTest creates a new ProxyTest instance
34+
func NewProxyTest(t *testing.T, opts ...ProxyTestOption) *ProxyTest {
35+
pt := &ProxyTest{
36+
t: t,
37+
port: 8080,
38+
useCertManager: false,
39+
configDir: "/tmp/boundary",
40+
startupDelay: 100 * time.Millisecond,
41+
}
42+
43+
// Apply options
44+
for _, opt := range opts {
45+
opt(pt)
46+
}
47+
48+
return pt
49+
}
50+
51+
// WithProxyPort sets the proxy server port
52+
func WithProxyPort(port int) ProxyTestOption {
53+
return func(pt *ProxyTest) {
54+
pt.port = port
55+
}
56+
}
57+
58+
// WithCertManager enables TLS certificate manager
59+
func WithCertManager(configDir string) ProxyTestOption {
60+
return func(pt *ProxyTest) {
61+
pt.useCertManager = true
62+
pt.configDir = configDir
63+
}
64+
}
65+
66+
// WithStartupDelay sets how long to wait after starting server before making requests
67+
func WithStartupDelay(delay time.Duration) ProxyTestOption {
68+
return func(pt *ProxyTest) {
69+
pt.startupDelay = delay
70+
}
71+
}
72+
73+
// Start starts the proxy server
74+
func (pt *ProxyTest) Start() *ProxyTest {
75+
pt.t.Helper()
76+
77+
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
78+
Level: slog.LevelError,
79+
}))
80+
81+
testRules, err := rulesengine.ParseAllowSpecs([]string{"method=*"})
82+
require.NoError(pt.t, err, "Failed to parse test rules")
83+
84+
ruleEngine := rulesengine.NewRuleEngine(testRules, logger)
85+
auditor := &mockAuditor{}
86+
87+
var tlsConfig *tls.Config
88+
if pt.useCertManager {
89+
currentUser, err := user.Current()
90+
require.NoError(pt.t, err, "Failed to get current user")
91+
92+
uid, _ := strconv.Atoi(currentUser.Uid)
93+
gid, _ := strconv.Atoi(currentUser.Gid)
94+
95+
certManager, err := boundary_tls.NewCertificateManager(boundary_tls.Config{
96+
Logger: logger,
97+
ConfigDir: pt.configDir,
98+
Uid: uid,
99+
Gid: gid,
100+
})
101+
require.NoError(pt.t, err, "Failed to create certificate manager")
102+
103+
tlsConfig, err = certManager.SetupTLSAndWriteCACert()
104+
require.NoError(pt.t, err, "Failed to setup TLS")
105+
} else {
106+
tlsConfig = &tls.Config{
107+
MinVersion: tls.VersionTLS12,
108+
}
109+
}
110+
111+
pt.server = NewProxyServer(Config{
112+
HTTPPort: pt.port,
113+
RuleEngine: ruleEngine,
114+
Auditor: auditor,
115+
Logger: logger,
116+
TLSConfig: tlsConfig,
117+
})
118+
119+
err = pt.server.Start()
120+
require.NoError(pt.t, err, "Failed to start server")
121+
122+
// Give server time to start
123+
time.Sleep(pt.startupDelay)
124+
125+
// Create HTTP client
126+
pt.client = &http.Client{
127+
Transport: &http.Transport{
128+
TLSClientConfig: &tls.Config{
129+
InsecureSkipVerify: true, // Skip cert verification for testing
130+
},
131+
},
132+
Timeout: 5 * time.Second,
133+
}
134+
135+
return pt
136+
}
137+
138+
// Stop gracefully stops the proxy server
139+
func (pt *ProxyTest) Stop() {
140+
if pt.server != nil {
141+
err := pt.server.Stop()
142+
if err != nil {
143+
pt.t.Logf("Failed to stop proxy server: %v", err)
144+
}
145+
}
146+
}
147+
148+
// ExpectAllowed makes a request through the proxy and expects it to be allowed with the given response body
149+
func (pt *ProxyTest) ExpectAllowed(proxyURL, hostHeader, expectedBody string) {
150+
pt.t.Helper()
151+
152+
req, err := http.NewRequest("GET", proxyURL, nil)
153+
require.NoError(pt.t, err, "Failed to create request")
154+
req.Host = hostHeader
155+
156+
resp, err := pt.client.Do(req)
157+
require.NoError(pt.t, err, "Failed to make request")
158+
defer resp.Body.Close() //nolint:errcheck
159+
160+
body, err := io.ReadAll(resp.Body)
161+
require.NoError(pt.t, err, "Failed to read response body")
162+
163+
require.Equal(pt.t, expectedBody, string(body), "Expected response body does not match")
164+
}
165+
166+
// ExpectAllowedContains makes a request through the proxy and expects it to be allowed, checking that response contains the given text
167+
func (pt *ProxyTest) ExpectAllowedContains(proxyURL, hostHeader, containsText string) {
168+
pt.t.Helper()
169+
170+
req, err := http.NewRequest("GET", proxyURL, nil)
171+
require.NoError(pt.t, err, "Failed to create request")
172+
req.Host = hostHeader
173+
174+
resp, err := pt.client.Do(req)
175+
require.NoError(pt.t, err, "Failed to make request")
176+
defer resp.Body.Close() //nolint:errcheck
177+
178+
body, err := io.ReadAll(resp.Body)
179+
require.NoError(pt.t, err, "Failed to read response body")
180+
181+
require.Contains(pt.t, string(body), containsText, "Response does not contain expected text")
182+
}

proxy/proxy_test.go

Lines changed: 10 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -29,176 +29,34 @@ func (m *mockAuditor) AuditRequest(req audit.Request) {
2929

3030
// TestProxyServerBasicHTTP tests basic HTTP request handling
3131
func TestProxyServerBasicHTTP(t *testing.T) {
32-
// Create test logger
33-
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
34-
Level: slog.LevelError,
35-
}))
36-
37-
// Create test rules (allow all for testing)
38-
testRules, err := rulesengine.ParseAllowSpecs([]string{"method=*"})
39-
if err != nil {
40-
t.Fatalf("Failed to parse test rules: %v", err)
41-
}
42-
43-
// Create rule engine
44-
ruleEngine := rulesengine.NewRuleEngine(testRules, logger)
45-
46-
// Create mock auditor
47-
auditor := &mockAuditor{}
48-
49-
// Create TLS config (minimal for testing)
50-
tlsConfig := &tls.Config{
51-
MinVersion: tls.VersionTLS12,
52-
}
53-
54-
// Create proxy server
55-
server := NewProxyServer(Config{
56-
HTTPPort: 8080,
57-
RuleEngine: ruleEngine,
58-
Auditor: auditor,
59-
Logger: logger,
60-
TLSConfig: tlsConfig,
61-
})
62-
63-
// Start server
64-
err = server.Start()
65-
require.NoError(t, err)
66-
67-
// Give server time to start
68-
time.Sleep(100 * time.Millisecond)
32+
pt := NewProxyTest(t).
33+
Start()
34+
defer pt.Stop()
6935

70-
// Test basic HTTP request
7136
t.Run("BasicHTTPRequest", func(t *testing.T) {
72-
// Create HTTP client
73-
client := &http.Client{
74-
Transport: &http.Transport{
75-
TLSClientConfig: &tls.Config{
76-
InsecureSkipVerify: true, // Skip cert verification for testing
77-
},
78-
},
79-
Timeout: 5 * time.Second,
80-
}
81-
82-
// Make request to proxy
83-
req, err := http.NewRequest("GET", "http://localhost:8080/todos/1", nil)
84-
if err != nil {
85-
t.Fatalf("Failed to create request: %v", err)
86-
}
87-
// Override the Host header
88-
req.Host = "jsonplaceholder.typicode.com"
89-
90-
// Make the request
91-
resp, err := client.Do(req)
92-
require.NoError(t, err)
93-
94-
body, err := io.ReadAll(resp.Body)
95-
require.NoError(t, err)
96-
require.NoError(t, resp.Body.Close())
97-
9837
expectedResponse := `{
9938
"userId": 1,
10039
"id": 1,
10140
"title": "delectus aut autem",
10241
"completed": false
10342
}`
104-
require.Equal(t, expectedResponse, string(body))
43+
pt.ExpectAllowed("http://localhost:8080/todos/1", "jsonplaceholder.typicode.com", expectedResponse)
10544
})
106-
107-
err = server.Stop()
108-
require.NoError(t, err)
10945
}
11046

11147
// TestProxyServerBasicHTTPS tests basic HTTPS request handling
11248
func TestProxyServerBasicHTTPS(t *testing.T) {
113-
// Create test logger
114-
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
115-
Level: slog.LevelError,
116-
}))
49+
pt := NewProxyTest(t,
50+
WithCertManager("/tmp/boundary"),
51+
).
52+
Start()
53+
defer pt.Stop()
11754

118-
// Create test rules (allow all for testing)
119-
testRules, err := rulesengine.ParseAllowSpecs([]string{"method=*"})
120-
if err != nil {
121-
t.Fatalf("Failed to parse test rules: %v", err)
122-
}
123-
124-
// Create rule engine
125-
ruleEngine := rulesengine.NewRuleEngine(testRules, logger)
126-
127-
// Create mock auditor
128-
auditor := &mockAuditor{}
129-
130-
currentUser, err := user.Current()
131-
if err != nil {
132-
log.Fatal(err)
133-
}
134-
135-
uid, _ := strconv.Atoi(currentUser.Uid)
136-
gid, _ := strconv.Atoi(currentUser.Gid)
137-
138-
// Create TLS certificate manager
139-
certManager, err := boundary_tls.NewCertificateManager(boundary_tls.Config{
140-
Logger: logger,
141-
ConfigDir: "/tmp/boundary",
142-
Uid: uid,
143-
Gid: gid,
144-
})
145-
require.NoError(t, err)
146-
147-
// Setup TLS to get cert path for jailer
148-
tlsConfig, err := certManager.SetupTLSAndWriteCACert()
149-
require.NoError(t, err)
150-
151-
// Create proxy server
152-
server := NewProxyServer(Config{
153-
HTTPPort: 8080,
154-
RuleEngine: ruleEngine,
155-
Auditor: auditor,
156-
Logger: logger,
157-
TLSConfig: tlsConfig,
158-
})
159-
160-
// Start server
161-
err = server.Start()
162-
require.NoError(t, err)
163-
164-
// Give server time to start
165-
time.Sleep(100 * time.Millisecond)
166-
167-
// Test basic HTTPS request
16855
t.Run("BasicHTTPSRequest", func(t *testing.T) {
169-
// Create HTTP client
170-
client := &http.Client{
171-
Transport: &http.Transport{
172-
TLSClientConfig: &tls.Config{
173-
InsecureSkipVerify: true, // Skip cert verification for testing
174-
},
175-
},
176-
Timeout: 5 * time.Second,
177-
}
178-
179-
// Make request to proxy
180-
req, err := http.NewRequest("GET", "https://localhost:8080/api/v2", nil)
181-
if err != nil {
182-
t.Fatalf("Failed to create request: %v", err)
183-
}
184-
// Override the Host header
185-
req.Host = "dev.coder.com"
186-
187-
// Make the request
188-
resp, err := client.Do(req)
189-
require.NoError(t, err)
190-
191-
body, err := io.ReadAll(resp.Body)
192-
require.NoError(t, err)
193-
require.NoError(t, resp.Body.Close())
194-
19556
expectedResponse := `{"message":"👋"}
19657
`
197-
require.Equal(t, expectedResponse, string(body))
58+
pt.ExpectAllowed("https://localhost:8080/api/v2", "dev.coder.com", expectedResponse)
19859
})
199-
200-
err = server.Stop()
201-
require.NoError(t, err)
20260
}
20361

20462
// TestProxyServerCONNECT tests HTTP CONNECT method for HTTPS tunneling

0 commit comments

Comments
 (0)