diff --git a/pb/c1/connectorapi/rtun/v1/gateway.pb.go b/pb/c1/connectorapi/rtun/v1/gateway.pb.go new file mode 100644 index 000000000..fb508fc42 --- /dev/null +++ b/pb/c1/connectorapi/rtun/v1/gateway.pb.go @@ -0,0 +1,552 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.4 +// protoc (unknown) +// source: c1/connectorapi/rtun/v1/gateway.proto + +package v1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// ReverseDialerServiceOpenRequest is sent from caller to gateway for the Open RPC. +type ReverseDialerServiceOpenRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Kind: + // + // *ReverseDialerServiceOpenRequest_OpenReq + // *ReverseDialerServiceOpenRequest_Frame + Kind isReverseDialerServiceOpenRequest_Kind `protobuf_oneof:"kind"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ReverseDialerServiceOpenRequest) Reset() { + *x = ReverseDialerServiceOpenRequest{} + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ReverseDialerServiceOpenRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ReverseDialerServiceOpenRequest) ProtoMessage() {} + +func (x *ReverseDialerServiceOpenRequest) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ReverseDialerServiceOpenRequest.ProtoReflect.Descriptor instead. +func (*ReverseDialerServiceOpenRequest) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_gateway_proto_rawDescGZIP(), []int{0} +} + +func (x *ReverseDialerServiceOpenRequest) GetKind() isReverseDialerServiceOpenRequest_Kind { + if x != nil { + return x.Kind + } + return nil +} + +func (x *ReverseDialerServiceOpenRequest) GetOpenReq() *OpenRequest { + if x != nil { + if x, ok := x.Kind.(*ReverseDialerServiceOpenRequest_OpenReq); ok { + return x.OpenReq + } + } + return nil +} + +func (x *ReverseDialerServiceOpenRequest) GetFrame() *Frame { + if x != nil { + if x, ok := x.Kind.(*ReverseDialerServiceOpenRequest_Frame); ok { + return x.Frame + } + } + return nil +} + +type isReverseDialerServiceOpenRequest_Kind interface { + isReverseDialerServiceOpenRequest_Kind() +} + +type ReverseDialerServiceOpenRequest_OpenReq struct { + OpenReq *OpenRequest `protobuf:"bytes,1,opt,name=open_req,json=openReq,proto3,oneof"` // initiate a connection (first message, or concurrent opens) +} + +type ReverseDialerServiceOpenRequest_Frame struct { + Frame *Frame `protobuf:"bytes,10,opt,name=frame,proto3,oneof"` // data/control frames (reuses rtun Frame) +} + +func (*ReverseDialerServiceOpenRequest_OpenReq) isReverseDialerServiceOpenRequest_Kind() {} + +func (*ReverseDialerServiceOpenRequest_Frame) isReverseDialerServiceOpenRequest_Kind() {} + +// ReverseDialerServiceOpenResponse is sent from gateway to caller for the Open RPC. +type ReverseDialerServiceOpenResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Kind: + // + // *ReverseDialerServiceOpenResponse_OpenResp + // *ReverseDialerServiceOpenResponse_Frame + Kind isReverseDialerServiceOpenResponse_Kind `protobuf_oneof:"kind"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ReverseDialerServiceOpenResponse) Reset() { + *x = ReverseDialerServiceOpenResponse{} + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ReverseDialerServiceOpenResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ReverseDialerServiceOpenResponse) ProtoMessage() {} + +func (x *ReverseDialerServiceOpenResponse) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ReverseDialerServiceOpenResponse.ProtoReflect.Descriptor instead. +func (*ReverseDialerServiceOpenResponse) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_gateway_proto_rawDescGZIP(), []int{1} +} + +func (x *ReverseDialerServiceOpenResponse) GetKind() isReverseDialerServiceOpenResponse_Kind { + if x != nil { + return x.Kind + } + return nil +} + +func (x *ReverseDialerServiceOpenResponse) GetOpenResp() *OpenResponse { + if x != nil { + if x, ok := x.Kind.(*ReverseDialerServiceOpenResponse_OpenResp); ok { + return x.OpenResp + } + } + return nil +} + +func (x *ReverseDialerServiceOpenResponse) GetFrame() *Frame { + if x != nil { + if x, ok := x.Kind.(*ReverseDialerServiceOpenResponse_Frame); ok { + return x.Frame + } + } + return nil +} + +type isReverseDialerServiceOpenResponse_Kind interface { + isReverseDialerServiceOpenResponse_Kind() +} + +type ReverseDialerServiceOpenResponse_OpenResp struct { + OpenResp *OpenResponse `protobuf:"bytes,1,opt,name=open_resp,json=openResp,proto3,oneof"` // handshake result +} + +type ReverseDialerServiceOpenResponse_Frame struct { + Frame *Frame `protobuf:"bytes,10,opt,name=frame,proto3,oneof"` // data/control frames (reuses rtun Frame) +} + +func (*ReverseDialerServiceOpenResponse_OpenResp) isReverseDialerServiceOpenResponse_Kind() {} + +func (*ReverseDialerServiceOpenResponse_Frame) isReverseDialerServiceOpenResponse_Kind() {} + +// OpenRequest initiates a reverse connection to a client. +// The caller proposes a gSID (gateway SID) for this connection to support concurrent opens. +type OpenRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Gsid uint32 `protobuf:"varint,1,opt,name=gsid,proto3" json:"gsid,omitempty"` // caller-proposed SID for this connection (must be unique per stream) + ClientId string `protobuf:"bytes,2,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"` // target client (must be URL-safe) + Port uint32 `protobuf:"varint,3,opt,name=port,proto3" json:"port,omitempty"` // target port on the client + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *OpenRequest) Reset() { + *x = OpenRequest{} + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *OpenRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*OpenRequest) ProtoMessage() {} + +func (x *OpenRequest) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use OpenRequest.ProtoReflect.Descriptor instead. +func (*OpenRequest) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_gateway_proto_rawDescGZIP(), []int{2} +} + +func (x *OpenRequest) GetGsid() uint32 { + if x != nil { + return x.Gsid + } + return 0 +} + +func (x *OpenRequest) GetClientId() string { + if x != nil { + return x.ClientId + } + return "" +} + +func (x *OpenRequest) GetPort() uint32 { + if x != nil { + return x.Port + } + return 0 +} + +// OpenResponse indicates the result of an OpenRequest. +type OpenResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Gsid uint32 `protobuf:"varint,1,opt,name=gsid,proto3" json:"gsid,omitempty"` // echoed from OpenRequest + // Types that are valid to be assigned to Result: + // + // *OpenResponse_NotFound + // *OpenResponse_Opened + Result isOpenResponse_Result `protobuf_oneof:"result"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *OpenResponse) Reset() { + *x = OpenResponse{} + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *OpenResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*OpenResponse) ProtoMessage() {} + +func (x *OpenResponse) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use OpenResponse.ProtoReflect.Descriptor instead. +func (*OpenResponse) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_gateway_proto_rawDescGZIP(), []int{3} +} + +func (x *OpenResponse) GetGsid() uint32 { + if x != nil { + return x.Gsid + } + return 0 +} + +func (x *OpenResponse) GetResult() isOpenResponse_Result { + if x != nil { + return x.Result + } + return nil +} + +func (x *OpenResponse) GetNotFound() *NotFound { + if x != nil { + if x, ok := x.Result.(*OpenResponse_NotFound); ok { + return x.NotFound + } + } + return nil +} + +func (x *OpenResponse) GetOpened() *Opened { + if x != nil { + if x, ok := x.Result.(*OpenResponse_Opened); ok { + return x.Opened + } + } + return nil +} + +type isOpenResponse_Result interface { + isOpenResponse_Result() +} + +type OpenResponse_NotFound struct { + NotFound *NotFound `protobuf:"bytes,2,opt,name=not_found,json=notFound,proto3,oneof"` // gateway doesn't own this client; caller should re-resolve +} + +type OpenResponse_Opened struct { + Opened *Opened `protobuf:"bytes,3,opt,name=opened,proto3,oneof"` // success; use the gSID for subsequent frames +} + +func (*OpenResponse_NotFound) isOpenResponse_Result() {} + +func (*OpenResponse_Opened) isOpenResponse_Result() {} + +type NotFound struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *NotFound) Reset() { + *x = NotFound{} + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *NotFound) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*NotFound) ProtoMessage() {} + +func (x *NotFound) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use NotFound.ProtoReflect.Descriptor instead. +func (*NotFound) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_gateway_proto_rawDescGZIP(), []int{4} +} + +type Opened struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Opened) Reset() { + *x = Opened{} + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Opened) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Opened) ProtoMessage() {} + +func (x *Opened) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Opened.ProtoReflect.Descriptor instead. +func (*Opened) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_gateway_proto_rawDescGZIP(), []int{5} +} + +var File_c1_connectorapi_rtun_v1_gateway_proto protoreflect.FileDescriptor + +var file_c1_connectorapi_rtun_v1_gateway_proto_rawDesc = string([]byte{ + 0x0a, 0x25, 0x63, 0x31, 0x2f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, + 0x69, 0x2f, 0x72, 0x74, 0x75, 0x6e, 0x2f, 0x76, 0x31, 0x2f, 0x67, 0x61, 0x74, 0x65, 0x77, 0x61, + 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x17, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, + 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, + 0x1a, 0x22, 0x63, 0x31, 0x2f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, + 0x69, 0x2f, 0x72, 0x74, 0x75, 0x6e, 0x2f, 0x76, 0x31, 0x2f, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xa4, 0x01, 0x0a, 0x1f, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, + 0x44, 0x69, 0x61, 0x6c, 0x65, 0x72, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4f, 0x70, 0x65, + 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x41, 0x0a, 0x08, 0x6f, 0x70, 0x65, 0x6e, + 0x5f, 0x72, 0x65, 0x71, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x63, 0x31, 0x2e, + 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, + 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x4f, 0x70, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x48, 0x00, 0x52, 0x07, 0x6f, 0x70, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x12, 0x36, 0x0a, 0x05, 0x66, + 0x72, 0x61, 0x6d, 0x65, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x63, 0x31, 0x2e, + 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, + 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x46, 0x72, 0x61, 0x6d, 0x65, 0x48, 0x00, 0x52, 0x05, 0x66, 0x72, + 0x61, 0x6d, 0x65, 0x42, 0x06, 0x0a, 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x22, 0xa8, 0x01, 0x0a, 0x20, + 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x44, 0x69, 0x61, 0x6c, 0x65, 0x72, 0x53, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x4f, 0x70, 0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x44, 0x0a, 0x09, 0x6f, 0x70, 0x65, 0x6e, 0x5f, 0x72, 0x65, 0x73, 0x70, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x25, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, + 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x4f, 0x70, + 0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x48, 0x00, 0x52, 0x08, 0x6f, 0x70, + 0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x12, 0x36, 0x0a, 0x05, 0x66, 0x72, 0x61, 0x6d, 0x65, 0x18, + 0x0a, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, + 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, + 0x46, 0x72, 0x61, 0x6d, 0x65, 0x48, 0x00, 0x52, 0x05, 0x66, 0x72, 0x61, 0x6d, 0x65, 0x42, 0x06, + 0x0a, 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x22, 0x52, 0x0a, 0x0b, 0x4f, 0x70, 0x65, 0x6e, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x67, 0x73, 0x69, 0x64, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0d, 0x52, 0x04, 0x67, 0x73, 0x69, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x63, 0x6c, 0x69, + 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x63, 0x6c, + 0x69, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x22, 0xa9, 0x01, 0x0a, 0x0c, 0x4f, + 0x70, 0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x67, + 0x73, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x67, 0x73, 0x69, 0x64, 0x12, + 0x40, 0x0a, 0x09, 0x6e, 0x6f, 0x74, 0x5f, 0x66, 0x6f, 0x75, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, + 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x6f, 0x74, + 0x46, 0x6f, 0x75, 0x6e, 0x64, 0x48, 0x00, 0x52, 0x08, 0x6e, 0x6f, 0x74, 0x46, 0x6f, 0x75, 0x6e, + 0x64, 0x12, 0x39, 0x0a, 0x06, 0x6f, 0x70, 0x65, 0x6e, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x1f, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, + 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x4f, 0x70, 0x65, 0x6e, + 0x65, 0x64, 0x48, 0x00, 0x52, 0x06, 0x6f, 0x70, 0x65, 0x6e, 0x65, 0x64, 0x42, 0x08, 0x0a, 0x06, + 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x0a, 0x0a, 0x08, 0x4e, 0x6f, 0x74, 0x46, 0x6f, 0x75, + 0x6e, 0x64, 0x22, 0x08, 0x0a, 0x06, 0x4f, 0x70, 0x65, 0x6e, 0x65, 0x64, 0x32, 0x97, 0x01, 0x0a, + 0x14, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x44, 0x69, 0x61, 0x6c, 0x65, 0x72, 0x53, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x7f, 0x0a, 0x04, 0x4f, 0x70, 0x65, 0x6e, 0x12, 0x38, 0x2e, + 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, + 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x44, + 0x69, 0x61, 0x6c, 0x65, 0x72, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4f, 0x70, 0x65, 0x6e, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x39, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, + 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, + 0x31, 0x2e, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x44, 0x69, 0x61, 0x6c, 0x65, 0x72, 0x53, + 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4f, 0x70, 0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x28, 0x01, 0x30, 0x01, 0x42, 0x3e, 0x5a, 0x3c, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, + 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x6e, 0x64, 0x75, 0x63, 0x74, 0x6f, 0x72, 0x6f, 0x6e, + 0x65, 0x2f, 0x62, 0x61, 0x74, 0x6f, 0x6e, 0x2d, 0x73, 0x64, 0x6b, 0x2f, 0x70, 0x62, 0x2f, 0x63, + 0x31, 0x2f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2f, 0x72, + 0x74, 0x75, 0x6e, 0x2f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +}) + +var ( + file_c1_connectorapi_rtun_v1_gateway_proto_rawDescOnce sync.Once + file_c1_connectorapi_rtun_v1_gateway_proto_rawDescData []byte +) + +func file_c1_connectorapi_rtun_v1_gateway_proto_rawDescGZIP() []byte { + file_c1_connectorapi_rtun_v1_gateway_proto_rawDescOnce.Do(func() { + file_c1_connectorapi_rtun_v1_gateway_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_c1_connectorapi_rtun_v1_gateway_proto_rawDesc), len(file_c1_connectorapi_rtun_v1_gateway_proto_rawDesc))) + }) + return file_c1_connectorapi_rtun_v1_gateway_proto_rawDescData +} + +var file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes = make([]protoimpl.MessageInfo, 6) +var file_c1_connectorapi_rtun_v1_gateway_proto_goTypes = []any{ + (*ReverseDialerServiceOpenRequest)(nil), // 0: c1.connectorapi.rtun.v1.ReverseDialerServiceOpenRequest + (*ReverseDialerServiceOpenResponse)(nil), // 1: c1.connectorapi.rtun.v1.ReverseDialerServiceOpenResponse + (*OpenRequest)(nil), // 2: c1.connectorapi.rtun.v1.OpenRequest + (*OpenResponse)(nil), // 3: c1.connectorapi.rtun.v1.OpenResponse + (*NotFound)(nil), // 4: c1.connectorapi.rtun.v1.NotFound + (*Opened)(nil), // 5: c1.connectorapi.rtun.v1.Opened + (*Frame)(nil), // 6: c1.connectorapi.rtun.v1.Frame +} +var file_c1_connectorapi_rtun_v1_gateway_proto_depIdxs = []int32{ + 2, // 0: c1.connectorapi.rtun.v1.ReverseDialerServiceOpenRequest.open_req:type_name -> c1.connectorapi.rtun.v1.OpenRequest + 6, // 1: c1.connectorapi.rtun.v1.ReverseDialerServiceOpenRequest.frame:type_name -> c1.connectorapi.rtun.v1.Frame + 3, // 2: c1.connectorapi.rtun.v1.ReverseDialerServiceOpenResponse.open_resp:type_name -> c1.connectorapi.rtun.v1.OpenResponse + 6, // 3: c1.connectorapi.rtun.v1.ReverseDialerServiceOpenResponse.frame:type_name -> c1.connectorapi.rtun.v1.Frame + 4, // 4: c1.connectorapi.rtun.v1.OpenResponse.not_found:type_name -> c1.connectorapi.rtun.v1.NotFound + 5, // 5: c1.connectorapi.rtun.v1.OpenResponse.opened:type_name -> c1.connectorapi.rtun.v1.Opened + 0, // 6: c1.connectorapi.rtun.v1.ReverseDialerService.Open:input_type -> c1.connectorapi.rtun.v1.ReverseDialerServiceOpenRequest + 1, // 7: c1.connectorapi.rtun.v1.ReverseDialerService.Open:output_type -> c1.connectorapi.rtun.v1.ReverseDialerServiceOpenResponse + 7, // [7:8] is the sub-list for method output_type + 6, // [6:7] is the sub-list for method input_type + 6, // [6:6] is the sub-list for extension type_name + 6, // [6:6] is the sub-list for extension extendee + 0, // [0:6] is the sub-list for field type_name +} + +func init() { file_c1_connectorapi_rtun_v1_gateway_proto_init() } +func file_c1_connectorapi_rtun_v1_gateway_proto_init() { + if File_c1_connectorapi_rtun_v1_gateway_proto != nil { + return + } + file_c1_connectorapi_rtun_v1_rtun_proto_init() + file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[0].OneofWrappers = []any{ + (*ReverseDialerServiceOpenRequest_OpenReq)(nil), + (*ReverseDialerServiceOpenRequest_Frame)(nil), + } + file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[1].OneofWrappers = []any{ + (*ReverseDialerServiceOpenResponse_OpenResp)(nil), + (*ReverseDialerServiceOpenResponse_Frame)(nil), + } + file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes[3].OneofWrappers = []any{ + (*OpenResponse_NotFound)(nil), + (*OpenResponse_Opened)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_c1_connectorapi_rtun_v1_gateway_proto_rawDesc), len(file_c1_connectorapi_rtun_v1_gateway_proto_rawDesc)), + NumEnums: 0, + NumMessages: 6, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_c1_connectorapi_rtun_v1_gateway_proto_goTypes, + DependencyIndexes: file_c1_connectorapi_rtun_v1_gateway_proto_depIdxs, + MessageInfos: file_c1_connectorapi_rtun_v1_gateway_proto_msgTypes, + }.Build() + File_c1_connectorapi_rtun_v1_gateway_proto = out.File + file_c1_connectorapi_rtun_v1_gateway_proto_goTypes = nil + file_c1_connectorapi_rtun_v1_gateway_proto_depIdxs = nil +} diff --git a/pb/c1/connectorapi/rtun/v1/gateway.pb.validate.go b/pb/c1/connectorapi/rtun/v1/gateway.pb.validate.go new file mode 100644 index 000000000..ba6b63815 --- /dev/null +++ b/pb/c1/connectorapi/rtun/v1/gateway.pb.validate.go @@ -0,0 +1,908 @@ +// Code generated by protoc-gen-validate. DO NOT EDIT. +// source: c1/connectorapi/rtun/v1/gateway.proto + +package v1 + +import ( + "bytes" + "errors" + "fmt" + "net" + "net/mail" + "net/url" + "regexp" + "sort" + "strings" + "time" + "unicode/utf8" + + "google.golang.org/protobuf/types/known/anypb" +) + +// ensure the imports are used +var ( + _ = bytes.MinRead + _ = errors.New("") + _ = fmt.Print + _ = utf8.UTFMax + _ = (*regexp.Regexp)(nil) + _ = (*strings.Reader)(nil) + _ = net.IPv4len + _ = time.Duration(0) + _ = (*url.URL)(nil) + _ = (*mail.Address)(nil) + _ = anypb.Any{} + _ = sort.Sort +) + +// Validate checks the field values on ReverseDialerServiceOpenRequest with the +// rules defined in the proto definition for this message. If any rules are +// violated, the first error encountered is returned, or nil if there are no violations. +func (m *ReverseDialerServiceOpenRequest) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on ReverseDialerServiceOpenRequest with +// the rules defined in the proto definition for this message. If any rules +// are violated, the result is a list of violation errors wrapped in +// ReverseDialerServiceOpenRequestMultiError, or nil if none found. +func (m *ReverseDialerServiceOpenRequest) ValidateAll() error { + return m.validate(true) +} + +func (m *ReverseDialerServiceOpenRequest) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + switch v := m.Kind.(type) { + case *ReverseDialerServiceOpenRequest_OpenReq: + if v == nil { + err := ReverseDialerServiceOpenRequestValidationError{ + field: "Kind", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetOpenReq()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, ReverseDialerServiceOpenRequestValidationError{ + field: "OpenReq", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, ReverseDialerServiceOpenRequestValidationError{ + field: "OpenReq", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetOpenReq()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return ReverseDialerServiceOpenRequestValidationError{ + field: "OpenReq", + reason: "embedded message failed validation", + cause: err, + } + } + } + + case *ReverseDialerServiceOpenRequest_Frame: + if v == nil { + err := ReverseDialerServiceOpenRequestValidationError{ + field: "Kind", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetFrame()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, ReverseDialerServiceOpenRequestValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, ReverseDialerServiceOpenRequestValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetFrame()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return ReverseDialerServiceOpenRequestValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + } + } + } + + default: + _ = v // ensures v is used + } + + if len(errors) > 0 { + return ReverseDialerServiceOpenRequestMultiError(errors) + } + + return nil +} + +// ReverseDialerServiceOpenRequestMultiError is an error wrapping multiple +// validation errors returned by ReverseDialerServiceOpenRequest.ValidateAll() +// if the designated constraints aren't met. +type ReverseDialerServiceOpenRequestMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m ReverseDialerServiceOpenRequestMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m ReverseDialerServiceOpenRequestMultiError) AllErrors() []error { return m } + +// ReverseDialerServiceOpenRequestValidationError is the validation error +// returned by ReverseDialerServiceOpenRequest.Validate if the designated +// constraints aren't met. +type ReverseDialerServiceOpenRequestValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e ReverseDialerServiceOpenRequestValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e ReverseDialerServiceOpenRequestValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e ReverseDialerServiceOpenRequestValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e ReverseDialerServiceOpenRequestValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e ReverseDialerServiceOpenRequestValidationError) ErrorName() string { + return "ReverseDialerServiceOpenRequestValidationError" +} + +// Error satisfies the builtin error interface +func (e ReverseDialerServiceOpenRequestValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sReverseDialerServiceOpenRequest.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = ReverseDialerServiceOpenRequestValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = ReverseDialerServiceOpenRequestValidationError{} + +// Validate checks the field values on ReverseDialerServiceOpenResponse with +// the rules defined in the proto definition for this message. If any rules +// are violated, the first error encountered is returned, or nil if there are +// no violations. +func (m *ReverseDialerServiceOpenResponse) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on ReverseDialerServiceOpenResponse with +// the rules defined in the proto definition for this message. If any rules +// are violated, the result is a list of violation errors wrapped in +// ReverseDialerServiceOpenResponseMultiError, or nil if none found. +func (m *ReverseDialerServiceOpenResponse) ValidateAll() error { + return m.validate(true) +} + +func (m *ReverseDialerServiceOpenResponse) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + switch v := m.Kind.(type) { + case *ReverseDialerServiceOpenResponse_OpenResp: + if v == nil { + err := ReverseDialerServiceOpenResponseValidationError{ + field: "Kind", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetOpenResp()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, ReverseDialerServiceOpenResponseValidationError{ + field: "OpenResp", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, ReverseDialerServiceOpenResponseValidationError{ + field: "OpenResp", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetOpenResp()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return ReverseDialerServiceOpenResponseValidationError{ + field: "OpenResp", + reason: "embedded message failed validation", + cause: err, + } + } + } + + case *ReverseDialerServiceOpenResponse_Frame: + if v == nil { + err := ReverseDialerServiceOpenResponseValidationError{ + field: "Kind", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetFrame()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, ReverseDialerServiceOpenResponseValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, ReverseDialerServiceOpenResponseValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetFrame()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return ReverseDialerServiceOpenResponseValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + } + } + } + + default: + _ = v // ensures v is used + } + + if len(errors) > 0 { + return ReverseDialerServiceOpenResponseMultiError(errors) + } + + return nil +} + +// ReverseDialerServiceOpenResponseMultiError is an error wrapping multiple +// validation errors returned by +// ReverseDialerServiceOpenResponse.ValidateAll() if the designated +// constraints aren't met. +type ReverseDialerServiceOpenResponseMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m ReverseDialerServiceOpenResponseMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m ReverseDialerServiceOpenResponseMultiError) AllErrors() []error { return m } + +// ReverseDialerServiceOpenResponseValidationError is the validation error +// returned by ReverseDialerServiceOpenResponse.Validate if the designated +// constraints aren't met. +type ReverseDialerServiceOpenResponseValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e ReverseDialerServiceOpenResponseValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e ReverseDialerServiceOpenResponseValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e ReverseDialerServiceOpenResponseValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e ReverseDialerServiceOpenResponseValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e ReverseDialerServiceOpenResponseValidationError) ErrorName() string { + return "ReverseDialerServiceOpenResponseValidationError" +} + +// Error satisfies the builtin error interface +func (e ReverseDialerServiceOpenResponseValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sReverseDialerServiceOpenResponse.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = ReverseDialerServiceOpenResponseValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = ReverseDialerServiceOpenResponseValidationError{} + +// Validate checks the field values on OpenRequest with the rules defined in +// the proto definition for this message. If any rules are violated, the first +// error encountered is returned, or nil if there are no violations. +func (m *OpenRequest) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on OpenRequest with the rules defined in +// the proto definition for this message. If any rules are violated, the +// result is a list of violation errors wrapped in OpenRequestMultiError, or +// nil if none found. +func (m *OpenRequest) ValidateAll() error { + return m.validate(true) +} + +func (m *OpenRequest) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + // no validation rules for Gsid + + // no validation rules for ClientId + + // no validation rules for Port + + if len(errors) > 0 { + return OpenRequestMultiError(errors) + } + + return nil +} + +// OpenRequestMultiError is an error wrapping multiple validation errors +// returned by OpenRequest.ValidateAll() if the designated constraints aren't met. +type OpenRequestMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m OpenRequestMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m OpenRequestMultiError) AllErrors() []error { return m } + +// OpenRequestValidationError is the validation error returned by +// OpenRequest.Validate if the designated constraints aren't met. +type OpenRequestValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e OpenRequestValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e OpenRequestValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e OpenRequestValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e OpenRequestValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e OpenRequestValidationError) ErrorName() string { return "OpenRequestValidationError" } + +// Error satisfies the builtin error interface +func (e OpenRequestValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sOpenRequest.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = OpenRequestValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = OpenRequestValidationError{} + +// Validate checks the field values on OpenResponse with the rules defined in +// the proto definition for this message. If any rules are violated, the first +// error encountered is returned, or nil if there are no violations. +func (m *OpenResponse) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on OpenResponse with the rules defined +// in the proto definition for this message. If any rules are violated, the +// result is a list of violation errors wrapped in OpenResponseMultiError, or +// nil if none found. +func (m *OpenResponse) ValidateAll() error { + return m.validate(true) +} + +func (m *OpenResponse) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + // no validation rules for Gsid + + switch v := m.Result.(type) { + case *OpenResponse_NotFound: + if v == nil { + err := OpenResponseValidationError{ + field: "Result", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetNotFound()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, OpenResponseValidationError{ + field: "NotFound", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, OpenResponseValidationError{ + field: "NotFound", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetNotFound()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return OpenResponseValidationError{ + field: "NotFound", + reason: "embedded message failed validation", + cause: err, + } + } + } + + case *OpenResponse_Opened: + if v == nil { + err := OpenResponseValidationError{ + field: "Result", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetOpened()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, OpenResponseValidationError{ + field: "Opened", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, OpenResponseValidationError{ + field: "Opened", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetOpened()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return OpenResponseValidationError{ + field: "Opened", + reason: "embedded message failed validation", + cause: err, + } + } + } + + default: + _ = v // ensures v is used + } + + if len(errors) > 0 { + return OpenResponseMultiError(errors) + } + + return nil +} + +// OpenResponseMultiError is an error wrapping multiple validation errors +// returned by OpenResponse.ValidateAll() if the designated constraints aren't met. +type OpenResponseMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m OpenResponseMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m OpenResponseMultiError) AllErrors() []error { return m } + +// OpenResponseValidationError is the validation error returned by +// OpenResponse.Validate if the designated constraints aren't met. +type OpenResponseValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e OpenResponseValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e OpenResponseValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e OpenResponseValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e OpenResponseValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e OpenResponseValidationError) ErrorName() string { return "OpenResponseValidationError" } + +// Error satisfies the builtin error interface +func (e OpenResponseValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sOpenResponse.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = OpenResponseValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = OpenResponseValidationError{} + +// Validate checks the field values on NotFound with the rules defined in the +// proto definition for this message. If any rules are violated, the first +// error encountered is returned, or nil if there are no violations. +func (m *NotFound) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on NotFound with the rules defined in +// the proto definition for this message. If any rules are violated, the +// result is a list of violation errors wrapped in NotFoundMultiError, or nil +// if none found. +func (m *NotFound) ValidateAll() error { + return m.validate(true) +} + +func (m *NotFound) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + if len(errors) > 0 { + return NotFoundMultiError(errors) + } + + return nil +} + +// NotFoundMultiError is an error wrapping multiple validation errors returned +// by NotFound.ValidateAll() if the designated constraints aren't met. +type NotFoundMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m NotFoundMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m NotFoundMultiError) AllErrors() []error { return m } + +// NotFoundValidationError is the validation error returned by +// NotFound.Validate if the designated constraints aren't met. +type NotFoundValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e NotFoundValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e NotFoundValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e NotFoundValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e NotFoundValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e NotFoundValidationError) ErrorName() string { return "NotFoundValidationError" } + +// Error satisfies the builtin error interface +func (e NotFoundValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sNotFound.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = NotFoundValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = NotFoundValidationError{} + +// Validate checks the field values on Opened with the rules defined in the +// proto definition for this message. If any rules are violated, the first +// error encountered is returned, or nil if there are no violations. +func (m *Opened) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on Opened with the rules defined in the +// proto definition for this message. If any rules are violated, the result is +// a list of violation errors wrapped in OpenedMultiError, or nil if none found. +func (m *Opened) ValidateAll() error { + return m.validate(true) +} + +func (m *Opened) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + if len(errors) > 0 { + return OpenedMultiError(errors) + } + + return nil +} + +// OpenedMultiError is an error wrapping multiple validation errors returned by +// Opened.ValidateAll() if the designated constraints aren't met. +type OpenedMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m OpenedMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m OpenedMultiError) AllErrors() []error { return m } + +// OpenedValidationError is the validation error returned by Opened.Validate if +// the designated constraints aren't met. +type OpenedValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e OpenedValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e OpenedValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e OpenedValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e OpenedValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e OpenedValidationError) ErrorName() string { return "OpenedValidationError" } + +// Error satisfies the builtin error interface +func (e OpenedValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sOpened.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = OpenedValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = OpenedValidationError{} diff --git a/pb/c1/connectorapi/rtun/v1/gateway_grpc.pb.go b/pb/c1/connectorapi/rtun/v1/gateway_grpc.pb.go new file mode 100644 index 000000000..a7b200ad7 --- /dev/null +++ b/pb/c1/connectorapi/rtun/v1/gateway_grpc.pb.go @@ -0,0 +1,119 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc (unknown) +// source: c1/connectorapi/rtun/v1/gateway.proto + +package v1 + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + ReverseDialerService_Open_FullMethodName = "/c1.connectorapi.rtun.v1.ReverseDialerService/Open" +) + +// ReverseDialerServiceClient is the client API for ReverseDialerService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +// +// ReverseDialer allows callers to establish connections to clients via the gateway. +// The gateway bridges caller streams to rtun sessions on the owner server. +type ReverseDialerServiceClient interface { + Open(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ReverseDialerServiceOpenRequest, ReverseDialerServiceOpenResponse], error) +} + +type reverseDialerServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewReverseDialerServiceClient(cc grpc.ClientConnInterface) ReverseDialerServiceClient { + return &reverseDialerServiceClient{cc} +} + +func (c *reverseDialerServiceClient) Open(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ReverseDialerServiceOpenRequest, ReverseDialerServiceOpenResponse], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &ReverseDialerService_ServiceDesc.Streams[0], ReverseDialerService_Open_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[ReverseDialerServiceOpenRequest, ReverseDialerServiceOpenResponse]{ClientStream: stream} + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type ReverseDialerService_OpenClient = grpc.BidiStreamingClient[ReverseDialerServiceOpenRequest, ReverseDialerServiceOpenResponse] + +// ReverseDialerServiceServer is the server API for ReverseDialerService service. +// All implementations should embed UnimplementedReverseDialerServiceServer +// for forward compatibility. +// +// ReverseDialer allows callers to establish connections to clients via the gateway. +// The gateway bridges caller streams to rtun sessions on the owner server. +type ReverseDialerServiceServer interface { + Open(grpc.BidiStreamingServer[ReverseDialerServiceOpenRequest, ReverseDialerServiceOpenResponse]) error +} + +// UnimplementedReverseDialerServiceServer should be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedReverseDialerServiceServer struct{} + +func (UnimplementedReverseDialerServiceServer) Open(grpc.BidiStreamingServer[ReverseDialerServiceOpenRequest, ReverseDialerServiceOpenResponse]) error { + return status.Errorf(codes.Unimplemented, "method Open not implemented") +} +func (UnimplementedReverseDialerServiceServer) testEmbeddedByValue() {} + +// UnsafeReverseDialerServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to ReverseDialerServiceServer will +// result in compilation errors. +type UnsafeReverseDialerServiceServer interface { + mustEmbedUnimplementedReverseDialerServiceServer() +} + +func RegisterReverseDialerServiceServer(s grpc.ServiceRegistrar, srv ReverseDialerServiceServer) { + // If the following call pancis, it indicates UnimplementedReverseDialerServiceServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&ReverseDialerService_ServiceDesc, srv) +} + +func _ReverseDialerService_Open_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(ReverseDialerServiceServer).Open(&grpc.GenericServerStream[ReverseDialerServiceOpenRequest, ReverseDialerServiceOpenResponse]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type ReverseDialerService_OpenServer = grpc.BidiStreamingServer[ReverseDialerServiceOpenRequest, ReverseDialerServiceOpenResponse] + +// ReverseDialerService_ServiceDesc is the grpc.ServiceDesc for ReverseDialerService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var ReverseDialerService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "c1.connectorapi.rtun.v1.ReverseDialerService", + HandlerType: (*ReverseDialerServiceServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "Open", + Handler: _ReverseDialerService_Open_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: "c1/connectorapi/rtun/v1/gateway.proto", +} diff --git a/pb/c1/connectorapi/rtun/v1/rtun.pb.go b/pb/c1/connectorapi/rtun/v1/rtun.pb.go new file mode 100644 index 000000000..81601ade4 --- /dev/null +++ b/pb/c1/connectorapi/rtun/v1/rtun.pb.go @@ -0,0 +1,681 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.4 +// protoc (unknown) +// source: c1/connectorapi/rtun/v1/rtun.proto + +package v1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type RstCode int32 + +const ( + RstCode_RST_CODE_UNSPECIFIED RstCode = 0 + RstCode_RST_CODE_NO_LISTENER RstCode = 1 + RstCode_RST_CODE_PORT_NOT_ADVERTISED RstCode = 2 + RstCode_RST_CODE_TIMEOUT RstCode = 3 + RstCode_RST_CODE_INTERNAL RstCode = 4 +) + +// Enum value maps for RstCode. +var ( + RstCode_name = map[int32]string{ + 0: "RST_CODE_UNSPECIFIED", + 1: "RST_CODE_NO_LISTENER", + 2: "RST_CODE_PORT_NOT_ADVERTISED", + 3: "RST_CODE_TIMEOUT", + 4: "RST_CODE_INTERNAL", + } + RstCode_value = map[string]int32{ + "RST_CODE_UNSPECIFIED": 0, + "RST_CODE_NO_LISTENER": 1, + "RST_CODE_PORT_NOT_ADVERTISED": 2, + "RST_CODE_TIMEOUT": 3, + "RST_CODE_INTERNAL": 4, + } +) + +func (x RstCode) Enum() *RstCode { + p := new(RstCode) + *p = x + return p +} + +func (x RstCode) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (RstCode) Descriptor() protoreflect.EnumDescriptor { + return file_c1_connectorapi_rtun_v1_rtun_proto_enumTypes[0].Descriptor() +} + +func (RstCode) Type() protoreflect.EnumType { + return &file_c1_connectorapi_rtun_v1_rtun_proto_enumTypes[0] +} + +func (x RstCode) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use RstCode.Descriptor instead. +func (RstCode) EnumDescriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{0} +} + +type ReverseTunnelServiceLinkRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Frame *Frame `protobuf:"bytes,1,opt,name=frame,proto3" json:"frame,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ReverseTunnelServiceLinkRequest) Reset() { + *x = ReverseTunnelServiceLinkRequest{} + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ReverseTunnelServiceLinkRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ReverseTunnelServiceLinkRequest) ProtoMessage() {} + +func (x *ReverseTunnelServiceLinkRequest) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ReverseTunnelServiceLinkRequest.ProtoReflect.Descriptor instead. +func (*ReverseTunnelServiceLinkRequest) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{0} +} + +func (x *ReverseTunnelServiceLinkRequest) GetFrame() *Frame { + if x != nil { + return x.Frame + } + return nil +} + +type ReverseTunnelServiceLinkResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Frame *Frame `protobuf:"bytes,1,opt,name=frame,proto3" json:"frame,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ReverseTunnelServiceLinkResponse) Reset() { + *x = ReverseTunnelServiceLinkResponse{} + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ReverseTunnelServiceLinkResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ReverseTunnelServiceLinkResponse) ProtoMessage() {} + +func (x *ReverseTunnelServiceLinkResponse) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ReverseTunnelServiceLinkResponse.ProtoReflect.Descriptor instead. +func (*ReverseTunnelServiceLinkResponse) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{1} +} + +func (x *ReverseTunnelServiceLinkResponse) GetFrame() *Frame { + if x != nil { + return x.Frame + } + return nil +} + +type Frame struct { + state protoimpl.MessageState `protogen:"open.v1"` + Sid uint32 `protobuf:"varint,1,opt,name=sid,proto3" json:"sid,omitempty"` // 0 reserved for control + // Types that are valid to be assigned to Kind: + // + // *Frame_Hello + // *Frame_Syn + // *Frame_Data + // *Frame_Fin + // *Frame_Rst + Kind isFrame_Kind `protobuf_oneof:"kind"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Frame) Reset() { + *x = Frame{} + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Frame) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Frame) ProtoMessage() {} + +func (x *Frame) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Frame.ProtoReflect.Descriptor instead. +func (*Frame) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{2} +} + +func (x *Frame) GetSid() uint32 { + if x != nil { + return x.Sid + } + return 0 +} + +func (x *Frame) GetKind() isFrame_Kind { + if x != nil { + return x.Kind + } + return nil +} + +func (x *Frame) GetHello() *Hello { + if x != nil { + if x, ok := x.Kind.(*Frame_Hello); ok { + return x.Hello + } + } + return nil +} + +func (x *Frame) GetSyn() *Syn { + if x != nil { + if x, ok := x.Kind.(*Frame_Syn); ok { + return x.Syn + } + } + return nil +} + +func (x *Frame) GetData() *Data { + if x != nil { + if x, ok := x.Kind.(*Frame_Data); ok { + return x.Data + } + } + return nil +} + +func (x *Frame) GetFin() *Fin { + if x != nil { + if x, ok := x.Kind.(*Frame_Fin); ok { + return x.Fin + } + } + return nil +} + +func (x *Frame) GetRst() *Rst { + if x != nil { + if x, ok := x.Kind.(*Frame_Rst); ok { + return x.Rst + } + } + return nil +} + +type isFrame_Kind interface { + isFrame_Kind() +} + +type Frame_Hello struct { + Hello *Hello `protobuf:"bytes,10,opt,name=hello,proto3,oneof"` // client -> server (first) +} + +type Frame_Syn struct { + Syn *Syn `protobuf:"bytes,11,opt,name=syn,proto3,oneof"` // server -> client (reverse open) +} + +type Frame_Data struct { + Data *Data `protobuf:"bytes,12,opt,name=data,proto3,oneof"` // either direction +} + +type Frame_Fin struct { + Fin *Fin `protobuf:"bytes,13,opt,name=fin,proto3,oneof"` // either direction +} + +type Frame_Rst struct { + Rst *Rst `protobuf:"bytes,14,opt,name=rst,proto3,oneof"` // either direction +} + +func (*Frame_Hello) isFrame_Kind() {} + +func (*Frame_Syn) isFrame_Kind() {} + +func (*Frame_Data) isFrame_Kind() {} + +func (*Frame_Fin) isFrame_Kind() {} + +func (*Frame_Rst) isFrame_Kind() {} + +type Hello struct { + state protoimpl.MessageState `protogen:"open.v1"` + Ports []uint32 `protobuf:"varint,1,rep,packed,name=ports,proto3" json:"ports,omitempty"` + Protocol uint32 `protobuf:"varint,2,opt,name=protocol,proto3" json:"protocol,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Hello) Reset() { + *x = Hello{} + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Hello) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Hello) ProtoMessage() {} + +func (x *Hello) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Hello.ProtoReflect.Descriptor instead. +func (*Hello) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{3} +} + +func (x *Hello) GetPorts() []uint32 { + if x != nil { + return x.Ports + } + return nil +} + +func (x *Hello) GetProtocol() uint32 { + if x != nil { + return x.Protocol + } + return 0 +} + +type Syn struct { + state protoimpl.MessageState `protogen:"open.v1"` + Port uint32 `protobuf:"varint,1,opt,name=port,proto3" json:"port,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Syn) Reset() { + *x = Syn{} + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Syn) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Syn) ProtoMessage() {} + +func (x *Syn) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Syn.ProtoReflect.Descriptor instead. +func (*Syn) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{4} +} + +func (x *Syn) GetPort() uint32 { + if x != nil { + return x.Port + } + return 0 +} + +type Data struct { + state protoimpl.MessageState `protogen:"open.v1"` + Payload []byte `protobuf:"bytes,1,opt,name=payload,proto3" json:"payload,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Data) Reset() { + *x = Data{} + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Data) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Data) ProtoMessage() {} + +func (x *Data) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Data.ProtoReflect.Descriptor instead. +func (*Data) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{5} +} + +func (x *Data) GetPayload() []byte { + if x != nil { + return x.Payload + } + return nil +} + +type Fin struct { + state protoimpl.MessageState `protogen:"open.v1"` + Ack bool `protobuf:"varint,1,opt,name=ack,proto3" json:"ack,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Fin) Reset() { + *x = Fin{} + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Fin) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Fin) ProtoMessage() {} + +func (x *Fin) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Fin.ProtoReflect.Descriptor instead. +func (*Fin) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{6} +} + +func (x *Fin) GetAck() bool { + if x != nil { + return x.Ack + } + return false +} + +type Rst struct { + state protoimpl.MessageState `protogen:"open.v1"` + Code RstCode `protobuf:"varint,1,opt,name=code,proto3,enum=c1.connectorapi.rtun.v1.RstCode" json:"code,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Rst) Reset() { + *x = Rst{} + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Rst) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Rst) ProtoMessage() {} + +func (x *Rst) ProtoReflect() protoreflect.Message { + mi := &file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Rst.ProtoReflect.Descriptor instead. +func (*Rst) Descriptor() ([]byte, []int) { + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP(), []int{7} +} + +func (x *Rst) GetCode() RstCode { + if x != nil { + return x.Code + } + return RstCode_RST_CODE_UNSPECIFIED +} + +var File_c1_connectorapi_rtun_v1_rtun_proto protoreflect.FileDescriptor + +var file_c1_connectorapi_rtun_v1_rtun_proto_rawDesc = string([]byte{ + 0x0a, 0x22, 0x63, 0x31, 0x2f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, + 0x69, 0x2f, 0x72, 0x74, 0x75, 0x6e, 0x2f, 0x76, 0x31, 0x2f, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x17, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, + 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x22, 0x57, 0x0a, + 0x1f, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x53, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x4c, 0x69, 0x6e, 0x6b, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x34, 0x0a, 0x05, 0x66, 0x72, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x1e, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, + 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x46, 0x72, 0x61, 0x6d, 0x65, 0x52, + 0x05, 0x66, 0x72, 0x61, 0x6d, 0x65, 0x22, 0x58, 0x0a, 0x20, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, + 0x65, 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4c, 0x69, + 0x6e, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x34, 0x0a, 0x05, 0x66, 0x72, + 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x63, 0x31, 0x2e, 0x63, + 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, + 0x2e, 0x76, 0x31, 0x2e, 0x46, 0x72, 0x61, 0x6d, 0x65, 0x52, 0x05, 0x66, 0x72, 0x61, 0x6d, 0x65, + 0x22, 0xa4, 0x02, 0x0a, 0x05, 0x46, 0x72, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x73, 0x69, + 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x73, 0x69, 0x64, 0x12, 0x36, 0x0a, 0x05, + 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x63, 0x31, + 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, + 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x48, 0x00, 0x52, 0x05, 0x68, + 0x65, 0x6c, 0x6c, 0x6f, 0x12, 0x30, 0x0a, 0x03, 0x73, 0x79, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x1c, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, + 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x79, 0x6e, 0x48, + 0x00, 0x52, 0x03, 0x73, 0x79, 0x6e, 0x12, 0x33, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x0c, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, + 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x44, + 0x61, 0x74, 0x61, 0x48, 0x00, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x12, 0x30, 0x0a, 0x03, 0x66, + 0x69, 0x6e, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, + 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, + 0x76, 0x31, 0x2e, 0x46, 0x69, 0x6e, 0x48, 0x00, 0x52, 0x03, 0x66, 0x69, 0x6e, 0x12, 0x30, 0x0a, + 0x03, 0x72, 0x73, 0x74, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x63, 0x31, 0x2e, + 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, + 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x73, 0x74, 0x48, 0x00, 0x52, 0x03, 0x72, 0x73, 0x74, 0x42, + 0x06, 0x0a, 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x22, 0x39, 0x0a, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f, + 0x12, 0x14, 0x0a, 0x05, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0d, 0x52, + 0x05, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x22, 0x19, 0x0a, 0x03, 0x53, 0x79, 0x6e, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x6f, 0x72, + 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x22, 0x20, 0x0a, + 0x04, 0x44, 0x61, 0x74, 0x61, 0x12, 0x18, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x22, + 0x17, 0x0a, 0x03, 0x46, 0x69, 0x6e, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x63, 0x6b, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x63, 0x6b, 0x22, 0x3b, 0x0a, 0x03, 0x52, 0x73, 0x74, 0x12, + 0x34, 0x0a, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x20, 0x2e, + 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, + 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x73, 0x74, 0x43, 0x6f, 0x64, 0x65, 0x52, + 0x04, 0x63, 0x6f, 0x64, 0x65, 0x2a, 0x8c, 0x01, 0x0a, 0x07, 0x52, 0x73, 0x74, 0x43, 0x6f, 0x64, + 0x65, 0x12, 0x18, 0x0a, 0x14, 0x52, 0x53, 0x54, 0x5f, 0x43, 0x4f, 0x44, 0x45, 0x5f, 0x55, 0x4e, + 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x18, 0x0a, 0x14, 0x52, + 0x53, 0x54, 0x5f, 0x43, 0x4f, 0x44, 0x45, 0x5f, 0x4e, 0x4f, 0x5f, 0x4c, 0x49, 0x53, 0x54, 0x45, + 0x4e, 0x45, 0x52, 0x10, 0x01, 0x12, 0x20, 0x0a, 0x1c, 0x52, 0x53, 0x54, 0x5f, 0x43, 0x4f, 0x44, + 0x45, 0x5f, 0x50, 0x4f, 0x52, 0x54, 0x5f, 0x4e, 0x4f, 0x54, 0x5f, 0x41, 0x44, 0x56, 0x45, 0x52, + 0x54, 0x49, 0x53, 0x45, 0x44, 0x10, 0x02, 0x12, 0x14, 0x0a, 0x10, 0x52, 0x53, 0x54, 0x5f, 0x43, + 0x4f, 0x44, 0x45, 0x5f, 0x54, 0x49, 0x4d, 0x45, 0x4f, 0x55, 0x54, 0x10, 0x03, 0x12, 0x15, 0x0a, + 0x11, 0x52, 0x53, 0x54, 0x5f, 0x43, 0x4f, 0x44, 0x45, 0x5f, 0x49, 0x4e, 0x54, 0x45, 0x52, 0x4e, + 0x41, 0x4c, 0x10, 0x04, 0x32, 0x97, 0x01, 0x0a, 0x14, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, + 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x7f, 0x0a, + 0x04, 0x4c, 0x69, 0x6e, 0x6b, 0x12, 0x38, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, + 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, + 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x53, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x4c, 0x69, 0x6e, 0x6b, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x39, 0x2e, 0x63, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x61, 0x70, + 0x69, 0x2e, 0x72, 0x74, 0x75, 0x6e, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, + 0x65, 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4c, 0x69, + 0x6e, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x28, 0x01, 0x30, 0x01, 0x42, 0x3e, + 0x5a, 0x3c, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x6e, + 0x64, 0x75, 0x63, 0x74, 0x6f, 0x72, 0x6f, 0x6e, 0x65, 0x2f, 0x62, 0x61, 0x74, 0x6f, 0x6e, 0x2d, + 0x73, 0x64, 0x6b, 0x2f, 0x70, 0x62, 0x2f, 0x63, 0x31, 0x2f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, + 0x74, 0x6f, 0x72, 0x61, 0x70, 0x69, 0x2f, 0x72, 0x74, 0x75, 0x6e, 0x2f, 0x76, 0x31, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +}) + +var ( + file_c1_connectorapi_rtun_v1_rtun_proto_rawDescOnce sync.Once + file_c1_connectorapi_rtun_v1_rtun_proto_rawDescData []byte +) + +func file_c1_connectorapi_rtun_v1_rtun_proto_rawDescGZIP() []byte { + file_c1_connectorapi_rtun_v1_rtun_proto_rawDescOnce.Do(func() { + file_c1_connectorapi_rtun_v1_rtun_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_c1_connectorapi_rtun_v1_rtun_proto_rawDesc), len(file_c1_connectorapi_rtun_v1_rtun_proto_rawDesc))) + }) + return file_c1_connectorapi_rtun_v1_rtun_proto_rawDescData +} + +var file_c1_connectorapi_rtun_v1_rtun_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes = make([]protoimpl.MessageInfo, 8) +var file_c1_connectorapi_rtun_v1_rtun_proto_goTypes = []any{ + (RstCode)(0), // 0: c1.connectorapi.rtun.v1.RstCode + (*ReverseTunnelServiceLinkRequest)(nil), // 1: c1.connectorapi.rtun.v1.ReverseTunnelServiceLinkRequest + (*ReverseTunnelServiceLinkResponse)(nil), // 2: c1.connectorapi.rtun.v1.ReverseTunnelServiceLinkResponse + (*Frame)(nil), // 3: c1.connectorapi.rtun.v1.Frame + (*Hello)(nil), // 4: c1.connectorapi.rtun.v1.Hello + (*Syn)(nil), // 5: c1.connectorapi.rtun.v1.Syn + (*Data)(nil), // 6: c1.connectorapi.rtun.v1.Data + (*Fin)(nil), // 7: c1.connectorapi.rtun.v1.Fin + (*Rst)(nil), // 8: c1.connectorapi.rtun.v1.Rst +} +var file_c1_connectorapi_rtun_v1_rtun_proto_depIdxs = []int32{ + 3, // 0: c1.connectorapi.rtun.v1.ReverseTunnelServiceLinkRequest.frame:type_name -> c1.connectorapi.rtun.v1.Frame + 3, // 1: c1.connectorapi.rtun.v1.ReverseTunnelServiceLinkResponse.frame:type_name -> c1.connectorapi.rtun.v1.Frame + 4, // 2: c1.connectorapi.rtun.v1.Frame.hello:type_name -> c1.connectorapi.rtun.v1.Hello + 5, // 3: c1.connectorapi.rtun.v1.Frame.syn:type_name -> c1.connectorapi.rtun.v1.Syn + 6, // 4: c1.connectorapi.rtun.v1.Frame.data:type_name -> c1.connectorapi.rtun.v1.Data + 7, // 5: c1.connectorapi.rtun.v1.Frame.fin:type_name -> c1.connectorapi.rtun.v1.Fin + 8, // 6: c1.connectorapi.rtun.v1.Frame.rst:type_name -> c1.connectorapi.rtun.v1.Rst + 0, // 7: c1.connectorapi.rtun.v1.Rst.code:type_name -> c1.connectorapi.rtun.v1.RstCode + 1, // 8: c1.connectorapi.rtun.v1.ReverseTunnelService.Link:input_type -> c1.connectorapi.rtun.v1.ReverseTunnelServiceLinkRequest + 2, // 9: c1.connectorapi.rtun.v1.ReverseTunnelService.Link:output_type -> c1.connectorapi.rtun.v1.ReverseTunnelServiceLinkResponse + 9, // [9:10] is the sub-list for method output_type + 8, // [8:9] is the sub-list for method input_type + 8, // [8:8] is the sub-list for extension type_name + 8, // [8:8] is the sub-list for extension extendee + 0, // [0:8] is the sub-list for field type_name +} + +func init() { file_c1_connectorapi_rtun_v1_rtun_proto_init() } +func file_c1_connectorapi_rtun_v1_rtun_proto_init() { + if File_c1_connectorapi_rtun_v1_rtun_proto != nil { + return + } + file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes[2].OneofWrappers = []any{ + (*Frame_Hello)(nil), + (*Frame_Syn)(nil), + (*Frame_Data)(nil), + (*Frame_Fin)(nil), + (*Frame_Rst)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_c1_connectorapi_rtun_v1_rtun_proto_rawDesc), len(file_c1_connectorapi_rtun_v1_rtun_proto_rawDesc)), + NumEnums: 1, + NumMessages: 8, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_c1_connectorapi_rtun_v1_rtun_proto_goTypes, + DependencyIndexes: file_c1_connectorapi_rtun_v1_rtun_proto_depIdxs, + EnumInfos: file_c1_connectorapi_rtun_v1_rtun_proto_enumTypes, + MessageInfos: file_c1_connectorapi_rtun_v1_rtun_proto_msgTypes, + }.Build() + File_c1_connectorapi_rtun_v1_rtun_proto = out.File + file_c1_connectorapi_rtun_v1_rtun_proto_goTypes = nil + file_c1_connectorapi_rtun_v1_rtun_proto_depIdxs = nil +} diff --git a/pb/c1/connectorapi/rtun/v1/rtun.pb.validate.go b/pb/c1/connectorapi/rtun/v1/rtun.pb.validate.go new file mode 100644 index 000000000..d3214846a --- /dev/null +++ b/pb/c1/connectorapi/rtun/v1/rtun.pb.validate.go @@ -0,0 +1,1112 @@ +// Code generated by protoc-gen-validate. DO NOT EDIT. +// source: c1/connectorapi/rtun/v1/rtun.proto + +package v1 + +import ( + "bytes" + "errors" + "fmt" + "net" + "net/mail" + "net/url" + "regexp" + "sort" + "strings" + "time" + "unicode/utf8" + + "google.golang.org/protobuf/types/known/anypb" +) + +// ensure the imports are used +var ( + _ = bytes.MinRead + _ = errors.New("") + _ = fmt.Print + _ = utf8.UTFMax + _ = (*regexp.Regexp)(nil) + _ = (*strings.Reader)(nil) + _ = net.IPv4len + _ = time.Duration(0) + _ = (*url.URL)(nil) + _ = (*mail.Address)(nil) + _ = anypb.Any{} + _ = sort.Sort +) + +// Validate checks the field values on ReverseTunnelServiceLinkRequest with the +// rules defined in the proto definition for this message. If any rules are +// violated, the first error encountered is returned, or nil if there are no violations. +func (m *ReverseTunnelServiceLinkRequest) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on ReverseTunnelServiceLinkRequest with +// the rules defined in the proto definition for this message. If any rules +// are violated, the result is a list of violation errors wrapped in +// ReverseTunnelServiceLinkRequestMultiError, or nil if none found. +func (m *ReverseTunnelServiceLinkRequest) ValidateAll() error { + return m.validate(true) +} + +func (m *ReverseTunnelServiceLinkRequest) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + if all { + switch v := interface{}(m.GetFrame()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, ReverseTunnelServiceLinkRequestValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, ReverseTunnelServiceLinkRequestValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetFrame()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return ReverseTunnelServiceLinkRequestValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + } + } + } + + if len(errors) > 0 { + return ReverseTunnelServiceLinkRequestMultiError(errors) + } + + return nil +} + +// ReverseTunnelServiceLinkRequestMultiError is an error wrapping multiple +// validation errors returned by ReverseTunnelServiceLinkRequest.ValidateAll() +// if the designated constraints aren't met. +type ReverseTunnelServiceLinkRequestMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m ReverseTunnelServiceLinkRequestMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m ReverseTunnelServiceLinkRequestMultiError) AllErrors() []error { return m } + +// ReverseTunnelServiceLinkRequestValidationError is the validation error +// returned by ReverseTunnelServiceLinkRequest.Validate if the designated +// constraints aren't met. +type ReverseTunnelServiceLinkRequestValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e ReverseTunnelServiceLinkRequestValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e ReverseTunnelServiceLinkRequestValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e ReverseTunnelServiceLinkRequestValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e ReverseTunnelServiceLinkRequestValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e ReverseTunnelServiceLinkRequestValidationError) ErrorName() string { + return "ReverseTunnelServiceLinkRequestValidationError" +} + +// Error satisfies the builtin error interface +func (e ReverseTunnelServiceLinkRequestValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sReverseTunnelServiceLinkRequest.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = ReverseTunnelServiceLinkRequestValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = ReverseTunnelServiceLinkRequestValidationError{} + +// Validate checks the field values on ReverseTunnelServiceLinkResponse with +// the rules defined in the proto definition for this message. If any rules +// are violated, the first error encountered is returned, or nil if there are +// no violations. +func (m *ReverseTunnelServiceLinkResponse) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on ReverseTunnelServiceLinkResponse with +// the rules defined in the proto definition for this message. If any rules +// are violated, the result is a list of violation errors wrapped in +// ReverseTunnelServiceLinkResponseMultiError, or nil if none found. +func (m *ReverseTunnelServiceLinkResponse) ValidateAll() error { + return m.validate(true) +} + +func (m *ReverseTunnelServiceLinkResponse) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + if all { + switch v := interface{}(m.GetFrame()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, ReverseTunnelServiceLinkResponseValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, ReverseTunnelServiceLinkResponseValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetFrame()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return ReverseTunnelServiceLinkResponseValidationError{ + field: "Frame", + reason: "embedded message failed validation", + cause: err, + } + } + } + + if len(errors) > 0 { + return ReverseTunnelServiceLinkResponseMultiError(errors) + } + + return nil +} + +// ReverseTunnelServiceLinkResponseMultiError is an error wrapping multiple +// validation errors returned by +// ReverseTunnelServiceLinkResponse.ValidateAll() if the designated +// constraints aren't met. +type ReverseTunnelServiceLinkResponseMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m ReverseTunnelServiceLinkResponseMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m ReverseTunnelServiceLinkResponseMultiError) AllErrors() []error { return m } + +// ReverseTunnelServiceLinkResponseValidationError is the validation error +// returned by ReverseTunnelServiceLinkResponse.Validate if the designated +// constraints aren't met. +type ReverseTunnelServiceLinkResponseValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e ReverseTunnelServiceLinkResponseValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e ReverseTunnelServiceLinkResponseValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e ReverseTunnelServiceLinkResponseValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e ReverseTunnelServiceLinkResponseValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e ReverseTunnelServiceLinkResponseValidationError) ErrorName() string { + return "ReverseTunnelServiceLinkResponseValidationError" +} + +// Error satisfies the builtin error interface +func (e ReverseTunnelServiceLinkResponseValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sReverseTunnelServiceLinkResponse.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = ReverseTunnelServiceLinkResponseValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = ReverseTunnelServiceLinkResponseValidationError{} + +// Validate checks the field values on Frame with the rules defined in the +// proto definition for this message. If any rules are violated, the first +// error encountered is returned, or nil if there are no violations. +func (m *Frame) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on Frame with the rules defined in the +// proto definition for this message. If any rules are violated, the result is +// a list of violation errors wrapped in FrameMultiError, or nil if none found. +func (m *Frame) ValidateAll() error { + return m.validate(true) +} + +func (m *Frame) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + // no validation rules for Sid + + switch v := m.Kind.(type) { + case *Frame_Hello: + if v == nil { + err := FrameValidationError{ + field: "Kind", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetHello()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, FrameValidationError{ + field: "Hello", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, FrameValidationError{ + field: "Hello", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetHello()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return FrameValidationError{ + field: "Hello", + reason: "embedded message failed validation", + cause: err, + } + } + } + + case *Frame_Syn: + if v == nil { + err := FrameValidationError{ + field: "Kind", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetSyn()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, FrameValidationError{ + field: "Syn", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, FrameValidationError{ + field: "Syn", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetSyn()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return FrameValidationError{ + field: "Syn", + reason: "embedded message failed validation", + cause: err, + } + } + } + + case *Frame_Data: + if v == nil { + err := FrameValidationError{ + field: "Kind", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetData()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, FrameValidationError{ + field: "Data", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, FrameValidationError{ + field: "Data", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetData()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return FrameValidationError{ + field: "Data", + reason: "embedded message failed validation", + cause: err, + } + } + } + + case *Frame_Fin: + if v == nil { + err := FrameValidationError{ + field: "Kind", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetFin()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, FrameValidationError{ + field: "Fin", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, FrameValidationError{ + field: "Fin", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetFin()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return FrameValidationError{ + field: "Fin", + reason: "embedded message failed validation", + cause: err, + } + } + } + + case *Frame_Rst: + if v == nil { + err := FrameValidationError{ + field: "Kind", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetRst()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, FrameValidationError{ + field: "Rst", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, FrameValidationError{ + field: "Rst", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetRst()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return FrameValidationError{ + field: "Rst", + reason: "embedded message failed validation", + cause: err, + } + } + } + + default: + _ = v // ensures v is used + } + + if len(errors) > 0 { + return FrameMultiError(errors) + } + + return nil +} + +// FrameMultiError is an error wrapping multiple validation errors returned by +// Frame.ValidateAll() if the designated constraints aren't met. +type FrameMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m FrameMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m FrameMultiError) AllErrors() []error { return m } + +// FrameValidationError is the validation error returned by Frame.Validate if +// the designated constraints aren't met. +type FrameValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e FrameValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e FrameValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e FrameValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e FrameValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e FrameValidationError) ErrorName() string { return "FrameValidationError" } + +// Error satisfies the builtin error interface +func (e FrameValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sFrame.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = FrameValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = FrameValidationError{} + +// Validate checks the field values on Hello with the rules defined in the +// proto definition for this message. If any rules are violated, the first +// error encountered is returned, or nil if there are no violations. +func (m *Hello) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on Hello with the rules defined in the +// proto definition for this message. If any rules are violated, the result is +// a list of violation errors wrapped in HelloMultiError, or nil if none found. +func (m *Hello) ValidateAll() error { + return m.validate(true) +} + +func (m *Hello) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + // no validation rules for Protocol + + if len(errors) > 0 { + return HelloMultiError(errors) + } + + return nil +} + +// HelloMultiError is an error wrapping multiple validation errors returned by +// Hello.ValidateAll() if the designated constraints aren't met. +type HelloMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m HelloMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m HelloMultiError) AllErrors() []error { return m } + +// HelloValidationError is the validation error returned by Hello.Validate if +// the designated constraints aren't met. +type HelloValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e HelloValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e HelloValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e HelloValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e HelloValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e HelloValidationError) ErrorName() string { return "HelloValidationError" } + +// Error satisfies the builtin error interface +func (e HelloValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sHello.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = HelloValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = HelloValidationError{} + +// Validate checks the field values on Syn with the rules defined in the proto +// definition for this message. If any rules are violated, the first error +// encountered is returned, or nil if there are no violations. +func (m *Syn) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on Syn with the rules defined in the +// proto definition for this message. If any rules are violated, the result is +// a list of violation errors wrapped in SynMultiError, or nil if none found. +func (m *Syn) ValidateAll() error { + return m.validate(true) +} + +func (m *Syn) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + // no validation rules for Port + + if len(errors) > 0 { + return SynMultiError(errors) + } + + return nil +} + +// SynMultiError is an error wrapping multiple validation errors returned by +// Syn.ValidateAll() if the designated constraints aren't met. +type SynMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m SynMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m SynMultiError) AllErrors() []error { return m } + +// SynValidationError is the validation error returned by Syn.Validate if the +// designated constraints aren't met. +type SynValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e SynValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e SynValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e SynValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e SynValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e SynValidationError) ErrorName() string { return "SynValidationError" } + +// Error satisfies the builtin error interface +func (e SynValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sSyn.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = SynValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = SynValidationError{} + +// Validate checks the field values on Data with the rules defined in the proto +// definition for this message. If any rules are violated, the first error +// encountered is returned, or nil if there are no violations. +func (m *Data) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on Data with the rules defined in the +// proto definition for this message. If any rules are violated, the result is +// a list of violation errors wrapped in DataMultiError, or nil if none found. +func (m *Data) ValidateAll() error { + return m.validate(true) +} + +func (m *Data) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + // no validation rules for Payload + + if len(errors) > 0 { + return DataMultiError(errors) + } + + return nil +} + +// DataMultiError is an error wrapping multiple validation errors returned by +// Data.ValidateAll() if the designated constraints aren't met. +type DataMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m DataMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m DataMultiError) AllErrors() []error { return m } + +// DataValidationError is the validation error returned by Data.Validate if the +// designated constraints aren't met. +type DataValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e DataValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e DataValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e DataValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e DataValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e DataValidationError) ErrorName() string { return "DataValidationError" } + +// Error satisfies the builtin error interface +func (e DataValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sData.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = DataValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = DataValidationError{} + +// Validate checks the field values on Fin with the rules defined in the proto +// definition for this message. If any rules are violated, the first error +// encountered is returned, or nil if there are no violations. +func (m *Fin) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on Fin with the rules defined in the +// proto definition for this message. If any rules are violated, the result is +// a list of violation errors wrapped in FinMultiError, or nil if none found. +func (m *Fin) ValidateAll() error { + return m.validate(true) +} + +func (m *Fin) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + // no validation rules for Ack + + if len(errors) > 0 { + return FinMultiError(errors) + } + + return nil +} + +// FinMultiError is an error wrapping multiple validation errors returned by +// Fin.ValidateAll() if the designated constraints aren't met. +type FinMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m FinMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m FinMultiError) AllErrors() []error { return m } + +// FinValidationError is the validation error returned by Fin.Validate if the +// designated constraints aren't met. +type FinValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e FinValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e FinValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e FinValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e FinValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e FinValidationError) ErrorName() string { return "FinValidationError" } + +// Error satisfies the builtin error interface +func (e FinValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sFin.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = FinValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = FinValidationError{} + +// Validate checks the field values on Rst with the rules defined in the proto +// definition for this message. If any rules are violated, the first error +// encountered is returned, or nil if there are no violations. +func (m *Rst) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on Rst with the rules defined in the +// proto definition for this message. If any rules are violated, the result is +// a list of violation errors wrapped in RstMultiError, or nil if none found. +func (m *Rst) ValidateAll() error { + return m.validate(true) +} + +func (m *Rst) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + // no validation rules for Code + + if len(errors) > 0 { + return RstMultiError(errors) + } + + return nil +} + +// RstMultiError is an error wrapping multiple validation errors returned by +// Rst.ValidateAll() if the designated constraints aren't met. +type RstMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m RstMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m RstMultiError) AllErrors() []error { return m } + +// RstValidationError is the validation error returned by Rst.Validate if the +// designated constraints aren't met. +type RstValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e RstValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e RstValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e RstValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e RstValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e RstValidationError) ErrorName() string { return "RstValidationError" } + +// Error satisfies the builtin error interface +func (e RstValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sRst.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = RstValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = RstValidationError{} diff --git a/pb/c1/connectorapi/rtun/v1/rtun_grpc.pb.go b/pb/c1/connectorapi/rtun/v1/rtun_grpc.pb.go new file mode 100644 index 000000000..7e9ca2932 --- /dev/null +++ b/pb/c1/connectorapi/rtun/v1/rtun_grpc.pb.go @@ -0,0 +1,113 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc (unknown) +// source: c1/connectorapi/rtun/v1/rtun.proto + +package v1 + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + ReverseTunnelService_Link_FullMethodName = "/c1.connectorapi.rtun.v1.ReverseTunnelService/Link" +) + +// ReverseTunnelServiceClient is the client API for ReverseTunnelService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type ReverseTunnelServiceClient interface { + Link(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ReverseTunnelServiceLinkRequest, ReverseTunnelServiceLinkResponse], error) +} + +type reverseTunnelServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewReverseTunnelServiceClient(cc grpc.ClientConnInterface) ReverseTunnelServiceClient { + return &reverseTunnelServiceClient{cc} +} + +func (c *reverseTunnelServiceClient) Link(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ReverseTunnelServiceLinkRequest, ReverseTunnelServiceLinkResponse], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &ReverseTunnelService_ServiceDesc.Streams[0], ReverseTunnelService_Link_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[ReverseTunnelServiceLinkRequest, ReverseTunnelServiceLinkResponse]{ClientStream: stream} + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type ReverseTunnelService_LinkClient = grpc.BidiStreamingClient[ReverseTunnelServiceLinkRequest, ReverseTunnelServiceLinkResponse] + +// ReverseTunnelServiceServer is the server API for ReverseTunnelService service. +// All implementations should embed UnimplementedReverseTunnelServiceServer +// for forward compatibility. +type ReverseTunnelServiceServer interface { + Link(grpc.BidiStreamingServer[ReverseTunnelServiceLinkRequest, ReverseTunnelServiceLinkResponse]) error +} + +// UnimplementedReverseTunnelServiceServer should be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedReverseTunnelServiceServer struct{} + +func (UnimplementedReverseTunnelServiceServer) Link(grpc.BidiStreamingServer[ReverseTunnelServiceLinkRequest, ReverseTunnelServiceLinkResponse]) error { + return status.Errorf(codes.Unimplemented, "method Link not implemented") +} +func (UnimplementedReverseTunnelServiceServer) testEmbeddedByValue() {} + +// UnsafeReverseTunnelServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to ReverseTunnelServiceServer will +// result in compilation errors. +type UnsafeReverseTunnelServiceServer interface { + mustEmbedUnimplementedReverseTunnelServiceServer() +} + +func RegisterReverseTunnelServiceServer(s grpc.ServiceRegistrar, srv ReverseTunnelServiceServer) { + // If the following call pancis, it indicates UnimplementedReverseTunnelServiceServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&ReverseTunnelService_ServiceDesc, srv) +} + +func _ReverseTunnelService_Link_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(ReverseTunnelServiceServer).Link(&grpc.GenericServerStream[ReverseTunnelServiceLinkRequest, ReverseTunnelServiceLinkResponse]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type ReverseTunnelService_LinkServer = grpc.BidiStreamingServer[ReverseTunnelServiceLinkRequest, ReverseTunnelServiceLinkResponse] + +// ReverseTunnelService_ServiceDesc is the grpc.ServiceDesc for ReverseTunnelService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var ReverseTunnelService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "c1.connectorapi.rtun.v1.ReverseTunnelService", + HandlerType: (*ReverseTunnelServiceServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "Link", + Handler: _ReverseTunnelService_Link_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: "c1/connectorapi/rtun/v1/rtun.proto", +} diff --git a/pkg/lambda/grpc/config/sts.go b/pkg/lambda/grpc/config/sts.go index d52644b76..084b2f755 100644 --- a/pkg/lambda/grpc/config/sts.go +++ b/pkg/lambda/grpc/config/sts.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "net/http" + "net/url" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -21,7 +22,12 @@ func createSigv4STSGetCallerIdentityRequest(ctx context.Context, cfg *aws.Config region := cfg.Region body := "Action=GetCallerIdentity&Version=2011-06-15" service := "sts" - endpoint := fmt.Sprintf("https://sts.%s.amazonaws.com", region) + + endpoint := (&url.URL{ + Scheme: "https", + Host: fmt.Sprintf("sts.%s.amazonaws.com", region), + }).String() + method := "POST" reqHeaders := map[string][]string{ diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 7af8bcd86..f22f3b561 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -8,6 +8,10 @@ type Handler interface { Int64Counter(name string, description string, unit Unit) Int64Counter Int64Gauge(name string, description string, unit Unit) Int64Gauge Int64Histogram(name string, description string, unit Unit) Int64Histogram + // RegisterInt64ObservableGauge registers an asynchronous gauge that will be observed + // during collection by invoking the provided callback. The callback should return + // the current value and optional tags at observation time. + RegisterInt64ObservableGauge(name string, description string, unit Unit, callback func(ctx context.Context) (int64, map[string]string)) WithTags(tags map[string]string) Handler } diff --git a/pkg/metrics/noop.go b/pkg/metrics/noop.go index eec13f8ae..7fde9d2af 100644 --- a/pkg/metrics/noop.go +++ b/pkg/metrics/noop.go @@ -28,6 +28,9 @@ func (*noopHandler) Int64Histogram(_ string, _ string, _ Unit) Int64Histogram { return &noopRecorder{} } +func (*noopHandler) RegisterInt64ObservableGauge(_ string, _ string, _ Unit, _ func(ctx context.Context) (int64, map[string]string)) { +} + func (*noopHandler) WithTags(_ map[string]string) Handler { return &noopHandler{} } diff --git a/pkg/metrics/otel.go b/pkg/metrics/otel.go index 9a43d79e7..3b51d31fb 100644 --- a/pkg/metrics/otel.go +++ b/pkg/metrics/otel.go @@ -22,6 +22,11 @@ type otelHandler struct { provider otelmetric.MeterProvider defaultAttrs *[]attribute.KeyValue + // protects access to defaultAttrs and registeredGauges + defaultAttrsMtx sync.RWMutex + registeredGaugesMtx sync.Mutex + registeredGauges map[string]struct{} + int64CountersMtx sync.Mutex int64Counters map[string]*otelInt64Counter int64HistosMtx sync.Mutex @@ -154,10 +159,49 @@ func (h *otelHandler) Int64Gauge(name string, description string, unit Unit) Int return c } +func (h *otelHandler) RegisterInt64ObservableGauge(name string, description string, unit Unit, callback func(ctx context.Context) (int64, map[string]string)) { + name = strings.ToLower(name) + + // prevent duplicate registrations for the same gauge name + h.registeredGaugesMtx.Lock() + if h.registeredGauges == nil { + h.registeredGauges = make(map[string]struct{}) + } + if _, exists := h.registeredGauges[name]; exists { + h.registeredGaugesMtx.Unlock() + return + } + h.registeredGauges[name] = struct{}{} + h.registeredGaugesMtx.Unlock() + + gauge, err := h.meter.Int64ObservableGauge(name, otelmetric.WithDescription(description), otelmetric.WithUnit(string(unit))) + if err != nil { + panic(err) + } + // capture default attrs pointer for this handler under read lock + _, err = h.meter.RegisterCallback(func(ctx context.Context, observer otelmetric.Observer) error { + val, tags := callback(ctx) + attrs := makeAttrs(tags) + h.defaultAttrsMtx.RLock() + defaultAttrs := h.defaultAttrs + h.defaultAttrsMtx.RUnlock() + if defaultAttrs != nil { + attrs = append(attrs, *defaultAttrs...) + } + observer.ObserveInt64(gauge, val, otelmetric.WithAttributes(attrs...)) + return nil + }, gauge) + if err != nil { + panic(err) + } +} + func (h *otelHandler) WithTags(tags map[string]string) Handler { attrs := makeAttrs(tags) + h.defaultAttrsMtx.Lock() h.defaultAttrs = &attrs + h.defaultAttrsMtx.Unlock() return h } @@ -173,11 +217,12 @@ func makeAttrs(tags map[string]string) []attribute.KeyValue { func NewOtelHandler(_ context.Context, provider otelmetric.MeterProvider, name string) Handler { return &otelHandler{ - name: name, - meter: provider.Meter(name), - provider: provider, - int64Counters: make(map[string]*otelInt64Counter), - int64Histos: make(map[string]*otelInt64Histogram), - int64Gauges: make(map[string]*otelInt64Gauge), + name: name, + meter: provider.Meter(name), + provider: provider, + int64Counters: make(map[string]*otelInt64Counter), + int64Histos: make(map[string]*otelInt64Histogram), + int64Gauges: make(map[string]*otelInt64Gauge), + registeredGauges: make(map[string]struct{}), } } diff --git a/pkg/rtun/gateway/client.go b/pkg/rtun/gateway/client.go new file mode 100644 index 000000000..06176c07b --- /dev/null +++ b/pkg/rtun/gateway/client.go @@ -0,0 +1,556 @@ +package gateway + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "sync" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" + "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +const ( + defaultReadBufferCap = 16 + defaultWriteQueueCap = 16 + maxChunkSize = 32 * 1024 +) + +// Dialer opens reverse connections to clients via a gateway server. +type Dialer struct { + gatewayAddr string + creds credentials.TransportCredentials + // configuration + readBufferCap int // number of frames to buffer for reads + writeQueueCap int // number of chunks to buffer for writes +} + +// DialerOption configures the Dialer. +type DialerOption func(*Dialer) + +// WithReadBufferCapacity sets the number of inbound frames to buffer before backpressure blocks producer. +func WithReadBufferCapacity(capacity int) DialerOption { + return func(d *Dialer) { + if capacity > 0 { + d.readBufferCap = capacity + } + } +} + +// WithWriteQueueCapacity sets the number of outbound chunks to queue before backpressure blocks writers. +func WithWriteQueueCapacity(capacity int) DialerOption { + return func(d *Dialer) { + if capacity > 0 { + d.writeQueueCap = capacity + } + } +} + +// NewDialerWithOptions creates a gateway client with extra configuration. +func NewDialer(gatewayAddr string, creds credentials.TransportCredentials, opts ...DialerOption) *Dialer { + d := &Dialer{ + gatewayAddr: gatewayAddr, + creds: creds, + readBufferCap: defaultReadBufferCap, + writeQueueCap: defaultWriteQueueCap, + } + + for _, opt := range opts { + if opt != nil { + opt(d) + } + } + return d +} + +// DialContext opens a reverse connection to clientID:port via the gateway. +// Returns ErrNotFound if the gateway doesn't own the client (caller should re-resolve owner). +func (d *Dialer) DialContext(ctx context.Context, clientID string, port uint32) (net.Conn, error) { + logger := ctxzap.Extract(ctx).With(zap.String("client_id", clientID), zap.Uint32("port", port)) + + // Dial gateway + cc, err := grpc.NewClient("passthrough:///"+d.gatewayAddr, + grpc.WithTransportCredentials(d.creds), + ) + if err != nil { + return nil, fmt.Errorf("gateway dial failed: %w", err) + } + + client := rtunpb.NewReverseDialerServiceClient(cc) + // Create a cancellable stream context so Close() can interrupt Recv/Send. + streamCtx, cancel := context.WithCancel(ctx) + stream, err := client.Open(streamCtx) + if err != nil { + cancel() + cc.Close() + return nil, fmt.Errorf("gateway open stream failed: %w", err) + } + + // Send OpenRequest with gSID=1 (simple case: one connection per stream) + gsid := uint32(1) + if err := stream.Send(&rtunpb.ReverseDialerServiceOpenRequest{Kind: &rtunpb.ReverseDialerServiceOpenRequest_OpenReq{ + OpenReq: &rtunpb.OpenRequest{Gsid: gsid, ClientId: clientID, Port: port}, + }}); err != nil { + _ = stream.CloseSend() + cancel() + cc.Close() + return nil, fmt.Errorf("gateway send OpenRequest failed: %w", err) + } + + // Recv OpenResponse + resp, err := stream.Recv() + if err != nil { + _ = stream.CloseSend() + cancel() + cc.Close() + return nil, fmt.Errorf("gateway recv OpenResponse failed: %w", err) + } + + openResp := resp.GetOpenResp() + if openResp == nil { + _ = stream.CloseSend() + cancel() + cc.Close() + return nil, ErrProtocol + } + if openResp.GetGsid() != gsid { + _ = stream.CloseSend() + cancel() + cc.Close() + return nil, fmt.Errorf("gateway returned mismatched gSID: got %d, want %d", openResp.GetGsid(), gsid) + } + + switch openResp.Result.(type) { + case *rtunpb.OpenResponse_NotFound: + _ = stream.CloseSend() + cancel() + cc.Close() + logger.Info("client not found on gateway") + return nil, ErrNotFound + case *rtunpb.OpenResponse_Opened: + logger.Info("gateway connection opened") + doneCh := make(chan struct{}) + gc := &gatewayConn{ + stream: stream, + cc: cc, + gsid: gsid, + cancel: cancel, + doneCh: doneCh, + } + gc.r = newReader(stream, gsid, d.readBufferCap, doneCh) + gc.w = newWriter(stream, gsid, d.writeQueueCap, doneCh) + return gc, nil + default: + _ = stream.CloseSend() + cancel() + cc.Close() + return nil, ErrProtocol + } +} + +var _ net.Conn = (*gatewayConn)(nil) + +// gatewayConn implements net.Conn over a gateway stream. +type gatewayConn struct { + stream rtunpb.ReverseDialerService_OpenClient + cc *grpc.ClientConn + gsid uint32 + cancel context.CancelFunc + + // reader/writer components + r *reader + w *writer + + readMu sync.Mutex + writeMu sync.Mutex + writeClosed bool + + closeOnce sync.Once + doneCh chan struct{} + + rdDeadline time.Time + wrDeadline time.Time +} + +type writeMsg struct { + payload []byte + fin bool +} + +type reader struct { + stream rtunpb.ReverseDialerService_OpenClient + gsid uint32 + bufCap int + ch chan []byte + doneCh <-chan struct{} + + mu sync.Mutex + rem []byte + err error + once sync.Once +} + +func newReader(stream rtunpb.ReverseDialerService_OpenClient, gsid uint32, bufCap int, doneCh <-chan struct{}) *reader { + if bufCap <= 0 { + bufCap = defaultReadBufferCap + } + return &reader{ + stream: stream, + gsid: gsid, + bufCap: bufCap, + doneCh: doneCh, + } +} + +func (r *reader) start() { + r.once.Do(func() { + r.ch = make(chan []byte, r.bufCap) + go r.loop() + }) +} + +func (r *reader) loop() { + defer close(r.ch) + for { + resp, err := r.stream.Recv() + if err != nil { + r.mu.Lock() + if r.err == nil { + r.err = err + } + r.mu.Unlock() + return + } + fr := resp.GetFrame() + if fr == nil || fr.GetSid() != r.gsid { + continue + } + switch k := fr.Kind.(type) { + case *rtunpb.Frame_Data: + payload := append([]byte(nil), k.Data.GetPayload()...) + select { + case r.ch <- payload: + case <-r.doneCh: + return + case <-r.stream.Context().Done(): + return + } + case *rtunpb.Frame_Fin: + r.mu.Lock() + r.err = io.EOF + r.mu.Unlock() + return + case *rtunpb.Frame_Rst: + r.mu.Lock() + r.err = fmt.Errorf("gateway: connection reset (code %v)", k.Rst.GetCode()) + r.mu.Unlock() + return + } + } +} + +func (r *reader) next(ctx context.Context) ([]byte, error) { + r.start() + select { + case buf, ok := <-r.ch: + if !ok { + r.mu.Lock() + err := r.err + r.mu.Unlock() + if err == nil { + return nil, io.EOF + } + return nil, err + } + return buf, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +type writer struct { + stream rtunpb.ReverseDialerService_OpenClient + gsid uint32 + queueCap int + ch chan writeMsg + doneCh <-chan struct{} + + mu sync.Mutex + err error + once sync.Once +} + +func newWriter(stream rtunpb.ReverseDialerService_OpenClient, gsid uint32, queueCap int, doneCh <-chan struct{}) *writer { + if queueCap <= 0 { + queueCap = defaultWriteQueueCap + } + return &writer{ + stream: stream, + gsid: gsid, + queueCap: queueCap, + doneCh: doneCh, + } +} + +func (w *writer) start() { + w.once.Do(func() { + w.ch = make(chan writeMsg, w.queueCap) + go w.loop() + }) +} + +func (w *writer) setErr(err error) { + w.mu.Lock() + if w.err == nil { + w.err = err + } + w.mu.Unlock() +} + +func (w *writer) getErr() error { + w.mu.Lock() + defer w.mu.Unlock() + return w.err +} + +func (w *writer) loop() { + for { + select { + case msg := <-w.ch: + if msg.fin { + finFrame := &rtunpb.Frame{ + Sid: w.gsid, + Kind: &rtunpb.Frame_Fin{Fin: &rtunpb.Fin{}}, + } + _ = w.stream.Send( + &rtunpb.ReverseDialerServiceOpenRequest{ + Kind: &rtunpb.ReverseDialerServiceOpenRequest_Frame{ + Frame: finFrame, + }, + }, + ) + continue + } + + dataFrame := &rtunpb.Frame{ + Sid: w.gsid, + Kind: &rtunpb.Frame_Data{ + Data: &rtunpb.Data{Payload: msg.payload}, + }, + } + if err := w.stream.Send( + &rtunpb.ReverseDialerServiceOpenRequest{ + Kind: &rtunpb.ReverseDialerServiceOpenRequest_Frame{ + Frame: dataFrame, + }, + }, + ); err != nil { + w.setErr(err) + return + } + case <-w.doneCh: + return + case <-w.stream.Context().Done(): + return + } + } +} + +func (w *writer) enqueue(msg writeMsg, deadline time.Time) error { + w.start() + // fast-fail if previous send error + if err := w.getErr(); err != nil { + return err + } + + if !deadline.IsZero() { + until := time.Until(deadline) + if until <= 0 { + return context.DeadlineExceeded + } + timer := time.NewTimer(until) + defer timer.Stop() + select { + case w.ch <- msg: + return nil + case <-timer.C: + return context.DeadlineExceeded + case <-w.doneCh: + return fmt.Errorf("rtun/gateway: write after close: %w", net.ErrClosed) + case <-w.stream.Context().Done(): + return w.stream.Context().Err() + } + } + select { + case w.ch <- msg: + return nil + case <-w.doneCh: + return fmt.Errorf("rtun/gateway: write after close: %w", net.ErrClosed) + case <-w.stream.Context().Done(): + return w.stream.Context().Err() + } +} + +func (g *gatewayConn) Read(p []byte) (int, error) { + // Consume remainder first + g.r.mu.Lock() + if len(g.r.rem) > 0 { + n := copy(p, g.r.rem) + g.r.rem = g.r.rem[n:] + g.r.mu.Unlock() + return n, nil + } + g.r.mu.Unlock() + + // Compute deadline context (snapshot under lock) + g.readMu.Lock() + deadline := g.rdDeadline + g.readMu.Unlock() + + var ctx context.Context + var cancel context.CancelFunc + if deadline.IsZero() { + ctx = context.Background() + cancel = func() {} + } else { + until := time.Until(deadline) + if until <= 0 { + return 0, context.DeadlineExceeded + } + ctx, cancel = context.WithTimeout(context.Background(), until) + } + defer cancel() + + buf, err := g.r.next(ctx) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) || errors.Is(err, io.EOF) { + return 0, err + } + if g.isClosed() { + return 0, fmt.Errorf("rtun/gateway: read on closed connection: %w", net.ErrClosed) + } + return 0, err + } + n := copy(p, buf) + if n < len(buf) { + g.r.mu.Lock() + g.r.rem = buf[n:] + g.r.mu.Unlock() + } + return n, nil +} + +func (g *gatewayConn) isClosed() bool { + select { + case <-g.doneCh: + return true + default: + return false + } +} + +// recvLoop moved into reader.loop + +func (g *gatewayConn) Write(p []byte) (int, error) { + g.writeMu.Lock() + if g.writeClosed { + g.writeMu.Unlock() + return 0, fmt.Errorf("rtun/gateway: write on closed connection: %w", net.ErrClosed) + } + // snapshot write deadline while holding the write lock + deadline := g.wrDeadline + // keep writeMu locked for the duration to serialize writes and SetWriteDeadline + defer g.writeMu.Unlock() + + if err := g.w.getErr(); err != nil { + return 0, err + } + + total := 0 + for len(p) > 0 { + chunk := p + if len(chunk) > maxChunkSize { + chunk = p[:maxChunkSize] + } + cp := append([]byte(nil), chunk...) + if err := g.w.enqueue(writeMsg{payload: cp}, deadline); err != nil { + if total == 0 { + return 0, err + } + return total, err + } + total += len(chunk) + p = p[len(chunk):] + } + return total, nil +} + +// writer loop moved into writer.loop + +func (g *gatewayConn) Close() error { + g.closeOnce.Do(func() { + g.writeMu.Lock() + if !g.writeClosed { + g.writeClosed = true + // best-effort FIN without blocking; if writer not started yet, send directly + if g.w != nil && g.w.ch != nil { + select { + case g.w.ch <- writeMsg{fin: true}: + default: + } + } else { + _ = g.stream.Send(&rtunpb.ReverseDialerServiceOpenRequest{Kind: &rtunpb.ReverseDialerServiceOpenRequest_Frame{ + Frame: &rtunpb.Frame{Sid: g.gsid, Kind: &rtunpb.Frame_Fin{Fin: &rtunpb.Fin{}}}, + }}) + } + } + g.writeMu.Unlock() + + // Cancel stream context to unblock Recv/Send + if g.cancel != nil { + g.cancel() + } + close(g.doneCh) + _ = g.stream.CloseSend() + _ = g.cc.Close() + }) + return nil +} + +func (g *gatewayConn) LocalAddr() net.Addr { return gatewayAddr{"gateway-local"} } +func (g *gatewayConn) RemoteAddr() net.Addr { return gatewayAddr{"gateway-remote"} } + +func (g *gatewayConn) SetDeadline(t time.Time) error { + g.readMu.Lock() + g.rdDeadline = t + g.readMu.Unlock() + g.writeMu.Lock() + g.wrDeadline = t + g.writeMu.Unlock() + return nil +} + +func (g *gatewayConn) SetReadDeadline(t time.Time) error { + g.readMu.Lock() + g.rdDeadline = t + g.readMu.Unlock() + return nil +} + +func (g *gatewayConn) SetWriteDeadline(t time.Time) error { + g.writeMu.Lock() + g.wrDeadline = t + g.writeMu.Unlock() + return nil +} + +type gatewayAddr struct{ s string } + +func (a gatewayAddr) Network() string { return "gateway" } +func (a gatewayAddr) String() string { return a.s } diff --git a/pkg/rtun/gateway/client_conn_test.go b/pkg/rtun/gateway/client_conn_test.go new file mode 100644 index 000000000..d8eb619d1 --- /dev/null +++ b/pkg/rtun/gateway/client_conn_test.go @@ -0,0 +1,171 @@ +package gateway + +import ( + "context" + "errors" + "io" + "net" + "testing" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" + "github.com/conductorone/baton-sdk/pkg/rtun/server" + "github.com/conductorone/baton-sdk/pkg/rtun/transport" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + health "google.golang.org/grpc/health" + healthpb "google.golang.org/grpc/health/grpc_health_v1" +) + +type gwEnv struct { + addr string + cleanup func() + accepted chan net.Conn +} + +// setupGateway spins up a real ReverseTunnel + ReverseDialer server and returns control. +// If silent is true, the client listener will accept a raw conn and never write to it unless tests do. +func setupGateway(t *testing.T, silent bool) *gwEnv { + t.Helper() + reg := server.NewRegistry() + handler := server.NewHandler(reg, "server-a", testValidator{id: "client-xyz"}) + gw := NewServer(reg, "server-a", nil) + + gsrv := grpc.NewServer() + rtunpb.RegisterReverseTunnelServiceServer(gsrv, handler) + rtunpb.RegisterReverseDialerServiceServer(gsrv, gw) + l, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0") + require.NoError(t, err) + go func() { _ = gsrv.Serve(l) }() + + // Bring up a client link and listen on port 1 + cc, err := grpc.NewClient("passthrough:///"+l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + rtunClient := rtunpb.NewReverseTunnelServiceClient(cc) + stream, err := rtunClient.Link(context.Background()) + require.NoError(t, err) + cl := &clientLink{cli: stream} + sess := transport.NewSession(cl) + require.NoError(t, cl.Send(&rtunpb.Frame{Sid: 0, Kind: &rtunpb.Frame_Hello{Hello: &rtunpb.Hello{Ports: []uint32{1}}}})) + ln, err := sess.Listen(context.Background(), 1) + require.NoError(t, err) + + accepted := make(chan net.Conn, 1) + var cgs *grpc.Server + if silent { + // Accept one conn and expose it to tests + go func() { c, _ := ln.Accept(); accepted <- c }() + } else { + // Serve a health server + cgs = grpc.NewServer() + hs := health.NewServer() + healthpb.RegisterHealthServer(cgs, hs) + hs.SetServingStatus("", healthpb.HealthCheckResponse_SERVING) + go func() { _ = cgs.Serve(ln) }() + } + + cleanup := func() { + if cgs != nil { + cgs.GracefulStop() + } + _ = ln.Close() + _ = cc.Close() + gsrv.GracefulStop() + _ = l.Close() + } + return &gwEnv{addr: l.Addr().String(), cleanup: cleanup, accepted: accepted} +} + +func TestGatewayConnReadDeadline(t *testing.T) { + env := setupGateway(t, true) + defer env.cleanup() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + d := NewDialer(env.addr, insecure.NewCredentials()) + gwc, err := d.DialContext(ctx, "client-xyz", 1) + require.NoError(t, err) + defer gwc.Close() + + _ = gwc.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + _, err = gwc.Read(make([]byte, 1)) + require.ErrorIs(t, err, context.DeadlineExceeded) +} + +func TestGatewayConnCloseIdempotentAndWriteAfterClose(t *testing.T) { + env := setupGateway(t, true) + defer env.cleanup() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + d := NewDialer(env.addr, insecure.NewCredentials()) + gwc, err := d.DialContext(ctx, "client-xyz", 1) + require.NoError(t, err) + + require.NoError(t, gwc.Close()) + require.NoError(t, gwc.Close()) + _, err = gwc.Write([]byte("x")) + require.True(t, errors.Is(err, net.ErrClosed)) +} + +func TestGatewayConnEOFOnRemoteClose(t *testing.T) { + env := setupGateway(t, true) + defer env.cleanup() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + d := NewDialer(env.addr, insecure.NewCredentials()) + gwc, err := d.DialContext(ctx, "client-xyz", 1) + require.NoError(t, err) + defer gwc.Close() + + // Remote close: wait for accept and then close the accepted conn to generate FIN. + select { + case rc := <-env.accepted: + _ = rc.Close() + case <-time.After(200 * time.Millisecond): + t.Fatal("no accepted conn") + } + + // Read should return EOF + _, err = gwc.Read(make([]byte, 1)) + require.ErrorIs(t, err, io.EOF) +} + +func TestGatewayConnWriteAndRemoteReceive(t *testing.T) { + env := setupGateway(t, true) + defer env.cleanup() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + d := NewDialer(env.addr, insecure.NewCredentials()) + gwc, err := d.DialContext(ctx, "client-xyz", 1) + require.NoError(t, err) + defer gwc.Close() + + // Wait for remote accept + var rc net.Conn + select { + case rc = <-env.accepted: + case <-time.After(200 * time.Millisecond): + t.Fatal("no accepted conn") + } + + // Write and verify bytes arrive remotely + msg := []byte("hello") + n, err := gwc.Write(msg) + require.NoError(t, err) + require.Equal(t, len(msg), n) + buf := make([]byte, len(msg)) + rn, err := io.ReadFull(rc, buf) + require.NoError(t, err) + require.Equal(t, len(msg), rn) + require.Equal(t, msg, buf) + + // Local FIN: close gwc; remote should see EOF + _ = gwc.Close() + tmp := make([]byte, 1) + _, err = rc.Read(tmp) + require.ErrorIs(t, err, io.EOF) +} diff --git a/pkg/rtun/gateway/errors.go b/pkg/rtun/gateway/errors.go new file mode 100644 index 000000000..3deaf8d35 --- /dev/null +++ b/pkg/rtun/gateway/errors.go @@ -0,0 +1,9 @@ +package gateway + +import "errors" + +var ( + ErrNotFound = errors.New("rtun/gateway: client not found on this server") + ErrInvalidGSID = errors.New("rtun/gateway: invalid or duplicate gSID") + ErrProtocol = errors.New("rtun/gateway: protocol error") +) diff --git a/pkg/rtun/gateway/integration_test.go b/pkg/rtun/gateway/integration_test.go new file mode 100644 index 000000000..86e2dc8c4 --- /dev/null +++ b/pkg/rtun/gateway/integration_test.go @@ -0,0 +1,145 @@ +package gateway + +import ( + "context" + "net" + "testing" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" + "github.com/conductorone/baton-sdk/pkg/rtun/server" + "github.com/conductorone/baton-sdk/pkg/rtun/transport" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + health "google.golang.org/grpc/health" + healthpb "google.golang.org/grpc/health/grpc_health_v1" +) + +// testValidator for integration tests. +type testValidator struct{ id string } + +func (t testValidator) ValidateAuth(ctx context.Context) (string, error) { return t.id, nil } +func (t testValidator) ValidateHello(ctx context.Context, hello *rtunpb.Hello) error { return nil } + +// clientLink adapts client bidi stream to transport.Link. +type clientLink struct { + cli rtunpb.ReverseTunnelService_LinkClient +} + +func (c *clientLink) Send(f *rtunpb.Frame) error { + return c.cli.Send(&rtunpb.ReverseTunnelServiceLinkRequest{Frame: f}) +} + +func (c *clientLink) Recv() (*rtunpb.Frame, error) { + fr, err := c.cli.Recv() + if err != nil { + return nil, err + } + return fr.GetFrame(), nil +} + +func (c *clientLink) Context() context.Context { return c.cli.Context() } + +// TestGatewayE2E validates the full gateway stack: +// - Client connects to server A (handler+registry). +// - Gateway server on A. +// - Remote caller uses gateway.Dialer to get net.Conn to client. +// - Caller performs gRPC health check over gateway conn. +func TestGatewayE2E(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Server A: rtun handler + registry + gateway + regA := server.NewRegistry() + handlerA := server.NewHandler(regA, "server-a", testValidator{id: "client-123"}) + gwA := NewServer(regA, "server-a", nil) + + gsrvA := grpc.NewServer() + rtunpb.RegisterReverseTunnelServiceServer(gsrvA, handlerA) + rtunpb.RegisterReverseDialerServiceServer(gsrvA, gwA) + + lA, err := (&net.ListenConfig{}).Listen(ctx, "tcp", "127.0.0.1:0") + require.NoError(t, err) + go func() { _ = gsrvA.Serve(lA) }() + + // Client connects to server A + clientCtx, clientCancel := context.WithCancel(ctx) + defer clientCancel() + + ccA, err := grpc.NewClient("passthrough:///"+lA.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + + rtunClient := rtunpb.NewReverseTunnelServiceClient(ccA) + stream, err := rtunClient.Link(clientCtx) + require.NoError(t, err) + + // Client session: send HELLO, listen on port 1 + cl := &clientLink{cli: stream} + sess := transport.NewSession(cl) + require.NoError(t, cl.Send(&rtunpb.Frame{Sid: 0, Kind: &rtunpb.Frame_Hello{Hello: &rtunpb.Hello{Ports: []uint32{1}}}})) + ln, err := sess.Listen(ctx, 1) + require.NoError(t, err) + + // Client runs health service + cgs := grpc.NewServer() + hs := health.NewServer() + healthpb.RegisterHealthServer(cgs, hs) + hs.SetServingStatus("", healthpb.HealthCheckResponse_SERVING) + go func() { _ = cgs.Serve(ln) }() + + // Wait for registration + time.Sleep(50 * time.Millisecond) + + // Remote caller: use gateway.Dialer to get net.Conn to client + dialer := NewDialer(lA.Addr().String(), insecure.NewCredentials(), nil) + gwConn, err := dialer.DialContext(ctx, "client-123", 1) + require.NoError(t, err) + defer gwConn.Close() + + // Wrap gateway conn in grpc.Dial and perform health check + callerCC, err := grpc.NewClient("passthrough:///ignored", + grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { return gwConn, nil }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + defer callerCC.Close() + + hc := healthpb.NewHealthClient(callerCC) + resp, err := hc.Check(ctx, &healthpb.HealthCheckRequest{Service: ""}) + require.NoError(t, err) + require.Equal(t, healthpb.HealthCheckResponse_SERVING, resp.GetStatus()) + + // Cleanup + callerCC.Close() + gwConn.Close() + cgs.GracefulStop() + ln.Close() + clientCancel() + ccA.Close() + gsrvA.GracefulStop() + lA.Close() +} + +func TestGatewayNotFound(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // Server with gateway but no client registered + regA := server.NewRegistry() + gwA := NewServer(regA, "server-a", nil) + + gsrvA := grpc.NewServer() + rtunpb.RegisterReverseDialerServiceServer(gsrvA, gwA) + + lA, err := (&net.ListenConfig{}).Listen(ctx, "tcp", "127.0.0.1:0") + require.NoError(t, err) + defer lA.Close() + go func() { _ = gsrvA.Serve(lA) }() + defer gsrvA.GracefulStop() + + // Caller tries to dial non-existent client + dialer := NewDialer(lA.Addr().String(), insecure.NewCredentials(), nil) + _, err = dialer.DialContext(ctx, "client-nonexistent", 1) + require.ErrorIs(t, err, ErrNotFound) +} diff --git a/pkg/rtun/gateway/metrics.go b/pkg/rtun/gateway/metrics.go new file mode 100644 index 000000000..112623bf9 --- /dev/null +++ b/pkg/rtun/gateway/metrics.go @@ -0,0 +1,44 @@ +package gateway + +import ( + "context" + + sdkmetrics "github.com/conductorone/baton-sdk/pkg/metrics" +) + +type gwMetrics struct { + h sdkmetrics.Handler + + openReq sdkmetrics.Int64Counter + openOK sdkmetrics.Int64Counter + openMiss sdkmetrics.Int64Counter + framesRx sdkmetrics.Int64Counter + framesTx sdkmetrics.Int64Counter + writerDrop sdkmetrics.Int64Counter + writeErr sdkmetrics.Int64Counter +} + +func newGwMetrics(h sdkmetrics.Handler) *gwMetrics { + return &gwMetrics{ + h: h, + openReq: h.Int64Counter("rtun.gateway.open_requests_total", "gateway open requests", sdkmetrics.Dimensionless), + openOK: h.Int64Counter("rtun.gateway.open_success_total", "gateway opens succeeded", sdkmetrics.Dimensionless), + openMiss: h.Int64Counter("rtun.gateway.open_not_found_total", "gateway opens not found", sdkmetrics.Dimensionless), + framesRx: h.Int64Counter("rtun.gateway.frame_rx_total", "gateway frames received", sdkmetrics.Dimensionless), + framesTx: h.Int64Counter("rtun.gateway.frame_tx_total", "gateway frames sent", sdkmetrics.Dimensionless), + writerDrop: h.Int64Counter("rtun.gateway.writer_queue_drops_total", "gateway writer queue drops", sdkmetrics.Dimensionless), + writeErr: h.Int64Counter("rtun.gateway.writer_write_errors_total", "gateway write errors", sdkmetrics.Dimensionless), + } +} + +func (m *gwMetrics) addOpenReq(ctx context.Context) { m.openReq.Add(ctx, 1, nil) } +func (m *gwMetrics) addOpenOK(ctx context.Context) { m.openOK.Add(ctx, 1, nil) } +func (m *gwMetrics) addOpenMiss(ctx context.Context) { m.openMiss.Add(ctx, 1, nil) } +func (m *gwMetrics) addFrameRx(ctx context.Context, k string) { + m.framesRx.Add(ctx, 1, map[string]string{"kind": k}) +} +func (m *gwMetrics) addFrameTx(ctx context.Context, k string) { + m.framesTx.Add(ctx, 1, map[string]string{"kind": k}) +} +func (m *gwMetrics) addWriterDrop(ctx context.Context) { m.writerDrop.Add(ctx, 1, nil) } +func (m *gwMetrics) addWriteErr(ctx context.Context) { m.writeErr.Add(ctx, 1, nil) } diff --git a/pkg/rtun/gateway/options.go b/pkg/rtun/gateway/options.go new file mode 100644 index 000000000..1caf04456 --- /dev/null +++ b/pkg/rtun/gateway/options.go @@ -0,0 +1,16 @@ +package gateway + +import ( + sdkmetrics "github.com/conductorone/baton-sdk/pkg/metrics" +) + +type Option func(*options) + +type options struct { + metrics sdkmetrics.Handler +} + +// WithMetricsHandler injects a metrics handler for the gateway service. +func WithMetricsHandler(h sdkmetrics.Handler) Option { + return func(o *options) { o.metrics = h } +} diff --git a/pkg/rtun/gateway/server.go b/pkg/rtun/gateway/server.go new file mode 100644 index 000000000..afc9ecea6 --- /dev/null +++ b/pkg/rtun/gateway/server.go @@ -0,0 +1,303 @@ +package gateway + +import ( + "context" + "errors" + "io" + "net" + "net/url" + "strconv" + "sync" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" + "github.com/conductorone/baton-sdk/pkg/rtun/server" + "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" + "go.uber.org/zap" +) + +// Server implements the ReverseDialer gateway service. +// It bridges caller streams to rtun sessions on the owner server process. +type Server struct { + rtunpb.UnimplementedReverseDialerServiceServer + + reg *server.Registry + serverID string + m *gwMetrics +} + +// NewServer creates a gateway server that bridges callers to local rtun sessions. +func NewServer(reg *server.Registry, serverID string, opts ...Option) *Server { + var o options + for _, opt := range opts { + if opt == nil { + continue + } + opt(&o) + } + s := &Server{ + reg: reg, + serverID: serverID, + } + if o.metrics != nil { + s.m = newGwMetrics(o.metrics) + } + return s +} + +// Open handles a gateway stream: caller sends OpenRequest(s) and Frames; gateway bridges to rtun. +const ( + writerQueueCap = 256 + writeDeadline = 30 * time.Second +) + +type entry struct { + conn net.Conn + writeCh chan []byte + done chan struct{} + closeOnce sync.Once + m *gwMetrics + ctx context.Context +} + +func newEntry(conn net.Conn, m *gwMetrics, ctx context.Context) *entry { + e := &entry{ + conn: conn, + writeCh: make(chan []byte, writerQueueCap), + done: make(chan struct{}), + m: m, + ctx: ctx, + } + go e.writerLoop() + return e +} + +func (e *entry) writerLoop() { + for { + select { + case <-e.done: + return + case buf, ok := <-e.writeCh: + if !ok { + return + } + _ = e.conn.SetWriteDeadline(time.Now().Add(writeDeadline)) + if _, err := e.conn.Write(buf); err != nil { + if e.m != nil { + e.m.addWriteErr(e.ctx) + } + e.close() + return + } + } + } +} + +func (e *entry) send(b []byte) bool { + cp := append([]byte(nil), b...) + select { + case <-e.done: + return false + case e.writeCh <- cp: + return true + default: + e.close() + return false + } +} + +func (e *entry) close() { + e.closeOnce.Do(func() { + _ = e.conn.Close() + close(e.done) + close(e.writeCh) + }) +} + +func (s *Server) Open(stream rtunpb.ReverseDialerService_OpenServer) error { + ctx := stream.Context() + logger := ctxzap.Extract(ctx).With(zap.String("server_id", s.serverID)) + + var mu sync.Mutex + entries := make(map[uint32]*entry) + var wg sync.WaitGroup + + defer func() { + // Cleanup: close all connections and wait for readers + mu.Lock() + for gsid, ent := range entries { + ent.close() + delete(entries, gsid) + } + mu.Unlock() + wg.Wait() + }() + + for { + req, err := stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } + return err + } + + switch k := req.Kind.(type) { + case *rtunpb.ReverseDialerServiceOpenRequest_OpenReq: + openReq := k.OpenReq + gsid := openReq.GetGsid() + clientID := openReq.GetClientId() + port := openReq.GetPort() + + logger := logger.With(zap.Uint32("gsid", gsid), zap.String("client_id", clientID), zap.Uint32("port", port)) + if s.m != nil { + s.m.addOpenReq(ctx) + s.m.addFrameRx(ctx, "OPEN_REQ") + } + + // Check for duplicate gSID + mu.Lock() + if _, exists := entries[gsid]; exists { + mu.Unlock() + logger.Warn("duplicate gSID in OpenRequest") + _ = stream.Send(&rtunpb.ReverseDialerServiceOpenResponse{Kind: &rtunpb.ReverseDialerServiceOpenResponse_Frame{ + Frame: &rtunpb.Frame{Sid: gsid, Kind: &rtunpb.Frame_Rst{Rst: &rtunpb.Rst{Code: rtunpb.RstCode_RST_CODE_INTERNAL}}}, + }}) + continue + } + mu.Unlock() + + // Open reverse connection via local registry + addr := formatRtunAddr(clientID, port) + conn, err := s.reg.DialContext(ctx, addr) + if err != nil { + logger.Info("client not found or dial failed", zap.Error(err)) + _ = stream.Send(&rtunpb.ReverseDialerServiceOpenResponse{Kind: &rtunpb.ReverseDialerServiceOpenResponse_OpenResp{ + OpenResp: &rtunpb.OpenResponse{Gsid: gsid, Result: &rtunpb.OpenResponse_NotFound{NotFound: &rtunpb.NotFound{}}}, + }}) + if s.m != nil { + s.m.addOpenMiss(ctx) + } + continue + } + + // Store conn and reply success + mu.Lock() + ent := newEntry(conn, s.m, ctx) + entries[gsid] = ent + mu.Unlock() + + logger.Info("opened reverse connection") + if err := stream.Send(&rtunpb.ReverseDialerServiceOpenResponse{Kind: &rtunpb.ReverseDialerServiceOpenResponse_OpenResp{ + OpenResp: &rtunpb.OpenResponse{Gsid: gsid, Result: &rtunpb.OpenResponse_Opened{Opened: &rtunpb.Opened{}}}, + }}); err != nil { + conn.Close() + return err + } + if s.m != nil { + s.m.addOpenOK(ctx) + } + + // Spawn reader: conn → caller + wg.Add(1) + go s.bridgeRead(ctx, conn, gsid, stream, &wg, logger) + + case *rtunpb.ReverseDialerServiceOpenRequest_Frame: + fr := k.Frame + gsid := fr.GetSid() + + mu.Lock() + ent := entries[gsid] + mu.Unlock() + + if ent == nil { + // Unknown gSID; send RST + _ = stream.Send(&rtunpb.ReverseDialerServiceOpenResponse{Kind: &rtunpb.ReverseDialerServiceOpenResponse_Frame{ + Frame: &rtunpb.Frame{Sid: gsid, Kind: &rtunpb.Frame_Rst{Rst: &rtunpb.Rst{Code: rtunpb.RstCode_RST_CODE_INTERNAL}}}, + }}) + continue + } + + // Handle frame + switch fk := fr.Kind.(type) { + case *rtunpb.Frame_Data: + if s.m != nil { + s.m.addFrameRx(ctx, "DATA") + } + if ok := ent.send(fk.Data.GetPayload()); !ok { + if s.m != nil { + s.m.addWriterDrop(ctx) + } + mu.Lock() + delete(entries, gsid) + mu.Unlock() + } + case *rtunpb.Frame_Fin: + if s.m != nil { + s.m.addFrameRx(ctx, "FIN") + } + mu.Lock() + ent.close() + delete(entries, gsid) + mu.Unlock() + case *rtunpb.Frame_Rst: + if s.m != nil { + s.m.addFrameRx(ctx, "RST") + } + mu.Lock() + ent.close() + delete(entries, gsid) + mu.Unlock() + } + } + } +} + +// bridgeRead reads from rtun conn and sends frames to the caller stream. +func (s *Server) bridgeRead(ctx context.Context, conn net.Conn, gsid uint32, stream rtunpb.ReverseDialerService_OpenServer, wg *sync.WaitGroup, logger *zap.Logger) { + defer wg.Done() + buf := make([]byte, 32*1024) + for { + n, err := conn.Read(buf) + if n > 0 { + if err := stream.Send(&rtunpb.ReverseDialerServiceOpenResponse{Kind: &rtunpb.ReverseDialerServiceOpenResponse_Frame{ + Frame: &rtunpb.Frame{Sid: gsid, Kind: &rtunpb.Frame_Data{Data: &rtunpb.Data{Payload: append([]byte(nil), buf[:n]...)}}}, + }}); err != nil { + logger.Warn("failed to send data to caller", zap.Error(err)) + return + } + if s.m != nil { + s.m.addFrameTx(ctx, "DATA") + } + } + if err != nil { + if errors.Is(err, io.EOF) { + _ = stream.Send(&rtunpb.ReverseDialerServiceOpenResponse{Kind: &rtunpb.ReverseDialerServiceOpenResponse_Frame{ + Frame: &rtunpb.Frame{Sid: gsid, Kind: &rtunpb.Frame_Fin{Fin: &rtunpb.Fin{}}}, + }}) + if s.m != nil { + s.m.addFrameTx(ctx, "FIN") + } + } else { + _ = stream.Send(&rtunpb.ReverseDialerServiceOpenResponse{Kind: &rtunpb.ReverseDialerServiceOpenResponse_Frame{ + Frame: &rtunpb.Frame{Sid: gsid, Kind: &rtunpb.Frame_Rst{Rst: &rtunpb.Rst{Code: rtunpb.RstCode_RST_CODE_INTERNAL}}}, + }}) + if s.m != nil { + s.m.addFrameTx(ctx, "RST") + } + } + return + } + } +} + +func formatRtunAddr(clientID string, port uint32) string { + u := url.URL{ + Scheme: "rtun", + Host: net.JoinHostPort( + clientID, + strconv.Itoa(int(port)), + ), + } + return u.String() +} diff --git a/pkg/rtun/match/directory.go b/pkg/rtun/match/directory.go new file mode 100644 index 000000000..67fcb8d95 --- /dev/null +++ b/pkg/rtun/match/directory.go @@ -0,0 +1,20 @@ +package match + +import ( + "context" + "errors" + "time" +) + +// ErrNotImplemented is returned by placeholder implementations. +var ErrNotImplemented = errors.New("rtun/match: not implemented") + +// Directory maps server IDs to dialable addresses, with TTL support. +type Directory interface { + // Advertise publishes a server's address with a TTL. + Advertise(ctx context.Context, serverID string, addr string, ttl time.Duration) error + // Revoke removes a server's advertisement. + Revoke(ctx context.Context, serverID string) error + // Resolve returns the address for a server ID if present and not expired. + Resolve(ctx context.Context, serverID string) (string, error) +} diff --git a/pkg/rtun/match/errors.go b/pkg/rtun/match/errors.go new file mode 100644 index 000000000..6f9c4ae74 --- /dev/null +++ b/pkg/rtun/match/errors.go @@ -0,0 +1,8 @@ +package match + +import "errors" + +var ( + // ErrClientOffline indicates no servers currently advertise the client. + ErrClientOffline = errors.New("rtun/match: client offline") +) diff --git a/pkg/rtun/match/locator.go b/pkg/rtun/match/locator.go new file mode 100644 index 000000000..be652b2ea --- /dev/null +++ b/pkg/rtun/match/locator.go @@ -0,0 +1,51 @@ +package match + +import ( + "context" + "hash/fnv" +) + +// Locator selects the owning server for a client based on presence and rendezvous hashing. +type Locator struct { + Presence Presence +} + +// OwnerOf returns the server that currently owns the client's link along with the client's ports. +// It uses Presence to list available servers and rendezvous hashing to choose among them. +func (l *Locator) OwnerOf(ctx context.Context, clientID string) (string, []uint32, error) { + if l == nil || l.Presence == nil { + return "", nil, ErrNotImplemented + } + servers, err := l.Presence.Locations(ctx, clientID) + if err != nil { + return "", nil, err + } + if len(servers) == 0 { + return "", nil, ErrClientOffline + } + owner := rendezvousChoose(clientID, servers) + p, err := l.Presence.Ports(ctx, clientID) + if err != nil { + return "", nil, err + } + return owner, p, nil +} + +func rendezvousChoose(clientID string, servers []string) string { + var best string + var bestVal uint64 + var have bool + for _, s := range servers { + h := fnv.New64a() + _, _ = h.Write([]byte(clientID)) + _, _ = h.Write([]byte("|")) + _, _ = h.Write([]byte(s)) + v := h.Sum64() + if !have || v > bestVal { + bestVal = v + best = s + have = true + } + } + return best +} diff --git a/pkg/rtun/match/locator_test.go b/pkg/rtun/match/locator_test.go new file mode 100644 index 000000000..4032be2f6 --- /dev/null +++ b/pkg/rtun/match/locator_test.go @@ -0,0 +1,84 @@ +package match + +import ( + "context" + "testing" + "time" + + "github.com/conductorone/baton-sdk/pkg/rtun/match/memory" + "github.com/stretchr/testify/require" +) + +func TestLocatorOwnerOfSingleServer(t *testing.T) { + p := memory.NewPresence() + ctx := context.Background() + + err := p.SetPorts(ctx, "client-1", []uint32{1}) + require.NoError(t, err) + err = p.Announce(ctx, "client-1", "server-a", 10*time.Second) + require.NoError(t, err) + + loc := &Locator{Presence: p} + owner, ports, err := loc.OwnerOf(ctx, "client-1") + require.NoError(t, err) + require.Equal(t, "server-a", owner) + require.Equal(t, []uint32{1}, ports) +} + +func TestLocatorOwnerOfMultipleServersDeterministic(t *testing.T) { + p := memory.NewPresence() + ctx := context.Background() + + // Register two servers + _ = p.SetPorts(ctx, "client-1", []uint32{1}) + _ = p.Announce(ctx, "client-1", "server-a", 10*time.Second) + _ = p.Announce(ctx, "client-1", "server-b", 10*time.Second) + + loc := &Locator{Presence: p} + owner1, _, err := loc.OwnerOf(ctx, "client-1") + require.NoError(t, err) + + // Call again; should be same owner (deterministic rendezvous hashing) + owner2, _, err := loc.OwnerOf(ctx, "client-1") + require.NoError(t, err) + require.Equal(t, owner1, owner2) + + // owner should be one of the two servers + require.Contains(t, []string{"server-a", "server-b"}, owner1) +} + +func TestLocatorOwnerOfClientOffline(t *testing.T) { + p := memory.NewPresence() + ctx := context.Background() + + loc := &Locator{Presence: p} + _, _, err := loc.OwnerOf(ctx, "client-nonexistent") + require.ErrorIs(t, err, ErrClientOffline) +} + +func TestLocatorOwnerOfWithCustomChooser(t *testing.T) { + // Removed: locator no longer accepts a custom chooser; keep surface minimal. +} + +func TestLocatorOwnerOfTTLExpiry(t *testing.T) { + p := memory.NewPresence() + ctx := context.Background() + + // Set with very short TTL + err := p.SetPorts(ctx, "client-1", []uint32{1}) + require.NoError(t, err) + err = p.Announce(ctx, "client-1", "server-a", 10*time.Millisecond) + require.NoError(t, err) + + loc := &Locator{Presence: p} + owner, _, err := loc.OwnerOf(ctx, "client-1") + require.NoError(t, err) + require.Equal(t, "server-a", owner) + + // Wait for expiry + time.Sleep(20 * time.Millisecond) + + // Now should return ErrClientOffline + _, _, err = loc.OwnerOf(ctx, "client-1") + require.ErrorIs(t, err, ErrClientOffline) +} diff --git a/pkg/rtun/match/memory/directory.go b/pkg/rtun/match/memory/directory.go new file mode 100644 index 000000000..bf4c680db --- /dev/null +++ b/pkg/rtun/match/memory/directory.go @@ -0,0 +1,58 @@ +package memory + +import ( + "context" + "errors" + "sync" + "time" +) + +// ErrServerNotFound indicates a server ID has no active advertisement. +var ErrServerNotFound = errors.New("rtun/match: server not found") + +// Directory is an in-memory Directory for tests and single-node deployments. +type Directory struct { + mu sync.RWMutex + servers map[string]record // serverID -> record +} + +type record struct { + addr string + expires time.Time +} + +// NewDirectory returns an in-memory Directory. +func NewDirectory() *Directory { + return &Directory{ + servers: make(map[string]record), + } +} + +func (d *Directory) Advertise(ctx context.Context, serverID string, addr string, ttl time.Duration) error { + d.mu.Lock() + defer d.mu.Unlock() + d.servers[serverID] = record{addr: addr, expires: time.Now().Add(ttl)} + return nil +} + +func (d *Directory) Revoke(ctx context.Context, serverID string) error { + d.mu.Lock() + defer d.mu.Unlock() + delete(d.servers, serverID) + return nil +} + +func (d *Directory) Resolve(ctx context.Context, serverID string) (string, error) { + now := time.Now() + d.mu.Lock() + defer d.mu.Unlock() + rec, ok := d.servers[serverID] + if !ok { + return "", ErrServerNotFound + } + if !rec.expires.IsZero() && now.After(rec.expires) { + delete(d.servers, serverID) + return "", ErrServerNotFound + } + return rec.addr, nil +} diff --git a/pkg/rtun/match/memory/directory_test.go b/pkg/rtun/match/memory/directory_test.go new file mode 100644 index 000000000..70a694417 --- /dev/null +++ b/pkg/rtun/match/memory/directory_test.go @@ -0,0 +1,54 @@ +package memory + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestDirectoryResolve(t *testing.T) { + d := NewDirectory() + ctx := context.Background() + + // Initially not found + _, err := d.Resolve(ctx, "server-a") + require.ErrorIs(t, err, ErrServerNotFound) + + // Register server + require.NoError(t, d.Advertise(ctx, "server-a", "127.0.0.1:5000", 10*time.Second)) + + addr, err := d.Resolve(ctx, "server-a") + require.NoError(t, err) + require.Equal(t, "127.0.0.1:5000", addr) +} + +func TestDirectoryUnregister(t *testing.T) { + d := NewDirectory() + ctx := context.Background() + + require.NoError(t, d.Advertise(ctx, "server-a", "127.0.0.1:5000", 10*time.Second)) + addr, err := d.Resolve(ctx, "server-a") + require.NoError(t, err) + require.Equal(t, "127.0.0.1:5000", addr) + + // Unregister + require.NoError(t, d.Revoke(ctx, "server-a")) + _, err = d.Resolve(ctx, "server-a") + require.ErrorIs(t, err, ErrServerNotFound) +} + +func TestDirectoryMultipleServersIsolation(t *testing.T) { + d := NewDirectory() + ctx := context.Background() + + require.NoError(t, d.Advertise(ctx, "server-a", "a.local:1", 10*time.Second)) + require.NoError(t, d.Advertise(ctx, "server-b", "b.local:2", 10*time.Second)) + + addrA, _ := d.Resolve(ctx, "server-a") + addrB, _ := d.Resolve(ctx, "server-b") + + require.Equal(t, "a.local:1", addrA) + require.Equal(t, "b.local:2", addrB) +} diff --git a/pkg/rtun/match/memory/presence.go b/pkg/rtun/match/memory/presence.go new file mode 100644 index 000000000..899d2dc91 --- /dev/null +++ b/pkg/rtun/match/memory/presence.go @@ -0,0 +1,94 @@ +package memory + +import ( + "context" + "sync" + "time" +) + +// Presence is an in-memory implementation of match.Presence for tests and +// single-node deployments. It stores per-(client, server) leases and global +// client ports. +type Presence struct { + mu sync.RWMutex + leases map[string]map[string]time.Time // clientID -> serverID -> expiry + ports map[string][]uint32 // clientID -> ports +} + +// NewPresence returns an in-memory Presence tracker. +func NewPresence() *Presence { + return &Presence{ + leases: make(map[string]map[string]time.Time), + ports: make(map[string][]uint32), + } +} + +func (p *Presence) Announce(ctx context.Context, clientID string, serverID string, ttl time.Duration) error { + p.mu.Lock() + defer p.mu.Unlock() + if p.leases[clientID] == nil { + p.leases[clientID] = make(map[string]time.Time) + } + p.leases[clientID][serverID] = time.Now().Add(ttl) + return nil +} + +func (p *Presence) Revoke(ctx context.Context, clientID string, serverID string) error { + p.mu.Lock() + defer p.mu.Unlock() + if inner := p.leases[clientID]; inner != nil { + delete(inner, serverID) + if len(inner) == 0 { + delete(p.leases, clientID) + delete(p.ports, clientID) + } + } + return nil +} + +func (p *Presence) Locations(ctx context.Context, clientID string) ([]string, error) { + now := time.Now() + p.mu.Lock() + defer p.mu.Unlock() + inner := p.leases[clientID] + if inner == nil { + return nil, nil + } + for s, exp := range inner { + if !exp.IsZero() && now.After(exp) { + delete(inner, s) + } + } + if len(inner) == 0 { + delete(p.leases, clientID) + delete(p.ports, clientID) + return nil, nil + } + servers := make([]string, 0, len(inner)) + for s := range inner { + servers = append(servers, s) + } + return servers, nil +} + +func (p *Presence) SetPorts(ctx context.Context, clientID string, ports []uint32) error { + p.mu.Lock() + defer p.mu.Unlock() + if ports == nil { + delete(p.ports, clientID) + return nil + } + cp := append([]uint32(nil), ports...) + p.ports[clientID] = cp + return nil +} + +func (p *Presence) Ports(ctx context.Context, clientID string) ([]uint32, error) { + p.mu.RLock() + defer p.mu.RUnlock() + ports := p.ports[clientID] + if ports == nil { + return nil, nil + } + return append([]uint32(nil), ports...), nil +} diff --git a/pkg/rtun/match/memory/presence_test.go b/pkg/rtun/match/memory/presence_test.go new file mode 100644 index 000000000..50b50928c --- /dev/null +++ b/pkg/rtun/match/memory/presence_test.go @@ -0,0 +1,139 @@ +package memory + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestPresenceSetOnlineGetLocations(t *testing.T) { + p := NewPresence() + ctx := context.Background() + + // Initially no locations + locs, err := p.Locations(ctx, "client-1") + require.NoError(t, err) + require.Empty(t, locs) + + // Set ports and announce on server-a + err = p.SetPorts(ctx, "client-1", []uint32{1}) + require.NoError(t, err) + err = p.Announce(ctx, "client-1", "server-a", 10*time.Second) + require.NoError(t, err) + + locs, err = p.Locations(ctx, "client-1") + require.NoError(t, err) + require.Len(t, locs, 1) + require.Contains(t, locs, "server-a") + + // Add second server + err = p.Announce(ctx, "client-1", "server-b", 10*time.Second) + require.NoError(t, err) + + locs, err = p.Locations(ctx, "client-1") + require.NoError(t, err) + require.Len(t, locs, 2) + require.Contains(t, locs, "server-a") + require.Contains(t, locs, "server-b") +} + +func TestPresenceSetOffline(t *testing.T) { + p := NewPresence() + ctx := context.Background() + + err := p.SetPorts(ctx, "client-1", []uint32{1}) + require.NoError(t, err) + err = p.Announce(ctx, "client-1", "server-a", 10*time.Second) + require.NoError(t, err) + err = p.Announce(ctx, "client-1", "server-b", 10*time.Second) + require.NoError(t, err) + + // Remove server-a + err = p.Revoke(ctx, "client-1", "server-a") + require.NoError(t, err) + + locs, err := p.Locations(ctx, "client-1") + require.NoError(t, err) + require.Len(t, locs, 1) + require.Contains(t, locs, "server-b") + + // Remove server-b (last server); client should disappear + err = p.Revoke(ctx, "client-1", "server-b") + require.NoError(t, err) + + locs, err = p.Locations(ctx, "client-1") + require.NoError(t, err) + require.Empty(t, locs) + + // Ports should also be gone + ports, err := p.Ports(ctx, "client-1") + require.NoError(t, err) + require.Empty(t, ports) +} + +func TestPresenceGetPorts(t *testing.T) { + p := NewPresence() + ctx := context.Background() + + err := p.SetPorts(ctx, "client-1", []uint32{1, 2, 3}) + require.NoError(t, err) + err = p.Announce(ctx, "client-1", "server-a", 10*time.Second) + require.NoError(t, err) + + ports, err := p.Ports(ctx, "client-1") + require.NoError(t, err) + require.Equal(t, []uint32{1, 2, 3}, ports) + + // Update ports independently of leases + err = p.SetPorts(ctx, "client-1", []uint32{5}) + require.NoError(t, err) + + ports, err = p.Ports(ctx, "client-1") + require.NoError(t, err) + require.Equal(t, []uint32{5}, ports) +} + +func TestPresenceTTLExpiry(t *testing.T) { + p := NewPresence() + ctx := context.Background() + + // Set with very short TTL + err := p.SetPorts(ctx, "client-1", []uint32{1}) + require.NoError(t, err) + err = p.Announce(ctx, "client-1", "server-a", 10*time.Millisecond) + require.NoError(t, err) + + locs, err := p.Locations(ctx, "client-1") + require.NoError(t, err) + require.Len(t, locs, 1) + + // Wait for expiry + time.Sleep(20 * time.Millisecond) + + // GetLocations should prune expired and return empty + locs, err = p.Locations(ctx, "client-1") + require.NoError(t, err) + require.Empty(t, locs) +} + +func TestPresenceMultipleClientsIsolation(t *testing.T) { + p := NewPresence() + ctx := context.Background() + + err := p.SetPorts(ctx, "client-1", []uint32{1}) + require.NoError(t, err) + err = p.Announce(ctx, "client-1", "server-a", 10*time.Second) + require.NoError(t, err) + err = p.SetPorts(ctx, "client-2", []uint32{2}) + require.NoError(t, err) + err = p.Announce(ctx, "client-2", "server-b", 10*time.Second) + require.NoError(t, err) + + locs1, _ := p.Locations(ctx, "client-1") + locs2, _ := p.Locations(ctx, "client-2") + + require.Equal(t, []string{"server-a"}, locs1) + require.Equal(t, []string{"server-b"}, locs2) +} diff --git a/pkg/rtun/match/presence.go b/pkg/rtun/match/presence.go new file mode 100644 index 000000000..06be0da20 --- /dev/null +++ b/pkg/rtun/match/presence.go @@ -0,0 +1,21 @@ +package match + +import ( + "context" + "time" +) + +// Presence tracks per-(client, server) leases and a client's global ports. +// +// Semantics: +// - Announce/refresh a lease for (clientID, serverID) with a TTL. +// - Revoke removes a single server's lease for a client. +// - Locations returns only non-expired serverIDs. +// - SetPorts sets the client's ports (typically once at hello); Ports returns them. +type Presence interface { + Announce(ctx context.Context, clientID string, serverID string, ttl time.Duration) error + Revoke(ctx context.Context, clientID string, serverID string) error + Locations(ctx context.Context, clientID string) ([]string, error) + SetPorts(ctx context.Context, clientID string, ports []uint32) error + Ports(ctx context.Context, clientID string) ([]uint32, error) +} diff --git a/pkg/rtun/match/route.go b/pkg/rtun/match/route.go new file mode 100644 index 000000000..621b9170a --- /dev/null +++ b/pkg/rtun/match/route.go @@ -0,0 +1,62 @@ +package match + +import ( + "context" + "fmt" + "net" + "net/url" + "strconv" + + "github.com/conductorone/baton-sdk/pkg/rtun/server" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +// OwnerRouter helps route work to the server process that owns a client's link. +type OwnerRouter struct { + Locator *Locator + Directory Directory + DialOpts []grpc.DialOption +} + +// DialOwner resolves the owner of clientID and returns a gRPC connection to that server. +// The caller should use this connection to invoke services on the owner, where the owner +// can use its local Registry.DialContext to perform reverse RPCs. +func (r *OwnerRouter) DialOwner(ctx context.Context, clientID string) (*grpc.ClientConn, string, error) { + owner, _, err := r.Locator.OwnerOf(ctx, clientID) + if err != nil { + return nil, "", fmt.Errorf("rtun: locate owner: %w", err) + } + addr, err := r.Directory.Resolve(ctx, owner) + if err != nil { + return nil, "", fmt.Errorf("rtun: resolve owner address: %w", err) + } + opts := r.DialOpts + conn, err := grpc.NewClient("passthrough:///"+addr, opts...) + if err != nil { + return nil, "", fmt.Errorf("rtun: dial owner: %w", err) + } + return conn, owner, nil +} + +// LocalReverseDial is a helper to be called ON the owner server process. +// It uses the local Registry to open a reverse connection to the client. +// clientID must be URL-safe; use url.PathEscape if it contains special characters. +func LocalReverseDial(ctx context.Context, reg *server.Registry, clientID string, port uint32) (*grpc.ClientConn, error) { + u := url.URL{ + Scheme: "rtun", + Host: net.JoinHostPort( + clientID, + strconv.FormatUint(uint64(port), 10), + ), + } + addr := u.String() + conn, err := grpc.NewClient("passthrough:///"+addr, + grpc.WithContextDialer(reg.DialContext), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + return nil, fmt.Errorf("rtun: reverse dial: %w", err) + } + return conn, nil +} diff --git a/pkg/rtun/match/route_test.go b/pkg/rtun/match/route_test.go new file mode 100644 index 000000000..b1237ae79 --- /dev/null +++ b/pkg/rtun/match/route_test.go @@ -0,0 +1,128 @@ +package match + +import ( + "context" + "net" + "testing" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" + "github.com/conductorone/baton-sdk/pkg/rtun/match/memory" + "github.com/conductorone/baton-sdk/pkg/rtun/server" + "github.com/conductorone/baton-sdk/pkg/rtun/transport" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + health "google.golang.org/grpc/health" + healthpb "google.golang.org/grpc/health/grpc_health_v1" +) + +// testValidator for integration tests. +type testValidator struct{ id string } + +func (t testValidator) ValidateAuth(ctx context.Context) (string, error) { return t.id, nil } +func (t testValidator) ValidateHello(ctx context.Context, hello *rtunpb.Hello) error { return nil } + +// TestOwnerRouterTwoServers simulates: +// - Server A owns client-123 (handler + registry). +// - Server B is a caller that uses OwnerRouter to find A, dials A, and uses A's registry to reverse-dial client-123. +func TestOwnerRouterTwoServers(t *testing.T) { + presence := memory.NewPresence() + directory := memory.NewDirectory() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Server A setup + regA := server.NewRegistry() + handlerA := server.NewHandler(regA, "server-a", testValidator{id: "client-123"}) + gsrvA := grpc.NewServer() + rtunpb.RegisterReverseTunnelServiceServer(gsrvA, handlerA) + lA, err := (&net.ListenConfig{}).Listen(ctx, "tcp", "127.0.0.1:0") + require.NoError(t, err) + go func() { _ = gsrvA.Serve(lA) }() + + // Register server A in directory + require.NoError(t, directory.Advertise(ctx, "server-a", lA.Addr().String(), 10*time.Second)) + + // Client connects to server A with cancelable context + clientCtx, clientCancel := context.WithCancel(ctx) + defer clientCancel() + + ccA, err := grpc.NewClient("passthrough:///"+lA.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + rtunClientA := rtunpb.NewReverseTunnelServiceClient(ccA) + streamA, err := rtunClientA.Link(clientCtx) + require.NoError(t, err) + + // Client-side session and HELLO + clA := &clientLink{cli: streamA} + sessA := transport.NewSession(clA) + require.NoError(t, clA.Send(&rtunpb.Frame{Sid: 0, Kind: &rtunpb.Frame_Hello{Hello: &rtunpb.Hello{Ports: []uint32{1}}}})) + lnA, err := sessA.Listen(ctx, 1) + require.NoError(t, err) + + // Client runs health service + cgsA := grpc.NewServer() + hsA := health.NewServer() + healthpb.RegisterHealthServer(cgsA, hsA) + hsA.SetServingStatus("", healthpb.HealthCheckResponse_SERVING) + go func() { _ = cgsA.Serve(lnA) }() + + // Mark client-123 online on server-a in presence + require.NoError(t, presence.SetPorts(ctx, "client-123", []uint32{1})) + require.NoError(t, presence.Announce(ctx, "client-123", "server-a", 10*time.Second)) + + // Server B (caller) uses OwnerRouter to find and dial server A + router := &OwnerRouter{ + Locator: &Locator{Presence: presence}, + Directory: directory, + DialOpts: []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}, + } + // Wait briefly for registration + time.Sleep(50 * time.Millisecond) + + ownerConn, ownerID, err := router.DialOwner(ctx, "client-123") + require.NoError(t, err) + require.Equal(t, "server-a", ownerID) + + // Now server B has a connection to server A. In production, B would invoke a service on A + // that internally uses regA.DialContext. For this test, simulate by directly using regA + // (since we're in the same process). + rconn, err := LocalReverseDial(ctx, regA, "client-123", 1) + require.NoError(t, err) + + // Perform health check over reverse connection + hc := healthpb.NewHealthClient(rconn) + resp, err := hc.Check(ctx, &healthpb.HealthCheckRequest{Service: ""}) + require.NoError(t, err) + require.Equal(t, healthpb.HealthCheckResponse_SERVING, resp.GetStatus()) + + // Clean up in correct order + rconn.Close() + ownerConn.Close() + cgsA.GracefulStop() + lnA.Close() + clientCancel() // Close client stream + ccA.Close() + gsrvA.GracefulStop() + lA.Close() +} + +// clientLink adapts the client bidi stream to transport.Link. +type clientLink struct { + cli rtunpb.ReverseTunnelService_LinkClient +} + +func (c *clientLink) Send(f *rtunpb.Frame) error { + return c.cli.Send(&rtunpb.ReverseTunnelServiceLinkRequest{Frame: f}) +} + +func (c *clientLink) Recv() (*rtunpb.Frame, error) { + fr, err := c.cli.Recv() + if err != nil { + return nil, err + } + return fr.GetFrame(), nil +} + +func (c *clientLink) Context() context.Context { return c.cli.Context() } diff --git a/pkg/rtun/server/auth.go b/pkg/rtun/server/auth.go new file mode 100644 index 000000000..d1ae59ed3 --- /dev/null +++ b/pkg/rtun/server/auth.go @@ -0,0 +1,16 @@ +package server + +import ( + "context" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" +) + +// TokenValidator decouples auth from the rtun transport. +type TokenValidator interface { + // ValidateAuth is invoked when a stream is first connected. It should authenticate the caller + // from the gRPC context (e.g., mTLS, headers) and return the bound clientID. + ValidateAuth(ctx context.Context) (clientID string, err error) + // ValidateHello validates the HELLO frame contents (e.g., ports, protocol negotiation). + ValidateHello(ctx context.Context, hello *rtunpb.Hello) error +} diff --git a/pkg/rtun/server/handler.go b/pkg/rtun/server/handler.go new file mode 100644 index 000000000..2e8c5adb7 --- /dev/null +++ b/pkg/rtun/server/handler.go @@ -0,0 +1,153 @@ +package server + +import ( + "context" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" + "github.com/conductorone/baton-sdk/pkg/rtun/transport" + "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" + "go.uber.org/zap" +) + +// Handler implements the ReverseTunnel gRPC service, binding Links to Sessions and the Registry. +type Handler struct { + rtunpb.UnimplementedReverseTunnelServiceServer + + reg *Registry + serverID string + tv TokenValidator + + // metrics (optional) + m *serverMetrics +} + +// NewHandler constructs a ReverseTunnel gRPC handler bound to a `Registry` and `TokenValidator`. +// It optionally enables metrics if provided via options. +func NewHandler(reg *Registry, serverID string, tv TokenValidator, opts ...Option) rtunpb.ReverseTunnelServiceServer { + var o options + for _, opt := range opts { + if opt == nil { + continue + } + opt(&o) + } + h := &Handler{reg: reg, serverID: serverID, tv: tv} + if o.metrics != nil { + h.m = newServerMetrics(o.metrics) + } + return h +} + +// Link accepts a bidi stream and binds it to a transport.Session after validating HELLO. +func (h *Handler) Link(stream rtunpb.ReverseTunnelService_LinkServer) error { + // Wrap the gRPC stream as transport.Link + l := &grpcLink{srv: stream} + + // Authenticate and determine clientID via TokenValidator BEFORE waiting for HELLO. + if h.tv == nil { + return ErrProtocol + } + clientID, err := h.tv.ValidateAuth(stream.Context()) + if err != nil { + return err + } + logger := ctxzap.Extract(stream.Context()).With(zap.String("client_id", clientID)) + logger.Info("auth ok") + + // First frame must be HELLO with timeout; if not HELLO, protocol violation + type recvResult struct { + fr *rtunpb.Frame + err error + } + resCh := make(chan recvResult, 1) + go func() { + fr, err := l.Recv() + resCh <- recvResult{fr: fr, err: err} + }() + helloTimeout := 15 * time.Second + var hello *rtunpb.Hello + select { + case res := <-resCh: + if res.err != nil { + return res.err + } + hello = res.fr.GetHello() + if hello == nil { + logger.Warn("first frame not HELLO; closing") + if h.m != nil { + h.m.helloRejected(stream.Context(), "not_hello") + } + return ErrProtocol + } + logger.Info("HELLO received", zap.Uint32s("ports", hello.GetPorts())) + // enforce reasonable HELLO port count limit (2500) + if len(hello.GetPorts()) > 2500 { + logger.Warn("HELLO ports exceed limit", zap.Int("count", len(hello.GetPorts()))) + if h.m != nil { + h.m.helloPortsOverLimit(stream.Context()) + } + return ErrProtocol + } + if err := h.tv.ValidateHello(stream.Context(), hello); err != nil { + if h.m != nil { + h.m.helloRejected(stream.Context(), "validate_failed") + } + return err + } + case <-time.After(helloTimeout): + logger.Warn("HELLO timeout") + if h.m != nil { + h.m.helloTimeout(stream.Context()) + } + return ErrHelloTimeout + } + + // Bind Session and start Recv loop. + var sessOpts []transport.Option + ports := hello.GetPorts() + if len(ports) > 0 { + sessOpts = append(sessOpts, transport.WithAllowedPorts(ports)) + } + // Pass metrics down to transport if available + if h.m != nil && h.m.h != nil { + sessOpts = append(sessOpts, transport.WithMetricsHandler(h.m.h)) + } + s := transport.NewSession(l, sessOpts...) + if h.m != nil { + h.m.registryRegister(stream.Context()) + } + h.reg.Register(stream.Context(), clientID, s) + defer h.reg.Unregister(stream.Context(), clientID) + defer func() { + if h.m != nil { + h.m.registryUnregister(stream.Context()) + } + }() + + // Start the session; recvLoop will run until link Recv errors (stream close/cancel) + s.Start() + + // Block until stream context is done (client disconnect or server shutdown) + <-stream.Context().Done() + return stream.Context().Err() +} + +// grpcLink adapts the gRPC server stream to transport.Link. +type grpcLink struct { + srv rtunpb.ReverseTunnelService_LinkServer +} + +func (g *grpcLink) Send(f *rtunpb.Frame) error { + return g.srv.Send(&rtunpb.ReverseTunnelServiceLinkResponse{Frame: f}) +} + +func (g *grpcLink) Recv() (*rtunpb.Frame, error) { + fr, err := g.srv.Recv() + if err != nil { + return nil, err + } + return fr.GetFrame(), nil +} + +func (g *grpcLink) Context() context.Context { return g.srv.Context() } diff --git a/pkg/rtun/server/metrics.go b/pkg/rtun/server/metrics.go new file mode 100644 index 000000000..7700605a0 --- /dev/null +++ b/pkg/rtun/server/metrics.go @@ -0,0 +1,46 @@ +package server + +import ( + "context" + + sdkmetrics "github.com/conductorone/baton-sdk/pkg/metrics" +) + +type serverMetrics struct { + h sdkmetrics.Handler + + helloTimeoutCtr sdkmetrics.Int64Counter + helloPortsOverCtr sdkmetrics.Int64Counter + helloRejectedCtr sdkmetrics.Int64Counter + regRegisterCtr sdkmetrics.Int64Counter + regUnregisterCtr sdkmetrics.Int64Counter + reverseDialOKCtr sdkmetrics.Int64Counter + reverseDialMissCtr sdkmetrics.Int64Counter +} + +func newServerMetrics(h sdkmetrics.Handler) *serverMetrics { + return &serverMetrics{ + h: h, + helloTimeoutCtr: h.Int64Counter("rtun.server.hello_timeout_total", "HELLO timeouts", sdkmetrics.Dimensionless), + helloPortsOverCtr: h.Int64Counter("rtun.server.hello_ports_over_limit_total", "HELLO ports over limit", sdkmetrics.Dimensionless), + helloRejectedCtr: h.Int64Counter("rtun.server.hello_rejected_total", "HELLO rejected for reason", sdkmetrics.Dimensionless), + regRegisterCtr: h.Int64Counter("rtun.server.registry_register_total", "registry register", sdkmetrics.Dimensionless), + regUnregisterCtr: h.Int64Counter("rtun.server.registry_unregister_total", "registry unregister", sdkmetrics.Dimensionless), + reverseDialOKCtr: h.Int64Counter("rtun.server.reverse_dial_success_total", "reverse dial success", sdkmetrics.Dimensionless), + reverseDialMissCtr: h.Int64Counter("rtun.server.reverse_dial_not_found_total", "reverse dial not found", sdkmetrics.Dimensionless), + } +} + +func (m *serverMetrics) helloTimeout(ctx context.Context) { m.helloTimeoutCtr.Add(ctx, 1, nil) } +func (m *serverMetrics) helloPortsOverLimit(ctx context.Context) { + m.helloPortsOverCtr.Add(ctx, 1, nil) +} +func (m *serverMetrics) helloRejected(ctx context.Context, reason string) { + m.helloRejectedCtr.Add(ctx, 1, map[string]string{"reason": reason}) +} +func (m *serverMetrics) registryRegister(ctx context.Context) { m.regRegisterCtr.Add(ctx, 1, nil) } +func (m *serverMetrics) registryUnregister(ctx context.Context) { m.regUnregisterCtr.Add(ctx, 1, nil) } +func (m *serverMetrics) reverseDialSuccess(ctx context.Context) { m.reverseDialOKCtr.Add(ctx, 1, nil) } +func (m *serverMetrics) reverseDialNotFound(ctx context.Context) { + m.reverseDialMissCtr.Add(ctx, 1, nil) +} diff --git a/pkg/rtun/server/options.go b/pkg/rtun/server/options.go new file mode 100644 index 000000000..88538125d --- /dev/null +++ b/pkg/rtun/server/options.go @@ -0,0 +1,28 @@ +package server + +import ( + "errors" + + sdkmetrics "github.com/conductorone/baton-sdk/pkg/metrics" +) + +var ( + // ErrNotImplemented is returned when a feature isn't implemented. + ErrNotImplemented = errors.New("rtun/server: not implemented") + // ErrProtocol indicates a protocol violation on the stream. + ErrProtocol = errors.New("rtun/server: protocol error") + // ErrHelloTimeout indicates the HELLO frame was not received in time. + ErrHelloTimeout = errors.New("rtun/server: hello timeout") +) + +// Option configures server components such as the handler and registry. +type Option func(*options) + +type options struct { + metrics sdkmetrics.Handler +} + +// WithMetricsHandler injects a metrics handler for server components (handler/registry). +func WithMetricsHandler(h sdkmetrics.Handler) Option { + return func(o *options) { o.metrics = h } +} diff --git a/pkg/rtun/server/registry.go b/pkg/rtun/server/registry.go new file mode 100644 index 000000000..e1b9b766c --- /dev/null +++ b/pkg/rtun/server/registry.go @@ -0,0 +1,93 @@ +package server + +import ( + "context" + "fmt" + "net" + "net/url" + "strconv" + "sync" + + "github.com/conductorone/baton-sdk/pkg/rtun/transport" +) + +// Registry tracks active client Sessions by client ID and mediates reverse dials. +type Registry struct { + mu sync.RWMutex + sessions map[string]*transport.Session + m *serverMetrics +} + +// NewRegistry creates a new Registry. Metrics can be enabled via options. +func NewRegistry(opts ...Option) *Registry { + var o options + for _, opt := range opts { + opt(&o) + } + r := &Registry{sessions: make(map[string]*transport.Session)} + if o.metrics != nil { + r.m = newServerMetrics(o.metrics) + } + return r +} + +// DialContext dials ONLY if this process owns the client link. +// addr: "rtun://:" where clientID is URL-safe and port is required. +// clientID must not contain unescaped colons or slashes; use url.PathEscape if needed. +func (r *Registry) DialContext(ctx context.Context, addr string) (net.Conn, error) { + // Parse rtun://clientID:port + u, err := url.Parse(addr) + if err != nil { + return nil, err + } + if u.Scheme != "rtun" { + return nil, fmt.Errorf("rtun: invalid scheme: %s", u.Scheme) + } + host := u.Host + clientID, portStr, err := net.SplitHostPort(host) + if err != nil { + return nil, fmt.Errorf("rtun: missing or invalid port in '%s': %w", addr, err) + } + pu, err := strconv.ParseUint(portStr, 10, 32) + if err != nil { + return nil, fmt.Errorf("rtun: invalid port in '%s': %w", addr, err) + } + port := uint32(pu) + + r.mu.RLock() + s := r.sessions[clientID] + r.mu.RUnlock() + if s == nil { + if r.m != nil { + r.m.reverseDialNotFound(ctx) + } + return nil, fmt.Errorf("rtun: client not connected: %s", clientID) + } + conn, err := s.Open(ctx, port) + if err == nil { + if r.m != nil { + r.m.reverseDialSuccess(ctx) + } + } + return conn, err +} + +// Register binds a client's Session to this Registry under the given clientID. +func (r *Registry) Register(ctx context.Context, clientID string, s *transport.Session) { + r.mu.Lock() + r.sessions[clientID] = s + r.mu.Unlock() + if r.m != nil { + r.m.registryRegister(ctx) + } +} + +// Unregister removes a client's Session binding. +func (r *Registry) Unregister(ctx context.Context, clientID string) { + r.mu.Lock() + delete(r.sessions, clientID) + r.mu.Unlock() + if r.m != nil { + r.m.registryUnregister(ctx) + } +} diff --git a/pkg/rtun/server/server_integration_test.go b/pkg/rtun/server/server_integration_test.go new file mode 100644 index 000000000..19de88649 --- /dev/null +++ b/pkg/rtun/server/server_integration_test.go @@ -0,0 +1,101 @@ +package server + +import ( + "context" + "net" + "testing" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" + "github.com/conductorone/baton-sdk/pkg/rtun/transport" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + health "google.golang.org/grpc/health" + healthpb "google.golang.org/grpc/health/grpc_health_v1" +) + +// testValidator is a production-shaped validator used only in tests. +// It authenticates the link and returns a fixed clientID; HELLO is allowed as-is. +type testValidator struct{ id string } + +func (t testValidator) ValidateAuth(ctx context.Context) (string, error) { return t.id, nil } +func (t testValidator) ValidateHello(ctx context.Context, hello *rtunpb.Hello) error { return nil } + +// clientLink adapts the client bidi stream to transport.Link on the client side. +type clientLink struct { + cli rtunpb.ReverseTunnelService_LinkClient +} + +func (c clientLink) Send(f *rtunpb.Frame) error { + return c.cli.Send(&rtunpb.ReverseTunnelServiceLinkRequest{Frame: f}) +} + +func (c clientLink) Recv() (*rtunpb.Frame, error) { + resp, err := c.cli.Recv() + if err != nil { + return nil, err + } + return resp.GetFrame(), nil +} + +func (c clientLink) Context() context.Context { return c.cli.Context() } + +// TestReverseGrpcE2E spins up a real gRPC server with Handler, connects a real gRPC client stream for Link, +// runs the standard gRPC health service over rtun on the client, and performs a health check from the owner via Registry.DialContext. +func TestReverseGrpcE2E(t *testing.T) { + // Server side: real gRPC server with our handler + reg := NewRegistry() + h := NewHandler(reg, "server-1", testValidator{id: "client-123"}) + gsrv := grpc.NewServer() + rtunpb.RegisterReverseTunnelServiceServer(gsrv, h) + l, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0") + require.NoError(t, err) + defer l.Close() + go func() { _ = gsrv.Serve(l) }() + defer gsrv.GracefulStop() + + // Client side: dial server and open Link stream + cc, err := grpc.NewClient("passthrough:///"+l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer cc.Close() + rtunClient := rtunpb.NewReverseTunnelServiceClient(cc) + stream, err := rtunClient.Link(context.Background()) + require.NoError(t, err) + + // Wrap client stream as transport.Link and start Session + cl := clientLink{cli: stream} + sess := transport.NewSession(cl) + // Send HELLO announcing port 1 + require.NoError(t, cl.Send(&rtunpb.Frame{Sid: 0, Kind: &rtunpb.Frame_Hello{Hello: &rtunpb.Hello{Ports: []uint32{1}}}})) + ln, err := sess.Listen(context.Background(), 1) + require.NoError(t, err) + defer ln.Close() + + // Run the standard gRPC health service over the rtun listener + cgs := grpc.NewServer() + hs := health.NewServer() + healthpb.RegisterHealthServer(cgs, hs) + hs.SetServingStatus("", healthpb.HealthCheckResponse_SERVING) + go func() { _ = cgs.Serve(ln) }() + defer cgs.GracefulStop() + + // Owner side: reverse dial and perform health check + // Wait briefly for registration to finish + time.Sleep(50 * time.Millisecond) + rconn, err := reg.DialContext(context.Background(), "rtun://client-123:1") + require.NoError(t, err) + defer rconn.Close() + + ownerCC, err := grpc.NewClient("passthrough:///ignored", + grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { return rconn, nil }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + defer ownerCC.Close() + + hc := healthpb.NewHealthClient(ownerCC) + resp, err := hc.Check(context.Background(), &healthpb.HealthCheckRequest{Service: ""}) + require.NoError(t, err) + require.Equal(t, healthpb.HealthCheckResponse_SERVING, resp.GetStatus()) +} diff --git a/pkg/rtun/transport/closedset.go b/pkg/rtun/transport/closedset.go new file mode 100644 index 000000000..b21df51e1 --- /dev/null +++ b/pkg/rtun/transport/closedset.go @@ -0,0 +1,151 @@ +package transport + +// closedSet maintains a compact set of closed SIDs using a run-length encoded +// interval list and a high-water mark for the largest contiguous closed prefix +// starting from 1. It is package-local and not concurrency-safe; callers should +// provide external synchronization. +// +// Invariants: +// - cs.ranges is sorted by start and contains disjoint, non-adjacent intervals +// with start <= end. +// - All SIDs in [1..highClosed] are considered closed and not represented in +// cs.ranges (auto-pruned on each Close call). +type closedSet struct { + highClosed uint32 + ranges []interval +} + +type interval struct { + start uint32 + end uint32 +} + +// IsClosed returns true if sid has been marked closed. +func (cs *closedSet) IsClosed(sid uint32) bool { + if sid == 0 { + return true + } + if sid <= cs.highClosed { + return true + } + i := cs.find(sid) + if i < len(cs.ranges) { + iv := cs.ranges[i] + if sid >= iv.start && sid <= iv.end { + return true + } + } + if i > 0 { + iv := cs.ranges[i-1] + return sid >= iv.start && sid <= iv.end + } + return false +} + +// Close marks sid as closed. This merges adjacent/overlapping intervals and +// auto-prunes intervals that become part of the contiguous closed prefix. +func (cs *closedSet) Close(sid uint32) { + if sid == 0 { + return + } + // If already in the contiguous prefix, nothing to do. + if sid <= cs.highClosed { + return + } + + i := cs.find(sid) + // Check if already closed by neighboring interval + if i < len(cs.ranges) { + iv := cs.ranges[i] + if sid >= iv.start && sid <= iv.end { + return + } + } + if i > 0 { + iv := cs.ranges[i-1] + if sid >= iv.start && sid <= iv.end { + return + } + } + + // New interval initially [sid, sid]. Merge with previous/next if adjacent or overlapping. + start, end := sid, sid + // Merge with previous + if i > 0 { + prev := cs.ranges[i-1] + if prev.end+1 >= sid { // adjacent or overlap + start = prev.start + if prev.end > end { + end = prev.end + } + // remove prev + cs.ranges = append(cs.ranges[:i-1], cs.ranges[i:]...) + i-- + } + } + // Merge with next (note: i points to the first interval with start > sid or the merged position) + if i < len(cs.ranges) { + next := cs.ranges[i] + if next.start <= end+1 { // adjacent or overlap + if next.end > end { + end = next.end + } + // remove next + cs.ranges = append(cs.ranges[:i], cs.ranges[i+1:]...) + } + } + // Insert merged interval at position i + cs.ranges = append(cs.ranges, interval{}) + copy(cs.ranges[i+1:], cs.ranges[i:]) + cs.ranges[i] = interval{start: start, end: end} + + // Update highClosed: if the merged interval starts at highClosed+1, extend the prefix. + cs.promotePrefix() +} + +// find returns the index of the first interval with start > sid, or len(ranges) +// if none; suitable for insertion and neighbor checks. +func (cs *closedSet) find(sid uint32) int { + lo, hi := 0, len(cs.ranges) + for lo < hi { + mid := (lo + hi) >> 1 + if cs.ranges[mid].start <= sid { + lo = mid + 1 + } else { + hi = mid + } + } + return lo +} + +// promotePrefix extends highClosed if there is an interval that begins at +// highClosed+1. It also prunes any intervals fully included in the prefix by +// removing them from the ranges slice. +func (cs *closedSet) promotePrefix() { + changed := true + for changed { + changed = false + if len(cs.ranges) == 0 { + return + } + // After insert/merge, the first interval that could extend the prefix is the earliest one + // whose start is <= highClosed+1. Because we always keep ranges sorted and disjoint, it must + // be at index 0 if it can extend the prefix. + iv := cs.ranges[0] + want := cs.highClosed + 1 + if iv.start == want { + // Extend prefix to this interval's end and prune it + cs.highClosed = iv.end + cs.ranges = cs.ranges[1:] + changed = true + continue + } + // Prune only intervals fully covered by the prefix. + if iv.end <= cs.highClosed { + cs.ranges = cs.ranges[1:] + changed = true + continue + } + // Otherwise, iv lies strictly above the prefix and cannot extend it now. + } +} diff --git a/pkg/rtun/transport/closedset_test.go b/pkg/rtun/transport/closedset_test.go new file mode 100644 index 000000000..36c257b5c --- /dev/null +++ b/pkg/rtun/transport/closedset_test.go @@ -0,0 +1,152 @@ +package transport + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestClosedSetBasicPrefixAndRanges(t *testing.T) { + var cs closedSet + + // Initially everything open except sid 0 considered closed by spec + require.True(t, cs.IsClosed(0)) + require.False(t, cs.IsClosed(1)) + + // Close(1) promotes prefix to 1 + cs.Close(1) + require.True(t, cs.IsClosed(1)) + require.Equal(t, uint32(1), cs.highClosed) + require.Empty(t, cs.ranges) + + // Close non-contiguous sid produces a range + cs.Close(3) + require.False(t, cs.IsClosed(2)) + require.True(t, cs.IsClosed(3)) + require.Len(t, cs.ranges, 1) + require.Equal(t, interval{start: 3, end: 3}, cs.ranges[0]) + + // Close(2) creates [2,3] which begins at highClosed+1=2, so prefix promotes to 3 and ranges empty + cs.Close(2) + require.True(t, cs.IsClosed(2)) + require.Equal(t, uint32(3), cs.highClosed) + require.Empty(t, cs.ranges) + + // Closing 4 promotes prefix to 4 (no ranges kept) + cs.Close(4) + require.True(t, cs.IsClosed(4)) + require.Equal(t, uint32(4), cs.highClosed) + require.Empty(t, cs.ranges) + + // Closing 5 promotes prefix to 5 (no ranges kept) + cs.Close(5) + require.True(t, cs.IsClosed(5)) + require.Equal(t, uint32(5), cs.highClosed) + require.Empty(t, cs.ranges) + + // Closing 0 ignored; does not change state (highClosed already 5 and no ranges) + cs.Close(0) + require.Equal(t, uint32(5), cs.highClosed) + require.Empty(t, cs.ranges) + + // Closing 6 promotes prefix further + cs.Close(6) + require.Equal(t, uint32(6), cs.highClosed) + require.Empty(t, cs.ranges) + + // Now closing 2..6 already closed does nothing + for sid := uint32(2); sid <= 6; sid++ { + cs.Close(sid) + } + require.Equal(t, uint32(6), cs.highClosed) + require.Empty(t, cs.ranges) + + // Close(7) continues extension + cs.Close(7) + require.Equal(t, uint32(7), cs.highClosed) + require.Empty(t, cs.ranges) + + // Close(8) continues extension + cs.Close(8) + require.Equal(t, uint32(8), cs.highClosed) + require.Empty(t, cs.ranges) +} + +func TestClosedSetMergeBothSides(t *testing.T) { + var cs closedSet + // Create two ranges [10,12] and [14,16], then close 13 to merge both into [10,16] + cs.Close(11) + cs.Close(12) + cs.Close(10) + require.Equal(t, []interval{{start: 10, end: 12}}, cs.ranges) + + cs.Close(15) + cs.Close(16) + cs.Close(14) + require.Equal(t, []interval{{start: 10, end: 12}, {start: 14, end: 16}}, cs.ranges) + + cs.Close(13) + require.Equal(t, []interval{{start: 10, end: 16}}, cs.ranges) +} + +func TestClosedSetPromotionFromRange(t *testing.T) { + var cs closedSet + // Close a distant block [5,7] + cs.Close(7) + cs.Close(6) + cs.Close(5) + require.Equal(t, uint32(0), cs.highClosed) + require.Equal(t, []interval{{start: 5, end: 7}}, cs.ranges) + + // Now close 1..4; once 4 closes, the first range starts at 5 which is highClosed+1 => promotion to 7 + cs.Close(1) + cs.Close(2) + cs.Close(3) + require.Equal(t, uint32(3), cs.highClosed) + cs.Close(4) + require.Equal(t, uint32(7), cs.highClosed) + require.Empty(t, cs.ranges) +} + +func TestClosedSetIdempotentAndOrderInvariant(t *testing.T) { + var a, b closedSet + order1 := []uint32{100, 1, 3, 2, 5, 4} + order2 := []uint32{1, 2, 3, 4, 5, 100} + for _, sid := range order1 { + a.Close(sid) + } + for _, sid := range order2 { + b.Close(sid) + } + require.Equal(t, a.highClosed, b.highClosed) + require.Equal(t, a.ranges, b.ranges) +} + +func TestClosedSetLargeValues(t *testing.T) { + var cs closedSet + cs.Close(1_000_000) + cs.Close(1_000_002) + require.False(t, cs.IsClosed(1)) + require.True(t, cs.IsClosed(1_000_000)) + require.True(t, cs.IsClosed(1_000_002)) + require.False(t, cs.IsClosed(1_000_001)) + + // Merge into single interval [1_000_000, 1_000_010] + for sid := uint32(1_000_001); sid <= 1_000_010; sid++ { + cs.Close(sid) + } + require.Equal(t, []interval{{start: 1_000_000, end: 1_000_010}}, cs.ranges) +} + +func TestClosedSetHighClosedQuery(t *testing.T) { + var cs closedSet + for sid := uint32(1); sid <= 50; sid++ { + cs.Close(sid) + } + require.Equal(t, uint32(50), cs.highClosed) + // Queries under prefix are true without ranges + for sid := uint32(1); sid <= 50; sid++ { + require.True(t, cs.IsClosed(sid)) + } + require.Empty(t, cs.ranges) +} diff --git a/pkg/rtun/transport/conn.go b/pkg/rtun/transport/conn.go new file mode 100644 index 000000000..970e30744 --- /dev/null +++ b/pkg/rtun/transport/conn.go @@ -0,0 +1,310 @@ +package transport + +import ( + "errors" + "io" + "net" + "sync" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" +) + +const maxWriteChunk = 32 * 1024 + +// virtConn implements net.Conn over a multiplexed SID. +type virtConn struct { + mux *Session + sid uint32 + + readCh chan []byte + readErr error + readMu sync.Mutex + readRem []byte // remainder from partial reads + + writeMu sync.Mutex + writeClosed bool + + rdDeadline time.Time + wrDeadline time.Time + + // closeReadOnce ensures the read side is closed exactly once, regardless of + // whether closure originates locally (Close), remotely (FIN), or via RST. + closeReadOnce sync.Once + + // idle timer management + idleMu sync.Mutex + idleTimer *time.Timer +} + +var _ net.Conn = (*virtConn)(nil) + +func newVirtConn(m *Session, sid uint32) *virtConn { + return &virtConn{ + mux: m, + sid: sid, + readCh: make(chan []byte, 256), + } +} + +// Read implements net.Conn for virtConn by returning data delivered for this SID, honoring read deadlines. +func (c *virtConn) Read(p []byte) (int, error) { + // Check terminal error or remainder under lock + c.readMu.Lock() + if c.readErr != nil { + err := c.readErr + c.readMu.Unlock() + return 0, err + } + if len(c.readRem) > 0 { + n := copy(p, c.readRem) + c.readRem = c.readRem[n:] + c.readMu.Unlock() + return n, nil + } + // Snapshot deadline then release lock before blocking on channel + deadline := c.rdDeadline + c.readMu.Unlock() + + if deadline.IsZero() { + buf, ok := <-c.readCh + if !ok { + c.readMu.Lock() + err := c.readErr + c.readMu.Unlock() + if err == nil { + return 0, io.EOF + } + return 0, err + } + n := copy(p, buf) + c.onActivity() + if n < len(buf) { + c.readMu.Lock() + c.readRem = buf[n:] + c.readMu.Unlock() + } + return n, nil + } + + // Deadline set + until := time.Until(deadline) + if until <= 0 { + return 0, ErrTimeout + } + timer := time.NewTimer(until) + defer timer.Stop() + select { + case buf, ok := <-c.readCh: + if !ok { + c.readMu.Lock() + err := c.readErr + c.readMu.Unlock() + if err == nil { + return 0, io.EOF + } + return 0, err + } + n := copy(p, buf) + c.onActivity() + if n < len(buf) { + c.readMu.Lock() + c.readRem = buf[n:] + c.readMu.Unlock() + } + return n, nil + case <-timer.C: + return 0, ErrTimeout + } +} + +// Write implements net.Conn for virtConn by sending DATA frames for this SID, honoring write deadlines. +func (c *virtConn) Write(p []byte) (int, error) { + c.writeMu.Lock() + defer c.writeMu.Unlock() + if c.writeClosed { + return 0, net.ErrClosed + } + // Writes are allowed even after remote FIN (half-close), so we do not block based on remote state. + total := 0 + for len(p) > 0 { + if !c.wrDeadline.IsZero() && time.Until(c.wrDeadline) <= 0 { + if total == 0 { + return 0, ErrTimeout + } + return total, ErrTimeout + } + chunk := p + if len(chunk) > maxWriteChunk { + chunk = p[:maxWriteChunk] + } + frame := &rtunpb.Frame{Sid: c.sid, Kind: &rtunpb.Frame_Data{Data: &rtunpb.Data{Payload: append([]byte(nil), chunk...)}}} + if err := c.mux.link.Send(frame); err != nil { + return total, err + } + c.onActivity() + if c.mux.m != nil { + c.mux.m.recordFrameTx(c.mux.link.Context(), "DATA") + c.mux.m.recordBytesTx(c.mux.link.Context(), int64(len(chunk))) + } + total += len(chunk) + p = p[len(chunk):] + } + return total, nil +} + +// Close implements net.Conn for virtConn by half-closing writes (sending FIN) and releasing resources. +func (c *virtConn) Close() error { + c.writeMu.Lock() + if c.writeClosed { + c.writeMu.Unlock() + return nil + } + c.writeClosed = true + c.writeMu.Unlock() + // send FIN (ack=false) + _ = c.mux.link.Send(&rtunpb.Frame{Sid: c.sid, Kind: &rtunpb.Frame_Fin{Fin: &rtunpb.Fin{Ack: false}}}) + // Unblock any pending Read by closing the read side + c.closeReadOnce.Do(func() { + c.readMu.Lock() + if c.readErr == nil { + c.readErr = io.EOF + } + c.readMu.Unlock() + close(c.readCh) + }) + c.stopIdleTimer() + c.mux.removeConn(c.sid) + return nil +} + +// LocalAddr implements net.Conn. +func (c *virtConn) LocalAddr() net.Addr { return rtunAddr{"rtun-local"} } + +// RemoteAddr implements net.Conn. +func (c *virtConn) RemoteAddr() net.Addr { return rtunAddr{"rtun-remote"} } + +// SetDeadline implements net.Conn. +func (c *virtConn) SetDeadline(t time.Time) error { + c.readMu.Lock() + c.rdDeadline = t + c.readMu.Unlock() + c.writeMu.Lock() + c.wrDeadline = t + c.writeMu.Unlock() + return nil +} + +// SetReadDeadline implements net.Conn. +func (c *virtConn) SetReadDeadline(t time.Time) error { + c.readMu.Lock() + c.rdDeadline = t + c.readMu.Unlock() + return nil +} + +// SetWriteDeadline implements net.Conn. +func (c *virtConn) SetWriteDeadline(t time.Time) error { + c.writeMu.Lock() + c.wrDeadline = t + c.writeMu.Unlock() + return nil +} + +// feedData is called by the mux to deliver inbound bytes. +func (c *virtConn) feedData(b []byte) { + // If we've already observed a terminal read error (EOF, overflow, RST), drop incoming data. + c.readMu.Lock() + alreadyErr := c.readErr + c.readMu.Unlock() + if alreadyErr != nil { + return + } + select { + case c.readCh <- b: + // delivered + c.onActivity() + default: + // backpressure: send RST, perform full RST handling (including write-side close), and detach from session + err := errors.New("rtun: inbound buffer overflow") + _ = c.mux.link.Send(&rtunpb.Frame{Sid: c.sid, Kind: &rtunpb.Frame_Rst{Rst: &rtunpb.Rst{Code: rtunpb.RstCode_RST_CODE_INTERNAL}}}) + c.handleRst(err) + c.mux.removeConn(c.sid) + } +} + +func (c *virtConn) handleFin(ack bool) { + // mark remote closed; signal EOF + c.readMu.Lock() + if c.readErr == nil { + c.readErr = io.EOF + } + c.readMu.Unlock() + c.closeReadOnce.Do(func() { close(c.readCh) }) + c.stopIdleTimer() +} + +func (c *virtConn) handleRst(err error) { + c.readMu.Lock() + if c.readErr == nil { + c.readErr = err + } + c.readMu.Unlock() + c.closeReadOnce.Do(func() { close(c.readCh) }) + c.writeMu.Lock() + c.writeClosed = true + c.writeMu.Unlock() + c.stopIdleTimer() +} + +type rtunAddr struct{ s string } + +func (a rtunAddr) Network() string { return "rtun" } +func (a rtunAddr) String() string { return a.s } + +// startIdleTimer starts or resets the per-connection idle timer according to the Session's configuration. +func (c *virtConn) startIdleTimer() { + timeout := c.mux.idleTimeout + if timeout < 0 { + return + } + c.idleMu.Lock() + if c.idleTimer == nil { + c.idleTimer = time.AfterFunc(timeout, func() { + c.handleIdleTimeout() + }) + } else { + c.idleTimer.Reset(timeout) + } + c.idleMu.Unlock() +} + +// onActivity resets the idle timer if enabled to reflect recent I/O activity. +func (c *virtConn) onActivity() { + timeout := c.mux.idleTimeout + if timeout < 0 { + return + } + c.idleMu.Lock() + if c.idleTimer != nil { + _ = c.idleTimer.Reset(timeout) + } + c.idleMu.Unlock() +} + +func (c *virtConn) stopIdleTimer() { + c.idleMu.Lock() + if c.idleTimer != nil { + c.idleTimer.Stop() + c.idleTimer = nil + } + c.idleMu.Unlock() +} + +func (c *virtConn) handleIdleTimeout() { + // Timer fired: send RST and tear down connection. + _ = c.mux.link.Send(&rtunpb.Frame{Sid: c.sid, Kind: &rtunpb.Frame_Rst{Rst: &rtunpb.Rst{Code: rtunpb.RstCode_RST_CODE_INTERNAL}}}) + c.handleRst(ErrTimeout) + c.mux.removeConn(c.sid) + c.mux.markClosed(c.sid) +} diff --git a/pkg/rtun/transport/conn_test.go b/pkg/rtun/transport/conn_test.go new file mode 100644 index 000000000..b9d310447 --- /dev/null +++ b/pkg/rtun/transport/conn_test.go @@ -0,0 +1,94 @@ +package transport + +import ( + "context" + "errors" + "io" + "net" + "testing" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" + "github.com/stretchr/testify/require" +) + +func TestVirtConnCloseIdempotentAndWriteAfterClose(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + t.Cleanup(func() { _ = ln.Close() }) + + accCh := make(chan net.Conn, 1) + go func() { c, _ := ln.Accept(); accCh <- c }() + tl.push(&rtunpb.Frame{Sid: 31, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + + var c net.Conn + select { + case c = <-accCh: + case <-ctx.Done(): + t.Fatal("accept timeout") + } + + // Close twice is ok + require.NoError(t, c.Close()) + require.NoError(t, c.Close()) + + // Write after close should fail with ErrClosed + _, err = c.Write([]byte("x")) + require.ErrorIs(t, err, net.ErrClosed) + + // Read after close yields EOF or ErrClosed + _, err = c.Read(make([]byte, 1)) + require.True(t, errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed)) +} + +func TestVirtConnRemoteRstPropagatesToRead(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + t.Cleanup(func() { _ = ln.Close() }) + + accCh := make(chan net.Conn, 1) + go func() { c, _ := ln.Accept(); accCh <- c }() + tl.push(&rtunpb.Frame{Sid: 32, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + var c net.Conn + select { + case c = <-accCh: + case <-ctx.Done(): + t.Fatal("accept timeout") + } + + // Remote sends RST + tl.push(&rtunpb.Frame{Sid: 32, Kind: &rtunpb.Frame_Rst{Rst: &rtunpb.Rst{Code: rtunpb.RstCode_RST_CODE_INTERNAL}}}) + _, err = c.Read(make([]byte, 1)) + require.ErrorIs(t, err, ErrConnReset) +} + +func TestVirtConnLocalRemoteAddr(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + t.Cleanup(func() { _ = ln.Close() }) + + accCh := make(chan net.Conn, 1) + go func() { c, _ := ln.Accept(); accCh <- c }() + tl.push(&rtunpb.Frame{Sid: 33, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + var c net.Conn + select { + case c = <-accCh: + case <-ctx.Done(): + t.Fatal("accept timeout") + } + + require.NotNil(t, c.LocalAddr()) + require.NotNil(t, c.RemoteAddr()) +} diff --git a/pkg/rtun/transport/errors.go b/pkg/rtun/transport/errors.go new file mode 100644 index 000000000..c8b2d0757 --- /dev/null +++ b/pkg/rtun/transport/errors.go @@ -0,0 +1,10 @@ +package transport + +import "errors" + +var ( + // ErrConnReset indicates the remote reset the connection. + ErrConnReset = errors.New("rtun: connection reset") + // ErrTimeout indicates an operation exceeded its deadline. + ErrTimeout = errors.New("rtun: deadline exceeded") +) diff --git a/pkg/rtun/transport/listener.go b/pkg/rtun/transport/listener.go new file mode 100644 index 000000000..566f5cb65 --- /dev/null +++ b/pkg/rtun/transport/listener.go @@ -0,0 +1,72 @@ +package transport + +import ( + "net" + "sync" +) + +type rtunListener struct { + port uint32 + accepts chan net.Conn + mux *Session + mu sync.Mutex + closed bool + err error +} + +// Accept implements net.Listener and returns the next inbound connection. +func (l *rtunListener) Accept() (net.Conn, error) { + l.mu.Lock() + lerr := l.err + l.mu.Unlock() + if lerr != nil { + return nil, lerr + } + c, ok := <-l.accepts + if !ok { + l.mu.Lock() + lerr := l.err + l.mu.Unlock() + if lerr != nil { + return nil, lerr + } + return nil, net.ErrClosed + } + return c, nil +} + +// Close implements net.Listener and stops accepting new connections. +func (l *rtunListener) Close() error { + l.mu.Lock() + if l.closed { + l.mu.Unlock() + return nil + } + l.closed = true + l.mu.Unlock() + l.mux.removeListener(l.port) + close(l.accepts) + return nil +} + +// Addr implements net.Listener. +func (l *rtunListener) Addr() net.Addr { return rtunAddr{"rtun-listener"} } + +func (l *rtunListener) enqueue(c *virtConn) { + select { + case l.accepts <- c: + default: + // listener full, drop + c.handleRst(net.ErrClosed) + } +} + +func (l *rtunListener) closeWithErr(err error) { + l.mu.Lock() + l.err = err + if !l.closed { + l.closed = true + close(l.accepts) + } + l.mu.Unlock() +} diff --git a/pkg/rtun/transport/session.go b/pkg/rtun/transport/session.go new file mode 100644 index 000000000..fad824d9e --- /dev/null +++ b/pkg/rtun/transport/session.go @@ -0,0 +1,435 @@ +package transport + +import ( + "context" + "errors" + "net" + "sync" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" + sdkmetrics "github.com/conductorone/baton-sdk/pkg/metrics" + "go.uber.org/zap" +) + +// Option configures a Session. +type Option func(*options) + +type options struct { + logger *zap.Logger + // allowedPorts is an allowlist of ports that the server is permitted to Open() toward the client. + // If nil or empty, all ports are allowed. + allowedPorts map[uint32]bool + // idleTimeout controls per-SID idle expiration; zero means use default (10m). Negative disables. + idleTimeout time.Duration + // metrics handler (optional) + metrics sdkmetrics.Handler +} + +// WithLogger sets a structured logger for the session. +func WithLogger(l *zap.Logger) Option { + return func(o *options) { + o.logger = l + } +} + +// WithAllowedPorts sets the explicit list of allowed ports for Session.Open(). +func WithAllowedPorts(ports []uint32) Option { + return func(o *options) { + if len(ports) == 0 { + o.allowedPorts = nil + return + } + if o.allowedPorts == nil { + o.allowedPorts = make(map[uint32]bool, len(ports)) + } + for _, p := range ports { + o.allowedPorts[p] = true + } + } +} + +// WithIdleTimeout sets the per-SID idle timeout. Zero selects the default (10m). Negative disables. +func WithIdleTimeout(d time.Duration) Option { + return func(o *options) { + o.idleTimeout = d + } +} + +// WithMetricsHandler injects a metrics handler for transport-level metrics. +func WithMetricsHandler(h sdkmetrics.Handler) Option { + return func(o *options) { + o.metrics = h + } +} + +// Link is the minimal adapter that the generated gRPC stream satisfies. +// It is intentionally small to decouple from gRPC specifics and simplify testing. +type Link interface { + Send(*rtunpb.Frame) error + Recv() (*rtunpb.Frame, error) + Context() context.Context +} + +// Session represents a per-link dispatcher with a single Recv loop and listener/conn registry. +type Session struct { + link Link + logger *zap.Logger + + mu sync.Mutex + started bool + closing bool + conns map[uint32]*virtConn + listeners map[uint32]*rtunListener + nextSID uint32 + closed closedSet // closed SIDs for late-frame detection + + // configuration + allowedPorts map[uint32]bool + idleTimeout time.Duration + + // metrics (optional) + m *transportMetrics +} + +// NewSession constructs a per-link session. The Recv loop starts on first listener registration. +func NewSession(link Link, opts ...Option) *Session { + var o options + for _, opt := range opts { + opt(&o) + } + logger := o.logger + if logger == nil { + logger = zap.NewNop() + } + idle := o.idleTimeout + if idle == 0 { + idle = 10 * time.Minute + } + s := &Session{ + link: link, + conns: make(map[uint32]*virtConn), + listeners: make(map[uint32]*rtunListener), + nextSID: 1, + logger: logger, + allowedPorts: o.allowedPorts, + idleTimeout: idle, + } + if o.metrics != nil { + s.m = newTransportMetrics(o.metrics, func(ctx context.Context) (int64, map[string]string) { + s.mu.Lock() + n := len(s.conns) + s.mu.Unlock() + return int64(n), nil + }) + } + return s +} + +// Listen exposes a net.Listener for a numeric port on this Session. +func (s *Session) Listen(ctx context.Context, port uint32, opts ...Option) (net.Listener, error) { + l := &rtunListener{ + port: port, + accepts: make(chan net.Conn, 512), // effectively listener backlog + mux: s, + } + if err := s.addListener(l); err != nil { + return nil, err + } + s.startOnce() + return l, nil +} + +func (s *Session) removeConn(sid uint32) int { + s.mu.Lock() + delete(s.conns, sid) + n := len(s.conns) + s.mu.Unlock() + return n +} + +func (s *Session) addListener(l *rtunListener) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.closing { + return net.ErrClosed + } + if _, exists := s.listeners[l.port]; exists { + return errors.New("rtun: listener already exists for port") + } + if s.listeners == nil { + s.listeners = make(map[uint32]*rtunListener) + } + s.listeners[l.port] = l + return nil +} + +func (s *Session) removeListener(port uint32) { + s.mu.Lock() + delete(s.listeners, port) + s.mu.Unlock() +} + +func (s *Session) startOnce() { + s.mu.Lock() + if s.started { + s.mu.Unlock() + return + } + s.started = true + s.mu.Unlock() + go s.recvLoop() +} + +func (s *Session) recvLoop() { + for { + fr, err := s.link.Recv() + if err != nil { + s.mu.Lock() + for _, c := range s.conns { + c.handleRst(err) + } + for _, l := range s.listeners { + l.closeWithErr(err) + } + s.closing = true + s.mu.Unlock() + return + } + if fr == nil { + continue + } + sid := fr.GetSid() + if s.m != nil { + s.m.recordFrameRx(s.link.Context(), kindOf(fr)) + } + // sid==0 is invalid; treat SYN on sid 0 as fatal, drop other frames silently. + if sid == 0 { + if _, isSyn := fr.Kind.(*rtunpb.Frame_Syn); isSyn { + s.logger.Warn("protocol violation: SYN with sid 0; closing link") + s.failLocked(errors.New("rtun: protocol violation (sid 0)")) + return + } + continue + } + switch k := fr.Kind.(type) { + case *rtunpb.Frame_Syn: + port := k.Syn.GetPort() + s.mu.Lock() + l := s.listeners[port] + if l == nil { + s.mu.Unlock() + // Send RST outside lock + _ = s.link.Send(&rtunpb.Frame{Sid: sid, Kind: &rtunpb.Frame_Rst{Rst: &rtunpb.Rst{Code: rtunpb.RstCode_RST_CODE_NO_LISTENER}}}) + if s.m != nil { + s.m.recordRstSent(s.link.Context(), "no_listener") + } + continue + } + // Duplicate SYN on existing SID is a fatal protocol error for the link. + if _, exists := s.conns[sid]; exists || s.closed.IsClosed(sid) { + s.mu.Unlock() + s.logger.Warn("protocol violation: duplicate SYN for existing or closed SID; closing link", zap.Uint32("sid", sid)) + s.failLocked(errors.New("rtun: duplicate SYN")) + return + } + vc := newVirtConn(s, sid) + if s.conns == nil { + s.conns = make(map[uint32]*virtConn) + } + s.conns[sid] = vc + vc.startIdleTimer() + // connection count is observed via metrics callback + s.mu.Unlock() + l.enqueue(vc) + case *rtunpb.Frame_Data: + s.mu.Lock() + // Ignore if SID is closed (late frame) + if s.closed.IsClosed(sid) { + s.mu.Unlock() + continue + } + c := s.conns[sid] + payload := append([]byte(nil), k.Data.GetPayload()...) + if s.m != nil { + s.m.recordBytesRx(s.link.Context(), int64(len(payload))) + } + if c != nil { + s.mu.Unlock() + c.feedData(payload) + } else { + // defensive programming decision: DATA-before-SYN is a protocol error. RST and do not buffer. + s.mu.Unlock() + _ = s.link.Send(&rtunpb.Frame{Sid: sid, Kind: &rtunpb.Frame_Rst{Rst: &rtunpb.Rst{Code: rtunpb.RstCode_RST_CODE_INTERNAL}}}) + if s.m != nil { + s.m.recordRstSent(s.link.Context(), "protocol_violation") + } + continue + } + case *rtunpb.Frame_Fin: + s.mu.Lock() + c := s.conns[sid] + s.mu.Unlock() + if c != nil { + c.handleFin(k.Fin.GetAck()) + _ = s.removeConn(sid) + s.mu.Lock() + s.closed.Close(sid) + s.mu.Unlock() + } + case *rtunpb.Frame_Rst: + s.mu.Lock() + c := s.conns[sid] + s.mu.Unlock() + if c != nil { + if s.m != nil { + s.m.recordRstRecv(s.link.Context(), k.Rst.GetCode().String()) + } + c.handleRst(ErrConnReset) + _ = s.removeConn(sid) + s.mu.Lock() + s.closed.Close(sid) + s.mu.Unlock() + } + } + } +} + +// Start begins the Recv loop if not already started. +func (s *Session) Start() { s.startOnce() } + +// Open initiates a reverse connection to the remote client on the given port. +// It allocates a new SID, sends SYN, and returns a net.Conn bound to that SID. +// If ctx is canceled before SYN is sent, returns ctx.Err(). +func (s *Session) Open(ctx context.Context, port uint32) (net.Conn, error) { + // Check context before acquiring lock + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + // Enforce allowed ports policy if configured + if s.allowedPorts != nil { + if !s.allowedPorts[port] { + return nil, errors.New("rtun: port not allowed by HELLO policy") + } + } + + s.mu.Lock() + if s.closing { + s.mu.Unlock() + return nil, net.ErrClosed + } + sid := s.nextSID + if sid == 0 { + sid = 1 + } + s.nextSID = sid + 1 + vc := newVirtConn(s, sid) + s.conns[sid] = vc + s.mu.Unlock() + + // Send SYN to remote + if err := s.link.Send(&rtunpb.Frame{Sid: sid, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: port}}}); err != nil { + // Cleanup on failure + _ = s.removeConn(sid) + return nil, err + } + vc.startIdleTimer() + if s.m != nil { + s.m.recordFrameTx(s.link.Context(), "SYN") + } + return vc, nil +} + +// markClosed records the SID as closed to ignore late frames. +func (s *Session) markClosed(sid uint32) { + s.mu.Lock() + s.closed.Close(sid) + s.mu.Unlock() +} + +// failLocked closes all session resources and marks the session as closing. +// It must be called outside of s.mu (we only use it here where no lock is held), +// but it acquires the lock internally for safety. +func (s *Session) failLocked(err error) { + s.mu.Lock() + if s.closing { + s.mu.Unlock() + return + } + for _, c := range s.conns { + c.handleRst(err) + } + for _, l := range s.listeners { + l.closeWithErr(err) + } + s.closing = true + s.mu.Unlock() +} + +// metrics helpers. +type transportMetrics struct { + framesRx sdkmetrics.Int64Counter + framesTx sdkmetrics.Int64Counter + bytesRx sdkmetrics.Int64Counter + bytesTx sdkmetrics.Int64Counter + rstSent sdkmetrics.Int64Counter + rstRecv sdkmetrics.Int64Counter +} + +func newTransportMetrics(h sdkmetrics.Handler, observeSids func(ctx context.Context) (int64, map[string]string)) *transportMetrics { + m := &transportMetrics{ + framesRx: h.Int64Counter("rtun.transport.frames_rx_total", "transport frames received", sdkmetrics.Dimensionless), + framesTx: h.Int64Counter("rtun.transport.frames_tx_total", "transport frames sent", sdkmetrics.Dimensionless), + bytesRx: h.Int64Counter("rtun.transport.data_bytes_rx_total", "transport bytes received", sdkmetrics.Bytes), + bytesTx: h.Int64Counter("rtun.transport.data_bytes_tx_total", "transport bytes sent", sdkmetrics.Bytes), + rstSent: h.Int64Counter("rtun.transport.rst_sent_total", "RST frames sent by code", sdkmetrics.Dimensionless), + rstRecv: h.Int64Counter("rtun.transport.rst_recv_total", "RST frames received by code", sdkmetrics.Dimensionless), + } + // register observable gauge for active SIDs using provided callback + h.RegisterInt64ObservableGauge("rtun.transport.sids_active", "active SIDs per session", sdkmetrics.Dimensionless, observeSids) + return m +} + +func (m *transportMetrics) recordFrameRx(ctx context.Context, kind string) { + m.framesRx.Add(ctx, 1, map[string]string{"kind": kind}) +} + +func (m *transportMetrics) recordFrameTx(ctx context.Context, kind string) { + m.framesTx.Add(ctx, 1, map[string]string{"kind": kind}) +} + +func (m *transportMetrics) recordBytesRx(ctx context.Context, n int64) { + m.bytesRx.Add(ctx, n, nil) +} + +func (m *transportMetrics) recordBytesTx(ctx context.Context, n int64) { + m.bytesTx.Add(ctx, n, nil) +} + +func (m *transportMetrics) recordRstSent(ctx context.Context, code string) { + m.rstSent.Add(ctx, 1, map[string]string{"code": code}) +} + +func (m *transportMetrics) recordRstRecv(ctx context.Context, code string) { + m.rstRecv.Add(ctx, 1, map[string]string{"code": code}) +} + +func kindOf(fr *rtunpb.Frame) string { + switch fr.Kind.(type) { + case *rtunpb.Frame_Hello: + return "HELLO" + case *rtunpb.Frame_Syn: + return "SYN" + case *rtunpb.Frame_Data: + return "DATA" + case *rtunpb.Frame_Fin: + return "FIN" + case *rtunpb.Frame_Rst: + return "RST" + default: + return "UNKNOWN" + } +} diff --git a/pkg/rtun/transport/session_race_test.go b/pkg/rtun/transport/session_race_test.go new file mode 100644 index 000000000..d41ce769b --- /dev/null +++ b/pkg/rtun/transport/session_race_test.go @@ -0,0 +1,96 @@ +package transport + +import ( + "context" + "sync" + "testing" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" +) + +// TestConcurrentOperationsNoPanic validates no races/panics under concurrent use. +// Run with -race to detect data races. +func TestConcurrentOperationsNoPanic(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + var wg sync.WaitGroup + + // Concurrent Listen on ports 1-5 + for port := uint32(1); port <= 5; port++ { + wg.Add(1) + go func(p uint32) { + defer wg.Done() + ln, err := s.Listen(ctx, p) + if err == nil { + time.Sleep(10 * time.Millisecond) + ln.Close() + } + }(port) + } + + // Concurrent Open (reverse dial) + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + conn, err := s.Open(ctx, 1) + if err == nil { + time.Sleep(5 * time.Millisecond) + conn.Close() + } + }() + } + + // Feed random frames + for u := uint32(0); u < 20; u++ { + wg.Add(1) + go func(i uint32) { + defer wg.Done() + sid := 100 + i + tl.push(&rtunpb.Frame{Sid: sid, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + time.Sleep(2 * time.Millisecond) + tl.push(&rtunpb.Frame{Sid: sid, Kind: &rtunpb.Frame_Data{Data: &rtunpb.Data{Payload: []byte("x")}}}) + }(u) + } + + wg.Wait() + // No panic or deadlock = success +} + +func TestLateFrameAfterClose(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx := context.Background() + + ln, _ := s.Listen(ctx, 1) + defer ln.Close() + + // Establish conn + tl.push(&rtunpb.Frame{Sid: 99, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + conn, _ := ln.Accept() + conn.Close() + + // Give close time to propagate + time.Sleep(20 * time.Millisecond) + + // Send late DATA; should not panic + tl.push(&rtunpb.Frame{Sid: 99, Kind: &rtunpb.Frame_Data{Data: &rtunpb.Data{Payload: []byte("late")}}}) + time.Sleep(10 * time.Millisecond) + // No panic = success +} + +func TestContextCanceledOpen(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := s.Open(ctx, 1) + if err == nil { + t.Fatal("expected error on canceled context") + } +} diff --git a/pkg/rtun/transport/session_test.go b/pkg/rtun/transport/session_test.go new file mode 100644 index 000000000..d36ea932f --- /dev/null +++ b/pkg/rtun/transport/session_test.go @@ -0,0 +1,308 @@ +package transport + +import ( + "context" + "io" + "net" + "testing" + "time" + + rtunpb "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1" + "github.com/stretchr/testify/require" +) + +type testLink struct { + ctx context.Context + inCh chan *rtunpb.Frame + sent []*rtunpb.Frame +} + +func newTestLink() *testLink { + return &testLink{ctx: context.Background(), inCh: make(chan *rtunpb.Frame, 32)} +} + +func (t *testLink) Send(f *rtunpb.Frame) error { + t.sent = append(t.sent, f) + return nil +} + +func (t *testLink) Recv() (*rtunpb.Frame, error) { return <-t.inCh, nil } +func (t *testLink) Context() context.Context { return t.ctx } + +func (t *testLink) push(f *rtunpb.Frame) { t.inCh <- f } + +func TestSessionSynAcceptDataFin(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + defer ln.Close() + + // Accept in background + accCh := make(chan net.Conn, 1) + go func() { + c, err := ln.Accept() + if err == nil { + accCh <- c + } + }() + + // Send SYN for port 1 on sid 5 + tl.push(&rtunpb.Frame{Sid: 5, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + var rwc net.Conn + select { + case rwc = <-accCh: + case <-ctx.Done(): + t.Fatal("accept timeout") + } + + // Deliver DATA from link to conn + tl.push(&rtunpb.Frame{Sid: 5, Kind: &rtunpb.Frame_Data{Data: &rtunpb.Data{Payload: []byte("hi")}}}) + buf := make([]byte, 2) + n, err := rwc.Read(buf) + require.NoError(t, err) + require.Equal(t, 2, n) + require.Equal(t, []byte("hi"), buf) + + // Write from conn to link; expect a DATA frame sent + _, err = rwc.Write([]byte("ok")) + require.NoError(t, err) + require.NotEmpty(t, tl.sent) + last := tl.sent[len(tl.sent)-1] + require.Equal(t, uint32(5), last.GetSid()) + require.IsType(t, &rtunpb.Frame_Data{}, last.Kind) + require.Equal(t, []byte("ok"), last.GetData().GetPayload()) + + // Send FIN from link; expect EOF on next read + tl.push(&rtunpb.Frame{Sid: 5, Kind: &rtunpb.Frame_Fin{Fin: &rtunpb.Fin{Ack: false}}}) + _, err = rwc.Read(make([]byte, 1)) + require.ErrorIs(t, err, io.EOF) + + // Close should send FIN + _ = rwc.Close() + require.NotEmpty(t, tl.sent) + foundFin := false + for _, f := range tl.sent { + if f.GetSid() == 5 { + if _, ok := f.Kind.(*rtunpb.Frame_Fin); ok { + foundFin = true + break + } + } + } + require.True(t, foundFin) +} + +func TestSessionSynNoListenerRst(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + // Start session loop without listener by poking it: creating a throwaway listener on port 1 and closing is too involved, + // so instead we trigger start by creating and closing a short-lived listener which ensures the goroutine is running. + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + ln.Close() + + // Send SYN for a port without a listener + tl.push(&rtunpb.Frame{Sid: 9, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 2}}}) + // Allow loop to process + time.Sleep(10 * time.Millisecond) + require.NotEmpty(t, tl.sent) + last := tl.sent[len(tl.sent)-1] + require.Equal(t, uint32(9), last.GetSid()) + require.IsType(t, &rtunpb.Frame_Rst{}, last.Kind) + require.Equal(t, rtunpb.RstCode_RST_CODE_NO_LISTENER, last.GetRst().GetCode()) +} + +func TestVirtConnReadDeadline(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + defer ln.Close() + + // Accept in background + accCh := make(chan net.Conn, 1) + go func() { + c, _ := ln.Accept() + accCh <- c + }() + + // Establish + tl.push(&rtunpb.Frame{Sid: 7, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + var rwc net.Conn + select { + case rwc = <-accCh: + case <-ctx.Done(): + t.Fatal("accept timeout") + } + + // Set a short read deadline and expect timeout + _ = rwc.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + buf := make([]byte, 1) + _, err = rwc.Read(buf) + require.ErrorIs(t, err, ErrTimeout) +} + +func TestVirtConnHalfCloseWriteAfterRemoteFin(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + defer ln.Close() + + accCh := make(chan net.Conn, 1) + go func() { + c, _ := ln.Accept() + accCh <- c + }() + tl.push(&rtunpb.Frame{Sid: 11, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + var rwc net.Conn + select { + case rwc = <-accCh: + case <-ctx.Done(): + t.Fatal("accept timeout") + } + // Remote sends FIN; we should still be able to Write() + tl.push(&rtunpb.Frame{Sid: 11, Kind: &rtunpb.Frame_Fin{Fin: &rtunpb.Fin{Ack: false}}}) + _, err = rwc.Write([]byte("after")) + require.NoError(t, err) +} + +func TestWriteFragmentationIntoMultipleDataFrames(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + defer ln.Close() + + accCh := make(chan net.Conn, 1) + go func() { c, _ := ln.Accept(); accCh <- c }() + tl.push(&rtunpb.Frame{Sid: 21, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + var rwc net.Conn + select { + case rwc = <-accCh: + case <-ctx.Done(): + t.Fatal("accept timeout") + } + + // Prepare a payload larger than maxWriteChunk (32KiB) to force fragmentation. + big := make([]byte, 32*1024+4096) + for i := range big { + big[i] = byte(i) + } + prev := len(tl.sent) + n, err := rwc.Write(big) + require.NoError(t, err) + require.Equal(t, len(big), n) + + // Collect new DATA frames for this SID and reassemble. + var collected []byte + for _, f := range tl.sent[prev:] { + if f.GetSid() == 21 { + if d := f.GetData(); d != nil { + collected = append(collected, d.GetPayload()...) + } + } + } + require.Equal(t, big, collected) +} + +func TestReadRemainderAcrossReads(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + defer ln.Close() + + accCh := make(chan net.Conn, 1) + go func() { c, _ := ln.Accept(); accCh <- c }() + tl.push(&rtunpb.Frame{Sid: 22, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + var rwc net.Conn + select { + case rwc = <-accCh: + case <-ctx.Done(): + t.Fatal("accept timeout") + } + + // Deliver 3 bytes, then read 2, then read 1 + tl.push(&rtunpb.Frame{Sid: 22, Kind: &rtunpb.Frame_Data{Data: &rtunpb.Data{Payload: []byte("xyz")}}}) + buf := make([]byte, 2) + n, err := rwc.Read(buf) + require.NoError(t, err) + require.Equal(t, 2, n) + require.Equal(t, []byte("xy"), buf) + buf2 := make([]byte, 2) + n, err = rwc.Read(buf2) + require.NoError(t, err) + require.Equal(t, 1, n) + require.Equal(t, []byte("z\x00"), buf2) // last byte remains zero +} + +func TestBackpressureInboundOverflow(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + defer ln.Close() + + accCh := make(chan net.Conn, 1) + go func() { c, _ := ln.Accept(); accCh <- c }() + tl.push(&rtunpb.Frame{Sid: 23, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + var rwc net.Conn + select { + case rwc = <-accCh: + case <-ctx.Done(): + t.Fatal("accept timeout") + } + + // Push more frames than the read buffer can hold (capacity 256) without reading. + for i := 0; i < 300; i++ { + tl.push(&rtunpb.Frame{Sid: 23, Kind: &rtunpb.Frame_Data{Data: &rtunpb.Data{Payload: []byte{byte(i)}}}}) + } + // Give the loop a moment to overflow and close. + time.Sleep(20 * time.Millisecond) + buf := make([]byte, 1) + _, err = rwc.Read(buf) + require.Error(t, err) + require.NotEqual(t, io.EOF, err) +} + +func TestWriteDeadlineTimeout(t *testing.T) { + tl := newTestLink() + s := NewSession(tl) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + ln, err := s.Listen(ctx, 1) + require.NoError(t, err) + defer ln.Close() + + accCh := make(chan net.Conn, 1) + go func() { c, _ := ln.Accept(); accCh <- c }() + tl.push(&rtunpb.Frame{Sid: 24, Kind: &rtunpb.Frame_Syn{Syn: &rtunpb.Syn{Port: 1}}}) + var rwc net.Conn + select { + case rwc = <-accCh: + case <-ctx.Done(): + t.Fatal("accept timeout") + } + // Set deadline in the past and attempt write. + _ = rwc.SetWriteDeadline(time.Now().Add(-1 * time.Millisecond)) + _, err = rwc.Write([]byte("x")) + require.ErrorIs(t, err, ErrTimeout) +} diff --git a/proto/c1/connectorapi/rtun/v1/gateway.proto b/proto/c1/connectorapi/rtun/v1/gateway.proto new file mode 100644 index 000000000..d7564aa48 --- /dev/null +++ b/proto/c1/connectorapi/rtun/v1/gateway.proto @@ -0,0 +1,49 @@ +syntax = "proto3"; + +package c1.connectorapi.rtun.v1; + +import "c1/connectorapi/rtun/v1/rtun.proto"; + +option go_package = "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1"; + +// ReverseDialer allows callers to establish connections to clients via the gateway. +// The gateway bridges caller streams to rtun sessions on the owner server. +service ReverseDialerService { + rpc Open(stream ReverseDialerServiceOpenRequest) returns (stream ReverseDialerServiceOpenResponse); +} + +// ReverseDialerServiceOpenRequest is sent from caller to gateway for the Open RPC. +message ReverseDialerServiceOpenRequest { + oneof kind { + OpenRequest open_req = 1; // initiate a connection (first message, or concurrent opens) + Frame frame = 10; // data/control frames (reuses rtun Frame) + } +} + +// ReverseDialerServiceOpenResponse is sent from gateway to caller for the Open RPC. +message ReverseDialerServiceOpenResponse { + oneof kind { + OpenResponse open_resp = 1; // handshake result + Frame frame = 10; // data/control frames (reuses rtun Frame) + } +} + +// OpenRequest initiates a reverse connection to a client. +// The caller proposes a gSID (gateway SID) for this connection to support concurrent opens. +message OpenRequest { + uint32 gsid = 1; // caller-proposed SID for this connection (must be unique per stream) + string client_id = 2; // target client (must be URL-safe) + uint32 port = 3; // target port on the client +} + +// OpenResponse indicates the result of an OpenRequest. +message OpenResponse { + uint32 gsid = 1; // echoed from OpenRequest + oneof result { + NotFound not_found = 2; // gateway doesn't own this client; caller should re-resolve + Opened opened = 3; // success; use the gSID for subsequent frames + } +} + +message NotFound {} +message Opened {} diff --git a/proto/c1/connectorapi/rtun/v1/rtun.proto b/proto/c1/connectorapi/rtun/v1/rtun.proto new file mode 100644 index 000000000..459433ecc --- /dev/null +++ b/proto/c1/connectorapi/rtun/v1/rtun.proto @@ -0,0 +1,57 @@ +syntax = "proto3"; + +package c1.connectorapi.rtun.v1; + +option go_package = "github.com/conductorone/baton-sdk/pb/c1/connectorapi/rtun/v1"; + +service ReverseTunnelService { + rpc Link(stream ReverseTunnelServiceLinkRequest) returns (stream ReverseTunnelServiceLinkResponse); +} + +message ReverseTunnelServiceLinkRequest { + Frame frame = 1; +} + +message ReverseTunnelServiceLinkResponse { + Frame frame = 1; +} + +message Frame { + uint32 sid = 1; // 0 reserved for control + oneof kind { + Hello hello = 10; // client -> server (first) + Syn syn = 11; // server -> client (reverse open) + Data data = 12; // either direction + Fin fin = 13; // either direction + Rst rst = 14; // either direction + } +} + +message Hello { + repeated uint32 ports = 1; + uint32 protocol = 2; +} + +message Syn { + uint32 port = 1; +} + +message Data { + bytes payload = 1; +} + +message Fin { + bool ack = 1; +} + +enum RstCode { + RST_CODE_UNSPECIFIED = 0; + RST_CODE_NO_LISTENER = 1; + RST_CODE_PORT_NOT_ADVERTISED = 2; + RST_CODE_TIMEOUT = 3; + RST_CODE_INTERNAL = 4; +} + +message Rst { + RstCode code = 1; +} diff --git a/vendor/google.golang.org/grpc/health/client.go b/vendor/google.golang.org/grpc/health/client.go new file mode 100644 index 000000000..740745c45 --- /dev/null +++ b/vendor/google.golang.org/grpc/health/client.go @@ -0,0 +1,117 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package health + +import ( + "context" + "fmt" + "io" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" + healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/backoff" + "google.golang.org/grpc/status" +) + +var ( + backoffStrategy = backoff.DefaultExponential + backoffFunc = func(ctx context.Context, retries int) bool { + d := backoffStrategy.Backoff(retries) + timer := time.NewTimer(d) + select { + case <-timer.C: + return true + case <-ctx.Done(): + timer.Stop() + return false + } + } +) + +func init() { + internal.HealthCheckFunc = clientHealthCheck +} + +const healthCheckMethod = "/grpc.health.v1.Health/Watch" + +// This function implements the protocol defined at: +// https://github.com/grpc/grpc/blob/master/doc/health-checking.md +func clientHealthCheck(ctx context.Context, newStream func(string) (any, error), setConnectivityState func(connectivity.State, error), service string) error { + tryCnt := 0 + +retryConnection: + for { + // Backs off if the connection has failed in some way without receiving a message in the previous retry. + if tryCnt > 0 && !backoffFunc(ctx, tryCnt-1) { + return nil + } + tryCnt++ + + if ctx.Err() != nil { + return nil + } + setConnectivityState(connectivity.Connecting, nil) + rawS, err := newStream(healthCheckMethod) + if err != nil { + continue retryConnection + } + + s, ok := rawS.(grpc.ClientStream) + // Ideally, this should never happen. But if it happens, the server is marked as healthy for LBing purposes. + if !ok { + setConnectivityState(connectivity.Ready, nil) + return fmt.Errorf("newStream returned %v (type %T); want grpc.ClientStream", rawS, rawS) + } + + if err = s.SendMsg(&healthpb.HealthCheckRequest{Service: service}); err != nil && err != io.EOF { + // Stream should have been closed, so we can safely continue to create a new stream. + continue retryConnection + } + s.CloseSend() + + resp := new(healthpb.HealthCheckResponse) + for { + err = s.RecvMsg(resp) + + // Reports healthy for the LBing purposes if health check is not implemented in the server. + if status.Code(err) == codes.Unimplemented { + setConnectivityState(connectivity.Ready, nil) + return err + } + + // Reports unhealthy if server's Watch method gives an error other than UNIMPLEMENTED. + if err != nil { + setConnectivityState(connectivity.TransientFailure, fmt.Errorf("connection active but received health check RPC error: %v", err)) + continue retryConnection + } + + // As a message has been received, removes the need for backoff for the next retry by resetting the try count. + tryCnt = 0 + if resp.Status == healthpb.HealthCheckResponse_SERVING { + setConnectivityState(connectivity.Ready, nil) + } else { + setConnectivityState(connectivity.TransientFailure, fmt.Errorf("connection active but health check failed. status=%s", resp.Status)) + } + } + } +} diff --git a/vendor/google.golang.org/grpc/health/logging.go b/vendor/google.golang.org/grpc/health/logging.go new file mode 100644 index 000000000..83c6acf55 --- /dev/null +++ b/vendor/google.golang.org/grpc/health/logging.go @@ -0,0 +1,23 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package health + +import "google.golang.org/grpc/grpclog" + +var logger = grpclog.Component("health_service") diff --git a/vendor/google.golang.org/grpc/health/producer.go b/vendor/google.golang.org/grpc/health/producer.go new file mode 100644 index 000000000..f938e5790 --- /dev/null +++ b/vendor/google.golang.org/grpc/health/producer.go @@ -0,0 +1,106 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package health + +import ( + "context" + "sync" + + "google.golang.org/grpc" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/status" +) + +func init() { + producerBuilderSingleton = &producerBuilder{} + internal.RegisterClientHealthCheckListener = registerClientSideHealthCheckListener +} + +type producerBuilder struct{} + +var producerBuilderSingleton *producerBuilder + +// Build constructs and returns a producer and its cleanup function. +func (*producerBuilder) Build(cci any) (balancer.Producer, func()) { + p := &healthServiceProducer{ + cc: cci.(grpc.ClientConnInterface), + cancel: func() {}, + } + return p, func() { + p.mu.Lock() + defer p.mu.Unlock() + p.cancel() + } +} + +type healthServiceProducer struct { + // The following fields are initialized at build time and read-only after + // that and therefore do not need to be guarded by a mutex. + cc grpc.ClientConnInterface + + mu sync.Mutex + cancel func() +} + +// registerClientSideHealthCheckListener accepts a listener to provide server +// health state via the health service. +func registerClientSideHealthCheckListener(ctx context.Context, sc balancer.SubConn, serviceName string, listener func(balancer.SubConnState)) func() { + pr, closeFn := sc.GetOrBuildProducer(producerBuilderSingleton) + p := pr.(*healthServiceProducer) + p.mu.Lock() + defer p.mu.Unlock() + p.cancel() + if listener == nil { + return closeFn + } + + ctx, cancel := context.WithCancel(ctx) + p.cancel = cancel + + go p.startHealthCheck(ctx, sc, serviceName, listener) + return closeFn +} + +func (p *healthServiceProducer) startHealthCheck(ctx context.Context, sc balancer.SubConn, serviceName string, listener func(balancer.SubConnState)) { + newStream := func(method string) (any, error) { + return p.cc.NewStream(ctx, &grpc.StreamDesc{ServerStreams: true}, method) + } + + setConnectivityState := func(state connectivity.State, err error) { + listener(balancer.SubConnState{ + ConnectivityState: state, + ConnectionError: err, + }) + } + + // Call the function through the internal variable as tests use it for + // mocking. + err := internal.HealthCheckFunc(ctx, newStream, setConnectivityState, serviceName) + if err == nil { + return + } + if status.Code(err) == codes.Unimplemented { + logger.Errorf("Subchannel health check is unimplemented at server side, thus health check is disabled for SubConn %p", sc) + } else { + logger.Errorf("Health checking failed for SubConn %p: %v", sc, err) + } +} diff --git a/vendor/google.golang.org/grpc/health/server.go b/vendor/google.golang.org/grpc/health/server.go new file mode 100644 index 000000000..d4b4b7081 --- /dev/null +++ b/vendor/google.golang.org/grpc/health/server.go @@ -0,0 +1,163 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package health provides a service that exposes server's health and it must be +// imported to enable support for client-side health checks. +package health + +import ( + "context" + "sync" + + "google.golang.org/grpc/codes" + healthgrpc "google.golang.org/grpc/health/grpc_health_v1" + healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/status" +) + +// Server implements `service Health`. +type Server struct { + healthgrpc.UnimplementedHealthServer + mu sync.RWMutex + // If shutdown is true, it's expected all serving status is NOT_SERVING, and + // will stay in NOT_SERVING. + shutdown bool + // statusMap stores the serving status of the services this Server monitors. + statusMap map[string]healthpb.HealthCheckResponse_ServingStatus + updates map[string]map[healthgrpc.Health_WatchServer]chan healthpb.HealthCheckResponse_ServingStatus +} + +// NewServer returns a new Server. +func NewServer() *Server { + return &Server{ + statusMap: map[string]healthpb.HealthCheckResponse_ServingStatus{"": healthpb.HealthCheckResponse_SERVING}, + updates: make(map[string]map[healthgrpc.Health_WatchServer]chan healthpb.HealthCheckResponse_ServingStatus), + } +} + +// Check implements `service Health`. +func (s *Server) Check(_ context.Context, in *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { + s.mu.RLock() + defer s.mu.RUnlock() + if servingStatus, ok := s.statusMap[in.Service]; ok { + return &healthpb.HealthCheckResponse{ + Status: servingStatus, + }, nil + } + return nil, status.Error(codes.NotFound, "unknown service") +} + +// Watch implements `service Health`. +func (s *Server) Watch(in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error { + service := in.Service + // update channel is used for getting service status updates. + update := make(chan healthpb.HealthCheckResponse_ServingStatus, 1) + s.mu.Lock() + // Puts the initial status to the channel. + if servingStatus, ok := s.statusMap[service]; ok { + update <- servingStatus + } else { + update <- healthpb.HealthCheckResponse_SERVICE_UNKNOWN + } + + // Registers the update channel to the correct place in the updates map. + if _, ok := s.updates[service]; !ok { + s.updates[service] = make(map[healthgrpc.Health_WatchServer]chan healthpb.HealthCheckResponse_ServingStatus) + } + s.updates[service][stream] = update + defer func() { + s.mu.Lock() + delete(s.updates[service], stream) + s.mu.Unlock() + }() + s.mu.Unlock() + + var lastSentStatus healthpb.HealthCheckResponse_ServingStatus = -1 + for { + select { + // Status updated. Sends the up-to-date status to the client. + case servingStatus := <-update: + if lastSentStatus == servingStatus { + continue + } + lastSentStatus = servingStatus + err := stream.Send(&healthpb.HealthCheckResponse{Status: servingStatus}) + if err != nil { + return status.Error(codes.Canceled, "Stream has ended.") + } + // Context done. Removes the update channel from the updates map. + case <-stream.Context().Done(): + return status.Error(codes.Canceled, "Stream has ended.") + } + } +} + +// SetServingStatus is called when need to reset the serving status of a service +// or insert a new service entry into the statusMap. +func (s *Server) SetServingStatus(service string, servingStatus healthpb.HealthCheckResponse_ServingStatus) { + s.mu.Lock() + defer s.mu.Unlock() + if s.shutdown { + logger.Infof("health: status changing for %s to %v is ignored because health service is shutdown", service, servingStatus) + return + } + + s.setServingStatusLocked(service, servingStatus) +} + +func (s *Server) setServingStatusLocked(service string, servingStatus healthpb.HealthCheckResponse_ServingStatus) { + s.statusMap[service] = servingStatus + for _, update := range s.updates[service] { + // Clears previous updates, that are not sent to the client, from the channel. + // This can happen if the client is not reading and the server gets flow control limited. + select { + case <-update: + default: + } + // Puts the most recent update to the channel. + update <- servingStatus + } +} + +// Shutdown sets all serving status to NOT_SERVING, and configures the server to +// ignore all future status changes. +// +// This changes serving status for all services. To set status for a particular +// services, call SetServingStatus(). +func (s *Server) Shutdown() { + s.mu.Lock() + defer s.mu.Unlock() + s.shutdown = true + for service := range s.statusMap { + s.setServingStatusLocked(service, healthpb.HealthCheckResponse_NOT_SERVING) + } +} + +// Resume sets all serving status to SERVING, and configures the server to +// accept all future status changes. +// +// This changes serving status for all services. To set status for a particular +// services, call SetServingStatus(). +func (s *Server) Resume() { + s.mu.Lock() + defer s.mu.Unlock() + s.shutdown = false + for service := range s.statusMap { + s.setServingStatusLocked(service, healthpb.HealthCheckResponse_SERVING) + } +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 299b72801..d7cad6e22 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -623,6 +623,7 @@ google.golang.org/grpc/encoding/proto google.golang.org/grpc/experimental/stats google.golang.org/grpc/grpclog google.golang.org/grpc/grpclog/internal +google.golang.org/grpc/health google.golang.org/grpc/health/grpc_health_v1 google.golang.org/grpc/internal google.golang.org/grpc/internal/backoff