From df495758411bbb99e06a575d0b627cc7c8d76bc3 Mon Sep 17 00:00:00 2001 From: Alem Tuzlak Date: Sat, 30 May 2026 16:33:02 +0200 Subject: [PATCH 01/10] feat(ai-client): add client-side chat persistence adapter Add an optional `persistence` adapter to ChatClient with the same getItem/setItem/removeItem shape used elsewhere in the SDK. When provided, the client hydrates from getItem(id) on construction (sync or async), saves to setItem(id, messages) on every message change via an ordered write queue, and calls removeItem(id) on clear(). Late chunks from a stream cleared mid-flight are suppressed so they can't repopulate cleared state. When omitted, behavior is unchanged. The persistence option is threaded through all framework wrappers (react, preact, solid, vue, svelte), which now read hydrated messages from the client after construction. Implements TanStack/ai#374. --- .changeset/chat-client-persistence.md | 10 + docs/chat/persistence.md | 142 ++ docs/config.json | 4 + .../typescript/ai-client/src/chat-client.ts | 347 +++- .../ai-client/src/connection-adapters.ts | 21 +- packages/typescript/ai-client/src/index.ts | 1 + packages/typescript/ai-client/src/types.ts | 22 + .../ai-client/tests/chat-client.test.ts | 1608 ++++++++++++++++- packages/typescript/ai-preact/src/use-chat.ts | 82 +- .../ai-preact/tests/use-chat.test.ts | 377 +++- packages/typescript/ai-react/src/use-chat.ts | 73 +- .../ai-react/tests/use-chat.test.ts | 322 +++- packages/typescript/ai-solid/src/use-chat.ts | 17 +- .../ai-solid/tests/use-chat.test.ts | 59 + .../ai-svelte/src/create-chat.svelte.ts | 5 + .../ai-svelte/tests/use-chat.test.ts | 53 + packages/typescript/ai-vue/src/use-chat.ts | 5 + .../typescript/ai-vue/tests/use-chat.test.ts | 55 + testing/e2e/src/routes/$provider/$feature.tsx | 30 +- testing/e2e/tests/chat.spec.ts | 31 + 20 files changed, 3116 insertions(+), 148 deletions(-) create mode 100644 .changeset/chat-client-persistence.md create mode 100644 docs/chat/persistence.md diff --git a/.changeset/chat-client-persistence.md b/.changeset/chat-client-persistence.md new file mode 100644 index 000000000..0a6124036 --- /dev/null +++ b/.changeset/chat-client-persistence.md @@ -0,0 +1,10 @@ +--- +"@tanstack/ai-client": minor +"@tanstack/ai-react": minor +"@tanstack/ai-preact": minor +"@tanstack/ai-solid": minor +"@tanstack/ai-svelte": minor +"@tanstack/ai-vue": minor +--- + +Add persistence support for chat messages. diff --git a/docs/chat/persistence.md b/docs/chat/persistence.md new file mode 100644 index 000000000..db877be0c --- /dev/null +++ b/docs/chat/persistence.md @@ -0,0 +1,142 @@ +--- +title: Persistence +id: chat-persistence +order: 5 +description: "Persist chat conversations on the client with TanStack AI — hydrate on load, save on change, and clear on reset using a simple getItem/setItem/removeItem adapter." +keywords: + - tanstack ai + - persistence + - chat history + - localStorage + - indexeddb + - offline + - hydration +--- + +By default a `ChatClient` (and every framework `useChat`/`createChat` wrapper) keeps messages in memory only — reload the page or navigate away and the conversation is gone. The optional **persistence adapter** wires the client to a storage backend so conversations survive reloads, with no manual `initialMessages` + `onFinish` boilerplate. + +This is especially useful for SPAs, Electron apps, and offline-first setups where the client is the source of truth and there's no server managing conversation state. + +## The adapter interface + +A persistence adapter is any object with three methods — the same `getItem`/`setItem`/`removeItem` shape used elsewhere in TanStack AI. Each method may be synchronous or return a `Promise`: + +```typescript +import type { ChatClientPersistence } from "@tanstack/ai-client"; + +interface ChatClientPersistence { + getItem: ( + id: string, + ) => + | Array + | null + | undefined + | Promise | null | undefined>; + setItem: (id: string, messages: Array) => void | Promise; + removeItem: (id: string) => void | Promise; +} +``` + +The `id` passed to each method is the client's `id` option. Provide a stable `id` per conversation so the right history is loaded back: + +```typescript +const client = new ChatClient({ + id: "conversation-123", + connection: adapter, + persistence: myPersistenceAdapter, +}); +``` + +## What the client does for you + +When a `persistence` adapter is provided, `ChatClient`: + +- **Hydrates on construction** — calls `getItem(id)`. If it returns an array, those messages populate the client (overriding `initialMessages`). Async adapters hydrate as soon as the promise resolves, unless you've already started a new conversation in the meantime. +- **Saves on every change** — calls `setItem(id, messages)` whenever the message list changes (new user message, streamed assistant content, tool calls/results, approval responses). Writes are queued so they never overlap or land out of order. +- **Clears on `clear()`** — calls `removeItem(id)` and discards any in-flight stream so a cleared conversation doesn't get repopulated by late chunks. + +When `persistence` is omitted, nothing changes — the client behaves exactly as before. The option is fully backwards compatible. + +Persistence is **best-effort**: if an adapter method throws or rejects, the error is swallowed so storage problems never break the chat. Handle and surface errors inside your adapter if you need to react to them. + +## Framework usage + +Every framework wrapper accepts the same `persistence` option and forwards it to the underlying `ChatClient`: + +```tsx +// React / Preact +const chat = useChat({ + id: "conversation-123", + connection: fetchServerSentEvents("/api/chat"), + persistence: myPersistenceAdapter, +}); +``` + +```ts +// Solid / Vue — same option +const chat = useChat({ + id: "conversation-123", + connection: fetchServerSentEvents("/api/chat"), + persistence: myPersistenceAdapter, +}); +``` + +```ts +// Svelte +const chat = createChat({ + id: "conversation-123", + connection: fetchServerSentEvents("/api/chat"), + persistence: myPersistenceAdapter, +}); +``` + +## Example: `localStorage` + +A synchronous adapter backed by `localStorage`. Note that `UIMessage.createdAt` is a `Date`, which `JSON.stringify` turns into a string — revive it on read if you depend on it: + +```typescript +import type { ChatClientPersistence, UIMessage } from "@tanstack/ai-client"; + +const localStoragePersistence: ChatClientPersistence = { + getItem: (id) => { + const raw = window.localStorage.getItem(id); + if (!raw) return null; + return (JSON.parse(raw) as Array).map((message) => ({ + ...message, + createdAt: + typeof message.createdAt === "string" + ? new Date(message.createdAt) + : message.createdAt, + })); + }, + setItem: (id, messages) => { + window.localStorage.setItem(id, JSON.stringify(messages)); + }, + removeItem: (id) => { + window.localStorage.removeItem(id); + }, +}; +``` + +## Example: IndexedDB (async) + +For larger histories or structured queries, back the adapter with an async store such as IndexedDB. The client awaits async methods automatically: + +```typescript +import type { ChatClientPersistence } from "@tanstack/ai-client"; + +const indexedDbPersistence: ChatClientPersistence = { + getItem: async (id) => { + const record = await db.conversations.get(id); + return record?.messages; + }, + setItem: async (id, messages) => { + await db.conversations.put({ id, messages, updatedAt: Date.now() }); + }, + removeItem: async (id) => { + await db.conversations.delete(id); + }, +}; +``` + +Any backend works — IndexedDB, SQLite (Electron/Tauri), a remote database, or an in-memory `Map` for tests — as long as it implements the three methods. diff --git a/docs/config.json b/docs/config.json index 370c839a5..0d29c160f 100644 --- a/docs/config.json +++ b/docs/config.json @@ -99,6 +99,10 @@ { "label": "Thinking & Reasoning", "to": "chat/thinking-content" + }, + { + "label": "Persistence", + "to": "chat/persistence" } ] }, diff --git a/packages/typescript/ai-client/src/chat-client.ts b/packages/typescript/ai-client/src/chat-client.ts index 505f36e3a..b604aa615 100644 --- a/packages/typescript/ai-client/src/chat-client.ts +++ b/packages/typescript/ai-client/src/chat-client.ts @@ -5,7 +5,10 @@ import { normalizeToUIMessage, } from '@tanstack/ai' import { DefaultChatClientEventEmitter } from './events' -import { normalizeConnectionAdapter } from './connection-adapters' +import { + getInternalRunContextIds, + normalizeConnectionAdapter, +} from './connection-adapters' import type { AnyClientTool, ContentPart, @@ -19,6 +22,7 @@ import type { import type { ChatClientEventEmitter } from './events' import type { ChatClientOptions, + ChatClientPersistence, ChatClientState, ConnectionStatus, MessagePart, @@ -32,6 +36,17 @@ export class ChatClient { private connection: SubscribeConnectionAdapter private readonly uniqueId: string private readonly threadId: string + private readonly persistence?: ChatClientPersistence + private skipNextPersist = false + private persistenceGeneration = 0 + private persistenceQueue: Promise = Promise.resolve() + private persistenceQueuePending = false + private currentRunId: string | null = null + private readonly clearedMessageIds = new Set() + private readonly clearedRunIds = new Set() + private readonly ignoredActiveRunIds = new Set() + private readonly clearedToolCallIds = new Set() + private currentRunlessRunId: string | null = null // Track the legacy `body` option and the canonical `forwardedProps` // option as separate slots so that `updateOptions({ forwardedProps })` // doesn't wipe a previously-set `body` (and vice versa). They are @@ -58,6 +73,7 @@ export class ChatClient { private processingResolve: (() => void) | null = null private errorReportedGeneration: number | null = null private streamGeneration = 0 + private messagesGeneration = 0 // Tracks whether a queued checkForContinuation was skipped because // continuationPending was true (chained approval scenario) private continuationSkipped = false @@ -89,6 +105,7 @@ export class ChatClient { constructor(options: ChatClientOptions) { this.uniqueId = options.id || this.generateUniqueId('chat') this.threadId = options.threadId || this.generateUniqueId('thread') + this.persistence = options.persistence // Both `body` (deprecated) and `forwardedProps` populate the AG-UI // `RunAgentInput.forwardedProps` wire field. They are stored // separately so `updateOptions` can replace one without touching the @@ -129,15 +146,20 @@ export class ChatClient { // Create StreamProcessor with event handlers. // Use conditional spreads so we don't pass `undefined` into // `StreamProcessorOptions` fields under `exactOptionalPropertyTypes`. + const persistedMessages = this.getPersistedMessages() + const initialMessages = Array.isArray(persistedMessages) + ? persistedMessages + : options.initialMessages + this.processor = new StreamProcessor({ ...(options.streamProcessor?.chunkStrategy ? { chunkStrategy: options.streamProcessor.chunkStrategy } : {}), - ...(options.initialMessages - ? { initialMessages: options.initialMessages } - : {}), + ...(initialMessages ? { initialMessages } : {}), events: { onMessagesChange: (messages: Array) => { + this.messagesGeneration++ + this.persistMessages(messages) this.callbacksRef.current.onMessagesChange(messages) }, onStreamStart: () => { @@ -276,12 +298,266 @@ export class ChatClient { }) this.events.clientCreated(this.processor.getMessages().length) + this.hydratePersistedMessagesAsync(persistedMessages) + } + + private getPersistedMessages(): + | Array + | null + | undefined + | Promise | null | undefined> { + if (!this.persistence) { + return undefined + } + try { + return this.persistence.getItem(this.uniqueId) + } catch { + return undefined + } + } + + private hydratePersistedMessagesAsync( + persistedMessages: + | Array + | null + | undefined + | Promise | null | undefined>, + ): void { + if (!(persistedMessages instanceof Promise)) { + return + } + + const hydrationGeneration = this.messagesGeneration + persistedMessages + .then((messages) => { + if ( + Array.isArray(messages) && + this.messagesGeneration === hydrationGeneration + ) { + this.processor.setMessages(messages) + } + }) + .catch(() => { + // Persistence adapters are best-effort and must not break chat setup. + }) + } + + private persistMessages(messages: Array): void { + if (this.skipNextPersist) { + this.skipNextPersist = false + return + } + if (!this.persistence) { + return + } + const persistence = this.persistence + const persistenceGeneration = this.persistenceGeneration + const messagesSnapshot = [...messages] + this.runPersistenceOperation(() => { + if (persistenceGeneration !== this.persistenceGeneration) { + return + } + return persistence.setItem(this.uniqueId, messagesSnapshot) + }) + } + + private runPersistenceOperation(operation: () => void | Promise): void { + if (this.persistenceQueuePending) { + const queued = this.persistenceQueue.then(operation).catch(() => { + // Persistence adapters are best-effort and must not break chat updates. + }) + this.persistenceQueue = queued + void queued.finally(() => { + if (this.persistenceQueue === queued) { + this.persistenceQueuePending = false + } + }) + return + } + + try { + const result = operation() + if (result instanceof Promise) { + this.persistenceQueuePending = true + const queued = result.catch(() => { + // Persistence adapters are best-effort and must not break chat updates. + }) + this.persistenceQueue = queued + void queued.finally(() => { + if (this.persistenceQueue === queued) { + this.persistenceQueuePending = false + } + }) + } + } catch { + // Persistence adapters are best-effort and must not break chat updates. + } + } + + private removePersistedMessages(): void { + if (!this.persistence) { + return + } + const persistence = this.persistence + const persistenceGeneration = ++this.persistenceGeneration + this.runPersistenceOperation(() => { + if (persistenceGeneration !== this.persistenceGeneration) { + return + } + return persistence.removeItem(this.uniqueId) + }) + } + + private snapshotClearedStreamState(): void { + if (!this.persistence) return + for (const message of this.processor.getMessages()) { + this.clearedMessageIds.add(message.id) + } + for (const runId of this.activeRunIds) { + this.clearedRunIds.add(runId) + this.ignoredActiveRunIds.add(runId) + } + if (this.currentRunId) { + this.clearedRunIds.add(this.currentRunId) + this.ignoredActiveRunIds.add(this.currentRunId) + } + } + + private shouldIgnoreChunk(chunk: StreamChunk): boolean { + if (!this.persistence) return false + + const runId = this.getChunkRunId(chunk) + if (runId && this.clearedRunIds.has(runId)) { + if (chunk.type === 'RUN_STARTED') { + this.ignoredActiveRunIds.add(runId) + this.currentRunlessRunId = runId + } + this.markIgnoredChunkIds(chunk) + return true + } + + if (runId && this.ignoredActiveRunIds.has(runId)) { + this.markIgnoredChunkIds(chunk) + return true + } + + if (this.isRunlessChunkFromIgnoredRun(chunk)) { + this.markIgnoredChunkIds(chunk) + return true + } + + const toolCallId = (chunk as { toolCallId?: string }).toolCallId + if (toolCallId && this.clearedToolCallIds.has(toolCallId)) { + return true + } + + const parentMessageId = (chunk as { parentMessageId?: string }) + .parentMessageId + if (parentMessageId && this.clearedMessageIds.has(parentMessageId)) { + if (toolCallId) { + this.clearedToolCallIds.add(toolCallId) + } + return true + } + + const messageId = (chunk as { messageId?: string }).messageId + if (!messageId) { + return false + } + if (this.clearedMessageIds.has(messageId)) { + return true + } + + return false + } + + private markIgnoredChunkIds(chunk: StreamChunk): void { + const messageId = (chunk as { messageId?: string }).messageId + if (messageId) { + this.clearedMessageIds.add(messageId) + } + const toolCallId = (chunk as { toolCallId?: string }).toolCallId + if (toolCallId) { + this.clearedToolCallIds.add(toolCallId) + } + } + + private isRunlessChunkFromIgnoredRun(chunk: StreamChunk): boolean { + const runId = this.getChunkRunId(chunk) + if (runId || !this.currentRunlessRunId) return false + if ( + !this.ignoredActiveRunIds.has(this.currentRunlessRunId) && + !this.clearedRunIds.has(this.currentRunlessRunId) + ) { + return false + } + return ( + chunk.type === 'TEXT_MESSAGE_START' || + chunk.type === 'TEXT_MESSAGE_CONTENT' || + chunk.type === 'TOOL_CALL_START' || + chunk.type === 'TOOL_CALL_ARGS' || + chunk.type === 'TOOL_CALL_END' || + chunk.type === 'TOOL_CALL_RESULT' || + chunk.type === 'MESSAGES_SNAPSHOT' || + chunk.type === 'RUN_ERROR' + ) + } + + private drainIgnoredRunlessChunk(chunk: StreamChunk): void { + if (!this.currentRunlessRunId || chunk.type !== 'RUN_ERROR') return + const runId = this.currentRunlessRunId + this.activeRunIds.delete(runId) + this.ignoredActiveRunIds.delete(runId) + this.clearedRunIds.delete(runId) + this.currentRunlessRunId = null + this.setSessionGenerating(this.activeRunIds.size > 0) + this.resolveProcessing() + } + + private updateRunLifecycle( + chunk: StreamChunk, + options?: { resolveProcessing?: boolean }, + ): void { + if (chunk.type === 'RUN_STARTED') { + const chunkRunId = this.getChunkRunId(chunk) ?? chunk.runId + this.activeRunIds.add(chunkRunId) + this.currentRunlessRunId = chunkRunId + this.setSessionGenerating(true) + return + } + + if (chunk.type !== 'RUN_FINISHED' && chunk.type !== 'RUN_ERROR') { + return + } + + const runId = this.getChunkRunId(chunk) + if (runId) { + this.activeRunIds.delete(runId) + this.ignoredActiveRunIds.delete(runId) + this.clearedRunIds.delete(runId) + if (this.currentRunlessRunId === runId) { + this.currentRunlessRunId = + this.ignoredActiveRunIds.values().next().value ?? null + } + } else if (chunk.type === 'RUN_ERROR') { + this.activeRunIds.clear() + this.ignoredActiveRunIds.clear() + this.currentRunlessRunId = null + } + this.setSessionGenerating(this.activeRunIds.size > 0) + if (options?.resolveProcessing !== false) { + this.resolveProcessing() + } } private generateUniqueId(prefix: string): string { return `${prefix}-${Date.now()}-${Math.random().toString(36).substring(7)}` } + private getChunkRunId(chunk: StreamChunk): string | undefined { + return (chunk as { runId?: string }).runId ?? getInternalRunContextIds(chunk)?.runId + } + private setIsLoading(isLoading: boolean): void { this.isLoading = isLoading this.callbacksRef.current.onLoadingChange(isLoading) @@ -311,6 +587,7 @@ export class ChatClient { private resetSessionGenerating(): void { this.activeRunIds.clear() + this.ignoredActiveRunIds.clear() this.setSessionGenerating(false) } @@ -408,32 +685,26 @@ export class ChatClient { if (this.connectionStatus === 'connecting') { this.setConnectionStatus('connected') } - this.callbacksRef.current.onChunk(chunk) - this.processor.processChunk(chunk) - if (chunk.type === 'RUN_STARTED') { - this.activeRunIds.add(chunk.runId) - this.setSessionGenerating(true) - } - // RUN_FINISHED / RUN_ERROR signal run completion — resolve processing - // (redundant if onStreamEnd already resolved it, harmless) - if (chunk.type === 'RUN_FINISHED' || chunk.type === 'RUN_ERROR') { - // RUN_FINISHED has runId in its schema; RUN_ERROR carries it via the - // AG-UI passthrough so adapters can correlate per-run errors. Extract - // both so a RUN_ERROR with a runId only clears that run, not every - // active run in the session. - const runId = - chunk.type === 'RUN_FINISHED' - ? chunk.runId - : (chunk as { runId?: string }).runId - if (runId) { - this.activeRunIds.delete(runId) - } else if (chunk.type === 'RUN_ERROR') { - // RUN_ERROR without runId is a session-level error; clear all runs - this.activeRunIds.clear() + const shouldIgnore = this.shouldIgnoreChunk(chunk) + if (shouldIgnore) { + if (chunk.type === 'RUN_FINISHED' || chunk.type === 'RUN_ERROR') { + if (this.getChunkRunId(chunk)) { + this.updateRunLifecycle(chunk, { resolveProcessing: false }) + } else { + this.drainIgnoredRunlessChunk(chunk) + } } - this.setSessionGenerating(this.activeRunIds.size > 0) - this.resolveProcessing() + continue } + this.callbacksRef.current.onChunk(chunk) + this.processor.processChunk(chunk) + // Run lifecycle (active-run tracking, session-generating state, and + // processing resolution for RUN_FINISHED / RUN_ERROR) is handled in a + // single place so the ignored-chunk path above and this path can't + // diverge. RUN_ERROR carries its runId via the AG-UI passthrough so a + // per-run error only clears that run, while a runId-less RUN_ERROR is + // treated as a session-level error that clears every active run. + this.updateRunLifecycle(chunk) // Yield control back to event loop for UI updates await new Promise((resolve) => setTimeout(resolve, 0)) } @@ -589,6 +860,8 @@ export class ChatClient { // Track generation so a superseded stream's cleanup doesn't clobber the new one const generation = ++this.streamGeneration + const runId = `run-${Date.now()}-${Math.random().toString(36).slice(2, 8)}` + this.currentRunId = runId this.setIsLoading(true) this.setStatus('submitted') @@ -662,7 +935,7 @@ export class ChatClient { // serialize to an unusable shape. const runContext = { threadId: this.threadId, - runId: `run-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`, + runId, clientTools: Array.from(this.clientToolsRef.current.values()).map( (t) => ({ name: t.name, @@ -717,6 +990,7 @@ export class ChatClient { if (generation === this.streamGeneration) { this.currentStreamId = null this.currentMessageId = null + this.currentRunId = null this.abortController = null this.setIsLoading(false) this.pendingMessageBody = undefined // Ensure it's cleared even on error @@ -815,7 +1089,11 @@ export class ChatClient { * Stop the current stream */ stop(): void { + const hadLocalStream = this.abortController !== null this.cancelInFlightStream({ setReadyStatus: true }) + if (hadLocalStream) { + this.resetSessionGenerating() + } this.events.stopped() } @@ -823,7 +1101,18 @@ export class ChatClient { * Clear all messages */ clear(): void { + if (this.persistence) { + this.snapshotClearedStreamState() + } + if (this.persistence && this.isLoading) { + this.cancelInFlightStream({ setReadyStatus: true }) + this.resetSessionGenerating() + } else if (this.persistence && this.activeRunIds.size > 0) { + this.resetSessionGenerating() + } + this.skipNextPersist = true this.processor.clearMessages() + this.removePersistedMessages() this.setError(undefined) this.events.messagesCleared() } diff --git a/packages/typescript/ai-client/src/connection-adapters.ts b/packages/typescript/ai-client/src/connection-adapters.ts index 00da68103..6be665dcf 100644 --- a/packages/typescript/ai-client/src/connection-adapters.ts +++ b/packages/typescript/ai-client/src/connection-adapters.ts @@ -5,6 +5,17 @@ function generateRunId(prefix: string): string { return `${prefix}-${Date.now()}-${Math.random().toString(36).slice(2, 8)}` } +const chunkRunContextIds = new WeakMap< + StreamChunk, + Pick +>() + +export function getInternalRunContextIds( + chunk: StreamChunk, +): Pick | undefined { + return chunkRunContextIds.get(chunk) +} + /** * Merge custom headers into request headers */ @@ -156,7 +167,13 @@ export function normalizeConnectionAdapter( let activeBuffer: Array = [] let activeWaiters: Array<(chunk: StreamChunk | null) => void> = [] - function push(chunk: StreamChunk): void { + function push(chunk: StreamChunk, runContext?: RunAgentInputContext): void { + if (runContext) { + chunkRunContextIds.set(chunk, { + threadId: runContext.threadId, + runId: runContext.runId, + }) + } const waiter = activeWaiters.shift() if (waiter) { waiter(chunk) @@ -207,7 +224,7 @@ export function normalizeConnectionAdapter( if (chunk.type === 'RUN_FINISHED' || chunk.type === 'RUN_ERROR') { hasTerminalEvent = true } - push(chunk) + push(chunk, runContext) } // If the connect stream ended cleanly without a terminal event, diff --git a/packages/typescript/ai-client/src/index.ts b/packages/typescript/ai-client/src/index.ts index 98f75a999..0b5a25e7a 100644 --- a/packages/typescript/ai-client/src/index.ts +++ b/packages/typescript/ai-client/src/index.ts @@ -12,6 +12,7 @@ export type { ThinkingPart, StructuredOutputPart, // Client configuration types + ChatClientPersistence, ChatClientOptions, ChatRequestBody, InferChatMessages, diff --git a/packages/typescript/ai-client/src/types.ts b/packages/typescript/ai-client/src/types.ts index 815d6a838..8773acedd 100644 --- a/packages/typescript/ai-client/src/types.ts +++ b/packages/typescript/ai-client/src/types.ts @@ -201,6 +201,23 @@ export interface UIMessage< createdAt?: Date } +export interface ChatClientPersistence< + TTools extends ReadonlyArray = any, +> { + getItem: ( + id: string, + ) => + | Array> + | null + | undefined + | Promise> | null | undefined> + setItem: ( + id: string, + messages: Array>, + ) => void | Promise + removeItem: (id: string) => void | Promise +} + export interface ChatClientOptions< TTools extends ReadonlyArray = any, > { @@ -216,6 +233,11 @@ export interface ChatClientOptions< */ initialMessages?: Array> + /** + * Optional persistence adapter for chat messages. + */ + persistence?: ChatClientPersistence + /** * Unique identifier for this chat instance * Used for managing multiple chats diff --git a/packages/typescript/ai-client/tests/chat-client.test.ts b/packages/typescript/ai-client/tests/chat-client.test.ts index f49c05110..45a9bacc2 100644 --- a/packages/typescript/ai-client/tests/chat-client.test.ts +++ b/packages/typescript/ai-client/tests/chat-client.test.ts @@ -13,10 +13,44 @@ import type { ConnectConnectionAdapter, ConnectionAdapter, } from '../src/connection-adapters' -import type { StreamChunk } from '@tanstack/ai' -import type { UIMessage } from '../src/types' +import type { ModelMessage, StreamChunk } from '@tanstack/ai' +import type { ChatClientPersistence, UIMessage } from '../src/types' describe('ChatClient', () => { + const persistedMessage: UIMessage = { + id: 'persisted-1', + role: 'user', + parts: [{ type: 'text', content: 'Persisted hello' }], + createdAt: new Date('2024-01-01T00:00:00.000Z'), + } + + const initialMessage: UIMessage = { + id: 'initial-1', + role: 'user', + parts: [{ type: 'text', content: 'Initial hello' }], + createdAt: new Date('2024-01-02T00:00:00.000Z'), + } + + function createPersistence( + storedMessages?: Array | null, + ): ChatClientPersistence { + return { + getItem: vi.fn(() => storedMessages), + setItem: vi.fn(), + removeItem: vi.fn(), + } + } + + function createDeferred() { + let resolve!: (value: T) => void + let reject!: (reason?: unknown) => void + const promise = new Promise((promiseResolve, promiseReject) => { + resolve = promiseResolve + reject = promiseReject + }) + return { promise, resolve, reject } + } + describe('constructor', () => { it('should create a client with default options', () => { const adapter = createMockConnectionAdapter() @@ -48,6 +82,151 @@ describe('ChatClient', () => { expect(client.getMessages()).toEqual(initialMessages) }) + it('should hydrate messages from persistence', () => { + const adapter = createMockConnectionAdapter() + const persistence = createPersistence([persistedMessage]) + + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + expect(persistence.getItem).toHaveBeenCalledWith('chat-1') + expect(client.getMessages()).toEqual([persistedMessage]) + }) + + it('should prefer persisted messages over initial messages', () => { + const adapter = createMockConnectionAdapter() + const persistence = createPersistence([persistedMessage]) + + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + initialMessages: [initialMessage], + persistence, + }) + + expect(client.getMessages()).toEqual([persistedMessage]) + }) + + it('should fall back to initial messages when persistence returns null', () => { + const adapter = createMockConnectionAdapter() + const persistence = createPersistence(null) + + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + initialMessages: [initialMessage], + persistence, + }) + + expect(client.getMessages()).toEqual([initialMessage]) + }) + + it('should fall back to initial messages when persistence returns undefined', () => { + const adapter = createMockConnectionAdapter() + const persistence = createPersistence(undefined) + + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + initialMessages: [initialMessage], + persistence, + }) + + expect(client.getMessages()).toEqual([initialMessage]) + }) + + it('should let persisted empty arrays override initial messages', () => { + const adapter = createMockConnectionAdapter() + const persistence = createPersistence([]) + + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + initialMessages: [initialMessage], + persistence, + }) + + expect(client.getMessages()).toEqual([]) + }) + + it('should hydrate from async persistence and notify message listeners', async () => { + const adapter = createMockConnectionAdapter() + const onMessagesChange = vi.fn() + const persistence = { + getItem: vi.fn(() => Promise.resolve([persistedMessage])), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + initialMessages: [initialMessage], + onMessagesChange, + persistence, + }) + + expect(client.getMessages()).toEqual([initialMessage]) + + await vi.waitFor(() => { + expect(client.getMessages()).toEqual([persistedMessage]) + }) + + expect(onMessagesChange).toHaveBeenCalledWith([persistedMessage]) + }) + + it('should ignore async persistence hydration after local message changes', async () => { + const adapter = createMockConnectionAdapter() + const deferred = createDeferred>() + const persistence = { + getItem: vi.fn(() => deferred.promise), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + initialMessages: [initialMessage], + persistence, + }) + + client.setMessagesManually([ + { + id: 'local-1', + role: 'user', + parts: [{ type: 'text', content: 'Local change' }], + createdAt: new Date('2024-01-03T00:00:00.000Z'), + }, + ]) + + deferred.resolve([persistedMessage]) + await new Promise((resolve) => setTimeout(resolve, 0)) + + expect(client.getMessages()).toEqual([ + { + id: 'local-1', + role: 'user', + parts: [{ type: 'text', content: 'Local change' }], + createdAt: new Date('2024-01-03T00:00:00.000Z'), + }, + ]) + }) + + it('should keep current constructor behavior when persistence is omitted', () => { + const adapter = createMockConnectionAdapter() + + const client = new ChatClient({ + connection: adapter, + initialMessages: [initialMessage], + }) + + expect(client.getMessages()).toEqual([initialMessage]) + }) + it('should use provided id or generate one', async () => { const adapter = createMockConnectionAdapter({ chunks: createTextChunks('Response'), @@ -92,60 +271,1044 @@ describe('ChatClient', () => { }) }) - describe('subscribe/send connection mode', () => { - function createSubscribeAdapter(chunksToSend: Array) { - let hasPendingSend = false - let wakeSubscriber: (() => void) | null = null - let removeAbortListener: (() => void) | null = null + describe('subscribe/send connection mode', () => { + function createSubscribeAdapter(chunksToSend: Array) { + let hasPendingSend = false + let wakeSubscriber: (() => void) | null = null + let removeAbortListener: (() => void) | null = null + + const subscribe = vi.fn((signal?: AbortSignal) => { + return (async function* () { + while (!signal?.aborted) { + if (!hasPendingSend) { + await new Promise((resolve) => { + removeAbortListener?.() + removeAbortListener = null + wakeSubscriber = resolve + const onAbort = () => resolve() + signal?.addEventListener('abort', onAbort, { once: true }) + removeAbortListener = () => { + signal?.removeEventListener('abort', onAbort) + } + }) + continue + } + + hasPendingSend = false + for (const chunk of chunksToSend) { + yield chunk + } + } + removeAbortListener?.() + removeAbortListener = null + })() + }) + + const send = vi.fn(async () => { + removeAbortListener?.() + removeAbortListener = null + hasPendingSend = true + wakeSubscriber?.() + wakeSubscriber = null + }) + + return { subscribe, send } + } + + it('should use subscribe/send adapter mode', async () => { + const adapter = createSubscribeAdapter( + createTextChunks('From subscribe/send mode'), + ) + const client = new ChatClient({ connection: adapter }) + + await client.sendMessage('Hello') + + expect(adapter.subscribe).toHaveBeenCalled() + expect(adapter.send).toHaveBeenCalled() + }) + + it('should ignore native subscribe/send chunks from a cleared persisted request without runId', async () => { + let storedMessages: Array | undefined + const releaseFirstResponse = createDeferred() + const queuedChunks: Array<{ prompt: string; chunks: Array }> = + [] + let wakeSubscriber: (() => void) | null = null + const adapter: ConnectionAdapter = { + subscribe: vi.fn((_signal?: AbortSignal) => { + return (async function* () { + while (true) { + if (queuedChunks.length === 0) { + await new Promise((resolve) => { + wakeSubscriber = resolve + }) + } + const next = queuedChunks.shift() + if (!next) continue + if (next.prompt === 'A') { + const [started, ...remainingChunks] = next.chunks + if (started) { + yield started + } + await releaseFirstResponse.promise + yield* remainingChunks + continue + } + yield* next.chunks + } + })() + }), + send: vi.fn( + async ( + messages: Array | Array, + _data, + _signal, + runContext, + ) => { + const prompt = messages + .flatMap((message) => ('parts' in message ? message.parts : [])) + .find((part) => part.type === 'text')?.content + + queuedChunks.push({ + prompt: prompt ?? '', + chunks: [ + { + type: EventType.RUN_STARTED, + threadId: runContext?.threadId ?? 'thread-1', + runId: runContext?.runId ?? 'run-1', + timestamp: Date.now(), + } as StreamChunk, + ...createTextChunks( + prompt === 'A' ? 'stale A' : 'fresh B', + prompt === 'A' ? 'msg-a' : 'msg-b', + ).map((chunk) => { + if (chunk.type === 'TEXT_MESSAGE_CONTENT') { + const { runId: _runId, ...withoutRunId } = chunk + return withoutRunId as StreamChunk + } + if (chunk.type === 'RUN_FINISHED') { + return { + ...chunk, + threadId: runContext?.threadId ?? chunk.threadId, + runId: runContext?.runId ?? chunk.runId, + } as StreamChunk + } + return chunk + }), + ], + }) + wakeSubscriber?.() + wakeSubscriber = null + }, + ), + } + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn((_key: string, messages: Array) => { + storedMessages = messages + }), + removeItem: vi.fn(() => { + storedMessages = undefined + }), + } + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + const firstSend = client.sendMessage('A') + await vi.waitFor(() => { + expect(client.getIsLoading()).toBe(true) + }) + + client.clear() + const secondSend = client.sendMessage('B') + releaseFirstResponse.resolve() + await firstSend + await secondSend + + const finalText = client + .getMessages() + .flatMap((message) => message.parts) + .filter((part) => part.type === 'text') + .map((part) => part.content) + .join('') + + expect(finalText).toContain('B') + expect(finalText).toContain('fresh B') + expect(finalText).not.toContain('A') + expect(finalText).not.toContain('stale A') + expect(storedMessages).toEqual(client.getMessages()) + expect( + storedMessages + ?.flatMap((message) => message.parts) + .filter((part) => part.type === 'text') + .map((part) => part.content) + .join(''), + ).not.toContain('stale A') + }) + + it('should ignore already-started runless chunks from a cleared persisted request', async () => { + let storedMessages: Array | undefined + const releaseStaleChunks = createDeferred() + const staleChunksAttempted = createDeferred() + let wakeSubscriber: (() => void) | null = null + let queued = false + const adapter: ConnectionAdapter = { + subscribe: vi.fn((_signal?: AbortSignal): AsyncIterable => { + return (async function* () { + while (true) { + if (!queued) { + await new Promise((resolve) => { + wakeSubscriber = resolve + }) + } + queued = false + yield { + type: EventType.RUN_STARTED, + threadId: 'thread-1', + runId: 'run-cleared', + timestamp: Date.now(), + } as StreamChunk + await releaseStaleChunks.promise + yield { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: 'stale-message', + timestamp: Date.now(), + delta: 'stale content', + content: 'stale content', + } as StreamChunk + staleChunksAttempted.resolve() + yield { + type: EventType.RUN_FINISHED, + threadId: 'thread-1', + runId: 'run-cleared', + timestamp: Date.now(), + } as StreamChunk + } + })() + }), + send: vi.fn(async () => { + queued = true + wakeSubscriber?.() + wakeSubscriber = null + }), + } + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn((_key: string, messages: Array) => { + storedMessages = messages + }), + removeItem: vi.fn(() => { + storedMessages = undefined + }), + } + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + const sendPromise = client.sendMessage('A') + await vi.waitFor(() => { + expect(client.getSessionGenerating()).toBe(true) + }) + + client.clear() + releaseStaleChunks.resolve() + await staleChunksAttempted.promise + await sendPromise + + expect(client.getMessages()).toEqual([]) + expect(storedMessages).toBeUndefined() + expect(client.getSessionGenerating()).toBe(false) + expect(client.getError()).toBeUndefined() + }) + + it('should keep fresh runless chunks after a clear when a fresh run starts before stale chunks drain', async () => { + let storedMessages: Array | undefined + const releaseStaleChunks = createDeferred() + const staleChunksAttempted = createDeferred() + const queuedChunks: Array<{ prompt: string; chunks: Array }> = + [] + let staleReleased = false + let wakeSubscriber: (() => void) | null = null + const wakeQueuedSubscriber = () => { + const wake = wakeSubscriber + wakeSubscriber = null + wake?.() + } + const adapter: ConnectionAdapter = { + subscribe: vi.fn((_signal?: AbortSignal): AsyncIterable => { + return (async function* () { + while (true) { + if (queuedChunks.length === 0) { + await new Promise((resolve) => { + wakeSubscriber = resolve + }) + } + const freshIndex = queuedChunks.findIndex( + (queued) => queued.prompt === 'B', + ) + const next = + freshIndex >= 0 + ? queuedChunks.splice(freshIndex, 1)[0] + : queuedChunks.shift() + if (!next) continue + yield next.chunks[0]! + if (next.prompt === 'A' && !staleReleased) { + queuedChunks.push({ + prompt: 'A-after-start', + chunks: next.chunks.slice(1), + }) + continue + } + if (next.prompt === 'A-after-start' && !staleReleased) { + queuedChunks.push(next) + await new Promise((resolve) => { + wakeSubscriber = resolve + }) + continue + } + for (const chunk of next.chunks.slice(1)) { + yield chunk + } + if (next.prompt === 'A-after-start') { + staleChunksAttempted.resolve() + } + } + })() + }), + send: vi.fn( + async ( + messages: Array | Array, + _data, + _signal, + runContext, + ) => { + const prompt = messages + .flatMap((message) => ('parts' in message ? message.parts : [])) + .find((part) => part.type === 'text')?.content + const messageId = prompt === 'A' ? 'stale-message' : 'fresh-message' + queuedChunks.push({ + prompt: prompt ?? '', + chunks: [ + { + type: EventType.RUN_STARTED, + threadId: runContext?.threadId ?? 'thread-1', + runId: + prompt === 'A' + ? 'run-cleared' + : (runContext?.runId ?? 'run-fresh'), + timestamp: Date.now(), + } as StreamChunk, + { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId, + timestamp: Date.now(), + delta: prompt === 'A' ? 'stale content' : 'fresh content', + content: prompt === 'A' ? 'stale content' : 'fresh content', + } as StreamChunk, + { + type: EventType.RUN_FINISHED, + threadId: runContext?.threadId ?? 'thread-1', + runId: + prompt === 'A' + ? 'run-cleared' + : (runContext?.runId ?? 'run-fresh'), + timestamp: Date.now(), + } as StreamChunk, + ], + }) + wakeQueuedSubscriber() + }, + ), + } + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn((_key: string, messages: Array) => { + storedMessages = messages + }), + removeItem: vi.fn(() => { + storedMessages = undefined + }), + } + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + const firstSend = client.sendMessage('A') + await vi.waitFor(() => { + expect(client.getSessionGenerating()).toBe(true) + }) + + client.clear() + const secondSend = client.sendMessage('B') + await secondSend + staleReleased = true + releaseStaleChunks.resolve() + wakeQueuedSubscriber() + await staleChunksAttempted.promise + firstSend.catch(() => { + // The stale request may have been superseded by the fresh send. + }) + + const finalText = client + .getMessages() + .flatMap((message) => message.parts) + .filter((part) => part.type === 'text') + .map((part) => part.content) + .join('') + + expect(finalText).toContain('B') + expect(finalText).toContain('fresh content') + expect(finalText).not.toContain('A') + expect(finalText).not.toContain('stale content') + expect(storedMessages).toEqual(client.getMessages()) + }) + + it('should ignore stale messages snapshot after persisted clear', async () => { + let storedMessages: Array | undefined + const releaseSnapshot = createDeferred() + const snapshotAttempted = createDeferred() + const adapter: ConnectionAdapter = { + async *connect(_messages, _data, _signal, runContext) { + yield { + type: EventType.RUN_STARTED, + threadId: runContext?.threadId ?? 'thread-1', + runId: runContext?.runId ?? 'run-1', + timestamp: Date.now(), + } as StreamChunk + await releaseSnapshot.promise + yield { + type: EventType.MESSAGES_SNAPSHOT, + messages: [ + { + id: 'stale-assistant', + role: 'assistant', + content: 'stale snapshot', + }, + ], + } as StreamChunk + snapshotAttempted.resolve() + yield { + type: EventType.RUN_FINISHED, + threadId: runContext?.threadId ?? 'thread-1', + runId: runContext?.runId ?? 'run-1', + timestamp: Date.now(), + } as StreamChunk + }, + } + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn((_key: string, messages: Array) => { + storedMessages = messages + }), + removeItem: vi.fn(() => { + storedMessages = undefined + }), + } + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + const sendPromise = client.sendMessage('A') + await vi.waitFor(() => { + expect(client.getSessionGenerating()).toBe(true) + }) + + client.clear() + releaseSnapshot.resolve() + await snapshotAttempted.promise + await sendPromise + + expect(client.getMessages()).toEqual([]) + expect(storedMessages).toBeUndefined() + }) + + it('should ignore stale runless run error after persisted clear', async () => { + let storedMessages: Array | undefined + const onError = vi.fn() + const releaseError = createDeferred() + const errorAttempted = createDeferred() + const adapter: ConnectionAdapter = { + async *connect(_messages, _data, _signal, runContext) { + yield { + type: EventType.RUN_STARTED, + threadId: runContext?.threadId ?? 'thread-1', + runId: runContext?.runId ?? 'run-1', + timestamp: Date.now(), + } as StreamChunk + await releaseError.promise + yield { + type: EventType.RUN_ERROR, + threadId: runContext?.threadId ?? 'thread-1', + timestamp: Date.now(), + message: 'stale failure', + error: { message: 'stale failure' }, + } as StreamChunk + errorAttempted.resolve() + }, + } + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn((_key: string, messages: Array) => { + storedMessages = messages + }), + removeItem: vi.fn(() => { + storedMessages = undefined + }), + } + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + onError, + }) + + const sendPromise = client.sendMessage('A') + await vi.waitFor(() => { + expect(client.getSessionGenerating()).toBe(true) + }) + + client.clear() + releaseError.resolve() + await errorAttempted.promise + await sendPromise + + expect(client.getMessages()).toEqual([]) + expect(client.getError()).toBeUndefined() + expect(client.getStatus()).toBe('ready') + expect(onError).not.toHaveBeenCalled() + expect(storedMessages).toBeUndefined() + }) + + it('should keep fresh native subscribe/send chunks when they arrive before stale cleared chunks', async () => { + let storedMessages: Array | undefined + const releaseStaleResponse = createDeferred() + let staleReleased = false + const queuedChunks: Array<{ prompt: string; chunks: Array }> = + [] + let wakeSubscriber: (() => void) | null = null + const adapter: ConnectionAdapter = { + subscribe: vi.fn((_signal?: AbortSignal) => { + return (async function* () { + while (true) { + if (queuedChunks.length === 0) { + await new Promise((resolve) => { + wakeSubscriber = resolve + }) + } + const freshIndex = queuedChunks.findIndex( + (queued) => queued.prompt !== 'A', + ) + const next = + !staleReleased && freshIndex > 0 + ? queuedChunks.splice(freshIndex, 1)[0] + : queuedChunks.shift() + if (!next) continue + if (next.prompt === 'A') { + if (!staleReleased) { + queuedChunks.push(next) + await new Promise((resolve) => { + wakeSubscriber = resolve + }) + continue + } + await releaseStaleResponse.promise + } + yield* next.chunks + } + })() + }), + send: vi.fn( + async ( + messages: Array | Array, + _data, + _signal, + runContext, + ) => { + const prompt = messages + .flatMap((message) => ('parts' in message ? message.parts : [])) + .find((part) => part.type === 'text')?.content + const messageId = prompt === 'A' ? 'msg-a' : 'msg-b' + + queuedChunks.push({ + prompt: prompt ?? '', + chunks: [ + { + type: EventType.RUN_STARTED, + threadId: runContext?.threadId ?? 'thread-1', + runId: runContext?.runId ?? `run-${messageId}`, + timestamp: Date.now(), + } as StreamChunk, + ...createTextChunks( + prompt === 'A' ? 'stale A' : 'fresh B', + messageId, + ).map((chunk) => { + if (chunk.type === 'TEXT_MESSAGE_CONTENT') { + const { runId: _runId, ...withoutRunId } = chunk + return withoutRunId as StreamChunk + } + if (chunk.type === 'RUN_FINISHED') { + return { + ...chunk, + threadId: runContext?.threadId ?? chunk.threadId, + runId: runContext?.runId ?? chunk.runId, + } as StreamChunk + } + return chunk + }), + ], + }) + wakeSubscriber?.() + wakeSubscriber = null + }, + ), + } + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn((_key: string, messages: Array) => { + storedMessages = messages + }), + removeItem: vi.fn(() => { + storedMessages = undefined + }), + } + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + const firstSend = client.sendMessage('A') + await vi.waitFor(() => { + expect(client.getIsLoading()).toBe(true) + }) + + client.clear() + const secondSend = client.sendMessage('B') + await secondSend + staleReleased = true + releaseStaleResponse.resolve() + const wake = wakeSubscriber as (() => void) | null + wake?.() + wakeSubscriber = null + await vi.waitFor(() => { + expect(adapter.send).toHaveBeenCalledTimes(2) + }) + firstSend.catch(() => { + // The stale request may already have been cancelled by clear(). + }) + + const finalText = client + .getMessages() + .flatMap((message) => message.parts) + .filter((part) => part.type === 'text') + .map((part) => part.content) + .join('') + + expect(finalText).toContain('B') + expect(finalText).toContain('fresh B') + expect(finalText).not.toContain('A') + expect(finalText).not.toContain('stale A') + expect(storedMessages).toEqual(client.getMessages()) + expect( + storedMessages + ?.flatMap((message) => message.parts) + .filter((part) => part.type === 'text') + .map((part) => part.content) + .join(''), + ).toContain('fresh B') + expect( + storedMessages + ?.flatMap((message) => message.parts) + .filter((part) => part.type === 'text') + .map((part) => part.content) + .join(''), + ).not.toContain('stale A') + }) + + it('should ignore stale tool chunks by cleared parentMessageId after persisted clear', async () => { + const releaseToolChunks = createDeferred() + const adapter: ConnectionAdapter = { + async *connect(_messages, _data, _signal, runContext) { + yield { + type: EventType.RUN_STARTED, + threadId: runContext?.threadId ?? 'thread-1', + runId: runContext?.runId ?? 'run-1', + timestamp: Date.now(), + } as StreamChunk + yield { + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'assistant-a', + model: 'test', + timestamp: Date.now(), + delta: '', + content: '', + } as StreamChunk + await releaseToolChunks.promise + yield { + type: 'TOOL_CALL_START', + toolCallId: 'stale-tool', + toolCallName: 'staleTool', + toolName: 'staleTool', + parentMessageId: 'assistant-a', + model: 'test', + timestamp: Date.now(), + index: 0, + } as StreamChunk + yield { + type: 'TOOL_CALL_ARGS', + toolCallId: 'stale-tool', + model: 'test', + timestamp: Date.now(), + delta: '{"stale":true}', + } as StreamChunk + }, + } + const persistence = createPersistence(undefined) + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + const sendPromise = client.sendMessage('A') + await vi.waitFor(() => { + expect(client.getMessages().some((message) => message.id === 'assistant-a')).toBe( + true, + ) + }) + + client.clear() + releaseToolChunks.resolve() + await sendPromise + + expect(client.getMessages()).toEqual([]) + expect(persistence.setItem).not.toHaveBeenLastCalledWith( + 'chat-1', + expect.arrayContaining([ + expect.objectContaining({ + id: 'assistant-a', + parts: expect.arrayContaining([ + expect.objectContaining({ type: 'tool-call' }), + ]), + }), + ]), + ) + }) + + it('should remember ignored stale runless message ids so child tool chunks are ignored after persisted clear', async () => { + const releaseToolChunks = createDeferred() + const adapter: ConnectionAdapter = { + async *connect(_messages, _data, _signal, runContext) { + yield { + type: EventType.RUN_STARTED, + threadId: runContext?.threadId ?? 'thread-1', + runId: runContext?.runId ?? 'run-1', + timestamp: Date.now(), + } as StreamChunk + await releaseToolChunks.promise + yield { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: 'stale-runless-message', + timestamp: Date.now(), + delta: 'stale text', + content: 'stale text', + } as StreamChunk + yield { + type: 'TOOL_CALL_START', + toolCallId: 'stale-child-tool', + toolCallName: 'staleTool', + toolName: 'staleTool', + parentMessageId: 'stale-runless-message', + model: 'test', + timestamp: Date.now(), + index: 0, + } as StreamChunk + yield { + type: 'TOOL_CALL_ARGS', + toolCallId: 'stale-child-tool', + model: 'test', + timestamp: Date.now(), + delta: '{"stale":true}', + } as StreamChunk + }, + } + const persistence = createPersistence(undefined) + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + const sendPromise = client.sendMessage('A') + await vi.waitFor(() => { + expect(client.getSessionGenerating()).toBe(true) + }) + + client.clear() + releaseToolChunks.resolve() + await sendPromise + + expect(client.getMessages()).toEqual([]) + expect(persistence.setItem).not.toHaveBeenLastCalledWith( + 'chat-1', + expect.arrayContaining([ + expect.objectContaining({ + id: 'stale-runless-message', + }), + ]), + ) + }) + + it('should ignore stale runless tool starts without parentMessageId after persisted clear', async () => { + const releaseToolChunks = createDeferred() + const adapter: ConnectionAdapter = { + async *connect(_messages, _data, _signal, runContext) { + yield { + type: EventType.RUN_STARTED, + threadId: runContext?.threadId ?? 'thread-1', + runId: runContext?.runId ?? 'run-1', + timestamp: Date.now(), + } as StreamChunk + await releaseToolChunks.promise + yield { + type: 'TOOL_CALL_START', + toolCallId: 'stale-parentless-tool', + toolCallName: 'staleTool', + toolName: 'staleTool', + model: 'test', + timestamp: Date.now(), + index: 0, + } as StreamChunk + yield { + type: 'TOOL_CALL_ARGS', + toolCallId: 'stale-parentless-tool', + model: 'test', + timestamp: Date.now(), + delta: '{"stale":true}', + } as StreamChunk + }, + } + const persistence = createPersistence(undefined) + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + const sendPromise = client.sendMessage('A') + await vi.waitFor(() => { + expect(client.getSessionGenerating()).toBe(true) + }) + + client.clear() + releaseToolChunks.resolve() + await sendPromise + + expect(client.getMessages()).toEqual([]) + expect(persistence.setItem).not.toHaveBeenLastCalledWith( + 'chat-1', + expect.arrayContaining([ + expect.objectContaining({ + parts: expect.arrayContaining([ + expect.objectContaining({ type: 'tool-call' }), + ]), + }), + ]), + ) + }) + + it('should reset session generation when a persisted clear ignores terminal chunks', async () => { + const releaseResponse = createDeferred() + let wakeSubscriber: (() => void) | null = null + let queued = false + const adapter: ConnectionAdapter = { + subscribe: vi.fn((_signal?: AbortSignal): AsyncIterable => { + return (async function* () { + while (true) { + if (!queued) { + await new Promise((resolve) => { + wakeSubscriber = resolve + }) + } + queued = false + yield { + type: EventType.RUN_STARTED, + threadId: 'thread-1', + runId: 'run-1', + timestamp: Date.now(), + } as StreamChunk + await releaseResponse.promise + yield { + type: EventType.RUN_FINISHED, + threadId: 'thread-1', + runId: 'run-1', + model: 'test', + timestamp: Date.now(), + finishReason: 'stop', + } as StreamChunk + } + })() + }), + send: vi.fn(async () => { + queued = true + wakeSubscriber?.() + wakeSubscriber = null + }), + } + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence: createPersistence(), + }) + + const sendPromise = client.sendMessage('A') + await vi.waitFor(() => { + expect(client.getSessionGenerating()).toBe(true) + }) + + client.clear() + releaseResponse.resolve() + await sendPromise + + expect(client.getSessionGenerating()).toBe(false) + }) + + it('should ignore live-only active run chunks after persisted clear and clean up session generation', async () => { + let storedMessages: Array | undefined + const releaseAfterClear = createDeferred() + const subscriberReady = createDeferred<() => void>() + const adapter: ConnectionAdapter = { + subscribe: vi.fn((_signal?: AbortSignal): AsyncIterable => { + return (async function* () { + await new Promise((resolve) => { + subscriberReady.resolve(resolve) + }) + yield { + type: EventType.RUN_STARTED, + threadId: 'thread-1', + runId: 'live-run-1', + timestamp: Date.now(), + } as StreamChunk + await releaseAfterClear.promise + yield { + type: EventType.TEXT_MESSAGE_START, + messageId: 'live-message-1', + role: 'assistant', + } as StreamChunk + yield { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: 'live-message-1', + delta: 'stale live content', + } as StreamChunk + yield { + type: EventType.RUN_FINISHED, + threadId: 'thread-1', + runId: 'live-run-1', + model: 'test', + timestamp: Date.now(), + finishReason: 'stop', + } as StreamChunk + })() + }), + send: vi.fn(), + } + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn((_key: string, messages: Array) => { + storedMessages = messages + }), + removeItem: vi.fn(() => { + storedMessages = undefined + }), + } + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + client.subscribe() + const wakeSubscriber = await subscriberReady.promise + wakeSubscriber() + await vi.waitFor(() => { + expect(client.getSessionGenerating()).toBe(true) + }) - const subscribe = vi.fn((signal?: AbortSignal) => { - return (async function* () { - while (!signal?.aborted) { - if (!hasPendingSend) { - await new Promise((resolve) => { - removeAbortListener?.() - removeAbortListener = null - wakeSubscriber = resolve - const onAbort = () => resolve() - signal?.addEventListener('abort', onAbort, { once: true }) - removeAbortListener = () => { - signal?.removeEventListener('abort', onAbort) - } - }) - continue - } + client.clear() + releaseAfterClear.resolve() - hasPendingSend = false - for (const chunk of chunksToSend) { - yield chunk - } - } - removeAbortListener?.() - removeAbortListener = null - })() + await vi.waitFor(() => { + expect(client.getSessionGenerating()).toBe(false) }) + expect(client.getMessages()).toEqual([]) + expect(storedMessages).toBeUndefined() + }) - const send = vi.fn(async () => { - removeAbortListener?.() - removeAbortListener = null - hasPendingSend = true - wakeSubscriber?.() - wakeSubscriber = null + it('should not expose request generation metadata on public chunks or run context', async () => { + const onChunk = vi.fn() + const runContextSpy = vi.fn() + const client = new ChatClient({ + connection: { + async *connect(_messages, _data, _abortSignal, runContext) { + runContextSpy(runContext) + yield* createTextChunks('Hello') + }, + }, + onChunk, }) - return { subscribe, send } - } + await client.sendMessage('Hello') - it('should use subscribe/send adapter mode', async () => { - const adapter = createSubscribeAdapter( - createTextChunks('From subscribe/send mode'), + expect(runContextSpy).toHaveBeenCalled() + expect(runContextSpy.mock.calls[0]![0]).not.toHaveProperty( + 'requestGeneration', ) - const client = new ChatClient({ connection: adapter }) + expect(onChunk).toHaveBeenCalled() + for (const [chunk] of onChunk.mock.calls) { + expect(Object.keys(chunk)).not.toContain('requestGeneration') + expect(chunk).not.toHaveProperty('requestGeneration') + } + }) + + it('should not add internal threadId or runId to public connect chunks', async () => { + const onChunk = vi.fn() + const client = new ChatClient({ + connection: { + async *connect() { + yield { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: 'public-message', + timestamp: Date.now(), + delta: 'Hello', + content: 'Hello', + } as StreamChunk + }, + }, + onChunk, + }) await client.sendMessage('Hello') - expect(adapter.subscribe).toHaveBeenCalled() - expect(adapter.send).toHaveBeenCalled() + expect(onChunk).toHaveBeenCalledWith( + expect.objectContaining({ + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: 'public-message', + }), + ) + for (const [chunk] of onChunk.mock.calls) { + if (chunk.type === EventType.TEXT_MESSAGE_CONTENT) { + expect(chunk).not.toHaveProperty('threadId') + expect(chunk).not.toHaveProperty('runId') + } + } }) it('stop should not unsubscribe an active subscription', async () => { @@ -736,6 +1899,60 @@ describe('ChatClient', () => { client.unsubscribe() }) + it('should process future live subscription chunks after persistence clear', async () => { + const wake = { fn: null as (() => void) | null } + const chunks: Array = [] + const connection = { + subscribe: async function* (signal?: AbortSignal) { + while (!signal?.aborted) { + if (chunks.length > 0) { + const batch = chunks.splice(0) + for (const chunk of batch) { + yield chunk + } + } + await new Promise((resolve) => { + wake.fn = resolve + const onAbort = () => resolve() + signal?.addEventListener('abort', onAbort, { once: true }) + }) + } + }, + send: async () => { + wake.fn?.() + }, + } + const persistence = createPersistence() + const client = new ChatClient({ + connection, + id: 'chat-1', + persistence, + }) + + client.subscribe() + await vi.waitFor(() => { + expect(client.getIsSubscribed()).toBe(true) + }) + + client.clear() + chunks.push(...createTextChunks('future live', 'future-live')) + wake.fn?.() + + await vi.waitFor(() => { + expect( + client + .getMessages() + .flatMap((message) => message.parts) + .some( + (part) => + part.type === 'text' && part.content.includes('future live'), + ), + ).toBe(true) + }) + + client.unsubscribe() + }) + it('should clear all runs on RUN_ERROR without runId', async () => { const wake = { fn: null as (() => void) | null } const chunks: Array = [] @@ -1069,6 +2286,273 @@ describe('ChatClient', () => { expect(client.getMessages().length).toBe(0) expect(client.getError()).toBeUndefined() }) + + it('should remove persisted messages without saving an empty snapshot', async () => { + const chunks = createTextChunks('Response') + const adapter = createMockConnectionAdapter({ chunks }) + const persistence = createPersistence() + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + await client.sendMessage('Hello') + vi.mocked(persistence.setItem).mockClear() + + client.clear() + + expect(persistence.removeItem).toHaveBeenCalledWith('chat-1') + expect(persistence.setItem).not.toHaveBeenCalled() + }) + + it('should not abort an in-flight stream when persistence is omitted', async () => { + let abortSignal: AbortSignal | undefined + + const adapter: ConnectConnectionAdapter = { + async *connect(_messages, _data, signal) { + abortSignal = signal + await new Promise((resolve) => setTimeout(resolve, 10)) + yield* createTextChunks('Delayed') + }, + } + const client = new ChatClient({ connection: adapter }) + + const sendPromise = client.sendMessage('Hello') + await vi.waitFor(() => { + expect(abortSignal).toBeDefined() + }) + + client.clear() + expect(abortSignal?.aborted).toBe(false) + + await sendPromise + + expect(client.getMessages()).toEqual( + expect.arrayContaining([ + expect.objectContaining({ role: 'assistant' }), + ]), + ) + }) + + it('should prevent delayed stream chunks from recreating messages after clear', async () => { + const adapter = createMockConnectionAdapter({ + chunks: createTextChunks('Delayed'), + chunkDelay: 20, + }) + const persistence = createPersistence() + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + const sendPromise = client.sendMessage('Hello') + await new Promise((resolve) => setTimeout(resolve, 5)) + + client.clear() + await sendPromise + + expect(client.getMessages()).toEqual([]) + expect(persistence.removeItem).toHaveBeenCalledWith('chat-1') + expect(persistence.setItem).not.toHaveBeenLastCalledWith( + 'chat-1', + expect.arrayContaining([ + expect.objectContaining({ role: 'assistant' }), + ]), + ) + }) + + it('should not persist non-cooperative delayed chunks after clear removes storage', async () => { + const persistenceEvents: Array = [] + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn(() => { + persistenceEvents.push('set') + }), + removeItem: vi.fn(() => { + persistenceEvents.push('remove') + }), + } + const adapter: ConnectConnectionAdapter = { + async *connect() { + await new Promise((resolve) => setTimeout(resolve, 10)) + yield* createTextChunks('Delayed') + }, + } + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + const sendPromise = client.sendMessage('Hello') + await new Promise((resolve) => setTimeout(resolve, 0)) + + client.clear() + await sendPromise + + const removeIndex = persistenceEvents.indexOf('remove') + expect(removeIndex).toBeGreaterThanOrEqual(0) + expect(persistenceEvents.slice(removeIndex + 1)).not.toContain('set') + expect(client.getMessages()).toEqual([]) + expect(persistence.removeItem).toHaveBeenCalledWith('chat-1') + }) + + it('should ignore cleared request chunks after persistence resumes for a new send', async () => { + let storedMessages: Array | undefined + const releaseFirstResponse = createDeferred() + const connection: ConnectConnectionAdapter = { + async *connect(messages) { + const userText = messages + .flatMap((message) => ('parts' in message ? message.parts : [])) + .find((part) => part.type === 'text')?.content + + if (userText === 'A') { + await releaseFirstResponse.promise + yield* createTextChunks('stale A', 'msg-a') + return + } + + yield* createTextChunks('fresh B', 'msg-b') + }, + } + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn((_key: string, messages: Array) => { + storedMessages = messages + }), + removeItem: vi.fn(() => { + storedMessages = undefined + }), + } + const client = new ChatClient({ + connection, + id: 'chat-1', + persistence, + }) + + const firstSend = client.sendMessage('A') + await vi.waitFor(() => { + expect(client.getIsLoading()).toBe(true) + }) + + client.clear() + await client.sendMessage('B') + releaseFirstResponse.resolve() + await firstSend + + const finalText = client + .getMessages() + .flatMap((message) => message.parts) + .filter((part) => part.type === 'text') + .map((part) => part.content) + .join('') + + expect(finalText).toContain('B') + expect(finalText).toContain('fresh B') + expect(finalText).not.toContain('A') + expect(finalText).not.toContain('stale A') + expect(storedMessages).toEqual(client.getMessages()) + expect( + storedMessages + ?.flatMap((message) => message.parts) + .filter((part) => part.type === 'text') + .map((part) => part.content) + .join(''), + ).not.toContain('stale A') + }) + + it('should ensure async setItem scheduled before clear cannot win after removeItem', async () => { + let storedMessages: Array | undefined + const releaseSet = createDeferred() + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn(async (_key: string, messages: Array) => { + await releaseSet.promise + storedMessages = messages + }), + removeItem: vi.fn(() => { + storedMessages = undefined + }), + } + const client = new ChatClient({ + connection: createMockConnectionAdapter(), + id: 'chat-1', + persistence, + }) + + client.setMessagesManually([initialMessage]) + client.clear() + + releaseSet.resolve() + await vi.waitFor(() => { + expect(persistence.removeItem).toHaveBeenCalledWith('chat-1') + }) + + expect(storedMessages).toBeUndefined() + }) + }) + + describe('persistence', () => { + it('should save message snapshots after sendMessage changes messages', async () => { + const chunks = createTextChunks('Response') + const adapter = createMockConnectionAdapter({ chunks }) + const persistence = createPersistence() + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + await client.sendMessage('Hello') + + expect(persistence.setItem).toHaveBeenCalled() + expect(persistence.setItem).toHaveBeenLastCalledWith( + 'chat-1', + client.getMessages(), + ) + }) + + it('should save message snapshots when messages are set manually', () => { + const adapter = createMockConnectionAdapter() + const persistence = createPersistence() + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + client.setMessagesManually([initialMessage]) + + expect(persistence.setItem).toHaveBeenCalledWith('chat-1', [ + initialMessage, + ]) + }) + + it('should swallow async persistence write and remove failures', async () => { + const adapter = createMockConnectionAdapter({ + chunks: createTextChunks('Hi'), + }) + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn(() => Promise.reject(new Error('set failed'))), + removeItem: vi.fn(() => Promise.reject(new Error('remove failed'))), + } + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + await client.sendMessage('Hello') + client.clear() + await new Promise((resolve) => setTimeout(resolve, 0)) + + expect(client.getMessages()).toEqual([]) + expect(persistence.setItem).toHaveBeenCalled() + expect(persistence.removeItem).toHaveBeenCalledWith('chat-1') + }) }) describe('callbacks', () => { @@ -1088,6 +2572,40 @@ describe('ChatClient', () => { expect(onMessagesChange.mock.calls.length).toBeGreaterThan(0) }) + it('should preserve state updates and onMessagesChange when persistence throws', async () => { + const chunks = createTextChunks('Response') + const adapter = createMockConnectionAdapter({ chunks }) + const onMessagesChange = vi.fn() + const persistence: ChatClientPersistence = { + getItem: vi.fn(() => { + throw new Error('get failed') + }), + setItem: vi.fn(() => { + throw new Error('set failed') + }), + removeItem: vi.fn(() => { + throw new Error('remove failed') + }), + } + + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + initialMessages: [initialMessage], + onMessagesChange, + persistence, + }) + + expect(client.getMessages()).toEqual([initialMessage]) + + await client.sendMessage('Hello') + expect(client.getMessages().length).toBeGreaterThan(1) + expect(onMessagesChange).toHaveBeenCalled() + + expect(() => client.clear()).not.toThrow() + expect(client.getMessages()).toEqual([]) + }) + it('should call onLoadingChange when loading state changes', async () => { const chunks = createTextChunks('Response') const adapter = createMockConnectionAdapter({ chunks }) diff --git a/packages/typescript/ai-preact/src/use-chat.ts b/packages/typescript/ai-preact/src/use-chat.ts index 0913f3588..4effda908 100644 --- a/packages/typescript/ai-preact/src/use-chat.ts +++ b/packages/typescript/ai-preact/src/use-chat.ts @@ -39,6 +39,10 @@ export function useChat = any>( options.initialMessages || [], ) const isFirstMountRef = useRef(true) + const activeClientRef = useRef(null) + const cleanupInvalidationRef = useRef | null>( + null, + ) const optionsRef = useRef>(options) optionsRef.current = options @@ -48,11 +52,7 @@ export function useChat = any>( }, [messages]) const client = useMemo(() => { - // On first mount, use initialMessages. On subsequent recreations, preserve existing messages. - const messagesToUse = isFirstMountRef.current - ? options.initialMessages || [] - : messagesRef.current - + const messagesToUse = options.initialMessages || [] isFirstMountRef.current = false // Build options with conditional spreads for fields whose source @@ -60,7 +60,7 @@ export function useChat = any>( // optional (`field?: T`) — `exactOptionalPropertyTypes` rejects // assigning `undefined` to those, so we omit the key when absent. const initialOptions = optionsRef.current - return new ChatClient({ + const instance = new ChatClient({ connection: initialOptions.connection, id: clientId, initialMessages: messagesToUse, @@ -68,19 +68,32 @@ export function useChat = any>( ...(initialOptions.forwardedProps !== undefined && { forwardedProps: initialOptions.forwardedProps, }), + ...(initialOptions.persistence !== undefined && { + persistence: initialOptions.persistence, + }), // Wrap every callback so the latest options are read at call time. // Capturing the function reference directly would freeze it to whatever // the parent passed on the first render. - onResponse: (response) => optionsRef.current.onResponse?.(response), - onChunk: (chunk) => optionsRef.current.onChunk?.(chunk), + onResponse: (response) => { + if (activeClientRef.current !== instance) return + return optionsRef.current.onResponse?.(response) + }, + onChunk: (chunk) => { + if (activeClientRef.current !== instance) return + optionsRef.current.onChunk?.(chunk) + }, onFinish: (message) => { + if (activeClientRef.current !== instance) return optionsRef.current.onFinish?.(message) }, onError: (err) => { + if (activeClientRef.current !== instance) return optionsRef.current.onError?.(err) }, - onCustomEvent: (eventType, data, context) => - optionsRef.current.onCustomEvent?.(eventType, data, context), + onCustomEvent: (eventType, data, context) => { + if (activeClientRef.current !== instance) return + optionsRef.current.onCustomEvent?.(eventType, data, context) + }, ...(initialOptions.tools !== undefined && { tools: initialOptions.tools, }), @@ -88,29 +101,45 @@ export function useChat = any>( streamProcessor: options.streamProcessor, }), onMessagesChange: (newMessages: Array>) => { + if (activeClientRef.current !== instance) return setMessages(newMessages) }, onLoadingChange: (newIsLoading: boolean) => { + if (activeClientRef.current !== instance) return setIsLoading(newIsLoading) }, onStatusChange: (newStatus: ChatClientState) => { + if (activeClientRef.current !== instance) return setStatus(newStatus) }, onErrorChange: (newError: Error | undefined) => { + if (activeClientRef.current !== instance) return setError(newError) }, onSubscriptionChange: (nextIsSubscribed: boolean) => { + if (activeClientRef.current !== instance) return setIsSubscribed(nextIsSubscribed) }, onConnectionStatusChange: (nextStatus: ConnectionStatus) => { + if (activeClientRef.current !== instance) return setConnectionStatus(nextStatus) }, onSessionGeneratingChange: (isGenerating: boolean) => { + if (activeClientRef.current !== instance) return setSessionGenerating(isGenerating) }, }) + activeClientRef.current = instance + return instance }, [clientId]) + useEffect(() => { + const clientMessages = client.getMessages() as Array> + if (clientMessages !== messagesRef.current) { + setMessages(clientMessages) + } + }, [client]) + // Sync body / forwardedProps changes to the client. // Both populate the same wire payload; `forwardedProps` is preferred // and `body` is deprecated but still supported. @@ -125,19 +154,6 @@ export function useChat = any>( }) }, [client, options.body, options.forwardedProps]) - // Sync initial messages on mount only - // Note: initialMessages are passed to ChatClient constructor, but we also - // set them here to ensure Preact state is in sync - useEffect(() => { - if ( - options.initialMessages && - options.initialMessages.length && - !messages.length - ) { - client.setMessagesManually(options.initialMessages) - } - }, []) - useEffect(() => { if (options.live) { client.subscribe() @@ -151,14 +167,26 @@ export function useChat = any>( // DO NOT include isLoading in dependencies - that would cause the cleanup // to run when isLoading changes, aborting continuation requests. useEffect(() => { + if (cleanupInvalidationRef.current) { + clearTimeout(cleanupInvalidationRef.current) + cleanupInvalidationRef.current = null + } + activeClientRef.current = client + return () => { - if (options.live) { + cleanupInvalidationRef.current = setTimeout(() => { + if (activeClientRef.current === client) { + activeClientRef.current = null + } + cleanupInvalidationRef.current = null + }, 0) + if (optionsRef.current.live) { client.unsubscribe() } else { client.stop() } } - }, [client, options.live]) + }, [client]) // All callback options are read through optionsRef at call time, so fresh // closures from each render are picked up without recreating the client. @@ -215,8 +243,10 @@ export function useChat = any>( [client], ) + const renderedMessages = client.getMessages() as Array> + return { - messages, + messages: renderedMessages, sendMessage, append, reload, diff --git a/packages/typescript/ai-preact/tests/use-chat.test.ts b/packages/typescript/ai-preact/tests/use-chat.test.ts index e2e2b8f24..9963014cf 100644 --- a/packages/typescript/ai-preact/tests/use-chat.test.ts +++ b/packages/typescript/ai-preact/tests/use-chat.test.ts @@ -1,7 +1,12 @@ -import type { ModelMessage } from '@tanstack/ai' -import { act, waitFor } from '@testing-library/preact' +import type { ModelMessage, StreamChunk } from '@tanstack/ai' +import { EventType } from '@tanstack/ai' +import type { SubscribeConnectionAdapter } from '@tanstack/ai-client' +import { act, renderHook, waitFor } from '@testing-library/preact' +import { StrictMode } from 'preact/compat' +import { useState } from 'preact/hooks' import { describe, expect, it, vi } from 'vitest' import type { UIMessage } from '../src/types' +import { useChat } from '../src/use-chat' import { createMockConnectionAdapter, createTextChunks, @@ -10,6 +15,14 @@ import { } from './test-utils' describe('useChat', () => { + function createDeferred() { + let resolve!: (value: T) => void + const promise = new Promise((promiseResolve) => { + resolve = promiseResolve + }) + return { promise, resolve } + } + describe('initialization', () => { it('should initialize with default state', () => { const adapter = createMockConnectionAdapter() @@ -55,6 +68,120 @@ describe('useChat', () => { expect(result.current.messages).toEqual(initialMessages) }) + it('should initialize with persisted messages', async () => { + const adapter = createMockConnectionAdapter() + const persistedMessages: Array = [ + { + id: 'persisted-1', + role: 'user', + parts: [{ type: 'text', content: 'Persisted' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn(() => persistedMessages), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const { result } = renderUseChat({ + connection: adapter, + id: 'persisted-chat', + persistence, + }) + + await waitFor(() => { + expect(result.current.messages).toEqual(persistedMessages) + }) + expect(persistence.getItem).toHaveBeenCalledWith('persisted-chat') + }) + + it('should preserve persisted empty messages over provided initial messages', async () => { + const adapter = createMockConnectionAdapter() + const initialMessages: Array = [ + { + id: 'initial-1', + role: 'user', + parts: [{ type: 'text', content: 'Initial' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn(() => []), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const { result } = renderUseChat({ + connection: adapter, + id: 'persisted-empty-chat', + initialMessages, + persistence, + }) + + await waitFor(() => { + expect(persistence.getItem).toHaveBeenCalledWith( + 'persisted-empty-chat', + ) + }) + expect(result.current.messages).toEqual([]) + }) + + it('should ignore async persisted messages from a previous id', async () => { + const oldHydration = createDeferred>() + const oldMessages: Array = [ + { + id: 'old-persisted', + role: 'user', + parts: [{ type: 'text', content: 'Old persisted' }], + createdAt: new Date(), + }, + ] + const newMessages: Array = [ + { + id: 'new-persisted', + role: 'user', + parts: [{ type: 'text', content: 'New persisted' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn((id: string) => + id === 'old-chat' ? oldHydration.promise : newMessages, + ), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + function useChangingChat() { + const [id, setId] = useState('old-chat') + const chat = useChat({ + connection: createMockConnectionAdapter(), + id, + persistence, + }) + + return { ...chat, setId } + } + + const { result } = renderHook(() => useChangingChat()) + + act(() => { + result.current.setId('new-chat') + }) + + await waitFor(() => { + expect(result.current.messages).toEqual(newMessages) + }) + + await act(async () => { + oldHydration.resolve(oldMessages) + await oldHydration.promise + }) + + expect(result.current.messages).toEqual(newMessages) + }) + it('should use provided id', async () => { const chunks = createTextChunks('Response') const adapter = createMockConnectionAdapter({ chunks }) @@ -876,6 +1003,78 @@ describe('useChat', () => { }) describe('edge cases and error handling', () => { + describe('callbacks', () => { + it('should ignore user callbacks from an old client after id changes', async () => { + const releaseOldStream = createDeferred() + const oldOnChunk = vi.fn() + const newOnChunk = vi.fn() + const adapter = { + async *connect(): AsyncIterable { + await releaseOldStream.promise + yield* createTextChunks('stale old client') + }, + } + + const { result, rerender } = renderHook( + (opts: { id: string; onChunk: (chunk: StreamChunk) => void }) => + useChat({ + connection: adapter, + id: opts.id, + onChunk: opts.onChunk, + }), + { + initialProps: { + id: 'old-client', + onChunk: oldOnChunk, + }, + }, + ) + + let sendPromise: Promise + act(() => { + sendPromise = result.current.sendMessage('Test') + }) + await waitFor(() => { + expect(result.current.isLoading).toBe(true) + }) + + rerender({ + id: 'new-client', + onChunk: newOnChunk, + }) + + releaseOldStream.resolve() + await sendPromise! + + expect(oldOnChunk).not.toHaveBeenCalled() + expect(newOnChunk).not.toHaveBeenCalled() + }) + + it('should keep callbacks live across StrictMode effect replay for the same client', async () => { + const onChunk = vi.fn() + const adapter = createMockConnectionAdapter({ + chunks: createTextChunks('strict response'), + }) + + const { result } = renderHook( + () => + useChat({ + connection: adapter, + onChunk, + }), + { wrapper: StrictMode }, + ) + + await act(async () => { + await result.current.sendMessage('Test') + }) + + expect(onChunk).toHaveBeenCalledWith( + expect.objectContaining({ type: EventType.TEXT_MESSAGE_CONTENT }), + ) + }) + }) + describe('options changes', () => { it('should maintain client instance when options change', () => { const adapter1 = createMockConnectionAdapter() @@ -926,6 +1125,93 @@ describe('useChat', () => { }) }) + describe('client recreation', () => { + it('should not pass previous id messages to a new client id without persisted messages', async () => { + const connectSpy = vi.fn() + const adapter = createMockConnectionAdapter({ + chunks: createTextChunks('Reply'), + onConnect: connectSpy, + }) + + const { result } = renderHook(() => { + const [id, setId] = useState('client-A') + const chat = useChat({ connection: adapter, id }) + return { ...chat, switchId: setId } + }) + + const messages: Array = [ + { + id: 'msg-1', + role: 'user', + parts: [{ type: 'text', content: 'Hello' }], + createdAt: new Date(), + }, + { + id: 'msg-2', + role: 'assistant', + parts: [{ type: 'text', content: 'Hi there!' }], + createdAt: new Date(), + }, + ] + + act(() => { + result.current.setMessages(messages) + result.current.switchId('client-B') + }) + + await act(async () => { + await result.current.sendMessage('Follow-up') + }) + + await waitFor(() => { + expect(connectSpy).toHaveBeenCalled() + }) + + const sentMessages = connectSpy.mock.calls[0]![0] as Array< + ModelMessage | UIMessage + > + const sentText = sentMessages + .flatMap((message) => ('parts' in message ? message.parts : [])) + .filter((part) => part.type === 'text') + .map((part) => part.content) + .join('') + + expect(sentText).toContain('Follow-up') + expect(sentText).not.toContain('Hello') + expect(result.current.messages).not.toEqual(messages) + }) + + it('should return new client messages during the id change render', () => { + const adapter = createMockConnectionAdapter() + const oldMessages: Array = [ + { + id: 'old-message', + role: 'user', + parts: [{ type: 'text', content: 'Old client message' }], + createdAt: new Date(), + }, + ] + + const { result } = renderHook(() => { + const [id, setId] = useState('client-A') + const chat = useChat({ connection: adapter, id }) + return { ...chat, switchId: setId } + }) + + act(() => { + result.current.setMessages(oldMessages) + }) + + expect(result.current.messages).toEqual(oldMessages) + + act(() => { + result.current.switchId('client-B') + }) + + expect(result.current.messages).toEqual([]) + }) + }) + describe('unmount behavior', () => { it('should not update state after unmount', async () => { const chunks = createTextChunks('Response') @@ -1374,5 +1660,92 @@ describe('useChat', () => { } }) }) + + describe('sessionGenerating', () => { + it('should keep receiving live updates and callbacks after live toggles for the same client', async () => { + const onChunk = vi.fn() + const chunks: Array = [] + // Wrapped in an object so the assignment inside the generator closure + // is visible to TS control-flow analysis. A bare `let` would narrow to + // `null` at the call site, and `?.()` would then strip it to `never`. + const subscriberControl: { wake: (() => void) | null } = { wake: null } + const adapter: SubscribeConnectionAdapter = { + subscribe: async function* (_signal?: AbortSignal) { + while (true) { + if (chunks.length === 0) { + await new Promise((resolve) => { + subscriberControl.wake = resolve + }) + } + const chunk = chunks.shift() + if (chunk) yield chunk + } + }, + send: vi.fn(async () => {}), + } + + const { result, rerender } = renderHook( + ({ live }) => + useChat({ + connection: adapter, + live, + onChunk, + }), + { initialProps: { live: true } }, + ) + + await waitFor(() => { + expect(result.current.isSubscribed).toBe(true) + }) + + rerender({ live: false }) + await waitFor(() => { + expect(result.current.isSubscribed).toBe(false) + }) + + rerender({ live: true }) + await waitFor(() => { + expect(result.current.isSubscribed).toBe(true) + }) + + chunks.push( + { + type: EventType.RUN_STARTED, + runId: 'run-after-toggle', + threadId: 'thread-1', + timestamp: Date.now(), + }, + { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: 'msg-after-toggle', + timestamp: Date.now(), + delta: 'after toggle', + content: 'after toggle', + }, + { + type: EventType.RUN_FINISHED, + runId: 'run-after-toggle', + threadId: 'thread-1', + timestamp: Date.now(), + }, + ) + subscriberControl.wake?.() + subscriberControl.wake = null + + await waitFor(() => { + expect( + result.current.messages.some((message) => + message.parts.some( + (part) => + part.type === 'text' && part.content === 'after toggle', + ), + ), + ).toBe(true) + }) + expect(onChunk).toHaveBeenCalledWith( + expect.objectContaining({ type: EventType.TEXT_MESSAGE_CONTENT }), + ) + }) + }) }) }) diff --git a/packages/typescript/ai-react/src/use-chat.ts b/packages/typescript/ai-react/src/use-chat.ts index d9bd2dce8..874d22189 100644 --- a/packages/typescript/ai-react/src/use-chat.ts +++ b/packages/typescript/ai-react/src/use-chat.ts @@ -47,6 +47,10 @@ export function useChat< options.initialMessages || [], ) const isFirstMountRef = useRef(true) + const activeClientRef = useRef(null) + const cleanupInvalidationRef = useRef | null>( + null, + ) // Update ref synchronously during render so it's always current when useMemo runs. messagesRef.current = messages @@ -57,10 +61,7 @@ export function useChat< // Create ChatClient instance with callbacks to sync state const client = useMemo(() => { - const messagesToUse = isFirstMountRef.current - ? options.initialMessages || [] - : messagesRef.current - + const messagesToUse = options.initialMessages || [] isFirstMountRef.current = false // Build options with conditional spreads for fields whose source @@ -68,7 +69,7 @@ export function useChat< // optional (`field?: T`) — `exactOptionalPropertyTypes` rejects // assigning `undefined` to those, so we omit the key when absent. const initialOptions = optionsRef.current - return new ChatClient({ + const instance = new ChatClient({ connection: initialOptions.connection, id: clientId, initialMessages: messagesToUse, @@ -76,51 +77,75 @@ export function useChat< ...(initialOptions.forwardedProps !== undefined && { forwardedProps: initialOptions.forwardedProps, }), + ...(initialOptions.persistence !== undefined && { + persistence: initialOptions.persistence, + }), onResponse: (response) => { + if (activeClientRef.current !== instance) return void optionsRef.current.onResponse?.(response) }, onChunk: (chunk: StreamChunk) => { + if (activeClientRef.current !== instance) return optionsRef.current.onChunk?.(chunk) }, onFinish: (message: UIMessage) => { + if (activeClientRef.current !== instance) return optionsRef.current.onFinish?.(message) }, onError: (error: Error) => { + if (activeClientRef.current !== instance) return optionsRef.current.onError?.(error) }, ...(initialOptions.tools !== undefined && { tools: initialOptions.tools, }), onCustomEvent: (eventType, data, context) => { + if (activeClientRef.current !== instance) return optionsRef.current.onCustomEvent?.(eventType, data, context) }, ...(options.streamProcessor !== undefined && { streamProcessor: options.streamProcessor, }), onMessagesChange: (newMessages: Array>) => { + if (activeClientRef.current !== instance) return setMessages(newMessages) }, onLoadingChange: (newIsLoading: boolean) => { + if (activeClientRef.current !== instance) return setIsLoading(newIsLoading) }, onErrorChange: (newError: Error | undefined) => { + if (activeClientRef.current !== instance) return setError(newError) }, onStatusChange: (status: ChatClientState) => { + if (activeClientRef.current !== instance) return setStatus(status) }, onSubscriptionChange: (nextIsSubscribed: boolean) => { + if (activeClientRef.current !== instance) return setIsSubscribed(nextIsSubscribed) }, onConnectionStatusChange: (nextStatus: ConnectionStatus) => { + if (activeClientRef.current !== instance) return setConnectionStatus(nextStatus) }, onSessionGeneratingChange: (isGenerating: boolean) => { + if (activeClientRef.current !== instance) return setSessionGenerating(isGenerating) }, }) + activeClientRef.current = instance + return instance }, [clientId]) + useEffect(() => { + const clientMessages = client.getMessages() as Array> + if (clientMessages !== messagesRef.current) { + setMessages(clientMessages) + } + }, [client]) + useEffect(() => { // Conditional spread: `updateOptions` declares strict-optional // fields and rejects explicit `undefined` under EOPT. @@ -132,14 +157,6 @@ export function useChat< }) }, [client, options.body, options.forwardedProps]) - useEffect(() => { - if (options.initialMessages && options.initialMessages.length > 0) { - if (messages.length === 0) { - client.setMessagesManually(options.initialMessages) - } - } - }, []) - useEffect(() => { if (options.live) { client.subscribe() @@ -149,14 +166,26 @@ export function useChat< }, [client, options.live]) useEffect(() => { + if (cleanupInvalidationRef.current) { + clearTimeout(cleanupInvalidationRef.current) + cleanupInvalidationRef.current = null + } + activeClientRef.current = client + return () => { - if (options.live) { + cleanupInvalidationRef.current = setTimeout(() => { + if (activeClientRef.current === client) { + activeClientRef.current = null + } + cleanupInvalidationRef.current = null + }, 0) + if (optionsRef.current.live) { client.unsubscribe() } else { client.stop() } } - }, [client, options.live]) + }, [client]) const sendMessage = useCallback( async (content: string | MultimodalContent) => { @@ -221,17 +250,19 @@ export function useChat< // a stale assistant turn or a system prompt) we deliberately return null // rather than scanning historical assistants — otherwise a `final` from a // previous session would leak into the hook value on first render. + const renderedMessages = client.getMessages() as Array> + const activeStructuredPart = useMemo(() => { let lastUserIndex = -1 - for (let i = messages.length - 1; i >= 0; i--) { - if (messages[i]?.role === 'user') { + for (let i = renderedMessages.length - 1; i >= 0; i--) { + if (renderedMessages[i]?.role === 'user') { lastUserIndex = i break } } if (lastUserIndex === -1) return null - for (let i = messages.length - 1; i > lastUserIndex; i--) { - const m = messages[i] + for (let i = renderedMessages.length - 1; i > lastUserIndex; i--) { + const m = renderedMessages[i] if (m?.role !== 'assistant') continue const part = m.parts.find( (p): p is StructuredOutputPart => p.type === 'structured-output', @@ -239,7 +270,7 @@ export function useChat< if (part) return part } return null - }, [messages]) + }, [renderedMessages]) const partial = useMemo(() => { if (!activeStructuredPart) return {} as Partial @@ -259,7 +290,7 @@ export function useChat< // structurally narrow across that conditional, so the `as` is the seam. // eslint-disable-next-line no-restricted-syntax -- hook return shape diverges from generic UseChatReturn due to conditional type on TSchema; TS can't structurally narrow return { - messages, + messages: renderedMessages, sendMessage, append, reload, diff --git a/packages/typescript/ai-react/tests/use-chat.test.ts b/packages/typescript/ai-react/tests/use-chat.test.ts index 9a536c1c7..0bbdee039 100644 --- a/packages/typescript/ai-react/tests/use-chat.test.ts +++ b/packages/typescript/ai-react/tests/use-chat.test.ts @@ -2,7 +2,7 @@ import type { ModelMessage, StreamChunk } from '@tanstack/ai' import { EventType } from '@tanstack/ai' import type { SubscribeConnectionAdapter } from '@tanstack/ai-client' import { act, renderHook, waitFor } from '@testing-library/react' -import { useState } from 'react' +import { StrictMode, useState } from 'react' import { describe, expect, it, vi } from 'vitest' import type { UIMessage, UseChatOptions } from '../src/types' import { useChat } from '../src/use-chat' @@ -14,6 +14,14 @@ import { } from './test-utils' describe('useChat', () => { + function createDeferred() { + let resolve!: (value: T) => void + const promise = new Promise((promiseResolve) => { + resolve = promiseResolve + }) + return { promise, resolve } + } + describe('initialization', () => { it('should initialize with default state', () => { const adapter = createMockConnectionAdapter() @@ -59,6 +67,120 @@ describe('useChat', () => { expect(result.current.messages).toEqual(initialMessages) }) + it('should initialize with persisted messages', async () => { + const adapter = createMockConnectionAdapter() + const persistedMessages: UIMessage[] = [ + { + id: 'persisted-1', + role: 'user', + parts: [{ type: 'text', content: 'Persisted' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn(() => persistedMessages), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const { result } = renderUseChat({ + connection: adapter, + id: 'persisted-chat', + persistence, + }) + + await waitFor(() => { + expect(result.current.messages).toEqual(persistedMessages) + }) + expect(persistence.getItem).toHaveBeenCalledWith('persisted-chat') + }) + + it('should preserve persisted empty messages over provided initial messages', async () => { + const adapter = createMockConnectionAdapter() + const initialMessages: UIMessage[] = [ + { + id: 'initial-1', + role: 'user', + parts: [{ type: 'text', content: 'Initial' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn(() => []), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const { result } = renderUseChat({ + connection: adapter, + id: 'persisted-empty-chat', + initialMessages, + persistence, + }) + + await waitFor(() => { + expect(persistence.getItem).toHaveBeenCalledWith( + 'persisted-empty-chat', + ) + }) + expect(result.current.messages).toEqual([]) + }) + + it('should ignore async persisted messages from a previous id', async () => { + const oldHydration = createDeferred>() + const oldMessages: Array = [ + { + id: 'old-persisted', + role: 'user', + parts: [{ type: 'text', content: 'Old persisted' }], + createdAt: new Date(), + }, + ] + const newMessages: Array = [ + { + id: 'new-persisted', + role: 'user', + parts: [{ type: 'text', content: 'New persisted' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn((id: string) => + id === 'old-chat' ? oldHydration.promise : newMessages, + ), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + function useChangingChat() { + const [id, setId] = useState('old-chat') + const chat = useChat({ + connection: createMockConnectionAdapter(), + id, + persistence, + }) + + return { ...chat, setId } + } + + const { result } = renderHook(() => useChangingChat()) + + act(() => { + result.current.setId('new-chat') + }) + + await waitFor(() => { + expect(result.current.messages).toEqual(newMessages) + }) + + await act(async () => { + oldHydration.resolve(oldMessages) + await oldHydration.promise + }) + + expect(result.current.messages).toEqual(newMessages) + }) + it('should use provided id', async () => { const chunks = createTextChunks('Response') const adapter = createMockConnectionAdapter({ chunks }) @@ -812,7 +934,71 @@ describe('useChat', () => { }) expect(first).not.toHaveBeenCalled() }) - }) + + it('should ignore user callbacks from an old client after id changes', async () => { + const releaseOldStream = createDeferred() + const oldOnChunk = vi.fn() + const newOnChunk = vi.fn() + const adapter = { + async *connect() { + await releaseOldStream.promise + yield* createTextChunks('stale old client') + }, + } + + const { result, rerender } = renderHook( + (opts: UseChatOptions) => useChat(opts), + { + initialProps: { + connection: adapter, + id: 'old-client', + onChunk: oldOnChunk, + }, + }, + ) + + const sendPromise = result.current.sendMessage('Test') + await waitFor(() => { + expect(result.current.isLoading).toBe(true) + }) + + rerender({ + connection: adapter, + id: 'new-client', + onChunk: newOnChunk, + }) + + releaseOldStream.resolve() + await sendPromise + + expect(oldOnChunk).not.toHaveBeenCalled() + expect(newOnChunk).not.toHaveBeenCalled() + }) + + it('should keep callbacks live across StrictMode effect replay for the same client', async () => { + const onChunk = vi.fn() + const adapter = createMockConnectionAdapter({ + chunks: createTextChunks('strict response'), + }) + + const { result } = renderHook( + () => + useChat({ + connection: adapter, + onChunk, + }), + { wrapper: StrictMode }, + ) + + await act(async () => { + await result.current.sendMessage('Test') + }) + + expect(onChunk).toHaveBeenCalledWith( + expect.objectContaining({ type: EventType.TEXT_MESSAGE_CONTENT }), + ) + }) + }) describe('edge cases and error handling', () => { describe('options changes', () => { @@ -866,7 +1052,7 @@ describe('useChat', () => { }) describe('client recreation', () => { - it('should pass existing messages to new client when id changes in a batched update', async () => { + it('should not pass previous id messages to a new client id without persisted messages', async () => { const connectSpy = vi.fn() const chunks = createTextChunks('Reply') const adapter = createMockConnectionAdapter({ @@ -906,8 +1092,6 @@ describe('useChat', () => { result.current.switchId('client-B') }) - // Send a message through the new client. If the client lost the - // previous messages, the adapter only receives the new message. await act(async () => { await result.current.sendMessage('Follow-up') }) @@ -916,14 +1100,48 @@ describe('useChat', () => { expect(connectSpy).toHaveBeenCalled() }) - // The adapter should receive the previous conversation + new message. - // `connectSpy.mock.calls[0][0]` is the messages array passed to - // adapter.connect — typed via MockConnectionAdapterOptions.onConnect. const sentMessages = connectSpy.mock.calls[0]![0] as Array< ModelMessage | UIMessage > - const userMessages = sentMessages.filter((m) => m.role === 'user') - expect(userMessages.length).toBeGreaterThanOrEqual(2) + const sentText = sentMessages + .flatMap((message) => ('parts' in message ? message.parts : [])) + .filter((part) => part.type === 'text') + .map((part) => part.content) + .join('') + + expect(sentText).toContain('Follow-up') + expect(sentText).not.toContain('Hello') + expect(result.current.messages).not.toEqual(messages) + }) + + it('should return new client messages during the id change render', () => { + const adapter = createMockConnectionAdapter() + const oldMessages: Array = [ + { + id: 'old-message', + role: 'user', + parts: [{ type: 'text', content: 'Old client message' }], + createdAt: new Date(), + }, + ] + + const { result } = renderHook(() => { + const [id, setId] = useState('client-A') + const chat = useChat({ connection: adapter, id }) + return { ...chat, switchId: setId } + }) + + act(() => { + result.current.setMessages(oldMessages) + }) + + expect(result.current.messages).toEqual(oldMessages) + + act(() => { + result.current.switchId('client-B') + }) + + expect(result.current.messages).toEqual([]) }) }) @@ -1716,5 +1934,89 @@ describe('useChat', () => { expect(result.current.isLoading).toBe(false) }) }) + + it('should keep receiving live updates and callbacks after live toggles for the same client', async () => { + const onChunk = vi.fn() + const chunks: Array = [] + // Wrapped in an object so the assignment inside the generator closure + // is visible to TS control-flow analysis. A bare `let` would narrow to + // `null` at the call site, and `?.()` would then strip it to `never`. + const subscriberControl: { wake: (() => void) | null } = { wake: null } + const adapter: SubscribeConnectionAdapter = { + subscribe: async function* (_signal?: AbortSignal) { + while (true) { + if (chunks.length === 0) { + await new Promise((resolve) => { + subscriberControl.wake = resolve + }) + } + const chunk = chunks.shift() + if (chunk) yield chunk + } + }, + send: vi.fn(async () => {}), + } + + const { result, rerender } = renderHook( + ({ live }) => + useChat({ + connection: adapter, + live, + onChunk, + }), + { initialProps: { live: true } }, + ) + + await waitFor(() => { + expect(result.current.isSubscribed).toBe(true) + }) + + rerender({ live: false }) + await waitFor(() => { + expect(result.current.isSubscribed).toBe(false) + }) + + rerender({ live: true }) + await waitFor(() => { + expect(result.current.isSubscribed).toBe(true) + }) + + chunks.push( + { + type: EventType.RUN_STARTED, + runId: 'run-after-toggle', + threadId: 'thread-1', + timestamp: Date.now(), + }, + { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: 'msg-after-toggle', + timestamp: Date.now(), + delta: 'after toggle', + content: 'after toggle', + }, + { + type: EventType.RUN_FINISHED, + runId: 'run-after-toggle', + threadId: 'thread-1', + timestamp: Date.now(), + }, + ) + subscriberControl.wake?.() + subscriberControl.wake = null + + await waitFor(() => { + expect( + result.current.messages.some((message) => + message.parts.some( + (part) => part.type === 'text' && part.content === 'after toggle', + ), + ), + ).toBe(true) + }) + expect(onChunk).toHaveBeenCalledWith( + expect.objectContaining({ type: EventType.TEXT_MESSAGE_CONTENT }), + ) + }) }) }) diff --git a/packages/typescript/ai-solid/src/use-chat.ts b/packages/typescript/ai-solid/src/use-chat.ts index d2a2b8163..ac6ef1d80 100644 --- a/packages/typescript/ai-solid/src/use-chat.ts +++ b/packages/typescript/ai-solid/src/use-chat.ts @@ -74,6 +74,9 @@ export function useChat< ...(options.initialMessages !== undefined && { initialMessages: options.initialMessages, }), + ...(options.persistence !== undefined && { + persistence: options.persistence, + }), body: options.body, ...(options.forwardedProps !== undefined && { forwardedProps: options.forwardedProps, @@ -120,6 +123,8 @@ export function useChat< // Connection and other options are captured at creation time }, [clientId]) + setMessages(client().getMessages() as Array>) + // Sync body / forwardedProps changes to the client. // Both populate the same wire payload; `forwardedProps` is preferred // and `body` is deprecated but still supported. @@ -134,18 +139,6 @@ export function useChat< }) }) - // Sync initial messages on mount only - // Note: initialMessages are passed to ChatClient constructor, but we also - // set them here to ensure React state is in sync - createEffect(() => { - if (options.initialMessages && options.initialMessages.length > 0) { - // Only set if current messages are empty (initial state) - if (messages().length === 0) { - client().setMessagesManually(options.initialMessages) - } - } - }) // Only run on mount - initialMessages are handled by ChatClient constructor - // Apply initial live mode immediately on hook creation. if (options.live) { client().subscribe() diff --git a/packages/typescript/ai-solid/tests/use-chat.test.ts b/packages/typescript/ai-solid/tests/use-chat.test.ts index f6afaef4e..350fb7fce 100644 --- a/packages/typescript/ai-solid/tests/use-chat.test.ts +++ b/packages/typescript/ai-solid/tests/use-chat.test.ts @@ -55,6 +55,65 @@ describe('useChat', () => { expect(result.current.messages).toEqual(initialMessages) }) + it('should initialize with persisted messages', async () => { + const adapter = createMockConnectionAdapter() + const persistedMessages: UIMessage[] = [ + { + id: 'persisted-1', + role: 'user', + parts: [{ type: 'text', content: 'Persisted' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn(() => persistedMessages), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const { result } = renderUseChat({ + connection: adapter, + id: 'persisted-chat', + persistence, + }) + + await waitFor(() => { + expect(result.current.messages).toEqual(persistedMessages) + }) + expect(persistence.getItem).toHaveBeenCalledWith('persisted-chat') + }) + + it('should preserve persisted empty messages over provided initial messages', async () => { + const adapter = createMockConnectionAdapter() + const initialMessages: UIMessage[] = [ + { + id: 'initial-1', + role: 'user', + parts: [{ type: 'text', content: 'Initial' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn(() => []), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const { result } = renderUseChat({ + connection: adapter, + id: 'persisted-empty-chat', + initialMessages, + persistence, + }) + + await waitFor(() => { + expect(persistence.getItem).toHaveBeenCalledWith( + 'persisted-empty-chat', + ) + }) + expect(result.current.messages).toEqual([]) + }) + it('should use provided id', async () => { const chunks = createTextChunks('Response') const adapter = createMockConnectionAdapter({ chunks }) diff --git a/packages/typescript/ai-svelte/src/create-chat.svelte.ts b/packages/typescript/ai-svelte/src/create-chat.svelte.ts index a10aef2a3..a3a869a36 100644 --- a/packages/typescript/ai-svelte/src/create-chat.svelte.ts +++ b/packages/typescript/ai-svelte/src/create-chat.svelte.ts @@ -92,6 +92,9 @@ export function createChat< ...(options.initialMessages !== undefined && { initialMessages: options.initialMessages, }), + ...(options.persistence !== undefined && { + persistence: options.persistence, + }), ...(options.body !== undefined && { body: options.body }), ...(options.forwardedProps !== undefined && { forwardedProps: options.forwardedProps, @@ -136,6 +139,8 @@ export function createChat< }, }) + messages = client.getMessages() as Array> + if (options.live) { client.subscribe() } diff --git a/packages/typescript/ai-svelte/tests/use-chat.test.ts b/packages/typescript/ai-svelte/tests/use-chat.test.ts index 1a6c4b777..ff5b846ba 100644 --- a/packages/typescript/ai-svelte/tests/use-chat.test.ts +++ b/packages/typescript/ai-svelte/tests/use-chat.test.ts @@ -54,6 +54,59 @@ describe('createChat', () => { expect(chat.messages[0]!.role).toBe('user') }) + it('should initialize with persisted messages', () => { + const mockConnection = createMockConnectionAdapter({ chunks: [] }) + const persistedMessages = [ + { + id: 'persisted-1', + role: 'user' as const, + parts: [{ type: 'text' as const, content: 'Persisted' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn(() => persistedMessages), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const chat = createChat({ + connection: mockConnection, + id: 'persisted-chat', + persistence, + }) + + expect(chat.messages).toEqual(persistedMessages) + expect(persistence.getItem).toHaveBeenCalledWith('persisted-chat') + }) + + it('should let persisted empty arrays override initial messages', () => { + const mockConnection = createMockConnectionAdapter({ chunks: [] }) + const initialMessages = [ + { + id: 'initial-1', + role: 'user' as const, + parts: [{ type: 'text' as const, content: 'Initial' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn(() => []), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const chat = createChat({ + connection: mockConnection, + id: 'persisted-chat', + initialMessages, + persistence, + }) + + expect(chat.messages).toEqual([]) + expect(persistence.getItem).toHaveBeenCalledWith('persisted-chat') + }) + it('should have sendMessage method', () => { const mockConnection = createMockConnectionAdapter({ chunks: [] }) diff --git a/packages/typescript/ai-vue/src/use-chat.ts b/packages/typescript/ai-vue/src/use-chat.ts index 0a7eae106..ed1203516 100644 --- a/packages/typescript/ai-vue/src/use-chat.ts +++ b/packages/typescript/ai-vue/src/use-chat.ts @@ -76,6 +76,9 @@ export function useChat< ...(options.initialMessages !== undefined && { initialMessages: options.initialMessages, }), + ...(options.persistence !== undefined && { + persistence: options.persistence, + }), ...(options.body !== undefined && { body: options.body }), ...(options.forwardedProps !== undefined && { forwardedProps: options.forwardedProps, @@ -119,6 +122,8 @@ export function useChat< }, }) + messages.value = client.getMessages() as Array> + // Sync body / forwardedProps changes to the client. // Both populate the same wire payload; `forwardedProps` is preferred // and `body` is deprecated but still supported. diff --git a/packages/typescript/ai-vue/tests/use-chat.test.ts b/packages/typescript/ai-vue/tests/use-chat.test.ts index 6492fb396..0e227afe5 100644 --- a/packages/typescript/ai-vue/tests/use-chat.test.ts +++ b/packages/typescript/ai-vue/tests/use-chat.test.ts @@ -54,6 +54,61 @@ describe('useChat', () => { expect(result.current.messages).toEqual(initialMessages) }) + it('should initialize with persisted messages', async () => { + const adapter = createMockConnectionAdapter() + const persistedMessages: Array = [ + { + id: 'persisted-1', + role: 'user', + parts: [{ type: 'text', content: 'Persisted' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn(() => persistedMessages), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const { result } = renderUseChat({ + connection: adapter, + id: 'persisted-chat', + persistence, + }) + await flushPromises() + + expect(result.current.messages).toEqual(persistedMessages) + expect(persistence.getItem).toHaveBeenCalledWith('persisted-chat') + }) + + it('should let persisted empty arrays override initial messages', async () => { + const adapter = createMockConnectionAdapter() + const initialMessages: Array = [ + { + id: 'initial-1', + role: 'user', + parts: [{ type: 'text', content: 'Initial' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn(() => []), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const { result } = renderUseChat({ + connection: adapter, + id: 'persisted-chat', + initialMessages, + persistence, + }) + await flushPromises() + + expect(result.current.messages).toEqual([]) + expect(persistence.getItem).toHaveBeenCalledWith('persisted-chat') + }) + it('should use provided id', async () => { const chunks = createTextChunks('Response') const adapter = createMockConnectionAdapter({ chunks }) diff --git a/testing/e2e/src/routes/$provider/$feature.tsx b/testing/e2e/src/routes/$provider/$feature.tsx index f4c7478e4..26b2f4d42 100644 --- a/testing/e2e/src/routes/$provider/$feature.tsx +++ b/testing/e2e/src/routes/$provider/$feature.tsx @@ -2,6 +2,7 @@ import { createFileRoute } from '@tanstack/react-router' import { useState } from 'react' import { fetchServerSentEvents, useChat } from '@tanstack/ai-react' import { clientTools } from '@tanstack/ai-client' +import type { ChatClientPersistence, UIMessage } from '@tanstack/ai-client' import type { Feature, Mode, Provider } from '@/lib/types' import { ALL_PROVIDERS } from '@/lib/types' import { isSupported } from '@/lib/feature-support' @@ -24,6 +25,8 @@ export const Route = createFileRoute('/$provider/$feature')({ testId: typeof search.testId === 'string' ? search.testId : undefined, aimockPort: port != null && !isNaN(port) ? port : undefined, mode: typeof search.mode === 'string' ? (search.mode as Mode) : undefined, + persistence: + search.persistence === 'localStorage' ? 'localStorage' : undefined, } }, }) @@ -42,6 +45,27 @@ const addToCartClient = addToCartToolDef.client((args) => ({ quantity: args.quantity, })) +const localStoragePersistence: ChatClientPersistence = { + getItem: (id) => { + const item = window.localStorage.getItem(id) + return item + ? (JSON.parse(item) as Array).map((message) => ({ + ...message, + createdAt: + typeof message.createdAt === 'string' + ? new Date(message.createdAt) + : message.createdAt, + })) + : null + }, + setItem: (id, messages) => { + window.localStorage.setItem(id, JSON.stringify(messages)) + }, + removeItem: (id) => { + window.localStorage.removeItem(id) + }, +} + function FeaturePage() { const { provider, feature } = Route.useParams() as { provider: Provider @@ -136,7 +160,8 @@ function ChatFeature({ const tools = needsApproval ? clientTools(addToCartClient) : undefined - const { testId, aimockPort } = Route.useSearch() + const { testId, aimockPort, persistence } = Route.useSearch() + const chatId = `e2e-chat-${testId ?? `${provider}-${feature}`}` // Tracks streaming-structured-output observability for e2e tests: // - structuredObject: captured from the terminal CUSTOM event @@ -148,9 +173,12 @@ function ChatFeature({ const { messages, sendMessage, isLoading, addToolApprovalResponse, stop } = useChat({ + id: chatId, connection: fetchServerSentEvents('/api/chat'), tools, body: { provider, feature, testId, aimockPort }, + persistence: + persistence === 'localStorage' ? localStoragePersistence : undefined, onCustomEvent: (eventType, data) => { if (eventType === 'structured-output.complete') { const value = data as { object: unknown; raw: string } | undefined diff --git a/testing/e2e/tests/chat.spec.ts b/testing/e2e/tests/chat.spec.ts index cdec5109f..e83c926af 100644 --- a/testing/e2e/tests/chat.spec.ts +++ b/testing/e2e/tests/chat.spec.ts @@ -24,3 +24,34 @@ for (const provider of providersFor('chat')) { }) }) } + +test.describe('openai chat persistence', () => { + test('persists chat messages across browser reload with localStorage', async ({ + page, + testId, + aimockPort, + }) => { + await page.goto( + `${featureUrl('openai', 'chat', testId, aimockPort)}&persistence=localStorage`, + ) + + await sendMessage(page, '[chat] recommend a guitar') + await waitForResponse(page) + + await expect(page.getByTestId('user-message')).toContainText( + '[chat] recommend a guitar', + ) + await expect(page.getByTestId('assistant-message')).toContainText( + 'Fender Stratocaster', + ) + + await page.reload() + + await expect(page.getByTestId('user-message')).toContainText( + '[chat] recommend a guitar', + ) + await expect(page.getByTestId('assistant-message')).toContainText( + 'Fender Stratocaster', + ) + }) +}) From 6d982223c8bba2328901efcc2a58e12d13184482 Mon Sep 17 00:00:00 2001 From: Alem Tuzlak Date: Sat, 30 May 2026 16:57:48 +0200 Subject: [PATCH 02/10] refactor(ai-client): extract persistence into ChatPersistor class Move all persistence concerns out of ChatClient into a dedicated ChatPersistor: storage orchestration (hydrate / save / remove via an ordered write queue with generation guards) and clear-during-stream chunk suppression (cleared/ignored id tracking). ChatClient now holds a single optional `persistor` and delegates, keeping it focused on streaming and message state. The run-lifecycle methods keep their activeRunIds / session-generating / processing side-effects and call persistor hooks (onRunStarted, onRunSettled, onSessionRunError, resetIgnored, takeRunlessRunId) for the cleared-set bookkeeping. A shared getChunkRunId helper is exported from connection-adapters so both modules use one source. Pure refactor: no public API or behavior change. All ai-client unit tests, framework wrapper tests, and the persistence E2E test pass unchanged; chat-client.js shrinks ~34kB -> ~28kB. --- .../typescript/ai-client/src/chat-client.ts | 302 +++-------------- .../ai-client/src/client-persistor.ts | 312 ++++++++++++++++++ .../ai-client/src/connection-adapters.ts | 12 + 3 files changed, 373 insertions(+), 253 deletions(-) create mode 100644 packages/typescript/ai-client/src/client-persistor.ts diff --git a/packages/typescript/ai-client/src/chat-client.ts b/packages/typescript/ai-client/src/chat-client.ts index b604aa615..83bb0ba21 100644 --- a/packages/typescript/ai-client/src/chat-client.ts +++ b/packages/typescript/ai-client/src/chat-client.ts @@ -6,9 +6,10 @@ import { } from '@tanstack/ai' import { DefaultChatClientEventEmitter } from './events' import { - getInternalRunContextIds, + getChunkRunId, normalizeConnectionAdapter, } from './connection-adapters' +import { ChatPersistor } from './client-persistor' import type { AnyClientTool, ContentPart, @@ -22,7 +23,6 @@ import type { import type { ChatClientEventEmitter } from './events' import type { ChatClientOptions, - ChatClientPersistence, ChatClientState, ConnectionStatus, MessagePart, @@ -36,17 +36,11 @@ export class ChatClient { private connection: SubscribeConnectionAdapter private readonly uniqueId: string private readonly threadId: string - private readonly persistence?: ChatClientPersistence - private skipNextPersist = false - private persistenceGeneration = 0 - private persistenceQueue: Promise = Promise.resolve() - private persistenceQueuePending = false + // All persistence concerns (hydrate / save / clear, plus suppression of late + // chunks after a mid-stream clear) live in ChatPersistor so this class stays + // focused on streaming. Undefined when no `persistence` adapter is configured. + private readonly persistor?: ChatPersistor private currentRunId: string | null = null - private readonly clearedMessageIds = new Set() - private readonly clearedRunIds = new Set() - private readonly ignoredActiveRunIds = new Set() - private readonly clearedToolCallIds = new Set() - private currentRunlessRunId: string | null = null // Track the legacy `body` option and the canonical `forwardedProps` // option as separate slots so that `updateOptions({ forwardedProps })` // doesn't wipe a previously-set `body` (and vice versa). They are @@ -73,7 +67,6 @@ export class ChatClient { private processingResolve: (() => void) | null = null private errorReportedGeneration: number | null = null private streamGeneration = 0 - private messagesGeneration = 0 // Tracks whether a queued checkForContinuation was skipped because // continuationPending was true (chained approval scenario) private continuationSkipped = false @@ -105,7 +98,13 @@ export class ChatClient { constructor(options: ChatClientOptions) { this.uniqueId = options.id || this.generateUniqueId('chat') this.threadId = options.threadId || this.generateUniqueId('thread') - this.persistence = options.persistence + if (options.persistence) { + this.persistor = new ChatPersistor( + options.persistence, + this.uniqueId, + (messages) => this.processor.setMessages(messages), + ) + } // Both `body` (deprecated) and `forwardedProps` populate the AG-UI // `RunAgentInput.forwardedProps` wire field. They are stored // separately so `updateOptions` can replace one without touching the @@ -146,7 +145,7 @@ export class ChatClient { // Create StreamProcessor with event handlers. // Use conditional spreads so we don't pass `undefined` into // `StreamProcessorOptions` fields under `exactOptionalPropertyTypes`. - const persistedMessages = this.getPersistedMessages() + const persistedMessages = this.persistor?.readInitial() const initialMessages = Array.isArray(persistedMessages) ? persistedMessages : options.initialMessages @@ -158,8 +157,7 @@ export class ChatClient { ...(initialMessages ? { initialMessages } : {}), events: { onMessagesChange: (messages: Array) => { - this.messagesGeneration++ - this.persistMessages(messages) + this.persistor?.notifyMessagesChanged(messages) this.callbacksRef.current.onMessagesChange(messages) }, onStreamStart: () => { @@ -298,218 +296,19 @@ export class ChatClient { }) this.events.clientCreated(this.processor.getMessages().length) - this.hydratePersistedMessagesAsync(persistedMessages) - } - - private getPersistedMessages(): - | Array - | null - | undefined - | Promise | null | undefined> { - if (!this.persistence) { - return undefined - } - try { - return this.persistence.getItem(this.uniqueId) - } catch { - return undefined - } - } - - private hydratePersistedMessagesAsync( - persistedMessages: - | Array - | null - | undefined - | Promise | null | undefined>, - ): void { - if (!(persistedMessages instanceof Promise)) { - return - } - - const hydrationGeneration = this.messagesGeneration - persistedMessages - .then((messages) => { - if ( - Array.isArray(messages) && - this.messagesGeneration === hydrationGeneration - ) { - this.processor.setMessages(messages) - } - }) - .catch(() => { - // Persistence adapters are best-effort and must not break chat setup. - }) - } - - private persistMessages(messages: Array): void { - if (this.skipNextPersist) { - this.skipNextPersist = false - return - } - if (!this.persistence) { - return - } - const persistence = this.persistence - const persistenceGeneration = this.persistenceGeneration - const messagesSnapshot = [...messages] - this.runPersistenceOperation(() => { - if (persistenceGeneration !== this.persistenceGeneration) { - return - } - return persistence.setItem(this.uniqueId, messagesSnapshot) - }) - } - - private runPersistenceOperation(operation: () => void | Promise): void { - if (this.persistenceQueuePending) { - const queued = this.persistenceQueue.then(operation).catch(() => { - // Persistence adapters are best-effort and must not break chat updates. - }) - this.persistenceQueue = queued - void queued.finally(() => { - if (this.persistenceQueue === queued) { - this.persistenceQueuePending = false - } - }) - return - } - - try { - const result = operation() - if (result instanceof Promise) { - this.persistenceQueuePending = true - const queued = result.catch(() => { - // Persistence adapters are best-effort and must not break chat updates. - }) - this.persistenceQueue = queued - void queued.finally(() => { - if (this.persistenceQueue === queued) { - this.persistenceQueuePending = false - } - }) - } - } catch { - // Persistence adapters are best-effort and must not break chat updates. - } - } - - private removePersistedMessages(): void { - if (!this.persistence) { - return - } - const persistence = this.persistence - const persistenceGeneration = ++this.persistenceGeneration - this.runPersistenceOperation(() => { - if (persistenceGeneration !== this.persistenceGeneration) { - return - } - return persistence.removeItem(this.uniqueId) - }) - } - - private snapshotClearedStreamState(): void { - if (!this.persistence) return - for (const message of this.processor.getMessages()) { - this.clearedMessageIds.add(message.id) - } - for (const runId of this.activeRunIds) { - this.clearedRunIds.add(runId) - this.ignoredActiveRunIds.add(runId) - } - if (this.currentRunId) { - this.clearedRunIds.add(this.currentRunId) - this.ignoredActiveRunIds.add(this.currentRunId) - } - } - - private shouldIgnoreChunk(chunk: StreamChunk): boolean { - if (!this.persistence) return false - - const runId = this.getChunkRunId(chunk) - if (runId && this.clearedRunIds.has(runId)) { - if (chunk.type === 'RUN_STARTED') { - this.ignoredActiveRunIds.add(runId) - this.currentRunlessRunId = runId - } - this.markIgnoredChunkIds(chunk) - return true - } - - if (runId && this.ignoredActiveRunIds.has(runId)) { - this.markIgnoredChunkIds(chunk) - return true - } - - if (this.isRunlessChunkFromIgnoredRun(chunk)) { - this.markIgnoredChunkIds(chunk) - return true - } - - const toolCallId = (chunk as { toolCallId?: string }).toolCallId - if (toolCallId && this.clearedToolCallIds.has(toolCallId)) { - return true - } - - const parentMessageId = (chunk as { parentMessageId?: string }) - .parentMessageId - if (parentMessageId && this.clearedMessageIds.has(parentMessageId)) { - if (toolCallId) { - this.clearedToolCallIds.add(toolCallId) - } - return true - } - - const messageId = (chunk as { messageId?: string }).messageId - if (!messageId) { - return false - } - if (this.clearedMessageIds.has(messageId)) { - return true - } - - return false - } - - private markIgnoredChunkIds(chunk: StreamChunk): void { - const messageId = (chunk as { messageId?: string }).messageId - if (messageId) { - this.clearedMessageIds.add(messageId) - } - const toolCallId = (chunk as { toolCallId?: string }).toolCallId - if (toolCallId) { - this.clearedToolCallIds.add(toolCallId) - } - } - - private isRunlessChunkFromIgnoredRun(chunk: StreamChunk): boolean { - const runId = this.getChunkRunId(chunk) - if (runId || !this.currentRunlessRunId) return false - if ( - !this.ignoredActiveRunIds.has(this.currentRunlessRunId) && - !this.clearedRunIds.has(this.currentRunlessRunId) - ) { - return false - } - return ( - chunk.type === 'TEXT_MESSAGE_START' || - chunk.type === 'TEXT_MESSAGE_CONTENT' || - chunk.type === 'TOOL_CALL_START' || - chunk.type === 'TOOL_CALL_ARGS' || - chunk.type === 'TOOL_CALL_END' || - chunk.type === 'TOOL_CALL_RESULT' || - chunk.type === 'MESSAGES_SNAPSHOT' || - chunk.type === 'RUN_ERROR' - ) + this.persistor?.hydrateAsync(persistedMessages) } + /** + * Drain a runId-less RUN_ERROR that belongs to a cleared run the client is + * still tracking. The persistor owns the cleared-run bookkeeping; the client + * owns the active-run / session / processing state. + */ private drainIgnoredRunlessChunk(chunk: StreamChunk): void { - if (!this.currentRunlessRunId || chunk.type !== 'RUN_ERROR') return - const runId = this.currentRunlessRunId + if (chunk.type !== 'RUN_ERROR') return + const runId = this.persistor?.takeRunlessRunId() + if (!runId) return this.activeRunIds.delete(runId) - this.ignoredActiveRunIds.delete(runId) - this.clearedRunIds.delete(runId) - this.currentRunlessRunId = null this.setSessionGenerating(this.activeRunIds.size > 0) this.resolveProcessing() } @@ -519,9 +318,9 @@ export class ChatClient { options?: { resolveProcessing?: boolean }, ): void { if (chunk.type === 'RUN_STARTED') { - const chunkRunId = this.getChunkRunId(chunk) ?? chunk.runId + const chunkRunId = getChunkRunId(chunk) ?? chunk.runId this.activeRunIds.add(chunkRunId) - this.currentRunlessRunId = chunkRunId + this.persistor?.onRunStarted(chunkRunId) this.setSessionGenerating(true) return } @@ -530,19 +329,14 @@ export class ChatClient { return } - const runId = this.getChunkRunId(chunk) + const runId = getChunkRunId(chunk) if (runId) { this.activeRunIds.delete(runId) - this.ignoredActiveRunIds.delete(runId) - this.clearedRunIds.delete(runId) - if (this.currentRunlessRunId === runId) { - this.currentRunlessRunId = - this.ignoredActiveRunIds.values().next().value ?? null - } + this.persistor?.onRunSettled(runId) } else if (chunk.type === 'RUN_ERROR') { + // RUN_ERROR without runId is a session-level error; clear all runs. this.activeRunIds.clear() - this.ignoredActiveRunIds.clear() - this.currentRunlessRunId = null + this.persistor?.onSessionRunError() } this.setSessionGenerating(this.activeRunIds.size > 0) if (options?.resolveProcessing !== false) { @@ -554,10 +348,6 @@ export class ChatClient { return `${prefix}-${Date.now()}-${Math.random().toString(36).substring(7)}` } - private getChunkRunId(chunk: StreamChunk): string | undefined { - return (chunk as { runId?: string }).runId ?? getInternalRunContextIds(chunk)?.runId - } - private setIsLoading(isLoading: boolean): void { this.isLoading = isLoading this.callbacksRef.current.onLoadingChange(isLoading) @@ -587,7 +377,7 @@ export class ChatClient { private resetSessionGenerating(): void { this.activeRunIds.clear() - this.ignoredActiveRunIds.clear() + this.persistor?.resetIgnored() this.setSessionGenerating(false) } @@ -685,10 +475,10 @@ export class ChatClient { if (this.connectionStatus === 'connecting') { this.setConnectionStatus('connected') } - const shouldIgnore = this.shouldIgnoreChunk(chunk) + const shouldIgnore = this.persistor?.shouldIgnoreChunk(chunk) ?? false if (shouldIgnore) { if (chunk.type === 'RUN_FINISHED' || chunk.type === 'RUN_ERROR') { - if (this.getChunkRunId(chunk)) { + if (getChunkRunId(chunk)) { this.updateRunLifecycle(chunk, { resolveProcessing: false }) } else { this.drainIgnoredRunlessChunk(chunk) @@ -1101,18 +891,24 @@ export class ChatClient { * Clear all messages */ clear(): void { - if (this.persistence) { - this.snapshotClearedStreamState() - } - if (this.persistence && this.isLoading) { - this.cancelInFlightStream({ setReadyStatus: true }) - this.resetSessionGenerating() - } else if (this.persistence && this.activeRunIds.size > 0) { - this.resetSessionGenerating() + if (this.persistor) { + this.persistor.snapshotClear({ + messages: this.processor.getMessages(), + activeRunIds: this.activeRunIds, + currentRunId: this.currentRunId, + }) + if (this.isLoading) { + this.cancelInFlightStream({ setReadyStatus: true }) + this.resetSessionGenerating() + } else if (this.activeRunIds.size > 0) { + this.resetSessionGenerating() + } + // Suppress persisting the empty snapshot that clearMessages emits, then + // remove the stored conversation outright. + this.persistor.beginClear() } - this.skipNextPersist = true this.processor.clearMessages() - this.removePersistedMessages() + this.persistor?.remove() this.setError(undefined) this.events.messagesCleared() } diff --git a/packages/typescript/ai-client/src/client-persistor.ts b/packages/typescript/ai-client/src/client-persistor.ts new file mode 100644 index 000000000..fcc050491 --- /dev/null +++ b/packages/typescript/ai-client/src/client-persistor.ts @@ -0,0 +1,312 @@ +import { getChunkRunId } from './connection-adapters' +import type { StreamChunk } from '@tanstack/ai' +import type { ChatClientPersistence, UIMessage } from './types' + +/** + * Encapsulates everything persistence-related for `ChatClient` so the client + * itself stays focused on streaming and message state. + * + * Two responsibilities live here: + * + * 1. **Storage orchestration** — hydrate from `getItem(id)` on creation, save to + * `setItem(id, messages)` on every change through an ordered write queue, and + * `removeItem(id)` on clear. A generation counter discards stale writes when a + * removal or a newer conversation supersedes an in-flight async operation. + * 2. **Clear-during-stream suppression** — when a conversation is cleared while a + * stream is still producing, late chunks for the cleared run(s) must not + * repopulate the now-empty state. The persistor tracks the cleared ids and + * decides, per chunk, whether the client should ignore it. + * + * All adapter calls are best-effort: a throwing or rejecting adapter is swallowed + * so storage problems never break the chat. + */ +export class ChatPersistor { + // --- storage queue state --- + private skipNextPersist = false + private generation = 0 + private queue: Promise = Promise.resolve() + private queuePending = false + // Bumped on every message change; lets an in-flight async hydration detect + // that the message list moved on and avoid clobbering it. + private messagesGeneration = 0 + + // --- clear-during-stream suppression state --- + private readonly clearedMessageIds = new Set() + private readonly clearedRunIds = new Set() + private readonly ignoredActiveRunIds = new Set() + private readonly clearedToolCallIds = new Set() + private currentRunlessRunId: string | null = null + + constructor( + private readonly adapter: ChatClientPersistence, + private readonly id: string, + private readonly applyMessages: (messages: Array) => void, + ) {} + + // --------------------------------------------------------------------------- + // Storage orchestration + // --------------------------------------------------------------------------- + + /** + * Synchronously read the persisted messages for constructor-time hydration. + * Returns the raw `getItem` result (which may be a promise for async stores). + */ + readInitial(): + | Array + | null + | undefined + | Promise | null | undefined> { + try { + return this.adapter.getItem(this.id) + } catch { + return undefined + } + } + + /** + * Apply messages from an async `getItem` once it resolves, unless the message + * list has already changed since hydration began. + */ + hydrateAsync( + persistedMessages: + | Array + | null + | undefined + | Promise | null | undefined>, + ): void { + if (!(persistedMessages instanceof Promise)) { + return + } + + const hydrationGeneration = this.messagesGeneration + persistedMessages + .then((messages) => { + if ( + Array.isArray(messages) && + this.messagesGeneration === hydrationGeneration + ) { + this.applyMessages(messages) + } + }) + .catch(() => { + // Persistence adapters are best-effort and must not break chat setup. + }) + } + + /** + * Record a message-list change and queue a `setItem` write for it. Skips a + * single write after {@link beginClear} so the clear's empty snapshot isn't + * persisted between `clearMessages()` and {@link remove}. + */ + notifyMessagesChanged(messages: Array): void { + this.messagesGeneration++ + if (this.skipNextPersist) { + this.skipNextPersist = false + return + } + const generation = this.generation + const messagesSnapshot = [...messages] + this.runOperation(() => { + if (generation !== this.generation) { + return + } + return this.adapter.setItem(this.id, messagesSnapshot) + }) + } + + /** Remove the persisted conversation. Invalidates any queued writes. */ + remove(): void { + const generation = ++this.generation + this.runOperation(() => { + if (generation !== this.generation) { + return + } + return this.adapter.removeItem(this.id) + }) + } + + private runOperation(operation: () => void | Promise): void { + if (this.queuePending) { + const queued = this.queue.then(operation).catch(() => { + // Persistence adapters are best-effort and must not break chat updates. + }) + this.queue = queued + void queued.finally(() => { + if (this.queue === queued) { + this.queuePending = false + } + }) + return + } + + try { + const result = operation() + if (result instanceof Promise) { + this.queuePending = true + const queued = result.catch(() => { + // Persistence adapters are best-effort and must not break chat updates. + }) + this.queue = queued + void queued.finally(() => { + if (this.queue === queued) { + this.queuePending = false + } + }) + } + } catch { + // Persistence adapters are best-effort and must not break chat updates. + } + } + + // --------------------------------------------------------------------------- + // Clear-during-stream suppression + // --------------------------------------------------------------------------- + + /** + * Capture the message/run ids that exist at the moment of a clear so chunks + * still arriving for them can be ignored. + */ + snapshotClear(context: { + messages: Array + activeRunIds: Set + currentRunId: string | null + }): void { + for (const message of context.messages) { + this.clearedMessageIds.add(message.id) + } + for (const runId of context.activeRunIds) { + this.clearedRunIds.add(runId) + this.ignoredActiveRunIds.add(runId) + } + if (context.currentRunId) { + this.clearedRunIds.add(context.currentRunId) + this.ignoredActiveRunIds.add(context.currentRunId) + } + } + + /** Mark that the next persisted message change (the clear itself) is skipped. */ + beginClear(): void { + this.skipNextPersist = true + } + + /** Whether a chunk belongs to cleared state and should not be processed. */ + shouldIgnoreChunk(chunk: StreamChunk): boolean { + const runId = getChunkRunId(chunk) + if (runId && this.clearedRunIds.has(runId)) { + if (chunk.type === 'RUN_STARTED') { + this.ignoredActiveRunIds.add(runId) + this.currentRunlessRunId = runId + } + this.markIgnoredChunkIds(chunk) + return true + } + + if (runId && this.ignoredActiveRunIds.has(runId)) { + this.markIgnoredChunkIds(chunk) + return true + } + + if (this.isRunlessChunkFromIgnoredRun(chunk)) { + this.markIgnoredChunkIds(chunk) + return true + } + + const toolCallId = (chunk as { toolCallId?: string }).toolCallId + if (toolCallId && this.clearedToolCallIds.has(toolCallId)) { + return true + } + + const parentMessageId = (chunk as { parentMessageId?: string }) + .parentMessageId + if (parentMessageId && this.clearedMessageIds.has(parentMessageId)) { + if (toolCallId) { + this.clearedToolCallIds.add(toolCallId) + } + return true + } + + const messageId = (chunk as { messageId?: string }).messageId + if (!messageId) { + return false + } + if (this.clearedMessageIds.has(messageId)) { + return true + } + + return false + } + + /** + * The owning client calls this when a run starts so runless content chunks + * (adapters that omit `runId` on content events) can be attributed to it. + */ + onRunStarted(runId: string): void { + this.currentRunlessRunId = runId + } + + /** Forget a settled run, advancing the runless pointer to another ignored run. */ + onRunSettled(runId: string): void { + this.ignoredActiveRunIds.delete(runId) + this.clearedRunIds.delete(runId) + if (this.currentRunlessRunId === runId) { + this.currentRunlessRunId = + this.ignoredActiveRunIds.values().next().value ?? null + } + } + + /** A session-level (runId-less) RUN_ERROR clears all ignored-run tracking. */ + onSessionRunError(): void { + this.ignoredActiveRunIds.clear() + this.currentRunlessRunId = null + } + + /** Clear the ignored-active-run markers (mirrors a session-generating reset). */ + resetIgnored(): void { + this.ignoredActiveRunIds.clear() + } + + /** + * Consume the current runless run id (if any), forgetting it. Used when an + * ignored, runId-less RUN_ERROR drains the run the client is still tracking. + */ + takeRunlessRunId(): string | null { + const runId = this.currentRunlessRunId + if (!runId) return null + this.ignoredActiveRunIds.delete(runId) + this.clearedRunIds.delete(runId) + this.currentRunlessRunId = null + return runId + } + + private markIgnoredChunkIds(chunk: StreamChunk): void { + const messageId = (chunk as { messageId?: string }).messageId + if (messageId) { + this.clearedMessageIds.add(messageId) + } + const toolCallId = (chunk as { toolCallId?: string }).toolCallId + if (toolCallId) { + this.clearedToolCallIds.add(toolCallId) + } + } + + private isRunlessChunkFromIgnoredRun(chunk: StreamChunk): boolean { + const runId = getChunkRunId(chunk) + if (runId || !this.currentRunlessRunId) return false + if ( + !this.ignoredActiveRunIds.has(this.currentRunlessRunId) && + !this.clearedRunIds.has(this.currentRunlessRunId) + ) { + return false + } + return ( + chunk.type === 'TEXT_MESSAGE_START' || + chunk.type === 'TEXT_MESSAGE_CONTENT' || + chunk.type === 'TOOL_CALL_START' || + chunk.type === 'TOOL_CALL_ARGS' || + chunk.type === 'TOOL_CALL_END' || + chunk.type === 'TOOL_CALL_RESULT' || + chunk.type === 'MESSAGES_SNAPSHOT' || + chunk.type === 'RUN_ERROR' + ) + } +} diff --git a/packages/typescript/ai-client/src/connection-adapters.ts b/packages/typescript/ai-client/src/connection-adapters.ts index 6be665dcf..bb303d44f 100644 --- a/packages/typescript/ai-client/src/connection-adapters.ts +++ b/packages/typescript/ai-client/src/connection-adapters.ts @@ -16,6 +16,18 @@ export function getInternalRunContextIds( return chunkRunContextIds.get(chunk) } +/** + * Resolve a chunk's run id, preferring the value on the chunk itself and + * falling back to the internal run-context map populated by the connection + * adapter for events whose wire schema omits `runId`. + */ +export function getChunkRunId(chunk: StreamChunk): string | undefined { + return ( + (chunk as { runId?: string }).runId ?? + getInternalRunContextIds(chunk)?.runId + ) +} + /** * Merge custom headers into request headers */ From 92c44c9329d6df7bae383c885a118382726d6796 Mon Sep 17 00:00:00 2001 From: Alem Tuzlak Date: Sat, 30 May 2026 17:18:42 +0200 Subject: [PATCH 03/10] test(ai-client): add ChatPersistor unit coverage Add a dedicated isolation suite for the extracted ChatPersistor (31 tests) covering edge cases that are awkward to trigger through the full ChatClient: - readInitial: sync read, async promise passthrough, throwing adapter - hydrateAsync: array / null / undefined / non-promise / rejection, plus the generation guard that drops a stale async hydration after a local change - notifyMessagesChanged: persists a snapshot (not the live array), skips exactly one write after beginClear, swallows sync errors, runs async writes FIFO with no overlap, and isolates rejections - remove: removeItem, swallowed errors, and generation invalidation of a queued write a removal supersedes - shouldIgnoreChunk: every branch (cleared message id, cleared run id, in-flight currentRunId, runless content/snapshot, parentMessageId with toolCallId memoization) - run lifecycle hooks: onRunSettled forget + runless-pointer advance, onSessionRunError, takeRunlessRunId, resetIgnored (partial reset), and the guard that a non-cleared run is never suppressed Adds shared createUIMessage / createMockPersistence helpers to test-utils. Full ai-client suite: 279 passing. --- .../ai-client/tests/client-persistor.test.ts | 471 ++++++++++++++++++ .../typescript/ai-client/tests/test-utils.ts | 29 +- 2 files changed, 499 insertions(+), 1 deletion(-) create mode 100644 packages/typescript/ai-client/tests/client-persistor.test.ts diff --git a/packages/typescript/ai-client/tests/client-persistor.test.ts b/packages/typescript/ai-client/tests/client-persistor.test.ts new file mode 100644 index 000000000..a600fc44c --- /dev/null +++ b/packages/typescript/ai-client/tests/client-persistor.test.ts @@ -0,0 +1,471 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { EventType } from '@tanstack/ai' +import { ChatPersistor } from '../src/client-persistor' +import { createMockPersistence, createUIMessage } from './test-utils' +import type { StreamChunk } from '@tanstack/ai' +import type { ChatClientPersistence, UIMessage } from '../src/types' + +const CHAT_ID = 'chat-1' + +/** Resolve after pending micro- and macro-tasks have drained. */ +function flushAsync(): Promise { + return new Promise((resolve) => setTimeout(resolve, 0)) +} + +/** A promise with externally accessible resolve/reject. */ +function createDeferred() { + let resolve!: (value: T) => void + let reject!: (reason?: unknown) => void + const promise = new Promise((res, rej) => { + resolve = res + reject = rej + }) + return { promise, resolve, reject } +} + +/** + * Build a persistor wired to `adapter` with a spy `applyMessages` callback so + * tests can assert what (if anything) gets re-applied to the client. + */ +function createPersistor(adapter: ChatClientPersistence, id: string = CHAT_ID) { + const applyMessages = vi.fn<(messages: Array) => void>() + const persistor = new ChatPersistor(adapter, id, applyMessages) + return { persistor, applyMessages } +} + +// --- typed StreamChunk fixtures ------------------------------------------- +// These satisfy the StreamChunk union directly (no casts). The persistor only +// ever reads `type` and the relevant id field off a chunk, and unit-test chunks +// are never run through the connection adapter's run-context map, so the run id +// must live on the chunk itself (only RUN_STARTED/RUN_FINISHED carry one). + +function runStarted(runId: string, threadId: string = 'thread-1'): StreamChunk { + return { type: EventType.RUN_STARTED, threadId, runId } +} + +function runFinished(runId: string, threadId: string = 'thread-1'): StreamChunk { + return { type: EventType.RUN_FINISHED, threadId, runId } +} + +function runError(message: string = 'boom'): StreamChunk { + return { type: EventType.RUN_ERROR, message } +} + +function textContent(messageId: string, delta: string = 'hi'): StreamChunk { + return { type: EventType.TEXT_MESSAGE_CONTENT, messageId, delta } +} + +function toolCallStart( + toolCallId: string, + parentMessageId?: string, +): StreamChunk { + return { + type: EventType.TOOL_CALL_START, + toolCallId, + toolCallName: 'tool', + toolName: 'tool', + ...(parentMessageId ? { parentMessageId } : {}), + } +} + +function messagesSnapshot(): StreamChunk { + return { type: EventType.MESSAGES_SNAPSHOT, messages: [] } +} + +describe('ChatPersistor', () => { + describe('readInitial', () => { + it('returns the stored messages from a synchronous getItem', () => { + const stored = [createUIMessage('m-1')] + const adapter = createMockPersistence(stored) + const { persistor } = createPersistor(adapter) + + expect(persistor.readInitial()).toBe(stored) + expect(adapter.getItem).toHaveBeenCalledWith(CHAT_ID) + }) + + it('returns the promise from an asynchronous getItem', () => { + const stored = [createUIMessage('m-1')] + const adapter = createMockPersistence() + adapter.getItem = vi.fn(() => Promise.resolve(stored)) + const { persistor } = createPersistor(adapter) + + expect(persistor.readInitial()).toBeInstanceOf(Promise) + }) + + it('swallows a throwing getItem and returns undefined', () => { + const adapter = createMockPersistence() + adapter.getItem = vi.fn(() => { + throw new Error('storage unavailable') + }) + const { persistor } = createPersistor(adapter) + + expect(persistor.readInitial()).toBeUndefined() + }) + }) + + describe('hydrateAsync', () => { + it('applies messages once the promise resolves to an array', async () => { + const stored = [createUIMessage('persisted-1')] + const { persistor, applyMessages } = createPersistor( + createMockPersistence(), + ) + + persistor.hydrateAsync(Promise.resolve(stored)) + await flushAsync() + + expect(applyMessages).toHaveBeenCalledWith(stored) + }) + + it.each([ + ['null', null], + ['undefined', undefined], + ])('does not apply when the promise resolves to %s', async (_label, value) => { + const { persistor, applyMessages } = createPersistor( + createMockPersistence(), + ) + + persistor.hydrateAsync(Promise.resolve(value)) + await flushAsync() + + expect(applyMessages).not.toHaveBeenCalled() + }) + + it('does nothing for a synchronous (non-promise) value', async () => { + const { persistor, applyMessages } = createPersistor( + createMockPersistence(), + ) + + persistor.hydrateAsync([createUIMessage('m-1')]) + await flushAsync() + + expect(applyMessages).not.toHaveBeenCalled() + }) + + it('does not apply if messages changed before hydration resolves', async () => { + const deferred = createDeferred>() + const { persistor, applyMessages } = createPersistor( + createMockPersistence(), + ) + + persistor.hydrateAsync(deferred.promise) + // A local change lands before the slow getItem resolves. + persistor.notifyMessagesChanged([createUIMessage('local-1')]) + deferred.resolve([createUIMessage('persisted-1')]) + await flushAsync() + + expect(applyMessages).not.toHaveBeenCalled() + }) + + it('swallows a rejected hydration promise', async () => { + const { persistor, applyMessages } = createPersistor( + createMockPersistence(), + ) + + persistor.hydrateAsync(Promise.reject(new Error('read failed'))) + await flushAsync() + + expect(applyMessages).not.toHaveBeenCalled() + }) + }) + + describe('notifyMessagesChanged', () => { + it('persists the messages on change', () => { + const adapter = createMockPersistence() + const { persistor } = createPersistor(adapter) + const messages = [createUIMessage('m-1')] + + persistor.notifyMessagesChanged(messages) + + expect(adapter.setItem).toHaveBeenCalledWith(CHAT_ID, messages) + }) + + it('persists a snapshot, not the live array reference', () => { + const adapter = createMockPersistence() + const { persistor } = createPersistor(adapter) + const messages = [createUIMessage('m-1')] + + persistor.notifyMessagesChanged(messages) + messages.push(createUIMessage('m-2')) + + const persisted = vi.mocked(adapter.setItem).mock.calls[0]?.[1] + expect(persisted).toHaveLength(1) + }) + + it('skips exactly one write after beginClear, then resumes', () => { + const adapter = createMockPersistence() + const { persistor } = createPersistor(adapter) + + persistor.notifyMessagesChanged([createUIMessage('m-1')]) + persistor.beginClear() + persistor.notifyMessagesChanged([]) // the clear's empty snapshot — skipped + persistor.notifyMessagesChanged([createUIMessage('m-2')]) + + expect(adapter.setItem).toHaveBeenCalledTimes(2) + }) + + it('swallows a synchronous setItem error', () => { + const adapter = createMockPersistence() + adapter.setItem = vi.fn(() => { + throw new Error('quota exceeded') + }) + const { persistor } = createPersistor(adapter) + + expect(() => + persistor.notifyMessagesChanged([createUIMessage('m-1')]), + ).not.toThrow() + }) + + it('runs async writes sequentially (FIFO, no overlap)', async () => { + const first = createDeferred() + const second = createDeferred() + const adapter = createMockPersistence() + adapter.setItem = vi + .fn() + .mockReturnValueOnce(first.promise) + .mockReturnValueOnce(second.promise) + const { persistor } = createPersistor(adapter) + + persistor.notifyMessagesChanged([createUIMessage('a')]) + persistor.notifyMessagesChanged([createUIMessage('b')]) + // Second write is queued behind the first and must not have started yet. + expect(adapter.setItem).toHaveBeenCalledTimes(1) + + first.resolve() + await flushAsync() + + expect(adapter.setItem).toHaveBeenCalledTimes(2) + expect(vi.mocked(adapter.setItem).mock.calls[1]?.[1]).toEqual([ + createUIMessage('b'), + ]) + }) + + it('keeps writing after an async write rejects', async () => { + const adapter = createMockPersistence() + adapter.setItem = vi + .fn() + .mockRejectedValueOnce(new Error('transient')) + .mockResolvedValue(undefined) + const { persistor } = createPersistor(adapter) + + persistor.notifyMessagesChanged([createUIMessage('a')]) + persistor.notifyMessagesChanged([createUIMessage('b')]) + await flushAsync() + + expect(adapter.setItem).toHaveBeenCalledTimes(2) + }) + }) + + describe('remove', () => { + it('removes the persisted conversation', () => { + const adapter = createMockPersistence() + const { persistor } = createPersistor(adapter) + + persistor.remove() + + expect(adapter.removeItem).toHaveBeenCalledWith(CHAT_ID) + }) + + it('swallows a synchronous removeItem error', () => { + const adapter = createMockPersistence() + adapter.removeItem = vi.fn(() => { + throw new Error('locked') + }) + const { persistor } = createPersistor(adapter) + + expect(() => persistor.remove()).not.toThrow() + }) + + it('invalidates a queued write that a removal supersedes', async () => { + const inFlight = createDeferred() + const adapter = createMockPersistence() + adapter.setItem = vi + .fn() + .mockReturnValueOnce(inFlight.promise) // write A (in flight) + .mockResolvedValue(undefined) // write B (would run later) + const { persistor } = createPersistor(adapter) + + persistor.notifyMessagesChanged([createUIMessage('a')]) // starts, in flight + persistor.notifyMessagesChanged([createUIMessage('b')]) // queued behind A + persistor.remove() // bumps generation, queued behind B + + inFlight.resolve() + await flushAsync() + + // B's generation was superseded by remove(), so only A was written. + expect(adapter.setItem).toHaveBeenCalledTimes(1) + expect(adapter.removeItem).toHaveBeenCalledTimes(1) + }) + }) + + describe('clear-during-stream suppression', () => { + it('does not ignore chunks before a clear', () => { + const { persistor } = createPersistor(createMockPersistence()) + + expect(persistor.shouldIgnoreChunk(runStarted('run-1'))).toBe(false) + expect(persistor.shouldIgnoreChunk(textContent('msg-1'))).toBe(false) + }) + + it('ignores chunks for a message id captured at clear time', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [createUIMessage('msg-1')], + activeRunIds: new Set(), + currentRunId: null, + }) + + expect(persistor.shouldIgnoreChunk(textContent('msg-1'))).toBe(true) + expect(persistor.shouldIgnoreChunk(textContent('msg-other'))).toBe(false) + }) + + it('ignores a RUN_STARTED for an active run captured at clear time', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [], + activeRunIds: new Set(['run-1']), + currentRunId: null, + }) + + expect(persistor.shouldIgnoreChunk(runStarted('run-1'))).toBe(true) + }) + + it('ignores the in-flight currentRunId captured at clear time', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [], + activeRunIds: new Set(), + currentRunId: 'run-current', + }) + + expect(persistor.shouldIgnoreChunk(runStarted('run-current'))).toBe(true) + }) + + it('ignores runless content chunks belonging to a cleared run', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [], + activeRunIds: new Set(['run-1']), + currentRunId: null, + }) + // The cleared run starts again, pinning it as the current runless run. + persistor.shouldIgnoreChunk(runStarted('run-1')) + + expect(persistor.shouldIgnoreChunk(textContent('late-msg'))).toBe(true) + expect(persistor.shouldIgnoreChunk(messagesSnapshot())).toBe(true) + }) + + it('ignores tool chunks by cleared parentMessageId and remembers the toolCallId', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [createUIMessage('parent-msg')], + activeRunIds: new Set(), + currentRunId: null, + }) + + // First seen via its parent message id... + expect( + persistor.shouldIgnoreChunk(toolCallStart('tc-1', 'parent-msg')), + ).toBe(true) + // ...then a later chunk for the same tool call (no parent) is still ignored. + expect(persistor.shouldIgnoreChunk(toolCallStart('tc-1'))).toBe(true) + }) + }) + + describe('run lifecycle hooks', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('forgets a settled run so a later run with the same id is not ignored', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [], + activeRunIds: new Set(['run-1']), + currentRunId: null, + }) + expect(persistor.shouldIgnoreChunk(runStarted('run-1'))).toBe(true) + + persistor.onRunSettled('run-1') + + expect(persistor.shouldIgnoreChunk(runStarted('run-1'))).toBe(false) + }) + + it('stops ignoring runless chunks after a session-level run error', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [], + activeRunIds: new Set(['run-1']), + currentRunId: null, + }) + persistor.shouldIgnoreChunk(runStarted('run-1')) + expect(persistor.shouldIgnoreChunk(textContent('late-1'))).toBe(true) + + persistor.onSessionRunError() + + // A fresh runless chunk is no longer attributed to the cleared run. (An + // already-seen message id stays ignored — shouldIgnoreChunk records it.) + expect(persistor.shouldIgnoreChunk(textContent('late-2'))).toBe(false) + }) + + it('takeRunlessRunId returns and clears the current runless run, then null', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [], + activeRunIds: new Set(['run-1']), + currentRunId: null, + }) + persistor.shouldIgnoreChunk(runStarted('run-1')) + + expect(persistor.takeRunlessRunId()).toBe('run-1') + expect(persistor.takeRunlessRunId()).toBeNull() + // The drained run is no longer pinned, so runless chunks aren't ignored. + expect(persistor.shouldIgnoreChunk(textContent('late'))).toBe(false) + }) + + it('resetIgnored only clears active-run markers, not cleared ids', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [createUIMessage('msg-1')], + activeRunIds: new Set(['run-1']), + currentRunId: null, + }) + + persistor.resetIgnored() + + // Cleared run/message ids survive a reset (they outlive the active run). + expect(persistor.shouldIgnoreChunk(runStarted('run-1'))).toBe(true) + expect(persistor.shouldIgnoreChunk(textContent('msg-1'))).toBe(true) + }) + + it('does not ignore an unrelated runId-less run error', () => { + const { persistor } = createPersistor(createMockPersistence()) + + expect(persistor.shouldIgnoreChunk(runError())).toBe(false) + }) + + it('onRunStarted for a non-cleared run does not suppress runless chunks', () => { + const { persistor } = createPersistor(createMockPersistence()) + + // Pinning a live (never-cleared) run must not start swallowing its output. + persistor.onRunStarted('run-live') + + expect(persistor.shouldIgnoreChunk(textContent('content'))).toBe(false) + }) + + it('advances the runless pointer when the active run finishes', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [], + activeRunIds: new Set(['run-1', 'run-2']), + currentRunId: null, + }) + persistor.shouldIgnoreChunk(runStarted('run-1')) + persistor.shouldIgnoreChunk(runStarted('run-2')) + + // run-2 was the pinned runless run; settling it falls back to run-1. + persistor.onRunSettled('run-2') + + expect(persistor.shouldIgnoreChunk(runFinished('run-2'))).toBe(false) + expect(persistor.shouldIgnoreChunk(textContent('late'))).toBe(true) + }) + }) +}) diff --git a/packages/typescript/ai-client/tests/test-utils.ts b/packages/typescript/ai-client/tests/test-utils.ts index 20a728a8f..84d474ed2 100644 --- a/packages/typescript/ai-client/tests/test-utils.ts +++ b/packages/typescript/ai-client/tests/test-utils.ts @@ -1,6 +1,33 @@ +import { vi } from 'vitest' import type { ConnectConnectionAdapter } from '../src/connection-adapters' import type { ModelMessage, StreamChunk } from '@tanstack/ai' -import type { UIMessage } from '../src/types' +import type { ChatClientPersistence, UIMessage } from '../src/types' + +/** + * Build a minimal text {@link UIMessage} for tests. + */ +export function createUIMessage( + id: string, + text: string = 'hello', + role: UIMessage['role'] = 'user', +): UIMessage { + return { id, role, parts: [{ type: 'text', content: text }] } +} + +/** + * Create a persistence adapter whose three methods are vitest spies. `getItem` + * synchronously returns `initial` (defaults to `undefined`); override individual + * methods via the returned object's `.mock*` helpers for async/error scenarios. + */ +export function createMockPersistence( + initial?: Array | null, +): ChatClientPersistence { + return { + getItem: vi.fn(() => initial), + setItem: vi.fn(), + removeItem: vi.fn(), + } +} /** * Options for creating a mock connection adapter */ From 1f67db32f8b369cada91f26c894a5871ca8789e9 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Sat, 30 May 2026 15:53:57 +0000 Subject: [PATCH 04/10] ci: apply automated fixes --- .changeset/chat-client-persistence.md | 12 +- packages/ai-client/tests/chat-client.test.ts | 290 +++++++++--------- .../ai-client/tests/client-persistor.test.ts | 28 +- packages/ai-preact/tests/use-chat.test.ts | 4 +- packages/ai-react/tests/use-chat.test.ts | 52 ++-- packages/ai-solid/tests/use-chat.test.ts | 4 +- 6 files changed, 202 insertions(+), 188 deletions(-) diff --git a/.changeset/chat-client-persistence.md b/.changeset/chat-client-persistence.md index 0a6124036..a040e656c 100644 --- a/.changeset/chat-client-persistence.md +++ b/.changeset/chat-client-persistence.md @@ -1,10 +1,10 @@ --- -"@tanstack/ai-client": minor -"@tanstack/ai-react": minor -"@tanstack/ai-preact": minor -"@tanstack/ai-solid": minor -"@tanstack/ai-svelte": minor -"@tanstack/ai-vue": minor +'@tanstack/ai-client': minor +'@tanstack/ai-react': minor +'@tanstack/ai-preact': minor +'@tanstack/ai-solid': minor +'@tanstack/ai-svelte': minor +'@tanstack/ai-vue': minor --- Add persistence support for chat messages. diff --git a/packages/ai-client/tests/chat-client.test.ts b/packages/ai-client/tests/chat-client.test.ts index d44abe9d6..dd9e88408 100644 --- a/packages/ai-client/tests/chat-client.test.ts +++ b/packages/ai-client/tests/chat-client.test.ts @@ -330,8 +330,10 @@ describe('ChatClient', () => { it('should ignore native subscribe/send chunks from a cleared persisted request without runId', async () => { let storedMessages: Array | undefined const releaseFirstResponse = createDeferred() - const queuedChunks: Array<{ prompt: string; chunks: Array }> = - [] + const queuedChunks: Array<{ + prompt: string + chunks: Array + }> = [] let wakeSubscriber: (() => void) | null = null const adapter: ConnectionAdapter = { subscribe: vi.fn((_signal?: AbortSignal) => { @@ -455,39 +457,41 @@ describe('ChatClient', () => { let wakeSubscriber: (() => void) | null = null let queued = false const adapter: ConnectionAdapter = { - subscribe: vi.fn((_signal?: AbortSignal): AsyncIterable => { - return (async function* () { - while (true) { - if (!queued) { - await new Promise((resolve) => { - wakeSubscriber = resolve - }) + subscribe: vi.fn( + (_signal?: AbortSignal): AsyncIterable => { + return (async function* () { + while (true) { + if (!queued) { + await new Promise((resolve) => { + wakeSubscriber = resolve + }) + } + queued = false + yield { + type: EventType.RUN_STARTED, + threadId: 'thread-1', + runId: 'run-cleared', + timestamp: Date.now(), + } as StreamChunk + await releaseStaleChunks.promise + yield { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: 'stale-message', + timestamp: Date.now(), + delta: 'stale content', + content: 'stale content', + } as StreamChunk + staleChunksAttempted.resolve() + yield { + type: EventType.RUN_FINISHED, + threadId: 'thread-1', + runId: 'run-cleared', + timestamp: Date.now(), + } as StreamChunk } - queued = false - yield { - type: EventType.RUN_STARTED, - threadId: 'thread-1', - runId: 'run-cleared', - timestamp: Date.now(), - } as StreamChunk - await releaseStaleChunks.promise - yield { - type: EventType.TEXT_MESSAGE_CONTENT, - messageId: 'stale-message', - timestamp: Date.now(), - delta: 'stale content', - content: 'stale content', - } as StreamChunk - staleChunksAttempted.resolve() - yield { - type: EventType.RUN_FINISHED, - threadId: 'thread-1', - runId: 'run-cleared', - timestamp: Date.now(), - } as StreamChunk - } - })() - }), + })() + }, + ), send: vi.fn(async () => { queued = true wakeSubscriber?.() @@ -529,8 +533,10 @@ describe('ChatClient', () => { let storedMessages: Array | undefined const releaseStaleChunks = createDeferred() const staleChunksAttempted = createDeferred() - const queuedChunks: Array<{ prompt: string; chunks: Array }> = - [] + const queuedChunks: Array<{ + prompt: string + chunks: Array + }> = [] let staleReleased = false let wakeSubscriber: (() => void) | null = null const wakeQueuedSubscriber = () => { @@ -539,46 +545,48 @@ describe('ChatClient', () => { wake?.() } const adapter: ConnectionAdapter = { - subscribe: vi.fn((_signal?: AbortSignal): AsyncIterable => { - return (async function* () { - while (true) { - if (queuedChunks.length === 0) { - await new Promise((resolve) => { - wakeSubscriber = resolve - }) - } - const freshIndex = queuedChunks.findIndex( - (queued) => queued.prompt === 'B', - ) - const next = - freshIndex >= 0 - ? queuedChunks.splice(freshIndex, 1)[0] - : queuedChunks.shift() - if (!next) continue - yield next.chunks[0]! - if (next.prompt === 'A' && !staleReleased) { - queuedChunks.push({ - prompt: 'A-after-start', - chunks: next.chunks.slice(1), - }) - continue - } - if (next.prompt === 'A-after-start' && !staleReleased) { - queuedChunks.push(next) - await new Promise((resolve) => { - wakeSubscriber = resolve - }) - continue - } - for (const chunk of next.chunks.slice(1)) { - yield chunk - } - if (next.prompt === 'A-after-start') { - staleChunksAttempted.resolve() + subscribe: vi.fn( + (_signal?: AbortSignal): AsyncIterable => { + return (async function* () { + while (true) { + if (queuedChunks.length === 0) { + await new Promise((resolve) => { + wakeSubscriber = resolve + }) + } + const freshIndex = queuedChunks.findIndex( + (queued) => queued.prompt === 'B', + ) + const next = + freshIndex >= 0 + ? queuedChunks.splice(freshIndex, 1)[0] + : queuedChunks.shift() + if (!next) continue + yield next.chunks[0]! + if (next.prompt === 'A' && !staleReleased) { + queuedChunks.push({ + prompt: 'A-after-start', + chunks: next.chunks.slice(1), + }) + continue + } + if (next.prompt === 'A-after-start' && !staleReleased) { + queuedChunks.push(next) + await new Promise((resolve) => { + wakeSubscriber = resolve + }) + continue + } + for (const chunk of next.chunks.slice(1)) { + yield chunk + } + if (next.prompt === 'A-after-start') { + staleChunksAttempted.resolve() + } } - } - })() - }), + })() + }, + ), send: vi.fn( async ( messages: Array | Array, @@ -791,8 +799,10 @@ describe('ChatClient', () => { let storedMessages: Array | undefined const releaseStaleResponse = createDeferred() let staleReleased = false - const queuedChunks: Array<{ prompt: string; chunks: Array }> = - [] + const queuedChunks: Array<{ + prompt: string + chunks: Array + }> = [] let wakeSubscriber: (() => void) | null = null const adapter: ConnectionAdapter = { subscribe: vi.fn((_signal?: AbortSignal) => { @@ -980,9 +990,9 @@ describe('ChatClient', () => { const sendPromise = client.sendMessage('A') await vi.waitFor(() => { - expect(client.getMessages().some((message) => message.id === 'assistant-a')).toBe( - true, - ) + expect( + client.getMessages().some((message) => message.id === 'assistant-a'), + ).toBe(true) }) client.clear() @@ -1130,33 +1140,35 @@ describe('ChatClient', () => { let wakeSubscriber: (() => void) | null = null let queued = false const adapter: ConnectionAdapter = { - subscribe: vi.fn((_signal?: AbortSignal): AsyncIterable => { - return (async function* () { - while (true) { - if (!queued) { - await new Promise((resolve) => { - wakeSubscriber = resolve - }) + subscribe: vi.fn( + (_signal?: AbortSignal): AsyncIterable => { + return (async function* () { + while (true) { + if (!queued) { + await new Promise((resolve) => { + wakeSubscriber = resolve + }) + } + queued = false + yield { + type: EventType.RUN_STARTED, + threadId: 'thread-1', + runId: 'run-1', + timestamp: Date.now(), + } as StreamChunk + await releaseResponse.promise + yield { + type: EventType.RUN_FINISHED, + threadId: 'thread-1', + runId: 'run-1', + model: 'test', + timestamp: Date.now(), + finishReason: 'stop', + } as StreamChunk } - queued = false - yield { - type: EventType.RUN_STARTED, - threadId: 'thread-1', - runId: 'run-1', - timestamp: Date.now(), - } as StreamChunk - await releaseResponse.promise - yield { - type: EventType.RUN_FINISHED, - threadId: 'thread-1', - runId: 'run-1', - model: 'test', - timestamp: Date.now(), - finishReason: 'stop', - } as StreamChunk - } - })() - }), + })() + }, + ), send: vi.fn(async () => { queued = true wakeSubscriber?.() @@ -1186,38 +1198,40 @@ describe('ChatClient', () => { const releaseAfterClear = createDeferred() const subscriberReady = createDeferred<() => void>() const adapter: ConnectionAdapter = { - subscribe: vi.fn((_signal?: AbortSignal): AsyncIterable => { - return (async function* () { - await new Promise((resolve) => { - subscriberReady.resolve(resolve) - }) - yield { - type: EventType.RUN_STARTED, - threadId: 'thread-1', - runId: 'live-run-1', - timestamp: Date.now(), - } as StreamChunk - await releaseAfterClear.promise - yield { - type: EventType.TEXT_MESSAGE_START, - messageId: 'live-message-1', - role: 'assistant', - } as StreamChunk - yield { - type: EventType.TEXT_MESSAGE_CONTENT, - messageId: 'live-message-1', - delta: 'stale live content', - } as StreamChunk - yield { - type: EventType.RUN_FINISHED, - threadId: 'thread-1', - runId: 'live-run-1', - model: 'test', - timestamp: Date.now(), - finishReason: 'stop', - } as StreamChunk - })() - }), + subscribe: vi.fn( + (_signal?: AbortSignal): AsyncIterable => { + return (async function* () { + await new Promise((resolve) => { + subscriberReady.resolve(resolve) + }) + yield { + type: EventType.RUN_STARTED, + threadId: 'thread-1', + runId: 'live-run-1', + timestamp: Date.now(), + } as StreamChunk + await releaseAfterClear.promise + yield { + type: EventType.TEXT_MESSAGE_START, + messageId: 'live-message-1', + role: 'assistant', + } as StreamChunk + yield { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: 'live-message-1', + delta: 'stale live content', + } as StreamChunk + yield { + type: EventType.RUN_FINISHED, + threadId: 'thread-1', + runId: 'live-run-1', + model: 'test', + timestamp: Date.now(), + finishReason: 'stop', + } as StreamChunk + })() + }, + ), send: vi.fn(), } const persistence = { diff --git a/packages/ai-client/tests/client-persistor.test.ts b/packages/ai-client/tests/client-persistor.test.ts index dc6806d39..34431a10c 100644 --- a/packages/ai-client/tests/client-persistor.test.ts +++ b/packages/ai-client/tests/client-persistor.test.ts @@ -43,7 +43,10 @@ function runStarted(runId: string, threadId: string = 'thread-1'): StreamChunk { return { type: EventType.RUN_STARTED, threadId, runId } } -function runFinished(runId: string, threadId: string = 'thread-1'): StreamChunk { +function runFinished( + runId: string, + threadId: string = 'thread-1', +): StreamChunk { return { type: EventType.RUN_FINISHED, threadId, runId } } @@ -119,16 +122,19 @@ describe('ChatPersistor', () => { it.each([ ['null', null], ['undefined', undefined], - ])('does not apply when the promise resolves to %s', async (_label, value) => { - const { persistor, applyMessages } = createPersistor( - createMockPersistence(), - ) - - persistor.hydrateAsync(Promise.resolve(value)) - await flushAsync() - - expect(applyMessages).not.toHaveBeenCalled() - }) + ])( + 'does not apply when the promise resolves to %s', + async (_label, value) => { + const { persistor, applyMessages } = createPersistor( + createMockPersistence(), + ) + + persistor.hydrateAsync(Promise.resolve(value)) + await flushAsync() + + expect(applyMessages).not.toHaveBeenCalled() + }, + ) it('does nothing for a synchronous (non-promise) value', async () => { const { persistor, applyMessages } = createPersistor( diff --git a/packages/ai-preact/tests/use-chat.test.ts b/packages/ai-preact/tests/use-chat.test.ts index 9963014cf..11e874c41 100644 --- a/packages/ai-preact/tests/use-chat.test.ts +++ b/packages/ai-preact/tests/use-chat.test.ts @@ -120,9 +120,7 @@ describe('useChat', () => { }) await waitFor(() => { - expect(persistence.getItem).toHaveBeenCalledWith( - 'persisted-empty-chat', - ) + expect(persistence.getItem).toHaveBeenCalledWith('persisted-empty-chat') }) expect(result.current.messages).toEqual([]) }) diff --git a/packages/ai-react/tests/use-chat.test.ts b/packages/ai-react/tests/use-chat.test.ts index 7f9e58665..8e43eedc1 100644 --- a/packages/ai-react/tests/use-chat.test.ts +++ b/packages/ai-react/tests/use-chat.test.ts @@ -119,9 +119,7 @@ describe('useChat', () => { }) await waitFor(() => { - expect(persistence.getItem).toHaveBeenCalledWith( - 'persisted-empty-chat', - ) + expect(persistence.getItem).toHaveBeenCalledWith('persisted-empty-chat') }) expect(result.current.messages).toEqual([]) }) @@ -935,7 +933,7 @@ describe('useChat', () => { expect(first).not.toHaveBeenCalled() }) - it('should ignore user callbacks from an old client after id changes', async () => { + it('should ignore user callbacks from an old client after id changes', async () => { const releaseOldStream = createDeferred() const oldOnChunk = vi.fn() const newOnChunk = vi.fn() @@ -971,34 +969,34 @@ describe('useChat', () => { releaseOldStream.resolve() await sendPromise - expect(oldOnChunk).not.toHaveBeenCalled() - expect(newOnChunk).not.toHaveBeenCalled() - }) - - it('should keep callbacks live across StrictMode effect replay for the same client', async () => { - const onChunk = vi.fn() - const adapter = createMockConnectionAdapter({ - chunks: createTextChunks('strict response'), - }) + expect(oldOnChunk).not.toHaveBeenCalled() + expect(newOnChunk).not.toHaveBeenCalled() + }) - const { result } = renderHook( - () => - useChat({ - connection: adapter, - onChunk, - }), - { wrapper: StrictMode }, - ) + it('should keep callbacks live across StrictMode effect replay for the same client', async () => { + const onChunk = vi.fn() + const adapter = createMockConnectionAdapter({ + chunks: createTextChunks('strict response'), + }) - await act(async () => { - await result.current.sendMessage('Test') - }) + const { result } = renderHook( + () => + useChat({ + connection: adapter, + onChunk, + }), + { wrapper: StrictMode }, + ) - expect(onChunk).toHaveBeenCalledWith( - expect.objectContaining({ type: EventType.TEXT_MESSAGE_CONTENT }), - ) + await act(async () => { + await result.current.sendMessage('Test') }) + + expect(onChunk).toHaveBeenCalledWith( + expect.objectContaining({ type: EventType.TEXT_MESSAGE_CONTENT }), + ) }) + }) describe('edge cases and error handling', () => { describe('options changes', () => { diff --git a/packages/ai-solid/tests/use-chat.test.ts b/packages/ai-solid/tests/use-chat.test.ts index 350fb7fce..31ba2f26f 100644 --- a/packages/ai-solid/tests/use-chat.test.ts +++ b/packages/ai-solid/tests/use-chat.test.ts @@ -107,9 +107,7 @@ describe('useChat', () => { }) await waitFor(() => { - expect(persistence.getItem).toHaveBeenCalledWith( - 'persisted-empty-chat', - ) + expect(persistence.getItem).toHaveBeenCalledWith('persisted-empty-chat') }) expect(result.current.messages).toEqual([]) }) From 787c99ade922c2675a1f2324506868727bf1f56a Mon Sep 17 00:00:00 2001 From: Alem Tuzlak Date: Sat, 30 May 2026 18:37:10 +0200 Subject: [PATCH 05/10] test(ai-client): de-flake clear-without-persistence stream test The test relied on a 10ms timer racing against vi.waitFor to ensure clear() ran before the delayed chunks. On faster machines/CI the chunks were processed before clear(), so clear() wiped them and the assertion saw an empty message list. Gate the chunks on a deferred released after clear() instead, making the ordering deterministic across platforms. --- packages/ai-client/tests/chat-client.test.ts | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/packages/ai-client/tests/chat-client.test.ts b/packages/ai-client/tests/chat-client.test.ts index dd9e88408..7f8cb66c4 100644 --- a/packages/ai-client/tests/chat-client.test.ts +++ b/packages/ai-client/tests/chat-client.test.ts @@ -2322,11 +2322,15 @@ describe('ChatClient', () => { it('should not abort an in-flight stream when persistence is omitted', async () => { let abortSignal: AbortSignal | undefined + // Gate the chunks on a deferred (instead of a fixed timer) so they are + // released strictly after clear() runs — otherwise the assertion races + // the stream and is flaky on faster machines/CI. + const releaseChunks = createDeferred() const adapter: ConnectConnectionAdapter = { async *connect(_messages, _data, signal) { abortSignal = signal - await new Promise((resolve) => setTimeout(resolve, 10)) + await releaseChunks.promise yield* createTextChunks('Delayed') }, } @@ -2340,6 +2344,9 @@ describe('ChatClient', () => { client.clear() expect(abortSignal?.aborted).toBe(false) + // Without persistence, clear() does not abort the in-flight stream, so + // its chunks still populate messages once they arrive. + releaseChunks.resolve() await sendPromise expect(client.getMessages()).toEqual( From bf19249f816d556ced575ea5b33a5a50ac82b1af Mon Sep 17 00:00:00 2001 From: Alem Tuzlak Date: Sun, 31 May 2026 14:58:33 +0200 Subject: [PATCH 06/10] feat(examples): add ChatGPT-style multi-thread persistence demo to ts-react-chat Adds a /threads route that demonstrates the new useChat `persistence` option: per-thread message history keyed by ChatClient id (namespaced in localStorage), plus an app-owned thread index (sidebar catalog) for listing, titling, and ordering conversations. Supports new chat, switch (with re-hydration), delete, title-from-first-message, and persistence across page reloads. Also removes a duplicate `TranscriptionGenerateInput` import in generations.transcription.tsx that broke route generation and prevented the example from booting. --- .../ts-react-chat/src/components/Header.tsx | 14 + examples/ts-react-chat/src/routeTree.gen.ts | 21 + .../src/routes/generations.transcription.tsx | 1 - examples/ts-react-chat/src/routes/threads.tsx | 390 ++++++++++++++++++ 4 files changed, 425 insertions(+), 1 deletion(-) create mode 100644 examples/ts-react-chat/src/routes/threads.tsx diff --git a/examples/ts-react-chat/src/components/Header.tsx b/examples/ts-react-chat/src/components/Header.tsx index bd4a1822d..2123e03b7 100644 --- a/examples/ts-react-chat/src/components/Header.tsx +++ b/examples/ts-react-chat/src/components/Header.tsx @@ -11,6 +11,7 @@ import { Image, Menu, Mic, + MessageSquare, Music, Server, Video, @@ -213,6 +214,19 @@ export default function Header() { Guitar Demo + setIsOpen(false)} + className="flex items-center gap-3 p-3 rounded-lg hover:bg-gray-800 transition-colors mb-2" + activeProps={{ + className: + 'flex items-center gap-3 p-3 rounded-lg bg-cyan-600 hover:bg-cyan-700 transition-colors mb-2', + }} + > + + Persistent Chats + + setIsOpen(false)} diff --git a/examples/ts-react-chat/src/routeTree.gen.ts b/examples/ts-react-chat/src/routeTree.gen.ts index 6bbf55373..d52d923c5 100644 --- a/examples/ts-react-chat/src/routeTree.gen.ts +++ b/examples/ts-react-chat/src/routeTree.gen.ts @@ -9,6 +9,7 @@ // Additionally, you should also exclude this file from your linter and/or formatter to prevent it from being checked or modified. import { Route as rootRouteImport } from './routes/__root' +import { Route as ThreadsRouteImport } from './routes/threads' import { Route as ServerFnChatRouteImport } from './routes/server-fn-chat' import { Route as RealtimeRouteImport } from './routes/realtime' import { Route as Issue176ToolResultRouteImport } from './routes/issue-176-tool-result' @@ -36,6 +37,11 @@ import { Route as ApiGenerateSpeechRouteImport } from './routes/api.generate.spe import { Route as ApiGenerateImageRouteImport } from './routes/api.generate.image' import { Route as ApiGenerateAudioRouteImport } from './routes/api.generate.audio' +const ThreadsRoute = ThreadsRouteImport.update({ + id: '/threads', + path: '/threads', + getParentRoute: () => rootRouteImport, +} as any) const ServerFnChatRoute = ServerFnChatRouteImport.update({ id: '/server-fn-chat', path: '/server-fn-chat', @@ -177,6 +183,7 @@ export interface FileRoutesByFullPath { '/issue-176-tool-result': typeof Issue176ToolResultRoute '/realtime': typeof RealtimeRoute '/server-fn-chat': typeof ServerFnChatRoute + '/threads': typeof ThreadsRoute '/api/image-gen': typeof ApiImageGenRoute '/api/structured-chat': typeof ApiStructuredChatRoute '/api/structured-output': typeof ApiStructuredOutputRoute @@ -205,6 +212,7 @@ export interface FileRoutesByTo { '/issue-176-tool-result': typeof Issue176ToolResultRoute '/realtime': typeof RealtimeRoute '/server-fn-chat': typeof ServerFnChatRoute + '/threads': typeof ThreadsRoute '/api/image-gen': typeof ApiImageGenRoute '/api/structured-chat': typeof ApiStructuredChatRoute '/api/structured-output': typeof ApiStructuredOutputRoute @@ -234,6 +242,7 @@ export interface FileRoutesById { '/issue-176-tool-result': typeof Issue176ToolResultRoute '/realtime': typeof RealtimeRoute '/server-fn-chat': typeof ServerFnChatRoute + '/threads': typeof ThreadsRoute '/api/image-gen': typeof ApiImageGenRoute '/api/structured-chat': typeof ApiStructuredChatRoute '/api/structured-output': typeof ApiStructuredOutputRoute @@ -264,6 +273,7 @@ export interface FileRouteTypes { | '/issue-176-tool-result' | '/realtime' | '/server-fn-chat' + | '/threads' | '/api/image-gen' | '/api/structured-chat' | '/api/structured-output' @@ -292,6 +302,7 @@ export interface FileRouteTypes { | '/issue-176-tool-result' | '/realtime' | '/server-fn-chat' + | '/threads' | '/api/image-gen' | '/api/structured-chat' | '/api/structured-output' @@ -320,6 +331,7 @@ export interface FileRouteTypes { | '/issue-176-tool-result' | '/realtime' | '/server-fn-chat' + | '/threads' | '/api/image-gen' | '/api/structured-chat' | '/api/structured-output' @@ -349,6 +361,7 @@ export interface RootRouteChildren { Issue176ToolResultRoute: typeof Issue176ToolResultRoute RealtimeRoute: typeof RealtimeRoute ServerFnChatRoute: typeof ServerFnChatRoute + ThreadsRoute: typeof ThreadsRoute ApiImageGenRoute: typeof ApiImageGenRoute ApiStructuredChatRoute: typeof ApiStructuredChatRoute ApiStructuredOutputRoute: typeof ApiStructuredOutputRoute @@ -373,6 +386,13 @@ export interface RootRouteChildren { declare module '@tanstack/react-router' { interface FileRoutesByPath { + '/threads': { + id: '/threads' + path: '/threads' + fullPath: '/threads' + preLoaderRoute: typeof ThreadsRouteImport + parentRoute: typeof rootRouteImport + } '/server-fn-chat': { id: '/server-fn-chat' path: '/server-fn-chat' @@ -565,6 +585,7 @@ const rootRouteChildren: RootRouteChildren = { Issue176ToolResultRoute: Issue176ToolResultRoute, RealtimeRoute: RealtimeRoute, ServerFnChatRoute: ServerFnChatRoute, + ThreadsRoute: ThreadsRoute, ApiImageGenRoute: ApiImageGenRoute, ApiStructuredChatRoute: ApiStructuredChatRoute, ApiStructuredOutputRoute: ApiStructuredOutputRoute, diff --git a/examples/ts-react-chat/src/routes/generations.transcription.tsx b/examples/ts-react-chat/src/routes/generations.transcription.tsx index d67e0de02..6a7000e33 100644 --- a/examples/ts-react-chat/src/routes/generations.transcription.tsx +++ b/examples/ts-react-chat/src/routes/generations.transcription.tsx @@ -4,7 +4,6 @@ import { useTranscription } from '@tanstack/ai-react' import type { UseTranscriptionReturn } from '@tanstack/ai-react' import type { TranscriptionGenerateInput } from '@tanstack/ai-client' import { fetchServerSentEvents } from '@tanstack/ai-client' -import type { TranscriptionGenerateInput } from '@tanstack/ai-client' import { transcribeFn, transcribeStreamFn } from '../lib/server-fns' import { TRANSCRIPTION_PROVIDERS, diff --git a/examples/ts-react-chat/src/routes/threads.tsx b/examples/ts-react-chat/src/routes/threads.tsx new file mode 100644 index 000000000..b7a2cec19 --- /dev/null +++ b/examples/ts-react-chat/src/routes/threads.tsx @@ -0,0 +1,390 @@ +import { useEffect, useRef, useState } from 'react' +import { createFileRoute } from '@tanstack/react-router' +import { fetchServerSentEvents, useChat } from '@tanstack/ai-react' +import { MessageSquare, Plus, Trash2 } from 'lucide-react' +import type { ChatClientPersistence, UIMessage } from '@tanstack/ai-client' + +export const Route = createFileRoute('/threads')({ + component: ThreadsRoute, +}) + +// --------------------------------------------------------------------------- +// Storage layer +// +// Two concerns, two key spaces — this is the whole point of the demo: +// +// 1. Per-thread message history → the PR's persistence adapter, keyed by the +// ChatClient `id`. We namespace it under `THREAD_KEY_PREFIX + id` so every +// thread is its own localStorage entry and threads never overwrite each +// other. +// 2. The thread catalog (which threads exist, their titles, recency) → a +// separate index the example owns. The persistence adapter is deliberately +// single-conversation and knows nothing about "all my chats", so listing +// and ordering threads is an app-level layer built on top. +// --------------------------------------------------------------------------- + +const INDEX_KEY = 'tanstack-ai:threads' +const THREAD_KEY_PREFIX = 'tanstack-ai:thread:' +const NEW_CHAT_TITLE = 'New chat' +const TITLE_MAX = 40 + +interface ThreadMeta { + id: string + title: string + updatedAt: number +} + +const hasWindow = () => typeof window !== 'undefined' + +/** + * Per-thread history adapter. Each thread's messages live under their own + * namespaced key, so this satisfies the PR's `ChatClientPersistence` contract + * (get/set/remove by id) while keeping threads isolated from each other. + */ +const threadPersistence: ChatClientPersistence = { + getItem: (id) => { + if (!hasWindow()) return null + const raw = window.localStorage.getItem(THREAD_KEY_PREFIX + id) + if (!raw) return null + // `UIMessage.createdAt` is a Date that JSON.stringify turned into a string — + // revive it on read. + return (JSON.parse(raw) as Array).map((message) => ({ + ...message, + createdAt: + typeof message.createdAt === 'string' + ? new Date(message.createdAt) + : message.createdAt, + })) + }, + setItem: (id, messages) => { + if (!hasWindow()) return + window.localStorage.setItem( + THREAD_KEY_PREFIX + id, + JSON.stringify(messages), + ) + }, + removeItem: (id) => { + if (!hasWindow()) return + window.localStorage.removeItem(THREAD_KEY_PREFIX + id) + }, +} + +function readIndex(): Array { + if (!hasWindow()) return [] + try { + const raw = window.localStorage.getItem(INDEX_KEY) + return raw ? (JSON.parse(raw) as Array) : [] + } catch { + return [] + } +} + +function writeIndex(threads: Array): void { + if (!hasWindow()) return + try { + window.localStorage.setItem(INDEX_KEY, JSON.stringify(threads)) + } catch { + // Best-effort, mirroring the persistence adapter: a full or unavailable + // store should never break the chat. + } +} + +const byRecency = (a: ThreadMeta, b: ThreadMeta) => b.updatedAt - a.updatedAt + +function truncateTitle(text: string): string { + const trimmed = text.trim() + return trimmed.length > TITLE_MAX + ? `${trimmed.slice(0, TITLE_MAX).trimEnd()}…` + : trimmed +} + +function relativeTime(ts: number): string { + const seconds = Math.floor((Date.now() - ts) / 1000) + if (seconds < 60) return 'just now' + const minutes = Math.floor(seconds / 60) + if (minutes < 60) return `${minutes}m ago` + const hours = Math.floor(minutes / 60) + if (hours < 24) return `${hours}h ago` + return `${Math.floor(hours / 24)}d ago` +} + +/** + * Owns the thread catalog in localStorage. This is the layer the persistence + * adapter doesn't provide: the list of conversations, their titles, and recency. + */ +function useThreadIndex() { + const [threads, setThreads] = useState>([]) + // SSR-safety: the server has no localStorage, so we render an empty list and + // load the real catalog on the client after mount. `loaded` gates bootstrap + // logic so we don't auto-create a spurious thread before the real one loads. + const [loaded, setLoaded] = useState(false) + // Mirror of the catalog read synchronously by mutators. State closures are + // stale between two calls in the same event handler (e.g. set-title then + // bump-recency), which would make the second call clobber the first; the ref + // is always current, so sequential mutations compose correctly. + const threadsRef = useRef>([]) + + useEffect(() => { + const initial = readIndex().sort(byRecency) + threadsRef.current = initial + setThreads(initial) + setLoaded(true) + }, []) + + const commit = (next: Array) => { + const sorted = [...next].sort(byRecency) + threadsRef.current = sorted + setThreads(sorted) + writeIndex(sorted) + } + + const createThread = (): ThreadMeta => { + const thread: ThreadMeta = { + id: crypto.randomUUID(), + title: NEW_CHAT_TITLE, + updatedAt: Date.now(), + } + commit([thread, ...threadsRef.current]) + return thread + } + + const deleteThread = (id: string) => { + threadPersistence.removeItem(id) + commit(threadsRef.current.filter((t) => t.id !== id)) + } + + /** Bump recency, and set the title the first time a thread gets one. */ + const touchThread = (id: string, title?: string) => { + commit( + threadsRef.current.map((t) => + t.id === id + ? { + ...t, + updatedAt: Date.now(), + title: + title && t.title === NEW_CHAT_TITLE + ? truncateTitle(title) + : t.title, + } + : t, + ), + ) + } + + return { threads, loaded, createThread, deleteThread, touchThread } +} + +// --------------------------------------------------------------------------- +// Route +// --------------------------------------------------------------------------- + +function ThreadsRoute() { + const { threads, loaded, createThread, deleteThread, touchThread } = + useThreadIndex() + const [activeId, setActiveId] = useState(null) + + // Ensure there's always an active thread: keep the current one if it still + // exists, otherwise fall back to the most recent, otherwise create one. + useEffect(() => { + if (!loaded) return + if (activeId && threads.some((t) => t.id === activeId)) return + if (threads.length > 0) { + setActiveId(threads[0].id) + return + } + setActiveId(createThread().id) + // `threads`/`createThread` are derived from the same state; depending on + // `activeId` + `loaded` is enough to re-run when the active thread vanishes. + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [loaded, activeId, threads.length]) + + const handleNewChat = () => { + // Don't pile up empty chats — if an untouched "New chat" exists, focus it. + const empty = threads.find((t) => t.title === NEW_CHAT_TITLE) + setActiveId(empty ? empty.id : createThread().id) + } + + const handleDelete = (id: string) => { + deleteThread(id) + if (id === activeId) { + const next = threads.find((t) => t.id !== id) + // If nothing remains the bootstrap effect creates a fresh thread. + setActiveId(next ? next.id : null) + } + } + + return ( +
+ + +
+ {activeId ? ( + touchThread(activeId, text)} + onActivity={() => touchThread(activeId)} + /> + ) : ( +
+ Loading… +
+ )} +
+
+ ) +} + +// --------------------------------------------------------------------------- +// Chat panel — one ChatClient per thread, hydrated from + saved to the adapter. +// Remounted via `key={threadId}` on the parent so switching threads is a clean +// swap; useChat then hydrates the newly-keyed thread's history. +// --------------------------------------------------------------------------- + +function ThreadChat({ + threadId, + onFirstMessage, + onActivity, +}: { + threadId: string + onFirstMessage: (text: string) => void + onActivity: () => void +}) { + const { messages, sendMessage, isLoading, error } = useChat({ + id: threadId, + connection: fetchServerSentEvents('/api/tanchat'), + persistence: threadPersistence, + body: { provider: 'openai', model: 'gpt-4o-mini' }, + onFinish: () => onActivity(), + }) + const [input, setInput] = useState('') + + const handleSubmit = (e: React.FormEvent) => { + e.preventDefault() + const text = input.trim() + if (!text || isLoading) return + if (messages.length === 0) onFirstMessage(text) + onActivity() + setInput('') + void sendMessage(text) + } + + return ( + <> +
+
+ {messages.length === 0 && ( +
+ +

Send a message to start this chat.

+

+ Reload the page or switch threads — it's restored from + localStorage. +

+
+ )} + {messages.map((message) => ( +
+
+ {message.role} +
+
+ {message.parts.map((part, i) => + part.type === 'text' ? ( + {part.content} + ) : null, + )} +
+
+ ))} + {error && ( +
+ {String(error.message ?? error)} +
+ )} +
+
+ +
+
+ setInput(e.target.value)} + placeholder="Type a message…" + className="flex-1 rounded-lg border border-gray-700 bg-gray-800 px-4 py-2 text-white placeholder-gray-500 focus:border-orange-500/50 focus:outline-none" + /> + +
+
+ + ) +} From ae1370302cc9a0441311fd1a8dded2d40d0fd4cf Mon Sep 17 00:00:00 2001 From: Alem Tuzlak Date: Sun, 31 May 2026 15:15:17 +0200 Subject: [PATCH 07/10] feat(examples): render reasoning + tool calls in the threads persistence demo Upgrades the /threads chat panel from plain text to the full guitar-store experience so the persistence demo verifies that every UIMessage part type round-trips: reasoning (`thinking`), `tool-call` + `tool-result` (server `getGuitars` and client `recommendGuitar`), and markdown text. Reuses the index route's client tools, `ThinkingPart`, `GuitarRecommendation`, and the ReactMarkdown plugin set; adds a Stop button. Hardcoded to OpenRouter GPT-5.1 because /api/tanchat enables a reasoning summary for the openrouter provider, so reasoning parts actually stream. Deliberately omits the model picker / image attachments / Gemini-Interactions chaining, and does NOT carry over index.tsx's `setMessages([])`-on-model-change (which would wipe a thread's persisted history). --- examples/ts-react-chat/src/routes/threads.tsx | 226 ++++++++++++++++-- 1 file changed, 200 insertions(+), 26 deletions(-) diff --git a/examples/ts-react-chat/src/routes/threads.tsx b/examples/ts-react-chat/src/routes/threads.tsx index b7a2cec19..32e88fe70 100644 --- a/examples/ts-react-chat/src/routes/threads.tsx +++ b/examples/ts-react-chat/src/routes/threads.tsx @@ -1,8 +1,60 @@ import { useEffect, useRef, useState } from 'react' import { createFileRoute } from '@tanstack/react-router' import { fetchServerSentEvents, useChat } from '@tanstack/ai-react' -import { MessageSquare, Plus, Trash2 } from 'lucide-react' +import { clientTools } from '@tanstack/ai-client' +import { ThinkingPart } from '@tanstack/ai-react-ui' +import ReactMarkdown from 'react-markdown' +import rehypeRaw from 'rehype-raw' +import rehypeSanitize from 'rehype-sanitize' +import rehypeHighlight from 'rehype-highlight' +import remarkGfm from 'remark-gfm' +import { MessageSquare, Plus, Square, Trash2 } from 'lucide-react' import type { ChatClientPersistence, UIMessage } from '@tanstack/ai-client' +import GuitarRecommendation from '@/components/example-GuitarRecommendation' +import { + addToCartToolDef, + addToWishListToolDef, + getPersonalGuitarPreferenceToolDef, + recommendGuitarToolDef, +} from '@/lib/guitar-tools' + +// --------------------------------------------------------------------------- +// Client tool implementations (same set the index route uses). The +// /api/tanchat endpoint always merges the guitar *server* tools, and its +// system prompt forces a `recommendGuitar` call — a *client* tool — so the +// agent loop only completes if we provide these client implementations. +// Reasoning + tool-call parts both live on the persisted UIMessage, so they +// round-trip through the persistence adapter for free. +// --------------------------------------------------------------------------- + +const getPersonalGuitarPreferenceToolClient = + getPersonalGuitarPreferenceToolDef.client(() => ({ preference: 'acoustic' })) + +const addToWishListToolClient = addToWishListToolDef.client((args) => { + const wishList = JSON.parse(localStorage.getItem('wishList') || '[]') + wishList.push(args.guitarId) + localStorage.setItem('wishList', JSON.stringify(wishList)) + return { success: true, guitarId: args.guitarId, totalItems: wishList.length } +}) + +const addToCartToolClient = addToCartToolDef.client((args) => ({ + success: true, + cartId: 'CART_CLIENT_' + Date.now(), + guitarId: args.guitarId, + quantity: args.quantity, + totalItems: args.quantity, +})) + +const recommendGuitarToolClient = recommendGuitarToolDef.client(({ id }) => ({ + id: +id, +})) + +const tools = clientTools( + getPersonalGuitarPreferenceToolClient, + addToWishListToolClient, + addToCartToolClient, + recommendGuitarToolClient, +) export const Route = createFileRoute('/threads')({ component: ThreadsRoute, @@ -195,7 +247,6 @@ function ThreadsRoute() { setActiveId(createThread().id) // `threads`/`createThread` are derived from the same state; depending on // `activeId` + `loaded` is enough to re-run when the active thread vanishes. - // eslint-disable-next-line react-hooks/exhaustive-deps }, [loaded, activeId, threads.length]) const handleNewChat = () => { @@ -292,9 +343,105 @@ function ThreadsRoute() { // --------------------------------------------------------------------------- // Chat panel — one ChatClient per thread, hydrated from + saved to the adapter. // Remounted via `key={threadId}` on the parent so switching threads is a clean -// swap; useChat then hydrates the newly-keyed thread's history. +// swap; useChat then hydrates the newly-keyed thread's history. Reasoning +// (`thinking`) and `tool-call` parts live on the persisted UIMessage, so they +// are restored on reload / thread-switch alongside the text. Hardcoded to +// OpenRouter GPT-5.1 because the /api/tanchat endpoint enables a reasoning +// summary for the openrouter provider, so reasoning parts actually stream. // --------------------------------------------------------------------------- +const CHAT_BODY = { provider: 'openrouter', model: 'openai/gpt-5.1' } as const + +const MARKDOWN_PLUGINS = [rehypeRaw, rehypeSanitize, rehypeHighlight, remarkGfm] + +/** Render a single UIMessage part: reasoning, text, approval prompt, or guitar card. */ +function MessagePart({ + part, + index, + message, + addToolApprovalResponse, +}: { + part: UIMessage['parts'][number] + index: number + message: UIMessage + addToolApprovalResponse: (response: { + id: string + approved: boolean + }) => Promise +}) { + if (part.type === 'thinking') { + // "Complete" once any text part follows it in the same message. + const isComplete = message.parts + .slice(index + 1) + .some((p) => p.type === 'text') + return ( + + ) + } + + if (part.type === 'text' && part.content) { + return ( +
+ + {part.content} + +
+ ) + } + + if ( + part.type === 'tool-call' && + part.state === 'approval-requested' && + part.approval + ) { + return ( +
+

+ 🔒 Approval required: {part.name} +

+
+          {JSON.stringify(JSON.parse(part.arguments), null, 2)}
+        
+
+ + +
+
+ ) + } + + if ( + part.type === 'tool-call' && + part.name === 'recommendGuitar' && + part.output + ) { + return + } + + return null +} + function ThreadChat({ threadId, onFirstMessage, @@ -304,11 +451,19 @@ function ThreadChat({ onFirstMessage: (text: string) => void onActivity: () => void }) { - const { messages, sendMessage, isLoading, error } = useChat({ + const { + messages, + sendMessage, + isLoading, + error, + addToolApprovalResponse, + stop, + } = useChat({ id: threadId, connection: fetchServerSentEvents('/api/tanchat'), persistence: threadPersistence, - body: { provider: 'openai', model: 'gpt-4o-mini' }, + tools, + body: CHAT_BODY, onFinish: () => onActivity(), }) const [input, setInput] = useState('') @@ -332,7 +487,8 @@ function ThreadChat({

Send a message to start this chat.

- Reload the page or switch threads — it's restored from + Try “Recommend me an acoustic guitar” to see reasoning + tool + calls — then reload or switch threads and they're restored from localStorage.

@@ -340,29 +496,37 @@ function ThreadChat({ {messages.map((message) => (
-
- {message.role} -
- {message.parts.map((part, i) => - part.type === 'text' ? ( - {part.content} - ) : null, - )} +
+ {message.role} +
+ {message.parts.map((part, index) => ( + + ))}
))} {error && (
- {String(error.message ?? error)} + {error.message}
)} @@ -373,16 +537,26 @@ function ThreadChat({ setInput(e.target.value)} - placeholder="Type a message…" + placeholder="Try “Recommend me an acoustic guitar”…" className="flex-1 rounded-lg border border-gray-700 bg-gray-800 px-4 py-2 text-white placeholder-gray-500 focus:border-orange-500/50 focus:outline-none" /> - + {isLoading ? ( + + ) : ( + + )} From 5ebe7a3d6a7e8b243fa6421c3a74f94f1fd5d86b Mon Sep 17 00:00:00 2001 From: Tom Beckenham <34339192+tombeckenham@users.noreply.github.com> Date: Tue, 2 Jun 2026 10:48:53 +1000 Subject: [PATCH 08/10] refactor(ai-client): replace chunk shape-casts with `in` narrowing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ChatPersistor read `toolCallId` / `messageId` / `parentMessageId` off a `StreamChunk` via `(chunk as { toolCallId?: string }).toolCallId` shape assertions. Those bypass the discriminated union and would silently read `undefined` (or a wrong type) if the AG-UI field were renamed/retyped. Extract three small `in`-narrowing helpers — getChunkToolCallId / getChunkMessageId / getChunkParentMessageId — mirroring the existing getChunkRunId in connection-adapters. Removes all five `as` casts and DRYs the duplicate reads in shouldIgnoreChunk / markIgnoredChunkIds. No behavior change; ai-client typecheck, eslint, and 376 unit tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- packages/ai-client/src/client-persistor.ts | 34 ++++++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/packages/ai-client/src/client-persistor.ts b/packages/ai-client/src/client-persistor.ts index 731ff4e01..426e8c2be 100644 --- a/packages/ai-client/src/client-persistor.ts +++ b/packages/ai-client/src/client-persistor.ts @@ -2,6 +2,29 @@ import { getChunkRunId } from './connection-adapters' import type { StreamChunk } from '@tanstack/ai/client' import type { ChatClientPersistence, UIMessage } from './types' +// `StreamChunk` is a discriminated union; `toolCallId` / `messageId` / +// `parentMessageId` exist on only some members. Narrow with `in` (matching +// `getChunkRunId`) instead of asserting a shape, so the field's real type is +// preserved and a protocol rename can't be read past silently. +function getChunkToolCallId(chunk: StreamChunk): string | undefined { + return 'toolCallId' in chunk && typeof chunk.toolCallId === 'string' + ? chunk.toolCallId + : undefined +} + +function getChunkMessageId(chunk: StreamChunk): string | undefined { + return 'messageId' in chunk && typeof chunk.messageId === 'string' + ? chunk.messageId + : undefined +} + +function getChunkParentMessageId(chunk: StreamChunk): string | undefined { + return 'parentMessageId' in chunk && + typeof chunk.parentMessageId === 'string' + ? chunk.parentMessageId + : undefined +} + /** * Encapsulates everything persistence-related for `ChatClient` so the client * itself stays focused on streaming and message state. @@ -211,13 +234,12 @@ export class ChatPersistor { return true } - const toolCallId = (chunk as { toolCallId?: string }).toolCallId + const toolCallId = getChunkToolCallId(chunk) if (toolCallId && this.clearedToolCallIds.has(toolCallId)) { return true } - const parentMessageId = (chunk as { parentMessageId?: string }) - .parentMessageId + const parentMessageId = getChunkParentMessageId(chunk) if (parentMessageId && this.clearedMessageIds.has(parentMessageId)) { if (toolCallId) { this.clearedToolCallIds.add(toolCallId) @@ -225,7 +247,7 @@ export class ChatPersistor { return true } - const messageId = (chunk as { messageId?: string }).messageId + const messageId = getChunkMessageId(chunk) if (!messageId) { return false } @@ -279,11 +301,11 @@ export class ChatPersistor { } private markIgnoredChunkIds(chunk: StreamChunk): void { - const messageId = (chunk as { messageId?: string }).messageId + const messageId = getChunkMessageId(chunk) if (messageId) { this.clearedMessageIds.add(messageId) } - const toolCallId = (chunk as { toolCallId?: string }).toolCallId + const toolCallId = getChunkToolCallId(chunk) if (toolCallId) { this.clearedToolCallIds.add(toolCallId) } From 0fce6f1f8c68ef7effd1213bdf91da1239030a17 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Tue, 2 Jun 2026 00:50:16 +0000 Subject: [PATCH 09/10] ci: apply automated fixes --- packages/ai-client/src/client-persistor.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/packages/ai-client/src/client-persistor.ts b/packages/ai-client/src/client-persistor.ts index 426e8c2be..bc3edb632 100644 --- a/packages/ai-client/src/client-persistor.ts +++ b/packages/ai-client/src/client-persistor.ts @@ -19,8 +19,7 @@ function getChunkMessageId(chunk: StreamChunk): string | undefined { } function getChunkParentMessageId(chunk: StreamChunk): string | undefined { - return 'parentMessageId' in chunk && - typeof chunk.parentMessageId === 'string' + return 'parentMessageId' in chunk && typeof chunk.parentMessageId === 'string' ? chunk.parentMessageId : undefined } From 1fb316caf8e88b951f0f127199247e5432b4a9d9 Mon Sep 17 00:00:00 2001 From: Tom Beckenham <34339192+tombeckenham@users.noreply.github.com> Date: Tue, 2 Jun 2026 13:33:19 +1000 Subject: [PATCH 10/10] fix(ai-client): advance runless pointer in takeRunlessRunId; add clear/thread-switch e2e takeRunlessRunId nulled currentRunlessRunId instead of advancing to the next still-ignored run (unlike onRunSettled). With two runs cleared mid-stream concurrently, draining one via a runId-less RUN_ERROR stopped suppressing the other's runless content, letting it repopulate a cleared conversation. Advance the pointer so a second cleared run stays suppressed. Also add e2e coverage for the two highest-risk persistence behaviors that were previously untested: clear() removing the stored conversation across a reload, and switching the chat id in place loading that id's own history. Co-Authored-By: Claude Opus 4.8 (1M context) --- packages/ai-client/src/client-persistor.ts | 6 +- .../ai-client/tests/client-persistor.test.ts | 20 ++++ testing/e2e/src/routes/$provider/$feature.tsx | 101 ++++++++++++------ testing/e2e/tests/chat.spec.ts | 63 +++++++++++ 4 files changed, 157 insertions(+), 33 deletions(-) diff --git a/packages/ai-client/src/client-persistor.ts b/packages/ai-client/src/client-persistor.ts index bc3edb632..517aaf0a7 100644 --- a/packages/ai-client/src/client-persistor.ts +++ b/packages/ai-client/src/client-persistor.ts @@ -295,7 +295,11 @@ export class ChatPersistor { if (!runId) return null this.ignoredActiveRunIds.delete(runId) this.clearedRunIds.delete(runId) - this.currentRunlessRunId = null + // Advance to another still-ignored run (mirroring `onRunSettled`) so that + // when two cleared runs drain concurrently, draining one via a runId-less + // RUN_ERROR doesn't stop suppressing the other's runless content. + this.currentRunlessRunId = + this.ignoredActiveRunIds.values().next().value ?? null return runId } diff --git a/packages/ai-client/tests/client-persistor.test.ts b/packages/ai-client/tests/client-persistor.test.ts index 34431a10c..05cc4380c 100644 --- a/packages/ai-client/tests/client-persistor.test.ts +++ b/packages/ai-client/tests/client-persistor.test.ts @@ -473,5 +473,25 @@ describe('ChatPersistor', () => { expect(persistor.shouldIgnoreChunk(runFinished('run-2'))).toBe(false) expect(persistor.shouldIgnoreChunk(textContent('late'))).toBe(true) }) + + it('keeps suppressing a second cleared run after the first drains via a runless error', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [], + activeRunIds: new Set(['run-1', 'run-2']), + currentRunId: null, + }) + persistor.shouldIgnoreChunk(runStarted('run-1')) + persistor.shouldIgnoreChunk(runStarted('run-2')) // run-2 now pinned + + // A runId-less RUN_ERROR drains the pinned run (run-2)... + expect(persistor.takeRunlessRunId()).toBe('run-2') + + // ...but run-1 is still cleared, so its runless content stays suppressed + // instead of leaking through once the pointer advances back to it. + expect(persistor.shouldIgnoreChunk(textContent('late-from-run-1'))).toBe( + true, + ) + }) }) }) diff --git a/testing/e2e/src/routes/$provider/$feature.tsx b/testing/e2e/src/routes/$provider/$feature.tsx index 94ba0ba51..ea080c4fc 100644 --- a/testing/e2e/src/routes/$provider/$feature.tsx +++ b/testing/e2e/src/routes/$provider/$feature.tsx @@ -191,7 +191,17 @@ function ChatFeature({ const tools = needsApproval ? clientTools(addToCartClient) : undefined const { testId, aimockPort, persistence } = Route.useSearch() - const chatId = `e2e-chat-${testId ?? `${provider}-${feature}`}` + const persistenceEnabled = persistence === 'localStorage' + const baseChatId = `e2e-chat-${testId ?? `${provider}-${feature}`}` + // When persistence is on, expose a tiny thread switcher so e2e can verify that + // changing the `id` in place swaps to that id's own persisted history (the + // render-from-getMessages + activeClientRef path), keyed per thread. Start on + // thread "a" (not null) so the page loads already on a thread id — switching + // is then a pure in-place id swap with no initial null→thread transition. + const [activeThread, setActiveThread] = useState( + persistenceEnabled ? 'a' : null, + ) + const chatId = activeThread ? `${baseChatId}:${activeThread}` : baseChatId const [structuredObject, setStructuredObject] = useState(null) const [contentDeltaCount, setContentDeltaCount] = useState(0) @@ -240,37 +250,43 @@ function ChatFeature({ } : { connection: fetchServerSentEvents('/api/chat') } - const { messages, sendMessage, isLoading, addToolApprovalResponse, stop } = - useChat({ - id: chatId, - ...transport, - tools, - body: { - provider, - feature, - testId, - aimockPort, - previousInteractionId: interactionId, - }, - persistence: - persistence === 'localStorage' ? localStoragePersistence : undefined, - onCustomEvent: (eventType, data) => { - if (eventType === 'structured-output.complete') { - const value = data as { object: unknown; raw: string } | undefined - setStructuredObject(value?.object ?? null) - } else if (eventType === 'gemini.interactionId') { - const value = data as - | GeminiInteractionsCustomEventValue<'gemini.interactionId'> - | undefined - if (value?.interactionId) setInteractionId(value.interactionId) - } - }, - onChunk: (chunk) => { - if (chunk.type === 'TEXT_MESSAGE_CONTENT') { - setContentDeltaCount((n) => n + 1) - } - }, - }) + const { + messages, + sendMessage, + isLoading, + addToolApprovalResponse, + stop, + clear, + } = useChat({ + id: chatId, + ...transport, + tools, + body: { + provider, + feature, + testId, + aimockPort, + previousInteractionId: interactionId, + }, + persistence: + persistence === 'localStorage' ? localStoragePersistence : undefined, + onCustomEvent: (eventType, data) => { + if (eventType === 'structured-output.complete') { + const value = data as { object: unknown; raw: string } | undefined + setStructuredObject(value?.object ?? null) + } else if (eventType === 'gemini.interactionId') { + const value = data as + | GeminiInteractionsCustomEventValue<'gemini.interactionId'> + | undefined + if (value?.interactionId) setInteractionId(value.interactionId) + } + }, + onChunk: (chunk) => { + if (chunk.type === 'TEXT_MESSAGE_CONTENT') { + setContentDeltaCount((n) => n + 1) + } + }, + }) return ( <> @@ -279,6 +295,27 @@ function ChatFeature({ {interactionId} )} + {persistenceEnabled && ( +
+ + + +
+ )} { 'Fender Stratocaster', ) }) + + test('clear() removes the persisted conversation so a reload starts empty', async ({ + page, + testId, + aimockPort, + }) => { + await page.goto( + `${featureUrl('openai', 'chat', testId, aimockPort)}&persistence=localStorage`, + ) + + await sendMessage(page, '[chat] recommend a guitar') + await waitForResponse(page) + await expect(page.getByTestId('user-message')).toContainText( + '[chat] recommend a guitar', + ) + + await page.getByTestId('clear-button').click() + await expect(page.getByTestId('user-message')).toHaveCount(0) + await expect(page.getByTestId('assistant-message')).toHaveCount(0) + + // The conversation was removed from storage, not just from memory — a + // reload must not resurrect it. + await page.reload() + await expect(page.getByTestId('message-list')).toBeVisible() + await expect(page.getByTestId('user-message')).toHaveCount(0) + await expect(page.getByTestId('assistant-message')).toHaveCount(0) + }) + + test('switches per-thread history when the chat id changes in place', async ({ + page, + testId, + aimockPort, + }) => { + await page.goto( + `${featureUrl('openai', 'chat', testId, aimockPort)}&persistence=localStorage`, + ) + + // The page loads on thread A. Send a message (persisted under A's own id). + await sendMessage(page, '[chat] recommend a guitar') + await waitForResponse(page) + await expect(page.getByTestId('user-message')).toHaveCount(1) + + // Switch to thread B in place — its own (empty) history loads, proving the + // id swap doesn't leak thread A's messages into thread B. + await page.getByTestId('select-thread-b').click() + await expect(page.getByTestId('user-message')).toHaveCount(0) + await expect(page.getByTestId('assistant-message')).toHaveCount(0) + + await sendMessage(page, '[chat] recommend a guitar') + await waitForResponse(page) + await expect(page.getByTestId('user-message')).toHaveCount(1) + + // Switch back to thread A — its persisted history is restored from storage + // on the in-place swap (render-from-getMessages), exactly one message. + await page.getByTestId('select-thread-a').click() + await expect(page.getByTestId('user-message')).toHaveCount(1) + await expect(page.getByTestId('user-message')).toContainText( + '[chat] recommend a guitar', + ) + await expect(page.getByTestId('assistant-message')).toContainText( + 'Fender Stratocaster', + ) + }) })