diff --git a/internal/cmd/auth.go b/internal/cmd/auth.go index 657bc9e7..a56e23f9 100644 --- a/internal/cmd/auth.go +++ b/internal/cmd/auth.go @@ -478,6 +478,7 @@ type AuthAddCmd struct { Manual bool `name:"manual" help:"Browserless auth flow (paste redirect URL)"` Remote bool `name:"remote" help:"Remote/server-friendly manual flow (print URL, then exchange code)"` Step int `name:"step" help:"Remote auth step: 1=print URL, 2=exchange code"` + RedirectURI string `name:"redirect-uri" help:"Override OAuth redirect URI for manual/remote flows (for example https://host.example/oauth2/callback)"` AuthURL string `name:"auth-url" help:"Redirect URL from browser (manual flow; required for --remote --step 2)"` AuthCode string `name:"auth-code" hidden:"" help:"UNSAFE: Authorization code from browser (manual flow; skips state check; not valid with --remote)"` Timeout time.Duration `name:"timeout" help:"Authorization timeout (manual flows default to 5m)"` @@ -525,6 +526,7 @@ func (c *AuthAddCmd) Run(ctx context.Context, flags *RootFlags) error { authURL := strings.TrimSpace(c.AuthURL) authCode := strings.TrimSpace(c.AuthCode) + redirectURI := strings.TrimSpace(c.RedirectURI) if authURL != "" && authCode != "" { return usage("cannot combine --auth-url with --auth-code") } @@ -535,7 +537,7 @@ func (c *AuthAddCmd) Run(ctx context.Context, flags *RootFlags) error { return usage("--step requires --remote") } - manual := c.Manual || c.Remote || authURL != "" || authCode != "" + manual := c.Manual || c.Remote || authURL != "" || authCode != "" || redirectURI != "" if c.Remote { step := c.Step @@ -558,6 +560,7 @@ func (c *AuthAddCmd) Run(ctx context.Context, flags *RootFlags) error { ForceConsent: c.ForceConsent, DisableIncludeGrantedScopes: disableIncludeGrantedScopes, Client: client, + RedirectURI: redirectURI, }) if manualErr != nil { return manualErr @@ -595,6 +598,7 @@ func (c *AuthAddCmd) Run(ctx context.Context, flags *RootFlags) error { "manual": c.Manual, "remote": c.Remote, "step": c.Step, + "redirect_uri": redirectURI, "force_consent": c.ForceConsent, "readonly": c.Readonly, "drive_scope": c.DriveScope, @@ -618,6 +622,7 @@ func (c *AuthAddCmd) Run(ctx context.Context, flags *RootFlags) error { Client: client, AuthURL: authURL, AuthCode: authCode, + RedirectURI: redirectURI, RequireState: c.Remote, }) if err != nil { diff --git a/internal/cmd/auth_add_test.go b/internal/cmd/auth_add_test.go index 0ec3e303..78192363 100644 --- a/internal/cmd/auth_add_test.go +++ b/internal/cmd/auth_add_test.go @@ -654,6 +654,52 @@ func TestAuthAddCmd_RemoteStep1_PrintsAuthURL(t *testing.T) { } } +func TestAuthAddCmd_RemoteStep1_PassesRedirectURI(t *testing.T) { + origManualURL := manualAuthURL + origAuth := authorizeGoogle + origKeychain := ensureKeychainAccess + t.Cleanup(func() { + manualAuthURL = origManualURL + authorizeGoogle = origAuth + ensureKeychainAccess = origKeychain + }) + + var gotOpts googleauth.AuthorizeOptions + manualAuthURL = func(_ context.Context, opts googleauth.AuthorizeOptions) (googleauth.ManualAuthURLResult, error) { + gotOpts = opts + return googleauth.ManualAuthURLResult{ + URL: "https://example.com/auth", + }, nil + } + authorizeGoogle = func(context.Context, googleauth.AuthorizeOptions) (string, error) { + t.Fatal("authorizeGoogle should not be called in remote step 1") + return "", nil + } + ensureKeychainAccess = func() error { + t.Fatal("keychain access should not be checked in remote step 1") + return nil + } + + if err := Execute([]string{ + "auth", + "add", + "user@example.com", + "--services", + "gmail", + "--remote", + "--step", + "1", + "--redirect-uri", + "https://molty2.tail8108.ts.net/oauth2/callback", + }); err != nil { + t.Fatalf("Execute: %v", err) + } + + if gotOpts.RedirectURI != "https://molty2.tail8108.ts.net/oauth2/callback" { + t.Fatalf("unexpected redirect uri: %q", gotOpts.RedirectURI) + } +} + func TestAuthAddCmd_RemoteStep2_RejectsAuthCode(t *testing.T) { err := Execute([]string{ "auth", @@ -729,6 +775,52 @@ func TestAuthAddCmd_RemoteStep2_PassesAuthURL(t *testing.T) { } } +func TestAuthAddCmd_RemoteStep2_PassesRedirectURI(t *testing.T) { + origAuth := authorizeGoogle + origOpen := openSecretsStore + origKeychain := ensureKeychainAccess + origFetch := fetchAuthorizedEmail + t.Cleanup(func() { + authorizeGoogle = origAuth + openSecretsStore = origOpen + ensureKeychainAccess = origKeychain + fetchAuthorizedEmail = origFetch + }) + + ensureKeychainAccess = func() error { return nil } + openSecretsStore = func() (secrets.Store, error) { return newMemSecretsStore(), nil } + + var gotOpts googleauth.AuthorizeOptions + authorizeGoogle = func(ctx context.Context, opts googleauth.AuthorizeOptions) (string, error) { + gotOpts = opts + return "rt", nil + } + fetchAuthorizedEmail = func(context.Context, string, string, []string, time.Duration) (string, error) { + return "user@example.com", nil + } + + if err := Execute([]string{ + "auth", + "add", + "user@example.com", + "--services", + "gmail", + "--remote", + "--step", + "2", + "--redirect-uri", + "https://molty2.tail8108.ts.net/oauth2/callback", + "--auth-url", + "https://molty2.tail8108.ts.net/oauth2/callback?code=abc&state=state123", + }); err != nil { + t.Fatalf("Execute: %v", err) + } + + if gotOpts.RedirectURI != "https://molty2.tail8108.ts.net/oauth2/callback" { + t.Fatalf("unexpected redirect uri: %q", gotOpts.RedirectURI) + } +} + func TestAuthAddCmd_AuthCode_PassesThrough(t *testing.T) { origAuth := authorizeGoogle origOpen := openSecretsStore diff --git a/internal/googleauth/manual_state_test.go b/internal/googleauth/manual_state_test.go index c690de3b..9580e474 100644 --- a/internal/googleauth/manual_state_test.go +++ b/internal/googleauth/manual_state_test.go @@ -67,6 +67,109 @@ func TestManualAuthURL_ReusesState(t *testing.T) { } } +func TestManualAuthURL_UsesRedirectURIOverride(t *testing.T) { + origRead := readClientCredentials + origEndpoint := oauthEndpoint + origState := randomStateFn + origManualRedirect := manualRedirectURIFn + + t.Cleanup(func() { + readClientCredentials = origRead + oauthEndpoint = origEndpoint + randomStateFn = origState + manualRedirectURIFn = origManualRedirect + }) + + useTempManualStatePath(t) + + readClientCredentials = func(string) (config.ClientCredentials, error) { + return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil + } + oauthEndpoint = oauth2EndpointForTest("http://example.com") + randomStateFn = func() (string, error) { return "state1", nil } + manualRedirectURIFn = func(context.Context) (string, error) { + t.Fatal("manualRedirectURIFn should not be called when redirect-uri is provided") + return "", nil + } + + res, err := ManualAuthURL(context.Background(), AuthorizeOptions{ + Scopes: []string{"s1"}, + Manual: true, + RedirectURI: "https://host.example/oauth2/callback", + }) + if err != nil { + t.Fatalf("ManualAuthURL: %v", err) + } + + if got := authURLRedirectURI(t, res.URL); got != "https://host.example/oauth2/callback" { + t.Fatalf("unexpected redirect uri: %q", got) + } +} + +func TestManualAuthURL_ChangesStateWhenRedirectURIOverrideChanges(t *testing.T) { + origRead := readClientCredentials + origEndpoint := oauthEndpoint + origState := randomStateFn + origManualRedirect := manualRedirectURIFn + + t.Cleanup(func() { + readClientCredentials = origRead + oauthEndpoint = origEndpoint + randomStateFn = origState + manualRedirectURIFn = origManualRedirect + }) + + useTempManualStatePath(t) + + readClientCredentials = func(string) (config.ClientCredentials, error) { + return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil + } + oauthEndpoint = oauth2EndpointForTest("http://example.com") + stateCalls := 0 + randomStateFn = func() (string, error) { + stateCalls++ + if stateCalls == 1 { + return "state1", nil + } + + return "state2", nil + } + manualRedirectURIFn = func(context.Context) (string, error) { + t.Fatal("manualRedirectURIFn should not be called when redirect-uri is provided") + return "", nil + } + + res1, err := ManualAuthURL(context.Background(), AuthorizeOptions{ + Scopes: []string{"s1"}, + Manual: true, + RedirectURI: "https://host.example/oauth2/callback", + }) + if err != nil { + t.Fatalf("ManualAuthURL first: %v", err) + } + + res2, err := ManualAuthURL(context.Background(), AuthorizeOptions{ + Scopes: []string{"s1"}, + Manual: true, + RedirectURI: "https://other.example/oauth2/callback", + }) + if err != nil { + t.Fatalf("ManualAuthURL second: %v", err) + } + + if authURLState(t, res1.URL) == authURLState(t, res2.URL) { + t.Fatalf("expected a new state when redirect uri changes") + } + + if res2.StateReused { + t.Fatalf("expected state_reused false when redirect uri changes") + } + + if stateCalls != 2 { + t.Fatalf("expected randomStateFn called twice, got %d", stateCalls) + } +} + func authURLState(t *testing.T, rawURL string) string { t.Helper() @@ -77,3 +180,14 @@ func authURLState(t *testing.T, rawURL string) string { return parsed.Query().Get("state") } + +func authURLRedirectURI(t *testing.T, rawURL string) string { + t.Helper() + + parsed, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("parse auth URL: %v", err) + } + + return parsed.Query().Get("redirect_uri") +} diff --git a/internal/googleauth/oauth_flow.go b/internal/googleauth/oauth_flow.go index c31f7884..b6deb35a 100644 --- a/internal/googleauth/oauth_flow.go +++ b/internal/googleauth/oauth_flow.go @@ -29,6 +29,7 @@ type AuthorizeOptions struct { Client string AuthCode string AuthURL string + RedirectURI string RequireState bool } @@ -79,6 +80,15 @@ func Authorize(ctx context.Context, opts AuthorizeOptions) (string, error) { opts.Timeout = 2 * time.Minute } + if strings.TrimSpace(opts.RedirectURI) != "" { + redirectURI, err := normalizeRedirectURI(opts.RedirectURI) + if err != nil { + return "", err + } + + opts.RedirectURI = redirectURI + } + if strings.TrimSpace(opts.AuthURL) != "" && strings.TrimSpace(opts.AuthCode) != "" { return "", errInvalidAuthorizeOptionsAuthURLAndCode } diff --git a/internal/googleauth/oauth_flow_authorize_test.go b/internal/googleauth/oauth_flow_authorize_test.go index 70c310f8..acd9ef59 100644 --- a/internal/googleauth/oauth_flow_authorize_test.go +++ b/internal/googleauth/oauth_flow_authorize_test.go @@ -313,6 +313,123 @@ func TestAuthorize_Manual_AuthCode(t *testing.T) { } } +func TestAuthorize_Manual_AuthCode_WithRedirectURI(t *testing.T) { + origRead := readClientCredentials + origEndpoint := oauthEndpoint + + t.Cleanup(func() { + readClientCredentials = origRead + oauthEndpoint = origEndpoint + }) + useTempManualStatePath(t) + + readClientCredentials = func(string) (config.ClientCredentials, error) { + return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil + } + + wantRedirectURI := "https://host.example/oauth2/callback" + + tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/token" { + http.NotFound(w, r) + return + } + + if err := r.ParseForm(); err != nil { + http.Error(w, "bad form", http.StatusBadRequest) + return + } + + if r.Form.Get("redirect_uri") != wantRedirectURI { + http.Error(w, "bad redirect_uri", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "at", + "refresh_token": "rt", + "token_type": "Bearer", + "expires_in": 3600, + }) + })) + defer tokenSrv.Close() + oauthEndpoint = oauth2EndpointForTest(tokenSrv.URL) + + rt, err := Authorize(context.Background(), AuthorizeOptions{ + Scopes: []string{"s1"}, + Manual: true, + AuthCode: "abc", + RedirectURI: wantRedirectURI, + Timeout: 2 * time.Second, + }) + if err != nil { + t.Fatalf("Authorize: %v", err) + } + + if rt != "rt" { + t.Fatalf("unexpected refresh token: %q", rt) + } +} + +func TestAuthorize_Manual_AuthURL_PrefersAuthURLRedirectOverOverride(t *testing.T) { + origRead := readClientCredentials + origEndpoint := oauthEndpoint + + t.Cleanup(func() { + readClientCredentials = origRead + oauthEndpoint = origEndpoint + }) + useTempManualStatePath(t) + + readClientCredentials = func(string) (config.ClientCredentials, error) { + return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil + } + + redirectFromAuthURL := "https://from-auth-url.example/oauth2/callback" + tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/token" { + http.NotFound(w, r) + return + } + + if err := r.ParseForm(); err != nil { + http.Error(w, "bad form", http.StatusBadRequest) + return + } + + if r.Form.Get("redirect_uri") != redirectFromAuthURL { + http.Error(w, "bad redirect_uri", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "at", + "refresh_token": "rt", + "token_type": "Bearer", + "expires_in": 3600, + }) + })) + defer tokenSrv.Close() + oauthEndpoint = oauth2EndpointForTest(tokenSrv.URL) + + rt, err := Authorize(context.Background(), AuthorizeOptions{ + Scopes: []string{"s1"}, + Manual: true, + AuthURL: redirectFromAuthURL + "?code=abc", + RedirectURI: "https://override.example/oauth2/callback", + Timeout: 2 * time.Second, + }) + if err != nil { + t.Fatalf("Authorize: %v", err) + } + + if rt != "rt" { + t.Fatalf("unexpected refresh token: %q", rt) + } +} + func TestAuthorize_Manual_AuthURL_RequireStateMissing(t *testing.T) { origRead := readClientCredentials origEndpoint := oauthEndpoint diff --git a/internal/googleauth/oauth_flow_manual.go b/internal/googleauth/oauth_flow_manual.go index 6c7f9140..60d6c685 100644 --- a/internal/googleauth/oauth_flow_manual.go +++ b/internal/googleauth/oauth_flow_manual.go @@ -82,6 +82,10 @@ func authorizeManualWithCode( cfg.RedirectURL = gotRedirectURI } + if cfg.RedirectURL == "" && strings.TrimSpace(opts.RedirectURI) != "" { + cfg.RedirectURL = strings.TrimSpace(opts.RedirectURI) + } + if cfg.RedirectURL == "" { if cached, ok, err := loadManualState(opts.Client, opts.Scopes, opts.ForceConsent); err != nil { return "", err @@ -259,6 +263,16 @@ type manualAuthSetupResult struct { } func manualAuthSetup(ctx context.Context, opts AuthorizeOptions) (manualAuthSetupResult, error) { + redirectURIOverride := strings.TrimSpace(opts.RedirectURI) + if redirectURIOverride != "" { + normalized, err := normalizeRedirectURI(redirectURIOverride) + if err != nil { + return manualAuthSetupResult{}, err + } + + redirectURIOverride = normalized + } + st, reused, err := loadManualState(opts.Client, opts.Scopes, opts.ForceConsent) if err != nil { return manualAuthSetupResult{}, err @@ -267,10 +281,19 @@ func manualAuthSetup(ctx context.Context, opts AuthorizeOptions) (manualAuthSetu state := st.State redirectURI := st.RedirectURI + if redirectURIOverride != "" { + if !reused || st.RedirectURI != redirectURIOverride { + reused = false + redirectURI = redirectURIOverride + } + } + if !reused { - redirectURI, err = manualRedirectURIFn(ctx) - if err != nil { - return manualAuthSetupResult{}, err + if redirectURI == "" { + redirectURI, err = manualRedirectURIFn(ctx) + if err != nil { + return manualAuthSetupResult{}, err + } } state, err = randomStateFn() diff --git a/internal/googleauth/oauth_flow_manual_redirect.go b/internal/googleauth/oauth_flow_manual_redirect.go index 3f50e458..e69140ae 100644 --- a/internal/googleauth/oauth_flow_manual_redirect.go +++ b/internal/googleauth/oauth_flow_manual_redirect.go @@ -30,6 +30,19 @@ func redirectURIFromParsedURL(u *url.URL) string { return fmt.Sprintf("%s://%s%s", u.Scheme, u.Host, path) } +func normalizeRedirectURI(rawURI string) (string, error) { + parsed, err := url.Parse(strings.TrimSpace(rawURI)) + if err != nil { + return "", fmt.Errorf("parse redirect uri: %w", err) + } + + if parsed.Scheme == "" || parsed.Host == "" || parsed.RawQuery != "" || parsed.Fragment != "" { + return "", fmt.Errorf("parse redirect uri: %w", errInvalidRedirectURL) + } + + return redirectURIFromParsedURL(parsed), nil +} + func parseRedirectURL(rawURL string) (code string, state string, redirectURI string, err error) { parsed, err := url.Parse(strings.TrimSpace(rawURL)) if err != nil { diff --git a/internal/googleauth/oauth_flow_more_test.go b/internal/googleauth/oauth_flow_more_test.go index d4e8798c..c831b855 100644 --- a/internal/googleauth/oauth_flow_more_test.go +++ b/internal/googleauth/oauth_flow_more_test.go @@ -1,6 +1,7 @@ package googleauth import ( + "context" "net/url" "strings" "testing" @@ -93,3 +94,49 @@ func TestRandomState(t *testing.T) { t.Fatalf("unexpected charset: %q %q", s1, s2) } } + +func TestNormalizeRedirectURI(t *testing.T) { + t.Parallel() + + got, err := normalizeRedirectURI("https://host.example/oauth2/callback") + if err != nil { + t.Fatalf("normalizeRedirectURI: %v", err) + } + + if got != "https://host.example/oauth2/callback" { + t.Fatalf("unexpected redirect uri: %q", got) + } + + got, err = normalizeRedirectURI("https://host.example") + if err != nil { + t.Fatalf("normalizeRedirectURI host-only: %v", err) + } + + if got != "https://host.example/" { + t.Fatalf("expected trailing slash for host-only uri, got: %q", got) + } + + if _, err := normalizeRedirectURI("host-only/path"); err == nil { + t.Fatalf("expected error for invalid redirect uri") + } + + if _, err := normalizeRedirectURI("https://host.example/cb?x=1"); err == nil { + t.Fatalf("expected error when redirect uri has query") + } +} + +func TestAuthorize_InvalidRedirectURI(t *testing.T) { + t.Parallel() + + _, err := Authorize(context.Background(), AuthorizeOptions{ + Scopes: []string{"s1"}, + Manual: true, + RedirectURI: "host-only/path", + }) + if err == nil { + t.Fatalf("expected invalid redirect uri error") + } + if !strings.Contains(err.Error(), "parse redirect uri") { + t.Fatalf("unexpected error: %v", err) + } +}