chore: finalize backend feature scope
This commit is contained in:
@@ -9,11 +9,11 @@ from app.models.training import SessionOrder, TrainingSession
|
||||
|
||||
|
||||
class HintAgent:
|
||||
"""新手提示 Agent:基于病例、对话和检查结果调用快速模型生成结构化提示。"""
|
||||
"""练习提示 Agent:基于病例、对话和检查结果调用快速模型生成结构化提示。"""
|
||||
|
||||
def __init__(self, llm: DeepSeekClient | None = None) -> None:
|
||||
self.llm = llm or DeepSeekClient()
|
||||
self.template_path = Path(__file__).resolve().parents[1] / "prompts" / "hint" / "novice_case_hint.md"
|
||||
self.template_path = Path(__file__).resolve().parents[1] / "prompts" / "hint" / "practice_case_hint.md"
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
@@ -85,7 +85,7 @@ class HintAgent:
|
||||
}
|
||||
|
||||
def _load_template(self) -> str:
|
||||
"""提示词读取:加载新手模式病例提示模板。"""
|
||||
"""提示词读取:加载练习模式病例提示模板。"""
|
||||
if self.template_path.exists():
|
||||
return self.template_path.read_text(encoding="utf-8")
|
||||
return "你是医疗问诊训练提示 Agent,只输出合法 JSON。"
|
||||
|
||||
@@ -1,37 +1,34 @@
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from app.agents.llm_adapter import LLMResponse, LLMStreamChunk, OpenAICompatibleLLMClient
|
||||
from app.agents.llm_adapter import LLMStreamChunk, OpenAICompatibleLLMClient
|
||||
from app.core.config import settings
|
||||
from app.schemas.learning_assistant import LearningAssistantSource
|
||||
|
||||
|
||||
class LearningAssistantAgent:
|
||||
"""AI学习助手 Agent:根据 RAG 来源生成带循证出处的医学学习回答。"""
|
||||
"""AI 学习助手 Agent:根据 RAG 来源和短期上下文生成带循证出处的医学学习回答。"""
|
||||
|
||||
def __init__(self, llm_client: OpenAICompatibleLLMClient | None = None) -> None:
|
||||
self.llm_client = llm_client or OpenAICompatibleLLMClient()
|
||||
|
||||
async def answer(self, question: str, sources: list[LearningAssistantSource]) -> LLMResponse:
|
||||
"""非流式回答:把问题和检索来源拼接后调用快速模型生成标准回答。"""
|
||||
return await self.llm_client.chat(
|
||||
self._messages(question, sources),
|
||||
model=settings.llm_fast_model,
|
||||
thinking_enabled=settings.llm_fast_thinking_enabled,
|
||||
max_tokens=1200,
|
||||
)
|
||||
|
||||
async def stream_answer(self, question: str, sources: list[LearningAssistantSource]) -> AsyncIterator[LLMStreamChunk]:
|
||||
async def stream_answer(
|
||||
self,
|
||||
question: str,
|
||||
sources: list[LearningAssistantSource],
|
||||
history: list[dict] | None = None,
|
||||
) -> AsyncIterator[LLMStreamChunk]:
|
||||
"""流式回答:输出 AI 学习助手增量文本,前端可直接渲染。"""
|
||||
async for chunk in self.llm_client.stream_chat(
|
||||
self._messages(question, sources),
|
||||
self._messages(question, sources, history or []),
|
||||
model=settings.llm_fast_model,
|
||||
thinking_enabled=settings.llm_fast_thinking_enabled,
|
||||
max_tokens=1200,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _messages(self, question: str, sources: list[LearningAssistantSource]) -> list[dict]:
|
||||
"""提示词拼接:命中知识库时必须引用来源,未命中时必须声明未找到参考。"""
|
||||
def _messages(self, question: str, sources: list[LearningAssistantSource], history: list[dict]) -> list[dict]:
|
||||
"""提示词拼接:命中知识库时强制引用来源,未命中时必须声明未找到机构参考。"""
|
||||
history_text = self._history_text(history)
|
||||
if sources:
|
||||
context = "\n\n".join(
|
||||
(
|
||||
@@ -42,17 +39,41 @@ class LearningAssistantAgent:
|
||||
for index, source in enumerate(sources, start=1)
|
||||
)
|
||||
system = (
|
||||
"你是医学学习助手,只用于医学教育学习,不替代临床诊疗。"
|
||||
"请优先依据给定知识库片段回答,回答要清晰、准确、分点。"
|
||||
"你是医学学习助手,用于医学教育、课程学习和临床思维训练,不替代临床诊疗。"
|
||||
"优先依据给定知识库片段回答,回答要清晰、准确、分点。"
|
||||
"每个关键结论后标注对应来源编号,例如【来源1】。"
|
||||
"不得编造不存在的PDF、页码或指南来源。"
|
||||
"不得编造不存在的 PDF、页码或指南来源。"
|
||||
)
|
||||
user = (
|
||||
f"{history_text}"
|
||||
f"用户当前问题:{question}\n\n"
|
||||
f"可用知识库片段:\n{context}\n\n"
|
||||
"请给出带来源的学习回答。"
|
||||
)
|
||||
user = f"用户问题:{question}\n\n可用知识库片段:\n{context}\n\n请给出带来源的学习回答。"
|
||||
else:
|
||||
system = (
|
||||
"你是医学学习助手,只用于医学教育学习,不替代临床诊疗。"
|
||||
"当前没有检索到机构知识库参考,回答开头必须写:未检索到本机构知识库参考,以下为大模型通用学习回答。"
|
||||
"不得伪造PDF来源、页码或指南名称。"
|
||||
"你是医学学习助手,用于医学教育、课程学习和临床思维训练,不替代临床诊疗。"
|
||||
"当前没有检索到机构知识库参考,回答开头必须写:"
|
||||
"未检索到本机构知识库参考,以下为大模型通用学习回答。"
|
||||
"不得伪造 PDF 来源、页码或指南名称。"
|
||||
)
|
||||
user = (
|
||||
f"{history_text}"
|
||||
f"用户当前问题:{question}\n\n"
|
||||
"请给出通用学习回答,并提醒用户以课程教材、指南和临床医生判断为准。"
|
||||
)
|
||||
user = f"用户问题:{question}\n\n请给出通用学习回答,并提醒用户以课程教材和临床规范为准。"
|
||||
return [{"role": "system", "content": system}, {"role": "user", "content": user}]
|
||||
|
||||
def _history_text(self, history: list[dict]) -> str:
|
||||
"""上下文摘要:把当前学习助手会话最近几轮问答压缩为提示词上下文。"""
|
||||
if not history:
|
||||
return ""
|
||||
lines: list[str] = []
|
||||
for item in history[-settings.learning_assistant_history_limit :]:
|
||||
role = "用户" if item.get("role") == "user" else "助手"
|
||||
content = str(item.get("content") or "").strip()
|
||||
if content:
|
||||
lines.append(f"{role}:{content[:500]}")
|
||||
if not lines:
|
||||
return ""
|
||||
return "当前会话最近上下文:\n" + "\n".join(lines) + "\n\n"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from app.agents.llm_adapter import LLMResponse, LLMStreamChunk
|
||||
from app.agents.llm_adapter import LLMStreamChunk
|
||||
from app.agents.hint_agent import HintAgent
|
||||
from app.agents.patient_agent import PatientAgent
|
||||
from app.agents.report_agent import ReportAgent
|
||||
@@ -18,10 +18,6 @@ class MedicalConsultationOrchestrator:
|
||||
self.scoring_agent = ScoringAgent()
|
||||
self.report_agent = ReportAgent()
|
||||
|
||||
async def patient_reply(self, session: TrainingSession, case: CaseBase, memory_messages: list[dict], message: str) -> LLMResponse:
|
||||
"""问诊编排:调用 Patient Agent 生成 AI 病人回复。"""
|
||||
return await self.patient_agent.reply(case, memory_messages, message, session.mode, self._patient_config(session))
|
||||
|
||||
async def patient_stream_reply(
|
||||
self,
|
||||
session: TrainingSession,
|
||||
@@ -84,7 +80,7 @@ class MedicalConsultationOrchestrator:
|
||||
orders: list[SessionOrder],
|
||||
last_user_message: str | None = None,
|
||||
) -> dict:
|
||||
"""新手提示编排:基于当前会话上下文生成轻量训练提醒。"""
|
||||
"""练习提示编排:基于当前会话上下文生成轻量训练提醒。"""
|
||||
return await self.hint_agent.generate(session, case, memory_messages, orders, last_user_message)
|
||||
|
||||
def _patient_config(self, session: TrainingSession) -> dict | None:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from app.agents.llm_adapter import DeepSeekClient, LLMResponse, LLMStreamChunk
|
||||
from app.agents.llm_adapter import DeepSeekClient, LLMStreamChunk
|
||||
from app.core.config import settings
|
||||
from app.models.source_case import CaseBase
|
||||
|
||||
@@ -11,23 +11,6 @@ class PatientAgent:
|
||||
def __init__(self, llm: DeepSeekClient | None = None) -> None:
|
||||
self.llm = llm or DeepSeekClient()
|
||||
|
||||
async def reply(
|
||||
self,
|
||||
case: CaseBase,
|
||||
memory_messages: list[dict],
|
||||
user_message: str,
|
||||
mode: str,
|
||||
patient_config: dict | None = None,
|
||||
) -> LLMResponse:
|
||||
"""问诊回复:拼接病例上下文、短期记忆和用户输入后调用 Patient Agent。"""
|
||||
messages = self._build_messages(case, memory_messages, user_message, mode, patient_config)
|
||||
return await self.llm.chat(
|
||||
messages,
|
||||
settings.llm_fast_model,
|
||||
thinking_enabled=settings.llm_fast_thinking_enabled,
|
||||
max_tokens=settings.llm_fast_max_tokens,
|
||||
)
|
||||
|
||||
async def stream_reply(
|
||||
self,
|
||||
case: CaseBase,
|
||||
@@ -58,11 +41,7 @@ class PatientAgent:
|
||||
profile = case.ai_patient_profile or {}
|
||||
hidden_info = case.hidden_patient_info or {}
|
||||
config_rule = self._build_patient_config_rule(patient_config)
|
||||
mode_rule = {
|
||||
"novice": "新手模式:回答清楚,必要时可提示医生继续追问症状、既往史或检查。",
|
||||
"practice": "练习模式:只回答被问到的信息,不主动给诊断建议。",
|
||||
"teaching": "教学模式:保持患者身份,允许在回答后补充简短学习提示。",
|
||||
}.get(mode, "只回答被问到的信息。")
|
||||
mode_rule = "练习模式:只回答被问到的信息,不主动给诊断建议。"
|
||||
system = f"""
|
||||
你是一名标准化 AI 病人或患儿家属,只能基于病例资料回答。
|
||||
病例主诉:{case.chief_complaint}
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.response import ApiResponse, ok
|
||||
from app.core.user_context import UserContext, get_user_context
|
||||
from app.db.session import get_db
|
||||
from app.schemas.case import CaseDetailResponse, CaseListResponse
|
||||
from app.services.case_service import CaseService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=ApiResponse[CaseListResponse])
|
||||
def list_cases(
|
||||
_: UserContext = Depends(get_user_context),
|
||||
db: Session = Depends(get_db),
|
||||
department_id: int | None = Query(default=None),
|
||||
training_type: str | None = Query(default=None),
|
||||
mode: str | None = Query(default=None),
|
||||
):
|
||||
"""病例列表:返回当前可用的激活病例,不暴露标准答案。"""
|
||||
return ok(CaseService(db).list_cases(department_id=department_id, training_type=training_type, mode=mode))
|
||||
|
||||
|
||||
@router.get("/{case_id}", response_model=ApiResponse[CaseDetailResponse])
|
||||
def get_case_detail(
|
||||
case_id: int,
|
||||
_: UserContext = Depends(get_user_context),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""病例详情:返回训练入口信息和可申请检查类型。"""
|
||||
return ok(CaseService(db).get_case_detail(case_id))
|
||||
+1
-13
@@ -7,7 +7,7 @@ from sqlalchemy.orm import Session
|
||||
from app.core.response import ApiResponse, ok
|
||||
from app.core.user_context import UserContext, get_user_context
|
||||
from app.db.session import get_db
|
||||
from app.schemas.evaluation import EvaluationDetailResponse, EvaluationListResponse, ExportPdfResponse
|
||||
from app.schemas.evaluation import EvaluationDetailResponse, EvaluationListResponse
|
||||
from app.services.evaluation_service import EvaluationService
|
||||
from app.services.pdf_export_service import PdfExportService
|
||||
|
||||
@@ -35,18 +35,6 @@ def get_evaluation_detail(
|
||||
return ok(EvaluationService(db).get_detail(evaluation_id, ctx.user_id))
|
||||
|
||||
|
||||
@router.post("/{evaluation_id}/export-pdf", response_model=ApiResponse[ExportPdfResponse])
|
||||
def export_pdf(
|
||||
evaluation_id: int,
|
||||
ctx: UserContext = Depends(get_user_context),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""PDF 导出:生成评价报告 PDF 并保存导出记录。"""
|
||||
export = PdfExportService(db).export(evaluation_id, ctx.user_id)
|
||||
db.commit()
|
||||
return ok(ExportPdfResponse(export_id=export.id, file_path=export.file_path))
|
||||
|
||||
|
||||
@router.get("/{evaluation_id}/download-pdf", response_class=FileResponse)
|
||||
def download_pdf(
|
||||
evaluation_id: int,
|
||||
|
||||
@@ -5,33 +5,43 @@ from starlette.responses import StreamingResponse
|
||||
from app.core.response import ApiResponse, ok
|
||||
from app.core.user_context import UserContext, get_user_context
|
||||
from app.db.session import get_db
|
||||
from app.schemas.learning_assistant import LearningAssistantChatRequest, LearningAssistantChatResponse
|
||||
from app.schemas.learning_assistant import (
|
||||
LearningAssistantChatRequest,
|
||||
LearningAssistantSessionCreateRequest,
|
||||
LearningAssistantSessionResponse,
|
||||
)
|
||||
from app.services.learning_assistant_service import LearningAssistantService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/chat", response_model=ApiResponse[LearningAssistantChatResponse], include_in_schema=False)
|
||||
async def learning_assistant_chat(
|
||||
payload: LearningAssistantChatRequest,
|
||||
@router.post("/sessions", response_model=ApiResponse[LearningAssistantSessionResponse])
|
||||
async def create_learning_assistant_session(
|
||||
payload: LearningAssistantSessionCreateRequest,
|
||||
ctx: UserContext = Depends(get_user_context),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""AI 学习助手调试接口:非流式返回回答,正式前端联调使用流式接口。"""
|
||||
result = await LearningAssistantService(db).chat(ctx, payload)
|
||||
db.commit()
|
||||
"""学习助手会话创建:进入 AI 学习助手页面时生成短期会话 ID。"""
|
||||
result = LearningAssistantService(db).create_session(ctx, payload)
|
||||
return ok(result)
|
||||
|
||||
|
||||
@router.post("/chat/stream", response_class=StreamingResponse)
|
||||
async def learning_assistant_stream_chat(
|
||||
@router.post("/sessions/{assistant_session_id}/chat/stream", response_class=StreamingResponse)
|
||||
async def stream_learning_assistant_session_chat(
|
||||
assistant_session_id: str,
|
||||
payload: LearningAssistantChatRequest,
|
||||
ctx: UserContext = Depends(get_user_context),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""AI 学习助手流式问答:返回 retrieval_done、answer_delta、answer_done 事件。"""
|
||||
stream = LearningAssistantService(db).stream_chat(ctx, payload)
|
||||
db.commit()
|
||||
"""学习助手会话式流式问答:绑定短期会话上下文,返回 SSE 增量回答。"""
|
||||
service = LearningAssistantService(db)
|
||||
assistant_session = service.validate_session(ctx, assistant_session_id)
|
||||
stream = service.stream_session_chat(ctx, payload, assistant_session)
|
||||
return _sse_response(stream)
|
||||
|
||||
|
||||
def _sse_response(stream) -> StreamingResponse:
|
||||
"""SSE 响应封装:关闭代理缓冲,避免前端长时间看不到增量内容。"""
|
||||
return StreamingResponse(
|
||||
stream,
|
||||
media_type="text/event-stream",
|
||||
|
||||
+1
-2
@@ -1,11 +1,10 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api import agent, auth, cases, evaluations, knowledge_admin, learning_assistant, sessions, teaching, training_config
|
||||
from app.api import agent, auth, evaluations, knowledge_admin, learning_assistant, sessions, teaching, training_config
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(agent.router, tags=["agent"])
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||
api_router.include_router(cases.router, prefix="/cases", tags=["cases"])
|
||||
api_router.include_router(training_config.router, prefix="/training-config", tags=["training-config"])
|
||||
api_router.include_router(sessions.router, prefix="/sessions", tags=["sessions"])
|
||||
api_router.include_router(teaching.router, prefix="/teaching", tags=["teaching"])
|
||||
|
||||
+1
-15
@@ -8,7 +8,6 @@ from app.db.session import get_db
|
||||
from app.schemas.evaluation import CreateEvaluationRequest, EvaluationResponse
|
||||
from app.schemas.session import (
|
||||
ChatRequest,
|
||||
ChatResponse,
|
||||
CreateOrderRequest,
|
||||
CreateOrderResponse,
|
||||
CreateSessionRequest,
|
||||
@@ -41,19 +40,6 @@ def create_session(
|
||||
return ok(result)
|
||||
|
||||
|
||||
@router.post("/{session_id}/chat", response_model=ApiResponse[ChatResponse])
|
||||
async def chat(
|
||||
session_id: int,
|
||||
payload: ChatRequest,
|
||||
ctx: UserContext = Depends(get_user_context),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""非流式问诊:发送医生问题并返回 AI 病人回复。"""
|
||||
result = await SessionService(db).chat(ctx, session_id, payload.message)
|
||||
db.commit()
|
||||
return ok(result)
|
||||
|
||||
|
||||
@router.post("/{session_id}/chat/stream", response_class=StreamingResponse)
|
||||
async def chat_stream(
|
||||
session_id: int,
|
||||
@@ -117,7 +103,7 @@ async def generate_hints(
|
||||
ctx: UserContext = Depends(get_user_context),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""新手模式提示:根据当前问诊上下文生成缺失维度和下一步问题。"""
|
||||
"""练习提示:根据当前问诊上下文生成缺失维度和下一步问题。"""
|
||||
result = await SessionService(db).generate_hints(ctx, session_id, payload)
|
||||
db.commit()
|
||||
return ok(result)
|
||||
|
||||
@@ -111,6 +111,10 @@ class Settings(BaseModel):
|
||||
runtime_memory_fallback_enabled: bool = Field(
|
||||
default_factory=lambda: _env_bool("RUNTIME_MEMORY_FALLBACK_ENABLED", True)
|
||||
)
|
||||
learning_assistant_session_ttl_seconds: int = Field(
|
||||
default_factory=lambda: int(os.getenv("LEARNING_ASSISTANT_SESSION_TTL_SECONDS", os.getenv("RUNTIME_MEMORY_TTL_SECONDS", "7200")))
|
||||
)
|
||||
learning_assistant_history_limit: int = Field(default_factory=lambda: int(os.getenv("LEARNING_ASSISTANT_HISTORY_LIMIT", "6")))
|
||||
redis_url: str = Field(default_factory=lambda: os.getenv("REDIS_URL", "redis://redis:6379/0"))
|
||||
auth_validate_enabled: bool = Field(default_factory=lambda: _env_bool("AUTH_VALIDATE_ENABLED", True))
|
||||
auth_user_me_url: str = Field(default_factory=lambda: os.getenv("AUTH_USER_ME_URL", ""))
|
||||
|
||||
@@ -43,7 +43,7 @@ class MilvusVectorStore:
|
||||
index_params=index_params,
|
||||
consistency_level="Strong",
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - 真实 Milvus 由联调环境验证
|
||||
except Exception as exc: # pragma: no cover - 真实 Milvus 由集成环境验证
|
||||
raise AppError("MILVUS_COLLECTION_INIT_FAILED", "milvus collection init failed", 502) from exc
|
||||
|
||||
def upsert_vectors(self, collection_name: str, vectors: list[tuple[str, list[float]]]) -> None:
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
---
|
||||
template_code: novice_hint
|
||||
agent_type: hint
|
||||
version: v1
|
||||
scene: novice
|
||||
model_type: fast
|
||||
output_format: text
|
||||
---
|
||||
|
||||
# Role
|
||||
|
||||
你是临床问诊教学提示 Agent。
|
||||
|
||||
# Task
|
||||
|
||||
在新手模式下生成下一步问诊提示,帮助用户补齐问诊框架。
|
||||
|
||||
# Inputs
|
||||
|
||||
- 当前病例基础信息。
|
||||
- 已完成问诊内容。
|
||||
- 缺失的关键症状、病史或风险点。
|
||||
|
||||
# Rules
|
||||
|
||||
- 只提示问诊方向。
|
||||
- 不直接给出诊断结论。
|
||||
- 不替用户完成问诊。
|
||||
|
||||
# Output Format
|
||||
|
||||
输出 1 条简短提示。
|
||||
|
||||
# Safety Boundaries
|
||||
|
||||
提示仅用于教学训练,不构成真实医疗建议。
|
||||
@@ -1,19 +1,19 @@
|
||||
---
|
||||
template_code: novice_case_hint
|
||||
template_code: practice_case_hint
|
||||
agent_type: hint
|
||||
version: v1
|
||||
scene: novice
|
||||
scene: practice
|
||||
model_type: fast
|
||||
output_format: json
|
||||
---
|
||||
|
||||
# Role
|
||||
|
||||
你是医疗问诊训练系统的新手提示 Agent。你的任务是帮助医学生在当前病例训练中发现问诊缺口、下一步问题和必要检查,而不是替学生完成诊断。
|
||||
你是医疗问诊训练系统的练习提示 Agent。你的任务是帮助医学生在当前病例训练中发现问诊缺口、下一步问题和必要检查,而不是替学生完成诊断。
|
||||
|
||||
# Task
|
||||
|
||||
根据输入的病例信息、当前会话状态、短期对话摘要、已申请检查和最后一句医生问题,生成新手模式下可展示的结构化提示。
|
||||
根据输入的病例信息、当前会话状态、短期对话摘要、已申请检查和最后一句医生问题,生成练习模式下可展示的结构化提示。
|
||||
|
||||
# Inputs
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
---
|
||||
template_code: patient_novice
|
||||
agent_type: patient
|
||||
version: v1
|
||||
scene: novice
|
||||
model_type: fast
|
||||
output_format: text
|
||||
---
|
||||
|
||||
# Role
|
||||
|
||||
你是医疗问诊训练中的 AI 标准化病人或患儿家属。
|
||||
|
||||
# Task
|
||||
|
||||
在新手模式下回答医生问题,并用更清晰的表达帮助用户建立问诊框架。
|
||||
|
||||
# Inputs
|
||||
|
||||
- 病例资料和 AI 病人人设。
|
||||
- 当前短期 memory。
|
||||
- 医生最新问题。
|
||||
- 新手模式问诊引导规则。
|
||||
|
||||
# Rules
|
||||
|
||||
- 只基于病例内信息回答。
|
||||
- 不直接给出诊断或治疗方案。
|
||||
- 医生问题过宽时,允许用家属口吻提示一个继续追问方向。
|
||||
- 不输出检查结果,除非医生明确申请并由系统工具返回。
|
||||
|
||||
# Output Format
|
||||
|
||||
先回答问题,再补充一句温和引导,例如“医生,您还想了解哪方面?”。
|
||||
|
||||
# Safety Boundaries
|
||||
|
||||
本输出仅用于教学训练,不构成真实医疗建议,不替代临床医生判断。
|
||||
@@ -1,40 +0,0 @@
|
||||
---
|
||||
template_code: patient_practice
|
||||
agent_type: patient
|
||||
version: v1
|
||||
scene: practice
|
||||
model_type: fast
|
||||
output_format: text
|
||||
---
|
||||
|
||||
# Role
|
||||
|
||||
你是医疗问诊训练中的 AI 标准化病人或患儿家属。
|
||||
|
||||
# Task
|
||||
|
||||
在练习模式下根据病例资料回答医生问题,保持真实患者沟通风格。
|
||||
|
||||
# Inputs
|
||||
|
||||
- 病例基础信息。
|
||||
- AI 病人人设。
|
||||
- 隐藏信息。
|
||||
- 当前会话短期 memory。
|
||||
- 医生最新问题。
|
||||
|
||||
# Rules
|
||||
|
||||
- 只回答被问到的内容。
|
||||
- 不主动给出未被追问的隐藏信息。
|
||||
- 不评价医生表现。
|
||||
- 不输出诊断指导。
|
||||
- 不编造病例外检查结果。
|
||||
|
||||
# Output Format
|
||||
|
||||
使用自然、简短的患者或家属口吻回答。
|
||||
|
||||
# Safety Boundaries
|
||||
|
||||
本输出仅用于医学模拟训练,不构成真实医疗建议。
|
||||
@@ -1,39 +0,0 @@
|
||||
---
|
||||
template_code: patient_teaching
|
||||
agent_type: patient
|
||||
version: v1
|
||||
scene: teaching
|
||||
model_type: fast
|
||||
output_format: text
|
||||
---
|
||||
|
||||
# Role
|
||||
|
||||
你是医疗问诊训练中的 AI 标准化病人或患儿家属,同时支持教学互动。
|
||||
|
||||
# Task
|
||||
|
||||
回答医生问题,并在不泄露标准答案的前提下给出简短学习提示。
|
||||
|
||||
# Inputs
|
||||
|
||||
- 病例资料。
|
||||
- 教学互动配置。
|
||||
- 当前短期 memory。
|
||||
- 医生最新问题。
|
||||
|
||||
# Rules
|
||||
|
||||
- 患者回答和教学提示必须分开。
|
||||
- 教学提示只能提示问诊方向,不直接给出诊断结论。
|
||||
- 不编造病例外检查结果。
|
||||
|
||||
# Output Format
|
||||
|
||||
输出格式:
|
||||
患者回答:...
|
||||
学习提示:...
|
||||
|
||||
# Safety Boundaries
|
||||
|
||||
本输出仅用于教学训练,不替代真实临床诊疗。
|
||||
@@ -1,36 +0,0 @@
|
||||
---
|
||||
template_code: doctor_question_polish
|
||||
agent_type: polish
|
||||
version: v1
|
||||
scene: doctor_question
|
||||
model_type: fast
|
||||
output_format: text
|
||||
---
|
||||
|
||||
# Role
|
||||
|
||||
你是临床问诊表达润色 Agent。
|
||||
|
||||
# Task
|
||||
|
||||
将用户问题改写为更规范、清晰、符合临床问诊习惯的表达。
|
||||
|
||||
# Inputs
|
||||
|
||||
- 用户原始问题。
|
||||
- 当前病例场景。
|
||||
- 当前训练模式。
|
||||
|
||||
# Rules
|
||||
|
||||
- 保留原始意图。
|
||||
- 不改变医学含义。
|
||||
- 不加入用户没有表达的新问题。
|
||||
|
||||
# Output Format
|
||||
|
||||
输出润色后的单句问诊问题。
|
||||
|
||||
# Safety Boundaries
|
||||
|
||||
润色结果仅用于教学训练,不作为真实医疗建议。
|
||||
@@ -1,7 +1,7 @@
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from app.models.source_case import CaseBase, CaseExamItem, TeachingCase, TraditionalCase
|
||||
from app.models.source_case import CaseBase, CaseExamItem
|
||||
|
||||
|
||||
class CaseRepository:
|
||||
@@ -10,31 +10,8 @@ class CaseRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def list_active_cases(
|
||||
self,
|
||||
department_id: int | None = None,
|
||||
training_type: str | None = None,
|
||||
mode: str | None = None,
|
||||
) -> list[CaseBase]:
|
||||
"""病例列表:从 case_base 读取已发布病例,并按模式匹配扩展表。"""
|
||||
normalized_mode = "practice" if mode == "novice" else mode
|
||||
stmt = (
|
||||
select(CaseBase)
|
||||
.options(selectinload(CaseBase.traditional_case), selectinload(CaseBase.teaching_case))
|
||||
.where(CaseBase.status == 1, CaseBase.publish_status == 1)
|
||||
)
|
||||
if department_id:
|
||||
stmt = stmt.where(CaseBase.department_id == department_id)
|
||||
if training_type:
|
||||
stmt = stmt.where(CaseBase.case_type == training_type)
|
||||
if normalized_mode == "practice":
|
||||
stmt = stmt.where(exists().where(TraditionalCase.case_id == CaseBase.id))
|
||||
if normalized_mode == "teaching":
|
||||
stmt = stmt.where(exists().where(TeachingCase.case_id == CaseBase.id))
|
||||
return list(self.db.scalars(stmt.order_by(CaseBase.id.desc())).all())
|
||||
|
||||
def get_active_case(self, case_id: int) -> CaseBase | None:
|
||||
"""病例详情:读取病例主表及训练所需的扩展表、评分规则和检查项目。"""
|
||||
"""病例读取:读取病例主表及训练所需的扩展表、评分规则和检查项目。"""
|
||||
stmt = (
|
||||
select(CaseBase)
|
||||
.options(
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class CaseListItem(BaseModel):
|
||||
"""病例列表项:不暴露标准答案和隐藏信息。"""
|
||||
|
||||
id: int
|
||||
case_code: str
|
||||
department_id: int
|
||||
title: str
|
||||
difficulty: str
|
||||
chief_complaint: str | None = None
|
||||
supported_training_type: str
|
||||
supported_mode: str
|
||||
has_teaching_video: bool
|
||||
has_knowledge_points: bool
|
||||
has_quiz: bool
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class CaseListResponse(BaseModel):
|
||||
"""病例列表响应:返回激活病例集合。"""
|
||||
|
||||
items: list[CaseListItem]
|
||||
|
||||
|
||||
class CasePatientInfo(BaseModel):
|
||||
"""患者展示信息:用于病例详情页。"""
|
||||
|
||||
name: str | None = None
|
||||
age: int | None = None
|
||||
gender: str | None = None
|
||||
occupation: str | None = None
|
||||
|
||||
|
||||
class CaseDetailResponse(BaseModel):
|
||||
"""病例详情响应:展示训练入口需要的信息。"""
|
||||
|
||||
id: int
|
||||
case_code: str
|
||||
title: str
|
||||
department: str
|
||||
difficulty: str
|
||||
patient: CasePatientInfo
|
||||
chief_complaint: str | None = None
|
||||
supported_training_type: str
|
||||
supported_mode: str
|
||||
has_teaching_video: bool
|
||||
has_knowledge_points: bool
|
||||
has_quiz: bool
|
||||
order_item_types: list[str]
|
||||
@@ -79,13 +79,6 @@ class EvaluationListResponse(BaseModel):
|
||||
pagination: PaginationMeta
|
||||
|
||||
|
||||
class ExportPdfResponse(BaseModel):
|
||||
"""PDF 导出响应:返回导出记录和本地文件路径。"""
|
||||
|
||||
export_id: int
|
||||
file_path: str
|
||||
|
||||
|
||||
class EvaluationDetailResponse(EvaluationResponse):
|
||||
"""评价详情响应:在报告详情页使用。"""
|
||||
|
||||
|
||||
@@ -1,10 +1,30 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class LearningAssistantChatRequest(BaseModel):
|
||||
"""学习助手请求:普通用户面向机构知识库提出医学学习问题。"""
|
||||
class LearningAssistantSessionCreateRequest(BaseModel):
|
||||
"""学习助手会话创建请求:进入 AI 学习助手页面时初始化短期问答会话。"""
|
||||
|
||||
question: str = Field(..., min_length=2, max_length=1000, description="用户问题")
|
||||
title: str | None = Field(default=None, max_length=100, description="会话标题,前端可不传")
|
||||
|
||||
|
||||
class LearningAssistantSessionResponse(BaseModel):
|
||||
"""学习助手会话响应:返回前端后续流式问答需要使用的会话 ID。"""
|
||||
|
||||
assistant_session_id: str
|
||||
user_id: str
|
||||
institution_id: int | None = None
|
||||
institution_name: str | None = None
|
||||
title: str
|
||||
status: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
expires_in_seconds: int
|
||||
|
||||
|
||||
class LearningAssistantChatRequest(BaseModel):
|
||||
"""学习助手问答请求:普通用户面向机构知识库提出医学学习问题。"""
|
||||
|
||||
question: str = Field(..., min_length=1, max_length=1000, description="用户问题")
|
||||
top_k: int | None = Field(default=None, ge=1, le=10, description="最终返回给 LLM 的来源片段数")
|
||||
score_threshold: float | None = Field(default=None, ge=0, le=1, description="向量相似度过滤阈值")
|
||||
|
||||
@@ -20,17 +40,3 @@ class LearningAssistantSource(BaseModel):
|
||||
chunk_uid: str
|
||||
score: float
|
||||
quote: str
|
||||
|
||||
|
||||
class LearningAssistantChatResponse(BaseModel):
|
||||
"""学习助手回答:返回答案、知识库命中状态、循证来源和耗时。"""
|
||||
|
||||
answer: str
|
||||
retrieval_hit: bool
|
||||
sources: list[LearningAssistantSource] = Field(default_factory=list)
|
||||
retrieval_error: str | None = None
|
||||
model: str | None = None
|
||||
embedding_latency_ms: int | None = None
|
||||
search_latency_ms: int | None = None
|
||||
llm_latency_ms: int | None = None
|
||||
total_latency_ms: int | None = None
|
||||
|
||||
+3
-18
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.schemas.training_config import PatientConfig
|
||||
|
||||
@@ -8,16 +8,10 @@ class CreateSessionRequest(BaseModel):
|
||||
|
||||
case_id: int
|
||||
training_type: str = Field(pattern="^(case_analysis|diagnosis_treatment|consultation)$")
|
||||
mode: str = Field(pattern="^(novice|practice|teaching)$")
|
||||
mode: str = Field(pattern="^practice$")
|
||||
score_type: str = Field(default="percentage", pattern="^(percentage|five_point)$")
|
||||
patient_config: PatientConfig | None = None
|
||||
|
||||
@field_validator("mode")
|
||||
@classmethod
|
||||
def normalize_mode(cls, value: str) -> str:
|
||||
"""训练模式:兼容旧 novice 请求,实际按 practice 练习模式处理。"""
|
||||
return "practice" if value == "novice" else value
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
"""创建会话响应:返回会话标识和 AI 病人开场白。"""
|
||||
@@ -35,15 +29,6 @@ class ChatRequest(BaseModel):
|
||||
message: str = Field(min_length=1, max_length=2000)
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""问诊消息响应:返回 AI 病人的非流式回复。"""
|
||||
|
||||
reply: str
|
||||
latency_ms: int
|
||||
model: str
|
||||
fallback_used: bool = False
|
||||
|
||||
|
||||
class OrderItemResponse(BaseModel):
|
||||
"""可申请检查项:只返回名称和类型,不返回结果。"""
|
||||
|
||||
@@ -116,7 +101,7 @@ class SubmitTreatmentResponse(BaseModel):
|
||||
|
||||
|
||||
class HintRequest(BaseModel):
|
||||
"""会话提示入参:基于当前会话上下文生成新手模式提醒。"""
|
||||
"""会话提示入参:基于当前会话上下文生成练习提示。"""
|
||||
|
||||
last_user_message: str | None = Field(default=None, max_length=2000)
|
||||
scope: str = Field(default="current_conversation", pattern="^current_conversation$")
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.exceptions import AppError
|
||||
from app.models.source_case import CaseBase
|
||||
from app.repositories.case_repository import CaseRepository
|
||||
from app.repositories.source_case_repository import SourceCaseRepository
|
||||
from app.schemas.case import CaseDetailResponse, CaseListItem, CaseListResponse, CasePatientInfo
|
||||
|
||||
|
||||
class CaseService:
|
||||
"""病例服务:基于 case_base 新表体系提供病例列表和训练入口详情。"""
|
||||
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.repo = CaseRepository(db)
|
||||
self.source_repo = SourceCaseRepository(db)
|
||||
|
||||
def list_cases(
|
||||
self,
|
||||
department_id: int | None = None,
|
||||
training_type: str | None = None,
|
||||
mode: str | None = None,
|
||||
) -> CaseListResponse:
|
||||
"""病例列表:从 case_base 读取已发布病例,并按模式匹配传统/教学互动扩展表。"""
|
||||
cases = self.repo.list_active_cases(department_id=department_id, training_type=training_type, mode=mode)
|
||||
return CaseListResponse(items=[self._to_list_item(case) for case in cases])
|
||||
|
||||
def get_case_detail(self, case_id: int) -> CaseDetailResponse:
|
||||
"""病例详情:展示训练入口信息,不返回标准答案、隐藏病情和评分细则。"""
|
||||
case = self.repo.get_active_case(case_id)
|
||||
if not case:
|
||||
raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404)
|
||||
order_items = self.repo.get_exam_items(case.id)
|
||||
return CaseDetailResponse(
|
||||
id=case.id,
|
||||
case_code=f"SRC_{case.id}",
|
||||
title=case.title,
|
||||
department=self.source_repo.get_department_name(case.department_id),
|
||||
difficulty=case.difficulty,
|
||||
patient=CasePatientInfo(
|
||||
name=None,
|
||||
age=case.patient_age,
|
||||
gender=case.patient_gender,
|
||||
occupation=None,
|
||||
),
|
||||
chief_complaint=case.chief_complaint,
|
||||
supported_training_type=self._training_type(case.case_type),
|
||||
supported_mode=self._supported_mode(case),
|
||||
has_teaching_video=self._has_video(case),
|
||||
has_knowledge_points=bool(case.knowledge_points),
|
||||
has_quiz=bool(case.teaching_case and case.teaching_case.discussion_questions),
|
||||
order_item_types=sorted({item.item_type for item in order_items}),
|
||||
)
|
||||
|
||||
def _to_list_item(self, case: CaseBase) -> CaseListItem:
|
||||
"""病例卡片转换:把 case_base 映射为当前前端病例列表结构。"""
|
||||
return CaseListItem(
|
||||
id=case.id,
|
||||
case_code=f"SRC_{case.id}",
|
||||
department_id=case.department_id or 0,
|
||||
title=case.title,
|
||||
difficulty=case.difficulty,
|
||||
chief_complaint=case.chief_complaint,
|
||||
supported_training_type=self._training_type(case.case_type),
|
||||
supported_mode=self._supported_mode(case),
|
||||
has_teaching_video=self._has_video(case),
|
||||
has_knowledge_points=bool(case.knowledge_points),
|
||||
has_quiz=bool(case.teaching_case and case.teaching_case.discussion_questions),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supported_mode(case: CaseBase) -> str:
|
||||
"""模式标识:教学互动病例显示 interactive,其余显示 free_chat。"""
|
||||
return "interactive" if case.teaching_case else "free_chat"
|
||||
|
||||
@staticmethod
|
||||
def _has_video(case: CaseBase) -> bool:
|
||||
"""资源标识:根据 source 表 multimodal_assets 判断是否存在视频资源。"""
|
||||
assets = case.multimodal_assets or []
|
||||
return any(isinstance(item, dict) and item.get("type") == "video" for item in assets)
|
||||
|
||||
@staticmethod
|
||||
def _training_type(case_type: str) -> str:
|
||||
"""训练类别兼容:源库 case_type 不在当前枚举内时按诊断治疗训练处理。"""
|
||||
return case_type if case_type in {"case_analysis", "diagnosis_treatment", "consultation"} else "diagnosis_treatment"
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -10,8 +11,14 @@ from app.core.config import settings
|
||||
from app.core.context import UserContext
|
||||
from app.core.exceptions import AppError
|
||||
from app.repositories.knowledge_base_repository import KnowledgeBaseRepository
|
||||
from app.schemas.learning_assistant import LearningAssistantChatRequest, LearningAssistantChatResponse, LearningAssistantSource
|
||||
from app.schemas.learning_assistant import (
|
||||
LearningAssistantChatRequest,
|
||||
LearningAssistantSessionCreateRequest,
|
||||
LearningAssistantSessionResponse,
|
||||
LearningAssistantSource,
|
||||
)
|
||||
from app.services.knowledge_space_service import KnowledgeSpaceService
|
||||
from app.services.learning_assistant_session_store import LearningAssistantSessionStore, learning_assistant_session_store
|
||||
from app.services.vector_search_service import RetrievedChunk, VectorSearchService
|
||||
|
||||
|
||||
@@ -28,7 +35,7 @@ class LearningAssistantRetrieval:
|
||||
|
||||
|
||||
class LearningAssistantService:
|
||||
"""AI 学习助手服务:优先 RAG 检索,知识库不可用时降级为通用流式问答。"""
|
||||
"""AI 学习助手服务:管理短期会话,并优先通过 RAG 检索生成流式学习回答。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -36,78 +43,115 @@ class LearningAssistantService:
|
||||
*,
|
||||
vector_search_service: VectorSearchService | None = None,
|
||||
agent: LearningAssistantAgent | None = None,
|
||||
session_store: LearningAssistantSessionStore | None = None,
|
||||
) -> None:
|
||||
self.db = db
|
||||
self.repo = KnowledgeBaseRepository(db)
|
||||
self.space_service = KnowledgeSpaceService(self.repo)
|
||||
self.vector_search = vector_search_service or VectorSearchService(db)
|
||||
self.agent = agent or LearningAssistantAgent()
|
||||
self.session_store = session_store or learning_assistant_session_store
|
||||
|
||||
async def chat(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> LearningAssistantChatResponse:
|
||||
"""知识问答调试:检索失败不阻断回答,返回完整文本和检索降级信息。"""
|
||||
start = time.perf_counter()
|
||||
retrieval = await self._retrieve_sources(ctx, payload)
|
||||
llm_started = time.perf_counter()
|
||||
response = await self.agent.answer(payload.question, retrieval.sources)
|
||||
total_latency_ms = int((time.perf_counter() - start) * 1000)
|
||||
llm_latency_ms = response.latency_ms or int((time.perf_counter() - llm_started) * 1000)
|
||||
self._write_query_log(
|
||||
ctx=ctx,
|
||||
payload=payload,
|
||||
retrieval=retrieval,
|
||||
answer=response.content,
|
||||
model=response.model,
|
||||
llm_latency_ms=llm_latency_ms,
|
||||
total_latency_ms=total_latency_ms,
|
||||
)
|
||||
return LearningAssistantChatResponse(
|
||||
answer=response.content,
|
||||
retrieval_hit=bool(retrieval.sources),
|
||||
sources=retrieval.sources,
|
||||
retrieval_error=retrieval.retrieval_error,
|
||||
model=response.model,
|
||||
embedding_latency_ms=retrieval.embedding_latency_ms,
|
||||
search_latency_ms=retrieval.search_latency_ms,
|
||||
llm_latency_ms=llm_latency_ms,
|
||||
total_latency_ms=total_latency_ms,
|
||||
)
|
||||
def create_session(self, ctx: UserContext, payload: LearningAssistantSessionCreateRequest) -> LearningAssistantSessionResponse:
|
||||
"""学习助手会话创建:进入 AI 学习助手页面时初始化短期上下文容器。"""
|
||||
state = self.session_store.create(ctx, title=payload.title)
|
||||
return self._session_response(state)
|
||||
|
||||
async def stream_chat(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> AsyncIterator[str]:
|
||||
"""流式知识问答:先返回检索状态,再流式输出 LLM 回答。"""
|
||||
def validate_session(self, ctx: UserContext, assistant_session_id: str) -> dict[str, Any]:
|
||||
"""学习助手会话校验:确保会话存在、未过期且属于当前用户。"""
|
||||
state = self.session_store.get(assistant_session_id, ctx.user_id)
|
||||
if not state:
|
||||
raise AppError("LEARNING_ASSISTANT_SESSION_NOT_FOUND", "learning assistant session not found", 404)
|
||||
if state.get("status") != "active":
|
||||
raise AppError("LEARNING_ASSISTANT_SESSION_INVALID", "learning assistant session is not active", 400)
|
||||
return state
|
||||
|
||||
async def stream_session_chat(
|
||||
self,
|
||||
ctx: UserContext,
|
||||
payload: LearningAssistantChatRequest,
|
||||
assistant_session: dict[str, Any],
|
||||
) -> AsyncIterator[str]:
|
||||
"""会话式流式问答:绑定学习助手会话,记录最近问答并参与后续提示词拼接。"""
|
||||
yield self._sse(
|
||||
"session_ready",
|
||||
{
|
||||
"assistant_session_id": assistant_session["assistant_session_id"],
|
||||
"status": assistant_session["status"],
|
||||
"history_count": len(assistant_session.get("messages") or []),
|
||||
},
|
||||
)
|
||||
async for event in self._stream_answer(ctx, payload, assistant_session=assistant_session):
|
||||
yield event
|
||||
|
||||
async def _stream_answer(
|
||||
self,
|
||||
ctx: UserContext,
|
||||
payload: LearningAssistantChatRequest,
|
||||
*,
|
||||
assistant_session: dict[str, Any] | None,
|
||||
) -> AsyncIterator[str]:
|
||||
"""学习助手流式核心流程:检索知识库、调用 LLM、写入查询日志和短期会话上下文。"""
|
||||
start = time.perf_counter()
|
||||
assistant_session_id = assistant_session.get("assistant_session_id") if assistant_session else None
|
||||
history = (
|
||||
self.session_store.get_messages(assistant_session_id, ctx.user_id, settings.learning_assistant_history_limit)
|
||||
if assistant_session_id
|
||||
else []
|
||||
)
|
||||
if assistant_session_id:
|
||||
self.session_store.append_message(assistant_session_id, ctx.user_id, "user", payload.question)
|
||||
|
||||
retrieval = await self._retrieve_sources(ctx, payload)
|
||||
yield self._sse(
|
||||
"retrieval_done",
|
||||
{
|
||||
"retrieval_hit": bool(retrieval.sources),
|
||||
"sources": [source.model_dump() for source in retrieval.sources],
|
||||
"retrieval_error": retrieval.retrieval_error,
|
||||
"embedding_latency_ms": retrieval.embedding_latency_ms,
|
||||
"search_latency_ms": retrieval.search_latency_ms,
|
||||
},
|
||||
self._with_session(
|
||||
assistant_session_id,
|
||||
{
|
||||
"retrieval_hit": bool(retrieval.sources),
|
||||
"sources": [source.model_dump() for source in retrieval.sources],
|
||||
"retrieval_error": retrieval.retrieval_error,
|
||||
"embedding_latency_ms": retrieval.embedding_latency_ms,
|
||||
"search_latency_ms": retrieval.search_latency_ms,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
answer_parts: list[str] = []
|
||||
llm_latency_ms: int | None = None
|
||||
model: str | None = None
|
||||
try:
|
||||
async for chunk in self.agent.stream_answer(payload.question, retrieval.sources):
|
||||
async for chunk in self.agent.stream_answer(payload.question, retrieval.sources, history=history):
|
||||
if chunk.done:
|
||||
llm_latency_ms = chunk.total_latency_ms
|
||||
model = chunk.model
|
||||
break
|
||||
if chunk.delta:
|
||||
answer_parts.append(chunk.delta)
|
||||
yield self._sse("answer_delta", {"delta": chunk.delta})
|
||||
yield self._sse("answer_delta", self._with_session(assistant_session_id, {"delta": chunk.delta}))
|
||||
except AppError as exc:
|
||||
yield self._sse("error", {"code": exc.code, "message": exc.message})
|
||||
yield self._sse("error", self._with_session(assistant_session_id, {"code": exc.code, "message": exc.message}))
|
||||
return
|
||||
except Exception:
|
||||
yield self._sse("error", {"code": "LEARNING_ASSISTANT_LLM_FAILED", "message": "AI 学习助手回答生成失败,请稍后重试"})
|
||||
yield self._sse(
|
||||
"error",
|
||||
self._with_session(
|
||||
assistant_session_id,
|
||||
{"code": "LEARNING_ASSISTANT_LLM_FAILED", "message": "AI 学习助手回答生成失败,请稍后重试"},
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
answer = "".join(answer_parts)
|
||||
total_latency_ms = int((time.perf_counter() - start) * 1000)
|
||||
if assistant_session_id:
|
||||
self.session_store.append_message(
|
||||
assistant_session_id,
|
||||
ctx.user_id,
|
||||
"assistant",
|
||||
answer,
|
||||
metadata={"retrieval_hit": bool(retrieval.sources), "source_count": len(retrieval.sources), "model": model},
|
||||
)
|
||||
self._write_query_log(
|
||||
ctx=ctx,
|
||||
payload=payload,
|
||||
@@ -118,7 +162,17 @@ class LearningAssistantService:
|
||||
total_latency_ms=total_latency_ms,
|
||||
commit=True,
|
||||
)
|
||||
yield self._sse("answer_done", {"model": model, "total_latency_ms": total_latency_ms})
|
||||
yield self._sse(
|
||||
"answer_done",
|
||||
self._with_session(
|
||||
assistant_session_id,
|
||||
{
|
||||
"model": model,
|
||||
"total_latency_ms": total_latency_ms,
|
||||
"llm_latency_ms": llm_latency_ms,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
async def _retrieve_sources(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> LearningAssistantRetrieval:
|
||||
"""知识检索:按机构读取知识空间;无空间、Milvus 或 embedding 异常时降级为空来源。"""
|
||||
@@ -204,7 +258,6 @@ class LearningAssistantService:
|
||||
sources: list[LearningAssistantSource] = []
|
||||
for item in chunks:
|
||||
document = self.repo.get_document(item.chunk.document_id, item.chunk.institution_id)
|
||||
quote = item.chunk.chunk_text[:500]
|
||||
sources.append(
|
||||
LearningAssistantSource(
|
||||
document_id=item.chunk.document_id,
|
||||
@@ -214,11 +267,31 @@ class LearningAssistantService:
|
||||
page_end=item.chunk.page_end,
|
||||
chunk_uid=item.chunk.chunk_uid,
|
||||
score=round(item.score, 4),
|
||||
quote=quote,
|
||||
quote=item.chunk.chunk_text[:500],
|
||||
)
|
||||
)
|
||||
return sources
|
||||
|
||||
def _session_response(self, state: dict[str, Any]) -> LearningAssistantSessionResponse:
|
||||
"""会话响应转换:只返回前端需要展示和后续调用的字段。"""
|
||||
return LearningAssistantSessionResponse(
|
||||
assistant_session_id=state["assistant_session_id"],
|
||||
user_id=state["user_id"],
|
||||
institution_id=state.get("institution_id"),
|
||||
institution_name=state.get("institution_name"),
|
||||
title=state["title"],
|
||||
status=state["status"],
|
||||
created_at=state["created_at"],
|
||||
updated_at=state["updated_at"],
|
||||
expires_in_seconds=state["expires_in_seconds"],
|
||||
)
|
||||
|
||||
def _sse(self, event: str, data: dict) -> str:
|
||||
"""SSE 封装:统一输出 event + data 格式。"""
|
||||
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
def _with_session(self, assistant_session_id: str | None, data: dict) -> dict:
|
||||
"""SSE 数据增强:会话式接口返回 assistant_session_id,旧接口保持兼容。"""
|
||||
if assistant_session_id:
|
||||
return {"assistant_session_id": assistant_session_id, **data}
|
||||
return data
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.context import UserContext
|
||||
|
||||
|
||||
class LearningAssistantSessionStore:
|
||||
"""AI 学习助手短期会话存储:使用 Redis 保存会话状态,测试或降级时使用进程内存。"""
|
||||
|
||||
key_prefix = "learning_assistant:session:"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = Lock()
|
||||
self._memory_store: dict[str, dict[str, Any]] = {}
|
||||
self._redis_client = self._create_redis_client()
|
||||
|
||||
def create(self, ctx: UserContext, title: str | None = None) -> dict[str, Any]:
|
||||
"""学习助手会话创建:按当前用户和机构初始化一个短期问答会话。"""
|
||||
now = self._now()
|
||||
session_id = f"las_{uuid.uuid4().hex}"
|
||||
state: dict[str, Any] = {
|
||||
"assistant_session_id": session_id,
|
||||
"user_id": ctx.user_id,
|
||||
"institution_id": ctx.institution_id,
|
||||
"institution_name": self._profile_value(ctx, "institution_name"),
|
||||
"title": title or "AI 学习助手",
|
||||
"status": "active",
|
||||
"messages": [],
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"expires_in_seconds": settings.learning_assistant_session_ttl_seconds,
|
||||
}
|
||||
self._save(state)
|
||||
return state
|
||||
|
||||
def get(self, assistant_session_id: str, user_id: str) -> dict[str, Any] | None:
|
||||
"""学习助手会话读取:只返回属于当前用户且未过期的会话。"""
|
||||
state = self._load(assistant_session_id)
|
||||
if not state or state.get("user_id") != user_id:
|
||||
return None
|
||||
return state
|
||||
|
||||
def get_messages(self, assistant_session_id: str, user_id: str, limit: int | None = None) -> list[dict[str, Any]]:
|
||||
"""学习助手上下文读取:返回当前会话最近若干轮问答,用于提示词拼接。"""
|
||||
state = self.get(assistant_session_id, user_id)
|
||||
if not state:
|
||||
return []
|
||||
messages = list(state.get("messages") or [])
|
||||
if limit is None:
|
||||
limit = settings.learning_assistant_history_limit
|
||||
return messages[-limit:]
|
||||
|
||||
def append_message(
|
||||
self,
|
||||
assistant_session_id: str,
|
||||
user_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
metadata: dict | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""学习助手上下文写入:记录用户问题和 AI 回答,按 TTL 自动过期。"""
|
||||
state = self.get(assistant_session_id, user_id)
|
||||
if not state:
|
||||
return None
|
||||
messages = list(state.get("messages") or [])
|
||||
messages.append(
|
||||
{
|
||||
"role": role,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"created_at": self._now(),
|
||||
}
|
||||
)
|
||||
max_messages = max(settings.learning_assistant_history_limit * 2, 2)
|
||||
state["messages"] = messages[-max_messages:]
|
||||
state["updated_at"] = self._now()
|
||||
self._save(state)
|
||||
return state
|
||||
|
||||
def _create_redis_client(self):
|
||||
"""Redis 客户端创建:遵循 runtime memory 配置,失败时按配置降级。"""
|
||||
if settings.runtime_memory_backend.lower() != "redis":
|
||||
return None
|
||||
try:
|
||||
import redis
|
||||
|
||||
client = redis.Redis.from_url(settings.redis_url, decode_responses=True)
|
||||
client.ping()
|
||||
return client
|
||||
except Exception:
|
||||
if settings.is_production and not settings.runtime_memory_fallback_enabled:
|
||||
raise RuntimeError("Redis is required for learning assistant sessions")
|
||||
return None
|
||||
|
||||
def _save(self, state: dict[str, Any]) -> None:
|
||||
"""会话保存:Redis 使用 setex,内存 fallback 使用过期时间戳。"""
|
||||
session_id = state["assistant_session_id"]
|
||||
ttl = settings.learning_assistant_session_ttl_seconds
|
||||
if self._redis_client is not None:
|
||||
self._redis_client.setex(self._key(session_id), ttl, json.dumps(state, ensure_ascii=False))
|
||||
return
|
||||
with self._lock:
|
||||
self._memory_store[session_id] = {"expires_at": time.time() + ttl, "state": state}
|
||||
|
||||
def _load(self, assistant_session_id: str) -> dict[str, Any] | None:
|
||||
"""会话加载:读取并校验短期会话是否仍然有效。"""
|
||||
if self._redis_client is not None:
|
||||
raw = self._redis_client.get(self._key(assistant_session_id))
|
||||
if not raw:
|
||||
return None
|
||||
return json.loads(raw)
|
||||
with self._lock:
|
||||
item = self._memory_store.get(assistant_session_id)
|
||||
if not item:
|
||||
return None
|
||||
if item["expires_at"] < time.time():
|
||||
self._memory_store.pop(assistant_session_id, None)
|
||||
return None
|
||||
return dict(item["state"])
|
||||
|
||||
def _key(self, assistant_session_id: str) -> str:
|
||||
"""Redis key 生成:与训练短期 memory 隔离。"""
|
||||
return f"{self.key_prefix}{assistant_session_id}"
|
||||
|
||||
def _profile_value(self, ctx: UserContext, key: str) -> Any:
|
||||
"""用户资料读取:从 Django `/me` 标准化 profile 中提取扩展字段。"""
|
||||
if not ctx.profile:
|
||||
return None
|
||||
return ctx.profile.get(key)
|
||||
|
||||
def _now(self) -> str:
|
||||
"""时间格式化:返回 ISO 字符串,便于前端展示和日志排查。"""
|
||||
return datetime.utcnow().isoformat(timespec="seconds") + "Z"
|
||||
|
||||
|
||||
learning_assistant_session_store = LearningAssistantSessionStore()
|
||||
@@ -454,7 +454,7 @@ class PdfExportService:
|
||||
|
||||
def _mode_label(self, mode: str) -> str:
|
||||
"""训练模式标签:转换内部枚举为中文显示。"""
|
||||
return {"practice": "练习模式", "teaching": "教学互动模式", "novice": "练习模式"}.get(mode, mode)
|
||||
return {"practice": "练习模式", "teaching": "教学互动模式"}.get(mode, mode)
|
||||
|
||||
def _format_datetime(self, value: datetime | None) -> str:
|
||||
"""时间格式化:统一报告中的时间展示。"""
|
||||
|
||||
@@ -15,7 +15,6 @@ from app.models.training import SessionSubmission, TrainingSession
|
||||
from app.repositories.case_repository import CaseRepository
|
||||
from app.repositories.session_repository import SessionRepository
|
||||
from app.schemas.session import (
|
||||
ChatResponse,
|
||||
CreateSessionRequest,
|
||||
CreateSessionResponse,
|
||||
SessionStatusResponse,
|
||||
@@ -80,34 +79,6 @@ class SessionService:
|
||||
patient_config=patient_config,
|
||||
)
|
||||
|
||||
async def chat(self, ctx: UserContext, session_id: int, message: str) -> ChatResponse:
|
||||
"""问诊对话:拼接病例上下文、短期记忆和用户输入后调用 Patient Agent。"""
|
||||
session = self._get_session(session_id, ctx.user_id)
|
||||
if session.status != "inquiry":
|
||||
raise AppError("SESSION_STATUS_INVALID", "chat is only allowed in inquiry status", 400)
|
||||
case = self.case_repo.get_active_case(session.case_id)
|
||||
if not case:
|
||||
raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404)
|
||||
|
||||
start = time.perf_counter()
|
||||
memory_messages = runtime_memory.get_messages(session.memory_key)
|
||||
runtime_memory.add_message(session.memory_key or "", "doctor", message)
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
self.orchestrator.patient_reply(session, case, memory_messages, message),
|
||||
timeout=settings.llm_chat_timeout_seconds,
|
||||
)
|
||||
except TimeoutError as exc:
|
||||
raise AppError("LLM_CALL_TIMEOUT", "AI 病人回复超时,请稍后重试或切换为普通问诊", 504) from exc
|
||||
runtime_memory.add_message(session.memory_key or "", "patient", response.content)
|
||||
self.audit.log(ctx, "session.chat", "training_session", str(session.id), session.id)
|
||||
return ChatResponse(
|
||||
reply=response.content,
|
||||
latency_ms=response.latency_ms or int((time.perf_counter() - start) * 1000),
|
||||
model=response.model,
|
||||
fallback_used=response.model.startswith("mock-fallback"),
|
||||
)
|
||||
|
||||
async def stream_chat(self, ctx: UserContext, session_id: int, message: str) -> AsyncIterator[str]:
|
||||
"""流式问诊:返回 SSE 格式的 AI 病人回复。"""
|
||||
session = self._get_session(session_id, ctx.user_id)
|
||||
@@ -214,7 +185,7 @@ class SessionService:
|
||||
return f"event: error\ndata: {payload}\n\n"
|
||||
|
||||
async def generate_hints(self, ctx: UserContext, session_id: int, payload: HintRequest) -> HintResponse:
|
||||
"""新手提示:基于当前会话上下文、已申请检查和病例信息生成提醒。"""
|
||||
"""练习提示:基于当前会话上下文、已申请检查和病例信息生成提醒。"""
|
||||
session = self._get_session(session_id, ctx.user_id)
|
||||
if session.mode != "practice":
|
||||
raise AppError("SESSION_STATUS_INVALID", "hints are only available in practice mode", 400)
|
||||
|
||||
Reference in New Issue
Block a user