Files
gangyan/langchain-chat/text_splitter/MarkdownTextSplitter.py

172 lines
6.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
from typing import List
from pydantic import BaseModel, Field
from typing_extensions import Literal
from configs.kb_config import CHUNK_SIZE, OVERLAP_SIZE
class Document(BaseModel):
page_content: str
metadata: dict = Field(default_factory=dict)
type: Literal["Document"] = "Document"
class MarkdownTextSplitter:
def __init__(self, headers_to_split_on: List[str] = None, **kwargs):
self.chunk_size = CHUNK_SIZE
self.overlap_size = OVERLAP_SIZE
self.headers_to_split_on = headers_to_split_on or ["#", "##", "###", "####"]
def clean_text(self, text: str) -> str:
"""
清理文本中的特殊符号,如 \n\t\\n 等,及图片格式链接,如 ![](image_path)
"""
# 去除 \n、\t、\\n 等多余的特殊符号
text = text.replace("\n", " ").replace("\t", " ").replace("\\n", " ").strip()
# 正则匹配 Markdown 图片格式并清除,例如 ![](image_path)
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
return text
def split_text_by_headers(self, markdown_document: str):
"""
使用正则表达式将 Markdown 文档根据标题分段
"""
header_pattern = r"^(#{1,6})\s+(.*)$" # 捕获所有级别的标题(从 # 到 ######
sections = []
current_header = None
current_content = []
current_header_level = 0
# 按行处理 markdown 文档
for line in markdown_document.split("\n"):
match = re.match(header_pattern, line)
if match:
# 如果找到标题,处理之前的部分
if current_header:
sections.append((current_header, current_header_level, "\n".join(current_content)))
# 更新标题和内容
current_header = match.group(2).strip() # 标题内容
current_header_level = len(match.group(1)) # 标题级别,# 代表 h1, ## 代表 h2 等
current_content = []
else:
# 否则将该行加入当前内容
current_content.append(self.clean_text(line))
# 添加最后一个部分
if current_header:
sections.append((current_header, current_header_level, "\n".join(current_content)))
return sections
def split_paragraphs(self, content: str) -> List[str]:
"""
按照 chunk_size 将文本分段,保持语句完整,尽量在句子结束处分段。
分段规则:按句号、换行符分割,且每段字数不超过 chunk_size。若没有合适的标点符号则强制截断。
"""
paragraphs = []
current_paragraph = ""
# 用正则按句子结束符(句号、问号、感叹号等)分割文本
sentences = re.split(r'([。!?])', content) # 捕获句子结尾符
# 将分割后的句子和句末标点符号重新拼接成完整的句子
sentences = [s.strip() + (sentences[i + 1] if i + 1 < len(sentences) else '')
for i, s in enumerate(sentences) if i % 2 == 0]
# 遍历所有句子
for sentence in sentences:
sentence = self.clean_text(sentence)
# 判断当前句子是否加入到当前段落
if len(current_paragraph) + len(sentence) + 1 <= self.chunk_size:
# 如果加入后不超过最大字数,继续添加到当前段落
current_paragraph += (" " + sentence) if current_paragraph else sentence
else:
# 如果当前段落已超过字数限制,则检查是否能在当前句子末尾进行分割
if len(current_paragraph) > self.chunk_size:
paragraphs.append(current_paragraph[:self.chunk_size]) # 强制截断
current_paragraph = current_paragraph[self.chunk_size:] # 剩余部分移到下一个段落
# 处理当前句子
if len(sentence) > self.chunk_size:
# 如果单个句子超过 chunk_size则强制截断
while len(sentence) > self.chunk_size:
paragraphs.append(sentence[:self.chunk_size])
sentence = sentence[self.chunk_size:]
# 最后将剩余的句子添加到当前段落
current_paragraph = sentence
# 添加最后一个段落(如果有的话)
if current_paragraph:
paragraphs.append(current_paragraph)
return paragraphs
def split_documents(self, sections: List[str], doc_source) -> List[Document]:
"""
递归分段,根据每个 Markdown 文档部分生成对应的内容
"""
final_splits = []
for header, header_level, content in sections:
# 递归分段每个部分
paragraphs = self.split_paragraphs(content)
# 根据标题级别,将标题存入对应的 h1, h2, h3 等字段
metadata = {"source": doc_source, "header": header}
if header_level == 1:
metadata["h1"] = header
elif header_level == 2:
metadata["h2"] = header
elif header_level == 3:
metadata["h3"] = header
elif header_level == 4:
metadata["h4"] = header
elif header_level == 5:
metadata["h5"] = header
elif header_level == 6:
metadata["h6"] = header
# 对每个段落创建 Document 对象
for paragraph in paragraphs:
doc = Document(
page_content=paragraph,
metadata=metadata
)
final_splits.append(doc)
return final_splits
def split_markdown_text(self, markdown_document: str, doc_source: str) -> List[Document]:
# 首先根据标题分段
sections = self.split_text_by_headers(markdown_document)
# 进一步分段并创建 Document 对象
final_splits = self.split_documents(sections, doc_source)
# 返回最终的文档段落列表
return final_splits
# 示例使用
if __name__ == "__main__":
doc_source = ""
markdown_text = """
# 标题 1
QQQ
## 标题 2
WWW
### 标题 3
EEE
"""
splitter = MarkdownTextSplitter()
splits = splitter.split_markdown_text(markdown_text, doc_source)
for split in splits:
print(f"Header: {split.metadata}, Content: {split.page_content}")