70 lines
2.5 KiB
Python
70 lines
2.5 KiB
Python
|
|
import time
|
||
|
|
from dataclasses import dataclass
|
||
|
|
|
||
|
|
from sqlalchemy.orm import Session
|
||
|
|
|
||
|
|
from app.core.config import settings
|
||
|
|
from app.integrations.milvus_adapter import MilvusVectorStore
|
||
|
|
from app.models.knowledge_base import KbKnowledgeChunk
|
||
|
|
from app.repositories.knowledge_base_repository import KnowledgeBaseRepository
|
||
|
|
from app.services.embedding_service import EmbeddingService
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class RetrievedChunk:
|
||
|
|
"""RAG 检索结果:包含 MySQL 分片详情和 Milvus 相似度。"""
|
||
|
|
|
||
|
|
chunk: KbKnowledgeChunk
|
||
|
|
score: float
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class RetrievalResult:
|
||
|
|
"""检索结果包:返回命中分片和各阶段耗时。"""
|
||
|
|
|
||
|
|
chunks: list[RetrievedChunk]
|
||
|
|
embedding_latency_ms: int
|
||
|
|
search_latency_ms: int
|
||
|
|
|
||
|
|
|
||
|
|
class VectorSearchService:
|
||
|
|
"""向量检索服务:把用户问题向量化后在机构 Milvus collection 中检索。"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
db: Session,
|
||
|
|
*,
|
||
|
|
embedding_service: EmbeddingService | None = None,
|
||
|
|
vector_store: MilvusVectorStore | None = None,
|
||
|
|
) -> None:
|
||
|
|
self.repo = KnowledgeBaseRepository(db)
|
||
|
|
self.embedding_service = embedding_service or EmbeddingService()
|
||
|
|
self.vector_store = vector_store or MilvusVectorStore()
|
||
|
|
|
||
|
|
async def search(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
institution_id: int,
|
||
|
|
collection_name: str,
|
||
|
|
question: str,
|
||
|
|
top_n: int | None = None,
|
||
|
|
top_k: int | None = None,
|
||
|
|
score_threshold: float | None = None,
|
||
|
|
) -> RetrievalResult:
|
||
|
|
"""知识检索:先召回 top_n,再按阈值和 top_k 过滤最终上下文。"""
|
||
|
|
vectors, embedding_latency_ms = await self.embedding_service.embed_texts([question])
|
||
|
|
started = time.perf_counter()
|
||
|
|
hits = self.vector_store.search(collection_name, vectors[0], top_n or settings.rag_top_n)
|
||
|
|
search_latency_ms = int((time.perf_counter() - started) * 1000)
|
||
|
|
|
||
|
|
threshold = settings.rag_score_threshold if score_threshold is None else score_threshold
|
||
|
|
final_limit = top_k or settings.rag_top_k
|
||
|
|
filtered = [hit for hit in hits if hit.score >= threshold][:final_limit]
|
||
|
|
chunks = self.repo.get_chunks_by_uids(institution_id, [hit.chunk_uid for hit in filtered])
|
||
|
|
score_by_uid = {hit.chunk_uid: hit.score for hit in filtered}
|
||
|
|
return RetrievalResult(
|
||
|
|
chunks=[RetrievedChunk(chunk=chunk, score=score_by_uid.get(chunk.chunk_uid, 0.0)) for chunk in chunks],
|
||
|
|
embedding_latency_ms=embedding_latency_ms,
|
||
|
|
search_latency_ms=search_latency_ms,
|
||
|
|
)
|