[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast
|
||||
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain.callbacks.base import AsyncCallbackHandler
|
||||
|
||||
# TODO If used by two LLM runs in parallel this won't work as expected
|
||||
|
||||
|
||||
class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
||||
"""Callback handler that returns an async iterator."""
|
||||
|
||||
queue: asyncio.Queue[str]
|
||||
|
||||
done: asyncio.Event
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
return True
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.queue = asyncio.Queue()
|
||||
self.done = asyncio.Event()
|
||||
|
||||
async def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
# If two calls are made in a row, this resets the state
|
||||
self.done.clear()
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
if token is not None and token != "":
|
||||
self.queue.put_nowait(token)
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
self.done.set()
|
||||
|
||||
async def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
self.done.set()
|
||||
|
||||
# TODO implement the other methods
|
||||
def get_queue(self):
|
||||
return self.queue._queue
|
||||
|
||||
async def aiter(self) -> AsyncIterator[str]:
|
||||
while not self.queue.empty() or not self.done.is_set():
|
||||
# Wait for the next token in the queue,
|
||||
# but stop waiting if the done event is set
|
||||
done, other = await asyncio.wait(
|
||||
[
|
||||
# NOTE: If you add other tasks here, update the code below,
|
||||
# which assumes each set has exactly one task each
|
||||
asyncio.ensure_future(self.queue.get()),
|
||||
asyncio.ensure_future(self.done.wait()),
|
||||
],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
# Cancel the other task
|
||||
if other:
|
||||
other.pop().cancel()
|
||||
|
||||
# Extract the value of the first completed task
|
||||
token_or_done = cast(Union[str, Literal[True]], done.pop().result())
|
||||
|
||||
# If the extracted value is the boolean True, the done event was set
|
||||
if token_or_done is True:
|
||||
break
|
||||
|
||||
# Otherwise, the extracted value is a token, which we yield
|
||||
yield token_or_done
|
||||
0
langchain-chat/server/custom/__init__.py
Normal file
0
langchain-chat/server/custom/__init__.py
Normal file
120
langchain-chat/server/custom/abstract_search.py
Normal file
120
langchain-chat/server/custom/abstract_search.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncIterable, List, Optional
|
||||
from fastapi import Body, Request, HTTPException
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
from configs import (
|
||||
TEMPERATURE,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
SCORE_THRESHOLD,
|
||||
LLM_MODELS,
|
||||
JOURNAL_KNOWLEDGE_BASE,
|
||||
MAX_TOKENS,
|
||||
)
|
||||
from server.utils import BaseResponse, get_prompt_template
|
||||
from server.chat.policy_fun_iast import get_llm_model_response
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
|
||||
async def abstract_search(
|
||||
query: str = Body(..., description="用户输入", examples=["人工智能领域相关的文献有啥"]),
|
||||
knowledge_base_name_list: List[str] = Body(..., description="知识库名称", examples=[["t_journal_article_bge_v1"]]),
|
||||
# top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数", ge=1),
|
||||
# fileName: List[str] = Body([], description="文件名称", examples=[["123.txt"]]),
|
||||
# stream: bool = Body(False, description="流式输出"),
|
||||
# model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
# temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
# max_tokens: Optional[int] = Body(MAX_TOKENS,description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
# score_threshold: float = Body(SCORE_THRESHOLD,
|
||||
# description=(
|
||||
# "知识库匹配相关度阈值,取值范围在0-1之间,"
|
||||
# "SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右"
|
||||
# ),
|
||||
# ge=0.0,
|
||||
# le=1.0
|
||||
# ),
|
||||
# request: Request = None,
|
||||
# query_rewrite_model_name: str = Body(
|
||||
# 'Qwen2-72B-Instruct',
|
||||
# description="Query rewrite model name."
|
||||
# ),
|
||||
) -> EventSourceResponse:
|
||||
|
||||
async def search_iterator() -> AsyncIterable[str]:
|
||||
try:
|
||||
# 重写查询
|
||||
search_query_response = await run_in_threadpool(
|
||||
get_llm_model_response,
|
||||
strategy_name="query rewrite",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="query_rewrite",
|
||||
prompt_param_dict={"query": query, "history": []},
|
||||
temperature=0.01,
|
||||
max_tokens=512
|
||||
)
|
||||
|
||||
# 解析重写后的查询
|
||||
try:
|
||||
data = json.loads(search_query_response.strip("```json\n").strip("```"))
|
||||
queries = data.get('query', [])
|
||||
search_query = ' '.join(queries)
|
||||
except json.JSONDecodeError:
|
||||
search_query = query
|
||||
|
||||
logging.info(f'段落推荐输入问题: {search_query}')
|
||||
|
||||
# 执行文档搜索(这里只搜索期刊知识库)
|
||||
journaldocs = await run_in_threadpool(
|
||||
search_docs,
|
||||
query=search_query,
|
||||
fileName=[],
|
||||
knowledge_base_name=JOURNAL_KNOWLEDGE_BASE,
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
score_threshold=SCORE_THRESHOLD
|
||||
)
|
||||
logging.info(f"段落推荐召回资料: {journaldocs}")
|
||||
# 处理搜索结果,去重并整理
|
||||
seen_ids = set()
|
||||
journallist = []
|
||||
# journaldocs = []
|
||||
for doc in journaldocs:
|
||||
doc_id = doc.metadata.get('ID')
|
||||
content = doc.metadata.get('abstract', '').strip()
|
||||
#处理content不是完整段落的情况
|
||||
if content:
|
||||
if not content.endswith('。'):
|
||||
last_period_index = content.rfind('。')
|
||||
if last_period_index != -1:
|
||||
content = content[:last_period_index + 1]
|
||||
else:
|
||||
content += '...'
|
||||
|
||||
if doc_id and content and doc_id not in seen_ids:
|
||||
journallist.append({
|
||||
'id': doc_id,
|
||||
'title': doc.metadata.get('title', ''),
|
||||
'content': content
|
||||
})
|
||||
seen_ids.add(doc_id)
|
||||
|
||||
search_results = {'journal_article': journallist}
|
||||
logging.info(f"段落推荐召回资料整理结果:{journallist}")
|
||||
|
||||
# 如果没有找到相关文档,返回默认响应
|
||||
# if not journallist:
|
||||
# search_results['journal_article'].append({
|
||||
# 'id': 'empty',
|
||||
# 'title': 'empty',
|
||||
# 'content': '暂未找到相关资料。'
|
||||
# })
|
||||
|
||||
yield json.dumps(search_results, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
# 捕获并返回错误信息
|
||||
error_response = BaseResponse(code=500, msg=str(e))
|
||||
yield json.dumps(error_response.dict(), ensure_ascii=False)
|
||||
|
||||
return EventSourceResponse(search_iterator())
|
||||
193
langchain-chat/server/custom/article_overview.py
Normal file
193
langchain-chat/server/custom/article_overview.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncIterable, List, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from fastapi import Body, Request
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from configs import (TEMPERATURE,
|
||||
USE_RERANKER,
|
||||
RERANKER_MODEL,
|
||||
RERANKER_MAX_LENGTH,
|
||||
LLM_MODELS,
|
||||
MODEL_PATH,
|
||||
MAX_TOKENS)
|
||||
from server.chat.utils import History
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.reranker.reranker import LangchainReranker
|
||||
from server.utils import BaseResponse, get_prompt_template
|
||||
from server.utils import embedding_device
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
|
||||
|
||||
async def article_overview(query: str = Body("你好", description="用户输入", examples=["你好"]),
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["t_policy_total_bce_v1"]),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(
|
||||
MAX_TOKENS,
|
||||
description="限制LLM生成Token数量,默认None代表模型最大值"
|
||||
),
|
||||
prompt_name: str = Body(
|
||||
"Article Overview",
|
||||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
|
||||
),
|
||||
source_name_list: List[str] = Body([], description="资源列表"),
|
||||
request: Request = None,
|
||||
):
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
query = "帮我对以下文件进行总结 :" + ",".join(source_name_list)
|
||||
if len(source_name_list) > 1:
|
||||
prompt_name = "Article Overview2"
|
||||
else:
|
||||
prompt_name = "Article Overview"
|
||||
async def article_overview_iterator(
|
||||
query: str,
|
||||
model_name: str = model_name,
|
||||
prompt_name: str = prompt_name,
|
||||
) -> AsyncIterable[str]:
|
||||
nonlocal max_tokens
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
||||
max_tokens = None
|
||||
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[callback],
|
||||
)
|
||||
docs = []
|
||||
docs = await run_in_threadpool(kb.get_doc_by_sources_name,
|
||||
source_name_list=source_name_list)
|
||||
|
||||
# 加入reranker
|
||||
if USE_RERANKER:
|
||||
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL, "BAAI/bge-reranker-large")
|
||||
print("-----------------model path------------------")
|
||||
print(reranker_model_path)
|
||||
reranker_model = LangchainReranker(top_n=3,
|
||||
device=embedding_device(),
|
||||
max_length=RERANKER_MAX_LENGTH,
|
||||
model_name_or_path=reranker_model_path
|
||||
)
|
||||
print("---------before rerank------------------")
|
||||
print(docs)
|
||||
docs = reranker_model.compress_documents(documents=docs,
|
||||
query=query)
|
||||
print("---------after rerank------------------")
|
||||
print(docs)
|
||||
# context = "\n".join([doc.page_content for doc in docs])
|
||||
# 相关信息把标题和内容进行整合
|
||||
if knowledge_base_name == 't_policy_total_bce_v1':
|
||||
knowledge = []
|
||||
for doc in docs:
|
||||
if doc.metadata["_type"] == "title":
|
||||
knowledge.append(doc.page_content + "\n" + doc.metadata['content'])
|
||||
if doc.metadata["_type"] == "content":
|
||||
knowledge.append(doc.metadata['title'] + "\n" + doc.page_content)
|
||||
context = "\n\n".join(knowledge)
|
||||
# 非政策知识库
|
||||
else:
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
print(f"context:{context}\n")
|
||||
if len(docs) == 0: # 如果没有找到相关文档,使用empty模板
|
||||
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
|
||||
else:
|
||||
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
|
||||
print(f"chat_prompt:{chat_prompt}\n")
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
|
||||
# Begin a task that runs in the background.
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"context": context, "question": query}),
|
||||
callback.done),
|
||||
)
|
||||
|
||||
source_documents = []
|
||||
# 政策知识库
|
||||
if knowledge_base_name == 't_policy_total_bce_v1':
|
||||
for inum, doc in enumerate(docs):
|
||||
# 获取标题以及详情地址(url)
|
||||
filename = doc.metadata.get("title")
|
||||
detail_url = doc.metadata.get("source")
|
||||
# parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename})
|
||||
# base_url = request.base_url
|
||||
# url = f"{base_url}knowledge_base/download_doc?" + parameters
|
||||
# text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
||||
if filename:
|
||||
text = f"""出处: [{filename}]({detail_url}) \n\n"""
|
||||
else:
|
||||
text = f"""出处: [{"原文地址"}]({detail_url}) \n\n"""
|
||||
source_documents.append(text)
|
||||
# 非政策知识库
|
||||
else:
|
||||
for inum, doc in enumerate(docs):
|
||||
filename = doc.metadata.get("source")
|
||||
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename})
|
||||
base_url = request.base_url
|
||||
url = f"{base_url}knowledge_base/download_doc?" + parameters
|
||||
if filename:
|
||||
text = f"""出处: [{filename}]({url}) \n\n"""
|
||||
else:
|
||||
text = f"""出处: [{"原文地址"}]({url}) \n\n"""
|
||||
source_documents.append(text)
|
||||
|
||||
if len(source_documents) == 0: # 没有找到相关文档
|
||||
source_documents.append(f"<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>")
|
||||
|
||||
if stream:
|
||||
async for token in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
yield json.dumps({"answer": token}, ensure_ascii=False)
|
||||
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
|
||||
else:
|
||||
answer = ""
|
||||
async for token in callback.aiter():
|
||||
answer += token
|
||||
yield json.dumps({"answer": answer,
|
||||
"docs": source_documents},
|
||||
ensure_ascii=False)
|
||||
await task
|
||||
|
||||
return EventSourceResponse(article_overview_iterator(query, model_name=model_name, prompt_name=prompt_name))
|
||||
|
||||
|
||||
class ArticleOverview:
|
||||
query = "请给我对文件进行一下总结"
|
||||
|
||||
def __init__(self):
|
||||
self._PROMPT_TEMPLATE = """
|
||||
'<角色> 你是由浪潮开发的知冶大模型中所选定的文件综述助手。</角色> \n\n'
|
||||
'Your task is to write a detailed summary of the provided {{context}} file. Ensure that your summary is '
|
||||
'longer than 300 words and captures the essence of the content. Focus on the main points, '
|
||||
'key findings, and any important implications or conclusions. Maintain an unbiased tone and avoid relying '
|
||||
'on stereotypes. Organize the summary in a clear and coherent manner, using appropriate headings or '
|
||||
'bullet points if necessary. Remember to keep the summary concise while preserving the core information. '
|
||||
'Let\'s start with a brief overview of the file\'s main topic and then delve into the specifics.'
|
||||
'PLEASE ALWAYS RESPOND IN CHINESE!\n'
|
||||
'<已知信息>{{ context }}</已知信息>\n'
|
||||
'<问题>{{ question }}</问题>\n',
|
||||
"""
|
||||
self.PROMPT = PromptTemplate(
|
||||
input_variables=["question", "database_names"],
|
||||
template=self._PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
def query_out(self, knowledge_base_name: str, source_name_list: list):
|
||||
self.query = "帮我对以下文件进行总结 :" + ",".join(source_name_list)
|
||||
return article_overview(self.query,
|
||||
knowledge_base_name=knowledge_base_name,
|
||||
source_name_list=source_name_list
|
||||
)
|
||||
113
langchain-chat/server/custom/chapter_overview.py
Normal file
113
langchain-chat/server/custom/chapter_overview.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncIterable, List, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from fastapi import Body, Request
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from configs import (TEMPERATURE,
|
||||
USE_RERANKER,
|
||||
RERANKER_MODEL,
|
||||
RERANKER_MAX_LENGTH,
|
||||
MODEL_PATH,
|
||||
MAX_TOKENS,
|
||||
MAX_CUT_TOKENS, LLM_MODELS)
|
||||
from server.chat.utils import History
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.reranker.reranker import LangchainReranker
|
||||
from server.utils import BaseResponse, get_prompt_template
|
||||
from server.utils import embedding_device
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
from collections import defaultdict
|
||||
from server.custom.custom_fun import chapter_overview_summary
|
||||
|
||||
async def task(param):
|
||||
contextk = param["contextk"]
|
||||
i = param["i"]
|
||||
model_name = param["model_name"]
|
||||
temperature = param["temperature"]
|
||||
max_tokens = param["max_tokens"]
|
||||
chat_prompt = param["chat_prompt"]
|
||||
print(f"i:{i},len_context:{len(contextk)}\n")
|
||||
callback_temp = AsyncIteratorCallbackHandler()
|
||||
model_temp = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[callback_temp],
|
||||
)
|
||||
chain_temp = LLMChain(prompt=chat_prompt, llm=model_temp)
|
||||
task_temp = wrap_done(chain_temp.acall({"context": contextk,
|
||||
"question": "对该部分内容进行总结"}),
|
||||
callback_temp.done)
|
||||
|
||||
await task_temp
|
||||
async for token in callback_temp.aiter():
|
||||
yield token
|
||||
|
||||
|
||||
# 使用多线程执行任务
|
||||
async def run_tasks_concurrently(params):
|
||||
result = []
|
||||
async for data in asyncio.as_completed([task(param) async for param in params]):
|
||||
result.append(''.join([token async for token in data]))
|
||||
return result
|
||||
|
||||
|
||||
async def chapter_overview(query: str = Body("为我总结这些内容", description="用户输入", examples=["你好"]),
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(
|
||||
MAX_TOKENS,
|
||||
description="限制LLM生成Token数量,默认None代表模型最大值"
|
||||
),
|
||||
prompt_name: str = Body(
|
||||
"Chapter Overview",
|
||||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
|
||||
),
|
||||
source_name_list: List[str] = Body([], description="资源列表"),
|
||||
request: Request = None,
|
||||
):
|
||||
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
async def chapter_overview_iterator(
|
||||
model_name: str = model_name,
|
||||
) -> AsyncIterable[str]:
|
||||
nonlocal max_tokens
|
||||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
||||
max_tokens = None
|
||||
|
||||
docs = await run_in_threadpool(kb.get_doc_by_sources_name,source_name_list=source_name_list)
|
||||
|
||||
chapter_summaries, global_summary = await chapter_overview_summary(docs, model_name, temperature, max_tokens)
|
||||
|
||||
if stream:
|
||||
for h1, summaries in chapter_summaries.items():
|
||||
for summary in summaries:
|
||||
yield json.dumps({"chapter_title": h1, "summary": summary}, ensure_ascii=False)
|
||||
yield json.dumps({"global_summary": global_summary}, ensure_ascii=False)
|
||||
# yield json.dumps({"docs": source_documents}, ensure_ascii=False)
|
||||
else:
|
||||
result = {
|
||||
"chapter_summaries": chapter_summaries,
|
||||
"global_summary": global_summary,
|
||||
}
|
||||
yield json.dumps(result, ensure_ascii=False)
|
||||
|
||||
return EventSourceResponse(chapter_overview_iterator(model_name=model_name))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Binary file not shown.
74
langchain-chat/server/custom/custom_fun.py
Normal file
74
langchain-chat/server/custom/custom_fun.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
Custom functions for document processing.
|
||||
"""
|
||||
|
||||
|
||||
async def chapter_overview_summary(docs, model_name, temperature, max_tokens):
|
||||
"""
|
||||
为文档生成章节概述摘要的占位函数。
|
||||
|
||||
Args:
|
||||
docs: 文档列表
|
||||
model_name: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
|
||||
Returns:
|
||||
chapter_summaries: 章节摘要字典
|
||||
global_summary: 全局摘要字符串
|
||||
"""
|
||||
chapter_summaries = {}
|
||||
global_summary = ""
|
||||
|
||||
# 简单的占位实现:直接使用文档内容
|
||||
for doc in docs:
|
||||
h1 = doc.metadata.get("head1", "默认章节") if hasattr(doc, 'metadata') else "默认章节"
|
||||
if h1 not in chapter_summaries:
|
||||
chapter_summaries[h1] = []
|
||||
content = doc.page_content if hasattr(doc, 'page_content') else str(doc)
|
||||
chapter_summaries[h1].append(content[:500] if len(content) > 500 else content)
|
||||
|
||||
# 生成全局摘要
|
||||
if docs:
|
||||
all_content = " ".join([doc.page_content if hasattr(doc, 'page_content') else str(doc) for doc in docs[:3]])
|
||||
global_summary = all_content[:1000] if len(all_content) > 1000 else all_content
|
||||
|
||||
return chapter_summaries, global_summary
|
||||
|
||||
|
||||
async def chi_translation(docs, prompt_template, model_name, temperature, max_tokens):
|
||||
"""
|
||||
中文翻译占位函数。
|
||||
|
||||
Args:
|
||||
docs: 文档列表
|
||||
prompt_template: 提示模板
|
||||
model_name: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
|
||||
Yields:
|
||||
翻译结果字符串
|
||||
"""
|
||||
for doc in docs:
|
||||
content = doc.page_content if hasattr(doc, 'page_content') else str(doc)
|
||||
yield content
|
||||
|
||||
|
||||
async def eng_translation(docs, prompt_template, model_name, temperature, max_tokens):
|
||||
"""
|
||||
英文翻译占位函数。
|
||||
|
||||
Args:
|
||||
docs: 文档列表
|
||||
prompt_template: 提示模板
|
||||
model_name: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
|
||||
Yields:
|
||||
翻译结果字符串
|
||||
"""
|
||||
for doc in docs:
|
||||
content = doc.page_content if hasattr(doc, 'page_content') else str(doc)
|
||||
yield content
|
||||
78
langchain-chat/server/custom/paper_translation.py
Normal file
78
langchain-chat/server/custom/paper_translation.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncIterable, List, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from fastapi import Body, Request
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from configs import (TEMPERATURE,
|
||||
USE_RERANKER,
|
||||
RERANKER_MODEL,
|
||||
RERANKER_MAX_LENGTH,
|
||||
MODEL_PATH,
|
||||
MAX_TOKENS,
|
||||
MAX_CUT_TOKENS, LLM_MODELS)
|
||||
from server.chat.utils import History
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.reranker.reranker import LangchainReranker
|
||||
from server.utils import BaseResponse, get_prompt_template
|
||||
from server.utils import embedding_device
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
from collections import defaultdict
|
||||
from server.custom.custom_fun import chi_translation,eng_translation
|
||||
|
||||
async def paper_translation(query: str = Body("为我总结这些内容", description="用户输入", examples=["你好"]),
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(
|
||||
MAX_TOKENS,
|
||||
description="限制LLM生成Token数量,默认None代表模型最大值"
|
||||
),
|
||||
prompt_name: str = Body(
|
||||
"eng_chi",
|
||||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
|
||||
),
|
||||
source_name_list: List[str] = Body([], description="资源列表"),
|
||||
request: Request = None,
|
||||
):
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
async def translation_iterator(
|
||||
query: str,
|
||||
model_name: str = model_name,
|
||||
prompt_name: str = prompt_name,
|
||||
) -> AsyncIterable[str]:
|
||||
nonlocal max_tokens
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
||||
max_tokens = None
|
||||
|
||||
docs = []
|
||||
docs = await run_in_threadpool(kb.get_doc_by_sources_name,
|
||||
source_name_list=source_name_list)
|
||||
|
||||
if len(docs) == 0: # 如果没有找到相关文档,使用empty模板
|
||||
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
|
||||
else:
|
||||
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
||||
# input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
# chat_prompt = ChatPromptTemplate.from_messages([input_msg])
|
||||
# chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
if prompt_name == "chi_eng":
|
||||
async for chunk in chi_translation(docs, prompt_template, model_name, temperature, max_tokens):
|
||||
yield chunk
|
||||
else:
|
||||
async for chunk in eng_translation(docs, prompt_template, model_name, temperature, max_tokens):
|
||||
yield chunk
|
||||
|
||||
|
||||
return EventSourceResponse(translation_iterator(query, model_name=model_name, prompt_name=prompt_name))
|
||||
Reference in New Issue
Block a user