Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions services/httpoverrpc/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ type proxyCmd struct {
allowAnyHost bool
protocol string
hostname string
hostheader string
insecureSkipVerify bool
stream bool
}
Expand All @@ -91,6 +92,7 @@ func (p *proxyCmd) SetFlags(f *flag.FlagSet) {
f.BoolVar(&p.allowAnyHost, "allow-any-host", false, "Serve data regardless of the Host in HTTP requests instead of only allowing localhost and IPs. False by default to prevent DNS rebinding attacks.")
f.StringVar(&p.protocol, "protocol", "http", "protocol to communicate with specified hostname")
f.StringVar(&p.hostname, "hostname", "localhost", "ip address or domain name to specify host")
f.StringVar(&p.hostheader, "host-header", "", "if set, a host header to set on requests (overriding the value from the incoming request")
f.BoolVar(&p.insecureSkipVerify, "insecure-skip-tls-verify", false, "If true, skip TLS cert verification")
f.BoolVar(&p.stream, "stream", false, "If true, stream the response back to the client. Useful for large responses.")
}
Expand Down Expand Up @@ -204,6 +206,9 @@ func (p *proxyCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa
return
}

if p.hostheader != "" {
httpReq.Header.Set("host", p.hostheader)
}
var reqHeaders []*pb.Header
for k, v := range httpReq.Header {
reqHeaders = append(reqHeaders, &pb.Header{Key: k, Values: v})
Expand Down
236 changes: 236 additions & 0 deletions services/httpoverrpc/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,3 +501,239 @@ func TestGetPortDefaultHTTPS(t *testing.T) {
t.Fatalf("got wrong port: %d. Expected: %d", result, defaultHTTPSPort)
}
}

func TestProxyHostHeader(t *testing.T) {
ctx := context.Background()
receivedHeaders := make(map[string]string)

// Set up web server that captures headers
m := http.NewServeMux()
m.HandleFunc("/", func(httpResp http.ResponseWriter, httpReq *http.Request) {
receivedHeaders["request-host"] = httpReq.Host
_, _ = httpResp.Write([]byte("hello world"))
})
l, err := net.Listen("tcp4", "localhost:0")
if err != nil {
t.Fatal(err)
}
go func() { _ = http.Serve(l, m) }()

// Dial out to sansshell server set up in TestMain
conn, err := proxy.DialContext(ctx, "", []string{"bufnet"}, grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { conn.Close() })

// Start proxying command with host-header flag
f := flag.NewFlagSet("proxy", flag.PanicOnError)
p := &proxyCmd{}
p.SetFlags(f)
_, port, err := net.SplitHostPort(l.Addr().String())
if err != nil {
t.Fatal(err)
}
customHostHeader := "custom.example.com"
if err := f.Parse([]string{"-host-header", customHostHeader, "-allow-any-host", port}); err != nil {
t.Fatal(err)
}
reader, writer := io.Pipe()
go p.Execute(ctx, f, &util.ExecuteState{
Conn: conn,
Out: []io.Writer{writer},
Err: []io.Writer{os.Stderr},
})

// Find the port to use
buf := make([]byte, 1024)
if _, err := reader.Read(buf); err != nil {
t.Fatal(err)
}
msg := strings.Fields(string(buf))
// Parse out "Listening on http://%v, "
addr := msg[2][:len(msg[2])-1]

// Make a call with original host header
req, err := http.NewRequest("GET", addr, nil)
if err != nil {
t.Fatal(err)
}
req.Host = "original.example.com:8080"

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
want := "hello world"
if string(body) != want {
t.Errorf("got %q, want %q", body, want)
}

// Verify that the custom host header was used instead of the original
if receivedHeaders["request-host"] != customHostHeader {
t.Errorf("got host header %q, want %q", receivedHeaders["request-host"], customHostHeader)
}
}

