Files
fastapi/app/services/vector_search_service.py
T

70 lines
2.5 KiB
Python
Raw Normal View History

import time
from dataclasses import dataclass
from sqlalchemy.orm import Session
from app.core.config import settings
from app.integrations.milvus_adapter import MilvusVectorStore
from app.models.knowledge_base import KbKnowledgeChunk
from app.repositories.knowledge_base_repository import KnowledgeBaseRepository
from app.services.embedding_service import EmbeddingService
@dataclass(frozen=True)
class RetrievedChunk:
"""RAG 检索结果:包含 MySQL 分片详情和 Milvus 相似度。"""
chunk: KbKnowledgeChunk
score: float
@dataclass(frozen=True)
class RetrievalResult:
"""检索结果包:返回命中分片和各阶段耗时。"""
chunks: list[RetrievedChunk]
embedding_latency_ms: int
search_latency_ms: int
class VectorSearchService:
"""向量检索服务:把用户问题向量化后在机构 Milvus collection 中检索。"""
def __init__(
self,
db: Session,
*,
embedding_service: EmbeddingService | None = None,
vector_store: MilvusVectorStore | None = None,
) -> None:
self.repo = KnowledgeBaseRepository(db)
self.embedding_service = embedding_service or EmbeddingService()
self.vector_store = vector_store or MilvusVectorStore()
async def search(
self,
*,
institution_id: int,
collection_name: str,
question: str,
top_n: int | None = None,
top_k: int | None = None,
score_threshold: float | None = None,
) -> RetrievalResult:
"""知识检索:先召回 top_n,再按阈值和 top_k 过滤最终上下文。"""
vectors, embedding_latency_ms = await self.embedding_service.embed_texts([question])
started = time.perf_counter()
hits = self.vector_store.search(collection_name, vectors[0], top_n or settings.rag_top_n)
search_latency_ms = int((time.perf_counter() - started) * 1000)
threshold = settings.rag_score_threshold if score_threshold is None else score_threshold
final_limit = top_k or settings.rag_top_k
filtered = [hit for hit in hits if hit.score >= threshold][:final_limit]
chunks = self.repo.get_chunks_by_uids(institution_id, [hit.chunk_uid for hit in filtered])
score_by_uid = {hit.chunk_uid: hit.score for hit in filtered}
return RetrievalResult(
chunks=[RetrievedChunk(chunk=chunk, score=score_by_uid.get(chunk.chunk_uid, 0.0)) for chunk in chunks],
embedding_latency_ms=embedding_latency_ms,
search_latency_ms=search_latency_ms,
)