feat: add streaming learning assistant and knowledge base scaffolding
This commit is contained in:
@@ -0,0 +1,224 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.knowledge_base import (
|
||||
KbKnowledgeChunk,
|
||||
KbKnowledgeDocument,
|
||||
KbKnowledgeIngestionTask,
|
||||
KbKnowledgeQueryLog,
|
||||
KbKnowledgeSpace,
|
||||
)
|
||||
|
||||
|
||||
class KnowledgeBaseRepository:
|
||||
"""知识库仓储:集中管理 kb_* 表的查询和写入。"""
|
||||
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_or_create_space(self, institution_id: int, institution_name: str | None) -> KbKnowledgeSpace:
|
||||
"""知识空间:按机构和 embedding 版本获取或创建 Milvus collection 映射。"""
|
||||
version = "v1"
|
||||
space = self.db.scalar(
|
||||
select(KbKnowledgeSpace).where(
|
||||
KbKnowledgeSpace.institution_id == institution_id,
|
||||
KbKnowledgeSpace.embedding_version == version,
|
||||
)
|
||||
)
|
||||
if space:
|
||||
return space
|
||||
collection_name = self._collection_name(institution_id, version)
|
||||
space = KbKnowledgeSpace(
|
||||
institution_id=institution_id,
|
||||
institution_name=institution_name,
|
||||
space_code=f"institution_{institution_id}_{version}",
|
||||
collection_name=collection_name,
|
||||
embedding_model=settings.embedding_model,
|
||||
embedding_dim=settings.embedding_dim,
|
||||
embedding_version=version,
|
||||
chunk_size=1100,
|
||||
chunk_overlap=180,
|
||||
top_k_default=settings.rag_top_k,
|
||||
score_threshold=settings.rag_score_threshold,
|
||||
status="active",
|
||||
)
|
||||
self.db.add(space)
|
||||
self.db.flush()
|
||||
return space
|
||||
|
||||
def get_space(self, institution_id: int) -> KbKnowledgeSpace | None:
|
||||
"""知识空间查询:读取机构当前可用的知识库 collection 映射。"""
|
||||
return self.db.scalar(
|
||||
select(KbKnowledgeSpace)
|
||||
.where(KbKnowledgeSpace.institution_id == institution_id, KbKnowledgeSpace.status == "active")
|
||||
.order_by(KbKnowledgeSpace.id.desc())
|
||||
)
|
||||
|
||||
def get_document_by_hash(self, institution_id: int, file_sha256: str) -> KbKnowledgeDocument | None:
|
||||
"""文档去重:按机构和文件 SHA256 判断是否已上传。"""
|
||||
return self.db.scalar(
|
||||
select(KbKnowledgeDocument).where(
|
||||
KbKnowledgeDocument.institution_id == institution_id,
|
||||
KbKnowledgeDocument.file_sha256 == file_sha256,
|
||||
)
|
||||
)
|
||||
|
||||
def create_document(
|
||||
self,
|
||||
*,
|
||||
institution_id: int,
|
||||
uploaded_by: str,
|
||||
file_name: str,
|
||||
file_sha256: str,
|
||||
file_size: int,
|
||||
file_path: str,
|
||||
document_title: str | None,
|
||||
document_category: str,
|
||||
version: str,
|
||||
) -> KbKnowledgeDocument:
|
||||
"""文档创建:保存内容管理员上传 PDF 的元数据和本地存储路径。"""
|
||||
document = KbKnowledgeDocument(
|
||||
institution_id=institution_id,
|
||||
uploaded_by=uploaded_by,
|
||||
file_name=file_name,
|
||||
file_sha256=file_sha256,
|
||||
file_type="pdf",
|
||||
file_size=file_size,
|
||||
file_path=file_path,
|
||||
document_title=document_title,
|
||||
document_category=document_category,
|
||||
version=version,
|
||||
status="uploaded",
|
||||
parse_status="pending",
|
||||
embedding_status="pending",
|
||||
chunk_count=0,
|
||||
)
|
||||
self.db.add(document)
|
||||
self.db.flush()
|
||||
return document
|
||||
|
||||
def get_document(self, document_id: int, institution_id: int | None = None) -> KbKnowledgeDocument | None:
|
||||
"""文档查询:按文档 ID 获取知识库文档,机构参数用于访问隔离。"""
|
||||
stmt = select(KbKnowledgeDocument).where(KbKnowledgeDocument.id == document_id)
|
||||
if institution_id is not None:
|
||||
stmt = stmt.where(KbKnowledgeDocument.institution_id == institution_id)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def create_ingestion_task(self, document: KbKnowledgeDocument) -> KbKnowledgeIngestionTask:
|
||||
"""入库任务:记录 PDF 解析、分片、向量化和写入 Milvus 的处理进度。"""
|
||||
task = KbKnowledgeIngestionTask(
|
||||
document_id=document.id,
|
||||
institution_id=document.institution_id,
|
||||
task_type="document_ingestion",
|
||||
status="queued",
|
||||
progress=0,
|
||||
current_step="queued",
|
||||
)
|
||||
self.db.add(task)
|
||||
self.db.flush()
|
||||
return task
|
||||
|
||||
def get_ingestion_task(self, task_id: int) -> KbKnowledgeIngestionTask | None:
|
||||
"""任务查询:按任务 ID 读取知识入库任务。"""
|
||||
return self.db.get(KbKnowledgeIngestionTask, task_id)
|
||||
|
||||
def update_task(
|
||||
self,
|
||||
task: KbKnowledgeIngestionTask,
|
||||
*,
|
||||
status: str | None = None,
|
||||
progress: int | None = None,
|
||||
current_step: str | None = None,
|
||||
error_message: str | None = None,
|
||||
) -> None:
|
||||
"""任务进度:更新入库任务状态,供前端或运维查看。"""
|
||||
if status:
|
||||
task.status = status
|
||||
if status == "running" and task.started_at is None:
|
||||
task.started_at = datetime.utcnow()
|
||||
if status in {"success", "failed"}:
|
||||
task.finished_at = datetime.utcnow()
|
||||
if progress is not None:
|
||||
task.progress = progress
|
||||
if current_step is not None:
|
||||
task.current_step = current_step
|
||||
if error_message is not None:
|
||||
task.error_message = error_message
|
||||
|
||||
def replace_chunks(self, document: KbKnowledgeDocument, chunks: list[KbKnowledgeChunk]) -> None:
|
||||
"""分片替换:重新构建文档时先删除旧分片,再写入新分片。"""
|
||||
self.db.execute(delete(KbKnowledgeChunk).where(KbKnowledgeChunk.document_id == document.id))
|
||||
for chunk in chunks:
|
||||
self.db.add(chunk)
|
||||
document.chunk_count = len(chunks)
|
||||
self.db.flush()
|
||||
|
||||
def get_chunks_by_uids(self, institution_id: int, chunk_uids: list[str]) -> list[KbKnowledgeChunk]:
|
||||
"""分片查询:根据 Milvus 返回的 chunk_uid 批量读取 MySQL 分片详情。"""
|
||||
if not chunk_uids:
|
||||
return []
|
||||
rows = self.db.scalars(
|
||||
select(KbKnowledgeChunk).where(
|
||||
KbKnowledgeChunk.institution_id == institution_id,
|
||||
KbKnowledgeChunk.chunk_uid.in_(chunk_uids),
|
||||
)
|
||||
).all()
|
||||
order = {chunk_uid: index for index, chunk_uid in enumerate(chunk_uids)}
|
||||
return sorted(rows, key=lambda row: order.get(row.chunk_uid, 10_000))
|
||||
|
||||
def list_documents(self, institution_id: int, limit: int = 20) -> list[KbKnowledgeDocument]:
|
||||
"""文档列表:返回机构最近上传的知识库文档。"""
|
||||
return list(
|
||||
self.db.scalars(
|
||||
select(KbKnowledgeDocument)
|
||||
.where(KbKnowledgeDocument.institution_id == institution_id)
|
||||
.order_by(KbKnowledgeDocument.id.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
)
|
||||
|
||||
def create_query_log(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
institution_id: int,
|
||||
question: str,
|
||||
retrieval_hit: bool,
|
||||
retrieved_chunk_ids: list[str],
|
||||
answer_summary: str,
|
||||
llm_model: str | None,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
embedding_latency_ms: int | None,
|
||||
search_latency_ms: int | None,
|
||||
llm_latency_ms: int | None,
|
||||
total_latency_ms: int | None,
|
||||
) -> KbKnowledgeQueryLog:
|
||||
"""查询日志:记录 RAG 命中、来源和耗时,支撑后续审计与效果分析。"""
|
||||
log = KbKnowledgeQueryLog(
|
||||
user_id=user_id,
|
||||
institution_id=institution_id,
|
||||
question=question,
|
||||
retrieval_hit=bool(retrieval_hit),
|
||||
retrieved_chunk_ids=retrieved_chunk_ids,
|
||||
answer_summary=answer_summary[:1000],
|
||||
llm_model=llm_model,
|
||||
embedding_model=settings.embedding_model,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
embedding_latency_ms=embedding_latency_ms,
|
||||
search_latency_ms=search_latency_ms,
|
||||
llm_latency_ms=llm_latency_ms,
|
||||
total_latency_ms=total_latency_ms,
|
||||
)
|
||||
self.db.add(log)
|
||||
self.db.flush()
|
||||
return log
|
||||
|
||||
def _collection_name(self, institution_id: int, version: str) -> str:
|
||||
"""集合命名:按机构隔离 Milvus collection,名称只使用安全字符。"""
|
||||
model_part = "".join(ch if ch.isalnum() else "_" for ch in settings.embedding_model.lower())
|
||||
return f"{settings.milvus_collection_prefix}_{institution_id}_{model_part}_{version}"[:120]
|
||||
Reference in New Issue
Block a user