Skip to content
Snippets Groups Projects
Unverified Commit 826ad6a3 authored by Anush's avatar Anush Committed by GitHub
Browse files

feat: FastEmbed embedding support (#291)

### Description

Following up on https://github.com/infiniflow/ragflow/pull/275, this PR
adds support for FastEmbed model configurations.

The options are not exhaustive. You can find the full list
[here](https://qdrant.github.io/fastembed/examples/Supported_Models/

).

P.S. I ran into OOM issues when building the image.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: default avatarKevinHuSh <kevinhu.sh@gmail.com>
parent e5a5b820
No related branches found
No related tags found
No related merge requests found
...@@ -109,6 +109,11 @@ factory_infos = [{ ...@@ -109,6 +109,11 @@ factory_infos = [{
"logo": "", "logo": "",
"tags": "LLM,TEXT EMBEDDING", "tags": "LLM,TEXT EMBEDDING",
"status": "1", "status": "1",
}, {
"name": "FastEmbed",
"logo": "",
"tags": "TEXT EMBEDDING",
"status": "1",
}, },
{ {
"name": "Xinference", "name": "Xinference",
...@@ -268,6 +273,58 @@ def init_llm_factory(): ...@@ -268,6 +273,58 @@ def init_llm_factory():
"max_tokens": 128 * 1000, "max_tokens": 128 * 1000,
"model_type": LLMType.CHAT.value "model_type": LLMType.CHAT.value
}, },
# ------------------------ FastEmbed -----------------------
{
"fid": factory_infos[5]["name"],
"llm_name": "BAAI/bge-small-en-v1.5",
"tags": "TEXT EMBEDDING,",
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[5]["name"],
"llm_name": "BAAI/bge-small-zh-v1.5",
"tags": "TEXT EMBEDDING,",
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
}, {
}, {
"fid": factory_infos[5]["name"],
"llm_name": "BAAI/bge-base-en-v1.5",
"tags": "TEXT EMBEDDING,",
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
}, {
}, {
"fid": factory_infos[5]["name"],
"llm_name": "BAAI/bge-large-en-v1.5",
"tags": "TEXT EMBEDDING,",
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[5]["name"],
"llm_name": "sentence-transformers/all-MiniLM-L6-v2",
"tags": "TEXT EMBEDDING,",
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[5]["name"],
"llm_name": "nomic-ai/nomic-embed-text-v1.5",
"tags": "TEXT EMBEDDING,",
"max_tokens": 8192,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[5]["name"],
"llm_name": "jinaai/jina-embeddings-v2-small-en",
"tags": "TEXT EMBEDDING,",
"max_tokens": 2147483648,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[5]["name"],
"llm_name": "jinaai/jina-embeddings-v2-base-en",
"tags": "TEXT EMBEDDING,",
"max_tokens": 2147483648,
"model_type": LLMType.EMBEDDING.value
},
] ]
for info in factory_infos: for info in factory_infos:
try: try:
......
...@@ -24,7 +24,8 @@ EmbeddingModel = { ...@@ -24,7 +24,8 @@ EmbeddingModel = {
"Xinference": XinferenceEmbed, "Xinference": XinferenceEmbed,
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed, "Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
"ZHIPU-AI": ZhipuEmbed, "ZHIPU-AI": ZhipuEmbed,
"Moonshot": HuEmbedding "Moonshot": HuEmbedding,
"FastEmbed": FastEmbed
} }
......
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from typing import Optional
from zhipuai import ZhipuAI from zhipuai import ZhipuAI
import os import os
from abc import ABC from abc import ABC
from ollama import Client from ollama import Client
import dashscope import dashscope
from openai import OpenAI from openai import OpenAI
from fastembed import TextEmbedding
from FlagEmbedding import FlagModel from FlagEmbedding import FlagModel
import torch import torch
import numpy as np import numpy as np
...@@ -172,6 +174,34 @@ class OllamaEmbed(Base): ...@@ -172,6 +174,34 @@ class OllamaEmbed(Base):
return np.array(res["embedding"]), 128 return np.array(res["embedding"]), 128
class FastEmbed(Base):
def __init__(
self,
key: Optional[str] = None,
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
**kwargs,
):
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
def encode(self, texts: list, batch_size=32):
# Using the internal tokenizer to encode the texts and get the total number of tokens
encodings = self._model.model.tokenizer.encode_batch(texts)
total_tokens = sum(len(e) for e in encodings)
embeddings = [e.tolist() for e in self._model.embed(texts, batch_size)]
return np.array(embeddings), total_tokens
def encode_queries(self, text: str):
# Using the internal tokenizer to encode the texts and get the total number of tokens
encoding = self._model.model.tokenizer.encode(text)
embedding = next(self._model.query_embed(text)).tolist()
return np.array(embedding), len(encoding.ids)
class XinferenceEmbed(Base): class XinferenceEmbed(Base):
def __init__(self, key, model_name="", base_url=""): def __init__(self, key, model_name="", base_url=""):
self.client = OpenAI(api_key="xxx", base_url=base_url) self.client = OpenAI(api_key="xxx", base_url=base_url)
...@@ -187,3 +217,4 @@ class XinferenceEmbed(Base): ...@@ -187,3 +217,4 @@ class XinferenceEmbed(Base):
res = self.client.embeddings.create(input=[text], res = self.client.embeddings.create(input=[text],
model=self.model_name) model=self.model_name)
return np.array(res.data[0].embedding), res.usage.total_tokens return np.array(res.data[0].embedding), res.usage.total_tokens
...@@ -27,6 +27,7 @@ elasticsearch==8.12.1 ...@@ -27,6 +27,7 @@ elasticsearch==8.12.1
elasticsearch-dsl==8.12.0 elasticsearch-dsl==8.12.0
et-xmlfile==1.1.0 et-xmlfile==1.1.0
filelock==3.13.1 filelock==3.13.1
fastembed==0.2.6
FlagEmbedding==1.2.5 FlagEmbedding==1.2.5
Flask==3.0.2 Flask==3.0.2
Flask-Cors==4.0.0 Flask-Cors==4.0.0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment