main
1package proxy
2
3import (
4 "bytes"
5 "crypto/sha1"
6 "crypto/sha256"
7 "encoding/base32"
8 "fmt"
9 "io"
10 "net/http"
11 "strings"
12 "time"
13
14 "github.com/internetarchive/gowarcprox/internal/models"
15 "github.com/zeebo/blake3"
16)
17
18// RecordingResponseWriter wraps http.ResponseWriter to capture response data
19type RecordingResponseWriter struct {
20 http.ResponseWriter
21 statusCode int
22 body *bytes.Buffer
23 written int64
24}
25
26// NewRecordingResponseWriter creates a new recording response writer
27func NewRecordingResponseWriter(w http.ResponseWriter) *RecordingResponseWriter {
28 return &RecordingResponseWriter{
29 ResponseWriter: w,
30 statusCode: http.StatusOK, // Default status
31 body: &bytes.Buffer{},
32 }
33}
34
35// WriteHeader captures the status code and forwards to underlying writer
36func (r *RecordingResponseWriter) WriteHeader(statusCode int) {
37 r.statusCode = statusCode
38 r.ResponseWriter.WriteHeader(statusCode)
39}
40
41// Write captures the response body and forwards to underlying writer
42func (r *RecordingResponseWriter) Write(data []byte) (int, error) {
43 // Write to buffer for recording
44 r.body.Write(data)
45
46 // Write to actual response
47 n, err := r.ResponseWriter.Write(data)
48 r.written += int64(n)
49 return n, err
50}
51
52// StatusCode returns the captured status code
53func (r *RecordingResponseWriter) StatusCode() int {
54 return r.statusCode
55}
56
57// Body returns the captured response body
58func (r *RecordingResponseWriter) Body() []byte {
59 return r.body.Bytes()
60}
61
62// createRecordedURL creates a RecordedURL from request and response data
63func (h *Handler) createRecordedURL(
64 req *http.Request,
65 resp *http.Response,
66 reqBody []byte,
67 respBody []byte,
68 startTime time.Time,
69 remoteAddr string,
70) *models.RecordedURL {
71
72 duration := time.Since(startTime)
73
74 // Create RecordedURL
75 ru := &models.RecordedURL{
76 URL: req.URL.String(),
77 Method: req.Method,
78 RequestHeader: req.Header.Clone(),
79 RequestBody: reqBody,
80 StatusCode: resp.StatusCode,
81 StatusMessage: resp.Status,
82 ResponseHeader: resp.Header.Clone(),
83 ResponseBody: respBody,
84 Timestamp: startTime,
85 Duration: duration,
86 RemoteAddr: remoteAddr,
87 ClientAddr: req.RemoteAddr,
88 ContentType: resp.Header.Get("Content-Type"),
89 ContentLength: int64(len(respBody)),
90 }
91
92 // Extract client IP
93 if idx := strings.LastIndex(req.RemoteAddr, ":"); idx != -1 {
94 ru.ClientIP = req.RemoteAddr[:idx]
95 } else {
96 ru.ClientIP = req.RemoteAddr
97 }
98
99 // Extract remote IP
100 if idx := strings.LastIndex(remoteAddr, ":"); idx != -1 {
101 ru.RemoteIP = remoteAddr[:idx]
102 } else {
103 ru.RemoteIP = remoteAddr
104 }
105
106 // Calculate payload digest (just the response body)
107 ru.PayloadDigest = h.calculateDigest(respBody)
108
109 // Calculate block digest (entire HTTP response block)
110 blockData := h.buildResponseBlock(resp, respBody)
111 ru.BlockDigest = h.calculateDigest(blockData)
112
113 // Parse Warcprox-Meta header if present
114 if metaHeader := req.Header.Get("Warcprox-Meta"); metaHeader != "" {
115 // TODO: Parse Warcprox-Meta JSON in Phase 7
116 h.logger.Debug("Warcprox-Meta header present (parsing not yet implemented)",
117 "header", metaHeader)
118 }
119
120 return ru
121}
122
123// calculateDigest calculates the digest of data using the configured algorithm
124func (h *Handler) calculateDigest(data []byte) string {
125 algorithm := h.config.DigestAlgorithm
126
127 switch strings.ToLower(algorithm) {
128 case "sha1":
129 hash := sha1.Sum(data)
130 encoded := base32.StdEncoding.EncodeToString(hash[:])
131 // Remove padding
132 encoded = strings.TrimRight(encoded, "=")
133 return "sha1:" + encoded
134
135 case "sha256":
136 hash := sha256.Sum256(data)
137 encoded := base32.StdEncoding.EncodeToString(hash[:])
138 encoded = strings.TrimRight(encoded, "=")
139 return "sha256:" + encoded
140
141 case "blake3":
142 hash := blake3.Sum256(data)
143 encoded := fmt.Sprintf("%x", hash)
144 return "blake3:" + encoded
145
146 default:
147 // Default to SHA1
148 hash := sha1.Sum(data)
149 encoded := base32.StdEncoding.EncodeToString(hash[:])
150 encoded = strings.TrimRight(encoded, "=")
151 return "sha1:" + encoded
152 }
153}
154
155// buildResponseBlock builds the complete HTTP response block (headers + body)
156func (h *Handler) buildResponseBlock(resp *http.Response, body []byte) []byte {
157 var buf bytes.Buffer
158
159 // Write status line
160 fmt.Fprintf(&buf, "%s %s\r\n", resp.Proto, resp.Status)
161
162 // Write headers
163 for name, values := range resp.Header {
164 for _, value := range values {
165 fmt.Fprintf(&buf, "%s: %s\r\n", name, value)
166 }
167 }
168
169 // Write blank line separating headers from body
170 buf.WriteString("\r\n")
171
172 // Write body
173 buf.Write(body)
174
175 return buf.Bytes()
176}
177
178// ErrResourceTooLarge is returned when a resource exceeds the maximum allowed size
179var ErrResourceTooLarge = fmt.Errorf("resource size exceeds maximum allowed")
180
181// readLimitedBody reads from a reader with an optional size limit.
182// If maxSize is 0, no limit is applied.
183// If maxSize is positive and the content exceeds it, returns ErrResourceTooLarge.
184func readLimitedBody(r io.Reader, maxSize int64) ([]byte, error) {
185 if r == nil {
186 return nil, nil
187 }
188
189 if maxSize > 0 {
190 // Use LimitReader to prevent reading more than maxSize+1 bytes
191 // The +1 allows us to detect if the content was truncated
192 limited := io.LimitReader(r, maxSize+1)
193 body, err := io.ReadAll(limited)
194 if err != nil {
195 return nil, err
196 }
197 if int64(len(body)) > maxSize {
198 return nil, ErrResourceTooLarge
199 }
200 return body, nil
201 }
202
203 return io.ReadAll(r)
204}
205
206// readRequestBody reads and returns the request body, replacing it with a buffer.
207// If maxSize is positive, it limits the body size and returns ErrResourceTooLarge if exceeded.
208func readRequestBody(req *http.Request, maxSize int64) ([]byte, error) {
209 if req.Body == nil {
210 return nil, nil
211 }
212
213 body, err := readLimitedBody(req.Body, maxSize)
214 if err != nil {
215 return nil, fmt.Errorf("failed to read request body: %w", err)
216 }
217
218 // Replace body with a buffer so it can be read again
219 req.Body = io.NopCloser(bytes.NewReader(body))
220
221 return body, nil
222}
223
224// readResponseBody reads and returns the response body.
225// If maxSize is positive, it limits the body size and returns ErrResourceTooLarge if exceeded.
226func readResponseBody(r io.Reader, maxSize int64) ([]byte, error) {
227 body, err := readLimitedBody(r, maxSize)
228 if err != nil {
229 return nil, fmt.Errorf("failed to read response body: %w", err)
230 }
231 return body, nil
232}