Skip to content

Commit 8161535

Browse files
committed
xyz
1 parent dc9ca33 commit 8161535

3 files changed

Lines changed: 136 additions & 97 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ dependencies = [
99
"llama-index>=0.14.8",
1010
"llama-index-embeddings-azure-openai>=0.4.1",
1111
"llama-index-llms-azure-openai>=0.4.2",
12+
"llama-index-llms-google-genai>=0.8.0",
1213
"openinference-instrumentation-llama-index>=4.3.9",
1314
"uipath>=2.2.26, <2.3.0",
1415
]
Lines changed: 119 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import logging
22
import os
3-
from typing import Optional
3+
from typing import Any, Optional
44

5+
import httpx
6+
from llama_index.core.callbacks import CallbackManager
7+
from llama_index.core.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE
58
from uipath.utils import EndpointManager
69

710
from .supported_models import GeminiModels
@@ -15,8 +18,8 @@ def _check_vertex_dependencies() -> None:
1518

1619
missing_packages = []
1720

18-
if importlib.util.find_spec("llama_index.llms.vertex") is None:
19-
missing_packages.append("llama-index-llms-vertex")
21+
if importlib.util.find_spec("llama_index.llms.google_genai") is None:
22+
missing_packages.append("llama-index-llms-google-genai")
2023

2124
if missing_packages:
2225
packages_str = ", ".join(missing_packages)
@@ -32,61 +35,62 @@ def _check_vertex_dependencies() -> None:
3235

3336
_check_vertex_dependencies()
3437

35-
from google.auth.credentials import AnonymousCredentials
36-
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
37-
PredictionServiceClient as v1beta1PredictionServiceClient,
38-
)
39-
from google.cloud.aiplatform_v1beta1.services.prediction_service.transports.rest import (
40-
PredictionServiceRestTransport,
41-
)
38+
import google.genai
39+
import google.genai.types as genai_types
4240
from llama_index.core.bridge.pydantic import PrivateAttr
43-
from llama_index.llms.vertex import Vertex
41+
from llama_index.llms.google_genai import GoogleGenAI
4442

4543

46-
class CustomPredictionServiceRestTransport(PredictionServiceRestTransport):
47-
"""Custom REST transport that redirects requests to UiPath LLM Gateway."""
44+
class UiPathHttpxClient(httpx.Client):
45+
"""Custom httpx client that redirects generateContent requests to UiPath gateway."""
4846

49-
def __init__(self, llmgw_url: str, custom_headers: dict[str, str], **kwargs):
50-
self.llmgw_url = llmgw_url
51-
self.custom_headers = custom_headers or {}
52-
53-
kwargs.setdefault("credentials", AnonymousCredentials())
47+
def __init__(self, gateway_url: str, **kwargs):
48+
self.gateway_url = gateway_url
5449
super().__init__(**kwargs)
5550

