Files
fastapi/app/integrations/embedding_adapter.py
T

78 lines
3.5 KiB
Python

import hashlib
import math
from dataclasses import dataclass
import httpx
from app.core.config import settings
from app.core.exceptions import AppError
@dataclass(frozen=True)
class EmbeddingUsage:
"""Embedding 调用指标:记录批量向量化的模型和 token 用量。"""
model: str
total_tokens: int | None = None
class OpenAICompatibleEmbeddingClient:
"""Embedding Adapter:封装 OpenAI-compatible embeddings 接口,并提供稳定 mock。"""
@property
def is_mock_mode(self) -> bool:
"""模式判断:没有 API Key 或显式 mock provider 时使用确定性本地向量。"""
return settings.embedding_provider.lower() == "mock" or not settings.embedding_api_key
async def embed_texts(self, texts: list[str]) -> tuple[list[list[float]], EmbeddingUsage]:
"""文本向量化:对文本批量生成 embedding,返回与输入顺序一致的向量列表。"""
if not texts:
return [], EmbeddingUsage(model=settings.embedding_model, total_tokens=0)
if self.is_mock_mode:
return [self._mock_vector(text) for text in texts], EmbeddingUsage(model=f"mock-{settings.embedding_model}")
try:
async with httpx.AsyncClient(timeout=settings.embedding_timeout_seconds) as client:
resp = await client.post(
self._embeddings_url(),
headers={"Authorization": f"Bearer {settings.embedding_api_key}"},
json={"model": settings.embedding_model, "input": texts},
)
resp.raise_for_status()
payload = resp.json()
except (httpx.TimeoutException, httpx.HTTPError, ValueError) as exc:
raise AppError("EMBEDDING_CALL_FAILED", "embedding service call failed", 502) from exc
try:
vectors = [item["embedding"] for item in sorted(payload["data"], key=lambda item: item.get("index", 0))]
self._validate_vectors(vectors)
usage = payload.get("usage") or {}
return vectors, EmbeddingUsage(model=settings.embedding_model, total_tokens=usage.get("total_tokens"))
except (KeyError, TypeError, ValueError) as exc:
raise AppError("EMBEDDING_RESPONSE_INVALID", "embedding response format invalid", 502) from exc
def _embeddings_url(self) -> str:
"""接口地址:兼容 base URL 和完整 /embeddings URL 两种写法。"""
base_url = settings.embedding_base_url.rstrip("/")
if base_url.endswith("/embeddings"):
return base_url
return f"{base_url}/embeddings"
def _validate_vectors(self, vectors: list[list[float]]) -> None:
"""向量校验:确保向量维度与 Milvus collection 维度一致。"""
for vector in vectors:
if len(vector) != settings.embedding_dim:
raise ValueError(f"embedding dimension mismatch: {len(vector)} != {settings.embedding_dim}")
def _mock_vector(self, text: str) -> list[float]:
"""Mock向量:基于文本哈希生成稳定归一化向量,便于本地和CI测试。"""
values: list[float] = []
seed = hashlib.sha256(f"{settings.embedding_model}:{text}".encode("utf-8")).digest()
current = seed
while len(values) < settings.embedding_dim:
current = hashlib.sha256(current).digest()
values.extend((byte / 127.5) - 1.0 for byte in current)
vector = values[: settings.embedding_dim]
norm = math.sqrt(sum(item * item for item in vector)) or 1.0
return [item / norm for item in vector]