diff --git a/examples/sdk-core/README.md b/examples/sdk-core/README.md index ad24a49..d3b1059 100644 --- a/examples/sdk-core/README.md +++ b/examples/sdk-core/README.md @@ -114,8 +114,10 @@ realtimeClient.disconnect(); const realtimeClient = await client.realtime.connect(null, { model: models.realtime("live_avatar"), onRemoteStream: (videoStream) => { ... }, - avatar: { avatarImage: "https://example.com/avatar.png" }, - initialState: { prompt: { text: "A friendly assistant", enhance: true } }, + initialState: { + image: "https://example.com/avatar.png", + prompt: { text: "A friendly assistant", enhance: true }, + }, }); await realtimeClient.playAudio(audioBlob); @@ -124,7 +126,9 @@ const micStream = await navigator.mediaDevices.getUserMedia({ audio: true, video const realtimeClient = await client.realtime.connect(micStream, { model: models.realtime("live_avatar"), onRemoteStream: (videoStream) => { ... }, - avatar: { avatarImage: avatarFile }, - initialState: { prompt: { text: "A friendly assistant", enhance: true } }, + initialState: { + image: avatarFile, + prompt: { text: "A friendly assistant", enhance: true }, + }, }); ``` diff --git a/examples/sdk-core/realtime/live-avatar.ts b/examples/sdk-core/realtime/live-avatar.ts index 9d55584..10cb72e 100644 --- a/examples/sdk-core/realtime/live-avatar.ts +++ b/examples/sdk-core/realtime/live-avatar.ts @@ -22,10 +22,8 @@ async function withPlayAudio() { const video = document.getElementById("output") as HTMLVideoElement; video.srcObject = videoStream; }, - avatar: { - avatarImage: "https://example.com/avatar.png", // or File/Blob - }, initialState: { + image: "https://example.com/avatar.png", // or File/Blob prompt: { text: "A friendly assistant", enhance: true }, }, }); @@ -63,10 +61,8 @@ async function withMicInput() { const video = document.getElementById("output") as HTMLVideoElement; video.srcObject = videoStream; }, - avatar: { - avatarImage: "https://example.com/avatar.png", - }, initialState: { + image: "https://example.com/avatar.png", prompt: { text: "A friendly assistant", enhance: true }, }, }); diff --git a/packages/sdk/index.html b/packages/sdk/index.html index fbe0c4c..fe1fd0d 100644 --- a/packages/sdk/index.html +++ b/packages/sdk/index.html @@ -498,6 +498,13 @@

Console Logs

addLog('Connecting to Decart server...', 'info'); + // Load initial reference image for models that support it + let initialImage; + if (model.name === 'lucy_2_rt' || model.name === 'mirage_v2') { + const initialImageResponse = await fetch('./tests/fixtures/image.png'); + initialImage = await initialImageResponse.blob(); + } + decartRealtime = await decartClient.realtime.connect(localStream, { model, onRemoteStream: (stream) => { @@ -507,7 +514,8 @@

Console Logs

initialState: { prompt: { text: "Lego World", - } + }, + ...(initialImage && { image: initialImage }), } }); diff --git a/packages/sdk/src/index.ts b/packages/sdk/src/index.ts index 0a86932..6f2b506 100644 --- a/packages/sdk/src/index.ts +++ b/packages/sdk/src/index.ts @@ -18,7 +18,6 @@ export type { QueueSubmitOptions, } from "./queue/types"; export type { - AvatarOptions, Events as RealTimeEvents, RealTimeClient, RealTimeClientConnectOptions, diff --git a/packages/sdk/src/realtime/client.ts b/packages/sdk/src/realtime/client.ts index 017239c..498dd12 100644 --- a/packages/sdk/src/realtime/client.ts +++ b/packages/sdk/src/realtime/client.ts @@ -1,5 +1,5 @@ import { z } from "zod"; -import { modelDefinitionSchema } from "../shared/model"; +import { modelDefinitionSchema, type RealTimeModels } from "../shared/model"; import { modelStateSchema } from "../shared/types"; import { createWebrtcError, type DecartSDKError } from "../utils/errors"; import { AudioStreamManager } from "./audio-stream-manager"; @@ -80,11 +80,6 @@ export type RealTimeClientInitialState = z.infer(schema: T) => z.custom[0]>((fn) => schema.implementAsync(fn as Parameters[0])); -const avatarOptionsSchema = z.object({ - avatarImage: z.union([z.instanceof(Blob), z.instanceof(File), z.string()]), -}); -export type AvatarOptions = z.infer; - const realTimeClientConnectOptionsSchema = z.object({ model: modelDefinitionSchema, onRemoteStream: z.custom((val) => typeof val === "function", { @@ -92,7 +87,6 @@ const realTimeClientConnectOptionsSchema = z.object({ }), initialState: realTimeClientInitialStateSchema.optional(), customizeOffer: createAsyncFunctionSchema(z.function()).optional(), - avatar: avatarOptionsSchema.optional(), }); export type RealTimeClientConnectOptions = z.infer; @@ -133,7 +127,7 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { const isAvatarLive = options.model.name === "live_avatar"; - const { onRemoteStream, initialState, avatar } = parsedOptions.data; + const { onRemoteStream, initialState } = parsedOptions.data; // For live_avatar without user-provided stream: create AudioStreamManager for continuous silent stream with audio injection // If user provides their own stream (e.g., mic input), use it directly @@ -150,26 +144,16 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { let webrtcManager: WebRTCManager | undefined; try { - // For live_avatar: prepare avatar image base64 before connection - let avatarImageBase64: string | undefined; - if (isAvatarLive && avatar?.avatarImage) { - if (typeof avatar.avatarImage === "string") { - const response = await fetch(avatar.avatarImage); - if (!response.ok) { - throw new Error(`Failed to fetch image: ${response.status} ${response.statusText}`); + // Prepare initial image base64 before connection + const initialImage = initialState?.image ? await imageToBase64(initialState.image) : undefined; + + // Prepare initial prompt to send via WebSocket before WebRTC handshake + const initialPrompt = initialState?.prompt + ? { + text: initialState.prompt.text, + enhance: initialState.prompt.enhance, } - const imageBlob = await response.blob(); - avatarImageBase64 = await blobToBase64(imageBlob); - } else { - avatarImageBase64 = await blobToBase64(avatar.avatarImage); - } - } - - // For live_avatar: prepare initial prompt to send before WebRTC handshake - const initialPrompt = - isAvatarLive && initialState?.prompt - ? { text: initialState.prompt.text, enhance: initialState.prompt.enhance } - : undefined; + : undefined; const url = `${baseUrl}${options.model.urlPath}`; @@ -189,8 +173,8 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { customizeOffer: options.customizeOffer as ((offer: RTCSessionDescriptionInit) => Promise) | undefined, vp8MinBitrate: 300, vp8StartBitrate: 600, - isAvatarLive, - avatarImageBase64, + modelName: options.model.name as RealTimeModels, + initialImage, initialPrompt, }); @@ -213,12 +197,6 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { const methods = realtimeMethods(manager, imageToBase64); - // For non-live_avatar models: send initial prompt after connection is established - if (!isAvatarLive && initialState?.prompt) { - const { text, enhance } = initialState.prompt; - await methods.setPrompt(text, { enhance }); - } - const client: RealTimeClient = { set: methods.set, setPrompt: methods.setPrompt, diff --git a/packages/sdk/src/realtime/webrtc-connection.ts b/packages/sdk/src/realtime/webrtc-connection.ts index a412615..000936c 100644 --- a/packages/sdk/src/realtime/webrtc-connection.ts +++ b/packages/sdk/src/realtime/webrtc-connection.ts @@ -1,4 +1,5 @@ import mitt from "mitt"; +import type { RealTimeModels } from "../shared/model"; import { buildUserAgent } from "../utils/user-agent"; import type { ConnectionState, @@ -21,8 +22,8 @@ interface ConnectionCallbacks { customizeOffer?: (offer: RTCSessionDescriptionInit) => Promise; vp8MinBitrate?: number; vp8StartBitrate?: number; - isAvatarLive?: boolean; - avatarImageBase64?: string; + modelName?: RealTimeModels; + initialImage?: string; initialPrompt?: { text: string; enhance?: boolean }; } @@ -96,12 +97,17 @@ export class WebRTCConnection { connectAbort, ]); - // Phase 2: Pre-handshake setup (avatar image + initial prompt) + // Phase 2: Pre-handshake setup (initial image and/or prompt) // connectionReject is already active, so ws.onclose or server errors abort these too - if (this.callbacks.avatarImageBase64) { - await Promise.race([this.sendAvatarImage(this.callbacks.avatarImageBase64), connectAbort]); - } - if (this.callbacks.initialPrompt) { + if (this.callbacks.initialImage) { + await Promise.race([ + this.setImageBase64(this.callbacks.initialImage, { + prompt: this.callbacks.initialPrompt?.text, + enhance: this.callbacks.initialPrompt?.enhance, + }), + connectAbort, + ]); + } else if (this.callbacks.initialPrompt) { await Promise.race([this.sendInitialPrompt(this.callbacks.initialPrompt), connectAbort]); } @@ -228,16 +234,6 @@ export class WebRTCConnection { return false; } - private async sendAvatarImage(imageBase64: string): Promise { - return this.setImageBase64(imageBase64); - } - - /** - * Send an image to the server (e.g., as a reference for inference). - * Can be called after connection is established. - * Pass null to clear the reference image or use a placeholder. - * Optionally include a prompt to send with the image. - */ async setImageBase64( imageBase64: string | null, options?: { prompt?: string; enhance?: boolean; timeout?: number }, @@ -260,7 +256,12 @@ export class WebRTCConnection { this.websocketMessagesEmitter.on("setImageAck", listener); - const message: { type: "set_image"; image_data: string | null; prompt?: string; enhance_prompt?: boolean } = { + const message: { + type: "set_image"; + image_data: string | null; + prompt?: string; + enhance_prompt?: boolean; + } = { type: "set_image", image_data: imageBase64, }; @@ -304,7 +305,13 @@ export class WebRTCConnection { this.websocketMessagesEmitter.on("promptAck", listener); - if (!this.send({ type: "prompt", prompt: prompt.text, enhance_prompt: prompt.enhance ?? true })) { + if ( + !this.send({ + type: "prompt", + prompt: prompt.text, + enhance_prompt: prompt.enhance ?? true, + }) + ) { clearTimeout(timeoutId); this.websocketMessagesEmitter.off("promptAck", listener); reject(new Error("WebSocket is not open")); @@ -341,7 +348,7 @@ export class WebRTCConnection { if (this.localStream) { // For live_avatar: add receive-only video transceiver (sends audio only, receives audio+video) - if (this.callbacks.isAvatarLive) { + if (this.callbacks.modelName === "live_avatar") { this.pc.addTransceiver("video", { direction: "recvonly" }); } diff --git a/packages/sdk/src/realtime/webrtc-manager.ts b/packages/sdk/src/realtime/webrtc-manager.ts index f2ae2c7..b29b276 100644 --- a/packages/sdk/src/realtime/webrtc-manager.ts +++ b/packages/sdk/src/realtime/webrtc-manager.ts @@ -1,4 +1,5 @@ import pRetry, { AbortError } from "p-retry"; +import type { RealTimeModels } from "../shared/model"; import type { ConnectionState, OutgoingMessage } from "./types"; import { WebRTCConnection } from "./webrtc-connection"; @@ -11,8 +12,8 @@ export interface WebRTCConfig { customizeOffer?: (offer: RTCSessionDescriptionInit) => Promise; vp8MinBitrate?: number; vp8StartBitrate?: number; - isAvatarLive?: boolean; - avatarImageBase64?: string; + modelName?: RealTimeModels; + initialImage?: string; initialPrompt?: { text: string; enhance?: boolean }; } @@ -54,8 +55,8 @@ export class WebRTCManager { customizeOffer: config.customizeOffer, vp8MinBitrate: config.vp8MinBitrate, vp8StartBitrate: config.vp8StartBitrate, - isAvatarLive: config.isAvatarLive, - avatarImageBase64: config.avatarImageBase64, + modelName: config.modelName, + initialImage: config.initialImage, initialPrompt: config.initialPrompt, }); } diff --git a/packages/sdk/src/shared/types.ts b/packages/sdk/src/shared/types.ts index c725c93..37f7d0a 100644 --- a/packages/sdk/src/shared/types.ts +++ b/packages/sdk/src/shared/types.ts @@ -7,5 +7,6 @@ export const modelStateSchema = z.object({ enhance: z.boolean().optional().default(true), }) .optional(), + image: z.union([z.instanceof(Blob), z.instanceof(File), z.string()]).optional(), }); export type ModelState = z.infer; diff --git a/packages/sdk/tests/unit.test.ts b/packages/sdk/tests/unit.test.ts index bcf4623..eecc208 100644 --- a/packages/sdk/tests/unit.test.ts +++ b/packages/sdk/tests/unit.test.ts @@ -1084,7 +1084,7 @@ describe("WebRTCConnection", () => { }); describe("RealTimeClient cleanup", () => { - it("cleans up AudioStreamManager when avatar fetch fails before WebRTC connect", async () => { + it("cleans up AudioStreamManager when initial image fetch fails before WebRTC connect", async () => { class FakeAudioContext { createMediaStreamDestination() { return { stream: {} }; @@ -1121,7 +1121,7 @@ describe("RealTimeClient cleanup", () => { realtime.connect(null, { model: models.realtime("live_avatar"), onRemoteStream: vi.fn(), - avatar: { avatarImage: "https://example.com/avatar.png" }, + initialState: { image: "https://example.com/avatar.png" }, }), ).rejects.toThrow("Failed to fetch image: 404 Not Found"); @@ -1705,11 +1705,22 @@ describe("WebSockets Connection", () => { const connectSpy = vi.spyOn(WebRTCManager.prototype, "connect").mockImplementation(async function () { const manager = this as unknown as { - config: { onConnectionStateChange?: (state: import("../src/realtime/types").ConnectionState) => void }; + config: { + onConnectionStateChange?: (state: import("../src/realtime/types").ConnectionState) => void; + initialPrompt?: { text: string; enhance?: boolean }; + }; managerState: import("../src/realtime/types").ConnectionState; }; manager.managerState = "connected"; manager.config.onConnectionStateChange?.("connected"); + + // Simulate initial prompt sent via WebSocket during connection setup + if (manager.config.initialPrompt) { + await new Promise((resolve) => setTimeout(resolve, 0)); + manager.managerState = "generating"; + manager.config.onConnectionStateChange?.("generating"); + } + return true; }); const stateSpy = vi.spyOn(WebRTCManager.prototype, "getConnectionState").mockImplementation(function () {