main
Raw Download raw file
  1package proxy
  2
  3import (
  4	"bufio"
  5	"bytes"
  6	"crypto/tls"
  7	"fmt"
  8	"io"
  9	"net"
 10	"net/http"
 11	"strings"
 12	"time"
 13)
 14
 15// handleConnectMITM handles the HTTPS MITM proxy logic
 16func (h *Handler) handleConnectMITM(w http.ResponseWriter, r *http.Request) error {
 17	// Extract hostname and port from the request
 18	host := r.Host
 19	if !strings.Contains(host, ":") {
 20		host = host + ":443"
 21	}
 22
 23	hostname, _, err := net.SplitHostPort(host)
 24	if err != nil {
 25		h.logger.Error("failed to parse host", "host", host, "error", err)
 26		return fmt.Errorf("invalid host: %w", err)
 27	}
 28
 29	h.logger.Debug("CONNECT request",
 30		"host", host,
 31		"hostname", hostname)
 32
 33	// Hijack the connection
 34	hijacker, ok := w.(http.Hijacker)
 35	if !ok {
 36		h.logger.Error("ResponseWriter doesn't support hijacking")
 37		http.Error(w, "Internal Server Error", http.StatusInternalServerError)
 38		return fmt.Errorf("hijacking not supported")
 39	}
 40
 41	clientConn, clientBuf, err := hijacker.Hijack()
 42	if err != nil {
 43		h.logger.Error("failed to hijack connection", "error", err)
 44		http.Error(w, "Internal Server Error", http.StatusInternalServerError)
 45		return fmt.Errorf("hijack failed: %w", err)
 46	}
 47	defer clientConn.Close()
 48
 49	// Send "200 Connection Established" to client
 50	response := "HTTP/1.1 200 Connection Established\r\n\r\n"
 51	if _, err := clientConn.Write([]byte(response)); err != nil {
 52		h.logger.Error("failed to send connection established", "error", err)
 53		return fmt.Errorf("failed to send 200: %w", err)
 54	}
 55
 56	// Get or generate certificate for this hostname
 57	certPEM, keyPEM, err := h.server.certAuth.GetCertificate(hostname)
 58	if err != nil {
 59		h.logger.Error("failed to get certificate",
 60			"hostname", hostname,
 61			"error", err)
 62		return fmt.Errorf("certificate generation failed: %w", err)
 63	}
 64
 65	// Load the certificate
 66	cert, err := tls.X509KeyPair(certPEM, keyPEM)
 67	if err != nil {
 68		h.logger.Error("failed to load certificate",
 69			"hostname", hostname,
 70			"error", err)
 71		return fmt.Errorf("failed to load certificate: %w", err)
 72	}
 73
 74	// Wrap client connection with TLS (server-side)
 75	tlsConfig := &tls.Config{
 76		Certificates: []tls.Certificate{cert},
 77	}
 78
 79	tlsClientConn := tls.Server(clientConn, tlsConfig)
 80	defer tlsClientConn.Close()
 81
 82	// Perform TLS handshake
 83	if err := tlsClientConn.Handshake(); err != nil {
 84		h.logger.Error("TLS handshake failed",
 85			"hostname", hostname,
 86			"error", err)
 87		return fmt.Errorf("TLS handshake failed: %w", err)
 88	}
 89
 90	h.logger.Debug("TLS handshake successful", "hostname", hostname)
 91
 92	startTime := time.Now()
 93
 94	// Now read the actual HTTP request from the encrypted connection
 95	reader := bufio.NewReader(tlsClientConn)
 96	req, err := http.ReadRequest(reader)
 97	if err != nil {
 98		// EOF is normal when client closes connection
 99		if err != io.EOF {
100			h.logger.Error("failed to read request from TLS connection",
101				"hostname", hostname,
102				"error", err)
103		}
104		return nil // Don't return error for EOF
105	}
106
107	// Read request body for recording (with size limit if configured)
108	reqBody, err := readRequestBody(req, h.config.MaxResourceSize)
109	if err != nil {
110		h.logger.Error("failed to read request body", "error", err)
111		return fmt.Errorf("failed to read request body: %w", err)
112	}
113
114	// Fix up the request URL
115	// In HTTPS proxy mode, the request URL doesn't have scheme/host
116	if req.URL.Scheme == "" {
117		req.URL.Scheme = "https"
118	}
119	if req.URL.Host == "" {
120		req.URL.Host = hostname
121	}
122	req.Host = hostname
123
124	h.logger.Debug("decrypted request",
125		"method", req.Method,
126		"url", req.URL.String(),
127		"host", req.Host)
128
129	// Connect to the remote server
130	remoteConn, err := tls.Dial("tcp", host, &tls.Config{
131		ServerName: hostname,
132	})
133	if err != nil {
134		h.logger.Error("failed to connect to remote server",
135			"host", host,
136			"error", err)
137		// Send error response to client
138		resp := &http.Response{
139			StatusCode: http.StatusBadGateway,
140			ProtoMajor: 1,
141			ProtoMinor: 1,
142			Body:       io.NopCloser(strings.NewReader("Bad Gateway")),
143		}
144		resp.Write(tlsClientConn)
145		return fmt.Errorf("failed to connect to remote: %w", err)
146	}
147	defer remoteConn.Close()
148
149	// Send the request to the remote server
150	if err := req.Write(remoteConn); err != nil {
151		h.logger.Error("failed to write request to remote server",
152			"url", req.URL.String(),
153			"error", err)
154		resp := &http.Response{
155			StatusCode: http.StatusBadGateway,
156			ProtoMajor: 1,
157			ProtoMinor: 1,
158			Body:       io.NopCloser(strings.NewReader("Bad Gateway")),
159		}
160		resp.Write(tlsClientConn)
161		return fmt.Errorf("failed to write request: %w", err)
162	}
163
164	// Read response from remote server
165	remoteReader := bufio.NewReader(remoteConn)
166	resp, err := http.ReadResponse(remoteReader, req)
167	if err != nil {
168		h.logger.Error("failed to read response from remote server",
169			"url", req.URL.String(),
170			"error", err)
171		return fmt.Errorf("failed to read response: %w", err)
172	}
173	defer resp.Body.Close()
174
175	// Read response body for recording (with size limit if configured)
176	respBody, err := readResponseBody(resp.Body, h.config.MaxResourceSize)
177	if err != nil {
178		h.logger.Error("failed to read response body",
179			"url", req.URL.String(),
180			"error", err)
181		return fmt.Errorf("failed to read response body: %w", err)
182	}
183
184	// Replace response body with buffer for writing
185	resp.Body = io.NopCloser(bytes.NewReader(respBody))
186
187	// Send response back to client
188	if err := resp.Write(tlsClientConn); err != nil {
189		h.logger.Error("failed to write response to client",
190			"url", req.URL.String(),
191			"error", err)
192		return fmt.Errorf("failed to write response: %w", err)
193	}
194
195	// Create RecordedURL and enqueue to pipeline
196	ru := h.createRecordedURL(req, resp, reqBody, respBody, startTime, host)
197	if err := h.server.pipeline.Enqueue(ru); err != nil {
198		h.logger.Error("failed to enqueue recorded URL",
199			"url", req.URL.String(),
200			"error", err)
201	}
202
203	h.logger.Info("proxied HTTPS request",
204		"method", req.Method,
205		"url", req.URL.String(),
206		"status", resp.StatusCode,
207		"digest", ru.PayloadDigest)
208
209	// If there's buffered data from the hijack, it should be handled
210	// but for CONNECT, there usually isn't any
211	if clientBuf.Reader.Buffered() > 0 {
212		h.logger.Warn("unexpected buffered data after CONNECT")
213	}
214
215	return nil
216}