Files
fastapi/app/api/sessions.py
T
2026-06-04 10:55:23 +08:00

163 lines
5.3 KiB
Python

from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from app.core.response import ApiResponse, ok
from app.core.user_context import UserContext, get_user_context
from app.db.session import get_db
from app.schemas.evaluation import CreateEvaluationRequest, EvaluationResponse
from app.schemas.session import (
ChatRequest,
ChatResponse,
CreateOrderRequest,
CreateOrderResponse,
CreateSessionRequest,
CreateSessionResponse,
OrderItemsResponse,
SessionStatusResponse,
SubmitDiagnosisRequest,
SubmitDiagnosisResponse,
SubmitTreatmentRequest,
SubmitTreatmentResponse,
HintRequest,
HintResponse,
)
from app.services.evaluation_service import EvaluationService
from app.services.order_service import OrderService
from app.services.session_service import SessionService
router = APIRouter()
@router.post("", response_model=ApiResponse[CreateSessionResponse])
def create_session(
payload: CreateSessionRequest,
ctx: UserContext = Depends(get_user_context),
db: Session = Depends(get_db),
):
"""创建训练会话:初始化 user_id 隔离的训练会话和短期 memory。"""
result = SessionService(db).create_session(ctx, payload)
db.commit()
return ok(result)
@router.post("/{session_id}/chat", response_model=ApiResponse[ChatResponse])
async def chat(
session_id: int,
payload: ChatRequest,
ctx: UserContext = Depends(get_user_context),
db: Session = Depends(get_db),
):
"""非流式问诊:发送医生问题并返回 AI 病人回复。"""
result = await SessionService(db).chat(ctx, session_id, payload.message)
db.commit()
return ok(result)
@router.post("/{session_id}/chat/stream", response_class=StreamingResponse)
async def chat_stream(
session_id: int,
payload: ChatRequest,
ctx: UserContext = Depends(get_user_context),
db: Session = Depends(get_db),
):
"""流式问诊:返回 SSE 格式的 AI 病人增量回复。"""
response = await SessionService(db).stream_chat(ctx, session_id, payload.message)
db.commit()
return StreamingResponse(
response,
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.get("/{session_id}/order-items", response_model=ApiResponse[OrderItemsResponse])
def list_order_items(
session_id: int,
ctx: UserContext = Depends(get_user_context),
db: Session = Depends(get_db),
):
"""检查项目列表:返回当前病例可申请项目,不返回检查结果。"""
return ok(OrderService(db).list_order_items(session_id, ctx.user_id))
@router.post("/{session_id}/orders", response_model=ApiResponse[CreateOrderResponse])
def create_order(
session_id: int,
payload: CreateOrderRequest,
ctx: UserContext = Depends(get_user_context),
db: Session = Depends(get_db),
):
"""申请检查检验:从数据库读取并返回结构化结果。"""
result = OrderService(db).create_order(session_id, ctx.user_id, payload.item_code)
db.commit()
return ok(result)
@router.post("/{session_id}/complete-inquiry", response_model=ApiResponse[SessionStatusResponse])
def complete_inquiry(
session_id: int,
ctx: UserContext = Depends(get_user_context),
db: Session = Depends(get_db),
):
"""完成问诊:从问诊阶段进入诊断阶段。"""
result = SessionService(db).complete_inquiry(ctx, session_id)
db.commit()
return ok(result)
@router.post("/{session_id}/hints", response_model=ApiResponse[HintResponse])
async def generate_hints(
session_id: int,
payload: HintRequest,
ctx: UserContext = Depends(get_user_context),
db: Session = Depends(get_db),
):
"""新手模式提示:根据当前问诊上下文生成缺失维度和下一步问题。"""
result = await SessionService(db).generate_hints(ctx, session_id, payload)
db.commit()
return ok(result)
@router.post("/{session_id}/diagnosis", response_model=ApiResponse[SubmitDiagnosisResponse])
def submit_diagnosis(
session_id: int,
payload: SubmitDiagnosisRequest,
ctx: UserContext = Depends(get_user_context),
db: Session = Depends(get_db),
):
"""提交诊断:保存主要诊断、鉴别诊断和诊断依据。"""
result = SessionService(db).submit_diagnosis(ctx, session_id, payload)
db.commit()
return ok(result)
@router.post("/{session_id}/treatment", response_model=ApiResponse[SubmitTreatmentResponse])
def submit_treatment(
session_id: int,
payload: SubmitTreatmentRequest,
ctx: UserContext = Depends(get_user_context),
db: Session = Depends(get_db),
):
"""提交治疗方案:保存治疗、风险、沟通和随访内容。"""
result = SessionService(db).submit_treatment(ctx, session_id, payload)
db.commit()
return ok(result)
@router.post("/{session_id}/evaluation", response_model=ApiResponse[EvaluationResponse])
async def create_evaluation(
session_id: int,
payload: CreateEvaluationRequest,
ctx: UserContext = Depends(get_user_context),
db: Session = Depends(get_db),
):
"""生成评价报告:检索指南并调用 Scoring Agent 生成结构化评价。"""
result = await EvaluationService(db).create_evaluation(ctx, session_id, payload)
db.commit()
return ok(result)