Files
gangyan/langchain-chat/server/chat/agent_chat_new.py

75 lines
4.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datetime import datetime
import uuid
from langchain.memory import ConversationBufferMemory
from langchain_experimental.plan_and_execute import PlanAndExecute, load_agent_executor, load_chat_planner
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
from server.agent.tools_select import tools, tool_names, search_tool_names
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
from langchain.agents import LLMSingleActionAgent, AgentExecutor
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
from fastapi import Body
from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
from server.chat import utils
from server.chat.knowledge_base_name import KnowledgeBase
from server.custom.AsyncIteratorCallbackHandlerNew import AsyncIteratorCallbackHandler
from server.utils import replace_variables, wrap_done, get_ChatOpenAI, get_prompt_template
from langchain.chains import LLMChain
from typing import AsyncIterable, Optional, Union
from server.agent.tools import rag_search
import asyncio
from typing import List
from server.chat.utils import History
import json
from server.agent import model_container
from server.knowledge_base.kb_service.base import get_kb_details
import ast
import re
from configs import kb_config
from configs.basic_config import *
async def agent_chat_new(
query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
conversation_id: str = Body("", description="对话框ID"),
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
history: Union[int, List[History]] = Body([],
description="历史对话,设为一个整数可以从数据库中读取历史消息",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", "content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
async def chat_iterator() -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
# 加载计划者和执行者
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
planner = load_chat_planner(model)
executor = load_agent_executor(model, tools, verbose=True)
agent1 = PlanAndExecute(planner=planner, executor=executor, verbose=True)
# 初始化 response1
response1 = ""
# 确保传递的输入是字典
input_data = {
"input":"",
"\"query\"": query
}
async for response in agent1.run(input_data):
response1 += response
yield response1 # 使用 yield 逐步返回响应
return EventSourceResponse(chat_iterator())