diff --git a/README.md b/README.md index 23178f4..585c821 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,25 @@ if err != nil { defer client.Close() ``` ++ Dial by SOCKS5 proxy + +```go +client, err := DialWithPasswd(addr, user, passwd, WithDialFuncOption(func(network string, address string) (net.Conn, error) { + // get proxy address from env or config + proxyAddress := os.Getenv("socks5_proxy") + dial, err := proxy.SOCKS5(network, proxyAddress, nil, nil) + if err != nil { + t.Fatal(err) + } + c, err := dial.Dial(network, address) + return c, err +})) +if err != nil { + handleErr(err) +} +defer client.Close() +``` + ## execute commmand + Don't care about output, calling Run diff --git a/sshclient.go b/sshclient.go index cf604f2..e42c084 100644 --- a/sshclient.go +++ b/sshclient.go @@ -20,6 +20,21 @@ import ( type remoteScriptType byte type remoteShellType byte +// DialOption represents an option for Dial connection. +type DialOption func(opt *dialOptions) + +// dialOptions contains all the options set by WithXxxOpt func as flow. +type dialOptions struct { + dialFunc func(network string, address string) (net.Conn, error) +} + +// WithDialFuncOption append dialFunc field to dialOptions +func WithDialFuncOption(dialFunc func(network string, address string) (net.Conn, error)) DialOption { + return func(opt *dialOptions) { + opt.dialFunc = dialFunc + } +} + const ( cmdLine remoteScriptType = iota rawScript @@ -36,7 +51,11 @@ type Client struct { } // DialWithPasswd starts a client connection to the given SSH server with passwd authmethod. -func DialWithPasswd(addr, user, passwd string) (*Client, error) { +func DialWithPasswd(addr, user, passwd string, options ...DialOption) (*Client, error) { + opts := &dialOptions{} + for _, option := range options { + option(opts) + } config := &ssh.ClientConfig{ User: user, Auth: []ssh.AuthMethod{ @@ -45,11 +64,15 @@ func DialWithPasswd(addr, user, passwd string) (*Client, error) { HostKeyCallback: ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }), } - return Dial("tcp", addr, config) + return Dial("tcp", addr, config, opts) } // DialWithKey starts a client connection to the given SSH server with key authmethod. -func DialWithKey(addr, user, keyfile string) (*Client, error) { +func DialWithKey(addr, user, keyfile string, options ...DialOption) (*Client, error) { + opts := &dialOptions{} + for _, option := range options { + option(opts) + } key, err := ioutil.ReadFile(keyfile) if err != nil { return nil, err @@ -67,12 +90,15 @@ func DialWithKey(addr, user, keyfile string) (*Client, error) { }, HostKeyCallback: ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }), } - - return Dial("tcp", addr, config) + return Dial("tcp", addr, config, opts) } // DialWithKeyWithPassphrase same as DialWithKey but with a passphrase to decrypt the private key -func DialWithKeyWithPassphrase(addr, user, keyfile string, passphrase string) (*Client, error) { +func DialWithKeyWithPassphrase(addr, user, keyfile string, passphrase string, options ...DialOption) (*Client, error) { + opts := &dialOptions{} + for _, option := range options { + option(opts) + } key, err := ioutil.ReadFile(keyfile) if err != nil { return nil, err @@ -91,12 +117,19 @@ func DialWithKeyWithPassphrase(addr, user, keyfile string, passphrase string) (* HostKeyCallback: ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }), } - return Dial("tcp", addr, config) + return Dial("tcp", addr, config, opts) } // Dial starts a client connection to the given SSH server. // This wraps ssh.Dial. -func Dial(network, addr string, config *ssh.ClientConfig) (*Client, error) { +func Dial(network, addr string, config *ssh.ClientConfig, opts *dialOptions) (*Client, error) { + if opts != nil && opts.dialFunc != nil { + conn, err := opts.dialFunc(network, addr) + if err != nil { + return nil, err + } + return DialWithConnection(conn, network, config) + } sshClient, err := ssh.Dial(network, addr, config) if err != nil { return nil, err @@ -114,9 +147,7 @@ func DialWithConnection(conn net.Conn, addr string, config *ssh.ClientConfig) (* client := ssh.NewClient(ncc, chans, reqs) - return &Client{ - client: client, - }, nil + return &Client{sshClient: client}, nil } // Close closes the underlying client network connection.