import hashlib import math from dataclasses import dataclass import httpx from app.core.config import settings from app.core.exceptions import AppError @dataclass(frozen=True) class EmbeddingUsage: """Embedding 调用指标:记录批量向量化的模型和 token 用量。""" model: str total_tokens: int | None = None class OpenAICompatibleEmbeddingClient: """Embedding Adapter:封装 OpenAI-compatible embeddings 接口,并提供稳定 mock。""" @property def is_mock_mode(self) -> bool: """模式判断:没有 API Key 或显式 mock provider 时使用确定性本地向量。""" return settings.embedding_provider.lower() == "mock" or not settings.embedding_api_key async def embed_texts(self, texts: list[str]) -> tuple[list[list[float]], EmbeddingUsage]: """文本向量化:对文本批量生成 embedding,返回与输入顺序一致的向量列表。""" if not texts: return [], EmbeddingUsage(model=settings.embedding_model, total_tokens=0) if self.is_mock_mode: return [self._mock_vector(text) for text in texts], EmbeddingUsage(model=f"mock-{settings.embedding_model}") try: async with httpx.AsyncClient(timeout=settings.embedding_timeout_seconds) as client: resp = await client.post( self._embeddings_url(), headers={"Authorization": f"Bearer {settings.embedding_api_key}"}, json={"model": settings.embedding_model, "input": texts}, ) resp.raise_for_status() payload = resp.json() except (httpx.TimeoutException, httpx.HTTPError, ValueError) as exc: raise AppError("EMBEDDING_CALL_FAILED", "embedding service call failed", 502) from exc try: vectors = [item["embedding"] for item in sorted(payload["data"], key=lambda item: item.get("index", 0))] self._validate_vectors(vectors) usage = payload.get("usage") or {} return vectors, EmbeddingUsage(model=settings.embedding_model, total_tokens=usage.get("total_tokens")) except (KeyError, TypeError, ValueError) as exc: raise AppError("EMBEDDING_RESPONSE_INVALID", "embedding response format invalid", 502) from exc def _embeddings_url(self) -> str: """接口地址:兼容 base URL 和完整 /embeddings URL 两种写法。""" base_url = settings.embedding_base_url.rstrip("/") if base_url.endswith("/embeddings"): return base_url return f"{base_url}/embeddings" def _validate_vectors(self, vectors: list[list[float]]) -> None: """向量校验:确保向量维度与 Milvus collection 维度一致。""" for vector in vectors: if len(vector) != settings.embedding_dim: raise ValueError(f"embedding dimension mismatch: {len(vector)} != {settings.embedding_dim}") def _mock_vector(self, text: str) -> list[float]: """Mock向量:基于文本哈希生成稳定归一化向量,便于本地和CI测试。""" values: list[float] = [] seed = hashlib.sha256(f"{settings.embedding_model}:{text}".encode("utf-8")).digest() current = seed while len(values) < settings.embedding_dim: current = hashlib.sha256(current).digest() values.extend((byte / 127.5) - 1.0 for byte in current) vector = values[: settings.embedding_dim] norm = math.sqrt(sum(item * item for item in vector)) or 1.0 return [item / norm for item in vector]