Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 174 additions & 3 deletions ssh-agent-filter/ssh-agent-filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"syscall"

"github.com/jnoxon/ssh-agent-utils/filter"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)

Expand Down Expand Up @@ -38,6 +40,171 @@ TODO:

*/

// ReconnectableAgent wraps an agent.Agent with automatic reconnection capability
type ReconnectableAgent struct {
upstreamName string
mu sync.RWMutex
agent agent.Agent
conn net.Conn
}

// NewReconnectableAgent creates a new reconnectable agent
func NewReconnectableAgent(upstreamName string) (*ReconnectableAgent, error) {
ra := &ReconnectableAgent{upstreamName: upstreamName}
err := ra.connect()
return ra, err
}

// connect establishes a new connection to the upstream agent
func (ra *ReconnectableAgent) connect() error {
ra.mu.Lock()
defer ra.mu.Unlock()

// Close existing connection if any
if ra.conn != nil {
ra.conn.Close()
}

// Establish new connection
conn, err := net.Dial("unix", ra.upstreamName)
if err != nil {
return fmt.Errorf("failed to connect to upstream agent: %w", err)
}

ra.conn = conn
ra.agent = agent.NewClient(conn)
return nil
}

// reconnect attempts to reconnect and returns true if successful
func (ra *ReconnectableAgent) reconnect() bool {
err := ra.connect()
return err == nil
}

// withReconnect wraps an operation with automatic reconnection on failure
func (ra *ReconnectableAgent) withReconnect(operation func(agent.Agent) error) error {
// Try the operation first
ra.mu.RLock()
agent := ra.agent
ra.mu.RUnlock()

err := operation(agent)
if err == nil {
return nil
}

// Check if it's a connection-related error
if isConnectionError(err) {
fmt.Fprintf(os.Stderr, "Connection error detected, attempting reconnect: %v\n", err)

// Try to reconnect
if ra.reconnect() {
fmt.Fprintf(os.Stderr, "Reconnection successful, retrying operation\n")
// Retry the operation with the new connection
ra.mu.RLock()
agent = ra.agent
ra.mu.RUnlock()
return operation(agent)
} else {
fmt.Fprintf(os.Stderr, "Reconnection failed\n")
}
}

return err
}

// isConnectionError checks if an error is related to connection issues
func isConnectionError(err error) bool {
if err == nil {
return false
}

errStr := err.Error()
connectionErrors := []string{
"agent refused operation",
"connection refused",
"broken pipe",
"connection reset",
"use of closed network connection",
"EOF",
"no such file or directory", // Socket file doesn't exist
"permission denied", // Socket permission issues
"network is unreachable", // Network issues
"timeout", // Connection timeouts
"i/o timeout", // I/O timeouts
}

for _, connErr := range connectionErrors {
if strings.Contains(strings.ToLower(errStr), strings.ToLower(connErr)) {
return true
}
}

return false
}

// Agent interface methods with reconnection support
func (ra *ReconnectableAgent) List() ([]*agent.Key, error) {
var keys []*agent.Key
err := ra.withReconnect(func(a agent.Agent) error {
var listErr error
keys, listErr = a.List()
return listErr
})
return keys, err
}

func (ra *ReconnectableAgent) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) {
var signature *ssh.Signature
err := ra.withReconnect(func(a agent.Agent) error {
var signErr error
signature, signErr = a.Sign(key, data)
return signErr
})
return signature, err
}

func (ra *ReconnectableAgent) Add(key agent.AddedKey) error {
return ra.withReconnect(func(a agent.Agent) error {
return a.Add(key)
})
}

func (ra *ReconnectableAgent) Remove(key ssh.PublicKey) error {
return ra.withReconnect(func(a agent.Agent) error {
return a.Remove(key)
})
}

func (ra *ReconnectableAgent) RemoveAll() error {
return ra.withReconnect(func(a agent.Agent) error {
return a.RemoveAll()
})
}

func (ra *ReconnectableAgent) Lock(passphrase []byte) error {
return ra.withReconnect(func(a agent.Agent) error {
return a.Lock(passphrase)
})
}

func (ra *ReconnectableAgent) Unlock(passphrase []byte) error {
return ra.withReconnect(func(a agent.Agent) error {
return a.Unlock(passphrase)
})
}

func (ra *ReconnectableAgent) Signers() ([]ssh.Signer, error) {
var signers []ssh.Signer
err := ra.withReconnect(func(a agent.Agent) error {
var signersErr error
signers, signersErr = a.Signers()
return signersErr
})
return signers, err
}

var (
listenName = flag.String("listen-socket", "", "path to listening socket")
upstreamName = flag.String("upstream-agent", os.Getenv("SSH_AUTH_SOCK"), "path to the ssh agent socket")
Expand All @@ -48,15 +215,19 @@ func main() {
flag.Parse()

if *listenName == "" || *upstreamName == "" || *fingerprints == "" {
flag.Usage()
fmt.Fprintf(os.Stderr, "Usage: %s -listen-socket <socket> -upstream-agent <socket> -fingerprints <fingerprints>\n", os.Args[0])
fmt.Fprintf(os.Stderr, "\nParameters:\n")
flag.PrintDefaults()
os.Exit(1)
}

up, err := net.Dial("unix", *upstreamName)
// Create reconnectable upstream agent
upstream, err := NewReconnectableAgent(*upstreamName)
if err != nil {
panic(err)
}
upstream := agent.NewClient(up)

fmt.Fprintf(os.Stderr, "Starting ssh-agent-filter with automatic reconnection\n")

// fixme, this is common code from agent-mux, stick it somewhere else
st, err := os.Stat(*listenName)
Expand Down