-
Notifications
You must be signed in to change notification settings - Fork 2
Release 0.15.4 #446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Release 0.15.4 #446
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| __version__ = "0.15.3" | ||
| __version__ = "0.15.4" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,9 @@ | |
| get_qdrant, | ||
| get_sentence_transformer, | ||
| ) | ||
| from ...models import SearchQuery | ||
| from ...models import SearchQuery, SearchReturnModel | ||
| from qdrant_client.models import ScoredPoint | ||
| from pepdbagent.models import Namespace | ||
|
|
||
| load_dotenv() | ||
|
|
||
|
|
@@ -35,142 +37,72 @@ async def search_for_namespaces( | |
|
|
||
|
|
||
| # perform a search | ||
| @search.post("/", summary="Search for a PEP") | ||
| @search.post("/", summary="Search for a PEP", response_model=SearchReturnModel) | ||
| async def search_for_pep( | ||
| query: SearchQuery, | ||
| qdrant: QdrantClient = Depends(get_qdrant), | ||
| model: Embedding = Depends(get_sentence_transformer), | ||
| agent: PEPDatabaseAgent = Depends(get_db), | ||
| namespace_access: List[str] = Depends(get_namespace_access_list), | ||
| ): | ||
| ) -> SearchReturnModel: | ||
| """ | ||
| Perform a search for PEPs. This can be done using qdrant (semantic search), | ||
| or with basic SQL string matches. | ||
| """ | ||
| limit = query.limit | ||
| offset = query.offset | ||
| score_threshold = query.score_threshold | ||
| if qdrant is not None: | ||
| try: | ||
| # get the embeding for the query | ||
| query_vec = list(model.embed(query.query))[0] | ||
|
|
||
| # get actual results using the limit and offset | ||
| vector_results = qdrant.search( | ||
| collection_name=( | ||
| query.collection_name or DEFAULT_QDRANT_COLLECTION_NAME | ||
| ), | ||
| query_vector=query_vec, | ||
| limit=limit, | ||
| offset=offset, | ||
| score_threshold=score_threshold, | ||
| ) | ||
| # get namespaces: | ||
| namespaces: list[Namespace] = agent.namespace.get( | ||
| query=query.query, admin=namespace_access, limit=limit, offset=offset | ||
| ).results | ||
|
|
||
| # get sql results using the limit and offset | ||
| sql_results = agent.annotation.get( | ||
| query=query.query, | ||
| limit=limit, | ||
| offset=offset, | ||
| namespace=None, | ||
| admin=namespace_access, | ||
| ) | ||
| if qdrant is not None: | ||
| query_vec = list(model.embed(query.query))[0] | ||
|
|
||
| # map the results to the format we want | ||
| vector_results_mapped = [r.model_dump() for r in vector_results] | ||
| sql_results_mapped = [ | ||
| { | ||
| "id": r.digest, | ||
| "version": 0, | ||
| "score": 1.0, # Its a SQL search, so we just set the score to 1.0 | ||
| "payload": { | ||
| "description": r.description, | ||
| "registry": f"{r.namespace}/{r.name}:{r.tag}", | ||
| }, | ||
| "vector": None, | ||
| } | ||
| for r in sql_results.results | ||
| ] | ||
| results = vector_results_mapped + sql_results_mapped | ||
| namespaces = agent.namespace.get(admin=namespace_access) | ||
| namespace_hits = [ | ||
| n.namespace | ||
| for n in namespaces.results | ||
| if query.query.lower() in n.namespace.lower() | ||
| ] | ||
| namespace_hits.extend( | ||
| [ | ||
| n | ||
| for n in list( | ||
| set( | ||
| [ | ||
| r.model_dump()["payload"]["registry"].split("/")[0] | ||
| for r in vector_results | ||
| ] | ||
| ) | ||
| ) | ||
| if n not in namespace_hits | ||
| ] | ||
| ) | ||
| vector_results = qdrant.query_points( | ||
| collection_name=(query.collection_name or DEFAULT_QDRANT_COLLECTION_NAME), | ||
| query=query_vec, | ||
| limit=limit, | ||
| offset=offset, | ||
| score_threshold=score_threshold, | ||
| ).points | ||
khoroshevskyi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # finally, sort the results by score | ||
| results = sorted(results, key=lambda x: x["score"], reverse=True) | ||
| return SearchReturnModel( | ||
| query=query.query, | ||
| results=vector_results, | ||
| namespace_hits=namespaces, | ||
| limit=limit, | ||
| offset=offset, | ||
| total=len(vector_results), | ||
|
||
| ) | ||
|
|
||
| return JSONResponse( | ||
| content={ | ||
| "query": query.query, | ||
| "results": results, | ||
| "namespace_hits": namespace_hits, | ||
| "limit": limit, | ||
| "offset": offset, | ||
| "total": len(vector_results) + sql_results.count, | ||
| } | ||
| ) | ||
| except Exception as e: | ||
| # TODO: this isnt proper error handling. Also we need to use a logger | ||
| print("Qdrant search failed, falling back to SQL search. Reason: ", e) | ||
| else: | ||
| # fallback to SQL search | ||
| namespaces = agent.namespace.get(admin=namespace_access).results | ||
| results = agent.annotation.get( | ||
| query=query.query, limit=limit, offset=offset | ||
| ).results | ||
| results = agent.annotation.get(query=query.query, limit=limit, offset=offset) | ||
|
|
||
| # emulate qdrant response from the SQL search | ||
| # for frontend compatibility | ||
| parsed_results = [ | ||
| { | ||
| "id": None, | ||
| "version": 0, | ||
| "score": None, | ||
| "payload": { | ||
| ScoredPoint( | ||
| id=f"{r.namespace}/{r.name}:{r.tag}", | ||
| version=0, | ||
| score=1.0, # SQL search, so we just set the score to 1.0 | ||
| payload={ | ||
| "description": r.description, | ||
| "registry": f"{r.namespace}/{r.name}:{r.tag}", | ||
| }, | ||
| "vector": None, | ||
| } | ||
| for r in results | ||
| vector=None, | ||
| ) | ||
| for r in results.results | ||
| ] | ||
|
|
||
| namespace_hits = [ | ||
| n.namespace | ||
| for n in namespaces | ||
| if query.query.lower() in n.namespace.lower() | ||
| ] | ||
| namespace_hits.extend( | ||
| [ | ||
| n | ||
| for n in list( | ||
| set( | ||
| [r["payload"]["registry"].split("/")[0] for r in parsed_results] | ||
| ) | ||
| ) | ||
| if n not in namespace_hits | ||
| ] | ||
| ) | ||
| return JSONResponse( | ||
| content={ | ||
| "query": query.query, | ||
| "results": parsed_results, | ||
| "namespace_hits": namespace_hits, | ||
| } | ||
| return SearchReturnModel( | ||
| query=query.query, | ||
| results=parsed_results, | ||
| namespace_hits=namespaces, | ||
| limit=limit, | ||
| offset=offset, | ||
| total=results.count, | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.