From 8b5b2f659c2bf848918648f8564dcb832d79f57e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o?= Date: Thu, 5 Mar 2026 12:04:33 +0100 Subject: [PATCH 01/14] Remove unused cache dict --- src/app/services/sql_db/queries.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/app/services/sql_db/queries.py b/src/app/services/sql_db/queries.py index 170f7b4..031709b 100644 --- a/src/app/services/sql_db/queries.py +++ b/src/app/services/sql_db/queries.py @@ -34,7 +34,6 @@ from src.app.services.constants import APP_NAME from src.app.services.sql_db.sql_service import session_maker -model_id_cache: dict[str, UUID] = {} model_id_lock = Lock() @@ -215,7 +214,7 @@ def get_context_documents( return sdg_meta_documents -def get_embeddings_model_id_according_name(model_name: str) -> UUID | None: +def get_embeddings_model_id_according_name(model_name: str) -> list[UUID] | None: """ Get the embeddings model ID according to its name. @@ -225,10 +224,6 @@ def get_embeddings_model_id_according_name(model_name: str) -> UUID | None: Returns: The ID of the embeddings model if found, otherwise None. """ - with model_id_lock: - if model_name in model_id_cache: - return model_id_cache[model_name] - with session_maker() as session: model = ( session.query(EmbeddingModel) From 27279659e16b802ba9263307ad70d879e3de59d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o?= Date: Thu, 5 Mar 2026 12:20:55 +0100 Subject: [PATCH 02/14] feat(micro-learning): update embedding model handling to support multiple models --- src/app/api/api_v1/endpoints/micro_learning.py | 7 +++++-- src/app/services/helpers.py | 9 +++++---- src/app/services/sql_db/queries.py | 13 +++++++------ 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/app/api/api_v1/endpoints/micro_learning.py b/src/app/api/api_v1/endpoints/micro_learning.py index 460fcc8..b54a7c3 100644 --- a/src/app/api/api_v1/endpoints/micro_learning.py +++ b/src/app/api/api_v1/endpoints/micro_learning.py @@ -30,10 +30,13 @@ async def get_subject_list( lang: str | None = None, sp: SearchService = Depends(get_search_service) ) -> list[str]: - collection_info, model_id = await collection_and_model_id_according_lang( + collection_info, models_ids = await collection_and_model_id_according_lang( sp=sp, lang=lang ) - ret = [md.title for md in get_subjects(embedding_model_id=model_id)] + ret = [ + md.title + for md in get_subjects(embedding_models_ids=[model.id for model in models_ids]) + ] if len(ret) == 0: raise HTTPException(status_code=404, detail="No subjects found.") return ret diff --git a/src/app/services/helpers.py b/src/app/services/helpers.py index a4cee7c..d5374a8 100644 --- a/src/app/services/helpers.py +++ b/src/app/services/helpers.py @@ -8,6 +8,7 @@ from json_repair import JSONReturnType from langdetect import detect_langs from qdrant_client.http.models import models +from welearn_database.data.models import EmbeddingModel from src.app.models.collections import Collection from src.app.models.documents import JourneySectionType @@ -184,7 +185,7 @@ def choose_readability_according_journey_section_type( @log_time_and_error_sync async def collection_and_model_id_according_lang( lang: str | None, sp -) -> tuple[Collection, uuid]: +) -> tuple[Collection, list[EmbeddingModel]]: """ Get the collection info and model id according to the language. Args: @@ -203,9 +204,9 @@ async def collection_and_model_id_according_lang( raise HTTPException( status_code=404, detail=f"No collection found for language '{lang}'." ) - model_id = get_embeddings_model_id_according_name(collection_info.model) - if not model_id: + emb_models = get_embeddings_model_id_according_name(collection_info.model) + if not emb_models: raise ValueError( f"Embedding model '{collection_info.model}' not found in the database." ) - return collection_info, model_id + return collection_info, emb_models diff --git a/src/app/services/sql_db/queries.py b/src/app/services/sql_db/queries.py index 031709b..1af13aa 100644 --- a/src/app/services/sql_db/queries.py +++ b/src/app/services/sql_db/queries.py @@ -173,7 +173,7 @@ def get_subject(subject: str, embedding_model_id: UUID) -> ContextDocument | Non return subject_meta_document -def get_subjects(embedding_model_id: UUID) -> list[ContextDocument]: +def get_subjects(embedding_models_ids: list[UUID]) -> list[ContextDocument]: """ Get all the subject meta documents from the database. Returns: List of subject meta documents. @@ -183,7 +183,7 @@ def get_subjects(embedding_model_id: UUID) -> list[ContextDocument]: session.query(ContextDocument) .filter( ContextDocument.context_type == ContextType.SUBJECT.value.lower(), - ContextDocument.embedding_model_id == embedding_model_id, + ContextDocument.embedding_model_id.in_(embedding_models_ids), ) .all() ) @@ -214,7 +214,9 @@ def get_context_documents( return sdg_meta_documents -def get_embeddings_model_id_according_name(model_name: str) -> list[UUID] | None: +def get_embeddings_model_id_according_name( + model_name: str, +) -> list[EmbeddingModel | None]: """ Get the embeddings model ID according to its name. @@ -225,12 +227,11 @@ def get_embeddings_model_id_according_name(model_name: str) -> list[UUID] | None The ID of the embeddings model if found, otherwise None. """ with session_maker() as session: - model = ( + return ( session.query(EmbeddingModel) .filter(EmbeddingModel.title == model_name) - .first() + .all() ) - return model.id if model else None def write_new_data_quality_error( From dbcfde0ed8b8891545ac4fd89820e1dfadf7fabf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o?= Date: Thu, 5 Mar 2026 15:01:50 +0100 Subject: [PATCH 03/14] feat(micro-learning): update functions to support multiple embedding models --- src/app/api/api_v1/endpoints/micro_learning.py | 8 ++++---- src/app/services/search.py | 8 +++++--- src/app/services/sql_db/queries.py | 12 ++++++++---- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/app/api/api_v1/endpoints/micro_learning.py b/src/app/api/api_v1/endpoints/micro_learning.py index b54a7c3..ea75559 100644 --- a/src/app/api/api_v1/endpoints/micro_learning.py +++ b/src/app/api/api_v1/endpoints/micro_learning.py @@ -53,13 +53,13 @@ async def get_full_journey( lang: str | None = None, sp: SearchService = Depends(get_search_service), ): - collection_info, model_id = await collection_and_model_id_according_lang( + collection_info, models_id = await collection_and_model_id_according_lang( sp=sp, lang=lang ) - + emb_models_ids = [model.id for model in models_id] journey_part = [i.lower() for i in JourneySectionType] sdg_meta_documents = get_context_documents( - journey_part=journey_part, sdg=sdg, embedding_model_id=model_id + journey_part=journey_part, sdg=sdg, embedding_models_ids=emb_models_ids ) if not sdg_meta_documents: raise HTTPException( @@ -67,7 +67,7 @@ async def get_full_journey( ) subject_meta_document: ContextDocument | None = get_subject( - subject=subject, embedding_model_id=model_id + subject=subject, embedding_models_ids=emb_models_ids ) if not subject_meta_document: diff --git a/src/app/services/search.py b/src/app/services/search.py index d52fb56..d98d461 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -25,7 +25,7 @@ from src.app.services.data_quality import DataQualityChecker from src.app.services.exceptions import CollectionNotFoundError, ModelNotFoundError from src.app.services.helpers import convert_embedding_bytes -from src.app.services.sql_db.queries import get_subject +from src.app.services.sql_db.queries import get_subject, get_embeddings_model_id_according_name from src.app.utils.decorators import log_time_and_error, log_time_and_error_sync from src.app.utils.logger import logger as logger_utils @@ -443,7 +443,7 @@ def concatenate_same_doc_id_slices( @log_time_and_error_sync -def get_subject_vector(subject: str | None) -> list[float] | None: +def get_subject_vector(subject: str | None, model_name: str) -> list[float] | None: """ Get the subject vector from the database. Args: @@ -455,7 +455,9 @@ def get_subject_vector(subject: str | None) -> list[float] | None: if not subject: return None - subject_from_db = get_subject(subject=subject) + emb_models = get_embeddings_model_id_according_name(model_name) + + subject_from_db = get_subject(subject=subject, embedding_models_ids=[m.id for m in emb_models]) if not subject_from_db: return None diff --git a/src/app/services/sql_db/queries.py b/src/app/services/sql_db/queries.py index 1af13aa..370c44a 100644 --- a/src/app/services/sql_db/queries.py +++ b/src/app/services/sql_db/queries.py @@ -151,10 +151,13 @@ def register_endpoint(endpoint, session_id, http_code): session.commit() -def get_subject(subject: str, embedding_model_id: UUID) -> ContextDocument | None: +def get_subject( + subject: str, embedding_models_ids: list[UUID] +) -> ContextDocument | None: """ Get the subject meta document from the database. Args: + embedding_models_ids: Database IDs of embeddings models used for vectorize documents subject: The subject to get. Returns: The subject meta document. @@ -166,7 +169,7 @@ def get_subject(subject: str, embedding_model_id: UUID) -> ContextDocument | Non .filter( ContextDocument.context_type == ContextType.SUBJECT.value.lower(), ContextDocument.title == subject, - ContextDocument.embedding_model_id == embedding_model_id, + ContextDocument.embedding_model_id.in_(embedding_models_ids), ) .first() ) @@ -191,12 +194,13 @@ def get_subjects(embedding_models_ids: list[UUID]) -> list[ContextDocument]: def get_context_documents( - journey_part: JourneySection, sdg: int, embedding_model_id: UUID + journey_part: list[JourneySection], sdg: int, embedding_models_ids: list[UUID] ): """ Get the context documents from the database. Args: + embedding_models_ids: Database IDs of embeddings models used for vectorize documents journey_part: The journey part to get the context documents for. sdg: The SDG to get the context documents for. Returns: List of context documents. @@ -207,7 +211,7 @@ def get_context_documents( .filter( ContextDocument.context_type.in_(journey_part), ContextDocument.sdg_related.contains([sdg]), - ContextDocument.embedding_model_id == embedding_model_id, + ContextDocument.embedding_model_id.in_(embedding_models_ids), ) .all() ) From cf5b0a2be959f2f5819b13a092069fbd0992b776 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o?= Date: Thu, 5 Mar 2026 15:02:19 +0100 Subject: [PATCH 04/14] type ignore --- src/app/services/search.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/app/services/search.py b/src/app/services/search.py index d98d461..14cde2d 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -25,7 +25,10 @@ from src.app.services.data_quality import DataQualityChecker from src.app.services.exceptions import CollectionNotFoundError, ModelNotFoundError from src.app.services.helpers import convert_embedding_bytes -from src.app.services.sql_db.queries import get_subject, get_embeddings_model_id_according_name +from src.app.services.sql_db.queries import ( + get_embeddings_model_id_according_name, + get_subject, +) from src.app.utils.decorators import log_time_and_error, log_time_and_error_sync from src.app.utils.logger import logger as logger_utils @@ -457,7 +460,9 @@ def get_subject_vector(subject: str | None, model_name: str) -> list[float] | No emb_models = get_embeddings_model_id_according_name(model_name) - subject_from_db = get_subject(subject=subject, embedding_models_ids=[m.id for m in emb_models]) + subject_from_db = get_subject( + subject=subject, embedding_models_ids=[m.id for m in emb_models] # type: ignore + ) if not subject_from_db: return None From df64c55f0e1c7c405c04886b75197d8f087813b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o?= Date: Thu, 5 Mar 2026 15:50:24 +0100 Subject: [PATCH 05/14] refactor(queries): clean up unused imports in queries.py --- src/app/services/sql_db/queries.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/app/services/sql_db/queries.py b/src/app/services/sql_db/queries.py index 370c44a..f70abf9 100644 --- a/src/app/services/sql_db/queries.py +++ b/src/app/services/sql_db/queries.py @@ -7,7 +7,7 @@ from qdrant_client.http.models import ScoredPoint from sqlalchemy import func, select from welearn_database.data.enumeration import Step -from welearn_database.data.models import ( +from welearn_database.data.models import ( # FilterType,; FilterUsedInQuery, ChatMessage, ContextDocument, Corpus, @@ -17,8 +17,6 @@ EmbeddingModel, EndpointRequest, ErrorDataQuality, - FilterType, - FilterUsedInQuery, ProcessState, QtyDocumentInQdrant, QtyDocumentInQdrantPerCorpus, From d5b2a66a438fdfe65d238eb1ab747e9512c7298c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o?= Date: Thu, 5 Mar 2026 12:04:33 +0100 Subject: [PATCH 06/14] Remove unused cache dict --- src/app/services/sql_db/queries.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/app/services/sql_db/queries.py b/src/app/services/sql_db/queries.py index 170f7b4..031709b 100644 --- a/src/app/services/sql_db/queries.py +++ b/src/app/services/sql_db/queries.py @@ -34,7 +34,6 @@ from src.app.services.constants import APP_NAME from src.app.services.sql_db.sql_service import session_maker -model_id_cache: dict[str, UUID] = {} model_id_lock = Lock() @@ -215,7 +214,7 @@ def get_context_documents( return sdg_meta_documents -def get_embeddings_model_id_according_name(model_name: str) -> UUID | None: +def get_embeddings_model_id_according_name(model_name: str) -> list[UUID] | None: """ Get the embeddings model ID according to its name. @@ -225,10 +224,6 @@ def get_embeddings_model_id_according_name(model_name: str) -> UUID | None: Returns: The ID of the embeddings model if found, otherwise None. """ - with model_id_lock: - if model_name in model_id_cache: - return model_id_cache[model_name] - with session_maker() as session: model = ( session.query(EmbeddingModel) From f8b14cde6d69303a283229e3462753976dae9503 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o?= Date: Thu, 5 Mar 2026 12:20:55 +0100 Subject: [PATCH 07/14] feat(micro-learning): update embedding model handling to support multiple models --- src/app/api/api_v1/endpoints/micro_learning.py | 7 +++++-- src/app/services/helpers.py | 9 +++++---- src/app/services/sql_db/queries.py | 13 +++++++------ 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/app/api/api_v1/endpoints/micro_learning.py b/src/app/api/api_v1/endpoints/micro_learning.py index 460fcc8..b54a7c3 100644 --- a/src/app/api/api_v1/endpoints/micro_learning.py +++ b/src/app/api/api_v1/endpoints/micro_learning.py @@ -30,10 +30,13 @@ async def get_subject_list( lang: str | None = None, sp: SearchService = Depends(get_search_service) ) -> list[str]: - collection_info, model_id = await collection_and_model_id_according_lang( + collection_info, models_ids = await collection_and_model_id_according_lang( sp=sp, lang=lang ) - ret = [md.title for md in get_subjects(embedding_model_id=model_id)] + ret = [ + md.title + for md in get_subjects(embedding_models_ids=[model.id for model in models_ids]) + ] if len(ret) == 0: raise HTTPException(status_code=404, detail="No subjects found.") return ret diff --git a/src/app/services/helpers.py b/src/app/services/helpers.py index a4cee7c..d5374a8 100644 --- a/src/app/services/helpers.py +++ b/src/app/services/helpers.py @@ -8,6 +8,7 @@ from json_repair import JSONReturnType from langdetect import detect_langs from qdrant_client.http.models import models +from welearn_database.data.models import EmbeddingModel from src.app.models.collections import Collection from src.app.models.documents import JourneySectionType @@ -184,7 +185,7 @@ def choose_readability_according_journey_section_type( @log_time_and_error_sync async def collection_and_model_id_according_lang( lang: str | None, sp -) -> tuple[Collection, uuid]: +) -> tuple[Collection, list[EmbeddingModel]]: """ Get the collection info and model id according to the language. Args: @@ -203,9 +204,9 @@ async def collection_and_model_id_according_lang( raise HTTPException( status_code=404, detail=f"No collection found for language '{lang}'." ) - model_id = get_embeddings_model_id_according_name(collection_info.model) - if not model_id: + emb_models = get_embeddings_model_id_according_name(collection_info.model) + if not emb_models: raise ValueError( f"Embedding model '{collection_info.model}' not found in the database." ) - return collection_info, model_id + return collection_info, emb_models diff --git a/src/app/services/sql_db/queries.py b/src/app/services/sql_db/queries.py index 031709b..1af13aa 100644 --- a/src/app/services/sql_db/queries.py +++ b/src/app/services/sql_db/queries.py @@ -173,7 +173,7 @@ def get_subject(subject: str, embedding_model_id: UUID) -> ContextDocument | Non return subject_meta_document -def get_subjects(embedding_model_id: UUID) -> list[ContextDocument]: +def get_subjects(embedding_models_ids: list[UUID]) -> list[ContextDocument]: """ Get all the subject meta documents from the database. Returns: List of subject meta documents. @@ -183,7 +183,7 @@ def get_subjects(embedding_model_id: UUID) -> list[ContextDocument]: session.query(ContextDocument) .filter( ContextDocument.context_type == ContextType.SUBJECT.value.lower(), - ContextDocument.embedding_model_id == embedding_model_id, + ContextDocument.embedding_model_id.in_(embedding_models_ids), ) .all() ) @@ -214,7 +214,9 @@ def get_context_documents( return sdg_meta_documents -def get_embeddings_model_id_according_name(model_name: str) -> list[UUID] | None: +def get_embeddings_model_id_according_name( + model_name: str, +) -> list[EmbeddingModel | None]: """ Get the embeddings model ID according to its name. @@ -225,12 +227,11 @@ def get_embeddings_model_id_according_name(model_name: str) -> list[UUID] | None The ID of the embeddings model if found, otherwise None. """ with session_maker() as session: - model = ( + return ( session.query(EmbeddingModel) .filter(EmbeddingModel.title == model_name) - .first() + .all() ) - return model.id if model else None def write_new_data_quality_error( From cd3a57a6180c98537115d928099b3bc776e49bac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o?= Date: Thu, 5 Mar 2026 15:01:50 +0100 Subject: [PATCH 08/14] feat(micro-learning): update functions to support multiple embedding models --- src/app/api/api_v1/endpoints/micro_learning.py | 8 ++++---- src/app/services/search.py | 8 +++++--- src/app/services/sql_db/queries.py | 12 ++++++++---- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/app/api/api_v1/endpoints/micro_learning.py b/src/app/api/api_v1/endpoints/micro_learning.py index b54a7c3..ea75559 100644 --- a/src/app/api/api_v1/endpoints/micro_learning.py +++ b/src/app/api/api_v1/endpoints/micro_learning.py @@ -53,13 +53,13 @@ async def get_full_journey( lang: str | None = None, sp: SearchService = Depends(get_search_service), ): - collection_info, model_id = await collection_and_model_id_according_lang( + collection_info, models_id = await collection_and_model_id_according_lang( sp=sp, lang=lang ) - + emb_models_ids = [model.id for model in models_id] journey_part = [i.lower() for i in JourneySectionType] sdg_meta_documents = get_context_documents( - journey_part=journey_part, sdg=sdg, embedding_model_id=model_id + journey_part=journey_part, sdg=sdg, embedding_models_ids=emb_models_ids ) if not sdg_meta_documents: raise HTTPException( @@ -67,7 +67,7 @@ async def get_full_journey( ) subject_meta_document: ContextDocument | None = get_subject( - subject=subject, embedding_model_id=model_id + subject=subject, embedding_models_ids=emb_models_ids ) if not subject_meta_document: diff --git a/src/app/services/search.py b/src/app/services/search.py index d52fb56..d98d461 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -25,7 +25,7 @@ from src.app.services.data_quality import DataQualityChecker from src.app.services.exceptions import CollectionNotFoundError, ModelNotFoundError from src.app.services.helpers import convert_embedding_bytes -from src.app.services.sql_db.queries import get_subject +from src.app.services.sql_db.queries import get_subject, get_embeddings_model_id_according_name from src.app.utils.decorators import log_time_and_error, log_time_and_error_sync from src.app.utils.logger import logger as logger_utils @@ -443,7 +443,7 @@ def concatenate_same_doc_id_slices( @log_time_and_error_sync -def get_subject_vector(subject: str | None) -> list[float] | None: +def get_subject_vector(subject: str | None, model_name: str) -> list[float] | None: """ Get the subject vector from the database. Args: @@ -455,7 +455,9 @@ def get_subject_vector(subject: str | None) -> list[float] | None: if not subject: return None - subject_from_db = get_subject(subject=subject) + emb_models = get_embeddings_model_id_according_name(model_name) + + subject_from_db = get_subject(subject=subject, embedding_models_ids=[m.id for m in emb_models]) if not subject_from_db: return None diff --git a/src/app/services/sql_db/queries.py b/src/app/services/sql_db/queries.py index 1af13aa..370c44a 100644 --- a/src/app/services/sql_db/queries.py +++ b/src/app/services/sql_db/queries.py @@ -151,10 +151,13 @@ def register_endpoint(endpoint, session_id, http_code): session.commit() -def get_subject(subject: str, embedding_model_id: UUID) -> ContextDocument | None: +def get_subject( + subject: str, embedding_models_ids: list[UUID] +) -> ContextDocument | None: """ Get the subject meta document from the database. Args: + embedding_models_ids: Database IDs of embeddings models used for vectorize documents subject: The subject to get. Returns: The subject meta document. @@ -166,7 +169,7 @@ def get_subject(subject: str, embedding_model_id: UUID) -> ContextDocument | Non .filter( ContextDocument.context_type == ContextType.SUBJECT.value.lower(), ContextDocument.title == subject, - ContextDocument.embedding_model_id == embedding_model_id, + ContextDocument.embedding_model_id.in_(embedding_models_ids), ) .first() ) @@ -191,12 +194,13 @@ def get_subjects(embedding_models_ids: list[UUID]) -> list[ContextDocument]: def get_context_documents( - journey_part: JourneySection, sdg: int, embedding_model_id: UUID + journey_part: list[JourneySection], sdg: int, embedding_models_ids: list[UUID] ): """ Get the context documents from the database. Args: + embedding_models_ids: Database IDs of embeddings models used for vectorize documents journey_part: The journey part to get the context documents for. sdg: The SDG to get the context documents for. Returns: List of context documents. @@ -207,7 +211,7 @@ def get_context_documents( .filter( ContextDocument.context_type.in_(journey_part), ContextDocument.sdg_related.contains([sdg]), - ContextDocument.embedding_model_id == embedding_model_id, + ContextDocument.embedding_model_id.in_(embedding_models_ids), ) .all() ) From 94291d9893c9663fd6a1db133ef754d4d0979302 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o?= Date: Thu, 5 Mar 2026 15:02:19 +0100 Subject: [PATCH 09/14] type ignore --- src/app/services/search.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/app/services/search.py b/src/app/services/search.py index d98d461..14cde2d 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -25,7 +25,10 @@ from src.app.services.data_quality import DataQualityChecker from src.app.services.exceptions import CollectionNotFoundError, ModelNotFoundError from src.app.services.helpers import convert_embedding_bytes -from src.app.services.sql_db.queries import get_subject, get_embeddings_model_id_according_name +from src.app.services.sql_db.queries import ( + get_embeddings_model_id_according_name, + get_subject, +) from src.app.utils.decorators import log_time_and_error, log_time_and_error_sync from src.app.utils.logger import logger as logger_utils @@ -457,7 +460,9 @@ def get_subject_vector(subject: str | None, model_name: str) -> list[float] | No emb_models = get_embeddings_model_id_according_name(model_name) - subject_from_db = get_subject(subject=subject, embedding_models_ids=[m.id for m in emb_models]) + subject_from_db = get_subject( + subject=subject, embedding_models_ids=[m.id for m in emb_models] # type: ignore + ) if not subject_from_db: return None From c04374e8914453da1b720fb83bc7fb8b54817b48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o?= Date: Thu, 5 Mar 2026 15:50:24 +0100 Subject: [PATCH 10/14] refactor(queries): clean up unused imports in queries.py --- src/app/services/sql_db/queries.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/app/services/sql_db/queries.py b/src/app/services/sql_db/queries.py index 370c44a..f70abf9 100644 --- a/src/app/services/sql_db/queries.py +++ b/src/app/services/sql_db/queries.py @@ -7,7 +7,7 @@ from qdrant_client.http.models import ScoredPoint from sqlalchemy import func, select from welearn_database.data.enumeration import Step -from welearn_database.data.models import ( +from welearn_database.data.models import ( # FilterType,; FilterUsedInQuery, ChatMessage, ContextDocument, Corpus, @@ -17,8 +17,6 @@ EmbeddingModel, EndpointRequest, ErrorDataQuality, - FilterType, - FilterUsedInQuery, ProcessState, QtyDocumentInQdrant, QtyDocumentInQdrantPerCorpus, From ebe50045128a15369d405bb4f5e0f9365c694e60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o?= Date: Mon, 9 Mar 2026 16:02:40 +0100 Subject: [PATCH 11/14] Fix test --- src/app/tests/api/api_v1/test_micro_learning.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/app/tests/api/api_v1/test_micro_learning.py b/src/app/tests/api/api_v1/test_micro_learning.py index fcbdfd5..7162a24 100644 --- a/src/app/tests/api/api_v1/test_micro_learning.py +++ b/src/app/tests/api/api_v1/test_micro_learning.py @@ -1,8 +1,9 @@ import unittest +import uuid from unittest import mock from fastapi.testclient import TestClient -from welearn_database.data.models import ContextDocument +from welearn_database.data.models import ContextDocument, EmbeddingModel from src.app.core.config import settings from src.app.models.collections import Collection @@ -38,7 +39,7 @@ async def test_get_full_journey( with TestClient(app) as client: mock_collection_and_model_id_according_lang.return_value = ( Collection(name="test_collection", lang="en", model="test_model"), - "model_id", + [EmbeddingModel(id=uuid.uuid4(), title="test_model", lang="en")], ) # Mock data mock_get_context_docs.return_value = [ From be8668e47f41f61b3da092e230e79bb1824e3a46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o?= Date: Mon, 9 Mar 2026 16:09:25 +0100 Subject: [PATCH 12/14] fix(micro-learning): pass model to get_subject_vector in search.py --- src/app/services/search.py | 4 +++- src/app/tests/api/api_v1/test_micro_learning.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/app/services/search.py b/src/app/services/search.py index 14cde2d..d30c62b 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -256,7 +256,9 @@ async def search_handler( assert isinstance(qp.query, str) collection = await self.get_collection_by_language(lang="mul") - subject_vector = await run_in_threadpool(get_subject_vector, qp.subject) + subject_vector = await run_in_threadpool( + get_subject_vector, qp.subject, collection.model + ) embedding = await self.get_query_embed( model=collection.model, query=qp.query, diff --git a/src/app/tests/api/api_v1/test_micro_learning.py b/src/app/tests/api/api_v1/test_micro_learning.py index 7162a24..025c729 100644 --- a/src/app/tests/api/api_v1/test_micro_learning.py +++ b/src/app/tests/api/api_v1/test_micro_learning.py @@ -115,7 +115,7 @@ async def test_get_subject_list( ] mock_collection_and_model_id_according_lang.return_value = ( None, - "model_id", + [], ) # API call From 23908cf17126401bdb341ecf660c7d4fce6ca1cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o?= Date: Tue, 10 Mar 2026 10:49:53 +0100 Subject: [PATCH 13/14] refactor: organize imports in agents.py, helpers.py, and queries.py --- src/app/services/helpers.py | 1 - src/app/services/sql_db/queries.py | 4 +++- src/app/services/tutor/agents.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/app/services/helpers.py b/src/app/services/helpers.py index d5374a8..9fce3d5 100644 --- a/src/app/services/helpers.py +++ b/src/app/services/helpers.py @@ -1,4 +1,3 @@ -import uuid from functools import cache from typing import Any, List diff --git a/src/app/services/sql_db/queries.py b/src/app/services/sql_db/queries.py index f70abf9..370c44a 100644 --- a/src/app/services/sql_db/queries.py +++ b/src/app/services/sql_db/queries.py @@ -7,7 +7,7 @@ from qdrant_client.http.models import ScoredPoint from sqlalchemy import func, select from welearn_database.data.enumeration import Step -from welearn_database.data.models import ( # FilterType,; FilterUsedInQuery, +from welearn_database.data.models import ( ChatMessage, ContextDocument, Corpus, @@ -17,6 +17,8 @@ EmbeddingModel, EndpointRequest, ErrorDataQuality, + FilterType, + FilterUsedInQuery, ProcessState, QtyDocumentInQdrant, QtyDocumentInQdrantPerCorpus, diff --git a/src/app/services/tutor/agents.py b/src/app/services/tutor/agents.py index 6bb1947..0c73a7a 100644 --- a/src/app/services/tutor/agents.py +++ b/src/app/services/tutor/agents.py @@ -1,5 +1,5 @@ -import time import json +import time from pathlib import Path from langchain_core.language_models import BaseChatModel From e0435feca3a3f57450ea31191db158298bbec695 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o?= <133012334+lpi-tn@users.noreply.github.com> Date: Tue, 10 Mar 2026 11:13:10 +0100 Subject: [PATCH 14/14] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/app/services/search.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/app/services/search.py b/src/app/services/search.py index d30c62b..fdad4c1 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -461,9 +461,13 @@ def get_subject_vector(subject: str | None, model_name: str) -> list[float] | No return None emb_models = get_embeddings_model_id_according_name(model_name) + embedding_model_ids = [m.id for m in emb_models if m is not None] + if not embedding_model_ids: + return None subject_from_db = get_subject( - subject=subject, embedding_models_ids=[m.id for m in emb_models] # type: ignore + subject=subject, + embedding_models_ids=embedding_model_ids, ) if not subject_from_db: return None