Skip to content

Commit 88ca7e3

Browse files
committed
remove check_tool_choice because "any" works for all providers
1 parent 8604b3c commit 88ca7e3

7 files changed

Lines changed: 2 additions & 46 deletions

File tree

src/uipath_langchain/agent/react/llm_node.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def create_llm_node(
6060
"""
6161
bindable_tools = list(tools) if tools else []
6262
payload_handler = get_payload_handler(model)
63-
tool_choice_required_value = payload_handler.get_required_tool_choice()
6463

6564
async def llm_node(state: StateT):
6665
messages: list[AnyMessage] = state.messages
@@ -78,16 +77,15 @@ async def llm_node(state: StateT):
7877
static_schema_tools = _apply_tool_argument_properties(
7978
bindable_tools, state, input_schema
8079
)
81-
base_llm = model.bind_tools(static_schema_tools)
8280

8381
if (
8482
not is_conversational
8583
and bindable_tools
8684
and consecutive_thinking_messages >= thinking_messages_limit
8785
):
88-
llm = base_llm.bind(tool_choice=tool_choice_required_value)
86+
llm = model.bind_tools(static_schema_tools, tool_choice="any")
8987
else:
90-
llm = base_llm
88+
llm = model.bind_tools(static_schema_tools)
9189

9290
response = await llm.ainvoke(messages)
9391
if not isinstance(response, AIMessage):

src/uipath_langchain/chat/handlers/base.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Abstract base class for model payload handlers."""
22

33
from abc import ABC, abstractmethod
4-
from typing import Any
54

65
from langchain_core.messages import AIMessage
76

@@ -12,17 +11,6 @@ class ModelPayloadHandler(ABC):
1211
Each handler provides provider-specific parameter values for LLM operations.
1312
"""
1413

15-
@abstractmethod
16-
def get_required_tool_choice(self) -> str | dict[str, Any]:
17-
"""Get the tool_choice value that enforces tool usage.
18-
19-
Returns:
20-
Provider-specific value to force tool usage:
21-
- "required" for OpenAI-compatible models
22-
- "any" for Bedrock Converse and Vertex models (string format)
23-
- {"type": "any"} for Bedrock Invoke API (dict format required)
24-
"""
25-
2614
@abstractmethod
2715
def check_stop_reason(self, response: AIMessage) -> None:
2816
"""Check response stop reason and raise exception for faulty terminations.

src/uipath_langchain/chat/handlers/bedrock_converse.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""Bedrock Converse payload handler."""
22

3-
from typing import Any
4-
53
from langchain_core.messages import AIMessage
64
from uipath.runtime.errors import UiPathErrorCode
75

@@ -42,10 +40,6 @@
4240
class BedrockConversePayloadHandler(ModelPayloadHandler):
4341
"""Payload handler for AWS Bedrock Converse API."""
4442

45-
def get_required_tool_choice(self) -> str | dict[str, Any]:
46-
"""Get tool_choice value for Bedrock Converse API."""
47-
return "any"
48-
4943
def check_stop_reason(self, response: AIMessage) -> None:
5044
"""Check Bedrock Converse stopReason and raise exception for faulty terminations.
5145

src/uipath_langchain/chat/handlers/bedrock_invoke.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""Bedrock Invoke payload handler."""
22

3-
from typing import Any
4-
53
from langchain_core.messages import AIMessage
64
from uipath.runtime.errors import UiPathErrorCode
75

@@ -37,10 +35,6 @@
3735
class BedrockInvokePayloadHandler(ModelPayloadHandler):
3836
"""Payload handler for AWS Bedrock Invoke API."""
3937

40-
def get_required_tool_choice(self) -> str | dict[str, Any]:
41-
"""Get tool_choice value for Bedrock Invoke API."""
42-
return {"type": "any"}
43-
4438
def check_stop_reason(self, response: AIMessage) -> None:
4539
"""Check Bedrock Invoke stop_reason and raise exception for faulty terminations.
4640

src/uipath_langchain/chat/handlers/openai_completions.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""OpenAI Chat Completions payload handler."""
22

3-
from typing import Any
4-
53
from langchain_core.messages import AIMessage
64
from uipath.runtime.errors import UiPathErrorCode
75

@@ -31,10 +29,6 @@
3129
class OpenAICompletionsPayloadHandler(ModelPayloadHandler):
3230
"""Payload handler for OpenAI Chat Completions API."""
3331

34-
def get_required_tool_choice(self) -> str | dict[str, Any]:
35-
"""Get tool_choice value for OpenAI Completions API."""
36-
return "required"
37-
3832
def check_stop_reason(self, response: AIMessage) -> None:
3933
"""Check OpenAI finish_reason and raise exception for faulty terminations.
4034

src/uipath_langchain/chat/handlers/openai_responses.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""OpenAI payload handlers."""
22

3-
from typing import Any
4-
53
from langchain_core.messages import AIMessage
64
from uipath.runtime.errors import UiPathErrorCode
75

@@ -35,10 +33,6 @@
3533
class OpenAIResponsesPayloadHandler(ModelPayloadHandler):
3634
"""Payload handler for OpenAI Responses API."""
3735

38-
def get_required_tool_choice(self) -> str | dict[str, Any]:
39-
"""Get tool_choice value for OpenAI Responses API."""
40-
return "required"
41-
4236
def check_stop_reason(self, response: AIMessage) -> None:
4337
"""Check OpenAI Responses API status and raise exception for faulty terminations.
4438

src/uipath_langchain/chat/handlers/vertex_gemini.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""Vertex Gemini payload handler."""
22

3-
from typing import Any
4-
53
from langchain_core.messages import AIMessage
64
from uipath.runtime.errors import UiPathErrorCode
75

@@ -115,10 +113,6 @@
115113
class VertexGeminiPayloadHandler(ModelPayloadHandler):
116114
"""Payload handler for Google Vertex AI Gemini API."""
117115

118-
def get_required_tool_choice(self) -> str | dict[str, Any]:
119-
"""Get tool_choice value for Vertex Gemini API."""
120-
return "any"
121-
122116
def check_stop_reason(self, response: AIMessage) -> None:
123117
"""Check Vertex Gemini finishReason and raise exception for faulty terminations.
124118

0 commit comments

Comments
 (0)