From f1f09df901c2729419489b68ca93b493426080e7 Mon Sep 17 00:00:00 2001
From: KevinHuSh <kevinhu.sh@gmail.com>
Date: Tue, 12 Mar 2024 11:57:08 +0800
Subject: [PATCH] add local llm implementation (#119)

---
 Dockerfile                               |  2 +-
 README.md                                |  6 +-
 api/apps/__init__.py                     |  2 +-
 api/apps/llm_app.py                      |  2 +-
 api/db/db_models.py                      |  2 +-
 api/db/services/knowledgebase_service.py |  1 +
 api/settings.py                          | 10 ++-
 deepdoc/parser/excel_parser.py           | 16 ++++-
 docker/nginx/nginx.conf                  |  2 +-
 rag/app/table.py                         | 17 +++--
 rag/llm/__init__.py                      |  9 ++-
 rag/llm/chat_model.py                    | 42 ++++++++++-
 rag/llm/cv_model.py                      |  8 +++
 rag/llm/rpc_server.py                    | 90 ++++++++++++++++++++++++
 rag/settings.py                          |  2 +-
 rag/svr/task_broker.py                   |  8 +++
 rag/svr/task_executor.py                 |  2 +-
 17 files changed, 196 insertions(+), 25 deletions(-)
 create mode 100644 rag/llm/rpc_server.py

diff --git a/Dockerfile b/Dockerfile
index b49434f..c174ccb 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,4 +1,4 @@
-FROM infiniflow/ragflow-base:v1.0
+FROM swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow-base:v1.0
 USER  root
 
 WORKDIR /ragflow
diff --git a/README.md b/README.md
index 6793e65..b59a4ad 100644
--- a/README.md
+++ b/README.md
@@ -21,7 +21,7 @@
   </a>
 </p>
 
-[RAGFLOW](http://ragflow.io) is a knowledge management platform built on custom-build document understanding engine and LLM, 
+[RagFlow](http://ragflow.io) is a knowledge management platform built on custom-build document understanding engine and LLM, 
 with reasoned and well-founded answers to your question. Clone this repository, you can deploy your own knowledge management 
 platform to empower your business with AI.
     
@@ -29,12 +29,12 @@ platform to empower your business with AI.
 <img src="https://github.com/infiniflow/ragflow/assets/12318111/b24a7a5f-4d1d-4a30-90b1-7b0ec558b79d" width="1000"/>
 </div>
 
-# Features
+# Key Features
 - **Custom-build document understanding engine.** Our deep learning engine is made according to the needs of analyzing and searching various type of documents in different domain.
   - For documents from different domain for different purpose, the engine applys different analyzing and search strategy.
   - Easily intervene and manipulate the data proccessing procedure when things goes beyond expectation.
   - Multi-media document understanding is supported using OCR and multi-modal LLM. 
-- **State-of-the-art table structure and layout recognition.** Precisely extract and understand the document including table content. [README](./deepdoc/README.md)
+- **State-of-the-art table structure and layout recognition.** Precisely extract and understand the document including table content. See [README.](./deepdoc/README.md)
   - For PDF files, layout and table structures including row, column and span of them are recognized.
   - Put the table accrossing the pages together.
   - Reconstruct the table structure components into html table.  
diff --git a/api/apps/__init__.py b/api/apps/__init__.py
index a53663b..02f629a 100644
--- a/api/apps/__init__.py
+++ b/api/apps/__init__.py
@@ -52,7 +52,7 @@ app.errorhandler(Exception)(server_error_response)
 #app.config["LOGIN_DISABLED"] = True
 app.config["SESSION_PERMANENT"] = False
 app.config["SESSION_TYPE"] = "filesystem"
-app.config['MAX_CONTENT_LENGTH'] = 64 * 1024 * 1024
+app.config['MAX_CONTENT_LENGTH'] = 128 * 1024 * 1024
 
 Session(app)
 login_manager = LoginManager()
diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py
index e0e213f..e89c855 100644
--- a/api/apps/llm_app.py
+++ b/api/apps/llm_app.py
@@ -85,7 +85,7 @@ def my_llms():
                 }
             res[o["llm_factory"]]["llm"].append({
                 "type": o["model_type"],
-                "name": o["model_name"],
+                "name": o["llm_name"],
                 "used_token": o["used_tokens"]
             })
         return get_json_result(data=res)
diff --git a/api/db/db_models.py b/api/db/db_models.py
index f28d37b..b02ea2c 100644
--- a/api/db/db_models.py
+++ b/api/db/db_models.py
@@ -520,7 +520,7 @@ class Task(DataBaseModel):
     begin_at = DateTimeField(null=True)
     process_duation = FloatField(default=0)
     progress = FloatField(default=0)
-    progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="")
+    progress_msg = TextField(max_length=4096, null=True, help_text="process message", default="")
 
 
 class Dialog(DataBaseModel):
diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py
index 236b8d0..46fe4bc 100644
--- a/api/db/services/knowledgebase_service.py
+++ b/api/db/services/knowledgebase_service.py
@@ -47,6 +47,7 @@ class KnowledgebaseService(CommonService):
             Tenant.embd_id,
             cls.model.avatar,
             cls.model.name,
+            cls.model.language,
             cls.model.description,
             cls.model.permission,
             cls.model.doc_num,
diff --git a/api/settings.py b/api/settings.py
index 3200763..956e2f3 100644
--- a/api/settings.py
+++ b/api/settings.py
@@ -42,7 +42,7 @@ ERROR_REPORT = True
 ERROR_REPORT_WITH_PATH = False
 
 MAX_TIMESTAMP_INTERVAL = 60
-SESSION_VALID_PERIOD = 7 * 24 * 60 * 60 * 1000
+SESSION_VALID_PERIOD = 7 * 24 * 60 * 60
 
 REQUEST_TRY_TIMES = 3
 REQUEST_WAIT_SEC = 2
@@ -69,6 +69,12 @@ default_llm = {
         "image2text_model": "glm-4v",
         "asr_model": "",
     },
+    "local": {
+        "chat_model": "",
+        "embedding_model": "",
+        "image2text_model": "",
+        "asr_model": "",
+    }
 }
 LLM = get_base_config("user_default_llm", {})
 LLM_FACTORY = LLM.get("factory", "通义千问")
