diff --git a/app/api/dependencies.py b/app/api/dependencies.py index 358ecb6..fd93701 100644 --- a/app/api/dependencies.py +++ b/app/api/dependencies.py @@ -20,6 +20,8 @@ from app.repositories.implementations.source_repository import SourceRepository from app.repositories.implementations.search_repository import SearchRepository from app.repositories.implementations.feedback_repository import FeedbackRepository +from app.repositories.implementations.discussion_repository import DiscussionRepository +from app.repositories.implementations.post_repository import PostRepository from app.core.config import settings from app.services.analysis_orchestrator import AnalysisOrchestrator from app.services.claim_conversation_service import ClaimConversationService @@ -37,6 +39,8 @@ from app.services.source_service import SourceService from app.services.search_service import SearchService from app.services.feedback_service import FeedbackService +from app.services.discussion_service import DiscussionService +from app.services.post_service import PostService from app.db.session import AsyncSessionLocal logger = logging.getLogger(__name__) @@ -97,6 +101,14 @@ async def get_feedback_repository(session: AsyncSession = Depends(get_db)) -> Fe return FeedbackRepository(session) +async def get_discussion_repository(session: AsyncSession = Depends(get_db)) -> DiscussionRepository: + return DiscussionRepository(session) + + +async def get_post_repository(session: AsyncSession = Depends(get_db)) -> PostRepository: + return PostRepository(session) + + async def get_embedding_generator() -> EmbeddingGeneratorInterface: return EmbeddingGenerator() @@ -227,6 +239,19 @@ async def get_orchestrator_service( ) +async def get_discussion_service( + discussion_repository: DiscussionRepository = Depends(get_discussion_repository), +) -> DiscussionService: + return DiscussionService(discussion_repository=discussion_repository) + + +async def get_post_service( + post_repository: PostRepository = Depends(get_post_repository), + discussion_repository: DiscussionRepository = Depends(get_discussion_repository), +) -> PostService: + return PostService(post_repository=post_repository, discussion_repository=discussion_repository) + + async def get_together_orchestrator_service( claim_repository: ClaimRepository = Depends(get_claim_repository), analysis_repository: AnalysisRepository = Depends(get_analysis_repository), diff --git a/app/api/endpoints/discussion_endpoints.py b/app/api/endpoints/discussion_endpoints.py new file mode 100644 index 0000000..f458831 --- /dev/null +++ b/app/api/endpoints/discussion_endpoints.py @@ -0,0 +1,85 @@ +from fastapi import APIRouter, Depends, Query, status, HTTPException + +# from typing import Optional +from uuid import UUID +import logging + +from app.models.domain.user import User +from app.services.discussion_service import DiscussionService +from app.schemas.discussion_schema import DiscussionResponse, PaginatedDiscussionsResponse, DiscussionCreate +from app.api.dependencies import get_discussion_service, get_current_user + +# 'get_current_user_optional' is useful if you want to know WHO is asking, +# but allow anonymous users to read discussions. + +router = APIRouter(prefix="/discussions", tags=["discussions"]) +logger = logging.getLogger(__name__) + + +@router.get("/user", response_model=PaginatedDiscussionsResponse, status_code=status.HTTP_200_OK) +async def get_discussions_per_user( + current_user: User = Depends(get_current_user), + limit: int = Query(10, ge=1, le=100), + offset: int = Query(0, ge=0), + service: DiscussionService = Depends(get_discussion_service), +): + """ + Get a list of discussions. + - If `user_id` is provided, returns discussions for that user. + - Otherwise, returns the most recent discussions system-wide. + """ + if current_user.id: + discussions, total = await service.list_user_discussions(current_user.id, limit=limit, offset=offset) + + return {"items": discussions, "total": total, "limit": limit, "offset": offset} + + +@router.get("/", response_model=PaginatedDiscussionsResponse, status_code=status.HTTP_200_OK) +async def get_recent_discussions( + limit: int = Query(10, ge=1, le=100), + offset: int = Query(0, ge=0), + service: DiscussionService = Depends(get_discussion_service), +): + """ + Get a list of discussions. + - If `user_id` is provided, returns discussions for that user. + - Otherwise, returns the most recent discussions system-wide. + """ + + discussions, total = await service.list_recent_discussions(limit=limit, offset=offset) + + return {"items": discussions, "total": total, "limit": limit, "offset": offset} + + +@router.get("/{discussion_id}", response_model=DiscussionResponse) +async def get_discussion_by_id( + discussion_id: UUID, + service: DiscussionService = Depends(get_discussion_service), +): + """Get a single discussion by ID.""" + try: + return await service.get_discussion(discussion_id) + except Exception as e: + # Assuming your service raises NotFoundException + raise HTTPException(status_code=404, detail=str(e)) + + +@router.post("/", response_model=DiscussionResponse, status_code=status.HTTP_201_CREATED) +async def create_discussion( + payload: DiscussionCreate, + current_user: User = Depends(get_current_user), + service: DiscussionService = Depends(get_discussion_service), +): + """ + Create a new discussion. + - Requires authentication. + - 'analysis_id' is optional (use it if you want to attach the discussion to a specific claim analysis). + """ + discussion = await service.create_discussion( + title=payload.title, + description=payload.description, + analysis_id=payload.analysis_id, + user_id=current_user.id, # Matches your User model field + ) + + return discussion diff --git a/app/api/endpoints/post_endpoints.py b/app/api/endpoints/post_endpoints.py new file mode 100644 index 0000000..0be4a83 --- /dev/null +++ b/app/api/endpoints/post_endpoints.py @@ -0,0 +1,110 @@ +from fastapi import APIRouter, Depends, status, HTTPException, Query +from typing import List +from uuid import UUID +import logging + +from app.models.domain.user import User +from app.services.post_service import PostService +from app.schemas.post_schema import PostCreate, PostUpdate, PostVote, PostResponse +from app.api.dependencies import get_post_service, get_current_user +from app.core.exceptions import NotFoundException, NotAuthorizedException + +router = APIRouter(prefix="/posts", tags=["Posts"]) +logger = logging.getLogger(__name__) + + +# --- CREATE POST --- +@router.post("/", response_model=PostResponse, status_code=status.HTTP_201_CREATED) +async def create_post( + payload: PostCreate, + current_user: User = Depends(get_current_user), + service: PostService = Depends(get_post_service), +): + """Create a new post in a discussion.""" + try: + return await service.create_post( + discussion_id=payload.discussion_id, user_id=current_user.id, text=payload.text + ) + except NotFoundException as e: + raise HTTPException(status_code=404, detail=str(e)) + except Exception as e: + logger.exception("Error creating post") + raise HTTPException(status_code=500, detail=str(e)) + + +# --- UPDATE POST TEXT --- +@router.put("/{post_id}", response_model=PostResponse) +async def update_post_content( + post_id: UUID, + payload: PostUpdate, + current_user: User = Depends(get_current_user), + service: PostService = Depends(get_post_service), +): + """ + Update the text content of a post. + Only the creator of the post can do this. + """ + try: + return await service.update_post_text(post_id=post_id, user_id=current_user.id, new_text=payload.text) + except NotFoundException as e: + raise HTTPException(status_code=404, detail=str(e)) + except NotAuthorizedException as e: + raise HTTPException(status_code=403, detail=str(e)) + except Exception as e: + logger.exception("Error updating post") + raise HTTPException(status_code=500, detail=str(e)) + + +# --- VOTE ON POST --- +@router.put("/{post_id}/vote", response_model=PostResponse) +async def vote_on_post( + post_id: UUID, + payload: PostVote, + current_user: User = Depends(get_current_user), + service: PostService = Depends(get_post_service), +): + """ + Increment upvotes or downvotes. + Payload: {"vote_type": "up"} or {"vote_type": "down"} + """ + try: + return await service.vote_post(post_id=post_id, vote_type=payload.vote_type) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except NotFoundException as e: + raise HTTPException(status_code=404, detail=str(e)) + except Exception as e: + logger.exception("Error voting on post") + raise HTTPException(status_code=500, detail=str(e)) + + +# --- DELETE POST --- +@router.delete("/{post_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_post( + post_id: UUID, + current_user: User = Depends(get_current_user), + service: PostService = Depends(get_post_service), +): + """Delete a post (Owner only).""" + try: + await service.delete_post(post_id=post_id, user_id=current_user.id) + except NotFoundException as e: + raise HTTPException(status_code=404, detail=str(e)) + except NotAuthorizedException as e: + raise HTTPException(status_code=403, detail=str(e)) + except Exception as e: + logger.exception("Error deleting post") + raise HTTPException(status_code=500, detail=str(e)) + + +# --- GET POSTS BY DISCUSSION --- +@router.get("/discussion/{discussion_id}", response_model=List[PostResponse]) +async def list_posts_for_discussion( + discussion_id: UUID, + limit: int = Query(50, ge=1, le=100), + offset: int = Query(0, ge=0), + service: PostService = Depends(get_post_service), +): + """List all posts belonging to a specific discussion.""" + posts, _ = await service.list_discussion_posts(discussion_id=discussion_id, limit=limit, offset=offset) + return posts diff --git a/app/api/router.py b/app/api/router.py index 0569bf0..4774254 100644 --- a/app/api/router.py +++ b/app/api/router.py @@ -11,6 +11,8 @@ domain_endpoints, health_endpoints, claim_conversation_endpoints, + discussion_endpoints, + post_endpoints, ) router = APIRouter() @@ -24,5 +26,7 @@ router.include_router(conversation_endpoints.router, tags=["conversations"]) router.include_router(message_endpoints.router, tags=["messages"]) router.include_router(domain_endpoints.router, tags=["domains"]) +router.include_router(discussion_endpoints.router, tags=["discussions"]) +router.include_router(post_endpoints.router, tags=["posts"]) router.include_router(claim_conversation_endpoints.router, tags=["claim-conversations"]) router.include_router(health_endpoints.router, tags=["health"]) diff --git a/app/models/database/models.py b/app/models/database/models.py index 625f428..e79b881 100644 --- a/app/models/database/models.py +++ b/app/models/database/models.py @@ -64,6 +64,8 @@ class UserModel(Base): claims: Mapped[List["ClaimModel"]] = relationship(back_populates="user", cascade="all, delete-orphan") conversations: Mapped[List["ConversationModel"]] = relationship(back_populates="user", cascade="all, delete-orphan") feedbacks: Mapped[List["FeedbackModel"]] = relationship(back_populates="user", cascade="all, delete-orphan") + discussions: Mapped[List["DiscussionModel"]] = relationship(back_populates="user", passive_deletes=True) + posts: Mapped[List["PostModel"]] = relationship(back_populates="user", passive_deletes=True) class DomainModel(Base): @@ -150,6 +152,7 @@ class AnalysisModel(Base): cascade="all, delete-orphan", primaryjoin="FeedbackModel.analysis_id == AnalysisModel.id", ) + discussions: Mapped[List["DiscussionModel"]] = relationship(back_populates="analysis", passive_deletes=True) messages: Mapped[List["MessageModel"]] = relationship(back_populates="analysis", doc="Related analysis, if any") __table_args__ = ( @@ -158,6 +161,79 @@ class AnalysisModel(Base): ) +class DiscussionModel(Base): + __tablename__ = "discussions" + + analysis_id: Mapped[UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("analysis.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) + + title: Mapped[str] = mapped_column( + Text, + nullable=False, + ) + + description: Mapped[str] = mapped_column( + Text, + nullable=True, + ) + + user_id: Mapped[UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("users.id", ondelete="SET NULL"), # If User dies, set this to NULL + nullable=True, # Must be True for SET NULL to work + index=True, + ) + + # Relationships + user: Mapped["UserModel"] = relationship(back_populates="discussions") + + analysis: Mapped["AnalysisModel"] = relationship(back_populates="discussions") + + posts: Mapped[List["PostModel"]] = relationship(back_populates="discussion", cascade="all, delete-orphan") + + +class PostModel(Base): + __tablename__ = "posts" + + discussion_id: Mapped[UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("discussions.id"), + nullable=False, + index=True, + ) + + text: Mapped[str] = mapped_column( + Text, + nullable=False, + ) + + up_votes: Mapped[str] = mapped_column( + Integer, + nullable=True, + ) + + down_votes: Mapped[str] = mapped_column( + Integer, + nullable=True, + ) + + user_id: Mapped[UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("users.id", ondelete="SET NULL"), # If User dies, set this to NULL + nullable=True, # Must be True for SET NULL to work + index=True, + ) + + # Relationships + user: Mapped["UserModel"] = relationship(back_populates="posts") + + discussion: Mapped["DiscussionModel"] = relationship(back_populates="posts") + + class SearchModel(Base): __tablename__ = "searches" diff --git a/app/models/domain/discussion.py b/app/models/domain/discussion.py new file mode 100644 index 0000000..c77042d --- /dev/null +++ b/app/models/domain/discussion.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Optional +from uuid import UUID + +from app.models.database.models import DiscussionModel + + +@dataclass +class Discussion: + """Domain model for discussions.""" + + id: UUID + title: str + created_at: datetime + updated_at: datetime + + analysis_id: Optional[UUID] = None + user_id: Optional[UUID] = None + description: Optional[str] = None + + @classmethod + def from_model(cls, model: "DiscussionModel") -> "Discussion": + """Create domain model from database model.""" + return cls( + id=model.id, + title=model.title, + description=model.description, + analysis_id=model.analysis_id, + user_id=model.user_id, + created_at=model.created_at, + updated_at=model.updated_at, + ) + + def to_model(self) -> "DiscussionModel": + """Convert to database model.""" + return DiscussionModel( + id=self.id, + title=self.title, + description=self.description, + analysis_id=self.analysis_id, + user_id=self.user_id, + ) diff --git a/app/models/domain/post.py b/app/models/domain/post.py new file mode 100644 index 0000000..7e1bb1c --- /dev/null +++ b/app/models/domain/post.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Optional +from uuid import UUID + +from app.models.database.models import PostModel + + +@dataclass +class Post: + """Domain model for discussion posts.""" + + id: UUID + discussion_id: UUID + text: str + created_at: datetime + updated_at: datetime + user_id: Optional[UUID] = None + up_votes: Optional[int] = 0 + down_votes: Optional[int] = 0 + + @classmethod + def from_model(cls, model: "PostModel") -> "Post": + """Create domain model from database model.""" + return cls( + id=model.id, + discussion_id=model.discussion_id, + text=model.text, + user_id=model.user_id, + up_votes=model.up_votes, + down_votes=model.down_votes, + created_at=model.created_at, + updated_at=model.updated_at, + ) + + def to_model(self) -> "PostModel": + """Convert to database model.""" + return PostModel( + id=self.id, + discussion_id=self.discussion_id, + text=self.text, + user_id=self.user_id, + up_votes=self.up_votes, + down_votes=self.down_votes, + ) diff --git a/app/repositories/implementations/discussion_repository.py b/app/repositories/implementations/discussion_repository.py new file mode 100644 index 0000000..fb2100c --- /dev/null +++ b/app/repositories/implementations/discussion_repository.py @@ -0,0 +1,78 @@ +import logging +from typing import Optional, List, Tuple +from uuid import UUID +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.database.models import DiscussionModel +from app.models.domain.discussion import Discussion +from app.repositories.base import BaseRepository +from app.repositories.interfaces.discussion_repository import DiscussionRepositoryInterface + +logger = logging.getLogger(__name__) + + +class DiscussionRepository(BaseRepository[DiscussionModel, Discussion], DiscussionRepositoryInterface): + def __init__(self, session: AsyncSession): + super().__init__(session, DiscussionModel) + + def _to_model(self, domain: Discussion) -> DiscussionModel: + return DiscussionModel( + id=domain.id, + title=domain.title, + description=domain.description, + analysis_id=domain.analysis_id, + user_id=domain.user_id, + # created_at/updated_at are typically handled by DB defaults, + # but can be passed if the domain sets them explicitly + ) + + def _to_domain(self, model: DiscussionModel) -> Discussion: + return Discussion( + id=model.id, + title=model.title, + description=model.description, + analysis_id=model.analysis_id, + user_id=model.user_id, + created_at=model.created_at, + updated_at=model.updated_at, + ) + + async def get_by_analysis_id(self, analysis_id: UUID) -> Optional[Discussion]: + stmt = select(self._model_class).where(self._model_class.analysis_id == analysis_id) + result = await self._session.execute(stmt) + model = result.scalar_one_or_none() + return self._to_domain(model) if model else None + + async def get_user_discussions( + self, user_id: UUID, limit: int = 20, offset: int = 0 + ) -> Tuple[List[Discussion], int]: + + # 1. Build Query + query = select(self._model_class).where(self._model_class.user_id == user_id) + + # 2. Get Total Count + count_query = select(func.count()).select_from(self._model_class).where(self._model_class.user_id == user_id) + total = await self._session.scalar(count_query) + + # 3. Apply Limit/Offset/Order + query = query.order_by(self._model_class.created_at.desc()).limit(limit).offset(offset) + + result = await self._session.execute(query) + discussions = [self._to_domain(model) for model in result.scalars().all()] + + return discussions, total + + async def get_recent_discussions(self, limit: int = 20, offset: int = 0) -> Tuple[List[Discussion], int]: + + query = select(self._model_class) + count_query = select(func.count()).select_from(self._model_class) + + total = await self._session.scalar(count_query) + + query = query.order_by(self._model_class.created_at.desc()).limit(limit).offset(offset) + + result = await self._session.execute(query) + discussions = [self._to_domain(model) for model in result.scalars().all()] + + return discussions, total diff --git a/app/repositories/implementations/post_repository.py b/app/repositories/implementations/post_repository.py new file mode 100644 index 0000000..1fa14e4 --- /dev/null +++ b/app/repositories/implementations/post_repository.py @@ -0,0 +1,78 @@ +import logging +from typing import List, Tuple +from uuid import UUID +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.database.models import PostModel +from app.models.domain.post import Post +from app.repositories.base import BaseRepository +from app.repositories.interfaces.post_repository import PostRepositoryInterface + +logger = logging.getLogger(__name__) + + +class PostRepository(BaseRepository[PostModel, Post], PostRepositoryInterface): + def __init__(self, session: AsyncSession): + super().__init__(session, PostModel) + + def _to_model(self, domain: Post) -> PostModel: + return PostModel( + id=domain.id, + discussion_id=domain.discussion_id, + text=domain.text, + user_id=domain.user_id, + up_votes=domain.up_votes, + down_votes=domain.down_votes, + # Database usually handles created_at, but we map if provided + ) + + def _to_domain(self, model: PostModel) -> Post: + return Post( + id=model.id, + discussion_id=model.discussion_id, + text=model.text, + created_at=model.created_at, + updated_at=model.updated_at, + user_id=model.user_id, + up_votes=model.up_votes, + down_votes=model.down_votes, + ) + + async def get_by_discussion_id( + self, discussion_id: UUID, limit: int = 50, offset: int = 0 + ) -> Tuple[List[Post], int]: + + # Base query + query = select(self._model_class).where(self._model_class.discussion_id == discussion_id) + + # Count query + count_query = ( + select(func.count()).select_from(self._model_class).where(self._model_class.discussion_id == discussion_id) + ) + total = await self._session.scalar(count_query) + + # Ordering (Usually oldest first for chats/threads, or usually newest first? + # Standard forums usually do Oldest -> Newest. Comments sections usually do Newest -> Oldest. + # I will default to Created At ASC (Oldest first) for reading a thread properly). + query = query.order_by(self._model_class.created_at.asc()).limit(limit).offset(offset) + + result = await self._session.execute(query) + posts = [self._to_domain(model) for model in result.scalars().all()] + + return posts, total + + async def get_user_posts(self, user_id: UUID, limit: int = 50, offset: int = 0) -> Tuple[List[Post], int]: + + query = select(self._model_class).where(self._model_class.user_id == user_id) + + count_query = select(func.count()).select_from(self._model_class).where(self._model_class.user_id == user_id) + total = await self._session.scalar(count_query) + + # For user history, Newest first is standard + query = query.order_by(self._model_class.created_at.desc()).limit(limit).offset(offset) + + result = await self._session.execute(query) + posts = [self._to_domain(model) for model in result.scalars().all()] + + return posts, total diff --git a/app/repositories/interfaces/discussion_repository.py b/app/repositories/interfaces/discussion_repository.py new file mode 100644 index 0000000..9528e84 --- /dev/null +++ b/app/repositories/interfaces/discussion_repository.py @@ -0,0 +1,46 @@ +from abc import ABC, abstractmethod +from typing import Optional, List, Tuple +from uuid import UUID + +from app.models.domain.discussion import Discussion + + +class DiscussionRepositoryInterface(ABC): + """Interface for discussion repository operations.""" + + @abstractmethod + async def create(self, discussion: Discussion) -> Discussion: + """Create a new discussion.""" + pass + + @abstractmethod + async def get(self, discussion_id: UUID) -> Optional[Discussion]: + """Get discussion by ID.""" + pass + + @abstractmethod + async def update(self, discussion: Discussion) -> Discussion: + """Update a discussion.""" + pass + + @abstractmethod + async def delete(self, discussion_id: UUID) -> bool: + """Delete a discussion.""" + pass + + @abstractmethod + async def get_by_analysis_id(self, analysis_id: UUID) -> Optional[Discussion]: + """Get a discussion associated with a specific analysis.""" + pass + + @abstractmethod + async def get_user_discussions( + self, user_id: UUID, limit: int = 20, offset: int = 0 + ) -> Tuple[List[Discussion], int]: + """Get discussions started by a user with pagination.""" + pass + + @abstractmethod + async def get_recent_discussions(self, limit: int = 20, offset: int = 0) -> Tuple[List[Discussion], int]: + """Get a list of all discussions, ordered by recency.""" + pass diff --git a/app/repositories/interfaces/post_repository.py b/app/repositories/interfaces/post_repository.py new file mode 100644 index 0000000..a0d3a7a --- /dev/null +++ b/app/repositories/interfaces/post_repository.py @@ -0,0 +1,41 @@ +from abc import ABC, abstractmethod +from typing import Optional, List, Tuple +from uuid import UUID + +from app.models.domain.post import Post + + +class PostRepositoryInterface(ABC): + """Interface for discussion post repository operations.""" + + @abstractmethod + async def create(self, post: Post) -> Post: + """Create a new post.""" + pass + + @abstractmethod + async def get(self, post_id: UUID) -> Optional[Post]: + """Get post by ID.""" + pass + + @abstractmethod + async def update(self, post: Post) -> Post: + """Update a post.""" + pass + + @abstractmethod + async def delete(self, post_id: UUID) -> bool: + """Delete a post.""" + pass + + @abstractmethod + async def get_by_discussion_id( + self, discussion_id: UUID, limit: int = 50, offset: int = 0 + ) -> Tuple[List[Post], int]: + """Get posts belonging to a discussion (paginated).""" + pass + + @abstractmethod + async def get_user_posts(self, user_id: UUID, limit: int = 50, offset: int = 0) -> Tuple[List[Post], int]: + """Get posts created by a specific user.""" + pass diff --git a/app/schemas/discussion_schema.py b/app/schemas/discussion_schema.py new file mode 100644 index 0000000..e453727 --- /dev/null +++ b/app/schemas/discussion_schema.py @@ -0,0 +1,33 @@ +from typing import List, Optional +from uuid import UUID +from datetime import datetime +from pydantic import BaseModel + + +class DiscussionResponse(BaseModel): + id: UUID + title: str + description: Optional[str] = None + analysis_id: Optional[UUID] = None + user_id: Optional[UUID] = None + created_at: datetime + updated_at: datetime + + class Config: + from_attributes = True # Allows mapping from your Domain/DB models + + +class PaginatedDiscussionsResponse(BaseModel): + items: List[DiscussionResponse] + total: int + limit: int + offset: int + + +class DiscussionCreate(BaseModel): + title: str + description: Optional[str] = None + analysis_id: Optional[UUID] = None + + class Config: + from_attributes = True diff --git a/app/schemas/post_schema.py b/app/schemas/post_schema.py new file mode 100644 index 0000000..f2e025d --- /dev/null +++ b/app/schemas/post_schema.py @@ -0,0 +1,31 @@ +from uuid import UUID +from datetime import datetime +from pydantic import BaseModel, Field + + +class PostCreate(BaseModel): + discussion_id: UUID = Field(..., description="ID of the discussion this post belongs to") + text: str = Field(..., min_length=1, max_length=10000) + + +class PostUpdate(BaseModel): + text: str = Field(..., min_length=1, max_length=10000) + + +# TODO confirm with Dorsaf +class PostVote(BaseModel): + vote_type: str = Field(..., pattern="^(up|down)$", description="Must be 'up' or 'down'") + + +class PostResponse(BaseModel): + id: UUID + discussion_id: UUID + user_id: UUID + text: str + up_votes: int + down_votes: int + created_at: datetime + updated_at: datetime + + class Config: + from_attributes = True diff --git a/app/services/discussion_service.py b/app/services/discussion_service.py new file mode 100644 index 0000000..db49fb5 --- /dev/null +++ b/app/services/discussion_service.py @@ -0,0 +1,94 @@ +import logging +from datetime import datetime, UTC +from typing import List, Optional, Tuple +from uuid import UUID, uuid4 + +from app.models.domain.discussion import Discussion +from app.repositories.implementations.discussion_repository import DiscussionRepository +from app.core.exceptions import NotFoundException, NotAuthorizedException + +logger = logging.getLogger(__name__) + + +class DiscussionService: + def __init__(self, discussion_repository: DiscussionRepository): + self._discussion_repo = discussion_repository + + async def create_discussion( + self, + title: str, + user_id: UUID, + analysis_id: Optional[UUID] = None, + description: Optional[str] = None, + ) -> Discussion: + """Create a new discussion.""" + now = datetime.now(UTC) + + # Optional: You could check if a discussion already exists for this analysis_id + # if analysis_id: + # existing = await self._discussion_repo.get_by_analysis_id(analysis_id) + # if existing: + # return existing + + discussion = Discussion( + id=uuid4(), + title=title, + description=description, + analysis_id=analysis_id, + user_id=user_id, + created_at=now, + updated_at=now, + ) + + return await self._discussion_repo.create(discussion) + + async def get_discussion(self, discussion_id: UUID) -> Discussion: + """Get a discussion by ID.""" + discussion = await self._discussion_repo.get(discussion_id) + if not discussion: + raise NotFoundException("Discussion not found") + return discussion + + async def get_discussion_by_analysis(self, analysis_id: UUID) -> Discussion: + """Get the discussion associated with a specific analysis.""" + discussion = await self._discussion_repo.get_by_analysis_id(analysis_id) + if not discussion: + raise NotFoundException("Discussion for this analysis not found") + return discussion + + async def list_recent_discussions(self, limit: int = 20, offset: int = 0) -> Tuple[List[Discussion], int]: + """List all discussions ordered by recency.""" + return await self._discussion_repo.get_recent_discussions(limit=limit, offset=offset) + + async def list_user_discussions( + self, user_id: UUID, limit: int = 20, offset: int = 0 + ) -> Tuple[List[Discussion], int]: + """List discussions created by a specific user.""" + return await self._discussion_repo.get_user_discussions(user_id=user_id, limit=limit, offset=offset) + + async def delete_discussion(self, discussion_id: UUID, user_id: UUID) -> bool: + """Delete a discussion (Owner only).""" + discussion = await self.get_discussion(discussion_id) + + if discussion.user_id != user_id: + raise NotAuthorizedException("You are not authorized to delete this discussion") + + return await self._discussion_repo.delete(discussion_id) + + async def update_discussion( + self, discussion_id: UUID, user_id: UUID, title: Optional[str] = None, description: Optional[str] = None + ) -> Discussion: + """Update discussion details (Owner only).""" + discussion = await self.get_discussion(discussion_id) + + if discussion.user_id != user_id: + raise NotAuthorizedException("You are not authorized to edit this discussion") + + if title: + discussion.title = title + if description is not None: + discussion.description = description + + discussion.updated_at = datetime.now(UTC) + + return await self._discussion_repo.update(discussion) diff --git a/app/services/post_service.py b/app/services/post_service.py new file mode 100644 index 0000000..d942819 --- /dev/null +++ b/app/services/post_service.py @@ -0,0 +1,101 @@ +import logging +from datetime import datetime, UTC +from typing import List, Tuple +from uuid import UUID, uuid4 + +from app.models.domain.post import Post +from app.repositories.implementations.post_repository import PostRepository +from app.repositories.implementations.discussion_repository import DiscussionRepository +from app.core.exceptions import NotFoundException, NotAuthorizedException + +logger = logging.getLogger(__name__) + + +class PostService: + def __init__(self, post_repository: PostRepository, discussion_repository: DiscussionRepository): + self._post_repo = post_repository + self._discussion_repo = discussion_repository + + async def create_post( + self, + discussion_id: UUID, + user_id: UUID, + text: str, + ) -> Post: + """Create a new post in a discussion.""" + # Validate discussion exists + discussion = await self._discussion_repo.get(discussion_id) + if not discussion: + raise NotFoundException("Discussion not found") + + now = datetime.now(UTC) + + post = Post( + id=uuid4(), + discussion_id=discussion_id, + user_id=user_id, + text=text, + up_votes=0, + down_votes=0, + created_at=now, + updated_at=now, + ) + + return await self._post_repo.create(post) + + async def get_post(self, post_id: UUID) -> Post: + """Get a single post.""" + post = await self._post_repo.get(post_id) + if not post: + raise NotFoundException("Post not found") + return post + + async def list_discussion_posts( + self, discussion_id: UUID, limit: int = 50, offset: int = 0 + ) -> Tuple[List[Post], int]: + """List posts for a specific discussion.""" + # Optional: verify discussion exists first if you want strict checking + return await self._post_repo.get_by_discussion_id(discussion_id=discussion_id, limit=limit, offset=offset) + + async def list_user_posts(self, user_id: UUID, limit: int = 50, offset: int = 0) -> Tuple[List[Post], int]: + """List posts made by a specific user.""" + return await self._post_repo.get_user_posts(user_id=user_id, limit=limit, offset=offset) + + async def update_post_text(self, post_id: UUID, user_id: UUID, new_text: str) -> Post: + """Update the content of a post (Owner only).""" + post = await self.get_post(post_id) + + if post.user_id != user_id: + raise NotAuthorizedException("You are not authorized to edit this post") + + post.text = new_text + post.updated_at = datetime.now(UTC) + return await self._post_repo.update(post) + + async def delete_post(self, post_id: UUID, user_id: UUID) -> bool: + """Delete a post (Owner only).""" + post = await self.get_post(post_id) + + if post.user_id != user_id: + raise NotAuthorizedException("You are not authorized to delete this post") + + return await self._post_repo.delete(post_id) + + async def vote_post(self, post_id: UUID, vote_type: str) -> Post: + """ + Increment upvotes or downvotes. + vote_type must be 'up' or 'down'. + Note: In a production system, you should track *who* voted to prevent double voting. + This simple implementation just increments counters. + """ + post = await self.get_post(post_id) + + if vote_type == "up": + post.up_votes = (post.up_votes or 0) + 1 + elif vote_type == "down": + post.down_votes = (post.down_votes or 0) + 1 + else: + raise ValueError("vote_type must be 'up' or 'down'") + + # We don't update 'updated_at' for votes usually, but that depends on preference + return await self._post_repo.update(post) diff --git a/migrations/versions/345aea7c066f_added_new_tables.py b/migrations/versions/345aea7c066f_added_new_tables.py new file mode 100644 index 0000000..2c34b36 --- /dev/null +++ b/migrations/versions/345aea7c066f_added_new_tables.py @@ -0,0 +1,69 @@ +"""Added new tables + +Revision ID: 345aea7c066f +Revises: d2ffae797992 +Create Date: 2026-01-19 18:57:32.700667 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "345aea7c066f" +down_revision: Union[str, None] = "d2ffae797992" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "discussions", + sa.Column("analysis_id", sa.UUID(), nullable=True), + sa.Column("title", sa.Text(), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint( + ["analysis_id"], ["analysis.id"], name=op.f("fk_discussions_analysis_id_analysis"), ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["user_id"], ["users.id"], name=op.f("fk_discussions_user_id_users"), ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_discussions")), + ) + op.create_index(op.f("ix_discussions_analysis_id"), "discussions", ["analysis_id"], unique=False) + op.create_index(op.f("ix_discussions_user_id"), "discussions", ["user_id"], unique=False) + op.create_table( + "posts", + sa.Column("discussion_id", sa.UUID(), nullable=False), + sa.Column("text", sa.Text(), nullable=False), + sa.Column("up_votes", sa.Integer(), nullable=True), + sa.Column("down_votes", sa.Integer(), nullable=True), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint(["discussion_id"], ["discussions.id"], name=op.f("fk_posts_discussion_id_discussions")), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], name=op.f("fk_posts_user_id_users"), ondelete="SET NULL"), + sa.PrimaryKeyConstraint("id", name=op.f("pk_posts")), + ) + op.create_index(op.f("ix_posts_discussion_id"), "posts", ["discussion_id"], unique=False) + op.create_index(op.f("ix_posts_user_id"), "posts", ["user_id"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_posts_user_id"), table_name="posts") + op.drop_index(op.f("ix_posts_discussion_id"), table_name="posts") + op.drop_table("posts") + op.drop_index(op.f("ix_discussions_user_id"), table_name="discussions") + op.drop_index(op.f("ix_discussions_analysis_id"), table_name="discussions") + op.drop_table("discussions") + # ### end Alembic commands ###