Skip to content

Commit 72e2ae2

Browse files
cleop-googlecopybara-github
authored andcommitted
feat: GenAI SDK client(multimodal) - Add single_turn_template helper to GeminiRequestReadConfig.
PiperOrigin-RevId: 890422497
1 parent e164b19 commit 72e2ae2

2 files changed

Lines changed: 155 additions & 0 deletions

File tree

tests/unit/vertexai/genai/test_multimodal_datasets_genai.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Tests for multimodal datasets."""
1616

1717
from vertexai._genai import types
18+
from google.genai import types as genai_types
1819

1920

2021
class TestMultimodalDataset:
@@ -126,3 +127,59 @@ def test_set_bigquery_uri_preserves_other_fields(self):
126127
dataset.metadata.gemini_request_read_config.assembled_request_column_name
127128
== "test_column"
128129
)
130+
131+
132+
class TestGeminiRequestReadConfig:
133+
def test_single_turn_template(self):
134+
read_config = types.GeminiRequestReadConfig.single_turn_template(
135+
model="gemini-1.5-flash",
136+
prompt="test_prompt",
137+
response="test_response",
138+
system_instruction="test_system_instruction",
139+
cached_content="test_cached_content",
140+
tools={"function_declarations": [{"name": "test_tool"}]},
141+
tool_config={"function_calling_config": {"mode": "ANY"}},
142+
safety_settings={"category": "HARM_CATEGORY_DANGEROUS_CONTENT"},
143+
generation_config={"temperature": 0.5},
144+
field_mapping={"test_placeholder": "test_column"},
145+
)
146+
147+
expected_read_config = types.GeminiRequestReadConfig(
148+
template_config=types.GeminiTemplateConfig(
149+
gemini_example=types.GeminiExample(
150+
model="gemini-1.5-flash",
151+
contents=[
152+
genai_types.Content(
153+
role="user",
154+
parts=[genai_types.Part.from_text(text="test_prompt")],
155+
),
156+
genai_types.Content(
157+
role="model",
158+
parts=[genai_types.Part.from_text(text="test_response")],
159+
),
160+
],
161+
system_instruction=genai_types.Content(
162+
parts=[
163+
genai_types.Part.from_text(text="test_system_instruction")
164+
],
165+
),
166+
cached_content="test_cached_content",
167+
tools=genai_types.Tool(
168+
function_declarations=[
169+
genai_types.FunctionDeclaration(name="test_tool")
170+
]
171+
),
172+
tool_config=genai_types.ToolConfig(
173+
function_calling_config=genai_types.FunctionCallingConfig(
174+
mode="ANY"
175+
)
176+
),
177+
safety_settings=genai_types.SafetySetting(
178+
category="HARM_CATEGORY_DANGEROUS_CONTENT"
179+
),
180+
generation_config=genai_types.GenerationConfig(temperature=0.5),
181+
),
182+
field_mapping={"test_placeholder": "test_column"},
183+
),
184+
)
185+
assert read_config == expected_read_config

vertexai/_genai/types/common.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11873,6 +11873,104 @@ class GeminiRequestReadConfig(_common.BaseModel):
1187311873
description="""Column name in the underlying BigQuery table that contains already fully assembled Gemini requests.""",
1187411874
)
1187511875

11876+
@classmethod
11877+
def single_turn_template(
11878+
cls,
11879+
*,
11880+
prompt: Optional[str] = None,
11881+
response: Optional[str] = None,
11882+
system_instruction: Optional[str] = None,
11883+
model: Optional[str] = None,
11884+
cached_content: Optional[str] = None,
11885+
tools: Optional[list[Union[genai_types.Tool, dict[str, Any]]]] = None,
11886+
tool_config: Optional[Union[genai_types.ToolConfig, dict[str, Any]]] = None,
11887+
safety_settings: Optional[
11888+
list[Union[genai_types.SafetySetting, dict[str, Any]]]
11889+
] = None,
11890+
generation_config: Optional[
11891+
Union[genai_types.GenerationConfig, dict[str, Any]]
11892+
] = None,
11893+
field_mapping: Optional[dict[str, str]] = None,
11894+
) -> "GeminiRequestReadConfig":
11895+
"""Constructs a GeminiRequestReadConfig object for single-turn cases.
11896+
11897+
Example:
11898+
template_config = datasets.construct_single_turn_template(
11899+
prompt = "Which flower is this {flower_image} ?",
11900+
response="This is a {label}.",
11901+
system_instruction="You are a botanical classifier."
11902+
)
11903+
11904+
Args:
11905+
11906+
prompt (str):
11907+
Required. User input.
11908+
response (str):
11909+
Optional. Model response to user input.
11910+
system_instruction (str):
11911+
Optional. System instructions for the model.
11912+
model (str):
11913+
Optional. The model to use for the GeminiExample.
11914+
cached_content (str):
11915+
Optional. The cached content to use for the GeminiExample.
11916+
tools (List[Tool]):
11917+
Optional. The tools to use for the GeminiExample.
11918+
tool_config (ToolConfig):
11919+
Optional. The tool config to use for the GeminiExample.
11920+
safety_settings (List[SafetySetting]):
11921+
Optional. The safety settings to use for the GeminiExample.
11922+
generation_config (GenerationConfig):
11923+
Optional. The generation config to use for the GeminiExample.
11924+
field_mapping (dict[str, str]):
11925+
Optional. Mapping of placeholders to dataset columns.
11926+
11927+
Returns:
11928+
A GeminiRequestReadConfig object.
11929+
"""
11930+
contents = []
11931+
if prompt:
11932+
contents.append(
11933+
genai_types.Content(
11934+
role="user",
11935+
parts=[
11936+
genai_types.Part.from_text(text=prompt),
11937+
],
11938+
)
11939+
)
11940+
if response:
11941+
contents.append(
11942+
genai_types.Content(
11943+
role="model",
11944+
parts=[
11945+
genai_types.Part.from_text(text=response),
11946+
],
11947+
)
11948+
)
11949+
11950+
sys_inst = None
11951+
if system_instruction:
11952+
sys_inst = genai_types.Content(
11953+
parts=[
11954+
genai_types.Part.from_text(text=system_instruction),
11955+
],
11956+
)
11957+
11958+
return cls(
11959+
template_config=GeminiTemplateConfig(
11960+
gemini_example=GeminiExample(
11961+
model=model,
11962+
contents=contents,
11963+
system_instruction=sys_inst,
11964+
cached_content=cached_content,
11965+
tools=tools,
11966+
tool_config=tool_config,
11967+
safety_settings=safety_settings,
11968+
generation_config=generation_config,
11969+
),
11970+
field_mapping=field_mapping,
11971+
),
11972+
)
11973+
1187611974

1187711975
class GeminiRequestReadConfigDict(TypedDict, total=False):
1187811976
"""Represents the config for reading Gemini requests."""

0 commit comments

Comments
 (0)