294 lines
16 KiB
Python
294 lines
16 KiB
Python
|
|
import base64
|
|||
|
|
import uuid
|
|||
|
|
from fastapi import Body
|
|||
|
|
from fastapi.responses import FileResponse
|
|||
|
|
from langchain.memory import (
|
|||
|
|
CombinedMemory,
|
|||
|
|
ConversationBufferMemory,
|
|||
|
|
ConversationSummaryMemory,
|
|||
|
|
ConversationBufferWindowMemory
|
|||
|
|
)
|
|||
|
|
from typing import Any
|
|||
|
|
from sse_starlette.sse import EventSourceResponse
|
|||
|
|
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN
|
|||
|
|
from server.chat import utils
|
|||
|
|
from server.chat.agent_write_test import agent_write_test
|
|||
|
|
from server.chat.agent_chat_test import agent_chat_test, run_sync
|
|||
|
|
from server.chat.policy_fun_iast import get_llm_model_response
|
|||
|
|
from server.chat.solve_problem import solve_problem
|
|||
|
|
from server.knowledge_base.kb_service.base import TextRank
|
|||
|
|
from server.utils import wrap_done, get_ChatOpenAI
|
|||
|
|
from langchain.chains import LLMChain, ConversationChain
|
|||
|
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
|||
|
|
from typing import AsyncIterable
|
|||
|
|
import asyncio
|
|||
|
|
import json
|
|||
|
|
from langchain.prompts.chat import ChatPromptTemplate
|
|||
|
|
from typing import List, Optional, Union
|
|||
|
|
from server.chat.utils import History, compute_lps, remove_after_and_including,remove_before_and_including
|
|||
|
|
from langchain.prompts import PromptTemplate
|
|||
|
|
from server.utils import get_prompt_template, get_format_template
|
|||
|
|
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
|||
|
|
from server.db.repository import add_message_to_db
|
|||
|
|
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
|
|||
|
|
from datetime import datetime
|
|||
|
|
from langchain_core.messages import SystemMessage
|
|||
|
|
import time as t
|
|||
|
|
from configs.basic_config import *
|
|||
|
|
async def process_task(task):
|
|||
|
|
results = []
|
|||
|
|
async for result in task:
|
|||
|
|
results.append(result)
|
|||
|
|
return results
|
|||
|
|
async def chunk_docs(docs, chunk_size=800):
|
|||
|
|
"""将文档分块"""
|
|||
|
|
return [docs[i:i + chunk_size] for i in range(0, len(docs), chunk_size)]
|
|||
|
|
def convert_last_digits(s):
|
|||
|
|
# 检查字符串最后一位是否是数字
|
|||
|
|
if s[-1].strip().replace("\n","").isdigit():
|
|||
|
|
# 如果最后两位都是数字
|
|||
|
|
if len(s) > 1 and s[-2].strip().replace("\n","").isdigit():
|
|||
|
|
return int(s[-2:].strip().replace("\n","")),2
|
|||
|
|
# 如果只有最后一位是数字
|
|||
|
|
else:
|
|||
|
|
return int(s[-1].strip().replace("\n","")),1
|
|||
|
|
# 如果最后一位不是数字,返回None或其他适当的值
|
|||
|
|
return None
|
|||
|
|
async def write_article(
|
|||
|
|
messageId: Optional[Any] = Body(None, description="消息ID", examples=[""]),
|
|||
|
|
uid: Optional[Any] = Body(None, description="用户ID"),
|
|||
|
|
styles: Optional[Any] = Body(None, description="风格"),
|
|||
|
|
query: str = Body(..., description="大纲", examples=["恼羞成怒"]),
|
|||
|
|
knowledge_base_list: Optional[List[str]] = Body(None, description="个人知识库名称列表"),
|
|||
|
|
# conversation_id: str = Body("", description="对话框ID"),
|
|||
|
|
# history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
|
|||
|
|
# history: Union[int, List] = Body([],
|
|||
|
|
# description="历史对话,设为一个整数可以从数据库中读取历史消息",
|
|||
|
|
# examples=[[
|
|||
|
|
# {"role": "user",
|
|||
|
|
# "content": "我们来玩成语接龙,我先来,生龙活虎"},
|
|||
|
|
# {"role": "assistant", "content": "虎头虎脑","summary":"8989uj9"}]]
|
|||
|
|
# ),
|
|||
|
|
# stream: bool = Body(False, description="流式输出"),
|
|||
|
|
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
|||
|
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0),
|
|||
|
|
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
|||
|
|
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
|||
|
|
prompt_name: str = Body("complete_outline", description="大纲补全必须使用这个提示词"),
|
|||
|
|
):
|
|||
|
|
async def chat_iterator() -> AsyncIterable[str]:
|
|||
|
|
nonlocal messageId, styles, knowledge_base_list
|
|||
|
|
res={}
|
|||
|
|
res["text"]=""
|
|||
|
|
step=0
|
|||
|
|
num = 0
|
|||
|
|
time_based_uuid = messageId if messageId else str(uuid.uuid1())+"q"
|
|||
|
|
tip ={}
|
|||
|
|
tip["type"] = "article"
|
|||
|
|
include_knowldge = False
|
|||
|
|
if knowledge_base_list:
|
|||
|
|
tip["database"] = knowledge_base_list
|
|||
|
|
tip["status"] = True
|
|||
|
|
tip["source_docs"]=[]
|
|||
|
|
utils.set_shared_variable(time_based_uuid,tip)
|
|||
|
|
logger.info(f"生成正文的入参信息:\nquery:{query}\nstyles:{styles}\nmessageId: {messageId}\nknowledge_base_list:{knowledge_base_list}\nprompt_name:{prompt_name}\n")
|
|||
|
|
# 完成文章分解。
|
|||
|
|
while True:
|
|||
|
|
try:
|
|||
|
|
async for response in solve_problem(query=query, conversation_id="", history=[], model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="solve_problem_outline",stream=True):
|
|||
|
|
res["text"]+=json.loads(response)["text"]
|
|||
|
|
if not utils.get_shared_variable(time_based_uuid)["status"]:
|
|||
|
|
logging.info("\n==============================STOPPED==============================\n")
|
|||
|
|
return
|
|||
|
|
step,num = convert_last_digits(res["text"])
|
|||
|
|
break
|
|||
|
|
except:
|
|||
|
|
continue
|
|||
|
|
#step,num = convert_last_digits(res["text"])
|
|||
|
|
res["text"]=res["text"][:-num]
|
|||
|
|
docs = ""
|
|||
|
|
finish = []
|
|||
|
|
i = 0
|
|||
|
|
nums = 0
|
|||
|
|
while i < int(step):
|
|||
|
|
history_summary = ""
|
|||
|
|
count = 0
|
|||
|
|
index = 0
|
|||
|
|
temp = ""
|
|||
|
|
# time_based_uuid = messageId if messageId else str(uuid.uuid1())+"q"
|
|||
|
|
answer = ""
|
|||
|
|
tools = []
|
|||
|
|
tip["num"]= nums
|
|||
|
|
tip["END"]=""
|
|||
|
|
tip['title'] = []
|
|||
|
|
utils.set_shared_variable(time_based_uuid,tip)
|
|||
|
|
finish_content=""
|
|||
|
|
style = ""
|
|||
|
|
chapter = f"当前撰写的是第{i+1}章"
|
|||
|
|
try:
|
|||
|
|
determine_chapter = await run_sync(
|
|||
|
|
get_llm_model_response,
|
|||
|
|
strategy_name="default_code",
|
|||
|
|
llm_model_name=LLM_MODELS[0],
|
|||
|
|
template_prompt_name="identify_chapters",
|
|||
|
|
prompt_param_dict={"chapter": chapter, "outline": query},
|
|||
|
|
temperature=0.01,
|
|||
|
|
max_tokens=512
|
|||
|
|
)
|
|||
|
|
if int(determine_chapter.strip()) == 1:
|
|||
|
|
i += 1
|
|||
|
|
await asyncio.sleep(2)
|
|||
|
|
|
|||
|
|
if docs == "":
|
|||
|
|
history_summary += f"\n\n## {i}. 参考文献\n\n未参考任何文献"
|
|||
|
|
yield json.dumps({"text": f"\n\n## {i}. 参考文献\n\n未参考任何文献"}, ensure_ascii=False)
|
|||
|
|
else:
|
|||
|
|
# 分块处理文档
|
|||
|
|
doc_chunks = await chunk_docs(docs)
|
|||
|
|
total_chunks = len(doc_chunks)
|
|||
|
|
|
|||
|
|
# 第一块:包含标题
|
|||
|
|
first_chunk = doc_chunks[0]
|
|||
|
|
history_summary += f"\n\n## {i}. 参考文献\n\n{first_chunk}"
|
|||
|
|
yield json.dumps({
|
|||
|
|
"text": f"\n\n## {i}. 参考文献\n\n{first_chunk}",
|
|||
|
|
}, ensure_ascii=False)
|
|||
|
|
|
|||
|
|
# 处理剩余块
|
|||
|
|
for idx, chunk in enumerate(doc_chunks[1:], 2):
|
|||
|
|
await asyncio.sleep(2) # 每个块之间等待2秒
|
|||
|
|
history_summary += chunk
|
|||
|
|
yield json.dumps({
|
|||
|
|
"text": chunk,
|
|||
|
|
}, ensure_ascii=False)
|
|||
|
|
include_knowldge = True
|
|||
|
|
finish.append(history_summary)
|
|||
|
|
continue
|
|||
|
|
except Exception as e:
|
|||
|
|
i+=1
|
|||
|
|
await asyncio.sleep(2)
|
|||
|
|
if not docs == "":
|
|||
|
|
history_summary += f"\n\n## {i}. 参考文献\n\n{docs}"
|
|||
|
|
yield json.dumps({"text": f"\n\n## {i}. 参考文献\n\n{docs}"}, ensure_ascii=False)
|
|||
|
|
else:
|
|||
|
|
# 分块处理文档
|
|||
|
|
doc_chunks = await chunk_docs(docs)
|
|||
|
|
total_chunks = len(doc_chunks)
|
|||
|
|
|
|||
|
|
# 第一块:包含标题
|
|||
|
|
first_chunk = doc_chunks[0]
|
|||
|
|
history_summary += f"\n\n## {i}. 参考文献\n\n{first_chunk}"
|
|||
|
|
yield json.dumps({
|
|||
|
|
"text": f"\n\n## {i}. 参考文献\n\n{first_chunk}",
|
|||
|
|
}, ensure_ascii=False)
|
|||
|
|
|
|||
|
|
# 处理剩余块
|
|||
|
|
for idx, chunk in enumerate(doc_chunks[1:], 2):
|
|||
|
|
await asyncio.sleep(2) # 每个块之间等待2秒
|
|||
|
|
history_summary += chunk
|
|||
|
|
yield json.dumps({
|
|||
|
|
"text": chunk,
|
|||
|
|
}, ensure_ascii=False)
|
|||
|
|
include_knowldge = True
|
|||
|
|
finish.append(history_summary)
|
|||
|
|
continue
|
|||
|
|
if isinstance(styles, str) and styles.strip():
|
|||
|
|
try:
|
|||
|
|
# 获取语言风格模板
|
|||
|
|
result = get_prompt_template("llm_chat", styles)
|
|||
|
|
style = result if result is not None else styles
|
|||
|
|
except Exception as e:
|
|||
|
|
# 使用传入的 styles 作为语言风格
|
|||
|
|
style = styles
|
|||
|
|
logger.info(f"当前的语言风格为: {style}")
|
|||
|
|
search_query = get_llm_model_response(
|
|||
|
|
strategy_name="find_relation",
|
|||
|
|
llm_model_name=LLM_MODELS[0],
|
|||
|
|
template_prompt_name="find_relation",
|
|||
|
|
prompt_param_dict={"outline":res["text"],"i":i+1},
|
|||
|
|
temperature=0.01,
|
|||
|
|
max_tokens=15120
|
|||
|
|
)
|
|||
|
|
if "无" in search_query:
|
|||
|
|
finish_content=""
|
|||
|
|
else:
|
|||
|
|
for j in search_query.split(","):
|
|||
|
|
try:
|
|||
|
|
if int(j)<i+1 :
|
|||
|
|
finish_content+=f"已写【第{int(j)}部分】内容:\n{finish[int(j)-1]}\n"
|
|||
|
|
else:
|
|||
|
|
finish_content+=""
|
|||
|
|
except Exception as e:
|
|||
|
|
if int(j.strip()[-1])<i+1 :
|
|||
|
|
finish_content+=f"已写【第{int(j)}部分】内容:\n{finish[int(j)-1]}\n"
|
|||
|
|
else:
|
|||
|
|
finish_content+=""
|
|||
|
|
res1 = res["text"]
|
|||
|
|
#"\n当前使用过的资料名称(如果是参考文献章节,这些资料,必须输出:"+(docs if not docs == ""else "无")
|
|||
|
|
|
|||
|
|
async for response in agent_write_test(user_prompt_name = prompt_name,style = style, query="撰写中间部分章节禁止输出综上所述之类的影响文风的话,!!!\n注意严格按照提示1的部分来,你绝对不能省略任何一个部分.当前需要撰写的是的【第"+str(i+1)+"部分】内容。大纲如下:\n"+query+"\n"+(("可能需要参考的其他部分内容:\n"+finish_content) if not finish_content == "" else ""),uuid=time_based_uuid, history=[], model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="Write Test",think_content=res1+"\n你只需完成大纲的第"+str(i+1)+"部分撰写补全\n"):
|
|||
|
|
# print("------------"+response)
|
|||
|
|
if not utils.get_shared_variable(time_based_uuid)["status"]:
|
|||
|
|
logging.info("\n==============================STOPPED==============================\n")
|
|||
|
|
return
|
|||
|
|
if "answer" in json.loads(response):
|
|||
|
|
# logging.info(f"answer:{json.loads(response)['answer']}")
|
|||
|
|
answer = json.loads(response)["answer"]
|
|||
|
|
history_summary+=answer
|
|||
|
|
yield json.dumps({"text": answer}, ensure_ascii=False)
|
|||
|
|
elif "tools" in json.loads(response):
|
|||
|
|
# print("tools:", json.loads(response)["tools"])
|
|||
|
|
tools.append(json.loads(response)["tools"])
|
|||
|
|
# yield json.dumps({"tools": tools}, ensure_ascii=False)
|
|||
|
|
elif "search_answer" in json.loads(response):
|
|||
|
|
# search_answer = json.loads(response)["search_answer"]
|
|||
|
|
pass
|
|||
|
|
# history_summary+= search_answer
|
|||
|
|
#yield json.dumps({"texts": search_answer}, ensure_ascii=False)
|
|||
|
|
elif "docs" in json.loads(response):
|
|||
|
|
docs = json.loads(response)["docs"]
|
|||
|
|
elif "pic" in json.loads(response):
|
|||
|
|
# 获取图片路径
|
|||
|
|
image_name = json.loads(response)["pic"]
|
|||
|
|
image_name = f"\n\n\n\n"
|
|||
|
|
# yield json.dumps({"text": image_name}, ensure_ascii=False)
|
|||
|
|
else :
|
|||
|
|
#history_summary += json.loads(response)["final_answer"]
|
|||
|
|
# yield json.dumps({"texts": json.loads(response)["final_answer"]}, ensure_ascii=False)
|
|||
|
|
pass
|
|||
|
|
# if index3 == 0 and not "Action" in answer:
|
|||
|
|
# history_summary += answer
|
|||
|
|
# yield json.dumps({"text": answer}, ensure_ascii=False)
|
|||
|
|
yield json.dumps({"text":"\n"}, ensure_ascii=False)
|
|||
|
|
if history_summary == "":
|
|||
|
|
continue
|
|||
|
|
else:
|
|||
|
|
temp1 = utils.get_shared_variable(time_based_uuid)
|
|||
|
|
temp1["END"] = ""
|
|||
|
|
utils.set_shared_variable(time_based_uuid,temp1)
|
|||
|
|
i+=1
|
|||
|
|
nums = utils.get_shared_variable(time_based_uuid)["num"]
|
|||
|
|
finish.append(history_summary)
|
|||
|
|
await asyncio.sleep(0)
|
|||
|
|
if not include_knowldge:
|
|||
|
|
if not docs == "":
|
|||
|
|
# 分块处理文档
|
|||
|
|
doc_chunks = await chunk_docs(docs)
|
|||
|
|
total_chunks = len(doc_chunks)
|
|||
|
|
|
|||
|
|
# 第一块:包含标题
|
|||
|
|
first_chunk = doc_chunks[0]
|
|||
|
|
history_summary += f"\n\n## {i+1}. 参考文献\n\n{first_chunk}"
|
|||
|
|
yield json.dumps({
|
|||
|
|
"text": f"\n\n## {i+1}. 参考文献\n\n{first_chunk}",
|
|||
|
|
}, ensure_ascii=False)
|
|||
|
|
|
|||
|
|
# 处理剩余块
|
|||
|
|
for idx, chunk in enumerate(doc_chunks[1:], 2):
|
|||
|
|
await asyncio.sleep(2) # 每个块之间等待2秒
|
|||
|
|
history_summary += chunk
|
|||
|
|
yield json.dumps({
|
|||
|
|
"text": chunk,
|
|||
|
|
}, ensure_ascii=False)
|
|||
|
|
return EventSourceResponse(chat_iterator())
|