import json import time from collections.abc import AsyncIterator from dataclasses import dataclass from sqlalchemy.orm import Session from app.agents.learning_assistant_agent import LearningAssistantAgent from app.core.config import settings from app.core.context import UserContext from app.core.exceptions import AppError from app.repositories.knowledge_base_repository import KnowledgeBaseRepository from app.schemas.learning_assistant import LearningAssistantChatRequest, LearningAssistantChatResponse, LearningAssistantSource from app.services.knowledge_space_service import KnowledgeSpaceService from app.services.vector_search_service import RetrievedChunk, VectorSearchService @dataclass(frozen=True) class LearningAssistantRetrieval: """学习助手检索结果:封装知识库命中、耗时和降级原因。""" institution_id: int | None score_threshold: float sources: list[LearningAssistantSource] embedding_latency_ms: int | None = None search_latency_ms: int | None = None retrieval_error: str | None = None class LearningAssistantService: """AI 学习助手服务:优先 RAG 检索,知识库不可用时降级为通用流式问答。""" def __init__( self, db: Session, *, vector_search_service: VectorSearchService | None = None, agent: LearningAssistantAgent | None = None, ) -> None: self.db = db self.repo = KnowledgeBaseRepository(db) self.space_service = KnowledgeSpaceService(self.repo) self.vector_search = vector_search_service or VectorSearchService(db) self.agent = agent or LearningAssistantAgent() async def chat(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> LearningAssistantChatResponse: """知识问答调试:检索失败不阻断回答,返回完整文本和检索降级信息。""" start = time.perf_counter() retrieval = await self._retrieve_sources(ctx, payload) llm_started = time.perf_counter() response = await self.agent.answer(payload.question, retrieval.sources) total_latency_ms = int((time.perf_counter() - start) * 1000) llm_latency_ms = response.latency_ms or int((time.perf_counter() - llm_started) * 1000) self._write_query_log( ctx=ctx, payload=payload, retrieval=retrieval, answer=response.content, model=response.model, llm_latency_ms=llm_latency_ms, total_latency_ms=total_latency_ms, ) return LearningAssistantChatResponse( answer=response.content, retrieval_hit=bool(retrieval.sources), sources=retrieval.sources, retrieval_error=retrieval.retrieval_error, model=response.model, embedding_latency_ms=retrieval.embedding_latency_ms, search_latency_ms=retrieval.search_latency_ms, llm_latency_ms=llm_latency_ms, total_latency_ms=total_latency_ms, ) async def stream_chat(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> AsyncIterator[str]: """流式知识问答:先返回检索状态,再流式输出 LLM 回答。""" start = time.perf_counter() retrieval = await self._retrieve_sources(ctx, payload) yield self._sse( "retrieval_done", { "retrieval_hit": bool(retrieval.sources), "sources": [source.model_dump() for source in retrieval.sources], "retrieval_error": retrieval.retrieval_error, "embedding_latency_ms": retrieval.embedding_latency_ms, "search_latency_ms": retrieval.search_latency_ms, }, ) answer_parts: list[str] = [] llm_latency_ms: int | None = None model: str | None = None try: async for chunk in self.agent.stream_answer(payload.question, retrieval.sources): if chunk.done: llm_latency_ms = chunk.total_latency_ms model = chunk.model break if chunk.delta: answer_parts.append(chunk.delta) yield self._sse("answer_delta", {"delta": chunk.delta}) except AppError as exc: yield self._sse("error", {"code": exc.code, "message": exc.message}) return except Exception: yield self._sse("error", {"code": "LEARNING_ASSISTANT_LLM_FAILED", "message": "AI 学习助手回答生成失败,请稍后重试"}) return answer = "".join(answer_parts) total_latency_ms = int((time.perf_counter() - start) * 1000) self._write_query_log( ctx=ctx, payload=payload, retrieval=retrieval, answer=answer, model=model, llm_latency_ms=llm_latency_ms, total_latency_ms=total_latency_ms, commit=True, ) yield self._sse("answer_done", {"model": model, "total_latency_ms": total_latency_ms}) async def _retrieve_sources(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> LearningAssistantRetrieval: """知识检索:按机构读取知识空间;无空间、Milvus 或 embedding 异常时降级为空来源。""" score_threshold = payload.score_threshold if payload.score_threshold is not None else settings.rag_score_threshold try: institution_id = self.space_service.require_institution_id(ctx) except AppError: return LearningAssistantRetrieval( institution_id=None, score_threshold=score_threshold, sources=[], retrieval_error="当前用户缺少机构信息,已转为大模型通用学习回答。", ) try: space = self.space_service.get_active_space(ctx) retrieval = await self.vector_search.search( institution_id=space.institution_id, collection_name=space.collection_name, question=payload.question, top_k=payload.top_k, score_threshold=payload.score_threshold, ) return LearningAssistantRetrieval( institution_id=space.institution_id, score_threshold=payload.score_threshold if payload.score_threshold is not None else space.score_threshold, sources=self._build_sources(retrieval.chunks), embedding_latency_ms=retrieval.embedding_latency_ms, search_latency_ms=retrieval.search_latency_ms, ) except AppError as exc: if exc.code in {"KNOWLEDGE_SPACE_NOT_FOUND", "MILVUS_COLLECTION_NOT_FOUND", "EMBEDDING_CALL_FAILED"}: return LearningAssistantRetrieval( institution_id=institution_id, score_threshold=score_threshold, sources=[], retrieval_error="当前机构知识库暂未初始化或检索不可用,已转为大模型通用学习回答。", ) raise except Exception: return LearningAssistantRetrieval( institution_id=institution_id, score_threshold=score_threshold, sources=[], retrieval_error="当前机构知识库检索暂不可用,已转为大模型通用学习回答。", ) def _write_query_log( self, *, ctx: UserContext, payload: LearningAssistantChatRequest, retrieval: LearningAssistantRetrieval, answer: str, model: str | None, llm_latency_ms: int | None, total_latency_ms: int | None, commit: bool = False, ) -> None: """查询日志:仅在存在机构 ID 时记录 RAG 命中、来源和耗时。""" if retrieval.institution_id is None: return self.repo.create_query_log( user_id=ctx.user_id, institution_id=retrieval.institution_id, question=payload.question, retrieval_hit=bool(retrieval.sources), retrieved_chunk_ids=[source.chunk_uid for source in retrieval.sources], answer_summary=answer, llm_model=model, top_k=payload.top_k or len(retrieval.sources) or settings.rag_top_k, score_threshold=retrieval.score_threshold, embedding_latency_ms=retrieval.embedding_latency_ms, search_latency_ms=retrieval.search_latency_ms, llm_latency_ms=llm_latency_ms, total_latency_ms=total_latency_ms, ) if commit: self.db.commit() def _build_sources(self, chunks: list[RetrievedChunk]) -> list[LearningAssistantSource]: """来源构建:把检索分片转换为前端可展示的 PDF 来源结构。""" sources: list[LearningAssistantSource] = [] for item in chunks: document = self.repo.get_document(item.chunk.document_id, item.chunk.institution_id) quote = item.chunk.chunk_text[:500] sources.append( LearningAssistantSource( document_id=item.chunk.document_id, document_title=document.document_title if document else None, file_name=document.file_name if document else "", page_start=item.chunk.page_start, page_end=item.chunk.page_end, chunk_uid=item.chunk.chunk_uid, score=round(item.score, 4), quote=quote, ) ) return sources def _sse(self, event: str, data: dict) -> str: """SSE 封装:统一输出 event + data 格式。""" return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"