make case catalog read-only

This commit is contained in:
刘金宝
2026-06-04 17:50:22 +08:00
parent b46e43aadc
commit 7f1803f9fa
15 changed files with 35 additions and 1268 deletions
+4 -3
View File
@@ -2,6 +2,8 @@
医疗问诊 Agent 是医疗教学平台中的问诊训练服务。后端负责 Django 用户身份校验、病例读取、多轮问诊、检查申请、诊断治疗提交、AI 评价、评分明细、PDF 报告和历史训练记录。
病例库在本服务中为只读数据源。病例新增、解析、修改和删除由外部病例管理系统负责;本服务只读取已发布病例及其训练扩展、检查项和评分规则。
## 项目结构
仓库根目录可以直接部署为服务器的 `fastapi/` 目录:
@@ -127,14 +129,14 @@ docker compose exec fastapi python scripts/check_final_schema.py
docker compose exec fastapi python scripts/check_final_demo_readiness.py
```
仅在结构检查确认缺少 Agent 所需表时,备份数据库后执行
以下迁移脚本只用于独立本地开发库或旧环境升级,不得用于共享生产病例库
```bash
docker compose exec fastapi python scripts/migrate_to_new_schema.py
docker compose exec fastapi python scripts/migrate_user_department_score_detail.py
```
迁移脚本使用 `create_all` 补齐 Agent 所需表,不删除 Django 或现有业务表;`migrate_to_new_schema.py`在缺少对应数据时写入 Demo 病例和基础数据。
迁移脚本使用 `create_all` 补齐 Agent 所需表,不删除 Django 或现有业务表;`migrate_to_new_schema.py` 会写入 Demo 病例和基础数据。共享环境中的病例数据由外部病例管理系统维护。
## 部署验证
@@ -170,5 +172,4 @@ python -m compileall app scripts tests
python tests\test_core_logic.py
python tests\test_api_contract.py
python tests\test_demo_flow.py
python tests\test_import_source_case_sql.py
```
+1 -28
View File
@@ -4,13 +4,7 @@ from sqlalchemy.orm import Session
from app.core.response import ApiResponse, ok
from app.core.user_context import UserContext, get_user_context
from app.db.session import get_db
from app.schemas.case import (
CaseDeletePreviewResponse,
CaseDeleteRequest,
CaseDeleteResponse,
CaseDetailResponse,
CaseListResponse,
)
from app.schemas.case import CaseDetailResponse, CaseListResponse
from app.services.case_service import CaseService
router = APIRouter()
@@ -36,24 +30,3 @@ def get_case_detail(
):
"""病例详情:返回训练入口信息和可申请检查类型。"""
return ok(CaseService(db).get_case_detail(case_id))
@router.get("/{case_id}/delete-preview", response_model=ApiResponse[CaseDeletePreviewResponse])
def get_case_delete_preview(
case_id: int,
_: UserContext = Depends(get_user_context),
db: Session = Depends(get_db),
):
"""病例删除预览:返回删除该病例会影响的训练与病例数据数量。"""
return ok(CaseService(db).get_delete_preview(case_id))
@router.delete("/{case_id}", response_model=ApiResponse[CaseDeleteResponse])
def delete_case(
case_id: int,
payload: CaseDeleteRequest,
ctx: UserContext = Depends(get_user_context),
db: Session = Depends(get_db),
):
"""病例删除:确认后级联删除病例、扩展表、评分规则、检查项和关联训练数据。"""
return ok(CaseService(db).delete_case(case_id, payload, ctx))
-26
View File
@@ -1,26 +0,0 @@
from fastapi import APIRouter, Depends, File, UploadFile
from app.core.response import ApiResponse, ok
from app.core.user_context import UserContext, get_user_context
from app.schemas.imports import CaseSqlImportApplyResponse, CaseSqlImportPreviewResponse
from app.services.case_sql_import_service import CaseSqlImportService
router = APIRouter()
@router.post("/case-sql/preview", response_model=ApiResponse[CaseSqlImportPreviewResponse])
async def preview_case_sql(
file: UploadFile = File(...),
_: UserContext = Depends(get_user_context),
):
"""病例 SQL 预检:上传接口 SQL 文件,解析可导入病例数据但不写入数据库。"""
return ok(await CaseSqlImportService().preview(file))
@router.post("/case-sql/apply", response_model=ApiResponse[CaseSqlImportApplyResponse])
async def apply_case_sql(
file: UploadFile = File(...),
_: UserContext = Depends(get_user_context),
):
"""病例 SQL 导入:确认后把 SQL 中的病例表数据映射写入当前本地数据库。"""
return ok(await CaseSqlImportService().apply(file))
+1 -2
View File
@@ -1,6 +1,6 @@
from fastapi import APIRouter
from app.api import agent, auth, cases, evaluations, imports, knowledge, llm_test, sessions
from app.api import agent, auth, cases, evaluations, knowledge, llm_test, sessions
api_router = APIRouter()
api_router.include_router(agent.router, tags=["agent"])
@@ -10,4 +10,3 @@ api_router.include_router(sessions.router, prefix="/sessions", tags=["sessions"]
api_router.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"])
api_router.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
api_router.include_router(llm_test.router, prefix="/llm/test", tags=["llm-test"])
api_router.include_router(imports.router, prefix="/imports", tags=["imports"])
+9 -132
View File
@@ -1,13 +1,11 @@
from sqlalchemy import delete, exists, func, or_, select
from sqlalchemy import exists, select
from sqlalchemy.orm import Session, selectinload
from app.models.source_case import CaseBase, CaseExamItem, ScoringRule, TeachingCase, TraditionalCase
from app.models.training import SessionOrder, SessionSubmission, TrainingSession
from app.models.training_record import TrainingRecord, TrainingScoreDetail
from app.models.source_case import CaseBase, CaseExamItem, TeachingCase, TraditionalCase
class CaseRepository:
"""病例仓储:基于 case_base 新表体系读取病例、扩展表和检查项目。"""
"""病例只读仓储:读取已发布病例、模式扩展数据和固定检查项目。"""
def __init__(self, db: Session) -> None:
self.db = db
@@ -18,7 +16,8 @@ class CaseRepository:
training_type: str | None = None,
mode: str | None = None,
) -> list[CaseBase]:
"""病例列表:从 case_base 读取已发布病例,并按模式匹配传统/教学扩展表。"""
"""病例列表:从 case_base 读取已发布病例,并按模式匹配扩展表。"""
normalized_mode = "practice" if mode == "novice" else mode
stmt = (
select(CaseBase)
.options(selectinload(CaseBase.traditional_case), selectinload(CaseBase.teaching_case))
@@ -28,14 +27,14 @@ class CaseRepository:
stmt = stmt.where(CaseBase.department_id == department_id)
if training_type:
stmt = stmt.where(CaseBase.case_type == training_type)
if mode == "practice":
if normalized_mode == "practice":
stmt = stmt.where(exists().where(TraditionalCase.case_id == CaseBase.id))
if mode == "teaching":
if normalized_mode == "teaching":
stmt = stmt.where(exists().where(TeachingCase.case_id == CaseBase.id))
return list(self.db.scalars(stmt.order_by(CaseBase.id.desc())).all())
def get_active_case(self, case_id: int) -> CaseBase | None:
"""病例详情:读取 case_base,并加载传统病例、教学病例、评分规则和检查项目。"""
"""病例详情:读取病例主表及训练所需的扩展表、评分规则和检查项目。"""
stmt = (
select(CaseBase)
.options(
@@ -48,134 +47,12 @@ class CaseRepository:
)
return self.db.scalar(stmt)
def get_case_by_id(self, case_id: int) -> CaseBase | None:
"""病例删除:按主键读取病例,不限制发布状态,用于删除前校验。"""
return self.db.get(CaseBase, case_id)
def get_delete_preview_counts(self, case_id: int) -> dict[str, int]:
"""病例删除预览:统计删除病例会影响的业务数据数量。"""
session_ids = self._session_ids(case_id)
return {
"case_base": self._count(CaseBase, CaseBase.id == case_id),
"traditional_case": self._count(TraditionalCase, TraditionalCase.case_id == case_id),
"teaching_case": self._count(TeachingCase, TeachingCase.case_id == case_id),
"scoring_rule": self._count(ScoringRule, ScoringRule.case_id == case_id),
"case_exam_item": self._count(CaseExamItem, CaseExamItem.case_id == case_id),
"training_session": len(session_ids),
"training_order": self._count_training_orders(case_id, session_ids),
"training_submission": self._count_by_sessions(SessionSubmission, SessionSubmission.session_id, session_ids),
"training_score_detail": self._count_score_details(case_id, session_ids),
"training_record": self._count_training_records(case_id, session_ids),
}
def delete_case_cascade(self, case_id: int) -> dict[str, int]:
"""病例删除执行:按外键依赖顺序清理训练数据、检查项、评分规则和病例主表。"""
session_ids = self._session_ids(case_id)
deleted: dict[str, int] = {}
deleted["training_order"] = self._delete_training_orders(case_id, session_ids)
deleted["training_submission"] = self._delete_by_sessions(
SessionSubmission, SessionSubmission.session_id, session_ids
)
deleted["training_score_detail"] = self._delete_score_details(case_id, session_ids)
deleted["training_record"] = self._delete_training_records(case_id, session_ids)
deleted["training_session"] = self._delete_where(TrainingSession, TrainingSession.case_id == case_id)
deleted["case_exam_item"] = self._delete_where(CaseExamItem, CaseExamItem.case_id == case_id)
deleted["scoring_rule"] = self._delete_where(ScoringRule, ScoringRule.case_id == case_id)
deleted["traditional_case"] = self._delete_where(TraditionalCase, TraditionalCase.case_id == case_id)
deleted["teaching_case"] = self._delete_where(TeachingCase, TeachingCase.case_id == case_id)
deleted["case_base"] = self._delete_where(CaseBase, CaseBase.id == case_id)
return deleted
def _session_ids(self, case_id: int) -> list[int]:
"""病例删除:读取该病例关联的训练会话 ID 集合。"""
stmt = select(TrainingSession.id).where(TrainingSession.case_id == case_id)
return [int(item) for item in self.db.scalars(stmt).all()]
def _count(self, model: type, *criteria) -> int:
"""病例删除预览:按条件统计单表记录数。"""
stmt = select(func.count()).select_from(model).where(*criteria)
return int(self.db.scalar(stmt) or 0)
def _count_by_sessions(self, model: type, session_column, session_ids: list[int]) -> int:
"""病例删除预览:按训练会话集合统计从表记录数。"""
if not session_ids:
return 0
return self._count(model, session_column.in_(session_ids))
def _count_training_orders(self, case_id: int, session_ids: list[int]) -> int:
"""病例删除预览:统计检查申请记录,兼容按病例和按会话两种关联。"""
if session_ids:
return self._count(SessionOrder, or_(SessionOrder.case_id == case_id, SessionOrder.session_id.in_(session_ids)))
return self._count(SessionOrder, SessionOrder.case_id == case_id)
def _count_training_records(self, case_id: int, session_ids: list[int]) -> int:
"""病例删除预览:统计完整训练记录,兼容按病例和按会话两种关联。"""
if session_ids:
return self._count(
TrainingRecord,
or_(TrainingRecord.case_id == case_id, TrainingRecord.session_id.in_(session_ids)),
)
return self._count(TrainingRecord, TrainingRecord.case_id == case_id)
def _count_score_details(self, case_id: int, session_ids: list[int]) -> int:
"""病例删除预览:统计该病例评价记录下的评分明细。"""
record_ids = self._record_ids(case_id, session_ids)
if not record_ids:
return 0
return self._count(TrainingScoreDetail, TrainingScoreDetail.record_id.in_(record_ids))
def _delete_where(self, model: type, *criteria) -> int:
"""病例删除执行:按条件删除单表记录并返回影响行数。"""
result = self.db.execute(delete(model).where(*criteria))
return int(result.rowcount or 0)
def _delete_by_sessions(self, model: type, session_column, session_ids: list[int]) -> int:
"""病例删除执行:按训练会话集合删除从表记录。"""
if not session_ids:
return 0
return self._delete_where(model, session_column.in_(session_ids))
def _delete_training_orders(self, case_id: int, session_ids: list[int]) -> int:
"""病例删除执行:删除该病例下所有检查申请记录,避免阻塞检查项删除。"""
if session_ids:
return self._delete_where(
SessionOrder,
or_(SessionOrder.case_id == case_id, SessionOrder.session_id.in_(session_ids)),
)
return self._delete_where(SessionOrder, SessionOrder.case_id == case_id)
def _delete_training_records(self, case_id: int, session_ids: list[int]) -> int:
"""病例删除执行:删除该病例完整训练后沉淀的评价记录。"""
if session_ids:
return self._delete_where(
TrainingRecord,
or_(TrainingRecord.case_id == case_id, TrainingRecord.session_id.in_(session_ids)),
)
return self._delete_where(TrainingRecord, TrainingRecord.case_id == case_id)
def _delete_score_details(self, case_id: int, session_ids: list[int]) -> int:
"""病例删除执行:先删除评价明细,避免阻塞训练记录删除。"""
record_ids = self._record_ids(case_id, session_ids)
if not record_ids:
return 0
return self._delete_where(TrainingScoreDetail, TrainingScoreDetail.record_id.in_(record_ids))
def _record_ids(self, case_id: int, session_ids: list[int]) -> list[int]:
"""病例删除:读取该病例关联的训练记录 ID 集合。"""
if session_ids:
stmt = select(TrainingRecord.id).where(
or_(TrainingRecord.case_id == case_id, TrainingRecord.session_id.in_(session_ids))
)
else:
stmt = select(TrainingRecord.id).where(TrainingRecord.case_id == case_id)
return [int(item) for item in self.db.scalars(stmt).all()]
def get_exam_items(self, case_id: int) -> list[CaseExamItem]:
"""检查项目:读取当前病例下全部可申请检查检验项目。"""
stmt = select(CaseExamItem).where(CaseExamItem.case_id == case_id).order_by(CaseExamItem.display_order)
return list(self.db.scalars(stmt).all())
def get_exam_item(self, case_id: int, item_code: str) -> CaseExamItem | None:
"""检查结果:按病例和项目编码读取固定检查检验结果。"""
"""检查结果:按病例和项目编码读取数据库中的固定结果。"""
stmt = select(CaseExamItem).where(CaseExamItem.case_id == case_id, CaseExamItem.item_code == item_code)
return self.db.scalar(stmt)
+5 -50
View File
@@ -1,56 +1,18 @@
from __future__ import annotations
from sqlalchemy import exists, select
from sqlalchemy.orm import Session, selectinload
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.models.department import Department
from app.models.source_case import CaseBase, ScoringRule, TeachingCase, TraditionalCase
from app.models.source_case import ScoringRule
class SourceCaseRepository:
"""源库病例仓储:读取 case_base、traditional_case、teaching_case 和 scoring_rule"""
"""病例辅助只读仓储:读取科室名称和病例评分规则"""
def __init__(self, db: Session) -> None:
self.db = db
def list_active_cases(
self,
department_id: int | None = None,
case_type: str | None = None,
mode: str | None = None,
) -> list[CaseBase]:
"""源库病例列表:按科室、病例分类和训练模式读取已发布病例。"""
stmt = (
select(CaseBase)
.options(selectinload(CaseBase.traditional_case), selectinload(CaseBase.teaching_case))
.where(CaseBase.status == 1, CaseBase.publish_status == 1)
)
if department_id:
stmt = stmt.where(CaseBase.department_id == department_id)
if case_type:
stmt = stmt.where(CaseBase.case_type == case_type)
normalized_mode = self.normalize_mode(mode)
if normalized_mode == "practice":
stmt = stmt.where(exists().where(TraditionalCase.case_id == CaseBase.id))
if normalized_mode == "teaching":
stmt = stmt.where(exists().where(TeachingCase.case_id == CaseBase.id))
return list(self.db.scalars(stmt.order_by(CaseBase.id.desc())).all())
def get_active_case_base(self, case_id: int) -> CaseBase | None:
"""源库病例详情:读取病例主表及传统/教学扩展表。"""
stmt = (
select(CaseBase)
.options(
selectinload(CaseBase.traditional_case),
selectinload(CaseBase.teaching_case),
selectinload(CaseBase.scoring_rules),
)
.where(CaseBase.id == case_id, CaseBase.status == 1, CaseBase.publish_status == 1)
)
return self.db.scalar(stmt)
def get_department_name(self, department_id: int | None) -> str:
"""科室名称:按用户端 department 表读取科室名称。"""
"""科室名称:按 department 表读取当前病例所属科室名称。"""
if not department_id:
return ""
department = self.db.scalar(select(Department).where(Department.id == department_id))
@@ -60,10 +22,3 @@ class SourceCaseRepository:
"""评分规则:读取当前病例对应的基础评分细则。"""
stmt = select(ScoringRule).where(ScoringRule.case_id == case_id).order_by(ScoringRule.id)
return list(self.db.scalars(stmt).all())
@staticmethod
def normalize_mode(mode: str | None) -> str | None:
"""模式归一:旧 novice 请求按练习模式处理,第一版只暴露 practice/teaching。"""
if mode == "novice":
return "practice"
return mode
-24
View File
@@ -50,27 +50,3 @@ class CaseDetailResponse(BaseModel):
has_knowledge_points: bool
has_quiz: bool
order_item_types: list[str]
class CaseDeletePreviewResponse(BaseModel):
"""病例删除预览:返回删除该病例会影响的业务数据数量。"""
case_id: int
case_title: str
can_delete: bool
affected: dict[str, int]
class CaseDeleteRequest(BaseModel):
"""病例删除请求:前端必须显式确认,并默认同时删除该病例训练数据。"""
confirm: bool = False
delete_training_data: bool = True
class CaseDeleteResponse(BaseModel):
"""病例删除结果:返回已删除的各表记录数量。"""
deleted: bool
case_id: int
deleted_counts: dict[str, int]
-36
View File
@@ -1,36 +0,0 @@
from pydantic import BaseModel, Field
class CaseSqlPreviewCase(BaseModel):
"""病例 SQL 预览病例:展示导入文件中识别到的病例摘要。"""
id: int
title: str
case_type: str
difficulty: str
class CaseSqlImportPreviewResponse(BaseModel):
"""病例 SQL 预检响应:只展示解析结果,不写入数据库。"""
file_name: str
encoding: str | None = None
tables: dict[str, int] = Field(default_factory=dict)
can_import: bool = False
warnings: list[str] = Field(default_factory=list)
errors: list[str] = Field(default_factory=list)
preview_cases: list[CaseSqlPreviewCase] = Field(default_factory=list)
class CaseSqlImportApplyResponse(BaseModel):
"""病例 SQL 导入响应:展示实际写库结果。"""
imported: bool
file_name: str
encoding: str
inserted_or_updated_cases: int
imported_traditional_cases: int
imported_teaching_cases: int
imported_scoring_rules: int
generated_exam_items: int
warnings: list[str] = Field(default_factory=list)
+1 -56
View File
@@ -1,20 +1,10 @@
from sqlalchemy.orm import Session
from app.core.context import UserContext
from app.core.exceptions import AppError
from app.models.source_case import CaseBase
from app.repositories.case_repository import CaseRepository
from app.repositories.source_case_repository import SourceCaseRepository
from app.schemas.case import (
CaseDeletePreviewResponse,
CaseDeleteRequest,
CaseDeleteResponse,
CaseDetailResponse,
CaseListItem,
CaseListResponse,
CasePatientInfo,
)
from app.services.audit_service import AuditService
from app.schemas.case import CaseDetailResponse, CaseListItem, CaseListResponse, CasePatientInfo
class CaseService:
@@ -62,51 +52,6 @@ class CaseService:
order_item_types=sorted({item.item_type for item in order_items}),
)
def get_delete_preview(self, case_id: int) -> CaseDeletePreviewResponse:
"""病例删除预览:返回删除病例前端需要展示的影响范围。"""
case = self.repo.get_case_by_id(case_id)
if not case:
raise AppError("CASE_NOT_FOUND", "case not found", 404)
return CaseDeletePreviewResponse(
case_id=case.id,
case_title=case.title,
can_delete=True,
affected=self.repo.get_delete_preview_counts(case.id),
)
def delete_case(self, case_id: int, payload: CaseDeleteRequest, ctx: UserContext) -> CaseDeleteResponse:
"""病例删除:级联删除病例业务数据,并保留审计日志用于追踪操作。"""
case = self.repo.get_case_by_id(case_id)
if not case:
raise AppError("CASE_NOT_FOUND", "case not found", 404)
if not payload.confirm:
raise AppError("CASE_DELETE_CONFIRM_REQUIRED", "case delete confirmation is required", 400)
preview = self.repo.get_delete_preview_counts(case_id)
training_rows = (
preview.get("training_session", 0)
+ preview.get("training_order", 0)
+ preview.get("training_submission", 0)
+ preview.get("training_record", 0)
)
if training_rows and not payload.delete_training_data:
raise AppError("CASE_DELETE_TRAINING_DATA_EXISTS", "case has training data; delete_training_data must be true", 400)
try:
deleted_counts = self.repo.delete_case_cascade(case_id)
AuditService(self.db).log(
ctx,
action="case.delete",
resource_type="case",
resource_id=str(case_id),
metadata={"case_title": case.title, "deleted_counts": deleted_counts},
)
self.db.commit()
except Exception:
self.db.rollback()
raise
return CaseDeleteResponse(deleted=True, case_id=case_id, deleted_counts=deleted_counts)
def _to_list_item(self, case: CaseBase) -> CaseListItem:
"""病例卡片转换:把 case_base 映射为当前前端病例列表结构。"""
return CaseListItem(
-101
View File
@@ -1,101 +0,0 @@
from __future__ import annotations
import tempfile
from pathlib import Path
from fastapi import UploadFile
from app.core.exceptions import AppError
from app.schemas.imports import CaseSqlImportApplyResponse, CaseSqlImportPreviewResponse, CaseSqlPreviewCase
from scripts.import_source_case_sql import ImportValidationError, import_source_sql, parse_source_dump
MAX_SQL_UPLOAD_BYTES = 5 * 1024 * 1024
ALLOWED_TABLES = {"case_base", "traditional_case", "teaching_case", "scoring_rule"}
class CaseSqlImportService:
"""病例 SQL 导入服务:解析接口 SQL 文件并安全映射到当前病例表。"""
async def preview(self, file: UploadFile) -> CaseSqlImportPreviewResponse:
"""导入预检:上传 SQL 后只解析结构和数据,不写入数据库。"""
temp_path = await self._save_upload_to_temp(file)
try:
parsed, warnings, encoding = parse_source_dump(temp_path)
table_rows = self._allowed_table_counts(parsed)
preview_cases = [
CaseSqlPreviewCase(
id=int(row["id"]),
title=str(row.get("title") or ""),
case_type=str(row.get("case_type") or ""),
difficulty=str(row.get("difficulty") or ""),
)
for row in parsed.get("case_base", [])
]
return CaseSqlImportPreviewResponse(
file_name=file.filename or "case.sql",
encoding=encoding,
tables=table_rows,
can_import=True,
warnings=warnings,
errors=[],
preview_cases=preview_cases,
)
except ImportValidationError as exc:
return CaseSqlImportPreviewResponse(
file_name=file.filename or "case.sql",
can_import=False,
errors=[str(exc)],
)
finally:
self._remove_temp_file(temp_path)
async def apply(self, file: UploadFile) -> CaseSqlImportApplyResponse:
"""确认导入:解析通过后以事务方式写入 case_base/traditional_case/teaching_case/scoring_rule。"""
temp_path = await self._save_upload_to_temp(file)
try:
report = import_source_sql(temp_path, apply=True)
return CaseSqlImportApplyResponse(
imported=True,
file_name=file.filename or "case.sql",
encoding=report.encoding,
inserted_or_updated_cases=report.upserted_cases,
imported_traditional_cases=report.upserted_traditional_cases,
imported_teaching_cases=report.upserted_teaching_cases,
imported_scoring_rules=report.replaced_scoring_rules,
generated_exam_items=report.generated_exam_items,
warnings=report.warnings,
)
except ImportValidationError as exc:
raise AppError("CASE_SQL_IMPORT_INVALID", str(exc), 400) from exc
finally:
self._remove_temp_file(temp_path)
async def _save_upload_to_temp(self, file: UploadFile) -> Path:
"""上传落盘:校验 SQL 后缀和大小后保存到临时文件,供解析器读取。"""
filename = file.filename or ""
if not filename.lower().endswith(".sql"):
raise AppError("CASE_SQL_FILE_INVALID", "only .sql files are supported", 400)
content = await file.read()
if not content:
raise AppError("CASE_SQL_FILE_EMPTY", "uploaded SQL file is empty", 400)
if len(content) > MAX_SQL_UPLOAD_BYTES:
raise AppError("CASE_SQL_FILE_TOO_LARGE", "SQL file is larger than 5MB", 400)
handle = tempfile.NamedTemporaryFile(delete=False, suffix=".sql")
try:
handle.write(content)
return Path(handle.name)
finally:
handle.close()
def _allowed_table_counts(self, parsed: dict[str, list[dict]]) -> dict[str, int]:
"""表计数:只暴露当前业务允许导入的病例表。"""
return {table: len(rows) for table, rows in parsed.items() if table in ALLOWED_TABLES}
def _remove_temp_file(self, path: Path) -> None:
"""临时文件清理:解析或导入完成后删除上传副本。"""
try:
path.unlink(missing_ok=True)
except OSError:
pass
-612
View File
@@ -1,612 +0,0 @@
from __future__ import annotations
import argparse
import json
import re
import sys
from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal
from pathlib import Path
from typing import Any
from sqlalchemy import delete, inspect, select
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from app.db.session import SessionLocal
from app.models.source_case import CaseBase, CaseExamItem, ScoringRule, TeachingCase, TraditionalCase
SOURCE_TABLES = ("case_base",)
OPTIONAL_SOURCE_TABLES = ("traditional_case", "teaching_case", "scoring_rule", "case_exam_item")
DANGEROUS_PATTERNS = (
r"\bDROP\s+DATABASE\b",
r"\bDROP\s+TABLE\b",
r"\bTRUNCATE\b",
r"\bDELETE\s+FROM\b",
r"\bCREATE\s+DATABASE\b",
r"\bALTER\s+TABLE\b",
)
JSON_COLUMNS = {
"case_base": {"symptom_tags", "disease_tags", "competency_tags", "guideline_tags", "knowledge_points", "multimodal_assets"},
"scoring_rule": {"rubric_json"},
}
DATETIME_COLUMNS = {"created_at", "updated_at"}
DECIMAL_COLUMNS = {"score_weight"}
INT_COLUMNS = {
"id",
"difficulty_score",
"patient_age",
"estimated_minutes",
"vector_status",
"publish_status",
"status",
"created_by_id",
"department_id",
"case_id",
}
BOOL_COLUMNS = {"osce_enabled", "rag_enabled", "ai_auto_score", "osce_dimension"}
class ImportValidationError(Exception):
"""导入校验错误:源 SQL 不满足安全导入要求时中止。"""
@dataclass
class ImportReport:
"""导入报告:记录检查、写入和跳过情况。"""
source_path: str
encoding: str
table_rows: dict[str, int]
warnings: list[str]
applied: bool
upserted_cases: int = 0
upserted_traditional_cases: int = 0
upserted_teaching_cases: int = 0
replaced_scoring_rules: int = 0
generated_exam_items: int = 0
def as_dict(self) -> dict[str, Any]:
"""报告输出:转换为前端和命令行均可读的结构。"""
return {
"source_path": self.source_path,
"encoding": self.encoding,
"table_rows": self.table_rows,
"warnings": self.warnings,
"applied": self.applied,
"upserted_cases": self.upserted_cases,
"upserted_traditional_cases": self.upserted_traditional_cases,
"upserted_teaching_cases": self.upserted_teaching_cases,
"replaced_scoring_rules": self.replaced_scoring_rules,
"generated_exam_items": self.generated_exam_items,
}
def detect_encoding(path: Path) -> str:
"""编码识别:根据 BOM 和试读结果判断 SQL 文件编码。"""
raw = path.read_bytes()
if raw.startswith(b"\xff\xfe") or raw.startswith(b"\xfe\xff"):
return "utf-16"
if raw.startswith(b"\xef\xbb\xbf"):
return "utf-8-sig"
for encoding in ("utf-8", "utf-16"):
try:
raw.decode(encoding)
return encoding
except UnicodeDecodeError:
continue
raise ImportValidationError("SQL 文件编码无法识别,请提供 UTF-8 或 UTF-16 文件。")
def load_sql_text(path: Path) -> tuple[str, str]:
"""文件读取:读取接口提供的 SQL dump,并返回文本与编码。"""
if not path.exists():
raise ImportValidationError(f"SQL 文件不存在:{path}")
encoding = detect_encoding(path)
return path.read_text(encoding=encoding), encoding
def find_dangerous_statements(sql_text: str) -> list[str]:
"""安全扫描:识别源 dump 中不能在正式库直接执行的 DDL/DML。"""
hits: list[str] = []
for pattern in DANGEROUS_PATTERNS:
if re.search(pattern, sql_text, flags=re.IGNORECASE):
hits.append(pattern.replace(r"\b", "").replace("\\s+", " "))
return hits
def extract_create_columns(sql_text: str, table_name: str) -> list[str]:
"""字段提取:从 CREATE TABLE 中按源顺序读取字段名。"""
match = re.search(rf"CREATE TABLE `{re.escape(table_name)}` \((.*?)\) ENGINE=", sql_text, flags=re.S | re.I)
if not match:
return []
columns: list[str] = []
for line in match.group(1).splitlines():
stripped = line.strip().rstrip(",")
if stripped.startswith("`"):
columns.append(stripped.split("`", 2)[1])
return columns
def extract_insert_rows(sql_text: str, table_name: str, columns: list[str]) -> list[dict[str, Any]]:
"""数据提取:只解析 INSERT VALUES 数据,不执行源 SQL。"""
rows: list[dict[str, Any]] = []
pattern = re.compile(rf"INSERT INTO `{re.escape(table_name)}` VALUES\s*(.*?);", flags=re.S | re.I)
for match in pattern.finditer(sql_text):
tuples = parse_values_clause(match.group(1))
for values in tuples:
if len(values) != len(columns):
raise ImportValidationError(
f"{table_name} INSERT 字段数不匹配:期望 {len(columns)},实际 {len(values)}"
)
row = {columns[index]: normalize_value(table_name, columns[index], value) for index, value in enumerate(values)}
rows.append(row)
return rows
def parse_values_clause(values_clause: str) -> list[list[Any]]:
"""VALUES 解析:把 SQL values 子句解析为二维数组,损坏字符串会直接报错。"""
rows: list[list[Any]] = []
index = 0
length = len(values_clause)
while index < length:
while index < length and values_clause[index].isspace():
index += 1
if index >= length:
break
if values_clause[index] == ",":
index += 1
continue
if values_clause[index] != "(":
raise ImportValidationError(f"VALUES 子句格式错误:第 {index} 个字符不是 '('")
row, index = _parse_tuple(values_clause, index)
rows.append(row)
return rows
def _parse_tuple(text: str, start: int) -> tuple[list[Any], int]:
"""元组解析:解析一组括号内的 SQL 字段值。"""
values: list[Any] = []
token: list[str] = []
in_string = False
was_string = False
escaped = False
index = start + 1
while index < len(text):
char = text[index]
if in_string:
if escaped:
token.append(_unescape_char(char))
escaped = False
elif char == "\\":
escaped = True
elif char == ")":
recovered_at_end = _recover_unclosed_string_at_tuple_end(token)
if recovered_at_end:
string_value, raw_values = recovered_at_end
values.append(string_value)
values.extend(raw_values)
return values, index + 1
token.append(char)
elif char == "'":
recovered = _recover_misplaced_quote_separator(token, text[index + 1] if index + 1 < len(text) else "")
if recovered:
string_value, raw_values = recovered
values.append(string_value)
values.extend(raw_values)
token = []
was_string = True
index += 1
continue
in_string = False
else:
token.append(char)
index += 1
continue
if char == "'":
stripped = "".join(token).strip()
if stripped:
raise ImportValidationError(f"字符串前存在非法未引用内容:{stripped[:20]}")
in_string = True
was_string = True
index += 1
continue
if was_string and char.isspace():
index += 1
continue
if was_string and char not in {",", ")"}:
raise ImportValidationError(f"字符串后存在非法未引用内容:{char}{text[index + 1:index + 20]}")
if char == ",":
values.append(_coerce_raw_token("".join(token), was_string))
token = []
was_string = False
index += 1
continue
if char == ")":
values.append(_coerce_raw_token("".join(token), was_string))
return values, index + 1
token.append(char)
index += 1
raise ImportValidationError("VALUES 子句存在未闭合括号或未闭合字符串。")
def _recover_misplaced_quote_separator(token: list[str], next_char: str) -> tuple[str, list[Any]] | None:
"""兼容解析:修复接口 SQL 文本字段结尾写成 `文本,'next'` 的引号/逗号错位。"""
if not next_char or next_char in {",", ")"} or next_char.isspace():
return None
raw = "".join(token)
if not raw.endswith(","):
return None
body = raw[:-1]
trailing_values: list[Any] = []
while True:
head, separator, tail = body.rpartition(",")
if not separator or not _looks_like_unquoted_scalar(tail):
break
trailing_values.insert(0, _coerce_raw_token(tail, was_string=False))
body = head
if not body:
return None
return body, trailing_values
def _recover_unclosed_string_at_tuple_end(token: list[str]) -> tuple[str, list[Any]] | None:
"""兼容解析:修复文本字段缺少结束引号且后接 `,数字)` 的源 SQL。"""
body = "".join(token)
trailing_values: list[Any] = []
while True:
head, separator, tail = body.rpartition(",")
if not separator or not _looks_like_unquoted_scalar(tail):
break
trailing_values.insert(0, _coerce_raw_token(tail, was_string=False))
body = head
if not body or not trailing_values:
return None
return body, trailing_values
def _looks_like_unquoted_scalar(value: str) -> bool:
"""兼容解析:判断错位引号前夹带的字段是否是可安全恢复的未引用标量。"""
stripped = value.strip()
if not stripped:
return False
if stripped.upper() == "NULL":
return True
return bool(re.fullmatch(r"-?\d+(?:\.\d+)?", stripped))
def _unescape_char(char: str) -> str:
"""SQL 转义:处理 dump 中的常见反斜杠转义。"""
return {
"0": "\0",
"b": "\b",
"n": "\n",
"r": "\r",
"t": "\t",
"Z": "\x1a",
"\\": "\\",
"'": "'",
'"': '"',
}.get(char, char)
def _coerce_raw_token(raw: str, was_string: bool) -> Any:
"""原始字段转换:区分 SQL NULL、数字和字符串。"""
if was_string:
return raw
value = raw.strip()
if not value:
return ""
if value.upper() == "NULL":
return None
if re.fullmatch(r"-?\d+", value):
return int(value)
if re.fullmatch(r"-?\d+\.\d+", value):
return Decimal(value)
if re.search(r"[A-Za-z\u4e00-\u9fff]", value):
raise ImportValidationError(f"发现未引用文本字段:{value[:30]}")
return value
def normalize_value(table_name: str, column: str, value: Any) -> Any:
"""字段归一:按当前 ORM 需要转换 JSON、时间、布尔和数值。"""
if value is None:
return None
if column in JSON_COLUMNS.get(table_name, set()):
if isinstance(value, (list, dict)):
return value
try:
return json.loads(value)
except (TypeError, json.JSONDecodeError) as exc:
raise ImportValidationError(f"{table_name}.{column} 不是合法 JSON。") from exc
if column in DATETIME_COLUMNS and isinstance(value, str):
try:
return datetime.fromisoformat(value)
except ValueError as exc:
raise ImportValidationError(f"{table_name}.{column} 不是合法 datetime。") from exc
if column in BOOL_COLUMNS:
return bool(int(value))
if column in INT_COLUMNS and value is not None:
return int(value)
if column in DECIMAL_COLUMNS and value is not None:
return Decimal(str(value))
return value
def parse_source_dump(path: Path) -> tuple[dict[str, list[dict[str, Any]]], list[str], str]:
"""源文件解析:提取可导入表数据并返回兼容性警告。"""
sql_text, encoding = load_sql_text(path)
warnings = []
dangerous = find_dangerous_statements(sql_text)
if dangerous:
warnings.append("源 SQL 包含 DDL/DML 覆盖语句,导入器会忽略这些语句:" + ", ".join(sorted(set(dangerous))))
parsed: dict[str, list[dict[str, Any]]] = {}
for table_name in SOURCE_TABLES:
columns = extract_create_columns(sql_text, table_name)
if not columns:
raise ImportValidationError(f"源 SQL 缺少必需表结构:{table_name}")
parsed[table_name] = extract_insert_rows(sql_text, table_name, columns)
for table_name in OPTIONAL_SOURCE_TABLES:
columns = extract_create_columns(sql_text, table_name)
if not columns:
warnings.append(f"源 SQL 未包含 {table_name},导入器会按当前业务规则处理。")
continue
parsed[table_name] = extract_insert_rows(sql_text, table_name, columns)
if not parsed.get("traditional_case") and not parsed.get("teaching_case"):
warnings.append("源 SQL 未包含 traditional_case 或 teaching_case,病例导入后暂时缺少训练模式扩展数据。")
if not parsed.get("scoring_rule"):
warnings.append("源 SQL 未包含 scoring_rule,评价时将缺少接口侧基础评分规则。")
return parsed, warnings, encoding
def validate_target_schema(parsed: dict[str, list[dict[str, Any]]]) -> None:
"""目标结构校验:确认源字段可映射到当前数据库表。"""
with SessionLocal() as db:
inspector = inspect(db.bind)
for table_name, rows in parsed.items():
if not inspector.has_table(table_name):
raise ImportValidationError(f"当前数据库缺少目标表:{table_name}")
target_columns = {column["name"] for column in inspector.get_columns(table_name)}
for row in rows:
extra_columns = sorted(set(row) - target_columns)
if extra_columns:
raise ImportValidationError(f"{table_name} 存在当前库不支持的字段:{extra_columns}")
def import_source_sql(path: Path, apply: bool = False, generate_exam_items: bool = True) -> ImportReport:
"""安全导入:解析源 SQL,并在显式 apply 时写入当前新表。"""
parsed, warnings, encoding = parse_source_dump(path)
validate_target_schema(parsed)
report = ImportReport(
source_path=str(path),
encoding=encoding,
table_rows={table: len(rows) for table, rows in parsed.items()},
warnings=warnings,
applied=apply,
)
if not apply:
return report
with SessionLocal() as db:
try:
case_rows = parsed.get("case_base", [])
report.upserted_cases = _upsert_cases(db, case_rows)
if "traditional_case" in parsed:
report.upserted_traditional_cases = _sync_traditional_cases(db, case_rows, parsed.get("traditional_case", []))
if "teaching_case" in parsed:
report.upserted_teaching_cases = _sync_teaching_cases(db, case_rows, parsed.get("teaching_case", []))
if "scoring_rule" in parsed:
report.replaced_scoring_rules = _replace_scoring_rules(db, case_rows, parsed.get("scoring_rule", []))
if generate_exam_items:
report.generated_exam_items = _upsert_generated_exam_items(db, case_rows)
db.commit()
except Exception:
db.rollback()
raise
return report
def _upsert_cases(db, rows: list[dict[str, Any]]) -> int:
"""病例导入:按 case_base.id 更新或插入病例主表。"""
count = 0
for row in rows:
row = dict(row)
row["status"] = 1
row["publish_status"] = 1
entity = db.get(CaseBase, row["id"])
if not entity:
db.add(CaseBase(**row))
else:
for key, value in row.items():
setattr(entity, key, value)
count += 1
return count
def _sync_traditional_cases(db, case_rows: list[dict[str, Any]], rows: list[dict[str, Any]]) -> int:
"""传统病例导入:按本次 SQL 同步练习模式扩展表,避免同 case_id 旧数据残留。"""
case_ids = _imported_case_ids(case_rows)
if case_ids:
db.execute(delete(TraditionalCase).where(TraditionalCase.case_id.in_(case_ids)))
count = 0
for row in rows:
db.add(TraditionalCase(**row))
count += 1
return count
def _sync_teaching_cases(db, case_rows: list[dict[str, Any]], rows: list[dict[str, Any]]) -> int:
"""教学互动病例导入:源 SQL 明确提供 teaching_case 时,以源数据为准同步扩展表。"""
case_ids = _imported_case_ids(case_rows)
if case_ids:
db.execute(delete(TeachingCase).where(TeachingCase.case_id.in_(case_ids)))
count = 0
for row in rows:
db.add(TeachingCase(**row))
count += 1
return count
def _replace_scoring_rules(db, case_rows: list[dict[str, Any]], rows: list[dict[str, Any]]) -> int:
"""评分规则导入:按本次导入病例替换 scoring_rule,保持评分规则与源数据一致。"""
case_ids = _imported_case_ids(case_rows)
for case_id in case_ids:
db.execute(delete(ScoringRule).where(ScoringRule.case_id == case_id))
for row in rows:
db.add(ScoringRule(**row))
return len(rows)
def _upsert_generated_exam_items(db, case_rows: list[dict[str, Any]]) -> int:
"""检查项目补齐:按 item_code 更新或补齐固定检查结果,避免删除历史训练引用的检查项。"""
changed = 0
for row in case_rows:
case_id = row["id"]
items = build_exam_items_from_case(row)
for item in items:
entity = db.scalar(
select(CaseExamItem).where(CaseExamItem.case_id == case_id, CaseExamItem.item_code == item["item_code"])
)
if not entity:
db.add(CaseExamItem(case_id=case_id, **item))
else:
for key, value in item.items():
setattr(entity, key, value)
changed += 1
return changed
def _imported_case_ids(case_rows: list[dict[str, Any]]) -> list[int]:
"""导入范围:提取本次 SQL 涉及的病例 ID,用于同步扩展表。"""
return sorted({int(row["id"]) for row in case_rows if row.get("id") is not None})
def build_exam_items_from_case(case_row: dict[str, Any]) -> list[dict[str, Any]]:
"""检查项目生成:从病例文本识别常见检查结果,保障 Demo 可继续申请检查。"""
text = " ".join(str(case_row.get(key) or "") for key in ("chief_complaint", "description", "knowledge_points"))
candidates = [
(
"blood_routine",
"血常规",
"lab",
"实验室检查",
_find_result(text, r"(WBC[^,。;;]*[,]\s*[^,。;;]*(?:中性|neutrophil)[^,。;;]*)") or "血常规结果见病例资料。",
{"source": "case_description"},
"WBC" in text or "血常规" in text,
),
(
"crp",
"CRP",
"lab",
"实验室检查",
_find_result(text, r"(CRP\s*[^,。;;]*)") or "CRP 结果见病例资料。",
{"source": "case_description"},
"CRP" in text.upper(),
),
(
"chest_xray",
"胸片",
"imaging",
"影像检查",
_find_result(text, r"(胸片[^。;;]*)") or "胸片结果见病例资料。",
{"source": "case_description"},
"胸片" in text or "X线" in text,
),
(
"spo2",
"血氧饱和度",
"vital_sign",
"生命体征",
_find_result(text, r"(SpO2\s*\d+%?)") or "血氧饱和度结果见病例资料。",
{"source": "case_description"},
"SpO2" in text or "血氧" in text,
),
(
"chest_auscultation",
"肺部体格检查",
"physical_exam",
"体格检查",
_find_result(text, r"(肺[^。;;]*(?:湿啰音|哮鸣音|呼吸音)[^。;;]*)") or "肺部体格检查结果见病例资料。",
{"source": "case_description"},
"湿啰音" in text or "哮鸣音" in text or "" in text,
),
(
"mp_igm",
"肺炎支原体抗体IgM",
"lab",
"实验室检查",
_find_result(text, r"(肺炎支原体抗体IgM[^,。;;]*)") or "肺炎支原体抗体IgM结果见病例资料。",
{"source": "case_description"},
"肺炎支原体抗体IgM" in text or "支原体" in text,
),
]
items: list[dict[str, Any]] = []
for index, (code, name, item_type, category, result_text, structured, detected) in enumerate(candidates, start=1):
if detected:
items.append(
{
"item_code": code,
"item_name": name,
"item_type": item_type,
"category": category,
"result_text": result_text,
"result_structured": structured,
"is_key": True,
"is_abnormal": True,
"score_weight": Decimal("5.00"),
"display_order": index,
}
)
if items:
return items
return [
{
"item_code": "basic_exam",
"item_name": "基础检查",
"item_type": "physical_exam",
"category": "体格检查",
"result_text": "基础检查结果见病例资料。",
"result_structured": {"source": "case_description"},
"is_key": True,
"is_abnormal": False,
"score_weight": Decimal("1.00"),
"display_order": 1,
}
]
def _find_result(text: str, pattern: str) -> str | None:
"""文本抽取:从病例描述中抽取简短检查结果。"""
match = re.search(pattern, text, flags=re.I)
return match.group(1).strip() if match else None
def main() -> None:
"""命令入口:执行接口 SQL 的安全检查或导入。"""
parser = argparse.ArgumentParser(description="Safely import parsed case SQL into current medical agent schema.")
parser.add_argument("sql_path", type=Path, help="接口提供的 SQL dump 文件路径")
parser.add_argument("--apply", action="store_true", help="确认写入当前数据库;默认只检查不写入")
parser.add_argument("--no-generate-exam-items", action="store_true", help="不自动补齐 case_exam_item")
args = parser.parse_args()
try:
report = import_source_sql(
args.sql_path,
apply=args.apply,
generate_exam_items=not args.no_generate_exam_items,
)
except ImportValidationError as exc:
print(json.dumps({"ok": False, "error": str(exc)}, ensure_ascii=False, indent=2))
raise SystemExit(2) from exc
print(json.dumps({"ok": True, "report": report.as_dict()}, ensure_ascii=False, indent=2, default=str))
if __name__ == "__main__":
main()
+1 -1
View File
@@ -32,7 +32,7 @@ def init_database() -> None:
def seed_demo_data(db) -> None:
"""病例导入:写入儿科支气管肺炎病例、检查项目、评分规则和提示词元数据。"""
"""Demo 数据初始化:仅为本地/测试库写入儿科病例和基础训练数据。"""
department = _get_or_create_department(db)
case = _get_or_create_case_base(db, department.id)
_seed_traditional_case(db, case.id)
+8 -119
View File
@@ -1,15 +1,16 @@
import os
import sys
from decimal import Decimal
import tempfile
from pathlib import Path
os.environ.setdefault("DATABASE_URL", "sqlite:///./storage/test_api_contract.db")
TEST_DB_PATH = Path(tempfile.gettempdir()) / "medical_agent_test_api_contract.db"
TEST_DB_PATH.unlink(missing_ok=True)
os.environ["DATABASE_URL"] = f"sqlite:///{TEST_DB_PATH.as_posix()}"
os.environ.setdefault("RUNTIME_MEMORY_BACKEND", "memory")
os.environ.setdefault("LLM_MOCK_ENABLED", "true")
os.environ.setdefault("AUTH_USER_ME_URL", "http://django-user-center.test/api/user/users/me/")
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
Path("storage").mkdir(exist_ok=True)
try:
from fastapi.testclient import TestClient
@@ -25,9 +26,6 @@ def run_api_contract_tests() -> None:
from app.main import app
from app.services.external_auth_service import AuthenticatedUser, ExternalAuthService
from app.db.session import SessionLocal
from app.models.source_case import CaseBase, CaseExamItem, ScoringRule, TraditionalCase
from app.repositories.case_repository import CaseRepository
from scripts.init_demo_db import init_database
async def fake_authenticate(self, request): # noqa: ARG001
@@ -87,6 +85,10 @@ def run_api_contract_tests() -> None:
auth_me_operation = openapi_payload["paths"]["/api/v1/auth/me"]["get"]
assert any("HTTPBearer" in item for item in auth_me_operation.get("security", []))
assert "HTTPBearer" in openapi_payload["components"]["securitySchemes"]
assert "/api/v1/imports/case-sql/preview" not in openapi_payload["paths"]
assert "/api/v1/imports/case-sql/apply" not in openapi_payload["paths"]
assert "/api/v1/cases/{case_id}/delete-preview" not in openapi_payload["paths"]
assert "delete" not in openapi_payload["paths"]["/api/v1/cases/{case_id}"]
cases = client.get("/api/v1/cases", headers=headers)
assert cases.status_code == 200
@@ -158,119 +160,6 @@ def run_api_contract_tests() -> None:
assert llm_reason.json()["code"] == "OK"
assert "total_latency_ms" in llm_reason.json()["data"]
import_preview = client.post(
"/api/v1/imports/case-sql/preview",
headers=headers,
files={"file": ("bad.sql", b"not sql", "application/sql")},
)
assert import_preview.status_code == 200
assert import_preview.json()["code"] == "OK"
assert import_preview.json()["data"]["can_import"] is False
assert import_preview.json()["data"]["errors"]
import_apply = client.post(
"/api/v1/imports/case-sql/apply",
headers=headers,
files={"file": ("bad.sql", b"not sql", "application/sql")},
)
assert import_apply.status_code == 400
assert import_apply.json()["code"] == "CASE_SQL_IMPORT_INVALID"
temp_case_id = 880001
with SessionLocal() as db:
CaseRepository(db).delete_case_cascade(temp_case_id)
db.add(
CaseBase(
id=temp_case_id,
title="删除测试病例",
case_type="diagnosis_treatment",
difficulty="medium",
chief_complaint="发热、咳嗽",
description="用于删除接口测试的临时病例",
patient_age=6,
patient_gender="male",
tags="test",
symptom_tags=[],
disease_tags=[],
competency_tags=[],
guideline_tags=[],
knowledge_points=[],
icd_codes="",
osce_enabled=False,
rag_enabled=False,
ai_prompt_template="",
multimodal_assets=[],
vector_status=0,
publish_status=1,
status=1,
department_id=1,
)
)
db.add(
TraditionalCase(
id=temp_case_id,
case_id=temp_case_id,
standard_diagnosis="测试诊断",
standard_treatment="测试治疗",
guideline_reference="测试指南",
)
)
db.add(
ScoringRule(
id=temp_case_id,
case_id=temp_case_id,
dimension="信息采集",
competency_dimension="问诊完整性",
score_weight=Decimal("10.00"),
ai_auto_score=True,
osce_dimension=False,
scoring_standard="测试评分标准",
rubric_json={},
)
)
db.add(
CaseExamItem(
id=temp_case_id,
case_id=temp_case_id,
item_code="temp_exam",
item_name="测试检查",
item_type="lab",
result_text="测试结果",
result_structured={},
is_key=True,
is_abnormal=False,
score_weight=Decimal("1.00"),
display_order=1,
)
)
db.commit()
delete_preview = client.get(f"/api/v1/cases/{temp_case_id}/delete-preview", headers=headers)
assert delete_preview.status_code == 200
assert delete_preview.json()["data"]["affected"]["case_base"] == 1
assert delete_preview.json()["data"]["affected"]["case_exam_item"] == 1
delete_without_confirm = client.request(
"DELETE",
f"/api/v1/cases/{temp_case_id}",
headers=headers,
json={"confirm": False, "delete_training_data": True},
)
assert delete_without_confirm.status_code == 400
assert delete_without_confirm.json()["code"] == "CASE_DELETE_CONFIRM_REQUIRED"
delete_case = client.request(
"DELETE",
f"/api/v1/cases/{temp_case_id}",
headers=headers,
json={"confirm": True, "delete_training_data": True},
)
assert delete_case.status_code == 200
assert delete_case.json()["data"]["deleted_counts"]["case_base"] == 1
deleted_detail = client.get(f"/api/v1/cases/{temp_case_id}", headers=headers)
assert deleted_detail.status_code == 404
assert deleted_detail.json()["code"] == "CASE_NOT_FOUND"
if __name__ == "__main__":
+5 -3
View File
@@ -1,15 +1,17 @@
import asyncio
import os
import sys
import tempfile
from pathlib import Path
os.environ.setdefault("DATABASE_URL", "sqlite:///./storage/test_demo_flow.db")
TEST_DB_PATH = Path(tempfile.gettempdir()) / "medical_agent_test_demo_flow.db"
TEST_DB_PATH.unlink(missing_ok=True)
os.environ["DATABASE_URL"] = f"sqlite:///{TEST_DB_PATH.as_posix()}"
os.environ["REPORT_STORAGE_DIR"] = str(Path(tempfile.gettempdir()) / "medical_agent_test_reports")
os.environ.setdefault("RUNTIME_MEMORY_BACKEND", "memory")
os.environ.setdefault("LLM_MOCK_ENABLED", "true")
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
if os.getenv("DATABASE_URL") == "sqlite:///./storage/test_demo_flow.db":
Path("storage/test_demo_flow.db").unlink(missing_ok=True)
from sqlalchemy import select
-75
View File
@@ -1,75 +0,0 @@
import tempfile
from pathlib import Path
import os
import sys
os.environ.setdefault("DATABASE_URL", "sqlite:///./storage/test_import.db")
os.environ.setdefault("RUNTIME_MEMORY_BACKEND", "memory")
os.environ.setdefault("LLM_MOCK_ENABLED", "true")
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from scripts.import_source_case_sql import ImportValidationError, extract_create_columns, extract_insert_rows, parse_values_clause
def test_extract_insert_rows_maps_by_create_columns() -> None:
"""导入解析:验证 INSERT VALUES 按 CREATE TABLE 字段顺序映射。"""
sql = """
CREATE TABLE `traditional_case` (
`created_at` datetime(6) NOT NULL,
`updated_at` datetime(6) NOT NULL,
`id` bigint NOT NULL AUTO_INCREMENT,
`standard_diagnosis` longtext NOT NULL,
`standard_treatment` longtext NOT NULL,
`guideline_reference` longtext NOT NULL,
`case_id` bigint NOT NULL
) ENGINE=InnoDB;
INSERT INTO `traditional_case` VALUES ('2026-05-28 09:34:08','2026-05-28 09:34:09',1,'支气管肺炎','抗感染','指南',1);
"""
columns = extract_create_columns(sql, "traditional_case")
rows = extract_insert_rows(sql, "traditional_case", columns)
assert rows[0]["standard_diagnosis"] == "支气管肺炎"
assert rows[0]["case_id"] == 1
def test_extract_insert_rows_rejects_broken_sql() -> None:
"""导入解析:验证损坏 SQL 被拒绝,避免半导入。"""
sql = """
CREATE TABLE `traditional_case` (
`created_at` datetime(6) NOT NULL,
`updated_at` datetime(6) NOT NULL,
`id` bigint NOT NULL AUTO_INCREMENT,
`standard_diagnosis` longtext NOT NULL,
`standard_treatment` longtext NOT NULL,
`guideline_reference` longtext NOT NULL,
`case_id` bigint NOT NULL
) ENGINE=InnoDB;
INSERT INTO `traditional_case` VALUES ('2026-05-28','2026-05-28',1,'支气管肺炎'bad,'抗感染','指南',1);
"""
columns = extract_create_columns(sql, "traditional_case")
try:
extract_insert_rows(sql, "traditional_case", columns)
except ImportValidationError:
return
raise AssertionError("broken SQL should be rejected")
def test_parse_values_clause_recovers_misplaced_quote_separator() -> None:
"""导入解析:兼容接口 SQL 中文本字段写成 `文本,'next'` 的引号/逗号错位。"""
rows = parse_values_clause("(1,'标题?,'traditional','medium',NULL,'主诉?,'描述?,6,'male')")
assert rows[0] == [1, "标题?", "traditional", "medium", None, "主诉?", "描述?", 6, "male"]
def test_parse_values_clause_recovers_unclosed_string_before_tuple_end() -> None:
"""导入解析:兼容接口 SQL 最后文本字段缺少结束引号且后接外键的情况。"""
rows = parse_values_clause("(1,'诊断?,'治疗?,'指南?,1)")
assert rows[0] == [1, "诊断?", "治疗?", "指南?", 1]
if __name__ == "__main__":
test_extract_insert_rows_maps_by_create_columns()
test_extract_insert_rows_rejects_broken_sql()
test_parse_values_clause_recovers_misplaced_quote_separator()
test_parse_values_clause_recovers_unclosed_string_before_tuple_end()
print("import source case sql tests passed")