89 lines
3.0 KiB
Python
89 lines
3.0 KiB
Python
import re
|
|
from typing import List, Optional, Any
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _split_text_with_regex_from_end(
|
|
text: str, separator: str, keep_separator: bool
|
|
) -> List[str]:
|
|
# Now that we have the separator, split the text
|
|
if separator:
|
|
if keep_separator:
|
|
# The parentheses in the pattern keep the delimiters in the result.
|
|
_splits = re.split(f"({separator})", text)
|
|
splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])]
|
|
if len(_splits) % 2 == 1:
|
|
splits += _splits[-1:]
|
|
# splits = [_splits[0]] + splits
|
|
else:
|
|
splits = re.split(separator, text)
|
|
else:
|
|
splits = list(text)
|
|
return [s for s in splits if s != ""]
|
|
|
|
|
|
class ChineseRecursiveParagraphSplitter(RecursiveCharacterTextSplitter):
|
|
def __init__(
|
|
self,
|
|
separators: Optional[List[str]] = None,
|
|
keep_separator: bool = True,
|
|
is_separator_regex: bool = True,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Create a new TextSplitter."""
|
|
super().__init__(keep_separator=keep_separator, **kwargs)
|
|
self._separators = separators or [
|
|
"\n\n",
|
|
"\n",
|
|
'\r\n',
|
|
'\r'
|
|
]
|
|
self._is_separator_regex = is_separator_regex
|
|
|
|
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
|
"""Split incoming text and return chunks."""
|
|
final_chunks = []
|
|
# Get appropriate separator to use
|
|
separator = separators[-1]
|
|
new_separators = []
|
|
for i, _s in enumerate(separators):
|
|
_separator = _s if self._is_separator_regex else re.escape(_s)
|
|
if _s == "":
|
|
separator = _s
|
|
break
|
|
if re.search(_separator, text):
|
|
separator = _s
|
|
new_separators = separators[i + 1:]
|
|
break
|
|
|
|
_separator = separator if self._is_separator_regex else re.escape(separator)
|
|
splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator)
|
|
final_chunks = splits
|
|
|
|
return [re.sub(r"\n{2,}", "\n", chunk.strip()) for chunk in final_chunks if chunk.strip()!=""]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
text_splitter = ChineseRecursiveParagraphSplitter(
|
|
keep_separator=True,
|
|
is_separator_regex=True,
|
|
chunk_size=1,
|
|
chunk_overlap=0
|
|
)
|
|
import sys
|
|
sys.path.append('../../../GCY-RAG-LangChain-ChatChat/')
|
|
filepath = "/home/work/project/test_result.csv"
|
|
import document_loaders
|
|
|
|
loader = document_loaders.RapidOCRCSVLoader(filepath, autodetect_encoding=True)
|
|
docs = loader.load()
|
|
for inum, text in enumerate(docs):
|
|
print(inum)
|
|
chunks = text_splitter.split_text(text.page_content)
|
|
for idx, chunk in enumerate(chunks):
|
|
print(f'///////////////////////// idx:{idx} //////////////////////////')
|
|
print(len(chunk))
|