diff --git a/apps/integration-tests/src/end-to-end.integration.test.ts b/apps/integration-tests/src/end-to-end.integration.test.ts index b7fd767..1d9128d 100644 --- a/apps/integration-tests/src/end-to-end.integration.test.ts +++ b/apps/integration-tests/src/end-to-end.integration.test.ts @@ -209,6 +209,32 @@ t.describe("E2E Integration Test", () => { await resumedDappClient.sendRequest(testPayload); await t.expect(messagePromise).resolves.toEqual(testPayload); }); + + t.test("should discard inbound messages when the receiver's session has expired", async () => { + await connectClients(dappClient, walletClient, "trusted"); + + // Verify the connection works before expiry + const preExpiryPayload = { method: "before_expiry" }; + const preExpiryPromise = new Promise((resolve) => walletClient.once("message", resolve)); + await dappClient.sendRequest(preExpiryPayload); + await t.expect(preExpiryPromise).resolves.toEqual(preExpiryPayload); + + // Force-expire the wallet's session by setting expiresAt to the past + (walletClient as any).session.expiresAt = Date.now() - 1000; + + // Listen for the SESSION_EXPIRED error on the wallet + const errorPromise = new Promise((resolve) => walletClient.once("error", resolve)); + + // Dapp sends another message - wallet should reject it + const postExpiryPayload = { method: "after_expiry" }; + await dappClient.sendRequest(postExpiryPayload); + + const error = await errorPromise; + t.expect(error.code).toBe("SESSION_EXPIRED"); + + // Give time for any message processing + await new Promise((resolve) => setTimeout(resolve, 500)); + }); }); t.describe("E2E Integration Test via Proxy", () => { diff --git a/packages/core/CHANGELOG.md b/packages/core/CHANGELOG.md index ae1a5c6..6b8f309 100644 --- a/packages/core/CHANGELOG.md +++ b/packages/core/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Fixed + +- Prevent nonce poisoning by deferring nonce persistence until after successful decryption ([#69](https://github.com/MetaMask/mobile-wallet-protocol/pull/69)) +- Guard against `NaN` in nonce storage ([#69](https://github.com/MetaMask/mobile-wallet-protocol/pull/69)) +- Reject inbound messages on expired sessions instead of processing them + ## [0.3.1] ### Fixed diff --git a/packages/core/package.json b/packages/core/package.json index 6cfe023..64093cd 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -25,6 +25,7 @@ "registry": "https://registry.npmjs.org/" }, "dependencies": { + "async-mutex": "^0.5.0", "centrifuge": "^5.3.5", "eventemitter3": "^5.0.1", "uuid": "^11.1.0" diff --git a/packages/core/src/base-client.integration.test.ts b/packages/core/src/base-client.integration.test.ts index e536cee..851697c 100644 --- a/packages/core/src/base-client.integration.test.ts +++ b/packages/core/src/base-client.integration.test.ts @@ -352,6 +352,48 @@ t.describe("BaseClient", () => { publishSpy.mockRestore(); }); + t.test("should discard inbound messages on an expired session", async () => { + const keyManagerA = new KeyManager(); + const keyManagerB = new KeyManager(); + const keyPairA = keyManagerA.generateKeyPair(); + const keyPairB = keyManagerB.generateKeyPair(); + + const sessionA: Session = { + id: "session-inbound-expiry", + channel, + keyPair: keyPairA, + theirPublicKey: keyPairB.publicKey, + expiresAt: Date.now() + 60000, + }; + const sessionB: Session = { + id: "session-inbound-expiry", + channel, + keyPair: keyPairB, + theirPublicKey: keyPairA.publicKey, + expiresAt: Date.now() - 1000, // Already expired + }; + + clientA.setSession(sessionA); + clientB.setSession(sessionB); + + await clientA["transport"].subscribe(channel); + await clientB["transport"].subscribe(channel); + + const errorPromise = new Promise((resolve) => { + clientB.once("error", resolve); + }); + + const messageToSend: ProtocolMessage = { type: "message", payload: { method: "should_be_dropped" } }; + await clientA.sendMessage(channel, messageToSend); + + const error = await errorPromise; + t.expect(error.code).toBe("SESSION_EXPIRED"); + + // Give a moment for any message processing to complete + await new Promise((resolve) => setTimeout(resolve, 200)); + t.expect(clientB.receivedMessages).toHaveLength(0); + }); + t.test("should reject resume() when client is already connected", async () => { // 1. Create and store a valid session const keyManagerA = new KeyManager(); diff --git a/packages/core/src/base-client.ts b/packages/core/src/base-client.ts index 7173fcf..2d07357 100644 --- a/packages/core/src/base-client.ts +++ b/packages/core/src/base-client.ts @@ -45,8 +45,17 @@ export abstract class BaseClient extends EventEmitter { this.transport.on("message", async (payload) => { if (!this.session?.keyPair.privateKey) return; + if (await this.checkSessionExpiry()) { + this.emit("error", new SessionError(ErrorCode.SESSION_EXPIRED, "Session expired")); + return; + } const message = await this.decryptMessage(payload.data); - if (message) this.handleMessage(message); + if (message) { + // Confirm the nonce only after successful decryption to prevent + // attackers from poisoning the nonce tracker with invalid messages. + await payload.confirmNonce?.(); + this.handleMessage(message); + } }); } @@ -139,7 +148,7 @@ export abstract class BaseClient extends EventEmitter { */ protected async sendMessage(channel: string, message: ProtocolMessage): Promise { if (!this.session) throw new SessionError(ErrorCode.SESSION_INVALID_STATE, "Cannot send message: session is not initialized."); - await this.checkSessionExpiry(); + if (await this.checkSessionExpiry()) throw new SessionError(ErrorCode.SESSION_EXPIRED, "Session expired"); const plaintext = JSON.stringify(message); const encrypted = await this.keymanager.encrypt(plaintext, this.session.theirPublicKey); const ok = await this.transport.publish(channel, encrypted); @@ -147,15 +156,14 @@ export abstract class BaseClient extends EventEmitter { } /** - * Checks if the current session is expired. If it is, triggers a disconnect. - * @throws {SessionError} if the session is expired. + * Checks if the current session has expired. If so, triggers a disconnect. + * + * @returns true if the session was expired (and cleanup was triggered), false otherwise. */ - private async checkSessionExpiry(): Promise { - if (!this.session) return; - if (this.session.expiresAt < Date.now()) { - await this.disconnect(); - throw new SessionError(ErrorCode.SESSION_EXPIRED, "Session expired"); - } + private async checkSessionExpiry(): Promise { + if (!this.session || this.session.expiresAt >= Date.now()) return false; + await this.disconnect(); + return true; } /** diff --git a/packages/core/src/domain/transport.ts b/packages/core/src/domain/transport.ts index b295b4c..47395aa 100644 --- a/packages/core/src/domain/transport.ts +++ b/packages/core/src/domain/transport.ts @@ -40,7 +40,7 @@ export interface ITransport { * @param event The name of the event to listen for. * @param handler The callback function to execute. */ - on(event: "message", handler: (payload: { channel: string; data: string }) => void): void; + on(event: "message", handler: (payload: { channel: string; data: string; confirmNonce?: () => Promise }) => void): void; on(event: "connecting" | "connected" | "disconnected", handler: () => void): void; on(event: "error", handler: (error: Error) => void): void; diff --git a/packages/core/src/transport/websocket/index.integration.test.ts b/packages/core/src/transport/websocket/index.integration.test.ts index b063e95..293b3d3 100644 --- a/packages/core/src/transport/websocket/index.integration.test.ts +++ b/packages/core/src/transport/websocket/index.integration.test.ts @@ -189,7 +189,8 @@ t.describe.each(testModes)("WebSocketTransport with $name", ({ useSharedConnecti const received = await messagePromise; // The subscriber transport should unwrap the envelope and emit the original payload. - t.expect(received).toEqual({ channel, data: payload }); + t.expect(received).toMatchObject({ channel, data: payload }); + t.expect(received.confirmNonce).toBeTypeOf("function"); await publisher.disconnect(); }); @@ -431,16 +432,18 @@ t.describe.each(testModes)("WebSocketTransport with $name", ({ useSharedConnecti const messagePayload = "dedup-test-message"; let messageCount = 0; - subscriber.on("message", ({ data }) => { + subscriber.on("message", ({ data, confirmNonce }) => { if (data === messagePayload) { messageCount++; + confirmNonce?.(); } }); // Send the message once using normal publish const firstMessagePromise = waitFor(subscriber, "message"); await rawPublisher.publish(channel, messagePayload); - await firstMessagePromise; + const firstMsg = await firstMessagePromise; + await firstMsg.confirmNonce?.(); t.expect(messageCount).toBe(1); // Create the exact same message envelope that was sent @@ -791,8 +794,11 @@ t.describe.each(testModes)("WebSocketTransport with $name", ({ useSharedConnecti await publisher.publish(channel, `message-${i}`); } - // Wait for all messages to be received - await Promise.all(messagePromises); + // Wait for all messages to be received and confirm nonces so they persist + const receivedMsgs = await Promise.all(messagePromises); + for (const msg of receivedMsgs) { + await msg.confirmNonce?.(); + } // Verify storage has accumulated data const storage = (transport as any).storage; diff --git a/packages/core/src/transport/websocket/index.ts b/packages/core/src/transport/websocket/index.ts index 8e7be2d..c1b3f92 100644 --- a/packages/core/src/transport/websocket/index.ts +++ b/packages/core/src/transport/websocket/index.ts @@ -51,6 +51,17 @@ type TransportState = "disconnected" | "connecting" | "connected"; /** The maximum number of messages to fetch from history upon a new subscription. */ const HISTORY_FETCH_LIMIT = 50; +/** + * Maximum allowed nonce jump from a known sender. Messages with a nonce that + * jumps more than this from the last confirmed nonce are rejected as suspicious. + * Does not apply to the first message from a new sender (baseline is 0). + * + * Trade-off: if the receiver goes offline and misses more than this many + * messages from a known sender, legitimate messages will be permanently + * blocked. In practice this is unlikely given low message rates in MWP + * sessions, but worth noting. + */ +const MAX_NONCE_JUMP = 100; /** The maximum number of retry attempts for publishing a message. */ const MAX_RETRY_ATTEMPTS = 5; /** The base delay in milliseconds for exponential backoff between publish retries. */ @@ -65,6 +76,7 @@ export class WebSocketTransport extends EventEmitter implements ITransport { private readonly centrifuge: Centrifuge | SharedCentrifuge; private readonly storage: WebSocketTransportStorage; private readonly queue: QueuedItem[] = []; + private readonly pendingNonces = new Map>(); private isProcessingQueue = false; private state: TransportState = "disconnected"; @@ -214,6 +226,9 @@ export class WebSocketTransport extends EventEmitter implements ITransport { */ public async clear(channel: string): Promise { await this.storage.clear(channel); + for (const key of this.pendingNonces.keys()) { + if (key.startsWith(`${channel}:`)) this.pendingNonces.delete(key); + } const sub = this.centrifuge.getSubscription(channel); if (sub) this.centrifuge.removeSubscription(sub as Subscription); } @@ -229,6 +244,11 @@ export class WebSocketTransport extends EventEmitter implements ITransport { /** * Parses an incoming raw message, checks for duplicates, and emits it. + * + * The nonce is checked for deduplication but NOT persisted here. The emitted + * payload includes a `confirmNonce` callback that the consumer (BaseClient) + * must call after successful decryption. This prevents an attacker from + * poisoning the nonce tracker with high-nonce messages that fail decryption. */ private async _handleIncomingMessage(channel: string, rawData: string): Promise { try { @@ -246,13 +266,40 @@ export class WebSocketTransport extends EventEmitter implements ITransport { const latestNonces = await this.storage.getLatestNonces(channel); const latestNonce = latestNonces.get(message.clientId) || 0; - if (message.nonce > latestNonce) { - // This is a new message, update the latest nonce and emit the message. - latestNonces.set(message.clientId, message.nonce); - await this.storage.setLatestNonces(channel, latestNonces); - this.emit("message", { channel, data: message.payload }); + if (message.nonce <= latestNonce) { + return; + } + + // Reject suspiciously large nonce jumps (but allow first message from a new sender). + if (latestNonce > 0 && message.nonce - latestNonce > MAX_NONCE_JUMP) { + this.emit("error", new TransportError(ErrorCode.TRANSPORT_PARSE_FAILED, `Nonce jump too large: ${latestNonce} -> ${message.nonce}`)); + return; + } + + // Guard against duplicate processing between emit and confirm. + // Without this, a message arriving via both live publication and + // _fetchHistory could pass the storage-based dedup check twice + // because the nonce hasn't been confirmed yet. + const pendingKey = `${channel}:${message.clientId}`; + const pending = this.pendingNonces.get(pendingKey); + if (pending?.has(message.nonce)) { + return; } - // If message.nonce <= latestNonce, it's a duplicate and we ignore it. + if (!pending) { + this.pendingNonces.set(pendingKey, new Set([message.nonce])); + } else { + pending.add(message.nonce); + } + + const confirmNonce = async () => { + await this.storage.confirmNonce(channel, message.clientId, message.nonce); + const p = this.pendingNonces.get(pendingKey); + if (p) { + p.delete(message.nonce); + if (p.size === 0) this.pendingNonces.delete(pendingKey); + } + }; + this.emit("message", { channel, data: message.payload, confirmNonce }); } catch (error) { this.emit("error", new TransportError(ErrorCode.TRANSPORT_PARSE_FAILED, `Failed to parse incoming message: ${error instanceof Error ? error.message : "Unknown error"}`)); } diff --git a/packages/core/src/transport/websocket/store.test.ts b/packages/core/src/transport/websocket/store.test.ts index 871e1a8..a58bcbe 100644 --- a/packages/core/src/transport/websocket/store.test.ts +++ b/packages/core/src/transport/websocket/store.test.ts @@ -116,6 +116,52 @@ t.describe("WebSocketTransportStorage", () => { t.expect(nextNonce).toBe(6); t.expect(await kvstore.get(newStorage.getNonceKey(channel))).toBe("6"); }); + + t.test("should recover from NaN nonce value in storage", async () => { + const channel = "session:nan-channel"; + + // Corrupt the stored nonce value + await kvstore.set(storage.getNonceKey(channel), "not-a-number"); + + const nonce = await storage.getNextNonce(channel); + t.expect(nonce).toBe(1); + t.expect(await kvstore.get(storage.getNonceKey(channel))).toBe("1"); + }); + }); + + t.describe("Nonce Confirmation", () => { + let storage: WebSocketTransportStorage; + + t.beforeEach(async () => { + storage = await WebSocketTransportStorage.create(kvstore); + }); + + t.test("should save nonce via confirmNonce", async () => { + const channel = "session:confirm-channel"; + await storage.confirmNonce(channel, "sender-1", 5); + + const nonces = await storage.getLatestNonces(channel); + t.expect(nonces.get("sender-1")).toBe(5); + }); + + t.test("should not regress nonce on confirmNonce with lower value", async () => { + const channel = "session:confirm-channel"; + await storage.confirmNonce(channel, "sender-1", 10); + await storage.confirmNonce(channel, "sender-1", 3); + + const nonces = await storage.getLatestNonces(channel); + t.expect(nonces.get("sender-1")).toBe(10); + }); + + t.test("should track nonces independently per sender", async () => { + const channel = "session:confirm-channel"; + await storage.confirmNonce(channel, "sender-1", 5); + await storage.confirmNonce(channel, "sender-2", 8); + + const nonces = await storage.getLatestNonces(channel); + t.expect(nonces.get("sender-1")).toBe(5); + t.expect(nonces.get("sender-2")).toBe(8); + }); }); t.describe("Latest Nonces Management", () => { diff --git a/packages/core/src/transport/websocket/store.ts b/packages/core/src/transport/websocket/store.ts index a7a7263..10362e9 100644 --- a/packages/core/src/transport/websocket/store.ts +++ b/packages/core/src/transport/websocket/store.ts @@ -1,3 +1,4 @@ +import { Mutex } from "async-mutex"; import { v4 as uuid } from "uuid"; import type { IKVStore } from "../../domain/kv-store"; @@ -8,6 +9,7 @@ import type { IKVStore } from "../../domain/kv-store"; export class WebSocketTransportStorage { private readonly kvstore: IKVStore; private readonly clientId: string; + private readonly nonceMutex = new Mutex(); /** * Creates a new WebSocketTransportStorage instance with a persistent client ID. @@ -42,12 +44,28 @@ export class WebSocketTransportStorage { async getNextNonce(channel: string): Promise { const key = this.getNonceKey(channel); const value = await this.kvstore.get(key); - const currentNonce = value ? parseInt(value, 10) : 0; + let currentNonce = value ? parseInt(value, 10) : 0; + if (Number.isNaN(currentNonce)) currentNonce = 0; const nextNonce = currentNonce + 1; await this.kvstore.set(key, nextNonce.toString()); return nextNonce; } + /** + * Confirms a received nonce after the message has been successfully processed + * (e.g., decrypted). Only updates if the nonce is higher than the current value. + */ + async confirmNonce(channel: string, clientId: string, nonce: number): Promise { + await this.nonceMutex.runExclusive(async () => { + const latestNonces = await this.getLatestNonces(channel); + const current = latestNonces.get(clientId) || 0; + if (nonce > current) { + latestNonces.set(clientId, nonce); + await this.setLatestNonces(channel, latestNonces); + } + }); + } + /** * Retrieves the latest received nonces from all senders on the specified channel. * Used for message deduplication - only messages with nonces greater than the diff --git a/yarn.lock b/yarn.lock index 1289ad8..5c7b9b6 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2764,6 +2764,7 @@ __metadata: "@metamask/auto-changelog": ^5.0.2 "@types/uuid": ^10.0.0 "@types/ws": ^8.18.1 + async-mutex: ^0.5.0 centrifuge: ^5.3.5 eciesjs: ^0.4.15 eventemitter3: ^5.0.1 @@ -5421,6 +5422,15 @@ __metadata: languageName: node linkType: hard +"async-mutex@npm:^0.5.0": + version: 0.5.0 + resolution: "async-mutex@npm:0.5.0" + dependencies: + tslib: ^2.4.0 + checksum: be1587f4875f3bb15e34e9fcce82eac2966daef4432c8d0046e61947fb9a1b95405284601bc7ce4869319249bc07c75100880191db6af11d1498931ac2a2f9ea + languageName: node + linkType: hard + "asynckit@npm:^0.4.0": version: 0.4.0 resolution: "asynckit@npm:0.4.0"