diff --git a/tests/unit/vertexai/test_evaluation.py b/tests/unit/vertexai/test_evaluation.py index 1e0b0f7d56..70775dfdc9 100644 --- a/tests/unit/vertexai/test_evaluation.py +++ b/tests/unit/vertexai/test_evaluation.py @@ -15,6 +15,7 @@ # limitations under the License. # +import json import re import sys import threading @@ -851,11 +852,26 @@ def test_compute_exact_match_metric(self, api_transport): ) test_metrics = ["exact_match"] test_eval_task = EvalTask(dataset=eval_dataset, metrics=test_metrics) - mock_metric_results = _MOCK_EXACT_MATCH_RESULT + + def _exact_match_side_effect(**kwargs): + request = kwargs.get("request") + prediction = request.exact_match_input.instances[0].prediction + reference = request.exact_match_input.instances[0].reference + score = 1.0 if prediction == reference else 0.0 + return gapic_evaluation_service_types.EvaluateInstancesResponse( + exact_match_results=gapic_evaluation_service_types.ExactMatchResults( + exact_match_metric_values=[ + gapic_evaluation_service_types.ExactMatchMetricValue( + score=score + ), + ] + ) + ) + with mock.patch.object( target=gapic_evaluation_services.EvaluationServiceClient, attribute="evaluate_instances", - side_effect=mock_metric_results, + side_effect=_exact_match_side_effect, ): test_result = test_eval_task.evaluate() @@ -932,11 +948,26 @@ def test_compute_pointwise_metrics_free_string(self): metrics=[_TEST_POINTWISE_METRIC_FREE_STRING], metric_column_mapping={"abc": "prompt"}, ) - mock_metric_results = _MOCK_POINTWISE_RESULT + + def _pointwise_side_effect(**kwargs): + request = kwargs.get("request") + instance_data = json.loads( + request.pointwise_metric_input.instance.json_instance + ) + # Row with prompt "test_prompt" gets score 5, "text_prompt" gets 4. + score = 5 if instance_data.get("abc") == "test_prompt" else 4 + return gapic_evaluation_service_types.EvaluateInstancesResponse( + pointwise_metric_result=( + gapic_evaluation_service_types.PointwiseMetricResult( + score=score, explanation="explanation" + ) + ) + ) + with mock.patch.object( target=gapic_evaluation_services.EvaluationServiceClient, attribute="evaluate_instances", - side_effect=mock_metric_results, + side_effect=_pointwise_side_effect, ): test_result = test_eval_task.evaluate() @@ -1095,11 +1126,26 @@ def test_compute_pointwise_metrics_without_model_inference(self, api_transport): test_eval_task = EvalTask( dataset=_TEST_EVAL_DATASET_ALL_INCLUDED, metrics=test_metrics ) - mock_metric_results = _MOCK_SUMMARIZATION_QUALITY_RESULT + + def _summarization_side_effect(**kwargs): + request = kwargs.get("request") + instance_data = json.loads( + request.pointwise_metric_input.instance.json_instance + ) + # Row with response "test" gets score 5, "text" gets score 4. + score = 5 if instance_data.get("response") == "test" else 4 + return gapic_evaluation_service_types.EvaluateInstancesResponse( + pointwise_metric_result=( + gapic_evaluation_service_types.PointwiseMetricResult( + score=score, explanation="explanation" + ) + ) + ) + with mock.patch.object( target=gapic_evaluation_services.EvaluationServiceClient, attribute="evaluate_instances", - side_effect=mock_metric_results, + side_effect=_summarization_side_effect, ): test_result = test_eval_task.evaluate()