From 3708b97db905922631244aef03d400bebe6ace85 Mon Sep 17 00:00:00 2001
From: KevinHuSh <kevinhu.sh@gmail.com>
Date: Mon, 8 Apr 2024 19:20:57 +0800
Subject: [PATCH] Support Ollama (#261)

### What problem does this PR solve?

Issue link:#221

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
---
 README.md                    | 13 +++++---
 README_ja.md                 |  9 ++++--
 README_zh.md                 | 15 ++++++----
 api/apps/conversation_app.py |  2 +-
 api/apps/document_app.py     |  2 +-
 api/apps/llm_app.py          | 57 ++++++++++++++++++++++++++++++++++++
 api/apps/user_app.py         |  4 +++
 api/db/init_data.py          | 28 +++++-------------
 docker/docker-compose-CN.yml |  1 +
 docs/ollama.md               | 40 +++++++++++++++++++++++++
 rag/llm/__init__.py          |  6 ++--
 rag/llm/chat_model.py        | 27 +++++++++++++++++
 rag/llm/cv_model.py          | 24 ++++++++++++++-
 rag/llm/embedding_model.py   | 24 +++++++++++++--
 rag/svr/task_executor.py     | 25 +++++++++++++---
 15 files changed, 234 insertions(+), 43 deletions(-)
 create mode 100644 docs/ollama.md

diff --git a/README.md b/README.md
index 8f9827a..9ef33c0 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
 <div align="center">
 <a href="https://demo.ragflow.io/">
-<img src="web/src/assets/logo-with-text.png" width="350" alt="ragflow logo">
+<img src="web/src/assets/logo-with-text.png" width="520" alt="ragflow logo">
 </a>
 </div>
 
@@ -124,12 +124,12 @@
 
     * Running on all addresses (0.0.0.0)
     * Running on http://127.0.0.1:9380
-    * Running on http://172.22.0.5:9380
+    * Running on http://x.x.x.x:9380
     INFO:werkzeug:Press CTRL+C to quit
    ```
 
-5. In your web browser, enter the IP address of your server as prompted and log in to RAGFlow.
-   > In the given scenario, you only need to enter `http://IP_of_RAGFlow ` (sans port number) as the default HTTP serving port `80` can be omitted when using the default configurations.
+5. In your web browser, enter the IP address of your server and log in to RAGFlow.
+   > In the given scenario, you only need to enter `http://IP_OF_YOUR_MACHINE` (sans port number) as the default HTTP serving port `80` can be omitted when using the default configurations.
 6. In [service_conf.yaml](./docker/service_conf.yaml), select the desired LLM factory in `user_default_llm` and update the `API_KEY` field with the corresponding API key.
 
    > See [./docs/llm_api_key_setup.md](./docs/llm_api_key_setup.md) for more information.
@@ -168,6 +168,11 @@ $ cd ragflow/docker
 $ docker compose up -d
 ```
 
+## 🆕 Latest Features
+
+- Support [Ollam](./docs/ollama.md) for local LLM deployment.
+- Support Chinese UI.
+
 ## 📜 Roadmap
 
 See the [RAGFlow Roadmap 2024](https://github.com/infiniflow/ragflow/issues/162)
diff --git a/README_ja.md b/README_ja.md
index e6e2ed3..8437beb 100644
--- a/README_ja.md
+++ b/README_ja.md
@@ -124,12 +124,12 @@
 
     * Running on all addresses (0.0.0.0)
     * Running on http://127.0.0.1:9380
-    * Running on http://172.22.0.5:9380
+    * Running on http://x.x.x.x:9380
     INFO:werkzeug:Press CTRL+C to quit
    ```
 
 5. ウェブブラウザで、プロンプトに従ってサーバーの IP アドレスを入力し、RAGFlow にログインします。
-   > デフォルトの設定を使用する場合、デフォルトの HTTP サービングポート `80` は省略できるので、与えられたシナリオでは、`http://172.22.0.5`(ポート番号は省略)だけを入力すればよい。
+   > デフォルトの設定を使用する場合、デフォルトの HTTP サービングポート `80` は省略できるので、与えられたシナリオでは、`http://IP_OF_YOUR_MACHINE`(ポート番号は省略)だけを入力すればよい。
 6. [service_conf.yaml](./docker/service_conf.yaml) で、`user_default_llm` で希望の LLM ファクトリを選択し、`API_KEY` フィールドを対応する API キーで更新する。
 
    > 詳しくは [./docs/llm_api_key_setup.md](./docs/llm_api_key_setup.md) を参照してください。
@@ -168,6 +168,11 @@ $ cd ragflow/docker
 $ docker compose up -d
 ```
 
+## 🆕 最新の新機能
+
+- [Ollam](./docs/ollama.md) を使用した大規模モデルのローカライズされたデプロイメントをサポートします。
+- 中国語インターフェースをサポートします。
+
 ## 📜 ロードマップ
 
 [RAGFlow ロードマップ 2024](https://github.com/infiniflow/ragflow/issues/162) を参照
diff --git a/README_zh.md b/README_zh.md
index d7452bd..eec642e 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -124,12 +124,12 @@
 
     * Running on all addresses (0.0.0.0)
     * Running on http://127.0.0.1:9380
-    * Running on http://172.22.0.5:9380
+    * Running on http://x.x.x.x:9380
     INFO:werkzeug:Press CTRL+C to quit
    ```
 
-5. 根据刚才的界面提示在你的浏览器中输入你的服务器对应的 IP 地址并登录 RAGFlow。
-   > 上面这个例子中,您只需输入 http://172.22.0.5 即可:未改动过配置则无需输入端口(默认的 HTTP 服务端口 80)。
+5. 在你的浏览器中输入你的服务器对应的 IP 地址并登录 RAGFlow。
+   > 上面这个例子中,您只需输入 http://IP_OF_YOUR_MACHINE 即可:未改动过配置则无需输入端口(默认的 HTTP 服务端口 80)。
 6. 在 [service_conf.yaml](./docker/service_conf.yaml) 文件的 `user_default_llm` 栏配置 LLM factory,并在 `API_KEY` 栏填写和你选择的大模型相对应的 API key。
 
    > 详见 [./docs/llm_api_key_setup.md](./docs/llm_api_key_setup.md)。
@@ -168,9 +168,14 @@ $ cd ragflow/docker
 $ docker compose up -d
 ```
 
+## 🆕 最近新特性
+
+- 支持用 [Ollam](./docs/ollama.md) 对大模型进行本地化部署。
+- 支持中文界面。
+
 ## 📜 路线图
 
-详见 [RAGFlow Roadmap 2024](https://github.com/infiniflow/ragflow/issues/162)。
+详见 [RAGFlow Roadmap 2024](https://github.com/infiniflow/ragflow/issues/162) 。
 
 ## 🏄 开源社区
 
@@ -179,7 +184,7 @@ $ docker compose up -d
 
 ## 🙌 贡献指南
 
-RAGFlow 只有通过开源协作才能蓬勃发展。秉持这一精神,我们欢迎来自社区的各种贡献。如果您有意参与其中,请查阅我们的[贡献者指南](https://github.com/infiniflow/ragflow/blob/main/docs/CONTRIBUTING.md)。
+RAGFlow 只有通过开源协作才能蓬勃发展。秉持这一精神,我们欢迎来自社区的各种贡献。如果您有意参与其中,请查阅我们的[贡献者指南](https://github.com/infiniflow/ragflow/blob/main/docs/CONTRIBUTING.md) 。
 
 ## 👥 加入社区
 
diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py
index 49374b0..42339e1 100644
--- a/api/apps/conversation_app.py
+++ b/api/apps/conversation_app.py
@@ -126,7 +126,7 @@ def message_fit_in(msg, max_length=4000):
     if c < max_length:
         return c, msg
 
-    msg_ = [m for m in msg[:-1] if m.role == "system"]
+    msg_ = [m for m in msg[:-1] if m["role"] == "system"]
     msg_.append(msg[-1])
     msg = msg_
     c = count()
diff --git a/api/apps/document_app.py b/api/apps/document_app.py
index 6b6715a..29e1686 100644
--- a/api/apps/document_app.py
+++ b/api/apps/document_app.py
@@ -81,7 +81,7 @@ def upload():
             "parser_id": kb.parser_id,
             "parser_config": kb.parser_config,
             "created_by": current_user.id,
-            "type": filename_type(filename),
+            "type": filetype,
             "name": filename,
             "location": location,
             "size": len(blob),
diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py
index a0eb80a..78ff7bf 100644
--- a/api/apps/llm_app.py
+++ b/api/apps/llm_app.py
@@ -91,6 +91,57 @@ def set_api_key():
     return get_json_result(data=True)
 
 
+@manager.route('/add_llm', methods=['POST'])
+@login_required
+@validate_request("llm_factory", "llm_name", "model_type")
+def add_llm():
+    req = request.json
+    llm = {
+        "tenant_id": current_user.id,
+        "llm_factory": req["llm_factory"],
+        "model_type": req["model_type"],
+        "llm_name": req["llm_name"],
+        "api_base": req.get("api_base", ""),
+        "api_key": "xxxxxxxxxxxxxxx"
+    }
+
+    factory = req["llm_factory"]
+    msg = ""
+    if llm["model_type"] == LLMType.EMBEDDING.value:
+        mdl = EmbeddingModel[factory](
+            key=None, model_name=llm["llm_name"], base_url=llm["api_base"])
+        try:
+            arr, tc = mdl.encode(["Test if the api key is available"])
+            if len(arr[0]) == 0 or tc == 0:
+                raise Exception("Fail")
+        except Exception as e:
+            msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
+    elif llm["model_type"] == LLMType.CHAT.value:
+        mdl = ChatModel[factory](
+            key=None, model_name=llm["llm_name"], base_url=llm["api_base"])
+        try:
+            m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
+                             "temperature": 0.9})
+            if not tc:
+                raise Exception(m)
+        except Exception as e:
+            msg += f"\nFail to access model({llm['llm_name']})." + str(
+                e)
+    else:
+        # TODO: check other type of models
+        pass
+
+    if msg:
+        return get_data_error_result(retmsg=msg)
+
+
+    if not TenantLLMService.filter_update(
+            [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm):
+        TenantLLMService.save(**llm)
+
+    return get_json_result(data=True)
+
+
 @manager.route('/my_llms', methods=['GET'])
 @login_required
 def my_llms():
@@ -125,6 +176,12 @@ def list():
         for m in llms:
             m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding"
 
+        llm_set = set([m["llm_name"] for m in llms])
+        for o in objs:
+            if not o.api_key:continue
+            if o.llm_name in llm_set:continue
+            llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True})
+
         res = {}
         for m in llms:
             if model_type and m["model_type"] != model_type:
diff --git a/api/apps/user_app.py b/api/apps/user_app.py
index 31fc685..47ed1c9 100644
--- a/api/apps/user_app.py
+++ b/api/apps/user_app.py
@@ -181,6 +181,10 @@ def user_info():
 
 
 def rollback_user_registration(user_id):
+    try:
+        UserService.delete_by_id(user_id)
+    except Exception as e:
+        pass
     try:
         TenantService.delete_by_id(user_id)
     except Exception as e:
diff --git a/api/db/init_data.py b/api/db/init_data.py
index 5f34328..4cc72a2 100644
--- a/api/db/init_data.py
+++ b/api/db/init_data.py
@@ -18,7 +18,7 @@ import time
 import uuid
 
 from api.db import LLMType, UserTenantRole
-from api.db.db_models import init_database_tables as init_web_db
+from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM
 from api.db.services import UserService
 from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
 from api.db.services.user_service import TenantService, UserTenantService
@@ -100,16 +100,16 @@ factory_infos = [{
     "status": "1",
 },
     {
-    "name": "Local",
+    "name": "Ollama",
     "logo": "",
     "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
         "status": "1",
 }, {
-        "name": "Moonshot",
+    "name": "Moonshot",
     "logo": "",
     "tags": "LLM,TEXT EMBEDDING",
     "status": "1",
-}
+},
     # {
     #     "name": "文心一言",
     #     "logo": "",
@@ -230,20 +230,6 @@ def init_llm_factory():
             "max_tokens": 512,
             "model_type": LLMType.EMBEDDING.value
         },
-        # ---------------------- 本地 ----------------------
-        {
-            "fid": factory_infos[3]["name"],
-            "llm_name": "qwen-14B-chat",
-            "tags": "LLM,CHAT,",
-            "max_tokens": 4096,
-            "model_type": LLMType.CHAT.value
-        }, {
-            "fid": factory_infos[3]["name"],
-            "llm_name": "flag-embedding",
-            "tags": "TEXT EMBEDDING,",
-            "max_tokens": 128 * 1000,
-            "model_type": LLMType.EMBEDDING.value
-        },
         # ------------------------ Moonshot -----------------------
         {
             "fid": factory_infos[4]["name"],
@@ -282,6 +268,9 @@ def init_llm_factory():
         except Exception as e:
             pass
 
+    LLMFactoriesService.filter_delete([LLMFactories.name=="Local"])
+    LLMService.filter_delete([LLM.fid=="Local"])
+
     """
     drop table llm;
     drop table llm_factories;
@@ -295,8 +284,7 @@ def init_llm_factory():
 def init_web_data():
     start_time = time.time()
 
-    if LLMFactoriesService.get_all().count() != len(factory_infos):
-        init_llm_factory()
+    init_llm_factory()
     if not UserService.get_all().count():
         init_superuser()
 
diff --git a/docker/docker-compose-CN.yml b/docker/docker-compose-CN.yml
index a4f3f77..2621634 100644
--- a/docker/docker-compose-CN.yml
+++ b/docker/docker-compose-CN.yml
@@ -20,6 +20,7 @@ services:
       - 443:443
     volumes:
       - ./service_conf.yaml:/ragflow/conf/service_conf.yaml
+      - ./entrypoint.sh:/ragflow/entrypoint.sh
       - ./ragflow-logs:/ragflow/logs
       - ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
       - ./nginx/proxy.conf:/etc/nginx/proxy.conf
diff --git a/docs/ollama.md b/docs/ollama.md
new file mode 100644
index 0000000..c226d86
--- /dev/null
+++ b/docs/ollama.md
@@ -0,0 +1,40 @@
+# Ollama
+
+<div align="center" style="margin-top:20px;margin-bottom:20px;">
+<img src="https://github.com/infiniflow/ragflow/assets/12318111/2019e7ee-1e8a-412e-9349-11bbf702e549" width="130"/>
+</div>
+
+One-click deployment of local LLMs, that is [Ollama](https://github.com/ollama/ollama).
+
+## Install
+
+- [Ollama on Linux](https://github.com/ollama/ollama/blob/main/docs/linux.md)
+- [Ollama Windows Preview](https://github.com/ollama/ollama/blob/main/docs/windows.md)
+- [Docker](https://hub.docker.com/r/ollama/ollama)
+
+## Launch Ollama
+
+Decide which LLM you want to deploy ([here's a list for supported LLM](https://ollama.com/library)), say, **mistral**:
+```bash
+$ ollama run mistral
+```
+Or,
+```bash
+$ docker exec -it ollama ollama run mistral
+```
+
+## Use Ollama in RAGFlow
+
+- Go to 'Settings > Model Providers > Models to be added > Ollama'.
+    
+<div align="center" style="margin-top:20px;margin-bottom:20px;">
+<img src="https://github.com/infiniflow/ragflow/assets/12318111/2019e7ee-1e8a-412e-9349-11bbf702e549" width="130"/>
+</div>
+
+> Base URL: Enter the base URL where the Ollama service is accessible, like, http://<your-ollama-endpoint-domain>:11434
+
+- Use Ollama Models.
+
+<div align="center" style="margin-top:20px;margin-bottom:20px;">
+<img src="https://github.com/infiniflow/ragflow/assets/12318111/2019e7ee-1e8a-412e-9349-11bbf702e549" width="130"/>
+</div>
\ No newline at end of file
diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py
index 74a8dbf..c3fc7db 100644
--- a/rag/llm/__init__.py
+++ b/rag/llm/__init__.py
@@ -19,7 +19,7 @@ from .cv_model import *
 
 
 EmbeddingModel = {
-    "Local": HuEmbedding,
+    "Ollama": OllamaEmbed,
     "OpenAI": OpenAIEmbed,
     "Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
     "ZHIPU-AI": ZhipuEmbed,
@@ -29,7 +29,7 @@ EmbeddingModel = {
 
 CvModel = {
     "OpenAI": GptV4,
-    "Local": LocalCV,
+    "Ollama": OllamaCV,
     "Tongyi-Qianwen": QWenCV,
     "ZHIPU-AI": Zhipu4V,
     "Moonshot": LocalCV
@@ -40,7 +40,7 @@ ChatModel = {
     "OpenAI": GptTurbo,
     "ZHIPU-AI": ZhipuChat,
     "Tongyi-Qianwen": QWenChat,
-    "Local": LocalLLM,
+    "Ollama": OllamaChat,
     "Moonshot": MoonshotChat
 }
 
diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py
index c0379a8..4da841b 100644
--- a/rag/llm/chat_model.py
+++ b/rag/llm/chat_model.py
@@ -18,6 +18,7 @@ from dashscope import Generation
 from abc import ABC
 from openai import OpenAI
 import openai
+from ollama import Client
 from rag.nlp import is_english
 from rag.utils import num_tokens_from_string
 
@@ -129,6 +130,32 @@ class ZhipuChat(Base):
             return "**ERROR**: " + str(e), 0
 
 
+class OllamaChat(Base):
+    def __init__(self, key, model_name, **kwargs):
+        self.client = Client(host=kwargs["base_url"])
+        self.model_name = model_name
+
+    def chat(self, system, history, gen_conf):
+        if system:
+            history.insert(0, {"role": "system", "content": system})
+        try:
+            options = {"temperature": gen_conf.get("temperature", 0.1),
+                       "num_predict": gen_conf.get("max_tokens", 128),
+                       "top_k": gen_conf.get("top_p", 0.3),
+                       "presence_penalty": gen_conf.get("presence_penalty", 0.4),
+                       "frequency_penalty": gen_conf.get("frequency_penalty", 0.7),
+                       }
+            response = self.client.chat(
+                model=self.model_name,
+                messages=history,
+                options=options
+            )
+            ans = response["message"]["content"].strip()
+            return ans, response["eval_count"]
+        except Exception as e:
+            return "**ERROR**: " + str(e), 0
+
+
 class LocalLLM(Base):
     class RPCProxy:
         def __init__(self, host, port):
diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py
index 61b942c..d764bc8 100644
--- a/rag/llm/cv_model.py
+++ b/rag/llm/cv_model.py
@@ -16,7 +16,7 @@
 from zhipuai import ZhipuAI
 import io
 from abc import ABC
-
+from ollama import Client
 from PIL import Image
 from openai import OpenAI
 import os
@@ -140,6 +140,28 @@ class Zhipu4V(Base):
         return res.choices[0].message.content.strip(), res.usage.total_tokens
 
 
+class OllamaCV(Base):
+    def __init__(self, key, model_name, lang="Chinese", **kwargs):
+        self.client = Client(host=kwargs["base_url"])
+        self.model_name = model_name
+        self.lang = lang
+
+    def describe(self, image, max_tokens=1024):
+        prompt = self.prompt("")
+        try:
+            options = {"num_predict": max_tokens}
+            response = self.client.generate(
+                model=self.model_name,
+                prompt=prompt[0]["content"][1]["text"],
+                images=[image],
+                options=options
+            )
+            ans = response["response"].strip()
+            return ans, 128
+        except Exception as e:
+            return "**ERROR**: " + str(e), 0
+
+
 class LocalCV(Base):
     def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
         pass
diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py
index 6ee3a58..d5b763d 100644
--- a/rag/llm/embedding_model.py
+++ b/rag/llm/embedding_model.py
@@ -16,13 +16,12 @@
 from zhipuai import ZhipuAI
 import os
 from abc import ABC
-
+from ollama import Client
 import dashscope
 from openai import OpenAI
 from FlagEmbedding import FlagModel
 import torch
 import numpy as np
-from huggingface_hub import snapshot_download
 
 from api.utils.file_utils import get_project_base_directory
 from rag.utils import num_tokens_from_string
@@ -150,3 +149,24 @@ class ZhipuEmbed(Base):
         res = self.client.embeddings.create(input=text,
                                             model=self.model_name)
         return np.array(res.data[0].embedding), res.usage.total_tokens
+
+
+class OllamaEmbed(Base):
+    def __init__(self, key, model_name, **kwargs):
+        self.client = Client(host=kwargs["base_url"])
+        self.model_name = model_name
+
+    def encode(self, texts: list, batch_size=32):
+        arr = []
+        tks_num = 0
+        for txt in texts:
+            res = self.client.embeddings(prompt=txt,
+                                            model=self.model_name)
+            arr.append(res["embedding"])
+            tks_num += 128
+        return np.array(arr), tks_num
+
+    def encode_queries(self, text):
+        res = self.client.embeddings(prompt=text,
+                                            model=self.model_name)
+        return np.array(res["embedding"]), 128
diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py
index 7918648..799d252 100644
--- a/rag/svr/task_executor.py
+++ b/rag/svr/task_executor.py
@@ -23,7 +23,8 @@ import re
 import sys
 import traceback
 from functools import partial
-
+import signal
+from contextlib import contextmanager
 from rag.settings import database_logger
 from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
 
@@ -97,8 +98,21 @@ def collect(comm, mod, tm):
     cron_logger.info("TOTAL:{}, To:{}".format(len(tasks), mtm))
     return tasks
 
+@contextmanager
+def timeout(time):
+    # Register a function to raise a TimeoutError on the signal.
+    signal.signal(signal.SIGALRM, raise_timeout)
+    # Schedule the signal to be sent after ``time``.
+    signal.alarm(time)
+    yield
+
+
+def raise_timeout(signum, frame):
+    raise TimeoutError
+
 
 def build(row):
+    from timeit import default_timer as timer
     if row["size"] > DOC_MAXIMUM_SIZE:
         set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
                      (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
@@ -111,11 +125,14 @@ def build(row):
         row["to_page"])
     chunker = FACTORY[row["parser_id"].lower()]
     try:
-        cron_logger.info(
-            "Chunkking {}/{}".format(row["location"], row["name"]))
-        cks = chunker.chunk(row["name"], binary=MINIO.get(row["kb_id"], row["location"]), from_page=row["from_page"],
+        st = timer()
+        with timeout(30):
+            binary = MINIO.get(row["kb_id"], row["location"])
+        cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"],
                             to_page=row["to_page"], lang=row["language"], callback=callback,
                             kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
+        cron_logger.info(
+            "Chunkking({}) {}/{}".format(timer()-st, row["location"], row["name"]))
     except Exception as e:
         if re.search("(No such file|not found)", str(e)):
             callback(-1, "Can not find file <%s>" % row["name"])
-- 
GitLab