From 4568a4b2cbecde585e0e4ba9b2a1f78f5a10d2d9 Mon Sep 17 00:00:00 2001 From: KevinHuSh <kevinhu.sh@gmail.com> Date: Tue, 27 Feb 2024 14:57:34 +0800 Subject: [PATCH] refine admin initialization (#75) --- api/apps/chunk_app.py | 4 +- api/apps/conversation_app.py | 4 +- api/db/init_data.py | 46 ++++++++++++++++++-- api/settings.py | 6 ++- deepdoc/parser/pdf_parser.py | 2 +- deepdoc/vision/layout_recognizer.py | 2 +- deepdoc/vision/postprocess.py | 5 +-- deepdoc/vision/recognizer.py | 12 +++++ deepdoc/vision/t_recognizer.py | 4 +- deepdoc/vision/table_structure_recognizer.py | 10 ++--- rag/llm/chat_model.py | 19 ++++---- rag/nlp/__init__.py | 5 +-- rag/nlp/search.py | 6 ++- 13 files changed, 91 insertions(+), 34 deletions(-) diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 11d9c7b..1ead645 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -20,7 +20,7 @@ from flask_login import login_required, current_user from elasticsearch_dsl import Q from rag.app.qa import rmPrefix, beAdoc -from rag.nlp import search, huqie, retrievaler +from rag.nlp import search, huqie from rag.utils import ELASTICSEARCH, rmSpace from api.db import LLMType, ParserType from api.db.services.knowledgebase_service import KnowledgebaseService @@ -28,7 +28,7 @@ from api.db.services.llm_service import TenantLLMService from api.db.services.user_service import UserTenantService from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.db.services.document_service import DocumentService -from api.settings import RetCode +from api.settings import RetCode, retrievaler from api.utils.api_utils import get_json_result import hashlib import re diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index fe1047b..13e02aa 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -21,13 +21,11 @@ from api.db.services.dialog_service import DialogService, ConversationService from api.db import LLMType from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMService, LLMBundle -from api.settings import access_logger, stat_logger +from api.settings import access_logger, stat_logger, retrievaler from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils import get_uuid from api.utils.api_utils import get_json_result from rag.app.resume import forbidden_select_fields4resume -from rag.llm import ChatModel -from rag.nlp import retrievaler from rag.nlp.search import index_name from rag.utils import num_tokens_from_string, encoder, rmSpace diff --git a/api/db/init_data.py b/api/db/init_data.py index 5e4a812..ee91fd8 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -16,10 +16,12 @@ import time import uuid -from api.db import LLMType +from api.db import LLMType, UserTenantRole from api.db.db_models import init_database_tables as init_web_db from api.db.services import UserService -from api.db.services.llm_service import LLMFactoriesService, LLMService +from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle +from api.db.services.user_service import TenantService, UserTenantService +from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY def init_superuser(): @@ -32,8 +34,44 @@ def init_superuser(): "creator": "system", "status": "1", } + tenant = { + "id": user_info["id"], + "name": user_info["nickname"] + "â€s Kingdom", + "llm_id": CHAT_MDL, + "embd_id": EMBEDDING_MDL, + "asr_id": ASR_MDL, + "parser_ids": PARSERS, + "img2txt_id": IMAGE2TEXT_MDL + } + usr_tenant = { + "tenant_id": user_info["id"], + "user_id": user_info["id"], + "invited_by": user_info["id"], + "role": UserTenantRole.OWNER + } + tenant_llm = [] + for llm in LLMService.query(fid=LLM_FACTORY): + tenant_llm.append( + {"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type, + "api_key": API_KEY}) + + if not UserService.save(**user_info): + print("ă€ERROR】can't init admin.") + return + TenantService.save(**tenant) + UserTenantService.save(**usr_tenant) + TenantLLMService.insert_many(tenant_llm) UserService.save(**user_info) + chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) + msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}) + if msg.find("ERROR: ") == 0: + print("ă€ERROR】: '{}' dosen't work. {}".format(tenant["llm_id"]), msg) + embd_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["embd_id"]) + v,c = embd_mdl.encode(["Hello!"]) + if c == 0: + print("ă€ERROR】: '{}' dosen't work...".format(tenant["embd_id"])) + def init_llm_factory(): factory_infos = [{ @@ -171,10 +209,10 @@ def init_llm_factory(): def init_web_data(): start_time = time.time() - if not UserService.get_all().count(): - init_superuser() if not LLMService.get_all().count():init_llm_factory() + if not UserService.get_all().count(): + init_superuser() print("init web data success:{}".format(time.time() - start_time)) diff --git a/api/settings.py b/api/settings.py index a5882a5..08f7dc7 100644 --- a/api/settings.py +++ b/api/settings.py @@ -21,8 +21,10 @@ from api.utils import get_base_config,decrypt_database_config from api.utils.file_utils import get_project_base_directory from api.utils.log_utils import LoggerFactory, getLogger +from rag.nlp import search +from rag.utils import ELASTICSEARCH + -# Server API_VERSION = "v1" RAG_FLOW_SERVICE_NAME = "ragflow" SERVER_MODULE = "rag_flow_server.py" @@ -116,6 +118,8 @@ AUTHENTICATION_DEFAULT_TIMEOUT = 30 * 24 * 60 * 60 # s PRIVILEGE_COMMAND_WHITELIST = [] CHECK_NODES_IDENTITY = False +retrievaler = search.Dealer(ELASTICSEARCH) + class CustomEnum(Enum): @classmethod def valid(cls, value): diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index c5d2e6a..f333b1d 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -230,7 +230,7 @@ class HuParser: b["H_right"] = headers[ii]["x1"] b["H"] = ii - ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3) + ii = Recognizer.find_horizontally_tightest_fit(b, clmns) if ii is not None: b["C"] = ii b["C_left"] = clmns[ii]["x0"] diff --git a/deepdoc/vision/layout_recognizer.py b/deepdoc/vision/layout_recognizer.py index 7d0abb9..7a7791a 100644 --- a/deepdoc/vision/layout_recognizer.py +++ b/deepdoc/vision/layout_recognizer.py @@ -37,7 +37,7 @@ class LayoutRecognizer(Recognizer): super().__init__(self.labels, domain, os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) - def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16): + def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16): def __is_garbage(b): patt = [r"^•+$", r"(ç‰ćťĺ˝’©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$", r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", diff --git a/deepdoc/vision/postprocess.py b/deepdoc/vision/postprocess.py index b3d1ac7..ec6f69d 100644 --- a/deepdoc/vision/postprocess.py +++ b/deepdoc/vision/postprocess.py @@ -2,7 +2,6 @@ import copy import numpy as np import cv2 -import paddle from shapely.geometry import Polygon import pyclipper @@ -215,7 +214,7 @@ class DBPostProcess(object): def __call__(self, outs_dict, shape_list): pred = outs_dict['maps'] - if isinstance(pred, paddle.Tensor): + if not isinstance(pred, np.ndarray): pred = pred.numpy() pred = pred[:, 0, :, :] segmentation = pred > self.thresh @@ -339,7 +338,7 @@ class CTCLabelDecode(BaseRecLabelDecode): def __call__(self, preds, label=None, *args, **kwargs): if isinstance(preds, tuple) or isinstance(preds, list): preds = preds[-1] - if isinstance(preds, paddle.Tensor): + if not isinstance(preds, np.ndarray): preds = preds.numpy() preds_idx = preds.argmax(axis=2) preds_prob = preds.max(axis=2) diff --git a/deepdoc/vision/recognizer.py b/deepdoc/vision/recognizer.py index 7a13019..b0b7aea 100644 --- a/deepdoc/vision/recognizer.py +++ b/deepdoc/vision/recognizer.py @@ -259,6 +259,18 @@ class Recognizer(object): return max_overlaped_i + @staticmethod + def find_horizontally_tightest_fit(box, boxes): + if not boxes: + return + min_dis, min_i = 1000000, None + for i,b in enumerate(boxes): + dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2) + if dis < min_dis: + min_i = i + min_dis = dis + return min_i + @staticmethod def find_overlapped_with_threashold(box, boxes, thr=0.3): if not boxes: diff --git a/deepdoc/vision/t_recognizer.py b/deepdoc/vision/t_recognizer.py index 4ec1f29..7358c4e 100644 --- a/deepdoc/vision/t_recognizer.py +++ b/deepdoc/vision/t_recognizer.py @@ -74,6 +74,7 @@ def get_table_html(img, tb_cpns, ocr): clmns = sorted([r for r in tb_cpns if re.match( r"table column$", r["label"])], key=lambda x: x["x0"]) clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5) + for b in boxes: ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3) if ii is not None: @@ -89,7 +90,7 @@ def get_table_html(img, tb_cpns, ocr): b["H_right"] = headers[ii]["x1"] b["H"] = ii - ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3) + ii = Recognizer.find_horizontally_tightest_fit(b, clmns) if ii is not None: b["C"] = ii b["C_left"] = clmns[ii]["x0"] @@ -102,6 +103,7 @@ def get_table_html(img, tb_cpns, ocr): b["H_left"] = spans[ii]["x0"] b["H_right"] = spans[ii]["x1"] b["SP"] = ii + html = """ <html> <head> diff --git a/deepdoc/vision/table_structure_recognizer.py b/deepdoc/vision/table_structure_recognizer.py index 26b4a1e..e396cbf 100644 --- a/deepdoc/vision/table_structure_recognizer.py +++ b/deepdoc/vision/table_structure_recognizer.py @@ -14,7 +14,6 @@ import logging import os import re from collections import Counter -from copy import deepcopy import numpy as np @@ -37,7 +36,7 @@ class TableStructureRecognizer(Recognizer): super().__init__(self.labels, "tsr", os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) - def __call__(self, images, thr=0.5): + def __call__(self, images, thr=0.2): tbls = super().__call__(images, thr) res = [] # align left&right for rows, align top&bottom for columns @@ -56,8 +55,8 @@ class TableStructureRecognizer(Recognizer): "row") > 0 or b["label"].find("header") > 0] if not left: continue - left = np.median(left) if len(left) > 4 else np.min(left) - right = np.median(right) if len(right) > 4 else np.max(right) + left = np.mean(left) if len(left) > 4 else np.min(left) + right = np.mean(right) if len(right) > 4 else np.max(right) for b in lts: if b["label"].find("row") > 0 or b["label"].find("header") > 0: if b["x0"] > left: @@ -129,6 +128,7 @@ class TableStructureRecognizer(Recognizer): i = 0 while i < len(boxes): if TableStructureRecognizer.is_caption(boxes[i]): + if is_english: cap + " " cap += boxes[i]["text"] boxes.pop(i) i -= 1 @@ -398,7 +398,7 @@ class TableStructureRecognizer(Recognizer): for i in range(clmno): if not tbl[r][i]: continue - txt = "".join([a["text"].strip() for a in tbl[r][i]]) + txt = " ".join([a["text"].strip() for a in tbl[r][i]]) headers[r][i] = txt hdrset.add(txt) if all([not t for t in headers[r]]): diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index cdc2c8b..7b95cd6 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -15,7 +15,7 @@ # from abc import ABC from openai import OpenAI -import os +import openai class Base(ABC): @@ -33,11 +33,14 @@ class GptTurbo(Base): def chat(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) - res = self.client.chat.completions.create( - model=self.model_name, - messages=history, - **gen_conf) - return res.choices[0].message.content.strip(), res.usage.completion_tokens + try: + res = self.client.chat.completions.create( + model=self.model_name, + messages=history, + **gen_conf) + return res.choices[0].message.content.strip(), res.usage.completion_tokens + except openai.APIError as e: + return "ERROR: "+str(e), 0 from dashscope import Generation @@ -58,7 +61,7 @@ class QWenChat(Base): ) if response.status_code == HTTPStatus.OK: return response.output.choices[0]['message']['content'], response.usage.output_tokens - return response.message, 0 + return "ERROR: " + response.message, 0 from zhipuai import ZhipuAI @@ -77,4 +80,4 @@ class ZhipuChat(Base): ) if response.status_code == HTTPStatus.OK: return response.output.choices[0]['message']['content'], response.usage.completion_tokens - return response.message, 0 \ No newline at end of file + return "ERROR: " + response.message, 0 \ No newline at end of file diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py index 2d4bb2d..fcb306e 100644 --- a/rag/nlp/__init__.py +++ b/rag/nlp/__init__.py @@ -1,7 +1,4 @@ -from . import search -from rag.utils import ELASTICSEARCH -retrievaler = search.Dealer(ELASTICSEARCH) from nltk.stem import PorterStemmer stemmer = PorterStemmer() @@ -39,10 +36,12 @@ BULLET_PATTERN = [[ ] ] + def random_choices(arr, k): k = min(len(arr), k) return random.choices(arr, k=k) + def bullets_category(sections): global BULLET_PATTERN hits = [0] * len(BULLET_PATTERN) diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 580caaf..5f9fb70 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import json import re -from elasticsearch_dsl import Q, Search, A +from elasticsearch_dsl import Q, Search from typing import List, Optional, Dict, Union from dataclasses import dataclass @@ -183,6 +183,7 @@ class Dealer: def insert_citations(self, answer, chunks, chunk_v, embd_mdl, tkweight=0.3, vtweight=0.7): + assert len(chunks) == len(chunk_v) pieces = re.split(r"([;。?!ďĽ\n]|[a-z][.?;!][ \n])", answer) for i in range(1, len(pieces)): if re.match(r"[a-z][.?;!][ \n]", pieces[i]): @@ -216,7 +217,7 @@ class Dealer: if mx < 0.55: continue cites[idx[i]] = list( - set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4] + set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4] res = "" for i, p in enumerate(pieces): @@ -225,6 +226,7 @@ class Dealer: continue if i not in cites: continue + assert int(cites[i]) < len(chunk_v) res += "##%s$$" % "$".join(cites[i]) return res -- GitLab