import json from pathlib import Path from typing import Any from app.agents.llm_adapter import DeepSeekClient from app.core.config import settings from app.models.source_case import CaseBase from app.models.training import SessionOrder, TrainingSession class HintAgent: """新手提示 Agent:基于病例、对话和检查结果调用快速模型生成结构化提示。""" def __init__(self, llm: DeepSeekClient | None = None) -> None: self.llm = llm or DeepSeekClient() self.template_path = Path(__file__).resolve().parents[1] / "prompts" / "hint" / "novice_case_hint.md" async def generate( self, session: TrainingSession, case: CaseBase, memory_messages: list[dict], orders: list[SessionOrder], last_user_message: str | None = None, ) -> dict: """LLM 提示:标准化输入病例上下文并要求模型返回固定 JSON 结构。""" payload = self._build_input(session, case, memory_messages, orders, last_user_message) messages = [ {"role": "system", "content": self._load_template()}, {"role": "user", "content": json.dumps(payload, ensure_ascii=False)}, ] try: response = await self.llm.chat( messages, settings.llm_fast_model, thinking_enabled=settings.llm_fast_thinking_enabled, response_format={"type": "json_object"}, max_tokens=settings.llm_hint_max_tokens, ) data = json.loads(response.content) return self._normalize_output(data, payload) except Exception: return self._fallback_output(payload) def _build_input( self, session: TrainingSession, case: CaseBase, memory_messages: list[dict], orders: list[SessionOrder], last_user_message: str | None, ) -> dict: """输入构造:把病例、会话、对话摘要和已申请检查整理为稳定 JSON。""" return { "case": { "case_id": case.id, "department": case.department.name if getattr(case, "department", None) else "", "title": case.title, "chief_complaint": case.chief_complaint, "key_symptoms": case.key_symptoms or [], "key_exams": case.key_exams or [], "key_points": case.key_points or [], }, "session": { "mode": session.mode, "status": session.status, }, "conversation_summary": [ {"role": item.get("role"), "content": str(item.get("content", ""))[:240]} for item in memory_messages[-12:] if item.get("content") ], "ordered_results": [ { "item_code": order.item_code, "item_name": order.item_name, "item_type": order.item_type, "result_text": order.result_text, "is_key": order.is_key, "is_abnormal": order.is_abnormal, } for order in orders ], "last_user_message": last_user_message or "", } def _load_template(self) -> str: """提示词读取:加载新手模式病例提示模板。""" if self.template_path.exists(): return self.template_path.read_text(encoding="utf-8") return "你是医疗问诊训练提示 Agent,只输出合法 JSON。" def _normalize_output(self, data: Any, payload: dict) -> dict: """输出校验:确保 LLM 返回结构稳定,不把原始文本透传给前端。""" if not isinstance(data, dict): return self._fallback_output(payload) normalized = { "hints": self._clean_str_list(data.get("hints"))[:4], "missing_dimensions": self._clean_str_list(data.get("missing_dimensions"))[:6], "next_questions": self._clean_str_list(data.get("next_questions"))[:5], "recommended_orders": self._clean_orders(data.get("recommended_orders"))[:5], } if not any(normalized.values()): return self._fallback_output(payload) return normalized def _fallback_output(self, payload: dict) -> dict: """提示兜底:LLM 异常或 JSON 不合法时按病例关键点生成稳定提示。""" case = payload.get("case", {}) ordered_codes = {item.get("item_code") for item in payload.get("ordered_results", [])} key_exams = case.get("key_exams") or [] recommended_orders = [] if "blood_routine" not in ordered_codes: recommended_orders.append({"item_code": "blood_routine", "reason": "用于初步判断感染及炎症反应"}) if "crp" not in ordered_codes: recommended_orders.append({"item_code": "crp", "reason": "用于辅助判断炎症程度"}) if "chest_xray" not in ordered_codes: recommended_orders.append({"item_code": "chest_xray", "reason": "用于判断肺部感染影像学证据"}) if "spo2" not in ordered_codes: recommended_orders.append({"item_code": "spo2", "reason": "用于判断氧合和病情严重程度"}) return { "hints": [ f"本病例主诉为{case.get('chief_complaint') or '当前症状'},问诊要围绕起病时间、症状演变和严重程度展开。", "当前提示来自病例关键症状、关键检查和已完成对话的结构化兜底分析。", "获得检查结果后,需要把异常结果用于诊断依据和病情严重程度判断。", ], "missing_dimensions": ["既往史", "严重程度评估", "人文沟通"], "next_questions": [ "孩子最高体温多少?退烧药后能不能降下来?", "有没有喘息、气促、口唇发绀或呼吸困难?", "以前有没有喘息、哮喘、湿疹或药物过敏史?", "精神、食欲、饮水和尿量怎么样?", "家属现在最担心什么?", ], "recommended_orders": recommended_orders or [ {"item_code": str(item), "reason": "病例关键检查,需要结合结果完善诊断依据"} for item in key_exams[:3] ], } def _clean_str_list(self, value: Any) -> list[str]: """字段清洗:把模型返回的数组字段压成字符串列表。""" if not isinstance(value, list): return [] return [str(item).strip() for item in value if str(item).strip()] def _clean_orders(self, value: Any) -> list[dict]: """推荐检查清洗:只保留 item_code 和 reason 两个前端需要的字段。""" if not isinstance(value, list): return [] cleaned = [] for item in value: if not isinstance(item, dict): continue code = str(item.get("item_code", "")).strip() reason = str(item.get("reason", "")).strip() if code: cleaned.append({"item_code": code, "reason": reason}) return cleaned