78 lines
3.5 KiB
Python
78 lines
3.5 KiB
Python
|
|
import hashlib
|
||
|
|
import math
|
||
|
|
from dataclasses import dataclass
|
||
|
|
|
||
|
|
import httpx
|
||
|
|
|
||
|
|
from app.core.config import settings
|
||
|
|
from app.core.exceptions import AppError
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class EmbeddingUsage:
|
||
|
|
"""Embedding 调用指标:记录批量向量化的模型和 token 用量。"""
|
||
|
|
|
||
|
|
model: str
|
||
|
|
total_tokens: int | None = None
|
||
|
|
|
||
|
|
|
||
|
|
class OpenAICompatibleEmbeddingClient:
|
||
|
|
"""Embedding Adapter:封装 OpenAI-compatible embeddings 接口,并提供稳定 mock。"""
|
||
|
|
|
||
|
|
@property
|
||
|
|
def is_mock_mode(self) -> bool:
|
||
|
|
"""模式判断:没有 API Key 或显式 mock provider 时使用确定性本地向量。"""
|
||
|
|
return settings.embedding_provider.lower() == "mock" or not settings.embedding_api_key
|
||
|
|
|
||
|
|
async def embed_texts(self, texts: list[str]) -> tuple[list[list[float]], EmbeddingUsage]:
|
||
|
|
"""文本向量化:对文本批量生成 embedding,返回与输入顺序一致的向量列表。"""
|
||
|
|
if not texts:
|
||
|
|
return [], EmbeddingUsage(model=settings.embedding_model, total_tokens=0)
|
||
|
|
if self.is_mock_mode:
|
||
|
|
return [self._mock_vector(text) for text in texts], EmbeddingUsage(model=f"mock-{settings.embedding_model}")
|
||
|
|
|
||
|
|
try:
|
||
|
|
async with httpx.AsyncClient(timeout=settings.embedding_timeout_seconds) as client:
|
||
|
|
resp = await client.post(
|
||
|
|
self._embeddings_url(),
|
||
|
|
headers={"Authorization": f"Bearer {settings.embedding_api_key}"},
|
||
|
|
json={"model": settings.embedding_model, "input": texts},
|
||
|
|
)
|
||
|
|
resp.raise_for_status()
|
||
|
|
payload = resp.json()
|
||
|
|
except (httpx.TimeoutException, httpx.HTTPError, ValueError) as exc:
|
||
|
|
raise AppError("EMBEDDING_CALL_FAILED", "embedding service call failed", 502) from exc
|
||
|
|
|
||
|
|
try:
|
||
|
|
vectors = [item["embedding"] for item in sorted(payload["data"], key=lambda item: item.get("index", 0))]
|
||
|
|
self._validate_vectors(vectors)
|
||
|
|
usage = payload.get("usage") or {}
|
||
|
|
return vectors, EmbeddingUsage(model=settings.embedding_model, total_tokens=usage.get("total_tokens"))
|
||
|
|
except (KeyError, TypeError, ValueError) as exc:
|
||
|
|
raise AppError("EMBEDDING_RESPONSE_INVALID", "embedding response format invalid", 502) from exc
|
||
|
|
|
||
|
|
def _embeddings_url(self) -> str:
|
||
|
|
"""接口地址:兼容 base URL 和完整 /embeddings URL 两种写法。"""
|
||
|
|
base_url = settings.embedding_base_url.rstrip("/")
|
||
|
|
if base_url.endswith("/embeddings"):
|
||
|
|
return base_url
|
||
|
|
return f"{base_url}/embeddings"
|
||
|
|
|
||
|
|
def _validate_vectors(self, vectors: list[list[float]]) -> None:
|
||
|
|
"""向量校验:确保向量维度与 Milvus collection 维度一致。"""
|
||
|
|
for vector in vectors:
|
||
|
|
if len(vector) != settings.embedding_dim:
|
||
|
|
raise ValueError(f"embedding dimension mismatch: {len(vector)} != {settings.embedding_dim}")
|
||
|
|
|
||
|
|
def _mock_vector(self, text: str) -> list[float]:
|
||
|
|
"""Mock向量:基于文本哈希生成稳定归一化向量,便于本地和CI测试。"""
|
||
|
|
values: list[float] = []
|
||
|
|
seed = hashlib.sha256(f"{settings.embedding_model}:{text}".encode("utf-8")).digest()
|
||
|
|
current = seed
|
||
|
|
while len(values) < settings.embedding_dim:
|
||
|
|
current = hashlib.sha256(current).digest()
|
||
|
|
values.extend((byte / 127.5) - 1.0 for byte in current)
|
||
|
|
vector = values[: settings.embedding_dim]
|
||
|
|
norm = math.sqrt(sum(item * item for item in vector)) or 1.0
|
||
|
|
return [item / norm for item in vector]
|