import asyncio import time import uuid import json import logging from collections.abc import AsyncIterator from datetime import datetime from sqlalchemy.orm import Session from app.agents.orchestrator import MedicalConsultationOrchestrator from app.core.config import settings from app.core.context import UserContext from app.core.exceptions import AppError from app.models.training import SessionSubmission, TrainingSession from app.repositories.case_repository import CaseRepository from app.repositories.session_repository import SessionRepository from app.schemas.session import ( ChatResponse, CreateSessionRequest, CreateSessionResponse, SessionStatusResponse, SubmitDiagnosisRequest, SubmitDiagnosisResponse, SubmitTreatmentRequest, SubmitTreatmentResponse, HintRequest, HintResponse, ) from app.services.audit_service import AuditService from app.services.runtime_memory import runtime_memory from app.services.training_config_service import TrainingConfigService logger = logging.getLogger(__name__) class SessionService: """会话服务:负责创建会话、问诊、多阶段状态流转和诊断治疗提交。""" def __init__(self, db: Session) -> None: self.db = db self.case_repo = CaseRepository(db) self.session_repo = SessionRepository(db) self.audit = AuditService(db) self.orchestrator = MedicalConsultationOrchestrator() def create_session(self, ctx: UserContext, payload: CreateSessionRequest) -> CreateSessionResponse: """会话创建:校验病例并初始化短期 memory。""" case = self.case_repo.get_active_case(payload.case_id) if not case: raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404) patient_config = TrainingConfigService(self.db).normalize_patient_config(payload.patient_config, case) session_code = f"sess_{datetime.utcnow().strftime('%Y%m%d%H%M%S')}_{uuid.uuid4().hex[:8]}" memory_key = f"mem:{session_code}" session = self.session_repo.create_session( TrainingSession( session_code=session_code, user_id=ctx.user_id, tenant_id=ctx.tenant_id, class_id=ctx.class_id, entry_scene=ctx.entry_scene, case_id=case.id, training_type=payload.training_type, mode=payload.mode, score_type=payload.score_type, status="inquiry", started_at=datetime.utcnow(), memory_key=memory_key, metadata_={"source": "demo", "patient_config": patient_config}, ) ) patient_opening = case.patient_opening or "家长:医生,孩子这几天不舒服,想请您看看。" runtime_memory.create(memory_key, patient_opening) self.audit.log(ctx, "session.create", "training_session", str(session.id), session.id) return CreateSessionResponse( session_id=session.id, session_code=session.session_code, status=session.status, patient_opening=patient_opening, patient_config=patient_config, ) async def chat(self, ctx: UserContext, session_id: int, message: str) -> ChatResponse: """问诊对话:拼接病例上下文、短期记忆和用户输入后调用 Patient Agent。""" session = self._get_session(session_id, ctx.user_id) if session.status != "inquiry": raise AppError("SESSION_STATUS_INVALID", "chat is only allowed in inquiry status", 400) case = self.case_repo.get_active_case(session.case_id) if not case: raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404) start = time.perf_counter() memory_messages = runtime_memory.get_messages(session.memory_key) runtime_memory.add_message(session.memory_key or "", "doctor", message) try: response = await asyncio.wait_for( self.orchestrator.patient_reply(session, case, memory_messages, message), timeout=settings.llm_chat_timeout_seconds, ) except TimeoutError as exc: raise AppError("LLM_CALL_TIMEOUT", "AI 病人回复超时,请稍后重试或切换为普通问诊", 504) from exc runtime_memory.add_message(session.memory_key or "", "patient", response.content) self.audit.log(ctx, "session.chat", "training_session", str(session.id), session.id) return ChatResponse( reply=response.content, latency_ms=response.latency_ms or int((time.perf_counter() - start) * 1000), model=response.model, fallback_used=response.model.startswith("mock-fallback"), ) async def stream_chat(self, ctx: UserContext, session_id: int, message: str) -> AsyncIterator[str]: """流式问诊:返回 SSE 格式的 AI 病人回复。""" session = self._get_session(session_id, ctx.user_id) if session.status != "inquiry": raise AppError("SESSION_STATUS_INVALID", "chat is only allowed in inquiry status", 400) case = self.case_repo.get_active_case(session.case_id) if not case: raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404) memory_messages = runtime_memory.get_messages(session.memory_key) runtime_memory.add_message(session.memory_key or "", "doctor", message) logger.info( "chat_stream.start session_id=%s user_id=%s message_len=%s", session.id, ctx.user_id, len(message), ) async def event_generator() -> AsyncIterator[str]: full_reply = "" started_at = time.perf_counter() stream_iter = self.orchestrator.patient_stream_reply(session, case, memory_messages, message).__aiter__() while True: elapsed = time.perf_counter() - started_at total_remaining = settings.llm_stream_total_timeout_seconds - elapsed if total_remaining <= 0: logger.warning("chat_stream.total_timeout session_id=%s", session.id) yield self._sse_error("AI 病人回复总耗时超限,请重试或关闭流式模式", "LLM_STREAM_TIMEOUT") return timeout = min( total_remaining, settings.llm_stream_first_token_timeout_seconds if not full_reply else settings.llm_chat_timeout_seconds, ) try: chunk = await asyncio.wait_for(stream_iter.__anext__(), timeout=timeout) except StopAsyncIteration: if full_reply: runtime_memory.add_message(session.memory_key or "", "patient", full_reply) yield self._sse_done( latency_ms=int((time.perf_counter() - started_at) * 1000), first_token_ms=0, model=None, fallback_used=False, ) else: logger.warning("chat_stream.empty_stop session_id=%s", session.id) yield self._sse_error("AI 病人没有返回有效内容,请重试", "LLM_EMPTY_RESPONSE") return except TimeoutError: logger.warning("chat_stream.first_token_timeout session_id=%s", session.id) yield self._sse_error("AI 病人首段回复超时,请重试或关闭流式模式", "LLM_STREAM_TIMEOUT") return except Exception as exc: logger.exception("chat_stream.failed session_id=%s error=%s", session.id, exc.__class__.__name__) yield self._sse_error("AI 病人回复失败,请检查模型配置或稍后重试", "LLM_STREAM_FAILED") return if chunk.done: if not full_reply.strip(): logger.warning("chat_stream.empty_done session_id=%s", session.id) yield self._sse_error("AI 病人没有返回有效内容,请重试", "LLM_EMPTY_RESPONSE") return runtime_memory.add_message(session.memory_key or "", "patient", full_reply) logger.info( "chat_stream.done session_id=%s chars=%s latency_ms=%s model=%s", session.id, len(full_reply), chunk.total_latency_ms, chunk.model, ) yield self._sse_done( latency_ms=chunk.total_latency_ms or int((time.perf_counter() - started_at) * 1000), first_token_ms=chunk.first_token_ms or 0, model=chunk.model, fallback_used=chunk.fallback_used, ) return if chunk.delta: full_reply += chunk.delta logger.debug("chat_stream.delta session_id=%s delta_len=%s", session.id, len(chunk.delta)) delta_payload = json.dumps({"delta": chunk.delta}, ensure_ascii=False) yield f"event: message_delta\ndata: {delta_payload}\n\n" self.audit.log(ctx, "session.chat.stream", "training_session", str(session.id), session.id) return event_generator() def _sse_done(self, latency_ms: int, first_token_ms: int, model: str | None, fallback_used: bool) -> str: """SSE 完成事件:统一返回流式问诊耗时和模型状态。""" done_payload = json.dumps( { "latency_ms": latency_ms, "first_token_ms": first_token_ms, "model": model, "fallback_used": fallback_used, }, ensure_ascii=False, ) return f"event: message_done\ndata: {done_payload}\n\n" def _sse_error(self, message: str, code: str = "LLM_STREAM_TIMEOUT") -> str: """SSE 错误事件:让前端结束 pending 状态并展示用户可读错误。""" payload = json.dumps({"code": code, "message": message}, ensure_ascii=False) return f"event: error\ndata: {payload}\n\n" async def generate_hints(self, ctx: UserContext, session_id: int, payload: HintRequest) -> HintResponse: """新手提示:基于当前会话上下文、已申请检查和病例信息生成提醒。""" session = self._get_session(session_id, ctx.user_id) if session.mode != "practice": raise AppError("SESSION_STATUS_INVALID", "hints are only available in practice mode", 400) if session.status != "inquiry": raise AppError("SESSION_STATUS_INVALID", "hints are only available during inquiry", 400) case = self.case_repo.get_active_case(session.case_id) if not case: raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404) memory_messages = runtime_memory.get_messages(session.memory_key) result = await self.orchestrator.generate_hints(session, case, memory_messages, session.orders, payload.last_user_message) self.audit.log(ctx, "session.hints", "training_session", str(session.id), session.id) return HintResponse(**result) async def stream_hints(self, ctx: UserContext, session_id: int, payload: HintRequest) -> AsyncIterator[str]: """流式练习提示:把结构化提示压缩成一句话并用 SSE 返回给前端。""" started_at = time.perf_counter() try: hint_result = await self.generate_hints(ctx, session_id, payload) sentence = self._build_hint_sentence(hint_result) except AppError as exc: error_message = exc.message error_code = exc.code async def app_error_generator() -> AsyncIterator[str]: yield self._sse_error(error_message, error_code) return app_error_generator() except Exception: logger.exception("hint_stream.failed session_id=%s", session_id) async def error_generator() -> AsyncIterator[str]: yield self._sse_error("练习提示生成失败,请稍后重试", "HINT_STREAM_FAILED") return error_generator() async def event_generator() -> AsyncIterator[str]: if not sentence: yield self._sse_error("当前没有生成有效提示,请继续问诊后再试", "HINT_EMPTY") return for chunk in self._chunk_text(sentence, size=12): payload_text = json.dumps({"delta": chunk}, ensure_ascii=False) yield f"event: hint_delta\ndata: {payload_text}\n\n" await asyncio.sleep(0) done_payload = json.dumps({"latency_ms": int((time.perf_counter() - started_at) * 1000)}, ensure_ascii=False) yield f"event: hint_done\ndata: {done_payload}\n\n" return event_generator() def _build_hint_sentence(self, hint_result: HintResponse) -> str: """提示压缩:从结构化提示中提炼适合前端流式展示的一句话。""" parts: list[str] = [] if hint_result.missing_dimensions: parts.append(f"当前可补充{ '、'.join(hint_result.missing_dimensions[:3]) }") if hint_result.next_questions: parts.append(f"下一步可问:{hint_result.next_questions[0]}") elif hint_result.hints: parts.append(hint_result.hints[0]) if hint_result.recommended_orders: order = hint_result.recommended_orders[0] item_code = order.get("item_code") or order.get("item_name") or "关键检查" reason = order.get("reason") or "用于完善病情判断" parts.append(f"可考虑申请{item_code},{reason}") return ";".join(parts) + ("。" if parts else "") def _chunk_text(self, text: str, size: int) -> list[str]: """文本切片:把一句练习提示拆成短片段,便于前端按 SSE 渐进展示。""" return [text[index : index + size] for index in range(0, len(text), size)] def complete_inquiry(self, ctx: UserContext, session_id: int) -> SessionStatusResponse: """完成问诊:校验至少一轮医生问诊后进入诊断阶段。""" session = self._get_session(session_id, ctx.user_id) if session.status != "inquiry": raise AppError("SESSION_STATUS_INVALID", "only inquiry status can be completed", 400) if not runtime_memory.has_doctor_message(session.memory_key): raise AppError("INQUIRY_REQUIRED", "at least one doctor message is required", 400) self.session_repo.update_status(session, "diagnosis") self.audit.log(ctx, "session.complete_inquiry", "training_session", str(session.id), session.id) return SessionStatusResponse(session_id=session.id, status=session.status) def submit_diagnosis(self, ctx: UserContext, session_id: int, payload: SubmitDiagnosisRequest) -> SubmitDiagnosisResponse: """诊断提交:保存主要诊断、鉴别诊断和诊断依据并进入治疗阶段。""" session = self._get_session(session_id, ctx.user_id) if session.status != "diagnosis": raise AppError("SESSION_STATUS_INVALID", "diagnosis submit is not allowed", 400) submission = self.session_repo.get_submission(session.id) or SessionSubmission(session_id=session.id, user_id=ctx.user_id) submission.primary_diagnosis = payload.primary_diagnosis submission.differential_diagnoses = payload.differential_diagnoses submission.diagnosis_basis = payload.diagnosis_basis submission.diagnosis_submitted_at = datetime.utcnow() self.session_repo.upsert_submission(submission) self.session_repo.update_status(session, "treatment") self.audit.log(ctx, "session.submit_diagnosis", "training_session", str(session.id), session.id) return SubmitDiagnosisResponse(status=session.status) def submit_treatment(self, ctx: UserContext, session_id: int, payload: SubmitTreatmentRequest) -> SubmitTreatmentResponse: """治疗提交:保存治疗方案、风险预案、沟通和随访并进入评价阶段。""" session = self._get_session(session_id, ctx.user_id) if session.status != "treatment": raise AppError("SESSION_STATUS_INVALID", "treatment submit is not allowed", 400) submission = self.session_repo.get_submission(session.id) if not submission: raise AppError("DIAGNOSIS_REQUIRED", "diagnosis submission is required", 400) submission.treatment_principle = payload.treatment_principle submission.treatment_measures = payload.treatment_measures submission.risk_plan = payload.risk_plan submission.communication = payload.communication submission.follow_up = payload.follow_up submission.treatment_submitted_at = datetime.utcnow() self.session_repo.upsert_submission(submission) self.session_repo.update_status(session, "evaluating") self.audit.log(ctx, "session.submit_treatment", "training_session", str(session.id), session.id) return SubmitTreatmentResponse(status=session.status) def _get_session(self, session_id: int, user_id: str) -> TrainingSession: """会话归属:按 session_id 和 user_id 校验会话隔离。""" session = self.session_repo.get_owned_session(session_id, user_id) if not session: raise AppError("SESSION_NOT_FOUND", "session not found or not owned by current user", 404) return session