Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions entrypoints/sidepanel/utils/agent/constants.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import { TagBuilder } from '@/utils/prompts/helpers'

export const AGENT_INITIAL_GUIDANCE = new TagBuilder('initial_guidance').insertContent(`Remember to analyze what information layer this question requires:
- Layer 0 (Surface): Immediately visible information
- Layer 1 (Interactive): Information behind clicks and expansions
- Layer 2 (External): Information from other sources

Most detailed questions require Layer 1 information. Surface information is rarely sufficient.`)

export const AGENT_CHECKPOINT_MESSAGES: Record<number, TagBuilder> = {
2: new TagBuilder('checkpoint_1').insertContent(`FIRST ACTION COMPLETE. Now assess your progress:
- Did you identify where the detailed information lives on this page?
- Can you see clickable elements that lead to the information the user needs?
- If yes to both: Your next 2-3 actions should be clicking those elements.
- If no: You may need to look elsewhere or search online.

Remember: Surface viewing (Layer 0) is just the beginning, not the end.
Most user questions require Layer 1 information behind clicks.`),

4: new TagBuilder('checkpoint_3').insertContent(`THIRD ITERATION. Progress check:
- Have you moved from surface (Layer 0) to deep (Layer 1) information?
- Are you clicking on information-rich elements or just browsing around?
- Is your information getting more complete or are you wandering aimlessly?

Critical reminder: If user asked about specific content (comments, details, analysis),
you MUST access that actual content, not just see that it exists.
Click on relevant elements that contain the information needed.`),

6: new TagBuilder('checkpoint_5').insertContent(`FIFTH ITERATION. Strategy review:
- List what concrete information you've learned so far
- List what you still don't know but need to know
- If gaps remain, identify exactly which elements to click next
- If information is complete, prepare to provide your final answer

Key question: Can you fully answer the user's question with your current information?
If NO: Continue clicking on relevant elements. If YES: Provide comprehensive answer.`),

8: new TagBuilder('checkpoint_7').insertContent(`APPROACHING ITERATION LIMIT. Final push needed:
- You have only 3 more attempts remaining
- Focus on the most critical missing information gaps
- Click on the most promising unexplored elements that likely contain needed info
- After 2 more iterations, you must provide the best answer possible with available data

Stop exploring only if: no more relevant clickable elements OR information is truly complete.`),
}

export const AGENT_TOOL_CALL_RESULT_GUIDANCE = `Based on the tool results above, follow your self-assessment protocol:

- Analyze the results before taking further action.
- Evaluate what you have learned and identify any remaining gaps.
- Proceed with the next steps based on your assessment of information completeness.`

export const AGENT_FORCE_FINAL_ANSWER = `Answer Language: Strictly follow the LANGUAGE POLICY above.\nBased on all the information collected above, please provide a comprehensive final answer.\nDo not use any tools.`
58 changes: 32 additions & 26 deletions entrypoints/sidepanel/utils/agent/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { CoreAssistantMessage, CoreMessage, CoreUserMessage } from 'ai'
import { CoreAssistantMessage, CoreMessage, CoreUserMessage, FilePart, ImagePart, TextPart } from 'ai'
import { isEqual } from 'es-toolkit'
import EventEmitter from 'events'
import { Ref, ref } from 'vue'
Expand All @@ -8,17 +8,17 @@ import { AssistantMessageV1 } from '@/types/chat'
import { PromiseOr } from '@/types/common'
import { Base64ImageData, ImageDataWithId } from '@/types/image'
import { TagBuilderJSON } from '@/types/prompt'
import { AGENT_LOOP_COUNT_REFILL_USER_PROMPT } from '@/utils/constants'
import { AbortError, AiSDKError, AppError, ErrorCode, fromError, ModelNotFoundError, ModelRequestError, ParseFunctionCallError, UnknownError } from '@/utils/error'
import { useGlobalI18n } from '@/utils/i18n'
import { generateRandomId } from '@/utils/id'
import { InferredParams } from '@/utils/llm/tools/prompt-based/helpers'
import { GetPromptBasedTool, PromptBasedToolName, PromptBasedToolNameAndParams } from '@/utils/llm/tools/prompt-based/tools'
import logger from '@/utils/logger'
import { renderPrompt, TagBuilder, UserPrompt } from '@/utils/prompts/helpers'
import { renderPrompt, TagBuilder, TextBuilder, UserPrompt } from '@/utils/prompts/helpers'

