156 lines
8.3 KiB
Python
156 lines
8.3 KiB
Python
|
|
from sqlalchemy import delete, exists, func, or_, select
|
||
|
|
from sqlalchemy.orm import Session, selectinload
|
||
|
|
|
||
|
|
from app.models.source_case import CaseBase, CaseExamItem, ScoringRule, TeachingCase, TraditionalCase
|
||
|
|
from app.models.training import SessionOrder, SessionSubmission, TrainingSession
|
||
|
|
from app.models.training_record import TrainingRecord
|
||
|
|
|
||
|
|
|
||
|
|
class CaseRepository:
|
||
|
|
"""病例仓储:基于 case_base 新表体系读取病例、扩展表和检查项目。"""
|
||
|
|
|
||
|
|
def __init__(self, db: Session) -> None:
|
||
|
|
self.db = db
|
||
|
|
|
||
|
|
def list_active_cases(
|
||
|
|
self,
|
||
|
|
department_id: int | None = None,
|
||
|
|
training_type: str | None = None,
|
||
|
|
mode: str | None = None,
|
||
|
|
) -> list[CaseBase]:
|
||
|
|
"""病例列表:从 case_base 读取已发布病例,并按模式匹配传统/教学扩展表。"""
|
||
|
|
stmt = (
|
||
|
|
select(CaseBase)
|
||
|
|
.options(selectinload(CaseBase.traditional_case), selectinload(CaseBase.teaching_case))
|
||
|
|
.where(CaseBase.status == 1, CaseBase.publish_status == 1)
|
||
|
|
)
|
||
|
|
if department_id:
|
||
|
|
stmt = stmt.where(CaseBase.department_id == department_id)
|
||
|
|
if training_type:
|
||
|
|
stmt = stmt.where(CaseBase.case_type == training_type)
|
||
|
|
if mode == "practice":
|
||
|
|
stmt = stmt.where(exists().where(TraditionalCase.case_id == CaseBase.id))
|
||
|
|
if mode == "teaching":
|
||
|
|
stmt = stmt.where(exists().where(TeachingCase.case_id == CaseBase.id))
|
||
|
|
return list(self.db.scalars(stmt.order_by(CaseBase.id.desc())).all())
|
||
|
|
|
||
|
|
def get_active_case(self, case_id: int) -> CaseBase | None:
|
||
|
|
"""病例详情:读取 case_base,并加载传统病例、教学病例、评分规则和检查项目。"""
|
||
|
|
stmt = (
|
||
|
|
select(CaseBase)
|
||
|
|
.options(
|
||
|
|
selectinload(CaseBase.traditional_case),
|
||
|
|
selectinload(CaseBase.teaching_case),
|
||
|
|
selectinload(CaseBase.scoring_rules),
|
||
|
|
selectinload(CaseBase.exam_items),
|
||
|
|
)
|
||
|
|
.where(CaseBase.id == case_id, CaseBase.status == 1, CaseBase.publish_status == 1)
|
||
|
|
)
|
||
|
|
return self.db.scalar(stmt)
|
||
|
|
|
||
|
|
def get_case_by_id(self, case_id: int) -> CaseBase | None:
|
||
|
|
"""病例删除:按主键读取病例,不限制发布状态,用于删除前校验。"""
|
||
|
|
return self.db.get(CaseBase, case_id)
|
||
|
|
|
||
|
|
def get_delete_preview_counts(self, case_id: int) -> dict[str, int]:
|
||
|
|
"""病例删除预览:统计删除病例会影响的业务数据数量。"""
|
||
|
|
session_ids = self._session_ids(case_id)
|
||
|
|
return {
|
||
|
|
"case_base": self._count(CaseBase, CaseBase.id == case_id),
|
||
|
|
"traditional_case": self._count(TraditionalCase, TraditionalCase.case_id == case_id),
|
||
|
|
"teaching_case": self._count(TeachingCase, TeachingCase.case_id == case_id),
|
||
|
|
"scoring_rule": self._count(ScoringRule, ScoringRule.case_id == case_id),
|
||
|
|
"case_exam_item": self._count(CaseExamItem, CaseExamItem.case_id == case_id),
|
||
|
|
"training_session": len(session_ids),
|
||
|
|
"training_order": self._count_training_orders(case_id, session_ids),
|
||
|
|
"training_submission": self._count_by_sessions(SessionSubmission, SessionSubmission.session_id, session_ids),
|
||
|
|
"training_record": self._count_training_records(case_id, session_ids),
|
||
|
|
}
|
||
|
|
|
||
|
|
def delete_case_cascade(self, case_id: int) -> dict[str, int]:
|
||
|
|
"""病例删除执行:按外键依赖顺序清理训练数据、检查项、评分规则和病例主表。"""
|
||
|
|
session_ids = self._session_ids(case_id)
|
||
|
|
deleted: dict[str, int] = {}
|
||
|
|
deleted["training_order"] = self._delete_training_orders(case_id, session_ids)
|
||
|
|
deleted["training_submission"] = self._delete_by_sessions(
|
||
|
|
SessionSubmission, SessionSubmission.session_id, session_ids
|
||
|
|
)
|
||
|
|
deleted["training_record"] = self._delete_training_records(case_id, session_ids)
|
||
|
|
deleted["training_session"] = self._delete_where(TrainingSession, TrainingSession.case_id == case_id)
|
||
|
|
deleted["case_exam_item"] = self._delete_where(CaseExamItem, CaseExamItem.case_id == case_id)
|
||
|
|
deleted["scoring_rule"] = self._delete_where(ScoringRule, ScoringRule.case_id == case_id)
|
||
|
|
deleted["traditional_case"] = self._delete_where(TraditionalCase, TraditionalCase.case_id == case_id)
|
||
|
|
deleted["teaching_case"] = self._delete_where(TeachingCase, TeachingCase.case_id == case_id)
|
||
|
|
deleted["case_base"] = self._delete_where(CaseBase, CaseBase.id == case_id)
|
||
|
|
return deleted
|
||
|
|
|
||
|
|
def _session_ids(self, case_id: int) -> list[int]:
|
||
|
|
"""病例删除:读取该病例关联的训练会话 ID 集合。"""
|
||
|
|
stmt = select(TrainingSession.id).where(TrainingSession.case_id == case_id)
|
||
|
|
return [int(item) for item in self.db.scalars(stmt).all()]
|
||
|
|
|
||
|
|
def _count(self, model: type, *criteria) -> int:
|
||
|
|
"""病例删除预览:按条件统计单表记录数。"""
|
||
|
|
stmt = select(func.count()).select_from(model).where(*criteria)
|
||
|
|
return int(self.db.scalar(stmt) or 0)
|
||
|
|
|
||
|
|
def _count_by_sessions(self, model: type, session_column, session_ids: list[int]) -> int:
|
||
|
|
"""病例删除预览:按训练会话集合统计从表记录数。"""
|
||
|
|
if not session_ids:
|
||
|
|
return 0
|
||
|
|
return self._count(model, session_column.in_(session_ids))
|
||
|
|
|
||
|
|
def _count_training_orders(self, case_id: int, session_ids: list[int]) -> int:
|
||
|
|
"""病例删除预览:统计检查申请记录,兼容按病例和按会话两种关联。"""
|
||
|
|
if session_ids:
|
||
|
|
return self._count(SessionOrder, or_(SessionOrder.case_id == case_id, SessionOrder.session_id.in_(session_ids)))
|
||
|
|
return self._count(SessionOrder, SessionOrder.case_id == case_id)
|
||
|
|
|
||
|
|
def _count_training_records(self, case_id: int, session_ids: list[int]) -> int:
|
||
|
|
"""病例删除预览:统计完整训练记录,兼容按病例和按会话两种关联。"""
|
||
|
|
if session_ids:
|
||
|
|
return self._count(
|
||
|
|
TrainingRecord,
|
||
|
|
or_(TrainingRecord.case_id == case_id, TrainingRecord.session_id.in_(session_ids)),
|
||
|
|
)
|
||
|
|
return self._count(TrainingRecord, TrainingRecord.case_id == case_id)
|
||
|
|
|
||
|
|
def _delete_where(self, model: type, *criteria) -> int:
|
||
|
|
"""病例删除执行:按条件删除单表记录并返回影响行数。"""
|
||
|
|
result = self.db.execute(delete(model).where(*criteria))
|
||
|
|
return int(result.rowcount or 0)
|
||
|
|
|
||
|
|
def _delete_by_sessions(self, model: type, session_column, session_ids: list[int]) -> int:
|
||
|
|
"""病例删除执行:按训练会话集合删除从表记录。"""
|
||
|
|
if not session_ids:
|
||
|
|
return 0
|
||
|
|
return self._delete_where(model, session_column.in_(session_ids))
|
||
|
|
|
||
|
|
def _delete_training_orders(self, case_id: int, session_ids: list[int]) -> int:
|
||
|
|
"""病例删除执行:删除该病例下所有检查申请记录,避免阻塞检查项删除。"""
|
||
|
|
if session_ids:
|
||
|
|
return self._delete_where(
|
||
|
|
SessionOrder,
|
||
|
|
or_(SessionOrder.case_id == case_id, SessionOrder.session_id.in_(session_ids)),
|
||
|
|
)
|
||
|
|
return self._delete_where(SessionOrder, SessionOrder.case_id == case_id)
|
||
|
|
|
||
|
|
def _delete_training_records(self, case_id: int, session_ids: list[int]) -> int:
|
||
|
|
"""病例删除执行:删除该病例完整训练后沉淀的评价记录。"""
|
||
|
|
if session_ids:
|
||
|
|
return self._delete_where(
|
||
|
|
TrainingRecord,
|
||
|
|
or_(TrainingRecord.case_id == case_id, TrainingRecord.session_id.in_(session_ids)),
|
||
|
|
)
|
||
|
|
return self._delete_where(TrainingRecord, TrainingRecord.case_id == case_id)
|
||
|
|
|
||
|
|
def get_exam_items(self, case_id: int) -> list[CaseExamItem]:
|
||
|
|
"""检查项目:读取当前病例下全部可申请检查检验项目。"""
|
||
|
|
stmt = select(CaseExamItem).where(CaseExamItem.case_id == case_id).order_by(CaseExamItem.display_order)
|
||
|
|
return list(self.db.scalars(stmt).all())
|
||
|
|
|
||
|
|
def get_exam_item(self, case_id: int, item_code: str) -> CaseExamItem | None:
|
||
|
|
"""检查结果:按病例和项目编码读取固定检查检验结果。"""
|
||
|
|
stmt = select(CaseExamItem).where(CaseExamItem.case_id == case_id, CaseExamItem.item_code == item_code)
|
||
|
|
return self.db.scalar(stmt)
|