Skip to content

Commit 657bb26

Browse files
jsondaicopybara-github
authored andcommitted
chore: GenAI Client(evals) - Fix run_inference producing duplicate 'response' column
PiperOrigin-RevId: 900941216
1 parent 04c5e02 commit 657bb26

2 files changed

Lines changed: 88 additions & 1 deletion

File tree

tests/unit/vertexai/genai/test_evals.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2241,6 +2241,79 @@ def mock_model_fn(contents):
22412241
assert inference_result.candidate_name == "mock_model_fn"
22422242
assert inference_result.gcs_source is None
22432243

2244+
@mock.patch.object(_evals_utils, "EvalDatasetLoader")
2245+
def test_inference_overwrites_existing_response_column_with_callable(
2246+
self, mock_eval_dataset_loader
2247+
):
2248+
"""Tests that run_inference overwrites an existing 'response' column."""
2249+
mock_df = pd.DataFrame(
2250+
{
2251+
"prompt": ["test prompt"],
2252+
"response": ["old response"],
2253+
}
2254+
)
2255+
mock_eval_dataset_loader.return_value.load.return_value = mock_df.to_dict(
2256+
orient="records"
2257+
)
2258+
2259+
def mock_model_fn(contents):
2260+
return "new response"
2261+
2262+
inference_result = self.client.evals.run_inference(
2263+
model=mock_model_fn,
2264+
src=mock_df,
2265+
)
2266+
2267+
result_df = inference_result.eval_dataset_df
2268+
# Assert there is exactly one 'response' column (no duplicates).
2269+
assert list(result_df.columns).count("response") == 1
2270+
# Assert the 'response' column contains the new inference result.
2271+
assert result_df["response"][0] == "new response"
2272+
assert "prompt" in result_df.columns
2273+
2274+
@mock.patch.object(_evals_common, "Models")
2275+
@mock.patch.object(_evals_utils, "EvalDatasetLoader")
2276+
def test_inference_overwrites_existing_response_column_with_gemini(
2277+
self, mock_eval_dataset_loader, mock_models
2278+
):
2279+
"""Tests that run_inference with Gemini overwrites an existing 'response' column."""
2280+
mock_df = pd.DataFrame(
2281+
{
2282+
"prompt": ["test prompt"],
2283+
"response": ["old response"],
2284+
}
2285+
)
2286+
mock_eval_dataset_loader.return_value.load.return_value = mock_df.to_dict(
2287+
orient="records"
2288+
)
2289+
2290+
mock_generate_content_response = genai_types.GenerateContentResponse(
2291+
candidates=[
2292+
genai_types.Candidate(
2293+
content=genai_types.Content(
2294+
parts=[genai_types.Part(text="new gemini response")]
2295+
),
2296+
finish_reason=genai_types.FinishReason.STOP,
2297+
)
2298+
],
2299+
prompt_feedback=None,
2300+
)
2301+
mock_models.return_value.generate_content.return_value = (
2302+
mock_generate_content_response
2303+
)
2304+
2305+
inference_result = self.client.evals.run_inference(
2306+
model="gemini-pro",
2307+
src=mock_df,
2308+
)
2309+
2310+
result_df = inference_result.eval_dataset_df
2311+
# Assert there is exactly one 'response' column (no duplicates).
2312+
assert list(result_df.columns).count("response") == 1
2313+
# Assert the 'response' column contains the new inference result.
2314+
assert result_df["response"][0] == "new gemini response"
2315+
assert "prompt" in result_df.columns
2316+
22442317
@mock.patch.object(_evals_common, "Models")
22452318
@mock.patch.object(_evals_utils, "EvalDatasetLoader")
22462319
def test_inference_with_prompt_template(

vertexai/_genai/_evals_common.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -949,11 +949,25 @@ def _run_inference_internal(
949949

950950
results_df_responses_only = pd.DataFrame(
951951
{
952-
"response": responses,
952+
_evals_constant.RESPONSE: responses,
953953
}
954954
)
955955

956956
prompt_dataset_indexed = prompt_dataset.reset_index(drop=True)
957+
958+
# Drop existing 'response' column to prevent duplicate column names when
959+
# re-running inference on a dataset that already has responses.
960+
if _evals_constant.RESPONSE in prompt_dataset_indexed.columns:
961+
logger.warning(
962+
"A column named '%s' already exists in the prompt dataset. "
963+
"The existing column will be dropped and replaced with the new "
964+
"inference results.",
965+
_evals_constant.RESPONSE,
966+
)
967+
prompt_dataset_indexed = prompt_dataset_indexed.drop(
968+
columns=[_evals_constant.RESPONSE]
969+
)
970+
957971
results_df_responses_only_indexed = results_df_responses_only.reset_index(drop=True)
958972

959973
results_df = pd.concat(

0 commit comments

Comments
 (0)