add training configuration APIs
This commit is contained in:
@@ -20,7 +20,7 @@ class MedicalConsultationOrchestrator:
|
||||
|
||||
async def patient_reply(self, session: TrainingSession, case: CaseBase, memory_messages: list[dict], message: str) -> LLMResponse:
|
||||
"""问诊编排:调用 Patient Agent 生成 AI 病人回复。"""
|
||||
return await self.patient_agent.reply(case, memory_messages, message, session.mode)
|
||||
return await self.patient_agent.reply(case, memory_messages, message, session.mode, self._patient_config(session))
|
||||
|
||||
async def patient_stream_reply(
|
||||
self,
|
||||
@@ -30,7 +30,7 @@ class MedicalConsultationOrchestrator:
|
||||
message: str,
|
||||
) -> AsyncIterator[LLMStreamChunk]:
|
||||
"""流式问诊编排:调用 Patient Agent 并返回流式片段。"""
|
||||
async for chunk in self.patient_agent.stream_reply(case, memory_messages, message, session.mode):
|
||||
async for chunk in self.patient_agent.stream_reply(case, memory_messages, message, session.mode, self._patient_config(session)):
|
||||
yield chunk
|
||||
|
||||
async def evaluate(
|
||||
@@ -67,3 +67,9 @@ class MedicalConsultationOrchestrator:
|
||||
) -> dict:
|
||||
"""新手提示编排:基于当前会话上下文生成轻量训练提醒。"""
|
||||
return await self.hint_agent.generate(session, case, memory_messages, orders, last_user_message)
|
||||
|
||||
def _patient_config(self, session: TrainingSession) -> dict | None:
|
||||
"""病人配置:从会话 metadata 读取训练页初始化配置,传递给 Patient Agent。"""
|
||||
metadata = session.metadata_ or {}
|
||||
patient_config = metadata.get("patient_config") if isinstance(metadata, dict) else None
|
||||
return patient_config if isinstance(patient_config, dict) else None
|
||||
|
||||
@@ -11,9 +11,16 @@ class PatientAgent:
|
||||
def __init__(self, llm: DeepSeekClient | None = None) -> None:
|
||||
self.llm = llm or DeepSeekClient()
|
||||
|
||||
async def reply(self, case: CaseBase, memory_messages: list[dict], user_message: str, mode: str) -> LLMResponse:
|
||||
async def reply(
|
||||
self,
|
||||
case: CaseBase,
|
||||
memory_messages: list[dict],
|
||||
user_message: str,
|
||||
mode: str,
|
||||
patient_config: dict | None = None,
|
||||
) -> LLMResponse:
|
||||
"""问诊回复:拼接病例上下文、短期记忆和用户输入后调用 Patient Agent。"""
|
||||
messages = self._build_messages(case, memory_messages, user_message, mode)
|
||||
messages = self._build_messages(case, memory_messages, user_message, mode, patient_config)
|
||||
return await self.llm.chat(
|
||||
messages,
|
||||
settings.llm_fast_model,
|
||||
@@ -27,9 +34,10 @@ class PatientAgent:
|
||||
memory_messages: list[dict],
|
||||
user_message: str,
|
||||
mode: str,
|
||||
patient_config: dict | None = None,
|
||||
) -> AsyncIterator[LLMStreamChunk]:
|
||||
"""流式问诊:以 SSE 方式返回 AI 病人增量回复。"""
|
||||
messages = self._build_messages(case, memory_messages, user_message, mode)
|
||||
messages = self._build_messages(case, memory_messages, user_message, mode, patient_config)
|
||||
async for chunk in self.llm.stream_chat(
|
||||
messages,
|
||||
settings.llm_fast_model,
|
||||
@@ -38,10 +46,18 @@ class PatientAgent:
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _build_messages(self, case: CaseBase, memory_messages: list[dict], user_message: str, mode: str) -> list[dict]:
|
||||
def _build_messages(
|
||||
self,
|
||||
case: CaseBase,
|
||||
memory_messages: list[dict],
|
||||
user_message: str,
|
||||
mode: str,
|
||||
patient_config: dict | None = None,
|
||||
) -> list[dict]:
|
||||
"""提示词拼接:构造 AI 病人的系统提示词和对话历史。"""
|
||||
profile = case.ai_patient_profile or {}
|
||||
hidden_info = case.hidden_patient_info or {}
|
||||
config_rule = self._build_patient_config_rule(patient_config)
|
||||
mode_rule = {
|
||||
"novice": "新手模式:回答清楚,必要时可提示医生继续追问症状、既往史或检查。",
|
||||
"practice": "练习模式:只回答被问到的信息,不主动给诊断建议。",
|
||||
@@ -52,6 +68,7 @@ class PatientAgent:
|
||||
病例主诉:{case.chief_complaint}
|
||||
患者人设:{profile}
|
||||
隐藏信息:{hidden_info}
|
||||
病人初始化配置:{config_rule}
|
||||
回答规则:
|
||||
1. 不主动透露未被问到的隐藏信息。
|
||||
2. 不替医生做诊断,不提供治疗方案。
|
||||
@@ -66,6 +83,21 @@ class PatientAgent:
|
||||
messages.append({"role": "user", "content": user_message})
|
||||
return messages
|
||||
|
||||
def _build_patient_config_rule(self, patient_config: dict | None) -> str:
|
||||
"""配置提示:把训练页初始化配置转成 AI 病人表达约束。"""
|
||||
if not patient_config:
|
||||
return "使用默认门诊、青年、高等教育、平和性格的表达方式。"
|
||||
labels = patient_config.get("labels") if isinstance(patient_config, dict) else None
|
||||
values = labels or (patient_config.get("values") if isinstance(patient_config, dict) else {}) or {}
|
||||
visit_environment = values.get("visit_environment", "门诊")
|
||||
age_group = values.get("age_group", "青年")
|
||||
education_level = values.get("education_level", "高等教育")
|
||||
personality = values.get("personality", "平和")
|
||||
return (
|
||||
f"就诊环境={visit_environment};年龄段={age_group};文化程度={education_level};性格={personality}。"
|
||||
"回答时根据性格调整情绪和配合度,根据文化程度调整表达清晰度,但不得改变病例事实。"
|
||||
)
|
||||
|
||||
def _to_llm_history(self, memory_messages: list[dict]) -> list[dict]:
|
||||
"""历史转换:把业务角色 doctor/patient 转换为 LLM role。"""
|
||||
role_map = {"doctor": "user", "patient": "assistant", "system": "system", "tool": "assistant"}
|
||||
|
||||
Reference in New Issue
Block a user