From 453c29170f305cc840b92d9065b94fbe15421ef7 Mon Sep 17 00:00:00 2001
From: KevinHuSh <kevinhu.sh@gmail.com>
Date: Thu, 18 Apr 2024 09:37:23 +0800
Subject: [PATCH] make sure the models will not be load twice (#422)

### What problem does this PR solve?

#381
### Type of change

- [x] Refactoring
---
 api/apps/api_app.py                          | 9 ++++-----
 api/db/db_models.py                          | 2 +-
 deepdoc/parser/pdf_parser.py                 | 4 +++-
 deepdoc/vision/layout_recognizer.py          | 4 +++-
 deepdoc/vision/ocr.py                        | 4 +++-
 deepdoc/vision/recognizer.py                 | 4 +++-
 deepdoc/vision/table_structure_recognizer.py | 4 +++-
 rag/llm/embedding_model.py                   | 7 ++++++-
 8 files changed, 26 insertions(+), 12 deletions(-)

diff --git a/api/apps/api_app.py b/api/apps/api_app.py
index f294272..cc6f646 100644
--- a/api/apps/api_app.py
+++ b/api/apps/api_app.py
@@ -105,8 +105,8 @@ def stats():
         res = {
             "pv": [(o["dt"], o["pv"]) for o in objs],
             "uv": [(o["dt"], o["uv"]) for o in objs],
-            "speed": [(o["dt"], o["tokens"]/o["duration"]) for o in objs],
-            "tokens": [(o["dt"], o["tokens"]/1000.) for o in objs],
+            "speed": [(o["dt"], float(o["tokens"])/float(o["duration"])) for o in objs],
+            "tokens": [(o["dt"], float(o["tokens"])/1000.) for o in objs],
             "round": [(o["dt"], o["round"]) for o in objs],
             "thumb_up": [(o["dt"], o["thumb_up"]) for o in objs]
         }
@@ -115,8 +115,7 @@ def stats():
         return server_error_response(e)
 
 
-@manager.route('/new_conversation', methods=['POST'])
-@validate_request("user_id")
+@manager.route('/new_conversation', methods=['GET'])
 def set_conversation():
     token = request.headers.get('Authorization').split()[1]
     objs = APIToken.query(token=token)
@@ -131,7 +130,7 @@ def set_conversation():
         conv = {
             "id": get_uuid(),
             "dialog_id": dia.id,
-            "user_id": req["user_id"],
+            "user_id": request.args.get("user_id", ""),
             "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
         }
         API4ConversationService.save(**conv)
diff --git a/api/db/db_models.py b/api/db/db_models.py
index 27ad80f..e6f2d28 100644
--- a/api/db/db_models.py
+++ b/api/db/db_models.py
@@ -629,7 +629,7 @@ class Document(DataBaseModel):
         max_length=128,
         null=False,
         default="local",
-        help_text="where dose this document from")
+        help_text="where dose this document come from")
     type = CharField(max_length=32, null=False, help_text="file extension")
     created_by = CharField(
         max_length=32,
diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py
index dfd3756..6c33245 100644
--- a/deepdoc/parser/pdf_parser.py
+++ b/deepdoc/parser/pdf_parser.py
@@ -43,7 +43,9 @@ class HuParser:
                 model_dir, "updown_concat_xgb.model"))
         except Exception as e:
             model_dir = snapshot_download(
-                repo_id="InfiniFlow/text_concat_xgb_v1.0")
+                repo_id="InfiniFlow/text_concat_xgb_v1.0",
+                local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
+                local_dir_use_symlinks=False)
             self.updown_cnt_mdl.load_model(os.path.join(
                 model_dir, "updown_concat_xgb.model"))
 
diff --git a/deepdoc/vision/layout_recognizer.py b/deepdoc/vision/layout_recognizer.py
index 917ee6e..58ddcdb 100644
--- a/deepdoc/vision/layout_recognizer.py
+++ b/deepdoc/vision/layout_recognizer.py
@@ -43,7 +43,9 @@ class LayoutRecognizer(Recognizer):
                     "rag/res/deepdoc")
             super().__init__(self.labels, domain, model_dir)
         except Exception as e:
-            model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
+            model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
+                                          local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
+                                          local_dir_use_symlinks=False)
             super().__init__(self.labels, domain, model_dir)
 
         self.garbage_layouts = ["footer", "header", "reference"]
diff --git a/deepdoc/vision/ocr.py b/deepdoc/vision/ocr.py
index b55024e..d602da0 100644
--- a/deepdoc/vision/ocr.py
+++ b/deepdoc/vision/ocr.py
@@ -486,7 +486,9 @@ class OCR(object):
                 self.text_detector = TextDetector(model_dir)
                 self.text_recognizer = TextRecognizer(model_dir)
             except Exception as e:
-                model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
+                model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
+                                              local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
+                                              local_dir_use_symlinks=False)
                 self.text_detector = TextDetector(model_dir)
                 self.text_recognizer = TextRecognizer(model_dir)
 
diff --git a/deepdoc/vision/recognizer.py b/deepdoc/vision/recognizer.py
index 67e096e..1ca7c44 100644
--- a/deepdoc/vision/recognizer.py
+++ b/deepdoc/vision/recognizer.py
@@ -41,7 +41,9 @@ class Recognizer(object):
                         "rag/res/deepdoc")
             model_file_path = os.path.join(model_dir, task_name + ".onnx")
             if not os.path.exists(model_file_path):
-                model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
+                model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
+                                              local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
+                                              local_dir_use_symlinks=False)
                 model_file_path = os.path.join(model_dir, task_name + ".onnx")
         else:
             model_file_path = os.path.join(model_dir, task_name + ".onnx")
diff --git a/deepdoc/vision/table_structure_recognizer.py b/deepdoc/vision/table_structure_recognizer.py
index 6779137..548eb62 100644
--- a/deepdoc/vision/table_structure_recognizer.py
+++ b/deepdoc/vision/table_structure_recognizer.py
@@ -39,7 +39,9 @@ class TableStructureRecognizer(Recognizer):
                     get_project_base_directory(),
                     "rag/res/deepdoc"))
         except Exception as e:
-            super().__init__(self.labels, "tsr", snapshot_download(repo_id="InfiniFlow/deepdoc"))
+            super().__init__(self.labels, "tsr", snapshot_download(repo_id="InfiniFlow/deepdoc",
+                                              local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
+                                              local_dir_use_symlinks=False))
 
     def __call__(self, images, thr=0.2):
         tbls = super().__call__(images, thr)
diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py
index e6ff0b1..e6e18fb 100644
--- a/rag/llm/embedding_model.py
+++ b/rag/llm/embedding_model.py
@@ -14,6 +14,8 @@
 #  limitations under the License.
 #
 from typing import Optional
+
+from huggingface_hub import snapshot_download
 from zhipuai import ZhipuAI
 import os
 from abc import ABC
@@ -35,7 +37,10 @@ try:
         query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
         use_fp16=torch.cuda.is_available())
 except Exception as e:
-    flag_model = FlagModel("BAAI/bge-large-zh-v1.5",
+    model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
+                                  local_dir=os.path.join(get_project_base_directory(), "rag/res/bge-large-zh-v1.5"),
+                                  local_dir_use_symlinks=False)
+    flag_model = FlagModel(model_dir,
                            query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
                            use_fp16=torch.cuda.is_available())
 
-- 
GitLab