Skip to content

Commit 415a888

Browse files
committed
add proxy tests
1 parent faed335 commit 415a888

File tree

1 file changed

+394
-0
lines changed

1 file changed

+394
-0
lines changed

internal/api/proxy_test.go

Lines changed: 394 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,394 @@
1+
package api
2+
3+
import (
4+
"crypto/ecdsa"
5+
"crypto/elliptic"
6+
"crypto/rand"
7+
"crypto/tls"
8+
"crypto/x509"
9+
"crypto/x509/pkix"
10+
"encoding/base64"
11+
"fmt"
12+
"io"
13+
"math/big"
14+
"net"
15+
"net/http"
16+
"net/http/httptest"
17+
"net/url"
18+
"strings"
19+
"testing"
20+
"time"
21+
)
22+
23+
// startCONNECTProxy starts an HTTP or HTTPS CONNECT proxy on a random port.
24+
// It returns the proxy URL and a channel that receives the protocol observed by
25+
// the proxy handler for each CONNECT request.
26+
func startCONNECTProxy(t *testing.T, useTLS bool) (proxyURL *url.URL, obsCh <-chan string) {
27+
t.Helper()
28+
29+
ch := make(chan string, 10)
30+
31+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
32+
select {
33+
case ch <- r.Proto:
34+
default:
35+
}
36+
37+
if r.Method != http.MethodConnect {
38+
http.Error(w, "expected CONNECT", http.StatusMethodNotAllowed)
39+
return
40+
}
41+
42+
destConn, err := net.DialTimeout("tcp", r.Host, 10*time.Second)
43+
if err != nil {
44+
http.Error(w, err.Error(), http.StatusBadGateway)
45+
return
46+
}
47+
defer destConn.Close()
48+
49+
hijacker, ok := w.(http.Hijacker)
50+
if !ok {
51+
http.Error(w, "hijacking not supported", http.StatusInternalServerError)
52+
return
53+
}
54+
55+
w.WriteHeader(http.StatusOK)
56+
clientConn, _, err := hijacker.Hijack()
57+
if err != nil {
58+
return
59+
}
60+
defer clientConn.Close()
61+
62+
done := make(chan struct{}, 2)
63+
go func() { io.Copy(destConn, clientConn); done <- struct{}{} }()
64+
go func() { io.Copy(clientConn, destConn); done <- struct{}{} }()
65+
<-done
66+
})
67+
68+
ln, err := net.Listen("tcp", "127.0.0.1:0")
69+
if err != nil {
70+
t.Fatalf("proxy listen: %v", err)
71+
}
72+
73+
srv := &http.Server{Handler: handler}
74+
75+
if useTLS {
76+
cert := generateTestCert(t, "127.0.0.1")
77+
srv.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
78+
go srv.ServeTLS(ln, "", "")
79+
} else {
80+
go srv.Serve(ln)
81+
}
82+
t.Cleanup(func() { srv.Close() })
83+
84+
scheme := "http"
85+
if useTLS {
86+
scheme = "https"
87+
}
88+
pURL, _ := url.Parse(fmt.Sprintf("%s://%s", scheme, ln.Addr().String()))
89+
return pURL, ch
90+
}
91+
92+
// startCONNECTProxyWithAuth is like startCONNECTProxy but requires
93+
// Proxy-Authorization with the given username and password.
94+
func startCONNECTProxyWithAuth(t *testing.T, useTLS bool, wantUser, wantPass string) (proxyURL *url.URL) {
95+
t.Helper()
96+
97+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
98+
if r.Method != http.MethodConnect {
99+
http.Error(w, "expected CONNECT", http.StatusMethodNotAllowed)
100+
return
101+
}
102+
103+
authHeader := r.Header.Get("Proxy-Authorization")
104+
wantAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(wantUser+":"+wantPass))
105+
if authHeader != wantAuth {
106+
http.Error(w, "proxy auth required", http.StatusProxyAuthRequired)
107+
return
108+
}
109+
110+
destConn, err := net.DialTimeout("tcp", r.Host, 10*time.Second)
111+
if err != nil {
112+
http.Error(w, err.Error(), http.StatusBadGateway)
113+
return
114+
}
115+
defer destConn.Close()
116+
117+
hijacker, ok := w.(http.Hijacker)
118+
if !ok {
119+
http.Error(w, "hijacking not supported", http.StatusInternalServerError)
120+
return
121+
}
122+
123+
w.WriteHeader(http.StatusOK)
124+
clientConn, _, err := hijacker.Hijack()
125+
if err != nil {
126+
return
127+
}
128+
defer clientConn.Close()
129+
130+
done := make(chan struct{}, 2)
131+
go func() { io.Copy(destConn, clientConn); done <- struct{}{} }()
132+
go func() { io.Copy(clientConn, destConn); done <- struct{}{} }()
133+
<-done
134+
})
135+
136+
ln, err := net.Listen("tcp", "127.0.0.1:0")
137+
if err != nil {
138+
t.Fatalf("proxy listen: %v", err)
139+
}
140+
141+
srv := &http.Server{Handler: handler}
142+
143+
if useTLS {
144+
cert := generateTestCert(t, "127.0.0.1")
145+
srv.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
146+
go srv.ServeTLS(ln, "", "")
147+
} else {
148+
go srv.Serve(ln)
149+
}
150+
t.Cleanup(func() { srv.Close() })
151+
152+
scheme := "http"
153+
if useTLS {
154+
scheme = "https"
155+
}
156+
pURL, _ := url.Parse(fmt.Sprintf("%s://%s@%s", scheme, url.UserPassword(wantUser, wantPass).String(), ln.Addr().String()))
157+
return pURL
158+
}
159+
160+
func generateTestCert(t *testing.T, host string) tls.Certificate {
161+
t.Helper()
162+
163+
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
164+
if err != nil {
165+
t.Fatalf("generate key: %v", err)
166+
}
167+
template := &x509.Certificate{
168+
SerialNumber: big.NewInt(1),
169+
Subject: pkix.Name{CommonName: host},
170+
NotBefore: time.Now(),
171+
NotAfter: time.Now().Add(1 * time.Hour),
172+
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
173+
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
174+
IPAddresses: []net.IP{net.ParseIP(host)},
175+
}
176+
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
177+
if err != nil {
178+
t.Fatalf("create cert: %v", err)
179+
}
180+
return tls.Certificate{
181+
Certificate: [][]byte{certDER},
182+
PrivateKey: key,
183+
}
184+
}
185+
186+
// newTestTransport creates a base transport suitable for proxy tests.
187+
func newTestTransport() *http.Transport {
188+
transport := http.DefaultTransport.(*http.Transport).Clone()
189+
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
190+
return transport
191+
}
192+
193+
// startTargetServer starts an HTTPS server (with HTTP/2 enabled) that
194+
// responds with "ok" to GET /.
195+
func startTargetServer(t *testing.T) *httptest.Server {
196+
t.Helper()
197+
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
198+
fmt.Fprintln(w, "ok")
199+
}))
200+
srv.EnableHTTP2 = true
201+
srv.StartTLS()
202+
t.Cleanup(srv.Close)
203+
return srv
204+
}
205+
206+
func TestWithProxyTransport_HTTPProxy(t *testing.T) {
207+
target := startTargetServer(t)
208+
proxyURL, obsCh := startCONNECTProxy(t, false)
209+
210+
transport := withProxyTransport(newTestTransport(), proxyURL, "")
211+
client := &http.Client{Transport: transport, Timeout: 10 * time.Second}
212+
213+
resp, err := client.Get(target.URL)
214+
if err != nil {
215+
t.Fatalf("GET through http proxy: %v", err)
216+
}
217+
defer resp.Body.Close()
218+
body, _ := io.ReadAll(resp.Body)
219+
220+
if resp.StatusCode != http.StatusOK {
221+
t.Errorf("expected 200, got %d", resp.StatusCode)
222+
}
223+
if got := strings.TrimSpace(string(body)); got != "ok" {
224+
t.Errorf("expected body 'ok', got %q", got)
225+
}
226+
227+
select {
228+
case proto := <-obsCh:
229+
if proto != "HTTP/1.1" {
230+
t.Errorf("expected proxy to see HTTP/1.1 CONNECT, got %s", proto)
231+
}
232+
case <-time.After(2 * time.Second):
233+
t.Fatal("proxy handler was never invoked")
234+
}
235+
}
236+
237+
func TestWithProxyTransport_HTTPSProxy(t *testing.T) {
238+
target := startTargetServer(t)
239+
proxyURL, obsCh := startCONNECTProxy(t, true)
240+
241+
transport := withProxyTransport(newTestTransport(), proxyURL, "")
242+
client := &http.Client{Transport: transport, Timeout: 10 * time.Second}
243+
244+
resp, err := client.Get(target.URL)
245+
if err != nil {
246+
t.Fatalf("GET through https proxy: %v", err)
247+
}
248+
defer resp.Body.Close()
249+
body, _ := io.ReadAll(resp.Body)
250+
251+
if resp.StatusCode != http.StatusOK {
252+
t.Errorf("expected 200, got %d", resp.StatusCode)
253+
}
254+
if got := strings.TrimSpace(string(body)); got != "ok" {
255+
t.Errorf("expected body 'ok', got %q", got)
256+
}
257+
258+
select {
259+
case proto := <-obsCh:
260+
if proto != "HTTP/1.1" {
261+
t.Errorf("expected proxy to see HTTP/1.1 CONNECT, got %s", proto)
262+
}
263+
case <-time.After(2 * time.Second):
264+
t.Fatal("proxy handler was never invoked")
265+
}
266+
}
267+
268+
func TestWithProxyTransport_ProxyAuth(t *testing.T) {
269+
target := startTargetServer(t)
270+
271+
t.Run("http proxy with auth", func(t *testing.T) {
272+
proxyURL := startCONNECTProxyWithAuth(t, false, "user", "pass")
273+
transport := withProxyTransport(newTestTransport(), proxyURL, "")
274+
client := &http.Client{Transport: transport, Timeout: 10 * time.Second}
275+
276+
resp, err := client.Get(target.URL)
277+
if err != nil {
278+
t.Fatalf("GET through authenticated http proxy: %v", err)
279+
}
280+
defer resp.Body.Close()
281+
io.ReadAll(resp.Body)
282+
283+
if resp.StatusCode != http.StatusOK {
284+
t.Errorf("expected 200, got %d", resp.StatusCode)
285+
}
286+
})
287+
288+
t.Run("https proxy with auth", func(t *testing.T) {
289+
proxyURL := startCONNECTProxyWithAuth(t, true, "user", "s3cret")
290+
transport := withProxyTransport(newTestTransport(), proxyURL, "")
291+
client := &http.Client{Transport: transport, Timeout: 10 * time.Second}
292+
293+
resp, err := client.Get(target.URL)
294+
if err != nil {
295+
t.Fatalf("GET through authenticated https proxy: %v", err)
296+
}
297+
defer resp.Body.Close()
298+
io.ReadAll(resp.Body)
299+
300+
if resp.StatusCode != http.StatusOK {
301+
t.Errorf("expected 200, got %d", resp.StatusCode)
302+
}
303+
})
304+
}
305+
306+
func TestWithProxyTransport_HTTPSProxy_HTTP2ToOrigin(t *testing.T) {
307+
// Verify that when tunneling through an HTTPS proxy, the connection to
308+
// the origin target still negotiates HTTP/2 (not downgraded to HTTP/1.1).
309+
target := startTargetServer(t)
310+
proxyURL, _ := startCONNECTProxy(t, true)
311+
312+
transport := withProxyTransport(newTestTransport(), proxyURL, "")
313+
client := &http.Client{Transport: transport, Timeout: 10 * time.Second}
314+
315+
resp, err := client.Get(target.URL)
316+
if err != nil {
317+
t.Fatalf("GET through https proxy: %v", err)
318+
}
319+
defer resp.Body.Close()
320+
io.ReadAll(resp.Body)
321+
322+
if resp.Proto != "HTTP/2.0" {
323+
t.Errorf("expected HTTP/2.0 to origin, got %s", resp.Proto)
324+
}
325+
}
326+
327+
func TestWithProxyTransport_ProxyRejectsConnect(t *testing.T) {
328+
tests := []struct {
329+
name string
330+
statusCode int
331+
body string
332+
wantStatus string
333+
}{
334+
{"407 proxy auth required", http.StatusProxyAuthRequired, "proxy auth required", "407 Proxy Authentication Required"},
335+
{"403 forbidden", http.StatusForbidden, "access denied by policy", "403 Forbidden"},
336+
{"502 bad gateway", http.StatusBadGateway, "upstream unreachable", "502 Bad Gateway"},
337+
}
338+
339+
for _, tt := range tests {
340+
t.Run(tt.name, func(t *testing.T) {
341+
// Start a proxy that always rejects CONNECT with the given status.
342+
ln, err := net.Listen("tcp", "127.0.0.1:0")
343+
if err != nil {
344+
t.Fatalf("listen: %v", err)
345+
}
346+
srv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
347+
http.Error(w, tt.body, tt.statusCode)
348+
})}
349+
go srv.Serve(ln)
350+
t.Cleanup(func() { srv.Close() })
351+
352+
proxyURL, _ := url.Parse(fmt.Sprintf("http://%s", ln.Addr().String()))
353+
transport := withProxyTransport(newTestTransport(), proxyURL, "")
354+
client := &http.Client{Transport: transport, Timeout: 10 * time.Second}
355+
356+
_, err = client.Get("https://example.com")
357+
if err == nil {
358+
t.Fatal("expected error, got nil")
359+
}
360+
if !strings.Contains(err.Error(), tt.wantStatus) {
361+
t.Errorf("error should contain status %q, got: %v", tt.wantStatus, err)
362+
}
363+
if !strings.Contains(err.Error(), tt.body) {
364+
t.Errorf("error should contain body %q, got: %v", tt.body, err)
365+
}
366+
})
367+
}
368+
}
369+
370+
371+
func TestProxyHostPort(t *testing.T) {
372+
tests := []struct {
373+
name string
374+
url string
375+
want string
376+
}{
377+
{"https with port", "https://proxy.example.com:8443", "proxy.example.com:8443"},
378+
{"https without port", "https://proxy.example.com", "proxy.example.com:443"},
379+
{"http with port", "http://proxy.example.com:8080", "proxy.example.com:8080"},
380+
{"http without port", "http://proxy.example.com", "proxy.example.com:80"},
381+
}
382+
for _, tt := range tests {
383+
t.Run(tt.name, func(t *testing.T) {
384+
u, err := url.Parse(tt.url)
385+
if err != nil {
386+
t.Fatalf("parse URL: %v", err)
387+
}
388+
got := proxyDialAddr(u)
389+
if got != tt.want {
390+
t.Errorf("proxyHostPort(%s) = %q, want %q", tt.url, got, tt.want)
391+
}
392+
})
393+
}
394+
}

0 commit comments

Comments
 (0)