chore: initialize medical consultation agent demo

This commit is contained in:
刘金宝
2026-06-01 09:25:26 +08:00
commit a7733243b2
139 changed files with 15764 additions and 0 deletions
+1
View File
@@ -0,0 +1 @@
"""后端脚本包。"""
+53
View File
@@ -0,0 +1,53 @@
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from app.agents.llm_adapter import OpenAICompatibleLLMClient
from app.agents.patient_agent import PatientAgent
from app.core.config import settings
from app.db.session import SessionLocal
from app.repositories.case_repository import CaseRepository
async def main() -> None:
"""本地调试:直接调用 Patient Agent 流式回复,绕过前端和 FastAPI。"""
client = OpenAICompatibleLLMClient()
print(f"mock_mode={client.is_mock_mode}")
print(f"fast_model={settings.llm_fast_model}")
print(f"fast_thinking={settings.llm_fast_thinking_enabled}")
print(f"stream_first_token_timeout={settings.llm_stream_first_token_timeout_seconds}")
print(f"stream_total_timeout={settings.llm_stream_total_timeout_seconds}")
db = SessionLocal()
try:
case = CaseRepository(db).list_active_cases()[0]
text = ""
first_token_ms = None
done_seen = False
async for chunk in PatientAgent().stream_reply(case, [], "孩子发热几天了?最高体温多少?", "novice"):
if first_token_ms is None and chunk.first_token_ms is not None:
first_token_ms = chunk.first_token_ms
if chunk.done:
done_seen = True
print(f"done_seen={done_seen}")
print(f"first_token_ms={first_token_ms}")
print(f"total_latency_ms={chunk.total_latency_ms}")
print(f"model={chunk.model}")
print(f"fallback_used={chunk.fallback_used}")
print(f"text_len={len(text)}")
print(f"text_preview={text[:30]}")
break
text += chunk.delta
if not done_seen:
print("done_seen=False")
print(f"text_len={len(text)}")
print(f"text_preview={text[:30]}")
raise SystemExit(1)
finally:
db.close()
if __name__ == "__main__":
asyncio.run(main())
+43
View File
@@ -0,0 +1,43 @@
from __future__ import annotations
import sys
from pathlib import Path
from sqlalchemy import inspect, text
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from app.db.session import SessionLocal
LEGACY_TABLES = [
"evaluation_report_exports",
"evaluation_records",
"session_submissions",
"session_orders",
"session_runtime_messages",
"training_sessions",
"case_exam_items",
"rubric_templates",
"cases",
]
def main() -> None:
"""旧表清理:在新表链路验证通过后删除不再被业务依赖的旧表。"""
with SessionLocal() as db:
existing = set(inspect(db.bind).get_table_names()) if db.bind else set()
dialect = db.bind.dialect.name if db.bind else ""
if dialect == "mysql":
db.execute(text("SET FOREIGN_KEY_CHECKS=0"))
for table_name in LEGACY_TABLES:
if table_name in existing:
db.execute(text(f"DROP TABLE `{table_name}`"))
print(f"dropped legacy table: {table_name}")
if dialect == "mysql":
db.execute(text("SET FOREIGN_KEY_CHECKS=1"))
db.commit()
print("legacy table cleanup completed")
if __name__ == "__main__":
main()
+612
View File
@@ -0,0 +1,612 @@
from __future__ import annotations
import argparse
import json
import re
import sys
from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal
from pathlib import Path
from typing import Any
from sqlalchemy import delete, inspect, select
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from app.db.session import SessionLocal
from app.models.source_case import CaseBase, CaseExamItem, ScoringRule, TeachingCase, TraditionalCase
SOURCE_TABLES = ("case_base",)
OPTIONAL_SOURCE_TABLES = ("traditional_case", "teaching_case", "scoring_rule", "case_exam_item")
DANGEROUS_PATTERNS = (
r"\bDROP\s+DATABASE\b",
r"\bDROP\s+TABLE\b",
r"\bTRUNCATE\b",
r"\bDELETE\s+FROM\b",
r"\bCREATE\s+DATABASE\b",
r"\bALTER\s+TABLE\b",
)
JSON_COLUMNS = {
"case_base": {"symptom_tags", "disease_tags", "competency_tags", "guideline_tags", "knowledge_points", "multimodal_assets"},
"scoring_rule": {"rubric_json"},
}
DATETIME_COLUMNS = {"created_at", "updated_at"}
DECIMAL_COLUMNS = {"score_weight"}
INT_COLUMNS = {
"id",
"difficulty_score",
"patient_age",
"estimated_minutes",
"vector_status",
"publish_status",
"status",
"created_by_id",
"department_id",
"case_id",
}
BOOL_COLUMNS = {"osce_enabled", "rag_enabled", "ai_auto_score", "osce_dimension"}
class ImportValidationError(Exception):
"""导入校验错误:源 SQL 不满足安全导入要求时中止。"""
@dataclass
class ImportReport:
"""导入报告:记录检查、写入和跳过情况。"""
source_path: str
encoding: str
table_rows: dict[str, int]
warnings: list[str]
applied: bool
upserted_cases: int = 0
upserted_traditional_cases: int = 0
upserted_teaching_cases: int = 0
replaced_scoring_rules: int = 0
generated_exam_items: int = 0
def as_dict(self) -> dict[str, Any]:
"""报告输出:转换为前端和命令行均可读的结构。"""
return {
"source_path": self.source_path,
"encoding": self.encoding,
"table_rows": self.table_rows,
"warnings": self.warnings,
"applied": self.applied,
"upserted_cases": self.upserted_cases,
"upserted_traditional_cases": self.upserted_traditional_cases,
"upserted_teaching_cases": self.upserted_teaching_cases,
"replaced_scoring_rules": self.replaced_scoring_rules,
"generated_exam_items": self.generated_exam_items,
}
def detect_encoding(path: Path) -> str:
"""编码识别:根据 BOM 和试读结果判断 SQL 文件编码。"""
raw = path.read_bytes()
if raw.startswith(b"\xff\xfe") or raw.startswith(b"\xfe\xff"):
return "utf-16"
if raw.startswith(b"\xef\xbb\xbf"):
return "utf-8-sig"
for encoding in ("utf-8", "utf-16"):
try:
raw.decode(encoding)
return encoding
except UnicodeDecodeError:
continue
raise ImportValidationError("SQL 文件编码无法识别,请提供 UTF-8 或 UTF-16 文件。")
def load_sql_text(path: Path) -> tuple[str, str]:
"""文件读取:读取接口提供的 SQL dump,并返回文本与编码。"""
if not path.exists():
raise ImportValidationError(f"SQL 文件不存在:{path}")
encoding = detect_encoding(path)
return path.read_text(encoding=encoding), encoding
def find_dangerous_statements(sql_text: str) -> list[str]:
"""安全扫描:识别源 dump 中不能在正式库直接执行的 DDL/DML。"""
hits: list[str] = []
for pattern in DANGEROUS_PATTERNS:
if re.search(pattern, sql_text, flags=re.IGNORECASE):
hits.append(pattern.replace(r"\b", "").replace("\\s+", " "))
return hits
def extract_create_columns(sql_text: str, table_name: str) -> list[str]:
"""字段提取:从 CREATE TABLE 中按源顺序读取字段名。"""
match = re.search(rf"CREATE TABLE `{re.escape(table_name)}` \((.*?)\) ENGINE=", sql_text, flags=re.S | re.I)
if not match:
return []
columns: list[str] = []
for line in match.group(1).splitlines():
stripped = line.strip().rstrip(",")
if stripped.startswith("`"):
columns.append(stripped.split("`", 2)[1])
return columns
def extract_insert_rows(sql_text: str, table_name: str, columns: list[str]) -> list[dict[str, Any]]:
"""数据提取:只解析 INSERT VALUES 数据,不执行源 SQL。"""
rows: list[dict[str, Any]] = []
pattern = re.compile(rf"INSERT INTO `{re.escape(table_name)}` VALUES\s*(.*?);", flags=re.S | re.I)
for match in pattern.finditer(sql_text):
tuples = parse_values_clause(match.group(1))
for values in tuples:
if len(values) != len(columns):
raise ImportValidationError(
f"{table_name} INSERT 字段数不匹配:期望 {len(columns)},实际 {len(values)}"
)
row = {columns[index]: normalize_value(table_name, columns[index], value) for index, value in enumerate(values)}
rows.append(row)
return rows
def parse_values_clause(values_clause: str) -> list[list[Any]]:
"""VALUES 解析:把 SQL values 子句解析为二维数组,损坏字符串会直接报错。"""
rows: list[list[Any]] = []
index = 0
length = len(values_clause)
while index < length:
while index < length and values_clause[index].isspace():
index += 1
if index >= length:
break
if values_clause[index] == ",":
index += 1
continue
if values_clause[index] != "(":
raise ImportValidationError(f"VALUES 子句格式错误:第 {index} 个字符不是 '('")
row, index = _parse_tuple(values_clause, index)
rows.append(row)
return rows
def _parse_tuple(text: str, start: int) -> tuple[list[Any], int]:
"""元组解析:解析一组括号内的 SQL 字段值。"""
values: list[Any] = []
token: list[str] = []
in_string = False
was_string = False
escaped = False
index = start + 1
while index < len(text):
char = text[index]
if in_string:
if escaped:
token.append(_unescape_char(char))
escaped = False
elif char == "\\":
escaped = True
elif char == ")":
recovered_at_end = _recover_unclosed_string_at_tuple_end(token)
if recovered_at_end:
string_value, raw_values = recovered_at_end
values.append(string_value)
values.extend(raw_values)
return values, index + 1
token.append(char)
elif char == "'":
recovered = _recover_misplaced_quote_separator(token, text[index + 1] if index + 1 < len(text) else "")
if recovered:
string_value, raw_values = recovered
values.append(string_value)
values.extend(raw_values)
token = []
was_string = True
index += 1
continue
in_string = False
else:
token.append(char)
index += 1
continue
if char == "'":
stripped = "".join(token).strip()
if stripped:
raise ImportValidationError(f"字符串前存在非法未引用内容:{stripped[:20]}")
in_string = True
was_string = True
index += 1
continue
if was_string and char.isspace():
index += 1
continue
if was_string and char not in {",", ")"}:
raise ImportValidationError(f"字符串后存在非法未引用内容:{char}{text[index + 1:index + 20]}")
if char == ",":
values.append(_coerce_raw_token("".join(token), was_string))
token = []
was_string = False
index += 1
continue
if char == ")":
values.append(_coerce_raw_token("".join(token), was_string))
return values, index + 1
token.append(char)
index += 1
raise ImportValidationError("VALUES 子句存在未闭合括号或未闭合字符串。")
def _recover_misplaced_quote_separator(token: list[str], next_char: str) -> tuple[str, list[Any]] | None:
"""兼容解析:修复接口 SQL 文本字段结尾写成 `文本,'next'` 的引号/逗号错位。"""
if not next_char or next_char in {",", ")"} or next_char.isspace():
return None
raw = "".join(token)
if not raw.endswith(","):
return None
body = raw[:-1]
trailing_values: list[Any] = []
while True:
head, separator, tail = body.rpartition(",")
if not separator or not _looks_like_unquoted_scalar(tail):
break
trailing_values.insert(0, _coerce_raw_token(tail, was_string=False))
body = head
if not body:
return None
return body, trailing_values
def _recover_unclosed_string_at_tuple_end(token: list[str]) -> tuple[str, list[Any]] | None:
"""兼容解析:修复文本字段缺少结束引号且后接 `,数字)` 的源 SQL。"""
body = "".join(token)
trailing_values: list[Any] = []
while True:
head, separator, tail = body.rpartition(",")
if not separator or not _looks_like_unquoted_scalar(tail):
break
trailing_values.insert(0, _coerce_raw_token(tail, was_string=False))
body = head
if not body or not trailing_values:
return None
return body, trailing_values
def _looks_like_unquoted_scalar(value: str) -> bool:
"""兼容解析:判断错位引号前夹带的字段是否是可安全恢复的未引用标量。"""
stripped = value.strip()
if not stripped:
return False
if stripped.upper() == "NULL":
return True
return bool(re.fullmatch(r"-?\d+(?:\.\d+)?", stripped))
def _unescape_char(char: str) -> str:
"""SQL 转义:处理 dump 中的常见反斜杠转义。"""
return {
"0": "\0",
"b": "\b",
"n": "\n",
"r": "\r",
"t": "\t",
"Z": "\x1a",
"\\": "\\",
"'": "'",
'"': '"',
}.get(char, char)
def _coerce_raw_token(raw: str, was_string: bool) -> Any:
"""原始字段转换:区分 SQL NULL、数字和字符串。"""
if was_string:
return raw
value = raw.strip()
if not value:
return ""
if value.upper() == "NULL":
return None
if re.fullmatch(r"-?\d+", value):
return int(value)
if re.fullmatch(r"-?\d+\.\d+", value):
return Decimal(value)
if re.search(r"[A-Za-z\u4e00-\u9fff]", value):
raise ImportValidationError(f"发现未引用文本字段:{value[:30]}")
return value
def normalize_value(table_name: str, column: str, value: Any) -> Any:
"""字段归一:按当前 ORM 需要转换 JSON、时间、布尔和数值。"""
if value is None:
return None
if column in JSON_COLUMNS.get(table_name, set()):
if isinstance(value, (list, dict)):
return value
try:
return json.loads(value)
except (TypeError, json.JSONDecodeError) as exc:
raise ImportValidationError(f"{table_name}.{column} 不是合法 JSON。") from exc
if column in DATETIME_COLUMNS and isinstance(value, str):
try:
return datetime.fromisoformat(value)
except ValueError as exc:
raise ImportValidationError(f"{table_name}.{column} 不是合法 datetime。") from exc
if column in BOOL_COLUMNS:
return bool(int(value))
if column in INT_COLUMNS and value is not None:
return int(value)
if column in DECIMAL_COLUMNS and value is not None:
return Decimal(str(value))
return value
def parse_source_dump(path: Path) -> tuple[dict[str, list[dict[str, Any]]], list[str], str]:
"""源文件解析:提取可导入表数据并返回兼容性警告。"""
sql_text, encoding = load_sql_text(path)
warnings = []
dangerous = find_dangerous_statements(sql_text)
if dangerous:
warnings.append("源 SQL 包含 DDL/DML 覆盖语句,导入器会忽略这些语句:" + ", ".join(sorted(set(dangerous))))
parsed: dict[str, list[dict[str, Any]]] = {}
for table_name in SOURCE_TABLES:
columns = extract_create_columns(sql_text, table_name)
if not columns:
raise ImportValidationError(f"源 SQL 缺少必需表结构:{table_name}")
parsed[table_name] = extract_insert_rows(sql_text, table_name, columns)
for table_name in OPTIONAL_SOURCE_TABLES:
columns = extract_create_columns(sql_text, table_name)
if not columns:
warnings.append(f"源 SQL 未包含 {table_name},导入器会按当前业务规则处理。")
continue
parsed[table_name] = extract_insert_rows(sql_text, table_name, columns)
if not parsed.get("traditional_case") and not parsed.get("teaching_case"):
warnings.append("源 SQL 未包含 traditional_case 或 teaching_case,病例导入后暂时缺少训练模式扩展数据。")
if not parsed.get("scoring_rule"):
warnings.append("源 SQL 未包含 scoring_rule,评价时将缺少接口侧基础评分规则。")
return parsed, warnings, encoding
def validate_target_schema(parsed: dict[str, list[dict[str, Any]]]) -> None:
"""目标结构校验:确认源字段可映射到当前数据库表。"""
with SessionLocal() as db:
inspector = inspect(db.bind)
for table_name, rows in parsed.items():
if not inspector.has_table(table_name):
raise ImportValidationError(f"当前数据库缺少目标表:{table_name}")
target_columns = {column["name"] for column in inspector.get_columns(table_name)}
for row in rows:
extra_columns = sorted(set(row) - target_columns)
if extra_columns:
raise ImportValidationError(f"{table_name} 存在当前库不支持的字段:{extra_columns}")
def import_source_sql(path: Path, apply: bool = False, generate_exam_items: bool = True) -> ImportReport:
"""安全导入:解析源 SQL,并在显式 apply 时写入当前新表。"""
parsed, warnings, encoding = parse_source_dump(path)
validate_target_schema(parsed)
report = ImportReport(
source_path=str(path),
encoding=encoding,
table_rows={table: len(rows) for table, rows in parsed.items()},
warnings=warnings,
applied=apply,
)
if not apply:
return report
with SessionLocal() as db:
try:
case_rows = parsed.get("case_base", [])
report.upserted_cases = _upsert_cases(db, case_rows)
if "traditional_case" in parsed:
report.upserted_traditional_cases = _sync_traditional_cases(db, case_rows, parsed.get("traditional_case", []))
if "teaching_case" in parsed:
report.upserted_teaching_cases = _sync_teaching_cases(db, case_rows, parsed.get("teaching_case", []))
if "scoring_rule" in parsed:
report.replaced_scoring_rules = _replace_scoring_rules(db, case_rows, parsed.get("scoring_rule", []))
if generate_exam_items:
report.generated_exam_items = _upsert_generated_exam_items(db, case_rows)
db.commit()
except Exception:
db.rollback()
raise
return report
def _upsert_cases(db, rows: list[dict[str, Any]]) -> int:
"""病例导入:按 case_base.id 更新或插入病例主表。"""
count = 0
for row in rows:
row = dict(row)
row["status"] = 1
row["publish_status"] = 1
entity = db.get(CaseBase, row["id"])
if not entity:
db.add(CaseBase(**row))
else:
for key, value in row.items():
setattr(entity, key, value)
count += 1
return count
def _sync_traditional_cases(db, case_rows: list[dict[str, Any]], rows: list[dict[str, Any]]) -> int:
"""传统病例导入:按本次 SQL 同步练习模式扩展表,避免同 case_id 旧数据残留。"""
case_ids = _imported_case_ids(case_rows)
if case_ids:
db.execute(delete(TraditionalCase).where(TraditionalCase.case_id.in_(case_ids)))
count = 0
for row in rows:
db.add(TraditionalCase(**row))
count += 1
return count
def _sync_teaching_cases(db, case_rows: list[dict[str, Any]], rows: list[dict[str, Any]]) -> int:
"""教学互动病例导入:源 SQL 明确提供 teaching_case 时,以源数据为准同步扩展表。"""
case_ids = _imported_case_ids(case_rows)
if case_ids:
db.execute(delete(TeachingCase).where(TeachingCase.case_id.in_(case_ids)))
count = 0
for row in rows:
db.add(TeachingCase(**row))
count += 1
return count
def _replace_scoring_rules(db, case_rows: list[dict[str, Any]], rows: list[dict[str, Any]]) -> int:
"""评分规则导入:按本次导入病例替换 scoring_rule,保持评分规则与源数据一致。"""
case_ids = _imported_case_ids(case_rows)
for case_id in case_ids:
db.execute(delete(ScoringRule).where(ScoringRule.case_id == case_id))
for row in rows:
db.add(ScoringRule(**row))
return len(rows)
def _upsert_generated_exam_items(db, case_rows: list[dict[str, Any]]) -> int:
"""检查项目补齐:按 item_code 更新或补齐固定检查结果,避免删除历史训练引用的检查项。"""
changed = 0
for row in case_rows:
case_id = row["id"]
items = build_exam_items_from_case(row)
for item in items:
entity = db.scalar(
select(CaseExamItem).where(CaseExamItem.case_id == case_id, CaseExamItem.item_code == item["item_code"])
)
if not entity:
db.add(CaseExamItem(case_id=case_id, **item))
else:
for key, value in item.items():
setattr(entity, key, value)
changed += 1
return changed
def _imported_case_ids(case_rows: list[dict[str, Any]]) -> list[int]:
"""导入范围:提取本次 SQL 涉及的病例 ID,用于同步扩展表。"""
return sorted({int(row["id"]) for row in case_rows if row.get("id") is not None})
def build_exam_items_from_case(case_row: dict[str, Any]) -> list[dict[str, Any]]:
"""检查项目生成:从病例文本识别常见检查结果,保障 Demo 可继续申请检查。"""
text = " ".join(str(case_row.get(key) or "") for key in ("chief_complaint", "description", "knowledge_points"))
candidates = [
(
"blood_routine",
"血常规",
"lab",
"实验室检查",
_find_result(text, r"(WBC[^,。;;]*[,]\s*[^,。;;]*(?:中性|neutrophil)[^,。;;]*)") or "血常规结果见病例资料。",
{"source": "case_description"},
"WBC" in text or "血常规" in text,
),
(
"crp",
"CRP",
"lab",
"实验室检查",
_find_result(text, r"(CRP\s*[^,。;;]*)") or "CRP 结果见病例资料。",
{"source": "case_description"},
"CRP" in text.upper(),
),
(
"chest_xray",
"胸片",
"imaging",
"影像检查",
_find_result(text, r"(胸片[^。;;]*)") or "胸片结果见病例资料。",
{"source": "case_description"},
"胸片" in text or "X线" in text,
),
(
"spo2",
"血氧饱和度",
"vital_sign",
"生命体征",
_find_result(text, r"(SpO2\s*\d+%?)") or "血氧饱和度结果见病例资料。",
{"source": "case_description"},
"SpO2" in text or "血氧" in text,
),
(
"chest_auscultation",
"肺部体格检查",
"physical_exam",
"体格检查",
_find_result(text, r"(肺[^。;;]*(?:湿啰音|哮鸣音|呼吸音)[^。;;]*)") or "肺部体格检查结果见病例资料。",
{"source": "case_description"},
"湿啰音" in text or "哮鸣音" in text or "" in text,
),
(
"mp_igm",
"肺炎支原体抗体IgM",
"lab",
"实验室检查",
_find_result(text, r"(肺炎支原体抗体IgM[^,。;;]*)") or "肺炎支原体抗体IgM结果见病例资料。",
{"source": "case_description"},
"肺炎支原体抗体IgM" in text or "支原体" in text,
),
]
items: list[dict[str, Any]] = []
for index, (code, name, item_type, category, result_text, structured, detected) in enumerate(candidates, start=1):
if detected:
items.append(
{
"item_code": code,
"item_name": name,
"item_type": item_type,
"category": category,
"result_text": result_text,
"result_structured": structured,
"is_key": True,
"is_abnormal": True,
"score_weight": Decimal("5.00"),
"display_order": index,
}
)
if items:
return items
return [
{
"item_code": "basic_exam",
"item_name": "基础检查",
"item_type": "physical_exam",
"category": "体格检查",
"result_text": "基础检查结果见病例资料。",
"result_structured": {"source": "case_description"},
"is_key": True,
"is_abnormal": False,
"score_weight": Decimal("1.00"),
"display_order": 1,
}
]
def _find_result(text: str, pattern: str) -> str | None:
"""文本抽取:从病例描述中抽取简短检查结果。"""
match = re.search(pattern, text, flags=re.I)
return match.group(1).strip() if match else None
def main() -> None:
"""命令入口:执行接口 SQL 的安全检查或导入。"""
parser = argparse.ArgumentParser(description="Safely import parsed case SQL into current medical agent schema.")
parser.add_argument("sql_path", type=Path, help="接口提供的 SQL dump 文件路径")
parser.add_argument("--apply", action="store_true", help="确认写入当前数据库;默认只检查不写入")
parser.add_argument("--no-generate-exam-items", action="store_true", help="不自动补齐 case_exam_item")
args = parser.parse_args()
try:
report = import_source_sql(
args.sql_path,
apply=args.apply,
generate_exam_items=not args.no_generate_exam_items,
)
except ImportValidationError as exc:
print(json.dumps({"ok": False, "error": str(exc)}, ensure_ascii=False, indent=2))
raise SystemExit(2) from exc
print(json.dumps({"ok": True, "report": report.as_dict()}, ensure_ascii=False, indent=2, default=str))
if __name__ == "__main__":
main()
+303
View File
@@ -0,0 +1,303 @@
from __future__ import annotations
import sys
from pathlib import Path
from sqlalchemy import select
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from app.db.base import Base
from app.db.session import SessionLocal, engine
from app.models import (
CaseBase,
CaseExamItem,
Department,
KnowledgeChunk,
KnowledgeDocument,
KnowledgeSource,
PromptTemplate,
ScoringRule,
TeachingCase,
TraditionalCase,
User,
)
def init_database() -> None:
"""数据库初始化:创建当前新表体系并写入第一版 Demo 种子数据。"""
Base.metadata.create_all(bind=engine)
with SessionLocal() as db:
seed_demo_data(db)
db.commit()
def seed_demo_data(db) -> None:
"""病例导入:写入儿科支气管肺炎病例、检查项目、评分规则和提示词元数据。"""
department = _get_or_create_department(db)
user = _get_or_create_seed_user(db)
case = _get_or_create_case_base(db, department.id, user.id)
_seed_traditional_case(db, case.id)
_seed_teaching_case(db, case.id)
_seed_exam_items(db, case.id)
_seed_scoring_rules(db, case.id)
_seed_knowledge(db, department.id)
_seed_prompts(db)
def _get_or_create_department(db) -> Department:
"""科室种子:写入儿科科室。"""
department = db.scalar(select(Department).where(Department.code == "PEDIATRICS"))
if department:
return department
department = Department(name="儿科", code="PEDIATRICS", sort_order=1, is_active=True)
db.add(department)
db.flush()
return department
def _get_or_create_seed_user(db) -> User:
"""用户占位:写入系统种子用户,不承担登录职责。"""
user = db.scalar(select(User).where(User.external_user_id == "system_seed"))
if user:
return user
user = User(external_user_id="system_seed", display_name="系统种子数据")
db.add(user)
db.flush()
return user
def _get_or_create_case_base(db, department_id: int, user_id: int) -> CaseBase:
"""病例主表种子:以 case_base 作为病例唯一主表。"""
case = db.scalar(select(CaseBase).where(CaseBase.title == "支气管肺炎 - 6岁男性患儿"))
if case:
return case
case = CaseBase(
title="支气管肺炎 - 6岁男性患儿",
case_type="diagnosis_treatment",
difficulty="medium",
difficulty_score=2,
chief_complaint="发热、咳嗽4天,喘息1天",
description=(
"患儿4天前无明显诱因出现发热,最高体温39.2℃,伴阵发性咳嗽,后有少量白色黏痰。"
"1天前出现喘息,夜间明显,活动后加重。精神较差,食欲下降,小便略少。"
),
patient_age=6,
patient_gender="male",
tags="pediatrics,pneumonia,demo",
symptom_tags=["发热", "咳嗽", "喘息", "精神食纳差"],
disease_tags=["支气管肺炎"],
competency_tags=["问诊完整性", "儿科查体规范", "关键症状识别", "诊断准确性", "治疗计划合理性"],
guideline_tags=["CAP_2019", "HUMANISTIC_CARE"],
knowledge_points=["血常规", "CRP", "胸片", "血氧饱和度", "肺部湿啰音"],
icd_codes="",
estimated_minutes=20,
osce_enabled=False,
rag_enabled=True,
ai_prompt_template="app/prompts/patient/practice.md",
multimodal_assets=[],
vector_status=0,
publish_status=1,
status=1,
created_by_id=user_id,
department_id=department_id,
)
db.add(case)
db.flush()
return case
def _seed_traditional_case(db, case_id: int) -> None:
"""传统病例种子:练习模式读取 case_base + traditional_case。"""
if db.scalar(select(TraditionalCase).where(TraditionalCase.case_id == case_id)):
return
db.add(
TraditionalCase(
case_id=case_id,
standard_diagnosis="支气管肺炎",
standard_treatment=(
"抗感染、止咳平喘、改善氧合、严密观察病情变化;必要时雾化吸入缓解喘息,"
"监测体温、呼吸、血氧、精神反应和饮水尿量,出现低氧或呼吸困难加重时及时升级处理。"
),
guideline_reference=(
"诊断依据:发热、咳嗽、喘息,肺部湿啰音,炎症指标升高,胸片提示右下肺片状模糊影,"
"符合儿童社区获得性肺炎/支气管肺炎诊断思路。严重程度需结合呼吸频率、SpO2、意识、循环和进食饮水情况。"
),
)
)
def _seed_teaching_case(db, case_id: int) -> None:
"""教学互动病例种子:教学互动模式读取 case_base + teaching_case。"""
if db.scalar(select(TeachingCase).where(TeachingCase.case_id == case_id)):
return
db.add(
TeachingCase(
case_id=case_id,
teaching_goal="围绕儿科肺炎问诊、检查选择、诊断依据、治疗决策和医患沟通完成互动训练。",
discussion_questions="如何判断病情严重程度?哪些检查是关键检查?治疗方案如何兼顾抗感染、平喘和氧合监测?",
teacher_guide="观察学生是否完整追问发热、咳嗽、喘息、既往史、接触史,并能解释胸片、炎症指标和血氧。",
scoring_focus="问诊完整性、检查合理性、诊断准确性、治疗计划、风险预案、人文沟通。",
)
)
def _seed_exam_items(db, case_id: int) -> None:
"""检查项目种子:写入病例可申请检查和固定返回结果。"""
if db.scalar(select(CaseExamItem).where(CaseExamItem.case_id == case_id)):
return
items = [
("blood_routine", "血常规", "lab", "WBC 12.5×10^9/L,中性粒细胞比例72%,提示感染及炎症反应。", {"wbc": "12.5×10^9/L", "neutrophil": "72%"}, True, True, 1),
("crp", "CRP", "lab", "CRP 28 mg/L,提示炎症反应升高。", {"crp": "28 mg/L"}, True, True, 2),
("chest_xray", "胸片", "imaging", "双下肺纹理增多,右下肺片状模糊影,支持肺部感染。", {"finding": "右下肺片状模糊影"}, True, True, 3),
("spo2", "血氧饱和度", "vital_sign", "室内空气 SpO2 94%,处于临界偏低范围。", {"spo2": "94%"}, True, True, 4),
("mp_igm", "肺炎支原体IgM", "lab", "肺炎支原体IgM阴性。", {"mp_igm": "negative"}, False, False, 5),
]
for code, name, item_type, result_text, structured, is_key, abnormal, order in items:
db.add(
CaseExamItem(
case_id=case_id,
item_code=code,
item_name=name,
item_type=item_type,
category=item_type,
result_text=result_text,
result_structured=structured,
is_key=is_key,
is_abnormal=abnormal,
score_weight=5.0 if is_key else 1.0,
display_order=order,
)
)
def _seed_scoring_rules(db, case_id: int) -> None:
"""评分规则种子:写入 scoring_rule,评价时作为基础评分细则。"""
if db.scalar(select(ScoringRule).where(ScoringRule.case_id == case_id)):
return
rules = [
("信息获取", "问诊完整性", 25, "覆盖现病史、既往史、个人史、家族史、儿科特异性症状与家属担忧。"),
("分析推理", "诊断与鉴别诊断", 25, "结合症状、体征、胸片、炎症指标和血氧支持支气管肺炎诊断,并列出合理鉴别诊断。"),
("处置决策", "检查与治疗方案", 20, "检查申请合理,治疗原则覆盖抗感染、止咳平喘、改善氧合、风险预案和随访。"),
("沟通人文", "家属沟通", 15, "向家属说明病情、用药注意事项、危险信号、复诊或住院指征,并回应焦虑。"),
("临床整合", "流程与整体思维", 15, "流程连贯,把问诊、检查、诊断、治疗和沟通整合成完整临床决策。"),
]
for dimension, competency, weight, standard in rules:
db.add(
ScoringRule(
case_id=case_id,
dimension=dimension,
competency_dimension=competency,
score_weight=weight,
ai_auto_score=True,
osce_dimension=False,
scoring_standard=standard,
rubric_json={"max_score": weight, "criteria": standard},
)
)
def _seed_knowledge(db, department_id: int) -> None:
"""知识库种子:写入评分参考指南和人文沟通片段。"""
if db.scalar(select(KnowledgeSource).where(KnowledgeSource.source_code == "CAP_2019")):
return
source = KnowledgeSource(
source_code="CAP_2019",
source_name="儿童社区获得性肺炎诊疗规范(2019年版)",
source_type="clinical_guideline",
authority_level=5,
is_active=True,
)
human = KnowledgeSource(
source_code="HUMANISTIC_CARE",
source_name="问诊沟通与人文关怀要求",
source_type="humanistic_care",
authority_level=4,
is_active=True,
)
db.add_all([source, human])
db.flush()
doc = KnowledgeDocument(
source_id=source.id,
department_id=department_id,
title="儿童社区获得性肺炎诊疗规范摘要",
task_type="diagnosis_treatment",
summary="用于肺炎病例诊断、严重程度评估和治疗评分参考。",
file_path="docs/knowledge/cap_2019.md",
is_active=True,
)
human_doc = KnowledgeDocument(
source_id=human.id,
department_id=department_id,
title="儿科问诊人文关怀要点",
task_type="diagnosis_treatment",
summary="用于评价家属沟通、知情告知和健康教育。",
file_path="docs/knowledge/humanistic_care.md",
is_active=True,
)
db.add_all([doc, human_doc])
db.flush()
db.add_all(
[
KnowledgeChunk(
document_id=doc.id,
department_id=department_id,
task_type="diagnosis_treatment",
chunk_text="儿童肺炎诊断需综合发热、咳嗽、喘息、肺部湿啰音、炎症指标和胸部影像学新发浸润影。",
keywords=["发热", "咳嗽", "喘息", "胸片异常"],
weight=5.0,
is_active=True,
),
KnowledgeChunk(
document_id=doc.id,
department_id=department_id,
task_type="diagnosis_treatment",
chunk_text="严重程度评估应关注呼吸频率、血氧饱和度、意识状态、循环状态和进食饮水情况。",
keywords=["血氧饱和度下降", "呼吸", "严重程度"],
weight=4.5,
is_active=True,
),
KnowledgeChunk(
document_id=human_doc.id,
department_id=department_id,
task_type="diagnosis_treatment",
chunk_text="儿科问诊需要向家属说明病情观察指标、用药注意事项、复诊指征,并给予情绪安抚。",
keywords=["沟通", "健康教育", "家属"],
weight=4.0,
is_active=True,
),
]
)
def _seed_prompts(db) -> None:
"""提示词种子:写入 Markdown 模板元数据,正文保存在 prompts 目录。"""
templates = [
("patient_practice", "patient", "practice", "v1", "fast", "text", "app/prompts/patient/practice.md"),
("patient_teaching", "patient", "teaching", "v1", "fast", "text", "app/prompts/patient/teaching.md"),
("novice_case_hint", "hint", "novice", "v1", "fast", "json", "app/prompts/hint/novice_case_hint.md"),
("scoring_pediatrics_pneumonia", "scoring", "pediatrics_pneumonia", "v1", "fast", "json", "app/prompts/scoring/pediatrics_pneumonia.md"),
("report_evaluation", "report", "evaluation", "v1", "fast", "json", "app/prompts/report/evaluation_report.md"),
]
for code, agent_type, scene, version, model_type, output_format, file_path in templates:
template = db.scalar(select(PromptTemplate).where(PromptTemplate.template_code == code, PromptTemplate.version_no == version))
if not template:
template = PromptTemplate(template_code=code, version_no=version)
template.agent_type = agent_type
template.scene = scene
template.model_type = model_type
template.output_format = output_format
template.file_path = file_path
template.is_active = True
db.add(template)
def ensure_storage_dirs() -> None:
"""目录初始化:创建报告导出目录。"""
Path("storage/reports").mkdir(parents=True, exist_ok=True)
if __name__ == "__main__":
ensure_storage_dirs()
init_database()
print("Demo database initialized.")
+45
View File
@@ -0,0 +1,45 @@
from __future__ import annotations
import sys
from pathlib import Path
from sqlalchemy import text
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from app.db.session import SessionLocal
from scripts.init_demo_db import init_database
def main() -> None:
"""新表迁移:创建并补齐 case_base、traditional_case、teaching_case、case_exam_item、training_* 和 training_record。"""
init_database()
with SessionLocal() as db:
_apply_table_comments(db)
db.commit()
print("new schema migration completed")
def _apply_table_comments(db) -> None:
"""表注释补齐:为当前业务表写入中文说明,便于数据库工具查看。"""
comments = {
"case_base": "病例主表",
"traditional_case": "传统病例扩展表",
"teaching_case": "教学互动病例扩展表",
"scoring_rule": "评分规则表",
"case_exam_item": "病例检查检验项目表",
"training_session": "训练会话表",
"training_order": "训练检查申请表",
"training_submission": "训练诊断治疗提交表",
"training_record": "训练记录表",
}
dialect = db.bind.dialect.name if db.bind else ""
if dialect != "mysql":
return
for table_name, comment in comments.items():
safe_comment = comment.replace("'", "''")
db.execute(text(f"ALTER TABLE `{table_name}` COMMENT='{safe_comment}'"))
if __name__ == "__main__":
main()