[全量] 初始化项目代码、配置、文档及Agent协同harness

This commit is contained in:
2026-04-02 11:36:05 +08:00
parent 0553309cdf
commit 87e571d9ec
1133 changed files with 221948 additions and 0 deletions

View File

@@ -0,0 +1,81 @@
## 指定制定列的csv文件加载器
from langchain.document_loaders import CSVLoader
import csv
from io import TextIOWrapper
from typing import Dict, List, Optional
from langchain.docstore.document import Document
from langchain.document_loaders.helpers import detect_file_encodings
class FilteredCSVLoader(CSVLoader):
def __init__(
self,
file_path: str,
columns_to_read: List[str],
source_column: Optional[str] = None,
metadata_columns: List[str] = [],
csv_args: Optional[Dict] = None,
encoding: Optional[str] = None,
autodetect_encoding: bool = False,
):
super().__init__(
file_path=file_path,
source_column=source_column,
metadata_columns=metadata_columns,
csv_args=csv_args,
encoding=encoding,
autodetect_encoding=autodetect_encoding,
)
self.columns_to_read = columns_to_read
def load(self) -> List[Document]:
"""Load data into document objects."""
docs = []
try:
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
docs = self.__read_file(csvfile)
except UnicodeDecodeError as e:
if self.autodetect_encoding:
detected_encodings = detect_file_encodings(self.file_path)
for encoding in detected_encodings:
try:
with open(
self.file_path, newline="", encoding=encoding.encoding
) as csvfile:
docs = self.__read_file(csvfile)
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {self.file_path}") from e
except Exception as e:
raise RuntimeError(f"Error loading {self.file_path}") from e
return docs
def __read_file(self, csvfile: TextIOWrapper) -> List[Document]:
docs = []
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
for i, row in enumerate(csv_reader):
if self.columns_to_read[0] in row:
content = row[self.columns_to_read[0]]
# Extract the source if available
source = (
row.get(self.source_column, None)
if self.source_column is not None
else self.file_path
)
metadata = {"source": source, "row": i}
for col in self.metadata_columns:
if col in row:
metadata[col] = row[col]
doc = Document(page_content=content, metadata=metadata)
docs.append(doc)
else:
raise ValueError(f"Column '{self.columns_to_read[0]}' not found in CSV file.")
return docs

View File

@@ -0,0 +1,8 @@
from .myimgloader import RapidOCRLoader
from .mydocloader import RapidOCRDocLoader
from .mypptloader import RapidOCRPPTLoader
from .mycsvloader import RapidOCRCSVLoader
from .GCYWordLoader import GCYWordLoader
from .gycWordLoader import GCYWordLoader as GCYWordLoader2
from .GCYHTMLLoader import GCYHTMLLoader
from .myexcelloader import ExcelLoader

View File

