From 4568a4b2cbecde585e0e4ba9b2a1f78f5a10d2d9 Mon Sep 17 00:00:00 2001
From: KevinHuSh <kevinhu.sh@gmail.com>
Date: Tue, 27 Feb 2024 14:57:34 +0800
Subject: [PATCH] refine admin initialization (#75)

---
 api/apps/chunk_app.py                        |  4 +-
 api/apps/conversation_app.py                 |  4 +-
 api/db/init_data.py                          | 46 ++++++++++++++++++--
 api/settings.py                              |  6 ++-
 deepdoc/parser/pdf_parser.py                 |  2 +-
 deepdoc/vision/layout_recognizer.py          |  2 +-
 deepdoc/vision/postprocess.py                |  5 +--
 deepdoc/vision/recognizer.py                 | 12 +++++
 deepdoc/vision/t_recognizer.py               |  4 +-
 deepdoc/vision/table_structure_recognizer.py | 10 ++---
 rag/llm/chat_model.py                        | 19 ++++----
 rag/nlp/__init__.py                          |  5 +--
 rag/nlp/search.py                            |  6 ++-
 13 files changed, 91 insertions(+), 34 deletions(-)

diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py
index 11d9c7b..1ead645 100644
--- a/api/apps/chunk_app.py
+++ b/api/apps/chunk_app.py
@@ -20,7 +20,7 @@ from flask_login import login_required, current_user
 from elasticsearch_dsl import Q
 
 from rag.app.qa import rmPrefix, beAdoc
-from rag.nlp import search, huqie, retrievaler
+from rag.nlp import search, huqie
 from rag.utils import ELASTICSEARCH, rmSpace
 from api.db import LLMType, ParserType
 from api.db.services.knowledgebase_service import KnowledgebaseService
@@ -28,7 +28,7 @@ from api.db.services.llm_service import TenantLLMService
 from api.db.services.user_service import UserTenantService
 from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
 from api.db.services.document_service import DocumentService
-from api.settings import RetCode
+from api.settings import RetCode, retrievaler
 from api.utils.api_utils import get_json_result
 import hashlib
 import re
diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py
index fe1047b..13e02aa 100644
--- a/api/apps/conversation_app.py
+++ b/api/apps/conversation_app.py
@@ -21,13 +21,11 @@ from api.db.services.dialog_service import DialogService, ConversationService
 from api.db import LLMType
 from api.db.services.knowledgebase_service import KnowledgebaseService
 from api.db.services.llm_service import LLMService, LLMBundle
-from api.settings import access_logger, stat_logger
+from api.settings import access_logger, stat_logger, retrievaler
 from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
 from api.utils import get_uuid
 from api.utils.api_utils import get_json_result
 from rag.app.resume import forbidden_select_fields4resume
-from rag.llm import ChatModel
-from rag.nlp import retrievaler
 from rag.nlp.search import index_name
 from rag.utils import num_tokens_from_string, encoder, rmSpace
 
diff --git a/api/db/init_data.py b/api/db/init_data.py
index 5e4a812..ee91fd8 100644
--- a/api/db/init_data.py
+++ b/api/db/init_data.py
@@ -16,10 +16,12 @@
 import time
 import uuid
 
-from api.db import LLMType
+from api.db import LLMType, UserTenantRole
 from api.db.db_models import init_database_tables as init_web_db
 from api.db.services import UserService
-from api.db.services.llm_service import LLMFactoriesService, LLMService
+from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
+from api.db.services.user_service import TenantService, UserTenantService
+from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY
 
 
 def init_superuser():
@@ -32,8 +34,44 @@ def init_superuser():
         "creator": "system",
         "status": "1",
     }
