613 lines
23 KiB
Python
613 lines
23 KiB
Python
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
import re
|
||
import sys
|
||
from dataclasses import dataclass
|
||
from datetime import datetime
|
||
from decimal import Decimal
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
from sqlalchemy import delete, inspect, select
|
||
|
||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||
|
||
from app.db.session import SessionLocal
|
||
from app.models.source_case import CaseBase, CaseExamItem, ScoringRule, TeachingCase, TraditionalCase
|
||
|
||
|
||
SOURCE_TABLES = ("case_base",)
|
||
OPTIONAL_SOURCE_TABLES = ("traditional_case", "teaching_case", "scoring_rule", "case_exam_item")
|
||
DANGEROUS_PATTERNS = (
|
||
r"\bDROP\s+DATABASE\b",
|
||
r"\bDROP\s+TABLE\b",
|
||
r"\bTRUNCATE\b",
|
||
r"\bDELETE\s+FROM\b",
|
||
r"\bCREATE\s+DATABASE\b",
|
||
r"\bALTER\s+TABLE\b",
|
||
)
|
||
JSON_COLUMNS = {
|
||
"case_base": {"symptom_tags", "disease_tags", "competency_tags", "guideline_tags", "knowledge_points", "multimodal_assets"},
|
||
"scoring_rule": {"rubric_json"},
|
||
}
|
||
DATETIME_COLUMNS = {"created_at", "updated_at"}
|
||
DECIMAL_COLUMNS = {"score_weight"}
|
||
INT_COLUMNS = {
|
||
"id",
|
||
"difficulty_score",
|
||
"patient_age",
|
||
"estimated_minutes",
|
||
"vector_status",
|
||
"publish_status",
|
||
"status",
|
||
"created_by_id",
|
||
"department_id",
|
||
"case_id",
|
||
}
|
||
BOOL_COLUMNS = {"osce_enabled", "rag_enabled", "ai_auto_score", "osce_dimension"}
|
||
|
||
|
||
class ImportValidationError(Exception):
|
||
"""导入校验错误:源 SQL 不满足安全导入要求时中止。"""
|
||
|
||
|
||
@dataclass
|
||
class ImportReport:
|
||
"""导入报告:记录检查、写入和跳过情况。"""
|
||
|
||
source_path: str
|
||
encoding: str
|
||
table_rows: dict[str, int]
|
||
warnings: list[str]
|
||
applied: bool
|
||
upserted_cases: int = 0
|
||
upserted_traditional_cases: int = 0
|
||
upserted_teaching_cases: int = 0
|
||
replaced_scoring_rules: int = 0
|
||
generated_exam_items: int = 0
|
||
|
||
def as_dict(self) -> dict[str, Any]:
|
||
"""报告输出:转换为前端和命令行均可读的结构。"""
|
||
return {
|
||
"source_path": self.source_path,
|
||
"encoding": self.encoding,
|
||
"table_rows": self.table_rows,
|
||
"warnings": self.warnings,
|
||
"applied": self.applied,
|
||
"upserted_cases": self.upserted_cases,
|
||
"upserted_traditional_cases": self.upserted_traditional_cases,
|
||
"upserted_teaching_cases": self.upserted_teaching_cases,
|
||
"replaced_scoring_rules": self.replaced_scoring_rules,
|
||
"generated_exam_items": self.generated_exam_items,
|
||
}
|
||
|
||
|
||
def detect_encoding(path: Path) -> str:
|
||
"""编码识别:根据 BOM 和试读结果判断 SQL 文件编码。"""
|
||
raw = path.read_bytes()
|
||
if raw.startswith(b"\xff\xfe") or raw.startswith(b"\xfe\xff"):
|
||
return "utf-16"
|
||
if raw.startswith(b"\xef\xbb\xbf"):
|
||
return "utf-8-sig"
|
||
for encoding in ("utf-8", "utf-16"):
|
||
try:
|
||
raw.decode(encoding)
|
||
return encoding
|
||
except UnicodeDecodeError:
|
||
continue
|
||
raise ImportValidationError("SQL 文件编码无法识别,请提供 UTF-8 或 UTF-16 文件。")
|
||
|
||
|
||
def load_sql_text(path: Path) -> tuple[str, str]:
|
||
"""文件读取:读取接口提供的 SQL dump,并返回文本与编码。"""
|
||
if not path.exists():
|
||
raise ImportValidationError(f"SQL 文件不存在:{path}")
|
||
encoding = detect_encoding(path)
|
||
return path.read_text(encoding=encoding), encoding
|
||
|
||
|
||
def find_dangerous_statements(sql_text: str) -> list[str]:
|
||
"""安全扫描:识别源 dump 中不能在正式库直接执行的 DDL/DML。"""
|
||
hits: list[str] = []
|
||
for pattern in DANGEROUS_PATTERNS:
|
||
if re.search(pattern, sql_text, flags=re.IGNORECASE):
|
||
hits.append(pattern.replace(r"\b", "").replace("\\s+", " "))
|
||
return hits
|
||
|
||
|
||
def extract_create_columns(sql_text: str, table_name: str) -> list[str]:
|
||
"""字段提取:从 CREATE TABLE 中按源顺序读取字段名。"""
|
||
match = re.search(rf"CREATE TABLE `{re.escape(table_name)}` \((.*?)\) ENGINE=", sql_text, flags=re.S | re.I)
|
||
if not match:
|
||
return []
|
||
columns: list[str] = []
|
||
for line in match.group(1).splitlines():
|
||
stripped = line.strip().rstrip(",")
|
||
if stripped.startswith("`"):
|
||
columns.append(stripped.split("`", 2)[1])
|
||
return columns
|
||
|
||
|
||
def extract_insert_rows(sql_text: str, table_name: str, columns: list[str]) -> list[dict[str, Any]]:
|
||
"""数据提取:只解析 INSERT VALUES 数据,不执行源 SQL。"""
|
||
rows: list[dict[str, Any]] = []
|
||
pattern = re.compile(rf"INSERT INTO `{re.escape(table_name)}` VALUES\s*(.*?);", flags=re.S | re.I)
|
||
for match in pattern.finditer(sql_text):
|
||
tuples = parse_values_clause(match.group(1))
|
||
for values in tuples:
|
||
if len(values) != len(columns):
|
||
raise ImportValidationError(
|
||
f"{table_name} INSERT 字段数不匹配:期望 {len(columns)},实际 {len(values)}。"
|
||
)
|
||
row = {columns[index]: normalize_value(table_name, columns[index], value) for index, value in enumerate(values)}
|
||
rows.append(row)
|
||
return rows
|
||
|
||
|
||
def parse_values_clause(values_clause: str) -> list[list[Any]]:
|
||
"""VALUES 解析:把 SQL values 子句解析为二维数组,损坏字符串会直接报错。"""
|
||
rows: list[list[Any]] = []
|
||
index = 0
|
||
length = len(values_clause)
|
||
while index < length:
|
||
while index < length and values_clause[index].isspace():
|
||
index += 1
|
||
if index >= length:
|
||
break
|
||
if values_clause[index] == ",":
|
||
index += 1
|
||
continue
|
||
if values_clause[index] != "(":
|
||
raise ImportValidationError(f"VALUES 子句格式错误:第 {index} 个字符不是 '('。")
|
||
row, index = _parse_tuple(values_clause, index)
|
||
rows.append(row)
|
||
return rows
|
||
|
||
|
||
def _parse_tuple(text: str, start: int) -> tuple[list[Any], int]:
|
||
"""元组解析:解析一组括号内的 SQL 字段值。"""
|
||
values: list[Any] = []
|
||
token: list[str] = []
|
||
in_string = False
|
||
was_string = False
|
||
escaped = False
|
||
index = start + 1
|
||
while index < len(text):
|
||
char = text[index]
|
||
if in_string:
|
||
if escaped:
|
||
token.append(_unescape_char(char))
|
||
escaped = False
|
||
elif char == "\\":
|
||
escaped = True
|
||
elif char == ")":
|
||
recovered_at_end = _recover_unclosed_string_at_tuple_end(token)
|
||
if recovered_at_end:
|
||
string_value, raw_values = recovered_at_end
|
||
values.append(string_value)
|
||
values.extend(raw_values)
|
||
return values, index + 1
|
||
token.append(char)
|
||
elif char == "'":
|
||
recovered = _recover_misplaced_quote_separator(token, text[index + 1] if index + 1 < len(text) else "")
|
||
if recovered:
|
||
string_value, raw_values = recovered
|
||
values.append(string_value)
|
||
values.extend(raw_values)
|
||
token = []
|
||
was_string = True
|
||
index += 1
|
||
continue
|
||
in_string = False
|
||
else:
|
||
token.append(char)
|
||
index += 1
|
||
continue
|
||
|
||
if char == "'":
|
||
stripped = "".join(token).strip()
|
||
if stripped:
|
||
raise ImportValidationError(f"字符串前存在非法未引用内容:{stripped[:20]}")
|
||
in_string = True
|
||
was_string = True
|
||
index += 1
|
||
continue
|
||
if was_string and char.isspace():
|
||
index += 1
|
||
continue
|
||
if was_string and char not in {",", ")"}:
|
||
raise ImportValidationError(f"字符串后存在非法未引用内容:{char}{text[index + 1:index + 20]}")
|
||
if char == ",":
|
||
values.append(_coerce_raw_token("".join(token), was_string))
|
||
token = []
|
||
was_string = False
|
||
index += 1
|
||
continue
|
||
if char == ")":
|
||
values.append(_coerce_raw_token("".join(token), was_string))
|
||
return values, index + 1
|
||
token.append(char)
|
||
index += 1
|
||
raise ImportValidationError("VALUES 子句存在未闭合括号或未闭合字符串。")
|
||
|
||
|
||
def _recover_misplaced_quote_separator(token: list[str], next_char: str) -> tuple[str, list[Any]] | None:
|
||
"""兼容解析:修复接口 SQL 文本字段结尾写成 `文本,'next'` 的引号/逗号错位。"""
|
||
if not next_char or next_char in {",", ")"} or next_char.isspace():
|
||
return None
|
||
raw = "".join(token)
|
||
if not raw.endswith(","):
|
||
return None
|
||
|
||
body = raw[:-1]
|
||
trailing_values: list[Any] = []
|
||
while True:
|
||
head, separator, tail = body.rpartition(",")
|
||
if not separator or not _looks_like_unquoted_scalar(tail):
|
||
break
|
||
trailing_values.insert(0, _coerce_raw_token(tail, was_string=False))
|
||
body = head
|
||
if not body:
|
||
return None
|
||
return body, trailing_values
|
||
|
||
|
||
def _recover_unclosed_string_at_tuple_end(token: list[str]) -> tuple[str, list[Any]] | None:
|
||
"""兼容解析:修复文本字段缺少结束引号且后接 `,数字)` 的源 SQL。"""
|
||
body = "".join(token)
|
||
trailing_values: list[Any] = []
|
||
while True:
|
||
head, separator, tail = body.rpartition(",")
|
||
if not separator or not _looks_like_unquoted_scalar(tail):
|
||
break
|
||
trailing_values.insert(0, _coerce_raw_token(tail, was_string=False))
|
||
body = head
|
||
if not body or not trailing_values:
|
||
return None
|
||
return body, trailing_values
|
||
|
||
|
||
def _looks_like_unquoted_scalar(value: str) -> bool:
|
||
"""兼容解析:判断错位引号前夹带的字段是否是可安全恢复的未引用标量。"""
|
||
stripped = value.strip()
|
||
if not stripped:
|
||
return False
|
||
if stripped.upper() == "NULL":
|
||
return True
|
||
return bool(re.fullmatch(r"-?\d+(?:\.\d+)?", stripped))
|
||
|
||
|
||
def _unescape_char(char: str) -> str:
|
||
"""SQL 转义:处理 dump 中的常见反斜杠转义。"""
|
||
return {
|
||
"0": "\0",
|
||
"b": "\b",
|
||
"n": "\n",
|
||
"r": "\r",
|
||
"t": "\t",
|
||
"Z": "\x1a",
|
||
"\\": "\\",
|
||
"'": "'",
|
||
'"': '"',
|
||
}.get(char, char)
|
||
|
||
|
||
def _coerce_raw_token(raw: str, was_string: bool) -> Any:
|
||
"""原始字段转换:区分 SQL NULL、数字和字符串。"""
|
||
if was_string:
|
||
return raw
|
||
value = raw.strip()
|
||
if not value:
|
||
return ""
|
||
if value.upper() == "NULL":
|
||
return None
|
||
if re.fullmatch(r"-?\d+", value):
|
||
return int(value)
|
||
if re.fullmatch(r"-?\d+\.\d+", value):
|
||
return Decimal(value)
|
||
if re.search(r"[A-Za-z\u4e00-\u9fff]", value):
|
||
raise ImportValidationError(f"发现未引用文本字段:{value[:30]}")
|
||
return value
|
||
|
||
|
||
def normalize_value(table_name: str, column: str, value: Any) -> Any:
|
||
"""字段归一:按当前 ORM 需要转换 JSON、时间、布尔和数值。"""
|
||
if value is None:
|
||
return None
|
||
if column in JSON_COLUMNS.get(table_name, set()):
|
||
if isinstance(value, (list, dict)):
|
||
return value
|
||
try:
|
||
return json.loads(value)
|
||
except (TypeError, json.JSONDecodeError) as exc:
|
||
raise ImportValidationError(f"{table_name}.{column} 不是合法 JSON。") from exc
|
||
if column in DATETIME_COLUMNS and isinstance(value, str):
|
||
try:
|
||
return datetime.fromisoformat(value)
|
||
except ValueError as exc:
|
||
raise ImportValidationError(f"{table_name}.{column} 不是合法 datetime。") from exc
|
||
if column in BOOL_COLUMNS:
|
||
return bool(int(value))
|
||
if column in INT_COLUMNS and value is not None:
|
||
return int(value)
|
||
if column in DECIMAL_COLUMNS and value is not None:
|
||
return Decimal(str(value))
|
||
return value
|
||
|
||
|
||
def parse_source_dump(path: Path) -> tuple[dict[str, list[dict[str, Any]]], list[str], str]:
|
||
"""源文件解析:提取可导入表数据并返回兼容性警告。"""
|
||
sql_text, encoding = load_sql_text(path)
|
||
warnings = []
|
||
dangerous = find_dangerous_statements(sql_text)
|
||
if dangerous:
|
||
warnings.append("源 SQL 包含 DDL/DML 覆盖语句,导入器会忽略这些语句:" + ", ".join(sorted(set(dangerous))))
|
||
|
||
parsed: dict[str, list[dict[str, Any]]] = {}
|
||
for table_name in SOURCE_TABLES:
|
||
columns = extract_create_columns(sql_text, table_name)
|
||
if not columns:
|
||
raise ImportValidationError(f"源 SQL 缺少必需表结构:{table_name}")
|
||
parsed[table_name] = extract_insert_rows(sql_text, table_name, columns)
|
||
|
||
for table_name in OPTIONAL_SOURCE_TABLES:
|
||
columns = extract_create_columns(sql_text, table_name)
|
||
if not columns:
|
||
warnings.append(f"源 SQL 未包含 {table_name},导入器会按当前业务规则处理。")
|
||
continue
|
||
parsed[table_name] = extract_insert_rows(sql_text, table_name, columns)
|
||
if not parsed.get("traditional_case") and not parsed.get("teaching_case"):
|
||
warnings.append("源 SQL 未包含 traditional_case 或 teaching_case,病例导入后暂时缺少训练模式扩展数据。")
|
||
if not parsed.get("scoring_rule"):
|
||
warnings.append("源 SQL 未包含 scoring_rule,评价时将缺少接口侧基础评分规则。")
|
||
return parsed, warnings, encoding
|
||
|
||
|
||
def validate_target_schema(parsed: dict[str, list[dict[str, Any]]]) -> None:
|
||
"""目标结构校验:确认源字段可映射到当前数据库表。"""
|
||
with SessionLocal() as db:
|
||
inspector = inspect(db.bind)
|
||
for table_name, rows in parsed.items():
|
||
if not inspector.has_table(table_name):
|
||
raise ImportValidationError(f"当前数据库缺少目标表:{table_name}")
|
||
target_columns = {column["name"] for column in inspector.get_columns(table_name)}
|
||
for row in rows:
|
||
extra_columns = sorted(set(row) - target_columns)
|
||
if extra_columns:
|
||
raise ImportValidationError(f"{table_name} 存在当前库不支持的字段:{extra_columns}")
|
||
|
||
|
||
def import_source_sql(path: Path, apply: bool = False, generate_exam_items: bool = True) -> ImportReport:
|
||
"""安全导入:解析源 SQL,并在显式 apply 时写入当前新表。"""
|
||
parsed, warnings, encoding = parse_source_dump(path)
|
||
validate_target_schema(parsed)
|
||
report = ImportReport(
|
||
source_path=str(path),
|
||
encoding=encoding,
|
||
table_rows={table: len(rows) for table, rows in parsed.items()},
|
||
warnings=warnings,
|
||
applied=apply,
|
||
)
|
||
if not apply:
|
||
return report
|
||
|
||
with SessionLocal() as db:
|
||
try:
|
||
case_rows = parsed.get("case_base", [])
|
||
report.upserted_cases = _upsert_cases(db, case_rows)
|
||
if "traditional_case" in parsed:
|
||
report.upserted_traditional_cases = _sync_traditional_cases(db, case_rows, parsed.get("traditional_case", []))
|
||
if "teaching_case" in parsed:
|
||
report.upserted_teaching_cases = _sync_teaching_cases(db, case_rows, parsed.get("teaching_case", []))
|
||
if "scoring_rule" in parsed:
|
||
report.replaced_scoring_rules = _replace_scoring_rules(db, case_rows, parsed.get("scoring_rule", []))
|
||
if generate_exam_items:
|
||
report.generated_exam_items = _upsert_generated_exam_items(db, case_rows)
|
||
db.commit()
|
||
except Exception:
|
||
db.rollback()
|
||
raise
|
||
return report
|
||
|
||
|
||
def _upsert_cases(db, rows: list[dict[str, Any]]) -> int:
|
||
"""病例导入:按 case_base.id 更新或插入病例主表。"""
|
||
count = 0
|
||
for row in rows:
|
||
row = dict(row)
|
||
row["status"] = 1
|
||
row["publish_status"] = 1
|
||
entity = db.get(CaseBase, row["id"])
|
||
if not entity:
|
||
db.add(CaseBase(**row))
|
||
else:
|
||
for key, value in row.items():
|
||
setattr(entity, key, value)
|
||
count += 1
|
||
return count
|
||
|
||
|
||
def _sync_traditional_cases(db, case_rows: list[dict[str, Any]], rows: list[dict[str, Any]]) -> int:
|
||
"""传统病例导入:按本次 SQL 同步练习模式扩展表,避免同 case_id 旧数据残留。"""
|
||
case_ids = _imported_case_ids(case_rows)
|
||
if case_ids:
|
||
db.execute(delete(TraditionalCase).where(TraditionalCase.case_id.in_(case_ids)))
|
||
count = 0
|
||
for row in rows:
|
||
db.add(TraditionalCase(**row))
|
||
count += 1
|
||
return count
|
||
|
||
|
||
def _sync_teaching_cases(db, case_rows: list[dict[str, Any]], rows: list[dict[str, Any]]) -> int:
|
||
"""教学互动病例导入:源 SQL 明确提供 teaching_case 时,以源数据为准同步扩展表。"""
|
||
case_ids = _imported_case_ids(case_rows)
|
||
if case_ids:
|
||
db.execute(delete(TeachingCase).where(TeachingCase.case_id.in_(case_ids)))
|
||
count = 0
|
||
for row in rows:
|
||
db.add(TeachingCase(**row))
|
||
count += 1
|
||
return count
|
||
|
||
|
||
def _replace_scoring_rules(db, case_rows: list[dict[str, Any]], rows: list[dict[str, Any]]) -> int:
|
||
"""评分规则导入:按本次导入病例替换 scoring_rule,保持评分规则与源数据一致。"""
|
||
case_ids = _imported_case_ids(case_rows)
|
||
for case_id in case_ids:
|
||
db.execute(delete(ScoringRule).where(ScoringRule.case_id == case_id))
|
||
for row in rows:
|
||
db.add(ScoringRule(**row))
|
||
return len(rows)
|
||
|
||
|
||
def _upsert_generated_exam_items(db, case_rows: list[dict[str, Any]]) -> int:
|
||
"""检查项目补齐:按 item_code 更新或补齐固定检查结果,避免删除历史训练引用的检查项。"""
|
||
changed = 0
|
||
for row in case_rows:
|
||
case_id = row["id"]
|
||
items = build_exam_items_from_case(row)
|
||
for item in items:
|
||
entity = db.scalar(
|
||
select(CaseExamItem).where(CaseExamItem.case_id == case_id, CaseExamItem.item_code == item["item_code"])
|
||
)
|
||
if not entity:
|
||
db.add(CaseExamItem(case_id=case_id, **item))
|
||
else:
|
||
for key, value in item.items():
|
||
setattr(entity, key, value)
|
||
changed += 1
|
||
return changed
|
||
|
||
|
||
def _imported_case_ids(case_rows: list[dict[str, Any]]) -> list[int]:
|
||
"""导入范围:提取本次 SQL 涉及的病例 ID,用于同步扩展表。"""
|
||
return sorted({int(row["id"]) for row in case_rows if row.get("id") is not None})
|
||
|
||
|
||
def build_exam_items_from_case(case_row: dict[str, Any]) -> list[dict[str, Any]]:
|
||
"""检查项目生成:从病例文本识别常见检查结果,保障 Demo 可继续申请检查。"""
|
||
text = " ".join(str(case_row.get(key) or "") for key in ("chief_complaint", "description", "knowledge_points"))
|
||
candidates = [
|
||
(
|
||
"blood_routine",
|
||
"血常规",
|
||
"lab",
|
||
"实验室检查",
|
||
_find_result(text, r"(WBC[^,。;;]*[,,]\s*[^,。;;]*(?:中性|neutrophil)[^,。;;]*)") or "血常规结果见病例资料。",
|
||
{"source": "case_description"},
|
||
"WBC" in text or "血常规" in text,
|
||
),
|
||
(
|
||
"crp",
|
||
"CRP",
|
||
"lab",
|
||
"实验室检查",
|
||
_find_result(text, r"(CRP\s*[^,。;;]*)") or "CRP 结果见病例资料。",
|
||
{"source": "case_description"},
|
||
"CRP" in text.upper(),
|
||
),
|
||
(
|
||
"chest_xray",
|
||
"胸片",
|
||
"imaging",
|
||
"影像检查",
|
||
_find_result(text, r"(胸片[^。;;]*)") or "胸片结果见病例资料。",
|
||
{"source": "case_description"},
|
||
"胸片" in text or "X线" in text,
|
||
),
|
||
(
|
||
"spo2",
|
||
"血氧饱和度",
|
||
"vital_sign",
|
||
"生命体征",
|
||
_find_result(text, r"(SpO2\s*\d+%?)") or "血氧饱和度结果见病例资料。",
|
||
{"source": "case_description"},
|
||
"SpO2" in text or "血氧" in text,
|
||
),
|
||
(
|
||
"chest_auscultation",
|
||
"肺部体格检查",
|
||
"physical_exam",
|
||
"体格检查",
|
||
_find_result(text, r"(肺[^。;;]*(?:湿啰音|哮鸣音|呼吸音)[^。;;]*)") or "肺部体格检查结果见病例资料。",
|
||
{"source": "case_description"},
|
||
"湿啰音" in text or "哮鸣音" in text or "肺" in text,
|
||
),
|
||
(
|
||
"mp_igm",
|
||
"肺炎支原体抗体IgM",
|
||
"lab",
|
||
"实验室检查",
|
||
_find_result(text, r"(肺炎支原体抗体IgM[^,。;;]*)") or "肺炎支原体抗体IgM结果见病例资料。",
|
||
{"source": "case_description"},
|
||
"肺炎支原体抗体IgM" in text or "支原体" in text,
|
||
),
|
||
]
|
||
items: list[dict[str, Any]] = []
|
||
for index, (code, name, item_type, category, result_text, structured, detected) in enumerate(candidates, start=1):
|
||
if detected:
|
||
items.append(
|
||
{
|
||
"item_code": code,
|
||
"item_name": name,
|
||
"item_type": item_type,
|
||
"category": category,
|
||
"result_text": result_text,
|
||
"result_structured": structured,
|
||
"is_key": True,
|
||
"is_abnormal": True,
|
||
"score_weight": Decimal("5.00"),
|
||
"display_order": index,
|
||
}
|
||
)
|
||
if items:
|
||
return items
|
||
return [
|
||
{
|
||
"item_code": "basic_exam",
|
||
"item_name": "基础检查",
|
||
"item_type": "physical_exam",
|
||
"category": "体格检查",
|
||
"result_text": "基础检查结果见病例资料。",
|
||
"result_structured": {"source": "case_description"},
|
||
"is_key": True,
|
||
"is_abnormal": False,
|
||
"score_weight": Decimal("1.00"),
|
||
"display_order": 1,
|
||
}
|
||
]
|
||
|
||
|
||
def _find_result(text: str, pattern: str) -> str | None:
|
||
"""文本抽取:从病例描述中抽取简短检查结果。"""
|
||
match = re.search(pattern, text, flags=re.I)
|
||
return match.group(1).strip() if match else None
|
||
|
||
|
||
def main() -> None:
|
||
"""命令入口:执行接口 SQL 的安全检查或导入。"""
|
||
parser = argparse.ArgumentParser(description="Safely import parsed case SQL into current medical agent schema.")
|
||
parser.add_argument("sql_path", type=Path, help="接口提供的 SQL dump 文件路径")
|
||
parser.add_argument("--apply", action="store_true", help="确认写入当前数据库;默认只检查不写入")
|
||
parser.add_argument("--no-generate-exam-items", action="store_true", help="不自动补齐 case_exam_item")
|
||
args = parser.parse_args()
|
||
|
||
try:
|
||
report = import_source_sql(
|
||
args.sql_path,
|
||
apply=args.apply,
|
||
generate_exam_items=not args.no_generate_exam_items,
|
||
)
|
||
except ImportValidationError as exc:
|
||
print(json.dumps({"ok": False, "error": str(exc)}, ensure_ascii=False, indent=2))
|
||
raise SystemExit(2) from exc
|
||
|
||
print(json.dumps({"ok": True, "report": report.as_dict()}, ensure_ascii=False, indent=2, default=str))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|