@@ -134,7 +140,7 @@ USE_AUTHENTICATION = False
 USE_DATA_AUTHENTICATION = False
 AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True
 USE_DEFAULT_TIMEOUT = False
-AUTHENTICATION_DEFAULT_TIMEOUT = 30 * 24 * 60 * 60 # s
+AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s
 PRIVILEGE_COMMAND_WHITELIST = []
 CHECK_NODES_IDENTITY = False
 
diff --git a/deepdoc/parser/excel_parser.py b/deepdoc/parser/excel_parser.py
index 10f3b28..d2054f1 100644
--- a/deepdoc/parser/excel_parser.py
+++ b/deepdoc/parser/excel_parser.py
@@ -20,13 +20,27 @@ class HuExcelParser:
                 for i,c in enumerate(r):
                     if not c.value:continue
                     t = str(ti[i].value) if i < len(ti) else ""
-                    t += (":"  if t else "") + str(c.value)
+                    t += (":" if t else "") + str(c.value)
                     l.append(t)
                 l = "; ".join(l)
                 if sheetname.lower().find("sheet") <0: l += " ——"+sheetname
                 res.append(l)
         return res
 
+    @staticmethod
+    def row_number(fnm, binary):
+        if fnm.split(".")[-1].lower().find("xls") >= 0:
+            wb = load_workbook(BytesIO(binary))
+            total = 0
+            for sheetname in wb.sheetnames:
+                ws = wb[sheetname]
+                total += len(ws.rows)
+                return total
+
+        if fnm.split(".")[-1].lower() in ["csv", "txt"]:
+            txt = binary.decode("utf-8")
+            return len(txt.split("\n"))
+
 
 if __name__ == "__main__":
     psr = HuExcelParser()