@@ -0,0 +1,204 @@
from typing import List, Optional, Union, Dict, Any
from pathlib import Path
from langchain_core.documents import Document
from langchain_community.document_loaders.base import BaseLoader
from docx import Document as DocxDocument
import os
import subprocess
import logging
import zipfile
from lxml import etree as ET
import re
logger = logging.getLogger(__name__)
class GCYWordLoader(BaseLoader):
"""用于加载和解析 Word 文档的自定义加载器支持标题层级结构解析及XML级别内容段落、表格、页眉页脚、批注、修订、目录"""
# WordprocessingML命名空间
ns = {
"w": "http://schemas.openxmlformats.org/wordprocessingml/2006/main"
}
def __init__(
self,
file_path: Union[str, Path],
output_dir: Optional[Union[str, Path]] = None,
*,
keep_doc_title: bool = True,
start_with_title: bool = False,
max_heading_level: int = 3,
metadata: Optional[Dict[str, Any]] = None
):
self.file_path = str(file_path)
self.output_dir = str(output_dir) if output_dir else os.path.dirname(self.file_path)
self.keep_doc_title = keep_doc_title
self.start_with_title = start_with_title
self.max_heading_level = min(max(1, max_heading_level), 6)
self.metadata = metadata or {}
# 临时解压目录
self._work_dir = os.path.join(self.output_dir, '_pyc_work')
self._doc_dir = os.path.join(self._work_dir, 'word')
# 验证
if not os.path.isfile(self.file_path):
raise FileNotFoundError(f"文件不存在: {self.file_path}")
if not self.file_path.lower().endswith(('.doc', '.docx')):
raise ValueError("仅支持 .doc 或 .docx 格式")
def load(self) -> List[Document]:
try:
# 预处理(.doc 转 .docx
processed_path = self._preprocess_document()
# 准备工作目录并解压
self._prepare_work_directories()
self._extract_docx(processed_path)
# 解析主文档 XML
self._parse_document_xml()
# 构建文档片段列表
docs: List[Document] = []
base_meta = {"source": self.file_path, "file_name": os.path.basename(self.file_path), **self.metadata}
# 支持保留文档标题
docx_core = DocxDocument(processed_path)
if self.keep_doc_title and docx_core.core_properties.title:
title = docx_core.core_properties.title
if self.start_with_title:
docs.append(Document(page_content=title, metadata={**base_meta, "heading": "Document Title"}))
else:
docs.append(Document(page_content=title, metadata={**base_meta, "heading": None}))
# 段落
docs.extend(self._parse_paragraphs(base_meta))
# 表格
docs.extend(self._parse_tables(base_meta))
# 页眉与页脚
docs.extend(self._parse_headers_and_footers(base_meta))
# 批注
docs.extend(self._parse_comments(base_meta))
# 修订
docs.extend(self._parse_revisions(base_meta))
# 目录项
docs.extend(self._parse_toc(base_meta))
return docs
except Exception as e:
logger.error(f"文档加载失败: {self.file_path}", exc_info=True)
raise RuntimeError(f"无法加载文档: {e}")
finally:
# 清理临时
if os.path.exists(self._work_dir):
try:
import shutil; shutil.rmtree(self._work_dir)
except:
pass
def _preprocess_document(self) -> str:
# .docx 直接返回
if self.file_path.lower().endswith('.docx'):
return self.file_path
# .doc 转 .docx
output_path = os.path.join(self.output_dir, Path(self.file_path).stem + '.docx')
subprocess.run([
'soffice', '--headless', '--convert-to', 'docx', '--outdir', self.output_dir, self.file_path
], check=True, capture_output=True)
if not os.path.exists(output_path):
raise RuntimeError('文档格式转换失败')
return output_path
def _prepare_work_directories(self):
if os.path.exists(self._work_dir):
import shutil; shutil.rmtree(self._work_dir)
os.makedirs(self._doc_dir, exist_ok=True)
def _extract_docx(self, path: str):
with zipfile.ZipFile(path, 'r') as z:
z.extractall(self._work_dir)
def _parse_document_xml(self):
xml_path = os.path.join(self._doc_dir, 'document.xml')
parser = ET.XMLParser(remove_blank_text=True)
self._doc_tree = ET.parse(xml_path, parser)
self._doc_root = self._doc_tree.getroot()
def _get_text_from_runs(self, parent) -> str:
parts = []
for r in parent.findall('.//w:r', self.ns):
for t in r.findall('w:t', self.ns):
if t.text:
parts.append(t.text)
return ''.join(parts).strip()
def _parse_paragraphs(self, base_meta) -> List[Document]:
docs = []
for p in self._doc_root.findall('.//w:p', self.ns):
text = self._get_text_from_runs(p)
if not text:
continue
docs.append(Document(page_content=text, metadata={**base_meta, 'content_type':'paragraph'}))
return docs
def _parse_tables(self, base_meta) -> List[Document]:
docs = []
for tbl in self._doc_root.findall('.//w:tbl', self.ns):
# 简单按行为 \n 分段
rows = []
for row in tbl.findall('.//w:tr', self.ns):
cells = []
for cell in row.findall('.//w:tc', self.ns):
cells.append(self._get_text_from_runs(cell))
rows.append('|'.join(cells))
content = '\n'.join(rows)
docs.append(Document(page_content=content, metadata={**base_meta, 'content_type':'table'}))
return docs
def _parse_headers_and_footers(self, base_meta) -> List[Document]:
docs = []
for part in ['header', 'footer']:
for i in range(1,10):
path = os.path.join(self._doc_dir, f'{part}{i}.xml')
if not os.path.exists(path):
continue
tree = ET.parse(path)
root = tree.getroot()
for p in root.findall('.//w:p', self.ns):
text = self._get_text_from_runs(p)
if text:
docs.append(Document(page_content=text, metadata={**base_meta, 'content_type':part}))
return docs
def _parse_comments(self, base_meta) -> List[Document]:
docs = []
path = os.path.join(self._doc_dir, 'comments.xml')
if os.path.exists(path):
tree = ET.parse(path)
root = tree.getroot()
for c in root.findall('.//w:comment', self.ns):
text = self._get_text_from_runs(c)
if text:
docs.append(Document(page_content=text, metadata={**base_meta, 'content_type':'comment'}))
return docs
def _parse_revisions(self, base_meta) -> List[Document]:
docs = []
for tag in ['del','ins']:
for el in self._doc_root.findall(f'.//w:{tag}', self.ns):
text = self._get_text_from_runs(el)
if text:
docs.append(Document(page_content=text, metadata={**base_meta, 'content_type':'revision'}))
return docs
def _parse_toc(self, base_meta) -> List[Document]:
docs = []
for p in self._doc_root.findall('.//w:p', self.ns):
full = ''.join(t.text or '' for t in p.findall('w:t', self.ns)).strip()
if len(full)>255: continue
if p.find('.//w:tab', self.ns) is not None and full and full[-1].isdigit():
# 简单拆分标题和页码
parts = re.split(r'\t+', full)
title = parts[0]
docs.append(Document(page_content=title, metadata={**base_meta,'content_type':'toc_entry'}))
return docs

