diff --git a/api/db/init_data.py b/api/db/init_data.py index b813b293634dc16e44256140e1329977ca0a28c8..759fe6ab0dac738670798d6545a8abf80165accd 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -109,6 +109,11 @@ factory_infos = [{ "logo": "", "tags": "LLM,TEXT EMBEDDING", "status": "1", +}, { + "name": "FastEmbed", + "logo": "", + "tags": "TEXT EMBEDDING", + "status": "1", }, { "name": "Xinference", @@ -268,6 +273,58 @@ def init_llm_factory(): "max_tokens": 128 * 1000, "model_type": LLMType.CHAT.value }, + # ------------------------ FastEmbed ----------------------- + { + "fid": factory_infos[5]["name"], + "llm_name": "BAAI/bge-small-en-v1.5", + "tags": "TEXT EMBEDDING,", + "max_tokens": 512, + "model_type": LLMType.EMBEDDING.value + }, { + "fid": factory_infos[5]["name"], + "llm_name": "BAAI/bge-small-zh-v1.5", + "tags": "TEXT EMBEDDING,", + "max_tokens": 512, + "model_type": LLMType.EMBEDDING.value + }, { + }, { + "fid": factory_infos[5]["name"], + "llm_name": "BAAI/bge-base-en-v1.5", + "tags": "TEXT EMBEDDING,", + "max_tokens": 512, + "model_type": LLMType.EMBEDDING.value + }, { + }, { + "fid": factory_infos[5]["name"], + "llm_name": "BAAI/bge-large-en-v1.5", + "tags": "TEXT EMBEDDING,", + "max_tokens": 512, + "model_type": LLMType.EMBEDDING.value + }, { + "fid": factory_infos[5]["name"], + "llm_name": "sentence-transformers/all-MiniLM-L6-v2", + "tags": "TEXT EMBEDDING,", + "max_tokens": 512, + "model_type": LLMType.EMBEDDING.value + }, { + "fid": factory_infos[5]["name"], + "llm_name": "nomic-ai/nomic-embed-text-v1.5", + "tags": "TEXT EMBEDDING,", + "max_tokens": 8192, + "model_type": LLMType.EMBEDDING.value + }, { + "fid": factory_infos[5]["name"], + "llm_name": "jinaai/jina-embeddings-v2-small-en", + "tags": "TEXT EMBEDDING,", + "max_tokens": 2147483648, + "model_type": LLMType.EMBEDDING.value + }, { + "fid": factory_infos[5]["name"], + "llm_name": "jinaai/jina-embeddings-v2-base-en", + "tags": "TEXT EMBEDDING,", + "max_tokens": 2147483648, + "model_type": LLMType.EMBEDDING.value + }, ] for info in factory_infos: try: diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index c088a7f94443deea005c81952ec1be88e0a9f785..edf1f156fef0946409b35d717bb29cfb46ed504e 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -24,7 +24,8 @@ EmbeddingModel = { "Xinference": XinferenceEmbed, "Tongyi-Qianwen": HuEmbedding, #QWenEmbed, "ZHIPU-AI": ZhipuEmbed, - "Moonshot": HuEmbedding + "Moonshot": HuEmbedding, + "FastEmbed": FastEmbed } diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index aa6b565b87491734e8e1e225e012ca883a6e8297..228930a2545e2b8601f6684837fd6cb272b48114 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -13,12 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from typing import Optional from zhipuai import ZhipuAI import os from abc import ABC from ollama import Client import dashscope from openai import OpenAI +from fastembed import TextEmbedding from FlagEmbedding import FlagModel import torch import numpy as np @@ -172,6 +174,34 @@ class OllamaEmbed(Base): return np.array(res["embedding"]), 128 +class FastEmbed(Base): + def __init__( + self, + key: Optional[str] = None, + model_name: str = "BAAI/bge-small-en-v1.5", + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + **kwargs, + ): + self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) + + def encode(self, texts: list, batch_size=32): + # Using the internal tokenizer to encode the texts and get the total number of tokens + encodings = self._model.model.tokenizer.encode_batch(texts) + total_tokens = sum(len(e) for e in encodings) + + embeddings = [e.tolist() for e in self._model.embed(texts, batch_size)] + + return np.array(embeddings), total_tokens + + def encode_queries(self, text: str): + # Using the internal tokenizer to encode the texts and get the total number of tokens + encoding = self._model.model.tokenizer.encode(text) + embedding = next(self._model.query_embed(text)).tolist() + + return np.array(embedding), len(encoding.ids) + + class XinferenceEmbed(Base): def __init__(self, key, model_name="", base_url=""): self.client = OpenAI(api_key="xxx", base_url=base_url) @@ -187,3 +217,4 @@ class XinferenceEmbed(Base): res = self.client.embeddings.create(input=[text], model=self.model_name) return np.array(res.data[0].embedding), res.usage.total_tokens + diff --git a/requirements.txt b/requirements.txt index ac885fefb535a3c58817058f5205c05288ba6ef6..9e3296956c8cf41345ff9e971c65ba5240d61aa5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,6 +27,7 @@ elasticsearch==8.12.1 elasticsearch-dsl==8.12.0 et-xmlfile==1.1.0 filelock==3.13.1 +fastembed==0.2.6 FlagEmbedding==1.2.5 Flask==3.0.2 Flask-Cors==4.0.0