from datetime import datetime from sqlalchemy import delete, select from sqlalchemy.orm import Session from app.core.config import settings from app.models.knowledge_base import ( KbKnowledgeChunk, KbKnowledgeDocument, KbKnowledgeIngestionTask, KbKnowledgeQueryLog, KbKnowledgeSpace, ) class KnowledgeBaseRepository: """知识库仓储:集中管理 kb_* 表的查询和写入。""" def __init__(self, db: Session) -> None: self.db = db def get_or_create_space(self, institution_id: int, institution_name: str | None) -> KbKnowledgeSpace: """知识空间:按机构和 embedding 版本获取或创建 Milvus collection 映射。""" version = "v1" space = self.db.scalar( select(KbKnowledgeSpace).where( KbKnowledgeSpace.institution_id == institution_id, KbKnowledgeSpace.embedding_version == version, ) ) if space: return space collection_name = self._collection_name(institution_id, version) space = KbKnowledgeSpace( institution_id=institution_id, institution_name=institution_name, space_code=f"institution_{institution_id}_{version}", collection_name=collection_name, embedding_model=settings.embedding_model, embedding_dim=settings.embedding_dim, embedding_version=version, chunk_size=1100, chunk_overlap=180, top_k_default=settings.rag_top_k, score_threshold=settings.rag_score_threshold, status="active", ) self.db.add(space) self.db.flush() return space def get_space(self, institution_id: int) -> KbKnowledgeSpace | None: """知识空间查询:读取机构当前可用的知识库 collection 映射。""" return self.db.scalar( select(KbKnowledgeSpace) .where(KbKnowledgeSpace.institution_id == institution_id, KbKnowledgeSpace.status == "active") .order_by(KbKnowledgeSpace.id.desc()) ) def get_document_by_hash(self, institution_id: int, file_sha256: str) -> KbKnowledgeDocument | None: """文档去重:按机构和文件 SHA256 判断是否已上传。""" return self.db.scalar( select(KbKnowledgeDocument).where( KbKnowledgeDocument.institution_id == institution_id, KbKnowledgeDocument.file_sha256 == file_sha256, ) ) def create_document( self, *, institution_id: int, uploaded_by: str, file_name: str, file_sha256: str, file_size: int, file_path: str, document_title: str | None, document_category: str, version: str, ) -> KbKnowledgeDocument: """文档创建:保存内容管理员上传 PDF 的元数据和本地存储路径。""" document = KbKnowledgeDocument( institution_id=institution_id, uploaded_by=uploaded_by, file_name=file_name, file_sha256=file_sha256, file_type="pdf", file_size=file_size, file_path=file_path, document_title=document_title, document_category=document_category, version=version, status="uploaded", parse_status="pending", embedding_status="pending", chunk_count=0, ) self.db.add(document) self.db.flush() return document def get_document(self, document_id: int, institution_id: int | None = None) -> KbKnowledgeDocument | None: """文档查询:按文档 ID 获取知识库文档,机构参数用于访问隔离。""" stmt = select(KbKnowledgeDocument).where(KbKnowledgeDocument.id == document_id) if institution_id is not None: stmt = stmt.where(KbKnowledgeDocument.institution_id == institution_id) return self.db.scalar(stmt) def create_ingestion_task(self, document: KbKnowledgeDocument) -> KbKnowledgeIngestionTask: """入库任务:记录 PDF 解析、分片、向量化和写入 Milvus 的处理进度。""" task = KbKnowledgeIngestionTask( document_id=document.id, institution_id=document.institution_id, task_type="document_ingestion", status="queued", progress=0, current_step="queued", ) self.db.add(task) self.db.flush() return task def get_ingestion_task(self, task_id: int) -> KbKnowledgeIngestionTask | None: """任务查询:按任务 ID 读取知识入库任务。""" return self.db.get(KbKnowledgeIngestionTask, task_id) def update_task( self, task: KbKnowledgeIngestionTask, *, status: str | None = None, progress: int | None = None, current_step: str | None = None, error_message: str | None = None, ) -> None: """任务进度:更新入库任务状态,供前端或运维查看。""" if status: task.status = status if status == "running" and task.started_at is None: task.started_at = datetime.utcnow() if status in {"success", "failed"}: task.finished_at = datetime.utcnow() if progress is not None: task.progress = progress if current_step is not None: task.current_step = current_step if error_message is not None: task.error_message = error_message def replace_chunks(self, document: KbKnowledgeDocument, chunks: list[KbKnowledgeChunk]) -> None: """分片替换:重新构建文档时先删除旧分片,再写入新分片。""" self.db.execute(delete(KbKnowledgeChunk).where(KbKnowledgeChunk.document_id == document.id)) for chunk in chunks: self.db.add(chunk) document.chunk_count = len(chunks) self.db.flush() def get_chunks_by_uids(self, institution_id: int, chunk_uids: list[str]) -> list[KbKnowledgeChunk]: """分片查询:根据 Milvus 返回的 chunk_uid 批量读取 MySQL 分片详情。""" if not chunk_uids: return [] rows = self.db.scalars( select(KbKnowledgeChunk).where( KbKnowledgeChunk.institution_id == institution_id, KbKnowledgeChunk.chunk_uid.in_(chunk_uids), ) ).all() order = {chunk_uid: index for index, chunk_uid in enumerate(chunk_uids)} return sorted(rows, key=lambda row: order.get(row.chunk_uid, 10_000)) def list_documents(self, institution_id: int, limit: int = 20) -> list[KbKnowledgeDocument]: """文档列表:返回机构最近上传的知识库文档。""" return list( self.db.scalars( select(KbKnowledgeDocument) .where(KbKnowledgeDocument.institution_id == institution_id) .order_by(KbKnowledgeDocument.id.desc()) .limit(limit) ) ) def create_query_log( self, *, user_id: str, institution_id: int, question: str, retrieval_hit: bool, retrieved_chunk_ids: list[str], answer_summary: str, llm_model: str | None, top_k: int, score_threshold: float, embedding_latency_ms: int | None, search_latency_ms: int | None, llm_latency_ms: int | None, total_latency_ms: int | None, ) -> KbKnowledgeQueryLog: """查询日志:记录 RAG 命中、来源和耗时,支撑后续审计与效果分析。""" log = KbKnowledgeQueryLog( user_id=user_id, institution_id=institution_id, question=question, retrieval_hit=bool(retrieval_hit), retrieved_chunk_ids=retrieved_chunk_ids, answer_summary=answer_summary[:1000], llm_model=llm_model, embedding_model=settings.embedding_model, top_k=top_k, score_threshold=score_threshold, embedding_latency_ms=embedding_latency_ms, search_latency_ms=search_latency_ms, llm_latency_ms=llm_latency_ms, total_latency_ms=total_latency_ms, ) self.db.add(log) self.db.flush() return log def _collection_name(self, institution_id: int, version: str) -> str: """集合命名:按机构隔离 Milvus collection,名称只使用安全字符。""" model_part = "".join(ch if ch.isalnum() else "_" for ch in settings.embedding_model.lower()) return f"{settings.milvus_collection_prefix}_{institution_id}_{model_part}_{version}"[:120]