feat: add streaming learning assistant and knowledge base scaffolding
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""外部能力适配层:封装 PDF 解析、Embedding 和 Milvus 等可替换基础设施。"""
|
||||
@@ -0,0 +1,77 @@
|
||||
import hashlib
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import AppError
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EmbeddingUsage:
|
||||
"""Embedding 调用指标:记录批量向量化的模型和 token 用量。"""
|
||||
|
||||
model: str
|
||||
total_tokens: int | None = None
|
||||
|
||||
|
||||
class OpenAICompatibleEmbeddingClient:
|
||||
"""Embedding Adapter:封装 OpenAI-compatible embeddings 接口,并提供稳定 mock。"""
|
||||
|
||||
@property
|
||||
def is_mock_mode(self) -> bool:
|
||||
"""模式判断:没有 API Key 或显式 mock provider 时使用确定性本地向量。"""
|
||||
return settings.embedding_provider.lower() == "mock" or not settings.embedding_api_key
|
||||
|
||||
async def embed_texts(self, texts: list[str]) -> tuple[list[list[float]], EmbeddingUsage]:
|
||||
"""文本向量化:对文本批量生成 embedding,返回与输入顺序一致的向量列表。"""
|
||||
if not texts:
|
||||
return [], EmbeddingUsage(model=settings.embedding_model, total_tokens=0)
|
||||
if self.is_mock_mode:
|
||||
return [self._mock_vector(text) for text in texts], EmbeddingUsage(model=f"mock-{settings.embedding_model}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=settings.embedding_timeout_seconds) as client:
|
||||
resp = await client.post(
|
||||
self._embeddings_url(),
|
||||
headers={"Authorization": f"Bearer {settings.embedding_api_key}"},
|
||||
json={"model": settings.embedding_model, "input": texts},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
except (httpx.TimeoutException, httpx.HTTPError, ValueError) as exc:
|
||||
raise AppError("EMBEDDING_CALL_FAILED", "embedding service call failed", 502) from exc
|
||||
|
||||
try:
|
||||
vectors = [item["embedding"] for item in sorted(payload["data"], key=lambda item: item.get("index", 0))]
|
||||
self._validate_vectors(vectors)
|
||||
usage = payload.get("usage") or {}
|
||||
return vectors, EmbeddingUsage(model=settings.embedding_model, total_tokens=usage.get("total_tokens"))
|
||||
except (KeyError, TypeError, ValueError) as exc:
|
||||
raise AppError("EMBEDDING_RESPONSE_INVALID", "embedding response format invalid", 502) from exc
|
||||
|
||||
def _embeddings_url(self) -> str:
|
||||
"""接口地址:兼容 base URL 和完整 /embeddings URL 两种写法。"""
|
||||
base_url = settings.embedding_base_url.rstrip("/")
|
||||
if base_url.endswith("/embeddings"):
|
||||
return base_url
|
||||
return f"{base_url}/embeddings"
|
||||
|
||||
def _validate_vectors(self, vectors: list[list[float]]) -> None:
|
||||
"""向量校验:确保向量维度与 Milvus collection 维度一致。"""
|
||||
for vector in vectors:
|
||||
if len(vector) != settings.embedding_dim:
|
||||
raise ValueError(f"embedding dimension mismatch: {len(vector)} != {settings.embedding_dim}")
|
||||
|
||||
def _mock_vector(self, text: str) -> list[float]:
|
||||
"""Mock向量:基于文本哈希生成稳定归一化向量,便于本地和CI测试。"""
|
||||
values: list[float] = []
|
||||
seed = hashlib.sha256(f"{settings.embedding_model}:{text}".encode("utf-8")).digest()
|
||||
current = seed
|
||||
while len(values) < settings.embedding_dim:
|
||||
current = hashlib.sha256(current).digest()
|
||||
values.extend((byte / 127.5) - 1.0 for byte in current)
|
||||
vector = values[: settings.embedding_dim]
|
||||
norm = math.sqrt(sum(item * item for item in vector)) or 1.0
|
||||
return [item / norm for item in vector]
|
||||
@@ -0,0 +1,112 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import AppError
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VectorSearchHit:
|
||||
"""向量检索命中:只保存 chunk_uid 和相似度,来源详情从 MySQL 读取。"""
|
||||
|
||||
chunk_uid: str
|
||||
score: float
|
||||
|
||||
|
||||
class MilvusVectorStore:
|
||||
"""Milvus 向量库适配器:按机构 collection 写入和检索知识分片向量。"""
|
||||
|
||||
_mock_store: dict[str, dict[str, list[float]]] = {}
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.mock_enabled = settings.milvus_uri.startswith("mock://")
|
||||
self._client = None
|
||||
|
||||
def ensure_collection(self, collection_name: str) -> None:
|
||||
"""集合初始化:不存在时创建 VARCHAR 主键 + FLOAT_VECTOR 的 Milvus collection。"""
|
||||
if self.mock_enabled:
|
||||
self._mock_store.setdefault(collection_name, {})
|
||||
return
|
||||
client = self._client_or_raise()
|
||||
try:
|
||||
if client.has_collection(collection_name=collection_name):
|
||||
return
|
||||
schema = client.create_schema(auto_id=False, enable_dynamic_field=False)
|
||||
from pymilvus import DataType
|
||||
|
||||
schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=128)
|
||||
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=settings.embedding_dim)
|
||||
index_params = client.prepare_index_params()
|
||||
index_params.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="COSINE")
|
||||
client.create_collection(
|
||||
collection_name=collection_name,
|
||||
schema=schema,
|
||||
index_params=index_params,
|
||||
consistency_level="Strong",
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - 真实 Milvus 由联调环境验证
|
||||
raise AppError("MILVUS_COLLECTION_INIT_FAILED", "milvus collection init failed", 502) from exc
|
||||
|
||||
def upsert_vectors(self, collection_name: str, vectors: list[tuple[str, list[float]]]) -> None:
|
||||
"""向量写入:使用 chunk_uid 作为 Milvus 主键,保证重复构建可覆盖。"""
|
||||
if not vectors:
|
||||
return
|
||||
self.ensure_collection(collection_name)
|
||||
if self.mock_enabled:
|
||||
collection = self._mock_store.setdefault(collection_name, {})
|
||||
for chunk_uid, vector in vectors:
|
||||
collection[chunk_uid] = vector
|
||||
return
|
||||
client = self._client_or_raise()
|
||||
try:
|
||||
client.upsert(
|
||||
collection_name=collection_name,
|
||||
data=[{"id": chunk_uid, "vector": vector} for chunk_uid, vector in vectors],
|
||||
)
|
||||
except Exception as exc: # pragma: no cover
|
||||
raise AppError("MILVUS_UPSERT_FAILED", "milvus vector upsert failed", 502) from exc
|
||||
|
||||
def search(self, collection_name: str, query_vector: list[float], limit: int) -> list[VectorSearchHit]:
|
||||
"""向量检索:按余弦相似度返回候选 chunk_uid,后续由业务层过滤阈值。"""
|
||||
self.ensure_collection(collection_name)
|
||||
if self.mock_enabled:
|
||||
return self._mock_search(collection_name, query_vector, limit)
|
||||
client = self._client_or_raise()
|
||||
try:
|
||||
results = client.search(
|
||||
collection_name=collection_name,
|
||||
data=[query_vector],
|
||||
anns_field="vector",
|
||||
limit=limit,
|
||||
search_params={"metric_type": "COSINE"},
|
||||
output_fields=["id"],
|
||||
)
|
||||
except Exception as exc: # pragma: no cover
|
||||
raise AppError("MILVUS_SEARCH_FAILED", "milvus vector search failed", 502) from exc
|
||||
|
||||
hits: list[VectorSearchHit] = []
|
||||
for item in results[0] if results else []:
|
||||
entity = item.get("entity") or {}
|
||||
chunk_uid = str(entity.get("id") or item.get("id") or "")
|
||||
if chunk_uid:
|
||||
hits.append(VectorSearchHit(chunk_uid=chunk_uid, score=float(item.get("distance") or item.get("score") or 0)))
|
||||
return hits
|
||||
|
||||
def _client_or_raise(self):
|
||||
"""客户端获取:懒加载 pymilvus,避免未使用知识库时影响现有训练接口。"""
|
||||
if self._client is not None:
|
||||
return self._client
|
||||
try:
|
||||
from pymilvus import MilvusClient
|
||||
except ImportError as exc:
|
||||
raise AppError("MILVUS_CLIENT_NOT_INSTALLED", "pymilvus is required for vector search", 500) from exc
|
||||
self._client = MilvusClient(uri=settings.milvus_uri, db_name=settings.milvus_default_db)
|
||||
return self._client
|
||||
|
||||
def _mock_search(self, collection_name: str, query_vector: list[float], limit: int) -> list[VectorSearchHit]:
|
||||
"""Mock检索:用向量点积模拟余弦排序,便于无 Milvus 环境测试。"""
|
||||
collection = self._mock_store.get(collection_name, {})
|
||||
scored = [
|
||||
VectorSearchHit(chunk_uid=chunk_uid, score=(sum(a * b for a, b in zip(query_vector, vector)) + 1.0) / 2.0)
|
||||
for chunk_uid, vector in collection.items()
|
||||
]
|
||||
return sorted(scored, key=lambda item: item.score, reverse=True)[:limit]
|
||||
@@ -0,0 +1,44 @@
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.exceptions import AppError
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ParsedPdfPage:
|
||||
"""PDF 页文本:保留页码,支撑 RAG 回答中的来源页码引用。"""
|
||||
|
||||
page_number: int
|
||||
text: str
|
||||
|
||||
|
||||
class PdfParser:
|
||||
"""PDF 解析器:使用 PyMuPDF 逐页提取教材、指南等 PDF 文本。"""
|
||||
|
||||
def parse(self, file_path: str | Path) -> list[ParsedPdfPage]:
|
||||
"""PDF解析:逐页读取文本并过滤空页,失败时返回统一业务异常。"""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
raise AppError("PDF_FILE_NOT_FOUND", "uploaded pdf file not found", 404)
|
||||
try:
|
||||
import fitz # PyMuPDF
|
||||
except ImportError as exc:
|
||||
raise AppError("PDF_PARSER_NOT_INSTALLED", "PyMuPDF is required for pdf parsing", 500) from exc
|
||||
|
||||
pages: list[ParsedPdfPage] = []
|
||||
try:
|
||||
with fitz.open(path) as doc:
|
||||
for index, page in enumerate(doc, start=1):
|
||||
text = self._clean_text(page.get_text("text") or "")
|
||||
if text:
|
||||
pages.append(ParsedPdfPage(page_number=index, text=text))
|
||||
except Exception as exc: # pragma: no cover - PyMuPDF 异常类型较多,统一转换即可
|
||||
raise AppError("PDF_PARSE_FAILED", "pdf parse failed", 422) from exc
|
||||
if not pages:
|
||||
raise AppError("PDF_PARSE_EMPTY", "pdf text content is empty", 422)
|
||||
return pages
|
||||
|
||||
def _clean_text(self, text: str) -> str:
|
||||
"""文本清洗:压缩空白并保留自然换行,便于后续教材分片。"""
|
||||
lines = [" ".join(line.strip().split()) for line in text.splitlines()]
|
||||
return "\n".join(line for line in lines if line)
|
||||
Reference in New Issue
Block a user