172 lines
6.5 KiB
Python
172 lines
6.5 KiB
Python
import re
|
||
from typing import List
|
||
from pydantic import BaseModel, Field
|
||
from typing_extensions import Literal
|
||
|
||
from configs.kb_config import CHUNK_SIZE, OVERLAP_SIZE
|
||
|
||
|
||
class Document(BaseModel):
|
||
page_content: str
|
||
metadata: dict = Field(default_factory=dict)
|
||
type: Literal["Document"] = "Document"
|
||
|
||
|
||
class MarkdownTextSplitter:
|
||
def __init__(self, headers_to_split_on: List[str] = None, **kwargs):
|
||
self.chunk_size = CHUNK_SIZE
|
||
self.overlap_size = OVERLAP_SIZE
|
||
self.headers_to_split_on = headers_to_split_on or ["#", "##", "###", "####"]
|
||
|
||
def clean_text(self, text: str) -> str:
|
||
"""
|
||
清理文本中的特殊符号,如 \n、\t、\\n 等,及图片格式链接,如 
|
||
"""
|
||
# 去除 \n、\t、\\n 等多余的特殊符号
|
||
text = text.replace("\n", " ").replace("\t", " ").replace("\\n", " ").strip()
|
||
|
||
# 正则匹配 Markdown 图片格式并清除,例如 
|
||
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
|
||
|
||
return text
|
||
def split_text_by_headers(self, markdown_document: str):
|
||
"""
|
||
使用正则表达式将 Markdown 文档根据标题分段
|
||
"""
|
||
header_pattern = r"^(#{1,6})\s+(.*)$" # 捕获所有级别的标题(从 # 到 ######)
|
||
sections = []
|
||
current_header = None
|
||
current_content = []
|
||
current_header_level = 0
|
||
|
||
# 按行处理 markdown 文档
|
||
for line in markdown_document.split("\n"):
|
||
match = re.match(header_pattern, line)
|
||
if match:
|
||
# 如果找到标题,处理之前的部分
|
||
if current_header:
|
||
sections.append((current_header, current_header_level, "\n".join(current_content)))
|
||
# 更新标题和内容
|
||
current_header = match.group(2).strip() # 标题内容
|
||
current_header_level = len(match.group(1)) # 标题级别,# 代表 h1, ## 代表 h2 等
|
||
current_content = []
|
||
else:
|
||
# 否则将该行加入当前内容
|
||
current_content.append(self.clean_text(line))
|
||
|
||
# 添加最后一个部分
|
||
if current_header:
|
||
sections.append((current_header, current_header_level, "\n".join(current_content)))
|
||
|
||
return sections
|
||
|
||
def split_paragraphs(self, content: str) -> List[str]:
|
||
"""
|
||
按照 chunk_size 将文本分段,保持语句完整,尽量在句子结束处分段。
|
||
分段规则:按句号、换行符分割,且每段字数不超过 chunk_size。若没有合适的标点符号,则强制截断。
|
||
"""
|
||
paragraphs = []
|
||
current_paragraph = ""
|
||
|
||
# 用正则按句子结束符(句号、问号、感叹号等)分割文本
|
||
sentences = re.split(r'([。!?])', content) # 捕获句子结尾符
|
||
|
||
# 将分割后的句子和句末标点符号重新拼接成完整的句子
|
||
sentences = [s.strip() + (sentences[i + 1] if i + 1 < len(sentences) else '')
|
||
for i, s in enumerate(sentences) if i % 2 == 0]
|
||
|
||
# 遍历所有句子
|
||
for sentence in sentences:
|
||
sentence = self.clean_text(sentence)
|
||
|
||
# 判断当前句子是否加入到当前段落
|
||
if len(current_paragraph) + len(sentence) + 1 <= self.chunk_size:
|
||
# 如果加入后不超过最大字数,继续添加到当前段落
|
||
current_paragraph += (" " + sentence) if current_paragraph else sentence
|
||
else:
|
||
# 如果当前段落已超过字数限制,则检查是否能在当前句子末尾进行分割
|
||
if len(current_paragraph) > self.chunk_size:
|
||
paragraphs.append(current_paragraph[:self.chunk_size]) # 强制截断
|
||
current_paragraph = current_paragraph[self.chunk_size:] # 剩余部分移到下一个段落
|
||
|
||
# 处理当前句子
|
||
if len(sentence) > self.chunk_size:
|
||
# 如果单个句子超过 chunk_size,则强制截断
|
||
while len(sentence) > self.chunk_size:
|
||
paragraphs.append(sentence[:self.chunk_size])
|
||
sentence = sentence[self.chunk_size:]
|
||
|
||
# 最后将剩余的句子添加到当前段落
|
||
current_paragraph = sentence
|
||
|
||
# 添加最后一个段落(如果有的话)
|
||
if current_paragraph:
|
||
paragraphs.append(current_paragraph)
|
||
|
||
return paragraphs
|
||
def split_documents(self, sections: List[str], doc_source) -> List[Document]:
|
||
"""
|
||
递归分段,根据每个 Markdown 文档部分生成对应的内容
|
||
"""
|
||
final_splits = []
|
||
|
||
for header, header_level, content in sections:
|
||
# 递归分段每个部分
|
||
paragraphs = self.split_paragraphs(content)
|
||
|
||
# 根据标题级别,将标题存入对应的 h1, h2, h3 等字段
|
||
metadata = {"source": doc_source, "header": header}
|
||
if header_level == 1:
|
||
metadata["h1"] = header
|
||
elif header_level == 2:
|
||
metadata["h2"] = header
|
||
elif header_level == 3:
|
||
metadata["h3"] = header
|
||
elif header_level == 4:
|
||
metadata["h4"] = header
|
||
elif header_level == 5:
|
||
metadata["h5"] = header
|
||
elif header_level == 6:
|
||
metadata["h6"] = header
|
||
|
||
# 对每个段落创建 Document 对象
|
||
for paragraph in paragraphs:
|
||
doc = Document(
|
||
page_content=paragraph,
|
||
metadata=metadata
|
||
)
|
||
final_splits.append(doc)
|
||
|
||
return final_splits
|
||
|
||
def split_markdown_text(self, markdown_document: str, doc_source: str) -> List[Document]:
|
||
|
||
# 首先根据标题分段
|
||
sections = self.split_text_by_headers(markdown_document)
|
||
|
||
# 进一步分段并创建 Document 对象
|
||
final_splits = self.split_documents(sections, doc_source)
|
||
|
||
# 返回最终的文档段落列表
|
||
return final_splits
|
||
|
||
|
||
# 示例使用
|
||
if __name__ == "__main__":
|
||
doc_source = ""
|
||
markdown_text = """
|
||
# 标题 1
|
||
QQQ
|
||
|
||
## 标题 2
|
||
WWW
|
||
|
||
### 标题 3
|
||
EEE
|
||
"""
|
||
splitter = MarkdownTextSplitter()
|
||
splits = splitter.split_markdown_text(markdown_text, doc_source)
|
||
|
||
for split in splits:
|
||
print(f"Header: {split.metadata}, Content: {split.page_content}")
|