make case catalog read-only

This commit is contained in:
刘金宝
2026-06-04 17:50:22 +08:00
parent b46e43aadc
commit 7f1803f9fa
15 changed files with 35 additions and 1268 deletions
+1 -28
View File
@@ -4,13 +4,7 @@ from sqlalchemy.orm import Session
from app.core.response import ApiResponse, ok
from app.core.user_context import UserContext, get_user_context
from app.db.session import get_db
from app.schemas.case import (
CaseDeletePreviewResponse,
CaseDeleteRequest,
CaseDeleteResponse,
CaseDetailResponse,
CaseListResponse,
)
from app.schemas.case import CaseDetailResponse, CaseListResponse
from app.services.case_service import CaseService
router = APIRouter()
@@ -36,24 +30,3 @@ def get_case_detail(
):
"""病例详情:返回训练入口信息和可申请检查类型。"""
return ok(CaseService(db).get_case_detail(case_id))
@router.get("/{case_id}/delete-preview", response_model=ApiResponse[CaseDeletePreviewResponse])
def get_case_delete_preview(
case_id: int,
_: UserContext = Depends(get_user_context),
db: Session = Depends(get_db),
):
"""病例删除预览:返回删除该病例会影响的训练与病例数据数量。"""
return ok(CaseService(db).get_delete_preview(case_id))
@router.delete("/{case_id}", response_model=ApiResponse[CaseDeleteResponse])
def delete_case(
case_id: int,
payload: CaseDeleteRequest,
ctx: UserContext = Depends(get_user_context),
db: Session = Depends(get_db),
):
"""病例删除:确认后级联删除病例、扩展表、评分规则、检查项和关联训练数据。"""
return ok(CaseService(db).delete_case(case_id, payload, ctx))
-26
View File
@@ -1,26 +0,0 @@
from fastapi import APIRouter, Depends, File, UploadFile
from app.core.response import ApiResponse, ok
from app.core.user_context import UserContext, get_user_context
from app.schemas.imports import CaseSqlImportApplyResponse, CaseSqlImportPreviewResponse
from app.services.case_sql_import_service import CaseSqlImportService
router = APIRouter()
@router.post("/case-sql/preview", response_model=ApiResponse[CaseSqlImportPreviewResponse])
async def preview_case_sql(
file: UploadFile = File(...),
_: UserContext = Depends(get_user_context),
):
"""病例 SQL 预检:上传接口 SQL 文件,解析可导入病例数据但不写入数据库。"""
return ok(await CaseSqlImportService().preview(file))
@router.post("/case-sql/apply", response_model=ApiResponse[CaseSqlImportApplyResponse])
async def apply_case_sql(
file: UploadFile = File(...),
_: UserContext = Depends(get_user_context),
):
"""病例 SQL 导入:确认后把 SQL 中的病例表数据映射写入当前本地数据库。"""
return ok(await CaseSqlImportService().apply(file))
+1 -2
View File
@@ -1,6 +1,6 @@
from fastapi import APIRouter
from app.api import agent, auth, cases, evaluations, imports, knowledge, llm_test, sessions
from app.api import agent, auth, cases, evaluations, knowledge, llm_test, sessions
api_router = APIRouter()
api_router.include_router(agent.router, tags=["agent"])
@@ -10,4 +10,3 @@ api_router.include_router(sessions.router, prefix="/sessions", tags=["sessions"]
api_router.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"])
api_router.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
api_router.include_router(llm_test.router, prefix="/llm/test", tags=["llm-test"])
api_router.include_router(imports.router, prefix="/imports", tags=["imports"])
+9 -132
View File
@@ -1,13 +1,11 @@
from sqlalchemy import delete, exists, func, or_, select
from sqlalchemy import exists, select
from sqlalchemy.orm import Session, selectinload
from app.models.source_case import CaseBase, CaseExamItem, ScoringRule, TeachingCase, TraditionalCase
from app.models.training import SessionOrder, SessionSubmission, TrainingSession
from app.models.training_record import TrainingRecord, TrainingScoreDetail
from app.models.source_case import CaseBase, CaseExamItem, TeachingCase, TraditionalCase
class CaseRepository:
"""病例仓储:基于 case_base 新表体系读取病例、扩展表和检查项目。"""
"""病例只读仓储:读取已发布病例、模式扩展数据和固定检查项目。"""
def __init__(self, db: Session) -> None:
self.db = db
@@ -18,7 +16,8 @@ class CaseRepository:
training_type: str | None = None,
mode: str | None = None,
) -> list[CaseBase]:
"""病例列表:从 case_base 读取已发布病例,并按模式匹配传统/教学扩展表。"""
"""病例列表:从 case_base 读取已发布病例,并按模式匹配扩展表。"""
normalized_mode = "practice" if mode == "novice" else mode
stmt = (
select(CaseBase)
.options(selectinload(CaseBase.traditional_case), selectinload(CaseBase.teaching_case))
@@ -28,14 +27,14 @@ class CaseRepository:
stmt = stmt.where(CaseBase.department_id == department_id)
if training_type:
stmt = stmt.where(CaseBase.case_type == training_type)
if mode == "practice":
if normalized_mode == "practice":
stmt = stmt.where(exists().where(TraditionalCase.case_id == CaseBase.id))
if mode == "teaching":
if normalized_mode == "teaching":
stmt = stmt.where(exists().where(TeachingCase.case_id == CaseBase.id))
return list(self.db.scalars(stmt.order_by(CaseBase.id.desc())).all())
def get_active_case(self, case_id: int) -> CaseBase | None:
"""病例详情:读取 case_base,并加载传统病例、教学病例、评分规则和检查项目。"""
"""病例详情:读取病例主表及训练所需的扩展表、评分规则和检查项目。"""
stmt = (
select(CaseBase)
.options(
@@ -48,134 +47,12 @@ class CaseRepository:
)
return self.db.scalar(stmt)
def get_case_by_id(self, case_id: int) -> CaseBase | None:
"""病例删除:按主键读取病例,不限制发布状态,用于删除前校验。"""
return self.db.get(CaseBase, case_id)
def get_delete_preview_counts(self, case_id: int) -> dict[str, int]:
"""病例删除预览:统计删除病例会影响的业务数据数量。"""
session_ids = self._session_ids(case_id)
return {
"case_base": self._count(CaseBase, CaseBase.id == case_id),
"traditional_case": self._count(TraditionalCase, TraditionalCase.case_id == case_id),
"teaching_case": self._count(TeachingCase, TeachingCase.case_id == case_id),
"scoring_rule": self._count(ScoringRule, ScoringRule.case_id == case_id),
"case_exam_item": self._count(CaseExamItem, CaseExamItem.case_id == case_id),
"training_session": len(session_ids),
"training_order": self._count_training_orders(case_id, session_ids),
"training_submission": self._count_by_sessions(SessionSubmission, SessionSubmission.session_id, session_ids),
"training_score_detail": self._count_score_details(case_id, session_ids),
"training_record": self._count_training_records(case_id, session_ids),
}
def delete_case_cascade(self, case_id: int) -> dict[str, int]:
"""病例删除执行:按外键依赖顺序清理训练数据、检查项、评分规则和病例主表。"""
session_ids = self._session_ids(case_id)
deleted: dict[str, int] = {}
deleted["training_order"] = self._delete_training_orders(case_id, session_ids)
deleted["training_submission"] = self._delete_by_sessions(
SessionSubmission, SessionSubmission.session_id, session_ids
)
deleted["training_score_detail"] = self._delete_score_details(case_id, session_ids)
deleted["training_record"] = self._delete_training_records(case_id, session_ids)
deleted["training_session"] = self._delete_where(TrainingSession, TrainingSession.case_id == case_id)
deleted["case_exam_item"] = self._delete_where(CaseExamItem, CaseExamItem.case_id == case_id)
deleted["scoring_rule"] = self._delete_where(ScoringRule, ScoringRule.case_id == case_id)
deleted["traditional_case"] = self._delete_where(TraditionalCase, TraditionalCase.case_id == case_id)
deleted["teaching_case"] = self._delete_where(TeachingCase, TeachingCase.case_id == case_id)
deleted["case_base"] = self._delete_where(CaseBase, CaseBase.id == case_id)
return deleted
def _session_ids(self, case_id: int) -> list[int]:
"""病例删除:读取该病例关联的训练会话 ID 集合。"""
stmt = select(TrainingSession.id).where(TrainingSession.case_id == case_id)
return [int(item) for item in self.db.scalars(stmt).all()]
def _count(self, model: type, *criteria) -> int:
"""病例删除预览:按条件统计单表记录数。"""
stmt = select(func.count()).select_from(model).where(*criteria)
return int(self.db.scalar(stmt) or 0)
def _count_by_sessions(self, model: type, session_column, session_ids: list[int]) -> int:
"""病例删除预览:按训练会话集合统计从表记录数。"""
if not session_ids:
return 0
return self._count(model, session_column.in_(session_ids))
def _count_training_orders(self, case_id: int, session_ids: list[int]) -> int:
"""病例删除预览:统计检查申请记录,兼容按病例和按会话两种关联。"""
if session_ids:
return self._count(SessionOrder, or_(SessionOrder.case_id == case_id, SessionOrder.session_id.in_(session_ids)))
return self._count(SessionOrder, SessionOrder.case_id == case_id)
def _count_training_records(self, case_id: int, session_ids: list[int]) -> int:
"""病例删除预览:统计完整训练记录,兼容按病例和按会话两种关联。"""
if session_ids:
return self._count(
TrainingRecord,
or_(TrainingRecord.case_id == case_id, TrainingRecord.session_id.in_(session_ids)),
)
return self._count(TrainingRecord, TrainingRecord.case_id == case_id)
def _count_score_details(self, case_id: int, session_ids: list[int]) -> int:
"""病例删除预览:统计该病例评价记录下的评分明细。"""
record_ids = self._record_ids(case_id, session_ids)
if not record_ids:
return 0
return self._count(TrainingScoreDetail, TrainingScoreDetail.record_id.in_(record_ids))
def _delete_where(self, model: type, *criteria) -> int:
"""病例删除执行:按条件删除单表记录并返回影响行数。"""
result = self.db.execute(delete(model).where(*criteria))
return int(result.rowcount or 0)
def _delete_by_sessions(self, model: type, session_column, session_ids: list[int]) -> int:
"""病例删除执行:按训练会话集合删除从表记录。"""
if not session_ids:
return 0
return self._delete_where(model, session_column.in_(session_ids))
def _delete_training_orders(self, case_id: int, session_ids: list[int]) -> int:
"""病例删除执行:删除该病例下所有检查申请记录,避免阻塞检查项删除。"""
if session_ids:
return self._delete_where(
SessionOrder,
or_(SessionOrder.case_id == case_id, SessionOrder.session_id.in_(session_ids)),
)
return self._delete_where(SessionOrder, SessionOrder.case_id == case_id)
def _delete_training_records(self, case_id: int, session_ids: list[int]) -> int:
"""病例删除执行:删除该病例完整训练后沉淀的评价记录。"""
if session_ids:
return self._delete_where(
TrainingRecord,
or_(TrainingRecord.case_id == case_id, TrainingRecord.session_id.in_(session_ids)),
)
return self._delete_where(TrainingRecord, TrainingRecord.case_id == case_id)
def _delete_score_details(self, case_id: int, session_ids: list[int]) -> int:
"""病例删除执行:先删除评价明细,避免阻塞训练记录删除。"""
record_ids = self._record_ids(case_id, session_ids)
if not record_ids:
return 0
return self._delete_where(TrainingScoreDetail, TrainingScoreDetail.record_id.in_(record_ids))
def _record_ids(self, case_id: int, session_ids: list[int]) -> list[int]:
"""病例删除:读取该病例关联的训练记录 ID 集合。"""
if session_ids:
stmt = select(TrainingRecord.id).where(
or_(TrainingRecord.case_id == case_id, TrainingRecord.session_id.in_(session_ids))
)
else:
stmt = select(TrainingRecord.id).where(TrainingRecord.case_id == case_id)
return [int(item) for item in self.db.scalars(stmt).all()]
def get_exam_items(self, case_id: int) -> list[CaseExamItem]:
"""检查项目:读取当前病例下全部可申请检查检验项目。"""
stmt = select(CaseExamItem).where(CaseExamItem.case_id == case_id).order_by(CaseExamItem.display_order)
return list(self.db.scalars(stmt).all())
def get_exam_item(self, case_id: int, item_code: str) -> CaseExamItem | None:
"""检查结果:按病例和项目编码读取固定检查检验结果。"""
"""检查结果:按病例和项目编码读取数据库中的固定结果。"""
stmt = select(CaseExamItem).where(CaseExamItem.case_id == case_id, CaseExamItem.item_code == item_code)
return self.db.scalar(stmt)
+5 -50
View File
@@ -1,56 +1,18 @@
from __future__ import annotations
from sqlalchemy import exists, select
from sqlalchemy.orm import Session, selectinload
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.models.department import Department
from app.models.source_case import CaseBase, ScoringRule, TeachingCase, TraditionalCase
from app.models.source_case import ScoringRule
class SourceCaseRepository:
"""源库病例仓储:读取 case_base、traditional_case、teaching_case 和 scoring_rule"""
"""病例辅助只读仓储:读取科室名称和病例评分规则"""
def __init__(self, db: Session) -> None:
self.db = db
def list_active_cases(
self,
department_id: int | None = None,
case_type: str | None = None,
mode: str | None = None,
) -> list[CaseBase]:
"""源库病例列表:按科室、病例分类和训练模式读取已发布病例。"""
stmt = (
select(CaseBase)
.options(selectinload(CaseBase.traditional_case), selectinload(CaseBase.teaching_case))
.where(CaseBase.status == 1, CaseBase.publish_status == 1)
)
if department_id:
stmt = stmt.where(CaseBase.department_id == department_id)
if case_type:
stmt = stmt.where(CaseBase.case_type == case_type)
normalized_mode = self.normalize_mode(mode)
if normalized_mode == "practice":
stmt = stmt.where(exists().where(TraditionalCase.case_id == CaseBase.id))
if normalized_mode == "teaching":
stmt = stmt.where(exists().where(TeachingCase.case_id == CaseBase.id))
return list(self.db.scalars(stmt.order_by(CaseBase.id.desc())).all())
def get_active_case_base(self, case_id: int) -> CaseBase | None:
"""源库病例详情:读取病例主表及传统/教学扩展表。"""
stmt = (
select(CaseBase)
.options(
selectinload(CaseBase.traditional_case),
selectinload(CaseBase.teaching_case),
selectinload(CaseBase.scoring_rules),
)
.where(CaseBase.id == case_id, CaseBase.status == 1, CaseBase.publish_status == 1)
)
return self.db.scalar(stmt)
def get_department_name(self, department_id: int | None) -> str:
"""科室名称:按用户端 department 表读取科室名称。"""
"""科室名称:按 department 表读取当前病例所属科室名称。"""
if not department_id:
return ""
department = self.db.scalar(select(Department).where(Department.id == department_id))
@@ -60,10 +22,3 @@ class SourceCaseRepository:
"""评分规则:读取当前病例对应的基础评分细则。"""
stmt = select(ScoringRule).where(ScoringRule.case_id == case_id).order_by(ScoringRule.id)
return list(self.db.scalars(stmt).all())
@staticmethod
def normalize_mode(mode: str | None) -> str | None:
"""模式归一:旧 novice 请求按练习模式处理,第一版只暴露 practice/teaching。"""
if mode == "novice":
return "practice"
return mode
-24
View File
@@ -50,27 +50,3 @@ class CaseDetailResponse(BaseModel):
has_knowledge_points: bool
has_quiz: bool
order_item_types: list[str]
class CaseDeletePreviewResponse(BaseModel):
"""病例删除预览:返回删除该病例会影响的业务数据数量。"""
case_id: int
case_title: str
can_delete: bool
affected: dict[str, int]
class CaseDeleteRequest(BaseModel):
"""病例删除请求:前端必须显式确认,并默认同时删除该病例训练数据。"""
confirm: bool = False
delete_training_data: bool = True
class CaseDeleteResponse(BaseModel):
"""病例删除结果:返回已删除的各表记录数量。"""
deleted: bool
case_id: int
deleted_counts: dict[str, int]
-36
View File
@@ -1,36 +0,0 @@
from pydantic import BaseModel, Field
class CaseSqlPreviewCase(BaseModel):
"""病例 SQL 预览病例:展示导入文件中识别到的病例摘要。"""
id: int
title: str
case_type: str
difficulty: str
class CaseSqlImportPreviewResponse(BaseModel):
"""病例 SQL 预检响应:只展示解析结果,不写入数据库。"""
file_name: str
encoding: str | None = None
tables: dict[str, int] = Field(default_factory=dict)
can_import: bool = False
warnings: list[str] = Field(default_factory=list)
errors: list[str] = Field(default_factory=list)
preview_cases: list[CaseSqlPreviewCase] = Field(default_factory=list)
class CaseSqlImportApplyResponse(BaseModel):
"""病例 SQL 导入响应:展示实际写库结果。"""
imported: bool
file_name: str
encoding: str
inserted_or_updated_cases: int
imported_traditional_cases: int
imported_teaching_cases: int
imported_scoring_rules: int
generated_exam_items: int
warnings: list[str] = Field(default_factory=list)
+1 -56
View File
@@ -1,20 +1,10 @@
from sqlalchemy.orm import Session
from app.core.context import UserContext
from app.core.exceptions import AppError
from app.models.source_case import CaseBase
from app.repositories.case_repository import CaseRepository
from app.repositories.source_case_repository import SourceCaseRepository
from app.schemas.case import (
CaseDeletePreviewResponse,
CaseDeleteRequest,
CaseDeleteResponse,
CaseDetailResponse,
CaseListItem,
CaseListResponse,
CasePatientInfo,
)
from app.services.audit_service import AuditService
from app.schemas.case import CaseDetailResponse, CaseListItem, CaseListResponse, CasePatientInfo
class CaseService:
@@ -62,51 +52,6 @@ class CaseService:
order_item_types=sorted({item.item_type for item in order_items}),
)
def get_delete_preview(self, case_id: int) -> CaseDeletePreviewResponse:
"""病例删除预览:返回删除病例前端需要展示的影响范围。"""
case = self.repo.get_case_by_id(case_id)
if not case:
raise AppError("CASE_NOT_FOUND", "case not found", 404)
return CaseDeletePreviewResponse(
case_id=case.id,
case_title=case.title,
can_delete=True,
affected=self.repo.get_delete_preview_counts(case.id),
)
def delete_case(self, case_id: int, payload: CaseDeleteRequest, ctx: UserContext) -> CaseDeleteResponse:
"""病例删除:级联删除病例业务数据,并保留审计日志用于追踪操作。"""
case = self.repo.get_case_by_id(case_id)
if not case:
raise AppError("CASE_NOT_FOUND", "case not found", 404)
if not payload.confirm:
raise AppError("CASE_DELETE_CONFIRM_REQUIRED", "case delete confirmation is required", 400)
preview = self.repo.get_delete_preview_counts(case_id)
training_rows = (
preview.get("training_session", 0)
+ preview.get("training_order", 0)
+ preview.get("training_submission", 0)
+ preview.get("training_record", 0)
)
if training_rows and not payload.delete_training_data:
raise AppError("CASE_DELETE_TRAINING_DATA_EXISTS", "case has training data; delete_training_data must be true", 400)
try:
deleted_counts = self.repo.delete_case_cascade(case_id)
AuditService(self.db).log(
ctx,
action="case.delete",
resource_type="case",
resource_id=str(case_id),
metadata={"case_title": case.title, "deleted_counts": deleted_counts},
)
self.db.commit()
except Exception:
self.db.rollback()
raise
return CaseDeleteResponse(deleted=True, case_id=case_id, deleted_counts=deleted_counts)
def _to_list_item(self, case: CaseBase) -> CaseListItem:
"""病例卡片转换:把 case_base 映射为当前前端病例列表结构。"""
return CaseListItem(
-101
View File
@@ -1,101 +0,0 @@
from __future__ import annotations
import tempfile
from pathlib import Path
from fastapi import UploadFile
from app.core.exceptions import AppError
from app.schemas.imports import CaseSqlImportApplyResponse, CaseSqlImportPreviewResponse, CaseSqlPreviewCase
from scripts.import_source_case_sql import ImportValidationError, import_source_sql, parse_source_dump
MAX_SQL_UPLOAD_BYTES = 5 * 1024 * 1024
ALLOWED_TABLES = {"case_base", "traditional_case", "teaching_case", "scoring_rule"}
class CaseSqlImportService:
"""病例 SQL 导入服务:解析接口 SQL 文件并安全映射到当前病例表。"""
async def preview(self, file: UploadFile) -> CaseSqlImportPreviewResponse:
"""导入预检:上传 SQL 后只解析结构和数据,不写入数据库。"""
temp_path = await self._save_upload_to_temp(file)
try:
parsed, warnings, encoding = parse_source_dump(temp_path)
table_rows = self._allowed_table_counts(parsed)
preview_cases = [
CaseSqlPreviewCase(
id=int(row["id"]),
title=str(row.get("title") or ""),
case_type=str(row.get("case_type") or ""),
difficulty=str(row.get("difficulty") or ""),
)
for row in parsed.get("case_base", [])
]
return CaseSqlImportPreviewResponse(
file_name=file.filename or "case.sql",
encoding=encoding,
tables=table_rows,
can_import=True,
warnings=warnings,
errors=[],
preview_cases=preview_cases,
)
except ImportValidationError as exc:
return CaseSqlImportPreviewResponse(
file_name=file.filename or "case.sql",
can_import=False,
errors=[str(exc)],
)
finally:
self._remove_temp_file(temp_path)
async def apply(self, file: UploadFile) -> CaseSqlImportApplyResponse:
"""确认导入:解析通过后以事务方式写入 case_base/traditional_case/teaching_case/scoring_rule。"""
temp_path = await self._save_upload_to_temp(file)
try:
report = import_source_sql(temp_path, apply=True)
return CaseSqlImportApplyResponse(
imported=True,
file_name=file.filename or "case.sql",
encoding=report.encoding,
inserted_or_updated_cases=report.upserted_cases,
imported_traditional_cases=report.upserted_traditional_cases,
imported_teaching_cases=report.upserted_teaching_cases,
imported_scoring_rules=report.replaced_scoring_rules,
generated_exam_items=report.generated_exam_items,
warnings=report.warnings,
)
except ImportValidationError as exc:
raise AppError("CASE_SQL_IMPORT_INVALID", str(exc), 400) from exc
finally:
self._remove_temp_file(temp_path)
async def _save_upload_to_temp(self, file: UploadFile) -> Path:
"""上传落盘:校验 SQL 后缀和大小后保存到临时文件,供解析器读取。"""
filename = file.filename or ""
if not filename.lower().endswith(".sql"):
raise AppError("CASE_SQL_FILE_INVALID", "only .sql files are supported", 400)
content = await file.read()
if not content:
raise AppError("CASE_SQL_FILE_EMPTY", "uploaded SQL file is empty", 400)
if len(content) > MAX_SQL_UPLOAD_BYTES:
raise AppError("CASE_SQL_FILE_TOO_LARGE", "SQL file is larger than 5MB", 400)
handle = tempfile.NamedTemporaryFile(delete=False, suffix=".sql")
try:
handle.write(content)
return Path(handle.name)
finally:
handle.close()
def _allowed_table_counts(self, parsed: dict[str, list[dict]]) -> dict[str, int]:
"""表计数:只暴露当前业务允许导入的病例表。"""
return {table: len(rows) for table, rows in parsed.items() if table in ALLOWED_TABLES}
def _remove_temp_file(self, path: Path) -> None:
"""临时文件清理:解析或导入完成后删除上传副本。"""
try:
path.unlink(missing_ok=True)
except OSError:
pass