[全量] 初始化项目代码、配置、文档及Agent协同harness

This commit is contained in:
2026-04-02 11:36:05 +08:00
parent 0553309cdf
commit 87e571d9ec
1133 changed files with 221948 additions and 0 deletions

View File

@@ -0,0 +1,74 @@
from __future__ import annotations
import asyncio
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast
from langchain_core.outputs import LLMResult
from langchain.callbacks.base import AsyncCallbackHandler
# TODO If used by two LLM runs in parallel this won't work as expected
class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
"""Callback handler that returns an async iterator."""
queue: asyncio.Queue[str]
done: asyncio.Event
@property
def always_verbose(self) -> bool:
return True
def __init__(self) -> None:
self.queue = asyncio.Queue()
self.done = asyncio.Event()
async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
# If two calls are made in a row, this resets the state
self.done.clear()
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
if token is not None and token != "":
self.queue.put_nowait(token)
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self.done.set()
async def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self.done.set()
# TODO implement the other methods
def get_queue(self):
return self.queue._queue
async def aiter(self) -> AsyncIterator[str]:
while not self.queue.empty() or not self.done.is_set():
# Wait for the next token in the queue,
# but stop waiting if the done event is set
done, other = await asyncio.wait(
[
# NOTE: If you add other tasks here, update the code below,
# which assumes each set has exactly one task each
asyncio.ensure_future(self.queue.get()),
asyncio.ensure_future(self.done.wait()),
],
return_when=asyncio.FIRST_COMPLETED,
)
# Cancel the other task
if other:
other.pop().cancel()
# Extract the value of the first completed task
token_or_done = cast(Union[str, Literal[True]], done.pop().result())
# If the extracted value is the boolean True, the done event was set
if token_or_done is True:
break
# Otherwise, the extracted value is a token, which we yield
yield token_or_done

View File

View File

@@ -0,0 +1,120 @@
import asyncio
import json
import logging
from typing import AsyncIterable, List, Optional
from fastapi import Body, Request, HTTPException
from fastapi.concurrency import run_in_threadpool
from sse_starlette.sse import EventSourceResponse
from server.knowledge_base.kb_doc_api import search_docs
from configs import (
TEMPERATURE,
VECTOR_SEARCH_TOP_K,
SCORE_THRESHOLD,
LLM_MODELS,
JOURNAL_KNOWLEDGE_BASE,
MAX_TOKENS,
)
from server.utils import BaseResponse, get_prompt_template
from server.chat.policy_fun_iast import get_llm_model_response
from server.knowledge_base.kb_service.base import KBServiceFactory
async def abstract_search(
query: str = Body(..., description="用户输入", examples=["人工智能领域相关的文献有啥"]),
knowledge_base_name_list: List[str] = Body(..., description="知识库名称", examples=[["t_journal_article_bge_v1"]]),
# top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数", ge=1),
# fileName: List[str] = Body([], description="文件名称", examples=[["123.txt"]]),
# stream: bool = Body(False, description="流式输出"),
# model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
# temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
# max_tokens: Optional[int] = Body(MAX_TOKENS,description="限制LLM生成Token数量默认None代表模型最大值"),
# score_threshold: float = Body(SCORE_THRESHOLD,
# description=(
# "知识库匹配相关度阈值取值范围在0-1之间"
# "SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右"
# ),
# ge=0.0,
# le=1.0
# ),
# request: Request = None,
# query_rewrite_model_name: str = Body(
# 'Qwen2-72B-Instruct',
# description="Query rewrite model name."
# ),
) -> EventSourceResponse:
async def search_iterator() -> AsyncIterable[str]:
try:
# 重写查询
search_query_response = await run_in_threadpool(
get_llm_model_response,
strategy_name="query rewrite",
llm_model_name=LLM_MODELS[0],
template_prompt_name="query_rewrite",
prompt_param_dict={"query": query, "history": []},
temperature=0.01,
max_tokens=512
)
# 解析重写后的查询
try:
data = json.loads(search_query_response.strip("```json\n").strip("```"))
queries = data.get('query', [])
search_query = ' '.join(queries)
except json.JSONDecodeError:
search_query = query
logging.info(f'段落推荐输入问题: {search_query}')
# 执行文档搜索(这里只搜索期刊知识库)
journaldocs = await run_in_threadpool(
search_docs,
query=search_query,
fileName=[],
knowledge_base_name=JOURNAL_KNOWLEDGE_BASE,
top_k=VECTOR_SEARCH_TOP_K,
score_threshold=SCORE_THRESHOLD
)
logging.info(f"段落推荐召回资料: {journaldocs}")
# 处理搜索结果,去重并整理
seen_ids = set()
journallist = []
# journaldocs = []
for doc in journaldocs:
doc_id = doc.metadata.get('ID')
content = doc.metadata.get('abstract', '').strip()
#处理content不是完整段落的情况
if content:
if not content.endswith(''):
last_period_index = content.rfind('')
if last_period_index != -1:
content = content[:last_period_index + 1]
else:
content += '...'
if doc_id and content and doc_id not in seen_ids:
journallist.append({
'id': doc_id,
'title': doc.metadata.get('title', ''),
'content': content
})
seen_ids.add(doc_id)
search_results = {'journal_article': journallist}
logging.info(f"段落推荐召回资料整理结果:{journallist}")
# 如果没有找到相关文档,返回默认响应
# if not journallist:
# search_results['journal_article'].append({
# 'id': 'empty',
# 'title': 'empty',
# 'content': '暂未找到相关资料。'
# })
yield json.dumps(search_results, ensure_ascii=False)
except Exception as e:
# 捕获并返回错误信息
error_response = BaseResponse(code=500, msg=str(e))
yield json.dumps(error_response.dict(), ensure_ascii=False)
return EventSourceResponse(search_iterator())

