Files
gangyan/langchain-chat/text_splitter/chinese_recursive_paragraph_splitter.py

89 lines
3.0 KiB
Python

import re
from typing import List, Optional, Any
from langchain.text_splitter import RecursiveCharacterTextSplitter
import logging
logger = logging.getLogger(__name__)
def _split_text_with_regex_from_end(
text: str, separator: str, keep_separator: bool
) -> List[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])]
if len(_splits) % 2 == 1:
splits += _splits[-1:]
# splits = [_splits[0]] + splits
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if s != ""]
class ChineseRecursiveParagraphSplitter(RecursiveCharacterTextSplitter):
def __init__(
self,
separators: Optional[List[str]] = None,
keep_separator: bool = True,
is_separator_regex: bool = True,
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(keep_separator=keep_separator, **kwargs)
self._separators = separators or [
"\n\n",
"\n",
'\r\n',
'\r'
]
self._is_separator_regex = is_separator_regex
def _split_text(self, text: str, separators: List[str]) -> List[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = separators[-1]
new_separators = []
for i, _s in enumerate(separators):
_separator = _s if self._is_separator_regex else re.escape(_s)
if _s == "":
separator = _s
break
if re.search(_separator, text):
separator = _s
new_separators = separators[i + 1:]
break
_separator = separator if self._is_separator_regex else re.escape(separator)
splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator)
final_chunks = splits
return [re.sub(r"\n{2,}", "\n", chunk.strip()) for chunk in final_chunks if chunk.strip()!=""]
if __name__ == "__main__":
text_splitter = ChineseRecursiveParagraphSplitter(
keep_separator=True,
is_separator_regex=True,
chunk_size=1,
chunk_overlap=0
)
import sys
sys.path.append('../../../GCY-RAG-LangChain-ChatChat/')
filepath = "/home/work/project/test_result.csv"
import document_loaders
loader = document_loaders.RapidOCRCSVLoader(filepath, autodetect_encoding=True)
docs = loader.load()
for inum, text in enumerate(docs):
print(inum)
chunks = text_splitter.split_text(text.page_content)
for idx, chunk in enumerate(chunks):
print(f'///////////////////////// idx:{idx} //////////////////////////')
print(len(chunk))