From aeb763fe71455a394ff298d07871db61fa1130e0 Mon Sep 17 00:00:00 2001 From: William Cheung Date: Fri, 20 Dec 2024 17:35:21 -0500 Subject: [PATCH] Update `query` to return `embedding: Sequence[float]` in `QueryResult`. Note: this increases memory usage obviously, but is useful for exporting vectors. TODO: Add a flag `return_embedding` to disable this by default. --- tidb_vector/integrations/vector_client.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tidb_vector/integrations/vector_client.py b/tidb_vector/integrations/vector_client.py index 174c2a9..cd6b846 100644 --- a/tidb_vector/integrations/vector_client.py +++ b/tidb_vector/integrations/vector_client.py @@ -4,7 +4,7 @@ import logging import enum import uuid -from typing import Type, Tuple, Any, Dict, Generator, Iterable, List, Optional +from typing import Sequence, Type, Tuple, Any, Dict, Generator, Iterable, List, Optional import sqlalchemy from sqlalchemy.orm import Session, declarative_base @@ -73,6 +73,7 @@ class QueryResult: document: str metadata: dict distance: float + embedding: Sequence[float] class TiDBVectorClient: @@ -303,6 +304,7 @@ def query( metadata=doc.meta, id=doc.id, distance=doc.distance, + embedding=doc.embedding, ) for doc in relevant_docs ] @@ -326,6 +328,7 @@ def _vector_search( self._table_model.id, self._table_model.meta, self._table_model.document, + self._table_model.embedding, self.distance_strategy(query_embedding).label("distance"), ) .filter(filter_by) @@ -342,6 +345,7 @@ def _vector_search( self._table_model.id, self._table_model.meta, self._table_model.document, + self._table_model.embedding, self.distance_strategy(query_embedding).label("distance"), ) .order_by(sqlalchemy.asc("distance")) @@ -354,6 +358,7 @@ def _vector_search( subquery.c.id, subquery.c.meta, subquery.c.document, + subquery.c.embedding, subquery.c.distance, ) .filter(filter_by)