From 6224edcd1bb67d0c75c156ca31c1622cd772dd2e Mon Sep 17 00:00:00 2001 From: KevinHuSh <kevinhu.sh@gmail.com> Date: Wed, 31 Jan 2024 19:57:45 +0800 Subject: [PATCH] Add task moduel, and pipline the task and every parser (#49) --- api/apps/document_app.py | 24 ++- api/db/__init__.py | 11 ++ api/db/db_models.py | 80 ++------ api/db/db_utils.py | 18 +- api/db/services/common_service.py | 1 + api/db/services/document_service.py | 16 +- api/db/services/task_service.py | 53 +++++ rag/app/__init__.py | 4 +- rag/app/laws.py | 18 +- rag/app/manual.py | 23 +-- rag/app/paper.py | 14 +- rag/app/presentation.py | 17 +- rag/parser/pdf_parser.py | 9 + rag/svr/task_broker.py | 130 ++++++++++++ .../{parse_user_docs.py => task_executor.py} | 186 +++++++----------- 15 files changed, 368 insertions(+), 236 deletions(-) create mode 100644 api/db/services/task_service.py create mode 100644 rag/svr/task_broker.py rename rag/svr/{parse_user_docs.py => task_executor.py} (53%) diff --git a/api/apps/document_app.py b/api/apps/document_app.py index c8f8384..1a57aa0 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -22,6 +22,8 @@ from elasticsearch_dsl import Q from flask import request from flask_login import login_required, current_user +from api.db.db_models import Task +from api.db.services.task_service import TaskService from rag.nlp import search from rag.utils import ELASTICSEARCH from api.db.services import duplicate_name @@ -205,6 +207,26 @@ def rm(): return server_error_response(e) +@manager.route('/run', methods=['POST']) +@login_required +@validate_request("doc_ids", "run") +def rm(): + req = request.json + try: + for id in req["doc_ids"]: + DocumentService.update_by_id(id, {"run": str(req["run"])}) + if req["run"] == "2": + TaskService.filter_delete([Task.doc_id == id]) + tenant_id = DocumentService.get_tenant_id(id) + if not tenant_id: + return get_data_error_result(retmsg="Tenant not found!") + ELASTICSEARCH.deleteByQuery(Q("match", doc_id=id), idxnm=search.index_name(tenant_id)) + + return get_json_result(data=True) + except Exception as e: + return server_error_response(e) + + @manager.route('/rename', methods=['POST']) @login_required @validate_request("doc_id", "name", "old_name") @@ -262,7 +284,7 @@ def change_parser(): if doc.parser_id.lower() == req["parser_id"].lower(): return get_json_result(data=True) - e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""}) + e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": "", "run": 1}) if not e: return get_data_error_result(retmsg="Document not found!") e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1) diff --git a/api/db/__init__.py b/api/db/__init__.py index 4979634..c84ee15 100644 --- a/api/db/__init__.py +++ b/api/db/__init__.py @@ -59,3 +59,14 @@ class ChatStyle(StrEnum): PRECISE = 'Precise' EVENLY = 'Evenly' CUSTOM = 'Custom' + + +class ParserType(StrEnum): + GENERAL = "general" + PRESENTATION = "presentation" + LAWS = "laws" + MANUAL = "manual" + PAPER = "paper" + RESUME = "" + BOOK = "" + QA = "" diff --git a/api/db/db_models.py b/api/db/db_models.py index b0580eb..3b7b5bc 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -496,15 +496,27 @@ class Document(DataBaseModel): token_num = IntegerField(default=0) chunk_num = IntegerField(default=0) progress = FloatField(default=0) - progress_msg = CharField(max_length=255, null=True, help_text="process message", default="") + progress_msg = CharField(max_length=512, null=True, help_text="process message", default="") process_begin_at = DateTimeField(null=True) process_duation = FloatField(default=0) + run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0") status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") class Meta: db_table = "document" +class Task(DataBaseModel): + id = CharField(max_length=32, primary_key=True) + doc_id = CharField(max_length=32, null=False, index=True) + from_page = IntegerField(default=0) + to_page = IntegerField(default=-1) + begin_at = DateTimeField(null=True) + process_duation = FloatField(default=0) + progress = FloatField(default=0) + progress_msg = CharField(max_length=255, null=True, help_text="process message", default="") + + class Dialog(DataBaseModel): id = CharField(max_length=32, primary_key=True) tenant_id = CharField(max_length=32, null=False) @@ -553,72 +565,6 @@ class Conversation(DataBaseModel): """ -class Job(DataBaseModel): - # multi-party common configuration - f_user_id = CharField(max_length=25, null=True) - f_job_id = CharField(max_length=25, index=True) - f_name = CharField(max_length=500, null=True, default='') - f_description = TextField(null=True, default='') - f_tag = CharField(max_length=50, null=True, default='') - f_dsl = JSONField() - f_runtime_conf = JSONField() - f_runtime_conf_on_party = JSONField() - f_train_runtime_conf = JSONField(null=True) - f_roles = JSONField() - f_initiator_role = CharField(max_length=50) - f_initiator_party_id = CharField(max_length=50) - f_status = CharField(max_length=50) - f_status_code = IntegerField(null=True) - f_user = JSONField() - # this party configuration - f_role = CharField(max_length=50, index=True) - f_party_id = CharField(max_length=10, index=True) - f_is_initiator = BooleanField(null=True, default=False) - f_progress = IntegerField(null=True, default=0) - f_ready_signal = BooleanField(default=False) - f_ready_time = BigIntegerField(null=True) - f_cancel_signal = BooleanField(default=False) - f_cancel_time = BigIntegerField(null=True) - f_rerun_signal = BooleanField(default=False) - f_end_scheduling_updates = IntegerField(null=True, default=0) - - f_engine_name = CharField(max_length=50, null=True) - f_engine_type = CharField(max_length=10, null=True) - f_cores = IntegerField(default=0) - f_memory = IntegerField(default=0) # MB - f_remaining_cores = IntegerField(default=0) - f_remaining_memory = IntegerField(default=0) # MB - f_resource_in_use = BooleanField(default=False) - f_apply_resource_time = BigIntegerField(null=True) - f_return_resource_time = BigIntegerField(null=True) - - f_inheritance_info = JSONField(null=True) - f_inheritance_status = CharField(max_length=50, null=True) - - f_start_time = BigIntegerField(null=True) - f_start_date = DateTimeField(null=True) - f_end_time = BigIntegerField(null=True) - f_end_date = DateTimeField(null=True) - f_elapsed = BigIntegerField(null=True) - - class Meta: - db_table = "t_job" - primary_key = CompositeKey('f_job_id', 'f_role', 'f_party_id') - - - -class PipelineComponentMeta(DataBaseModel): - f_model_id = CharField(max_length=100, index=True) - f_model_version = CharField(max_length=100, index=True) - f_role = CharField(max_length=50, index=True) - f_party_id = CharField(max_length=10, index=True) - f_component_name = CharField(max_length=100, index=True) - f_component_module_name = CharField(max_length=100) - f_model_alias = CharField(max_length=100, index=True) - f_model_proto_index = JSONField(null=True) - f_run_parameters = JSONField(null=True) - f_archive_sha256 = CharField(max_length=100, null=True) - f_archive_from_ip = CharField(max_length=100, null=True) class Meta: db_table = 't_pipeline_component_meta' diff --git a/api/db/db_utils.py b/api/db/db_utils.py index f049eab..c5ad024 100644 --- a/api/db/db_utils.py +++ b/api/db/db_utils.py @@ -32,19 +32,19 @@ LOGGER = getLogger() def bulk_insert_into_db(model, data_source, replace_on_conflict=False): DB.create_tables([model]) - current_time = current_timestamp() - current_date = timestamp_to_date(current_time) for data in data_source: - if 'f_create_time' not in data: - data['f_create_time'] = current_time - data['f_create_date'] = timestamp_to_date(data['f_create_time']) - data['f_update_time'] = current_time - data['f_update_date'] = current_date + current_time = current_timestamp() + current_date = timestamp_to_date(current_time) + if 'create_time' not in data: + data['create_time'] = current_time + data['create_date'] = timestamp_to_date(data['create_time']) + data['update_time'] = current_time + data['update_date'] = current_date - preserve = tuple(data_source[0].keys() - {'f_create_time', 'f_create_date'}) + preserve = tuple(data_source[0].keys() - {'create_time', 'create_date'}) - batch_size = 50 if RuntimeConfig.USE_LOCAL_DATABASE else 1000 + batch_size = 1000 for i in range(0, len(data_source), batch_size): with DB.atomic(): diff --git a/api/db/services/common_service.py b/api/db/services/common_service.py index a168e0f..6ae1c35 100644 --- a/api/db/services/common_service.py +++ b/api/db/services/common_service.py @@ -70,6 +70,7 @@ class CommonService: @DB.connection_context() def insert_many(cls, data_list, batch_size=100): with DB.atomic(): + for d in data_list: d["create_time"] = datetime_format(datetime.now()) for i in range(0, len(data_list), batch_size): cls.model.insert_many(data_list[i:i + batch_size]).execute() diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index b4afdbe..aba5afd 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -61,8 +61,8 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() - def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64): - fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, cls.model.update_time] + def get_newly_uploaded(cls, tm, mod=0, comm=1, items_per_page=64): + fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.type, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] docs = cls.model.select(*fields) \ .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \ .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\ @@ -76,6 +76,18 @@ class DocumentService(CommonService): .paginate(1, items_per_page) return list(docs.dicts()) + @classmethod + @DB.connection_context() + def get_unfinished_docs(cls): + fields = [cls.model.id, cls.model.process_begin_at] + docs = cls.model.select(*fields) \ + .where( + cls.model.status == StatusEnum.VALID.value, + ~(cls.model.type == FileType.VIRTUAL.value), + cls.model.progress < 1, + cls.model.progress > 0) + return list(docs.dicts()) + @classmethod @DB.connection_context() def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation): diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py new file mode 100644 index 0000000..e73759e --- /dev/null +++ b/api/db/services/task_service.py @@ -0,0 +1,53 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from peewee import Expression +from api.db.db_models import DB +from api.db import StatusEnum, FileType +from api.db.db_models import Task, Document, Knowledgebase, Tenant +from api.db.services.common_service import CommonService + + +class TaskService(CommonService): + model = Task + + @classmethod + @DB.connection_context() + def get_tasks(cls, tm, mod=0, comm=1, items_per_page=64): + fields = [cls.model.id, cls.model.doc_id, cls.model.from_page,cls.model.to_page, Document.kb_id, Document.parser_id, Document.name, Document.type, Document.location, Document.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] + docs = cls.model.select(*fields) \ + .join(Document, on=(cls.model.doc_id == Document.id)) \ + .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \ + .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\ + .where( + Document.status == StatusEnum.VALID.value, + ~(Document.type == FileType.VIRTUAL.value), + cls.model.progress == 0, + cls.model.update_time >= tm, + (Expression(cls.model.create_time, "%%", comm) == mod))\ + .order_by(cls.model.update_time.asc())\ + .paginate(1, items_per_page) + return list(docs.dicts()) + + + @classmethod + @DB.connection_context() + def do_cancel(cls, id): + try: + cls.model.get_by_id(id) + return False + except Exception as e: + pass + return True diff --git a/rag/app/__init__.py b/rag/app/__init__.py index 6d390ca..4b06f20 100644 --- a/rag/app/__init__.py +++ b/rag/app/__init__.py @@ -67,4 +67,6 @@ def tokenize(d, t, eng): d["content_ltks"] = " ".join([stemmer.stem(w) for w in word_tokenize(t)]) else: d["content_ltks"] = huqie.qie(t) - d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) \ No newline at end of file + d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) + + diff --git a/rag/app/laws.py b/rag/app/laws.py index 465213e..34f12a3 100644 --- a/rag/app/laws.py +++ b/rag/app/laws.py @@ -32,14 +32,12 @@ class Pdf(HuParser): zoomin, from_page, to_page) - callback__((min(to_page, self.total_page) - from_page) / self.total_page / 2, - "Page {}~{}: OCR finished".format(from_page, min(to_page, self.total_page)), callback) + callback__(0.1, "OCR finished", callback) from timeit import default_timer as timer start = timer() self._layouts_paddle(zoomin) - callback__((min(to_page, self.total_page) - from_page) / self.total_page / 2, - "Page {}~{}: Layout analysis finished".format(from_page, min(to_page, self.total_page)), callback) + callback__(0.77, "Layout analysis finished", callback) print("paddle layouts:", timer()-start) bxs = self.sort_Y_firstly(self.boxes, np.median(self.mean_height) / 3) # is it English @@ -77,8 +75,7 @@ class Pdf(HuParser): b["x1"] = max(b["x1"], b_["x1"]) bxs.pop(i + 1) - callback__((min(to_page, self.total_page) - from_page) / self.total_page / 2, - "Page {}~{}: Text extraction finished".format(from_page, min(to_page, self.total_page)), callback) + callback__(0.8, "Text extraction finished", callback) return [b["text"] + self._line_tag(b, zoomin) for b in bxs] @@ -92,14 +89,17 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None): pdf_parser = None sections = [] if re.search(r"\.docx?$", filename, re.IGNORECASE): + callback__(0.1, "Start to parse.", callback) for txt in Docx()(filename, binary): sections.append(txt) - if re.search(r"\.pdf$", filename, re.IGNORECASE): + callback__(0.8, "Finish parsing.", callback) + elif re.search(r"\.pdf$", filename, re.IGNORECASE): pdf_parser = Pdf() for txt in pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback): sections.append(txt) - if re.search(r"\.txt$", filename, re.IGNORECASE): + elif re.search(r"\.txt$", filename, re.IGNORECASE): + callback__(0.1, "Start to parse.", callback) txt = "" if binary:txt = binary.decode("utf-8") else: @@ -110,6 +110,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None): txt += l sections = txt.split("\n") sections = [l for l in sections if l] + callback__(0.8, "Finish parsing.", callback) + else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") # is it English eng = is_english(sections) diff --git a/rag/app/manual.py b/rag/app/manual.py index 420b678..43195d6 100644 --- a/rag/app/manual.py +++ b/rag/app/manual.py @@ -1,12 +1,8 @@ import copy import re -from collections import Counter -from rag.app import callback__, bullets_category, BULLET_PATTERN, is_english, tokenize -from rag.nlp import huqie, stemmer -from rag.parser.docx_parser import HuDocxParser +from rag.app import callback__, tokenize +from rag.nlp import huqie from rag.parser.pdf_parser import HuParser -from nltk.tokenize import word_tokenize -import numpy as np from rag.utils import num_tokens_from_string @@ -18,24 +14,19 @@ class Pdf(HuParser): zoomin, from_page, to_page) - callback__((min(to_page, self.total_page) - from_page) / self.total_page / 4, - "Page {}~{}: OCR finished".format(from_page, min(to_page, self.total_page)), callback) + callback__(0.2, "OCR finished.", callback) from timeit import default_timer as timer start = timer() self._layouts_paddle(zoomin) - callback__((min(to_page, self.total_page) - from_page) / self.total_page / 4, - "Page {}~{}: Layout analysis finished".format(from_page, min(to_page, self.total_page)), callback) + callback__(0.5, "Layout analysis finished.", callback) print("paddle layouts:", timer() - start) self._table_transformer_job(zoomin) - callback__((min(to_page, self.total_page) - from_page) / self.total_page / 4, - "Page {}~{}: Table analysis finished".format(from_page, min(to_page, self.total_page)), callback) + callback__(0.7, "Table analysis finished.", callback) self._text_merge() - column_width = np.median([b["x1"] - b["x0"] for b in self.boxes]) self._concat_downward(concat_between_pages=False) self._filter_forpages() - callback__((min(to_page, self.total_page) - from_page) / self.total_page / 4, - "Page {}~{}: Text merging finished".format(from_page, min(to_page, self.total_page)), callback) + callback__(0.77, "Text merging finished", callback) tbls = self._extract_table_figure(True, zoomin, False) # clean mess @@ -71,6 +62,7 @@ class Pdf(HuParser): b_["top"] = b["top"] self.boxes.pop(i) + callback__(0.8, "Parsing finished", callback) for b in self.boxes: print(b["text"], b.get("layoutno")) print(tbls) @@ -85,6 +77,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None): pdf_parser = Pdf() cks, tbls = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback) + else: raise NotImplementedError("file type not supported yet(pdf supported)") doc = { "docnm_kwd": filename } diff --git a/rag/app/paper.py b/rag/app/paper.py index b9c4aed..eacbd15 100644 --- a/rag/app/paper.py +++ b/rag/app/paper.py @@ -18,24 +18,20 @@ class Pdf(HuParser): zoomin, from_page, to_page) - callback__((min(to_page, self.total_page) - from_page) / self.total_page / 4, - "Page {}~{}: OCR finished".format(from_page, min(to_page, self.total_page)), callback) + callback__(0.2, "OCR finished.", callback) from timeit import default_timer as timer start = timer() self._layouts_paddle(zoomin) - callback__((min(to_page, self.total_page) - from_page) / self.total_page / 4, - "Page {}~{}: Layout analysis finished".format(from_page, min(to_page, self.total_page)), callback) + callback__(0.47, "Layout analysis finished", callback) print("paddle layouts:", timer() - start) self._table_transformer_job(zoomin) - callback__((min(to_page, self.total_page) - from_page) / self.total_page / 4, - "Page {}~{}: Table analysis finished".format(from_page, min(to_page, self.total_page)), callback) + callback__(0.68, "Table analysis finished", callback) self._text_merge() column_width = np.median([b["x1"] - b["x0"] for b in self.boxes]) self._concat_downward(concat_between_pages=False) self._filter_forpages() - callback__((min(to_page, self.total_page) - from_page) / self.total_page / 4, - "Page {}~{}: Text merging finished".format(from_page, min(to_page, self.total_page)), callback) + callback__(0.75, "Text merging finished.", callback) tbls = self._extract_table_figure(True, zoomin, False) # clean mess @@ -105,6 +101,7 @@ class Pdf(HuParser): break if not abstr: i = 0 + callback__(0.8, "Page {}~{}: Text merging finished".format(from_page, min(to_page, self.total_page)), callback) for b in self.boxes: print(b["text"], b.get("layoutno")) print(tbls) @@ -126,6 +123,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None): pdf_parser = Pdf() paper = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback) + else: raise NotImplementedError("file type not supported yet(pdf supported)") doc = { "docnm_kwd": paper["title"] if paper["title"] else filename, "authors_tks": paper["authors"] diff --git a/rag/app/presentation.py b/rag/app/presentation.py index 303af34..69c8777 100644 --- a/rag/app/presentation.py +++ b/rag/app/presentation.py @@ -42,10 +42,8 @@ class Ppt(object): txt = self.__extract(shape) if txt: texts.append(txt) txts.append("\n".join(texts)) - callback__((i+1)/self.total_page/2, "", callback) - callback__((min(to_page, self.total_page) - from_page) / self.total_page, - "Page {}~{}: Text extraction finished".format(from_page, min(to_page, self.total_page)), callback) + callback__(0.5, "Text extraction finished.", callback) import aspose.slides as slides import aspose.pydrawing as drawing imgs = [] @@ -55,8 +53,7 @@ class Ppt(object): slide.get_thumbnail(0.5, 0.5).save(buffered, drawing.imaging.ImageFormat.jpeg) imgs.append(buffered.getvalue()) assert len(imgs) == len(txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts)) - callback__((min(to_page, self.total_page) - from_page) / self.total_page, - "Page {}~{}: Image extraction finished".format(from_page, min(to_page, self.total_page)), callback) + callback__(0.9, "Image extraction finished", callback) self.is_english = is_english(txts) return [(txts[i], imgs[i]) for i in range(len(txts))] @@ -73,7 +70,7 @@ class Pdf(HuParser): def __call__(self, filename, binary=None, from_page=0, to_page=100000, zoomin=3, callback=None): self.__images__(filename if not binary else binary, zoomin, from_page, to_page) - callback__((min(to_page, self.total_page)-from_page) / self.total_page, "Page {}~{}: OCR finished".format(from_page, min(to_page, self.total_page)), callback) + callback__(0.8, "Page {}~{}: OCR finished".format(from_page, min(to_page, self.total_page)), callback) assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(len(self.boxes), len(self.page_images)) res = [] #################### More precisely ################### @@ -92,6 +89,7 @@ class Pdf(HuParser): for i in range(len(self.boxes)): lines = "\n".join([b["text"] for b in self.boxes[i] if not self.__garbage(b["text"])]) res.append((lines, self.page_images[i])) + callback__(0.9, "Page {}~{}: Parsing finished".format(from_page, min(to_page, self.total_page)), callback) return res @@ -104,13 +102,13 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None): res = [] if re.search(r"\.pptx?$", filename, re.IGNORECASE): ppt_parser = Ppt() - for txt,img in ppt_parser(filename if not binary else binary, from_page, to_page, callback): + for txt,img in ppt_parser(filename if not binary else binary, from_page, 1000000, callback): d = copy.deepcopy(doc) d["image"] = img tokenize(d, txt, ppt_parser.is_english) res.append(d) return res - if re.search(r"\.pdf$", filename, re.IGNORECASE): + elif re.search(r"\.pdf$", filename, re.IGNORECASE): pdf_parser = Pdf() for txt,img in pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback): d = copy.deepcopy(doc) @@ -118,7 +116,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None): tokenize(d, txt, pdf_parser.is_english) res.append(d) return res - callback__(-1, "This kind of presentation document did not support yet!", callback) + + raise NotImplementedError("file type not supported yet(pptx, pdf supported)") if __name__== "__main__": diff --git a/rag/parser/pdf_parser.py b/rag/parser/pdf_parser.py index 53cfbdb..0f9def0 100644 --- a/rag/parser/pdf_parser.py +++ b/rag/parser/pdf_parser.py @@ -1559,6 +1559,15 @@ class HuParser: return "\n\n".join(res) + @staticmethod + def total_page_number(fnm, binary=None): + try: + pdf = pdfplumber.open(fnm) if not binary else pdfplumber.open(BytesIO(binary)) + return len(pdf.pages) + except Exception as e: + pdf = fitz.open(fnm) if not binary else fitz.open(stream=fnm, filetype="pdf") + return len(pdf) + def __images__(self, fnm, zoomin=3, page_from=0, page_to=299): self.lefted_chars = [] self.mean_height = [] diff --git a/rag/svr/task_broker.py b/rag/svr/task_broker.py new file mode 100644 index 0000000..f5ea4f4 --- /dev/null +++ b/rag/svr/task_broker.py @@ -0,0 +1,130 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +import os +import time +import random +from timeit import default_timer as timer +from api.db.db_models import Task +from api.db.db_utils import bulk_insert_into_db +from api.db.services.task_service import TaskService +from rag.parser.pdf_parser import HuParser +from rag.settings import cron_logger +from rag.utils import MINIO +from rag.utils import findMaxTm +import pandas as pd +from api.db import FileType +from api.db.services.document_service import DocumentService +from api.settings import database_logger +from api.utils import get_format_time, get_uuid +from api.utils.file_utils import get_project_base_directory + + +def collect(tm): + docs = DocumentService.get_newly_uploaded(tm) + if len(docs) == 0: + return pd.DataFrame() + docs = pd.DataFrame(docs) + mtm = docs["update_time"].max() + cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm)) + return docs + + +def set_dispatching(docid): + try: + DocumentService.update_by_id( + docid, {"progress": random.randint(0, 3) / 100., + "progress_msg": "Task dispatched...", + "process_begin_at": get_format_time() + }) + except Exception as e: + cron_logger.error("set_dispatching:({}), {}".format(docid, str(e))) + + +def dispatch(): + tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"broker.tm") + tm = findMaxTm(tm_fnm) + rows = collect(tm) + if len(rows) == 0: + return + + tmf = open(tm_fnm, "a+") + for _, r in rows.iterrows(): + try: + tsks = TaskService.query(doc_id=r["id"]) + if tsks: + for t in tsks: + TaskService.delete_by_id(t.id) + except Exception as e: + cron_logger.error("delete task exception:" + str(e)) + + def new_task(): + nonlocal r + return { + "id": get_uuid(), + "doc_id": r["id"] + } + + tsks = [] + if r["type"] == FileType.PDF.value: + pages = HuParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"])) + for p in range(0, pages, 10): + task = new_task() + task["from_page"] = p + task["to_page"] = min(p + 10, pages) + tsks.append(task) + else: + tsks.append(new_task()) + print(tsks) + bulk_insert_into_db(Task, tsks, True) + set_dispatching(r["id"]) + tmf.write(str(r["update_time"]) + "\n") + tmf.close() + + +def update_progress(): + docs = DocumentService.get_unfinished_docs() + for d in docs: + try: + tsks = TaskService.query(doc_id=d["id"], order_by=Task.create_time) + if not tsks:continue + msg = [] + prg = 0 + finished = True + bad = 0 + for t in tsks: + if 0 <= t.progress < 1: finished = False + prg += t.progress if t.progress >= 0 else 0 + msg.append(t.progress_msg) + if t.progress == -1: bad += 1 + prg /= len(tsks) + if finished and bad: prg = -1 + msg = "\n".join(msg) + DocumentService.update_by_id(d["id"], {"progress": prg, "progress_msg": msg, "process_duation": timer()-d["process_begin_at"].timestamp()}) + except Exception as e: + cron_logger.error("fetch task exception:" + str(e)) + + +if __name__ == "__main__": + peewee_logger = logging.getLogger('peewee') + peewee_logger.propagate = False + peewee_logger.addHandler(database_logger.handlers[0]) + peewee_logger.setLevel(database_logger.level) + + while True: + dispatch() + time.sleep(3) + update_progress() diff --git a/rag/svr/parse_user_docs.py b/rag/svr/task_executor.py similarity index 53% rename from rag/svr/parse_user_docs.py rename to rag/svr/task_executor.py index 88bc585..ef40b3b 100644 --- a/rag/svr/parse_user_docs.py +++ b/rag/svr/task_executor.py @@ -19,49 +19,59 @@ import logging import os import hashlib import copy -import time -import random import re +import sys +from functools import partial from timeit import default_timer as timer +from api.db.services.task_service import TaskService from rag.llm import EmbeddingModel, CvModel from rag.settings import cron_logger, DOC_MAXIMUM_SIZE from rag.utils import ELASTICSEARCH from rag.utils import MINIO from rag.utils import rmSpace, findMaxTm -from rag.nlp import huchunk, huqie, search +from rag.nlp import search from io import BytesIO import pandas as pd -from elasticsearch_dsl import Q -from PIL import Image -from rag.parser import ( - PdfParser, - DocxParser, - ExcelParser -) -from rag.nlp.huchunk import ( - PdfChunker, - DocxChunker, - ExcelChunker, - PptChunker, - TextChunker -) -from api.db import LLMType + +from rag.app import laws, paper, presentation, manual + +from api.db import LLMType, ParserType from api.db.services.document_service import DocumentService -from api.db.services.llm_service import TenantLLMService, LLMBundle +from api.db.services.llm_service import LLMBundle from api.settings import database_logger -from api.utils import get_format_time from api.utils.file_utils import get_project_base_directory BATCH_SIZE = 64 -PDF = PdfChunker(PdfParser()) -DOC = DocxChunker(DocxParser()) -EXC = ExcelChunker(ExcelParser()) -PPT = PptChunker() +FACTORY = { + ParserType.GENERAL.value: laws, + ParserType.PAPER.value: paper, + ParserType.PRESENTATION.value: presentation, + ParserType.MANUAL.value: manual, + ParserType.LAWS.value: laws, +} + +def set_progress(task_id, from_page, to_page, prog=None, msg="Processing..."): + cancel = TaskService.do_cancel(task_id) + if cancel: + msg = "Canceled." + prog = -1 + + if to_page > 0: msg = f"Page({from_page}~{to_page}): " + msg + d = {"progress_msg": msg} + if prog is not None: d["progress"] = prog + try: + TaskService.update_by_id(task_id, d) + except Exception as e: + cron_logger.error("set_progress:({}), {}".format(task_id, str(e))) + if cancel:sys.exit() + + +""" def chuck_doc(name, binary, tenant_id, cvmdl=None): suff = os.path.split(name)[-1].lower().split(".")[-1] if suff.find("pdf") >= 0: @@ -81,27 +91,17 @@ def chuck_doc(name, binary, tenant_id, cvmdl=None): return field return TextChunker()(binary) +""" def collect(comm, mod, tm): - docs = DocumentService.get_newly_uploaded(tm, mod, comm) - if len(docs) == 0: + tasks = TaskService.get_tasks(tm, mod, comm) + if len(tasks) == 0: return pd.DataFrame() - docs = pd.DataFrame(docs) - mtm = docs["update_time"].max() - cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm)) - return docs - - -def set_progress(docid, prog, msg="Processing...", begin=False): - d = {"progress": prog, "progress_msg": msg} - if begin: - d["process_begin_at"] = get_format_time() - try: - DocumentService.update_by_id( - docid, {"progress": prog, "progress_msg": msg}) - except Exception as e: - cron_logger.error("set_progress:({}), {}".format(docid, str(e))) + tasks = pd.DataFrame(tasks) + mtm = tasks["update_time"].max() + cron_logger.info("TOTAL:{}, To:{}".format(len(tasks), mtm)) + return tasks def build(row, cvmdl): @@ -110,97 +110,50 @@ def build(row, cvmdl): (int(DOC_MAXIMUM_SIZE / 1024 / 1024))) return [] - # res = ELASTICSEARCH.search(Q("term", doc_id=row["id"])) - # if ELASTICSEARCH.getTotal(res) > 0: - # ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]), - # scripts=""" - # if(!ctx._source.kb_id.contains('%s')) - # ctx._source.kb_id.add('%s'); - # """ % (str(row["kb_id"]), str(row["kb_id"])), - # idxnm=search.index_name(row["tenant_id"]) - # ) - # set_progress(row["id"], 1, "Done") - # return [] - - random.seed(time.time()) - set_progress(row["id"], random.randint(0, 20) / - 100., "Finished preparing! Start to slice file!", True) + callback = partial(set_progress, row["id"], row["from_page"], row["to_page"]) + chunker = FACTORY[row["parser_id"]] try: cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"])) - obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), row["tenant_id"], cvmdl) + cks = chunker.chunk(row["name"], MINIO.get(row["kb_id"], row["location"]), row["from_page"], row["to_page"], + callback) except Exception as e: if re.search("(No such file|not found)", str(e)): - set_progress( - row["id"], -1, "Can not find file <%s>" % - row["doc_name"]) + callback(-1, "Can not find file <%s>" % row["doc_name"]) else: - set_progress( - row["id"], -1, f"Internal server error: %s" % - str(e).replace( - "'", "")) + callback(-1, f"Internal server error: %s" % str(e).replace("'", "")) cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e))) return [] - if not obj.text_chunks and not obj.table_chunks: - set_progress( - row["id"], - 1, - "Nothing added! Mostly, file type unsupported yet.") - return [] - - set_progress(row["id"], random.randint(20, 60) / 100., - "Finished slicing files. Start to embedding the content.") + callback(msg="Finished slicing files. Start to embedding the content.") + docs = [] doc = { - "doc_id": row["id"], - "kb_id": [str(row["kb_id"])], - "docnm_kwd": os.path.split(row["location"])[-1], - "title_tks": huqie.qie(row["name"]) + "doc_id": row["doc_id"], + "kb_id": [str(row["kb_id"])] } - doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) - output_buffer = BytesIO() - docs = [] - for txt, img in obj.text_chunks: + for ck in cks: d = copy.deepcopy(doc) + d.update(ck) md5 = hashlib.md5() - md5.update((txt + str(d["doc_id"])).encode("utf-8")) + md5.update((ck["content_with_weight"] + str(d["doc_id"])).encode("utf-8")) d["_id"] = md5.hexdigest() - d["content_ltks"] = huqie.qie(txt) - d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) - if not img: + d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] + if not d.get("image"): docs.append(d) continue - if isinstance(img, bytes): - output_buffer = BytesIO(img) + output_buffer = BytesIO() + if isinstance(d["image"], bytes): + output_buffer = BytesIO(d["image"]) else: - img.save(output_buffer, format='JPEG') + d["image"].save(output_buffer, format='JPEG') MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue()) d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"]) - d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] docs.append(d) - for arr, img in obj.table_chunks: - for i, txt in enumerate(arr): - d = copy.deepcopy(doc) - d["content_ltks"] = huqie.qie(txt) - md5 = hashlib.md5() - md5.update((txt + str(d["doc_id"])).encode("utf-8")) - d["_id"] = md5.hexdigest() - if not img: - docs.append(d) - continue - img.save(output_buffer, format='JPEG') - MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue()) - d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"]) - d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] - docs.append(d) - set_progress(row["id"], random.randint(60, 70) / - 100., "Continue embedding the content.") - return docs @@ -213,7 +166,7 @@ def init_kb(row): def embedding(docs, mdl): - tts, cnts = [rmSpace(d["title_tks"]) for d in docs], [rmSpace(d["content_ltks"]) for d in docs] + tts, cnts = [d["docnm_kwd"] for d in docs], [d["content_with_weight"] for d in docs] tk_count = 0 tts, c = mdl.encode(tts) tk_count += c @@ -223,7 +176,7 @@ def embedding(docs, mdl): assert len(vects) == len(docs) for i, d in enumerate(docs): v = vects[i].tolist() - d["q_%d_vec"%len(v)] = v + d["q_%d_vec" % len(v)] = v return tk_count @@ -239,11 +192,12 @@ def main(comm, mod): try: embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING) cv_mdl = LLMBundle(r["tenant_id"], LLMType.IMAGE2TEXT) - #TODO: sequence2text model + # TODO: sequence2text model except Exception as e: set_progress(r["id"], -1, str(e)) continue + callback = partial(set_progress, r["id"], r["from_page"], r["to_page"]) st_tm = timer() cks = build(r, cv_mdl) if not cks: @@ -254,21 +208,20 @@ def main(comm, mod): try: tk_count = embedding(cks, embd_mdl) except Exception as e: - set_progress(r["id"], -1, "Embedding error:{}".format(str(e))) + callback(-1, "Embedding error:{}".format(str(e))) cron_logger.error(str(e)) continue - set_progress(r["id"], random.randint(70, 95) / 100., - "Finished embedding! Start to build index!") + callback(msg="Finished embedding! Start to build index!") init_kb(r) chunk_count = len(set([c["_id"] for c in cks])) + callback(1., "Done!") es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"])) if es_r: - set_progress(r["id"], -1, "Index failure!") + callback(-1, "Index failure!") cron_logger.error(str(es_r)) else: - set_progress(r["id"], 1., "Done!") - DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, chunk_count, timer()-st_tm) + DocumentService.increment_chunk_num(r["doc_id"], r["kb_id"], tk_count, chunk_count, 0) cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks))) tmf.write(str(r["update_time"]) + "\n") @@ -282,5 +235,6 @@ if __name__ == "__main__": peewee_logger.setLevel(database_logger.level) from mpi4py import MPI + comm = MPI.COMM_WORLD main(comm.Get_size(), comm.Get_rank()) -- GitLab