Skip to content

Commit 216003e

Browse files
lingyinwcopybara-github
authored andcommitted
fix: Add embedding_metadata to MatchNeighbor from_embedding.
PiperOrigin-RevId: 879842103
1 parent 317bf40 commit 216003e

2 files changed

Lines changed: 19 additions & 11 deletions

File tree

google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def __post_init__(self):
151151
@dataclass
152152
class HybridQuery:
153153
"""
154-
Hyrbid query. Could be used for dense-only or sparse-only or hybrid queries.
154+
Hybrid query. Could be used for dense-only or sparse-only or hybrid queries.
155155
156156
dense_embedding (List[float]):
157157
Optional. The dense part of the hybrid queries.
@@ -328,6 +328,10 @@ def from_embedding(self, embedding: match_service_pb2.Embedding) -> "MatchNeighb
328328
if embedding.sparse_embedding:
329329
self.sparse_embedding_values = embedding.sparse_embedding.float_val
330330
self.sparse_embedding_dimensions = embedding.sparse_embedding.dimension
331+
332+
# retrieve embedding metadata
333+
if embedding.embedding_metadata:
334+
self.embedding_metadata = embedding.embedding_metadata
331335
return self
332336

333337

@@ -1883,7 +1887,7 @@ def find_neighbors(
18831887
[
18841888
MatchNeighbor(
18851889
id=neighbor.datapoint.datapoint_id,
1886-
distance=neighbor.distance,
1890+
distance=neighbor.distance if neighbor.distance else None,
18871891
sparse_distance=(
18881892
neighbor.sparse_distance if neighbor.sparse_distance else None
18891893
),
@@ -2219,19 +2223,18 @@ def match(
22192223
# Wrap the results in MatchNeighbor objects and return
22202224
match_neighbors_response = []
22212225
for resp in response.responses[0].responses:
2222-
match_neighbors_id_map = {}
2226+
embedding_map = {embedding.id: embedding for embedding in resp.embeddings}
2227+
neighbors_list = []
22232228
for neighbor in resp.neighbor:
2224-
match_neighbors_id_map[neighbor.id] = MatchNeighbor(
2229+
match_neighbor = MatchNeighbor(
22252230
id=neighbor.id,
2226-
distance=neighbor.distance,
2231+
distance=neighbor.distance if neighbor.distance else None,
22272232
sparse_distance=(
22282233
neighbor.sparse_distance if neighbor.sparse_distance else None
22292234
),
22302235
)
2231-
for embedding in resp.embeddings:
2232-
if embedding.id in match_neighbors_id_map:
2233-
match_neighbors_id_map[embedding.id] = match_neighbors_id_map[
2234-
embedding.id
2235-
].from_embedding(embedding=embedding)
2236-
match_neighbors_response.append(list(match_neighbors_id_map.values()))
2236+
if neighbor.id in embedding_map:
2237+
match_neighbor.from_embedding(embedding=embedding_map[neighbor.id])
2238+
neighbors_list.append(match_neighbor)
2239+
match_neighbors_response.append(neighbors_list)
22372240
return match_neighbors_response

tests/unit/aiplatform/test_matching_engine_index_endpoint.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
import constants as test_constants
5252

5353
from google.protobuf import field_mask_pb2
54+
from google.protobuf import struct_pb2
5455

5556
import grpc
5657

@@ -2445,6 +2446,8 @@ def test_from_index_datapoint(self):
24452446
assert result.numeric_restricts[0].value_double is None
24462447

24472448
def test_from_embedding(self):
2449+
embedding_metadata_struct = struct_pb2.Struct()
2450+
embedding_metadata_struct["key"] = "value"
24482451
embedding = match_service_pb2.Embedding(
24492452
id="test_embedding_id",
24502453
float_val=[1.0, 2.0, 3.0],
@@ -2459,6 +2462,7 @@ def test_from_embedding(self):
24592462
name="namespace2", value_int=10, value_float=None, value_double=None
24602463
)
24612464
],
2465+
embedding_metadata=embedding_metadata_struct,
24622466
)
24632467

24642468
result = MatchNeighbor(id="embedding_id", distance=0.3).from_embedding(
@@ -2476,3 +2480,4 @@ def test_from_embedding(self):
24762480
assert result.numeric_restricts[0].value_int == 10
24772481
assert not result.numeric_restricts[0].value_float
24782482
assert not result.numeric_restricts[0].value_double
2483+
assert result.embedding_metadata == {"key": "value"}

0 commit comments

Comments
 (0)