[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
780
langchain-chat/server/utils.py
Normal file
780
langchain-chat/server/utils.py
Normal file
@@ -0,0 +1,780 @@
|
||||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
from fastapi import FastAPI
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
from configs import (LLM_MODELS, LLM_DEVICE, EMBEDDING_DEVICE,
|
||||
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, logger, log_verbose,
|
||||
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_community.llms import OpenAI
|
||||
import httpx
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Literal,
|
||||
Optional,
|
||||
Callable,
|
||||
Generator,
|
||||
Dict,
|
||||
Any,
|
||||
Awaitable,
|
||||
Union,
|
||||
Tuple
|
||||
)
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from server.minx_chat_openai import MinxChatOpenAI
|
||||
|
||||
|
||||
async def wrap_done(fn: Awaitable, event: asyncio.Event, queue: Optional[asyncio.Queue]=None):
|
||||
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
|
||||
try:
|
||||
await fn
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
msg = f"Caught exception: {e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
if queue is not None:
|
||||
await queue.put(e)
|
||||
finally:
|
||||
# Signal the aiter to stop.
|
||||
event.set()
|
||||
|
||||
|
||||
def get_ChatOpenAI(
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
max_tokens: int = None,
|
||||
streaming: bool = True,
|
||||
callbacks: List[Callable] = [],
|
||||
verbose: bool = True,
|
||||
# extra_body: dict = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatOpenAI:
|
||||
config = get_model_worker_config(model_name)
|
||||
if model_name == "openai-api":
|
||||
model_name = config.get("model_name")
|
||||
ChatOpenAI._get_encoding_model = MinxChatOpenAI.get_encoding_model
|
||||
|
||||
if model_name == 'Qwen3-8B':
|
||||
kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": False}}
|
||||
|
||||
model = ChatOpenAI(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
openai_api_key=config.get("api_key", "EMPTY"),
|
||||
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
openai_proxy=config.get("openai_proxy"),
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def get_OpenAI(
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
max_tokens: int = None,
|
||||
streaming: bool = True,
|
||||
echo: bool = True,
|
||||
callbacks: List[Callable] = [],
|
||||
verbose: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> OpenAI:
|
||||
config = get_model_worker_config(model_name)
|
||||
if model_name == "openai-api":
|
||||
model_name = config.get("model_name")
|
||||
model = OpenAI(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
openai_api_key=config.get("api_key", "EMPTY"),
|
||||
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
openai_proxy=config.get("openai_proxy"),
|
||||
echo=echo,
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
code: int = pydantic.Field(200, description="API status code")
|
||||
msg: str = pydantic.Field("success", description="API status message")
|
||||
data: Any = pydantic.Field(None, description="API data")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"code": 200,
|
||||
"msg": "success",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ListResponse(BaseResponse):
|
||||
data: List[str] = pydantic.Field(..., description="List of names")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"code": 200,
|
||||
"msg": "success",
|
||||
"data": ["doc1.docx", "doc2.pdf", "doc3.txt"],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
question: str = pydantic.Field(..., description="Question text")
|
||||
response: str = pydantic.Field(..., description="Response text")
|
||||
history: List[List[str]] = pydantic.Field(..., description="History text")
|
||||
source_documents: List[str] = pydantic.Field(
|
||||
..., description="List of source documents and their scores"
|
||||
)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"question": "工伤保险如何办理?",
|
||||
"response": "根据已知信息,可以总结如下:\n\n1. 参保单位为员工缴纳工伤保险费,以保障员工在发生工伤时能够获得相应的待遇。\n"
|
||||
"2. 不同地区的工伤保险缴费规定可能有所不同,需要向当地社保部门咨询以了解具体的缴费标准和规定。\n"
|
||||
"3. 工伤从业人员及其近亲属需要申请工伤认定,确认享受的待遇资格,并按时缴纳工伤保险费。\n"
|
||||
"4. 工伤保险待遇包括工伤医疗、康复、辅助器具配置费用、伤残待遇、工亡待遇、一次性工亡补助金等。\n"
|
||||
"5. 工伤保险待遇领取资格认证包括长期待遇领取人员认证和一次性待遇领取人员认证。\n"
|
||||
"6. 工伤保险基金支付的待遇项目包括工伤医疗待遇、康复待遇、辅助器具配置费用、一次性工亡补助金、丧葬补助金等。",
|
||||
"history": [
|
||||
[
|
||||
"工伤保险是什么?",
|
||||
"工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,"
|
||||
"由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
|
||||
]
|
||||
],
|
||||
"source_documents": [
|
||||
"出处 [1] 广州市单位从业的特定人员参加工伤保险办事指引.docx:\n\n\t"
|
||||
"( 一) 从业单位 (组织) 按“自愿参保”原则, 为未建 立劳动关系的特定从业人员单项参加工伤保险 、缴纳工伤保 险费。",
|
||||
"出处 [2] ...",
|
||||
"出处 [3] ...",
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def torch_gc():
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
# with torch.cuda.device(DEVICE):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
elif torch.backends.mps.is_available():
|
||||
try:
|
||||
from torch.mps import empty_cache
|
||||
empty_cache()
|
||||
except Exception as e:
|
||||
msg = ("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,"
|
||||
"以支持及时清理 torch 产生的内存占用。")
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
except Exception:
|
||||
...
|
||||
|
||||
|
||||
def run_async(cor):
|
||||
'''
|
||||
在同步环境中运行异步代码.
|
||||
'''
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except:
|
||||
loop = asyncio.new_event_loop()
|
||||
return loop.run_until_complete(cor)
|
||||
|
||||
|
||||
def iter_over_async(ait, loop=None):
|
||||
'''
|
||||
将异步生成器封装成同步生成器.
|
||||
'''
|
||||
ait = ait.__aiter__()
|
||||
|
||||
async def get_next():
|
||||
try:
|
||||
obj = await ait.__anext__()
|
||||
return False, obj
|
||||
except StopAsyncIteration:
|
||||
return True, None
|
||||
|
||||
if loop is None:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except:
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
while True:
|
||||
done, obj = loop.run_until_complete(get_next())
|
||||
if done:
|
||||
break
|
||||
yield obj
|
||||
|
||||
|
||||
def MakeFastAPIOffline(
|
||||
app: FastAPI,
|
||||
static_dir=Path(__file__).parent / "static",
|
||||
static_url="/static-offline-docs",
|
||||
docs_url: Optional[str] = "/docs",
|
||||
redoc_url: Optional[str] = "/redoc",
|
||||
) -> None:
|
||||
"""patch the FastAPI obj that doesn't rely on CDN for the documentation page"""
|
||||
from fastapi import Request
|
||||
from fastapi.openapi.docs import (
|
||||
get_redoc_html,
|
||||
get_swagger_ui_html,
|
||||
get_swagger_ui_oauth2_redirect_html,
|
||||
)
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.responses import HTMLResponse
|
||||
import secrets
|
||||
from typing import Annotated
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from starlette.responses import RedirectResponse
|
||||
|
||||
openapi_url = app.openapi_url
|
||||
swagger_ui_oauth2_redirect_url = app.swagger_ui_oauth2_redirect_url
|
||||
|
||||
security = HTTPBasic()
|
||||
def get_current_username(
|
||||
credentials: Annotated[HTTPBasicCredentials, Depends(security)],
|
||||
):
|
||||
current_username_bytes = credentials.username.encode("utf8")
|
||||
correct_username_bytes = b"SwaggerAdmin"
|
||||
is_correct_username = secrets.compare_digest(
|
||||
current_username_bytes, correct_username_bytes
|
||||
)
|
||||
current_password_bytes = credentials.password.encode("utf8")
|
||||
correct_password_bytes = b"G6@PJZh*"
|
||||
is_correct_password = secrets.compare_digest(
|
||||
current_password_bytes, correct_password_bytes
|
||||
)
|
||||
if not (is_correct_username and is_correct_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
return credentials.username
|
||||
def remove_route(url: str) -> None:
|
||||
'''
|
||||
remove original route from app
|
||||
'''
|
||||
index = None
|
||||
for i, r in enumerate(app.routes):
|
||||
if r.path.lower() == url.lower():
|
||||
index = i
|
||||
break
|
||||
if isinstance(index, int):
|
||||
app.routes.pop(index)
|
||||
|
||||
# Set up static file mount
|
||||
app.mount(
|
||||
static_url,
|
||||
StaticFiles(directory=Path(static_dir).as_posix()),
|
||||
name="static-offline-docs",
|
||||
)
|
||||
|
||||
if docs_url is not None:
|
||||
remove_route(docs_url)
|
||||
remove_route(swagger_ui_oauth2_redirect_url)
|
||||
|
||||
# Define the doc and redoc pages, pointing at the right files
|
||||
@app.get(docs_url, include_in_schema=False)
|
||||
async def custom_swagger_ui_html(request: Request, username: str = Depends(get_current_username)) -> HTMLResponse:
|
||||
root = request.scope.get("root_path")
|
||||
favicon = f"{root}{static_url}/favicon.png"
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=f"{root}{openapi_url}",
|
||||
title=app.title + " - Swagger UI",
|
||||
oauth2_redirect_url=swagger_ui_oauth2_redirect_url,
|
||||
swagger_js_url=f"{root}{static_url}/swagger-ui-bundle.js",
|
||||
swagger_css_url=f"{root}{static_url}/swagger-ui.css",
|
||||
swagger_favicon_url=favicon,
|
||||
)
|
||||
|
||||
@app.get(swagger_ui_oauth2_redirect_url, include_in_schema=False)
|
||||
async def swagger_ui_redirect(username: str = Depends(get_current_username)) -> HTMLResponse:
|
||||
return get_swagger_ui_oauth2_redirect_html()
|
||||
|
||||
if redoc_url is not None:
|
||||
remove_route(redoc_url)
|
||||
|
||||
@app.get(redoc_url, include_in_schema=False)
|
||||
async def redoc_html(request: Request,username: str = Depends(get_current_username)) -> HTMLResponse:
|
||||
root = request.scope.get("root_path")
|
||||
favicon = f"{root}{static_url}/favicon.png"
|
||||
|
||||
return get_redoc_html(
|
||||
openapi_url=f"{root}{openapi_url}",
|
||||
title=app.title + " - ReDoc",
|
||||
redoc_js_url=f"{root}{static_url}/redoc.standalone.js",
|
||||
with_google_fonts=False,
|
||||
redoc_favicon_url=favicon,
|
||||
)
|
||||
|
||||
|
||||
# 从model_config中获取模型信息
|
||||
|
||||
def list_embed_models() -> List[str]:
|
||||
'''
|
||||
get names of configured embedding models
|
||||
'''
|
||||
return list(MODEL_PATH["embed_model"])
|
||||
|
||||
|
||||
def resolve_embed_model_name(model_name: Optional[str]) -> str:
|
||||
"""
|
||||
将数据库/历史配置中的嵌入名映射为当前配置(如 bge_m3 -> bge-m3-api),避免仍走本地 HuggingFace。
|
||||
"""
|
||||
from configs.model_config import EMBED_MODEL_ALIASES, EMBEDDING_MODEL
|
||||
|
||||
if model_name is None or str(model_name).strip() == "":
|
||||
return EMBEDDING_MODEL
|
||||
name = str(model_name).strip()
|
||||
return EMBED_MODEL_ALIASES.get(name, name)
|
||||
|
||||
|
||||
def resolve_llm_model_name(model_name: Optional[str]) -> str:
|
||||
"""
|
||||
将历史/前端传入的模型名映射到当前可用模型名,避免因网关权限导致 403。
|
||||
"""
|
||||
from configs.model_config import LLM_MODEL_ALIASES, LLM_MODELS
|
||||
|
||||
if model_name is None or str(model_name).strip() == "":
|
||||
return LLM_MODELS[0]
|
||||
name = str(model_name).strip()
|
||||
return LLM_MODEL_ALIASES.get(name, name)
|
||||
|
||||
|
||||
def list_config_llm_models() -> Dict[str, Dict]:
|
||||
'''
|
||||
get configured llm models with different types.
|
||||
return {config_type: {model_name: config}, ...}
|
||||
'''
|
||||
workers = FSCHAT_MODEL_WORKERS.copy()
|
||||
workers.pop("default", None)
|
||||
|
||||
return {
|
||||
"local": MODEL_PATH["llm_model"].copy(),
|
||||
"online": ONLINE_LLM_MODEL.copy(),
|
||||
"worker": workers,
|
||||
}
|
||||
|
||||
|
||||
def get_model_path(model_name: str, type: str = None) -> Optional[str]:
|
||||
if type in MODEL_PATH:
|
||||
paths = MODEL_PATH[type]
|
||||
else:
|
||||
paths = {}
|
||||
for v in MODEL_PATH.values():
|
||||
paths.update(v)
|
||||
|
||||
if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径
|
||||
path = Path(path_str)
|
||||
if path.is_dir(): # 任意绝对路径
|
||||
return str(path)
|
||||
|
||||
root_path = Path(MODEL_ROOT_PATH)
|
||||
if root_path.is_dir():
|
||||
path = root_path / model_name
|
||||
if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b
|
||||
return str(path)
|
||||
path = root_path / path_str
|
||||
if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new
|
||||
return str(path)
|
||||
path = root_path / path_str.split("/")[-1]
|
||||
if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new
|
||||
return str(path)
|
||||
return path_str # THUDM/chatglm06b
|
||||
|
||||
|
||||
# 从server_config中获取服务信息
|
||||
|
||||
def get_model_worker_config(model_name: str = None) -> dict:
|
||||
'''
|
||||
加载model worker的配置项。
|
||||
优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"]
|
||||
'''
|
||||
from configs.model_config import ONLINE_LLM_MODEL, MODEL_PATH
|
||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||
from server import model_workers
|
||||
|
||||
model_name = resolve_llm_model_name(model_name)
|
||||
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
||||
config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy())
|
||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy())
|
||||
|
||||
if model_name in ONLINE_LLM_MODEL:
|
||||
config["online_api"] = True
|
||||
if provider := config.get("provider"):
|
||||
try:
|
||||
config["worker_class"] = getattr(model_workers, provider)
|
||||
except Exception as e:
|
||||
msg = f"在线模型 ‘{model_name}’ 的provider没有正确配置"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
# 本地模型
|
||||
if model_name in MODEL_PATH["llm_model"]:
|
||||
path = get_model_path(model_name)
|
||||
config["model_path"] = path
|
||||
if path and os.path.isdir(path):
|
||||
config["model_path_exists"] = True
|
||||
config["device"] = llm_device(config.get("device"))
|
||||
return config
|
||||
|
||||
|
||||
def get_all_model_worker_configs() -> dict:
|
||||
result = {}
|
||||
model_names = set(FSCHAT_MODEL_WORKERS.keys())
|
||||
for name in model_names:
|
||||
if name != "default":
|
||||
result[name] = get_model_worker_config(name)
|
||||
return result
|
||||
|
||||
|
||||
def fschat_controller_address() -> str:
|
||||
from configs.server_config import FSCHAT_CONTROLLER
|
||||
|
||||
host = FSCHAT_CONTROLLER["host"]
|
||||
if host == "0.0.0.0":
|
||||
host = "127.0.0.1"
|
||||
port = FSCHAT_CONTROLLER["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def fschat_model_worker_address(model_name: str = LLM_MODELS[0]) -> str:
|
||||
if model := get_model_worker_config(model_name):
|
||||
host = model["host"]
|
||||
if host == "0.0.0.0":
|
||||
host = "127.0.0.1"
|
||||
port = model["port"]
|
||||
return f"http://{host}:{port}"
|
||||
return ""
|
||||
|
||||
|
||||
def fschat_openai_api_address() -> str:
|
||||
from configs.server_config import FSCHAT_OPENAI_API
|
||||
|
||||
host = FSCHAT_OPENAI_API["host"]
|
||||
if host == "0.0.0.0":
|
||||
host = "127.0.0.1"
|
||||
port = FSCHAT_OPENAI_API["port"]
|
||||
return f"http://{host}:{port}/v1"
|
||||
|
||||
|
||||
def api_address() -> str:
|
||||
from configs.server_config import API_SERVER
|
||||
|
||||
host = API_SERVER["host"]
|
||||
if host == "0.0.0.0":
|
||||
host = "0.0.0.0"
|
||||
port = API_SERVER["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def webui_address() -> str:
|
||||
from configs.server_config import WEBUI_SERVER
|
||||
|
||||
host = WEBUI_SERVER["host"]
|
||||
port = WEBUI_SERVER["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def get_prompt_template(type: str, name: str) -> Optional[str]:
|
||||
'''
|
||||
从prompt_config中加载模板内容
|
||||
type: "llm_chat","agent_chat","knowledge_base_chat","search_engine_chat"的其中一种,如果有新功能,应该进行加入。
|
||||
'''
|
||||
|
||||
from configs import prompt_config
|
||||
import importlib
|
||||
importlib.reload(prompt_config)
|
||||
return prompt_config.PROMPT_TEMPLATES[type].get(name)
|
||||
def replace_variables(prompt_template: str, replace_content: str, replace_param: str) -> Optional[str]:
|
||||
try:
|
||||
# 使用字典解包和.format()方法来替换占位符
|
||||
formatted_string = prompt_template.replace(replace_param,replace_content)
|
||||
return formatted_string
|
||||
except KeyError as e:
|
||||
# 如果模板中的占位符在字典中找不到对应的键,.format()将抛出KeyError
|
||||
# 这里我们捕获异常并返回None,但您也可以选择抛出更具体的异常或进行其他错误处理
|
||||
print(f"Error replacing variable: {e} not found in the dictionary.")
|
||||
return None
|
||||
def get_strategy_prompt_template(type: str, name: str) -> Optional[str]:
|
||||
'''
|
||||
从prompt_config中加载模板内容
|
||||
type: "llm_chat","agent_chat","knowledge_base_chat","search_engine_chat"的其中一种,如果有新功能,应该进行加入。
|
||||
'''
|
||||
|
||||
from configs import strategy_prompt_config
|
||||
import importlib
|
||||
importlib.reload(strategy_prompt_config)
|
||||
return strategy_prompt_config.PROMPT_TEMPLATES[type].get(name)
|
||||
|
||||
|
||||
def get_format_template(type: str, name: str) -> Optional[str]:
|
||||
'''
|
||||
从format_config中加载模板内容
|
||||
type: "llm_chat","agent_chat","knowledge_base_chat","search_engine_chat"
|
||||
'''
|
||||
|
||||
from configs import format_config
|
||||
import importlib
|
||||
importlib.reload(format_config)
|
||||
return format_config.FORMAT_TEMPLATES[type].get(name)
|
||||
|
||||
|
||||
def set_httpx_config(
|
||||
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
||||
proxy: Union[str, Dict] = None,
|
||||
):
|
||||
'''
|
||||
设置httpx默认timeout。httpx默认timeout是5秒,在请求LLM回答时不够用。
|
||||
将本项目相关服务加入无代理列表,避免fastchat的服务器请求错误。(windows下无效)
|
||||
对于chatgpt等在线API,如要使用代理需要手动配置。搜索引擎的代理如何处置还需考虑。
|
||||
'''
|
||||
|
||||
import httpx
|
||||
import os
|
||||
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
|
||||
|
||||
# 在进程范围内设置系统级代理
|
||||
proxies = {}
|
||||
if isinstance(proxy, str):
|
||||
for n in ["http", "https", "all"]:
|
||||
proxies[n + "_proxy"] = proxy
|
||||
elif isinstance(proxy, dict):
|
||||
for n in ["http", "https", "all"]:
|
||||
if p := proxy.get(n):
|
||||
proxies[n + "_proxy"] = p
|
||||
elif p := proxy.get(n + "_proxy"):
|
||||
proxies[n + "_proxy"] = p
|
||||
|
||||
for k, v in proxies.items():
|
||||
os.environ[k] = v
|
||||
|
||||
# set host to bypass proxy
|
||||
no_proxy = [x.strip() for x in os.environ.get("no_proxy", "").split(",") if x.strip()]
|
||||
no_proxy += [
|
||||
# do not use proxy for locahost
|
||||
"http://127.0.0.1",
|
||||
"http://localhost",
|
||||
]
|
||||
# do not use proxy for user deployed fastchat servers
|
||||
for x in [
|
||||
fschat_controller_address(),
|
||||
fschat_model_worker_address(),
|
||||
fschat_openai_api_address(),
|
||||
]:
|
||||
host = ":".join(x.split(":")[:2])
|
||||
if host not in no_proxy:
|
||||
no_proxy.append(host)
|
||||
os.environ["NO_PROXY"] = ",".join(no_proxy)
|
||||
|
||||
def _get_proxies():
|
||||
return proxies
|
||||
|
||||
import urllib.request
|
||||
urllib.request.getproxies = _get_proxies
|
||||
|
||||
|
||||
def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
except:
|
||||
pass
|
||||
return "cpu"
|
||||
|
||||
|
||||
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
||||
device = device or LLM_DEVICE
|
||||
if device not in ["cuda", "mps", "cpu"]:
|
||||
device = detect_device()
|
||||
return device
|
||||
|
||||
|
||||
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
||||
device = device or EMBEDDING_DEVICE
|
||||
if device not in ["cuda", "mps", "cpu"]:
|
||||
device = detect_device()
|
||||
return device
|
||||
|
||||
|
||||
def run_in_thread_pool(
|
||||
func: Callable,
|
||||
params: List[Dict] = [],
|
||||
) -> Generator:
|
||||
'''
|
||||
在线程池中批量运行任务,并将运行结果以生成器的形式返回。
|
||||
请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。
|
||||
'''
|
||||
tasks = []
|
||||
with ThreadPoolExecutor() as pool:
|
||||
for kwargs in params:
|
||||
thread = pool.submit(func, **kwargs)
|
||||
tasks.append(thread)
|
||||
|
||||
for obj in as_completed(tasks):
|
||||
yield obj.result()
|
||||
|
||||
|
||||
def get_httpx_client(
|
||||
use_async: bool = False,
|
||||
proxies: Union[str, Dict] = None,
|
||||
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
||||
**kwargs,
|
||||
) -> Union[httpx.Client, httpx.AsyncClient]:
|
||||
'''
|
||||
helper to get httpx client with default proxies that bypass local addesses.
|
||||
'''
|
||||
default_proxies = {
|
||||
# do not use proxy for locahost
|
||||
"all://127.0.0.1": None,
|
||||
"all://localhost": None,
|
||||
}
|
||||
# do not use proxy for user deployed fastchat servers
|
||||
for x in [
|
||||
fschat_controller_address(),
|
||||
fschat_model_worker_address(),
|
||||
fschat_openai_api_address(),
|
||||
]:
|
||||
host = ":".join(x.split(":")[:2])
|
||||
default_proxies.update({host: None})
|
||||
|
||||
# get proxies from system envionrent
|
||||
# proxy not str empty string, None, False, 0, [] or {}
|
||||
default_proxies.update({
|
||||
"http://": (os.environ.get("http_proxy")
|
||||
if os.environ.get("http_proxy") and len(os.environ.get("http_proxy").strip())
|
||||
else None),
|
||||
"https://": (os.environ.get("https_proxy")
|
||||
if os.environ.get("https_proxy") and len(os.environ.get("https_proxy").strip())
|
||||
else None),
|
||||
"all://": (os.environ.get("all_proxy")
|
||||
if os.environ.get("all_proxy") and len(os.environ.get("all_proxy").strip())
|
||||
else None),
|
||||
})
|
||||
for host in os.environ.get("no_proxy", "").split(","):
|
||||
if host := host.strip():
|
||||
# default_proxies.update({host: None}) # Origin code
|
||||
default_proxies.update({'all://' + host: None}) # PR 1838 fix, if not add 'all://', httpx will raise error
|
||||
|
||||
# merge default proxies with user provided proxies
|
||||
if isinstance(proxies, str):
|
||||
proxies = {"all://": proxies}
|
||||
|
||||
if isinstance(proxies, dict):
|
||||
default_proxies.update(proxies)
|
||||
|
||||
# construct Client
|
||||
kwargs.update(timeout=timeout, proxies=default_proxies)
|
||||
|
||||
if log_verbose:
|
||||
logger.info(f'{get_httpx_client.__class__.__name__}:kwargs: {kwargs}')
|
||||
|
||||
if use_async:
|
||||
return httpx.AsyncClient(**kwargs)
|
||||
else:
|
||||
return httpx.Client(**kwargs)
|
||||
|
||||
|
||||
def get_server_configs() -> Dict:
|
||||
'''
|
||||
获取configs中的原始配置项,供前端使用
|
||||
'''
|
||||
from configs.kb_config import (
|
||||
DEFAULT_KNOWLEDGE_BASE,
|
||||
DEFAULT_SEARCH_ENGINE,
|
||||
DEFAULT_VS_TYPE,
|
||||
CHUNK_SIZE,
|
||||
OVERLAP_SIZE,
|
||||
SCORE_THRESHOLD,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
SEARCH_ENGINE_TOP_K,
|
||||
ZH_TITLE_ENHANCE,
|
||||
text_splitter_dict,
|
||||
TEXT_SPLITTER_NAME,
|
||||
)
|
||||
from configs.model_config import (
|
||||
LLM_MODELS,
|
||||
HISTORY_LEN,
|
||||
TEMPERATURE,
|
||||
)
|
||||
from configs.prompt_config import PROMPT_TEMPLATES
|
||||
from configs.prompt_config import PROMPT_ABSTRACT
|
||||
|
||||
_custom = {
|
||||
"controller_address": fschat_controller_address(),
|
||||
"openai_api_address": fschat_openai_api_address(),
|
||||
"api_address": api_address(),
|
||||
}
|
||||
|
||||
return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom}
|
||||
|
||||
|
||||
def list_online_embed_models() -> List[str]:
|
||||
from server import model_workers
|
||||
|
||||
ret = []
|
||||
for k, v in list_config_llm_models()["online"].items():
|
||||
if provider := v.get("provider"):
|
||||
worker_class = getattr(model_workers, provider, None)
|
||||
if worker_class is not None and worker_class.can_embedding():
|
||||
ret.append(k)
|
||||
return ret
|
||||
|
||||
|
||||
def load_local_embeddings(model: str = None, device: str = embedding_device()):
|
||||
'''
|
||||
从缓存中加载embeddings,可以避免多线程时竞争加载。
|
||||
'''
|
||||
from server.knowledge_base.kb_cache.base import embeddings_pool
|
||||
from configs import EMBEDDING_MODEL
|
||||
|
||||
model = model or EMBEDDING_MODEL
|
||||
return embeddings_pool.load_embeddings(model=model, device=device)
|
||||
|
||||
|
||||
def get_temp_dir(id: str = None) -> Tuple[str, str]:
|
||||
'''
|
||||
创建一个临时目录,返回(路径,文件夹名称)
|
||||
'''
|
||||
from configs.basic_config import BASE_TEMP_DIR
|
||||
import tempfile
|
||||
|
||||
if id is not None: # 如果指定的临时目录已存在,直接返回
|
||||
path = os.path.join(BASE_TEMP_DIR, id)
|
||||
if os.path.isdir(path):
|
||||
return path, id
|
||||
|
||||
path = tempfile.mkdtemp(dir=BASE_TEMP_DIR)
|
||||
return path, os.path.basename(path)
|
||||
|
||||
def flatten(lst):
|
||||
"""将嵌套列表展平为一维列表"""
|
||||
return [item for sublist in lst for item in sublist]
|
||||
Reference in New Issue
Block a user