Files
fastapi/app/agents/learning_assistant_agent.py
T
2026-06-11 16:19:50 +08:00

80 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from collections.abc import AsyncIterator
from app.agents.llm_adapter import LLMStreamChunk, OpenAICompatibleLLMClient
from app.core.config import settings
from app.schemas.learning_assistant import LearningAssistantSource
class LearningAssistantAgent:
"""AI 学习助手 Agent:根据 RAG 来源和短期上下文生成带循证出处的医学学习回答。"""
def __init__(self, llm_client: OpenAICompatibleLLMClient | None = None) -> None:
self.llm_client = llm_client or OpenAICompatibleLLMClient()
async def stream_answer(
self,
question: str,
sources: list[LearningAssistantSource],
history: list[dict] | None = None,
) -> AsyncIterator[LLMStreamChunk]:
"""流式回答:输出 AI 学习助手增量文本,前端可直接渲染。"""
async for chunk in self.llm_client.stream_chat(
self._messages(question, sources, history or []),
model=settings.llm_fast_model,
thinking_enabled=settings.llm_fast_thinking_enabled,
max_tokens=1200,
):
yield chunk
def _messages(self, question: str, sources: list[LearningAssistantSource], history: list[dict]) -> list[dict]:
"""提示词拼接:命中知识库时强制引用来源,未命中时必须声明未找到机构参考。"""
history_text = self._history_text(history)
if sources:
context = "\n\n".join(
(
f"[来源{index}] 文档:{source.document_title or source.file_name}"
f"页码:{source.page_start}-{source.page_end}chunk_uid{source.chunk_uid}\n"
f"{source.quote}"
)
for index, source in enumerate(sources, start=1)
)
system = (
"你是医学学习助手,用于医学教育、课程学习和临床思维训练,不替代临床诊疗。"
"优先依据给定知识库片段回答,回答要清晰、准确、分点。"
"每个关键结论后标注对应来源编号,例如【来源1】。"
"不得编造不存在的 PDF、页码或指南来源。"
)
user = (
f"{history_text}"
f"用户当前问题:{question}\n\n"
f"可用知识库片段:\n{context}\n\n"
"请给出带来源的学习回答。"
)
else:
system = (
"你是医学学习助手,用于医学教育、课程学习和临床思维训练,不替代临床诊疗。"
"当前没有检索到机构知识库参考,回答开头必须写:"
"未检索到本机构知识库参考,以下为大模型通用学习回答。"
"不得伪造 PDF 来源、页码或指南名称。"
)
user = (
f"{history_text}"
f"用户当前问题:{question}\n\n"
"请给出通用学习回答,并提醒用户以课程教材、指南和临床医生判断为准。"
)
return [{"role": "system", "content": system}, {"role": "user", "content": user}]
def _history_text(self, history: list[dict]) -> str:
"""上下文摘要:把当前学习助手会话最近几轮问答压缩为提示词上下文。"""
if not history:
return ""
lines: list[str] = []
for item in history[-settings.learning_assistant_history_limit :]:
role = "用户" if item.get("role") == "user" else "助手"
content = str(item.get("content") or "").strip()
if content:
lines.append(f"{role}{content[:500]}")
if not lines:
return ""
return "当前会话最近上下文:\n" + "\n".join(lines) + "\n\n"