[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
158
langchain-chat/document_loaders/mycsvloader.py
Normal file
158
langchain-chat/document_loaders/mycsvloader.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from langchain_core.documents import Document
|
||||
from typing import List
|
||||
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||
# 导入模块
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import pandas as pd
|
||||
|
||||
# TODO: 以CSV格式加载内容
|
||||
MAX_CONTENT = 18000
|
||||
SPECIAL_SYMBOL = ['\'', '"']
|
||||
class RapidOCRCSVLoader(UnstructuredFileLoader):
|
||||
|
||||
def _get_elements(self) -> List:
|
||||
meta_path = self.file_path.split("/")[-1]
|
||||
|
||||
# 读取csv文件
|
||||
def read_csv_file(file_path, sep = ',', encoding = 'utf-8'):
|
||||
df_data = pd.read_csv(file_path, sep = sep, encoding = encoding)
|
||||
df_data = df_data.apply(remove_duplicate_pandas, axis = 1)
|
||||
df_data = df_data.dropna(subset=['HTML全文'])
|
||||
|
||||
return df_data
|
||||
|
||||
def remove_duplicate_string(x):
|
||||
if not isinstance(x, str):
|
||||
return x
|
||||
if len(x) <= 0:
|
||||
return x
|
||||
|
||||
if x[0] in SPECIAL_SYMBOL:
|
||||
x = x[1:]
|
||||
|
||||
if x[-1] in SPECIAL_SYMBOL:
|
||||
x = x[:-1]
|
||||
|
||||
return x
|
||||
|
||||
def remove_duplicate_pandas(row):
|
||||
headers = ['标题', 'HTML全文','发布时间', '发布机构来源', '发布机构',
|
||||
'适用地区', '主题', '详情地址', '资源来源分类',
|
||||
'资源来源名称', '资源来源标识']
|
||||
|
||||
for col in headers:
|
||||
row[col] = remove_duplicate_string(row[col])
|
||||
|
||||
if len(row[col]) > MAX_CONTENT:
|
||||
row[col] = row[col][:MAX_CONTENT]
|
||||
|
||||
return row
|
||||
|
||||
def apply_title_content(x):
|
||||
# 处理 title相关的内容
|
||||
titles = [x['标题']]
|
||||
pattern = r'##title##(.*?)##/title##'
|
||||
title_list = re.findall(pattern, x['HTML全文'])
|
||||
titles += title_list
|
||||
|
||||
# 去除内容中标题的特殊符号
|
||||
content = x['HTML全文'].replace('##title##', '').replace('##/title##', '').replace('\\u3000', '')
|
||||
title_str = "\n".join(titles)
|
||||
tilte_str = title_str.replace('\\u3000', '')
|
||||
|
||||
content = content.replace('#####', '\n')
|
||||
|
||||
# 把内容相关提出
|
||||
title_idx = content.rfind("##/title##") + 1
|
||||
|
||||
return Document(page_content=x['标题'],
|
||||
metadata={"path": meta_path,
|
||||
"_type": 'title',
|
||||
'title' : x['标题'],
|
||||
"content": content ,
|
||||
# 'file_number': x['发文文号'],
|
||||
# 'keywords' : x['关键词'],
|
||||
'release_date' : x['发布时间'],
|
||||
'sourceOrganization' : x['发布机构来源'],
|
||||
'organization' : x['发布机构'],
|
||||
'region' : x['适用地区'],
|
||||
'subject' : x['主题'],
|
||||
'source' : x['详情地址'],
|
||||
'datasource' : x['资源来源分类'],
|
||||
'datasourceclass' : x['资源来源名称'],
|
||||
'datasource_key' : x['资源来源标识'],
|
||||
"xml": title_str,
|
||||
"tag": tag})
|
||||
|
||||
# title作为query,content作为answer
|
||||
def title_content2text(pd_obj, filepath, tag = 'csv'):
|
||||
ret_list = []
|
||||
pd_obj = pd_obj.apply(apply_title_content, axis=1)
|
||||
ret_list = pd_obj.values.tolist()
|
||||
return ret_list
|
||||
|
||||
def apply_content(x):
|
||||
# 去除内容中标题的特殊符号
|
||||
# 把内容相关提出
|
||||
content = x['HTML全文']
|
||||
title_idx = content.rfind("##/title##") + 1
|
||||
if title_idx > 0 and title_idx < len(content):
|
||||
content = content[title_idx + 1:]
|
||||
content = content.replace('##title##', '').replace('##/title##', '').replace('\\u3000', '')
|
||||
content = content.replace('#####', '\n')
|
||||
|
||||
# 对内容进行切分
|
||||
return Document(page_content=content,
|
||||
metadata={"path": meta_path,
|
||||
"_type": 'content',
|
||||
'title' : x['标题'],
|
||||
"content": '' ,
|
||||
# 'file_number': x['发文文号'],
|
||||
# 'keywords' : x['关键词'],
|
||||
'release_date' : x['发布时间'],
|
||||
'sourceOrganization' : x['发布机构来源'],
|
||||
'organization' : x['发布机构'],
|
||||
'region' : x['适用地区'],
|
||||
'subject' : x['主题'],
|
||||
'source' : x['详情地址'],
|
||||
'datasource' : x['资源来源分类'],
|
||||
'datasourceclass' : x['资源来源名称'],
|
||||
'datasource_key' : x['资源来源标识'],
|
||||
"xml": '',
|
||||
"tag": tag})
|
||||
|
||||
def content2text(pd_obj, filepath, tag = 'csv'):
|
||||
ret_list = []
|
||||
pd_obj = pd_obj.apply(apply_content, axis=1)
|
||||
ret_list = pd_obj.values.tolist()
|
||||
return ret_list
|
||||
|
||||
# 读取 csv 文件
|
||||
pd_obj = read_csv_file(self.file_path, r'\^\|\|\^')
|
||||
tag = "csv"
|
||||
|
||||
# 解析文件,获取 query(标题)-value(全文)
|
||||
ret_list_title = title_content2text(pd_obj, self.file_path, 'csv')
|
||||
# 解析文件,获取内容,并进行切割
|
||||
ret_list_content = content2text(pd_obj, self.file_path, 'csv')
|
||||
ret_list = ret_list_title + ret_list_content
|
||||
|
||||
return ret_list
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
return self._get_elements()
|
||||
|
||||
if __name__ == '__main__':
|
||||
loader = RapidOCRCSVLoader(file_path="//home/work/project/test_result.csv")
|
||||
docs = loader.load()
|
||||
|
||||
i =0
|
||||
for doc in docs:
|
||||
|
||||
print(doc)
|
||||
i += 1
|
||||
if i>1000:
|
||||
break
|
||||
Reference in New Issue
Block a user