113 lines
5.1 KiB
Python
113 lines
5.1 KiB
Python
|
|
from dataclasses import dataclass
|
||
|
|
|
||
|
|
from app.core.config import settings
|
||
|
|
from app.core.exceptions import AppError
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class VectorSearchHit:
|
||
|
|
"""向量检索命中:只保存 chunk_uid 和相似度,来源详情从 MySQL 读取。"""
|
||
|
|
|
||
|
|
chunk_uid: str
|
||
|
|
score: float
|
||
|
|
|
||
|
|
|
||
|
|
class MilvusVectorStore:
|
||
|
|
"""Milvus 向量库适配器:按机构 collection 写入和检索知识分片向量。"""
|
||
|
|
|
||
|
|
_mock_store: dict[str, dict[str, list[float]]] = {}
|
||
|
|
|
||
|
|
def __init__(self) -> None:
|
||
|
|
self.mock_enabled = settings.milvus_uri.startswith("mock://")
|
||
|
|
self._client = None
|
||
|
|
|
||
|
|
def ensure_collection(self, collection_name: str) -> None:
|
||
|
|
"""集合初始化:不存在时创建 VARCHAR 主键 + FLOAT_VECTOR 的 Milvus collection。"""
|
||
|
|
if self.mock_enabled:
|
||
|
|
self._mock_store.setdefault(collection_name, {})
|
||
|
|
return
|
||
|
|
client = self._client_or_raise()
|
||
|
|
try:
|
||
|
|
if client.has_collection(collection_name=collection_name):
|
||
|
|
return
|
||
|
|
schema = client.create_schema(auto_id=False, enable_dynamic_field=False)
|
||
|
|
from pymilvus import DataType
|
||
|
|
|
||
|
|
schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=128)
|
||
|
|
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=settings.embedding_dim)
|
||
|
|
index_params = client.prepare_index_params()
|
||
|
|
index_params.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="COSINE")
|
||
|
|
client.create_collection(
|
||
|
|
collection_name=collection_name,
|
||
|
|
schema=schema,
|
||
|
|
index_params=index_params,
|
||
|
|
consistency_level="Strong",
|
||
|
|
)
|
||
|
|
except Exception as exc: # pragma: no cover - 真实 Milvus 由联调环境验证
|
||
|
|
raise AppError("MILVUS_COLLECTION_INIT_FAILED", "milvus collection init failed", 502) from exc
|
||
|
|
|
||
|
|
def upsert_vectors(self, collection_name: str, vectors: list[tuple[str, list[float]]]) -> None:
|
||
|
|
"""向量写入:使用 chunk_uid 作为 Milvus 主键,保证重复构建可覆盖。"""
|
||
|
|
if not vectors:
|
||
|
|
return
|
||
|
|
self.ensure_collection(collection_name)
|
||
|
|
if self.mock_enabled:
|
||
|
|
collection = self._mock_store.setdefault(collection_name, {})
|
||
|
|
for chunk_uid, vector in vectors:
|
||
|
|
collection[chunk_uid] = vector
|
||
|
|
return
|
||
|
|
client = self._client_or_raise()
|
||
|
|
try:
|
||
|
|
client.upsert(
|
||
|
|
collection_name=collection_name,
|
||
|
|
data=[{"id": chunk_uid, "vector": vector} for chunk_uid, vector in vectors],
|
||
|
|
)
|
||
|
|
except Exception as exc: # pragma: no cover
|
||
|
|
raise AppError("MILVUS_UPSERT_FAILED", "milvus vector upsert failed", 502) from exc
|
||
|
|
|
||
|
|
def search(self, collection_name: str, query_vector: list[float], limit: int) -> list[VectorSearchHit]:
|
||
|
|
"""向量检索:按余弦相似度返回候选 chunk_uid,后续由业务层过滤阈值。"""
|
||
|
|
self.ensure_collection(collection_name)
|
||
|
|
if self.mock_enabled:
|
||
|
|
return self._mock_search(collection_name, query_vector, limit)
|
||
|
|
client = self._client_or_raise()
|
||
|
|
try:
|
||
|
|
results = client.search(
|
||
|
|
collection_name=collection_name,
|
||
|
|
data=[query_vector],
|
||
|
|
anns_field="vector",
|
||
|
|
limit=limit,
|
||
|
|
search_params={"metric_type": "COSINE"},
|
||
|
|
output_fields=["id"],
|
||
|
|
)
|
||
|
|
except Exception as exc: # pragma: no cover
|
||
|
|
raise AppError("MILVUS_SEARCH_FAILED", "milvus vector search failed", 502) from exc
|
||
|
|
|
||
|
|
hits: list[VectorSearchHit] = []
|
||
|
|
for item in results[0] if results else []:
|
||
|
|
entity = item.get("entity") or {}
|
||
|
|
chunk_uid = str(entity.get("id") or item.get("id") or "")
|
||
|
|
if chunk_uid:
|
||
|
|
hits.append(VectorSearchHit(chunk_uid=chunk_uid, score=float(item.get("distance") or item.get("score") or 0)))
|
||
|
|
return hits
|
||
|
|
|
||
|
|
def _client_or_raise(self):
|
||
|
|
"""客户端获取:懒加载 pymilvus,避免未使用知识库时影响现有训练接口。"""
|
||
|
|
if self._client is not None:
|
||
|
|
return self._client
|
||
|
|
try:
|
||
|
|
from pymilvus import MilvusClient
|
||
|
|
except ImportError as exc:
|
||
|
|
raise AppError("MILVUS_CLIENT_NOT_INSTALLED", "pymilvus is required for vector search", 500) from exc
|
||
|
|
self._client = MilvusClient(uri=settings.milvus_uri, db_name=settings.milvus_default_db)
|
||
|
|
return self._client
|
||
|
|
|
||
|
|
def _mock_search(self, collection_name: str, query_vector: list[float], limit: int) -> list[VectorSearchHit]:
|
||
|
|
"""Mock检索:用向量点积模拟余弦排序,便于无 Milvus 环境测试。"""
|
||
|
|
collection = self._mock_store.get(collection_name, {})
|
||
|
|
scored = [
|
||
|
|
VectorSearchHit(chunk_uid=chunk_uid, score=(sum(a * b for a, b in zip(query_vector, vector)) + 1.0) / 2.0)
|
||
|
|
for chunk_uid, vector in collection.items()
|
||
|
|
]
|
||
|
|
return sorted(scored, key=lambda item: item.score, reverse=True)[:limit]
|