from datetime import datetime from sqlalchemy import select from sqlalchemy.orm import Session, selectinload from app.models.training import SessionOrder, SessionSubmission, TrainingSession class SessionRepository: """会话仓储:负责训练会话、检查申请和诊断治疗提交数据。""" def __init__(self, db: Session) -> None: self.db = db def create_session(self, session: TrainingSession) -> TrainingSession: """会话创建:保存训练会话主记录。""" self.db.add(session) self.db.flush() return session def get_owned_session(self, session_id: int, user_id: str) -> TrainingSession | None: """会话归属校验:根据 session_id 和 user_id 查询会话。""" stmt = ( select(TrainingSession) .options( selectinload(TrainingSession.case), selectinload(TrainingSession.orders), selectinload(TrainingSession.submission), ) .where(TrainingSession.id == session_id, TrainingSession.user_id == user_id) ) return self.db.scalar(stmt) def update_status(self, session: TrainingSession, status: str) -> TrainingSession: """状态流转:更新训练会话阶段状态。""" session.status = status if status == "diagnosis": session.inquiry_completed_at = datetime.utcnow() if status == "completed": session.completed_at = datetime.utcnow() self.db.flush() return session def create_order(self, order: SessionOrder) -> SessionOrder: """检查申请保存:保存用户申请过的检查检验结果。""" self.db.add(order) self.db.flush() return order def get_order_by_item(self, session_id: int, item_code: str) -> SessionOrder | None: """检查申请读取:按会话和检查编码获取已申请结果,用于幂等返回。""" stmt = select(SessionOrder).where( SessionOrder.session_id == session_id, SessionOrder.item_code == item_code, ) return self.db.scalar(stmt) def get_submission(self, session_id: int) -> SessionSubmission | None: """提交读取:获取当前会话的诊断治疗提交记录。""" stmt = select(SessionSubmission).where(SessionSubmission.session_id == session_id) return self.db.scalar(stmt) def upsert_submission(self, submission: SessionSubmission) -> SessionSubmission: """诊断治疗保存:创建或更新当前会话的提交记录。""" self.db.add(submission) self.db.flush() return submission