diff --git a/docker/nginx/nginx.conf b/docker/nginx/nginx.conf
index 05c5029..8933a8d 100644
--- a/docker/nginx/nginx.conf
+++ b/docker/nginx/nginx.conf
@@ -26,7 +26,7 @@ http {
     keepalive_timeout  65;
 
     #gzip  on;
-    client_max_body_size 82M;
+    client_max_body_size 128M;
 
     include /etc/nginx/conf.d/ragflow.conf;
 }
diff --git a/rag/app/table.py b/rag/app/table.py
index 68b2d33..4cf1c1c 100644
--- a/rag/app/table.py
+++ b/rag/app/table.py
@@ -25,7 +25,7 @@ from deepdoc.parser import ExcelParser
 
 
 class Excel(ExcelParser):
-    def __call__(self, fnm, binary=None, callback=None):
+    def __call__(self, fnm, binary=None, from_page=0, to_page=10000000000, callback=None):
         if not binary:
             wb = load_workbook(fnm)
         else:
@@ -35,6 +35,7 @@ class Excel(ExcelParser):
             total += len(list(wb[sheetname].rows))
 
         res, fails, done = [], [], 0
+        rn = 0
         for sheetname in wb.sheetnames:
             ws = wb[sheetname]
             rows = list(ws.rows)
@@ -46,6 +47,9 @@ class Excel(ExcelParser):
                     rows[0]) if i not in missed]
             data = []
             for i, r in enumerate(rows[1:]):
+                rn += 1
+                if rn-1 < from_page:continue
+                if rn -1>=to_page: break
                 row = [
                     cell.value for ii,
                     cell in enumerate(r) if ii not in missed]
@@ -111,7 +115,7 @@ def column_data_type(arr):
     return arr, ty
 
 
