diff --git a/services/httpoverrpc/client/client.go b/services/httpoverrpc/client/client.go index 03c99bfa..506c23e8 100644 --- a/services/httpoverrpc/client/client.go +++ b/services/httpoverrpc/client/client.go @@ -72,6 +72,7 @@ type proxyCmd struct { allowAnyHost bool protocol string hostname string + hostheader string insecureSkipVerify bool stream bool } @@ -91,6 +92,7 @@ func (p *proxyCmd) SetFlags(f *flag.FlagSet) { f.BoolVar(&p.allowAnyHost, "allow-any-host", false, "Serve data regardless of the Host in HTTP requests instead of only allowing localhost and IPs. False by default to prevent DNS rebinding attacks.") f.StringVar(&p.protocol, "protocol", "http", "protocol to communicate with specified hostname") f.StringVar(&p.hostname, "hostname", "localhost", "ip address or domain name to specify host") + f.StringVar(&p.hostheader, "host-header", "", "if set, a host header to set on requests (overriding the value from the incoming request") f.BoolVar(&p.insecureSkipVerify, "insecure-skip-tls-verify", false, "If true, skip TLS cert verification") f.BoolVar(&p.stream, "stream", false, "If true, stream the response back to the client. Useful for large responses.") } @@ -204,6 +206,9 @@ func (p *proxyCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa return } + if p.hostheader != "" { + httpReq.Header.Set("host", p.hostheader) + } var reqHeaders []*pb.Header for k, v := range httpReq.Header { reqHeaders = append(reqHeaders, &pb.Header{Key: k, Values: v}) diff --git a/services/httpoverrpc/client/client_test.go b/services/httpoverrpc/client/client_test.go index 45df23b9..7ef729c6 100644 --- a/services/httpoverrpc/client/client_test.go +++ b/services/httpoverrpc/client/client_test.go @@ -501,3 +501,239 @@ func TestGetPortDefaultHTTPS(t *testing.T) { t.Fatalf("got wrong port: %d. Expected: %d", result, defaultHTTPSPort) } } + +func TestProxyHostHeader(t *testing.T) { + ctx := context.Background() + receivedHeaders := make(map[string]string) + + // Set up web server that captures headers + m := http.NewServeMux() + m.HandleFunc("/", func(httpResp http.ResponseWriter, httpReq *http.Request) { + receivedHeaders["request-host"] = httpReq.Host + _, _ = httpResp.Write([]byte("hello world")) + }) + l, err := net.Listen("tcp4", "localhost:0") + if err != nil { + t.Fatal(err) + } + go func() { _ = http.Serve(l, m) }() + + // Dial out to sansshell server set up in TestMain + conn, err := proxy.DialContext(ctx, "", []string{"bufnet"}, grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { conn.Close() }) + + // Start proxying command with host-header flag + f := flag.NewFlagSet("proxy", flag.PanicOnError) + p := &proxyCmd{} + p.SetFlags(f) + _, port, err := net.SplitHostPort(l.Addr().String()) + if err != nil { + t.Fatal(err) + } + customHostHeader := "custom.example.com" + if err := f.Parse([]string{"-host-header", customHostHeader, "-allow-any-host", port}); err != nil { + t.Fatal(err) + } + reader, writer := io.Pipe() + go p.Execute(ctx, f, &util.ExecuteState{ + Conn: conn, + Out: []io.Writer{writer}, + Err: []io.Writer{os.Stderr}, + }) + + // Find the port to use + buf := make([]byte, 1024) + if _, err := reader.Read(buf); err != nil { + t.Fatal(err) + } + msg := strings.Fields(string(buf)) + // Parse out "Listening on http://%v, " + addr := msg[2][:len(msg[2])-1] + + // Make a call with original host header + req, err := http.NewRequest("GET", addr, nil) + if err != nil { + t.Fatal(err) + } + req.Host = "original.example.com:8080" + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + want := "hello world" + if string(body) != want { + t.Errorf("got %q, want %q", body, want) + } + + // Verify that the custom host header was used instead of the original + if receivedHeaders["request-host"] != customHostHeader { + t.Errorf("got host header %q, want %q", receivedHeaders["request-host"], customHostHeader) + } +} + +func TestProxyHostHeaderStream(t *testing.T) { + ctx := context.Background() + receivedHeaders := make(map[string]string) + + // Set up web server that captures headers + m := http.NewServeMux() + m.HandleFunc("/", func(httpResp http.ResponseWriter, httpReq *http.Request) { + receivedHeaders["request-host"] = httpReq.Host + _, _ = httpResp.Write([]byte("hello world")) + }) + l, err := net.Listen("tcp4", "localhost:0") + if err != nil { + t.Fatal(err) + } + go func() { _ = http.Serve(l, m) }() + + // Dial out to sansshell server set up in TestMain + conn, err := proxy.DialContext(ctx, "", []string{"bufnet"}, grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { conn.Close() }) + + // Start proxying command with host-header flag and stream mode + f := flag.NewFlagSet("proxy", flag.PanicOnError) + p := &proxyCmd{} + p.SetFlags(f) + p.stream = true + _, port, err := net.SplitHostPort(l.Addr().String()) + if err != nil { + t.Fatal(err) + } + customHostHeader := "custom-stream.example.com" + if err := f.Parse([]string{"-host-header", customHostHeader, "-allow-any-host", port}); err != nil { + t.Fatal(err) + } + reader, writer := io.Pipe() + go p.Execute(ctx, f, &util.ExecuteState{ + Conn: conn, + Out: []io.Writer{writer}, + Err: []io.Writer{os.Stderr}, + }) + + // Find the port to use + buf := make([]byte, 1024) + if _, err := reader.Read(buf); err != nil { + t.Fatal(err) + } + msg := strings.Fields(string(buf)) + // Parse out "Listening on http://%v, " + addr := msg[2][:len(msg[2])-1] + + // Make a call with original host header + req, err := http.NewRequest("GET", addr, nil) + if err != nil { + t.Fatal(err) + } + req.Host = "original-stream.example.com:8080" + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + want := "hello world" + if string(body) != want { + t.Errorf("got %q, want %q", body, want) + } + + // Verify that the custom host header was used instead of the original + if receivedHeaders["request-host"] != customHostHeader { + t.Errorf("got host header %q, want %q", receivedHeaders["request-host"], customHostHeader) + } +} + +func TestProxyNoHostHeader(t *testing.T) { + ctx := context.Background() + receivedHeaders := make(map[string]string) + + // Set up web server that captures headers + m := http.NewServeMux() + m.HandleFunc("/", func(httpResp http.ResponseWriter, httpReq *http.Request) { + receivedHeaders["request-host"] = httpReq.Host + _, _ = httpResp.Write([]byte("hello world")) + }) + l, err := net.Listen("tcp4", "localhost:0") + if err != nil { + t.Fatal(err) + } + go func() { _ = http.Serve(l, m) }() + + // Dial out to sansshell server set up in TestMain + conn, err := proxy.DialContext(ctx, "", []string{"bufnet"}, grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { conn.Close() }) + + // Start proxying command without host-header flag + f := flag.NewFlagSet("proxy", flag.PanicOnError) + p := &proxyCmd{} + p.SetFlags(f) + _, port, err := net.SplitHostPort(l.Addr().String()) + if err != nil { + t.Fatal(err) + } + if err := f.Parse([]string{"-allow-any-host", port}); err != nil { + t.Fatal(err) + } + reader, writer := io.Pipe() + go p.Execute(ctx, f, &util.ExecuteState{ + Conn: conn, + Out: []io.Writer{writer}, + Err: []io.Writer{os.Stderr}, + }) + + // Find the port to use + buf := make([]byte, 1024) + if _, err := reader.Read(buf); err != nil { + t.Fatal(err) + } + msg := strings.Fields(string(buf)) + // Parse out "Listening on http://%v, " + addr := msg[2][:len(msg[2])-1] + + // Make a call with original host header + req, err := http.NewRequest("GET", addr, nil) + if err != nil { + t.Fatal(err) + } + originalHost := "original.example.com:8080" + req.Host = originalHost + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + want := "hello world" + if string(body) != want { + t.Errorf("got %q, want %q", body, want) + } + + // Verify that the proxy's address was used (no host override) + // When host-header flag is not set, the target receives the proxy's listening address + if !strings.HasPrefix(receivedHeaders["request-host"], "localhost:") { + t.Errorf("got host header %q, want localhost:port", receivedHeaders["request-host"]) + } +}