main
Raw Download raw file
  1package proxy
  2
  3import (
  4	"bytes"
  5	"crypto/sha1"
  6	"crypto/sha256"
  7	"encoding/base32"
  8	"fmt"
  9	"io"
 10	"net/http"
 11	"strings"
 12	"time"
 13
 14	"github.com/internetarchive/gowarcprox/internal/models"
 15	"github.com/zeebo/blake3"
 16)
 17
 18// RecordingResponseWriter wraps http.ResponseWriter to capture response data
 19type RecordingResponseWriter struct {
 20	http.ResponseWriter
 21	statusCode int
 22	body       *bytes.Buffer
 23	written    int64
 24}
 25
 26// NewRecordingResponseWriter creates a new recording response writer
 27func NewRecordingResponseWriter(w http.ResponseWriter) *RecordingResponseWriter {
 28	return &RecordingResponseWriter{
 29		ResponseWriter: w,
 30		statusCode:     http.StatusOK, // Default status
 31		body:           &bytes.Buffer{},
 32	}
 33}
 34
 35// WriteHeader captures the status code and forwards to underlying writer
 36func (r *RecordingResponseWriter) WriteHeader(statusCode int) {
 37	r.statusCode = statusCode
 38	r.ResponseWriter.WriteHeader(statusCode)
 39}
 40
 41// Write captures the response body and forwards to underlying writer
 42func (r *RecordingResponseWriter) Write(data []byte) (int, error) {
 43	// Write to buffer for recording
 44	r.body.Write(data)
 45
 46	// Write to actual response
 47	n, err := r.ResponseWriter.Write(data)
 48	r.written += int64(n)
 49	return n, err
 50}
 51
 52// StatusCode returns the captured status code
 53func (r *RecordingResponseWriter) StatusCode() int {
 54	return r.statusCode
 55}
 56
 57// Body returns the captured response body
 58func (r *RecordingResponseWriter) Body() []byte {
 59	return r.body.Bytes()
 60}
 61
 62// createRecordedURL creates a RecordedURL from request and response data
 63func (h *Handler) createRecordedURL(
 64	req *http.Request,
 65	resp *http.Response,
 66	reqBody []byte,
 67	respBody []byte,
 68	startTime time.Time,
 69	remoteAddr string,
 70) *models.RecordedURL {
 71
 72	duration := time.Since(startTime)
 73
 74	// Create RecordedURL
 75	ru := &models.RecordedURL{
 76		URL:            req.URL.String(),
 77		Method:         req.Method,
 78		RequestHeader:  req.Header.Clone(),
 79		RequestBody:    reqBody,
 80		StatusCode:     resp.StatusCode,
 81		StatusMessage:  resp.Status,
 82		ResponseHeader: resp.Header.Clone(),
 83		ResponseBody:   respBody,
 84		Timestamp:      startTime,
 85		Duration:       duration,
 86		RemoteAddr:     remoteAddr,
 87		ClientAddr:     req.RemoteAddr,
 88		ContentType:    resp.Header.Get("Content-Type"),
 89		ContentLength:  int64(len(respBody)),
 90	}
 91
 92	// Extract client IP
 93	if idx := strings.LastIndex(req.RemoteAddr, ":"); idx != -1 {
 94		ru.ClientIP = req.RemoteAddr[:idx]
 95	} else {
 96		ru.ClientIP = req.RemoteAddr
 97	}
 98
 99	// Extract remote IP
100	if idx := strings.LastIndex(remoteAddr, ":"); idx != -1 {
101		ru.RemoteIP = remoteAddr[:idx]
102	} else {
103		ru.RemoteIP = remoteAddr
104	}
105
106	// Calculate payload digest (just the response body)
107	ru.PayloadDigest = h.calculateDigest(respBody)
108
109	// Calculate block digest (entire HTTP response block)
110	blockData := h.buildResponseBlock(resp, respBody)
111	ru.BlockDigest = h.calculateDigest(blockData)
112
113	// Parse Warcprox-Meta header if present
114	if metaHeader := req.Header.Get("Warcprox-Meta"); metaHeader != "" {
115		// TODO: Parse Warcprox-Meta JSON in Phase 7
116		h.logger.Debug("Warcprox-Meta header present (parsing not yet implemented)",
117			"header", metaHeader)
118	}
119
120	return ru
121}
122
123// calculateDigest calculates the digest of data using the configured algorithm
124func (h *Handler) calculateDigest(data []byte) string {
125	algorithm := h.config.DigestAlgorithm
126
127	switch strings.ToLower(algorithm) {
128	case "sha1":
129		hash := sha1.Sum(data)
130		encoded := base32.StdEncoding.EncodeToString(hash[:])
131		// Remove padding
132		encoded = strings.TrimRight(encoded, "=")
133		return "sha1:" + encoded
134
135	case "sha256":
136		hash := sha256.Sum256(data)
137		encoded := base32.StdEncoding.EncodeToString(hash[:])
138		encoded = strings.TrimRight(encoded, "=")
139		return "sha256:" + encoded
140
141	case "blake3":
142		hash := blake3.Sum256(data)
143		encoded := fmt.Sprintf("%x", hash)
144		return "blake3:" + encoded
145
146	default:
147		// Default to SHA1
148		hash := sha1.Sum(data)
149		encoded := base32.StdEncoding.EncodeToString(hash[:])
150		encoded = strings.TrimRight(encoded, "=")
151		return "sha1:" + encoded
152	}
153}
154
155// buildResponseBlock builds the complete HTTP response block (headers + body)
156func (h *Handler) buildResponseBlock(resp *http.Response, body []byte) []byte {
157	var buf bytes.Buffer
158
159	// Write status line
160	fmt.Fprintf(&buf, "%s %s\r\n", resp.Proto, resp.Status)
161
162	// Write headers
163	for name, values := range resp.Header {
164		for _, value := range values {
165			fmt.Fprintf(&buf, "%s: %s\r\n", name, value)
166		}
167	}
168
169	// Write blank line separating headers from body
170	buf.WriteString("\r\n")
171
172	// Write body
173	buf.Write(body)
174
175	return buf.Bytes()
176}
177
178// ErrResourceTooLarge is returned when a resource exceeds the maximum allowed size
179var ErrResourceTooLarge = fmt.Errorf("resource size exceeds maximum allowed")
180
181// readLimitedBody reads from a reader with an optional size limit.
182// If maxSize is 0, no limit is applied.
183// If maxSize is positive and the content exceeds it, returns ErrResourceTooLarge.
184func readLimitedBody(r io.Reader, maxSize int64) ([]byte, error) {
185	if r == nil {
186		return nil, nil
187	}
188
189	if maxSize > 0 {
190		// Use LimitReader to prevent reading more than maxSize+1 bytes
191		// The +1 allows us to detect if the content was truncated
192		limited := io.LimitReader(r, maxSize+1)
193		body, err := io.ReadAll(limited)
194		if err != nil {
195			return nil, err
196		}
197		if int64(len(body)) > maxSize {
198			return nil, ErrResourceTooLarge
199		}
200		return body, nil
201	}
202
203	return io.ReadAll(r)
204}
205
206// readRequestBody reads and returns the request body, replacing it with a buffer.
207// If maxSize is positive, it limits the body size and returns ErrResourceTooLarge if exceeded.
208func readRequestBody(req *http.Request, maxSize int64) ([]byte, error) {
209	if req.Body == nil {
210		return nil, nil
211	}
212
213	body, err := readLimitedBody(req.Body, maxSize)
214	if err != nil {
215		return nil, fmt.Errorf("failed to read request body: %w", err)
216	}
217
218	// Replace body with a buffer so it can be read again
219	req.Body = io.NopCloser(bytes.NewReader(body))
220
221	return body, nil
222}
223
224// readResponseBody reads and returns the response body.
225// If maxSize is positive, it limits the body size and returns ErrResourceTooLarge if exceeded.
226func readResponseBody(r io.Reader, maxSize int64) ([]byte, error) {
227	body, err := readLimitedBody(r, maxSize)
228	if err != nil {
229		return nil, fmt.Errorf("failed to read response body: %w", err)
230	}
231	return body, nil
232}