diff --git a/README.md b/README.md index eca5ce8..4932358 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,8 @@ 医疗问诊 Agent 是医疗教学平台中的问诊训练服务。后端负责 Django 用户身份校验、病例读取、多轮问诊、检查申请、诊断治疗提交、AI 评价、评分明细、PDF 报告和历史训练记录。 +病例库在本服务中为只读数据源。病例新增、解析、修改和删除由外部病例管理系统负责;本服务只读取已发布病例及其训练扩展、检查项和评分规则。 + ## 项目结构 仓库根目录可以直接部署为服务器的 `fastapi/` 目录: @@ -127,14 +129,14 @@ docker compose exec fastapi python scripts/check_final_schema.py docker compose exec fastapi python scripts/check_final_demo_readiness.py ``` -仅在结构检查确认缺少 Agent 所需表时,备份数据库后执行: +以下迁移脚本只用于独立本地开发库或旧环境升级,不得用于共享生产病例库: ```bash docker compose exec fastapi python scripts/migrate_to_new_schema.py docker compose exec fastapi python scripts/migrate_user_department_score_detail.py ``` -迁移脚本使用 `create_all` 补齐 Agent 所需表,不删除 Django 或现有业务表;`migrate_to_new_schema.py` 会在缺少对应数据时写入 Demo 病例和基础数据。 +迁移脚本使用 `create_all` 补齐 Agent 所需表,不删除 Django 或现有业务表;`migrate_to_new_schema.py` 会写入 Demo 病例和基础数据。共享环境中的病例数据由外部病例管理系统维护。 ## 部署验证 @@ -170,5 +172,4 @@ python -m compileall app scripts tests python tests\test_core_logic.py python tests\test_api_contract.py python tests\test_demo_flow.py -python tests\test_import_source_case_sql.py ``` diff --git a/app/api/cases.py b/app/api/cases.py index c46cd97..58d638c 100644 --- a/app/api/cases.py +++ b/app/api/cases.py @@ -4,13 +4,7 @@ from sqlalchemy.orm import Session from app.core.response import ApiResponse, ok from app.core.user_context import UserContext, get_user_context from app.db.session import get_db -from app.schemas.case import ( - CaseDeletePreviewResponse, - CaseDeleteRequest, - CaseDeleteResponse, - CaseDetailResponse, - CaseListResponse, -) +from app.schemas.case import CaseDetailResponse, CaseListResponse from app.services.case_service import CaseService router = APIRouter() @@ -36,24 +30,3 @@ def get_case_detail( ): """病例详情:返回训练入口信息和可申请检查类型。""" return ok(CaseService(db).get_case_detail(case_id)) - - -@router.get("/{case_id}/delete-preview", response_model=ApiResponse[CaseDeletePreviewResponse]) -def get_case_delete_preview( - case_id: int, - _: UserContext = Depends(get_user_context), - db: Session = Depends(get_db), -): - """病例删除预览:返回删除该病例会影响的训练与病例数据数量。""" - return ok(CaseService(db).get_delete_preview(case_id)) - - -@router.delete("/{case_id}", response_model=ApiResponse[CaseDeleteResponse]) -def delete_case( - case_id: int, - payload: CaseDeleteRequest, - ctx: UserContext = Depends(get_user_context), - db: Session = Depends(get_db), -): - """病例删除:确认后级联删除病例、扩展表、评分规则、检查项和关联训练数据。""" - return ok(CaseService(db).delete_case(case_id, payload, ctx)) diff --git a/app/api/imports.py b/app/api/imports.py deleted file mode 100644 index c5c7eae..0000000 --- a/app/api/imports.py +++ /dev/null @@ -1,26 +0,0 @@ -from fastapi import APIRouter, Depends, File, UploadFile - -from app.core.response import ApiResponse, ok -from app.core.user_context import UserContext, get_user_context -from app.schemas.imports import CaseSqlImportApplyResponse, CaseSqlImportPreviewResponse -from app.services.case_sql_import_service import CaseSqlImportService - -router = APIRouter() - - -@router.post("/case-sql/preview", response_model=ApiResponse[CaseSqlImportPreviewResponse]) -async def preview_case_sql( - file: UploadFile = File(...), - _: UserContext = Depends(get_user_context), -): - """病例 SQL 预检:上传接口 SQL 文件,解析可导入病例数据但不写入数据库。""" - return ok(await CaseSqlImportService().preview(file)) - - -@router.post("/case-sql/apply", response_model=ApiResponse[CaseSqlImportApplyResponse]) -async def apply_case_sql( - file: UploadFile = File(...), - _: UserContext = Depends(get_user_context), -): - """病例 SQL 导入:确认后把 SQL 中的病例表数据映射写入当前本地数据库。""" - return ok(await CaseSqlImportService().apply(file)) diff --git a/app/api/router.py b/app/api/router.py index 9059984..0b95689 100644 --- a/app/api/router.py +++ b/app/api/router.py @@ -1,6 +1,6 @@ from fastapi import APIRouter -from app.api import agent, auth, cases, evaluations, imports, knowledge, llm_test, sessions +from app.api import agent, auth, cases, evaluations, knowledge, llm_test, sessions api_router = APIRouter() api_router.include_router(agent.router, tags=["agent"]) @@ -10,4 +10,3 @@ api_router.include_router(sessions.router, prefix="/sessions", tags=["sessions"] api_router.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"]) api_router.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"]) api_router.include_router(llm_test.router, prefix="/llm/test", tags=["llm-test"]) -api_router.include_router(imports.router, prefix="/imports", tags=["imports"]) diff --git a/app/repositories/case_repository.py b/app/repositories/case_repository.py index 5e9be83..482012a 100644 --- a/app/repositories/case_repository.py +++ b/app/repositories/case_repository.py @@ -1,13 +1,11 @@ -from sqlalchemy import delete, exists, func, or_, select +from sqlalchemy import exists, select from sqlalchemy.orm import Session, selectinload -from app.models.source_case import CaseBase, CaseExamItem, ScoringRule, TeachingCase, TraditionalCase -from app.models.training import SessionOrder, SessionSubmission, TrainingSession -from app.models.training_record import TrainingRecord, TrainingScoreDetail +from app.models.source_case import CaseBase, CaseExamItem, TeachingCase, TraditionalCase class CaseRepository: - """病例仓储:基于 case_base 新表体系读取病例、扩展表和检查项目。""" + """病例只读仓储:读取已发布病例、模式扩展数据和固定检查项目。""" def __init__(self, db: Session) -> None: self.db = db @@ -18,7 +16,8 @@ class CaseRepository: training_type: str | None = None, mode: str | None = None, ) -> list[CaseBase]: - """病例列表:从 case_base 读取已发布病例,并按模式匹配传统/教学扩展表。""" + """病例列表:从 case_base 读取已发布病例,并按模式匹配扩展表。""" + normalized_mode = "practice" if mode == "novice" else mode stmt = ( select(CaseBase) .options(selectinload(CaseBase.traditional_case), selectinload(CaseBase.teaching_case)) @@ -28,14 +27,14 @@ class CaseRepository: stmt = stmt.where(CaseBase.department_id == department_id) if training_type: stmt = stmt.where(CaseBase.case_type == training_type) - if mode == "practice": + if normalized_mode == "practice": stmt = stmt.where(exists().where(TraditionalCase.case_id == CaseBase.id)) - if mode == "teaching": + if normalized_mode == "teaching": stmt = stmt.where(exists().where(TeachingCase.case_id == CaseBase.id)) return list(self.db.scalars(stmt.order_by(CaseBase.id.desc())).all()) def get_active_case(self, case_id: int) -> CaseBase | None: - """病例详情:读取 case_base,并加载传统病例、教学病例、评分规则和检查项目。""" + """病例详情:读取病例主表及训练所需的扩展表、评分规则和检查项目。""" stmt = ( select(CaseBase) .options( @@ -48,134 +47,12 @@ class CaseRepository: ) return self.db.scalar(stmt) - def get_case_by_id(self, case_id: int) -> CaseBase | None: - """病例删除:按主键读取病例,不限制发布状态,用于删除前校验。""" - return self.db.get(CaseBase, case_id) - - def get_delete_preview_counts(self, case_id: int) -> dict[str, int]: - """病例删除预览:统计删除病例会影响的业务数据数量。""" - session_ids = self._session_ids(case_id) - return { - "case_base": self._count(CaseBase, CaseBase.id == case_id), - "traditional_case": self._count(TraditionalCase, TraditionalCase.case_id == case_id), - "teaching_case": self._count(TeachingCase, TeachingCase.case_id == case_id), - "scoring_rule": self._count(ScoringRule, ScoringRule.case_id == case_id), - "case_exam_item": self._count(CaseExamItem, CaseExamItem.case_id == case_id), - "training_session": len(session_ids), - "training_order": self._count_training_orders(case_id, session_ids), - "training_submission": self._count_by_sessions(SessionSubmission, SessionSubmission.session_id, session_ids), - "training_score_detail": self._count_score_details(case_id, session_ids), - "training_record": self._count_training_records(case_id, session_ids), - } - - def delete_case_cascade(self, case_id: int) -> dict[str, int]: - """病例删除执行:按外键依赖顺序清理训练数据、检查项、评分规则和病例主表。""" - session_ids = self._session_ids(case_id) - deleted: dict[str, int] = {} - deleted["training_order"] = self._delete_training_orders(case_id, session_ids) - deleted["training_submission"] = self._delete_by_sessions( - SessionSubmission, SessionSubmission.session_id, session_ids - ) - deleted["training_score_detail"] = self._delete_score_details(case_id, session_ids) - deleted["training_record"] = self._delete_training_records(case_id, session_ids) - deleted["training_session"] = self._delete_where(TrainingSession, TrainingSession.case_id == case_id) - deleted["case_exam_item"] = self._delete_where(CaseExamItem, CaseExamItem.case_id == case_id) - deleted["scoring_rule"] = self._delete_where(ScoringRule, ScoringRule.case_id == case_id) - deleted["traditional_case"] = self._delete_where(TraditionalCase, TraditionalCase.case_id == case_id) - deleted["teaching_case"] = self._delete_where(TeachingCase, TeachingCase.case_id == case_id) - deleted["case_base"] = self._delete_where(CaseBase, CaseBase.id == case_id) - return deleted - - def _session_ids(self, case_id: int) -> list[int]: - """病例删除:读取该病例关联的训练会话 ID 集合。""" - stmt = select(TrainingSession.id).where(TrainingSession.case_id == case_id) - return [int(item) for item in self.db.scalars(stmt).all()] - - def _count(self, model: type, *criteria) -> int: - """病例删除预览:按条件统计单表记录数。""" - stmt = select(func.count()).select_from(model).where(*criteria) - return int(self.db.scalar(stmt) or 0) - - def _count_by_sessions(self, model: type, session_column, session_ids: list[int]) -> int: - """病例删除预览:按训练会话集合统计从表记录数。""" - if not session_ids: - return 0 - return self._count(model, session_column.in_(session_ids)) - - def _count_training_orders(self, case_id: int, session_ids: list[int]) -> int: - """病例删除预览:统计检查申请记录,兼容按病例和按会话两种关联。""" - if session_ids: - return self._count(SessionOrder, or_(SessionOrder.case_id == case_id, SessionOrder.session_id.in_(session_ids))) - return self._count(SessionOrder, SessionOrder.case_id == case_id) - - def _count_training_records(self, case_id: int, session_ids: list[int]) -> int: - """病例删除预览:统计完整训练记录,兼容按病例和按会话两种关联。""" - if session_ids: - return self._count( - TrainingRecord, - or_(TrainingRecord.case_id == case_id, TrainingRecord.session_id.in_(session_ids)), - ) - return self._count(TrainingRecord, TrainingRecord.case_id == case_id) - - def _count_score_details(self, case_id: int, session_ids: list[int]) -> int: - """病例删除预览:统计该病例评价记录下的评分明细。""" - record_ids = self._record_ids(case_id, session_ids) - if not record_ids: - return 0 - return self._count(TrainingScoreDetail, TrainingScoreDetail.record_id.in_(record_ids)) - - def _delete_where(self, model: type, *criteria) -> int: - """病例删除执行:按条件删除单表记录并返回影响行数。""" - result = self.db.execute(delete(model).where(*criteria)) - return int(result.rowcount or 0) - - def _delete_by_sessions(self, model: type, session_column, session_ids: list[int]) -> int: - """病例删除执行:按训练会话集合删除从表记录。""" - if not session_ids: - return 0 - return self._delete_where(model, session_column.in_(session_ids)) - - def _delete_training_orders(self, case_id: int, session_ids: list[int]) -> int: - """病例删除执行:删除该病例下所有检查申请记录,避免阻塞检查项删除。""" - if session_ids: - return self._delete_where( - SessionOrder, - or_(SessionOrder.case_id == case_id, SessionOrder.session_id.in_(session_ids)), - ) - return self._delete_where(SessionOrder, SessionOrder.case_id == case_id) - - def _delete_training_records(self, case_id: int, session_ids: list[int]) -> int: - """病例删除执行:删除该病例完整训练后沉淀的评价记录。""" - if session_ids: - return self._delete_where( - TrainingRecord, - or_(TrainingRecord.case_id == case_id, TrainingRecord.session_id.in_(session_ids)), - ) - return self._delete_where(TrainingRecord, TrainingRecord.case_id == case_id) - - def _delete_score_details(self, case_id: int, session_ids: list[int]) -> int: - """病例删除执行:先删除评价明细,避免阻塞训练记录删除。""" - record_ids = self._record_ids(case_id, session_ids) - if not record_ids: - return 0 - return self._delete_where(TrainingScoreDetail, TrainingScoreDetail.record_id.in_(record_ids)) - - def _record_ids(self, case_id: int, session_ids: list[int]) -> list[int]: - """病例删除:读取该病例关联的训练记录 ID 集合。""" - if session_ids: - stmt = select(TrainingRecord.id).where( - or_(TrainingRecord.case_id == case_id, TrainingRecord.session_id.in_(session_ids)) - ) - else: - stmt = select(TrainingRecord.id).where(TrainingRecord.case_id == case_id) - return [int(item) for item in self.db.scalars(stmt).all()] - def get_exam_items(self, case_id: int) -> list[CaseExamItem]: """检查项目:读取当前病例下全部可申请检查检验项目。""" stmt = select(CaseExamItem).where(CaseExamItem.case_id == case_id).order_by(CaseExamItem.display_order) return list(self.db.scalars(stmt).all()) def get_exam_item(self, case_id: int, item_code: str) -> CaseExamItem | None: - """检查结果:按病例和项目编码读取固定检查检验结果。""" + """检查结果:按病例和项目编码读取数据库中的固定结果。""" stmt = select(CaseExamItem).where(CaseExamItem.case_id == case_id, CaseExamItem.item_code == item_code) return self.db.scalar(stmt) diff --git a/app/repositories/source_case_repository.py b/app/repositories/source_case_repository.py index 428fa3a..dbb723c 100644 --- a/app/repositories/source_case_repository.py +++ b/app/repositories/source_case_repository.py @@ -1,56 +1,18 @@ -from __future__ import annotations - -from sqlalchemy import exists, select -from sqlalchemy.orm import Session, selectinload +from sqlalchemy import select +from sqlalchemy.orm import Session from app.models.department import Department -from app.models.source_case import CaseBase, ScoringRule, TeachingCase, TraditionalCase +from app.models.source_case import ScoringRule class SourceCaseRepository: - """源库病例仓储:读取 case_base、traditional_case、teaching_case 和 scoring_rule。""" + """病例辅助只读仓储:读取科室名称和病例评分规则。""" def __init__(self, db: Session) -> None: self.db = db - def list_active_cases( - self, - department_id: int | None = None, - case_type: str | None = None, - mode: str | None = None, - ) -> list[CaseBase]: - """源库病例列表:按科室、病例分类和训练模式读取已发布病例。""" - stmt = ( - select(CaseBase) - .options(selectinload(CaseBase.traditional_case), selectinload(CaseBase.teaching_case)) - .where(CaseBase.status == 1, CaseBase.publish_status == 1) - ) - if department_id: - stmt = stmt.where(CaseBase.department_id == department_id) - if case_type: - stmt = stmt.where(CaseBase.case_type == case_type) - normalized_mode = self.normalize_mode(mode) - if normalized_mode == "practice": - stmt = stmt.where(exists().where(TraditionalCase.case_id == CaseBase.id)) - if normalized_mode == "teaching": - stmt = stmt.where(exists().where(TeachingCase.case_id == CaseBase.id)) - return list(self.db.scalars(stmt.order_by(CaseBase.id.desc())).all()) - - def get_active_case_base(self, case_id: int) -> CaseBase | None: - """源库病例详情:读取病例主表及传统/教学扩展表。""" - stmt = ( - select(CaseBase) - .options( - selectinload(CaseBase.traditional_case), - selectinload(CaseBase.teaching_case), - selectinload(CaseBase.scoring_rules), - ) - .where(CaseBase.id == case_id, CaseBase.status == 1, CaseBase.publish_status == 1) - ) - return self.db.scalar(stmt) - def get_department_name(self, department_id: int | None) -> str: - """科室名称:按用户端 department 表读取科室名称。""" + """科室名称:按 department 表读取当前病例所属科室名称。""" if not department_id: return "" department = self.db.scalar(select(Department).where(Department.id == department_id)) @@ -60,10 +22,3 @@ class SourceCaseRepository: """评分规则:读取当前病例对应的基础评分细则。""" stmt = select(ScoringRule).where(ScoringRule.case_id == case_id).order_by(ScoringRule.id) return list(self.db.scalars(stmt).all()) - - @staticmethod - def normalize_mode(mode: str | None) -> str | None: - """模式归一:旧 novice 请求按练习模式处理,第一版只暴露 practice/teaching。""" - if mode == "novice": - return "practice" - return mode diff --git a/app/schemas/case.py b/app/schemas/case.py index 40e1780..2f79eef 100644 --- a/app/schemas/case.py +++ b/app/schemas/case.py @@ -50,27 +50,3 @@ class CaseDetailResponse(BaseModel): has_knowledge_points: bool has_quiz: bool order_item_types: list[str] - - -class CaseDeletePreviewResponse(BaseModel): - """病例删除预览:返回删除该病例会影响的业务数据数量。""" - - case_id: int - case_title: str - can_delete: bool - affected: dict[str, int] - - -class CaseDeleteRequest(BaseModel): - """病例删除请求:前端必须显式确认,并默认同时删除该病例训练数据。""" - - confirm: bool = False - delete_training_data: bool = True - - -class CaseDeleteResponse(BaseModel): - """病例删除结果:返回已删除的各表记录数量。""" - - deleted: bool - case_id: int - deleted_counts: dict[str, int] diff --git a/app/schemas/imports.py b/app/schemas/imports.py deleted file mode 100644 index ac39490..0000000 --- a/app/schemas/imports.py +++ /dev/null @@ -1,36 +0,0 @@ -from pydantic import BaseModel, Field - - -class CaseSqlPreviewCase(BaseModel): - """病例 SQL 预览病例:展示导入文件中识别到的病例摘要。""" - - id: int - title: str - case_type: str - difficulty: str - - -class CaseSqlImportPreviewResponse(BaseModel): - """病例 SQL 预检响应:只展示解析结果,不写入数据库。""" - - file_name: str - encoding: str | None = None - tables: dict[str, int] = Field(default_factory=dict) - can_import: bool = False - warnings: list[str] = Field(default_factory=list) - errors: list[str] = Field(default_factory=list) - preview_cases: list[CaseSqlPreviewCase] = Field(default_factory=list) - - -class CaseSqlImportApplyResponse(BaseModel): - """病例 SQL 导入响应:展示实际写库结果。""" - - imported: bool - file_name: str - encoding: str - inserted_or_updated_cases: int - imported_traditional_cases: int - imported_teaching_cases: int - imported_scoring_rules: int - generated_exam_items: int - warnings: list[str] = Field(default_factory=list) diff --git a/app/services/case_service.py b/app/services/case_service.py index f330d49..2d0ef64 100644 --- a/app/services/case_service.py +++ b/app/services/case_service.py @@ -1,20 +1,10 @@ from sqlalchemy.orm import Session -from app.core.context import UserContext from app.core.exceptions import AppError from app.models.source_case import CaseBase from app.repositories.case_repository import CaseRepository from app.repositories.source_case_repository import SourceCaseRepository -from app.schemas.case import ( - CaseDeletePreviewResponse, - CaseDeleteRequest, - CaseDeleteResponse, - CaseDetailResponse, - CaseListItem, - CaseListResponse, - CasePatientInfo, -) -from app.services.audit_service import AuditService +from app.schemas.case import CaseDetailResponse, CaseListItem, CaseListResponse, CasePatientInfo class CaseService: @@ -62,51 +52,6 @@ class CaseService: order_item_types=sorted({item.item_type for item in order_items}), ) - def get_delete_preview(self, case_id: int) -> CaseDeletePreviewResponse: - """病例删除预览:返回删除病例前端需要展示的影响范围。""" - case = self.repo.get_case_by_id(case_id) - if not case: - raise AppError("CASE_NOT_FOUND", "case not found", 404) - return CaseDeletePreviewResponse( - case_id=case.id, - case_title=case.title, - can_delete=True, - affected=self.repo.get_delete_preview_counts(case.id), - ) - - def delete_case(self, case_id: int, payload: CaseDeleteRequest, ctx: UserContext) -> CaseDeleteResponse: - """病例删除:级联删除病例业务数据,并保留审计日志用于追踪操作。""" - case = self.repo.get_case_by_id(case_id) - if not case: - raise AppError("CASE_NOT_FOUND", "case not found", 404) - if not payload.confirm: - raise AppError("CASE_DELETE_CONFIRM_REQUIRED", "case delete confirmation is required", 400) - - preview = self.repo.get_delete_preview_counts(case_id) - training_rows = ( - preview.get("training_session", 0) - + preview.get("training_order", 0) - + preview.get("training_submission", 0) - + preview.get("training_record", 0) - ) - if training_rows and not payload.delete_training_data: - raise AppError("CASE_DELETE_TRAINING_DATA_EXISTS", "case has training data; delete_training_data must be true", 400) - - try: - deleted_counts = self.repo.delete_case_cascade(case_id) - AuditService(self.db).log( - ctx, - action="case.delete", - resource_type="case", - resource_id=str(case_id), - metadata={"case_title": case.title, "deleted_counts": deleted_counts}, - ) - self.db.commit() - except Exception: - self.db.rollback() - raise - return CaseDeleteResponse(deleted=True, case_id=case_id, deleted_counts=deleted_counts) - def _to_list_item(self, case: CaseBase) -> CaseListItem: """病例卡片转换:把 case_base 映射为当前前端病例列表结构。""" return CaseListItem( diff --git a/app/services/case_sql_import_service.py b/app/services/case_sql_import_service.py deleted file mode 100644 index af22bb0..0000000 --- a/app/services/case_sql_import_service.py +++ /dev/null @@ -1,101 +0,0 @@ -from __future__ import annotations - -import tempfile -from pathlib import Path - -from fastapi import UploadFile - -from app.core.exceptions import AppError -from app.schemas.imports import CaseSqlImportApplyResponse, CaseSqlImportPreviewResponse, CaseSqlPreviewCase -from scripts.import_source_case_sql import ImportValidationError, import_source_sql, parse_source_dump - - -MAX_SQL_UPLOAD_BYTES = 5 * 1024 * 1024 -ALLOWED_TABLES = {"case_base", "traditional_case", "teaching_case", "scoring_rule"} - - -class CaseSqlImportService: - """病例 SQL 导入服务:解析接口 SQL 文件并安全映射到当前病例表。""" - - async def preview(self, file: UploadFile) -> CaseSqlImportPreviewResponse: - """导入预检:上传 SQL 后只解析结构和数据,不写入数据库。""" - temp_path = await self._save_upload_to_temp(file) - try: - parsed, warnings, encoding = parse_source_dump(temp_path) - table_rows = self._allowed_table_counts(parsed) - preview_cases = [ - CaseSqlPreviewCase( - id=int(row["id"]), - title=str(row.get("title") or ""), - case_type=str(row.get("case_type") or ""), - difficulty=str(row.get("difficulty") or ""), - ) - for row in parsed.get("case_base", []) - ] - return CaseSqlImportPreviewResponse( - file_name=file.filename or "case.sql", - encoding=encoding, - tables=table_rows, - can_import=True, - warnings=warnings, - errors=[], - preview_cases=preview_cases, - ) - except ImportValidationError as exc: - return CaseSqlImportPreviewResponse( - file_name=file.filename or "case.sql", - can_import=False, - errors=[str(exc)], - ) - finally: - self._remove_temp_file(temp_path) - - async def apply(self, file: UploadFile) -> CaseSqlImportApplyResponse: - """确认导入:解析通过后以事务方式写入 case_base/traditional_case/teaching_case/scoring_rule。""" - temp_path = await self._save_upload_to_temp(file) - try: - report = import_source_sql(temp_path, apply=True) - return CaseSqlImportApplyResponse( - imported=True, - file_name=file.filename or "case.sql", - encoding=report.encoding, - inserted_or_updated_cases=report.upserted_cases, - imported_traditional_cases=report.upserted_traditional_cases, - imported_teaching_cases=report.upserted_teaching_cases, - imported_scoring_rules=report.replaced_scoring_rules, - generated_exam_items=report.generated_exam_items, - warnings=report.warnings, - ) - except ImportValidationError as exc: - raise AppError("CASE_SQL_IMPORT_INVALID", str(exc), 400) from exc - finally: - self._remove_temp_file(temp_path) - - async def _save_upload_to_temp(self, file: UploadFile) -> Path: - """上传落盘:校验 SQL 后缀和大小后保存到临时文件,供解析器读取。""" - filename = file.filename or "" - if not filename.lower().endswith(".sql"): - raise AppError("CASE_SQL_FILE_INVALID", "only .sql files are supported", 400) - content = await file.read() - if not content: - raise AppError("CASE_SQL_FILE_EMPTY", "uploaded SQL file is empty", 400) - if len(content) > MAX_SQL_UPLOAD_BYTES: - raise AppError("CASE_SQL_FILE_TOO_LARGE", "SQL file is larger than 5MB", 400) - - handle = tempfile.NamedTemporaryFile(delete=False, suffix=".sql") - try: - handle.write(content) - return Path(handle.name) - finally: - handle.close() - - def _allowed_table_counts(self, parsed: dict[str, list[dict]]) -> dict[str, int]: - """表计数:只暴露当前业务允许导入的病例表。""" - return {table: len(rows) for table, rows in parsed.items() if table in ALLOWED_TABLES} - - def _remove_temp_file(self, path: Path) -> None: - """临时文件清理:解析或导入完成后删除上传副本。""" - try: - path.unlink(missing_ok=True) - except OSError: - pass diff --git a/scripts/import_source_case_sql.py b/scripts/import_source_case_sql.py deleted file mode 100644 index c25cd44..0000000 --- a/scripts/import_source_case_sql.py +++ /dev/null @@ -1,612 +0,0 @@ -from __future__ import annotations - -import argparse -import json -import re -import sys -from dataclasses import dataclass -from datetime import datetime -from decimal import Decimal -from pathlib import Path -from typing import Any - -from sqlalchemy import delete, inspect, select - -sys.path.insert(0, str(Path(__file__).resolve().parents[1])) - -from app.db.session import SessionLocal -from app.models.source_case import CaseBase, CaseExamItem, ScoringRule, TeachingCase, TraditionalCase - - -SOURCE_TABLES = ("case_base",) -OPTIONAL_SOURCE_TABLES = ("traditional_case", "teaching_case", "scoring_rule", "case_exam_item") -DANGEROUS_PATTERNS = ( - r"\bDROP\s+DATABASE\b", - r"\bDROP\s+TABLE\b", - r"\bTRUNCATE\b", - r"\bDELETE\s+FROM\b", - r"\bCREATE\s+DATABASE\b", - r"\bALTER\s+TABLE\b", -) -JSON_COLUMNS = { - "case_base": {"symptom_tags", "disease_tags", "competency_tags", "guideline_tags", "knowledge_points", "multimodal_assets"}, - "scoring_rule": {"rubric_json"}, -} -DATETIME_COLUMNS = {"created_at", "updated_at"} -DECIMAL_COLUMNS = {"score_weight"} -INT_COLUMNS = { - "id", - "difficulty_score", - "patient_age", - "estimated_minutes", - "vector_status", - "publish_status", - "status", - "created_by_id", - "department_id", - "case_id", -} -BOOL_COLUMNS = {"osce_enabled", "rag_enabled", "ai_auto_score", "osce_dimension"} - - -class ImportValidationError(Exception): - """导入校验错误:源 SQL 不满足安全导入要求时中止。""" - - -@dataclass -class ImportReport: - """导入报告:记录检查、写入和跳过情况。""" - - source_path: str - encoding: str - table_rows: dict[str, int] - warnings: list[str] - applied: bool - upserted_cases: int = 0 - upserted_traditional_cases: int = 0 - upserted_teaching_cases: int = 0 - replaced_scoring_rules: int = 0 - generated_exam_items: int = 0 - - def as_dict(self) -> dict[str, Any]: - """报告输出:转换为前端和命令行均可读的结构。""" - return { - "source_path": self.source_path, - "encoding": self.encoding, - "table_rows": self.table_rows, - "warnings": self.warnings, - "applied": self.applied, - "upserted_cases": self.upserted_cases, - "upserted_traditional_cases": self.upserted_traditional_cases, - "upserted_teaching_cases": self.upserted_teaching_cases, - "replaced_scoring_rules": self.replaced_scoring_rules, - "generated_exam_items": self.generated_exam_items, - } - - -def detect_encoding(path: Path) -> str: - """编码识别:根据 BOM 和试读结果判断 SQL 文件编码。""" - raw = path.read_bytes() - if raw.startswith(b"\xff\xfe") or raw.startswith(b"\xfe\xff"): - return "utf-16" - if raw.startswith(b"\xef\xbb\xbf"): - return "utf-8-sig" - for encoding in ("utf-8", "utf-16"): - try: - raw.decode(encoding) - return encoding - except UnicodeDecodeError: - continue - raise ImportValidationError("SQL 文件编码无法识别,请提供 UTF-8 或 UTF-16 文件。") - - -def load_sql_text(path: Path) -> tuple[str, str]: - """文件读取:读取接口提供的 SQL dump,并返回文本与编码。""" - if not path.exists(): - raise ImportValidationError(f"SQL 文件不存在:{path}") - encoding = detect_encoding(path) - return path.read_text(encoding=encoding), encoding - - -def find_dangerous_statements(sql_text: str) -> list[str]: - """安全扫描:识别源 dump 中不能在正式库直接执行的 DDL/DML。""" - hits: list[str] = [] - for pattern in DANGEROUS_PATTERNS: - if re.search(pattern, sql_text, flags=re.IGNORECASE): - hits.append(pattern.replace(r"\b", "").replace("\\s+", " ")) - return hits - - -def extract_create_columns(sql_text: str, table_name: str) -> list[str]: - """字段提取:从 CREATE TABLE 中按源顺序读取字段名。""" - match = re.search(rf"CREATE TABLE `{re.escape(table_name)}` \((.*?)\) ENGINE=", sql_text, flags=re.S | re.I) - if not match: - return [] - columns: list[str] = [] - for line in match.group(1).splitlines(): - stripped = line.strip().rstrip(",") - if stripped.startswith("`"): - columns.append(stripped.split("`", 2)[1]) - return columns - - -def extract_insert_rows(sql_text: str, table_name: str, columns: list[str]) -> list[dict[str, Any]]: - """数据提取:只解析 INSERT VALUES 数据,不执行源 SQL。""" - rows: list[dict[str, Any]] = [] - pattern = re.compile(rf"INSERT INTO `{re.escape(table_name)}` VALUES\s*(.*?);", flags=re.S | re.I) - for match in pattern.finditer(sql_text): - tuples = parse_values_clause(match.group(1)) - for values in tuples: - if len(values) != len(columns): - raise ImportValidationError( - f"{table_name} INSERT 字段数不匹配:期望 {len(columns)},实际 {len(values)}。" - ) - row = {columns[index]: normalize_value(table_name, columns[index], value) for index, value in enumerate(values)} - rows.append(row) - return rows - - -def parse_values_clause(values_clause: str) -> list[list[Any]]: - """VALUES 解析:把 SQL values 子句解析为二维数组,损坏字符串会直接报错。""" - rows: list[list[Any]] = [] - index = 0 - length = len(values_clause) - while index < length: - while index < length and values_clause[index].isspace(): - index += 1 - if index >= length: - break - if values_clause[index] == ",": - index += 1 - continue - if values_clause[index] != "(": - raise ImportValidationError(f"VALUES 子句格式错误:第 {index} 个字符不是 '('。") - row, index = _parse_tuple(values_clause, index) - rows.append(row) - return rows - - -def _parse_tuple(text: str, start: int) -> tuple[list[Any], int]: - """元组解析:解析一组括号内的 SQL 字段值。""" - values: list[Any] = [] - token: list[str] = [] - in_string = False - was_string = False - escaped = False - index = start + 1 - while index < len(text): - char = text[index] - if in_string: - if escaped: - token.append(_unescape_char(char)) - escaped = False - elif char == "\\": - escaped = True - elif char == ")": - recovered_at_end = _recover_unclosed_string_at_tuple_end(token) - if recovered_at_end: - string_value, raw_values = recovered_at_end - values.append(string_value) - values.extend(raw_values) - return values, index + 1 - token.append(char) - elif char == "'": - recovered = _recover_misplaced_quote_separator(token, text[index + 1] if index + 1 < len(text) else "") - if recovered: - string_value, raw_values = recovered - values.append(string_value) - values.extend(raw_values) - token = [] - was_string = True - index += 1 - continue - in_string = False - else: - token.append(char) - index += 1 - continue - - if char == "'": - stripped = "".join(token).strip() - if stripped: - raise ImportValidationError(f"字符串前存在非法未引用内容:{stripped[:20]}") - in_string = True - was_string = True - index += 1 - continue - if was_string and char.isspace(): - index += 1 - continue - if was_string and char not in {",", ")"}: - raise ImportValidationError(f"字符串后存在非法未引用内容:{char}{text[index + 1:index + 20]}") - if char == ",": - values.append(_coerce_raw_token("".join(token), was_string)) - token = [] - was_string = False - index += 1 - continue - if char == ")": - values.append(_coerce_raw_token("".join(token), was_string)) - return values, index + 1 - token.append(char) - index += 1 - raise ImportValidationError("VALUES 子句存在未闭合括号或未闭合字符串。") - - -def _recover_misplaced_quote_separator(token: list[str], next_char: str) -> tuple[str, list[Any]] | None: - """兼容解析:修复接口 SQL 文本字段结尾写成 `文本,'next'` 的引号/逗号错位。""" - if not next_char or next_char in {",", ")"} or next_char.isspace(): - return None - raw = "".join(token) - if not raw.endswith(","): - return None - - body = raw[:-1] - trailing_values: list[Any] = [] - while True: - head, separator, tail = body.rpartition(",") - if not separator or not _looks_like_unquoted_scalar(tail): - break - trailing_values.insert(0, _coerce_raw_token(tail, was_string=False)) - body = head - if not body: - return None - return body, trailing_values - - -def _recover_unclosed_string_at_tuple_end(token: list[str]) -> tuple[str, list[Any]] | None: - """兼容解析:修复文本字段缺少结束引号且后接 `,数字)` 的源 SQL。""" - body = "".join(token) - trailing_values: list[Any] = [] - while True: - head, separator, tail = body.rpartition(",") - if not separator or not _looks_like_unquoted_scalar(tail): - break - trailing_values.insert(0, _coerce_raw_token(tail, was_string=False)) - body = head - if not body or not trailing_values: - return None - return body, trailing_values - - -def _looks_like_unquoted_scalar(value: str) -> bool: - """兼容解析:判断错位引号前夹带的字段是否是可安全恢复的未引用标量。""" - stripped = value.strip() - if not stripped: - return False - if stripped.upper() == "NULL": - return True - return bool(re.fullmatch(r"-?\d+(?:\.\d+)?", stripped)) - - -def _unescape_char(char: str) -> str: - """SQL 转义:处理 dump 中的常见反斜杠转义。""" - return { - "0": "\0", - "b": "\b", - "n": "\n", - "r": "\r", - "t": "\t", - "Z": "\x1a", - "\\": "\\", - "'": "'", - '"': '"', - }.get(char, char) - - -def _coerce_raw_token(raw: str, was_string: bool) -> Any: - """原始字段转换:区分 SQL NULL、数字和字符串。""" - if was_string: - return raw - value = raw.strip() - if not value: - return "" - if value.upper() == "NULL": - return None - if re.fullmatch(r"-?\d+", value): - return int(value) - if re.fullmatch(r"-?\d+\.\d+", value): - return Decimal(value) - if re.search(r"[A-Za-z\u4e00-\u9fff]", value): - raise ImportValidationError(f"发现未引用文本字段:{value[:30]}") - return value - - -def normalize_value(table_name: str, column: str, value: Any) -> Any: - """字段归一:按当前 ORM 需要转换 JSON、时间、布尔和数值。""" - if value is None: - return None - if column in JSON_COLUMNS.get(table_name, set()): - if isinstance(value, (list, dict)): - return value - try: - return json.loads(value) - except (TypeError, json.JSONDecodeError) as exc: - raise ImportValidationError(f"{table_name}.{column} 不是合法 JSON。") from exc - if column in DATETIME_COLUMNS and isinstance(value, str): - try: - return datetime.fromisoformat(value) - except ValueError as exc: - raise ImportValidationError(f"{table_name}.{column} 不是合法 datetime。") from exc - if column in BOOL_COLUMNS: - return bool(int(value)) - if column in INT_COLUMNS and value is not None: - return int(value) - if column in DECIMAL_COLUMNS and value is not None: - return Decimal(str(value)) - return value - - -def parse_source_dump(path: Path) -> tuple[dict[str, list[dict[str, Any]]], list[str], str]: - """源文件解析:提取可导入表数据并返回兼容性警告。""" - sql_text, encoding = load_sql_text(path) - warnings = [] - dangerous = find_dangerous_statements(sql_text) - if dangerous: - warnings.append("源 SQL 包含 DDL/DML 覆盖语句,导入器会忽略这些语句:" + ", ".join(sorted(set(dangerous)))) - - parsed: dict[str, list[dict[str, Any]]] = {} - for table_name in SOURCE_TABLES: - columns = extract_create_columns(sql_text, table_name) - if not columns: - raise ImportValidationError(f"源 SQL 缺少必需表结构:{table_name}") - parsed[table_name] = extract_insert_rows(sql_text, table_name, columns) - - for table_name in OPTIONAL_SOURCE_TABLES: - columns = extract_create_columns(sql_text, table_name) - if not columns: - warnings.append(f"源 SQL 未包含 {table_name},导入器会按当前业务规则处理。") - continue - parsed[table_name] = extract_insert_rows(sql_text, table_name, columns) - if not parsed.get("traditional_case") and not parsed.get("teaching_case"): - warnings.append("源 SQL 未包含 traditional_case 或 teaching_case,病例导入后暂时缺少训练模式扩展数据。") - if not parsed.get("scoring_rule"): - warnings.append("源 SQL 未包含 scoring_rule,评价时将缺少接口侧基础评分规则。") - return parsed, warnings, encoding - - -def validate_target_schema(parsed: dict[str, list[dict[str, Any]]]) -> None: - """目标结构校验:确认源字段可映射到当前数据库表。""" - with SessionLocal() as db: - inspector = inspect(db.bind) - for table_name, rows in parsed.items(): - if not inspector.has_table(table_name): - raise ImportValidationError(f"当前数据库缺少目标表:{table_name}") - target_columns = {column["name"] for column in inspector.get_columns(table_name)} - for row in rows: - extra_columns = sorted(set(row) - target_columns) - if extra_columns: - raise ImportValidationError(f"{table_name} 存在当前库不支持的字段:{extra_columns}") - - -def import_source_sql(path: Path, apply: bool = False, generate_exam_items: bool = True) -> ImportReport: - """安全导入:解析源 SQL,并在显式 apply 时写入当前新表。""" - parsed, warnings, encoding = parse_source_dump(path) - validate_target_schema(parsed) - report = ImportReport( - source_path=str(path), - encoding=encoding, - table_rows={table: len(rows) for table, rows in parsed.items()}, - warnings=warnings, - applied=apply, - ) - if not apply: - return report - - with SessionLocal() as db: - try: - case_rows = parsed.get("case_base", []) - report.upserted_cases = _upsert_cases(db, case_rows) - if "traditional_case" in parsed: - report.upserted_traditional_cases = _sync_traditional_cases(db, case_rows, parsed.get("traditional_case", [])) - if "teaching_case" in parsed: - report.upserted_teaching_cases = _sync_teaching_cases(db, case_rows, parsed.get("teaching_case", [])) - if "scoring_rule" in parsed: - report.replaced_scoring_rules = _replace_scoring_rules(db, case_rows, parsed.get("scoring_rule", [])) - if generate_exam_items: - report.generated_exam_items = _upsert_generated_exam_items(db, case_rows) - db.commit() - except Exception: - db.rollback() - raise - return report - - -def _upsert_cases(db, rows: list[dict[str, Any]]) -> int: - """病例导入:按 case_base.id 更新或插入病例主表。""" - count = 0 - for row in rows: - row = dict(row) - row["status"] = 1 - row["publish_status"] = 1 - entity = db.get(CaseBase, row["id"]) - if not entity: - db.add(CaseBase(**row)) - else: - for key, value in row.items(): - setattr(entity, key, value) - count += 1 - return count - - -def _sync_traditional_cases(db, case_rows: list[dict[str, Any]], rows: list[dict[str, Any]]) -> int: - """传统病例导入:按本次 SQL 同步练习模式扩展表,避免同 case_id 旧数据残留。""" - case_ids = _imported_case_ids(case_rows) - if case_ids: - db.execute(delete(TraditionalCase).where(TraditionalCase.case_id.in_(case_ids))) - count = 0 - for row in rows: - db.add(TraditionalCase(**row)) - count += 1 - return count - - -def _sync_teaching_cases(db, case_rows: list[dict[str, Any]], rows: list[dict[str, Any]]) -> int: - """教学互动病例导入:源 SQL 明确提供 teaching_case 时,以源数据为准同步扩展表。""" - case_ids = _imported_case_ids(case_rows) - if case_ids: - db.execute(delete(TeachingCase).where(TeachingCase.case_id.in_(case_ids))) - count = 0 - for row in rows: - db.add(TeachingCase(**row)) - count += 1 - return count - - -def _replace_scoring_rules(db, case_rows: list[dict[str, Any]], rows: list[dict[str, Any]]) -> int: - """评分规则导入:按本次导入病例替换 scoring_rule,保持评分规则与源数据一致。""" - case_ids = _imported_case_ids(case_rows) - for case_id in case_ids: - db.execute(delete(ScoringRule).where(ScoringRule.case_id == case_id)) - for row in rows: - db.add(ScoringRule(**row)) - return len(rows) - - -def _upsert_generated_exam_items(db, case_rows: list[dict[str, Any]]) -> int: - """检查项目补齐:按 item_code 更新或补齐固定检查结果,避免删除历史训练引用的检查项。""" - changed = 0 - for row in case_rows: - case_id = row["id"] - items = build_exam_items_from_case(row) - for item in items: - entity = db.scalar( - select(CaseExamItem).where(CaseExamItem.case_id == case_id, CaseExamItem.item_code == item["item_code"]) - ) - if not entity: - db.add(CaseExamItem(case_id=case_id, **item)) - else: - for key, value in item.items(): - setattr(entity, key, value) - changed += 1 - return changed - - -def _imported_case_ids(case_rows: list[dict[str, Any]]) -> list[int]: - """导入范围:提取本次 SQL 涉及的病例 ID,用于同步扩展表。""" - return sorted({int(row["id"]) for row in case_rows if row.get("id") is not None}) - - -def build_exam_items_from_case(case_row: dict[str, Any]) -> list[dict[str, Any]]: - """检查项目生成:从病例文本识别常见检查结果,保障 Demo 可继续申请检查。""" - text = " ".join(str(case_row.get(key) or "") for key in ("chief_complaint", "description", "knowledge_points")) - candidates = [ - ( - "blood_routine", - "血常规", - "lab", - "实验室检查", - _find_result(text, r"(WBC[^,。;;]*[,,]\s*[^,。;;]*(?:中性|neutrophil)[^,。;;]*)") or "血常规结果见病例资料。", - {"source": "case_description"}, - "WBC" in text or "血常规" in text, - ), - ( - "crp", - "CRP", - "lab", - "实验室检查", - _find_result(text, r"(CRP\s*[^,。;;]*)") or "CRP 结果见病例资料。", - {"source": "case_description"}, - "CRP" in text.upper(), - ), - ( - "chest_xray", - "胸片", - "imaging", - "影像检查", - _find_result(text, r"(胸片[^。;;]*)") or "胸片结果见病例资料。", - {"source": "case_description"}, - "胸片" in text or "X线" in text, - ), - ( - "spo2", - "血氧饱和度", - "vital_sign", - "生命体征", - _find_result(text, r"(SpO2\s*\d+%?)") or "血氧饱和度结果见病例资料。", - {"source": "case_description"}, - "SpO2" in text or "血氧" in text, - ), - ( - "chest_auscultation", - "肺部体格检查", - "physical_exam", - "体格检查", - _find_result(text, r"(肺[^。;;]*(?:湿啰音|哮鸣音|呼吸音)[^。;;]*)") or "肺部体格检查结果见病例资料。", - {"source": "case_description"}, - "湿啰音" in text or "哮鸣音" in text or "肺" in text, - ), - ( - "mp_igm", - "肺炎支原体抗体IgM", - "lab", - "实验室检查", - _find_result(text, r"(肺炎支原体抗体IgM[^,。;;]*)") or "肺炎支原体抗体IgM结果见病例资料。", - {"source": "case_description"}, - "肺炎支原体抗体IgM" in text or "支原体" in text, - ), - ] - items: list[dict[str, Any]] = [] - for index, (code, name, item_type, category, result_text, structured, detected) in enumerate(candidates, start=1): - if detected: - items.append( - { - "item_code": code, - "item_name": name, - "item_type": item_type, - "category": category, - "result_text": result_text, - "result_structured": structured, - "is_key": True, - "is_abnormal": True, - "score_weight": Decimal("5.00"), - "display_order": index, - } - ) - if items: - return items - return [ - { - "item_code": "basic_exam", - "item_name": "基础检查", - "item_type": "physical_exam", - "category": "体格检查", - "result_text": "基础检查结果见病例资料。", - "result_structured": {"source": "case_description"}, - "is_key": True, - "is_abnormal": False, - "score_weight": Decimal("1.00"), - "display_order": 1, - } - ] - - -def _find_result(text: str, pattern: str) -> str | None: - """文本抽取:从病例描述中抽取简短检查结果。""" - match = re.search(pattern, text, flags=re.I) - return match.group(1).strip() if match else None - - -def main() -> None: - """命令入口:执行接口 SQL 的安全检查或导入。""" - parser = argparse.ArgumentParser(description="Safely import parsed case SQL into current medical agent schema.") - parser.add_argument("sql_path", type=Path, help="接口提供的 SQL dump 文件路径") - parser.add_argument("--apply", action="store_true", help="确认写入当前数据库;默认只检查不写入") - parser.add_argument("--no-generate-exam-items", action="store_true", help="不自动补齐 case_exam_item") - args = parser.parse_args() - - try: - report = import_source_sql( - args.sql_path, - apply=args.apply, - generate_exam_items=not args.no_generate_exam_items, - ) - except ImportValidationError as exc: - print(json.dumps({"ok": False, "error": str(exc)}, ensure_ascii=False, indent=2)) - raise SystemExit(2) from exc - - print(json.dumps({"ok": True, "report": report.as_dict()}, ensure_ascii=False, indent=2, default=str)) - - -if __name__ == "__main__": - main() diff --git a/scripts/init_demo_db.py b/scripts/init_demo_db.py index 63872d7..ee063f2 100644 --- a/scripts/init_demo_db.py +++ b/scripts/init_demo_db.py @@ -32,7 +32,7 @@ def init_database() -> None: def seed_demo_data(db) -> None: - """病例导入:写入儿科支气管肺炎病例、检查项目、评分规则和提示词元数据。""" + """Demo 数据初始化:仅为本地/测试库写入儿科病例和基础训练数据。""" department = _get_or_create_department(db) case = _get_or_create_case_base(db, department.id) _seed_traditional_case(db, case.id) diff --git a/tests/test_api_contract.py b/tests/test_api_contract.py index 64a62e8..4a4067d 100644 --- a/tests/test_api_contract.py +++ b/tests/test_api_contract.py @@ -1,15 +1,16 @@ import os import sys -from decimal import Decimal +import tempfile from pathlib import Path -os.environ.setdefault("DATABASE_URL", "sqlite:///./storage/test_api_contract.db") +TEST_DB_PATH = Path(tempfile.gettempdir()) / "medical_agent_test_api_contract.db" +TEST_DB_PATH.unlink(missing_ok=True) +os.environ["DATABASE_URL"] = f"sqlite:///{TEST_DB_PATH.as_posix()}" os.environ.setdefault("RUNTIME_MEMORY_BACKEND", "memory") os.environ.setdefault("LLM_MOCK_ENABLED", "true") os.environ.setdefault("AUTH_USER_ME_URL", "http://django-user-center.test/api/user/users/me/") sys.path.insert(0, str(Path(__file__).resolve().parents[1])) -Path("storage").mkdir(exist_ok=True) try: from fastapi.testclient import TestClient @@ -25,9 +26,6 @@ def run_api_contract_tests() -> None: from app.main import app from app.services.external_auth_service import AuthenticatedUser, ExternalAuthService - from app.db.session import SessionLocal - from app.models.source_case import CaseBase, CaseExamItem, ScoringRule, TraditionalCase - from app.repositories.case_repository import CaseRepository from scripts.init_demo_db import init_database async def fake_authenticate(self, request): # noqa: ARG001 @@ -87,6 +85,10 @@ def run_api_contract_tests() -> None: auth_me_operation = openapi_payload["paths"]["/api/v1/auth/me"]["get"] assert any("HTTPBearer" in item for item in auth_me_operation.get("security", [])) assert "HTTPBearer" in openapi_payload["components"]["securitySchemes"] + assert "/api/v1/imports/case-sql/preview" not in openapi_payload["paths"] + assert "/api/v1/imports/case-sql/apply" not in openapi_payload["paths"] + assert "/api/v1/cases/{case_id}/delete-preview" not in openapi_payload["paths"] + assert "delete" not in openapi_payload["paths"]["/api/v1/cases/{case_id}"] cases = client.get("/api/v1/cases", headers=headers) assert cases.status_code == 200 @@ -158,119 +160,6 @@ def run_api_contract_tests() -> None: assert llm_reason.json()["code"] == "OK" assert "total_latency_ms" in llm_reason.json()["data"] - import_preview = client.post( - "/api/v1/imports/case-sql/preview", - headers=headers, - files={"file": ("bad.sql", b"not sql", "application/sql")}, - ) - assert import_preview.status_code == 200 - assert import_preview.json()["code"] == "OK" - assert import_preview.json()["data"]["can_import"] is False - assert import_preview.json()["data"]["errors"] - - import_apply = client.post( - "/api/v1/imports/case-sql/apply", - headers=headers, - files={"file": ("bad.sql", b"not sql", "application/sql")}, - ) - assert import_apply.status_code == 400 - assert import_apply.json()["code"] == "CASE_SQL_IMPORT_INVALID" - - temp_case_id = 880001 - with SessionLocal() as db: - CaseRepository(db).delete_case_cascade(temp_case_id) - db.add( - CaseBase( - id=temp_case_id, - title="删除测试病例", - case_type="diagnosis_treatment", - difficulty="medium", - chief_complaint="发热、咳嗽", - description="用于删除接口测试的临时病例", - patient_age=6, - patient_gender="male", - tags="test", - symptom_tags=[], - disease_tags=[], - competency_tags=[], - guideline_tags=[], - knowledge_points=[], - icd_codes="", - osce_enabled=False, - rag_enabled=False, - ai_prompt_template="", - multimodal_assets=[], - vector_status=0, - publish_status=1, - status=1, - department_id=1, - ) - ) - db.add( - TraditionalCase( - id=temp_case_id, - case_id=temp_case_id, - standard_diagnosis="测试诊断", - standard_treatment="测试治疗", - guideline_reference="测试指南", - ) - ) - db.add( - ScoringRule( - id=temp_case_id, - case_id=temp_case_id, - dimension="信息采集", - competency_dimension="问诊完整性", - score_weight=Decimal("10.00"), - ai_auto_score=True, - osce_dimension=False, - scoring_standard="测试评分标准", - rubric_json={}, - ) - ) - db.add( - CaseExamItem( - id=temp_case_id, - case_id=temp_case_id, - item_code="temp_exam", - item_name="测试检查", - item_type="lab", - result_text="测试结果", - result_structured={}, - is_key=True, - is_abnormal=False, - score_weight=Decimal("1.00"), - display_order=1, - ) - ) - db.commit() - - delete_preview = client.get(f"/api/v1/cases/{temp_case_id}/delete-preview", headers=headers) - assert delete_preview.status_code == 200 - assert delete_preview.json()["data"]["affected"]["case_base"] == 1 - assert delete_preview.json()["data"]["affected"]["case_exam_item"] == 1 - - delete_without_confirm = client.request( - "DELETE", - f"/api/v1/cases/{temp_case_id}", - headers=headers, - json={"confirm": False, "delete_training_data": True}, - ) - assert delete_without_confirm.status_code == 400 - assert delete_without_confirm.json()["code"] == "CASE_DELETE_CONFIRM_REQUIRED" - - delete_case = client.request( - "DELETE", - f"/api/v1/cases/{temp_case_id}", - headers=headers, - json={"confirm": True, "delete_training_data": True}, - ) - assert delete_case.status_code == 200 - assert delete_case.json()["data"]["deleted_counts"]["case_base"] == 1 - - deleted_detail = client.get(f"/api/v1/cases/{temp_case_id}", headers=headers) - assert deleted_detail.status_code == 404 - assert deleted_detail.json()["code"] == "CASE_NOT_FOUND" if __name__ == "__main__": diff --git a/tests/test_demo_flow.py b/tests/test_demo_flow.py index 204ef4f..b518ec9 100644 --- a/tests/test_demo_flow.py +++ b/tests/test_demo_flow.py @@ -1,15 +1,17 @@ import asyncio import os import sys +import tempfile from pathlib import Path -os.environ.setdefault("DATABASE_URL", "sqlite:///./storage/test_demo_flow.db") +TEST_DB_PATH = Path(tempfile.gettempdir()) / "medical_agent_test_demo_flow.db" +TEST_DB_PATH.unlink(missing_ok=True) +os.environ["DATABASE_URL"] = f"sqlite:///{TEST_DB_PATH.as_posix()}" +os.environ["REPORT_STORAGE_DIR"] = str(Path(tempfile.gettempdir()) / "medical_agent_test_reports") os.environ.setdefault("RUNTIME_MEMORY_BACKEND", "memory") os.environ.setdefault("LLM_MOCK_ENABLED", "true") sys.path.insert(0, str(Path(__file__).resolve().parents[1])) -if os.getenv("DATABASE_URL") == "sqlite:///./storage/test_demo_flow.db": - Path("storage/test_demo_flow.db").unlink(missing_ok=True) from sqlalchemy import select diff --git a/tests/test_import_source_case_sql.py b/tests/test_import_source_case_sql.py deleted file mode 100644 index 55d6614..0000000 --- a/tests/test_import_source_case_sql.py +++ /dev/null @@ -1,75 +0,0 @@ -import tempfile -from pathlib import Path - -import os -import sys - -os.environ.setdefault("DATABASE_URL", "sqlite:///./storage/test_import.db") -os.environ.setdefault("RUNTIME_MEMORY_BACKEND", "memory") -os.environ.setdefault("LLM_MOCK_ENABLED", "true") - -sys.path.insert(0, str(Path(__file__).resolve().parents[1])) - -from scripts.import_source_case_sql import ImportValidationError, extract_create_columns, extract_insert_rows, parse_values_clause - - -def test_extract_insert_rows_maps_by_create_columns() -> None: - """导入解析:验证 INSERT VALUES 按 CREATE TABLE 字段顺序映射。""" - sql = """ - CREATE TABLE `traditional_case` ( - `created_at` datetime(6) NOT NULL, - `updated_at` datetime(6) NOT NULL, - `id` bigint NOT NULL AUTO_INCREMENT, - `standard_diagnosis` longtext NOT NULL, - `standard_treatment` longtext NOT NULL, - `guideline_reference` longtext NOT NULL, - `case_id` bigint NOT NULL - ) ENGINE=InnoDB; - INSERT INTO `traditional_case` VALUES ('2026-05-28 09:34:08','2026-05-28 09:34:09',1,'支气管肺炎','抗感染','指南',1); - """ - columns = extract_create_columns(sql, "traditional_case") - rows = extract_insert_rows(sql, "traditional_case", columns) - assert rows[0]["standard_diagnosis"] == "支气管肺炎" - assert rows[0]["case_id"] == 1 - - -def test_extract_insert_rows_rejects_broken_sql() -> None: - """导入解析:验证损坏 SQL 被拒绝,避免半导入。""" - sql = """ - CREATE TABLE `traditional_case` ( - `created_at` datetime(6) NOT NULL, - `updated_at` datetime(6) NOT NULL, - `id` bigint NOT NULL AUTO_INCREMENT, - `standard_diagnosis` longtext NOT NULL, - `standard_treatment` longtext NOT NULL, - `guideline_reference` longtext NOT NULL, - `case_id` bigint NOT NULL - ) ENGINE=InnoDB; - INSERT INTO `traditional_case` VALUES ('2026-05-28','2026-05-28',1,'支气管肺炎'bad,'抗感染','指南',1); - """ - columns = extract_create_columns(sql, "traditional_case") - try: - extract_insert_rows(sql, "traditional_case", columns) - except ImportValidationError: - return - raise AssertionError("broken SQL should be rejected") - - -def test_parse_values_clause_recovers_misplaced_quote_separator() -> None: - """导入解析:兼容接口 SQL 中文本字段写成 `文本,'next'` 的引号/逗号错位。""" - rows = parse_values_clause("(1,'标题?,'traditional','medium',NULL,'主诉?,'描述?,6,'male')") - assert rows[0] == [1, "标题?", "traditional", "medium", None, "主诉?", "描述?", 6, "male"] - - -def test_parse_values_clause_recovers_unclosed_string_before_tuple_end() -> None: - """导入解析:兼容接口 SQL 最后文本字段缺少结束引号且后接外键的情况。""" - rows = parse_values_clause("(1,'诊断?,'治疗?,'指南?,1)") - assert rows[0] == [1, "诊断?", "治疗?", "指南?", 1] - - -if __name__ == "__main__": - test_extract_insert_rows_maps_by_create_columns() - test_extract_insert_rows_rejects_broken_sql() - test_parse_values_clause_recovers_misplaced_quote_separator() - test_parse_values_clause_recovers_unclosed_string_before_tuple_end() - print("import source case sql tests passed")