Files
gangyan/langchain-chat/server/agent/agent_chat.py

101 lines
4.0 KiB
Python

import uuid
from fastapi import Body
from langchain.memory import (
CombinedMemory,
ConversationBufferMemory,
ConversationSummaryMemory,
ConversationBufferWindowMemory
)
from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN
from server.utils import wrap_done, get_ChatOpenAI
from langchain.chains import LLMChain, ConversationChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
import json
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional, Union
from server.chat.utils import History
from langchain.prompts import PromptTemplate
from server.utils import get_prompt_template, get_format_template
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from server.db.repository import add_message_to_db
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
from datetime import datetime
from langchain_core.messages import SystemMessage
import time as t
from server.utils import replace_variables
from configs.basic_config import *
from configs.outline_config import outlines
async def agent_chat_new(
user_prompt_name: Optional[str] = Body(None, description="用户输入"),
query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
conversation_id: str = Body("", description="对话框ID"),
history: Union[int, List[History]] = Body([], description="历史对话"),
model_name: str = Body("default_model", description="LLM 模型名称。"),
temperature: float = Body(0.7, description="LLM 采样温度", ge=0.0, le=2.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量"),
prompt_template: str = Body("default", description="使用的prompt模板内容"),
stream: bool = Body(False, description="流式输出")
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
callbacks = [callback]
time = datetime.now().strftime("%Y年%m月%d")
message_id = str(uuid.uuid1())+"q"
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=callbacks,
)
history = [History.from_data(h) for h in history]
chat_prompt = PromptTemplate.from_template(prompt_template)
# 把history转成memory
buff_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="history", input_key="input")
if len(history)>0:
for message in history:
# 检查消息的角色
if message.role == 'user':
# 添加用户消息
buff_memory.chat_memory.add_user_message(message.content)
elif message.role == 'assistant':
# 添加AI消息
buff_memory.chat_memory.add_ai_message(message.content)
else:
buff_memory.chat_memory.add_user_message("")
buff_memory.chat_memory.add_ai_message("")
background_memory = ConversationBufferMemory(human_prefix='user', ai_prefix='assistant', memory_key="time", input_key="input")
message = SystemMessage(content = f'当前的时间是:{time}')
background_memory.chat_memory.add_message(message)
memory = CombinedMemory(memories=[background_memory, buff_memory])
chain = ConversationChain(llm=model, verbose=True, memory=memory, prompt=chat_prompt)
task = asyncio.create_task(wrap_done(
chain.acall({"input": query, "time": time}),
callback.done),
)
answer = ""
async for token in callback.aiter():
if stream:
yield json.dumps({"text": token}, ensure_ascii=False)
else:
answer += token
logger.info(f'solve_problem: {str(answer)}')
await task
if stream:
return
else:
yield json.dumps(
{"text": answer, "message_id": message_id},
ensure_ascii=False
)