298 lines
13 KiB
Python
298 lines
13 KiB
Python
import json
|
|
import time
|
|
from collections.abc import AsyncIterator
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.agents.learning_assistant_agent import LearningAssistantAgent
|
|
from app.core.config import settings
|
|
from app.core.context import UserContext
|
|
from app.core.exceptions import AppError
|
|
from app.repositories.knowledge_base_repository import KnowledgeBaseRepository
|
|
from app.schemas.learning_assistant import (
|
|
LearningAssistantChatRequest,
|
|
LearningAssistantSessionCreateRequest,
|
|
LearningAssistantSessionResponse,
|
|
LearningAssistantSource,
|
|
)
|
|
from app.services.knowledge_space_service import KnowledgeSpaceService
|
|
from app.services.learning_assistant_session_store import LearningAssistantSessionStore, learning_assistant_session_store
|
|
from app.services.vector_search_service import RetrievedChunk, VectorSearchService
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class LearningAssistantRetrieval:
|
|
"""学习助手检索结果:封装知识库命中、耗时和降级原因。"""
|
|
|
|
institution_id: int | None
|
|
score_threshold: float
|
|
sources: list[LearningAssistantSource]
|
|
embedding_latency_ms: int | None = None
|
|
search_latency_ms: int | None = None
|
|
retrieval_error: str | None = None
|
|
|
|
|
|
class LearningAssistantService:
|
|
"""AI 学习助手服务:管理短期会话,并优先通过 RAG 检索生成流式学习回答。"""
|
|
|
|
def __init__(
|
|
self,
|
|
db: Session,
|
|
*,
|
|
vector_search_service: VectorSearchService | None = None,
|
|
agent: LearningAssistantAgent | None = None,
|
|
session_store: LearningAssistantSessionStore | None = None,
|
|
) -> None:
|
|
self.db = db
|
|
self.repo = KnowledgeBaseRepository(db)
|
|
self.space_service = KnowledgeSpaceService(self.repo)
|
|
self.vector_search = vector_search_service or VectorSearchService(db)
|
|
self.agent = agent or LearningAssistantAgent()
|
|
self.session_store = session_store or learning_assistant_session_store
|
|
|
|
def create_session(self, ctx: UserContext, payload: LearningAssistantSessionCreateRequest) -> LearningAssistantSessionResponse:
|
|
"""学习助手会话创建:进入 AI 学习助手页面时初始化短期上下文容器。"""
|
|
state = self.session_store.create(ctx, title=payload.title)
|
|
return self._session_response(state)
|
|
|
|
def validate_session(self, ctx: UserContext, assistant_session_id: str) -> dict[str, Any]:
|
|
"""学习助手会话校验:确保会话存在、未过期且属于当前用户。"""
|
|
state = self.session_store.get(assistant_session_id, ctx.user_id)
|
|
if not state:
|
|
raise AppError("LEARNING_ASSISTANT_SESSION_NOT_FOUND", "learning assistant session not found", 404)
|
|
if state.get("status") != "active":
|
|
raise AppError("LEARNING_ASSISTANT_SESSION_INVALID", "learning assistant session is not active", 400)
|
|
return state
|
|
|
|
async def stream_session_chat(
|
|
self,
|
|
ctx: UserContext,
|
|
payload: LearningAssistantChatRequest,
|
|
assistant_session: dict[str, Any],
|
|
) -> AsyncIterator[str]:
|
|
"""会话式流式问答:绑定学习助手会话,记录最近问答并参与后续提示词拼接。"""
|
|
yield self._sse(
|
|
"session_ready",
|
|
{
|
|
"assistant_session_id": assistant_session["assistant_session_id"],
|
|
"status": assistant_session["status"],
|
|
"history_count": len(assistant_session.get("messages") or []),
|
|
},
|
|
)
|
|
async for event in self._stream_answer(ctx, payload, assistant_session=assistant_session):
|
|
yield event
|
|
|
|
async def _stream_answer(
|
|
self,
|
|
ctx: UserContext,
|
|
payload: LearningAssistantChatRequest,
|
|
*,
|
|
assistant_session: dict[str, Any] | None,
|
|
) -> AsyncIterator[str]:
|
|
"""学习助手流式核心流程:检索知识库、调用 LLM、写入查询日志和短期会话上下文。"""
|
|
start = time.perf_counter()
|
|
assistant_session_id = assistant_session.get("assistant_session_id") if assistant_session else None
|
|
history = (
|
|
self.session_store.get_messages(assistant_session_id, ctx.user_id, settings.learning_assistant_history_limit)
|
|
if assistant_session_id
|
|
else []
|
|
)
|
|
if assistant_session_id:
|
|
self.session_store.append_message(assistant_session_id, ctx.user_id, "user", payload.question)
|
|
|
|
retrieval = await self._retrieve_sources(ctx, payload)
|
|
yield self._sse(
|
|
"retrieval_done",
|
|
self._with_session(
|
|
assistant_session_id,
|
|
{
|
|
"retrieval_hit": bool(retrieval.sources),
|
|
"sources": [source.model_dump() for source in retrieval.sources],
|
|
"retrieval_error": retrieval.retrieval_error,
|
|
"embedding_latency_ms": retrieval.embedding_latency_ms,
|
|
"search_latency_ms": retrieval.search_latency_ms,
|
|
},
|
|
),
|
|
)
|
|
|
|
answer_parts: list[str] = []
|
|
llm_latency_ms: int | None = None
|
|
model: str | None = None
|
|
try:
|
|
async for chunk in self.agent.stream_answer(payload.question, retrieval.sources, history=history):
|
|
if chunk.done:
|
|
llm_latency_ms = chunk.total_latency_ms
|
|
model = chunk.model
|
|
break
|
|
if chunk.delta:
|
|
answer_parts.append(chunk.delta)
|
|
yield self._sse("answer_delta", self._with_session(assistant_session_id, {"delta": chunk.delta}))
|
|
except AppError as exc:
|
|
yield self._sse("error", self._with_session(assistant_session_id, {"code": exc.code, "message": exc.message}))
|
|
return
|
|
except Exception:
|
|
yield self._sse(
|
|
"error",
|
|
self._with_session(
|
|
assistant_session_id,
|
|
{"code": "LEARNING_ASSISTANT_LLM_FAILED", "message": "AI 学习助手回答生成失败,请稍后重试"},
|
|
),
|
|
)
|
|
return
|
|
|
|
answer = "".join(answer_parts)
|
|
total_latency_ms = int((time.perf_counter() - start) * 1000)
|
|
if assistant_session_id:
|
|
self.session_store.append_message(
|
|
assistant_session_id,
|
|
ctx.user_id,
|
|
"assistant",
|
|
answer,
|
|
metadata={"retrieval_hit": bool(retrieval.sources), "source_count": len(retrieval.sources), "model": model},
|
|
)
|
|
self._write_query_log(
|
|
ctx=ctx,
|
|
payload=payload,
|
|
retrieval=retrieval,
|
|
answer=answer,
|
|
model=model,
|
|
llm_latency_ms=llm_latency_ms,
|
|
total_latency_ms=total_latency_ms,
|
|
commit=True,
|
|
)
|
|
yield self._sse(
|
|
"answer_done",
|
|
self._with_session(
|
|
assistant_session_id,
|
|
{
|
|
"model": model,
|
|
"total_latency_ms": total_latency_ms,
|
|
"llm_latency_ms": llm_latency_ms,
|
|
},
|
|
),
|
|
)
|
|
|
|
async def _retrieve_sources(self, ctx: UserContext, payload: LearningAssistantChatRequest) -> LearningAssistantRetrieval:
|
|
"""知识检索:按机构读取知识空间;无空间、Milvus 或 embedding 异常时降级为空来源。"""
|
|
score_threshold = payload.score_threshold if payload.score_threshold is not None else settings.rag_score_threshold
|
|
try:
|
|
institution_id = self.space_service.require_institution_id(ctx)
|
|
except AppError:
|
|
return LearningAssistantRetrieval(
|
|
institution_id=None,
|
|
score_threshold=score_threshold,
|
|
sources=[],
|
|
retrieval_error="当前用户缺少机构信息,已转为大模型通用学习回答。",
|
|
)
|
|
|
|
try:
|
|
space = self.space_service.get_active_space(ctx)
|
|
retrieval = await self.vector_search.search(
|
|
institution_id=space.institution_id,
|
|
collection_name=space.collection_name,
|
|
question=payload.question,
|
|
top_k=payload.top_k,
|
|
score_threshold=payload.score_threshold,
|
|
)
|
|
return LearningAssistantRetrieval(
|
|
institution_id=space.institution_id,
|
|
score_threshold=payload.score_threshold if payload.score_threshold is not None else space.score_threshold,
|
|
sources=self._build_sources(retrieval.chunks),
|
|
embedding_latency_ms=retrieval.embedding_latency_ms,
|
|
search_latency_ms=retrieval.search_latency_ms,
|
|
)
|
|
except AppError as exc:
|
|
if exc.code in {"KNOWLEDGE_SPACE_NOT_FOUND", "MILVUS_COLLECTION_NOT_FOUND", "EMBEDDING_CALL_FAILED"}:
|
|
return LearningAssistantRetrieval(
|
|
institution_id=institution_id,
|
|
score_threshold=score_threshold,
|
|
sources=[],
|
|
retrieval_error="当前机构知识库暂未初始化或检索不可用,已转为大模型通用学习回答。",
|
|
)
|
|
raise
|
|
except Exception:
|
|
return LearningAssistantRetrieval(
|
|
institution_id=institution_id,
|
|
score_threshold=score_threshold,
|
|
sources=[],
|
|
retrieval_error="当前机构知识库检索暂不可用,已转为大模型通用学习回答。",
|
|
)
|
|
|
|
def _write_query_log(
|
|
self,
|
|
*,
|
|
ctx: UserContext,
|
|
payload: LearningAssistantChatRequest,
|
|
retrieval: LearningAssistantRetrieval,
|
|
answer: str,
|
|
model: str | None,
|
|
llm_latency_ms: int | None,
|
|
total_latency_ms: int | None,
|
|
commit: bool = False,
|
|
) -> None:
|
|
"""查询日志:仅在存在机构 ID 时记录 RAG 命中、来源和耗时。"""
|
|
if retrieval.institution_id is None:
|
|
return
|
|
self.repo.create_query_log(
|
|
user_id=ctx.user_id,
|
|
institution_id=retrieval.institution_id,
|
|
question=payload.question,
|
|
retrieval_hit=bool(retrieval.sources),
|
|
retrieved_chunk_ids=[source.chunk_uid for source in retrieval.sources],
|
|
answer_summary=answer,
|
|
llm_model=model,
|
|
top_k=payload.top_k or len(retrieval.sources) or settings.rag_top_k,
|
|
score_threshold=retrieval.score_threshold,
|
|
embedding_latency_ms=retrieval.embedding_latency_ms,
|
|
search_latency_ms=retrieval.search_latency_ms,
|
|
llm_latency_ms=llm_latency_ms,
|
|
total_latency_ms=total_latency_ms,
|
|
)
|
|
if commit:
|
|
self.db.commit()
|
|
|
|
def _build_sources(self, chunks: list[RetrievedChunk]) -> list[LearningAssistantSource]:
|
|
"""来源构建:把检索分片转换为前端可展示的 PDF 来源结构。"""
|
|
sources: list[LearningAssistantSource] = []
|
|
for item in chunks:
|
|
document = self.repo.get_document(item.chunk.document_id, item.chunk.institution_id)
|
|
sources.append(
|
|
LearningAssistantSource(
|
|
document_id=item.chunk.document_id,
|
|
document_title=document.document_title if document else None,
|
|
file_name=document.file_name if document else "",
|
|
page_start=item.chunk.page_start,
|
|
page_end=item.chunk.page_end,
|
|
chunk_uid=item.chunk.chunk_uid,
|
|
score=round(item.score, 4),
|
|
quote=item.chunk.chunk_text[:500],
|
|
)
|
|
)
|
|
return sources
|
|
|
|
def _session_response(self, state: dict[str, Any]) -> LearningAssistantSessionResponse:
|
|
"""会话响应转换:只返回前端需要展示和后续调用的字段。"""
|
|
return LearningAssistantSessionResponse(
|
|
assistant_session_id=state["assistant_session_id"],
|
|
user_id=state["user_id"],
|
|
institution_id=state.get("institution_id"),
|
|
institution_name=state.get("institution_name"),
|
|
title=state["title"],
|
|
status=state["status"],
|
|
created_at=state["created_at"],
|
|
updated_at=state["updated_at"],
|
|
expires_in_seconds=state["expires_in_seconds"],
|
|
)
|
|
|
|
def _sse(self, event: str, data: dict) -> str:
|
|
"""SSE 封装:统一输出 event + data 格式。"""
|
|
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
|
|
def _with_session(self, assistant_session_id: str | None, data: dict) -> dict:
|
|
"""SSE 数据增强:会话式接口返回 assistant_session_id,旧接口保持兼容。"""
|
|
if assistant_session_id:
|
|
return {"assistant_session_id": assistant_session_id, **data}
|
|
return data
|