11import logging
22from typing import Any , Callable , Literal , Optional
33
4+ from langchain_core .language_models import BaseChatModel
45from langchain_core .messages import HumanMessage
56from langchain_core .tools import tool
67from langgraph .checkpoint .memory import MemorySaver
7- from langgraph .graph import END , START , StateGraph , MessagesState
8+ from langgraph .graph import END , START , MessagesState , StateGraph
89from pydantic import BaseModel , Field
9- from langchain_core .language_models import BaseChatModel
10-
1110
12- from uipath_langchain_client .clients .openai import UiPathAzureChatOpenAI
13- from uipath_langchain_client .clients .bedrock import UiPathChatBedrock , UiPathChatBedrockConverse
14- from uipath_langchain_client .clients .google import UiPathChatGoogleGenerativeAI
11+ from uipath_langchain .chat import (
12+ UiPathAzureChatOpenAI ,
13+ UiPathChat ,
14+ UiPathChatAnthropic ,
15+ UiPathChatAnthropicVertex ,
16+ UiPathChatBedrock ,
17+ UiPathChatBedrockConverse ,
18+ UiPathChatGoogleGenerativeAI ,
19+ )
1520
1621logger = logging .getLogger (__name__ )
1722
1823
1924def create_test_models (max_tokens : int = 100 ) -> list [tuple [str , Any ]]:
2025 """Create all test chat models with the specified max_tokens."""
2126 return [
22- ("UiPathChatGoogleGenerativeAI" , UiPathChatGoogleGenerativeAI ()),
23- ("UiPathChatBedrockConverse" , UiPathChatBedrockConverse ()),
24- ("UiPathChatBedrock" , UiPathChatBedrock ()),
25- ("UiPathAzureChatOpenAI" , UiPathAzureChatOpenAI ())
27+ ("UiPathChat" , UiPathChat (model = "gpt-4o-2024-11-20" )),
28+ ("UiPathAzureChatOpenAI" , UiPathAzureChatOpenAI (model = "gpt-4o-2024-11-20" )),
29+ (
30+ "UiPathChatBedrock" ,
31+ UiPathChatBedrock (model = "anthropic.claude-haiku-4-5-20251001-v1:0" ),
32+ ),
33+ (
34+ "UiPathChatBedrockConverse" ,
35+ UiPathChatBedrockConverse (model = "anthropic.claude-haiku-4-5-20251001-v1:0" ),
36+ ),
37+ (
38+ "UiPathChatGoogleGenerativeAI" ,
39+ UiPathChatGoogleGenerativeAI (model = "gemini-2.5-flash" ),
40+ ),
41+ (
42+ "UiPathChatAnthropic" ,
43+ UiPathChatAnthropic (model = "anthropic.claude-haiku-4-5-20251001-v1:0" ),
44+ ),
45+ (
46+ "UiPathChatAnthropicVertex" ,
47+ UiPathChatAnthropicVertex (model = "claude-haiku-4-5@20251001" ),
48+ ),
2649 ]
2750
2851
@@ -41,7 +64,9 @@ def format_error_message(error: str, max_length: int = 60) -> str:
4164
4265
4366@tool
44- def get_weather (location : str , unit : Literal ["celsius" , "fahrenheit" ] = "celsius" ) -> str :
67+ def get_weather (
68+ location : str , unit : Literal ["celsius" , "fahrenheit" ] = "celsius"
69+ ) -> str :
4570 """Get the current weather for a location.
4671
4772 Args:
@@ -67,22 +92,24 @@ def calculate(expression: str) -> str:
6792
6893class PersonInfo (BaseModel ):
6994 """Information about a person."""
95+
7096 name : str = Field (description = "The person's full name" )
7197 age : int = Field (description = "The person's age in years" )
7298 city : str = Field (description = "The city where the person lives" )
7399
74100
75101class TestResult :
76102 """Accumulates test metrics across all test runs."""
103+
77104 def __init__ (self ):
78105 self .chunks = 0
79106 self .content_length = 0
80107 self .tool_calls = 0
81108
82109 def add_response (self , response : Any ) -> None :
83- if hasattr (response , ' content' ) and response .content :
110+ if hasattr (response , " content" ) and response .content :
84111 self .content_length += len (response .content )
85- if hasattr (response , ' tool_calls' ) and response .tool_calls :
112+ if hasattr (response , " tool_calls" ) and response .tool_calls :
86113 self .tool_calls += len (response .tool_calls )
87114
88115 def add_chunks (self , count : int ) -> None :
@@ -120,14 +147,15 @@ async def run_test_method(
120147
121148class GraphInput (BaseModel ):
122149 """Input model for the testing graph."""
150+
123151 prompt : str = Field (
124- default = "Count from 1 to 5." ,
125- description = "The prompt to send to the LLM"
152+ default = "Count from 1 to 5." , description = "The prompt to send to the LLM"
126153 )
127154
128155
129156class GraphOutput (BaseModel ):
130157 """Output model for the testing graph."""
158+
131159 success : bool
132160 result_summary : str
133161 chunks_received : Optional [int ] = None
@@ -137,6 +165,7 @@ class GraphOutput(BaseModel):
137165
138166class GraphState (MessagesState ):
139167 """State model for the testing workflow."""
168+
140169 prompt : str
141170 success : bool
142171 result_summary : str
@@ -177,7 +206,7 @@ async def test_single_model_all(
177206 ("invoke" , False , False ),
178207 ("ainvoke" , True , False ),
179208 ("stream" , False , True ),
180- ("astream" , True , True )
209+ ("astream" , True , True ),
181210 ]
182211
183212 for method_name , is_async , is_streaming in test_methods :
@@ -192,7 +221,7 @@ async def test_single_model_all(
192221 model_results [method_name ] = "✓"
193222
194223 # Test tool calling
195- logger .info (f " Testing tool_calling..." )
224+ logger .info (" Testing tool_calling..." )
196225 try :
197226 llm_with_tools = model .bind_tools (tools )
198227 chunks = []
@@ -203,20 +232,24 @@ async def test_single_model_all(
203232 for chunk in chunks :
204233 accumulated = chunk if accumulated is None else accumulated + chunk
205234
206- if accumulated and hasattr (accumulated , 'tool_calls' ) and accumulated .tool_calls :
235+ if (
236+ accumulated
237+ and hasattr (accumulated , "tool_calls" )
238+ and accumulated .tool_calls
239+ ):
207240 tool_calls_count = len (accumulated .tool_calls )
208241 result .add_tool_calls (tool_calls_count )
209242 logger .info (f" Tool calls detected: { tool_calls_count } " )
210243 model_results ["tool_calling" ] = f"✓ ({ tool_calls_count } calls)"
211244 else :
212- logger .warning (f " No tool calls detected" )
245+ logger .warning (" No tool calls detected" )
213246 model_results ["tool_calling" ] = "✗ No tool calls detected"
214247 except Exception as e :
215248 logger .error (f" Tool calling failed: { e } " )
216249 model_results ["tool_calling" ] = f"✗ { format_error_message (str (e ))} "
217250
218251 # Test structured output
219- logger .info (f " Testing structured_output..." )
252+ logger .info (" Testing structured_output..." )
220253 try :
221254 llm_with_structure = model .with_structured_output (PersonInfo )
222255 response = await llm_with_structure .ainvoke (structured_messages )
@@ -247,18 +280,28 @@ async def run_all_tests(state: GraphState) -> dict:
247280 """Run all tests for all chat models in parallel."""
248281 import asyncio
249282
250- logger .info ("=" * 80 )
283+ logger .info ("=" * 80 )
251284 logger .info ("Running All Tests" )
252- logger .info ("=" * 80 )
285+ logger .info ("=" * 80 )
253286
254287 models = create_test_models (max_tokens = 2000 )
255288 tools = [get_weather , calculate ]
256- tool_messages = [HumanMessage (content = "What's the weather in San Francisco? Also calculate 15 * 23." )]
257- structured_messages = [HumanMessage (content = "Tell me about John Smith, a 35 year old software engineer living in New York." )]
289+ tool_messages = [
290+ HumanMessage (
291+ content = "What's the weather in San Francisco? Also calculate 15 * 23."
292+ )
293+ ]
294+ structured_messages = [
295+ HumanMessage (
296+ content = "Tell me about John Smith, a 35 year old software engineer living in New York."
297+ )
298+ ]
258299
259300 # Run all models in parallel
260301 tasks = [
261- test_single_model_all (name , model , state ["messages" ], tools , tool_messages , structured_messages )
302+ test_single_model_all (
303+ name , model , state ["messages" ], tools , tool_messages , structured_messages
304+ )
262305 for name , model in models
263306 ]
264307 results_list = await asyncio .gather (* tasks )
@@ -274,17 +317,34 @@ async def run_all_tests(state: GraphState) -> dict:
274317 total_result .tool_calls += result .tool_calls
275318
276319 # Build summary
277- logger .info ("=" * 80 )
320+ logger .info ("=" * 80 )
278321 summary_lines = []
279- for model_name in ["UiPathChatOpenAI" , "UiPathChatGoogleGenerativeAI" , "UiPathChatBedrockConverse" , "UiPathChatBedrock" , "UiPathChat" , "UiPathAzureChatOpenAI" ]:
322+ for model_name in [
323+ "UiPathChat" ,
324+ "UiPathAzureChatOpenAI" ,
325+ "UiPathChatBedrock" ,
326+ "UiPathChatBedrockConverse" ,
327+ "UiPathChatGoogleGenerativeAI" ,
328+ "UiPathChatAnthropic" ,
329+ "UiPathChatAnthropicVertex" ,
330+ ]:
280331 if model_name in all_model_results :
281332 summary_lines .append (f"{ model_name } :" )
282333 results = all_model_results [model_name ]
283- for test_name in ["invoke" , "ainvoke" , "stream" , "astream" , "tool_calling" , "structured_output" ]:
334+ for test_name in [
335+ "invoke" ,
336+ "ainvoke" ,
337+ "stream" ,
338+ "astream" ,
339+ "tool_calling" ,
340+ "structured_output" ,
341+ ]:
284342 if test_name in results :
285343 summary_lines .append (f" { test_name } : { results [test_name ]} " )
286344
287- has_failures = any ("✗" in str (v ) for r in all_model_results .values () for v in r .values ())
345+ has_failures = any (
346+ "✗" in str (v ) for r in all_model_results .values () for v in r .values ()
347+ )
288348
289349 return {
290350 "success" : not has_failures ,
@@ -298,16 +358,16 @@ async def run_all_tests(state: GraphState) -> dict:
298358
299359async def return_results (state : GraphState ) -> GraphOutput :
300360 """Return final test results."""
301- logger .info ("=" * 80 )
361+ logger .info ("=" * 80 )
302362 logger .info ("TEST RESULTS" )
303- logger .info ("=" * 80 )
363+ logger .info ("=" * 80 )
304364 logger .info (f"Success: { state ['success' ]} " )
305365 logger .info (f"Summary: { state ['result_summary' ]} " )
306- if state .get (' chunks_received' ):
366+ if state .get (" chunks_received" ):
307367 logger .info (f"Chunks Received: { state ['chunks_received' ]} " )
308- if state .get (' content_length' ):
368+ if state .get (" content_length" ):
309369 logger .info (f"Content Length: { state ['content_length' ]} " )
310- if state .get (' tool_calls_count' ):
370+ if state .get (" tool_calls_count" ):
311371 logger .info (f"Tool Calls: { state ['tool_calls_count' ]} " )
312372
313373 return GraphOutput (
0 commit comments