Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions src/app/api/api_v1/endpoints/micro_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@
async def get_subject_list(
lang: str | None = None, sp: SearchService = Depends(get_search_service)
) -> list[str]:
collection_info, model_id = await collection_and_model_id_according_lang(
collection_info, models_ids = await collection_and_model_id_according_lang(
sp=sp, lang=lang
)
ret = [md.title for md in get_subjects(embedding_model_id=model_id)]
ret = [
md.title
for md in get_subjects(embedding_models_ids=[model.id for model in models_ids])
]
if len(ret) == 0:
raise HTTPException(status_code=404, detail="No subjects found.")
return ret
Expand All @@ -50,21 +53,21 @@ async def get_full_journey(
lang: str | None = None,
sp: SearchService = Depends(get_search_service),
):
collection_info, model_id = await collection_and_model_id_according_lang(
collection_info, models_id = await collection_and_model_id_according_lang(
sp=sp, lang=lang
)

emb_models_ids = [model.id for model in models_id]
journey_part = [i.lower() for i in JourneySectionType]
sdg_meta_documents = get_context_documents(
journey_part=journey_part, sdg=sdg, embedding_model_id=model_id
journey_part=journey_part, sdg=sdg, embedding_models_ids=emb_models_ids
)
if not sdg_meta_documents:
raise HTTPException(
status_code=404, detail=f"SDG '{sdg}' not found in meta documents."
)

subject_meta_document: ContextDocument | None = get_subject(
subject=subject, embedding_model_id=model_id
subject=subject, embedding_models_ids=emb_models_ids
)

if not subject_meta_document:
Expand Down
10 changes: 5 additions & 5 deletions src/app/services/helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import uuid
from functools import cache
from typing import Any, List

Expand All @@ -8,6 +7,7 @@
from json_repair import JSONReturnType
from langdetect import detect_langs
from qdrant_client.http.models import models
from welearn_database.data.models import EmbeddingModel

from src.app.models.collections import Collection
from src.app.models.documents import JourneySectionType
Expand Down Expand Up @@ -184,7 +184,7 @@ def choose_readability_according_journey_section_type(
@log_time_and_error_sync
async def collection_and_model_id_according_lang(
lang: str | None, sp
) -> tuple[Collection, uuid]:
) -> tuple[Collection, list[EmbeddingModel]]:
"""
Get the collection info and model id according to the language.
Args:
Expand All @@ -203,9 +203,9 @@ async def collection_and_model_id_according_lang(
raise HTTPException(
status_code=404, detail=f"No collection found for language '{lang}'."
)
model_id = get_embeddings_model_id_according_name(collection_info.model)
if not model_id:
emb_models = get_embeddings_model_id_according_name(collection_info.model)
if not emb_models:
raise ValueError(
f"Embedding model '{collection_info.model}' not found in the database."
)
return collection_info, model_id
return collection_info, emb_models
21 changes: 17 additions & 4 deletions src/app/services/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from src.app.services.data_quality import DataQualityChecker
from src.app.services.exceptions import CollectionNotFoundError, ModelNotFoundError
from src.app.services.helpers import convert_embedding_bytes
from src.app.services.sql_db.queries import get_subject
from src.app.services.sql_db.queries import (
get_embeddings_model_id_according_name,
get_subject,
)
from src.app.utils.decorators import log_time_and_error, log_time_and_error_sync
from src.app.utils.logger import logger as logger_utils

Expand Down Expand Up @@ -253,7 +256,9 @@ async def search_handler(
assert isinstance(qp.query, str)

collection = await self.get_collection_by_language(lang="mul")
subject_vector = await run_in_threadpool(get_subject_vector, qp.subject)
subject_vector = await run_in_threadpool(
get_subject_vector, qp.subject, collection.model
)
embedding = await self.get_query_embed(
model=collection.model,
query=qp.query,
Expand Down Expand Up @@ -443,7 +448,7 @@ def concatenate_same_doc_id_slices(


@log_time_and_error_sync
def get_subject_vector(subject: str | None) -> list[float] | None:
def get_subject_vector(subject: str | None, model_name: str) -> list[float] | None:
"""
Get the subject vector from the database.
Args:
Expand All @@ -455,7 +460,15 @@ def get_subject_vector(subject: str | None) -> list[float] | None:
if not subject:
return None

subject_from_db = get_subject(subject=subject)
emb_models = get_embeddings_model_id_according_name(model_name)
embedding_model_ids = [m.id for m in emb_models if m is not None]
if not embedding_model_ids:
return None

subject_from_db = get_subject(
subject=subject,
embedding_models_ids=embedding_model_ids,
)
if not subject_from_db:
return None

Expand Down
30 changes: 15 additions & 15 deletions src/app/services/sql_db/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from src.app.services.constants import APP_NAME
from src.app.services.sql_db.sql_service import session_maker

model_id_cache: dict[str, UUID] = {}
model_id_lock = Lock()


Expand Down Expand Up @@ -152,10 +151,13 @@ def register_endpoint(endpoint, session_id, http_code):
session.commit()


def get_subject(subject: str, embedding_model_id: UUID) -> ContextDocument | None:
def get_subject(
subject: str, embedding_models_ids: list[UUID]
) -> ContextDocument | None:
"""
Get the subject meta document from the database.
Args:
embedding_models_ids: Database IDs of embeddings models used for vectorize documents
subject: The subject to get.

Returns: The subject meta document.
Expand All @@ -167,14 +169,14 @@ def get_subject(subject: str, embedding_model_id: UUID) -> ContextDocument | Non
.filter(
ContextDocument.context_type == ContextType.SUBJECT.value.lower(),
ContextDocument.title == subject,
ContextDocument.embedding_model_id == embedding_model_id,
ContextDocument.embedding_model_id.in_(embedding_models_ids),
)
.first()
)
return subject_meta_document


def get_subjects(embedding_model_id: UUID) -> list[ContextDocument]:
def get_subjects(embedding_models_ids: list[UUID]) -> list[ContextDocument]:
"""
Get all the subject meta documents from the database.
Returns: List of subject meta documents.
Expand All @@ -184,20 +186,21 @@ def get_subjects(embedding_model_id: UUID) -> list[ContextDocument]:
session.query(ContextDocument)
.filter(
ContextDocument.context_type == ContextType.SUBJECT.value.lower(),
ContextDocument.embedding_model_id == embedding_model_id,
ContextDocument.embedding_model_id.in_(embedding_models_ids),
)
.all()
)
return sdg_meta_documents


def get_context_documents(
journey_part: JourneySection, sdg: int, embedding_model_id: UUID
journey_part: list[JourneySection], sdg: int, embedding_models_ids: list[UUID]
):
"""
Get the context documents from the database.

Args:
embedding_models_ids: Database IDs of embeddings models used for vectorize documents
journey_part: The journey part to get the context documents for.
sdg: The SDG to get the context documents for.
Returns: List of context documents.
Expand All @@ -208,14 +211,16 @@ def get_context_documents(
.filter(
ContextDocument.context_type.in_(journey_part),
ContextDocument.sdg_related.contains([sdg]),
ContextDocument.embedding_model_id == embedding_model_id,
ContextDocument.embedding_model_id.in_(embedding_models_ids),
)
.all()
)
return sdg_meta_documents


def get_embeddings_model_id_according_name(model_name: str) -> UUID | None:
def get_embeddings_model_id_according_name(
model_name: str,
) -> list[EmbeddingModel | None]:
"""
Get the embeddings model ID according to its name.

Expand All @@ -225,17 +230,12 @@ def get_embeddings_model_id_according_name(model_name: str) -> UUID | None:
Returns:
The ID of the embeddings model if found, otherwise None.
"""
with model_id_lock:
if model_name in model_id_cache:
return model_id_cache[model_name]

with session_maker() as session:
model = (
return (
session.query(EmbeddingModel)
.filter(EmbeddingModel.title == model_name)
.first()
.all()
)
return model.id if model else None


def write_new_data_quality_error(
Expand Down
2 changes: 1 addition & 1 deletion src/app/services/tutor/agents.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
import json
import time
from pathlib import Path

from langchain_core.language_models import BaseChatModel
Expand Down
7 changes: 4 additions & 3 deletions src/app/tests/api/api_v1/test_micro_learning.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import unittest
import uuid
from unittest import mock

from fastapi.testclient import TestClient
from welearn_database.data.models import ContextDocument
from welearn_database.data.models import ContextDocument, EmbeddingModel

from src.app.core.config import settings
from src.app.models.collections import Collection
Expand Down Expand Up @@ -38,7 +39,7 @@ async def test_get_full_journey(
with TestClient(app) as client:
mock_collection_and_model_id_according_lang.return_value = (
Collection(name="test_collection", lang="en", model="test_model"),
"model_id",
[EmbeddingModel(id=uuid.uuid4(), title="test_model", lang="en")],
)
# Mock data
mock_get_context_docs.return_value = [
Expand Down Expand Up @@ -114,7 +115,7 @@ async def test_get_subject_list(
]
mock_collection_and_model_id_according_lang.return_value = (
None,
"model_id",
[],
)

# API call
Expand Down