import asyncio import time import uuid import json import logging from collections.abc import AsyncIterator from datetime import datetime from sqlalchemy.orm import Session from app.agents.orchestrator import MedicalConsultationOrchestrator from app.core.config import settings from app.core.context import UserContext from app.core.exceptions import AppError from app.models.training import SessionSubmission, TrainingSession from app.repositories.case_repository import CaseRepository from app.repositories.session_repository import SessionRepository from app.schemas.session import ( ChatResponse, CreateSessionRequest, CreateSessionResponse, SessionStatusResponse, SubmitDiagnosisRequest, SubmitDiagnosisResponse, SubmitTreatmentRequest, SubmitTreatmentResponse, HintRequest, HintResponse, ) from app.services.audit_service import AuditService from app.services.runtime_memory import runtime_memory logger = logging.getLogger(__name__) class SessionService: """会话服务:负责创建会话、问诊、多阶段状态流转和诊断治疗提交。""" def __init__(self, db: Session) -> None: self.db = db self.case_repo = CaseRepository(db) self.session_repo = SessionRepository(db) self.audit = AuditService(db) self.orchestrator = MedicalConsultationOrchestrator() def create_session(self, ctx: UserContext, payload: CreateSessionRequest) -> CreateSessionResponse: """会话创建:校验病例并初始化短期 memory。""" case = self.case_repo.get_active_case(payload.case_id) if not case: raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404) session_code = f"sess_{datetime.utcnow().strftime('%Y%m%d%H%M%S')}_{uuid.uuid4().hex[:8]}" memory_key = f"mem:{session_code}" session = self.session_repo.create_session( TrainingSession( session_code=session_code, user_id=ctx.user_id, tenant_id=ctx.tenant_id, class_id=ctx.class_id, entry_scene=ctx.entry_scene, case_id=case.id, training_type=payload.training_type, mode=payload.mode, score_type=payload.score_type, status="inquiry", started_at=datetime.utcnow(), memory_key=memory_key, metadata_={"source": "demo"}, ) ) patient_opening = case.patient_opening or "家长:医生,孩子这几天不舒服,想请您看看。" runtime_memory.create(memory_key, patient_opening) self.audit.log(ctx, "session.create", "training_session", str(session.id), session.id) return CreateSessionResponse( session_id=session.id, session_code=session.session_code, status=session.status, patient_opening=patient_opening, ) async def chat(self, ctx: UserContext, session_id: int, message: str) -> ChatResponse: """问诊对话:拼接病例上下文、短期记忆和用户输入后调用 Patient Agent。""" session = self._get_session(session_id, ctx.user_id) if session.status != "inquiry": raise AppError("SESSION_STATUS_INVALID", "chat is only allowed in inquiry status", 400) case = self.case_repo.get_active_case(session.case_id) if not case: raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404) start = time.perf_counter() memory_messages = runtime_memory.get_messages(session.memory_key) runtime_memory.add_message(session.memory_key or "", "doctor", message) try: response = await asyncio.wait_for( self.orchestrator.patient_reply(session, case, memory_messages, message), timeout=settings.llm_chat_timeout_seconds, ) except TimeoutError as exc: raise AppError("LLM_CALL_TIMEOUT", "AI 病人回复超时,请稍后重试或切换为普通问诊", 504) from exc runtime_memory.add_message(session.memory_key or "", "patient", response.content) self.audit.log(ctx, "session.chat", "training_session", str(session.id), session.id) return ChatResponse( reply=response.content, latency_ms=response.latency_ms or int((time.perf_counter() - start) * 1000), model=response.model, fallback_used=response.model.startswith("mock-fallback"), ) async def stream_chat(self, ctx: UserContext, session_id: int, message: str) -> AsyncIterator[str]: """流式问诊:返回 SSE 格式的 AI 病人回复。""" session = self._get_session(session_id, ctx.user_id) if session.status != "inquiry": raise AppError("SESSION_STATUS_INVALID", "chat is only allowed in inquiry status", 400) case = self.case_repo.get_active_case(session.case_id) if not case: raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404) memory_messages = runtime_memory.get_messages(session.memory_key) runtime_memory.add_message(session.memory_key or "", "doctor", message) logger.info( "chat_stream.start session_id=%s user_id=%s message_len=%s", session.id, ctx.user_id, len(message), ) async def event_generator() -> AsyncIterator[str]: full_reply = "" started_at = time.perf_counter() stream_iter = self.orchestrator.patient_stream_reply(session, case, memory_messages, message).__aiter__() while True: elapsed = time.perf_counter() - started_at total_remaining = settings.llm_stream_total_timeout_seconds - elapsed if total_remaining <= 0: logger.warning("chat_stream.total_timeout session_id=%s", session.id) yield self._sse_error("AI 病人回复总耗时超限,请重试或关闭流式模式", "LLM_STREAM_TIMEOUT") return timeout = min( total_remaining, settings.llm_stream_first_token_timeout_seconds if not full_reply else settings.llm_chat_timeout_seconds, ) try: chunk = await asyncio.wait_for(stream_iter.__anext__(), timeout=timeout) except StopAsyncIteration: if full_reply: runtime_memory.add_message(session.memory_key or "", "patient", full_reply) yield self._sse_done( latency_ms=int((time.perf_counter() - started_at) * 1000), first_token_ms=0, model=None, fallback_used=False, ) else: logger.warning("chat_stream.empty_stop session_id=%s", session.id) yield self._sse_error("AI 病人没有返回有效内容,请重试", "LLM_EMPTY_RESPONSE") return except TimeoutError: logger.warning("chat_stream.first_token_timeout session_id=%s", session.id) yield self._sse_error("AI 病人首段回复超时,请重试或关闭流式模式", "LLM_STREAM_TIMEOUT") return except Exception as exc: logger.exception("chat_stream.failed session_id=%s error=%s", session.id, exc.__class__.__name__) yield self._sse_error("AI 病人回复失败,请检查模型配置或稍后重试", "LLM_STREAM_FAILED") return if chunk.done: if not full_reply.strip(): logger.warning("chat_stream.empty_done session_id=%s", session.id) yield self._sse_error("AI 病人没有返回有效内容,请重试", "LLM_EMPTY_RESPONSE") return runtime_memory.add_message(session.memory_key or "", "patient", full_reply) logger.info( "chat_stream.done session_id=%s chars=%s latency_ms=%s model=%s", session.id, len(full_reply), chunk.total_latency_ms, chunk.model, ) yield self._sse_done( latency_ms=chunk.total_latency_ms or int((time.perf_counter() - started_at) * 1000), first_token_ms=chunk.first_token_ms or 0, model=chunk.model, fallback_used=chunk.fallback_used, ) return if chunk.delta: full_reply += chunk.delta logger.debug("chat_stream.delta session_id=%s delta_len=%s", session.id, len(chunk.delta)) delta_payload = json.dumps({"delta": chunk.delta}, ensure_ascii=False) yield f"event: message_delta\ndata: {delta_payload}\n\n" self.audit.log(ctx, "session.chat.stream", "training_session", str(session.id), session.id) return event_generator() def _sse_done(self, latency_ms: int, first_token_ms: int, model: str | None, fallback_used: bool) -> str: """SSE 完成事件:统一返回流式问诊耗时和模型状态。""" done_payload = json.dumps( { "latency_ms": latency_ms, "first_token_ms": first_token_ms, "model": model, "fallback_used": fallback_used, }, ensure_ascii=False, ) return f"event: message_done\ndata: {done_payload}\n\n" def _sse_error(self, message: str, code: str = "LLM_STREAM_TIMEOUT") -> str: """SSE 错误事件:让前端结束 pending 状态并展示用户可读错误。""" payload = json.dumps({"code": code, "message": message}, ensure_ascii=False) return f"event: error\ndata: {payload}\n\n" async def generate_hints(self, ctx: UserContext, session_id: int, payload: HintRequest) -> HintResponse: """新手提示:基于当前会话上下文、已申请检查和病例信息生成提醒。""" session = self._get_session(session_id, ctx.user_id) if session.mode != "practice": raise AppError("SESSION_STATUS_INVALID", "hints are only available in practice mode", 400) if session.status != "inquiry": raise AppError("SESSION_STATUS_INVALID", "hints are only available during inquiry", 400) case = self.case_repo.get_active_case(session.case_id) if not case: raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404) memory_messages = runtime_memory.get_messages(session.memory_key) result = await self.orchestrator.generate_hints(session, case, memory_messages, session.orders, payload.last_user_message) self.audit.log(ctx, "session.hints", "training_session", str(session.id), session.id) return HintResponse(**result) def complete_inquiry(self, ctx: UserContext, session_id: int) -> SessionStatusResponse: """完成问诊:校验至少一轮医生问诊后进入诊断阶段。""" session = self._get_session(session_id, ctx.user_id) if session.status != "inquiry": raise AppError("SESSION_STATUS_INVALID", "only inquiry status can be completed", 400) if not runtime_memory.has_doctor_message(session.memory_key): raise AppError("INQUIRY_REQUIRED", "at least one doctor message is required", 400) self.session_repo.update_status(session, "diagnosis") self.audit.log(ctx, "session.complete_inquiry", "training_session", str(session.id), session.id) return SessionStatusResponse(session_id=session.id, status=session.status) def submit_diagnosis(self, ctx: UserContext, session_id: int, payload: SubmitDiagnosisRequest) -> SubmitDiagnosisResponse: """诊断提交:保存主要诊断、鉴别诊断和诊断依据并进入治疗阶段。""" session = self._get_session(session_id, ctx.user_id) if session.status != "diagnosis": raise AppError("SESSION_STATUS_INVALID", "diagnosis submit is not allowed", 400) submission = self.session_repo.get_submission(session.id) or SessionSubmission(session_id=session.id, user_id=ctx.user_id) submission.primary_diagnosis = payload.primary_diagnosis submission.differential_diagnoses = payload.differential_diagnoses submission.diagnosis_basis = payload.diagnosis_basis submission.diagnosis_submitted_at = datetime.utcnow() self.session_repo.upsert_submission(submission) self.session_repo.update_status(session, "treatment") self.audit.log(ctx, "session.submit_diagnosis", "training_session", str(session.id), session.id) return SubmitDiagnosisResponse(status=session.status) def submit_treatment(self, ctx: UserContext, session_id: int, payload: SubmitTreatmentRequest) -> SubmitTreatmentResponse: """治疗提交:保存治疗方案、风险预案、沟通和随访并进入评价阶段。""" session = self._get_session(session_id, ctx.user_id) if session.status != "treatment": raise AppError("SESSION_STATUS_INVALID", "treatment submit is not allowed", 400) submission = self.session_repo.get_submission(session.id) if not submission: raise AppError("DIAGNOSIS_REQUIRED", "diagnosis submission is required", 400) submission.treatment_principle = payload.treatment_principle submission.treatment_measures = payload.treatment_measures submission.risk_plan = payload.risk_plan submission.communication = payload.communication submission.follow_up = payload.follow_up submission.treatment_submitted_at = datetime.utcnow() self.session_repo.upsert_submission(submission) self.session_repo.update_status(session, "evaluating") self.audit.log(ctx, "session.submit_treatment", "training_session", str(session.id), session.id) return SubmitTreatmentResponse(status=session.status) def _get_session(self, session_id: int, user_id: str) -> TrainingSession: """会话归属:按 session_id 和 user_id 校验会话隔离。""" session = self.session_repo.get_owned_session(session_id, user_id) if not session: raise AppError("SESSION_NOT_FOUND", "session not found or not owned by current user", 404) return session