158 lines
6.2 KiB
Python
158 lines
6.2 KiB
Python
from langchain_core.documents import Document
|
||
from typing import List
|
||
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||
# 导入模块
|
||
import csv
|
||
import os
|
||
import re
|
||
import json
|
||
import pandas as pd
|
||
|
||
# TODO: 以CSV格式加载内容
|
||
MAX_CONTENT = 18000
|
||
SPECIAL_SYMBOL = ['\'', '"']
|
||
class RapidOCRCSVLoader(UnstructuredFileLoader):
|
||
|
||
def _get_elements(self) -> List:
|
||
meta_path = self.file_path.split("/")[-1]
|
||
|
||
# 读取csv文件
|
||
def read_csv_file(file_path, sep = ',', encoding = 'utf-8'):
|
||
df_data = pd.read_csv(file_path, sep = sep, encoding = encoding)
|
||
df_data = df_data.apply(remove_duplicate_pandas, axis = 1)
|
||
df_data = df_data.dropna(subset=['HTML全文'])
|
||
|
||
return df_data
|
||
|
||
def remove_duplicate_string(x):
|
||
if not isinstance(x, str):
|
||
return x
|
||
if len(x) <= 0:
|
||
return x
|
||
|
||
if x[0] in SPECIAL_SYMBOL:
|
||
x = x[1:]
|
||
|
||
if x[-1] in SPECIAL_SYMBOL:
|
||
x = x[:-1]
|
||
|
||
return x
|
||
|
||
def remove_duplicate_pandas(row):
|
||
headers = ['标题', 'HTML全文','发布时间', '发布机构来源', '发布机构',
|
||
'适用地区', '主题', '详情地址', '资源来源分类',
|
||
'资源来源名称', '资源来源标识']
|
||
|
||
for col in headers:
|
||
row[col] = remove_duplicate_string(row[col])
|
||
|
||
if len(row[col]) > MAX_CONTENT:
|
||
row[col] = row[col][:MAX_CONTENT]
|
||
|
||
return row
|
||
|
||
def apply_title_content(x):
|
||
# 处理 title相关的内容
|
||
titles = [x['标题']]
|
||
pattern = r'##title##(.*?)##/title##'
|
||
title_list = re.findall(pattern, x['HTML全文'])
|
||
titles += title_list
|
||
|
||
# 去除内容中标题的特殊符号
|
||
content = x['HTML全文'].replace('##title##', '').replace('##/title##', '').replace('\\u3000', '')
|
||
title_str = "\n".join(titles)
|
||
tilte_str = title_str.replace('\\u3000', '')
|
||
|
||
content = content.replace('#####', '\n')
|
||
|
||
# 把内容相关提出
|
||
title_idx = content.rfind("##/title##") + 1
|
||
|
||
return Document(page_content=x['标题'],
|
||
metadata={"path": meta_path,
|
||
"_type": 'title',
|
||
'title' : x['标题'],
|
||
"content": content ,
|
||
# 'file_number': x['发文文号'],
|
||
# 'keywords' : x['关键词'],
|
||
'release_date' : x['发布时间'],
|
||
'sourceOrganization' : x['发布机构来源'],
|
||
'organization' : x['发布机构'],
|
||
'region' : x['适用地区'],
|
||
'subject' : x['主题'],
|
||
'source' : x['详情地址'],
|
||
'datasource' : x['资源来源分类'],
|
||
'datasourceclass' : x['资源来源名称'],
|
||
'datasource_key' : x['资源来源标识'],
|
||
"xml": title_str,
|
||
"tag": tag})
|
||
|
||
# title作为query,content作为answer
|
||
def title_content2text(pd_obj, filepath, tag = 'csv'):
|
||
ret_list = []
|
||
pd_obj = pd_obj.apply(apply_title_content, axis=1)
|
||
ret_list = pd_obj.values.tolist()
|
||
return ret_list
|
||
|
||
def apply_content(x):
|
||
# 去除内容中标题的特殊符号
|
||
# 把内容相关提出
|
||
content = x['HTML全文']
|
||
title_idx = content.rfind("##/title##") + 1
|
||
if title_idx > 0 and title_idx < len(content):
|
||
content = content[title_idx + 1:]
|
||
content = content.replace('##title##', '').replace('##/title##', '').replace('\\u3000', '')
|
||
content = content.replace('#####', '\n')
|
||
|
||
# 对内容进行切分
|
||
return Document(page_content=content,
|
||
metadata={"path": meta_path,
|
||
"_type": 'content',
|
||
'title' : x['标题'],
|
||
"content": '' ,
|
||
# 'file_number': x['发文文号'],
|
||
# 'keywords' : x['关键词'],
|
||
'release_date' : x['发布时间'],
|
||
'sourceOrganization' : x['发布机构来源'],
|
||
'organization' : x['发布机构'],
|
||
'region' : x['适用地区'],
|
||
'subject' : x['主题'],
|
||
'source' : x['详情地址'],
|
||
'datasource' : x['资源来源分类'],
|
||
'datasourceclass' : x['资源来源名称'],
|
||
'datasource_key' : x['资源来源标识'],
|
||
"xml": '',
|
||
"tag": tag})
|
||
|
||
def content2text(pd_obj, filepath, tag = 'csv'):
|
||
ret_list = []
|
||
pd_obj = pd_obj.apply(apply_content, axis=1)
|
||
ret_list = pd_obj.values.tolist()
|
||
return ret_list
|
||
|
||
# 读取 csv 文件
|
||
pd_obj = read_csv_file(self.file_path, r'\^\|\|\^')
|
||
tag = "csv"
|
||
|
||
# 解析文件,获取 query(标题)-value(全文)
|
||
ret_list_title = title_content2text(pd_obj, self.file_path, 'csv')
|
||
# 解析文件,获取内容,并进行切割
|
||
ret_list_content = content2text(pd_obj, self.file_path, 'csv')
|
||
ret_list = ret_list_title + ret_list_content
|
||
|
||
return ret_list
|
||
|
||
def load(self) -> List[Document]:
|
||
return self._get_elements()
|
||
|
||
if __name__ == '__main__':
|
||
loader = RapidOCRCSVLoader(file_path="//home/work/project/test_result.csv")
|
||
docs = loader.load()
|
||
|
||
i =0
|
||
for doc in docs:
|
||
|
||
print(doc)
|
||
i += 1
|
||
if i>1000:
|
||
break |