68 lines
2.6 KiB
Python
68 lines
2.6 KiB
Python
from datetime import datetime
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session, selectinload
|
|
|
|
from app.models.training import SessionOrder, SessionSubmission, TrainingSession
|
|
|
|
|
|
class SessionRepository:
|
|
"""会话仓储:负责训练会话、检查申请和诊断治疗提交数据。"""
|
|
|
|
def __init__(self, db: Session) -> None:
|
|
self.db = db
|
|
|
|
def create_session(self, session: TrainingSession) -> TrainingSession:
|
|
"""会话创建:保存训练会话主记录。"""
|
|
self.db.add(session)
|
|
self.db.flush()
|
|
return session
|
|
|
|
def get_owned_session(self, session_id: int, user_id: str) -> TrainingSession | None:
|
|
"""会话归属校验:根据 session_id 和 user_id 查询会话。"""
|
|
stmt = (
|
|
select(TrainingSession)
|
|
.options(
|
|
selectinload(TrainingSession.case),
|
|
selectinload(TrainingSession.orders),
|
|
selectinload(TrainingSession.submission),
|
|
)
|
|
.where(TrainingSession.id == session_id, TrainingSession.user_id == user_id)
|
|
)
|
|
return self.db.scalar(stmt)
|
|
|
|
def update_status(self, session: TrainingSession, status: str) -> TrainingSession:
|
|
"""状态流转:更新训练会话阶段状态。"""
|
|
session.status = status
|
|
if status == "diagnosis":
|
|
session.inquiry_completed_at = datetime.utcnow()
|
|
if status == "completed":
|
|
session.completed_at = datetime.utcnow()
|
|
self.db.flush()
|
|
return session
|
|
|
|
def create_order(self, order: SessionOrder) -> SessionOrder:
|
|
"""检查申请保存:保存用户申请过的检查检验结果。"""
|
|
self.db.add(order)
|
|
self.db.flush()
|
|
return order
|
|
|
|
def get_order_by_item(self, session_id: int, item_code: str) -> SessionOrder | None:
|
|
"""检查申请读取:按会话和检查编码获取已申请结果,用于幂等返回。"""
|
|
stmt = select(SessionOrder).where(
|
|
SessionOrder.session_id == session_id,
|
|
SessionOrder.item_code == item_code,
|
|
)
|
|
return self.db.scalar(stmt)
|
|
|
|
def get_submission(self, session_id: int) -> SessionSubmission | None:
|
|
"""提交读取:获取当前会话的诊断治疗提交记录。"""
|
|
stmt = select(SessionSubmission).where(SessionSubmission.session_id == session_id)
|
|
return self.db.scalar(stmt)
|
|
|
|
def upsert_submission(self, submission: SessionSubmission) -> SessionSubmission:
|
|
"""诊断治疗保存:创建或更新当前会话的提交记录。"""
|
|
self.db.add(submission)
|
|
self.db.flush()
|
|
return submission
|