feat: add streaming learning assistant and knowledge base scaffolding

This commit is contained in:
刘金宝
2026-06-10 09:32:36 +08:00
parent f0cdc454b3
commit 89258ab448
31 changed files with 2021 additions and 330 deletions
+79 -3
View File
@@ -10,6 +10,10 @@ os.environ["REPORT_STORAGE_DIR"] = str(Path(tempfile.gettempdir()) / "medical_ag
os.environ.setdefault("RUNTIME_MEMORY_BACKEND", "memory")
os.environ.setdefault("LLM_MOCK_ENABLED", "true")
os.environ.setdefault("AUTH_USER_ME_URL", "http://django-user-center.test/api/user/users/me/")
os.environ.setdefault("MILVUS_URI", "mock://milvus")
os.environ.setdefault("EMBEDDING_PROVIDER", "mock")
os.environ.setdefault("KNOWLEDGE_INGESTION_SYNC", "true")
os.environ.setdefault("KNOWLEDGE_STORAGE_DIR", str(Path(tempfile.gettempdir()) / "medical_agent_test_knowledge"))
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
@@ -36,20 +40,33 @@ def run_api_contract_tests() -> None:
from app.core.exceptions import AppError
raise AppError("AUTH_CREDENTIAL_REQUIRED", "Authorization header is required", 401)
user_id = "api_user_002" if "api_user_002_token" in authorization else "api_user_001"
if "content_admin_token" in authorization:
user_id = "content_admin_001"
role = "content_admin"
else:
user_id = "api_user_002" if "api_user_002_token" in authorization else "api_user_001"
role = "student"
return AuthenticatedUser(
user_id=user_id,
username=f"{user_id}_name",
display_name="Swagger测试",
role="student",
role=role,
tenant_id="1",
institution_id=1,
institution_name="test institution",
department_id=1,
department_name="pediatrics",
status=1,
profile={
"id": user_id,
"username": f"{user_id}_name",
"real_name": "Swagger测试",
"role_type": "student",
"role_type": role,
"institution": 1,
"institution_id": 1,
"department": 1,
"department_id": 1,
"department_name": "pediatrics",
"institution_name": "测试机构",
"status": 1,
},
@@ -97,6 +114,65 @@ def run_api_contract_tests() -> None:
assert "/api/v1/evaluations/{evaluation_id}/download-pdf" in openapi_payload["paths"]
assert "/api/v1/teaching/cases/{case_id}/items" in openapi_payload["paths"]
assert "/api/v1/teaching/evaluation" in openapi_payload["paths"]
assert "/api/v1/knowledge-admin/documents/upload" in openapi_payload["paths"]
assert "/api/v1/learning-assistant/chat" not in openapi_payload["paths"]
assert "/api/v1/learning-assistant/chat/stream" in openapi_payload["paths"]
with client.stream(
"POST",
"/api/v1/learning-assistant/chat/stream",
headers=headers,
json={"question": "支气管肺炎有哪些常见表现?", "top_k": 1},
) as no_kb_stream:
assert no_kb_stream.status_code == 200
no_kb_stream_text = "".join(no_kb_stream.iter_text())
assert "event: retrieval_done" in no_kb_stream_text
assert '"retrieval_hit": false' in no_kb_stream_text
assert "event: answer_delta" in no_kb_stream_text
assert "event: answer_done" in no_kb_stream_text
assert "event: error" not in no_kb_stream_text
from app.integrations.pdf_parser import ParsedPdfPage, PdfParser
def fake_pdf_parse(self, file_path): # noqa: ARG001
"""知识库测试:用稳定页文本替代真实 PDF 解析,避免测试依赖外部文件。"""
return [
ParsedPdfPage(
page_number=12,
text="支气管肺炎常见表现包括发热、咳嗽、喘息和肺部湿啰音。血氧饱和度下降提示病情可能加重。",
)
]
PdfParser.parse = fake_pdf_parse
admin_headers = {"Authorization": "Bearer content_admin_token", "X-Entry-Scene": "api_test"}
upload = client.post(
"/api/v1/knowledge-admin/documents/upload",
headers=admin_headers,
data={"document_title": "诊断学第十版", "document_category": "textbook", "version": "v1"},
files={"file": ("diagnostics.pdf", b"%PDF-1.4\n%%EOF", "application/pdf")},
)
assert upload.status_code == 200
assert upload.json()["data"]["status"] == "ready"
assert upload.json()["data"]["chunk_count"] >= 1
document_id = upload.json()["data"]["document_id"]
document_detail = client.get(f"/api/v1/knowledge-admin/documents/{document_id}", headers=admin_headers)
assert document_detail.status_code == 200
assert document_detail.json()["data"]["document_title"] == "诊断学第十版"
with client.stream(
"POST",
"/api/v1/learning-assistant/chat/stream",
headers=headers,
json={"question": "血氧下降说明什么?", "top_k": 1},
) as rag_stream:
assert rag_stream.status_code == 200
rag_stream_text = "".join(rag_stream.iter_text())
assert "event: retrieval_done" in rag_stream_text
assert '"retrieval_hit": true' in rag_stream_text
assert '"page_start": 12' in rag_stream_text
assert "event: answer_delta" in rag_stream_text
assert "event: answer_done" in rag_stream_text
cases = client.get("/api/v1/cases", headers=headers)
assert cases.status_code == 200