172 lines
6.5 KiB
Python
172 lines
6.5 KiB
Python
|
|
import re
|
|||
|
|
from typing import List
|
|||
|
|
from pydantic import BaseModel, Field
|
|||
|
|
from typing_extensions import Literal
|
|||
|
|
|
|||
|
|
from configs.kb_config import CHUNK_SIZE, OVERLAP_SIZE
|
|||
|
|
|
|||
|
|
|
|||
|
|
class Document(BaseModel):
|
|||
|
|
page_content: str
|
|||
|
|
metadata: dict = Field(default_factory=dict)
|
|||
|
|
type: Literal["Document"] = "Document"
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MarkdownTextSplitter:
|
|||
|
|
def __init__(self, headers_to_split_on: List[str] = None, **kwargs):
|
|||
|
|
self.chunk_size = CHUNK_SIZE
|
|||
|
|
self.overlap_size = OVERLAP_SIZE
|
|||
|
|
self.headers_to_split_on = headers_to_split_on or ["#", "##", "###", "####"]
|
|||
|
|
|
|||
|
|
def clean_text(self, text: str) -> str:
|
|||
|
|
"""
|
|||
|
|
清理文本中的特殊符号,如 \n、\t、\\n 等,及图片格式链接,如 
|
|||
|
|
"""
|
|||
|
|
# 去除 \n、\t、\\n 等多余的特殊符号
|
|||
|
|
text = text.replace("\n", " ").replace("\t", " ").replace("\\n", " ").strip()
|
|||
|
|
|
|||
|
|
# 正则匹配 Markdown 图片格式并清除,例如 
|
|||
|
|
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
|
|||
|
|
|
|||
|
|
return text
|
|||
|
|
def split_text_by_headers(self, markdown_document: str):
|
|||
|
|
"""
|
|||
|
|
使用正则表达式将 Markdown 文档根据标题分段
|
|||
|
|
"""
|
|||
|
|
header_pattern = r"^(#{1,6})\s+(.*)$" # 捕获所有级别的标题(从 # 到 ######)
|
|||
|
|
sections = []
|
|||
|
|
current_header = None
|
|||
|
|
current_content = []
|
|||
|
|
current_header_level = 0
|
|||
|
|
|
|||
|
|
# 按行处理 markdown 文档
|
|||
|
|
for line in markdown_document.split("\n"):
|
|||
|
|
match = re.match(header_pattern, line)
|
|||
|
|
if match:
|
|||
|
|
# 如果找到标题,处理之前的部分
|
|||
|
|
if current_header:
|
|||
|
|
sections.append((current_header, current_header_level, "\n".join(current_content)))
|
|||
|
|
# 更新标题和内容
|
|||
|
|
current_header = match.group(2).strip() # 标题内容
|
|||
|
|
current_header_level = len(match.group(1)) # 标题级别,# 代表 h1, ## 代表 h2 等
|
|||
|
|
current_content = []
|
|||
|
|
else:
|
|||
|
|
# 否则将该行加入当前内容
|
|||
|
|
current_content.append(self.clean_text(line))
|
|||
|
|
|
|||
|
|
# 添加最后一个部分
|
|||
|
|
if current_header:
|
|||
|
|
sections.append((current_header, current_header_level, "\n".join(current_content)))
|
|||
|
|
|
|||
|
|
return sections
|
|||
|
|
|
|||
|
|
def split_paragraphs(self, content: str) -> List[str]:
|
|||
|
|
"""
|
|||
|
|
按照 chunk_size 将文本分段,保持语句完整,尽量在句子结束处分段。
|
|||
|
|
分段规则:按句号、换行符分割,且每段字数不超过 chunk_size。若没有合适的标点符号,则强制截断。
|
|||
|
|
"""
|
|||
|
|
paragraphs = []
|
|||
|
|
current_paragraph = ""
|
|||
|
|
|
|||
|
|
# 用正则按句子结束符(句号、问号、感叹号等)分割文本
|
|||
|
|
sentences = re.split(r'([。!?])', content) # 捕获句子结尾符
|
|||
|
|
|
|||
|
|
# 将分割后的句子和句末标点符号重新拼接成完整的句子
|
|||
|
|
sentences = [s.strip() + (sentences[i + 1] if i + 1 < len(sentences) else '')
|
|||
|
|
for i, s in enumerate(sentences) if i % 2 == 0]
|
|||
|
|
|
|||
|
|
# 遍历所有句子
|
|||
|
|
for sentence in sentences:
|
|||
|
|
sentence = self.clean_text(sentence)
|
|||
|
|
|
|||
|
|
# 判断当前句子是否加入到当前段落
|
|||
|
|
if len(current_paragraph) + len(sentence) + 1 <= self.chunk_size:
|
|||
|
|
# 如果加入后不超过最大字数,继续添加到当前段落
|
|||
|
|
current_paragraph += (" " + sentence) if current_paragraph else sentence
|
|||
|
|
else:
|
|||
|
|
# 如果当前段落已超过字数限制,则检查是否能在当前句子末尾进行分割
|
|||
|
|
if len(current_paragraph) > self.chunk_size:
|
|||
|
|
paragraphs.append(current_paragraph[:self.chunk_size]) # 强制截断
|
|||
|
|
current_paragraph = current_paragraph[self.chunk_size:] # 剩余部分移到下一个段落
|
|||
|
|
|
|||
|
|
# 处理当前句子
|
|||
|
|
if len(sentence) > self.chunk_size:
|
|||
|
|
# 如果单个句子超过 chunk_size,则强制截断
|
|||
|
|
while len(sentence) > self.chunk_size:
|
|||
|
|
paragraphs.append(sentence[:self.chunk_size])
|
|||
|
|
sentence = sentence[self.chunk_size:]
|
|||
|
|
|
|||
|
|
# 最后将剩余的句子添加到当前段落
|
|||
|
|
current_paragraph = sentence
|
|||
|
|
|
|||
|
|
# 添加最后一个段落(如果有的话)
|
|||
|
|
if current_paragraph:
|
|||
|
|
paragraphs.append(current_paragraph)
|
|||
|
|
|
|||
|
|
return paragraphs
|
|||
|
|
def split_documents(self, sections: List[str], doc_source) -> List[Document]:
|
|||
|
|
"""
|
|||
|
|
递归分段,根据每个 Markdown 文档部分生成对应的内容
|
|||
|
|
"""
|
|||
|
|
final_splits = []
|
|||
|
|
|
|||
|
|
for header, header_level, content in sections:
|
|||
|
|
# 递归分段每个部分
|
|||
|
|
paragraphs = self.split_paragraphs(content)
|
|||
|
|
|
|||
|
|
# 根据标题级别,将标题存入对应的 h1, h2, h3 等字段
|
|||
|
|
metadata = {"source": doc_source, "header": header}
|
|||
|
|
if header_level == 1:
|
|||
|
|
metadata["h1"] = header
|
|||
|
|
elif header_level == 2:
|
|||
|
|
metadata["h2"] = header
|
|||
|
|
elif header_level == 3:
|
|||
|
|
metadata["h3"] = header
|
|||
|
|
elif header_level == 4:
|
|||
|
|
metadata["h4"] = header
|
|||
|
|
elif header_level == 5:
|
|||
|
|
metadata["h5"] = header
|
|||
|
|
elif header_level == 6:
|
|||
|
|
metadata["h6"] = header
|
|||
|
|
|
|||
|
|
# 对每个段落创建 Document 对象
|
|||
|
|
for paragraph in paragraphs:
|
|||
|
|
doc = Document(
|
|||
|
|
page_content=paragraph,
|
|||
|
|
metadata=metadata
|
|||
|
|
)
|
|||
|
|
final_splits.append(doc)
|
|||
|
|
|
|||
|
|
return final_splits
|
|||
|
|
|
|||
|
|
def split_markdown_text(self, markdown_document: str, doc_source: str) -> List[Document]:
|
|||
|
|
|
|||
|
|
# 首先根据标题分段
|
|||
|
|
sections = self.split_text_by_headers(markdown_document)
|
|||
|
|
|
|||
|
|
# 进一步分段并创建 Document 对象
|
|||
|
|
final_splits = self.split_documents(sections, doc_source)
|
|||
|
|
|
|||
|
|
# 返回最终的文档段落列表
|
|||
|
|
return final_splits
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 示例使用
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
doc_source = ""
|
|||
|
|
markdown_text = """
|
|||
|
|
# 标题 1
|
|||
|
|
QQQ
|
|||
|
|
|
|||
|
|
## 标题 2
|
|||
|
|
WWW
|
|||
|
|
|
|||
|
|
### 标题 3
|
|||
|
|
EEE
|
|||
|
|
"""
|
|||
|
|
splitter = MarkdownTextSplitter()
|
|||
|
|
splits = splitter.split_markdown_text(markdown_text, doc_source)
|
|||
|
|
|
|||
|
|
for split in splits:
|
|||
|
|
print(f"Header: {split.metadata}, Content: {split.page_content}")
|