11import logging
22import os
3- from typing import Optional
3+ from typing import Any , Optional
44
5+ import httpx
6+ from llama_index .core .callbacks import CallbackManager
7+ from llama_index .core .constants import DEFAULT_NUM_OUTPUTS , DEFAULT_TEMPERATURE
58from uipath .utils import EndpointManager
69
710from .supported_models import GeminiModels
@@ -15,8 +18,8 @@ def _check_vertex_dependencies() -> None:
1518
1619 missing_packages = []
1720
18- if importlib .util .find_spec ("llama_index.llms.vertex " ) is None :
19- missing_packages .append ("llama-index-llms-vertex " )
21+ if importlib .util .find_spec ("llama_index.llms.google_genai " ) is None :
22+ missing_packages .append ("llama-index-llms-google-genai " )
2023
2124 if missing_packages :
2225 packages_str = ", " .join (missing_packages )
@@ -32,61 +35,62 @@ def _check_vertex_dependencies() -> None:
3235
3336_check_vertex_dependencies ()
3437
35- from google .auth .credentials import AnonymousCredentials
36- from google .cloud .aiplatform_v1beta1 .services .prediction_service import (
37- PredictionServiceClient as v1beta1PredictionServiceClient ,
38- )
39- from google .cloud .aiplatform_v1beta1 .services .prediction_service .transports .rest import (
40- PredictionServiceRestTransport ,
41- )
38+ import google .genai
39+ import google .genai .types as genai_types
4240from llama_index .core .bridge .pydantic import PrivateAttr
43- from llama_index .llms .vertex import Vertex
41+ from llama_index .llms .google_genai import GoogleGenAI
4442
4543
46- class CustomPredictionServiceRestTransport ( PredictionServiceRestTransport ):
47- """Custom REST transport that redirects requests to UiPath LLM Gateway ."""
44+ class UiPathHttpxClient ( httpx . Client ):
45+ """Custom httpx client that redirects generateContent requests to UiPath gateway ."""
4846
49- def __init__ (self , llmgw_url : str , custom_headers : dict [str , str ], ** kwargs ):
50- self .llmgw_url = llmgw_url
51- self .custom_headers = custom_headers or {}
52-
53- kwargs .setdefault ("credentials" , AnonymousCredentials ())
47+ def __init__ (self , gateway_url : str , ** kwargs ):
48+ self .gateway_url = gateway_url
5449 super ().__init__ (** kwargs )
5550
56- # Monkey-patch the session's request method to redirect to UiPath Gateway
57- # This preserves the session object identity while redirecting all requests
58- original_request = self ._session .request
59-
60- def redirected_request (method , url , ** kwargs_inner ):
61- headers = kwargs_inner .pop ("headers" , {})
62- headers .update (self .custom_headers )
63-
64- # Remove Google's internal query parameters - UiPath gateway doesn't need them
65- kwargs_inner .pop ("params" , None )
66-
67- is_streaming = kwargs_inner .get ("stream" , False )
68- headers ["X-UiPath-Streaming-Enabled" ] = "true" if is_streaming else "false"
69-
70- return original_request (
71- method , self .llmgw_url , headers = headers , ** kwargs_inner
51+ def request (self , method : str , url : Any , ** kwargs ) -> httpx .Response :
52+ """Override request to redirect generateContent/streamGenerateContent to UiPath gateway."""
53+ url_str = str (url )
54+ if "generateContent" in url_str or "streamGenerateContent" in url_str :
55+ url = self .gateway_url
56+ return super ().request (method , url , ** kwargs )
57+
58+ def send (self , request : httpx .Request , ** kwargs ) -> httpx .Response :
59+ """Override send to redirect generateContent/streamGenerateContent to UiPath gateway."""
60+ url_str = str (request .url )
61+ if "generateContent" in url_str or "streamGenerateContent" in url_str :
62+ is_streaming = "streamGenerateContent" in url_str
63+ # Build headers with streaming flag and correct host
64+ headers = dict (request .headers )
65+ if is_streaming :
66+ headers ["X-UiPath-Streaming-Enabled" ] = "true"
67+ # Update host header to match the gateway URL
68+ gateway_url_parsed = httpx .URL (self .gateway_url )
69+ headers ["host" ] = gateway_url_parsed .host
70+ # Create new request with rewritten URL
71+ request = httpx .Request (
72+ method = request .method ,
73+ url = self .gateway_url ,
74+ headers = headers ,
75+ content = request .content ,
76+ extensions = request .extensions ,
7277 )
73-
74- self ._session .request = redirected_request # type: ignore[method-assign]
78+ return super ().send (request , ** kwargs )
7579
7680
77- class UiPathVertex (Vertex ):
81+ class UiPathVertex (GoogleGenAI ):
7882 """
7983 UiPath Vertex AI LLM that routes requests through UiPath's LLM Gateway.
8084
81- This class wraps LlamaIndex's Vertex class and redirects all API calls
85+ This class wraps LlamaIndex's GoogleGenAI class and redirects all API calls
8286 to UiPath's LLM Gateway for authentication and routing.
8387
8488 Args:
8589 org_id: UiPath organization ID. Falls back to UIPATH_ORGANIZATION_ID env var.
8690 tenant_id: UiPath tenant ID. Falls back to UIPATH_TENANT_ID env var.
8791 token: UiPath access token. Falls back to UIPATH_ACCESS_TOKEN env var.
8892 model: Model identifier. Defaults to gemini-2.5-flash.
89- **kwargs: Additional arguments passed to the Vertex base class.
93+ **kwargs: Additional arguments passed to the GoogleGenAI base class.
9094
9195 Example:
9296 ```python
@@ -110,7 +114,14 @@ def __init__(
110114 tenant_id : Optional [str ] = None ,
111115 token : Optional [str ] = None ,
112116 model : str = GeminiModels .gemini_2_5_flash ,
113- ** kwargs ,
117+ temperature : float = DEFAULT_TEMPERATURE ,
118+ max_tokens : Optional [int ] = None ,
119+ context_window : Optional [int ] = None ,
120+ max_retries : int = 3 ,
121+ generation_config : Optional [genai_types .GenerateContentConfig ] = None ,
122+ callback_manager : Optional [CallbackManager ] = None ,
123+ is_function_calling_model : bool = True ,
124+ ** kwargs : Any ,
114125 ):
115126 org_id = org_id or os .getenv ("UIPATH_ORGANIZATION_ID" )
116127 tenant_id = tenant_id or os .getenv ("UIPATH_TENANT_ID" )
@@ -129,62 +140,72 @@ def __init__(
129140 "UIPATH_ACCESS_TOKEN environment variable or token parameter is required"
130141 )
131142
132- # Initialize base Vertex class with dummy credentials
133- # The actual auth is handled by UiPath Gateway
134- super ().__init__ (
135- model = model ,
136- project = os .getenv ("VERTEXAI_PROJECT" , "none" ),
137- location = os .getenv ("VERTEXAI_LOCATION" , "us-central1" ),
138- credentials = AnonymousCredentials (),
139- ** kwargs ,
140- )
141-
142- # Set private attributes after super().__init__
143- self ._uipath_vendor = "vertexai"
144- self ._uipath_model_name = model
145- self ._uipath_url = None
146- self ._uipath_token = token
147-
148- # After super().__init__, self._client is a GenerativeModel instance
149- # We need to patch its _prediction_client to use our custom transport
150- self ._patch_generative_model_client ()
143+ # Build UiPath gateway URL and headers
144+ uipath_url = self ._build_base_url_static (model )
145+ headers = self ._build_headers_static (token )
151146
152- def _patch_generative_model_client (self ) -> None :
153- """Patch the GenerativeModel's internal prediction client to use UiPath Gateway."""
154- llmgw_url = self ._build_base_url ()
155- custom_headers = self ._build_headers (self ._uipath_token )
156-
157- # Create custom sync REST transport that routes to UiPath Gateway
158- sync_transport = CustomPredictionServiceRestTransport (
159- llmgw_url = llmgw_url , custom_headers = custom_headers
147+ # Create custom httpx client that redirects requests to UiPath gateway
148+ custom_httpx_client = UiPathHttpxClient (
149+ gateway_url = uipath_url ,
150+ headers = headers ,
151+ follow_redirects = True ,
160152 )
161153
162- # Create the sync prediction client with our custom transport
163- custom_sync_client = v1beta1PredictionServiceClient (
164- transport = sync_transport ,
154+ # Configure HTTP options with our custom client
155+ http_options = genai_types . HttpOptions (
156+ httpxClient = custom_httpx_client ,
165157 )
166158
167- # Inject our custom client into the GenerativeModel instance
168- # This bypasses the cached_property and uses our client directly
169- # The GenerativeModel uses _prediction_client as @functools.cached_property
170- # Setting in __dict__ ensures it's used
171- if hasattr ( self , "_client" ) and self . _client is not None :
172- self . _client . __dict__ [ "_prediction_client" ] = custom_sync_client
159+ # Create google.genai client with custom httpx client
160+ # We pass a dummy api_key since auth is handled by UiPath headers
161+ client = google . genai . Client (
162+ api_key = "uipath-gateway" ,
163+ http_options = http_options ,
164+ )
173165
174- if hasattr (self , "_chat_client" ) and self ._chat_client is not None :
175- self ._chat_client .__dict__ ["_prediction_client" ] = custom_sync_client
166+ # Skip calling GoogleGenAI.__init__ which tries to fetch model metadata
167+ # Instead, initialize the grandparent (FunctionCallingLLM) directly
168+ # and set up the attributes ourselves
169+ from llama_index .core .llms .function_calling import FunctionCallingLLM
176170
177- @property
178- def _uipath_endpoint (self ) -> str :
179- """Get the UiPath LLM Gateway endpoint for this model."""
180- vendor_endpoint = EndpointManager .get_vendor_endpoint ()
181- formatted_endpoint = vendor_endpoint .format (
182- vendor = self ._uipath_vendor ,
183- model = self ._uipath_model_name ,
171+ FunctionCallingLLM .__init__ (
172+ self ,
173+ callback_manager = callback_manager ,
174+ ** kwargs ,
184175 )
185- return formatted_endpoint
186176
187- def _build_headers (self , token : str ) -> dict [str , str ]:
177+ # Set GoogleGenAI public attributes
178+ self .model = model
179+ self .temperature = temperature
180+ self .context_window = context_window
181+ self .max_retries = max_retries
182+ self .is_function_calling_model = is_function_calling_model
183+ self .cached_content = None
184+ self .built_in_tool = None
185+ self .file_mode = "hybrid"
186+
187+ # Set GoogleGenAI private attributes
188+ self ._client = client
189+ self ._model_meta = None # We skip model metadata fetch
190+ self ._max_tokens = max_tokens or DEFAULT_NUM_OUTPUTS
191+
192+ # Set up generation config
193+ if generation_config :
194+ self ._generation_config = generation_config .model_dump ()
195+ else :
196+ self ._generation_config = genai_types .GenerateContentConfig (
197+ temperature = temperature ,
198+ max_output_tokens = max_tokens ,
199+ ).model_dump ()
200+
201+ # Set UiPath private attributes
202+ self ._uipath_vendor = "vertexai"
203+ self ._uipath_model_name = model
204+ self ._uipath_url = uipath_url
205+ self ._uipath_token = token
206+
207+ @staticmethod
208+ def _build_headers_static (token : str ) -> dict [str , str ]:
188209 """Build HTTP headers for UiPath Gateway requests."""
189210 headers = {
190211 "Authorization" : f"Bearer { token } " ,
@@ -195,16 +216,17 @@ def _build_headers(self, token: str) -> dict[str, str]:
195216 headers ["X-UiPath-ProcessKey" ] = process_key
196217 return headers
197218
198- def _build_base_url (self ) -> str :
219+ @staticmethod
220+ def _build_base_url_static (model : str ) -> str :
199221 """Build the full URL for the UiPath LLM Gateway."""
200- if not self ._uipath_url :
201- env_uipath_url = os .getenv ("UIPATH_URL" )
222+ env_uipath_url = os .getenv ("UIPATH_URL" )
202223
203- if env_uipath_url :
204- self ._uipath_url = (
205- f"{ env_uipath_url .rstrip ('/' )} /{ self ._uipath_endpoint } "
206- )
207- else :
208- raise ValueError ("UIPATH_URL environment variable is required" )
224+ if not env_uipath_url :
225+ raise ValueError ("UIPATH_URL environment variable is required" )
209226
210- return self ._uipath_url
227+ vendor_endpoint = EndpointManager .get_vendor_endpoint ()
228+ formatted_endpoint = vendor_endpoint .format (
229+ vendor = "vertexai" ,
230+ model = model ,
231+ )
232+ return f"{ env_uipath_url .rstrip ('/' )} /{ formatted_endpoint } "
0 commit comments