Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions apps/integration-tests/src/end-to-end.integration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,32 @@ t.describe("E2E Integration Test", () => {
await resumedDappClient.sendRequest(testPayload);
await t.expect(messagePromise).resolves.toEqual(testPayload);
});

t.test("should discard inbound messages when the receiver's session has expired", async () => {
await connectClients(dappClient, walletClient, "trusted");

// Verify the connection works before expiry
const preExpiryPayload = { method: "before_expiry" };
const preExpiryPromise = new Promise((resolve) => walletClient.once("message", resolve));
await dappClient.sendRequest(preExpiryPayload);
await t.expect(preExpiryPromise).resolves.toEqual(preExpiryPayload);

// Force-expire the wallet's session by setting expiresAt to the past
(walletClient as any).session.expiresAt = Date.now() - 1000;

// Listen for the SESSION_EXPIRED error on the wallet
const errorPromise = new Promise<any>((resolve) => walletClient.once("error", resolve));

// Dapp sends another message - wallet should reject it
const postExpiryPayload = { method: "after_expiry" };
await dappClient.sendRequest(postExpiryPayload);

const error = await errorPromise;
t.expect(error.code).toBe("SESSION_EXPIRED");

// Give time for any message processing
await new Promise((resolve) => setTimeout(resolve, 500));
});
});

