import hashlib import re from dataclasses import dataclass from app.integrations.pdf_parser import ParsedPdfPage from app.models.knowledge_base import KbKnowledgeChunk @dataclass(frozen=True) class ChunkDraft: """分片草稿:PDF 文本切分后的中间结构,后续写入 MySQL 和 Milvus。""" chunk_index: int page_start: int page_end: int section_title: str | None text: str class DocumentChunkService: """文档分片服务:面向教材/指南 PDF 的页码保留语义分片。""" def __init__(self, chunk_size: int = 1100, chunk_overlap: int = 180) -> None: self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap def build_chunks(self, pages: list[ParsedPdfPage]) -> list[ChunkDraft]: """教材分片:按页和自然段切分,超长段落使用窗口切分并保留页码。""" drafts: list[ChunkDraft] = [] buffer: list[str] = [] page_start: int | None = None page_end: int | None = None current_title: str | None = None for page in pages: paragraphs = self._split_paragraphs(page.text) for paragraph in paragraphs: detected_title = self._detect_title(paragraph) if detected_title: current_title = detected_title for piece in self._split_long_text(paragraph): candidate = "\n".join([*buffer, piece]).strip() if buffer and len(candidate) > self.chunk_size: drafts.append( ChunkDraft( chunk_index=len(drafts), page_start=page_start or page.page_number, page_end=page_end or page.page_number, section_title=current_title, text="\n".join(buffer).strip(), ) ) buffer = self._overlap_tail(buffer) page_start = page.page_number if not buffer else page_start if not buffer: page_start = page.page_number page_end = page.page_number buffer.append(piece) if buffer: drafts.append( ChunkDraft( chunk_index=len(drafts), page_start=page_start or pages[-1].page_number, page_end=page_end or pages[-1].page_number, section_title=current_title, text="\n".join(buffer).strip(), ) ) return [draft for draft in drafts if draft.text] def to_models( self, *, institution_id: int, document_id: int, collection_name: str, embedding_model: str, drafts: list[ChunkDraft], ) -> list[KbKnowledgeChunk]: """分片落库:把分片草稿转换为 ORM 对象,chunk_uid 同时作为 Milvus vector_id。""" rows: list[KbKnowledgeChunk] = [] for draft in drafts: chunk_hash = hashlib.sha256(draft.text.encode("utf-8")).hexdigest() chunk_uid = f"doc{document_id}_chunk{draft.chunk_index}_{chunk_hash[:12]}" rows.append( KbKnowledgeChunk( institution_id=institution_id, document_id=document_id, chunk_uid=chunk_uid, chunk_index=draft.chunk_index, page_start=draft.page_start, page_end=draft.page_end, section_title=draft.section_title, chunk_text=draft.text, chunk_hash=chunk_hash, token_count=max(1, len(draft.text) // 2), vector_id=chunk_uid, collection_name=collection_name, embedding_model=embedding_model, metadata_={"chunking": "page_semantic_window", "chunk_size": self.chunk_size, "overlap": self.chunk_overlap}, ) ) return rows def _split_paragraphs(self, text: str) -> list[str]: """段落切分:优先按 PDF 自带换行和空白段落切分教材内容。""" parts = re.split(r"\n{1,}", text) return [part.strip() for part in parts if part.strip()] def _split_long_text(self, text: str) -> list[str]: """超长兜底:对超过窗口的段落按句末标点拆分,仍过长时按字符窗口切分。""" if len(text) <= self.chunk_size: return [text] sentences = re.split(r"(?<=[。!?;;.!?])", text) pieces: list[str] = [] current = "" for sentence in sentences: if len(current) + len(sentence) > self.chunk_size and current: pieces.append(current.strip()) current = current[-self.chunk_overlap :] if self.chunk_overlap else "" current += sentence if current.strip(): pieces.append(current.strip()) final: list[str] = [] for piece in pieces: if len(piece) <= self.chunk_size: final.append(piece) continue start = 0 while start < len(piece): final.append(piece[start : start + self.chunk_size]) start += max(1, self.chunk_size - self.chunk_overlap) return final def _overlap_tail(self, buffer: list[str]) -> list[str]: """重叠窗口:保留上一片尾部少量文本,提升跨片问题召回。""" if not self.chunk_overlap: return [] text = "\n".join(buffer).strip() tail = text[-self.chunk_overlap :] return [tail] if tail else [] def _detect_title(self, paragraph: str) -> str | None: """标题识别:识别教材常见章、节、条目标题,作为分片元数据。""" compact = paragraph.strip() if len(compact) > 80: return None title_patterns = [ r"^第[一二三四五六七八九十百0-9]+[章节篇]", r"^[一二三四五六七八九十]+[、..]", r"^\d+(\.\d+){0,3}\s+", ] return compact if any(re.search(pattern, compact) for pattern in title_patterns) else None