chore: finalize backend feature scope

This commit is contained in:
刘金宝
2026-06-11 16:19:07 +08:00
parent d855ecab82
commit ec515d5453
43 changed files with 680 additions and 712 deletions
-85
View File
@@ -1,85 +0,0 @@
from sqlalchemy.orm import Session
from app.core.exceptions import AppError
from app.models.source_case import CaseBase
from app.repositories.case_repository import CaseRepository
from app.repositories.source_case_repository import SourceCaseRepository
from app.schemas.case import CaseDetailResponse, CaseListItem, CaseListResponse, CasePatientInfo
class CaseService:
"""病例服务:基于 case_base 新表体系提供病例列表和训练入口详情。"""
def __init__(self, db: Session) -> None:
self.db = db
self.repo = CaseRepository(db)
self.source_repo = SourceCaseRepository(db)
def list_cases(
self,
department_id: int | None = None,
training_type: str | None = None,
mode: str | None = None,
) -> CaseListResponse:
"""病例列表:从 case_base 读取已发布病例,并按模式匹配传统/教学互动扩展表。"""
cases = self.repo.list_active_cases(department_id=department_id, training_type=training_type, mode=mode)
return CaseListResponse(items=[self._to_list_item(case) for case in cases])
def get_case_detail(self, case_id: int) -> CaseDetailResponse:
"""病例详情:展示训练入口信息,不返回标准答案、隐藏病情和评分细则。"""
case = self.repo.get_active_case(case_id)
if not case:
raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404)
order_items = self.repo.get_exam_items(case.id)
return CaseDetailResponse(
id=case.id,
case_code=f"SRC_{case.id}",
title=case.title,
department=self.source_repo.get_department_name(case.department_id),
difficulty=case.difficulty,
patient=CasePatientInfo(
name=None,
age=case.patient_age,
gender=case.patient_gender,
occupation=None,
),
chief_complaint=case.chief_complaint,
supported_training_type=self._training_type(case.case_type),
supported_mode=self._supported_mode(case),
has_teaching_video=self._has_video(case),
has_knowledge_points=bool(case.knowledge_points),
has_quiz=bool(case.teaching_case and case.teaching_case.discussion_questions),
order_item_types=sorted({item.item_type for item in order_items}),
)
def _to_list_item(self, case: CaseBase) -> CaseListItem:
"""病例卡片转换:把 case_base 映射为当前前端病例列表结构。"""
return CaseListItem(
id=case.id,
case_code=f"SRC_{case.id}",
department_id=case.department_id or 0,
title=case.title,
difficulty=case.difficulty,
chief_complaint=case.chief_complaint,
supported_training_type=self._training_type(case.case_type),
supported_mode=self._supported_mode(case),
has_teaching_video=self._has_video(case),
has_knowledge_points=bool(case.knowledge_points),
has_quiz=bool(case.teaching_case and case.teaching_case.discussion_questions),
)
@staticmethod
def _supported_mode(case: CaseBase) -> str:
"""模式标识:教学互动病例显示 interactive,其余显示 free_chat。"""
return "interactive" if case.teaching_case else "free_chat"
@staticmethod
def _has_video(case: CaseBase) -> bool:
"""资源标识:根据 source 表 multimodal_assets 判断是否存在视频资源。"""
assets = case.multimodal_assets or []
return any(isinstance(item, dict) and item.get("type") == "video" for item in assets)
@staticmethod
def _training_type(case_type: str) -> str:
"""训练类别兼容:源库 case_type 不在当前枚举内时按诊断治疗训练处理。"""
return case_type if case_type in {"case_analysis", "diagnosis_treatment", "consultation"} else "diagnosis_treatment"
+119 -46
View File
@@ -2,6 +2,7 @@ import json
import time
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any
from sqlalchemy.orm import Session
@@ -10,8 +11,14 @@ from app.core.config import settings
from app.core.context import UserContext
from app.core.exceptions import AppError
from app.repositories.knowledge_base_repository import KnowledgeBaseRepository
from app.schemas.learning_assistant import LearningAssistantChatRequest, LearningAssistantChatResponse, LearningAssistantSource
from app.schemas.learning_assistant import (
LearningAssistantChatRequest,
LearningAssistantSessionCreateRequest,
LearningAssistantSessionResponse,
LearningAssistantSource,
)
from app.services.knowledge_space_service import KnowledgeSpaceService
from app.services.learning_assistant_session_store import LearningAssistantSessionStore, learning_assistant_session_store
from app.services.vector_search_service import RetrievedChunk, VectorSearchService
@@ -28,7 +35,7 @@ class LearningAssistantRetrieval:
class LearningAssistantService:
"""AI 学习助手服务:优先 RAG 检索,知识库不可用时降级为通用流式问答。"""
"""AI 学习助手服务:管理短期会话,并优先通过 RAG 检索生成流式学习回答。"""
def __init__(
self,
@@ -36,78 +43,115 @@ class LearningAssistantService:
*,
vector_search_service: VectorSearchService | None = None,
agent: LearningAssistantAgent | None = None,
session_store: LearningAssistantSessionStore | None = None,
) -> None:
self.db = db
self.repo = KnowledgeBaseRepository(db)
self.space_service = KnowledgeSpaceService(self.repo)
self.vector_search = vector_search_service or VectorSearchService(db)
self.agent = agent or LearningAssistantAgent()
self.session_store = session_store or learning_assistant_session_store
async def chat(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> LearningAssistantChatResponse:
"""知识问答调试:检索失败不阻断回答,返回完整文本和检索降级信息"""
start = time.perf_counter()
retrieval = await self._retrieve_sources(ctx, payload)
llm_started = time.perf_counter()
response = await self.agent.answer(payload.question, retrieval.sources)
total_latency_ms = int((time.perf_counter() - start) * 1000)
llm_latency_ms = response.latency_ms or int((time.perf_counter() - llm_started) * 1000)
self._write_query_log(
ctx=ctx,
payload=payload,
retrieval=retrieval,
answer=response.content,
model=response.model,
llm_latency_ms=llm_latency_ms,
total_latency_ms=total_latency_ms,
)
return LearningAssistantChatResponse(
answer=response.content,
retrieval_hit=bool(retrieval.sources),
sources=retrieval.sources,
retrieval_error=retrieval.retrieval_error,
model=response.model,
embedding_latency_ms=retrieval.embedding_latency_ms,
search_latency_ms=retrieval.search_latency_ms,
llm_latency_ms=llm_latency_ms,
total_latency_ms=total_latency_ms,
)
def create_session(self, ctx: UserContext, payload: LearningAssistantSessionCreateRequest) -> LearningAssistantSessionResponse:
"""学习助手会话创建:进入 AI 学习助手页面时初始化短期上下文容器"""
state = self.session_store.create(ctx, title=payload.title)
return self._session_response(state)
async def stream_chat(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> AsyncIterator[str]:
"""流式知识问答:先返回检索状态,再流式输出 LLM 回答"""
def validate_session(self, ctx: UserContext, assistant_session_id: str) -> dict[str, Any]:
"""学习助手会话校验:确保会话存在、未过期且属于当前用户"""
state = self.session_store.get(assistant_session_id, ctx.user_id)
if not state:
raise AppError("LEARNING_ASSISTANT_SESSION_NOT_FOUND", "learning assistant session not found", 404)
if state.get("status") != "active":
raise AppError("LEARNING_ASSISTANT_SESSION_INVALID", "learning assistant session is not active", 400)
return state
async def stream_session_chat(
self,
ctx: UserContext,
payload: LearningAssistantChatRequest,
assistant_session: dict[str, Any],
) -> AsyncIterator[str]:
"""会话式流式问答:绑定学习助手会话,记录最近问答并参与后续提示词拼接。"""
yield self._sse(
"session_ready",
{
"assistant_session_id": assistant_session["assistant_session_id"],
"status": assistant_session["status"],
"history_count": len(assistant_session.get("messages") or []),
},
)
async for event in self._stream_answer(ctx, payload, assistant_session=assistant_session):
yield event
async def _stream_answer(
self,
ctx: UserContext,
payload: LearningAssistantChatRequest,
*,
assistant_session: dict[str, Any] | None,
) -> AsyncIterator[str]:
"""学习助手流式核心流程:检索知识库、调用 LLM、写入查询日志和短期会话上下文。"""
start = time.perf_counter()
assistant_session_id = assistant_session.get("assistant_session_id") if assistant_session else None
history = (
self.session_store.get_messages(assistant_session_id, ctx.user_id, settings.learning_assistant_history_limit)
if assistant_session_id
else []
)
if assistant_session_id:
self.session_store.append_message(assistant_session_id, ctx.user_id, "user", payload.question)
retrieval = await self._retrieve_sources(ctx, payload)
yield self._sse(
"retrieval_done",
{
"retrieval_hit": bool(retrieval.sources),
"sources": [source.model_dump() for source in retrieval.sources],
"retrieval_error": retrieval.retrieval_error,
"embedding_latency_ms": retrieval.embedding_latency_ms,
"search_latency_ms": retrieval.search_latency_ms,
},
self._with_session(
assistant_session_id,
{
"retrieval_hit": bool(retrieval.sources),
"sources": [source.model_dump() for source in retrieval.sources],
"retrieval_error": retrieval.retrieval_error,
"embedding_latency_ms": retrieval.embedding_latency_ms,
"search_latency_ms": retrieval.search_latency_ms,
},
),
)
answer_parts: list[str] = []
llm_latency_ms: int | None = None
model: str | None = None
try:
async for chunk in self.agent.stream_answer(payload.question, retrieval.sources):
async for chunk in self.agent.stream_answer(payload.question, retrieval.sources, history=history):
if chunk.done:
llm_latency_ms = chunk.total_latency_ms
model = chunk.model
break
if chunk.delta:
answer_parts.append(chunk.delta)
yield self._sse("answer_delta", {"delta": chunk.delta})
yield self._sse("answer_delta", self._with_session(assistant_session_id, {"delta": chunk.delta}))
except AppError as exc:
yield self._sse("error", {"code": exc.code, "message": exc.message})
yield self._sse("error", self._with_session(assistant_session_id, {"code": exc.code, "message": exc.message}))
return
except Exception:
yield self._sse("error", {"code": "LEARNING_ASSISTANT_LLM_FAILED", "message": "AI 学习助手回答生成失败,请稍后重试"})
yield self._sse(
"error",
self._with_session(
assistant_session_id,
{"code": "LEARNING_ASSISTANT_LLM_FAILED", "message": "AI 学习助手回答生成失败,请稍后重试"},
),
)
return
answer = "".join(answer_parts)
total_latency_ms = int((time.perf_counter() - start) * 1000)
if assistant_session_id:
self.session_store.append_message(
assistant_session_id,
ctx.user_id,
"assistant",
answer,
metadata={"retrieval_hit": bool(retrieval.sources), "source_count": len(retrieval.sources), "model": model},
)
self._write_query_log(
ctx=ctx,
payload=payload,
@@ -118,7 +162,17 @@ class LearningAssistantService:
total_latency_ms=total_latency_ms,
commit=True,
)
yield self._sse("answer_done", {"model": model, "total_latency_ms": total_latency_ms})
yield self._sse(
"answer_done",
self._with_session(
assistant_session_id,
{
"model": model,
"total_latency_ms": total_latency_ms,
"llm_latency_ms": llm_latency_ms,
},
),
)
async def _retrieve_sources(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> LearningAssistantRetrieval:
"""知识检索:按机构读取知识空间;无空间、Milvus 或 embedding 异常时降级为空来源。"""
@@ -204,7 +258,6 @@ class LearningAssistantService:
sources: list[LearningAssistantSource] = []
for item in chunks:
document = self.repo.get_document(item.chunk.document_id, item.chunk.institution_id)
quote = item.chunk.chunk_text[:500]
sources.append(
LearningAssistantSource(
document_id=item.chunk.document_id,
@@ -214,11 +267,31 @@ class LearningAssistantService:
page_end=item.chunk.page_end,
chunk_uid=item.chunk.chunk_uid,
score=round(item.score, 4),
quote=quote,
quote=item.chunk.chunk_text[:500],
)
)
return sources
def _session_response(self, state: dict[str, Any]) -> LearningAssistantSessionResponse:
"""会话响应转换:只返回前端需要展示和后续调用的字段。"""
return LearningAssistantSessionResponse(
assistant_session_id=state["assistant_session_id"],
user_id=state["user_id"],
institution_id=state.get("institution_id"),
institution_name=state.get("institution_name"),
title=state["title"],
status=state["status"],
created_at=state["created_at"],
updated_at=state["updated_at"],
expires_in_seconds=state["expires_in_seconds"],
)
def _sse(self, event: str, data: dict) -> str:
"""SSE 封装:统一输出 event + data 格式。"""
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
def _with_session(self, assistant_session_id: str | None, data: dict) -> dict:
"""SSE 数据增强:会话式接口返回 assistant_session_id,旧接口保持兼容。"""
if assistant_session_id:
return {"assistant_session_id": assistant_session_id, **data}
return data
@@ -0,0 +1,141 @@
import json
import time
import uuid
from datetime import datetime
from threading import Lock
from typing import Any
from app.core.config import settings
from app.core.context import UserContext
class LearningAssistantSessionStore:
"""AI 学习助手短期会话存储:使用 Redis 保存会话状态,测试或降级时使用进程内存。"""
key_prefix = "learning_assistant:session:"
def __init__(self) -> None:
self._lock = Lock()
self._memory_store: dict[str, dict[str, Any]] = {}
self._redis_client = self._create_redis_client()
def create(self, ctx: UserContext, title: str | None = None) -> dict[str, Any]:
"""学习助手会话创建:按当前用户和机构初始化一个短期问答会话。"""
now = self._now()
session_id = f"las_{uuid.uuid4().hex}"
state: dict[str, Any] = {
"assistant_session_id": session_id,
"user_id": ctx.user_id,
"institution_id": ctx.institution_id,
"institution_name": self._profile_value(ctx, "institution_name"),
"title": title or "AI 学习助手",
"status": "active",
"messages": [],
"created_at": now,
"updated_at": now,
"expires_in_seconds": settings.learning_assistant_session_ttl_seconds,
}
self._save(state)
return state
def get(self, assistant_session_id: str, user_id: str) -> dict[str, Any] | None:
"""学习助手会话读取:只返回属于当前用户且未过期的会话。"""
state = self._load(assistant_session_id)
if not state or state.get("user_id") != user_id:
return None
return state
def get_messages(self, assistant_session_id: str, user_id: str, limit: int | None = None) -> list[dict[str, Any]]:
"""学习助手上下文读取:返回当前会话最近若干轮问答,用于提示词拼接。"""
state = self.get(assistant_session_id, user_id)
if not state:
return []
messages = list(state.get("messages") or [])
if limit is None:
limit = settings.learning_assistant_history_limit
return messages[-limit:]
def append_message(
self,
assistant_session_id: str,
user_id: str,
role: str,
content: str,
metadata: dict | None = None,
) -> dict[str, Any] | None:
"""学习助手上下文写入:记录用户问题和 AI 回答,按 TTL 自动过期。"""
state = self.get(assistant_session_id, user_id)
if not state:
return None
messages = list(state.get("messages") or [])
messages.append(
{
"role": role,
"content": content,
"metadata": metadata or {},
"created_at": self._now(),
}
)
max_messages = max(settings.learning_assistant_history_limit * 2, 2)
state["messages"] = messages[-max_messages:]
state["updated_at"] = self._now()
self._save(state)
return state
def _create_redis_client(self):
"""Redis 客户端创建:遵循 runtime memory 配置,失败时按配置降级。"""
if settings.runtime_memory_backend.lower() != "redis":
return None
try:
import redis
client = redis.Redis.from_url(settings.redis_url, decode_responses=True)
client.ping()
return client
except Exception:
if settings.is_production and not settings.runtime_memory_fallback_enabled:
raise RuntimeError("Redis is required for learning assistant sessions")
return None
def _save(self, state: dict[str, Any]) -> None:
"""会话保存:Redis 使用 setex,内存 fallback 使用过期时间戳。"""
session_id = state["assistant_session_id"]
ttl = settings.learning_assistant_session_ttl_seconds
if self._redis_client is not None:
self._redis_client.setex(self._key(session_id), ttl, json.dumps(state, ensure_ascii=False))
return
with self._lock:
self._memory_store[session_id] = {"expires_at": time.time() + ttl, "state": state}
def _load(self, assistant_session_id: str) -> dict[str, Any] | None:
"""会话加载:读取并校验短期会话是否仍然有效。"""
if self._redis_client is not None:
raw = self._redis_client.get(self._key(assistant_session_id))
if not raw:
return None
return json.loads(raw)
with self._lock:
item = self._memory_store.get(assistant_session_id)
if not item:
return None
if item["expires_at"] < time.time():
self._memory_store.pop(assistant_session_id, None)
return None
return dict(item["state"])
def _key(self, assistant_session_id: str) -> str:
"""Redis key 生成:与训练短期 memory 隔离。"""
return f"{self.key_prefix}{assistant_session_id}"
def _profile_value(self, ctx: UserContext, key: str) -> Any:
"""用户资料读取:从 Django `/me` 标准化 profile 中提取扩展字段。"""
if not ctx.profile:
return None
return ctx.profile.get(key)
def _now(self) -> str:
"""时间格式化:返回 ISO 字符串,便于前端展示和日志排查。"""
return datetime.utcnow().isoformat(timespec="seconds") + "Z"
learning_assistant_session_store = LearningAssistantSessionStore()
+1 -1
View File
@@ -454,7 +454,7 @@ class PdfExportService:
def _mode_label(self, mode: str) -> str:
"""训练模式标签:转换内部枚举为中文显示。"""
return {"practice": "练习模式", "teaching": "教学互动模式", "novice": "练习模式"}.get(mode, mode)
return {"practice": "练习模式", "teaching": "教学互动模式"}.get(mode, mode)
def _format_datetime(self, value: datetime | None) -> str:
"""时间格式化:统一报告中的时间展示。"""
+1 -30
View File
@@ -15,7 +15,6 @@ from app.models.training import SessionSubmission, TrainingSession
from app.repositories.case_repository import CaseRepository
from app.repositories.session_repository import SessionRepository
from app.schemas.session import (
ChatResponse,
CreateSessionRequest,
CreateSessionResponse,
SessionStatusResponse,
@@ -80,34 +79,6 @@ class SessionService:
patient_config=patient_config,
)
async def chat(self, ctx: UserContext, session_id: int, message: str) -> ChatResponse:
"""问诊对话:拼接病例上下文、短期记忆和用户输入后调用 Patient Agent。"""
session = self._get_session(session_id, ctx.user_id)
if session.status != "inquiry":
raise AppError("SESSION_STATUS_INVALID", "chat is only allowed in inquiry status", 400)
case = self.case_repo.get_active_case(session.case_id)
if not case:
raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404)
start = time.perf_counter()
memory_messages = runtime_memory.get_messages(session.memory_key)
runtime_memory.add_message(session.memory_key or "", "doctor", message)
try:
response = await asyncio.wait_for(
self.orchestrator.patient_reply(session, case, memory_messages, message),
timeout=settings.llm_chat_timeout_seconds,
)
except TimeoutError as exc:
raise AppError("LLM_CALL_TIMEOUT", "AI 病人回复超时,请稍后重试或切换为普通问诊", 504) from exc
runtime_memory.add_message(session.memory_key or "", "patient", response.content)
self.audit.log(ctx, "session.chat", "training_session", str(session.id), session.id)
return ChatResponse(
reply=response.content,
latency_ms=response.latency_ms or int((time.perf_counter() - start) * 1000),
model=response.model,
fallback_used=response.model.startswith("mock-fallback"),
)
async def stream_chat(self, ctx: UserContext, session_id: int, message: str) -> AsyncIterator[str]:
"""流式问诊:返回 SSE 格式的 AI 病人回复。"""
session = self._get_session(session_id, ctx.user_id)
@@ -214,7 +185,7 @@ class SessionService:
return f"event: error\ndata: {payload}\n\n"
async def generate_hints(self, ctx: UserContext, session_id: int, payload: HintRequest) -> HintResponse:
"""新手提示:基于当前会话上下文、已申请检查和病例信息生成提醒。"""
"""练习提示:基于当前会话上下文、已申请检查和病例信息生成提醒。"""
session = self._get_session(session_id, ctx.user_id)
if session.mode != "practice":
raise AppError("SESSION_STATUS_INVALID", "hints are only available in practice mode", 400)