diff --git a/src/core/container.ts b/src/core/container.ts index acf2d854..6411ef46 100644 --- a/src/core/container.ts +++ b/src/core/container.ts @@ -48,6 +48,7 @@ export class ServiceContainer implements vscode.Disposable { this.mementoManager, this.vscodeProposed, this.logger, + context.extension.id, ); } @@ -89,5 +90,6 @@ export class ServiceContainer implements vscode.Disposable { dispose(): void { this.contextManager.dispose(); this.logger.dispose(); + this.loginCoordinator.dispose(); } } diff --git a/src/core/secretsManager.ts b/src/core/secretsManager.ts index e6558299..618ee308 100644 --- a/src/core/secretsManager.ts +++ b/src/core/secretsManager.ts @@ -1,4 +1,5 @@ import { type Logger } from "../logging/logger"; +import { type ClientRegistrationResponse } from "../oauth/types"; import { toSafeHost } from "../util"; import type { Memento, SecretStorage, Disposable } from "vscode"; @@ -7,7 +8,12 @@ import type { Deployment } from "../deployment/types"; // Each deployment has its own key to ensure atomic operations (multiple windows // writing to a shared key could drop data) and to receive proper VS Code events. -const SESSION_KEY_PREFIX = "coder.session."; +const SESSION_KEY_PREFIX = "coder.session." as const; +const OAUTH_CLIENT_PREFIX = "coder.oauth.client." as const; + +type SecretKeyPrefix = typeof SESSION_KEY_PREFIX | typeof OAUTH_CLIENT_PREFIX; + +const OAUTH_CALLBACK_KEY = "coder.oauthCallback"; const CURRENT_DEPLOYMENT_KEY = "coder.currentDeployment"; @@ -20,9 +26,22 @@ export interface CurrentDeploymentState { deployment: Deployment | null; } +/** + * OAuth token data stored alongside session auth. + * When present, indicates the session is authenticated via OAuth. + */ +export interface OAuthTokenData { + token_type: "Bearer"; + refresh_token?: string; + scope?: string; + expiry_timestamp: number; +} + export interface SessionAuth { url: string; token: string; + /** If present, this session uses OAuth authentication */ + oauth?: OAuthTokenData; } // Tracks when a deployment was last accessed for LRU pruning. @@ -31,6 +50,12 @@ interface DeploymentUsage { lastAccessedAt: string; } +interface OAuthCallbackData { + state: string; + code: string | null; + error: string | null; +} + export class SecretsManager { constructor( private readonly secrets: SecretStorage, @@ -38,6 +63,44 @@ export class SecretsManager { private readonly logger: Logger, ) {} + private buildKey(prefix: SecretKeyPrefix, safeHostname: string): string { + return `${prefix}${safeHostname || ""}`; + } + + private async getSecret( + prefix: SecretKeyPrefix, + safeHostname: string, + ): Promise { + try { + const data = await this.secrets.get(this.buildKey(prefix, safeHostname)); + if (!data) { + return undefined; + } + return JSON.parse(data) as T; + } catch { + return undefined; + } + } + + private async setSecret( + prefix: SecretKeyPrefix, + safeHostname: string, + value: T, + ): Promise { + await this.secrets.store( + this.buildKey(prefix, safeHostname), + JSON.stringify(value), + ); + await this.recordDeploymentAccess(safeHostname); + } + + private async clearSecret( + prefix: SecretKeyPrefix, + safeHostname: string, + ): Promise { + await this.secrets.delete(this.buildKey(prefix, safeHostname)); + } + /** * Sets the current deployment and triggers a cross-window sync event. */ @@ -104,7 +167,7 @@ export class SecretsManager { safeHostname: string, listener: (auth: SessionAuth | undefined) => void | Promise, ): Disposable { - const sessionKey = this.getSessionKey(safeHostname); + const sessionKey = this.buildKey(SESSION_KEY_PREFIX, safeHostname); return this.secrets.onDidChange(async (e) => { if (e.key !== sessionKey) { return; @@ -118,39 +181,27 @@ export class SecretsManager { }); } - public async getSessionAuth( + public getSessionAuth( safeHostname: string, ): Promise { - const sessionKey = this.getSessionKey(safeHostname); - try { - const data = await this.secrets.get(sessionKey); - if (!data) { - return undefined; - } - return JSON.parse(data) as SessionAuth; - } catch { - return undefined; - } + return this.getSecret(SESSION_KEY_PREFIX, safeHostname); } public async setSessionAuth( safeHostname: string, auth: SessionAuth, ): Promise { - const sessionKey = this.getSessionKey(safeHostname); - // Extract only url and token before serializing - const state: SessionAuth = { url: auth.url, token: auth.token }; - await this.secrets.store(sessionKey, JSON.stringify(state)); - await this.recordDeploymentAccess(safeHostname); - } - - private async clearSessionAuth(safeHostname: string): Promise { - const sessionKey = this.getSessionKey(safeHostname); - await this.secrets.delete(sessionKey); + // Extract relevant fields before serializing + const state: SessionAuth = { + url: auth.url, + token: auth.token, + ...(auth.oauth && { oauth: auth.oauth }), + }; + await this.setSecret(SESSION_KEY_PREFIX, safeHostname, state); } - private getSessionKey(safeHostname: string): string { - return `${SESSION_KEY_PREFIX}${safeHostname || ""}`; + private clearSessionAuth(safeHostname: string): Promise { + return this.clearSecret(SESSION_KEY_PREFIX, safeHostname); } /** @@ -181,7 +232,10 @@ export class SecretsManager { * Clear all auth data for a deployment and remove it from the usage list. */ public async clearAllAuthData(safeHostname: string): Promise { - await this.clearSessionAuth(safeHostname); + await Promise.all([ + this.clearSessionAuth(safeHostname), + this.clearOAuthClientRegistration(safeHostname), + ]); const usage = this.getDeploymentUsage().filter( (u) => u.safeHostname !== safeHostname, ); @@ -234,4 +288,56 @@ export class SecretsManager { return safeHostname; } + + /** + * Write an OAuth callback result to secrets storage. + * Used for cross-window communication when OAuth callback arrives in a different window. + */ + public async setOAuthCallback(data: OAuthCallbackData): Promise { + await this.secrets.store(OAUTH_CALLBACK_KEY, JSON.stringify(data)); + } + + /** + * Listen for OAuth callback results from any VS Code window. + * The listener receives the state parameter, code (if success), and error (if failed). + */ + public onDidChangeOAuthCallback( + listener: (data: OAuthCallbackData) => void, + ): Disposable { + return this.secrets.onDidChange(async (e) => { + if (e.key !== OAUTH_CALLBACK_KEY) { + return; + } + + try { + const data = await this.secrets.get(OAUTH_CALLBACK_KEY); + if (data) { + const parsed = JSON.parse(data) as OAuthCallbackData; + listener(parsed); + } + } catch { + // Ignore parse errors + } + }); + } + + public getOAuthClientRegistration( + safeHostname: string, + ): Promise { + return this.getSecret( + OAUTH_CLIENT_PREFIX, + safeHostname, + ); + } + + public setOAuthClientRegistration( + safeHostname: string, + registration: ClientRegistrationResponse, + ): Promise { + return this.setSecret(OAUTH_CLIENT_PREFIX, safeHostname, registration); + } + + public clearOAuthClientRegistration(safeHostname: string): Promise { + return this.clearSecret(OAUTH_CLIENT_PREFIX, safeHostname); + } } diff --git a/src/deployment/deploymentManager.ts b/src/deployment/deploymentManager.ts index 850d2176..a2fa241e 100644 --- a/src/deployment/deploymentManager.ts +++ b/src/deployment/deploymentManager.ts @@ -1,17 +1,18 @@ import { CoderApi } from "../api/coderApi"; +import { type ServiceContainer } from "../core/container"; +import { type ContextManager } from "../core/contextManager"; +import { type MementoManager } from "../core/mementoManager"; +import { type SecretsManager } from "../core/secretsManager"; +import { type Logger } from "../logging/logger"; +import { type OAuthInterceptor } from "../oauth/axiosInterceptor"; +import { type OAuthSessionManager } from "../oauth/sessionManager"; +import { type WorkspaceProvider } from "../workspace/workspacesProvider"; + +import { type Deployment, type DeploymentWithAuth } from "./types"; import type { User } from "coder/site/src/api/typesGenerated"; import type * as vscode from "vscode"; -import type { ServiceContainer } from "../core/container"; -import type { ContextManager } from "../core/contextManager"; -import type { MementoManager } from "../core/mementoManager"; -import type { SecretsManager } from "../core/secretsManager"; -import type { Logger } from "../logging/logger"; -import type { WorkspaceProvider } from "../workspace/workspacesProvider"; - -import type { Deployment, DeploymentWithAuth } from "./types"; - /** * Internal state type that allows mutation of user property. */ @@ -23,6 +24,7 @@ type DeploymentWithUser = Deployment & { user: User }; * Centralizes: * - In-memory deployment state (url, label, token, user) * - Client credential updates + * - OAuth session management * - Auth listener registration * - Context updates (coder.authenticated, coder.isOwner) * - Workspace provider refresh @@ -41,6 +43,8 @@ export class DeploymentManager implements vscode.Disposable { private constructor( serviceContainer: ServiceContainer, private readonly client: CoderApi, + private readonly oauthSessionManager: OAuthSessionManager, + private readonly oauthInterceptor: OAuthInterceptor, private readonly workspaceProviders: WorkspaceProvider[], ) { this.secretsManager = serviceContainer.getSecretsManager(); @@ -52,11 +56,15 @@ export class DeploymentManager implements vscode.Disposable { public static create( serviceContainer: ServiceContainer, client: CoderApi, + oauthSessionManager: OAuthSessionManager, + oauthInterceptor: OAuthInterceptor, workspaceProviders: WorkspaceProvider[], ): DeploymentManager { const manager = new DeploymentManager( serviceContainer, client, + oauthSessionManager, + oauthInterceptor, workspaceProviders, ); manager.subscribeToCrossWindowChanges(); @@ -125,9 +133,14 @@ export class DeploymentManager implements vscode.Disposable { this.client.setCredentials(deployment.url, deployment.token); } + // Register auth listener before setDeployment so background token refresh + // can update client credentials via the listener this.registerAuthListener(); this.updateAuthContexts(); this.refreshWorkspaces(); + + await this.oauthSessionManager.setDeployment(deployment); + await this.oauthInterceptor.setDeployment(deployment.safeHostname); await this.persistDeployment(deployment); } @@ -140,6 +153,8 @@ export class DeploymentManager implements vscode.Disposable { this.#deployment = null; this.client.setCredentials(undefined, undefined); + this.oauthSessionManager.clearDeployment(); + this.oauthInterceptor.clearDeployment(); this.updateAuthContexts(); this.refreshWorkspaces(); diff --git a/src/extension.ts b/src/extension.ts index eceb112f..a448a73b 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -13,6 +13,8 @@ import { ServiceContainer } from "./core/container"; import { type SecretsManager } from "./core/secretsManager"; import { DeploymentManager } from "./deployment/deploymentManager"; import { CertificateError, getErrorDetail } from "./error"; +import { OAuthInterceptor } from "./oauth/axiosInterceptor"; +import { OAuthSessionManager } from "./oauth/sessionManager"; import { Remote } from "./remote/remote"; import { getRemoteSshExtension } from "./remote/sshExtension"; import { registerUriHandler } from "./uri/uriHandler"; @@ -67,6 +69,13 @@ export async function activate(ctx: vscode.ExtensionContext): Promise { const deployment = await secretsManager.getCurrentDeployment(); + // Create OAuth session manager with login coordinator + const oauthSessionManager = OAuthSessionManager.create( + deployment, + serviceContainer, + ); + ctx.subscriptions.push(oauthSessionManager); + // This client tracks the current login and will be used through the life of // the plugin to poll workspaces for the current login, as well as being used // in commands that operate on the current login. @@ -78,6 +87,16 @@ export async function activate(ctx: vscode.ExtensionContext): Promise { ); ctx.subscriptions.push(client); + // Create OAuth interceptor - auto attaches/detaches based on token state + const oauthInterceptor = await OAuthInterceptor.create( + client, + output, + oauthSessionManager, + secretsManager, + deployment?.safeHostname ?? "", + ); + ctx.subscriptions.push(oauthInterceptor); + const myWorkspacesProvider = new WorkspaceProvider( WorkspaceQuery.Mine, client, @@ -122,10 +141,13 @@ export async function activate(ctx: vscode.ExtensionContext): Promise { ); // Create deployment manager to centralize deployment state management - const deploymentManager = DeploymentManager.create(serviceContainer, client, [ - myWorkspacesProvider, - allWorkspacesProvider, - ]); + const deploymentManager = DeploymentManager.create( + serviceContainer, + client, + oauthSessionManager, + oauthInterceptor, + [myWorkspacesProvider, allWorkspacesProvider], + ); ctx.subscriptions.push(deploymentManager); // Register globally available commands. Many of these have visibility diff --git a/src/login/loginCoordinator.ts b/src/login/loginCoordinator.ts index 7e5a66d7..aaf275a6 100644 --- a/src/login/loginCoordinator.ts +++ b/src/login/loginCoordinator.ts @@ -5,18 +5,20 @@ import * as vscode from "vscode"; import { CoderApi } from "../api/coderApi"; import { needToken } from "../api/utils"; import { CertificateError } from "../error"; -import { maybeAskUrl } from "../promptUtils"; +import { OAuthAuthorizer } from "../oauth/authorizer"; +import { buildOAuthTokenData } from "../oauth/utils"; +import { maybeAskAuthMethod, maybeAskUrl } from "../promptUtils"; import type { User } from "coder/site/src/api/typesGenerated"; import type { MementoManager } from "../core/mementoManager"; -import type { SecretsManager } from "../core/secretsManager"; +import type { OAuthTokenData, SecretsManager } from "../core/secretsManager"; import type { Deployment } from "../deployment/types"; import type { Logger } from "../logging/logger"; type LoginResult = | { success: false } - | { success: true; user: User; token: string }; + | { success: true; user: User; token: string; oauth?: OAuthTokenData }; interface LoginOptions { safeHostname: string; @@ -28,15 +30,23 @@ interface LoginOptions { /** * Coordinates login prompts across windows and prevents duplicate dialogs. */ -export class LoginCoordinator { - private readonly inProgressLogins = new Map>(); +export class LoginCoordinator implements vscode.Disposable { + private loginQueue: Promise = Promise.resolve(); + private readonly oauthAuthorizer: OAuthAuthorizer; constructor( private readonly secretsManager: SecretsManager, private readonly mementoManager: MementoManager, private readonly vscodeProposed: typeof vscode, private readonly logger: Logger, - ) {} + extensionId: string, + ) { + this.oauthAuthorizer = new OAuthAuthorizer( + secretsManager, + logger, + extensionId, + ); + } /** * Direct login - for user-initiated login via commands. @@ -46,7 +56,7 @@ export class LoginCoordinator { options: LoginOptions & { url: string }, ): Promise { const { safeHostname, url } = options; - return this.executeWithGuard(safeHostname, async () => { + return this.executeWithGuard(async () => { const result = await this.attemptLogin( { safeHostname, url }, options.autoLogin ?? false, @@ -60,13 +70,13 @@ export class LoginCoordinator { } /** - * Shows dialog then login - for system-initiated auth (remote). + * Shows dialog then login - for system-initiated auth (remote, OAuth refresh). */ public async ensureLoggedInWithDialog( options: LoginOptions & { message?: string; detailPrefix?: string }, ): Promise { const { safeHostname, url, detailPrefix, message } = options; - return this.executeWithGuard(safeHostname, async () => { + return this.executeWithGuard(async () => { // Show dialog promise const dialogPromise = this.vscodeProposed.window .showErrorMessage( @@ -132,31 +142,21 @@ export class LoginCoordinator { await this.secretsManager.setSessionAuth(safeHostname, { url, token: result.token, + oauth: result.oauth, // undefined for non-OAuth logins }); await this.mementoManager.addToUrlHistory(url); } } /** - * Same-window guard wrapper. + * Chains login attempts to prevent overlapping UI. */ - private async executeWithGuard( - safeHostname: string, + private executeWithGuard( executeFn: () => Promise, ): Promise { - const existingLogin = this.inProgressLogins.get(safeHostname); - if (existingLogin) { - return existingLogin; - } - - const loginPromise = executeFn(); - this.inProgressLogins.set(safeHostname, loginPromise); - - try { - return await loginPromise; - } finally { - this.inProgressLogins.delete(safeHostname); - } + const result = this.loginQueue.then(executeFn); + this.loginQueue = result.catch(() => {}); // Keep chain going on error + return result; } /** @@ -193,7 +193,7 @@ export class LoginCoordinator { } /** - * Attempt to authenticate using token, or mTLS. If necessary, prompts + * Attempt to authenticate using OAuth, token, or mTLS. If necessary, prompts * for authentication method and credentials. Returns the token and user upon * successful authentication. Null means the user aborted or authentication * failed (in which case an error notification will have been displayed). @@ -234,7 +234,15 @@ export class LoginCoordinator { } // Prompt user for token - return this.loginWithToken(client); + const authMethod = await maybeAskAuthMethod(client); + switch (authMethod) { + case "oauth": + return this.loginWithOAuth(deployment); + case "legacy": + return this.loginWithToken(client); + case undefined: + return { success: false }; // User aborted + } } private async tryMtlsAuth( @@ -346,4 +354,49 @@ export class LoginCoordinator { return { success: false }; } + + /** + * OAuth authentication flow. + */ + private async loginWithOAuth(deployment: Deployment): Promise { + try { + this.logger.info("Starting OAuth authentication"); + + const { tokenResponse, user } = await vscode.window.withProgress( + { + location: vscode.ProgressLocation.Notification, + title: "Authenticating", + cancellable: true, + }, + async (progress, cancellationToken) => + await this.oauthAuthorizer.login( + deployment, + progress, + cancellationToken, + ), + ); + + return { + success: true, + token: tokenResponse.access_token, + user, + oauth: buildOAuthTokenData(tokenResponse), + }; + } catch (error) { + const title = "OAuth authentication failed"; + this.logger.error(title, error); + if (error instanceof CertificateError) { + error.showNotification(title); + } else { + vscode.window.showErrorMessage( + `${title}: ${getErrorMessage(error, "Unknown error")}`, + ); + } + return { success: false }; + } + } + + public dispose(): void { + this.oauthAuthorizer.dispose(); + } } diff --git a/src/oauth/authorizer.ts b/src/oauth/authorizer.ts new file mode 100644 index 00000000..b03847af --- /dev/null +++ b/src/oauth/authorizer.ts @@ -0,0 +1,347 @@ +import { type AxiosInstance } from "axios"; +import { type User } from "coder/site/src/api/typesGenerated"; +import * as vscode from "vscode"; + +import { CoderApi } from "../api/coderApi"; +import { type SecretsManager } from "../core/secretsManager"; +import { type Deployment } from "../deployment/types"; +import { type Logger } from "../logging/logger"; + +import { OAuthMetadataClient } from "./metadataClient"; +import { + CALLBACK_PATH, + generatePKCE, + generateState, + toUrlSearchParams, +} from "./utils"; + +import type { + ClientRegistrationRequest, + ClientRegistrationResponse, + OAuthServerMetadata, + TokenRequestParams, + TokenResponse, +} from "./types"; + +const AUTH_GRANT_TYPE = "authorization_code"; +const RESPONSE_TYPE = "code"; +const PKCE_CHALLENGE_METHOD = "S256"; + +/** + * Minimal scopes required by the VS Code extension. + */ +const DEFAULT_OAUTH_SCOPES = [ + "workspace:read", + "workspace:update", + "workspace:start", + "workspace:ssh", + "workspace:application_connect", + "template:read", + "user:read_personal", +].join(" "); + +/** + * Handles the OAuth authorization code flow for authenticating with Coder deployments. + * Encapsulates client registration, PKCE challenge, and token exchange. + */ +export class OAuthAuthorizer implements vscode.Disposable { + private pendingAuthReject: ((error: Error) => void) | null = null; + + constructor( + private readonly secretsManager: SecretsManager, + private readonly logger: Logger, + private readonly extensionId: string, + ) {} + + /** + * Perform complete OAuth login flow. + * Creates CoderApi internally from deployment. + * Returns the token response and user - does not persist tokens. + */ + public async login( + deployment: Deployment, + progress: vscode.Progress<{ message?: string; increment?: number }>, + cancellationToken: vscode.CancellationToken, + ): Promise<{ tokenResponse: TokenResponse; user: User }> { + const reportProgress = (message?: string, increment?: number): void => { + if (cancellationToken.isCancellationRequested) { + throw new Error("OAuth login cancelled by user"); + } + progress.report({ message, increment }); + }; + + const client = CoderApi.create(deployment.url, undefined, this.logger); + const axiosInstance = client.getAxiosInstance(); + + reportProgress("fetching metadata...", 10); + const metadataClient = new OAuthMetadataClient(axiosInstance, this.logger); + const metadata = await metadataClient.getMetadata(); + + reportProgress("registering client...", 10); + const registration = await this.registerClient( + deployment, + axiosInstance, + metadata, + ); + + reportProgress("waiting for authorization...", 30); + const { code, verifier } = await this.startAuthorization( + metadata, + registration, + cancellationToken, + ); + + reportProgress("exchanging token...", 30); + const tokenResponse = await this.exchangeToken( + code, + verifier, + axiosInstance, + metadata, + registration, + ); + + // Set token on client to fetch user + client.setSessionToken(tokenResponse.access_token); + + reportProgress("fetching user...", 20); + const user = await client.getAuthenticatedUser(); + + this.logger.info("OAuth login flow completed successfully"); + + return { + tokenResponse, + user, + }; + } + + /** + * Get the redirect URI for OAuth callbacks. + */ + private getRedirectUri(): string { + return `${vscode.env.uriScheme}://${this.extensionId}${CALLBACK_PATH}`; + } + + /** + * Register OAuth client or return existing if still valid. + * Re-registers if redirect URI has changed. + */ + private async registerClient( + deployment: Deployment, + axiosInstance: AxiosInstance, + metadata: OAuthServerMetadata, + ): Promise { + const redirectUri = this.getRedirectUri(); + + const existing = await this.secretsManager.getOAuthClientRegistration( + deployment.safeHostname, + ); + if (existing?.client_id) { + if (existing.redirect_uris.includes(redirectUri)) { + this.logger.debug( + "Using existing client registration:", + existing.client_id, + ); + return existing; + } + this.logger.debug("Redirect URI changed, re-registering client"); + } + + if (!metadata.registration_endpoint) { + throw new Error("Server does not support dynamic client registration"); + } + + const registrationRequest: ClientRegistrationRequest = { + redirect_uris: [redirectUri], + application_type: "web", + grant_types: ["authorization_code"], + response_types: ["code"], + client_name: "VS Code Coder Extension", + token_endpoint_auth_method: "client_secret_post", + }; + + const response = await axiosInstance.post( + metadata.registration_endpoint, + registrationRequest, + ); + + await this.secretsManager.setOAuthClientRegistration( + deployment.safeHostname, + response.data, + ); + this.logger.info( + "Saved OAuth client registration:", + response.data.client_id, + ); + + return response.data; + } + + /** + * Build authorization URL with all required OAuth 2.1 parameters. + */ + private buildAuthorizationUrl( + metadata: OAuthServerMetadata, + clientId: string, + state: string, + challenge: string, + ): string { + if (metadata.scopes_supported) { + const requestedScopes = DEFAULT_OAUTH_SCOPES.split(" "); + const unsupportedScopes = requestedScopes.filter( + (s) => !metadata.scopes_supported?.includes(s), + ); + if (unsupportedScopes.length > 0) { + this.logger.warn( + `Requested scopes not in server's supported scopes: ${unsupportedScopes.join(", ")}. Server may still accept them.`, + { supported_scopes: metadata.scopes_supported }, + ); + } + } + + const params = new URLSearchParams({ + client_id: clientId, + response_type: RESPONSE_TYPE, + redirect_uri: this.getRedirectUri(), + scope: DEFAULT_OAUTH_SCOPES, + state, + code_challenge: challenge, + code_challenge_method: PKCE_CHALLENGE_METHOD, + }); + + const url = `${metadata.authorization_endpoint}?${params.toString()}`; + + this.logger.debug("Built OAuth authorization URL:", { + client_id: clientId, + redirect_uri: this.getRedirectUri(), + scope: DEFAULT_OAUTH_SCOPES, + }); + + return url; + } + + /** + * Start OAuth authorization flow. + * Opens browser for user authentication and waits for callback. + * Returns authorization code and PKCE verifier on success. + */ + private async startAuthorization( + metadata: OAuthServerMetadata, + registration: ClientRegistrationResponse, + cancellationToken: vscode.CancellationToken, + ): Promise<{ code: string; verifier: string }> { + const state = generateState(); + const { verifier, challenge } = generatePKCE(); + + const authUrl = this.buildAuthorizationUrl( + metadata, + registration.client_id, + state, + challenge, + ); + + const callbackPromise = new Promise<{ code: string; verifier: string }>( + (resolve, reject) => { + // Track reject for disposal + this.pendingAuthReject = reject; + + const timeoutMins = 5; + const timeoutHandle = setTimeout( + () => { + cleanup(); + reject( + new Error(`OAuth flow timed out after ${timeoutMins} minutes`), + ); + }, + timeoutMins * 60 * 1000, + ); + + const listener = this.secretsManager.onDidChangeOAuthCallback( + ({ state: callbackState, code, error }) => { + if (callbackState !== state) { + return; + } + + cleanup(); + + if (error) { + reject(new Error(`OAuth error: ${error}`)); + } else if (code) { + resolve({ code, verifier }); + } else { + reject(new Error("No authorization code received")); + } + }, + ); + + const cancellationListener = cancellationToken.onCancellationRequested( + () => { + cleanup(); + reject(new Error("OAuth flow cancelled by user")); + }, + ); + + const cleanup = () => { + this.pendingAuthReject = null; + clearTimeout(timeoutHandle); + listener.dispose(); + cancellationListener.dispose(); + }; + }, + ); + + try { + await vscode.env.openExternal(vscode.Uri.parse(authUrl)); + } catch (error) { + throw error instanceof Error + ? error + : new Error("Failed to open browser"); + } + + return callbackPromise; + } + + /** + * Exchange authorization code for access token. + */ + private async exchangeToken( + code: string, + verifier: string, + axiosInstance: AxiosInstance, + metadata: OAuthServerMetadata, + registration: ClientRegistrationResponse, + ): Promise { + this.logger.debug("Exchanging authorization code for token"); + + const params: TokenRequestParams = { + grant_type: AUTH_GRANT_TYPE, + code, + redirect_uri: this.getRedirectUri(), + client_id: registration.client_id, + client_secret: registration.client_secret, + code_verifier: verifier, + }; + + const tokenRequest = toUrlSearchParams(params); + + const response = await axiosInstance.post( + metadata.token_endpoint, + tokenRequest, + { + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + }, + ); + + this.logger.debug("Token exchange successful"); + + return response.data; + } + + public dispose(): void { + if (this.pendingAuthReject) { + this.pendingAuthReject(new Error("OAuthAuthorizer disposed")); + this.pendingAuthReject = null; + } + } +} diff --git a/src/oauth/axiosInterceptor.ts b/src/oauth/axiosInterceptor.ts new file mode 100644 index 00000000..f2ba68a2 --- /dev/null +++ b/src/oauth/axiosInterceptor.ts @@ -0,0 +1,179 @@ +import { type AxiosError, isAxiosError } from "axios"; + +import type * as vscode from "vscode"; + +import type { CoderApi } from "../api/coderApi"; +import type { SecretsManager } from "../core/secretsManager"; +import type { Logger } from "../logging/logger"; +import type { RequestConfigWithMeta } from "../logging/types"; + +import type { OAuthSessionManager } from "./sessionManager"; + +const coderSessionTokenHeader = "Coder-Session-Token"; + +/** + * Manages OAuth interceptor lifecycle reactively based on token presence. + * + * Automatically attaches/detaches the interceptor when OAuth tokens appear/disappear + * in secrets storage. This ensures the interceptor state always matches the actual + * OAuth authentication state. + */ +export class OAuthInterceptor implements vscode.Disposable { + private interceptorId: number | null = null; + private tokenListener: vscode.Disposable | undefined; + private safeHostname: string; + + private constructor( + private readonly client: CoderApi, + private readonly logger: Logger, + private readonly oauthSessionManager: OAuthSessionManager, + private readonly secretsManager: SecretsManager, + safeHostname: string, + ) { + this.safeHostname = safeHostname; + } + + public static async create( + client: CoderApi, + logger: Logger, + oauthSessionManager: OAuthSessionManager, + secretsManager: SecretsManager, + safeHostname: string, + ): Promise { + const instance = new OAuthInterceptor( + client, + logger, + oauthSessionManager, + secretsManager, + safeHostname, + ); + + instance.setupTokenListener(); + await instance.syncWithTokenState(); + return instance; + } + + public async setDeployment(safeHostname: string): Promise { + if (this.safeHostname === safeHostname) { + return; + } + + this.safeHostname = safeHostname; + this.detach(); + this.setupTokenListener(); + await this.syncWithTokenState(); + } + + public clearDeployment(): void { + this.tokenListener?.dispose(); + this.tokenListener = undefined; + this.detach(); + } + + private setupTokenListener(): void { + this.tokenListener?.dispose(); + + if (!this.safeHostname) { + this.tokenListener = undefined; + return; + } + + this.tokenListener = this.secretsManager.onDidChangeSessionAuth( + this.safeHostname, + () => { + this.syncWithTokenState().catch((err) => { + this.logger.error("Error syncing OAuth interceptor state:", err); + }); + }, + ); + } + + /** + * Sync interceptor state with OAuth token presence. + * Attaches when tokens exist, detaches when they don't. + */ + private async syncWithTokenState(): Promise { + const isOAuth = await this.oauthSessionManager.isLoggedInWithOAuth(); + if (isOAuth && this.interceptorId === null) { + this.attach(); + } else if (!isOAuth && this.interceptorId !== null) { + this.detach(); + } + } + + private attach(): void { + if (this.interceptorId !== null) { + return; + } + + this.interceptorId = this.client + .getAxiosInstance() + .interceptors.response.use( + (r) => r, + (error: unknown) => this.handleError(error), + ); + + this.logger.debug("OAuth interceptor attached"); + } + + private detach(): void { + if (this.interceptorId === null) { + return; + } + + this.client + .getAxiosInstance() + .interceptors.response.eject(this.interceptorId); + this.interceptorId = null; + this.logger.debug("OAuth interceptor detached"); + } + + private async handleError(error: unknown): Promise { + if (!isAxiosError(error)) { + throw error; + } + + if (error.config) { + const config = error.config as { _oauthRetryAttempted?: boolean }; + if (config._oauthRetryAttempted) { + throw error; + } + } + + if (error.response?.status === 401) { + return this.handle401Error(error); + } + + throw error; + } + + private async handle401Error(error: AxiosError): Promise { + this.logger.info("Received 401 response, attempting token refresh"); + + try { + const newTokens = await this.oauthSessionManager.refreshToken(); + this.client.setSessionToken(newTokens.access_token); + + this.logger.info("Token refresh successful, retrying request"); + + if (error.config) { + const config = error.config as RequestConfigWithMeta & { + _oauthRetryAttempted?: boolean; + }; + config._oauthRetryAttempted = true; + config.headers[coderSessionTokenHeader] = newTokens.access_token; + return this.client.getAxiosInstance().request(config); + } + + throw error; + } catch (refreshError) { + this.logger.error("Token refresh failed:", refreshError); + throw error; + } + } + + public dispose(): void { + this.tokenListener?.dispose(); + this.detach(); + } +} diff --git a/src/oauth/errors.ts b/src/oauth/errors.ts new file mode 100644 index 00000000..9b7ee3ac --- /dev/null +++ b/src/oauth/errors.ts @@ -0,0 +1,166 @@ +import { isAxiosError } from "axios"; + +import type { OAuthErrorResponse } from "./types"; + +/** + * Base class for OAuth errors + */ +export class OAuthError extends Error { + constructor( + message: string, + public readonly errorCode: string, + public readonly description?: string, + public readonly errorUri?: string, + ) { + super(message); + this.name = "OAuthError"; + } +} + +/** + * Refresh token is invalid, expired, or revoked. Requires re-authentication. + */ +export class InvalidGrantError extends OAuthError { + constructor(description?: string, errorUri?: string) { + super( + "OAuth refresh token is invalid, expired, or revoked", + "invalid_grant", + description, + errorUri, + ); + this.name = "InvalidGrantError"; + } +} + +/** + * Client credentials are invalid. Requires re-registration. + */ +export class InvalidClientError extends OAuthError { + constructor(description?: string, errorUri?: string) { + super( + "OAuth client credentials are invalid", + "invalid_client", + description, + errorUri, + ); + this.name = "InvalidClientError"; + } +} + +/** + * Invalid request error - malformed OAuth request + */ +export class InvalidRequestError extends OAuthError { + constructor(description?: string, errorUri?: string) { + super( + "OAuth request is malformed or invalid", + "invalid_request", + description, + errorUri, + ); + this.name = "InvalidRequestError"; + } +} + +/** + * Client is not authorized for this grant type. + */ +export class UnauthorizedClientError extends OAuthError { + constructor(description?: string, errorUri?: string) { + super( + "OAuth client is not authorized for this grant type", + "unauthorized_client", + description, + errorUri, + ); + this.name = "UnauthorizedClientError"; + } +} + +/** + * Unsupported grant type error. + */ +export class UnsupportedGrantTypeError extends OAuthError { + constructor(description?: string, errorUri?: string) { + super( + "OAuth grant type is not supported", + "unsupported_grant_type", + description, + errorUri, + ); + this.name = "UnsupportedGrantTypeError"; + } +} + +/** + * Invalid scope error. + */ +export class InvalidScopeError extends OAuthError { + constructor(description?: string, errorUri?: string) { + super( + "OAuth scope is invalid, unknown, malformed, or exceeds the scope granted by the resource owner", + "invalid_scope", + description, + errorUri, + ); + this.name = "InvalidScopeError"; + } +} + +/** + * Parses an axios error to extract OAuth error information + * Returns an OAuthError instance if the error is OAuth-related, otherwise returns null + */ +export function parseOAuthError(error: unknown): OAuthError | null { + if (!isAxiosError(error)) { + return null; + } + + const data = error.response?.data; + + if (!isOAuthErrorResponse(data)) { + return null; + } + + const { error: errorCode, error_description, error_uri } = data; + + switch (errorCode) { + case "invalid_grant": + return new InvalidGrantError(error_description, error_uri); + case "invalid_client": + return new InvalidClientError(error_description, error_uri); + case "invalid_request": + return new InvalidRequestError(error_description, error_uri); + case "unauthorized_client": + return new UnauthorizedClientError(error_description, error_uri); + case "unsupported_grant_type": + return new UnsupportedGrantTypeError(error_description, error_uri); + case "invalid_scope": + return new InvalidScopeError(error_description, error_uri); + default: + return new OAuthError( + `OAuth error: ${errorCode}`, + errorCode, + error_description, + error_uri, + ); + } +} + +function isOAuthErrorResponse(data: unknown): data is OAuthErrorResponse { + return ( + data !== null && + typeof data === "object" && + "error" in data && + typeof data.error === "string" + ); +} + +/** + * Checks if an error requires re-authentication + */ +export function requiresReAuthentication(error: OAuthError): boolean { + return ( + error instanceof InvalidGrantError || error instanceof InvalidClientError + ); +} diff --git a/src/oauth/metadataClient.ts b/src/oauth/metadataClient.ts new file mode 100644 index 00000000..38e25e7b --- /dev/null +++ b/src/oauth/metadataClient.ts @@ -0,0 +1,142 @@ +import type { AxiosInstance } from "axios"; + +import type { Logger } from "../logging/logger"; + +import type { + GrantType, + OAuthServerMetadata, + ResponseType, + TokenEndpointAuthMethod, +} from "./types"; + +const OAUTH_DISCOVERY_ENDPOINT = "/.well-known/oauth-authorization-server"; + +const AUTH_GRANT_TYPE = "authorization_code" as const; +const REFRESH_GRANT_TYPE = "refresh_token" as const; +const RESPONSE_TYPE = "code" as const; +const OAUTH_METHOD = "client_secret_post" as const; +const PKCE_CHALLENGE_METHOD = "S256" as const; + +const REQUIRED_GRANT_TYPES = [AUTH_GRANT_TYPE, REFRESH_GRANT_TYPE] as const; + +// RFC 8414 defaults when fields are omitted +const DEFAULT_GRANT_TYPES = [AUTH_GRANT_TYPE] as GrantType[]; +const DEFAULT_RESPONSE_TYPES = [RESPONSE_TYPE] as ResponseType[]; +const DEFAULT_AUTH_METHODS = [ + "client_secret_basic", +] as TokenEndpointAuthMethod[]; + +/** + * Client for discovering and validating OAuth server metadata. + */ +export class OAuthMetadataClient { + constructor( + private readonly axiosInstance: AxiosInstance, + private readonly logger: Logger, + ) {} + + /** + * Check if a server supports OAuth by attempting to fetch the well-known endpoint. + */ + public static async checkOAuthSupport( + axiosInstance: AxiosInstance, + ): Promise { + try { + await axiosInstance.get(OAUTH_DISCOVERY_ENDPOINT); + return true; + } catch { + return false; + } + } + + /** + * Fetch and validate OAuth server metadata. + * Throws detailed errors if server doesn't meet OAuth 2.1 requirements. + */ + async getMetadata(): Promise { + this.logger.debug("Discovering OAuth endpoints..."); + + const response = await this.axiosInstance.get( + OAUTH_DISCOVERY_ENDPOINT, + ); + + const metadata = response.data; + + this.validateRequiredEndpoints(metadata); + this.validateGrantTypes(metadata); + this.validateResponseTypes(metadata); + this.validateAuthMethods(metadata); + this.validatePKCEMethods(metadata); + + this.logger.debug("OAuth endpoints discovered:", { + authorization: metadata.authorization_endpoint, + token: metadata.token_endpoint, + registration: metadata.registration_endpoint, + revocation: metadata.revocation_endpoint, + }); + + return metadata; + } + + private validateRequiredEndpoints(metadata: OAuthServerMetadata): void { + if ( + !metadata.authorization_endpoint || + !metadata.token_endpoint || + !metadata.issuer + ) { + throw new Error( + "OAuth server metadata missing required endpoints: " + + JSON.stringify(metadata), + ); + } + } + + private validateGrantTypes(metadata: OAuthServerMetadata): void { + const supported = metadata.grant_types_supported ?? DEFAULT_GRANT_TYPES; + if (!includesAllTypes(supported, REQUIRED_GRANT_TYPES)) { + throw new Error( + `Server does not support required grant types: ${REQUIRED_GRANT_TYPES.join(", ")}. Supported: ${supported.join(", ")}`, + ); + } + } + + private validateResponseTypes(metadata: OAuthServerMetadata): void { + const supported = + metadata.response_types_supported ?? DEFAULT_RESPONSE_TYPES; + if (!includesAllTypes(supported, [RESPONSE_TYPE])) { + throw new Error( + `Server does not support required response type: ${RESPONSE_TYPE}. Supported: ${supported.join(", ")}`, + ); + } + } + + private validateAuthMethods(metadata: OAuthServerMetadata): void { + const supported = + metadata.token_endpoint_auth_methods_supported ?? DEFAULT_AUTH_METHODS; + if (!includesAllTypes(supported, [OAUTH_METHOD])) { + throw new Error( + `Server does not support required auth method: ${OAUTH_METHOD}. Supported: ${supported.join(", ")}`, + ); + } + } + + private validatePKCEMethods(metadata: OAuthServerMetadata): void { + // PKCE has no RFC 8414 default - if undefined, server doesn't advertise support + const supported = metadata.code_challenge_methods_supported ?? []; + if (!includesAllTypes(supported, [PKCE_CHALLENGE_METHOD])) { + throw new Error( + `Server does not support required PKCE method: ${PKCE_CHALLENGE_METHOD}. Supported: ${supported.length > 0 ? supported.join(", ") : "none"}`, + ); + } + } +} + +/** + * Check if an array includes all required types. + */ +function includesAllTypes( + arr: readonly string[], + requiredTypes: readonly string[], +): boolean { + return requiredTypes.every((type) => arr.includes(type)); +} diff --git a/src/oauth/sessionManager.ts b/src/oauth/sessionManager.ts new file mode 100644 index 00000000..7d3b1b51 --- /dev/null +++ b/src/oauth/sessionManager.ts @@ -0,0 +1,601 @@ +import { type AxiosInstance } from "axios"; + +import { CoderApi } from "../api/coderApi"; +import { type ServiceContainer } from "../core/container"; +import { + type OAuthTokenData, + type SecretsManager, + type SessionAuth, +} from "../core/secretsManager"; +import { type Deployment } from "../deployment/types"; +import { type Logger } from "../logging/logger"; +import { type LoginCoordinator } from "../login/loginCoordinator"; + +import { + type OAuthError, + parseOAuthError, + requiresReAuthentication, +} from "./errors"; +import { OAuthMetadataClient } from "./metadataClient"; +import { buildOAuthTokenData, toUrlSearchParams } from "./utils"; + +import type * as vscode from "vscode"; + +import type { + ClientRegistrationResponse, + OAuthServerMetadata, + RefreshTokenRequestParams, + TokenResponse, + TokenRevocationRequest, +} from "./types"; + +const REFRESH_GRANT_TYPE = "refresh_token"; + +/** + * Token refresh threshold: refresh when token expires in less than this time. + */ +const TOKEN_REFRESH_THRESHOLD_MS = 10 * 60 * 1000; + +/** + * Minimum time between refresh attempts to prevent thrashing. + */ +const REFRESH_THROTTLE_MS = 30 * 1000; + +/** + * Background token refresh check interval. + */ +const BACKGROUND_REFRESH_INTERVAL_MS = 5 * 60 * 1000; + +/** + * Minimal scopes required by the VS Code extension. + */ +const DEFAULT_OAUTH_SCOPES = [ + "workspace:read", + "workspace:update", + "workspace:start", + "workspace:ssh", + "workspace:application_connect", + "template:read", + "user:read_personal", +].join(" "); + +/** + * Internal type combining access token with OAuth-specific data. + * Used by getStoredTokens() for token refresh and validation. + */ +type StoredTokens = OAuthTokenData & { + access_token: string; +}; + +/** + * Manages OAuth session lifecycle for a Coder deployment. + * Coordinates authorization flow, token management, and automatic refresh. + */ +export class OAuthSessionManager implements vscode.Disposable { + private refreshPromise: Promise | null = null; + private lastRefreshAttempt = 0; + private refreshTimer: NodeJS.Timeout | undefined; + private tokenChangeListener: vscode.Disposable | undefined; + private disposed = false; + + /** + * Create and initialize a new OAuth session manager. + */ + public static create( + deployment: Deployment | null, + container: ServiceContainer, + ): OAuthSessionManager { + const manager = new OAuthSessionManager( + deployment, + container.getSecretsManager(), + container.getLogger(), + container.getLoginCoordinator(), + ); + manager.setupTokenListener(); + manager.scheduleNextRefresh(); + return manager; + } + + private constructor( + private deployment: Deployment | null, + private readonly secretsManager: SecretsManager, + private readonly logger: Logger, + private readonly loginCoordinator: LoginCoordinator, + ) {} + + /** + * Get current deployment, throwing if not set. + * Use this in methods that require a deployment to be configured. + */ + private requireDeployment(): Deployment { + if (!this.deployment) { + throw new Error("No deployment configured for OAuth session manager"); + } + return this.deployment; + } + + /** + * Get stored tokens fresh from secrets manager. + * Always reads from storage to ensure cross-window synchronization. + * Validates that tokens match current deployment URL and have required scopes. + * Invalid tokens are cleared and undefined is returned. + */ + private async getStoredTokens(): Promise { + if (!this.deployment) { + return undefined; + } + + const auth = await this.secretsManager.getSessionAuth( + this.deployment.safeHostname, + ); + if (!auth?.oauth) { + return undefined; + } + + // Validate deployment URL matches + if (auth.url !== this.deployment.url) { + this.logger.warn( + "Stored tokens have mismatched deployment URL, clearing OAuth", + { stored: auth.url, current: this.deployment.url }, + ); + await this.clearOAuthFromSessionAuth(auth); + return undefined; + } + + if (!this.hasRequiredScopes(auth.oauth.scope)) { + this.logger.warn("Stored tokens have insufficient scopes, clearing", { + scope: auth.oauth.scope, + }); + await this.clearOAuthFromSessionAuth(auth); + return undefined; + } + + return { + access_token: auth.token, + ...auth.oauth, + }; + } + + /** + * Clear OAuth data from session auth while preserving the session token. + */ + private async clearOAuthFromSessionAuth(auth: SessionAuth): Promise { + if (!this.deployment) { + return; + } + await this.secretsManager.setSessionAuth(this.deployment.safeHostname, { + url: auth.url, + token: auth.token, + }); + } + + /** + * Clear all refresh-related state: in-flight promise, throttle, timer, and listener. + */ + private clearRefreshState(): void { + this.refreshPromise = null; + this.lastRefreshAttempt = 0; + if (this.refreshTimer) { + clearTimeout(this.refreshTimer); + this.refreshTimer = undefined; + } + this.tokenChangeListener?.dispose(); + this.tokenChangeListener = undefined; + } + + /** + * Setup listener for token changes. Disposes existing listener first. + * Reschedules refresh when tokens change (e.g., from another window). + */ + private setupTokenListener(): void { + this.tokenChangeListener?.dispose(); + this.tokenChangeListener = undefined; + + if (!this.deployment) { + return; + } + + this.tokenChangeListener = this.secretsManager.onDidChangeSessionAuth( + this.deployment.safeHostname, + (auth) => { + if (auth?.oauth) { + this.scheduleNextRefresh(); + } else { + this.clearRefreshState(); + } + }, + ); + } + + /** + * Schedule the next token refresh based on expiry time. + * - Far from expiry: schedule once at threshold + * - Near/past expiry: attempt refresh immediately + */ + private scheduleNextRefresh(): void { + if (this.refreshTimer) { + clearTimeout(this.refreshTimer); + this.refreshTimer = undefined; + } + + this.getStoredTokens() + .then((storedTokens) => { + if (!storedTokens?.refresh_token) { + return; + } + + const now = Date.now(); + const timeUntilExpiry = storedTokens.expiry_timestamp - now; + + if (timeUntilExpiry <= TOKEN_REFRESH_THRESHOLD_MS) { + // Within threshold or expired, attempt refresh now + this.attemptRefreshWithRetry(); + } else { + // Schedule for when we reach the threshold + const delay = timeUntilExpiry - TOKEN_REFRESH_THRESHOLD_MS; + this.logger.debug( + `Scheduling token refresh in ${Math.round(delay / 1000 / 60)} minutes`, + ); + this.refreshTimer = setTimeout( + () => this.attemptRefreshWithRetry(), + delay, + ); + } + }) + .catch((error) => { + this.logger.warn("Failed to schedule token refresh:", error); + }); + } + + /** + * Attempt refresh, falling back to polling on failure. + */ + private attemptRefreshWithRetry(): void { + if (this.disposed) { + return; + } + + this.refreshTimer = undefined; + + this.refreshToken() + .then(() => { + this.logger.debug("Background token refresh succeeded"); + }) + .catch((error) => { + if (this.disposed) { + return; + } + this.logger.warn("Background token refresh failed, will retry:", error); + this.refreshTimer = setTimeout( + () => this.attemptRefreshWithRetry(), + BACKGROUND_REFRESH_INTERVAL_MS, + ); + }); + } + + /** + * Check if granted scopes cover all required scopes. + * Supports wildcard scopes like "workspace:*". + */ + private hasRequiredScopes(grantedScope: string | undefined): boolean { + if (!grantedScope) { + // TODO server always returns empty scopes + return true; + } + + const grantedScopes = new Set(grantedScope.split(" ")); + const requiredScopes = DEFAULT_OAUTH_SCOPES.split(" "); + + for (const required of requiredScopes) { + if (grantedScopes.has(required)) { + continue; + } + + // Check wildcard match (e.g., "workspace:*" grants "workspace:read") + const colonIndex = required.indexOf(":"); + if (colonIndex !== -1) { + const prefix = required.substring(0, colonIndex); + const wildcard = `${prefix}:*`; + if (grantedScopes.has(wildcard)) { + continue; + } + } + + return false; + } + + return true; + } + + /** + * Prepare common OAuth operation setup: client, metadata, and registration. + * Used by refresh and revoke operations to reduce duplication. + */ + private async prepareOAuthOperation(token?: string): Promise<{ + axiosInstance: AxiosInstance; + metadata: OAuthServerMetadata; + registration: ClientRegistrationResponse; + }> { + const deployment = this.requireDeployment(); + const client = CoderApi.create(deployment.url, token, this.logger); + const axiosInstance = client.getAxiosInstance(); + + const metadataClient = new OAuthMetadataClient(axiosInstance, this.logger); + const metadata = await metadataClient.getMetadata(); + + const registration = await this.secretsManager.getOAuthClientRegistration( + deployment.safeHostname, + ); + if (!registration) { + throw new Error("No client registration found"); + } + + return { axiosInstance, metadata, registration }; + } + + public async setDeployment(deployment: Deployment): Promise { + if ( + this.deployment && + deployment.safeHostname === this.deployment.safeHostname && + deployment.url === this.deployment.url + ) { + return; + } + this.logger.debug("Switching OAuth deployment", deployment); + this.deployment = deployment; + this.clearRefreshState(); + + // Block on refresh if token is expired to ensure valid state for callers + const storedTokens = await this.getStoredTokens(); + if (storedTokens && Date.now() >= storedTokens.expiry_timestamp) { + try { + await this.refreshToken(); + } catch (error) { + this.logger.warn("Token refresh failed (expired):", error); + } + } + + // Schedule after blocking refresh to avoid concurrent attempts + this.setupTokenListener(); + this.scheduleNextRefresh(); + } + + public clearDeployment(): void { + this.logger.debug("Clearing OAuth deployment state"); + this.deployment = null; + this.clearRefreshState(); + } + + /** + * Refresh the access token using the stored refresh token. + * Uses a shared promise to handle concurrent refresh attempts. + */ + public async refreshToken(): Promise { + if (this.refreshPromise) { + this.logger.debug( + "Token refresh already in progress, waiting for result", + ); + return this.refreshPromise; + } + + const deployment = this.requireDeployment(); + // Assign synchronously before any async work to prevent race conditions + this.refreshPromise = this.executeTokenRefresh(deployment); + return this.refreshPromise; + } + + private async executeTokenRefresh( + deployment: Deployment, + ): Promise { + try { + const storedTokens = await this.getStoredTokens(); + if (!storedTokens?.refresh_token) { + throw new Error("No refresh token available"); + } + + const refreshToken = storedTokens.refresh_token; + const accessToken = storedTokens.access_token; + + this.lastRefreshAttempt = Date.now(); + + const { axiosInstance, metadata, registration } = + await this.prepareOAuthOperation(accessToken); + + this.logger.debug("Refreshing access token"); + + const params: RefreshTokenRequestParams = { + grant_type: REFRESH_GRANT_TYPE, + refresh_token: refreshToken, + client_id: registration.client_id, + client_secret: registration.client_secret, + }; + + const tokenRequest = toUrlSearchParams(params); + + const response = await axiosInstance.post( + metadata.token_endpoint, + tokenRequest, + { + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + }, + ); + + this.logger.debug("Token refresh successful"); + + const oauthData = buildOAuthTokenData(response.data); + await this.secretsManager.setSessionAuth(deployment.safeHostname, { + url: deployment.url, + token: response.data.access_token, + oauth: oauthData, + }); + + return response.data; + } catch (error) { + this.handleOAuthError(error); + throw error; + } finally { + this.refreshPromise = null; + } + } + + /** + * Refreshes the token if it is approaching expiry. + */ + public async refreshIfAlmostExpired(): Promise { + if (await this.shouldRefreshToken()) { + this.logger.debug("Token approaching expiry, triggering refresh"); + await this.refreshToken(); + } + } + + /** + * Check if token should be refreshed. + * Returns true if: + * 1. Stored tokens exist with a refresh token + * 2. Token expires in less than TOKEN_REFRESH_THRESHOLD_MS + * 3. Last refresh attempt was more than REFRESH_THROTTLE_MS ago + * 4. No refresh is currently in progress + */ + private async shouldRefreshToken(): Promise { + const storedTokens = await this.getStoredTokens(); + if (!storedTokens?.refresh_token || this.refreshPromise !== null) { + return false; + } + + const now = Date.now(); + if (now - this.lastRefreshAttempt < REFRESH_THROTTLE_MS) { + return false; + } + + const timeUntilExpiry = storedTokens.expiry_timestamp - now; + return timeUntilExpiry < TOKEN_REFRESH_THRESHOLD_MS; + } + + public async revokeRefreshToken(): Promise { + const storedTokens = await this.getStoredTokens(); + if (!storedTokens?.refresh_token) { + this.logger.info("No refresh token to revoke"); + return; + } + + await this.revokeToken( + storedTokens.access_token, + storedTokens.refresh_token, + "refresh_token", + ); + } + + /** + * Revoke a token using the OAuth server's revocation endpoint. + * + * @param authToken - Token for authenticating the revocation request + * @param tokenToRevoke - The token to be revoked + * @param tokenTypeHint - Hint about the token type being revoked + */ + private async revokeToken( + authToken: string, + tokenToRevoke: string, + tokenTypeHint: "access_token" | "refresh_token" = "refresh_token", + ): Promise { + const { axiosInstance, metadata, registration } = + await this.prepareOAuthOperation(authToken); + + if (!metadata.revocation_endpoint) { + this.logger.info("No revocation endpoint available, skipping revocation"); + return; + } + + this.logger.info("Revoking refresh token"); + + const params: TokenRevocationRequest = { + token: tokenToRevoke, + client_id: registration.client_id, + client_secret: registration.client_secret, + token_type_hint: tokenTypeHint, + }; + + const revocationRequest = toUrlSearchParams(params); + + try { + await axiosInstance.post( + metadata.revocation_endpoint, + revocationRequest, + { + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + }, + ); + + this.logger.info("Token revocation successful"); + } catch (error) { + this.logger.error("Token revocation failed:", error); + throw error; + } + } + + /** + * Returns true if OAuth tokens exist for the current deployment. + * Always reads fresh from secrets to ensure cross-window synchronization. + */ + public async isLoggedInWithOAuth(): Promise { + const storedTokens = await this.getStoredTokens(); + return storedTokens !== undefined; + } + + /** + * Handle OAuth errors that may require re-authentication. + * Parses the error and triggers re-authentication modal if needed. + */ + private handleOAuthError(error: unknown): void { + const oauthError = parseOAuthError(error); + if (oauthError && requiresReAuthentication(oauthError)) { + this.logger.error( + `OAuth operation failed with error: ${oauthError.errorCode}`, + ); + // Fire and forget - don't block on showing the modal + this.showReAuthenticationModal(oauthError).catch((err) => { + this.logger.error("Failed to show re-auth modal:", err); + }); + } + } + + /** + * Show a modal dialog to the user when OAuth re-authentication is required. + * This is called when the refresh token is invalid or the client credentials are invalid. + * Clears tokens directly and lets listeners handle updates. + */ + public async showReAuthenticationModal(error: OAuthError): Promise { + const deployment = this.requireDeployment(); + const errorMessage = + error.description || + "Your session is no longer valid. This could be due to token expiration or revocation."; + + this.clearRefreshState(); + // Clear client registration and tokens to force full re-authentication + await this.secretsManager.clearOAuthClientRegistration( + deployment.safeHostname, + ); + await this.secretsManager.setSessionAuth(deployment.safeHostname, { + url: deployment.url, + token: "", + }); + + await this.loginCoordinator.ensureLoggedInWithDialog({ + safeHostname: deployment.safeHostname, + url: deployment.url, + detailPrefix: errorMessage, + }); + } + + /** + * Clears all in-memory state. + */ + public dispose(): void { + this.disposed = true; + this.clearDeployment(); + this.logger.debug("OAuth session manager disposed"); + } +} diff --git a/src/oauth/types.ts b/src/oauth/types.ts new file mode 100644 index 00000000..6ecaa0ff --- /dev/null +++ b/src/oauth/types.ts @@ -0,0 +1,163 @@ +// OAuth 2.1 Grant Types +export type GrantType = + | "authorization_code" + | "refresh_token" + | "client_credentials"; + +// OAuth 2.1 Response Types +export type ResponseType = "code"; + +// Token Endpoint Authentication Methods +export type TokenEndpointAuthMethod = + | "client_secret_post" + | "client_secret_basic" + | "none"; + +// Application Types +export type ApplicationType = "native" | "web"; + +// PKCE Code Challenge Methods (OAuth 2.1 requires S256) +export type CodeChallengeMethod = "S256"; + +// Token Types +export type TokenType = "Bearer" | "DPoP"; + +// Client Registration Request (RFC 7591 + OAuth 2.1) +export interface ClientRegistrationRequest { + redirect_uris: string[]; + token_endpoint_auth_method: TokenEndpointAuthMethod; + application_type: ApplicationType; + grant_types: GrantType[]; + response_types: ResponseType[]; + client_name?: string; + client_uri?: string; + logo_uri?: string; + scope?: string; + contacts?: string[]; + tos_uri?: string; + policy_uri?: string; + jwks_uri?: string; + software_id?: string; + software_version?: string; +} + +// Client Registration Response (RFC 7591) +export interface ClientRegistrationResponse { + client_id: string; + client_secret?: string; + client_id_issued_at?: number; + client_secret_expires_at?: number; + redirect_uris: string[]; + token_endpoint_auth_method: TokenEndpointAuthMethod; + application_type?: ApplicationType; + grant_types: GrantType[]; + response_types: ResponseType[]; + client_name?: string; + client_uri?: string; + logo_uri?: string; + scope?: string; + contacts?: string[]; + tos_uri?: string; + policy_uri?: string; + jwks_uri?: string; + software_id?: string; + software_version?: string; + registration_client_uri?: string; + registration_access_token?: string; +} + +// OAuth 2.1 Authorization Server Metadata (RFC 8414) +export interface OAuthServerMetadata { + issuer: string; + authorization_endpoint: string; + token_endpoint: string; + registration_endpoint?: string; + jwks_uri?: string; + response_types_supported: ResponseType[]; + grant_types_supported?: GrantType[]; + code_challenge_methods_supported: CodeChallengeMethod[]; + scopes_supported?: string[]; + token_endpoint_auth_methods_supported?: TokenEndpointAuthMethod[]; + revocation_endpoint?: string; + revocation_endpoint_auth_methods_supported?: TokenEndpointAuthMethod[]; + introspection_endpoint?: string; + introspection_endpoint_auth_methods_supported?: TokenEndpointAuthMethod[]; + service_documentation?: string; + ui_locales_supported?: string[]; +} + +// Token Response (RFC 6749 Section 5.1) +export interface TokenResponse { + access_token: string; + token_type: TokenType; + expires_in?: number; + refresh_token?: string; + scope?: string; +} + +// Authorization Request Parameters (OAuth 2.1) +export interface AuthorizationRequestParams { + client_id: string; + response_type: ResponseType; + redirect_uri: string; + scope?: string; + state: string; + code_challenge: string; + code_challenge_method: CodeChallengeMethod; +} + +// Token Request Parameters - Authorization Code Grant (OAuth 2.1) +export interface TokenRequestParams { + grant_type: "authorization_code"; + code: string; + redirect_uri: string; + client_id: string; + code_verifier: string; + client_secret?: string; +} + +// Token Request Parameters - Refresh Token Grant +export interface RefreshTokenRequestParams { + grant_type: "refresh_token"; + refresh_token: string; + client_id: string; + client_secret?: string; + scope?: string; +} + +// Token Request Parameters - Client Credentials Grant +export interface ClientCredentialsRequestParams { + grant_type: "client_credentials"; + client_id: string; + client_secret: string; + scope?: string; +} + +// Union type for all token request types +export type TokenRequestParamsUnion = + | TokenRequestParams + | RefreshTokenRequestParams + | ClientCredentialsRequestParams; + +// Token Revocation Request (RFC 7009) +export interface TokenRevocationRequest { + token: string; + token_type_hint?: "access_token" | "refresh_token"; + client_id: string; + client_secret?: string; +} + +// Error Response (RFC 6749 Section 5.2) +export interface OAuthErrorResponse { + error: + | "invalid_request" + | "invalid_client" + | "invalid_grant" + | "unauthorized_client" + | "unsupported_grant_type" + | "invalid_scope" + | "server_error" + | "temporarily_unavailable"; + error_description?: string; + error_uri?: string; +} diff --git a/src/oauth/utils.ts b/src/oauth/utils.ts new file mode 100644 index 00000000..733041df --- /dev/null +++ b/src/oauth/utils.ts @@ -0,0 +1,79 @@ +import { createHash, randomBytes } from "node:crypto"; + +import type { OAuthTokenData } from "../core/secretsManager"; + +import type { TokenResponse } from "./types"; + +/** + * OAuth callback path for handling authorization responses (RFC 6749). + */ +export const CALLBACK_PATH = "/oauth/callback"; + +/** + * Default expiry time for OAuth access tokens when the server doesn't provide one. + */ +const ACCESS_TOKEN_DEFAULT_EXPIRY_MS = 60 * 60 * 1000; + +export interface PKCEChallenge { + verifier: string; + challenge: string; +} + +/** + * Generates a PKCE challenge pair (RFC 7636). + * Creates a code verifier and its SHA256 challenge for secure OAuth flows. + */ +export function generatePKCE(): PKCEChallenge { + const verifier = randomBytes(32).toString("base64url"); + const challenge = createHash("sha256").update(verifier).digest("base64url"); + return { verifier, challenge }; +} + +/** + * Generates a cryptographically secure state parameter to prevent CSRF attacks (RFC 6749). + */ +export function generateState(): string { + return randomBytes(16).toString("base64url"); +} + +/** + * Converts an object with string properties to URLSearchParams, + * filtering out undefined values for use with OAuth requests. + */ +export function toUrlSearchParams(obj: object): URLSearchParams { + const params = Object.fromEntries( + Object.entries(obj).filter( + ([, value]) => value !== undefined && typeof value === "string", + ), + ) as Record; + + return new URLSearchParams(params); +} + +/** + * Build OAuthTokenData from a token response. + * Used by LoginCoordinator (initial login) and OAuthSessionManager (refresh). + */ +export function buildOAuthTokenData( + tokenResponse: TokenResponse, +): OAuthTokenData { + if (tokenResponse.token_type !== "Bearer") { + throw new Error( + `Unsupported token type: ${tokenResponse.token_type}. Only Bearer tokens are supported.`, + ); + } + + const expiresIn = tokenResponse.expires_in; + const hasValidExpiry = + expiresIn && expiresIn > 0 && Number.isFinite(expiresIn); + const expiryTimestamp = hasValidExpiry + ? Date.now() + expiresIn * 1000 + : Date.now() + ACCESS_TOKEN_DEFAULT_EXPIRY_MS; + + return { + token_type: tokenResponse.token_type, + refresh_token: tokenResponse.refresh_token, + scope: tokenResponse.scope, + expiry_timestamp: expiryTimestamp, + }; +} diff --git a/src/promptUtils.ts b/src/promptUtils.ts index 3fb31475..9e3d8895 100644 --- a/src/promptUtils.ts +++ b/src/promptUtils.ts @@ -1,7 +1,11 @@ import { type WorkspaceAgent } from "coder/site/src/api/typesGenerated"; import * as vscode from "vscode"; +import { type CoderApi } from "./api/coderApi"; import { type MementoManager } from "./core/mementoManager"; +import { OAuthMetadataClient } from "./oauth/metadataClient"; + +type AuthMethod = "oauth" | "legacy"; /** * Find the requested agent if specified, otherwise return the agent if there @@ -130,3 +134,54 @@ export async function maybeAskUrl( } return url; } + +export async function maybeAskAuthMethod( + client: CoderApi, +): Promise { + // Check if server supports OAuth with progress indication + const supportsOAuth = await vscode.window.withProgress( + { + location: vscode.ProgressLocation.Notification, + title: "Checking authentication methods", + cancellable: false, + }, + async () => { + return await OAuthMetadataClient.checkOAuthSupport( + client.getAxiosInstance(), + ); + }, + ); + + if (supportsOAuth) { + return await askAuthMethod(); + } else { + return "legacy"; + } +} + +/** + * Ask user to choose between OAuth and legacy API token authentication. + */ +async function askAuthMethod(): Promise { + const choice = await vscode.window.showQuickPick( + [ + { + label: "OAuth (Recommended)", + description: "Secure authentication with automatic token refresh", + value: "oauth" as const, + }, + { + label: "Session Token (Legacy)", + description: "Generate and paste a session token manually", + value: "legacy" as const, + }, + ], + { + title: "Select authentication method", + placeHolder: "How would you like to authenticate?", + ignoreFocusOut: true, + }, + ); + + return choice?.value; +} diff --git a/src/remote/remote.ts b/src/remote/remote.ts index 8dee8f1c..2f72cd1b 100644 --- a/src/remote/remote.ts +++ b/src/remote/remote.ts @@ -34,6 +34,8 @@ import { getHeaderCommand } from "../headers"; import { Inbox } from "../inbox"; import { type Logger } from "../logging/logger"; import { type LoginCoordinator } from "../login/loginCoordinator"; +import { OAuthInterceptor } from "../oauth/axiosInterceptor"; +import { OAuthSessionManager } from "../oauth/sessionManager"; import { AuthorityPrefix, escapeCommandArg, @@ -69,7 +71,7 @@ export class Remote { private readonly loginCoordinator: LoginCoordinator; public constructor( - serviceContainer: ServiceContainer, + private readonly serviceContainer: ServiceContainer, private readonly commands: Commands, private readonly extensionContext: vscode.ExtensionContext, ) { @@ -115,6 +117,13 @@ export class Remote { const disposables: vscode.Disposable[] = []; try { + // Create OAuth session manager for this remote deployment + const remoteOAuthManager = OAuthSessionManager.create( + { url: baseUrlRaw, safeHostname: parts.safeHostname }, + this.serviceContainer, + ); + disposables.push(remoteOAuthManager); + const ensureLoggedInAndRetry = async ( message: string, url: string | undefined, @@ -159,6 +168,17 @@ export class Remote { // client to remain unaffected by whatever the plugin is doing. const workspaceClient = CoderApi.create(baseUrlRaw, token, this.logger); disposables.push(workspaceClient); + + // Create OAuth interceptor - auto attaches/detaches based on token state + const oauthInterceptor = await OAuthInterceptor.create( + workspaceClient, + this.logger, + remoteOAuthManager, + this.secretsManager, + parts.safeHostname, + ); + disposables.push(oauthInterceptor); + // Store for use in commands. this.commands.remoteWorkspaceClient = workspaceClient; diff --git a/src/uri/uriHandler.ts b/src/uri/uriHandler.ts index 1e6eeff9..b54531a5 100644 --- a/src/uri/uriHandler.ts +++ b/src/uri/uriHandler.ts @@ -19,6 +19,7 @@ type UriRouteHandler = (ctx: UriRouteContext) => Promise; const routes: Record = { "/open": handleOpen, "/openDevContainer": handleOpenDevContainer, + CALLBACK_PATH: handleOAuthCallback, }; /** @@ -177,3 +178,25 @@ async function setupDeployment( user: result.user, }); } + +async function handleOAuthCallback(ctx: UriRouteContext): Promise { + const { params, serviceContainer } = ctx; + const logger = serviceContainer.getLogger(); + const secretsManager = serviceContainer.getSecretsManager(); + + const code = params.get("code"); + const state = params.get("state"); + const error = params.get("error"); + + if (!state) { + logger.warn("Received OAuth callback with no state parameter"); + return; + } + + try { + await secretsManager.setOAuthCallback({ state, code, error }); + logger.debug("OAuth callback processed successfully"); + } catch (err) { + logger.error("Failed to process OAuth callback:", err); + } +} diff --git a/test/mocks/testHelpers.ts b/test/mocks/testHelpers.ts index 21978b13..31c643e9 100644 --- a/test/mocks/testHelpers.ts +++ b/test/mocks/testHelpers.ts @@ -1,3 +1,4 @@ +import axios, { AxiosError, AxiosHeaders } from "axios"; import { vi } from "vitest"; import * as vscode from "vscode"; @@ -528,6 +529,32 @@ export class MockCoderApi } } +/** + * Mock OAuthSessionManager for testing. + * Provides no-op implementations of all public methods. + */ +export class MockOAuthSessionManager { + readonly setDeployment = vi.fn().mockResolvedValue(undefined); + readonly clearDeployment = vi.fn(); + readonly login = vi.fn().mockResolvedValue({ access_token: "test-token" }); + readonly handleCallback = vi.fn().mockResolvedValue(undefined); + readonly refreshToken = vi + .fn() + .mockResolvedValue({ access_token: "test-token" }); + readonly refreshIfAlmostExpired = vi.fn().mockResolvedValue(undefined); + readonly revokeRefreshToken = vi.fn().mockResolvedValue(undefined); + readonly isLoggedInWithOAuth = vi.fn().mockReturnValue(false); + readonly clearOAuthState = vi.fn().mockResolvedValue(undefined); + readonly showReAuthenticationModal = vi.fn().mockResolvedValue(undefined); + readonly dispose = vi.fn(); +} + +export class MockOAuthInterceptor { + readonly setDeployment = vi.fn().mockResolvedValue(undefined); + readonly clearDeployment = vi.fn(); + readonly dispose = vi.fn(); +} + /** * Create a mock User for testing. */ @@ -549,3 +576,211 @@ export function createMockUser(overrides: Partial = {}): User { ...overrides, }; } + +/** + * Creates an AxiosError for testing. + */ +export function createAxiosError( + status: number, + message: string, + config: Record = {}, +): AxiosError { + const error = new AxiosError( + message, + "ERR_BAD_REQUEST", + undefined, + undefined, + { + status, + statusText: message, + headers: {}, + config: { headers: new AxiosHeaders() }, + data: {}, + }, + ); + error.config = { headers: new AxiosHeaders(), ...config }; + return error; +} + +type MockAdapterFn = ReturnType; + +const AXIOS_MOCK_SETUP_EXAMPLE = ` +vi.mock("axios", async () => { + const actual = await vi.importActual("axios"); + const mockAdapter = vi.fn(); + return { + ...actual, + default: { + ...actual.default, + create: vi.fn((config) => + actual.default.create({ ...config, adapter: mockAdapter }), + ), + __mockAdapter: mockAdapter, + }, + }; +});`; + +/** + * Gets the mock axios adapter from the mocked axios module. + * The axios module must be mocked with __mockAdapter exposed. + * + * @throws Error if axios mock is not set up correctly, with instructions on how to fix it + */ +export function getAxiosMockAdapter(): MockAdapterFn { + const axiosWithMock = axios as typeof axios & { + __mockAdapter?: MockAdapterFn; + }; + const mockAdapter = axiosWithMock.__mockAdapter; + + if (!mockAdapter) { + throw new Error( + "Axios mock adapter not found. Make sure to mock axios with __mockAdapter:\n" + + AXIOS_MOCK_SETUP_EXAMPLE, + ); + } + + return mockAdapter; +} + +/** + * Sets up mock routes for the axios adapter. + * + * Route values can be: + * - Any data: Returns 200 OK with that data + * - Error instance: Rejects with that error + * + * If no route matches, rejects with a 404 AxiosError. + * + * @example + * ```ts + * setupAxiosMockRoutes(mockAdapter, { + * "/.well-known/oauth": metadata, // Returns 200 with metadata + * "/oauth2/register": new Error("Registration failed"), // Throws error + * "/api/v2/users/me": user, // Returns 200 with user + * }); + * ``` + */ +export function setupAxiosMockRoutes( + mockAdapter: MockAdapterFn, + routes: Record, +): void { + mockAdapter.mockImplementation(async (config: { url?: string }) => { + for (const [pattern, value] of Object.entries(routes)) { + if (config.url?.includes(pattern)) { + if (value instanceof Error) { + throw value; + } + const data = typeof value === "function" ? await value() : value; + return { + data, + status: 200, + statusText: "OK", + headers: {}, + config, + }; + } + } + const error = new AxiosError( + `Request failed with status code 404`, + "ERR_BAD_REQUEST", + undefined, + undefined, + { + status: 404, + statusText: "Not Found", + headers: {}, + config: { headers: new AxiosHeaders() }, + data: { + message: "Not found", + detail: `No route matched: ${config.url}`, + }, + }, + ); + throw error; + }); +} + +/** + * A mock vscode.Progress implementation that tracks all reported progress. + * Use this when testing code that accepts a Progress parameter directly. + */ +export class MockProgress + implements vscode.Progress +{ + private readonly reports: T[] = []; + readonly report = vi.fn((value: T) => { + this.reports.push(value); + }); + + /** + * Get all progress reports that have been made. + */ + getReports(): readonly T[] { + return this.reports; + } + + /** + * Clear all recorded reports. + */ + clear(): void { + this.reports.length = 0; + this.report.mockClear(); + } +} + +/** + * A mock vscode.CancellationToken that can be programmatically cancelled. + * Use this when testing code that accepts a CancellationToken parameter directly. + */ +export class MockCancellationToken implements vscode.CancellationToken { + private _isCancellationRequested: boolean; + private readonly listeners: Array<(e: unknown) => void> = []; + + constructor(initialCancelled = false) { + this._isCancellationRequested = initialCancelled; + } + + get isCancellationRequested(): boolean { + return this._isCancellationRequested; + } + + onCancellationRequested: vscode.Event = ( + listener: (e: unknown) => void, + ) => { + this.listeners.push(listener); + // If already cancelled, fire immediately (async to match VS Code behavior) + if (this._isCancellationRequested) { + setTimeout(() => listener(undefined), 0); + } + return { + dispose: () => { + const index = this.listeners.indexOf(listener); + if (index > -1) { + this.listeners.splice(index, 1); + } + }, + }; + }; + + /** + * Trigger cancellation. This will: + * - Set isCancellationRequested to true + * - Fire all registered cancellation listeners + */ + cancel(): void { + if (this._isCancellationRequested) { + return; // Already cancelled + } + this._isCancellationRequested = true; + for (const listener of this.listeners) { + listener(undefined); + } + } + + /** + * Reset to uncancelled state. Useful for reusing the token across tests. + */ + reset(): void { + this._isCancellationRequested = false; + } +} diff --git a/test/mocks/vscode.runtime.ts b/test/mocks/vscode.runtime.ts index cc557d09..8d5f35d8 100644 --- a/test/mocks/vscode.runtime.ts +++ b/test/mocks/vscode.runtime.ts @@ -132,6 +132,7 @@ export const env = { sessionId: "test-session-id", remoteName: undefined as string | undefined, shell: "/bin/bash", + uriScheme: "vscode", openExternal: vi.fn(), }; diff --git a/test/unit/deployment/deploymentManager.test.ts b/test/unit/deployment/deploymentManager.test.ts index 4f0ca52d..33c8cb95 100644 --- a/test/unit/deployment/deploymentManager.test.ts +++ b/test/unit/deployment/deploymentManager.test.ts @@ -11,10 +11,14 @@ import { InMemoryMemento, InMemorySecretStorage, MockCoderApi, + MockOAuthInterceptor, + MockOAuthSessionManager, } from "../../mocks/testHelpers"; import type { ServiceContainer } from "@/core/container"; import type { ContextManager } from "@/core/contextManager"; +import type { OAuthInterceptor } from "@/oauth/axiosInterceptor"; +import type { OAuthSessionManager } from "@/oauth/sessionManager"; import type { WorkspaceProvider } from "@/workspace/workspacesProvider"; // Mock CoderApi.create to return our mock client for validation @@ -64,6 +68,8 @@ function createTestContext() { // For setDeploymentIfValid, we use a separate mock for validation const validationMockClient = new MockCoderApi(); const mockWorkspaceProvider = new MockWorkspaceProvider(); + const mockOAuthSessionManager = new MockOAuthSessionManager(); + const mockOAuthInterceptor = new MockOAuthInterceptor(); const secretStorage = new InMemorySecretStorage(); const memento = new InMemoryMemento(); const logger = createMockLogger(); @@ -86,6 +92,8 @@ function createTestContext() { const manager = DeploymentManager.create( container as unknown as ServiceContainer, mockClient as unknown as CoderApi, + mockOAuthSessionManager as unknown as OAuthSessionManager, + mockOAuthInterceptor as unknown as OAuthInterceptor, [mockWorkspaceProvider as unknown as WorkspaceProvider], ); diff --git a/test/unit/login/loginCoordinator.test.ts b/test/unit/login/loginCoordinator.test.ts index 6044dc90..0c1d4a30 100644 --- a/test/unit/login/loginCoordinator.test.ts +++ b/test/unit/login/loginCoordinator.test.ts @@ -6,8 +6,10 @@ import { MementoManager } from "@/core/mementoManager"; import { SecretsManager } from "@/core/secretsManager"; import { getHeaders } from "@/headers"; import { LoginCoordinator } from "@/login/loginCoordinator"; +import { maybeAskAuthMethod } from "@/promptUtils"; import { + createAxiosError, createMockLogger, createMockUser, InMemoryMemento, @@ -58,7 +60,29 @@ vi.mock("@/api/streamingFetchAdapter", () => ({ createStreamingFetchAdapter: vi.fn(() => fetch), })); -vi.mock("@/promptUtils"); +vi.mock("@/promptUtils", () => ({ + maybeAskAuthMethod: vi.fn().mockResolvedValue("legacy"), + maybeAskUrl: vi.fn(), +})); + +// Mock CoderApi to control getAuthenticatedUser behavior +const mockGetAuthenticatedUser = vi.hoisted(() => vi.fn()); +vi.mock("@/api/coderApi", async (importOriginal) => { + const original = await importOriginal(); + return { + ...original, + CoderApi: { + ...original.CoderApi, + create: vi.fn(() => ({ + getAxiosInstance: () => ({ + defaults: { baseURL: "https://coder.example.com" }, + }), + setSessionToken: vi.fn(), + getAuthenticatedUser: mockGetAuthenticatedUser, + })), + }, + }; +}); // Type for axios with our mock adapter type MockedAxios = typeof axios & { __mockAdapter: ReturnType }; @@ -75,8 +99,8 @@ function createTestContext() { const mockAdapter = (axios as MockedAxios).__mockAdapter; mockAdapter.mockImplementation(mockAxiosAdapterImpl); vi.mocked(getHeaders).mockResolvedValue({}); + vi.mocked(maybeAskAuthMethod).mockResolvedValue("legacy"); - // MockConfigurationProvider sets sensible defaults (httpClientLogLevel, tlsCertFile, tlsKeyFile) const mockConfig = new MockConfigurationProvider(); // MockUserInteraction sets up vscode.window dialogs and input boxes const userInteraction = new MockUserInteraction(); @@ -92,9 +116,12 @@ function createTestContext() { mementoManager, vscode, logger, + "coder.coder-remote", ); const mockSuccessfulAuth = (user = createMockUser()) => { + // Configure both the axios adapter (for tests that bypass CoderApi mock) + // and mockGetAuthenticatedUser (for tests that use the CoderApi mock) mockAdapter.mockResolvedValue({ data: user, status: 200, @@ -102,18 +129,18 @@ function createTestContext() { headers: {}, config: {}, }); + mockGetAuthenticatedUser.mockResolvedValue(user); return user; }; const mockAuthFailure = (message = "Unauthorized") => { - mockAdapter.mockRejectedValue({ - response: { status: 401, data: { message } }, - message, - }); + mockAdapter.mockRejectedValue(createAxiosError(401, message)); + mockGetAuthenticatedUser.mockRejectedValue(createAxiosError(401, message)); }; return { mockAdapter, + mockGetAuthenticatedUser, mockConfig, userInteraction, secretsManager, @@ -149,21 +176,16 @@ describe("LoginCoordinator", () => { }); it("prompts for token when no stored auth exists", async () => { - const { mockAdapter, userInteraction, secretsManager, coordinator } = - createTestContext(); - const user = createMockUser(); - - // No stored token, so goes directly to input box flow - // Mock succeeds when validateInput calls getAuthenticatedUser - mockAdapter.mockResolvedValueOnce({ - data: user, - status: 200, - statusText: "OK", - headers: {}, - config: {}, - }); + const { + userInteraction, + secretsManager, + coordinator, + mockSuccessfulAuth, + } = createTestContext(); + const user = mockSuccessfulAuth(); // User enters a new token in the input box + vi.mocked(maybeAskAuthMethod).mockResolvedValue("legacy"); userInteraction.setInputBoxValue("new-token"); const result = await coordinator.ensureLoggedIn({ @@ -195,19 +217,14 @@ describe("LoginCoordinator", () => { describe("same-window guard", () => { it("prevents duplicate login calls for same hostname", async () => { - const { mockAdapter, userInteraction, coordinator } = createTestContext(); - const user = createMockUser(); + const { userInteraction, coordinator, mockSuccessfulAuth } = + createTestContext(); + mockSuccessfulAuth(); // User enters a token in the input box + vi.mocked(maybeAskAuthMethod).mockResolvedValue("legacy"); userInteraction.setInputBoxValue("new-token"); - let resolveAuth: (value: unknown) => void; - mockAdapter.mockReturnValue( - new Promise((resolve) => { - resolveAuth = resolve; - }), - ); - // Start first login const login1 = coordinator.ensureLoggedIn({ url: TEST_URL, @@ -220,15 +237,6 @@ describe("LoginCoordinator", () => { safeHostname: TEST_HOSTNAME, }); - // Resolve the auth (this validates the token from input box) - resolveAuth!({ - data: user, - status: 200, - statusText: "OK", - headers: {}, - config: {}, - }); - // Both should complete with the same result const [result1, result2] = await Promise.all([login1, login2]); expect(result1.success).toBe(true); @@ -297,6 +305,7 @@ describe("LoginCoordinator", () => { mementoManager, vscode, logger, + "coder.coder-remote", ); mockAuthFailure("Certificate error"); @@ -354,21 +363,14 @@ describe("LoginCoordinator", () => { }); it("falls back to stored token when provided token is invalid", async () => { - const { mockAdapter, secretsManager, coordinator } = createTestContext(); + const { mockGetAuthenticatedUser, secretsManager, coordinator } = + createTestContext(); const user = createMockUser(); - mockAdapter - .mockRejectedValueOnce({ - isAxiosError: true, - response: { status: 401 }, // Fail the provided token with 401 - message: "Unauthorized", - }) - .mockResolvedValueOnce({ - data: user, - status: 200, // Succeed the stored token - headers: {}, - config: {}, - }); + // First call (provided token) fails with 401, second call (stored token) succeeds + mockGetAuthenticatedUser + .mockRejectedValueOnce(createAxiosError(401, "Unauthorized")) + .mockResolvedValueOnce(user); await secretsManager.setSessionAuth(TEST_HOSTNAME, { url: TEST_URL, @@ -385,27 +387,20 @@ describe("LoginCoordinator", () => { }); it("prompts user when both provided and stored tokens are invalid", async () => { - const { mockAdapter, userInteraction, secretsManager, coordinator } = - createTestContext(); + const { + mockGetAuthenticatedUser, + userInteraction, + secretsManager, + coordinator, + } = createTestContext(); const user = createMockUser(); - mockAdapter - .mockRejectedValueOnce({ - isAxiosError: true, - response: { status: 401 }, // provided token - message: "Unauthorized", - }) - .mockRejectedValueOnce({ - isAxiosError: true, - response: { status: 401 }, // stored token - message: "Unauthorized", - }) - .mockResolvedValueOnce({ - data: user, - status: 200, // user-entered token - headers: {}, - config: {}, - }); + // First call (provided token) fails, second call (stored token) fails, + // third call (user-entered token) succeeds + mockGetAuthenticatedUser + .mockRejectedValueOnce(createAxiosError(401, "Unauthorized")) + .mockRejectedValueOnce(createAxiosError(401, "Unauthorized")) + .mockResolvedValueOnce(user); await secretsManager.setSessionAuth(TEST_HOSTNAME, { url: TEST_URL, @@ -429,22 +424,19 @@ describe("LoginCoordinator", () => { }); it("skips stored token check when same as provided token", async () => { - const { mockAdapter, userInteraction, secretsManager, coordinator } = - createTestContext(); + const { + mockGetAuthenticatedUser, + userInteraction, + secretsManager, + coordinator, + } = createTestContext(); const user = createMockUser(); - mockAdapter - .mockRejectedValueOnce({ - isAxiosError: true, - response: { status: 401 }, // provided token - message: "Unauthorized", - }) - .mockResolvedValueOnce({ - data: user, - status: 200, // user-entered token - headers: {}, - config: {}, - }); + // First call (provided token = stored token) fails with 401, + // second call (user-entered token) succeeds + mockGetAuthenticatedUser + .mockRejectedValueOnce(createAxiosError(401, "Unauthorized")) + .mockResolvedValueOnce(user); // Store the SAME token as will be provided await secretsManager.setSessionAuth(TEST_HOSTNAME, { @@ -466,7 +458,7 @@ describe("LoginCoordinator", () => { token: "user-entered-token", }); // Provided/stored token check only called once + user prompt - expect(mockAdapter).toHaveBeenCalledTimes(2); + expect(mockGetAuthenticatedUser).toHaveBeenCalledTimes(2); }); }); }); diff --git a/test/unit/oauth/authorizer.test.ts b/test/unit/oauth/authorizer.test.ts new file mode 100644 index 00000000..95a0a822 --- /dev/null +++ b/test/unit/oauth/authorizer.test.ts @@ -0,0 +1,381 @@ +import { describe, expect, it, vi } from "vitest"; +import * as vscode from "vscode"; + +import { getHeaders } from "@/headers"; +import { OAuthAuthorizer } from "@/oauth/authorizer"; + +import { + MockCancellationToken, + MockProgress, + setupAxiosMockRoutes, +} from "../../mocks/testHelpers"; + +import { + createMockTokenResponse, + createBaseTestContext, + createMockClientRegistration, + createMockOAuthMetadata, + createTestDeployment, + TEST_HOSTNAME, + TEST_URL, +} from "./testUtils"; + +vi.mock("axios", async () => { + const actual = await vi.importActual("axios"); + const mockAdapter = vi.fn(); + return { + ...actual, + default: { + ...actual.default, + create: vi.fn((config) => + actual.default.create({ ...config, adapter: mockAdapter }), + ), + __mockAdapter: mockAdapter, + }, + }; +}); + +vi.mock("@/headers", () => ({ + getHeaders: vi.fn().mockResolvedValue({}), + getHeaderCommand: vi.fn(), +})); + +vi.mock("@/api/utils", async () => { + const actual = + await vi.importActual("@/api/utils"); + return { ...actual, createHttpAgent: vi.fn() }; +}); + +vi.mock("@/api/streamingFetchAdapter", () => ({ + createStreamingFetchAdapter: vi.fn(() => fetch), +})); + +const EXTENSION_ID = "coder.coder-remote"; + +function createTestContext() { + vi.resetAllMocks(); + vi.mocked(getHeaders).mockResolvedValue({}); + + const base = createBaseTestContext(); + const authorizer = new OAuthAuthorizer( + base.secretsManager, + base.logger, + EXTENSION_ID, + ); + + /** Starts login flow and waits for browser to open. Returns promise and state for completing flow. */ + const startLogin = async (options?: { + progress?: MockProgress; + token?: MockCancellationToken; + }) => { + const progress = options?.progress ?? new MockProgress(); + const token = options?.token ?? new MockCancellationToken(); + const loginPromise = authorizer.login( + createTestDeployment(), + progress, + token, + ); + const { state, authUrl } = await waitForBrowserToOpen(); + return { loginPromise, state, authUrl, progress, token }; + }; + + /** Completes login by sending successful OAuth callback */ + const completeLogin = async (state: string) => { + await base.secretsManager.setOAuthCallback({ + state, + code: "code", + error: null, + }); + }; + + return { ...base, authorizer, startLogin, completeLogin }; +} + +/** + * Wait for openExternal to be called and return the auth URL and state. + */ +async function waitForBrowserToOpen(): Promise<{ + authUrl: URL; + state: string; +}> { + await vi.waitFor(() => { + expect(vscode.env.openExternal).toHaveBeenCalled(); + }); + const openExternalCall = vi.mocked(vscode.env.openExternal).mock.calls[0][0]; + const authUrl = new URL(openExternalCall.toString()); + return { authUrl, state: authUrl.searchParams.get("state")! }; +} + +describe("OAuthAuthorizer", () => { + describe("login flow", () => { + it("completes full OAuth login flow successfully", async () => { + const { mockAdapter, secretsManager, authorizer } = createTestContext(); + + setupAxiosMockRoutes(mockAdapter, { + "/.well-known/oauth-authorization-server": + createMockOAuthMetadata(TEST_URL), + "/oauth2/register": createMockClientRegistration({ + client_id: "registered-client-id", + }), + "/oauth2/token": createMockTokenResponse({ + access_token: "oauth-access-token", + }), + "/api/v2/users/me": { username: "oauth-user" }, + }); + + const deployment = createTestDeployment(); + const progress = new MockProgress(); + const cancellationToken = new MockCancellationToken(); + + const loginPromise = authorizer.login( + deployment, + progress, + cancellationToken, + ); + + const { state } = await waitForBrowserToOpen(); + + // Set the callback with the correct state (simulate user clicking authorize) + await secretsManager.setOAuthCallback({ + state, + code: "auth-code-123", + error: null, + }); + + const result = await loginPromise; + + expect(result.tokenResponse.access_token).toBe("oauth-access-token"); + expect(result.user.username).toBe("oauth-user"); + + // Verify client registration was stored + const storedRegistration = + await secretsManager.getOAuthClientRegistration(TEST_HOSTNAME); + expect(storedRegistration?.client_id).toBe("registered-client-id"); + }); + + it("uses existing client registration when redirect URI matches", async () => { + const { mockAdapter, secretsManager, authorizer } = createTestContext(); + + // Pre-store a client registration with matching redirect URI + await secretsManager.setOAuthClientRegistration( + TEST_HOSTNAME, + createMockClientRegistration({ + client_id: "existing-client-id", + redirect_uris: [`vscode://${EXTENSION_ID}/oauth/callback`], + }), + ); + + // Registration endpoint should throw if called (existing registration should be reused) + setupAxiosMockRoutes(mockAdapter, { + "/oauth2/register": new Error("Should not re-register"), + "/.well-known/oauth-authorization-server": + createMockOAuthMetadata(TEST_URL), + "/oauth2/token": createMockTokenResponse(), + "/api/v2/users/me": { username: "test-user" }, + }); + + const loginPromise = authorizer.login( + createTestDeployment(), + new MockProgress(), + new MockCancellationToken(), + ); + + const { authUrl, state } = await waitForBrowserToOpen(); + expect(authUrl.searchParams.get("client_id")).toBe("existing-client-id"); + + await secretsManager.setOAuthCallback({ + state, + code: "code", + error: null, + }); + await loginPromise; + }); + + it("re-registers client when redirect URI has changed", async () => { + const { mockAdapter, secretsManager, authorizer } = createTestContext(); + + // Pre-store a client registration with different redirect URI + await secretsManager.setOAuthClientRegistration( + TEST_HOSTNAME, + createMockClientRegistration({ + client_id: "old-client-id", + redirect_uris: ["vscode://different-extension/oauth/callback"], + }), + ); + + // Server will return new registration + setupAxiosMockRoutes(mockAdapter, { + "/.well-known/oauth-authorization-server": + createMockOAuthMetadata(TEST_URL), + "/oauth2/register": createMockClientRegistration({ + client_id: "new-client-id", + }), + "/oauth2/token": createMockTokenResponse(), + "/api/v2/users/me": { username: "test-user" }, + }); + + const loginPromise = authorizer.login( + createTestDeployment(), + new MockProgress(), + new MockCancellationToken(), + ); + + const { authUrl, state } = await waitForBrowserToOpen(); + expect(authUrl.searchParams.get("client_id")).toBe("new-client-id"); + + await secretsManager.setOAuthCallback({ + state, + code: "code", + error: null, + }); + await loginPromise; + + const stored = + await secretsManager.getOAuthClientRegistration(TEST_HOSTNAME); + expect(stored?.client_id).toBe("new-client-id"); + }); + + it("reports progress during login flow", async () => { + const { setupOAuthRoutes, startLogin, completeLogin } = + createTestContext(); + setupOAuthRoutes(); + + const progress = new MockProgress(); + const { loginPromise, state } = await startLogin({ progress }); + await completeLogin(state); + await loginPromise; + + const messages = progress.getReports().map((r) => r.message); + expect(messages).toEqual([ + "fetching metadata...", + "registering client...", + "waiting for authorization...", + "exchanging token...", + "fetching user...", + ]); + }); + }); + + describe("callback handling", () => { + it("ignores callback with wrong state", async () => { + const { secretsManager, setupOAuthRoutes, startLogin, completeLogin } = + createTestContext(); + setupOAuthRoutes(); + + const { loginPromise, state } = await startLogin(); + + // Send callback with wrong state - should be ignored + await secretsManager.setOAuthCallback({ + state: "wrong-state", + code: "code", + error: null, + }); + + // Login should still be waiting + const raceResult = await Promise.race([ + loginPromise.then(() => "completed"), + new Promise((resolve) => setTimeout(() => resolve("timeout"), 100)), + ]); + expect(raceResult).toBe("timeout"); + + // Now send correct callback + await completeLogin(state); + const result = await loginPromise; + expect(result.tokenResponse.access_token).toBeDefined(); + }); + + it("rejects on OAuth error callback", async () => { + const { secretsManager, setupOAuthRoutes, startLogin } = + createTestContext(); + setupOAuthRoutes(); + + const { loginPromise, state } = await startLogin(); + await secretsManager.setOAuthCallback({ + state, + code: null, + error: "access_denied", + }); + + await expect(loginPromise).rejects.toThrow("OAuth error: access_denied"); + }); + + it("rejects when no code is received", async () => { + const { secretsManager, setupOAuthRoutes, startLogin } = + createTestContext(); + setupOAuthRoutes(); + + const { loginPromise, state } = await startLogin(); + await secretsManager.setOAuthCallback({ state, code: null, error: null }); + + await expect(loginPromise).rejects.toThrow( + "No authorization code received", + ); + }); + }); + + describe("cancellation", () => { + it("rejects when cancelled before callback", async () => { + const { setupOAuthRoutes, startLogin } = createTestContext(); + setupOAuthRoutes(); + + const { loginPromise, token } = await startLogin(); + token.cancel(); + + await expect(loginPromise).rejects.toThrow( + "OAuth flow cancelled by user", + ); + }); + + it("rejects immediately when already cancelled", async () => { + const { authorizer, setupOAuthRoutes } = createTestContext(); + setupOAuthRoutes(); + + // Can't use startLogin() here because login rejects before browser opens + await expect( + authorizer.login( + createTestDeployment(), + new MockProgress(), + new MockCancellationToken(true), + ), + ).rejects.toThrow("OAuth login cancelled by user"); + }); + }); + + describe("dispose", () => { + it("rejects pending auth when disposed", async () => { + const { authorizer, setupOAuthRoutes, startLogin } = createTestContext(); + setupOAuthRoutes(); + + const { loginPromise } = await startLogin(); + authorizer.dispose(); + + await expect(loginPromise).rejects.toThrow("OAuthAuthorizer disposed"); + }); + + it("does nothing when disposed without pending auth", () => { + const { authorizer } = createTestContext(); + expect(() => authorizer.dispose()).not.toThrow(); + }); + }); + + describe("error handling", () => { + it("throws when server does not support dynamic client registration", async () => { + const { mockAdapter, authorizer } = createTestContext(); + + setupAxiosMockRoutes(mockAdapter, { + "/.well-known/oauth-authorization-server": createMockOAuthMetadata( + TEST_URL, + { registration_endpoint: undefined }, + ), + }); + + await expect( + authorizer.login( + createTestDeployment(), + new MockProgress(), + new MockCancellationToken(), + ), + ).rejects.toThrow("Server does not support dynamic client registration"); + }); + }); +}); diff --git a/test/unit/oauth/axiosInterceptor.test.ts b/test/unit/oauth/axiosInterceptor.test.ts new file mode 100644 index 00000000..ccf50afd --- /dev/null +++ b/test/unit/oauth/axiosInterceptor.test.ts @@ -0,0 +1,277 @@ +import axios, { type AxiosInstance } from "axios"; +import { describe, expect, it, vi } from "vitest"; + +import { SecretsManager } from "@/core/secretsManager"; +import { OAuthInterceptor } from "@/oauth/axiosInterceptor"; + +import { + createAxiosError, + createMockLogger, + InMemoryMemento, + InMemorySecretStorage, + MockOAuthSessionManager, +} from "../../mocks/testHelpers"; + +import { createMockTokenResponse, TEST_HOSTNAME, TEST_URL } from "./testUtils"; + +import type { CoderApi } from "@/api/coderApi"; +import type { OAuthSessionManager } from "@/oauth/sessionManager"; + +/** + * Creates a mock axios instance with controllable interceptors. + * Simplified to track count and last handler only. + */ +function createMockAxiosInstance(): AxiosInstance & { + triggerResponseError: (error: unknown) => Promise; + getInterceptorCount: () => number; +} { + const instance = axios.create(); + let interceptorCount = 0; + let lastRejectedHandler: ((error: unknown) => unknown) | null = null; + + vi.spyOn(instance.interceptors.response, "use").mockImplementation( + (_onFulfilled, onRejected) => { + interceptorCount++; + lastRejectedHandler = onRejected ?? ((e) => Promise.reject(e)); + return interceptorCount; + }, + ); + + vi.spyOn(instance.interceptors.response, "eject").mockImplementation(() => { + interceptorCount = Math.max(0, interceptorCount - 1); + if (interceptorCount === 0) { + lastRejectedHandler = null; + } + }); + + return Object.assign(instance, { + triggerResponseError: (error: unknown): Promise => { + if (!lastRejectedHandler) { + return Promise.reject(error); + } + return Promise.resolve(lastRejectedHandler(error)); + }, + getInterceptorCount: () => interceptorCount, + }); +} + +function createMockCoderApi(axiosInstance: AxiosInstance): CoderApi { + let sessionToken: string | undefined; + return { + getAxiosInstance: () => axiosInstance, + setSessionToken: vi.fn((token: string) => { + sessionToken = token; + }), + getSessionToken: () => sessionToken, + } as unknown as CoderApi; +} + +const ONE_HOUR_MS = 60 * 60 * 1000; + +function createTestContext() { + vi.resetAllMocks(); + + const secretStorage = new InMemorySecretStorage(); + const memento = new InMemoryMemento(); + const logger = createMockLogger(); + const secretsManager = new SecretsManager(secretStorage, memento, logger); + + const axiosInstance = createMockAxiosInstance(); + const mockCoderApi = createMockCoderApi(axiosInstance); + const mockOAuthManager = new MockOAuthSessionManager(); + + // Make isLoggedInWithOAuth check actual storage instead of returning a fixed value + vi.spyOn(mockOAuthManager, "isLoggedInWithOAuth").mockImplementation( + async () => { + const auth = await secretsManager.getSessionAuth(TEST_HOSTNAME); + return auth?.oauth !== undefined; + }, + ); + + /** Sets up OAuth tokens and creates interceptor */ + const setupOAuthInterceptor = async () => { + await secretsManager.setSessionAuth(TEST_HOSTNAME, { + url: TEST_URL, + token: "access-token", + oauth: { + token_type: "Bearer", + refresh_token: "refresh-token", + expiry_timestamp: Date.now() + ONE_HOUR_MS, + }, + }); + return OAuthInterceptor.create( + mockCoderApi, + logger, + mockOAuthManager as unknown as OAuthSessionManager, + secretsManager, + TEST_HOSTNAME, + ); + }; + + /** Sets up session token only (no OAuth) */ + const setupSessionToken = async () => { + await secretsManager.setSessionAuth(TEST_HOSTNAME, { + url: TEST_URL, + token: "session-token", + }); + }; + + /** Creates interceptor without any pre-existing auth */ + const createInterceptor = () => + OAuthInterceptor.create( + mockCoderApi, + logger, + mockOAuthManager as unknown as OAuthSessionManager, + secretsManager, + TEST_HOSTNAME, + ); + + return { + secretsManager, + logger, + axiosInstance, + mockCoderApi, + mockOAuthManager: mockOAuthManager as unknown as OAuthSessionManager & + MockOAuthSessionManager, + setupOAuthInterceptor, + setupSessionToken, + createInterceptor, + }; +} + +describe("OAuthInterceptor", () => { + describe("attach/detach based on token state", () => { + it("attaches when OAuth tokens stored", async () => { + const { axiosInstance, setupOAuthInterceptor } = createTestContext(); + + await setupOAuthInterceptor(); + + expect(axiosInstance.getInterceptorCount()).toBe(1); + }); + + it("does not attach when no OAuth tokens", async () => { + const { axiosInstance, setupSessionToken, createInterceptor } = + createTestContext(); + + await setupSessionToken(); + await createInterceptor(); + + expect(axiosInstance.getInterceptorCount()).toBe(0); + }); + + it("detaches when OAuth tokens cleared", async () => { + const { axiosInstance, setupOAuthInterceptor, setupSessionToken } = + createTestContext(); + + await setupOAuthInterceptor(); + expect(axiosInstance.getInterceptorCount()).toBe(1); + + await setupSessionToken(); + await vi.waitFor(() => { + expect(axiosInstance.getInterceptorCount()).toBe(0); + }); + }); + + it("attaches when OAuth tokens added", async () => { + const { + secretsManager, + axiosInstance, + setupSessionToken, + createInterceptor, + } = createTestContext(); + + await setupSessionToken(); + await createInterceptor(); + expect(axiosInstance.getInterceptorCount()).toBe(0); + + // Add OAuth tokens + await secretsManager.setSessionAuth(TEST_HOSTNAME, { + url: TEST_URL, + token: "access-token", + oauth: { + token_type: "Bearer", + refresh_token: "refresh-token", + expiry_timestamp: Date.now() + ONE_HOUR_MS, + }, + }); + + await vi.waitFor(() => { + expect(axiosInstance.getInterceptorCount()).toBe(1); + }); + }); + }); + + describe("401 handling", () => { + it("refreshes token and retries request", async () => { + const { + mockCoderApi, + mockOAuthManager, + axiosInstance, + setupOAuthInterceptor, + } = createTestContext(); + + const newTokens = createMockTokenResponse({ + access_token: "new-access-token", + }); + mockOAuthManager.refreshToken.mockResolvedValue(newTokens); + + const retryResponse = { data: "success", status: 200 }; + vi.spyOn(axiosInstance, "request").mockResolvedValue(retryResponse); + + await setupOAuthInterceptor(); + + const error = createAxiosError(401, "Unauthorized"); + const result = await axiosInstance.triggerResponseError(error); + + expect(mockCoderApi.getSessionToken()).toBe("new-access-token"); + expect(result).toEqual(retryResponse); + }); + + it("does not retry if already retried", async () => { + const { mockOAuthManager, axiosInstance, setupOAuthInterceptor } = + createTestContext(); + + await setupOAuthInterceptor(); + + const error = createAxiosError(401, "Unauthorized", { + _oauthRetryAttempted: true, + }); + + await expect(axiosInstance.triggerResponseError(error)).rejects.toThrow(); + expect(mockOAuthManager.refreshToken).not.toHaveBeenCalled(); + }); + + it("rethrows original error if refresh fails", async () => { + const { mockOAuthManager, axiosInstance, setupOAuthInterceptor } = + createTestContext(); + + mockOAuthManager.refreshToken.mockRejectedValue( + new Error("Refresh failed"), + ); + + await setupOAuthInterceptor(); + + const error = createAxiosError(401, "Unauthorized"); + + await expect(axiosInstance.triggerResponseError(error)).rejects.toThrow( + "Unauthorized", + ); + }); + + it.each<{ name: string; error: Error }>([ + { + name: "non-401 axios error", + error: createAxiosError(500, "Server Error"), + }, + { name: "non-axios error", error: new Error("Network failure") }, + ])("ignores $name", async ({ error }) => { + const { mockOAuthManager, axiosInstance, setupOAuthInterceptor } = + createTestContext(); + + await setupOAuthInterceptor(); + + await expect(axiosInstance.triggerResponseError(error)).rejects.toThrow(); + expect(mockOAuthManager.refreshToken).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/test/unit/oauth/sessionManager.test.ts b/test/unit/oauth/sessionManager.test.ts new file mode 100644 index 00000000..dc780b9f --- /dev/null +++ b/test/unit/oauth/sessionManager.test.ts @@ -0,0 +1,372 @@ +import { describe, expect, it, vi } from "vitest"; + +import { type SecretsManager, type SessionAuth } from "@/core/secretsManager"; +import { InvalidGrantError } from "@/oauth/errors"; +import { OAuthSessionManager } from "@/oauth/sessionManager"; + +import { + type createMockLogger, + setupAxiosMockRoutes, +} from "../../mocks/testHelpers"; + +import { + createBaseTestContext, + createMockClientRegistration, + createMockOAuthMetadata, + createMockTokenResponse, + createTestDeployment, + TEST_HOSTNAME, + TEST_URL, +} from "./testUtils"; + +import type { ServiceContainer } from "@/core/container"; +import type { Deployment } from "@/deployment/types"; +import type { LoginCoordinator } from "@/login/loginCoordinator"; + +vi.mock("axios", async () => { + const actual = await vi.importActual("axios"); + const mockAdapter = vi.fn(); + return { + ...actual, + default: { + ...actual.default, + create: vi.fn((config) => + actual.default.create({ ...config, adapter: mockAdapter }), + ), + __mockAdapter: mockAdapter, + }, + }; +}); + +vi.mock("@/headers", () => ({ + getHeaders: vi.fn().mockResolvedValue({}), + getHeaderCommand: vi.fn(), +})); + +vi.mock("@/api/utils", async () => { + const actual = + await vi.importActual("@/api/utils"); + return { ...actual, createHttpAgent: vi.fn() }; +}); + +const REFRESH_BUFFER_MS = 5 * 60 * 1000; // Tokens refresh 5 minutes before expiry +const ONE_HOUR_MS = 60 * 60 * 1000; + +function createMockLoginCoordinator(): LoginCoordinator { + return { + ensureLoggedIn: vi.fn(), + ensureLoggedInWithDialog: vi.fn(), + } as unknown as LoginCoordinator; +} + +function createMockServiceContainer( + secretsManager: SecretsManager, + logger: ReturnType, + loginCoordinator: LoginCoordinator, +): ServiceContainer { + return { + getSecretsManager: () => secretsManager, + getLogger: () => logger, + getLoginCoordinator: () => loginCoordinator, + } as ServiceContainer; +} + +function createTestContext(deployment: Deployment = createTestDeployment()) { + vi.resetAllMocks(); + + const base = createBaseTestContext(); + const loginCoordinator = createMockLoginCoordinator(); + const container = createMockServiceContainer( + base.secretsManager, + base.logger, + loginCoordinator, + ); + const manager = OAuthSessionManager.create(deployment, container); + + /** Sets up OAuth session auth */ + const setupOAuthSession = async ( + overrides: { + token?: string; + refreshToken?: string; + expiryMs?: number; + scope?: string; + } = {}, + ) => { + await base.secretsManager.setSessionAuth(TEST_HOSTNAME, { + url: TEST_URL, + token: overrides.token ?? "access-token", + oauth: { + token_type: "Bearer", + refresh_token: overrides.refreshToken ?? "refresh-token", + expiry_timestamp: Date.now() + (overrides.expiryMs ?? ONE_HOUR_MS), + scope: overrides.scope ?? "", + }, + }); + }; + + /** Creates a new manager (for tests that need manager created after OAuth setup) */ + const createManager = (d: Deployment = deployment) => + OAuthSessionManager.create(d, container); + + return { + ...base, + loginCoordinator, + manager, + setupOAuthSession, + createManager, + }; +} + +describe("OAuthSessionManager", () => { + describe("isLoggedInWithOAuth", () => { + type IsLoggedInTestCase = { + name: string; + auth: SessionAuth | null; + expected: boolean; + }; + + it.each([ + { + name: "returns true when OAuth tokens exist", + auth: { + url: TEST_URL, + token: "access-token", + oauth: { + token_type: "Bearer", + refresh_token: "refresh-token", + expiry_timestamp: Date.now() + ONE_HOUR_MS, + }, + }, + expected: true, + }, + { + name: "returns false when no tokens exist", + auth: null, + expected: false, + }, + { + name: "returns false when session auth has no OAuth data", + auth: { url: TEST_URL, token: "session-token" }, + expected: false, + }, + ])("$name", async ({ auth, expected }) => { + const { secretsManager, manager } = createTestContext(); + + if (auth) { + await secretsManager.setSessionAuth(TEST_HOSTNAME, auth); + } + + const result = await manager.isLoggedInWithOAuth(); + expect(result).toBe(expected); + }); + }); + + describe("refreshToken", () => { + it("throws when no refresh token available", async () => { + const { manager } = createTestContext(); + + await expect(manager.refreshToken()).rejects.toThrow( + "No refresh token available", + ); + }); + + it("refreshes token successfully", async () => { + const { secretsManager, mockAdapter, manager, setupOAuthSession } = + createTestContext(); + + await setupOAuthSession({ token: "old-token" }); + await secretsManager.setOAuthClientRegistration( + TEST_HOSTNAME, + createMockClientRegistration(), + ); + + setupAxiosMockRoutes(mockAdapter, { + "/.well-known/oauth-authorization-server": + createMockOAuthMetadata(TEST_URL), + "/oauth2/token": createMockTokenResponse({ + access_token: "refreshed-token", + }), + }); + + const result = await manager.refreshToken(); + expect(result.access_token).toBe("refreshed-token"); + }); + }); + + describe("getStoredTokens validation", () => { + it("returns undefined when URL mismatches", async () => { + const { secretsManager, manager } = createTestContext(); + + // Manually set auth with different URL (can't use helper) + await secretsManager.setSessionAuth(TEST_HOSTNAME, { + url: "https://different-coder.example.com", + token: "access-token", + oauth: { + token_type: "Bearer", + refresh_token: "refresh-token", + expiry_timestamp: Date.now() + ONE_HOUR_MS, + scope: "", + }, + }); + + const result = await manager.isLoggedInWithOAuth(); + expect(result).toBe(false); + }); + }); + + describe("setDeployment", () => { + it("switches to new deployment", async () => { + const { manager } = createTestContext(); + + const newDeployment: Deployment = { + url: "https://new-coder.example.com", + safeHostname: "new-coder.example.com", + }; + + await manager.setDeployment(newDeployment); + + const result = await manager.isLoggedInWithOAuth(); + expect(result).toBe(false); + }); + }); + + describe("clearDeployment", () => { + it("clears all deployment state", async () => { + const { manager } = createTestContext(); + + manager.clearDeployment(); + + const result = await manager.isLoggedInWithOAuth(); + expect(result).toBe(false); + }); + }); + + describe("background refresh", () => { + it("schedules refresh before token expiry", async () => { + vi.useFakeTimers(); + + const { secretsManager, mockAdapter, setupOAuthSession, createManager } = + createTestContext(); + + await setupOAuthSession({ token: "original-token" }); + await secretsManager.setOAuthClientRegistration( + TEST_HOSTNAME, + createMockClientRegistration(), + ); + + setupAxiosMockRoutes(mockAdapter, { + "/.well-known/oauth-authorization-server": + createMockOAuthMetadata(TEST_URL), + "/oauth2/token": createMockTokenResponse({ + access_token: "background-refreshed-token", + }), + }); + + // Create manager AFTER OAuth session is set up so it schedules refresh + createManager(); + + // Advance to when refresh should trigger + await vi.advanceTimersByTimeAsync(ONE_HOUR_MS - REFRESH_BUFFER_MS); + + const auth = await secretsManager.getSessionAuth(TEST_HOSTNAME); + expect(auth?.token).toBe("background-refreshed-token"); + }); + }); + + describe("showReAuthenticationModal", () => { + it("clears OAuth state and prompts for re-login", async () => { + const { secretsManager, loginCoordinator, manager, setupOAuthSession } = + createTestContext(); + + await setupOAuthSession(); + await secretsManager.setOAuthClientRegistration( + TEST_HOSTNAME, + createMockClientRegistration(), + ); + + await manager.showReAuthenticationModal( + new InvalidGrantError("Token expired"), + ); + + const auth = await secretsManager.getSessionAuth(TEST_HOSTNAME); + expect(auth?.oauth).toBeUndefined(); + expect(auth?.token).toBe(""); + expect(loginCoordinator.ensureLoggedInWithDialog).toHaveBeenCalled(); + }); + }); + + describe("concurrent refresh", () => { + it("deduplicates concurrent calls", async () => { + const { secretsManager, mockAdapter, manager, setupOAuthSession } = + createTestContext(); + + await setupOAuthSession(); + await secretsManager.setOAuthClientRegistration( + TEST_HOSTNAME, + createMockClientRegistration(), + ); + + let callCount = 0; + setupAxiosMockRoutes(mockAdapter, { + "/.well-known/oauth-authorization-server": + createMockOAuthMetadata(TEST_URL), + "/oauth2/token": () => { + callCount++; + return createMockTokenResponse({ + access_token: `token-${callCount}`, + }); + }, + }); + + const results = await Promise.all([ + manager.refreshToken(), + manager.refreshToken(), + manager.refreshToken(), + ]); + + expect(callCount).toBe(1); + expect(results[0]).toEqual(results[1]); + expect(results[1]).toEqual(results[2]); + }); + }); + + describe("deployment switch during refresh", () => { + it("completes in-flight refresh after switch", async () => { + const { secretsManager, mockAdapter, manager, setupOAuthSession } = + createTestContext(); + + await setupOAuthSession(); + await secretsManager.setOAuthClientRegistration( + TEST_HOSTNAME, + createMockClientRegistration(), + ); + + let resolveToken: (v: unknown) => void; + const tokenEndpointCalled = new Promise((resolve) => { + setupAxiosMockRoutes(mockAdapter, { + "/.well-known/oauth-authorization-server": + createMockOAuthMetadata(TEST_URL), + "/oauth2/token": () => + new Promise((r) => { + resolveToken = r; + resolve(); + }), + }); + }); + + const refreshPromise = manager.refreshToken(); + await tokenEndpointCalled; + + await manager.setDeployment({ + url: "https://new.example.com", + safeHostname: "new.example.com", + }); + + resolveToken!(createMockTokenResponse({ access_token: "new-token" })); + const result = await refreshPromise; + + expect(result.access_token).toBe("new-token"); + expect(await manager.isLoggedInWithOAuth()).toBe(false); + }); + }); +}); diff --git a/test/unit/oauth/testUtils.ts b/test/unit/oauth/testUtils.ts new file mode 100644 index 00000000..23062d7f --- /dev/null +++ b/test/unit/oauth/testUtils.ts @@ -0,0 +1,113 @@ +import { vi } from "vitest"; + +import { SecretsManager } from "@/core/secretsManager"; +import { getHeaders } from "@/headers"; + +import { + createMockLogger, + getAxiosMockAdapter, + InMemoryMemento, + InMemorySecretStorage, + MockConfigurationProvider, + setupAxiosMockRoutes, +} from "../../mocks/testHelpers"; + +import type { Deployment } from "@/deployment/types"; +import type { + ClientRegistrationResponse, + OAuthServerMetadata, + TokenResponse, +} from "@/oauth/types"; + +export const TEST_URL = "https://coder.example.com"; +export const TEST_HOSTNAME = "coder.example.com"; + +export function createMockOAuthMetadata( + issuer: string, + overrides: Partial = {}, +): OAuthServerMetadata { + return { + issuer, + authorization_endpoint: `${issuer}/oauth2/authorize`, + token_endpoint: `${issuer}/oauth2/token`, + revocation_endpoint: `${issuer}/oauth2/revoke`, + registration_endpoint: `${issuer}/oauth2/register`, + scopes_supported: [ + "workspace:read", + "workspace:update", + "workspace:start", + "workspace:ssh", + "workspace:application_connect", + "template:read", + "user:read_personal", + ], + response_types_supported: ["code"], + grant_types_supported: ["authorization_code", "refresh_token"], + code_challenge_methods_supported: ["S256"], + token_endpoint_auth_methods_supported: ["client_secret_post"], + ...overrides, + }; +} + +export function createMockClientRegistration( + overrides: Partial = {}, +): ClientRegistrationResponse { + return { + client_id: "test-client-id", + client_secret: "test-client-secret", + redirect_uris: ["vscode://coder.coder-remote/oauth/callback"], + token_endpoint_auth_method: "client_secret_post", + grant_types: ["authorization_code", "refresh_token"], + response_types: ["code"], + ...overrides, + }; +} + +/** + * Creates a mock OAuth token response for testing. + */ +export function createMockTokenResponse( + overrides: Partial = {}, +): TokenResponse { + return { + access_token: "test-access-token", + refresh_token: "test-refresh-token", + token_type: "Bearer", + expires_in: 3600, + scope: "workspace:read workspace:update", + ...overrides, + }; +} + +export function createTestDeployment(): Deployment { + return { + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + }; +} + +export function createBaseTestContext() { + const mockAdapter = getAxiosMockAdapter(); + vi.mocked(getHeaders).mockResolvedValue({}); + + // Constructor sets up vscode.workspace mock + new MockConfigurationProvider(); + + const secretStorage = new InMemorySecretStorage(); + const memento = new InMemoryMemento(); + const logger = createMockLogger(); + const secretsManager = new SecretsManager(secretStorage, memento, logger); + + /** Sets up default OAuth routes - use explicit routes when asserting on values */ + const setupOAuthRoutes = () => { + setupAxiosMockRoutes(mockAdapter, { + "/.well-known/oauth-authorization-server": + createMockOAuthMetadata(TEST_URL), + "/oauth2/register": createMockClientRegistration(), + "/oauth2/token": createMockTokenResponse(), + "/api/v2/users/me": { username: "test-user" }, + }); + }; + + return { mockAdapter, secretsManager, logger, setupOAuthRoutes }; +} diff --git a/test/unit/oauth/utils.test.ts b/test/unit/oauth/utils.test.ts new file mode 100644 index 00000000..3e5d603e --- /dev/null +++ b/test/unit/oauth/utils.test.ts @@ -0,0 +1,100 @@ +import { describe, expect, it } from "vitest"; + +import { buildOAuthTokenData } from "@/oauth/utils"; + +import type { TokenResponse } from "@/oauth/types"; + +const ACCESS_TOKEN_DEFAULT_EXPIRY_MS = 60 * 60 * 1000; + +function createTokenResponse( + overrides: Partial = {}, +): TokenResponse { + return { + access_token: "test-token", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "refresh-token", + scope: "workspace:read", + ...overrides, + }; +} + +describe("buildOAuthTokenData", () => { + describe("expires_in validation", () => { + it("uses expires_in when valid", () => { + const result = buildOAuthTokenData( + createTokenResponse({ expires_in: 7200 }), + ); + const expectedExpiry = Date.now() + 7200 * 1000; + expect(result.expiry_timestamp).toBeGreaterThanOrEqual( + expectedExpiry - 100, + ); + expect(result.expiry_timestamp).toBeLessThanOrEqual(expectedExpiry + 100); + }); + + it("uses default when expires_in is zero", () => { + const before = Date.now(); + const result = buildOAuthTokenData( + createTokenResponse({ expires_in: 0 }), + ); + expect(result.expiry_timestamp).toBeGreaterThanOrEqual( + before + ACCESS_TOKEN_DEFAULT_EXPIRY_MS, + ); + }); + + it("uses default when expires_in is negative", () => { + const before = Date.now(); + const result = buildOAuthTokenData( + createTokenResponse({ expires_in: -100 }), + ); + expect(result.expiry_timestamp).toBeGreaterThanOrEqual( + before + ACCESS_TOKEN_DEFAULT_EXPIRY_MS, + ); + }); + + it("uses default when expires_in is undefined", () => { + const before = Date.now(); + const result = buildOAuthTokenData( + createTokenResponse({ expires_in: undefined }), + ); + expect(result.expiry_timestamp).toBeGreaterThanOrEqual( + before + ACCESS_TOKEN_DEFAULT_EXPIRY_MS, + ); + }); + + it("uses default when expires_in is Infinity", () => { + const before = Date.now(); + const result = buildOAuthTokenData( + createTokenResponse({ expires_in: Infinity }), + ); + expect(result.expiry_timestamp).toBeGreaterThanOrEqual( + before + ACCESS_TOKEN_DEFAULT_EXPIRY_MS, + ); + }); + }); + + describe("token_type validation", () => { + it("accepts Bearer tokens", () => { + const result = buildOAuthTokenData( + createTokenResponse({ token_type: "Bearer" }), + ); + expect(result.token_type).toBe("Bearer"); + }); + + it("rejects DPoP tokens", () => { + expect(() => + buildOAuthTokenData( + createTokenResponse({ token_type: "DPoP" as "Bearer" }), + ), + ).toThrow("Unsupported token type: DPoP"); + }); + + it("rejects unknown token types", () => { + expect(() => + buildOAuthTokenData( + createTokenResponse({ token_type: "unknown" as "Bearer" }), + ), + ).toThrow("Unsupported token type: unknown"); + }); + }); +});