main
1package proxy
2
3import (
4 "errors"
5 "fmt"
6 "log/slog"
7 "net/http"
8 "strings"
9 "time"
10
11 "github.com/internetarchive/gowarcprox/pkg/config"
12)
13
14// Handler handles HTTP proxy requests
15type Handler struct {
16 server *Server
17 config *config.Config
18 client *http.Client
19 logger *slog.Logger
20}
21
22// NewHandler creates a new HTTP handler
23func NewHandler(server *Server, cfg *config.Config, logger *slog.Logger) *Handler {
24 return &Handler{
25 server: server,
26 config: cfg,
27 client: &http.Client{
28 // Don't follow redirects - let the client handle them
29 CheckRedirect: func(req *http.Request, via []*http.Request) error {
30 return http.ErrUseLastResponse
31 },
32 Timeout: cfg.SocketTimeout,
33 },
34 logger: logger,
35 }
36}
37
38// ServeHTTP handles incoming HTTP requests
39func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
40 h.logger.Debug("received request",
41 "method", r.Method,
42 "url", r.URL.String(),
43 "host", r.Host)
44
45 // Handle CONNECT for HTTPS (will be implemented in Phase 2)
46 if r.Method == http.MethodConnect {
47 h.handleConnect(w, r)
48 return
49 }
50
51 // Handle regular HTTP proxy requests
52 h.handleHTTP(w, r)
53}
54
55// handleConnect handles CONNECT requests for HTTPS tunneling
56func (h *Handler) handleConnect(w http.ResponseWriter, r *http.Request) {
57 if err := h.handleConnectMITM(w, r); err != nil {
58 h.logger.Error("CONNECT handler error",
59 "host", r.Host,
60 "error", err)
61 }
62}
63
64// handleHTTP handles regular HTTP proxy requests
65func (h *Handler) handleHTTP(w http.ResponseWriter, r *http.Request) {
66 startTime := time.Now()
67
68 // For proxy requests, the URL should be absolute
69 // Make sure we have a valid URL
70 if r.URL.Scheme == "" {
71 h.logger.Error("invalid proxy request: missing scheme", "url", r.URL.String())
72 http.Error(w, "Bad Request: URL must be absolute for proxy requests", http.StatusBadRequest)
73 return
74 }
75
76 // Read and record request body (with size limit if configured)
77 reqBody, err := readRequestBody(r, h.config.MaxResourceSize)
78 if err != nil {
79 h.logger.Error("failed to read request body", "error", err)
80 if errors.Is(err, ErrResourceTooLarge) {
81 http.Error(w, "Request Entity Too Large", http.StatusRequestEntityTooLarge)
82 } else {
83 http.Error(w, "Bad Gateway", http.StatusBadGateway)
84 }
85 return
86 }
87
88 // Create a new request to the remote server
89 outReq, err := http.NewRequest(r.Method, r.URL.String(), r.Body)
90 if err != nil {
91 h.logger.Error("failed to create outbound request", "error", err)
92 http.Error(w, "Bad Gateway", http.StatusBadGateway)
93 return
94 }
95
96 // Copy headers from original request
97 // Remove hop-by-hop headers
98 for name, values := range r.Header {
99 if isHopByHopHeader(name) {
100 continue
101 }
102 for _, value := range values {
103 outReq.Header.Add(name, value)
104 }
105 }
106
107 // Set X-Forwarded-For header
108 if clientIP := getClientIP(r); clientIP != "" {
109 prior := outReq.Header.Get("X-Forwarded-For")
110 if prior != "" {
111 outReq.Header.Set("X-Forwarded-For", prior+", "+clientIP)
112 } else {
113 outReq.Header.Set("X-Forwarded-For", clientIP)
114 }
115 }
116
117 // Send the request to the remote server
118 resp, err := h.client.Do(outReq)
119 if err != nil {
120 h.logger.Error("failed to fetch from remote server",
121 "url", r.URL.String(),
122 "error", err)
123 http.Error(w, "Bad Gateway", http.StatusBadGateway)
124 return
125 }
126 defer resp.Body.Close()
127
128 // Read response body for recording (with size limit if configured)
129 respBody, err := readResponseBody(resp.Body, h.config.MaxResourceSize)
130 if err != nil {
131 h.logger.Error("failed to read response body",
132 "url", r.URL.String(),
133 "error", err)
134 if errors.Is(err, ErrResourceTooLarge) {
135 http.Error(w, "Response Too Large", http.StatusBadGateway)
136 } else {
137 http.Error(w, "Bad Gateway", http.StatusBadGateway)
138 }
139 return
140 }
141
142 // Copy response headers
143 for name, values := range resp.Header {
144 if isHopByHopHeader(name) {
145 continue
146 }
147 for _, value := range values {
148 w.Header().Add(name, value)
149 }
150 }
151
152 // Write status code
153 w.WriteHeader(resp.StatusCode)
154
155 // Write response body to client
156 written, err := w.Write(respBody)
157 if err != nil {
158 h.logger.Error("failed to write response body",
159 "url", r.URL.String(),
160 "error", err)
161 return
162 }
163
164 // Get remote address
165 remoteAddr := ""
166 if resp.Request != nil && resp.Request.RemoteAddr != "" {
167 remoteAddr = resp.Request.RemoteAddr
168 }
169
170 // Create RecordedURL and enqueue to pipeline
171 ru := h.createRecordedURL(r, resp, reqBody, respBody, startTime, remoteAddr)
172 if err := h.server.pipeline.Enqueue(ru); err != nil {
173 h.logger.Error("failed to enqueue recorded URL",
174 "url", r.URL.String(),
175 "error", err)
176 }
177
178 h.logger.Info("proxied request",
179 "method", r.Method,
180 "url", r.URL.String(),
181 "status", resp.StatusCode,
182 "bytes", written,
183 "digest", ru.PayloadDigest)
184}
185
186// isHopByHopHeader checks if a header is hop-by-hop
187// These headers should not be forwarded
188func isHopByHopHeader(name string) bool {
189 hopByHopHeaders := []string{
190 "Connection",
191 "Keep-Alive",
192 "Proxy-Authenticate",
193 "Proxy-Authorization",
194 "Te",
195 "Trailer",
196 "Transfer-Encoding",
197 "Upgrade",
198 }
199
200 nameLower := strings.ToLower(name)
201 for _, h := range hopByHopHeaders {
202 if strings.ToLower(h) == nameLower {
203 return true
204 }
205 }
206 return false
207}
208
209// getClientIP extracts the client IP address from the request
210func getClientIP(r *http.Request) string {
211 if ip := r.Header.Get("X-Real-IP"); ip != "" {
212 return ip
213 }
214 if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
215 // X-Forwarded-For can contain multiple IPs, get the first one
216 if idx := strings.Index(ip, ","); idx != -1 {
217 return strings.TrimSpace(ip[:idx])
218 }
219 return ip
220 }
221 // Fall back to RemoteAddr
222 if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
223 return r.RemoteAddr[:idx]
224 }
225 return r.RemoteAddr
226}
227
228// logRequest logs detailed request information for debugging
229func (h *Handler) logRequest(r *http.Request) {
230 h.logger.Debug("request details",
231 "method", r.Method,
232 "url", r.URL.String(),
233 "proto", r.Proto,
234 "host", r.Host,
235 "remote_addr", r.RemoteAddr,
236 "content_length", r.ContentLength)
237
238 for name, values := range r.Header {
239 for _, value := range values {
240 h.logger.Debug(fmt.Sprintf("header: %s: %s", name, value))
241 }
242 }
243}