159 lines
7.2 KiB
Python
159 lines
7.2 KiB
Python
|
|
import json
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
from app.agents.llm_adapter import DeepSeekClient
|
||
|
|
from app.core.config import settings
|
||
|
|
from app.models.source_case import CaseBase
|
||
|
|
from app.models.training import SessionOrder, TrainingSession
|
||
|
|
|
||
|
|
|
||
|
|
class HintAgent:
|
||
|
|
"""新手提示 Agent:基于病例、对话和检查结果调用快速模型生成结构化提示。"""
|
||
|
|
|
||
|
|
def __init__(self, llm: DeepSeekClient | None = None) -> None:
|
||
|
|
self.llm = llm or DeepSeekClient()
|
||
|
|
self.template_path = Path(__file__).resolve().parents[1] / "prompts" / "hint" / "novice_case_hint.md"
|
||
|
|
|
||
|
|
async def generate(
|
||
|
|
self,
|
||
|
|
session: TrainingSession,
|
||
|
|
case: CaseBase,
|
||
|
|
memory_messages: list[dict],
|
||
|
|
orders: list[SessionOrder],
|
||
|
|
last_user_message: str | None = None,
|
||
|
|
) -> dict:
|
||
|
|
"""LLM 提示:标准化输入病例上下文并要求模型返回固定 JSON 结构。"""
|
||
|
|
payload = self._build_input(session, case, memory_messages, orders, last_user_message)
|
||
|
|
messages = [
|
||
|
|
{"role": "system", "content": self._load_template()},
|
||
|
|
{"role": "user", "content": json.dumps(payload, ensure_ascii=False)},
|
||
|
|
]
|
||
|
|
try:
|
||
|
|
response = await self.llm.chat(
|
||
|
|
messages,
|
||
|
|
settings.llm_fast_model,
|
||
|
|
thinking_enabled=settings.llm_fast_thinking_enabled,
|
||
|
|
response_format={"type": "json_object"},
|
||
|
|
max_tokens=settings.llm_hint_max_tokens,
|
||
|
|
)
|
||
|
|
data = json.loads(response.content)
|
||
|
|
return self._normalize_output(data, payload)
|
||
|
|
except Exception:
|
||
|
|
return self._fallback_output(payload)
|
||
|
|
|
||
|
|
def _build_input(
|
||
|
|
self,
|
||
|
|
session: TrainingSession,
|
||
|
|
case: CaseBase,
|
||
|
|
memory_messages: list[dict],
|
||
|
|
orders: list[SessionOrder],
|
||
|
|
last_user_message: str | None,
|
||
|
|
) -> dict:
|
||
|
|
"""输入构造:把病例、会话、对话摘要和已申请检查整理为稳定 JSON。"""
|
||
|
|
return {
|
||
|
|
"case": {
|
||
|
|
"case_id": case.id,
|
||
|
|
"department": case.department.name if getattr(case, "department", None) else "",
|
||
|
|
"title": case.title,
|
||
|
|
"chief_complaint": case.chief_complaint,
|
||
|
|
"key_symptoms": case.key_symptoms or [],
|
||
|
|
"key_exams": case.key_exams or [],
|
||
|
|
"key_points": case.key_points or [],
|
||
|
|
},
|
||
|
|
"session": {
|
||
|
|
"mode": session.mode,
|
||
|
|
"status": session.status,
|
||
|
|
},
|
||
|
|
"conversation_summary": [
|
||
|
|
{"role": item.get("role"), "content": str(item.get("content", ""))[:240]}
|
||
|
|
for item in memory_messages[-12:]
|
||
|
|
if item.get("content")
|
||
|
|
],
|
||
|
|
"ordered_results": [
|
||
|
|
{
|
||
|
|
"item_code": order.item_code,
|
||
|
|
"item_name": order.item_name,
|
||
|
|
"item_type": order.item_type,
|
||
|
|
"result_text": order.result_text,
|
||
|
|
"is_key": order.is_key,
|
||
|
|
"is_abnormal": order.is_abnormal,
|
||
|
|
}
|
||
|
|
for order in orders
|
||
|
|
],
|
||
|
|
"last_user_message": last_user_message or "",
|
||
|
|
}
|
||
|
|
|
||
|
|
def _load_template(self) -> str:
|
||
|
|
"""提示词读取:加载新手模式病例提示模板。"""
|
||
|
|
if self.template_path.exists():
|
||
|
|
return self.template_path.read_text(encoding="utf-8")
|
||
|
|
return "你是医疗问诊训练提示 Agent,只输出合法 JSON。"
|
||
|
|
|
||
|
|
def _normalize_output(self, data: Any, payload: dict) -> dict:
|
||
|
|
"""输出校验:确保 LLM 返回结构稳定,不把原始文本透传给前端。"""
|
||
|
|
if not isinstance(data, dict):
|
||
|
|
return self._fallback_output(payload)
|
||
|
|
normalized = {
|
||
|
|
"hints": self._clean_str_list(data.get("hints"))[:4],
|
||
|
|
"missing_dimensions": self._clean_str_list(data.get("missing_dimensions"))[:6],
|
||
|
|
"next_questions": self._clean_str_list(data.get("next_questions"))[:5],
|
||
|
|
"recommended_orders": self._clean_orders(data.get("recommended_orders"))[:5],
|
||
|
|
}
|
||
|
|
if not any(normalized.values()):
|
||
|
|
return self._fallback_output(payload)
|
||
|
|
return normalized
|
||
|
|
|
||
|
|
def _fallback_output(self, payload: dict) -> dict:
|
||
|
|
"""提示兜底:LLM 异常或 JSON 不合法时按病例关键点生成稳定提示。"""
|
||
|
|
case = payload.get("case", {})
|
||
|
|
ordered_codes = {item.get("item_code") for item in payload.get("ordered_results", [])}
|
||
|
|
key_exams = case.get("key_exams") or []
|
||
|
|
recommended_orders = []
|
||
|
|
if "blood_routine" not in ordered_codes:
|
||
|
|
recommended_orders.append({"item_code": "blood_routine", "reason": "用于初步判断感染及炎症反应"})
|
||
|
|
if "crp" not in ordered_codes:
|
||
|
|
recommended_orders.append({"item_code": "crp", "reason": "用于辅助判断炎症程度"})
|
||
|
|
if "chest_xray" not in ordered_codes:
|
||
|
|
recommended_orders.append({"item_code": "chest_xray", "reason": "用于判断肺部感染影像学证据"})
|
||
|
|
if "spo2" not in ordered_codes:
|
||
|
|
recommended_orders.append({"item_code": "spo2", "reason": "用于判断氧合和病情严重程度"})
|
||
|
|
return {
|
||
|
|
"hints": [
|
||
|
|
f"本病例主诉为{case.get('chief_complaint') or '当前症状'},问诊要围绕起病时间、症状演变和严重程度展开。",
|
||
|
|
"当前提示来自病例关键症状、关键检查和已完成对话的结构化兜底分析。",
|
||
|
|
"获得检查结果后,需要把异常结果用于诊断依据和病情严重程度判断。",
|
||
|
|
],
|
||
|
|
"missing_dimensions": ["既往史", "严重程度评估", "人文沟通"],
|
||
|
|
"next_questions": [
|
||
|
|
"孩子最高体温多少?退烧药后能不能降下来?",
|
||
|
|
"有没有喘息、气促、口唇发绀或呼吸困难?",
|
||
|
|
"以前有没有喘息、哮喘、湿疹或药物过敏史?",
|
||
|
|
"精神、食欲、饮水和尿量怎么样?",
|
||
|
|
"家属现在最担心什么?",
|
||
|
|
],
|
||
|
|
"recommended_orders": recommended_orders or [
|
||
|
|
{"item_code": str(item), "reason": "病例关键检查,需要结合结果完善诊断依据"} for item in key_exams[:3]
|
||
|
|
],
|
||
|
|
}
|
||
|
|
|
||
|
|
def _clean_str_list(self, value: Any) -> list[str]:
|
||
|
|
"""字段清洗:把模型返回的数组字段压成字符串列表。"""
|
||
|
|
if not isinstance(value, list):
|
||
|
|
return []
|
||
|
|
return [str(item).strip() for item in value if str(item).strip()]
|
||
|
|
|
||
|
|
def _clean_orders(self, value: Any) -> list[dict]:
|
||
|
|
"""推荐检查清洗:只保留 item_code 和 reason 两个前端需要的字段。"""
|
||
|
|
if not isinstance(value, list):
|
||
|
|
return []
|
||
|
|
cleaned = []
|
||
|
|
for item in value:
|
||
|
|
if not isinstance(item, dict):
|
||
|
|
continue
|
||
|
|
code = str(item.get("item_code", "")).strip()
|
||
|
|
reason = str(item.get("reason", "")).strip()
|
||
|
|
if code:
|
||
|
|
cleaned.append({"item_code": code, "reason": reason})
|
||
|
|
return cleaned
|