294 lines
16 KiB
Python
294 lines
16 KiB
Python
import base64
|
||
import uuid
|
||
from fastapi import Body
|
||
from fastapi.responses import FileResponse
|
||
from langchain.memory import (
|
||
CombinedMemory,
|
||
ConversationBufferMemory,
|
||
ConversationSummaryMemory,
|
||
ConversationBufferWindowMemory
|
||
)
|
||
from typing import Any
|
||
from sse_starlette.sse import EventSourceResponse
|
||
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN
|
||
from server.chat import utils
|
||
from server.chat.agent_write_test import agent_write_test
|
||
from server.chat.agent_chat_test import agent_chat_test, run_sync
|
||
from server.chat.policy_fun_iast import get_llm_model_response
|
||
from server.chat.solve_problem import solve_problem
|
||
from server.knowledge_base.kb_service.base import TextRank
|
||
from server.utils import wrap_done, get_ChatOpenAI
|
||
from langchain.chains import LLMChain, ConversationChain
|
||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||
from typing import AsyncIterable
|
||
import asyncio
|
||
import json
|
||
from langchain.prompts.chat import ChatPromptTemplate
|
||
from typing import List, Optional, Union
|
||
from server.chat.utils import History, compute_lps, remove_after_and_including,remove_before_and_including
|
||
from langchain.prompts import PromptTemplate
|
||
from server.utils import get_prompt_template, get_format_template
|
||
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
||
from server.db.repository import add_message_to_db
|
||
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
|
||
from datetime import datetime
|
||
from langchain_core.messages import SystemMessage
|
||
import time as t
|
||
from configs.basic_config import *
|
||
async def process_task(task):
|
||
results = []
|
||
async for result in task:
|
||
results.append(result)
|
||
return results
|
||
async def chunk_docs(docs, chunk_size=800):
|
||
"""将文档分块"""
|
||
return [docs[i:i + chunk_size] for i in range(0, len(docs), chunk_size)]
|
||
def convert_last_digits(s):
|
||
# 检查字符串最后一位是否是数字
|
||
if s[-1].strip().replace("\n","").isdigit():
|
||
# 如果最后两位都是数字
|
||
if len(s) > 1 and s[-2].strip().replace("\n","").isdigit():
|
||
return int(s[-2:].strip().replace("\n","")),2
|
||
# 如果只有最后一位是数字
|
||
else:
|
||
return int(s[-1].strip().replace("\n","")),1
|
||
# 如果最后一位不是数字,返回None或其他适当的值
|
||
return None
|
||
async def write_article(
|
||
messageId: Optional[Any] = Body(None, description="消息ID", examples=[""]),
|
||
uid: Optional[Any] = Body(None, description="用户ID"),
|
||
styles: Optional[Any] = Body(None, description="风格"),
|
||
query: str = Body(..., description="大纲", examples=["恼羞成怒"]),
|
||
knowledge_base_list: Optional[List[str]] = Body(None, description="个人知识库名称列表"),
|
||
# conversation_id: str = Body("", description="对话框ID"),
|
||
# history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
|
||
# history: Union[int, List] = Body([],
|
||
# description="历史对话,设为一个整数可以从数据库中读取历史消息",
|
||
# examples=[[
|
||
# {"role": "user",
|
||
# "content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||
# {"role": "assistant", "content": "虎头虎脑","summary":"8989uj9"}]]
|
||
# ),
|
||
# stream: bool = Body(False, description="流式输出"),
|
||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0),
|
||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||
prompt_name: str = Body("complete_outline", description="大纲补全必须使用这个提示词"),
|
||
):
|
||
async def chat_iterator() -> AsyncIterable[str]:
|
||
nonlocal messageId, styles, knowledge_base_list
|
||
res={}
|
||
res["text"]=""
|
||
step=0
|
||
num = 0
|
||
time_based_uuid = messageId if messageId else str(uuid.uuid1())+"q"
|
||
tip ={}
|
||
tip["type"] = "article"
|
||
include_knowldge = False
|
||
if knowledge_base_list:
|
||
tip["database"] = knowledge_base_list
|
||
tip["status"] = True
|
||
tip["source_docs"]=[]
|
||
utils.set_shared_variable(time_based_uuid,tip)
|
||
logger.info(f"生成正文的入参信息:\nquery:{query}\nstyles:{styles}\nmessageId: {messageId}\nknowledge_base_list:{knowledge_base_list}\nprompt_name:{prompt_name}\n")
|
||
# 完成文章分解。
|
||
while True:
|
||
try:
|
||
async for response in solve_problem(query=query, conversation_id="", history=[], model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="solve_problem_outline",stream=True):
|
||
res["text"]+=json.loads(response)["text"]
|
||
if not utils.get_shared_variable(time_based_uuid)["status"]:
|
||
logging.info("\n==============================STOPPED==============================\n")
|
||
return
|
||
step,num = convert_last_digits(res["text"])
|
||
break
|
||
except:
|
||
continue
|
||
#step,num = convert_last_digits(res["text"])
|
||
res["text"]=res["text"][:-num]
|
||
docs = ""
|
||
finish = []
|
||
i = 0
|
||
nums = 0
|
||
while i < int(step):
|
||
history_summary = ""
|
||
count = 0
|
||
index = 0
|
||
temp = ""
|
||
# time_based_uuid = messageId if messageId else str(uuid.uuid1())+"q"
|
||
answer = ""
|
||
tools = []
|
||
tip["num"]= nums
|
||
tip["END"]=""
|
||
tip['title'] = []
|
||
utils.set_shared_variable(time_based_uuid,tip)
|
||
finish_content=""
|
||
style = ""
|
||
chapter = f"当前撰写的是第{i+1}章"
|
||
try:
|
||
determine_chapter = await run_sync(
|
||
get_llm_model_response,
|
||
strategy_name="default_code",
|
||
llm_model_name=LLM_MODELS[0],
|
||
template_prompt_name="identify_chapters",
|
||
prompt_param_dict={"chapter": chapter, "outline": query},
|
||
temperature=0.01,
|
||
max_tokens=512
|
||
)
|
||
if int(determine_chapter.strip()) == 1:
|
||
i += 1
|
||
await asyncio.sleep(2)
|
||
|
||
if docs == "":
|
||
history_summary += f"\n\n## {i}. 参考文献\n\n未参考任何文献"
|
||
yield json.dumps({"text": f"\n\n## {i}. 参考文献\n\n未参考任何文献"}, ensure_ascii=False)
|
||
else:
|
||
# 分块处理文档
|
||
doc_chunks = await chunk_docs(docs)
|
||
total_chunks = len(doc_chunks)
|
||
|
||
# 第一块:包含标题
|
||
first_chunk = doc_chunks[0]
|
||
history_summary += f"\n\n## {i}. 参考文献\n\n{first_chunk}"
|
||
yield json.dumps({
|
||
"text": f"\n\n## {i}. 参考文献\n\n{first_chunk}",
|
||
}, ensure_ascii=False)
|
||
|
||
# 处理剩余块
|
||
for idx, chunk in enumerate(doc_chunks[1:], 2):
|
||
await asyncio.sleep(2) # 每个块之间等待2秒
|
||
history_summary += chunk
|
||
yield json.dumps({
|
||
"text": chunk,
|
||
}, ensure_ascii=False)
|
||
include_knowldge = True
|
||
finish.append(history_summary)
|
||
continue
|
||
except Exception as e:
|
||
i+=1
|
||
await asyncio.sleep(2)
|
||
if not docs == "":
|
||
history_summary += f"\n\n## {i}. 参考文献\n\n{docs}"
|
||
yield json.dumps({"text": f"\n\n## {i}. 参考文献\n\n{docs}"}, ensure_ascii=False)
|
||
else:
|
||
# 分块处理文档
|
||
doc_chunks = await chunk_docs(docs)
|
||
total_chunks = len(doc_chunks)
|
||
|
||
# 第一块:包含标题
|
||
first_chunk = doc_chunks[0]
|
||
history_summary += f"\n\n## {i}. 参考文献\n\n{first_chunk}"
|
||
yield json.dumps({
|
||
"text": f"\n\n## {i}. 参考文献\n\n{first_chunk}",
|
||
}, ensure_ascii=False)
|
||
|
||
# 处理剩余块
|
||
for idx, chunk in enumerate(doc_chunks[1:], 2):
|
||
await asyncio.sleep(2) # 每个块之间等待2秒
|
||
history_summary += chunk
|
||
yield json.dumps({
|
||
"text": chunk,
|
||
}, ensure_ascii=False)
|
||
include_knowldge = True
|
||
finish.append(history_summary)
|
||
continue
|
||
if isinstance(styles, str) and styles.strip():
|
||
try:
|
||
# 获取语言风格模板
|
||
result = get_prompt_template("llm_chat", styles)
|
||
style = result if result is not None else styles
|
||
except Exception as e:
|
||
# 使用传入的 styles 作为语言风格
|
||
style = styles
|
||
logger.info(f"当前的语言风格为: {style}")
|
||
search_query = get_llm_model_response(
|
||
strategy_name="find_relation",
|
||
llm_model_name=LLM_MODELS[0],
|
||
template_prompt_name="find_relation",
|
||
prompt_param_dict={"outline":res["text"],"i":i+1},
|
||
temperature=0.01,
|
||
max_tokens=15120
|
||
)
|
||
if "无" in search_query:
|
||
finish_content=""
|
||
else:
|
||
for j in search_query.split(","):
|
||
try:
|
||
if int(j)<i+1 :
|
||
finish_content+=f"已写【第{int(j)}部分】内容:\n{finish[int(j)-1]}\n"
|
||
else:
|
||
finish_content+=""
|
||
except Exception as e:
|
||
if int(j.strip()[-1])<i+1 :
|
||
finish_content+=f"已写【第{int(j)}部分】内容:\n{finish[int(j)-1]}\n"
|
||
else:
|
||
finish_content+=""
|
||
res1 = res["text"]
|
||
#"\n当前使用过的资料名称(如果是参考文献章节,这些资料,必须输出:"+(docs if not docs == ""else "无")
|
||
|
||
async for response in agent_write_test(user_prompt_name = prompt_name,style = style, query="撰写中间部分章节禁止输出综上所述之类的影响文风的话,!!!\n注意严格按照提示1的部分来,你绝对不能省略任何一个部分.当前需要撰写的是的【第"+str(i+1)+"部分】内容。大纲如下:\n"+query+"\n"+(("可能需要参考的其他部分内容:\n"+finish_content) if not finish_content == "" else ""),uuid=time_based_uuid, history=[], model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="Write Test",think_content=res1+"\n你只需完成大纲的第"+str(i+1)+"部分撰写补全\n"):
|
||
# print("------------"+response)
|
||
if not utils.get_shared_variable(time_based_uuid)["status"]:
|
||
logging.info("\n==============================STOPPED==============================\n")
|
||
return
|
||
if "answer" in json.loads(response):
|
||
# logging.info(f"answer:{json.loads(response)['answer']}")
|
||
answer = json.loads(response)["answer"]
|
||
history_summary+=answer
|
||
yield json.dumps({"text": answer}, ensure_ascii=False)
|
||
elif "tools" in json.loads(response):
|
||
# print("tools:", json.loads(response)["tools"])
|
||
tools.append(json.loads(response)["tools"])
|
||
# yield json.dumps({"tools": tools}, ensure_ascii=False)
|
||
elif "search_answer" in json.loads(response):
|
||
# search_answer = json.loads(response)["search_answer"]
|
||
pass
|
||
# history_summary+= search_answer
|
||
#yield json.dumps({"texts": search_answer}, ensure_ascii=False)
|
||
elif "docs" in json.loads(response):
|
||
docs = json.loads(response)["docs"]
|
||
elif "pic" in json.loads(response):
|
||
# 获取图片路径
|
||
image_name = json.loads(response)["pic"]
|
||
image_name = f"\n\n\n\n"
|
||
# yield json.dumps({"text": image_name}, ensure_ascii=False)
|
||
else :
|
||
#history_summary += json.loads(response)["final_answer"]
|
||
# yield json.dumps({"texts": json.loads(response)["final_answer"]}, ensure_ascii=False)
|
||
pass
|
||
# if index3 == 0 and not "Action" in answer:
|
||
# history_summary += answer
|
||
# yield json.dumps({"text": answer}, ensure_ascii=False)
|
||
yield json.dumps({"text":"\n"}, ensure_ascii=False)
|
||
if history_summary == "":
|
||
continue
|
||
else:
|
||
temp1 = utils.get_shared_variable(time_based_uuid)
|
||
temp1["END"] = ""
|
||
utils.set_shared_variable(time_based_uuid,temp1)
|
||
i+=1
|
||
nums = utils.get_shared_variable(time_based_uuid)["num"]
|
||
finish.append(history_summary)
|
||
await asyncio.sleep(0)
|
||
if not include_knowldge:
|
||
if not docs == "":
|
||
# 分块处理文档
|
||
doc_chunks = await chunk_docs(docs)
|
||
total_chunks = len(doc_chunks)
|
||
|
||
# 第一块:包含标题
|
||
first_chunk = doc_chunks[0]
|
||
history_summary += f"\n\n## {i+1}. 参考文献\n\n{first_chunk}"
|
||
yield json.dumps({
|
||
"text": f"\n\n## {i+1}. 参考文献\n\n{first_chunk}",
|
||
}, ensure_ascii=False)
|
||
|
||
# 处理剩余块
|
||
for idx, chunk in enumerate(doc_chunks[1:], 2):
|
||
await asyncio.sleep(2) # 每个块之间等待2秒
|
||
history_summary += chunk
|
||
yield json.dumps({
|
||
"text": chunk,
|
||
}, ensure_ascii=False)
|
||
return EventSourceResponse(chat_iterator())
|