import json import time from collections.abc import AsyncIterator from dataclasses import dataclass from typing import Any from sqlalchemy.orm import Session from app.agents.learning_assistant_agent import LearningAssistantAgent from app.core.config import settings from app.core.context import UserContext from app.core.exceptions import AppError from app.repositories.knowledge_base_repository import KnowledgeBaseRepository from app.schemas.learning_assistant import ( LearningAssistantChatRequest, LearningAssistantSessionCreateRequest, LearningAssistantSessionResponse, LearningAssistantSource, ) from app.services.knowledge_space_service import KnowledgeSpaceService from app.services.learning_assistant_session_store import LearningAssistantSessionStore, learning_assistant_session_store from app.services.vector_search_service import RetrievedChunk, VectorSearchService @dataclass(frozen=True) class LearningAssistantRetrieval: """学习助手检索结果:封装知识库命中、耗时和降级原因。""" institution_id: int | None score_threshold: float sources: list[LearningAssistantSource] embedding_latency_ms: int | None = None search_latency_ms: int | None = None retrieval_error: str | None = None class LearningAssistantService: """AI 学习助手服务:管理短期会话,并优先通过 RAG 检索生成流式学习回答。""" def __init__( self, db: Session, *, vector_search_service: VectorSearchService | None = None, agent: LearningAssistantAgent | None = None, session_store: LearningAssistantSessionStore | None = None, ) -> None: self.db = db self.repo = KnowledgeBaseRepository(db) self.space_service = KnowledgeSpaceService(self.repo) self.vector_search = vector_search_service or VectorSearchService(db) self.agent = agent or LearningAssistantAgent() self.session_store = session_store or learning_assistant_session_store def create_session(self, ctx: UserContext, payload: LearningAssistantSessionCreateRequest) -> LearningAssistantSessionResponse: """学习助手会话创建:进入 AI 学习助手页面时初始化短期上下文容器。""" state = self.session_store.create(ctx, title=payload.title) return self._session_response(state) def validate_session(self, ctx: UserContext, assistant_session_id: str) -> dict[str, Any]: """学习助手会话校验:确保会话存在、未过期且属于当前用户。""" state = self.session_store.get(assistant_session_id, ctx.user_id) if not state: raise AppError("LEARNING_ASSISTANT_SESSION_NOT_FOUND", "learning assistant session not found", 404) if state.get("status") != "active": raise AppError("LEARNING_ASSISTANT_SESSION_INVALID", "learning assistant session is not active", 400) return state async def stream_session_chat( self, ctx: UserContext, payload: LearningAssistantChatRequest, assistant_session: dict[str, Any], ) -> AsyncIterator[str]: """会话式流式问答:绑定学习助手会话,记录最近问答并参与后续提示词拼接。""" yield self._sse( "session_ready", { "assistant_session_id": assistant_session["assistant_session_id"], "status": assistant_session["status"], "history_count": len(assistant_session.get("messages") or []), }, ) async for event in self._stream_answer(ctx, payload, assistant_session=assistant_session): yield event async def _stream_answer( self, ctx: UserContext, payload: LearningAssistantChatRequest, *, assistant_session: dict[str, Any] | None, ) -> AsyncIterator[str]: """学习助手流式核心流程:检索知识库、调用 LLM、写入查询日志和短期会话上下文。""" start = time.perf_counter() assistant_session_id = assistant_session.get("assistant_session_id") if assistant_session else None history = ( self.session_store.get_messages(assistant_session_id, ctx.user_id, settings.learning_assistant_history_limit) if assistant_session_id else [] ) if assistant_session_id: self.session_store.append_message(assistant_session_id, ctx.user_id, "user", payload.question) retrieval = await self._retrieve_sources(ctx, payload) yield self._sse( "retrieval_done", self._with_session( assistant_session_id, { "retrieval_hit": bool(retrieval.sources), "sources": [source.model_dump() for source in retrieval.sources], "retrieval_error": retrieval.retrieval_error, "embedding_latency_ms": retrieval.embedding_latency_ms, "search_latency_ms": retrieval.search_latency_ms, }, ), ) answer_parts: list[str] = [] llm_latency_ms: int | None = None model: str | None = None try: async for chunk in self.agent.stream_answer(payload.question, retrieval.sources, history=history): if chunk.done: llm_latency_ms = chunk.total_latency_ms model = chunk.model break if chunk.delta: answer_parts.append(chunk.delta) yield self._sse("answer_delta", self._with_session(assistant_session_id, {"delta": chunk.delta})) except AppError as exc: yield self._sse("error", self._with_session(assistant_session_id, {"code": exc.code, "message": exc.message})) return except Exception: yield self._sse( "error", self._with_session( assistant_session_id, {"code": "LEARNING_ASSISTANT_LLM_FAILED", "message": "AI 学习助手回答生成失败,请稍后重试"}, ), ) return answer = "".join(answer_parts) total_latency_ms = int((time.perf_counter() - start) * 1000) if assistant_session_id: self.session_store.append_message( assistant_session_id, ctx.user_id, "assistant", answer, metadata={"retrieval_hit": bool(retrieval.sources), "source_count": len(retrieval.sources), "model": model}, ) self._write_query_log( ctx=ctx, payload=payload, retrieval=retrieval, answer=answer, model=model, llm_latency_ms=llm_latency_ms, total_latency_ms=total_latency_ms, commit=True, ) yield self._sse( "answer_done", self._with_session( assistant_session_id, { "model": model, "total_latency_ms": total_latency_ms, "llm_latency_ms": llm_latency_ms, }, ), ) async def _retrieve_sources(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> LearningAssistantRetrieval: """知识检索:按机构读取知识空间;无空间、Milvus 或 embedding 异常时降级为空来源。""" score_threshold = payload.score_threshold if payload.score_threshold is not None else settings.rag_score_threshold try: institution_id = self.space_service.require_institution_id(ctx) except AppError: return LearningAssistantRetrieval( institution_id=None, score_threshold=score_threshold, sources=[], retrieval_error="当前用户缺少机构信息,已转为大模型通用学习回答。", ) try: space = self.space_service.get_active_space(ctx) retrieval = await self.vector_search.search( institution_id=space.institution_id, collection_name=space.collection_name, question=payload.question, top_k=payload.top_k, score_threshold=payload.score_threshold, ) return LearningAssistantRetrieval( institution_id=space.institution_id, score_threshold=payload.score_threshold if payload.score_threshold is not None else space.score_threshold, sources=self._build_sources(retrieval.chunks), embedding_latency_ms=retrieval.embedding_latency_ms, search_latency_ms=retrieval.search_latency_ms, ) except AppError as exc: if exc.code in {"KNOWLEDGE_SPACE_NOT_FOUND", "MILVUS_COLLECTION_NOT_FOUND", "EMBEDDING_CALL_FAILED"}: return LearningAssistantRetrieval( institution_id=institution_id, score_threshold=score_threshold, sources=[], retrieval_error="当前机构知识库暂未初始化或检索不可用,已转为大模型通用学习回答。", ) raise except Exception: return LearningAssistantRetrieval( institution_id=institution_id, score_threshold=score_threshold, sources=[], retrieval_error="当前机构知识库检索暂不可用,已转为大模型通用学习回答。", ) def _write_query_log( self, *, ctx: UserContext, payload: LearningAssistantChatRequest, retrieval: LearningAssistantRetrieval, answer: str, model: str | None, llm_latency_ms: int | None, total_latency_ms: int | None, commit: bool = False, ) -> None: """查询日志:仅在存在机构 ID 时记录 RAG 命中、来源和耗时。""" if retrieval.institution_id is None: return self.repo.create_query_log( user_id=ctx.user_id, institution_id=retrieval.institution_id, question=payload.question, retrieval_hit=bool(retrieval.sources), retrieved_chunk_ids=[source.chunk_uid for source in retrieval.sources], answer_summary=answer, llm_model=model, top_k=payload.top_k or len(retrieval.sources) or settings.rag_top_k, score_threshold=retrieval.score_threshold, embedding_latency_ms=retrieval.embedding_latency_ms, search_latency_ms=retrieval.search_latency_ms, llm_latency_ms=llm_latency_ms, total_latency_ms=total_latency_ms, ) if commit: self.db.commit() def _build_sources(self, chunks: list[RetrievedChunk]) -> list[LearningAssistantSource]: """来源构建:把检索分片转换为前端可展示的 PDF 来源结构。""" sources: list[LearningAssistantSource] = [] for item in chunks: document = self.repo.get_document(item.chunk.document_id, item.chunk.institution_id) sources.append( LearningAssistantSource( document_id=item.chunk.document_id, document_title=document.document_title if document else None, file_name=document.file_name if document else "", page_start=item.chunk.page_start, page_end=item.chunk.page_end, chunk_uid=item.chunk.chunk_uid, score=round(item.score, 4), quote=item.chunk.chunk_text[:500], ) ) return sources def _session_response(self, state: dict[str, Any]) -> LearningAssistantSessionResponse: """会话响应转换:只返回前端需要展示和后续调用的字段。""" return LearningAssistantSessionResponse( assistant_session_id=state["assistant_session_id"], user_id=state["user_id"], institution_id=state.get("institution_id"), institution_name=state.get("institution_name"), title=state["title"], status=state["status"], created_at=state["created_at"], updated_at=state["updated_at"], expires_in_seconds=state["expires_in_seconds"], ) def _sse(self, event: str, data: dict) -> str: """SSE 封装:统一输出 event + data 格式。""" return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" def _with_session(self, assistant_session_id: str | None, data: dict) -> dict: """SSE 数据增强:会话式接口返回 assistant_session_id,旧接口保持兼容。""" if assistant_session_id: return {"assistant_session_id": assistant_session_id, **data} return data