Files
gangyan/langchain-chat/server/chat/write_article.py

294 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import base64
import uuid
from fastapi import Body
from fastapi.responses import FileResponse
from langchain.memory import (
CombinedMemory,
ConversationBufferMemory,
ConversationSummaryMemory,
ConversationBufferWindowMemory
)
from typing import Any
from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN
from server.chat import utils
from server.chat.agent_write_test import agent_write_test
from server.chat.agent_chat_test import agent_chat_test, run_sync
from server.chat.policy_fun_iast import get_llm_model_response
from server.chat.solve_problem import solve_problem
from server.knowledge_base.kb_service.base import TextRank
from server.utils import wrap_done, get_ChatOpenAI
from langchain.chains import LLMChain, ConversationChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
import json
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional, Union
from server.chat.utils import History, compute_lps, remove_after_and_including,remove_before_and_including
from langchain.prompts import PromptTemplate
from server.utils import get_prompt_template, get_format_template
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from server.db.repository import add_message_to_db
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
from datetime import datetime
from langchain_core.messages import SystemMessage
import time as t
from configs.basic_config import *
async def process_task(task):
results = []
async for result in task:
results.append(result)
return results
async def chunk_docs(docs, chunk_size=800):
"""将文档分块"""
return [docs[i:i + chunk_size] for i in range(0, len(docs), chunk_size)]
def convert_last_digits(s):
# 检查字符串最后一位是否是数字
if s[-1].strip().replace("\n","").isdigit():
# 如果最后两位都是数字
if len(s) > 1 and s[-2].strip().replace("\n","").isdigit():
return int(s[-2:].strip().replace("\n","")),2
# 如果只有最后一位是数字
else:
return int(s[-1].strip().replace("\n","")),1
# 如果最后一位不是数字返回None或其他适当的值
return None
async def write_article(
messageId: Optional[Any] = Body(None, description="消息ID", examples=[""]),
uid: Optional[Any] = Body(None, description="用户ID"),
styles: Optional[Any] = Body(None, description="风格"),
query: str = Body(..., description="大纲", examples=["恼羞成怒"]),
knowledge_base_list: Optional[List[str]] = Body(None, description="个人知识库名称列表"),
# conversation_id: str = Body("", description="对话框ID"),
# history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
# history: Union[int, List] = Body([],
# description="历史对话,设为一个整数可以从数据库中读取历史消息",
# examples=[[
# {"role": "user",
# "content": "我们来玩成语接龙,我先来,生龙活虎"},
# {"role": "assistant", "content": "虎头虎脑","summary":"8989uj9"}]]
# ),
# stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("complete_outline", description="大纲补全必须使用这个提示词"),
):
async def chat_iterator() -> AsyncIterable[str]:
nonlocal messageId, styles, knowledge_base_list
res={}
res["text"]=""
step=0
num = 0
time_based_uuid = messageId if messageId else str(uuid.uuid1())+"q"
tip ={}
tip["type"] = "article"
include_knowldge = False
if knowledge_base_list:
tip["database"] = knowledge_base_list
tip["status"] = True
tip["source_docs"]=[]
utils.set_shared_variable(time_based_uuid,tip)
logger.info(f"生成正文的入参信息:\nquery{query}\nstyles{styles}\nmessageId: {messageId}\nknowledge_base_list{knowledge_base_list}\nprompt_name{prompt_name}\n")
# 完成文章分解。
while True:
try:
async for response in solve_problem(query=query, conversation_id="", history=[], model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="solve_problem_outline",stream=True):
res["text"]+=json.loads(response)["text"]
if not utils.get_shared_variable(time_based_uuid)["status"]:
logging.info("\n==============================STOPPED==============================\n")
return
step,num = convert_last_digits(res["text"])
break
except:
continue
#step,num = convert_last_digits(res["text"])
res["text"]=res["text"][:-num]
docs = ""
finish = []
i = 0
nums = 0
while i < int(step):
history_summary = ""
count = 0
index = 0
temp = ""
# time_based_uuid = messageId if messageId else str(uuid.uuid1())+"q"
answer = ""
tools = []
tip["num"]= nums
tip["END"]=""
tip['title'] = []
utils.set_shared_variable(time_based_uuid,tip)
finish_content=""
style = ""
chapter = f"当前撰写的是第{i+1}"
try:
determine_chapter = await run_sync(
get_llm_model_response,
strategy_name="default_code",
llm_model_name=LLM_MODELS[0],
template_prompt_name="identify_chapters",
prompt_param_dict={"chapter": chapter, "outline": query},
temperature=0.01,
max_tokens=512
)
if int(determine_chapter.strip()) == 1:
i += 1
await asyncio.sleep(2)
if docs == "":
history_summary += f"\n\n## {i}. 参考文献\n\n未参考任何文献"
yield json.dumps({"text": f"\n\n## {i}. 参考文献\n\n未参考任何文献"}, ensure_ascii=False)
else:
# 分块处理文档
doc_chunks = await chunk_docs(docs)
total_chunks = len(doc_chunks)
# 第一块:包含标题
first_chunk = doc_chunks[0]
history_summary += f"\n\n## {i}. 参考文献\n\n{first_chunk}"
yield json.dumps({
"text": f"\n\n## {i}. 参考文献\n\n{first_chunk}",
}, ensure_ascii=False)
# 处理剩余块
for idx, chunk in enumerate(doc_chunks[1:], 2):
await asyncio.sleep(2) # 每个块之间等待2秒
history_summary += chunk
yield json.dumps({
"text": chunk,
}, ensure_ascii=False)
include_knowldge = True
finish.append(history_summary)
continue
except Exception as e:
i+=1
await asyncio.sleep(2)
if not docs == "":
history_summary += f"\n\n## {i}. 参考文献\n\n{docs}"
yield json.dumps({"text": f"\n\n## {i}. 参考文献\n\n{docs}"}, ensure_ascii=False)
else:
# 分块处理文档
doc_chunks = await chunk_docs(docs)
total_chunks = len(doc_chunks)
# 第一块:包含标题
first_chunk = doc_chunks[0]
history_summary += f"\n\n## {i}. 参考文献\n\n{first_chunk}"
yield json.dumps({
"text": f"\n\n## {i}. 参考文献\n\n{first_chunk}",
}, ensure_ascii=False)
# 处理剩余块
for idx, chunk in enumerate(doc_chunks[1:], 2):
await asyncio.sleep(2) # 每个块之间等待2秒
history_summary += chunk
yield json.dumps({
"text": chunk,
}, ensure_ascii=False)
include_knowldge = True
finish.append(history_summary)
continue
if isinstance(styles, str) and styles.strip():
try:
# 获取语言风格模板
result = get_prompt_template("llm_chat", styles)
style = result if result is not None else styles
except Exception as e:
# 使用传入的 styles 作为语言风格
style = styles
logger.info(f"当前的语言风格为: {style}")
search_query = get_llm_model_response(
strategy_name="find_relation",
llm_model_name=LLM_MODELS[0],
template_prompt_name="find_relation",
prompt_param_dict={"outline":res["text"],"i":i+1},
temperature=0.01,
max_tokens=15120
)
if "" in search_query:
finish_content=""
else:
for j in search_query.split(","):
try:
if int(j)<i+1 :
finish_content+=f"已写【第{int(j)}部分】内容:\n{finish[int(j)-1]}\n"
else:
finish_content+=""
except Exception as e:
if int(j.strip()[-1])<i+1 :
finish_content+=f"已写【第{int(j)}部分】内容:\n{finish[int(j)-1]}\n"
else:
finish_content+=""
res1 = res["text"]
#"\n当前使用过的资料名称如果是参考文献章节这些资料,必须输出:"+(docs if not docs == ""else "无")
async for response in agent_write_test(user_prompt_name = prompt_name,style = style, query="撰写中间部分章节禁止输出综上所述之类的影响文风的话,!!!\n注意严格按照提示1的部分来你绝对不能省略任何一个部分.当前需要撰写的是的【第"+str(i+1)+"部分】内容。大纲如下:\n"+query+"\n"+(("可能需要参考的其他部分内容:\n"+finish_content) if not finish_content == "" else ""),uuid=time_based_uuid, history=[], model_name=model_name,temperature=temperature,max_tokens=max_tokens,prompt_name="Write Test",think_content=res1+"\n你只需完成大纲的第"+str(i+1)+"部分撰写补全\n"):
# print("------------"+response)
if not utils.get_shared_variable(time_based_uuid)["status"]:
logging.info("\n==============================STOPPED==============================\n")
return
if "answer" in json.loads(response):
# logging.info(f"answer:{json.loads(response)['answer']}")
answer = json.loads(response)["answer"]
history_summary+=answer
yield json.dumps({"text": answer}, ensure_ascii=False)
elif "tools" in json.loads(response):
# print("tools:", json.loads(response)["tools"])
tools.append(json.loads(response)["tools"])
# yield json.dumps({"tools": tools}, ensure_ascii=False)
elif "search_answer" in json.loads(response):
# search_answer = json.loads(response)["search_answer"]
pass
# history_summary+= search_answer
#yield json.dumps({"texts": search_answer}, ensure_ascii=False)
elif "docs" in json.loads(response):
docs = json.loads(response)["docs"]
elif "pic" in json.loads(response):
# 获取图片路径
image_name = json.loads(response)["pic"]
image_name = f"\n\n![图片](http://127.0.0.1:8099/chat_web_backend/get-image?file_name={image_name})\n\n"
# yield json.dumps({"text": image_name}, ensure_ascii=False)
else :
#history_summary += json.loads(response)["final_answer"]
# yield json.dumps({"texts": json.loads(response)["final_answer"]}, ensure_ascii=False)
pass
# if index3 == 0 and not "Action" in answer:
# history_summary += answer
# yield json.dumps({"text": answer}, ensure_ascii=False)
yield json.dumps({"text":"\n"}, ensure_ascii=False)
if history_summary == "":
continue
else:
temp1 = utils.get_shared_variable(time_based_uuid)
temp1["END"] = ""
utils.set_shared_variable(time_based_uuid,temp1)
i+=1
nums = utils.get_shared_variable(time_based_uuid)["num"]
finish.append(history_summary)
await asyncio.sleep(0)
if not include_knowldge:
if not docs == "":
# 分块处理文档
doc_chunks = await chunk_docs(docs)
total_chunks = len(doc_chunks)
# 第一块:包含标题
first_chunk = doc_chunks[0]
history_summary += f"\n\n## {i+1}. 参考文献\n\n{first_chunk}"
yield json.dumps({
"text": f"\n\n## {i+1}. 参考文献\n\n{first_chunk}",
}, ensure_ascii=False)
# 处理剩余块
for idx, chunk in enumerate(doc_chunks[1:], 2):
await asyncio.sleep(2) # 每个块之间等待2秒
history_summary += chunk
yield json.dumps({
"text": chunk,
}, ensure_ascii=False)
return EventSourceResponse(chat_iterator())