import { ReactiveHistoryManager } from '../chat'
import { streamTextInBackground } from '../llm'
import { AGENT_CHECKPOINT_MESSAGES, AGENT_FORCE_FINAL_ANSWER, AGENT_INITIAL_GUIDANCE, AGENT_TOOL_CALL_RESULT_GUIDANCE } from './constants'
import { AgentStorage } from './strorage'

type AgentToolCallHandOffResult = {
Expand Down Expand Up @@ -114,7 +114,7 @@ export class Agent<T extends PromptBasedToolName> {
}
}

injectImagesLastMessage(messages: CoreMessage[], images: Base64ImageData[]) {
injectImagesToLastMessage(messages: CoreMessage[], images: Base64ImageData[]) {
const lastMessage = messages[messages.length - 1]
if (lastMessage && lastMessage.role === 'user') {
if (typeof lastMessage.content === 'string') {
Expand All @@ -140,25 +140,21 @@ export class Agent<T extends PromptBasedToolName> {
return clonedMessages
}

injectContentToLastUserMessage(messages: CoreMessage[], content?: string, newLine = true) {
replaceLastUserMessage(messages: CoreMessage[], content: string) {
if (!content) return messages
content = newLine ? `\n${content}` : content
const lastMessage = messages[messages.length - 1]
if (lastMessage && lastMessage.role === 'user') {
if (typeof lastMessage.content === 'string') {
lastMessage.content = [
{ type: 'text', text: lastMessage.content },
{ type: 'text', text: content },
]
lastMessage.content = [{ type: 'text', text: content }]
}
else {
const text = lastMessage.content.map((c) => c.type === 'text' ? c.text : '').join('')
lastMessage.content = [
{ type: 'text', text: text },
{ type: 'text', text: content },
]
const nonTextParts = lastMessage.content.filter((part) => part.type !== 'text') as (TextPart | ImagePart | FilePart)[]
lastMessage.content = [...nonTextParts, { type: 'text', text: content }]
}
}
else {
throw new UnknownError('Last message is not a user message')
}
return messages
}

Expand Down Expand Up @@ -257,13 +253,26 @@ export class Agent<T extends PromptBasedToolName> {
}
}

async runWithPrompt(baseMessages: CoreMessage[]) {
// iteration starts from 1
buildExtendedUserMessage(iteration: number, originalUserMessage: string, toolResults?: string) {
if (iteration === 1) return `${originalUserMessage}\n\n${AGENT_INITIAL_GUIDANCE.build()}`
const textBuilder = new TextBuilder(`${originalUserMessage}`)
if (toolResults) textBuilder.insertContent(toolResults)
const checkPointMessageBuilder = AGENT_CHECKPOINT_MESSAGES[iteration]
if (checkPointMessageBuilder) textBuilder.insertContent(checkPointMessageBuilder.build())
textBuilder.insertContent(AGENT_TOOL_CALL_RESULT_GUIDANCE)
return textBuilder.build()
}

async run(rawBaseMessages: CoreMessage[]) {
this.stop()
const abortController = this.createAbortController()
let reasoningStart: number | undefined
const baseMessages = structuredClone(rawBaseMessages)
this.log.debug('baseMessages', baseMessages)
const originalUserMessageText = this.extractTheLastUserMessageText(baseMessages)
if (!originalUserMessageText) this.log.warn('Missing original user message')
if (!originalUserMessageText) throw new UnknownError('Missing original user message')
this.replaceLastUserMessage(baseMessages, this.buildExtendedUserMessage(1, originalUserMessageText))
// clone the message to avoid ui changes in agent's running process

// this messages only used for the agent iteration but not user-facing
Expand All @@ -282,20 +291,16 @@ export class Agent<T extends PromptBasedToolName> {
}
iteration++
const shouldForceAnswer = iteration === this.maxIterations
const shouldRefillOriginalUserPrompt = iteration >= AGENT_LOOP_COUNT_REFILL_USER_PROMPT
this.log.debug('Agent iteration', iteration, { shouldForceAnswer, shouldRefillOriginalUserPrompt })
this.log.debug('Agent iteration', iteration, { shouldForceAnswer })

const thisLoopMessages: CoreMessage[] = [...baseMessages, ...loopMessages]
if (shouldForceAnswer) {
thisLoopMessages.push({ role: 'user', content: `Answer Language: Strictly follow the LANGUAGE POLICY above.\nBased on all the information collected above, please provide a comprehensive final answer.\nDo not use any tools.` })
}
if (shouldForceAnswer) thisLoopMessages.push({ role: 'user', content: AGENT_FORCE_FINAL_ANSWER })
let taskMessageModifier = this.makeTaskMessageGroupProxy(abortController.signal)
const agentMessageManager = this.makeTempAgentMessageManager()
const agentMessage = agentMessageManager.getOrAddAgentMessage()
const response = streamTextInBackground({
abortSignal: abortController.signal,
// do not modify the original messages to avoid duplicated images in history
messages: this.injectContentToLastUserMessage(this.injectImagesLastMessage(thisLoopMessages, loopImages), shouldRefillOriginalUserPrompt ? originalUserMessageText : undefined),
messages: this.injectImagesToLastMessage(thisLoopMessages, loopImages),
})
let hasError = false
let text = ''
Expand Down Expand Up @@ -349,13 +354,14 @@ export class Agent<T extends PromptBasedToolName> {
const subAgent = new Agent({ tools: this.tools, agentStorage: this.agentStorage, historyManager: this.historyManager, maxIterations: this.maxIterations })
abortController.signal.addEventListener('abort', () => subAgent.stop())
loopMessages.push({ role: 'user', content: handoffResult.userPrompt })
const lastMsg = await subAgent.runWithPrompt(this.overrideSystemPrompt([...baseMessages, ...loopMessages], handoffResult.overrideSystemPrompt))
const lastMsg = await subAgent.run(this.overrideSystemPrompt([...baseMessages, ...loopMessages], handoffResult.overrideSystemPrompt))
this.log.debug('Sub-agent finished', lastMsg)
if (lastMsg?.content) loopMessages.push(lastMsg)
}
}
else {
loopMessages.push({ role: 'user', content: this.toolResultsToPrompt(toolResults.filter((t) => t.type === 'tool-result')) })
const toolResultPart = this.toolResultsToPrompt(toolResults.filter((t) => t.type === 'tool-result'))
loopMessages.push({ role: 'user', content: this.buildExtendedUserMessage(iteration + 1, originalUserMessageText, toolResultPart) })
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion entrypoints/sidepanel/utils/chat/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ export class Chat {
},
})
this.currentAgent = agent
await agent.runWithPrompt(baseMessages)
await agent.run(baseMessages)
}

private async generateEnvironmentDetails(currentUserMessageId: string) {
Expand Down
2 changes: 1 addition & 1 deletion utils/error.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ export class TimeoutError extends AppError<'timeoutError'> {
}

export class ParseFunctionCallError extends AppError<'parseFunctionCallError'> {
constructor(message: string) {
constructor(message: string, public type?: 'toolNotFound' | 'invalidFormat', public toolName?: string) {
super('parseFunctionCallError', message)
}

Expand Down
89 changes: 71 additions & 18 deletions utils/llm/middlewares.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@ import type { LanguageModelV1Middleware, LanguageModelV1StreamPart } from 'ai'
import { extractReasoningMiddleware } from 'ai'
import { z } from 'zod'

import { nonNullable } from '../array'
import { debounce } from '../debounce'
import { ParseFunctionCallError } from '../error'
import { generateRandomId } from '../id'
import logger from '../logger'
import Logger from '../logger'
import { TagBuilder } from '../prompts/helpers'
import { PromptBasedTool } from './tools/prompt-based/helpers'
import { promptBasedTools } from './tools/prompt-based/tools'

const logger = Logger.child('Agent')

export const reasoningMiddleware = extractReasoningMiddleware({
tagName: 'think',
separator: '\n\n',
Expand Down Expand Up @@ -56,7 +60,7 @@ export const extractPromptBasedToolCallsMiddleware: LanguageModelV1Middleware =
log.warn('Error parsing prompt-based tool calls', errors)
controller.enqueue({
type: 'error',
error: new ParseFunctionCallError(`${errors.map((e) => e).join(', ')}`),
error: new ParseFunctionCallError(`${errors.map((e) => e).join(', ')}`, 'invalidFormat'),
})
controller.terminate()
return
Expand Down Expand Up @@ -87,43 +91,92 @@ export const extractPromptBasedToolCallsMiddleware: LanguageModelV1Middleware =
},
}

function normalizeToolCall<T extends LanguageModelV1FunctionToolCall>(toolCall: T) {
const log = logger.child('normalizeToolCallsMiddleware')

const normalizeToolName = (toolName: string) => {
// sometimes gpt-oss return tool named xxx.toolName, we should normalize it to toolName
return toolName.split('.').pop()!
}
if (toolCall.toolName === 'tool_calls' && toolCall.args) {
// 1. {name: 'tool_calls', arguments: { name: <toolName>, arguments: { a: 1 } }}
const newToolCall = safeParseJSON({ text: toolCall.args, schema: z.object({ name: z.string(), arguments: z.any() }) })
if (newToolCall.success) {
const { name: toolName, ...restArgs } = newToolCall.value
toolCall.toolName = normalizeToolName(toolName)
toolCall.args = JSON.stringify(restArgs.arguments ?? restArgs)
}
// 2. {name: 'tool_calls', arguments: { tool: <toolName>, ...otherArgs }}
else {
const newToolCall = safeParseJSON({ text: toolCall.args, schema: z.record(z.any(), z.any()) })
if (newToolCall.success && newToolCall.value.tool) {
const { tool: toolName, ...restArgs } = newToolCall.value
toolCall.toolName = normalizeToolName(toolName)
toolCall.args = JSON.stringify(typeof restArgs.arguments === 'object' && restArgs.arguments ? restArgs.arguments : restArgs)
}
}
}
else {
toolCall.toolName = normalizeToolName(toolCall.toolName)
}
const paramsParsedResult = safeParseJSON({ text: toolCall.args, schema: z.record(z.any(), z.any()).optional() })
const params = paramsParsedResult.success ? (paramsParsedResult.value ?? {}) : {}
const tool = promptBasedTools.find((t) => t.toolName === toolCall.toolName)
// ignore invalid tool calls
if (tool) {
const { success, errors } = tool.validateParameters(params)
if (!success) {
log.warn('Tool call validation failed', { toolCall, errors })
return { errors, toolCall, params }
}
}
return { toolCall, params }
}

export const normalizeToolCallsMiddleware: LanguageModelV1Middleware = {
wrapGenerate: async ({ doGenerate }) => {
const log = logger.child('normalizeToolCallsMiddleware')

const result = await doGenerate()

if (result.toolCalls?.length) {
const originalToolCalls = structuredClone(result.toolCalls)
result.toolCalls = result.toolCalls?.map((toolCall) => {
if (toolCall.toolName === 'tool_calls' && toolCall.args) {
const newToolCall = safeParseJSON({ text: toolCall.args, schema: z.object({ name: z.string(), arguments: z.any() }) })
if (newToolCall.success) {
toolCall.toolName = newToolCall.value.name
toolCall.args = JSON.stringify(newToolCall.value.arguments)
}
}
return toolCall
})
const { toolCall: normalizedToolCall } = normalizeToolCall(toolCall)
return normalizedToolCall
}).filter(nonNullable)
log.debug('Normalized tool calls', { originalToolCalls, normalizedToolCalls: result.toolCalls })
}

return result
},

wrapStream: async ({ doStream }) => {
const log = logger.child('normalizeToolCallsMiddleware')
const { stream, ...rest } = await doStream()

const transformStream = new TransformStream<
LanguageModelV1StreamPart,
LanguageModelV1StreamPart
>({
transform(chunk, controller) {
if (chunk.type === 'tool-call' && chunk.toolName === 'tool_calls' && chunk.args) {
const newToolCall = safeParseJSON({ text: chunk.args, schema: z.object({ name: z.string(), arguments: z.any() }) })
logger.debug('Normalizing tool call', chunk, newToolCall)
if (newToolCall.success) {
chunk.toolName = newToolCall.value.name
chunk.args = JSON.stringify(newToolCall.value.arguments)
if (chunk.type === 'tool-call') {
const originalToolCalls = structuredClone(chunk)
const { toolCall: normalizedToolCall, errors, params } = normalizeToolCall(chunk)
log.debug('Normalized tool call', { originalToolCalls, normalizedToolCall })
if (errors?.length) {
controller.enqueue({
type: 'error',
error: new ParseFunctionCallError(`${TagBuilder.fromStructured(normalizedToolCall.toolName, params).build()}\n\nError: ${errors.join(',')}`, 'toolNotFound', normalizedToolCall.toolName),
})
}
else {
controller.enqueue(normalizedToolCall)
}
}
controller.enqueue(chunk)
else {
controller.enqueue(chunk)
}
},
})

Expand Down
2 changes: 1 addition & 1 deletion utils/llm/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ export async function getModel(options: {
})
}

export type LLMEndpointType = 'ollama' | 'web-llm'
export type LLMEndpointType = 'ollama' | 'lm-studio' | 'web-llm'

export function parseErrorMessageFromChunk(error: unknown): string | null {
if (error && typeof error === 'object' && 'message' in error && typeof (error as { message: unknown }).message === 'string') {
Expand Down
13 changes: 13 additions & 0 deletions utils/llm/tools/prompt-based/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,19 @@ export class PromptBasedTool<Name extends string, T extends PromptBasedToolParam
return null
}

validateParameters(params: Record<string, unknown>): { errors: string[], success: boolean } {
const errors: string[] = []
for (const key in this.parameters) {
const schema = this.parameters[key]
const value = params[key]
const result = schema.safeParse(value)
if (!result.success) {
errors.push(`Invalid parameter <${key}>: ${result.error.message}`)
}
}
return { errors, success: errors.length === 0 }
}

static createToolCallsStreamParser<Tools extends PromptBasedToolType[]>(tools: Tools) {
type ToolWithParams = ExtractToolWithParams<Tools[number]> & { tagText: string }
let accText = ''
Expand Down
4 changes: 4 additions & 0 deletions utils/rpc/background-fns.ts
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,7 @@ async function checkModelReady(modelId: string) {
const userConfig = await getUserConfig()
const endpointType = userConfig.llm.endpointType.get()
if (endpointType === 'ollama') return true
else if (endpointType === 'lm-studio') return true
else if (endpointType === 'web-llm') {
return await hasWebLLMModelInCache(modelId as WebLLMSupportedModel)
}
Expand All @@ -653,6 +654,9 @@ async function initCurrentModel() {
if (endpointType === 'ollama') {
return false
}
else if (endpointType === 'lm-studio') {
return false
}
else if (endpointType === 'web-llm') {
const connectInfo = initWebLLMEngine(model as WebLLMSupportedModel)
return connectInfo.portName
Expand Down
Loading