Files
fastapi/app/services/learning_assistant_service.py
T

225 lines
9.8 KiB
Python
Raw Normal View History

import json
import time
from collections.abc import AsyncIterator
from dataclasses import dataclass
from sqlalchemy.orm import Session
from app.agents.learning_assistant_agent import LearningAssistantAgent
from app.core.config import settings
from app.core.context import UserContext
from app.core.exceptions import AppError
from app.repositories.knowledge_base_repository import KnowledgeBaseRepository
from app.schemas.learning_assistant import LearningAssistantChatRequest, LearningAssistantChatResponse, LearningAssistantSource
from app.services.knowledge_space_service import KnowledgeSpaceService
from app.services.vector_search_service import RetrievedChunk, VectorSearchService
@dataclass(frozen=True)
class LearningAssistantRetrieval:
"""学习助手检索结果:封装知识库命中、耗时和降级原因。"""
institution_id: int | None
score_threshold: float
sources: list[LearningAssistantSource]
embedding_latency_ms: int | None = None
search_latency_ms: int | None = None
retrieval_error: str | None = None
class LearningAssistantService:
"""AI 学习助手服务:优先 RAG 检索,知识库不可用时降级为通用流式问答。"""
def __init__(
self,
db: Session,
*,
vector_search_service: VectorSearchService | None = None,
agent: LearningAssistantAgent | None = None,
) -> None:
self.db = db
self.repo = KnowledgeBaseRepository(db)
self.space_service = KnowledgeSpaceService(self.repo)
self.vector_search = vector_search_service or VectorSearchService(db)
self.agent = agent or LearningAssistantAgent()
async def chat(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> LearningAssistantChatResponse:
"""知识问答调试:检索失败不阻断回答,返回完整文本和检索降级信息。"""
start = time.perf_counter()
retrieval = await self._retrieve_sources(ctx, payload)
llm_started = time.perf_counter()
response = await self.agent.answer(payload.question, retrieval.sources)
total_latency_ms = int((time.perf_counter() - start) * 1000)
llm_latency_ms = response.latency_ms or int((time.perf_counter() - llm_started) * 1000)
self._write_query_log(
ctx=ctx,
payload=payload,
retrieval=retrieval,
answer=response.content,
model=response.model,
llm_latency_ms=llm_latency_ms,
total_latency_ms=total_latency_ms,
)
return LearningAssistantChatResponse(
answer=response.content,
retrieval_hit=bool(retrieval.sources),
sources=retrieval.sources,
retrieval_error=retrieval.retrieval_error,
model=response.model,
embedding_latency_ms=retrieval.embedding_latency_ms,
search_latency_ms=retrieval.search_latency_ms,
llm_latency_ms=llm_latency_ms,
total_latency_ms=total_latency_ms,
)
async def stream_chat(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> AsyncIterator[str]:
"""流式知识问答:先返回检索状态,再流式输出 LLM 回答。"""
start = time.perf_counter()
retrieval = await self._retrieve_sources(ctx, payload)
yield self._sse(
"retrieval_done",
{
"retrieval_hit": bool(retrieval.sources),
"sources": [source.model_dump() for source in retrieval.sources],
"retrieval_error": retrieval.retrieval_error,
"embedding_latency_ms": retrieval.embedding_latency_ms,
"search_latency_ms": retrieval.search_latency_ms,
},
)
answer_parts: list[str] = []
llm_latency_ms: int | None = None
model: str | None = None
try:
async for chunk in self.agent.stream_answer(payload.question, retrieval.sources):
if chunk.done:
llm_latency_ms = chunk.total_latency_ms
model = chunk.model
break
if chunk.delta:
answer_parts.append(chunk.delta)
yield self._sse("answer_delta", {"delta": chunk.delta})
except AppError as exc:
yield self._sse("error", {"code": exc.code, "message": exc.message})
return
except Exception:
yield self._sse("error", {"code": "LEARNING_ASSISTANT_LLM_FAILED", "message": "AI 学习助手回答生成失败,请稍后重试"})
return
answer = "".join(answer_parts)
total_latency_ms = int((time.perf_counter() - start) * 1000)
self._write_query_log(
ctx=ctx,
payload=payload,
retrieval=retrieval,
answer=answer,
model=model,
llm_latency_ms=llm_latency_ms,
total_latency_ms=total_latency_ms,
commit=True,
)
yield self._sse("answer_done", {"model": model, "total_latency_ms": total_latency_ms})
async def _retrieve_sources(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> LearningAssistantRetrieval:
"""知识检索:按机构读取知识空间;无空间、Milvus 或 embedding 异常时降级为空来源。"""
score_threshold = payload.score_threshold if payload.score_threshold is not None else settings.rag_score_threshold
try:
institution_id = self.space_service.require_institution_id(ctx)
except AppError:
return LearningAssistantRetrieval(
institution_id=None,
score_threshold=score_threshold,
sources=[],
retrieval_error="当前用户缺少机构信息,已转为大模型通用学习回答。",
)
try:
space = self.space_service.get_active_space(ctx)
retrieval = await self.vector_search.search(
institution_id=space.institution_id,
collection_name=space.collection_name,
question=payload.question,
top_k=payload.top_k,
score_threshold=payload.score_threshold,
)
return LearningAssistantRetrieval(
institution_id=space.institution_id,
score_threshold=payload.score_threshold if payload.score_threshold is not None else space.score_threshold,
sources=self._build_sources(retrieval.chunks),
embedding_latency_ms=retrieval.embedding_latency_ms,
search_latency_ms=retrieval.search_latency_ms,
)
except AppError as exc:
if exc.code in {"KNOWLEDGE_SPACE_NOT_FOUND", "MILVUS_COLLECTION_NOT_FOUND", "EMBEDDING_CALL_FAILED"}:
return LearningAssistantRetrieval(
institution_id=institution_id,
score_threshold=score_threshold,
sources=[],
retrieval_error="当前机构知识库暂未初始化或检索不可用,已转为大模型通用学习回答。",
)
raise
except Exception:
return LearningAssistantRetrieval(
institution_id=institution_id,
score_threshold=score_threshold,
sources=[],
retrieval_error="当前机构知识库检索暂不可用,已转为大模型通用学习回答。",
)
def _write_query_log(
self,
*,
ctx: UserContext,
payload: LearningAssistantChatRequest,
retrieval: LearningAssistantRetrieval,
answer: str,
model: str | None,
llm_latency_ms: int | None,
total_latency_ms: int | None,
commit: bool = False,
) -> None:
"""查询日志:仅在存在机构 ID 时记录 RAG 命中、来源和耗时。"""
if retrieval.institution_id is None:
return
self.repo.create_query_log(
user_id=ctx.user_id,
institution_id=retrieval.institution_id,
question=payload.question,
retrieval_hit=bool(retrieval.sources),
retrieved_chunk_ids=[source.chunk_uid for source in retrieval.sources],
answer_summary=answer,
llm_model=model,
top_k=payload.top_k or len(retrieval.sources) or settings.rag_top_k,
score_threshold=retrieval.score_threshold,
embedding_latency_ms=retrieval.embedding_latency_ms,
search_latency_ms=retrieval.search_latency_ms,
llm_latency_ms=llm_latency_ms,
total_latency_ms=total_latency_ms,
)
if commit:
self.db.commit()
def _build_sources(self, chunks: list[RetrievedChunk]) -> list[LearningAssistantSource]:
"""来源构建:把检索分片转换为前端可展示的 PDF 来源结构。"""
sources: list[LearningAssistantSource] = []
for item in chunks:
document = self.repo.get_document(item.chunk.document_id, item.chunk.institution_id)
quote = item.chunk.chunk_text[:500]
sources.append(
LearningAssistantSource(
document_id=item.chunk.document_id,
document_title=document.document_title if document else None,
file_name=document.file_name if document else "",
page_start=item.chunk.page_start,
page_end=item.chunk.page_end,
chunk_uid=item.chunk.chunk_uid,
score=round(item.score, 4),
quote=quote,
)
)
return sources
def _sse(self, event: str, data: dict) -> str:
"""SSE 封装:统一输出 event + data 格式。"""
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"