Files
fastapi/backend/app/services/session_service.py
T
2026-06-01 09:25:26 +08:00

279 lines
14 KiB
Python

import asyncio
import time
import uuid
import json
import logging
from collections.abc import AsyncIterator
from datetime import datetime
from sqlalchemy.orm import Session
from app.agents.orchestrator import MedicalConsultationOrchestrator
from app.core.config import settings
from app.core.context import UserContext
from app.core.exceptions import AppError
from app.models.training import SessionSubmission, TrainingSession
from app.repositories.case_repository import CaseRepository
from app.repositories.session_repository import SessionRepository
from app.schemas.session import (
ChatResponse,
CreateSessionRequest,
CreateSessionResponse,
SessionStatusResponse,
SubmitDiagnosisRequest,
SubmitDiagnosisResponse,
SubmitTreatmentRequest,
SubmitTreatmentResponse,
HintRequest,
HintResponse,
)
from app.services.audit_service import AuditService
from app.services.runtime_memory import runtime_memory
logger = logging.getLogger(__name__)
class SessionService:
"""会话服务:负责创建会话、问诊、多阶段状态流转和诊断治疗提交。"""
def __init__(self, db: Session) -> None:
self.db = db
self.case_repo = CaseRepository(db)
self.session_repo = SessionRepository(db)
self.audit = AuditService(db)
self.orchestrator = MedicalConsultationOrchestrator()
def create_session(self, ctx: UserContext, payload: CreateSessionRequest) -> CreateSessionResponse:
"""会话创建:校验病例并初始化短期 memory。"""
case = self.case_repo.get_active_case(payload.case_id)
if not case:
raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404)
session_code = f"sess_{datetime.utcnow().strftime('%Y%m%d%H%M%S')}_{uuid.uuid4().hex[:8]}"
memory_key = f"mem:{session_code}"
session = self.session_repo.create_session(
TrainingSession(
session_code=session_code,
user_id=ctx.user_id,
tenant_id=ctx.tenant_id,
class_id=ctx.class_id,
entry_scene=ctx.entry_scene,
case_id=case.id,
training_type=payload.training_type,
mode=payload.mode,
score_type=payload.score_type,
status="inquiry",
started_at=datetime.utcnow(),
memory_key=memory_key,
metadata_={"source": "demo"},
)
)
patient_opening = case.patient_opening or "家长:医生,孩子这几天不舒服,想请您看看。"
runtime_memory.create(memory_key, patient_opening)
self.audit.log(ctx, "session.create", "training_session", str(session.id), session.id)
return CreateSessionResponse(
session_id=session.id,
session_code=session.session_code,
status=session.status,
patient_opening=patient_opening,
)
async def chat(self, ctx: UserContext, session_id: int, message: str) -> ChatResponse:
"""问诊对话:拼接病例上下文、短期记忆和用户输入后调用 Patient Agent。"""
session = self._get_session(session_id, ctx.user_id)
if session.status != "inquiry":
raise AppError("SESSION_STATUS_INVALID", "chat is only allowed in inquiry status", 400)
case = self.case_repo.get_active_case(session.case_id)
if not case:
raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404)
start = time.perf_counter()
memory_messages = runtime_memory.get_messages(session.memory_key)
runtime_memory.add_message(session.memory_key or "", "doctor", message)
try:
response = await asyncio.wait_for(
self.orchestrator.patient_reply(session, case, memory_messages, message),
timeout=settings.llm_chat_timeout_seconds,
)
except TimeoutError as exc:
raise AppError("LLM_CALL_TIMEOUT", "AI 病人回复超时,请稍后重试或切换为普通问诊", 504) from exc
runtime_memory.add_message(session.memory_key or "", "patient", response.content)
self.audit.log(ctx, "session.chat", "training_session", str(session.id), session.id)
return ChatResponse(
reply=response.content,
latency_ms=response.latency_ms or int((time.perf_counter() - start) * 1000),
model=response.model,
fallback_used=response.model.startswith("mock-fallback"),
)
async def stream_chat(self, ctx: UserContext, session_id: int, message: str) -> AsyncIterator[str]:
"""流式问诊:返回 SSE 格式的 AI 病人回复。"""
session = self._get_session(session_id, ctx.user_id)
if session.status != "inquiry":
raise AppError("SESSION_STATUS_INVALID", "chat is only allowed in inquiry status", 400)
case = self.case_repo.get_active_case(session.case_id)
if not case:
raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404)
memory_messages = runtime_memory.get_messages(session.memory_key)
runtime_memory.add_message(session.memory_key or "", "doctor", message)
logger.info(
"chat_stream.start session_id=%s user_id=%s message_len=%s",
session.id,
ctx.user_id,
len(message),
)
async def event_generator() -> AsyncIterator[str]:
full_reply = ""
started_at = time.perf_counter()
stream_iter = self.orchestrator.patient_stream_reply(session, case, memory_messages, message).__aiter__()
while True:
elapsed = time.perf_counter() - started_at
total_remaining = settings.llm_stream_total_timeout_seconds - elapsed
if total_remaining <= 0:
logger.warning("chat_stream.total_timeout session_id=%s", session.id)
yield self._sse_error("AI 病人回复总耗时超限,请重试或关闭流式模式", "LLM_STREAM_TIMEOUT")
return
timeout = min(
total_remaining,
settings.llm_stream_first_token_timeout_seconds if not full_reply else settings.llm_chat_timeout_seconds,
)
try:
chunk = await asyncio.wait_for(stream_iter.__anext__(), timeout=timeout)
except StopAsyncIteration:
if full_reply:
runtime_memory.add_message(session.memory_key or "", "patient", full_reply)
yield self._sse_done(
latency_ms=int((time.perf_counter() - started_at) * 1000),
first_token_ms=0,
model=None,
fallback_used=False,
)
else:
logger.warning("chat_stream.empty_stop session_id=%s", session.id)
yield self._sse_error("AI 病人没有返回有效内容,请重试", "LLM_EMPTY_RESPONSE")
return
except TimeoutError:
logger.warning("chat_stream.first_token_timeout session_id=%s", session.id)
yield self._sse_error("AI 病人首段回复超时,请重试或关闭流式模式", "LLM_STREAM_TIMEOUT")
return
except Exception as exc:
logger.exception("chat_stream.failed session_id=%s error=%s", session.id, exc.__class__.__name__)
yield self._sse_error("AI 病人回复失败,请检查模型配置或稍后重试", "LLM_STREAM_FAILED")
return
if chunk.done:
if not full_reply.strip():
logger.warning("chat_stream.empty_done session_id=%s", session.id)
yield self._sse_error("AI 病人没有返回有效内容,请重试", "LLM_EMPTY_RESPONSE")
return
runtime_memory.add_message(session.memory_key or "", "patient", full_reply)
logger.info(
"chat_stream.done session_id=%s chars=%s latency_ms=%s model=%s",
session.id,
len(full_reply),
chunk.total_latency_ms,
chunk.model,
)
yield self._sse_done(
latency_ms=chunk.total_latency_ms or int((time.perf_counter() - started_at) * 1000),
first_token_ms=chunk.first_token_ms or 0,
model=chunk.model,
fallback_used=chunk.fallback_used,
)
return
if chunk.delta:
full_reply += chunk.delta
logger.debug("chat_stream.delta session_id=%s delta_len=%s", session.id, len(chunk.delta))
delta_payload = json.dumps({"delta": chunk.delta}, ensure_ascii=False)
yield f"event: message_delta\ndata: {delta_payload}\n\n"
self.audit.log(ctx, "session.chat.stream", "training_session", str(session.id), session.id)
return event_generator()
def _sse_done(self, latency_ms: int, first_token_ms: int, model: str | None, fallback_used: bool) -> str:
"""SSE 完成事件:统一返回流式问诊耗时和模型状态。"""
done_payload = json.dumps(
{
"latency_ms": latency_ms,
"first_token_ms": first_token_ms,
"model": model,
"fallback_used": fallback_used,
},
ensure_ascii=False,
)
return f"event: message_done\ndata: {done_payload}\n\n"
def _sse_error(self, message: str, code: str = "LLM_STREAM_TIMEOUT") -> str:
"""SSE 错误事件:让前端结束 pending 状态并展示用户可读错误。"""
payload = json.dumps({"code": code, "message": message}, ensure_ascii=False)
return f"event: error\ndata: {payload}\n\n"
async def generate_hints(self, ctx: UserContext, session_id: int, payload: HintRequest) -> HintResponse:
"""新手提示:基于当前会话上下文、已申请检查和病例信息生成提醒。"""
session = self._get_session(session_id, ctx.user_id)
if session.mode != "practice":
raise AppError("SESSION_STATUS_INVALID", "hints are only available in practice mode", 400)
if session.status != "inquiry":
raise AppError("SESSION_STATUS_INVALID", "hints are only available during inquiry", 400)
case = self.case_repo.get_active_case(session.case_id)
if not case:
raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404)
memory_messages = runtime_memory.get_messages(session.memory_key)
result = await self.orchestrator.generate_hints(session, case, memory_messages, session.orders, payload.last_user_message)
self.audit.log(ctx, "session.hints", "training_session", str(session.id), session.id)
return HintResponse(**result)
def complete_inquiry(self, ctx: UserContext, session_id: int) -> SessionStatusResponse:
"""完成问诊:校验至少一轮医生问诊后进入诊断阶段。"""
session = self._get_session(session_id, ctx.user_id)
if session.status != "inquiry":
raise AppError("SESSION_STATUS_INVALID", "only inquiry status can be completed", 400)
if not runtime_memory.has_doctor_message(session.memory_key):
raise AppError("INQUIRY_REQUIRED", "at least one doctor message is required", 400)
self.session_repo.update_status(session, "diagnosis")
self.audit.log(ctx, "session.complete_inquiry", "training_session", str(session.id), session.id)
return SessionStatusResponse(session_id=session.id, status=session.status)
def submit_diagnosis(self, ctx: UserContext, session_id: int, payload: SubmitDiagnosisRequest) -> SubmitDiagnosisResponse:
"""诊断提交:保存主要诊断、鉴别诊断和诊断依据并进入治疗阶段。"""
session = self._get_session(session_id, ctx.user_id)
if session.status != "diagnosis":
raise AppError("SESSION_STATUS_INVALID", "diagnosis submit is not allowed", 400)
submission = self.session_repo.get_submission(session.id) or SessionSubmission(session_id=session.id, user_id=ctx.user_id)
submission.primary_diagnosis = payload.primary_diagnosis
submission.differential_diagnoses = payload.differential_diagnoses
submission.diagnosis_basis = payload.diagnosis_basis
submission.diagnosis_submitted_at = datetime.utcnow()
self.session_repo.upsert_submission(submission)
self.session_repo.update_status(session, "treatment")
self.audit.log(ctx, "session.submit_diagnosis", "training_session", str(session.id), session.id)
return SubmitDiagnosisResponse(status=session.status)
def submit_treatment(self, ctx: UserContext, session_id: int, payload: SubmitTreatmentRequest) -> SubmitTreatmentResponse:
"""治疗提交:保存治疗方案、风险预案、沟通和随访并进入评价阶段。"""
session = self._get_session(session_id, ctx.user_id)
if session.status != "treatment":
raise AppError("SESSION_STATUS_INVALID", "treatment submit is not allowed", 400)
submission = self.session_repo.get_submission(session.id)
if not submission:
raise AppError("DIAGNOSIS_REQUIRED", "diagnosis submission is required", 400)
submission.treatment_principle = payload.treatment_principle
submission.treatment_measures = payload.treatment_measures
submission.risk_plan = payload.risk_plan
submission.communication = payload.communication
submission.follow_up = payload.follow_up
submission.treatment_submitted_at = datetime.utcnow()
self.session_repo.upsert_submission(submission)
self.session_repo.update_status(session, "evaluating")
self.audit.log(ctx, "session.submit_treatment", "training_session", str(session.id), session.id)
return SubmitTreatmentResponse(status=session.status)
def _get_session(self, session_id: int, user_id: str) -> TrainingSession:
"""会话归属:按 session_id 和 user_id 校验会话隔离。"""
session = self.session_repo.get_owned_session(session_id, user_id)
if not session:
raise AppError("SESSION_NOT_FOUND", "session not found or not owned by current user", 404)
return session