|
15 | 15 | # limitations under the License. |
16 | 16 | # |
17 | 17 |
|
18 | | -import uuid |
19 | 18 | from importlib import reload |
20 | 19 | from unittest import mock |
21 | 20 | from unittest.mock import patch |
| 21 | +import uuid |
22 | 22 |
|
23 | 23 | from google.api_core import operation |
24 | 24 | from google.cloud import aiplatform |
25 | 25 | from google.cloud.aiplatform import base |
26 | 26 | from google.cloud.aiplatform import initializer |
27 | | -from google.cloud.aiplatform.matching_engine._protos import ( |
28 | | - match_service_pb2, |
29 | | - match_service_pb2_grpc, |
30 | | -) |
31 | | -from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import ( |
32 | | - Namespace, |
33 | | - NumericNamespace, |
34 | | - MatchNeighbor, |
35 | | - HybridQuery, |
| 27 | +from google.cloud.aiplatform.compat.services import ( |
| 28 | + index_endpoint_service_client, |
| 29 | + index_service_client, |
| 30 | + match_service_client_v1beta1, |
36 | 31 | ) |
37 | 32 | from google.cloud.aiplatform.compat.types import ( |
38 | | - matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref, |
| 33 | + encryption_spec as gca_encryption_spec, |
39 | 34 | index_endpoint as gca_index_endpoint, |
| 35 | + index_v1beta1 as gca_index_v1beta1, |
40 | 36 | index as gca_index, |
41 | 37 | match_service_v1beta1 as gca_match_service_v1beta1, |
42 | | - index_v1beta1 as gca_index_v1beta1, |
| 38 | + matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref, |
43 | 39 | service_networking as gca_service_networking, |
44 | | - encryption_spec as gca_encryption_spec, |
45 | 40 | ) |
46 | | -from google.cloud.aiplatform.compat.services import ( |
47 | | - index_endpoint_service_client, |
48 | | - index_service_client, |
49 | | - match_service_client_v1beta1, |
| 41 | +from google.cloud.aiplatform.matching_engine._protos import ( |
| 42 | + match_service_pb2, |
| 43 | + match_service_pb2_grpc, |
| 44 | +) |
| 45 | +from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import ( |
| 46 | + HybridQuery, |
| 47 | + MatchNeighbor, |
| 48 | + Namespace, |
| 49 | + NumericNamespace, |
50 | 50 | ) |
51 | 51 | import constants as test_constants |
52 | | - |
53 | | -from google.protobuf import field_mask_pb2 |
54 | | - |
55 | 52 | import grpc |
56 | | - |
57 | 53 | import pytest |
58 | 54 |
|
| 55 | +from google.protobuf import field_mask_pb2 |
| 56 | +from google.protobuf import struct_pb2 |
| 57 | + |
59 | 58 | # project |
60 | 59 | _TEST_PROJECT = test_constants.ProjectConstants._TEST_PROJECT |
61 | 60 | _TEST_LOCATION = test_constants.ProjectConstants._TEST_LOCATION |
@@ -2409,70 +2408,82 @@ def test_index_endpoint_read_index_datapoints_for_private_service_connect_automa |
2409 | 2408 |
|
2410 | 2409 |
|
2411 | 2410 | class TestMatchNeighbor: |
2412 | | - def test_from_index_datapoint(self): |
2413 | | - index_datapoint = gca_index_v1beta1.IndexDatapoint() |
2414 | | - index_datapoint.datapoint_id = "test_datapoint_id" |
2415 | | - index_datapoint.feature_vector = [1.0, 2.0, 3.0] |
2416 | | - index_datapoint.crowding_tag = gca_index_v1beta1.IndexDatapoint.CrowdingTag( |
| 2411 | + def test_from_index_datapoint(self): |
| 2412 | + index_datapoint = gca_index_v1beta1.IndexDatapoint() |
| 2413 | + index_datapoint.datapoint_id = "test_datapoint_id" |
| 2414 | + index_datapoint.feature_vector = [1.0, 2.0, 3.0] |
| 2415 | + index_datapoint.crowding_tag = gca_index_v1beta1.IndexDatapoint.CrowdingTag( |
2417 | 2416 | crowding_attribute="test_crowding" |
2418 | 2417 | ) |
2419 | | - index_datapoint.restricts = [ |
| 2418 | + index_datapoint.restricts = [ |
2420 | 2419 | gca_index_v1beta1.IndexDatapoint.Restriction( |
2421 | 2420 | namespace="namespace1", allow_list=["token1"], deny_list=["token2"] |
2422 | 2421 | ), |
2423 | 2422 | ] |
2424 | | - index_datapoint.numeric_restricts = [ |
| 2423 | + index_datapoint.numeric_restricts = [ |
2425 | 2424 | gca_index_v1beta1.IndexDatapoint.NumericRestriction( |
2426 | 2425 | namespace="namespace2", |
2427 | 2426 | value_int=0, |
2428 | 2427 | ) |
2429 | 2428 | ] |
| 2429 | + index_datapoint.embedding_metadata = {"key": "value", "key2": "value2"} |
| 2430 | + |
| 2431 | + result = MatchNeighbor( |
| 2432 | + id="index_datapoint_id", distance=0.3 |
| 2433 | + ).from_index_datapoint(index_datapoint) |
| 2434 | + |
| 2435 | + assert result.feature_vector == [1.0, 2.0, 3.0] |
| 2436 | + assert result.crowding_tag == "test_crowding" |
| 2437 | + assert len(result.restricts) == 1 |
| 2438 | + assert result.restricts[0].name == "namespace1" |
| 2439 | + assert result.restricts[0].allow_tokens == ["token1"] |
| 2440 | + assert result.restricts[0].deny_tokens == ["token2"] |
| 2441 | + assert len(result.numeric_restricts) == 1 |
| 2442 | + assert result.numeric_restricts[0].name == "namespace2" |
| 2443 | + assert result.numeric_restricts[0].value_int == 0 |
| 2444 | + assert result.numeric_restricts[0].value_float is None |
| 2445 | + assert result.numeric_restricts[0].value_double is None |
| 2446 | + assert result.embedding_metadata == {"key": "value", "key2": "value2"} |
| 2447 | + |
| 2448 | + def test_from_embedding(self): |
| 2449 | + embedding_metadata_struct = struct_pb2.Struct() |
| 2450 | + embedding_metadata_struct.update({"key": "value", "key2": "value2"}) |
| 2451 | + |
| 2452 | + embedding = match_service_pb2.Embedding( |
| 2453 | + id="test_embedding_id", |
| 2454 | + float_val=[1.0, 2.0, 3.0], |
| 2455 | + crowding_attribute=1, |
| 2456 | + restricts=[ |
| 2457 | + match_service_pb2.Namespace( |
| 2458 | + name="namespace1", |
| 2459 | + allow_tokens=["token1"], |
| 2460 | + deny_tokens=["token2"], |
| 2461 | + ), |
| 2462 | + ], |
| 2463 | + numeric_restricts=[ |
| 2464 | + match_service_pb2.NumericNamespace( |
| 2465 | + name="namespace2", |
| 2466 | + value_int=10, |
| 2467 | + value_float=None, |
| 2468 | + value_double=None, |
| 2469 | + ) |
| 2470 | + ], |
| 2471 | + embedding_metadata=embedding_metadata_struct, |
| 2472 | + ) |
2430 | 2473 |
|
2431 | | - result = MatchNeighbor( |
2432 | | - id="index_datapoint_id", distance=0.3 |
2433 | | - ).from_index_datapoint(index_datapoint) |
2434 | | - |
2435 | | - assert result.feature_vector == [1.0, 2.0, 3.0] |
2436 | | - assert result.crowding_tag == "test_crowding" |
2437 | | - assert len(result.restricts) == 1 |
2438 | | - assert result.restricts[0].name == "namespace1" |
2439 | | - assert result.restricts[0].allow_tokens == ["token1"] |
2440 | | - assert result.restricts[0].deny_tokens == ["token2"] |
2441 | | - assert len(result.numeric_restricts) == 1 |
2442 | | - assert result.numeric_restricts[0].name == "namespace2" |
2443 | | - assert result.numeric_restricts[0].value_int == 0 |
2444 | | - assert result.numeric_restricts[0].value_float is None |
2445 | | - assert result.numeric_restricts[0].value_double is None |
2446 | | - |
2447 | | - def test_from_embedding(self): |
2448 | | - embedding = match_service_pb2.Embedding( |
2449 | | - id="test_embedding_id", |
2450 | | - float_val=[1.0, 2.0, 3.0], |
2451 | | - crowding_attribute=1, |
2452 | | - restricts=[ |
2453 | | - match_service_pb2.Namespace( |
2454 | | - name="namespace1", allow_tokens=["token1"], deny_tokens=["token2"] |
2455 | | - ), |
2456 | | - ], |
2457 | | - numeric_restricts=[ |
2458 | | - match_service_pb2.NumericNamespace( |
2459 | | - name="namespace2", value_int=10, value_float=None, value_double=None |
2460 | | - ) |
2461 | | - ], |
2462 | | - ) |
2463 | | - |
2464 | | - result = MatchNeighbor(id="embedding_id", distance=0.3).from_embedding( |
2465 | | - embedding |
2466 | | - ) |
| 2474 | + result = MatchNeighbor(id="embedding_id", distance=0.3).from_embedding( |
| 2475 | + embedding |
| 2476 | + ) |
2467 | 2477 |
|
2468 | | - assert result.feature_vector == [1.0, 2.0, 3.0] |
2469 | | - assert result.crowding_tag == "1" |
2470 | | - assert len(result.restricts) == 1 |
2471 | | - assert result.restricts[0].name == "namespace1" |
2472 | | - assert result.restricts[0].allow_tokens == ["token1"] |
2473 | | - assert result.restricts[0].deny_tokens == ["token2"] |
2474 | | - assert len(result.numeric_restricts) == 1 |
2475 | | - assert result.numeric_restricts[0].name == "namespace2" |
2476 | | - assert result.numeric_restricts[0].value_int == 10 |
2477 | | - assert not result.numeric_restricts[0].value_float |
2478 | | - assert not result.numeric_restricts[0].value_double |
| 2478 | + assert result.feature_vector == [1.0, 2.0, 3.0] |
| 2479 | + assert result.crowding_tag == "1" |
| 2480 | + assert len(result.restricts) == 1 |
| 2481 | + assert result.restricts[0].name == "namespace1" |
| 2482 | + assert result.restricts[0].allow_tokens == ["token1"] |
| 2483 | + assert result.restricts[0].deny_tokens == ["token2"] |
| 2484 | + assert len(result.numeric_restricts) == 1 |
| 2485 | + assert result.numeric_restricts[0].name == "namespace2" |
| 2486 | + assert result.numeric_restricts[0].value_int == 10 |
| 2487 | + assert not result.numeric_restricts[0].value_float |
| 2488 | + assert not result.numeric_restricts[0].value_double |
| 2489 | + assert result.embedding_metadata == {"key": "value", "key2": "value2"} |
0 commit comments