Files
gangyan/langchain-chat/server/api.py

579 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from server.chat.check_language import check_language, get_supported_languages
from server.chat.chat_comparison import chat_comparison_test
from server.chat.gen_title import gen_title
from server.chat.relevant_articles import relevant_articles
from server.chat.self_kb_chat import self_kb_chat
from server.chat.stop import stop
import nltk
import sys
import os
from server.chat.chat_test import chat_test, get_image
from server.chat.gen_abstract import gen_abstract
from server.chat.gen_conclusion import gen_conclusion
from server.chat.gen_keywords import gen_keywords
from server.chat.gen_paragraph import gen_paragraph
from server.chat.knowledge_chat_test import knowledge_chat_test
from server.chat.translate import tarnslate_text
from server.chat.upload import upload_file
from server.chat.utils import download_self_doc
from server.chat.word_explain import word_explain
from server.chat.write_article import write_article
from server.knowledge_base.kb_doc_api import search_self_docs, upload_docs_new
from server.translator_service.main_api import cancel_task, download_result, get_progress, translate_file
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs import VERSION
from configs.model_config import NLTK_DATA_PATH
from configs.server_config import OPEN_CROSS_DOMAIN
import argparse
import uvicorn
from fastapi import Body
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from server.chat.chat import chat
from server.chat.search_engine_chat import search_engine_chat
from server.chat.completion import completion
from server.custom.chapter_overview import chapter_overview
from server.custom.article_overview import article_overview
from server.custom.abstract_search import abstract_search
from server.custom.paper_translation import paper_translation
from server.chat.feedback import chat_feedback
from server.embeddings_api import embed_texts_endpoint
from server.llm_api import (list_running_models, list_config_models,
change_llm_model, stop_llm_model,
get_model_config, list_search_engines)
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, get_server_configs, get_prompt_template)
from typing import List, Literal
from server.chat.rewrite import(
# con_rewrite,
# exp_write,
# abb_write,
formal_style,
party_style,
col_style,
chi_to_ens,
ens_to_chi
)
from server.chat.con_rewrite import con_rewrite
from server.chat.exp_rewrite import exp_rewrite
from server.chat.abb_rewrite import abb_rewrite
from server.chat.rew_rewrite import rew_rewrite
from server.chat.sentence_reference import sentence_reference
from contextlib import asynccontextmanager
from server.translator_service.task_manager import TaskManager
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
async def document():
return RedirectResponse(url="/docs")
@asynccontextmanager
async def lifespan(app: FastAPI):
tm = TaskManager(translate_fn=translate_file)
tm.start()
app.state.tm = tm
# 2. 手动执行所有注册的 startup 钩子(包括启动器注入的)
for fn in app.router.on_startup:
await fn()
yield
tm.shutdown()
for fn in app.router.on_shutdown:
await fn()
def create_app(run_mode: str = None):
app = FastAPI(
title="Langchain-Chatchat API Server",
version=VERSION,
lifespan=lifespan,
)
MakeFastAPIOffline(app)
# asyncio.create_task(lifespans)
# Add CORS middleware to allow all origins
# 在config.py中设置OPEN_DOMAIN=True允许跨域
# set OPEN_DOMAIN=True in config.py to allow cross-domain
if OPEN_CROSS_DOMAIN:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
mount_app_routes(app, run_mode=run_mode)
return app
def mount_app_routes(app: FastAPI, run_mode: str = None):
app.get("/",
response_model=BaseResponse,
summary="swagger 文档")(document)
app.get("/chat/get_image",
tags=["Chat"],
summary="获取图片",
)(get_image)
app.get("/chat/get_self_file",
tags=["Chat"],
summary="获取个人知识库文件",
)(download_self_doc)
# Tag: Chat
app.post("/chat/chat_comparison",
tags=["Chat"],
summary="生成文献对比报告"
)(chat_comparison_test)
app.post("/chat/chat",
tags=["Chat"],
summary="与llm模型对话(通过LLMChain)",
)(chat_test)
app.post("/chat",
tags=["Chat"],
summary="与llm模型对话(通过LLMChain)",
)(chat)
app.post("/chat/translate_text",
tags=["Chat"],
summary="翻译",
)(tarnslate_text)
app.post("/translate/translate_file",
tags=["translate"],
summary="文件翻译",
)(translate_file)
app.get("/translate/download_file",
tags=["translate"],
summary="下载译文",
)(download_result)
app.get("/translate/progress",
tags=["translate"],
summary="获取翻译任务进度",
)(get_progress)
app.post("/translate/cancel",
tags=["translate"],
summary="取消翻译任务",
)(cancel_task)
app.post("/chat/check_language",
tags=["Chat"],
summary="语种检测接口",
)(check_language)
app.get("/chat/get_language",
tags=["Chat"],
summary="获取当前支持的语种",
)(get_supported_languages)
# Tag: Chat
app.post("/chat/stop",
tags=["Chat"],
summary="中断模型请求",
)(stop)
app.post("/chat/upload_Allfile",
tags=["Chat"],
summary="上传文件",
)(upload_file)
app.post("/chat/outlines",
tags=["Chat"],
summary="与llm模型对话生成大纲(通过LLMChain)",
)(knowledge_chat_test)
app.post("/chat/finsh_outlines",
tags=["Chat"],
summary="与llm模型对话生成全文(通过LLMChain)",
)(write_article)
app.post("/chat/chat_test",
tags=["Chat"],
summary="与llm模型对话(通过LLMChain)",
)(chat_test)
app.post("/chat/search_engine_chat",
tags=["Chat"],
summary="与搜索引擎对话",
)(search_engine_chat)
app.post("/chat/feedback",
tags=["Chat"],
summary="返回llm模型对话评分",
)(chat_feedback)
app.post("/chat/gen_title",
tags=["Chat"],
summary="生成当前对话的标题",
)(gen_title)
app.post("/rewrite/con_rewrite",
tags=["Write"],
summary="续写文本",
)(con_rewrite)
app.post("/rewrite/exp_write",
tags=["Write"],
summary="扩写文本",
)(exp_rewrite)
app.post("/rewrite/abb_write",
tags=["Write"],
summary="缩写文本",
)(abb_rewrite)
app.post("/rewrite/rew_rewrite",
tags=["Write"],
summary="重写文本",
)(rew_rewrite)
app.post("/sentence_reference",
tags=["Write"],
summary="好句子提示",
)(sentence_reference)
app.post("/gen_abstract",
tags=["Write"],
summary="摘要生成",
)(gen_abstract)
app.post("/gen_conclusion",
tags=["Write"],
summary="结论生成",
)(gen_conclusion)
app.post("/gen_keywords",
tags=["Read"],
summary="关键词生成",
)(gen_keywords)
app.post("/gen_paragraph",
tags=["Read"],
summary="章节速览",
)(gen_paragraph)
app.post("/word_explain",
tags=["Read"],
summary="名词解释",
)(word_explain)
app.post("/relevant_articles",
tags=["Read"],
summary="相关文献",
)(relevant_articles)
#新功能接口
# app.post("/rewrite/con_rewrite",
# tags=["IastStrategy"],
# summary="续写文本",
# )(con_rewrite)
#
# app.post("/rewrite/rewrite",
# tags=["IastStrategy"],
# summary="改写",
# )(rewrite)
# app.post("/rewrite/exp_write",
# tags=["IastStrategy"],
# summary="扩写",
# )(exp_write)
# app.post("/rewrite/abb_write",
# tags=["IastStrategy"],
# summary="缩写",
# )(abb_write)
# app.post("/rewrite/embellish",
# tags=["IastStrategy"],
# summary="润色",
# )(embellish)
app.post("/rewrite/formal_style",
tags=["IastStrategy"],
summary="正式风格",
)(formal_style)
app.post("/rewrite/party_style",
tags=["IastStrategy"],
summary="党政风格",
)(party_style)
app.post("/rewrite/col_style",
tags=["IastStrategy"],
summary="口语风格",
)(col_style)
app.post("/rewrite/chi_to_ens",
tags=["IastStrategy"],
summary="中译英",
)(chi_to_ens)
app.post("/rewrite/ens_to_chi",
tags=["IastStrategy"],
summary="英译中",
)(ens_to_chi)
# 知识库相关接口
mount_knowledge_routes(app)
# 摘要相关接口
mount_filename_summary_routes(app)
# LLM模型相关接口
app.post("/llm_model/list_running_models",
tags=["LLM Model Management"],
summary="列出当前已加载的模型",
)(list_running_models)
app.post("/llm_model/list_config_models",
tags=["LLM Model Management"],
summary="列出configs已配置的模型",
)(list_config_models)
app.post("/llm_model/get_model_config",
tags=["LLM Model Management"],
summary="获取模型配置(合并后)",
)(get_model_config)
app.post("/llm_model/stop",
tags=["LLM Model Management"],
summary="停止指定的LLM模型Model Worker)",
)(stop_llm_model)
app.post("/llm_model/change",
tags=["LLM Model Management"],
summary="切换指定的LLM模型Model Worker)",
)(change_llm_model)
# 服务器相关接口
app.post("/server/configs",
tags=["Server State"],
summary="获取服务器原始配置信息",
)(get_server_configs)
app.post("/server/list_search_engines",
tags=["Server State"],
summary="获取服务器支持的搜索引擎",
)(list_search_engines)
@app.post("/server/get_prompt_template",
tags=["Server State"],
summary="获取服务区配置的 prompt 模板")
def get_server_prompt_template(
type: Literal[
"llm_chat",
"knowledge_base_chat",
"report_chat",
"search_engine_chat",
"agent_chat"
] = Body("llm_chat",
description="模板类型可选值llm_chatknowledge_base_chatsearch_engine_chatagent_chat"),
name: str = Body("default", description="模板名称"),
) -> str:
return get_prompt_template(type=type, name=name)
# 其它接口
app.post("/other/completion",
tags=["Other"],
summary="要求llm模型补全(通过LLMChain)",
)(completion)
app.post("/other/embed_texts",
tags=["Other"],
summary="将文本向量化,支持本地模型和在线模型",
)(embed_texts_endpoint)
app.post("/knowledge_base/chapter_overview",
tags=["Other"],
summary="文件速览"
)(chapter_overview)
app.post("/knowledge_base/abstract_search",
tags=["Other"],
summary="相似摘要搜索"
)(abstract_search)
app.post("/knowledge_base/article_overview",
tags=["Other"],
summary="文件综述"
)(article_overview)
app.post("/knowledge_base/paper_translation",
tags=["Other"],
summary="论文翻译"
)(paper_translation)
def mount_knowledge_routes(app: FastAPI):
from server.chat.knowledge_base_chat import knowledge_base_chat
from server.chat.knowledge_base_chat_old import knowledge_base_chat_old
from server.chat.report_chat import report_chat
from server.chat.file_chat import upload_temp_docs, file_chat
from server.chat.agent_chat import agent_chat
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store,
search_docs, DocumentWithVSId, update_info,
update_docs_by_id, )
app.post("/chat/knowledge_base_chat",
tags=["Chat"],
summary="与知识库对话")(knowledge_base_chat)
app.post("/chat/self_kb_chat",
tags=["Chat"],
summary="与个人知识库对话")(self_kb_chat)
app.post("/chat/knowledge_base_chat_old",
tags=["Chat"],
summary="旧版与知识库对话")(knowledge_base_chat_old)
app.post("/chat/report_chat",
tags=["Chat"],
summary="与报告知识库对话")(report_chat)
app.post("/chat/file_chat",
tags=["Knowledge Base Management"],
summary="文件对话"
)(file_chat)
app.post("/chat/agent_chat",
tags=["Chat"],
summary="与agent对话")(agent_chat)
# Tag: Knowledge Base Management
app.get("/knowledge_base/list_knowledge_bases",
tags=["Knowledge Base Management"],
response_model=ListResponse,
summary="获取知识库列表")(list_kbs)
app.post("/knowledge_base/create_knowledge_base",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="创建知识库"
)(create_kb)
app.post("/knowledge_base/delete_knowledge_base",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库"
)(delete_kb)
app.get("/knowledge_base/list_files",
tags=["Knowledge Base Management"],
response_model=ListResponse,
summary="获取知识库内的文件列表"
)(list_files)
app.post("/knowledge_base/search_docs",
tags=["Knowledge Base Management"],
response_model=List[DocumentWithVSId],
summary="搜索知识库"
)(search_docs)
app.post("/knowledge_base/search_self_docs",
tags=["Knowledge Base Management"],
response_model=List[DocumentWithVSId],
summary="搜索个人知识库"
)(search_self_docs)
app.post("/knowledge_base/update_docs_by_id",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="直接更新知识库文档"
)(update_docs_by_id)
app.post("/knowledge_base/upload_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="上传文件到知识库,并/或进行向量化"
)(upload_docs)
app.post("/knowledge_base/upload_docs_new",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="上传文件到知识库,并/或进行向量化,并获取解析结果"
)(upload_docs_new)
app.post("/knowledge_base/delete_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库内指定文件"
)(delete_docs)
app.post("/knowledge_base/update_info",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="更新知识库介绍"
)(update_info)
app.post("/knowledge_base/update_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="更新现有文件到知识库"
)(update_docs)
app.get("/knowledge_base/download_doc",
tags=["Knowledge Base Management"],
summary="下载对应的知识文件")(download_doc)
app.post("/knowledge_base/recreate_vector_store",
tags=["Knowledge Base Management"],
summary="根据content中文档重建向量库流式输出处理进度。"
)(recreate_vector_store)
app.post("/knowledge_base/upload_temp_docs",
tags=["Knowledge Base Management"],
summary="上传文件到临时目录,用于文件对话。"
)(upload_temp_docs)
def mount_filename_summary_routes(app: FastAPI):
from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store,
summary_doc_ids_to_vector_store)
app.post("/knowledge_base/kb_summary_api/summary_file_to_vector_store",
tags=["Knowledge kb_summary_api Management"],
summary="单个知识库根据文件名称摘要"
)(summary_file_to_vector_store)
app.post("/knowledge_base/kb_summary_api/summary_doc_ids_to_vector_store",
tags=["Knowledge kb_summary_api Management"],
summary="单个知识库根据doc_ids摘要",
response_model=BaseResponse,
)(summary_doc_ids_to_vector_store)
app.post("/knowledge_base/kb_summary_api/recreate_summary_vector_store",
tags=["Knowledge kb_summary_api Management"],
summary="重建单个知识库文件摘要"
)(recreate_summary_vector_store)
def run_api(host, port, **kwargs):
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
uvicorn.run(app,
host=host,
port=port,
ssl_keyfile=kwargs.get("ssl_keyfile"),
ssl_certfile=kwargs.get("ssl_certfile"),
)
else:
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='langchain-ChatGLM',
description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain'
' 基于本地知识库的 ChatGLM 问答')
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7861)
parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str)
# 初始化消息
args = parser.parse_args()
args_dict = vars(args)
app = create_app()
run_api(host=args.host,
port=args.port,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
)