diff --git a/.changeset/chat-client-persistence.md b/.changeset/chat-client-persistence.md new file mode 100644 index 000000000..a040e656c --- /dev/null +++ b/.changeset/chat-client-persistence.md @@ -0,0 +1,10 @@ +--- +'@tanstack/ai-client': minor +'@tanstack/ai-react': minor +'@tanstack/ai-preact': minor +'@tanstack/ai-solid': minor +'@tanstack/ai-svelte': minor +'@tanstack/ai-vue': minor +--- + +Add persistence support for chat messages. diff --git a/docs/chat/persistence.md b/docs/chat/persistence.md new file mode 100644 index 000000000..db877be0c --- /dev/null +++ b/docs/chat/persistence.md @@ -0,0 +1,142 @@ +--- +title: Persistence +id: chat-persistence +order: 5 +description: "Persist chat conversations on the client with TanStack AI — hydrate on load, save on change, and clear on reset using a simple getItem/setItem/removeItem adapter." +keywords: + - tanstack ai + - persistence + - chat history + - localStorage + - indexeddb + - offline + - hydration +--- + +By default a `ChatClient` (and every framework `useChat`/`createChat` wrapper) keeps messages in memory only — reload the page or navigate away and the conversation is gone. The optional **persistence adapter** wires the client to a storage backend so conversations survive reloads, with no manual `initialMessages` + `onFinish` boilerplate. + +This is especially useful for SPAs, Electron apps, and offline-first setups where the client is the source of truth and there's no server managing conversation state. + +## The adapter interface + +A persistence adapter is any object with three methods — the same `getItem`/`setItem`/`removeItem` shape used elsewhere in TanStack AI. Each method may be synchronous or return a `Promise`: + +```typescript +import type { ChatClientPersistence } from "@tanstack/ai-client"; + +interface ChatClientPersistence { + getItem: ( + id: string, + ) => + | Array + | null + | undefined + | Promise | null | undefined>; + setItem: (id: string, messages: Array) => void | Promise; + removeItem: (id: string) => void | Promise; +} +``` + +The `id` passed to each method is the client's `id` option. Provide a stable `id` per conversation so the right history is loaded back: + +```typescript +const client = new ChatClient({ + id: "conversation-123", + connection: adapter, + persistence: myPersistenceAdapter, +}); +``` + +## What the client does for you + +When a `persistence` adapter is provided, `ChatClient`: + +- **Hydrates on construction** — calls `getItem(id)`. If it returns an array, those messages populate the client (overriding `initialMessages`). Async adapters hydrate as soon as the promise resolves, unless you've already started a new conversation in the meantime. +- **Saves on every change** — calls `setItem(id, messages)` whenever the message list changes (new user message, streamed assistant content, tool calls/results, approval responses). Writes are queued so they never overlap or land out of order. +- **Clears on `clear()`** — calls `removeItem(id)` and discards any in-flight stream so a cleared conversation doesn't get repopulated by late chunks. + +When `persistence` is omitted, nothing changes — the client behaves exactly as before. The option is fully backwards compatible. + +Persistence is **best-effort**: if an adapter method throws or rejects, the error is swallowed so storage problems never break the chat. Handle and surface errors inside your adapter if you need to react to them. + +## Framework usage + +Every framework wrapper accepts the same `persistence` option and forwards it to the underlying `ChatClient`: + +```tsx +// React / Preact +const chat = useChat({ + id: "conversation-123", + connection: fetchServerSentEvents("/api/chat"), + persistence: myPersistenceAdapter, +}); +``` + +```ts +// Solid / Vue — same option +const chat = useChat({ + id: "conversation-123", + connection: fetchServerSentEvents("/api/chat"), + persistence: myPersistenceAdapter, +}); +``` + +```ts +// Svelte +const chat = createChat({ + id: "conversation-123", + connection: fetchServerSentEvents("/api/chat"), + persistence: myPersistenceAdapter, +}); +``` + +## Example: `localStorage` + +A synchronous adapter backed by `localStorage`. Note that `UIMessage.createdAt` is a `Date`, which `JSON.stringify` turns into a string — revive it on read if you depend on it: + +```typescript +import type { ChatClientPersistence, UIMessage } from "@tanstack/ai-client"; + +const localStoragePersistence: ChatClientPersistence = { + getItem: (id) => { + const raw = window.localStorage.getItem(id); + if (!raw) return null; + return (JSON.parse(raw) as Array).map((message) => ({ + ...message, + createdAt: + typeof message.createdAt === "string" + ? new Date(message.createdAt) + : message.createdAt, + })); + }, + setItem: (id, messages) => { + window.localStorage.setItem(id, JSON.stringify(messages)); + }, + removeItem: (id) => { + window.localStorage.removeItem(id); + }, +}; +``` + +## Example: IndexedDB (async) + +For larger histories or structured queries, back the adapter with an async store such as IndexedDB. The client awaits async methods automatically: + +```typescript +import type { ChatClientPersistence } from "@tanstack/ai-client"; + +const indexedDbPersistence: ChatClientPersistence = { + getItem: async (id) => { + const record = await db.conversations.get(id); + return record?.messages; + }, + setItem: async (id, messages) => { + await db.conversations.put({ id, messages, updatedAt: Date.now() }); + }, + removeItem: async (id) => { + await db.conversations.delete(id); + }, +}; +``` + +Any backend works — IndexedDB, SQLite (Electron/Tauri), a remote database, or an in-memory `Map` for tests — as long as it implements the three methods. diff --git a/docs/config.json b/docs/config.json index 097ecdc41..b712f229c 100644 --- a/docs/config.json +++ b/docs/config.json @@ -103,6 +103,10 @@ { "label": "Thinking & Reasoning", "to": "chat/thinking-content" + }, + { + "label": "Persistence", + "to": "chat/persistence" } ] }, diff --git a/examples/ts-react-chat/src/components/Header.tsx b/examples/ts-react-chat/src/components/Header.tsx index 2224e4578..7d7b7f333 100644 --- a/examples/ts-react-chat/src/components/Header.tsx +++ b/examples/ts-react-chat/src/components/Header.tsx @@ -12,6 +12,7 @@ import { Image, Menu, Mic, + MessageSquare, Music, Server, Video, @@ -214,6 +215,19 @@ export default function Header() { Guitar Demo + setIsOpen(false)} + className="flex items-center gap-3 p-3 rounded-lg hover:bg-gray-800 transition-colors mb-2" + activeProps={{ + className: + 'flex items-center gap-3 p-3 rounded-lg bg-cyan-600 hover:bg-cyan-700 transition-colors mb-2', + }} + > + + Persistent Chats + + setIsOpen(false)} diff --git a/examples/ts-react-chat/src/routeTree.gen.ts b/examples/ts-react-chat/src/routeTree.gen.ts index 59ad81fc8..467b07418 100644 --- a/examples/ts-react-chat/src/routeTree.gen.ts +++ b/examples/ts-react-chat/src/routeTree.gen.ts @@ -9,6 +9,7 @@ // Additionally, you should also exclude this file from your linter and/or formatter to prevent it from being checked or modified. import { Route as rootRouteImport } from './routes/__root' +import { Route as ThreadsRouteImport } from './routes/threads' import { Route as ServerFnChatRouteImport } from './routes/server-fn-chat' import { Route as RealtimeRouteImport } from './routes/realtime' import { Route as Issue176ToolResultRouteImport } from './routes/issue-176-tool-result' @@ -39,6 +40,11 @@ import { Route as ApiGenerateSpeechRouteImport } from './routes/api.generate.spe import { Route as ApiGenerateImageRouteImport } from './routes/api.generate.image' import { Route as ApiGenerateAudioRouteImport } from './routes/api.generate.audio' +const ThreadsRoute = ThreadsRouteImport.update({ + id: '/threads', + path: '/threads', + getParentRoute: () => rootRouteImport, +} as any) const ServerFnChatRoute = ServerFnChatRouteImport.update({ id: '/server-fn-chat', path: '/server-fn-chat', @@ -196,6 +202,7 @@ export interface FileRoutesByFullPath { '/issue-176-tool-result': typeof Issue176ToolResultRoute '/realtime': typeof RealtimeRoute '/server-fn-chat': typeof ServerFnChatRoute + '/threads': typeof ThreadsRoute '/api/image-gen': typeof ApiImageGenRoute '/api/image-tool-repro': typeof ApiImageToolReproRoute '/api/structured-chat': typeof ApiStructuredChatRoute @@ -227,6 +234,7 @@ export interface FileRoutesByTo { '/issue-176-tool-result': typeof Issue176ToolResultRoute '/realtime': typeof RealtimeRoute '/server-fn-chat': typeof ServerFnChatRoute + '/threads': typeof ThreadsRoute '/api/image-gen': typeof ApiImageGenRoute '/api/image-tool-repro': typeof ApiImageToolReproRoute '/api/structured-chat': typeof ApiStructuredChatRoute @@ -259,6 +267,7 @@ export interface FileRoutesById { '/issue-176-tool-result': typeof Issue176ToolResultRoute '/realtime': typeof RealtimeRoute '/server-fn-chat': typeof ServerFnChatRoute + '/threads': typeof ThreadsRoute '/api/image-gen': typeof ApiImageGenRoute '/api/image-tool-repro': typeof ApiImageToolReproRoute '/api/structured-chat': typeof ApiStructuredChatRoute @@ -292,6 +301,7 @@ export interface FileRouteTypes { | '/issue-176-tool-result' | '/realtime' | '/server-fn-chat' + | '/threads' | '/api/image-gen' | '/api/image-tool-repro' | '/api/structured-chat' @@ -323,6 +333,7 @@ export interface FileRouteTypes { | '/issue-176-tool-result' | '/realtime' | '/server-fn-chat' + | '/threads' | '/api/image-gen' | '/api/image-tool-repro' | '/api/structured-chat' @@ -354,6 +365,7 @@ export interface FileRouteTypes { | '/issue-176-tool-result' | '/realtime' | '/server-fn-chat' + | '/threads' | '/api/image-gen' | '/api/image-tool-repro' | '/api/structured-chat' @@ -386,6 +398,7 @@ export interface RootRouteChildren { Issue176ToolResultRoute: typeof Issue176ToolResultRoute RealtimeRoute: typeof RealtimeRoute ServerFnChatRoute: typeof ServerFnChatRoute + ThreadsRoute: typeof ThreadsRoute ApiImageGenRoute: typeof ApiImageGenRoute ApiImageToolReproRoute: typeof ApiImageToolReproRoute ApiStructuredChatRoute: typeof ApiStructuredChatRoute @@ -412,6 +425,13 @@ export interface RootRouteChildren { declare module '@tanstack/react-router' { interface FileRoutesByPath { + '/threads': { + id: '/threads' + path: '/threads' + fullPath: '/threads' + preLoaderRoute: typeof ThreadsRouteImport + parentRoute: typeof rootRouteImport + } '/server-fn-chat': { id: '/server-fn-chat' path: '/server-fn-chat' @@ -626,6 +646,7 @@ const rootRouteChildren: RootRouteChildren = { Issue176ToolResultRoute: Issue176ToolResultRoute, RealtimeRoute: RealtimeRoute, ServerFnChatRoute: ServerFnChatRoute, + ThreadsRoute: ThreadsRoute, ApiImageGenRoute: ApiImageGenRoute, ApiImageToolReproRoute: ApiImageToolReproRoute, ApiStructuredChatRoute: ApiStructuredChatRoute, diff --git a/examples/ts-react-chat/src/routes/threads.tsx b/examples/ts-react-chat/src/routes/threads.tsx new file mode 100644 index 000000000..32e88fe70 --- /dev/null +++ b/examples/ts-react-chat/src/routes/threads.tsx @@ -0,0 +1,564 @@ +import { useEffect, useRef, useState } from 'react' +import { createFileRoute } from '@tanstack/react-router' +import { fetchServerSentEvents, useChat } from '@tanstack/ai-react' +import { clientTools } from '@tanstack/ai-client' +import { ThinkingPart } from '@tanstack/ai-react-ui' +import ReactMarkdown from 'react-markdown' +import rehypeRaw from 'rehype-raw' +import rehypeSanitize from 'rehype-sanitize' +import rehypeHighlight from 'rehype-highlight' +import remarkGfm from 'remark-gfm' +import { MessageSquare, Plus, Square, Trash2 } from 'lucide-react' +import type { ChatClientPersistence, UIMessage } from '@tanstack/ai-client' +import GuitarRecommendation from '@/components/example-GuitarRecommendation' +import { + addToCartToolDef, + addToWishListToolDef, + getPersonalGuitarPreferenceToolDef, + recommendGuitarToolDef, +} from '@/lib/guitar-tools' + +// --------------------------------------------------------------------------- +// Client tool implementations (same set the index route uses). The +// /api/tanchat endpoint always merges the guitar *server* tools, and its +// system prompt forces a `recommendGuitar` call — a *client* tool — so the +// agent loop only completes if we provide these client implementations. +// Reasoning + tool-call parts both live on the persisted UIMessage, so they +// round-trip through the persistence adapter for free. +// --------------------------------------------------------------------------- + +const getPersonalGuitarPreferenceToolClient = + getPersonalGuitarPreferenceToolDef.client(() => ({ preference: 'acoustic' })) + +const addToWishListToolClient = addToWishListToolDef.client((args) => { + const wishList = JSON.parse(localStorage.getItem('wishList') || '[]') + wishList.push(args.guitarId) + localStorage.setItem('wishList', JSON.stringify(wishList)) + return { success: true, guitarId: args.guitarId, totalItems: wishList.length } +}) + +const addToCartToolClient = addToCartToolDef.client((args) => ({ + success: true, + cartId: 'CART_CLIENT_' + Date.now(), + guitarId: args.guitarId, + quantity: args.quantity, + totalItems: args.quantity, +})) + +const recommendGuitarToolClient = recommendGuitarToolDef.client(({ id }) => ({ + id: +id, +})) + +const tools = clientTools( + getPersonalGuitarPreferenceToolClient, + addToWishListToolClient, + addToCartToolClient, + recommendGuitarToolClient, +) + +export const Route = createFileRoute('/threads')({ + component: ThreadsRoute, +}) + +// --------------------------------------------------------------------------- +// Storage layer +// +// Two concerns, two key spaces — this is the whole point of the demo: +// +// 1. Per-thread message history → the PR's persistence adapter, keyed by the +// ChatClient `id`. We namespace it under `THREAD_KEY_PREFIX + id` so every +// thread is its own localStorage entry and threads never overwrite each +// other. +// 2. The thread catalog (which threads exist, their titles, recency) → a +// separate index the example owns. The persistence adapter is deliberately +// single-conversation and knows nothing about "all my chats", so listing +// and ordering threads is an app-level layer built on top. +// --------------------------------------------------------------------------- + +const INDEX_KEY = 'tanstack-ai:threads' +const THREAD_KEY_PREFIX = 'tanstack-ai:thread:' +const NEW_CHAT_TITLE = 'New chat' +const TITLE_MAX = 40 + +interface ThreadMeta { + id: string + title: string + updatedAt: number +} + +const hasWindow = () => typeof window !== 'undefined' + +/** + * Per-thread history adapter. Each thread's messages live under their own + * namespaced key, so this satisfies the PR's `ChatClientPersistence` contract + * (get/set/remove by id) while keeping threads isolated from each other. + */ +const threadPersistence: ChatClientPersistence = { + getItem: (id) => { + if (!hasWindow()) return null + const raw = window.localStorage.getItem(THREAD_KEY_PREFIX + id) + if (!raw) return null + // `UIMessage.createdAt` is a Date that JSON.stringify turned into a string — + // revive it on read. + return (JSON.parse(raw) as Array).map((message) => ({ + ...message, + createdAt: + typeof message.createdAt === 'string' + ? new Date(message.createdAt) + : message.createdAt, + })) + }, + setItem: (id, messages) => { + if (!hasWindow()) return + window.localStorage.setItem( + THREAD_KEY_PREFIX + id, + JSON.stringify(messages), + ) + }, + removeItem: (id) => { + if (!hasWindow()) return + window.localStorage.removeItem(THREAD_KEY_PREFIX + id) + }, +} + +function readIndex(): Array { + if (!hasWindow()) return [] + try { + const raw = window.localStorage.getItem(INDEX_KEY) + return raw ? (JSON.parse(raw) as Array) : [] + } catch { + return [] + } +} + +function writeIndex(threads: Array): void { + if (!hasWindow()) return + try { + window.localStorage.setItem(INDEX_KEY, JSON.stringify(threads)) + } catch { + // Best-effort, mirroring the persistence adapter: a full or unavailable + // store should never break the chat. + } +} + +const byRecency = (a: ThreadMeta, b: ThreadMeta) => b.updatedAt - a.updatedAt + +function truncateTitle(text: string): string { + const trimmed = text.trim() + return trimmed.length > TITLE_MAX + ? `${trimmed.slice(0, TITLE_MAX).trimEnd()}…` + : trimmed +} + +function relativeTime(ts: number): string { + const seconds = Math.floor((Date.now() - ts) / 1000) + if (seconds < 60) return 'just now' + const minutes = Math.floor(seconds / 60) + if (minutes < 60) return `${minutes}m ago` + const hours = Math.floor(minutes / 60) + if (hours < 24) return `${hours}h ago` + return `${Math.floor(hours / 24)}d ago` +} + +/** + * Owns the thread catalog in localStorage. This is the layer the persistence + * adapter doesn't provide: the list of conversations, their titles, and recency. + */ +function useThreadIndex() { + const [threads, setThreads] = useState>([]) + // SSR-safety: the server has no localStorage, so we render an empty list and + // load the real catalog on the client after mount. `loaded` gates bootstrap + // logic so we don't auto-create a spurious thread before the real one loads. + const [loaded, setLoaded] = useState(false) + // Mirror of the catalog read synchronously by mutators. State closures are + // stale between two calls in the same event handler (e.g. set-title then + // bump-recency), which would make the second call clobber the first; the ref + // is always current, so sequential mutations compose correctly. + const threadsRef = useRef>([]) + + useEffect(() => { + const initial = readIndex().sort(byRecency) + threadsRef.current = initial + setThreads(initial) + setLoaded(true) + }, []) + + const commit = (next: Array) => { + const sorted = [...next].sort(byRecency) + threadsRef.current = sorted + setThreads(sorted) + writeIndex(sorted) + } + + const createThread = (): ThreadMeta => { + const thread: ThreadMeta = { + id: crypto.randomUUID(), + title: NEW_CHAT_TITLE, + updatedAt: Date.now(), + } + commit([thread, ...threadsRef.current]) + return thread + } + + const deleteThread = (id: string) => { + threadPersistence.removeItem(id) + commit(threadsRef.current.filter((t) => t.id !== id)) + } + + /** Bump recency, and set the title the first time a thread gets one. */ + const touchThread = (id: string, title?: string) => { + commit( + threadsRef.current.map((t) => + t.id === id + ? { + ...t, + updatedAt: Date.now(), + title: + title && t.title === NEW_CHAT_TITLE + ? truncateTitle(title) + : t.title, + } + : t, + ), + ) + } + + return { threads, loaded, createThread, deleteThread, touchThread } +} + +// --------------------------------------------------------------------------- +// Route +// --------------------------------------------------------------------------- + +function ThreadsRoute() { + const { threads, loaded, createThread, deleteThread, touchThread } = + useThreadIndex() + const [activeId, setActiveId] = useState(null) + + // Ensure there's always an active thread: keep the current one if it still + // exists, otherwise fall back to the most recent, otherwise create one. + useEffect(() => { + if (!loaded) return + if (activeId && threads.some((t) => t.id === activeId)) return + if (threads.length > 0) { + setActiveId(threads[0].id) + return + } + setActiveId(createThread().id) + // `threads`/`createThread` are derived from the same state; depending on + // `activeId` + `loaded` is enough to re-run when the active thread vanishes. + }, [loaded, activeId, threads.length]) + + const handleNewChat = () => { + // Don't pile up empty chats — if an untouched "New chat" exists, focus it. + const empty = threads.find((t) => t.title === NEW_CHAT_TITLE) + setActiveId(empty ? empty.id : createThread().id) + } + + const handleDelete = (id: string) => { + deleteThread(id) + if (id === activeId) { + const next = threads.find((t) => t.id !== id) + // If nothing remains the bootstrap effect creates a fresh thread. + setActiveId(next ? next.id : null) + } + } + + return ( +
+ + +
+ {activeId ? ( + touchThread(activeId, text)} + onActivity={() => touchThread(activeId)} + /> + ) : ( +
+ Loading… +
+ )} +
+
+ ) +} + +// --------------------------------------------------------------------------- +// Chat panel — one ChatClient per thread, hydrated from + saved to the adapter. +// Remounted via `key={threadId}` on the parent so switching threads is a clean +// swap; useChat then hydrates the newly-keyed thread's history. Reasoning +// (`thinking`) and `tool-call` parts live on the persisted UIMessage, so they +// are restored on reload / thread-switch alongside the text. Hardcoded to +// OpenRouter GPT-5.1 because the /api/tanchat endpoint enables a reasoning +// summary for the openrouter provider, so reasoning parts actually stream. +// --------------------------------------------------------------------------- + +const CHAT_BODY = { provider: 'openrouter', model: 'openai/gpt-5.1' } as const + +const MARKDOWN_PLUGINS = [rehypeRaw, rehypeSanitize, rehypeHighlight, remarkGfm] + +/** Render a single UIMessage part: reasoning, text, approval prompt, or guitar card. */ +function MessagePart({ + part, + index, + message, + addToolApprovalResponse, +}: { + part: UIMessage['parts'][number] + index: number + message: UIMessage + addToolApprovalResponse: (response: { + id: string + approved: boolean + }) => Promise +}) { + if (part.type === 'thinking') { + // "Complete" once any text part follows it in the same message. + const isComplete = message.parts + .slice(index + 1) + .some((p) => p.type === 'text') + return ( + + ) + } + + if (part.type === 'text' && part.content) { + return ( +
+ + {part.content} + +
+ ) + } + + if ( + part.type === 'tool-call' && + part.state === 'approval-requested' && + part.approval + ) { + return ( +
+

+ 🔒 Approval required: {part.name} +

+
+          {JSON.stringify(JSON.parse(part.arguments), null, 2)}
+        
+
+ + +
+
+ ) + } + + if ( + part.type === 'tool-call' && + part.name === 'recommendGuitar' && + part.output + ) { + return + } + + return null +} + +function ThreadChat({ + threadId, + onFirstMessage, + onActivity, +}: { + threadId: string + onFirstMessage: (text: string) => void + onActivity: () => void +}) { + const { + messages, + sendMessage, + isLoading, + error, + addToolApprovalResponse, + stop, + } = useChat({ + id: threadId, + connection: fetchServerSentEvents('/api/tanchat'), + persistence: threadPersistence, + tools, + body: CHAT_BODY, + onFinish: () => onActivity(), + }) + const [input, setInput] = useState('') + + const handleSubmit = (e: React.FormEvent) => { + e.preventDefault() + const text = input.trim() + if (!text || isLoading) return + if (messages.length === 0) onFirstMessage(text) + onActivity() + setInput('') + void sendMessage(text) + } + + return ( + <> +
+
+ {messages.length === 0 && ( +
+ +

Send a message to start this chat.

+

+ Try “Recommend me an acoustic guitar” to see reasoning + tool + calls — then reload or switch threads and they're restored from + localStorage. +

+
+ )} + {messages.map((message) => ( +
+
+
+ {message.role} +
+ {message.parts.map((part, index) => ( + + ))} +
+
+ ))} + {error && ( +
+ {error.message} +
+ )} +
+
+ +
+
+ setInput(e.target.value)} + placeholder="Try “Recommend me an acoustic guitar”…" + className="flex-1 rounded-lg border border-gray-700 bg-gray-800 px-4 py-2 text-white placeholder-gray-500 focus:border-orange-500/50 focus:outline-none" + /> + {isLoading ? ( + + ) : ( + + )} +
+
+ + ) +} diff --git a/packages/ai-client/src/chat-client.ts b/packages/ai-client/src/chat-client.ts index cdf29ef78..6ec7ce563 100644 --- a/packages/ai-client/src/chat-client.ts +++ b/packages/ai-client/src/chat-client.ts @@ -9,8 +9,10 @@ import { import { createNoOpChatDevtoolsBridge } from './devtools-noop' import { fetcherToConnectionAdapter, + getChunkRunId, normalizeConnectionAdapter, } from './connection-adapters' +import { ChatPersistor } from './client-persistor' import type { AnyClientTool, ContentPart, @@ -95,6 +97,11 @@ export class ChatClient< private connection: SubscribeConnectionAdapter private readonly uniqueId: string private readonly threadId: string + // All persistence concerns (hydrate / save / clear, plus suppression of late + // chunks after a mid-stream clear) live in ChatPersistor so this class stays + // focused on streaming. Undefined when no `persistence` adapter is configured. + private readonly persistor?: ChatPersistor + private currentRunId: string | null = null // Track the legacy `body` option and the canonical `forwardedProps` // option as separate slots so that `updateOptions({ forwardedProps })` // doesn't wipe a previously-set `body` (and vice versa). They are @@ -163,6 +170,13 @@ export class ChatClient< constructor(options: ChatClientOptions) { this.uniqueId = options.id || this.generateUniqueId('chat') this.threadId = options.threadId || this.generateUniqueId('thread') + if (options.persistence) { + this.persistor = new ChatPersistor( + options.persistence, + this.uniqueId, + (messages) => this.processor.setMessages(messages), + ) + } // Both `body` (deprecated) and `forwardedProps` populate the AG-UI // `RunAgentInput.forwardedProps` wire field. They are stored // separately so `updateOptions` can replace one without touching the @@ -208,15 +222,19 @@ export class ChatClient< // Create StreamProcessor with event handlers. // Use conditional spreads so we don't pass `undefined` into // `StreamProcessorOptions` fields under `exactOptionalPropertyTypes`. + const persistedMessages = this.persistor?.readInitial() + const initialMessages = Array.isArray(persistedMessages) + ? persistedMessages + : options.initialMessages + this.processor = new StreamProcessor({ ...(options.streamProcessor?.chunkStrategy ? { chunkStrategy: options.streamProcessor.chunkStrategy } : {}), - ...(options.initialMessages - ? { initialMessages: options.initialMessages } - : {}), + ...(initialMessages ? { initialMessages } : {}), events: { onMessagesChange: (messages: Array) => { + this.persistor?.notifyMessagesChanged(messages) this.callbacksRef.current.onMessagesChange(messages) }, onStreamStart: () => { @@ -413,6 +431,8 @@ export class ChatClient< }, }, }) + + this.persistor?.hydrateAsync(persistedMessages) } mountDevtools(): void { @@ -424,6 +444,51 @@ export class ChatClient< this.devtoolsBridge.mountWithTools(this.processor.getMessages().length) } + /** + * Drain a runId-less RUN_ERROR that belongs to a cleared run the client is + * still tracking. The persistor owns the cleared-run bookkeeping; the client + * owns the active-run / session / processing state. + */ + private drainIgnoredRunlessChunk(chunk: StreamChunk): void { + if (chunk.type !== 'RUN_ERROR') return + const runId = this.persistor?.takeRunlessRunId() + if (!runId) return + this.activeRunIds.delete(runId) + this.setSessionGenerating(this.activeRunIds.size > 0) + this.resolveProcessing() + } + + private updateRunLifecycle( + chunk: StreamChunk, + options?: { resolveProcessing?: boolean }, + ): void { + if (chunk.type === 'RUN_STARTED') { + const chunkRunId = getChunkRunId(chunk) ?? chunk.runId + this.activeRunIds.add(chunkRunId) + this.persistor?.onRunStarted(chunkRunId) + this.setSessionGenerating(true) + return + } + + if (chunk.type !== 'RUN_FINISHED' && chunk.type !== 'RUN_ERROR') { + return + } + + const runId = getChunkRunId(chunk) + if (runId) { + this.activeRunIds.delete(runId) + this.persistor?.onRunSettled(runId) + } else if (chunk.type === 'RUN_ERROR') { + // RUN_ERROR without runId is a session-level error; clear all runs. + this.activeRunIds.clear() + this.persistor?.onSessionRunError() + } + this.setSessionGenerating(this.activeRunIds.size > 0) + if (options?.resolveProcessing !== false) { + this.resolveProcessing() + } + } + private generateUniqueId(prefix: string): string { return `${prefix}-${Date.now()}-${Math.random().toString(36).substring(7)}` } @@ -461,6 +526,7 @@ export class ChatClient< private resetSessionGenerating(): void { this.activeRunIds.clear() + this.persistor?.resetIgnored() this.setSessionGenerating(false) } @@ -609,33 +675,27 @@ export class ChatClient< if (this.connectionStatus === 'connecting') { this.setConnectionStatus('connected') } - this.callbacksRef.current.onChunk(chunk) - if (chunk.type === 'RUN_STARTED') { - this.activeRunIds.add(chunk.runId) - this.setSessionGenerating(true) + const shouldIgnore = this.persistor?.shouldIgnoreChunk(chunk) ?? false + if (shouldIgnore) { + if (chunk.type === 'RUN_FINISHED' || chunk.type === 'RUN_ERROR') { + if (getChunkRunId(chunk)) { + this.updateRunLifecycle(chunk, { resolveProcessing: false }) + } else { + this.drainIgnoredRunlessChunk(chunk) + } + } + continue } + this.callbacksRef.current.onChunk(chunk) this.devtoolsBridge.observeChunk(chunk) this.processor.processChunk(chunk) - // RUN_FINISHED / RUN_ERROR signal run completion — resolve processing - // (redundant if onStreamEnd already resolved it, harmless) - if (chunk.type === 'RUN_FINISHED' || chunk.type === 'RUN_ERROR') { - // RUN_FINISHED has runId in its schema; RUN_ERROR carries it via the - // AG-UI passthrough so adapters can correlate per-run errors. Extract - // both so a RUN_ERROR with a runId only clears that run, not every - // active run in the session. - const runId = - 'runId' in chunk && typeof chunk.runId === 'string' - ? chunk.runId - : undefined - if (runId) { - this.activeRunIds.delete(runId) - } else if (chunk.type === 'RUN_ERROR') { - // RUN_ERROR without runId is a session-level error; clear all runs - this.activeRunIds.clear() - } - this.setSessionGenerating(this.activeRunIds.size > 0) - this.resolveProcessing() - } + // Run lifecycle (active-run tracking, session-generating state, and + // processing resolution for RUN_FINISHED / RUN_ERROR) is handled in a + // single place so the ignored-chunk path above and this path can't + // diverge. RUN_ERROR carries its runId via the AG-UI passthrough so a + // per-run error only clears that run, while a runId-less RUN_ERROR is + // treated as a session-level error that clears every active run. + this.updateRunLifecycle(chunk) // Yield control back to event loop for UI updates await new Promise((resolve) => setTimeout(resolve, 0)) } @@ -794,6 +854,8 @@ export class ChatClient< // Track generation so a superseded stream's cleanup doesn't clobber the new one const generation = ++this.streamGeneration + const runId = `run-${Date.now()}-${Math.random().toString(36).slice(2, 8)}` + this.currentRunId = runId this.setIsLoading(true) this.setStatus('submitted') @@ -874,7 +936,7 @@ export class ChatClient< // serialize to an unusable shape. const runContext = { threadId: this.threadId, - runId: `run-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`, + runId, clientTools: Array.from(clientTools.values()).map((t) => ({ name: t.name, description: t.description, @@ -967,6 +1029,7 @@ export class ChatClient< this.currentStreamId = null this.devtoolsBridge.setCurrentStreamId(null) this.currentMessageId = null + this.currentRunId = null this.activeClientTools = null this.activeContext = undefined this.abortController = null @@ -1084,7 +1147,11 @@ export class ChatClient< * Stop the current stream */ stop(): void { + const hadLocalStream = this.abortController !== null this.cancelInFlightStream({ setReadyStatus: true }) + if (hadLocalStream) { + this.resetSessionGenerating() + } this.events.stopped() } @@ -1092,7 +1159,24 @@ export class ChatClient< * Clear all messages */ clear(): void { + if (this.persistor) { + this.persistor.snapshotClear({ + messages: this.processor.getMessages(), + activeRunIds: this.activeRunIds, + currentRunId: this.currentRunId, + }) + if (this.isLoading) { + this.cancelInFlightStream({ setReadyStatus: true }) + this.resetSessionGenerating() + } else if (this.activeRunIds.size > 0) { + this.resetSessionGenerating() + } + // Suppress persisting the empty snapshot that clearMessages emits, then + // remove the stored conversation outright. + this.persistor.beginClear() + } this.processor.clearMessages() + this.persistor?.remove() this.setError(undefined) this.events.messagesCleared() } diff --git a/packages/ai-client/src/client-persistor.ts b/packages/ai-client/src/client-persistor.ts new file mode 100644 index 000000000..517aaf0a7 --- /dev/null +++ b/packages/ai-client/src/client-persistor.ts @@ -0,0 +1,337 @@ +import { getChunkRunId } from './connection-adapters' +import type { StreamChunk } from '@tanstack/ai/client' +import type { ChatClientPersistence, UIMessage } from './types' + +// `StreamChunk` is a discriminated union; `toolCallId` / `messageId` / +// `parentMessageId` exist on only some members. Narrow with `in` (matching +// `getChunkRunId`) instead of asserting a shape, so the field's real type is +// preserved and a protocol rename can't be read past silently. +function getChunkToolCallId(chunk: StreamChunk): string | undefined { + return 'toolCallId' in chunk && typeof chunk.toolCallId === 'string' + ? chunk.toolCallId + : undefined +} + +function getChunkMessageId(chunk: StreamChunk): string | undefined { + return 'messageId' in chunk && typeof chunk.messageId === 'string' + ? chunk.messageId + : undefined +} + +function getChunkParentMessageId(chunk: StreamChunk): string | undefined { + return 'parentMessageId' in chunk && typeof chunk.parentMessageId === 'string' + ? chunk.parentMessageId + : undefined +} + +/** + * Encapsulates everything persistence-related for `ChatClient` so the client + * itself stays focused on streaming and message state. + * + * Two responsibilities live here: + * + * 1. **Storage orchestration** — hydrate from `getItem(id)` on creation, save to + * `setItem(id, messages)` on every change through an ordered write queue, and + * `removeItem(id)` on clear. A generation counter discards stale writes when a + * removal or a newer conversation supersedes an in-flight async operation. + * 2. **Clear-during-stream suppression** — when a conversation is cleared while a + * stream is still producing, late chunks for the cleared run(s) must not + * repopulate the now-empty state. The persistor tracks the cleared ids and + * decides, per chunk, whether the client should ignore it. + * + * All adapter calls are best-effort: a throwing or rejecting adapter is swallowed + * so storage problems never break the chat. + */ +export class ChatPersistor { + // --- storage queue state --- + private skipNextPersist = false + private generation = 0 + private queue: Promise = Promise.resolve() + private queuePending = false + // Bumped on every message change; lets an in-flight async hydration detect + // that the message list moved on and avoid clobbering it. + private messagesGeneration = 0 + + // --- clear-during-stream suppression state --- + private readonly clearedMessageIds = new Set() + private readonly clearedRunIds = new Set() + private readonly ignoredActiveRunIds = new Set() + private readonly clearedToolCallIds = new Set() + private currentRunlessRunId: string | null = null + + constructor( + private readonly adapter: ChatClientPersistence, + private readonly id: string, + private readonly applyMessages: (messages: Array) => void, + ) {} + + // --------------------------------------------------------------------------- + // Storage orchestration + // --------------------------------------------------------------------------- + + /** + * Synchronously read the persisted messages for constructor-time hydration. + * Returns the raw `getItem` result (which may be a promise for async stores). + */ + readInitial(): + | Array + | null + | undefined + | Promise | null | undefined> { + try { + return this.adapter.getItem(this.id) + } catch { + return undefined + } + } + + /** + * Apply messages from an async `getItem` once it resolves, unless the message + * list has already changed since hydration began. + */ + hydrateAsync( + persistedMessages: + | Array + | null + | undefined + | Promise | null | undefined>, + ): void { + if (!(persistedMessages instanceof Promise)) { + return + } + + const hydrationGeneration = this.messagesGeneration + persistedMessages + .then((messages) => { + if ( + Array.isArray(messages) && + this.messagesGeneration === hydrationGeneration + ) { + this.applyMessages(messages) + } + }) + .catch(() => { + // Persistence adapters are best-effort and must not break chat setup. + }) + } + + /** + * Record a message-list change and queue a `setItem` write for it. Skips a + * single write after {@link beginClear} so the clear's empty snapshot isn't + * persisted between `clearMessages()` and {@link remove}. + */ + notifyMessagesChanged(messages: Array): void { + this.messagesGeneration++ + if (this.skipNextPersist) { + this.skipNextPersist = false + return + } + const generation = this.generation + const messagesSnapshot = [...messages] + this.runOperation(() => { + if (generation !== this.generation) { + return + } + return this.adapter.setItem(this.id, messagesSnapshot) + }) + } + + /** Remove the persisted conversation. Invalidates any queued writes. */ + remove(): void { + const generation = ++this.generation + this.runOperation(() => { + if (generation !== this.generation) { + return + } + return this.adapter.removeItem(this.id) + }) + } + + private runOperation(operation: () => void | Promise): void { + if (this.queuePending) { + const queued = this.queue.then(operation).catch(() => { + // Persistence adapters are best-effort and must not break chat updates. + }) + this.queue = queued + void queued.finally(() => { + if (this.queue === queued) { + this.queuePending = false + } + }) + return + } + + try { + const result = operation() + if (result instanceof Promise) { + this.queuePending = true + const queued = result.catch(() => { + // Persistence adapters are best-effort and must not break chat updates. + }) + this.queue = queued + void queued.finally(() => { + if (this.queue === queued) { + this.queuePending = false + } + }) + } + } catch { + // Persistence adapters are best-effort and must not break chat updates. + } + } + + // --------------------------------------------------------------------------- + // Clear-during-stream suppression + // --------------------------------------------------------------------------- + + /** + * Capture the message/run ids that exist at the moment of a clear so chunks + * still arriving for them can be ignored. + */ + snapshotClear(context: { + messages: Array + activeRunIds: Set + currentRunId: string | null + }): void { + for (const message of context.messages) { + this.clearedMessageIds.add(message.id) + } + for (const runId of context.activeRunIds) { + this.clearedRunIds.add(runId) + this.ignoredActiveRunIds.add(runId) + } + if (context.currentRunId) { + this.clearedRunIds.add(context.currentRunId) + this.ignoredActiveRunIds.add(context.currentRunId) + } + } + + /** Mark that the next persisted message change (the clear itself) is skipped. */ + beginClear(): void { + this.skipNextPersist = true + } + + /** Whether a chunk belongs to cleared state and should not be processed. */ + shouldIgnoreChunk(chunk: StreamChunk): boolean { + const runId = getChunkRunId(chunk) + if (runId && this.clearedRunIds.has(runId)) { + if (chunk.type === 'RUN_STARTED') { + this.ignoredActiveRunIds.add(runId) + this.currentRunlessRunId = runId + } + this.markIgnoredChunkIds(chunk) + return true + } + + if (runId && this.ignoredActiveRunIds.has(runId)) { + this.markIgnoredChunkIds(chunk) + return true + } + + if (this.isRunlessChunkFromIgnoredRun(chunk)) { + this.markIgnoredChunkIds(chunk) + return true + } + + const toolCallId = getChunkToolCallId(chunk) + if (toolCallId && this.clearedToolCallIds.has(toolCallId)) { + return true + } + + const parentMessageId = getChunkParentMessageId(chunk) + if (parentMessageId && this.clearedMessageIds.has(parentMessageId)) { + if (toolCallId) { + this.clearedToolCallIds.add(toolCallId) + } + return true + } + + const messageId = getChunkMessageId(chunk) + if (!messageId) { + return false + } + if (this.clearedMessageIds.has(messageId)) { + return true + } + + return false + } + + /** + * The owning client calls this when a run starts so runless content chunks + * (adapters that omit `runId` on content events) can be attributed to it. + */ + onRunStarted(runId: string): void { + this.currentRunlessRunId = runId + } + + /** Forget a settled run, advancing the runless pointer to another ignored run. */ + onRunSettled(runId: string): void { + this.ignoredActiveRunIds.delete(runId) + this.clearedRunIds.delete(runId) + if (this.currentRunlessRunId === runId) { + this.currentRunlessRunId = + this.ignoredActiveRunIds.values().next().value ?? null + } + } + + /** A session-level (runId-less) RUN_ERROR clears all ignored-run tracking. */ + onSessionRunError(): void { + this.ignoredActiveRunIds.clear() + this.currentRunlessRunId = null + } + + /** Clear the ignored-active-run markers (mirrors a session-generating reset). */ + resetIgnored(): void { + this.ignoredActiveRunIds.clear() + } + + /** + * Consume the current runless run id (if any), forgetting it. Used when an + * ignored, runId-less RUN_ERROR drains the run the client is still tracking. + */ + takeRunlessRunId(): string | null { + const runId = this.currentRunlessRunId + if (!runId) return null + this.ignoredActiveRunIds.delete(runId) + this.clearedRunIds.delete(runId) + // Advance to another still-ignored run (mirroring `onRunSettled`) so that + // when two cleared runs drain concurrently, draining one via a runId-less + // RUN_ERROR doesn't stop suppressing the other's runless content. + this.currentRunlessRunId = + this.ignoredActiveRunIds.values().next().value ?? null + return runId + } + + private markIgnoredChunkIds(chunk: StreamChunk): void { + const messageId = getChunkMessageId(chunk) + if (messageId) { + this.clearedMessageIds.add(messageId) + } + const toolCallId = getChunkToolCallId(chunk) + if (toolCallId) { + this.clearedToolCallIds.add(toolCallId) + } + } + + private isRunlessChunkFromIgnoredRun(chunk: StreamChunk): boolean { + const runId = getChunkRunId(chunk) + if (runId || !this.currentRunlessRunId) return false + if ( + !this.ignoredActiveRunIds.has(this.currentRunlessRunId) && + !this.clearedRunIds.has(this.currentRunlessRunId) + ) { + return false + } + return ( + chunk.type === 'TEXT_MESSAGE_START' || + chunk.type === 'TEXT_MESSAGE_CONTENT' || + chunk.type === 'TOOL_CALL_START' || + chunk.type === 'TOOL_CALL_ARGS' || + chunk.type === 'TOOL_CALL_END' || + chunk.type === 'TOOL_CALL_RESULT' || + chunk.type === 'MESSAGES_SNAPSHOT' || + chunk.type === 'RUN_ERROR' + ) + } +} diff --git a/packages/ai-client/src/connection-adapters.ts b/packages/ai-client/src/connection-adapters.ts index 53e4d629d..3c4010047 100644 --- a/packages/ai-client/src/connection-adapters.ts +++ b/packages/ai-client/src/connection-adapters.ts @@ -13,6 +13,26 @@ import type { } from '@tanstack/ai/client' import type { ChatFetcher } from './types' +/** + * Associates connect-wrapped chunks with the run they were produced under. + * Content events (TEXT_MESSAGE_CONTENT, TOOL_CALL_*, …) carry no `runId` of + * their own, so the connect wrapper stamps the caller's run id here. Lets + * run-scoped consumers (e.g. clear-during-stream suppression) attribute those + * otherwise-runless chunks to their originating request. + */ +const chunkRunIds = new WeakMap() + +/** + * Resolve a chunk's run id, preferring the value on the chunk itself + * (RUN_STARTED / RUN_FINISHED / RUN_ERROR carry one) and falling back to the + * run the connect wrapper stamped it with. + */ +export function getChunkRunId(chunk: StreamChunk): string | undefined { + return 'runId' in chunk && typeof chunk.runId === 'string' + ? chunk.runId + : chunkRunIds.get(chunk) +} + /** * Thrown when an SSE/HTTP stream ends with a non-empty unterminated buffer. * Indicates the connection was cut mid-line (server crash, dropped TCP, proxy @@ -265,7 +285,10 @@ export function normalizeConnectionAdapter( let activeBuffer: Array = [] let activeWaiters: Array<(chunk: StreamChunk | null) => void> = [] - function push(chunk: StreamChunk): void { + function push(chunk: StreamChunk, runId?: string): void { + if (runId) { + chunkRunIds.set(chunk, runId) + } const waiter = activeWaiters.shift() if (waiter) { waiter(chunk) @@ -324,7 +347,7 @@ export function normalizeConnectionAdapter( if (chunk.type === 'RUN_FINISHED' || chunk.type === 'RUN_ERROR') { hasTerminalEvent = true } - push(chunk) + push(chunk, runContext?.runId) } // If the connect stream ended cleanly without a terminal event, diff --git a/packages/ai-client/src/index.ts b/packages/ai-client/src/index.ts index 21bdf9ce5..aa191a4b2 100644 --- a/packages/ai-client/src/index.ts +++ b/packages/ai-client/src/index.ts @@ -12,6 +12,7 @@ export type { ThinkingPart, StructuredOutputPart, // Client configuration types + ChatClientPersistence, ChatClientOptions, ClientContextOptionFromTools, ChatRequestBody, diff --git a/packages/ai-client/src/types.ts b/packages/ai-client/src/types.ts index 426b05537..5127a33de 100644 --- a/packages/ai-client/src/types.ts +++ b/packages/ai-client/src/types.ts @@ -266,6 +266,23 @@ export interface UIMessage< createdAt?: Date } +export interface ChatClientPersistence< + TTools extends ReadonlyArray = any, +> { + getItem: ( + id: string, + ) => + | Array> + | null + | undefined + | Promise> | null | undefined> + setItem: ( + id: string, + messages: Array>, + ) => void | Promise + removeItem: (id: string) => void | Promise +} + type IsUnknown = unknown extends T ? [T] extends [unknown] ? true @@ -356,6 +373,11 @@ export interface ChatClientBaseOptions< */ initialMessages?: Array> + /** + * Optional persistence adapter for chat messages. + */ + persistence?: ChatClientPersistence + /** * Unique identifier for this chat instance * Used for managing multiple chats diff --git a/packages/ai-client/tests/chat-client.test.ts b/packages/ai-client/tests/chat-client.test.ts index 83e139e47..7f8cb66c4 100644 --- a/packages/ai-client/tests/chat-client.test.ts +++ b/packages/ai-client/tests/chat-client.test.ts @@ -13,10 +13,44 @@ import type { ConnectConnectionAdapter, ConnectionAdapter, } from '../src/connection-adapters' -import type { StreamChunk } from '@tanstack/ai/client' -import type { UIMessage } from '../src/types' +import type { ModelMessage, StreamChunk } from '@tanstack/ai/client' +import type { ChatClientPersistence, UIMessage } from '../src/types' describe('ChatClient', () => { + const persistedMessage: UIMessage = { + id: 'persisted-1', + role: 'user', + parts: [{ type: 'text', content: 'Persisted hello' }], + createdAt: new Date('2024-01-01T00:00:00.000Z'), + } + + const initialMessage: UIMessage = { + id: 'initial-1', + role: 'user', + parts: [{ type: 'text', content: 'Initial hello' }], + createdAt: new Date('2024-01-02T00:00:00.000Z'), + } + + function createPersistence( + storedMessages?: Array | null, + ): ChatClientPersistence { + return { + getItem: vi.fn(() => storedMessages), + setItem: vi.fn(), + removeItem: vi.fn(), + } + } + + function createDeferred() { + let resolve!: (value: T) => void + let reject!: (reason?: unknown) => void + const promise = new Promise((promiseResolve, promiseReject) => { + resolve = promiseResolve + reject = promiseReject + }) + return { promise, resolve, reject } + } + describe('constructor', () => { it('should create a client with default options', () => { const adapter = createMockConnectionAdapter() @@ -48,6 +82,151 @@ describe('ChatClient', () => { expect(client.getMessages()).toEqual(initialMessages) }) + it('should hydrate messages from persistence', () => { + const adapter = createMockConnectionAdapter() + const persistence = createPersistence([persistedMessage]) + + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + expect(persistence.getItem).toHaveBeenCalledWith('chat-1') + expect(client.getMessages()).toEqual([persistedMessage]) + }) + + it('should prefer persisted messages over initial messages', () => { + const adapter = createMockConnectionAdapter() + const persistence = createPersistence([persistedMessage]) + + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + initialMessages: [initialMessage], + persistence, + }) + + expect(client.getMessages()).toEqual([persistedMessage]) + }) + + it('should fall back to initial messages when persistence returns null', () => { + const adapter = createMockConnectionAdapter() + const persistence = createPersistence(null) + + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + initialMessages: [initialMessage], + persistence, + }) + + expect(client.getMessages()).toEqual([initialMessage]) + }) + + it('should fall back to initial messages when persistence returns undefined', () => { + const adapter = createMockConnectionAdapter() + const persistence = createPersistence(undefined) + + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + initialMessages: [initialMessage], + persistence, + }) + + expect(client.getMessages()).toEqual([initialMessage]) + }) + + it('should let persisted empty arrays override initial messages', () => { + const adapter = createMockConnectionAdapter() + const persistence = createPersistence([]) + + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + initialMessages: [initialMessage], + persistence, + }) + + expect(client.getMessages()).toEqual([]) + }) + + it('should hydrate from async persistence and notify message listeners', async () => { + const adapter = createMockConnectionAdapter() + const onMessagesChange = vi.fn() + const persistence = { + getItem: vi.fn(() => Promise.resolve([persistedMessage])), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + initialMessages: [initialMessage], + onMessagesChange, + persistence, + }) + + expect(client.getMessages()).toEqual([initialMessage]) + + await vi.waitFor(() => { + expect(client.getMessages()).toEqual([persistedMessage]) + }) + + expect(onMessagesChange).toHaveBeenCalledWith([persistedMessage]) + }) + + it('should ignore async persistence hydration after local message changes', async () => { + const adapter = createMockConnectionAdapter() + const deferred = createDeferred>() + const persistence = { + getItem: vi.fn(() => deferred.promise), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + initialMessages: [initialMessage], + persistence, + }) + + client.setMessagesManually([ + { + id: 'local-1', + role: 'user', + parts: [{ type: 'text', content: 'Local change' }], + createdAt: new Date('2024-01-03T00:00:00.000Z'), + }, + ]) + + deferred.resolve([persistedMessage]) + await new Promise((resolve) => setTimeout(resolve, 0)) + + expect(client.getMessages()).toEqual([ + { + id: 'local-1', + role: 'user', + parts: [{ type: 'text', content: 'Local change' }], + createdAt: new Date('2024-01-03T00:00:00.000Z'), + }, + ]) + }) + + it('should keep current constructor behavior when persistence is omitted', () => { + const adapter = createMockConnectionAdapter() + + const client = new ChatClient({ + connection: adapter, + initialMessages: [initialMessage], + }) + + expect(client.getMessages()).toEqual([initialMessage]) + }) + it('should use provided id or generate one', async () => { const adapter = createMockConnectionAdapter({ chunks: createTextChunks('Response'), @@ -103,49 +282,1047 @@ describe('ChatClient', () => { while (!signal?.aborted) { if (!hasPendingSend) { await new Promise((resolve) => { - removeAbortListener?.() - removeAbortListener = null - wakeSubscriber = resolve - const onAbort = () => resolve() - signal?.addEventListener('abort', onAbort, { once: true }) - removeAbortListener = () => { - signal?.removeEventListener('abort', onAbort) - } + removeAbortListener?.() + removeAbortListener = null + wakeSubscriber = resolve + const onAbort = () => resolve() + signal?.addEventListener('abort', onAbort, { once: true }) + removeAbortListener = () => { + signal?.removeEventListener('abort', onAbort) + } + }) + continue + } + + hasPendingSend = false + for (const chunk of chunksToSend) { + yield chunk + } + } + removeAbortListener?.() + removeAbortListener = null + })() + }) + + const send = vi.fn(async () => { + removeAbortListener?.() + removeAbortListener = null + hasPendingSend = true + wakeSubscriber?.() + wakeSubscriber = null + }) + + return { subscribe, send } + } + + it('should use subscribe/send adapter mode', async () => { + const adapter = createSubscribeAdapter( + createTextChunks('From subscribe/send mode'), + ) + const client = new ChatClient({ connection: adapter }) + + await client.sendMessage('Hello') + + expect(adapter.subscribe).toHaveBeenCalled() + expect(adapter.send).toHaveBeenCalled() + }) + + it('should ignore native subscribe/send chunks from a cleared persisted request without runId', async () => { + let storedMessages: Array | undefined + const releaseFirstResponse = createDeferred() + const queuedChunks: Array<{ + prompt: string + chunks: Array + }> = [] + let wakeSubscriber: (() => void) | null = null + const adapter: ConnectionAdapter = { + subscribe: vi.fn((_signal?: AbortSignal) => { + return (async function* () { + while (true) { + if (queuedChunks.length === 0) { + await new Promise((resolve) => { + wakeSubscriber = resolve + }) + } + const next = queuedChunks.shift() + if (!next) continue + if (next.prompt === 'A') { + const [started, ...remainingChunks] = next.chunks + if (started) { + yield started + } + await releaseFirstResponse.promise + yield* remainingChunks + continue + } + yield* next.chunks + } + })() + }), + send: vi.fn( + async ( + messages: Array | Array, + _data, + _signal, + runContext, + ) => { + const prompt = messages + .flatMap((message) => ('parts' in message ? message.parts : [])) + .find((part) => part.type === 'text')?.content + + queuedChunks.push({ + prompt: prompt ?? '', + chunks: [ + { + type: EventType.RUN_STARTED, + threadId: runContext?.threadId ?? 'thread-1', + runId: runContext?.runId ?? 'run-1', + timestamp: Date.now(), + } as StreamChunk, + ...createTextChunks( + prompt === 'A' ? 'stale A' : 'fresh B', + prompt === 'A' ? 'msg-a' : 'msg-b', + ).map((chunk) => { + if (chunk.type === 'TEXT_MESSAGE_CONTENT') { + const { runId: _runId, ...withoutRunId } = chunk + return withoutRunId as StreamChunk + } + if (chunk.type === 'RUN_FINISHED') { + return { + ...chunk, + threadId: runContext?.threadId ?? chunk.threadId, + runId: runContext?.runId ?? chunk.runId, + } as StreamChunk + } + return chunk + }), + ], + }) + wakeSubscriber?.() + wakeSubscriber = null + }, + ), + } + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn((_key: string, messages: Array) => { + storedMessages = messages + }), + removeItem: vi.fn(() => { + storedMessages = undefined + }), + } + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + const firstSend = client.sendMessage('A') + await vi.waitFor(() => { + expect(client.getIsLoading()).toBe(true) + }) + + client.clear() + const secondSend = client.sendMessage('B') + releaseFirstResponse.resolve() + await firstSend + await secondSend + + const finalText = client + .getMessages() + .flatMap((message) => message.parts) + .filter((part) => part.type === 'text') + .map((part) => part.content) + .join('') + + expect(finalText).toContain('B') + expect(finalText).toContain('fresh B') + expect(finalText).not.toContain('A') + expect(finalText).not.toContain('stale A') + expect(storedMessages).toEqual(client.getMessages()) + expect( + storedMessages + ?.flatMap((message) => message.parts) + .filter((part) => part.type === 'text') + .map((part) => part.content) + .join(''), + ).not.toContain('stale A') + }) + + it('should ignore already-started runless chunks from a cleared persisted request', async () => { + let storedMessages: Array | undefined + const releaseStaleChunks = createDeferred() + const staleChunksAttempted = createDeferred() + let wakeSubscriber: (() => void) | null = null + let queued = false + const adapter: ConnectionAdapter = { + subscribe: vi.fn( + (_signal?: AbortSignal): AsyncIterable => { + return (async function* () { + while (true) { + if (!queued) { + await new Promise((resolve) => { + wakeSubscriber = resolve + }) + } + queued = false + yield { + type: EventType.RUN_STARTED, + threadId: 'thread-1', + runId: 'run-cleared', + timestamp: Date.now(), + } as StreamChunk + await releaseStaleChunks.promise + yield { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: 'stale-message', + timestamp: Date.now(), + delta: 'stale content', + content: 'stale content', + } as StreamChunk + staleChunksAttempted.resolve() + yield { + type: EventType.RUN_FINISHED, + threadId: 'thread-1', + runId: 'run-cleared', + timestamp: Date.now(), + } as StreamChunk + } + })() + }, + ), + send: vi.fn(async () => { + queued = true + wakeSubscriber?.() + wakeSubscriber = null + }), + } + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn((_key: string, messages: Array) => { + storedMessages = messages + }), + removeItem: vi.fn(() => { + storedMessages = undefined + }), + } + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + const sendPromise = client.sendMessage('A') + await vi.waitFor(() => { + expect(client.getSessionGenerating()).toBe(true) + }) + + client.clear() + releaseStaleChunks.resolve() + await staleChunksAttempted.promise + await sendPromise + + expect(client.getMessages()).toEqual([]) + expect(storedMessages).toBeUndefined() + expect(client.getSessionGenerating()).toBe(false) + expect(client.getError()).toBeUndefined() + }) + + it('should keep fresh runless chunks after a clear when a fresh run starts before stale chunks drain', async () => { + let storedMessages: Array | undefined + const releaseStaleChunks = createDeferred() + const staleChunksAttempted = createDeferred() + const queuedChunks: Array<{ + prompt: string + chunks: Array + }> = [] + let staleReleased = false + let wakeSubscriber: (() => void) | null = null + const wakeQueuedSubscriber = () => { + const wake = wakeSubscriber + wakeSubscriber = null + wake?.() + } + const adapter: ConnectionAdapter = { + subscribe: vi.fn( + (_signal?: AbortSignal): AsyncIterable => { + return (async function* () { + while (true) { + if (queuedChunks.length === 0) { + await new Promise((resolve) => { + wakeSubscriber = resolve + }) + } + const freshIndex = queuedChunks.findIndex( + (queued) => queued.prompt === 'B', + ) + const next = + freshIndex >= 0 + ? queuedChunks.splice(freshIndex, 1)[0] + : queuedChunks.shift() + if (!next) continue + yield next.chunks[0]! + if (next.prompt === 'A' && !staleReleased) { + queuedChunks.push({ + prompt: 'A-after-start', + chunks: next.chunks.slice(1), + }) + continue + } + if (next.prompt === 'A-after-start' && !staleReleased) { + queuedChunks.push(next) + await new Promise((resolve) => { + wakeSubscriber = resolve + }) + continue + } + for (const chunk of next.chunks.slice(1)) { + yield chunk + } + if (next.prompt === 'A-after-start') { + staleChunksAttempted.resolve() + } + } + })() + }, + ), + send: vi.fn( + async ( + messages: Array | Array, + _data, + _signal, + runContext, + ) => { + const prompt = messages + .flatMap((message) => ('parts' in message ? message.parts : [])) + .find((part) => part.type === 'text')?.content + const messageId = prompt === 'A' ? 'stale-message' : 'fresh-message' + queuedChunks.push({ + prompt: prompt ?? '', + chunks: [ + { + type: EventType.RUN_STARTED, + threadId: runContext?.threadId ?? 'thread-1', + runId: + prompt === 'A' + ? 'run-cleared' + : (runContext?.runId ?? 'run-fresh'), + timestamp: Date.now(), + } as StreamChunk, + { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId, + timestamp: Date.now(), + delta: prompt === 'A' ? 'stale content' : 'fresh content', + content: prompt === 'A' ? 'stale content' : 'fresh content', + } as StreamChunk, + { + type: EventType.RUN_FINISHED, + threadId: runContext?.threadId ?? 'thread-1', + runId: + prompt === 'A' + ? 'run-cleared' + : (runContext?.runId ?? 'run-fresh'), + timestamp: Date.now(), + } as StreamChunk, + ], + }) + wakeQueuedSubscriber() + }, + ), + } + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn((_key: string, messages: Array) => { + storedMessages = messages + }), + removeItem: vi.fn(() => { + storedMessages = undefined + }), + } + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + const firstSend = client.sendMessage('A') + await vi.waitFor(() => { + expect(client.getSessionGenerating()).toBe(true) + }) + + client.clear() + const secondSend = client.sendMessage('B') + await secondSend + staleReleased = true + releaseStaleChunks.resolve() + wakeQueuedSubscriber() + await staleChunksAttempted.promise + firstSend.catch(() => { + // The stale request may have been superseded by the fresh send. + }) + + const finalText = client + .getMessages() + .flatMap((message) => message.parts) + .filter((part) => part.type === 'text') + .map((part) => part.content) + .join('') + + expect(finalText).toContain('B') + expect(finalText).toContain('fresh content') + expect(finalText).not.toContain('A') + expect(finalText).not.toContain('stale content') + expect(storedMessages).toEqual(client.getMessages()) + }) + + it('should ignore stale messages snapshot after persisted clear', async () => { + let storedMessages: Array | undefined + const releaseSnapshot = createDeferred() + const snapshotAttempted = createDeferred() + const adapter: ConnectionAdapter = { + async *connect(_messages, _data, _signal, runContext) { + yield { + type: EventType.RUN_STARTED, + threadId: runContext?.threadId ?? 'thread-1', + runId: runContext?.runId ?? 'run-1', + timestamp: Date.now(), + } as StreamChunk + await releaseSnapshot.promise + yield { + type: EventType.MESSAGES_SNAPSHOT, + messages: [ + { + id: 'stale-assistant', + role: 'assistant', + content: 'stale snapshot', + }, + ], + } as StreamChunk + snapshotAttempted.resolve() + yield { + type: EventType.RUN_FINISHED, + threadId: runContext?.threadId ?? 'thread-1', + runId: runContext?.runId ?? 'run-1', + timestamp: Date.now(), + } as StreamChunk + }, + } + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn((_key: string, messages: Array) => { + storedMessages = messages + }), + removeItem: vi.fn(() => { + storedMessages = undefined + }), + } + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + const sendPromise = client.sendMessage('A') + await vi.waitFor(() => { + expect(client.getSessionGenerating()).toBe(true) + }) + + client.clear() + releaseSnapshot.resolve() + await snapshotAttempted.promise + await sendPromise + + expect(client.getMessages()).toEqual([]) + expect(storedMessages).toBeUndefined() + }) + + it('should ignore stale runless run error after persisted clear', async () => { + let storedMessages: Array | undefined + const onError = vi.fn() + const releaseError = createDeferred() + const errorAttempted = createDeferred() + const adapter: ConnectionAdapter = { + async *connect(_messages, _data, _signal, runContext) { + yield { + type: EventType.RUN_STARTED, + threadId: runContext?.threadId ?? 'thread-1', + runId: runContext?.runId ?? 'run-1', + timestamp: Date.now(), + } as StreamChunk + await releaseError.promise + yield { + type: EventType.RUN_ERROR, + threadId: runContext?.threadId ?? 'thread-1', + timestamp: Date.now(), + message: 'stale failure', + error: { message: 'stale failure' }, + } as StreamChunk + errorAttempted.resolve() + }, + } + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn((_key: string, messages: Array) => { + storedMessages = messages + }), + removeItem: vi.fn(() => { + storedMessages = undefined + }), + } + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + onError, + }) + + const sendPromise = client.sendMessage('A') + await vi.waitFor(() => { + expect(client.getSessionGenerating()).toBe(true) + }) + + client.clear() + releaseError.resolve() + await errorAttempted.promise + await sendPromise + + expect(client.getMessages()).toEqual([]) + expect(client.getError()).toBeUndefined() + expect(client.getStatus()).toBe('ready') + expect(onError).not.toHaveBeenCalled() + expect(storedMessages).toBeUndefined() + }) + + it('should keep fresh native subscribe/send chunks when they arrive before stale cleared chunks', async () => { + let storedMessages: Array | undefined + const releaseStaleResponse = createDeferred() + let staleReleased = false + const queuedChunks: Array<{ + prompt: string + chunks: Array + }> = [] + let wakeSubscriber: (() => void) | null = null + const adapter: ConnectionAdapter = { + subscribe: vi.fn((_signal?: AbortSignal) => { + return (async function* () { + while (true) { + if (queuedChunks.length === 0) { + await new Promise((resolve) => { + wakeSubscriber = resolve + }) + } + const freshIndex = queuedChunks.findIndex( + (queued) => queued.prompt !== 'A', + ) + const next = + !staleReleased && freshIndex > 0 + ? queuedChunks.splice(freshIndex, 1)[0] + : queuedChunks.shift() + if (!next) continue + if (next.prompt === 'A') { + if (!staleReleased) { + queuedChunks.push(next) + await new Promise((resolve) => { + wakeSubscriber = resolve + }) + continue + } + await releaseStaleResponse.promise + } + yield* next.chunks + } + })() + }), + send: vi.fn( + async ( + messages: Array | Array, + _data, + _signal, + runContext, + ) => { + const prompt = messages + .flatMap((message) => ('parts' in message ? message.parts : [])) + .find((part) => part.type === 'text')?.content + const messageId = prompt === 'A' ? 'msg-a' : 'msg-b' + + queuedChunks.push({ + prompt: prompt ?? '', + chunks: [ + { + type: EventType.RUN_STARTED, + threadId: runContext?.threadId ?? 'thread-1', + runId: runContext?.runId ?? `run-${messageId}`, + timestamp: Date.now(), + } as StreamChunk, + ...createTextChunks( + prompt === 'A' ? 'stale A' : 'fresh B', + messageId, + ).map((chunk) => { + if (chunk.type === 'TEXT_MESSAGE_CONTENT') { + const { runId: _runId, ...withoutRunId } = chunk + return withoutRunId as StreamChunk + } + if (chunk.type === 'RUN_FINISHED') { + return { + ...chunk, + threadId: runContext?.threadId ?? chunk.threadId, + runId: runContext?.runId ?? chunk.runId, + } as StreamChunk + } + return chunk + }), + ], + }) + wakeSubscriber?.() + wakeSubscriber = null + }, + ), + } + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn((_key: string, messages: Array) => { + storedMessages = messages + }), + removeItem: vi.fn(() => { + storedMessages = undefined + }), + } + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + const firstSend = client.sendMessage('A') + await vi.waitFor(() => { + expect(client.getIsLoading()).toBe(true) + }) + + client.clear() + const secondSend = client.sendMessage('B') + await secondSend + staleReleased = true + releaseStaleResponse.resolve() + const wake = wakeSubscriber as (() => void) | null + wake?.() + wakeSubscriber = null + await vi.waitFor(() => { + expect(adapter.send).toHaveBeenCalledTimes(2) + }) + firstSend.catch(() => { + // The stale request may already have been cancelled by clear(). + }) + + const finalText = client + .getMessages() + .flatMap((message) => message.parts) + .filter((part) => part.type === 'text') + .map((part) => part.content) + .join('') + + expect(finalText).toContain('B') + expect(finalText).toContain('fresh B') + expect(finalText).not.toContain('A') + expect(finalText).not.toContain('stale A') + expect(storedMessages).toEqual(client.getMessages()) + expect( + storedMessages + ?.flatMap((message) => message.parts) + .filter((part) => part.type === 'text') + .map((part) => part.content) + .join(''), + ).toContain('fresh B') + expect( + storedMessages + ?.flatMap((message) => message.parts) + .filter((part) => part.type === 'text') + .map((part) => part.content) + .join(''), + ).not.toContain('stale A') + }) + + it('should ignore stale tool chunks by cleared parentMessageId after persisted clear', async () => { + const releaseToolChunks = createDeferred() + const adapter: ConnectionAdapter = { + async *connect(_messages, _data, _signal, runContext) { + yield { + type: EventType.RUN_STARTED, + threadId: runContext?.threadId ?? 'thread-1', + runId: runContext?.runId ?? 'run-1', + timestamp: Date.now(), + } as StreamChunk + yield { + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'assistant-a', + model: 'test', + timestamp: Date.now(), + delta: '', + content: '', + } as StreamChunk + await releaseToolChunks.promise + yield { + type: 'TOOL_CALL_START', + toolCallId: 'stale-tool', + toolCallName: 'staleTool', + toolName: 'staleTool', + parentMessageId: 'assistant-a', + model: 'test', + timestamp: Date.now(), + index: 0, + } as StreamChunk + yield { + type: 'TOOL_CALL_ARGS', + toolCallId: 'stale-tool', + model: 'test', + timestamp: Date.now(), + delta: '{"stale":true}', + } as StreamChunk + }, + } + const persistence = createPersistence(undefined) + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + const sendPromise = client.sendMessage('A') + await vi.waitFor(() => { + expect( + client.getMessages().some((message) => message.id === 'assistant-a'), + ).toBe(true) + }) + + client.clear() + releaseToolChunks.resolve() + await sendPromise + + expect(client.getMessages()).toEqual([]) + expect(persistence.setItem).not.toHaveBeenLastCalledWith( + 'chat-1', + expect.arrayContaining([ + expect.objectContaining({ + id: 'assistant-a', + parts: expect.arrayContaining([ + expect.objectContaining({ type: 'tool-call' }), + ]), + }), + ]), + ) + }) + + it('should remember ignored stale runless message ids so child tool chunks are ignored after persisted clear', async () => { + const releaseToolChunks = createDeferred() + const adapter: ConnectionAdapter = { + async *connect(_messages, _data, _signal, runContext) { + yield { + type: EventType.RUN_STARTED, + threadId: runContext?.threadId ?? 'thread-1', + runId: runContext?.runId ?? 'run-1', + timestamp: Date.now(), + } as StreamChunk + await releaseToolChunks.promise + yield { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: 'stale-runless-message', + timestamp: Date.now(), + delta: 'stale text', + content: 'stale text', + } as StreamChunk + yield { + type: 'TOOL_CALL_START', + toolCallId: 'stale-child-tool', + toolCallName: 'staleTool', + toolName: 'staleTool', + parentMessageId: 'stale-runless-message', + model: 'test', + timestamp: Date.now(), + index: 0, + } as StreamChunk + yield { + type: 'TOOL_CALL_ARGS', + toolCallId: 'stale-child-tool', + model: 'test', + timestamp: Date.now(), + delta: '{"stale":true}', + } as StreamChunk + }, + } + const persistence = createPersistence(undefined) + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + const sendPromise = client.sendMessage('A') + await vi.waitFor(() => { + expect(client.getSessionGenerating()).toBe(true) + }) + + client.clear() + releaseToolChunks.resolve() + await sendPromise + + expect(client.getMessages()).toEqual([]) + expect(persistence.setItem).not.toHaveBeenLastCalledWith( + 'chat-1', + expect.arrayContaining([ + expect.objectContaining({ + id: 'stale-runless-message', + }), + ]), + ) + }) + + it('should ignore stale runless tool starts without parentMessageId after persisted clear', async () => { + const releaseToolChunks = createDeferred() + const adapter: ConnectionAdapter = { + async *connect(_messages, _data, _signal, runContext) { + yield { + type: EventType.RUN_STARTED, + threadId: runContext?.threadId ?? 'thread-1', + runId: runContext?.runId ?? 'run-1', + timestamp: Date.now(), + } as StreamChunk + await releaseToolChunks.promise + yield { + type: 'TOOL_CALL_START', + toolCallId: 'stale-parentless-tool', + toolCallName: 'staleTool', + toolName: 'staleTool', + model: 'test', + timestamp: Date.now(), + index: 0, + } as StreamChunk + yield { + type: 'TOOL_CALL_ARGS', + toolCallId: 'stale-parentless-tool', + model: 'test', + timestamp: Date.now(), + delta: '{"stale":true}', + } as StreamChunk + }, + } + const persistence = createPersistence(undefined) + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + const sendPromise = client.sendMessage('A') + await vi.waitFor(() => { + expect(client.getSessionGenerating()).toBe(true) + }) + + client.clear() + releaseToolChunks.resolve() + await sendPromise + + expect(client.getMessages()).toEqual([]) + expect(persistence.setItem).not.toHaveBeenLastCalledWith( + 'chat-1', + expect.arrayContaining([ + expect.objectContaining({ + parts: expect.arrayContaining([ + expect.objectContaining({ type: 'tool-call' }), + ]), + }), + ]), + ) + }) + + it('should reset session generation when a persisted clear ignores terminal chunks', async () => { + const releaseResponse = createDeferred() + let wakeSubscriber: (() => void) | null = null + let queued = false + const adapter: ConnectionAdapter = { + subscribe: vi.fn( + (_signal?: AbortSignal): AsyncIterable => { + return (async function* () { + while (true) { + if (!queued) { + await new Promise((resolve) => { + wakeSubscriber = resolve + }) + } + queued = false + yield { + type: EventType.RUN_STARTED, + threadId: 'thread-1', + runId: 'run-1', + timestamp: Date.now(), + } as StreamChunk + await releaseResponse.promise + yield { + type: EventType.RUN_FINISHED, + threadId: 'thread-1', + runId: 'run-1', + model: 'test', + timestamp: Date.now(), + finishReason: 'stop', + } as StreamChunk + } + })() + }, + ), + send: vi.fn(async () => { + queued = true + wakeSubscriber?.() + wakeSubscriber = null + }), + } + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence: createPersistence(), + }) + + const sendPromise = client.sendMessage('A') + await vi.waitFor(() => { + expect(client.getSessionGenerating()).toBe(true) + }) + + client.clear() + releaseResponse.resolve() + await sendPromise + + expect(client.getSessionGenerating()).toBe(false) + }) + + it('should ignore live-only active run chunks after persisted clear and clean up session generation', async () => { + let storedMessages: Array | undefined + const releaseAfterClear = createDeferred() + const subscriberReady = createDeferred<() => void>() + const adapter: ConnectionAdapter = { + subscribe: vi.fn( + (_signal?: AbortSignal): AsyncIterable => { + return (async function* () { + await new Promise((resolve) => { + subscriberReady.resolve(resolve) }) - continue - } + yield { + type: EventType.RUN_STARTED, + threadId: 'thread-1', + runId: 'live-run-1', + timestamp: Date.now(), + } as StreamChunk + await releaseAfterClear.promise + yield { + type: EventType.TEXT_MESSAGE_START, + messageId: 'live-message-1', + role: 'assistant', + } as StreamChunk + yield { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: 'live-message-1', + delta: 'stale live content', + } as StreamChunk + yield { + type: EventType.RUN_FINISHED, + threadId: 'thread-1', + runId: 'live-run-1', + model: 'test', + timestamp: Date.now(), + finishReason: 'stop', + } as StreamChunk + })() + }, + ), + send: vi.fn(), + } + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn((_key: string, messages: Array) => { + storedMessages = messages + }), + removeItem: vi.fn(() => { + storedMessages = undefined + }), + } + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) - hasPendingSend = false - for (const chunk of chunksToSend) { - yield chunk - } - } - removeAbortListener?.() - removeAbortListener = null - })() + client.subscribe() + const wakeSubscriber = await subscriberReady.promise + wakeSubscriber() + await vi.waitFor(() => { + expect(client.getSessionGenerating()).toBe(true) }) - const send = vi.fn(async () => { - removeAbortListener?.() - removeAbortListener = null - hasPendingSend = true - wakeSubscriber?.() - wakeSubscriber = null + client.clear() + releaseAfterClear.resolve() + + await vi.waitFor(() => { + expect(client.getSessionGenerating()).toBe(false) }) + expect(client.getMessages()).toEqual([]) + expect(storedMessages).toBeUndefined() + }) - return { subscribe, send } - } + it('should not expose request generation metadata on public chunks or run context', async () => { + const onChunk = vi.fn() + const runContextSpy = vi.fn() + const client = new ChatClient({ + connection: { + async *connect(_messages, _data, _abortSignal, runContext) { + runContextSpy(runContext) + yield* createTextChunks('Hello') + }, + }, + onChunk, + }) - it('should use subscribe/send adapter mode', async () => { - const adapter = createSubscribeAdapter( - createTextChunks('From subscribe/send mode'), + await client.sendMessage('Hello') + + expect(runContextSpy).toHaveBeenCalled() + expect(runContextSpy.mock.calls[0]![0]).not.toHaveProperty( + 'requestGeneration', ) - const client = new ChatClient({ connection: adapter }) + expect(onChunk).toHaveBeenCalled() + for (const [chunk] of onChunk.mock.calls) { + expect(Object.keys(chunk)).not.toContain('requestGeneration') + expect(chunk).not.toHaveProperty('requestGeneration') + } + }) + + it('should not add internal threadId or runId to public connect chunks', async () => { + const onChunk = vi.fn() + const client = new ChatClient({ + connection: { + async *connect() { + yield { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: 'public-message', + timestamp: Date.now(), + delta: 'Hello', + content: 'Hello', + } as StreamChunk + }, + }, + onChunk, + }) await client.sendMessage('Hello') - expect(adapter.subscribe).toHaveBeenCalled() - expect(adapter.send).toHaveBeenCalled() + expect(onChunk).toHaveBeenCalledWith( + expect.objectContaining({ + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: 'public-message', + }), + ) + for (const [chunk] of onChunk.mock.calls) { + if (chunk.type === EventType.TEXT_MESSAGE_CONTENT) { + expect(chunk).not.toHaveProperty('threadId') + expect(chunk).not.toHaveProperty('runId') + } + } }) it('stop should not unsubscribe an active subscription', async () => { @@ -736,6 +1913,60 @@ describe('ChatClient', () => { client.unsubscribe() }) + it('should process future live subscription chunks after persistence clear', async () => { + const wake = { fn: null as (() => void) | null } + const chunks: Array = [] + const connection = { + subscribe: async function* (signal?: AbortSignal) { + while (!signal?.aborted) { + if (chunks.length > 0) { + const batch = chunks.splice(0) + for (const chunk of batch) { + yield chunk + } + } + await new Promise((resolve) => { + wake.fn = resolve + const onAbort = () => resolve() + signal?.addEventListener('abort', onAbort, { once: true }) + }) + } + }, + send: async () => { + wake.fn?.() + }, + } + const persistence = createPersistence() + const client = new ChatClient({ + connection, + id: 'chat-1', + persistence, + }) + + client.subscribe() + await vi.waitFor(() => { + expect(client.getIsSubscribed()).toBe(true) + }) + + client.clear() + chunks.push(...createTextChunks('future live', 'future-live')) + wake.fn?.() + + await vi.waitFor(() => { + expect( + client + .getMessages() + .flatMap((message) => message.parts) + .some( + (part) => + part.type === 'text' && part.content.includes('future live'), + ), + ).toBe(true) + }) + + client.unsubscribe() + }) + it('should clear all runs on RUN_ERROR without runId', async () => { const wake = { fn: null as (() => void) | null } const chunks: Array = [] @@ -1069,6 +2300,280 @@ describe('ChatClient', () => { expect(client.getMessages().length).toBe(0) expect(client.getError()).toBeUndefined() }) + + it('should remove persisted messages without saving an empty snapshot', async () => { + const chunks = createTextChunks('Response') + const adapter = createMockConnectionAdapter({ chunks }) + const persistence = createPersistence() + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + await client.sendMessage('Hello') + vi.mocked(persistence.setItem).mockClear() + + client.clear() + + expect(persistence.removeItem).toHaveBeenCalledWith('chat-1') + expect(persistence.setItem).not.toHaveBeenCalled() + }) + + it('should not abort an in-flight stream when persistence is omitted', async () => { + let abortSignal: AbortSignal | undefined + // Gate the chunks on a deferred (instead of a fixed timer) so they are + // released strictly after clear() runs — otherwise the assertion races + // the stream and is flaky on faster machines/CI. + const releaseChunks = createDeferred() + + const adapter: ConnectConnectionAdapter = { + async *connect(_messages, _data, signal) { + abortSignal = signal + await releaseChunks.promise + yield* createTextChunks('Delayed') + }, + } + const client = new ChatClient({ connection: adapter }) + + const sendPromise = client.sendMessage('Hello') + await vi.waitFor(() => { + expect(abortSignal).toBeDefined() + }) + + client.clear() + expect(abortSignal?.aborted).toBe(false) + + // Without persistence, clear() does not abort the in-flight stream, so + // its chunks still populate messages once they arrive. + releaseChunks.resolve() + await sendPromise + + expect(client.getMessages()).toEqual( + expect.arrayContaining([ + expect.objectContaining({ role: 'assistant' }), + ]), + ) + }) + + it('should prevent delayed stream chunks from recreating messages after clear', async () => { + const adapter = createMockConnectionAdapter({ + chunks: createTextChunks('Delayed'), + chunkDelay: 20, + }) + const persistence = createPersistence() + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + const sendPromise = client.sendMessage('Hello') + await new Promise((resolve) => setTimeout(resolve, 5)) + + client.clear() + await sendPromise + + expect(client.getMessages()).toEqual([]) + expect(persistence.removeItem).toHaveBeenCalledWith('chat-1') + expect(persistence.setItem).not.toHaveBeenLastCalledWith( + 'chat-1', + expect.arrayContaining([ + expect.objectContaining({ role: 'assistant' }), + ]), + ) + }) + + it('should not persist non-cooperative delayed chunks after clear removes storage', async () => { + const persistenceEvents: Array = [] + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn(() => { + persistenceEvents.push('set') + }), + removeItem: vi.fn(() => { + persistenceEvents.push('remove') + }), + } + const adapter: ConnectConnectionAdapter = { + async *connect() { + await new Promise((resolve) => setTimeout(resolve, 10)) + yield* createTextChunks('Delayed') + }, + } + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + const sendPromise = client.sendMessage('Hello') + await new Promise((resolve) => setTimeout(resolve, 0)) + + client.clear() + await sendPromise + + const removeIndex = persistenceEvents.indexOf('remove') + expect(removeIndex).toBeGreaterThanOrEqual(0) + expect(persistenceEvents.slice(removeIndex + 1)).not.toContain('set') + expect(client.getMessages()).toEqual([]) + expect(persistence.removeItem).toHaveBeenCalledWith('chat-1') + }) + + it('should ignore cleared request chunks after persistence resumes for a new send', async () => { + let storedMessages: Array | undefined + const releaseFirstResponse = createDeferred() + const connection: ConnectConnectionAdapter = { + async *connect(messages) { + const userText = messages + .flatMap((message) => ('parts' in message ? message.parts : [])) + .find((part) => part.type === 'text')?.content + + if (userText === 'A') { + await releaseFirstResponse.promise + yield* createTextChunks('stale A', 'msg-a') + return + } + + yield* createTextChunks('fresh B', 'msg-b') + }, + } + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn((_key: string, messages: Array) => { + storedMessages = messages + }), + removeItem: vi.fn(() => { + storedMessages = undefined + }), + } + const client = new ChatClient({ + connection, + id: 'chat-1', + persistence, + }) + + const firstSend = client.sendMessage('A') + await vi.waitFor(() => { + expect(client.getIsLoading()).toBe(true) + }) + + client.clear() + await client.sendMessage('B') + releaseFirstResponse.resolve() + await firstSend + + const finalText = client + .getMessages() + .flatMap((message) => message.parts) + .filter((part) => part.type === 'text') + .map((part) => part.content) + .join('') + + expect(finalText).toContain('B') + expect(finalText).toContain('fresh B') + expect(finalText).not.toContain('A') + expect(finalText).not.toContain('stale A') + expect(storedMessages).toEqual(client.getMessages()) + expect( + storedMessages + ?.flatMap((message) => message.parts) + .filter((part) => part.type === 'text') + .map((part) => part.content) + .join(''), + ).not.toContain('stale A') + }) + + it('should ensure async setItem scheduled before clear cannot win after removeItem', async () => { + let storedMessages: Array | undefined + const releaseSet = createDeferred() + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn(async (_key: string, messages: Array) => { + await releaseSet.promise + storedMessages = messages + }), + removeItem: vi.fn(() => { + storedMessages = undefined + }), + } + const client = new ChatClient({ + connection: createMockConnectionAdapter(), + id: 'chat-1', + persistence, + }) + + client.setMessagesManually([initialMessage]) + client.clear() + + releaseSet.resolve() + await vi.waitFor(() => { + expect(persistence.removeItem).toHaveBeenCalledWith('chat-1') + }) + + expect(storedMessages).toBeUndefined() + }) + }) + + describe('persistence', () => { + it('should save message snapshots after sendMessage changes messages', async () => { + const chunks = createTextChunks('Response') + const adapter = createMockConnectionAdapter({ chunks }) + const persistence = createPersistence() + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + await client.sendMessage('Hello') + + expect(persistence.setItem).toHaveBeenCalled() + expect(persistence.setItem).toHaveBeenLastCalledWith( + 'chat-1', + client.getMessages(), + ) + }) + + it('should save message snapshots when messages are set manually', () => { + const adapter = createMockConnectionAdapter() + const persistence = createPersistence() + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + client.setMessagesManually([initialMessage]) + + expect(persistence.setItem).toHaveBeenCalledWith('chat-1', [ + initialMessage, + ]) + }) + + it('should swallow async persistence write and remove failures', async () => { + const adapter = createMockConnectionAdapter({ + chunks: createTextChunks('Hi'), + }) + const persistence = { + getItem: vi.fn(() => undefined), + setItem: vi.fn(() => Promise.reject(new Error('set failed'))), + removeItem: vi.fn(() => Promise.reject(new Error('remove failed'))), + } + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + persistence, + }) + + await client.sendMessage('Hello') + client.clear() + await new Promise((resolve) => setTimeout(resolve, 0)) + + expect(client.getMessages()).toEqual([]) + expect(persistence.setItem).toHaveBeenCalled() + expect(persistence.removeItem).toHaveBeenCalledWith('chat-1') + }) }) describe('callbacks', () => { @@ -1088,6 +2593,40 @@ describe('ChatClient', () => { expect(onMessagesChange.mock.calls.length).toBeGreaterThan(0) }) + it('should preserve state updates and onMessagesChange when persistence throws', async () => { + const chunks = createTextChunks('Response') + const adapter = createMockConnectionAdapter({ chunks }) + const onMessagesChange = vi.fn() + const persistence: ChatClientPersistence = { + getItem: vi.fn(() => { + throw new Error('get failed') + }), + setItem: vi.fn(() => { + throw new Error('set failed') + }), + removeItem: vi.fn(() => { + throw new Error('remove failed') + }), + } + + const client = new ChatClient({ + connection: adapter, + id: 'chat-1', + initialMessages: [initialMessage], + onMessagesChange, + persistence, + }) + + expect(client.getMessages()).toEqual([initialMessage]) + + await client.sendMessage('Hello') + expect(client.getMessages().length).toBeGreaterThan(1) + expect(onMessagesChange).toHaveBeenCalled() + + expect(() => client.clear()).not.toThrow() + expect(client.getMessages()).toEqual([]) + }) + it('should call onLoadingChange when loading state changes', async () => { const chunks = createTextChunks('Response') const adapter = createMockConnectionAdapter({ chunks }) diff --git a/packages/ai-client/tests/client-persistor.test.ts b/packages/ai-client/tests/client-persistor.test.ts new file mode 100644 index 000000000..05cc4380c --- /dev/null +++ b/packages/ai-client/tests/client-persistor.test.ts @@ -0,0 +1,497 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { EventType } from '@tanstack/ai/client' +import { ChatPersistor } from '../src/client-persistor' +import { createMockPersistence, createUIMessage } from './test-utils' +import type { StreamChunk } from '@tanstack/ai/client' +import type { ChatClientPersistence, UIMessage } from '../src/types' + +const CHAT_ID = 'chat-1' + +/** Resolve after pending micro- and macro-tasks have drained. */ +function flushAsync(): Promise { + return new Promise((resolve) => setTimeout(resolve, 0)) +} + +/** A promise with externally accessible resolve/reject. */ +function createDeferred() { + let resolve!: (value: T) => void + let reject!: (reason?: unknown) => void + const promise = new Promise((res, rej) => { + resolve = res + reject = rej + }) + return { promise, resolve, reject } +} + +/** + * Build a persistor wired to `adapter` with a spy `applyMessages` callback so + * tests can assert what (if anything) gets re-applied to the client. + */ +function createPersistor(adapter: ChatClientPersistence, id: string = CHAT_ID) { + const applyMessages = vi.fn<(messages: Array) => void>() + const persistor = new ChatPersistor(adapter, id, applyMessages) + return { persistor, applyMessages } +} + +// --- typed StreamChunk fixtures ------------------------------------------- +// These satisfy the StreamChunk union directly (no casts). The persistor only +// ever reads `type` and the relevant id field off a chunk, and unit-test chunks +// are never run through the connection adapter's run-context map, so the run id +// must live on the chunk itself (only RUN_STARTED/RUN_FINISHED carry one). + +function runStarted(runId: string, threadId: string = 'thread-1'): StreamChunk { + return { type: EventType.RUN_STARTED, threadId, runId } +} + +function runFinished( + runId: string, + threadId: string = 'thread-1', +): StreamChunk { + return { type: EventType.RUN_FINISHED, threadId, runId } +} + +function runError(message: string = 'boom'): StreamChunk { + return { type: EventType.RUN_ERROR, message } +} + +function textContent(messageId: string, delta: string = 'hi'): StreamChunk { + return { type: EventType.TEXT_MESSAGE_CONTENT, messageId, delta } +} + +function toolCallStart( + toolCallId: string, + parentMessageId?: string, +): StreamChunk { + return { + type: EventType.TOOL_CALL_START, + toolCallId, + toolCallName: 'tool', + toolName: 'tool', + ...(parentMessageId ? { parentMessageId } : {}), + } +} + +function messagesSnapshot(): StreamChunk { + return { type: EventType.MESSAGES_SNAPSHOT, messages: [] } +} + +describe('ChatPersistor', () => { + describe('readInitial', () => { + it('returns the stored messages from a synchronous getItem', () => { + const stored = [createUIMessage('m-1')] + const adapter = createMockPersistence(stored) + const { persistor } = createPersistor(adapter) + + expect(persistor.readInitial()).toBe(stored) + expect(adapter.getItem).toHaveBeenCalledWith(CHAT_ID) + }) + + it('returns the promise from an asynchronous getItem', () => { + const stored = [createUIMessage('m-1')] + const adapter = createMockPersistence() + adapter.getItem = vi.fn(() => Promise.resolve(stored)) + const { persistor } = createPersistor(adapter) + + expect(persistor.readInitial()).toBeInstanceOf(Promise) + }) + + it('swallows a throwing getItem and returns undefined', () => { + const adapter = createMockPersistence() + adapter.getItem = vi.fn(() => { + throw new Error('storage unavailable') + }) + const { persistor } = createPersistor(adapter) + + expect(persistor.readInitial()).toBeUndefined() + }) + }) + + describe('hydrateAsync', () => { + it('applies messages once the promise resolves to an array', async () => { + const stored = [createUIMessage('persisted-1')] + const { persistor, applyMessages } = createPersistor( + createMockPersistence(), + ) + + persistor.hydrateAsync(Promise.resolve(stored)) + await flushAsync() + + expect(applyMessages).toHaveBeenCalledWith(stored) + }) + + it.each([ + ['null', null], + ['undefined', undefined], + ])( + 'does not apply when the promise resolves to %s', + async (_label, value) => { + const { persistor, applyMessages } = createPersistor( + createMockPersistence(), + ) + + persistor.hydrateAsync(Promise.resolve(value)) + await flushAsync() + + expect(applyMessages).not.toHaveBeenCalled() + }, + ) + + it('does nothing for a synchronous (non-promise) value', async () => { + const { persistor, applyMessages } = createPersistor( + createMockPersistence(), + ) + + persistor.hydrateAsync([createUIMessage('m-1')]) + await flushAsync() + + expect(applyMessages).not.toHaveBeenCalled() + }) + + it('does not apply if messages changed before hydration resolves', async () => { + const deferred = createDeferred>() + const { persistor, applyMessages } = createPersistor( + createMockPersistence(), + ) + + persistor.hydrateAsync(deferred.promise) + // A local change lands before the slow getItem resolves. + persistor.notifyMessagesChanged([createUIMessage('local-1')]) + deferred.resolve([createUIMessage('persisted-1')]) + await flushAsync() + + expect(applyMessages).not.toHaveBeenCalled() + }) + + it('swallows a rejected hydration promise', async () => { + const { persistor, applyMessages } = createPersistor( + createMockPersistence(), + ) + + persistor.hydrateAsync(Promise.reject(new Error('read failed'))) + await flushAsync() + + expect(applyMessages).not.toHaveBeenCalled() + }) + }) + + describe('notifyMessagesChanged', () => { + it('persists the messages on change', () => { + const adapter = createMockPersistence() + const { persistor } = createPersistor(adapter) + const messages = [createUIMessage('m-1')] + + persistor.notifyMessagesChanged(messages) + + expect(adapter.setItem).toHaveBeenCalledWith(CHAT_ID, messages) + }) + + it('persists a snapshot, not the live array reference', () => { + const adapter = createMockPersistence() + const { persistor } = createPersistor(adapter) + const messages = [createUIMessage('m-1')] + + persistor.notifyMessagesChanged(messages) + messages.push(createUIMessage('m-2')) + + const persisted = vi.mocked(adapter.setItem).mock.calls[0]?.[1] + expect(persisted).toHaveLength(1) + }) + + it('skips exactly one write after beginClear, then resumes', () => { + const adapter = createMockPersistence() + const { persistor } = createPersistor(adapter) + + persistor.notifyMessagesChanged([createUIMessage('m-1')]) + persistor.beginClear() + persistor.notifyMessagesChanged([]) // the clear's empty snapshot — skipped + persistor.notifyMessagesChanged([createUIMessage('m-2')]) + + expect(adapter.setItem).toHaveBeenCalledTimes(2) + }) + + it('swallows a synchronous setItem error', () => { + const adapter = createMockPersistence() + adapter.setItem = vi.fn(() => { + throw new Error('quota exceeded') + }) + const { persistor } = createPersistor(adapter) + + expect(() => + persistor.notifyMessagesChanged([createUIMessage('m-1')]), + ).not.toThrow() + }) + + it('runs async writes sequentially (FIFO, no overlap)', async () => { + const first = createDeferred() + const second = createDeferred() + const adapter = createMockPersistence() + adapter.setItem = vi + .fn() + .mockReturnValueOnce(first.promise) + .mockReturnValueOnce(second.promise) + const { persistor } = createPersistor(adapter) + + persistor.notifyMessagesChanged([createUIMessage('a')]) + persistor.notifyMessagesChanged([createUIMessage('b')]) + // Second write is queued behind the first and must not have started yet. + expect(adapter.setItem).toHaveBeenCalledTimes(1) + + first.resolve() + await flushAsync() + + expect(adapter.setItem).toHaveBeenCalledTimes(2) + expect(vi.mocked(adapter.setItem).mock.calls[1]?.[1]).toEqual([ + createUIMessage('b'), + ]) + }) + + it('keeps writing after an async write rejects', async () => { + const adapter = createMockPersistence() + adapter.setItem = vi + .fn() + .mockRejectedValueOnce(new Error('transient')) + .mockResolvedValue(undefined) + const { persistor } = createPersistor(adapter) + + persistor.notifyMessagesChanged([createUIMessage('a')]) + persistor.notifyMessagesChanged([createUIMessage('b')]) + await flushAsync() + + expect(adapter.setItem).toHaveBeenCalledTimes(2) + }) + }) + + describe('remove', () => { + it('removes the persisted conversation', () => { + const adapter = createMockPersistence() + const { persistor } = createPersistor(adapter) + + persistor.remove() + + expect(adapter.removeItem).toHaveBeenCalledWith(CHAT_ID) + }) + + it('swallows a synchronous removeItem error', () => { + const adapter = createMockPersistence() + adapter.removeItem = vi.fn(() => { + throw new Error('locked') + }) + const { persistor } = createPersistor(adapter) + + expect(() => persistor.remove()).not.toThrow() + }) + + it('invalidates a queued write that a removal supersedes', async () => { + const inFlight = createDeferred() + const adapter = createMockPersistence() + adapter.setItem = vi + .fn() + .mockReturnValueOnce(inFlight.promise) // write A (in flight) + .mockResolvedValue(undefined) // write B (would run later) + const { persistor } = createPersistor(adapter) + + persistor.notifyMessagesChanged([createUIMessage('a')]) // starts, in flight + persistor.notifyMessagesChanged([createUIMessage('b')]) // queued behind A + persistor.remove() // bumps generation, queued behind B + + inFlight.resolve() + await flushAsync() + + // B's generation was superseded by remove(), so only A was written. + expect(adapter.setItem).toHaveBeenCalledTimes(1) + expect(adapter.removeItem).toHaveBeenCalledTimes(1) + }) + }) + + describe('clear-during-stream suppression', () => { + it('does not ignore chunks before a clear', () => { + const { persistor } = createPersistor(createMockPersistence()) + + expect(persistor.shouldIgnoreChunk(runStarted('run-1'))).toBe(false) + expect(persistor.shouldIgnoreChunk(textContent('msg-1'))).toBe(false) + }) + + it('ignores chunks for a message id captured at clear time', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [createUIMessage('msg-1')], + activeRunIds: new Set(), + currentRunId: null, + }) + + expect(persistor.shouldIgnoreChunk(textContent('msg-1'))).toBe(true) + expect(persistor.shouldIgnoreChunk(textContent('msg-other'))).toBe(false) + }) + + it('ignores a RUN_STARTED for an active run captured at clear time', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [], + activeRunIds: new Set(['run-1']), + currentRunId: null, + }) + + expect(persistor.shouldIgnoreChunk(runStarted('run-1'))).toBe(true) + }) + + it('ignores the in-flight currentRunId captured at clear time', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [], + activeRunIds: new Set(), + currentRunId: 'run-current', + }) + + expect(persistor.shouldIgnoreChunk(runStarted('run-current'))).toBe(true) + }) + + it('ignores runless content chunks belonging to a cleared run', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [], + activeRunIds: new Set(['run-1']), + currentRunId: null, + }) + // The cleared run starts again, pinning it as the current runless run. + persistor.shouldIgnoreChunk(runStarted('run-1')) + + expect(persistor.shouldIgnoreChunk(textContent('late-msg'))).toBe(true) + expect(persistor.shouldIgnoreChunk(messagesSnapshot())).toBe(true) + }) + + it('ignores tool chunks by cleared parentMessageId and remembers the toolCallId', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [createUIMessage('parent-msg')], + activeRunIds: new Set(), + currentRunId: null, + }) + + // First seen via its parent message id... + expect( + persistor.shouldIgnoreChunk(toolCallStart('tc-1', 'parent-msg')), + ).toBe(true) + // ...then a later chunk for the same tool call (no parent) is still ignored. + expect(persistor.shouldIgnoreChunk(toolCallStart('tc-1'))).toBe(true) + }) + }) + + describe('run lifecycle hooks', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('forgets a settled run so a later run with the same id is not ignored', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [], + activeRunIds: new Set(['run-1']), + currentRunId: null, + }) + expect(persistor.shouldIgnoreChunk(runStarted('run-1'))).toBe(true) + + persistor.onRunSettled('run-1') + + expect(persistor.shouldIgnoreChunk(runStarted('run-1'))).toBe(false) + }) + + it('stops ignoring runless chunks after a session-level run error', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [], + activeRunIds: new Set(['run-1']), + currentRunId: null, + }) + persistor.shouldIgnoreChunk(runStarted('run-1')) + expect(persistor.shouldIgnoreChunk(textContent('late-1'))).toBe(true) + + persistor.onSessionRunError() + + // A fresh runless chunk is no longer attributed to the cleared run. (An + // already-seen message id stays ignored — shouldIgnoreChunk records it.) + expect(persistor.shouldIgnoreChunk(textContent('late-2'))).toBe(false) + }) + + it('takeRunlessRunId returns and clears the current runless run, then null', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [], + activeRunIds: new Set(['run-1']), + currentRunId: null, + }) + persistor.shouldIgnoreChunk(runStarted('run-1')) + + expect(persistor.takeRunlessRunId()).toBe('run-1') + expect(persistor.takeRunlessRunId()).toBeNull() + // The drained run is no longer pinned, so runless chunks aren't ignored. + expect(persistor.shouldIgnoreChunk(textContent('late'))).toBe(false) + }) + + it('resetIgnored only clears active-run markers, not cleared ids', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [createUIMessage('msg-1')], + activeRunIds: new Set(['run-1']), + currentRunId: null, + }) + + persistor.resetIgnored() + + // Cleared run/message ids survive a reset (they outlive the active run). + expect(persistor.shouldIgnoreChunk(runStarted('run-1'))).toBe(true) + expect(persistor.shouldIgnoreChunk(textContent('msg-1'))).toBe(true) + }) + + it('does not ignore an unrelated runId-less run error', () => { + const { persistor } = createPersistor(createMockPersistence()) + + expect(persistor.shouldIgnoreChunk(runError())).toBe(false) + }) + + it('onRunStarted for a non-cleared run does not suppress runless chunks', () => { + const { persistor } = createPersistor(createMockPersistence()) + + // Pinning a live (never-cleared) run must not start swallowing its output. + persistor.onRunStarted('run-live') + + expect(persistor.shouldIgnoreChunk(textContent('content'))).toBe(false) + }) + + it('advances the runless pointer when the active run finishes', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [], + activeRunIds: new Set(['run-1', 'run-2']), + currentRunId: null, + }) + persistor.shouldIgnoreChunk(runStarted('run-1')) + persistor.shouldIgnoreChunk(runStarted('run-2')) + + // run-2 was the pinned runless run; settling it falls back to run-1. + persistor.onRunSettled('run-2') + + expect(persistor.shouldIgnoreChunk(runFinished('run-2'))).toBe(false) + expect(persistor.shouldIgnoreChunk(textContent('late'))).toBe(true) + }) + + it('keeps suppressing a second cleared run after the first drains via a runless error', () => { + const { persistor } = createPersistor(createMockPersistence()) + persistor.snapshotClear({ + messages: [], + activeRunIds: new Set(['run-1', 'run-2']), + currentRunId: null, + }) + persistor.shouldIgnoreChunk(runStarted('run-1')) + persistor.shouldIgnoreChunk(runStarted('run-2')) // run-2 now pinned + + // A runId-less RUN_ERROR drains the pinned run (run-2)... + expect(persistor.takeRunlessRunId()).toBe('run-2') + + // ...but run-1 is still cleared, so its runless content stays suppressed + // instead of leaking through once the pointer advances back to it. + expect(persistor.shouldIgnoreChunk(textContent('late-from-run-1'))).toBe( + true, + ) + }) + }) +}) diff --git a/packages/ai-client/tests/test-utils.ts b/packages/ai-client/tests/test-utils.ts index dc515d43e..82d37ddc2 100644 --- a/packages/ai-client/tests/test-utils.ts +++ b/packages/ai-client/tests/test-utils.ts @@ -1,6 +1,33 @@ +import { vi } from 'vitest' import type { ConnectConnectionAdapter } from '../src/connection-adapters' import type { ModelMessage, StreamChunk } from '@tanstack/ai/client' -import type { UIMessage } from '../src/types' +import type { ChatClientPersistence, UIMessage } from '../src/types' + +/** + * Build a minimal text {@link UIMessage} for tests. + */ +export function createUIMessage( + id: string, + text: string = 'hello', + role: UIMessage['role'] = 'user', +): UIMessage { + return { id, role, parts: [{ type: 'text', content: text }] } +} + +/** + * Create a persistence adapter whose three methods are vitest spies. `getItem` + * synchronously returns `initial` (defaults to `undefined`); override individual + * methods via the returned object's `.mock*` helpers for async/error scenarios. + */ +export function createMockPersistence( + initial?: Array | null, +): ChatClientPersistence { + return { + getItem: vi.fn(() => initial), + setItem: vi.fn(), + removeItem: vi.fn(), + } +} /** * Options for creating a mock connection adapter */ diff --git a/packages/ai-preact/src/use-chat.ts b/packages/ai-preact/src/use-chat.ts index 6a9f2f8ce..cd5a38a1d 100644 --- a/packages/ai-preact/src/use-chat.ts +++ b/packages/ai-preact/src/use-chat.ts @@ -45,6 +45,10 @@ export function useChat< options.initialMessages || [], ) const isFirstMountRef = useRef(true) + const activeClientRef = useRef(null) + const cleanupInvalidationRef = useRef | null>( + null, + ) const optionsRef = useRef>(options) optionsRef.current = options @@ -54,11 +58,7 @@ export function useChat< }, [messages]) const client = useMemo(() => { - // On first mount, use initialMessages. On subsequent recreations, preserve existing messages. - const messagesToUse = isFirstMountRef.current - ? options.initialMessages || [] - : messagesRef.current - + const messagesToUse = options.initialMessages || [] isFirstMountRef.current = false // Build options with conditional spreads for fields whose source @@ -70,7 +70,7 @@ export function useChat< ? { connection: initialOptions.connection } : { fetcher: initialOptions.fetcher } - return new ChatClient({ + const instance = new ChatClient({ devtoolsBridgeFactory: createChatDevtoolsBridge, ...transport, id: clientId, @@ -79,6 +79,9 @@ export function useChat< ...(initialOptions.forwardedProps !== undefined && { forwardedProps: initialOptions.forwardedProps, }), + ...(initialOptions.persistence !== undefined && { + persistence: initialOptions.persistence, + }), ...(initialOptions.context !== undefined && { context: initialOptions.context, }), @@ -91,16 +94,26 @@ export function useChat< // Wrap every callback so the latest options are read at call time. // Capturing the function reference directly would freeze it to whatever // the parent passed on the first render. - onResponse: (response) => optionsRef.current.onResponse?.(response), - onChunk: (chunk) => optionsRef.current.onChunk?.(chunk), + onResponse: (response) => { + if (activeClientRef.current !== instance) return + return optionsRef.current.onResponse?.(response) + }, + onChunk: (chunk) => { + if (activeClientRef.current !== instance) return + optionsRef.current.onChunk?.(chunk) + }, onFinish: (message) => { + if (activeClientRef.current !== instance) return optionsRef.current.onFinish?.(message) }, onError: (err) => { + if (activeClientRef.current !== instance) return optionsRef.current.onError?.(err) }, - onCustomEvent: (eventType, data, context) => - optionsRef.current.onCustomEvent?.(eventType, data, context), + onCustomEvent: (eventType, data, context) => { + if (activeClientRef.current !== instance) return + optionsRef.current.onCustomEvent?.(eventType, data, context) + }, ...(initialOptions.tools !== undefined && { tools: initialOptions.tools, }), @@ -108,29 +121,45 @@ export function useChat< streamProcessor: options.streamProcessor, }), onMessagesChange: (newMessages: Array>) => { + if (activeClientRef.current !== instance) return setMessages(newMessages) }, onLoadingChange: (newIsLoading: boolean) => { + if (activeClientRef.current !== instance) return setIsLoading(newIsLoading) }, onStatusChange: (newStatus: ChatClientState) => { + if (activeClientRef.current !== instance) return setStatus(newStatus) }, onErrorChange: (newError: Error | undefined) => { + if (activeClientRef.current !== instance) return setError(newError) }, onSubscriptionChange: (nextIsSubscribed: boolean) => { + if (activeClientRef.current !== instance) return setIsSubscribed(nextIsSubscribed) }, onConnectionStatusChange: (nextStatus: ConnectionStatus) => { + if (activeClientRef.current !== instance) return setConnectionStatus(nextStatus) }, onSessionGeneratingChange: (isGenerating: boolean) => { + if (activeClientRef.current !== instance) return setSessionGenerating(isGenerating) }, }) + activeClientRef.current = instance + return instance }, [clientId]) + useEffect(() => { + const clientMessages = client.getMessages() + if (clientMessages !== messagesRef.current) { + setMessages(clientMessages) + } + }, [client]) + // Sync body / forwardedProps changes to the client. // Both populate the same wire payload; `forwardedProps` is preferred // and `body` is deprecated but still supported. @@ -146,19 +175,6 @@ export function useChat< }) }, [client, options.body, options.forwardedProps, options.context]) - // Sync initial messages on mount only - // Note: initialMessages are passed to ChatClient constructor, but we also - // set them here to ensure Preact state is in sync - useEffect(() => { - if ( - options.initialMessages && - options.initialMessages.length && - !messages.length - ) { - client.setMessagesManually(options.initialMessages) - } - }, []) - useEffect(() => { if (options.live) { client.subscribe() @@ -172,9 +188,20 @@ export function useChat< // DO NOT include isLoading in dependencies - that would cause the cleanup // to run when isLoading changes, aborting continuation requests. useEffect(() => { + if (cleanupInvalidationRef.current) { + clearTimeout(cleanupInvalidationRef.current) + cleanupInvalidationRef.current = null + } + activeClientRef.current = client client.mountDevtools() return () => { + cleanupInvalidationRef.current = setTimeout(() => { + if (activeClientRef.current === client) { + activeClientRef.current = null + } + cleanupInvalidationRef.current = null + }, 0) // Subscribe/unsubscribe on `options.live` is owned by the dedicated // effect above. This cleanup only fires on unmount or client swap, // so read `live` through the ref to avoid disposing the client every @@ -243,8 +270,10 @@ export function useChat< [client], ) + const renderedMessages = client.getMessages() + return { - messages, + messages: renderedMessages, sendMessage, append, reload, diff --git a/packages/ai-preact/tests/use-chat.test.ts b/packages/ai-preact/tests/use-chat.test.ts index e2e2b8f24..11e874c41 100644 --- a/packages/ai-preact/tests/use-chat.test.ts +++ b/packages/ai-preact/tests/use-chat.test.ts @@ -1,7 +1,12 @@ -import type { ModelMessage } from '@tanstack/ai' -import { act, waitFor } from '@testing-library/preact' +import type { ModelMessage, StreamChunk } from '@tanstack/ai' +import { EventType } from '@tanstack/ai' +import type { SubscribeConnectionAdapter } from '@tanstack/ai-client' +import { act, renderHook, waitFor } from '@testing-library/preact' +import { StrictMode } from 'preact/compat' +import { useState } from 'preact/hooks' import { describe, expect, it, vi } from 'vitest' import type { UIMessage } from '../src/types' +import { useChat } from '../src/use-chat' import { createMockConnectionAdapter, createTextChunks, @@ -10,6 +15,14 @@ import { } from './test-utils' describe('useChat', () => { + function createDeferred() { + let resolve!: (value: T) => void + const promise = new Promise((promiseResolve) => { + resolve = promiseResolve + }) + return { promise, resolve } + } + describe('initialization', () => { it('should initialize with default state', () => { const adapter = createMockConnectionAdapter() @@ -55,6 +68,118 @@ describe('useChat', () => { expect(result.current.messages).toEqual(initialMessages) }) + it('should initialize with persisted messages', async () => { + const adapter = createMockConnectionAdapter() + const persistedMessages: Array = [ + { + id: 'persisted-1', + role: 'user', + parts: [{ type: 'text', content: 'Persisted' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn(() => persistedMessages), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const { result } = renderUseChat({ + connection: adapter, + id: 'persisted-chat', + persistence, + }) + + await waitFor(() => { + expect(result.current.messages).toEqual(persistedMessages) + }) + expect(persistence.getItem).toHaveBeenCalledWith('persisted-chat') + }) + + it('should preserve persisted empty messages over provided initial messages', async () => { + const adapter = createMockConnectionAdapter() + const initialMessages: Array = [ + { + id: 'initial-1', + role: 'user', + parts: [{ type: 'text', content: 'Initial' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn(() => []), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const { result } = renderUseChat({ + connection: adapter, + id: 'persisted-empty-chat', + initialMessages, + persistence, + }) + + await waitFor(() => { + expect(persistence.getItem).toHaveBeenCalledWith('persisted-empty-chat') + }) + expect(result.current.messages).toEqual([]) + }) + + it('should ignore async persisted messages from a previous id', async () => { + const oldHydration = createDeferred>() + const oldMessages: Array = [ + { + id: 'old-persisted', + role: 'user', + parts: [{ type: 'text', content: 'Old persisted' }], + createdAt: new Date(), + }, + ] + const newMessages: Array = [ + { + id: 'new-persisted', + role: 'user', + parts: [{ type: 'text', content: 'New persisted' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn((id: string) => + id === 'old-chat' ? oldHydration.promise : newMessages, + ), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + function useChangingChat() { + const [id, setId] = useState('old-chat') + const chat = useChat({ + connection: createMockConnectionAdapter(), + id, + persistence, + }) + + return { ...chat, setId } + } + + const { result } = renderHook(() => useChangingChat()) + + act(() => { + result.current.setId('new-chat') + }) + + await waitFor(() => { + expect(result.current.messages).toEqual(newMessages) + }) + + await act(async () => { + oldHydration.resolve(oldMessages) + await oldHydration.promise + }) + + expect(result.current.messages).toEqual(newMessages) + }) + it('should use provided id', async () => { const chunks = createTextChunks('Response') const adapter = createMockConnectionAdapter({ chunks }) @@ -876,6 +1001,78 @@ describe('useChat', () => { }) describe('edge cases and error handling', () => { + describe('callbacks', () => { + it('should ignore user callbacks from an old client after id changes', async () => { + const releaseOldStream = createDeferred() + const oldOnChunk = vi.fn() + const newOnChunk = vi.fn() + const adapter = { + async *connect(): AsyncIterable { + await releaseOldStream.promise + yield* createTextChunks('stale old client') + }, + } + + const { result, rerender } = renderHook( + (opts: { id: string; onChunk: (chunk: StreamChunk) => void }) => + useChat({ + connection: adapter, + id: opts.id, + onChunk: opts.onChunk, + }), + { + initialProps: { + id: 'old-client', + onChunk: oldOnChunk, + }, + }, + ) + + let sendPromise: Promise + act(() => { + sendPromise = result.current.sendMessage('Test') + }) + await waitFor(() => { + expect(result.current.isLoading).toBe(true) + }) + + rerender({ + id: 'new-client', + onChunk: newOnChunk, + }) + + releaseOldStream.resolve() + await sendPromise! + + expect(oldOnChunk).not.toHaveBeenCalled() + expect(newOnChunk).not.toHaveBeenCalled() + }) + + it('should keep callbacks live across StrictMode effect replay for the same client', async () => { + const onChunk = vi.fn() + const adapter = createMockConnectionAdapter({ + chunks: createTextChunks('strict response'), + }) + + const { result } = renderHook( + () => + useChat({ + connection: adapter, + onChunk, + }), + { wrapper: StrictMode }, + ) + + await act(async () => { + await result.current.sendMessage('Test') + }) + + expect(onChunk).toHaveBeenCalledWith( + expect.objectContaining({ type: EventType.TEXT_MESSAGE_CONTENT }), + ) + }) + }) + describe('options changes', () => { it('should maintain client instance when options change', () => { const adapter1 = createMockConnectionAdapter() @@ -926,6 +1123,93 @@ describe('useChat', () => { }) }) + describe('client recreation', () => { + it('should not pass previous id messages to a new client id without persisted messages', async () => { + const connectSpy = vi.fn() + const adapter = createMockConnectionAdapter({ + chunks: createTextChunks('Reply'), + onConnect: connectSpy, + }) + + const { result } = renderHook(() => { + const [id, setId] = useState('client-A') + const chat = useChat({ connection: adapter, id }) + return { ...chat, switchId: setId } + }) + + const messages: Array = [ + { + id: 'msg-1', + role: 'user', + parts: [{ type: 'text', content: 'Hello' }], + createdAt: new Date(), + }, + { + id: 'msg-2', + role: 'assistant', + parts: [{ type: 'text', content: 'Hi there!' }], + createdAt: new Date(), + }, + ] + + act(() => { + result.current.setMessages(messages) + result.current.switchId('client-B') + }) + + await act(async () => { + await result.current.sendMessage('Follow-up') + }) + + await waitFor(() => { + expect(connectSpy).toHaveBeenCalled() + }) + + const sentMessages = connectSpy.mock.calls[0]![0] as Array< + ModelMessage | UIMessage + > + const sentText = sentMessages + .flatMap((message) => ('parts' in message ? message.parts : [])) + .filter((part) => part.type === 'text') + .map((part) => part.content) + .join('') + + expect(sentText).toContain('Follow-up') + expect(sentText).not.toContain('Hello') + expect(result.current.messages).not.toEqual(messages) + }) + + it('should return new client messages during the id change render', () => { + const adapter = createMockConnectionAdapter() + const oldMessages: Array = [ + { + id: 'old-message', + role: 'user', + parts: [{ type: 'text', content: 'Old client message' }], + createdAt: new Date(), + }, + ] + + const { result } = renderHook(() => { + const [id, setId] = useState('client-A') + const chat = useChat({ connection: adapter, id }) + return { ...chat, switchId: setId } + }) + + act(() => { + result.current.setMessages(oldMessages) + }) + + expect(result.current.messages).toEqual(oldMessages) + + act(() => { + result.current.switchId('client-B') + }) + + expect(result.current.messages).toEqual([]) + }) + }) + describe('unmount behavior', () => { it('should not update state after unmount', async () => { const chunks = createTextChunks('Response') @@ -1374,5 +1658,92 @@ describe('useChat', () => { } }) }) + + describe('sessionGenerating', () => { + it('should keep receiving live updates and callbacks after live toggles for the same client', async () => { + const onChunk = vi.fn() + const chunks: Array = [] + // Wrapped in an object so the assignment inside the generator closure + // is visible to TS control-flow analysis. A bare `let` would narrow to + // `null` at the call site, and `?.()` would then strip it to `never`. + const subscriberControl: { wake: (() => void) | null } = { wake: null } + const adapter: SubscribeConnectionAdapter = { + subscribe: async function* (_signal?: AbortSignal) { + while (true) { + if (chunks.length === 0) { + await new Promise((resolve) => { + subscriberControl.wake = resolve + }) + } + const chunk = chunks.shift() + if (chunk) yield chunk + } + }, + send: vi.fn(async () => {}), + } + + const { result, rerender } = renderHook( + ({ live }) => + useChat({ + connection: adapter, + live, + onChunk, + }), + { initialProps: { live: true } }, + ) + + await waitFor(() => { + expect(result.current.isSubscribed).toBe(true) + }) + + rerender({ live: false }) + await waitFor(() => { + expect(result.current.isSubscribed).toBe(false) + }) + + rerender({ live: true }) + await waitFor(() => { + expect(result.current.isSubscribed).toBe(true) + }) + + chunks.push( + { + type: EventType.RUN_STARTED, + runId: 'run-after-toggle', + threadId: 'thread-1', + timestamp: Date.now(), + }, + { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: 'msg-after-toggle', + timestamp: Date.now(), + delta: 'after toggle', + content: 'after toggle', + }, + { + type: EventType.RUN_FINISHED, + runId: 'run-after-toggle', + threadId: 'thread-1', + timestamp: Date.now(), + }, + ) + subscriberControl.wake?.() + subscriberControl.wake = null + + await waitFor(() => { + expect( + result.current.messages.some((message) => + message.parts.some( + (part) => + part.type === 'text' && part.content === 'after toggle', + ), + ), + ).toBe(true) + }) + expect(onChunk).toHaveBeenCalledWith( + expect.objectContaining({ type: EventType.TEXT_MESSAGE_CONTENT }), + ) + }) + }) }) }) diff --git a/packages/ai-react/src/use-chat.ts b/packages/ai-react/src/use-chat.ts index 733ee86f0..2c652bc51 100644 --- a/packages/ai-react/src/use-chat.ts +++ b/packages/ai-react/src/use-chat.ts @@ -52,6 +52,10 @@ export function useChat< options.initialMessages || [], ) const isFirstMountRef = useRef(true) + const activeClientRef = useRef(null) + const cleanupInvalidationRef = useRef | null>( + null, + ) // Update ref synchronously during render so it's always current when useMemo runs. messagesRef.current = messages @@ -62,10 +66,7 @@ export function useChat< // Create ChatClient instance with callbacks to sync state const client = useMemo(() => { - const messagesToUse = isFirstMountRef.current - ? options.initialMessages || [] - : messagesRef.current - + const messagesToUse = options.initialMessages || [] isFirstMountRef.current = false // Build options with conditional spreads for fields whose source @@ -77,7 +78,7 @@ export function useChat< ? { connection: initialOptions.connection } : { fetcher: initialOptions.fetcher } - return new ChatClient({ + const instance = new ChatClient({ devtoolsBridgeFactory: createChatDevtoolsBridge, ...transport, id: clientId, @@ -86,6 +87,9 @@ export function useChat< ...(initialOptions.forwardedProps !== undefined && { forwardedProps: initialOptions.forwardedProps, }), + ...(initialOptions.persistence !== undefined && { + persistence: initialOptions.persistence, + }), ...(initialOptions.context !== undefined && { context: initialOptions.context, }), @@ -96,50 +100,71 @@ export function useChat< outputKind: initialOptions.outputSchema ? 'structured' : 'chat', }, onResponse: (response) => { + if (activeClientRef.current !== instance) return void optionsRef.current.onResponse?.(response) }, onChunk: (chunk: StreamChunk) => { + if (activeClientRef.current !== instance) return optionsRef.current.onChunk?.(chunk) }, onFinish: (message: UIMessage) => { + if (activeClientRef.current !== instance) return optionsRef.current.onFinish?.(message) }, onError: (error: Error) => { + if (activeClientRef.current !== instance) return optionsRef.current.onError?.(error) }, ...(initialOptions.tools !== undefined && { tools: initialOptions.tools, }), onCustomEvent: (eventType, data, context) => { + if (activeClientRef.current !== instance) return optionsRef.current.onCustomEvent?.(eventType, data, context) }, ...(options.streamProcessor !== undefined && { streamProcessor: options.streamProcessor, }), onMessagesChange: (newMessages: Array>) => { + if (activeClientRef.current !== instance) return setMessages(newMessages) }, onLoadingChange: (newIsLoading: boolean) => { + if (activeClientRef.current !== instance) return setIsLoading(newIsLoading) }, onErrorChange: (newError: Error | undefined) => { + if (activeClientRef.current !== instance) return setError(newError) }, onStatusChange: (status: ChatClientState) => { + if (activeClientRef.current !== instance) return setStatus(status) }, onSubscriptionChange: (nextIsSubscribed: boolean) => { + if (activeClientRef.current !== instance) return setIsSubscribed(nextIsSubscribed) }, onConnectionStatusChange: (nextStatus: ConnectionStatus) => { + if (activeClientRef.current !== instance) return setConnectionStatus(nextStatus) }, onSessionGeneratingChange: (isGenerating: boolean) => { + if (activeClientRef.current !== instance) return setSessionGenerating(isGenerating) }, }) + activeClientRef.current = instance + return instance }, [clientId]) + useEffect(() => { + const clientMessages = client.getMessages() + if (clientMessages !== messagesRef.current) { + setMessages(clientMessages) + } + }, [client]) + useEffect(() => { // Conditional spread: `updateOptions` declares strict-optional // fields and rejects explicit `undefined` under EOPT. @@ -152,14 +177,6 @@ export function useChat< }) }, [client, options.body, options.forwardedProps, options.context]) - useEffect(() => { - if (options.initialMessages && options.initialMessages.length > 0) { - if (messages.length === 0) { - client.setMessagesManually(options.initialMessages) - } - } - }, []) - useEffect(() => { if (options.live) { client.subscribe() @@ -169,9 +186,20 @@ export function useChat< }, [client, options.live]) useEffect(() => { + if (cleanupInvalidationRef.current) { + clearTimeout(cleanupInvalidationRef.current) + cleanupInvalidationRef.current = null + } + activeClientRef.current = client client.mountDevtools() return () => { + cleanupInvalidationRef.current = setTimeout(() => { + if (activeClientRef.current === client) { + activeClientRef.current = null + } + cleanupInvalidationRef.current = null + }, 0) // Subscribe/unsubscribe on `options.live` is owned by the dedicated // effect above. This cleanup only fires on unmount or client swap, // so read `live` through the ref to avoid disposing the client every @@ -248,17 +276,19 @@ export function useChat< // a stale assistant turn or a system prompt) we deliberately return null // rather than scanning historical assistants — otherwise a `final` from a // previous session would leak into the hook value on first render. + const renderedMessages = client.getMessages() + const activeStructuredPart = useMemo(() => { let lastUserIndex = -1 - for (let i = messages.length - 1; i >= 0; i--) { - if (messages[i]?.role === 'user') { + for (let i = renderedMessages.length - 1; i >= 0; i--) { + if (renderedMessages[i]?.role === 'user') { lastUserIndex = i break } } if (lastUserIndex === -1) return null - for (let i = messages.length - 1; i > lastUserIndex; i--) { - const m = messages[i] + for (let i = renderedMessages.length - 1; i > lastUserIndex; i--) { + const m = renderedMessages[i] if (m?.role !== 'assistant') continue const part = m.parts.find( (p): p is StructuredOutputPart => p.type === 'structured-output', @@ -266,7 +296,7 @@ export function useChat< if (part) return part } return null - }, [messages]) + }, [renderedMessages]) const partial = useMemo(() => { if (!activeStructuredPart) return {} as Partial @@ -286,7 +316,7 @@ export function useChat< // structurally narrow across that conditional, so the `as` is the seam. // eslint-disable-next-line no-restricted-syntax -- hook return shape diverges from generic UseChatReturn due to conditional type on TSchema; TS can't structurally narrow return { - messages, + messages: renderedMessages, sendMessage, append, reload, diff --git a/packages/ai-react/tests/use-chat.test.ts b/packages/ai-react/tests/use-chat.test.ts index 687e3b84b..8e43eedc1 100644 --- a/packages/ai-react/tests/use-chat.test.ts +++ b/packages/ai-react/tests/use-chat.test.ts @@ -2,7 +2,7 @@ import type { ModelMessage, StreamChunk } from '@tanstack/ai' import { EventType } from '@tanstack/ai' import type { SubscribeConnectionAdapter } from '@tanstack/ai-client' import { act, renderHook, waitFor } from '@testing-library/react' -import { useState } from 'react' +import { StrictMode, useState } from 'react' import { describe, expect, it, vi } from 'vitest' import type { UIMessage, UseChatOptions } from '../src/types' import { useChat } from '../src/use-chat' @@ -14,6 +14,14 @@ import { } from './test-utils' describe('useChat', () => { + function createDeferred() { + let resolve!: (value: T) => void + const promise = new Promise((promiseResolve) => { + resolve = promiseResolve + }) + return { promise, resolve } + } + describe('initialization', () => { it('should initialize with default state', () => { const adapter = createMockConnectionAdapter() @@ -59,6 +67,118 @@ describe('useChat', () => { expect(result.current.messages).toEqual(initialMessages) }) + it('should initialize with persisted messages', async () => { + const adapter = createMockConnectionAdapter() + const persistedMessages: UIMessage[] = [ + { + id: 'persisted-1', + role: 'user', + parts: [{ type: 'text', content: 'Persisted' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn(() => persistedMessages), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const { result } = renderUseChat({ + connection: adapter, + id: 'persisted-chat', + persistence, + }) + + await waitFor(() => { + expect(result.current.messages).toEqual(persistedMessages) + }) + expect(persistence.getItem).toHaveBeenCalledWith('persisted-chat') + }) + + it('should preserve persisted empty messages over provided initial messages', async () => { + const adapter = createMockConnectionAdapter() + const initialMessages: UIMessage[] = [ + { + id: 'initial-1', + role: 'user', + parts: [{ type: 'text', content: 'Initial' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn(() => []), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const { result } = renderUseChat({ + connection: adapter, + id: 'persisted-empty-chat', + initialMessages, + persistence, + }) + + await waitFor(() => { + expect(persistence.getItem).toHaveBeenCalledWith('persisted-empty-chat') + }) + expect(result.current.messages).toEqual([]) + }) + + it('should ignore async persisted messages from a previous id', async () => { + const oldHydration = createDeferred>() + const oldMessages: Array = [ + { + id: 'old-persisted', + role: 'user', + parts: [{ type: 'text', content: 'Old persisted' }], + createdAt: new Date(), + }, + ] + const newMessages: Array = [ + { + id: 'new-persisted', + role: 'user', + parts: [{ type: 'text', content: 'New persisted' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn((id: string) => + id === 'old-chat' ? oldHydration.promise : newMessages, + ), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + function useChangingChat() { + const [id, setId] = useState('old-chat') + const chat = useChat({ + connection: createMockConnectionAdapter(), + id, + persistence, + }) + + return { ...chat, setId } + } + + const { result } = renderHook(() => useChangingChat()) + + act(() => { + result.current.setId('new-chat') + }) + + await waitFor(() => { + expect(result.current.messages).toEqual(newMessages) + }) + + await act(async () => { + oldHydration.resolve(oldMessages) + await oldHydration.promise + }) + + expect(result.current.messages).toEqual(newMessages) + }) + it('should use provided id', async () => { const chunks = createTextChunks('Response') const adapter = createMockConnectionAdapter({ chunks }) @@ -812,6 +932,70 @@ describe('useChat', () => { }) expect(first).not.toHaveBeenCalled() }) + + it('should ignore user callbacks from an old client after id changes', async () => { + const releaseOldStream = createDeferred() + const oldOnChunk = vi.fn() + const newOnChunk = vi.fn() + const adapter = { + async *connect() { + await releaseOldStream.promise + yield* createTextChunks('stale old client') + }, + } + + const { result, rerender } = renderHook( + (opts: UseChatOptions) => useChat(opts), + { + initialProps: { + connection: adapter, + id: 'old-client', + onChunk: oldOnChunk, + }, + }, + ) + + const sendPromise = result.current.sendMessage('Test') + await waitFor(() => { + expect(result.current.isLoading).toBe(true) + }) + + rerender({ + connection: adapter, + id: 'new-client', + onChunk: newOnChunk, + }) + + releaseOldStream.resolve() + await sendPromise + + expect(oldOnChunk).not.toHaveBeenCalled() + expect(newOnChunk).not.toHaveBeenCalled() + }) + + it('should keep callbacks live across StrictMode effect replay for the same client', async () => { + const onChunk = vi.fn() + const adapter = createMockConnectionAdapter({ + chunks: createTextChunks('strict response'), + }) + + const { result } = renderHook( + () => + useChat({ + connection: adapter, + onChunk, + }), + { wrapper: StrictMode }, + ) + + await act(async () => { + await result.current.sendMessage('Test') + }) + + expect(onChunk).toHaveBeenCalledWith( + expect.objectContaining({ type: EventType.TEXT_MESSAGE_CONTENT }), + ) + }) }) describe('edge cases and error handling', () => { @@ -866,7 +1050,7 @@ describe('useChat', () => { }) describe('client recreation', () => { - it('should pass existing messages to new client when id changes in a batched update', async () => { + it('should not pass previous id messages to a new client id without persisted messages', async () => { const connectSpy = vi.fn() const chunks = createTextChunks('Reply') const adapter = createMockConnectionAdapter({ @@ -906,8 +1090,6 @@ describe('useChat', () => { result.current.switchId('client-B') }) - // Send a message through the new client. If the client lost the - // previous messages, the adapter only receives the new message. await act(async () => { await result.current.sendMessage('Follow-up') }) @@ -916,14 +1098,48 @@ describe('useChat', () => { expect(connectSpy).toHaveBeenCalled() }) - // The adapter should receive the previous conversation + new message. - // `connectSpy.mock.calls[0][0]` is the messages array passed to - // adapter.connect — typed via MockConnectionAdapterOptions.onConnect. const sentMessages = connectSpy.mock.calls[0]![0] as Array< ModelMessage | UIMessage > - const userMessages = sentMessages.filter((m) => m.role === 'user') - expect(userMessages.length).toBeGreaterThanOrEqual(2) + const sentText = sentMessages + .flatMap((message) => ('parts' in message ? message.parts : [])) + .filter((part) => part.type === 'text') + .map((part) => part.content) + .join('') + + expect(sentText).toContain('Follow-up') + expect(sentText).not.toContain('Hello') + expect(result.current.messages).not.toEqual(messages) + }) + + it('should return new client messages during the id change render', () => { + const adapter = createMockConnectionAdapter() + const oldMessages: Array = [ + { + id: 'old-message', + role: 'user', + parts: [{ type: 'text', content: 'Old client message' }], + createdAt: new Date(), + }, + ] + + const { result } = renderHook(() => { + const [id, setId] = useState('client-A') + const chat = useChat({ connection: adapter, id }) + return { ...chat, switchId: setId } + }) + + act(() => { + result.current.setMessages(oldMessages) + }) + + expect(result.current.messages).toEqual(oldMessages) + + act(() => { + result.current.switchId('client-B') + }) + + expect(result.current.messages).toEqual([]) }) }) @@ -1716,5 +1932,89 @@ describe('useChat', () => { expect(result.current.isLoading).toBe(false) }) }) + + it('should keep receiving live updates and callbacks after live toggles for the same client', async () => { + const onChunk = vi.fn() + const chunks: Array = [] + // Wrapped in an object so the assignment inside the generator closure + // is visible to TS control-flow analysis. A bare `let` would narrow to + // `null` at the call site, and `?.()` would then strip it to `never`. + const subscriberControl: { wake: (() => void) | null } = { wake: null } + const adapter: SubscribeConnectionAdapter = { + subscribe: async function* (_signal?: AbortSignal) { + while (true) { + if (chunks.length === 0) { + await new Promise((resolve) => { + subscriberControl.wake = resolve + }) + } + const chunk = chunks.shift() + if (chunk) yield chunk + } + }, + send: vi.fn(async () => {}), + } + + const { result, rerender } = renderHook( + ({ live }) => + useChat({ + connection: adapter, + live, + onChunk, + }), + { initialProps: { live: true } }, + ) + + await waitFor(() => { + expect(result.current.isSubscribed).toBe(true) + }) + + rerender({ live: false }) + await waitFor(() => { + expect(result.current.isSubscribed).toBe(false) + }) + + rerender({ live: true }) + await waitFor(() => { + expect(result.current.isSubscribed).toBe(true) + }) + + chunks.push( + { + type: EventType.RUN_STARTED, + runId: 'run-after-toggle', + threadId: 'thread-1', + timestamp: Date.now(), + }, + { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: 'msg-after-toggle', + timestamp: Date.now(), + delta: 'after toggle', + content: 'after toggle', + }, + { + type: EventType.RUN_FINISHED, + runId: 'run-after-toggle', + threadId: 'thread-1', + timestamp: Date.now(), + }, + ) + subscriberControl.wake?.() + subscriberControl.wake = null + + await waitFor(() => { + expect( + result.current.messages.some((message) => + message.parts.some( + (part) => part.type === 'text' && part.content === 'after toggle', + ), + ), + ).toBe(true) + }) + expect(onChunk).toHaveBeenCalledWith( + expect.objectContaining({ type: EventType.TEXT_MESSAGE_CONTENT }), + ) + }) }) }) diff --git a/packages/ai-solid/src/use-chat.ts b/packages/ai-solid/src/use-chat.ts index 992317e01..14937e792 100644 --- a/packages/ai-solid/src/use-chat.ts +++ b/packages/ai-solid/src/use-chat.ts @@ -83,6 +83,9 @@ export function useChat< ...(options.initialMessages !== undefined && { initialMessages: options.initialMessages, }), + ...(options.persistence !== undefined && { + persistence: options.persistence, + }), body: options.body, ...(options.forwardedProps !== undefined && { forwardedProps: options.forwardedProps, @@ -136,6 +139,8 @@ export function useChat< // Connection and other options are captured at creation time }, [clientId]) + setMessages(client().getMessages()) + // Sync body / forwardedProps changes to the client. // Both populate the same wire payload; `forwardedProps` is preferred // and `body` is deprecated but still supported. @@ -151,18 +156,6 @@ export function useChat< }) }) - // Sync initial messages on mount only - // Note: initialMessages are passed to ChatClient constructor, but we also - // set them here to ensure React state is in sync - createEffect(() => { - if (options.initialMessages && options.initialMessages.length > 0) { - // Only set if current messages are empty (initial state) - if (messages().length === 0) { - client().setMessagesManually(options.initialMessages) - } - } - }) // Only run on mount - initialMessages are handled by ChatClient constructor - // Apply initial live mode immediately on hook creation. if (options.live) { client().subscribe() diff --git a/packages/ai-solid/tests/use-chat.test.ts b/packages/ai-solid/tests/use-chat.test.ts index f6afaef4e..31ba2f26f 100644 --- a/packages/ai-solid/tests/use-chat.test.ts +++ b/packages/ai-solid/tests/use-chat.test.ts @@ -55,6 +55,63 @@ describe('useChat', () => { expect(result.current.messages).toEqual(initialMessages) }) + it('should initialize with persisted messages', async () => { + const adapter = createMockConnectionAdapter() + const persistedMessages: UIMessage[] = [ + { + id: 'persisted-1', + role: 'user', + parts: [{ type: 'text', content: 'Persisted' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn(() => persistedMessages), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const { result } = renderUseChat({ + connection: adapter, + id: 'persisted-chat', + persistence, + }) + + await waitFor(() => { + expect(result.current.messages).toEqual(persistedMessages) + }) + expect(persistence.getItem).toHaveBeenCalledWith('persisted-chat') + }) + + it('should preserve persisted empty messages over provided initial messages', async () => { + const adapter = createMockConnectionAdapter() + const initialMessages: UIMessage[] = [ + { + id: 'initial-1', + role: 'user', + parts: [{ type: 'text', content: 'Initial' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn(() => []), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const { result } = renderUseChat({ + connection: adapter, + id: 'persisted-empty-chat', + initialMessages, + persistence, + }) + + await waitFor(() => { + expect(persistence.getItem).toHaveBeenCalledWith('persisted-empty-chat') + }) + expect(result.current.messages).toEqual([]) + }) + it('should use provided id', async () => { const chunks = createTextChunks('Response') const adapter = createMockConnectionAdapter({ chunks }) diff --git a/packages/ai-svelte/src/create-chat.svelte.ts b/packages/ai-svelte/src/create-chat.svelte.ts index f7e991141..b7e4539b3 100644 --- a/packages/ai-svelte/src/create-chat.svelte.ts +++ b/packages/ai-svelte/src/create-chat.svelte.ts @@ -100,6 +100,9 @@ export function createChat< ...(options.initialMessages !== undefined && { initialMessages: options.initialMessages, }), + ...(options.persistence !== undefined && { + persistence: options.persistence, + }), ...(options.body !== undefined && { body: options.body }), ...(options.forwardedProps !== undefined && { forwardedProps: options.forwardedProps, @@ -151,6 +154,8 @@ export function createChat< }, }) + messages = client.getMessages() + if (options.live) { client.subscribe() } diff --git a/packages/ai-svelte/tests/use-chat.test.ts b/packages/ai-svelte/tests/use-chat.test.ts index 1a6c4b777..ff5b846ba 100644 --- a/packages/ai-svelte/tests/use-chat.test.ts +++ b/packages/ai-svelte/tests/use-chat.test.ts @@ -54,6 +54,59 @@ describe('createChat', () => { expect(chat.messages[0]!.role).toBe('user') }) + it('should initialize with persisted messages', () => { + const mockConnection = createMockConnectionAdapter({ chunks: [] }) + const persistedMessages = [ + { + id: 'persisted-1', + role: 'user' as const, + parts: [{ type: 'text' as const, content: 'Persisted' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn(() => persistedMessages), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const chat = createChat({ + connection: mockConnection, + id: 'persisted-chat', + persistence, + }) + + expect(chat.messages).toEqual(persistedMessages) + expect(persistence.getItem).toHaveBeenCalledWith('persisted-chat') + }) + + it('should let persisted empty arrays override initial messages', () => { + const mockConnection = createMockConnectionAdapter({ chunks: [] }) + const initialMessages = [ + { + id: 'initial-1', + role: 'user' as const, + parts: [{ type: 'text' as const, content: 'Initial' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn(() => []), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const chat = createChat({ + connection: mockConnection, + id: 'persisted-chat', + initialMessages, + persistence, + }) + + expect(chat.messages).toEqual([]) + expect(persistence.getItem).toHaveBeenCalledWith('persisted-chat') + }) + it('should have sendMessage method', () => { const mockConnection = createMockConnectionAdapter({ chunks: [] }) diff --git a/packages/ai-vue/src/use-chat.ts b/packages/ai-vue/src/use-chat.ts index a5e5301f1..2d7f0ccae 100644 --- a/packages/ai-vue/src/use-chat.ts +++ b/packages/ai-vue/src/use-chat.ts @@ -86,6 +86,9 @@ export function useChat< ...(options.initialMessages !== undefined && { initialMessages: options.initialMessages, }), + ...(options.persistence !== undefined && { + persistence: options.persistence, + }), ...(options.body !== undefined && { body: options.body }), ...(options.forwardedProps !== undefined && { forwardedProps: options.forwardedProps, @@ -136,6 +139,8 @@ export function useChat< }, }) + messages.value = client.getMessages() + // Sync body / forwardedProps changes to the client. // Both populate the same wire payload; `forwardedProps` is preferred // and `body` is deprecated but still supported. diff --git a/packages/ai-vue/tests/use-chat.test.ts b/packages/ai-vue/tests/use-chat.test.ts index 6492fb396..0e227afe5 100644 --- a/packages/ai-vue/tests/use-chat.test.ts +++ b/packages/ai-vue/tests/use-chat.test.ts @@ -54,6 +54,61 @@ describe('useChat', () => { expect(result.current.messages).toEqual(initialMessages) }) + it('should initialize with persisted messages', async () => { + const adapter = createMockConnectionAdapter() + const persistedMessages: Array = [ + { + id: 'persisted-1', + role: 'user', + parts: [{ type: 'text', content: 'Persisted' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn(() => persistedMessages), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const { result } = renderUseChat({ + connection: adapter, + id: 'persisted-chat', + persistence, + }) + await flushPromises() + + expect(result.current.messages).toEqual(persistedMessages) + expect(persistence.getItem).toHaveBeenCalledWith('persisted-chat') + }) + + it('should let persisted empty arrays override initial messages', async () => { + const adapter = createMockConnectionAdapter() + const initialMessages: Array = [ + { + id: 'initial-1', + role: 'user', + parts: [{ type: 'text', content: 'Initial' }], + createdAt: new Date(), + }, + ] + const persistence = { + getItem: vi.fn(() => []), + setItem: vi.fn(), + removeItem: vi.fn(), + } + + const { result } = renderUseChat({ + connection: adapter, + id: 'persisted-chat', + initialMessages, + persistence, + }) + await flushPromises() + + expect(result.current.messages).toEqual([]) + expect(persistence.getItem).toHaveBeenCalledWith('persisted-chat') + }) + it('should use provided id', async () => { const chunks = createTextChunks('Response') const adapter = createMockConnectionAdapter({ chunks }) diff --git a/testing/e2e/src/routes/$provider/$feature.tsx b/testing/e2e/src/routes/$provider/$feature.tsx index a3c4a8996..ea080c4fc 100644 --- a/testing/e2e/src/routes/$provider/$feature.tsx +++ b/testing/e2e/src/routes/$provider/$feature.tsx @@ -3,7 +3,7 @@ import { createFileRoute } from '@tanstack/react-router' import { uiMessagesToWire } from '@tanstack/ai' import { fetchServerSentEvents, useChat } from '@tanstack/ai-react' import { clientTools } from '@tanstack/ai-client' -import type { UIMessage } from '@tanstack/ai-client' +import type { ChatClientPersistence, UIMessage } from '@tanstack/ai-client' import type { GeminiInteractionsCustomEventValue } from '@tanstack/ai-gemini/experimental' import type { Feature, Mode, Provider } from '@/lib/types' import { ALL_FEATURES, ALL_PROVIDERS } from '@/lib/types' @@ -34,6 +34,8 @@ export const Route = createFileRoute('/$provider/$feature')({ rawMode && VALID_MODES.has(rawMode as Mode) ? (rawMode as Mode) : undefined, + persistence: + search.persistence === 'localStorage' ? 'localStorage' : undefined, } }, }) @@ -54,6 +56,27 @@ const addToCartClient = addToCartToolDef.client((args) => ({ quantity: args.quantity, })) +const localStoragePersistence: ChatClientPersistence = { + getItem: (id) => { + const item = window.localStorage.getItem(id) + return item + ? (JSON.parse(item) as Array).map((message) => ({ + ...message, + createdAt: + typeof message.createdAt === 'string' + ? new Date(message.createdAt) + : message.createdAt, + })) + : null + }, + setItem: (id, messages) => { + window.localStorage.setItem(id, JSON.stringify(messages)) + }, + removeItem: (id) => { + window.localStorage.removeItem(id) + }, +} + const isProvider = (s: string): s is Provider => (ALL_PROVIDERS as ReadonlyArray).includes(s) const isFeature = (s: string): s is Feature => @@ -167,7 +190,18 @@ function ChatFeature({ const tools = needsApproval ? clientTools(addToCartClient) : undefined - const { testId, aimockPort } = Route.useSearch() + const { testId, aimockPort, persistence } = Route.useSearch() + const persistenceEnabled = persistence === 'localStorage' + const baseChatId = `e2e-chat-${testId ?? `${provider}-${feature}`}` + // When persistence is on, expose a tiny thread switcher so e2e can verify that + // changing the `id` in place swaps to that id's own persisted history (the + // render-from-getMessages + activeClientRef path), keyed per thread. Start on + // thread "a" (not null) so the page loads already on a thread id — switching + // is then a pure in-place id swap with no initial null→thread transition. + const [activeThread, setActiveThread] = useState( + persistenceEnabled ? 'a' : null, + ) + const chatId = activeThread ? `${baseChatId}:${activeThread}` : baseChatId const [structuredObject, setStructuredObject] = useState(null) const [contentDeltaCount, setContentDeltaCount] = useState(0) @@ -216,34 +250,43 @@ function ChatFeature({ } : { connection: fetchServerSentEvents('/api/chat') } - const { messages, sendMessage, isLoading, addToolApprovalResponse, stop } = - useChat({ - ...transport, - tools, - body: { - provider, - feature, - testId, - aimockPort, - previousInteractionId: interactionId, - }, - onCustomEvent: (eventType, data) => { - if (eventType === 'structured-output.complete') { - const value = data as { object: unknown; raw: string } | undefined - setStructuredObject(value?.object ?? null) - } else if (eventType === 'gemini.interactionId') { - const value = data as - | GeminiInteractionsCustomEventValue<'gemini.interactionId'> - | undefined - if (value?.interactionId) setInteractionId(value.interactionId) - } - }, - onChunk: (chunk) => { - if (chunk.type === 'TEXT_MESSAGE_CONTENT') { - setContentDeltaCount((n) => n + 1) - } - }, - }) + const { + messages, + sendMessage, + isLoading, + addToolApprovalResponse, + stop, + clear, + } = useChat({ + id: chatId, + ...transport, + tools, + body: { + provider, + feature, + testId, + aimockPort, + previousInteractionId: interactionId, + }, + persistence: + persistence === 'localStorage' ? localStoragePersistence : undefined, + onCustomEvent: (eventType, data) => { + if (eventType === 'structured-output.complete') { + const value = data as { object: unknown; raw: string } | undefined + setStructuredObject(value?.object ?? null) + } else if (eventType === 'gemini.interactionId') { + const value = data as + | GeminiInteractionsCustomEventValue<'gemini.interactionId'> + | undefined + if (value?.interactionId) setInteractionId(value.interactionId) + } + }, + onChunk: (chunk) => { + if (chunk.type === 'TEXT_MESSAGE_CONTENT') { + setContentDeltaCount((n) => n + 1) + } + }, + }) return ( <> @@ -252,6 +295,27 @@ function ChatFeature({ {interactionId} )} + {persistenceEnabled && ( +
+ + + +
+ )} { + test('persists chat messages across browser reload with localStorage', async ({ + page, + testId, + aimockPort, + }) => { + await page.goto( + `${featureUrl('openai', 'chat', testId, aimockPort)}&persistence=localStorage`, + ) + + await sendMessage(page, '[chat] recommend a guitar') + await waitForResponse(page) + + await expect(page.getByTestId('user-message')).toContainText( + '[chat] recommend a guitar', + ) + await expect(page.getByTestId('assistant-message')).toContainText( + 'Fender Stratocaster', + ) + + await page.reload() + + await expect(page.getByTestId('user-message')).toContainText( + '[chat] recommend a guitar', + ) + await expect(page.getByTestId('assistant-message')).toContainText( + 'Fender Stratocaster', + ) + }) + + test('clear() removes the persisted conversation so a reload starts empty', async ({ + page, + testId, + aimockPort, + }) => { + await page.goto( + `${featureUrl('openai', 'chat', testId, aimockPort)}&persistence=localStorage`, + ) + + await sendMessage(page, '[chat] recommend a guitar') + await waitForResponse(page) + await expect(page.getByTestId('user-message')).toContainText( + '[chat] recommend a guitar', + ) + + await page.getByTestId('clear-button').click() + await expect(page.getByTestId('user-message')).toHaveCount(0) + await expect(page.getByTestId('assistant-message')).toHaveCount(0) + + // The conversation was removed from storage, not just from memory — a + // reload must not resurrect it. + await page.reload() + await expect(page.getByTestId('message-list')).toBeVisible() + await expect(page.getByTestId('user-message')).toHaveCount(0) + await expect(page.getByTestId('assistant-message')).toHaveCount(0) + }) + + test('switches per-thread history when the chat id changes in place', async ({ + page, + testId, + aimockPort, + }) => { + await page.goto( + `${featureUrl('openai', 'chat', testId, aimockPort)}&persistence=localStorage`, + ) + + // The page loads on thread A. Send a message (persisted under A's own id). + await sendMessage(page, '[chat] recommend a guitar') + await waitForResponse(page) + await expect(page.getByTestId('user-message')).toHaveCount(1) + + // Switch to thread B in place — its own (empty) history loads, proving the + // id swap doesn't leak thread A's messages into thread B. + await page.getByTestId('select-thread-b').click() + await expect(page.getByTestId('user-message')).toHaveCount(0) + await expect(page.getByTestId('assistant-message')).toHaveCount(0) + + await sendMessage(page, '[chat] recommend a guitar') + await waitForResponse(page) + await expect(page.getByTestId('user-message')).toHaveCount(1) + + // Switch back to thread A — its persisted history is restored from storage + // on the in-place swap (render-from-getMessages), exactly one message. + await page.getByTestId('select-thread-a').click() + await expect(page.getByTestId('user-message')).toHaveCount(1) + await expect(page.getByTestId('user-message')).toContainText( + '[chat] recommend a guitar', + ) + await expect(page.getByTestId('assistant-message')).toContainText( + 'Fender Stratocaster', + ) + }) +})