View File

@@ -0,0 +1,193 @@
import asyncio
import json
from typing import AsyncIterable, List, Optional
from urllib.parse import urlencode
from fastapi import Body, Request
from fastapi.concurrency import run_in_threadpool
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
from sse_starlette.sse import EventSourceResponse
from configs import (TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
LLM_MODELS,
MODEL_PATH,
MAX_TOKENS)
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.reranker.reranker import LangchainReranker
from server.utils import BaseResponse, get_prompt_template
from server.utils import embedding_device
from server.utils import wrap_done, get_ChatOpenAI
async def article_overview(query: str = Body("你好", description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["t_policy_total_bce_v1"]),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(
MAX_TOKENS,
description="限制LLM生成Token数量默认None代表模型最大值"
),
prompt_name: str = Body(
"Article Overview",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
),
source_name_list: List[str] = Body([], description="资源列表"),
request: Request = None,
):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
query = "帮我对以下文件进行总结 " + ",".join(source_name_list)
if len(source_name_list) > 1:
prompt_name = "Article Overview2"
else:
prompt_name = "Article Overview"
async def article_overview_iterator(
query: str,
model_name: str = model_name,
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
docs = []
docs = await run_in_threadpool(kb.get_doc_by_sources_name,
source_name_list=source_name_list)
# 加入reranker
if USE_RERANKER:
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL, "BAAI/bge-reranker-large")
print("-----------------model path------------------")
print(reranker_model_path)
reranker_model = LangchainReranker(top_n=3,
device=embedding_device(),
max_length=RERANKER_MAX_LENGTH,
model_name_or_path=reranker_model_path
)
print("---------before rerank------------------")
print(docs)
docs = reranker_model.compress_documents(documents=docs,
query=query)
print("---------after rerank------------------")
print(docs)
# context = "\n".join([doc.page_content for doc in docs])
# 相关信息把标题和内容进行整合
if knowledge_base_name == 't_policy_total_bce_v1':
knowledge = []
for doc in docs:
if doc.metadata["_type"] == "title":
knowledge.append(doc.page_content + "\n" + doc.metadata['content'])
if doc.metadata["_type"] == "content":
knowledge.append(doc.metadata['title'] + "\n" + doc.page_content)
context = "\n\n".join(knowledge)
# 非政策知识库
else:
context = "\n".join([doc.page_content for doc in docs])
print(f"context:{context}\n")
if len(docs) == 0: # 如果没有找到相关文档使用empty模板
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
else:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
print(f"chat_prompt:{chat_prompt}\n")
chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"context": context, "question": query}),
callback.done),
)
source_documents = []
# 政策知识库
if knowledge_base_name == 't_policy_total_bce_v1':
for inum, doc in enumerate(docs):
# 获取标题以及详情地址url
filename = doc.metadata.get("title")
detail_url = doc.metadata.get("source")
# parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename})
# base_url = request.base_url
# url = f"{base_url}knowledge_base/download_doc?" + parameters
# text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
if filename:
text = f"""出处: [{filename}]({detail_url}) \n\n"""
else:
text = f"""出处: [{"原文地址"}]({detail_url}) \n\n"""
source_documents.append(text)
# 非政策知识库
else:
for inum, doc in enumerate(docs):
filename = doc.metadata.get("source")
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename})
base_url = request.base_url
url = f"{base_url}knowledge_base/download_doc?" + parameters
if filename:
text = f"""出处: [{filename}]({url}) \n\n"""
else:
text = f"""出处: [{"原文地址"}]({url}) \n\n"""
source_documents.append(text)
if len(source_documents) == 0: # 没有找到相关文档
source_documents.append(f"<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>")
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"answer": token}, ensure_ascii=False)
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps({"answer": answer,
"docs": source_documents},
ensure_ascii=False)
await task
return EventSourceResponse(article_overview_iterator(query, model_name=model_name, prompt_name=prompt_name))
class ArticleOverview:
query = "请给我对文件进行一下总结"
def __init__(self):
self._PROMPT_TEMPLATE = """
'<角色> 你是由浪潮开发的知冶大模型中所选定的文件综述助手。</角色> \n\n'
'Your task is to write a detailed summary of the provided {{context}} file. Ensure that your summary is '
'longer than 300 words and captures the essence of the content. Focus on the main points, '
'key findings, and any important implications or conclusions. Maintain an unbiased tone and avoid relying '
'on stereotypes. Organize the summary in a clear and coherent manner, using appropriate headings or '
'bullet points if necessary. Remember to keep the summary concise while preserving the core information. '
'Let\'s start with a brief overview of the file\'s main topic and then delve into the specifics.'
'PLEASE ALWAYS RESPOND IN CHINESE!\n'
'<已知信息>{{ context }}</已知信息>\n'
'<问题>{{ question }}</问题>\n',
"""
self.PROMPT = PromptTemplate(
input_variables=["question", "database_names"],
template=self._PROMPT_TEMPLATE,
)
def query_out(self, knowledge_base_name: str, source_name_list: list):
self.query = "帮我对以下文件进行总结 " + ",".join(source_name_list)
return article_overview(self.query,
knowledge_base_name=knowledge_base_name,
source_name_list=source_name_list
)

