158 lines
6.2 KiB
Python
158 lines
6.2 KiB
Python
|
|
from langchain_core.documents import Document
|
|||
|
|
from typing import List
|
|||
|
|
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
|||
|
|
# 导入模块
|
|||
|
|
import csv
|
|||
|
|
import os
|
|||
|
|
import re
|
|||
|
|
import json
|
|||
|
|
import pandas as pd
|
|||
|
|
|
|||
|
|
# TODO: 以CSV格式加载内容
|
|||
|
|
MAX_CONTENT = 18000
|
|||
|
|
SPECIAL_SYMBOL = ['\'', '"']
|
|||
|
|
class RapidOCRCSVLoader(UnstructuredFileLoader):
|
|||
|
|
|
|||
|
|
def _get_elements(self) -> List:
|
|||
|
|
meta_path = self.file_path.split("/")[-1]
|
|||
|
|
|
|||
|
|
# 读取csv文件
|
|||
|
|
def read_csv_file(file_path, sep = ',', encoding = 'utf-8'):
|
|||
|
|
df_data = pd.read_csv(file_path, sep = sep, encoding = encoding)
|
|||
|
|
df_data = df_data.apply(remove_duplicate_pandas, axis = 1)
|
|||
|
|
df_data = df_data.dropna(subset=['HTML全文'])
|
|||
|
|
|
|||
|
|
return df_data
|
|||
|
|
|
|||
|
|
def remove_duplicate_string(x):
|
|||
|
|
if not isinstance(x, str):
|
|||
|
|
return x
|
|||
|
|
if len(x) <= 0:
|
|||
|
|
return x
|
|||
|
|
|
|||
|
|
if x[0] in SPECIAL_SYMBOL:
|
|||
|
|
x = x[1:]
|
|||
|
|
|
|||
|
|
if x[-1] in SPECIAL_SYMBOL:
|
|||
|
|
x = x[:-1]
|
|||
|
|
|
|||
|
|
return x
|
|||
|
|
|
|||
|
|
def remove_duplicate_pandas(row):
|
|||
|
|
headers = ['标题', 'HTML全文','发布时间', '发布机构来源', '发布机构',
|
|||
|
|
'适用地区', '主题', '详情地址', '资源来源分类',
|
|||
|
|
'资源来源名称', '资源来源标识']
|
|||
|
|
|
|||
|
|
for col in headers:
|
|||
|
|
row[col] = remove_duplicate_string(row[col])
|
|||
|
|
|
|||
|
|
if len(row[col]) > MAX_CONTENT:
|
|||
|
|
row[col] = row[col][:MAX_CONTENT]
|
|||
|
|
|
|||
|
|
return row
|
|||
|
|
|
|||
|
|
def apply_title_content(x):
|
|||
|
|
# 处理 title相关的内容
|
|||
|
|
titles = [x['标题']]
|
|||
|
|
pattern = r'##title##(.*?)##/title##'
|
|||
|
|
title_list = re.findall(pattern, x['HTML全文'])
|
|||
|
|
titles += title_list
|
|||
|
|
|
|||
|
|
# 去除内容中标题的特殊符号
|
|||
|
|
content = x['HTML全文'].replace('##title##', '').replace('##/title##', '').replace('\\u3000', '')
|
|||
|
|
title_str = "\n".join(titles)
|
|||
|
|
tilte_str = title_str.replace('\\u3000', '')
|
|||
|
|
|
|||
|
|
content = content.replace('#####', '\n')
|
|||
|
|
|
|||
|
|
# 把内容相关提出
|
|||
|
|
title_idx = content.rfind("##/title##") + 1
|
|||
|
|
|
|||
|
|
return Document(page_content=x['标题'],
|
|||
|
|
metadata={"path": meta_path,
|
|||
|
|
"_type": 'title',
|
|||
|
|
'title' : x['标题'],
|
|||
|
|
"content": content ,
|
|||
|
|
# 'file_number': x['发文文号'],
|
|||
|
|
# 'keywords' : x['关键词'],
|
|||
|
|
'release_date' : x['发布时间'],
|
|||
|
|
'sourceOrganization' : x['发布机构来源'],
|
|||
|
|
'organization' : x['发布机构'],
|
|||
|
|
'region' : x['适用地区'],
|
|||
|
|
'subject' : x['主题'],
|
|||
|
|
'source' : x['详情地址'],
|
|||
|
|
'datasource' : x['资源来源分类'],
|
|||
|
|
'datasourceclass' : x['资源来源名称'],
|
|||
|
|
'datasource_key' : x['资源来源标识'],
|
|||
|
|
"xml": title_str,
|
|||
|
|
"tag": tag})
|
|||
|
|
|
|||
|
|
# title作为query,content作为answer
|
|||
|
|
def title_content2text(pd_obj, filepath, tag = 'csv'):
|
|||
|
|
ret_list = []
|
|||
|
|
pd_obj = pd_obj.apply(apply_title_content, axis=1)
|
|||
|
|
ret_list = pd_obj.values.tolist()
|
|||
|
|
return ret_list
|
|||
|
|
|
|||
|
|
def apply_content(x):
|
|||
|
|
# 去除内容中标题的特殊符号
|
|||
|
|
# 把内容相关提出
|
|||
|
|
content = x['HTML全文']
|
|||
|
|
title_idx = content.rfind("##/title##") + 1
|
|||
|
|
if title_idx > 0 and title_idx < len(content):
|
|||
|
|
content = content[title_idx + 1:]
|
|||
|
|
content = content.replace('##title##', '').replace('##/title##', '').replace('\\u3000', '')
|
|||
|
|
content = content.replace('#####', '\n')
|
|||
|
|
|
|||
|
|
# 对内容进行切分
|
|||
|
|
return Document(page_content=content,
|
|||
|
|
metadata={"path": meta_path,
|
|||
|
|
"_type": 'content',
|
|||
|
|
'title' : x['标题'],
|
|||
|
|
"content": '' ,
|
|||
|
|
# 'file_number': x['发文文号'],
|
|||
|
|
# 'keywords' : x['关键词'],
|
|||
|
|
'release_date' : x['发布时间'],
|
|||
|
|
'sourceOrganization' : x['发布机构来源'],
|
|||
|
|
'organization' : x['发布机构'],
|
|||
|
|
'region' : x['适用地区'],
|
|||
|
|
'subject' : x['主题'],
|
|||
|
|
'source' : x['详情地址'],
|
|||
|
|
'datasource' : x['资源来源分类'],
|
|||
|
|
'datasourceclass' : x['资源来源名称'],
|
|||
|
|
'datasource_key' : x['资源来源标识'],
|
|||
|
|
"xml": '',
|
|||
|
|
"tag": tag})
|
|||
|
|
|
|||
|
|
def content2text(pd_obj, filepath, tag = 'csv'):
|
|||
|
|
ret_list = []
|
|||
|
|
pd_obj = pd_obj.apply(apply_content, axis=1)
|
|||
|
|
ret_list = pd_obj.values.tolist()
|
|||
|
|
return ret_list
|
|||
|
|
|
|||
|
|
# 读取 csv 文件
|
|||
|
|
pd_obj = read_csv_file(self.file_path, r'\^\|\|\^')
|
|||
|
|
tag = "csv"
|
|||
|
|
|
|||
|
|
# 解析文件,获取 query(标题)-value(全文)
|
|||
|
|
ret_list_title = title_content2text(pd_obj, self.file_path, 'csv')
|
|||
|
|
# 解析文件,获取内容,并进行切割
|
|||
|
|
ret_list_content = content2text(pd_obj, self.file_path, 'csv')
|
|||
|
|
ret_list = ret_list_title + ret_list_content
|
|||
|
|
|
|||
|
|
return ret_list
|
|||
|
|
|
|||
|
|
def load(self) -> List[Document]:
|
|||
|
|
return self._get_elements()
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
loader = RapidOCRCSVLoader(file_path="//home/work/project/test_result.csv")
|
|||
|
|
docs = loader.load()
|
|||
|
|
|
|||
|
|
i =0
|
|||
|
|
for doc in docs:
|
|||
|
|
|
|||
|
|
print(doc)
|
|||
|
|
i += 1
|
|||
|
|
if i>1000:
|
|||
|
|
break
|