Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,10 @@ COPY . .

ENV PYTHONPATH=/app

ENV HF_HOME="/app/hf_cache"

# 4. RUN THE PRELOAD SCRIPT
# This downloads the 80MB model and saves it into the image layers
RUN python app/preload_model.py

CMD ["./docker-entrypoint.sh"]
36 changes: 36 additions & 0 deletions app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from app.core.auth.auth0_middleware import Auth0Middleware
from app.core.llm.vertex_ai_llama import VertexAILlamaProvider
from app.core.llm.together_ai_llama import TogetherAIProvider

# from app.db.session import get_session
from app.models.domain.user import User
Expand Down Expand Up @@ -177,6 +178,15 @@ async def get_llm_provider():
raise


async def get_together_llm_provider():
try:
provider = TogetherAIProvider(settings)
return provider
except Exception as e:
logger.error(f"Failed to initialize LLM provider: {str(e)}", exc_info=True)
raise


async def get_web_search_service(
domain_service: DomainService = Depends(get_domain_service),
source_repository: SourceRepository = Depends(get_source_repository),
Expand Down Expand Up @@ -217,6 +227,32 @@ async def get_orchestrator_service(
)


async def get_together_orchestrator_service(
claim_repository: ClaimRepository = Depends(get_claim_repository),
analysis_repository: AnalysisRepository = Depends(get_analysis_repository),
conversation_repository: ConversationRepository = Depends(get_conversation_repository),
claim_conversation_repository: ClaimConversationRepository = Depends(get_claim_conversation_repository),
message_repository: MessageRepository = Depends(get_message_repository),
source_repository: SourceRepository = Depends(get_source_repository),
search_repository: SearchRepository = Depends(get_search_repository),
web_search_service: WebSearchServiceInterface = Depends(get_web_search_service),
llm_provider=Depends(get_together_llm_provider),
) -> AnalysisOrchestrator:
llm_provider = TogetherAIProvider(settings)

return AnalysisOrchestrator(
claim_repo=claim_repository,
analysis_repo=analysis_repository,
conversation_repo=conversation_repository,
claim_conversation_repo=claim_conversation_repository,
message_repo=message_repository,
source_repo=source_repository,
search_repo=search_repository,
web_search_service=web_search_service,
llm_provider=llm_provider,
)


async def get_serper_orchestrator_service(
claim_repository: ClaimRepository = Depends(get_claim_repository),
analysis_repository: AnalysisRepository = Depends(get_analysis_repository),
Expand Down
3 changes: 2 additions & 1 deletion app/api/endpoints/analysis_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
get_orchestrator_service,
get_current_user,
get_claim_service,
get_together_orchestrator_service,
)
from app.models.domain.user import User
from app.services.analysis_service import AnalysisService
Expand Down Expand Up @@ -93,7 +94,7 @@ async def stream_claim_analysis_exp(
request: Request,
claim_id: UUID,
current_user: User = Depends(get_current_user),
analysis_orchestrator: AnalysisOrchestrator = Depends(get_orchestrator_service),
analysis_orchestrator: AnalysisOrchestrator = Depends(get_together_orchestrator_service),
claim_service: ClaimService = Depends(get_claim_service),
) -> StreamingResponse:
"""Stream the analysis process for a claim in real-time."""
Expand Down
3 changes: 3 additions & 0 deletions app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class Settings(BaseSettings):

SERPER_API_KEY: str = ""

TOGETHER_API_KEY: str = ""

LLAMA_MODEL_NAME: str = "meta/llama-3.3-70b-instruct-maas"

AUTH0_DOMAIN: str = "veri-fact.ca.auth0.com"
Expand Down Expand Up @@ -76,6 +78,7 @@ def mask_password_in_url(url: str) -> str:
print(f"GOOGLE_SEARCH_ENGINE_ID: {self.GOOGLE_SEARCH_ENGINE_ID}")
print(f"Google Search API configured: {bool(self.GOOGLE_SEARCH_API_KEY)}")
print(f"Serper Search API configured: {bool(self.SERPER_API_KEY)}")
print(f"Together AI API configured: {bool(self.TOGETHER_API_KEY)}")
print(f"LLAMA_MODEL_NAME: {self.LLAMA_MODEL_NAME}")
print("=====================================\n")

Expand Down
128 changes: 128 additions & 0 deletions app/core/llm/together_ai_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import logging
import math
from typing import AsyncGenerator, List
from datetime import datetime, timezone
import openai

from app.core.llm.interfaces import LLMProvider
from app.core.llm.messages import Message, Response, ResponseChunk

logger = logging.getLogger(__name__)


class TogetherAIProvider(LLMProvider):
def __init__(self, settings):
try:
self.api_key = settings.TOGETHER_API_KEY
self.base_url = "https://api.together.xyz/v1"
self.model_id = "meta-llama/Llama-3.3-70B-Instruct-Turbo" # e.g., "meta-llama/Llama-3.3-70B-Instruct-Turbo"

if not self.api_key:
raise ValueError("TOGETHER_API_KEY is not set in settings")

logger.info(f"Initializing Together AI provider with model: {self.model_id}")

# Together AI is fully compatible with the OpenAI SDK
self.client = openai.OpenAI(
base_url=self.base_url,
api_key=self.api_key,
)

logger.info("Successfully initialized Together AI provider")

except Exception as e:
logger.error(f"Failed to initialize Together AI provider: {str(e)}", exc_info=True)
raise

def _calculate_confidence(self, logprobs: List[float] | None) -> float:
"""
Calculates the average confidence score (0.0 to 1.0) from a list of token logprobs.
"""
if not logprobs:
return 0.0

try:
# 1. Sum up all the log probabilities (negative numbers)
sum_logprobs = sum(logprobs)

# 2. Divide by the number of tokens (Length Normalization)
avg_logprob = sum_logprobs / len(logprobs)

# 3. Convert back to probability space (0.0 to 1.0)
return math.exp(avg_logprob)

except Exception as e:
logger.warning(f"Math error calculating confidence: {e}")
return 0.0

async def generate_response(self, messages: List[Message], temperature: float = 0.7) -> Response:
try:
logger.debug(f"Generating response with temperature {temperature}")

# Synchronous call (OpenAI Python client is sync by default, can use AsyncOpenAI if needed)
response = self.client.chat.completions.create(
model=self.model_id,
messages=[{"role": m.role, "content": m.content} for m in messages],
temperature=temperature,
# CRITICAL: This enables the confidence data you need
logprobs=1,
)

choice = response.choices[0]

# 1. Extract Logprob object
logprobs_obj = getattr(choice, "logprobs", None)

# 2. Get the Log Probs
logprobs = getattr(logprobs_obj, "token_logprobs", None)

confidence = self._calculate_confidence(logprobs)

return Response(
text=choice.message.content,
confidence_score=confidence,
created_at=datetime.now(timezone.utc),
metadata={
"model": self.model_id,
"finish_reason": choice.finish_reason,
"usage": response.usage.model_dump() if response.usage else None,
# We store the full logprobs in case you want to debug specific tokens later
"raw_logprobs": logprobs_obj if logprobs_obj else None,
},
)

except Exception as e:
logger.error(f"Error generating response: {str(e)}", exc_info=True)
raise

async def generate_stream(
self, messages: List[Message], temperature: float = 0.7
) -> AsyncGenerator[ResponseChunk, None]:
try:
logger.debug("Starting stream generation")

stream = self.client.chat.completions.create(
model=self.model_id,
messages=[{"role": m.role, "content": m.content} for m in messages],
temperature=temperature,
stream=True,
logprobs=1,
)

for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content

# 1. Extract the Logprobs Object

chunk_logprobs = getattr(chunk.choices[0], "logprobs", None)

yield ResponseChunk(
text=content, is_complete=False, metadata={"model": self.model_id, "logprobs": chunk_logprobs}
)

yield ResponseChunk(text="", is_complete=True, metadata={"model": self.model_id})

except Exception as e:
logger.error(f"Error in generate_stream: {str(e)}", exc_info=True)
raise
11 changes: 11 additions & 0 deletions app/preload_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# app/preload_model.py
import os
from sentence_transformers import SentenceTransformer

# 1. Force the location to be inside the container structure
os.environ["HF_HOME"] = "/app/hf_cache"

print("⏳ Downloading embedding model to build...")
# 2. This triggers the download
model = SentenceTransformer("all-MiniLM-L6-v2")
print("✅ Model downloaded successfully.")
33 changes: 25 additions & 8 deletions app/services/analysis_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import re
from copy import deepcopy
import math

from app.core.exceptions import NotAuthorizedException, NotFoundException, ValidationError
from app.core.llm.interfaces import LLMProvider
Expand Down Expand Up @@ -193,13 +194,17 @@ async def _generate_analysis(
messages += [LLMMessage(role="user", content=AnalysisPrompt.GET_VERACITY_FR)]

analysis_text = []
log_probs = []

async for chunk in self._llm.generate_stream(messages):
if not chunk.is_complete:
analysis_text.append(chunk.text)
log_probs.append(chunk.metadata.get("logprobs"))
yield {"type": "content", "content": chunk.text}
else:
full_text = "".join(analysis_text)
logger.warning(f"length {len(analysis_text)}, {analysis_text}")
logger.warning(f"length {len(log_probs)}, {log_probs}")

try:
# Clean the text before parsing
Expand Down Expand Up @@ -238,14 +243,10 @@ async def _generate_analysis(
current_analysis.updated_at = datetime.now(UTC)

if not default:
con_score = await self._generate_confidence_score(
statement=claim_text,
analysis=analysis_content,
sources=sources_text,
veracity=veracity_score,
)
logger.warning(con_score)
current_analysis.confidence_score = float(con_score) / 100.0

con_score = await self._generate_logprob_confidence_score(log_probs=log_probs)
logger.info(con_score)
current_analysis.confidence_score = float(con_score)

updated_analysis = await self._analysis_repo.update(current_analysis)

Expand Down Expand Up @@ -707,6 +708,22 @@ def _query_initial(self, statement: str, language: str):
else:
raise ValidationError("Claim Language is invalid")

async def _generate_logprob_confidence_score(self, log_probs: list[float]):
if not log_probs:
return 0.0

try:
# 1. Calculate the average of the log-probabilities
# (This represents the geometric mean in log-space)
avg_logprob = sum(log_probs) / len(log_probs)

# 2. Convert back to linear probability (0.0 to 1.0)
return math.exp(avg_logprob)

except Exception:
# Safety fallback for edge cases like overflow (unlikely with logprobs)
return 0.0

async def _generate_confidence_score(self, statement: str, analysis: str, sources: str, veracity: str):
messages = [
LLMMessage(
Expand Down
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ services:
- .:/app
environment:
- DATABASE_URL=${DATABASE_URL:-postgresql://will:nordai123@misinformation_mitigation_db/mitigation_misinformation_db}
- HF_HOME=/app/hf_cache
depends_on:
misinformation_mitigation_db:
condition: service_healthy
Expand Down
Binary file not shown.
Loading