View File

@@ -0,0 +1,113 @@
import asyncio
import json
from typing import AsyncIterable, List, Optional
from urllib.parse import urlencode
from fastapi import Body, Request
from fastapi.concurrency import run_in_threadpool
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
from sse_starlette.sse import EventSourceResponse
from configs import (TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
MODEL_PATH,
MAX_TOKENS,
MAX_CUT_TOKENS, LLM_MODELS)
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.reranker.reranker import LangchainReranker
from server.utils import BaseResponse, get_prompt_template
from server.utils import embedding_device
from server.utils import wrap_done, get_ChatOpenAI
from collections import defaultdict
from server.custom.custom_fun import chapter_overview_summary
async def task(param):
contextk = param["contextk"]
i = param["i"]
model_name = param["model_name"]
temperature = param["temperature"]
max_tokens = param["max_tokens"]
chat_prompt = param["chat_prompt"]
print(f"i:{i}len_context:{len(contextk)}\n")
callback_temp = AsyncIteratorCallbackHandler()
model_temp = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback_temp],
)
chain_temp = LLMChain(prompt=chat_prompt, llm=model_temp)
task_temp = wrap_done(chain_temp.acall({"context": contextk,
"question": "对该部分内容进行总结"}),
callback_temp.done)
await task_temp
async for token in callback_temp.aiter():
yield token
# 使用多线程执行任务
async def run_tasks_concurrently(params):
result = []
async for data in asyncio.as_completed([task(param) async for param in params]):
result.append(''.join([token async for token in data]))
return result
async def chapter_overview(query: str = Body("为我总结这些内容", description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(
MAX_TOKENS,
description="限制LLM生成Token数量默认None代表模型最大值"
),
prompt_name: str = Body(
"Chapter Overview",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
),
source_name_list: List[str] = Body([], description="资源列表"),
request: Request = None,
):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
async def chapter_overview_iterator(
model_name: str = model_name,
) -> AsyncIterable[str]:
nonlocal max_tokens
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
docs = await run_in_threadpool(kb.get_doc_by_sources_name,source_name_list=source_name_list)
chapter_summaries, global_summary = await chapter_overview_summary(docs, model_name, temperature, max_tokens)
if stream:
for h1, summaries in chapter_summaries.items():
for summary in summaries:
yield json.dumps({"chapter_title": h1, "summary": summary}, ensure_ascii=False)
yield json.dumps({"global_summary": global_summary}, ensure_ascii=False)
# yield json.dumps({"docs": source_documents}, ensure_ascii=False)
else:
result = {
"chapter_summaries": chapter_summaries,
"global_summary": global_summary,
}
yield json.dumps(result, ensure_ascii=False)
return EventSourceResponse(chapter_overview_iterator(model_name=model_name))

View File

@@ -0,0 +1,74 @@
"""
Custom functions for document processing.
"""
async def chapter_overview_summary(docs, model_name, temperature, max_tokens):
"""
为文档生成章节概述摘要的占位函数。
Args:
docs: 文档列表
model_name: 模型名称
temperature: 温度参数
max_tokens: 最大token数
Returns:
chapter_summaries: 章节摘要字典
global_summary: 全局摘要字符串
"""
chapter_summaries = {}
global_summary = ""
# 简单的占位实现:直接使用文档内容
for doc in docs:
h1 = doc.metadata.get("head1", "默认章节") if hasattr(doc, 'metadata') else "默认章节"
if h1 not in chapter_summaries:
chapter_summaries[h1] = []
content = doc.page_content if hasattr(doc, 'page_content') else str(doc)
chapter_summaries[h1].append(content[:500] if len(content) > 500 else content)
# 生成全局摘要
if docs:
all_content = " ".join([doc.page_content if hasattr(doc, 'page_content') else str(doc) for doc in docs[:3]])
global_summary = all_content[:1000] if len(all_content) > 1000 else all_content
return chapter_summaries, global_summary
async def chi_translation(docs, prompt_template, model_name, temperature, max_tokens):
"""
中文翻译占位函数。
Args:
docs: 文档列表
prompt_template: 提示模板
model_name: 模型名称
temperature: 温度参数
max_tokens: 最大token数
Yields:
翻译结果字符串
"""
for doc in docs:
content = doc.page_content if hasattr(doc, 'page_content') else str(doc)
yield content
async def eng_translation(docs, prompt_template, model_name, temperature, max_tokens):
"""
英文翻译占位函数。
Args:
docs: 文档列表
prompt_template: 提示模板
model_name: 模型名称
temperature: 温度参数
max_tokens: 最大token数
Yields:
翻译结果字符串
"""
for doc in docs:
content = doc.page_content if hasattr(doc, 'page_content') else str(doc)
yield content

View File

@@ -0,0 +1,78 @@
import asyncio
import json
from typing import AsyncIterable, List, Optional
from urllib.parse import urlencode
from fastapi import Body, Request
from fastapi.concurrency import run_in_threadpool
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
from sse_starlette.sse import EventSourceResponse
from configs import (TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
MODEL_PATH,
MAX_TOKENS,
MAX_CUT_TOKENS, LLM_MODELS)
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.reranker.reranker import LangchainReranker
from server.utils import BaseResponse, get_prompt_template
from server.utils import embedding_device
from server.utils import wrap_done, get_ChatOpenAI
from collections import defaultdict
from server.custom.custom_fun import chi_translation,eng_translation
async def paper_translation(query: str = Body("为我总结这些内容", description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(
MAX_TOKENS,
description="限制LLM生成Token数量默认None代表模型最大值"
),
prompt_name: str = Body(
"eng_chi",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
),
source_name_list: List[str] = Body([], description="资源列表"),
request: Request = None,
):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
async def translation_iterator(
query: str,
model_name: str = model_name,
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
docs = []
docs = await run_in_threadpool(kb.get_doc_by_sources_name,
source_name_list=source_name_list)
if len(docs) == 0: # 如果没有找到相关文档使用empty模板
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
else:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
# input_msg = History(role="user", content=prompt_template).to_msg_template(False)
# chat_prompt = ChatPromptTemplate.from_messages([input_msg])
# chain = LLMChain(prompt=chat_prompt, llm=model)
if prompt_name == "chi_eng":
async for chunk in chi_translation(docs, prompt_template, model_name, temperature, max_tokens):
yield chunk
else:
async for chunk in eng_translation(docs, prompt_template, model_name, temperature, max_tokens):
yield chunk
return EventSourceResponse(translation_iterator(query, model_name=model_name, prompt_name=prompt_name))