579 lines
20 KiB
Python
579 lines
20 KiB
Python
from server.chat.check_language import check_language, get_supported_languages
|
||
from server.chat.chat_comparison import chat_comparison_test
|
||
from server.chat.gen_title import gen_title
|
||
from server.chat.relevant_articles import relevant_articles
|
||
from server.chat.self_kb_chat import self_kb_chat
|
||
from server.chat.stop import stop
|
||
import nltk
|
||
import sys
|
||
import os
|
||
|
||
from server.chat.chat_test import chat_test, get_image
|
||
from server.chat.gen_abstract import gen_abstract
|
||
from server.chat.gen_conclusion import gen_conclusion
|
||
from server.chat.gen_keywords import gen_keywords
|
||
from server.chat.gen_paragraph import gen_paragraph
|
||
from server.chat.knowledge_chat_test import knowledge_chat_test
|
||
from server.chat.translate import tarnslate_text
|
||
from server.chat.upload import upload_file
|
||
from server.chat.utils import download_self_doc
|
||
from server.chat.word_explain import word_explain
|
||
from server.chat.write_article import write_article
|
||
from server.knowledge_base.kb_doc_api import search_self_docs, upload_docs_new
|
||
from server.translator_service.main_api import cancel_task, download_result, get_progress, translate_file
|
||
|
||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||
|
||
from configs import VERSION
|
||
from configs.model_config import NLTK_DATA_PATH
|
||
from configs.server_config import OPEN_CROSS_DOMAIN
|
||
import argparse
|
||
import uvicorn
|
||
from fastapi import Body
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from starlette.responses import RedirectResponse
|
||
from server.chat.chat import chat
|
||
from server.chat.search_engine_chat import search_engine_chat
|
||
from server.chat.completion import completion
|
||
from server.custom.chapter_overview import chapter_overview
|
||
from server.custom.article_overview import article_overview
|
||
from server.custom.abstract_search import abstract_search
|
||
from server.custom.paper_translation import paper_translation
|
||
from server.chat.feedback import chat_feedback
|
||
from server.embeddings_api import embed_texts_endpoint
|
||
from server.llm_api import (list_running_models, list_config_models,
|
||
change_llm_model, stop_llm_model,
|
||
get_model_config, list_search_engines)
|
||
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, get_server_configs, get_prompt_template)
|
||
from typing import List, Literal
|
||
from server.chat.rewrite import(
|
||
# con_rewrite,
|
||
# exp_write,
|
||
# abb_write,
|
||
formal_style,
|
||
party_style,
|
||
col_style,
|
||
chi_to_ens,
|
||
ens_to_chi
|
||
)
|
||
from server.chat.con_rewrite import con_rewrite
|
||
from server.chat.exp_rewrite import exp_rewrite
|
||
from server.chat.abb_rewrite import abb_rewrite
|
||
from server.chat.rew_rewrite import rew_rewrite
|
||
from server.chat.sentence_reference import sentence_reference
|
||
from contextlib import asynccontextmanager
|
||
from server.translator_service.task_manager import TaskManager
|
||
|
||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||
|
||
|
||
|
||
async def document():
|
||
return RedirectResponse(url="/docs")
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
tm = TaskManager(translate_fn=translate_file)
|
||
tm.start()
|
||
app.state.tm = tm
|
||
# 2. 手动执行所有注册的 startup 钩子(包括启动器注入的)
|
||
for fn in app.router.on_startup:
|
||
await fn()
|
||
yield
|
||
tm.shutdown()
|
||
for fn in app.router.on_shutdown:
|
||
await fn()
|
||
|
||
def create_app(run_mode: str = None):
|
||
app = FastAPI(
|
||
title="Langchain-Chatchat API Server",
|
||
version=VERSION,
|
||
lifespan=lifespan,
|
||
)
|
||
MakeFastAPIOffline(app)
|
||
# asyncio.create_task(lifespans)
|
||
# Add CORS middleware to allow all origins
|
||
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
|
||
# set OPEN_DOMAIN=True in config.py to allow cross-domain
|
||
if OPEN_CROSS_DOMAIN:
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
mount_app_routes(app, run_mode=run_mode)
|
||
return app
|
||
|
||
|
||
def mount_app_routes(app: FastAPI, run_mode: str = None):
|
||
app.get("/",
|
||
response_model=BaseResponse,
|
||
summary="swagger 文档")(document)
|
||
|
||
app.get("/chat/get_image",
|
||
tags=["Chat"],
|
||
summary="获取图片",
|
||
)(get_image)
|
||
|
||
app.get("/chat/get_self_file",
|
||
tags=["Chat"],
|
||
summary="获取个人知识库文件",
|
||
)(download_self_doc)
|
||
|
||
# Tag: Chat
|
||
|
||
app.post("/chat/chat_comparison",
|
||
tags=["Chat"],
|
||
summary="生成文献对比报告"
|
||
)(chat_comparison_test)
|
||
|
||
app.post("/chat/chat",
|
||
tags=["Chat"],
|
||
summary="与llm模型对话(通过LLMChain)",
|
||
)(chat_test)
|
||
|
||
app.post("/chat",
|
||
tags=["Chat"],
|
||
summary="与llm模型对话(通过LLMChain)",
|
||
)(chat)
|
||
|
||
app.post("/chat/translate_text",
|
||
tags=["Chat"],
|
||
summary="翻译",
|
||
)(tarnslate_text)
|
||
|
||
app.post("/translate/translate_file",
|
||
tags=["translate"],
|
||
summary="文件翻译",
|
||
)(translate_file)
|
||
|
||
app.get("/translate/download_file",
|
||
tags=["translate"],
|
||
summary="下载译文",
|
||
)(download_result)
|
||
|
||
app.get("/translate/progress",
|
||
tags=["translate"],
|
||
summary="获取翻译任务进度",
|
||
)(get_progress)
|
||
|
||
app.post("/translate/cancel",
|
||
tags=["translate"],
|
||
summary="取消翻译任务",
|
||
)(cancel_task)
|
||
|
||
app.post("/chat/check_language",
|
||
tags=["Chat"],
|
||
summary="语种检测接口",
|
||
)(check_language)
|
||
|
||
app.get("/chat/get_language",
|
||
tags=["Chat"],
|
||
summary="获取当前支持的语种",
|
||
)(get_supported_languages)
|
||
|
||
# Tag: Chat
|
||
app.post("/chat/stop",
|
||
tags=["Chat"],
|
||
summary="中断模型请求",
|
||
)(stop)
|
||
app.post("/chat/upload_Allfile",
|
||
tags=["Chat"],
|
||
summary="上传文件",
|
||
)(upload_file)
|
||
app.post("/chat/outlines",
|
||
tags=["Chat"],
|
||
summary="与llm模型对话生成大纲(通过LLMChain)",
|
||
)(knowledge_chat_test)
|
||
|
||
app.post("/chat/finsh_outlines",
|
||
tags=["Chat"],
|
||
summary="与llm模型对话生成全文(通过LLMChain)",
|
||
)(write_article)
|
||
|
||
app.post("/chat/chat_test",
|
||
tags=["Chat"],
|
||
summary="与llm模型对话(通过LLMChain)",
|
||
)(chat_test)
|
||
|
||
app.post("/chat/search_engine_chat",
|
||
tags=["Chat"],
|
||
summary="与搜索引擎对话",
|
||
)(search_engine_chat)
|
||
|
||
app.post("/chat/feedback",
|
||
tags=["Chat"],
|
||
summary="返回llm模型对话评分",
|
||
)(chat_feedback)
|
||
|
||
app.post("/chat/gen_title",
|
||
tags=["Chat"],
|
||
summary="生成当前对话的标题",
|
||
)(gen_title)
|
||
|
||
app.post("/rewrite/con_rewrite",
|
||
tags=["Write"],
|
||
summary="续写文本",
|
||
)(con_rewrite)
|
||
|
||
app.post("/rewrite/exp_write",
|
||
tags=["Write"],
|
||
summary="扩写文本",
|
||
)(exp_rewrite)
|
||
|
||
app.post("/rewrite/abb_write",
|
||
tags=["Write"],
|
||
summary="缩写文本",
|
||
)(abb_rewrite)
|
||
|
||
app.post("/rewrite/rew_rewrite",
|
||
tags=["Write"],
|
||
summary="重写文本",
|
||
)(rew_rewrite)
|
||
|
||
app.post("/sentence_reference",
|
||
tags=["Write"],
|
||
summary="好句子提示",
|
||
)(sentence_reference)
|
||
|
||
app.post("/gen_abstract",
|
||
tags=["Write"],
|
||
summary="摘要生成",
|
||
)(gen_abstract)
|
||
|
||
app.post("/gen_conclusion",
|
||
tags=["Write"],
|
||
summary="结论生成",
|
||
)(gen_conclusion)
|
||
|
||
app.post("/gen_keywords",
|
||
tags=["Read"],
|
||
summary="关键词生成",
|
||
)(gen_keywords)
|
||
|
||
app.post("/gen_paragraph",
|
||
tags=["Read"],
|
||
summary="章节速览",
|
||
)(gen_paragraph)
|
||
|
||
app.post("/word_explain",
|
||
tags=["Read"],
|
||
summary="名词解释",
|
||
)(word_explain)
|
||
|
||
app.post("/relevant_articles",
|
||
tags=["Read"],
|
||
summary="相关文献",
|
||
)(relevant_articles)
|
||
#新功能接口
|
||
# app.post("/rewrite/con_rewrite",
|
||
# tags=["IastStrategy"],
|
||
# summary="续写文本",
|
||
# )(con_rewrite)
|
||
#
|
||
# app.post("/rewrite/rewrite",
|
||
# tags=["IastStrategy"],
|
||
# summary="改写",
|
||
# )(rewrite)
|
||
# app.post("/rewrite/exp_write",
|
||
# tags=["IastStrategy"],
|
||
# summary="扩写",
|
||
# )(exp_write)
|
||
|
||
# app.post("/rewrite/abb_write",
|
||
# tags=["IastStrategy"],
|
||
# summary="缩写",
|
||
# )(abb_write)
|
||
|
||
# app.post("/rewrite/embellish",
|
||
# tags=["IastStrategy"],
|
||
# summary="润色",
|
||
# )(embellish)
|
||
|
||
app.post("/rewrite/formal_style",
|
||
tags=["IastStrategy"],
|
||
summary="正式风格",
|
||
)(formal_style)
|
||
|
||
app.post("/rewrite/party_style",
|
||
tags=["IastStrategy"],
|
||
summary="党政风格",
|
||
)(party_style)
|
||
|
||
app.post("/rewrite/col_style",
|
||
tags=["IastStrategy"],
|
||
summary="口语风格",
|
||
)(col_style)
|
||
|
||
app.post("/rewrite/chi_to_ens",
|
||
tags=["IastStrategy"],
|
||
summary="中译英",
|
||
)(chi_to_ens)
|
||
|
||
app.post("/rewrite/ens_to_chi",
|
||
tags=["IastStrategy"],
|
||
summary="英译中",
|
||
)(ens_to_chi)
|
||
|
||
# 知识库相关接口
|
||
mount_knowledge_routes(app)
|
||
# 摘要相关接口
|
||
mount_filename_summary_routes(app)
|
||
|
||
# LLM模型相关接口
|
||
app.post("/llm_model/list_running_models",
|
||
tags=["LLM Model Management"],
|
||
summary="列出当前已加载的模型",
|
||
)(list_running_models)
|
||
|
||
app.post("/llm_model/list_config_models",
|
||
tags=["LLM Model Management"],
|
||
summary="列出configs已配置的模型",
|
||
)(list_config_models)
|
||
|
||
app.post("/llm_model/get_model_config",
|
||
tags=["LLM Model Management"],
|
||
summary="获取模型配置(合并后)",
|
||
)(get_model_config)
|
||
|
||
app.post("/llm_model/stop",
|
||
tags=["LLM Model Management"],
|
||
summary="停止指定的LLM模型(Model Worker)",
|
||
)(stop_llm_model)
|
||
|
||
app.post("/llm_model/change",
|
||
tags=["LLM Model Management"],
|
||
summary="切换指定的LLM模型(Model Worker)",
|
||
)(change_llm_model)
|
||
|
||
# 服务器相关接口
|
||
app.post("/server/configs",
|
||
tags=["Server State"],
|
||
summary="获取服务器原始配置信息",
|
||
)(get_server_configs)
|
||
|
||
app.post("/server/list_search_engines",
|
||
tags=["Server State"],
|
||
summary="获取服务器支持的搜索引擎",
|
||
)(list_search_engines)
|
||
|
||
@app.post("/server/get_prompt_template",
|
||
tags=["Server State"],
|
||
summary="获取服务区配置的 prompt 模板")
|
||
|
||
def get_server_prompt_template(
|
||
type: Literal[
|
||
"llm_chat",
|
||
"knowledge_base_chat",
|
||
"report_chat",
|
||
"search_engine_chat",
|
||
"agent_chat"
|
||
] = Body("llm_chat",
|
||
description="模板类型,可选值:llm_chat,knowledge_base_chat,search_engine_chat,agent_chat"),
|
||
name: str = Body("default", description="模板名称"),
|
||
) -> str:
|
||
return get_prompt_template(type=type, name=name)
|
||
|
||
# 其它接口
|
||
app.post("/other/completion",
|
||
tags=["Other"],
|
||
summary="要求llm模型补全(通过LLMChain)",
|
||
)(completion)
|
||
|
||
app.post("/other/embed_texts",
|
||
tags=["Other"],
|
||
summary="将文本向量化,支持本地模型和在线模型",
|
||
)(embed_texts_endpoint)
|
||
app.post("/knowledge_base/chapter_overview",
|
||
tags=["Other"],
|
||
summary="文件速览"
|
||
)(chapter_overview)
|
||
app.post("/knowledge_base/abstract_search",
|
||
tags=["Other"],
|
||
summary="相似摘要搜索"
|
||
)(abstract_search)
|
||
app.post("/knowledge_base/article_overview",
|
||
tags=["Other"],
|
||
summary="文件综述"
|
||
)(article_overview)
|
||
app.post("/knowledge_base/paper_translation",
|
||
tags=["Other"],
|
||
summary="论文翻译"
|
||
)(paper_translation)
|
||
|
||
def mount_knowledge_routes(app: FastAPI):
|
||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||
from server.chat.knowledge_base_chat_old import knowledge_base_chat_old
|
||
from server.chat.report_chat import report_chat
|
||
from server.chat.file_chat import upload_temp_docs, file_chat
|
||
from server.chat.agent_chat import agent_chat
|
||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||
update_docs, download_doc, recreate_vector_store,
|
||
search_docs, DocumentWithVSId, update_info,
|
||
update_docs_by_id, )
|
||
|
||
app.post("/chat/knowledge_base_chat",
|
||
tags=["Chat"],
|
||
summary="与知识库对话")(knowledge_base_chat)
|
||
|
||
app.post("/chat/self_kb_chat",
|
||
tags=["Chat"],
|
||
summary="与个人知识库对话")(self_kb_chat)
|
||
|
||
app.post("/chat/knowledge_base_chat_old",
|
||
tags=["Chat"],
|
||
summary="旧版与知识库对话")(knowledge_base_chat_old)
|
||
|
||
app.post("/chat/report_chat",
|
||
tags=["Chat"],
|
||
summary="与报告知识库对话")(report_chat)
|
||
|
||
app.post("/chat/file_chat",
|
||
tags=["Knowledge Base Management"],
|
||
summary="文件对话"
|
||
)(file_chat)
|
||
|
||
app.post("/chat/agent_chat",
|
||
tags=["Chat"],
|
||
summary="与agent对话")(agent_chat)
|
||
|
||
# Tag: Knowledge Base Management
|
||
app.get("/knowledge_base/list_knowledge_bases",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=ListResponse,
|
||
summary="获取知识库列表")(list_kbs)
|
||
|
||
app.post("/knowledge_base/create_knowledge_base",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=BaseResponse,
|
||
summary="创建知识库"
|
||
)(create_kb)
|
||
|
||
app.post("/knowledge_base/delete_knowledge_base",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=BaseResponse,
|
||
summary="删除知识库"
|
||
)(delete_kb)
|
||
|
||
app.get("/knowledge_base/list_files",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=ListResponse,
|
||
summary="获取知识库内的文件列表"
|
||
)(list_files)
|
||
|
||
app.post("/knowledge_base/search_docs",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=List[DocumentWithVSId],
|
||
summary="搜索知识库"
|
||
)(search_docs)
|
||
|
||
app.post("/knowledge_base/search_self_docs",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=List[DocumentWithVSId],
|
||
summary="搜索个人知识库"
|
||
)(search_self_docs)
|
||
app.post("/knowledge_base/update_docs_by_id",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=BaseResponse,
|
||
summary="直接更新知识库文档"
|
||
)(update_docs_by_id)
|
||
|
||
app.post("/knowledge_base/upload_docs",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=BaseResponse,
|
||
summary="上传文件到知识库,并/或进行向量化"
|
||
)(upload_docs)
|
||
|
||
|
||
app.post("/knowledge_base/upload_docs_new",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=BaseResponse,
|
||
summary="上传文件到知识库,并/或进行向量化,并获取解析结果"
|
||
)(upload_docs_new)
|
||
|
||
app.post("/knowledge_base/delete_docs",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=BaseResponse,
|
||
summary="删除知识库内指定文件"
|
||
)(delete_docs)
|
||
|
||
app.post("/knowledge_base/update_info",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=BaseResponse,
|
||
summary="更新知识库介绍"
|
||
)(update_info)
|
||
app.post("/knowledge_base/update_docs",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=BaseResponse,
|
||
summary="更新现有文件到知识库"
|
||
)(update_docs)
|
||
|
||
app.get("/knowledge_base/download_doc",
|
||
tags=["Knowledge Base Management"],
|
||
summary="下载对应的知识文件")(download_doc)
|
||
|
||
app.post("/knowledge_base/recreate_vector_store",
|
||
tags=["Knowledge Base Management"],
|
||
summary="根据content中文档重建向量库,流式输出处理进度。"
|
||
)(recreate_vector_store)
|
||
|
||
app.post("/knowledge_base/upload_temp_docs",
|
||
tags=["Knowledge Base Management"],
|
||
summary="上传文件到临时目录,用于文件对话。"
|
||
)(upload_temp_docs)
|
||
|
||
|
||
def mount_filename_summary_routes(app: FastAPI):
|
||
from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store,
|
||
summary_doc_ids_to_vector_store)
|
||
|
||
app.post("/knowledge_base/kb_summary_api/summary_file_to_vector_store",
|
||
tags=["Knowledge kb_summary_api Management"],
|
||
summary="单个知识库根据文件名称摘要"
|
||
)(summary_file_to_vector_store)
|
||
app.post("/knowledge_base/kb_summary_api/summary_doc_ids_to_vector_store",
|
||
tags=["Knowledge kb_summary_api Management"],
|
||
summary="单个知识库根据doc_ids摘要",
|
||
response_model=BaseResponse,
|
||
)(summary_doc_ids_to_vector_store)
|
||
app.post("/knowledge_base/kb_summary_api/recreate_summary_vector_store",
|
||
tags=["Knowledge kb_summary_api Management"],
|
||
summary="重建单个知识库文件摘要"
|
||
)(recreate_summary_vector_store)
|
||
|
||
|
||
def run_api(host, port, **kwargs):
|
||
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
||
uvicorn.run(app,
|
||
host=host,
|
||
port=port,
|
||
ssl_keyfile=kwargs.get("ssl_keyfile"),
|
||
ssl_certfile=kwargs.get("ssl_certfile"),
|
||
)
|
||
else:
|
||
uvicorn.run(app, host=host, port=port)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(prog='langchain-ChatGLM',
|
||
description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain'
|
||
' | 基于本地知识库的 ChatGLM 问答')
|
||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||
parser.add_argument("--port", type=int, default=7861)
|
||
parser.add_argument("--ssl_keyfile", type=str)
|
||
parser.add_argument("--ssl_certfile", type=str)
|
||
# 初始化消息
|
||
args = parser.parse_args()
|
||
args_dict = vars(args)
|
||
|
||
app = create_app()
|
||
|
||
run_api(host=args.host,
|
||
port=args.port,
|
||
ssl_keyfile=args.ssl_keyfile,
|
||
ssl_certfile=args.ssl_certfile,
|
||
)
|