|
2 | 2 | ApiClient, |
3 | 3 | ApiRequestOptions, |
4 | 4 | makeIdempotencyKey, |
| 5 | + SSEStreamPart, |
| 6 | + SSEStreamSubscription, |
5 | 7 | stringifyIO, |
6 | 8 | TriggerOptions, |
7 | 9 | } from "@trigger.dev/core/v3"; |
@@ -205,47 +207,61 @@ export class TriggerChatTransport< |
205 | 207 | runState: TriggerChatRunState, |
206 | 208 | abortSignal: AbortSignal | undefined, |
207 | 209 | lastEventId?: string |
208 | | - ): Promise<ReadableStream<UIMessageChunk>> { |
| 210 | + ): Promise<ReadableStream<SSEStreamPart<InferUIMessageChunk<UI_MESSAGE>>>> { |
209 | 211 | const streamClient = new ApiClient( |
210 | 212 | this.baseURL, |
211 | 213 | runState.publicAccessToken, |
212 | 214 | this.previewBranch, |
213 | 215 | this.requestOptions |
214 | 216 | ); |
215 | 217 |
|
216 | | - const stream = await streamClient.fetchStream<InferUIMessageChunk<UI_MESSAGE>>( |
217 | | - runState.runId, |
218 | | - runState.streamKey, |
| 218 | + const subscription = new SSEStreamSubscription( |
| 219 | + this.createStreamUrl(runState.runId, runState.streamKey), |
219 | 220 | { |
| 221 | + headers: streamClient.getHeaders(), |
220 | 222 | signal: abortSignal, |
221 | 223 | timeoutInSeconds: this.timeoutInSeconds, |
222 | 224 | lastEventId, |
223 | 225 | } |
224 | 226 | ); |
225 | 227 |
|
226 | | - return stream as unknown as ReadableStream<UIMessageChunk>; |
| 228 | + return (await subscription.subscribe()) as ReadableStream< |
| 229 | + SSEStreamPart<InferUIMessageChunk<UI_MESSAGE>> |
| 230 | + >; |
227 | 231 | } |
228 | 232 |
|
229 | | - private createTrackedStream(chatId: string, stream: ReadableStream<UIMessageChunk>) { |
| 233 | + private createTrackedStream( |
| 234 | + chatId: string, |
| 235 | + stream: ReadableStream<SSEStreamPart<InferUIMessageChunk<UI_MESSAGE>>> |
| 236 | + ) { |
230 | 237 | const teeStreams = stream.tee(); |
231 | 238 | const trackingStream = teeStreams[0]; |
232 | 239 | const consumerStream = teeStreams[1]; |
233 | 240 |
|
234 | 241 | this.consumeTrackingStream(chatId, trackingStream); |
235 | 242 |
|
236 | | - return consumerStream; |
| 243 | + return consumerStream.pipeThrough( |
| 244 | + new TransformStream<SSEStreamPart<InferUIMessageChunk<UI_MESSAGE>>, UIMessageChunk>({ |
| 245 | + transform(part, controller) { |
| 246 | + controller.enqueue(part.chunk as UIMessageChunk); |
| 247 | + }, |
| 248 | + }) |
| 249 | + ); |
237 | 250 | } |
238 | 251 |
|
239 | | - private async consumeTrackingStream(chatId: string, stream: ReadableStream<UIMessageChunk>) { |
| 252 | + private async consumeTrackingStream( |
| 253 | + chatId: string, |
| 254 | + stream: ReadableStream<SSEStreamPart<InferUIMessageChunk<UI_MESSAGE>>> |
| 255 | + ) { |
240 | 256 | try { |
241 | | - for await (const _chunk of stream) { |
| 257 | + for await (const part of stream) { |
242 | 258 | const runState = await this.runStore.get(chatId); |
243 | 259 |
|
244 | 260 | if (!runState) { |
245 | 261 | return; |
246 | 262 | } |
247 | 263 |
|
248 | | - runState.lastEventId = incrementLastEventId(runState.lastEventId); |
| 264 | + runState.lastEventId = part.id; |
249 | 265 | await this.runStore.set(runState); |
250 | 266 | } |
251 | 267 |
|
@@ -274,6 +290,14 @@ export class TriggerChatTransport< |
274 | 290 |
|
275 | 291 | return handle as TriggerTaskResponse; |
276 | 292 | } |
| 293 | + |
| 294 | + private createStreamUrl(runId: string, streamKey: string): string { |
| 295 | + const normalizedBaseUrl = this.baseURL.replace(/\/$/, ""); |
| 296 | + const encodedRunId = encodeURIComponent(runId); |
| 297 | + const encodedStreamKey = encodeURIComponent(streamKey); |
| 298 | + |
| 299 | + return `${normalizedBaseUrl}/realtime/v1/streams/${encodedRunId}/${encodedStreamKey}`; |
| 300 | + } |
277 | 301 | } |
278 | 302 |
|
279 | 303 | export function createTriggerChatTransport< |
@@ -426,17 +450,4 @@ async function createTriggerTaskOptions( |
426 | 450 | }; |
427 | 451 | } |
428 | 452 |
|
429 | | -function incrementLastEventId(lastEventId: string | undefined): string { |
430 | | - if (!lastEventId) { |
431 | | - return "0"; |
432 | | - } |
433 | | - |
434 | | - const numberValue = Number.parseInt(lastEventId, 10); |
435 | | - if (Number.isNaN(numberValue)) { |
436 | | - return "0"; |
437 | | - } |
438 | | - |
439 | | - return String(numberValue + 1); |
440 | | -} |
441 | | - |
442 | 453 | export type { TriggerChatTaskContext }; |
0 commit comments