main
1package proxy
2
3import (
4 "bufio"
5 "bytes"
6 "crypto/tls"
7 "fmt"
8 "io"
9 "net"
10 "net/http"
11 "strings"
12 "time"
13)
14
15// handleConnectMITM handles the HTTPS MITM proxy logic
16func (h *Handler) handleConnectMITM(w http.ResponseWriter, r *http.Request) error {
17 // Extract hostname and port from the request
18 host := r.Host
19 if !strings.Contains(host, ":") {
20 host = host + ":443"
21 }
22
23 hostname, _, err := net.SplitHostPort(host)
24 if err != nil {
25 h.logger.Error("failed to parse host", "host", host, "error", err)
26 return fmt.Errorf("invalid host: %w", err)
27 }
28
29 h.logger.Debug("CONNECT request",
30 "host", host,
31 "hostname", hostname)
32
33 // Hijack the connection
34 hijacker, ok := w.(http.Hijacker)
35 if !ok {
36 h.logger.Error("ResponseWriter doesn't support hijacking")
37 http.Error(w, "Internal Server Error", http.StatusInternalServerError)
38 return fmt.Errorf("hijacking not supported")
39 }
40
41 clientConn, clientBuf, err := hijacker.Hijack()
42 if err != nil {
43 h.logger.Error("failed to hijack connection", "error", err)
44 http.Error(w, "Internal Server Error", http.StatusInternalServerError)
45 return fmt.Errorf("hijack failed: %w", err)
46 }
47 defer clientConn.Close()
48
49 // Send "200 Connection Established" to client
50 response := "HTTP/1.1 200 Connection Established\r\n\r\n"
51 if _, err := clientConn.Write([]byte(response)); err != nil {
52 h.logger.Error("failed to send connection established", "error", err)
53 return fmt.Errorf("failed to send 200: %w", err)
54 }
55
56 // Get or generate certificate for this hostname
57 certPEM, keyPEM, err := h.server.certAuth.GetCertificate(hostname)
58 if err != nil {
59 h.logger.Error("failed to get certificate",
60 "hostname", hostname,
61 "error", err)
62 return fmt.Errorf("certificate generation failed: %w", err)
63 }
64
65 // Load the certificate
66 cert, err := tls.X509KeyPair(certPEM, keyPEM)
67 if err != nil {
68 h.logger.Error("failed to load certificate",
69 "hostname", hostname,
70 "error", err)
71 return fmt.Errorf("failed to load certificate: %w", err)
72 }
73
74 // Wrap client connection with TLS (server-side)
75 tlsConfig := &tls.Config{
76 Certificates: []tls.Certificate{cert},
77 }
78
79 tlsClientConn := tls.Server(clientConn, tlsConfig)
80 defer tlsClientConn.Close()
81
82 // Perform TLS handshake
83 if err := tlsClientConn.Handshake(); err != nil {
84 h.logger.Error("TLS handshake failed",
85 "hostname", hostname,
86 "error", err)
87 return fmt.Errorf("TLS handshake failed: %w", err)
88 }
89
90 h.logger.Debug("TLS handshake successful", "hostname", hostname)
91
92 startTime := time.Now()
93
94 // Now read the actual HTTP request from the encrypted connection
95 reader := bufio.NewReader(tlsClientConn)
96 req, err := http.ReadRequest(reader)
97 if err != nil {
98 // EOF is normal when client closes connection
99 if err != io.EOF {
100 h.logger.Error("failed to read request from TLS connection",
101 "hostname", hostname,
102 "error", err)
103 }
104 return nil // Don't return error for EOF
105 }
106
107 // Read request body for recording (with size limit if configured)
108 reqBody, err := readRequestBody(req, h.config.MaxResourceSize)
109 if err != nil {
110 h.logger.Error("failed to read request body", "error", err)
111 return fmt.Errorf("failed to read request body: %w", err)
112 }
113
114 // Fix up the request URL
115 // In HTTPS proxy mode, the request URL doesn't have scheme/host
116 if req.URL.Scheme == "" {
117 req.URL.Scheme = "https"
118 }
119 if req.URL.Host == "" {
120 req.URL.Host = hostname
121 }
122 req.Host = hostname
123
124 h.logger.Debug("decrypted request",
125 "method", req.Method,
126 "url", req.URL.String(),
127 "host", req.Host)
128
129 // Connect to the remote server
130 remoteConn, err := tls.Dial("tcp", host, &tls.Config{
131 ServerName: hostname,
132 })
133 if err != nil {
134 h.logger.Error("failed to connect to remote server",
135 "host", host,
136 "error", err)
137 // Send error response to client
138 resp := &http.Response{
139 StatusCode: http.StatusBadGateway,
140 ProtoMajor: 1,
141 ProtoMinor: 1,
142 Body: io.NopCloser(strings.NewReader("Bad Gateway")),
143 }
144 resp.Write(tlsClientConn)
145 return fmt.Errorf("failed to connect to remote: %w", err)
146 }
147 defer remoteConn.Close()
148
149 // Send the request to the remote server
150 if err := req.Write(remoteConn); err != nil {
151 h.logger.Error("failed to write request to remote server",
152 "url", req.URL.String(),
153 "error", err)
154 resp := &http.Response{
155 StatusCode: http.StatusBadGateway,
156 ProtoMajor: 1,
157 ProtoMinor: 1,
158 Body: io.NopCloser(strings.NewReader("Bad Gateway")),
159 }
160 resp.Write(tlsClientConn)
161 return fmt.Errorf("failed to write request: %w", err)
162 }
163
164 // Read response from remote server
165 remoteReader := bufio.NewReader(remoteConn)
166 resp, err := http.ReadResponse(remoteReader, req)
167 if err != nil {
168 h.logger.Error("failed to read response from remote server",
169 "url", req.URL.String(),
170 "error", err)
171 return fmt.Errorf("failed to read response: %w", err)
172 }
173 defer resp.Body.Close()
174
175 // Read response body for recording (with size limit if configured)
176 respBody, err := readResponseBody(resp.Body, h.config.MaxResourceSize)
177 if err != nil {
178 h.logger.Error("failed to read response body",
179 "url", req.URL.String(),
180 "error", err)
181 return fmt.Errorf("failed to read response body: %w", err)
182 }
183
184 // Replace response body with buffer for writing
185 resp.Body = io.NopCloser(bytes.NewReader(respBody))
186
187 // Send response back to client
188 if err := resp.Write(tlsClientConn); err != nil {
189 h.logger.Error("failed to write response to client",
190 "url", req.URL.String(),
191 "error", err)
192 return fmt.Errorf("failed to write response: %w", err)
193 }
194
195 // Create RecordedURL and enqueue to pipeline
196 ru := h.createRecordedURL(req, resp, reqBody, respBody, startTime, host)
197 if err := h.server.pipeline.Enqueue(ru); err != nil {
198 h.logger.Error("failed to enqueue recorded URL",
199 "url", req.URL.String(),
200 "error", err)
201 }
202
203 h.logger.Info("proxied HTTPS request",
204 "method", req.Method,
205 "url", req.URL.String(),
206 "status", resp.StatusCode,
207 "digest", ru.PayloadDigest)
208
209 // If there's buffered data from the hijack, it should be handled
210 // but for CONNECT, there usually isn't any
211 if clientBuf.Reader.Buffered() > 0 {
212 h.logger.Warn("unexpected buffered data after CONNECT")
213 }
214
215 return nil
216}