Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions main/managers/partner_request/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def publish(cls, pr_id: PartnerRequestRef, account_id: AccountRef) -> Opt[Partne
if not pr.is_admin(account_id):
raise ValueError("Must be admin")

# Idempotent: if already published, just return
if pr.status == PartnerRequestStatus.JOINABLE.value:
return pr

if pr.status != PartnerRequestStatus.DRAFT.value:
raise ValueError("Can only publish draft")

Expand Down Expand Up @@ -140,6 +144,10 @@ def cancel(cls, pr_id: PartnerRequestRef, account_id: AccountRef) -> Opt[Partner
if not pr.is_admin(account_id):
raise ValueError("Must be admin")

# Idempotent: if already cancelled, just return
if pr.status == PartnerRequestStatus.CANCELLED.value:
return pr

if pr.status not in [
PartnerRequestStatus.JOINABLE.value,
PartnerRequestStatus.READY.value,
Expand Down
68 changes: 62 additions & 6 deletions main/managers/partner_request/trip/commute.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@
__all__ = ["CommutePRManager"]

import structlog
import sqlmodel
from typing import Optional as Opt

from .base import TripPRManager
from ....schemas.partner_request import PartnerRequestRef
from ....schemas.partner_request import PartnerRequestRef, PartnerRequest, PartnerRequestL2Type
from ....schemas.partner_request.trip.commute import CommutePRContent
from ....schemas.partner_request.trip.create import CommutePRCreate
from core.engine import SessionLocal
from account.schemas import AccountRef


logger = structlog.get_logger(__name__)
Expand All @@ -30,11 +36,61 @@ def get(cls, pr_id: PartnerRequestRef):
pass

@classmethod
def create(cls, *args, **kwargs):
"""Create a commute partner request."""
# TODO: Implement when commute schema is refactored
logger.info("Create commute PR")
raise NotImplementedError("Commute PR creation pending schema refactor")
def create(
cls,
account_id: AccountRef,
data: CommutePRCreate,
db: Opt[sqlmodel.Session] = None,
) -> PartnerRequestRef:
"""Create a commute partner request.

Creates both base PartnerRequest and CommutePRContent records.

:param account_id: Account ID of the creator
:param data: Create request data
:param db: Optional database session
:return: Created partner request ID
"""
logger.info("Create commute PR", account_id=account_id)

if db is None:
db = SessionLocal()
should_close_session = True
else:
should_close_session = False

try:
# Create base partner request
pr = PartnerRequest(
type=PartnerRequestL2Type.COMMUTE.value,
created_by=account_id,
title=data.title,
introduction=data.introduction,
)
db.add(pr)
db.flush() # Get the ID without committing

# Create commute specific content
content = CommutePRContent(
id=pr.id,
route=data.route,
trip_preference=data.trip_preference,
on_at=data.on_at,
off_at=data.off_at,
workdays=data.workdays,
)
db.add(content)
db.commit()

logger.info("Created commute PR", pr_id=pr.id)
return pr.id
except Exception as e:
db.rollback()
logger.error("Failed to create commute PR", error=str(e))
raise
finally:
if should_close_session:
db.close()

@classmethod
def update(cls, *args, **kwargs):
Expand Down
66 changes: 60 additions & 6 deletions main/managers/partner_request/trip/ride_hailing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@
__all__ = ["RideHailingPRManager"]

import structlog
import sqlmodel
from typing import Optional as Opt

from .base import TripPRManager
from ....schemas.partner_request import PartnerRequestRef
from ....schemas.partner_request import PartnerRequestRef, PartnerRequest, PartnerRequestL2Type
from ....schemas.partner_request.trip.ride_hailing import RideHailingPRContent
from ....schemas.partner_request.trip.create import RideHailingPRCreate
from core.engine import SessionLocal
from account.schemas import AccountRef


logger = structlog.get_logger(__name__)
Expand All @@ -30,11 +36,59 @@ def get(cls, pr_id: PartnerRequestRef):
pass

@classmethod
def create(cls, *args, **kwargs):
"""Create a ride-hailing partner request."""
# TODO: Implement when ride-hailing schema is refactored
logger.info("Create ride-hailing PR")
raise NotImplementedError("Ride-hailing PR creation pending schema refactor")
def create(
cls,
account_id: AccountRef,
data: RideHailingPRCreate,
db: Opt[sqlmodel.Session] = None,
) -> PartnerRequestRef:
"""Create a ride-hailing partner request.

Creates both base PartnerRequest and RideHailingPRContent records.

:param account_id: Account ID of the creator
:param data: Create request data
:param db: Optional database session
:return: Created partner request ID
"""
logger.info("Create ride-hailing PR", account_id=account_id)

if db is None:
db = SessionLocal()
should_close_session = True
else:
should_close_session = False

try:
# Create base partner request
pr = PartnerRequest(
type=PartnerRequestL2Type.RIDE_HAILING.value,
created_by=account_id,
title=data.title,
introduction=data.introduction,
)
db.add(pr)
db.flush() # Get the ID without committing

# Create ride-hailing specific content
content = RideHailingPRContent(
id=pr.id,
route=data.route,
trip_preference=data.trip_preference,
ride_hailing_preference=data.ride_hailing_preference,
)
db.add(content)
db.commit()

logger.info("Created ride-hailing PR", pr_id=pr.id)
return pr.id
except Exception as e:
db.rollback()
logger.error("Failed to create ride-hailing PR", error=str(e))
raise
finally:
if should_close_session:
db.close()

@classmethod
def update(cls, *args, **kwargs):
Expand Down
79 changes: 49 additions & 30 deletions main/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,21 @@
PartnerRequestRef,
PartnerRequestStatus,
PartnerRequestListType,
PartnerRequestL2Type,
)
from .schemas.partner_request.trip.create import (
PartnerRequestCreate,
RideHailingPRCreate,
CommutePRCreate,
)
from .managers.partner_request.trip.ride_hailing import RideHailingPRManager
from .managers.partner_request.trip.commute import CommutePRManager

router = fastapi.APIRouter()

# Supported partner request types for creation
SUPPORTED_CREATE_TYPES = [PartnerRequestL2Type.RIDE_HAILING, PartnerRequestL2Type.COMMUTE]


@router.get("/partner_request/list/{list_type}")
def get_partner_request_list(
Expand Down Expand Up @@ -62,6 +73,34 @@ def get_partner_request_list(
return result


@router.post("/partner_request/{pr_type}")
def create_partner_request(
pr_type: PartnerRequestL2Type,
data: PartnerRequestCreate,
auth: AuthInfo = Depends(require_auth),
db: sqlmodel.Session = Depends(get_db_session),
) -> PartnerRequestRef:
"""Create a partner request.

Creates both base PartnerRequest and type-specific content records.
The type of partner request is determined by the pr_type path parameter.
"""
account_id = auth.user_id

if pr_type == PartnerRequestL2Type.RIDE_HAILING:
pr_id = RideHailingPRManager.create(account_id, data, db)
elif pr_type == PartnerRequestL2Type.COMMUTE:
pr_id = CommutePRManager.create(account_id, data, db)
else:
supported = ", ".join([t.value for t in SUPPORTED_CREATE_TYPES])
raise HTTPException(
status_code=400,
detail=f"Unsupported partner request type: {pr_type}. Supported types: {supported}"
)

return pr_id


@router.put("/partner_request/{pr_id}/publish")
def publish_partner_request(
pr_id: PartnerRequestRef,
Expand All @@ -72,6 +111,7 @@ def publish_partner_request(

Changes status from DRAFT to JOINABLE.
Only the admin (creator) can publish.
Idempotent: if already JOINABLE, returns the request without error.
"""
pr = db.get(PartnerRequest, pr_id)
if not pr:
Expand All @@ -80,6 +120,10 @@ def publish_partner_request(
if not pr.is_admin(auth.user_id):
raise HTTPException(status_code=403, detail="Must be admin")

# Idempotent: if already published, just return
if pr.status == PartnerRequestStatus.JOINABLE.value:
return pr

if pr.status != PartnerRequestStatus.DRAFT.value:
raise HTTPException(status_code=409, detail="Can only publish draft")

Expand All @@ -101,6 +145,7 @@ def cancel_partner_request(
Changes status to CANCELLED.
Only the admin (creator) can cancel.
Can only cancel if status is JOINABLE or READY.
Idempotent: if already CANCELLED, returns the request without error.
"""
pr = db.get(PartnerRequest, pr_id)
if not pr:
Expand All @@ -109,6 +154,10 @@ def cancel_partner_request(
if not pr.is_admin(auth.user_id):
raise HTTPException(status_code=403, detail="Must be admin")

# Idempotent: if already cancelled, just return
if pr.status == PartnerRequestStatus.CANCELLED.value:
return pr

if pr.status not in [PartnerRequestStatus.JOINABLE.value, PartnerRequestStatus.READY.value]:
raise HTTPException(status_code=409, detail="Cannot cancel")

Expand All @@ -117,33 +166,3 @@ def cancel_partner_request(
db.commit()
db.refresh(pr)
return pr


@router.put("/partner_request/{pr_id}/status/next")
def next_status(
pr_id: PartnerRequestRef,
auth: AuthInfo = Depends(require_auth),
db: sqlmodel.Session = Depends(get_db_session),
) -> PartnerRequest:
"""Move partner request to next status.

Status progression: DRAFT -> JOINABLE -> READY -> PERFORMING -> SETTLING -> CLOSED
Only the admin (creator) can change status.
"""
pr = db.get(PartnerRequest, pr_id)
if not pr:
raise HTTPException(status_code=404, detail="Partner request not found")

if not pr.is_admin(auth.user_id):
raise HTTPException(status_code=403, detail="Must be admin")

current_status = PartnerRequestStatus(pr.status)
try:
next_status = current_status.next()
pr.status = next_status.value
db.add(pr)
db.commit()
db.refresh(pr)
return pr
except ValueError:
raise HTTPException(status_code=409, detail="No next status available")
41 changes: 41 additions & 0 deletions main/schemas/partner_request/trip/create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Create request schemas for trip partner requests"""

import datetime
from typing import Optional as Opt, List, Literal, Annotated
from pydantic import BaseModel, Field

from .base import TripPreference
from .ride_hailing import RideHailingPreference
from ...base.route import RouteItem
from ...base import Weekday


class RideHailingPRCreate(BaseModel):
"""网约车搭子请求创建请求"""

type: Literal["ride_hailing"] = "ride_hailing"
title: Opt[str] = None
introduction: Opt[str] = None
route: List[RouteItem] = []
trip_preference: Opt[TripPreference] = None
ride_hailing_preference: RideHailingPreference = RideHailingPreference()


class CommutePRCreate(BaseModel):
"""通勤搭子请求创建请求"""

type: Literal["commute"] = "commute"
title: Opt[str] = None
introduction: Opt[str] = None
route: List[RouteItem] = []
trip_preference: Opt[TripPreference] = None
on_at: Opt[datetime.time] = None
off_at: Opt[datetime.time] = None
workdays: Opt[List[Weekday]] = None


PartnerRequestCreate = Annotated[
RideHailingPRCreate | CommutePRCreate,
Field(discriminator="type")
]

Loading