main
Raw Download raw file
  1package proxy
  2
  3import (
  4	"errors"
  5	"fmt"
  6	"log/slog"
  7	"net/http"
  8	"strings"
  9	"time"
 10
 11	"github.com/internetarchive/gowarcprox/pkg/config"
 12)
 13
 14// Handler handles HTTP proxy requests
 15type Handler struct {
 16	server *Server
 17	config *config.Config
 18	client *http.Client
 19	logger *slog.Logger
 20}
 21
 22// NewHandler creates a new HTTP handler
 23func NewHandler(server *Server, cfg *config.Config, logger *slog.Logger) *Handler {
 24	return &Handler{
 25		server: server,
 26		config: cfg,
 27		client: &http.Client{
 28			// Don't follow redirects - let the client handle them
 29			CheckRedirect: func(req *http.Request, via []*http.Request) error {
 30				return http.ErrUseLastResponse
 31			},
 32			Timeout: cfg.SocketTimeout,
 33		},
 34		logger: logger,
 35	}
 36}
 37
 38// ServeHTTP handles incoming HTTP requests
 39func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 40	h.logger.Debug("received request",
 41		"method", r.Method,
 42		"url", r.URL.String(),
 43		"host", r.Host)
 44
 45	// Handle CONNECT for HTTPS (will be implemented in Phase 2)
 46	if r.Method == http.MethodConnect {
 47		h.handleConnect(w, r)
 48		return
 49	}
 50
 51	// Handle regular HTTP proxy requests
 52	h.handleHTTP(w, r)
 53}
 54
 55// handleConnect handles CONNECT requests for HTTPS tunneling
 56func (h *Handler) handleConnect(w http.ResponseWriter, r *http.Request) {
 57	if err := h.handleConnectMITM(w, r); err != nil {
 58		h.logger.Error("CONNECT handler error",
 59			"host", r.Host,
 60			"error", err)
 61	}
 62}
 63
 64// handleHTTP handles regular HTTP proxy requests
 65func (h *Handler) handleHTTP(w http.ResponseWriter, r *http.Request) {
 66	startTime := time.Now()
 67
 68	// For proxy requests, the URL should be absolute
 69	// Make sure we have a valid URL
 70	if r.URL.Scheme == "" {
 71		h.logger.Error("invalid proxy request: missing scheme", "url", r.URL.String())
 72		http.Error(w, "Bad Request: URL must be absolute for proxy requests", http.StatusBadRequest)
 73		return
 74	}
 75
 76	// Read and record request body (with size limit if configured)
 77	reqBody, err := readRequestBody(r, h.config.MaxResourceSize)
 78	if err != nil {
 79		h.logger.Error("failed to read request body", "error", err)
 80		if errors.Is(err, ErrResourceTooLarge) {
 81			http.Error(w, "Request Entity Too Large", http.StatusRequestEntityTooLarge)
 82		} else {
 83			http.Error(w, "Bad Gateway", http.StatusBadGateway)
 84		}
 85		return
 86	}
 87
 88	// Create a new request to the remote server
 89	outReq, err := http.NewRequest(r.Method, r.URL.String(), r.Body)
 90	if err != nil {
 91		h.logger.Error("failed to create outbound request", "error", err)
 92		http.Error(w, "Bad Gateway", http.StatusBadGateway)
 93		return
 94	}
 95
 96	// Copy headers from original request
 97	// Remove hop-by-hop headers
 98	for name, values := range r.Header {
 99		if isHopByHopHeader(name) {
100			continue
101		}
102		for _, value := range values {
103			outReq.Header.Add(name, value)
104		}
105	}
106
107	// Set X-Forwarded-For header
108	if clientIP := getClientIP(r); clientIP != "" {
109		prior := outReq.Header.Get("X-Forwarded-For")
110		if prior != "" {
111			outReq.Header.Set("X-Forwarded-For", prior+", "+clientIP)
112		} else {
113			outReq.Header.Set("X-Forwarded-For", clientIP)
114		}
115	}
116
117	// Send the request to the remote server
118	resp, err := h.client.Do(outReq)
119	if err != nil {
120		h.logger.Error("failed to fetch from remote server",
121			"url", r.URL.String(),
122			"error", err)
123		http.Error(w, "Bad Gateway", http.StatusBadGateway)
124		return
125	}
126	defer resp.Body.Close()
127
128	// Read response body for recording (with size limit if configured)
129	respBody, err := readResponseBody(resp.Body, h.config.MaxResourceSize)
130	if err != nil {
131		h.logger.Error("failed to read response body",
132			"url", r.URL.String(),
133			"error", err)
134		if errors.Is(err, ErrResourceTooLarge) {
135			http.Error(w, "Response Too Large", http.StatusBadGateway)
136		} else {
137			http.Error(w, "Bad Gateway", http.StatusBadGateway)
138		}
139		return
140	}
141
142	// Copy response headers
143	for name, values := range resp.Header {
144		if isHopByHopHeader(name) {
145			continue
146		}
147		for _, value := range values {
148			w.Header().Add(name, value)
149		}
150	}
151
152	// Write status code
153	w.WriteHeader(resp.StatusCode)
154
155	// Write response body to client
156	written, err := w.Write(respBody)
157	if err != nil {
158		h.logger.Error("failed to write response body",
159			"url", r.URL.String(),
160			"error", err)
161		return
162	}
163
164	// Get remote address
165	remoteAddr := ""
166	if resp.Request != nil && resp.Request.RemoteAddr != "" {
167		remoteAddr = resp.Request.RemoteAddr
168	}
169
170	// Create RecordedURL and enqueue to pipeline
171	ru := h.createRecordedURL(r, resp, reqBody, respBody, startTime, remoteAddr)
172	if err := h.server.pipeline.Enqueue(ru); err != nil {
173		h.logger.Error("failed to enqueue recorded URL",
174			"url", r.URL.String(),
175			"error", err)
176	}
177
178	h.logger.Info("proxied request",
179		"method", r.Method,
180		"url", r.URL.String(),
181		"status", resp.StatusCode,
182		"bytes", written,
183		"digest", ru.PayloadDigest)
184}
185
186// isHopByHopHeader checks if a header is hop-by-hop
187// These headers should not be forwarded
188func isHopByHopHeader(name string) bool {
189	hopByHopHeaders := []string{
190		"Connection",
191		"Keep-Alive",
192		"Proxy-Authenticate",
193		"Proxy-Authorization",
194		"Te",
195		"Trailer",
196		"Transfer-Encoding",
197		"Upgrade",
198	}
199
200	nameLower := strings.ToLower(name)
201	for _, h := range hopByHopHeaders {
202		if strings.ToLower(h) == nameLower {
203			return true
204		}
205	}
206	return false
207}
208
209// getClientIP extracts the client IP address from the request
210func getClientIP(r *http.Request) string {
211	if ip := r.Header.Get("X-Real-IP"); ip != "" {
212		return ip
213	}
214	if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
215		// X-Forwarded-For can contain multiple IPs, get the first one
216		if idx := strings.Index(ip, ","); idx != -1 {
217			return strings.TrimSpace(ip[:idx])
218		}
219		return ip
220	}
221	// Fall back to RemoteAddr
222	if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
223		return r.RemoteAddr[:idx]
224	}
225	return r.RemoteAddr
226}
227
228// logRequest logs detailed request information for debugging
229func (h *Handler) logRequest(r *http.Request) {
230	h.logger.Debug("request details",
231		"method", r.Method,
232		"url", r.URL.String(),
233		"proto", r.Proto,
234		"host", r.Host,
235		"remote_addr", r.RemoteAddr,
236		"content_length", r.ContentLength)
237
238	for name, values := range r.Header {
239		for _, value := range values {
240			h.logger.Debug(fmt.Sprintf("header: %s: %s", name, value))
241		}
242	}
243}