Skip to content

Commit 74125e2

Browse files
committed
test fixes
1 parent 724025c commit 74125e2

5 files changed

Lines changed: 118 additions & 49 deletions

File tree

testcases/chat-models/src/main.py

Lines changed: 94 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,51 @@
11
import logging
22
from typing import Any, Callable, Literal, Optional
33

4+
from langchain_core.language_models import BaseChatModel
45
from langchain_core.messages import HumanMessage
56
from langchain_core.tools import tool
67
from langgraph.checkpoint.memory import MemorySaver
7-
from langgraph.graph import END, START, StateGraph, MessagesState
8+
from langgraph.graph import END, START, MessagesState, StateGraph
89
from pydantic import BaseModel, Field
9-
from langchain_core.language_models import BaseChatModel
10-
1110

12-
from uipath_langchain_client.clients.openai import UiPathAzureChatOpenAI
13-
from uipath_langchain_client.clients.bedrock import UiPathChatBedrock, UiPathChatBedrockConverse
14-
from uipath_langchain_client.clients.google import UiPathChatGoogleGenerativeAI
11+
from uipath_langchain.chat import (
12+
UiPathAzureChatOpenAI,
13+
UiPathChat,
14+
UiPathChatAnthropic,
15+
UiPathChatAnthropicVertex,
16+
UiPathChatBedrock,
17+
UiPathChatBedrockConverse,
18+
UiPathChatGoogleGenerativeAI,
19+
)
1520

1621
logger = logging.getLogger(__name__)
1722

1823

1924
def create_test_models(max_tokens: int = 100) -> list[tuple[str, Any]]:
2025
"""Create all test chat models with the specified max_tokens."""
2126
return [
22-
("UiPathChatGoogleGenerativeAI", UiPathChatGoogleGenerativeAI()),
23-
("UiPathChatBedrockConverse", UiPathChatBedrockConverse()),
24-
("UiPathChatBedrock", UiPathChatBedrock()),
25-
("UiPathAzureChatOpenAI", UiPathAzureChatOpenAI())
27+
("UiPathChat", UiPathChat(model="gpt-4o-2024-11-20")),
28+
("UiPathAzureChatOpenAI", UiPathAzureChatOpenAI(model="gpt-4o-2024-11-20")),
29+
(
30+
"UiPathChatBedrock",
31+
UiPathChatBedrock(model="anthropic.claude-haiku-4-5-20251001-v1:0"),
32+
),
33+
(
34+
"UiPathChatBedrockConverse",
35+
UiPathChatBedrockConverse(model="anthropic.claude-haiku-4-5-20251001-v1:0"),
36+
),
37+
(
38+
"UiPathChatGoogleGenerativeAI",
39+
UiPathChatGoogleGenerativeAI(model="gemini-2.5-flash"),
40+
),
41+
(
42+
"UiPathChatAnthropic",
43+
UiPathChatAnthropic(model="anthropic.claude-haiku-4-5-20251001-v1:0"),
44+
),
45+
(
46+
"UiPathChatAnthropicVertex",
47+
UiPathChatAnthropicVertex(model="claude-haiku-4-5@20251001"),
48+
),
2649
]
2750

2851

@@ -41,7 +64,9 @@ def format_error_message(error: str, max_length: int = 60) -> str:
4164

4265

4366
@tool
44-
def get_weather(location: str, unit: Literal["celsius", "fahrenheit"] = "celsius") -> str:
67+
def get_weather(
68+
location: str, unit: Literal["celsius", "fahrenheit"] = "celsius"
69+
) -> str:
4570
"""Get the current weather for a location.
4671
4772
Args:
@@ -67,22 +92,24 @@ def calculate(expression: str) -> str:
6792

6893
class PersonInfo(BaseModel):
6994
"""Information about a person."""
95+
7096
name: str = Field(description="The person's full name")
7197
age: int = Field(description="The person's age in years")
7298
city: str = Field(description="The city where the person lives")
7399

74100

75101
class TestResult:
76102
"""Accumulates test metrics across all test runs."""
103+
77104
def __init__(self):
78105
self.chunks = 0
79106
self.content_length = 0
80107
self.tool_calls = 0
81108

82109
def add_response(self, response: Any) -> None:
83-
if hasattr(response, 'content') and response.content:
110+
if hasattr(response, "content") and response.content:
84111
self.content_length += len(response.content)
85-
if hasattr(response, 'tool_calls') and response.tool_calls:
112+
if hasattr(response, "tool_calls") and response.tool_calls:
86113
self.tool_calls += len(response.tool_calls)
87114

88115
def add_chunks(self, count: int) -> None:
@@ -120,14 +147,15 @@ async def run_test_method(
120147

121148
class GraphInput(BaseModel):
122149
"""Input model for the testing graph."""
150+
123151
prompt: str = Field(
124-
default="Count from 1 to 5.",
125-
description="The prompt to send to the LLM"
152+
default="Count from 1 to 5.", description="The prompt to send to the LLM"
126153
)
127154

128155

129156
class GraphOutput(BaseModel):
130157
"""Output model for the testing graph."""
158+
131159
success: bool
132160
result_summary: str
133161
chunks_received: Optional[int] = None
@@ -137,6 +165,7 @@ class GraphOutput(BaseModel):
137165

138166
class GraphState(MessagesState):
139167
"""State model for the testing workflow."""
168+
140169
prompt: str
141170
success: bool
142171
result_summary: str
@@ -177,7 +206,7 @@ async def test_single_model_all(
177206
("invoke", False, False),
178207
("ainvoke", True, False),
179208
("stream", False, True),
180-
("astream", True, True)
209+
("astream", True, True),
181210
]
182211

183212
for method_name, is_async, is_streaming in test_methods:
@@ -192,7 +221,7 @@ async def test_single_model_all(
192221
model_results[method_name] = "✓"
193222

194223
# Test tool calling
195-
logger.info(f" Testing tool_calling...")
224+
logger.info(" Testing tool_calling...")
196225
try:
197226
llm_with_tools = model.bind_tools(tools)
198227
chunks = []
@@ -203,20 +232,24 @@ async def test_single_model_all(
203232
for chunk in chunks:
204233
accumulated = chunk if accumulated is None else accumulated + chunk
205234

206-
if accumulated and hasattr(accumulated, 'tool_calls') and accumulated.tool_calls:
235+
if (
236+
accumulated
237+
and hasattr(accumulated, "tool_calls")
238+
and accumulated.tool_calls
239+
):
207240
tool_calls_count = len(accumulated.tool_calls)
208241
result.add_tool_calls(tool_calls_count)
209242
logger.info(f" Tool calls detected: {tool_calls_count}")
210243
model_results["tool_calling"] = f"✓ ({tool_calls_count} calls)"
211244
else:
212-
logger.warning(f" No tool calls detected")
245+
logger.warning(" No tool calls detected")
213246
model_results["tool_calling"] = "✗ No tool calls detected"
214247
except Exception as e:
215248
logger.error(f" Tool calling failed: {e}")
216249
model_results["tool_calling"] = f"✗ {format_error_message(str(e))}"
217250

218251
# Test structured output
219-
logger.info(f" Testing structured_output...")
252+
logger.info(" Testing structured_output...")
220253
try:
221254
llm_with_structure = model.with_structured_output(PersonInfo)
222255
response = await llm_with_structure.ainvoke(structured_messages)
@@ -247,18 +280,28 @@ async def run_all_tests(state: GraphState) -> dict:
247280
"""Run all tests for all chat models in parallel."""
248281
import asyncio
249282

250-
logger.info("="*80)
283+
logger.info("=" * 80)
251284
logger.info("Running All Tests")
252-
logger.info("="*80)
285+
logger.info("=" * 80)
253286

254287
models = create_test_models(max_tokens=2000)
255288
tools = [get_weather, calculate]
256-
tool_messages = [HumanMessage(content="What's the weather in San Francisco? Also calculate 15 * 23.")]
257-
structured_messages = [HumanMessage(content="Tell me about John Smith, a 35 year old software engineer living in New York.")]
289+
tool_messages = [
290+
HumanMessage(
291+
content="What's the weather in San Francisco? Also calculate 15 * 23."
292+
)
293+
]
294+
structured_messages = [
295+
HumanMessage(
296+
content="Tell me about John Smith, a 35 year old software engineer living in New York."
297+
)
298+
]
258299

259300
# Run all models in parallel
260301
tasks = [
261-
test_single_model_all(name, model, state["messages"], tools, tool_messages, structured_messages)
302+
test_single_model_all(
303+
name, model, state["messages"], tools, tool_messages, structured_messages
304+
)
262305
for name, model in models
263306
]
264307
results_list = await asyncio.gather(*tasks)
@@ -274,17 +317,34 @@ async def run_all_tests(state: GraphState) -> dict:
274317
total_result.tool_calls += result.tool_calls
275318

276319
# Build summary
277-
logger.info("="*80)
320+
logger.info("=" * 80)
278321
summary_lines = []
279-
for model_name in ["UiPathChatOpenAI", "UiPathChatGoogleGenerativeAI", "UiPathChatBedrockConverse", "UiPathChatBedrock", "UiPathChat", "UiPathAzureChatOpenAI"]:
322+
for model_name in [
323+
"UiPathChat",
324+
"UiPathAzureChatOpenAI",
325+
"UiPathChatBedrock",
326+
"UiPathChatBedrockConverse",
327+
"UiPathChatGoogleGenerativeAI",
328+
"UiPathChatAnthropic",
329+
"UiPathChatAnthropicVertex",
330+
]:
280331
if model_name in all_model_results:
281332
summary_lines.append(f"{model_name}:")
282333
results = all_model_results[model_name]
283-
for test_name in ["invoke", "ainvoke", "stream", "astream", "tool_calling", "structured_output"]:
334+
for test_name in [
335+
"invoke",
336+
"ainvoke",
337+
"stream",
338+
"astream",
339+
"tool_calling",
340+
"structured_output",
341+
]:
284342
if test_name in results:
285343
summary_lines.append(f" {test_name}: {results[test_name]}")
286344

287-
has_failures = any("✗" in str(v) for r in all_model_results.values() for v in r.values())
345+
has_failures = any(
346+
"✗" in str(v) for r in all_model_results.values() for v in r.values()
347+
)
288348

289349
return {
290350
"success": not has_failures,
@@ -298,16 +358,16 @@ async def run_all_tests(state: GraphState) -> dict:
298358

299359
async def return_results(state: GraphState) -> GraphOutput:
300360
"""Return final test results."""
301-
logger.info("="*80)
361+
logger.info("=" * 80)
302362
logger.info("TEST RESULTS")
303-
logger.info("="*80)
363+
logger.info("=" * 80)
304364
logger.info(f"Success: {state['success']}")
305365
logger.info(f"Summary: {state['result_summary']}")
306-
if state.get('chunks_received'):
366+
if state.get("chunks_received"):
307367
logger.info(f"Chunks Received: {state['chunks_received']}")
308-
if state.get('content_length'):
368+
if state.get("content_length"):
309369
logger.info(f"Content Length: {state['content_length']}")
310-
if state.get('tool_calls_count'):
370+
if state.get("tool_calls_count"):
311371
logger.info(f"Tool Calls: {state['tool_calls_count']}")
312372

313373
return GraphOutput(

testcases/company-research-agent/src/graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from langchain_community.tools import DuckDuckGoSearchResults
33
from langgraph.graph import END, START, MessagesState, StateGraph
44
from pydantic import BaseModel
5+
from uipath_langchain.chat import UiPathChat
56

6-
from uipath_langchain_client.clients.openai import UiPathAzureChatOpenAI
77
# Configuration constants
88

99

@@ -37,9 +37,9 @@ def get_search_tool() -> DuckDuckGoSearchResults:
3737
"""
3838

3939

40-
def create_llm() -> UiPathAzureChatOpenAI:
40+
def create_llm() -> UiPathChat:
4141
"""Create and configure the language model."""
42-
return UiPathAzureChatOpenAI(streaming=False)
42+
return UiPathChat(model="gpt-4o-2024-11-20", streaming=False)
4343

4444

4545
def create_research_agent():

testcases/ticket-classification/src/main.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77
from langchain_core.messages import HumanMessage, SystemMessage
88
from langchain_core.output_parsers import PydanticOutputParser
99
from langgraph.checkpoint.memory import MemorySaver
10-
from langgraph.graph import END, START, StateGraph, MessagesState
10+
from langgraph.graph import END, START, MessagesState, StateGraph
1111
from langgraph.types import Command, interrupt
1212
from pydantic import BaseModel, Field
13-
1413
from uipath.platform import UiPath
15-
from uipath.platform.common import CreateTask
16-
from uipath_langchain_client.clients.openai import UiPathAzureChatOpenAI
14+
15+
from uipath_langchain.chat import UiPathChat
1716

1817
# Configuration
1918
logger = logging.getLogger(__name__)
@@ -27,22 +26,26 @@
2726
TicketCategory = Literal["security", "error", "system", "billing", "performance"]
2827
NextNode = Literal["classify", "notify_team"]
2928

29+
3030
# Data Models
3131
class GraphInput(BaseModel):
3232
"""Input model for the ticket classification graph."""
33+
3334
message: str
3435
ticket_id: str
3536
assignee: str | None = None
3637

3738

3839
class GraphOutput(BaseModel):
3940
"""Output model for the ticket classification graph."""
41+
4042
label: str
4143
confidence: float
4244

4345

4446
class GraphState(MessagesState):
4547
"""State model for the ticket classification workflow."""
48+
4649
message: str
4750
ticket_id: str
4851
assignee: str | None
@@ -54,6 +57,7 @@ class GraphState(MessagesState):
5457

5558
class TicketClassification(BaseModel):
5659
"""Model for ticket classification results."""
60+
5761
label: TicketCategory = Field(
5862
description="The classification label for the support ticket"
5963
)
@@ -90,6 +94,7 @@ def create_system_message() -> str:
9094
format_instructions=output_parser.get_format_instructions()
9195
)
9296

97+
9398
# Node Functions
9499
def prepare_input(graph_input: GraphInput) -> GraphState:
95100
"""Prepare the initial state from graph input."""
@@ -99,7 +104,7 @@ def prepare_input(graph_input: GraphInput) -> GraphState:
99104
assignee=graph_input.assignee,
100105
messages=[
101106
SystemMessage(content=create_system_message()),
102-
HumanMessage(content=graph_input.message)
107+
HumanMessage(content=graph_input.message),
103108
],
104109
last_predicted_category=None,
105110
human_approval=None,
@@ -112,9 +117,10 @@ def decide_next_node(state: GraphState) -> NextNode:
112117
return "notify_team"
113118
return "classify"
114119

120+
115121
async def classify(state: GraphState) -> Command:
116122
"""Classify the support ticket using LLM."""
117-
llm = UiPathAzureChatOpenAI(model="gpt-4o-mini-2024-07-18")
123+
llm = UiPathChat(model="gpt-4o-mini-2024-07-18")
118124

119125
# Add rejection message if there was a previous prediction
120126
if state.get("last_predicted_category"):
@@ -151,7 +157,10 @@ async def classify(state: GraphState) -> Command:
151157
}
152158
)
153159

154-
def create_approval_message(ticket_id: str, ticket_message: str, label: str, confidence: float) -> str:
160+
161+
def create_approval_message(
162+
ticket_id: str, ticket_message: str, label: str, confidence: float
163+
) -> str:
155164
"""Create formatted message for human approval."""
156165
return (
157166
f"This is how I classified the ticket: '{ticket_id}', "
@@ -170,8 +179,6 @@ async def wait_for_human(state: GraphState) -> Command:
170179
confidence = state["confidence"]
171180
is_resume = state.get("human_approval") is not None
172181

173-
174-
175182
if not is_resume:
176183
logger.info("Waiting for human approval via regular interrupt")
177184
interrupt_message = (
@@ -187,6 +194,7 @@ async def wait_for_human(state: GraphState) -> Command:
187194
}
188195
)
189196

197+
190198
async def notify_team(state: GraphState) -> GraphOutput:
191199
"""Send team notification and return final output."""
192200
logger.info("Sending team email notification")

0 commit comments

Comments
 (0)