Files
fastapi/app/agents/patient_agent.py
T

77 lines
3.6 KiB
Python
Raw Normal View History

from collections.abc import AsyncIterator
from app.agents.llm_adapter import DeepSeekClient, LLMResponse, LLMStreamChunk
from app.core.config import settings
from app.models.source_case import CaseBase
class PatientAgent:
"""AI 病人:根据病例资料、隐藏信息和短期 memory 回复医生问诊。"""
def __init__(self, llm: DeepSeekClient | None = None) -> None:
self.llm = llm or DeepSeekClient()
async def reply(self, case: CaseBase, memory_messages: list[dict], user_message: str, mode: str) -> LLMResponse:
"""问诊回复:拼接病例上下文、短期记忆和用户输入后调用 Patient Agent。"""
messages = self._build_messages(case, memory_messages, user_message, mode)
return await self.llm.chat(
messages,
settings.llm_fast_model,
thinking_enabled=settings.llm_fast_thinking_enabled,
max_tokens=settings.llm_fast_max_tokens,
)
async def stream_reply(
self,
case: CaseBase,
memory_messages: list[dict],
user_message: str,
mode: str,
) -> AsyncIterator[LLMStreamChunk]:
"""流式问诊:以 SSE 方式返回 AI 病人增量回复。"""
messages = self._build_messages(case, memory_messages, user_message, mode)
async for chunk in self.llm.stream_chat(
messages,
settings.llm_fast_model,
thinking_enabled=settings.llm_fast_thinking_enabled,
max_tokens=settings.llm_fast_max_tokens,
):
yield chunk
def _build_messages(self, case: CaseBase, memory_messages: list[dict], user_message: str, mode: str) -> list[dict]:
"""提示词拼接:构造 AI 病人的系统提示词和对话历史。"""
profile = case.ai_patient_profile or {}
hidden_info = case.hidden_patient_info or {}
mode_rule = {
"novice": "新手模式:回答清楚,必要时可提示医生继续追问症状、既往史或检查。",
"practice": "练习模式:只回答被问到的信息,不主动给诊断建议。",
"teaching": "教学模式:保持患者身份,允许在回答后补充简短学习提示。",
}.get(mode, "只回答被问到的信息。")
system = f"""
你是一名标准化 AI 病人或患儿家属,只能基于病例资料回答。
病例主诉:{case.chief_complaint}
患者人设:{profile}
隐藏信息:{hidden_info}
回答规则:
1. 不主动透露未被问到的隐藏信息。
2. 不替医生做诊断,不提供治疗方案。
3. 不编造病例外检查检验结果。
4. 每次回答控制在1到3句话,使用患儿家属口吻,不输出分析过程。
5. 只输出给医生看的家属回答纯文本,不输出 JSON、Markdown、标题、解释或思考过程。
6. 如果医生一次问多个问题,按问题顺序简短回答,不扩展病例外信息。
7. {mode_rule}
"""
messages = [{"role": "system", "content": system.strip()}]
messages.extend(self._to_llm_history(memory_messages[-12:]))
messages.append({"role": "user", "content": user_message})
return messages
def _to_llm_history(self, memory_messages: list[dict]) -> list[dict]:
"""历史转换:把业务角色 doctor/patient 转换为 LLM role。"""
role_map = {"doctor": "user", "patient": "assistant", "system": "system", "tool": "assistant"}
return [
{"role": role_map.get(item.get("role"), "user"), "content": item.get("content", "")}
for item in memory_messages
if item.get("content")
]