-def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
+def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese", callback=None, **kwargs):
     """
         Excel and csv(txt) format files are supported.
         For csv or txt file, the delimiter between columns is TAB.
@@ -147,16 +151,15 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
         headers = lines[0].split(kwargs.get("delimiter", "\t"))
         rows = []
         for i, line in enumerate(lines[1:]):
+            if from_page < from_page:continue
+            if i >= to_page: break
             row = [l for l in line.split(kwargs.get("delimiter", "\t"))]
             if len(row) != len(headers):
                 fails.append(str(i))
                 continue
             rows.append(row)
-            if len(rows) % 999 == 0:
-                callback(len(rows) * 0.6 / len(lines), ("Extract records: {}".format(len(rows)) + (
-                    f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
 
-        callback(0.6, ("Extract records: {}".format(len(rows)) + (
+        callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + (
             f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
 
         dfs = [pd.DataFrame(np.array(rows), columns=headers)]
@@ -209,7 +212,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
 
         KnowledgebaseService.update_parser_config(
             kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}})
-    callback(0.6, "")
+    callback(0.35, "")
 
     return res
 
diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py
index 462313d..c2c2726 100644
--- a/rag/llm/__init__.py
+++ b/rag/llm/__init__.py
@@ -19,22 +19,25 @@ from .cv_model import *
 
 
 EmbeddingModel = {
-    "Infiniflow": HuEmbedding,
+    "local": HuEmbedding,
     "OpenAI": OpenAIEmbed,
     "通义千问": HuEmbedding, #QWenEmbed,
+    "智谱AI": ZhipuEmbed
 }
 
 
 CvModel = {
     "OpenAI": GptV4,
-    "Infiniflow": GptV4,
+    "local": LocalCV,
     "通义千问": QWenCV,
+    "智谱AI": Zhipu4V
 }
 
 
 ChatModel = {
     "OpenAI": GptTurbo,
-    "Infiniflow": GptTurbo,
+    "智谱AI": ZhipuChat,
     "通义千问": QWenChat,
+    "local": LocalLLM
 }
 
diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py
index 2389561..88a682f 100644
--- a/rag/llm/chat_model.py
+++ b/rag/llm/chat_model.py
@@ -20,6 +20,7 @@ from openai import OpenAI
 import openai
 
 from rag.nlp import is_english
+from rag.utils import num_tokens_from_string
 
 
 class Base(ABC):
@@ -86,7 +87,6 @@ class ZhipuChat(Base):
         self.model_name = model_name
 
     def chat(self, system, history, gen_conf):
-        from http import HTTPStatus
         if system: history.insert(0, {"role": "system", "content": system})
         try:
             response = self.client.chat.completions.create(
@@ -100,4 +100,42 @@ class ZhipuChat(Base):
                     [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
             return ans, response.usage.completion_tokens
         except Exception as e:
-            return "**ERROR**: " + str(e), 0
\ No newline at end of file
+            return "**ERROR**: " + str(e), 0
+
+class LocalLLM(Base):
+    class RPCProxy:
+        def __init__(self, host, port):
+            self.host = host
+            self.port = int(port)
+            self.__conn()
+
+        def __conn(self):
+            from multiprocessing.connection import Client
+            self._connection = Client((self.host, self.port), authkey=b'infiniflow-token4kevinhu')
+
+        def __getattr__(self, name):
+            import pickle
+            def do_rpc(*args, **kwargs):
+                for _ in range(3):
+                    try:
+                        self._connection.send(pickle.dumps((name, args, kwargs)))
+                        return pickle.loads(self._connection.recv())
+                    except Exception as e:
+                        self.__conn()
+                raise Exception("RPC connection lost!")
+
+            return do_rpc
+
+    def __init__(self, key, model_name="glm-3-turbo"):
+        self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
+
+    def chat(self, system, history, gen_conf):
+        if system: history.insert(0, {"role": "system", "content": system})
+        try:
+            ans = self.client.chat(
+                history,
+                gen_conf
+            )
+            return ans, num_tokens_from_string(ans)
+        except Exception as e:
+            return "**ERROR**: " + str(e), 0
diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py
index 02298c6..6e139c7 100644
--- a/rag/llm/cv_model.py
+++ b/rag/llm/cv_model.py
@@ -138,3 +138,11 @@ class Zhipu4V(Base):
             max_tokens=max_tokens,
         )
         return res.choices[0].message.content.strip(), res.usage.total_tokens
+
+
+class LocalCV(Base):
+    def __init__(self, key, model_name="glm-4v", lang="Chinese"):
+        pass
+
+    def describe(self, image, max_tokens=1024):
+        return "", 0
diff --git a/rag/llm/rpc_server.py b/rag/llm/rpc_server.py
new file mode 100644
index 0000000..7cfd0e8
--- /dev/null
+++ b/rag/llm/rpc_server.py
@@ -0,0 +1,90 @@
+import argparse
+import pickle
+import random
+import time
+from multiprocessing.connection import Listener
+from threading import Thread
+import torch
+
+
+class RPCHandler:
+    def __init__(self):
+        self._functions = { }
+
+    def register_function(self, func):
+        self._functions[func.__name__] = func
+
+    def handle_connection(self, connection):
+        try:
+            while True:
+                # Receive a message
+                func_name, args, kwargs = pickle.loads(connection.recv())
+                # Run the RPC and send a response
+                try:
+                    r = self._functions[func_name](*args,**kwargs)
+                    connection.send(pickle.dumps(r))
+                except Exception as e:
+                    connection.send(pickle.dumps(e))
+        except EOFError:
+             pass
+
+
+def rpc_server(hdlr, address, authkey):
+    sock = Listener(address, authkey=authkey)
+    while True:
+        try:
+            client = sock.accept()
+            t = Thread(target=hdlr.handle_connection, args=(client,))
+            t.daemon = True
+            t.start()
+        except Exception as e:
+            print("【EXCEPTION】:", str(e))
+
+
+models = []
+tokenizer = None
+
+def chat(messages, gen_conf):
+    global tokenizer
+    model = Model()
+    roles = {"system":"System", "user": "User", "assistant": "Assistant"}
+    line = ["{}: {}".format(roles[m["role"].lower()], m["content"]) for m in messages]
+    line = "\n".join(line) + "\nAssistant: "
+    tokens = tokenizer([line], return_tensors='pt')
+    tokens = {k: tokens[k].to(model.device) if isinstance(tokens[k], torch.Tensor) else tokens[k] for k in
+              tokens.keys()}
+    res = [tokenizer.decode(t) for t in model.generate(**tokens, **gen_conf)][0]
+    return res.split("Assistant: ")[-1]
+
+
+def Model():
+    global models
+    random.seed(time.time())
+    return random.choice(models)
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--model_name", type=str, help="Model name")
+    parser.add_argument("--port", default=7860, type=int, help="RPC serving port")
+    args = parser.parse_args()
+
+    handler = RPCHandler()
+    handler.register_function(chat)
+
+    from transformers import AutoModelForCausalLM, AutoTokenizer
+    from transformers.generation.utils import GenerationConfig
+
+    models = []
+    for _ in range(2):
+        m = AutoModelForCausalLM.from_pretrained(args.model_name,
+                                                 device_map="auto",
+                                                 torch_dtype='auto',
+                                                 trust_remote_code=True)
+        m.generation_config = GenerationConfig.from_pretrained(args.model_name)
+        m.generation_config.pad_token_id = m.generation_config.eos_token_id
+        models.append(m)
+    tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False,
+                                              trust_remote_code=True)
+
+    # Run the server
+    rpc_server(handler, ('0.0.0.0', args.port), authkey=b'infiniflow-token4kevinhu')
diff --git a/rag/settings.py b/rag/settings.py
index 2906ea3..7c2257a 100644
--- a/rag/settings.py
+++ b/rag/settings.py
@@ -25,7 +25,7 @@ SUBPROCESS_STD_LOG_NAME = "std.log"
 
 ES = get_base_config("es", {})
 MINIO = decrypt_database_config(name="minio")
-DOC_MAXIMUM_SIZE = 64 * 1024 * 1024
+DOC_MAXIMUM_SIZE = 128 * 1024 * 1024
 
 # Logger
 LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "rag"))
diff --git a/rag/svr/task_broker.py b/rag/svr/task_broker.py
index 38a2b44..e79ac38 100644
--- a/rag/svr/task_broker.py
+++ b/rag/svr/task_broker.py
@@ -22,6 +22,7 @@ from api.db.db_models import Task
 from api.db.db_utils import bulk_insert_into_db
 from api.db.services.task_service import TaskService
 from deepdoc.parser import PdfParser
+from deepdoc.parser.excel_parser import HuExcelParser
 from rag.settings import cron_logger
 from rag.utils import MINIO
 from rag.utils import findMaxTm
@@ -88,6 +89,13 @@ def dispatch():
                     task["from_page"] = p
                     task["to_page"] = min(p + 5, e)
                     tsks.append(task)
+        elif r["parser_id"] == "table":
+                rn = HuExcelParser.row_number(r["name"], MINIO.get(r["kb_id"], r["location"]))
+                for i in range(0, rn, 1000):
+                    task = new_task()
+                    task["from_page"] = i
+                    task["to_page"] = min(i + 1000, rn)
+                    tsks.append(task)
         else:
             tsks.append(new_task())
 
diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py
index 9ba17ef..9765957 100644
--- a/rag/svr/task_executor.py
+++ b/rag/svr/task_executor.py
@@ -184,7 +184,7 @@ def embedding(docs, mdl, parser_config={}, callback=None):
         if len(cnts_) == 0: cnts_ = vts
         else: cnts_ = np.concatenate((cnts_, vts), axis=0)
         tk_count += c
-        callback(msg="")
+        callback(prog=0.7+0.2*(i+1)/len(cnts), msg="")
     cnts = cnts_
 
     title_w = float(parser_config.get("filename_embd_weight", 0.1))
-- 
GitLab