diff --git a/proxy.go b/proxy.go index d1474189..d272785a 100644 --- a/proxy.go +++ b/proxy.go @@ -33,6 +33,7 @@ import ( "github.com/google/martian/v3/nosigpipe" "github.com/google/martian/v3/proxyutil" "github.com/google/martian/v3/trafficshape" + "golang.org/x/net/proxy" ) var errClose = errors.New("closing connection") @@ -606,30 +607,31 @@ func (p *Proxy) roundTrip(ctx *Context, req *http.Request) (*http.Response, erro } func (p *Proxy) connect(req *http.Request) (*http.Response, net.Conn, error) { + var ( + conn net.Conn + err error + ) + if p.proxyURL != nil { log.Debugf("martian: CONNECT with downstream proxy: %s", p.proxyURL.Host) - conn, err := p.dial("tcp", p.proxyURL.Host) + dialer, err := proxy.FromURL( + p.proxyURL, &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }, + ) if err != nil { return nil, nil, err } - pbw := bufio.NewWriter(conn) - pbr := bufio.NewReader(conn) - - req.Write(pbw) - pbw.Flush() - res, err := http.ReadResponse(pbr, req) - if err != nil { - return nil, nil, err - } + conn, err = dialer.Dial("tcp", req.URL.Host) + } else { + log.Debugf("martian: CONNECT to host directly: %s", req.URL.Host) - return res, conn, nil + conn, err = p.dial("tcp", req.URL.Host) } - log.Debugf("martian: CONNECT to host directly: %s", req.URL.Host) - - conn, err := p.dial("tcp", req.URL.Host) if err != nil { return nil, nil, err }