import time from dataclasses import dataclass from sqlalchemy.orm import Session from app.core.config import settings from app.integrations.milvus_adapter import MilvusVectorStore from app.models.knowledge_base import KbKnowledgeChunk from app.repositories.knowledge_base_repository import KnowledgeBaseRepository from app.services.embedding_service import EmbeddingService @dataclass(frozen=True) class RetrievedChunk: """RAG 检索结果:包含 MySQL 分片详情和 Milvus 相似度。""" chunk: KbKnowledgeChunk score: float @dataclass(frozen=True) class RetrievalResult: """检索结果包:返回命中分片和各阶段耗时。""" chunks: list[RetrievedChunk] embedding_latency_ms: int search_latency_ms: int class VectorSearchService: """向量检索服务:把用户问题向量化后在机构 Milvus collection 中检索。""" def __init__( self, db: Session, *, embedding_service: EmbeddingService | None = None, vector_store: MilvusVectorStore | None = None, ) -> None: self.repo = KnowledgeBaseRepository(db) self.embedding_service = embedding_service or EmbeddingService() self.vector_store = vector_store or MilvusVectorStore() async def search( self, *, institution_id: int, collection_name: str, question: str, top_n: int | None = None, top_k: int | None = None, score_threshold: float | None = None, ) -> RetrievalResult: """知识检索:先召回 top_n,再按阈值和 top_k 过滤最终上下文。""" vectors, embedding_latency_ms = await self.embedding_service.embed_texts([question]) started = time.perf_counter() hits = self.vector_store.search(collection_name, vectors[0], top_n or settings.rag_top_n) search_latency_ms = int((time.perf_counter() - started) * 1000) threshold = settings.rag_score_threshold if score_threshold is None else score_threshold final_limit = top_k or settings.rag_top_k filtered = [hit for hit in hits if hit.score >= threshold][:final_limit] chunks = self.repo.get_chunks_by_uids(institution_id, [hit.chunk_uid for hit in filtered]) score_by_uid = {hit.chunk_uid: hit.score for hit in filtered} return RetrievalResult( chunks=[RetrievedChunk(chunk=chunk, score=score_by_uid.get(chunk.chunk_uid, 0.0)) for chunk in chunks], embedding_latency_ms=embedding_latency_ms, search_latency_ms=search_latency_ms, )