diff --git a/cmd/broker/acl_test.go b/cmd/broker/acl_test.go index 65edcc8..025237c 100644 --- a/cmd/broker/acl_test.go +++ b/cmd/broker/acl_test.go @@ -20,8 +20,10 @@ import ( "context" "encoding/binary" "io" + "net" "testing" + "github.com/KafScale/platform/pkg/broker" "github.com/KafScale/platform/pkg/metadata" "github.com/KafScale/platform/pkg/protocol" "github.com/twmb/franz-go/pkg/kmsg" @@ -152,6 +154,108 @@ func TestACLOffsetFetchDenied(t *testing.T) { } } +func TestACLDescribeGroupsMixed(t *testing.T) { + t.Setenv("KAFSCALE_ACL_ENABLED", "true") + t.Setenv("KAFSCALE_ACL_JSON", `{"default_policy":"deny","principals":[{"name":"client-a","allow":[{"action":"group_read","resource":"group","name":"group-allowed"}]}]}`) + + store := metadata.NewInMemoryStore(defaultMetadata()) + handler := newTestHandler(store) + + clientID := "client-a" + req := &protocol.DescribeGroupsRequest{Groups: []string{"group-allowed", "group-denied"}} + payload, err := handler.Handle(context.Background(), &protocol.RequestHeader{CorrelationID: 20, APIVersion: 5, ClientID: &clientID}, req) + if err != nil { + t.Fatalf("Handle DescribeGroups: %v", err) + } + resp := decodeDescribeGroupsResponse(t, payload, 5) + if len(resp.Groups) != 2 { + t.Fatalf("expected 2 groups, got %d", len(resp.Groups)) + } + if resp.Groups[0].Group != "group-allowed" || resp.Groups[1].Group != "group-denied" { + t.Fatalf("unexpected group order: %+v", resp.Groups) + } + if resp.Groups[0].ErrorCode == protocol.GROUP_AUTHORIZATION_FAILED { + t.Fatalf("expected allowed group not to be auth failed, got %d", resp.Groups[0].ErrorCode) + } + if resp.Groups[1].ErrorCode != protocol.GROUP_AUTHORIZATION_FAILED { + t.Fatalf("expected denied group auth failed, got %d", resp.Groups[1].ErrorCode) + } +} + +func TestACLDeleteGroupsMixed(t *testing.T) { + t.Setenv("KAFSCALE_ACL_ENABLED", "true") + t.Setenv("KAFSCALE_ACL_JSON", `{"default_policy":"deny","principals":[{"name":"client-a","allow":[{"action":"group_admin","resource":"group","name":"group-allowed"}]}]}`) + + store := metadata.NewInMemoryStore(defaultMetadata()) + handler := newTestHandler(store) + + clientID := "client-a" + req := &protocol.DeleteGroupsRequest{Groups: []string{"group-allowed", "group-denied"}} + payload, err := handler.Handle(context.Background(), &protocol.RequestHeader{CorrelationID: 21, APIVersion: 2, ClientID: &clientID}, req) + if err != nil { + t.Fatalf("Handle DeleteGroups: %v", err) + } + resp := decodeDeleteGroupsResponse(t, payload, 2) + if len(resp.Groups) != 2 { + t.Fatalf("expected 2 groups, got %d", len(resp.Groups)) + } + if resp.Groups[0].Group != "group-allowed" || resp.Groups[1].Group != "group-denied" { + t.Fatalf("unexpected group order: %+v", resp.Groups) + } + if resp.Groups[0].ErrorCode == protocol.GROUP_AUTHORIZATION_FAILED { + t.Fatalf("expected allowed group not to be auth failed, got %d", resp.Groups[0].ErrorCode) + } + if resp.Groups[1].ErrorCode != protocol.GROUP_AUTHORIZATION_FAILED { + t.Fatalf("expected denied group auth failed, got %d", resp.Groups[1].ErrorCode) + } +} + +func TestACLProxyAddrProduceAllowed(t *testing.T) { + t.Setenv("KAFSCALE_ACL_ENABLED", "true") + t.Setenv("KAFSCALE_ACL_JSON", `{"default_policy":"deny","principals":[{"name":"10.0.0.1","allow":[{"action":"produce","resource":"topic","name":"orders"}]}]}`) + t.Setenv("KAFSCALE_PRINCIPAL_SOURCE", "proxy_addr") + t.Setenv("KAFSCALE_PROXY_PROTOCOL", "true") + + store := metadata.NewInMemoryStore(defaultMetadata()) + handler := newTestHandler(store) + + conn, peer := net.Pipe() + defer conn.Close() + defer peer.Close() + go func() { + _, _ = peer.Write([]byte("PROXY TCP4 10.0.0.1 10.0.0.2 12345 9092\r\n")) + }() + + connCtx := buildConnContextFunc(testLogger()) + _, info, err := connCtx(conn) + if err != nil { + t.Fatalf("proxy conn context: %v", err) + } + ctx := broker.ContextWithConnInfo(context.Background(), info) + + clientID := "spoofed-client" + req := &protocol.ProduceRequest{ + Acks: -1, + TimeoutMs: 1000, + Topics: []protocol.ProduceTopic{ + { + Name: "orders", + Partitions: []protocol.ProducePartition{ + {Partition: 0, Records: testBatchBytes(0, 0, 1)}, + }, + }, + }, + } + payload, err := handler.handleProduce(ctx, &protocol.RequestHeader{CorrelationID: 22, APIVersion: 0, ClientID: &clientID}, req) + if err != nil { + t.Fatalf("handleProduce: %v", err) + } + resp := decodeProduceResponse(t, payload, 0) + if resp.Topics[0].Partitions[0].ErrorCode != protocol.NONE { + t.Fatalf("expected produce allowed, got %d", resp.Topics[0].Partitions[0].ErrorCode) + } +} + func TestACLSyncGroupDenied(t *testing.T) { t.Setenv("KAFSCALE_ACL_ENABLED", "true") t.Setenv("KAFSCALE_ACL_JSON", `{"default_policy":"deny","principals":[{"name":"client-a","allow":[]}]}`) @@ -401,6 +505,28 @@ func decodeOffsetFetchResponse(t *testing.T, payload []byte, version int16) *pro return resp } +func decodeDescribeGroupsResponse(t *testing.T, payload []byte, version int16) *kmsg.DescribeGroupsResponse { + t.Helper() + reader := bytes.NewReader(payload) + var corr int32 + if err := binary.Read(reader, binary.BigEndian, &corr); err != nil { + t.Fatalf("read correlation id: %v", err) + } + if version >= 5 { + skipTaggedFields(t, reader) + } + body, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("read response body: %v", err) + } + resp := kmsg.NewPtrDescribeGroupsResponse() + resp.Version = version + if err := resp.ReadFrom(body); err != nil { + t.Fatalf("decode describe groups response: %v", err) + } + return resp +} + func decodeSyncGroupResponse(t *testing.T, payload []byte, version int16) *kmsg.SyncGroupResponse { t.Helper() reader := bytes.NewReader(payload) diff --git a/cmd/broker/main.go b/cmd/broker/main.go index 2a6bf02..61928ae 100644 --- a/cmd/broker/main.go +++ b/cmd/broker/main.go @@ -259,41 +259,64 @@ func (h *handler) Handle(ctx context.Context, header *protocol.RequestHeader, re return protocol.EncodeSyncGroupResponse(resp, header.APIVersion) case *protocol.DescribeGroupsRequest: req := req.(*protocol.DescribeGroupsRequest) - if !h.allowGroups(principal, req.Groups, acl.ActionGroupRead) { - h.recordAuthzDeniedWithPrincipal(principal, acl.ActionGroupRead, acl.ResourceGroup, strings.Join(req.Groups, ",")) + return h.withAdminMetrics(header.APIKey, func() ([]byte, error) { + allowed := make([]string, 0, len(req.Groups)) + denied := make(map[string]struct{}) + for _, groupID := range req.Groups { + if h.allowGroup(principal, groupID, acl.ActionGroupRead) { + allowed = append(allowed, groupID) + } else { + denied[groupID] = struct{}{} + h.recordAuthzDeniedWithPrincipal(principal, acl.ActionGroupRead, acl.ResourceGroup, groupID) + } + } + + responseByGroup := make(map[string]protocol.DescribeGroupsResponseGroup, len(req.Groups)) + if len(allowed) > 0 { + if !h.etcdAvailable() { + for _, groupID := range allowed { + responseByGroup[groupID] = protocol.DescribeGroupsResponseGroup{ + ErrorCode: protocol.REQUEST_TIMED_OUT, + GroupID: groupID, + } + } + } else { + allowedReq := *req + allowedReq.Groups = allowed + resp, err := h.coordinator.DescribeGroups(ctx, &allowedReq, header.CorrelationID) + if err != nil { + return nil, err + } + for _, group := range resp.Groups { + responseByGroup[group.GroupID] = group + } + } + } + results := make([]protocol.DescribeGroupsResponseGroup, 0, len(req.Groups)) for _, groupID := range req.Groups { - results = append(results, protocol.DescribeGroupsResponseGroup{ - ErrorCode: protocol.GROUP_AUTHORIZATION_FAILED, - GroupID: groupID, - }) + if _, denied := denied[groupID]; denied { + results = append(results, protocol.DescribeGroupsResponseGroup{ + ErrorCode: protocol.GROUP_AUTHORIZATION_FAILED, + GroupID: groupID, + }) + continue + } + if group, ok := responseByGroup[groupID]; ok { + results = append(results, group) + } else { + results = append(results, protocol.DescribeGroupsResponseGroup{ + ErrorCode: protocol.UNKNOWN_SERVER_ERROR, + GroupID: groupID, + }) + } } + return protocol.EncodeDescribeGroupsResponse(&protocol.DescribeGroupsResponse{ CorrelationID: header.CorrelationID, ThrottleMs: 0, Groups: results, }, header.APIVersion) - } - return h.withAdminMetrics(header.APIKey, func() ([]byte, error) { - if !h.etcdAvailable() { - results := make([]protocol.DescribeGroupsResponseGroup, 0, len(req.Groups)) - for _, groupID := range req.Groups { - results = append(results, protocol.DescribeGroupsResponseGroup{ - ErrorCode: protocol.REQUEST_TIMED_OUT, - GroupID: groupID, - }) - } - return protocol.EncodeDescribeGroupsResponse(&protocol.DescribeGroupsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Groups: results, - }, header.APIVersion) - } - resp, err := h.coordinator.DescribeGroups(ctx, req, header.CorrelationID) - if err != nil { - return nil, err - } - return protocol.EncodeDescribeGroupsResponse(resp, header.APIVersion) }) case *protocol.ListGroupsRequest: if !h.allowGroup(principal, "*", acl.ActionGroupRead) { @@ -493,28 +516,63 @@ func (h *handler) Handle(ctx context.Context, header *protocol.RequestHeader, re case *protocol.DeleteGroupsRequest: return h.withAdminMetrics(header.APIKey, func() ([]byte, error) { deleteReq := req.(*protocol.DeleteGroupsRequest) - if !h.allowGroups(principal, deleteReq.Groups, acl.ActionGroupAdmin) { - return h.unauthorizedDeleteGroups(principal, header, deleteReq) + allowed := make([]string, 0, len(deleteReq.Groups)) + denied := make(map[string]struct{}) + for _, groupID := range deleteReq.Groups { + if h.allowGroup(principal, groupID, acl.ActionGroupAdmin) { + allowed = append(allowed, groupID) + } else { + denied[groupID] = struct{}{} + h.recordAuthzDeniedWithPrincipal(principal, acl.ActionGroupAdmin, acl.ResourceGroup, groupID) + } } - if !h.etcdAvailable() { - results := make([]protocol.DeleteGroupsResponseGroup, 0, len(deleteReq.Groups)) - for _, groupID := range deleteReq.Groups { + + responseByGroup := make(map[string]protocol.DeleteGroupsResponseGroup, len(deleteReq.Groups)) + if len(allowed) > 0 { + if !h.etcdAvailable() { + for _, groupID := range allowed { + responseByGroup[groupID] = protocol.DeleteGroupsResponseGroup{ + Group: groupID, + ErrorCode: protocol.REQUEST_TIMED_OUT, + } + } + } else { + allowedReq := *deleteReq + allowedReq.Groups = allowed + resp, err := h.coordinator.DeleteGroups(ctx, &allowedReq, header.CorrelationID) + if err != nil { + return nil, err + } + for _, group := range resp.Groups { + responseByGroup[group.Group] = group + } + } + } + + results := make([]protocol.DeleteGroupsResponseGroup, 0, len(deleteReq.Groups)) + for _, groupID := range deleteReq.Groups { + if _, denied := denied[groupID]; denied { results = append(results, protocol.DeleteGroupsResponseGroup{ Group: groupID, - ErrorCode: protocol.REQUEST_TIMED_OUT, + ErrorCode: protocol.GROUP_AUTHORIZATION_FAILED, + }) + continue + } + if group, ok := responseByGroup[groupID]; ok { + results = append(results, group) + } else { + results = append(results, protocol.DeleteGroupsResponseGroup{ + Group: groupID, + ErrorCode: protocol.UNKNOWN_SERVER_ERROR, }) } - return protocol.EncodeDeleteGroupsResponse(&protocol.DeleteGroupsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Groups: results, - }, header.APIVersion) - } - resp, err := h.coordinator.DeleteGroups(ctx, deleteReq, header.CorrelationID) - if err != nil { - return nil, err } - return protocol.EncodeDeleteGroupsResponse(resp, header.APIVersion) + + return protocol.EncodeDeleteGroupsResponse(&protocol.DeleteGroupsResponse{ + CorrelationID: header.CorrelationID, + ThrottleMs: 0, + Groups: results, + }, header.APIVersion) }) case *protocol.CreateTopicsRequest: createReq := req.(*protocol.CreateTopicsRequest) diff --git a/pkg/broker/proxyproto.go b/pkg/broker/proxyproto.go index b5b1780..fa9757c 100644 --- a/pkg/broker/proxyproto.go +++ b/pkg/broker/proxyproto.go @@ -74,6 +74,9 @@ func parseProxyV1(br *bufio.Reader) (*ProxyInfo, error) { return nil, err } parts := bytes.Fields([]byte(line)) + if len(parts) >= 2 && bytes.Equal(bytes.ToUpper(parts[1]), []byte("UNKNOWN")) { + return &ProxyInfo{Local: true}, nil + } if len(parts) < 6 { return nil, fmt.Errorf("proxy v1 header malformed") } diff --git a/pkg/broker/proxyproto_test.go b/pkg/broker/proxyproto_test.go new file mode 100644 index 0000000..96f8ec4 --- /dev/null +++ b/pkg/broker/proxyproto_test.go @@ -0,0 +1,80 @@ +// Copyright 2026 Alexander Alten (novatechflow), NovaTechflow (novatechflow.com). +// This project is supported and financed by Scalytics, Inc. (www.scalytics.io). +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package broker + +import ( + "bytes" + "io" + "net" + "testing" +) + +func TestProxyProtocolV1Unknown(t *testing.T) { + conn, peer := net.Pipe() + defer conn.Close() + defer peer.Close() + + payload := []byte("PROXY UNKNOWN\r\nping") + go func() { + _, _ = peer.Write(payload) + }() + + wrapped, info, err := ReadProxyProtocol(conn) + if err != nil { + t.Fatalf("ReadProxyProtocol: %v", err) + } + if info == nil || !info.Local { + t.Fatalf("expected local proxy info, got %+v", info) + } + buf := make([]byte, 4) + if _, err := io.ReadFull(wrapped, buf); err != nil { + t.Fatalf("read payload: %v", err) + } + if !bytes.Equal(buf, []byte("ping")) { + t.Fatalf("unexpected payload %q", string(buf)) + } +} + +func TestProxyProtocolV2Local(t *testing.T) { + conn, peer := net.Pipe() + defer conn.Close() + defer peer.Close() + + header := append([]byte{}, proxyV2Signature...) + header = append(header, 0x20) // v2 + LOCAL + header = append(header, 0x00) // UNSPEC + header = append(header, 0x00, 0x00) // length 0 + payload := append(header, []byte("pong")...) + + go func() { + _, _ = peer.Write(payload) + }() + + wrapped, info, err := ReadProxyProtocol(conn) + if err != nil { + t.Fatalf("ReadProxyProtocol: %v", err) + } + if info == nil || !info.Local { + t.Fatalf("expected local proxy info, got %+v", info) + } + buf := make([]byte, 4) + if _, err := io.ReadFull(wrapped, buf); err != nil { + t.Fatalf("read payload: %v", err) + } + if !bytes.Equal(buf, []byte("pong")) { + t.Fatalf("unexpected payload %q", string(buf)) + } +}