From e77bf503ce97d48037a9abe4abfcfa02eb6c4e14 Mon Sep 17 00:00:00 2001 From: Tony Hu Date: Wed, 3 Sep 2025 14:21:24 +0800 Subject: [PATCH 1/2] feat(agent): add initial guidance and checkpoint messages for user interactions --- .../sidepanel/utils/agent/constants.ts | 53 +++++++++++++++++ entrypoints/sidepanel/utils/agent/index.ts | 58 ++++++++++--------- entrypoints/sidepanel/utils/chat/chat.ts | 2 +- utils/error.ts | 2 +- utils/llm/models.ts | 2 +- utils/llm/tools/prompt-based/helpers.ts | 13 +++++ utils/rpc/background-fns.ts | 4 ++ 7 files changed, 105 insertions(+), 29 deletions(-) create mode 100644 entrypoints/sidepanel/utils/agent/constants.ts diff --git a/entrypoints/sidepanel/utils/agent/constants.ts b/entrypoints/sidepanel/utils/agent/constants.ts new file mode 100644 index 00000000..c5b47932 --- /dev/null +++ b/entrypoints/sidepanel/utils/agent/constants.ts @@ -0,0 +1,53 @@ +import { TagBuilder } from '@/utils/prompts/helpers' + +export const AGENT_INITIAL_GUIDANCE = new TagBuilder('initial_guidance').insertContent(`Remember to analyze what information layer this question requires: +- Layer 0 (Surface): Immediately visible information +- Layer 1 (Interactive): Information behind clicks and expansions +- Layer 2 (External): Information from other sources + +Most detailed questions require Layer 1 information. Surface information is rarely sufficient.`) + +export const AGENT_CHECKPOINT_MESSAGES: Record = { + 2: new TagBuilder('checkpoint_1').insertContent(`FIRST ACTION COMPLETE. Now assess your progress: +- Did you identify where the detailed information lives on this page? +- Can you see clickable elements that lead to the information the user needs? +- If yes to both: Your next 2-3 actions should be clicking those elements. +- If no: You may need to look elsewhere or search online. + +Remember: Surface viewing (Layer 0) is just the beginning, not the end. +Most user questions require Layer 1 information behind clicks.`), + + 4: new TagBuilder('checkpoint_3').insertContent(`THIRD ITERATION. Progress check: +- Have you moved from surface (Layer 0) to deep (Layer 1) information? +- Are you clicking on information-rich elements or just browsing around? +- Is your information getting more complete or are you wandering aimlessly? + +Critical reminder: If user asked about specific content (comments, details, analysis), +you MUST access that actual content, not just see that it exists. +Click on relevant elements that contain the information needed.`), + + 6: new TagBuilder('checkpoint_5').insertContent(`FIFTH ITERATION. Strategy review: +- List what concrete information you've learned so far +- List what you still don't know but need to know +- If gaps remain, identify exactly which elements to click next +- If information is complete, prepare to provide your final answer + +Key question: Can you fully answer the user's question with your current information? +If NO: Continue clicking on relevant elements. If YES: Provide comprehensive answer.`), + + 8: new TagBuilder('checkpoint_7').insertContent(`APPROACHING ITERATION LIMIT. Final push needed: +- You have only 3 more attempts remaining +- Focus on the most critical missing information gaps +- Click on the most promising unexplored elements that likely contain needed info +- After 2 more iterations, you must provide the best answer possible with available data + +Stop exploring only if: no more relevant clickable elements OR information is truly complete.`), +} + +export const AGENT_TOOL_CALL_RESULT_GUIDANCE = `Based on the tool results above, follow your self-assessment protocol: + +- Analyze the results before taking further action. +- Evaluate what you have learned and identify any remaining gaps. +- Proceed with the next steps based on your assessment of information completeness.` + +export const AGENT_FORCE_FINAL_ANSWER = `Answer Language: Strictly follow the LANGUAGE POLICY above.\nBased on all the information collected above, please provide a comprehensive final answer.\nDo not use any tools.` diff --git a/entrypoints/sidepanel/utils/agent/index.ts b/entrypoints/sidepanel/utils/agent/index.ts index d8957641..1fecc893 100644 --- a/entrypoints/sidepanel/utils/agent/index.ts +++ b/entrypoints/sidepanel/utils/agent/index.ts @@ -1,4 +1,4 @@ -import { CoreAssistantMessage, CoreMessage, CoreUserMessage } from 'ai' +import { CoreAssistantMessage, CoreMessage, CoreUserMessage, FilePart, ImagePart, TextPart } from 'ai' import { isEqual } from 'es-toolkit' import EventEmitter from 'events' import { Ref, ref } from 'vue' @@ -8,17 +8,17 @@ import { AssistantMessageV1 } from '@/types/chat' import { PromiseOr } from '@/types/common' import { Base64ImageData, ImageDataWithId } from '@/types/image' import { TagBuilderJSON } from '@/types/prompt' -import { AGENT_LOOP_COUNT_REFILL_USER_PROMPT } from '@/utils/constants' import { AbortError, AiSDKError, AppError, ErrorCode, fromError, ModelNotFoundError, ModelRequestError, ParseFunctionCallError, UnknownError } from '@/utils/error' import { useGlobalI18n } from '@/utils/i18n' import { generateRandomId } from '@/utils/id' import { InferredParams } from '@/utils/llm/tools/prompt-based/helpers' import { GetPromptBasedTool, PromptBasedToolName, PromptBasedToolNameAndParams } from '@/utils/llm/tools/prompt-based/tools' import logger from '@/utils/logger' -import { renderPrompt, TagBuilder, UserPrompt } from '@/utils/prompts/helpers' +import { renderPrompt, TagBuilder, TextBuilder, UserPrompt } from '@/utils/prompts/helpers' import { ReactiveHistoryManager } from '../chat' import { streamTextInBackground } from '../llm' +import { AGENT_CHECKPOINT_MESSAGES, AGENT_FORCE_FINAL_ANSWER, AGENT_INITIAL_GUIDANCE, AGENT_TOOL_CALL_RESULT_GUIDANCE } from './constants' import { AgentStorage } from './strorage' type AgentToolCallHandOffResult = { @@ -114,7 +114,7 @@ export class Agent { } } - injectImagesLastMessage(messages: CoreMessage[], images: Base64ImageData[]) { + injectImagesToLastMessage(messages: CoreMessage[], images: Base64ImageData[]) { const lastMessage = messages[messages.length - 1] if (lastMessage && lastMessage.role === 'user') { if (typeof lastMessage.content === 'string') { @@ -140,25 +140,21 @@ export class Agent { return clonedMessages } - injectContentToLastUserMessage(messages: CoreMessage[], content?: string, newLine = true) { + replaceLastUserMessage(messages: CoreMessage[], content: string) { if (!content) return messages - content = newLine ? `\n${content}` : content const lastMessage = messages[messages.length - 1] if (lastMessage && lastMessage.role === 'user') { if (typeof lastMessage.content === 'string') { - lastMessage.content = [ - { type: 'text', text: lastMessage.content }, - { type: 'text', text: content }, - ] + lastMessage.content = [{ type: 'text', text: content }] } else { - const text = lastMessage.content.map((c) => c.type === 'text' ? c.text : '').join('') - lastMessage.content = [ - { type: 'text', text: text }, - { type: 'text', text: content }, - ] + const nonTextParts = lastMessage.content.filter((part) => part.type !== 'text') as (TextPart | ImagePart | FilePart)[] + lastMessage.content = [...nonTextParts, { type: 'text', text: content }] } } + else { + throw new UnknownError('Last message is not a user message') + } return messages } @@ -257,13 +253,26 @@ export class Agent { } } - async runWithPrompt(baseMessages: CoreMessage[]) { + // iteration starts from 1 + buildExtendedUserMessage(iteration: number, originalUserMessage: string, toolResults?: string) { + if (iteration === 1) return `${originalUserMessage}\n\n${AGENT_INITIAL_GUIDANCE.build()}` + const textBuilder = new TextBuilder(`${originalUserMessage}`) + if (toolResults) textBuilder.insertContent(toolResults) + const checkPointMessageBuilder = AGENT_CHECKPOINT_MESSAGES[iteration] + if (checkPointMessageBuilder) textBuilder.insertContent(checkPointMessageBuilder.build()) + textBuilder.insertContent(AGENT_TOOL_CALL_RESULT_GUIDANCE) + return textBuilder.build() + } + + async run(rawBaseMessages: CoreMessage[]) { this.stop() const abortController = this.createAbortController() let reasoningStart: number | undefined + const baseMessages = structuredClone(rawBaseMessages) this.log.debug('baseMessages', baseMessages) const originalUserMessageText = this.extractTheLastUserMessageText(baseMessages) - if (!originalUserMessageText) this.log.warn('Missing original user message') + if (!originalUserMessageText) throw new UnknownError('Missing original user message') + this.replaceLastUserMessage(baseMessages, this.buildExtendedUserMessage(1, originalUserMessageText)) // clone the message to avoid ui changes in agent's running process // this messages only used for the agent iteration but not user-facing @@ -282,20 +291,16 @@ export class Agent { } iteration++ const shouldForceAnswer = iteration === this.maxIterations - const shouldRefillOriginalUserPrompt = iteration >= AGENT_LOOP_COUNT_REFILL_USER_PROMPT - this.log.debug('Agent iteration', iteration, { shouldForceAnswer, shouldRefillOriginalUserPrompt }) + this.log.debug('Agent iteration', iteration, { shouldForceAnswer }) const thisLoopMessages: CoreMessage[] = [...baseMessages, ...loopMessages] - if (shouldForceAnswer) { - thisLoopMessages.push({ role: 'user', content: `Answer Language: Strictly follow the LANGUAGE POLICY above.\nBased on all the information collected above, please provide a comprehensive final answer.\nDo not use any tools.` }) - } + if (shouldForceAnswer) thisLoopMessages.push({ role: 'user', content: AGENT_FORCE_FINAL_ANSWER }) let taskMessageModifier = this.makeTaskMessageGroupProxy(abortController.signal) const agentMessageManager = this.makeTempAgentMessageManager() const agentMessage = agentMessageManager.getOrAddAgentMessage() const response = streamTextInBackground({ abortSignal: abortController.signal, - // do not modify the original messages to avoid duplicated images in history - messages: this.injectContentToLastUserMessage(this.injectImagesLastMessage(thisLoopMessages, loopImages), shouldRefillOriginalUserPrompt ? originalUserMessageText : undefined), + messages: this.injectImagesToLastMessage(thisLoopMessages, loopImages), }) let hasError = false let text = '' @@ -349,13 +354,14 @@ export class Agent { const subAgent = new Agent({ tools: this.tools, agentStorage: this.agentStorage, historyManager: this.historyManager, maxIterations: this.maxIterations }) abortController.signal.addEventListener('abort', () => subAgent.stop()) loopMessages.push({ role: 'user', content: handoffResult.userPrompt }) - const lastMsg = await subAgent.runWithPrompt(this.overrideSystemPrompt([...baseMessages, ...loopMessages], handoffResult.overrideSystemPrompt)) + const lastMsg = await subAgent.run(this.overrideSystemPrompt([...baseMessages, ...loopMessages], handoffResult.overrideSystemPrompt)) this.log.debug('Sub-agent finished', lastMsg) if (lastMsg?.content) loopMessages.push(lastMsg) } } else { - loopMessages.push({ role: 'user', content: this.toolResultsToPrompt(toolResults.filter((t) => t.type === 'tool-result')) }) + const toolResultPart = this.toolResultsToPrompt(toolResults.filter((t) => t.type === 'tool-result')) + loopMessages.push({ role: 'user', content: this.buildExtendedUserMessage(iteration + 1, originalUserMessageText, toolResultPart) }) } } } diff --git a/entrypoints/sidepanel/utils/chat/chat.ts b/entrypoints/sidepanel/utils/chat/chat.ts index 5774f1bd..14a3735d 100644 --- a/entrypoints/sidepanel/utils/chat/chat.ts +++ b/entrypoints/sidepanel/utils/chat/chat.ts @@ -573,7 +573,7 @@ export class Chat { }, }) this.currentAgent = agent - await agent.runWithPrompt(baseMessages) + await agent.run(baseMessages) } private async generateEnvironmentDetails(currentUserMessageId: string) { diff --git a/utils/error.ts b/utils/error.ts index 00d291fa..545e2758 100644 --- a/utils/error.ts +++ b/utils/error.ts @@ -134,7 +134,7 @@ export class TimeoutError extends AppError<'timeoutError'> { } export class ParseFunctionCallError extends AppError<'parseFunctionCallError'> { - constructor(message: string) { + constructor(message: string, public type?: 'toolNotFound' | 'invalidFormat', public toolName?: string) { super('parseFunctionCallError', message) } diff --git a/utils/llm/models.ts b/utils/llm/models.ts index d15bd495..6644149f 100644 --- a/utils/llm/models.ts +++ b/utils/llm/models.ts @@ -97,7 +97,7 @@ export async function getModel(options: { }) } -export type LLMEndpointType = 'ollama' | 'web-llm' +export type LLMEndpointType = 'ollama' | 'lm-studio' | 'web-llm' export function parseErrorMessageFromChunk(error: unknown): string | null { if (error && typeof error === 'object' && 'message' in error && typeof (error as { message: unknown }).message === 'string') { diff --git a/utils/llm/tools/prompt-based/helpers.ts b/utils/llm/tools/prompt-based/helpers.ts index 79cf501a..c63259c0 100644 --- a/utils/llm/tools/prompt-based/helpers.ts +++ b/utils/llm/tools/prompt-based/helpers.ts @@ -212,6 +212,19 @@ export class PromptBasedTool): { errors: string[], success: boolean } { + const errors: string[] = [] + for (const key in this.parameters) { + const schema = this.parameters[key] + const value = params[key] + const result = schema.safeParse(value) + if (!result.success) { + errors.push(`Invalid parameter <${key}>: ${result.error.message}`) + } + } + return { errors, success: errors.length === 0 } + } + static createToolCallsStreamParser(tools: Tools) { type ToolWithParams = ExtractToolWithParams & { tagText: string } let accText = '' diff --git a/utils/rpc/background-fns.ts b/utils/rpc/background-fns.ts index 57490871..18eff385 100644 --- a/utils/rpc/background-fns.ts +++ b/utils/rpc/background-fns.ts @@ -635,6 +635,7 @@ async function checkModelReady(modelId: string) { const userConfig = await getUserConfig() const endpointType = userConfig.llm.endpointType.get() if (endpointType === 'ollama') return true + else if (endpointType === 'lm-studio') return true else if (endpointType === 'web-llm') { return await hasWebLLMModelInCache(modelId as WebLLMSupportedModel) } @@ -653,6 +654,9 @@ async function initCurrentModel() { if (endpointType === 'ollama') { return false } + else if (endpointType === 'lm-studio') { + return false + } else if (endpointType === 'web-llm') { const connectInfo = initWebLLMEngine(model as WebLLMSupportedModel) return connectInfo.portName From 839e0d78864424b74ffef66f3e63c2c40edc1ef2 Mon Sep 17 00:00:00 2001 From: Tony Hu Date: Wed, 3 Sep 2025 14:22:36 +0800 Subject: [PATCH 2/2] fix(agent): enhance tool call parsing in responses --- utils/llm/middlewares.ts | 89 ++++++++++++++++++++++++++++++++-------- 1 file changed, 71 insertions(+), 18 deletions(-) diff --git a/utils/llm/middlewares.ts b/utils/llm/middlewares.ts index 8c66fc3e..54a70f36 100644 --- a/utils/llm/middlewares.ts +++ b/utils/llm/middlewares.ts @@ -4,13 +4,17 @@ import type { LanguageModelV1Middleware, LanguageModelV1StreamPart } from 'ai' import { extractReasoningMiddleware } from 'ai' import { z } from 'zod' +import { nonNullable } from '../array' import { debounce } from '../debounce' import { ParseFunctionCallError } from '../error' import { generateRandomId } from '../id' -import logger from '../logger' +import Logger from '../logger' +import { TagBuilder } from '../prompts/helpers' import { PromptBasedTool } from './tools/prompt-based/helpers' import { promptBasedTools } from './tools/prompt-based/tools' +const logger = Logger.child('Agent') + export const reasoningMiddleware = extractReasoningMiddleware({ tagName: 'think', separator: '\n\n', @@ -56,7 +60,7 @@ export const extractPromptBasedToolCallsMiddleware: LanguageModelV1Middleware = log.warn('Error parsing prompt-based tool calls', errors) controller.enqueue({ type: 'error', - error: new ParseFunctionCallError(`${errors.map((e) => e).join(', ')}`), + error: new ParseFunctionCallError(`${errors.map((e) => e).join(', ')}`, 'invalidFormat'), }) controller.terminate() return @@ -87,27 +91,68 @@ export const extractPromptBasedToolCallsMiddleware: LanguageModelV1Middleware = }, } +function normalizeToolCall(toolCall: T) { + const log = logger.child('normalizeToolCallsMiddleware') + + const normalizeToolName = (toolName: string) => { + // sometimes gpt-oss return tool named xxx.toolName, we should normalize it to toolName + return toolName.split('.').pop()! + } + if (toolCall.toolName === 'tool_calls' && toolCall.args) { + // 1. {name: 'tool_calls', arguments: { name: , arguments: { a: 1 } }} + const newToolCall = safeParseJSON({ text: toolCall.args, schema: z.object({ name: z.string(), arguments: z.any() }) }) + if (newToolCall.success) { + const { name: toolName, ...restArgs } = newToolCall.value + toolCall.toolName = normalizeToolName(toolName) + toolCall.args = JSON.stringify(restArgs.arguments ?? restArgs) + } + // 2. {name: 'tool_calls', arguments: { tool: , ...otherArgs }} + else { + const newToolCall = safeParseJSON({ text: toolCall.args, schema: z.record(z.any(), z.any()) }) + if (newToolCall.success && newToolCall.value.tool) { + const { tool: toolName, ...restArgs } = newToolCall.value + toolCall.toolName = normalizeToolName(toolName) + toolCall.args = JSON.stringify(typeof restArgs.arguments === 'object' && restArgs.arguments ? restArgs.arguments : restArgs) + } + } + } + else { + toolCall.toolName = normalizeToolName(toolCall.toolName) + } + const paramsParsedResult = safeParseJSON({ text: toolCall.args, schema: z.record(z.any(), z.any()).optional() }) + const params = paramsParsedResult.success ? (paramsParsedResult.value ?? {}) : {} + const tool = promptBasedTools.find((t) => t.toolName === toolCall.toolName) + // ignore invalid tool calls + if (tool) { + const { success, errors } = tool.validateParameters(params) + if (!success) { + log.warn('Tool call validation failed', { toolCall, errors }) + return { errors, toolCall, params } + } + } + return { toolCall, params } +} + export const normalizeToolCallsMiddleware: LanguageModelV1Middleware = { wrapGenerate: async ({ doGenerate }) => { + const log = logger.child('normalizeToolCallsMiddleware') + const result = await doGenerate() if (result.toolCalls?.length) { + const originalToolCalls = structuredClone(result.toolCalls) result.toolCalls = result.toolCalls?.map((toolCall) => { - if (toolCall.toolName === 'tool_calls' && toolCall.args) { - const newToolCall = safeParseJSON({ text: toolCall.args, schema: z.object({ name: z.string(), arguments: z.any() }) }) - if (newToolCall.success) { - toolCall.toolName = newToolCall.value.name - toolCall.args = JSON.stringify(newToolCall.value.arguments) - } - } - return toolCall - }) + const { toolCall: normalizedToolCall } = normalizeToolCall(toolCall) + return normalizedToolCall + }).filter(nonNullable) + log.debug('Normalized tool calls', { originalToolCalls, normalizedToolCalls: result.toolCalls }) } return result }, wrapStream: async ({ doStream }) => { + const log = logger.child('normalizeToolCallsMiddleware') const { stream, ...rest } = await doStream() const transformStream = new TransformStream< @@ -115,15 +160,23 @@ export const normalizeToolCallsMiddleware: LanguageModelV1Middleware = { LanguageModelV1StreamPart >({ transform(chunk, controller) { - if (chunk.type === 'tool-call' && chunk.toolName === 'tool_calls' && chunk.args) { - const newToolCall = safeParseJSON({ text: chunk.args, schema: z.object({ name: z.string(), arguments: z.any() }) }) - logger.debug('Normalizing tool call', chunk, newToolCall) - if (newToolCall.success) { - chunk.toolName = newToolCall.value.name - chunk.args = JSON.stringify(newToolCall.value.arguments) + if (chunk.type === 'tool-call') { + const originalToolCalls = structuredClone(chunk) + const { toolCall: normalizedToolCall, errors, params } = normalizeToolCall(chunk) + log.debug('Normalized tool call', { originalToolCalls, normalizedToolCall }) + if (errors?.length) { + controller.enqueue({ + type: 'error', + error: new ParseFunctionCallError(`${TagBuilder.fromStructured(normalizedToolCall.toolName, params).build()}\n\nError: ${errors.join(',')}`, 'toolNotFound', normalizedToolCall.toolName), + }) + } + else { + controller.enqueue(normalizedToolCall) } } - controller.enqueue(chunk) + else { + controller.enqueue(chunk) + } }, })