From 407b2523b6d19fd6883c90a3e13c7a6e87f3eae5 Mon Sep 17 00:00:00 2001 From: KevinHuSh <kevinhu.sh@gmail.com> Date: Mon, 5 Feb 2024 18:08:17 +0800 Subject: [PATCH] remove unused codes, seperate layout detection out as a new api. Add new rag methed 'table' (#55) --- api/apps/__init__.py | 32 ----- api/apps/chunk_app.py | 12 +- api/apps/conversation_app.py | 2 +- api/apps/document_app.py | 10 +- api/apps/kb_app.py | 2 +- api/db/__init__.py | 1 + api/db/db_models.py | 9 +- api/db/db_services.py | 157 --------------------- api/db/services/document_service.py | 2 +- api/db/services/knowledgebase_service.py | 3 +- api/db/services/task_service.py | 12 +- api/db/services/user_service.py | 6 + api/errors/__init__.py | 10 -- api/errors/error_services.py | 13 -- api/errors/general_error.py | 21 --- api/hook/__init__.py | 57 -------- api/hook/api/client_authentication.py | 29 ---- api/hook/api/permission.py | 25 ---- api/hook/api/site_authentication.py | 49 ------- api/hook/common/parameters.py | 56 -------- api/ragflow_server.py | 5 - api/settings.py | 2 +- rag/app/book.py | 12 +- rag/app/laws.py | 3 +- rag/app/naive.py | 2 +- rag/app/paper.py | 25 +++- rag/app/qa.py | 6 +- rag/app/table.py | 170 +++++++++++++++++++++++ rag/nlp/search.py | 5 +- rag/parser/__init__.py | 4 + rag/parser/pdf_parser.py | 53 ++++++- rag/svr/task_broker.py | 2 +- rag/svr/task_executor.py | 14 +- 33 files changed, 306 insertions(+), 505 deletions(-) delete mode 100644 api/db/db_services.py delete mode 100644 api/errors/__init__.py delete mode 100644 api/errors/error_services.py delete mode 100644 api/errors/general_error.py delete mode 100644 api/hook/__init__.py delete mode 100644 api/hook/api/client_authentication.py delete mode 100644 api/hook/api/permission.py delete mode 100644 api/hook/api/site_authentication.py delete mode 100644 api/hook/common/parameters.py create mode 100644 rag/app/table.py diff --git a/api/apps/__init__.py b/api/apps/__init__.py index bde66d8..a53663b 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -28,8 +28,6 @@ from api.utils import CustomJSONEncoder from flask_session import Session from flask_login import LoginManager from api.settings import RetCode, SECRET_KEY, stat_logger -from api.hook import HookManager -from api.hook.common.parameters import AuthenticationParameters, ClientAuthenticationParameters from api.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger from api.utils.api_utils import get_json_result, server_error_response from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer @@ -96,37 +94,7 @@ client_urls_prefix = [ ] -def client_authentication_before_request(): - result = HookManager.client_authentication(ClientAuthenticationParameters( - request.full_path, request.headers, - request.form, request.data, request.json, - )) - if result.code != RetCode.SUCCESS: - return get_json_result(result.code, result.message) - - -def site_authentication_before_request(): - for url_prefix in client_urls_prefix: - if request.path.startswith(url_prefix): - return - - result = HookManager.site_authentication(AuthenticationParameters( - request.headers.get('site_signature'), - request.json, - )) - - if result.code != RetCode.SUCCESS: - return get_json_result(result.code, result.message) - - -@app.before_request -def authentication_before_request(): - if CLIENT_AUTHENTICATION: - return client_authentication_before_request() - - if SITE_AUTHENTICATION: - return site_authentication_before_request() @login_manager.request_loader def load_user(web_request): diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 53caa76..9a5a168 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -57,7 +57,7 @@ def list(): for id in sres.ids: d = { "chunk_id": id, - "content_ltks": rmSpace(sres.highlight[id]) if question else sres.field[id]["content_ltks"], + "content_with_weight": rmSpace(sres.highlight[id]) if question else sres.field[id]["content_with_weight"], "doc_id": sres.field[id]["doc_id"], "docnm_kwd": sres.field[id]["docnm_kwd"], "important_kwd": sres.field[id].get("important_kwd", []), @@ -134,7 +134,7 @@ def set(): q, a = rmPrefix(arr[0]), rmPrefix[arr[1]] d = beAdoc(d, arr[0], arr[1], not any([huqie.is_chinese(t) for t in q+a])) - v, c = embd_mdl.encode([doc.name, req["content_ltks"]]) + v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] d["q_%d_vec" % len(v)] = v.tolist() ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) @@ -175,13 +175,13 @@ def rm(): @manager.route('/create', methods=['POST']) @login_required -@validate_request("doc_id", "content_ltks") +@validate_request("doc_id", "content_with_weight") def create(): req = request.json md5 = hashlib.md5() - md5.update((req["content_ltks"] + req["doc_id"]).encode("utf-8")) + md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8")) chunck_id = md5.hexdigest() - d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_ltks"])} + d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_with_weight"])} d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) d["important_kwd"] = req.get("important_kwd", []) d["important_tks"] = huqie.qie(" ".join(req.get("important_kwd", []))) @@ -201,7 +201,7 @@ def create(): embd_mdl = TenantLLMService.model_instance( tenant_id, LLMType.EMBEDDING.value) - v, c = embd_mdl.encode([doc.name, req["content_ltks"]]) + v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) DocumentService.increment_chunk_num(req["doc_id"], doc.kb_id, c, 1, 0) v = 0.1 * v[0] + 0.9 * v[1] d["q_%d_vec" % len(v)] = v.tolist() diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index b26fe7e..c5c8466 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -175,7 +175,7 @@ def chat(dialog, messages, **kwargs): chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold, dialog.vector_similarity_weight, top=1024, aggs=False) - knowledges = [ck["content_ltks"] for ck in kbinfos["chunks"]] + knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] if not knowledges and prompt_config["empty_response"]: return {"answer": prompt_config["empty_response"], "retrieval": kbinfos} diff --git a/api/apps/document_app.py b/api/apps/document_app.py index e43bfc7..32b8d12 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -73,6 +73,7 @@ def upload(): "id": get_uuid(), "kb_id": kb.id, "parser_id": kb.parser_id, + "parser_config": kb.parser_config, "created_by": current_user.id, "type": filename_type(filename), "name": filename, @@ -108,6 +109,7 @@ def create(): "id": get_uuid(), "kb_id": kb.id, "parser_id": kb.parser_id, + "parser_config": kb.parser_config, "created_by": current_user.id, "type": FileType.VIRTUAL, "name": req["name"], @@ -128,8 +130,8 @@ def list(): data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR) keywords = request.args.get("keywords", "") - page_number = request.args.get("page", 1) - items_per_page = request.args.get("page_size", 15) + page_number = int(request.args.get("page", 1)) + items_per_page = int(request.args.get("page_size", 15)) orderby = request.args.get("orderby", "create_time") desc = request.args.get("desc", True) try: @@ -214,7 +216,9 @@ def run(): req = request.json try: for id in req["doc_ids"]: - DocumentService.update_by_id(id, {"run": str(req["run"]), "progress": 0}) + info = {"run": str(req["run"]), "progress": 0} + if str(req["run"]) == TaskStatus.RUNNING.value:info["progress_msg"] = "" + DocumentService.update_by_id(id, info) if str(req["run"]) == TaskStatus.CANCEL.value: tenant_id = DocumentService.get_tenant_id(id) if not tenant_id: diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 15e6be8..a7f4009 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -29,7 +29,7 @@ from api.utils.api_utils import get_json_result @manager.route('/create', methods=['post']) @login_required -@validate_request("name", "description", "permission", "parser_id") +@validate_request("name") def create(): req = request.json req["name"] = req["name"].strip() diff --git a/api/db/__init__.py b/api/db/__init__.py index de37613..c657dee 100644 --- a/api/db/__init__.py +++ b/api/db/__init__.py @@ -77,3 +77,4 @@ class ParserType(StrEnum): RESUME = "resume" BOOK = "book" QA = "qa" + TABLE = "table" diff --git a/api/db/db_models.py b/api/db/db_models.py index 3b7b5bc..db641cf 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -29,7 +29,7 @@ from peewee import ( ) from playhouse.pool import PooledMySQLDatabase -from api.db import SerializedType +from api.db import SerializedType, ParserType from api.settings import DATABASE, stat_logger, SECRET_KEY from api.utils.log_utils import getLogger from api import utils @@ -381,7 +381,8 @@ class Tenant(DataBaseModel): embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID") asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID") img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID") - parser_ids = CharField(max_length=128, null=False, help_text="default image to text model ID") + parser_ids = CharField(max_length=128, null=False, help_text="document processors") + credit = IntegerField(default=512) status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") class Meta: @@ -472,7 +473,8 @@ class Knowledgebase(DataBaseModel): similarity_threshold = FloatField(default=0.2) vector_similarity_weight = FloatField(default=0.3) - parser_id = CharField(max_length=32, null=False, help_text="default parser ID") + parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.GENERAL.value) + parser_config = JSONField(null=False, default={"from_page":0, "to_page": 100000}) status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") def __str__(self): @@ -487,6 +489,7 @@ class Document(DataBaseModel): thumbnail = TextField(null=True, help_text="thumbnail base64 string") kb_id = CharField(max_length=256, null=False, index=True) parser_id = CharField(max_length=32, null=False, help_text="default parser ID") + parser_config = JSONField(null=False, default={"from_page":0, "to_page": 100000}) source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from") type = CharField(max_length=32, null=False, help_text="file extension") created_by = CharField(max_length=32, null=False, help_text="who created it") diff --git a/api/db/db_services.py b/api/db/db_services.py deleted file mode 100644 index 9e734f5..0000000 --- a/api/db/db_services.py +++ /dev/null @@ -1,157 +0,0 @@ -# -# Copyright 2021 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import abc -import json -import time -from functools import wraps -from shortuuid import ShortUUID - -from api.versions import get_rag_version - -from api.errors.error_services import * -from api.settings import ( - GRPC_PORT, HOST, HTTP_PORT, - RANDOM_INSTANCE_ID, stat_logger, -) - - -instance_id = ShortUUID().random(length=8) if RANDOM_INSTANCE_ID else f'flow-{HOST}-{HTTP_PORT}' -server_instance = ( - f'{HOST}:{GRPC_PORT}', - json.dumps({ - 'instance_id': instance_id, - 'timestamp': round(time.time() * 1000), - 'version': get_rag_version() or '', - 'host': HOST, - 'grpc_port': GRPC_PORT, - 'http_port': HTTP_PORT, - }), -) - - -def check_service_supported(method): - """Decorator to check if `service_name` is supported. - The attribute `supported_services` MUST be defined in class. - The first and second arguments of `method` MUST be `self` and `service_name`. - - :param Callable method: The class method. - :return: The inner wrapper function. - :rtype: Callable - """ - @wraps(method) - def magic(self, service_name, *args, **kwargs): - if service_name not in self.supported_services: - raise ServiceNotSupported(service_name=service_name) - return method(self, service_name, *args, **kwargs) - return magic - - -class ServicesDB(abc.ABC): - """Database for storage service urls. - Abstract base class for the real backends. - - """ - @property - @abc.abstractmethod - def supported_services(self): - """The names of supported services. - The returned list SHOULD contain `ragflow` (model download) and `servings` (RAG-Serving). - - :return: The service names. - :rtype: list - """ - pass - - @abc.abstractmethod - def _get_serving(self): - pass - - def get_serving(self): - - try: - return self._get_serving() - except ServicesError as e: - stat_logger.exception(e) - return [] - - @abc.abstractmethod - def _insert(self, service_name, service_url, value=''): - pass - - @check_service_supported - def insert(self, service_name, service_url, value=''): - """Insert a service url to database. - - :param str service_name: The service name. - :param str service_url: The service url. - :return: None - """ - try: - self._insert(service_name, service_url, value) - except ServicesError as e: - stat_logger.exception(e) - - @abc.abstractmethod - def _delete(self, service_name, service_url): - pass - - @check_service_supported - def delete(self, service_name, service_url): - """Delete a service url from database. - - :param str service_name: The service name. - :param str service_url: The service url. - :return: None - """ - try: - self._delete(service_name, service_url) - except ServicesError as e: - stat_logger.exception(e) - - def register_flow(self): - """Call `self.insert` for insert the flow server address to databae. - - :return: None - """ - self.insert('flow-server', *server_instance) - - def unregister_flow(self): - """Call `self.delete` for delete the flow server address from databae. - - :return: None - """ - self.delete('flow-server', server_instance[0]) - - @abc.abstractmethod - def _get_urls(self, service_name, with_values=False): - pass - - @check_service_supported - def get_urls(self, service_name, with_values=False): - """Query service urls from database. The urls may belong to other nodes. - Currently, only `ragflow` (model download) urls and `servings` (RAG-Serving) urls are supported. - `ragflow` is a url containing scheme, host, port and path, - while `servings` only contains host and port. - - :param str service_name: The service name. - :return: The service urls. - :rtype: list - """ - try: - return self._get_urls(service_name, with_values) - except ServicesError as e: - stat_logger.exception(e) - return [] diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index b17ee89..d4d00c1 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -63,7 +63,7 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def get_newly_uploaded(cls, tm, mod=0, comm=1, items_per_page=64): - fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.type, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] + fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.parser_config, cls.model.name, cls.model.type, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] docs = cls.model.select(*fields) \ .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \ .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\ diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index 63de6c2..a99346a 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -52,7 +52,8 @@ class KnowledgebaseService(CommonService): cls.model.doc_num, cls.model.token_num, cls.model.chunk_num, - cls.model.parser_id] + cls.model.parser_id, + cls.model.parser_config] kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where( (cls.model.id == kb_id), (cls.model.status == StatusEnum.VALID.value) diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 6cc62b2..87e84a1 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -27,7 +27,7 @@ class TaskService(CommonService): @classmethod @DB.connection_context() def get_tasks(cls, tm, mod=0, comm=1, items_per_page=64): - fields = [cls.model.id, cls.model.doc_id, cls.model.from_page,cls.model.to_page, Document.kb_id, Document.parser_id, Document.name, Document.type, Document.location, Document.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] + fields = [cls.model.id, cls.model.doc_id, cls.model.from_page,cls.model.to_page, Document.kb_id, Document.parser_id, Document.parser_config, Document.name, Document.type, Document.location, Document.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] docs = cls.model.select(*fields) \ .join(Document, on=(cls.model.doc_id == Document.id)) \ .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \ @@ -53,3 +53,13 @@ class TaskService(CommonService): except Exception as e: pass return True + + + @classmethod + @DB.connection_context() + def update_progress(cls, id, info): + cls.model.update(progress_msg=cls.model.progress_msg + "\n"+info["progress_msg"]).where( + cls.model.id == id).execute() + if "progress" in info: + cls.model.update(progress=info["progress"]).where( + cls.model.id == id).execute() diff --git a/api/db/services/user_service.py b/api/db/services/user_service.py index 3764acc..1ddfa01 100644 --- a/api/db/services/user_service.py +++ b/api/db/services/user_service.py @@ -92,6 +92,12 @@ class TenantService(CommonService): .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role==UserTenantRole.NORMAL.value)))\ .where(cls.model.status == StatusEnum.VALID.value).dicts()) + @classmethod + @DB.connection_context() + def decrease(cls, user_id, num): + num = cls.model.update(credit=cls.model.credit - num).where( + cls.model.id == user_id).execute() + if num == 0: raise LookupError("Tenant not found which is supposed to be there") class UserTenantService(CommonService): model = UserTenant diff --git a/api/errors/__init__.py b/api/errors/__init__.py deleted file mode 100644 index 10a03cc..0000000 --- a/api/errors/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from .general_error import * - - -class RagFlowError(Exception): - message = 'Unknown Rag Flow Error' - - def __init__(self, message=None, *args, **kwargs): - message = str(message) if message is not None else self.message - message = message.format(*args, **kwargs) - super().__init__(message) \ No newline at end of file diff --git a/api/errors/error_services.py b/api/errors/error_services.py deleted file mode 100644 index 73b383b..0000000 --- a/api/errors/error_services.py +++ /dev/null @@ -1,13 +0,0 @@ -from api.errors import RagFlowError - -__all__ = ['ServicesError', 'ServiceNotSupported', 'ZooKeeperNotConfigured', - 'MissingZooKeeperUsernameOrPassword', 'ZooKeeperBackendError'] - - -class ServicesError(RagFlowError): - message = 'Unknown services error' - - -class ServiceNotSupported(ServicesError): - message = 'The service {service_name} is not supported' - diff --git a/api/errors/general_error.py b/api/errors/general_error.py deleted file mode 100644 index e87e54f..0000000 --- a/api/errors/general_error.py +++ /dev/null @@ -1,21 +0,0 @@ -# -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -class ParameterError(Exception): - pass - - -class PassError(Exception): - pass \ No newline at end of file diff --git a/api/hook/__init__.py b/api/hook/__init__.py deleted file mode 100644 index 9cbec58..0000000 --- a/api/hook/__init__.py +++ /dev/null @@ -1,57 +0,0 @@ -import importlib - -from api.hook.common.parameters import SignatureParameters, AuthenticationParameters, \ - SignatureReturn, AuthenticationReturn, PermissionReturn, ClientAuthenticationReturn, ClientAuthenticationParameters -from api.settings import HOOK_MODULE, stat_logger,RetCode - - -class HookManager: - SITE_SIGNATURE = [] - SITE_AUTHENTICATION = [] - CLIENT_AUTHENTICATION = [] - PERMISSION_CHECK = [] - - @staticmethod - def init(): - if HOOK_MODULE is not None: - for modules in HOOK_MODULE.values(): - for module in modules.split(";"): - try: - importlib.import_module(module) - except Exception as e: - stat_logger.exception(e) - - @staticmethod - def register_site_signature_hook(func): - HookManager.SITE_SIGNATURE.append(func) - - @staticmethod - def register_site_authentication_hook(func): - HookManager.SITE_AUTHENTICATION.append(func) - - @staticmethod - def register_client_authentication_hook(func): - HookManager.CLIENT_AUTHENTICATION.append(func) - - @staticmethod - def register_permission_check_hook(func): - HookManager.PERMISSION_CHECK.append(func) - - @staticmethod - def client_authentication(parm: ClientAuthenticationParameters) -> ClientAuthenticationReturn: - if HookManager.CLIENT_AUTHENTICATION: - return HookManager.CLIENT_AUTHENTICATION[0](parm) - return ClientAuthenticationReturn() - - @staticmethod - def site_signature(parm: SignatureParameters) -> SignatureReturn: - if HookManager.SITE_SIGNATURE: - return HookManager.SITE_SIGNATURE[0](parm) - return SignatureReturn() - - @staticmethod - def site_authentication(parm: AuthenticationParameters) -> AuthenticationReturn: - if HookManager.SITE_AUTHENTICATION: - return HookManager.SITE_AUTHENTICATION[0](parm) - return AuthenticationReturn() - diff --git a/api/hook/api/client_authentication.py b/api/hook/api/client_authentication.py deleted file mode 100644 index da126af..0000000 --- a/api/hook/api/client_authentication.py +++ /dev/null @@ -1,29 +0,0 @@ -import requests - -from api.db.service_registry import ServiceRegistry -from api.settings import RegistryServiceName -from api.hook import HookManager -from api.hook.common.parameters import ClientAuthenticationParameters, ClientAuthenticationReturn -from api.settings import HOOK_SERVER_NAME - - -@HookManager.register_client_authentication_hook -def authentication(parm: ClientAuthenticationParameters) -> ClientAuthenticationReturn: - service_list = ServiceRegistry.load_service( - server_name=HOOK_SERVER_NAME, - service_name=RegistryServiceName.CLIENT_AUTHENTICATION.value - ) - if not service_list: - raise Exception(f"client authentication error: no found server" - f" {HOOK_SERVER_NAME} service client_authentication") - service = service_list[0] - response = getattr(requests, service.f_method.lower(), None)( - url=service.f_url, - json=parm.to_dict() - ) - if response.status_code != 200: - raise Exception( - f"client authentication error: request authentication url failed, status code {response.status_code}") - elif response.json().get("code") != 0: - return ClientAuthenticationReturn(code=response.json().get("code"), message=response.json().get("msg")) - return ClientAuthenticationReturn() \ No newline at end of file diff --git a/api/hook/api/permission.py b/api/hook/api/permission.py deleted file mode 100644 index 76bfc5f..0000000 --- a/api/hook/api/permission.py +++ /dev/null @@ -1,25 +0,0 @@ -import requests - -from api.db.service_registry import ServiceRegistry -from api.settings import RegistryServiceName -from api.hook import HookManager -from api.hook.common.parameters import PermissionCheckParameters, PermissionReturn -from api.settings import HOOK_SERVER_NAME - - -@HookManager.register_permission_check_hook -def permission(parm: PermissionCheckParameters) -> PermissionReturn: - service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, service_name=RegistryServiceName.PERMISSION_CHECK.value) - if not service_list: - raise Exception(f"permission check error: no found server {HOOK_SERVER_NAME} service permission") - service = service_list[0] - response = getattr(requests, service.f_method.lower(), None)( - url=service.f_url, - json=parm.to_dict() - ) - if response.status_code != 200: - raise Exception( - f"permission check error: request permission url failed, status code {response.status_code}") - elif response.json().get("code") != 0: - return PermissionReturn(code=response.json().get("code"), message=response.json().get("msg")) - return PermissionReturn() diff --git a/api/hook/api/site_authentication.py b/api/hook/api/site_authentication.py deleted file mode 100644 index 7c751d9..0000000 --- a/api/hook/api/site_authentication.py +++ /dev/null @@ -1,49 +0,0 @@ -import requests - -from api.db.service_registry import ServiceRegistry -from api.settings import RegistryServiceName -from api.hook import HookManager -from api.hook.common.parameters import SignatureParameters, AuthenticationParameters, AuthenticationReturn,\ - SignatureReturn -from api.settings import HOOK_SERVER_NAME, PARTY_ID - - -@HookManager.register_site_signature_hook -def signature(parm: SignatureParameters) -> SignatureReturn: - service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, service_name=RegistryServiceName.SIGNATURE.value) - if not service_list: - raise Exception(f"signature error: no found server {HOOK_SERVER_NAME} service signature") - service = service_list[0] - response = getattr(requests, service.f_method.lower(), None)( - url=service.f_url, - json=parm.to_dict() - ) - if response.status_code == 200: - if response.json().get("code") == 0: - return SignatureReturn(site_signature=response.json().get("data")) - else: - raise Exception(f"signature error: request signature url failed, result: {response.json()}") - else: - raise Exception(f"signature error: request signature url failed, status code {response.status_code}") - - -@HookManager.register_site_authentication_hook -def authentication(parm: AuthenticationParameters) -> AuthenticationReturn: - if not parm.src_party_id or str(parm.src_party_id) == "0": - parm.src_party_id = PARTY_ID - service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, - service_name=RegistryServiceName.SITE_AUTHENTICATION.value) - if not service_list: - raise Exception( - f"site authentication error: no found server {HOOK_SERVER_NAME} service site_authentication") - service = service_list[0] - response = getattr(requests, service.f_method.lower(), None)( - url=service.f_url, - json=parm.to_dict() - ) - if response.status_code != 200: - raise Exception( - f"site authentication error: request site_authentication url failed, status code {response.status_code}") - elif response.json().get("code") != 0: - return AuthenticationReturn(code=response.json().get("code"), message=response.json().get("msg")) - return AuthenticationReturn() \ No newline at end of file diff --git a/api/hook/common/parameters.py b/api/hook/common/parameters.py deleted file mode 100644 index 22893e8..0000000 --- a/api/hook/common/parameters.py +++ /dev/null @@ -1,56 +0,0 @@ -from api.settings import RetCode - - -class ParametersBase: - def to_dict(self): - d = {} - for k, v in self.__dict__.items(): - d[k] = v - return d - - -class ClientAuthenticationParameters(ParametersBase): - def __init__(self, full_path, headers, form, data, json): - self.full_path = full_path - self.headers = headers - self.form = form - self.data = data - self.json = json - - -class ClientAuthenticationReturn(ParametersBase): - def __init__(self, code=RetCode.SUCCESS, message="success"): - self.code = code - self.message = message - - -class SignatureParameters(ParametersBase): - def __init__(self, party_id, body): - self.party_id = party_id - self.body = body - - -class SignatureReturn(ParametersBase): - def __init__(self, code=RetCode.SUCCESS, site_signature=None): - self.code = code - self.site_signature = site_signature - - -class AuthenticationParameters(ParametersBase): - def __init__(self, site_signature, body): - self.site_signature = site_signature - self.body = body - - -class AuthenticationReturn(ParametersBase): - def __init__(self, code=RetCode.SUCCESS, message="success"): - self.code = code - self.message = message - - -class PermissionReturn(ParametersBase): - def __init__(self, code=RetCode.SUCCESS, message="success"): - self.code = code - self.message = message - - diff --git a/api/ragflow_server.py b/api/ragflow_server.py index 62ef541..f322b4e 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -20,12 +20,9 @@ import os import signal import sys import traceback - from werkzeug.serving import run_simple - from api.apps import app from api.db.runtime_config import RuntimeConfig -from api.hook import HookManager from api.settings import ( HOST, HTTP_PORT, access_logger, database_logger, stat_logger, ) @@ -60,8 +57,6 @@ if __name__ == '__main__': RuntimeConfig.init_env() RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT) - HookManager.init() - peewee_logger = logging.getLogger('peewee') peewee_logger.propagate = False # rag_arch.common.log.ROpenHandler diff --git a/api/settings.py b/api/settings.py index e42f393..23c7592 100644 --- a/api/settings.py +++ b/api/settings.py @@ -47,7 +47,7 @@ LLM = get_base_config("llm", {}) CHAT_MDL = LLM.get("chat_model", "gpt-3.5-turbo") EMBEDDING_MDL = LLM.get("embedding_model", "text-embedding-ada-002") ASR_MDL = LLM.get("asr_model", "whisper-1") -PARSERS = LLM.get("parsers", "General,Resume,Laws,Product Instructions,Books,Paper,Q&A,Programming Code,Power Point,Research Report") +PARSERS = LLM.get("parsers", "general:General,resume:esume,laws:Laws,manual:Manual,book:Book,paper:Paper,qa:Q&A,presentation:Presentation") IMAGE2TEXT_MDL = LLM.get("image2text_model", "gpt-4-vision-preview") # distribution diff --git a/rag/app/book.py b/rag/app/book.py index a478f17..4b38d09 100644 --- a/rag/app/book.py +++ b/rag/app/book.py @@ -3,7 +3,7 @@ import random import re import numpy as np from rag.parser import bullets_category, BULLET_PATTERN, is_english, tokenize, remove_contents_table, \ - hierarchical_merge, make_colon_as_title, naive_merge + hierarchical_merge, make_colon_as_title, naive_merge, random_choices from rag.nlp import huqie from rag.parser.docx_parser import HuDocxParser from rag.parser.pdf_parser import HuParser @@ -51,7 +51,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k doc_parser = HuDocxParser() # TODO: table of contents need to be removed sections, tbls = doc_parser(binary if binary else filename, from_page=from_page, to_page=to_page) - remove_contents_table(sections, eng=is_english(random.choices([t for t,_ in sections], k=200))) + remove_contents_table(sections, eng=is_english(random_choices([t for t,_ in sections], k=200))) callback(0.8, "Finish parsing.") elif re.search(r"\.pdf$", filename, re.IGNORECASE): pdf_parser = Pdf() @@ -67,20 +67,20 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k l = f.readline() if not l:break txt += l - sections = txt.split("\n") + sections = txt.split("\n") sections = [(l,"") for l in sections if l] - remove_contents_table(sections, eng = is_english(random.choices([t for t,_ in sections], k=200))) + remove_contents_table(sections, eng = is_english(random_choices([t for t,_ in sections], k=200))) callback(0.8, "Finish parsing.") else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") make_colon_as_title(sections) - bull = bullets_category([t for t in random.choices([t for t,_ in sections], k=100)]) + bull = bullets_category([t for t in random_choices([t for t,_ in sections], k=100)]) if bull >= 0: cks = hierarchical_merge(bull, sections, 3) else: cks = naive_merge(sections, kwargs.get("chunk_token_num", 256), kwargs.get("delimer", "\n。;ďĽďĽź")) sections = [t for t, _ in sections] # is it English - eng = is_english(random.choices(sections, k=218)) + eng = is_english(random_choices(sections, k=218)) res = [] # add tables diff --git a/rag/app/laws.py b/rag/app/laws.py index 7e9a964..ebeebf0 100644 --- a/rag/app/laws.py +++ b/rag/app/laws.py @@ -86,7 +86,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k l = f.readline() if not l:break txt += l - sections = txt.split("\n") + sections = txt.split("\n") + sections = txt.split("\n") sections = [l for l in sections if l] callback(0.8, "Finish parsing.") else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") diff --git a/rag/app/naive.py b/rag/app/naive.py index 14bc1f8..178e016 100644 --- a/rag/app/naive.py +++ b/rag/app/naive.py @@ -52,7 +52,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k l = f.readline() if not l:break txt += l - sections = txt.split("\n") + sections = txt.split("\n") sections = [(l,"") for l in sections if l] callback(0.8, "Finish parsing.") else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") diff --git a/rag/app/paper.py b/rag/app/paper.py index 131582f..3fca9a9 100644 --- a/rag/app/paper.py +++ b/rag/app/paper.py @@ -1,6 +1,9 @@ import copy import re from collections import Counter + +from api.db import ParserType +from rag.cv.ppdetection import PPDet from rag.parser import tokenize from rag.nlp import huqie from rag.parser.pdf_parser import HuParser @@ -9,6 +12,10 @@ from rag.utils import num_tokens_from_string class Pdf(HuParser): + def __init__(self): + self.model_speciess = ParserType.PAPER.value + super().__init__() + def __call__(self, filename, binary=None, from_page=0, to_page=100000, zoomin=3, callback=None): self.__images__( @@ -63,6 +70,15 @@ class Pdf(HuParser): "[0-9. 一ă€i]*(introduction|abstract|ć‘č¦|引言|keywords|key words|关键词|background|čŚć™Ż|目录|前言|contents)", txt.lower().strip()) + if from_page > 0: + return { + "title":"", + "authors": "", + "abstract": "", + "lines": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes[i:] if + re.match(r"(text|title)", b.get("layoutno", "text"))], + "tables": tbls + } # get title and authors title = "" authors = [] @@ -115,18 +131,13 @@ class Pdf(HuParser): def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs): pdf_parser = None - paper = {} - if re.search(r"\.pdf$", filename, re.IGNORECASE): pdf_parser = Pdf() paper = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback) else: raise NotImplementedError("file type not supported yet(pdf supported)") - doc = { - "docnm_kwd": paper["title"] if paper["title"] else filename, - "authors_tks": paper["authors"] - } - doc["title_tks"] = huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", doc["docnm_kwd"])) + doc = {"docnm_kwd": filename, "authors_tks": paper["authors"], + "title_tks": huqie.qie(paper["title"] if paper["title"] else filename)} doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) doc["authors_sm_tks"] = huqie.qieqie(doc["authors_tks"]) # is it English diff --git a/rag/app/qa.py b/rag/app/qa.py index 4012984..fd4f568 100644 --- a/rag/app/qa.py +++ b/rag/app/qa.py @@ -3,7 +3,7 @@ import re from io import BytesIO from nltk import word_tokenize from openpyxl import load_workbook -from rag.parser import is_english +from rag.parser import is_english, random_choices from rag.nlp import huqie, stemmer @@ -33,9 +33,9 @@ class Excel(object): if len(res) % 999 == 0: callback(len(res)*0.6/total, ("Extract Q&A: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..."%(",".join(fails[:3])) if fails else ""))) - callback(0.6, ("Extract Q&A: {}".format(len(res)) + ( + callback(0.6, ("Extract Q&A: {}. ".format(len(res)) + ( f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) - self.is_english = is_english([rmPrefix(q) for q, _ in random.choices(res, k=30) if len(q)>1]) + self.is_english = is_english([rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q)>1]) return res diff --git a/rag/app/table.py b/rag/app/table.py new file mode 100644 index 0000000..aba795f --- /dev/null +++ b/rag/app/table.py @@ -0,0 +1,170 @@ +import copy +import random +import re +from io import BytesIO +from xpinyin import Pinyin +import numpy as np +import pandas as pd +from nltk import word_tokenize +from openpyxl import load_workbook +from dateutil.parser import parse as datetime_parse +from rag.parser import is_english, tokenize +from rag.nlp import huqie, stemmer + + +class Excel(object): + def __call__(self, fnm, binary=None, callback=None): + if not binary: + wb = load_workbook(fnm) + else: + wb = load_workbook(BytesIO(binary)) + total = 0 + for sheetname in wb.sheetnames: + total += len(list(wb[sheetname].rows)) + + res, fails, done = [], [], 0 + for sheetname in wb.sheetnames: + ws = wb[sheetname] + rows = list(ws.rows) + headers = [cell.value for cell in rows[0]] + missed = set([i for i,h in enumerate(headers) if h is None]) + headers = [cell.value for i,cell in enumerate(rows[0]) if i not in missed] + data = [] + for i, r in enumerate(rows[1:]): + row = [cell.value for ii,cell in enumerate(r) if ii not in missed] + if len(row) != len(headers): + fails.append(str(i)) + continue + data.append(row) + done += 1 + if done % 999 == 0: + callback(done * 0.6/total, ("Extract records: {}".format(len(res)) + (f"{len(fails)} failure({sheetname}), line: %s..."%(",".join(fails[:3])) if fails else ""))) + res.append(pd.DataFrame(np.array(data), columns=headers)) + + callback(0.6, ("Extract records: {}. ".format(done) + ( + f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + return res + + +def trans_datatime(s): + try: + return datetime_parse(s.strip()).strftime("%Y-%m-%dT%H:%M:%S") + except Exception as e: + pass + + +def trans_bool(s): + if re.match(r"(true|yes|ćŻ)$", str(s).strip(), flags=re.IGNORECASE): return ["yes", "ćŻ"] + if re.match(r"(false|no|ĺ¦)$", str(s).strip(), flags=re.IGNORECASE): return ["no", "ĺ¦"] + + +def column_data_type(arr): + uni = len(set([a for a in arr if a is not None])) + counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0} + trans = {t:f for f,t in [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]} + for a in arr: + if a is None:continue + if re.match(r"[+-]?[0-9]+(\.0+)?$", str(a).replace("%%", "")): + counts["int"] += 1 + elif re.match(r"[+-]?[0-9.]+$", str(a).replace("%%", "")): + counts["float"] += 1 + elif re.match(r"(true|false|yes|no|ćŻ|ĺ¦)$", str(a), flags=re.IGNORECASE): + counts["bool"] += 1 + elif trans_datatime(str(a)): + counts["datetime"] += 1 + else: counts["text"] += 1 + counts = sorted(counts.items(), key=lambda x: x[1]*-1) + ty = counts[0][0] + for i in range(len(arr)): + if arr[i] is None:continue + try: + arr[i] = trans[ty](str(arr[i])) + except Exception as e: + arr[i] = None + if ty == "text": + if len(arr) > 128 and uni/len(arr) < 0.1: + ty = "keyword" + return arr, ty + + +def chunk(filename, binary=None, callback=None, **kwargs): + dfs = [] + if re.search(r"\.xlsx?$", filename, re.IGNORECASE): + callback(0.1, "Start to parse.") + excel_parser = Excel() + dfs = excel_parser(filename, binary, callback) + elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE): + callback(0.1, "Start to parse.") + txt = "" + if binary: + txt = binary.decode("utf-8") + else: + with open(filename, "r") as f: + while True: + l = f.readline() + if not l: break + txt += l + lines = txt.split("\n") + fails = [] + headers = lines[0].split(kwargs.get("delimiter", "\t")) + rows = [] + for i, line in enumerate(lines[1:]): + row = [l for l in line.split(kwargs.get("delimiter", "\t"))] + if len(row) != len(headers): + fails.append(str(i)) + continue + rows.append(row) + if len(rows) % 999 == 0: + callback(len(rows) * 0.6 / len(lines), ("Extract records: {}".format(len(rows)) + ( + f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + + callback(0.6, ("Extract records: {}".format(len(rows)) + ( + f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + + dfs = [pd.DataFrame(np.array(rows), columns=headers)] + + else: raise NotImplementedError("file type not supported yet(excel, text, csv supported)") + + res = [] + PY = Pinyin() + fieds_map = {"text": "_tks", "int": "_int", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"} + for df in dfs: + for n in ["id", "_id", "index", "idx"]: + if n in df.columns:del df[n] + clmns = df.columns.values + txts = list(copy.deepcopy(clmns)) + py_clmns = [PY.get_pinyins(n)[0].replace("-", "_") for n in clmns] + clmn_tys = [] + for j in range(len(clmns)): + cln,ty = column_data_type(df[clmns[j]]) + clmn_tys.append(ty) + df[clmns[j]] = cln + if ty == "text": txts.extend([str(c) for c in cln if c]) + clmns_map = [(py_clmns[j] + fieds_map[clmn_tys[j]], clmns[j]) for i in range(len(clmns))] + # TODO: set this column map to KB parser configuration + + eng = is_english(txts) + for ii,row in df.iterrows(): + d = {} + row_txt = [] + for j in range(len(clmns)): + if row[clmns[j]] is None:continue + fld = clmns_map[j][0] + d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie(row[clmns[j]]) + row_txt.append("{}:{}".format(clmns[j], row[clmns[j]])) + if not row_txt:continue + tokenize(d, "; ".join(row_txt), eng) + print(d) + res.append(d) + callback(0.6, "") + + return res + + + +if __name__== "__main__": + import sys + def dummy(a, b): + pass + chunk(sys.argv[1], callback=dummy) + diff --git a/rag/nlp/search.py b/rag/nlp/search.py index d42909f..4c8c215 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -67,7 +67,7 @@ class Dealer: ps = int(req.get("size", 1000)) src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "image_id", "doc_id", "q_512_vec", "q_768_vec", - "q_1024_vec", "q_1536_vec", "available_int"]) + "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"]) s = s.query(bqry)[pg * ps:(pg + 1) * ps] s = s.highlight("content_ltks") @@ -234,7 +234,7 @@ class Dealer: sres.field[i].get("q_%d_vec" % len(sres.query_vector), "\t".join(["0"] * len(sres.query_vector)))) for i in sres.ids] if not ins_embd: return [], [], [] - ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") + ins_tw = [sres.field[i][cfield].split(" ") for i in sres.ids] sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, ins_embd, @@ -281,6 +281,7 @@ class Dealer: d = { "chunk_id": id, "content_ltks": sres.field[id]["content_ltks"], + "content_with_weight": sres.field[id]["content_with_weight"], "doc_id": sres.field[id]["doc_id"], "docnm_kwd": dnm, "kb_id": sres.field[id]["kb_id"], diff --git a/rag/parser/__init__.py b/rag/parser/__init__.py index ed98049..d2b499c 100644 --- a/rag/parser/__init__.py +++ b/rag/parser/__init__.py @@ -1,4 +1,5 @@ import copy +import random from .pdf_parser import HuParser as PdfParser from .docx_parser import HuDocxParser as DocxParser @@ -38,6 +39,9 @@ BULLET_PATTERN = [[ ] ] +def random_choices(arr, k): + k = min(len(arr), k) + return random.choices(arr, k=k) def bullets_category(sections): global BULLET_PATTERN diff --git a/rag/parser/pdf_parser.py b/rag/parser/pdf_parser.py index 9cc3451..32cd636 100644 --- a/rag/parser/pdf_parser.py +++ b/rag/parser/pdf_parser.py @@ -1,7 +1,10 @@ # -*- coding: utf-8 -*- +import os import random +from functools import partial import fitz +import requests import xgboost as xgb from io import BytesIO import torch @@ -10,13 +13,14 @@ import pdfplumber import logging from PIL import Image import numpy as np + +from api.db import ParserType from rag.nlp import huqie from collections import Counter from copy import deepcopy -from rag.cv.table_recognize import TableTransformer -from rag.cv.ppdetection import PPDet from huggingface_hub import hf_hub_download + logging.getLogger("pdfminer").setLevel(logging.WARNING) @@ -25,8 +29,10 @@ class HuParser: from paddleocr import PaddleOCR logging.getLogger("ppocr").setLevel(logging.ERROR) self.ocr = PaddleOCR(use_angle_cls=False, lang="ch") - self.layouter = PPDet("/data/newpeak/medical-gpt/res/ppdet") - self.tbl_det = PPDet("/data/newpeak/medical-gpt/res/ppdet.tbl") + if not hasattr(self, "model_speciess"): + self.model_speciess = ParserType.GENERAL.value + self.layouter = partial(self.__remote_call, self.model_speciess) + self.tbl_det = partial(self.__remote_call, "table_component") self.updown_cnt_mdl = xgb.Booster() if torch.cuda.is_available(): @@ -45,6 +51,38 @@ class HuParser: """ + def __remote_call(self, species, images, thr=0.7): + url = os.environ.get("INFINIFLOW_SERVER") + if not url:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_SERVER'") + token = os.environ.get("INFINIFLOW_TOKEN") + if not token:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_TOKEN'") + + def convert_image_to_bytes(PILimage): + image = BytesIO() + PILimage.save(image, format='png') + image.seek(0) + return image.getvalue() + + images = [convert_image_to_bytes(img) for img in images] + + def remote_call(): + nonlocal images, thr + res = requests.post(url+"/v1/layout/detect/"+species, files=[("image", img) for img in images], data={"threashold": thr}, + headers={"Authorization": token}, timeout=len(images) * 10) + res = res.json() + if res["retcode"] != 0: raise RuntimeError(res["retmsg"]) + return res["data"] + + for _ in range(3): + try: + return remote_call() + except RuntimeError as e: + raise e + except Exception as e: + logging.error("layout_predict:"+str(e)) + return remote_call() + + def __char_width(self, c): return (c["x1"] - c["x0"]) // len(c["text"]) @@ -344,7 +382,7 @@ class HuParser: return layouts def __table_paddle(self, images): - tbls = self.tbl_det([np.array(img) for img in images], thr=0.5) + tbls = self.tbl_det(images, thr=0.5) res = [] # align left&right for rows, align top&bottom for columns for tbl in tbls: @@ -522,7 +560,7 @@ class HuParser: assert len(self.page_images) == len(self.boxes) # Tag layout type boxes = [] - layouts = self.layouter([np.array(img) for img in self.page_images]) + layouts = self.layouter(self.page_images) assert len(self.page_images) == len(layouts) for pn, lts in enumerate(layouts): bxs = self.boxes[pn] @@ -1705,7 +1743,8 @@ class HuParser: self.__ocr_paddle(i + 1, img, chars, zoomin) if not self.is_english and not any([c for c in self.page_chars]) and self.boxes: - self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join([b["text"] for b in random.choices([b for bxs in self.boxes for b in bxs], k=30)])) + bxes = [b for bxs in self.boxes for b in bxs] + self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join([b["text"] for b in random.choices(bxes, k=min(30, len(bxes)))])) logging.info("Is it English:", self.is_english) diff --git a/rag/svr/task_broker.py b/rag/svr/task_broker.py index 8b52e14..76aa7f9 100644 --- a/rag/svr/task_broker.py +++ b/rag/svr/task_broker.py @@ -134,5 +134,5 @@ if __name__ == "__main__": while True: dispatch() - time.sleep(3) + time.sleep(1) update_progress() diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 4cc348f..e945dad 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -36,7 +36,7 @@ from rag.nlp import search from io import BytesIO import pandas as pd -from rag.app import laws, paper, presentation, manual, qa +from rag.app import laws, paper, presentation, manual, qa, table,book from api.db import LLMType, ParserType from api.db.services.document_service import DocumentService @@ -49,10 +49,12 @@ BATCH_SIZE = 64 FACTORY = { ParserType.GENERAL.value: laws, ParserType.PAPER.value: paper, + ParserType.BOOK.value: book, ParserType.PRESENTATION.value: presentation, ParserType.MANUAL.value: manual, ParserType.LAWS.value: laws, ParserType.QA.value: qa, + ParserType.TABLE.value: table, } @@ -66,7 +68,7 @@ def set_progress(task_id, from_page, to_page, prog=None, msg="Processing..."): d = {"progress_msg": msg} if prog is not None: d["progress"] = prog try: - TaskService.update_by_id(task_id, d) + TaskService.update_progress(task_id, d) except Exception as e: cron_logger.error("set_progress:({}), {}".format(task_id, str(e))) @@ -113,7 +115,7 @@ def build(row, cvmdl): return [] callback = partial(set_progress, row["id"], row["from_page"], row["to_page"]) - chunker = FACTORY[row["parser_id"]] + chunker = FACTORY[row["parser_id"].lower()] try: cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"])) cks = chunker.chunk(row["name"], MINIO.get(row["kb_id"], row["location"]), row["from_page"], row["to_page"], @@ -154,6 +156,7 @@ def build(row, cvmdl): MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue()) d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"]) + del d["image"] docs.append(d) return docs @@ -168,7 +171,7 @@ def init_kb(row): def embedding(docs, mdl): - tts, cnts = [d["docnm_kwd"] for d in docs if d.get("docnm_kwd")], [d["content_with_weight"] for d in docs] + tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [d["content_with_weight"] for d in docs] tk_count = 0 if len(tts) == len(cnts): tts, c = mdl.encode(tts) @@ -207,6 +210,7 @@ def main(comm, mod): cks = build(r, cv_mdl) if not cks: tmf.write(str(r["update_time"]) + "\n") + callback(1., "No chunk! Done!") continue # TODO: exception handler ## set_progress(r["did"], -1, "ERROR: ") @@ -215,7 +219,6 @@ def main(comm, mod): except Exception as e: callback(-1, "Embedding error:{}".format(str(e))) cron_logger.error(str(e)) - continue callback(msg="Finished embedding! Start to build index!") init_kb(r) @@ -227,6 +230,7 @@ def main(comm, mod): else: if TaskService.do_cancel(r["id"]): ELASTICSEARCH.deleteByQuery(Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"])) + continue callback(1., "Done!") DocumentService.increment_chunk_num(r["doc_id"], r["kb_id"], tk_count, chunk_count, 0) cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks))) -- GitLab