diff --git a/README.md b/README.md index cd113a1d5174ccb63a6355f943d19bb61aa02f69..5acf8df98485ea49f06f69818a11f8df5fb6a430 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,8 @@ ## đź“Ś Latest Features +- 2024-04-16 Add an embedding model 'bce-embedding-base_v1' from [QAnything](https://github.com/netease-youdao/QAnything). +- 2024-04-16 Add [FastEmbed](https://github.com/qdrant/fastembed) is designed for light and speeding embedding. - 2024-04-11 Support [Xinference](./docs/xinference.md) for local LLM deployment. - 2024-04-10 Add a new layout recognization model for analyzing Laws documentation. - 2024-04-08 Support [Ollama](./docs/ollama.md) for local LLM deployment. diff --git a/README_ja.md b/README_ja.md index 431e4fb8972a8a2f37870d50d7ea460ba432e7c4..68a5f73ee976fa5f150a7df87550a2058ab835ca 100644 --- a/README_ja.md +++ b/README_ja.md @@ -55,6 +55,8 @@ ## đź“Ś 最新ă®ć©źč˝ +- 2024-04-16 [QAnything](https://github.com/netease-youdao/QAnything) ă‹ă‚‰ĺź‹ă‚込ăżă˘ă‡ă«ă€Śbce-embedding-base_v1ă€Ťă‚’čż˝ĺŠ ă—ăľă™ă€‚ +- 2024-04-16 [FastEmbed](https://github.com/qdrant/fastembed) ăŻă€č»˝é‡Źă‹ă¤é«é€źăŞĺź‹ă‚込ăżç”¨ă«č¨č¨ă•ă‚Śă¦ă„ăľă™ă€‚ - 2024-04-11 ăăĽă‚«ă« LLM ă‡ă—ăイăˇăłă用㫠[Xinference](./docs/xinference.md) をサăťăĽăă—ăľă™ă€‚ - 2024-04-10 ăˇă‚˝ăă‰ă€ŚLaws」ă«ć–°ă—ă„ă¬ă‚¤ă‚˘ă‚¦ă認čă˘ă‡ă«ă‚’čż˝ĺŠ ă—ăľă™ă€‚ - 2024-04-08 [Ollama](./docs/ollama.md) を使用ă—ăźĺ¤§č¦Źć¨ˇă˘ă‡ă«ă®ăăĽă‚«ă©ă‚¤ă‚şă•ă‚Śăźă‡ă—ăイăˇăłăをサăťăĽăă—ăľă™ă€‚ diff --git a/README_zh.md b/README_zh.md index ed604c9e15d43478f2195d1e511940c3068deb44..e1f6064b0909465a72ca81f3dfc251afaddb125c 100644 --- a/README_zh.md +++ b/README_zh.md @@ -55,6 +55,8 @@ ## đź“Ś ć–°ĺ˘žĺŠźč˝ +- 2024-04-16 ć·»ĺŠ ĺµŚĺ…Ąć¨ˇĺž‹ [QAnythingçš„bce-embedding-base_v1](https://github.com/netease-youdao/QAnything) 。 +- 2024-04-16 ć·»ĺŠ [FastEmbed](https://github.com/qdrant/fastembed) 专为轻型和é«é€źĺµŚĺ…Ąč€Śč®ľč®ˇă€‚ - 2024-04-11 支ćŚç”¨ [Xinference](./docs/xinference.md) 本地化é¨ç˝˛ĺ¤§ć¨ˇĺž‹ă€‚ - 2024-04-10 为â€Laws’ç‰éť˘ĺ†ćžĺ˘žĺŠ 了底层模型。 - 2024-04-08 支ćŚç”¨ [Ollama](./docs/ollama.md) 本地化é¨ç˝˛ĺ¤§ć¨ˇĺž‹ă€‚ diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index c0315cbabe6006453916bf2d776b17f95632a297..81f5b528538f07360201f18e90cf75a84e77f111 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -252,7 +252,7 @@ def retrieval_test(): return get_data_error_result(retmsg="Knowledgebase not found!") embd_mdl = TenantLLMService.model_instance( - kb.tenant_id, LLMType.EMBEDDING.value) + kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold, vector_similarity_weight, top, doc_ids) for c in ranks["chunks"]: diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 033712fd24d7758f19b47e2efbad675d231f9cf7..c01f73a36e58191075656a38dec6357ea2418541 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -15,6 +15,7 @@ # import base64 +import os import pathlib import re diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 78ff7bf7859e0e83af312d369c71c28dc551c228..b98316272181bdfacccb088203ddca6c0e5edd87 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -28,7 +28,7 @@ from rag.llm import EmbeddingModel, ChatModel def factories(): try: fac = LLMFactoriesService.get_all() - return get_json_result(data=[f.to_dict() for f in fac]) + return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["QAnything", "FastEmbed"]]) except Exception as e: return server_error_response(e) @@ -174,7 +174,7 @@ def list(): llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value] for m in llms: - m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" + m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["QAnything","FastEmbed"] llm_set = set([m["llm_name"] for m in llms]) for o in objs: diff --git a/api/db/init_data.py b/api/db/init_data.py index 759fe6ab0dac738670798d6545a8abf80165accd..14cb414a1f5088391505a47d9a158ab105982ba3 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -18,7 +18,7 @@ import time import uuid from api.db import LLMType, UserTenantRole -from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM +from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM from api.db.services import UserService from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle from api.db.services.user_service import TenantService, UserTenantService @@ -114,12 +114,16 @@ factory_infos = [{ "logo": "", "tags": "TEXT EMBEDDING", "status": "1", -}, - { +}, { "name": "Xinference", "logo": "", "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", "status": "1", +},{ + "name": "QAnything", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", + "status": "1", }, # { # "name": "ć–‡ĺżä¸€č¨€", @@ -254,12 +258,6 @@ def init_llm_factory(): "tags": "LLM,CHAT,", "max_tokens": 7900, "model_type": LLMType.CHAT.value - }, { - "fid": factory_infos[4]["name"], - "llm_name": "flag-embedding", - "tags": "TEXT EMBEDDING,", - "max_tokens": 128 * 1000, - "model_type": LLMType.EMBEDDING.value }, { "fid": factory_infos[4]["name"], "llm_name": "moonshot-v1-32k", @@ -325,6 +323,14 @@ def init_llm_factory(): "max_tokens": 2147483648, "model_type": LLMType.EMBEDDING.value }, + # ------------------------ QAnything ----------------------- + { + "fid": factory_infos[7]["name"], + "llm_name": "maidalun1020/bce-embedding-base_v1", + "tags": "TEXT EMBEDDING,", + "max_tokens": 512, + "model_type": LLMType.EMBEDDING.value + }, ] for info in factory_infos: try: @@ -337,8 +343,10 @@ def init_llm_factory(): except Exception as e: pass - LLMFactoriesService.filter_delete([LLMFactories.name=="Local"]) - LLMService.filter_delete([LLM.fid=="Local"]) + LLMFactoriesService.filter_delete([LLMFactories.name == "Local"]) + LLMService.filter_delete([LLM.fid == "Local"]) + LLMService.filter_delete([LLM.fid == "Moonshot", LLM.llm_name == "flag-embedding"]) + TenantLLMService.filter_delete([TenantLLM.llm_factory == "Moonshot", TenantLLM.llm_name == "flag-embedding"]) """ drop table llm; diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 6e4855c4adb714ab8cc583df8f8131bc15bfad83..db1d6241f1985558430de7c5ae5133389c37c553 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -80,8 +80,12 @@ def chat(dialog, messages, **kwargs): raise LookupError("LLM(%s) not found" % dialog.llm_id) max_tokens = 1024 else: max_tokens = llm[0].max_tokens + kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) + embd_nms = list(set([kb.embd_id for kb in kbs])) + assert len(embd_nms) == 1, "Knowledge bases use different embedding models." + questions = [m["content"] for m in messages if m["role"] == "user"] - embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING) + embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0]) chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) prompt_config = dialog.prompt_config diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index b14972dcf8d2b904749cbe0a9b4af96166135f23..a94701b0df2ee0b371f56838829026d6c548eefd 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -66,7 +66,7 @@ class TenantLLMService(CommonService): raise LookupError("Tenant not found") if llm_type == LLMType.EMBEDDING.value: - mdlnm = tenant.embd_id + mdlnm = tenant.embd_id if not llm_name else llm_name elif llm_type == LLMType.SPEECH2TEXT.value: mdlnm = tenant.asr_id elif llm_type == LLMType.IMAGE2TEXT.value: @@ -77,9 +77,14 @@ class TenantLLMService(CommonService): assert False, "LLM type error" model_config = cls.get_api_key(tenant_id, mdlnm) + if model_config: model_config = model_config.to_dict() if not model_config: - raise LookupError("Model({}) not authorized".format(mdlnm)) - model_config = model_config.to_dict() + if llm_type == LLMType.EMBEDDING.value: + llm = LLMService.query(llm_name=llm_name) + if llm and llm[0].fid in ["QAnything", "FastEmbed"]: + model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""} + if not model_config: raise LookupError("Model({}) not authorized".format(mdlnm)) + if llm_type == LLMType.EMBEDDING.value: if model_config["llm_factory"] not in EmbeddingModel: return diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 360ab5bf8391edc04a84ce768947446fd3b86ec1..624f1ececfb95943a2090d0461ae2aff70bde54e 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -41,7 +41,7 @@ class TaskService(CommonService): Document.size, Knowledgebase.tenant_id, Knowledgebase.language, - Tenant.embd_id, + Knowledgebase.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index edf1f156fef0946409b35d717bb29cfb46ed504e..7d6d7c44117d0098e6ff8ef6f75d1e6374490b07 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -24,8 +24,8 @@ EmbeddingModel = { "Xinference": XinferenceEmbed, "Tongyi-Qianwen": HuEmbedding, #QWenEmbed, "ZHIPU-AI": ZhipuEmbed, - "Moonshot": HuEmbedding, - "FastEmbed": FastEmbed + "FastEmbed": FastEmbed, + "QAnything": QAnythingEmbed } diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 228930a2545e2b8601f6684837fd6cb272b48114..e6ff0b117835819d34106d9f3dcedfc55560df04 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -20,7 +20,6 @@ from abc import ABC from ollama import Client import dashscope from openai import OpenAI -from fastembed import TextEmbedding from FlagEmbedding import FlagModel import torch import numpy as np @@ -28,16 +27,17 @@ import numpy as np from api.utils.file_utils import get_project_base_directory from rag.utils import num_tokens_from_string + try: flag_model = FlagModel(os.path.join( - get_project_base_directory(), - "rag/res/bge-large-zh-v1.5"), - query_instruction_for_retrieval="为这个句ĺ生ćčˇ¨ç¤şä»Ąç”¨äşŽćŁ€ç´˘ç›¸ĺ…łć–‡ç« ďĽš", - use_fp16=torch.cuda.is_available()) + get_project_base_directory(), + "rag/res/bge-large-zh-v1.5"), + query_instruction_for_retrieval="为这个句ĺ生ćčˇ¨ç¤şä»Ąç”¨äşŽćŁ€ç´˘ç›¸ĺ…łć–‡ç« ďĽš", + use_fp16=torch.cuda.is_available()) except Exception as e: flag_model = FlagModel("BAAI/bge-large-zh-v1.5", - query_instruction_for_retrieval="为这个句ĺ生ćčˇ¨ç¤şä»Ąç”¨äşŽćŁ€ç´˘ç›¸ĺ…łć–‡ç« ďĽš", - use_fp16=torch.cuda.is_available()) + query_instruction_for_retrieval="为这个句ĺ生ćčˇ¨ç¤şä»Ąç”¨äşŽćŁ€ç´˘ç›¸ĺ…łć–‡ç« ďĽš", + use_fp16=torch.cuda.is_available()) class Base(ABC): @@ -82,8 +82,10 @@ class HuEmbedding(Base): class OpenAIEmbed(Base): - def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"): - if not base_url: base_url="https://api.openai.com/v1" + def __init__(self, key, model_name="text-embedding-ada-002", + base_url="https://api.openai.com/v1"): + if not base_url: + base_url = "https://api.openai.com/v1" self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name @@ -142,7 +144,7 @@ class ZhipuEmbed(Base): tks_num = 0 for txt in texts: res = self.client.embeddings.create(input=txt, - model=self.model_name) + model=self.model_name) arr.append(res.data[0].embedding) tks_num += res.usage.total_tokens return np.array(arr), tks_num @@ -163,14 +165,14 @@ class OllamaEmbed(Base): tks_num = 0 for txt in texts: res = self.client.embeddings(prompt=txt, - model=self.model_name) + model=self.model_name) arr.append(res["embedding"]) tks_num += 128 return np.array(arr), tks_num def encode_queries(self, text): res = self.client.embeddings(prompt=text, - model=self.model_name) + model=self.model_name) return np.array(res["embedding"]), 128 @@ -183,10 +185,12 @@ class FastEmbed(Base): threads: Optional[int] = None, **kwargs, ): + from fastembed import TextEmbedding self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) def encode(self, texts: list, batch_size=32): - # Using the internal tokenizer to encode the texts and get the total number of tokens + # Using the internal tokenizer to encode the texts and get the total + # number of tokens encodings = self._model.model.tokenizer.encode_batch(texts) total_tokens = sum(len(e) for e in encodings) @@ -195,7 +199,8 @@ class FastEmbed(Base): return np.array(embeddings), total_tokens def encode_queries(self, text: str): - # Using the internal tokenizer to encode the texts and get the total number of tokens + # Using the internal tokenizer to encode the texts and get the total + # number of tokens encoding = self._model.model.tokenizer.encode(text) embedding = next(self._model.query_embed(text)).tolist() @@ -218,3 +223,33 @@ class XinferenceEmbed(Base): model=self.model_name) return np.array(res.data[0].embedding), res.usage.total_tokens + +class QAnythingEmbed(Base): + _client = None + + def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs): + from BCEmbedding import EmbeddingModel as qanthing + if not QAnythingEmbed._client: + try: + print("LOADING BCE...") + QAnythingEmbed._client = qanthing(model_name_or_path=os.path.join( + get_project_base_directory(), + "rag/res/bce-embedding-base_v1")) + except Exception as e: + QAnythingEmbed._client = qanthing( + model_name_or_path=model_name.replace( + "maidalun1020", "InfiniFlow")) + + def encode(self, texts: list, batch_size=10): + res = [] + token_count = 0 + for t in texts: + token_count += num_tokens_from_string(t) + for i in range(0, len(texts), batch_size): + embds = QAnythingEmbed._client.encode(texts[i:i + batch_size]) + res.extend(embds) + return np.array(res), token_count + + def encode_queries(self, text): + embds = QAnythingEmbed._client.encode([text]) + return np.array(embds[0]), num_tokens_from_string(text) diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 178466a6812760abf3c20ee66ef322e8ded5f148..422ce54e97e7d18660ecd961849a10505bed57ea 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -46,7 +46,7 @@ class Dealer: "k": topk, "similarity": sim, "num_candidates": topk * 2, - "query_vector": list(qv) + "query_vector": [float(v) for v in qv] } def search(self, req, idxnm, emb_mdl=None): diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 6ea80d9c95bc96050d3238bb9a1509ca2d60e61d..804a151eb9b4df99660a099c951825fbe5a53ece 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -244,8 +244,9 @@ def main(comm, mod): for _, r in rows.iterrows(): callback = partial(set_progress, r["id"], r["from_page"], r["to_page"]) try: - embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING) + embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING, llm_name=r["embd_id"], lang=r["language"]) except Exception as e: + traceback.print_stack(e) callback(prog=-1, msg=str(e)) continue diff --git a/requirements.txt b/requirements.txt index 9cf9234b7783c662f2b1121735387370cbe43f30..f9ca516fe6c0af3ee2472c6eaa034702274a50b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -132,3 +132,5 @@ xpinyin==0.7.6 xxhash==3.4.1 yarl==1.9.4 zhipuai==2.0.1 +BCEmbedding +loguru==0.7.2