diff --git a/packages/sdk/src/realtime/client.ts b/packages/sdk/src/realtime/client.ts index 7a5391a..02326c3 100644 --- a/packages/sdk/src/realtime/client.ts +++ b/packages/sdk/src/realtime/client.ts @@ -79,6 +79,8 @@ export type RealTimeClientOptions = { const realTimeClientInitialStateSchema = modelStateSchema; type OnRemoteStreamFn = (stream: MediaStream) => void; +type OnStatusFn = (status: string) => void; +type OnQueuePositionFn = (data: { position: number; queueSize: number }) => void; export type RealTimeClientInitialState = z.infer; // ugly workaround to add an optional function to the schema @@ -93,6 +95,16 @@ const realTimeClientConnectOptionsSchema = z.object({ }), initialState: realTimeClientInitialStateSchema.optional(), customizeOffer: createAsyncFunctionSchema(z.function()).optional(), + onStatus: z + .custom((val) => typeof val === "function", { + message: "onStatus must be a function", + }) + .optional(), + onQueuePosition: z + .custom((val) => typeof val === "function", { + message: "onQueuePosition must be a function", + }) + .optional(), }); export type RealTimeClientConnectOptions = z.infer; @@ -100,6 +112,8 @@ export type Events = { connectionChange: ConnectionState; error: DecartSDKError; generationTick: { seconds: number }; + status: string; + queuePosition: { position: number; queueSize: number }; diagnostic: DiagnosticEvent; stats: WebRTCStats; }; @@ -135,7 +149,7 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { const isAvatarLive = options.model.name === "live_avatar"; - const { onRemoteStream, initialState } = parsedOptions.data; + const { onRemoteStream, initialState, onStatus, onQueuePosition } = parsedOptions.data; // For live_avatar without user-provided stream: create AudioStreamManager for continuous silent stream with audio injection // If user provides their own stream (e.g., mic input), use it directly @@ -256,6 +270,17 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { }; manager.getWebsocketMessageEmitter().on("generationTick", tickListener); + const wsEmitter = manager.getWebsocketMessageEmitter(); + wsEmitter.on("status", (msg) => { + emitOrBuffer("status", msg.status); + onStatus?.(msg.status); + }); + wsEmitter.on("queuePosition", (msg) => { + const data = { position: msg.position, queueSize: msg.queue_size }; + emitOrBuffer("queuePosition", data); + onQueuePosition?.(data); + }); + await manager.connect(inputStream); const methods = realtimeMethods(manager, imageToBase64); diff --git a/packages/sdk/src/realtime/types.ts b/packages/sdk/src/realtime/types.ts index acd6b39..349da04 100644 --- a/packages/sdk/src/realtime/types.ts +++ b/packages/sdk/src/realtime/types.ts @@ -60,6 +60,17 @@ export type SetImageAckMessage = { error: null | string; }; +export type StatusMessage = { + type: "status"; + status: string; +}; + +export type QueuePositionMessage = { + type: "queue_position"; + position: number; + queue_size: number; +}; + export type GenerationStartedMessage = { type: "generation_started"; }; @@ -97,7 +108,9 @@ export type IncomingWebRTCMessage = | GenerationStartedMessage | GenerationTickMessage | GenerationEndedMessage - | SessionIdMessage; + | SessionIdMessage + | StatusMessage + | QueuePositionMessage; // Outgoing message types (to server) export type OutgoingWebRTCMessage = diff --git a/packages/sdk/src/realtime/webrtc-connection.ts b/packages/sdk/src/realtime/webrtc-connection.ts index 908b76f..d843124 100644 --- a/packages/sdk/src/realtime/webrtc-connection.ts +++ b/packages/sdk/src/realtime/webrtc-connection.ts @@ -9,8 +9,10 @@ import type { IncomingWebRTCMessage, OutgoingWebRTCMessage, PromptAckMessage, + QueuePositionMessage, SessionIdMessage, SetImageAckMessage, + StatusMessage, TurnConfig, } from "./types"; @@ -36,6 +38,8 @@ type WsMessageEvents = { setImageAck: SetImageAckMessage; sessionId: SessionIdMessage; generationTick: GenerationTickMessage; + status: StatusMessage; + queuePosition: QueuePositionMessage; }; const noopDiagnostic: DiagnosticEmitter = () => {}; @@ -246,6 +250,16 @@ export class WebRTCConnection { return; } + if (msg.type === "status") { + this.websocketMessagesEmitter.emit("status", msg); + return; + } + + if (msg.type === "queue_position") { + this.websocketMessagesEmitter.emit("queuePosition", msg); + return; + } + // All other messages require peer connection if (!this.pc) return; diff --git a/packages/sdk/tests/unit.test.ts b/packages/sdk/tests/unit.test.ts index c1cbde6..d592cf0 100644 --- a/packages/sdk/tests/unit.test.ts +++ b/packages/sdk/tests/unit.test.ts @@ -1522,6 +1522,167 @@ describe("Subscribe Client", () => { } }); + it("exposes status and queue_position websocket messages as realtime client events", async () => { + const { createRealTimeClient } = await import("../src/realtime/client.js"); + const { WebRTCManager } = await import("../src/realtime/webrtc-manager.js"); + + const statusListeners = new Set< + ( + msg: { type: "status"; status: string } | { type: "queue_position"; position: number; queue_size: number }, + ) => void + >(); + const queuePositionListeners = new Set< + ( + msg: { type: "status"; status: string } | { type: "queue_position"; position: number; queue_size: number }, + ) => void + >(); + const websocketEmitter = { + on: ( + event: string, + listener: ( + msg: { type: "status"; status: string } | { type: "queue_position"; position: number; queue_size: number }, + ) => void, + ) => { + if (event === "status") statusListeners.add(listener); + if (event === "queuePosition") queuePositionListeners.add(listener); + }, + off: ( + event: string, + listener: ( + msg: { type: "status"; status: string } | { type: "queue_position"; position: number; queue_size: number }, + ) => void, + ) => { + if (event === "status") statusListeners.delete(listener); + if (event === "queuePosition") queuePositionListeners.delete(listener); + }, + }; + + const connectSpy = vi.spyOn(WebRTCManager.prototype, "connect").mockImplementation(async function () { + const mgr = this as unknown as { + config: { onConnectionStateChange?: (state: import("../src/realtime/types").ConnectionState) => void }; + managerState: import("../src/realtime/types").ConnectionState; + }; + mgr.managerState = "connected"; + mgr.config.onConnectionStateChange?.("connected"); + return true; + }); + const stateSpy = vi.spyOn(WebRTCManager.prototype, "getConnectionState").mockReturnValue("connected"); + const emitterSpy = vi + .spyOn(WebRTCManager.prototype, "getWebsocketMessageEmitter") + .mockReturnValue(websocketEmitter as never); + const cleanupSpy = vi.spyOn(WebRTCManager.prototype, "cleanup").mockImplementation(() => {}); + + try { + const realtime = createRealTimeClient({ baseUrl: "wss://api3.decart.ai", apiKey: "test-key" }); + const client = await realtime.connect({} as MediaStream, { + model: models.realtime("mirage_v2"), + onRemoteStream: vi.fn(), + }); + + const statusEvents: string[] = []; + const queueEvents: Array<{ position: number; queueSize: number }> = []; + + client.on("status", (status) => statusEvents.push(status)); + client.on("queuePosition", (data) => queueEvents.push(data)); + + for (const listener of statusListeners) { + listener({ type: "status", status: "queued" }); + } + for (const listener of queuePositionListeners) { + listener({ type: "queue_position", position: 2, queue_size: 11 }); + } + + await new Promise((resolve) => setTimeout(resolve, 0)); + + expect(statusEvents).toEqual(["queued"]); + expect(queueEvents).toEqual([{ position: 2, queueSize: 11 }]); + } finally { + connectSpy.mockRestore(); + stateSpy.mockRestore(); + emitterSpy.mockRestore(); + cleanupSpy.mockRestore(); + } + }); + + it("calls onStatus and onQueuePosition callbacks when websocket updates arrive", async () => { + const { createRealTimeClient } = await import("../src/realtime/client.js"); + const { WebRTCManager } = await import("../src/realtime/webrtc-manager.js"); + + const statusListeners = new Set< + ( + msg: { type: "status"; status: string } | { type: "queue_position"; position: number; queue_size: number }, + ) => void + >(); + const queuePositionListeners = new Set< + ( + msg: { type: "status"; status: string } | { type: "queue_position"; position: number; queue_size: number }, + ) => void + >(); + const websocketEmitter = { + on: ( + event: string, + listener: ( + msg: { type: "status"; status: string } | { type: "queue_position"; position: number; queue_size: number }, + ) => void, + ) => { + if (event === "status") statusListeners.add(listener); + if (event === "queuePosition") queuePositionListeners.add(listener); + }, + off: ( + event: string, + listener: ( + msg: { type: "status"; status: string } | { type: "queue_position"; position: number; queue_size: number }, + ) => void, + ) => { + if (event === "status") statusListeners.delete(listener); + if (event === "queuePosition") queuePositionListeners.delete(listener); + }, + }; + + const connectSpy = vi.spyOn(WebRTCManager.prototype, "connect").mockImplementation(async function () { + const mgr = this as unknown as { + config: { onConnectionStateChange?: (state: import("../src/realtime/types").ConnectionState) => void }; + managerState: import("../src/realtime/types").ConnectionState; + }; + mgr.managerState = "connected"; + mgr.config.onConnectionStateChange?.("connected"); + return true; + }); + const stateSpy = vi.spyOn(WebRTCManager.prototype, "getConnectionState").mockReturnValue("connected"); + const emitterSpy = vi + .spyOn(WebRTCManager.prototype, "getWebsocketMessageEmitter") + .mockReturnValue(websocketEmitter as never); + const cleanupSpy = vi.spyOn(WebRTCManager.prototype, "cleanup").mockImplementation(() => {}); + + try { + const onStatus = vi.fn(); + const onQueuePosition = vi.fn(); + + const realtime = createRealTimeClient({ baseUrl: "wss://api3.decart.ai", apiKey: "test-key" }); + await realtime.connect({} as MediaStream, { + model: models.realtime("mirage_v2"), + onRemoteStream: vi.fn(), + onStatus, + onQueuePosition, + }); + + for (const listener of statusListeners) { + listener({ type: "status", status: "initializing" }); + } + for (const listener of queuePositionListeners) { + listener({ type: "queue_position", position: 4, queue_size: 19 }); + } + + expect(onStatus).toHaveBeenCalledWith("initializing"); + expect(onQueuePosition).toHaveBeenCalledWith({ position: 4, queueSize: 19 }); + } finally { + connectSpy.mockRestore(); + stateSpy.mockRestore(); + emitterSpy.mockRestore(); + cleanupSpy.mockRestore(); + } + }); + it("buffers pre-session telemetry diagnostics and flushes them after session_id", async () => { const { createRealTimeClient } = await import("../src/realtime/client.js"); const { WebRTCManager } = await import("../src/realtime/webrtc-manager.js");