Files
fastapi/app/services/document_chunk_service.py

155 lines
6.3 KiB
Python

import hashlib
import re
from dataclasses import dataclass
from app.integrations.pdf_parser import ParsedPdfPage
from app.models.knowledge_base import KbKnowledgeChunk
@dataclass(frozen=True)
class ChunkDraft:
"""分片草稿:PDF 文本切分后的中间结构,后续写入 MySQL 和 Milvus。"""
chunk_index: int
page_start: int
page_end: int
section_title: str | None
text: str
class DocumentChunkService:
"""文档分片服务:面向教材/指南 PDF 的页码保留语义分片。"""
def __init__(self, chunk_size: int = 1100, chunk_overlap: int = 180) -> None:
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def build_chunks(self, pages: list[ParsedPdfPage]) -> list[ChunkDraft]:
"""教材分片:按页和自然段切分,超长段落使用窗口切分并保留页码。"""
drafts: list[ChunkDraft] = []
buffer: list[str] = []
page_start: int | None = None
page_end: int | None = None
current_title: str | None = None
for page in pages:
paragraphs = self._split_paragraphs(page.text)
for paragraph in paragraphs:
detected_title = self._detect_title(paragraph)
if detected_title:
current_title = detected_title
for piece in self._split_long_text(paragraph):
candidate = "\n".join([*buffer, piece]).strip()
if buffer and len(candidate) > self.chunk_size:
drafts.append(
ChunkDraft(
chunk_index=len(drafts),
page_start=page_start or page.page_number,
page_end=page_end or page.page_number,
section_title=current_title,
text="\n".join(buffer).strip(),
)
)
buffer = self._overlap_tail(buffer)
page_start = page.page_number if not buffer else page_start
if not buffer:
page_start = page.page_number
page_end = page.page_number
buffer.append(piece)
if buffer:
drafts.append(
ChunkDraft(
chunk_index=len(drafts),
page_start=page_start or pages[-1].page_number,
page_end=page_end or pages[-1].page_number,
section_title=current_title,
text="\n".join(buffer).strip(),
)
)
return [draft for draft in drafts if draft.text]
def to_models(
self,
*,
institution_id: int,
document_id: int,
collection_name: str,
embedding_model: str,
drafts: list[ChunkDraft],
) -> list[KbKnowledgeChunk]:
"""分片落库:把分片草稿转换为 ORM 对象,chunk_uid 同时作为 Milvus vector_id。"""
rows: list[KbKnowledgeChunk] = []
for draft in drafts:
chunk_hash = hashlib.sha256(draft.text.encode("utf-8")).hexdigest()
chunk_uid = f"doc{document_id}_chunk{draft.chunk_index}_{chunk_hash[:12]}"
rows.append(
KbKnowledgeChunk(
institution_id=institution_id,
document_id=document_id,
chunk_uid=chunk_uid,
chunk_index=draft.chunk_index,
page_start=draft.page_start,
page_end=draft.page_end,
section_title=draft.section_title,
chunk_text=draft.text,
chunk_hash=chunk_hash,
token_count=max(1, len(draft.text) // 2),
vector_id=chunk_uid,
collection_name=collection_name,
embedding_model=embedding_model,
metadata_={"chunking": "page_semantic_window", "chunk_size": self.chunk_size, "overlap": self.chunk_overlap},
)
)
return rows
def _split_paragraphs(self, text: str) -> list[str]:
"""段落切分:优先按 PDF 自带换行和空白段落切分教材内容。"""
parts = re.split(r"\n{1,}", text)
return [part.strip() for part in parts if part.strip()]
def _split_long_text(self, text: str) -> list[str]:
"""超长兜底:对超过窗口的段落按句末标点拆分,仍过长时按字符窗口切分。"""
if len(text) <= self.chunk_size:
return [text]
sentences = re.split(r"(?<=[。!?;;.!?])", text)
pieces: list[str] = []
current = ""
for sentence in sentences:
if len(current) + len(sentence) > self.chunk_size and current:
pieces.append(current.strip())
current = current[-self.chunk_overlap :] if self.chunk_overlap else ""
current += sentence
if current.strip():
pieces.append(current.strip())
final: list[str] = []
for piece in pieces:
if len(piece) <= self.chunk_size:
final.append(piece)
continue
start = 0
while start < len(piece):
final.append(piece[start : start + self.chunk_size])
start += max(1, self.chunk_size - self.chunk_overlap)
return final
def _overlap_tail(self, buffer: list[str]) -> list[str]:
"""重叠窗口:保留上一片尾部少量文本,提升跨片问题召回。"""
if not self.chunk_overlap:
return []
text = "\n".join(buffer).strip()
tail = text[-self.chunk_overlap :]
return [tail] if tail else []
def _detect_title(self, paragraph: str) -> str | None:
"""标题识别:识别教材常见章、节、条目标题,作为分片元数据。"""
compact = paragraph.strip()
if len(compact) > 80:
return None
title_patterns = [
r"^第[一二三四五六七八九十百0-9]+[章节篇]",
r"^[一二三四五六七八九十]+[、..]",
r"^\d+(\.\d+){0,3}\s+",
]
return compact if any(re.search(pattern, compact) for pattern in title_patterns) else None