Skip to content

Commit ba5020d

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add support for metric_resource_name in rubric generation
PiperOrigin-RevId: 880805629
1 parent 1ecaa9b commit ba5020d

6 files changed

Lines changed: 157 additions & 10 deletions

File tree

tests/unit/vertexai/genai/replays/test_create_evaluation_run.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,28 @@ def test_create_eval_run_with_inference_configs(client):
238238
assert evaluation_run.error is None
239239

240240

241+
def test_create_eval_run_with_metric_resource_name(client):
242+
"""Tests create_evaluation_run with metric_resource_name."""
243+
client._api_client._http_options.api_version = "v1beta1"
244+
client._api_client._http_options.base_url = (
245+
"https://us-central1-autopush-aiplatform.sandbox.googleapis.com/"
246+
)
247+
metric_resource_name = "projects/977012026409/locations/us-central1/evaluationMetrics/6048334299558576128"
248+
metric = types.EvaluationRunMetric(
249+
metric="my_custom_metric",
250+
metric_resource_name=metric_resource_name,
251+
)
252+
evaluation_run = client.evals.create_evaluation_run(
253+
dataset=types.EvaluationDataset(
254+
eval_dataset_df=INPUT_DF_WITH_CONTEXT_AND_HISTORY
255+
),
256+
metrics=[metric],
257+
dest=GCS_DEST,
258+
)
259+
assert isinstance(evaluation_run, types.EvaluationRun)
260+
assert evaluation_run.evaluation_config.metrics[0].metric == "my_custom_metric"
261+
262+
241263
# Dataframe tests fail in replay mode because of UUID generation mismatch.
242264
# def test_create_eval_run_data_source_evaluation_dataset(client):
243265
# """Tests that create_evaluation_run() creates a correctly structured

