From 30284b07c90b50405eeca6e85ed58bcb770513c0 Mon Sep 17 00:00:00 2001 From: Paul Querna Date: Sun, 12 Oct 2025 13:17:06 -0700 Subject: [PATCH 1/9] rtun v0 --- pb/c1/connectorapi/rtun/v1/gateway.pb.go | 548 +++++++++++ .../rtun/v1/gateway.pb.validate.go | 900 ++++++++++++++++++ pb/c1/connectorapi/rtun/v1/gateway_grpc.pb.go | 119 +++ pb/c1/connectorapi/rtun/v1/rtun.pb.go | 574 +++++++++++ .../connectorapi/rtun/v1/rtun.pb.validate.go | 846 ++++++++++++++++ pb/c1/connectorapi/rtun/v1/rtun_grpc.pb.go | 113 +++ pkg/lambda/grpc/config/sts.go | 8 +- pkg/rtun/gateway/client.go | 521 ++++++++++ pkg/rtun/gateway/client_conn_test.go | 170 ++++ pkg/rtun/gateway/errors.go | 9 + pkg/rtun/gateway/grpc_options.go | 25 + pkg/rtun/gateway/integration_test.go | 135 +++ pkg/rtun/gateway/metrics.go | 44 + pkg/rtun/gateway/options.go | 16 + pkg/rtun/gateway/server.go | 302 ++++++ pkg/rtun/match/directory.go | 15 + pkg/rtun/match/errors.go | 7 + pkg/rtun/match/locator.go | 48 + pkg/rtun/match/locator_test.go | 84 ++ pkg/rtun/match/memory/directory.go | 56 ++ pkg/rtun/match/memory/directory_test.go | 54 ++ pkg/rtun/match/memory/presence.go | 93 ++ pkg/rtun/match/memory/presence_test.go | 139 +++ pkg/rtun/match/presence.go | 21 + pkg/rtun/match/route.go | 62 ++ pkg/rtun/match/route_test.go | 118 +++ pkg/rtun/server/auth.go | 16 + pkg/rtun/server/grpc_options.go | 28 + pkg/rtun/server/handler.go | 144 +++ pkg/rtun/server/metrics.go | 46 + pkg/rtun/server/options.go | 24 + pkg/rtun/server/registry.go | 91 ++ pkg/rtun/server/server_integration_test.go | 91 ++ pkg/rtun/transport/closedset.go | 151 +++ pkg/rtun/transport/closedset_test.go | 152 +++ pkg/rtun/transport/conn.go | 306 ++++++ pkg/rtun/transport/conn_test.go | 93 ++ pkg/rtun/transport/errors.go | 9 + pkg/rtun/transport/listener.go | 63 ++ pkg/rtun/transport/session.go | 474 +++++++++ pkg/rtun/transport/session_race_test.go | 96 ++ pkg/rtun/transport/session_test.go | 308 ++++++ proto/c1/connectorapi/rtun/v1/gateway.proto | 49 + proto/c1/connectorapi/rtun/v1/rtun.proto | 49 + .../google.golang.org/grpc/health/client.go | 117 +++ .../google.golang.org/grpc/health/logging.go | 23 + .../google.golang.org/grpc/health/producer.go | 106 +++ .../google.golang.org/grpc/health/server.go | 163 ++++ vendor/modules.txt | 1 + 49 files changed, 7626 insertions(+), 1 deletion(-) create mode 100644 pb/c1/connectorapi/rtun/v1/gateway.pb.go create mode 100644 pb/c1/connectorapi/rtun/v1/gateway.pb.validate.go create mode 100644 pb/c1/connectorapi/rtun/v1/gateway_grpc.pb.go create mode 100644 pb/c1/connectorapi/rtun/v1/rtun.pb.go create mode 100644 pb/c1/connectorapi/rtun/v1/rtun.pb.validate.go create mode 100644 pb/c1/connectorapi/rtun/v1/rtun_grpc.pb.go create mode 100644 pkg/rtun/gateway/client.go create mode 100644 pkg/rtun/gateway/client_conn_test.go create mode 100644 pkg/rtun/gateway/errors.go create mode 100644 pkg/rtun/gateway/grpc_options.go create mode 100644 pkg/rtun/gateway/integration_test.go create mode 100644 pkg/rtun/gateway/metrics.go create mode 100644 pkg/rtun/gateway/options.go create mode 100644 pkg/rtun/gateway/server.go create mode 100644 pkg/rtun/match/directory.go create mode 100644 pkg/rtun/match/errors.go create mode 100644 pkg/rtun/match/locator.go create mode 100644 pkg/rtun/match/locator_test.go create mode 100644 pkg/rtun/match/memory/directory.go create mode 100644 pkg/rtun/match/memory/directory_test.go create mode 100644 pkg/rtun/match/memory/presence.go create mode 100644 pkg/rtun/match/memory/presence_test.go create mode 100644 pkg/rtun/match/presence.go create mode 100644 pkg/rtun/match/route.go create mode 100644 pkg/rtun/match/route_test.go create mode 100644 pkg/rtun/server/auth.go create mode 100644 pkg/rtun/server/grpc_options.go create mode 100644 pkg/rtun/server/handler.go create mode 100644 pkg/rtun/server/metrics.go create mode 100644 pkg/rtun/server/options.go create mode 100644 pkg/rtun/server/registry.go create mode 100644 pkg/rtun/server/server_integration_test.go create mode 100644 pkg/rtun/transport/closedset.go create mode 100644 pkg/rtun/transport/closedset_test.go create mode 100644 pkg/rtun/transport/conn.go create mode 100644 pkg/rtun/transport/conn_test.go create mode 100644 pkg/rtun/transport/errors.go create mode 100644 pkg/rtun/transport/listener.go create mode 100644 pkg/rtun/transport/session.go create mode 100644 pkg/rtun/transport/session_race_test.go create mode 100644 pkg/rtun/transport/session_test.go create mode 100644 proto/c1/connectorapi/rtun/v1/gateway.proto create mode 100644 proto/c1/connectorapi/rtun/v1/rtun.proto create mode 100644 vendor/google.golang.org/grpc/health/client.go create mode 100644 vendor/google.golang.org/grpc/health/logging.go create mode 100644 vendor/google.golang.org/grpc/health/producer.go create mode 100644 vendor/google.golang.org/grpc/health/server.go diff --git a/pb/c1/connectorapi/rtun/v1/gateway.pb.go b/pb/c1/connectorapi/rtun/v1/gateway.pb.go new file mode 100644 index 000000000..25f0ec02b --- /dev/null +++ b/pb/c1/connectorapi/rtun/v1/gateway.pb.go @@ -0,0 +1,548 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.4 +// protoc (unknown) +// source: c1/connectorapi/rtun/v1/gateway.proto + +package v1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// GatewayRequest is sent from caller to gateway. +type GatewayRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Kind: + // + // *GatewayRequest_OpenReq + // *GatewayRequest_Frame + Kind isGatewayRequest_Kind `protobuf_oneof:"kind"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GatewayRequest) Reset() { + *x = GatewayRequest{} + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GatewayRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GatewayRequest) ProtoMessage() {} + +func (x *GatewayRequest) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GatewayRequest.ProtoReflect.Descriptor instead. +func (*GatewayRequest) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_gateway_proto_rawDescGZIP(), []int{0} +} + +func (x *GatewayRequest) GetKind() isGatewayRequest_Kind { + if x != nil { + return x.Kind + } + return nil +} + +func (x *GatewayRequest) GetOpenReq() *OpenRequest { + if x != nil { + if x, ok := x.Kind.(*GatewayRequest_OpenReq); ok { + return x.OpenReq + } + } + return nil +} + +func (x *GatewayRequest) GetFrame() *Frame { + if x != nil { + if x, ok := x.Kind.(*GatewayRequest_Frame); ok { + return x.Frame + } + } + return nil +} + +type isGatewayRequest_Kind interface { + isGatewayRequest_Kind() +} + +type GatewayRequest_OpenReq struct { + OpenReq *OpenRequest `protobuf:"bytes,1,opt,name=open_req,json=openReq,proto3,oneof"` // initiate a connection (first message, or concurrent opens) +} + +type GatewayRequest_Frame struct { + Frame *Frame `protobuf:"bytes,10,opt,name=frame,proto3,oneof"` // data/control frames (reuses rtun Frame) +} + +func (*GatewayRequest_OpenReq) isGatewayRequest_Kind() {} + +func (*GatewayRequest_Frame) isGatewayRequest_Kind() {} + +// GatewayResponse is sent from gateway to caller. +type GatewayResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Kind: + // + // *GatewayResponse_OpenResp + // *GatewayResponse_Frame + Kind isGatewayResponse_Kind `protobuf_oneof:"kind"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GatewayResponse) Reset() { + *x = GatewayResponse{} + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GatewayResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GatewayResponse) ProtoMessage() {} + +func (x *GatewayResponse) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GatewayResponse.ProtoReflect.Descriptor instead. +func (*GatewayResponse) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_gateway_proto_rawDescGZIP(), []int{1} +} + +func (x *GatewayResponse) GetKind() isGatewayResponse_Kind { + if x != nil { + return x.Kind + } + return nil +} + +func (x *GatewayResponse) GetOpenResp() *OpenResponse { + if x != nil { + if x, ok := x.Kind.(*GatewayResponse_OpenResp); ok { + return x.OpenResp + } + } + return nil +} + +func (x *GatewayResponse) GetFrame() *Frame { + if x != nil { + if x, ok := x.Kind.(*GatewayResponse_Frame); ok { + return x.Frame + } + } + return nil +} + +type isGatewayResponse_Kind interface { + isGatewayResponse_Kind() +} + +type GatewayResponse_OpenResp struct { + OpenResp *OpenResponse `protobuf:"bytes,1,opt,name=open_resp,json=openResp,proto3,oneof"` // handshake result +} + +type GatewayResponse_Frame struct { + Frame *Frame `protobuf:"bytes,10,opt,name=frame,proto3,oneof"` // data/control frames (reuses rtun Frame) +} + +func (*GatewayResponse_OpenResp) isGatewayResponse_Kind() {} + +func (*GatewayResponse_Frame) isGatewayResponse_Kind() {} + +// OpenRequest initiates a reverse connection to a client. +// The caller proposes a gSID (gateway SID) for this connection to support concurrent opens. +type OpenRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Gsid uint32 `protobuf:"varint,1,opt,name=gsid,proto3" json:"gsid,omitempty"` // caller-proposed SID for this connection (must be unique per stream) + ClientId string `protobuf:"bytes,2,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"` // target client (must be URL-safe) + Port uint32 `protobuf:"varint,3,opt,name=port,proto3" json:"port,omitempty"` // target port on the client + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *OpenRequest) Reset() { + *x = OpenRequest{} + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *OpenRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*OpenRequest) ProtoMessage() {} + +func (x *OpenRequest) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use OpenRequest.ProtoReflect.Descriptor instead. +func (*OpenRequest) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_gateway_proto_rawDescGZIP(), []int{2} +} + +func (x *OpenRequest) GetGsid() uint32 { + if x != nil { + return x.Gsid + } + return 0 +} + +func (x *OpenRequest) GetClientId() string { + if x != nil { + return x.ClientId + } + return "" +} + +func (x *OpenRequest) GetPort() uint32 { + if x != nil { + return x.Port + } + return 0 +} + +// OpenResponse indicates the result of an OpenRequest. +type OpenResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Gsid uint32 `protobuf:"varint,1,opt,name=gsid,proto3" json:"gsid,omitempty"` // echoed from OpenRequest + // Types that are valid to be assigned to Result: + // + // *OpenResponse_NotFound + // *OpenResponse_Opened + Result isOpenResponse_Result `protobuf_oneof:"result"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *OpenResponse) Reset() { + *x = OpenResponse{} + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *OpenResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*OpenResponse) ProtoMessage() {} + +func (x *OpenResponse) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use OpenResponse.ProtoReflect.Descriptor instead. +func (*OpenResponse) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_gateway_proto_rawDescGZIP(), []int{3} +} + +func (x *OpenResponse) GetGsid() uint32 { + if x != nil { + return x.Gsid + } + return 0 +} + +func (x *OpenResponse) GetResult() isOpenResponse_Result { + if x != nil { + return x.Result + } + return nil +} + +func (x *OpenResponse) GetNotFound() *NotFound { + if x != nil { + if x, ok := x.Result.(*OpenResponse_NotFound); ok { + return x.NotFound + } + } + return nil +} + +func (x *OpenResponse) GetOpened() *Opened { + if x != nil { + if x, ok := x.Result.(*OpenResponse_Opened); ok { + return x.Opened + } + } + return nil +} + +type isOpenResponse_Result interface { + isOpenResponse_Result() +} + +type OpenResponse_NotFound struct { + NotFound *NotFound `protobuf:"bytes,2,opt,name=not_found,json=notFound,proto3,oneof"` // gateway doesn't own this client; caller should re-resolve +} + +type OpenResponse_Opened struct { + Opened *Opened `protobuf:"bytes,3,opt,name=opened,proto3,oneof"` // success; use the gSID for subsequent frames +} + +func (*OpenResponse_NotFound) isOpenResponse_Result() {} + +func (*OpenResponse_Opened) isOpenResponse_Result() {} + +type NotFound struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *NotFound) Reset() { + *x = NotFound{} + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *NotFound) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*NotFound) ProtoMessage() {} + +func (x *NotFound) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use NotFound.ProtoReflect.Descriptor instead. +func (*NotFound) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_gateway_proto_rawDescGZIP(), []int{4} +} + +type Opened struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Opened) Reset() { + *x = Opened{} + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Opened) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Opened) ProtoMessage() {} + +func (x *Opened) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Opened.ProtoReflect.Descriptor instead. +func (*Opened) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_gateway_proto_rawDescGZIP(), []int{5} +} + +var File_c1_connectorapi_rtun_v1_gateway_proto protoreflect.FileDescriptor + +var file_c1_connectorapi_rtun_v1_gateway_proto_rawDesc = string([]byte{ + 0x0a, 0x25, 0x63, 0x31, 0x2f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, + 0x69, 0x2f, 0x72, 0x74, 0x75, 0x6e, 0x2f, 0x76, 0x31, 0x2f, 0x67, 0x61, 0x74, 0x65, 0x77, 0x61, + 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x17, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, + 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, + 0x1a, 0x22, 0x63, 0x31, 0x2f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, + 0x69, 0x2f, 0x72, 0x74, 0x75, 0x6e, 0x2f, 0x76, 0x31, 0x2f, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x93, 0x01, 0x0a, 0x0e, 0x47, 0x61, 0x74, 0x65, 0x77, 0x61, 0x79, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x41, 0x0a, 0x08, 0x6f, 0x70, 0x65, 0x6e, 0x5f, + 0x72, 0x65, 0x71, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x63, 0x31, 0x2e, 0x63, + 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, + 0x2e, 0x76, 0x31, 0x2e, 0x4f, 0x70, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, + 0x00, 0x52, 0x07, 0x6f, 0x70, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x12, 0x36, 0x0a, 0x05, 0x66, 0x72, + 0x61, 0x6d, 0x65, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x63, 0x31, 0x2e, 0x63, + 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, + 0x2e, 0x76, 0x31, 0x2e, 0x46, 0x72, 0x61, 0x6d, 0x65, 0x48, 0x00, 0x52, 0x05, 0x66, 0x72, 0x61, + 0x6d, 0x65, 0x42, 0x06, 0x0a, 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x22, 0x97, 0x01, 0x0a, 0x0f, 0x47, + 0x61, 0x74, 0x65, 0x77, 0x61, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x44, + 0x0a, 0x09, 0x6f, 0x70, 0x65, 0x6e, 0x5f, 0x72, 0x65, 0x73, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x25, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, + 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x4f, 0x70, 0x65, 0x6e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x48, 0x00, 0x52, 0x08, 0x6f, 0x70, 0x65, 0x6e, + 0x52, 0x65, 0x73, 0x70, 0x12, 0x36, 0x0a, 0x05, 0x66, 0x72, 0x61, 0x6d, 0x65, 0x18, 0x0a, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, + 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x46, 0x72, + 0x61, 0x6d, 0x65, 0x48, 0x00, 0x52, 0x05, 0x66, 0x72, 0x61, 0x6d, 0x65, 0x42, 0x06, 0x0a, 0x04, + 0x6b, 0x69, 0x6e, 0x64, 0x22, 0x52, 0x0a, 0x0b, 0x4f, 0x70, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x67, 0x73, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x0d, 0x52, 0x04, 0x67, 0x73, 0x69, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x63, 0x6c, 0x69, 0x65, 0x6e, + 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, + 0x6e, 0x74, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x0d, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x22, 0xa9, 0x01, 0x0a, 0x0c, 0x4f, 0x70, 0x65, + 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x67, 0x73, 0x69, + 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x67, 0x73, 0x69, 0x64, 0x12, 0x40, 0x0a, + 0x09, 0x6e, 0x6f, 0x74, 0x5f, 0x66, 0x6f, 0x75, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x21, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, + 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x6f, 0x74, 0x46, 0x6f, + 0x75, 0x6e, 0x64, 0x48, 0x00, 0x52, 0x08, 0x6e, 0x6f, 0x74, 0x46, 0x6f, 0x75, 0x6e, 0x64, 0x12, + 0x39, 0x0a, 0x06, 0x6f, 0x70, 0x65, 0x6e, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x1f, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, + 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x4f, 0x70, 0x65, 0x6e, 0x65, 0x64, + 0x48, 0x00, 0x52, 0x06, 0x6f, 0x70, 0x65, 0x6e, 0x65, 0x64, 0x42, 0x08, 0x0a, 0x06, 0x72, 0x65, + 0x73, 0x75, 0x6c, 0x74, 0x22, 0x0a, 0x0a, 0x08, 0x4e, 0x6f, 0x74, 0x46, 0x6f, 0x75, 0x6e, 0x64, + 0x22, 0x08, 0x0a, 0x06, 0x4f, 0x70, 0x65, 0x6e, 0x65, 0x64, 0x32, 0x6e, 0x0a, 0x0d, 0x52, 0x65, + 0x76, 0x65, 0x72, 0x73, 0x65, 0x44, 0x69, 0x61, 0x6c, 0x65, 0x72, 0x12, 0x5d, 0x0a, 0x04, 0x4f, + 0x70, 0x65, 0x6e, 0x12, 0x27, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, + 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x47, 0x61, + 0x74, 0x65, 0x77, 0x61, 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x63, + 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, + 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x47, 0x61, 0x74, 0x65, 0x77, 0x61, 0x79, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x28, 0x01, 0x30, 0x01, 0x42, 0x3e, 0x5a, 0x3c, 0x67, 0x69, + 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x6e, 0x64, 0x75, 0x63, 0x74, + 0x6f, 0x72, 0x6f, 0x6e, 0x65, 0x2f, 0x62, 0x61, 0x74, 0x6f, 0x6e, 0x2d, 0x73, 0x64, 0x6b, 0x2f, + 0x70, 0x62, 0x2f, 0x63, 0x31, 0x2f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, + 0x70, 0x69, 0x2f, 0x72, 0x74, 0x75, 0x6e, 0x2f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, +}) + +var ( + file_c1_connectorapi_rtun_v1_gateway_proto_rawDescOnce sync.Once + file_c1_connectorapi_rtun_v1_gateway_proto_rawDescData []byte +) + +func file_c1_connectorapi_rtun_v1_gateway_proto_rawDescGZIP() []byte { + file_c1_connectorapi_rtun_v1_gateway_proto_rawDescOnce.Do(func() { + file_c1_connectorapi_rtun_v1_gateway_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_c1_connectorapi_rtun_v1_gateway_proto_rawDesc), len(file_c1_connectorapi_rtun_v1_gateway_proto_rawDesc))) + }) + return file_c1_connectorapi_rtun_v1_gateway_proto_rawDescData +} + +var file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes = make([]protoimpl.MessageInfo, 6) +var file_c1_connectorapi_rtun_v1_gateway_proto_goTypes = []any{ + (*GatewayRequest)(nil), // 0: c1.connectorapi.rtun.v1.GatewayRequest + (*GatewayResponse)(nil), // 1: c1.connectorapi.rtun.v1.GatewayResponse + (*OpenRequest)(nil), // 2: c1.connectorapi.rtun.v1.OpenRequest + (*OpenResponse)(nil), // 3: c1.connectorapi.rtun.v1.OpenResponse + (*NotFound)(nil), // 4: c1.connectorapi.rtun.v1.NotFound + (*Opened)(nil), // 5: c1.connectorapi.rtun.v1.Opened + (*Frame)(nil), // 6: c1.connectorapi.rtun.v1.Frame +} +var file_c1_connectorapi_rtun_v1_gateway_proto_depIdxs = []int32{ + 2, // 0: c1.connectorapi.rtun.v1.GatewayRequest.open_req:type_name -> c1.connectorapi.rtun.v1.OpenRequest + 6, // 1: c1.connectorapi.rtun.v1.GatewayRequest.frame:type_name -> c1.connectorapi.rtun.v1.Frame + 3, // 2: c1.connectorapi.rtun.v1.GatewayResponse.open_resp:type_name -> c1.connectorapi.rtun.v1.OpenResponse + 6, // 3: c1.connectorapi.rtun.v1.GatewayResponse.frame:type_name -> c1.connectorapi.rtun.v1.Frame + 4, // 4: c1.connectorapi.rtun.v1.OpenResponse.not_found:type_name -> c1.connectorapi.rtun.v1.NotFound + 5, // 5: c1.connectorapi.rtun.v1.OpenResponse.opened:type_name -> c1.connectorapi.rtun.v1.Opened + 0, // 6: c1.connectorapi.rtun.v1.ReverseDialer.Open:input_type -> c1.connectorapi.rtun.v1.GatewayRequest + 1, // 7: c1.connectorapi.rtun.v1.ReverseDialer.Open:output_type -> c1.connectorapi.rtun.v1.GatewayResponse + 7, // [7:8] is the sub-list for method output_type + 6, // [6:7] is the sub-list for method input_type + 6, // [6:6] is the sub-list for extension type_name + 6, // [6:6] is the sub-list for extension extendee + 0, // [0:6] is the sub-list for field type_name +} + +func init() { file_c1_connectorapi_rtun_v1_gateway_proto_init() } +func file_c1_connectorapi_rtun_v1_gateway_proto_init() { + if File_c1_connectorapi_rtun_v1_gateway_proto != nil { + return + } + file_c1_connectorapi_rtun_v1_rtun_proto_init() + file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[0].OneofWrappers = []any{ + (*GatewayRequest_OpenReq)(nil), + (*GatewayRequest_Frame)(nil), + } + file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[1].OneofWrappers = []any{ + (*GatewayResponse_OpenResp)(nil), + (*GatewayResponse_Frame)(nil), + } + file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[3].OneofWrappers = []any{ + (*OpenResponse_NotFound)(nil), + (*OpenResponse_Opened)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_c1_connectorapi_rtun_v1_gateway_proto_rawDesc), len(file_c1_connectorapi_rtun_v1_gateway_proto_rawDesc)), + NumEnums: 0, + NumMessages: 6, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_c1_connectorapi_rtun_v1_gateway_proto_goTypes, + DependencyIndexes: file_c1_connectorapi_rtun_v1_gateway_proto_depIdxs, + MessageInfos: file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes, + }.Build() + File_c1_connectorapi_rtun_v1_gateway_proto = out.File + file_c1_connectorapi_rtun_v1_gateway_proto_goTypes = nil + file_c1_connectorapi_rtun_v1_gateway_proto_depIdxs = nil +} diff --git a/pb/c1/connectorapi/rtun/v1/gateway.pb.validate.go b/pb/c1/connectorapi/rtun/v1/gateway.pb.validate.go new file mode 100644 index 000000000..6a6eea901 --- /dev/null +++ b/pb/c1/connectorapi/rtun/v1/gateway.pb.validate.go @@ -0,0 +1,900 @@ +// Code generated by protoc-gen-validate. DO NOT EDIT. +// source: c1/connectorapi/rtun/v1/gateway.proto + +package v1 + +import ( + "bytes" + "errors" + "fmt" + "net" + "net/mail" + "net/url" + "regexp" + "sort" + "strings" + "time" + "unicode/utf8" + + "google.golang.org/protobuf/types/known/anypb" +) + +// ensure the imports are used +var ( + _ = bytes.MinRead + _ = errors.New("") + _ = fmt.Print + _ = utf8.UTFMax + _ = (*regexp.Regexp)(nil) + _ = (*strings.Reader)(nil) + _ = net.IPv4len + _ = time.Duration(0) + _ = (*url.URL)(nil) + _ = (*mail.Address)(nil) + _ = anypb.Any{} + _ = sort.Sort +) + +// Validate checks the field values on GatewayRequest with the rules defined in +// the proto definition for this message. If any rules are violated, the first +// error encountered is returned, or nil if there are no violations. +func (m *GatewayRequest) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on GatewayRequest with the rules defined +// in the proto definition for this message. If any rules are violated, the +// result is a list of violation errors wrapped in GatewayRequestMultiError, +// or nil if none found. +func (m *GatewayRequest) ValidateAll() error { + return m.validate(true) +} + +func (m *GatewayRequest) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + switch v := m.Kind.(type) { + case *GatewayRequest_OpenReq: + if v == nil { + err := GatewayRequestValidationError{ + field: "Kind", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetOpenReq()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, GatewayRequestValidationError{ + field: "OpenReq", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, GatewayRequestValidationError{ + field: "OpenReq", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetOpenReq()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return GatewayRequestValidationError{ + field: "OpenReq", + reason: "embedded message failed validation", + cause: err, + } + } + } + + case *GatewayRequest_Frame: + if v == nil { + err := GatewayRequestValidationError{ + field: "Kind", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetFrame()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, GatewayRequestValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, GatewayRequestValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetFrame()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return GatewayRequestValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + } + } + } + + default: + _ = v // ensures v is used + } + + if len(errors) > 0 { + return GatewayRequestMultiError(errors) + } + + return nil +} + +// GatewayRequestMultiError is an error wrapping multiple validation errors +// returned by GatewayRequest.ValidateAll() if the designated constraints +// aren't met. +type GatewayRequestMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m GatewayRequestMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m GatewayRequestMultiError) AllErrors() []error { return m } + +// GatewayRequestValidationError is the validation error returned by +// GatewayRequest.Validate if the designated constraints aren't met. +type GatewayRequestValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e GatewayRequestValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e GatewayRequestValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e GatewayRequestValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e GatewayRequestValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e GatewayRequestValidationError) ErrorName() string { return "GatewayRequestValidationError" } + +// Error satisfies the builtin error interface +func (e GatewayRequestValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sGatewayRequest.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = GatewayRequestValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = GatewayRequestValidationError{} + +// Validate checks the field values on GatewayResponse with the rules defined +// in the proto definition for this message. If any rules are violated, the +// first error encountered is returned, or nil if there are no violations. +func (m *GatewayResponse) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on GatewayResponse with the rules +// defined in the proto definition for this message. If any rules are +// violated, the result is a list of violation errors wrapped in +// GatewayResponseMultiError, or nil if none found. +func (m *GatewayResponse) ValidateAll() error { + return m.validate(true) +} + +func (m *GatewayResponse) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + switch v := m.Kind.(type) { + case *GatewayResponse_OpenResp: + if v == nil { + err := GatewayResponseValidationError{ + field: "Kind", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetOpenResp()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, GatewayResponseValidationError{ + field: "OpenResp", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, GatewayResponseValidationError{ + field: "OpenResp", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetOpenResp()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return GatewayResponseValidationError{ + field: "OpenResp", + reason: "embedded message failed validation", + cause: err, + } + } + } + + case *GatewayResponse_Frame: + if v == nil { + err := GatewayResponseValidationError{ + field: "Kind", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetFrame()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, GatewayResponseValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, GatewayResponseValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetFrame()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return GatewayResponseValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + } + } + } + + default: + _ = v // ensures v is used + } + + if len(errors) > 0 { + return GatewayResponseMultiError(errors) + } + + return nil +} + +// GatewayResponseMultiError is an error wrapping multiple validation errors +// returned by GatewayResponse.ValidateAll() if the designated constraints +// aren't met. +type GatewayResponseMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m GatewayResponseMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m GatewayResponseMultiError) AllErrors() []error { return m } + +// GatewayResponseValidationError is the validation error returned by +// GatewayResponse.Validate if the designated constraints aren't met. +type GatewayResponseValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e GatewayResponseValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e GatewayResponseValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e GatewayResponseValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e GatewayResponseValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e GatewayResponseValidationError) ErrorName() string { return "GatewayResponseValidationError" } + +// Error satisfies the builtin error interface +func (e GatewayResponseValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sGatewayResponse.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = GatewayResponseValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = GatewayResponseValidationError{} + +// Validate checks the field values on OpenRequest with the rules defined in +// the proto definition for this message. If any rules are violated, the first +// error encountered is returned, or nil if there are no violations. +func (m *OpenRequest) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on OpenRequest with the rules defined in +// the proto definition for this message. If any rules are violated, the +// result is a list of violation errors wrapped in OpenRequestMultiError, or +// nil if none found. +func (m *OpenRequest) ValidateAll() error { + return m.validate(true) +} + +func (m *OpenRequest) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + // no validation rules for Gsid + + // no validation rules for ClientId + + // no validation rules for Port + + if len(errors) > 0 { + return OpenRequestMultiError(errors) + } + + return nil +} + +// OpenRequestMultiError is an error wrapping multiple validation errors +// returned by OpenRequest.ValidateAll() if the designated constraints aren't met. +type OpenRequestMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m OpenRequestMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m OpenRequestMultiError) AllErrors() []error { return m } + +// OpenRequestValidationError is the validation error returned by +// OpenRequest.Validate if the designated constraints aren't met. +type OpenRequestValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e OpenRequestValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e OpenRequestValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e OpenRequestValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e OpenRequestValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e OpenRequestValidationError) ErrorName() string { return "OpenRequestValidationError" } + +// Error satisfies the builtin error interface +func (e OpenRequestValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sOpenRequest.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = OpenRequestValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = OpenRequestValidationError{} + +// Validate checks the field values on OpenResponse with the rules defined in +// the proto definition for this message. If any rules are violated, the first +// error encountered is returned, or nil if there are no violations. +func (m *OpenResponse) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on OpenResponse with the rules defined +// in the proto definition for this message. If any rules are violated, the +// result is a list of violation errors wrapped in OpenResponseMultiError, or +// nil if none found. +func (m *OpenResponse) ValidateAll() error { + return m.validate(true) +} + +func (m *OpenResponse) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + // no validation rules for Gsid + + switch v := m.Result.(type) { + case *OpenResponse_NotFound: + if v == nil { + err := OpenResponseValidationError{ + field: "Result", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetNotFound()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, OpenResponseValidationError{ + field: "NotFound", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, OpenResponseValidationError{ + field: "NotFound", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetNotFound()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return OpenResponseValidationError{ + field: "NotFound", + reason: "embedded message failed validation", + cause: err, + } + } + } + + case *OpenResponse_Opened: + if v == nil { + err := OpenResponseValidationError{ + field: "Result", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetOpened()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, OpenResponseValidationError{ + field: "Opened", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, OpenResponseValidationError{ + field: "Opened", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetOpened()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return OpenResponseValidationError{ + field: "Opened", + reason: "embedded message failed validation", + cause: err, + } + } + } + + default: + _ = v // ensures v is used + } + + if len(errors) > 0 { + return OpenResponseMultiError(errors) + } + + return nil +} + +// OpenResponseMultiError is an error wrapping multiple validation errors +// returned by OpenResponse.ValidateAll() if the designated constraints aren't met. +type OpenResponseMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m OpenResponseMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m OpenResponseMultiError) AllErrors() []error { return m } + +// OpenResponseValidationError is the validation error returned by +// OpenResponse.Validate if the designated constraints aren't met. +type OpenResponseValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e OpenResponseValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e OpenResponseValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e OpenResponseValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e OpenResponseValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e OpenResponseValidationError) ErrorName() string { return "OpenResponseValidationError" } + +// Error satisfies the builtin error interface +func (e OpenResponseValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sOpenResponse.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = OpenResponseValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = OpenResponseValidationError{} + +// Validate checks the field values on NotFound with the rules defined in the +// proto definition for this message. If any rules are violated, the first +// error encountered is returned, or nil if there are no violations. +func (m *NotFound) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on NotFound with the rules defined in +// the proto definition for this message. If any rules are violated, the +// result is a list of violation errors wrapped in NotFoundMultiError, or nil +// if none found. +func (m *NotFound) ValidateAll() error { + return m.validate(true) +} + +func (m *NotFound) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + if len(errors) > 0 { + return NotFoundMultiError(errors) + } + + return nil +} + +// NotFoundMultiError is an error wrapping multiple validation errors returned +// by NotFound.ValidateAll() if the designated constraints aren't met. +type NotFoundMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m NotFoundMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m NotFoundMultiError) AllErrors() []error { return m } + +// NotFoundValidationError is the validation error returned by +// NotFound.Validate if the designated constraints aren't met. +type NotFoundValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e NotFoundValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e NotFoundValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e NotFoundValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e NotFoundValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e NotFoundValidationError) ErrorName() string { return "NotFoundValidationError" } + +// Error satisfies the builtin error interface +func (e NotFoundValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sNotFound.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = NotFoundValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = NotFoundValidationError{} + +// Validate checks the field values on Opened with the rules defined in the +// proto definition for this message. If any rules are violated, the first +// error encountered is returned, or nil if there are no violations. +func (m *Opened) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on Opened with the rules defined in the +// proto definition for this message. If any rules are violated, the result is +// a list of violation errors wrapped in OpenedMultiError, or nil if none found. +func (m *Opened) ValidateAll() error { + return m.validate(true) +} + +func (m *Opened) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + if len(errors) > 0 { + return OpenedMultiError(errors) + } + + return nil +} + +// OpenedMultiError is an error wrapping multiple validation errors returned by +// Opened.ValidateAll() if the designated constraints aren't met. +type OpenedMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m OpenedMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m OpenedMultiError) AllErrors() []error { return m } + +// OpenedValidationError is the validation error returned by Opened.Validate if +// the designated constraints aren't met. +type OpenedValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e OpenedValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e OpenedValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e OpenedValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e OpenedValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e OpenedValidationError) ErrorName() string { return "OpenedValidationError" } + +// Error satisfies the builtin error interface +func (e OpenedValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sOpened.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = OpenedValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = OpenedValidationError{} diff --git a/pb/c1/connectorapi/rtun/v1/gateway_grpc.pb.go b/pb/c1/connectorapi/rtun/v1/gateway_grpc.pb.go new file mode 100644 index 000000000..937dc00bc --- /dev/null +++ b/pb/c1/connectorapi/rtun/v1/gateway_grpc.pb.go @@ -0,0 +1,119 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc (unknown) +// source: c1/connectorapi/rtun/v1/gateway.proto + +package v1 + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + ReverseDialer_Open_FullMethodName = "/c1.connectorapi.rtun.v1.ReverseDialer/Open" +) + +// ReverseDialerClient is the client API for ReverseDialer service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +// +// ReverseDialer allows callers to establish connections to clients via the gateway. +// The gateway bridges caller streams to rtun sessions on the owner server. +type ReverseDialerClient interface { + Open(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[GatewayRequest, GatewayResponse], error) +} + +type reverseDialerClient struct { + cc grpc.ClientConnInterface +} + +func NewReverseDialerClient(cc grpc.ClientConnInterface) ReverseDialerClient { + return &reverseDialerClient{cc} +} + +func (c *reverseDialerClient) Open(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[GatewayRequest, GatewayResponse], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &ReverseDialer_ServiceDesc.Streams[0], ReverseDialer_Open_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[GatewayRequest, GatewayResponse]{ClientStream: stream} + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type ReverseDialer_OpenClient = grpc.BidiStreamingClient[GatewayRequest, GatewayResponse] + +// ReverseDialerServer is the server API for ReverseDialer service. +// All implementations should embed UnimplementedReverseDialerServer +// for forward compatibility. +// +// ReverseDialer allows callers to establish connections to clients via the gateway. +// The gateway bridges caller streams to rtun sessions on the owner server. +type ReverseDialerServer interface { + Open(grpc.BidiStreamingServer[GatewayRequest, GatewayResponse]) error +} + +// UnimplementedReverseDialerServer should be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedReverseDialerServer struct{} + +func (UnimplementedReverseDialerServer) Open(grpc.BidiStreamingServer[GatewayRequest, GatewayResponse]) error { + return status.Errorf(codes.Unimplemented, "method Open not implemented") +} +func (UnimplementedReverseDialerServer) testEmbeddedByValue() {} + +// UnsafeReverseDialerServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to ReverseDialerServer will +// result in compilation errors. +type UnsafeReverseDialerServer interface { + mustEmbedUnimplementedReverseDialerServer() +} + +func RegisterReverseDialerServer(s grpc.ServiceRegistrar, srv ReverseDialerServer) { + // If the following call pancis, it indicates UnimplementedReverseDialerServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&ReverseDialer_ServiceDesc, srv) +} + +func _ReverseDialer_Open_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(ReverseDialerServer).Open(&grpc.GenericServerStream[GatewayRequest, GatewayResponse]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type ReverseDialer_OpenServer = grpc.BidiStreamingServer[GatewayRequest, GatewayResponse] + +// ReverseDialer_ServiceDesc is the grpc.ServiceDesc for ReverseDialer service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var ReverseDialer_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "c1.connectorapi.rtun.v1.ReverseDialer", + HandlerType: (*ReverseDialerServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "Open", + Handler: _ReverseDialer_Open_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: "c1/connectorapi/rtun/v1/gateway.proto", +} diff --git a/pb/c1/connectorapi/rtun/v1/rtun.pb.go b/pb/c1/connectorapi/rtun/v1/rtun.pb.go new file mode 100644 index 000000000..5a3cf9fda --- /dev/null +++ b/pb/c1/connectorapi/rtun/v1/rtun.pb.go @@ -0,0 +1,574 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.4 +// protoc (unknown) +// source: c1/connectorapi/rtun/v1/rtun.proto + +package v1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type RstCode int32 + +const ( + RstCode_RST_CODE_UNSPECIFIED RstCode = 0 + RstCode_RST_CODE_NO_LISTENER RstCode = 1 + RstCode_RST_CODE_PORT_NOT_ADVERTISED RstCode = 2 + RstCode_RST_CODE_TIMEOUT RstCode = 3 + RstCode_RST_CODE_INTERNAL RstCode = 4 +) + +// Enum value maps for RstCode. +var ( + RstCode_name = map[int32]string{ + 0: "RST_CODE_UNSPECIFIED", + 1: "RST_CODE_NO_LISTENER", + 2: "RST_CODE_PORT_NOT_ADVERTISED", + 3: "RST_CODE_TIMEOUT", + 4: "RST_CODE_INTERNAL", + } + RstCode_value = map[string]int32{ + "RST_CODE_UNSPECIFIED": 0, + "RST_CODE_NO_LISTENER": 1, + "RST_CODE_PORT_NOT_ADVERTISED": 2, + "RST_CODE_TIMEOUT": 3, + "RST_CODE_INTERNAL": 4, + } +) + +func (x RstCode) Enum() *RstCode { + p := new(RstCode) + *p = x + return p +} + +func (x RstCode) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (RstCode) Descriptor() protoreflect.EnumDescriptor { + return file_c1_connectorapi_rtun_v1_rtun_proto_enumTypes[0].Descriptor() +} + +func (RstCode) Type() protoreflect.EnumType { + return &file_c1_connectorapi_rtun_v1_rtun_proto_enumTypes[0] +} + +func (x RstCode) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use RstCode.Descriptor instead. +func (RstCode) EnumDescriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{0} +} + +type Frame struct { + state protoimpl.MessageState `protogen:"open.v1"` + Sid uint32 `protobuf:"varint,1,opt,name=sid,proto3" json:"sid,omitempty"` // 0 reserved for control + // Types that are valid to be assigned to Kind: + // + // *Frame_Hello + // *Frame_Syn + // *Frame_Data + // *Frame_Fin + // *Frame_Rst + Kind isFrame_Kind `protobuf_oneof:"kind"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Frame) Reset() { + *x = Frame{} + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Frame) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Frame) ProtoMessage() {} + +func (x *Frame) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Frame.ProtoReflect.Descriptor instead. +func (*Frame) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{0} +} + +func (x *Frame) GetSid() uint32 { + if x != nil { + return x.Sid + } + return 0 +} + +func (x *Frame) GetKind() isFrame_Kind { + if x != nil { + return x.Kind + } + return nil +} + +func (x *Frame) GetHello() *Hello { + if x != nil { + if x, ok := x.Kind.(*Frame_Hello); ok { + return x.Hello + } + } + return nil +} + +func (x *Frame) GetSyn() *Syn { + if x != nil { + if x, ok := x.Kind.(*Frame_Syn); ok { + return x.Syn + } + } + return nil +} + +func (x *Frame) GetData() *Data { + if x != nil { + if x, ok := x.Kind.(*Frame_Data); ok { + return x.Data + } + } + return nil +} + +func (x *Frame) GetFin() *Fin { + if x != nil { + if x, ok := x.Kind.(*Frame_Fin); ok { + return x.Fin + } + } + return nil +} + +func (x *Frame) GetRst() *Rst { + if x != nil { + if x, ok := x.Kind.(*Frame_Rst); ok { + return x.Rst + } + } + return nil +} + +type isFrame_Kind interface { + isFrame_Kind() +} + +type Frame_Hello struct { + Hello *Hello `protobuf:"bytes,10,opt,name=hello,proto3,oneof"` // client -> server (first) +} + +type Frame_Syn struct { + Syn *Syn `protobuf:"bytes,11,opt,name=syn,proto3,oneof"` // server -> client (reverse open) +} + +type Frame_Data struct { + Data *Data `protobuf:"bytes,12,opt,name=data,proto3,oneof"` // either direction +} + +type Frame_Fin struct { + Fin *Fin `protobuf:"bytes,13,opt,name=fin,proto3,oneof"` // either direction +} + +type Frame_Rst struct { + Rst *Rst `protobuf:"bytes,14,opt,name=rst,proto3,oneof"` // either direction +} + +func (*Frame_Hello) isFrame_Kind() {} + +func (*Frame_Syn) isFrame_Kind() {} + +func (*Frame_Data) isFrame_Kind() {} + +func (*Frame_Fin) isFrame_Kind() {} + +func (*Frame_Rst) isFrame_Kind() {} + +type Hello struct { + state protoimpl.MessageState `protogen:"open.v1"` + Ports []uint32 `protobuf:"varint,1,rep,packed,name=ports,proto3" json:"ports,omitempty"` + Protocol uint32 `protobuf:"varint,2,opt,name=protocol,proto3" json:"protocol,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Hello) Reset() { + *x = Hello{} + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Hello) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Hello) ProtoMessage() {} + +func (x *Hello) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Hello.ProtoReflect.Descriptor instead. +func (*Hello) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{1} +} + +func (x *Hello) GetPorts() []uint32 { + if x != nil { + return x.Ports + } + return nil +} + +func (x *Hello) GetProtocol() uint32 { + if x != nil { + return x.Protocol + } + return 0 +} + +type Syn struct { + state protoimpl.MessageState `protogen:"open.v1"` + Port uint32 `protobuf:"varint,1,opt,name=port,proto3" json:"port,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Syn) Reset() { + *x = Syn{} + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Syn) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Syn) ProtoMessage() {} + +func (x *Syn) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Syn.ProtoReflect.Descriptor instead. +func (*Syn) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{2} +} + +func (x *Syn) GetPort() uint32 { + if x != nil { + return x.Port + } + return 0 +} + +type Data struct { + state protoimpl.MessageState `protogen:"open.v1"` + Payload []byte `protobuf:"bytes,1,opt,name=payload,proto3" json:"payload,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Data) Reset() { + *x = Data{} + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Data) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Data) ProtoMessage() {} + +func (x *Data) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Data.ProtoReflect.Descriptor instead. +func (*Data) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{3} +} + +func (x *Data) GetPayload() []byte { + if x != nil { + return x.Payload + } + return nil +} + +type Fin struct { + state protoimpl.MessageState `protogen:"open.v1"` + Ack bool `protobuf:"varint,1,opt,name=ack,proto3" json:"ack,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Fin) Reset() { + *x = Fin{} + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Fin) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Fin) ProtoMessage() {} + +func (x *Fin) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Fin.ProtoReflect.Descriptor instead. +func (*Fin) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{4} +} + +func (x *Fin) GetAck() bool { + if x != nil { + return x.Ack + } + return false +} + +type Rst struct { + state protoimpl.MessageState `protogen:"open.v1"` + Code RstCode `protobuf:"varint,1,opt,name=code,proto3,enum=c1.connectorapi.rtun.v1.RstCode" json:"code,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Rst) Reset() { + *x = Rst{} + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Rst) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Rst) ProtoMessage() {} + +func (x *Rst) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Rst.ProtoReflect.Descriptor instead. +func (*Rst) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{5} +} + +func (x *Rst) GetCode() RstCode { + if x != nil { + return x.Code + } + return RstCode_RST_CODE_UNSPECIFIED +} + +var File_c1_connectorapi_rtun_v1_rtun_proto protoreflect.FileDescriptor + +var file_c1_connectorapi_rtun_v1_rtun_proto_rawDesc = string([]byte{ + 0x0a, 0x22, 0x63, 0x31, 0x2f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, + 0x69, 0x2f, 0x72, 0x74, 0x75, 0x6e, 0x2f, 0x76, 0x31, 0x2f, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x17, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, + 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x22, 0xa4, 0x02, + 0x0a, 0x05, 0x46, 0x72, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x73, 0x69, 0x64, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x73, 0x69, 0x64, 0x12, 0x36, 0x0a, 0x05, 0x68, 0x65, 0x6c, + 0x6c, 0x6f, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, + 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, + 0x76, 0x31, 0x2e, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x48, 0x00, 0x52, 0x05, 0x68, 0x65, 0x6c, 0x6c, + 0x6f, 0x12, 0x30, 0x0a, 0x03, 0x73, 0x79, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, + 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, + 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x79, 0x6e, 0x48, 0x00, 0x52, 0x03, + 0x73, 0x79, 0x6e, 0x12, 0x33, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x0c, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x1d, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, + 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x44, 0x61, 0x74, 0x61, + 0x48, 0x00, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x12, 0x30, 0x0a, 0x03, 0x66, 0x69, 0x6e, 0x18, + 0x0d, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, + 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, + 0x46, 0x69, 0x6e, 0x48, 0x00, 0x52, 0x03, 0x66, 0x69, 0x6e, 0x12, 0x30, 0x0a, 0x03, 0x72, 0x73, + 0x74, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, + 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, + 0x31, 0x2e, 0x52, 0x73, 0x74, 0x48, 0x00, 0x52, 0x03, 0x72, 0x73, 0x74, 0x42, 0x06, 0x0a, 0x04, + 0x6b, 0x69, 0x6e, 0x64, 0x22, 0x39, 0x0a, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x12, 0x14, 0x0a, + 0x05, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x05, 0x70, 0x6f, + 0x72, 0x74, 0x73, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, + 0x19, 0x0a, 0x03, 0x53, 0x79, 0x6e, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x22, 0x20, 0x0a, 0x04, 0x44, 0x61, + 0x74, 0x61, 0x12, 0x18, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0c, 0x52, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x22, 0x17, 0x0a, 0x03, + 0x46, 0x69, 0x6e, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x63, 0x6b, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x03, 0x61, 0x63, 0x6b, 0x22, 0x3b, 0x0a, 0x03, 0x52, 0x73, 0x74, 0x12, 0x34, 0x0a, 0x04, + 0x63, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x20, 0x2e, 0x63, 0x31, 0x2e, + 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, + 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x73, 0x74, 0x43, 0x6f, 0x64, 0x65, 0x52, 0x04, 0x63, 0x6f, + 0x64, 0x65, 0x2a, 0x8c, 0x01, 0x0a, 0x07, 0x52, 0x73, 0x74, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x18, + 0x0a, 0x14, 0x52, 0x53, 0x54, 0x5f, 0x43, 0x4f, 0x44, 0x45, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, + 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x18, 0x0a, 0x14, 0x52, 0x53, 0x54, 0x5f, + 0x43, 0x4f, 0x44, 0x45, 0x5f, 0x4e, 0x4f, 0x5f, 0x4c, 0x49, 0x53, 0x54, 0x45, 0x4e, 0x45, 0x52, + 0x10, 0x01, 0x12, 0x20, 0x0a, 0x1c, 0x52, 0x53, 0x54, 0x5f, 0x43, 0x4f, 0x44, 0x45, 0x5f, 0x50, + 0x4f, 0x52, 0x54, 0x5f, 0x4e, 0x4f, 0x54, 0x5f, 0x41, 0x44, 0x56, 0x45, 0x52, 0x54, 0x49, 0x53, + 0x45, 0x44, 0x10, 0x02, 0x12, 0x14, 0x0a, 0x10, 0x52, 0x53, 0x54, 0x5f, 0x43, 0x4f, 0x44, 0x45, + 0x5f, 0x54, 0x49, 0x4d, 0x45, 0x4f, 0x55, 0x54, 0x10, 0x03, 0x12, 0x15, 0x0a, 0x11, 0x52, 0x53, + 0x54, 0x5f, 0x43, 0x4f, 0x44, 0x45, 0x5f, 0x49, 0x4e, 0x54, 0x45, 0x52, 0x4e, 0x41, 0x4c, 0x10, + 0x04, 0x32, 0x5b, 0x0a, 0x0d, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x54, 0x75, 0x6e, 0x6e, + 0x65, 0x6c, 0x12, 0x4a, 0x0a, 0x04, 0x4c, 0x69, 0x6e, 0x6b, 0x12, 0x1e, 0x2e, 0x63, 0x31, 0x2e, + 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, + 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x46, 0x72, 0x61, 0x6d, 0x65, 0x1a, 0x1e, 0x2e, 0x63, 0x31, 0x2e, + 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, + 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x46, 0x72, 0x61, 0x6d, 0x65, 0x28, 0x01, 0x30, 0x01, 0x42, 0x3e, + 0x5a, 0x3c, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x6e, + 0x64, 0x75, 0x63, 0x74, 0x6f, 0x72, 0x6f, 0x6e, 0x65, 0x2f, 0x62, 0x61, 0x74, 0x6f, 0x6e, 0x2d, + 0x73, 0x64, 0x6b, 0x2f, 0x70, 0x62, 0x2f, 0x63, 0x31, 0x2f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, + 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2f, 0x72, 0x74, 0x75, 0x6e, 0x2f, 0x76, 0x31, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +}) + +var ( + file_c1_connectorapi_rtun_v1_rtun_proto_rawDescOnce sync.Once + file_c1_connectorapi_rtun_v1_rtun_proto_rawDescData []byte +) + +func file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP() []byte { + file_c1_connectorapi_rtun_v1_rtun_proto_rawDescOnce.Do(func() { + file_c1_connectorapi_rtun_v1_rtun_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_c1_connectorapi_rtun_v1_rtun_proto_rawDesc), len(file_c1_connectorapi_rtun_v1_rtun_proto_rawDesc))) + }) + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescData +} + +var file_c1_connectorapi_rtun_v1_rtun_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes = make([]protoimpl.MessageInfo, 6) +var file_c1_connectorapi_rtun_v1_rtun_proto_goTypes = []any{ + (RstCode)(0), // 0: c1.connectorapi.rtun.v1.RstCode + (*Frame)(nil), // 1: c1.connectorapi.rtun.v1.Frame + (*Hello)(nil), // 2: c1.connectorapi.rtun.v1.Hello + (*Syn)(nil), // 3: c1.connectorapi.rtun.v1.Syn + (*Data)(nil), // 4: c1.connectorapi.rtun.v1.Data + (*Fin)(nil), // 5: c1.connectorapi.rtun.v1.Fin + (*Rst)(nil), // 6: c1.connectorapi.rtun.v1.Rst +} +var file_c1_connectorapi_rtun_v1_rtun_proto_depIdxs = []int32{ + 2, // 0: c1.connectorapi.rtun.v1.Frame.hello:type_name -> c1.connectorapi.rtun.v1.Hello + 3, // 1: c1.connectorapi.rtun.v1.Frame.syn:type_name -> c1.connectorapi.rtun.v1.Syn + 4, // 2: c1.connectorapi.rtun.v1.Frame.data:type_name -> c1.connectorapi.rtun.v1.Data + 5, // 3: c1.connectorapi.rtun.v1.Frame.fin:type_name -> c1.connectorapi.rtun.v1.Fin + 6, // 4: c1.connectorapi.rtun.v1.Frame.rst:type_name -> c1.connectorapi.rtun.v1.Rst + 0, // 5: c1.connectorapi.rtun.v1.Rst.code:type_name -> c1.connectorapi.rtun.v1.RstCode + 1, // 6: c1.connectorapi.rtun.v1.ReverseTunnel.Link:input_type -> c1.connectorapi.rtun.v1.Frame + 1, // 7: c1.connectorapi.rtun.v1.ReverseTunnel.Link:output_type -> c1.connectorapi.rtun.v1.Frame + 7, // [7:8] is the sub-list for method output_type + 6, // [6:7] is the sub-list for method input_type + 6, // [6:6] is the sub-list for extension type_name + 6, // [6:6] is the sub-list for extension extendee + 0, // [0:6] is the sub-list for field type_name +} + +func init() { file_c1_connectorapi_rtun_v1_rtun_proto_init() } +func file_c1_connectorapi_rtun_v1_rtun_proto_init() { + if File_c1_connectorapi_rtun_v1_rtun_proto != nil { + return + } + file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[0].OneofWrappers = []any{ + (*Frame_Hello)(nil), + (*Frame_Syn)(nil), + (*Frame_Data)(nil), + (*Frame_Fin)(nil), + (*Frame_Rst)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_c1_connectorapi_rtun_v1_rtun_proto_rawDesc), len(file_c1_connectorapi_rtun_v1_rtun_proto_rawDesc)), + NumEnums: 1, + NumMessages: 6, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_c1_connectorapi_rtun_v1_rtun_proto_goTypes, + DependencyIndexes: file_c1_connectorapi_rtun_v1_rtun_proto_depIdxs, + EnumInfos: file_c1_connectorapi_rtun_v1_rtun_proto_enumTypes, + MessageInfos: file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes, + }.Build() + File_c1_connectorapi_rtun_v1_rtun_proto = out.File + file_c1_connectorapi_rtun_v1_rtun_proto_goTypes = nil + file_c1_connectorapi_rtun_v1_rtun_proto_depIdxs = nil +} diff --git a/pb/c1/connectorapi/rtun/v1/rtun.pb.validate.go b/pb/c1/connectorapi/rtun/v1/rtun.pb.validate.go new file mode 100644 index 000000000..854f092e4 --- /dev/null +++ b/pb/c1/connectorapi/rtun/v1/rtun.pb.validate.go @@ -0,0 +1,846 @@ +// Code generated by protoc-gen-validate. DO NOT EDIT. +// source: c1/connectorapi/rtun/v1/rtun.proto + +package v1 + +import ( + "bytes" + "errors" + "fmt" + "net" + "net/mail" + "net/url" + "regexp" + "sort" + "strings" + "time" + "unicode/utf8" + + "google.golang.org/protobuf/types/known/anypb" +) + +// ensure the imports are used +var ( + _ = bytes.MinRead + _ = errors.New("") + _ = fmt.Print + _ = utf8.UTFMax + _ = (*regexp.Regexp)(nil) + _ = (*strings.Reader)(nil) + _ = net.IPv4len + _ = time.Duration(0) + _ = (*url.URL)(nil) + _ = (*mail.Address)(nil) + _ = anypb.Any{} + _ = sort.Sort +) + +// Validate checks the field values on Frame with the rules defined in the +// proto definition for this message. If any rules are violated, the first +// error encountered is returned, or nil if there are no violations. +func (m *Frame) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on Frame with the rules defined in the +// proto definition for this message. If any rules are violated, the result is +// a list of violation errors wrapped in FrameMultiError, or nil if none found. +func (m *Frame) ValidateAll() error { + return m.validate(true) +} + +func (m *Frame) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + // no validation rules for Sid + + switch v := m.Kind.(type) { + case *Frame_Hello: + if v == nil { + err := FrameValidationError{ + field: "Kind", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetHello()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, FrameValidationError{ + field: "Hello", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, FrameValidationError{ + field: "Hello", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetHello()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return FrameValidationError{ + field: "Hello", + reason: "embedded message failed validation", + cause: err, + } + } + } + + case *Frame_Syn: + if v == nil { + err := FrameValidationError{ + field: "Kind", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetSyn()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, FrameValidationError{ + field: "Syn", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, FrameValidationError{ + field: "Syn", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetSyn()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return FrameValidationError{ + field: "Syn", + reason: "embedded message failed validation", + cause: err, + } + } + } + + case *Frame_Data: + if v == nil { + err := FrameValidationError{ + field: "Kind", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetData()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, FrameValidationError{ + field: "Data", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, FrameValidationError{ + field: "Data", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetData()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return FrameValidationError{ + field: "Data", + reason: "embedded message failed validation", + cause: err, + } + } + } + + case *Frame_Fin: + if v == nil { + err := FrameValidationError{ + field: "Kind", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetFin()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, FrameValidationError{ + field: "Fin", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, FrameValidationError{ + field: "Fin", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetFin()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return FrameValidationError{ + field: "Fin", + reason: "embedded message failed validation", + cause: err, + } + } + } + + case *Frame_Rst: + if v == nil { + err := FrameValidationError{ + field: "Kind", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetRst()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, FrameValidationError{ + field: "Rst", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, FrameValidationError{ + field: "Rst", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetRst()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return FrameValidationError{ + field: "Rst", + reason: "embedded message failed validation", + cause: err, + } + } + } + + default: + _ = v // ensures v is used + } + + if len(errors) > 0 { + return FrameMultiError(errors) + } + + return nil +} + +// FrameMultiError is an error wrapping multiple validation errors returned by +// Frame.ValidateAll() if the designated constraints aren't met. +type FrameMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m FrameMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m FrameMultiError) AllErrors() []error { return m } + +// FrameValidationError is the validation error returned by Frame.Validate if +// the designated constraints aren't met. +type FrameValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e FrameValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e FrameValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e FrameValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e FrameValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e FrameValidationError) ErrorName() string { return "FrameValidationError" } + +// Error satisfies the builtin error interface +func (e FrameValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sFrame.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = FrameValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = FrameValidationError{} + +// Validate checks the field values on Hello with the rules defined in the +// proto definition for this message. If any rules are violated, the first +// error encountered is returned, or nil if there are no violations. +func (m *Hello) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on Hello with the rules defined in the +// proto definition for this message. If any rules are violated, the result is +// a list of violation errors wrapped in HelloMultiError, or nil if none found. +func (m *Hello) ValidateAll() error { + return m.validate(true) +} + +func (m *Hello) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + // no validation rules for Protocol + + if len(errors) > 0 { + return HelloMultiError(errors) + } + + return nil +} + +// HelloMultiError is an error wrapping multiple validation errors returned by +// Hello.ValidateAll() if the designated constraints aren't met. +type HelloMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m HelloMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m HelloMultiError) AllErrors() []error { return m } + +// HelloValidationError is the validation error returned by Hello.Validate if +// the designated constraints aren't met. +type HelloValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e HelloValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e HelloValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e HelloValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e HelloValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e HelloValidationError) ErrorName() string { return "HelloValidationError" } + +// Error satisfies the builtin error interface +func (e HelloValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sHello.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = HelloValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = HelloValidationError{} + +// Validate checks the field values on Syn with the rules defined in the proto +// definition for this message. If any rules are violated, the first error +// encountered is returned, or nil if there are no violations. +func (m *Syn) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on Syn with the rules defined in the +// proto definition for this message. If any rules are violated, the result is +// a list of violation errors wrapped in SynMultiError, or nil if none found. +func (m *Syn) ValidateAll() error { + return m.validate(true) +} + +func (m *Syn) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + // no validation rules for Port + + if len(errors) > 0 { + return SynMultiError(errors) + } + + return nil +} + +// SynMultiError is an error wrapping multiple validation errors returned by +// Syn.ValidateAll() if the designated constraints aren't met. +type SynMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m SynMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m SynMultiError) AllErrors() []error { return m } + +// SynValidationError is the validation error returned by Syn.Validate if the +// designated constraints aren't met. +type SynValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e SynValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e SynValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e SynValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e SynValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e SynValidationError) ErrorName() string { return "SynValidationError" } + +// Error satisfies the builtin error interface +func (e SynValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sSyn.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = SynValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = SynValidationError{} + +// Validate checks the field values on Data with the rules defined in the proto +// definition for this message. If any rules are violated, the first error +// encountered is returned, or nil if there are no violations. +func (m *Data) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on Data with the rules defined in the +// proto definition for this message. If any rules are violated, the result is +// a list of violation errors wrapped in DataMultiError, or nil if none found. +func (m *Data) ValidateAll() error { + return m.validate(true) +} + +func (m *Data) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + // no validation rules for Payload + + if len(errors) > 0 { + return DataMultiError(errors) + } + + return nil +} + +// DataMultiError is an error wrapping multiple validation errors returned by +// Data.ValidateAll() if the designated constraints aren't met. +type DataMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m DataMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m DataMultiError) AllErrors() []error { return m } + +// DataValidationError is the validation error returned by Data.Validate if the +// designated constraints aren't met. +type DataValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e DataValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e DataValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e DataValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e DataValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e DataValidationError) ErrorName() string { return "DataValidationError" } + +// Error satisfies the builtin error interface +func (e DataValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sData.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = DataValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = DataValidationError{} + +// Validate checks the field values on Fin with the rules defined in the proto +// definition for this message. If any rules are violated, the first error +// encountered is returned, or nil if there are no violations. +func (m *Fin) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on Fin with the rules defined in the +// proto definition for this message. If any rules are violated, the result is +// a list of violation errors wrapped in FinMultiError, or nil if none found. +func (m *Fin) ValidateAll() error { + return m.validate(true) +} + +func (m *Fin) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + // no validation rules for Ack + + if len(errors) > 0 { + return FinMultiError(errors) + } + + return nil +} + +// FinMultiError is an error wrapping multiple validation errors returned by +// Fin.ValidateAll() if the designated constraints aren't met. +type FinMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m FinMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m FinMultiError) AllErrors() []error { return m } + +// FinValidationError is the validation error returned by Fin.Validate if the +// designated constraints aren't met. +type FinValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e FinValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e FinValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e FinValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e FinValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e FinValidationError) ErrorName() string { return "FinValidationError" } + +// Error satisfies the builtin error interface +func (e FinValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sFin.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = FinValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = FinValidationError{} + +// Validate checks the field values on Rst with the rules defined in the proto +// definition for this message. If any rules are violated, the first error +// encountered is returned, or nil if there are no violations. +func (m *Rst) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on Rst with the rules defined in the +// proto definition for this message. If any rules are violated, the result is +// a list of violation errors wrapped in RstMultiError, or nil if none found. +func (m *Rst) ValidateAll() error { + return m.validate(true) +} + +func (m *Rst) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + // no validation rules for Code + + if len(errors) > 0 { + return RstMultiError(errors) + } + + return nil +} + +// RstMultiError is an error wrapping multiple validation errors returned by +// Rst.ValidateAll() if the designated constraints aren't met. +type RstMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m RstMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m RstMultiError) AllErrors() []error { return m } + +// RstValidationError is the validation error returned by Rst.Validate if the +// designated constraints aren't met. +type RstValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e RstValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e RstValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e RstValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e RstValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e RstValidationError) ErrorName() string { return "RstValidationError" } + +// Error satisfies the builtin error interface +func (e RstValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sRst.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = RstValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = RstValidationError{} diff --git a/pb/c1/connectorapi/rtun/v1/rtun_grpc.pb.go b/pb/c1/connectorapi/rtun/v1/rtun_grpc.pb.go new file mode 100644 index 000000000..9f4ee09a5 --- /dev/null +++ b/pb/c1/connectorapi/rtun/v1/rtun_grpc.pb.go @@ -0,0 +1,113 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc (unknown) +// source: c1/connectorapi/rtun/v1/rtun.proto + +package v1 + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + ReverseTunnel_Link_FullMethodName = "/c1.connectorapi.rtun.v1.ReverseTunnel/Link" +) + +// ReverseTunnelClient is the client API for ReverseTunnel service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type ReverseTunnelClient interface { + Link(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[Frame, Frame], error) +} + +type reverseTunnelClient struct { + cc grpc.ClientConnInterface +} + +func NewReverseTunnelClient(cc grpc.ClientConnInterface) ReverseTunnelClient { + return &reverseTunnelClient{cc} +} + +func (c *reverseTunnelClient) Link(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[Frame, Frame], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &ReverseTunnel_ServiceDesc.Streams[0], ReverseTunnel_Link_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[Frame, Frame]{ClientStream: stream} + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type ReverseTunnel_LinkClient = grpc.BidiStreamingClient[Frame, Frame] + +// ReverseTunnelServer is the server API for ReverseTunnel service. +// All implementations should embed UnimplementedReverseTunnelServer +// for forward compatibility. +type ReverseTunnelServer interface { + Link(grpc.BidiStreamingServer[Frame, Frame]) error +} + +// UnimplementedReverseTunnelServer should be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedReverseTunnelServer struct{} + +func (UnimplementedReverseTunnelServer) Link(grpc.BidiStreamingServer[Frame, Frame]) error { + return status.Errorf(codes.Unimplemented, "method Link not implemented") +} +func (UnimplementedReverseTunnelServer) testEmbeddedByValue() {} + +// UnsafeReverseTunnelServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to ReverseTunnelServer will +// result in compilation errors. +type UnsafeReverseTunnelServer interface { + mustEmbedUnimplementedReverseTunnelServer() +} + +func RegisterReverseTunnelServer(s grpc.ServiceRegistrar, srv ReverseTunnelServer) { + // If the following call pancis, it indicates UnimplementedReverseTunnelServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&ReverseTunnel_ServiceDesc, srv) +} + +func _ReverseTunnel_Link_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(ReverseTunnelServer).Link(&grpc.GenericServerStream[Frame, Frame]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type ReverseTunnel_LinkServer = grpc.BidiStreamingServer[Frame, Frame] + +// ReverseTunnel_ServiceDesc is the grpc.ServiceDesc for ReverseTunnel service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var ReverseTunnel_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "c1.connectorapi.rtun.v1.ReverseTunnel", + HandlerType: (*ReverseTunnelServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "Link", + Handler: _ReverseTunnel_Link_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: "c1/connectorapi/rtun/v1/rtun.proto", +} diff --git a/pkg/lambda/grpc/config/sts.go b/pkg/lambda/grpc/config/sts.go index d52644b76..084b2f755 100644 --- a/pkg/lambda/grpc/config/sts.go +++ b/pkg/lambda/grpc/config/sts.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "net/http" + "net/url" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -21,7 +22,12 @@ func createSigv4STSGetCallerIdentityRequest(ctx context.Context, cfg *aws.Config region := cfg.Region body := "Action=GetCallerIdentity&Version=2011-06-15" service := "sts" - endpoint := fmt.Sprintf("https://sts.%s.amazonaws.com", region) + + endpoint := (&url.URL{ + Scheme: "https", + Host: fmt.Sprintf("sts.%s.amazonaws.com", region), + }).String() + method := "POST" reqHeaders := map[string][]string{ diff --git a/pkg/rtun/gateway/client.go b/pkg/rtun/gateway/client.go new file mode 100644 index 000000000..645461ea0 --- /dev/null +++ b/pkg/rtun/gateway/client.go @@ -0,0 +1,521 @@ +package gateway + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "sync" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" + "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +const ( + defaultReadBufferCap = 16 + defaultWriteQueueCap = 16 + maxChunkSize = 32 * 1024 +) + +// Dialer opens reverse connections to clients via a gateway server. +type Dialer struct { + gatewayAddr string + creds credentials.TransportCredentials + // configuration + readBufferCap int // number of frames to buffer for reads + writeQueueCap int // number of chunks to buffer for writes +} + +// DialerOption configures the Dialer. +type DialerOption func(*Dialer) + +// WithReadBufferCapacity sets the number of inbound frames to buffer before backpressure blocks producer. +func WithReadBufferCapacity(capacity int) DialerOption { + return func(d *Dialer) { + if capacity > 0 { + d.readBufferCap = capacity + } + } +} + +// WithWriteQueueCapacity sets the number of outbound chunks to queue before backpressure blocks writers. +func WithWriteQueueCapacity(capacity int) DialerOption { + return func(d *Dialer) { + if capacity > 0 { + d.writeQueueCap = capacity + } + } +} + +// NewDialerWithOptions creates a gateway client with extra configuration. +func NewDialer(gatewayAddr string, creds credentials.TransportCredentials, opts ...DialerOption) *Dialer { + d := &Dialer{ + gatewayAddr: gatewayAddr, + creds: creds, + readBufferCap: defaultReadBufferCap, + writeQueueCap: defaultWriteQueueCap, + } + + for _, opt := range opts { + if opt != nil { + opt(d) + } + } + return d +} + +// DialContext opens a reverse connection to clientID:port via the gateway. +// Returns ErrNotFound if the gateway doesn't own the client (caller should re-resolve owner). +func (d *Dialer) DialContext(ctx context.Context, clientID string, port uint32) (net.Conn, error) { + logger := ctxzap.Extract(ctx).With(zap.String("client_id", clientID), zap.Uint32("port", port)) + + // Dial gateway + cc, err := grpc.DialContext(ctx, d.gatewayAddr, + grpc.WithTransportCredentials(d.creds), + ) + if err != nil { + return nil, fmt.Errorf("gateway dial failed: %w", err) + } + + client := rtunpb.NewReverseDialerClient(cc) + // Create a cancellable stream context so Close() can interrupt Recv/Send. + streamCtx, cancel := context.WithCancel(ctx) + stream, err := client.Open(streamCtx) + if err != nil { + cancel() + cc.Close() + return nil, fmt.Errorf("gateway open stream failed: %w", err) + } + + // Send OpenRequest with gSID=1 (simple case: one connection per stream) + gsid := uint32(1) + if err := stream.Send(&rtunpb.GatewayRequest{Kind: &rtunpb.GatewayRequest_OpenReq{ + OpenReq: &rtunpb.OpenRequest{Gsid: gsid, ClientId: clientID, Port: port}, + }}); err != nil { + stream.CloseSend() + cancel() + cc.Close() + return nil, fmt.Errorf("gateway send OpenRequest failed: %w", err) + } + + // Recv OpenResponse + resp, err := stream.Recv() + if err != nil { + stream.CloseSend() + cancel() + cc.Close() + return nil, fmt.Errorf("gateway recv OpenResponse failed: %w", err) + } + + openResp := resp.GetOpenResp() + if openResp == nil { + stream.CloseSend() + cancel() + cc.Close() + return nil, ErrProtocol + } + if openResp.GetGsid() != gsid { + stream.CloseSend() + cancel() + cc.Close() + return nil, fmt.Errorf("gateway returned mismatched gSID: got %d, want %d", openResp.GetGsid(), gsid) + } + + switch openResp.Result.(type) { + case *rtunpb.OpenResponse_NotFound: + stream.CloseSend() + cancel() + cc.Close() + logger.Info("client not found on gateway") + return nil, ErrNotFound + case *rtunpb.OpenResponse_Opened: + logger.Info("gateway connection opened") + doneCh := make(chan struct{}) + gc := &gatewayConn{ + stream: stream, + cc: cc, + gsid: gsid, + cancel: cancel, + doneCh: doneCh, + } + gc.r = newReader(stream, gsid, d.readBufferCap, doneCh) + gc.w = newWriter(stream, gsid, d.writeQueueCap, doneCh) + return gc, nil + default: + stream.CloseSend() + cancel() + cc.Close() + return nil, ErrProtocol + } +} + +var _ net.Conn = (*gatewayConn)(nil) + +// gatewayConn implements net.Conn over a gateway stream. +type gatewayConn struct { + stream rtunpb.ReverseDialer_OpenClient + cc *grpc.ClientConn + gsid uint32 + cancel context.CancelFunc + + // reader/writer components + r *reader + w *writer + + writeMu sync.Mutex + writeClosed bool + + closeOnce sync.Once + doneCh chan struct{} + + rdDeadline time.Time + wrDeadline time.Time +} + +type writeMsg struct { + payload []byte + fin bool +} + +type reader struct { + stream rtunpb.ReverseDialer_OpenClient + gsid uint32 + bufCap int + ch chan []byte + doneCh <-chan struct{} + + mu sync.Mutex + rem []byte + err error + once sync.Once +} + +func newReader(stream rtunpb.ReverseDialer_OpenClient, gsid uint32, bufCap int, doneCh <-chan struct{}) *reader { + if bufCap <= 0 { + bufCap = defaultReadBufferCap + } + return &reader{ + stream: stream, + gsid: gsid, + bufCap: bufCap, + doneCh: doneCh, + } +} + +func (r *reader) start() { + r.once.Do(func() { + r.ch = make(chan []byte, r.bufCap) + go r.loop() + }) +} + +func (r *reader) loop() { + defer close(r.ch) + for { + resp, err := r.stream.Recv() + if err != nil { + r.mu.Lock() + if r.err == nil { + r.err = err + } + r.mu.Unlock() + return + } + fr := resp.GetFrame() + if fr == nil || fr.GetSid() != r.gsid { + continue + } + switch k := fr.Kind.(type) { + case *rtunpb.Frame_Data: + payload := append([]byte(nil), k.Data.GetPayload()...) + select { + case r.ch <- payload: + case <-r.doneCh: + return + case <-r.stream.Context().Done(): + return + } + case *rtunpb.Frame_Fin: + r.mu.Lock() + r.err = io.EOF + r.mu.Unlock() + return + case *rtunpb.Frame_Rst: + r.mu.Lock() + r.err = fmt.Errorf("gateway: connection reset (code %v)", k.Rst.GetCode()) + r.mu.Unlock() + return + } + } +} + +func (r *reader) next(ctx context.Context) ([]byte, error) { + r.start() + select { + case buf, ok := <-r.ch: + if !ok { + r.mu.Lock() + err := r.err + r.mu.Unlock() + if err == nil { + return nil, io.EOF + } + return nil, err + } + return buf, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +type writer struct { + stream rtunpb.ReverseDialer_OpenClient + gsid uint32 + queueCap int + ch chan writeMsg + doneCh <-chan struct{} + + mu sync.Mutex + err error + once sync.Once +} + +func newWriter(stream rtunpb.ReverseDialer_OpenClient, gsid uint32, queueCap int, doneCh <-chan struct{}) *writer { + if queueCap <= 0 { + queueCap = defaultWriteQueueCap + } + return &writer{ + stream: stream, + gsid: gsid, + queueCap: queueCap, + doneCh: doneCh, + } +} + +func (w *writer) start() { + w.once.Do(func() { + w.ch = make(chan writeMsg, w.queueCap) + go w.loop() + }) +} + +func (w *writer) setErr(err error) { + w.mu.Lock() + if w.err == nil { + w.err = err + } + w.mu.Unlock() +} + +func (w *writer) getErr() error { + w.mu.Lock() + defer w.mu.Unlock() + return w.err +} + +func (w *writer) loop() { + for { + select { + case msg := <-w.ch: + if msg.fin { + _ = w.stream.Send(&rtunpb.GatewayRequest{Kind: &rtunpb.GatewayRequest_Frame{ + Frame: &rtunpb.Frame{Sid: w.gsid, Kind: &rtunpb.Frame_Fin{Fin: &rtunpb.Fin{}}}, + }}) + continue + } + if err := w.stream.Send(&rtunpb.GatewayRequest{Kind: &rtunpb.GatewayRequest_Frame{ + Frame: &rtunpb.Frame{Sid: w.gsid, Kind: &rtunpb.Frame_Data{Data: &rtunpb.Data{Payload: msg.payload}}}, + }}); err != nil { + w.setErr(err) + return + } + case <-w.doneCh: + return + case <-w.stream.Context().Done(): + return + } + } +} + +func (w *writer) enqueue(msg writeMsg, deadline time.Time) error { + w.start() + // fast-fail if previous send error + if err := w.getErr(); err != nil { + return err + } + + if !deadline.IsZero() { + until := time.Until(deadline) + if until <= 0 { + return context.DeadlineExceeded + } + timer := time.NewTimer(until) + defer timer.Stop() + select { + case w.ch <- msg: + return nil + case <-timer.C: + return context.DeadlineExceeded + case <-w.doneCh: + return fmt.Errorf("rtun/gateway: write after close: %w", net.ErrClosed) + case <-w.stream.Context().Done(): + return w.stream.Context().Err() + } + } + select { + case w.ch <- msg: + return nil + case <-w.doneCh: + return fmt.Errorf("rtun/gateway: write after close: %w", net.ErrClosed) + case <-w.stream.Context().Done(): + return w.stream.Context().Err() + } +} + +func (g *gatewayConn) Read(p []byte) (int, error) { + // Consume remainder first + g.r.mu.Lock() + if len(g.r.rem) > 0 { + n := copy(p, g.r.rem) + g.r.rem = g.r.rem[n:] + g.r.mu.Unlock() + return n, nil + } + g.r.mu.Unlock() + + // Compute deadline context + var ctx context.Context + var cancel context.CancelFunc + if g.rdDeadline.IsZero() { + ctx = context.Background() + cancel = func() {} + } else { + until := time.Until(g.rdDeadline) + if until <= 0 { + return 0, context.DeadlineExceeded + } + ctx, cancel = context.WithTimeout(context.Background(), until) + } + defer cancel() + + buf, err := g.r.next(ctx) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) || errors.Is(err, io.EOF) { + return 0, err + } + if g.isClosed() { + return 0, fmt.Errorf("rtun/gateway: read on closed connection: %w", net.ErrClosed) + } + return 0, err + } + n := copy(p, buf) + if n < len(buf) { + g.r.mu.Lock() + g.r.rem = buf[n:] + g.r.mu.Unlock() + } + return n, nil +} + +func (g *gatewayConn) isClosed() bool { + select { + case <-g.doneCh: + return true + default: + return false + } +} + +// recvLoop moved into reader.loop + +func (g *gatewayConn) Write(p []byte) (int, error) { + g.writeMu.Lock() + if g.writeClosed { + g.writeMu.Unlock() + return 0, fmt.Errorf("rtun/gateway: write on closed connection: %w", net.ErrClosed) + } + g.writeMu.Unlock() + + if err := g.w.getErr(); err != nil { + return 0, err + } + + total := 0 + for len(p) > 0 { + chunk := p + if len(chunk) > maxChunkSize { + chunk = p[:maxChunkSize] + } + cp := append([]byte(nil), chunk...) + if err := g.w.enqueue(writeMsg{payload: cp}, g.wrDeadline); err != nil { + if total == 0 { + return 0, err + } + return total, err + } + total += len(chunk) + p = p[len(chunk):] + } + return total, nil +} + +// writer loop moved into writer.loop + +func (g *gatewayConn) Close() error { + g.closeOnce.Do(func() { + g.writeMu.Lock() + if !g.writeClosed { + g.writeClosed = true + // best-effort FIN without blocking; if writer not started yet, send directly + if g.w != nil && g.w.ch != nil { + select { + case g.w.ch <- writeMsg{fin: true}: + default: + } + } else { + _ = g.stream.Send(&rtunpb.GatewayRequest{Kind: &rtunpb.GatewayRequest_Frame{ + Frame: &rtunpb.Frame{Sid: g.gsid, Kind: &rtunpb.Frame_Fin{Fin: &rtunpb.Fin{}}}, + }}) + } + } + g.writeMu.Unlock() + + // Cancel stream context to unblock Recv/Send + if g.cancel != nil { + g.cancel() + } + close(g.doneCh) + _ = g.stream.CloseSend() + _ = g.cc.Close() + }) + return nil +} + +func (g *gatewayConn) LocalAddr() net.Addr { return gatewayAddr{"gateway-local"} } +func (g *gatewayConn) RemoteAddr() net.Addr { return gatewayAddr{"gateway-remote"} } + +func (g *gatewayConn) SetDeadline(t time.Time) error { + g.rdDeadline = t + g.wrDeadline = t + return nil +} + +func (g *gatewayConn) SetReadDeadline(t time.Time) error { + g.rdDeadline = t + return nil +} + +func (g *gatewayConn) SetWriteDeadline(t time.Time) error { + g.wrDeadline = t + return nil +} + +type gatewayAddr struct{ s string } + +func (a gatewayAddr) Network() string { return "gateway" } +func (a gatewayAddr) String() string { return a.s } diff --git a/pkg/rtun/gateway/client_conn_test.go b/pkg/rtun/gateway/client_conn_test.go new file mode 100644 index 000000000..2cc424efc --- /dev/null +++ b/pkg/rtun/gateway/client_conn_test.go @@ -0,0 +1,170 @@ +package gateway + +import ( + "context" + "errors" + "io" + "net" + "testing" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" + "github.com/conductorone/baton-sdk/pkg/rtun/server" + "github.com/conductorone/baton-sdk/pkg/rtun/transport" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + health "google.golang.org/grpc/health" + healthpb "google.golang.org/grpc/health/grpc_health_v1" +) + +type gwEnv struct { + addr string + cleanup func() + accepted chan net.Conn +} + +// setupGateway spins up a real ReverseTunnel + ReverseDialer server and returns control. +// If silent is true, the client listener will accept a raw conn and never write to it unless tests do. +func setupGateway(t *testing.T, silent bool) *gwEnv { + t.Helper() + reg := server.NewRegistry() + handler := server.NewHandler(reg, "server-a", testValidator{id: "client-xyz"}) + gw := NewServer(reg, "server-a", nil) + + gsrv := grpc.NewServer() + rtunpb.RegisterReverseTunnelServer(gsrv, handler) + rtunpb.RegisterReverseDialerServer(gsrv, gw) + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + go func() { _ = gsrv.Serve(l) }() + + // Bring up a client link and listen on port 1 + cc, err := grpc.Dial(l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + rtunClient := rtunpb.NewReverseTunnelClient(cc) + stream, err := rtunClient.Link(context.Background()) + require.NoError(t, err) + cl := &clientLink{cli: stream} + sess := transport.NewSession(cl) + require.NoError(t, cl.Send(&rtunpb.Frame{Sid: 0, Kind: &rtunpb.Frame_Hello{Hello: &rtunpb.Hello{Ports: []uint32{1}}}})) + ln, err := sess.Listen(context.Background(), 1) + require.NoError(t, err) + + accepted := make(chan net.Conn, 1) + var cgs *grpc.Server + if silent { + // Accept one conn and expose it to tests + go func() { c, _ := ln.Accept(); accepted <- c }() + } else { + // Serve a health server + cgs = grpc.NewServer() + hs := health.NewServer() + healthpb.RegisterHealthServer(cgs, hs) + hs.SetServingStatus("", healthpb.HealthCheckResponse_SERVING) + go func() { _ = cgs.Serve(ln) }() + } + + cleanup := func() { + if cgs != nil { + cgs.GracefulStop() + } + _ = ln.Close() + _ = cc.Close() + gsrv.GracefulStop() + _ = l.Close() + } + return &gwEnv{addr: l.Addr().String(), cleanup: cleanup, accepted: accepted} +} + +func TestGatewayConnReadDeadline(t *testing.T) { + env := setupGateway(t, true) + defer env.cleanup() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + d := NewDialer(env.addr, insecure.NewCredentials()) + gwc, err := d.DialContext(ctx, "client-xyz", 1) + require.NoError(t, err) + defer gwc.Close() + + _ = gwc.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + _, err = gwc.Read(make([]byte, 1)) + require.ErrorIs(t, err, context.DeadlineExceeded) +} + +func TestGatewayConnCloseIdempotentAndWriteAfterClose(t *testing.T) { + env := setupGateway(t, true) + defer env.cleanup() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + d := NewDialer(env.addr, insecure.NewCredentials()) + gwc, err := d.DialContext(ctx, "client-xyz", 1) + require.NoError(t, err) + + require.NoError(t, gwc.Close()) + require.NoError(t, gwc.Close()) + _, err = gwc.Write([]byte("x")) + require.True(t, errors.Is(err, net.ErrClosed)) +} + +func TestGatewayConnEOFOnRemoteClose(t *testing.T) { + env := setupGateway(t, true) + defer env.cleanup() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + d := NewDialer(env.addr, insecure.NewCredentials()) + gwc, err := d.DialContext(ctx, "client-xyz", 1) + require.NoError(t, err) + defer gwc.Close() + + // Remote close: wait for accept and then close the accepted conn to generate FIN. + select { + case rc := <-env.accepted: + _ = rc.Close() + case <-time.After(200 * time.Millisecond): + t.Fatal("no accepted conn") + } + + // Read should return EOF + _, err = gwc.Read(make([]byte, 1)) + require.ErrorIs(t, err, io.EOF) +} + +func TestGatewayConnWriteAndRemoteReceive(t *testing.T) { + env := setupGateway(t, true) + defer env.cleanup() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + d := NewDialer(env.addr, insecure.NewCredentials()) + gwc, err := d.DialContext(ctx, "client-xyz", 1) + require.NoError(t, err) + defer gwc.Close() + + // Wait for remote accept + var rc net.Conn + select { + case rc = <-env.accepted: + case <-time.After(200 * time.Millisecond): + t.Fatal("no accepted conn") + } + + // Write and verify bytes arrive remotely + msg := []byte("hello") + n, err := gwc.Write(msg) + require.NoError(t, err) + require.Equal(t, len(msg), n) + buf := make([]byte, len(msg)) + _, err = rc.Read(buf) + require.NoError(t, err) + require.Equal(t, msg, buf) + + // Local FIN: close gwc; remote should see EOF + _ = gwc.Close() + tmp := make([]byte, 1) + _, err = rc.Read(tmp) + require.ErrorIs(t, err, io.EOF) +} diff --git a/pkg/rtun/gateway/errors.go b/pkg/rtun/gateway/errors.go new file mode 100644 index 000000000..3deaf8d35 --- /dev/null +++ b/pkg/rtun/gateway/errors.go @@ -0,0 +1,9 @@ +package gateway + +import "errors" + +var ( + ErrNotFound = errors.New("rtun/gateway: client not found on this server") + ErrInvalidGSID = errors.New("rtun/gateway: invalid or duplicate gSID") + ErrProtocol = errors.New("rtun/gateway: protocol error") +) diff --git a/pkg/rtun/gateway/grpc_options.go b/pkg/rtun/gateway/grpc_options.go new file mode 100644 index 000000000..706577875 --- /dev/null +++ b/pkg/rtun/gateway/grpc_options.go @@ -0,0 +1,25 @@ +package gateway + +import ( + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/keepalive" +) + +// RecommendedGRPCServerOptions returns server options enabling basic keepalive and +// reasonable message size limits suitable for the gateway service. +func RecommendedGRPCServerOptions() []grpc.ServerOption { + return []grpc.ServerOption{ + grpc.KeepaliveParams(keepalive.ServerParameters{ + Time: 30 * time.Second, + Timeout: 10 * time.Second, + MaxConnectionIdle: 0, + MaxConnectionAge: 0, + }), + grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{ + MinTime: 10 * time.Second, + PermitWithoutStream: true, + }), + } +} diff --git a/pkg/rtun/gateway/integration_test.go b/pkg/rtun/gateway/integration_test.go new file mode 100644 index 000000000..63f66a163 --- /dev/null +++ b/pkg/rtun/gateway/integration_test.go @@ -0,0 +1,135 @@ +package gateway + +import ( + "context" + "net" + "testing" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" + "github.com/conductorone/baton-sdk/pkg/rtun/server" + "github.com/conductorone/baton-sdk/pkg/rtun/transport" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + health "google.golang.org/grpc/health" + healthpb "google.golang.org/grpc/health/grpc_health_v1" +) + +// testValidator for integration tests. +type testValidator struct{ id string } + +func (t testValidator) ValidateAuth(ctx context.Context) (string, error) { return t.id, nil } +func (t testValidator) ValidateHello(ctx context.Context, hello *rtunpb.Hello) error { return nil } + +// clientLink adapts client bidi stream to transport.Link. +type clientLink struct { + cli rtunpb.ReverseTunnel_LinkClient +} + +func (c *clientLink) Send(f *rtunpb.Frame) error { return c.cli.Send(f) } +func (c *clientLink) Recv() (*rtunpb.Frame, error) { return c.cli.Recv() } +func (c *clientLink) Context() context.Context { return c.cli.Context() } + +// TestGatewayE2E validates the full gateway stack: +// - Client connects to server A (handler+registry). +// - Gateway server on A. +// - Remote caller uses gateway.Dialer to get net.Conn to client. +// - Caller performs gRPC health check over gateway conn. +func TestGatewayE2E(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Server A: rtun handler + registry + gateway + regA := server.NewRegistry() + handlerA := server.NewHandler(regA, "server-a", testValidator{id: "client-123"}) + gwA := NewServer(regA, "server-a", nil) + + gsrvA := grpc.NewServer() + rtunpb.RegisterReverseTunnelServer(gsrvA, handlerA) + rtunpb.RegisterReverseDialerServer(gsrvA, gwA) + + lA, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + go func() { _ = gsrvA.Serve(lA) }() + + // Client connects to server A + clientCtx, clientCancel := context.WithCancel(ctx) + defer clientCancel() + + ccA, err := grpc.Dial(lA.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + + rtunClient := rtunpb.NewReverseTunnelClient(ccA) + stream, err := rtunClient.Link(clientCtx) + require.NoError(t, err) + + // Client session: send HELLO, listen on port 1 + cl := &clientLink{cli: stream} + sess := transport.NewSession(cl) + require.NoError(t, cl.Send(&rtunpb.Frame{Sid: 0, Kind: &rtunpb.Frame_Hello{Hello: &rtunpb.Hello{Ports: []uint32{1}}}})) + ln, err := sess.Listen(ctx, 1) + require.NoError(t, err) + + // Client runs health service + cgs := grpc.NewServer() + hs := health.NewServer() + healthpb.RegisterHealthServer(cgs, hs) + hs.SetServingStatus("", healthpb.HealthCheckResponse_SERVING) + go func() { _ = cgs.Serve(ln) }() + + // Wait for registration + time.Sleep(50 * time.Millisecond) + + // Remote caller: use gateway.Dialer to get net.Conn to client + dialer := NewDialer(lA.Addr().String(), insecure.NewCredentials(), nil) + gwConn, err := dialer.DialContext(ctx, "client-123", 1) + require.NoError(t, err) + defer gwConn.Close() + + // Wrap gateway conn in grpc.Dial and perform health check + callerCC, err := grpc.DialContext(ctx, "ignored", + grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { return gwConn, nil }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + defer callerCC.Close() + + hc := healthpb.NewHealthClient(callerCC) + resp, err := hc.Check(ctx, &healthpb.HealthCheckRequest{Service: ""}) + require.NoError(t, err) + require.Equal(t, healthpb.HealthCheckResponse_SERVING, resp.GetStatus()) + + // Cleanup + callerCC.Close() + gwConn.Close() + cgs.GracefulStop() + ln.Close() + clientCancel() + ccA.Close() + gsrvA.GracefulStop() + lA.Close() +} + +func TestGatewayNotFound(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // Server with gateway but no client registered + regA := server.NewRegistry() + gwA := NewServer(regA, "server-a", nil) + + gsrvA := grpc.NewServer() + rtunpb.RegisterReverseDialerServer(gsrvA, gwA) + + lA, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer lA.Close() + go func() { _ = gsrvA.Serve(lA) }() + defer gsrvA.GracefulStop() + + // Caller tries to dial non-existent client + dialer := NewDialer(lA.Addr().String(), insecure.NewCredentials(), nil) + _, err = dialer.DialContext(ctx, "client-nonexistent", 1) + require.ErrorIs(t, err, ErrNotFound) +} diff --git a/pkg/rtun/gateway/metrics.go b/pkg/rtun/gateway/metrics.go new file mode 100644 index 000000000..112623bf9 --- /dev/null +++ b/pkg/rtun/gateway/metrics.go @@ -0,0 +1,44 @@ +package gateway + +import ( + "context" + + sdkmetrics "github.com/conductorone/baton-sdk/pkg/metrics" +) + +type gwMetrics struct { + h sdkmetrics.Handler + + openReq sdkmetrics.Int64Counter + openOK sdkmetrics.Int64Counter + openMiss sdkmetrics.Int64Counter + framesRx sdkmetrics.Int64Counter + framesTx sdkmetrics.Int64Counter + writerDrop sdkmetrics.Int64Counter + writeErr sdkmetrics.Int64Counter +} + +func newGwMetrics(h sdkmetrics.Handler) *gwMetrics { + return &gwMetrics{ + h: h, + openReq: h.Int64Counter("rtun.gateway.open_requests_total", "gateway open requests", sdkmetrics.Dimensionless), + openOK: h.Int64Counter("rtun.gateway.open_success_total", "gateway opens succeeded", sdkmetrics.Dimensionless), + openMiss: h.Int64Counter("rtun.gateway.open_not_found_total", "gateway opens not found", sdkmetrics.Dimensionless), + framesRx: h.Int64Counter("rtun.gateway.frame_rx_total", "gateway frames received", sdkmetrics.Dimensionless), + framesTx: h.Int64Counter("rtun.gateway.frame_tx_total", "gateway frames sent", sdkmetrics.Dimensionless), + writerDrop: h.Int64Counter("rtun.gateway.writer_queue_drops_total", "gateway writer queue drops", sdkmetrics.Dimensionless), + writeErr: h.Int64Counter("rtun.gateway.writer_write_errors_total", "gateway write errors", sdkmetrics.Dimensionless), + } +} + +func (m *gwMetrics) addOpenReq(ctx context.Context) { m.openReq.Add(ctx, 1, nil) } +func (m *gwMetrics) addOpenOK(ctx context.Context) { m.openOK.Add(ctx, 1, nil) } +func (m *gwMetrics) addOpenMiss(ctx context.Context) { m.openMiss.Add(ctx, 1, nil) } +func (m *gwMetrics) addFrameRx(ctx context.Context, k string) { + m.framesRx.Add(ctx, 1, map[string]string{"kind": k}) +} +func (m *gwMetrics) addFrameTx(ctx context.Context, k string) { + m.framesTx.Add(ctx, 1, map[string]string{"kind": k}) +} +func (m *gwMetrics) addWriterDrop(ctx context.Context) { m.writerDrop.Add(ctx, 1, nil) } +func (m *gwMetrics) addWriteErr(ctx context.Context) { m.writeErr.Add(ctx, 1, nil) } diff --git a/pkg/rtun/gateway/options.go b/pkg/rtun/gateway/options.go new file mode 100644 index 000000000..1caf04456 --- /dev/null +++ b/pkg/rtun/gateway/options.go @@ -0,0 +1,16 @@ +package gateway + +import ( + sdkmetrics "github.com/conductorone/baton-sdk/pkg/metrics" +) + +type Option func(*options) + +type options struct { + metrics sdkmetrics.Handler +} + +// WithMetricsHandler injects a metrics handler for the gateway service. +func WithMetricsHandler(h sdkmetrics.Handler) Option { + return func(o *options) { o.metrics = h } +} diff --git a/pkg/rtun/gateway/server.go b/pkg/rtun/gateway/server.go new file mode 100644 index 000000000..556a20b41 --- /dev/null +++ b/pkg/rtun/gateway/server.go @@ -0,0 +1,302 @@ +package gateway + +import ( + "context" + "io" + "net" + "net/url" + "strconv" + "sync" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" + "github.com/conductorone/baton-sdk/pkg/rtun/server" + "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" + "go.uber.org/zap" +) + +// Server implements the ReverseDialer gateway service. +// It bridges caller streams to rtun sessions on the owner server process. +type Server struct { + rtunpb.UnimplementedReverseDialerServer + + reg *server.Registry + serverID string + m *gwMetrics +} + +// NewServer creates a gateway server that bridges callers to local rtun sessions. +func NewServer(reg *server.Registry, serverID string, opts ...Option) *Server { + var o options + for _, opt := range opts { + if opt == nil { + continue + } + opt(&o) + } + s := &Server{ + reg: reg, + serverID: serverID, + } + if o.metrics != nil { + s.m = newGwMetrics(o.metrics) + } + return s +} + +// Open handles a gateway stream: caller sends OpenRequest(s) and Frames; gateway bridges to rtun. +const ( + writerQueueCap = 256 + writeDeadline = 30 * time.Second +) + +type entry struct { + conn net.Conn + writeCh chan []byte + done chan struct{} + closeOnce sync.Once + m *gwMetrics + ctx context.Context +} + +func newEntry(conn net.Conn, m *gwMetrics, ctx context.Context) *entry { + e := &entry{ + conn: conn, + writeCh: make(chan []byte, writerQueueCap), + done: make(chan struct{}), + m: m, + ctx: ctx, + } + go e.writerLoop() + return e +} + +func (e *entry) writerLoop() { + for { + select { + case <-e.done: + return + case buf, ok := <-e.writeCh: + if !ok { + return + } + _ = e.conn.SetWriteDeadline(time.Now().Add(writeDeadline)) + if _, err := e.conn.Write(buf); err != nil { + if e.m != nil { + e.m.addWriteErr(e.ctx) + } + e.close() + return + } + } + } +} + +func (e *entry) send(b []byte) bool { + cp := append([]byte(nil), b...) + select { + case <-e.done: + return false + case e.writeCh <- cp: + return true + default: + e.close() + return false + } +} + +func (e *entry) close() { + e.closeOnce.Do(func() { + _ = e.conn.Close() + close(e.done) + close(e.writeCh) + }) +} + +func (s *Server) Open(stream rtunpb.ReverseDialer_OpenServer) error { + ctx := stream.Context() + logger := ctxzap.Extract(ctx).With(zap.String("server_id", s.serverID)) + + var mu sync.Mutex + entries := make(map[uint32]*entry) + var wg sync.WaitGroup + + defer func() { + // Cleanup: close all connections and wait for readers + mu.Lock() + for gsid, ent := range entries { + ent.close() + delete(entries, gsid) + } + mu.Unlock() + wg.Wait() + }() + + for { + req, err := stream.Recv() + if err != nil { + if err == io.EOF { + return nil + } + return err + } + + switch k := req.Kind.(type) { + case *rtunpb.GatewayRequest_OpenReq: + openReq := k.OpenReq + gsid := openReq.GetGsid() + clientID := openReq.GetClientId() + port := openReq.GetPort() + + logger := logger.With(zap.Uint32("gsid", gsid), zap.String("client_id", clientID), zap.Uint32("port", port)) + if s.m != nil { + s.m.addOpenReq(ctx) + s.m.addFrameRx(ctx, "OPEN_REQ") + } + + // Check for duplicate gSID + mu.Lock() + if _, exists := entries[gsid]; exists { + mu.Unlock() + logger.Warn("duplicate gSID in OpenRequest") + _ = stream.Send(&rtunpb.GatewayResponse{Kind: &rtunpb.GatewayResponse_Frame{ + Frame: &rtunpb.Frame{Sid: gsid, Kind: &rtunpb.Frame_Rst{Rst: &rtunpb.Rst{Code: rtunpb.RstCode_RST_CODE_INTERNAL}}}, + }}) + continue + } + mu.Unlock() + + // Open reverse connection via local registry + addr := formatRtunAddr(clientID, port) + conn, err := s.reg.DialContext(ctx, addr) + if err != nil { + logger.Info("client not found or dial failed", zap.Error(err)) + _ = stream.Send(&rtunpb.GatewayResponse{Kind: &rtunpb.GatewayResponse_OpenResp{ + OpenResp: &rtunpb.OpenResponse{Gsid: gsid, Result: &rtunpb.OpenResponse_NotFound{NotFound: &rtunpb.NotFound{}}}, + }}) + if s.m != nil { + s.m.addOpenMiss(ctx) + } + continue + } + + // Store conn and reply success + mu.Lock() + ent := newEntry(conn, s.m, ctx) + entries[gsid] = ent + mu.Unlock() + + logger.Info("opened reverse connection") + if err := stream.Send(&rtunpb.GatewayResponse{Kind: &rtunpb.GatewayResponse_OpenResp{ + OpenResp: &rtunpb.OpenResponse{Gsid: gsid, Result: &rtunpb.OpenResponse_Opened{Opened: &rtunpb.Opened{}}}, + }}); err != nil { + conn.Close() + return err + } + if s.m != nil { + s.m.addOpenOK(ctx) + } + + // Spawn reader: conn → caller + wg.Add(1) + go s.bridgeRead(ctx, conn, gsid, stream, &wg, logger) + + case *rtunpb.GatewayRequest_Frame: + fr := k.Frame + gsid := fr.GetSid() + + mu.Lock() + ent := entries[gsid] + mu.Unlock() + + if ent == nil { + // Unknown gSID; send RST + _ = stream.Send(&rtunpb.GatewayResponse{Kind: &rtunpb.GatewayResponse_Frame{ + Frame: &rtunpb.Frame{Sid: gsid, Kind: &rtunpb.Frame_Rst{Rst: &rtunpb.Rst{Code: rtunpb.RstCode_RST_CODE_INTERNAL}}}, + }}) + continue + } + + // Handle frame + switch fk := fr.Kind.(type) { + case *rtunpb.Frame_Data: + if s.m != nil { + s.m.addFrameRx(ctx, "DATA") + } + if ok := ent.send(fk.Data.GetPayload()); !ok { + if s.m != nil { + s.m.addWriterDrop(ctx) + } + mu.Lock() + delete(entries, gsid) + mu.Unlock() + } + case *rtunpb.Frame_Fin: + if s.m != nil { + s.m.addFrameRx(ctx, "FIN") + } + mu.Lock() + ent.close() + delete(entries, gsid) + mu.Unlock() + case *rtunpb.Frame_Rst: + if s.m != nil { + s.m.addFrameRx(ctx, "RST") + } + mu.Lock() + ent.close() + delete(entries, gsid) + mu.Unlock() + } + } + } +} + +// bridgeRead reads from rtun conn and sends frames to the caller stream. +func (s *Server) bridgeRead(ctx context.Context, conn net.Conn, gsid uint32, stream rtunpb.ReverseDialer_OpenServer, wg *sync.WaitGroup, logger *zap.Logger) { + defer wg.Done() + buf := make([]byte, 32*1024) + for { + n, err := conn.Read(buf) + if n > 0 { + if err := stream.Send(&rtunpb.GatewayResponse{Kind: &rtunpb.GatewayResponse_Frame{ + Frame: &rtunpb.Frame{Sid: gsid, Kind: &rtunpb.Frame_Data{Data: &rtunpb.Data{Payload: append([]byte(nil), buf[:n]...)}}}, + }}); err != nil { + logger.Warn("failed to send data to caller", zap.Error(err)) + return + } + if s.m != nil { + s.m.addFrameTx(ctx, "DATA") + } + } + if err != nil { + if err == io.EOF { + _ = stream.Send(&rtunpb.GatewayResponse{Kind: &rtunpb.GatewayResponse_Frame{ + Frame: &rtunpb.Frame{Sid: gsid, Kind: &rtunpb.Frame_Fin{Fin: &rtunpb.Fin{}}}, + }}) + if s.m != nil { + s.m.addFrameTx(ctx, "FIN") + } + } else { + _ = stream.Send(&rtunpb.GatewayResponse{Kind: &rtunpb.GatewayResponse_Frame{ + Frame: &rtunpb.Frame{Sid: gsid, Kind: &rtunpb.Frame_Rst{Rst: &rtunpb.Rst{Code: rtunpb.RstCode_RST_CODE_INTERNAL}}}, + }}) + if s.m != nil { + s.m.addFrameTx(ctx, "RST") + } + } + return + } + } +} + +func formatRtunAddr(clientID string, port uint32) string { + u := url.URL{ + Scheme: "rtun", + Host: net.JoinHostPort( + clientID, + strconv.Itoa(int(port)), + ), + } + return u.String() +} diff --git a/pkg/rtun/match/directory.go b/pkg/rtun/match/directory.go new file mode 100644 index 000000000..dc4197583 --- /dev/null +++ b/pkg/rtun/match/directory.go @@ -0,0 +1,15 @@ +package match + +import ( + "context" + "errors" + "time" +) + +var ErrNotImplemented = errors.New("rtun/match: not implemented") + +type Directory interface { + Advertise(ctx context.Context, serverID string, addr string, ttl time.Duration) error + Revoke(ctx context.Context, serverID string) error + Resolve(ctx context.Context, serverID string) (addr string, err error) +} diff --git a/pkg/rtun/match/errors.go b/pkg/rtun/match/errors.go new file mode 100644 index 000000000..d762a9278 --- /dev/null +++ b/pkg/rtun/match/errors.go @@ -0,0 +1,7 @@ +package match + +import "errors" + +var ( + ErrClientOffline = errors.New("rtun/match: client offline") +) diff --git a/pkg/rtun/match/locator.go b/pkg/rtun/match/locator.go new file mode 100644 index 000000000..864b7abbf --- /dev/null +++ b/pkg/rtun/match/locator.go @@ -0,0 +1,48 @@ +package match + +import ( + "context" + "hash/fnv" +) + +type Locator struct { + Presence Presence +} + +func (l *Locator) OwnerOf(ctx context.Context, clientID string) (serverID string, ports []uint32, err error) { + if l == nil || l.Presence == nil { + return "", nil, ErrNotImplemented + } + servers, err := l.Presence.Locations(ctx, clientID) + if err != nil { + return "", nil, err + } + if len(servers) == 0 { + return "", nil, ErrClientOffline + } + owner := rendezvousChoose(clientID, servers) + ports, err = l.Presence.Ports(ctx, clientID) + if err != nil { + return "", nil, err + } + return owner, ports, nil +} + +func rendezvousChoose(clientID string, servers []string) string { + var best string + var bestVal uint64 + var have bool + for _, s := range servers { + h := fnv.New64a() + _, _ = h.Write([]byte(clientID)) + _, _ = h.Write([]byte("|")) + _, _ = h.Write([]byte(s)) + v := h.Sum64() + if !have || v > bestVal { + bestVal = v + best = s + have = true + } + } + return best +} diff --git a/pkg/rtun/match/locator_test.go b/pkg/rtun/match/locator_test.go new file mode 100644 index 000000000..4032be2f6 --- /dev/null +++ b/pkg/rtun/match/locator_test.go @@ -0,0 +1,84 @@ +package match + +import ( + "context" + "testing" + "time" + + "github.com/conductorone/baton-sdk/pkg/rtun/match/memory" + "github.com/stretchr/testify/require" +) + +func TestLocatorOwnerOfSingleServer(t *testing.T) { + p := memory.NewPresence() + ctx := context.Background() + + err := p.SetPorts(ctx, "client-1", []uint32{1}) + require.NoError(t, err) + err = p.Announce(ctx, "client-1", "server-a", 10*time.Second) + require.NoError(t, err) + + loc := &Locator{Presence: p} + owner, ports, err := loc.OwnerOf(ctx, "client-1") + require.NoError(t, err) + require.Equal(t, "server-a", owner) + require.Equal(t, []uint32{1}, ports) +} + +func TestLocatorOwnerOfMultipleServersDeterministic(t *testing.T) { + p := memory.NewPresence() + ctx := context.Background() + + // Register two servers + _ = p.SetPorts(ctx, "client-1", []uint32{1}) + _ = p.Announce(ctx, "client-1", "server-a", 10*time.Second) + _ = p.Announce(ctx, "client-1", "server-b", 10*time.Second) + + loc := &Locator{Presence: p} + owner1, _, err := loc.OwnerOf(ctx, "client-1") + require.NoError(t, err) + + // Call again; should be same owner (deterministic rendezvous hashing) + owner2, _, err := loc.OwnerOf(ctx, "client-1") + require.NoError(t, err) + require.Equal(t, owner1, owner2) + + // owner should be one of the two servers + require.Contains(t, []string{"server-a", "server-b"}, owner1) +} + +func TestLocatorOwnerOfClientOffline(t *testing.T) { + p := memory.NewPresence() + ctx := context.Background() + + loc := &Locator{Presence: p} + _, _, err := loc.OwnerOf(ctx, "client-nonexistent") + require.ErrorIs(t, err, ErrClientOffline) +} + +func TestLocatorOwnerOfWithCustomChooser(t *testing.T) { + // Removed: locator no longer accepts a custom chooser; keep surface minimal. +} + +func TestLocatorOwnerOfTTLExpiry(t *testing.T) { + p := memory.NewPresence() + ctx := context.Background() + + // Set with very short TTL + err := p.SetPorts(ctx, "client-1", []uint32{1}) + require.NoError(t, err) + err = p.Announce(ctx, "client-1", "server-a", 10*time.Millisecond) + require.NoError(t, err) + + loc := &Locator{Presence: p} + owner, _, err := loc.OwnerOf(ctx, "client-1") + require.NoError(t, err) + require.Equal(t, "server-a", owner) + + // Wait for expiry + time.Sleep(20 * time.Millisecond) + + // Now should return ErrClientOffline + _, _, err = loc.OwnerOf(ctx, "client-1") + require.ErrorIs(t, err, ErrClientOffline) +} diff --git a/pkg/rtun/match/memory/directory.go b/pkg/rtun/match/memory/directory.go new file mode 100644 index 000000000..a95fe4d8f --- /dev/null +++ b/pkg/rtun/match/memory/directory.go @@ -0,0 +1,56 @@ +package memory + +import ( + "context" + "errors" + "sync" + "time" +) + +var ErrServerNotFound = errors.New("rtun/match: server not found") + +// Directory is an in-memory Directory for tests and single-node deployments. +type Directory struct { + mu sync.RWMutex + servers map[string]record // serverID -> record +} + +type record struct { + addr string + expires time.Time +} + +func NewDirectory() *Directory { + return &Directory{ + servers: make(map[string]record), + } +} + +func (d *Directory) Advertise(ctx context.Context, serverID string, addr string, ttl time.Duration) error { + d.mu.Lock() + defer d.mu.Unlock() + d.servers[serverID] = record{addr: addr, expires: time.Now().Add(ttl)} + return nil +} + +func (d *Directory) Revoke(ctx context.Context, serverID string) error { + d.mu.Lock() + defer d.mu.Unlock() + delete(d.servers, serverID) + return nil +} + +func (d *Directory) Resolve(ctx context.Context, serverID string) (addr string, err error) { + now := time.Now() + d.mu.Lock() + defer d.mu.Unlock() + rec, ok := d.servers[serverID] + if !ok { + return "", ErrServerNotFound + } + if !rec.expires.IsZero() && now.After(rec.expires) { + delete(d.servers, serverID) + return "", ErrServerNotFound + } + return rec.addr, nil +} diff --git a/pkg/rtun/match/memory/directory_test.go b/pkg/rtun/match/memory/directory_test.go new file mode 100644 index 000000000..70a694417 --- /dev/null +++ b/pkg/rtun/match/memory/directory_test.go @@ -0,0 +1,54 @@ +package memory + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestDirectoryResolve(t *testing.T) { + d := NewDirectory() + ctx := context.Background() + + // Initially not found + _, err := d.Resolve(ctx, "server-a") + require.ErrorIs(t, err, ErrServerNotFound) + + // Register server + require.NoError(t, d.Advertise(ctx, "server-a", "127.0.0.1:5000", 10*time.Second)) + + addr, err := d.Resolve(ctx, "server-a") + require.NoError(t, err) + require.Equal(t, "127.0.0.1:5000", addr) +} + +func TestDirectoryUnregister(t *testing.T) { + d := NewDirectory() + ctx := context.Background() + + require.NoError(t, d.Advertise(ctx, "server-a", "127.0.0.1:5000", 10*time.Second)) + addr, err := d.Resolve(ctx, "server-a") + require.NoError(t, err) + require.Equal(t, "127.0.0.1:5000", addr) + + // Unregister + require.NoError(t, d.Revoke(ctx, "server-a")) + _, err = d.Resolve(ctx, "server-a") + require.ErrorIs(t, err, ErrServerNotFound) +} + +func TestDirectoryMultipleServersIsolation(t *testing.T) { + d := NewDirectory() + ctx := context.Background() + + require.NoError(t, d.Advertise(ctx, "server-a", "a.local:1", 10*time.Second)) + require.NoError(t, d.Advertise(ctx, "server-b", "b.local:2", 10*time.Second)) + + addrA, _ := d.Resolve(ctx, "server-a") + addrB, _ := d.Resolve(ctx, "server-b") + + require.Equal(t, "a.local:1", addrA) + require.Equal(t, "b.local:2", addrB) +} diff --git a/pkg/rtun/match/memory/presence.go b/pkg/rtun/match/memory/presence.go new file mode 100644 index 000000000..177e9bed6 --- /dev/null +++ b/pkg/rtun/match/memory/presence.go @@ -0,0 +1,93 @@ +package memory + +import ( + "context" + "sync" + "time" +) + +// Presence is an in-memory implementation of match.Presence for tests and +// single-node deployments. It stores per-(client, server) leases and global +// client ports. +type Presence struct { + mu sync.RWMutex + leases map[string]map[string]time.Time // clientID -> serverID -> expiry + ports map[string][]uint32 // clientID -> ports +} + +func NewPresence() *Presence { + return &Presence{ + leases: make(map[string]map[string]time.Time), + ports: make(map[string][]uint32), + } +} + +func (p *Presence) Announce(ctx context.Context, clientID string, serverID string, ttl time.Duration) error { + p.mu.Lock() + defer p.mu.Unlock() + if p.leases[clientID] == nil { + p.leases[clientID] = make(map[string]time.Time) + } + p.leases[clientID][serverID] = time.Now().Add(ttl) + return nil +} + +func (p *Presence) Revoke(ctx context.Context, clientID string, serverID string) error { + p.mu.Lock() + defer p.mu.Unlock() + if inner := p.leases[clientID]; inner != nil { + delete(inner, serverID) + if len(inner) == 0 { + delete(p.leases, clientID) + delete(p.ports, clientID) + } + } + return nil +} + +func (p *Presence) Locations(ctx context.Context, clientID string) ([]string, error) { + now := time.Now() + p.mu.Lock() + defer p.mu.Unlock() + inner := p.leases[clientID] + if inner == nil { + return nil, nil + } + for s, exp := range inner { + if !exp.IsZero() && now.After(exp) { + delete(inner, s) + } + } + if len(inner) == 0 { + delete(p.leases, clientID) + delete(p.ports, clientID) + return nil, nil + } + servers := make([]string, 0, len(inner)) + for s := range inner { + servers = append(servers, s) + } + return servers, nil +} + +func (p *Presence) SetPorts(ctx context.Context, clientID string, ports []uint32) error { + p.mu.Lock() + defer p.mu.Unlock() + if ports == nil { + delete(p.ports, clientID) + return nil + } + cp := append([]uint32(nil), ports...) + p.ports[clientID] = cp + return nil +} + +func (p *Presence) Ports(ctx context.Context, clientID string) ([]uint32, error) { + p.mu.RLock() + defer p.mu.RUnlock() + ports := p.ports[clientID] + if ports == nil { + return nil, nil + } + return append([]uint32(nil), ports...), nil +} diff --git a/pkg/rtun/match/memory/presence_test.go b/pkg/rtun/match/memory/presence_test.go new file mode 100644 index 000000000..50b50928c --- /dev/null +++ b/pkg/rtun/match/memory/presence_test.go @@ -0,0 +1,139 @@ +package memory + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestPresenceSetOnlineGetLocations(t *testing.T) { + p := NewPresence() + ctx := context.Background() + + // Initially no locations + locs, err := p.Locations(ctx, "client-1") + require.NoError(t, err) + require.Empty(t, locs) + + // Set ports and announce on server-a + err = p.SetPorts(ctx, "client-1", []uint32{1}) + require.NoError(t, err) + err = p.Announce(ctx, "client-1", "server-a", 10*time.Second) + require.NoError(t, err) + + locs, err = p.Locations(ctx, "client-1") + require.NoError(t, err) + require.Len(t, locs, 1) + require.Contains(t, locs, "server-a") + + // Add second server + err = p.Announce(ctx, "client-1", "server-b", 10*time.Second) + require.NoError(t, err) + + locs, err = p.Locations(ctx, "client-1") + require.NoError(t, err) + require.Len(t, locs, 2) + require.Contains(t, locs, "server-a") + require.Contains(t, locs, "server-b") +} + +func TestPresenceSetOffline(t *testing.T) { + p := NewPresence() + ctx := context.Background() + + err := p.SetPorts(ctx, "client-1", []uint32{1}) + require.NoError(t, err) + err = p.Announce(ctx, "client-1", "server-a", 10*time.Second) + require.NoError(t, err) + err = p.Announce(ctx, "client-1", "server-b", 10*time.Second) + require.NoError(t, err) + + // Remove server-a + err = p.Revoke(ctx, "client-1", "server-a") + require.NoError(t, err) + + locs, err := p.Locations(ctx, "client-1") + require.NoError(t, err) + require.Len(t, locs, 1) + require.Contains(t, locs, "server-b") + + // Remove server-b (last server); client should disappear + err = p.Revoke(ctx, "client-1", "server-b") + require.NoError(t, err) + + locs, err = p.Locations(ctx, "client-1") + require.NoError(t, err) + require.Empty(t, locs) + + // Ports should also be gone + ports, err := p.Ports(ctx, "client-1") + require.NoError(t, err) + require.Empty(t, ports) +} + +func TestPresenceGetPorts(t *testing.T) { + p := NewPresence() + ctx := context.Background() + + err := p.SetPorts(ctx, "client-1", []uint32{1, 2, 3}) + require.NoError(t, err) + err = p.Announce(ctx, "client-1", "server-a", 10*time.Second) + require.NoError(t, err) + + ports, err := p.Ports(ctx, "client-1") + require.NoError(t, err) + require.Equal(t, []uint32{1, 2, 3}, ports) + + // Update ports independently of leases + err = p.SetPorts(ctx, "client-1", []uint32{5}) + require.NoError(t, err) + + ports, err = p.Ports(ctx, "client-1") + require.NoError(t, err) + require.Equal(t, []uint32{5}, ports) +} + +func TestPresenceTTLExpiry(t *testing.T) { + p := NewPresence() + ctx := context.Background() + + // Set with very short TTL + err := p.SetPorts(ctx, "client-1", []uint32{1}) + require.NoError(t, err) + err = p.Announce(ctx, "client-1", "server-a", 10*time.Millisecond) + require.NoError(t, err) + + locs, err := p.Locations(ctx, "client-1") + require.NoError(t, err) + require.Len(t, locs, 1) + + // Wait for expiry + time.Sleep(20 * time.Millisecond) + + // GetLocations should prune expired and return empty + locs, err = p.Locations(ctx, "client-1") + require.NoError(t, err) + require.Empty(t, locs) +} + +func TestPresenceMultipleClientsIsolation(t *testing.T) { + p := NewPresence() + ctx := context.Background() + + err := p.SetPorts(ctx, "client-1", []uint32{1}) + require.NoError(t, err) + err = p.Announce(ctx, "client-1", "server-a", 10*time.Second) + require.NoError(t, err) + err = p.SetPorts(ctx, "client-2", []uint32{2}) + require.NoError(t, err) + err = p.Announce(ctx, "client-2", "server-b", 10*time.Second) + require.NoError(t, err) + + locs1, _ := p.Locations(ctx, "client-1") + locs2, _ := p.Locations(ctx, "client-2") + + require.Equal(t, []string{"server-a"}, locs1) + require.Equal(t, []string{"server-b"}, locs2) +} diff --git a/pkg/rtun/match/presence.go b/pkg/rtun/match/presence.go new file mode 100644 index 000000000..06be0da20 --- /dev/null +++ b/pkg/rtun/match/presence.go @@ -0,0 +1,21 @@ +package match + +import ( + "context" + "time" +) + +// Presence tracks per-(client, server) leases and a client's global ports. +// +// Semantics: +// - Announce/refresh a lease for (clientID, serverID) with a TTL. +// - Revoke removes a single server's lease for a client. +// - Locations returns only non-expired serverIDs. +// - SetPorts sets the client's ports (typically once at hello); Ports returns them. +type Presence interface { + Announce(ctx context.Context, clientID string, serverID string, ttl time.Duration) error + Revoke(ctx context.Context, clientID string, serverID string) error + Locations(ctx context.Context, clientID string) ([]string, error) + SetPorts(ctx context.Context, clientID string, ports []uint32) error + Ports(ctx context.Context, clientID string) ([]uint32, error) +} diff --git a/pkg/rtun/match/route.go b/pkg/rtun/match/route.go new file mode 100644 index 000000000..6af9ed7e9 --- /dev/null +++ b/pkg/rtun/match/route.go @@ -0,0 +1,62 @@ +package match + +import ( + "context" + "fmt" + "net" + "net/url" + "strconv" + + "github.com/conductorone/baton-sdk/pkg/rtun/server" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +// OwnerRouter helps route work to the server process that owns a client's link. +type OwnerRouter struct { + Locator *Locator + Directory Directory + DialOpts []grpc.DialOption +} + +// DialOwner resolves the owner of clientID and returns a gRPC connection to that server. +// The caller should use this connection to invoke services on the owner, where the owner +// can use its local Registry.DialContext to perform reverse RPCs. +func (r *OwnerRouter) DialOwner(ctx context.Context, clientID string) (*grpc.ClientConn, string, error) { + owner, _, err := r.Locator.OwnerOf(ctx, clientID) + if err != nil { + return nil, "", fmt.Errorf("rtun: locate owner: %w", err) + } + addr, err := r.Directory.Resolve(ctx, owner) + if err != nil { + return nil, "", fmt.Errorf("rtun: resolve owner address: %w", err) + } + opts := r.DialOpts + conn, err := grpc.DialContext(ctx, addr, opts...) + if err != nil { + return nil, "", fmt.Errorf("rtun: dial owner: %w", err) + } + return conn, owner, nil +} + +// LocalReverseDial is a helper to be called ON the owner server process. +// It uses the local Registry to open a reverse connection to the client. +// clientID must be URL-safe; use url.PathEscape if it contains special characters. +func LocalReverseDial(ctx context.Context, reg *server.Registry, clientID string, port uint32) (*grpc.ClientConn, error) { + u := url.URL{ + Scheme: "rtun", + Host: net.JoinHostPort( + clientID, + strconv.FormatUint(uint64(port), 10), + ), + } + addr := u.String() + conn, err := grpc.DialContext(ctx, addr, + grpc.WithContextDialer(reg.DialContext), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + return nil, fmt.Errorf("rtun: reverse dial: %w", err) + } + return conn, nil +} diff --git a/pkg/rtun/match/route_test.go b/pkg/rtun/match/route_test.go new file mode 100644 index 000000000..572152ed1 --- /dev/null +++ b/pkg/rtun/match/route_test.go @@ -0,0 +1,118 @@ +package match + +import ( + "context" + "net" + "testing" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" + "github.com/conductorone/baton-sdk/pkg/rtun/match/memory" + "github.com/conductorone/baton-sdk/pkg/rtun/server" + "github.com/conductorone/baton-sdk/pkg/rtun/transport" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + health "google.golang.org/grpc/health" + healthpb "google.golang.org/grpc/health/grpc_health_v1" +) + +// testValidator for integration tests. +type testValidator struct{ id string } + +func (t testValidator) ValidateAuth(ctx context.Context) (string, error) { return t.id, nil } +func (t testValidator) ValidateHello(ctx context.Context, hello *rtunpb.Hello) error { return nil } + +// TestOwnerRouterTwoServers simulates: +// - Server A owns client-123 (handler + registry). +// - Server B is a caller that uses OwnerRouter to find A, dials A, and uses A's registry to reverse-dial client-123. +func TestOwnerRouterTwoServers(t *testing.T) { + presence := memory.NewPresence() + directory := memory.NewDirectory() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Server A setup + regA := server.NewRegistry() + handlerA := server.NewHandler(regA, "server-a", testValidator{id: "client-123"}) + gsrvA := grpc.NewServer() + rtunpb.RegisterReverseTunnelServer(gsrvA, handlerA) + lA, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + go func() { _ = gsrvA.Serve(lA) }() + + // Register server A in directory + require.NoError(t, directory.Advertise(ctx, "server-a", lA.Addr().String(), 10*time.Second)) + + // Client connects to server A with cancelable context + clientCtx, clientCancel := context.WithCancel(ctx) + defer clientCancel() + + ccA, err := grpc.Dial(lA.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + rtunClientA := rtunpb.NewReverseTunnelClient(ccA) + streamA, err := rtunClientA.Link(clientCtx) + require.NoError(t, err) + + // Client-side session and HELLO + clA := &clientLink{cli: streamA} + sessA := transport.NewSession(clA) + require.NoError(t, clA.Send(&rtunpb.Frame{Sid: 0, Kind: &rtunpb.Frame_Hello{Hello: &rtunpb.Hello{Ports: []uint32{1}}}})) + lnA, err := sessA.Listen(ctx, 1) + require.NoError(t, err) + + // Client runs health service + cgsA := grpc.NewServer() + hsA := health.NewServer() + healthpb.RegisterHealthServer(cgsA, hsA) + hsA.SetServingStatus("", healthpb.HealthCheckResponse_SERVING) + go func() { _ = cgsA.Serve(lnA) }() + + // Mark client-123 online on server-a in presence + _ = presence.SetPorts(ctx, "client-123", []uint32{1}) + _ = presence.Announce(ctx, "client-123", "server-a", 10*time.Second) + + // Server B (caller) uses OwnerRouter to find and dial server A + router := &OwnerRouter{ + Locator: &Locator{Presence: presence}, + Directory: directory, + DialOpts: []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}, + } + // Wait briefly for registration + time.Sleep(50 * time.Millisecond) + + ownerConn, ownerID, err := router.DialOwner(ctx, "client-123") + require.NoError(t, err) + require.Equal(t, "server-a", ownerID) + + // Now server B has a connection to server A. In production, B would invoke a service on A + // that internally uses regA.DialContext. For this test, simulate by directly using regA + // (since we're in the same process). + rconn, err := LocalReverseDial(ctx, regA, "client-123", 1) + require.NoError(t, err) + + // Perform health check over reverse connection + hc := healthpb.NewHealthClient(rconn) + resp, err := hc.Check(ctx, &healthpb.HealthCheckRequest{Service: ""}) + require.NoError(t, err) + require.Equal(t, healthpb.HealthCheckResponse_SERVING, resp.GetStatus()) + + // Clean up in correct order + rconn.Close() + ownerConn.Close() + cgsA.GracefulStop() + lnA.Close() + clientCancel() // Close client stream + ccA.Close() + gsrvA.GracefulStop() + lA.Close() +} + +// clientLink adapts the client bidi stream to transport.Link. +type clientLink struct { + cli rtunpb.ReverseTunnel_LinkClient +} + +func (c *clientLink) Send(f *rtunpb.Frame) error { return c.cli.Send(f) } +func (c *clientLink) Recv() (*rtunpb.Frame, error) { return c.cli.Recv() } +func (c *clientLink) Context() context.Context { return c.cli.Context() } diff --git a/pkg/rtun/server/auth.go b/pkg/rtun/server/auth.go new file mode 100644 index 000000000..d1ae59ed3 --- /dev/null +++ b/pkg/rtun/server/auth.go @@ -0,0 +1,16 @@ +package server + +import ( + "context" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" +) + +// TokenValidator decouples auth from the rtun transport. +type TokenValidator interface { + // ValidateAuth is invoked when a stream is first connected. It should authenticate the caller + // from the gRPC context (e.g., mTLS, headers) and return the bound clientID. + ValidateAuth(ctx context.Context) (clientID string, err error) + // ValidateHello validates the HELLO frame contents (e.g., ports, protocol negotiation). + ValidateHello(ctx context.Context, hello *rtunpb.Hello) error +} diff --git a/pkg/rtun/server/grpc_options.go b/pkg/rtun/server/grpc_options.go new file mode 100644 index 000000000..2dcdf23f5 --- /dev/null +++ b/pkg/rtun/server/grpc_options.go @@ -0,0 +1,28 @@ +package server + +import ( + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/keepalive" +) + +// RecommendedGRPCServerOptions returns server options enabling basic keepalive and +// reasonable message size limits suitable for RTUN services. +func RecommendedGRPCServerOptions() []grpc.ServerOption { + return []grpc.ServerOption{ + grpc.MaxRecvMsgSize(4 * 1024 * 1024), + grpc.MaxSendMsgSize(4 * 1024 * 1024), + grpc.MaxConcurrentStreams(250), + grpc.KeepaliveParams(keepalive.ServerParameters{ + Time: 30 * time.Second, + Timeout: 10 * time.Second, + MaxConnectionIdle: 0, + MaxConnectionAge: 0, + }), + grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{ + MinTime: 10 * time.Second, + PermitWithoutStream: true, + }), + } +} diff --git a/pkg/rtun/server/handler.go b/pkg/rtun/server/handler.go new file mode 100644 index 000000000..01819e7d4 --- /dev/null +++ b/pkg/rtun/server/handler.go @@ -0,0 +1,144 @@ +package server + +import ( + "context" + "sync" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" + "github.com/conductorone/baton-sdk/pkg/rtun/transport" + "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" + "go.uber.org/zap" +) + +// Handler implements the ReverseTunnel gRPC service, binding Links to Sessions and the Registry. +type Handler struct { + rtunpb.UnimplementedReverseTunnelServer + + reg *Registry + serverID string + tv TokenValidator + + mu sync.Mutex + + // metrics (optional) + m *serverMetrics +} + +func NewHandler(reg *Registry, serverID string, tv TokenValidator, opts ...Option) rtunpb.ReverseTunnelServer { + var o options + for _, opt := range opts { + if opt == nil { + continue + } + opt(&o) + } + h := &Handler{reg: reg, serverID: serverID, tv: tv} + if o.metrics != nil { + h.m = newServerMetrics(o.metrics) + } + return h +} + +// Link accepts a bidi stream and binds it to a transport.Session after validating HELLO. +func (h *Handler) Link(stream rtunpb.ReverseTunnel_LinkServer) error { + // Wrap the gRPC stream as transport.Link + l := &grpcLink{srv: stream} + + // Authenticate and determine clientID via TokenValidator BEFORE waiting for HELLO. + if h.tv == nil { + return ErrProtocol + } + clientID, err := h.tv.ValidateAuth(stream.Context()) + if err != nil { + return err + } + logger := ctxzap.Extract(stream.Context()).With(zap.String("client_id", clientID)) + logger.Info("auth ok") + + // First frame must be HELLO with timeout; if not HELLO, protocol violation + type recvResult struct { + fr *rtunpb.Frame + err error + } + resCh := make(chan recvResult, 1) + go func() { + fr, err := l.Recv() + resCh <- recvResult{fr: fr, err: err} + }() + helloTimeout := 15 * time.Second + var hello *rtunpb.Hello + select { + case res := <-resCh: + if res.err != nil { + return res.err + } + hello = res.fr.GetHello() + if hello == nil { + logger.Warn("first frame not HELLO; closing") + if h.m != nil { + h.m.helloRejected(stream.Context(), "not_hello") + } + return ErrProtocol + } + logger.Info("HELLO received", zap.Uint32s("ports", hello.GetPorts())) + // enforce reasonable HELLO port count limit (2500) + if len(hello.GetPorts()) > 2500 { + logger.Warn("HELLO ports exceed limit", zap.Int("count", len(hello.GetPorts()))) + if h.m != nil { + h.m.helloPortsOverLimit(stream.Context()) + } + return ErrProtocol + } + if err := h.tv.ValidateHello(stream.Context(), hello); err != nil { + if h.m != nil { + h.m.helloRejected(stream.Context(), "validate_failed") + } + return err + } + case <-time.After(helloTimeout): + logger.Warn("HELLO timeout") + if h.m != nil { + h.m.helloTimeout(stream.Context()) + } + return ErrHelloTimeout + } + + // Bind Session and start Recv loop. + var sessOpts []transport.Option + ports := hello.GetPorts() + if len(ports) > 0 { + sessOpts = append(sessOpts, transport.WithAllowedPorts(ports)) + } + // Pass metrics down to transport if available + if h.m != nil && h.m.h != nil { + sessOpts = append(sessOpts, transport.WithMetricsHandler(h.m.h)) + } + s := transport.NewSession(l, sessOpts...) + if h.m != nil { + h.m.registryRegister(stream.Context()) + } + h.reg.Register(stream.Context(), clientID, s) + defer h.reg.Unregister(stream.Context(), clientID) + defer func() { + if h.m != nil { + h.m.registryUnregister(stream.Context()) + } + }() + + // Start the session; recvLoop will run until link Recv errors (stream close/cancel) + s.Start() + + // Block until stream context is done (client disconnect or server shutdown) + <-stream.Context().Done() + return stream.Context().Err() +} + +// grpcLink adapts the gRPC server stream to transport.Link +type grpcLink struct { + srv rtunpb.ReverseTunnel_LinkServer +} + +func (g *grpcLink) Send(f *rtunpb.Frame) error { return g.srv.Send(f) } +func (g *grpcLink) Recv() (*rtunpb.Frame, error) { return g.srv.Recv() } +func (g *grpcLink) Context() context.Context { return g.srv.Context() } diff --git a/pkg/rtun/server/metrics.go b/pkg/rtun/server/metrics.go new file mode 100644 index 000000000..7700605a0 --- /dev/null +++ b/pkg/rtun/server/metrics.go @@ -0,0 +1,46 @@ +package server + +import ( + "context" + + sdkmetrics "github.com/conductorone/baton-sdk/pkg/metrics" +) + +type serverMetrics struct { + h sdkmetrics.Handler + + helloTimeoutCtr sdkmetrics.Int64Counter + helloPortsOverCtr sdkmetrics.Int64Counter + helloRejectedCtr sdkmetrics.Int64Counter + regRegisterCtr sdkmetrics.Int64Counter + regUnregisterCtr sdkmetrics.Int64Counter + reverseDialOKCtr sdkmetrics.Int64Counter + reverseDialMissCtr sdkmetrics.Int64Counter +} + +func newServerMetrics(h sdkmetrics.Handler) *serverMetrics { + return &serverMetrics{ + h: h, + helloTimeoutCtr: h.Int64Counter("rtun.server.hello_timeout_total", "HELLO timeouts", sdkmetrics.Dimensionless), + helloPortsOverCtr: h.Int64Counter("rtun.server.hello_ports_over_limit_total", "HELLO ports over limit", sdkmetrics.Dimensionless), + helloRejectedCtr: h.Int64Counter("rtun.server.hello_rejected_total", "HELLO rejected for reason", sdkmetrics.Dimensionless), + regRegisterCtr: h.Int64Counter("rtun.server.registry_register_total", "registry register", sdkmetrics.Dimensionless), + regUnregisterCtr: h.Int64Counter("rtun.server.registry_unregister_total", "registry unregister", sdkmetrics.Dimensionless), + reverseDialOKCtr: h.Int64Counter("rtun.server.reverse_dial_success_total", "reverse dial success", sdkmetrics.Dimensionless), + reverseDialMissCtr: h.Int64Counter("rtun.server.reverse_dial_not_found_total", "reverse dial not found", sdkmetrics.Dimensionless), + } +} + +func (m *serverMetrics) helloTimeout(ctx context.Context) { m.helloTimeoutCtr.Add(ctx, 1, nil) } +func (m *serverMetrics) helloPortsOverLimit(ctx context.Context) { + m.helloPortsOverCtr.Add(ctx, 1, nil) +} +func (m *serverMetrics) helloRejected(ctx context.Context, reason string) { + m.helloRejectedCtr.Add(ctx, 1, map[string]string{"reason": reason}) +} +func (m *serverMetrics) registryRegister(ctx context.Context) { m.regRegisterCtr.Add(ctx, 1, nil) } +func (m *serverMetrics) registryUnregister(ctx context.Context) { m.regUnregisterCtr.Add(ctx, 1, nil) } +func (m *serverMetrics) reverseDialSuccess(ctx context.Context) { m.reverseDialOKCtr.Add(ctx, 1, nil) } +func (m *serverMetrics) reverseDialNotFound(ctx context.Context) { + m.reverseDialMissCtr.Add(ctx, 1, nil) +} diff --git a/pkg/rtun/server/options.go b/pkg/rtun/server/options.go new file mode 100644 index 000000000..091cc910e --- /dev/null +++ b/pkg/rtun/server/options.go @@ -0,0 +1,24 @@ +package server + +import ( + "errors" + + sdkmetrics "github.com/conductorone/baton-sdk/pkg/metrics" +) + +var ( + ErrNotImplemented = errors.New("rtun/server: not implemented") + ErrProtocol = errors.New("rtun/server: protocol error") + ErrHelloTimeout = errors.New("rtun/server: hello timeout") +) + +type Option func(*options) + +type options struct { + metrics sdkmetrics.Handler +} + +// WithMetricsHandler injects a metrics handler for server components (handler/registry). +func WithMetricsHandler(h sdkmetrics.Handler) Option { + return func(o *options) { o.metrics = h } +} diff --git a/pkg/rtun/server/registry.go b/pkg/rtun/server/registry.go new file mode 100644 index 000000000..35a16aff8 --- /dev/null +++ b/pkg/rtun/server/registry.go @@ -0,0 +1,91 @@ +package server + +import ( + "context" + "fmt" + "net" + "net/url" + "strconv" + "sync" + + "github.com/conductorone/baton-sdk/pkg/rtun/transport" +) + +type Registry struct { + mu sync.RWMutex + sessions map[string]*transport.Session + m *serverMetrics +} + +func NewRegistry(opts ...Option) *Registry { + var o options + for _, opt := range opts { + opt(&o) + } + r := &Registry{sessions: make(map[string]*transport.Session)} + if o.metrics != nil { + r.m = newServerMetrics(o.metrics) + } + return r +} + +// DialContext dials ONLY if this process owns the client link. +// addr: "rtun://:" where clientID is URL-safe and port is required. +// clientID must not contain unescaped colons or slashes; use url.PathEscape if needed. +func (r *Registry) DialContext(ctx context.Context, addr string) (net.Conn, error) { + // Parse rtun://clientID:port + u, err := url.Parse(addr) + if err != nil { + return nil, err + } + if u.Scheme != "rtun" { + return nil, fmt.Errorf("rtun: invalid scheme: %s", u.Scheme) + } + host := u.Host + clientID, portStr, err := net.SplitHostPort(host) + if err != nil { + return nil, fmt.Errorf("rtun: missing or invalid port in '%s': %w", addr, err) + } + pu, err := strconv.ParseUint(portStr, 10, 32) + if err != nil { + return nil, fmt.Errorf("rtun: invalid port in '%s': %w", addr, err) + } + port := uint32(pu) + + r.mu.RLock() + s := r.sessions[clientID] + r.mu.RUnlock() + if s == nil { + if r.m != nil { + r.m.reverseDialNotFound(ctx) + } + return nil, fmt.Errorf("rtun: client not connected: %s", clientID) + } + conn, err := s.Open(ctx, port) + if err == nil { + if r.m != nil { + r.m.reverseDialSuccess(ctx) + } + } + return conn, err +} + +// Register binds a client's Session to this Registry under the given clientID. +func (r *Registry) Register(ctx context.Context, clientID string, s *transport.Session) { + r.mu.Lock() + r.sessions[clientID] = s + r.mu.Unlock() + if r.m != nil { + r.m.registryRegister(ctx) + } +} + +// Unregister removes a client's Session binding. +func (r *Registry) Unregister(ctx context.Context, clientID string) { + r.mu.Lock() + delete(r.sessions, clientID) + r.mu.Unlock() + if r.m != nil { + r.m.registryUnregister(ctx) + } +} diff --git a/pkg/rtun/server/server_integration_test.go b/pkg/rtun/server/server_integration_test.go new file mode 100644 index 000000000..62d3db82e --- /dev/null +++ b/pkg/rtun/server/server_integration_test.go @@ -0,0 +1,91 @@ +package server + +import ( + "context" + "net" + "testing" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" + "github.com/conductorone/baton-sdk/pkg/rtun/transport" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + health "google.golang.org/grpc/health" + healthpb "google.golang.org/grpc/health/grpc_health_v1" +) + +// testValidator is a production-shaped validator used only in tests. +// It authenticates the link and returns a fixed clientID; HELLO is allowed as-is. +type testValidator struct{ id string } + +func (t testValidator) ValidateAuth(ctx context.Context) (string, error) { return t.id, nil } +func (t testValidator) ValidateHello(ctx context.Context, hello *rtunpb.Hello) error { return nil } + +// clientLink adapts the client bidi stream to transport.Link on the client side. +type clientLink struct { + cli rtunpb.ReverseTunnel_LinkClient +} + +func (c clientLink) Send(f *rtunpb.Frame) error { return c.cli.Send(f) } +func (c clientLink) Recv() (*rtunpb.Frame, error) { return c.cli.Recv() } +func (c clientLink) Context() context.Context { return c.cli.Context() } + +// TestReverseGrpcE2E spins up a real gRPC server with Handler, connects a real gRPC client stream for Link, +// runs the standard gRPC health service over rtun on the client, and performs a health check from the owner via Registry.DialContext. +func TestReverseGrpcE2E(t *testing.T) { + // Server side: real gRPC server with our handler + reg := NewRegistry() + h := NewHandler(reg, "server-1", testValidator{id: "client-123"}) + gsrv := grpc.NewServer() + rtunpb.RegisterReverseTunnelServer(gsrv, h) + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer l.Close() + go func() { _ = gsrv.Serve(l) }() + defer gsrv.GracefulStop() + + // Client side: dial server and open Link stream + cc, err := grpc.Dial(l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer cc.Close() + rtunClient := rtunpb.NewReverseTunnelClient(cc) + stream, err := rtunClient.Link(context.Background()) + require.NoError(t, err) + + // Wrap client stream as transport.Link and start Session + cl := clientLink{cli: stream} + sess := transport.NewSession(cl) + // Send HELLO announcing port 1 + require.NoError(t, cl.Send(&rtunpb.Frame{Sid: 0, Kind: &rtunpb.Frame_Hello{Hello: &rtunpb.Hello{Ports: []uint32{1}}}})) + ln, err := sess.Listen(context.Background(), 1) + require.NoError(t, err) + defer ln.Close() + + // Run the standard gRPC health service over the rtun listener + cgs := grpc.NewServer() + hs := health.NewServer() + healthpb.RegisterHealthServer(cgs, hs) + hs.SetServingStatus("", healthpb.HealthCheckResponse_SERVING) + go func() { _ = cgs.Serve(ln) }() + defer cgs.GracefulStop() + + // Owner side: reverse dial and perform health check + // Wait briefly for registration to finish + time.Sleep(50 * time.Millisecond) + rconn, err := reg.DialContext(context.Background(), "rtun://client-123:1") + require.NoError(t, err) + defer rconn.Close() + + ownerCC, err := grpc.DialContext(context.Background(), "ignored", + grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { return rconn, nil }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + defer ownerCC.Close() + + hc := healthpb.NewHealthClient(ownerCC) + resp, err := hc.Check(context.Background(), &healthpb.HealthCheckRequest{Service: ""}) + require.NoError(t, err) + require.Equal(t, healthpb.HealthCheckResponse_SERVING, resp.GetStatus()) +} diff --git a/pkg/rtun/transport/closedset.go b/pkg/rtun/transport/closedset.go new file mode 100644 index 000000000..b21df51e1 --- /dev/null +++ b/pkg/rtun/transport/closedset.go @@ -0,0 +1,151 @@ +package transport + +// closedSet maintains a compact set of closed SIDs using a run-length encoded +// interval list and a high-water mark for the largest contiguous closed prefix +// starting from 1. It is package-local and not concurrency-safe; callers should +// provide external synchronization. +// +// Invariants: +// - cs.ranges is sorted by start and contains disjoint, non-adjacent intervals +// with start <= end. +// - All SIDs in [1..highClosed] are considered closed and not represented in +// cs.ranges (auto-pruned on each Close call). +type closedSet struct { + highClosed uint32 + ranges []interval +} + +type interval struct { + start uint32 + end uint32 +} + +// IsClosed returns true if sid has been marked closed. +func (cs *closedSet) IsClosed(sid uint32) bool { + if sid == 0 { + return true + } + if sid <= cs.highClosed { + return true + } + i := cs.find(sid) + if i < len(cs.ranges) { + iv := cs.ranges[i] + if sid >= iv.start && sid <= iv.end { + return true + } + } + if i > 0 { + iv := cs.ranges[i-1] + return sid >= iv.start && sid <= iv.end + } + return false +} + +// Close marks sid as closed. This merges adjacent/overlapping intervals and +// auto-prunes intervals that become part of the contiguous closed prefix. +func (cs *closedSet) Close(sid uint32) { + if sid == 0 { + return + } + // If already in the contiguous prefix, nothing to do. + if sid <= cs.highClosed { + return + } + + i := cs.find(sid) + // Check if already closed by neighboring interval + if i < len(cs.ranges) { + iv := cs.ranges[i] + if sid >= iv.start && sid <= iv.end { + return + } + } + if i > 0 { + iv := cs.ranges[i-1] + if sid >= iv.start && sid <= iv.end { + return + } + } + + // New interval initially [sid, sid]. Merge with previous/next if adjacent or overlapping. + start, end := sid, sid + // Merge with previous + if i > 0 { + prev := cs.ranges[i-1] + if prev.end+1 >= sid { // adjacent or overlap + start = prev.start + if prev.end > end { + end = prev.end + } + // remove prev + cs.ranges = append(cs.ranges[:i-1], cs.ranges[i:]...) + i-- + } + } + // Merge with next (note: i points to the first interval with start > sid or the merged position) + if i < len(cs.ranges) { + next := cs.ranges[i] + if next.start <= end+1 { // adjacent or overlap + if next.end > end { + end = next.end + } + // remove next + cs.ranges = append(cs.ranges[:i], cs.ranges[i+1:]...) + } + } + // Insert merged interval at position i + cs.ranges = append(cs.ranges, interval{}) + copy(cs.ranges[i+1:], cs.ranges[i:]) + cs.ranges[i] = interval{start: start, end: end} + + // Update highClosed: if the merged interval starts at highClosed+1, extend the prefix. + cs.promotePrefix() +} + +// find returns the index of the first interval with start > sid, or len(ranges) +// if none; suitable for insertion and neighbor checks. +func (cs *closedSet) find(sid uint32) int { + lo, hi := 0, len(cs.ranges) + for lo < hi { + mid := (lo + hi) >> 1 + if cs.ranges[mid].start <= sid { + lo = mid + 1 + } else { + hi = mid + } + } + return lo +} + +// promotePrefix extends highClosed if there is an interval that begins at +// highClosed+1. It also prunes any intervals fully included in the prefix by +// removing them from the ranges slice. +func (cs *closedSet) promotePrefix() { + changed := true + for changed { + changed = false + if len(cs.ranges) == 0 { + return + } + // After insert/merge, the first interval that could extend the prefix is the earliest one + // whose start is <= highClosed+1. Because we always keep ranges sorted and disjoint, it must + // be at index 0 if it can extend the prefix. + iv := cs.ranges[0] + want := cs.highClosed + 1 + if iv.start == want { + // Extend prefix to this interval's end and prune it + cs.highClosed = iv.end + cs.ranges = cs.ranges[1:] + changed = true + continue + } + // Prune only intervals fully covered by the prefix. + if iv.end <= cs.highClosed { + cs.ranges = cs.ranges[1:] + changed = true + continue + } + // Otherwise, iv lies strictly above the prefix and cannot extend it now. + } +} diff --git a/pkg/rtun/transport/closedset_test.go b/pkg/rtun/transport/closedset_test.go new file mode 100644 index 000000000..36c257b5c --- /dev/null +++ b/pkg/rtun/transport/closedset_test.go @@ -0,0 +1,152 @@ +package transport + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestClosedSetBasicPrefixAndRanges(t *testing.T) { + var cs closedSet + + // Initially everything open except sid 0 considered closed by spec + require.True(t, cs.IsClosed(0)) + require.False(t, cs.IsClosed(1)) + + // Close(1) promotes prefix to 1 + cs.Close(1) + require.True(t, cs.IsClosed(1)) + require.Equal(t, uint32(1), cs.highClosed) + require.Empty(t, cs.ranges) + + // Close non-contiguous sid produces a range + cs.Close(3) + require.False(t, cs.IsClosed(2)) + require.True(t, cs.IsClosed(3)) + require.Len(t, cs.ranges, 1) + require.Equal(t, interval{start: 3, end: 3}, cs.ranges[0]) + + // Close(2) creates [2,3] which begins at highClosed+1=2, so prefix promotes to 3 and ranges empty + cs.Close(2) + require.True(t, cs.IsClosed(2)) + require.Equal(t, uint32(3), cs.highClosed) + require.Empty(t, cs.ranges) + + // Closing 4 promotes prefix to 4 (no ranges kept) + cs.Close(4) + require.True(t, cs.IsClosed(4)) + require.Equal(t, uint32(4), cs.highClosed) + require.Empty(t, cs.ranges) + + // Closing 5 promotes prefix to 5 (no ranges kept) + cs.Close(5) + require.True(t, cs.IsClosed(5)) + require.Equal(t, uint32(5), cs.highClosed) + require.Empty(t, cs.ranges) + + // Closing 0 ignored; does not change state (highClosed already 5 and no ranges) + cs.Close(0) + require.Equal(t, uint32(5), cs.highClosed) + require.Empty(t, cs.ranges) + + // Closing 6 promotes prefix further + cs.Close(6) + require.Equal(t, uint32(6), cs.highClosed) + require.Empty(t, cs.ranges) + + // Now closing 2..6 already closed does nothing + for sid := uint32(2); sid <= 6; sid++ { + cs.Close(sid) + } + require.Equal(t, uint32(6), cs.highClosed) + require.Empty(t, cs.ranges) + + // Close(7) continues extension + cs.Close(7) + require.Equal(t, uint32(7), cs.highClosed) + require.Empty(t, cs.ranges) + + // Close(8) continues extension + cs.Close(8) + require.Equal(t, uint32(8), cs.highClosed) + require.Empty(t, cs.ranges) +} + +func TestClosedSetMergeBothSides(t *testing.T) { + var cs closedSet + // Create two ranges [10,12] and [14,16], then close 13 to merge both into [10,16] + cs.Close(11) + cs.Close(12) + cs.Close(10) + require.Equal(t, []interval{{start: 10, end: 12}}, cs.ranges) + + cs.Close(15) + cs.Close(16) + cs.Close(14) + require.Equal(t, []interval{{start: 10, end: 12}, {start: 14, end: 16}}, cs.ranges) + + cs.Close(13) + require.Equal(t, []interval{{start: 10, end: 16}}, cs.ranges) +} + +func TestClosedSetPromotionFromRange(t *testing.T) { + var cs closedSet + // Close a distant block [5,7] + cs.Close(7) + cs.Close(6) + cs.Close(5) + require.Equal(t, uint32(0), cs.highClosed) + require.Equal(t, []interval{{start: 5, end: 7}}, cs.ranges) + + // Now close 1..4; once 4 closes, the first range starts at 5 which is highClosed+1 => promotion to 7 + cs.Close(1) + cs.Close(2) + cs.Close(3) + require.Equal(t, uint32(3), cs.highClosed) + cs.Close(4) + require.Equal(t, uint32(7), cs.highClosed) + require.Empty(t, cs.ranges) +} + +func TestClosedSetIdempotentAndOrderInvariant(t *testing.T) { + var a, b closedSet + order1 := []uint32{100, 1, 3, 2, 5, 4} + order2 := []uint32{1, 2, 3, 4, 5, 100} + for _, sid := range order1 { + a.Close(sid) + } + for _, sid := range order2 { + b.Close(sid) + } + require.Equal(t, a.highClosed, b.highClosed) + require.Equal(t, a.ranges, b.ranges) +} + +func TestClosedSetLargeValues(t *testing.T) { + var cs closedSet + cs.Close(1_000_000) + cs.Close(1_000_002) + require.False(t, cs.IsClosed(1)) + require.True(t, cs.IsClosed(1_000_000)) + require.True(t, cs.IsClosed(1_000_002)) + require.False(t, cs.IsClosed(1_000_001)) + + // Merge into single interval [1_000_000, 1_000_010] + for sid := uint32(1_000_001); sid <= 1_000_010; sid++ { + cs.Close(sid) + } + require.Equal(t, []interval{{start: 1_000_000, end: 1_000_010}}, cs.ranges) +} + +func TestClosedSetHighClosedQuery(t *testing.T) { + var cs closedSet + for sid := uint32(1); sid <= 50; sid++ { + cs.Close(sid) + } + require.Equal(t, uint32(50), cs.highClosed) + // Queries under prefix are true without ranges + for sid := uint32(1); sid <= 50; sid++ { + require.True(t, cs.IsClosed(sid)) + } + require.Empty(t, cs.ranges) +} diff --git a/pkg/rtun/transport/conn.go b/pkg/rtun/transport/conn.go new file mode 100644 index 000000000..6b2a3c249 --- /dev/null +++ b/pkg/rtun/transport/conn.go @@ -0,0 +1,306 @@ +package transport + +import ( + "errors" + "io" + "net" + "sync" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" +) + +const maxWriteChunk = 32 * 1024 + +// virtConn implements net.Conn over a multiplexed SID. +type virtConn struct { + mux *Session + sid uint32 + + readCh chan []byte + readErr error + readMu sync.Mutex + readRem []byte // remainder from partial reads + + writeMu sync.Mutex + writeClosed bool + + rdDeadline time.Time + wrDeadline time.Time + + // closeReadOnce ensures the read side is closed exactly once, regardless of + // whether closure originates locally (Close), remotely (FIN), or via RST. + closeReadOnce sync.Once + + // idle timer management + idleMu sync.Mutex + idleTimer *time.Timer +} + +var _ net.Conn = (*virtConn)(nil) + +func newVirtConn(m *Session, sid uint32) *virtConn { + return &virtConn{ + mux: m, + sid: sid, + readCh: make(chan []byte, 16), + } +} + +func (c *virtConn) Read(p []byte) (int, error) { + // Check terminal error or remainder under lock + c.readMu.Lock() + if c.readErr != nil { + err := c.readErr + c.readMu.Unlock() + return 0, err + } + if len(c.readRem) > 0 { + n := copy(p, c.readRem) + c.readRem = c.readRem[n:] + c.readMu.Unlock() + return n, nil + } + // Snapshot deadline then release lock before blocking on channel + deadline := c.rdDeadline + c.readMu.Unlock() + + if deadline.IsZero() { + buf, ok := <-c.readCh + if !ok { + c.readMu.Lock() + err := c.readErr + c.readMu.Unlock() + if err == nil { + return 0, io.EOF + } + return 0, err + } + n := copy(p, buf) + c.onActivity() + if n < len(buf) { + c.readMu.Lock() + c.readRem = buf[n:] + c.readMu.Unlock() + } + return n, nil + } + + // Deadline set + until := time.Until(deadline) + if until <= 0 { + return 0, ErrTimeout + } + timer := time.NewTimer(until) + defer timer.Stop() + select { + case buf, ok := <-c.readCh: + if !ok { + c.readMu.Lock() + err := c.readErr + c.readMu.Unlock() + if err == nil { + return 0, io.EOF + } + return 0, err + } + n := copy(p, buf) + c.onActivity() + if n < len(buf) { + c.readMu.Lock() + c.readRem = buf[n:] + c.readMu.Unlock() + } + return n, nil + case <-timer.C: + return 0, ErrTimeout + } +} + +func (c *virtConn) Write(p []byte) (int, error) { + c.writeMu.Lock() + defer c.writeMu.Unlock() + if c.writeClosed { + return 0, ErrClosed + } + // Writes are allowed even after remote FIN (half-close), so we do not block based on remote state. + total := 0 + for len(p) > 0 { + if !c.wrDeadline.IsZero() && time.Until(c.wrDeadline) <= 0 { + if total == 0 { + return 0, ErrTimeout + } + return total, ErrTimeout + } + chunk := p + if len(chunk) > maxWriteChunk { + chunk = p[:maxWriteChunk] + } + frame := &rtunpb.Frame{Sid: c.sid, Kind: &rtunpb.Frame_Data{Data: &rtunpb.Data{Payload: append([]byte(nil), chunk...)}}} + if err := c.mux.link.Send(frame); err != nil { + return total, err + } + c.onActivity() + if c.mux.m != nil { + c.mux.m.recordFrameTx(c.mux.link.Context(), "DATA") + c.mux.m.recordBytesTx(c.mux.link.Context(), int64(len(chunk))) + } + total += len(chunk) + p = p[len(chunk):] + } + return total, nil +} + +func (c *virtConn) Close() error { + c.writeMu.Lock() + if c.writeClosed { + c.writeMu.Unlock() + return nil + } + c.writeClosed = true + c.writeMu.Unlock() + // send FIN (ack=false) + _ = c.mux.link.Send(&rtunpb.Frame{Sid: c.sid, Kind: &rtunpb.Frame_Fin{Fin: &rtunpb.Fin{Ack: false}}}) + // Unblock any pending Read by closing the read side + c.closeReadOnce.Do(func() { + c.readMu.Lock() + if c.readErr == nil { + c.readErr = io.EOF + } + c.readMu.Unlock() + close(c.readCh) + }) + c.stopIdleTimer() + c.mux.removeConn(c.sid) + return nil +} + +func (c *virtConn) LocalAddr() net.Addr { return rtunAddr{"rtun-local"} } +func (c *virtConn) RemoteAddr() net.Addr { return rtunAddr{"rtun-remote"} } + +func (c *virtConn) SetDeadline(t time.Time) error { + c.readMu.Lock() + c.rdDeadline = t + c.readMu.Unlock() + c.writeMu.Lock() + c.wrDeadline = t + c.writeMu.Unlock() + return nil +} + +func (c *virtConn) SetReadDeadline(t time.Time) error { + c.readMu.Lock() + c.rdDeadline = t + c.readMu.Unlock() + return nil +} + +func (c *virtConn) SetWriteDeadline(t time.Time) error { + c.writeMu.Lock() + c.wrDeadline = t + c.writeMu.Unlock() + return nil +} + +// feedData is called by the mux to deliver inbound bytes. +func (c *virtConn) feedData(b []byte) { + // If we've already observed a terminal read error (EOF, overflow, RST), drop incoming data. + c.readMu.Lock() + alreadyErr := c.readErr + c.readMu.Unlock() + if alreadyErr != nil { + return + } + select { + case c.readCh <- b: + // delivered + c.onActivity() + default: + // backpressure: mark error, close channel, send RST, and detach from session to avoid further deliveries + c.readMu.Lock() + if c.readErr == nil { + c.readErr = errors.New("rtun: inbound buffer overflow") + } + c.readMu.Unlock() + c.closeReadOnce.Do(func() { close(c.readCh) }) + _ = c.mux.link.Send(&rtunpb.Frame{Sid: c.sid, Kind: &rtunpb.Frame_Rst{Rst: &rtunpb.Rst{Code: rtunpb.RstCode_RST_CODE_INTERNAL}}}) + c.stopIdleTimer() + c.mux.removeConn(c.sid) + } +} + +func (c *virtConn) handleFin(ack bool) { + // mark remote closed; signal EOF + c.readMu.Lock() + if c.readErr == nil { + c.readErr = io.EOF + } + c.readMu.Unlock() + c.closeReadOnce.Do(func() { close(c.readCh) }) + c.stopIdleTimer() +} + +func (c *virtConn) handleRst(err error) { + c.readMu.Lock() + if c.readErr == nil { + c.readErr = err + } + c.readMu.Unlock() + c.closeReadOnce.Do(func() { close(c.readCh) }) + c.writeMu.Lock() + c.writeClosed = true + c.writeMu.Unlock() + c.stopIdleTimer() +} + +type rtunAddr struct{ s string } + +func (a rtunAddr) Network() string { return "rtun" } +func (a rtunAddr) String() string { return a.s } + +// startIdleTimer starts or resets the per-connection idle timer according to the Session's configuration. +func (c *virtConn) startIdleTimer() { + timeout := c.mux.idleTimeout + if timeout < 0 { + return + } + c.idleMu.Lock() + if c.idleTimer == nil { + c.idleTimer = time.AfterFunc(timeout, func() { + c.handleIdleTimeout() + }) + } else { + c.idleTimer.Reset(timeout) + } + c.idleMu.Unlock() +} + +// onActivity resets the idle timer if enabled to reflect recent I/O activity. +func (c *virtConn) onActivity() { + timeout := c.mux.idleTimeout + if timeout < 0 { + return + } + c.idleMu.Lock() + if c.idleTimer != nil { + _ = c.idleTimer.Reset(timeout) + } + c.idleMu.Unlock() +} + +func (c *virtConn) stopIdleTimer() { + c.idleMu.Lock() + if c.idleTimer != nil { + c.idleTimer.Stop() + c.idleTimer = nil + } + c.idleMu.Unlock() +} + +func (c *virtConn) handleIdleTimeout() { + // Timer fired: send RST and tear down connection. + _ = c.mux.link.Send(&rtunpb.Frame{Sid: c.sid, Kind: &rtunpb.Frame_Rst{Rst: &rtunpb.Rst{Code: rtunpb.RstCode_RST_CODE_INTERNAL}}}) + c.handleRst(ErrTimeout) + c.mux.removeConn(c.sid) + c.mux.markClosed(c.sid) +} diff --git a/pkg/rtun/transport/conn_test.go b/pkg/rtun/transport/conn_test.go new file mode 100644 index 000000000..2b4e61ba5 --- /dev/null +++ b/pkg/rtun/transport/conn_test.go @@ -0,0 +1,93 @@ +package transport + +import ( + "context" + "io" + "net" + "testing" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" + "github.com/stretchr/testify/require" +) + +func TestVirtConnCloseIdempotentAndWriteAfterClose(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + t.Cleanup(func() { _ = ln.Close() }) + + accCh := make(chan net.Conn, 1) + go func() { c, _ := ln.Accept(); accCh <- c }() + tl.push(&rtunpb.Frame{Sid: 31, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + + var c net.Conn + select { + case c = <-accCh: + case <-ctx.Done(): + t.Fatal("accept timeout") + } + + // Close twice is ok + require.NoError(t, c.Close()) + require.NoError(t, c.Close()) + + // Write after close should fail with ErrClosed + _, err = c.Write([]byte("x")) + require.ErrorIs(t, err, ErrClosed) + + // Read after close yields EOF or ErrClosed + _, err = c.Read(make([]byte, 1)) + require.True(t, err == io.EOF || err == ErrClosed) +} + +func TestVirtConnRemoteRstPropagatesToRead(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + t.Cleanup(func() { _ = ln.Close() }) + + accCh := make(chan net.Conn, 1) + go func() { c, _ := ln.Accept(); accCh <- c }() + tl.push(&rtunpb.Frame{Sid: 32, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + var c net.Conn + select { + case c = <-accCh: + case <-ctx.Done(): + t.Fatal("accept timeout") + } + + // Remote sends RST + tl.push(&rtunpb.Frame{Sid: 32, Kind: &rtunpb.Frame_Rst{Rst: &rtunpb.Rst{Code: rtunpb.RstCode_RST_CODE_INTERNAL}}}) + _, err = c.Read(make([]byte, 1)) + require.ErrorIs(t, err, ErrConnReset) +} + +func TestVirtConnLocalRemoteAddr(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + t.Cleanup(func() { _ = ln.Close() }) + + accCh := make(chan net.Conn, 1) + go func() { c, _ := ln.Accept(); accCh <- c }() + tl.push(&rtunpb.Frame{Sid: 33, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + var c net.Conn + select { + case c = <-accCh: + case <-ctx.Done(): + t.Fatal("accept timeout") + } + + require.NotNil(t, c.LocalAddr()) + require.NotNil(t, c.RemoteAddr()) +} diff --git a/pkg/rtun/transport/errors.go b/pkg/rtun/transport/errors.go new file mode 100644 index 000000000..e18230dda --- /dev/null +++ b/pkg/rtun/transport/errors.go @@ -0,0 +1,9 @@ +package transport + +import "errors" + +var ( + ErrConnReset = errors.New("rtun: connection reset") + ErrClosed = errors.New("rtun: closed") + ErrTimeout = errors.New("rtun: deadline exceeded") +) diff --git a/pkg/rtun/transport/listener.go b/pkg/rtun/transport/listener.go new file mode 100644 index 000000000..c8fc1e3f4 --- /dev/null +++ b/pkg/rtun/transport/listener.go @@ -0,0 +1,63 @@ +package transport + +import ( + "net" + "sync" +) + +type rtunListener struct { + port uint32 + accepts chan net.Conn + mux *Session + mu sync.Mutex + closed bool + err error +} + +func (l *rtunListener) Accept() (net.Conn, error) { + if l.err != nil { + return nil, l.err + } + c, ok := <-l.accepts + if !ok { + if l.err != nil { + return nil, l.err + } + return nil, ErrClosed + } + return c, nil +} + +func (l *rtunListener) Close() error { + l.mu.Lock() + if l.closed { + l.mu.Unlock() + return nil + } + l.closed = true + l.mu.Unlock() + l.mux.removeListener(l.port) + close(l.accepts) + return nil +} + +func (l *rtunListener) Addr() net.Addr { return rtunAddr{"rtun-listener"} } + +func (l *rtunListener) enqueue(c *virtConn) { + select { + case l.accepts <- c: + default: + // listener full, drop + c.handleRst(ErrClosed) + } +} + +func (l *rtunListener) closeWithErr(err error) { + l.mu.Lock() + l.err = err + if !l.closed { + l.closed = true + close(l.accepts) + } + l.mu.Unlock() +} diff --git a/pkg/rtun/transport/session.go b/pkg/rtun/transport/session.go new file mode 100644 index 000000000..b3b8534d3 --- /dev/null +++ b/pkg/rtun/transport/session.go @@ -0,0 +1,474 @@ +package transport + +import ( + "context" + "errors" + "net" + "sync" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" + sdkmetrics "github.com/conductorone/baton-sdk/pkg/metrics" + "go.uber.org/zap" +) + +// Option configures a Session. +type Option func(*options) + +type options struct { + logger *zap.Logger + // allowedPorts is an allowlist of ports that the server is permitted to Open() toward the client. + // If nil or empty, all ports are allowed. + allowedPorts map[uint32]bool + // maxPendingSIDs caps how many distinct SIDs may accumulate DATA-before-SYN buffers. + maxPendingSIDs int + // idleTimeout controls per-SID idle expiration; zero means use default (10m). Negative disables. + idleTimeout time.Duration + // metrics handler (optional) + metrics sdkmetrics.Handler +} + +// WithLogger sets a structured logger for the session. +func WithLogger(l *zap.Logger) Option { + return func(o *options) { + o.logger = l + } +} + +// WithAllowedPorts sets the explicit list of allowed ports for Session.Open(). +func WithAllowedPorts(ports []uint32) Option { + return func(o *options) { + if len(ports) == 0 { + o.allowedPorts = nil + return + } + if o.allowedPorts == nil { + o.allowedPorts = make(map[uint32]bool, len(ports)) + } + for _, p := range ports { + o.allowedPorts[p] = true + } + } +} + +// WithMaxPendingSIDs sets the maximum number of distinct SIDs allowed to accumulate +// DATA-before-SYN pending buffers. Values <= 0 select the default (64). +func WithMaxPendingSIDs(n int) Option { + return func(o *options) { + o.maxPendingSIDs = n + } +} + +// WithIdleTimeout sets the per-SID idle timeout. Zero selects the default (10m). Negative disables. +func WithIdleTimeout(d time.Duration) Option { + return func(o *options) { + o.idleTimeout = d + } +} + +// WithMetricsHandler injects a metrics handler for transport-level metrics. +func WithMetricsHandler(h sdkmetrics.Handler) Option { + return func(o *options) { + o.metrics = h + } +} + +// Link is the minimal adapter that the generated gRPC stream satisfies. +// It is intentionally small to decouple from gRPC specifics and simplify testing. +type Link interface { + Send(*rtunpb.Frame) error + Recv() (*rtunpb.Frame, error) + Context() context.Context +} + +// Session represents a per-link dispatcher with a single Recv loop and listener/conn registry. +type Session struct { + link Link + logger *zap.Logger + + mu sync.Mutex + started bool + closing bool + conns map[uint32]*virtConn + listeners map[uint32]*rtunListener + nextSID uint32 + pending map[uint32][][]byte // queued DATA before SYN processed + closed closedSet // closed SIDs for late-frame detection + + // configuration + allowedPorts map[uint32]bool + maxPendingSIDs int + idleTimeout time.Duration + + // metrics (optional) + m *transportMetrics +} + +// NewSession constructs a per-link session. The Recv loop starts on first listener registration. +func NewSession(link Link, opts ...Option) *Session { + var o options + for _, opt := range opts { + opt(&o) + } + logger := o.logger + if logger == nil { + logger = zap.NewNop() + } + // defaults + maxPending := o.maxPendingSIDs + if maxPending <= 0 { + maxPending = 64 + } + idle := o.idleTimeout + if idle == 0 { + idle = 10 * time.Minute + } + s := &Session{ + link: link, + conns: make(map[uint32]*virtConn), + listeners: make(map[uint32]*rtunListener), + nextSID: 1, + pending: make(map[uint32][][]byte), + logger: logger, + allowedPorts: o.allowedPorts, + maxPendingSIDs: maxPending, + idleTimeout: idle, + } + if o.metrics != nil { + s.m = newTransportMetrics(o.metrics) + } + return s +} + +// Listen exposes a net.Listener for a numeric port on this Session. +func (s *Session) Listen(ctx context.Context, port uint32, opts ...Option) (net.Listener, error) { + l := &rtunListener{ + port: port, + accepts: make(chan net.Conn, 512), // effectively listener backlog + mux: s, + } + if err := s.addListener(l); err != nil { + return nil, err + } + s.startOnce() + return l, nil +} + +func (s *Session) removeConn(sid uint32) { + s.mu.Lock() + delete(s.conns, sid) + s.mu.Unlock() +} + +func (s *Session) addListener(l *rtunListener) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.closing { + return ErrClosed + } + if _, exists := s.listeners[l.port]; exists { + return errors.New("rtun: listener already exists for port") + } + if s.listeners == nil { + s.listeners = make(map[uint32]*rtunListener) + } + s.listeners[l.port] = l + return nil +} + +func (s *Session) removeListener(port uint32) { + s.mu.Lock() + delete(s.listeners, port) + s.mu.Unlock() +} + +func (s *Session) startOnce() { + s.mu.Lock() + if s.started { + s.mu.Unlock() + return + } + s.started = true + s.mu.Unlock() + go s.recvLoop() +} + +const maxPendingBufferSize = 64 * 1024 + +func (s *Session) recvLoop() { + for { + fr, err := s.link.Recv() + if err != nil { + s.mu.Lock() + for _, c := range s.conns { + c.handleRst(err) + } + for _, l := range s.listeners { + l.closeWithErr(err) + } + s.closing = true + s.mu.Unlock() + return + } + if fr == nil { + continue + } + sid := fr.GetSid() + if s.m != nil { + s.m.recordFrameRx(s.link.Context(), kindOf(fr)) + } + // sid==0 is invalid; treat SYN on sid 0 as fatal, drop other frames silently. + if sid == 0 { + if _, isSyn := fr.Kind.(*rtunpb.Frame_Syn); isSyn { + s.logger.Warn("protocol violation: SYN with sid 0; closing link") + s.failLocked(errors.New("rtun: protocol violation (sid 0)")) + return + } + continue + } + switch k := fr.Kind.(type) { + case *rtunpb.Frame_Syn: + port := k.Syn.GetPort() + s.mu.Lock() + l := s.listeners[port] + if l == nil { + s.mu.Unlock() + // Send RST outside lock + _ = s.link.Send(&rtunpb.Frame{Sid: sid, Kind: &rtunpb.Frame_Rst{Rst: &rtunpb.Rst{Code: rtunpb.RstCode_RST_CODE_NO_LISTENER}}}) + if s.m != nil { + s.m.recordRstSent(s.link.Context(), "no_listener") + } + continue + } + // Duplicate SYN on existing SID is a fatal protocol error for the link. + if _, exists := s.conns[sid]; exists || s.closed.IsClosed(sid) { + s.mu.Unlock() + s.logger.Warn("protocol violation: duplicate SYN for existing or closed SID; closing link", zap.Uint32("sid", sid)) + s.failLocked(errors.New("rtun: duplicate SYN")) + return + } + vc := newVirtConn(s, sid) + if s.conns == nil { + s.conns = make(map[uint32]*virtConn) + } + s.conns[sid] = vc + vc.startIdleTimer() + if s.m != nil { + s.m.incSidsActive(s.link.Context(), 1) + } + // Drain any pending data queued before SYN + if q := s.pending[sid]; len(q) > 0 { + for _, p := range q { + vc.feedData(p) + } + delete(s.pending, sid) + } + s.mu.Unlock() + l.enqueue(vc) + case *rtunpb.Frame_Data: + s.mu.Lock() + // Ignore if SID is closed (late frame) + if s.closed.IsClosed(sid) { + s.mu.Unlock() + continue + } + c := s.conns[sid] + payload := append([]byte(nil), k.Data.GetPayload()...) + if s.m != nil { + s.m.recordBytesRx(s.link.Context(), int64(len(payload))) + } + if c != nil { + s.mu.Unlock() + c.feedData(payload) + } else { + // defensive programming decision: DATA-before-SYN is a protocol error. RST and do not buffer. + s.mu.Unlock() + _ = s.link.Send(&rtunpb.Frame{Sid: sid, Kind: &rtunpb.Frame_Rst{Rst: &rtunpb.Rst{Code: rtunpb.RstCode_RST_CODE_INTERNAL}}}) + if s.m != nil { + s.m.recordRstSent(s.link.Context(), "protocol_violation") + } + continue + } + case *rtunpb.Frame_Fin: + s.mu.Lock() + c := s.conns[sid] + s.mu.Unlock() + if c != nil { + c.handleFin(k.Fin.GetAck()) + s.removeConn(sid) + s.mu.Lock() + s.closed.Close(sid) + s.mu.Unlock() + if s.m != nil { + s.m.incSidsActive(s.link.Context(), -1) + } + } + case *rtunpb.Frame_Rst: + s.mu.Lock() + c := s.conns[sid] + s.mu.Unlock() + if c != nil { + if s.m != nil { + s.m.recordRstRecv(s.link.Context(), k.Rst.GetCode().String()) + } + c.handleRst(ErrConnReset) + s.removeConn(sid) + s.mu.Lock() + s.closed.Close(sid) + s.mu.Unlock() + if s.m != nil { + s.m.incSidsActive(s.link.Context(), -1) + } + } + } + } +} + +// Start begins the Recv loop if not already started. +func (s *Session) Start() { s.startOnce() } + +// Open initiates a reverse connection to the remote client on the given port. +// It allocates a new SID, sends SYN, and returns a net.Conn bound to that SID. +// If ctx is canceled before SYN is sent, returns ctx.Err(). +func (s *Session) Open(ctx context.Context, port uint32) (net.Conn, error) { + // Check context before acquiring lock + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + // Enforce allowed ports policy if configured + if s.allowedPorts != nil { + if !s.allowedPorts[port] { + return nil, errors.New("rtun: port not allowed by HELLO policy") + } + } + + s.mu.Lock() + if s.closing { + s.mu.Unlock() + return nil, ErrClosed + } + sid := s.nextSID + if sid == 0 { + sid = 1 + } + s.nextSID = sid + 1 + vc := newVirtConn(s, sid) + s.conns[sid] = vc + s.mu.Unlock() + + // Send SYN to remote + if err := s.link.Send(&rtunpb.Frame{Sid: sid, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: port}}}); err != nil { + // Cleanup on failure + s.removeConn(sid) + return nil, err + } + vc.startIdleTimer() + if s.m != nil { + s.m.incSidsActive(s.link.Context(), 1) + s.m.recordFrameTx(s.link.Context(), "SYN") + } + return vc, nil +} + +// markClosed records the SID as closed to ignore late frames. +func (s *Session) markClosed(sid uint32) { + s.mu.Lock() + s.closed.Close(sid) + s.mu.Unlock() +} + +// failLocked closes all session resources and marks the session as closing. +// It must be called outside of s.mu (we only use it here where no lock is held), +// but it acquires the lock internally for safety. +func (s *Session) failLocked(err error) { + s.mu.Lock() + if s.closing { + s.mu.Unlock() + return + } + for _, c := range s.conns { + c.handleRst(err) + } + for _, l := range s.listeners { + l.closeWithErr(err) + } + s.closing = true + s.mu.Unlock() +} + +// metrics helpers +type transportMetrics struct { + framesRx sdkmetrics.Int64Counter + framesTx sdkmetrics.Int64Counter + bytesRx sdkmetrics.Int64Counter + bytesTx sdkmetrics.Int64Counter + rstSent sdkmetrics.Int64Counter + rstRecv sdkmetrics.Int64Counter + sidsGauge sdkmetrics.Int64Gauge + + sids int64 +} + +func newTransportMetrics(h sdkmetrics.Handler) *transportMetrics { + m := &transportMetrics{ + framesRx: h.Int64Counter("rtun.transport.frames_rx_total", "transport frames received", sdkmetrics.Dimensionless), + framesTx: h.Int64Counter("rtun.transport.frames_tx_total", "transport frames sent", sdkmetrics.Dimensionless), + bytesRx: h.Int64Counter("rtun.transport.data_bytes_rx_total", "transport bytes received", sdkmetrics.Bytes), + bytesTx: h.Int64Counter("rtun.transport.data_bytes_tx_total", "transport bytes sent", sdkmetrics.Bytes), + rstSent: h.Int64Counter("rtun.transport.rst_sent_total", "RST frames sent by code", sdkmetrics.Dimensionless), + rstRecv: h.Int64Counter("rtun.transport.rst_recv_total", "RST frames received by code", sdkmetrics.Dimensionless), + sidsGauge: h.Int64Gauge("rtun.transport.sids_active", "active SIDs per session", sdkmetrics.Dimensionless), + } + // initialize gauge to 0 + m.sidsGauge.Observe(context.Background(), 0, nil) + return m +} + +func (m *transportMetrics) incSidsActive(ctx context.Context, delta int64) { + m.sids += delta + m.sidsGauge.Observe(ctx, m.sids, nil) +} + +func (m *transportMetrics) recordFrameRx(ctx context.Context, kind string) { + m.framesRx.Add(ctx, 1, map[string]string{"kind": kind}) +} + +func (m *transportMetrics) recordFrameTx(ctx context.Context, kind string) { + m.framesTx.Add(ctx, 1, map[string]string{"kind": kind}) +} + +func (m *transportMetrics) recordBytesRx(ctx context.Context, n int64) { + m.bytesRx.Add(ctx, n, nil) +} + +func (m *transportMetrics) recordBytesTx(ctx context.Context, n int64) { + m.bytesTx.Add(ctx, n, nil) +} + +func (m *transportMetrics) recordRstSent(ctx context.Context, code string) { + m.rstSent.Add(ctx, 1, map[string]string{"code": code}) +} + +func (m *transportMetrics) recordRstRecv(ctx context.Context, code string) { + m.rstRecv.Add(ctx, 1, map[string]string{"code": code}) +} + +func kindOf(fr *rtunpb.Frame) string { + switch fr.Kind.(type) { + case *rtunpb.Frame_Hello: + return "HELLO" + case *rtunpb.Frame_Syn: + return "SYN" + case *rtunpb.Frame_Data: + return "DATA" + case *rtunpb.Frame_Fin: + return "FIN" + case *rtunpb.Frame_Rst: + return "RST" + default: + return "UNKNOWN" + } +} diff --git a/pkg/rtun/transport/session_race_test.go b/pkg/rtun/transport/session_race_test.go new file mode 100644 index 000000000..093c752ba --- /dev/null +++ b/pkg/rtun/transport/session_race_test.go @@ -0,0 +1,96 @@ +package transport + +import ( + "context" + "sync" + "testing" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" +) + +// TestConcurrentOperationsNoPanic validates no races/panics under concurrent use. +// Run with -race to detect data races. +func TestConcurrentOperationsNoPanic(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + var wg sync.WaitGroup + + // Concurrent Listen on ports 1-5 + for port := uint32(1); port <= 5; port++ { + wg.Add(1) + go func(p uint32) { + defer wg.Done() + ln, err := s.Listen(ctx, p) + if err == nil { + time.Sleep(10 * time.Millisecond) + ln.Close() + } + }(port) + } + + // Concurrent Open (reverse dial) + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + conn, err := s.Open(ctx, 1) + if err == nil { + time.Sleep(5 * time.Millisecond) + conn.Close() + } + }() + } + + // Feed random frames + for i := 0; i < 20; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + sid := uint32(100 + i) + tl.push(&rtunpb.Frame{Sid: sid, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + time.Sleep(2 * time.Millisecond) + tl.push(&rtunpb.Frame{Sid: sid, Kind: &rtunpb.Frame_Data{Data: &rtunpb.Data{Payload: []byte("x")}}}) + }(i) + } + + wg.Wait() + // No panic or deadlock = success +} + +func TestLateFrameAfterClose(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx := context.Background() + + ln, _ := s.Listen(ctx, 1) + defer ln.Close() + + // Establish conn + tl.push(&rtunpb.Frame{Sid: 99, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + conn, _ := ln.Accept() + conn.Close() + + // Give close time to propagate + time.Sleep(20 * time.Millisecond) + + // Send late DATA; should not panic + tl.push(&rtunpb.Frame{Sid: 99, Kind: &rtunpb.Frame_Data{Data: &rtunpb.Data{Payload: []byte("late")}}}) + time.Sleep(10 * time.Millisecond) + // No panic = success +} + +func TestContextCanceledOpen(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := s.Open(ctx, 1) + if err == nil { + t.Fatal("expected error on canceled context") + } +} diff --git a/pkg/rtun/transport/session_test.go b/pkg/rtun/transport/session_test.go new file mode 100644 index 000000000..7abe62794 --- /dev/null +++ b/pkg/rtun/transport/session_test.go @@ -0,0 +1,308 @@ +package transport + +import ( + "context" + "io" + "net" + "testing" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" + "github.com/stretchr/testify/require" +) + +type testLink struct { + ctx context.Context + inCh chan *rtunpb.Frame + sent []*rtunpb.Frame +} + +func newTestLink() *testLink { + return &testLink{ctx: context.Background(), inCh: make(chan *rtunpb.Frame, 32)} +} + +func (t *testLink) Send(f *rtunpb.Frame) error { + t.sent = append(t.sent, f) + return nil +} + +func (t *testLink) Recv() (*rtunpb.Frame, error) { return <-t.inCh, nil } +func (t *testLink) Context() context.Context { return t.ctx } + +func (t *testLink) push(f *rtunpb.Frame) { t.inCh <- f } + +func TestSessionSynAcceptDataFin(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + defer ln.Close() + + // Accept in background + accCh := make(chan net.Conn, 1) + go func() { + c, err := ln.Accept() + if err == nil { + accCh <- c + } + }() + + // Send SYN for port 1 on sid 5 + tl.push(&rtunpb.Frame{Sid: 5, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + var rwc net.Conn + select { + case rwc = <-accCh: + case <-ctx.Done(): + t.Fatal("accept timeout") + } + + // Deliver DATA from link to conn + tl.push(&rtunpb.Frame{Sid: 5, Kind: &rtunpb.Frame_Data{Data: &rtunpb.Data{Payload: []byte("hi")}}}) + buf := make([]byte, 2) + n, err := rwc.Read(buf) + require.NoError(t, err) + require.Equal(t, 2, n) + require.Equal(t, []byte("hi"), buf) + + // Write from conn to link; expect a DATA frame sent + _, err = rwc.Write([]byte("ok")) + require.NoError(t, err) + require.NotEmpty(t, tl.sent) + last := tl.sent[len(tl.sent)-1] + require.Equal(t, uint32(5), last.GetSid()) + require.IsType(t, &rtunpb.Frame_Data{}, last.Kind) + require.Equal(t, []byte("ok"), last.GetData().GetPayload()) + + // Send FIN from link; expect EOF on next read + tl.push(&rtunpb.Frame{Sid: 5, Kind: &rtunpb.Frame_Fin{Fin: &rtunpb.Fin{Ack: false}}}) + _, err = rwc.Read(make([]byte, 1)) + require.ErrorIs(t, err, io.EOF) + + // Close should send FIN + _ = rwc.Close() + require.NotEmpty(t, tl.sent) + foundFin := false + for _, f := range tl.sent { + if f.GetSid() == 5 { + if _, ok := f.Kind.(*rtunpb.Frame_Fin); ok { + foundFin = true + break + } + } + } + require.True(t, foundFin) +} + +func TestSessionSynNoListenerRst(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + // Start session loop without listener by poking it: creating a throwaway listener on port 1 and closing is too involved, + // so instead we trigger start by creating and closing a short-lived listener which ensures the goroutine is running. + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + ln.Close() + + // Send SYN for a port without a listener + tl.push(&rtunpb.Frame{Sid: 9, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 2}}}) + // Allow loop to process + time.Sleep(10 * time.Millisecond) + require.NotEmpty(t, tl.sent) + last := tl.sent[len(tl.sent)-1] + require.Equal(t, uint32(9), last.GetSid()) + require.IsType(t, &rtunpb.Frame_Rst{}, last.Kind) + require.Equal(t, rtunpb.RstCode_RST_CODE_NO_LISTENER, last.GetRst().GetCode()) +} + +func TestVirtConnReadDeadline(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + defer ln.Close() + + // Accept in background + accCh := make(chan net.Conn, 1) + go func() { + c, _ := ln.Accept() + accCh <- c + }() + + // Establish + tl.push(&rtunpb.Frame{Sid: 7, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + var rwc net.Conn + select { + case rwc = <-accCh: + case <-ctx.Done(): + t.Fatal("accept timeout") + } + + // Set a short read deadline and expect timeout + _ = rwc.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + buf := make([]byte, 1) + _, err = rwc.Read(buf) + require.ErrorIs(t, err, ErrTimeout) +} + +func TestVirtConnHalfCloseWriteAfterRemoteFin(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + defer ln.Close() + + accCh := make(chan net.Conn, 1) + go func() { + c, _ := ln.Accept() + accCh <- c + }() + tl.push(&rtunpb.Frame{Sid: 11, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + var rwc net.Conn + select { + case rwc = <-accCh: + case <-ctx.Done(): + t.Fatal("accept timeout") + } + // Remote sends FIN; we should still be able to Write() + tl.push(&rtunpb.Frame{Sid: 11, Kind: &rtunpb.Frame_Fin{Fin: &rtunpb.Fin{Ack: false}}}) + _, err = rwc.Write([]byte("after")) + require.NoError(t, err) +} + +func TestWriteFragmentationIntoMultipleDataFrames(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + defer ln.Close() + + accCh := make(chan net.Conn, 1) + go func() { c, _ := ln.Accept(); accCh <- c }() + tl.push(&rtunpb.Frame{Sid: 21, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + var rwc net.Conn + select { + case rwc = <-accCh: + case <-ctx.Done(): + t.Fatal("accept timeout") + } + + // Prepare a payload larger than maxWriteChunk (32KiB) to force fragmentation. + big := make([]byte, 32*1024+4096) + for i := range big { + big[i] = byte(i) + } + prev := len(tl.sent) + n, err := rwc.Write(big) + require.NoError(t, err) + require.Equal(t, len(big), n) + + // Collect new DATA frames for this SID and reassemble. + var collected []byte + for _, f := range tl.sent[prev:] { + if f.GetSid() == 21 { + if d := f.GetData(); d != nil { + collected = append(collected, d.GetPayload()...) + } + } + } + require.Equal(t, big, collected) +} + +func TestReadRemainderAcrossReads(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + defer ln.Close() + + accCh := make(chan net.Conn, 1) + go func() { c, _ := ln.Accept(); accCh <- c }() + tl.push(&rtunpb.Frame{Sid: 22, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + var rwc net.Conn + select { + case rwc = <-accCh: + case <-ctx.Done(): + t.Fatal("accept timeout") + } + + // Deliver 3 bytes, then read 2, then read 1 + tl.push(&rtunpb.Frame{Sid: 22, Kind: &rtunpb.Frame_Data{Data: &rtunpb.Data{Payload: []byte("xyz")}}}) + buf := make([]byte, 2) + n, err := rwc.Read(buf) + require.NoError(t, err) + require.Equal(t, 2, n) + require.Equal(t, []byte("xy"), buf) + buf2 := make([]byte, 2) + n, err = rwc.Read(buf2) + require.NoError(t, err) + require.Equal(t, 1, n) + require.Equal(t, []byte("z\x00"), buf2) // last byte remains zero +} + +func TestBackpressureInboundOverflow(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + defer ln.Close() + + accCh := make(chan net.Conn, 1) + go func() { c, _ := ln.Accept(); accCh <- c }() + tl.push(&rtunpb.Frame{Sid: 23, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + var rwc net.Conn + select { + case rwc = <-accCh: + case <-ctx.Done(): + t.Fatal("accept timeout") + } + + // Push more frames than the read buffer can hold (capacity 16) without reading. + for i := 0; i < 20; i++ { + tl.push(&rtunpb.Frame{Sid: 23, Kind: &rtunpb.Frame_Data{Data: &rtunpb.Data{Payload: []byte{byte(i)}}}}) + } + // Give the loop a moment to overflow and close. + time.Sleep(20 * time.Millisecond) + buf := make([]byte, 1) + _, err = rwc.Read(buf) + require.Error(t, err) + require.NotEqual(t, io.EOF, err) +} + +func TestWriteDeadlineTimeout(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + defer ln.Close() + + accCh := make(chan net.Conn, 1) + go func() { c, _ := ln.Accept(); accCh <- c }() + tl.push(&rtunpb.Frame{Sid: 24, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + var rwc net.Conn + select { + case rwc = <-accCh: + case <-ctx.Done(): + t.Fatal("accept timeout") + } + // Set deadline in the past and attempt write. + _ = rwc.SetWriteDeadline(time.Now().Add(-1 * time.Millisecond)) + _, err = rwc.Write([]byte("x")) + require.ErrorIs(t, err, ErrTimeout) +} diff --git a/proto/c1/connectorapi/rtun/v1/gateway.proto b/proto/c1/connectorapi/rtun/v1/gateway.proto new file mode 100644 index 000000000..88cca67fd --- /dev/null +++ b/proto/c1/connectorapi/rtun/v1/gateway.proto @@ -0,0 +1,49 @@ +syntax = "proto3"; + +package c1.connectorapi.rtun.v1; + +import "c1/connectorapi/rtun/v1/rtun.proto"; + +option go_package = "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1"; + +// ReverseDialer allows callers to establish connections to clients via the gateway. +// The gateway bridges caller streams to rtun sessions on the owner server. +service ReverseDialer { + rpc Open(stream GatewayRequest) returns (stream GatewayResponse); +} + +// GatewayRequest is sent from caller to gateway. +message GatewayRequest { + oneof kind { + OpenRequest open_req = 1; // initiate a connection (first message, or concurrent opens) + Frame frame = 10; // data/control frames (reuses rtun Frame) + } +} + +// GatewayResponse is sent from gateway to caller. +message GatewayResponse { + oneof kind { + OpenResponse open_resp = 1; // handshake result + Frame frame = 10; // data/control frames (reuses rtun Frame) + } +} + +// OpenRequest initiates a reverse connection to a client. +// The caller proposes a gSID (gateway SID) for this connection to support concurrent opens. +message OpenRequest { + uint32 gsid = 1; // caller-proposed SID for this connection (must be unique per stream) + string client_id = 2; // target client (must be URL-safe) + uint32 port = 3; // target port on the client +} + +// OpenResponse indicates the result of an OpenRequest. +message OpenResponse { + uint32 gsid = 1; // echoed from OpenRequest + oneof result { + NotFound not_found = 2; // gateway doesn't own this client; caller should re-resolve + Opened opened = 3; // success; use the gSID for subsequent frames + } +} + +message NotFound {} +message Opened {} diff --git a/proto/c1/connectorapi/rtun/v1/rtun.proto b/proto/c1/connectorapi/rtun/v1/rtun.proto new file mode 100644 index 000000000..922d94b87 --- /dev/null +++ b/proto/c1/connectorapi/rtun/v1/rtun.proto @@ -0,0 +1,49 @@ +syntax = "proto3"; + +package c1.connectorapi.rtun.v1; + +option go_package = "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1"; + +service ReverseTunnel { + rpc Link(stream Frame) returns (stream Frame); +} + +message Frame { + uint32 sid = 1; // 0 reserved for control + oneof kind { + Hello hello = 10; // client -> server (first) + Syn syn = 11; // server -> client (reverse open) + Data data = 12; // either direction + Fin fin = 13; // either direction + Rst rst = 14; // either direction + } +} + +message Hello { + repeated uint32 ports = 1; + uint32 protocol = 2; +} + +message Syn { + uint32 port = 1; +} + +message Data { + bytes payload = 1; +} + +message Fin { + bool ack = 1; +} + +enum RstCode { + RST_CODE_UNSPECIFIED = 0; + RST_CODE_NO_LISTENER = 1; + RST_CODE_PORT_NOT_ADVERTISED = 2; + RST_CODE_TIMEOUT = 3; + RST_CODE_INTERNAL = 4; +} + +message Rst { + RstCode code = 1; +} diff --git a/vendor/google.golang.org/grpc/health/client.go b/vendor/google.golang.org/grpc/health/client.go new file mode 100644 index 000000000..740745c45 --- /dev/null +++ b/vendor/google.golang.org/grpc/health/client.go @@ -0,0 +1,117 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package health + +import ( + "context" + "fmt" + "io" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" + healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/backoff" + "google.golang.org/grpc/status" +) + +var ( + backoffStrategy = backoff.DefaultExponential + backoffFunc = func(ctx context.Context, retries int) bool { + d := backoffStrategy.Backoff(retries) + timer := time.NewTimer(d) + select { + case <-timer.C: + return true + case <-ctx.Done(): + timer.Stop() + return false + } + } +) + +func init() { + internal.HealthCheckFunc = clientHealthCheck +} + +const healthCheckMethod = "/grpc.health.v1.Health/Watch" + +// This function implements the protocol defined at: +// https://github.com/grpc/grpc/blob/master/doc/health-checking.md +func clientHealthCheck(ctx context.Context, newStream func(string) (any, error), setConnectivityState func(connectivity.State, error), service string) error { + tryCnt := 0 + +retryConnection: + for { + // Backs off if the connection has failed in some way without receiving a message in the previous retry. + if tryCnt > 0 && !backoffFunc(ctx, tryCnt-1) { + return nil + } + tryCnt++ + + if ctx.Err() != nil { + return nil + } + setConnectivityState(connectivity.Connecting, nil) + rawS, err := newStream(healthCheckMethod) + if err != nil { + continue retryConnection + } + + s, ok := rawS.(grpc.ClientStream) + // Ideally, this should never happen. But if it happens, the server is marked as healthy for LBing purposes. + if !ok { + setConnectivityState(connectivity.Ready, nil) + return fmt.Errorf("newStream returned %v (type %T); want grpc.ClientStream", rawS, rawS) + } + + if err = s.SendMsg(&healthpb.HealthCheckRequest{Service: service}); err != nil && err != io.EOF { + // Stream should have been closed, so we can safely continue to create a new stream. + continue retryConnection + } + s.CloseSend() + + resp := new(healthpb.HealthCheckResponse) + for { + err = s.RecvMsg(resp) + + // Reports healthy for the LBing purposes if health check is not implemented in the server. + if status.Code(err) == codes.Unimplemented { + setConnectivityState(connectivity.Ready, nil) + return err + } + + // Reports unhealthy if server's Watch method gives an error other than UNIMPLEMENTED. + if err != nil { + setConnectivityState(connectivity.TransientFailure, fmt.Errorf("connection active but received health check RPC error: %v", err)) + continue retryConnection + } + + // As a message has been received, removes the need for backoff for the next retry by resetting the try count. + tryCnt = 0 + if resp.Status == healthpb.HealthCheckResponse_SERVING { + setConnectivityState(connectivity.Ready, nil) + } else { + setConnectivityState(connectivity.TransientFailure, fmt.Errorf("connection active but health check failed. status=%s", resp.Status)) + } + } + } +} diff --git a/vendor/google.golang.org/grpc/health/logging.go b/vendor/google.golang.org/grpc/health/logging.go new file mode 100644 index 000000000..83c6acf55 --- /dev/null +++ b/vendor/google.golang.org/grpc/health/logging.go @@ -0,0 +1,23 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package health + +import "google.golang.org/grpc/grpclog" + +var logger = grpclog.Component("health_service") diff --git a/vendor/google.golang.org/grpc/health/producer.go b/vendor/google.golang.org/grpc/health/producer.go new file mode 100644 index 000000000..f938e5790 --- /dev/null +++ b/vendor/google.golang.org/grpc/health/producer.go @@ -0,0 +1,106 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package health + +import ( + "context" + "sync" + + "google.golang.org/grpc" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/status" +) + +func init() { + producerBuilderSingleton = &producerBuilder{} + internal.RegisterClientHealthCheckListener = registerClientSideHealthCheckListener +} + +type producerBuilder struct{} + +var producerBuilderSingleton *producerBuilder + +// Build constructs and returns a producer and its cleanup function. +func (*producerBuilder) Build(cci any) (balancer.Producer, func()) { + p := &healthServiceProducer{ + cc: cci.(grpc.ClientConnInterface), + cancel: func() {}, + } + return p, func() { + p.mu.Lock() + defer p.mu.Unlock() + p.cancel() + } +} + +type healthServiceProducer struct { + // The following fields are initialized at build time and read-only after + // that and therefore do not need to be guarded by a mutex. + cc grpc.ClientConnInterface + + mu sync.Mutex + cancel func() +} + +// registerClientSideHealthCheckListener accepts a listener to provide server +// health state via the health service. +func registerClientSideHealthCheckListener(ctx context.Context, sc balancer.SubConn, serviceName string, listener func(balancer.SubConnState)) func() { + pr, closeFn := sc.GetOrBuildProducer(producerBuilderSingleton) + p := pr.(*healthServiceProducer) + p.mu.Lock() + defer p.mu.Unlock() + p.cancel() + if listener == nil { + return closeFn + } + + ctx, cancel := context.WithCancel(ctx) + p.cancel = cancel + + go p.startHealthCheck(ctx, sc, serviceName, listener) + return closeFn +} + +func (p *healthServiceProducer) startHealthCheck(ctx context.Context, sc balancer.SubConn, serviceName string, listener func(balancer.SubConnState)) { + newStream := func(method string) (any, error) { + return p.cc.NewStream(ctx, &grpc.StreamDesc{ServerStreams: true}, method) + } + + setConnectivityState := func(state connectivity.State, err error) { + listener(balancer.SubConnState{ + ConnectivityState: state, + ConnectionError: err, + }) + } + + // Call the function through the internal variable as tests use it for + // mocking. + err := internal.HealthCheckFunc(ctx, newStream, setConnectivityState, serviceName) + if err == nil { + return + } + if status.Code(err) == codes.Unimplemented { + logger.Errorf("Subchannel health check is unimplemented at server side, thus health check is disabled for SubConn %p", sc) + } else { + logger.Errorf("Health checking failed for SubConn %p: %v", sc, err) + } +} diff --git a/vendor/google.golang.org/grpc/health/server.go b/vendor/google.golang.org/grpc/health/server.go new file mode 100644 index 000000000..d4b4b7081 --- /dev/null +++ b/vendor/google.golang.org/grpc/health/server.go @@ -0,0 +1,163 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package health provides a service that exposes server's health and it must be +// imported to enable support for client-side health checks. +package health + +import ( + "context" + "sync" + + "google.golang.org/grpc/codes" + healthgrpc "google.golang.org/grpc/health/grpc_health_v1" + healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/status" +) + +// Server implements `service Health`. +type Server struct { + healthgrpc.UnimplementedHealthServer + mu sync.RWMutex + // If shutdown is true, it's expected all serving status is NOT_SERVING, and + // will stay in NOT_SERVING. + shutdown bool + // statusMap stores the serving status of the services this Server monitors. + statusMap map[string]healthpb.HealthCheckResponse_ServingStatus + updates map[string]map[healthgrpc.Health_WatchServer]chan healthpb.HealthCheckResponse_ServingStatus +} + +// NewServer returns a new Server. +func NewServer() *Server { + return &Server{ + statusMap: map[string]healthpb.HealthCheckResponse_ServingStatus{"": healthpb.HealthCheckResponse_SERVING}, + updates: make(map[string]map[healthgrpc.Health_WatchServer]chan healthpb.HealthCheckResponse_ServingStatus), + } +} + +// Check implements `service Health`. +func (s *Server) Check(_ context.Context, in *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { + s.mu.RLock() + defer s.mu.RUnlock() + if servingStatus, ok := s.statusMap[in.Service]; ok { + return &healthpb.HealthCheckResponse{ + Status: servingStatus, + }, nil + } + return nil, status.Error(codes.NotFound, "unknown service") +} + +// Watch implements `service Health`. +func (s *Server) Watch(in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error { + service := in.Service + // update channel is used for getting service status updates. + update := make(chan healthpb.HealthCheckResponse_ServingStatus, 1) + s.mu.Lock() + // Puts the initial status to the channel. + if servingStatus, ok := s.statusMap[service]; ok { + update <- servingStatus + } else { + update <- healthpb.HealthCheckResponse_SERVICE_UNKNOWN + } + + // Registers the update channel to the correct place in the updates map. + if _, ok := s.updates[service]; !ok { + s.updates[service] = make(map[healthgrpc.Health_WatchServer]chan healthpb.HealthCheckResponse_ServingStatus) + } + s.updates[service][stream] = update + defer func() { + s.mu.Lock() + delete(s.updates[service], stream) + s.mu.Unlock() + }() + s.mu.Unlock() + + var lastSentStatus healthpb.HealthCheckResponse_ServingStatus = -1 + for { + select { + // Status updated. Sends the up-to-date status to the client. + case servingStatus := <-update: + if lastSentStatus == servingStatus { + continue + } + lastSentStatus = servingStatus + err := stream.Send(&healthpb.HealthCheckResponse{Status: servingStatus}) + if err != nil { + return status.Error(codes.Canceled, "Stream has ended.") + } + // Context done. Removes the update channel from the updates map. + case <-stream.Context().Done(): + return status.Error(codes.Canceled, "Stream has ended.") + } + } +} + +// SetServingStatus is called when need to reset the serving status of a service +// or insert a new service entry into the statusMap. +func (s *Server) SetServingStatus(service string, servingStatus healthpb.HealthCheckResponse_ServingStatus) { + s.mu.Lock() + defer s.mu.Unlock() + if s.shutdown { + logger.Infof("health: status changing for %s to %v is ignored because health service is shutdown", service, servingStatus) + return + } + + s.setServingStatusLocked(service, servingStatus) +} + +func (s *Server) setServingStatusLocked(service string, servingStatus healthpb.HealthCheckResponse_ServingStatus) { + s.statusMap[service] = servingStatus + for _, update := range s.updates[service] { + // Clears previous updates, that are not sent to the client, from the channel. + // This can happen if the client is not reading and the server gets flow control limited. + select { + case <-update: + default: + } + // Puts the most recent update to the channel. + update <- servingStatus + } +} + +// Shutdown sets all serving status to NOT_SERVING, and configures the server to +// ignore all future status changes. +// +// This changes serving status for all services. To set status for a particular +// services, call SetServingStatus(). +func (s *Server) Shutdown() { + s.mu.Lock() + defer s.mu.Unlock() + s.shutdown = true + for service := range s.statusMap { + s.setServingStatusLocked(service, healthpb.HealthCheckResponse_NOT_SERVING) + } +} + +// Resume sets all serving status to SERVING, and configures the server to +// accept all future status changes. +// +// This changes serving status for all services. To set status for a particular +// services, call SetServingStatus(). +func (s *Server) Resume() { + s.mu.Lock() + defer s.mu.Unlock() + s.shutdown = false + for service := range s.statusMap { + s.setServingStatusLocked(service, healthpb.HealthCheckResponse_SERVING) + } +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 299b72801..d7cad6e22 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -623,6 +623,7 @@ google.golang.org/grpc/encoding/proto google.golang.org/grpc/experimental/stats google.golang.org/grpc/grpclog google.golang.org/grpc/grpclog/internal +google.golang.org/grpc/health google.golang.org/grpc/health/grpc_health_v1 google.golang.org/grpc/internal google.golang.org/grpc/internal/backoff From 65d965231bc572b876352baa261718c0e5af9823 Mon Sep 17 00:00:00 2001 From: Paul Querna Date: Sun, 12 Oct 2025 19:14:43 -0700 Subject: [PATCH 2/9] review feedback --- pb/c1/connectorapi/rtun/v1/gateway.pb.go | 226 +++++++------- .../rtun/v1/gateway.pb.validate.go | 164 +++++----- pb/c1/connectorapi/rtun/v1/gateway_grpc.pb.go | 68 ++--- pb/c1/connectorapi/rtun/v1/rtun.pb.go | 279 ++++++++++++------ .../connectorapi/rtun/v1/rtun.pb.validate.go | 266 +++++++++++++++++ pb/c1/connectorapi/rtun/v1/rtun_grpc.pb.go | 68 ++--- pkg/rtun/gateway/client.go | 20 +- pkg/rtun/gateway/client_conn_test.go | 6 +- pkg/rtun/gateway/integration_test.go | 26 +- pkg/rtun/gateway/server.go | 29 +- pkg/rtun/match/route_test.go | 22 +- pkg/rtun/server/handler.go | 24 +- pkg/rtun/server/server_integration_test.go | 22 +- pkg/rtun/transport/conn.go | 2 +- pkg/rtun/transport/conn_test.go | 4 +- pkg/rtun/transport/errors.go | 1 - pkg/rtun/transport/listener.go | 18 +- pkg/rtun/transport/session.go | 4 +- proto/c1/connectorapi/rtun/v1/gateway.proto | 12 +- proto/c1/connectorapi/rtun/v1/rtun.proto | 12 +- 20 files changed, 856 insertions(+), 417 deletions(-) diff --git a/pb/c1/connectorapi/rtun/v1/gateway.pb.go b/pb/c1/connectorapi/rtun/v1/gateway.pb.go index 25f0ec02b..fb508fc42 100644 --- a/pb/c1/connectorapi/rtun/v1/gateway.pb.go +++ b/pb/c1/connectorapi/rtun/v1/gateway.pb.go @@ -21,32 +21,32 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -// GatewayRequest is sent from caller to gateway. -type GatewayRequest struct { +// ReverseDialerServiceOpenRequest is sent from caller to gateway for the Open RPC. +type ReverseDialerServiceOpenRequest struct { state protoimpl.MessageState `protogen:"open.v1"` // Types that are valid to be assigned to Kind: // - // *GatewayRequest_OpenReq - // *GatewayRequest_Frame - Kind isGatewayRequest_Kind `protobuf_oneof:"kind"` + // *ReverseDialerServiceOpenRequest_OpenReq + // *ReverseDialerServiceOpenRequest_Frame + Kind isReverseDialerServiceOpenRequest_Kind `protobuf_oneof:"kind"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *GatewayRequest) Reset() { - *x = GatewayRequest{} +func (x *ReverseDialerServiceOpenRequest) Reset() { + *x = ReverseDialerServiceOpenRequest{} mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *GatewayRequest) String() string { +func (x *ReverseDialerServiceOpenRequest) String() string { return protoimpl.X.MessageStringOf(x) } -func (*GatewayRequest) ProtoMessage() {} +func (*ReverseDialerServiceOpenRequest) ProtoMessage() {} -func (x *GatewayRequest) ProtoReflect() protoreflect.Message { +func (x *ReverseDialerServiceOpenRequest) ProtoReflect() protoreflect.Message { mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[0] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -58,78 +58,78 @@ func (x *GatewayRequest) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use GatewayRequest.ProtoReflect.Descriptor instead. -func (*GatewayRequest) Descriptor() ([]byte, []int) { +// Deprecated: Use ReverseDialerServiceOpenRequest.ProtoReflect.Descriptor instead. +func (*ReverseDialerServiceOpenRequest) Descriptor() ([]byte, []int) { return file_c1_connectorapi_rtun_v1_gateway_proto_rawDescGZIP(), []int{0} } -func (x *GatewayRequest) GetKind() isGatewayRequest_Kind { +func (x *ReverseDialerServiceOpenRequest) GetKind() isReverseDialerServiceOpenRequest_Kind { if x != nil { return x.Kind } return nil } -func (x *GatewayRequest) GetOpenReq() *OpenRequest { +func (x *ReverseDialerServiceOpenRequest) GetOpenReq() *OpenRequest { if x != nil { - if x, ok := x.Kind.(*GatewayRequest_OpenReq); ok { + if x, ok := x.Kind.(*ReverseDialerServiceOpenRequest_OpenReq); ok { return x.OpenReq } } return nil } -func (x *GatewayRequest) GetFrame() *Frame { +func (x *ReverseDialerServiceOpenRequest) GetFrame() *Frame { if x != nil { - if x, ok := x.Kind.(*GatewayRequest_Frame); ok { + if x, ok := x.Kind.(*ReverseDialerServiceOpenRequest_Frame); ok { return x.Frame } } return nil } -type isGatewayRequest_Kind interface { - isGatewayRequest_Kind() +type isReverseDialerServiceOpenRequest_Kind interface { + isReverseDialerServiceOpenRequest_Kind() } -type GatewayRequest_OpenReq struct { +type ReverseDialerServiceOpenRequest_OpenReq struct { OpenReq *OpenRequest `protobuf:"bytes,1,opt,name=open_req,json=openReq,proto3,oneof"` // initiate a connection (first message, or concurrent opens) } -type GatewayRequest_Frame struct { +type ReverseDialerServiceOpenRequest_Frame struct { Frame *Frame `protobuf:"bytes,10,opt,name=frame,proto3,oneof"` // data/control frames (reuses rtun Frame) } -func (*GatewayRequest_OpenReq) isGatewayRequest_Kind() {} +func (*ReverseDialerServiceOpenRequest_OpenReq) isReverseDialerServiceOpenRequest_Kind() {} -func (*GatewayRequest_Frame) isGatewayRequest_Kind() {} +func (*ReverseDialerServiceOpenRequest_Frame) isReverseDialerServiceOpenRequest_Kind() {} -// GatewayResponse is sent from gateway to caller. -type GatewayResponse struct { +// ReverseDialerServiceOpenResponse is sent from gateway to caller for the Open RPC. +type ReverseDialerServiceOpenResponse struct { state protoimpl.MessageState `protogen:"open.v1"` // Types that are valid to be assigned to Kind: // - // *GatewayResponse_OpenResp - // *GatewayResponse_Frame - Kind isGatewayResponse_Kind `protobuf_oneof:"kind"` + // *ReverseDialerServiceOpenResponse_OpenResp + // *ReverseDialerServiceOpenResponse_Frame + Kind isReverseDialerServiceOpenResponse_Kind `protobuf_oneof:"kind"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *GatewayResponse) Reset() { - *x = GatewayResponse{} +func (x *ReverseDialerServiceOpenResponse) Reset() { + *x = ReverseDialerServiceOpenResponse{} mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *GatewayResponse) String() string { +func (x *ReverseDialerServiceOpenResponse) String() string { return protoimpl.X.MessageStringOf(x) } -func (*GatewayResponse) ProtoMessage() {} +func (*ReverseDialerServiceOpenResponse) ProtoMessage() {} -func (x *GatewayResponse) ProtoReflect() protoreflect.Message { +func (x *ReverseDialerServiceOpenResponse) ProtoReflect() protoreflect.Message { mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[1] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -141,51 +141,51 @@ func (x *GatewayResponse) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use GatewayResponse.ProtoReflect.Descriptor instead. -func (*GatewayResponse) Descriptor() ([]byte, []int) { +// Deprecated: Use ReverseDialerServiceOpenResponse.ProtoReflect.Descriptor instead. +func (*ReverseDialerServiceOpenResponse) Descriptor() ([]byte, []int) { return file_c1_connectorapi_rtun_v1_gateway_proto_rawDescGZIP(), []int{1} } -func (x *GatewayResponse) GetKind() isGatewayResponse_Kind { +func (x *ReverseDialerServiceOpenResponse) GetKind() isReverseDialerServiceOpenResponse_Kind { if x != nil { return x.Kind } return nil } -func (x *GatewayResponse) GetOpenResp() *OpenResponse { +func (x *ReverseDialerServiceOpenResponse) GetOpenResp() *OpenResponse { if x != nil { - if x, ok := x.Kind.(*GatewayResponse_OpenResp); ok { + if x, ok := x.Kind.(*ReverseDialerServiceOpenResponse_OpenResp); ok { return x.OpenResp } } return nil } -func (x *GatewayResponse) GetFrame() *Frame { +func (x *ReverseDialerServiceOpenResponse) GetFrame() *Frame { if x != nil { - if x, ok := x.Kind.(*GatewayResponse_Frame); ok { + if x, ok := x.Kind.(*ReverseDialerServiceOpenResponse_Frame); ok { return x.Frame } } return nil } -type isGatewayResponse_Kind interface { - isGatewayResponse_Kind() +type isReverseDialerServiceOpenResponse_Kind interface { + isReverseDialerServiceOpenResponse_Kind() } -type GatewayResponse_OpenResp struct { +type ReverseDialerServiceOpenResponse_OpenResp struct { OpenResp *OpenResponse `protobuf:"bytes,1,opt,name=open_resp,json=openResp,proto3,oneof"` // handshake result } -type GatewayResponse_Frame struct { +type ReverseDialerServiceOpenResponse_Frame struct { Frame *Frame `protobuf:"bytes,10,opt,name=frame,proto3,oneof"` // data/control frames (reuses rtun Frame) } -func (*GatewayResponse_OpenResp) isGatewayResponse_Kind() {} +func (*ReverseDialerServiceOpenResponse_OpenResp) isReverseDialerServiceOpenResponse_Kind() {} -func (*GatewayResponse_Frame) isGatewayResponse_Kind() {} +func (*ReverseDialerServiceOpenResponse_Frame) isReverseDialerServiceOpenResponse_Kind() {} // OpenRequest initiates a reverse connection to a client. // The caller proposes a gSID (gateway SID) for this connection to support concurrent opens. @@ -421,55 +421,59 @@ var file_c1_connectorapi_rtun_v1_gateway_proto_rawDesc = string([]byte{ 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x1a, 0x22, 0x63, 0x31, 0x2f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2f, 0x72, 0x74, 0x75, 0x6e, 0x2f, 0x76, 0x31, 0x2f, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x93, 0x01, 0x0a, 0x0e, 0x47, 0x61, 0x74, 0x65, 0x77, 0x61, 0x79, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x41, 0x0a, 0x08, 0x6f, 0x70, 0x65, 0x6e, 0x5f, - 0x72, 0x65, 0x71, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x63, 0x31, 0x2e, 0x63, - 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, - 0x2e, 0x76, 0x31, 0x2e, 0x4f, 0x70, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, - 0x00, 0x52, 0x07, 0x6f, 0x70, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x12, 0x36, 0x0a, 0x05, 0x66, 0x72, - 0x61, 0x6d, 0x65, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x63, 0x31, 0x2e, 0x63, - 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, - 0x2e, 0x76, 0x31, 0x2e, 0x46, 0x72, 0x61, 0x6d, 0x65, 0x48, 0x00, 0x52, 0x05, 0x66, 0x72, 0x61, - 0x6d, 0x65, 0x42, 0x06, 0x0a, 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x22, 0x97, 0x01, 0x0a, 0x0f, 0x47, - 0x61, 0x74, 0x65, 0x77, 0x61, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x44, - 0x0a, 0x09, 0x6f, 0x70, 0x65, 0x6e, 0x5f, 0x72, 0x65, 0x73, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x25, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, + 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xa4, 0x01, 0x0a, 0x1f, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, + 0x44, 0x69, 0x61, 0x6c, 0x65, 0x72, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4f, 0x70, 0x65, + 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x41, 0x0a, 0x08, 0x6f, 0x70, 0x65, 0x6e, + 0x5f, 0x72, 0x65, 0x71, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x63, 0x31, 0x2e, + 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, + 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x4f, 0x70, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x48, 0x00, 0x52, 0x07, 0x6f, 0x70, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x12, 0x36, 0x0a, 0x05, 0x66, + 0x72, 0x61, 0x6d, 0x65, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x63, 0x31, 0x2e, + 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, + 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x46, 0x72, 0x61, 0x6d, 0x65, 0x48, 0x00, 0x52, 0x05, 0x66, 0x72, + 0x61, 0x6d, 0x65, 0x42, 0x06, 0x0a, 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x22, 0xa8, 0x01, 0x0a, 0x20, + 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x44, 0x69, 0x61, 0x6c, 0x65, 0x72, 0x53, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x4f, 0x70, 0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x44, 0x0a, 0x09, 0x6f, 0x70, 0x65, 0x6e, 0x5f, 0x72, 0x65, 0x73, 0x70, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x25, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, + 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x4f, 0x70, + 0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x48, 0x00, 0x52, 0x08, 0x6f, 0x70, + 0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x12, 0x36, 0x0a, 0x05, 0x66, 0x72, 0x61, 0x6d, 0x65, 0x18, + 0x0a, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, + 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, + 0x46, 0x72, 0x61, 0x6d, 0x65, 0x48, 0x00, 0x52, 0x05, 0x66, 0x72, 0x61, 0x6d, 0x65, 0x42, 0x06, + 0x0a, 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x22, 0x52, 0x0a, 0x0b, 0x4f, 0x70, 0x65, 0x6e, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x67, 0x73, 0x69, 0x64, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0d, 0x52, 0x04, 0x67, 0x73, 0x69, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x63, 0x6c, 0x69, + 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x63, 0x6c, + 0x69, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x22, 0xa9, 0x01, 0x0a, 0x0c, 0x4f, + 0x70, 0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x67, + 0x73, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x67, 0x73, 0x69, 0x64, 0x12, + 0x40, 0x0a, 0x09, 0x6e, 0x6f, 0x74, 0x5f, 0x66, 0x6f, 0x75, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, + 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x6f, 0x74, + 0x46, 0x6f, 0x75, 0x6e, 0x64, 0x48, 0x00, 0x52, 0x08, 0x6e, 0x6f, 0x74, 0x46, 0x6f, 0x75, 0x6e, + 0x64, 0x12, 0x39, 0x0a, 0x06, 0x6f, 0x70, 0x65, 0x6e, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x1f, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x4f, 0x70, 0x65, 0x6e, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x48, 0x00, 0x52, 0x08, 0x6f, 0x70, 0x65, 0x6e, - 0x52, 0x65, 0x73, 0x70, 0x12, 0x36, 0x0a, 0x05, 0x66, 0x72, 0x61, 0x6d, 0x65, 0x18, 0x0a, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, - 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x46, 0x72, - 0x61, 0x6d, 0x65, 0x48, 0x00, 0x52, 0x05, 0x66, 0x72, 0x61, 0x6d, 0x65, 0x42, 0x06, 0x0a, 0x04, - 0x6b, 0x69, 0x6e, 0x64, 0x22, 0x52, 0x0a, 0x0b, 0x4f, 0x70, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x67, 0x73, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0d, 0x52, 0x04, 0x67, 0x73, 0x69, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x63, 0x6c, 0x69, 0x65, 0x6e, - 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, - 0x6e, 0x74, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x0d, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x22, 0xa9, 0x01, 0x0a, 0x0c, 0x4f, 0x70, 0x65, - 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x67, 0x73, 0x69, - 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x67, 0x73, 0x69, 0x64, 0x12, 0x40, 0x0a, - 0x09, 0x6e, 0x6f, 0x74, 0x5f, 0x66, 0x6f, 0x75, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x21, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, - 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x6f, 0x74, 0x46, 0x6f, - 0x75, 0x6e, 0x64, 0x48, 0x00, 0x52, 0x08, 0x6e, 0x6f, 0x74, 0x46, 0x6f, 0x75, 0x6e, 0x64, 0x12, - 0x39, 0x0a, 0x06, 0x6f, 0x70, 0x65, 0x6e, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x1f, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, - 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x4f, 0x70, 0x65, 0x6e, 0x65, 0x64, - 0x48, 0x00, 0x52, 0x06, 0x6f, 0x70, 0x65, 0x6e, 0x65, 0x64, 0x42, 0x08, 0x0a, 0x06, 0x72, 0x65, - 0x73, 0x75, 0x6c, 0x74, 0x22, 0x0a, 0x0a, 0x08, 0x4e, 0x6f, 0x74, 0x46, 0x6f, 0x75, 0x6e, 0x64, - 0x22, 0x08, 0x0a, 0x06, 0x4f, 0x70, 0x65, 0x6e, 0x65, 0x64, 0x32, 0x6e, 0x0a, 0x0d, 0x52, 0x65, - 0x76, 0x65, 0x72, 0x73, 0x65, 0x44, 0x69, 0x61, 0x6c, 0x65, 0x72, 0x12, 0x5d, 0x0a, 0x04, 0x4f, - 0x70, 0x65, 0x6e, 0x12, 0x27, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, - 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x47, 0x61, - 0x74, 0x65, 0x77, 0x61, 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x63, - 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, - 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x47, 0x61, 0x74, 0x65, 0x77, 0x61, 0x79, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x28, 0x01, 0x30, 0x01, 0x42, 0x3e, 0x5a, 0x3c, 0x67, 0x69, - 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x6e, 0x64, 0x75, 0x63, 0x74, - 0x6f, 0x72, 0x6f, 0x6e, 0x65, 0x2f, 0x62, 0x61, 0x74, 0x6f, 0x6e, 0x2d, 0x73, 0x64, 0x6b, 0x2f, - 0x70, 0x62, 0x2f, 0x63, 0x31, 0x2f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, - 0x70, 0x69, 0x2f, 0x72, 0x74, 0x75, 0x6e, 0x2f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, + 0x65, 0x64, 0x48, 0x00, 0x52, 0x06, 0x6f, 0x70, 0x65, 0x6e, 0x65, 0x64, 0x42, 0x08, 0x0a, 0x06, + 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x0a, 0x0a, 0x08, 0x4e, 0x6f, 0x74, 0x46, 0x6f, 0x75, + 0x6e, 0x64, 0x22, 0x08, 0x0a, 0x06, 0x4f, 0x70, 0x65, 0x6e, 0x65, 0x64, 0x32, 0x97, 0x01, 0x0a, + 0x14, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x44, 0x69, 0x61, 0x6c, 0x65, 0x72, 0x53, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x7f, 0x0a, 0x04, 0x4f, 0x70, 0x65, 0x6e, 0x12, 0x38, 0x2e, + 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, + 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x44, + 0x69, 0x61, 0x6c, 0x65, 0x72, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4f, 0x70, 0x65, 0x6e, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x39, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, + 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, + 0x31, 0x2e, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x44, 0x69, 0x61, 0x6c, 0x65, 0x72, 0x53, + 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4f, 0x70, 0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x28, 0x01, 0x30, 0x01, 0x42, 0x3e, 0x5a, 0x3c, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, + 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x6e, 0x64, 0x75, 0x63, 0x74, 0x6f, 0x72, 0x6f, 0x6e, + 0x65, 0x2f, 0x62, 0x61, 0x74, 0x6f, 0x6e, 0x2d, 0x73, 0x64, 0x6b, 0x2f, 0x70, 0x62, 0x2f, 0x63, + 0x31, 0x2f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2f, 0x72, + 0x74, 0x75, 0x6e, 0x2f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, }) var ( @@ -486,23 +490,23 @@ func file_c1_connectorapi_rtun_v1_gateway_proto_rawDescGZIP() []byte { var file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes = make([]protoimpl.MessageInfo, 6) var file_c1_connectorapi_rtun_v1_gateway_proto_goTypes = []any{ - (*GatewayRequest)(nil), // 0: c1.connectorapi.rtun.v1.GatewayRequest - (*GatewayResponse)(nil), // 1: c1.connectorapi.rtun.v1.GatewayResponse - (*OpenRequest)(nil), // 2: c1.connectorapi.rtun.v1.OpenRequest - (*OpenResponse)(nil), // 3: c1.connectorapi.rtun.v1.OpenResponse - (*NotFound)(nil), // 4: c1.connectorapi.rtun.v1.NotFound - (*Opened)(nil), // 5: c1.connectorapi.rtun.v1.Opened - (*Frame)(nil), // 6: c1.connectorapi.rtun.v1.Frame + (*ReverseDialerServiceOpenRequest)(nil), // 0: c1.connectorapi.rtun.v1.ReverseDialerServiceOpenRequest + (*ReverseDialerServiceOpenResponse)(nil), // 1: c1.connectorapi.rtun.v1.ReverseDialerServiceOpenResponse + (*OpenRequest)(nil), // 2: c1.connectorapi.rtun.v1.OpenRequest + (*OpenResponse)(nil), // 3: c1.connectorapi.rtun.v1.OpenResponse + (*NotFound)(nil), // 4: c1.connectorapi.rtun.v1.NotFound + (*Opened)(nil), // 5: c1.connectorapi.rtun.v1.Opened + (*Frame)(nil), // 6: c1.connectorapi.rtun.v1.Frame } var file_c1_connectorapi_rtun_v1_gateway_proto_depIdxs = []int32{ - 2, // 0: c1.connectorapi.rtun.v1.GatewayRequest.open_req:type_name -> c1.connectorapi.rtun.v1.OpenRequest - 6, // 1: c1.connectorapi.rtun.v1.GatewayRequest.frame:type_name -> c1.connectorapi.rtun.v1.Frame - 3, // 2: c1.connectorapi.rtun.v1.GatewayResponse.open_resp:type_name -> c1.connectorapi.rtun.v1.OpenResponse - 6, // 3: c1.connectorapi.rtun.v1.GatewayResponse.frame:type_name -> c1.connectorapi.rtun.v1.Frame + 2, // 0: c1.connectorapi.rtun.v1.ReverseDialerServiceOpenRequest.open_req:type_name -> c1.connectorapi.rtun.v1.OpenRequest + 6, // 1: c1.connectorapi.rtun.v1.ReverseDialerServiceOpenRequest.frame:type_name -> c1.connectorapi.rtun.v1.Frame + 3, // 2: c1.connectorapi.rtun.v1.ReverseDialerServiceOpenResponse.open_resp:type_name -> c1.connectorapi.rtun.v1.OpenResponse + 6, // 3: c1.connectorapi.rtun.v1.ReverseDialerServiceOpenResponse.frame:type_name -> c1.connectorapi.rtun.v1.Frame 4, // 4: c1.connectorapi.rtun.v1.OpenResponse.not_found:type_name -> c1.connectorapi.rtun.v1.NotFound 5, // 5: c1.connectorapi.rtun.v1.OpenResponse.opened:type_name -> c1.connectorapi.rtun.v1.Opened - 0, // 6: c1.connectorapi.rtun.v1.ReverseDialer.Open:input_type -> c1.connectorapi.rtun.v1.GatewayRequest - 1, // 7: c1.connectorapi.rtun.v1.ReverseDialer.Open:output_type -> c1.connectorapi.rtun.v1.GatewayResponse + 0, // 6: c1.connectorapi.rtun.v1.ReverseDialerService.Open:input_type -> c1.connectorapi.rtun.v1.ReverseDialerServiceOpenRequest + 1, // 7: c1.connectorapi.rtun.v1.ReverseDialerService.Open:output_type -> c1.connectorapi.rtun.v1.ReverseDialerServiceOpenResponse 7, // [7:8] is the sub-list for method output_type 6, // [6:7] is the sub-list for method input_type 6, // [6:6] is the sub-list for extension type_name @@ -517,12 +521,12 @@ func file_c1_connectorapi_rtun_v1_gateway_proto_init() { } file_c1_connectorapi_rtun_v1_rtun_proto_init() file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[0].OneofWrappers = []any{ - (*GatewayRequest_OpenReq)(nil), - (*GatewayRequest_Frame)(nil), + (*ReverseDialerServiceOpenRequest_OpenReq)(nil), + (*ReverseDialerServiceOpenRequest_Frame)(nil), } file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[1].OneofWrappers = []any{ - (*GatewayResponse_OpenResp)(nil), - (*GatewayResponse_Frame)(nil), + (*ReverseDialerServiceOpenResponse_OpenResp)(nil), + (*ReverseDialerServiceOpenResponse_Frame)(nil), } file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[3].OneofWrappers = []any{ (*OpenResponse_NotFound)(nil), diff --git a/pb/c1/connectorapi/rtun/v1/gateway.pb.validate.go b/pb/c1/connectorapi/rtun/v1/gateway.pb.validate.go index 6a6eea901..ba6b63815 100644 --- a/pb/c1/connectorapi/rtun/v1/gateway.pb.validate.go +++ b/pb/c1/connectorapi/rtun/v1/gateway.pb.validate.go @@ -35,22 +35,22 @@ var ( _ = sort.Sort ) -// Validate checks the field values on GatewayRequest with the rules defined in -// the proto definition for this message. If any rules are violated, the first -// error encountered is returned, or nil if there are no violations. -func (m *GatewayRequest) Validate() error { +// Validate checks the field values on ReverseDialerServiceOpenRequest with the +// rules defined in the proto definition for this message. If any rules are +// violated, the first error encountered is returned, or nil if there are no violations. +func (m *ReverseDialerServiceOpenRequest) Validate() error { return m.validate(false) } -// ValidateAll checks the field values on GatewayRequest with the rules defined -// in the proto definition for this message. If any rules are violated, the -// result is a list of violation errors wrapped in GatewayRequestMultiError, -// or nil if none found. -func (m *GatewayRequest) ValidateAll() error { +// ValidateAll checks the field values on ReverseDialerServiceOpenRequest with +// the rules defined in the proto definition for this message. If any rules +// are violated, the result is a list of violation errors wrapped in +// ReverseDialerServiceOpenRequestMultiError, or nil if none found. +func (m *ReverseDialerServiceOpenRequest) ValidateAll() error { return m.validate(true) } -func (m *GatewayRequest) validate(all bool) error { +func (m *ReverseDialerServiceOpenRequest) validate(all bool) error { if m == nil { return nil } @@ -58,9 +58,9 @@ func (m *GatewayRequest) validate(all bool) error { var errors []error switch v := m.Kind.(type) { - case *GatewayRequest_OpenReq: + case *ReverseDialerServiceOpenRequest_OpenReq: if v == nil { - err := GatewayRequestValidationError{ + err := ReverseDialerServiceOpenRequestValidationError{ field: "Kind", reason: "oneof value cannot be a typed-nil", } @@ -74,7 +74,7 @@ func (m *GatewayRequest) validate(all bool) error { switch v := interface{}(m.GetOpenReq()).(type) { case interface{ ValidateAll() error }: if err := v.ValidateAll(); err != nil { - errors = append(errors, GatewayRequestValidationError{ + errors = append(errors, ReverseDialerServiceOpenRequestValidationError{ field: "OpenReq", reason: "embedded message failed validation", cause: err, @@ -82,7 +82,7 @@ func (m *GatewayRequest) validate(all bool) error { } case interface{ Validate() error }: if err := v.Validate(); err != nil { - errors = append(errors, GatewayRequestValidationError{ + errors = append(errors, ReverseDialerServiceOpenRequestValidationError{ field: "OpenReq", reason: "embedded message failed validation", cause: err, @@ -91,7 +91,7 @@ func (m *GatewayRequest) validate(all bool) error { } } else if v, ok := interface{}(m.GetOpenReq()).(interface{ Validate() error }); ok { if err := v.Validate(); err != nil { - return GatewayRequestValidationError{ + return ReverseDialerServiceOpenRequestValidationError{ field: "OpenReq", reason: "embedded message failed validation", cause: err, @@ -99,9 +99,9 @@ func (m *GatewayRequest) validate(all bool) error { } } - case *GatewayRequest_Frame: + case *ReverseDialerServiceOpenRequest_Frame: if v == nil { - err := GatewayRequestValidationError{ + err := ReverseDialerServiceOpenRequestValidationError{ field: "Kind", reason: "oneof value cannot be a typed-nil", } @@ -115,7 +115,7 @@ func (m *GatewayRequest) validate(all bool) error { switch v := interface{}(m.GetFrame()).(type) { case interface{ ValidateAll() error }: if err := v.ValidateAll(); err != nil { - errors = append(errors, GatewayRequestValidationError{ + errors = append(errors, ReverseDialerServiceOpenRequestValidationError{ field: "Frame", reason: "embedded message failed validation", cause: err, @@ -123,7 +123,7 @@ func (m *GatewayRequest) validate(all bool) error { } case interface{ Validate() error }: if err := v.Validate(); err != nil { - errors = append(errors, GatewayRequestValidationError{ + errors = append(errors, ReverseDialerServiceOpenRequestValidationError{ field: "Frame", reason: "embedded message failed validation", cause: err, @@ -132,7 +132,7 @@ func (m *GatewayRequest) validate(all bool) error { } } else if v, ok := interface{}(m.GetFrame()).(interface{ Validate() error }); ok { if err := v.Validate(); err != nil { - return GatewayRequestValidationError{ + return ReverseDialerServiceOpenRequestValidationError{ field: "Frame", reason: "embedded message failed validation", cause: err, @@ -145,19 +145,19 @@ func (m *GatewayRequest) validate(all bool) error { } if len(errors) > 0 { - return GatewayRequestMultiError(errors) + return ReverseDialerServiceOpenRequestMultiError(errors) } return nil } -// GatewayRequestMultiError is an error wrapping multiple validation errors -// returned by GatewayRequest.ValidateAll() if the designated constraints -// aren't met. -type GatewayRequestMultiError []error +// ReverseDialerServiceOpenRequestMultiError is an error wrapping multiple +// validation errors returned by ReverseDialerServiceOpenRequest.ValidateAll() +// if the designated constraints aren't met. +type ReverseDialerServiceOpenRequestMultiError []error // Error returns a concatenation of all the error messages it wraps. -func (m GatewayRequestMultiError) Error() string { +func (m ReverseDialerServiceOpenRequestMultiError) Error() string { msgs := make([]string, 0, len(m)) for _, err := range m { msgs = append(msgs, err.Error()) @@ -166,11 +166,12 @@ func (m GatewayRequestMultiError) Error() string { } // AllErrors returns a list of validation violation errors. -func (m GatewayRequestMultiError) AllErrors() []error { return m } +func (m ReverseDialerServiceOpenRequestMultiError) AllErrors() []error { return m } -// GatewayRequestValidationError is the validation error returned by -// GatewayRequest.Validate if the designated constraints aren't met. -type GatewayRequestValidationError struct { +// ReverseDialerServiceOpenRequestValidationError is the validation error +// returned by ReverseDialerServiceOpenRequest.Validate if the designated +// constraints aren't met. +type ReverseDialerServiceOpenRequestValidationError struct { field string reason string cause error @@ -178,22 +179,24 @@ type GatewayRequestValidationError struct { } // Field function returns field value. -func (e GatewayRequestValidationError) Field() string { return e.field } +func (e ReverseDialerServiceOpenRequestValidationError) Field() string { return e.field } // Reason function returns reason value. -func (e GatewayRequestValidationError) Reason() string { return e.reason } +func (e ReverseDialerServiceOpenRequestValidationError) Reason() string { return e.reason } // Cause function returns cause value. -func (e GatewayRequestValidationError) Cause() error { return e.cause } +func (e ReverseDialerServiceOpenRequestValidationError) Cause() error { return e.cause } // Key function returns key value. -func (e GatewayRequestValidationError) Key() bool { return e.key } +func (e ReverseDialerServiceOpenRequestValidationError) Key() bool { return e.key } // ErrorName returns error name. -func (e GatewayRequestValidationError) ErrorName() string { return "GatewayRequestValidationError" } +func (e ReverseDialerServiceOpenRequestValidationError) ErrorName() string { + return "ReverseDialerServiceOpenRequestValidationError" +} // Error satisfies the builtin error interface -func (e GatewayRequestValidationError) Error() string { +func (e ReverseDialerServiceOpenRequestValidationError) Error() string { cause := "" if e.cause != nil { cause = fmt.Sprintf(" | caused by: %v", e.cause) @@ -205,14 +208,14 @@ func (e GatewayRequestValidationError) Error() string { } return fmt.Sprintf( - "invalid %sGatewayRequest.%s: %s%s", + "invalid %sReverseDialerServiceOpenRequest.%s: %s%s", key, e.field, e.reason, cause) } -var _ error = GatewayRequestValidationError{} +var _ error = ReverseDialerServiceOpenRequestValidationError{} var _ interface { Field() string @@ -220,24 +223,25 @@ var _ interface { Key() bool Cause() error ErrorName() string -} = GatewayRequestValidationError{} +} = ReverseDialerServiceOpenRequestValidationError{} -// Validate checks the field values on GatewayResponse with the rules defined -// in the proto definition for this message. If any rules are violated, the -// first error encountered is returned, or nil if there are no violations. -func (m *GatewayResponse) Validate() error { +// Validate checks the field values on ReverseDialerServiceOpenResponse with +// the rules defined in the proto definition for this message. If any rules +// are violated, the first error encountered is returned, or nil if there are +// no violations. +func (m *ReverseDialerServiceOpenResponse) Validate() error { return m.validate(false) } -// ValidateAll checks the field values on GatewayResponse with the rules -// defined in the proto definition for this message. If any rules are -// violated, the result is a list of violation errors wrapped in -// GatewayResponseMultiError, or nil if none found. -func (m *GatewayResponse) ValidateAll() error { +// ValidateAll checks the field values on ReverseDialerServiceOpenResponse with +// the rules defined in the proto definition for this message. If any rules +// are violated, the result is a list of violation errors wrapped in +// ReverseDialerServiceOpenResponseMultiError, or nil if none found. +func (m *ReverseDialerServiceOpenResponse) ValidateAll() error { return m.validate(true) } -func (m *GatewayResponse) validate(all bool) error { +func (m *ReverseDialerServiceOpenResponse) validate(all bool) error { if m == nil { return nil } @@ -245,9 +249,9 @@ func (m *GatewayResponse) validate(all bool) error { var errors []error switch v := m.Kind.(type) { - case *GatewayResponse_OpenResp: + case *ReverseDialerServiceOpenResponse_OpenResp: if v == nil { - err := GatewayResponseValidationError{ + err := ReverseDialerServiceOpenResponseValidationError{ field: "Kind", reason: "oneof value cannot be a typed-nil", } @@ -261,7 +265,7 @@ func (m *GatewayResponse) validate(all bool) error { switch v := interface{}(m.GetOpenResp()).(type) { case interface{ ValidateAll() error }: if err := v.ValidateAll(); err != nil { - errors = append(errors, GatewayResponseValidationError{ + errors = append(errors, ReverseDialerServiceOpenResponseValidationError{ field: "OpenResp", reason: "embedded message failed validation", cause: err, @@ -269,7 +273,7 @@ func (m *GatewayResponse) validate(all bool) error { } case interface{ Validate() error }: if err := v.Validate(); err != nil { - errors = append(errors, GatewayResponseValidationError{ + errors = append(errors, ReverseDialerServiceOpenResponseValidationError{ field: "OpenResp", reason: "embedded message failed validation", cause: err, @@ -278,7 +282,7 @@ func (m *GatewayResponse) validate(all bool) error { } } else if v, ok := interface{}(m.GetOpenResp()).(interface{ Validate() error }); ok { if err := v.Validate(); err != nil { - return GatewayResponseValidationError{ + return ReverseDialerServiceOpenResponseValidationError{ field: "OpenResp", reason: "embedded message failed validation", cause: err, @@ -286,9 +290,9 @@ func (m *GatewayResponse) validate(all bool) error { } } - case *GatewayResponse_Frame: + case *ReverseDialerServiceOpenResponse_Frame: if v == nil { - err := GatewayResponseValidationError{ + err := ReverseDialerServiceOpenResponseValidationError{ field: "Kind", reason: "oneof value cannot be a typed-nil", } @@ -302,7 +306,7 @@ func (m *GatewayResponse) validate(all bool) error { switch v := interface{}(m.GetFrame()).(type) { case interface{ ValidateAll() error }: if err := v.ValidateAll(); err != nil { - errors = append(errors, GatewayResponseValidationError{ + errors = append(errors, ReverseDialerServiceOpenResponseValidationError{ field: "Frame", reason: "embedded message failed validation", cause: err, @@ -310,7 +314,7 @@ func (m *GatewayResponse) validate(all bool) error { } case interface{ Validate() error }: if err := v.Validate(); err != nil { - errors = append(errors, GatewayResponseValidationError{ + errors = append(errors, ReverseDialerServiceOpenResponseValidationError{ field: "Frame", reason: "embedded message failed validation", cause: err, @@ -319,7 +323,7 @@ func (m *GatewayResponse) validate(all bool) error { } } else if v, ok := interface{}(m.GetFrame()).(interface{ Validate() error }); ok { if err := v.Validate(); err != nil { - return GatewayResponseValidationError{ + return ReverseDialerServiceOpenResponseValidationError{ field: "Frame", reason: "embedded message failed validation", cause: err, @@ -332,19 +336,20 @@ func (m *GatewayResponse) validate(all bool) error { } if len(errors) > 0 { - return GatewayResponseMultiError(errors) + return ReverseDialerServiceOpenResponseMultiError(errors) } return nil } -// GatewayResponseMultiError is an error wrapping multiple validation errors -// returned by GatewayResponse.ValidateAll() if the designated constraints -// aren't met. -type GatewayResponseMultiError []error +// ReverseDialerServiceOpenResponseMultiError is an error wrapping multiple +// validation errors returned by +// ReverseDialerServiceOpenResponse.ValidateAll() if the designated +// constraints aren't met. +type ReverseDialerServiceOpenResponseMultiError []error // Error returns a concatenation of all the error messages it wraps. -func (m GatewayResponseMultiError) Error() string { +func (m ReverseDialerServiceOpenResponseMultiError) Error() string { msgs := make([]string, 0, len(m)) for _, err := range m { msgs = append(msgs, err.Error()) @@ -353,11 +358,12 @@ func (m GatewayResponseMultiError) Error() string { } // AllErrors returns a list of validation violation errors. -func (m GatewayResponseMultiError) AllErrors() []error { return m } +func (m ReverseDialerServiceOpenResponseMultiError) AllErrors() []error { return m } -// GatewayResponseValidationError is the validation error returned by -// GatewayResponse.Validate if the designated constraints aren't met. -type GatewayResponseValidationError struct { +// ReverseDialerServiceOpenResponseValidationError is the validation error +// returned by ReverseDialerServiceOpenResponse.Validate if the designated +// constraints aren't met. +type ReverseDialerServiceOpenResponseValidationError struct { field string reason string cause error @@ -365,22 +371,24 @@ type GatewayResponseValidationError struct { } // Field function returns field value. -func (e GatewayResponseValidationError) Field() string { return e.field } +func (e ReverseDialerServiceOpenResponseValidationError) Field() string { return e.field } // Reason function returns reason value. -func (e GatewayResponseValidationError) Reason() string { return e.reason } +func (e ReverseDialerServiceOpenResponseValidationError) Reason() string { return e.reason } // Cause function returns cause value. -func (e GatewayResponseValidationError) Cause() error { return e.cause } +func (e ReverseDialerServiceOpenResponseValidationError) Cause() error { return e.cause } // Key function returns key value. -func (e GatewayResponseValidationError) Key() bool { return e.key } +func (e ReverseDialerServiceOpenResponseValidationError) Key() bool { return e.key } // ErrorName returns error name. -func (e GatewayResponseValidationError) ErrorName() string { return "GatewayResponseValidationError" } +func (e ReverseDialerServiceOpenResponseValidationError) ErrorName() string { + return "ReverseDialerServiceOpenResponseValidationError" +} // Error satisfies the builtin error interface -func (e GatewayResponseValidationError) Error() string { +func (e ReverseDialerServiceOpenResponseValidationError) Error() string { cause := "" if e.cause != nil { cause = fmt.Sprintf(" | caused by: %v", e.cause) @@ -392,14 +400,14 @@ func (e GatewayResponseValidationError) Error() string { } return fmt.Sprintf( - "invalid %sGatewayResponse.%s: %s%s", + "invalid %sReverseDialerServiceOpenResponse.%s: %s%s", key, e.field, e.reason, cause) } -var _ error = GatewayResponseValidationError{} +var _ error = ReverseDialerServiceOpenResponseValidationError{} var _ interface { Field() string @@ -407,7 +415,7 @@ var _ interface { Key() bool Cause() error ErrorName() string -} = GatewayResponseValidationError{} +} = ReverseDialerServiceOpenResponseValidationError{} // Validate checks the field values on OpenRequest with the rules defined in // the proto definition for this message. If any rules are violated, the first diff --git a/pb/c1/connectorapi/rtun/v1/gateway_grpc.pb.go b/pb/c1/connectorapi/rtun/v1/gateway_grpc.pb.go index 937dc00bc..a7b200ad7 100644 --- a/pb/c1/connectorapi/rtun/v1/gateway_grpc.pb.go +++ b/pb/c1/connectorapi/rtun/v1/gateway_grpc.pb.go @@ -19,98 +19,98 @@ import ( const _ = grpc.SupportPackageIsVersion9 const ( - ReverseDialer_Open_FullMethodName = "/c1.connectorapi.rtun.v1.ReverseDialer/Open" + ReverseDialerService_Open_FullMethodName = "/c1.connectorapi.rtun.v1.ReverseDialerService/Open" ) -// ReverseDialerClient is the client API for ReverseDialer service. +// ReverseDialerServiceClient is the client API for ReverseDialerService service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. // // ReverseDialer allows callers to establish connections to clients via the gateway. // The gateway bridges caller streams to rtun sessions on the owner server. -type ReverseDialerClient interface { - Open(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[GatewayRequest, GatewayResponse], error) +type ReverseDialerServiceClient interface { + Open(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ReverseDialerServiceOpenRequest, ReverseDialerServiceOpenResponse], error) } -type reverseDialerClient struct { +type reverseDialerServiceClient struct { cc grpc.ClientConnInterface } -func NewReverseDialerClient(cc grpc.ClientConnInterface) ReverseDialerClient { - return &reverseDialerClient{cc} +func NewReverseDialerServiceClient(cc grpc.ClientConnInterface) ReverseDialerServiceClient { + return &reverseDialerServiceClient{cc} } -func (c *reverseDialerClient) Open(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[GatewayRequest, GatewayResponse], error) { +func (c *reverseDialerServiceClient) Open(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ReverseDialerServiceOpenRequest, ReverseDialerServiceOpenResponse], error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - stream, err := c.cc.NewStream(ctx, &ReverseDialer_ServiceDesc.Streams[0], ReverseDialer_Open_FullMethodName, cOpts...) + stream, err := c.cc.NewStream(ctx, &ReverseDialerService_ServiceDesc.Streams[0], ReverseDialerService_Open_FullMethodName, cOpts...) if err != nil { return nil, err } - x := &grpc.GenericClientStream[GatewayRequest, GatewayResponse]{ClientStream: stream} + x := &grpc.GenericClientStream[ReverseDialerServiceOpenRequest, ReverseDialerServiceOpenResponse]{ClientStream: stream} return x, nil } // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type ReverseDialer_OpenClient = grpc.BidiStreamingClient[GatewayRequest, GatewayResponse] +type ReverseDialerService_OpenClient = grpc.BidiStreamingClient[ReverseDialerServiceOpenRequest, ReverseDialerServiceOpenResponse] -// ReverseDialerServer is the server API for ReverseDialer service. -// All implementations should embed UnimplementedReverseDialerServer +// ReverseDialerServiceServer is the server API for ReverseDialerService service. +// All implementations should embed UnimplementedReverseDialerServiceServer // for forward compatibility. // // ReverseDialer allows callers to establish connections to clients via the gateway. // The gateway bridges caller streams to rtun sessions on the owner server. -type ReverseDialerServer interface { - Open(grpc.BidiStreamingServer[GatewayRequest, GatewayResponse]) error +type ReverseDialerServiceServer interface { + Open(grpc.BidiStreamingServer[ReverseDialerServiceOpenRequest, ReverseDialerServiceOpenResponse]) error } -// UnimplementedReverseDialerServer should be embedded to have +// UnimplementedReverseDialerServiceServer should be embedded to have // forward compatible implementations. // // NOTE: this should be embedded by value instead of pointer to avoid a nil // pointer dereference when methods are called. -type UnimplementedReverseDialerServer struct{} +type UnimplementedReverseDialerServiceServer struct{} -func (UnimplementedReverseDialerServer) Open(grpc.BidiStreamingServer[GatewayRequest, GatewayResponse]) error { +func (UnimplementedReverseDialerServiceServer) Open(grpc.BidiStreamingServer[ReverseDialerServiceOpenRequest, ReverseDialerServiceOpenResponse]) error { return status.Errorf(codes.Unimplemented, "method Open not implemented") } -func (UnimplementedReverseDialerServer) testEmbeddedByValue() {} +func (UnimplementedReverseDialerServiceServer) testEmbeddedByValue() {} -// UnsafeReverseDialerServer may be embedded to opt out of forward compatibility for this service. -// Use of this interface is not recommended, as added methods to ReverseDialerServer will +// UnsafeReverseDialerServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to ReverseDialerServiceServer will // result in compilation errors. -type UnsafeReverseDialerServer interface { - mustEmbedUnimplementedReverseDialerServer() +type UnsafeReverseDialerServiceServer interface { + mustEmbedUnimplementedReverseDialerServiceServer() } -func RegisterReverseDialerServer(s grpc.ServiceRegistrar, srv ReverseDialerServer) { - // If the following call pancis, it indicates UnimplementedReverseDialerServer was +func RegisterReverseDialerServiceServer(s grpc.ServiceRegistrar, srv ReverseDialerServiceServer) { + // If the following call pancis, it indicates UnimplementedReverseDialerServiceServer was // embedded by pointer and is nil. This will cause panics if an // unimplemented method is ever invoked, so we test this at initialization // time to prevent it from happening at runtime later due to I/O. if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { t.testEmbeddedByValue() } - s.RegisterService(&ReverseDialer_ServiceDesc, srv) + s.RegisterService(&ReverseDialerService_ServiceDesc, srv) } -func _ReverseDialer_Open_Handler(srv interface{}, stream grpc.ServerStream) error { - return srv.(ReverseDialerServer).Open(&grpc.GenericServerStream[GatewayRequest, GatewayResponse]{ServerStream: stream}) +func _ReverseDialerService_Open_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(ReverseDialerServiceServer).Open(&grpc.GenericServerStream[ReverseDialerServiceOpenRequest, ReverseDialerServiceOpenResponse]{ServerStream: stream}) } // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type ReverseDialer_OpenServer = grpc.BidiStreamingServer[GatewayRequest, GatewayResponse] +type ReverseDialerService_OpenServer = grpc.BidiStreamingServer[ReverseDialerServiceOpenRequest, ReverseDialerServiceOpenResponse] -// ReverseDialer_ServiceDesc is the grpc.ServiceDesc for ReverseDialer service. +// ReverseDialerService_ServiceDesc is the grpc.ServiceDesc for ReverseDialerService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) -var ReverseDialer_ServiceDesc = grpc.ServiceDesc{ - ServiceName: "c1.connectorapi.rtun.v1.ReverseDialer", - HandlerType: (*ReverseDialerServer)(nil), +var ReverseDialerService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "c1.connectorapi.rtun.v1.ReverseDialerService", + HandlerType: (*ReverseDialerServiceServer)(nil), Methods: []grpc.MethodDesc{}, Streams: []grpc.StreamDesc{ { StreamName: "Open", - Handler: _ReverseDialer_Open_Handler, + Handler: _ReverseDialerService_Open_Handler, ServerStreams: true, ClientStreams: true, }, diff --git a/pb/c1/connectorapi/rtun/v1/rtun.pb.go b/pb/c1/connectorapi/rtun/v1/rtun.pb.go index 5a3cf9fda..81601ade4 100644 --- a/pb/c1/connectorapi/rtun/v1/rtun.pb.go +++ b/pb/c1/connectorapi/rtun/v1/rtun.pb.go @@ -76,6 +76,94 @@ func (RstCode) EnumDescriptor() ([]byte, []int) { return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{0} } +type ReverseTunnelServiceLinkRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Frame *Frame `protobuf:"bytes,1,opt,name=frame,proto3" json:"frame,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ReverseTunnelServiceLinkRequest) Reset() { + *x = ReverseTunnelServiceLinkRequest{} + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ReverseTunnelServiceLinkRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ReverseTunnelServiceLinkRequest) ProtoMessage() {} + +func (x *ReverseTunnelServiceLinkRequest) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ReverseTunnelServiceLinkRequest.ProtoReflect.Descriptor instead. +func (*ReverseTunnelServiceLinkRequest) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{0} +} + +func (x *ReverseTunnelServiceLinkRequest) GetFrame() *Frame { + if x != nil { + return x.Frame + } + return nil +} + +type ReverseTunnelServiceLinkResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Frame *Frame `protobuf:"bytes,1,opt,name=frame,proto3" json:"frame,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ReverseTunnelServiceLinkResponse) Reset() { + *x = ReverseTunnelServiceLinkResponse{} + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ReverseTunnelServiceLinkResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ReverseTunnelServiceLinkResponse) ProtoMessage() {} + +func (x *ReverseTunnelServiceLinkResponse) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ReverseTunnelServiceLinkResponse.ProtoReflect.Descriptor instead. +func (*ReverseTunnelServiceLinkResponse) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{1} +} + +func (x *ReverseTunnelServiceLinkResponse) GetFrame() *Frame { + if x != nil { + return x.Frame + } + return nil +} + type Frame struct { state protoimpl.MessageState `protogen:"open.v1"` Sid uint32 `protobuf:"varint,1,opt,name=sid,proto3" json:"sid,omitempty"` // 0 reserved for control @@ -93,7 +181,7 @@ type Frame struct { func (x *Frame) Reset() { *x = Frame{} - mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[0] + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -105,7 +193,7 @@ func (x *Frame) String() string { func (*Frame) ProtoMessage() {} func (x *Frame) ProtoReflect() protoreflect.Message { - mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[0] + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[2] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -118,7 +206,7 @@ func (x *Frame) ProtoReflect() protoreflect.Message { // Deprecated: Use Frame.ProtoReflect.Descriptor instead. func (*Frame) Descriptor() ([]byte, []int) { - return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{0} + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{2} } func (x *Frame) GetSid() uint32 { @@ -224,7 +312,7 @@ type Hello struct { func (x *Hello) Reset() { *x = Hello{} - mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[1] + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -236,7 +324,7 @@ func (x *Hello) String() string { func (*Hello) ProtoMessage() {} func (x *Hello) ProtoReflect() protoreflect.Message { - mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[1] + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -249,7 +337,7 @@ func (x *Hello) ProtoReflect() protoreflect.Message { // Deprecated: Use Hello.ProtoReflect.Descriptor instead. func (*Hello) Descriptor() ([]byte, []int) { - return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{1} + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{3} } func (x *Hello) GetPorts() []uint32 { @@ -275,7 +363,7 @@ type Syn struct { func (x *Syn) Reset() { *x = Syn{} - mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[2] + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -287,7 +375,7 @@ func (x *Syn) String() string { func (*Syn) ProtoMessage() {} func (x *Syn) ProtoReflect() protoreflect.Message { - mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[2] + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -300,7 +388,7 @@ func (x *Syn) ProtoReflect() protoreflect.Message { // Deprecated: Use Syn.ProtoReflect.Descriptor instead. func (*Syn) Descriptor() ([]byte, []int) { - return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{2} + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{4} } func (x *Syn) GetPort() uint32 { @@ -319,7 +407,7 @@ type Data struct { func (x *Data) Reset() { *x = Data{} - mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[3] + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -331,7 +419,7 @@ func (x *Data) String() string { func (*Data) ProtoMessage() {} func (x *Data) ProtoReflect() protoreflect.Message { - mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[3] + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -344,7 +432,7 @@ func (x *Data) ProtoReflect() protoreflect.Message { // Deprecated: Use Data.ProtoReflect.Descriptor instead. func (*Data) Descriptor() ([]byte, []int) { - return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{3} + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{5} } func (x *Data) GetPayload() []byte { @@ -363,7 +451,7 @@ type Fin struct { func (x *Fin) Reset() { *x = Fin{} - mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[4] + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -375,7 +463,7 @@ func (x *Fin) String() string { func (*Fin) ProtoMessage() {} func (x *Fin) ProtoReflect() protoreflect.Message { - mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[4] + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[6] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -388,7 +476,7 @@ func (x *Fin) ProtoReflect() protoreflect.Message { // Deprecated: Use Fin.ProtoReflect.Descriptor instead. func (*Fin) Descriptor() ([]byte, []int) { - return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{4} + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{6} } func (x *Fin) GetAck() bool { @@ -407,7 +495,7 @@ type Rst struct { func (x *Rst) Reset() { *x = Rst{} - mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[5] + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -419,7 +507,7 @@ func (x *Rst) String() string { func (*Rst) ProtoMessage() {} func (x *Rst) ProtoReflect() protoreflect.Message { - mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[5] + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[7] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -432,7 +520,7 @@ func (x *Rst) ProtoReflect() protoreflect.Message { // Deprecated: Use Rst.ProtoReflect.Descriptor instead. func (*Rst) Descriptor() ([]byte, []int) { - return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{5} + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{7} } func (x *Rst) GetCode() RstCode { @@ -448,53 +536,68 @@ var file_c1_connectorapi_rtun_v1_rtun_proto_rawDesc = string([]byte{ 0x0a, 0x22, 0x63, 0x31, 0x2f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2f, 0x72, 0x74, 0x75, 0x6e, 0x2f, 0x76, 0x31, 0x2f, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x17, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, - 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x22, 0xa4, 0x02, - 0x0a, 0x05, 0x46, 0x72, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x73, 0x69, 0x64, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x73, 0x69, 0x64, 0x12, 0x36, 0x0a, 0x05, 0x68, 0x65, 0x6c, - 0x6c, 0x6f, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, + 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x22, 0x57, 0x0a, + 0x1f, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x53, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x4c, 0x69, 0x6e, 0x6b, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x34, 0x0a, 0x05, 0x66, 0x72, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x1e, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, + 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x46, 0x72, 0x61, 0x6d, 0x65, 0x52, + 0x05, 0x66, 0x72, 0x61, 0x6d, 0x65, 0x22, 0x58, 0x0a, 0x20, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, + 0x65, 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4c, 0x69, + 0x6e, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x34, 0x0a, 0x05, 0x66, 0x72, + 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x63, 0x31, 0x2e, 0x63, + 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, + 0x2e, 0x76, 0x31, 0x2e, 0x46, 0x72, 0x61, 0x6d, 0x65, 0x52, 0x05, 0x66, 0x72, 0x61, 0x6d, 0x65, + 0x22, 0xa4, 0x02, 0x0a, 0x05, 0x46, 0x72, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x73, 0x69, + 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x73, 0x69, 0x64, 0x12, 0x36, 0x0a, 0x05, + 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x63, 0x31, + 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, + 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x48, 0x00, 0x52, 0x05, 0x68, + 0x65, 0x6c, 0x6c, 0x6f, 0x12, 0x30, 0x0a, 0x03, 0x73, 0x79, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x1c, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, + 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x79, 0x6e, 0x48, + 0x00, 0x52, 0x03, 0x73, 0x79, 0x6e, 0x12, 0x33, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x0c, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, + 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x44, + 0x61, 0x74, 0x61, 0x48, 0x00, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x12, 0x30, 0x0a, 0x03, 0x66, + 0x69, 0x6e, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, - 0x76, 0x31, 0x2e, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x48, 0x00, 0x52, 0x05, 0x68, 0x65, 0x6c, 0x6c, - 0x6f, 0x12, 0x30, 0x0a, 0x03, 0x73, 0x79, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, - 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, - 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x79, 0x6e, 0x48, 0x00, 0x52, 0x03, - 0x73, 0x79, 0x6e, 0x12, 0x33, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x0c, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x1d, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, - 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x44, 0x61, 0x74, 0x61, - 0x48, 0x00, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x12, 0x30, 0x0a, 0x03, 0x66, 0x69, 0x6e, 0x18, - 0x0d, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, - 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, - 0x46, 0x69, 0x6e, 0x48, 0x00, 0x52, 0x03, 0x66, 0x69, 0x6e, 0x12, 0x30, 0x0a, 0x03, 0x72, 0x73, - 0x74, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, - 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, - 0x31, 0x2e, 0x52, 0x73, 0x74, 0x48, 0x00, 0x52, 0x03, 0x72, 0x73, 0x74, 0x42, 0x06, 0x0a, 0x04, - 0x6b, 0x69, 0x6e, 0x64, 0x22, 0x39, 0x0a, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x12, 0x14, 0x0a, - 0x05, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x05, 0x70, 0x6f, - 0x72, 0x74, 0x73, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, - 0x19, 0x0a, 0x03, 0x53, 0x79, 0x6e, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x22, 0x20, 0x0a, 0x04, 0x44, 0x61, - 0x74, 0x61, 0x12, 0x18, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0c, 0x52, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x22, 0x17, 0x0a, 0x03, - 0x46, 0x69, 0x6e, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x63, 0x6b, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x03, 0x61, 0x63, 0x6b, 0x22, 0x3b, 0x0a, 0x03, 0x52, 0x73, 0x74, 0x12, 0x34, 0x0a, 0x04, - 0x63, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x20, 0x2e, 0x63, 0x31, 0x2e, + 0x76, 0x31, 0x2e, 0x46, 0x69, 0x6e, 0x48, 0x00, 0x52, 0x03, 0x66, 0x69, 0x6e, 0x12, 0x30, 0x0a, + 0x03, 0x72, 0x73, 0x74, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, - 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x73, 0x74, 0x43, 0x6f, 0x64, 0x65, 0x52, 0x04, 0x63, 0x6f, - 0x64, 0x65, 0x2a, 0x8c, 0x01, 0x0a, 0x07, 0x52, 0x73, 0x74, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x18, - 0x0a, 0x14, 0x52, 0x53, 0x54, 0x5f, 0x43, 0x4f, 0x44, 0x45, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, - 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x18, 0x0a, 0x14, 0x52, 0x53, 0x54, 0x5f, - 0x43, 0x4f, 0x44, 0x45, 0x5f, 0x4e, 0x4f, 0x5f, 0x4c, 0x49, 0x53, 0x54, 0x45, 0x4e, 0x45, 0x52, - 0x10, 0x01, 0x12, 0x20, 0x0a, 0x1c, 0x52, 0x53, 0x54, 0x5f, 0x43, 0x4f, 0x44, 0x45, 0x5f, 0x50, - 0x4f, 0x52, 0x54, 0x5f, 0x4e, 0x4f, 0x54, 0x5f, 0x41, 0x44, 0x56, 0x45, 0x52, 0x54, 0x49, 0x53, - 0x45, 0x44, 0x10, 0x02, 0x12, 0x14, 0x0a, 0x10, 0x52, 0x53, 0x54, 0x5f, 0x43, 0x4f, 0x44, 0x45, - 0x5f, 0x54, 0x49, 0x4d, 0x45, 0x4f, 0x55, 0x54, 0x10, 0x03, 0x12, 0x15, 0x0a, 0x11, 0x52, 0x53, - 0x54, 0x5f, 0x43, 0x4f, 0x44, 0x45, 0x5f, 0x49, 0x4e, 0x54, 0x45, 0x52, 0x4e, 0x41, 0x4c, 0x10, - 0x04, 0x32, 0x5b, 0x0a, 0x0d, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x54, 0x75, 0x6e, 0x6e, - 0x65, 0x6c, 0x12, 0x4a, 0x0a, 0x04, 0x4c, 0x69, 0x6e, 0x6b, 0x12, 0x1e, 0x2e, 0x63, 0x31, 0x2e, - 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, - 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x46, 0x72, 0x61, 0x6d, 0x65, 0x1a, 0x1e, 0x2e, 0x63, 0x31, 0x2e, - 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, - 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x46, 0x72, 0x61, 0x6d, 0x65, 0x28, 0x01, 0x30, 0x01, 0x42, 0x3e, + 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x73, 0x74, 0x48, 0x00, 0x52, 0x03, 0x72, 0x73, 0x74, 0x42, + 0x06, 0x0a, 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x22, 0x39, 0x0a, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f, + 0x12, 0x14, 0x0a, 0x05, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0d, 0x52, + 0x05, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x22, 0x19, 0x0a, 0x03, 0x53, 0x79, 0x6e, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x6f, 0x72, + 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x22, 0x20, 0x0a, + 0x04, 0x44, 0x61, 0x74, 0x61, 0x12, 0x18, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x22, + 0x17, 0x0a, 0x03, 0x46, 0x69, 0x6e, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x63, 0x6b, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x63, 0x6b, 0x22, 0x3b, 0x0a, 0x03, 0x52, 0x73, 0x74, 0x12, + 0x34, 0x0a, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x20, 0x2e, + 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, + 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x73, 0x74, 0x43, 0x6f, 0x64, 0x65, 0x52, + 0x04, 0x63, 0x6f, 0x64, 0x65, 0x2a, 0x8c, 0x01, 0x0a, 0x07, 0x52, 0x73, 0x74, 0x43, 0x6f, 0x64, + 0x65, 0x12, 0x18, 0x0a, 0x14, 0x52, 0x53, 0x54, 0x5f, 0x43, 0x4f, 0x44, 0x45, 0x5f, 0x55, 0x4e, + 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x18, 0x0a, 0x14, 0x52, + 0x53, 0x54, 0x5f, 0x43, 0x4f, 0x44, 0x45, 0x5f, 0x4e, 0x4f, 0x5f, 0x4c, 0x49, 0x53, 0x54, 0x45, + 0x4e, 0x45, 0x52, 0x10, 0x01, 0x12, 0x20, 0x0a, 0x1c, 0x52, 0x53, 0x54, 0x5f, 0x43, 0x4f, 0x44, + 0x45, 0x5f, 0x50, 0x4f, 0x52, 0x54, 0x5f, 0x4e, 0x4f, 0x54, 0x5f, 0x41, 0x44, 0x56, 0x45, 0x52, + 0x54, 0x49, 0x53, 0x45, 0x44, 0x10, 0x02, 0x12, 0x14, 0x0a, 0x10, 0x52, 0x53, 0x54, 0x5f, 0x43, + 0x4f, 0x44, 0x45, 0x5f, 0x54, 0x49, 0x4d, 0x45, 0x4f, 0x55, 0x54, 0x10, 0x03, 0x12, 0x15, 0x0a, + 0x11, 0x52, 0x53, 0x54, 0x5f, 0x43, 0x4f, 0x44, 0x45, 0x5f, 0x49, 0x4e, 0x54, 0x45, 0x52, 0x4e, + 0x41, 0x4c, 0x10, 0x04, 0x32, 0x97, 0x01, 0x0a, 0x14, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, + 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x7f, 0x0a, + 0x04, 0x4c, 0x69, 0x6e, 0x6b, 0x12, 0x38, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, + 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, + 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x53, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x4c, 0x69, 0x6e, 0x6b, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x39, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, + 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, + 0x65, 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4c, 0x69, + 0x6e, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x28, 0x01, 0x30, 0x01, 0x42, 0x3e, 0x5a, 0x3c, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x6e, 0x64, 0x75, 0x63, 0x74, 0x6f, 0x72, 0x6f, 0x6e, 0x65, 0x2f, 0x62, 0x61, 0x74, 0x6f, 0x6e, 0x2d, 0x73, 0x64, 0x6b, 0x2f, 0x70, 0x62, 0x2f, 0x63, 0x31, 0x2f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, @@ -515,30 +618,34 @@ func file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP() []byte { } var file_c1_connectorapi_rtun_v1_rtun_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes = make([]protoimpl.MessageInfo, 6) +var file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes = make([]protoimpl.MessageInfo, 8) var file_c1_connectorapi_rtun_v1_rtun_proto_goTypes = []any{ - (RstCode)(0), // 0: c1.connectorapi.rtun.v1.RstCode - (*Frame)(nil), // 1: c1.connectorapi.rtun.v1.Frame - (*Hello)(nil), // 2: c1.connectorapi.rtun.v1.Hello - (*Syn)(nil), // 3: c1.connectorapi.rtun.v1.Syn - (*Data)(nil), // 4: c1.connectorapi.rtun.v1.Data - (*Fin)(nil), // 5: c1.connectorapi.rtun.v1.Fin - (*Rst)(nil), // 6: c1.connectorapi.rtun.v1.Rst + (RstCode)(0), // 0: c1.connectorapi.rtun.v1.RstCode + (*ReverseTunnelServiceLinkRequest)(nil), // 1: c1.connectorapi.rtun.v1.ReverseTunnelServiceLinkRequest + (*ReverseTunnelServiceLinkResponse)(nil), // 2: c1.connectorapi.rtun.v1.ReverseTunnelServiceLinkResponse + (*Frame)(nil), // 3: c1.connectorapi.rtun.v1.Frame + (*Hello)(nil), // 4: c1.connectorapi.rtun.v1.Hello + (*Syn)(nil), // 5: c1.connectorapi.rtun.v1.Syn + (*Data)(nil), // 6: c1.connectorapi.rtun.v1.Data + (*Fin)(nil), // 7: c1.connectorapi.rtun.v1.Fin + (*Rst)(nil), // 8: c1.connectorapi.rtun.v1.Rst } var file_c1_connectorapi_rtun_v1_rtun_proto_depIdxs = []int32{ - 2, // 0: c1.connectorapi.rtun.v1.Frame.hello:type_name -> c1.connectorapi.rtun.v1.Hello - 3, // 1: c1.connectorapi.rtun.v1.Frame.syn:type_name -> c1.connectorapi.rtun.v1.Syn - 4, // 2: c1.connectorapi.rtun.v1.Frame.data:type_name -> c1.connectorapi.rtun.v1.Data - 5, // 3: c1.connectorapi.rtun.v1.Frame.fin:type_name -> c1.connectorapi.rtun.v1.Fin - 6, // 4: c1.connectorapi.rtun.v1.Frame.rst:type_name -> c1.connectorapi.rtun.v1.Rst - 0, // 5: c1.connectorapi.rtun.v1.Rst.code:type_name -> c1.connectorapi.rtun.v1.RstCode - 1, // 6: c1.connectorapi.rtun.v1.ReverseTunnel.Link:input_type -> c1.connectorapi.rtun.v1.Frame - 1, // 7: c1.connectorapi.rtun.v1.ReverseTunnel.Link:output_type -> c1.connectorapi.rtun.v1.Frame - 7, // [7:8] is the sub-list for method output_type - 6, // [6:7] is the sub-list for method input_type - 6, // [6:6] is the sub-list for extension type_name - 6, // [6:6] is the sub-list for extension extendee - 0, // [0:6] is the sub-list for field type_name + 3, // 0: c1.connectorapi.rtun.v1.ReverseTunnelServiceLinkRequest.frame:type_name -> c1.connectorapi.rtun.v1.Frame + 3, // 1: c1.connectorapi.rtun.v1.ReverseTunnelServiceLinkResponse.frame:type_name -> c1.connectorapi.rtun.v1.Frame + 4, // 2: c1.connectorapi.rtun.v1.Frame.hello:type_name -> c1.connectorapi.rtun.v1.Hello + 5, // 3: c1.connectorapi.rtun.v1.Frame.syn:type_name -> c1.connectorapi.rtun.v1.Syn + 6, // 4: c1.connectorapi.rtun.v1.Frame.data:type_name -> c1.connectorapi.rtun.v1.Data + 7, // 5: c1.connectorapi.rtun.v1.Frame.fin:type_name -> c1.connectorapi.rtun.v1.Fin + 8, // 6: c1.connectorapi.rtun.v1.Frame.rst:type_name -> c1.connectorapi.rtun.v1.Rst + 0, // 7: c1.connectorapi.rtun.v1.Rst.code:type_name -> c1.connectorapi.rtun.v1.RstCode + 1, // 8: c1.connectorapi.rtun.v1.ReverseTunnelService.Link:input_type -> c1.connectorapi.rtun.v1.ReverseTunnelServiceLinkRequest + 2, // 9: c1.connectorapi.rtun.v1.ReverseTunnelService.Link:output_type -> c1.connectorapi.rtun.v1.ReverseTunnelServiceLinkResponse + 9, // [9:10] is the sub-list for method output_type + 8, // [8:9] is the sub-list for method input_type + 8, // [8:8] is the sub-list for extension type_name + 8, // [8:8] is the sub-list for extension extendee + 0, // [0:8] is the sub-list for field type_name } func init() { file_c1_connectorapi_rtun_v1_rtun_proto_init() } @@ -546,7 +653,7 @@ func file_c1_connectorapi_rtun_v1_rtun_proto_init() { if File_c1_connectorapi_rtun_v1_rtun_proto != nil { return } - file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[0].OneofWrappers = []any{ + file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[2].OneofWrappers = []any{ (*Frame_Hello)(nil), (*Frame_Syn)(nil), (*Frame_Data)(nil), @@ -559,7 +666,7 @@ func file_c1_connectorapi_rtun_v1_rtun_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_c1_connectorapi_rtun_v1_rtun_proto_rawDesc), len(file_c1_connectorapi_rtun_v1_rtun_proto_rawDesc)), NumEnums: 1, - NumMessages: 6, + NumMessages: 8, NumExtensions: 0, NumServices: 1, }, diff --git a/pb/c1/connectorapi/rtun/v1/rtun.pb.validate.go b/pb/c1/connectorapi/rtun/v1/rtun.pb.validate.go index 854f092e4..d3214846a 100644 --- a/pb/c1/connectorapi/rtun/v1/rtun.pb.validate.go +++ b/pb/c1/connectorapi/rtun/v1/rtun.pb.validate.go @@ -35,6 +35,272 @@ var ( _ = sort.Sort ) +// Validate checks the field values on ReverseTunnelServiceLinkRequest with the +// rules defined in the proto definition for this message. If any rules are +// violated, the first error encountered is returned, or nil if there are no violations. +func (m *ReverseTunnelServiceLinkRequest) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on ReverseTunnelServiceLinkRequest with +// the rules defined in the proto definition for this message. If any rules +// are violated, the result is a list of violation errors wrapped in +// ReverseTunnelServiceLinkRequestMultiError, or nil if none found. +func (m *ReverseTunnelServiceLinkRequest) ValidateAll() error { + return m.validate(true) +} + +func (m *ReverseTunnelServiceLinkRequest) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + if all { + switch v := interface{}(m.GetFrame()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, ReverseTunnelServiceLinkRequestValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, ReverseTunnelServiceLinkRequestValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetFrame()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return ReverseTunnelServiceLinkRequestValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + } + } + } + + if len(errors) > 0 { + return ReverseTunnelServiceLinkRequestMultiError(errors) + } + + return nil +} + +// ReverseTunnelServiceLinkRequestMultiError is an error wrapping multiple +// validation errors returned by ReverseTunnelServiceLinkRequest.ValidateAll() +// if the designated constraints aren't met. +type ReverseTunnelServiceLinkRequestMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m ReverseTunnelServiceLinkRequestMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m ReverseTunnelServiceLinkRequestMultiError) AllErrors() []error { return m } + +// ReverseTunnelServiceLinkRequestValidationError is the validation error +// returned by ReverseTunnelServiceLinkRequest.Validate if the designated +// constraints aren't met. +type ReverseTunnelServiceLinkRequestValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e ReverseTunnelServiceLinkRequestValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e ReverseTunnelServiceLinkRequestValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e ReverseTunnelServiceLinkRequestValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e ReverseTunnelServiceLinkRequestValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e ReverseTunnelServiceLinkRequestValidationError) ErrorName() string { + return "ReverseTunnelServiceLinkRequestValidationError" +} + +// Error satisfies the builtin error interface +func (e ReverseTunnelServiceLinkRequestValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sReverseTunnelServiceLinkRequest.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = ReverseTunnelServiceLinkRequestValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = ReverseTunnelServiceLinkRequestValidationError{} + +// Validate checks the field values on ReverseTunnelServiceLinkResponse with +// the rules defined in the proto definition for this message. If any rules +// are violated, the first error encountered is returned, or nil if there are +// no violations. +func (m *ReverseTunnelServiceLinkResponse) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on ReverseTunnelServiceLinkResponse with +// the rules defined in the proto definition for this message. If any rules +// are violated, the result is a list of violation errors wrapped in +// ReverseTunnelServiceLinkResponseMultiError, or nil if none found. +func (m *ReverseTunnelServiceLinkResponse) ValidateAll() error { + return m.validate(true) +} + +func (m *ReverseTunnelServiceLinkResponse) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + if all { + switch v := interface{}(m.GetFrame()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, ReverseTunnelServiceLinkResponseValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, ReverseTunnelServiceLinkResponseValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetFrame()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return ReverseTunnelServiceLinkResponseValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + } + } + } + + if len(errors) > 0 { + return ReverseTunnelServiceLinkResponseMultiError(errors) + } + + return nil +} + +// ReverseTunnelServiceLinkResponseMultiError is an error wrapping multiple +// validation errors returned by +// ReverseTunnelServiceLinkResponse.ValidateAll() if the designated +// constraints aren't met. +type ReverseTunnelServiceLinkResponseMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m ReverseTunnelServiceLinkResponseMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m ReverseTunnelServiceLinkResponseMultiError) AllErrors() []error { return m } + +// ReverseTunnelServiceLinkResponseValidationError is the validation error +// returned by ReverseTunnelServiceLinkResponse.Validate if the designated +// constraints aren't met. +type ReverseTunnelServiceLinkResponseValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e ReverseTunnelServiceLinkResponseValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e ReverseTunnelServiceLinkResponseValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e ReverseTunnelServiceLinkResponseValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e ReverseTunnelServiceLinkResponseValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e ReverseTunnelServiceLinkResponseValidationError) ErrorName() string { + return "ReverseTunnelServiceLinkResponseValidationError" +} + +// Error satisfies the builtin error interface +func (e ReverseTunnelServiceLinkResponseValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sReverseTunnelServiceLinkResponse.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = ReverseTunnelServiceLinkResponseValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = ReverseTunnelServiceLinkResponseValidationError{} + // Validate checks the field values on Frame with the rules defined in the // proto definition for this message. If any rules are violated, the first // error encountered is returned, or nil if there are no violations. diff --git a/pb/c1/connectorapi/rtun/v1/rtun_grpc.pb.go b/pb/c1/connectorapi/rtun/v1/rtun_grpc.pb.go index 9f4ee09a5..7e9ca2932 100644 --- a/pb/c1/connectorapi/rtun/v1/rtun_grpc.pb.go +++ b/pb/c1/connectorapi/rtun/v1/rtun_grpc.pb.go @@ -19,92 +19,92 @@ import ( const _ = grpc.SupportPackageIsVersion9 const ( - ReverseTunnel_Link_FullMethodName = "/c1.connectorapi.rtun.v1.ReverseTunnel/Link" + ReverseTunnelService_Link_FullMethodName = "/c1.connectorapi.rtun.v1.ReverseTunnelService/Link" ) -// ReverseTunnelClient is the client API for ReverseTunnel service. +// ReverseTunnelServiceClient is the client API for ReverseTunnelService service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. -type ReverseTunnelClient interface { - Link(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[Frame, Frame], error) +type ReverseTunnelServiceClient interface { + Link(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ReverseTunnelServiceLinkRequest, ReverseTunnelServiceLinkResponse], error) } -type reverseTunnelClient struct { +type reverseTunnelServiceClient struct { cc grpc.ClientConnInterface } -func NewReverseTunnelClient(cc grpc.ClientConnInterface) ReverseTunnelClient { - return &reverseTunnelClient{cc} +func NewReverseTunnelServiceClient(cc grpc.ClientConnInterface) ReverseTunnelServiceClient { + return &reverseTunnelServiceClient{cc} } -func (c *reverseTunnelClient) Link(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[Frame, Frame], error) { +func (c *reverseTunnelServiceClient) Link(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ReverseTunnelServiceLinkRequest, ReverseTunnelServiceLinkResponse], error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - stream, err := c.cc.NewStream(ctx, &ReverseTunnel_ServiceDesc.Streams[0], ReverseTunnel_Link_FullMethodName, cOpts...) + stream, err := c.cc.NewStream(ctx, &ReverseTunnelService_ServiceDesc.Streams[0], ReverseTunnelService_Link_FullMethodName, cOpts...) if err != nil { return nil, err } - x := &grpc.GenericClientStream[Frame, Frame]{ClientStream: stream} + x := &grpc.GenericClientStream[ReverseTunnelServiceLinkRequest, ReverseTunnelServiceLinkResponse]{ClientStream: stream} return x, nil } // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type ReverseTunnel_LinkClient = grpc.BidiStreamingClient[Frame, Frame] +type ReverseTunnelService_LinkClient = grpc.BidiStreamingClient[ReverseTunnelServiceLinkRequest, ReverseTunnelServiceLinkResponse] -// ReverseTunnelServer is the server API for ReverseTunnel service. -// All implementations should embed UnimplementedReverseTunnelServer +// ReverseTunnelServiceServer is the server API for ReverseTunnelService service. +// All implementations should embed UnimplementedReverseTunnelServiceServer // for forward compatibility. -type ReverseTunnelServer interface { - Link(grpc.BidiStreamingServer[Frame, Frame]) error +type ReverseTunnelServiceServer interface { + Link(grpc.BidiStreamingServer[ReverseTunnelServiceLinkRequest, ReverseTunnelServiceLinkResponse]) error } -// UnimplementedReverseTunnelServer should be embedded to have +// UnimplementedReverseTunnelServiceServer should be embedded to have // forward compatible implementations. // // NOTE: this should be embedded by value instead of pointer to avoid a nil // pointer dereference when methods are called. -type UnimplementedReverseTunnelServer struct{} +type UnimplementedReverseTunnelServiceServer struct{} -func (UnimplementedReverseTunnelServer) Link(grpc.BidiStreamingServer[Frame, Frame]) error { +func (UnimplementedReverseTunnelServiceServer) Link(grpc.BidiStreamingServer[ReverseTunnelServiceLinkRequest, ReverseTunnelServiceLinkResponse]) error { return status.Errorf(codes.Unimplemented, "method Link not implemented") } -func (UnimplementedReverseTunnelServer) testEmbeddedByValue() {} +func (UnimplementedReverseTunnelServiceServer) testEmbeddedByValue() {} -// UnsafeReverseTunnelServer may be embedded to opt out of forward compatibility for this service. -// Use of this interface is not recommended, as added methods to ReverseTunnelServer will +// UnsafeReverseTunnelServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to ReverseTunnelServiceServer will // result in compilation errors. -type UnsafeReverseTunnelServer interface { - mustEmbedUnimplementedReverseTunnelServer() +type UnsafeReverseTunnelServiceServer interface { + mustEmbedUnimplementedReverseTunnelServiceServer() } -func RegisterReverseTunnelServer(s grpc.ServiceRegistrar, srv ReverseTunnelServer) { - // If the following call pancis, it indicates UnimplementedReverseTunnelServer was +func RegisterReverseTunnelServiceServer(s grpc.ServiceRegistrar, srv ReverseTunnelServiceServer) { + // If the following call pancis, it indicates UnimplementedReverseTunnelServiceServer was // embedded by pointer and is nil. This will cause panics if an // unimplemented method is ever invoked, so we test this at initialization // time to prevent it from happening at runtime later due to I/O. if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { t.testEmbeddedByValue() } - s.RegisterService(&ReverseTunnel_ServiceDesc, srv) + s.RegisterService(&ReverseTunnelService_ServiceDesc, srv) } -func _ReverseTunnel_Link_Handler(srv interface{}, stream grpc.ServerStream) error { - return srv.(ReverseTunnelServer).Link(&grpc.GenericServerStream[Frame, Frame]{ServerStream: stream}) +func _ReverseTunnelService_Link_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(ReverseTunnelServiceServer).Link(&grpc.GenericServerStream[ReverseTunnelServiceLinkRequest, ReverseTunnelServiceLinkResponse]{ServerStream: stream}) } // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type ReverseTunnel_LinkServer = grpc.BidiStreamingServer[Frame, Frame] +type ReverseTunnelService_LinkServer = grpc.BidiStreamingServer[ReverseTunnelServiceLinkRequest, ReverseTunnelServiceLinkResponse] -// ReverseTunnel_ServiceDesc is the grpc.ServiceDesc for ReverseTunnel service. +// ReverseTunnelService_ServiceDesc is the grpc.ServiceDesc for ReverseTunnelService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) -var ReverseTunnel_ServiceDesc = grpc.ServiceDesc{ - ServiceName: "c1.connectorapi.rtun.v1.ReverseTunnel", - HandlerType: (*ReverseTunnelServer)(nil), +var ReverseTunnelService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "c1.connectorapi.rtun.v1.ReverseTunnelService", + HandlerType: (*ReverseTunnelServiceServer)(nil), Methods: []grpc.MethodDesc{}, Streams: []grpc.StreamDesc{ { StreamName: "Link", - Handler: _ReverseTunnel_Link_Handler, + Handler: _ReverseTunnelService_Link_Handler, ServerStreams: true, ClientStreams: true, }, diff --git a/pkg/rtun/gateway/client.go b/pkg/rtun/gateway/client.go index 645461ea0..4dc4f2795 100644 --- a/pkg/rtun/gateway/client.go +++ b/pkg/rtun/gateway/client.go @@ -82,7 +82,7 @@ func (d *Dialer) DialContext(ctx context.Context, clientID string, port uint32) return nil, fmt.Errorf("gateway dial failed: %w", err) } - client := rtunpb.NewReverseDialerClient(cc) + client := rtunpb.NewReverseDialerServiceClient(cc) // Create a cancellable stream context so Close() can interrupt Recv/Send. streamCtx, cancel := context.WithCancel(ctx) stream, err := client.Open(streamCtx) @@ -94,7 +94,7 @@ func (d *Dialer) DialContext(ctx context.Context, clientID string, port uint32) // Send OpenRequest with gSID=1 (simple case: one connection per stream) gsid := uint32(1) - if err := stream.Send(&rtunpb.GatewayRequest{Kind: &rtunpb.GatewayRequest_OpenReq{ + if err := stream.Send(&rtunpb.ReverseDialerServiceOpenRequest{Kind: &rtunpb.ReverseDialerServiceOpenRequest_OpenReq{ OpenReq: &rtunpb.OpenRequest{Gsid: gsid, ClientId: clientID, Port: port}, }}); err != nil { stream.CloseSend() @@ -158,7 +158,7 @@ var _ net.Conn = (*gatewayConn)(nil) // gatewayConn implements net.Conn over a gateway stream. type gatewayConn struct { - stream rtunpb.ReverseDialer_OpenClient + stream rtunpb.ReverseDialerService_OpenClient cc *grpc.ClientConn gsid uint32 cancel context.CancelFunc @@ -183,7 +183,7 @@ type writeMsg struct { } type reader struct { - stream rtunpb.ReverseDialer_OpenClient + stream rtunpb.ReverseDialerService_OpenClient gsid uint32 bufCap int ch chan []byte @@ -195,7 +195,7 @@ type reader struct { once sync.Once } -func newReader(stream rtunpb.ReverseDialer_OpenClient, gsid uint32, bufCap int, doneCh <-chan struct{}) *reader { +func newReader(stream rtunpb.ReverseDialerService_OpenClient, gsid uint32, bufCap int, doneCh <-chan struct{}) *reader { if bufCap <= 0 { bufCap = defaultReadBufferCap } @@ -274,7 +274,7 @@ func (r *reader) next(ctx context.Context) ([]byte, error) { } type writer struct { - stream rtunpb.ReverseDialer_OpenClient + stream rtunpb.ReverseDialerService_OpenClient gsid uint32 queueCap int ch chan writeMsg @@ -285,7 +285,7 @@ type writer struct { once sync.Once } -func newWriter(stream rtunpb.ReverseDialer_OpenClient, gsid uint32, queueCap int, doneCh <-chan struct{}) *writer { +func newWriter(stream rtunpb.ReverseDialerService_OpenClient, gsid uint32, queueCap int, doneCh <-chan struct{}) *writer { if queueCap <= 0 { queueCap = defaultWriteQueueCap } @@ -323,12 +323,12 @@ func (w *writer) loop() { select { case msg := <-w.ch: if msg.fin { - _ = w.stream.Send(&rtunpb.GatewayRequest{Kind: &rtunpb.GatewayRequest_Frame{ + _ = w.stream.Send(&rtunpb.ReverseDialerServiceOpenRequest{Kind: &rtunpb.ReverseDialerServiceOpenRequest_Frame{ Frame: &rtunpb.Frame{Sid: w.gsid, Kind: &rtunpb.Frame_Fin{Fin: &rtunpb.Fin{}}}, }}) continue } - if err := w.stream.Send(&rtunpb.GatewayRequest{Kind: &rtunpb.GatewayRequest_Frame{ + if err := w.stream.Send(&rtunpb.ReverseDialerServiceOpenRequest{Kind: &rtunpb.ReverseDialerServiceOpenRequest_Frame{ Frame: &rtunpb.Frame{Sid: w.gsid, Kind: &rtunpb.Frame_Data{Data: &rtunpb.Data{Payload: msg.payload}}}, }}); err != nil { w.setErr(err) @@ -478,7 +478,7 @@ func (g *gatewayConn) Close() error { default: } } else { - _ = g.stream.Send(&rtunpb.GatewayRequest{Kind: &rtunpb.GatewayRequest_Frame{ + _ = g.stream.Send(&rtunpb.ReverseDialerServiceOpenRequest{Kind: &rtunpb.ReverseDialerServiceOpenRequest_Frame{ Frame: &rtunpb.Frame{Sid: g.gsid, Kind: &rtunpb.Frame_Fin{Fin: &rtunpb.Fin{}}}, }}) } diff --git a/pkg/rtun/gateway/client_conn_test.go b/pkg/rtun/gateway/client_conn_test.go index 2cc424efc..936d6c6cb 100644 --- a/pkg/rtun/gateway/client_conn_test.go +++ b/pkg/rtun/gateway/client_conn_test.go @@ -33,8 +33,8 @@ func setupGateway(t *testing.T, silent bool) *gwEnv { gw := NewServer(reg, "server-a", nil) gsrv := grpc.NewServer() - rtunpb.RegisterReverseTunnelServer(gsrv, handler) - rtunpb.RegisterReverseDialerServer(gsrv, gw) + rtunpb.RegisterReverseTunnelServiceServer(gsrv, handler) + rtunpb.RegisterReverseDialerServiceServer(gsrv, gw) l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) go func() { _ = gsrv.Serve(l) }() @@ -42,7 +42,7 @@ func setupGateway(t *testing.T, silent bool) *gwEnv { // Bring up a client link and listen on port 1 cc, err := grpc.Dial(l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) require.NoError(t, err) - rtunClient := rtunpb.NewReverseTunnelClient(cc) + rtunClient := rtunpb.NewReverseTunnelServiceClient(cc) stream, err := rtunClient.Link(context.Background()) require.NoError(t, err) cl := &clientLink{cli: stream} diff --git a/pkg/rtun/gateway/integration_test.go b/pkg/rtun/gateway/integration_test.go index 63f66a163..cc7329172 100644 --- a/pkg/rtun/gateway/integration_test.go +++ b/pkg/rtun/gateway/integration_test.go @@ -24,12 +24,22 @@ func (t testValidator) ValidateHello(ctx context.Context, hello *rtunpb.Hello) e // clientLink adapts client bidi stream to transport.Link. type clientLink struct { - cli rtunpb.ReverseTunnel_LinkClient + cli rtunpb.ReverseTunnelService_LinkClient } -func (c *clientLink) Send(f *rtunpb.Frame) error { return c.cli.Send(f) } -func (c *clientLink) Recv() (*rtunpb.Frame, error) { return c.cli.Recv() } -func (c *clientLink) Context() context.Context { return c.cli.Context() } +func (c *clientLink) Send(f *rtunpb.Frame) error { + return c.cli.Send(&rtunpb.ReverseTunnelServiceLinkRequest{Frame: f}) +} + +func (c *clientLink) Recv() (*rtunpb.Frame, error) { + fr, err := c.cli.Recv() + if err != nil { + return nil, err + } + return fr.GetFrame(), nil +} + +func (c *clientLink) Context() context.Context { return c.cli.Context() } // TestGatewayE2E validates the full gateway stack: // - Client connects to server A (handler+registry). @@ -46,8 +56,8 @@ func TestGatewayE2E(t *testing.T) { gwA := NewServer(regA, "server-a", nil) gsrvA := grpc.NewServer() - rtunpb.RegisterReverseTunnelServer(gsrvA, handlerA) - rtunpb.RegisterReverseDialerServer(gsrvA, gwA) + rtunpb.RegisterReverseTunnelServiceServer(gsrvA, handlerA) + rtunpb.RegisterReverseDialerServiceServer(gsrvA, gwA) lA, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -60,7 +70,7 @@ func TestGatewayE2E(t *testing.T) { ccA, err := grpc.Dial(lA.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) require.NoError(t, err) - rtunClient := rtunpb.NewReverseTunnelClient(ccA) + rtunClient := rtunpb.NewReverseTunnelServiceClient(ccA) stream, err := rtunClient.Link(clientCtx) require.NoError(t, err) @@ -120,7 +130,7 @@ func TestGatewayNotFound(t *testing.T) { gwA := NewServer(regA, "server-a", nil) gsrvA := grpc.NewServer() - rtunpb.RegisterReverseDialerServer(gsrvA, gwA) + rtunpb.RegisterReverseDialerServiceServer(gsrvA, gwA) lA, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) diff --git a/pkg/rtun/gateway/server.go b/pkg/rtun/gateway/server.go index 556a20b41..afc9ecea6 100644 --- a/pkg/rtun/gateway/server.go +++ b/pkg/rtun/gateway/server.go @@ -2,6 +2,7 @@ package gateway import ( "context" + "errors" "io" "net" "net/url" @@ -18,7 +19,7 @@ import ( // Server implements the ReverseDialer gateway service. // It bridges caller streams to rtun sessions on the owner server process. type Server struct { - rtunpb.UnimplementedReverseDialerServer + rtunpb.UnimplementedReverseDialerServiceServer reg *server.Registry serverID string @@ -113,7 +114,7 @@ func (e *entry) close() { }) } -func (s *Server) Open(stream rtunpb.ReverseDialer_OpenServer) error { +func (s *Server) Open(stream rtunpb.ReverseDialerService_OpenServer) error { ctx := stream.Context() logger := ctxzap.Extract(ctx).With(zap.String("server_id", s.serverID)) @@ -135,14 +136,14 @@ func (s *Server) Open(stream rtunpb.ReverseDialer_OpenServer) error { for { req, err := stream.Recv() if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { return nil } return err } switch k := req.Kind.(type) { - case *rtunpb.GatewayRequest_OpenReq: + case *rtunpb.ReverseDialerServiceOpenRequest_OpenReq: openReq := k.OpenReq gsid := openReq.GetGsid() clientID := openReq.GetClientId() @@ -159,7 +160,7 @@ func (s *Server) Open(stream rtunpb.ReverseDialer_OpenServer) error { if _, exists := entries[gsid]; exists { mu.Unlock() logger.Warn("duplicate gSID in OpenRequest") - _ = stream.Send(&rtunpb.GatewayResponse{Kind: &rtunpb.GatewayResponse_Frame{ + _ = stream.Send(&rtunpb.ReverseDialerServiceOpenResponse{Kind: &rtunpb.ReverseDialerServiceOpenResponse_Frame{ Frame: &rtunpb.Frame{Sid: gsid, Kind: &rtunpb.Frame_Rst{Rst: &rtunpb.Rst{Code: rtunpb.RstCode_RST_CODE_INTERNAL}}}, }}) continue @@ -171,7 +172,7 @@ func (s *Server) Open(stream rtunpb.ReverseDialer_OpenServer) error { conn, err := s.reg.DialContext(ctx, addr) if err != nil { logger.Info("client not found or dial failed", zap.Error(err)) - _ = stream.Send(&rtunpb.GatewayResponse{Kind: &rtunpb.GatewayResponse_OpenResp{ + _ = stream.Send(&rtunpb.ReverseDialerServiceOpenResponse{Kind: &rtunpb.ReverseDialerServiceOpenResponse_OpenResp{ OpenResp: &rtunpb.OpenResponse{Gsid: gsid, Result: &rtunpb.OpenResponse_NotFound{NotFound: &rtunpb.NotFound{}}}, }}) if s.m != nil { @@ -187,7 +188,7 @@ func (s *Server) Open(stream rtunpb.ReverseDialer_OpenServer) error { mu.Unlock() logger.Info("opened reverse connection") - if err := stream.Send(&rtunpb.GatewayResponse{Kind: &rtunpb.GatewayResponse_OpenResp{ + if err := stream.Send(&rtunpb.ReverseDialerServiceOpenResponse{Kind: &rtunpb.ReverseDialerServiceOpenResponse_OpenResp{ OpenResp: &rtunpb.OpenResponse{Gsid: gsid, Result: &rtunpb.OpenResponse_Opened{Opened: &rtunpb.Opened{}}}, }}); err != nil { conn.Close() @@ -201,7 +202,7 @@ func (s *Server) Open(stream rtunpb.ReverseDialer_OpenServer) error { wg.Add(1) go s.bridgeRead(ctx, conn, gsid, stream, &wg, logger) - case *rtunpb.GatewayRequest_Frame: + case *rtunpb.ReverseDialerServiceOpenRequest_Frame: fr := k.Frame gsid := fr.GetSid() @@ -211,7 +212,7 @@ func (s *Server) Open(stream rtunpb.ReverseDialer_OpenServer) error { if ent == nil { // Unknown gSID; send RST - _ = stream.Send(&rtunpb.GatewayResponse{Kind: &rtunpb.GatewayResponse_Frame{ + _ = stream.Send(&rtunpb.ReverseDialerServiceOpenResponse{Kind: &rtunpb.ReverseDialerServiceOpenResponse_Frame{ Frame: &rtunpb.Frame{Sid: gsid, Kind: &rtunpb.Frame_Rst{Rst: &rtunpb.Rst{Code: rtunpb.RstCode_RST_CODE_INTERNAL}}}, }}) continue @@ -253,13 +254,13 @@ func (s *Server) Open(stream rtunpb.ReverseDialer_OpenServer) error { } // bridgeRead reads from rtun conn and sends frames to the caller stream. -func (s *Server) bridgeRead(ctx context.Context, conn net.Conn, gsid uint32, stream rtunpb.ReverseDialer_OpenServer, wg *sync.WaitGroup, logger *zap.Logger) { +func (s *Server) bridgeRead(ctx context.Context, conn net.Conn, gsid uint32, stream rtunpb.ReverseDialerService_OpenServer, wg *sync.WaitGroup, logger *zap.Logger) { defer wg.Done() buf := make([]byte, 32*1024) for { n, err := conn.Read(buf) if n > 0 { - if err := stream.Send(&rtunpb.GatewayResponse{Kind: &rtunpb.GatewayResponse_Frame{ + if err := stream.Send(&rtunpb.ReverseDialerServiceOpenResponse{Kind: &rtunpb.ReverseDialerServiceOpenResponse_Frame{ Frame: &rtunpb.Frame{Sid: gsid, Kind: &rtunpb.Frame_Data{Data: &rtunpb.Data{Payload: append([]byte(nil), buf[:n]...)}}}, }}); err != nil { logger.Warn("failed to send data to caller", zap.Error(err)) @@ -270,15 +271,15 @@ func (s *Server) bridgeRead(ctx context.Context, conn net.Conn, gsid uint32, str } } if err != nil { - if err == io.EOF { - _ = stream.Send(&rtunpb.GatewayResponse{Kind: &rtunpb.GatewayResponse_Frame{ + if errors.Is(err, io.EOF) { + _ = stream.Send(&rtunpb.ReverseDialerServiceOpenResponse{Kind: &rtunpb.ReverseDialerServiceOpenResponse_Frame{ Frame: &rtunpb.Frame{Sid: gsid, Kind: &rtunpb.Frame_Fin{Fin: &rtunpb.Fin{}}}, }}) if s.m != nil { s.m.addFrameTx(ctx, "FIN") } } else { - _ = stream.Send(&rtunpb.GatewayResponse{Kind: &rtunpb.GatewayResponse_Frame{ + _ = stream.Send(&rtunpb.ReverseDialerServiceOpenResponse{Kind: &rtunpb.ReverseDialerServiceOpenResponse_Frame{ Frame: &rtunpb.Frame{Sid: gsid, Kind: &rtunpb.Frame_Rst{Rst: &rtunpb.Rst{Code: rtunpb.RstCode_RST_CODE_INTERNAL}}}, }}) if s.m != nil { diff --git a/pkg/rtun/match/route_test.go b/pkg/rtun/match/route_test.go index 572152ed1..21264c616 100644 --- a/pkg/rtun/match/route_test.go +++ b/pkg/rtun/match/route_test.go @@ -36,7 +36,7 @@ func TestOwnerRouterTwoServers(t *testing.T) { regA := server.NewRegistry() handlerA := server.NewHandler(regA, "server-a", testValidator{id: "client-123"}) gsrvA := grpc.NewServer() - rtunpb.RegisterReverseTunnelServer(gsrvA, handlerA) + rtunpb.RegisterReverseTunnelServiceServer(gsrvA, handlerA) lA, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) go func() { _ = gsrvA.Serve(lA) }() @@ -50,7 +50,7 @@ func TestOwnerRouterTwoServers(t *testing.T) { ccA, err := grpc.Dial(lA.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) require.NoError(t, err) - rtunClientA := rtunpb.NewReverseTunnelClient(ccA) + rtunClientA := rtunpb.NewReverseTunnelServiceClient(ccA) streamA, err := rtunClientA.Link(clientCtx) require.NoError(t, err) @@ -110,9 +110,19 @@ func TestOwnerRouterTwoServers(t *testing.T) { // clientLink adapts the client bidi stream to transport.Link. type clientLink struct { - cli rtunpb.ReverseTunnel_LinkClient + cli rtunpb.ReverseTunnelService_LinkClient } -func (c *clientLink) Send(f *rtunpb.Frame) error { return c.cli.Send(f) } -func (c *clientLink) Recv() (*rtunpb.Frame, error) { return c.cli.Recv() } -func (c *clientLink) Context() context.Context { return c.cli.Context() } +func (c *clientLink) Send(f *rtunpb.Frame) error { + return c.cli.Send(&rtunpb.ReverseTunnelServiceLinkRequest{Frame: f}) +} + +func (c *clientLink) Recv() (*rtunpb.Frame, error) { + fr, err := c.cli.Recv() + if err != nil { + return nil, err + } + return fr.GetFrame(), nil +} + +func (c *clientLink) Context() context.Context { return c.cli.Context() } diff --git a/pkg/rtun/server/handler.go b/pkg/rtun/server/handler.go index 01819e7d4..0cf2da6e2 100644 --- a/pkg/rtun/server/handler.go +++ b/pkg/rtun/server/handler.go @@ -13,7 +13,7 @@ import ( // Handler implements the ReverseTunnel gRPC service, binding Links to Sessions and the Registry. type Handler struct { - rtunpb.UnimplementedReverseTunnelServer + rtunpb.UnimplementedReverseTunnelServiceServer reg *Registry serverID string @@ -25,7 +25,7 @@ type Handler struct { m *serverMetrics } -func NewHandler(reg *Registry, serverID string, tv TokenValidator, opts ...Option) rtunpb.ReverseTunnelServer { +func NewHandler(reg *Registry, serverID string, tv TokenValidator, opts ...Option) rtunpb.ReverseTunnelServiceServer { var o options for _, opt := range opts { if opt == nil { @@ -41,7 +41,7 @@ func NewHandler(reg *Registry, serverID string, tv TokenValidator, opts ...Optio } // Link accepts a bidi stream and binds it to a transport.Session after validating HELLO. -func (h *Handler) Link(stream rtunpb.ReverseTunnel_LinkServer) error { +func (h *Handler) Link(stream rtunpb.ReverseTunnelService_LinkServer) error { // Wrap the gRPC stream as transport.Link l := &grpcLink{srv: stream} @@ -136,9 +136,19 @@ func (h *Handler) Link(stream rtunpb.ReverseTunnel_LinkServer) error { // grpcLink adapts the gRPC server stream to transport.Link type grpcLink struct { - srv rtunpb.ReverseTunnel_LinkServer + srv rtunpb.ReverseTunnelService_LinkServer } -func (g *grpcLink) Send(f *rtunpb.Frame) error { return g.srv.Send(f) } -func (g *grpcLink) Recv() (*rtunpb.Frame, error) { return g.srv.Recv() } -func (g *grpcLink) Context() context.Context { return g.srv.Context() } +func (g *grpcLink) Send(f *rtunpb.Frame) error { + return g.srv.Send(&rtunpb.ReverseTunnelServiceLinkResponse{Frame: f}) +} + +func (g *grpcLink) Recv() (*rtunpb.Frame, error) { + fr, err := g.srv.Recv() + if err != nil { + return nil, err + } + return fr.GetFrame(), nil +} + +func (g *grpcLink) Context() context.Context { return g.srv.Context() } diff --git a/pkg/rtun/server/server_integration_test.go b/pkg/rtun/server/server_integration_test.go index 62d3db82e..c91485287 100644 --- a/pkg/rtun/server/server_integration_test.go +++ b/pkg/rtun/server/server_integration_test.go @@ -24,12 +24,22 @@ func (t testValidator) ValidateHello(ctx context.Context, hello *rtunpb.Hello) e // clientLink adapts the client bidi stream to transport.Link on the client side. type clientLink struct { - cli rtunpb.ReverseTunnel_LinkClient + cli rtunpb.ReverseTunnelService_LinkClient } -func (c clientLink) Send(f *rtunpb.Frame) error { return c.cli.Send(f) } -func (c clientLink) Recv() (*rtunpb.Frame, error) { return c.cli.Recv() } -func (c clientLink) Context() context.Context { return c.cli.Context() } +func (c clientLink) Send(f *rtunpb.Frame) error { + return c.cli.Send(&rtunpb.ReverseTunnelServiceLinkRequest{Frame: f}) +} + +func (c clientLink) Recv() (*rtunpb.Frame, error) { + resp, err := c.cli.Recv() + if err != nil { + return nil, err + } + return resp.GetFrame(), nil +} + +func (c clientLink) Context() context.Context { return c.cli.Context() } // TestReverseGrpcE2E spins up a real gRPC server with Handler, connects a real gRPC client stream for Link, // runs the standard gRPC health service over rtun on the client, and performs a health check from the owner via Registry.DialContext. @@ -38,7 +48,7 @@ func TestReverseGrpcE2E(t *testing.T) { reg := NewRegistry() h := NewHandler(reg, "server-1", testValidator{id: "client-123"}) gsrv := grpc.NewServer() - rtunpb.RegisterReverseTunnelServer(gsrv, h) + rtunpb.RegisterReverseTunnelServiceServer(gsrv, h) l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) defer l.Close() @@ -49,7 +59,7 @@ func TestReverseGrpcE2E(t *testing.T) { cc, err := grpc.Dial(l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) require.NoError(t, err) defer cc.Close() - rtunClient := rtunpb.NewReverseTunnelClient(cc) + rtunClient := rtunpb.NewReverseTunnelServiceClient(cc) stream, err := rtunClient.Link(context.Background()) require.NoError(t, err) diff --git a/pkg/rtun/transport/conn.go b/pkg/rtun/transport/conn.go index 6b2a3c249..7d29c1331 100644 --- a/pkg/rtun/transport/conn.go +++ b/pkg/rtun/transport/conn.go @@ -121,7 +121,7 @@ func (c *virtConn) Write(p []byte) (int, error) { c.writeMu.Lock() defer c.writeMu.Unlock() if c.writeClosed { - return 0, ErrClosed + return 0, net.ErrClosed } // Writes are allowed even after remote FIN (half-close), so we do not block based on remote state. total := 0 diff --git a/pkg/rtun/transport/conn_test.go b/pkg/rtun/transport/conn_test.go index 2b4e61ba5..44ea5721f 100644 --- a/pkg/rtun/transport/conn_test.go +++ b/pkg/rtun/transport/conn_test.go @@ -37,11 +37,11 @@ func TestVirtConnCloseIdempotentAndWriteAfterClose(t *testing.T) { // Write after close should fail with ErrClosed _, err = c.Write([]byte("x")) - require.ErrorIs(t, err, ErrClosed) + require.ErrorIs(t, err, net.ErrClosed) // Read after close yields EOF or ErrClosed _, err = c.Read(make([]byte, 1)) - require.True(t, err == io.EOF || err == ErrClosed) + require.True(t, err == io.EOF || err == net.ErrClosed) } func TestVirtConnRemoteRstPropagatesToRead(t *testing.T) { diff --git a/pkg/rtun/transport/errors.go b/pkg/rtun/transport/errors.go index e18230dda..a77acd39f 100644 --- a/pkg/rtun/transport/errors.go +++ b/pkg/rtun/transport/errors.go @@ -4,6 +4,5 @@ import "errors" var ( ErrConnReset = errors.New("rtun: connection reset") - ErrClosed = errors.New("rtun: closed") ErrTimeout = errors.New("rtun: deadline exceeded") ) diff --git a/pkg/rtun/transport/listener.go b/pkg/rtun/transport/listener.go index c8fc1e3f4..a6c3fcfe6 100644 --- a/pkg/rtun/transport/listener.go +++ b/pkg/rtun/transport/listener.go @@ -15,15 +15,21 @@ type rtunListener struct { } func (l *rtunListener) Accept() (net.Conn, error) { - if l.err != nil { - return nil, l.err + l.mu.Lock() + lerr := l.err + l.mu.Unlock() + if lerr != nil { + return nil, lerr } c, ok := <-l.accepts if !ok { - if l.err != nil { - return nil, l.err + l.mu.Lock() + lerr := l.err + l.mu.Unlock() + if lerr != nil { + return nil, lerr } - return nil, ErrClosed + return nil, net.ErrClosed } return c, nil } @@ -48,7 +54,7 @@ func (l *rtunListener) enqueue(c *virtConn) { case l.accepts <- c: default: // listener full, drop - c.handleRst(ErrClosed) + c.handleRst(net.ErrClosed) } } diff --git a/pkg/rtun/transport/session.go b/pkg/rtun/transport/session.go index b3b8534d3..e99ae09ec 100644 --- a/pkg/rtun/transport/session.go +++ b/pkg/rtun/transport/session.go @@ -164,7 +164,7 @@ func (s *Session) addListener(l *rtunListener) error { s.mu.Lock() defer s.mu.Unlock() if s.closing { - return ErrClosed + return net.ErrClosed } if _, exists := s.listeners[l.port]; exists { return errors.New("rtun: listener already exists for port") @@ -348,7 +348,7 @@ func (s *Session) Open(ctx context.Context, port uint32) (net.Conn, error) { s.mu.Lock() if s.closing { s.mu.Unlock() - return nil, ErrClosed + return nil, net.ErrClosed } sid := s.nextSID if sid == 0 { diff --git a/proto/c1/connectorapi/rtun/v1/gateway.proto b/proto/c1/connectorapi/rtun/v1/gateway.proto index 88cca67fd..d7564aa48 100644 --- a/proto/c1/connectorapi/rtun/v1/gateway.proto +++ b/proto/c1/connectorapi/rtun/v1/gateway.proto @@ -8,20 +8,20 @@ option go_package = "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v // ReverseDialer allows callers to establish connections to clients via the gateway. // The gateway bridges caller streams to rtun sessions on the owner server. -service ReverseDialer { - rpc Open(stream GatewayRequest) returns (stream GatewayResponse); +service ReverseDialerService { + rpc Open(stream ReverseDialerServiceOpenRequest) returns (stream ReverseDialerServiceOpenResponse); } -// GatewayRequest is sent from caller to gateway. -message GatewayRequest { +// ReverseDialerServiceOpenRequest is sent from caller to gateway for the Open RPC. +message ReverseDialerServiceOpenRequest { oneof kind { OpenRequest open_req = 1; // initiate a connection (first message, or concurrent opens) Frame frame = 10; // data/control frames (reuses rtun Frame) } } -// GatewayResponse is sent from gateway to caller. -message GatewayResponse { +// ReverseDialerServiceOpenResponse is sent from gateway to caller for the Open RPC. +message ReverseDialerServiceOpenResponse { oneof kind { OpenResponse open_resp = 1; // handshake result Frame frame = 10; // data/control frames (reuses rtun Frame) diff --git a/proto/c1/connectorapi/rtun/v1/rtun.proto b/proto/c1/connectorapi/rtun/v1/rtun.proto index 922d94b87..459433ecc 100644 --- a/proto/c1/connectorapi/rtun/v1/rtun.proto +++ b/proto/c1/connectorapi/rtun/v1/rtun.proto @@ -4,8 +4,16 @@ package c1.connectorapi.rtun.v1; option go_package = "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1"; -service ReverseTunnel { - rpc Link(stream Frame) returns (stream Frame); +service ReverseTunnelService { + rpc Link(stream ReverseTunnelServiceLinkRequest) returns (stream ReverseTunnelServiceLinkResponse); +} + +message ReverseTunnelServiceLinkRequest { + Frame frame = 1; +} + +message ReverseTunnelServiceLinkResponse { + Frame frame = 1; } message Frame { From d16408b47344a4027b74e5d1b4c94afadbc2cbc2 Mon Sep 17 00:00:00 2001 From: Paul Querna Date: Sun, 12 Oct 2025 19:17:04 -0700 Subject: [PATCH 3/9] remove recommended grpc funcs - not needed --- pkg/rtun/gateway/grpc_options.go | 25 ------------------------- pkg/rtun/server/grpc_options.go | 28 ---------------------------- 2 files changed, 53 deletions(-) delete mode 100644 pkg/rtun/gateway/grpc_options.go delete mode 100644 pkg/rtun/server/grpc_options.go diff --git a/pkg/rtun/gateway/grpc_options.go b/pkg/rtun/gateway/grpc_options.go deleted file mode 100644 index 706577875..000000000 --- a/pkg/rtun/gateway/grpc_options.go +++ /dev/null @@ -1,25 +0,0 @@ -package gateway - -import ( - "time" - - "google.golang.org/grpc" - "google.golang.org/grpc/keepalive" -) - -// RecommendedGRPCServerOptions returns server options enabling basic keepalive and -// reasonable message size limits suitable for the gateway service. -func RecommendedGRPCServerOptions() []grpc.ServerOption { - return []grpc.ServerOption{ - grpc.KeepaliveParams(keepalive.ServerParameters{ - Time: 30 * time.Second, - Timeout: 10 * time.Second, - MaxConnectionIdle: 0, - MaxConnectionAge: 0, - }), - grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{ - MinTime: 10 * time.Second, - PermitWithoutStream: true, - }), - } -} diff --git a/pkg/rtun/server/grpc_options.go b/pkg/rtun/server/grpc_options.go deleted file mode 100644 index 2dcdf23f5..000000000 --- a/pkg/rtun/server/grpc_options.go +++ /dev/null @@ -1,28 +0,0 @@ -package server - -import ( - "time" - - "google.golang.org/grpc" - "google.golang.org/grpc/keepalive" -) - -// RecommendedGRPCServerOptions returns server options enabling basic keepalive and -// reasonable message size limits suitable for RTUN services. -func RecommendedGRPCServerOptions() []grpc.ServerOption { - return []grpc.ServerOption{ - grpc.MaxRecvMsgSize(4 * 1024 * 1024), - grpc.MaxSendMsgSize(4 * 1024 * 1024), - grpc.MaxConcurrentStreams(250), - grpc.KeepaliveParams(keepalive.ServerParameters{ - Time: 30 * time.Second, - Timeout: 10 * time.Second, - MaxConnectionIdle: 0, - MaxConnectionAge: 0, - }), - grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{ - MinTime: 10 * time.Second, - PermitWithoutStream: true, - }), - } -} From c39f9f346a7928f044a294fbcdef253cad8bb5e0 Mon Sep 17 00:00:00 2001 From: Paul Querna Date: Sun, 12 Oct 2025 20:01:21 -0700 Subject: [PATCH 4/9] review feedback --- pkg/rtun/gateway/client.go | 38 ++++++++++++++++------- pkg/rtun/match/directory.go | 5 ++++ pkg/rtun/match/errors.go | 1 + pkg/rtun/match/locator.go | 3 ++ pkg/rtun/match/memory/directory.go | 2 ++ pkg/rtun/match/memory/presence.go | 1 + pkg/rtun/match/route_test.go | 4 +-- pkg/rtun/server/handler.go | 4 ++- pkg/rtun/server/options.go | 8 +++-- pkg/rtun/server/registry.go | 2 ++ pkg/rtun/transport/conn.go | 24 ++++++++------- pkg/rtun/transport/conn_test.go | 3 +- pkg/rtun/transport/errors.go | 4 ++- pkg/rtun/transport/listener.go | 3 ++ pkg/rtun/transport/session.go | 40 +++++++++++++++---------- pkg/rtun/transport/session_race_test.go | 8 ++--- pkg/rtun/transport/session_test.go | 4 +-- 17 files changed, 104 insertions(+), 50 deletions(-) diff --git a/pkg/rtun/gateway/client.go b/pkg/rtun/gateway/client.go index 4dc4f2795..895209942 100644 --- a/pkg/rtun/gateway/client.go +++ b/pkg/rtun/gateway/client.go @@ -97,7 +97,7 @@ func (d *Dialer) DialContext(ctx context.Context, clientID string, port uint32) if err := stream.Send(&rtunpb.ReverseDialerServiceOpenRequest{Kind: &rtunpb.ReverseDialerServiceOpenRequest_OpenReq{ OpenReq: &rtunpb.OpenRequest{Gsid: gsid, ClientId: clientID, Port: port}, }}); err != nil { - stream.CloseSend() + _ = stream.CloseSend() cancel() cc.Close() return nil, fmt.Errorf("gateway send OpenRequest failed: %w", err) @@ -106,7 +106,7 @@ func (d *Dialer) DialContext(ctx context.Context, clientID string, port uint32) // Recv OpenResponse resp, err := stream.Recv() if err != nil { - stream.CloseSend() + _ = stream.CloseSend() cancel() cc.Close() return nil, fmt.Errorf("gateway recv OpenResponse failed: %w", err) @@ -114,13 +114,13 @@ func (d *Dialer) DialContext(ctx context.Context, clientID string, port uint32) openResp := resp.GetOpenResp() if openResp == nil { - stream.CloseSend() + _ = stream.CloseSend() cancel() cc.Close() return nil, ErrProtocol } if openResp.GetGsid() != gsid { - stream.CloseSend() + _ = stream.CloseSend() cancel() cc.Close() return nil, fmt.Errorf("gateway returned mismatched gSID: got %d, want %d", openResp.GetGsid(), gsid) @@ -128,7 +128,7 @@ func (d *Dialer) DialContext(ctx context.Context, clientID string, port uint32) switch openResp.Result.(type) { case *rtunpb.OpenResponse_NotFound: - stream.CloseSend() + _ = stream.CloseSend() cancel() cc.Close() logger.Info("client not found on gateway") @@ -147,7 +147,7 @@ func (d *Dialer) DialContext(ctx context.Context, clientID string, port uint32) gc.w = newWriter(stream, gsid, d.writeQueueCap, doneCh) return gc, nil default: - stream.CloseSend() + _ = stream.CloseSend() cancel() cc.Close() return nil, ErrProtocol @@ -167,6 +167,7 @@ type gatewayConn struct { r *reader w *writer + readMu sync.Mutex writeMu sync.Mutex writeClosed bool @@ -388,14 +389,18 @@ func (g *gatewayConn) Read(p []byte) (int, error) { } g.r.mu.Unlock() - // Compute deadline context + // Compute deadline context (snapshot under lock) + g.readMu.Lock() + deadline := g.rdDeadline + g.readMu.Unlock() + var ctx context.Context var cancel context.CancelFunc - if g.rdDeadline.IsZero() { + if deadline.IsZero() { ctx = context.Background() cancel = func() {} } else { - until := time.Until(g.rdDeadline) + until := time.Until(deadline) if until <= 0 { return 0, context.DeadlineExceeded } @@ -439,7 +444,10 @@ func (g *gatewayConn) Write(p []byte) (int, error) { g.writeMu.Unlock() return 0, fmt.Errorf("rtun/gateway: write on closed connection: %w", net.ErrClosed) } - g.writeMu.Unlock() + // snapshot write deadline while holding the write lock + deadline := g.wrDeadline + // keep writeMu locked for the duration to serialize writes and SetWriteDeadline + defer g.writeMu.Unlock() if err := g.w.getErr(); err != nil { return 0, err @@ -452,7 +460,7 @@ func (g *gatewayConn) Write(p []byte) (int, error) { chunk = p[:maxChunkSize] } cp := append([]byte(nil), chunk...) - if err := g.w.enqueue(writeMsg{payload: cp}, g.wrDeadline); err != nil { + if err := g.w.enqueue(writeMsg{payload: cp}, deadline); err != nil { if total == 0 { return 0, err } @@ -500,18 +508,26 @@ func (g *gatewayConn) LocalAddr() net.Addr { return gatewayAddr{"gateway-local" func (g *gatewayConn) RemoteAddr() net.Addr { return gatewayAddr{"gateway-remote"} } func (g *gatewayConn) SetDeadline(t time.Time) error { + g.readMu.Lock() g.rdDeadline = t + g.readMu.Unlock() + g.writeMu.Lock() g.wrDeadline = t + g.writeMu.Unlock() return nil } func (g *gatewayConn) SetReadDeadline(t time.Time) error { + g.readMu.Lock() g.rdDeadline = t + g.readMu.Unlock() return nil } func (g *gatewayConn) SetWriteDeadline(t time.Time) error { + g.writeMu.Lock() g.wrDeadline = t + g.writeMu.Unlock() return nil } diff --git a/pkg/rtun/match/directory.go b/pkg/rtun/match/directory.go index dc4197583..366ee07b7 100644 --- a/pkg/rtun/match/directory.go +++ b/pkg/rtun/match/directory.go @@ -6,10 +6,15 @@ import ( "time" ) +// ErrNotImplemented is returned by placeholder implementations. var ErrNotImplemented = errors.New("rtun/match: not implemented") +// Directory maps server IDs to dialable addresses, with TTL support. type Directory interface { + // Advertise publishes a server's address with a TTL. Advertise(ctx context.Context, serverID string, addr string, ttl time.Duration) error + // Revoke removes a server's advertisement. Revoke(ctx context.Context, serverID string) error + // Resolve returns the address for a server ID if present and not expired. Resolve(ctx context.Context, serverID string) (addr string, err error) } diff --git a/pkg/rtun/match/errors.go b/pkg/rtun/match/errors.go index d762a9278..6f9c4ae74 100644 --- a/pkg/rtun/match/errors.go +++ b/pkg/rtun/match/errors.go @@ -3,5 +3,6 @@ package match import "errors" var ( + // ErrClientOffline indicates no servers currently advertise the client. ErrClientOffline = errors.New("rtun/match: client offline") ) diff --git a/pkg/rtun/match/locator.go b/pkg/rtun/match/locator.go index 864b7abbf..e417317dc 100644 --- a/pkg/rtun/match/locator.go +++ b/pkg/rtun/match/locator.go @@ -5,10 +5,13 @@ import ( "hash/fnv" ) +// Locator selects the owning server for a client based on presence and rendezvous hashing. type Locator struct { Presence Presence } +// OwnerOf returns the server that currently owns the client's link along with the client's ports. +// It uses Presence to list available servers and rendezvous hashing to choose among them. func (l *Locator) OwnerOf(ctx context.Context, clientID string) (serverID string, ports []uint32, err error) { if l == nil || l.Presence == nil { return "", nil, ErrNotImplemented diff --git a/pkg/rtun/match/memory/directory.go b/pkg/rtun/match/memory/directory.go index a95fe4d8f..9e3c93074 100644 --- a/pkg/rtun/match/memory/directory.go +++ b/pkg/rtun/match/memory/directory.go @@ -7,6 +7,7 @@ import ( "time" ) +// ErrServerNotFound indicates a server ID has no active advertisement. var ErrServerNotFound = errors.New("rtun/match: server not found") // Directory is an in-memory Directory for tests and single-node deployments. @@ -20,6 +21,7 @@ type record struct { expires time.Time } +// NewDirectory returns an in-memory Directory. func NewDirectory() *Directory { return &Directory{ servers: make(map[string]record), diff --git a/pkg/rtun/match/memory/presence.go b/pkg/rtun/match/memory/presence.go index 177e9bed6..899d2dc91 100644 --- a/pkg/rtun/match/memory/presence.go +++ b/pkg/rtun/match/memory/presence.go @@ -15,6 +15,7 @@ type Presence struct { ports map[string][]uint32 // clientID -> ports } +// NewPresence returns an in-memory Presence tracker. func NewPresence() *Presence { return &Presence{ leases: make(map[string]map[string]time.Time), diff --git a/pkg/rtun/match/route_test.go b/pkg/rtun/match/route_test.go index 21264c616..4f82cd9fa 100644 --- a/pkg/rtun/match/route_test.go +++ b/pkg/rtun/match/route_test.go @@ -69,8 +69,8 @@ func TestOwnerRouterTwoServers(t *testing.T) { go func() { _ = cgsA.Serve(lnA) }() // Mark client-123 online on server-a in presence - _ = presence.SetPorts(ctx, "client-123", []uint32{1}) - _ = presence.Announce(ctx, "client-123", "server-a", 10*time.Second) + require.NoError(t, presence.SetPorts(ctx, "client-123", []uint32{1})) + require.NoError(t, presence.Announce(ctx, "client-123", "server-a", 10*time.Second)) // Server B (caller) uses OwnerRouter to find and dial server A router := &OwnerRouter{ diff --git a/pkg/rtun/server/handler.go b/pkg/rtun/server/handler.go index 0cf2da6e2..9f7968a94 100644 --- a/pkg/rtun/server/handler.go +++ b/pkg/rtun/server/handler.go @@ -25,6 +25,8 @@ type Handler struct { m *serverMetrics } +// NewHandler constructs a ReverseTunnel gRPC handler bound to a `Registry` and `TokenValidator`. +// It optionally enables metrics if provided via options. func NewHandler(reg *Registry, serverID string, tv TokenValidator, opts ...Option) rtunpb.ReverseTunnelServiceServer { var o options for _, opt := range opts { @@ -134,7 +136,7 @@ func (h *Handler) Link(stream rtunpb.ReverseTunnelService_LinkServer) error { return stream.Context().Err() } -// grpcLink adapts the gRPC server stream to transport.Link +// grpcLink adapts the gRPC server stream to transport.Link. type grpcLink struct { srv rtunpb.ReverseTunnelService_LinkServer } diff --git a/pkg/rtun/server/options.go b/pkg/rtun/server/options.go index 091cc910e..88538125d 100644 --- a/pkg/rtun/server/options.go +++ b/pkg/rtun/server/options.go @@ -7,11 +7,15 @@ import ( ) var ( + // ErrNotImplemented is returned when a feature isn't implemented. ErrNotImplemented = errors.New("rtun/server: not implemented") - ErrProtocol = errors.New("rtun/server: protocol error") - ErrHelloTimeout = errors.New("rtun/server: hello timeout") + // ErrProtocol indicates a protocol violation on the stream. + ErrProtocol = errors.New("rtun/server: protocol error") + // ErrHelloTimeout indicates the HELLO frame was not received in time. + ErrHelloTimeout = errors.New("rtun/server: hello timeout") ) +// Option configures server components such as the handler and registry. type Option func(*options) type options struct { diff --git a/pkg/rtun/server/registry.go b/pkg/rtun/server/registry.go index 35a16aff8..e1b9b766c 100644 --- a/pkg/rtun/server/registry.go +++ b/pkg/rtun/server/registry.go @@ -11,12 +11,14 @@ import ( "github.com/conductorone/baton-sdk/pkg/rtun/transport" ) +// Registry tracks active client Sessions by client ID and mediates reverse dials. type Registry struct { mu sync.RWMutex sessions map[string]*transport.Session m *serverMetrics } +// NewRegistry creates a new Registry. Metrics can be enabled via options. func NewRegistry(opts ...Option) *Registry { var o options for _, opt := range opts { diff --git a/pkg/rtun/transport/conn.go b/pkg/rtun/transport/conn.go index 7d29c1331..970e30744 100644 --- a/pkg/rtun/transport/conn.go +++ b/pkg/rtun/transport/conn.go @@ -43,10 +43,11 @@ func newVirtConn(m *Session, sid uint32) *virtConn { return &virtConn{ mux: m, sid: sid, - readCh: make(chan []byte, 16), + readCh: make(chan []byte, 256), } } +// Read implements net.Conn for virtConn by returning data delivered for this SID, honoring read deadlines. func (c *virtConn) Read(p []byte) (int, error) { // Check terminal error or remainder under lock c.readMu.Lock() @@ -117,6 +118,7 @@ func (c *virtConn) Read(p []byte) (int, error) { } } +// Write implements net.Conn for virtConn by sending DATA frames for this SID, honoring write deadlines. func (c *virtConn) Write(p []byte) (int, error) { c.writeMu.Lock() defer c.writeMu.Unlock() @@ -151,6 +153,7 @@ func (c *virtConn) Write(p []byte) (int, error) { return total, nil } +// Close implements net.Conn for virtConn by half-closing writes (sending FIN) and releasing resources. func (c *virtConn) Close() error { c.writeMu.Lock() if c.writeClosed { @@ -175,9 +178,13 @@ func (c *virtConn) Close() error { return nil } -func (c *virtConn) LocalAddr() net.Addr { return rtunAddr{"rtun-local"} } +// LocalAddr implements net.Conn. +func (c *virtConn) LocalAddr() net.Addr { return rtunAddr{"rtun-local"} } + +// RemoteAddr implements net.Conn. func (c *virtConn) RemoteAddr() net.Addr { return rtunAddr{"rtun-remote"} } +// SetDeadline implements net.Conn. func (c *virtConn) SetDeadline(t time.Time) error { c.readMu.Lock() c.rdDeadline = t @@ -188,6 +195,7 @@ func (c *virtConn) SetDeadline(t time.Time) error { return nil } +// SetReadDeadline implements net.Conn. func (c *virtConn) SetReadDeadline(t time.Time) error { c.readMu.Lock() c.rdDeadline = t @@ -195,6 +203,7 @@ func (c *virtConn) SetReadDeadline(t time.Time) error { return nil } +// SetWriteDeadline implements net.Conn. func (c *virtConn) SetWriteDeadline(t time.Time) error { c.writeMu.Lock() c.wrDeadline = t @@ -216,15 +225,10 @@ func (c *virtConn) feedData(b []byte) { // delivered c.onActivity() default: - // backpressure: mark error, close channel, send RST, and detach from session to avoid further deliveries - c.readMu.Lock() - if c.readErr == nil { - c.readErr = errors.New("rtun: inbound buffer overflow") - } - c.readMu.Unlock() - c.closeReadOnce.Do(func() { close(c.readCh) }) + // backpressure: send RST, perform full RST handling (including write-side close), and detach from session + err := errors.New("rtun: inbound buffer overflow") _ = c.mux.link.Send(&rtunpb.Frame{Sid: c.sid, Kind: &rtunpb.Frame_Rst{Rst: &rtunpb.Rst{Code: rtunpb.RstCode_RST_CODE_INTERNAL}}}) - c.stopIdleTimer() + c.handleRst(err) c.mux.removeConn(c.sid) } } diff --git a/pkg/rtun/transport/conn_test.go b/pkg/rtun/transport/conn_test.go index 44ea5721f..b9d310447 100644 --- a/pkg/rtun/transport/conn_test.go +++ b/pkg/rtun/transport/conn_test.go @@ -2,6 +2,7 @@ package transport import ( "context" + "errors" "io" "net" "testing" @@ -41,7 +42,7 @@ func TestVirtConnCloseIdempotentAndWriteAfterClose(t *testing.T) { // Read after close yields EOF or ErrClosed _, err = c.Read(make([]byte, 1)) - require.True(t, err == io.EOF || err == net.ErrClosed) + require.True(t, errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed)) } func TestVirtConnRemoteRstPropagatesToRead(t *testing.T) { diff --git a/pkg/rtun/transport/errors.go b/pkg/rtun/transport/errors.go index a77acd39f..c8b2d0757 100644 --- a/pkg/rtun/transport/errors.go +++ b/pkg/rtun/transport/errors.go @@ -3,6 +3,8 @@ package transport import "errors" var ( + // ErrConnReset indicates the remote reset the connection. ErrConnReset = errors.New("rtun: connection reset") - ErrTimeout = errors.New("rtun: deadline exceeded") + // ErrTimeout indicates an operation exceeded its deadline. + ErrTimeout = errors.New("rtun: deadline exceeded") ) diff --git a/pkg/rtun/transport/listener.go b/pkg/rtun/transport/listener.go index a6c3fcfe6..566f5cb65 100644 --- a/pkg/rtun/transport/listener.go +++ b/pkg/rtun/transport/listener.go @@ -14,6 +14,7 @@ type rtunListener struct { err error } +// Accept implements net.Listener and returns the next inbound connection. func (l *rtunListener) Accept() (net.Conn, error) { l.mu.Lock() lerr := l.err @@ -34,6 +35,7 @@ func (l *rtunListener) Accept() (net.Conn, error) { return c, nil } +// Close implements net.Listener and stops accepting new connections. func (l *rtunListener) Close() error { l.mu.Lock() if l.closed { @@ -47,6 +49,7 @@ func (l *rtunListener) Close() error { return nil } +// Addr implements net.Listener. func (l *rtunListener) Addr() net.Addr { return rtunAddr{"rtun-listener"} } func (l *rtunListener) enqueue(c *virtConn) { diff --git a/pkg/rtun/transport/session.go b/pkg/rtun/transport/session.go index e99ae09ec..e2258647d 100644 --- a/pkg/rtun/transport/session.go +++ b/pkg/rtun/transport/session.go @@ -154,10 +154,12 @@ func (s *Session) Listen(ctx context.Context, port uint32, opts ...Option) (net. return l, nil } -func (s *Session) removeConn(sid uint32) { +func (s *Session) removeConn(sid uint32) int { s.mu.Lock() delete(s.conns, sid) + n := len(s.conns) s.mu.Unlock() + return n } func (s *Session) addListener(l *rtunListener) error { @@ -253,9 +255,8 @@ func (s *Session) recvLoop() { } s.conns[sid] = vc vc.startIdleTimer() - if s.m != nil { - s.m.incSidsActive(s.link.Context(), 1) - } + // capture new connection count under lock for gauge update + newCount := len(s.conns) // Drain any pending data queued before SYN if q := s.pending[sid]; len(q) > 0 { for _, p := range q { @@ -264,6 +265,9 @@ func (s *Session) recvLoop() { delete(s.pending, sid) } s.mu.Unlock() + if s.m != nil { + s.m.setSidsActive(s.link.Context(), int64(newCount)) + } l.enqueue(vc) case *rtunpb.Frame_Data: s.mu.Lock() @@ -295,12 +299,12 @@ func (s *Session) recvLoop() { s.mu.Unlock() if c != nil { c.handleFin(k.Fin.GetAck()) - s.removeConn(sid) + n := s.removeConn(sid) s.mu.Lock() s.closed.Close(sid) s.mu.Unlock() if s.m != nil { - s.m.incSidsActive(s.link.Context(), -1) + s.m.setSidsActive(s.link.Context(), int64(n)) } } case *rtunpb.Frame_Rst: @@ -312,12 +316,12 @@ func (s *Session) recvLoop() { s.m.recordRstRecv(s.link.Context(), k.Rst.GetCode().String()) } c.handleRst(ErrConnReset) - s.removeConn(sid) + n := s.removeConn(sid) s.mu.Lock() s.closed.Close(sid) s.mu.Unlock() if s.m != nil { - s.m.incSidsActive(s.link.Context(), -1) + s.m.setSidsActive(s.link.Context(), int64(n)) } } } @@ -362,12 +366,19 @@ func (s *Session) Open(ctx context.Context, port uint32) (net.Conn, error) { // Send SYN to remote if err := s.link.Send(&rtunpb.Frame{Sid: sid, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: port}}}); err != nil { // Cleanup on failure - s.removeConn(sid) + n := s.removeConn(sid) + if s.m != nil { + s.m.setSidsActive(s.link.Context(), int64(n)) + } return nil, err } vc.startIdleTimer() if s.m != nil { - s.m.incSidsActive(s.link.Context(), 1) + // set gauge to current number of conns after successful open + s.mu.Lock() + n := len(s.conns) + s.mu.Unlock() + s.m.setSidsActive(s.link.Context(), int64(n)) s.m.recordFrameTx(s.link.Context(), "SYN") } return vc, nil @@ -399,7 +410,7 @@ func (s *Session) failLocked(err error) { s.mu.Unlock() } -// metrics helpers +// metrics helpers. type transportMetrics struct { framesRx sdkmetrics.Int64Counter framesTx sdkmetrics.Int64Counter @@ -408,8 +419,6 @@ type transportMetrics struct { rstSent sdkmetrics.Int64Counter rstRecv sdkmetrics.Int64Counter sidsGauge sdkmetrics.Int64Gauge - - sids int64 } func newTransportMetrics(h sdkmetrics.Handler) *transportMetrics { @@ -427,9 +436,8 @@ func newTransportMetrics(h sdkmetrics.Handler) *transportMetrics { return m } -func (m *transportMetrics) incSidsActive(ctx context.Context, delta int64) { - m.sids += delta - m.sidsGauge.Observe(ctx, m.sids, nil) +func (m *transportMetrics) setSidsActive(ctx context.Context, value int64) { + m.sidsGauge.Observe(ctx, value, nil) } func (m *transportMetrics) recordFrameRx(ctx context.Context, kind string) { diff --git a/pkg/rtun/transport/session_race_test.go b/pkg/rtun/transport/session_race_test.go index 093c752ba..d41ce769b 100644 --- a/pkg/rtun/transport/session_race_test.go +++ b/pkg/rtun/transport/session_race_test.go @@ -46,15 +46,15 @@ func TestConcurrentOperationsNoPanic(t *testing.T) { } // Feed random frames - for i := 0; i < 20; i++ { + for u := uint32(0); u < 20; u++ { wg.Add(1) - go func(i int) { + go func(i uint32) { defer wg.Done() - sid := uint32(100 + i) + sid := 100 + i tl.push(&rtunpb.Frame{Sid: sid, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) time.Sleep(2 * time.Millisecond) tl.push(&rtunpb.Frame{Sid: sid, Kind: &rtunpb.Frame_Data{Data: &rtunpb.Data{Payload: []byte("x")}}}) - }(i) + }(u) } wg.Wait() diff --git a/pkg/rtun/transport/session_test.go b/pkg/rtun/transport/session_test.go index 7abe62794..d36ea932f 100644 --- a/pkg/rtun/transport/session_test.go +++ b/pkg/rtun/transport/session_test.go @@ -271,8 +271,8 @@ func TestBackpressureInboundOverflow(t *testing.T) { t.Fatal("accept timeout") } - // Push more frames than the read buffer can hold (capacity 16) without reading. - for i := 0; i < 20; i++ { + // Push more frames than the read buffer can hold (capacity 256) without reading. + for i := 0; i < 300; i++ { tl.push(&rtunpb.Frame{Sid: 23, Kind: &rtunpb.Frame_Data{Data: &rtunpb.Data{Payload: []byte{byte(i)}}}}) } // Give the loop a moment to overflow and close. From 0b2c7dbdf65dd091d9c356a4d32d9acd6bc9232d Mon Sep 17 00:00:00 2001 From: Paul Querna Date: Sun, 12 Oct 2025 20:12:09 -0700 Subject: [PATCH 5/9] obserable guage for sids active --- pkg/metrics/metrics.go | 4 ++ pkg/metrics/noop.go | 3 ++ pkg/metrics/otel.go | 22 +++++++++++ pkg/rtun/transport/session.go | 69 +++++++++++++---------------------- 4 files changed, 54 insertions(+), 44 deletions(-) diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 7af8bcd86..f22f3b561 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -8,6 +8,10 @@ type Handler interface { Int64Counter(name string, description string, unit Unit) Int64Counter Int64Gauge(name string, description string, unit Unit) Int64Gauge Int64Histogram(name string, description string, unit Unit) Int64Histogram + // RegisterInt64ObservableGauge registers an asynchronous gauge that will be observed + // during collection by invoking the provided callback. The callback should return + // the current value and optional tags at observation time. + RegisterInt64ObservableGauge(name string, description string, unit Unit, callback func(ctx context.Context) (int64, map[string]string)) WithTags(tags map[string]string) Handler } diff --git a/pkg/metrics/noop.go b/pkg/metrics/noop.go index eec13f8ae..7fde9d2af 100644 --- a/pkg/metrics/noop.go +++ b/pkg/metrics/noop.go @@ -28,6 +28,9 @@ func (*noopHandler) Int64Histogram(_ string, _ string, _ Unit) Int64Histogram { return &noopRecorder{} } +func (*noopHandler) RegisterInt64ObservableGauge(_ string, _ string, _ Unit, _ func(ctx context.Context) (int64, map[string]string)) { +} + func (*noopHandler) WithTags(_ map[string]string) Handler { return &noopHandler{} } diff --git a/pkg/metrics/otel.go b/pkg/metrics/otel.go index 9a43d79e7..71335330a 100644 --- a/pkg/metrics/otel.go +++ b/pkg/metrics/otel.go @@ -154,6 +154,28 @@ func (h *otelHandler) Int64Gauge(name string, description string, unit Unit) Int return c } +func (h *otelHandler) RegisterInt64ObservableGauge(name string, description string, unit Unit, callback func(ctx context.Context) (int64, map[string]string)) { + name = strings.ToLower(name) + gauge, err := h.meter.Int64ObservableGauge(name, otelmetric.WithDescription(description), otelmetric.WithUnit(string(unit))) + if err != nil { + panic(err) + } + // capture default attrs pointer for this handler + defaultAttrs := h.defaultAttrs + _, err = h.meter.RegisterCallback(func(ctx context.Context, observer otelmetric.Observer) error { + val, tags := callback(ctx) + attrs := makeAttrs(tags) + if defaultAttrs != nil { + attrs = append(attrs, *defaultAttrs...) + } + observer.ObserveInt64(gauge, val, otelmetric.WithAttributes(attrs...)) + return nil + }, gauge) + if err != nil { + panic(err) + } +} + func (h *otelHandler) WithTags(tags map[string]string) Handler { attrs := makeAttrs(tags) diff --git a/pkg/rtun/transport/session.go b/pkg/rtun/transport/session.go index e2258647d..deb478f48 100644 --- a/pkg/rtun/transport/session.go +++ b/pkg/rtun/transport/session.go @@ -135,7 +135,12 @@ func NewSession(link Link, opts ...Option) *Session { idleTimeout: idle, } if o.metrics != nil { - s.m = newTransportMetrics(o.metrics) + s.m = newTransportMetrics(o.metrics, func(ctx context.Context) (int64, map[string]string) { + s.mu.Lock() + n := len(s.conns) + s.mu.Unlock() + return int64(n), nil + }) } return s } @@ -255,8 +260,7 @@ func (s *Session) recvLoop() { } s.conns[sid] = vc vc.startIdleTimer() - // capture new connection count under lock for gauge update - newCount := len(s.conns) + // connection count is observed via metrics callback // Drain any pending data queued before SYN if q := s.pending[sid]; len(q) > 0 { for _, p := range q { @@ -265,9 +269,6 @@ func (s *Session) recvLoop() { delete(s.pending, sid) } s.mu.Unlock() - if s.m != nil { - s.m.setSidsActive(s.link.Context(), int64(newCount)) - } l.enqueue(vc) case *rtunpb.Frame_Data: s.mu.Lock() @@ -299,13 +300,10 @@ func (s *Session) recvLoop() { s.mu.Unlock() if c != nil { c.handleFin(k.Fin.GetAck()) - n := s.removeConn(sid) + _ = s.removeConn(sid) s.mu.Lock() s.closed.Close(sid) s.mu.Unlock() - if s.m != nil { - s.m.setSidsActive(s.link.Context(), int64(n)) - } } case *rtunpb.Frame_Rst: s.mu.Lock() @@ -316,13 +314,10 @@ func (s *Session) recvLoop() { s.m.recordRstRecv(s.link.Context(), k.Rst.GetCode().String()) } c.handleRst(ErrConnReset) - n := s.removeConn(sid) + _ = s.removeConn(sid) s.mu.Lock() s.closed.Close(sid) s.mu.Unlock() - if s.m != nil { - s.m.setSidsActive(s.link.Context(), int64(n)) - } } } } @@ -366,19 +361,11 @@ func (s *Session) Open(ctx context.Context, port uint32) (net.Conn, error) { // Send SYN to remote if err := s.link.Send(&rtunpb.Frame{Sid: sid, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: port}}}); err != nil { // Cleanup on failure - n := s.removeConn(sid) - if s.m != nil { - s.m.setSidsActive(s.link.Context(), int64(n)) - } + _ = s.removeConn(sid) return nil, err } vc.startIdleTimer() if s.m != nil { - // set gauge to current number of conns after successful open - s.mu.Lock() - n := len(s.conns) - s.mu.Unlock() - s.m.setSidsActive(s.link.Context(), int64(n)) s.m.recordFrameTx(s.link.Context(), "SYN") } return vc, nil @@ -412,34 +399,28 @@ func (s *Session) failLocked(err error) { // metrics helpers. type transportMetrics struct { - framesRx sdkmetrics.Int64Counter - framesTx sdkmetrics.Int64Counter - bytesRx sdkmetrics.Int64Counter - bytesTx sdkmetrics.Int64Counter - rstSent sdkmetrics.Int64Counter - rstRecv sdkmetrics.Int64Counter - sidsGauge sdkmetrics.Int64Gauge + framesRx sdkmetrics.Int64Counter + framesTx sdkmetrics.Int64Counter + bytesRx sdkmetrics.Int64Counter + bytesTx sdkmetrics.Int64Counter + rstSent sdkmetrics.Int64Counter + rstRecv sdkmetrics.Int64Counter } -func newTransportMetrics(h sdkmetrics.Handler) *transportMetrics { +func newTransportMetrics(h sdkmetrics.Handler, observeSids func(ctx context.Context) (int64, map[string]string)) *transportMetrics { m := &transportMetrics{ - framesRx: h.Int64Counter("rtun.transport.frames_rx_total", "transport frames received", sdkmetrics.Dimensionless), - framesTx: h.Int64Counter("rtun.transport.frames_tx_total", "transport frames sent", sdkmetrics.Dimensionless), - bytesRx: h.Int64Counter("rtun.transport.data_bytes_rx_total", "transport bytes received", sdkmetrics.Bytes), - bytesTx: h.Int64Counter("rtun.transport.data_bytes_tx_total", "transport bytes sent", sdkmetrics.Bytes), - rstSent: h.Int64Counter("rtun.transport.rst_sent_total", "RST frames sent by code", sdkmetrics.Dimensionless), - rstRecv: h.Int64Counter("rtun.transport.rst_recv_total", "RST frames received by code", sdkmetrics.Dimensionless), - sidsGauge: h.Int64Gauge("rtun.transport.sids_active", "active SIDs per session", sdkmetrics.Dimensionless), + framesRx: h.Int64Counter("rtun.transport.frames_rx_total", "transport frames received", sdkmetrics.Dimensionless), + framesTx: h.Int64Counter("rtun.transport.frames_tx_total", "transport frames sent", sdkmetrics.Dimensionless), + bytesRx: h.Int64Counter("rtun.transport.data_bytes_rx_total", "transport bytes received", sdkmetrics.Bytes), + bytesTx: h.Int64Counter("rtun.transport.data_bytes_tx_total", "transport bytes sent", sdkmetrics.Bytes), + rstSent: h.Int64Counter("rtun.transport.rst_sent_total", "RST frames sent by code", sdkmetrics.Dimensionless), + rstRecv: h.Int64Counter("rtun.transport.rst_recv_total", "RST frames received by code", sdkmetrics.Dimensionless), } - // initialize gauge to 0 - m.sidsGauge.Observe(context.Background(), 0, nil) + // register observable gauge for active SIDs using provided callback + h.RegisterInt64ObservableGauge("rtun.transport.sids_active", "active SIDs per session", sdkmetrics.Dimensionless, observeSids) return m } -func (m *transportMetrics) setSidsActive(ctx context.Context, value int64) { - m.sidsGauge.Observe(ctx, value, nil) -} - func (m *transportMetrics) recordFrameRx(ctx context.Context, kind string) { m.framesRx.Add(ctx, 1, map[string]string{"kind": kind}) } From 400c35f2c9d9ae06f7046cfb2d02759e70320c82 Mon Sep 17 00:00:00 2001 From: Paul Querna Date: Sun, 12 Oct 2025 20:27:45 -0700 Subject: [PATCH 6/9] review feedback --- pkg/rtun/gateway/client.go | 2 +- pkg/rtun/gateway/client_conn_test.go | 4 ++-- pkg/rtun/gateway/integration_test.go | 8 ++++---- pkg/rtun/match/directory.go | 2 +- pkg/rtun/match/locator.go | 6 +++--- pkg/rtun/match/memory/directory.go | 2 +- pkg/rtun/match/route.go | 4 ++-- pkg/rtun/match/route_test.go | 4 ++-- pkg/rtun/server/server_integration_test.go | 6 +++--- 9 files changed, 19 insertions(+), 19 deletions(-) diff --git a/pkg/rtun/gateway/client.go b/pkg/rtun/gateway/client.go index 895209942..08efce73f 100644 --- a/pkg/rtun/gateway/client.go +++ b/pkg/rtun/gateway/client.go @@ -75,7 +75,7 @@ func (d *Dialer) DialContext(ctx context.Context, clientID string, port uint32) logger := ctxzap.Extract(ctx).With(zap.String("client_id", clientID), zap.Uint32("port", port)) // Dial gateway - cc, err := grpc.DialContext(ctx, d.gatewayAddr, + cc, err := grpc.NewClient("passthrough:///"+d.gatewayAddr, grpc.WithTransportCredentials(d.creds), ) if err != nil { diff --git a/pkg/rtun/gateway/client_conn_test.go b/pkg/rtun/gateway/client_conn_test.go index 936d6c6cb..627396dd1 100644 --- a/pkg/rtun/gateway/client_conn_test.go +++ b/pkg/rtun/gateway/client_conn_test.go @@ -35,12 +35,12 @@ func setupGateway(t *testing.T, silent bool) *gwEnv { gsrv := grpc.NewServer() rtunpb.RegisterReverseTunnelServiceServer(gsrv, handler) rtunpb.RegisterReverseDialerServiceServer(gsrv, gw) - l, err := net.Listen("tcp", "127.0.0.1:0") + l, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0") require.NoError(t, err) go func() { _ = gsrv.Serve(l) }() // Bring up a client link and listen on port 1 - cc, err := grpc.Dial(l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + cc, err := grpc.NewClient("passthrough:///"+l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) require.NoError(t, err) rtunClient := rtunpb.NewReverseTunnelServiceClient(cc) stream, err := rtunClient.Link(context.Background()) diff --git a/pkg/rtun/gateway/integration_test.go b/pkg/rtun/gateway/integration_test.go index cc7329172..86e2dc8c4 100644 --- a/pkg/rtun/gateway/integration_test.go +++ b/pkg/rtun/gateway/integration_test.go @@ -59,7 +59,7 @@ func TestGatewayE2E(t *testing.T) { rtunpb.RegisterReverseTunnelServiceServer(gsrvA, handlerA) rtunpb.RegisterReverseDialerServiceServer(gsrvA, gwA) - lA, err := net.Listen("tcp", "127.0.0.1:0") + lA, err := (&net.ListenConfig{}).Listen(ctx, "tcp", "127.0.0.1:0") require.NoError(t, err) go func() { _ = gsrvA.Serve(lA) }() @@ -67,7 +67,7 @@ func TestGatewayE2E(t *testing.T) { clientCtx, clientCancel := context.WithCancel(ctx) defer clientCancel() - ccA, err := grpc.Dial(lA.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + ccA, err := grpc.NewClient("passthrough:///"+lA.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) require.NoError(t, err) rtunClient := rtunpb.NewReverseTunnelServiceClient(ccA) @@ -98,7 +98,7 @@ func TestGatewayE2E(t *testing.T) { defer gwConn.Close() // Wrap gateway conn in grpc.Dial and perform health check - callerCC, err := grpc.DialContext(ctx, "ignored", + callerCC, err := grpc.NewClient("passthrough:///ignored", grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { return gwConn, nil }), grpc.WithTransportCredentials(insecure.NewCredentials()), ) @@ -132,7 +132,7 @@ func TestGatewayNotFound(t *testing.T) { gsrvA := grpc.NewServer() rtunpb.RegisterReverseDialerServiceServer(gsrvA, gwA) - lA, err := net.Listen("tcp", "127.0.0.1:0") + lA, err := (&net.ListenConfig{}).Listen(ctx, "tcp", "127.0.0.1:0") require.NoError(t, err) defer lA.Close() go func() { _ = gsrvA.Serve(lA) }() diff --git a/pkg/rtun/match/directory.go b/pkg/rtun/match/directory.go index 366ee07b7..67fcb8d95 100644 --- a/pkg/rtun/match/directory.go +++ b/pkg/rtun/match/directory.go @@ -16,5 +16,5 @@ type Directory interface { // Revoke removes a server's advertisement. Revoke(ctx context.Context, serverID string) error // Resolve returns the address for a server ID if present and not expired. - Resolve(ctx context.Context, serverID string) (addr string, err error) + Resolve(ctx context.Context, serverID string) (string, error) } diff --git a/pkg/rtun/match/locator.go b/pkg/rtun/match/locator.go index e417317dc..be652b2ea 100644 --- a/pkg/rtun/match/locator.go +++ b/pkg/rtun/match/locator.go @@ -12,7 +12,7 @@ type Locator struct { // OwnerOf returns the server that currently owns the client's link along with the client's ports. // It uses Presence to list available servers and rendezvous hashing to choose among them. -func (l *Locator) OwnerOf(ctx context.Context, clientID string) (serverID string, ports []uint32, err error) { +func (l *Locator) OwnerOf(ctx context.Context, clientID string) (string, []uint32, error) { if l == nil || l.Presence == nil { return "", nil, ErrNotImplemented } @@ -24,11 +24,11 @@ func (l *Locator) OwnerOf(ctx context.Context, clientID string) (serverID string return "", nil, ErrClientOffline } owner := rendezvousChoose(clientID, servers) - ports, err = l.Presence.Ports(ctx, clientID) + p, err := l.Presence.Ports(ctx, clientID) if err != nil { return "", nil, err } - return owner, ports, nil + return owner, p, nil } func rendezvousChoose(clientID string, servers []string) string { diff --git a/pkg/rtun/match/memory/directory.go b/pkg/rtun/match/memory/directory.go index 9e3c93074..bf4c680db 100644 --- a/pkg/rtun/match/memory/directory.go +++ b/pkg/rtun/match/memory/directory.go @@ -42,7 +42,7 @@ func (d *Directory) Revoke(ctx context.Context, serverID string) error { return nil } -func (d *Directory) Resolve(ctx context.Context, serverID string) (addr string, err error) { +func (d *Directory) Resolve(ctx context.Context, serverID string) (string, error) { now := time.Now() d.mu.Lock() defer d.mu.Unlock() diff --git a/pkg/rtun/match/route.go b/pkg/rtun/match/route.go index 6af9ed7e9..621b9170a 100644 --- a/pkg/rtun/match/route.go +++ b/pkg/rtun/match/route.go @@ -32,7 +32,7 @@ func (r *OwnerRouter) DialOwner(ctx context.Context, clientID string) (*grpc.Cli return nil, "", fmt.Errorf("rtun: resolve owner address: %w", err) } opts := r.DialOpts - conn, err := grpc.DialContext(ctx, addr, opts...) + conn, err := grpc.NewClient("passthrough:///"+addr, opts...) if err != nil { return nil, "", fmt.Errorf("rtun: dial owner: %w", err) } @@ -51,7 +51,7 @@ func LocalReverseDial(ctx context.Context, reg *server.Registry, clientID string ), } addr := u.String() - conn, err := grpc.DialContext(ctx, addr, + conn, err := grpc.NewClient("passthrough:///"+addr, grpc.WithContextDialer(reg.DialContext), grpc.WithTransportCredentials(insecure.NewCredentials()), ) diff --git a/pkg/rtun/match/route_test.go b/pkg/rtun/match/route_test.go index 4f82cd9fa..b1237ae79 100644 --- a/pkg/rtun/match/route_test.go +++ b/pkg/rtun/match/route_test.go @@ -37,7 +37,7 @@ func TestOwnerRouterTwoServers(t *testing.T) { handlerA := server.NewHandler(regA, "server-a", testValidator{id: "client-123"}) gsrvA := grpc.NewServer() rtunpb.RegisterReverseTunnelServiceServer(gsrvA, handlerA) - lA, err := net.Listen("tcp", "127.0.0.1:0") + lA, err := (&net.ListenConfig{}).Listen(ctx, "tcp", "127.0.0.1:0") require.NoError(t, err) go func() { _ = gsrvA.Serve(lA) }() @@ -48,7 +48,7 @@ func TestOwnerRouterTwoServers(t *testing.T) { clientCtx, clientCancel := context.WithCancel(ctx) defer clientCancel() - ccA, err := grpc.Dial(lA.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + ccA, err := grpc.NewClient("passthrough:///"+lA.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) require.NoError(t, err) rtunClientA := rtunpb.NewReverseTunnelServiceClient(ccA) streamA, err := rtunClientA.Link(clientCtx) diff --git a/pkg/rtun/server/server_integration_test.go b/pkg/rtun/server/server_integration_test.go index c91485287..19de88649 100644 --- a/pkg/rtun/server/server_integration_test.go +++ b/pkg/rtun/server/server_integration_test.go @@ -49,14 +49,14 @@ func TestReverseGrpcE2E(t *testing.T) { h := NewHandler(reg, "server-1", testValidator{id: "client-123"}) gsrv := grpc.NewServer() rtunpb.RegisterReverseTunnelServiceServer(gsrv, h) - l, err := net.Listen("tcp", "127.0.0.1:0") + l, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0") require.NoError(t, err) defer l.Close() go func() { _ = gsrv.Serve(l) }() defer gsrv.GracefulStop() // Client side: dial server and open Link stream - cc, err := grpc.Dial(l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + cc, err := grpc.NewClient("passthrough:///"+l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) require.NoError(t, err) defer cc.Close() rtunClient := rtunpb.NewReverseTunnelServiceClient(cc) @@ -87,7 +87,7 @@ func TestReverseGrpcE2E(t *testing.T) { require.NoError(t, err) defer rconn.Close() - ownerCC, err := grpc.DialContext(context.Background(), "ignored", + ownerCC, err := grpc.NewClient("passthrough:///ignored", grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { return rconn, nil }), grpc.WithTransportCredentials(insecure.NewCredentials()), ) From 8745feca8e57a02f0c927e94705f90100adf5c41 Mon Sep 17 00:00:00 2001 From: Paul Querna Date: Sun, 12 Oct 2025 20:34:54 -0700 Subject: [PATCH 7/9] fix lint --- pkg/rtun/server/handler.go | 3 --- pkg/rtun/transport/session.go | 2 -- 2 files changed, 5 deletions(-) diff --git a/pkg/rtun/server/handler.go b/pkg/rtun/server/handler.go index 9f7968a94..2e8c5adb7 100644 --- a/pkg/rtun/server/handler.go +++ b/pkg/rtun/server/handler.go @@ -2,7 +2,6 @@ package server import ( "context" - "sync" "time" rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" @@ -19,8 +18,6 @@ type Handler struct { serverID string tv TokenValidator - mu sync.Mutex - // metrics (optional) m *serverMetrics } diff --git a/pkg/rtun/transport/session.go b/pkg/rtun/transport/session.go index deb478f48..7d54bf81c 100644 --- a/pkg/rtun/transport/session.go +++ b/pkg/rtun/transport/session.go @@ -200,8 +200,6 @@ func (s *Session) startOnce() { go s.recvLoop() } -const maxPendingBufferSize = 64 * 1024 - func (s *Session) recvLoop() { for { fr, err := s.link.Recv() From 52112bf218c053979f0aefa174cfa8876097046f Mon Sep 17 00:00:00 2001 From: Paul Querna Date: Sun, 12 Oct 2025 20:47:43 -0700 Subject: [PATCH 8/9] cleanup/lint more --- pkg/metrics/otel.go | 39 ++++++++++++++++++++++------ pkg/rtun/gateway/client.go | 31 +++++++++++++++++----- pkg/rtun/gateway/client_conn_test.go | 3 ++- pkg/rtun/transport/session.go | 11 +------- 4 files changed, 59 insertions(+), 25 deletions(-) diff --git a/pkg/metrics/otel.go b/pkg/metrics/otel.go index 71335330a..3b51d31fb 100644 --- a/pkg/metrics/otel.go +++ b/pkg/metrics/otel.go @@ -22,6 +22,11 @@ type otelHandler struct { provider otelmetric.MeterProvider defaultAttrs *[]attribute.KeyValue + // protects access to defaultAttrs and registeredGauges + defaultAttrsMtx sync.RWMutex + registeredGaugesMtx sync.Mutex + registeredGauges map[string]struct{} + int64CountersMtx sync.Mutex int64Counters map[string]*otelInt64Counter int64HistosMtx sync.Mutex @@ -156,15 +161,30 @@ func (h *otelHandler) Int64Gauge(name string, description string, unit Unit) Int func (h *otelHandler) RegisterInt64ObservableGauge(name string, description string, unit Unit, callback func(ctx context.Context) (int64, map[string]string)) { name = strings.ToLower(name) + + // prevent duplicate registrations for the same gauge name + h.registeredGaugesMtx.Lock() + if h.registeredGauges == nil { + h.registeredGauges = make(map[string]struct{}) + } + if _, exists := h.registeredGauges[name]; exists { + h.registeredGaugesMtx.Unlock() + return + } + h.registeredGauges[name] = struct{}{} + h.registeredGaugesMtx.Unlock() + gauge, err := h.meter.Int64ObservableGauge(name, otelmetric.WithDescription(description), otelmetric.WithUnit(string(unit))) if err != nil { panic(err) } - // capture default attrs pointer for this handler - defaultAttrs := h.defaultAttrs + // capture default attrs pointer for this handler under read lock _, err = h.meter.RegisterCallback(func(ctx context.Context, observer otelmetric.Observer) error { val, tags := callback(ctx) attrs := makeAttrs(tags) + h.defaultAttrsMtx.RLock() + defaultAttrs := h.defaultAttrs + h.defaultAttrsMtx.RUnlock() if defaultAttrs != nil { attrs = append(attrs, *defaultAttrs...) } @@ -179,7 +199,9 @@ func (h *otelHandler) RegisterInt64ObservableGauge(name string, description stri func (h *otelHandler) WithTags(tags map[string]string) Handler { attrs := makeAttrs(tags) + h.defaultAttrsMtx.Lock() h.defaultAttrs = &attrs + h.defaultAttrsMtx.Unlock() return h } @@ -195,11 +217,12 @@ func makeAttrs(tags map[string]string) []attribute.KeyValue { func NewOtelHandler(_ context.Context, provider otelmetric.MeterProvider, name string) Handler { return &otelHandler{ - name: name, - meter: provider.Meter(name), - provider: provider, - int64Counters: make(map[string]*otelInt64Counter), - int64Histos: make(map[string]*otelInt64Histogram), - int64Gauges: make(map[string]*otelInt64Gauge), + name: name, + meter: provider.Meter(name), + provider: provider, + int64Counters: make(map[string]*otelInt64Counter), + int64Histos: make(map[string]*otelInt64Histogram), + int64Gauges: make(map[string]*otelInt64Gauge), + registeredGauges: make(map[string]struct{}), } } diff --git a/pkg/rtun/gateway/client.go b/pkg/rtun/gateway/client.go index 08efce73f..06176c07b 100644 --- a/pkg/rtun/gateway/client.go +++ b/pkg/rtun/gateway/client.go @@ -324,14 +324,33 @@ func (w *writer) loop() { select { case msg := <-w.ch: if msg.fin { - _ = w.stream.Send(&rtunpb.ReverseDialerServiceOpenRequest{Kind: &rtunpb.ReverseDialerServiceOpenRequest_Frame{ - Frame: &rtunpb.Frame{Sid: w.gsid, Kind: &rtunpb.Frame_Fin{Fin: &rtunpb.Fin{}}}, - }}) + finFrame := &rtunpb.Frame{ + Sid: w.gsid, + Kind: &rtunpb.Frame_Fin{Fin: &rtunpb.Fin{}}, + } + _ = w.stream.Send( + &rtunpb.ReverseDialerServiceOpenRequest{ + Kind: &rtunpb.ReverseDialerServiceOpenRequest_Frame{ + Frame: finFrame, + }, + }, + ) continue } - if err := w.stream.Send(&rtunpb.ReverseDialerServiceOpenRequest{Kind: &rtunpb.ReverseDialerServiceOpenRequest_Frame{ - Frame: &rtunpb.Frame{Sid: w.gsid, Kind: &rtunpb.Frame_Data{Data: &rtunpb.Data{Payload: msg.payload}}}, - }}); err != nil { + + dataFrame := &rtunpb.Frame{ + Sid: w.gsid, + Kind: &rtunpb.Frame_Data{ + Data: &rtunpb.Data{Payload: msg.payload}, + }, + } + if err := w.stream.Send( + &rtunpb.ReverseDialerServiceOpenRequest{ + Kind: &rtunpb.ReverseDialerServiceOpenRequest_Frame{ + Frame: dataFrame, + }, + }, + ); err != nil { w.setErr(err) return } diff --git a/pkg/rtun/gateway/client_conn_test.go b/pkg/rtun/gateway/client_conn_test.go index 627396dd1..d8eb619d1 100644 --- a/pkg/rtun/gateway/client_conn_test.go +++ b/pkg/rtun/gateway/client_conn_test.go @@ -158,8 +158,9 @@ func TestGatewayConnWriteAndRemoteReceive(t *testing.T) { require.NoError(t, err) require.Equal(t, len(msg), n) buf := make([]byte, len(msg)) - _, err = rc.Read(buf) + rn, err := io.ReadFull(rc, buf) require.NoError(t, err) + require.Equal(t, len(msg), rn) require.Equal(t, msg, buf) // Local FIN: close gwc; remote should see EOF diff --git a/pkg/rtun/transport/session.go b/pkg/rtun/transport/session.go index 7d54bf81c..ba50c03ef 100644 --- a/pkg/rtun/transport/session.go +++ b/pkg/rtun/transport/session.go @@ -92,8 +92,7 @@ type Session struct { conns map[uint32]*virtConn listeners map[uint32]*rtunListener nextSID uint32 - pending map[uint32][][]byte // queued DATA before SYN processed - closed closedSet // closed SIDs for late-frame detection + closed closedSet // closed SIDs for late-frame detection // configuration allowedPorts map[uint32]bool @@ -128,7 +127,6 @@ func NewSession(link Link, opts ...Option) *Session { conns: make(map[uint32]*virtConn), listeners: make(map[uint32]*rtunListener), nextSID: 1, - pending: make(map[uint32][][]byte), logger: logger, allowedPorts: o.allowedPorts, maxPendingSIDs: maxPending, @@ -259,13 +257,6 @@ func (s *Session) recvLoop() { s.conns[sid] = vc vc.startIdleTimer() // connection count is observed via metrics callback - // Drain any pending data queued before SYN - if q := s.pending[sid]; len(q) > 0 { - for _, p := range q { - vc.feedData(p) - } - delete(s.pending, sid) - } s.mu.Unlock() l.enqueue(vc) case *rtunpb.Frame_Data: From fdbbaa7b6fac545f1538f2e4e936667d71c9af2d Mon Sep 17 00:00:00 2001 From: Paul Querna Date: Sun, 12 Oct 2025 20:54:54 -0700 Subject: [PATCH 9/9] remove legacy maxPending concept --- pkg/rtun/transport/session.go | 35 +++++++++-------------------------- 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/pkg/rtun/transport/session.go b/pkg/rtun/transport/session.go index ba50c03ef..fad824d9e 100644 --- a/pkg/rtun/transport/session.go +++ b/pkg/rtun/transport/session.go @@ -20,8 +20,6 @@ type options struct { // allowedPorts is an allowlist of ports that the server is permitted to Open() toward the client. // If nil or empty, all ports are allowed. allowedPorts map[uint32]bool - // maxPendingSIDs caps how many distinct SIDs may accumulate DATA-before-SYN buffers. - maxPendingSIDs int // idleTimeout controls per-SID idle expiration; zero means use default (10m). Negative disables. idleTimeout time.Duration // metrics handler (optional) @@ -51,14 +49,6 @@ func WithAllowedPorts(ports []uint32) Option { } } -// WithMaxPendingSIDs sets the maximum number of distinct SIDs allowed to accumulate -// DATA-before-SYN pending buffers. Values <= 0 select the default (64). -func WithMaxPendingSIDs(n int) Option { - return func(o *options) { - o.maxPendingSIDs = n - } -} - // WithIdleTimeout sets the per-SID idle timeout. Zero selects the default (10m). Negative disables. func WithIdleTimeout(d time.Duration) Option { return func(o *options) { @@ -95,9 +85,8 @@ type Session struct { closed closedSet // closed SIDs for late-frame detection // configuration - allowedPorts map[uint32]bool - maxPendingSIDs int - idleTimeout time.Duration + allowedPorts map[uint32]bool + idleTimeout time.Duration // metrics (optional) m *transportMetrics @@ -113,24 +102,18 @@ func NewSession(link Link, opts ...Option) *Session { if logger == nil { logger = zap.NewNop() } - // defaults - maxPending := o.maxPendingSIDs - if maxPending <= 0 { - maxPending = 64 - } idle := o.idleTimeout if idle == 0 { idle = 10 * time.Minute } s := &Session{ - link: link, - conns: make(map[uint32]*virtConn), - listeners: make(map[uint32]*rtunListener), - nextSID: 1, - logger: logger, - allowedPorts: o.allowedPorts, - maxPendingSIDs: maxPending, - idleTimeout: idle, + link: link, + conns: make(map[uint32]*virtConn), + listeners: make(map[uint32]*rtunListener), + nextSID: 1, + logger: logger, + allowedPorts: o.allowedPorts, + idleTimeout: idle, } if o.metrics != nil { s.m = newTransportMetrics(o.metrics, func(ctx context.Context) (int64, map[string]string) {