@@ -2241,6 +2241,79 @@ def mock_model_fn(contents):
22412241 assert inference_result .candidate_name == "mock_model_fn"
22422242 assert inference_result .gcs_source is None
22432243
2244+ @mock .patch .object (_evals_utils , "EvalDatasetLoader" )
2245+ def test_inference_overwrites_existing_response_column_with_callable (
2246+ self , mock_eval_dataset_loader
2247+ ):
2248+ """Tests that run_inference overwrites an existing 'response' column."""
2249+ mock_df = pd .DataFrame (
2250+ {
2251+ "prompt" : ["test prompt" ],
2252+ "response" : ["old response" ],
2253+ }
2254+ )
2255+ mock_eval_dataset_loader .return_value .load .return_value = mock_df .to_dict (
2256+ orient = "records"
2257+ )
2258+
2259+ def mock_model_fn (contents ):
2260+ return "new response"
2261+
2262+ inference_result = self .client .evals .run_inference (
2263+ model = mock_model_fn ,
2264+ src = mock_df ,
2265+ )
2266+
2267+ result_df = inference_result .eval_dataset_df
2268+ # Assert there is exactly one 'response' column (no duplicates).
2269+ assert list (result_df .columns ).count ("response" ) == 1
2270+ # Assert the 'response' column contains the new inference result.
2271+ assert result_df ["response" ][0 ] == "new response"
2272+ assert "prompt" in result_df .columns
2273+
2274+ @mock .patch .object (_evals_common , "Models" )
2275+ @mock .patch .object (_evals_utils , "EvalDatasetLoader" )
2276+ def test_inference_overwrites_existing_response_column_with_gemini (
2277+ self , mock_eval_dataset_loader , mock_models
2278+ ):
2279+ """Tests that run_inference with Gemini overwrites an existing 'response' column."""
2280+ mock_df = pd .DataFrame (
2281+ {
2282+ "prompt" : ["test prompt" ],
2283+ "response" : ["old response" ],
2284+ }
2285+ )
2286+ mock_eval_dataset_loader .return_value .load .return_value = mock_df .to_dict (
2287+ orient = "records"
2288+ )
2289+
2290+ mock_generate_content_response = genai_types .GenerateContentResponse (
2291+ candidates = [
2292+ genai_types .Candidate (
2293+ content = genai_types .Content (
2294+ parts = [genai_types .Part (text = "new gemini response" )]
2295+ ),
2296+ finish_reason = genai_types .FinishReason .STOP ,
2297+ )
2298+ ],
2299+ prompt_feedback = None ,
2300+ )
2301+ mock_models .return_value .generate_content .return_value = (
2302+ mock_generate_content_response
2303+ )
2304+
2305+ inference_result = self .client .evals .run_inference (
2306+ model = "gemini-pro" ,
2307+ src = mock_df ,
2308+ )
2309+
2310+ result_df = inference_result .eval_dataset_df
2311+ # Assert there is exactly one 'response' column (no duplicates).
2312+ assert list (result_df .columns ).count ("response" ) == 1
2313+ # Assert the 'response' column contains the new inference result.
2314+ assert result_df ["response" ][0 ] == "new gemini response"
2315+ assert "prompt" in result_df .columns
2316+
22442317 @mock .patch .object (_evals_common , "Models" )
22452318 @mock .patch .object (_evals_utils , "EvalDatasetLoader" )
22462319 def test_inference_with_prompt_template (
0 commit comments