add training configuration APIs
This commit is contained in:
+2
-1
@@ -37,7 +37,8 @@ reports/
|
||||
uploads/
|
||||
|
||||
# Demo-only or temporary files
|
||||
docs/
|
||||
docs/*
|
||||
!docs/03_api_design.md
|
||||
demo_frontend/
|
||||
scripts/check_mysql_demo.ps1
|
||||
scripts/init_mysql_demo.ps1
|
||||
|
||||
@@ -20,7 +20,7 @@ class MedicalConsultationOrchestrator:
|
||||
|
||||
async def patient_reply(self, session: TrainingSession, case: CaseBase, memory_messages: list[dict], message: str) -> LLMResponse:
|
||||
"""问诊编排:调用 Patient Agent 生成 AI 病人回复。"""
|
||||
return await self.patient_agent.reply(case, memory_messages, message, session.mode)
|
||||
return await self.patient_agent.reply(case, memory_messages, message, session.mode, self._patient_config(session))
|
||||
|
||||
async def patient_stream_reply(
|
||||
self,
|
||||
@@ -30,7 +30,7 @@ class MedicalConsultationOrchestrator:
|
||||
message: str,
|
||||
) -> AsyncIterator[LLMStreamChunk]:
|
||||
"""流式问诊编排:调用 Patient Agent 并返回流式片段。"""
|
||||
async for chunk in self.patient_agent.stream_reply(case, memory_messages, message, session.mode):
|
||||
async for chunk in self.patient_agent.stream_reply(case, memory_messages, message, session.mode, self._patient_config(session)):
|
||||
yield chunk
|
||||
|
||||
async def evaluate(
|
||||
@@ -67,3 +67,9 @@ class MedicalConsultationOrchestrator:
|
||||
) -> dict:
|
||||
"""新手提示编排:基于当前会话上下文生成轻量训练提醒。"""
|
||||
return await self.hint_agent.generate(session, case, memory_messages, orders, last_user_message)
|
||||
|
||||
def _patient_config(self, session: TrainingSession) -> dict | None:
|
||||
"""病人配置:从会话 metadata 读取训练页初始化配置,传递给 Patient Agent。"""
|
||||
metadata = session.metadata_ or {}
|
||||
patient_config = metadata.get("patient_config") if isinstance(metadata, dict) else None
|
||||
return patient_config if isinstance(patient_config, dict) else None
|
||||
|
||||
@@ -11,9 +11,16 @@ class PatientAgent:
|
||||
def __init__(self, llm: DeepSeekClient | None = None) -> None:
|
||||
self.llm = llm or DeepSeekClient()
|
||||
|
||||
async def reply(self, case: CaseBase, memory_messages: list[dict], user_message: str, mode: str) -> LLMResponse:
|
||||
async def reply(
|
||||
self,
|
||||
case: CaseBase,
|
||||
memory_messages: list[dict],
|
||||
user_message: str,
|
||||
mode: str,
|
||||
patient_config: dict | None = None,
|
||||
) -> LLMResponse:
|
||||
"""问诊回复:拼接病例上下文、短期记忆和用户输入后调用 Patient Agent。"""
|
||||
messages = self._build_messages(case, memory_messages, user_message, mode)
|
||||
messages = self._build_messages(case, memory_messages, user_message, mode, patient_config)
|
||||
return await self.llm.chat(
|
||||
messages,
|
||||
settings.llm_fast_model,
|
||||
@@ -27,9 +34,10 @@ class PatientAgent:
|
||||
memory_messages: list[dict],
|
||||
user_message: str,
|
||||
mode: str,
|
||||
patient_config: dict | None = None,
|
||||
) -> AsyncIterator[LLMStreamChunk]:
|
||||
"""流式问诊:以 SSE 方式返回 AI 病人增量回复。"""
|
||||
messages = self._build_messages(case, memory_messages, user_message, mode)
|
||||
messages = self._build_messages(case, memory_messages, user_message, mode, patient_config)
|
||||
async for chunk in self.llm.stream_chat(
|
||||
messages,
|
||||
settings.llm_fast_model,
|
||||
@@ -38,10 +46,18 @@ class PatientAgent:
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _build_messages(self, case: CaseBase, memory_messages: list[dict], user_message: str, mode: str) -> list[dict]:
|
||||
def _build_messages(
|
||||
self,
|
||||
case: CaseBase,
|
||||
memory_messages: list[dict],
|
||||
user_message: str,
|
||||
mode: str,
|
||||
patient_config: dict | None = None,
|
||||
) -> list[dict]:
|
||||
"""提示词拼接:构造 AI 病人的系统提示词和对话历史。"""
|
||||
profile = case.ai_patient_profile or {}
|
||||
hidden_info = case.hidden_patient_info or {}
|
||||
config_rule = self._build_patient_config_rule(patient_config)
|
||||
mode_rule = {
|
||||
"novice": "新手模式:回答清楚,必要时可提示医生继续追问症状、既往史或检查。",
|
||||
"practice": "练习模式:只回答被问到的信息,不主动给诊断建议。",
|
||||
@@ -52,6 +68,7 @@ class PatientAgent:
|
||||
病例主诉:{case.chief_complaint}
|
||||
患者人设:{profile}
|
||||
隐藏信息:{hidden_info}
|
||||
病人初始化配置:{config_rule}
|
||||
回答规则:
|
||||
1. 不主动透露未被问到的隐藏信息。
|
||||
2. 不替医生做诊断,不提供治疗方案。
|
||||
@@ -66,6 +83,21 @@ class PatientAgent:
|
||||
messages.append({"role": "user", "content": user_message})
|
||||
return messages
|
||||
|
||||
def _build_patient_config_rule(self, patient_config: dict | None) -> str:
|
||||
"""配置提示:把训练页初始化配置转成 AI 病人表达约束。"""
|
||||
if not patient_config:
|
||||
return "使用默认门诊、青年、高等教育、平和性格的表达方式。"
|
||||
labels = patient_config.get("labels") if isinstance(patient_config, dict) else None
|
||||
values = labels or (patient_config.get("values") if isinstance(patient_config, dict) else {}) or {}
|
||||
visit_environment = values.get("visit_environment", "门诊")
|
||||
age_group = values.get("age_group", "青年")
|
||||
education_level = values.get("education_level", "高等教育")
|
||||
personality = values.get("personality", "平和")
|
||||
return (
|
||||
f"就诊环境={visit_environment};年龄段={age_group};文化程度={education_level};性格={personality}。"
|
||||
"回答时根据性格调整情绪和配合度,根据文化程度调整表达清晰度,但不得改变病例事实。"
|
||||
)
|
||||
|
||||
def _to_llm_history(self, memory_messages: list[dict]) -> list[dict]:
|
||||
"""历史转换:把业务角色 doctor/patient 转换为 LLM role。"""
|
||||
role_map = {"doctor": "user", "patient": "assistant", "system": "system", "tool": "assistant"}
|
||||
|
||||
+2
-1
@@ -1,11 +1,12 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api import agent, auth, cases, evaluations, knowledge, llm_test, sessions
|
||||
from app.api import agent, auth, cases, evaluations, knowledge, llm_test, sessions, training_config
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(agent.router, tags=["agent"])
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||
api_router.include_router(cases.router, prefix="/cases", tags=["cases"])
|
||||
api_router.include_router(training_config.router, prefix="/training-config", tags=["training-config"])
|
||||
api_router.include_router(sessions.router, prefix="/sessions", tags=["sessions"])
|
||||
api_router.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"])
|
||||
api_router.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
|
||||
|
||||
@@ -123,6 +123,73 @@ async def generate_hints(
|
||||
return ok(result)
|
||||
|
||||
|
||||
@router.post("/{session_id}/hints/stream", response_class=StreamingResponse)
|
||||
async def stream_hints(
|
||||
session_id: int,
|
||||
payload: HintRequest,
|
||||
ctx: UserContext = Depends(get_user_context),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""流式练习提示:返回一句话形式的 SSE 提示。"""
|
||||
response = await SessionService(db).stream_hints(ctx, session_id, payload)
|
||||
db.commit()
|
||||
return StreamingResponse(
|
||||
response,
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{session_id}/physical-exams", response_model=ApiResponse[OrderItemsResponse])
|
||||
def list_physical_exam_items(
|
||||
session_id: int,
|
||||
ctx: UserContext = Depends(get_user_context),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""体格检查列表:返回当前病例可申请的体格检查项目。"""
|
||||
return ok(OrderService(db).list_physical_exam_items(session_id, ctx.user_id))
|
||||
|
||||
|
||||
@router.get("/{session_id}/auxiliary-exams", response_model=ApiResponse[OrderItemsResponse])
|
||||
def list_auxiliary_exam_items(
|
||||
session_id: int,
|
||||
ctx: UserContext = Depends(get_user_context),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""辅助检查列表:返回当前病例可申请的辅助检查项目。"""
|
||||
return ok(OrderService(db).list_auxiliary_exam_items(session_id, ctx.user_id))
|
||||
|
||||
|
||||
@router.post("/{session_id}/physical-exams/{item_code}", response_model=ApiResponse[CreateOrderResponse])
|
||||
def create_physical_exam_order(
|
||||
session_id: int,
|
||||
item_code: str,
|
||||
ctx: UserContext = Depends(get_user_context),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""体格检查结果:按项目编码返回数据库固定结果。"""
|
||||
result = OrderService(db).create_physical_exam_order(session_id, ctx.user_id, item_code)
|
||||
db.commit()
|
||||
return ok(result)
|
||||
|
||||
|
||||
@router.post("/{session_id}/auxiliary-exams/{item_code}", response_model=ApiResponse[CreateOrderResponse])
|
||||
def create_auxiliary_exam_order(
|
||||
session_id: int,
|
||||
item_code: str,
|
||||
ctx: UserContext = Depends(get_user_context),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""辅助检查结果:按项目编码返回数据库固定结果。"""
|
||||
result = OrderService(db).create_auxiliary_exam_order(session_id, ctx.user_id, item_code)
|
||||
db.commit()
|
||||
return ok(result)
|
||||
|
||||
|
||||
@router.post("/{session_id}/diagnosis", response_model=ApiResponse[SubmitDiagnosisResponse])
|
||||
def submit_diagnosis(
|
||||
session_id: int,
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.response import ApiResponse, ok
|
||||
from app.core.user_context import UserContext, get_user_context
|
||||
from app.db.session import get_db
|
||||
from app.schemas.training_config import TrainingConfigOptionsResponse, TrainingConfigRecommendedResponse
|
||||
from app.services.training_config_service import TrainingConfigService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/recommended", response_model=ApiResponse[TrainingConfigRecommendedResponse])
|
||||
def get_recommended_training_config(
|
||||
case_id: int = Query(..., ge=1),
|
||||
_: UserContext = Depends(get_user_context),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""推荐配置信息:返回训练页默认病人初始化配置。"""
|
||||
return ok(TrainingConfigService(db).get_recommended(case_id))
|
||||
|
||||
|
||||
@router.get("/options", response_model=ApiResponse[TrainingConfigOptionsResponse])
|
||||
def get_training_config_options(
|
||||
case_id: int = Query(..., ge=1),
|
||||
_: UserContext = Depends(get_user_context),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""训练配置信息:返回训练页自定义病人初始化配置选项。"""
|
||||
return ok(TrainingConfigService(db).get_options(case_id))
|
||||
@@ -1,5 +1,7 @@
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.schemas.training_config import PatientConfig
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
"""创建会话入参:选择病例、训练类别、模式和分数类型。"""
|
||||
@@ -8,6 +10,7 @@ class CreateSessionRequest(BaseModel):
|
||||
training_type: str = Field(pattern="^(case_analysis|diagnosis_treatment|consultation)$")
|
||||
mode: str = Field(pattern="^(novice|practice|teaching)$")
|
||||
score_type: str = Field(default="percentage", pattern="^(percentage|five_point)$")
|
||||
patient_config: PatientConfig | None = None
|
||||
|
||||
@field_validator("mode")
|
||||
@classmethod
|
||||
@@ -23,6 +26,7 @@ class CreateSessionResponse(BaseModel):
|
||||
session_code: str
|
||||
status: str
|
||||
patient_opening: str
|
||||
patient_config: dict | None = None
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ConfigOption(BaseModel):
|
||||
"""训练配置选项:用于前端渲染单个可选项。"""
|
||||
|
||||
value: str
|
||||
label: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class PatientConfig(BaseModel):
|
||||
"""病人初始化配置:控制 AI 病人的就诊场景、年龄段、文化程度和性格。"""
|
||||
|
||||
visit_environment: str = "outpatient"
|
||||
age_group: str = "youth"
|
||||
education_level: str = "higher"
|
||||
personality: str = "calm"
|
||||
|
||||
|
||||
class TrainingConfigOptionsResponse(BaseModel):
|
||||
"""训练配置响应:返回默认配置和全部可选项。"""
|
||||
|
||||
case_id: int
|
||||
recommended: PatientConfig
|
||||
recommended_labels: dict[str, str]
|
||||
options: dict[str, list[ConfigOption]]
|
||||
|
||||
|
||||
class TrainingConfigRecommendedResponse(BaseModel):
|
||||
"""推荐训练配置响应:用于训练页进入时初始化默认病人信息。"""
|
||||
|
||||
case_id: int
|
||||
recommended: PatientConfig
|
||||
recommended_labels: dict[str, str]
|
||||
options: dict[str, list[ConfigOption]]
|
||||
@@ -11,6 +11,9 @@ from app.services.runtime_memory import runtime_memory
|
||||
class OrderService:
|
||||
"""检查检验服务:提供可申请项目和数据库固定结果返回。"""
|
||||
|
||||
PHYSICAL_TYPES = {"physical", "physical_exam", "inspection", "palpation", "percussion", "auscultation"}
|
||||
PHYSICAL_KEYWORDS = ("体格", "体征", "查体", "听诊", "叩诊", "触诊")
|
||||
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.case_repo = CaseRepository(db)
|
||||
@@ -20,6 +23,30 @@ class OrderService:
|
||||
"""检查项目列表:按会话病例返回可申请项目,不返回结果。"""
|
||||
session = self._get_session(session_id, user_id)
|
||||
items = self.case_repo.get_exam_items(session.case_id)
|
||||
return self._items_response(items)
|
||||
|
||||
def list_physical_exam_items(self, session_id: int, user_id: str) -> OrderItemsResponse:
|
||||
"""体格检查列表:从当前病例检查项中筛选体格检查类项目。"""
|
||||
session = self._get_session(session_id, user_id)
|
||||
items = [item for item in self.case_repo.get_exam_items(session.case_id) if self._is_physical_item(item)]
|
||||
return self._items_response(items)
|
||||
|
||||
def list_auxiliary_exam_items(self, session_id: int, user_id: str) -> OrderItemsResponse:
|
||||
"""辅助检查列表:从当前病例检查项中筛选非体格检查类项目。"""
|
||||
session = self._get_session(session_id, user_id)
|
||||
items = [item for item in self.case_repo.get_exam_items(session.case_id) if not self._is_physical_item(item)]
|
||||
return self._items_response(items)
|
||||
|
||||
def create_physical_exam_order(self, session_id: int, user_id: str, item_code: str) -> CreateOrderResponse:
|
||||
"""体格检查结果:复用检查申请逻辑,结果仍只来自数据库。"""
|
||||
return self.create_order(session_id, user_id, item_code)
|
||||
|
||||
def create_auxiliary_exam_order(self, session_id: int, user_id: str, item_code: str) -> CreateOrderResponse:
|
||||
"""辅助检查结果:复用检查申请逻辑,结果仍只来自数据库。"""
|
||||
return self.create_order(session_id, user_id, item_code)
|
||||
|
||||
def _items_response(self, items) -> OrderItemsResponse:
|
||||
"""检查列表响应:把 ORM 检查项转换成前端列表结构。"""
|
||||
return OrderItemsResponse(
|
||||
items=[
|
||||
OrderItemResponse(item_code=item.item_code, item_name=item.item_name, item_type=item.item_type)
|
||||
@@ -27,6 +54,14 @@ class OrderService:
|
||||
]
|
||||
)
|
||||
|
||||
def _is_physical_item(self, item) -> bool:
|
||||
"""检查分类:按 item_type 和 category 识别体格检查,其他归入辅助检查。"""
|
||||
item_type = (item.item_type or "").lower()
|
||||
category = item.category or ""
|
||||
if item_type in self.PHYSICAL_TYPES:
|
||||
return True
|
||||
return any(keyword in category or keyword in item.item_name for keyword in self.PHYSICAL_KEYWORDS)
|
||||
|
||||
def create_order(self, session_id: int, user_id: str, item_code: str) -> CreateOrderResponse:
|
||||
"""检查申请:从数据库读取检查结果并写入当前会话记录。"""
|
||||
session = self._get_session(session_id, user_id)
|
||||
|
||||
@@ -28,6 +28,7 @@ from app.schemas.session import (
|
||||
)
|
||||
from app.services.audit_service import AuditService
|
||||
from app.services.runtime_memory import runtime_memory
|
||||
from app.services.training_config_service import TrainingConfigService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -48,6 +49,7 @@ class SessionService:
|
||||
if not case:
|
||||
raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404)
|
||||
|
||||
patient_config = TrainingConfigService(self.db).normalize_patient_config(payload.patient_config)
|
||||
session_code = f"sess_{datetime.utcnow().strftime('%Y%m%d%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
||||
memory_key = f"mem:{session_code}"
|
||||
session = self.session_repo.create_session(
|
||||
@@ -64,7 +66,7 @@ class SessionService:
|
||||
status="inquiry",
|
||||
started_at=datetime.utcnow(),
|
||||
memory_key=memory_key,
|
||||
metadata_={"source": "demo"},
|
||||
metadata_={"source": "demo", "patient_config": patient_config},
|
||||
)
|
||||
)
|
||||
patient_opening = case.patient_opening or "家长:医生,孩子这几天不舒服,想请您看看。"
|
||||
@@ -75,6 +77,7 @@ class SessionService:
|
||||
session_code=session.session_code,
|
||||
status=session.status,
|
||||
patient_opening=patient_opening,
|
||||
patient_config=patient_config,
|
||||
)
|
||||
|
||||
async def chat(self, ctx: UserContext, session_id: int, message: str) -> ChatResponse:
|
||||
@@ -225,6 +228,61 @@ class SessionService:
|
||||
self.audit.log(ctx, "session.hints", "training_session", str(session.id), session.id)
|
||||
return HintResponse(**result)
|
||||
|
||||
async def stream_hints(self, ctx: UserContext, session_id: int, payload: HintRequest) -> AsyncIterator[str]:
|
||||
"""流式练习提示:把结构化提示压缩成一句话并用 SSE 返回给前端。"""
|
||||
started_at = time.perf_counter()
|
||||
try:
|
||||
hint_result = await self.generate_hints(ctx, session_id, payload)
|
||||
sentence = self._build_hint_sentence(hint_result)
|
||||
except AppError as exc:
|
||||
error_message = exc.message
|
||||
error_code = exc.code
|
||||
|
||||
async def app_error_generator() -> AsyncIterator[str]:
|
||||
yield self._sse_error(error_message, error_code)
|
||||
|
||||
return app_error_generator()
|
||||
except Exception:
|
||||
logger.exception("hint_stream.failed session_id=%s", session_id)
|
||||
|
||||
async def error_generator() -> AsyncIterator[str]:
|
||||
yield self._sse_error("练习提示生成失败,请稍后重试", "HINT_STREAM_FAILED")
|
||||
|
||||
return error_generator()
|
||||
|
||||
async def event_generator() -> AsyncIterator[str]:
|
||||
if not sentence:
|
||||
yield self._sse_error("当前没有生成有效提示,请继续问诊后再试", "HINT_EMPTY")
|
||||
return
|
||||
for chunk in self._chunk_text(sentence, size=12):
|
||||
payload_text = json.dumps({"delta": chunk}, ensure_ascii=False)
|
||||
yield f"event: hint_delta\ndata: {payload_text}\n\n"
|
||||
await asyncio.sleep(0)
|
||||
done_payload = json.dumps({"latency_ms": int((time.perf_counter() - started_at) * 1000)}, ensure_ascii=False)
|
||||
yield f"event: hint_done\ndata: {done_payload}\n\n"
|
||||
|
||||
return event_generator()
|
||||
|
||||
def _build_hint_sentence(self, hint_result: HintResponse) -> str:
|
||||
"""提示压缩:从结构化提示中提炼适合前端流式展示的一句话。"""
|
||||
parts: list[str] = []
|
||||
if hint_result.missing_dimensions:
|
||||
parts.append(f"当前可补充{ '、'.join(hint_result.missing_dimensions[:3]) }")
|
||||
if hint_result.next_questions:
|
||||
parts.append(f"下一步可问:{hint_result.next_questions[0]}")
|
||||
elif hint_result.hints:
|
||||
parts.append(hint_result.hints[0])
|
||||
if hint_result.recommended_orders:
|
||||
order = hint_result.recommended_orders[0]
|
||||
item_code = order.get("item_code") or order.get("item_name") or "关键检查"
|
||||
reason = order.get("reason") or "用于完善病情判断"
|
||||
parts.append(f"可考虑申请{item_code},{reason}")
|
||||
return ";".join(parts) + ("。" if parts else "")
|
||||
|
||||
def _chunk_text(self, text: str, size: int) -> list[str]:
|
||||
"""文本切片:把一句练习提示拆成短片段,便于前端按 SSE 渐进展示。"""
|
||||
return [text[index : index + size] for index in range(0, len(text), size)]
|
||||
|
||||
def complete_inquiry(self, ctx: UserContext, session_id: int) -> SessionStatusResponse:
|
||||
"""完成问诊:校验至少一轮医生问诊后进入诊断阶段。"""
|
||||
session = self._get_session(session_id, ctx.user_id)
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.exceptions import AppError
|
||||
from app.repositories.case_repository import CaseRepository
|
||||
from app.schemas.training_config import (
|
||||
ConfigOption,
|
||||
PatientConfig,
|
||||
TrainingConfigOptionsResponse,
|
||||
TrainingConfigRecommendedResponse,
|
||||
)
|
||||
|
||||
|
||||
class TrainingConfigService:
|
||||
"""训练配置服务:提供训练页病人初始化配置,不写数据库。"""
|
||||
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.case_repo = CaseRepository(db)
|
||||
|
||||
def get_recommended(self, case_id: int) -> TrainingConfigRecommendedResponse:
|
||||
"""推荐配置:根据病例返回训练页默认病人初始化配置。"""
|
||||
self._ensure_case(case_id)
|
||||
recommended = self.default_patient_config()
|
||||
return TrainingConfigRecommendedResponse(
|
||||
case_id=case_id,
|
||||
recommended=recommended,
|
||||
recommended_labels=self.config_labels(recommended),
|
||||
options=self.config_options(),
|
||||
)
|
||||
|
||||
def get_options(self, case_id: int) -> TrainingConfigOptionsResponse:
|
||||
"""配置选项:返回训练页自定义配置的全部可选项。"""
|
||||
self._ensure_case(case_id)
|
||||
recommended = self.default_patient_config()
|
||||
return TrainingConfigOptionsResponse(
|
||||
case_id=case_id,
|
||||
recommended=recommended,
|
||||
recommended_labels=self.config_labels(recommended),
|
||||
options=self.config_options(),
|
||||
)
|
||||
|
||||
def default_patient_config(self) -> PatientConfig:
|
||||
"""默认配置:按当前产品原型初始化病人信息。"""
|
||||
return PatientConfig(
|
||||
visit_environment="outpatient",
|
||||
age_group="youth",
|
||||
education_level="higher",
|
||||
personality="calm",
|
||||
)
|
||||
|
||||
def normalize_patient_config(self, config: PatientConfig | None) -> dict:
|
||||
"""配置归一:校验并补齐前端传入的病人初始化配置。"""
|
||||
selected = config or self.default_patient_config()
|
||||
values = selected.model_dump()
|
||||
allowed = {key: {item.value for item in items} for key, items in self.config_options().items()}
|
||||
for key, value in values.items():
|
||||
if value not in allowed.get(key, set()):
|
||||
raise AppError("TRAINING_CONFIG_INVALID", f"invalid patient config field: {key}", 400)
|
||||
return {
|
||||
"values": values,
|
||||
"labels": self.config_labels(selected),
|
||||
}
|
||||
|
||||
def config_labels(self, config: PatientConfig) -> dict[str, str]:
|
||||
"""配置标签:把配置值转换为前端和提示词可读的中文标签。"""
|
||||
option_map = {
|
||||
key: {item.value: item.label for item in items}
|
||||
for key, items in self.config_options().items()
|
||||
}
|
||||
values = config.model_dump()
|
||||
return {key: option_map.get(key, {}).get(value, value) for key, value in values.items()}
|
||||
|
||||
def config_options(self) -> dict[str, list[ConfigOption]]:
|
||||
"""配置选项:训练页可选病人初始化配置。"""
|
||||
return {
|
||||
"visit_environment": [
|
||||
ConfigOption(value="outpatient", label="门诊", description="适合常规问诊训练"),
|
||||
ConfigOption(value="emergency", label="急诊", description="病情紧急、沟通节奏更快"),
|
||||
ConfigOption(value="ward", label="病房", description="适合住院病情追踪和处置沟通"),
|
||||
],
|
||||
"age_group": [
|
||||
ConfigOption(value="child", label="儿童", description="由家属代述为主"),
|
||||
ConfigOption(value="youth", label="青年", description="表达清楚,能配合问诊"),
|
||||
ConfigOption(value="middle_aged", label="中年", description="关注工作、家庭和慢病背景"),
|
||||
ConfigOption(value="elderly", label="老年", description="表达较慢,需关注基础病和用药史"),
|
||||
],
|
||||
"education_level": [
|
||||
ConfigOption(value="primary_or_below", label="小学及以下", description="医学术语理解弱"),
|
||||
ConfigOption(value="secondary", label="中等教育", description="能理解常见健康解释"),
|
||||
ConfigOption(value="higher", label="高等教育", description="理解能力强,能描述细节"),
|
||||
],
|
||||
"personality": [
|
||||
ConfigOption(value="calm", label="平和", description="情绪稳定,按问题回答"),
|
||||
ConfigOption(value="anxious", label="焦虑", description="更担心病情和治疗风险"),
|
||||
ConfigOption(value="impatient", label="急躁", description="希望快速获得结论"),
|
||||
ConfigOption(value="cooperative", label="配合", description="愿意补充细节"),
|
||||
ConfigOption(value="suspicious", label="多疑", description="会追问检查和用药依据"),
|
||||
],
|
||||
}
|
||||
|
||||
def _ensure_case(self, case_id: int) -> None:
|
||||
"""病例校验:确认配置请求对应已发布病例。"""
|
||||
if not self.case_repo.get_active_case(case_id):
|
||||
raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -89,18 +89,46 @@ def run_api_contract_tests() -> None:
|
||||
assert "/api/v1/imports/case-sql/apply" not in openapi_payload["paths"]
|
||||
assert "/api/v1/cases/{case_id}/delete-preview" not in openapi_payload["paths"]
|
||||
assert "delete" not in openapi_payload["paths"]["/api/v1/cases/{case_id}"]
|
||||
assert "/api/v1/training-config/recommended" in openapi_payload["paths"]
|
||||
assert "/api/v1/training-config/options" in openapi_payload["paths"]
|
||||
assert "/api/v1/sessions/{session_id}/hints/stream" in openapi_payload["paths"]
|
||||
assert "/api/v1/sessions/{session_id}/physical-exams" in openapi_payload["paths"]
|
||||
assert "/api/v1/sessions/{session_id}/auxiliary-exams" in openapi_payload["paths"]
|
||||
assert "/api/v1/sessions/{session_id}/physical-exams/{item_code}" in openapi_payload["paths"]
|
||||
assert "/api/v1/sessions/{session_id}/auxiliary-exams/{item_code}" in openapi_payload["paths"]
|
||||
|
||||
cases = client.get("/api/v1/cases", headers=headers)
|
||||
assert cases.status_code == 200
|
||||
case_id = cases.json()["data"]["items"][0]["id"]
|
||||
|
||||
recommended_config = client.get(f"/api/v1/training-config/recommended?case_id={case_id}", headers=headers)
|
||||
assert recommended_config.status_code == 200
|
||||
assert recommended_config.json()["data"]["recommended"]["visit_environment"] == "outpatient"
|
||||
assert recommended_config.json()["data"]["recommended_labels"]["visit_environment"] == "门诊"
|
||||
|
||||
config_options = client.get(f"/api/v1/training-config/options?case_id={case_id}", headers=headers)
|
||||
assert config_options.status_code == 200
|
||||
assert config_options.json()["data"]["options"]["personality"]
|
||||
|
||||
created = client.post(
|
||||
"/api/v1/sessions",
|
||||
headers=headers,
|
||||
json={"case_id": case_id, "training_type": "diagnosis_treatment", "mode": "practice", "score_type": "percentage"},
|
||||
json={
|
||||
"case_id": case_id,
|
||||
"training_type": "diagnosis_treatment",
|
||||
"mode": "practice",
|
||||
"score_type": "percentage",
|
||||
"patient_config": {
|
||||
"visit_environment": "outpatient",
|
||||
"age_group": "youth",
|
||||
"education_level": "higher",
|
||||
"personality": "calm",
|
||||
},
|
||||
},
|
||||
)
|
||||
assert created.status_code == 200
|
||||
session_id = created.json()["data"]["session_id"]
|
||||
assert created.json()["data"]["patient_config"]["labels"]["personality"] == "平和"
|
||||
|
||||
cross_user = client.get(
|
||||
f"/api/v1/sessions/{session_id}/order-items",
|
||||
@@ -120,6 +148,18 @@ def run_api_contract_tests() -> None:
|
||||
assert order_two.status_code == 200
|
||||
assert order_two.json()["data"]["already_ordered"] is True
|
||||
|
||||
physical_list = client.get(f"/api/v1/sessions/{session_id}/physical-exams", headers=headers)
|
||||
assert physical_list.status_code == 200
|
||||
assert "items" in physical_list.json()["data"]
|
||||
|
||||
auxiliary_list = client.get(f"/api/v1/sessions/{session_id}/auxiliary-exams", headers=headers)
|
||||
assert auxiliary_list.status_code == 200
|
||||
assert any(item["item_code"] == "blood_routine" for item in auxiliary_list.json()["data"]["items"])
|
||||
|
||||
auxiliary_result = client.post(f"/api/v1/sessions/{session_id}/auxiliary-exams/blood_routine", headers=headers)
|
||||
assert auxiliary_result.status_code == 200
|
||||
assert auxiliary_result.json()["data"]["already_ordered"] is True
|
||||
|
||||
practice_hint_session = client.post(
|
||||
"/api/v1/sessions",
|
||||
headers=headers,
|
||||
@@ -136,6 +176,17 @@ def run_api_contract_tests() -> None:
|
||||
assert hint.json()["data"]["hints"]
|
||||
assert "recommended_orders" in hint.json()["data"]
|
||||
|
||||
with client.stream(
|
||||
"POST",
|
||||
f"/api/v1/sessions/{practice_hint_session_id}/hints/stream",
|
||||
headers=headers,
|
||||
json={"scope": "current_conversation"},
|
||||
) as hint_stream:
|
||||
assert hint_stream.status_code == 200
|
||||
hint_stream_text = "".join(hint_stream.iter_text())
|
||||
assert "event: hint_delta" in hint_stream_text
|
||||
assert "event: hint_done" in hint_stream_text
|
||||
|
||||
teaching = client.post(
|
||||
"/api/v1/sessions",
|
||||
headers=headers,
|
||||
|
||||
@@ -28,6 +28,7 @@ from app.schemas.session import (
|
||||
SubmitDiagnosisRequest,
|
||||
SubmitTreatmentRequest,
|
||||
)
|
||||
from app.schemas.training_config import PatientConfig
|
||||
from app.services.evaluation_service import EvaluationService
|
||||
from app.services.order_service import OrderService
|
||||
from app.services.pdf_export_service import PdfExportService
|
||||
@@ -56,10 +57,17 @@ async def run_demo_flow() -> None:
|
||||
training_type="diagnosis_treatment",
|
||||
mode="practice",
|
||||
score_type="percentage",
|
||||
patient_config=PatientConfig(
|
||||
visit_environment="outpatient",
|
||||
age_group="youth",
|
||||
education_level="higher",
|
||||
personality="calm",
|
||||
),
|
||||
),
|
||||
)
|
||||
db.commit()
|
||||
assert created.status == "inquiry"
|
||||
assert created.patient_config["labels"]["visit_environment"] == "门诊"
|
||||
|
||||
chat = await session_service.chat(ctx, created.session_id, ChatRequest(message="孩子最高体温多少?").message)
|
||||
db.commit()
|
||||
@@ -68,6 +76,10 @@ async def run_demo_flow() -> None:
|
||||
order = order_service.create_order(created.session_id, ctx.user_id, CreateOrderRequest(item_code="chest_xray").item_code)
|
||||
db.commit()
|
||||
assert order.is_key is True
|
||||
auxiliary_items = order_service.list_auxiliary_exam_items(created.session_id, ctx.user_id)
|
||||
assert any(item.item_code == "chest_xray" for item in auxiliary_items.items)
|
||||
physical_items = order_service.list_physical_exam_items(created.session_id, ctx.user_id)
|
||||
assert physical_items.items == [] or all(item.item_code != "chest_xray" for item in physical_items.items)
|
||||
tool_count_before = len([item for item in runtime_memory.get_messages(f"mem:{created.session_code}") if item.get("role") == "tool"])
|
||||
|
||||
duplicate_order = order_service.create_order(created.session_id, ctx.user_id, "chest_xray")
|
||||
|
||||
Reference in New Issue
Block a user