diff --git a/internal/cmd/gmail_watch_cmds.go b/internal/cmd/gmail_watch_cmds.go index 887fdc8a..2006089f 100644 --- a/internal/cmd/gmail_watch_cmds.go +++ b/internal/cmd/gmail_watch_cmds.go @@ -14,6 +14,7 @@ import ( "google.golang.org/api/gmail/v1" "google.golang.org/api/idtoken" + "github.com/steipete/gogcli/internal/authclient" "github.com/steipete/gogcli/internal/outfmt" "github.com/steipete/gogcli/internal/ui" ) @@ -345,12 +346,20 @@ func (c *GmailWatchServeCmd) Run(ctx context.Context, kctx *kong.Context, flags cfg.MaxBodyBytes = defaultHookMaxBytes } + selectedClient := strings.TrimSpace(flags.Client) + serviceFactory := func(ctx context.Context, account string) (*gmail.Service, error) { + if selectedClient != "" { + ctx = authclient.WithClient(ctx, selectedClient) + } + return newGmailService(ctx, account) + } + hookClient := &http.Client{Timeout: cfg.HookTimeout} server := &gmailWatchServer{ cfg: cfg, store: store, validator: validator, - newService: newGmailService, + newService: serviceFactory, hookClient: hookClient, excludeLabelIDs: stringSet(cfg.ExcludeLabels), logf: u.Err().Printf, diff --git a/internal/cmd/gmail_watch_serve_test.go b/internal/cmd/gmail_watch_serve_test.go index 63b052d5..3e39386d 100644 --- a/internal/cmd/gmail_watch_serve_test.go +++ b/internal/cmd/gmail_watch_serve_test.go @@ -7,8 +7,10 @@ import ( "testing" "time" + "google.golang.org/api/gmail/v1" "google.golang.org/api/idtoken" + "github.com/steipete/gogcli/internal/authclient" "github.com/steipete/gogcli/internal/ui" ) @@ -223,3 +225,57 @@ func TestGmailWatchServeCmd_SaveHookAndOIDC(t *testing.T) { t.Fatalf("expected hook saved, got %#v", loaded.Get().Hook) } } + +func TestGmailWatchServeCmd_PreservesClientOverrideForRequestContexts(t *testing.T) { + origListen := listenAndServe + origNew := newGmailService + t.Cleanup(func() { + listenAndServe = origListen + newGmailService = origNew + }) + + home := t.TempDir() + t.Setenv("HOME", home) + + store, err := newGmailWatchStore("a@b.com") + if err != nil { + t.Fatalf("store: %v", err) + } + updateErr := store.Update(func(s *gmailWatchState) error { + s.Account = "a@b.com" + return nil + }) + if updateErr != nil { + t.Fatalf("seed: %v", updateErr) + } + + flags := &RootFlags{Account: "a@b.com", Client: "personal"} + var got *gmailWatchServer + listenAndServe = func(srv *http.Server) error { + if gs, ok := srv.Handler.(*gmailWatchServer); ok { + got = gs + } + return nil + } + + newGmailService = func(ctx context.Context, _ string) (*gmail.Service, error) { + if client := authclient.ClientOverrideFromContext(ctx); client != "personal" { + t.Fatalf("expected client override personal, got %q", client) + } + return nil, nil + } + + u, err := ui.New(ui.Options{Stdout: io.Discard, Stderr: io.Discard, Color: "never"}) + if err != nil { + t.Fatalf("ui.New: %v", err) + } + if execErr := runKong(t, &GmailWatchServeCmd{}, []string{"--port", "9999", "--path", "/hook"}, ui.WithUI(context.Background(), u), flags); execErr != nil { + t.Fatalf("execute: %v", execErr) + } + if got == nil { + t.Fatalf("expected server") + } + if _, callErr := got.newService(context.Background(), "a@b.com"); callErr != nil { + t.Fatalf("newService: %v", callErr) + } +}