diff --git a/yafsm/constants.go b/yafsm/constants.go new file mode 100644 index 0000000..a2ce750 --- /dev/null +++ b/yafsm/constants.go @@ -0,0 +1,6 @@ +package yafsm + +const ( + stateKey = "state" + stateDataKey = "stateData" +) diff --git a/yafsm/entityfsm.go b/yafsm/entityfsm.go new file mode 100644 index 0000000..b387bd7 --- /dev/null +++ b/yafsm/entityfsm.go @@ -0,0 +1,77 @@ +package yafsm + +import ( + "context" + + "github.com/YaCodeDev/GoYaCodeDevUtils/yaerrors" +) + +// EntityFSMStorage is a wrapper over FSM to work with specific entity (user, chat, etc). +type EntityFSMStorage struct { + storage FSM + uid string +} + +// NewUserFSMStorage creates a new EntityFSMStorage for a specific user ID. +// +// Example usage: +// +// userFSMStorage := NewUserFSMStorage(fsmStorage, "user123") +func NewUserFSMStorage( + storage FSM, + uid string, +) *EntityFSMStorage { + return &EntityFSMStorage{ + storage: storage, + uid: uid, + } +} + +// SetState sets the state for the entity. +// +// Example usage: +// +// err := userFSMStorage.SetState(ctx, &SomeState{Field: "value"}) +// +// if err != nil { +// // handle error +// } +func (b *EntityFSMStorage) SetState( + ctx context.Context, + stateData State, +) yaerrors.Error { + return b.storage.SetState(ctx, b.uid, stateData) +} + +// GetState retrieves the current state and its data for the entity. +// +// Example usage: +// +// stateName, stateData, err := userFSMStorage.GetState(ctx) +// +// if err != nil { +// // handle error +// } +func (b *EntityFSMStorage) GetState( + ctx context.Context, +) (string, stateDataMarshalled, yaerrors.Error) { // nolint: revive + return b.storage.GetState(ctx, b.uid) +} + +// GetStateData unmarshals the state data into the provided empty state struct. +// +// Example usage: +// +// var stateData SomeState +// +// err := userFSMStorage.GetStateData(marshalledData, &stateData) +// +// if err != nil { +// // handle error +// } +func (b *EntityFSMStorage) GetStateData( + stateData stateDataMarshalled, + emptyState State, +) yaerrors.Error { + return b.storage.GetStateData(stateData, emptyState) +} diff --git a/yafsm/fsm.go b/yafsm/fsm.go new file mode 100644 index 0000000..8bcd15a --- /dev/null +++ b/yafsm/fsm.go @@ -0,0 +1,196 @@ +package yafsm + +import ( + "context" + "encoding/json" + "net/http" + "reflect" + + "github.com/YaCodeDev/GoYaCodeDevUtils/yacache" + "github.com/YaCodeDev/GoYaCodeDevUtils/yaerrors" +) + +// State is an interface that all states must implement. +type State interface { + StateName() string +} + +// BaseState provides a default implementation of the State interface. +type BaseState[T State] struct{} + +// StateName returns the name of the state type. +func (BaseState[T]) StateName() string { + var zero T + + t := reflect.TypeOf(zero) + if t.Kind() == reflect.Pointer { + t = t.Elem() + } + + return t.Name() +} + +// Empty state is implementation of State interface with no data. +type EmptyState struct { + BaseState[EmptyState] +} + +// stateDataMarshalled is a type alias for marshalled state data. +type stateDataMarshalled string + +// StateAndData is a struct that holds the state name and its marshalled data. +type StateAndData struct { + State string `json:"state"` + StateData string `json:"stateData"` +} + +// FSM is an interface for finite state machine storage.\ +type FSM interface { + SetState(ctx context.Context, uid string, state State) yaerrors.Error + GetState(ctx context.Context, uid string) (string, stateDataMarshalled, yaerrors.Error) + GetStateData(stateData stateDataMarshalled, emptyState State) yaerrors.Error +} + +// DefaultFSMStorage is a default implementation of the FSM interface using yacache. +type DefaultFSMStorage[T yacache.Container] struct { + storage yacache.Cache[T] + defaultState State +} + +// NewDefaultFSMStorage creates a new instance of DefaultFSMStorage. +// +// Example usage: +// +// cache := yacache.NewCache(redisClient) +// +// fsmStorage := fsm.NewDefaultFSMStorage(cache, fsm.EmptyState{}) +func NewDefaultFSMStorage[T yacache.Container]( + storage yacache.Cache[T], + defaultState State, +) *DefaultFSMStorage[T] { + return &DefaultFSMStorage[T]{ + storage: storage, + defaultState: defaultState, + } +} + +// SetState sets the state for a given user ID. +// The state data is marshalled to JSON before being stored. +// +// Example usage: +// +// err := fsmStorage.SetState(ctx, "123", &SomeState{Field: "value"}) +func (b *DefaultFSMStorage[T]) SetState( + ctx context.Context, + uid string, + stateData State, +) yaerrors.Error { + val, err := json.Marshal(stateData) + if err != nil { + return yaerrors.FromError( + http.StatusInternalServerError, + err, + "failed to marshal state data", + ) + } + + val, err = json.Marshal(StateAndData{ + State: stateData.StateName(), + StateData: string(val), + }) + if err != nil { + return yaerrors.FromError( + http.StatusInternalServerError, + err, + "failed to marshal state data", + ) + } + + return b.storage.Set(ctx, uid, string(val), 0) +} + +// GetState retrieves the current state and its marshalled data for a given user ID. +// If no state is found, it returns the default state. +// +// Example usage: +// +// stateName, stateData, err := fsmStorage.GetState(ctx, "123") +func (b *DefaultFSMStorage[T]) GetState( + ctx context.Context, + uid string, +) (string, stateDataMarshalled, yaerrors.Error) { // nolint: revive + data, err := b.storage.Get(ctx, uid) + if err != nil { + return b.defaultState.StateName(), "", nil + } + + var stateAndData map[string]string + + if err := json.Unmarshal([]byte(data), &stateAndData); err != nil { + return "", "", yaerrors.FromError( + http.StatusInternalServerError, + err, + "failed to unmarshal state data map", + ) + } + + state, ok := stateAndData[stateKey] + + if !ok { + return "", "", yaerrors.FromString( + http.StatusNotFound, + "failed to get state", + ) + } + + return state, stateDataMarshalled(data), nil +} + +// GetStateData unmarshals the state data into the provided empty state struct. +// +// Example usage: +// +// var stateData SomeState +// +// err := fsmStorage.GetStateData(marshalledData, &stateData) +// +// if err != nil { +// // handle error +// } +func (b *DefaultFSMStorage[T]) GetStateData( + stateData stateDataMarshalled, + emptyState State, +) yaerrors.Error { + if stateData == "" { + return nil + } + + var stateAndData map[string]string + + if err := json.Unmarshal([]byte(stateData), &stateAndData); err != nil { + return yaerrors.FromError( + http.StatusInternalServerError, + err, + "failed to unmarshal state data map", + ) + } + + stateDataMarshalled, ok := stateAndData[stateDataKey] + + if !ok { + return yaerrors.FromString( + http.StatusNotFound, + "failed to get state data", + ) + } + + if err := json.Unmarshal([]byte(stateDataMarshalled), emptyState); err != nil { + return yaerrors.FromError( + http.StatusInternalServerError, + err, + "failed to unmarshal state data", + ) + } + + return nil +} diff --git a/yafsm/fsm_test.go b/yafsm/fsm_test.go new file mode 100644 index 0000000..c8b3c51 --- /dev/null +++ b/yafsm/fsm_test.go @@ -0,0 +1,101 @@ +package yafsm_test + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/YaCodeDev/GoYaCodeDevUtils/yacache" + "github.com/YaCodeDev/GoYaCodeDevUtils/yafsm" +) + +type ExampleState struct { + yafsm.BaseState[ExampleState] + + Param string `json:"param"` +} + +func TestFSMStorage_SetGetRoundTrip(t *testing.T) { + ctx := context.Background() + + cache := yacache.NewCache(yacache.NewMemoryContainer()) + + fsm := yafsm.NewDefaultFSMStorage(cache, yafsm.EmptyState{}) + + uid := "12345" + wantParam := "exampleparam" + + // 1) set + if err := fsm.SetState(ctx, uid, ExampleState{Param: wantParam}); err != nil { + t.Fatalf("SetState failed: %v", err) + } + + // 2) get state name + raw payload + stateName, raw, err := fsm.GetState(ctx, uid) + if err != nil { + t.Fatalf("GetState failed: %v", err) + } + + if stateName != (ExampleState{}).StateName() { + t.Fatalf("unexpected state name: want %q, got %q", + (ExampleState{}).StateName(), stateName) + } + + // 3) unmarshal into struct + var got ExampleState + if err := fsm.GetStateData(raw, &got); err != nil { + t.Fatalf("GetStateData failed: %v", err) + } + + if got.Param != wantParam { + t.Fatalf("unexpected param: want %q, got %q", wantParam, got.Param) + } +} + +func TestFSMStorage_DefaultStateReturned(t *testing.T) { + ctx := context.Background() + + cache := yacache.NewCache(yacache.NewMemoryContainer()) + fsm := yafsm.NewDefaultFSMStorage(cache, yafsm.EmptyState{}) + + uid := "non-existent" + + name, raw, err := fsm.GetState(ctx, uid) + if err != nil { + t.Fatalf("GetState failed: %v", err) + } + + if name != (yafsm.EmptyState{}).StateName() { + t.Fatalf("expected default state name %q, got %q", + (yafsm.EmptyState{}).StateName(), name) + } + + if raw != "" { + t.Fatalf("expected empty raw data, got %q", raw) + } +} + +func TestFSMStorage_CorruptedPayload(t *testing.T) { + ctx := context.Background() + + cache := yacache.NewCache(yacache.NewMemoryContainer()) + fsm := yafsm.NewDefaultFSMStorage(cache, yafsm.EmptyState{}) + + uid := "bad:user" + + err := cache.Set(ctx, uid, "{not:a:json}", 0) + if err != nil { + t.Fatalf("failed to set corrupted data: %v", err) + } + + _, _, err = fsm.GetState(ctx, uid) + if err == nil { + t.Fatal("expected error on corrupted JSON, got nil") + } + + var syntaxErr *json.SyntaxError + if !errors.As(err, &syntaxErr) { + t.Fatalf("expected json.SyntaxError, got %v", err) + } +} diff --git a/yatgbot/dispatcher.go b/yatgbot/dispatcher.go new file mode 100644 index 0000000..0ea0677 --- /dev/null +++ b/yatgbot/dispatcher.go @@ -0,0 +1,247 @@ +package yatgbot + +import ( + "context" + "errors" + "net/http" + "strconv" + + "github.com/YaCodeDev/GoYaCodeDevUtils/yaerrors" + "github.com/YaCodeDev/GoYaCodeDevUtils/yafsm" + "github.com/YaCodeDev/GoYaCodeDevUtils/yalocales" + "github.com/YaCodeDev/GoYaCodeDevUtils/yalogger" + "github.com/YaCodeDev/GoYaCodeDevUtils/yatgbot/messagequeue" + "github.com/YaCodeDev/GoYaCodeDevUtils/yatgclient" + "github.com/gotd/td/tg" +) + +// Dispatcher is responsible for routing updates to the appropriate handlers based on defined routes and filters. +type Dispatcher struct { + FSMStore yafsm.FSM + Log yalogger.Logger + BotUser *tg.User + MessageDispatcher *messagequeue.Dispatcher + Localizer yalocales.Localizer + Client *yatgclient.Client + MainRouter *RouterGroup +} + +// UpdateData holds the dependencies required for dispatching an update. +type UpdateData struct { + userID int64 + chatID int64 + ent tg.Entities + update tg.UpdateClass + inputPeer tg.InputPeerClass +} + +// dispatch processes the update by checking filters and executing the appropriate handler. +// It also supports nested routers by dispatching to sub-routers if no local route matches. +func (r *Dispatcher) dispatch(ctx context.Context, deps UpdateData) yaerrors.Error { + userFSMStorage := yafsm.NewUserFSMStorage( + r.FSMStore, + strconv.FormatInt(deps.chatID, 10), + ) + + r.Log.Debugf("Processing update: %+v with entities: %+v", deps.update, deps.ent) + + for _, rt := range r.MainRouter.routes { + ok, err := r.checkFilters( + ctx, + FilterDependencies{ + update: deps.update, + storage: *userFSMStorage, + userID: deps.userID, + }, + rt.filters) + if err != nil { + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + err, + "failed to apply filters", + r.Log, + ) + } + + if !ok { + r.Log.Debugf("Filters not passed for %T", deps.update) + + continue + } + + var localizer yalocales.Localizer + + if user, ok := deps.ent.Users[deps.userID]; ok && user.LangCode != "" { + localizer, err = r.Localizer.DeriveNewDefaultLang(user.LangCode) + if err != nil { + if !errors.Is(err, yalocales.ErrInvalidLanguage) { + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + err, + "failed to derive localizer", + r.Log, + ) + } + + localizer = r.Localizer + } + + r.Log.Debugf("Using user %d language: %s", deps.userID, user.LangCode) + } + + hdata := &HandlerData{ + Entities: deps.ent, + Update: deps.update, + UserID: deps.userID, + Peer: deps.inputPeer, + StateStorage: userFSMStorage, + Log: r.Log, + Dispatcher: r.MessageDispatcher, + Localizer: localizer, + Client: r.Client, + } + + err = chainMiddleware( + rt.handler, + r.MainRouter.collectMiddlewares()...)( + ctx, + hdata, + deps.update, + ) + if err != nil { + if errors.Is(err, ErrRouteMismatch) { + continue + } + + return err.Wrap("handler execution failed") + } + + return nil + } + + for _, sub := range r.MainRouter.sub { + r.MainRouter = sub + + err := r.dispatch(ctx, deps) + if err != nil { + return err.Wrap("sub-router dispatch failed") + } + } + + return nil +} + +// checkFilters checks the filters of the current router and its parents recursively. +func (r *Dispatcher) checkFilters( + ctx context.Context, + deps FilterDependencies, + local []Filter, +) (bool, yaerrors.Error) { + // 1) Build the chain from current group up to root. + var chain []*RouterGroup + for g := r.MainRouter; g != nil; g = g.parent { + chain = append(chain, g) + } + + // 2) Reverse so we run filters from root -> current. + for i, j := 0, len(chain)-1; i < j; i, j = i+1, j-1 { + chain[i], chain[j] = chain[j], chain[i] + } + + // 3) Run base filters in that order. + for _, grp := range chain { + for _, f := range grp.base { + ok, err := f(ctx, deps) + if err != nil { + return false, err.Wrap("base filter check failed") + } + + if !ok { + return false, nil + } + } + } + + // 4) Run local (route) filters last. + for _, f := range local { + ok, err := f(ctx, deps) + if err != nil { + return false, err.Wrap("local filter check failed") + } + + if !ok { + return false, nil + } + } + + return true, nil +} + +// makeInputPeer converts a tg.PeerClass to a tg.InputPeerClass using the provided entities. +func makeInputPeer(p tg.PeerClass, ents tg.Entities) (tg.InputPeerClass, bool) { + switch v := p.(type) { + case *tg.PeerUser: + u, ok := ents.Users[v.UserID] + if !ok { + return nil, false + } + + return &tg.InputPeerUser{ + UserID: v.UserID, + AccessHash: u.AccessHash, + }, true + + case *tg.PeerChat: + return &tg.InputPeerChat{ChatID: v.ChatID}, true + + case *tg.PeerChannel: + c, ok := ents.Channels[v.ChannelID] + if !ok { + return nil, false + } + + return &tg.InputPeerChannel{ + ChannelID: v.ChannelID, + AccessHash: c.AccessHash, + }, true + } + + return nil, false +} + +// getChatID extracts the chat ID from a tg.PeerClass using the provided entities. +func getChatID(peer tg.PeerClass, ents tg.Entities) (int64, bool) { + switch v := peer.(type) { + case *tg.PeerUser: + return v.UserID, true + case *tg.PeerChat: + return v.ChatID, true + case *tg.PeerChannel: + c, ok := ents.Channels[v.ChannelID] + if !ok { + return 0, false + } + + return c.ID, true + default: + return 0, false + } +} + +// getUserID extracts the user ID from a tg.PeerClass or from the FromID field if available. +func getUserID(peer tg.PeerClass, fromID tg.PeerClass) (int64, bool) { + switch v := peer.(type) { + case *tg.PeerUser: + return v.UserID, true + + case *tg.PeerChat: + if fromUser, ok := fromID.(*tg.PeerUser); ok { + return fromUser.UserID, true + } + + return 0, false + + default: + return 0, false + } +} diff --git a/yatgbot/filter.go b/yatgbot/filter.go new file mode 100644 index 0000000..760c47f --- /dev/null +++ b/yatgbot/filter.go @@ -0,0 +1,186 @@ +package yatgbot + +import ( + "context" + "net/http" + "regexp" + "strings" + + "github.com/gotd/td/tg" + + "github.com/YaCodeDev/GoYaCodeDevUtils/yaerrors" + "github.com/YaCodeDev/GoYaCodeDevUtils/yafsm" +) + +// Filter is a function that determines whether a given update should be processed +type Filter func(ctx context.Context, deps FilterDependencies) (bool, yaerrors.Error) + +// FilterDependencies holds the dependencies required by filters +type FilterDependencies struct { + storage yafsm.EntityFSMStorage + userID int64 + update tg.UpdateClass +} + +// StateIs creates a filter that checks if the user's state matches any of the provided states. +// +// Example usage: +// +// router.OnMessage(YourMessageHandler, router.StateIs("StateA", "StateB")) +func StateIs(want ...string) Filter { + wanted := make(map[string]struct{}, len(want)) + + for _, s := range want { + wanted[s] = struct{}{} + } + + return func(ctx context.Context, deps FilterDependencies) (bool, yaerrors.Error) { + state, _, err := deps.storage.GetState(ctx) + if err != nil { + return false, yaerrors.FromError( + http.StatusInternalServerError, + err, "failed to get state for user %d", + ) + } + + _, ok := wanted[state] + + return ok, nil + } +} + +// TextEq creates a filter that checks if the message text equals the specified string. +// +// Example usage: +// +// router.OnMessage(YourMessageHandler, router.TextEq("Hello")) +func TextEq(want string) Filter { + return func(_ context.Context, deps FilterDependencies) (bool, yaerrors.Error) { + if m, ok := ExtractMessageFromUpdate(deps.update); ok && m.Message == want { + return true, nil + } + + return false, nil + } +} + +// TextRegex creates a filter that checks if the message text matches the specified regex. +// +// Example usage: +// +// router.OnMessage(YourMessageHandler, router.TextRegex(regexp.MustCompile(`^Hello.*`))) +func TextRegex(re *regexp.Regexp) Filter { + return func(_ context.Context, deps FilterDependencies) (bool, yaerrors.Error) { + if m, ok := ExtractMessageFromUpdate(deps.update); ok && re.MatchString(m.Message) { + return true, nil + } + + return false, nil + } +} + +// CallbackEq creates a filter that checks if the callback query data equals the specified string. +// +// Example usage: +// +// router.OnCallback(YourCallbackHandler, router.CallbackEq("some_data")) +func CallbackEq(data string) Filter { + return func(_ context.Context, deps FilterDependencies) (bool, yaerrors.Error) { + if q, ok := deps.update.(*tg.UpdateBotCallbackQuery); ok && string(q.Data) == data { + return true, nil + } + + return false, nil + } +} + +// CallbackPrefix creates a filter that checks if the callback query data starts with the specified prefix. +// +// Example usage: +// +// router.OnCallback(YourCallbackHandler, router.CallbackPrefix("prefix_")) +func CallbackPrefix(prefix string) Filter { + return func(_ context.Context, deps FilterDependencies) (bool, yaerrors.Error) { + if q, ok := deps.update.(*tg.UpdateBotCallbackQuery); ok && + strings.HasPrefix(string(q.Data), prefix) { + return true, nil + } + + return false, nil + } +} + +// MessageServiceActionFilter creates a filter that checks if the message service action +// matches the specified type T. +// +// Example usage: +// +// router.OnMessageService(YourMessageServiceHandler, router.MessageServiceActionFilter[*tg.MessageActionChatCreate]()) +func MessageServiceActionFilter[T tg.MessageActionClass]() Filter { + return func(_ context.Context, deps FilterDependencies) (bool, yaerrors.Error) { + if messageService, ok := ExtractMessageServiceFromUpdate(deps.update); ok { + _, ok := messageService.Action.(T) + + return ok, nil + } + + return false, nil + } +} + +// MessageServiceFilter creates a filter that checks if the update contains a MessageService. +// +// Example usage: +// +// router.OnMessageService(YourMessageServiceHandler, router.MessageServiceFilter()) +func MessageServiceFilter() Filter { + return func(_ context.Context, deps FilterDependencies) (bool, yaerrors.Error) { + _, ok := ExtractMessageServiceFromUpdate(deps.update) + + return ok, nil + } +} + +// OneOfFilter creates a filter that passes if any of the provided filters pass. +// +// Example usage: +// +// router.OnMessage(YourMessageHandler, router.OneOfFilter(filter1, filter2)) +func OneOfFilter(filters ...Filter) Filter { + return func(ctx context.Context, deps FilterDependencies) (bool, yaerrors.Error) { + for _, f := range filters { + ok, err := f(ctx, deps) + if err != nil { + return false, err.Wrap("or-filter check failed") + } + + if ok { + return true, nil + } + } + + return false, nil + } +} + +// AllOfFilter creates a filter that passes only if all of the provided filters pass. +// +// Example usage: +// +// router.OnMessage(YourMessageHandler, router.AllOfFilter(filter1, filter2)) +func AllOfFilter(filters ...Filter) Filter { + return func(ctx context.Context, deps FilterDependencies) (bool, yaerrors.Error) { + for _, f := range filters { + ok, err := f(ctx, deps) + if err != nil { + return false, err.Wrap("and-filter check failed") + } + + if !ok { + return false, nil + } + } + + return true, nil + } +} diff --git a/yatgbot/messagequeue/constants.go b/yatgbot/messagequeue/constants.go new file mode 100644 index 0000000..040e1c1 --- /dev/null +++ b/yatgbot/messagequeue/constants.go @@ -0,0 +1,6 @@ +package messagequeue + +const ( + PriorityQueueAllocSize = 1024 + SingleMessage = 1 +) diff --git a/yatgbot/messagequeue/heap.go b/yatgbot/messagequeue/heap.go new file mode 100644 index 0000000..d1aba94 --- /dev/null +++ b/yatgbot/messagequeue/heap.go @@ -0,0 +1,238 @@ +package messagequeue + +import ( + "cmp" + "context" + "fmt" + "net/http" + "slices" + "sync" + "time" + + "github.com/YaCodeDev/GoYaCodeDevUtils/yaerrors" + "github.com/gotd/td/bin" + "github.com/gotd/td/tg" +) + +// MessageJob represents a job to send a message with a certain priority. +type MessageJob struct { + ID uint64 + Priority uint16 + Request bin.Encoder + ResultCh chan JobResult + Timestamp time.Time + IsPlaceholder bool + TaskCount uint +} + +// JobResult represents the result of a message job execution. +type JobResult struct { + Updates tg.UpdatesClass + Err yaerrors.Error +} + +// Execute performs the message sending operation. +// If the job is a placeholder, it returns an empty result. +// If the job has multiple tasks, it adds empty jobs to the dispatcher to compensate for it. +// +// Example usage: +// +// result := job.Execute(ctx, dispatcher, workerID) +// +// if result.Err != nil { +// // Handle error +// } else { +// // Process result.Updates +// } +func (j MessageJob) Execute( + ctx context.Context, + dispatcher *Dispatcher, + workerID uint, +) JobResult { + var yaErr yaerrors.Error + + if j.IsPlaceholder { + return JobResult{} + } + + if j.TaskCount > 1 { + dispatcher.AddEmptyJob(j.TaskCount - 1) + } + + var result tg.UpdatesBox + + err := dispatcher.Client.Invoke(ctx, j.Request, &result) + if err != nil { + yaErr = yaerrors.FromError( + http.StatusInternalServerError, + err, + fmt.Sprintf("worker %d: failed to send message", workerID), + ) + } + + return JobResult{ + Updates: result.Updates, + Err: yaErr, + } +} + +// messageHeap is a thread-safe priority queue for MessageJob. +type messageHeap struct { + jobs []MessageJob + mu sync.Mutex +} + +// newMessageHeap creates a new instance of messageHeap. +// +// Example usage: +// +// heap := newMessageHeap() +func newMessageHeap() messageHeap { + return messageHeap{ + jobs: make([]MessageJob, 0, PriorityQueueAllocSize), + } +} + +// sort sorts the jobs in the heap based on priority and timestamp. +// Placeholders are always sorted to the end. +// Higher priority jobs come first, and for equal priority, older jobs come first. +func (h *messageHeap) sort() { + slices.SortFunc(h.jobs, func(a, b MessageJob) int { + if a.IsPlaceholder && b.IsPlaceholder { + return 0 + } + + if a.IsPlaceholder { + return 1 + } + + if b.IsPlaceholder { + return -1 + } + + if a.Priority != b.Priority { + return cmp.Compare(b.Priority, a.Priority) + } + + switch { + case a.Timestamp.Before(b.Timestamp): + return 1 + case a.Timestamp.After(b.Timestamp): + return -1 + default: + return 0 + } + }) +} + +// Push adds a new job to the heap and sorts it. +// +// Example usage: +// +// heap.Push(job) +func (h *messageHeap) Push(job MessageJob) { + h.mu.Lock() + + h.jobs = append(h.jobs, job) + h.sort() + + h.mu.Unlock() +} + +// Len returns the number of jobs in the heap. +// +// Example usage: +// +// length := heap.Len() +func (h *messageHeap) Len() int { + h.mu.Lock() + defer h.mu.Unlock() + + return len(h.jobs) +} + +// Pop removes and returns the highest priority job from the heap. +// +// Example usage: +// +// job, ok := heap.Pop() +// +// if !ok { +// // Handle empty heap +// } +func (h *messageHeap) Pop() (MessageJob, bool) { + if h.Len() == 0 { + return MessageJob{}, false + } + + h.mu.Lock() + + last := len(h.jobs) - 1 + job := h.jobs[last] + h.jobs = h.jobs[:last] + + h.mu.Unlock() + + return job, true +} + +// Delete removes a job with the specified ID from the heap. +// Returns true if the job was found and deleted, false otherwise. +// +// Example usage: +// +// deleted := heap.Delete(jobID) +// +// if !deleted { +// // Handle job not found +// } +func (h *messageHeap) Delete(id uint64) bool { + h.mu.Lock() + defer h.mu.Unlock() + + for i, job := range h.jobs { + if job.ID == id { + h.jobs = append(h.jobs[:i], h.jobs[i+1:]...) + + return true + } + } + + return false +} + +// DeleteFunc removes jobs that satisfy the given condition from the heap. +// Returns a slice of IDs of the deleted jobs. +// +// Example usage: +// +// deletedIDs := heap.DeleteFunc(func(job MessageJob) bool { +// return job.Priority < 10 +// }) +// +// if len(deletedIDs) == 0 { +// // Handle no jobs deleted +// } +func (h *messageHeap) DeleteFunc(deleteFunc func(MessageJob) bool) []uint64 { + var deletedEntries []uint64 + + h.mu.Lock() + + newJobs := make([]MessageJob, 0, len(h.jobs)) + + for _, job := range h.jobs { + if deleteFunc(job) { + deletedEntries = append(deletedEntries, job.ID) + + continue + } + + newJobs = append(newJobs, job) + } + + h.jobs = newJobs + + h.mu.Unlock() + + return deletedEntries +} diff --git a/yatgbot/messagequeue/heap_test.go b/yatgbot/messagequeue/heap_test.go new file mode 100644 index 0000000..e5c0f13 --- /dev/null +++ b/yatgbot/messagequeue/heap_test.go @@ -0,0 +1,93 @@ +package messagequeue + +import ( + "testing" + "time" +) + +func mustPop(t *testing.T, h *messageHeap) MessageJob { + t.Helper() + + job, ok := h.Pop() + if !ok { + t.Fatalf("expected job, heap is empty") + } + + return job +} + +func TestHeap_PushPopOrdering(t *testing.T) { + h := newMessageHeap() + now := time.Now() + + // ID 3: highest priority (1) and *oldest* timestamp + h.Push(MessageJob{ID: 3, Priority: 1, Timestamp: now.Add(-2 * time.Minute)}) + // ID 2: same priority 1 but newer than ID 3 + h.Push(MessageJob{ID: 2, Priority: 1, Timestamp: now}) + // ID 1: lower priority (2) + h.Push(MessageJob{ID: 1, Priority: 2, Timestamp: now}) + // ID 4: placeholder – always first + h.Push(MessageJob{ID: 4, IsPlaceholder: true}) + + if h.Len() != 4 { + t.Fatalf("expected heap len 4, got %d", h.Len()) + } + + wantOrder := []uint64{4, 3, 2, 1} + for i, wantID := range wantOrder { + got := mustPop(t, &h) + if got.ID != wantID { + t.Fatalf("pop #%d: want ID %d, got %d", i+1, wantID, got.ID) + } + } + + if h.Len() != 0 { + t.Fatalf("expected empty heap, len=%d", h.Len()) + } +} + +func TestHeap_DeleteByID(t *testing.T) { + h := newMessageHeap() + h.Push(MessageJob{ID: 10}) + h.Push(MessageJob{ID: 20}) + + if !h.Delete(10) { + t.Fatalf("Delete should return true for existing ID") + } + + if h.Len() != 1 { + t.Fatalf("expected len 1 after delete, got %d", h.Len()) + } + + if h.Delete(42) { + t.Fatalf("Delete should return false for missing ID") + } +} + +func TestHeap_DeleteFunc(t *testing.T) { + h := newMessageHeap() + + h.Push(MessageJob{ID: 1, Priority: 5}) + h.Push(MessageJob{ID: 2, Priority: 3}) + h.Push(MessageJob{ID: 3, Priority: 1}) + + deleted := h.DeleteFunc(func(j MessageJob) bool { return j.Priority < 4 }) + if len(deleted) != 2 { + t.Fatalf("expected 2 jobs deleted, got %d", len(deleted)) + } + + if h.Len() != 1 { + t.Fatalf("expected heap len 1 after DeleteFunc, got %d", h.Len()) + } + + if deleted[0] == deleted[1] { + t.Fatalf("deleted IDs should be unique, got %+v", deleted) + } +} + +func TestHeap_PopOnEmpty(t *testing.T) { + h := newMessageHeap() + if _, ok := h.Pop(); ok { + t.Fatalf("expected ok==false on empty Pop") + } +} diff --git a/yatgbot/messagequeue/messagequeue.go b/yatgbot/messagequeue/messagequeue.go new file mode 100644 index 0000000..a2ebe57 --- /dev/null +++ b/yatgbot/messagequeue/messagequeue.go @@ -0,0 +1,253 @@ +package messagequeue + +import ( + "context" + "math/rand" + "sync" + "time" + + "github.com/YaCodeDev/GoYaCodeDevUtils/yalogger" + "github.com/YaCodeDev/GoYaCodeDevUtils/yatgclient" + "github.com/gotd/td/bin" + "github.com/gotd/td/tg" +) + +// Dispatcher handles message sending with priority and concurrency control. +type Dispatcher struct { + Client *yatgclient.Client + messageJobChannel chan MessageJob // TODO: rename channel name + heap messageHeap + cond sync.Cond + log yalogger.Logger +} + +// NewDispatcher creates a new Dispatcher with the given number of workers. +// Each worker processes jobs from the priority queue channel. +// The dispatcher uses a condition variable to signal when new jobs are added to the heap. +// It also initializes the message heap and starts the worker goroutines. +// +// Example usage: +// +// dispatcher := NewDispatcher(ctx, 5, log) +func NewDispatcher( + ctx context.Context, + client *yatgclient.Client, + workerCount uint, + log yalogger.Logger, +) *Dispatcher { + dispatcher := &Dispatcher{ + Client: client, + messageJobChannel: make(chan MessageJob), + log: log, + heap: newMessageHeap(), + cond: *sync.NewCond(&sync.Mutex{}), + } + + go dispatcher.proccessMessagesQueue() + + for i := range workerCount { + go dispatcher.worker(ctx, i) + } + + return dispatcher +} + +// DeleteJob removes a job from the heap by its ID. +// Returns true if the job was found and deleted, false otherwise. +// +// Example usage: +// +// deleted := dispatcher.DeleteJob(jobID) +// +// if !deleted { +// // Handle job not found +// } +func (d *Dispatcher) DeleteJob(id uint64) bool { + return d.heap.Delete(id) +} + +// DeleteJobFunc removes jobs from the heap that satisfy the given condition. +// +// Example usage: +// +// deletedIDs := dispatcher.DeleteJobFunc(func(job MessageJob) bool { +// return job.Priority < 10 +// }) +// +// for _, id := range deletedIDs { +// // Handle deleted job ID +// } +func (d *Dispatcher) DeleteJobFunc(deleteFunc func(MessageJob) bool) []uint64 { + return d.heap.DeleteFunc(deleteFunc) +} + +// AddRawJob adds a raw job to the dispatcher with the specified request, priority, and task count. +// It returns the job ID and a channel to receive the job result. +// +// Example usage: +// +// jobID, resultCh := dispatcher.AddRawJob(request, priority, taskCount) +// +// // Wait for the job result +// result := <-resultCh +// +// if result.Err != nil { +// // Handle job error +// } +func (d *Dispatcher) AddRawJob( + request bin.Encoder, + priority uint16, + taskCount uint, +) (uint64, <-chan JobResult) { + job := MessageJob{ + ID: rand.Uint64(), + Priority: priority, + Request: request, + ResultCh: make(chan JobResult, 1), + Timestamp: time.Now(), + TaskCount: taskCount, + } + + d.heap.Push(job) + + d.cond.Signal() + + return job.ID, job.ResultCh +} + +// AddEmptyJob adds the specified number of placeholder jobs to the dispatcher. +// +// Example usage: +// +// dispatcher.AddEmptyJob(5) // Adds 5 placeholder jobs +func (d *Dispatcher) AddEmptyJob(count uint) { + for range count { + d.heap.Push(MessageJob{ + IsPlaceholder: true, + }) + } +} + +// AddForwardMessagesJob adds a message forwarding job to the dispatcher. +// +// Example usage: +// +// jobID, resultCh := dispatcher.AddForwardMessagesJob(messagesForwardMessagesRequest, priority) +// +// // Wait for the job result +// result := <-resultCh +// +// if result.Err != nil { +// // Handle job error +// } +func (d *Dispatcher) AddForwardMessagesJob( + req *tg.MessagesForwardMessagesRequest, + priority uint16, +) (uint64, <-chan JobResult) { + req.RandomID = make([]int64, len(req.ID)) + for i := range req.RandomID { + req.RandomID[i] = rand.Int63() + } + + return d.AddRawJob(req, priority, uint(len(req.RandomID))) +} + +// AddSendMessageJob adds a message sending job to the dispatcher. +// +// Example usage: +// +// jobID, resultCh := dispatcher.AddSendMessageJob(messagesSendMessageRequest, priority) +// +// // Wait for the job result +// result := <-resultCh +// +// if result.Err != nil { +// // Handle job error +// } +func (d *Dispatcher) AddSendMessageJob( + req *tg.MessagesSendMessageRequest, + priority uint16, +) (uint64, <-chan JobResult) { + if req.RandomID == 0 { + req.RandomID = rand.Int63() + } + + return d.AddRawJob(req, priority, SingleMessage) +} + +// AddSendMultiMediaJob adds a media sending job to the dispatcher. +// +// Example usage: +// +// jobID, resultCh := dispatcher.AddSendMultiMediaJob(messagesSendMediaRequest, priority) +// +// // Wait for the job result +// result := <-resultCh +// +// if result.Err != nil { +// // Handle job error +// } +func (d *Dispatcher) AddSendMultiMediaJob( + req *tg.MessagesSendMultiMediaRequest, + priority uint16, +) (uint64, <-chan JobResult) { + for i := range req.MultiMedia { + req.MultiMedia[i].RandomID = rand.Int63() + } + + return d.AddRawJob(req, priority, uint(len(req.MultiMedia))) +} + +func (d *Dispatcher) AddSendMediaJob( + req *tg.MessagesSendMediaRequest, + priority uint16, +) (uint64, <-chan JobResult) { + if req.RandomID == 0 { + req.RandomID = rand.Int63() + } + + return d.AddRawJob(req, priority, SingleMessage) +} + +// proccessMessagesQueue continuously processes jobs from the heap and sends them to the priority queue channel. +// It waits for new jobs if the heap is empty. +func (d *Dispatcher) proccessMessagesQueue() { + for { + if d.heap.Len() == 0 { + d.cond.L.Lock() + d.cond.Wait() + d.cond.L.Unlock() + } + + job, ok := d.heap.Pop() + if !ok { + continue + } + + d.messageJobChannel <- job + } +} + +// worker processes jobs from the priority queue channel. +// It executes each job and sends the result back through the job's ResultCh. +func (d *Dispatcher) worker(ctx context.Context, id uint) { + for { + select { + case job := <-d.messageJobChannel: + start := time.Now() + + jobResult := job.Execute(ctx, d, id) + + select { + case job.ResultCh <- jobResult: + case <-ctx.Done(): + return + } + + time.Sleep(time.Second - time.Since(start)) + + case <-ctx.Done(): + return + } + } +} diff --git a/yatgbot/middlewares.go b/yatgbot/middlewares.go new file mode 100644 index 0000000..397ab1c --- /dev/null +++ b/yatgbot/middlewares.go @@ -0,0 +1,72 @@ +package yatgbot + +import ( + "context" + "net/http" + + "github.com/YaCodeDev/GoYaCodeDevUtils/yaerrors" + "github.com/gotd/td/tg" +) + +var ErrRouteMismatch = yaerrors.FromString(http.StatusContinue, "route: handler type mismatch") + +// HandlerNext is a function that represents the next handler in the middleware chain. +type HandlerNext = func(ctx context.Context, handlerData *HandlerData, upd tg.UpdateClass) yaerrors.Error + +// HandlerMiddleware is a middleware function that can process an update before or after the main handler. +type HandlerMiddleware func( + ctx context.Context, + handlerData *HandlerData, + upd tg.UpdateClass, + next HandlerNext, +) yaerrors.Error + +// AddMiddleware adds one or more middlewares to the router. +// +// Example usage: +// +// r.AddMiddleware(loggingMiddleware, authMiddleware) +func (r *RouterGroup) AddMiddleware(mw ...HandlerMiddleware) { + r.middlewares = append(r.middlewares, mw...) +} + +// chainMiddleware chains the provided middlewares and returns a single HandlerNext function. +func chainMiddleware(final HandlerNext, middlewares ...HandlerMiddleware) HandlerNext { + if len(middlewares) == 0 { + return final + } + + for _, mw := range middlewares { + middleware := mw + next := final + + final = func(ctx context.Context, hd *HandlerData, upd tg.UpdateClass) yaerrors.Error { + return middleware(ctx, hd, upd, next) + } + } + + return final +} + +// wrapHandler wraps a specific handler function into a generic HandlerNext function. +func wrapHandler[T tg.UpdateClass]( + h func(context.Context, *HandlerData, T) yaerrors.Error, +) HandlerNext { + return func(ctx context.Context, handlerData *HandlerData, upd tg.UpdateClass) yaerrors.Error { + t, ok := upd.(T) + if !ok { + return ErrRouteMismatch + } + + return h(ctx, handlerData, t) + } +} + +// collectMiddlewares collects middlewares from the current router and its parent routers. +func (r *RouterGroup) collectMiddlewares() []HandlerMiddleware { + if r.parent == nil { + return r.middlewares + } + + return append(r.parent.collectMiddlewares(), r.middlewares...) +} diff --git a/yatgbot/router.go b/yatgbot/router.go new file mode 100644 index 0000000..0826793 --- /dev/null +++ b/yatgbot/router.go @@ -0,0 +1,278 @@ +package yatgbot + +import ( + "context" + + "github.com/YaCodeDev/GoYaCodeDevUtils/yaerrors" + "github.com/YaCodeDev/GoYaCodeDevUtils/yafsm" + "github.com/YaCodeDev/GoYaCodeDevUtils/yalocales" + "github.com/YaCodeDev/GoYaCodeDevUtils/yalogger" + "github.com/YaCodeDev/GoYaCodeDevUtils/yatgbot/messagequeue" + "github.com/YaCodeDev/GoYaCodeDevUtils/yatgclient" + "github.com/gotd/td/tg" +) + +// HandlerData holds the dependencies and context for a handler execution. +type HandlerData struct { + Entities tg.Entities + Client *yatgclient.Client + Update tg.UpdateClass + UserID int64 + Peer tg.InputPeerClass + StateStorage *yafsm.EntityFSMStorage + Log yalogger.Logger + Dispatcher *messagequeue.Dispatcher + Localizer yalocales.Localizer + JobResults []messagequeue.JobResult +} + +type ( + // CallbackHandler is a function that processes incoming callback queries. + CallbackHandler func( + ctx context.Context, + handlerData *HandlerData, + cb *tg.UpdateBotCallbackQuery, + ) yaerrors.Error + + // NewMessageHandler is a function that processes incoming messages. + NewMessageHandler func( + ctx context.Context, + handlerData *HandlerData, + msg *tg.UpdateNewMessage, + ) yaerrors.Error + + // EditMessageHandler is a function that processes edited messages. + EditMessageHandler func( + ctx context.Context, + handlerData *HandlerData, + msg *tg.UpdateEditMessage, + ) yaerrors.Error + + // DeleteMessageHandler is a function that processes deleted messages. + DeleteMessageHandler func( + ctx context.Context, + handlerData *HandlerData, + msg *tg.UpdateDeleteMessages, + ) yaerrors.Error + + // NewChannelMessageHandler is a function that processes new channel messages. + NewChannelMessageHandler func( + ctx context.Context, + handlerData *HandlerData, + msg *tg.UpdateNewChannelMessage, + ) yaerrors.Error + + // EditChannelMessageHandler is a function that processes edited channel messages. + EditChannelMessageHandler func( + ctx context.Context, + handlerData *HandlerData, + msg *tg.UpdateEditChannelMessage, + ) yaerrors.Error + + // DeleteChannelMessagesHandler is a function that processes deleted channel messages. + DeleteChannelMessagesHandler func( + ctx context.Context, + handlerData *HandlerData, + msg *tg.UpdateDeleteChannelMessages, + ) yaerrors.Error + + // MessageReactionsHandler is a function that processes message reactions updates. + MessageReactionsHandler func( + ctx context.Context, + handlerData *HandlerData, + msg *tg.UpdateMessageReactions, + ) yaerrors.Error + + // ChannelParticipantHandler is a function that processes channel participant updates. + ChannelParticipantHandler func( + ctx context.Context, + handlerData *HandlerData, + msg *tg.UpdateChannelParticipant, + ) yaerrors.Error + + // PrecheckoutQueryHandler is a function that processes incoming pre-checkout queries. + PrecheckoutQueryHandler func( + tx context.Context, + handlerData *HandlerData, + query *tg.UpdateBotPrecheckoutQuery, + ) yaerrors.Error + + // InlineQueryHandler is a function that processes incoming inline queries. + InlineQueryHandler func( + ctx context.Context, + handlerData *HandlerData, + query *tg.UpdateBotInlineQuery, + ) yaerrors.Error +) + +// RouterGroup is the main struct that holds routes, sub-routers, and middlewares. +type RouterGroup struct { + parent *RouterGroup + base []Filter + sub []*RouterGroup + routes []route + middlewares []HandlerMiddleware +} + +// route represents a single route with its associated filters and handler. +type route struct { + filters []Filter + handler HandlerNext +} + +// NewRouterGroup creates a new Router instance with the given name. +// +// Example usage: +// +// r := router.NewRouterGroup("main", YourDependencies) +func NewRouterGroup() *RouterGroup { + return &RouterGroup{} +} + +// IncludeRouter includes sub-routers into the current router. +// It sets the parent and inherits dependencies if they are not set. +// +// Example usage: +// +// subRouter := router.NewRouterGroup() +// +// router.IncludeRouter(subRouter) +func (r *RouterGroup) IncludeRouter(subs ...*RouterGroup) { + for _, s := range subs { + s.parent = r + + r.sub = append(r.sub, s) + } +} + +// OnCallback registers a callback handler with optional filters. +// +// Example usage: +// +// router.OnCallback(YourCallbackHandler, YourFilter1, YourFilter2) +func (r *RouterGroup) OnCallback(h CallbackHandler, filters ...Filter) { + r.routes = append(r.routes, route{ + handler: wrapHandler(h), + filters: filters, + }) +} + +// OnMessage registers a message handler with optional filters. +// +// Example usage: +// +// router.OnMessage(YourMessageHandler, YourFilter1, YourFilter2) +func (r *RouterGroup) OnMessage(h NewMessageHandler, filters ...Filter) { + r.routes = append(r.routes, route{ + handler: wrapHandler(h), + filters: filters, + }) +} + +// OnEditMessage registers an edit message handler with optional filters. +// +// Example usage: +// +// router.OnEditMessage(YourEditMessageHandler, YourFilter1, YourFilter2) +func (r *RouterGroup) OnEditMessage(h EditMessageHandler, filters ...Filter) { + r.routes = append(r.routes, route{ + handler: wrapHandler(h), + filters: filters, + }) +} + +// OnDeleteMessage registers a delete message handler with optional filters. +// +// Example usage: +// +// router.OnDeleteMessage(YourDeleteMessageHandler, YourFilter1, YourFilter2) +func (r *RouterGroup) OnDeleteMessage(h DeleteMessageHandler, filters ...Filter) { + r.routes = append(r.routes, route{ + handler: wrapHandler(h), + filters: filters, + }) +} + +// OnNewChannelMessage registers a new channel message handler with optional filters. +// +// Example usage: +// +// router.OnNewChannelMessage(YourNewChannelMessageHandler, YourFilter1, YourFilter2) +func (r *RouterGroup) OnNewChannelMessage(h NewChannelMessageHandler, filters ...Filter) { + r.routes = append(r.routes, route{ + handler: wrapHandler(h), + filters: filters, + }) +} + +// OnEditChannelMessage registers an edit channel message handler with optional filters. +// +// Example usage: +// +// router.OnEditChannelMessage(YourEditChannelMessageHandler, YourFilter1, YourFilter2) +func (r *RouterGroup) OnEditChannelMessage(h EditChannelMessageHandler, filters ...Filter) { + r.routes = append(r.routes, route{ + handler: wrapHandler(h), + filters: filters, + }) +} + +// OnDeleteChannelMessages registers a delete channel messages handler with optional filters. +// +// Example usage: +// +// router.OnDeleteChannelMessages(YourDeleteChannelMessagesHandler, YourFilter1, YourFilter2) +func (r *RouterGroup) OnDeleteChannelMessages(h DeleteChannelMessagesHandler, filters ...Filter) { + r.routes = append(r.routes, route{ + handler: wrapHandler(h), + filters: filters, + }) +} + +// OnMessageReactions registers a message reactions handler with optional filters. +// +// Example usage: +// +// router.OnMessageReactions(YourMessageReactionsHandler, YourFilter1, YourFilter2) +func (r *RouterGroup) OnMessageReactions(h MessageReactionsHandler, filters ...Filter) { + r.routes = append(r.routes, route{ + handler: wrapHandler(h), + filters: filters, + }) +} + +// OnChannelParticipant registers a channel participant handler with optional filters. +// +// Example usage: +// +// router.OnChannelParticipant(YourChannelParticipantHandler, YourFilter1, YourFilter2) +func (r *RouterGroup) OnChannelParticipant(h ChannelParticipantHandler, filters ...Filter) { + r.routes = append(r.routes, route{ + handler: wrapHandler(h), + filters: filters, + }) +} + +// OnPrecheckoutQuery registers a pre-checkout query handler with optional filters. +// +// Example usage: +// +// router.OnPrecheckoutQuery(YourPrecheckoutQueryHandler, YourFilter1, YourFilter2) +func (r *RouterGroup) OnPrecheckoutQuery(h PrecheckoutQueryHandler, filters ...Filter) { + r.routes = append(r.routes, route{ + handler: wrapHandler(h), + filters: filters, + }) +} + +// OnInlineQuery registers an inline query handler with optional filters. +// +// Example usage: +// +// router.OnInlineQuery(YourInlineQueryHandler, YourFilter1, YourFilter2) +func (r *RouterGroup) OnInlineQuery(h InlineQueryHandler, filters ...Filter) { + r.routes = append(r.routes, route{ + handler: wrapHandler(h), + filters: filters, + }) +} diff --git a/yatgbot/updates.go b/yatgbot/updates.go new file mode 100644 index 0000000..a378f66 --- /dev/null +++ b/yatgbot/updates.go @@ -0,0 +1,353 @@ +package yatgbot + +import ( + "context" + + "github.com/gotd/td/tg" +) + +// Bind binds the router to the given update dispatcher. +// It sets up updates handling for bot. +// It should be called once during the bot setup. +// After calling this method, the router will start receiving updates +// and dispatching them to the appropriate handlers based on the defined routes and filters. +// +// Example usage: +// +// router := yatgbot.NewRouterGroup() +// +// dispatcher := tg.NewUpdateDispatcher(yourClient) +// +// router.Bind(dispatcher) +func (r *Dispatcher) Bind(tgDispatcher *tg.UpdateDispatcher) { + tgDispatcher.OnNewMessage(r.handleNewMessage) + tgDispatcher.OnBotCallbackQuery(r.handleBotCallbackQuery) + tgDispatcher.OnDeleteMessages(r.handleDeleteMessages) + tgDispatcher.OnEditMessage(r.handleEditMessage) + tgDispatcher.OnNewChannelMessage(r.handleNewChannelMessage) + tgDispatcher.OnEditChannelMessage(r.handleEditChannelMessage) + tgDispatcher.OnChannelParticipant(r.handleChannelParticipant) + tgDispatcher.OnDeleteChannelMessages(r.handleDeleteChannelMessages) + tgDispatcher.OnBotMessageReactions(r.handleBotMessageReactions) + tgDispatcher.OnBotPrecheckoutQuery(r.handleBotPrecheckoutQuery) + tgDispatcher.OnBotInlineQuery(r.handleBotInlineQuery) +} + +// handleNewMessage wraps the new message handler to match the expected signature for the update dispatcher. +func (r *Dispatcher) handleNewMessage( + ctx context.Context, + ent tg.Entities, + upd *tg.UpdateNewMessage, +) error { + var ( + uid int64 + chatID int64 + peer tg.InputPeerClass + ) + + switch msg := upd.Message.(type) { + case *tg.Message: + if msg.FromID != nil { + if fromUser, ok := msg.FromID.(*tg.PeerUser); ok { + if fromUser.UserID == r.BotUser.ID { + return nil + } + } + } + + uid, _ = getUserID(msg.PeerID, msg.FromID) + + chatID, _ = getChatID(msg.PeerID, ent) + + peer, _ = makeInputPeer(msg.PeerID, ent) + + case *tg.MessageService: + uid, _ = getUserID(msg.PeerID, msg.FromID) + + chatID, _ = getChatID(msg.PeerID, ent) + + peer, _ = makeInputPeer(msg.PeerID, ent) + default: + return nil + } + + return r.dispatch(ctx, UpdateData{ + userID: uid, + chatID: chatID, + ent: ent, + update: upd, + inputPeer: peer, + }) +} + +// handleBotCallbackQuery wraps the callback query handler to match the expected signature for the update dispatcher. +func (r *Dispatcher) handleBotCallbackQuery( + ctx context.Context, + ent tg.Entities, + q *tg.UpdateBotCallbackQuery, +) error { + chatID, _ := getChatID(q.Peer, ent) + + peer, _ := makeInputPeer(q.Peer, ent) + + return r.dispatch(ctx, UpdateData{ + userID: q.UserID, + chatID: chatID, + ent: ent, + update: q, + inputPeer: peer, + }) +} + +// handleNewChannelMessage wraps the new channel message handler to match +// the expected signature for the update dispatcher. +func (r *Dispatcher) handleNewChannelMessage( + ctx context.Context, + ent tg.Entities, + upd *tg.UpdateNewChannelMessage, +) error { + var ( + uid int64 + chatID int64 + peer tg.InputPeerClass + ) + + switch msg := upd.Message.(type) { + case *tg.Message: + uid, _ = getUserID(msg.PeerID, msg.FromID) + + chatID, _ = getChatID(msg.PeerID, ent) + + peer, _ = makeInputPeer(msg.PeerID, ent) + + case *tg.MessageService: + uid, _ = getUserID(msg.PeerID, msg.FromID) + + chatID, _ = getChatID(msg.PeerID, ent) + + peer, _ = makeInputPeer(msg.PeerID, ent) + default: + return nil + } + + return r.dispatch(ctx, UpdateData{ + userID: uid, + chatID: chatID, + ent: ent, + update: upd, + inputPeer: peer, + }) +} + +// handleBotPrecheckoutQuery wraps the pre-checkout query handler to match +// the expected signature for the update dispatcher. +func (r *Dispatcher) handleBotPrecheckoutQuery( + ctx context.Context, + ent tg.Entities, + upd *tg.UpdateBotPrecheckoutQuery, +) error { + user, ok := ent.Users[upd.UserID] + if !ok { + return nil + } + + var ( + chatID int64 + inputPeer tg.InputPeerClass + ) + + if len(ent.Chats) > 0 { + chatID = ent.Chats[0].ID + inputPeer = &tg.InputPeerChat{ + ChatID: chatID, + } + } else { + chatID = upd.UserID + inputPeer = &tg.InputPeerUser{ + UserID: upd.UserID, + AccessHash: user.AccessHash, + } + } + + return r.dispatch(ctx, UpdateData{ + userID: upd.UserID, + chatID: chatID, + ent: ent, + update: upd, + inputPeer: inputPeer, + }) +} + +// handleEditMessage wraps the edit message handler to match the expected signature for the update dispatcher. +func (r *Dispatcher) handleEditMessage( + ctx context.Context, + ent tg.Entities, + upd *tg.UpdateEditMessage, +) error { + r.Log.Infof("EditMessage: %+v", upd) + + msg, ok := upd.Message.(*tg.Message) + if !ok { + return nil + } + + if msg.FromID != nil { + if fromUser, okPeer := msg.FromID.(*tg.PeerUser); okPeer { + if fromUser.UserID == r.BotUser.ID { + return nil + } + } + } + + invoice, ok := msg.Media.(*tg.MessageMediaInvoice) + + if ok { + r.Log.Infof("Invoice received: %+v", invoice) + } + + uid, _ := getUserID(msg.PeerID, msg.FromID) + + chatID, _ := getChatID(msg.PeerID, ent) + + peer, _ := makeInputPeer(msg.PeerID, ent) + + return r.dispatch(ctx, UpdateData{ + userID: uid, + chatID: chatID, + ent: ent, + update: upd, + inputPeer: peer, + }) +} + +// handleBotInlineQuery wraps the inline query handler to match the expected signature for the update dispatcher. +func (r *Dispatcher) handleBotInlineQuery( + ctx context.Context, + ent tg.Entities, + upd *tg.UpdateBotInlineQuery, +) error { + user, ok := ent.Users[upd.UserID] + if !ok { + return nil + } + + var ( + chatID int64 + inputPeer tg.InputPeerClass + ) + + if len(ent.Chats) > 0 { + chatID = ent.Chats[0].ID + inputPeer = &tg.InputPeerChat{ + ChatID: chatID, + } + } else { + chatID = upd.UserID + inputPeer = &tg.InputPeerUser{ + UserID: upd.UserID, + AccessHash: user.AccessHash, + } + } + + return r.dispatch(ctx, UpdateData{ + userID: upd.UserID, + chatID: chatID, + ent: ent, + update: upd, + inputPeer: inputPeer, + }) +} + +// handleEditChannelMessage wraps the edit channel message handler to match +// the expected signature for the update dispatcher. +func (r *Dispatcher) handleEditChannelMessage( + ctx context.Context, + ent tg.Entities, + upd *tg.UpdateEditChannelMessage, +) error { + msg, ok := upd.Message.(*tg.Message) + if !ok { + return nil + } + + uid, _ := getUserID(msg.PeerID, msg.FromID) + + chatID, _ := getChatID(msg.PeerID, ent) + + peer, _ := makeInputPeer(msg.PeerID, ent) + + return r.dispatch(ctx, UpdateData{ + userID: uid, + chatID: chatID, + ent: ent, + update: upd, + inputPeer: peer, + }) +} + +// handleDeleteMessages wraps the delete messages handler to match the expected signature for the update dispatcher. +func (r *Dispatcher) handleDeleteMessages( + ctx context.Context, + ent tg.Entities, + upd *tg.UpdateDeleteMessages, +) error { + return r.dispatch(ctx, UpdateData{ + userID: 0, + chatID: 0, + ent: ent, + update: upd, + inputPeer: nil, + }) +} + +// handleDeleteChannelMessages wraps the delete channel messages handler to match +// the expected signature for the update dispatcher. +func (r *Dispatcher) handleDeleteChannelMessages( + ctx context.Context, + ent tg.Entities, + upd *tg.UpdateDeleteChannelMessages, +) error { + return r.dispatch(ctx, UpdateData{ + userID: 0, + chatID: upd.ChannelID, + ent: ent, + update: upd, + inputPeer: nil, + }) +} + +// handleChannelParticipant wraps the channel participant handler to match +// the expected signature for the update dispatcher. +func (r *Dispatcher) handleChannelParticipant( + ctx context.Context, + ent tg.Entities, + upd *tg.UpdateChannelParticipant, +) error { + return r.dispatch(ctx, UpdateData{ + userID: upd.UserID, + chatID: upd.ChannelID, + ent: ent, + update: upd, + inputPeer: nil, + }) +} + +// handleBotMessageReactions wraps the message reactions handler to match +// the expected signature for the update dispatcher. +func (r *Dispatcher) handleBotMessageReactions( + ctx context.Context, + ent tg.Entities, + upd *tg.UpdateBotMessageReactions, +) error { + chatID, _ := getChatID(upd.Peer, ent) + + peer, _ := makeInputPeer(upd.Peer, ent) + + return r.dispatch(ctx, UpdateData{ + userID: 0, + chatID: chatID, + ent: ent, + update: upd, + inputPeer: peer, + }) +} diff --git a/yatgbot/utils.go b/yatgbot/utils.go new file mode 100644 index 0000000..9d926b1 --- /dev/null +++ b/yatgbot/utils.go @@ -0,0 +1,53 @@ +package yatgbot + +import "github.com/gotd/td/tg" + +// ExtractMessageFromUpdate tries to extract a *tg.Message from the given update. +// It returns the message and true if successful, otherwise nil and false. +// +// Example usage: +// +// msg, ok := ExtractMessageFromUpdate(update) +// +// if ok { +// // process msg +// } +func ExtractMessageFromUpdate(upd tg.UpdateClass) (*tg.Message, bool) { + switch u := upd.(type) { + case *tg.UpdateNewMessage: + if msg, ok := u.Message.(*tg.Message); ok { + return msg, true + } + case *tg.UpdateNewChannelMessage: + if msg, ok := u.Message.(*tg.Message); ok { + return msg, true + } + } + + return nil, false +} + +// ExtractMessageServiceFromUpdate tries to extract a *tg.MessageService from the given update. +// It returns the message service and true if successful, otherwise nil and false. +// +// Example usage: +// +// msgService, ok := ExtractMessageServiceFromUpdate(update) +// +// if ok { +// // process msgService +// } +func ExtractMessageServiceFromUpdate(upd tg.UpdateClass) (*tg.MessageService, bool) { + switch u := upd.(type) { + case *tg.UpdateNewMessage: + if msg, ok := u.Message.(*tg.MessageService); ok { + return msg, true + } + case *tg.UpdateNewChannelMessage: + if msg, ok := u.Message.(*tg.MessageService); ok { + return msg, true + } + } + + return nil, false +} diff --git a/yatgbot/yatgbot.go b/yatgbot/yatgbot.go new file mode 100644 index 0000000..a18d44e --- /dev/null +++ b/yatgbot/yatgbot.go @@ -0,0 +1,142 @@ +package yatgbot + +import ( + "context" + "io/fs" + "net/http" + "strconv" + "strings" + + "github.com/YaCodeDev/GoYaCodeDevUtils/yacache" + "github.com/YaCodeDev/GoYaCodeDevUtils/yaerrors" + "github.com/YaCodeDev/GoYaCodeDevUtils/yafsm" + "github.com/YaCodeDev/GoYaCodeDevUtils/yalocales" + "github.com/YaCodeDev/GoYaCodeDevUtils/yalogger" + "github.com/YaCodeDev/GoYaCodeDevUtils/yatgbot/messagequeue" + "github.com/YaCodeDev/GoYaCodeDevUtils/yatgclient" + "github.com/YaCodeDev/GoYaCodeDevUtils/yatgstorage" + "github.com/gotd/td/telegram" + "github.com/gotd/td/telegram/updates" + "github.com/gotd/td/tg" + "github.com/redis/go-redis/v9" + "gorm.io/gorm" +) + +// InitYaTgBot initializes and returns a Dispatcher for the Telegram bot. +// It sets up the necessary components such as the Telegram client, session storage, +// FSM storage, localizer, and message dispatcher. +// +// Example usage: +// +// dispatcher, err := InitYaTgBot( +// ctx, +// "en", +// appID, +// appHash, +// botToken, +// poolDB, +// 10, +// embeddedLocales, +// log, +// cache, +// mainRouter, +// ) +// +// If err != nil { +// // Handle error +// } +func InitYaTgBot( + ctx context.Context, + defaultLang string, + appID int, + appHash string, + botToken string, + poolDB *gorm.DB, + messageQueueRatePerSecond uint, + embeddedLocales fs.FS, + log yalogger.Logger, + cache yacache.Cache[*redis.Client], + mainRouter *RouterGroup, +) (Dispatcher, yaerrors.Error) { + head, _, _ := strings.Cut(botToken, ":") + + BotID, err := strconv.ParseInt(strings.TrimSpace(head), 10, 64) + if err != nil || BotID <= 0 { + return Dispatcher{}, yaerrors.FromError( + http.StatusBadRequest, + err, + "invalid bot token provided", + ) + } + + telegramDispatcher := tg.NewUpdateDispatcher() + + fsmStorage := yafsm.NewDefaultFSMStorage(cache, yafsm.EmptyState{}) + + localizer := yalocales.NewLocalizer(defaultLang, true) + if yaErr := localizer.LoadLocales(embeddedLocales); yaErr != nil { + return Dispatcher{}, yaErr + } + + gormSessionRepo, yaErr := yatgstorage.NewGormSessionStorage(poolDB) + if yaErr != nil { + return Dispatcher{}, yaErr + } + + sessionStorage := yatgstorage.NewSessionStorageWithCustomRepo(BotID, botToken, gormSessionRepo) + stateStorage := yatgstorage.NewStorage(cache, log) + + gaps := yatgclient.NewUpdateManagerWithYaStorage( + BotID, + telegramDispatcher, + stateStorage, + ) + + client := yatgclient.NewClient( + yatgclient.ClientOptions{ + AppID: appID, + AppHash: appHash, + EntityID: BotID, + TelegramOptions: telegram.Options{ + SessionStorage: sessionStorage.TelegramSessionStorageCompatible(), + UpdateHandler: gaps, + }, + }, + log, + ) + + msgDispatcher := messagequeue.NewDispatcher(ctx, client, messageQueueRatePerSecond, log) + + if err := client.BackgroundConnect(ctx); err != nil { + return Dispatcher{}, err + } + + if err := client.BotAuthorization(ctx, botToken); err != nil { + return Dispatcher{}, err + } + + _ = client.RunUpdatesManager(ctx, gaps, updates.AuthOptions{IsBot: true}, nil) + + botUser, err := client.Self(ctx) + if err != nil { + return Dispatcher{}, yaerrors.FromError( + http.StatusInternalServerError, + err, + "failed to get bot user", + ) + } + + dispatcher := Dispatcher{ + FSMStore: fsmStorage, + Log: log, + BotUser: botUser, + MessageDispatcher: msgDispatcher, + Localizer: localizer, + Client: client, + MainRouter: mainRouter, + } + + dispatcher.Bind(&telegramDispatcher) + + return dispatcher, nil +}