[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
81
langchain-chat/document_loaders/FilteredCSVloader.py
Normal file
81
langchain-chat/document_loaders/FilteredCSVloader.py
Normal file
@@ -0,0 +1,81 @@
|
||||
## 指定制定列的csv文件加载器
|
||||
|
||||
from langchain.document_loaders import CSVLoader
|
||||
import csv
|
||||
from io import TextIOWrapper
|
||||
from typing import Dict, List, Optional
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
|
||||
|
||||
class FilteredCSVLoader(CSVLoader):
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
columns_to_read: List[str],
|
||||
source_column: Optional[str] = None,
|
||||
metadata_columns: List[str] = [],
|
||||
csv_args: Optional[Dict] = None,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
file_path=file_path,
|
||||
source_column=source_column,
|
||||
metadata_columns=metadata_columns,
|
||||
csv_args=csv_args,
|
||||
encoding=encoding,
|
||||
autodetect_encoding=autodetect_encoding,
|
||||
)
|
||||
self.columns_to_read = columns_to_read
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load data into document objects."""
|
||||
|
||||
docs = []
|
||||
try:
|
||||
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
|
||||
docs = self.__read_file(csvfile)
|
||||
except UnicodeDecodeError as e:
|
||||
if self.autodetect_encoding:
|
||||
detected_encodings = detect_file_encodings(self.file_path)
|
||||
for encoding in detected_encodings:
|
||||
try:
|
||||
with open(
|
||||
self.file_path, newline="", encoding=encoding.encoding
|
||||
) as csvfile:
|
||||
docs = self.__read_file(csvfile)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"Error loading {self.file_path}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error loading {self.file_path}") from e
|
||||
|
||||
return docs
|
||||
|
||||
def __read_file(self, csvfile: TextIOWrapper) -> List[Document]:
|
||||
docs = []
|
||||
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
|
||||
for i, row in enumerate(csv_reader):
|
||||
if self.columns_to_read[0] in row:
|
||||
content = row[self.columns_to_read[0]]
|
||||
# Extract the source if available
|
||||
source = (
|
||||
row.get(self.source_column, None)
|
||||
if self.source_column is not None
|
||||
else self.file_path
|
||||
)
|
||||
metadata = {"source": source, "row": i}
|
||||
|
||||
for col in self.metadata_columns:
|
||||
if col in row:
|
||||
metadata[col] = row[col]
|
||||
|
||||
doc = Document(page_content=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
else:
|
||||
raise ValueError(f"Column '{self.columns_to_read[0]}' not found in CSV file.")
|
||||
|
||||
return docs
|
||||
Binary file not shown.
Binary file not shown.
8
langchain-chat/document_loaders/__init__.py
Normal file
8
langchain-chat/document_loaders/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from .myimgloader import RapidOCRLoader
|
||||
from .mydocloader import RapidOCRDocLoader
|
||||
from .mypptloader import RapidOCRPPTLoader
|
||||
from .mycsvloader import RapidOCRCSVLoader
|
||||
from .GCYWordLoader import GCYWordLoader
|
||||
from .gycWordLoader import GCYWordLoader as GCYWordLoader2
|
||||
from .GCYHTMLLoader import GCYHTMLLoader
|
||||
from .myexcelloader import ExcelLoader
|
||||
204
langchain-chat/document_loaders/gycWordLoader.py
Normal file
204
langchain-chat/document_loaders/gycWordLoader.py
Normal file
@@ -0,0 +1,204 @@
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
from pathlib import Path
|
||||
from langchain_core.documents import Document
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
from docx import Document as DocxDocument
|
||||
import os
|
||||
import subprocess
|
||||
import logging
|
||||
import zipfile
|
||||
from lxml import etree as ET
|
||||
import re
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GCYWordLoader(BaseLoader):
|
||||
"""用于加载和解析 Word 文档的自定义加载器,支持标题层级结构解析及XML级别内容(段落、表格、页眉页脚、批注、修订、目录)。"""
|
||||
|
||||
# WordprocessingML命名空间
|
||||
ns = {
|
||||
"w": "http://schemas.openxmlformats.org/wordprocessingml/2006/main"
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: Union[str, Path],
|
||||
output_dir: Optional[Union[str, Path]] = None,
|
||||
*,
|
||||
keep_doc_title: bool = True,
|
||||
start_with_title: bool = False,
|
||||
max_heading_level: int = 3,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
self.file_path = str(file_path)
|
||||
self.output_dir = str(output_dir) if output_dir else os.path.dirname(self.file_path)
|
||||
self.keep_doc_title = keep_doc_title
|
||||
self.start_with_title = start_with_title
|
||||
self.max_heading_level = min(max(1, max_heading_level), 6)
|
||||
self.metadata = metadata or {}
|
||||
|
||||
# 临时解压目录
|
||||
self._work_dir = os.path.join(self.output_dir, '_pyc_work')
|
||||
self._doc_dir = os.path.join(self._work_dir, 'word')
|
||||
|
||||
# 验证
|
||||
if not os.path.isfile(self.file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {self.file_path}")
|
||||
if not self.file_path.lower().endswith(('.doc', '.docx')):
|
||||
raise ValueError("仅支持 .doc 或 .docx 格式")
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
try:
|
||||
# 预处理(.doc 转 .docx)
|
||||
processed_path = self._preprocess_document()
|
||||
|
||||
# 准备工作目录并解压
|
||||
self._prepare_work_directories()
|
||||
self._extract_docx(processed_path)
|
||||
|
||||
# 解析主文档 XML
|
||||
self._parse_document_xml()
|
||||
|
||||
# 构建文档片段列表
|
||||
docs: List[Document] = []
|
||||
base_meta = {"source": self.file_path, "file_name": os.path.basename(self.file_path), **self.metadata}
|
||||
|
||||
# 支持保留文档标题
|
||||
docx_core = DocxDocument(processed_path)
|
||||
if self.keep_doc_title and docx_core.core_properties.title:
|
||||
title = docx_core.core_properties.title
|
||||
if self.start_with_title:
|
||||
docs.append(Document(page_content=title, metadata={**base_meta, "heading": "Document Title"}))
|
||||
else:
|
||||
docs.append(Document(page_content=title, metadata={**base_meta, "heading": None}))
|
||||
|
||||
# 段落
|
||||
docs.extend(self._parse_paragraphs(base_meta))
|
||||
# 表格
|
||||
docs.extend(self._parse_tables(base_meta))
|
||||
# 页眉与页脚
|
||||
docs.extend(self._parse_headers_and_footers(base_meta))
|
||||
# 批注
|
||||
docs.extend(self._parse_comments(base_meta))
|
||||
# 修订
|
||||
docs.extend(self._parse_revisions(base_meta))
|
||||
# 目录项
|
||||
docs.extend(self._parse_toc(base_meta))
|
||||
|
||||
return docs
|
||||
except Exception as e:
|
||||
logger.error(f"文档加载失败: {self.file_path}", exc_info=True)
|
||||
raise RuntimeError(f"无法加载文档: {e}")
|
||||
finally:
|
||||
# 清理临时
|
||||
if os.path.exists(self._work_dir):
|
||||
try:
|
||||
import shutil; shutil.rmtree(self._work_dir)
|
||||
except:
|
||||
pass
|
||||
|
||||
def _preprocess_document(self) -> str:
|
||||
# .docx 直接返回
|
||||
if self.file_path.lower().endswith('.docx'):
|
||||
return self.file_path
|
||||
# .doc 转 .docx
|
||||
output_path = os.path.join(self.output_dir, Path(self.file_path).stem + '.docx')
|
||||
subprocess.run([
|
||||
'soffice', '--headless', '--convert-to', 'docx', '--outdir', self.output_dir, self.file_path
|
||||
], check=True, capture_output=True)
|
||||
if not os.path.exists(output_path):
|
||||
raise RuntimeError('文档格式转换失败')
|
||||
return output_path
|
||||
|
||||
def _prepare_work_directories(self):
|
||||
if os.path.exists(self._work_dir):
|
||||
import shutil; shutil.rmtree(self._work_dir)
|
||||
os.makedirs(self._doc_dir, exist_ok=True)
|
||||
|
||||
def _extract_docx(self, path: str):
|
||||
with zipfile.ZipFile(path, 'r') as z:
|
||||
z.extractall(self._work_dir)
|
||||
|
||||
def _parse_document_xml(self):
|
||||
xml_path = os.path.join(self._doc_dir, 'document.xml')
|
||||
parser = ET.XMLParser(remove_blank_text=True)
|
||||
self._doc_tree = ET.parse(xml_path, parser)
|
||||
self._doc_root = self._doc_tree.getroot()
|
||||
|
||||
def _get_text_from_runs(self, parent) -> str:
|
||||
parts = []
|
||||
for r in parent.findall('.//w:r', self.ns):
|
||||
for t in r.findall('w:t', self.ns):
|
||||
if t.text:
|
||||
parts.append(t.text)
|
||||
return ''.join(parts).strip()
|
||||
|
||||
def _parse_paragraphs(self, base_meta) -> List[Document]:
|
||||
docs = []
|
||||
for p in self._doc_root.findall('.//w:p', self.ns):
|
||||
text = self._get_text_from_runs(p)
|
||||
if not text:
|
||||
continue
|
||||
docs.append(Document(page_content=text, metadata={**base_meta, 'content_type':'paragraph'}))
|
||||
return docs
|
||||
|
||||
def _parse_tables(self, base_meta) -> List[Document]:
|
||||
docs = []
|
||||
for tbl in self._doc_root.findall('.//w:tbl', self.ns):
|
||||
# 简单按行为 \n 分段
|
||||
rows = []
|
||||
for row in tbl.findall('.//w:tr', self.ns):
|
||||
cells = []
|
||||
for cell in row.findall('.//w:tc', self.ns):
|
||||
cells.append(self._get_text_from_runs(cell))
|
||||
rows.append('|'.join(cells))
|
||||
content = '\n'.join(rows)
|
||||
docs.append(Document(page_content=content, metadata={**base_meta, 'content_type':'table'}))
|
||||
return docs
|
||||
|
||||
def _parse_headers_and_footers(self, base_meta) -> List[Document]:
|
||||
docs = []
|
||||
for part in ['header', 'footer']:
|
||||
for i in range(1,10):
|
||||
path = os.path.join(self._doc_dir, f'{part}{i}.xml')
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
tree = ET.parse(path)
|
||||
root = tree.getroot()
|
||||
for p in root.findall('.//w:p', self.ns):
|
||||
text = self._get_text_from_runs(p)
|
||||
if text:
|
||||
docs.append(Document(page_content=text, metadata={**base_meta, 'content_type':part}))
|
||||
return docs
|
||||
|
||||
def _parse_comments(self, base_meta) -> List[Document]:
|
||||
docs = []
|
||||
path = os.path.join(self._doc_dir, 'comments.xml')
|
||||
if os.path.exists(path):
|
||||
tree = ET.parse(path)
|
||||
root = tree.getroot()
|
||||
for c in root.findall('.//w:comment', self.ns):
|
||||
text = self._get_text_from_runs(c)
|
||||
if text:
|
||||
docs.append(Document(page_content=text, metadata={**base_meta, 'content_type':'comment'}))
|
||||
return docs
|
||||
|
||||
def _parse_revisions(self, base_meta) -> List[Document]:
|
||||
docs = []
|
||||
for tag in ['del','ins']:
|
||||
for el in self._doc_root.findall(f'.//w:{tag}', self.ns):
|
||||
text = self._get_text_from_runs(el)
|
||||
if text:
|
||||
docs.append(Document(page_content=text, metadata={**base_meta, 'content_type':'revision'}))
|
||||
return docs
|
||||
|
||||
def _parse_toc(self, base_meta) -> List[Document]:
|
||||
docs = []
|
||||
for p in self._doc_root.findall('.//w:p', self.ns):
|
||||
full = ''.join(t.text or '' for t in p.findall('w:t', self.ns)).strip()
|
||||
if len(full)>255: continue
|
||||
if p.find('.//w:tab', self.ns) is not None and full and full[-1].isdigit():
|
||||
# 简单拆分标题和页码
|
||||
parts = re.split(r'\t+', full)
|
||||
title = parts[0]
|
||||
docs.append(Document(page_content=title, metadata={**base_meta,'content_type':'toc_entry'}))
|
||||
return docs
|
||||
158
langchain-chat/document_loaders/mycsvloader.py
Normal file
158
langchain-chat/document_loaders/mycsvloader.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from langchain_core.documents import Document
|
||||
from typing import List
|
||||
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||
# 导入模块
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import pandas as pd
|
||||
|
||||
# TODO: 以CSV格式加载内容
|
||||
MAX_CONTENT = 18000
|
||||
SPECIAL_SYMBOL = ['\'', '"']
|
||||
class RapidOCRCSVLoader(UnstructuredFileLoader):
|
||||
|
||||
def _get_elements(self) -> List:
|
||||
meta_path = self.file_path.split("/")[-1]
|
||||
|
||||
# 读取csv文件
|
||||
def read_csv_file(file_path, sep = ',', encoding = 'utf-8'):
|
||||
df_data = pd.read_csv(file_path, sep = sep, encoding = encoding)
|
||||
df_data = df_data.apply(remove_duplicate_pandas, axis = 1)
|
||||
df_data = df_data.dropna(subset=['HTML全文'])
|
||||
|
||||
return df_data
|
||||
|
||||
def remove_duplicate_string(x):
|
||||
if not isinstance(x, str):
|
||||
return x
|
||||
if len(x) <= 0:
|
||||
return x
|
||||
|
||||
if x[0] in SPECIAL_SYMBOL:
|
||||
x = x[1:]
|
||||
|
||||
if x[-1] in SPECIAL_SYMBOL:
|
||||
x = x[:-1]
|
||||
|
||||
return x
|
||||
|
||||
def remove_duplicate_pandas(row):
|
||||
headers = ['标题', 'HTML全文','发布时间', '发布机构来源', '发布机构',
|
||||
'适用地区', '主题', '详情地址', '资源来源分类',
|
||||
'资源来源名称', '资源来源标识']
|
||||
|
||||
for col in headers:
|
||||
row[col] = remove_duplicate_string(row[col])
|
||||
|
||||
if len(row[col]) > MAX_CONTENT:
|
||||
row[col] = row[col][:MAX_CONTENT]
|
||||
|
||||
return row
|
||||
|
||||
def apply_title_content(x):
|
||||
# 处理 title相关的内容
|
||||
titles = [x['标题']]
|
||||
pattern = r'##title##(.*?)##/title##'
|
||||
title_list = re.findall(pattern, x['HTML全文'])
|
||||
titles += title_list
|
||||
|
||||
# 去除内容中标题的特殊符号
|
||||
content = x['HTML全文'].replace('##title##', '').replace('##/title##', '').replace('\\u3000', '')
|
||||
title_str = "\n".join(titles)
|
||||
tilte_str = title_str.replace('\\u3000', '')
|
||||
|
||||
content = content.replace('#####', '\n')
|
||||
|
||||
# 把内容相关提出
|
||||
title_idx = content.rfind("##/title##") + 1
|
||||
|
||||
return Document(page_content=x['标题'],
|
||||
metadata={"path": meta_path,
|
||||
"_type": 'title',
|
||||
'title' : x['标题'],
|
||||
"content": content ,
|
||||
# 'file_number': x['发文文号'],
|
||||
# 'keywords' : x['关键词'],
|
||||
'release_date' : x['发布时间'],
|
||||
'sourceOrganization' : x['发布机构来源'],
|
||||
'organization' : x['发布机构'],
|
||||
'region' : x['适用地区'],
|
||||
'subject' : x['主题'],
|
||||
'source' : x['详情地址'],
|
||||
'datasource' : x['资源来源分类'],
|
||||
'datasourceclass' : x['资源来源名称'],
|
||||
'datasource_key' : x['资源来源标识'],
|
||||
"xml": title_str,
|
||||
"tag": tag})
|
||||
|
||||
# title作为query,content作为answer
|
||||
def title_content2text(pd_obj, filepath, tag = 'csv'):
|
||||
ret_list = []
|
||||
pd_obj = pd_obj.apply(apply_title_content, axis=1)
|
||||
ret_list = pd_obj.values.tolist()
|
||||
return ret_list
|
||||
|
||||
def apply_content(x):
|
||||
# 去除内容中标题的特殊符号
|
||||
# 把内容相关提出
|
||||
content = x['HTML全文']
|
||||
title_idx = content.rfind("##/title##") + 1
|
||||
if title_idx > 0 and title_idx < len(content):
|
||||
content = content[title_idx + 1:]
|
||||
content = content.replace('##title##', '').replace('##/title##', '').replace('\\u3000', '')
|
||||
content = content.replace('#####', '\n')
|
||||
|
||||
# 对内容进行切分
|
||||
return Document(page_content=content,
|
||||
metadata={"path": meta_path,
|
||||
"_type": 'content',
|
||||
'title' : x['标题'],
|
||||
"content": '' ,
|
||||
# 'file_number': x['发文文号'],
|
||||
# 'keywords' : x['关键词'],
|
||||
'release_date' : x['发布时间'],
|
||||
'sourceOrganization' : x['发布机构来源'],
|
||||
'organization' : x['发布机构'],
|
||||
'region' : x['适用地区'],
|
||||
'subject' : x['主题'],
|
||||
'source' : x['详情地址'],
|
||||
'datasource' : x['资源来源分类'],
|
||||
'datasourceclass' : x['资源来源名称'],
|
||||
'datasource_key' : x['资源来源标识'],
|
||||
"xml": '',
|
||||
"tag": tag})
|
||||
|
||||
def content2text(pd_obj, filepath, tag = 'csv'):
|
||||
ret_list = []
|
||||
pd_obj = pd_obj.apply(apply_content, axis=1)
|
||||
ret_list = pd_obj.values.tolist()
|
||||
return ret_list
|
||||
|
||||
# 读取 csv 文件
|
||||
pd_obj = read_csv_file(self.file_path, r'\^\|\|\^')
|
||||
tag = "csv"
|
||||
|
||||
# 解析文件,获取 query(标题)-value(全文)
|
||||
ret_list_title = title_content2text(pd_obj, self.file_path, 'csv')
|
||||
# 解析文件,获取内容,并进行切割
|
||||
ret_list_content = content2text(pd_obj, self.file_path, 'csv')
|
||||
ret_list = ret_list_title + ret_list_content
|
||||
|
||||
return ret_list
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
return self._get_elements()
|
||||
|
||||
if __name__ == '__main__':
|
||||
loader = RapidOCRCSVLoader(file_path="//home/work/project/test_result.csv")
|
||||
docs = loader.load()
|
||||
|
||||
i =0
|
||||
for doc in docs:
|
||||
|
||||
print(doc)
|
||||
i += 1
|
||||
if i>1000:
|
||||
break
|
||||
71
langchain-chat/document_loaders/mydocloader.py
Normal file
71
langchain-chat/document_loaders/mydocloader.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||
from typing import List
|
||||
import tqdm
|
||||
|
||||
|
||||
class RapidOCRDocLoader(UnstructuredFileLoader):
|
||||
def _get_elements(self) -> List:
|
||||
def doc2text(filepath):
|
||||
from docx.table import _Cell, Table
|
||||
from docx.oxml.table import CT_Tbl
|
||||
from docx.oxml.text.paragraph import CT_P
|
||||
from docx.text.paragraph import Paragraph
|
||||
from docx import Document, ImagePart
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import numpy as np
|
||||
from rapidocr_onnxruntime import RapidOCR
|
||||
ocr = RapidOCR()
|
||||
doc = Document(filepath)
|
||||
resp = ""
|
||||
|
||||
def iter_block_items(parent):
|
||||
from docx.document import Document
|
||||
if isinstance(parent, Document):
|
||||
parent_elm = parent.element.body
|
||||
elif isinstance(parent, _Cell):
|
||||
parent_elm = parent._tc
|
||||
else:
|
||||
raise ValueError("RapidOCRDocLoader parse fail")
|
||||
|
||||
for child in parent_elm.iterchildren():
|
||||
if isinstance(child, CT_P):
|
||||
yield Paragraph(child, parent)
|
||||
elif isinstance(child, CT_Tbl):
|
||||
yield Table(child, parent)
|
||||
|
||||
b_unit = tqdm.tqdm(total=len(doc.paragraphs)+len(doc.tables),
|
||||
desc="RapidOCRDocLoader block index: 0")
|
||||
for i, block in enumerate(iter_block_items(doc)):
|
||||
b_unit.set_description(
|
||||
"RapidOCRDocLoader block index: {}".format(i))
|
||||
b_unit.refresh()
|
||||
if isinstance(block, Paragraph):
|
||||
resp += block.text.strip() + "\n"
|
||||
images = block._element.xpath('.//pic:pic') # 获取所有图片
|
||||
for image in images:
|
||||
for img_id in image.xpath('.//a:blip/@r:embed'): # 获取图片id
|
||||
part = doc.part.related_parts[img_id] # 根据图片id获取对应的图片
|
||||
if isinstance(part, ImagePart):
|
||||
image = Image.open(BytesIO(part._blob))
|
||||
result, _ = ocr(np.array(image))
|
||||
if result:
|
||||
ocr_result = [line[1] for line in result]
|
||||
resp += "\n".join(ocr_result)
|
||||
elif isinstance(block, Table):
|
||||
for row in block.rows:
|
||||
for cell in row.cells:
|
||||
for paragraph in cell.paragraphs:
|
||||
resp += paragraph.text.strip() + "\n"
|
||||
b_unit.update(1)
|
||||
return resp
|
||||
|
||||
text = doc2text(self.file_path)
|
||||
from unstructured.partition.text import partition_text
|
||||
return partition_text(text=text, **self.unstructured_kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
loader = RapidOCRDocLoader(file_path="../tests/samples/ocr_test.docx")
|
||||
docs = loader.load()
|
||||
print(docs)
|
||||
71
langchain-chat/document_loaders/myexcelloader.py
Normal file
71
langchain-chat/document_loaders/myexcelloader.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
from langchain_core.documents import Document
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ExcelLoader(BaseLoader):
|
||||
"""
|
||||
用于加载 Excel 文件(.xls/.xlsx)的 Loader。
|
||||
使用 pandas 解析所有工作表,并将非空单元格内容展平为按逗号分隔的文段。
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
file_path: Union[str, Path],
|
||||
*,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
file_path: Excel 文件路径,支持 .xls 和 .xlsx
|
||||
metadata: 附加的文档级元数据
|
||||
"""
|
||||
self.file_path = str(file_path)
|
||||
self.metadata = metadata or {}
|
||||
|
||||
suffix = Path(self.file_path).suffix.lower()
|
||||
if suffix not in (".xls", ".xlsx"):
|
||||
raise ValueError(f"ExcelLoader 仅支持 .xls/.xlsx 文件: {self.file_path}")
|
||||
if not os.path.isfile(self.file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {self.file_path}")
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""
|
||||
读取 Excel 中的所有工作表,返回每个表格中所有非空单元格按逗号分隔的 Document 列表。
|
||||
"""
|
||||
try:
|
||||
# sheet_name=None 返回 dict: {sheet_name: DataFrame}
|
||||
sheets: Dict[str, pd.DataFrame] = pd.read_excel(
|
||||
self.file_path,
|
||||
sheet_name=None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"读取 Excel 文件失败: {self.file_path}", exc_info=True)
|
||||
raise RuntimeError(f"无法加载 Excel 文件: {e}") from e
|
||||
|
||||
documents: List[Document] = []
|
||||
for sheet_name, df in sheets.items():
|
||||
segments: List[str] = []
|
||||
# 遍历所有单元格
|
||||
for row in df.itertuples(index=False, name=None):
|
||||
for cell in row:
|
||||
if pd.isna(cell):
|
||||
continue
|
||||
text = str(cell).strip()
|
||||
if text:
|
||||
segments.append(text)
|
||||
# 用英文逗号分隔所有文段
|
||||
content = ",".join(segments)
|
||||
md: Dict[str, Any] = {
|
||||
"source": self.file_path,
|
||||
"sheet_name": sheet_name,
|
||||
**self.metadata
|
||||
}
|
||||
documents.append(Document(page_content=content, metadata=md))
|
||||
|
||||
return documents
|
||||
26
langchain-chat/document_loaders/myimgloader.py
Normal file
26
langchain-chat/document_loaders/myimgloader.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import List
|
||||
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||
from document_loaders.ocr import get_ocr
|
||||
|
||||
|
||||
class RapidOCRLoader(UnstructuredFileLoader):
|
||||
def _get_elements(self) -> List:
|
||||
def img2text(filepath):
|
||||
resp = ""
|
||||
ocr = get_ocr()
|
||||
result, _ = ocr(filepath)
|
||||
if result:
|
||||
ocr_result = [line[1] for line in result]
|
||||
resp += "\n".join(ocr_result)
|
||||
return resp
|
||||
|
||||
text = img2text(self.file_path)
|
||||
return text
|
||||
# from unstructured.partition.text import partition_text
|
||||
# return partition_text(text=text, **self.unstructured_kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loader = RapidOCRLoader(file_path="../tests/samples/ocr_test.jpg")
|
||||
docs = loader.load()
|
||||
print(docs)
|
||||
87
langchain-chat/document_loaders/mypdfloader.py
Normal file
87
langchain-chat/document_loaders/mypdfloader.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# from typing import List
|
||||
# from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||
# import cv2
|
||||
# from PIL import Image
|
||||
# import numpy as np
|
||||
# from configs import PDF_OCR_THRESHOLD
|
||||
# from document_loaders.ocr import get_ocr
|
||||
# import tqdm
|
||||
|
||||
|
||||
# class RapidOCRPDFLoader(UnstructuredFileLoader):
|
||||
# def _get_elements(self) -> List:
|
||||
# def rotate_img(img, angle):
|
||||
# '''
|
||||
# img --image
|
||||
# angle --rotation angle
|
||||
# return--rotated img
|
||||
# '''
|
||||
|
||||
# h, w = img.shape[:2]
|
||||
# rotate_center = (w/2, h/2)
|
||||
# #获取旋转矩阵
|
||||
# # 参数1为旋转中心点;
|
||||
# # 参数2为旋转角度,正值-逆时针旋转;负值-顺时针旋转
|
||||
# # 参数3为各向同性的比例因子,1.0原图,2.0变成原来的2倍,0.5变成原来的0.5倍
|
||||
# M = cv2.getRotationMatrix2D(rotate_center, angle, 1.0)
|
||||
# #计算图像新边界
|
||||
# new_w = int(h * np.abs(M[0, 1]) + w * np.abs(M[0, 0]))
|
||||
# new_h = int(h * np.abs(M[0, 0]) + w * np.abs(M[0, 1]))
|
||||
# #调整旋转矩阵以考虑平移
|
||||
# M[0, 2] += (new_w - w) / 2
|
||||
# M[1, 2] += (new_h - h) / 2
|
||||
|
||||
# rotated_img = cv2.warpAffine(img, M, (new_w, new_h))
|
||||
# return rotated_img
|
||||
|
||||
# def pdf2text(filepath):
|
||||
# import fitz # pyMuPDF里面的fitz包,不要与pip install fitz混淆
|
||||
# import numpy as np
|
||||
# ocr = get_ocr()
|
||||
# doc = fitz.open(filepath)
|
||||
# resp = ""
|
||||
|
||||
# b_unit = tqdm.tqdm(total=doc.page_count, desc="RapidOCRPDFLoader context page index: 0")
|
||||
# for i, page in enumerate(doc):
|
||||
# b_unit.set_description("RapidOCRPDFLoader context page index: {}".format(i))
|
||||
# b_unit.refresh()
|
||||
# text = page.get_text("")
|
||||
# resp += text + "\n"
|
||||
|
||||
# img_list = page.get_image_info(xrefs=True)
|
||||
# for img in img_list:
|
||||
# if xref := img.get("xref"):
|
||||
# bbox = img["bbox"]
|
||||
# # 检查图片尺寸是否超过设定的阈值
|
||||
# if ((bbox[2] - bbox[0]) / (page.rect.width) < PDF_OCR_THRESHOLD[0]
|
||||
# or (bbox[3] - bbox[1]) / (page.rect.height) < PDF_OCR_THRESHOLD[1]):
|
||||
# continue
|
||||
# pix = fitz.Pixmap(doc, xref)
|
||||
# samples = pix.samples
|
||||
# if int(page.rotation)!=0: #如果Page有旋转角度,则旋转图片
|
||||
# img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1)
|
||||
# tmp_img = Image.fromarray(img_array);
|
||||
# ori_img = cv2.cvtColor(np.array(tmp_img),cv2.COLOR_RGB2BGR)
|
||||
# rot_img = rotate_img(img=ori_img, angle=360-page.rotation)
|
||||
# img_array = cv2.cvtColor(rot_img, cv2.COLOR_RGB2BGR)
|
||||
# else:
|
||||
# img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1)
|
||||
|
||||
# result, _ = ocr(img_array)
|
||||
# if result:
|
||||
# ocr_result = [line[1] for line in result]
|
||||
# resp += "\n".join(ocr_result)
|
||||
|
||||
# # 更新进度
|
||||
# b_unit.update(1)
|
||||
# return resp
|
||||
|
||||
# text = pdf2text(self.file_path)
|
||||
# from unstructured.partition.text import partition_text
|
||||
# return partition_text(text=text, **self.unstructured_kwargs)
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# loader = RapidOCRPDFLoader(file_path="/Users/tonysong/Desktop/test.pdf")
|
||||
# docs = loader.load()
|
||||
# print(docs)
|
||||
59
langchain-chat/document_loaders/mypptloader.py
Normal file
59
langchain-chat/document_loaders/mypptloader.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||
from typing import List
|
||||
import tqdm
|
||||
|
||||
|
||||
class RapidOCRPPTLoader(UnstructuredFileLoader):
|
||||
def _get_elements(self) -> List:
|
||||
def ppt2text(filepath):
|
||||
from pptx import Presentation
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from io import BytesIO
|
||||
from rapidocr_onnxruntime import RapidOCR
|
||||
ocr = RapidOCR()
|
||||
prs = Presentation(filepath)
|
||||
resp = ""
|
||||
|
||||
def extract_text(shape):
|
||||
nonlocal resp
|
||||
if shape.has_text_frame:
|
||||
resp += shape.text.strip() + "\n"
|
||||
if shape.has_table:
|
||||
for row in shape.table.rows:
|
||||
for cell in row.cells:
|
||||
for paragraph in cell.text_frame.paragraphs:
|
||||
resp += paragraph.text.strip() + "\n"
|
||||
if shape.shape_type == 13: # 13 表示图片
|
||||
image = Image.open(BytesIO(shape.image.blob))
|
||||
result, _ = ocr(np.array(image))
|
||||
if result:
|
||||
ocr_result = [line[1] for line in result]
|
||||
resp += "\n".join(ocr_result)
|
||||
elif shape.shape_type == 6: # 6 表示组合
|
||||
for child_shape in shape.shapes:
|
||||
extract_text(child_shape)
|
||||
|
||||
b_unit = tqdm.tqdm(total=len(prs.slides),
|
||||
desc="RapidOCRPPTLoader slide index: 1")
|
||||
# 遍历所有幻灯片
|
||||
for slide_number, slide in enumerate(prs.slides, start=1):
|
||||
b_unit.set_description(
|
||||
"RapidOCRPPTLoader slide index: {}".format(slide_number))
|
||||
b_unit.refresh()
|
||||
sorted_shapes = sorted(slide.shapes,
|
||||
key=lambda x: (x.top, x.left)) # 从上到下、从左到右遍历
|
||||
for shape in sorted_shapes:
|
||||
extract_text(shape)
|
||||
b_unit.update(1)
|
||||
return resp
|
||||
|
||||
text = ppt2text(self.file_path)
|
||||
from unstructured.partition.text import partition_text
|
||||
return partition_text(text=text, **self.unstructured_kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
loader = RapidOCRPPTLoader(file_path="../tests/samples/ocr_test.pptx")
|
||||
docs = loader.load()
|
||||
print(docs)
|
||||
18
langchain-chat/document_loaders/ocr.py
Normal file
18
langchain-chat/document_loaders/ocr.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from rapidocr_paddle import RapidOCR
|
||||
except ImportError:
|
||||
from rapidocr_onnxruntime import RapidOCR
|
||||
|
||||
|
||||
def get_ocr(use_cuda: bool = True) -> "RapidOCR":
|
||||
try:
|
||||
from rapidocr_paddle import RapidOCR
|
||||
ocr = RapidOCR(det_use_cuda=use_cuda, cls_use_cuda=use_cuda, rec_use_cuda=use_cuda)
|
||||
except ImportError:
|
||||
from rapidocr_onnxruntime import RapidOCR
|
||||
ocr = RapidOCR()
|
||||
return ocr
|
||||
Reference in New Issue
Block a user