func TestProxyHostHeaderStream(t *testing.T) {
ctx := context.Background()
receivedHeaders := make(map[string]string)

// Set up web server that captures headers
m := http.NewServeMux()
m.HandleFunc("/", func(httpResp http.ResponseWriter, httpReq *http.Request) {
receivedHeaders["request-host"] = httpReq.Host
_, _ = httpResp.Write([]byte("hello world"))
})
l, err := net.Listen("tcp4", "localhost:0")
if err != nil {
t.Fatal(err)
}
go func() { _ = http.Serve(l, m) }()

// Dial out to sansshell server set up in TestMain
conn, err := proxy.DialContext(ctx, "", []string{"bufnet"}, grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { conn.Close() })

// Start proxying command with host-header flag and stream mode
f := flag.NewFlagSet("proxy", flag.PanicOnError)
p := &proxyCmd{}
p.SetFlags(f)
p.stream = true
_, port, err := net.SplitHostPort(l.Addr().String())
if err != nil {
t.Fatal(err)
}
customHostHeader := "custom-stream.example.com"
if err := f.Parse([]string{"-host-header", customHostHeader, "-allow-any-host", port}); err != nil {
t.Fatal(err)
}
reader, writer := io.Pipe()
go p.Execute(ctx, f, &util.ExecuteState{
Conn: conn,
Out: []io.Writer{writer},
Err: []io.Writer{os.Stderr},
})

// Find the port to use
buf := make([]byte, 1024)
if _, err := reader.Read(buf); err != nil {
t.Fatal(err)
}
msg := strings.Fields(string(buf))
// Parse out "Listening on http://%v, "
addr := msg[2][:len(msg[2])-1]

// Make a call with original host header
req, err := http.NewRequest("GET", addr, nil)
if err != nil {
t.Fatal(err)
}
req.Host = "original-stream.example.com:8080"

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
want := "hello world"
if string(body) != want {
t.Errorf("got %q, want %q", body, want)
}

// Verify that the custom host header was used instead of the original
if receivedHeaders["request-host"] != customHostHeader {
t.Errorf("got host header %q, want %q", receivedHeaders["request-host"], customHostHeader)
}
}

func TestProxyNoHostHeader(t *testing.T) {
ctx := context.Background()
receivedHeaders := make(map[string]string)

// Set up web server that captures headers
m := http.NewServeMux()
m.HandleFunc("/", func(httpResp http.ResponseWriter, httpReq *http.Request) {
receivedHeaders["request-host"] = httpReq.Host
_, _ = httpResp.Write([]byte("hello world"))
})
l, err := net.Listen("tcp4", "localhost:0")
if err != nil {
t.Fatal(err)
}
go func() { _ = http.Serve(l, m) }()

// Dial out to sansshell server set up in TestMain
conn, err := proxy.DialContext(ctx, "", []string{"bufnet"}, grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { conn.Close() })

// Start proxying command without host-header flag
f := flag.NewFlagSet("proxy", flag.PanicOnError)
p := &proxyCmd{}
p.SetFlags(f)
_, port, err := net.SplitHostPort(l.Addr().String())
if err != nil {
t.Fatal(err)
}
if err := f.Parse([]string{"-allow-any-host", port}); err != nil {
t.Fatal(err)
}
reader, writer := io.Pipe()
go p.Execute(ctx, f, &util.ExecuteState{
Conn: conn,
Out: []io.Writer{writer},
Err: []io.Writer{os.Stderr},
})

// Find the port to use
buf := make([]byte, 1024)
if _, err := reader.Read(buf); err != nil {
t.Fatal(err)
}
msg := strings.Fields(string(buf))
// Parse out "Listening on http://%v, "
addr := msg[2][:len(msg[2])-1]

// Make a call with original host header
req, err := http.NewRequest("GET", addr, nil)
if err != nil {
t.Fatal(err)
}
originalHost := "original.example.com:8080"
req.Host = originalHost

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
want := "hello world"
if string(body) != want {
t.Errorf("got %q, want %q", body, want)
}

// Verify that the proxy's address was used (no host override)
// When host-header flag is not set, the target receives the proxy's listening address
if !strings.HasPrefix(receivedHeaders["request-host"], "localhost:") {
t.Errorf("got host header %q, want localhost:port", receivedHeaders["request-host"])
}
}
Loading