225 lines
8.6 KiB
Python
225 lines
8.6 KiB
Python
from datetime import datetime
|
|
|
|
from sqlalchemy import delete, select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.core.config import settings
|
|
from app.models.knowledge_base import (
|
|
KbKnowledgeChunk,
|
|
KbKnowledgeDocument,
|
|
KbKnowledgeIngestionTask,
|
|
KbKnowledgeQueryLog,
|
|
KbKnowledgeSpace,
|
|
)
|
|
|
|
|
|
class KnowledgeBaseRepository:
|
|
"""知识库仓储:集中管理 kb_* 表的查询和写入。"""
|
|
|
|
def __init__(self, db: Session) -> None:
|
|
self.db = db
|
|
|
|
def get_or_create_space(self, institution_id: int, institution_name: str | None) -> KbKnowledgeSpace:
|
|
"""知识空间:按机构和 embedding 版本获取或创建 Milvus collection 映射。"""
|
|
version = "v1"
|
|
space = self.db.scalar(
|
|
select(KbKnowledgeSpace).where(
|
|
KbKnowledgeSpace.institution_id == institution_id,
|
|
KbKnowledgeSpace.embedding_version == version,
|
|
)
|
|
)
|
|
if space:
|
|
return space
|
|
collection_name = self._collection_name(institution_id, version)
|
|
space = KbKnowledgeSpace(
|
|
institution_id=institution_id,
|
|
institution_name=institution_name,
|
|
space_code=f"institution_{institution_id}_{version}",
|
|
collection_name=collection_name,
|
|
embedding_model=settings.embedding_model,
|
|
embedding_dim=settings.embedding_dim,
|
|
embedding_version=version,
|
|
chunk_size=1100,
|
|
chunk_overlap=180,
|
|
top_k_default=settings.rag_top_k,
|
|
score_threshold=settings.rag_score_threshold,
|
|
status="active",
|
|
)
|
|
self.db.add(space)
|
|
self.db.flush()
|
|
return space
|
|
|
|
def get_space(self, institution_id: int) -> KbKnowledgeSpace | None:
|
|
"""知识空间查询:读取机构当前可用的知识库 collection 映射。"""
|
|
return self.db.scalar(
|
|
select(KbKnowledgeSpace)
|
|
.where(KbKnowledgeSpace.institution_id == institution_id, KbKnowledgeSpace.status == "active")
|
|
.order_by(KbKnowledgeSpace.id.desc())
|
|
)
|
|
|
|
def get_document_by_hash(self, institution_id: int, file_sha256: str) -> KbKnowledgeDocument | None:
|
|
"""文档去重:按机构和文件 SHA256 判断是否已上传。"""
|
|
return self.db.scalar(
|
|
select(KbKnowledgeDocument).where(
|
|
KbKnowledgeDocument.institution_id == institution_id,
|
|
KbKnowledgeDocument.file_sha256 == file_sha256,
|
|
)
|
|
)
|
|
|
|
def create_document(
|
|
self,
|
|
*,
|
|
institution_id: int,
|
|
uploaded_by: str,
|
|
file_name: str,
|
|
file_sha256: str,
|
|
file_size: int,
|
|
file_path: str,
|
|
document_title: str | None,
|
|
document_category: str,
|
|
version: str,
|
|
) -> KbKnowledgeDocument:
|
|
"""文档创建:保存内容管理员上传 PDF 的元数据和本地存储路径。"""
|
|
document = KbKnowledgeDocument(
|
|
institution_id=institution_id,
|
|
uploaded_by=uploaded_by,
|
|
file_name=file_name,
|
|
file_sha256=file_sha256,
|
|
file_type="pdf",
|
|
file_size=file_size,
|
|
file_path=file_path,
|
|
document_title=document_title,
|
|
document_category=document_category,
|
|
version=version,
|
|
status="uploaded",
|
|
parse_status="pending",
|
|
embedding_status="pending",
|
|
chunk_count=0,
|
|
)
|
|
self.db.add(document)
|
|
self.db.flush()
|
|
return document
|
|
|
|
def get_document(self, document_id: int, institution_id: int | None = None) -> KbKnowledgeDocument | None:
|
|
"""文档查询:按文档 ID 获取知识库文档,机构参数用于访问隔离。"""
|
|
stmt = select(KbKnowledgeDocument).where(KbKnowledgeDocument.id == document_id)
|
|
if institution_id is not None:
|
|
stmt = stmt.where(KbKnowledgeDocument.institution_id == institution_id)
|
|
return self.db.scalar(stmt)
|
|
|
|
def create_ingestion_task(self, document: KbKnowledgeDocument) -> KbKnowledgeIngestionTask:
|
|
"""入库任务:记录 PDF 解析、分片、向量化和写入 Milvus 的处理进度。"""
|
|
task = KbKnowledgeIngestionTask(
|
|
document_id=document.id,
|
|
institution_id=document.institution_id,
|
|
task_type="document_ingestion",
|
|
status="queued",
|
|
progress=0,
|
|
current_step="queued",
|
|
)
|
|
self.db.add(task)
|
|
self.db.flush()
|
|
return task
|
|
|
|
def get_ingestion_task(self, task_id: int) -> KbKnowledgeIngestionTask | None:
|
|
"""任务查询:按任务 ID 读取知识入库任务。"""
|
|
return self.db.get(KbKnowledgeIngestionTask, task_id)
|
|
|
|
def update_task(
|
|
self,
|
|
task: KbKnowledgeIngestionTask,
|
|
*,
|
|
status: str | None = None,
|
|
progress: int | None = None,
|
|
current_step: str | None = None,
|
|
error_message: str | None = None,
|
|
) -> None:
|
|
"""任务进度:更新入库任务状态,供前端或运维查看。"""
|
|
if status:
|
|
task.status = status
|
|
if status == "running" and task.started_at is None:
|
|
task.started_at = datetime.utcnow()
|
|
if status in {"success", "failed"}:
|
|
task.finished_at = datetime.utcnow()
|
|
if progress is not None:
|
|
task.progress = progress
|
|
if current_step is not None:
|
|
task.current_step = current_step
|
|
if error_message is not None:
|
|
task.error_message = error_message
|
|
|
|
def replace_chunks(self, document: KbKnowledgeDocument, chunks: list[KbKnowledgeChunk]) -> None:
|
|
"""分片替换:重新构建文档时先删除旧分片,再写入新分片。"""
|
|
self.db.execute(delete(KbKnowledgeChunk).where(KbKnowledgeChunk.document_id == document.id))
|
|
for chunk in chunks:
|
|
self.db.add(chunk)
|
|
document.chunk_count = len(chunks)
|
|
self.db.flush()
|
|
|
|
def get_chunks_by_uids(self, institution_id: int, chunk_uids: list[str]) -> list[KbKnowledgeChunk]:
|
|
"""分片查询:根据 Milvus 返回的 chunk_uid 批量读取 MySQL 分片详情。"""
|
|
if not chunk_uids:
|
|
return []
|
|
rows = self.db.scalars(
|
|
select(KbKnowledgeChunk).where(
|
|
KbKnowledgeChunk.institution_id == institution_id,
|
|
KbKnowledgeChunk.chunk_uid.in_(chunk_uids),
|
|
)
|
|
).all()
|
|
order = {chunk_uid: index for index, chunk_uid in enumerate(chunk_uids)}
|
|
return sorted(rows, key=lambda row: order.get(row.chunk_uid, 10_000))
|
|
|
|
def list_documents(self, institution_id: int, limit: int = 20) -> list[KbKnowledgeDocument]:
|
|
"""文档列表:返回机构最近上传的知识库文档。"""
|
|
return list(
|
|
self.db.scalars(
|
|
select(KbKnowledgeDocument)
|
|
.where(KbKnowledgeDocument.institution_id == institution_id)
|
|
.order_by(KbKnowledgeDocument.id.desc())
|
|
.limit(limit)
|
|
)
|
|
)
|
|
|
|
def create_query_log(
|
|
self,
|
|
*,
|
|
user_id: str,
|
|
institution_id: int,
|
|
question: str,
|
|
retrieval_hit: bool,
|
|
retrieved_chunk_ids: list[str],
|
|
answer_summary: str,
|
|
llm_model: str | None,
|
|
top_k: int,
|
|
score_threshold: float,
|
|
embedding_latency_ms: int | None,
|
|
search_latency_ms: int | None,
|
|
llm_latency_ms: int | None,
|
|
total_latency_ms: int | None,
|
|
) -> KbKnowledgeQueryLog:
|
|
"""查询日志:记录 RAG 命中、来源和耗时,支撑后续审计与效果分析。"""
|
|
log = KbKnowledgeQueryLog(
|
|
user_id=user_id,
|
|
institution_id=institution_id,
|
|
question=question,
|
|
retrieval_hit=bool(retrieval_hit),
|
|
retrieved_chunk_ids=retrieved_chunk_ids,
|
|
answer_summary=answer_summary[:1000],
|
|
llm_model=llm_model,
|
|
embedding_model=settings.embedding_model,
|
|
top_k=top_k,
|
|
score_threshold=score_threshold,
|
|
embedding_latency_ms=embedding_latency_ms,
|
|
search_latency_ms=search_latency_ms,
|
|
llm_latency_ms=llm_latency_ms,
|
|
total_latency_ms=total_latency_ms,
|
|
)
|
|
self.db.add(log)
|
|
self.db.flush()
|
|
return log
|
|
|
|
def _collection_name(self, institution_id: int, version: str) -> str:
|
|
"""集合命名:按机构隔离 Milvus collection,名称只使用安全字符。"""
|
|
model_part = "".join(ch if ch.isalnum() else "_" for ch in settings.embedding_model.lower())
|
|
return f"{settings.milvus_collection_prefix}_{institution_id}_{model_part}_{version}"[:120]
|