Skip to content

Commit 484f24d

Browse files
committed
add vertex
1 parent 494ce8e commit 484f24d

4 files changed

Lines changed: 363 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ bedrock = [
3030
"llama-index-llms-bedrock-converse>=0.3.0",
3131
"boto3>=1.28.0",
3232
]
33+
vertex = [
34+
"llama-index-llms-vertex>=0.4.0",
35+
"google-cloud-aiplatform>=1.38.0",
36+
"aiohttp>=3.8.0",
37+
]
3338

3439
[project.entry-points."uipath.middlewares"]
3540
register = "uipath_llamaindex.middlewares:register_middleware"

src/uipath_llamaindex/llms/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,16 @@
22
OpenAIModel,
33
UiPathOpenAI,
44
)
5+
from .supported_models import (
6+
BedrockModels,
7+
GeminiModels,
8+
OpenAIModels,
9+
)
510

611
__all__ = [
712
"UiPathOpenAI",
813
"OpenAIModel",
14+
"OpenAIModels",
15+
"GeminiModels",
16+
"BedrockModels",
917
]
Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
import logging
2+
import os
3+
from typing import Optional, Union
4+
5+
import aiohttp
6+
from uipath.utils import EndpointManager
7+
8+
from .supported_models import GeminiModels
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
def _check_vertex_dependencies() -> None:
14+
"""Check if required dependencies for UiPath Vertex LLMs are installed."""
15+
import importlib.util
16+
17+
missing_packages = []
18+
19+
if importlib.util.find_spec("llama_index.llms.vertex") is None:
20+
missing_packages.append("llama-index-llms-vertex")
21+
22+
if missing_packages:
23+
packages_str = ", ".join(missing_packages)
24+
raise ImportError(
25+
f"The following packages are required to use UiPath Vertex LLMs: {packages_str}\n"
26+
"Please install them using one of the following methods:\n\n"
27+
" # Using pip:\n"
28+
f" pip install uipath-llamaindex[vertex]\n\n"
29+
" # Using uv:\n"
30+
f" uv add 'uipath-llamaindex[vertex]'\n\n"
31+
)
32+
33+
34+
_check_vertex_dependencies()
35+
36+
from google.auth.credentials import AnonymousCredentials
37+
from google.cloud.aiplatform_v1.services.prediction_service import (
38+
PredictionServiceAsyncClient as v1PredictionServiceAsyncClient,
39+
)
40+
from google.cloud.aiplatform_v1.services.prediction_service import (
41+
PredictionServiceClient as v1PredictionServiceClient,
42+
)
43+
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
44+
PredictionServiceAsyncClient as v1beta1PredictionServiceAsyncClient,
45+
)
46+
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
47+
PredictionServiceClient as v1beta1PredictionServiceClient,
48+
)
49+
from google.cloud.aiplatform_v1beta1.services.prediction_service.transports.rest import (
50+
PredictionServiceRestTransport,
51+
)
52+
from llama_index.llms.vertex import Vertex
53+
54+
55+
class CustomPredictionServiceRestTransport(PredictionServiceRestTransport):
56+
"""Custom REST transport that redirects requests to UiPath LLM Gateway."""
57+
58+
def __init__(self, llmgw_url: str, custom_headers: dict[str, str], **kwargs):
59+
self.llmgw_url = llmgw_url
60+
self.custom_headers = custom_headers or {}
61+
62+
kwargs.setdefault("credentials", AnonymousCredentials())
63+
super().__init__(**kwargs)
64+
65+
# Disable SSL verification for testing
66+
self._session.verify = False
67+
68+
original_request = self._session.request
69+
70+
def redirected_request(method, url, **kwargs_inner):
71+
headers = kwargs_inner.pop("headers", {})
72+
headers.update(self.custom_headers)
73+
74+
is_streaming = kwargs_inner.get("stream", False)
75+
headers["X-UiPath-Streaming-Enabled"] = "true" if is_streaming else "false"
76+
77+
return original_request(
78+
method, self.llmgw_url, headers=headers, **kwargs_inner
79+
)
80+
81+
self._session.request = redirected_request # type: ignore[method-assign, assignment]
82+
83+
84+
class CustomPredictionServiceRestAsyncTransport:
85+
"""
86+
Custom async transport for calling UiPath LLM Gateway.
87+
88+
Uses aiohttp for REST/HTTP communication instead of gRPC.
89+
Handles both regular and streaming responses from the gateway.
90+
"""
91+
92+
def __init__(self, llmgw_url: str, custom_headers: dict[str, str], **kwargs):
93+
self.llmgw_url = llmgw_url
94+
self.custom_headers = custom_headers or {}
95+
96+
def _serialize_request(self, request) -> str:
97+
"""Convert proto-plus request to JSON string."""
98+
import json
99+
100+
from proto import Message as ProtoMessage
101+
102+
if isinstance(request, ProtoMessage):
103+
request_dict = type(request).to_dict(
104+
request, preserving_proto_field_name=False
105+
)
106+
return json.dumps(request_dict)
107+
else:
108+
from google.protobuf.json_format import MessageToJson
109+
110+
return MessageToJson(request, preserving_proto_field_name=False)
111+
112+
def _get_response_class(self, request):
113+
"""Get the response class corresponding to the request class."""
114+
import importlib
115+
116+
response_class_name = request.__class__.__name__.replace("Request", "Response")
117+
response_class = getattr(
118+
request.__class__.__module__, response_class_name, None
119+
)
120+
121+
if response_class is None:
122+
module = importlib.import_module(request.__class__.__module__)
123+
response_class = getattr(module, response_class_name, None)
124+
125+
return response_class
126+
127+
def _deserialize_response(self, response_json: str, request):
128+
"""Convert JSON string to proto-plus response object."""
129+
import json
130+
131+
from proto import Message as ProtoMessage
132+
133+
response_class = self._get_response_class(request)
134+
135+
if response_class and isinstance(request, ProtoMessage):
136+
return response_class.from_json(response_json, ignore_unknown_fields=True)
137+
elif response_class:
138+
from google.protobuf.json_format import Parse
139+
140+
return Parse(response_json, response_class(), ignore_unknown_fields=True)
141+
else:
142+
return json.loads(response_json)
143+
144+
async def _make_request(self, request_json: str, streaming: bool = False):
145+
"""Make HTTP POST request to UiPath gateway."""
146+
headers = self.custom_headers.copy()
147+
headers["Content-Type"] = "application/json"
148+
149+
if streaming:
150+
headers["X-UiPath-Streaming-Enabled"] = "true"
151+
152+
connector = aiohttp.TCPConnector(ssl=False)
153+
async with aiohttp.ClientSession(connector=connector) as session:
154+
async with session.post(
155+
self.llmgw_url, headers=headers, data=request_json
156+
) as response:
157+
if response.status != 200:
158+
error_text = await response.text()
159+
raise Exception(f"HTTP {response.status}: {error_text}")
160+
161+
return await response.text()
162+
163+
async def generate_content(self, request, **kwargs):
164+
"""Handle non-streaming generate_content calls."""
165+
request_json = self._serialize_request(request)
166+
response_text = await self._make_request(request_json, streaming=False)
167+
return self._deserialize_response(response_text, request)
168+
169+
def stream_generate_content(self, request, **kwargs):
170+
"""
171+
Handle streaming generate_content calls.
172+
173+
Returns a coroutine that yields an async iterator.
174+
"""
175+
return self._create_stream_awaitable(request)
176+
177+
async def _create_stream_awaitable(self, request):
178+
"""Awaitable wrapper that returns the async generator."""
179+
return self._stream_implementation(request)
180+
181+
async def _stream_implementation(self, request):
182+
"""
183+
Async generator that yields streaming response chunks.
184+
185+
Parses the array and yields each chunk individually.
186+
"""
187+
import json
188+
189+
request_json = self._serialize_request(request)
190+
response_text = await self._make_request(request_json, streaming=True)
191+
192+
try:
193+
chunks_array = json.loads(response_text)
194+
if isinstance(chunks_array, list):
195+
logger.info(f"Streaming: yielding {len(chunks_array)} chunks")
196+
for chunk_data in chunks_array:
197+
chunk_json = json.dumps(chunk_data)
198+
yield self._deserialize_response(chunk_json, request)
199+
return
200+
except Exception as e:
201+
logger.info(f"Not a JSON array, trying single response: {e}")
202+
203+
try:
204+
yield self._deserialize_response(response_text, request)
205+
except Exception as e:
206+
logger.error(f"Failed to parse streaming response: {e}")
207+
208+
209+
class UiPathVertex(Vertex):
210+
"""
211+
UiPath Vertex AI LLM that routes requests through UiPath's LLM Gateway.
212+
213+
This class wraps LlamaIndex's Vertex class and redirects all API calls
214+
to UiPath's LLM Gateway for authentication and routing.
215+
216+
Args:
217+
org_id: UiPath organization ID. Falls back to UIPATH_ORGANIZATION_ID env var.
218+
tenant_id: UiPath tenant ID. Falls back to UIPATH_TENANT_ID env var.
219+
token: UiPath access token. Falls back to UIPATH_ACCESS_TOKEN env var.
220+
model: Model identifier. Defaults to gemini-2.5-flash.
221+
**kwargs: Additional arguments passed to the Vertex base class.
222+
223+
Example:
224+
```python
225+
from uipath_llamaindex.llms import UiPathVertex, GeminiModels
226+
227+
llm = UiPathVertex(model=GeminiModels.gemini_2_5_pro)
228+
response = llm.complete("What is the capital of France?")
229+
```
230+
"""
231+
232+
_transport: Optional[CustomPredictionServiceRestTransport] = None
233+
_async_transport: Optional[CustomPredictionServiceRestAsyncTransport] = None
234+
_sync_client: Optional[
235+
Union[v1beta1PredictionServiceClient, v1PredictionServiceClient]
236+
] = None
237+
_async_client: Optional[
238+
Union[v1beta1PredictionServiceAsyncClient, v1PredictionServiceAsyncClient]
239+
] = None
240+
_llmgw_url: Optional[str] = None
241+
_custom_headers: Optional[dict[str, str]] = None
242+
243+
def __init__(
244+
self,
245+
org_id: Optional[str] = None,
246+
tenant_id: Optional[str] = None,
247+
token: Optional[str] = None,
248+
model: str = GeminiModels.gemini_2_5_flash,
249+
**kwargs,
250+
):
251+
org_id = org_id or os.getenv("UIPATH_ORGANIZATION_ID")
252+
tenant_id = tenant_id or os.getenv("UIPATH_TENANT_ID")
253+
token = token or os.getenv("UIPATH_ACCESS_TOKEN")
254+
255+
if not org_id:
256+
raise ValueError(
257+
"UIPATH_ORGANIZATION_ID environment variable or org_id parameter is required"
258+
)
259+
if not tenant_id:
260+
raise ValueError(
261+
"UIPATH_TENANT_ID environment variable or tenant_id parameter is required"
262+
)
263+
if not token:
264+
raise ValueError(
265+
"UIPATH_ACCESS_TOKEN environment variable or token parameter is required"
266+
)
267+
268+
self._vendor = "vertexai"
269+
self._model_name = model
270+
self._url: Optional[str] = None
271+
272+
self._llmgw_url = self._build_base_url()
273+
self._custom_headers = self._build_headers(token)
274+
275+
# Initialize base Vertex class with dummy credentials
276+
# The actual auth is handled by UiPath Gateway
277+
super().__init__(
278+
model=model,
279+
project=os.getenv("VERTEXAI_PROJECT", "none"),
280+
location=os.getenv("VERTEXAI_LOCATION", "us-central1"),
281+
credentials=AnonymousCredentials(),
282+
**kwargs,
283+
)
284+
285+
# Set up custom transports
286+
self._transport = CustomPredictionServiceRestTransport(
287+
llmgw_url=self._llmgw_url, custom_headers=self._custom_headers
288+
)
289+
290+
self._async_transport = CustomPredictionServiceRestAsyncTransport(
291+
llmgw_url=self._llmgw_url, custom_headers=self._custom_headers
292+
)
293+
294+
@property
295+
def endpoint(self) -> str:
296+
"""Get the UiPath LLM Gateway endpoint for this model."""
297+
vendor_endpoint = EndpointManager.get_vendor_endpoint()
298+
formatted_endpoint = vendor_endpoint.format(
299+
vendor=self._vendor,
300+
model=self._model_name,
301+
)
302+
return formatted_endpoint
303+
304+
def _build_headers(self, token: str) -> dict[str, str]:
305+
"""Build HTTP headers for UiPath Gateway requests."""
306+
headers = {
307+
"Authorization": f"Bearer {token}",
308+
}
309+
if job_key := os.getenv("UIPATH_JOB_KEY"):
310+
headers["X-UiPath-JobKey"] = job_key
311+
if process_key := os.getenv("UIPATH_PROCESS_KEY"):
312+
headers["X-UiPath-ProcessKey"] = process_key
313+
return headers
314+
315+
def _build_base_url(self) -> str:
316+
"""Build the full URL for the UiPath LLM Gateway."""
317+
if not self._url:
318+
env_uipath_url = os.getenv("UIPATH_URL")
319+
320+
if env_uipath_url:
321+
self._url = f"{env_uipath_url.rstrip('/')}/{self.endpoint}"
322+
else:
323+
raise ValueError("UIPATH_URL environment variable is required")
324+
325+
return self._url
326+
327+
@property
328+
def _client(self):
329+
"""Get the sync prediction client with custom transport."""
330+
if self._sync_client is None:
331+
self._sync_client = v1beta1PredictionServiceClient(
332+
transport=self._transport,
333+
)
334+
return self._sync_client
335+
336+
@property
337+
def _aclient(self):
338+
"""Get the async prediction client (uses custom async transport)."""
339+
return self._async_transport

0 commit comments

Comments
 (0)