+    tenant = {
+        "id": user_info["id"],
+        "name": user_info["nickname"] + "‘s Kingdom",
+        "llm_id": CHAT_MDL,
+        "embd_id": EMBEDDING_MDL,
+        "asr_id": ASR_MDL,
+        "parser_ids": PARSERS,
+        "img2txt_id": IMAGE2TEXT_MDL
+    }
+    usr_tenant = {
+        "tenant_id": user_info["id"],
+        "user_id": user_info["id"],
+        "invited_by": user_info["id"],
+        "role": UserTenantRole.OWNER
+    }
+    tenant_llm = []
+    for llm in LLMService.query(fid=LLM_FACTORY):
+        tenant_llm.append(
+            {"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type,
+             "api_key": API_KEY})
+
+    if not UserService.save(**user_info):
+        print("【ERROR】can't init admin.")
+        return
+    TenantService.save(**tenant)
+    UserTenantService.save(**usr_tenant)
+    TenantLLMService.insert_many(tenant_llm)
     UserService.save(**user_info)
 
+    chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
+    msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
+    if msg.find("ERROR: ") == 0:
+        print("【ERROR】: '{}' dosen't work. {}".format(tenant["llm_id"]), msg)
+    embd_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["embd_id"])
+    v,c = embd_mdl.encode(["Hello!"])
+    if c == 0:
+        print("【ERROR】: '{}' dosen't work...".format(tenant["embd_id"]))
+
 
 def init_llm_factory():
     factory_infos = [{
@@ -171,10 +209,10 @@ def init_llm_factory():
 
 def init_web_data():
     start_time = time.time()
-    if not UserService.get_all().count():
-        init_superuser()
 
     if not LLMService.get_all().count():init_llm_factory()
+    if not UserService.get_all().count():
+        init_superuser()
 
     print("init web data success:{}".format(time.time() - start_time))
 
diff --git a/api/settings.py b/api/settings.py
index a5882a5..08f7dc7 100644
--- a/api/settings.py
+++ b/api/settings.py
@@ -21,8 +21,10 @@ from api.utils import get_base_config,decrypt_database_config
 from api.utils.file_utils import get_project_base_directory
 from api.utils.log_utils import LoggerFactory, getLogger
 
+from rag.nlp import search
+from rag.utils import ELASTICSEARCH
+
 
-# Server
 API_VERSION = "v1"
 RAG_FLOW_SERVICE_NAME = "ragflow"
 SERVER_MODULE = "rag_flow_server.py"
@@ -116,6 +118,8 @@ AUTHENTICATION_DEFAULT_TIMEOUT = 30 * 24 * 60 * 60 # s
 PRIVILEGE_COMMAND_WHITELIST = []
 CHECK_NODES_IDENTITY = False
 
+retrievaler = search.Dealer(ELASTICSEARCH)
+
 class CustomEnum(Enum):
     @classmethod
     def valid(cls, value):
diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py
index c5d2e6a..f333b1d 100644
--- a/deepdoc/parser/pdf_parser.py
+++ b/deepdoc/parser/pdf_parser.py
@@ -230,7 +230,7 @@ class HuParser:
                 b["H_right"] = headers[ii]["x1"]
                 b["H"] = ii
 
-            ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3)
+            ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
             if ii is not None:
                 b["C"] = ii
                 b["C_left"] = clmns[ii]["x0"]
diff --git a/deepdoc/vision/layout_recognizer.py b/deepdoc/vision/layout_recognizer.py
index 7d0abb9..7a7791a 100644
--- a/deepdoc/vision/layout_recognizer.py
+++ b/deepdoc/vision/layout_recognizer.py
@@ -37,7 +37,7 @@ class LayoutRecognizer(Recognizer):
         super().__init__(self.labels, domain,
                          os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
 
-    def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16):
+    def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
         def __is_garbage(b):
             patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
                     r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
diff --git a/deepdoc/vision/postprocess.py b/deepdoc/vision/postprocess.py
index b3d1ac7..ec6f69d 100644
--- a/deepdoc/vision/postprocess.py
+++ b/deepdoc/vision/postprocess.py
@@ -2,7 +2,6 @@ import copy
 
 import numpy as np
 import cv2
-import paddle
 from shapely.geometry import Polygon
 import pyclipper
 
@@ -215,7 +214,7 @@ class DBPostProcess(object):
 
     def __call__(self, outs_dict, shape_list):
         pred = outs_dict['maps']
-        if isinstance(pred, paddle.Tensor):
+        if not isinstance(pred, np.ndarray):
             pred = pred.numpy()
         pred = pred[:, 0, :, :]
         segmentation = pred > self.thresh
@@ -339,7 +338,7 @@ class CTCLabelDecode(BaseRecLabelDecode):
     def __call__(self, preds, label=None, *args, **kwargs):
         if isinstance(preds, tuple) or isinstance(preds, list):
             preds = preds[-1]
-        if isinstance(preds, paddle.Tensor):
+        if not isinstance(preds, np.ndarray):
             preds = preds.numpy()
         preds_idx = preds.argmax(axis=2)
         preds_prob = preds.max(axis=2)
diff --git a/deepdoc/vision/recognizer.py b/deepdoc/vision/recognizer.py
index 7a13019..b0b7aea 100644
--- a/deepdoc/vision/recognizer.py
+++ b/deepdoc/vision/recognizer.py
@@ -259,6 +259,18 @@ class Recognizer(object):
 
         return max_overlaped_i
 
+    @staticmethod
+    def find_horizontally_tightest_fit(box, boxes):
+        if not boxes:
+            return
+        min_dis, min_i = 1000000, None
+        for i,b in enumerate(boxes):
+            dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
+            if dis < min_dis:
+                min_i = i
+                min_dis = dis
+        return min_i
+
     @staticmethod
     def find_overlapped_with_threashold(box, boxes, thr=0.3):
         if not boxes:
diff --git a/deepdoc/vision/t_recognizer.py b/deepdoc/vision/t_recognizer.py
index 4ec1f29..7358c4e 100644
--- a/deepdoc/vision/t_recognizer.py
+++ b/deepdoc/vision/t_recognizer.py
@@ -74,6 +74,7 @@ def get_table_html(img, tb_cpns, ocr):
     clmns = sorted([r for r in tb_cpns if re.match(
         r"table column$", r["label"])], key=lambda x: x["x0"])
     clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5)
