Files
fastapi/scripts/import_source_case_sql.py
T
2026-06-04 10:55:23 +08:00

613 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import argparse
import json
import re
import sys
from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal
from pathlib import Path
from typing import Any
from sqlalchemy import delete, inspect, select
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from app.db.session import SessionLocal
from app.models.source_case import CaseBase, CaseExamItem, ScoringRule, TeachingCase, TraditionalCase
SOURCE_TABLES = ("case_base",)
OPTIONAL_SOURCE_TABLES = ("traditional_case", "teaching_case", "scoring_rule", "case_exam_item")
DANGEROUS_PATTERNS = (
r"\bDROP\s+DATABASE\b",
r"\bDROP\s+TABLE\b",
r"\bTRUNCATE\b",
r"\bDELETE\s+FROM\b",
r"\bCREATE\s+DATABASE\b",
r"\bALTER\s+TABLE\b",
)
JSON_COLUMNS = {
"case_base": {"symptom_tags", "disease_tags", "competency_tags", "guideline_tags", "knowledge_points", "multimodal_assets"},
"scoring_rule": {"rubric_json"},
}
DATETIME_COLUMNS = {"created_at", "updated_at"}
DECIMAL_COLUMNS = {"score_weight"}
INT_COLUMNS = {
"id",
"difficulty_score",
"patient_age",
"estimated_minutes",
"vector_status",
"publish_status",
"status",
"created_by_id",
"department_id",
"case_id",
}
BOOL_COLUMNS = {"osce_enabled", "rag_enabled", "ai_auto_score", "osce_dimension"}
class ImportValidationError(Exception):
"""导入校验错误:源 SQL 不满足安全导入要求时中止。"""
@dataclass
class ImportReport:
"""导入报告:记录检查、写入和跳过情况。"""
source_path: str
encoding: str
table_rows: dict[str, int]
warnings: list[str]
applied: bool
upserted_cases: int = 0
upserted_traditional_cases: int = 0
upserted_teaching_cases: int = 0
replaced_scoring_rules: int = 0
generated_exam_items: int = 0
def as_dict(self) -> dict[str, Any]:
"""报告输出:转换为前端和命令行均可读的结构。"""
return {
"source_path": self.source_path,
"encoding": self.encoding,
"table_rows": self.table_rows,
"warnings": self.warnings,
"applied": self.applied,
"upserted_cases": self.upserted_cases,
"upserted_traditional_cases": self.upserted_traditional_cases,
"upserted_teaching_cases": self.upserted_teaching_cases,
"replaced_scoring_rules": self.replaced_scoring_rules,
"generated_exam_items": self.generated_exam_items,
}
def detect_encoding(path: Path) -> str:
"""编码识别:根据 BOM 和试读结果判断 SQL 文件编码。"""
raw = path.read_bytes()
if raw.startswith(b"\xff\xfe") or raw.startswith(b"\xfe\xff"):
return "utf-16"
if raw.startswith(b"\xef\xbb\xbf"):
return "utf-8-sig"
for encoding in ("utf-8", "utf-16"):
try:
raw.decode(encoding)
return encoding
except UnicodeDecodeError:
continue
raise ImportValidationError("SQL 文件编码无法识别,请提供 UTF-8 或 UTF-16 文件。")
def load_sql_text(path: Path) -> tuple[str, str]:
"""文件读取:读取接口提供的 SQL dump,并返回文本与编码。"""
if not path.exists():
raise ImportValidationError(f"SQL 文件不存在:{path}")
encoding = detect_encoding(path)
return path.read_text(encoding=encoding), encoding
def find_dangerous_statements(sql_text: str) -> list[str]:
"""安全扫描:识别源 dump 中不能在正式库直接执行的 DDL/DML。"""
hits: list[str] = []
for pattern in DANGEROUS_PATTERNS:
if re.search(pattern, sql_text, flags=re.IGNORECASE):
hits.append(pattern.replace(r"\b", "").replace("\\s+", " "))
return hits
def extract_create_columns(sql_text: str, table_name: str) -> list[str]:
"""字段提取:从 CREATE TABLE 中按源顺序读取字段名。"""
match = re.search(rf"CREATE TABLE `{re.escape(table_name)}` \((.*?)\) ENGINE=", sql_text, flags=re.S | re.I)
if not match:
return []
columns: list[str] = []
for line in match.group(1).splitlines():
stripped = line.strip().rstrip(",")
if stripped.startswith("`"):
columns.append(stripped.split("`", 2)[1])
return columns
def extract_insert_rows(sql_text: str, table_name: str, columns: list[str]) -> list[dict[str, Any]]:
"""数据提取:只解析 INSERT VALUES 数据,不执行源 SQL。"""
rows: list[dict[str, Any]] = []
pattern = re.compile(rf"INSERT INTO `{re.escape(table_name)}` VALUES\s*(.*?);", flags=re.S | re.I)
for match in pattern.finditer(sql_text):
tuples = parse_values_clause(match.group(1))
for values in tuples:
if len(values) != len(columns):
raise ImportValidationError(
f"{table_name} INSERT 字段数不匹配:期望 {len(columns)},实际 {len(values)}"
)
row = {columns[index]: normalize_value(table_name, columns[index], value) for index, value in enumerate(values)}
rows.append(row)
return rows
def parse_values_clause(values_clause: str) -> list[list[Any]]:
"""VALUES 解析:把 SQL values 子句解析为二维数组,损坏字符串会直接报错。"""
rows: list[list[Any]] = []
index = 0
length = len(values_clause)
while index < length:
while index < length and values_clause[index].isspace():
index += 1
if index >= length:
break
if values_clause[index] == ",":
index += 1
continue
if values_clause[index] != "(":
raise ImportValidationError(f"VALUES 子句格式错误:第 {index} 个字符不是 '('")
row, index = _parse_tuple(values_clause, index)
rows.append(row)
return rows
def _parse_tuple(text: str, start: int) -> tuple[list[Any], int]:
"""元组解析:解析一组括号内的 SQL 字段值。"""
values: list[Any] = []
token: list[str] = []
in_string = False
was_string = False
escaped = False
index = start + 1
while index < len(text):
char = text[index]
if in_string:
if escaped:
token.append(_unescape_char(char))
escaped = False
elif char == "\\":
escaped = True
elif char == ")":
recovered_at_end = _recover_unclosed_string_at_tuple_end(token)
if recovered_at_end:
string_value, raw_values = recovered_at_end
values.append(string_value)
values.extend(raw_values)
return values, index + 1
token.append(char)
elif char == "'":
recovered = _recover_misplaced_quote_separator(token, text[index + 1] if index + 1 < len(text) else "")
if recovered:
string_value, raw_values = recovered
values.append(string_value)
values.extend(raw_values)
token = []
was_string = True
index += 1
continue
in_string = False
else:
token.append(char)
index += 1
continue
if char == "'":
stripped = "".join(token).strip()
if stripped:
raise ImportValidationError(f"字符串前存在非法未引用内容:{stripped[:20]}")
in_string = True
was_string = True
index += 1
continue
if was_string and char.isspace():
index += 1
continue
if was_string and char not in {",", ")"}:
raise ImportValidationError(f"字符串后存在非法未引用内容:{char}{text[index + 1:index + 20]}")
if char == ",":
values.append(_coerce_raw_token("".join(token), was_string))
token = []
was_string = False
index += 1
continue
if char == ")":
values.append(_coerce_raw_token("".join(token), was_string))
return values, index + 1
token.append(char)
index += 1
raise ImportValidationError("VALUES 子句存在未闭合括号或未闭合字符串。")
def _recover_misplaced_quote_separator(token: list[str], next_char: str) -> tuple[str, list[Any]] | None:
"""兼容解析:修复接口 SQL 文本字段结尾写成 `文本,'next'` 的引号/逗号错位。"""
if not next_char or next_char in {",", ")"} or next_char.isspace():
return None
raw = "".join(token)
if not raw.endswith(","):
return None
body = raw[:-1]
trailing_values: list[Any] = []
while True:
head, separator, tail = body.rpartition(",")
if not separator or not _looks_like_unquoted_scalar(tail):
break
trailing_values.insert(0, _coerce_raw_token(tail, was_string=False))
body = head
if not body:
return None
return body, trailing_values
def _recover_unclosed_string_at_tuple_end(token: list[str]) -> tuple[str, list[Any]] | None:
"""兼容解析:修复文本字段缺少结束引号且后接 `,数字)` 的源 SQL。"""
body = "".join(token)
trailing_values: list[Any] = []
while True:
head, separator, tail = body.rpartition(",")
if not separator or not _looks_like_unquoted_scalar(tail):
break
trailing_values.insert(0, _coerce_raw_token(tail, was_string=False))
body = head
if not body or not trailing_values:
return None
return body, trailing_values
def _looks_like_unquoted_scalar(value: str) -> bool:
"""兼容解析:判断错位引号前夹带的字段是否是可安全恢复的未引用标量。"""
stripped = value.strip()
if not stripped:
return False
if stripped.upper() == "NULL":
return True
return bool(re.fullmatch(r"-?\d+(?:\.\d+)?", stripped))
def _unescape_char(char: str) -> str:
"""SQL 转义:处理 dump 中的常见反斜杠转义。"""
return {
"0": "\0",
"b": "\b",
"n": "\n",
"r": "\r",
"t": "\t",
"Z": "\x1a",
"\\": "\\",
"'": "'",
'"': '"',
}.get(char, char)
def _coerce_raw_token(raw: str, was_string: bool) -> Any:
"""原始字段转换:区分 SQL NULL、数字和字符串。"""
if was_string:
return raw
value = raw.strip()
if not value:
return ""
if value.upper() == "NULL":
return None
if re.fullmatch(r"-?\d+", value):
return int(value)
if re.fullmatch(r"-?\d+\.\d+", value):
return Decimal(value)
if re.search(r"[A-Za-z\u4e00-\u9fff]", value):
raise ImportValidationError(f"发现未引用文本字段:{value[:30]}")
return value
def normalize_value(table_name: str, column: str, value: Any) -> Any:
"""字段归一:按当前 ORM 需要转换 JSON、时间、布尔和数值。"""
if value is None:
return None
if column in JSON_COLUMNS.get(table_name, set()):
if isinstance(value, (list, dict)):
return value
try:
return json.loads(value)
except (TypeError, json.JSONDecodeError) as exc:
raise ImportValidationError(f"{table_name}.{column} 不是合法 JSON。") from exc
if column in DATETIME_COLUMNS and isinstance(value, str):
try:
return datetime.fromisoformat(value)
except ValueError as exc:
raise ImportValidationError(f"{table_name}.{column} 不是合法 datetime。") from exc
if column in BOOL_COLUMNS:
return bool(int(value))
if column in INT_COLUMNS and value is not None:
return int(value)
if column in DECIMAL_COLUMNS and value is not None:
return Decimal(str(value))
return value
def parse_source_dump(path: Path) -> tuple[dict[str, list[dict[str, Any]]], list[str], str]:
"""源文件解析:提取可导入表数据并返回兼容性警告。"""
sql_text, encoding = load_sql_text(path)
warnings = []
dangerous = find_dangerous_statements(sql_text)
if dangerous:
warnings.append("源 SQL 包含 DDL/DML 覆盖语句,导入器会忽略这些语句:" + ", ".join(sorted(set(dangerous))))
parsed: dict[str, list[dict[str, Any]]] = {}
for table_name in SOURCE_TABLES:
columns = extract_create_columns(sql_text, table_name)
if not columns:
raise ImportValidationError(f"源 SQL 缺少必需表结构:{table_name}")
parsed[table_name] = extract_insert_rows(sql_text, table_name, columns)
for table_name in OPTIONAL_SOURCE_TABLES:
columns = extract_create_columns(sql_text, table_name)
if not columns:
warnings.append(f"源 SQL 未包含 {table_name},导入器会按当前业务规则处理。")
continue
parsed[table_name] = extract_insert_rows(sql_text, table_name, columns)
if not parsed.get("traditional_case") and not parsed.get("teaching_case"):
warnings.append("源 SQL 未包含 traditional_case 或 teaching_case,病例导入后暂时缺少训练模式扩展数据。")
if not parsed.get("scoring_rule"):
warnings.append("源 SQL 未包含 scoring_rule,评价时将缺少接口侧基础评分规则。")
return parsed, warnings, encoding
def validate_target_schema(parsed: dict[str, list[dict[str, Any]]]) -> None:
"""目标结构校验:确认源字段可映射到当前数据库表。"""
with SessionLocal() as db:
inspector = inspect(db.bind)
for table_name, rows in parsed.items():
if not inspector.has_table(table_name):
raise ImportValidationError(f"当前数据库缺少目标表:{table_name}")
target_columns = {column["name"] for column in inspector.get_columns(table_name)}
for row in rows:
extra_columns = sorted(set(row) - target_columns)
if extra_columns:
raise ImportValidationError(f"{table_name} 存在当前库不支持的字段:{extra_columns}")
def import_source_sql(path: Path, apply: bool = False, generate_exam_items: bool = True) -> ImportReport:
"""安全导入:解析源 SQL,并在显式 apply 时写入当前新表。"""
parsed, warnings, encoding = parse_source_dump(path)
validate_target_schema(parsed)
report = ImportReport(
source_path=str(path),
encoding=encoding,
table_rows={table: len(rows) for table, rows in parsed.items()},
warnings=warnings,
applied=apply,
)
if not apply:
return report
with SessionLocal() as db:
try:
case_rows = parsed.get("case_base", [])
report.upserted_cases = _upsert_cases(db, case_rows)
if "traditional_case" in parsed:
report.upserted_traditional_cases = _sync_traditional_cases(db, case_rows, parsed.get("traditional_case", []))
if "teaching_case" in parsed:
report.upserted_teaching_cases = _sync_teaching_cases(db, case_rows, parsed.get("teaching_case", []))
if "scoring_rule" in parsed:
report.replaced_scoring_rules = _replace_scoring_rules(db, case_rows, parsed.get("scoring_rule", []))
if generate_exam_items:
report.generated_exam_items = _upsert_generated_exam_items(db, case_rows)
db.commit()
except Exception:
db.rollback()
raise
return report
def _upsert_cases(db, rows: list[dict[str, Any]]) -> int:
"""病例导入:按 case_base.id 更新或插入病例主表。"""
count = 0
for row in rows:
row = dict(row)
row["status"] = 1
row["publish_status"] = 1
entity = db.get(CaseBase, row["id"])
if not entity:
db.add(CaseBase(**row))
else:
for key, value in row.items():
setattr(entity, key, value)
count += 1
return count
def _sync_traditional_cases(db, case_rows: list[dict[str, Any]], rows: list[dict[str, Any]]) -> int:
"""传统病例导入:按本次 SQL 同步练习模式扩展表,避免同 case_id 旧数据残留。"""
case_ids = _imported_case_ids(case_rows)
if case_ids:
db.execute(delete(TraditionalCase).where(TraditionalCase.case_id.in_(case_ids)))
count = 0
for row in rows:
db.add(TraditionalCase(**row))
count += 1
return count
def _sync_teaching_cases(db, case_rows: list[dict[str, Any]], rows: list[dict[str, Any]]) -> int:
"""教学互动病例导入:源 SQL 明确提供 teaching_case 时,以源数据为准同步扩展表。"""
case_ids = _imported_case_ids(case_rows)
if case_ids:
db.execute(delete(TeachingCase).where(TeachingCase.case_id.in_(case_ids)))
count = 0
for row in rows:
db.add(TeachingCase(**row))
count += 1
return count
def _replace_scoring_rules(db, case_rows: list[dict[str, Any]], rows: list[dict[str, Any]]) -> int:
"""评分规则导入:按本次导入病例替换 scoring_rule,保持评分规则与源数据一致。"""
case_ids = _imported_case_ids(case_rows)
for case_id in case_ids:
db.execute(delete(ScoringRule).where(ScoringRule.case_id == case_id))
for row in rows:
db.add(ScoringRule(**row))
return len(rows)
def _upsert_generated_exam_items(db, case_rows: list[dict[str, Any]]) -> int:
"""检查项目补齐:按 item_code 更新或补齐固定检查结果,避免删除历史训练引用的检查项。"""
changed = 0
for row in case_rows:
case_id = row["id"]
items = build_exam_items_from_case(row)
for item in items:
entity = db.scalar(
select(CaseExamItem).where(CaseExamItem.case_id == case_id, CaseExamItem.item_code == item["item_code"])
)
if not entity:
db.add(CaseExamItem(case_id=case_id, **item))
else:
for key, value in item.items():
setattr(entity, key, value)
changed += 1
return changed
def _imported_case_ids(case_rows: list[dict[str, Any]]) -> list[int]:
"""导入范围:提取本次 SQL 涉及的病例 ID,用于同步扩展表。"""
return sorted({int(row["id"]) for row in case_rows if row.get("id") is not None})
def build_exam_items_from_case(case_row: dict[str, Any]) -> list[dict[str, Any]]:
"""检查项目生成:从病例文本识别常见检查结果,保障 Demo 可继续申请检查。"""
text = " ".join(str(case_row.get(key) or "") for key in ("chief_complaint", "description", "knowledge_points"))
candidates = [
(
"blood_routine",
"血常规",
"lab",
"实验室检查",
_find_result(text, r"(WBC[^,。;;]*[,]\s*[^,。;;]*(?:中性|neutrophil)[^,。;;]*)") or "血常规结果见病例资料。",
{"source": "case_description"},
"WBC" in text or "血常规" in text,
),
(
"crp",
"CRP",
"lab",
"实验室检查",
_find_result(text, r"(CRP\s*[^,。;;]*)") or "CRP 结果见病例资料。",
{"source": "case_description"},
"CRP" in text.upper(),
),
(
"chest_xray",
"胸片",
"imaging",
"影像检查",
_find_result(text, r"(胸片[^。;;]*)") or "胸片结果见病例资料。",
{"source": "case_description"},
"胸片" in text or "X线" in text,
),
(
"spo2",
"血氧饱和度",
"vital_sign",
"生命体征",
_find_result(text, r"(SpO2\s*\d+%?)") or "血氧饱和度结果见病例资料。",
{"source": "case_description"},
"SpO2" in text or "血氧" in text,
),
(
"chest_auscultation",
"肺部体格检查",
"physical_exam",
"体格检查",
_find_result(text, r"(肺[^。;;]*(?:湿啰音|哮鸣音|呼吸音)[^。;;]*)") or "肺部体格检查结果见病例资料。",
{"source": "case_description"},
"湿啰音" in text or "哮鸣音" in text or "" in text,
),
(
"mp_igm",
"肺炎支原体抗体IgM",
"lab",
"实验室检查",
_find_result(text, r"(肺炎支原体抗体IgM[^,。;;]*)") or "肺炎支原体抗体IgM结果见病例资料。",
{"source": "case_description"},
"肺炎支原体抗体IgM" in text or "支原体" in text,
),
]
items: list[dict[str, Any]] = []
for index, (code, name, item_type, category, result_text, structured, detected) in enumerate(candidates, start=1):
if detected:
items.append(
{
"item_code": code,
"item_name": name,
"item_type": item_type,
"category": category,
"result_text": result_text,
"result_structured": structured,
"is_key": True,
"is_abnormal": True,
"score_weight": Decimal("5.00"),
"display_order": index,
}
)
if items:
return items
return [
{
"item_code": "basic_exam",
"item_name": "基础检查",
"item_type": "physical_exam",
"category": "体格检查",
"result_text": "基础检查结果见病例资料。",
"result_structured": {"source": "case_description"},
"is_key": True,
"is_abnormal": False,
"score_weight": Decimal("1.00"),
"display_order": 1,
}
]
def _find_result(text: str, pattern: str) -> str | None:
"""文本抽取:从病例描述中抽取简短检查结果。"""
match = re.search(pattern, text, flags=re.I)
return match.group(1).strip() if match else None
def main() -> None:
"""命令入口:执行接口 SQL 的安全检查或导入。"""
parser = argparse.ArgumentParser(description="Safely import parsed case SQL into current medical agent schema.")
parser.add_argument("sql_path", type=Path, help="接口提供的 SQL dump 文件路径")
parser.add_argument("--apply", action="store_true", help="确认写入当前数据库;默认只检查不写入")
parser.add_argument("--no-generate-exam-items", action="store_true", help="不自动补齐 case_exam_item")
args = parser.parse_args()
try:
report = import_source_sql(
args.sql_path,
apply=args.apply,
generate_exam_items=not args.no_generate_exam_items,
)
except ImportValidationError as exc:
print(json.dumps({"ok": False, "error": str(exc)}, ensure_ascii=False, indent=2))
raise SystemExit(2) from exc
print(json.dumps({"ok": True, "report": report.as_dict()}, ensure_ascii=False, indent=2, default=str))
if __name__ == "__main__":
main()