[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
293
langchain-chat/server/chat/write_article.py
Normal file
293
langchain-chat/server/chat/write_article.py
Normal file
@@ -0,0 +1,293 @@
|
||||
import base64
|
||||
import uuid
|
||||
from fastapi import Body
|
||||
from fastapi.responses import FileResponse
|
||||
from langchain.memory import (
|
||||
CombinedMemory,
|
||||
ConversationBufferMemory,
|
||||
ConversationSummaryMemory,
|
||||
ConversationBufferWindowMemory
|
||||
)
|
||||
from typing import Any
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN
|
||||
from server.chat import utils
|
||||
from server.chat.agent_write_test import agent_write_test
|
||||
from server.chat.agent_chat_test import agent_chat_test, run_sync
|
||||
from server.chat.policy_fun_iast import get_llm_model_response
|
||||
from server.chat.solve_problem import solve_problem
|
||||
from server.knowledge_base.kb_service.base import TextRank
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
from langchain.chains import LLMChain, ConversationChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
import asyncio
|
||||
import json
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from typing import List, Optional, Union
|
||||
from server.chat.utils import History, compute_lps, remove_after_and_including,remove_before_and_including
|
||||
from langchain.prompts import PromptTemplate
|
||||
from server.utils import get_prompt_template, get_format_template
|
||||
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
||||
from server.db.repository import add_message_to_db
|
||||
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
|
||||
from datetime import datetime
|
||||
from langchain_core.messages import SystemMessage
|
||||
import time as t
|
||||
from configs.basic_config import *
|
||||
async def process_task(task):
|
||||
results = []
|
||||
async for result in task:
|
||||
results.append(result)
|
||||
return results
|
||||
async def chunk_docs(docs, chunk_size=800):
|
||||
"""将文档分块"""
|
||||
return [docs[i:i + chunk_size] for i in range(0, len(docs), chunk_size)]
|
||||
def convert_last_digits(s):
|
||||
# 检查字符串最后一位是否是数字
|
||||
if s[-1].strip().replace("\n","").isdigit():
|
||||
# 如果最后两位都是数字
|
||||
if len(s) > 1 and s[-2].strip().replace("\n","").isdigit():
|
||||
return int(s[-2:].strip().replace("\n","")),2
|
||||
# 如果只有最后一位是数字
|
||||
else:
|
||||
return int(s[-1].strip().replace("\n","")),1
|
||||
# 如果最后一位不是数字,返回None或其他适当的值
|
||||
return None
|
||||
async def write_article(
|
||||
messageId: Optional[Any] = Body(None, description="消息ID", examples=[""]),
|
||||
uid: Optional[Any] = Body(None, description="用户ID"),
|
||||
styles: Optional[Any] = Body(None, description="风格"),
|
||||
query: str = Body(..., description="大纲", examples=["恼羞成怒"]),
|
||||
knowledge_base_list: Optional[List[str]] = Body(None, description="个人知识库名称列表"),
|
||||
# conversation_id: str = Body("", description="对话框ID"),
|
||||
# history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
|
||||
# history: Union[int, List] = Body([],
|
||||
# description="历史对话,设为一个整数可以从数据库中读取历史消息",
|
||||
# examples=[[
|
||||
# {"role": "user",
|
||||
# "content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
# {"role": "assistant", "content": "虎头虎脑","summary":"8989uj9"}]]
|
||||
# ),
|
||||
# stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||
prompt_name: str = Body("complete_outline", description="大纲补全必须使用这个提示词"),
|
||||
):
|
||||
async def chat_iterator() -> AsyncIterable[str]:
|
||||
nonlocal messageId, styles, knowledge_base_list
|
||||
res={}
|
||||
res["text"]=""
|
||||
step=0
|
||||
num = 0
|
||||
time_based_uuid = messageId if messageId else str(uuid.uuid1())+"q"
|
||||
tip ={}
|
||||
tip["type"] = "article"
|
||||
include_knowldge = False
|
||||
if knowledge_base_list:
|
||||
tip["database"] = knowledge_base_list
|
||||
tip["status"] = True
|
||||
tip["source_docs"]=[]
|
||||
utils.set_shared_variable(time_based_uuid,tip)
|
||||
logger.info(f"生成正文的入参信息:\nquery:{query}\nstyles:{styles}\nmessageId: {messageId}\nknowledge_base_list:{knowledge_base_list}\nprompt_name:{prompt_name}\n")
|
||||
# 完成文章分解。
|
||||
while True:
|
||||
try:
|
||||
async for response in solve_problem(query=query, conversation_id="", history=[], model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="solve_problem_outline",stream=True):
|
||||
res["text"]+=json.loads(response)["text"]
|
||||
if not utils.get_shared_variable(time_based_uuid)["status"]:
|
||||
logging.info("\n==============================STOPPED==============================\n")
|
||||
return
|
||||
step,num = convert_last_digits(res["text"])
|
||||
break
|
||||
except:
|
||||
continue
|
||||
#step,num = convert_last_digits(res["text"])
|
||||
res["text"]=res["text"][:-num]
|
||||
docs = ""
|
||||
finish = []
|
||||
i = 0
|
||||
nums = 0
|
||||
while i < int(step):
|
||||
history_summary = ""
|
||||
count = 0
|
||||
index = 0
|
||||
temp = ""
|
||||
# time_based_uuid = messageId if messageId else str(uuid.uuid1())+"q"
|
||||
answer = ""
|
||||
tools = []
|
||||
tip["num"]= nums
|
||||
tip["END"]=""
|
||||
tip['title'] = []
|
||||
utils.set_shared_variable(time_based_uuid,tip)
|
||||
finish_content=""
|
||||
style = ""
|
||||
chapter = f"当前撰写的是第{i+1}章"
|
||||
try:
|
||||
determine_chapter = await run_sync(
|
||||
get_llm_model_response,
|
||||
strategy_name="default_code",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="identify_chapters",
|
||||
prompt_param_dict={"chapter": chapter, "outline": query},
|
||||
temperature=0.01,
|
||||
max_tokens=512
|
||||
)
|
||||
if int(determine_chapter.strip()) == 1:
|
||||
i += 1
|
||||
await asyncio.sleep(2)
|
||||
|
||||
if docs == "":
|
||||
history_summary += f"\n\n## {i}. 参考文献\n\n未参考任何文献"
|
||||
yield json.dumps({"text": f"\n\n## {i}. 参考文献\n\n未参考任何文献"}, ensure_ascii=False)
|
||||
else:
|
||||
# 分块处理文档
|
||||
doc_chunks = await chunk_docs(docs)
|
||||
total_chunks = len(doc_chunks)
|
||||
|
||||
# 第一块:包含标题
|
||||
first_chunk = doc_chunks[0]
|
||||
history_summary += f"\n\n## {i}. 参考文献\n\n{first_chunk}"
|
||||
yield json.dumps({
|
||||
"text": f"\n\n## {i}. 参考文献\n\n{first_chunk}",
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 处理剩余块
|
||||
for idx, chunk in enumerate(doc_chunks[1:], 2):
|
||||
await asyncio.sleep(2) # 每个块之间等待2秒
|
||||
history_summary += chunk
|
||||
yield json.dumps({
|
||||
"text": chunk,
|
||||
}, ensure_ascii=False)
|
||||
include_knowldge = True
|
||||
finish.append(history_summary)
|
||||
continue
|
||||
except Exception as e:
|
||||
i+=1
|
||||
await asyncio.sleep(2)
|
||||
if not docs == "":
|
||||
history_summary += f"\n\n## {i}. 参考文献\n\n{docs}"
|
||||
yield json.dumps({"text": f"\n\n## {i}. 参考文献\n\n{docs}"}, ensure_ascii=False)
|
||||
else:
|
||||
# 分块处理文档
|
||||
doc_chunks = await chunk_docs(docs)
|
||||
total_chunks = len(doc_chunks)
|
||||
|
||||
# 第一块:包含标题
|
||||
first_chunk = doc_chunks[0]
|
||||
history_summary += f"\n\n## {i}. 参考文献\n\n{first_chunk}"
|
||||
yield json.dumps({
|
||||
"text": f"\n\n## {i}. 参考文献\n\n{first_chunk}",
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 处理剩余块
|
||||
for idx, chunk in enumerate(doc_chunks[1:], 2):
|
||||
await asyncio.sleep(2) # 每个块之间等待2秒
|
||||
history_summary += chunk
|
||||
yield json.dumps({
|
||||
"text": chunk,
|
||||
}, ensure_ascii=False)
|
||||
include_knowldge = True
|
||||
finish.append(history_summary)
|
||||
continue
|
||||
if isinstance(styles, str) and styles.strip():
|
||||
try:
|
||||
# 获取语言风格模板
|
||||
result = get_prompt_template("llm_chat", styles)
|
||||
style = result if result is not None else styles
|
||||
except Exception as e:
|
||||
# 使用传入的 styles 作为语言风格
|
||||
style = styles
|
||||
logger.info(f"当前的语言风格为: {style}")
|
||||
search_query = get_llm_model_response(
|
||||
strategy_name="find_relation",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="find_relation",
|
||||
prompt_param_dict={"outline":res["text"],"i":i+1},
|
||||
temperature=0.01,
|
||||
max_tokens=15120
|
||||
)
|
||||
if "无" in search_query:
|
||||
finish_content=""
|
||||
else:
|
||||
for j in search_query.split(","):
|
||||
try:
|
||||
if int(j)<i+1 :
|
||||
finish_content+=f"已写【第{int(j)}部分】内容:\n{finish[int(j)-1]}\n"
|
||||
else:
|
||||
finish_content+=""
|
||||
except Exception as e:
|
||||
if int(j.strip()[-1])<i+1 :
|
||||
finish_content+=f"已写【第{int(j)}部分】内容:\n{finish[int(j)-1]}\n"
|
||||
else:
|
||||
finish_content+=""
|
||||
res1 = res["text"]
|
||||
#"\n当前使用过的资料名称(如果是参考文献章节,这些资料,必须输出:"+(docs if not docs == ""else "无")
|
||||
|
||||
async for response in agent_write_test(user_prompt_name = prompt_name,style = style, query="撰写中间部分章节禁止输出综上所述之类的影响文风的话,!!!\n注意严格按照提示1的部分来,你绝对不能省略任何一个部分.当前需要撰写的是的【第"+str(i+1)+"部分】内容。大纲如下:\n"+query+"\n"+(("可能需要参考的其他部分内容:\n"+finish_content) if not finish_content == "" else ""),uuid=time_based_uuid, history=[], model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="Write Test",think_content=res1+"\n你只需完成大纲的第"+str(i+1)+"部分撰写补全\n"):
|
||||
# print("------------"+response)
|
||||
if not utils.get_shared_variable(time_based_uuid)["status"]:
|
||||
logging.info("\n==============================STOPPED==============================\n")
|
||||
return
|
||||
if "answer" in json.loads(response):
|
||||
# logging.info(f"answer:{json.loads(response)['answer']}")
|
||||
answer = json.loads(response)["answer"]
|
||||
history_summary+=answer
|
||||
yield json.dumps({"text": answer}, ensure_ascii=False)
|
||||
elif "tools" in json.loads(response):
|
||||
# print("tools:", json.loads(response)["tools"])
|
||||
tools.append(json.loads(response)["tools"])
|
||||
# yield json.dumps({"tools": tools}, ensure_ascii=False)
|
||||
elif "search_answer" in json.loads(response):
|
||||
# search_answer = json.loads(response)["search_answer"]
|
||||
pass
|
||||
# history_summary+= search_answer
|
||||
#yield json.dumps({"texts": search_answer}, ensure_ascii=False)
|
||||
elif "docs" in json.loads(response):
|
||||
docs = json.loads(response)["docs"]
|
||||
elif "pic" in json.loads(response):
|
||||
# 获取图片路径
|
||||
image_name = json.loads(response)["pic"]
|
||||
image_name = f"\n\n\n\n"
|
||||
# yield json.dumps({"text": image_name}, ensure_ascii=False)
|
||||
else :
|
||||
#history_summary += json.loads(response)["final_answer"]
|
||||
# yield json.dumps({"texts": json.loads(response)["final_answer"]}, ensure_ascii=False)
|
||||
pass
|
||||
# if index3 == 0 and not "Action" in answer:
|
||||
# history_summary += answer
|
||||
# yield json.dumps({"text": answer}, ensure_ascii=False)
|
||||
yield json.dumps({"text":"\n"}, ensure_ascii=False)
|
||||
if history_summary == "":
|
||||
continue
|
||||
else:
|
||||
temp1 = utils.get_shared_variable(time_based_uuid)
|
||||
temp1["END"] = ""
|
||||
utils.set_shared_variable(time_based_uuid,temp1)
|
||||
i+=1
|
||||
nums = utils.get_shared_variable(time_based_uuid)["num"]
|
||||
finish.append(history_summary)
|
||||
await asyncio.sleep(0)
|
||||
if not include_knowldge:
|
||||
if not docs == "":
|
||||
# 分块处理文档
|
||||
doc_chunks = await chunk_docs(docs)
|
||||
total_chunks = len(doc_chunks)
|
||||
|
||||
# 第一块:包含标题
|
||||
first_chunk = doc_chunks[0]
|
||||
history_summary += f"\n\n## {i+1}. 参考文献\n\n{first_chunk}"
|
||||
yield json.dumps({
|
||||
"text": f"\n\n## {i+1}. 参考文献\n\n{first_chunk}",
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 处理剩余块
|
||||
for idx, chunk in enumerate(doc_chunks[1:], 2):
|
||||
await asyncio.sleep(2) # 每个块之间等待2秒
|
||||
history_summary += chunk
|
||||
yield json.dumps({
|
||||
"text": chunk,
|
||||
}, ensure_ascii=False)
|
||||
return EventSourceResponse(chat_iterator())
|
||||
Reference in New Issue
Block a user