Commit eacc60f
Changed files (5)
internal
certauth
proxy
internal/certauth/certauth.go
@@ -0,0 +1,324 @@
+package certauth
+
+import (
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/pem"
+ "fmt"
+ "log/slog"
+ "math/big"
+ "os"
+ "path/filepath"
+ "sync"
+ "time"
+)
+
+// CertificateAuthority manages CA and host certificate generation
+type CertificateAuthority struct {
+ caFile string
+ certsDir string
+ caName string
+ caCert *x509.Certificate
+ caKey *rsa.PrivateKey
+ certCache sync.Map // hostname -> *tls.Certificate
+ mu sync.Mutex
+ logger *slog.Logger
+}
+
+// NewCertificateAuthority creates or loads a certificate authority
+func NewCertificateAuthority(caFile, certsDir, caName string, logger *slog.Logger) (*CertificateAuthority, error) {
+ if logger == nil {
+ logger = slog.Default()
+ }
+
+ ca := &CertificateAuthority{
+ caFile: caFile,
+ certsDir: certsDir,
+ caName: caName,
+ logger: logger,
+ }
+
+ // Create certs directory if it doesn't exist
+ if err := os.MkdirAll(certsDir, 0755); err != nil {
+ return nil, fmt.Errorf("failed to create certs directory: %w", err)
+ }
+
+ // Load or create CA certificate
+ if err := ca.loadOrCreateCA(); err != nil {
+ return nil, fmt.Errorf("failed to initialize CA: %w", err)
+ }
+
+ return ca, nil
+}
+
+// loadOrCreateCA loads an existing CA or creates a new one
+func (ca *CertificateAuthority) loadOrCreateCA() error {
+ // Check if CA file exists
+ if _, err := os.Stat(ca.caFile); err == nil {
+ // Load existing CA
+ return ca.loadCA()
+ }
+
+ // Create new CA
+ return ca.createCA()
+}
+
+// loadCA loads an existing CA certificate and key
+func (ca *CertificateAuthority) loadCA() error {
+ ca.logger.Info("loading CA certificate", "file", ca.caFile)
+
+ // Read CA file
+ caData, err := os.ReadFile(ca.caFile)
+ if err != nil {
+ return fmt.Errorf("failed to read CA file: %w", err)
+ }
+
+ // Parse PEM blocks
+ var certPEM, keyPEM *pem.Block
+ remaining := caData
+
+ for {
+ block, rest := pem.Decode(remaining)
+ if block == nil {
+ break
+ }
+
+ switch block.Type {
+ case "CERTIFICATE":
+ certPEM = block
+ case "RSA PRIVATE KEY":
+ keyPEM = block
+ }
+
+ remaining = rest
+ }
+
+ if certPEM == nil || keyPEM == nil {
+ return fmt.Errorf("CA file must contain both certificate and private key")
+ }
+
+ // Parse certificate
+ cert, err := x509.ParseCertificate(certPEM.Bytes)
+ if err != nil {
+ return fmt.Errorf("failed to parse CA certificate: %w", err)
+ }
+ ca.caCert = cert
+
+ // Parse private key
+ key, err := x509.ParsePKCS1PrivateKey(keyPEM.Bytes)
+ if err != nil {
+ return fmt.Errorf("failed to parse CA private key: %w", err)
+ }
+ ca.caKey = key
+
+ ca.logger.Info("loaded CA certificate",
+ "subject", cert.Subject.CommonName,
+ "valid_from", cert.NotBefore,
+ "valid_to", cert.NotAfter)
+
+ return nil
+}
+
+// createCA creates a new CA certificate and key
+func (ca *CertificateAuthority) createCA() error {
+ ca.logger.Info("creating new CA certificate", "file", ca.caFile)
+
+ // Generate private key
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ return fmt.Errorf("failed to generate CA private key: %w", err)
+ }
+ ca.caKey = key
+
+ // Create CA certificate template
+ serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
+ if err != nil {
+ return fmt.Errorf("failed to generate serial number: %w", err)
+ }
+
+ template := &x509.Certificate{
+ SerialNumber: serialNumber,
+ Subject: pkix.Name{
+ CommonName: ca.caName,
+ Organization: []string{"gowarcprox"},
+ },
+ NotBefore: time.Now().Add(-24 * time.Hour), // Valid from 1 day ago
+ NotAfter: time.Now().Add(3 * 365 * 24 * time.Hour), // 3 years
+ KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ BasicConstraintsValid: true,
+ IsCA: true,
+ MaxPathLen: 0,
+ }
+
+ // Self-sign the certificate
+ certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
+ if err != nil {
+ return fmt.Errorf("failed to create CA certificate: %w", err)
+ }
+
+ // Parse the certificate
+ cert, err := x509.ParseCertificate(certDER)
+ if err != nil {
+ return fmt.Errorf("failed to parse created certificate: %w", err)
+ }
+ ca.caCert = cert
+
+ // Save to file
+ f, err := os.Create(ca.caFile)
+ if err != nil {
+ return fmt.Errorf("failed to create CA file: %w", err)
+ }
+ defer f.Close()
+
+ // Write certificate
+ if err := pem.Encode(f, &pem.Block{
+ Type: "CERTIFICATE",
+ Bytes: certDER,
+ }); err != nil {
+ return fmt.Errorf("failed to write CA certificate: %w", err)
+ }
+
+ // Write private key
+ if err := pem.Encode(f, &pem.Block{
+ Type: "RSA PRIVATE KEY",
+ Bytes: x509.MarshalPKCS1PrivateKey(key),
+ }); err != nil {
+ return fmt.Errorf("failed to write CA private key: %w", err)
+ }
+
+ ca.logger.Info("created new CA certificate",
+ "subject", cert.Subject.CommonName,
+ "valid_from", cert.NotBefore,
+ "valid_to", cert.NotAfter)
+
+ return nil
+}
+
+// GetCertificate returns a certificate for the given hostname
+// Generates a new one if not in cache
+func (ca *CertificateAuthority) GetCertificate(hostname string) ([]byte, []byte, error) {
+ // Check cache first
+ if cached, ok := ca.certCache.Load(hostname); ok {
+ certData := cached.(certKeyPair)
+ return certData.cert, certData.key, nil
+ }
+
+ // Generate new certificate
+ return ca.generateHostCert(hostname)
+}
+
+type certKeyPair struct {
+ cert []byte
+ key []byte
+}
+
+// generateHostCert generates a new certificate for the hostname
+func (ca *CertificateAuthority) generateHostCert(hostname string) ([]byte, []byte, error) {
+ ca.mu.Lock()
+ defer ca.mu.Unlock()
+
+ // Double-check cache after acquiring lock
+ if cached, ok := ca.certCache.Load(hostname); ok {
+ certData := cached.(certKeyPair)
+ return certData.cert, certData.key, nil
+ }
+
+ ca.logger.Debug("generating certificate for hostname", "hostname", hostname)
+
+ // Generate private key for host
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to generate host private key: %w", err)
+ }
+
+ // Generate serial number
+ serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to generate serial number: %w", err)
+ }
+
+ // Create certificate template
+ template := &x509.Certificate{
+ SerialNumber: serialNumber,
+ Subject: pkix.Name{
+ CommonName: hostname,
+ Organization: []string{"gowarcprox"},
+ },
+ NotBefore: time.Now().Add(-24 * time.Hour),
+ NotAfter: time.Now().Add(3 * 365 * 24 * time.Hour), // 3 years
+ KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ BasicConstraintsValid: true,
+ DNSNames: []string{hostname},
+ }
+
+ // Support wildcard certificates
+ if len(hostname) > 0 && hostname[0] != '*' {
+ template.DNSNames = append(template.DNSNames, "*."+hostname)
+ }
+
+ // Sign with CA
+ certDER, err := x509.CreateCertificate(rand.Reader, template, ca.caCert, &key.PublicKey, ca.caKey)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to create host certificate: %w", err)
+ }
+
+ // Encode certificate to PEM
+ certPEM := pem.EncodeToMemory(&pem.Block{
+ Type: "CERTIFICATE",
+ Bytes: certDER,
+ })
+
+ // Encode private key to PEM
+ keyPEM := pem.EncodeToMemory(&pem.Block{
+ Type: "RSA PRIVATE KEY",
+ Bytes: x509.MarshalPKCS1PrivateKey(key),
+ })
+
+ // Save to cache
+ ca.certCache.Store(hostname, certKeyPair{cert: certPEM, key: keyPEM})
+
+ // Optionally save to disk
+ if err := ca.saveCertToDisk(hostname, certPEM, keyPEM); err != nil {
+ ca.logger.Warn("failed to save certificate to disk",
+ "hostname", hostname,
+ "error", err)
+ // Don't return error - caching in memory is sufficient
+ }
+
+ ca.logger.Info("generated certificate for hostname", "hostname", hostname)
+ return certPEM, keyPEM, nil
+}
+
+// saveCertToDisk saves a certificate and key to disk
+func (ca *CertificateAuthority) saveCertToDisk(hostname string, certPEM, keyPEM []byte) error {
+ // Sanitize hostname for filename
+ filename := filepath.Join(ca.certsDir, hostname+".pem")
+
+ f, err := os.Create(filename)
+ if err != nil {
+ return fmt.Errorf("failed to create cert file: %w", err)
+ }
+ defer f.Close()
+
+ if _, err := f.Write(certPEM); err != nil {
+ return fmt.Errorf("failed to write certificate: %w", err)
+ }
+
+ if _, err := f.Write(keyPEM); err != nil {
+ return fmt.Errorf("failed to write private key: %w", err)
+ }
+
+ return nil
+}
+
+// GetCACert returns the CA certificate in PEM format
+func (ca *CertificateAuthority) GetCACert() []byte {
+ return pem.EncodeToMemory(&pem.Block{
+ Type: "CERTIFICATE",
+ Bytes: ca.caCert.Raw,
+ })
+}
internal/proxy/handler.go
@@ -12,14 +12,16 @@ import (
// Handler handles HTTP proxy requests
type Handler struct {
+ server *Server
config *config.Config
client *http.Client
logger *slog.Logger
}
// NewHandler creates a new HTTP handler
-func NewHandler(cfg *config.Config, logger *slog.Logger) *Handler {
+func NewHandler(server *Server, cfg *config.Config, logger *slog.Logger) *Handler {
return &Handler{
+ server: server,
config: cfg,
client: &http.Client{
// Don't follow redirects - let the client handle them
@@ -50,11 +52,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// handleConnect handles CONNECT requests for HTTPS tunneling
-// This is a stub for Phase 1 - will be fully implemented in Phase 2
func (h *Handler) handleConnect(w http.ResponseWriter, r *http.Request) {
- h.logger.Warn("CONNECT not yet supported (HTTPS tunneling coming in Phase 2)",
- "host", r.Host)
- http.Error(w, "CONNECT method not yet implemented", http.StatusNotImplemented)
+ if err := h.handleConnectMITM(w, r); err != nil {
+ h.logger.Error("CONNECT handler error",
+ "host", r.Host,
+ "error", err)
+ }
}
// handleHTTP handles regular HTTP proxy requests
internal/proxy/mitm.go
@@ -0,0 +1,184 @@
+package proxy
+
+import (
+ "bufio"
+ "crypto/tls"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "strings"
+)
+
+// handleConnectMITM handles the HTTPS MITM proxy logic
+func (h *Handler) handleConnectMITM(w http.ResponseWriter, r *http.Request) error {
+ // Extract hostname and port from the request
+ host := r.Host
+ if !strings.Contains(host, ":") {
+ host = host + ":443"
+ }
+
+ hostname, _, err := net.SplitHostPort(host)
+ if err != nil {
+ h.logger.Error("failed to parse host", "host", host, "error", err)
+ return fmt.Errorf("invalid host: %w", err)
+ }
+
+ h.logger.Debug("CONNECT request",
+ "host", host,
+ "hostname", hostname)
+
+ // Hijack the connection
+ hijacker, ok := w.(http.Hijacker)
+ if !ok {
+ h.logger.Error("ResponseWriter doesn't support hijacking")
+ http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+ return fmt.Errorf("hijacking not supported")
+ }
+
+ clientConn, clientBuf, err := hijacker.Hijack()
+ if err != nil {
+ h.logger.Error("failed to hijack connection", "error", err)
+ http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+ return fmt.Errorf("hijack failed: %w", err)
+ }
+ defer clientConn.Close()
+
+ // Send "200 Connection Established" to client
+ response := "HTTP/1.1 200 Connection Established\r\n\r\n"
+ if _, err := clientConn.Write([]byte(response)); err != nil {
+ h.logger.Error("failed to send connection established", "error", err)
+ return fmt.Errorf("failed to send 200: %w", err)
+ }
+
+ // Get or generate certificate for this hostname
+ certPEM, keyPEM, err := h.server.certAuth.GetCertificate(hostname)
+ if err != nil {
+ h.logger.Error("failed to get certificate",
+ "hostname", hostname,
+ "error", err)
+ return fmt.Errorf("certificate generation failed: %w", err)
+ }
+
+ // Load the certificate
+ cert, err := tls.X509KeyPair(certPEM, keyPEM)
+ if err != nil {
+ h.logger.Error("failed to load certificate",
+ "hostname", hostname,
+ "error", err)
+ return fmt.Errorf("failed to load certificate: %w", err)
+ }
+
+ // Wrap client connection with TLS (server-side)
+ tlsConfig := &tls.Config{
+ Certificates: []tls.Certificate{cert},
+ }
+
+ tlsClientConn := tls.Server(clientConn, tlsConfig)
+ defer tlsClientConn.Close()
+
+ // Perform TLS handshake
+ if err := tlsClientConn.Handshake(); err != nil {
+ h.logger.Error("TLS handshake failed",
+ "hostname", hostname,
+ "error", err)
+ return fmt.Errorf("TLS handshake failed: %w", err)
+ }
+
+ h.logger.Debug("TLS handshake successful", "hostname", hostname)
+
+ // Now read the actual HTTP request from the encrypted connection
+ reader := bufio.NewReader(tlsClientConn)
+ req, err := http.ReadRequest(reader)
+ if err != nil {
+ // EOF is normal when client closes connection
+ if err != io.EOF {
+ h.logger.Error("failed to read request from TLS connection",
+ "hostname", hostname,
+ "error", err)
+ }
+ return nil // Don't return error for EOF
+ }
+
+ // Fix up the request URL
+ // In HTTPS proxy mode, the request URL doesn't have scheme/host
+ if req.URL.Scheme == "" {
+ req.URL.Scheme = "https"
+ }
+ if req.URL.Host == "" {
+ req.URL.Host = hostname
+ }
+ req.Host = hostname
+
+ h.logger.Debug("decrypted request",
+ "method", req.Method,
+ "url", req.URL.String(),
+ "host", req.Host)
+
+ // Connect to the remote server
+ remoteConn, err := tls.Dial("tcp", host, &tls.Config{
+ ServerName: hostname,
+ })
+ if err != nil {
+ h.logger.Error("failed to connect to remote server",
+ "host", host,
+ "error", err)
+ // Send error response to client
+ resp := &http.Response{
+ StatusCode: http.StatusBadGateway,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Body: io.NopCloser(strings.NewReader("Bad Gateway")),
+ }
+ resp.Write(tlsClientConn)
+ return fmt.Errorf("failed to connect to remote: %w", err)
+ }
+ defer remoteConn.Close()
+
+ // Send the request to the remote server
+ if err := req.Write(remoteConn); err != nil {
+ h.logger.Error("failed to write request to remote server",
+ "url", req.URL.String(),
+ "error", err)
+ resp := &http.Response{
+ StatusCode: http.StatusBadGateway,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Body: io.NopCloser(strings.NewReader("Bad Gateway")),
+ }
+ resp.Write(tlsClientConn)
+ return fmt.Errorf("failed to write request: %w", err)
+ }
+
+ // Read response from remote server
+ remoteReader := bufio.NewReader(remoteConn)
+ resp, err := http.ReadResponse(remoteReader, req)
+ if err != nil {
+ h.logger.Error("failed to read response from remote server",
+ "url", req.URL.String(),
+ "error", err)
+ return fmt.Errorf("failed to read response: %w", err)
+ }
+ defer resp.Body.Close()
+
+ // Send response back to client
+ if err := resp.Write(tlsClientConn); err != nil {
+ h.logger.Error("failed to write response to client",
+ "url", req.URL.String(),
+ "error", err)
+ return fmt.Errorf("failed to write response: %w", err)
+ }
+
+ h.logger.Info("proxied HTTPS request",
+ "method", req.Method,
+ "url", req.URL.String(),
+ "status", resp.StatusCode)
+
+ // If there's buffered data from the hijack, it should be handled
+ // but for CONNECT, there usually isn't any
+ if clientBuf.Reader.Buffered() > 0 {
+ h.logger.Warn("unexpected buffered data after CONNECT")
+ }
+
+ return nil
+}
internal/proxy/proxy.go
@@ -9,6 +9,7 @@ import (
"sync"
"time"
+ "github.com/internetarchive/gowarcprox/internal/certauth"
"github.com/internetarchive/gowarcprox/pkg/config"
)
@@ -18,6 +19,7 @@ type Server struct {
listener net.Listener
server *http.Server
handler *Handler
+ certAuth *certauth.CertificateAuthority
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
@@ -39,8 +41,21 @@ func NewServer(cfg *config.Config, logger *slog.Logger) (*Server, error) {
logger: logger,
}
+ // Initialize certificate authority
+ certAuth, err := certauth.NewCertificateAuthority(
+ cfg.CACertFile,
+ cfg.CertsDir,
+ "gowarcprox CA",
+ logger,
+ )
+ if err != nil {
+ cancel()
+ return nil, fmt.Errorf("failed to initialize certificate authority: %w", err)
+ }
+ s.certAuth = certAuth
+
// Create the HTTP handler
- s.handler = NewHandler(cfg, logger)
+ s.handler = NewHandler(s, cfg, logger)
// Create the HTTP server
s.server = &http.Server{
.gitignore
@@ -2,3 +2,6 @@
/venv/
/scratch/
/gowarcprox
+warcprox-ca/
+warcprox-ca.pem
+*.sqlite