From 453c29170f305cc840b92d9065b94fbe15421ef7 Mon Sep 17 00:00:00 2001 From: KevinHuSh <kevinhu.sh@gmail.com> Date: Thu, 18 Apr 2024 09:37:23 +0800 Subject: [PATCH] make sure the models will not be load twice (#422) ### What problem does this PR solve? #381 ### Type of change - [x] Refactoring --- api/apps/api_app.py | 9 ++++----- api/db/db_models.py | 2 +- deepdoc/parser/pdf_parser.py | 4 +++- deepdoc/vision/layout_recognizer.py | 4 +++- deepdoc/vision/ocr.py | 4 +++- deepdoc/vision/recognizer.py | 4 +++- deepdoc/vision/table_structure_recognizer.py | 4 +++- rag/llm/embedding_model.py | 7 ++++++- 8 files changed, 26 insertions(+), 12 deletions(-) diff --git a/api/apps/api_app.py b/api/apps/api_app.py index f294272..cc6f646 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -105,8 +105,8 @@ def stats(): res = { "pv": [(o["dt"], o["pv"]) for o in objs], "uv": [(o["dt"], o["uv"]) for o in objs], - "speed": [(o["dt"], o["tokens"]/o["duration"]) for o in objs], - "tokens": [(o["dt"], o["tokens"]/1000.) for o in objs], + "speed": [(o["dt"], float(o["tokens"])/float(o["duration"])) for o in objs], + "tokens": [(o["dt"], float(o["tokens"])/1000.) for o in objs], "round": [(o["dt"], o["round"]) for o in objs], "thumb_up": [(o["dt"], o["thumb_up"]) for o in objs] } @@ -115,8 +115,7 @@ def stats(): return server_error_response(e) -@manager.route('/new_conversation', methods=['POST']) -@validate_request("user_id") +@manager.route('/new_conversation', methods=['GET']) def set_conversation(): token = request.headers.get('Authorization').split()[1] objs = APIToken.query(token=token) @@ -131,7 +130,7 @@ def set_conversation(): conv = { "id": get_uuid(), "dialog_id": dia.id, - "user_id": req["user_id"], + "user_id": request.args.get("user_id", ""), "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}] } API4ConversationService.save(**conv) diff --git a/api/db/db_models.py b/api/db/db_models.py index 27ad80f..e6f2d28 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -629,7 +629,7 @@ class Document(DataBaseModel): max_length=128, null=False, default="local", - help_text="where dose this document from") + help_text="where dose this document come from") type = CharField(max_length=32, null=False, help_text="file extension") created_by = CharField( max_length=32, diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index dfd3756..6c33245 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -43,7 +43,9 @@ class HuParser: model_dir, "updown_concat_xgb.model")) except Exception as e: model_dir = snapshot_download( - repo_id="InfiniFlow/text_concat_xgb_v1.0") + repo_id="InfiniFlow/text_concat_xgb_v1.0", + local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), + local_dir_use_symlinks=False) self.updown_cnt_mdl.load_model(os.path.join( model_dir, "updown_concat_xgb.model")) diff --git a/deepdoc/vision/layout_recognizer.py b/deepdoc/vision/layout_recognizer.py index 917ee6e..58ddcdb 100644 --- a/deepdoc/vision/layout_recognizer.py +++ b/deepdoc/vision/layout_recognizer.py @@ -43,7 +43,9 @@ class LayoutRecognizer(Recognizer): "rag/res/deepdoc") super().__init__(self.labels, domain, model_dir) except Exception as e: - model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc") + model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", + local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), + local_dir_use_symlinks=False) super().__init__(self.labels, domain, model_dir) self.garbage_layouts = ["footer", "header", "reference"] diff --git a/deepdoc/vision/ocr.py b/deepdoc/vision/ocr.py index b55024e..d602da0 100644 --- a/deepdoc/vision/ocr.py +++ b/deepdoc/vision/ocr.py @@ -486,7 +486,9 @@ class OCR(object): self.text_detector = TextDetector(model_dir) self.text_recognizer = TextRecognizer(model_dir) except Exception as e: - model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc") + model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", + local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), + local_dir_use_symlinks=False) self.text_detector = TextDetector(model_dir) self.text_recognizer = TextRecognizer(model_dir) diff --git a/deepdoc/vision/recognizer.py b/deepdoc/vision/recognizer.py index 67e096e..1ca7c44 100644 --- a/deepdoc/vision/recognizer.py +++ b/deepdoc/vision/recognizer.py @@ -41,7 +41,9 @@ class Recognizer(object): "rag/res/deepdoc") model_file_path = os.path.join(model_dir, task_name + ".onnx") if not os.path.exists(model_file_path): - model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc") + model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", + local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), + local_dir_use_symlinks=False) model_file_path = os.path.join(model_dir, task_name + ".onnx") else: model_file_path = os.path.join(model_dir, task_name + ".onnx") diff --git a/deepdoc/vision/table_structure_recognizer.py b/deepdoc/vision/table_structure_recognizer.py index 6779137..548eb62 100644 --- a/deepdoc/vision/table_structure_recognizer.py +++ b/deepdoc/vision/table_structure_recognizer.py @@ -39,7 +39,9 @@ class TableStructureRecognizer(Recognizer): get_project_base_directory(), "rag/res/deepdoc")) except Exception as e: - super().__init__(self.labels, "tsr", snapshot_download(repo_id="InfiniFlow/deepdoc")) + super().__init__(self.labels, "tsr", snapshot_download(repo_id="InfiniFlow/deepdoc", + local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), + local_dir_use_symlinks=False)) def __call__(self, images, thr=0.2): tbls = super().__call__(images, thr) diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index e6ff0b1..e6e18fb 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -14,6 +14,8 @@ # limitations under the License. # from typing import Optional + +from huggingface_hub import snapshot_download from zhipuai import ZhipuAI import os from abc import ABC @@ -35,7 +37,10 @@ try: query_instruction_for_retrieval="为这个句ĺ生ćčˇ¨ç¤şä»Ąç”¨äşŽćŁ€ç´˘ç›¸ĺ…łć–‡ç« ďĽš", use_fp16=torch.cuda.is_available()) except Exception as e: - flag_model = FlagModel("BAAI/bge-large-zh-v1.5", + model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", + local_dir=os.path.join(get_project_base_directory(), "rag/res/bge-large-zh-v1.5"), + local_dir_use_symlinks=False) + flag_model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句ĺ生ćčˇ¨ç¤şä»Ąç”¨äşŽćŁ€ç´˘ç›¸ĺ…łć–‡ç« ďĽš", use_fp16=torch.cuda.is_available()) -- GitLab