Files
fastapi/app/repositories/source_case_repository.py
T
2026-06-04 10:55:23 +08:00

70 lines
2.9 KiB
Python

from __future__ import annotations
from sqlalchemy import exists, select
from sqlalchemy.orm import Session, selectinload
from app.models.department import Department
from app.models.source_case import CaseBase, ScoringRule, TeachingCase, TraditionalCase
class SourceCaseRepository:
"""源库病例仓储:读取 case_base、traditional_case、teaching_case 和 scoring_rule。"""
def __init__(self, db: Session) -> None:
self.db = db
def list_active_cases(
self,
department_id: int | None = None,
case_type: str | None = None,
mode: str | None = None,
) -> list[CaseBase]:
"""源库病例列表:按科室、病例分类和训练模式读取已发布病例。"""
stmt = (
select(CaseBase)
.options(selectinload(CaseBase.traditional_case), selectinload(CaseBase.teaching_case))
.where(CaseBase.status == 1, CaseBase.publish_status == 1)
)
if department_id:
stmt = stmt.where(CaseBase.department_id == department_id)
if case_type:
stmt = stmt.where(CaseBase.case_type == case_type)
normalized_mode = self.normalize_mode(mode)
if normalized_mode == "practice":
stmt = stmt.where(exists().where(TraditionalCase.case_id == CaseBase.id))
if normalized_mode == "teaching":
stmt = stmt.where(exists().where(TeachingCase.case_id == CaseBase.id))
return list(self.db.scalars(stmt.order_by(CaseBase.id.desc())).all())
def get_active_case_base(self, case_id: int) -> CaseBase | None:
"""源库病例详情:读取病例主表及传统/教学扩展表。"""
stmt = (
select(CaseBase)
.options(
selectinload(CaseBase.traditional_case),
selectinload(CaseBase.teaching_case),
selectinload(CaseBase.scoring_rules),
)
.where(CaseBase.id == case_id, CaseBase.status == 1, CaseBase.publish_status == 1)
)
return self.db.scalar(stmt)
def get_department_name(self, department_id: int | None) -> str:
"""科室名称:按用户端 department 表读取科室名称。"""
if not department_id:
return ""
department = self.db.scalar(select(Department).where(Department.id == department_id))
return department.name if department else ""
def get_scoring_rules(self, case_id: int) -> list[ScoringRule]:
"""评分规则:读取当前病例对应的基础评分细则。"""
stmt = select(ScoringRule).where(ScoringRule.case_id == case_id).order_by(ScoringRule.id)
return list(self.db.scalars(stmt).all())
@staticmethod
def normalize_mode(mode: str | None) -> str | None:
"""模式归一:旧 novice 请求按练习模式处理,第一版只暴露 practice/teaching。"""
if mode == "novice":
return "practice"
return mode