Files
gangyan/langchain-chat/text_splitter/MarkdownTextSplitter.py

172 lines
6.5 KiB
Python
Raw Normal View History

import re
from typing import List
from pydantic import BaseModel, Field
from typing_extensions import Literal
from configs.kb_config import CHUNK_SIZE, OVERLAP_SIZE
class Document(BaseModel):
page_content: str
metadata: dict = Field(default_factory=dict)
type: Literal["Document"] = "Document"
class MarkdownTextSplitter:
def __init__(self, headers_to_split_on: List[str] = None, **kwargs):
self.chunk_size = CHUNK_SIZE
self.overlap_size = OVERLAP_SIZE
self.headers_to_split_on = headers_to_split_on or ["#", "##", "###", "####"]
def clean_text(self, text: str) -> str:
"""
清理文本中的特殊符号 \n\t\\n 及图片格式链接 ![](image_path)
"""
# 去除 \n、\t、\\n 等多余的特殊符号
text = text.replace("\n", " ").replace("\t", " ").replace("\\n", " ").strip()
# 正则匹配 Markdown 图片格式并清除,例如 ![](image_path)
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
return text
def split_text_by_headers(self, markdown_document: str):
"""
使用正则表达式将 Markdown 文档根据标题分段
"""
header_pattern = r"^(#{1,6})\s+(.*)$" # 捕获所有级别的标题(从 # 到 ######
sections = []
current_header = None
current_content = []
current_header_level = 0
# 按行处理 markdown 文档
for line in markdown_document.split("\n"):
match = re.match(header_pattern, line)
if match:
# 如果找到标题,处理之前的部分
if current_header:
sections.append((current_header, current_header_level, "\n".join(current_content)))
# 更新标题和内容
current_header = match.group(2).strip() # 标题内容
current_header_level = len(match.group(1)) # 标题级别,# 代表 h1, ## 代表 h2 等
current_content = []
else:
# 否则将该行加入当前内容
current_content.append(self.clean_text(line))
# 添加最后一个部分
if current_header:
sections.append((current_header, current_header_level, "\n".join(current_content)))
return sections
def split_paragraphs(self, content: str) -> List[str]:
"""
按照 chunk_size 将文本分段保持语句完整尽量在句子结束处分段
分段规则按句号换行符分割且每段字数不超过 chunk_size若没有合适的标点符号则强制截断
"""
paragraphs = []
current_paragraph = ""
# 用正则按句子结束符(句号、问号、感叹号等)分割文本
sentences = re.split(r'([。!?])', content) # 捕获句子结尾符
# 将分割后的句子和句末标点符号重新拼接成完整的句子
sentences = [s.strip() + (sentences[i + 1] if i + 1 < len(sentences) else '')
for i, s in enumerate(sentences) if i % 2 == 0]
# 遍历所有句子
for sentence in sentences:
sentence = self.clean_text(sentence)
# 判断当前句子是否加入到当前段落
if len(current_paragraph) + len(sentence) + 1 <= self.chunk_size:
# 如果加入后不超过最大字数,继续添加到当前段落
current_paragraph += (" " + sentence) if current_paragraph else sentence
else:
# 如果当前段落已超过字数限制,则检查是否能在当前句子末尾进行分割
if len(current_paragraph) > self.chunk_size:
paragraphs.append(current_paragraph[:self.chunk_size]) # 强制截断
current_paragraph = current_paragraph[self.chunk_size:] # 剩余部分移到下一个段落
# 处理当前句子
if len(sentence) > self.chunk_size:
# 如果单个句子超过 chunk_size则强制截断
while len(sentence) > self.chunk_size:
paragraphs.append(sentence[:self.chunk_size])
sentence = sentence[self.chunk_size:]
# 最后将剩余的句子添加到当前段落
current_paragraph = sentence
# 添加最后一个段落(如果有的话)
if current_paragraph:
paragraphs.append(current_paragraph)
return paragraphs
def split_documents(self, sections: List[str], doc_source) -> List[Document]:
"""
递归分段根据每个 Markdown 文档部分生成对应的内容
"""
final_splits = []
for header, header_level, content in sections:
# 递归分段每个部分
paragraphs = self.split_paragraphs(content)
# 根据标题级别,将标题存入对应的 h1, h2, h3 等字段
metadata = {"source": doc_source, "header": header}
if header_level == 1:
metadata["h1"] = header
elif header_level == 2:
metadata["h2"] = header
elif header_level == 3:
metadata["h3"] = header
elif header_level == 4:
metadata["h4"] = header
elif header_level == 5:
metadata["h5"] = header
elif header_level == 6:
metadata["h6"] = header
# 对每个段落创建 Document 对象
for paragraph in paragraphs:
doc = Document(
page_content=paragraph,
metadata=metadata
)
final_splits.append(doc)
return final_splits
def split_markdown_text(self, markdown_document: str, doc_source: str) -> List[Document]:
# 首先根据标题分段
sections = self.split_text_by_headers(markdown_document)
# 进一步分段并创建 Document 对象
final_splits = self.split_documents(sections, doc_source)
# 返回最终的文档段落列表
return final_splits
# 示例使用
if __name__ == "__main__":
doc_source = ""
markdown_text = """
# 标题 1
QQQ
## 标题 2
WWW
### 标题 3
EEE
"""
splitter = MarkdownTextSplitter()
splits = splitter.split_markdown_text(markdown_text, doc_source)
for split in splits:
print(f"Header: {split.metadata}, Content: {split.page_content}")