diff --git a/src/app/api/api_v1/endpoints/micro_learning.py b/src/app/api/api_v1/endpoints/micro_learning.py index 460fcc8..ea75559 100644 --- a/src/app/api/api_v1/endpoints/micro_learning.py +++ b/src/app/api/api_v1/endpoints/micro_learning.py @@ -30,10 +30,13 @@ async def get_subject_list( lang: str | None = None, sp: SearchService = Depends(get_search_service) ) -> list[str]: - collection_info, model_id = await collection_and_model_id_according_lang( + collection_info, models_ids = await collection_and_model_id_according_lang( sp=sp, lang=lang ) - ret = [md.title for md in get_subjects(embedding_model_id=model_id)] + ret = [ + md.title + for md in get_subjects(embedding_models_ids=[model.id for model in models_ids]) + ] if len(ret) == 0: raise HTTPException(status_code=404, detail="No subjects found.") return ret @@ -50,13 +53,13 @@ async def get_full_journey( lang: str | None = None, sp: SearchService = Depends(get_search_service), ): - collection_info, model_id = await collection_and_model_id_according_lang( + collection_info, models_id = await collection_and_model_id_according_lang( sp=sp, lang=lang ) - + emb_models_ids = [model.id for model in models_id] journey_part = [i.lower() for i in JourneySectionType] sdg_meta_documents = get_context_documents( - journey_part=journey_part, sdg=sdg, embedding_model_id=model_id + journey_part=journey_part, sdg=sdg, embedding_models_ids=emb_models_ids ) if not sdg_meta_documents: raise HTTPException( @@ -64,7 +67,7 @@ async def get_full_journey( ) subject_meta_document: ContextDocument | None = get_subject( - subject=subject, embedding_model_id=model_id + subject=subject, embedding_models_ids=emb_models_ids ) if not subject_meta_document: diff --git a/src/app/services/helpers.py b/src/app/services/helpers.py index a4cee7c..9fce3d5 100644 --- a/src/app/services/helpers.py +++ b/src/app/services/helpers.py @@ -1,4 +1,3 @@ -import uuid from functools import cache from typing import Any, List @@ -8,6 +7,7 @@ from json_repair import JSONReturnType from langdetect import detect_langs from qdrant_client.http.models import models +from welearn_database.data.models import EmbeddingModel from src.app.models.collections import Collection from src.app.models.documents import JourneySectionType @@ -184,7 +184,7 @@ def choose_readability_according_journey_section_type( @log_time_and_error_sync async def collection_and_model_id_according_lang( lang: str | None, sp -) -> tuple[Collection, uuid]: +) -> tuple[Collection, list[EmbeddingModel]]: """ Get the collection info and model id according to the language. Args: @@ -203,9 +203,9 @@ async def collection_and_model_id_according_lang( raise HTTPException( status_code=404, detail=f"No collection found for language '{lang}'." ) - model_id = get_embeddings_model_id_according_name(collection_info.model) - if not model_id: + emb_models = get_embeddings_model_id_according_name(collection_info.model) + if not emb_models: raise ValueError( f"Embedding model '{collection_info.model}' not found in the database." ) - return collection_info, model_id + return collection_info, emb_models diff --git a/src/app/services/search.py b/src/app/services/search.py index d52fb56..fdad4c1 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -25,7 +25,10 @@ from src.app.services.data_quality import DataQualityChecker from src.app.services.exceptions import CollectionNotFoundError, ModelNotFoundError from src.app.services.helpers import convert_embedding_bytes -from src.app.services.sql_db.queries import get_subject +from src.app.services.sql_db.queries import ( + get_embeddings_model_id_according_name, + get_subject, +) from src.app.utils.decorators import log_time_and_error, log_time_and_error_sync from src.app.utils.logger import logger as logger_utils @@ -253,7 +256,9 @@ async def search_handler( assert isinstance(qp.query, str) collection = await self.get_collection_by_language(lang="mul") - subject_vector = await run_in_threadpool(get_subject_vector, qp.subject) + subject_vector = await run_in_threadpool( + get_subject_vector, qp.subject, collection.model + ) embedding = await self.get_query_embed( model=collection.model, query=qp.query, @@ -443,7 +448,7 @@ def concatenate_same_doc_id_slices( @log_time_and_error_sync -def get_subject_vector(subject: str | None) -> list[float] | None: +def get_subject_vector(subject: str | None, model_name: str) -> list[float] | None: """ Get the subject vector from the database. Args: @@ -455,7 +460,15 @@ def get_subject_vector(subject: str | None) -> list[float] | None: if not subject: return None - subject_from_db = get_subject(subject=subject) + emb_models = get_embeddings_model_id_according_name(model_name) + embedding_model_ids = [m.id for m in emb_models if m is not None] + if not embedding_model_ids: + return None + + subject_from_db = get_subject( + subject=subject, + embedding_models_ids=embedding_model_ids, + ) if not subject_from_db: return None diff --git a/src/app/services/sql_db/queries.py b/src/app/services/sql_db/queries.py index 170f7b4..370c44a 100644 --- a/src/app/services/sql_db/queries.py +++ b/src/app/services/sql_db/queries.py @@ -34,7 +34,6 @@ from src.app.services.constants import APP_NAME from src.app.services.sql_db.sql_service import session_maker -model_id_cache: dict[str, UUID] = {} model_id_lock = Lock() @@ -152,10 +151,13 @@ def register_endpoint(endpoint, session_id, http_code): session.commit() -def get_subject(subject: str, embedding_model_id: UUID) -> ContextDocument | None: +def get_subject( + subject: str, embedding_models_ids: list[UUID] +) -> ContextDocument | None: """ Get the subject meta document from the database. Args: + embedding_models_ids: Database IDs of embeddings models used for vectorize documents subject: The subject to get. Returns: The subject meta document. @@ -167,14 +169,14 @@ def get_subject(subject: str, embedding_model_id: UUID) -> ContextDocument | Non .filter( ContextDocument.context_type == ContextType.SUBJECT.value.lower(), ContextDocument.title == subject, - ContextDocument.embedding_model_id == embedding_model_id, + ContextDocument.embedding_model_id.in_(embedding_models_ids), ) .first() ) return subject_meta_document -def get_subjects(embedding_model_id: UUID) -> list[ContextDocument]: +def get_subjects(embedding_models_ids: list[UUID]) -> list[ContextDocument]: """ Get all the subject meta documents from the database. Returns: List of subject meta documents. @@ -184,7 +186,7 @@ def get_subjects(embedding_model_id: UUID) -> list[ContextDocument]: session.query(ContextDocument) .filter( ContextDocument.context_type == ContextType.SUBJECT.value.lower(), - ContextDocument.embedding_model_id == embedding_model_id, + ContextDocument.embedding_model_id.in_(embedding_models_ids), ) .all() ) @@ -192,12 +194,13 @@ def get_subjects(embedding_model_id: UUID) -> list[ContextDocument]: def get_context_documents( - journey_part: JourneySection, sdg: int, embedding_model_id: UUID + journey_part: list[JourneySection], sdg: int, embedding_models_ids: list[UUID] ): """ Get the context documents from the database. Args: + embedding_models_ids: Database IDs of embeddings models used for vectorize documents journey_part: The journey part to get the context documents for. sdg: The SDG to get the context documents for. Returns: List of context documents. @@ -208,14 +211,16 @@ def get_context_documents( .filter( ContextDocument.context_type.in_(journey_part), ContextDocument.sdg_related.contains([sdg]), - ContextDocument.embedding_model_id == embedding_model_id, + ContextDocument.embedding_model_id.in_(embedding_models_ids), ) .all() ) return sdg_meta_documents -def get_embeddings_model_id_according_name(model_name: str) -> UUID | None: +def get_embeddings_model_id_according_name( + model_name: str, +) -> list[EmbeddingModel | None]: """ Get the embeddings model ID according to its name. @@ -225,17 +230,12 @@ def get_embeddings_model_id_according_name(model_name: str) -> UUID | None: Returns: The ID of the embeddings model if found, otherwise None. """ - with model_id_lock: - if model_name in model_id_cache: - return model_id_cache[model_name] - with session_maker() as session: - model = ( + return ( session.query(EmbeddingModel) .filter(EmbeddingModel.title == model_name) - .first() + .all() ) - return model.id if model else None def write_new_data_quality_error( diff --git a/src/app/services/tutor/agents.py b/src/app/services/tutor/agents.py index 6bb1947..0c73a7a 100644 --- a/src/app/services/tutor/agents.py +++ b/src/app/services/tutor/agents.py @@ -1,5 +1,5 @@ -import time import json +import time from pathlib import Path from langchain_core.language_models import BaseChatModel diff --git a/src/app/tests/api/api_v1/test_micro_learning.py b/src/app/tests/api/api_v1/test_micro_learning.py index fcbdfd5..025c729 100644 --- a/src/app/tests/api/api_v1/test_micro_learning.py +++ b/src/app/tests/api/api_v1/test_micro_learning.py @@ -1,8 +1,9 @@ import unittest +import uuid from unittest import mock from fastapi.testclient import TestClient -from welearn_database.data.models import ContextDocument +from welearn_database.data.models import ContextDocument, EmbeddingModel from src.app.core.config import settings from src.app.models.collections import Collection @@ -38,7 +39,7 @@ async def test_get_full_journey( with TestClient(app) as client: mock_collection_and_model_id_according_lang.return_value = ( Collection(name="test_collection", lang="en", model="test_model"), - "model_id", + [EmbeddingModel(id=uuid.uuid4(), title="test_model", lang="en")], ) # Mock data mock_get_context_docs.return_value = [ @@ -114,7 +115,7 @@ async def test_get_subject_list( ] mock_collection_and_model_id_according_lang.return_value = ( None, - "model_id", + [], ) # API call