Files
fastapi/app/services/session_service.py
T
2026-06-08 15:16:07 +08:00

337 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import time
import uuid
import json
import logging
from collections.abc import AsyncIterator
from datetime import datetime
from sqlalchemy.orm import Session
from app.agents.orchestrator import MedicalConsultationOrchestrator
from app.core.config import settings
from app.core.context import UserContext
from app.core.exceptions import AppError
from app.models.training import SessionSubmission, TrainingSession
from app.repositories.case_repository import CaseRepository
from app.repositories.session_repository import SessionRepository
from app.schemas.session import (
ChatResponse,
CreateSessionRequest,
CreateSessionResponse,
SessionStatusResponse,
SubmitDiagnosisRequest,
SubmitDiagnosisResponse,
SubmitTreatmentRequest,
SubmitTreatmentResponse,
HintRequest,
HintResponse,
)
from app.services.audit_service import AuditService
from app.services.runtime_memory import runtime_memory
from app.services.training_config_service import TrainingConfigService
logger = logging.getLogger(__name__)
class SessionService:
"""会话服务:负责创建会话、问诊、多阶段状态流转和诊断治疗提交。"""
def __init__(self, db: Session) -> None:
self.db = db
self.case_repo = CaseRepository(db)
self.session_repo = SessionRepository(db)
self.audit = AuditService(db)
self.orchestrator = MedicalConsultationOrchestrator()
def create_session(self, ctx: UserContext, payload: CreateSessionRequest) -> CreateSessionResponse:
"""会话创建:校验病例并初始化短期 memory。"""
case = self.case_repo.get_active_case(payload.case_id)
if not case:
raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404)
patient_config = TrainingConfigService(self.db).normalize_patient_config(payload.patient_config, case)
session_code = f"sess_{datetime.utcnow().strftime('%Y%m%d%H%M%S')}_{uuid.uuid4().hex[:8]}"
memory_key = f"mem:{session_code}"
session = self.session_repo.create_session(
TrainingSession(
session_code=session_code,
user_id=ctx.user_id,
tenant_id=ctx.tenant_id,
class_id=ctx.class_id,
entry_scene=ctx.entry_scene,
case_id=case.id,
training_type=payload.training_type,
mode=payload.mode,
score_type=payload.score_type,
status="inquiry",
started_at=datetime.utcnow(),
memory_key=memory_key,
metadata_={"source": "demo", "patient_config": patient_config},
)
)
patient_opening = case.patient_opening or "家长:医生,孩子这几天不舒服,想请您看看。"
runtime_memory.create(memory_key, patient_opening)
self.audit.log(ctx, "session.create", "training_session", str(session.id), session.id)
return CreateSessionResponse(
session_id=session.id,
session_code=session.session_code,
status=session.status,
patient_opening=patient_opening,
patient_config=patient_config,
)
async def chat(self, ctx: UserContext, session_id: int, message: str) -> ChatResponse:
"""问诊对话:拼接病例上下文、短期记忆和用户输入后调用 Patient Agent。"""
session = self._get_session(session_id, ctx.user_id)
if session.status != "inquiry":
raise AppError("SESSION_STATUS_INVALID", "chat is only allowed in inquiry status", 400)
case = self.case_repo.get_active_case(session.case_id)
if not case:
raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404)
start = time.perf_counter()
memory_messages = runtime_memory.get_messages(session.memory_key)
runtime_memory.add_message(session.memory_key or "", "doctor", message)
try:
response = await asyncio.wait_for(
self.orchestrator.patient_reply(session, case, memory_messages, message),
timeout=settings.llm_chat_timeout_seconds,
)
except TimeoutError as exc:
raise AppError("LLM_CALL_TIMEOUT", "AI 病人回复超时,请稍后重试或切换为普通问诊", 504) from exc
runtime_memory.add_message(session.memory_key or "", "patient", response.content)
self.audit.log(ctx, "session.chat", "training_session", str(session.id), session.id)
return ChatResponse(
reply=response.content,
latency_ms=response.latency_ms or int((time.perf_counter() - start) * 1000),
model=response.model,
fallback_used=response.model.startswith("mock-fallback"),
)
async def stream_chat(self, ctx: UserContext, session_id: int, message: str) -> AsyncIterator[str]:
"""流式问诊:返回 SSE 格式的 AI 病人回复。"""
session = self._get_session(session_id, ctx.user_id)
if session.status != "inquiry":
raise AppError("SESSION_STATUS_INVALID", "chat is only allowed in inquiry status", 400)
case = self.case_repo.get_active_case(session.case_id)
if not case:
raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404)
memory_messages = runtime_memory.get_messages(session.memory_key)
runtime_memory.add_message(session.memory_key or "", "doctor", message)
logger.info(
"chat_stream.start session_id=%s user_id=%s message_len=%s",
session.id,
ctx.user_id,
len(message),
)
async def event_generator() -> AsyncIterator[str]:
full_reply = ""
started_at = time.perf_counter()
stream_iter = self.orchestrator.patient_stream_reply(session, case, memory_messages, message).__aiter__()
while True:
elapsed = time.perf_counter() - started_at
total_remaining = settings.llm_stream_total_timeout_seconds - elapsed
if total_remaining <= 0:
logger.warning("chat_stream.total_timeout session_id=%s", session.id)
yield self._sse_error("AI 病人回复总耗时超限,请重试或关闭流式模式", "LLM_STREAM_TIMEOUT")
return
timeout = min(
total_remaining,
settings.llm_stream_first_token_timeout_seconds if not full_reply else settings.llm_chat_timeout_seconds,
)
try:
chunk = await asyncio.wait_for(stream_iter.__anext__(), timeout=timeout)
except StopAsyncIteration:
if full_reply:
runtime_memory.add_message(session.memory_key or "", "patient", full_reply)
yield self._sse_done(
latency_ms=int((time.perf_counter() - started_at) * 1000),
first_token_ms=0,
model=None,
fallback_used=False,
)
else:
logger.warning("chat_stream.empty_stop session_id=%s", session.id)
yield self._sse_error("AI 病人没有返回有效内容,请重试", "LLM_EMPTY_RESPONSE")
return
except TimeoutError:
logger.warning("chat_stream.first_token_timeout session_id=%s", session.id)
yield self._sse_error("AI 病人首段回复超时,请重试或关闭流式模式", "LLM_STREAM_TIMEOUT")
return
except Exception as exc:
logger.exception("chat_stream.failed session_id=%s error=%s", session.id, exc.__class__.__name__)
yield self._sse_error("AI 病人回复失败,请检查模型配置或稍后重试", "LLM_STREAM_FAILED")
return
if chunk.done:
if not full_reply.strip():
logger.warning("chat_stream.empty_done session_id=%s", session.id)
yield self._sse_error("AI 病人没有返回有效内容,请重试", "LLM_EMPTY_RESPONSE")
return
runtime_memory.add_message(session.memory_key or "", "patient", full_reply)
logger.info(
"chat_stream.done session_id=%s chars=%s latency_ms=%s model=%s",
session.id,
len(full_reply),
chunk.total_latency_ms,
chunk.model,
)
yield self._sse_done(
latency_ms=chunk.total_latency_ms or int((time.perf_counter() - started_at) * 1000),
first_token_ms=chunk.first_token_ms or 0,
model=chunk.model,
fallback_used=chunk.fallback_used,
)
return
if chunk.delta:
full_reply += chunk.delta
logger.debug("chat_stream.delta session_id=%s delta_len=%s", session.id, len(chunk.delta))
delta_payload = json.dumps({"delta": chunk.delta}, ensure_ascii=False)
yield f"event: message_delta\ndata: {delta_payload}\n\n"
self.audit.log(ctx, "session.chat.stream", "training_session", str(session.id), session.id)
return event_generator()
def _sse_done(self, latency_ms: int, first_token_ms: int, model: str | None, fallback_used: bool) -> str:
"""SSE 完成事件:统一返回流式问诊耗时和模型状态。"""
done_payload = json.dumps(
{
"latency_ms": latency_ms,
"first_token_ms": first_token_ms,
"model": model,
"fallback_used": fallback_used,
},
ensure_ascii=False,
)
return f"event: message_done\ndata: {done_payload}\n\n"
def _sse_error(self, message: str, code: str = "LLM_STREAM_TIMEOUT") -> str:
"""SSE 错误事件:让前端结束 pending 状态并展示用户可读错误。"""
payload = json.dumps({"code": code, "message": message}, ensure_ascii=False)
return f"event: error\ndata: {payload}\n\n"
async def generate_hints(self, ctx: UserContext, session_id: int, payload: HintRequest) -> HintResponse:
"""新手提示:基于当前会话上下文、已申请检查和病例信息生成提醒。"""
session = self._get_session(session_id, ctx.user_id)
if session.mode != "practice":
raise AppError("SESSION_STATUS_INVALID", "hints are only available in practice mode", 400)
if session.status != "inquiry":
raise AppError("SESSION_STATUS_INVALID", "hints are only available during inquiry", 400)
case = self.case_repo.get_active_case(session.case_id)
if not case:
raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404)
memory_messages = runtime_memory.get_messages(session.memory_key)
result = await self.orchestrator.generate_hints(session, case, memory_messages, session.orders, payload.last_user_message)
self.audit.log(ctx, "session.hints", "training_session", str(session.id), session.id)
return HintResponse(**result)
async def stream_hints(self, ctx: UserContext, session_id: int, payload: HintRequest) -> AsyncIterator[str]:
"""流式练习提示:把结构化提示压缩成一句话并用 SSE 返回给前端。"""
started_at = time.perf_counter()
try:
hint_result = await self.generate_hints(ctx, session_id, payload)
sentence = self._build_hint_sentence(hint_result)
except AppError as exc:
error_message = exc.message
error_code = exc.code
async def app_error_generator() -> AsyncIterator[str]:
yield self._sse_error(error_message, error_code)
return app_error_generator()
except Exception:
logger.exception("hint_stream.failed session_id=%s", session_id)
async def error_generator() -> AsyncIterator[str]:
yield self._sse_error("练习提示生成失败,请稍后重试", "HINT_STREAM_FAILED")
return error_generator()
async def event_generator() -> AsyncIterator[str]:
if not sentence:
yield self._sse_error("当前没有生成有效提示,请继续问诊后再试", "HINT_EMPTY")
return
for chunk in self._chunk_text(sentence, size=12):
payload_text = json.dumps({"delta": chunk}, ensure_ascii=False)
yield f"event: hint_delta\ndata: {payload_text}\n\n"
await asyncio.sleep(0)
done_payload = json.dumps({"latency_ms": int((time.perf_counter() - started_at) * 1000)}, ensure_ascii=False)
yield f"event: hint_done\ndata: {done_payload}\n\n"
return event_generator()
def _build_hint_sentence(self, hint_result: HintResponse) -> str:
"""提示压缩:从结构化提示中提炼适合前端流式展示的一句话。"""
parts: list[str] = []
if hint_result.missing_dimensions:
parts.append(f"当前可补充{ ''.join(hint_result.missing_dimensions[:3]) }")
if hint_result.next_questions:
parts.append(f"下一步可问:{hint_result.next_questions[0]}")
elif hint_result.hints:
parts.append(hint_result.hints[0])
if hint_result.recommended_orders:
order = hint_result.recommended_orders[0]
item_code = order.get("item_code") or order.get("item_name") or "关键检查"
reason = order.get("reason") or "用于完善病情判断"
parts.append(f"可考虑申请{item_code}{reason}")
return "".join(parts) + ("" if parts else "")
def _chunk_text(self, text: str, size: int) -> list[str]:
"""文本切片:把一句练习提示拆成短片段,便于前端按 SSE 渐进展示。"""
return [text[index : index + size] for index in range(0, len(text), size)]
def complete_inquiry(self, ctx: UserContext, session_id: int) -> SessionStatusResponse:
"""完成问诊:校验至少一轮医生问诊后进入诊断阶段。"""
session = self._get_session(session_id, ctx.user_id)
if session.status != "inquiry":
raise AppError("SESSION_STATUS_INVALID", "only inquiry status can be completed", 400)
if not runtime_memory.has_doctor_message(session.memory_key):
raise AppError("INQUIRY_REQUIRED", "at least one doctor message is required", 400)
self.session_repo.update_status(session, "diagnosis")
self.audit.log(ctx, "session.complete_inquiry", "training_session", str(session.id), session.id)
return SessionStatusResponse(session_id=session.id, status=session.status)
def submit_diagnosis(self, ctx: UserContext, session_id: int, payload: SubmitDiagnosisRequest) -> SubmitDiagnosisResponse:
"""诊断提交:保存主要诊断、鉴别诊断和诊断依据并进入治疗阶段。"""
session = self._get_session(session_id, ctx.user_id)
if session.status != "diagnosis":
raise AppError("SESSION_STATUS_INVALID", "diagnosis submit is not allowed", 400)
submission = self.session_repo.get_submission(session.id) or SessionSubmission(session_id=session.id, user_id=ctx.user_id)
submission.primary_diagnosis = payload.primary_diagnosis
submission.differential_diagnoses = payload.differential_diagnoses
submission.diagnosis_basis = payload.diagnosis_basis
submission.diagnosis_submitted_at = datetime.utcnow()
self.session_repo.upsert_submission(submission)
self.session_repo.update_status(session, "treatment")
self.audit.log(ctx, "session.submit_diagnosis", "training_session", str(session.id), session.id)
return SubmitDiagnosisResponse(status=session.status)
def submit_treatment(self, ctx: UserContext, session_id: int, payload: SubmitTreatmentRequest) -> SubmitTreatmentResponse:
"""治疗提交:保存治疗方案、风险预案、沟通和随访并进入评价阶段。"""
session = self._get_session(session_id, ctx.user_id)
if session.status != "treatment":
raise AppError("SESSION_STATUS_INVALID", "treatment submit is not allowed", 400)
submission = self.session_repo.get_submission(session.id)
if not submission:
raise AppError("DIAGNOSIS_REQUIRED", "diagnosis submission is required", 400)
submission.treatment_principle = payload.treatment_principle
submission.treatment_measures = payload.treatment_measures
submission.risk_plan = payload.risk_plan
submission.communication = payload.communication
submission.follow_up = payload.follow_up
submission.treatment_submitted_at = datetime.utcnow()
self.session_repo.upsert_submission(submission)
self.session_repo.update_status(session, "evaluating")
self.audit.log(ctx, "session.submit_treatment", "training_session", str(session.id), session.id)
return SubmitTreatmentResponse(status=session.status)
def _get_session(self, session_id: int, user_id: str) -> TrainingSession:
"""会话归属:按 session_id 和 user_id 校验会话隔离。"""
session = self.session_repo.get_owned_session(session_id, user_id)
if not session:
raise AppError("SESSION_NOT_FOUND", "session not found or not owned by current user", 404)
return session