diff --git a/.env.example b/.env.example index 96f880e..8e84798 100644 --- a/.env.example +++ b/.env.example @@ -43,3 +43,29 @@ LLM_SCORING_JSON_RESPONSE=true LLM_SCORING_MAX_TOKENS=4096 REPORT_STORAGE_DIR=./storage/reports + +# Knowledge base / Milvus / RAG +MILVUS_URI=http://milvus-standalone:19530 +MILVUS_COLLECTION_PREFIX=kb_inst +MILVUS_DEFAULT_DB=default +CELERY_BROKER_URL=redis://127.0.0.1:6379/1 +CELERY_RESULT_BACKEND=redis://127.0.0.1:6379/2 +KNOWLEDGE_INGESTION_SYNC=true +KNOWLEDGE_STORAGE_DIR=./storage/knowledge +KNOWLEDGE_MAX_UPLOAD_MB=50 + +# Embedding +EMBEDDING_PROVIDER=openai_compatible +EMBEDDING_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 +EMBEDDING_API_KEY= +EMBEDDING_MODEL=text-embedding-v4 +EMBEDDING_DIM=1024 +EMBEDDING_BATCH_SIZE=16 +EMBEDDING_TIMEOUT_SECONDS=30 + +# Learning assistant RAG +RAG_TOP_N=20 +RAG_TOP_K=5 +RAG_SCORE_THRESHOLD=0.35 +RAG_QUERY_REWRITE_ENABLED=false +RAG_RERANK_ENABLED=false diff --git a/.env.production.example b/.env.production.example index cd6ab6c..106d389 100644 --- a/.env.production.example +++ b/.env.production.example @@ -43,3 +43,29 @@ LLM_SCORING_JSON_RESPONSE=true LLM_SCORING_MAX_TOKENS=4096 REPORT_STORAGE_DIR=/app/storage/reports + +# Knowledge base / Milvus / RAG +MILVUS_URI=http://milvus-standalone:19530 +MILVUS_COLLECTION_PREFIX=kb_inst +MILVUS_DEFAULT_DB=default +CELERY_BROKER_URL=redis://redis:6379/1 +CELERY_RESULT_BACKEND=redis://redis:6379/2 +KNOWLEDGE_INGESTION_SYNC=false +KNOWLEDGE_STORAGE_DIR=/app/storage/knowledge +KNOWLEDGE_MAX_UPLOAD_MB=50 + +# Embedding +EMBEDDING_PROVIDER=openai_compatible +EMBEDDING_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 +EMBEDDING_API_KEY=CHANGE_ME +EMBEDDING_MODEL=text-embedding-v4 +EMBEDDING_DIM=1024 +EMBEDDING_BATCH_SIZE=16 +EMBEDDING_TIMEOUT_SECONDS=30 + +# Learning assistant RAG +RAG_TOP_N=20 +RAG_TOP_K=5 +RAG_SCORE_THRESHOLD=0.35 +RAG_QUERY_REWRITE_ENABLED=false +RAG_RERANK_ENABLED=false diff --git a/README.md b/README.md index fb46a24..44cae92 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ # 医疗问诊 Agent FastAPI 后端 -医疗问诊 Agent 是医疗教学平台中的训练服务。后端负责 Django 用户鉴权、病例读取、训练会话、流式问诊、练习提示、检查结果、诊断治疗提交、AI 评价、教学互动评价、训练记录和 PDF 下载。 +医疗问诊 Agent 是医疗教学平台中的 FastAPI 后端服务,负责 Django 用户鉴权、病例读取、问诊训练、教学互动、AI 评价、PDF 下载、AI 学习助手问答,以及后台预留的机构知识库构建能力。 -病例新增、病例解析、病例导入和病例删除不在本服务中实现。本服务只读取数据库中已经维护好的病例、检查项、教学题和评分规则。 +本服务不负责登录注册、病例 PDF 解析入库、病例增删改、多租户后台、HIS/LIS/PACS 对接。病例、检查项、教学题和评分规则由平台数据库维护,本服务只读取并使用。 -## 当前保留功能 +## 当前功能 训练页面: @@ -36,12 +36,27 @@ - 训练记录列表 - 训练记录详情 +AI 学习助手: + +- 普通用户通过流式接口提问 +- 后端优先检索本机构知识库 +- 未命中知识库或知识库暂不可用时,自动转为大模型通用学习回答 +- 命中知识库时返回 PDF 标题、页码、chunk_uid 和引用片段 + +后台预留能力: + +- 内容管理员上传 PDF 构建机构知识库 +- PDF 解析、分片、Embedding、Milvus 写入和 Celery 异步任务已留出接口 +- 当前阶段不要求真实 PDF 入库测试,优先保证 AI 学习助手问答可用 + 基础能力: - Django access token 鉴权 -- MySQL 数据读取和训练记录写入 +- MySQL 业务数据读取和训练记录写入 - Redis 短期会话 memory - OpenAI-compatible LLM 调用 +- OpenAI-compatible Embedding 调用预留 +- Milvus 向量检索预留 - Swagger / OpenAPI - 健康检查 @@ -50,8 +65,16 @@ ```text fastapi/ ├── app/ # FastAPI 应用、Agent、服务、模型和提示词 -├── scripts/ # 数据库初始化和检查脚本 -├── tests/ # 当前功能测试 +│ ├── api/ # API router +│ ├── agents/ # LLM Agent +│ ├── integrations/ # PDF、Embedding、Milvus 外部适配 +│ ├── models/ # SQLAlchemy ORM +│ ├── repositories/ # 数据访问层 +│ ├── schemas/ # Pydantic schema +│ ├── services/ # 业务服务 +│ └── tasks/ # Celery 异步任务预留 +├── scripts/ # 初始化和维护脚本 +├── tests/ # 自动化测试 ├── docs/03_api_design.md # 前端联调 API 文档 ├── Dockerfile ├── requirements.txt @@ -75,7 +98,39 @@ uvicorn app.main:app --host 127.0.0.1 --port 9000 http://127.0.0.1:9000/docs ``` -真实密码、LLM Key 和 access token 只写入本地 `.env` 或服务器环境变量,不提交到 Git。 +真实密码、LLM Key、Embedding Key 和 access token 只写入本地 `.env` 或服务器环境变量,不提交到 Git。 + +## 关键环境变量 + +```env +APP_ENV=local +APP_ROOT_PATH= +DATABASE_URL=mysql+pymysql://root:@127.0.0.1:3306/medical?charset=utf8mb4 +RUNTIME_MEMORY_BACKEND=redis +REDIS_URL=redis://127.0.0.1:6379/0 + +AUTH_VALIDATE_ENABLED=true +AUTH_USER_ME_URL=http://127.0.0.1:8000/api/user/users/me/ + +LLM_BASE_URL=https://api.deepseek.com/chat/completions +LLM_API_KEY= +LLM_MODEL=deepseek-chat +LLM_FAST_MODEL=deepseek-chat +LLM_REASON_MODEL=deepseek-reasoner + +MILVUS_URI=http://127.0.0.1:19530 +MILVUS_COLLECTION_PREFIX=kb_inst +MILVUS_DEFAULT_DB=default + +EMBEDDING_PROVIDER=openai_compatible +EMBEDDING_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 +EMBEDDING_API_KEY= +EMBEDDING_MODEL=text-embedding-v4 +EMBEDDING_DIM=1024 + +KNOWLEDGE_INGESTION_SYNC=true +KNOWLEDGE_STORAGE_DIR=./storage/knowledge +``` ## 服务器部署 @@ -87,6 +142,8 @@ http://127.0.0.1:9000/docs ├── fastapi/ ├── vueapp/ ├── vuecms/ +├── data/ +├── logs/ └── docker-compose.yml ``` @@ -99,25 +156,7 @@ cp fastapi/.env.production.example fastapi/.env vi fastapi/.env ``` -服务器 `.env` 至少配置: - -```env -APP_ENV=production -APP_ROOT_PATH=/fastapi -DATABASE_URL=mysql+pymysql://root:1822..@mysql:3306/medical?charset=utf8mb4 -REDIS_URL=redis://redis:6379/0 -RUNTIME_MEMORY_BACKEND=redis -AUTH_VALIDATE_ENABLED=true -AUTH_USER_ME_URL=http://django:8000/api/user/users/me/ -LLM_BASE_URL=<模型服务地址> -LLM_API_KEY=<模型密钥> -LLM_MODEL=<模型名称> -LLM_FAST_MODEL=<模型名称> -LLM_REASON_MODEL=<模型名称> -CORS_ALLOW_ORIGINS=http://8.160.178.88 -``` - -构建并启动: +FastAPI 构建和启动: ```bash cd /home/code/medical-ai @@ -134,6 +173,7 @@ git pull origin main cd .. docker compose build fastapi docker compose up -d fastapi +docker compose logs --tail=200 fastapi ``` ## 验证 @@ -146,7 +186,7 @@ http://8.160.178.88/fastapi/openapi.json http://8.160.178.88/fastapi/health/ready ``` -Django 用户鉴权验证: +Django 用户鉴权: ```bash curl "http://8.160.178.88/fastapi/api/v1/auth/me" \ @@ -154,7 +194,17 @@ curl "http://8.160.178.88/fastapi/api/v1/auth/me" \ -H "X-Entry-Scene: production_vue" ``` -PDF 下载验证: +AI 学习助手流式问答: + +```bash +curl -N -X POST "http://8.160.178.88/fastapi/api/v1/learning-assistant/chat/stream" \ + -H "Authorization: Bearer " \ + -H "X-Entry-Scene: production_vue" \ + -H "Content-Type: application/json" \ + -d '{"question":"支气管肺炎有哪些典型临床表现?","top_k":5}' +``` + +PDF 评价报告下载: ```bash curl -L "http://8.160.178.88/fastapi/api/v1/evaluations//download-pdf" \ @@ -182,11 +232,13 @@ python tests\test_demo_flow.py - 完成问诊、提交诊断、提交治疗、生成评价 - 教学互动列表和教学互动评价 - 训练记录列表、评价详情、PDF 下载 -- 跨用户访问拦截 +- AI 学习助手无知识库降级流式回答 +- AI 学习助手命中知识库后的来源返回 +- 跨用户访问拒绝 ## API 文档 -前端联调文档见: +前端联调文档: ```text docs/03_api_design.md diff --git a/app/agents/learning_assistant_agent.py b/app/agents/learning_assistant_agent.py new file mode 100644 index 0000000..289e91d --- /dev/null +++ b/app/agents/learning_assistant_agent.py @@ -0,0 +1,58 @@ +from collections.abc import AsyncIterator + +from app.agents.llm_adapter import LLMResponse, LLMStreamChunk, OpenAICompatibleLLMClient +from app.core.config import settings +from app.schemas.learning_assistant import LearningAssistantSource + + +class LearningAssistantAgent: + """AI学习助手 Agent:根据 RAG 来源生成带循证出处的医学学习回答。""" + + def __init__(self, llm_client: OpenAICompatibleLLMClient | None = None) -> None: + self.llm_client = llm_client or OpenAICompatibleLLMClient() + + async def answer(self, question: str, sources: list[LearningAssistantSource]) -> LLMResponse: + """非流式回答:把问题和检索来源拼接后调用快速模型生成标准回答。""" + return await self.llm_client.chat( + self._messages(question, sources), + model=settings.llm_fast_model, + thinking_enabled=settings.llm_fast_thinking_enabled, + max_tokens=1200, + ) + + async def stream_answer(self, question: str, sources: list[LearningAssistantSource]) -> AsyncIterator[LLMStreamChunk]: + """流式回答:输出 AI 学习助手增量文本,前端可直接渲染。""" + async for chunk in self.llm_client.stream_chat( + self._messages(question, sources), + model=settings.llm_fast_model, + thinking_enabled=settings.llm_fast_thinking_enabled, + max_tokens=1200, + ): + yield chunk + + def _messages(self, question: str, sources: list[LearningAssistantSource]) -> list[dict]: + """提示词拼接:命中知识库时必须引用来源,未命中时必须声明未找到参考。""" + if sources: + context = "\n\n".join( + ( + f"[来源{index}] 文档:{source.document_title or source.file_name};" + f"页码:{source.page_start}-{source.page_end};chunk_uid:{source.chunk_uid}\n" + f"{source.quote}" + ) + for index, source in enumerate(sources, start=1) + ) + system = ( + "你是医学学习助手,只用于医学教育学习,不替代临床诊疗。" + "请优先依据给定知识库片段回答,回答要清晰、准确、分点。" + "每个关键结论后标注对应来源编号,例如【来源1】。" + "不得编造不存在的PDF、页码或指南来源。" + ) + user = f"用户问题:{question}\n\n可用知识库片段:\n{context}\n\n请给出带来源的学习回答。" + else: + system = ( + "你是医学学习助手,只用于医学教育学习,不替代临床诊疗。" + "当前没有检索到机构知识库参考,回答开头必须写:未检索到本机构知识库参考,以下为大模型通用学习回答。" + "不得伪造PDF来源、页码或指南名称。" + ) + user = f"用户问题:{question}\n\n请给出通用学习回答,并提醒用户以课程教材和临床规范为准。" + return [{"role": "system", "content": system}, {"role": "user", "content": user}] diff --git a/app/api/knowledge_admin.py b/app/api/knowledge_admin.py new file mode 100644 index 0000000..2f2a302 --- /dev/null +++ b/app/api/knowledge_admin.py @@ -0,0 +1,87 @@ +from fastapi import APIRouter, Depends, File, Form, UploadFile +from sqlalchemy.orm import Session + +from app.core.response import ApiResponse, ok +from app.core.user_context import UserContext, get_user_context +from app.db.session import get_db +from app.repositories.knowledge_base_repository import KnowledgeBaseRepository +from app.schemas.knowledge_admin import ( + KnowledgeDocumentDetailResponse, + KnowledgeDocumentListResponse, + KnowledgeDocumentUploadResponse, +) +from app.services.document_ingestion_service import DocumentIngestionService +from app.services.knowledge_space_service import KnowledgeSpaceService + +router = APIRouter() + + +@router.post("/documents/upload", response_model=ApiResponse[KnowledgeDocumentUploadResponse]) +async def upload_knowledge_document( + file: UploadFile = File(..., description="PDF 文件"), + document_title: str | None = Form(default=None, description="文档标题"), + document_category: str = Form(default="textbook", description="文档分类:textbook/guideline/manual/other"), + version: str = Form(default="v1", description="文档版本"), + ctx: UserContext = Depends(get_user_context), + db: Session = Depends(get_db), +): + """知识库上传:内容管理员上传 PDF,并触发机构知识库构建。""" + result = await DocumentIngestionService(db).upload_pdf( + ctx, + file, + document_title=document_title, + document_category=document_category, + version=version, + ) + db.commit() + return ok(result) + + +@router.get("/documents", response_model=ApiResponse[KnowledgeDocumentListResponse]) +def list_knowledge_documents( + ctx: UserContext = Depends(get_user_context), + db: Session = Depends(get_db), +): + """知识库文档列表:内容管理员查看本机构已上传文档。""" + repo = KnowledgeBaseRepository(db) + KnowledgeSpaceService(repo).ensure_content_admin(ctx) + institution_id = KnowledgeSpaceService(repo).require_institution_id(ctx) + items = [_to_detail(item) for item in repo.list_documents(institution_id)] + return ok(KnowledgeDocumentListResponse(items=items)) + + +@router.get("/documents/{document_id}", response_model=ApiResponse[KnowledgeDocumentDetailResponse]) +def get_knowledge_document( + document_id: int, + ctx: UserContext = Depends(get_user_context), + db: Session = Depends(get_db), +): + """知识库文档详情:按机构隔离返回 PDF 构建状态。""" + repo = KnowledgeBaseRepository(db) + KnowledgeSpaceService(repo).ensure_content_admin(ctx) + institution_id = KnowledgeSpaceService(repo).require_institution_id(ctx) + document = repo.get_document(document_id, institution_id) + if not document: + from app.core.exceptions import AppError + + raise AppError("KNOWLEDGE_DOCUMENT_NOT_FOUND", "knowledge document not found", 404) + return ok(_to_detail(document)) + + +def _to_detail(document) -> KnowledgeDocumentDetailResponse: + """响应转换:把 ORM 文档对象转换为 API 文档详情。""" + return KnowledgeDocumentDetailResponse( + document_id=document.id, + institution_id=document.institution_id, + file_name=document.file_name, + document_title=document.document_title, + document_category=document.document_category, + version=document.version, + status=document.status, + parse_status=document.parse_status, + embedding_status=document.embedding_status, + chunk_count=document.chunk_count, + error_message=document.error_message, + created_at=getattr(document, "created_at", None), + updated_at=getattr(document, "updated_at", None), + ) diff --git a/app/api/learning_assistant.py b/app/api/learning_assistant.py new file mode 100644 index 0000000..599ad93 --- /dev/null +++ b/app/api/learning_assistant.py @@ -0,0 +1,39 @@ +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +from starlette.responses import StreamingResponse + +from app.core.response import ApiResponse, ok +from app.core.user_context import UserContext, get_user_context +from app.db.session import get_db +from app.schemas.learning_assistant import LearningAssistantChatRequest, LearningAssistantChatResponse +from app.services.learning_assistant_service import LearningAssistantService + +router = APIRouter() + + +@router.post("/chat", response_model=ApiResponse[LearningAssistantChatResponse], include_in_schema=False) +async def learning_assistant_chat( + payload: LearningAssistantChatRequest, + ctx: UserContext = Depends(get_user_context), + db: Session = Depends(get_db), +): + """AI 学习助手调试接口:非流式返回回答,正式前端联调使用流式接口。""" + result = await LearningAssistantService(db).chat(ctx, payload) + db.commit() + return ok(result) + + +@router.post("/chat/stream", response_class=StreamingResponse) +async def learning_assistant_stream_chat( + payload: LearningAssistantChatRequest, + ctx: UserContext = Depends(get_user_context), + db: Session = Depends(get_db), +): + """AI 学习助手流式问答:返回 retrieval_done、answer_delta、answer_done 事件。""" + stream = LearningAssistantService(db).stream_chat(ctx, payload) + db.commit() + return StreamingResponse( + stream, + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, + ) diff --git a/app/api/router.py b/app/api/router.py index e3ce21c..f03dad0 100644 --- a/app/api/router.py +++ b/app/api/router.py @@ -1,6 +1,6 @@ from fastapi import APIRouter -from app.api import agent, auth, cases, evaluations, sessions, teaching, training_config +from app.api import agent, auth, cases, evaluations, knowledge_admin, learning_assistant, sessions, teaching, training_config api_router = APIRouter() api_router.include_router(agent.router, tags=["agent"]) @@ -10,3 +10,5 @@ api_router.include_router(training_config.router, prefix="/training-config", tag api_router.include_router(sessions.router, prefix="/sessions", tags=["sessions"]) api_router.include_router(teaching.router, prefix="/teaching", tags=["teaching"]) api_router.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"]) +api_router.include_router(knowledge_admin.router, prefix="/knowledge-admin", tags=["knowledge-admin"]) +api_router.include_router(learning_assistant.router, prefix="/learning-assistant", tags=["learning-assistant"]) diff --git a/app/core/config.py b/app/core/config.py index 174dbd9..53997de 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -116,6 +116,26 @@ class Settings(BaseModel): auth_user_me_url: str = Field(default_factory=lambda: os.getenv("AUTH_USER_ME_URL", "")) auth_timeout_seconds: int = Field(default_factory=lambda: int(os.getenv("AUTH_TIMEOUT_SECONDS", "5"))) auth_cache_ttl_seconds: int = Field(default_factory=lambda: int(os.getenv("AUTH_CACHE_TTL_SECONDS", "300"))) + milvus_uri: str = Field(default_factory=lambda: os.getenv("MILVUS_URI", "http://milvus-standalone:19530")) + milvus_collection_prefix: str = Field(default_factory=lambda: os.getenv("MILVUS_COLLECTION_PREFIX", "kb_inst")) + milvus_default_db: str = Field(default_factory=lambda: os.getenv("MILVUS_DEFAULT_DB", "default")) + celery_broker_url: str = Field(default_factory=lambda: os.getenv("CELERY_BROKER_URL", "redis://redis:6379/1")) + celery_result_backend: str = Field(default_factory=lambda: os.getenv("CELERY_RESULT_BACKEND", "redis://redis:6379/2")) + knowledge_ingestion_sync: bool = Field(default_factory=lambda: _env_bool("KNOWLEDGE_INGESTION_SYNC", False)) + knowledge_storage_dir: str = Field(default_factory=lambda: os.getenv("KNOWLEDGE_STORAGE_DIR", "./storage/knowledge")) + knowledge_max_upload_mb: int = Field(default_factory=lambda: int(os.getenv("KNOWLEDGE_MAX_UPLOAD_MB", "50"))) + embedding_provider: str = Field(default_factory=lambda: os.getenv("EMBEDDING_PROVIDER", "openai_compatible")) + embedding_base_url: str = Field(default_factory=lambda: _env_first("EMBEDDING_BASE_URL", "LLM_BASE_URL", default="")) + embedding_api_key: str = Field(default_factory=lambda: _env_first("EMBEDDING_API_KEY", "LLM_API_KEY", default="")) + embedding_model: str = Field(default_factory=lambda: os.getenv("EMBEDDING_MODEL", "text-embedding-v4")) + embedding_dim: int = Field(default_factory=lambda: int(os.getenv("EMBEDDING_DIM", "1024"))) + embedding_batch_size: int = Field(default_factory=lambda: int(os.getenv("EMBEDDING_BATCH_SIZE", "16"))) + embedding_timeout_seconds: int = Field(default_factory=lambda: int(os.getenv("EMBEDDING_TIMEOUT_SECONDS", "30"))) + rag_top_n: int = Field(default_factory=lambda: int(os.getenv("RAG_TOP_N", "20"))) + rag_top_k: int = Field(default_factory=lambda: int(os.getenv("RAG_TOP_K", "5"))) + rag_score_threshold: float = Field(default_factory=lambda: float(os.getenv("RAG_SCORE_THRESHOLD", "0.35"))) + rag_query_rewrite_enabled: bool = Field(default_factory=lambda: _env_bool("RAG_QUERY_REWRITE_ENABLED", False)) + rag_rerank_enabled: bool = Field(default_factory=lambda: _env_bool("RAG_RERANK_ENABLED", False)) @property def is_production(self) -> bool: @@ -137,6 +157,8 @@ class Settings(BaseModel): errors.append("AUTH_USER_ME_URL is required") if self.llm_api_key in {"", "CHANGE_ME"} and not self.llm_mock_enabled: errors.append("LLM_API_KEY is required when mock mode is disabled") + if self.embedding_api_key in {"", "CHANGE_ME"} and not self.llm_mock_enabled: + errors.append("EMBEDDING_API_KEY is required when mock mode is disabled") return errors def as_public_dict(self) -> dict[str, Any]: @@ -157,6 +179,10 @@ class Settings(BaseModel): "llm_reasoning_effort": self.llm_reasoning_effort, "llm_fast_max_tokens": self.llm_fast_max_tokens, "runtime_memory_backend": self.runtime_memory_backend, + "learning_assistant": True, + "knowledge_admin": True, + "milvus_enabled": bool(self.milvus_uri), + "embedding_model": self.embedding_model, "auth_validate_enabled": True, "auth_source": "django_user_center", } diff --git a/app/integrations/__init__.py b/app/integrations/__init__.py new file mode 100644 index 0000000..de8e180 --- /dev/null +++ b/app/integrations/__init__.py @@ -0,0 +1 @@ +"""外部能力适配层:封装 PDF 解析、Embedding 和 Milvus 等可替换基础设施。""" diff --git a/app/integrations/embedding_adapter.py b/app/integrations/embedding_adapter.py new file mode 100644 index 0000000..32164eb --- /dev/null +++ b/app/integrations/embedding_adapter.py @@ -0,0 +1,77 @@ +import hashlib +import math +from dataclasses import dataclass + +import httpx + +from app.core.config import settings +from app.core.exceptions import AppError + + +@dataclass(frozen=True) +class EmbeddingUsage: + """Embedding 调用指标:记录批量向量化的模型和 token 用量。""" + + model: str + total_tokens: int | None = None + + +class OpenAICompatibleEmbeddingClient: + """Embedding Adapter:封装 OpenAI-compatible embeddings 接口,并提供稳定 mock。""" + + @property + def is_mock_mode(self) -> bool: + """模式判断:没有 API Key 或显式 mock provider 时使用确定性本地向量。""" + return settings.embedding_provider.lower() == "mock" or not settings.embedding_api_key + + async def embed_texts(self, texts: list[str]) -> tuple[list[list[float]], EmbeddingUsage]: + """文本向量化:对文本批量生成 embedding,返回与输入顺序一致的向量列表。""" + if not texts: + return [], EmbeddingUsage(model=settings.embedding_model, total_tokens=0) + if self.is_mock_mode: + return [self._mock_vector(text) for text in texts], EmbeddingUsage(model=f"mock-{settings.embedding_model}") + + try: + async with httpx.AsyncClient(timeout=settings.embedding_timeout_seconds) as client: + resp = await client.post( + self._embeddings_url(), + headers={"Authorization": f"Bearer {settings.embedding_api_key}"}, + json={"model": settings.embedding_model, "input": texts}, + ) + resp.raise_for_status() + payload = resp.json() + except (httpx.TimeoutException, httpx.HTTPError, ValueError) as exc: + raise AppError("EMBEDDING_CALL_FAILED", "embedding service call failed", 502) from exc + + try: + vectors = [item["embedding"] for item in sorted(payload["data"], key=lambda item: item.get("index", 0))] + self._validate_vectors(vectors) + usage = payload.get("usage") or {} + return vectors, EmbeddingUsage(model=settings.embedding_model, total_tokens=usage.get("total_tokens")) + except (KeyError, TypeError, ValueError) as exc: + raise AppError("EMBEDDING_RESPONSE_INVALID", "embedding response format invalid", 502) from exc + + def _embeddings_url(self) -> str: + """接口地址:兼容 base URL 和完整 /embeddings URL 两种写法。""" + base_url = settings.embedding_base_url.rstrip("/") + if base_url.endswith("/embeddings"): + return base_url + return f"{base_url}/embeddings" + + def _validate_vectors(self, vectors: list[list[float]]) -> None: + """向量校验:确保向量维度与 Milvus collection 维度一致。""" + for vector in vectors: + if len(vector) != settings.embedding_dim: + raise ValueError(f"embedding dimension mismatch: {len(vector)} != {settings.embedding_dim}") + + def _mock_vector(self, text: str) -> list[float]: + """Mock向量:基于文本哈希生成稳定归一化向量,便于本地和CI测试。""" + values: list[float] = [] + seed = hashlib.sha256(f"{settings.embedding_model}:{text}".encode("utf-8")).digest() + current = seed + while len(values) < settings.embedding_dim: + current = hashlib.sha256(current).digest() + values.extend((byte / 127.5) - 1.0 for byte in current) + vector = values[: settings.embedding_dim] + norm = math.sqrt(sum(item * item for item in vector)) or 1.0 + return [item / norm for item in vector] diff --git a/app/integrations/milvus_adapter.py b/app/integrations/milvus_adapter.py new file mode 100644 index 0000000..678bcd9 --- /dev/null +++ b/app/integrations/milvus_adapter.py @@ -0,0 +1,112 @@ +from dataclasses import dataclass + +from app.core.config import settings +from app.core.exceptions import AppError + + +@dataclass(frozen=True) +class VectorSearchHit: + """向量检索命中:只保存 chunk_uid 和相似度,来源详情从 MySQL 读取。""" + + chunk_uid: str + score: float + + +class MilvusVectorStore: + """Milvus 向量库适配器:按机构 collection 写入和检索知识分片向量。""" + + _mock_store: dict[str, dict[str, list[float]]] = {} + + def __init__(self) -> None: + self.mock_enabled = settings.milvus_uri.startswith("mock://") + self._client = None + + def ensure_collection(self, collection_name: str) -> None: + """集合初始化:不存在时创建 VARCHAR 主键 + FLOAT_VECTOR 的 Milvus collection。""" + if self.mock_enabled: + self._mock_store.setdefault(collection_name, {}) + return + client = self._client_or_raise() + try: + if client.has_collection(collection_name=collection_name): + return + schema = client.create_schema(auto_id=False, enable_dynamic_field=False) + from pymilvus import DataType + + schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=128) + schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=settings.embedding_dim) + index_params = client.prepare_index_params() + index_params.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="COSINE") + client.create_collection( + collection_name=collection_name, + schema=schema, + index_params=index_params, + consistency_level="Strong", + ) + except Exception as exc: # pragma: no cover - 真实 Milvus 由联调环境验证 + raise AppError("MILVUS_COLLECTION_INIT_FAILED", "milvus collection init failed", 502) from exc + + def upsert_vectors(self, collection_name: str, vectors: list[tuple[str, list[float]]]) -> None: + """向量写入:使用 chunk_uid 作为 Milvus 主键,保证重复构建可覆盖。""" + if not vectors: + return + self.ensure_collection(collection_name) + if self.mock_enabled: + collection = self._mock_store.setdefault(collection_name, {}) + for chunk_uid, vector in vectors: + collection[chunk_uid] = vector + return + client = self._client_or_raise() + try: + client.upsert( + collection_name=collection_name, + data=[{"id": chunk_uid, "vector": vector} for chunk_uid, vector in vectors], + ) + except Exception as exc: # pragma: no cover + raise AppError("MILVUS_UPSERT_FAILED", "milvus vector upsert failed", 502) from exc + + def search(self, collection_name: str, query_vector: list[float], limit: int) -> list[VectorSearchHit]: + """向量检索:按余弦相似度返回候选 chunk_uid,后续由业务层过滤阈值。""" + self.ensure_collection(collection_name) + if self.mock_enabled: + return self._mock_search(collection_name, query_vector, limit) + client = self._client_or_raise() + try: + results = client.search( + collection_name=collection_name, + data=[query_vector], + anns_field="vector", + limit=limit, + search_params={"metric_type": "COSINE"}, + output_fields=["id"], + ) + except Exception as exc: # pragma: no cover + raise AppError("MILVUS_SEARCH_FAILED", "milvus vector search failed", 502) from exc + + hits: list[VectorSearchHit] = [] + for item in results[0] if results else []: + entity = item.get("entity") or {} + chunk_uid = str(entity.get("id") or item.get("id") or "") + if chunk_uid: + hits.append(VectorSearchHit(chunk_uid=chunk_uid, score=float(item.get("distance") or item.get("score") or 0))) + return hits + + def _client_or_raise(self): + """客户端获取:懒加载 pymilvus,避免未使用知识库时影响现有训练接口。""" + if self._client is not None: + return self._client + try: + from pymilvus import MilvusClient + except ImportError as exc: + raise AppError("MILVUS_CLIENT_NOT_INSTALLED", "pymilvus is required for vector search", 500) from exc + self._client = MilvusClient(uri=settings.milvus_uri, db_name=settings.milvus_default_db) + return self._client + + def _mock_search(self, collection_name: str, query_vector: list[float], limit: int) -> list[VectorSearchHit]: + """Mock检索:用向量点积模拟余弦排序,便于无 Milvus 环境测试。""" + collection = self._mock_store.get(collection_name, {}) + scored = [ + VectorSearchHit(chunk_uid=chunk_uid, score=(sum(a * b for a, b in zip(query_vector, vector)) + 1.0) / 2.0) + for chunk_uid, vector in collection.items() + ] + return sorted(scored, key=lambda item: item.score, reverse=True)[:limit] diff --git a/app/integrations/pdf_parser.py b/app/integrations/pdf_parser.py new file mode 100644 index 0000000..e4d092e --- /dev/null +++ b/app/integrations/pdf_parser.py @@ -0,0 +1,44 @@ +from dataclasses import dataclass +from pathlib import Path + +from app.core.exceptions import AppError + + +@dataclass(frozen=True) +class ParsedPdfPage: + """PDF 页文本:保留页码,支撑 RAG 回答中的来源页码引用。""" + + page_number: int + text: str + + +class PdfParser: + """PDF 解析器:使用 PyMuPDF 逐页提取教材、指南等 PDF 文本。""" + + def parse(self, file_path: str | Path) -> list[ParsedPdfPage]: + """PDF解析:逐页读取文本并过滤空页,失败时返回统一业务异常。""" + path = Path(file_path) + if not path.exists(): + raise AppError("PDF_FILE_NOT_FOUND", "uploaded pdf file not found", 404) + try: + import fitz # PyMuPDF + except ImportError as exc: + raise AppError("PDF_PARSER_NOT_INSTALLED", "PyMuPDF is required for pdf parsing", 500) from exc + + pages: list[ParsedPdfPage] = [] + try: + with fitz.open(path) as doc: + for index, page in enumerate(doc, start=1): + text = self._clean_text(page.get_text("text") or "") + if text: + pages.append(ParsedPdfPage(page_number=index, text=text)) + except Exception as exc: # pragma: no cover - PyMuPDF 异常类型较多,统一转换即可 + raise AppError("PDF_PARSE_FAILED", "pdf parse failed", 422) from exc + if not pages: + raise AppError("PDF_PARSE_EMPTY", "pdf text content is empty", 422) + return pages + + def _clean_text(self, text: str) -> str: + """文本清洗:压缩空白并保留自然换行,便于后续教材分片。""" + lines = [" ".join(line.strip().split()) for line in text.splitlines()] + return "\n".join(line for line in lines if line) diff --git a/app/models/__init__.py b/app/models/__init__.py index a8ed056..27fffe7 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -1,8 +1,15 @@ -"""ORM 模型导出:初始化数据库时只导入当前新表体系需要的模型。""" +"""ORM 模型导出:初始化数据库时只导入当前新表体系需要的模型。""" from app.models.audit import AuditLog from app.models.department import Department from app.models.knowledge import KnowledgeChunk, KnowledgeDocument, KnowledgeSource +from app.models.knowledge_base import ( + KbKnowledgeChunk, + KbKnowledgeDocument, + KbKnowledgeIngestionTask, + KbKnowledgeQueryLog, + KbKnowledgeSpace, +) from app.models.prompt import PromptTemplate from app.models.source_case import CaseBase, CaseExamItem, ScoringRule, TeachingCase, TraditionalCase from app.models.training import SessionOrder, SessionSubmission, TrainingSession @@ -15,6 +22,11 @@ __all__ = [ "KnowledgeChunk", "KnowledgeDocument", "KnowledgeSource", + "KbKnowledgeSpace", + "KbKnowledgeDocument", + "KbKnowledgeChunk", + "KbKnowledgeIngestionTask", + "KbKnowledgeQueryLog", "PromptTemplate", "CaseBase", "CaseExamItem", diff --git a/app/models/knowledge_base.py b/app/models/knowledge_base.py new file mode 100644 index 0000000..e7b5783 --- /dev/null +++ b/app/models/knowledge_base.py @@ -0,0 +1,116 @@ +from datetime import datetime + +from sqlalchemy import DateTime, Float, Integer, JSON, String, Text, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column + +from app.db.base import Base +from app.models.mixins import TimestampMixin + + +class KbKnowledgeSpace(TimestampMixin, Base): + """知识空间模型:记录机构与 Milvus collection、embedding 参数之间的映射。""" + + __tablename__ = "kb_spaces" + __table_args__ = (UniqueConstraint("institution_id", "embedding_version", name="uq_kb_space_institution_version"),) + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + institution_id: Mapped[int] = mapped_column(Integer, nullable=False, index=True, comment="机构ID") + institution_name: Mapped[str | None] = mapped_column(String(128), comment="机构名称") + space_code: Mapped[str] = mapped_column(String(128), nullable=False, unique=True, comment="知识空间编码") + collection_name: Mapped[str] = mapped_column(String(128), nullable=False, unique=True, comment="Milvus集合名") + embedding_model: Mapped[str] = mapped_column(String(128), nullable=False, comment="向量模型") + embedding_dim: Mapped[int] = mapped_column(Integer, nullable=False, comment="向量维度") + embedding_version: Mapped[str] = mapped_column(String(32), nullable=False, default="v1", comment="向量版本") + chunk_size: Mapped[int] = mapped_column(Integer, nullable=False, comment="分片长度") + chunk_overlap: Mapped[int] = mapped_column(Integer, nullable=False, comment="分片重叠长度") + top_k_default: Mapped[int] = mapped_column(Integer, nullable=False, default=5, comment="默认返回片段数") + score_threshold: Mapped[float] = mapped_column(Float, nullable=False, default=0.35, comment="默认相似度阈值") + status: Mapped[str] = mapped_column(String(32), nullable=False, default="active", index=True, comment="状态") + + +class KbKnowledgeDocument(TimestampMixin, Base): + """知识文档模型:记录内容管理员上传的 PDF 及其处理状态。""" + + __tablename__ = "kb_documents" + __table_args__ = (UniqueConstraint("institution_id", "file_sha256", name="uq_kb_document_institution_sha"),) + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + institution_id: Mapped[int] = mapped_column(Integer, nullable=False, index=True, comment="机构ID") + uploaded_by: Mapped[str] = mapped_column(String(128), nullable=False, index=True, comment="上传用户ID") + file_name: Mapped[str] = mapped_column(String(255), nullable=False, comment="原始文件名") + file_sha256: Mapped[str] = mapped_column(String(64), nullable=False, index=True, comment="文件SHA256") + file_type: Mapped[str] = mapped_column(String(32), nullable=False, default="pdf", comment="文件类型") + file_size: Mapped[int] = mapped_column(Integer, nullable=False, comment="文件大小") + file_path: Mapped[str] = mapped_column(String(512), nullable=False, comment="文件保存路径") + document_title: Mapped[str | None] = mapped_column(String(255), comment="文档标题") + document_category: Mapped[str] = mapped_column(String(64), nullable=False, default="other", comment="文档分类") + version: Mapped[str] = mapped_column(String(32), nullable=False, default="v1", comment="文档版本") + status: Mapped[str] = mapped_column(String(32), nullable=False, default="uploaded", index=True, comment="处理状态") + parse_status: Mapped[str] = mapped_column(String(32), nullable=False, default="pending", comment="解析状态") + embedding_status: Mapped[str] = mapped_column(String(32), nullable=False, default="pending", comment="向量状态") + chunk_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0, comment="分片数量") + error_message: Mapped[str | None] = mapped_column(Text, comment="错误信息") + + +class KbKnowledgeChunk(Base): + """知识分片模型:保存 PDF 分片文本、页码和 Milvus 向量 ID 元数据。""" + + __tablename__ = "kb_chunks" + __table_args__ = (UniqueConstraint("chunk_uid", name="uq_kb_chunk_uid"),) + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + institution_id: Mapped[int] = mapped_column(Integer, nullable=False, index=True, comment="机构ID") + document_id: Mapped[int] = mapped_column(Integer, nullable=False, index=True, comment="文档ID") + chunk_uid: Mapped[str] = mapped_column(String(128), nullable=False, index=True, comment="分片唯一ID") + chunk_index: Mapped[int] = mapped_column(Integer, nullable=False, comment="分片序号") + page_start: Mapped[int] = mapped_column(Integer, nullable=False, comment="起始页") + page_end: Mapped[int] = mapped_column(Integer, nullable=False, comment="结束页") + section_title: Mapped[str | None] = mapped_column(String(255), comment="章节标题") + chunk_text: Mapped[str] = mapped_column(Text, nullable=False, comment="分片文本") + chunk_hash: Mapped[str] = mapped_column(String(64), nullable=False, index=True, comment="分片Hash") + token_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0, comment="估算token数") + vector_id: Mapped[str] = mapped_column(String(128), nullable=False, index=True, comment="Milvus向量ID") + collection_name: Mapped[str] = mapped_column(String(128), nullable=False, index=True, comment="Milvus集合名") + embedding_model: Mapped[str] = mapped_column(String(128), nullable=False, comment="向量模型") + metadata_: Mapped[dict | None] = mapped_column("metadata", JSON, comment="扩展元数据") + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, index=True) + + +class KbKnowledgeIngestionTask(TimestampMixin, Base): + """知识入库任务模型:记录 PDF 解析、分片、向量化和入 Milvus 的异步进度。""" + + __tablename__ = "kb_ingestion_tasks" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + document_id: Mapped[int] = mapped_column(Integer, nullable=False, index=True, comment="文档ID") + institution_id: Mapped[int] = mapped_column(Integer, nullable=False, index=True, comment="机构ID") + task_type: Mapped[str] = mapped_column(String(64), nullable=False, default="document_ingestion", comment="任务类型") + status: Mapped[str] = mapped_column(String(32), nullable=False, default="queued", index=True, comment="任务状态") + progress: Mapped[int] = mapped_column(Integer, nullable=False, default=0, comment="进度百分比") + current_step: Mapped[str | None] = mapped_column(String(255), comment="当前步骤") + error_message: Mapped[str | None] = mapped_column(Text, comment="错误信息") + started_at: Mapped[datetime | None] = mapped_column(DateTime) + finished_at: Mapped[datetime | None] = mapped_column(DateTime) + + +class KbKnowledgeQueryLog(Base): + """学习助手问答日志:记录 RAG 检索命中、来源和耗时,用于审计和效果分析。""" + + __tablename__ = "kb_query_logs" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + user_id: Mapped[str] = mapped_column(String(128), nullable=False, index=True, comment="用户ID") + institution_id: Mapped[int] = mapped_column(Integer, nullable=False, index=True, comment="机构ID") + question: Mapped[str] = mapped_column(Text, nullable=False, comment="用户问题") + retrieval_hit: Mapped[bool] = mapped_column(Integer, nullable=False, default=0, comment="是否命中知识库") + retrieved_chunk_ids: Mapped[list | None] = mapped_column(JSON, comment="命中分片ID") + answer_summary: Mapped[str | None] = mapped_column(Text, comment="回答摘要") + llm_model: Mapped[str | None] = mapped_column(String(128), comment="LLM模型") + embedding_model: Mapped[str | None] = mapped_column(String(128), comment="向量模型") + top_k: Mapped[int | None] = mapped_column(Integer, comment="最终返回片段数") + score_threshold: Mapped[float | None] = mapped_column(Float, comment="检索阈值") + embedding_latency_ms: Mapped[int | None] = mapped_column(Integer) + search_latency_ms: Mapped[int | None] = mapped_column(Integer) + llm_latency_ms: Mapped[int | None] = mapped_column(Integer) + total_latency_ms: Mapped[int | None] = mapped_column(Integer) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, index=True) diff --git a/app/prompts/learning_assistant/no_reference_answer.md b/app/prompts/learning_assistant/no_reference_answer.md new file mode 100644 index 0000000..0a6c775 --- /dev/null +++ b/app/prompts/learning_assistant/no_reference_answer.md @@ -0,0 +1,24 @@ +--- +template_code: learning_assistant_no_reference_answer +agent_type: learning_assistant +version: v1 +scene: no_reference +model_type: fast +output_format: text +--- + +# Role +你是医学学习助手,用于医学教育、课程学习和临床思维训练。 + +# Task +当机构知识库没有检索到可用参考时,给出通用学习回答。 + +# Rules +- 回答开头必须写:未检索到本机构知识库参考,以下为大模型通用学习回答。 +- 不得伪造 PDF 来源、页码、教材或指南。 +- 回答应简洁、准确、分点。 + +# Safety Boundaries +- 不替代真实临床诊疗。 +- 不针对具体患者给出最终诊疗指令。 +- 提醒用户以课程教材、指南和临床医生判断为准。 diff --git a/app/prompts/learning_assistant/rag_answer.md b/app/prompts/learning_assistant/rag_answer.md new file mode 100644 index 0000000..4644c66 --- /dev/null +++ b/app/prompts/learning_assistant/rag_answer.md @@ -0,0 +1,29 @@ +--- +template_code: learning_assistant_rag_answer +agent_type: learning_assistant +version: v1 +scene: rag +model_type: fast +output_format: text +--- + +# Role +你是医学学习助手,用于医学教育、课程学习和临床思维训练。 + +# Task +根据检索到的机构知识库片段回答用户问题,并在关键结论后标注来源编号。 + +# Inputs +- 用户问题 +- 知识库片段:PDF 文件名、页码、chunk_uid、片段文本 + +# Rules +- 优先使用知识库片段。 +- 每个关键结论后标注来源,例如【来源1】。 +- 不编造不存在的 PDF、页码、指南或教材。 +- 回答结构清晰,适合学生学习。 + +# Safety Boundaries +- 不替代真实临床诊疗。 +- 不针对具体患者给出最终诊疗指令。 +- 医疗结论需提示以教材、指南和临床医生判断为准。 diff --git a/app/repositories/knowledge_base_repository.py b/app/repositories/knowledge_base_repository.py new file mode 100644 index 0000000..0e042cf --- /dev/null +++ b/app/repositories/knowledge_base_repository.py @@ -0,0 +1,224 @@ +from datetime import datetime + +from sqlalchemy import delete, select +from sqlalchemy.orm import Session + +from app.core.config import settings +from app.models.knowledge_base import ( + KbKnowledgeChunk, + KbKnowledgeDocument, + KbKnowledgeIngestionTask, + KbKnowledgeQueryLog, + KbKnowledgeSpace, +) + + +class KnowledgeBaseRepository: + """知识库仓储:集中管理 kb_* 表的查询和写入。""" + + def __init__(self, db: Session) -> None: + self.db = db + + def get_or_create_space(self, institution_id: int, institution_name: str | None) -> KbKnowledgeSpace: + """知识空间:按机构和 embedding 版本获取或创建 Milvus collection 映射。""" + version = "v1" + space = self.db.scalar( + select(KbKnowledgeSpace).where( + KbKnowledgeSpace.institution_id == institution_id, + KbKnowledgeSpace.embedding_version == version, + ) + ) + if space: + return space + collection_name = self._collection_name(institution_id, version) + space = KbKnowledgeSpace( + institution_id=institution_id, + institution_name=institution_name, + space_code=f"institution_{institution_id}_{version}", + collection_name=collection_name, + embedding_model=settings.embedding_model, + embedding_dim=settings.embedding_dim, + embedding_version=version, + chunk_size=1100, + chunk_overlap=180, + top_k_default=settings.rag_top_k, + score_threshold=settings.rag_score_threshold, + status="active", + ) + self.db.add(space) + self.db.flush() + return space + + def get_space(self, institution_id: int) -> KbKnowledgeSpace | None: + """知识空间查询:读取机构当前可用的知识库 collection 映射。""" + return self.db.scalar( + select(KbKnowledgeSpace) + .where(KbKnowledgeSpace.institution_id == institution_id, KbKnowledgeSpace.status == "active") + .order_by(KbKnowledgeSpace.id.desc()) + ) + + def get_document_by_hash(self, institution_id: int, file_sha256: str) -> KbKnowledgeDocument | None: + """文档去重:按机构和文件 SHA256 判断是否已上传。""" + return self.db.scalar( + select(KbKnowledgeDocument).where( + KbKnowledgeDocument.institution_id == institution_id, + KbKnowledgeDocument.file_sha256 == file_sha256, + ) + ) + + def create_document( + self, + *, + institution_id: int, + uploaded_by: str, + file_name: str, + file_sha256: str, + file_size: int, + file_path: str, + document_title: str | None, + document_category: str, + version: str, + ) -> KbKnowledgeDocument: + """文档创建:保存内容管理员上传 PDF 的元数据和本地存储路径。""" + document = KbKnowledgeDocument( + institution_id=institution_id, + uploaded_by=uploaded_by, + file_name=file_name, + file_sha256=file_sha256, + file_type="pdf", + file_size=file_size, + file_path=file_path, + document_title=document_title, + document_category=document_category, + version=version, + status="uploaded", + parse_status="pending", + embedding_status="pending", + chunk_count=0, + ) + self.db.add(document) + self.db.flush() + return document + + def get_document(self, document_id: int, institution_id: int | None = None) -> KbKnowledgeDocument | None: + """文档查询:按文档 ID 获取知识库文档,机构参数用于访问隔离。""" + stmt = select(KbKnowledgeDocument).where(KbKnowledgeDocument.id == document_id) + if institution_id is not None: + stmt = stmt.where(KbKnowledgeDocument.institution_id == institution_id) + return self.db.scalar(stmt) + + def create_ingestion_task(self, document: KbKnowledgeDocument) -> KbKnowledgeIngestionTask: + """入库任务:记录 PDF 解析、分片、向量化和写入 Milvus 的处理进度。""" + task = KbKnowledgeIngestionTask( + document_id=document.id, + institution_id=document.institution_id, + task_type="document_ingestion", + status="queued", + progress=0, + current_step="queued", + ) + self.db.add(task) + self.db.flush() + return task + + def get_ingestion_task(self, task_id: int) -> KbKnowledgeIngestionTask | None: + """任务查询:按任务 ID 读取知识入库任务。""" + return self.db.get(KbKnowledgeIngestionTask, task_id) + + def update_task( + self, + task: KbKnowledgeIngestionTask, + *, + status: str | None = None, + progress: int | None = None, + current_step: str | None = None, + error_message: str | None = None, + ) -> None: + """任务进度:更新入库任务状态,供前端或运维查看。""" + if status: + task.status = status + if status == "running" and task.started_at is None: + task.started_at = datetime.utcnow() + if status in {"success", "failed"}: + task.finished_at = datetime.utcnow() + if progress is not None: + task.progress = progress + if current_step is not None: + task.current_step = current_step + if error_message is not None: + task.error_message = error_message + + def replace_chunks(self, document: KbKnowledgeDocument, chunks: list[KbKnowledgeChunk]) -> None: + """分片替换:重新构建文档时先删除旧分片,再写入新分片。""" + self.db.execute(delete(KbKnowledgeChunk).where(KbKnowledgeChunk.document_id == document.id)) + for chunk in chunks: + self.db.add(chunk) + document.chunk_count = len(chunks) + self.db.flush() + + def get_chunks_by_uids(self, institution_id: int, chunk_uids: list[str]) -> list[KbKnowledgeChunk]: + """分片查询:根据 Milvus 返回的 chunk_uid 批量读取 MySQL 分片详情。""" + if not chunk_uids: + return [] + rows = self.db.scalars( + select(KbKnowledgeChunk).where( + KbKnowledgeChunk.institution_id == institution_id, + KbKnowledgeChunk.chunk_uid.in_(chunk_uids), + ) + ).all() + order = {chunk_uid: index for index, chunk_uid in enumerate(chunk_uids)} + return sorted(rows, key=lambda row: order.get(row.chunk_uid, 10_000)) + + def list_documents(self, institution_id: int, limit: int = 20) -> list[KbKnowledgeDocument]: + """文档列表:返回机构最近上传的知识库文档。""" + return list( + self.db.scalars( + select(KbKnowledgeDocument) + .where(KbKnowledgeDocument.institution_id == institution_id) + .order_by(KbKnowledgeDocument.id.desc()) + .limit(limit) + ) + ) + + def create_query_log( + self, + *, + user_id: str, + institution_id: int, + question: str, + retrieval_hit: bool, + retrieved_chunk_ids: list[str], + answer_summary: str, + llm_model: str | None, + top_k: int, + score_threshold: float, + embedding_latency_ms: int | None, + search_latency_ms: int | None, + llm_latency_ms: int | None, + total_latency_ms: int | None, + ) -> KbKnowledgeQueryLog: + """查询日志:记录 RAG 命中、来源和耗时,支撑后续审计与效果分析。""" + log = KbKnowledgeQueryLog( + user_id=user_id, + institution_id=institution_id, + question=question, + retrieval_hit=bool(retrieval_hit), + retrieved_chunk_ids=retrieved_chunk_ids, + answer_summary=answer_summary[:1000], + llm_model=llm_model, + embedding_model=settings.embedding_model, + top_k=top_k, + score_threshold=score_threshold, + embedding_latency_ms=embedding_latency_ms, + search_latency_ms=search_latency_ms, + llm_latency_ms=llm_latency_ms, + total_latency_ms=total_latency_ms, + ) + self.db.add(log) + self.db.flush() + return log + + def _collection_name(self, institution_id: int, version: str) -> str: + """集合命名:按机构隔离 Milvus collection,名称只使用安全字符。""" + model_part = "".join(ch if ch.isalnum() else "_" for ch in settings.embedding_model.lower()) + return f"{settings.milvus_collection_prefix}_{institution_id}_{model_part}_{version}"[:120] diff --git a/app/schemas/knowledge_admin.py b/app/schemas/knowledge_admin.py new file mode 100644 index 0000000..d320e1e --- /dev/null +++ b/app/schemas/knowledge_admin.py @@ -0,0 +1,40 @@ +from datetime import datetime + +from pydantic import BaseModel, Field + + +class KnowledgeDocumentUploadResponse(BaseModel): + """知识文档上传响应:返回文档、任务和 Milvus collection 状态。""" + + document_id: int + task_id: int | None = None + duplicate: bool = False + status: str + parse_status: str + embedding_status: str + chunk_count: int + collection_name: str + + +class KnowledgeDocumentDetailResponse(BaseModel): + """知识文档详情:用于内容管理员查看 PDF 构建结果。""" + + document_id: int + institution_id: int + file_name: str + document_title: str | None = None + document_category: str + version: str + status: str + parse_status: str + embedding_status: str + chunk_count: int + error_message: str | None = None + created_at: datetime | None = None + updated_at: datetime | None = None + + +class KnowledgeDocumentListResponse(BaseModel): + """知识文档列表:按机构返回最近上传文档。""" + + items: list[KnowledgeDocumentDetailResponse] = Field(default_factory=list) diff --git a/app/schemas/learning_assistant.py b/app/schemas/learning_assistant.py new file mode 100644 index 0000000..44ff73d --- /dev/null +++ b/app/schemas/learning_assistant.py @@ -0,0 +1,36 @@ +from pydantic import BaseModel, Field + + +class LearningAssistantChatRequest(BaseModel): + """学习助手请求:普通用户面向机构知识库提出医学学习问题。""" + + question: str = Field(..., min_length=2, max_length=1000, description="用户问题") + top_k: int | None = Field(default=None, ge=1, le=10, description="最终返回给 LLM 的来源片段数") + score_threshold: float | None = Field(default=None, ge=0, le=1, description="向量相似度过滤阈值") + + +class LearningAssistantSource(BaseModel): + """学习助手来源:记录 PDF 文档、页码和引用片段。""" + + document_id: int + document_title: str | None = None + file_name: str + page_start: int + page_end: int + chunk_uid: str + score: float + quote: str + + +class LearningAssistantChatResponse(BaseModel): + """学习助手回答:返回答案、知识库命中状态、循证来源和耗时。""" + + answer: str + retrieval_hit: bool + sources: list[LearningAssistantSource] = Field(default_factory=list) + retrieval_error: str | None = None + model: str | None = None + embedding_latency_ms: int | None = None + search_latency_ms: int | None = None + llm_latency_ms: int | None = None + total_latency_ms: int | None = None diff --git a/app/services/document_chunk_service.py b/app/services/document_chunk_service.py new file mode 100644 index 0000000..dd44943 --- /dev/null +++ b/app/services/document_chunk_service.py @@ -0,0 +1,154 @@ +import hashlib +import re +from dataclasses import dataclass + +from app.integrations.pdf_parser import ParsedPdfPage +from app.models.knowledge_base import KbKnowledgeChunk + + +@dataclass(frozen=True) +class ChunkDraft: + """分片草稿:PDF 文本切分后的中间结构,后续写入 MySQL 和 Milvus。""" + + chunk_index: int + page_start: int + page_end: int + section_title: str | None + text: str + + +class DocumentChunkService: + """文档分片服务:面向教材/指南 PDF 的页码保留语义分片。""" + + def __init__(self, chunk_size: int = 1100, chunk_overlap: int = 180) -> None: + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + + def build_chunks(self, pages: list[ParsedPdfPage]) -> list[ChunkDraft]: + """教材分片:按页和自然段切分,超长段落使用窗口切分并保留页码。""" + drafts: list[ChunkDraft] = [] + buffer: list[str] = [] + page_start: int | None = None + page_end: int | None = None + current_title: str | None = None + + for page in pages: + paragraphs = self._split_paragraphs(page.text) + for paragraph in paragraphs: + detected_title = self._detect_title(paragraph) + if detected_title: + current_title = detected_title + for piece in self._split_long_text(paragraph): + candidate = "\n".join([*buffer, piece]).strip() + if buffer and len(candidate) > self.chunk_size: + drafts.append( + ChunkDraft( + chunk_index=len(drafts), + page_start=page_start or page.page_number, + page_end=page_end or page.page_number, + section_title=current_title, + text="\n".join(buffer).strip(), + ) + ) + buffer = self._overlap_tail(buffer) + page_start = page.page_number if not buffer else page_start + if not buffer: + page_start = page.page_number + page_end = page.page_number + buffer.append(piece) + + if buffer: + drafts.append( + ChunkDraft( + chunk_index=len(drafts), + page_start=page_start or pages[-1].page_number, + page_end=page_end or pages[-1].page_number, + section_title=current_title, + text="\n".join(buffer).strip(), + ) + ) + return [draft for draft in drafts if draft.text] + + def to_models( + self, + *, + institution_id: int, + document_id: int, + collection_name: str, + embedding_model: str, + drafts: list[ChunkDraft], + ) -> list[KbKnowledgeChunk]: + """分片落库:把分片草稿转换为 ORM 对象,chunk_uid 同时作为 Milvus vector_id。""" + rows: list[KbKnowledgeChunk] = [] + for draft in drafts: + chunk_hash = hashlib.sha256(draft.text.encode("utf-8")).hexdigest() + chunk_uid = f"doc{document_id}_chunk{draft.chunk_index}_{chunk_hash[:12]}" + rows.append( + KbKnowledgeChunk( + institution_id=institution_id, + document_id=document_id, + chunk_uid=chunk_uid, + chunk_index=draft.chunk_index, + page_start=draft.page_start, + page_end=draft.page_end, + section_title=draft.section_title, + chunk_text=draft.text, + chunk_hash=chunk_hash, + token_count=max(1, len(draft.text) // 2), + vector_id=chunk_uid, + collection_name=collection_name, + embedding_model=embedding_model, + metadata_={"chunking": "page_semantic_window", "chunk_size": self.chunk_size, "overlap": self.chunk_overlap}, + ) + ) + return rows + + def _split_paragraphs(self, text: str) -> list[str]: + """段落切分:优先按 PDF 自带换行和空白段落切分教材内容。""" + parts = re.split(r"\n{1,}", text) + return [part.strip() for part in parts if part.strip()] + + def _split_long_text(self, text: str) -> list[str]: + """超长兜底:对超过窗口的段落按句末标点拆分,仍过长时按字符窗口切分。""" + if len(text) <= self.chunk_size: + return [text] + sentences = re.split(r"(?<=[。!?;;.!?])", text) + pieces: list[str] = [] + current = "" + for sentence in sentences: + if len(current) + len(sentence) > self.chunk_size and current: + pieces.append(current.strip()) + current = current[-self.chunk_overlap :] if self.chunk_overlap else "" + current += sentence + if current.strip(): + pieces.append(current.strip()) + final: list[str] = [] + for piece in pieces: + if len(piece) <= self.chunk_size: + final.append(piece) + continue + start = 0 + while start < len(piece): + final.append(piece[start : start + self.chunk_size]) + start += max(1, self.chunk_size - self.chunk_overlap) + return final + + def _overlap_tail(self, buffer: list[str]) -> list[str]: + """重叠窗口:保留上一片尾部少量文本,提升跨片问题召回。""" + if not self.chunk_overlap: + return [] + text = "\n".join(buffer).strip() + tail = text[-self.chunk_overlap :] + return [tail] if tail else [] + + def _detect_title(self, paragraph: str) -> str | None: + """标题识别:识别教材常见章、节、条目标题,作为分片元数据。""" + compact = paragraph.strip() + if len(compact) > 80: + return None + title_patterns = [ + r"^第[一二三四五六七八九十百0-9]+[章节篇]", + r"^[一二三四五六七八九十]+[、..]", + r"^\d+(\.\d+){0,3}\s+", + ] + return compact if any(re.search(pattern, compact) for pattern in title_patterns) else None diff --git a/app/services/document_ingestion_service.py b/app/services/document_ingestion_service.py new file mode 100644 index 0000000..7d37014 --- /dev/null +++ b/app/services/document_ingestion_service.py @@ -0,0 +1,177 @@ +import hashlib +from pathlib import Path + +from fastapi import UploadFile +from sqlalchemy.orm import Session + +from app.core.config import settings +from app.core.context import UserContext +from app.core.exceptions import AppError +from app.integrations.milvus_adapter import MilvusVectorStore +from app.integrations.pdf_parser import PdfParser +from app.repositories.knowledge_base_repository import KnowledgeBaseRepository +from app.schemas.knowledge_admin import KnowledgeDocumentUploadResponse +from app.services.document_chunk_service import DocumentChunkService +from app.services.embedding_service import EmbeddingService +from app.services.knowledge_space_service import KnowledgeSpaceService + + +class DocumentIngestionService: + """知识入库服务:处理 PDF 上传、解析、分片、向量化和 Milvus 写入。""" + + def __init__( + self, + db: Session, + *, + parser: PdfParser | None = None, + chunker: DocumentChunkService | None = None, + embedding_service: EmbeddingService | None = None, + vector_store: MilvusVectorStore | None = None, + ) -> None: + self.db = db + self.repo = KnowledgeBaseRepository(db) + self.space_service = KnowledgeSpaceService(self.repo) + self.parser = parser or PdfParser() + self.chunker = chunker or DocumentChunkService() + self.embedding_service = embedding_service or EmbeddingService() + self.vector_store = vector_store or MilvusVectorStore() + + async def upload_pdf( + self, + ctx: UserContext, + file: UploadFile, + *, + document_title: str | None, + document_category: str, + version: str, + ) -> KnowledgeDocumentUploadResponse: + """文档上传:内容管理员上传 PDF 后创建知识文档并触发构建任务。""" + self.space_service.ensure_content_admin(ctx) + space = self.space_service.get_or_create_space(ctx) + content = await file.read() + self._validate_pdf(file, content) + file_sha256 = hashlib.sha256(content).hexdigest() + existing = self.repo.get_document_by_hash(space.institution_id, file_sha256) + if existing: + task = self.repo.create_ingestion_task(existing) if existing.status == "failed" else None + return KnowledgeDocumentUploadResponse( + document_id=existing.id, + task_id=task.id if task else None, + duplicate=True, + status=existing.status, + parse_status=existing.parse_status, + embedding_status=existing.embedding_status, + chunk_count=existing.chunk_count, + collection_name=space.collection_name, + ) + + storage_path = self._save_file(space.institution_id, file.filename or "knowledge.pdf", content) + document = self.repo.create_document( + institution_id=space.institution_id, + uploaded_by=ctx.user_id, + file_name=file.filename or storage_path.name, + file_sha256=file_sha256, + file_size=len(content), + file_path=str(storage_path), + document_title=document_title or Path(file.filename or "").stem or None, + document_category=document_category, + version=version, + ) + task = self.repo.create_ingestion_task(document) + if settings.knowledge_ingestion_sync: + await self.ingest_document(document.id, task.id) + else: + self._enqueue_async_task(document.id, task.id) + return KnowledgeDocumentUploadResponse( + document_id=document.id, + task_id=task.id, + duplicate=False, + status=document.status, + parse_status=document.parse_status, + embedding_status=document.embedding_status, + chunk_count=document.chunk_count, + collection_name=space.collection_name, + ) + + async def ingest_document(self, document_id: int, task_id: int | None = None) -> None: + """知识构建:把已上传 PDF 转换为 MySQL 分片和 Milvus 向量。""" + document = self.repo.get_document(document_id) + if not document: + raise AppError("KNOWLEDGE_DOCUMENT_NOT_FOUND", "knowledge document not found", 404) + task = self.repo.get_ingestion_task(task_id) if task_id else None + try: + if task: + self.repo.update_task(task, status="running", progress=5, current_step="parse_pdf") + document.status = "processing" + document.parse_status = "running" + self.db.flush() + + pages = self.parser.parse(document.file_path) + space = self.repo.get_or_create_space(document.institution_id, None) + drafts = self.chunker.build_chunks(pages) + chunks = self.chunker.to_models( + institution_id=document.institution_id, + document_id=document.id, + collection_name=space.collection_name, + embedding_model=settings.embedding_model, + drafts=drafts, + ) + if task: + self.repo.update_task(task, progress=35, current_step="embed_chunks") + document.parse_status = "success" + document.embedding_status = "running" + self.db.flush() + + vectors, _embedding_latency_ms = await self.embedding_service.embed_texts([chunk.chunk_text for chunk in chunks]) + if task: + self.repo.update_task(task, progress=75, current_step="write_vectors") + self.vector_store.upsert_vectors( + space.collection_name, + [(chunk.chunk_uid, vector) for chunk, vector in zip(chunks, vectors)], + ) + self.repo.replace_chunks(document, chunks) + document.status = "ready" + document.embedding_status = "success" + if task: + self.repo.update_task(task, status="success", progress=100, current_step="completed") + except Exception as exc: + document.status = "failed" + document.error_message = str(exc)[:2000] + document.parse_status = document.parse_status if document.parse_status == "success" else "failed" + document.embedding_status = "failed" + if task: + self.repo.update_task(task, status="failed", progress=100, current_step="failed", error_message=str(exc)[:2000]) + if isinstance(exc, AppError): + raise + raise AppError("KNOWLEDGE_INGESTION_FAILED", "knowledge document ingestion failed", 500) from exc + + def _validate_pdf(self, file: UploadFile, content: bytes) -> None: + """上传校验:限制文件类型和大小,只允许 PDF 文档进入知识库。""" + if not content: + raise AppError("UPLOAD_FILE_EMPTY", "uploaded file is empty", 422) + max_bytes = settings.knowledge_max_upload_mb * 1024 * 1024 + if len(content) > max_bytes: + raise AppError("UPLOAD_FILE_TOO_LARGE", f"uploaded file exceeds {settings.knowledge_max_upload_mb}MB", 413) + filename = (file.filename or "").lower() + if not filename.endswith(".pdf") and file.content_type not in {"application/pdf", "application/octet-stream"}: + raise AppError("UPLOAD_FILE_TYPE_INVALID", "only pdf file is supported", 422) + if not content.startswith(b"%PDF"): + raise AppError("UPLOAD_FILE_NOT_PDF", "uploaded file is not a valid pdf", 422) + + def _save_file(self, institution_id: int, filename: str, content: bytes) -> Path: + """文件保存:按机构隔离保存原始 PDF,供后续重建知识库。""" + safe_name = "".join(ch if ch.isalnum() or ch in {".", "_", "-"} else "_" for ch in filename) + storage_dir = Path(settings.knowledge_storage_dir) / "raw" / str(institution_id) + storage_dir.mkdir(parents=True, exist_ok=True) + target = storage_dir / f"{hashlib.sha256(content).hexdigest()[:16]}_{safe_name}" + target.write_bytes(content) + return target + + def _enqueue_async_task(self, document_id: int, task_id: int) -> None: + """异步投递:生产环境通过 Celery worker 执行 PDF 知识库构建。""" + try: + from app.tasks.knowledge_ingestion_tasks import ingest_knowledge_document + + ingest_knowledge_document.delay(document_id, task_id) + except Exception as exc: # pragma: no cover - Celery 未运行时保留任务 queued 状态 + raise AppError("KNOWLEDGE_TASK_ENQUEUE_FAILED", "knowledge ingestion task enqueue failed", 500) from exc diff --git a/app/services/embedding_service.py b/app/services/embedding_service.py new file mode 100644 index 0000000..0822fb8 --- /dev/null +++ b/app/services/embedding_service.py @@ -0,0 +1,22 @@ +import time + +from app.core.config import settings +from app.integrations.embedding_adapter import OpenAICompatibleEmbeddingClient + + +class EmbeddingService: + """Embedding 服务:按配置批量调用向量模型,控制批大小和耗时统计。""" + + def __init__(self, client: OpenAICompatibleEmbeddingClient | None = None) -> None: + self.client = client or OpenAICompatibleEmbeddingClient() + + async def embed_texts(self, texts: list[str]) -> tuple[list[list[float]], int]: + """批量向量化:按 EMBEDDING_BATCH_SIZE 分批生成向量并返回总耗时。""" + start = time.perf_counter() + vectors: list[list[float]] = [] + batch_size = max(1, settings.embedding_batch_size) + for index in range(0, len(texts), batch_size): + batch = texts[index : index + batch_size] + batch_vectors, _usage = await self.client.embed_texts(batch) + vectors.extend(batch_vectors) + return vectors, int((time.perf_counter() - start) * 1000) diff --git a/app/services/knowledge_space_service.py b/app/services/knowledge_space_service.py new file mode 100644 index 0000000..134435a --- /dev/null +++ b/app/services/knowledge_space_service.py @@ -0,0 +1,39 @@ +from app.core.context import UserContext +from app.core.exceptions import AppError +from app.models.knowledge_base import KbKnowledgeSpace +from app.repositories.knowledge_base_repository import KnowledgeBaseRepository + + +class KnowledgeSpaceService: + """知识空间服务:按用户所属机构定位知识库 collection。""" + + def __init__(self, repo: KnowledgeBaseRepository) -> None: + self.repo = repo + + def require_institution_id(self, ctx: UserContext) -> int: + """机构校验:知识库能力必须绑定 Django 用户中心返回的 institution_id。""" + if ctx.institution_id is None: + raise AppError("INSTITUTION_REQUIRED", "institution_id is required for knowledge base", 403) + return int(ctx.institution_id) + + def get_or_create_space(self, ctx: UserContext) -> KbKnowledgeSpace: + """知识空间获取:内容管理员上传文档时自动创建机构知识空间。""" + institution_id = self.require_institution_id(ctx) + profile = ctx.profile or {} + institution_name = profile.get("institution_name") or f"institution_{institution_id}" + return self.repo.get_or_create_space(institution_id, institution_name) + + def get_active_space(self, ctx: UserContext) -> KbKnowledgeSpace: + """知识空间读取:AI 学习助手问答时读取机构当前可用知识空间。""" + institution_id = self.require_institution_id(ctx) + space = self.repo.get_space(institution_id) + if not space: + raise AppError("KNOWLEDGE_SPACE_NOT_FOUND", "knowledge space not initialized for institution", 404) + return space + + def ensure_content_admin(self, ctx: UserContext) -> None: + """权限校验:仅内容管理员或系统管理员可以上传并构建机构知识库。""" + role = (ctx.role or "").lower() + allowed_roles = {"content_admin", "institution_admin", "admin", "super_admin"} + if role not in allowed_roles: + raise AppError("KNOWLEDGE_ADMIN_FORBIDDEN", "only content admin can upload knowledge documents", 403) diff --git a/app/services/learning_assistant_service.py b/app/services/learning_assistant_service.py new file mode 100644 index 0000000..9ab687d --- /dev/null +++ b/app/services/learning_assistant_service.py @@ -0,0 +1,224 @@ +import json +import time +from collections.abc import AsyncIterator +from dataclasses import dataclass + +from sqlalchemy.orm import Session + +from app.agents.learning_assistant_agent import LearningAssistantAgent +from app.core.config import settings +from app.core.context import UserContext +from app.core.exceptions import AppError +from app.repositories.knowledge_base_repository import KnowledgeBaseRepository +from app.schemas.learning_assistant import LearningAssistantChatRequest, LearningAssistantChatResponse, LearningAssistantSource +from app.services.knowledge_space_service import KnowledgeSpaceService +from app.services.vector_search_service import RetrievedChunk, VectorSearchService + + +@dataclass(frozen=True) +class LearningAssistantRetrieval: + """学习助手检索结果:封装知识库命中、耗时和降级原因。""" + + institution_id: int | None + score_threshold: float + sources: list[LearningAssistantSource] + embedding_latency_ms: int | None = None + search_latency_ms: int | None = None + retrieval_error: str | None = None + + +class LearningAssistantService: + """AI 学习助手服务:优先 RAG 检索,知识库不可用时降级为通用流式问答。""" + + def __init__( + self, + db: Session, + *, + vector_search_service: VectorSearchService | None = None, + agent: LearningAssistantAgent | None = None, + ) -> None: + self.db = db + self.repo = KnowledgeBaseRepository(db) + self.space_service = KnowledgeSpaceService(self.repo) + self.vector_search = vector_search_service or VectorSearchService(db) + self.agent = agent or LearningAssistantAgent() + + async def chat(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> LearningAssistantChatResponse: + """知识问答调试:检索失败不阻断回答,返回完整文本和检索降级信息。""" + start = time.perf_counter() + retrieval = await self._retrieve_sources(ctx, payload) + llm_started = time.perf_counter() + response = await self.agent.answer(payload.question, retrieval.sources) + total_latency_ms = int((time.perf_counter() - start) * 1000) + llm_latency_ms = response.latency_ms or int((time.perf_counter() - llm_started) * 1000) + self._write_query_log( + ctx=ctx, + payload=payload, + retrieval=retrieval, + answer=response.content, + model=response.model, + llm_latency_ms=llm_latency_ms, + total_latency_ms=total_latency_ms, + ) + return LearningAssistantChatResponse( + answer=response.content, + retrieval_hit=bool(retrieval.sources), + sources=retrieval.sources, + retrieval_error=retrieval.retrieval_error, + model=response.model, + embedding_latency_ms=retrieval.embedding_latency_ms, + search_latency_ms=retrieval.search_latency_ms, + llm_latency_ms=llm_latency_ms, + total_latency_ms=total_latency_ms, + ) + + async def stream_chat(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> AsyncIterator[str]: + """流式知识问答:先返回检索状态,再流式输出 LLM 回答。""" + start = time.perf_counter() + retrieval = await self._retrieve_sources(ctx, payload) + yield self._sse( + "retrieval_done", + { + "retrieval_hit": bool(retrieval.sources), + "sources": [source.model_dump() for source in retrieval.sources], + "retrieval_error": retrieval.retrieval_error, + "embedding_latency_ms": retrieval.embedding_latency_ms, + "search_latency_ms": retrieval.search_latency_ms, + }, + ) + + answer_parts: list[str] = [] + llm_latency_ms: int | None = None + model: str | None = None + try: + async for chunk in self.agent.stream_answer(payload.question, retrieval.sources): + if chunk.done: + llm_latency_ms = chunk.total_latency_ms + model = chunk.model + break + if chunk.delta: + answer_parts.append(chunk.delta) + yield self._sse("answer_delta", {"delta": chunk.delta}) + except AppError as exc: + yield self._sse("error", {"code": exc.code, "message": exc.message}) + return + except Exception: + yield self._sse("error", {"code": "LEARNING_ASSISTANT_LLM_FAILED", "message": "AI 学习助手回答生成失败,请稍后重试"}) + return + + answer = "".join(answer_parts) + total_latency_ms = int((time.perf_counter() - start) * 1000) + self._write_query_log( + ctx=ctx, + payload=payload, + retrieval=retrieval, + answer=answer, + model=model, + llm_latency_ms=llm_latency_ms, + total_latency_ms=total_latency_ms, + commit=True, + ) + yield self._sse("answer_done", {"model": model, "total_latency_ms": total_latency_ms}) + + async def _retrieve_sources(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> LearningAssistantRetrieval: + """知识检索:按机构读取知识空间;无空间、Milvus 或 embedding 异常时降级为空来源。""" + score_threshold = payload.score_threshold if payload.score_threshold is not None else settings.rag_score_threshold + try: + institution_id = self.space_service.require_institution_id(ctx) + except AppError: + return LearningAssistantRetrieval( + institution_id=None, + score_threshold=score_threshold, + sources=[], + retrieval_error="当前用户缺少机构信息,已转为大模型通用学习回答。", + ) + + try: + space = self.space_service.get_active_space(ctx) + retrieval = await self.vector_search.search( + institution_id=space.institution_id, + collection_name=space.collection_name, + question=payload.question, + top_k=payload.top_k, + score_threshold=payload.score_threshold, + ) + return LearningAssistantRetrieval( + institution_id=space.institution_id, + score_threshold=payload.score_threshold if payload.score_threshold is not None else space.score_threshold, + sources=self._build_sources(retrieval.chunks), + embedding_latency_ms=retrieval.embedding_latency_ms, + search_latency_ms=retrieval.search_latency_ms, + ) + except AppError as exc: + if exc.code in {"KNOWLEDGE_SPACE_NOT_FOUND", "MILVUS_COLLECTION_NOT_FOUND", "EMBEDDING_CALL_FAILED"}: + return LearningAssistantRetrieval( + institution_id=institution_id, + score_threshold=score_threshold, + sources=[], + retrieval_error="当前机构知识库暂未初始化或检索不可用,已转为大模型通用学习回答。", + ) + raise + except Exception: + return LearningAssistantRetrieval( + institution_id=institution_id, + score_threshold=score_threshold, + sources=[], + retrieval_error="当前机构知识库检索暂不可用,已转为大模型通用学习回答。", + ) + + def _write_query_log( + self, + *, + ctx: UserContext, + payload: LearningAssistantChatRequest, + retrieval: LearningAssistantRetrieval, + answer: str, + model: str | None, + llm_latency_ms: int | None, + total_latency_ms: int | None, + commit: bool = False, + ) -> None: + """查询日志:仅在存在机构 ID 时记录 RAG 命中、来源和耗时。""" + if retrieval.institution_id is None: + return + self.repo.create_query_log( + user_id=ctx.user_id, + institution_id=retrieval.institution_id, + question=payload.question, + retrieval_hit=bool(retrieval.sources), + retrieved_chunk_ids=[source.chunk_uid for source in retrieval.sources], + answer_summary=answer, + llm_model=model, + top_k=payload.top_k or len(retrieval.sources) or settings.rag_top_k, + score_threshold=retrieval.score_threshold, + embedding_latency_ms=retrieval.embedding_latency_ms, + search_latency_ms=retrieval.search_latency_ms, + llm_latency_ms=llm_latency_ms, + total_latency_ms=total_latency_ms, + ) + if commit: + self.db.commit() + + def _build_sources(self, chunks: list[RetrievedChunk]) -> list[LearningAssistantSource]: + """来源构建:把检索分片转换为前端可展示的 PDF 来源结构。""" + sources: list[LearningAssistantSource] = [] + for item in chunks: + document = self.repo.get_document(item.chunk.document_id, item.chunk.institution_id) + quote = item.chunk.chunk_text[:500] + sources.append( + LearningAssistantSource( + document_id=item.chunk.document_id, + document_title=document.document_title if document else None, + file_name=document.file_name if document else "", + page_start=item.chunk.page_start, + page_end=item.chunk.page_end, + chunk_uid=item.chunk.chunk_uid, + score=round(item.score, 4), + quote=quote, + ) + ) + return sources + + def _sse(self, event: str, data: dict) -> str: + """SSE 封装:统一输出 event + data 格式。""" + return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" diff --git a/app/services/vector_search_service.py b/app/services/vector_search_service.py new file mode 100644 index 0000000..777f953 --- /dev/null +++ b/app/services/vector_search_service.py @@ -0,0 +1,69 @@ +import time +from dataclasses import dataclass + +from sqlalchemy.orm import Session + +from app.core.config import settings +from app.integrations.milvus_adapter import MilvusVectorStore +from app.models.knowledge_base import KbKnowledgeChunk +from app.repositories.knowledge_base_repository import KnowledgeBaseRepository +from app.services.embedding_service import EmbeddingService + + +@dataclass(frozen=True) +class RetrievedChunk: + """RAG 检索结果:包含 MySQL 分片详情和 Milvus 相似度。""" + + chunk: KbKnowledgeChunk + score: float + + +@dataclass(frozen=True) +class RetrievalResult: + """检索结果包:返回命中分片和各阶段耗时。""" + + chunks: list[RetrievedChunk] + embedding_latency_ms: int + search_latency_ms: int + + +class VectorSearchService: + """向量检索服务:把用户问题向量化后在机构 Milvus collection 中检索。""" + + def __init__( + self, + db: Session, + *, + embedding_service: EmbeddingService | None = None, + vector_store: MilvusVectorStore | None = None, + ) -> None: + self.repo = KnowledgeBaseRepository(db) + self.embedding_service = embedding_service or EmbeddingService() + self.vector_store = vector_store or MilvusVectorStore() + + async def search( + self, + *, + institution_id: int, + collection_name: str, + question: str, + top_n: int | None = None, + top_k: int | None = None, + score_threshold: float | None = None, + ) -> RetrievalResult: + """知识检索:先召回 top_n,再按阈值和 top_k 过滤最终上下文。""" + vectors, embedding_latency_ms = await self.embedding_service.embed_texts([question]) + started = time.perf_counter() + hits = self.vector_store.search(collection_name, vectors[0], top_n or settings.rag_top_n) + search_latency_ms = int((time.perf_counter() - started) * 1000) + + threshold = settings.rag_score_threshold if score_threshold is None else score_threshold + final_limit = top_k or settings.rag_top_k + filtered = [hit for hit in hits if hit.score >= threshold][:final_limit] + chunks = self.repo.get_chunks_by_uids(institution_id, [hit.chunk_uid for hit in filtered]) + score_by_uid = {hit.chunk_uid: hit.score for hit in filtered} + return RetrievalResult( + chunks=[RetrievedChunk(chunk=chunk, score=score_by_uid.get(chunk.chunk_uid, 0.0)) for chunk in chunks], + embedding_latency_ms=embedding_latency_ms, + search_latency_ms=search_latency_ms, + ) diff --git a/app/tasks/__init__.py b/app/tasks/__init__.py new file mode 100644 index 0000000..b2059ea --- /dev/null +++ b/app/tasks/__init__.py @@ -0,0 +1 @@ +"""异步任务模块:生产环境通过 Celery worker 执行耗时任务。""" diff --git a/app/tasks/celery_app.py b/app/tasks/celery_app.py new file mode 100644 index 0000000..fbbe93b --- /dev/null +++ b/app/tasks/celery_app.py @@ -0,0 +1,16 @@ +from celery import Celery + +from app.core.config import settings + +celery_app = Celery( + "medical_agent", + broker=settings.celery_broker_url, + backend=settings.celery_result_backend, +) +celery_app.conf.update( + task_serializer="json", + accept_content=["json"], + result_serializer="json", + timezone="Asia/Shanghai", + enable_utc=False, +) diff --git a/app/tasks/knowledge_ingestion_tasks.py b/app/tasks/knowledge_ingestion_tasks.py new file mode 100644 index 0000000..48142e0 --- /dev/null +++ b/app/tasks/knowledge_ingestion_tasks.py @@ -0,0 +1,13 @@ +import asyncio + +from app.db.session import SessionLocal +from app.services.document_ingestion_service import DocumentIngestionService +from app.tasks.celery_app import celery_app + + +@celery_app.task(name="knowledge.ingest_document") +def ingest_knowledge_document(document_id: int, task_id: int) -> None: + """知识库异步任务:在 Celery worker 中执行 PDF 解析和向量入库。""" + with SessionLocal() as db: + asyncio.run(DocumentIngestionService(db).ingest_document(document_id, task_id)) + db.commit() diff --git a/docs/03_api_design.md b/docs/03_api_design.md index 8f8e224..c732452 100644 --- a/docs/03_api_design.md +++ b/docs/03_api_design.md @@ -1,19 +1,20 @@ # 医疗问诊 Agent API 文档 -> 当前文档只描述前端联调需要的后端能力。 +本文档面向前端联调,描述当前 FastAPI 后端已保留和可调用的接口。 -## 1. 联调地址 +公网基础地址: -| 项目 | 地址 | -|---|---| -| 公网网关 | `http://8.160.178.88/fastapi` | -| API Base URL | `http://8.160.178.88/fastapi/api/v1` | -| Swagger | `http://8.160.178.88/fastapi/docs` | -| OpenAPI JSON | `http://8.160.178.88/fastapi/openapi.json` | -| 存活检查 | `http://8.160.178.88/fastapi/health/live` | -| 就绪检查 | `http://8.160.178.88/fastapi/health/ready` | +```text +http://8.160.178.88/fastapi +``` -## 2. 通用规则 +本地调试地址: + +```text +http://127.0.0.1:9000 +``` + +## 1. 通用规则 除健康检查外,所有业务接口都需要携带: @@ -23,6 +24,8 @@ X-Entry-Scene: vue_frontend X-Request-Id: <可选> ``` +`Authorization` 使用 Django 用户中心签发的 access token。FastAPI 会转发 token 到 Django `/api/user/users/me/`,返回 200 后以 Django 用户 `id` 作为本服务的 `user_id`。 + 普通 JSON 接口统一返回: ```json @@ -33,31 +36,34 @@ X-Request-Id: <可选> } ``` -前端判断规则: +前端判断成功规则: - HTTP 状态码为 `2xx` - `code` 等于 `OK` - 业务数据从 `data` 读取 -SSE 流式接口不返回上述 JSON 包装,而是返回 `event + data` 事件流。 +SSE 流式接口不返回上述 JSON 包装,而是返回: -## 3. 前置接口 +```text +event: +data: {"key":"value"} +``` -### 3.1 当前用户 +## 2. 用户鉴权 | 接口名称 | url | api | methods | params(入参) | response(返回参数) | |---|---|---|---|---|---| -| 当前用户 | `http://8.160.178.88/fastapi/api/v1/auth/me` | `/api/v1/auth/me` | `GET` | Header:`Authorization` 必填,格式 `Bearer `;`X-Entry-Scene` 建议传。 | `data.user_id`、`data.username`、`data.display_name`、`data.role`、`data.phone`、`data.institution_id`、`data.department_id`、`data.status` 等 Django 用户中心字段。 | +| 当前用户信息 | `http://8.160.178.88/fastapi/api/v1/auth/me` | `/api/v1/auth/me` | `GET` | Header:`Authorization` 必填,格式 `Bearer `;`X-Entry-Scene` 可选,建议 `vue_frontend` | `data.user_id` 用户ID;`data.username` 用户名;`data.display_name` 显示名;`data.role` 用户角色;`data.institution_id` 机构ID;`data.institution_name` 机构名称;`data.department_id` 科室ID;`data.department_name` 科室名称;`data.status` 用户状态 | 请求示例: -```http -GET /api/v1/auth/me -Authorization: Bearer -X-Entry-Scene: vue_frontend +```bash +curl -X GET "http://8.160.178.88/fastapi/api/v1/auth/me" \ + -H "Authorization: Bearer " \ + -H "X-Entry-Scene: vue_frontend" ``` -成功返回示例: +返回示例: ```json { @@ -67,66 +73,37 @@ X-Entry-Scene: vue_frontend "user_id": "37", "source": "django_user_center", "username": "13700000099", - "display_name": "测试用户", + "display_name": "Swagger测试", "tenant_id": "1", "role": "student", "phone": "13700000099", "avatar": "", "gender": 0, "institution_id": 1, - "institution_name": "测试机构", - "department_id": 2, + "institution_name": "某医院", + "department_id": 1, "department_name": "儿科", "status": 1 } } ``` -### 3.2 病例读取 +## 3. 训练页面接口 + +### 3.1 推荐配置信息 | 接口名称 | url | api | methods | params(入参) | response(返回参数) | |---|---|---|---|---|---| -| 病例列表 | `http://8.160.178.88/fastapi/api/v1/cases` | `/api/v1/cases` | `GET` | Query:`department_id` 选填,科室 ID;`training_type` 选填;`mode` 选填,允许 `practice`、`teaching`。 | `data.items[]`,包含 `id`、`title`、`department_id`、`difficulty`、`chief_complaint`、`has_teaching_video`、`has_quiz`。 | -| 病例详情 | `http://8.160.178.88/fastapi/api/v1/cases/{case_id}` | `/api/v1/cases/{case_id}` | `GET` | Path:`case_id` 必填,病例 ID。 | 病例入口展示信息,不返回标准答案、隐藏病史、检查结果和评分规则。 | +| 推荐配置信息 | `http://8.160.178.88/fastapi/api/v1/training-config/recommended` | `/api/v1/training-config/recommended` | `GET` | Query:`case_id` 必填,病例ID | `data.recommended` 推荐配置值;`data.recommended_labels` 推荐配置中文名;`data.options` 可选项 | -## 4. 训练页面接口 - -| 接口名称 | url | api | methods | params(入参) | response(返回参数) | -|---|---|---|---|---|---| -| 推荐配置信息 | `http://8.160.178.88/fastapi/api/v1/training-config/recommended?case_id={case_id}` | `/api/v1/training-config/recommended` | `GET` | Query:`case_id` 必填,病例 ID。 | `data.recommended` 默认配置;`data.recommended_labels` 中文标签。 | -| 训练配置信息 | `http://8.160.178.88/fastapi/api/v1/training-config/options?case_id={case_id}` | `/api/v1/training-config/options` | `GET` | Query:`case_id` 必填,病例 ID。 | `data.options` 全部可选项;`data.recommended` 推荐默认值。 | -| 新建会话 | `http://8.160.178.88/fastapi/api/v1/sessions` | `/api/v1/sessions` | `POST` | Body:`case_id` 必填;`training_type` 必填;`mode` 必填,当前训练填 `practice`;`score_type` 选填,允许 `percentage`、`five_point`;`patient_config` 选填。 | `data.session_id`、`data.session_code`、`data.status`、`data.patient_opening`、`data.patient_config`。 | -| 流式会话 | `http://8.160.178.88/fastapi/api/v1/sessions/{session_id}/chat/stream` | `/api/v1/sessions/{session_id}/chat/stream` | `POST` | Path:`session_id` 必填;Body:`message` 必填,医学生问句。 | SSE:`message_delta` 返回患者回复增量;`message_done` 返回耗时;`error` 返回错误。 | -| 王主任练习提示 | `http://8.160.178.88/fastapi/api/v1/sessions/{session_id}/hints/stream` | `/api/v1/sessions/{session_id}/hints/stream` | `POST` | Path:`session_id` 必填;Body:`last_user_message` 选填;`scope` 选填,默认 `current_conversation`。 | SSE:`hint_delta` 返回一句话提示增量;`hint_done` 返回结束事件。 | -| 体格检查列表获取 | `http://8.160.178.88/fastapi/api/v1/sessions/{session_id}/physical-exams` | `/api/v1/sessions/{session_id}/physical-exams` | `GET` | Path:`session_id` 必填。 | `data.items[]`,包含体格检查项 `item_code`、`item_name`、`item_type`。不返回结果。 | -| 辅助检查列表获取 | `http://8.160.178.88/fastapi/api/v1/sessions/{session_id}/auxiliary-exams` | `/api/v1/sessions/{session_id}/auxiliary-exams` | `GET` | Path:`session_id` 必填。 | `data.items[]`,包含辅助检查项 `item_code`、`item_name`、`item_type`。不返回结果。 | -| 体格检查某项结果 | `http://8.160.178.88/fastapi/api/v1/sessions/{session_id}/physical-exams/{item_code}` | `/api/v1/sessions/{session_id}/physical-exams/{item_code}` | `POST` | Path:`session_id` 必填;`item_code` 必填,必须属于体格检查。 | `data.result_text`、`data.result_structured`、`data.context_written`、`data.already_ordered`。 | -| 辅助检查某项结果 | `http://8.160.178.88/fastapi/api/v1/sessions/{session_id}/auxiliary-exams/{item_code}` | `/api/v1/sessions/{session_id}/auxiliary-exams/{item_code}` | `POST` | Path:`session_id` 必填;`item_code` 必填,必须属于辅助检查。 | `data.result_text`、`data.result_structured`、`data.context_written`、`data.already_ordered`。 | -| 完成问诊 | `http://8.160.178.88/fastapi/api/v1/sessions/{session_id}/complete-inquiry` | `/api/v1/sessions/{session_id}/complete-inquiry` | `POST` | Path:`session_id` 必填。至少完成一轮医生问诊。 | `data.session_id`、`data.status=diagnosis`。 | -| 提交诊断 | `http://8.160.178.88/fastapi/api/v1/sessions/{session_id}/diagnosis` | `/api/v1/sessions/{session_id}/diagnosis` | `POST` | Path:`session_id` 必填;Body:`primary_diagnosis` 必填;`diagnosis_basis` 必填;`differential_diagnoses` 选填数组。 | `data.status=treatment`。 | -| 提交治疗 | `http://8.160.178.88/fastapi/api/v1/sessions/{session_id}/treatment` | `/api/v1/sessions/{session_id}/treatment` | `POST` | Path:`session_id` 必填;Body:`treatment_principle`、`treatment_measures` 必填;`risk_plan`、`communication`、`follow_up` 选填。 | `data.status=evaluating`。 | -| 生成评价 | `http://8.160.178.88/fastapi/api/v1/sessions/{session_id}/evaluation` | `/api/v1/sessions/{session_id}/evaluation` | `POST` | Path:`session_id` 必填;Body:`score_type` 选填,允许 `percentage`、`five_point`。必须已提交治疗。 | `data.evaluation_id`、`data.total_score`、`data.dimension_scores[]`、`data.score_details[]`、`data.overall_comment`。 | -| 获取评价(详情) | `http://8.160.178.88/fastapi/api/v1/evaluations/{evaluation_id}` | `/api/v1/evaluations/{evaluation_id}` | `GET` | Path:`evaluation_id` 必填。只允许当前用户访问自己的评价。 | 评价详情,包含病例、总分、维度评分、评分明细、改进建议、PDF 路径。 | -| 下载 PDF | `http://8.160.178.88/fastapi/api/v1/evaluations/{evaluation_id}/download-pdf` | `/api/v1/evaluations/{evaluation_id}/download-pdf` | `GET` | Path:`evaluation_id` 必填。Header:`Authorization` 必填。 | 成功时返回 `application/pdf` 文件流,浏览器可下载;失败时返回统一错误 JSON。 | - -### 4.1 训练配置入参和返回示例 - -请求: - -```http -GET /api/v1/training-config/options?case_id=2 -Authorization: Bearer -X-Entry-Scene: vue_frontend -``` - -返回: +返回示例: ```json { "code": "OK", "message": "success", "data": { - "case_id": 2, + "case_id": 1, "recommended": { "visit_environment": "outpatient", "age_group": "child", @@ -139,42 +116,28 @@ X-Entry-Scene: vue_frontend "education_level": "高等教育", "personality": "平和" }, - "options": { - "visit_environment": [ - {"value": "outpatient", "label": "门诊"}, - {"value": "emergency", "label": "急诊"}, - {"value": "ward", "label": "病房"} - ], - "age_group": [ - {"value": "child", "label": "儿童"}, - {"value": "youth", "label": "青年"}, - {"value": "middle_aged", "label": "中年"}, - {"value": "elderly", "label": "老年"} - ], - "education_level": [ - {"value": "primary_or_below", "label": "小学及以下"}, - {"value": "secondary", "label": "中等教育"}, - {"value": "higher", "label": "高等教育"} - ], - "personality": [ - {"value": "calm", "label": "平和"}, - {"value": "anxious", "label": "焦虑"}, - {"value": "impatient", "label": "急躁"}, - {"value": "cooperative", "label": "配合"}, - {"value": "suspicious", "label": "多疑"} - ] - } + "options": {} } } ``` -### 4.2 新建会话入参和返回示例 +### 3.2 训练配置信息 -请求: +| 接口名称 | url | api | methods | params(入参) | response(返回参数) | +|---|---|---|---|---|---| +| 训练配置信息 | `http://8.160.178.88/fastapi/api/v1/training-config/options` | `/api/v1/training-config/options` | `GET` | Query:`case_id` 必填,病例ID | `data.options.visit_environment` 就诊环境;`data.options.age_group` 年龄段;`data.options.education_level` 文化程度;`data.options.personality` 性格 | + +### 3.3 新建会话 + +| 接口名称 | url | api | methods | params(入参) | response(返回参数) | +|---|---|---|---|---|---| +| 新建会话 | `http://8.160.178.88/fastapi/api/v1/sessions` | `/api/v1/sessions` | `POST` | Body:`case_id` 必填;`training_type` 必填,当前用 `diagnosis_treatment`;`mode` 必填,允许 `practice`、`teaching`;`score_type` 必填,允许 `percentage`、`five_point`;`patient_config` 可选,自定义病人配置 | `data.session_id` 会话ID;`data.case_id` 病例ID;`data.status` 当前阶段;`data.patient_config` 实际使用的病人配置 | + +请求示例: ```json { - "case_id": 2, + "case_id": 1, "training_type": "diagnosis_treatment", "mode": "practice", "score_type": "percentage", @@ -187,249 +150,151 @@ X-Entry-Scene: vue_frontend } ``` -返回: +### 3.4 流式会话 -```json -{ - "code": "OK", - "message": "success", - "data": { - "session_id": 12, - "session_code": "sess_20260608120000_xxxx", - "status": "inquiry", - "patient_opening": "家长:医生,孩子发热咳嗽好几天了,昨天开始喘得厉害,精神也不太好。", - "patient_config": { - "values": { - "visit_environment": "outpatient", - "age_group": "child", - "education_level": "higher", - "personality": "calm" - }, - "labels": { - "visit_environment": "门诊", - "age_group": "儿童", - "education_level": "高等教育", - "personality": "平和" - } - } - } -} -``` +| 接口名称 | url | api | methods | params(入参) | response(返回参数) | +|---|---|---|---|---|---| +| 流式会话 | `http://8.160.178.88/fastapi/api/v1/sessions/{session_id}/chat/stream` | `/api/v1/sessions/{session_id}/chat/stream` | `POST` | Path:`session_id` 必填;Body:`message` 必填,医学生问诊内容 | SSE:`message_delta` AI病人回复增量;`message_done` 完成;`error` 错误 | -### 4.3 SSE 流式会话返回格式 - -请求: - -```json -{ - "message": "孩子发热几天了?最高体温多少?" -} -``` - -返回: +SSE 返回示例: ```text event: message_delta -data: {"delta":"发热有4天了,"} - -event: message_delta -data: {"delta":"最高烧到39度多。"} +data: {"delta":"发热有4天了,最高39度多。"} event: message_done -data: {"latency_ms":1800,"first_token_ms":500,"model":"deepseek-chat","fallback_used":false} +data: {"latency_ms":1200,"first_token_ms":300,"model":"deepseek-chat","fallback_used":false} ``` -### 4.4 检查结果返回示例 - -```json -{ - "code": "OK", - "message": "success", - "data": { - "item_code": "blood_routine", - "item_name": "血常规", - "item_type": "lab", - "result_text": "WBC 12.5×10^9/L,中性粒细胞比例72%,提示感染及炎症反应。", - "result_structured": { - "wbc": "12.5×10^9/L", - "neutrophil": "72%" - }, - "is_key": true, - "is_abnormal": true, - "context_written": true, - "already_ordered": false - } -} -``` - -检查规则: - -- 检查结果只来自数据库,不由 LLM 生成。 -- 检查结果会写入本次会话短期 memory 和评分依据。 -- 同一会话重复申请相同 `item_code` 时返回已有结果,`already_ordered=true`。 - -### 4.5 诊断和治疗提交示例 - -提交诊断: - -```json -{ - "primary_diagnosis": "支气管肺炎", - "differential_diagnoses": ["支气管哮喘急性发作", "上呼吸道感染"], - "diagnosis_basis": "结合发热、咳嗽、喘息、肺部湿啰音、胸片异常、炎症指标升高和血氧情况,符合儿童支气管肺炎表现。" -} -``` - -提交治疗: - -```json -{ - "treatment_principle": "抗感染、止咳平喘、改善氧合并严密观察病情变化。", - "treatment_measures": "根据病情进行抗感染治疗,必要时雾化吸入缓解喘息,监测体温、呼吸和血氧。", - "risk_plan": "关注低氧、呼吸困难加重、持续高热、精神反应差和脱水。", - "communication": "向家属说明病情、用药注意事项、危险信号和复诊指征。", - "follow_up": "治疗后复查体温、呼吸、血氧和必要炎症指标。" -} -``` - -## 5. 教学互动接口 +### 3.5 王主任练习提示 | 接口名称 | url | api | methods | params(入参) | response(返回参数) | |---|---|---|---|---|---| -| 获取教学列表(题目 选项 答案 解析文本 视频) | `http://8.160.178.88/fastapi/api/v1/teaching/cases/{case_id}/items` | `/api/v1/teaching/cases/{case_id}/items` | `GET` | Path:`case_id` 必填,必须存在 `teaching_case` 数据。 | `data.case` 病例摘要;`data.questions[]` 题目列表,包含 `stem`、`options[]`、`answer`、`analysis`、`video`。 | -| 生成评价 | `http://8.160.178.88/fastapi/api/v1/teaching/evaluation` | `/api/v1/teaching/evaluation` | `POST` | Body:`case_id` 必填;`answers[]` 必填;`answers[].question_id` 必填;`answers[].selected_answer` 必填;`score_type` 选填。 | `data.session_id`、`data.evaluation_id`、`data.total_score`、`data.dimension_scores[]`、`data.score_details[]`、`data.overall_comment`。 | -| 获取评价(详情) | `http://8.160.178.88/fastapi/api/v1/evaluations/{evaluation_id}` | `/api/v1/evaluations/{evaluation_id}` | `GET` | Path:`evaluation_id` 必填。 | 教学互动评价详情,结构与训练评价一致。 | -| 下载 PDF | `http://8.160.178.88/fastapi/api/v1/evaluations/{evaluation_id}/download-pdf` | `/api/v1/evaluations/{evaluation_id}/download-pdf` | `GET` | Path:`evaluation_id` 必填。 | 返回 `application/pdf` 文件流。 | +| 王主任练习提示 | `http://8.160.178.88/fastapi/api/v1/sessions/{session_id}/hints/stream` | `/api/v1/sessions/{session_id}/hints/stream` | `POST` | Path:`session_id` 必填;Body:`scope` 可选,默认 `current_conversation`;`last_user_message` 可选 | SSE:`hint_delta` 提示文本增量;`hint_done` 完成;`error` 错误 | -获取教学列表返回示例: - -```json -{ - "code": "OK", - "message": "success", - "data": { - "case": { - "case_id": 2, - "title": "支气管肺炎 - 6岁男性患儿", - "chief_complaint": "发热、咳嗽4天,喘息1天" - }, - "teaching_goal": "围绕儿科肺炎问诊、检查选择、诊断依据、治疗决策和医患沟通完成互动训练。", - "questions": [ - { - "question_id": "q1", - "question_type": "single_choice", - "stem": "该患儿最需要优先关注的病情严重程度指标是?", - "options": [ - {"key": "A", "text": "体温峰值"}, - {"key": "B", "text": "血氧饱和度"} - ], - "answer": "B", - "analysis": "血氧饱和度能帮助判断低氧和肺炎严重程度。", - "video": {"title": "儿童肺炎教学示例视频", "url": ""}, - "knowledge_points": ["严重程度评估", "血氧判断"] - } - ] - } -} -``` - -生成教学评价请求示例: - -```json -{ - "case_id": 2, - "score_type": "percentage", - "answers": [ - {"question_id": "q1", "selected_answer": "B"}, - {"question_id": "q2", "selected_answer": "A"} - ] -} -``` - -## 6. 个人中心接口 +### 3.6 检查列表和结果 | 接口名称 | url | api | methods | params(入参) | response(返回参数) | |---|---|---|---|---|---| -| 训练记录列表 | `http://8.160.178.88/fastapi/api/v1/evaluations` | `/api/v1/evaluations` | `GET` | Header:`Authorization` 必填。 | `data.items[]` 当前用户完整训练后的评价记录,包含 `evaluation_id`、`case_title`、`score_type`、`total_score`、`created_at`、`pdf_exported`。 | -| 训练记录详情(评价详情) | `http://8.160.178.88/fastapi/api/v1/evaluations/{evaluation_id}` | `/api/v1/evaluations/{evaluation_id}` | `GET` | Path:`evaluation_id` 必填。 | 完整评价详情,训练和教学互动评价共用。 | +| 体格检查列表获取 | `http://8.160.178.88/fastapi/api/v1/sessions/{session_id}/physical-exams` | `/api/v1/sessions/{session_id}/physical-exams` | `GET` | Path:`session_id` 必填 | `data.items[]` 当前病例可用体格检查项 | +| 辅助检查列表获取 | `http://8.160.178.88/fastapi/api/v1/sessions/{session_id}/auxiliary-exams` | `/api/v1/sessions/{session_id}/auxiliary-exams` | `GET` | Path:`session_id` 必填 | `data.items[]` 当前病例可用辅助检查项 | +| 体格检查某项结果 | `http://8.160.178.88/fastapi/api/v1/sessions/{session_id}/physical-exams/{item_code}` | `/api/v1/sessions/{session_id}/physical-exams/{item_code}` | `POST` | Path:`session_id` 必填;`item_code` 必填,如 `lung_auscultation` | `data.item_code`;`data.item_name`;`data.result_text`;`data.already_ordered`;`data.context_written` | +| 辅助检查某项结果 | `http://8.160.178.88/fastapi/api/v1/sessions/{session_id}/auxiliary-exams/{item_code}` | `/api/v1/sessions/{session_id}/auxiliary-exams/{item_code}` | `POST` | Path:`session_id` 必填;`item_code` 必填,如 `blood_routine`、`crp`、`chest_xray`、`oxygen_saturation` | 同上 | -训练记录列表返回示例: +### 3.7 阶段提交 -```json -{ - "code": "OK", - "message": "success", - "data": { - "items": [ - { - "evaluation_id": 101, - "case_title": "支气管肺炎 - 6岁男性患儿", - "score_type": "percentage", - "total_score": 82, - "created_at": "2026-06-08T12:00:00", - "pdf_exported": true - } - ] - } -} -``` +| 接口名称 | url | api | methods | params(入参) | response(返回参数) | +|---|---|---|---|---|---| +| 完成问诊 | `http://8.160.178.88/fastapi/api/v1/sessions/{session_id}/complete-inquiry` | `/api/v1/sessions/{session_id}/complete-inquiry` | `POST` | Path:`session_id` 必填 | `data.status=diagnosis` | +| 提交诊断 | `http://8.160.178.88/fastapi/api/v1/sessions/{session_id}/diagnosis` | `/api/v1/sessions/{session_id}/diagnosis` | `POST` | Path:`session_id` 必填;Body:`primary_diagnosis` 必填;`differential_diagnoses` 可选数组;`diagnosis_basis` 必填 | `data.status=treatment` | +| 提交治疗 | `http://8.160.178.88/fastapi/api/v1/sessions/{session_id}/treatment` | `/api/v1/sessions/{session_id}/treatment` | `POST` | Path:`session_id` 必填;Body:`treatment_principle`、`treatment_measures`、`risk_plan`、`communication`、`follow_up` 均必填 | `data.status=evaluating` | +| 生成评价 | `http://8.160.178.88/fastapi/api/v1/sessions/{session_id}/evaluation` | `/api/v1/sessions/{session_id}/evaluation` | `POST` | Path:`session_id` 必填;Body:`score_type` 可选,允许 `percentage`、`five_point` | `data.evaluation_id`;`data.total_score`;`data.score_details[]`;`data.report_summary` | -## 7. PDF 下载前端写法 +## 4. 教学互动接口 -因为下载接口需要 Bearer token,前端不能直接使用普通 ``。使用 `fetch` 获取 blob 后触发下载: +| 接口名称 | url | api | methods | params(入参) | response(返回参数) | +|---|---|---|---|---|---| +| 获取教学列表 | `http://8.160.178.88/fastapi/api/v1/teaching/cases/{case_id}/items` | `/api/v1/teaching/cases/{case_id}/items` | `GET` | Path:`case_id` 必填 | `data.case` 病例信息;`data.questions[]` 题目、选项、答案、解析文本、视频 | +| 教学互动生成评价 | `http://8.160.178.88/fastapi/api/v1/teaching/evaluation` | `/api/v1/teaching/evaluation` | `POST` | Body:`case_id` 必填;`answers[]` 必填,包含 `question_id` 和 `selected_answer`;`score_type` 可选 | `data.evaluation_id`;`data.session_id`;`data.total_score`;`data.score_details[]` | + +## 5. 评价和个人中心接口 + +| 接口名称 | url | api | methods | params(入参) | response(返回参数) | +|---|---|---|---|---|---| +| 训练记录列表 | `http://8.160.178.88/fastapi/api/v1/evaluations` | `/api/v1/evaluations` | `GET` | Query:`limit` 可选,默认20;`offset` 可选,默认0 | `data.items[]` 当前用户完整训练后的评价记录 | +| 训练记录详情 / 评价详情 | `http://8.160.178.88/fastapi/api/v1/evaluations/{evaluation_id}` | `/api/v1/evaluations/{evaluation_id}` | `GET` | Path:`evaluation_id` 必填 | 完整评价详情,训练和教学互动共用 | +| 导出 PDF | `http://8.160.178.88/fastapi/api/v1/evaluations/{evaluation_id}/export-pdf` | `/api/v1/evaluations/{evaluation_id}/export-pdf` | `POST` | Path:`evaluation_id` 必填 | `data.file_path`;`data.exported_at` | +| 下载 PDF | `http://8.160.178.88/fastapi/api/v1/evaluations/{evaluation_id}/download-pdf` | `/api/v1/evaluations/{evaluation_id}/download-pdf` | `GET` | Path:`evaluation_id` 必填 | `application/pdf` 文件流,浏览器可直接下载 | + +PDF 下载前端写法: ```ts async function downloadEvaluationPdf(baseUrl: string, token: string, evaluationId: number) { - const response = await fetch(`${baseUrl}/evaluations/${evaluationId}/download-pdf`, { + const response = await fetch(`${baseUrl}/api/v1/evaluations/${evaluationId}/download-pdf`, { method: "GET", headers: { Authorization: `Bearer ${token}`, "X-Entry-Scene": "vue_frontend", }, }); - - if (!response.ok) { - const error = await response.json().catch(() => null); - throw new Error(error?.message || `PDF 下载失败:${response.status}`); - } - + if (!response.ok) throw new Error("PDF 下载失败"); const blob = await response.blob(); const url = URL.createObjectURL(blob); - const a = document.createElement("a"); - a.href = url; - a.download = `evaluation_${evaluationId}.pdf`; - document.body.appendChild(a); - a.click(); - a.remove(); + const link = document.createElement("a"); + link.href = url; + link.download = `evaluation-${evaluationId}.pdf`; + link.click(); URL.revokeObjectURL(url); } ``` +## 6. AI 学习助手接口 + +该接口用于普通用户医学知识问答。后端优先检索本机构知识库;如果机构知识库未初始化、Milvus / embedding 暂不可用或未命中来源,接口仍会继续调用 LLM,回答开头会声明“未检索到本机构知识库参考,以下为大模型通用学习回答。” + +| 接口名称 | url | api | methods | params(入参) | response(返回参数) | +|---|---|---|---|---|---| +| AI学习助手流式问答 | `http://8.160.178.88/fastapi/api/v1/learning-assistant/chat/stream` | `/api/v1/learning-assistant/chat/stream` | `POST` | Body:`question` 必填,2-1000字;`top_k` 可选,1-10;`score_threshold` 可选,0-1 | SSE:`retrieval_done` 检索状态;`answer_delta` 回答增量;`answer_done` 完成;`error` 错误 | + +请求示例: + +```json +{ + "question": "支气管肺炎的典型临床表现有哪些?", + "top_k": 5, + "score_threshold": 0.35 +} +``` + +命中知识库时: + +```text +event: retrieval_done +data: {"retrieval_hit":true,"sources":[{"document_id":1,"document_title":"诊断学第十版","file_name":"19.《诊断学》十版(1).pdf","page_start":123,"page_end":124,"chunk_uid":"doc1_chunk35_abcd1234","score":0.78,"quote":"原文片段..."}],"retrieval_error":null,"embedding_latency_ms":320,"search_latency_ms":40} + +event: answer_delta +data: {"delta":"支气管肺炎常见表现包括发热、咳嗽、喘息...【来源1】"} + +event: answer_done +data: {"model":"deepseek-chat","total_latency_ms":2300} +``` + +未命中或知识库不可用时: + +```text +event: retrieval_done +data: {"retrieval_hit":false,"sources":[],"retrieval_error":"当前机构知识库暂未初始化或检索不可用,已转为大模型通用学习回答。","embedding_latency_ms":null,"search_latency_ms":null} + +event: answer_delta +data: {"delta":"未检索到本机构知识库参考,以下为大模型通用学习回答。..."} + +event: answer_done +data: {"model":"deepseek-chat","total_latency_ms":1800} +``` + +## 7. 后台预留:内容管理员知识库接口 + +该组接口是后台内容管理员能力,学生端不展示上传入口。当前阶段保留接口和数据结构,后续生产环境接入完整 PDF 解析、分片、embedding、Milvus 构建和异步任务。 + +| 接口名称 | url | api | methods | params(入参) | response(返回参数) | +|---|---|---|---|---|---| +| 上传知识库 PDF | `http://8.160.178.88/fastapi/api/v1/knowledge-admin/documents/upload` | `/api/v1/knowledge-admin/documents/upload` | `POST` | `multipart/form-data`:`file` 必填,PDF 文件;`document_title` 可选;`document_category` 可选,允许 `textbook`、`guideline`、`manual`、`other`;`version` 可选,默认 `v1` | `data.document_id`;`data.task_id`;`data.status`;`data.parse_status`;`data.embedding_status`;`data.chunk_count`;`data.collection_name` | +| 知识库文档列表 | `http://8.160.178.88/fastapi/api/v1/knowledge-admin/documents` | `/api/v1/knowledge-admin/documents` | `GET` | Header:内容管理员 token | `data.items[]` 本机构上传文档 | +| 知识库文档详情 | `http://8.160.178.88/fastapi/api/v1/knowledge-admin/documents/{document_id}` | `/api/v1/knowledge-admin/documents/{document_id}` | `GET` | Path:`document_id` 必填;Header:内容管理员 token | 文档构建状态、分片数、错误信息 | + ## 8. 常见错误码 -| HTTP | code | 说明 | +| HTTP状态 | code | 含义 | |---:|---|---| -| 401 | `AUTH_CREDENTIAL_REQUIRED` | 缺少 Authorization。 | -| 401 | `AUTH_USER_INVALID` | token 无效、过期或 Django 返回非 200。 | -| 403 | `AUTH_USER_DISABLED` | Django 用户状态被禁用。 | -| 503 | `AUTH_USER_CENTER_UNAVAILABLE` | Django 用户中心超时或不可达。 | -| 404 | `CASE_NOT_FOUND` | 病例不存在、未发布或已停用。 | -| 404 | `SESSION_NOT_FOUND` | 会话不存在或不属于当前用户。 | -| 400 | `SESSION_STATUS_INVALID` | 当前状态不允许执行该操作。 | -| 400 | `INQUIRY_REQUIRED` | 完成问诊前没有医生提问。 | -| 400 | `DIAGNOSIS_REQUIRED` | 提交治疗前没有提交诊断。 | -| 400 | `TREATMENT_REQUIRED` | 生成评价前没有提交治疗。 | -| 404 | `ORDER_ITEM_NOT_FOUND` | 当前病例不存在该检查项。 | -| 400 | `ORDER_ITEM_TYPE_MISMATCH` | 检查接口类型和检查项类型不匹配。 | -| 404 | `EVALUATION_NOT_FOUND` | 评价不存在或不属于当前用户。 | -| SSE error | `LLM_STREAM_TIMEOUT` | 流式问诊首段或总耗时超时。 | -| SSE error | `LLM_STREAM_FAILED` | 流式模型调用失败。 | -| 500 | `PDF_EXPORT_FAILED` | PDF 生成失败。 | - -## 9. 当前保留接口清单 - -当前后端业务模块: - -- 训练页面:推荐配置、训练配置、新建会话、流式会话、王主任练习提示、检查列表、检查结果、完成问诊、诊断、治疗、评价、评价详情、PDF 下载。 -- 教学互动:教学列表、教学评价、评价详情、PDF 下载。 -- 个人中心:训练记录列表、训练记录详情。 +| 401 | `AUTH_CREDENTIAL_REQUIRED` | 缺少 `Authorization` | +| 401 | `AUTH_USER_INVALID` | token 无效、过期或 Django 返回非 200 | +| 403 | `AUTH_USER_DISABLED` | Django 用户状态被禁用 | +| 403 | `KNOWLEDGE_ADMIN_FORBIDDEN` | 当前用户不是内容管理员,不能使用后台知识库上传接口 | +| 404 | `CASE_NOT_FOUND` | 病例不存在、未发布或已停用 | +| 404 | `SESSION_NOT_FOUND` | 会话不存在或不属于当前用户 | +| 400 | `SESSION_STATUS_INVALID` | 当前状态不允许执行该操作 | +| 404 | `ORDER_ITEM_NOT_FOUND` | 当前病例不存在该检查项 | +| 400 | `ORDER_ITEM_TYPE_MISMATCH` | 体格检查 / 辅助检查接口与检查项类型不匹配 | +| 404 | `EVALUATION_NOT_FOUND` | 评价不存在或不属于当前用户 | +| 502 | `LLM_STREAM_FAILED` | LLM 流式调用失败 | +| 503 | `AUTH_USER_CENTER_UNAVAILABLE` | Django 用户中心超时或不可达 | diff --git a/requirements.txt b/requirements.txt index 773ecba..2614425 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,7 @@ httpx>=0.27.0 python-dotenv>=1.0.0 reportlab>=4.2.0 redis>=5.0.0 +pymilvus>=2.6,<2.7 +celery>=5.4,<6 +PyMuPDF>=1.24,<2 +python-multipart>=0.0.9 diff --git a/tests/test_api_contract.py b/tests/test_api_contract.py index d45f4ee..c770b8d 100644 --- a/tests/test_api_contract.py +++ b/tests/test_api_contract.py @@ -10,6 +10,10 @@ os.environ["REPORT_STORAGE_DIR"] = str(Path(tempfile.gettempdir()) / "medical_ag os.environ.setdefault("RUNTIME_MEMORY_BACKEND", "memory") os.environ.setdefault("LLM_MOCK_ENABLED", "true") os.environ.setdefault("AUTH_USER_ME_URL", "http://django-user-center.test/api/user/users/me/") +os.environ.setdefault("MILVUS_URI", "mock://milvus") +os.environ.setdefault("EMBEDDING_PROVIDER", "mock") +os.environ.setdefault("KNOWLEDGE_INGESTION_SYNC", "true") +os.environ.setdefault("KNOWLEDGE_STORAGE_DIR", str(Path(tempfile.gettempdir()) / "medical_agent_test_knowledge")) sys.path.insert(0, str(Path(__file__).resolve().parents[1])) @@ -36,20 +40,33 @@ def run_api_contract_tests() -> None: from app.core.exceptions import AppError raise AppError("AUTH_CREDENTIAL_REQUIRED", "Authorization header is required", 401) - user_id = "api_user_002" if "api_user_002_token" in authorization else "api_user_001" + if "content_admin_token" in authorization: + user_id = "content_admin_001" + role = "content_admin" + else: + user_id = "api_user_002" if "api_user_002_token" in authorization else "api_user_001" + role = "student" return AuthenticatedUser( user_id=user_id, username=f"{user_id}_name", display_name="Swagger测试", - role="student", + role=role, tenant_id="1", + institution_id=1, + institution_name="test institution", + department_id=1, + department_name="pediatrics", status=1, profile={ "id": user_id, "username": f"{user_id}_name", "real_name": "Swagger测试", - "role_type": "student", + "role_type": role, "institution": 1, + "institution_id": 1, + "department": 1, + "department_id": 1, + "department_name": "pediatrics", "institution_name": "测试机构", "status": 1, }, @@ -97,6 +114,65 @@ def run_api_contract_tests() -> None: assert "/api/v1/evaluations/{evaluation_id}/download-pdf" in openapi_payload["paths"] assert "/api/v1/teaching/cases/{case_id}/items" in openapi_payload["paths"] assert "/api/v1/teaching/evaluation" in openapi_payload["paths"] + assert "/api/v1/knowledge-admin/documents/upload" in openapi_payload["paths"] + assert "/api/v1/learning-assistant/chat" not in openapi_payload["paths"] + assert "/api/v1/learning-assistant/chat/stream" in openapi_payload["paths"] + + with client.stream( + "POST", + "/api/v1/learning-assistant/chat/stream", + headers=headers, + json={"question": "支气管肺炎有哪些常见表现?", "top_k": 1}, + ) as no_kb_stream: + assert no_kb_stream.status_code == 200 + no_kb_stream_text = "".join(no_kb_stream.iter_text()) + assert "event: retrieval_done" in no_kb_stream_text + assert '"retrieval_hit": false' in no_kb_stream_text + assert "event: answer_delta" in no_kb_stream_text + assert "event: answer_done" in no_kb_stream_text + assert "event: error" not in no_kb_stream_text + + from app.integrations.pdf_parser import ParsedPdfPage, PdfParser + + def fake_pdf_parse(self, file_path): # noqa: ARG001 + """知识库测试:用稳定页文本替代真实 PDF 解析,避免测试依赖外部文件。""" + return [ + ParsedPdfPage( + page_number=12, + text="支气管肺炎常见表现包括发热、咳嗽、喘息和肺部湿啰音。血氧饱和度下降提示病情可能加重。", + ) + ] + + PdfParser.parse = fake_pdf_parse + admin_headers = {"Authorization": "Bearer content_admin_token", "X-Entry-Scene": "api_test"} + upload = client.post( + "/api/v1/knowledge-admin/documents/upload", + headers=admin_headers, + data={"document_title": "诊断学第十版", "document_category": "textbook", "version": "v1"}, + files={"file": ("diagnostics.pdf", b"%PDF-1.4\n%%EOF", "application/pdf")}, + ) + assert upload.status_code == 200 + assert upload.json()["data"]["status"] == "ready" + assert upload.json()["data"]["chunk_count"] >= 1 + document_id = upload.json()["data"]["document_id"] + + document_detail = client.get(f"/api/v1/knowledge-admin/documents/{document_id}", headers=admin_headers) + assert document_detail.status_code == 200 + assert document_detail.json()["data"]["document_title"] == "诊断学第十版" + + with client.stream( + "POST", + "/api/v1/learning-assistant/chat/stream", + headers=headers, + json={"question": "血氧下降说明什么?", "top_k": 1}, + ) as rag_stream: + assert rag_stream.status_code == 200 + rag_stream_text = "".join(rag_stream.iter_text()) + assert "event: retrieval_done" in rag_stream_text + assert '"retrieval_hit": true' in rag_stream_text + assert '"page_start": 12' in rag_stream_text + assert "event: answer_delta" in rag_stream_text + assert "event: answer_done" in rag_stream_text cases = client.get("/api/v1/cases", headers=headers) assert cases.status_code == 200