Skip to content
Snippets Groups Projects
Unverified Commit 826ad6a3 authored by Anush's avatar Anush Committed by GitHub
Browse files

feat: FastEmbed embedding support (#291)

### Description

Following up on https://github.com/infiniflow/ragflow/pull/275, this PR
adds support for FastEmbed model configurations.

The options are not exhaustive. You can find the full list
[here](https://qdrant.github.io/fastembed/examples/Supported_Models/

).

P.S. I ran into OOM issues when building the image.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: default avatarKevinHuSh <kevinhu.sh@gmail.com>
parent e5a5b820
No related branches found
No related tags found
No related merge requests found
......@@ -109,6 +109,11 @@ factory_infos = [{
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",
}, {
"name": "FastEmbed",
"logo": "",
"tags": "TEXT EMBEDDING",
"status": "1",
},
{
"name": "Xinference",
......@@ -268,6 +273,58 @@ def init_llm_factory():
"max_tokens": 128 * 1000,
"model_type": LLMType.CHAT.value
},
# ------------------------ FastEmbed -----------------------
{
"fid": factory_infos[5]["name"],
"llm_name": "BAAI/bge-small-en-v1.5",
"tags": "TEXT EMBEDDING,",
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[5]["name"],
"llm_name": "BAAI/bge-small-zh-v1.5",
"tags": "TEXT EMBEDDING,",
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
}, {
}, {
"fid": factory_infos[5]["name"],
"llm_name": "BAAI/bge-base-en-v1.5",
"tags": "TEXT EMBEDDING,",
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
}, {
}, {
"fid": factory_infos[5]["name"],
"llm_name": "BAAI/bge-large-en-v1.5",
"tags": "TEXT EMBEDDING,",
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[5]["name"],
"llm_name": "sentence-transformers/all-MiniLM-L6-v2",
"tags": "TEXT EMBEDDING,",
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[5]["name"],
"llm_name": "nomic-ai/nomic-embed-text-v1.5",
"tags": "TEXT EMBEDDING,",
"max_tokens": 8192,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[5]["name"],
"llm_name": "jinaai/jina-embeddings-v2-small-en",
"tags": "TEXT EMBEDDING,",
"max_tokens": 2147483648,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[5]["name"],
"llm_name": "jinaai/jina-embeddings-v2-base-en",
"tags": "TEXT EMBEDDING,",
"max_tokens": 2147483648,
"model_type": LLMType.EMBEDDING.value
},
]
for info in factory_infos:
try:
......
......@@ -24,7 +24,8 @@ EmbeddingModel = {
"Xinference": XinferenceEmbed,
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
"ZHIPU-AI": ZhipuEmbed,
"Moonshot": HuEmbedding
"Moonshot": HuEmbedding,
"FastEmbed": FastEmbed
}
......
......@@ -13,12 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional
from zhipuai import ZhipuAI
import os
from abc import ABC
from ollama import Client
import dashscope
from openai import OpenAI
from fastembed import TextEmbedding
from FlagEmbedding import FlagModel
import torch
import numpy as np
......@@ -172,6 +174,34 @@ class OllamaEmbed(Base):
return np.array(res["embedding"]), 128
class FastEmbed(Base):
def __init__(
self,
key: Optional[str] = None,
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
**kwargs,
):
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
def encode(self, texts: list, batch_size=32):
# Using the internal tokenizer to encode the texts and get the total number of tokens
encodings = self._model.model.tokenizer.encode_batch(texts)
total_tokens = sum(len(e) for e in encodings)
embeddings = [e.tolist() for e in self._model.embed(texts, batch_size)]
return np.array(embeddings), total_tokens
def encode_queries(self, text: str):
# Using the internal tokenizer to encode the texts and get the total number of tokens
encoding = self._model.model.tokenizer.encode(text)
embedding = next(self._model.query_embed(text)).tolist()
return np.array(embedding), len(encoding.ids)
class XinferenceEmbed(Base):
def __init__(self, key, model_name="", base_url=""):
self.client = OpenAI(api_key="xxx", base_url=base_url)
......@@ -187,3 +217,4 @@ class XinferenceEmbed(Base):
res = self.client.embeddings.create(input=[text],
model=self.model_name)
return np.array(res.data[0].embedding), res.usage.total_tokens
......@@ -27,6 +27,7 @@ elasticsearch==8.12.1
elasticsearch-dsl==8.12.0
et-xmlfile==1.1.0
filelock==3.13.1
fastembed==0.2.6
FlagEmbedding==1.2.5
Flask==3.0.2
Flask-Cors==4.0.0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment