Skip to content

Commit 0bff8d4

Browse files
committed
Merge branch 'feat/add-chat-models' of https://github.com/UiPath/uipath-llamaindex-python into feat/add-chat-models
2 parents 103deb1 + 9ba6735 commit 0bff8d4

3 files changed

Lines changed: 89 additions & 212 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,4 @@ cython_debug/
177177
**/.uipath
178178
**/**.nupkg
179179
**/__uipath/
180+
.claude/settings.local.json

playground.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from llama_index.core.llms import ChatMessage
2+
from uipath_llamaindex.llms import UiPathVertex, GeminiModels
3+
4+
5+
def test_all_methods():
6+
llm = UiPathVertex(model=GeminiModels.gemini_2_5_flash, max_tokens=1024)
7+
prompt = "What is 2+2? Answer in one word."
8+
messages = [ChatMessage(role="user", content=prompt)]
9+
10+
results = {}
11+
12+
# Test complete
13+
print("Testing complete...")
14+
try:
15+
response = llm.complete(prompt)
16+
print(f" complete: {response.text.strip()}")
17+
results["complete"] = "PASS"
18+
except Exception as e:
19+
print(f" complete: FAILED - {e}")
20+
results["complete"] = "FAIL"
21+
22+
# Test chat
23+
print("Testing chat...")
24+
try:
25+
response = llm.chat(messages)
26+
print(f" chat: {response.message.content.strip()}")
27+
results["chat"] = "PASS"
28+
except Exception as e:
29+
print(f" chat: FAILED - {e}")
30+
results["chat"] = "FAIL"
31+
32+
# Test stream_complete
33+
print("Testing stream_complete...")
34+
try:
35+
chunks = []
36+
for chunk in llm.stream_complete(prompt):
37+
chunks.append(chunk.delta)
38+
print(f" stream_complete: {''.join(chunks).strip()}")
39+
results["stream_complete"] = "PASS"
40+
except Exception as e:
41+
print(f" stream_complete: FAILED - {e}")
42+
results["stream_complete"] = "FAIL"
43+
44+
# Test stream_chat
45+
print("Testing stream_chat...")
46+
try:
47+
chunks = []
48+
for chunk in llm.stream_chat(messages):
49+
chunks.append(chunk.delta)
50+
print(f" stream_chat: {''.join(chunks).strip()}")
51+
results["stream_chat"] = "PASS"
52+
except Exception as e:
53+
print(f" stream_chat: FAILED - {e}")
54+
results["stream_chat"] = "FAIL"
55+
56+
# Print summary
57+
print("\n" + "=" * 50)
58+
print("SUMMARY")
59+
print("=" * 50)
60+
61+
passed = sum(1 for v in results.values() if v == "PASS")
62+
failed = sum(1 for v in results.values() if v == "FAIL")
63+
64+
for method, status in results.items():
65+
icon = "+" if status == "PASS" else "x"
66+
print(f" [{icon}] {method}: {status}")
67+
68+
print("-" * 50)
69+
print(f" Total: {len(results)} | Passed: {passed} | Failed: {failed}")
70+
print("=" * 50)
71+
72+
73+
if __name__ == "__main__":
74+
test_all_methods()

src/uipath_llamaindex/llms/vertex.py

Lines changed: 14 additions & 212 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ def _check_vertex_dependencies() -> None:
1818
if importlib.util.find_spec("llama_index.llms.vertex") is None:
1919
missing_packages.append("llama-index-llms-vertex")
2020

21-
if importlib.util.find_spec("httpx") is None:
22-
missing_packages.append("httpx")
23-
2421
if missing_packages:
2522
packages_str = ", ".join(missing_packages)
2623
raise ImportError(
@@ -35,126 +32,17 @@ def _check_vertex_dependencies() -> None:
3532

3633
_check_vertex_dependencies()
3734

38-
import httpx
39-
from google.auth.aio import credentials as aio_credentials
4035
from google.auth.credentials import AnonymousCredentials
4136
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
4237
PredictionServiceClient as v1beta1PredictionServiceClient,
4338
)
4439
from google.cloud.aiplatform_v1beta1.services.prediction_service.transports.rest import (
4540
PredictionServiceRestTransport,
4641
)
47-
from google.cloud.aiplatform_v1beta1.services.prediction_service.transports.rest_asyncio import (
48-
AsyncPredictionServiceRestTransport,
49-
)
5042
from llama_index.core.bridge.pydantic import PrivateAttr
5143
from llama_index.llms.vertex import Vertex
5244

5345

54-
class AsyncAnonymousCredentials(aio_credentials.Credentials):
55-
"""Async-compatible anonymous credentials that don't provide authentication.
56-
57-
Used to satisfy Google's credential requirements while we handle
58-
authentication separately via UiPath Gateway headers.
59-
"""
60-
61-
# Class-level attribute to ensure _token exists before parent __init__ runs
62-
_token: Optional[str] = None
63-
64-
def __init__(self) -> None:
65-
"""Initialize credentials, ensuring _token is set before parent init."""
66-
# Set instance attribute before parent __init__ tries to set self.token
67-
self._token = None
68-
# Now call parent - it will call self.token = None which uses our setter
69-
super().__init__()
70-
71-
@property
72-
def token(self) -> Optional[str]:
73-
return self._token
74-
75-
@token.setter
76-
def token(self, value: Optional[str]) -> None:
77-
self._token = value
78-
79-
@property
80-
def expired(self) -> bool:
81-
return False
82-
83-
@property
84-
def valid(self) -> bool:
85-
return True
86-
87-
async def refresh(self, _request) -> None:
88-
"""No-op refresh for anonymous credentials."""
89-
pass
90-
91-
92-
class CustomSyncSession:
93-
"""
94-
Custom sync session that redirects all requests to UiPath LLM Gateway.
95-
96-
Uses httpx for HTTP requests, bypassing Google's AuthorizedSession.
97-
"""
98-
99-
def __init__(self, llmgw_url: str, custom_headers: dict[str, str]):
100-
self.llmgw_url = llmgw_url
101-
self.custom_headers = custom_headers or {}
102-
self._client = httpx.Client()
103-
104-
def request(self, method: str, url: str, **kwargs):
105-
"""Make an HTTP request, redirecting to UiPath gateway."""
106-
# Get headers from kwargs or use empty dict
107-
headers = kwargs.pop("headers", {}) or {}
108-
109-
# Update with our custom headers (including Authorization)
110-
headers.update(self.custom_headers)
111-
112-
# Detect streaming from kwargs
113-
is_streaming = kwargs.get("stream", False)
114-
headers["X-UiPath-Streaming-Enabled"] = "true" if is_streaming else "false"
115-
116-
# Convert 'data' to 'content' for httpx
117-
if "data" in kwargs:
118-
kwargs["content"] = kwargs.pop("data")
119-
120-
# Make request to our gateway URL instead of the original URL
121-
response = self._client.request(
122-
method, self.llmgw_url, headers=headers, **kwargs
123-
)
124-
125-
# Return a response wrapper compatible with Google's expectations
126-
return HttpxSyncResponseWrapper(response)
127-
128-
@property
129-
def verify(self):
130-
return self._client._transport._pool._ssl_context is not None
131-
132-
@verify.setter
133-
def verify(self, _value):
134-
# httpx doesn't support changing verify after creation
135-
pass
136-
137-
def close(self):
138-
self._client.close()
139-
140-
141-
class HttpxSyncResponseWrapper:
142-
"""Wrapper to make httpx sync response compatible with requests.Response interface."""
143-
144-
def __init__(self, response: httpx.Response):
145-
self._response = response
146-
self.status_code = response.status_code
147-
self.headers = dict(response.headers)
148-
self.content = response.content
149-
self.text = response.text
150-
151-
def json(self):
152-
return self._response.json()
153-
154-
def raise_for_status(self):
155-
self._response.raise_for_status()
156-
157-
15846
class CustomPredictionServiceRestTransport(PredictionServiceRestTransport):
15947
"""Custom REST transport that redirects requests to UiPath LLM Gateway."""
16048

@@ -165,9 +53,6 @@ def __init__(self, llmgw_url: str, custom_headers: dict[str, str], **kwargs):
16553
kwargs.setdefault("credentials", AnonymousCredentials())
16654
super().__init__(**kwargs)
16755

168-
# Disable SSL verification
169-
self._session.verify = False
170-
17156
# Monkey-patch the session's request method to redirect to UiPath Gateway
17257
# This preserves the session object identity while redirecting all requests
17358
original_request = self._session.request
@@ -176,6 +61,9 @@ def redirected_request(method, url, **kwargs_inner):
17661
headers = kwargs_inner.pop("headers", {})
17762
headers.update(self.custom_headers)
17863

64+
# Remove Google's internal query parameters - UiPath gateway doesn't need them
65+
kwargs_inner.pop("params", None)
66+
17967
is_streaming = kwargs_inner.get("stream", False)
18068
headers["X-UiPath-Streaming-Enabled"] = "true" if is_streaming else "false"
18169

@@ -184,92 +72,6 @@ def redirected_request(method, url, **kwargs_inner):
18472
self._session.request = redirected_request # type: ignore[method-assign]
18573

18674

187-
class CustomAsyncSession:
188-
"""
189-
Custom async session for redirecting requests to UiPath LLM Gateway.
190-
191-
Uses httpx for async HTTP requests, bypassing Google's AsyncAuthorizedSession.
192-
"""
193-
194-
def __init__(self, llmgw_url: str, custom_headers: dict[str, str]):
195-
self.llmgw_url = llmgw_url
196-
self.custom_headers = custom_headers or {}
197-
198-
async def request(
199-
self,
200-
method: str,
201-
url: str,
202-
data: Optional[bytes] = None,
203-
headers: Optional[dict] = None,
204-
**kwargs,
205-
):
206-
"""Make an async HTTP request, redirecting to UiPath gateway."""
207-
request_headers = dict(headers) if headers else {}
208-
209-
# Update with our custom headers (including Authorization)
210-
request_headers.update(self.custom_headers)
211-
212-
# Detect streaming from URL pattern
213-
is_streaming = "stream" in url.lower()
214-
request_headers["X-UiPath-Streaming-Enabled"] = (
215-
"true" if is_streaming else "false"
216-
)
217-
218-
async with httpx.AsyncClient() as client:
219-
response = await client.request(
220-
method,
221-
self.llmgw_url,
222-
content=data,
223-
headers=request_headers,
224-
**kwargs,
225-
)
226-
# Return a response wrapper compatible with Google's expectations
227-
return HttpxAsyncResponseWrapper(
228-
status=response.status_code,
229-
headers=dict(response.headers),
230-
body=response.content,
231-
)
232-
233-
async def close(self):
234-
"""Close the session (no-op for our implementation)."""
235-
pass
236-
237-
238-
class HttpxAsyncResponseWrapper:
239-
"""Wrapper to make httpx async response compatible with Google's expected interface."""
240-
241-
def __init__(self, status: int, headers: dict, body: bytes):
242-
self.status = status
243-
self.headers = headers
244-
self._body = body
245-
246-
async def read(self) -> bytes:
247-
"""Read the response body."""
248-
return self._body
249-
250-
async def content(self) -> bytes:
251-
"""Read the response content."""
252-
return self._body
253-
254-
255-
class CustomAsyncPredictionServiceRestTransport(AsyncPredictionServiceRestTransport):
256-
"""Custom async REST transport that redirects requests to UiPath LLM Gateway."""
257-
258-
def __init__(self, llmgw_url: str, custom_headers: dict[str, str], **kwargs):
259-
self.llmgw_url = llmgw_url
260-
self.custom_headers = custom_headers or {}
261-
262-
# Use async-compatible credentials for the async transport
263-
kwargs.setdefault("credentials", AsyncAnonymousCredentials())
264-
super().__init__(**kwargs)
265-
266-
# Replace the session with a custom one that redirects requests
267-
self._session = CustomAsyncSession(
268-
llmgw_url=llmgw_url,
269-
custom_headers=self.custom_headers,
270-
)
271-
272-
27375
class UiPathVertex(Vertex):
27476
"""
27577
UiPath Vertex AI LLM that routes requests through UiPath's LLM Gateway.
@@ -350,25 +152,25 @@ def _patch_generative_model_client(self) -> None:
350152
llmgw_url = self._build_base_url()
351153
custom_headers = self._build_headers(self._uipath_token)
352154

353-
# Create custom REST transport that routes to UiPath Gateway
354-
rest_transport = CustomPredictionServiceRestTransport(
155+
# Create custom sync REST transport that routes to UiPath Gateway
156+
sync_transport = CustomPredictionServiceRestTransport(
355157
llmgw_url=llmgw_url, custom_headers=custom_headers
356158
)
357159

358-
# Create the prediction client with our custom transport
359-
custom_prediction_client = v1beta1PredictionServiceClient(
360-
transport=rest_transport,
160+
# Create the sync prediction client with our custom transport
161+
custom_sync_client = v1beta1PredictionServiceClient(
162+
transport=sync_transport,
361163
)
362164

363165
# Inject our custom client into the GenerativeModel instance
364166
# This bypasses the cached_property and uses our client directly
365-
# The GenerativeModel uses _prediction_client as a @functools.cached_property
366-
# Setting it in __dict__ ensures it's used instead of creating a new one
367-
if hasattr(self, '_client') and self._client is not None:
368-
self._client.__dict__['_prediction_client'] = custom_prediction_client
167+
# The GenerativeModel uses _prediction_client as @functools.cached_property
168+
# Setting in __dict__ ensures it's used
169+
if hasattr(self, "_client") and self._client is not None:
170+
self._client.__dict__["_prediction_client"] = custom_sync_client
369171

370-
if hasattr(self, '_chat_client') and self._chat_client is not None:
371-
self._chat_client.__dict__['_prediction_client'] = custom_prediction_client
172+
if hasattr(self, "_chat_client") and self._chat_client is not None:
173+
self._chat_client.__dict__["_prediction_client"] = custom_sync_client
372174

373175
@property
374176
def _uipath_endpoint(self) -> str:

0 commit comments

Comments
 (0)