t.describe("E2E Integration Test via Proxy", () => {
Expand Down
6 changes: 6 additions & 0 deletions packages/core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Fixed

- Prevent nonce poisoning by deferring nonce persistence until after successful decryption ([#69](https://github.com/MetaMask/mobile-wallet-protocol/pull/69))
- Guard against `NaN` in nonce storage ([#69](https://github.com/MetaMask/mobile-wallet-protocol/pull/69))
- Reject inbound messages on expired sessions instead of processing them

## [0.3.1]

### Fixed
Expand Down
1 change: 1 addition & 0 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"registry": "https://registry.npmjs.org/"
},
"dependencies": {
"async-mutex": "^0.5.0",
"centrifuge": "^5.3.5",
"eventemitter3": "^5.0.1",
"uuid": "^11.1.0"
Expand Down
42 changes: 42 additions & 0 deletions packages/core/src/base-client.integration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,48 @@ t.describe("BaseClient", () => {
publishSpy.mockRestore();
});

t.test("should discard inbound messages on an expired session", async () => {
const keyManagerA = new KeyManager();
const keyManagerB = new KeyManager();
const keyPairA = keyManagerA.generateKeyPair();
const keyPairB = keyManagerB.generateKeyPair();

const sessionA: Session = {
id: "session-inbound-expiry",
channel,
keyPair: keyPairA,
theirPublicKey: keyPairB.publicKey,
expiresAt: Date.now() + 60000,
};
const sessionB: Session = {
id: "session-inbound-expiry",
channel,
keyPair: keyPairB,
theirPublicKey: keyPairA.publicKey,
expiresAt: Date.now() - 1000, // Already expired
};

clientA.setSession(sessionA);
clientB.setSession(sessionB);

await clientA["transport"].subscribe(channel);
await clientB["transport"].subscribe(channel);

const errorPromise = new Promise<any>((resolve) => {
clientB.once("error", resolve);
});

const messageToSend: ProtocolMessage = { type: "message", payload: { method: "should_be_dropped" } };
await clientA.sendMessage(channel, messageToSend);

const error = await errorPromise;
t.expect(error.code).toBe("SESSION_EXPIRED");

// Give a moment for any message processing to complete
await new Promise((resolve) => setTimeout(resolve, 200));
t.expect(clientB.receivedMessages).toHaveLength(0);
});

t.test("should reject resume() when client is already connected", async () => {
// 1. Create and store a valid session
const keyManagerA = new KeyManager();
Expand Down
28 changes: 18 additions & 10 deletions packages/core/src/base-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,17 @@ export abstract class BaseClient extends EventEmitter {

this.transport.on("message", async (payload) => {
if (!this.session?.keyPair.privateKey) return;
if (await this.checkSessionExpiry()) {
this.emit("error", new SessionError(ErrorCode.SESSION_EXPIRED, "Session expired"));
return;
}
const message = await this.decryptMessage(payload.data);
if (message) this.handleMessage(message);
if (message) {
// Confirm the nonce only after successful decryption to prevent
// attackers from poisoning the nonce tracker with invalid messages.
await payload.confirmNonce?.();
this.handleMessage(message);
}
});
}

Expand Down Expand Up @@ -139,23 +148,22 @@ export abstract class BaseClient extends EventEmitter {
*/
protected async sendMessage(channel: string, message: ProtocolMessage): Promise<void> {
if (!this.session) throw new SessionError(ErrorCode.SESSION_INVALID_STATE, "Cannot send message: session is not initialized.");
await this.checkSessionExpiry();
if (await this.checkSessionExpiry()) throw new SessionError(ErrorCode.SESSION_EXPIRED, "Session expired");
const plaintext = JSON.stringify(message);
const encrypted = await this.keymanager.encrypt(plaintext, this.session.theirPublicKey);
const ok = await this.transport.publish(channel, encrypted);
if (!ok) throw new TransportError(ErrorCode.TRANSPORT_DISCONNECTED, "Message could not be sent because the transport is disconnected.");
}

/**
* Checks if the current session is expired. If it is, triggers a disconnect.
* @throws {SessionError} if the session is expired.
* Checks if the current session has expired. If so, triggers a disconnect.
*
* @returns true if the session was expired (and cleanup was triggered), false otherwise.
*/
private async checkSessionExpiry(): Promise<void> {
if (!this.session) return;
if (this.session.expiresAt < Date.now()) {
await this.disconnect();
throw new SessionError(ErrorCode.SESSION_EXPIRED, "Session expired");
}
private async checkSessionExpiry(): Promise<boolean> {
if (!this.session || this.session.expiresAt >= Date.now()) return false;
await this.disconnect();
return true;
}

/**
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/domain/transport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ export interface ITransport {
* @param event The name of the event to listen for.
* @param handler The callback function to execute.
*/
on(event: "message", handler: (payload: { channel: string; data: string }) => void): void;
on(event: "message", handler: (payload: { channel: string; data: string; confirmNonce?: () => Promise<void> }) => void): void;
on(event: "connecting" | "connected" | "disconnected", handler: () => void): void;
on(event: "error", handler: (error: Error) => void): void;

Expand Down
16 changes: 11 additions & 5 deletions packages/core/src/transport/websocket/index.integration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ t.describe.each(testModes)("WebSocketTransport with $name", ({ useSharedConnecti
const received = await messagePromise;

// The subscriber transport should unwrap the envelope and emit the original payload.
t.expect(received).toEqual({ channel, data: payload });
t.expect(received).toMatchObject({ channel, data: payload });
t.expect(received.confirmNonce).toBeTypeOf("function");

await publisher.disconnect();
});
Expand Down Expand Up @@ -431,16 +432,18 @@ t.describe.each(testModes)("WebSocketTransport with $name", ({ useSharedConnecti
const messagePayload = "dedup-test-message";
let messageCount = 0;

subscriber.on("message", ({ data }) => {
subscriber.on("message", ({ data, confirmNonce }) => {
if (data === messagePayload) {
messageCount++;
confirmNonce?.();
}
});

// Send the message once using normal publish
const firstMessagePromise = waitFor(subscriber, "message");
await rawPublisher.publish(channel, messagePayload);
await firstMessagePromise;
const firstMsg = await firstMessagePromise;
await firstMsg.confirmNonce?.();
t.expect(messageCount).toBe(1);

// Create the exact same message envelope that was sent
Expand Down Expand Up @@ -791,8 +794,11 @@ t.describe.each(testModes)("WebSocketTransport with $name", ({ useSharedConnecti
await publisher.publish(channel, `message-${i}`);
}

// Wait for all messages to be received
await Promise.all(messagePromises);
// Wait for all messages to be received and confirm nonces so they persist
const receivedMsgs = await Promise.all(messagePromises);
for (const msg of receivedMsgs) {
await msg.confirmNonce?.();
}

// Verify storage has accumulated data
const storage = (transport as any).storage;
Expand Down
59 changes: 53 additions & 6 deletions packages/core/src/transport/websocket/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ type TransportState = "disconnected" | "connecting" | "connected";

/** The maximum number of messages to fetch from history upon a new subscription. */
const HISTORY_FETCH_LIMIT = 50;
/**
* Maximum allowed nonce jump from a known sender. Messages with a nonce that
* jumps more than this from the last confirmed nonce are rejected as suspicious.
* Does not apply to the first message from a new sender (baseline is 0).
*
* Trade-off: if the receiver goes offline and misses more than this many
* messages from a known sender, legitimate messages will be permanently
* blocked. In practice this is unlikely given low message rates in MWP
* sessions, but worth noting.
*/
const MAX_NONCE_JUMP = 100;
/** The maximum number of retry attempts for publishing a message. */
const MAX_RETRY_ATTEMPTS = 5;
/** The base delay in milliseconds for exponential backoff between publish retries. */
Expand All @@ -65,6 +76,7 @@ export class WebSocketTransport extends EventEmitter implements ITransport {
private readonly centrifuge: Centrifuge | SharedCentrifuge;
private readonly storage: WebSocketTransportStorage;
private readonly queue: QueuedItem[] = [];
private readonly pendingNonces = new Map<string, Set<number>>();
private isProcessingQueue = false;
private state: TransportState = "disconnected";

Expand Down Expand Up @@ -214,6 +226,9 @@ export class WebSocketTransport extends EventEmitter implements ITransport {
*/
public async clear(channel: string): Promise<void> {
await this.storage.clear(channel);
for (const key of this.pendingNonces.keys()) {
if (key.startsWith(`${channel}:`)) this.pendingNonces.delete(key);
}
const sub = this.centrifuge.getSubscription(channel);
if (sub) this.centrifuge.removeSubscription(sub as Subscription);
}
Expand All @@ -229,6 +244,11 @@ export class WebSocketTransport extends EventEmitter implements ITransport {

/**
* Parses an incoming raw message, checks for duplicates, and emits it.
*
* The nonce is checked for deduplication but NOT persisted here. The emitted
* payload includes a `confirmNonce` callback that the consumer (BaseClient)
* must call after successful decryption. This prevents an attacker from
* poisoning the nonce tracker with high-nonce messages that fail decryption.
*/
private async _handleIncomingMessage(channel: string, rawData: string): Promise<void> {
try {
Expand All @@ -246,13 +266,40 @@ export class WebSocketTransport extends EventEmitter implements ITransport {
const latestNonces = await this.storage.getLatestNonces(channel);
const latestNonce = latestNonces.get(message.clientId) || 0;

if (message.nonce > latestNonce) {
// This is a new message, update the latest nonce and emit the message.
latestNonces.set(message.clientId, message.nonce);
await this.storage.setLatestNonces(channel, latestNonces);
this.emit("message", { channel, data: message.payload });
if (message.nonce <= latestNonce) {
return;
}

// Reject suspiciously large nonce jumps (but allow first message from a new sender).
if (latestNonce > 0 && message.nonce - latestNonce > MAX_NONCE_JUMP) {
this.emit("error", new TransportError(ErrorCode.TRANSPORT_PARSE_FAILED, `Nonce jump too large: ${latestNonce} -> ${message.nonce}`));
return;
}

// Guard against duplicate processing between emit and confirm.
// Without this, a message arriving via both live publication and
// _fetchHistory could pass the storage-based dedup check twice
// because the nonce hasn't been confirmed yet.
const pendingKey = `${channel}:${message.clientId}`;
const pending = this.pendingNonces.get(pendingKey);
if (pending?.has(message.nonce)) {
return;
}
// If message.nonce <= latestNonce, it's a duplicate and we ignore it.
if (!pending) {
this.pendingNonces.set(pendingKey, new Set([message.nonce]));
} else {
pending.add(message.nonce);
}

const confirmNonce = async () => {
await this.storage.confirmNonce(channel, message.clientId, message.nonce);
const p = this.pendingNonces.get(pendingKey);
if (p) {
p.delete(message.nonce);
if (p.size === 0) this.pendingNonces.delete(pendingKey);
}
};
this.emit("message", { channel, data: message.payload, confirmNonce });
} catch (error) {
this.emit("error", new TransportError(ErrorCode.TRANSPORT_PARSE_FAILED, `Failed to parse incoming message: ${error instanceof Error ? error.message : "Unknown error"}`));
}
Expand Down
46 changes: 46 additions & 0 deletions packages/core/src/transport/websocket/store.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,52 @@ t.describe("WebSocketTransportStorage", () => {
t.expect(nextNonce).toBe(6);
t.expect(await kvstore.get(newStorage.getNonceKey(channel))).toBe("6");
});

t.test("should recover from NaN nonce value in storage", async () => {
const channel = "session:nan-channel";

// Corrupt the stored nonce value
await kvstore.set(storage.getNonceKey(channel), "not-a-number");

const nonce = await storage.getNextNonce(channel);
t.expect(nonce).toBe(1);
t.expect(await kvstore.get(storage.getNonceKey(channel))).toBe("1");
});
});

t.describe("Nonce Confirmation", () => {
let storage: WebSocketTransportStorage;

t.beforeEach(async () => {
storage = await WebSocketTransportStorage.create(kvstore);
});

t.test("should save nonce via confirmNonce", async () => {
const channel = "session:confirm-channel";
await storage.confirmNonce(channel, "sender-1", 5);

const nonces = await storage.getLatestNonces(channel);
t.expect(nonces.get("sender-1")).toBe(5);
});

t.test("should not regress nonce on confirmNonce with lower value", async () => {
const channel = "session:confirm-channel";
await storage.confirmNonce(channel, "sender-1", 10);
await storage.confirmNonce(channel, "sender-1", 3);

const nonces = await storage.getLatestNonces(channel);
t.expect(nonces.get("sender-1")).toBe(10);
});

t.test("should track nonces independently per sender", async () => {
const channel = "session:confirm-channel";
await storage.confirmNonce(channel, "sender-1", 5);
await storage.confirmNonce(channel, "sender-2", 8);

const nonces = await storage.getLatestNonces(channel);
t.expect(nonces.get("sender-1")).toBe(5);
t.expect(nonces.get("sender-2")).toBe(8);
});
});

t.describe("Latest Nonces Management", () => {
Expand Down
20 changes: 19 additions & 1 deletion packages/core/src/transport/websocket/store.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { Mutex } from "async-mutex";
import { v4 as uuid } from "uuid";
import type { IKVStore } from "../../domain/kv-store";

Expand All @@ -8,6 +9,7 @@ import type { IKVStore } from "../../domain/kv-store";
export class WebSocketTransportStorage {
private readonly kvstore: IKVStore;
private readonly clientId: string;
private readonly nonceMutex = new Mutex();

/**
* Creates a new WebSocketTransportStorage instance with a persistent client ID.
Expand Down Expand Up @@ -42,12 +44,28 @@ export class WebSocketTransportStorage {
async getNextNonce(channel: string): Promise<number> {
const key = this.getNonceKey(channel);
const value = await this.kvstore.get(key);
const currentNonce = value ? parseInt(value, 10) : 0;
let currentNonce = value ? parseInt(value, 10) : 0;
if (Number.isNaN(currentNonce)) currentNonce = 0;
const nextNonce = currentNonce + 1;
await this.kvstore.set(key, nextNonce.toString());
return nextNonce;
}

/**
* Confirms a received nonce after the message has been successfully processed
* (e.g., decrypted). Only updates if the nonce is higher than the current value.
*/
async confirmNonce(channel: string, clientId: string, nonce: number): Promise<void> {
await this.nonceMutex.runExclusive(async () => {
const latestNonces = await this.getLatestNonces(channel);
const current = latestNonces.get(clientId) || 0;
if (nonce > current) {
latestNonces.set(clientId, nonce);
await this.setLatestNonces(channel, latestNonces);
}
});
}

/**
* Retrieves the latest received nonces from all senders on the specified channel.
* Used for message deduplication - only messages with nonces greater than the
Expand Down
Loading