feat: add streaming learning assistant and knowledge base scaffolding
This commit is contained in:
@@ -0,0 +1,154 @@
|
||||
import hashlib
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.integrations.pdf_parser import ParsedPdfPage
|
||||
from app.models.knowledge_base import KbKnowledgeChunk
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ChunkDraft:
|
||||
"""分片草稿:PDF 文本切分后的中间结构,后续写入 MySQL 和 Milvus。"""
|
||||
|
||||
chunk_index: int
|
||||
page_start: int
|
||||
page_end: int
|
||||
section_title: str | None
|
||||
text: str
|
||||
|
||||
|
||||
class DocumentChunkService:
|
||||
"""文档分片服务:面向教材/指南 PDF 的页码保留语义分片。"""
|
||||
|
||||
def __init__(self, chunk_size: int = 1100, chunk_overlap: int = 180) -> None:
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
|
||||
def build_chunks(self, pages: list[ParsedPdfPage]) -> list[ChunkDraft]:
|
||||
"""教材分片:按页和自然段切分,超长段落使用窗口切分并保留页码。"""
|
||||
drafts: list[ChunkDraft] = []
|
||||
buffer: list[str] = []
|
||||
page_start: int | None = None
|
||||
page_end: int | None = None
|
||||
current_title: str | None = None
|
||||
|
||||
for page in pages:
|
||||
paragraphs = self._split_paragraphs(page.text)
|
||||
for paragraph in paragraphs:
|
||||
detected_title = self._detect_title(paragraph)
|
||||
if detected_title:
|
||||
current_title = detected_title
|
||||
for piece in self._split_long_text(paragraph):
|
||||
candidate = "\n".join([*buffer, piece]).strip()
|
||||
if buffer and len(candidate) > self.chunk_size:
|
||||
drafts.append(
|
||||
ChunkDraft(
|
||||
chunk_index=len(drafts),
|
||||
page_start=page_start or page.page_number,
|
||||
page_end=page_end or page.page_number,
|
||||
section_title=current_title,
|
||||
text="\n".join(buffer).strip(),
|
||||
)
|
||||
)
|
||||
buffer = self._overlap_tail(buffer)
|
||||
page_start = page.page_number if not buffer else page_start
|
||||
if not buffer:
|
||||
page_start = page.page_number
|
||||
page_end = page.page_number
|
||||
buffer.append(piece)
|
||||
|
||||
if buffer:
|
||||
drafts.append(
|
||||
ChunkDraft(
|
||||
chunk_index=len(drafts),
|
||||
page_start=page_start or pages[-1].page_number,
|
||||
page_end=page_end or pages[-1].page_number,
|
||||
section_title=current_title,
|
||||
text="\n".join(buffer).strip(),
|
||||
)
|
||||
)
|
||||
return [draft for draft in drafts if draft.text]
|
||||
|
||||
def to_models(
|
||||
self,
|
||||
*,
|
||||
institution_id: int,
|
||||
document_id: int,
|
||||
collection_name: str,
|
||||
embedding_model: str,
|
||||
drafts: list[ChunkDraft],
|
||||
) -> list[KbKnowledgeChunk]:
|
||||
"""分片落库:把分片草稿转换为 ORM 对象,chunk_uid 同时作为 Milvus vector_id。"""
|
||||
rows: list[KbKnowledgeChunk] = []
|
||||
for draft in drafts:
|
||||
chunk_hash = hashlib.sha256(draft.text.encode("utf-8")).hexdigest()
|
||||
chunk_uid = f"doc{document_id}_chunk{draft.chunk_index}_{chunk_hash[:12]}"
|
||||
rows.append(
|
||||
KbKnowledgeChunk(
|
||||
institution_id=institution_id,
|
||||
document_id=document_id,
|
||||
chunk_uid=chunk_uid,
|
||||
chunk_index=draft.chunk_index,
|
||||
page_start=draft.page_start,
|
||||
page_end=draft.page_end,
|
||||
section_title=draft.section_title,
|
||||
chunk_text=draft.text,
|
||||
chunk_hash=chunk_hash,
|
||||
token_count=max(1, len(draft.text) // 2),
|
||||
vector_id=chunk_uid,
|
||||
collection_name=collection_name,
|
||||
embedding_model=embedding_model,
|
||||
metadata_={"chunking": "page_semantic_window", "chunk_size": self.chunk_size, "overlap": self.chunk_overlap},
|
||||
)
|
||||
)
|
||||
return rows
|
||||
|
||||
def _split_paragraphs(self, text: str) -> list[str]:
|
||||
"""段落切分:优先按 PDF 自带换行和空白段落切分教材内容。"""
|
||||
parts = re.split(r"\n{1,}", text)
|
||||
return [part.strip() for part in parts if part.strip()]
|
||||
|
||||
def _split_long_text(self, text: str) -> list[str]:
|
||||
"""超长兜底:对超过窗口的段落按句末标点拆分,仍过长时按字符窗口切分。"""
|
||||
if len(text) <= self.chunk_size:
|
||||
return [text]
|
||||
sentences = re.split(r"(?<=[。!?;;.!?])", text)
|
||||
pieces: list[str] = []
|
||||
current = ""
|
||||
for sentence in sentences:
|
||||
if len(current) + len(sentence) > self.chunk_size and current:
|
||||
pieces.append(current.strip())
|
||||
current = current[-self.chunk_overlap :] if self.chunk_overlap else ""
|
||||
current += sentence
|
||||
if current.strip():
|
||||
pieces.append(current.strip())
|
||||
final: list[str] = []
|
||||
for piece in pieces:
|
||||
if len(piece) <= self.chunk_size:
|
||||
final.append(piece)
|
||||
continue
|
||||
start = 0
|
||||
while start < len(piece):
|
||||
final.append(piece[start : start + self.chunk_size])
|
||||
start += max(1, self.chunk_size - self.chunk_overlap)
|
||||
return final
|
||||
|
||||
def _overlap_tail(self, buffer: list[str]) -> list[str]:
|
||||
"""重叠窗口:保留上一片尾部少量文本,提升跨片问题召回。"""
|
||||
if not self.chunk_overlap:
|
||||
return []
|
||||
text = "\n".join(buffer).strip()
|
||||
tail = text[-self.chunk_overlap :]
|
||||
return [tail] if tail else []
|
||||
|
||||
def _detect_title(self, paragraph: str) -> str | None:
|
||||
"""标题识别:识别教材常见章、节、条目标题,作为分片元数据。"""
|
||||
compact = paragraph.strip()
|
||||
if len(compact) > 80:
|
||||
return None
|
||||
title_patterns = [
|
||||
r"^第[一二三四五六七八九十百0-9]+[章节篇]",
|
||||
r"^[一二三四五六七八九十]+[、..]",
|
||||
r"^\d+(\.\d+){0,3}\s+",
|
||||
]
|
||||
return compact if any(re.search(pattern, compact) for pattern in title_patterns) else None
|
||||
@@ -0,0 +1,177 @@
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.context import UserContext
|
||||
from app.core.exceptions import AppError
|
||||
from app.integrations.milvus_adapter import MilvusVectorStore
|
||||
from app.integrations.pdf_parser import PdfParser
|
||||
from app.repositories.knowledge_base_repository import KnowledgeBaseRepository
|
||||
from app.schemas.knowledge_admin import KnowledgeDocumentUploadResponse
|
||||
from app.services.document_chunk_service import DocumentChunkService
|
||||
from app.services.embedding_service import EmbeddingService
|
||||
from app.services.knowledge_space_service import KnowledgeSpaceService
|
||||
|
||||
|
||||
class DocumentIngestionService:
|
||||
"""知识入库服务:处理 PDF 上传、解析、分片、向量化和 Milvus 写入。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
parser: PdfParser | None = None,
|
||||
chunker: DocumentChunkService | None = None,
|
||||
embedding_service: EmbeddingService | None = None,
|
||||
vector_store: MilvusVectorStore | None = None,
|
||||
) -> None:
|
||||
self.db = db
|
||||
self.repo = KnowledgeBaseRepository(db)
|
||||
self.space_service = KnowledgeSpaceService(self.repo)
|
||||
self.parser = parser or PdfParser()
|
||||
self.chunker = chunker or DocumentChunkService()
|
||||
self.embedding_service = embedding_service or EmbeddingService()
|
||||
self.vector_store = vector_store or MilvusVectorStore()
|
||||
|
||||
async def upload_pdf(
|
||||
self,
|
||||
ctx: UserContext,
|
||||
file: UploadFile,
|
||||
*,
|
||||
document_title: str | None,
|
||||
document_category: str,
|
||||
version: str,
|
||||
) -> KnowledgeDocumentUploadResponse:
|
||||
"""文档上传:内容管理员上传 PDF 后创建知识文档并触发构建任务。"""
|
||||
self.space_service.ensure_content_admin(ctx)
|
||||
space = self.space_service.get_or_create_space(ctx)
|
||||
content = await file.read()
|
||||
self._validate_pdf(file, content)
|
||||
file_sha256 = hashlib.sha256(content).hexdigest()
|
||||
existing = self.repo.get_document_by_hash(space.institution_id, file_sha256)
|
||||
if existing:
|
||||
task = self.repo.create_ingestion_task(existing) if existing.status == "failed" else None
|
||||
return KnowledgeDocumentUploadResponse(
|
||||
document_id=existing.id,
|
||||
task_id=task.id if task else None,
|
||||
duplicate=True,
|
||||
status=existing.status,
|
||||
parse_status=existing.parse_status,
|
||||
embedding_status=existing.embedding_status,
|
||||
chunk_count=existing.chunk_count,
|
||||
collection_name=space.collection_name,
|
||||
)
|
||||
|
||||
storage_path = self._save_file(space.institution_id, file.filename or "knowledge.pdf", content)
|
||||
document = self.repo.create_document(
|
||||
institution_id=space.institution_id,
|
||||
uploaded_by=ctx.user_id,
|
||||
file_name=file.filename or storage_path.name,
|
||||
file_sha256=file_sha256,
|
||||
file_size=len(content),
|
||||
file_path=str(storage_path),
|
||||
document_title=document_title or Path(file.filename or "").stem or None,
|
||||
document_category=document_category,
|
||||
version=version,
|
||||
)
|
||||
task = self.repo.create_ingestion_task(document)
|
||||
if settings.knowledge_ingestion_sync:
|
||||
await self.ingest_document(document.id, task.id)
|
||||
else:
|
||||
self._enqueue_async_task(document.id, task.id)
|
||||
return KnowledgeDocumentUploadResponse(
|
||||
document_id=document.id,
|
||||
task_id=task.id,
|
||||
duplicate=False,
|
||||
status=document.status,
|
||||
parse_status=document.parse_status,
|
||||
embedding_status=document.embedding_status,
|
||||
chunk_count=document.chunk_count,
|
||||
collection_name=space.collection_name,
|
||||
)
|
||||
|
||||
async def ingest_document(self, document_id: int, task_id: int | None = None) -> None:
|
||||
"""知识构建:把已上传 PDF 转换为 MySQL 分片和 Milvus 向量。"""
|
||||
document = self.repo.get_document(document_id)
|
||||
if not document:
|
||||
raise AppError("KNOWLEDGE_DOCUMENT_NOT_FOUND", "knowledge document not found", 404)
|
||||
task = self.repo.get_ingestion_task(task_id) if task_id else None
|
||||
try:
|
||||
if task:
|
||||
self.repo.update_task(task, status="running", progress=5, current_step="parse_pdf")
|
||||
document.status = "processing"
|
||||
document.parse_status = "running"
|
||||
self.db.flush()
|
||||
|
||||
pages = self.parser.parse(document.file_path)
|
||||
space = self.repo.get_or_create_space(document.institution_id, None)
|
||||
drafts = self.chunker.build_chunks(pages)
|
||||
chunks = self.chunker.to_models(
|
||||
institution_id=document.institution_id,
|
||||
document_id=document.id,
|
||||
collection_name=space.collection_name,
|
||||
embedding_model=settings.embedding_model,
|
||||
drafts=drafts,
|
||||
)
|
||||
if task:
|
||||
self.repo.update_task(task, progress=35, current_step="embed_chunks")
|
||||
document.parse_status = "success"
|
||||
document.embedding_status = "running"
|
||||
self.db.flush()
|
||||
|
||||
vectors, _embedding_latency_ms = await self.embedding_service.embed_texts([chunk.chunk_text for chunk in chunks])
|
||||
if task:
|
||||
self.repo.update_task(task, progress=75, current_step="write_vectors")
|
||||
self.vector_store.upsert_vectors(
|
||||
space.collection_name,
|
||||
[(chunk.chunk_uid, vector) for chunk, vector in zip(chunks, vectors)],
|
||||
)
|
||||
self.repo.replace_chunks(document, chunks)
|
||||
document.status = "ready"
|
||||
document.embedding_status = "success"
|
||||
if task:
|
||||
self.repo.update_task(task, status="success", progress=100, current_step="completed")
|
||||
except Exception as exc:
|
||||
document.status = "failed"
|
||||
document.error_message = str(exc)[:2000]
|
||||
document.parse_status = document.parse_status if document.parse_status == "success" else "failed"
|
||||
document.embedding_status = "failed"
|
||||
if task:
|
||||
self.repo.update_task(task, status="failed", progress=100, current_step="failed", error_message=str(exc)[:2000])
|
||||
if isinstance(exc, AppError):
|
||||
raise
|
||||
raise AppError("KNOWLEDGE_INGESTION_FAILED", "knowledge document ingestion failed", 500) from exc
|
||||
|
||||
def _validate_pdf(self, file: UploadFile, content: bytes) -> None:
|
||||
"""上传校验:限制文件类型和大小,只允许 PDF 文档进入知识库。"""
|
||||
if not content:
|
||||
raise AppError("UPLOAD_FILE_EMPTY", "uploaded file is empty", 422)
|
||||
max_bytes = settings.knowledge_max_upload_mb * 1024 * 1024
|
||||
if len(content) > max_bytes:
|
||||
raise AppError("UPLOAD_FILE_TOO_LARGE", f"uploaded file exceeds {settings.knowledge_max_upload_mb}MB", 413)
|
||||
filename = (file.filename or "").lower()
|
||||
if not filename.endswith(".pdf") and file.content_type not in {"application/pdf", "application/octet-stream"}:
|
||||
raise AppError("UPLOAD_FILE_TYPE_INVALID", "only pdf file is supported", 422)
|
||||
if not content.startswith(b"%PDF"):
|
||||
raise AppError("UPLOAD_FILE_NOT_PDF", "uploaded file is not a valid pdf", 422)
|
||||
|
||||
def _save_file(self, institution_id: int, filename: str, content: bytes) -> Path:
|
||||
"""文件保存:按机构隔离保存原始 PDF,供后续重建知识库。"""
|
||||
safe_name = "".join(ch if ch.isalnum() or ch in {".", "_", "-"} else "_" for ch in filename)
|
||||
storage_dir = Path(settings.knowledge_storage_dir) / "raw" / str(institution_id)
|
||||
storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
target = storage_dir / f"{hashlib.sha256(content).hexdigest()[:16]}_{safe_name}"
|
||||
target.write_bytes(content)
|
||||
return target
|
||||
|
||||
def _enqueue_async_task(self, document_id: int, task_id: int) -> None:
|
||||
"""异步投递:生产环境通过 Celery worker 执行 PDF 知识库构建。"""
|
||||
try:
|
||||
from app.tasks.knowledge_ingestion_tasks import ingest_knowledge_document
|
||||
|
||||
ingest_knowledge_document.delay(document_id, task_id)
|
||||
except Exception as exc: # pragma: no cover - Celery 未运行时保留任务 queued 状态
|
||||
raise AppError("KNOWLEDGE_TASK_ENQUEUE_FAILED", "knowledge ingestion task enqueue failed", 500) from exc
|
||||
@@ -0,0 +1,22 @@
|
||||
import time
|
||||
|
||||
from app.core.config import settings
|
||||
from app.integrations.embedding_adapter import OpenAICompatibleEmbeddingClient
|
||||
|
||||
|
||||
class EmbeddingService:
|
||||
"""Embedding 服务:按配置批量调用向量模型,控制批大小和耗时统计。"""
|
||||
|
||||
def __init__(self, client: OpenAICompatibleEmbeddingClient | None = None) -> None:
|
||||
self.client = client or OpenAICompatibleEmbeddingClient()
|
||||
|
||||
async def embed_texts(self, texts: list[str]) -> tuple[list[list[float]], int]:
|
||||
"""批量向量化:按 EMBEDDING_BATCH_SIZE 分批生成向量并返回总耗时。"""
|
||||
start = time.perf_counter()
|
||||
vectors: list[list[float]] = []
|
||||
batch_size = max(1, settings.embedding_batch_size)
|
||||
for index in range(0, len(texts), batch_size):
|
||||
batch = texts[index : index + batch_size]
|
||||
batch_vectors, _usage = await self.client.embed_texts(batch)
|
||||
vectors.extend(batch_vectors)
|
||||
return vectors, int((time.perf_counter() - start) * 1000)
|
||||
@@ -0,0 +1,39 @@
|
||||
from app.core.context import UserContext
|
||||
from app.core.exceptions import AppError
|
||||
from app.models.knowledge_base import KbKnowledgeSpace
|
||||
from app.repositories.knowledge_base_repository import KnowledgeBaseRepository
|
||||
|
||||
|
||||
class KnowledgeSpaceService:
|
||||
"""知识空间服务:按用户所属机构定位知识库 collection。"""
|
||||
|
||||
def __init__(self, repo: KnowledgeBaseRepository) -> None:
|
||||
self.repo = repo
|
||||
|
||||
def require_institution_id(self, ctx: UserContext) -> int:
|
||||
"""机构校验:知识库能力必须绑定 Django 用户中心返回的 institution_id。"""
|
||||
if ctx.institution_id is None:
|
||||
raise AppError("INSTITUTION_REQUIRED", "institution_id is required for knowledge base", 403)
|
||||
return int(ctx.institution_id)
|
||||
|
||||
def get_or_create_space(self, ctx: UserContext) -> KbKnowledgeSpace:
|
||||
"""知识空间获取:内容管理员上传文档时自动创建机构知识空间。"""
|
||||
institution_id = self.require_institution_id(ctx)
|
||||
profile = ctx.profile or {}
|
||||
institution_name = profile.get("institution_name") or f"institution_{institution_id}"
|
||||
return self.repo.get_or_create_space(institution_id, institution_name)
|
||||
|
||||
def get_active_space(self, ctx: UserContext) -> KbKnowledgeSpace:
|
||||
"""知识空间读取:AI 学习助手问答时读取机构当前可用知识空间。"""
|
||||
institution_id = self.require_institution_id(ctx)
|
||||
space = self.repo.get_space(institution_id)
|
||||
if not space:
|
||||
raise AppError("KNOWLEDGE_SPACE_NOT_FOUND", "knowledge space not initialized for institution", 404)
|
||||
return space
|
||||
|
||||
def ensure_content_admin(self, ctx: UserContext) -> None:
|
||||
"""权限校验:仅内容管理员或系统管理员可以上传并构建机构知识库。"""
|
||||
role = (ctx.role or "").lower()
|
||||
allowed_roles = {"content_admin", "institution_admin", "admin", "super_admin"}
|
||||
if role not in allowed_roles:
|
||||
raise AppError("KNOWLEDGE_ADMIN_FORBIDDEN", "only content admin can upload knowledge documents", 403)
|
||||
@@ -0,0 +1,224 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.agents.learning_assistant_agent import LearningAssistantAgent
|
||||
from app.core.config import settings
|
||||
from app.core.context import UserContext
|
||||
from app.core.exceptions import AppError
|
||||
from app.repositories.knowledge_base_repository import KnowledgeBaseRepository
|
||||
from app.schemas.learning_assistant import LearningAssistantChatRequest, LearningAssistantChatResponse, LearningAssistantSource
|
||||
from app.services.knowledge_space_service import KnowledgeSpaceService
|
||||
from app.services.vector_search_service import RetrievedChunk, VectorSearchService
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LearningAssistantRetrieval:
|
||||
"""学习助手检索结果:封装知识库命中、耗时和降级原因。"""
|
||||
|
||||
institution_id: int | None
|
||||
score_threshold: float
|
||||
sources: list[LearningAssistantSource]
|
||||
embedding_latency_ms: int | None = None
|
||||
search_latency_ms: int | None = None
|
||||
retrieval_error: str | None = None
|
||||
|
||||
|
||||
class LearningAssistantService:
|
||||
"""AI 学习助手服务:优先 RAG 检索,知识库不可用时降级为通用流式问答。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
vector_search_service: VectorSearchService | None = None,
|
||||
agent: LearningAssistantAgent | None = None,
|
||||
) -> None:
|
||||
self.db = db
|
||||
self.repo = KnowledgeBaseRepository(db)
|
||||
self.space_service = KnowledgeSpaceService(self.repo)
|
||||
self.vector_search = vector_search_service or VectorSearchService(db)
|
||||
self.agent = agent or LearningAssistantAgent()
|
||||
|
||||
async def chat(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> LearningAssistantChatResponse:
|
||||
"""知识问答调试:检索失败不阻断回答,返回完整文本和检索降级信息。"""
|
||||
start = time.perf_counter()
|
||||
retrieval = await self._retrieve_sources(ctx, payload)
|
||||
llm_started = time.perf_counter()
|
||||
response = await self.agent.answer(payload.question, retrieval.sources)
|
||||
total_latency_ms = int((time.perf_counter() - start) * 1000)
|
||||
llm_latency_ms = response.latency_ms or int((time.perf_counter() - llm_started) * 1000)
|
||||
self._write_query_log(
|
||||
ctx=ctx,
|
||||
payload=payload,
|
||||
retrieval=retrieval,
|
||||
answer=response.content,
|
||||
model=response.model,
|
||||
llm_latency_ms=llm_latency_ms,
|
||||
total_latency_ms=total_latency_ms,
|
||||
)
|
||||
return LearningAssistantChatResponse(
|
||||
answer=response.content,
|
||||
retrieval_hit=bool(retrieval.sources),
|
||||
sources=retrieval.sources,
|
||||
retrieval_error=retrieval.retrieval_error,
|
||||
model=response.model,
|
||||
embedding_latency_ms=retrieval.embedding_latency_ms,
|
||||
search_latency_ms=retrieval.search_latency_ms,
|
||||
llm_latency_ms=llm_latency_ms,
|
||||
total_latency_ms=total_latency_ms,
|
||||
)
|
||||
|
||||
async def stream_chat(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> AsyncIterator[str]:
|
||||
"""流式知识问答:先返回检索状态,再流式输出 LLM 回答。"""
|
||||
start = time.perf_counter()
|
||||
retrieval = await self._retrieve_sources(ctx, payload)
|
||||
yield self._sse(
|
||||
"retrieval_done",
|
||||
{
|
||||
"retrieval_hit": bool(retrieval.sources),
|
||||
"sources": [source.model_dump() for source in retrieval.sources],
|
||||
"retrieval_error": retrieval.retrieval_error,
|
||||
"embedding_latency_ms": retrieval.embedding_latency_ms,
|
||||
"search_latency_ms": retrieval.search_latency_ms,
|
||||
},
|
||||
)
|
||||
|
||||
answer_parts: list[str] = []
|
||||
llm_latency_ms: int | None = None
|
||||
model: str | None = None
|
||||
try:
|
||||
async for chunk in self.agent.stream_answer(payload.question, retrieval.sources):
|
||||
if chunk.done:
|
||||
llm_latency_ms = chunk.total_latency_ms
|
||||
model = chunk.model
|
||||
break
|
||||
if chunk.delta:
|
||||
answer_parts.append(chunk.delta)
|
||||
yield self._sse("answer_delta", {"delta": chunk.delta})
|
||||
except AppError as exc:
|
||||
yield self._sse("error", {"code": exc.code, "message": exc.message})
|
||||
return
|
||||
except Exception:
|
||||
yield self._sse("error", {"code": "LEARNING_ASSISTANT_LLM_FAILED", "message": "AI 学习助手回答生成失败,请稍后重试"})
|
||||
return
|
||||
|
||||
answer = "".join(answer_parts)
|
||||
total_latency_ms = int((time.perf_counter() - start) * 1000)
|
||||
self._write_query_log(
|
||||
ctx=ctx,
|
||||
payload=payload,
|
||||
retrieval=retrieval,
|
||||
answer=answer,
|
||||
model=model,
|
||||
llm_latency_ms=llm_latency_ms,
|
||||
total_latency_ms=total_latency_ms,
|
||||
commit=True,
|
||||
)
|
||||
yield self._sse("answer_done", {"model": model, "total_latency_ms": total_latency_ms})
|
||||
|
||||
async def _retrieve_sources(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> LearningAssistantRetrieval:
|
||||
"""知识检索:按机构读取知识空间;无空间、Milvus 或 embedding 异常时降级为空来源。"""
|
||||
score_threshold = payload.score_threshold if payload.score_threshold is not None else settings.rag_score_threshold
|
||||
try:
|
||||
institution_id = self.space_service.require_institution_id(ctx)
|
||||
except AppError:
|
||||
return LearningAssistantRetrieval(
|
||||
institution_id=None,
|
||||
score_threshold=score_threshold,
|
||||
sources=[],
|
||||
retrieval_error="当前用户缺少机构信息,已转为大模型通用学习回答。",
|
||||
)
|
||||
|
||||
try:
|
||||
space = self.space_service.get_active_space(ctx)
|
||||
retrieval = await self.vector_search.search(
|
||||
institution_id=space.institution_id,
|
||||
collection_name=space.collection_name,
|
||||
question=payload.question,
|
||||
top_k=payload.top_k,
|
||||
score_threshold=payload.score_threshold,
|
||||
)
|
||||
return LearningAssistantRetrieval(
|
||||
institution_id=space.institution_id,
|
||||
score_threshold=payload.score_threshold if payload.score_threshold is not None else space.score_threshold,
|
||||
sources=self._build_sources(retrieval.chunks),
|
||||
embedding_latency_ms=retrieval.embedding_latency_ms,
|
||||
search_latency_ms=retrieval.search_latency_ms,
|
||||
)
|
||||
except AppError as exc:
|
||||
if exc.code in {"KNOWLEDGE_SPACE_NOT_FOUND", "MILVUS_COLLECTION_NOT_FOUND", "EMBEDDING_CALL_FAILED"}:
|
||||
return LearningAssistantRetrieval(
|
||||
institution_id=institution_id,
|
||||
score_threshold=score_threshold,
|
||||
sources=[],
|
||||
retrieval_error="当前机构知识库暂未初始化或检索不可用,已转为大模型通用学习回答。",
|
||||
)
|
||||
raise
|
||||
except Exception:
|
||||
return LearningAssistantRetrieval(
|
||||
institution_id=institution_id,
|
||||
score_threshold=score_threshold,
|
||||
sources=[],
|
||||
retrieval_error="当前机构知识库检索暂不可用,已转为大模型通用学习回答。",
|
||||
)
|
||||
|
||||
def _write_query_log(
|
||||
self,
|
||||
*,
|
||||
ctx: UserContext,
|
||||
payload: LearningAssistantChatRequest,
|
||||
retrieval: LearningAssistantRetrieval,
|
||||
answer: str,
|
||||
model: str | None,
|
||||
llm_latency_ms: int | None,
|
||||
total_latency_ms: int | None,
|
||||
commit: bool = False,
|
||||
) -> None:
|
||||
"""查询日志:仅在存在机构 ID 时记录 RAG 命中、来源和耗时。"""
|
||||
if retrieval.institution_id is None:
|
||||
return
|
||||
self.repo.create_query_log(
|
||||
user_id=ctx.user_id,
|
||||
institution_id=retrieval.institution_id,
|
||||
question=payload.question,
|
||||
retrieval_hit=bool(retrieval.sources),
|
||||
retrieved_chunk_ids=[source.chunk_uid for source in retrieval.sources],
|
||||
answer_summary=answer,
|
||||
llm_model=model,
|
||||
top_k=payload.top_k or len(retrieval.sources) or settings.rag_top_k,
|
||||
score_threshold=retrieval.score_threshold,
|
||||
embedding_latency_ms=retrieval.embedding_latency_ms,
|
||||
search_latency_ms=retrieval.search_latency_ms,
|
||||
llm_latency_ms=llm_latency_ms,
|
||||
total_latency_ms=total_latency_ms,
|
||||
)
|
||||
if commit:
|
||||
self.db.commit()
|
||||
|
||||
def _build_sources(self, chunks: list[RetrievedChunk]) -> list[LearningAssistantSource]:
|
||||
"""来源构建:把检索分片转换为前端可展示的 PDF 来源结构。"""
|
||||
sources: list[LearningAssistantSource] = []
|
||||
for item in chunks:
|
||||
document = self.repo.get_document(item.chunk.document_id, item.chunk.institution_id)
|
||||
quote = item.chunk.chunk_text[:500]
|
||||
sources.append(
|
||||
LearningAssistantSource(
|
||||
document_id=item.chunk.document_id,
|
||||
document_title=document.document_title if document else None,
|
||||
file_name=document.file_name if document else "",
|
||||
page_start=item.chunk.page_start,
|
||||
page_end=item.chunk.page_end,
|
||||
chunk_uid=item.chunk.chunk_uid,
|
||||
score=round(item.score, 4),
|
||||
quote=quote,
|
||||
)
|
||||
)
|
||||
return sources
|
||||
|
||||
def _sse(self, event: str, data: dict) -> str:
|
||||
"""SSE 封装:统一输出 event + data 格式。"""
|
||||
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
@@ -0,0 +1,69 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.integrations.milvus_adapter import MilvusVectorStore
|
||||
from app.models.knowledge_base import KbKnowledgeChunk
|
||||
from app.repositories.knowledge_base_repository import KnowledgeBaseRepository
|
||||
from app.services.embedding_service import EmbeddingService
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RetrievedChunk:
|
||||
"""RAG 检索结果:包含 MySQL 分片详情和 Milvus 相似度。"""
|
||||
|
||||
chunk: KbKnowledgeChunk
|
||||
score: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RetrievalResult:
|
||||
"""检索结果包:返回命中分片和各阶段耗时。"""
|
||||
|
||||
chunks: list[RetrievedChunk]
|
||||
embedding_latency_ms: int
|
||||
search_latency_ms: int
|
||||
|
||||
|
||||
class VectorSearchService:
|
||||
"""向量检索服务:把用户问题向量化后在机构 Milvus collection 中检索。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
embedding_service: EmbeddingService | None = None,
|
||||
vector_store: MilvusVectorStore | None = None,
|
||||
) -> None:
|
||||
self.repo = KnowledgeBaseRepository(db)
|
||||
self.embedding_service = embedding_service or EmbeddingService()
|
||||
self.vector_store = vector_store or MilvusVectorStore()
|
||||
|
||||
async def search(
|
||||
self,
|
||||
*,
|
||||
institution_id: int,
|
||||
collection_name: str,
|
||||
question: str,
|
||||
top_n: int | None = None,
|
||||
top_k: int | None = None,
|
||||
score_threshold: float | None = None,
|
||||
) -> RetrievalResult:
|
||||
"""知识检索:先召回 top_n,再按阈值和 top_k 过滤最终上下文。"""
|
||||
vectors, embedding_latency_ms = await self.embedding_service.embed_texts([question])
|
||||
started = time.perf_counter()
|
||||
hits = self.vector_store.search(collection_name, vectors[0], top_n or settings.rag_top_n)
|
||||
search_latency_ms = int((time.perf_counter() - started) * 1000)
|
||||
|
||||
threshold = settings.rag_score_threshold if score_threshold is None else score_threshold
|
||||
final_limit = top_k or settings.rag_top_k
|
||||
filtered = [hit for hit in hits if hit.score >= threshold][:final_limit]
|
||||
chunks = self.repo.get_chunks_by_uids(institution_id, [hit.chunk_uid for hit in filtered])
|
||||
score_by_uid = {hit.chunk_uid: hit.score for hit in filtered}
|
||||
return RetrievalResult(
|
||||
chunks=[RetrievedChunk(chunk=chunk, score=score_by_uid.get(chunk.chunk_uid, 0.0)) for chunk in chunks],
|
||||
embedding_latency_ms=embedding_latency_ms,
|
||||
search_latency_ms=search_latency_ms,
|
||||
)
|
||||
Reference in New Issue
Block a user