169 lines
7.1 KiB
Python
169 lines
7.1 KiB
Python
import asyncio
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
TEST_DB_PATH = Path(tempfile.gettempdir()) / "medical_agent_test_demo_flow.db"
|
|
TEST_DB_PATH.unlink(missing_ok=True)
|
|
os.environ["DATABASE_URL"] = f"sqlite:///{TEST_DB_PATH.as_posix()}"
|
|
os.environ["REPORT_STORAGE_DIR"] = str(Path(tempfile.gettempdir()) / "medical_agent_test_reports")
|
|
os.environ.setdefault("RUNTIME_MEMORY_BACKEND", "memory")
|
|
os.environ.setdefault("LLM_MOCK_ENABLED", "true")
|
|
|
|
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
|
|
|
from sqlalchemy import select
|
|
|
|
from app.core.context import UserContext
|
|
from app.core.exceptions import AppError
|
|
from app.db.session import SessionLocal
|
|
from app.models.source_case import CaseBase
|
|
from app.models.training_record import TrainingRecord, TrainingScoreDetail
|
|
from app.schemas.evaluation import CreateEvaluationRequest
|
|
from app.schemas.session import (
|
|
ChatRequest,
|
|
CreateOrderRequest,
|
|
CreateSessionRequest,
|
|
SubmitDiagnosisRequest,
|
|
SubmitTreatmentRequest,
|
|
)
|
|
from app.schemas.training_config import PatientConfig
|
|
from app.services.evaluation_service import EvaluationService
|
|
from app.services.order_service import OrderService
|
|
from app.services.pdf_export_service import PdfExportService
|
|
from app.services.runtime_memory import runtime_memory
|
|
from app.services.session_service import SessionService
|
|
from scripts.init_demo_db import init_database
|
|
|
|
|
|
async def run_demo_flow() -> None:
|
|
"""完整闭环:验证第一版 Demo 的核心训练链路可跑通。"""
|
|
init_database()
|
|
ctx = UserContext(user_id="demo_flow_user", tenant_id="demo_tenant", entry_scene="service_test")
|
|
with SessionLocal() as db:
|
|
case = db.scalar(select(CaseBase).where(CaseBase.title == "支气管肺炎 - 6岁男性患儿"))
|
|
assert case is not None
|
|
|
|
session_service = SessionService(db)
|
|
order_service = OrderService(db)
|
|
evaluation_service = EvaluationService(db)
|
|
pdf_service = PdfExportService(db)
|
|
|
|
created = session_service.create_session(
|
|
ctx,
|
|
CreateSessionRequest(
|
|
case_id=case.id,
|
|
training_type="diagnosis_treatment",
|
|
mode="practice",
|
|
score_type="percentage",
|
|
patient_config=PatientConfig(
|
|
visit_environment="outpatient",
|
|
age_group="youth",
|
|
education_level="higher",
|
|
personality="calm",
|
|
),
|
|
),
|
|
)
|
|
db.commit()
|
|
assert created.status == "inquiry"
|
|
assert created.patient_config["labels"]["visit_environment"] == "门诊"
|
|
|
|
chat = await session_service.chat(ctx, created.session_id, ChatRequest(message="孩子最高体温多少?").message)
|
|
db.commit()
|
|
assert chat.reply
|
|
|
|
order = order_service.create_order(created.session_id, ctx.user_id, CreateOrderRequest(item_code="chest_xray").item_code)
|
|
db.commit()
|
|
assert order.is_key is True
|
|
auxiliary_items = order_service.list_auxiliary_exam_items(created.session_id, ctx.user_id)
|
|
assert any(item.item_code == "chest_xray" for item in auxiliary_items.items)
|
|
physical_items = order_service.list_physical_exam_items(created.session_id, ctx.user_id)
|
|
assert physical_items.items == [] or all(item.item_code != "chest_xray" for item in physical_items.items)
|
|
tool_count_before = len([item for item in runtime_memory.get_messages(f"mem:{created.session_code}") if item.get("role") == "tool"])
|
|
|
|
duplicate_order = order_service.create_order(created.session_id, ctx.user_id, "chest_xray")
|
|
db.commit()
|
|
tool_count_after = len([item for item in runtime_memory.get_messages(f"mem:{created.session_code}") if item.get("role") == "tool"])
|
|
assert duplicate_order.already_ordered is True
|
|
assert tool_count_after == tool_count_before
|
|
|
|
try:
|
|
order_service.create_order(created.session_id, ctx.user_id, "not_exists")
|
|
except AppError as exc:
|
|
assert exc.code == "ORDER_ITEM_NOT_FOUND"
|
|
else:
|
|
raise AssertionError("invalid order item should raise AppError")
|
|
|
|
try:
|
|
order_service.list_order_items(created.session_id, "another_user")
|
|
except AppError as exc:
|
|
assert exc.code == "SESSION_NOT_FOUND"
|
|
else:
|
|
raise AssertionError("cross user access should raise AppError")
|
|
|
|
status = session_service.complete_inquiry(ctx, created.session_id)
|
|
db.commit()
|
|
assert status.status == "diagnosis"
|
|
|
|
diagnosis = session_service.submit_diagnosis(
|
|
ctx,
|
|
created.session_id,
|
|
SubmitDiagnosisRequest(
|
|
primary_diagnosis="支气管肺炎",
|
|
differential_diagnoses=["毛细支气管炎", "支气管哮喘急性发作"],
|
|
diagnosis_basis="发热咳嗽伴喘息,肺部湿啰音,胸片异常,炎症指标升高。",
|
|
),
|
|
)
|
|
db.commit()
|
|
assert diagnosis.status == "treatment"
|
|
|
|
treatment = session_service.submit_treatment(
|
|
ctx,
|
|
created.session_id,
|
|
SubmitTreatmentRequest(
|
|
treatment_principle="抗感染、止咳平喘、改善氧合、严密观察病情变化。",
|
|
treatment_measures="根据病情选择抗感染治疗,必要时雾化吸入,监测体温、呼吸和血氧。",
|
|
risk_plan="关注低氧、呼吸困难加重、持续高热和精神反应差。",
|
|
communication="向家属说明病情、用药注意事项和复诊指征。",
|
|
follow_up="治疗后复查体温、呼吸情况和必要炎症指标。",
|
|
),
|
|
)
|
|
db.commit()
|
|
assert treatment.status == "evaluating"
|
|
|
|
try:
|
|
await session_service.chat(ctx, created.session_id, "治疗后还能问诊吗?")
|
|
except AppError as exc:
|
|
assert exc.code == "SESSION_STATUS_INVALID"
|
|
else:
|
|
raise AssertionError("chat after treatment submission should raise AppError")
|
|
|
|
evaluation = await evaluation_service.create_evaluation(
|
|
ctx,
|
|
created.session_id,
|
|
CreateEvaluationRequest(score_type="percentage"),
|
|
)
|
|
db.commit()
|
|
assert evaluation.total_score > 0
|
|
assert evaluation.score_details
|
|
training_record = db.scalar(select(TrainingRecord).where(TrainingRecord.session_id == created.session_id))
|
|
assert training_record is not None
|
|
assert training_record.external_user_id == ctx.user_id
|
|
score_details = list(db.scalars(select(TrainingScoreDetail).where(TrainingScoreDetail.record_id == training_record.id)).all())
|
|
assert score_details
|
|
assert score_details[0].dimension
|
|
|
|
export = pdf_service.export(evaluation.evaluation_id, ctx.user_id)
|
|
db.commit()
|
|
assert Path(export.file_path).exists()
|
|
assert Path(export.file_path).stat().st_size > 1000
|
|
|
|
history = evaluation_service.list_history(ctx.user_id)
|
|
assert history.items
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(run_demo_flow())
|
|
print("demo flow test passed")
|