diff --git a/packages/core/CHANGELOG.md b/packages/core/CHANGELOG.md index 17b077b..4d4d8d4 100644 --- a/packages/core/CHANGELOG.md +++ b/packages/core/CHANGELOG.md @@ -18,6 +18,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Prevent nonce poisoning by deferring nonce persistence until after successful decryption ([#69](https://github.com/MetaMask/mobile-wallet-protocol/pull/69)) +- Guard against `NaN` in nonce storage ([#69](https://github.com/MetaMask/mobile-wallet-protocol/pull/69)) - Fix `SessionStore` race conditions and fire-and-forget garbage collection ([#71](https://github.com/MetaMask/mobile-wallet-protocol/pull/71)) - Guard against `NaN` in session expiry timestamps ([#70](https://github.com/MetaMask/mobile-wallet-protocol/pull/70)) diff --git a/packages/core/src/base-client.ts b/packages/core/src/base-client.ts index ede243f..b6a0ee0 100644 --- a/packages/core/src/base-client.ts +++ b/packages/core/src/base-client.ts @@ -46,7 +46,12 @@ export abstract class BaseClient extends EventEmitter { this.transport.on("message", async (payload) => { if (!this.session?.keyPair.privateKey) return; const message = await this.decryptMessage(payload.data); - if (message) this.handleMessage(message); + if (message) { + // Confirm the nonce only after successful decryption to prevent + // attackers from poisoning the nonce tracker with invalid messages. + await payload.confirmNonce?.(); + this.handleMessage(message); + } }); } diff --git a/packages/core/src/domain/transport.ts b/packages/core/src/domain/transport.ts index b295b4c..47395aa 100644 --- a/packages/core/src/domain/transport.ts +++ b/packages/core/src/domain/transport.ts @@ -40,7 +40,7 @@ export interface ITransport { * @param event The name of the event to listen for. * @param handler The callback function to execute. */ - on(event: "message", handler: (payload: { channel: string; data: string }) => void): void; + on(event: "message", handler: (payload: { channel: string; data: string; confirmNonce?: () => Promise }) => void): void; on(event: "connecting" | "connected" | "disconnected", handler: () => void): void; on(event: "error", handler: (error: Error) => void): void; diff --git a/packages/core/src/transport/websocket/index.integration.test.ts b/packages/core/src/transport/websocket/index.integration.test.ts index b063e95..293b3d3 100644 --- a/packages/core/src/transport/websocket/index.integration.test.ts +++ b/packages/core/src/transport/websocket/index.integration.test.ts @@ -189,7 +189,8 @@ t.describe.each(testModes)("WebSocketTransport with $name", ({ useSharedConnecti const received = await messagePromise; // The subscriber transport should unwrap the envelope and emit the original payload. - t.expect(received).toEqual({ channel, data: payload }); + t.expect(received).toMatchObject({ channel, data: payload }); + t.expect(received.confirmNonce).toBeTypeOf("function"); await publisher.disconnect(); }); @@ -431,16 +432,18 @@ t.describe.each(testModes)("WebSocketTransport with $name", ({ useSharedConnecti const messagePayload = "dedup-test-message"; let messageCount = 0; - subscriber.on("message", ({ data }) => { + subscriber.on("message", ({ data, confirmNonce }) => { if (data === messagePayload) { messageCount++; + confirmNonce?.(); } }); // Send the message once using normal publish const firstMessagePromise = waitFor(subscriber, "message"); await rawPublisher.publish(channel, messagePayload); - await firstMessagePromise; + const firstMsg = await firstMessagePromise; + await firstMsg.confirmNonce?.(); t.expect(messageCount).toBe(1); // Create the exact same message envelope that was sent @@ -791,8 +794,11 @@ t.describe.each(testModes)("WebSocketTransport with $name", ({ useSharedConnecti await publisher.publish(channel, `message-${i}`); } - // Wait for all messages to be received - await Promise.all(messagePromises); + // Wait for all messages to be received and confirm nonces so they persist + const receivedMsgs = await Promise.all(messagePromises); + for (const msg of receivedMsgs) { + await msg.confirmNonce?.(); + } // Verify storage has accumulated data const storage = (transport as any).storage; diff --git a/packages/core/src/transport/websocket/index.ts b/packages/core/src/transport/websocket/index.ts index 8e7be2d..799a7a9 100644 --- a/packages/core/src/transport/websocket/index.ts +++ b/packages/core/src/transport/websocket/index.ts @@ -51,6 +51,19 @@ type TransportState = "disconnected" | "connecting" | "connected"; /** The maximum number of messages to fetch from history upon a new subscription. */ const HISTORY_FETCH_LIMIT = 50; +/** + * Maximum allowed nonce jump from a known sender. Messages with a nonce that + * jumps more than this from the last confirmed nonce are rejected as suspicious. + * Does not apply to the first message from a new sender (baseline is 0). This + * is safe because a spoofed first message would fail decryption and its nonce + * would never be confirmed to storage. + * + * Trade-off: if the receiver goes offline and misses more than this many + * messages from a known sender, legitimate messages will be permanently + * blocked. In practice this is unlikely given low message rates in MWP + * sessions, but worth noting. + */ +const MAX_NONCE_JUMP = 100; /** The maximum number of retry attempts for publishing a message. */ const MAX_RETRY_ATTEMPTS = 5; /** The base delay in milliseconds for exponential backoff between publish retries. */ @@ -65,6 +78,7 @@ export class WebSocketTransport extends EventEmitter implements ITransport { private readonly centrifuge: Centrifuge | SharedCentrifuge; private readonly storage: WebSocketTransportStorage; private readonly queue: QueuedItem[] = []; + private readonly pendingNonces = new Map>(); private isProcessingQueue = false; private state: TransportState = "disconnected"; @@ -214,6 +228,9 @@ export class WebSocketTransport extends EventEmitter implements ITransport { */ public async clear(channel: string): Promise { await this.storage.clear(channel); + for (const key of this.pendingNonces.keys()) { + if (key.startsWith(`${channel}:`)) this.pendingNonces.delete(key); + } const sub = this.centrifuge.getSubscription(channel); if (sub) this.centrifuge.removeSubscription(sub as Subscription); } @@ -229,6 +246,11 @@ export class WebSocketTransport extends EventEmitter implements ITransport { /** * Parses an incoming raw message, checks for duplicates, and emits it. + * + * The nonce is checked for deduplication but NOT persisted here. The emitted + * payload includes a `confirmNonce` callback that the consumer (BaseClient) + * must call after successful decryption. This prevents an attacker from + * poisoning the nonce tracker with high-nonce messages that fail decryption. */ private async _handleIncomingMessage(channel: string, rawData: string): Promise { try { @@ -246,13 +268,44 @@ export class WebSocketTransport extends EventEmitter implements ITransport { const latestNonces = await this.storage.getLatestNonces(channel); const latestNonce = latestNonces.get(message.clientId) || 0; - if (message.nonce > latestNonce) { - // This is a new message, update the latest nonce and emit the message. - latestNonces.set(message.clientId, message.nonce); - await this.storage.setLatestNonces(channel, latestNonces); - this.emit("message", { channel, data: message.payload }); + if (message.nonce <= latestNonce) { + return; } - // If message.nonce <= latestNonce, it's a duplicate and we ignore it. + + // Reject suspiciously large nonce jumps (but allow first message from a new sender). + if (latestNonce > 0 && message.nonce - latestNonce > MAX_NONCE_JUMP) { + this.emit("error", new TransportError(ErrorCode.TRANSPORT_PARSE_FAILED, `Nonce jump too large: ${latestNonce} -> ${message.nonce}`)); + return; + } + + // Guard against duplicate processing between emit and confirm. + // Without this, a message arriving via both live publication and + // _fetchHistory could pass the storage-based dedup check twice + // because the nonce hasn't been confirmed yet. + const pendingKey = `${channel}:${message.clientId}`; + const pending = this.pendingNonces.get(pendingKey); + if (pending?.has(message.nonce)) { + return; + } + if (!pending) { + this.pendingNonces.set(pendingKey, new Set([message.nonce])); + } else { + pending.add(message.nonce); + } + + const confirmNonce = async () => { + try { + await this.storage.confirmNonce(channel, message.clientId, message.nonce); + } catch (error) { + this.emit("error", new TransportError(ErrorCode.UNKNOWN, `Failed to confirm nonce: ${error instanceof Error ? error.message : String(error)}`)); + } + const p = this.pendingNonces.get(pendingKey); + if (p) { + p.delete(message.nonce); + if (p.size === 0) this.pendingNonces.delete(pendingKey); + } + }; + this.emit("message", { channel, data: message.payload, confirmNonce }); } catch (error) { this.emit("error", new TransportError(ErrorCode.TRANSPORT_PARSE_FAILED, `Failed to parse incoming message: ${error instanceof Error ? error.message : "Unknown error"}`)); } diff --git a/packages/core/src/transport/websocket/store.test.ts b/packages/core/src/transport/websocket/store.test.ts index 871e1a8..a58bcbe 100644 --- a/packages/core/src/transport/websocket/store.test.ts +++ b/packages/core/src/transport/websocket/store.test.ts @@ -116,6 +116,52 @@ t.describe("WebSocketTransportStorage", () => { t.expect(nextNonce).toBe(6); t.expect(await kvstore.get(newStorage.getNonceKey(channel))).toBe("6"); }); + + t.test("should recover from NaN nonce value in storage", async () => { + const channel = "session:nan-channel"; + + // Corrupt the stored nonce value + await kvstore.set(storage.getNonceKey(channel), "not-a-number"); + + const nonce = await storage.getNextNonce(channel); + t.expect(nonce).toBe(1); + t.expect(await kvstore.get(storage.getNonceKey(channel))).toBe("1"); + }); + }); + + t.describe("Nonce Confirmation", () => { + let storage: WebSocketTransportStorage; + + t.beforeEach(async () => { + storage = await WebSocketTransportStorage.create(kvstore); + }); + + t.test("should save nonce via confirmNonce", async () => { + const channel = "session:confirm-channel"; + await storage.confirmNonce(channel, "sender-1", 5); + + const nonces = await storage.getLatestNonces(channel); + t.expect(nonces.get("sender-1")).toBe(5); + }); + + t.test("should not regress nonce on confirmNonce with lower value", async () => { + const channel = "session:confirm-channel"; + await storage.confirmNonce(channel, "sender-1", 10); + await storage.confirmNonce(channel, "sender-1", 3); + + const nonces = await storage.getLatestNonces(channel); + t.expect(nonces.get("sender-1")).toBe(10); + }); + + t.test("should track nonces independently per sender", async () => { + const channel = "session:confirm-channel"; + await storage.confirmNonce(channel, "sender-1", 5); + await storage.confirmNonce(channel, "sender-2", 8); + + const nonces = await storage.getLatestNonces(channel); + t.expect(nonces.get("sender-1")).toBe(5); + t.expect(nonces.get("sender-2")).toBe(8); + }); }); t.describe("Latest Nonces Management", () => { diff --git a/packages/core/src/transport/websocket/store.ts b/packages/core/src/transport/websocket/store.ts index a7a7263..10362e9 100644 --- a/packages/core/src/transport/websocket/store.ts +++ b/packages/core/src/transport/websocket/store.ts @@ -1,3 +1,4 @@ +import { Mutex } from "async-mutex"; import { v4 as uuid } from "uuid"; import type { IKVStore } from "../../domain/kv-store"; @@ -8,6 +9,7 @@ import type { IKVStore } from "../../domain/kv-store"; export class WebSocketTransportStorage { private readonly kvstore: IKVStore; private readonly clientId: string; + private readonly nonceMutex = new Mutex(); /** * Creates a new WebSocketTransportStorage instance with a persistent client ID. @@ -42,12 +44,28 @@ export class WebSocketTransportStorage { async getNextNonce(channel: string): Promise { const key = this.getNonceKey(channel); const value = await this.kvstore.get(key); - const currentNonce = value ? parseInt(value, 10) : 0; + let currentNonce = value ? parseInt(value, 10) : 0; + if (Number.isNaN(currentNonce)) currentNonce = 0; const nextNonce = currentNonce + 1; await this.kvstore.set(key, nextNonce.toString()); return nextNonce; } + /** + * Confirms a received nonce after the message has been successfully processed + * (e.g., decrypted). Only updates if the nonce is higher than the current value. + */ + async confirmNonce(channel: string, clientId: string, nonce: number): Promise { + await this.nonceMutex.runExclusive(async () => { + const latestNonces = await this.getLatestNonces(channel); + const current = latestNonces.get(clientId) || 0; + if (nonce > current) { + latestNonces.set(clientId, nonce); + await this.setLatestNonces(channel, latestNonces); + } + }); + } + /** * Retrieves the latest received nonces from all senders on the specified channel. * Used for message deduplication - only messages with nonces greater than the