[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
164
langchain-chat/server/knowledge_base/kb_cache/base.py
Normal file
164
langchain-chat/server/knowledge_base/kb_cache/base.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
import threading
|
||||
from configs import (EMBEDDING_MODEL, CHUNK_SIZE,
|
||||
logger, log_verbose)
|
||||
from server.utils import embedding_device, get_model_path, list_online_embed_models, resolve_embed_model_name
|
||||
from contextlib import contextmanager
|
||||
from collections import OrderedDict
|
||||
from typing import List, Any, Union, Tuple
|
||||
|
||||
|
||||
class ThreadSafeObject:
|
||||
def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None):
|
||||
self._obj = obj
|
||||
self._key = key
|
||||
self._pool = pool
|
||||
self._lock = threading.RLock()
|
||||
self._loaded = threading.Event()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
cls = type(self).__name__
|
||||
return f"<{cls}: key: {self.key}, obj: {self._obj}>"
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
return self._key
|
||||
|
||||
@contextmanager
|
||||
def acquire(self, owner: str = "", msg: str = "") -> FAISS:
|
||||
owner = owner or f"thread {threading.get_native_id()}"
|
||||
try:
|
||||
self._lock.acquire()
|
||||
if self._pool is not None:
|
||||
self._pool._cache.move_to_end(self.key)
|
||||
if log_verbose:
|
||||
logger.info(f"{owner} 开始操作:{self.key}。{msg}")
|
||||
yield self._obj
|
||||
finally:
|
||||
if log_verbose:
|
||||
logger.info(f"{owner} 结束操作:{self.key}。{msg}")
|
||||
self._lock.release()
|
||||
|
||||
def start_loading(self):
|
||||
self._loaded.clear()
|
||||
|
||||
def finish_loading(self):
|
||||
self._loaded.set()
|
||||
|
||||
def wait_for_loading(self):
|
||||
self._loaded.wait()
|
||||
|
||||
@property
|
||||
def obj(self):
|
||||
return self._obj
|
||||
|
||||
@obj.setter
|
||||
def obj(self, val: Any):
|
||||
self._obj = val
|
||||
|
||||
|
||||
class CachePool:
|
||||
def __init__(self, cache_num: int = -1):
|
||||
self._cache_num = cache_num
|
||||
self._cache = OrderedDict()
|
||||
self.atomic = threading.RLock()
|
||||
|
||||
def keys(self) -> List[str]:
|
||||
return list(self._cache.keys())
|
||||
|
||||
def _check_count(self):
|
||||
if isinstance(self._cache_num, int) and self._cache_num > 0:
|
||||
while len(self._cache) > self._cache_num:
|
||||
self._cache.popitem(last=False)
|
||||
|
||||
def get(self, key: str) -> ThreadSafeObject:
|
||||
if cache := self._cache.get(key):
|
||||
cache.wait_for_loading()
|
||||
return cache
|
||||
|
||||
def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject:
|
||||
self._cache[key] = obj
|
||||
self._check_count()
|
||||
return obj
|
||||
|
||||
def pop(self, key: str = None) -> ThreadSafeObject:
|
||||
if key is None:
|
||||
return self._cache.popitem(last=False)
|
||||
else:
|
||||
return self._cache.pop(key, None)
|
||||
|
||||
def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""):
|
||||
cache = self.get(key)
|
||||
if cache is None:
|
||||
raise RuntimeError(f"请求的资源 {key} 不存在")
|
||||
elif isinstance(cache, ThreadSafeObject):
|
||||
self._cache.move_to_end(key)
|
||||
return cache.acquire(owner=owner, msg=msg)
|
||||
else:
|
||||
return cache
|
||||
|
||||
def load_kb_embeddings(
|
||||
self,
|
||||
kb_name: str,
|
||||
embed_device: str = embedding_device(),
|
||||
default_embed_model: str = EMBEDDING_MODEL,
|
||||
) -> Embeddings:
|
||||
from server.db.repository.knowledge_base_repository import get_kb_detail
|
||||
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
|
||||
|
||||
kb_detail = get_kb_detail(kb_name)
|
||||
embed_model = resolve_embed_model_name(
|
||||
kb_detail.get("embed_model", default_embed_model)
|
||||
)
|
||||
|
||||
if embed_model in list_online_embed_models():
|
||||
return EmbeddingsFunAdapter(embed_model)
|
||||
else:
|
||||
return embeddings_pool.load_embeddings(model=embed_model, device=embed_device)
|
||||
|
||||
|
||||
class EmbeddingsPool(CachePool):
|
||||
def load_embeddings(self, model: str = None, device: str = None) -> Embeddings:
|
||||
self.atomic.acquire()
|
||||
model = model or EMBEDDING_MODEL
|
||||
device = embedding_device()
|
||||
key = (model, device)
|
||||
if not self.get(key):
|
||||
item = ThreadSafeObject(key, pool=self)
|
||||
self.set(key, item)
|
||||
with item.acquire(msg="初始化"):
|
||||
self.atomic.release()
|
||||
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings(model=model,
|
||||
openai_api_key=get_model_path(model),
|
||||
chunk_size=CHUNK_SIZE)
|
||||
elif 'bge-' in model:
|
||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||
if 'zh' in model:
|
||||
# for chinese model
|
||||
query_instruction = "为这个句子生成表示以用于检索相关文章:"
|
||||
elif 'en' in model:
|
||||
# for english model
|
||||
query_instruction = "Represent this sentence for searching relevant passages:"
|
||||
else:
|
||||
# maybe ReRanker or else, just use empty string instead
|
||||
query_instruction = ""
|
||||
embeddings = HuggingFaceBgeEmbeddings(model_name=get_model_path(model),
|
||||
model_kwargs={'device': device},
|
||||
query_instruction=query_instruction)
|
||||
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
|
||||
embeddings.query_instruction = ""
|
||||
else:
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model),
|
||||
model_kwargs={'device': device})
|
||||
item.obj = embeddings
|
||||
item.finish_loading()
|
||||
else:
|
||||
self.atomic.release()
|
||||
return self.get(key).obj
|
||||
|
||||
|
||||
embeddings_pool = EmbeddingsPool(cache_num=1)
|
||||
175
langchain-chat/server/knowledge_base/kb_cache/faiss_cache.py
Normal file
175
langchain-chat/server/knowledge_base/kb_cache/faiss_cache.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from configs import CACHED_VS_NUM, CACHED_MEMO_VS_NUM
|
||||
from server.knowledge_base.kb_cache.base import *
|
||||
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
|
||||
from server.utils import load_local_embeddings
|
||||
from server.knowledge_base.utils import get_vs_path
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
from langchain.schema import Document
|
||||
import os
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
# patch FAISS to include doc id in Document.metadata
|
||||
def _new_ds_search(self, search: str) -> Union[str, Document]:
|
||||
if search not in self._dict:
|
||||
return f"ID {search} not found."
|
||||
else:
|
||||
doc = self._dict[search]
|
||||
if isinstance(doc, Document):
|
||||
doc.metadata["id"] = search
|
||||
return doc
|
||||
InMemoryDocstore.search = _new_ds_search
|
||||
|
||||
|
||||
class ThreadSafeFaiss(ThreadSafeObject):
|
||||
def __repr__(self) -> str:
|
||||
cls = type(self).__name__
|
||||
return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
|
||||
|
||||
def docs_count(self) -> int:
|
||||
return len(self._obj.docstore._dict)
|
||||
|
||||
def save(self, path: str, create_path: bool = True):
|
||||
with self.acquire():
|
||||
if not os.path.isdir(path) and create_path:
|
||||
os.makedirs(path)
|
||||
ret = self._obj.save_local(path)
|
||||
logger.info(f"已将向量库 {self.key} 保存到磁盘")
|
||||
return ret
|
||||
|
||||
def clear(self):
|
||||
ret = []
|
||||
with self.acquire():
|
||||
ids = list(self._obj.docstore._dict.keys())
|
||||
if ids:
|
||||
ret = self._obj.delete(ids)
|
||||
assert len(self._obj.docstore._dict) == 0
|
||||
logger.info(f"已将向量库 {self.key} 清空")
|
||||
return ret
|
||||
|
||||
|
||||
class _FaissPool(CachePool):
|
||||
def new_vector_store(
|
||||
self,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = embedding_device(),
|
||||
) -> FAISS:
|
||||
embeddings = EmbeddingsFunAdapter(embed_model)
|
||||
doc = Document(page_content="init", metadata={})
|
||||
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
|
||||
ids = list(vector_store.docstore._dict.keys())
|
||||
vector_store.delete(ids)
|
||||
return vector_store
|
||||
|
||||
def save_vector_store(self, kb_name: str, path: str=None):
|
||||
if cache := self.get(kb_name):
|
||||
return cache.save(path)
|
||||
|
||||
def unload_vector_store(self, kb_name: str):
|
||||
if cache := self.get(kb_name):
|
||||
self.pop(kb_name)
|
||||
logger.info(f"成功释放向量库:{kb_name}")
|
||||
|
||||
|
||||
class KBFaissPool(_FaissPool):
|
||||
def load_vector_store(
|
||||
self,
|
||||
kb_name: str,
|
||||
vector_name: str = None,
|
||||
create: bool = True,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = embedding_device(),
|
||||
) -> ThreadSafeFaiss:
|
||||
self.atomic.acquire()
|
||||
vector_name = vector_name or embed_model
|
||||
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
|
||||
if cache is None:
|
||||
item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
|
||||
self.set((kb_name, vector_name), item)
|
||||
with item.acquire(msg="初始化"):
|
||||
self.atomic.release()
|
||||
logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.")
|
||||
vs_path = get_vs_path(kb_name, vector_name)
|
||||
|
||||
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
|
||||
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model)
|
||||
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
|
||||
elif create:
|
||||
# create an empty vector store
|
||||
if not os.path.exists(vs_path):
|
||||
os.makedirs(vs_path)
|
||||
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
|
||||
vector_store.save_local(vs_path)
|
||||
else:
|
||||
raise RuntimeError(f"knowledge base {kb_name} not exist.")
|
||||
item.obj = vector_store
|
||||
item.finish_loading()
|
||||
else:
|
||||
self.atomic.release()
|
||||
return self.get((kb_name, vector_name))
|
||||
|
||||
|
||||
class MemoFaissPool(_FaissPool):
|
||||
def load_vector_store(
|
||||
self,
|
||||
kb_name: str,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = embedding_device(),
|
||||
) -> ThreadSafeFaiss:
|
||||
self.atomic.acquire()
|
||||
cache = self.get(kb_name)
|
||||
if cache is None:
|
||||
item = ThreadSafeFaiss(kb_name, pool=self)
|
||||
self.set(kb_name, item)
|
||||
with item.acquire(msg="初始化"):
|
||||
self.atomic.release()
|
||||
logger.info(f"loading vector store in '{kb_name}' to memory.")
|
||||
# create an empty vector store
|
||||
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
|
||||
item.obj = vector_store
|
||||
item.finish_loading()
|
||||
else:
|
||||
self.atomic.release()
|
||||
return self.get(kb_name)
|
||||
|
||||
|
||||
kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM)
|
||||
memo_faiss_pool = MemoFaissPool(cache_num=CACHED_MEMO_VS_NUM)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time, random
|
||||
from pprint import pprint
|
||||
|
||||
kb_names = ["vs1", "vs2", "vs3"]
|
||||
# for name in kb_names:
|
||||
# memo_faiss_pool.load_vector_store(name)
|
||||
|
||||
def worker(vs_name: str, name: str):
|
||||
vs_name = "samples"
|
||||
time.sleep(random.randint(1, 5))
|
||||
embeddings = load_local_embeddings()
|
||||
r = random.randint(1, 3)
|
||||
|
||||
with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs:
|
||||
if r == 1: # add docs
|
||||
ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
|
||||
pprint(ids)
|
||||
elif r == 2: # search docs
|
||||
docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0)
|
||||
pprint(docs)
|
||||
if r == 3: # delete docs
|
||||
logger.warning(f"清除 {vs_name} by {name}")
|
||||
kb_faiss_pool.get(vs_name).clear()
|
||||
|
||||
threads = []
|
||||
for n in range(1, 30):
|
||||
t = threading.Thread(target=worker,
|
||||
kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"},
|
||||
daemon=True)
|
||||
t.start()
|
||||
threads.append(t)
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
Reference in New Issue
Block a user