Skip to content

Commit fe98121

Browse files
refactor: avoid code-duplication in CONNECT impl (#129)
1 parent db38e4b commit fe98121

File tree

1 file changed

+8
-114
lines changed

1 file changed

+8
-114
lines changed

proxy/connect.go

Lines changed: 8 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,10 @@ package proxy
22

33
import (
44
"bufio"
5-
"bytes"
65
"crypto/tls"
76
"io"
87
"net"
98
"net/http"
10-
"net/url"
11-
"strconv"
12-
"strings"
139

1410
"github.com/coder/boundary/audit"
1511
)
@@ -18,12 +14,8 @@ import (
1814
func (p *Server) handleCONNECT(conn net.Conn, req *http.Request) {
1915
// Extract target from CONNECT request
2016
// CONNECT requests have the target in req.Host (format: hostname:port)
21-
target := req.Host
22-
if target == "" {
23-
target = req.URL.Host
24-
}
2517

26-
p.logger.Debug("🔌 CONNECT request", "target", target)
18+
p.logger.Debug("🔌 CONNECT request", "target", req.Host)
2719

2820
// Send 200 Connection established response
2921
response := "HTTP/1.1 200 Connection established\r\n\r\n"
@@ -33,15 +25,15 @@ func (p *Server) handleCONNECT(conn net.Conn, req *http.Request) {
3325
return
3426
}
3527

36-
p.logger.Debug("CONNECT tunnel established", "target", target)
28+
p.logger.Debug("CONNECT tunnel established", "target", req.Host)
3729

3830
// Handle the tunnel - decrypt TLS and process each HTTP request
39-
p.handleCONNECTTunnel(conn, target)
31+
p.handleCONNECTTunnel(conn)
4032
}
4133

4234
// handleCONNECTTunnel handles the tunnel after CONNECT is established
4335
// It decrypts TLS traffic and processes each HTTP request separately
44-
func (p *Server) handleCONNECTTunnel(conn net.Conn, target string) {
36+
func (p *Server) handleCONNECTTunnel(conn net.Conn) {
4537
defer func() {
4638
err := conn.Close()
4739
if err != nil {
@@ -74,15 +66,15 @@ func (p *Server) handleCONNECTTunnel(conn net.Conn, target string) {
7466
break
7567
}
7668

77-
p.logger.Debug("🔒 HTTP Request in CONNECT tunnel", "method", req.Method, "url", req.URL.String(), "target", target)
69+
p.logger.Debug("🔒 HTTP Request in CONNECT tunnel", "method", req.Method, "url", req.URL.String(), "target", req.Host)
7870

7971
// Process this request - check if allowed and forward to target
80-
p.processTunnelRequest(tlsConn, req, target)
72+
p.processTunnelRequest(tlsConn, req)
8173
}
8274
}
8375

8476
// processTunnelRequest processes a single HTTP request from the CONNECT tunnel
85-
func (p *Server) processTunnelRequest(conn net.Conn, req *http.Request, targetHost string) {
77+
func (p *Server) processTunnelRequest(conn net.Conn, req *http.Request) {
8678
// Check if request should be allowed
8779
// Use the original request URL but evaluate against rules
8880
urlStr := req.Host + req.URL.String()
@@ -105,103 +97,5 @@ func (p *Server) processTunnelRequest(conn net.Conn, req *http.Request, targetHo
10597

10698
// Forward request to target
10799
// The target is the original CONNECT target, but we use the request's host/path
108-
p.forwardTunnelRequest(conn, req, targetHost)
109-
}
110-
111-
// forwardTunnelRequest forwards a request from the tunnel to the target
112-
func (p *Server) forwardTunnelRequest(conn net.Conn, req *http.Request, targetHost string) {
113-
// Create HTTP client
114-
client := &http.Client{
115-
CheckRedirect: func(req *http.Request, via []*http.Request) error {
116-
return http.ErrUseLastResponse // Don't follow redirects
117-
},
118-
}
119-
120-
// Extract hostname and port from targetHost
121-
hostname := targetHost
122-
port := "443" // Default HTTPS port
123-
if strings.Contains(targetHost, ":") {
124-
parts := strings.Split(targetHost, ":")
125-
hostname = parts[0]
126-
port = parts[1]
127-
}
128-
129-
scheme := "https"
130-
if port == "80" {
131-
scheme = "http"
132-
}
133-
134-
// Build target URL using the request's path but the CONNECT target's host
135-
// URL.Host can include port for connection, but Host header should not
136-
targetURL := &url.URL{
137-
Scheme: scheme,
138-
Host: targetHost, // Include port for connection
139-
Path: req.URL.Path,
140-
RawQuery: req.URL.RawQuery,
141-
}
142-
143-
var body = req.Body
144-
if req.Method == http.MethodGet || req.Method == http.MethodHead {
145-
body = nil
146-
}
147-
148-
newReq, err := http.NewRequest(req.Method, targetURL.String(), body)
149-
if err != nil {
150-
p.logger.Error("can't create HTTP request for tunnel", "error", err)
151-
return
152-
}
153-
154-
// Set Host header to just the hostname (without port)
155-
// The Host header should not include the port number for HTTPS
156-
newReq.Host = hostname
157-
158-
// Copy headers (but skip Host since we set it explicitly above)
159-
for name, values := range req.Header {
160-
// Skip connection-specific headers and Host header
161-
lowerName := strings.ToLower(name)
162-
if lowerName == "connection" || lowerName == "proxy-connection" || lowerName == "host" {
163-
continue
164-
}
165-
for _, value := range values {
166-
newReq.Header.Add(name, value)
167-
}
168-
}
169-
170-
// Make request to destination
171-
resp, err := client.Do(newReq)
172-
if err != nil {
173-
p.logger.Error("Failed to forward request from CONNECT tunnel", "error", err)
174-
return
175-
}
176-
177-
p.logger.Debug("Response from target", "status", resp.StatusCode, "target", targetHost)
178-
179-
// Read the body and set Content-Length
180-
bodyBytes, err := io.ReadAll(resp.Body)
181-
if err != nil {
182-
p.logger.Error("can't read response body from tunnel", "error", err)
183-
return
184-
}
185-
resp.Header.Set("Content-Length", strconv.Itoa(len(bodyBytes)))
186-
resp.ContentLength = int64(len(bodyBytes))
187-
err = resp.Body.Close()
188-
if err != nil {
189-
p.logger.Error("Failed to close response body", "error", err)
190-
return
191-
}
192-
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
193-
194-
// Normalize to HTTP/1.1
195-
resp.Proto = "HTTP/1.1"
196-
resp.ProtoMajor = 1
197-
resp.ProtoMinor = 1
198-
199-
// Write response back to tunnel
200-
err = resp.Write(conn)
201-
if err != nil {
202-
p.logger.Error("Failed to write response to CONNECT tunnel", "error", err)
203-
return
204-
}
205-
206-
p.logger.Debug("Successfully forwarded response in CONNECT tunnel")
100+
p.forwardRequest(conn, req, true)
207101
}

0 commit comments

Comments
 (0)