+
     for b in boxes:
         ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3)
         if ii is not None:
@@ -89,7 +90,7 @@ def get_table_html(img, tb_cpns, ocr):
             b["H_right"] = headers[ii]["x1"]
             b["H"] = ii
 
-        ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3)
+        ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
         if ii is not None:
             b["C"] = ii
             b["C_left"] = clmns[ii]["x0"]
@@ -102,6 +103,7 @@ def get_table_html(img, tb_cpns, ocr):
             b["H_left"] = spans[ii]["x0"]
             b["H_right"] = spans[ii]["x1"]
             b["SP"] = ii
+
     html = """
     <html>
     <head>
diff --git a/deepdoc/vision/table_structure_recognizer.py b/deepdoc/vision/table_structure_recognizer.py
index 26b4a1e..e396cbf 100644
--- a/deepdoc/vision/table_structure_recognizer.py
+++ b/deepdoc/vision/table_structure_recognizer.py
@@ -14,7 +14,6 @@ import logging
 import os
 import re
 from collections import Counter
-from copy import deepcopy
 
 import numpy as np
 
@@ -37,7 +36,7 @@ class TableStructureRecognizer(Recognizer):
         super().__init__(self.labels, "tsr",
                          os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
 
-    def __call__(self, images, thr=0.5):
+    def __call__(self, images, thr=0.2):
         tbls = super().__call__(images, thr)
         res = []
         # align left&right for rows, align top&bottom for columns
@@ -56,8 +55,8 @@ class TableStructureRecognizer(Recognizer):
                 "row") > 0 or b["label"].find("header") > 0]
             if not left:
                 continue
-            left = np.median(left) if len(left) > 4 else np.min(left)
-            right = np.median(right) if len(right) > 4 else np.max(right)
+            left = np.mean(left) if len(left) > 4 else np.min(left)
+            right = np.mean(right) if len(right) > 4 else np.max(right)
             for b in lts:
                 if b["label"].find("row") > 0 or b["label"].find("header") > 0:
                     if b["x0"] > left:
@@ -129,6 +128,7 @@ class TableStructureRecognizer(Recognizer):
         i = 0
         while i < len(boxes):
             if TableStructureRecognizer.is_caption(boxes[i]):
+                if is_english: cap + " "
                 cap += boxes[i]["text"]
                 boxes.pop(i)
                 i -= 1
@@ -398,7 +398,7 @@ class TableStructureRecognizer(Recognizer):
             for i in range(clmno):
                 if not tbl[r][i]:
                     continue
-                txt = "".join([a["text"].strip() for a in tbl[r][i]])
+                txt = " ".join([a["text"].strip() for a in tbl[r][i]])
                 headers[r][i] = txt
                 hdrset.add(txt)
             if all([not t for t in headers[r]]):
diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py
index cdc2c8b..7b95cd6 100644
--- a/rag/llm/chat_model.py
+++ b/rag/llm/chat_model.py
@@ -15,7 +15,7 @@
 #
 from abc import ABC
 from openai import OpenAI
-import os
+import openai
 
 
 class Base(ABC):
@@ -33,11 +33,14 @@ class GptTurbo(Base):
 
     def chat(self, system, history, gen_conf):
         if system: history.insert(0, {"role": "system", "content": system})
-        res = self.client.chat.completions.create(
-            model=self.model_name,
-            messages=history,
-            **gen_conf)
-        return res.choices[0].message.content.strip(), res.usage.completion_tokens
+        try:
+            res = self.client.chat.completions.create(
+                model=self.model_name,
+                messages=history,
+                **gen_conf)
+            return res.choices[0].message.content.strip(), res.usage.completion_tokens
+        except openai.APIError as e:
+            return "ERROR: "+str(e), 0
 
 
 from dashscope import Generation
@@ -58,7 +61,7 @@ class QWenChat(Base):
         )
         if response.status_code == HTTPStatus.OK:
             return response.output.choices[0]['message']['content'], response.usage.output_tokens
-        return response.message, 0
+        return "ERROR: " + response.message, 0
 
 
 from zhipuai import ZhipuAI
@@ -77,4 +80,4 @@ class ZhipuChat(Base):
         )
         if response.status_code == HTTPStatus.OK:
             return response.output.choices[0]['message']['content'], response.usage.completion_tokens
-        return response.message, 0
\ No newline at end of file
+        return "ERROR: " + response.message, 0
\ No newline at end of file
diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py
index 2d4bb2d..fcb306e 100644
--- a/rag/nlp/__init__.py
+++ b/rag/nlp/__init__.py
@@ -1,7 +1,4 @@
-from . import search
-from rag.utils import ELASTICSEARCH
 
-retrievaler = search.Dealer(ELASTICSEARCH)
 
 from nltk.stem import PorterStemmer
 stemmer = PorterStemmer()
@@ -39,10 +36,12 @@ BULLET_PATTERN = [[
 ]
 ]
 
+
 def random_choices(arr, k):
     k = min(len(arr), k)
     return random.choices(arr, k=k)
 
+
 def bullets_category(sections):
     global BULLET_PATTERN
     hits = [0] * len(BULLET_PATTERN)
diff --git a/rag/nlp/search.py b/rag/nlp/search.py
index 580caaf..5f9fb70 100644
--- a/rag/nlp/search.py
+++ b/rag/nlp/search.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 import json
 import re
-from elasticsearch_dsl import Q, Search, A
+from elasticsearch_dsl import Q, Search
 from typing import List, Optional, Dict, Union
 from dataclasses import dataclass
 
@@ -183,6 +183,7 @@ class Dealer:
 
     def insert_citations(self, answer, chunks, chunk_v,
                          embd_mdl, tkweight=0.3, vtweight=0.7):
+        assert len(chunks) == len(chunk_v)
         pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer)
         for i in range(1, len(pieces)):
             if re.match(r"[a-z][.?;!][ \n]", pieces[i]):
@@ -216,7 +217,7 @@ class Dealer:
             if mx < 0.55:
                 continue
             cites[idx[i]] = list(
-                set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4]
+                set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
 
         res = ""
         for i, p in enumerate(pieces):
@@ -225,6 +226,7 @@ class Dealer:
                 continue
             if i not in cites:
                 continue
+            assert int(cites[i]) < len(chunk_v)
             res += "##%s$$" % "$".join(cites[i])
 
         return res
-- 
GitLab