View File

@@ -0,0 +1,158 @@
from langchain_core.documents import Document
from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader
# 导入模块
import csv
import os
import re
import json
import pandas as pd
# TODO: 以CSV格式加载内容
MAX_CONTENT = 18000
SPECIAL_SYMBOL = ['\'', '"']
class RapidOCRCSVLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
meta_path = self.file_path.split("/")[-1]
# 读取csv文件
def read_csv_file(file_path, sep = ',', encoding = 'utf-8'):
df_data = pd.read_csv(file_path, sep = sep, encoding = encoding)
df_data = df_data.apply(remove_duplicate_pandas, axis = 1)
df_data = df_data.dropna(subset=['HTML全文'])
return df_data
def remove_duplicate_string(x):
if not isinstance(x, str):
return x
if len(x) <= 0:
return x
if x[0] in SPECIAL_SYMBOL:
x = x[1:]
if x[-1] in SPECIAL_SYMBOL:
x = x[:-1]
return x
def remove_duplicate_pandas(row):
headers = ['标题', 'HTML全文','发布时间', '发布机构来源', '发布机构',
'适用地区', '主题', '详情地址', '资源来源分类',
'资源来源名称', '资源来源标识']
for col in headers:
row[col] = remove_duplicate_string(row[col])
if len(row[col]) > MAX_CONTENT:
row[col] = row[col][:MAX_CONTENT]
return row
def apply_title_content(x):
# 处理 title相关的内容
titles = [x['标题']]
pattern = r'##title##(.*?)##/title##'
title_list = re.findall(pattern, x['HTML全文'])
titles += title_list
# 去除内容中标题的特殊符号
content = x['HTML全文'].replace('##title##', '').replace('##/title##', '').replace('\\u3000', '')
title_str = "\n".join(titles)
tilte_str = title_str.replace('\\u3000', '')
content = content.replace('#####', '\n')
# 把内容相关提出
title_idx = content.rfind("##/title##") + 1
return Document(page_content=x['标题'],
metadata={"path": meta_path,
"_type": 'title',
'title' : x['标题'],
"content": content ,
# 'file_number': x['发文文号'],
# 'keywords' : x['关键词'],
'release_date' : x['发布时间'],
'sourceOrganization' : x['发布机构来源'],
'organization' : x['发布机构'],
'region' : x['适用地区'],
'subject' : x['主题'],
'source' : x['详情地址'],
'datasource' : x['资源来源分类'],
'datasourceclass' : x['资源来源名称'],
'datasource_key' : x['资源来源标识'],
"xml": title_str,
"tag": tag})
# title作为querycontent作为answer
def title_content2text(pd_obj, filepath, tag = 'csv'):
ret_list = []
pd_obj = pd_obj.apply(apply_title_content, axis=1)
ret_list = pd_obj.values.tolist()
return ret_list
def apply_content(x):
# 去除内容中标题的特殊符号
# 把内容相关提出
content = x['HTML全文']
title_idx = content.rfind("##/title##") + 1
if title_idx > 0 and title_idx < len(content):
content = content[title_idx + 1:]
content = content.replace('##title##', '').replace('##/title##', '').replace('\\u3000', '')
content = content.replace('#####', '\n')
# 对内容进行切分
return Document(page_content=content,
metadata={"path": meta_path,
"_type": 'content',
'title' : x['标题'],
"content": '' ,
# 'file_number': x['发文文号'],
# 'keywords' : x['关键词'],
'release_date' : x['发布时间'],
'sourceOrganization' : x['发布机构来源'],
'organization' : x['发布机构'],
'region' : x['适用地区'],
'subject' : x['主题'],
'source' : x['详情地址'],
'datasource' : x['资源来源分类'],
'datasourceclass' : x['资源来源名称'],
'datasource_key' : x['资源来源标识'],
"xml": '',
"tag": tag})
def content2text(pd_obj, filepath, tag = 'csv'):
ret_list = []
pd_obj = pd_obj.apply(apply_content, axis=1)
ret_list = pd_obj.values.tolist()
return ret_list
# 读取 csv 文件
pd_obj = read_csv_file(self.file_path, r'\^\|\|\^')
tag = "csv"
# 解析文件,获取 query(标题)-value(全文)
ret_list_title = title_content2text(pd_obj, self.file_path, 'csv')
# 解析文件,获取内容,并进行切割
ret_list_content = content2text(pd_obj, self.file_path, 'csv')
ret_list = ret_list_title + ret_list_content
return ret_list
def load(self) -> List[Document]:
return self._get_elements()
if __name__ == '__main__':
loader = RapidOCRCSVLoader(file_path="//home/work/project/test_result.csv")
docs = loader.load()
i =0
for doc in docs:
print(doc)
i += 1
if i>1000:
break

