Files
fastapi/app/repositories/knowledge_base_repository.py
T

225 lines
8.6 KiB
Python

from datetime import datetime
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from app.core.config import settings
from app.models.knowledge_base import (
KbKnowledgeChunk,
KbKnowledgeDocument,
KbKnowledgeIngestionTask,
KbKnowledgeQueryLog,
KbKnowledgeSpace,
)
class KnowledgeBaseRepository:
"""知识库仓储:集中管理 kb_* 表的查询和写入。"""
def __init__(self, db: Session) -> None:
self.db = db
def get_or_create_space(self, institution_id: int, institution_name: str | None) -> KbKnowledgeSpace:
"""知识空间:按机构和 embedding 版本获取或创建 Milvus collection 映射。"""
version = "v1"
space = self.db.scalar(
select(KbKnowledgeSpace).where(
KbKnowledgeSpace.institution_id == institution_id,
KbKnowledgeSpace.embedding_version == version,
)
)
if space:
return space
collection_name = self._collection_name(institution_id, version)
space = KbKnowledgeSpace(
institution_id=institution_id,
institution_name=institution_name,
space_code=f"institution_{institution_id}_{version}",
collection_name=collection_name,
embedding_model=settings.embedding_model,
embedding_dim=settings.embedding_dim,
embedding_version=version,
chunk_size=1100,
chunk_overlap=180,
top_k_default=settings.rag_top_k,
score_threshold=settings.rag_score_threshold,
status="active",
)
self.db.add(space)
self.db.flush()
return space
def get_space(self, institution_id: int) -> KbKnowledgeSpace | None:
"""知识空间查询:读取机构当前可用的知识库 collection 映射。"""
return self.db.scalar(
select(KbKnowledgeSpace)
.where(KbKnowledgeSpace.institution_id == institution_id, KbKnowledgeSpace.status == "active")
.order_by(KbKnowledgeSpace.id.desc())
)
def get_document_by_hash(self, institution_id: int, file_sha256: str) -> KbKnowledgeDocument | None:
"""文档去重:按机构和文件 SHA256 判断是否已上传。"""
return self.db.scalar(
select(KbKnowledgeDocument).where(
KbKnowledgeDocument.institution_id == institution_id,
KbKnowledgeDocument.file_sha256 == file_sha256,
)
)
def create_document(
self,
*,
institution_id: int,
uploaded_by: str,
file_name: str,
file_sha256: str,
file_size: int,
file_path: str,
document_title: str | None,
document_category: str,
version: str,
) -> KbKnowledgeDocument:
"""文档创建:保存内容管理员上传 PDF 的元数据和本地存储路径。"""
document = KbKnowledgeDocument(
institution_id=institution_id,
uploaded_by=uploaded_by,
file_name=file_name,
file_sha256=file_sha256,
file_type="pdf",
file_size=file_size,
file_path=file_path,
document_title=document_title,
document_category=document_category,
version=version,
status="uploaded",
parse_status="pending",
embedding_status="pending",
chunk_count=0,
)
self.db.add(document)
self.db.flush()
return document
def get_document(self, document_id: int, institution_id: int | None = None) -> KbKnowledgeDocument | None:
"""文档查询:按文档 ID 获取知识库文档,机构参数用于访问隔离。"""
stmt = select(KbKnowledgeDocument).where(KbKnowledgeDocument.id == document_id)
if institution_id is not None:
stmt = stmt.where(KbKnowledgeDocument.institution_id == institution_id)
return self.db.scalar(stmt)
def create_ingestion_task(self, document: KbKnowledgeDocument) -> KbKnowledgeIngestionTask:
"""入库任务:记录 PDF 解析、分片、向量化和写入 Milvus 的处理进度。"""
task = KbKnowledgeIngestionTask(
document_id=document.id,
institution_id=document.institution_id,
task_type="document_ingestion",
status="queued",
progress=0,
current_step="queued",
)
self.db.add(task)
self.db.flush()
return task
def get_ingestion_task(self, task_id: int) -> KbKnowledgeIngestionTask | None:
"""任务查询:按任务 ID 读取知识入库任务。"""
return self.db.get(KbKnowledgeIngestionTask, task_id)
def update_task(
self,
task: KbKnowledgeIngestionTask,
*,
status: str | None = None,
progress: int | None = None,
current_step: str | None = None,
error_message: str | None = None,
) -> None:
"""任务进度:更新入库任务状态,供前端或运维查看。"""
if status:
task.status = status
if status == "running" and task.started_at is None:
task.started_at = datetime.utcnow()
if status in {"success", "failed"}:
task.finished_at = datetime.utcnow()
if progress is not None:
task.progress = progress
if current_step is not None:
task.current_step = current_step
if error_message is not None:
task.error_message = error_message
def replace_chunks(self, document: KbKnowledgeDocument, chunks: list[KbKnowledgeChunk]) -> None:
"""分片替换:重新构建文档时先删除旧分片,再写入新分片。"""
self.db.execute(delete(KbKnowledgeChunk).where(KbKnowledgeChunk.document_id == document.id))
for chunk in chunks:
self.db.add(chunk)
document.chunk_count = len(chunks)
self.db.flush()
def get_chunks_by_uids(self, institution_id: int, chunk_uids: list[str]) -> list[KbKnowledgeChunk]:
"""分片查询:根据 Milvus 返回的 chunk_uid 批量读取 MySQL 分片详情。"""
if not chunk_uids:
return []
rows = self.db.scalars(
select(KbKnowledgeChunk).where(
KbKnowledgeChunk.institution_id == institution_id,
KbKnowledgeChunk.chunk_uid.in_(chunk_uids),
)
).all()
order = {chunk_uid: index for index, chunk_uid in enumerate(chunk_uids)}
return sorted(rows, key=lambda row: order.get(row.chunk_uid, 10_000))
def list_documents(self, institution_id: int, limit: int = 20) -> list[KbKnowledgeDocument]:
"""文档列表:返回机构最近上传的知识库文档。"""
return list(
self.db.scalars(
select(KbKnowledgeDocument)
.where(KbKnowledgeDocument.institution_id == institution_id)
.order_by(KbKnowledgeDocument.id.desc())
.limit(limit)
)
)
def create_query_log(
self,
*,
user_id: str,
institution_id: int,
question: str,
retrieval_hit: bool,
retrieved_chunk_ids: list[str],
answer_summary: str,
llm_model: str | None,
top_k: int,
score_threshold: float,
embedding_latency_ms: int | None,
search_latency_ms: int | None,
llm_latency_ms: int | None,
total_latency_ms: int | None,
) -> KbKnowledgeQueryLog:
"""查询日志:记录 RAG 命中、来源和耗时,支撑后续审计与效果分析。"""
log = KbKnowledgeQueryLog(
user_id=user_id,
institution_id=institution_id,
question=question,
retrieval_hit=bool(retrieval_hit),
retrieved_chunk_ids=retrieved_chunk_ids,
answer_summary=answer_summary[:1000],
llm_model=llm_model,
embedding_model=settings.embedding_model,
top_k=top_k,
score_threshold=score_threshold,
embedding_latency_ms=embedding_latency_ms,
search_latency_ms=search_latency_ms,
llm_latency_ms=llm_latency_ms,
total_latency_ms=total_latency_ms,
)
self.db.add(log)
self.db.flush()
return log
def _collection_name(self, institution_id: int, version: str) -> str:
"""集合命名:按机构隔离 Milvus collection,名称只使用安全字符。"""
model_part = "".join(ch if ch.isalnum() else "_" for ch in settings.embedding_model.lower())
return f"{settings.milvus_collection_prefix}_{institution_id}_{model_part}_{version}"[:120]