56-
# Monkey-patch the session's request method to redirect to UiPath Gateway
57-
# This preserves the session object identity while redirecting all requests
58-
original_request = self._session.request
59-
60-
def redirected_request(method, url, **kwargs_inner):
61-
headers = kwargs_inner.pop("headers", {})
62-
headers.update(self.custom_headers)
63-
64-
# Remove Google's internal query parameters - UiPath gateway doesn't need them
65-
kwargs_inner.pop("params", None)
66-
67-
is_streaming = kwargs_inner.get("stream", False)
68-
headers["X-UiPath-Streaming-Enabled"] = "true" if is_streaming else "false"
69-
70-
return original_request(
71-
method, self.llmgw_url, headers=headers, **kwargs_inner
51+
def request(self, method: str, url: Any, **kwargs) -> httpx.Response:
52+
"""Override request to redirect generateContent/streamGenerateContent to UiPath gateway."""
53+
url_str = str(url)
54+
if "generateContent" in url_str or "streamGenerateContent" in url_str:
55+
url = self.gateway_url
56+
return super().request(method, url, **kwargs)
57+
58+
def send(self, request: httpx.Request, **kwargs) -> httpx.Response:
59+
"""Override send to redirect generateContent/streamGenerateContent to UiPath gateway."""
60+
url_str = str(request.url)
61+
if "generateContent" in url_str or "streamGenerateContent" in url_str:
62+
is_streaming = "streamGenerateContent" in url_str
63+
# Build headers with streaming flag and correct host
64+
headers = dict(request.headers)
65+
if is_streaming:
66+
headers["X-UiPath-Streaming-Enabled"] = "true"
67+
# Update host header to match the gateway URL
68+
gateway_url_parsed = httpx.URL(self.gateway_url)
69+
headers["host"] = gateway_url_parsed.host
70+
# Create new request with rewritten URL
71+
request = httpx.Request(
72+
method=request.method,
73+
url=self.gateway_url,
74+
headers=headers,
75+
content=request.content,
76+
extensions=request.extensions,
7277
)
73-
74-
self._session.request = redirected_request # type: ignore[method-assign]
78+
return super().send(request, **kwargs)
7579

7680

77-
class UiPathVertex(Vertex):
81+
class UiPathVertex(GoogleGenAI):
7882
"""
7983
UiPath Vertex AI LLM that routes requests through UiPath's LLM Gateway.
8084
81-
This class wraps LlamaIndex's Vertex class and redirects all API calls
85+
This class wraps LlamaIndex's GoogleGenAI class and redirects all API calls
8286
to UiPath's LLM Gateway for authentication and routing.
8387
8488
Args:
8589
org_id: UiPath organization ID. Falls back to UIPATH_ORGANIZATION_ID env var.
8690
tenant_id: UiPath tenant ID. Falls back to UIPATH_TENANT_ID env var.
8791
token: UiPath access token. Falls back to UIPATH_ACCESS_TOKEN env var.
8892
model: Model identifier. Defaults to gemini-2.5-flash.
89-
**kwargs: Additional arguments passed to the Vertex base class.
93+
**kwargs: Additional arguments passed to the GoogleGenAI base class.
9094
9195
Example:
9296
```python
@@ -110,7 +114,14 @@ def __init__(
110114
tenant_id: Optional[str] = None,
111115
token: Optional[str] = None,
112116
model: str = GeminiModels.gemini_2_5_flash,
113-
**kwargs,
117+
temperature: float = DEFAULT_TEMPERATURE,
118+
max_tokens: Optional[int] = None,
119+
context_window: Optional[int] = None,
120+
max_retries: int = 3,
121+
generation_config: Optional[genai_types.GenerateContentConfig] = None,
122+
callback_manager: Optional[CallbackManager] = None,
123+
is_function_calling_model: bool = True,
124+
**kwargs: Any,
114125
):
115126
org_id = org_id or os.getenv("UIPATH_ORGANIZATION_ID")
116127
tenant_id = tenant_id or os.getenv("UIPATH_TENANT_ID")
@@ -129,62 +140,72 @@ def __init__(
129140
"UIPATH_ACCESS_TOKEN environment variable or token parameter is required"
130141
)
131142

132-
# Initialize base Vertex class with dummy credentials
133-
# The actual auth is handled by UiPath Gateway
134-
super().__init__(
135-
model=model,
136-
project=os.getenv("VERTEXAI_PROJECT", "none"),
137-
location=os.getenv("VERTEXAI_LOCATION", "us-central1"),
138-
credentials=AnonymousCredentials(),
139-
**kwargs,
140-
)
141-
142-
# Set private attributes after super().__init__
143-
self._uipath_vendor = "vertexai"
144-
self._uipath_model_name = model
145-
self._uipath_url = None
146-
self._uipath_token = token
147-
148-
# After super().__init__, self._client is a GenerativeModel instance
149-
# We need to patch its _prediction_client to use our custom transport
150-
self._patch_generative_model_client()
143+
# Build UiPath gateway URL and headers
144+
uipath_url = self._build_base_url_static(model)
145+
headers = self._build_headers_static(token)
151146

152-
def _patch_generative_model_client(self) -> None:
153-
"""Patch the GenerativeModel's internal prediction client to use UiPath Gateway."""
154-
llmgw_url = self._build_base_url()
155-
custom_headers = self._build_headers(self._uipath_token)
156-
157-
# Create custom sync REST transport that routes to UiPath Gateway
158-
sync_transport = CustomPredictionServiceRestTransport(
159-
llmgw_url=llmgw_url, custom_headers=custom_headers
147+
# Create custom httpx client that redirects requests to UiPath gateway
148+
custom_httpx_client = UiPathHttpxClient(
149+
gateway_url=uipath_url,
150+
headers=headers,
151+
follow_redirects=True,
160152
)
161153

162-
# Create the sync prediction client with our custom transport
163-
custom_sync_client = v1beta1PredictionServiceClient(
164-
transport=sync_transport,
154+
# Configure HTTP options with our custom client
155+
http_options = genai_types.HttpOptions(
156+
httpxClient=custom_httpx_client,
165157
)
166158

167-
# Inject our custom client into the GenerativeModel instance
168-
# This bypasses the cached_property and uses our client directly
169-
# The GenerativeModel uses _prediction_client as @functools.cached_property
170-
# Setting in __dict__ ensures it's used
171-
if hasattr(self, "_client") and self._client is not None:
172-
self._client.__dict__["_prediction_client"] = custom_sync_client
159+
# Create google.genai client with custom httpx client
160+
# We pass a dummy api_key since auth is handled by UiPath headers
161+
client = google.genai.Client(
162+
api_key="uipath-gateway",
163+
http_options=http_options,
164+
)
173165

174-
if hasattr(self, "_chat_client") and self._chat_client is not None:
175-
self._chat_client.__dict__["_prediction_client"] = custom_sync_client
166+
# Skip calling GoogleGenAI.__init__ which tries to fetch model metadata
167+
# Instead, initialize the grandparent (FunctionCallingLLM) directly
168+
# and set up the attributes ourselves
169+
from llama_index.core.llms.function_calling import FunctionCallingLLM
176170

177-
@property
178-
def _uipath_endpoint(self) -> str:
179-
"""Get the UiPath LLM Gateway endpoint for this model."""
180-
vendor_endpoint = EndpointManager.get_vendor_endpoint()
181-
formatted_endpoint = vendor_endpoint.format(
182-
vendor=self._uipath_vendor,
183-
model=self._uipath_model_name,
171+
FunctionCallingLLM.__init__(
172+
self,
173+
callback_manager=callback_manager,
174+
**kwargs,
184175
)
185-
return formatted_endpoint
186176

187-
def _build_headers(self, token: str) -> dict[str, str]:
177+
# Set GoogleGenAI public attributes
178+
self.model = model
179+
self.temperature = temperature
180+
self.context_window = context_window
181+
self.max_retries = max_retries
182+
self.is_function_calling_model = is_function_calling_model
183+
self.cached_content = None
184+
self.built_in_tool = None
185+
self.file_mode = "hybrid"
186+
187+
# Set GoogleGenAI private attributes
188+
self._client = client
189+
self._model_meta = None # We skip model metadata fetch
190+
self._max_tokens = max_tokens or DEFAULT_NUM_OUTPUTS
191+
192+
# Set up generation config
193+
if generation_config:
194+
self._generation_config = generation_config.model_dump()
195+
else:
196+
self._generation_config = genai_types.GenerateContentConfig(
197+
temperature=temperature,
198+
max_output_tokens=max_tokens,
199+
).model_dump()
200+
201+
# Set UiPath private attributes
202+
self._uipath_vendor = "vertexai"
203+
self._uipath_model_name = model
204+
self._uipath_url = uipath_url
205+
self._uipath_token = token
206+
207+
@staticmethod
208+
def _build_headers_static(token: str) -> dict[str, str]:
188209
"""Build HTTP headers for UiPath Gateway requests."""
189210
headers = {
190211
"Authorization": f"Bearer {token}",
@@ -195,16 +216,17 @@ def _build_headers(self, token: str) -> dict[str, str]:
195216
headers["X-UiPath-ProcessKey"] = process_key
196217
return headers
197218

198-
def _build_base_url(self) -> str:
219+
@staticmethod
220+
def _build_base_url_static(model: str) -> str:
199221
"""Build the full URL for the UiPath LLM Gateway."""
200-
if not self._uipath_url:
201-
env_uipath_url = os.getenv("UIPATH_URL")
222+
env_uipath_url = os.getenv("UIPATH_URL")
202223

203-
if env_uipath_url:
204-
self._uipath_url = (
205-
f"{env_uipath_url.rstrip('/')}/{self._uipath_endpoint}"
206-
)
207-
else:
208-
raise ValueError("UIPATH_URL environment variable is required")
224+
if not env_uipath_url:
225+
raise ValueError("UIPATH_URL environment variable is required")
209226

210-
return self._uipath_url
227+
vendor_endpoint = EndpointManager.get_vendor_endpoint()
228+
formatted_endpoint = vendor_endpoint.format(
229+
vendor="vertexai",
230+
model=model,
231+
)
232+
return f"{env_uipath_url.rstrip('/')}/{formatted_endpoint}"

uv.lock

Lines changed: 16 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)