tests/unit/vertexai/genai/replays/test_public_generate_rubrics.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,19 +143,21 @@
143143
User prompt:
144144
{prompt}"""
145145

146-
147-
def test_public_method_generate_rubrics(client):
148-
"""Tests the public generate_rubrics method."""
149-
prompts_df = pd.DataFrame(
146+
_PROMPTS_DF = pd.DataFrame(
150147
{
151148
"prompt": [
152149
"Explain the theory of relativity in one sentence.",
153150
"Write a short poem about a cat.",
154151
]
155152
}
156153
)
154+
155+
156+
def test_public_method_generate_rubrics(client):
157+
"""Tests the public generate_rubrics method."""
158+
157159
eval_dataset = client.evals.generate_rubrics(
158-
src=prompts_df,
160+
src=_PROMPTS_DF,
159161
prompt_template=_TEST_RUBRIC_GENERATION_PROMPT,
160162
rubric_group_name="text_quality_rubrics",
161163
)
@@ -176,6 +178,37 @@ def test_public_method_generate_rubrics(client):
176178
assert isinstance(first_rubric_group["text_quality_rubrics"][0], types.evals.Rubric)
177179

178180

181+
def test_public_method_generate_rubrics_with_metric(client):
182+
"""Tests the public generate_rubrics method with a metric."""
183+
client._api_client._http_options.api_version = "v1beta1"
184+
client._api_client._http_options.base_url = (
185+
"https://us-central1-staging-aiplatform.sandbox.googleapis.com/"
186+
)
187+
metric_resource_name = "projects/977012026409/locations/us-central1/evaluationMetrics/6048334299558576128"
188+
metric = types.Metric(
189+
name="my_custom_metric",
190+
metric_resource_name=metric_resource_name
191+
)
192+
eval_dataset = client.evals.generate_rubrics(
193+
src=_PROMPTS_DF,
194+
rubric_group_name="my_registered_rubrics",
195+
metric=metric
196+
)
197+
eval_dataset_df = eval_dataset.eval_dataset_df
198+
199+
assert isinstance(eval_dataset, types.EvaluationDataset)
200+
assert isinstance(eval_dataset_df, pd.DataFrame)
201+
assert "rubric_groups" in eval_dataset_df.columns
202+
assert len(eval_dataset_df) == 2
203+
204+
first_rubric_group = eval_dataset_df["rubric_groups"][0]
205+
assert isinstance(first_rubric_group, dict)
206+
assert "my_registered_rubrics" in first_rubric_group
207+
assert isinstance(first_rubric_group["my_registered_rubrics"], list)
208+
assert first_rubric_group["my_registered_rubrics"]
209+
assert isinstance(first_rubric_group["my_registered_rubrics"][0], types.evals.Rubric)
210+
211+
179212
pytestmark = pytest_helper.setup(
180213
file=__file__,
181214
globals_for_file=globals(),

vertexai/_genai/_evals_common.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from . import _gcs_utils
4646
from . import evals
4747
from . import types
48+
from . import _transformers as t
4849

4950
logger = logging.getLogger(__name__)
5051

@@ -1328,7 +1329,7 @@ def _resolve_dataset_inputs(
13281329

13291330

13301331
def _resolve_evaluation_run_metrics(
1331-
metrics: list[types.EvaluationRunMetric], api_client: Any
1332+
metrics: Union[list[types.EvaluationRunMetric], list[types.Metric]], api_client: Any
13321333
) -> list[types.EvaluationRunMetric]:
13331334
"""Resolves a list of evaluation run metric instances, loading RubricMetric if necessary."""
13341335
if not metrics:
@@ -1361,6 +1362,16 @@ def _resolve_evaluation_run_metrics(
13611362
e,
13621363
)
13631364
raise
1365+
elif isinstance(metric_instance, types.Metric):
1366+
config_dict = t.t_metrics([metric_instance])[0]
1367+
res_name = config_dict.pop("metric_resource_name", None)
1368+
resolved_metrics_list.append(
1369+
types.EvaluationRunMetric(
1370+
metric=metric_instance.name,
1371+
metric_config=config_dict if config_dict else None,
1372+
metric_resource_name=res_name,
1373+
)
1374+
)
13641375
else:
13651376
try:
13661377
metric_name_str = str(metric_instance)

vertexai/_genai/_transformers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def t_metrics(
3838

3939
for metric in metrics:
4040
metric_payload_item: dict[str, Any] = {}
41+
if hasattr(metric, "metric_resource_name") and metric.metric_resource_name:
42+
metric_payload_item["metric_resource_name"] = metric.metric_resource_name
4143

4244
metric_name = getv(metric, ["name"]).lower()
4345

@@ -79,6 +81,9 @@ def t_metrics(
7981
"return_raw_output": return_raw_output
8082
}
8183
metric_payload_item["pointwise_metric_spec"] = pointwise_spec
84+
elif "metric_resource_name" in metric_payload_item:
85+
# Valid case: Metric is identified by resource name; no inline spec required.
86+
pass
8287
else:
8388
raise ValueError(
8489
f"Unsupported metric type or invalid metric name: {metric_name}"

vertexai/_genai/evals.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,13 @@ def _EvaluationRunMetric_from_vertex(
392392
if getv(from_object, ["metric"]) is not None:
393393
setv(to_object, ["metric"], getv(from_object, ["metric"]))
394394

395+
if getv(from_object, ["metricResourceName"]) is not None:
396+
setv(
397+
to_object,
398+
["metric_resource_name"],
399+
getv(from_object, ["metricResourceName"]),
400+
)
401+
395402
if getv(from_object, ["metricConfig"]) is not None:
396403
setv(
397404
to_object,
@@ -410,6 +417,13 @@ def _EvaluationRunMetric_to_vertex(
410417
if getv(from_object, ["metric"]) is not None:
411418
setv(to_object, ["metric"], getv(from_object, ["metric"]))
412419

420+
if getv(from_object, ["metric_resource_name"]) is not None:
421+
setv(
422+
to_object,
423+
["metricResourceName"],
424+
getv(from_object, ["metric_resource_name"]),
425+
)
426+
413427
if getv(from_object, ["metric_config"]) is not None:
414428
setv(
415429
to_object,
@@ -512,6 +526,13 @@ def _GenerateInstanceRubricsRequest_to_vertex(
512526
),
513527
)
514528

529+
if getv(from_object, ["metric_resource_name"]) is not None:
530+
setv(
531+
to_object,
532+
["metricResourceName"],
533+
getv(from_object, ["metric_resource_name"]),
534+
)
535+
515536
if getv(from_object, ["config"]) is not None:
516537
setv(to_object, ["config"], getv(from_object, ["config"]))
517538

@@ -1049,6 +1070,7 @@ def _generate_rubrics(
10491070
types.PredefinedMetricSpecOrDict
10501071
] = None,
10511072
rubric_generation_spec: Optional[types.RubricGenerationSpecOrDict] = None,
1073+
metric_resource_name: Optional[str] = None,
10521074
config: Optional[types.RubricGenerationConfigOrDict] = None,
10531075
) -> types.GenerateInstanceRubricsResponse:
10541076
"""
@@ -1059,6 +1081,7 @@ def _generate_rubrics(
10591081
contents=contents,
10601082
predefined_rubric_generation_spec=predefined_rubric_generation_spec,
10611083
rubric_generation_spec=rubric_generation_spec,
1084+
metric_resource_name=metric_resource_name,
10621085
config=config,
10631086
)
10641087

