chore: initialize medical consultation agent demo
This commit is contained in:
@@ -0,0 +1,612 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, inspect, select
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||
|
||||
from app.db.session import SessionLocal
|
||||
from app.models.source_case import CaseBase, CaseExamItem, ScoringRule, TeachingCase, TraditionalCase
|
||||
|
||||
|
||||
SOURCE_TABLES = ("case_base",)
|
||||
OPTIONAL_SOURCE_TABLES = ("traditional_case", "teaching_case", "scoring_rule", "case_exam_item")
|
||||
DANGEROUS_PATTERNS = (
|
||||
r"\bDROP\s+DATABASE\b",
|
||||
r"\bDROP\s+TABLE\b",
|
||||
r"\bTRUNCATE\b",
|
||||
r"\bDELETE\s+FROM\b",
|
||||
r"\bCREATE\s+DATABASE\b",
|
||||
r"\bALTER\s+TABLE\b",
|
||||
)
|
||||
JSON_COLUMNS = {
|
||||
"case_base": {"symptom_tags", "disease_tags", "competency_tags", "guideline_tags", "knowledge_points", "multimodal_assets"},
|
||||
"scoring_rule": {"rubric_json"},
|
||||
}
|
||||
DATETIME_COLUMNS = {"created_at", "updated_at"}
|
||||
DECIMAL_COLUMNS = {"score_weight"}
|
||||
INT_COLUMNS = {
|
||||
"id",
|
||||
"difficulty_score",
|
||||
"patient_age",
|
||||
"estimated_minutes",
|
||||
"vector_status",
|
||||
"publish_status",
|
||||
"status",
|
||||
"created_by_id",
|
||||
"department_id",
|
||||
"case_id",
|
||||
}
|
||||
BOOL_COLUMNS = {"osce_enabled", "rag_enabled", "ai_auto_score", "osce_dimension"}
|
||||
|
||||
|
||||
class ImportValidationError(Exception):
|
||||
"""导入校验错误:源 SQL 不满足安全导入要求时中止。"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImportReport:
|
||||
"""导入报告:记录检查、写入和跳过情况。"""
|
||||
|
||||
source_path: str
|
||||
encoding: str
|
||||
table_rows: dict[str, int]
|
||||
warnings: list[str]
|
||||
applied: bool
|
||||
upserted_cases: int = 0
|
||||
upserted_traditional_cases: int = 0
|
||||
upserted_teaching_cases: int = 0
|
||||
replaced_scoring_rules: int = 0
|
||||
generated_exam_items: int = 0
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""报告输出:转换为前端和命令行均可读的结构。"""
|
||||
return {
|
||||
"source_path": self.source_path,
|
||||
"encoding": self.encoding,
|
||||
"table_rows": self.table_rows,
|
||||
"warnings": self.warnings,
|
||||
"applied": self.applied,
|
||||
"upserted_cases": self.upserted_cases,
|
||||
"upserted_traditional_cases": self.upserted_traditional_cases,
|
||||
"upserted_teaching_cases": self.upserted_teaching_cases,
|
||||
"replaced_scoring_rules": self.replaced_scoring_rules,
|
||||
"generated_exam_items": self.generated_exam_items,
|
||||
}
|
||||
|
||||
|
||||
def detect_encoding(path: Path) -> str:
|
||||
"""编码识别:根据 BOM 和试读结果判断 SQL 文件编码。"""
|
||||
raw = path.read_bytes()
|
||||
if raw.startswith(b"\xff\xfe") or raw.startswith(b"\xfe\xff"):
|
||||
return "utf-16"
|
||||
if raw.startswith(b"\xef\xbb\xbf"):
|
||||
return "utf-8-sig"
|
||||
for encoding in ("utf-8", "utf-16"):
|
||||
try:
|
||||
raw.decode(encoding)
|
||||
return encoding
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
raise ImportValidationError("SQL 文件编码无法识别,请提供 UTF-8 或 UTF-16 文件。")
|
||||
|
||||
|
||||
def load_sql_text(path: Path) -> tuple[str, str]:
|
||||
"""文件读取:读取接口提供的 SQL dump,并返回文本与编码。"""
|
||||
if not path.exists():
|
||||
raise ImportValidationError(f"SQL 文件不存在:{path}")
|
||||
encoding = detect_encoding(path)
|
||||
return path.read_text(encoding=encoding), encoding
|
||||
|
||||
|
||||
def find_dangerous_statements(sql_text: str) -> list[str]:
|
||||
"""安全扫描:识别源 dump 中不能在正式库直接执行的 DDL/DML。"""
|
||||
hits: list[str] = []
|
||||
for pattern in DANGEROUS_PATTERNS:
|
||||
if re.search(pattern, sql_text, flags=re.IGNORECASE):
|
||||
hits.append(pattern.replace(r"\b", "").replace("\\s+", " "))
|
||||
return hits
|
||||
|
||||
|
||||
def extract_create_columns(sql_text: str, table_name: str) -> list[str]:
|
||||
"""字段提取:从 CREATE TABLE 中按源顺序读取字段名。"""
|
||||
match = re.search(rf"CREATE TABLE `{re.escape(table_name)}` \((.*?)\) ENGINE=", sql_text, flags=re.S | re.I)
|
||||
if not match:
|
||||
return []
|
||||
columns: list[str] = []
|
||||
for line in match.group(1).splitlines():
|
||||
stripped = line.strip().rstrip(",")
|
||||
if stripped.startswith("`"):
|
||||
columns.append(stripped.split("`", 2)[1])
|
||||
return columns
|
||||
|
||||
|
||||
def extract_insert_rows(sql_text: str, table_name: str, columns: list[str]) -> list[dict[str, Any]]:
|
||||
"""数据提取:只解析 INSERT VALUES 数据,不执行源 SQL。"""
|
||||
rows: list[dict[str, Any]] = []
|
||||
pattern = re.compile(rf"INSERT INTO `{re.escape(table_name)}` VALUES\s*(.*?);", flags=re.S | re.I)
|
||||
for match in pattern.finditer(sql_text):
|
||||
tuples = parse_values_clause(match.group(1))
|
||||
for values in tuples:
|
||||
if len(values) != len(columns):
|
||||
raise ImportValidationError(
|
||||
f"{table_name} INSERT 字段数不匹配:期望 {len(columns)},实际 {len(values)}。"
|
||||
)
|
||||
row = {columns[index]: normalize_value(table_name, columns[index], value) for index, value in enumerate(values)}
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
|
||||
def parse_values_clause(values_clause: str) -> list[list[Any]]:
|
||||
"""VALUES 解析:把 SQL values 子句解析为二维数组,损坏字符串会直接报错。"""
|
||||
rows: list[list[Any]] = []
|
||||
index = 0
|
||||
length = len(values_clause)
|
||||
while index < length:
|
||||
while index < length and values_clause[index].isspace():
|
||||
index += 1
|
||||
if index >= length:
|
||||
break
|
||||
if values_clause[index] == ",":
|
||||
index += 1
|
||||
continue
|
||||
if values_clause[index] != "(":
|
||||
raise ImportValidationError(f"VALUES 子句格式错误:第 {index} 个字符不是 '('。")
|
||||
row, index = _parse_tuple(values_clause, index)
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
|
||||
def _parse_tuple(text: str, start: int) -> tuple[list[Any], int]:
|
||||
"""元组解析:解析一组括号内的 SQL 字段值。"""
|
||||
values: list[Any] = []
|
||||
token: list[str] = []
|
||||
in_string = False
|
||||
was_string = False
|
||||
escaped = False
|
||||
index = start + 1
|
||||
while index < len(text):
|
||||
char = text[index]
|
||||
if in_string:
|
||||
if escaped:
|
||||
token.append(_unescape_char(char))
|
||||
escaped = False
|
||||
elif char == "\\":
|
||||
escaped = True
|
||||
elif char == ")":
|
||||
recovered_at_end = _recover_unclosed_string_at_tuple_end(token)
|
||||
if recovered_at_end:
|
||||
string_value, raw_values = recovered_at_end
|
||||
values.append(string_value)
|
||||
values.extend(raw_values)
|
||||
return values, index + 1
|
||||
token.append(char)
|
||||
elif char == "'":
|
||||
recovered = _recover_misplaced_quote_separator(token, text[index + 1] if index + 1 < len(text) else "")
|
||||
if recovered:
|
||||
string_value, raw_values = recovered
|
||||
values.append(string_value)
|
||||
values.extend(raw_values)
|
||||
token = []
|
||||
was_string = True
|
||||
index += 1
|
||||
continue
|
||||
in_string = False
|
||||
else:
|
||||
token.append(char)
|
||||
index += 1
|
||||
continue
|
||||
|
||||
if char == "'":
|
||||
stripped = "".join(token).strip()
|
||||
if stripped:
|
||||
raise ImportValidationError(f"字符串前存在非法未引用内容:{stripped[:20]}")
|
||||
in_string = True
|
||||
was_string = True
|
||||
index += 1
|
||||
continue
|
||||
if was_string and char.isspace():
|
||||
index += 1
|
||||
continue
|
||||
if was_string and char not in {",", ")"}:
|
||||
raise ImportValidationError(f"字符串后存在非法未引用内容:{char}{text[index + 1:index + 20]}")
|
||||
if char == ",":
|
||||
values.append(_coerce_raw_token("".join(token), was_string))
|
||||
token = []
|
||||
was_string = False
|
||||
index += 1
|
||||
continue
|
||||
if char == ")":
|
||||
values.append(_coerce_raw_token("".join(token), was_string))
|
||||
return values, index + 1
|
||||
token.append(char)
|
||||
index += 1
|
||||
raise ImportValidationError("VALUES 子句存在未闭合括号或未闭合字符串。")
|
||||
|
||||
|
||||
def _recover_misplaced_quote_separator(token: list[str], next_char: str) -> tuple[str, list[Any]] | None:
|
||||
"""兼容解析:修复接口 SQL 文本字段结尾写成 `文本,'next'` 的引号/逗号错位。"""
|
||||
if not next_char or next_char in {",", ")"} or next_char.isspace():
|
||||
return None
|
||||
raw = "".join(token)
|
||||
if not raw.endswith(","):
|
||||
return None
|
||||
|
||||
body = raw[:-1]
|
||||
trailing_values: list[Any] = []
|
||||
while True:
|
||||
head, separator, tail = body.rpartition(",")
|
||||
if not separator or not _looks_like_unquoted_scalar(tail):
|
||||
break
|
||||
trailing_values.insert(0, _coerce_raw_token(tail, was_string=False))
|
||||
body = head
|
||||
if not body:
|
||||
return None
|
||||
return body, trailing_values
|
||||
|
||||
|
||||
def _recover_unclosed_string_at_tuple_end(token: list[str]) -> tuple[str, list[Any]] | None:
|
||||
"""兼容解析:修复文本字段缺少结束引号且后接 `,数字)` 的源 SQL。"""
|
||||
body = "".join(token)
|
||||
trailing_values: list[Any] = []
|
||||
while True:
|
||||
head, separator, tail = body.rpartition(",")
|
||||
if not separator or not _looks_like_unquoted_scalar(tail):
|
||||
break
|
||||
trailing_values.insert(0, _coerce_raw_token(tail, was_string=False))
|
||||
body = head
|
||||
if not body or not trailing_values:
|
||||
return None
|
||||
return body, trailing_values
|
||||
|
||||
|
||||
def _looks_like_unquoted_scalar(value: str) -> bool:
|
||||
"""兼容解析:判断错位引号前夹带的字段是否是可安全恢复的未引用标量。"""
|
||||
stripped = value.strip()
|
||||
if not stripped:
|
||||
return False
|
||||
if stripped.upper() == "NULL":
|
||||
return True
|
||||
return bool(re.fullmatch(r"-?\d+(?:\.\d+)?", stripped))
|
||||
|
||||
|
||||
def _unescape_char(char: str) -> str:
|
||||
"""SQL 转义:处理 dump 中的常见反斜杠转义。"""
|
||||
return {
|
||||
"0": "\0",
|
||||
"b": "\b",
|
||||
"n": "\n",
|
||||
"r": "\r",
|
||||
"t": "\t",
|
||||
"Z": "\x1a",
|
||||
"\\": "\\",
|
||||
"'": "'",
|
||||
'"': '"',
|
||||
}.get(char, char)
|
||||
|
||||
|
||||
def _coerce_raw_token(raw: str, was_string: bool) -> Any:
|
||||
"""原始字段转换:区分 SQL NULL、数字和字符串。"""
|
||||
if was_string:
|
||||
return raw
|
||||
value = raw.strip()
|
||||
if not value:
|
||||
return ""
|
||||
if value.upper() == "NULL":
|
||||
return None
|
||||
if re.fullmatch(r"-?\d+", value):
|
||||
return int(value)
|
||||
if re.fullmatch(r"-?\d+\.\d+", value):
|
||||
return Decimal(value)
|
||||
if re.search(r"[A-Za-z\u4e00-\u9fff]", value):
|
||||
raise ImportValidationError(f"发现未引用文本字段:{value[:30]}")
|
||||
return value
|
||||
|
||||
|
||||
def normalize_value(table_name: str, column: str, value: Any) -> Any:
|
||||
"""字段归一:按当前 ORM 需要转换 JSON、时间、布尔和数值。"""
|
||||
if value is None:
|
||||
return None
|
||||
if column in JSON_COLUMNS.get(table_name, set()):
|
||||
if isinstance(value, (list, dict)):
|
||||
return value
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (TypeError, json.JSONDecodeError) as exc:
|
||||
raise ImportValidationError(f"{table_name}.{column} 不是合法 JSON。") from exc
|
||||
if column in DATETIME_COLUMNS and isinstance(value, str):
|
||||
try:
|
||||
return datetime.fromisoformat(value)
|
||||
except ValueError as exc:
|
||||
raise ImportValidationError(f"{table_name}.{column} 不是合法 datetime。") from exc
|
||||
if column in BOOL_COLUMNS:
|
||||
return bool(int(value))
|
||||
if column in INT_COLUMNS and value is not None:
|
||||
return int(value)
|
||||
if column in DECIMAL_COLUMNS and value is not None:
|
||||
return Decimal(str(value))
|
||||
return value
|
||||
|
||||
|
||||
def parse_source_dump(path: Path) -> tuple[dict[str, list[dict[str, Any]]], list[str], str]:
|
||||
"""源文件解析:提取可导入表数据并返回兼容性警告。"""
|
||||
sql_text, encoding = load_sql_text(path)
|
||||
warnings = []
|
||||
dangerous = find_dangerous_statements(sql_text)
|
||||
if dangerous:
|
||||
warnings.append("源 SQL 包含 DDL/DML 覆盖语句,导入器会忽略这些语句:" + ", ".join(sorted(set(dangerous))))
|
||||
|
||||
parsed: dict[str, list[dict[str, Any]]] = {}
|
||||
for table_name in SOURCE_TABLES:
|
||||
columns = extract_create_columns(sql_text, table_name)
|
||||
if not columns:
|
||||
raise ImportValidationError(f"源 SQL 缺少必需表结构:{table_name}")
|
||||
parsed[table_name] = extract_insert_rows(sql_text, table_name, columns)
|
||||
|
||||
for table_name in OPTIONAL_SOURCE_TABLES:
|
||||
columns = extract_create_columns(sql_text, table_name)
|
||||
if not columns:
|
||||
warnings.append(f"源 SQL 未包含 {table_name},导入器会按当前业务规则处理。")
|
||||
continue
|
||||
parsed[table_name] = extract_insert_rows(sql_text, table_name, columns)
|
||||
if not parsed.get("traditional_case") and not parsed.get("teaching_case"):
|
||||
warnings.append("源 SQL 未包含 traditional_case 或 teaching_case,病例导入后暂时缺少训练模式扩展数据。")
|
||||
if not parsed.get("scoring_rule"):
|
||||
warnings.append("源 SQL 未包含 scoring_rule,评价时将缺少接口侧基础评分规则。")
|
||||
return parsed, warnings, encoding
|
||||
|
||||
|
||||
def validate_target_schema(parsed: dict[str, list[dict[str, Any]]]) -> None:
|
||||
"""目标结构校验:确认源字段可映射到当前数据库表。"""
|
||||
with SessionLocal() as db:
|
||||
inspector = inspect(db.bind)
|
||||
for table_name, rows in parsed.items():
|
||||
if not inspector.has_table(table_name):
|
||||
raise ImportValidationError(f"当前数据库缺少目标表:{table_name}")
|
||||
target_columns = {column["name"] for column in inspector.get_columns(table_name)}
|
||||
for row in rows:
|
||||
extra_columns = sorted(set(row) - target_columns)
|
||||
if extra_columns:
|
||||
raise ImportValidationError(f"{table_name} 存在当前库不支持的字段:{extra_columns}")
|
||||
|
||||
|
||||
def import_source_sql(path: Path, apply: bool = False, generate_exam_items: bool = True) -> ImportReport:
|
||||
"""安全导入:解析源 SQL,并在显式 apply 时写入当前新表。"""
|
||||
parsed, warnings, encoding = parse_source_dump(path)
|
||||
validate_target_schema(parsed)
|
||||
report = ImportReport(
|
||||
source_path=str(path),
|
||||
encoding=encoding,
|
||||
table_rows={table: len(rows) for table, rows in parsed.items()},
|
||||
warnings=warnings,
|
||||
applied=apply,
|
||||
)
|
||||
if not apply:
|
||||
return report
|
||||
|
||||
with SessionLocal() as db:
|
||||
try:
|
||||
case_rows = parsed.get("case_base", [])
|
||||
report.upserted_cases = _upsert_cases(db, case_rows)
|
||||
if "traditional_case" in parsed:
|
||||
report.upserted_traditional_cases = _sync_traditional_cases(db, case_rows, parsed.get("traditional_case", []))
|
||||
if "teaching_case" in parsed:
|
||||
report.upserted_teaching_cases = _sync_teaching_cases(db, case_rows, parsed.get("teaching_case", []))
|
||||
if "scoring_rule" in parsed:
|
||||
report.replaced_scoring_rules = _replace_scoring_rules(db, case_rows, parsed.get("scoring_rule", []))
|
||||
if generate_exam_items:
|
||||
report.generated_exam_items = _upsert_generated_exam_items(db, case_rows)
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
return report
|
||||
|
||||
|
||||
def _upsert_cases(db, rows: list[dict[str, Any]]) -> int:
|
||||
"""病例导入:按 case_base.id 更新或插入病例主表。"""
|
||||
count = 0
|
||||
for row in rows:
|
||||
row = dict(row)
|
||||
row["status"] = 1
|
||||
row["publish_status"] = 1
|
||||
entity = db.get(CaseBase, row["id"])
|
||||
if not entity:
|
||||
db.add(CaseBase(**row))
|
||||
else:
|
||||
for key, value in row.items():
|
||||
setattr(entity, key, value)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
def _sync_traditional_cases(db, case_rows: list[dict[str, Any]], rows: list[dict[str, Any]]) -> int:
|
||||
"""传统病例导入:按本次 SQL 同步练习模式扩展表,避免同 case_id 旧数据残留。"""
|
||||
case_ids = _imported_case_ids(case_rows)
|
||||
if case_ids:
|
||||
db.execute(delete(TraditionalCase).where(TraditionalCase.case_id.in_(case_ids)))
|
||||
count = 0
|
||||
for row in rows:
|
||||
db.add(TraditionalCase(**row))
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
def _sync_teaching_cases(db, case_rows: list[dict[str, Any]], rows: list[dict[str, Any]]) -> int:
|
||||
"""教学互动病例导入:源 SQL 明确提供 teaching_case 时,以源数据为准同步扩展表。"""
|
||||
case_ids = _imported_case_ids(case_rows)
|
||||
if case_ids:
|
||||
db.execute(delete(TeachingCase).where(TeachingCase.case_id.in_(case_ids)))
|
||||
count = 0
|
||||
for row in rows:
|
||||
db.add(TeachingCase(**row))
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
def _replace_scoring_rules(db, case_rows: list[dict[str, Any]], rows: list[dict[str, Any]]) -> int:
|
||||
"""评分规则导入:按本次导入病例替换 scoring_rule,保持评分规则与源数据一致。"""
|
||||
case_ids = _imported_case_ids(case_rows)
|
||||
for case_id in case_ids:
|
||||
db.execute(delete(ScoringRule).where(ScoringRule.case_id == case_id))
|
||||
for row in rows:
|
||||
db.add(ScoringRule(**row))
|
||||
return len(rows)
|
||||
|
||||
|
||||
def _upsert_generated_exam_items(db, case_rows: list[dict[str, Any]]) -> int:
|
||||
"""检查项目补齐:按 item_code 更新或补齐固定检查结果,避免删除历史训练引用的检查项。"""
|
||||
changed = 0
|
||||
for row in case_rows:
|
||||
case_id = row["id"]
|
||||
items = build_exam_items_from_case(row)
|
||||
for item in items:
|
||||
entity = db.scalar(
|
||||
select(CaseExamItem).where(CaseExamItem.case_id == case_id, CaseExamItem.item_code == item["item_code"])
|
||||
)
|
||||
if not entity:
|
||||
db.add(CaseExamItem(case_id=case_id, **item))
|
||||
else:
|
||||
for key, value in item.items():
|
||||
setattr(entity, key, value)
|
||||
changed += 1
|
||||
return changed
|
||||
|
||||
|
||||
def _imported_case_ids(case_rows: list[dict[str, Any]]) -> list[int]:
|
||||
"""导入范围:提取本次 SQL 涉及的病例 ID,用于同步扩展表。"""
|
||||
return sorted({int(row["id"]) for row in case_rows if row.get("id") is not None})
|
||||
|
||||
|
||||
def build_exam_items_from_case(case_row: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""检查项目生成:从病例文本识别常见检查结果,保障 Demo 可继续申请检查。"""
|
||||
text = " ".join(str(case_row.get(key) or "") for key in ("chief_complaint", "description", "knowledge_points"))
|
||||
candidates = [
|
||||
(
|
||||
"blood_routine",
|
||||
"血常规",
|
||||
"lab",
|
||||
"实验室检查",
|
||||
_find_result(text, r"(WBC[^,。;;]*[,,]\s*[^,。;;]*(?:中性|neutrophil)[^,。;;]*)") or "血常规结果见病例资料。",
|
||||
{"source": "case_description"},
|
||||
"WBC" in text or "血常规" in text,
|
||||
),
|
||||
(
|
||||
"crp",
|
||||
"CRP",
|
||||
"lab",
|
||||
"实验室检查",
|
||||
_find_result(text, r"(CRP\s*[^,。;;]*)") or "CRP 结果见病例资料。",
|
||||
{"source": "case_description"},
|
||||
"CRP" in text.upper(),
|
||||
),
|
||||
(
|
||||
"chest_xray",
|
||||
"胸片",
|
||||
"imaging",
|
||||
"影像检查",
|
||||
_find_result(text, r"(胸片[^。;;]*)") or "胸片结果见病例资料。",
|
||||
{"source": "case_description"},
|
||||
"胸片" in text or "X线" in text,
|
||||
),
|
||||
(
|
||||
"spo2",
|
||||
"血氧饱和度",
|
||||
"vital_sign",
|
||||
"生命体征",
|
||||
_find_result(text, r"(SpO2\s*\d+%?)") or "血氧饱和度结果见病例资料。",
|
||||
{"source": "case_description"},
|
||||
"SpO2" in text or "血氧" in text,
|
||||
),
|
||||
(
|
||||
"chest_auscultation",
|
||||
"肺部体格检查",
|
||||
"physical_exam",
|
||||
"体格检查",
|
||||
_find_result(text, r"(肺[^。;;]*(?:湿啰音|哮鸣音|呼吸音)[^。;;]*)") or "肺部体格检查结果见病例资料。",
|
||||
{"source": "case_description"},
|
||||
"湿啰音" in text or "哮鸣音" in text or "肺" in text,
|
||||
),
|
||||
(
|
||||
"mp_igm",
|
||||
"肺炎支原体抗体IgM",
|
||||
"lab",
|
||||
"实验室检查",
|
||||
_find_result(text, r"(肺炎支原体抗体IgM[^,。;;]*)") or "肺炎支原体抗体IgM结果见病例资料。",
|
||||
{"source": "case_description"},
|
||||
"肺炎支原体抗体IgM" in text or "支原体" in text,
|
||||
),
|
||||
]
|
||||
items: list[dict[str, Any]] = []
|
||||
for index, (code, name, item_type, category, result_text, structured, detected) in enumerate(candidates, start=1):
|
||||
if detected:
|
||||
items.append(
|
||||
{
|
||||
"item_code": code,
|
||||
"item_name": name,
|
||||
"item_type": item_type,
|
||||
"category": category,
|
||||
"result_text": result_text,
|
||||
"result_structured": structured,
|
||||
"is_key": True,
|
||||
"is_abnormal": True,
|
||||
"score_weight": Decimal("5.00"),
|
||||
"display_order": index,
|
||||
}
|
||||
)
|
||||
if items:
|
||||
return items
|
||||
return [
|
||||
{
|
||||
"item_code": "basic_exam",
|
||||
"item_name": "基础检查",
|
||||
"item_type": "physical_exam",
|
||||
"category": "体格检查",
|
||||
"result_text": "基础检查结果见病例资料。",
|
||||
"result_structured": {"source": "case_description"},
|
||||
"is_key": True,
|
||||
"is_abnormal": False,
|
||||
"score_weight": Decimal("1.00"),
|
||||
"display_order": 1,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def _find_result(text: str, pattern: str) -> str | None:
|
||||
"""文本抽取:从病例描述中抽取简短检查结果。"""
|
||||
match = re.search(pattern, text, flags=re.I)
|
||||
return match.group(1).strip() if match else None
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""命令入口:执行接口 SQL 的安全检查或导入。"""
|
||||
parser = argparse.ArgumentParser(description="Safely import parsed case SQL into current medical agent schema.")
|
||||
parser.add_argument("sql_path", type=Path, help="接口提供的 SQL dump 文件路径")
|
||||
parser.add_argument("--apply", action="store_true", help="确认写入当前数据库;默认只检查不写入")
|
||||
parser.add_argument("--no-generate-exam-items", action="store_true", help="不自动补齐 case_exam_item")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
report = import_source_sql(
|
||||
args.sql_path,
|
||||
apply=args.apply,
|
||||
generate_exam_items=not args.no_generate_exam_items,
|
||||
)
|
||||
except ImportValidationError as exc:
|
||||
print(json.dumps({"ok": False, "error": str(exc)}, ensure_ascii=False, indent=2))
|
||||
raise SystemExit(2) from exc
|
||||
|
||||
print(json.dumps({"ok": True, "report": report.as_dict()}, ensure_ascii=False, indent=2, default=str))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user