View File

@@ -0,0 +1,71 @@
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from typing import List
import tqdm
class RapidOCRDocLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def doc2text(filepath):
from docx.table import _Cell, Table
from docx.oxml.table import CT_Tbl
from docx.oxml.text.paragraph import CT_P
from docx.text.paragraph import Paragraph
from docx import Document, ImagePart
from PIL import Image
from io import BytesIO
import numpy as np
from rapidocr_onnxruntime import RapidOCR
ocr = RapidOCR()
doc = Document(filepath)
resp = ""
def iter_block_items(parent):
from docx.document import Document
if isinstance(parent, Document):
parent_elm = parent.element.body
elif isinstance(parent, _Cell):
parent_elm = parent._tc
else:
raise ValueError("RapidOCRDocLoader parse fail")
for child in parent_elm.iterchildren():
if isinstance(child, CT_P):
yield Paragraph(child, parent)
elif isinstance(child, CT_Tbl):
yield Table(child, parent)
b_unit = tqdm.tqdm(total=len(doc.paragraphs)+len(doc.tables),
desc="RapidOCRDocLoader block index: 0")
for i, block in enumerate(iter_block_items(doc)):
b_unit.set_description(
"RapidOCRDocLoader block index: {}".format(i))
b_unit.refresh()
if isinstance(block, Paragraph):
resp += block.text.strip() + "\n"
images = block._element.xpath('.//pic:pic') # 获取所有图片
for image in images:
for img_id in image.xpath('.//a:blip/@r:embed'): # 获取图片id
part = doc.part.related_parts[img_id] # 根据图片id获取对应的图片
if isinstance(part, ImagePart):
image = Image.open(BytesIO(part._blob))
result, _ = ocr(np.array(image))
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
elif isinstance(block, Table):
for row in block.rows:
for cell in row.cells:
for paragraph in cell.paragraphs:
resp += paragraph.text.strip() + "\n"
b_unit.update(1)
return resp
text = doc2text(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(text=text, **self.unstructured_kwargs)
if __name__ == '__main__':
loader = RapidOCRDocLoader(file_path="../tests/samples/ocr_test.docx")
docs = loader.load()
print(docs)

View File

@@ -0,0 +1,71 @@
import os
import pandas as pd
from pathlib import Path
import logging
from typing import List, Optional, Union, Dict, Any
from langchain_core.documents import Document
from langchain_community.document_loaders.base import BaseLoader
logger = logging.getLogger(__name__)
class ExcelLoader(BaseLoader):
"""
用于加载 Excel 文件(.xls/.xlsx的 Loader。
使用 pandas 解析所有工作表,并将非空单元格内容展平为按逗号分隔的文段。
"""
def __init__(
self,
file_path: Union[str, Path],
*,
metadata: Optional[Dict[str, Any]] = None
):
"""
Args:
file_path: Excel 文件路径,支持 .xls 和 .xlsx
metadata: 附加的文档级元数据
"""
self.file_path = str(file_path)
self.metadata = metadata or {}
suffix = Path(self.file_path).suffix.lower()
if suffix not in (".xls", ".xlsx"):
raise ValueError(f"ExcelLoader 仅支持 .xls/.xlsx 文件: {self.file_path}")
if not os.path.isfile(self.file_path):
raise FileNotFoundError(f"文件不存在: {self.file_path}")
def load(self) -> List[Document]:
"""
读取 Excel 中的所有工作表,返回每个表格中所有非空单元格按逗号分隔的 Document 列表。
"""
try:
# sheet_name=None 返回 dict: {sheet_name: DataFrame}
sheets: Dict[str, pd.DataFrame] = pd.read_excel(
self.file_path,
sheet_name=None
)
except Exception as e:
logger.error(f"读取 Excel 文件失败: {self.file_path}", exc_info=True)
raise RuntimeError(f"无法加载 Excel 文件: {e}") from e
documents: List[Document] = []
for sheet_name, df in sheets.items():
segments: List[str] = []
# 遍历所有单元格
for row in df.itertuples(index=False, name=None):
for cell in row:
if pd.isna(cell):
continue
text = str(cell).strip()
if text:
segments.append(text)
# 用英文逗号分隔所有文段
content = ",".join(segments)
md: Dict[str, Any] = {
"source": self.file_path,
"sheet_name": sheet_name,
**self.metadata
}
documents.append(Document(page_content=content, metadata=md))
return documents

View File

@@ -0,0 +1,26 @@
from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from document_loaders.ocr import get_ocr
class RapidOCRLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def img2text(filepath):
resp = ""
ocr = get_ocr()
result, _ = ocr(filepath)
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
return resp
text = img2text(self.file_path)
return text
# from unstructured.partition.text import partition_text
# return partition_text(text=text, **self.unstructured_kwargs)
if __name__ == "__main__":
loader = RapidOCRLoader(file_path="../tests/samples/ocr_test.jpg")
docs = loader.load()
print(docs)

View File

@@ -0,0 +1,87 @@
# from typing import List
# from langchain.document_loaders.unstructured import UnstructuredFileLoader
# import cv2
# from PIL import Image
# import numpy as np
# from configs import PDF_OCR_THRESHOLD
# from document_loaders.ocr import get_ocr
# import tqdm
# class RapidOCRPDFLoader(UnstructuredFileLoader):
# def _get_elements(self) -> List:
# def rotate_img(img, angle):
# '''
# img --image
# angle --rotation angle
# return--rotated img
# '''
# h, w = img.shape[:2]
# rotate_center = (w/2, h/2)
# #获取旋转矩阵
# # 参数1为旋转中心点;
# # 参数2为旋转角度,正值-逆时针旋转;负值-顺时针旋转
# # 参数3为各向同性的比例因子,1.0原图2.0变成原来的2倍0.5变成原来的0.5倍
# M = cv2.getRotationMatrix2D(rotate_center, angle, 1.0)
# #计算图像新边界
# new_w = int(h * np.abs(M[0, 1]) + w * np.abs(M[0, 0]))
# new_h = int(h * np.abs(M[0, 0]) + w * np.abs(M[0, 1]))
# #调整旋转矩阵以考虑平移
# M[0, 2] += (new_w - w) / 2
# M[1, 2] += (new_h - h) / 2
# rotated_img = cv2.warpAffine(img, M, (new_w, new_h))
# return rotated_img
# def pdf2text(filepath):
# import fitz # pyMuPDF里面的fitz包不要与pip install fitz混淆
# import numpy as np
# ocr = get_ocr()
# doc = fitz.open(filepath)
# resp = ""
# b_unit = tqdm.tqdm(total=doc.page_count, desc="RapidOCRPDFLoader context page index: 0")
# for i, page in enumerate(doc):
# b_unit.set_description("RapidOCRPDFLoader context page index: {}".format(i))
# b_unit.refresh()
# text = page.get_text("")
# resp += text + "\n"
# img_list = page.get_image_info(xrefs=True)
# for img in img_list:
# if xref := img.get("xref"):
# bbox = img["bbox"]
# # 检查图片尺寸是否超过设定的阈值
# if ((bbox[2] - bbox[0]) / (page.rect.width) < PDF_OCR_THRESHOLD[0]
# or (bbox[3] - bbox[1]) / (page.rect.height) < PDF_OCR_THRESHOLD[1]):
# continue
# pix = fitz.Pixmap(doc, xref)
# samples = pix.samples
# if int(page.rotation)!=0: #如果Page有旋转角度则旋转图片
# img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1)
# tmp_img = Image.fromarray(img_array);
# ori_img = cv2.cvtColor(np.array(tmp_img),cv2.COLOR_RGB2BGR)
# rot_img = rotate_img(img=ori_img, angle=360-page.rotation)
# img_array = cv2.cvtColor(rot_img, cv2.COLOR_RGB2BGR)
# else:
# img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1)
# result, _ = ocr(img_array)
# if result:
# ocr_result = [line[1] for line in result]
# resp += "\n".join(ocr_result)
# # 更新进度
# b_unit.update(1)
# return resp
# text = pdf2text(self.file_path)
# from unstructured.partition.text import partition_text
# return partition_text(text=text, **self.unstructured_kwargs)
# if __name__ == "__main__":
# loader = RapidOCRPDFLoader(file_path="/Users/tonysong/Desktop/test.pdf")
# docs = loader.load()
# print(docs)

View File

@@ -0,0 +1,59 @@
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from typing import List
import tqdm
class RapidOCRPPTLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def ppt2text(filepath):
from pptx import Presentation
from PIL import Image
import numpy as np
from io import BytesIO
from rapidocr_onnxruntime import RapidOCR
ocr = RapidOCR()
prs = Presentation(filepath)
resp = ""
def extract_text(shape):
nonlocal resp
if shape.has_text_frame:
resp += shape.text.strip() + "\n"
if shape.has_table:
for row in shape.table.rows:
for cell in row.cells:
for paragraph in cell.text_frame.paragraphs:
resp += paragraph.text.strip() + "\n"
if shape.shape_type == 13: # 13 表示图片
image = Image.open(BytesIO(shape.image.blob))
result, _ = ocr(np.array(image))
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
elif shape.shape_type == 6: # 6 表示组合
for child_shape in shape.shapes:
extract_text(child_shape)
b_unit = tqdm.tqdm(total=len(prs.slides),
desc="RapidOCRPPTLoader slide index: 1")
# 遍历所有幻灯片
for slide_number, slide in enumerate(prs.slides, start=1):
b_unit.set_description(
"RapidOCRPPTLoader slide index: {}".format(slide_number))
b_unit.refresh()
sorted_shapes = sorted(slide.shapes,
key=lambda x: (x.top, x.left)) # 从上到下、从左到右遍历
for shape in sorted_shapes:
extract_text(shape)
b_unit.update(1)
return resp
text = ppt2text(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(text=text, **self.unstructured_kwargs)
if __name__ == '__main__':
loader = RapidOCRPPTLoader(file_path="../tests/samples/ocr_test.pptx")
docs = loader.load()
print(docs)

View File

@@ -0,0 +1,18 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
try:
from rapidocr_paddle import RapidOCR
except ImportError:
from rapidocr_onnxruntime import RapidOCR
def get_ocr(use_cuda: bool = True) -> "RapidOCR":
try:
from rapidocr_paddle import RapidOCR
ocr = RapidOCR(det_use_cuda=use_cuda, cls_use_cuda=use_cuda, rec_use_cuda=use_cuda)
except ImportError:
from rapidocr_onnxruntime import RapidOCR
ocr = RapidOCR()
return ocr