from __future__ import annotations import tempfile from pathlib import Path from fastapi import UploadFile from app.core.exceptions import AppError from app.schemas.imports import CaseSqlImportApplyResponse, CaseSqlImportPreviewResponse, CaseSqlPreviewCase from scripts.import_source_case_sql import ImportValidationError, import_source_sql, parse_source_dump MAX_SQL_UPLOAD_BYTES = 5 * 1024 * 1024 ALLOWED_TABLES = {"case_base", "traditional_case", "teaching_case", "scoring_rule"} class CaseSqlImportService: """病例 SQL 导入服务:解析接口 SQL 文件并安全映射到当前病例表。""" async def preview(self, file: UploadFile) -> CaseSqlImportPreviewResponse: """导入预检:上传 SQL 后只解析结构和数据,不写入数据库。""" temp_path = await self._save_upload_to_temp(file) try: parsed, warnings, encoding = parse_source_dump(temp_path) table_rows = self._allowed_table_counts(parsed) preview_cases = [ CaseSqlPreviewCase( id=int(row["id"]), title=str(row.get("title") or ""), case_type=str(row.get("case_type") or ""), difficulty=str(row.get("difficulty") or ""), ) for row in parsed.get("case_base", []) ] return CaseSqlImportPreviewResponse( file_name=file.filename or "case.sql", encoding=encoding, tables=table_rows, can_import=True, warnings=warnings, errors=[], preview_cases=preview_cases, ) except ImportValidationError as exc: return CaseSqlImportPreviewResponse( file_name=file.filename or "case.sql", can_import=False, errors=[str(exc)], ) finally: self._remove_temp_file(temp_path) async def apply(self, file: UploadFile) -> CaseSqlImportApplyResponse: """确认导入:解析通过后以事务方式写入 case_base/traditional_case/teaching_case/scoring_rule。""" temp_path = await self._save_upload_to_temp(file) try: report = import_source_sql(temp_path, apply=True) return CaseSqlImportApplyResponse( imported=True, file_name=file.filename or "case.sql", encoding=report.encoding, inserted_or_updated_cases=report.upserted_cases, imported_traditional_cases=report.upserted_traditional_cases, imported_teaching_cases=report.upserted_teaching_cases, imported_scoring_rules=report.replaced_scoring_rules, generated_exam_items=report.generated_exam_items, warnings=report.warnings, ) except ImportValidationError as exc: raise AppError("CASE_SQL_IMPORT_INVALID", str(exc), 400) from exc finally: self._remove_temp_file(temp_path) async def _save_upload_to_temp(self, file: UploadFile) -> Path: """上传落盘:校验 SQL 后缀和大小后保存到临时文件,供解析器读取。""" filename = file.filename or "" if not filename.lower().endswith(".sql"): raise AppError("CASE_SQL_FILE_INVALID", "only .sql files are supported", 400) content = await file.read() if not content: raise AppError("CASE_SQL_FILE_EMPTY", "uploaded SQL file is empty", 400) if len(content) > MAX_SQL_UPLOAD_BYTES: raise AppError("CASE_SQL_FILE_TOO_LARGE", "SQL file is larger than 5MB", 400) handle = tempfile.NamedTemporaryFile(delete=False, suffix=".sql") try: handle.write(content) return Path(handle.name) finally: handle.close() def _allowed_table_counts(self, parsed: dict[str, list[dict]]) -> dict[str, int]: """表计数:只暴露当前业务允许导入的病例表。""" return {table: len(rows) for table, rows in parsed.items() if table in ALLOWED_TABLES} def _remove_temp_file(self, path: Path) -> None: """临时文件清理:解析或导入完成后删除上传副本。""" try: path.unlink(missing_ok=True) except OSError: pass