75 lines
4.1 KiB
Python
75 lines
4.1 KiB
Python
from datetime import datetime
|
||
import uuid
|
||
from langchain.memory import ConversationBufferMemory
|
||
from langchain_experimental.plan_and_execute import PlanAndExecute, load_agent_executor, load_chat_planner
|
||
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
|
||
from server.agent.tools_select import tools, tool_names, search_tool_names
|
||
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
|
||
from langchain.agents import LLMSingleActionAgent, AgentExecutor
|
||
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
|
||
from fastapi import Body
|
||
from sse_starlette.sse import EventSourceResponse
|
||
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
|
||
from server.chat import utils
|
||
from server.chat.knowledge_base_name import KnowledgeBase
|
||
from server.custom.AsyncIteratorCallbackHandlerNew import AsyncIteratorCallbackHandler
|
||
from server.utils import replace_variables, wrap_done, get_ChatOpenAI, get_prompt_template
|
||
from langchain.chains import LLMChain
|
||
from typing import AsyncIterable, Optional, Union
|
||
from server.agent.tools import rag_search
|
||
import asyncio
|
||
from typing import List
|
||
from server.chat.utils import History
|
||
import json
|
||
from server.agent import model_container
|
||
from server.knowledge_base.kb_service.base import get_kb_details
|
||
import ast
|
||
import re
|
||
from configs import kb_config
|
||
from configs.basic_config import *
|
||
|
||
async def agent_chat_new(
|
||
query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||
conversation_id: str = Body("", description="对话框ID"),
|
||
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
|
||
history: Union[int, List[History]] = Body([],
|
||
description="历史对话,设为一个整数可以从数据库中读取历史消息",
|
||
examples=[[
|
||
{"role": "user",
|
||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||
{"role": "assistant", "content": "虎头虎脑"}]]
|
||
),
|
||
stream: bool = Body(False, description="流式输出"),
|
||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0),
|
||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||
):
|
||
async def chat_iterator() -> AsyncIterable[str]:
|
||
callback = AsyncIteratorCallbackHandler()
|
||
# 加载计划者和执行者
|
||
model = get_ChatOpenAI(
|
||
model_name=model_name,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
callbacks=[callback],
|
||
)
|
||
planner = load_chat_planner(model)
|
||
executor = load_agent_executor(model, tools, verbose=True)
|
||
agent1 = PlanAndExecute(planner=planner, executor=executor, verbose=True)
|
||
|
||
# 初始化 response1
|
||
response1 = ""
|
||
|
||
# 确保传递的输入是字典
|
||
input_data = {
|
||
"input":"",
|
||
"\"query\"": query
|
||
}
|
||
|
||
async for response in agent1.run(input_data):
|
||
response1 += response
|
||
yield response1 # 使用 yield 逐步返回响应
|
||
|
||
return EventSourceResponse(chat_iterator()) |