279 lines
14 KiB
Python
279 lines
14 KiB
Python
|
|
import asyncio
|
||
|
|
import time
|
||
|
|
import uuid
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
from collections.abc import AsyncIterator
|
||
|
|
from datetime import datetime
|
||
|
|
|
||
|
|
from sqlalchemy.orm import Session
|
||
|
|
from app.agents.orchestrator import MedicalConsultationOrchestrator
|
||
|
|
from app.core.config import settings
|
||
|
|
from app.core.context import UserContext
|
||
|
|
from app.core.exceptions import AppError
|
||
|
|
from app.models.training import SessionSubmission, TrainingSession
|
||
|
|
from app.repositories.case_repository import CaseRepository
|
||
|
|
from app.repositories.session_repository import SessionRepository
|
||
|
|
from app.schemas.session import (
|
||
|
|
ChatResponse,
|
||
|
|
CreateSessionRequest,
|
||
|
|
CreateSessionResponse,
|
||
|
|
SessionStatusResponse,
|
||
|
|
SubmitDiagnosisRequest,
|
||
|
|
SubmitDiagnosisResponse,
|
||
|
|
SubmitTreatmentRequest,
|
||
|
|
SubmitTreatmentResponse,
|
||
|
|
HintRequest,
|
||
|
|
HintResponse,
|
||
|
|
)
|
||
|
|
from app.services.audit_service import AuditService
|
||
|
|
from app.services.runtime_memory import runtime_memory
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class SessionService:
|
||
|
|
"""会话服务:负责创建会话、问诊、多阶段状态流转和诊断治疗提交。"""
|
||
|
|
|
||
|
|
def __init__(self, db: Session) -> None:
|
||
|
|
self.db = db
|
||
|
|
self.case_repo = CaseRepository(db)
|
||
|
|
self.session_repo = SessionRepository(db)
|
||
|
|
self.audit = AuditService(db)
|
||
|
|
self.orchestrator = MedicalConsultationOrchestrator()
|
||
|
|
|
||
|
|
def create_session(self, ctx: UserContext, payload: CreateSessionRequest) -> CreateSessionResponse:
|
||
|
|
"""会话创建:校验病例并初始化短期 memory。"""
|
||
|
|
case = self.case_repo.get_active_case(payload.case_id)
|
||
|
|
if not case:
|
||
|
|
raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404)
|
||
|
|
|
||
|
|
session_code = f"sess_{datetime.utcnow().strftime('%Y%m%d%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
||
|
|
memory_key = f"mem:{session_code}"
|
||
|
|
session = self.session_repo.create_session(
|
||
|
|
TrainingSession(
|
||
|
|
session_code=session_code,
|
||
|
|
user_id=ctx.user_id,
|
||
|
|
tenant_id=ctx.tenant_id,
|
||
|
|
class_id=ctx.class_id,
|
||
|
|
entry_scene=ctx.entry_scene,
|
||
|
|
case_id=case.id,
|
||
|
|
training_type=payload.training_type,
|
||
|
|
mode=payload.mode,
|
||
|
|
score_type=payload.score_type,
|
||
|
|
status="inquiry",
|
||
|
|
started_at=datetime.utcnow(),
|
||
|
|
memory_key=memory_key,
|
||
|
|
metadata_={"source": "demo"},
|
||
|
|
)
|
||
|
|
)
|
||
|
|
patient_opening = case.patient_opening or "家长:医生,孩子这几天不舒服,想请您看看。"
|
||
|
|
runtime_memory.create(memory_key, patient_opening)
|
||
|
|
self.audit.log(ctx, "session.create", "training_session", str(session.id), session.id)
|
||
|
|
return CreateSessionResponse(
|
||
|
|
session_id=session.id,
|
||
|
|
session_code=session.session_code,
|
||
|
|
status=session.status,
|
||
|
|
patient_opening=patient_opening,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def chat(self, ctx: UserContext, session_id: int, message: str) -> ChatResponse:
|
||
|
|
"""问诊对话:拼接病例上下文、短期记忆和用户输入后调用 Patient Agent。"""
|
||
|
|
session = self._get_session(session_id, ctx.user_id)
|
||
|
|
if session.status != "inquiry":
|
||
|
|
raise AppError("SESSION_STATUS_INVALID", "chat is only allowed in inquiry status", 400)
|
||
|
|
case = self.case_repo.get_active_case(session.case_id)
|
||
|
|
if not case:
|
||
|
|
raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404)
|
||
|
|
|
||
|
|
start = time.perf_counter()
|
||
|
|
memory_messages = runtime_memory.get_messages(session.memory_key)
|
||
|
|
runtime_memory.add_message(session.memory_key or "", "doctor", message)
|
||
|
|
try:
|
||
|
|
response = await asyncio.wait_for(
|
||
|
|
self.orchestrator.patient_reply(session, case, memory_messages, message),
|
||
|
|
timeout=settings.llm_chat_timeout_seconds,
|
||
|
|
)
|
||
|
|
except TimeoutError as exc:
|
||
|
|
raise AppError("LLM_CALL_TIMEOUT", "AI 病人回复超时,请稍后重试或切换为普通问诊", 504) from exc
|
||
|
|
runtime_memory.add_message(session.memory_key or "", "patient", response.content)
|
||
|
|
self.audit.log(ctx, "session.chat", "training_session", str(session.id), session.id)
|
||
|
|
return ChatResponse(
|
||
|
|
reply=response.content,
|
||
|
|
latency_ms=response.latency_ms or int((time.perf_counter() - start) * 1000),
|
||
|
|
model=response.model,
|
||
|
|
fallback_used=response.model.startswith("mock-fallback"),
|
||
|
|
)
|
||
|
|
|
||
|
|
async def stream_chat(self, ctx: UserContext, session_id: int, message: str) -> AsyncIterator[str]:
|
||
|
|
"""流式问诊:返回 SSE 格式的 AI 病人回复。"""
|
||
|
|
session = self._get_session(session_id, ctx.user_id)
|
||
|
|
if session.status != "inquiry":
|
||
|
|
raise AppError("SESSION_STATUS_INVALID", "chat is only allowed in inquiry status", 400)
|
||
|
|
case = self.case_repo.get_active_case(session.case_id)
|
||
|
|
if not case:
|
||
|
|
raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404)
|
||
|
|
|
||
|
|
memory_messages = runtime_memory.get_messages(session.memory_key)
|
||
|
|
runtime_memory.add_message(session.memory_key or "", "doctor", message)
|
||
|
|
logger.info(
|
||
|
|
"chat_stream.start session_id=%s user_id=%s message_len=%s",
|
||
|
|
session.id,
|
||
|
|
ctx.user_id,
|
||
|
|
len(message),
|
||
|
|
)
|
||
|
|
|
||
|
|
async def event_generator() -> AsyncIterator[str]:
|
||
|
|
full_reply = ""
|
||
|
|
started_at = time.perf_counter()
|
||
|
|
stream_iter = self.orchestrator.patient_stream_reply(session, case, memory_messages, message).__aiter__()
|
||
|
|
while True:
|
||
|
|
elapsed = time.perf_counter() - started_at
|
||
|
|
total_remaining = settings.llm_stream_total_timeout_seconds - elapsed
|
||
|
|
if total_remaining <= 0:
|
||
|
|
logger.warning("chat_stream.total_timeout session_id=%s", session.id)
|
||
|
|
yield self._sse_error("AI 病人回复总耗时超限,请重试或关闭流式模式", "LLM_STREAM_TIMEOUT")
|
||
|
|
return
|
||
|
|
timeout = min(
|
||
|
|
total_remaining,
|
||
|
|
settings.llm_stream_first_token_timeout_seconds if not full_reply else settings.llm_chat_timeout_seconds,
|
||
|
|
)
|
||
|
|
try:
|
||
|
|
chunk = await asyncio.wait_for(stream_iter.__anext__(), timeout=timeout)
|
||
|
|
except StopAsyncIteration:
|
||
|
|
if full_reply:
|
||
|
|
runtime_memory.add_message(session.memory_key or "", "patient", full_reply)
|
||
|
|
yield self._sse_done(
|
||
|
|
latency_ms=int((time.perf_counter() - started_at) * 1000),
|
||
|
|
first_token_ms=0,
|
||
|
|
model=None,
|
||
|
|
fallback_used=False,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
logger.warning("chat_stream.empty_stop session_id=%s", session.id)
|
||
|
|
yield self._sse_error("AI 病人没有返回有效内容,请重试", "LLM_EMPTY_RESPONSE")
|
||
|
|
return
|
||
|
|
except TimeoutError:
|
||
|
|
logger.warning("chat_stream.first_token_timeout session_id=%s", session.id)
|
||
|
|
yield self._sse_error("AI 病人首段回复超时,请重试或关闭流式模式", "LLM_STREAM_TIMEOUT")
|
||
|
|
return
|
||
|
|
except Exception as exc:
|
||
|
|
logger.exception("chat_stream.failed session_id=%s error=%s", session.id, exc.__class__.__name__)
|
||
|
|
yield self._sse_error("AI 病人回复失败,请检查模型配置或稍后重试", "LLM_STREAM_FAILED")
|
||
|
|
return
|
||
|
|
|
||
|
|
if chunk.done:
|
||
|
|
if not full_reply.strip():
|
||
|
|
logger.warning("chat_stream.empty_done session_id=%s", session.id)
|
||
|
|
yield self._sse_error("AI 病人没有返回有效内容,请重试", "LLM_EMPTY_RESPONSE")
|
||
|
|
return
|
||
|
|
runtime_memory.add_message(session.memory_key or "", "patient", full_reply)
|
||
|
|
logger.info(
|
||
|
|
"chat_stream.done session_id=%s chars=%s latency_ms=%s model=%s",
|
||
|
|
session.id,
|
||
|
|
len(full_reply),
|
||
|
|
chunk.total_latency_ms,
|
||
|
|
chunk.model,
|
||
|
|
)
|
||
|
|
yield self._sse_done(
|
||
|
|
latency_ms=chunk.total_latency_ms or int((time.perf_counter() - started_at) * 1000),
|
||
|
|
first_token_ms=chunk.first_token_ms or 0,
|
||
|
|
model=chunk.model,
|
||
|
|
fallback_used=chunk.fallback_used,
|
||
|
|
)
|
||
|
|
return
|
||
|
|
|
||
|
|
if chunk.delta:
|
||
|
|
full_reply += chunk.delta
|
||
|
|
logger.debug("chat_stream.delta session_id=%s delta_len=%s", session.id, len(chunk.delta))
|
||
|
|
delta_payload = json.dumps({"delta": chunk.delta}, ensure_ascii=False)
|
||
|
|
yield f"event: message_delta\ndata: {delta_payload}\n\n"
|
||
|
|
|
||
|
|
self.audit.log(ctx, "session.chat.stream", "training_session", str(session.id), session.id)
|
||
|
|
return event_generator()
|
||
|
|
|
||
|
|
def _sse_done(self, latency_ms: int, first_token_ms: int, model: str | None, fallback_used: bool) -> str:
|
||
|
|
"""SSE 完成事件:统一返回流式问诊耗时和模型状态。"""
|
||
|
|
done_payload = json.dumps(
|
||
|
|
{
|
||
|
|
"latency_ms": latency_ms,
|
||
|
|
"first_token_ms": first_token_ms,
|
||
|
|
"model": model,
|
||
|
|
"fallback_used": fallback_used,
|
||
|
|
},
|
||
|
|
ensure_ascii=False,
|
||
|
|
)
|
||
|
|
return f"event: message_done\ndata: {done_payload}\n\n"
|
||
|
|
|
||
|
|
def _sse_error(self, message: str, code: str = "LLM_STREAM_TIMEOUT") -> str:
|
||
|
|
"""SSE 错误事件:让前端结束 pending 状态并展示用户可读错误。"""
|
||
|
|
payload = json.dumps({"code": code, "message": message}, ensure_ascii=False)
|
||
|
|
return f"event: error\ndata: {payload}\n\n"
|
||
|
|
|
||
|
|
async def generate_hints(self, ctx: UserContext, session_id: int, payload: HintRequest) -> HintResponse:
|
||
|
|
"""新手提示:基于当前会话上下文、已申请检查和病例信息生成提醒。"""
|
||
|
|
session = self._get_session(session_id, ctx.user_id)
|
||
|
|
if session.mode != "practice":
|
||
|
|
raise AppError("SESSION_STATUS_INVALID", "hints are only available in practice mode", 400)
|
||
|
|
if session.status != "inquiry":
|
||
|
|
raise AppError("SESSION_STATUS_INVALID", "hints are only available during inquiry", 400)
|
||
|
|
case = self.case_repo.get_active_case(session.case_id)
|
||
|
|
if not case:
|
||
|
|
raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404)
|
||
|
|
memory_messages = runtime_memory.get_messages(session.memory_key)
|
||
|
|
result = await self.orchestrator.generate_hints(session, case, memory_messages, session.orders, payload.last_user_message)
|
||
|
|
self.audit.log(ctx, "session.hints", "training_session", str(session.id), session.id)
|
||
|
|
return HintResponse(**result)
|
||
|
|
|
||
|
|
def complete_inquiry(self, ctx: UserContext, session_id: int) -> SessionStatusResponse:
|
||
|
|
"""完成问诊:校验至少一轮医生问诊后进入诊断阶段。"""
|
||
|
|
session = self._get_session(session_id, ctx.user_id)
|
||
|
|
if session.status != "inquiry":
|
||
|
|
raise AppError("SESSION_STATUS_INVALID", "only inquiry status can be completed", 400)
|
||
|
|
if not runtime_memory.has_doctor_message(session.memory_key):
|
||
|
|
raise AppError("INQUIRY_REQUIRED", "at least one doctor message is required", 400)
|
||
|
|
self.session_repo.update_status(session, "diagnosis")
|
||
|
|
self.audit.log(ctx, "session.complete_inquiry", "training_session", str(session.id), session.id)
|
||
|
|
return SessionStatusResponse(session_id=session.id, status=session.status)
|
||
|
|
|
||
|
|
def submit_diagnosis(self, ctx: UserContext, session_id: int, payload: SubmitDiagnosisRequest) -> SubmitDiagnosisResponse:
|
||
|
|
"""诊断提交:保存主要诊断、鉴别诊断和诊断依据并进入治疗阶段。"""
|
||
|
|
session = self._get_session(session_id, ctx.user_id)
|
||
|
|
if session.status != "diagnosis":
|
||
|
|
raise AppError("SESSION_STATUS_INVALID", "diagnosis submit is not allowed", 400)
|
||
|
|
submission = self.session_repo.get_submission(session.id) or SessionSubmission(session_id=session.id, user_id=ctx.user_id)
|
||
|
|
submission.primary_diagnosis = payload.primary_diagnosis
|
||
|
|
submission.differential_diagnoses = payload.differential_diagnoses
|
||
|
|
submission.diagnosis_basis = payload.diagnosis_basis
|
||
|
|
submission.diagnosis_submitted_at = datetime.utcnow()
|
||
|
|
self.session_repo.upsert_submission(submission)
|
||
|
|
self.session_repo.update_status(session, "treatment")
|
||
|
|
self.audit.log(ctx, "session.submit_diagnosis", "training_session", str(session.id), session.id)
|
||
|
|
return SubmitDiagnosisResponse(status=session.status)
|
||
|
|
|
||
|
|
def submit_treatment(self, ctx: UserContext, session_id: int, payload: SubmitTreatmentRequest) -> SubmitTreatmentResponse:
|
||
|
|
"""治疗提交:保存治疗方案、风险预案、沟通和随访并进入评价阶段。"""
|
||
|
|
session = self._get_session(session_id, ctx.user_id)
|
||
|
|
if session.status != "treatment":
|
||
|
|
raise AppError("SESSION_STATUS_INVALID", "treatment submit is not allowed", 400)
|
||
|
|
submission = self.session_repo.get_submission(session.id)
|
||
|
|
if not submission:
|
||
|
|
raise AppError("DIAGNOSIS_REQUIRED", "diagnosis submission is required", 400)
|
||
|
|
submission.treatment_principle = payload.treatment_principle
|
||
|
|
submission.treatment_measures = payload.treatment_measures
|
||
|
|
submission.risk_plan = payload.risk_plan
|
||
|
|
submission.communication = payload.communication
|
||
|
|
submission.follow_up = payload.follow_up
|
||
|
|
submission.treatment_submitted_at = datetime.utcnow()
|
||
|
|
self.session_repo.upsert_submission(submission)
|
||
|
|
self.session_repo.update_status(session, "evaluating")
|
||
|
|
self.audit.log(ctx, "session.submit_treatment", "training_session", str(session.id), session.id)
|
||
|
|
return SubmitTreatmentResponse(status=session.status)
|
||
|
|
|
||
|
|
def _get_session(self, session_id: int, user_id: str) -> TrainingSession:
|
||
|
|
"""会话归属:按 session_id 和 user_id 校验会话隔离。"""
|
||
|
|
session = self.session_repo.get_owned_session(session_id, user_id)
|
||
|
|
if not session:
|
||
|
|
raise AppError("SESSION_NOT_FOUND", "session not found or not owned by current user", 404)
|
||
|
|
return session
|