Files

68 lines
2.6 KiB
Python
Raw Permalink Normal View History

from datetime import datetime
from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload
from app.models.training import SessionOrder, SessionSubmission, TrainingSession
class SessionRepository:
"""会话仓储:负责训练会话、检查申请和诊断治疗提交数据。"""
def __init__(self, db: Session) -> None:
self.db = db
def create_session(self, session: TrainingSession) -> TrainingSession:
"""会话创建:保存训练会话主记录。"""
self.db.add(session)
self.db.flush()
return session
def get_owned_session(self, session_id: int, user_id: str) -> TrainingSession | None:
"""会话归属校验:根据 session_id 和 user_id 查询会话。"""
stmt = (
select(TrainingSession)
.options(
selectinload(TrainingSession.case),
selectinload(TrainingSession.orders),
selectinload(TrainingSession.submission),
)
.where(TrainingSession.id == session_id, TrainingSession.user_id == user_id)
)
return self.db.scalar(stmt)
def update_status(self, session: TrainingSession, status: str) -> TrainingSession:
"""状态流转:更新训练会话阶段状态。"""
session.status = status
if status == "diagnosis":
session.inquiry_completed_at = datetime.utcnow()
if status == "completed":
session.completed_at = datetime.utcnow()
self.db.flush()
return session
def create_order(self, order: SessionOrder) -> SessionOrder:
"""检查申请保存:保存用户申请过的检查检验结果。"""
self.db.add(order)
self.db.flush()
return order
def get_order_by_item(self, session_id: int, item_code: str) -> SessionOrder | None:
"""检查申请读取:按会话和检查编码获取已申请结果,用于幂等返回。"""
stmt = select(SessionOrder).where(
SessionOrder.session_id == session_id,
SessionOrder.item_code == item_code,
)
return self.db.scalar(stmt)
def get_submission(self, session_id: int) -> SessionSubmission | None:
"""提交读取:获取当前会话的诊断治疗提交记录。"""
stmt = select(SessionSubmission).where(SessionSubmission.session_id == session_id)
return self.db.scalar(stmt)
def upsert_submission(self, submission: SessionSubmission) -> SessionSubmission:
"""诊断治疗保存:创建或更新当前会话的提交记录。"""
self.db.add(submission)
self.db.flush()
return submission