Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion app/api/endpoints/claim_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from app.services.analysis_orchestrator import AnalysisOrchestrator
from app.core.exceptions import NotFoundException, NotAuthorizedException
from app.services.interfaces.embedding_generator import EmbeddingGeneratorInterface
from app.core.exceptions import MonthlyLimitExceededError

router = APIRouter(prefix="/claims", tags=["claims"])
logger = logging.getLogger(__name__)
Expand All @@ -45,8 +46,12 @@ async def create_claim(
language=data.language,
batch_user_id=data.batch_user_id,
batch_post_id=data.batch_post_id,
auth0_id=current_user.auth0_id,
)
return ClaimRead.model_validate(claim)
except MonthlyLimitExceededError:
# We don't have 'e.limit' anymore, so we just say "Limit reached"
raise HTTPException(status_code=429, detail="You have reached your monthly claim limit.")
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create claim: {str(e)}"
Expand All @@ -65,12 +70,19 @@ async def create_claims_batch(
raise HTTPException(status_code=400, detail="Maximum of 100 claims allowed.")

try:
created_claims = await claim_service.create_claims_batch(claims, current_user.id)
created_claims = await claim_service.create_claims_batch(
claims,
current_user.id,
auth0_id=current_user.auth0_id,
)
claim_ids = [str(claim.id) for claim in created_claims]
background_tasks.add_task(
claim_service.process_claims_batch_async, created_claims, current_user.id, analysis_orchestrator
)
return {"message": f"Processing {len(created_claims)} claims in the background.", "claim_ids": claim_ids}
except MonthlyLimitExceededError:
# We don't have 'e.limit' anymore, so we just say "Limit reached"
raise HTTPException(status_code=429, detail="You have reached your monthly claim limit.")
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to queue batch: {str(e)}"
Expand Down
11 changes: 11 additions & 0 deletions app/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ class DuplicateUserError(Exception):
pass


"""
Claim exceptions
"""


class MonthlyLimitExceededError(Exception):
"""Raised when a user hits their monthly claim limit."""

pass


"""
Feedback exceptions
"""
Expand Down
4 changes: 1 addition & 3 deletions app/models/domain/analysis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Optional, List, Dict
from typing import Optional, List
from uuid import UUID
import pickle

Expand All @@ -11,10 +11,8 @@

@dataclass
class LogProbsData:
anth_conf_score: float
tokens: List[str]
probs: List[float]
alternatives: List[Dict[str, float]]


@dataclass
Expand Down
18 changes: 17 additions & 1 deletion app/repositories/implementations/claim_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from uuid import UUID
from sqlalchemy import select, func, and_
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import datetime
from datetime import datetime, UTC

from app.models.database.models import ClaimModel, ClaimStatus
from app.models.domain.claim import Claim
Expand Down Expand Up @@ -94,6 +94,22 @@ async def get_claims_in_date_range(self, start_date: datetime, end_date: datetim
result = await self._session.execute(stmt)
return [self._to_domain(claim) for claim in result.scalars().all()]

async def get_monthly_claim_count(self, user_id: str) -> int:
"""Counts how many claims a user has created this month."""

# Calculate the 1st of the current month
now = datetime.now(UTC)
start_of_month = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)

query = (
select(func.count())
.select_from(ClaimModel)
.where(ClaimModel.user_id == user_id, ClaimModel.created_at >= start_of_month)
)

result = await self._session.execute(query)
return result.scalar_one()

async def insert_many(self, claim: List[Claim]) -> List[Claim]:
models = [self._to_model(claim) for claim in claim]
self._session.add_all(models)
Expand Down
3 changes: 2 additions & 1 deletion app/services/analysis_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ async def _generate_analysis(
logger.info(con_score)
current_analysis.confidence_score = float(con_score)
# log_data = await self._get_anth_confidence_score(statement=claim_text, veracity_score=veracity_score)
# current_analysis.log_probs = log_data
log_data = LogProbsData(tokens=analysis_text, probs=log_probs)
current_analysis.log_probs = log_data

updated_analysis = await self._analysis_repo.update(current_analysis)

Expand Down
25 changes: 24 additions & 1 deletion app/services/claim_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from app.repositories.implementations.claim_repository import ClaimRepository
from app.repositories.implementations.analysis_repository import AnalysisRepository
from app.services.analysis_orchestrator import AnalysisOrchestrator
from app.core.exceptions import MonthlyLimitExceededError

from app.core.exceptions import NotFoundException, NotAuthorizedException

Expand All @@ -31,6 +32,9 @@
logger = logging.getLogger(__name__)
executor = ThreadPoolExecutor(max_workers=1)

RESTRICTED_CLIENT_ID = "hHRhJr5OoJhWumP87MHk5RldejycVAmC@clients"
MONTHLY_LIMIT = 3000


class ClaimService:
def __init__(self, claim_repository: ClaimRepository, analysis_repository: AnalysisRepository):
Expand All @@ -45,9 +49,17 @@ async def create_claim(
language: str,
batch_user_id: str = None,
batch_post_id: str = None,
auth0_id: str = None,
) -> Claim:
"""Create a new claim."""
now = datetime.now(UTC)

if auth0_id is not None:
if auth0_id == RESTRICTED_CLIENT_ID:
current_count = await self._claim_repo.get_monthly_claim_count(user_id)

if current_count >= MONTHLY_LIMIT:
raise MonthlyLimitExceededError()
claim = Claim(
id=uuid4(),
user_id=user_id,
Expand Down Expand Up @@ -234,9 +246,20 @@ def _heavy_clustering_math(claims, num_clusters):
result = await loop.run_in_executor(executor, _heavy_clustering_math, claims, num_clusters)
return result

async def create_claims_batch(self, claims: List[Claim], user_id: str) -> List[Claim]:
async def create_claims_batch(
self,
claims: List[Claim],
user_id: str,
auth0_id: str = None,
) -> List[Claim]:
# Map ClaimCreate + user_id → Claim DB objects
now = datetime.now(UTC)
if auth0_id is not None:
if auth0_id == RESTRICTED_CLIENT_ID:
current_count = await self._claim_repo.get_monthly_claim_count(user_id)

if current_count + len(claims) >= MONTHLY_LIMIT:
raise MonthlyLimitExceededError()
claim_models = [
Claim(
id=uuid4(),
Expand Down