feat: add streaming learning assistant and knowledge base scaffolding

This commit is contained in:
刘金宝
2026-06-10 09:32:36 +08:00
parent f0cdc454b3
commit 89258ab448
31 changed files with 2021 additions and 330 deletions
+154
View File
@@ -0,0 +1,154 @@
import hashlib
import re
from dataclasses import dataclass
from app.integrations.pdf_parser import ParsedPdfPage
from app.models.knowledge_base import KbKnowledgeChunk
@dataclass(frozen=True)
class ChunkDraft:
"""分片草稿:PDF 文本切分后的中间结构,后续写入 MySQL 和 Milvus。"""
chunk_index: int
page_start: int
page_end: int
section_title: str | None
text: str
class DocumentChunkService:
"""文档分片服务:面向教材/指南 PDF 的页码保留语义分片。"""
def __init__(self, chunk_size: int = 1100, chunk_overlap: int = 180) -> None:
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def build_chunks(self, pages: list[ParsedPdfPage]) -> list[ChunkDraft]:
"""教材分片:按页和自然段切分,超长段落使用窗口切分并保留页码。"""
drafts: list[ChunkDraft] = []
buffer: list[str] = []
page_start: int | None = None
page_end: int | None = None
current_title: str | None = None
for page in pages:
paragraphs = self._split_paragraphs(page.text)
for paragraph in paragraphs:
detected_title = self._detect_title(paragraph)
if detected_title:
current_title = detected_title
for piece in self._split_long_text(paragraph):
candidate = "\n".join([*buffer, piece]).strip()
if buffer and len(candidate) > self.chunk_size:
drafts.append(
ChunkDraft(
chunk_index=len(drafts),
page_start=page_start or page.page_number,
page_end=page_end or page.page_number,
section_title=current_title,
text="\n".join(buffer).strip(),
)
)
buffer = self._overlap_tail(buffer)
page_start = page.page_number if not buffer else page_start
if not buffer:
page_start = page.page_number
page_end = page.page_number
buffer.append(piece)
if buffer:
drafts.append(
ChunkDraft(
chunk_index=len(drafts),
page_start=page_start or pages[-1].page_number,
page_end=page_end or pages[-1].page_number,
section_title=current_title,
text="\n".join(buffer).strip(),
)
)
return [draft for draft in drafts if draft.text]
def to_models(
self,
*,
institution_id: int,
document_id: int,
collection_name: str,
embedding_model: str,
drafts: list[ChunkDraft],
) -> list[KbKnowledgeChunk]:
"""分片落库:把分片草稿转换为 ORM 对象,chunk_uid 同时作为 Milvus vector_id。"""
rows: list[KbKnowledgeChunk] = []
for draft in drafts:
chunk_hash = hashlib.sha256(draft.text.encode("utf-8")).hexdigest()
chunk_uid = f"doc{document_id}_chunk{draft.chunk_index}_{chunk_hash[:12]}"
rows.append(
KbKnowledgeChunk(
institution_id=institution_id,
document_id=document_id,
chunk_uid=chunk_uid,
chunk_index=draft.chunk_index,
page_start=draft.page_start,
page_end=draft.page_end,
section_title=draft.section_title,
chunk_text=draft.text,
chunk_hash=chunk_hash,
token_count=max(1, len(draft.text) // 2),
vector_id=chunk_uid,
collection_name=collection_name,
embedding_model=embedding_model,
metadata_={"chunking": "page_semantic_window", "chunk_size": self.chunk_size, "overlap": self.chunk_overlap},
)
)
return rows
def _split_paragraphs(self, text: str) -> list[str]:
"""段落切分:优先按 PDF 自带换行和空白段落切分教材内容。"""
parts = re.split(r"\n{1,}", text)
return [part.strip() for part in parts if part.strip()]
def _split_long_text(self, text: str) -> list[str]:
"""超长兜底:对超过窗口的段落按句末标点拆分,仍过长时按字符窗口切分。"""
if len(text) <= self.chunk_size:
return [text]
sentences = re.split(r"(?<=[。!?;;.!?])", text)
pieces: list[str] = []
current = ""
for sentence in sentences:
if len(current) + len(sentence) > self.chunk_size and current:
pieces.append(current.strip())
current = current[-self.chunk_overlap :] if self.chunk_overlap else ""
current += sentence
if current.strip():
pieces.append(current.strip())
final: list[str] = []
for piece in pieces:
if len(piece) <= self.chunk_size:
final.append(piece)
continue
start = 0
while start < len(piece):
final.append(piece[start : start + self.chunk_size])
start += max(1, self.chunk_size - self.chunk_overlap)
return final
def _overlap_tail(self, buffer: list[str]) -> list[str]:
"""重叠窗口:保留上一片尾部少量文本,提升跨片问题召回。"""
if not self.chunk_overlap:
return []
text = "\n".join(buffer).strip()
tail = text[-self.chunk_overlap :]
return [tail] if tail else []
def _detect_title(self, paragraph: str) -> str | None:
"""标题识别:识别教材常见章、节、条目标题,作为分片元数据。"""
compact = paragraph.strip()
if len(compact) > 80:
return None
title_patterns = [
r"^第[一二三四五六七八九十百0-9]+[章节篇]",
r"^[一二三四五六七八九十]+[、..]",
r"^\d+(\.\d+){0,3}\s+",
]
return compact if any(re.search(pattern, compact) for pattern in title_patterns) else None
+177
View File
@@ -0,0 +1,177 @@
import hashlib
from pathlib import Path
from fastapi import UploadFile
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.context import UserContext
from app.core.exceptions import AppError
from app.integrations.milvus_adapter import MilvusVectorStore
from app.integrations.pdf_parser import PdfParser
from app.repositories.knowledge_base_repository import KnowledgeBaseRepository
from app.schemas.knowledge_admin import KnowledgeDocumentUploadResponse
from app.services.document_chunk_service import DocumentChunkService
from app.services.embedding_service import EmbeddingService
from app.services.knowledge_space_service import KnowledgeSpaceService
class DocumentIngestionService:
"""知识入库服务:处理 PDF 上传、解析、分片、向量化和 Milvus 写入。"""
def __init__(
self,
db: Session,
*,
parser: PdfParser | None = None,
chunker: DocumentChunkService | None = None,
embedding_service: EmbeddingService | None = None,
vector_store: MilvusVectorStore | None = None,
) -> None:
self.db = db
self.repo = KnowledgeBaseRepository(db)
self.space_service = KnowledgeSpaceService(self.repo)
self.parser = parser or PdfParser()
self.chunker = chunker or DocumentChunkService()
self.embedding_service = embedding_service or EmbeddingService()
self.vector_store = vector_store or MilvusVectorStore()
async def upload_pdf(
self,
ctx: UserContext,
file: UploadFile,
*,
document_title: str | None,
document_category: str,
version: str,
) -> KnowledgeDocumentUploadResponse:
"""文档上传:内容管理员上传 PDF 后创建知识文档并触发构建任务。"""
self.space_service.ensure_content_admin(ctx)
space = self.space_service.get_or_create_space(ctx)
content = await file.read()
self._validate_pdf(file, content)
file_sha256 = hashlib.sha256(content).hexdigest()
existing = self.repo.get_document_by_hash(space.institution_id, file_sha256)
if existing:
task = self.repo.create_ingestion_task(existing) if existing.status == "failed" else None
return KnowledgeDocumentUploadResponse(
document_id=existing.id,
task_id=task.id if task else None,
duplicate=True,
status=existing.status,
parse_status=existing.parse_status,
embedding_status=existing.embedding_status,
chunk_count=existing.chunk_count,
collection_name=space.collection_name,
)
storage_path = self._save_file(space.institution_id, file.filename or "knowledge.pdf", content)
document = self.repo.create_document(
institution_id=space.institution_id,
uploaded_by=ctx.user_id,
file_name=file.filename or storage_path.name,
file_sha256=file_sha256,
file_size=len(content),
file_path=str(storage_path),
document_title=document_title or Path(file.filename or "").stem or None,
document_category=document_category,
version=version,
)
task = self.repo.create_ingestion_task(document)
if settings.knowledge_ingestion_sync:
await self.ingest_document(document.id, task.id)
else:
self._enqueue_async_task(document.id, task.id)
return KnowledgeDocumentUploadResponse(
document_id=document.id,
task_id=task.id,
duplicate=False,
status=document.status,
parse_status=document.parse_status,
embedding_status=document.embedding_status,
chunk_count=document.chunk_count,
collection_name=space.collection_name,
)
async def ingest_document(self, document_id: int, task_id: int | None = None) -> None:
"""知识构建:把已上传 PDF 转换为 MySQL 分片和 Milvus 向量。"""
document = self.repo.get_document(document_id)
if not document:
raise AppError("KNOWLEDGE_DOCUMENT_NOT_FOUND", "knowledge document not found", 404)
task = self.repo.get_ingestion_task(task_id) if task_id else None
try:
if task:
self.repo.update_task(task, status="running", progress=5, current_step="parse_pdf")
document.status = "processing"
document.parse_status = "running"
self.db.flush()
pages = self.parser.parse(document.file_path)
space = self.repo.get_or_create_space(document.institution_id, None)
drafts = self.chunker.build_chunks(pages)
chunks = self.chunker.to_models(
institution_id=document.institution_id,
document_id=document.id,
collection_name=space.collection_name,
embedding_model=settings.embedding_model,
drafts=drafts,
)
if task:
self.repo.update_task(task, progress=35, current_step="embed_chunks")
document.parse_status = "success"
document.embedding_status = "running"
self.db.flush()
vectors, _embedding_latency_ms = await self.embedding_service.embed_texts([chunk.chunk_text for chunk in chunks])
if task:
self.repo.update_task(task, progress=75, current_step="write_vectors")
self.vector_store.upsert_vectors(
space.collection_name,
[(chunk.chunk_uid, vector) for chunk, vector in zip(chunks, vectors)],
)
self.repo.replace_chunks(document, chunks)
document.status = "ready"
document.embedding_status = "success"
if task:
self.repo.update_task(task, status="success", progress=100, current_step="completed")
except Exception as exc:
document.status = "failed"
document.error_message = str(exc)[:2000]
document.parse_status = document.parse_status if document.parse_status == "success" else "failed"
document.embedding_status = "failed"
if task:
self.repo.update_task(task, status="failed", progress=100, current_step="failed", error_message=str(exc)[:2000])
if isinstance(exc, AppError):
raise
raise AppError("KNOWLEDGE_INGESTION_FAILED", "knowledge document ingestion failed", 500) from exc
def _validate_pdf(self, file: UploadFile, content: bytes) -> None:
"""上传校验:限制文件类型和大小,只允许 PDF 文档进入知识库。"""
if not content:
raise AppError("UPLOAD_FILE_EMPTY", "uploaded file is empty", 422)
max_bytes = settings.knowledge_max_upload_mb * 1024 * 1024
if len(content) > max_bytes:
raise AppError("UPLOAD_FILE_TOO_LARGE", f"uploaded file exceeds {settings.knowledge_max_upload_mb}MB", 413)
filename = (file.filename or "").lower()
if not filename.endswith(".pdf") and file.content_type not in {"application/pdf", "application/octet-stream"}:
raise AppError("UPLOAD_FILE_TYPE_INVALID", "only pdf file is supported", 422)
if not content.startswith(b"%PDF"):
raise AppError("UPLOAD_FILE_NOT_PDF", "uploaded file is not a valid pdf", 422)
def _save_file(self, institution_id: int, filename: str, content: bytes) -> Path:
"""文件保存:按机构隔离保存原始 PDF,供后续重建知识库。"""
safe_name = "".join(ch if ch.isalnum() or ch in {".", "_", "-"} else "_" for ch in filename)
storage_dir = Path(settings.knowledge_storage_dir) / "raw" / str(institution_id)
storage_dir.mkdir(parents=True, exist_ok=True)
target = storage_dir / f"{hashlib.sha256(content).hexdigest()[:16]}_{safe_name}"
target.write_bytes(content)
return target
def _enqueue_async_task(self, document_id: int, task_id: int) -> None:
"""异步投递:生产环境通过 Celery worker 执行 PDF 知识库构建。"""
try:
from app.tasks.knowledge_ingestion_tasks import ingest_knowledge_document
ingest_knowledge_document.delay(document_id, task_id)
except Exception as exc: # pragma: no cover - Celery 未运行时保留任务 queued 状态
raise AppError("KNOWLEDGE_TASK_ENQUEUE_FAILED", "knowledge ingestion task enqueue failed", 500) from exc
+22
View File
@@ -0,0 +1,22 @@
import time
from app.core.config import settings
from app.integrations.embedding_adapter import OpenAICompatibleEmbeddingClient
class EmbeddingService:
"""Embedding 服务:按配置批量调用向量模型,控制批大小和耗时统计。"""
def __init__(self, client: OpenAICompatibleEmbeddingClient | None = None) -> None:
self.client = client or OpenAICompatibleEmbeddingClient()
async def embed_texts(self, texts: list[str]) -> tuple[list[list[float]], int]:
"""批量向量化:按 EMBEDDING_BATCH_SIZE 分批生成向量并返回总耗时。"""
start = time.perf_counter()
vectors: list[list[float]] = []
batch_size = max(1, settings.embedding_batch_size)
for index in range(0, len(texts), batch_size):
batch = texts[index : index + batch_size]
batch_vectors, _usage = await self.client.embed_texts(batch)
vectors.extend(batch_vectors)
return vectors, int((time.perf_counter() - start) * 1000)
+39
View File
@@ -0,0 +1,39 @@
from app.core.context import UserContext
from app.core.exceptions import AppError
from app.models.knowledge_base import KbKnowledgeSpace
from app.repositories.knowledge_base_repository import KnowledgeBaseRepository
class KnowledgeSpaceService:
"""知识空间服务:按用户所属机构定位知识库 collection。"""
def __init__(self, repo: KnowledgeBaseRepository) -> None:
self.repo = repo
def require_institution_id(self, ctx: UserContext) -> int:
"""机构校验:知识库能力必须绑定 Django 用户中心返回的 institution_id。"""
if ctx.institution_id is None:
raise AppError("INSTITUTION_REQUIRED", "institution_id is required for knowledge base", 403)
return int(ctx.institution_id)
def get_or_create_space(self, ctx: UserContext) -> KbKnowledgeSpace:
"""知识空间获取:内容管理员上传文档时自动创建机构知识空间。"""
institution_id = self.require_institution_id(ctx)
profile = ctx.profile or {}
institution_name = profile.get("institution_name") or f"institution_{institution_id}"
return self.repo.get_or_create_space(institution_id, institution_name)
def get_active_space(self, ctx: UserContext) -> KbKnowledgeSpace:
"""知识空间读取:AI 学习助手问答时读取机构当前可用知识空间。"""
institution_id = self.require_institution_id(ctx)
space = self.repo.get_space(institution_id)
if not space:
raise AppError("KNOWLEDGE_SPACE_NOT_FOUND", "knowledge space not initialized for institution", 404)
return space
def ensure_content_admin(self, ctx: UserContext) -> None:
"""权限校验:仅内容管理员或系统管理员可以上传并构建机构知识库。"""
role = (ctx.role or "").lower()
allowed_roles = {"content_admin", "institution_admin", "admin", "super_admin"}
if role not in allowed_roles:
raise AppError("KNOWLEDGE_ADMIN_FORBIDDEN", "only content admin can upload knowledge documents", 403)
+224
View File
@@ -0,0 +1,224 @@
import json
import time
from collections.abc import AsyncIterator
from dataclasses import dataclass
from sqlalchemy.orm import Session
from app.agents.learning_assistant_agent import LearningAssistantAgent
from app.core.config import settings
from app.core.context import UserContext
from app.core.exceptions import AppError
from app.repositories.knowledge_base_repository import KnowledgeBaseRepository
from app.schemas.learning_assistant import LearningAssistantChatRequest, LearningAssistantChatResponse, LearningAssistantSource
from app.services.knowledge_space_service import KnowledgeSpaceService
from app.services.vector_search_service import RetrievedChunk, VectorSearchService
@dataclass(frozen=True)
class LearningAssistantRetrieval:
"""学习助手检索结果:封装知识库命中、耗时和降级原因。"""
institution_id: int | None
score_threshold: float
sources: list[LearningAssistantSource]
embedding_latency_ms: int | None = None
search_latency_ms: int | None = None
retrieval_error: str | None = None
class LearningAssistantService:
"""AI 学习助手服务:优先 RAG 检索,知识库不可用时降级为通用流式问答。"""
def __init__(
self,
db: Session,
*,
vector_search_service: VectorSearchService | None = None,
agent: LearningAssistantAgent | None = None,
) -> None:
self.db = db
self.repo = KnowledgeBaseRepository(db)
self.space_service = KnowledgeSpaceService(self.repo)
self.vector_search = vector_search_service or VectorSearchService(db)
self.agent = agent or LearningAssistantAgent()
async def chat(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> LearningAssistantChatResponse:
"""知识问答调试:检索失败不阻断回答,返回完整文本和检索降级信息。"""
start = time.perf_counter()
retrieval = await self._retrieve_sources(ctx, payload)
llm_started = time.perf_counter()
response = await self.agent.answer(payload.question, retrieval.sources)
total_latency_ms = int((time.perf_counter() - start) * 1000)
llm_latency_ms = response.latency_ms or int((time.perf_counter() - llm_started) * 1000)
self._write_query_log(
ctx=ctx,
payload=payload,
retrieval=retrieval,
answer=response.content,
model=response.model,
llm_latency_ms=llm_latency_ms,
total_latency_ms=total_latency_ms,
)
return LearningAssistantChatResponse(
answer=response.content,
retrieval_hit=bool(retrieval.sources),
sources=retrieval.sources,
retrieval_error=retrieval.retrieval_error,
model=response.model,
embedding_latency_ms=retrieval.embedding_latency_ms,
search_latency_ms=retrieval.search_latency_ms,
llm_latency_ms=llm_latency_ms,
total_latency_ms=total_latency_ms,
)
async def stream_chat(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> AsyncIterator[str]:
"""流式知识问答:先返回检索状态,再流式输出 LLM 回答。"""
start = time.perf_counter()
retrieval = await self._retrieve_sources(ctx, payload)
yield self._sse(
"retrieval_done",
{
"retrieval_hit": bool(retrieval.sources),
"sources": [source.model_dump() for source in retrieval.sources],
"retrieval_error": retrieval.retrieval_error,
"embedding_latency_ms": retrieval.embedding_latency_ms,
"search_latency_ms": retrieval.search_latency_ms,
},
)
answer_parts: list[str] = []
llm_latency_ms: int | None = None
model: str | None = None
try:
async for chunk in self.agent.stream_answer(payload.question, retrieval.sources):
if chunk.done:
llm_latency_ms = chunk.total_latency_ms
model = chunk.model
break
if chunk.delta:
answer_parts.append(chunk.delta)
yield self._sse("answer_delta", {"delta": chunk.delta})
except AppError as exc:
yield self._sse("error", {"code": exc.code, "message": exc.message})
return
except Exception:
yield self._sse("error", {"code": "LEARNING_ASSISTANT_LLM_FAILED", "message": "AI 学习助手回答生成失败,请稍后重试"})
return
answer = "".join(answer_parts)
total_latency_ms = int((time.perf_counter() - start) * 1000)
self._write_query_log(
ctx=ctx,
payload=payload,
retrieval=retrieval,
answer=answer,
model=model,
llm_latency_ms=llm_latency_ms,
total_latency_ms=total_latency_ms,
commit=True,
)
yield self._sse("answer_done", {"model": model, "total_latency_ms": total_latency_ms})
async def _retrieve_sources(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> LearningAssistantRetrieval:
"""知识检索:按机构读取知识空间;无空间、Milvus 或 embedding 异常时降级为空来源。"""
score_threshold = payload.score_threshold if payload.score_threshold is not None else settings.rag_score_threshold
try:
institution_id = self.space_service.require_institution_id(ctx)
except AppError:
return LearningAssistantRetrieval(
institution_id=None,
score_threshold=score_threshold,
sources=[],
retrieval_error="当前用户缺少机构信息,已转为大模型通用学习回答。",
)
try:
space = self.space_service.get_active_space(ctx)
retrieval = await self.vector_search.search(
institution_id=space.institution_id,
collection_name=space.collection_name,
question=payload.question,
top_k=payload.top_k,
score_threshold=payload.score_threshold,
)
return LearningAssistantRetrieval(
institution_id=space.institution_id,
score_threshold=payload.score_threshold if payload.score_threshold is not None else space.score_threshold,
sources=self._build_sources(retrieval.chunks),
embedding_latency_ms=retrieval.embedding_latency_ms,
search_latency_ms=retrieval.search_latency_ms,
)
except AppError as exc:
if exc.code in {"KNOWLEDGE_SPACE_NOT_FOUND", "MILVUS_COLLECTION_NOT_FOUND", "EMBEDDING_CALL_FAILED"}:
return LearningAssistantRetrieval(
institution_id=institution_id,
score_threshold=score_threshold,
sources=[],
retrieval_error="当前机构知识库暂未初始化或检索不可用,已转为大模型通用学习回答。",
)
raise
except Exception:
return LearningAssistantRetrieval(
institution_id=institution_id,
score_threshold=score_threshold,
sources=[],
retrieval_error="当前机构知识库检索暂不可用,已转为大模型通用学习回答。",
)
def _write_query_log(
self,
*,
ctx: UserContext,
payload: LearningAssistantChatRequest,
retrieval: LearningAssistantRetrieval,
answer: str,
model: str | None,
llm_latency_ms: int | None,
total_latency_ms: int | None,
commit: bool = False,
) -> None:
"""查询日志:仅在存在机构 ID 时记录 RAG 命中、来源和耗时。"""
if retrieval.institution_id is None:
return
self.repo.create_query_log(
user_id=ctx.user_id,
institution_id=retrieval.institution_id,
question=payload.question,
retrieval_hit=bool(retrieval.sources),
retrieved_chunk_ids=[source.chunk_uid for source in retrieval.sources],
answer_summary=answer,
llm_model=model,
top_k=payload.top_k or len(retrieval.sources) or settings.rag_top_k,
score_threshold=retrieval.score_threshold,
embedding_latency_ms=retrieval.embedding_latency_ms,
search_latency_ms=retrieval.search_latency_ms,
llm_latency_ms=llm_latency_ms,
total_latency_ms=total_latency_ms,
)
if commit:
self.db.commit()
def _build_sources(self, chunks: list[RetrievedChunk]) -> list[LearningAssistantSource]:
"""来源构建:把检索分片转换为前端可展示的 PDF 来源结构。"""
sources: list[LearningAssistantSource] = []
for item in chunks:
document = self.repo.get_document(item.chunk.document_id, item.chunk.institution_id)
quote = item.chunk.chunk_text[:500]
sources.append(
LearningAssistantSource(
document_id=item.chunk.document_id,
document_title=document.document_title if document else None,
file_name=document.file_name if document else "",
page_start=item.chunk.page_start,
page_end=item.chunk.page_end,
chunk_uid=item.chunk.chunk_uid,
score=round(item.score, 4),
quote=quote,
)
)
return sources
def _sse(self, event: str, data: dict) -> str:
"""SSE 封装:统一输出 event + data 格式。"""
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
+69
View File
@@ -0,0 +1,69 @@
import time
from dataclasses import dataclass
from sqlalchemy.orm import Session
from app.core.config import settings
from app.integrations.milvus_adapter import MilvusVectorStore
from app.models.knowledge_base import KbKnowledgeChunk
from app.repositories.knowledge_base_repository import KnowledgeBaseRepository
from app.services.embedding_service import EmbeddingService
@dataclass(frozen=True)
class RetrievedChunk:
"""RAG 检索结果:包含 MySQL 分片详情和 Milvus 相似度。"""
chunk: KbKnowledgeChunk
score: float
@dataclass(frozen=True)
class RetrievalResult:
"""检索结果包:返回命中分片和各阶段耗时。"""
chunks: list[RetrievedChunk]
embedding_latency_ms: int
search_latency_ms: int
class VectorSearchService:
"""向量检索服务:把用户问题向量化后在机构 Milvus collection 中检索。"""
def __init__(
self,
db: Session,
*,
embedding_service: EmbeddingService | None = None,
vector_store: MilvusVectorStore | None = None,
) -> None:
self.repo = KnowledgeBaseRepository(db)
self.embedding_service = embedding_service or EmbeddingService()
self.vector_store = vector_store or MilvusVectorStore()
async def search(
self,
*,
institution_id: int,
collection_name: str,
question: str,
top_n: int | None = None,
top_k: int | None = None,
score_threshold: float | None = None,
) -> RetrievalResult:
"""知识检索:先召回 top_n,再按阈值和 top_k 过滤最终上下文。"""
vectors, embedding_latency_ms = await self.embedding_service.embed_texts([question])
started = time.perf_counter()
hits = self.vector_store.search(collection_name, vectors[0], top_n or settings.rag_top_n)
search_latency_ms = int((time.perf_counter() - started) * 1000)
threshold = settings.rag_score_threshold if score_threshold is None else score_threshold
final_limit = top_k or settings.rag_top_k
filtered = [hit for hit in hits if hit.score >= threshold][:final_limit]
chunks = self.repo.get_chunks_by_uids(institution_id, [hit.chunk_uid for hit in filtered])
score_by_uid = {hit.chunk_uid: hit.score for hit in filtered}
return RetrievalResult(
chunks=[RetrievedChunk(chunk=chunk, score=score_by_uid.get(chunk.chunk_uid, 0.0)) for chunk in chunks],
embedding_latency_ms=embedding_latency_ms,
search_latency_ms=search_latency_ms,
)