102 lines
4.3 KiB
Python
102 lines
4.3 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import tempfile
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
from fastapi import UploadFile
|
||
|
|
|
||
|
|
from app.core.exceptions import AppError
|
||
|
|
from app.schemas.imports import CaseSqlImportApplyResponse, CaseSqlImportPreviewResponse, CaseSqlPreviewCase
|
||
|
|
from scripts.import_source_case_sql import ImportValidationError, import_source_sql, parse_source_dump
|
||
|
|
|
||
|
|
|
||
|
|
MAX_SQL_UPLOAD_BYTES = 5 * 1024 * 1024
|
||
|
|
ALLOWED_TABLES = {"case_base", "traditional_case", "teaching_case", "scoring_rule"}
|
||
|
|
|
||
|
|
|
||
|
|
class CaseSqlImportService:
|
||
|
|
"""病例 SQL 导入服务:解析接口 SQL 文件并安全映射到当前病例表。"""
|
||
|
|
|
||
|
|
async def preview(self, file: UploadFile) -> CaseSqlImportPreviewResponse:
|
||
|
|
"""导入预检:上传 SQL 后只解析结构和数据,不写入数据库。"""
|
||
|
|
temp_path = await self._save_upload_to_temp(file)
|
||
|
|
try:
|
||
|
|
parsed, warnings, encoding = parse_source_dump(temp_path)
|
||
|
|
table_rows = self._allowed_table_counts(parsed)
|
||
|
|
preview_cases = [
|
||
|
|
CaseSqlPreviewCase(
|
||
|
|
id=int(row["id"]),
|
||
|
|
title=str(row.get("title") or ""),
|
||
|
|
case_type=str(row.get("case_type") or ""),
|
||
|
|
difficulty=str(row.get("difficulty") or ""),
|
||
|
|
)
|
||
|
|
for row in parsed.get("case_base", [])
|
||
|
|
]
|
||
|
|
return CaseSqlImportPreviewResponse(
|
||
|
|
file_name=file.filename or "case.sql",
|
||
|
|
encoding=encoding,
|
||
|
|
tables=table_rows,
|
||
|
|
can_import=True,
|
||
|
|
warnings=warnings,
|
||
|
|
errors=[],
|
||
|
|
preview_cases=preview_cases,
|
||
|
|
)
|
||
|
|
except ImportValidationError as exc:
|
||
|
|
return CaseSqlImportPreviewResponse(
|
||
|
|
file_name=file.filename or "case.sql",
|
||
|
|
can_import=False,
|
||
|
|
errors=[str(exc)],
|
||
|
|
)
|
||
|
|
finally:
|
||
|
|
self._remove_temp_file(temp_path)
|
||
|
|
|
||
|
|
async def apply(self, file: UploadFile) -> CaseSqlImportApplyResponse:
|
||
|
|
"""确认导入:解析通过后以事务方式写入 case_base/traditional_case/teaching_case/scoring_rule。"""
|
||
|
|
temp_path = await self._save_upload_to_temp(file)
|
||
|
|
try:
|
||
|
|
report = import_source_sql(temp_path, apply=True)
|
||
|
|
return CaseSqlImportApplyResponse(
|
||
|
|
imported=True,
|
||
|
|
file_name=file.filename or "case.sql",
|
||
|
|
encoding=report.encoding,
|
||
|
|
inserted_or_updated_cases=report.upserted_cases,
|
||
|
|
imported_traditional_cases=report.upserted_traditional_cases,
|
||
|
|
imported_teaching_cases=report.upserted_teaching_cases,
|
||
|
|
imported_scoring_rules=report.replaced_scoring_rules,
|
||
|
|
generated_exam_items=report.generated_exam_items,
|
||
|
|
warnings=report.warnings,
|
||
|
|
)
|
||
|
|
except ImportValidationError as exc:
|
||
|
|
raise AppError("CASE_SQL_IMPORT_INVALID", str(exc), 400) from exc
|
||
|
|
finally:
|
||
|
|
self._remove_temp_file(temp_path)
|
||
|
|
|
||
|
|
async def _save_upload_to_temp(self, file: UploadFile) -> Path:
|
||
|
|
"""上传落盘:校验 SQL 后缀和大小后保存到临时文件,供解析器读取。"""
|
||
|
|
filename = file.filename or ""
|
||
|
|
if not filename.lower().endswith(".sql"):
|
||
|
|
raise AppError("CASE_SQL_FILE_INVALID", "only .sql files are supported", 400)
|
||
|
|
content = await file.read()
|
||
|
|
if not content:
|
||
|
|
raise AppError("CASE_SQL_FILE_EMPTY", "uploaded SQL file is empty", 400)
|
||
|
|
if len(content) > MAX_SQL_UPLOAD_BYTES:
|
||
|
|
raise AppError("CASE_SQL_FILE_TOO_LARGE", "SQL file is larger than 5MB", 400)
|
||
|
|
|
||
|
|
handle = tempfile.NamedTemporaryFile(delete=False, suffix=".sql")
|
||
|
|
try:
|
||
|
|
handle.write(content)
|
||
|
|
return Path(handle.name)
|
||
|
|
finally:
|
||
|
|
handle.close()
|
||
|
|
|
||
|
|
def _allowed_table_counts(self, parsed: dict[str, list[dict]]) -> dict[str, int]:
|
||
|
|
"""表计数:只暴露当前业务允许导入的病例表。"""
|
||
|
|
return {table: len(rows) for table, rows in parsed.items() if table in ALLOWED_TABLES}
|
||
|
|
|
||
|
|
def _remove_temp_file(self, path: Path) -> None:
|
||
|
|
"""临时文件清理:解析或导入完成后删除上传副本。"""
|
||
|
|
try:
|
||
|
|
path.unlink(missing_ok=True)
|
||
|
|
except OSError:
|
||
|
|
pass
|