diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index ef1704cd6976ee742a55a86a718bacce407a39b5..53caa76036ad6c01dd18fc546bfc99bf84360f5e 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -18,10 +18,12 @@ import datetime from flask import request from flask_login import login_required, current_user from elasticsearch_dsl import Q + +from rag.app.qa import rmPrefix, beAdoc from rag.nlp import search, huqie, retrievaler from rag.utils import ELASTICSEARCH, rmSpace -from api.db import LLMType -from api.db.services.kb_service import KnowledgebaseService +from api.db import LLMType, ParserType +from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import TenantLLMService from api.db.services.user_service import UserTenantService from api.utils.api_utils import server_error_response, get_data_error_result, validate_request @@ -89,10 +91,8 @@ def get(): res["chunk_id"] = id k = [] for n in res.keys(): - if re.search(r"(_vec$|_sm_)", n): + if re.search(r"(_vec$|_sm_|_tks|_ltks)", n): k.append(n) - if re.search(r"(_tks|_ltks)", n): - res[n] = rmSpace(res[n]) for n in k: del res[n] @@ -106,12 +106,12 @@ def get(): @manager.route('/set', methods=['POST']) @login_required -@validate_request("doc_id", "chunk_id", "content_ltks", +@validate_request("doc_id", "chunk_id", "content_with_weight", "important_kwd") def set(): req = request.json d = {"id": req["chunk_id"]} - d["content_ltks"] = huqie.qie(req["content_ltks"]) + d["content_ltks"] = huqie.qie(req["content_with_weight"]) d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) d["important_kwd"] = req["important_kwd"] d["important_tks"] = huqie.qie(" ".join(req["important_kwd"])) @@ -127,8 +127,15 @@ def set(): e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: return get_data_error_result(retmsg="Document not found!") + + if doc.parser_id == ParserType.QA: + arr = [t for t in re.split(r"[\n\t]", req["content_with_weight"]) if len(t)>1] + if len(arr) != 2: return get_data_error_result(retmsg="Q&A must be separated by TAB/ENTER key.") + q, a = rmPrefix(arr[0]), rmPrefix[arr[1]] + d = beAdoc(d, arr[0], arr[1], not any([huqie.is_chinese(t) for t in q+a])) + v, c = embd_mdl.encode([doc.name, req["content_ltks"]]) - v = 0.1 * v[0] + 0.9 * v[1] + v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] d["q_%d_vec" % len(v)] = v.tolist() ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) return get_json_result(data=True) diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py index b78ef7307149790569c778e6d801dbc70731e393..d46a892e185ca11964749373fb1bbcfe4f2701bb 100644 --- a/api/apps/dialog_app.py +++ b/api/apps/dialog_app.py @@ -18,7 +18,7 @@ from flask import request from flask_login import login_required, current_user from api.db.services.dialog_service import DialogService from api.db import StatusEnum -from api.db.services.kb_service import KnowledgebaseService +from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.user_service import TenantService from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils import get_uuid diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 1a57aa0108173b75b124b6d06a2b0f1c3e939405..207ae84fc9704f4c2737b4f401dbc463d8c54948 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -27,10 +27,10 @@ from api.db.services.task_service import TaskService from rag.nlp import search from rag.utils import ELASTICSEARCH from api.db.services import duplicate_name -from api.db.services.kb_service import KnowledgebaseService +from api.db.services.knowledgebase_service import KnowledgebaseService from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils import get_uuid -from api.db import FileType +from api.db import FileType, TaskStatus from api.db.services.document_service import DocumentService from api.settings import RetCode from api.utils.api_utils import get_json_result @@ -210,13 +210,12 @@ def rm(): @manager.route('/run', methods=['POST']) @login_required @validate_request("doc_ids", "run") -def rm(): +def run(): req = request.json try: for id in req["doc_ids"]: - DocumentService.update_by_id(id, {"run": str(req["run"])}) - if req["run"] == "2": - TaskService.filter_delete([Task.doc_id == id]) + DocumentService.update_by_id(id, {"run": str(req["run"]), "progress": 0}) + if str(req["run"]) == TaskStatus.CANCEL.value: tenant_id = DocumentService.get_tenant_id(id) if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") @@ -284,12 +283,13 @@ def change_parser(): if doc.parser_id.lower() == req["parser_id"].lower(): return get_json_result(data=True) - e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": "", "run": 1}) - if not e: - return get_data_error_result(retmsg="Document not found!") - e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1) + e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""}) if not e: return get_data_error_result(retmsg="Document not found!") + if doc.token_num>0: + e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1) + if not e: + return get_data_error_result(retmsg="Document not found!") return get_json_result(data=True) except Exception as e: diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index d2236db2e91a928c278634569ae23163cc5264bf..15e6be8cc0545ba51504d2a6b34a1143a130e36d 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -21,7 +21,7 @@ from api.db.services.user_service import TenantService, UserTenantService from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils import get_uuid, get_format_time from api.db import StatusEnum, UserTenantRole -from api.db.services.kb_service import KnowledgebaseService +from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.db_models import Knowledgebase from api.settings import stat_logger, RetCode from api.utils.api_utils import get_json_result diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 614a902e0c437b05a172323b47b20e839e5ed11c..43f31090ce65bc21517ff25d3bfee8772efc2b8d 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -22,7 +22,7 @@ from api.db.services.user_service import TenantService, UserTenantService from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils import get_uuid, get_format_time from api.db import StatusEnum, UserTenantRole -from api.db.services.kb_service import KnowledgebaseService +from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.db_models import Knowledgebase, TenantLLM from api.settings import stat_logger, RetCode from api.utils.api_utils import get_json_result diff --git a/api/db/__init__.py b/api/db/__init__.py index c84ee15c56ba87c72c8452945e9bf26377d49917..de376134e3e1dc1a8a6d61b7fdabfc770c80c691 100644 --- a/api/db/__init__.py +++ b/api/db/__init__.py @@ -61,12 +61,19 @@ class ChatStyle(StrEnum): CUSTOM = 'Custom' +class TaskStatus(StrEnum): + RUNNING = "1" + CANCEL = "2" + DONE = "3" + FAIL = "4" + + class ParserType(StrEnum): GENERAL = "general" PRESENTATION = "presentation" LAWS = "laws" MANUAL = "manual" PAPER = "paper" - RESUME = "" - BOOK = "" - QA = "" + RESUME = "resume" + BOOK = "book" + QA = "qa" diff --git a/api/db/db_utils.py b/api/db/db_utils.py index c5ad0240e3a370de6bee5286605d75ad1f1d95e3..1e5a384baa8338e7604a86746fdaeab5d9a02e25 100644 --- a/api/db/db_utils.py +++ b/api/db/db_utils.py @@ -33,8 +33,8 @@ def bulk_insert_into_db(model, data_source, replace_on_conflict=False): DB.create_tables([model]) - for data in data_source: - current_time = current_timestamp() + for i,data in enumerate(data_source): + current_time = current_timestamp() + i current_date = timestamp_to_date(current_time) if 'create_time' not in data: data['create_time'] = current_time diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index aba5afd694af479b9da07cbe9ef663614861d46f..50b54abb5abb86e87b21c70176e043f8deaaf1b5 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -15,11 +15,11 @@ # from peewee import Expression -from api.db import TenantPermission, FileType +from api.db import TenantPermission, FileType, TaskStatus from api.db.db_models import DB, Knowledgebase, Tenant from api.db.db_models import Document from api.db.services.common_service import CommonService -from api.db.services.kb_service import KnowledgebaseService +from api.db.services.knowledgebase_service import KnowledgebaseService from api.db import StatusEnum @@ -71,6 +71,7 @@ class DocumentService(CommonService): ~(cls.model.type == FileType.VIRTUAL.value), cls.model.progress == 0, cls.model.update_time >= tm, + cls.model.run == TaskStatus.RUNNING.value, (Expression(cls.model.create_time, "%%", comm) == mod))\ .order_by(cls.model.update_time.asc())\ .paginate(1, items_per_page) diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index e849973c5cc8f5d7859a87d348ab3d4af9c785b4..63de6c2dada0e4a797ad669ee649614d8fca1da1 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -13,13 +13,52 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from api.db.db_models import Knowledgebase, Document +from api.db import StatusEnum, TenantPermission +from api.db.db_models import Knowledgebase, DB, Tenant from api.db.services.common_service import CommonService class KnowledgebaseService(CommonService): model = Knowledgebase + @classmethod + @DB.connection_context() + def get_by_tenant_ids(cls, joined_tenant_ids, user_id, + page_number, items_per_page, orderby, desc): + kbs = cls.model.select().where( + ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == + TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id)) + & (cls.model.status == StatusEnum.VALID.value) + ) + if desc: + kbs = kbs.order_by(cls.model.getter_by(orderby).desc()) + else: + kbs = kbs.order_by(cls.model.getter_by(orderby).asc()) -class DocumentService(CommonService): - model = Document + kbs = kbs.paginate(page_number, items_per_page) + + return list(kbs.dicts()) + + @classmethod + @DB.connection_context() + def get_detail(cls, kb_id): + fields = [ + cls.model.id, + Tenant.embd_id, + cls.model.avatar, + cls.model.name, + cls.model.description, + cls.model.permission, + cls.model.doc_num, + cls.model.token_num, + cls.model.chunk_num, + cls.model.parser_id] + kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where( + (cls.model.id == kb_id), + (cls.model.status == StatusEnum.VALID.value) + ) + if not kbs: + return + d = kbs[0].to_dict() + d["embd_id"] = kbs[0].tenant.embd_id + return d diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index e73759ebd929fb5da99279448cb413c6f06b0bdd..6cc62b2c92cb8437271099b41e7fb8bc41568a85 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -1,53 +1,55 @@ -# -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from peewee import Expression -from api.db.db_models import DB -from api.db import StatusEnum, FileType -from api.db.db_models import Task, Document, Knowledgebase, Tenant -from api.db.services.common_service import CommonService - - -class TaskService(CommonService): - model = Task - - @classmethod - @DB.connection_context() - def get_tasks(cls, tm, mod=0, comm=1, items_per_page=64): - fields = [cls.model.id, cls.model.doc_id, cls.model.from_page,cls.model.to_page, Document.kb_id, Document.parser_id, Document.name, Document.type, Document.location, Document.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] - docs = cls.model.select(*fields) \ - .join(Document, on=(cls.model.doc_id == Document.id)) \ - .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \ - .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\ - .where( - Document.status == StatusEnum.VALID.value, - ~(Document.type == FileType.VIRTUAL.value), - cls.model.progress == 0, - cls.model.update_time >= tm, - (Expression(cls.model.create_time, "%%", comm) == mod))\ - .order_by(cls.model.update_time.asc())\ - .paginate(1, items_per_page) - return list(docs.dicts()) - - - @classmethod - @DB.connection_context() - def do_cancel(cls, id): - try: - cls.model.get_by_id(id) - return False - except Exception as e: - pass - return True +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from peewee import Expression +from api.db.db_models import DB +from api.db import StatusEnum, FileType, TaskStatus +from api.db.db_models import Task, Document, Knowledgebase, Tenant +from api.db.services.common_service import CommonService +from api.db.services.document_service import DocumentService + + +class TaskService(CommonService): + model = Task + + @classmethod + @DB.connection_context() + def get_tasks(cls, tm, mod=0, comm=1, items_per_page=64): + fields = [cls.model.id, cls.model.doc_id, cls.model.from_page,cls.model.to_page, Document.kb_id, Document.parser_id, Document.name, Document.type, Document.location, Document.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] + docs = cls.model.select(*fields) \ + .join(Document, on=(cls.model.doc_id == Document.id)) \ + .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \ + .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\ + .where( + Document.status == StatusEnum.VALID.value, + ~(Document.type == FileType.VIRTUAL.value), + cls.model.progress == 0, + cls.model.update_time >= tm, + (Expression(cls.model.create_time, "%%", comm) == mod))\ + .order_by(cls.model.update_time.asc())\ + .paginate(1, items_per_page) + return list(docs.dicts()) + + + @classmethod + @DB.connection_context() + def do_cancel(cls, id): + try: + task = cls.model.get_by_id(id) + _, doc = DocumentService.get_by_id(task.doc_id) + return doc.run == TaskStatus.CANCEL.value + except Exception as e: + pass + return True diff --git a/api/utils/file_utils.py b/api/utils/file_utils.py index 14a2e3c2d18f65fcb1276ece88e890035edfead0..92771f457cf2d0aaff2be3734b51c0536c0284d5 100644 --- a/api/utils/file_utils.py +++ b/api/utils/file_utils.py @@ -143,7 +143,7 @@ def filename_type(filename): if re.match(r".*\.pdf$", filename): return FileType.PDF.value - if re.match(r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename): + if re.match(r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md)$", filename): return FileType.DOC.value if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): diff --git a/rag/app/__init__.py b/rag/app/__init__.py index 4b06f20a326fcc4ab51045c52c64c81fd6963677..06787b87c8e8c429aea4b631b2c77d380bff4c38 100644 --- a/rag/app/__init__.py +++ b/rag/app/__init__.py @@ -4,14 +4,8 @@ from nltk import word_tokenize from rag.nlp import stemmer, huqie - -def callback__(progress, msg, func): - if not func :return - func(progress, msg) - - BULLET_PATTERN = [[ - r"第[零一二三四五ĺ…ä¸ĺ…«äąťĺŤç™ľ]+编", + r"第[零一二三四五ĺ…ä¸ĺ…«äąťĺŤç™ľ]+(编|é¨ĺ†)", r"第[零一二三四五ĺ…ä¸ĺ…«äąťĺŤç™ľ]+ç« ", r"第[零一二三四五ĺ…ä¸ĺ…«äąťĺŤç™ľ]+节", r"第[零一二三四五ĺ…ä¸ĺ…«äąťĺŤç™ľ]+条", @@ -22,6 +16,8 @@ BULLET_PATTERN = [[ r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}", r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}", ], [ + r"第[零一二三四五ĺ…ä¸ĺ…«äąťĺŤç™ľ]+ç« ", + r"第[零一二三四五ĺ…ä¸ĺ…«äąťĺŤç™ľ]+节", r"[零一二三四五ĺ…ä¸ĺ…«äąťĺŤç™ľ]+[ ă€]", r"[\(ďĽ][零一二三四五ĺ…ä¸ĺ…«äąťĺŤç™ľ]+[\))]", r"[\(ďĽ][0-9]{,2}[\))]", @@ -54,7 +50,7 @@ def bullets_category(sections): def is_english(texts): eng = 0 for t in texts: - if re.match(r"[a-zA-Z]", t.strip()): + if re.match(r"[a-zA-Z]{2,}", t.strip()): eng += 1 if eng / len(texts) > 0.8: return True @@ -70,3 +66,26 @@ def tokenize(d, t, eng): d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) +def remove_contents_table(sections, eng=False): + i = 0 + while i < len(sections): + def get(i): + nonlocal sections + return (sections[i] if type(sections[i]) == type("") else sections[i][0]).strip() + if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$", re.sub(r"( | |\u3000)+", "", get(i).split("@@")[0], re.IGNORECASE)): + i += 1 + continue + sections.pop(i) + if i >= len(sections): break + prefix = get(i)[:3] if not eng else " ".join(get(i).split(" ")[:2]) + while not prefix: + sections.pop(i) + if i >= len(sections): break + prefix = get(i)[:3] if not eng else " ".join(get(i).split(" ")[:2]) + sections.pop(i) + if i >= len(sections) or not prefix: break + for j in range(i, min(i+128, len(sections))): + if not re.match(prefix, get(j)): + continue + for _ in range(i, j):sections.pop(i) + break \ No newline at end of file diff --git a/rag/app/book.py b/rag/app/book.py new file mode 100644 index 0000000000000000000000000000000000000000..59948ef2d3a17d1946097aba98688226b932eec0 --- /dev/null +++ b/rag/app/book.py @@ -0,0 +1,156 @@ +import copy +import random +import re +from io import BytesIO +from docx import Document +import numpy as np +from rag.app import bullets_category, BULLET_PATTERN, is_english, tokenize, remove_contents_table +from rag.nlp import huqie +from rag.parser.docx_parser import HuDocxParser +from rag.parser.pdf_parser import HuParser + + +class Pdf(HuParser): + def __call__(self, filename, binary=None, from_page=0, + to_page=100000, zoomin=3, callback=None): + self.__images__( + filename if not binary else binary, + zoomin, + from_page, + to_page) + callback(0.1, "OCR finished") + + from timeit import default_timer as timer + start = timer() + self._layouts_paddle(zoomin) + callback(0.47, "Layout analysis finished") + print("paddle layouts:", timer() - start) + self._table_transformer_job(zoomin) + callback(0.68, "Table analysis finished") + self._text_merge() + column_width = np.median([b["x1"] - b["x0"] for b in self.boxes]) + self._concat_downward(concat_between_pages=False) + self._filter_forpages() + self._merge_with_same_bullet() + callback(0.75, "Text merging finished.") + tbls = self._extract_table_figure(True, zoomin, False) + + callback(0.8, "Text extraction finished") + + return [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno","")) for b in self.boxes] + + +def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None): + doc = { + "docnm_kwd": filename, + "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename)) + } + doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) + pdf_parser = None + sections,tbls = [], [] + if re.search(r"\.docx?$", filename, re.IGNORECASE): + callback(0.1, "Start to parse.") + doc_parser = HuDocxParser() + # TODO: table of contents need to be removed + sections, tbls = doc_parser(binary if binary else filename) + remove_contents_table(sections, eng = is_english(random.choices([t for t,_ in sections], k=200))) + callback(0.8, "Finish parsing.") + elif re.search(r"\.pdf$", filename, re.IGNORECASE): + pdf_parser = Pdf() + sections,tbls = pdf_parser(filename if not binary else binary, + from_page=from_page, to_page=to_page, callback=callback) + elif re.search(r"\.txt$", filename, re.IGNORECASE): + callback(0.1, "Start to parse.") + txt = "" + if binary:txt = binary.decode("utf-8") + else: + with open(filename, "r") as f: + while True: + l = f.readline() + if not l:break + txt += l + sections = txt.split("\n") + sections = [(l,"") for l in sections if l] + remove_contents_table(sections, eng = is_english(random.choices([t for t,_ in sections], k=200))) + callback(0.8, "Finish parsing.") + else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") + + bull = bullets_category([b["text"] for b in random.choices([t for t,_ in sections], k=100)]) + projs = [len(BULLET_PATTERN[bull]) + 1] * len(sections) + levels = [[]] * len(BULLET_PATTERN[bull]) + 2 + for i, (txt, layout) in enumerate(sections): + for j, p in enumerate(BULLET_PATTERN[bull]): + if re.match(p, txt.strip()): + projs[i] = j + levels[j].append(i) + break + else: + if re.search(r"(title|head)", layout): + projs[i] = BULLET_PATTERN[bull] + levels[BULLET_PATTERN[bull]].append(i) + else: + levels[BULLET_PATTERN[bull] + 1].append(i) + sections = [t for t,_ in sections] + + def binary_search(arr, target): + if target > arr[-1]: return len(arr) - 1 + if target > arr[0]: return -1 + s, e = 0, len(arr) + while e - s > 1: + i = (e + s) // 2 + if target > arr[i]: + s = i + continue + elif target < arr[i]: + e = i + continue + else: + assert False + return s + + cks = [] + readed = [False] * len(sections) + levels = levels[::-1] + for i, arr in enumerate(levels): + for j in arr: + if readed[j]: continue + readed[j] = True + cks.append([j]) + if i + 1 == len(levels) - 1: continue + for ii in range(i + 1, len(levels)): + jj = binary_search(levels[ii], j) + if jj < 0: break + if jj > cks[-1][-1]: cks[-1].pop(-1) + cks[-1].append(levels[ii][jj]) + + # is it English + eng = is_english(random.choices(sections, k=218)) + + res = [] + # add tables + for img, rows in tbls: + bs = 10 + de = ";" if eng else ";" + for i in range(0, len(rows), bs): + d = copy.deepcopy(doc) + r = de.join(rows[i:i + bs]) + r = re.sub(r"\t——(来自| in ).*”%s" % de, "", r) + tokenize(d, r, eng) + d["image"] = img + res.append(d) + # wrap up to es documents + for ck in cks: + print("\n-".join(ck[::-1])) + ck = "\n".join(ck[::-1]) + d = copy.deepcopy(doc) + if pdf_parser: + d["image"] = pdf_parser.crop(ck) + ck = pdf_parser.remove_tag(ck) + tokenize(d, ck, eng) + res.append(d) + return res + + +if __name__ == "__main__": + import sys + chunk(sys.argv[1]) diff --git a/rag/app/laws.py b/rag/app/laws.py index 34f12a33890fb597442f512736227c8386f42576..c68d3b85bb07c125c49c0b6610788983486e9df8 100644 --- a/rag/app/laws.py +++ b/rag/app/laws.py @@ -3,7 +3,7 @@ import re from io import BytesIO from docx import Document import numpy as np -from rag.app import callback__, bullets_category, BULLET_PATTERN, is_english, tokenize +from rag.app import bullets_category, BULLET_PATTERN, is_english, tokenize from rag.nlp import huqie from rag.parser.docx_parser import HuDocxParser from rag.parser.pdf_parser import HuParser @@ -32,12 +32,12 @@ class Pdf(HuParser): zoomin, from_page, to_page) - callback__(0.1, "OCR finished", callback) + callback(0.1, "OCR finished") from timeit import default_timer as timer start = timer() self._layouts_paddle(zoomin) - callback__(0.77, "Layout analysis finished", callback) + callback(0.77, "Layout analysis finished") print("paddle layouts:", timer()-start) bxs = self.sort_Y_firstly(self.boxes, np.median(self.mean_height) / 3) # is it English @@ -75,7 +75,7 @@ class Pdf(HuParser): b["x1"] = max(b["x1"], b_["x1"]) bxs.pop(i + 1) - callback__(0.8, "Text extraction finished", callback) + callback(0.8, "Text extraction finished") return [b["text"] + self._line_tag(b, zoomin) for b in bxs] @@ -89,17 +89,17 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None): pdf_parser = None sections = [] if re.search(r"\.docx?$", filename, re.IGNORECASE): - callback__(0.1, "Start to parse.", callback) + callback(0.1, "Start to parse.") for txt in Docx()(filename, binary): sections.append(txt) - callback__(0.8, "Finish parsing.", callback) + callback(0.8, "Finish parsing.") elif re.search(r"\.pdf$", filename, re.IGNORECASE): pdf_parser = Pdf() for txt in pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback): sections.append(txt) elif re.search(r"\.txt$", filename, re.IGNORECASE): - callback__(0.1, "Start to parse.", callback) + callback(0.1, "Start to parse.") txt = "" if binary:txt = binary.decode("utf-8") else: @@ -110,7 +110,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None): txt += l sections = txt.split("\n") sections = [l for l in sections if l] - callback__(0.8, "Finish parsing.", callback) + callback(0.8, "Finish parsing.") else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") # is it English @@ -118,7 +118,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None): # Remove 'Contents' part i = 0 while i < len(sections): - if not re.match(r"(Contents|目录|目次)$", re.sub(r"( | |\u3000)+", "", sections[i].split("@@")[0])): + if not re.match(r"(contents|目录|目次|table of contents)$", re.sub(r"( | |\u3000)+", "", sections[i].split("@@")[0], re.IGNORECASE)): i += 1 continue sections.pop(i) @@ -133,7 +133,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None): for j in range(i, min(i+128, len(sections))): if not re.match(prefix, sections[j]): continue - for k in range(i, j):sections.pop(i) + for _ in range(i, j):sections.pop(i) break bull = bullets_category(sections) diff --git a/rag/app/manual.py b/rag/app/manual.py index 43195d65a59b311ae253794fc9ca803ace1f361d..241fdd17934aa0472fd7adeeccd52954d6476346 100644 --- a/rag/app/manual.py +++ b/rag/app/manual.py @@ -1,6 +1,6 @@ import copy import re -from rag.app import callback__, tokenize +from rag.app import tokenize from rag.nlp import huqie from rag.parser.pdf_parser import HuParser from rag.utils import num_tokens_from_string @@ -14,19 +14,19 @@ class Pdf(HuParser): zoomin, from_page, to_page) - callback__(0.2, "OCR finished.", callback) + callback(0.2, "OCR finished.") from timeit import default_timer as timer start = timer() self._layouts_paddle(zoomin) - callback__(0.5, "Layout analysis finished.", callback) + callback(0.5, "Layout analysis finished.") print("paddle layouts:", timer() - start) self._table_transformer_job(zoomin) - callback__(0.7, "Table analysis finished.", callback) + callback(0.7, "Table analysis finished.") self._text_merge() self._concat_downward(concat_between_pages=False) self._filter_forpages() - callback__(0.77, "Text merging finished", callback) + callback(0.77, "Text merging finished") tbls = self._extract_table_figure(True, zoomin, False) # clean mess @@ -34,20 +34,8 @@ class Pdf(HuParser): b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip()) # merge chunks with the same bullets - i = 0 - while i + 1 < len(self.boxes): - b = self.boxes[i] - b_ = self.boxes[i + 1] - if b["text"].strip()[0] != b_["text"].strip()[0] \ - or b["page_number"]!=b_["page_number"] \ - or b["top"] > b_["bottom"]: - i += 1 - continue - b_["text"] = b["text"] + "\n" + b_["text"] - b_["x0"] = min(b["x0"], b_["x0"]) - b_["x1"] = max(b["x1"], b_["x1"]) - b_["top"] = b["top"] - self.boxes.pop(i) + self._merge_with_same_bullet() + # merge title with decent chunk i = 0 while i + 1 < len(self.boxes): @@ -62,7 +50,7 @@ class Pdf(HuParser): b_["top"] = b["top"] self.boxes.pop(i) - callback__(0.8, "Parsing finished", callback) + callback(0.8, "Parsing finished") for b in self.boxes: print(b["text"], b.get("layoutno")) print(tbls) diff --git a/rag/app/paper.py b/rag/app/paper.py index eacbd151bb10dca8ed373d15bed5bac6f528bf1a..220852cc5d6be76d94ad25e93663146d659ffe31 100644 --- a/rag/app/paper.py +++ b/rag/app/paper.py @@ -1,11 +1,9 @@ import copy import re from collections import Counter -from rag.app import callback__, bullets_category, BULLET_PATTERN, is_english, tokenize -from rag.nlp import huqie, stemmer -from rag.parser.docx_parser import HuDocxParser +from rag.app import tokenize +from rag.nlp import huqie from rag.parser.pdf_parser import HuParser -from nltk.tokenize import word_tokenize import numpy as np from rag.utils import num_tokens_from_string @@ -18,20 +16,20 @@ class Pdf(HuParser): zoomin, from_page, to_page) - callback__(0.2, "OCR finished.", callback) + callback(0.2, "OCR finished.") from timeit import default_timer as timer start = timer() self._layouts_paddle(zoomin) - callback__(0.47, "Layout analysis finished", callback) + callback(0.47, "Layout analysis finished") print("paddle layouts:", timer() - start) self._table_transformer_job(zoomin) - callback__(0.68, "Table analysis finished", callback) + callback(0.68, "Table analysis finished") self._text_merge() column_width = np.median([b["x1"] - b["x0"] for b in self.boxes]) self._concat_downward(concat_between_pages=False) self._filter_forpages() - callback__(0.75, "Text merging finished.", callback) + callback(0.75, "Text merging finished.") tbls = self._extract_table_figure(True, zoomin, False) # clean mess @@ -101,7 +99,7 @@ class Pdf(HuParser): break if not abstr: i = 0 - callback__(0.8, "Page {}~{}: Text merging finished".format(from_page, min(to_page, self.total_page)), callback) + callback(0.8, "Page {}~{}: Text merging finished".format(from_page, min(to_page, self.total_page))) for b in self.boxes: print(b["text"], b.get("layoutno")) print(tbls) diff --git a/rag/app/presentation.py b/rag/app/presentation.py index 69c87778f6fae79c572bf9ff034bc6a9c90954dd..0495adb7242c290a0778405c16c87491a69547f4 100644 --- a/rag/app/presentation.py +++ b/rag/app/presentation.py @@ -3,7 +3,7 @@ import re from io import BytesIO from pptx import Presentation -from rag.app import callback__, tokenize, is_english +from rag.app import tokenize, is_english from rag.nlp import huqie from rag.parser.pdf_parser import HuParser @@ -43,7 +43,7 @@ class Ppt(object): if txt: texts.append(txt) txts.append("\n".join(texts)) - callback__(0.5, "Text extraction finished.", callback) + callback(0.5, "Text extraction finished.") import aspose.slides as slides import aspose.pydrawing as drawing imgs = [] @@ -53,7 +53,7 @@ class Ppt(object): slide.get_thumbnail(0.5, 0.5).save(buffered, drawing.imaging.ImageFormat.jpeg) imgs.append(buffered.getvalue()) assert len(imgs) == len(txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts)) - callback__(0.9, "Image extraction finished", callback) + callback(0.9, "Image extraction finished") self.is_english = is_english(txts) return [(txts[i], imgs[i]) for i in range(len(txts))] @@ -70,7 +70,7 @@ class Pdf(HuParser): def __call__(self, filename, binary=None, from_page=0, to_page=100000, zoomin=3, callback=None): self.__images__(filename if not binary else binary, zoomin, from_page, to_page) - callback__(0.8, "Page {}~{}: OCR finished".format(from_page, min(to_page, self.total_page)), callback) + callback(0.8, "Page {}~{}: OCR finished".format(from_page, min(to_page, self.total_page))) assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(len(self.boxes), len(self.page_images)) res = [] #################### More precisely ################### @@ -89,7 +89,7 @@ class Pdf(HuParser): for i in range(len(self.boxes)): lines = "\n".join([b["text"] for b in self.boxes[i] if not self.__garbage(b["text"])]) res.append((lines, self.page_images[i])) - callback__(0.9, "Page {}~{}: Parsing finished".format(from_page, min(to_page, self.total_page)), callback) + callback(0.9, "Page {}~{}: Parsing finished".format(from_page, min(to_page, self.total_page))) return res diff --git a/rag/app/qa.py b/rag/app/qa.py new file mode 100644 index 0000000000000000000000000000000000000000..b9ce8432278ba0294aa09c9e558052a556325dac --- /dev/null +++ b/rag/app/qa.py @@ -0,0 +1,104 @@ +import random +import re +from io import BytesIO +from nltk import word_tokenize +from openpyxl import load_workbook +from rag.app import is_english +from rag.nlp import huqie, stemmer + + +class Excel(object): + def __call__(self, fnm, binary=None, callback=None): + if not binary: + wb = load_workbook(fnm) + else: + wb = load_workbook(BytesIO(binary)) + total = 0 + for sheetname in wb.sheetnames: + total += len(list(wb[sheetname].rows)) + + res, fails = [], [] + for sheetname in wb.sheetnames: + ws = wb[sheetname] + rows = list(ws.rows) + for i, r in enumerate(rows): + q, a = "", "" + for cell in r: + if not cell.value: continue + if not q: q = str(cell.value) + elif not a: a = str(cell.value) + else: break + if q and a: res.append((q, a)) + else: fails.append(str(i+1)) + if len(res) % 999 == 0: + callback(len(res)*0.6/total, ("Extract Q&A: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..."%(",".join(fails[:3])) if fails else ""))) + + callback(0.6, ("Extract Q&A: {}".format(len(res)) + ( + f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + self.is_english = is_english([rmPrefix(q) for q, _ in random.choices(res, k=30) if len(q)>1]) + return res + + +def rmPrefix(txt): + return re.sub(r"^(é—®é˘|ç”ćˇ|回ç”|user|assistant|Q|A|Question|Answer|é—®|ç”)[\t:: ]+", "", txt.strip(), flags=re.IGNORECASE) + + +def beAdoc(d, q, a, eng): + qprefix = "Question: " if eng else "é—®é˘ďĽš" + aprefix = "Answer: " if eng else "回ç”:" + d["content_with_weight"] = "\t".join([qprefix+rmPrefix(q), aprefix+rmPrefix(a)]) + if eng: + d["content_ltks"] = " ".join([stemmer.stem(w) for w in word_tokenize(q)]) + else: + d["content_ltks"] = huqie.qie(q) + d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) + return d + + +def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None): + + res = [] + if re.search(r"\.xlsx?$", filename, re.IGNORECASE): + callback(0.1, "Start to parse.") + excel_parser = Excel() + for q,a in excel_parser(filename, binary, callback): + res.append(beAdoc({}, q, a, excel_parser.is_english)) + return res + elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE): + callback(0.1, "Start to parse.") + txt = "" + if binary: + txt = binary.decode("utf-8") + else: + with open(filename, "r") as f: + while True: + l = f.readline() + if not l: break + txt += l + lines = txt.split("\n") + eng = is_english([rmPrefix(l) for l in lines[:100]]) + fails = [] + for i, line in enumerate(lines): + arr = [l for l in line.split("\t") if len(l) > 1] + if len(arr) != 2: + fails.append(str(i)) + continue + res.append(beAdoc({}, arr[0], arr[1], eng)) + if len(res) % 999 == 0: + callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + ( + f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + + callback(0.6, ("Extract Q&A: {}".format(len(res)) + ( + f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + + return res + + raise NotImplementedError("file type not supported yet(pptx, pdf supported)") + + +if __name__== "__main__": + import sys + def kk(rat, ss): + pass + print(chunk(sys.argv[1], callback=kk)) + diff --git a/rag/parser/pdf_parser.py b/rag/parser/pdf_parser.py index 0f9def05194a608c1ee8b2450ce79be52882b71a..5935580a847c2c48aee0a197b8bef5d0b938641b 100644 --- a/rag/parser/pdf_parser.py +++ b/rag/parser/pdf_parser.py @@ -763,7 +763,7 @@ class HuParser: return i = 0 while i < len(self.boxes): - if not re.match(r"(contents|目录|目次|table of contents)$", re.sub(r"( | |\u3000)+", "", self.boxes[i]["text"].lower())): + if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$", re.sub(r"( | |\u3000)+", "", self.boxes[i]["text"].lower())): i += 1 continue eng = re.match(r"[0-9a-zA-Z :'.-]{5,}", self.boxes[i]["text"].strip()) @@ -782,6 +782,22 @@ class HuParser: for k in range(i, j): self.boxes.pop(i) break + def _merge_with_same_bullet(self): + i = 0 + while i + 1 < len(self.boxes): + b = self.boxes[i] + b_ = self.boxes[i + 1] + if b["text"].strip()[0] != b_["text"].strip()[0] \ + or b["text"].strip()[0].lower() in set("qwertyuopasdfghjklzxcvbnm") \ + or b["top"] > b_["bottom"]: + i += 1 + continue + b_["text"] = b["text"] + "\n" + b_["text"] + b_["x0"] = min(b["x0"], b_["x0"]) + b_["x1"] = max(b["x1"], b_["x1"]) + b_["top"] = b["top"] + self.boxes.pop(i) + def _blockType(self, b): patt = [ ("^(20|19)[0-9]{2}[ĺą´/-][0-9]{1,2}[ćś/-][0-9]{1,2}ć—Ą*$", "Dt"), diff --git a/rag/svr/task_broker.py b/rag/svr/task_broker.py index f5ea4f4691758effb609e8a84a058506a521a40a..8b52e14495bd84a0b798b07e0b6c1b62d2fb943d 100644 --- a/rag/svr/task_broker.py +++ b/rag/svr/task_broker.py @@ -1,130 +1,138 @@ -# -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import logging -import os -import time -import random -from timeit import default_timer as timer -from api.db.db_models import Task -from api.db.db_utils import bulk_insert_into_db -from api.db.services.task_service import TaskService -from rag.parser.pdf_parser import HuParser -from rag.settings import cron_logger -from rag.utils import MINIO -from rag.utils import findMaxTm -import pandas as pd -from api.db import FileType -from api.db.services.document_service import DocumentService -from api.settings import database_logger -from api.utils import get_format_time, get_uuid -from api.utils.file_utils import get_project_base_directory - - -def collect(tm): - docs = DocumentService.get_newly_uploaded(tm) - if len(docs) == 0: - return pd.DataFrame() - docs = pd.DataFrame(docs) - mtm = docs["update_time"].max() - cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm)) - return docs - - -def set_dispatching(docid): - try: - DocumentService.update_by_id( - docid, {"progress": random.randint(0, 3) / 100., - "progress_msg": "Task dispatched...", - "process_begin_at": get_format_time() - }) - except Exception as e: - cron_logger.error("set_dispatching:({}), {}".format(docid, str(e))) - - -def dispatch(): - tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"broker.tm") - tm = findMaxTm(tm_fnm) - rows = collect(tm) - if len(rows) == 0: - return - - tmf = open(tm_fnm, "a+") - for _, r in rows.iterrows(): - try: - tsks = TaskService.query(doc_id=r["id"]) - if tsks: - for t in tsks: - TaskService.delete_by_id(t.id) - except Exception as e: - cron_logger.error("delete task exception:" + str(e)) - - def new_task(): - nonlocal r - return { - "id": get_uuid(), - "doc_id": r["id"] - } - - tsks = [] - if r["type"] == FileType.PDF.value: - pages = HuParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"])) - for p in range(0, pages, 10): - task = new_task() - task["from_page"] = p - task["to_page"] = min(p + 10, pages) - tsks.append(task) - else: - tsks.append(new_task()) - print(tsks) - bulk_insert_into_db(Task, tsks, True) - set_dispatching(r["id"]) - tmf.write(str(r["update_time"]) + "\n") - tmf.close() - - -def update_progress(): - docs = DocumentService.get_unfinished_docs() - for d in docs: - try: - tsks = TaskService.query(doc_id=d["id"], order_by=Task.create_time) - if not tsks:continue - msg = [] - prg = 0 - finished = True - bad = 0 - for t in tsks: - if 0 <= t.progress < 1: finished = False - prg += t.progress if t.progress >= 0 else 0 - msg.append(t.progress_msg) - if t.progress == -1: bad += 1 - prg /= len(tsks) - if finished and bad: prg = -1 - msg = "\n".join(msg) - DocumentService.update_by_id(d["id"], {"progress": prg, "progress_msg": msg, "process_duation": timer()-d["process_begin_at"].timestamp()}) - except Exception as e: - cron_logger.error("fetch task exception:" + str(e)) - - -if __name__ == "__main__": - peewee_logger = logging.getLogger('peewee') - peewee_logger.propagate = False - peewee_logger.addHandler(database_logger.handlers[0]) - peewee_logger.setLevel(database_logger.level) - - while True: - dispatch() - time.sleep(3) - update_progress() +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +import os +import time +import random +from datetime import datetime +from api.db.db_models import Task +from api.db.db_utils import bulk_insert_into_db +from api.db.services.task_service import TaskService +from rag.parser.pdf_parser import HuParser +from rag.settings import cron_logger +from rag.utils import MINIO +from rag.utils import findMaxTm +import pandas as pd +from api.db import FileType, TaskStatus +from api.db.services.document_service import DocumentService +from api.settings import database_logger +from api.utils import get_format_time, get_uuid +from api.utils.file_utils import get_project_base_directory + + +def collect(tm): + docs = DocumentService.get_newly_uploaded(tm) + if len(docs) == 0: + return pd.DataFrame() + docs = pd.DataFrame(docs) + mtm = docs["update_time"].max() + cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm)) + return docs + + +def set_dispatching(docid): + try: + DocumentService.update_by_id( + docid, {"progress": random.randint(0, 3) / 100., + "progress_msg": "Task dispatched...", + "process_begin_at": get_format_time() + }) + except Exception as e: + cron_logger.error("set_dispatching:({}), {}".format(docid, str(e))) + + +def dispatch(): + tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"broker.tm") + tm = findMaxTm(tm_fnm) + rows = collect(tm) + if len(rows) == 0: + return + + tmf = open(tm_fnm, "a+") + for _, r in rows.iterrows(): + try: + tsks = TaskService.query(doc_id=r["id"]) + if tsks: + for t in tsks: + TaskService.delete_by_id(t.id) + except Exception as e: + cron_logger.error("delete task exception:" + str(e)) + + def new_task(): + nonlocal r + return { + "id": get_uuid(), + "doc_id": r["id"] + } + + tsks = [] + if r["type"] == FileType.PDF.value: + pages = HuParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"])) + for p in range(0, pages, 10): + task = new_task() + task["from_page"] = p + task["to_page"] = min(p + 10, pages) + tsks.append(task) + else: + tsks.append(new_task()) + print(tsks) + bulk_insert_into_db(Task, tsks, True) + set_dispatching(r["id"]) + tmf.write(str(r["update_time"]) + "\n") + tmf.close() + + +def update_progress(): + docs = DocumentService.get_unfinished_docs() + for d in docs: + try: + tsks = TaskService.query(doc_id=d["id"], order_by=Task.create_time) + if not tsks:continue + msg = [] + prg = 0 + finished = True + bad = 0 + status = TaskStatus.RUNNING.value + for t in tsks: + if 0 <= t.progress < 1: finished = False + prg += t.progress if t.progress >= 0 else 0 + msg.append(t.progress_msg) + if t.progress == -1: bad += 1 + prg /= len(tsks) + if finished and bad: + prg = -1 + status = TaskStatus.FAIL.value + elif finished: status = TaskStatus.DONE.value + + msg = "\n".join(msg) + info = {"process_duation": datetime.timestamp(datetime.now())-d["process_begin_at"].timestamp(), "run": status} + if prg !=0 : info["progress"] = prg + if msg: info["progress_msg"] = msg + DocumentService.update_by_id(d["id"], info) + except Exception as e: + cron_logger.error("fetch task exception:" + str(e)) + + +if __name__ == "__main__": + peewee_logger = logging.getLogger('peewee') + peewee_logger.propagate = False + peewee_logger.addHandler(database_logger.handlers[0]) + peewee_logger.setLevel(database_logger.level) + + while True: + dispatch() + time.sleep(3) + update_progress() diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index ef40b3b0d91385abc789f014f649e3833adb0bbc..4cc348f7bd6bafedcce03110321a3e3d7e0530f3 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -24,8 +24,9 @@ import sys from functools import partial from timeit import default_timer as timer +from elasticsearch_dsl import Q + from api.db.services.task_service import TaskService -from rag.llm import EmbeddingModel, CvModel from rag.settings import cron_logger, DOC_MAXIMUM_SIZE from rag.utils import ELASTICSEARCH from rag.utils import MINIO @@ -35,7 +36,7 @@ from rag.nlp import search from io import BytesIO import pandas as pd -from rag.app import laws, paper, presentation, manual +from rag.app import laws, paper, presentation, manual, qa from api.db import LLMType, ParserType from api.db.services.document_service import DocumentService @@ -51,13 +52,14 @@ FACTORY = { ParserType.PRESENTATION.value: presentation, ParserType.MANUAL.value: manual, ParserType.LAWS.value: laws, + ParserType.QA.value: qa, } def set_progress(task_id, from_page, to_page, prog=None, msg="Processing..."): cancel = TaskService.do_cancel(task_id) if cancel: - msg = "Canceled." + msg += " [Canceled]" prog = -1 if to_page > 0: msg = f"Page({from_page}~{to_page}): " + msg @@ -166,13 +168,16 @@ def init_kb(row): def embedding(docs, mdl): - tts, cnts = [d["docnm_kwd"] for d in docs], [d["content_with_weight"] for d in docs] + tts, cnts = [d["docnm_kwd"] for d in docs if d.get("docnm_kwd")], [d["content_with_weight"] for d in docs] tk_count = 0 - tts, c = mdl.encode(tts) - tk_count += c + if len(tts) == len(cnts): + tts, c = mdl.encode(tts) + tk_count += c + cnts, c = mdl.encode(cnts) tk_count += c - vects = 0.1 * tts + 0.9 * cnts + vects = (0.1 * tts + 0.9 * cnts) if len(tts) == len(cnts) else cnts + assert len(vects) == len(docs) for i, d in enumerate(docs): v = vects[i].tolist() @@ -215,12 +220,14 @@ def main(comm, mod): callback(msg="Finished embedding! Start to build index!") init_kb(r) chunk_count = len(set([c["_id"] for c in cks])) - callback(1., "Done!") es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"])) if es_r: callback(-1, "Index failure!") cron_logger.error(str(es_r)) else: + if TaskService.do_cancel(r["id"]): + ELASTICSEARCH.deleteByQuery(Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"])) + callback(1., "Done!") DocumentService.increment_chunk_num(r["doc_id"], r["kb_id"], tk_count, chunk_count, 0) cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))