@@ -1561,16 +1584,20 @@ def generate_rubrics(
15611584
rubric_type_ontology: Optional[list[str]] = None,
15621585
predefined_spec_name: Optional[Union[str, "types.PrebuiltMetric"]] = None,
15631586
metric_spec_parameters: Optional[dict[str, Any]] = None,
1587+
metric: Optional[types.MetricOrDict] = None,
15641588
config: Optional[types.RubricGenerationConfigOrDict] = None,
15651589
) -> types.EvaluationDataset:
15661590
"""Generates rubrics for each prompt in the source and adds them as a new column
15671591
structured as a dictionary.
15681592
15691593
You can generate rubrics by providing either:
1570-
1. A `predefined_spec_name` to use a Vertex AI backend recipe.
1571-
2. A `prompt_template` along with other configuration parameters
1594+
1. A `metric` to use a pre-registered metric resource.
1595+
2. A `predefined_spec_name` to use a Vertex AI backend recipe.
1596+
3. A `prompt_template` along with other configuration parameters
15721597
(`generator_model_config`, `rubric_content_type`, `rubric_type_ontology`)
15731598
for custom rubric generation.
1599+
with `metric` taking precedence over `predefined_spec_name`,
1600+
and `predefined_spec_name` taking precedence over `prompt_template`
15741601
15751602
These two modes are mutually exclusive.
15761603
@@ -1600,6 +1627,9 @@ def generate_rubrics(
16001627
metric_spec_parameters: Optional. Parameters for the Predefined Metric,
16011628
used to customize rubric generation. Only used if `predefined_spec_name` is set.
16021629
Example: {"guidelines": ["The response must be in Japanese."]}
1630+
metric: Optional. A types.Metric object containing a metric_resource_name,
1631+
or a resource name string. If provided, this will take precedence over
1632+
predefined_spec_name and prompt_template.
16031633
config: Optional. Configuration for the rubric generation process.
16041634
16051635
Returns:
@@ -1639,10 +1669,32 @@ def generate_rubrics(
16391669
)
16401670
all_rubric_groups: list[dict[str, list[types.Rubric]]] = []
16411671

1672+
actual_metric_resource_name = None
1673+
if metric:
1674+
if isinstance(metric, str) and metric.startswith("projects/"):
1675+
actual_metric_resource_name = metric
1676+
else:
1677+
metric_obj = (
1678+
types.Metric.model_validate(metric)
1679+
if isinstance(metric, dict)
1680+
else metric
1681+
)
1682+
actual_metric_resource_name = getattr(
1683+
metric_obj, "metric_resource_name", None
1684+
)
1685+
if not actual_metric_resource_name:
1686+
raise ValueError(
1687+
"The provided Metric object must have metric_resource_name set."
1688+
)
1689+
16421690
rubric_gen_spec = None
16431691
predefined_spec = None
16441692

1645-
if predefined_spec_name:
1693+
if actual_metric_resource_name:
1694+
# Precedence: Registered metric resource overrides everything else.
1695+
predefined_spec = None
1696+
rubric_gen_spec = None
1697+
elif predefined_spec_name:
16461698
if prompt_template:
16471699
logger.warning(
16481700
"prompt_template is ignored when predefined_spec_name is provided."
@@ -1699,7 +1751,7 @@ def generate_rubrics(
16991751
rubric_gen_spec = types.RubricGenerationSpec.model_validate(spec_dict)
17001752
else:
17011753
raise ValueError(
1702-
"Either predefined_spec_name or prompt_template must be provided."
1754+
"Either metric, predefined_spec_name or prompt_template must be provided."
17031755
)
17041756

17051757
for _, row in prompts_df.iterrows():
@@ -1722,6 +1774,7 @@ def generate_rubrics(
17221774
contents=contents,
17231775
rubric_generation_spec=rubric_gen_spec,
17241776
predefined_rubric_generation_spec=predefined_spec,
1777+
metric_resource_name=actual_metric_resource_name,
17251778
config=config,
17261779
)
17271780
rubric_group = {rubric_group_name: response.generated_rubrics}
@@ -2307,6 +2360,7 @@ async def _generate_rubrics(
23072360
types.PredefinedMetricSpecOrDict
23082361
] = None,
23092362
rubric_generation_spec: Optional[types.RubricGenerationSpecOrDict] = None,
2363+
metric_resource_name: Optional[str] = None,
23102364
config: Optional[types.RubricGenerationConfigOrDict] = None,
23112365
) -> types.GenerateInstanceRubricsResponse:
23122366
"""
@@ -2317,6 +2371,7 @@ async def _generate_rubrics(
23172371
contents=contents,
23182372
predefined_rubric_generation_spec=predefined_rubric_generation_spec,
23192373
rubric_generation_spec=rubric_generation_spec,
2374+
metric_resource_name=metric_resource_name,
23202375
config=config,
23212376
)
23222377

