chore: initialize medical consultation agent demo

This commit is contained in:
刘金宝
2026-06-01 09:25:26 +08:00
commit a7733243b2
139 changed files with 15764 additions and 0 deletions
+227
View File
@@ -0,0 +1,227 @@
import os
import sys
from decimal import Decimal
from pathlib import Path
os.environ.setdefault("DATABASE_URL", "sqlite:///./storage/test_api_contract.db")
os.environ.setdefault("RUNTIME_MEMORY_BACKEND", "memory")
os.environ.setdefault("LLM_MOCK_ENABLED", "true")
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
Path("storage").mkdir(exist_ok=True)
try:
from fastapi.testclient import TestClient
except ImportError:
TestClient = None
def run_api_contract_tests() -> None:
"""API 合约:验证统一响应、user_id 校验、核心接口和跨用户隔离。"""
if TestClient is None:
print("api contract tests skipped: fastapi is not installed")
return
from app.main import app
from app.db.session import SessionLocal
from app.models.source_case import CaseBase, CaseExamItem, ScoringRule, TraditionalCase
from app.repositories.case_repository import CaseRepository
from scripts.init_demo_db import init_database
init_database()
client = TestClient(app)
headers = {"X-User-Id": "api_user_001", "X-Entry-Scene": "api_test"}
missing_user = client.get("/api/v1/agent/hello")
assert missing_user.status_code == 401
assert missing_user.json()["code"] == "USER_ID_REQUIRED"
hello = client.get("/api/v1/agent/hello", headers=headers)
assert hello.status_code == 200
assert hello.json()["code"] == "OK"
cases = client.get("/api/v1/cases", headers=headers)
assert cases.status_code == 200
case_id = cases.json()["data"]["items"][0]["id"]
created = client.post(
"/api/v1/sessions",
headers=headers,
json={"case_id": case_id, "training_type": "diagnosis_treatment", "mode": "practice", "score_type": "percentage"},
)
assert created.status_code == 200
session_id = created.json()["data"]["session_id"]
cross_user = client.get(f"/api/v1/sessions/{session_id}/order-items", headers={"X-User-Id": "api_user_002"})
assert cross_user.status_code == 404
assert cross_user.json()["code"] == "SESSION_NOT_FOUND"
invalid_order = client.post(f"/api/v1/sessions/{session_id}/orders", headers=headers, json={"item_code": "bad_item"})
assert invalid_order.status_code == 404
assert invalid_order.json()["code"] == "ORDER_ITEM_NOT_FOUND"
order_one = client.post(f"/api/v1/sessions/{session_id}/orders", headers=headers, json={"item_code": "blood_routine"})
assert order_one.status_code == 200
assert order_one.json()["data"]["already_ordered"] is False
order_two = client.post(f"/api/v1/sessions/{session_id}/orders", headers=headers, json={"item_code": "blood_routine"})
assert order_two.status_code == 200
assert order_two.json()["data"]["already_ordered"] is True
practice_hint_session = client.post(
"/api/v1/sessions",
headers=headers,
json={"case_id": case_id, "training_type": "diagnosis_treatment", "mode": "practice", "score_type": "percentage"},
)
assert practice_hint_session.status_code == 200
practice_hint_session_id = practice_hint_session.json()["data"]["session_id"]
hint = client.post(
f"/api/v1/sessions/{practice_hint_session_id}/hints",
headers=headers,
json={"scope": "current_conversation"},
)
assert hint.status_code == 200
assert hint.json()["data"]["hints"]
assert "recommended_orders" in hint.json()["data"]
teaching = client.post(
"/api/v1/sessions",
headers=headers,
json={"case_id": case_id, "training_type": "diagnosis_treatment", "mode": "teaching", "score_type": "percentage"},
)
assert teaching.status_code == 200
teaching_hint = client.post(
f"/api/v1/sessions/{teaching.json()['data']['session_id']}/hints",
headers=headers,
json={"scope": "current_conversation"},
)
assert teaching_hint.status_code == 400
assert teaching_hint.json()["code"] == "SESSION_STATUS_INVALID"
llm_fast = client.post("/api/v1/llm/test/deepseek-fast", headers=headers, json={"message": "hello"})
assert llm_fast.status_code == 200
assert llm_fast.json()["code"] == "OK"
assert llm_fast.json()["data"]["stream"] is False
llm_reason = client.post("/api/v1/llm/test/deepseek-reason", headers=headers, json={"message": "hello"})
assert llm_reason.status_code == 200
assert llm_reason.json()["code"] == "OK"
assert "total_latency_ms" in llm_reason.json()["data"]
import_preview = client.post(
"/api/v1/imports/case-sql/preview",
headers=headers,
files={"file": ("bad.sql", b"not sql", "application/sql")},
)
assert import_preview.status_code == 200
assert import_preview.json()["code"] == "OK"
assert import_preview.json()["data"]["can_import"] is False
assert import_preview.json()["data"]["errors"]
import_apply = client.post(
"/api/v1/imports/case-sql/apply",
headers=headers,
files={"file": ("bad.sql", b"not sql", "application/sql")},
)
assert import_apply.status_code == 400
assert import_apply.json()["code"] == "CASE_SQL_IMPORT_INVALID"
temp_case_id = 880001
with SessionLocal() as db:
CaseRepository(db).delete_case_cascade(temp_case_id)
db.add(
CaseBase(
id=temp_case_id,
title="删除测试病例",
case_type="diagnosis_treatment",
difficulty="medium",
chief_complaint="发热、咳嗽",
description="用于删除接口测试的临时病例",
patient_age=6,
patient_gender="male",
tags="test",
symptom_tags=[],
disease_tags=[],
competency_tags=[],
guideline_tags=[],
knowledge_points=[],
icd_codes="",
osce_enabled=False,
rag_enabled=False,
ai_prompt_template="",
multimodal_assets=[],
vector_status=0,
publish_status=1,
status=1,
department_id=1,
)
)
db.add(
TraditionalCase(
id=temp_case_id,
case_id=temp_case_id,
standard_diagnosis="测试诊断",
standard_treatment="测试治疗",
guideline_reference="测试指南",
)
)
db.add(
ScoringRule(
id=temp_case_id,
case_id=temp_case_id,
dimension="信息采集",
competency_dimension="问诊完整性",
score_weight=Decimal("10.00"),
ai_auto_score=True,
osce_dimension=False,
scoring_standard="测试评分标准",
rubric_json={},
)
)
db.add(
CaseExamItem(
id=temp_case_id,
case_id=temp_case_id,
item_code="temp_exam",
item_name="测试检查",
item_type="lab",
result_text="测试结果",
result_structured={},
is_key=True,
is_abnormal=False,
score_weight=Decimal("1.00"),
display_order=1,
)
)
db.commit()
delete_preview = client.get(f"/api/v1/cases/{temp_case_id}/delete-preview", headers=headers)
assert delete_preview.status_code == 200
assert delete_preview.json()["data"]["affected"]["case_base"] == 1
assert delete_preview.json()["data"]["affected"]["case_exam_item"] == 1
delete_without_confirm = client.request(
"DELETE",
f"/api/v1/cases/{temp_case_id}",
headers=headers,
json={"confirm": False, "delete_training_data": True},
)
assert delete_without_confirm.status_code == 400
assert delete_without_confirm.json()["code"] == "CASE_DELETE_CONFIRM_REQUIRED"
delete_case = client.request(
"DELETE",
f"/api/v1/cases/{temp_case_id}",
headers=headers,
json={"confirm": True, "delete_training_data": True},
)
assert delete_case.status_code == 200
assert delete_case.json()["data"]["deleted_counts"]["case_base"] == 1
deleted_detail = client.get(f"/api/v1/cases/{temp_case_id}", headers=headers)
assert deleted_detail.status_code == 404
assert deleted_detail.json()["code"] == "CASE_NOT_FOUND"
if __name__ == "__main__":
run_api_contract_tests()
print("api contract tests passed")
+87
View File
@@ -0,0 +1,87 @@
import sys
import os
from pathlib import Path
os.environ.setdefault("DATABASE_URL", "sqlite:///./storage/test_core.db")
os.environ.setdefault("RUNTIME_MEMORY_BACKEND", "memory")
os.environ.setdefault("LLM_MOCK_ENABLED", "true")
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from app.agents.scoring_agent import ScoringAgent
from app.agents.hint_agent import HintAgent
from app.agents.llm_adapter import OpenAICompatibleLLMClient
from app.core.config import settings
from app.services.runtime_memory import InMemoryRuntimeMemoryService
def test_runtime_memory_lifecycle() -> None:
"""短期 memory:验证创建、写入、读取和释放流程。"""
memory = InMemoryRuntimeMemoryService()
memory.create("test-memory", "开场白")
memory.add_message("test-memory", "doctor", "孩子发热几天?")
assert memory.has_doctor_message("test-memory") is True
assert len(memory.get_messages("test-memory")) == 2
memory.release("test-memory")
assert memory.get_messages("test-memory") == []
def test_score_convert_to_five_point() -> None:
"""评分转换:验证百分制到五分制的结构转换。"""
agent = ScoringAgent()
result = agent._convert_to_five_point(
{
"score_type": "percentage",
"total_score": 80,
"dimension_scores": [{"dimension": "信息获取", "score": 20, "max_score": 25, "comment": "ok"}],
}
)
assert result["score_type"] == "five_point"
assert result["total_score"] == 4.0
assert result["dimension_scores"][0]["max_score"] == 5
def test_public_settings() -> None:
"""配置输出:验证 Demo 前端可读取功能开关。"""
public = settings.as_public_dict()
assert "score_types" in public
assert "percentage" in public["score_types"]
def test_reasoning_effort_disabled_when_thinking_off() -> None:
"""LLM 参数构造:thinking 关闭时不发送 reasoning_effort,避免 reason 测试流式 400。"""
client = OpenAICompatibleLLMClient()
payload = client._build_payload(
model="deepseek-v4-pro",
messages=[{"role": "user", "content": "hello"}],
stream=True,
thinking_enabled=False,
reasoning_effort="low",
max_tokens=128,
)
assert "reasoning_effort" not in payload
def test_hint_agent_invalid_json_fallback() -> None:
"""新手提示:验证模型输出结构不匹配时使用稳定 fallback。"""
agent = HintAgent()
payload = {
"case": {
"chief_complaint": "发热、咳嗽4天,喘息1天",
"key_exams": ["blood_routine", "chest_xray"],
},
"ordered_results": [],
}
result = agent._normalize_output({"score_type": "percentage"}, payload)
assert result["hints"]
assert result["next_questions"]
assert result["recommended_orders"]
if __name__ == "__main__":
test_runtime_memory_lifecycle()
test_score_convert_to_five_point()
test_public_settings()
test_reasoning_effort_disabled_when_thinking_off()
test_hint_agent_invalid_json_fallback()
print("core logic tests passed")
+148
View File
@@ -0,0 +1,148 @@
import asyncio
import os
import sys
from pathlib import Path
os.environ.setdefault("DATABASE_URL", "sqlite:///./storage/test_demo_flow.db")
os.environ.setdefault("RUNTIME_MEMORY_BACKEND", "memory")
os.environ.setdefault("LLM_MOCK_ENABLED", "true")
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from sqlalchemy import select
from app.core.context import UserContext
from app.core.exceptions import AppError
from app.db.session import SessionLocal
from app.models.source_case import CaseBase
from app.models.training_record import TrainingRecord
from app.schemas.evaluation import CreateEvaluationRequest
from app.schemas.session import (
ChatRequest,
CreateOrderRequest,
CreateSessionRequest,
SubmitDiagnosisRequest,
SubmitTreatmentRequest,
)
from app.services.evaluation_service import EvaluationService
from app.services.order_service import OrderService
from app.services.pdf_export_service import PdfExportService
from app.services.runtime_memory import runtime_memory
from app.services.session_service import SessionService
from scripts.init_demo_db import init_database
async def run_demo_flow() -> None:
"""完整闭环:验证第一版 Demo 的核心训练链路可跑通。"""
init_database()
ctx = UserContext(user_id="demo_flow_user", tenant_id="demo_tenant", entry_scene="service_test")
with SessionLocal() as db:
case = db.scalar(select(CaseBase).where(CaseBase.title == "支气管肺炎 - 6岁男性患儿"))
assert case is not None
session_service = SessionService(db)
order_service = OrderService(db)
evaluation_service = EvaluationService(db)
pdf_service = PdfExportService(db)
created = session_service.create_session(
ctx,
CreateSessionRequest(
case_id=case.id,
training_type="diagnosis_treatment",
mode="practice",
score_type="percentage",
),
)
db.commit()
assert created.status == "inquiry"
chat = await session_service.chat(ctx, created.session_id, ChatRequest(message="孩子最高体温多少?").message)
db.commit()
assert chat.reply
order = order_service.create_order(created.session_id, ctx.user_id, CreateOrderRequest(item_code="chest_xray").item_code)
db.commit()
assert order.is_key is True
tool_count_before = len([item for item in runtime_memory.get_messages(f"mem:{created.session_code}") if item.get("role") == "tool"])
duplicate_order = order_service.create_order(created.session_id, ctx.user_id, "chest_xray")
db.commit()
tool_count_after = len([item for item in runtime_memory.get_messages(f"mem:{created.session_code}") if item.get("role") == "tool"])
assert duplicate_order.already_ordered is True
assert tool_count_after == tool_count_before
try:
order_service.create_order(created.session_id, ctx.user_id, "not_exists")
except AppError as exc:
assert exc.code == "ORDER_ITEM_NOT_FOUND"
else:
raise AssertionError("invalid order item should raise AppError")
try:
order_service.list_order_items(created.session_id, "another_user")
except AppError as exc:
assert exc.code == "SESSION_NOT_FOUND"
else:
raise AssertionError("cross user access should raise AppError")
status = session_service.complete_inquiry(ctx, created.session_id)
db.commit()
assert status.status == "diagnosis"
diagnosis = session_service.submit_diagnosis(
ctx,
created.session_id,
SubmitDiagnosisRequest(
primary_diagnosis="支气管肺炎",
differential_diagnoses=["毛细支气管炎", "支气管哮喘急性发作"],
diagnosis_basis="发热咳嗽伴喘息,肺部湿啰音,胸片异常,炎症指标升高。",
),
)
db.commit()
assert diagnosis.status == "treatment"
treatment = session_service.submit_treatment(
ctx,
created.session_id,
SubmitTreatmentRequest(
treatment_principle="抗感染、止咳平喘、改善氧合、严密观察病情变化。",
treatment_measures="根据病情选择抗感染治疗,必要时雾化吸入,监测体温、呼吸和血氧。",
risk_plan="关注低氧、呼吸困难加重、持续高热和精神反应差。",
communication="向家属说明病情、用药注意事项和复诊指征。",
follow_up="治疗后复查体温、呼吸情况和必要炎症指标。",
),
)
db.commit()
assert treatment.status == "evaluating"
try:
await session_service.chat(ctx, created.session_id, "治疗后还能问诊吗?")
except AppError as exc:
assert exc.code == "SESSION_STATUS_INVALID"
else:
raise AssertionError("chat after treatment submission should raise AppError")
evaluation = await evaluation_service.create_evaluation(
ctx,
created.session_id,
CreateEvaluationRequest(score_type="percentage"),
)
db.commit()
assert evaluation.total_score > 0
training_record = db.scalar(select(TrainingRecord).where(TrainingRecord.session_id == created.session_id))
assert training_record is not None
assert training_record.external_user_id == ctx.user_id
export = pdf_service.export(evaluation.evaluation_id, ctx.user_id)
db.commit()
assert Path(export.file_path).exists()
assert Path(export.file_path).stat().st_size > 1000
history = evaluation_service.list_history(ctx.user_id)
assert history.items
if __name__ == "__main__":
asyncio.run(run_demo_flow())
print("demo flow test passed")
@@ -0,0 +1,75 @@
import tempfile
from pathlib import Path
import os
import sys
os.environ.setdefault("DATABASE_URL", "sqlite:///./storage/test_import.db")
os.environ.setdefault("RUNTIME_MEMORY_BACKEND", "memory")
os.environ.setdefault("LLM_MOCK_ENABLED", "true")
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from scripts.import_source_case_sql import ImportValidationError, extract_create_columns, extract_insert_rows, parse_values_clause
def test_extract_insert_rows_maps_by_create_columns() -> None:
"""导入解析:验证 INSERT VALUES 按 CREATE TABLE 字段顺序映射。"""
sql = """
CREATE TABLE `traditional_case` (
`created_at` datetime(6) NOT NULL,
`updated_at` datetime(6) NOT NULL,
`id` bigint NOT NULL AUTO_INCREMENT,
`standard_diagnosis` longtext NOT NULL,
`standard_treatment` longtext NOT NULL,
`guideline_reference` longtext NOT NULL,
`case_id` bigint NOT NULL
) ENGINE=InnoDB;
INSERT INTO `traditional_case` VALUES ('2026-05-28 09:34:08','2026-05-28 09:34:09',1,'支气管肺炎','抗感染','指南',1);
"""
columns = extract_create_columns(sql, "traditional_case")
rows = extract_insert_rows(sql, "traditional_case", columns)
assert rows[0]["standard_diagnosis"] == "支气管肺炎"
assert rows[0]["case_id"] == 1
def test_extract_insert_rows_rejects_broken_sql() -> None:
"""导入解析:验证损坏 SQL 被拒绝,避免半导入。"""
sql = """
CREATE TABLE `traditional_case` (
`created_at` datetime(6) NOT NULL,
`updated_at` datetime(6) NOT NULL,
`id` bigint NOT NULL AUTO_INCREMENT,
`standard_diagnosis` longtext NOT NULL,
`standard_treatment` longtext NOT NULL,
`guideline_reference` longtext NOT NULL,
`case_id` bigint NOT NULL
) ENGINE=InnoDB;
INSERT INTO `traditional_case` VALUES ('2026-05-28','2026-05-28',1,'支气管肺炎'bad,'抗感染','指南',1);
"""
columns = extract_create_columns(sql, "traditional_case")
try:
extract_insert_rows(sql, "traditional_case", columns)
except ImportValidationError:
return
raise AssertionError("broken SQL should be rejected")
def test_parse_values_clause_recovers_misplaced_quote_separator() -> None:
"""导入解析:兼容接口 SQL 中文本字段写成 `文本,'next'` 的引号/逗号错位。"""
rows = parse_values_clause("(1,'标题?,'traditional','medium',NULL,'主诉?,'描述?,6,'male')")
assert rows[0] == [1, "标题?", "traditional", "medium", None, "主诉?", "描述?", 6, "male"]
def test_parse_values_clause_recovers_unclosed_string_before_tuple_end() -> None:
"""导入解析:兼容接口 SQL 最后文本字段缺少结束引号且后接外键的情况。"""
rows = parse_values_clause("(1,'诊断?,'治疗?,'指南?,1)")
assert rows[0] == [1, "诊断?", "治疗?", "指南?", 1]
if __name__ == "__main__":
test_extract_insert_rows_maps_by_create_columns()
test_extract_insert_rows_rejects_broken_sql()
test_parse_values_clause_recovers_misplaced_quote_separator()
test_parse_values_clause_recovers_unclosed_string_before_tuple_end()
print("import source case sql tests passed")