diff --git a/docs/CONFIG.md b/docs/CONFIG.md index fce18f8..8d3160a 100644 --- a/docs/CONFIG.md +++ b/docs/CONFIG.md @@ -362,10 +362,27 @@ Log verbosity: `none`, `errors` (default), `info`, `debug`. CA certificate configuration for MITM TLS termination: - `path`: Directory for CA cert/key storage (default: `~/.claw-wrap/ca` on macOS, `/etc/openclaw/ca` on Linux) -- `validity_days`: Certificate validity period (default: 365) -- `organization`: CA organization name in certificate +- `cert_file`: Certificate filename (default: `ca.crt`). Use `tls.crt` for cert-manager compatibility. +- `key_file`: Key filename (default: `ca.key`). Use `tls.key` for cert-manager compatibility. +- `external`: Enable external CA mode (default: `false`). When `true`: + - Fails fast if CA files are missing (never auto-generates) + - Relaxes key permission check for k8s secret mounts (allows 0644) + - Watches files for changes and hot-reloads on rotation +- `validity_days`: Certificate validity period (default: 365, ignored in external mode) +- `organization`: CA organization name in certificate (ignored in external mode) -The CA cert is auto-generated on first start and auto-rotated 30 days before expiry. +**Self-managed mode** (default): The CA cert is auto-generated on first start and auto-rotated 30 days before expiry. + +**External mode** (`external: true`): Use with cert-manager or k8s secrets: + +```yaml +http_proxy: + ca: + path: /etc/claw/ca + cert_file: tls.crt + key_file: tls.key + external: true +``` ### `strip_response_headers` diff --git a/go.mod b/go.mod index 8dd27d8..722eb1f 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( require ( filippo.io/hpke v0.4.0 // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/itchyny/timefmt-go v0.1.7 // indirect golang.org/x/crypto v0.45.0 // indirect golang.org/x/net v0.47.0 // indirect diff --git a/go.sum b/go.sum index bdc097d..5045fb6 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/elazarl/goproxy v1.8.1 h1:/qGpPJGgIPOTZ7IoIQvjavocp//qYSe9LQnIGCgRY5k= github.com/elazarl/goproxy v1.8.1/go.mod h1:b5xm6W48AUHNpRTCvlnd0YVh+JafCCtsLsJZvvNTz+E= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/itchyny/gojq v0.12.18 h1:gFGHyt/MLbG9n6dqnvlliiya2TaMMh6FFaR2b1H6Drc= github.com/itchyny/gojq v0.12.18/go.mod h1:4hPoZ/3lN9fDL1D+aK7DY1f39XZpY9+1Xpjz8atrEkg= github.com/itchyny/timefmt-go v0.1.7 h1:xyftit9Tbw+Dc/huSSPJaEmX1TVL8lw5vxjJLK4GMMA= diff --git a/internal/config/config.go b/internal/config/config.go index cb8ebbc..55694be 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -90,6 +90,9 @@ func (c *HTTPProxyConfig) GetRequireAuth() bool { // CAConfig holds CA certificate configuration for MITM proxy. type CAConfig struct { Path string `yaml:"path"` + CertFile string `yaml:"cert_file"` // default "ca.crt" + KeyFile string `yaml:"key_file"` // default "ca.key" + External bool `yaml:"external"` // external management mode (cert-manager, etc.) ValidityDays int `yaml:"validity_days"` Organization string `yaml:"organization"` } @@ -171,9 +174,9 @@ type CredentialDef struct { // ToolDef defines a wrapped tool. type ToolDef struct { - Binary string `yaml:"binary"` - Timeout string `yaml:"timeout,omitempty"` - Env map[string]string `yaml:"env,omitempty"` // Unified env: credential refs, {{ interpolation }}, or literals + Binary string `yaml:"binary"` + Timeout string `yaml:"timeout,omitempty"` + Env map[string]string `yaml:"env,omitempty"` // Unified env: credential refs, {{ interpolation }}, or literals // Deprecated: Use Env instead. ForcedEnv values are always treated as literals. // Will be removed in a future version. ForcedEnv map[string]string `yaml:"forced_env,omitempty"` @@ -457,6 +460,14 @@ func (c *Config) validateHTTPProxy() error { return fmt.Errorf("ca.validity_days must be non-negative") } + // Validate CA filenames (prevent path traversal) + if err := validateCAFilename(cfg.CA.CertFile, "cert_file"); err != nil { + return err + } + if err := validateCAFilename(cfg.CA.KeyFile, "key_file"); err != nil { + return err + } + // Validate and compile routes for i := range cfg.Routes { route := &cfg.Routes[i] @@ -656,6 +667,31 @@ func validateSafeRelativePath(value string, allowNested bool) error { return nil } +// validateCAFilename ensures a CA filename is a simple basename without path traversal. +// Empty values are allowed (defaults apply). +func validateCAFilename(filename, field string) error { + if filename == "" { + return nil // Defaults apply + } + // Must be a simple basename (no directory components) + if strings.ContainsAny(filename, "/\\") { + return fmt.Errorf("ca.%s: must be a filename, not a path", field) + } + if filename == "." || filename == ".." { + return fmt.Errorf("ca.%s: invalid filename %q", field, filename) + } + if strings.ContainsRune(filename, '\x00') { + return fmt.Errorf("ca.%s: contains NUL byte", field) + } + // Reject control characters (log injection prevention) + for _, r := range filename { + if r < 32 || r == 127 { + return fmt.Errorf("ca.%s: contains control character", field) + } + } + return nil +} + // LoadDefault loads the configuration from the default path. func LoadDefault() (*Config, error) { return Load(DefaultConfigPath) @@ -969,6 +1005,30 @@ func (c *Config) GetHTTPProxyRequireAuth() bool { return c.HTTPProxy.GetRequireAuth() } +// GetHTTPProxyCACertFile returns the CA certificate filename (default "ca.crt"). +func (c *Config) GetHTTPProxyCACertFile() string { + if c.HTTPProxy == nil || c.HTTPProxy.CA.CertFile == "" { + return "ca.crt" + } + return c.HTTPProxy.CA.CertFile +} + +// GetHTTPProxyCAKeyFile returns the CA key filename (default "ca.key"). +func (c *Config) GetHTTPProxyCAKeyFile() string { + if c.HTTPProxy == nil || c.HTTPProxy.CA.KeyFile == "" { + return "ca.key" + } + return c.HTTPProxy.CA.KeyFile +} + +// GetHTTPProxyCAExternal returns whether the CA is externally managed. +func (c *Config) GetHTTPProxyCAExternal() bool { + if c.HTTPProxy == nil { + return false + } + return c.HTTPProxy.CA.External +} + // GetTimeout returns the tool-specific timeout or falls back to the global default. func (t *ToolDef) GetTimeout(globalDefault time.Duration) time.Duration { if t.Timeout == "" { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index b4db55b..3276329 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1706,3 +1706,136 @@ func TestAuditConfig_BoolDefaults(t *testing.T) { t.Error("GetIncludeDuration() should be false when explicitly set") } } + +// --- CAConfig tests --- + +func TestCAConfig_DefaultFilenames(t *testing.T) { + cfg := &Config{ + HTTPProxy: &HTTPProxyConfig{ + CA: CAConfig{ + Path: "/etc/claw/ca", + // No CertFile or KeyFile specified + }, + }, + } + + if got := cfg.GetHTTPProxyCACertFile(); got != "ca.crt" { + t.Errorf("GetHTTPProxyCACertFile() = %q, want ca.crt", got) + } + if got := cfg.GetHTTPProxyCAKeyFile(); got != "ca.key" { + t.Errorf("GetHTTPProxyCAKeyFile() = %q, want ca.key", got) + } +} + +func TestCAConfig_CustomFilenames(t *testing.T) { + cfg := &Config{ + HTTPProxy: &HTTPProxyConfig{ + CA: CAConfig{ + Path: "/etc/claw/ca", + CertFile: "tls.crt", + KeyFile: "tls.key", + }, + }, + } + + if got := cfg.GetHTTPProxyCACertFile(); got != "tls.crt" { + t.Errorf("GetHTTPProxyCACertFile() = %q, want tls.crt", got) + } + if got := cfg.GetHTTPProxyCAKeyFile(); got != "tls.key" { + t.Errorf("GetHTTPProxyCAKeyFile() = %q, want tls.key", got) + } +} + +func TestCAConfig_ExternalMode(t *testing.T) { + cfg := &Config{ + HTTPProxy: &HTTPProxyConfig{ + CA: CAConfig{ + Path: "/etc/claw/ca", + External: true, + }, + }, + } + + if !cfg.GetHTTPProxyCAExternal() { + t.Error("GetHTTPProxyCAExternal() = false, want true") + } + + // Default is false + cfg2 := &Config{ + HTTPProxy: &HTTPProxyConfig{ + CA: CAConfig{ + Path: "/etc/claw/ca", + }, + }, + } + if cfg2.GetHTTPProxyCAExternal() { + t.Error("GetHTTPProxyCAExternal() default = true, want false") + } +} + +func TestCAConfig_NilHTTPProxy(t *testing.T) { + cfg := &Config{} + + if got := cfg.GetHTTPProxyCACertFile(); got != "ca.crt" { + t.Errorf("GetHTTPProxyCACertFile() with nil HTTPProxy = %q, want ca.crt", got) + } + if got := cfg.GetHTTPProxyCAKeyFile(); got != "ca.key" { + t.Errorf("GetHTTPProxyCAKeyFile() with nil HTTPProxy = %q, want ca.key", got) + } + if cfg.GetHTTPProxyCAExternal() { + t.Error("GetHTTPProxyCAExternal() with nil HTTPProxy = true, want false") + } +} + +// --- CA filename validation (path traversal prevention) --- + +func TestCAConfig_FilenameValidation_PathTraversal(t *testing.T) { + tests := []struct { + name string + certFile string + keyFile string + wantErr bool + }{ + {"valid defaults", "", "", false}, + {"valid custom", "tls.crt", "tls.key", false}, + {"path traversal cert", "../../../etc/passwd", "key.pem", true}, + {"path traversal key", "cert.pem", "../secret.key", true}, + {"absolute path cert", "/etc/ssl/ca.crt", "ca.key", true}, + {"directory in cert", "foo/bar.crt", "ca.key", true}, + {"dot filename", ".", "ca.key", true}, + {"dotdot filename", "..", "ca.key", true}, + {"backslash path", "foo\\bar.crt", "ca.key", true}, + {"nul byte cert", "ca\x00.crt", "ca.key", true}, + {"control char cert", "ca\n.crt", "ca.key", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := Config{ + Credentials: map[string]CredentialDef{ + "test": {Source: "env:TEST"}, + }, + HTTPProxy: &HTTPProxyConfig{ + Enabled: true, + CA: CAConfig{ + Path: "/tmp/ca", + CertFile: tt.certFile, + KeyFile: tt.keyFile, + }, + Routes: []ProxyRoute{ + {Host: "api.example.com", Inject: InjectSpec{Header: "X-Api-Key", Value: "{{test}}"}}, + }, + }, + } + err := cfg.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr && err != nil { + if !strings.Contains(err.Error(), "ca.cert_file") && !strings.Contains(err.Error(), "ca.key_file") { + t.Errorf("error should mention ca.cert_file or ca.key_file, got: %v", err) + } + } + }) + } +} diff --git a/internal/config/interpolate.go b/internal/config/interpolate.go index 78946ce..477c259 100644 --- a/internal/config/interpolate.go +++ b/internal/config/interpolate.go @@ -135,4 +135,3 @@ func CredentialNamesSet(credentials map[string]CredentialDef) map[string]struct{ } return names } - diff --git a/internal/credentials/parser.go b/internal/credentials/parser.go index 7294075..4d81c11 100644 --- a/internal/credentials/parser.go +++ b/internal/credentials/parser.go @@ -10,11 +10,11 @@ import ( type Backend string const ( - BackendPass Backend = "pass" - BackendEnv Backend = "env" + BackendPass Backend = "pass" + BackendEnv Backend = "env" Backend1Password Backend = "op" - BackendAge Backend = "age" - BackendKeychain Backend = "keychain" + BackendAge Backend = "age" + BackendKeychain Backend = "keychain" BackendBitwarden Backend = "bw" ) diff --git a/internal/httpproxy/ca.go b/internal/httpproxy/ca.go index c1a3e84..1f779fd 100644 --- a/internal/httpproxy/ca.go +++ b/internal/httpproxy/ca.go @@ -13,8 +13,11 @@ import ( "math/big" "os" "path/filepath" + "sync" "time" + "github.com/fsnotify/fsnotify" + "claw-wrap/internal/config" "claw-wrap/internal/paths" ) @@ -35,8 +38,16 @@ type CAManager struct { certPath string keyPath string config config.CAConfig + external bool // external management mode (cert-manager, k8s secrets, etc.) - cert *tls.Certificate + cert *tls.Certificate + certMu sync.RWMutex // protects cert during hot-reload + watcher *fsnotify.Watcher // file watcher for external mode + stopCh chan struct{} // stop signal for watcher goroutine + watcherWg sync.WaitGroup // wait for watcher goroutine to exit + watcherMu sync.Mutex // protects watcher lifecycle (start/stop) + watcherInit bool // true if watcher was started + permLogOnce sync.Once // ensures permission relaxation logged only once } // NewCAManager creates a new CA manager with the given configuration. @@ -46,10 +57,21 @@ func NewCAManager(cfg config.CAConfig) *CAManager { path = DefaultCAPath() } + // Use custom filenames or defaults + certFile := cfg.CertFile + if certFile == "" { + certFile = "ca.crt" + } + keyFile := cfg.KeyFile + if keyFile == "" { + keyFile = "ca.key" + } + return &CAManager{ - certPath: filepath.Join(path, "ca.crt"), - keyPath: filepath.Join(path, "ca.key"), + certPath: filepath.Join(path, certFile), + keyPath: filepath.Join(path, keyFile), config: cfg, + external: cfg.External, } } @@ -64,22 +86,72 @@ func (m *CAManager) EnsureCA() (*tls.Certificate, error) { // Try to load existing CA cert, err := m.loadCA() if err == nil { - // Check if rotation is needed - if m.needsRotation(cert) { + // Validate it's actually a CA certificate + if err := m.validateCA(cert); err != nil { + return nil, fmt.Errorf("CA validation failed: %w", err) + } + + // Check if rotation is needed (only for self-managed CAs) + if !m.external && m.needsRotation(cert) { log.Printf("[INFO] CA certificate expires soon, regenerating") return m.generateAndSaveCA() } + + m.certMu.Lock() m.cert = cert + m.certMu.Unlock() + + if m.external { + log.Printf("[INFO] Using external CA from %s", sanitizePath(m.certPath)) + } return cert, nil } + // External mode: fail if CA files don't exist + if m.external { + return nil, fmt.Errorf("external CA not found at %s: %w (hint: ensure cert-manager secret is mounted)", sanitizePath(m.certPath), err) + } + // Generate new CA log.Printf("[INFO] Generating new CA certificate") return m.generateAndSaveCA() } -// Certificate returns the loaded CA certificate. +// validateCA checks that the loaded certificate is actually a CA. +func (m *CAManager) validateCA(cert *tls.Certificate) error { + if len(cert.Certificate) == 0 { + return fmt.Errorf("certificate chain is empty") + } + + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return fmt.Errorf("parse certificate: %w", err) + } + + if !x509Cert.IsCA { + return fmt.Errorf("certificate is not a CA (IsCA=false)") + } + + if x509Cert.KeyUsage&x509.KeyUsageCertSign == 0 { + return fmt.Errorf("certificate lacks KeyUsageCertSign") + } + + // Check certificate expiry + now := time.Now() + if now.Before(x509Cert.NotBefore) { + return fmt.Errorf("certificate not yet valid (NotBefore: %s)", x509Cert.NotBefore.Format(time.RFC3339)) + } + if now.After(x509Cert.NotAfter) { + return fmt.Errorf("certificate expired on %s", x509Cert.NotAfter.Format(time.RFC3339)) + } + + return nil +} + +// Certificate returns the loaded CA certificate (thread-safe). func (m *CAManager) Certificate() *tls.Certificate { + m.certMu.RLock() + defer m.certMu.RUnlock() return m.cert } @@ -90,10 +162,14 @@ func (m *CAManager) CertPath() string { // NeedsRotation checks if the CA certificate needs rotation. func (m *CAManager) NeedsRotation() bool { - if m.cert == nil { + m.certMu.RLock() + cert := m.cert + m.certMu.RUnlock() + + if cert == nil { return true } - return m.needsRotation(m.cert) + return m.needsRotation(cert) } func (m *CAManager) needsRotation(cert *tls.Certificate) bool { @@ -113,13 +189,20 @@ func (m *CAManager) needsRotation(cert *tls.Certificate) bool { func (m *CAManager) loadCA() (*tls.Certificate, error) { // Check key file permissions before loading (security: detect compromised keys) - info, err := os.Stat(m.keyPath) - if err != nil { - return nil, err - } - perm := info.Mode().Perm() - if perm > 0o600 { - return nil, fmt.Errorf("CA key %s has insecure permissions %04o (want 0600 or stricter)", m.keyPath, perm) + // Skip permission check for external mode (k8s secrets mount with 0644) + if !m.external { + info, err := os.Stat(m.keyPath) + if err != nil { + return nil, err + } + perm := info.Mode().Perm() + if perm > 0o600 { + return nil, fmt.Errorf("CA key has insecure permissions %04o (want 0600 or stricter)", perm) + } + } else { + m.permLogOnce.Do(func() { + log.Printf("[INFO] External CA mode: key permission check relaxed (k8s compat)") + }) } cert, err := tls.LoadX509KeyPair(m.certPath, m.keyPath) @@ -196,8 +279,14 @@ func (m *CAManager) generateAndSaveCA() (*tls.Certificate, error) { os.Remove(m.certPath) return nil, fmt.Errorf("write key: %w", err) } + // Ensure permissions are correct even if file existed with different perms + if err := os.Chmod(m.keyPath, 0600); err != nil { + os.Remove(m.certPath) + os.Remove(m.keyPath) + return nil, fmt.Errorf("chmod key: %w", err) + } - log.Printf("[INFO] CA certificate saved to %s (valid for %d days)", m.certPath, validityDays) + log.Printf("[INFO] CA certificate saved to %s (valid for %d days)", sanitizePath(m.certPath), validityDays) // Load the saved certificate cert, err := tls.LoadX509KeyPair(m.certPath, m.keyPath) @@ -205,6 +294,158 @@ func (m *CAManager) generateAndSaveCA() (*tls.Certificate, error) { return nil, fmt.Errorf("load generated CA: %w", err) } + m.certMu.Lock() m.cert = &cert + m.certMu.Unlock() return &cert, nil } + +// StartWatcher starts watching CA files for changes (external mode only). +// Returns nil if not in external mode. Safe to call multiple times. +func (m *CAManager) StartWatcher() error { + if !m.external { + return nil + } + + m.watcherMu.Lock() + defer m.watcherMu.Unlock() + + if m.watcherInit { + return nil // Already running + } + + watcher, err := fsnotify.NewWatcher() + if err != nil { + return fmt.Errorf("create watcher: %w", err) + } + + // Watch the directory containing the cert (handles k8s secret updates) + dir := filepath.Dir(m.certPath) + if err := watcher.Add(dir); err != nil { + watcher.Close() + return fmt.Errorf("watch directory %s: %w", sanitizePath(dir), err) + } + + m.watcher = watcher + m.stopCh = make(chan struct{}) + m.watcherInit = true + + // Capture references for goroutine to avoid races with StopWatcher + stopCh := m.stopCh + events := watcher.Events + errors := watcher.Errors + + m.watcherWg.Add(1) + go m.watchLoop(stopCh, events, errors) + + log.Printf("[INFO] CA file watcher started for %s", sanitizePath(dir)) + return nil +} + +// StopWatcher stops the file watcher. Safe to call multiple times. +func (m *CAManager) StopWatcher() { + m.watcherMu.Lock() + if !m.watcherInit { + m.watcherMu.Unlock() + return + } + + // Signal stop and mark as not initialized + if m.stopCh != nil { + close(m.stopCh) + m.stopCh = nil + } + m.watcherInit = false + watcher := m.watcher + m.watcher = nil + m.watcherMu.Unlock() + + // Wait outside lock to avoid deadlock with watchLoop + m.watcherWg.Wait() + + if watcher != nil { + watcher.Close() + } +} + +func (m *CAManager) watchLoop(stopCh <-chan struct{}, events <-chan fsnotify.Event, errors <-chan error) { + defer m.watcherWg.Done() + + certFile := filepath.Base(m.certPath) + keyFile := filepath.Base(m.keyPath) + + for { + select { + case <-stopCh: + return + case event, ok := <-events: + if !ok { + return + } + + // Check if this event is relevant: + // - Direct file changes (tls.crt, tls.key, ca.crt, ca.key) + // - k8s secret symlink updates (..data symlink change) + eventFile := filepath.Base(event.Name) + isRelevantFile := eventFile == certFile || eventFile == keyFile + isK8sSymlink := eventFile == "..data" // k8s atomic secret update + + if !isRelevantFile && !isK8sSymlink { + continue + } + + // React to write, create, or chmod events + if event.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Chmod) == 0 { + continue + } + + // Reload the certificate + if err := m.reloadCA(); err != nil { + log.Printf("[WARN] CA reload failed: %v", err) + } else { + log.Printf("[INFO] CA certificate reloaded") + } + + case err, ok := <-errors: + if !ok { + return + } + log.Printf("[WARN] CA watcher error: %v", err) + } + } +} + +func (m *CAManager) reloadCA() error { + cert, err := m.loadCA() + if err != nil { + return err + } + + if err := m.validateCA(cert); err != nil { + return fmt.Errorf("validation failed: %w", err) + } + + m.certMu.Lock() + m.cert = cert + m.certMu.Unlock() + + return nil +} + +// External returns whether this CA is externally managed. +func (m *CAManager) External() bool { + return m.external +} + +// sanitizePath removes control characters from a path for safe logging. +func sanitizePath(path string) string { + var result []rune + for _, r := range path { + if r < 32 || r == 127 { + result = append(result, '?') + } else { + result = append(result, r) + } + } + return string(result) +} diff --git a/internal/httpproxy/ca_test.go b/internal/httpproxy/ca_test.go index f0b43ac..caeaffe 100644 --- a/internal/httpproxy/ca_test.go +++ b/internal/httpproxy/ca_test.go @@ -1,10 +1,16 @@ package httpproxy import ( + "crypto/rand" + "crypto/rsa" "crypto/tls" "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" "os" "path/filepath" + "strings" "testing" "time" @@ -255,3 +261,470 @@ func TestCAManager_ValidCertificate(t *testing.T) { t.Error("empty certificate chain") } } + +// --- External CA mode tests --- + +func TestCAManager_External_FailsOnMissingFiles(t *testing.T) { + tmpDir := t.TempDir() + cfg := config.CAConfig{ + Path: tmpDir, + External: true, + } + + mgr := NewCAManager(cfg) + _, err := mgr.EnsureCA() + if err == nil { + t.Error("expected error when external CA files missing") + } + // Should mention "external CA not found" + if !strings.Contains(err.Error(), "external CA not found") { + t.Errorf("error message should mention 'external CA not found': %v", err) + } +} + +func TestCAManager_External_LoadsExistingCA(t *testing.T) { + tmpDir := t.TempDir() + + // Create a valid CA cert/key + createTestCA(t, tmpDir, "ca.crt", "ca.key") + + cfg := config.CAConfig{ + Path: tmpDir, + External: true, + } + + mgr := NewCAManager(cfg) + cert, err := mgr.EnsureCA() + if err != nil { + t.Fatalf("EnsureCA error: %v", err) + } + if cert == nil { + t.Error("expected non-nil certificate") + } + if !mgr.External() { + t.Error("External() should return true") + } +} + +func TestCAManager_External_CustomFilenames(t *testing.T) { + tmpDir := t.TempDir() + + // Create CA with custom filenames (like cert-manager) + createTestCA(t, tmpDir, "tls.crt", "tls.key") + + cfg := config.CAConfig{ + Path: tmpDir, + CertFile: "tls.crt", + KeyFile: "tls.key", + External: true, + } + + mgr := NewCAManager(cfg) + if mgr.CertPath() != filepath.Join(tmpDir, "tls.crt") { + t.Errorf("CertPath = %q, want %q", mgr.CertPath(), filepath.Join(tmpDir, "tls.crt")) + } + + cert, err := mgr.EnsureCA() + if err != nil { + t.Fatalf("EnsureCA error: %v", err) + } + if cert == nil { + t.Error("expected non-nil certificate") + } +} + +func TestCAManager_External_SkipsPermissionCheck(t *testing.T) { + tmpDir := t.TempDir() + + // Create CA with 0644 permissions (like k8s secrets) + createTestCAWithPerms(t, tmpDir, "ca.crt", "ca.key", 0644) + + // External mode should accept 0644 + cfg := config.CAConfig{ + Path: tmpDir, + External: true, + } + mgr := NewCAManager(cfg) + cert, err := mgr.EnsureCA() + if err != nil { + t.Fatalf("external mode should accept 0644 perms: %v", err) + } + if cert == nil { + t.Error("expected non-nil certificate") + } +} + +func TestCAManager_SelfManaged_RegeneratesOnInsecurePerms(t *testing.T) { + tmpDir := t.TempDir() + + // Create CA with 0644 permissions + createTestCAWithPerms(t, tmpDir, "ca.crt", "ca.key", 0644) + + // Get original serial + origCert, _ := tls.LoadX509KeyPair( + filepath.Join(tmpDir, "ca.crt"), + filepath.Join(tmpDir, "ca.key"), + ) + origX509, _ := x509.ParseCertificate(origCert.Certificate[0]) + origSerial := origX509.SerialNumber + + // Self-managed mode should regenerate CA (not use the insecure one) + cfg := config.CAConfig{ + Path: tmpDir, + External: false, + } + mgr := NewCAManager(cfg) + cert, err := mgr.EnsureCA() + if err != nil { + t.Fatalf("EnsureCA error: %v", err) + } + + // Should have generated a new CA with different serial + newX509, _ := x509.ParseCertificate(cert.Certificate[0]) + if origSerial.Cmp(newX509.SerialNumber) == 0 { + t.Error("expected new CA to be generated due to insecure permissions") + } + + // New key should have 0600 permissions + keyInfo, err := os.Stat(filepath.Join(tmpDir, "ca.key")) + if err != nil { + t.Fatalf("stat key: %v", err) + } + if keyInfo.Mode().Perm() != 0600 { + t.Errorf("regenerated key perms = %o, want 0600", keyInfo.Mode().Perm()) + } +} + +func TestCAManager_External_ValidatesCAFlag(t *testing.T) { + tmpDir := t.TempDir() + + // Create a non-CA certificate (end-entity) + createTestNonCACert(t, tmpDir, "ca.crt", "ca.key") + + cfg := config.CAConfig{ + Path: tmpDir, + External: true, + } + mgr := NewCAManager(cfg) + _, err := mgr.EnsureCA() + if err == nil { + t.Error("expected error for non-CA certificate") + } + if !strings.Contains(err.Error(), "IsCA=false") { + t.Errorf("error should mention IsCA: %v", err) + } +} + +// --- Watcher tests --- + +func TestCAManager_Watcher_StartsOnlyInExternalMode(t *testing.T) { + tmpDir := t.TempDir() + createTestCA(t, tmpDir, "ca.crt", "ca.key") + + // Self-managed mode: watcher should not start + cfg := config.CAConfig{ + Path: tmpDir, + External: false, + } + mgr := NewCAManager(cfg) + if err := mgr.StartWatcher(); err != nil { + t.Fatalf("StartWatcher error: %v", err) + } + // Should be a no-op, no watcher created + mgr.StopWatcher() // Should not panic + + // External mode: watcher should start + cfg.External = true + mgr2 := NewCAManager(cfg) + _, err := mgr2.EnsureCA() + if err != nil { + t.Fatalf("EnsureCA error: %v", err) + } + if err := mgr2.StartWatcher(); err != nil { + t.Fatalf("StartWatcher error: %v", err) + } + mgr2.StopWatcher() +} + +func TestCAManager_Watcher_ReloadsOnChange(t *testing.T) { + tmpDir := t.TempDir() + createTestCA(t, tmpDir, "ca.crt", "ca.key") + + cfg := config.CAConfig{ + Path: tmpDir, + External: true, + } + mgr := NewCAManager(cfg) + cert1, err := mgr.EnsureCA() + if err != nil { + t.Fatalf("EnsureCA error: %v", err) + } + + if err := mgr.StartWatcher(); err != nil { + t.Fatalf("StartWatcher error: %v", err) + } + defer mgr.StopWatcher() + + // Get original serial + x509Cert1, _ := x509.ParseCertificate(cert1.Certificate[0]) + origSerial := x509Cert1.SerialNumber + + // Replace with new CA + createTestCA(t, tmpDir, "ca.crt", "ca.key") + + // Poll for certificate change with timeout (avoid flaky fixed sleep) + deadline := time.Now().Add(2 * time.Second) + var reloaded bool + for time.Now().Before(deadline) { + cert2 := mgr.Certificate() + if cert2 != nil { + x509Cert2, _ := x509.ParseCertificate(cert2.Certificate[0]) + if origSerial.Cmp(x509Cert2.SerialNumber) != 0 { + reloaded = true + break + } + } + time.Sleep(50 * time.Millisecond) + } + + if !reloaded { + t.Error("certificate serial unchanged after 2s, watcher may not have reloaded") + } +} + +func TestCAManager_Watcher_StopsCleanly(t *testing.T) { + tmpDir := t.TempDir() + createTestCA(t, tmpDir, "ca.crt", "ca.key") + + cfg := config.CAConfig{ + Path: tmpDir, + External: true, + } + mgr := NewCAManager(cfg) + _, err := mgr.EnsureCA() + if err != nil { + t.Fatalf("EnsureCA error: %v", err) + } + + if err := mgr.StartWatcher(); err != nil { + t.Fatalf("StartWatcher error: %v", err) + } + + // Stop should not panic or hang + mgr.StopWatcher() + + // Double stop should be safe + mgr.StopWatcher() +} + +func TestCAManager_ValidatesExpiry(t *testing.T) { + tmpDir := t.TempDir() + + // Create an expired certificate + createExpiredTestCA(t, tmpDir, "ca.crt", "ca.key") + + cfg := config.CAConfig{ + Path: tmpDir, + External: true, + } + mgr := NewCAManager(cfg) + _, err := mgr.EnsureCA() + if err == nil { + t.Fatal("EnsureCA should fail on expired certificate") + } + if !strings.Contains(err.Error(), "expired") { + t.Errorf("error should mention 'expired', got: %v", err) + } +} + +func TestCAManager_ValidatesNotYetValid(t *testing.T) { + tmpDir := t.TempDir() + + // Create a certificate that's not yet valid + createFutureTestCA(t, tmpDir, "ca.crt", "ca.key") + + cfg := config.CAConfig{ + Path: tmpDir, + External: true, + } + mgr := NewCAManager(cfg) + _, err := mgr.EnsureCA() + if err == nil { + t.Fatal("EnsureCA should fail on not-yet-valid certificate") + } + if !strings.Contains(err.Error(), "not yet valid") { + t.Errorf("error should mention 'not yet valid', got: %v", err) + } +} + +func TestSanitizePath(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"/normal/path/ca.crt", "/normal/path/ca.crt"}, + {"/path/with\nnewline", "/path/with?newline"}, + {"/path/with\ttab", "/path/with?tab"}, + {"/path/with\x00nul", "/path/with?nul"}, + {"/path/with\x7fdel", "/path/with?del"}, + {"/path\r\n/crlf", "/path??/crlf"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := sanitizePath(tt.input) + if got != tt.want { + t.Errorf("sanitizePath(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +// --- Helper functions --- + +func createTestCA(t *testing.T, dir, certFile, keyFile string) { + t.Helper() + createTestCAWithPerms(t, dir, certFile, keyFile, 0600) +} + +func createTestCAWithPerms(t *testing.T, dir, certFile, keyFile string, keyPerm os.FileMode) { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + + serial, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + template := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{Organization: []string{"Test CA"}, CommonName: "Test CA"}, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1, 0, 0), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + t.Fatalf("create cert: %v", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + if err := os.WriteFile(filepath.Join(dir, certFile), certPEM, 0644); err != nil { + t.Fatalf("write cert: %v", err) + } + + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + if err := os.WriteFile(filepath.Join(dir, keyFile), keyPEM, keyPerm); err != nil { + t.Fatalf("write key: %v", err) + } +} + +func createTestNonCACert(t *testing.T, dir, certFile, keyFile string) { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + + serial, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + template := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{Organization: []string{"Not A CA"}, CommonName: "End Entity"}, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1, 0, 0), + KeyUsage: x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + IsCA: false, // Not a CA! + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + t.Fatalf("create cert: %v", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + if err := os.WriteFile(filepath.Join(dir, certFile), certPEM, 0644); err != nil { + t.Fatalf("write cert: %v", err) + } + + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + if err := os.WriteFile(filepath.Join(dir, keyFile), keyPEM, 0644); err != nil { + t.Fatalf("write key: %v", err) + } +} + +func createExpiredTestCA(t *testing.T, dir, certFile, keyFile string) { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + + serial, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + // Certificate that expired yesterday + template := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{Organization: []string{"Test CA"}, CommonName: "Expired CA"}, + NotBefore: time.Now().AddDate(0, 0, -30), + NotAfter: time.Now().AddDate(0, 0, -1), // Expired yesterday + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + t.Fatalf("create cert: %v", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + if err := os.WriteFile(filepath.Join(dir, certFile), certPEM, 0644); err != nil { + t.Fatalf("write cert: %v", err) + } + + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + if err := os.WriteFile(filepath.Join(dir, keyFile), keyPEM, 0644); err != nil { + t.Fatalf("write key: %v", err) + } +} + +func createFutureTestCA(t *testing.T, dir, certFile, keyFile string) { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + + serial, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + // Certificate that starts tomorrow + template := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{Organization: []string{"Test CA"}, CommonName: "Future CA"}, + NotBefore: time.Now().AddDate(0, 0, 1), // Valid from tomorrow + NotAfter: time.Now().AddDate(1, 0, 0), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + t.Fatalf("create cert: %v", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + if err := os.WriteFile(filepath.Join(dir, certFile), certPEM, 0644); err != nil { + t.Fatalf("write cert: %v", err) + } + + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + if err := os.WriteFile(filepath.Join(dir, keyFile), keyPEM, 0644); err != nil { + t.Fatalf("write key: %v", err) + } +} diff --git a/internal/httpproxy/proxy.go b/internal/httpproxy/proxy.go index 47383af..0bc748c 100644 --- a/internal/httpproxy/proxy.go +++ b/internal/httpproxy/proxy.go @@ -4,6 +4,7 @@ package httpproxy import ( "context" "crypto/subtle" + "crypto/tls" "encoding/base64" "fmt" "log" @@ -150,8 +151,14 @@ func (p *Proxy) Start(addr string) error { return fmt.Errorf("setup MITM: %w", err) } + // Start file watcher for external CA hot-reload + if err := p.ca.StartWatcher(); err != nil { + return fmt.Errorf("start CA watcher: %w", err) + } + listener, err := net.Listen("tcp", addr) if err != nil { + p.ca.StopWatcher() // Clean up watcher if listen fails return fmt.Errorf("listen: %w", err) } p.listener = listener @@ -181,16 +188,36 @@ func (p *Proxy) setupMITM() error { return fmt.Errorf("ensure CA: %w", err) } - // Set the CA for goproxy + // Set the CA for goproxy (required for some internal checks) goproxy.GoproxyCa = *cert - goproxy.OkConnect = &goproxy.ConnectAction{Action: goproxy.ConnectMitm, TLSConfig: goproxy.TLSConfigFromCA(cert)} - goproxy.MitmConnect = &goproxy.ConnectAction{Action: goproxy.ConnectMitm, TLSConfig: goproxy.TLSConfigFromCA(cert)} - goproxy.RejectConnect = &goproxy.ConnectAction{Action: goproxy.ConnectReject, TLSConfig: goproxy.TLSConfigFromCA(cert)} + + // Use lazy TLSConfig that reads fresh cert on each CONNECT. + // This enables hot-reload: when CAManager reloads cert from disk, + // new HTTPS connections automatically use the updated certificate. + lazyTLSConfig := p.makeLazyTLSConfig() + goproxy.OkConnect = &goproxy.ConnectAction{Action: goproxy.ConnectMitm, TLSConfig: lazyTLSConfig} + goproxy.MitmConnect = &goproxy.ConnectAction{Action: goproxy.ConnectMitm, TLSConfig: lazyTLSConfig} + goproxy.RejectConnect = &goproxy.ConnectAction{Action: goproxy.ConnectReject, TLSConfig: lazyTLSConfig} log.Printf("[INFO] MITM enabled with CA from %s", p.ca.CertPath()) return nil } +// makeLazyTLSConfig returns a TLS config function that reads the latest +// certificate from CAManager on each call. This enables hot-reload of +// CA certificates without restarting the proxy. +func (p *Proxy) makeLazyTLSConfig() func(host string, ctx *goproxy.ProxyCtx) (*tls.Config, error) { + return func(host string, ctx *goproxy.ProxyCtx) (*tls.Config, error) { + cert := p.ca.Certificate() + if cert == nil { + return nil, fmt.Errorf("CA certificate not loaded") + } + // Delegate to goproxy's standard TLSConfigFromCA which generates + // per-domain certificates signed by the CA + return goproxy.TLSConfigFromCA(cert)(host, ctx) + } +} + // CAPath returns the path to the CA certificate for trust injection. func (p *Proxy) CAPath() string { return p.ca.CertPath() @@ -202,6 +229,9 @@ func (p *Proxy) Stop() error { p.stopOnce.Do(func() { close(p.shutdownCh) + // Stop CA file watcher + p.ca.StopWatcher() + if p.listener != nil { if closeErr := p.listener.Close(); closeErr != nil { err = fmt.Errorf("close listener: %w", closeErr) diff --git a/internal/httpproxy/security_test.go b/internal/httpproxy/security_test.go index 0ffba27..39d0d2e 100644 --- a/internal/httpproxy/security_test.go +++ b/internal/httpproxy/security_test.go @@ -137,7 +137,7 @@ func TestSanitizeForLog(t *testing.T) { {"normal log message", "normal log message"}, {"token=abc123", "token=[REDACTED]"}, {"Token: abc123", "Token=[REDACTED]"}, - {"Authorization: Bearer-abc", "Authorization=[REDACTED]"}, // single word with dash + {"Authorization: Bearer-abc", "Authorization=[REDACTED]"}, // single word with dash {"api_key=secret123", "api_key=[REDACTED]"}, {"API-KEY: xyz", "API-KEY=[REDACTED]"}, {"password: hunter2", "password=[REDACTED]"},