Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion internal/cmd/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ type AuthAddCmd struct {
Manual bool `name:"manual" help:"Browserless auth flow (paste redirect URL)"`
Remote bool `name:"remote" help:"Remote/server-friendly manual flow (print URL, then exchange code)"`
Step int `name:"step" help:"Remote auth step: 1=print URL, 2=exchange code"`
RedirectURI string `name:"redirect-uri" help:"Override OAuth redirect URI for manual/remote flows (for example https://host.example/oauth2/callback)"`
AuthURL string `name:"auth-url" help:"Redirect URL from browser (manual flow; required for --remote --step 2)"`
AuthCode string `name:"auth-code" hidden:"" help:"UNSAFE: Authorization code from browser (manual flow; skips state check; not valid with --remote)"`
Timeout time.Duration `name:"timeout" help:"Authorization timeout (manual flows default to 5m)"`
Expand Down Expand Up @@ -525,6 +526,7 @@ func (c *AuthAddCmd) Run(ctx context.Context, flags *RootFlags) error {

authURL := strings.TrimSpace(c.AuthURL)
authCode := strings.TrimSpace(c.AuthCode)
redirectURI := strings.TrimSpace(c.RedirectURI)
if authURL != "" && authCode != "" {
return usage("cannot combine --auth-url with --auth-code")
}
Expand All @@ -535,7 +537,7 @@ func (c *AuthAddCmd) Run(ctx context.Context, flags *RootFlags) error {
return usage("--step requires --remote")
}

manual := c.Manual || c.Remote || authURL != "" || authCode != ""
manual := c.Manual || c.Remote || authURL != "" || authCode != "" || redirectURI != ""

if c.Remote {
step := c.Step
Expand All @@ -558,6 +560,7 @@ func (c *AuthAddCmd) Run(ctx context.Context, flags *RootFlags) error {
ForceConsent: c.ForceConsent,
DisableIncludeGrantedScopes: disableIncludeGrantedScopes,
Client: client,
RedirectURI: redirectURI,
})
if manualErr != nil {
return manualErr
Expand Down Expand Up @@ -595,6 +598,7 @@ func (c *AuthAddCmd) Run(ctx context.Context, flags *RootFlags) error {
"manual": c.Manual,
"remote": c.Remote,
"step": c.Step,
"redirect_uri": redirectURI,
"force_consent": c.ForceConsent,
"readonly": c.Readonly,
"drive_scope": c.DriveScope,
Expand All @@ -618,6 +622,7 @@ func (c *AuthAddCmd) Run(ctx context.Context, flags *RootFlags) error {
Client: client,
AuthURL: authURL,
AuthCode: authCode,
RedirectURI: redirectURI,
RequireState: c.Remote,
})
if err != nil {
Expand Down
92 changes: 92 additions & 0 deletions internal/cmd/auth_add_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,52 @@ func TestAuthAddCmd_RemoteStep1_PrintsAuthURL(t *testing.T) {
}
}

func TestAuthAddCmd_RemoteStep1_PassesRedirectURI(t *testing.T) {
origManualURL := manualAuthURL
origAuth := authorizeGoogle
origKeychain := ensureKeychainAccess
t.Cleanup(func() {
manualAuthURL = origManualURL
authorizeGoogle = origAuth
ensureKeychainAccess = origKeychain
})

var gotOpts googleauth.AuthorizeOptions
manualAuthURL = func(_ context.Context, opts googleauth.AuthorizeOptions) (googleauth.ManualAuthURLResult, error) {
gotOpts = opts
return googleauth.ManualAuthURLResult{
URL: "https://example.com/auth",
}, nil
}
authorizeGoogle = func(context.Context, googleauth.AuthorizeOptions) (string, error) {
t.Fatal("authorizeGoogle should not be called in remote step 1")
return "", nil
}
ensureKeychainAccess = func() error {
t.Fatal("keychain access should not be checked in remote step 1")
return nil
}

if err := Execute([]string{
"auth",
"add",
"user@example.com",
"--services",
"gmail",
"--remote",
"--step",
"1",
"--redirect-uri",
"https://molty2.tail8108.ts.net/oauth2/callback",
}); err != nil {
t.Fatalf("Execute: %v", err)
}

if gotOpts.RedirectURI != "https://molty2.tail8108.ts.net/oauth2/callback" {
t.Fatalf("unexpected redirect uri: %q", gotOpts.RedirectURI)
}
}

func TestAuthAddCmd_RemoteStep2_RejectsAuthCode(t *testing.T) {
err := Execute([]string{
"auth",
Expand Down Expand Up @@ -729,6 +775,52 @@ func TestAuthAddCmd_RemoteStep2_PassesAuthURL(t *testing.T) {
}
}

func TestAuthAddCmd_RemoteStep2_PassesRedirectURI(t *testing.T) {
origAuth := authorizeGoogle
origOpen := openSecretsStore
origKeychain := ensureKeychainAccess
origFetch := fetchAuthorizedEmail
t.Cleanup(func() {
authorizeGoogle = origAuth
openSecretsStore = origOpen
ensureKeychainAccess = origKeychain
fetchAuthorizedEmail = origFetch
})

ensureKeychainAccess = func() error { return nil }
openSecretsStore = func() (secrets.Store, error) { return newMemSecretsStore(), nil }

var gotOpts googleauth.AuthorizeOptions
authorizeGoogle = func(ctx context.Context, opts googleauth.AuthorizeOptions) (string, error) {
gotOpts = opts
return "rt", nil
}
fetchAuthorizedEmail = func(context.Context, string, string, []string, time.Duration) (string, error) {
return "user@example.com", nil
}

if err := Execute([]string{
"auth",
"add",
"user@example.com",
"--services",
"gmail",
"--remote",
"--step",
"2",
"--redirect-uri",
"https://molty2.tail8108.ts.net/oauth2/callback",
"--auth-url",
"https://molty2.tail8108.ts.net/oauth2/callback?code=abc&state=state123",
}); err != nil {
t.Fatalf("Execute: %v", err)
}

if gotOpts.RedirectURI != "https://molty2.tail8108.ts.net/oauth2/callback" {
t.Fatalf("unexpected redirect uri: %q", gotOpts.RedirectURI)
}
}

func TestAuthAddCmd_AuthCode_PassesThrough(t *testing.T) {
origAuth := authorizeGoogle
origOpen := openSecretsStore
Expand Down
114 changes: 114 additions & 0 deletions internal/googleauth/manual_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,109 @@ func TestManualAuthURL_ReusesState(t *testing.T) {
}
}

func TestManualAuthURL_UsesRedirectURIOverride(t *testing.T) {
origRead := readClientCredentials
origEndpoint := oauthEndpoint
origState := randomStateFn
origManualRedirect := manualRedirectURIFn

t.Cleanup(func() {
readClientCredentials = origRead
oauthEndpoint = origEndpoint
randomStateFn = origState
manualRedirectURIFn = origManualRedirect
})

useTempManualStatePath(t)

readClientCredentials = func(string) (config.ClientCredentials, error) {
return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil
}
oauthEndpoint = oauth2EndpointForTest("http://example.com")
randomStateFn = func() (string, error) { return "state1", nil }
manualRedirectURIFn = func(context.Context) (string, error) {
t.Fatal("manualRedirectURIFn should not be called when redirect-uri is provided")
return "", nil
}

res, err := ManualAuthURL(context.Background(), AuthorizeOptions{
Scopes: []string{"s1"},
Manual: true,
RedirectURI: "https://host.example/oauth2/callback",
})
if err != nil {
t.Fatalf("ManualAuthURL: %v", err)
}

if got := authURLRedirectURI(t, res.URL); got != "https://host.example/oauth2/callback" {
t.Fatalf("unexpected redirect uri: %q", got)
}
}

func TestManualAuthURL_ChangesStateWhenRedirectURIOverrideChanges(t *testing.T) {
origRead := readClientCredentials
origEndpoint := oauthEndpoint
origState := randomStateFn
origManualRedirect := manualRedirectURIFn

t.Cleanup(func() {
readClientCredentials = origRead
oauthEndpoint = origEndpoint
randomStateFn = origState
manualRedirectURIFn = origManualRedirect
})

useTempManualStatePath(t)

readClientCredentials = func(string) (config.ClientCredentials, error) {
return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil
}
oauthEndpoint = oauth2EndpointForTest("http://example.com")
stateCalls := 0
randomStateFn = func() (string, error) {
stateCalls++
if stateCalls == 1 {
return "state1", nil
}

return "state2", nil
}
manualRedirectURIFn = func(context.Context) (string, error) {
t.Fatal("manualRedirectURIFn should not be called when redirect-uri is provided")
return "", nil
}

res1, err := ManualAuthURL(context.Background(), AuthorizeOptions{
Scopes: []string{"s1"},
Manual: true,
RedirectURI: "https://host.example/oauth2/callback",
})
if err != nil {
t.Fatalf("ManualAuthURL first: %v", err)
}

res2, err := ManualAuthURL(context.Background(), AuthorizeOptions{
Scopes: []string{"s1"},
Manual: true,
RedirectURI: "https://other.example/oauth2/callback",
})
if err != nil {
t.Fatalf("ManualAuthURL second: %v", err)
}

if authURLState(t, res1.URL) == authURLState(t, res2.URL) {
t.Fatalf("expected a new state when redirect uri changes")
}

if res2.StateReused {
t.Fatalf("expected state_reused false when redirect uri changes")
}

if stateCalls != 2 {
t.Fatalf("expected randomStateFn called twice, got %d", stateCalls)
}
}

func authURLState(t *testing.T, rawURL string) string {
t.Helper()

Expand All @@ -77,3 +180,14 @@ func authURLState(t *testing.T, rawURL string) string {

return parsed.Query().Get("state")
}

func authURLRedirectURI(t *testing.T, rawURL string) string {
t.Helper()

parsed, err := url.Parse(rawURL)
if err != nil {
t.Fatalf("parse auth URL: %v", err)
}

return parsed.Query().Get("redirect_uri")
}
10 changes: 10 additions & 0 deletions internal/googleauth/oauth_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type AuthorizeOptions struct {
Client string
AuthCode string
AuthURL string
RedirectURI string
RequireState bool
}

Expand Down Expand Up @@ -79,6 +80,15 @@ func Authorize(ctx context.Context, opts AuthorizeOptions) (string, error) {
opts.Timeout = 2 * time.Minute
}

if strings.TrimSpace(opts.RedirectURI) != "" {
redirectURI, err := normalizeRedirectURI(opts.RedirectURI)
if err != nil {
return "", err
}

opts.RedirectURI = redirectURI
}

if strings.TrimSpace(opts.AuthURL) != "" && strings.TrimSpace(opts.AuthCode) != "" {
return "", errInvalidAuthorizeOptionsAuthURLAndCode
}
Expand Down
Loading
Loading