diff --git a/app/api/endpoints/claim_endpoints.py b/app/api/endpoints/claim_endpoints.py index bc3fe35..8ecfcb4 100644 --- a/app/api/endpoints/claim_endpoints.py +++ b/app/api/endpoints/claim_endpoints.py @@ -25,6 +25,7 @@ from app.services.analysis_orchestrator import AnalysisOrchestrator from app.core.exceptions import NotFoundException, NotAuthorizedException from app.services.interfaces.embedding_generator import EmbeddingGeneratorInterface +from app.core.exceptions import MonthlyLimitExceededError router = APIRouter(prefix="/claims", tags=["claims"]) logger = logging.getLogger(__name__) @@ -45,8 +46,12 @@ async def create_claim( language=data.language, batch_user_id=data.batch_user_id, batch_post_id=data.batch_post_id, + auth0_id=current_user.auth0_id, ) return ClaimRead.model_validate(claim) + except MonthlyLimitExceededError: + # We don't have 'e.limit' anymore, so we just say "Limit reached" + raise HTTPException(status_code=429, detail="You have reached your monthly claim limit.") except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create claim: {str(e)}" @@ -65,12 +70,19 @@ async def create_claims_batch( raise HTTPException(status_code=400, detail="Maximum of 100 claims allowed.") try: - created_claims = await claim_service.create_claims_batch(claims, current_user.id) + created_claims = await claim_service.create_claims_batch( + claims, + current_user.id, + auth0_id=current_user.auth0_id, + ) claim_ids = [str(claim.id) for claim in created_claims] background_tasks.add_task( claim_service.process_claims_batch_async, created_claims, current_user.id, analysis_orchestrator ) return {"message": f"Processing {len(created_claims)} claims in the background.", "claim_ids": claim_ids} + except MonthlyLimitExceededError: + # We don't have 'e.limit' anymore, so we just say "Limit reached" + raise HTTPException(status_code=429, detail="You have reached your monthly claim limit.") except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to queue batch: {str(e)}" diff --git a/app/core/exceptions.py b/app/core/exceptions.py index 2f1a033..bd5e34a 100644 --- a/app/core/exceptions.py +++ b/app/core/exceptions.py @@ -38,6 +38,17 @@ class DuplicateUserError(Exception): pass +""" +Claim exceptions +""" + + +class MonthlyLimitExceededError(Exception): + """Raised when a user hits their monthly claim limit.""" + + pass + + """ Feedback exceptions """ diff --git a/app/models/domain/analysis.py b/app/models/domain/analysis.py index b107940..faa70be 100644 --- a/app/models/domain/analysis.py +++ b/app/models/domain/analysis.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from datetime import datetime -from typing import Optional, List, Dict +from typing import Optional, List from uuid import UUID import pickle @@ -11,10 +11,8 @@ @dataclass class LogProbsData: - anth_conf_score: float tokens: List[str] probs: List[float] - alternatives: List[Dict[str, float]] @dataclass diff --git a/app/repositories/implementations/claim_repository.py b/app/repositories/implementations/claim_repository.py index 4ce5814..72a300f 100644 --- a/app/repositories/implementations/claim_repository.py +++ b/app/repositories/implementations/claim_repository.py @@ -3,7 +3,7 @@ from uuid import UUID from sqlalchemy import select, func, and_ from sqlalchemy.ext.asyncio import AsyncSession -from datetime import datetime +from datetime import datetime, UTC from app.models.database.models import ClaimModel, ClaimStatus from app.models.domain.claim import Claim @@ -94,6 +94,22 @@ async def get_claims_in_date_range(self, start_date: datetime, end_date: datetim result = await self._session.execute(stmt) return [self._to_domain(claim) for claim in result.scalars().all()] + async def get_monthly_claim_count(self, user_id: str) -> int: + """Counts how many claims a user has created this month.""" + + # Calculate the 1st of the current month + now = datetime.now(UTC) + start_of_month = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + + query = ( + select(func.count()) + .select_from(ClaimModel) + .where(ClaimModel.user_id == user_id, ClaimModel.created_at >= start_of_month) + ) + + result = await self._session.execute(query) + return result.scalar_one() + async def insert_many(self, claim: List[Claim]) -> List[Claim]: models = [self._to_model(claim) for claim in claim] self._session.add_all(models) diff --git a/app/services/analysis_orchestrator.py b/app/services/analysis_orchestrator.py index 7ecb9e7..1a9e2f3 100644 --- a/app/services/analysis_orchestrator.py +++ b/app/services/analysis_orchestrator.py @@ -248,7 +248,8 @@ async def _generate_analysis( logger.info(con_score) current_analysis.confidence_score = float(con_score) # log_data = await self._get_anth_confidence_score(statement=claim_text, veracity_score=veracity_score) - # current_analysis.log_probs = log_data + log_data = LogProbsData(tokens=analysis_text, probs=log_probs) + current_analysis.log_probs = log_data updated_analysis = await self._analysis_repo.update(current_analysis) diff --git a/app/services/claim_service.py b/app/services/claim_service.py index 21786f1..62a8b16 100644 --- a/app/services/claim_service.py +++ b/app/services/claim_service.py @@ -17,6 +17,7 @@ from app.repositories.implementations.claim_repository import ClaimRepository from app.repositories.implementations.analysis_repository import AnalysisRepository from app.services.analysis_orchestrator import AnalysisOrchestrator +from app.core.exceptions import MonthlyLimitExceededError from app.core.exceptions import NotFoundException, NotAuthorizedException @@ -31,6 +32,9 @@ logger = logging.getLogger(__name__) executor = ThreadPoolExecutor(max_workers=1) +RESTRICTED_CLIENT_ID = "hHRhJr5OoJhWumP87MHk5RldejycVAmC@clients" +MONTHLY_LIMIT = 3000 + class ClaimService: def __init__(self, claim_repository: ClaimRepository, analysis_repository: AnalysisRepository): @@ -45,9 +49,17 @@ async def create_claim( language: str, batch_user_id: str = None, batch_post_id: str = None, + auth0_id: str = None, ) -> Claim: """Create a new claim.""" now = datetime.now(UTC) + + if auth0_id is not None: + if auth0_id == RESTRICTED_CLIENT_ID: + current_count = await self._claim_repo.get_monthly_claim_count(user_id) + + if current_count >= MONTHLY_LIMIT: + raise MonthlyLimitExceededError() claim = Claim( id=uuid4(), user_id=user_id, @@ -234,9 +246,20 @@ def _heavy_clustering_math(claims, num_clusters): result = await loop.run_in_executor(executor, _heavy_clustering_math, claims, num_clusters) return result - async def create_claims_batch(self, claims: List[Claim], user_id: str) -> List[Claim]: + async def create_claims_batch( + self, + claims: List[Claim], + user_id: str, + auth0_id: str = None, + ) -> List[Claim]: # Map ClaimCreate + user_id → Claim DB objects now = datetime.now(UTC) + if auth0_id is not None: + if auth0_id == RESTRICTED_CLIENT_ID: + current_count = await self._claim_repo.get_monthly_claim_count(user_id) + + if current_count + len(claims) >= MONTHLY_LIMIT: + raise MonthlyLimitExceededError() claim_models = [ Claim( id=uuid4(),