vertexai/_genai/types/common.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2479,6 +2479,10 @@ class EvaluationRunMetric(_common.BaseModel):
24792479
metric: Optional[str] = Field(
24802480
default=None, description="""The name of the metric."""
24812481
)
2482+
metric_resource_name: Optional[str] = Field(
2483+
default=None,
2484+
description="""The resource name of the metric definition. Example: projects/{project}/locations/{location}/evaluationMetrics/{evaluation_metric_id}""",
2485+
)
24822486
metric_config: Optional[UnifiedMetric] = Field(
24832487
default=None, description="""The unified metric used for evaluation run."""
24842488
)
@@ -2490,6 +2494,9 @@ class EvaluationRunMetricDict(TypedDict, total=False):
24902494
metric: Optional[str]
24912495
"""The name of the metric."""
24922496

2497+
metric_resource_name: Optional[str]
2498+
"""The resource name of the metric definition. Example: projects/{project}/locations/{location}/evaluationMetrics/{evaluation_metric_id}"""
2499+
24932500
metric_config: Optional[UnifiedMetricDict]
24942501
"""The unified metric used for evaluation run."""
24952502

@@ -4439,6 +4446,10 @@ class Metric(_common.BaseModel):
44394446
default=None,
44404447
description="""Optional steering instruction parameters for the automated predefined metric.""",
44414448
)
4449+
metric_resource_name: Optional[str] = Field(
4450+
default=None,
4451+
description="""The resource name of the metric definition. Example: projects/{project}/locations/{location}/evaluationMetrics/{evaluation_metric_id}""",
4452+
)
44424453

44434454
# Allow extra fields to support metric-specific config fields.
44444455
model_config = ConfigDict(extra="allow")
@@ -4643,6 +4654,9 @@ class MetricDict(TypedDict, total=False):
46434654
metric_spec_parameters: Optional[dict[str, Any]]
46444655
"""Optional steering instruction parameters for the automated predefined metric."""
46454656

4657+
metric_resource_name: Optional[str]
4658+
"""The resource name of the metric definition. Example: projects/{project}/locations/{location}/evaluationMetrics/{evaluation_metric_id}"""
4659+
46464660

46474661
MetricOrDict = Union[Metric, MetricDict]
46484662

@@ -5354,6 +5368,10 @@ class _GenerateInstanceRubricsRequest(_common.BaseModel):
53545368
default=None,
53555369
description="""Specification for how the rubrics should be generated.""",
53565370
)
5371+
metric_resource_name: Optional[str] = Field(
5372+
default=None,
5373+
description="""Registered metric resource name. If this field is set, the configuration provided in this field is used for rubric generation. The `predefined_rubric_generation_spec` and `rubric_generation_spec` fields will be ignored.""",
5374+
)
53575375
config: Optional[RubricGenerationConfig] = Field(default=None, description="""""")
53585376

53595377

@@ -5374,6 +5392,9 @@ class _GenerateInstanceRubricsRequestDict(TypedDict, total=False):
53745392
rubric_generation_spec: Optional[RubricGenerationSpecDict]
53755393
"""Specification for how the rubrics should be generated."""
53765394

5395+
metric_resource_name: Optional[str]
5396+
"""Registered metric resource name. If this field is set, the configuration provided in this field is used for rubric generation. The `predefined_rubric_generation_spec` and `rubric_generation_spec` fields will be ignored."""
5397+
53775398
config: Optional[RubricGenerationConfigDict]
53785399
""""""
53795400

0 commit comments

Comments
 (0)