From d32322c08144009b15326604dbffabe4422309a5 Mon Sep 17 00:00:00 2001 From: KevinHuSh <kevinhu.sh@gmail.com> Date: Thu, 22 Feb 2024 19:11:37 +0800 Subject: [PATCH] rename vision, add layour and tsr recognizer (#70) * rename vision, add layour and tsr recognizer * trivial fixing --- api/apps/conversation_app.py | 26 +- api/apps/llm_app.py | 4 +- api/db/db_models.py | 5 +- api/db/services/llm_service.py | 2 +- deepdoc/parser/pdf_parser.py | 902 +------------------ deepdoc/vision/__init__.py | 4 + deepdoc/vision/layout_recognizer.py | 119 +++ deepdoc/{visual => vision}/ocr.py | 4 +- deepdoc/{visual => vision}/ocr.res | 0 deepdoc/{visual => vision}/operators.py | 0 deepdoc/{visual => vision}/postprocess.py | 0 deepdoc/vision/recognizer.py | 327 +++++++ deepdoc/{visual => vision}/seeit.py | 0 deepdoc/vision/table_structure_recognizer.py | 556 ++++++++++++ deepdoc/visual/__init__.py | 2 - deepdoc/visual/recognizer.py | 139 --- rag/svr/task_broker.py | 4 +- 17 files changed, 1055 insertions(+), 1039 deletions(-) create mode 100644 deepdoc/vision/__init__.py create mode 100644 deepdoc/vision/layout_recognizer.py rename deepdoc/{visual => vision}/ocr.py (99%) rename deepdoc/{visual => vision}/ocr.res (100%) rename deepdoc/{visual => vision}/operators.py (100%) rename deepdoc/{visual => vision}/postprocess.py (100%) create mode 100644 deepdoc/vision/recognizer.py rename deepdoc/{visual => vision}/seeit.py (100%) create mode 100644 deepdoc/vision/table_structure_recognizer.py delete mode 100644 deepdoc/visual/__init__.py delete mode 100644 deepdoc/visual/recognizer.py diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index e6e33d0..0d392bc 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -34,7 +34,6 @@ from rag.utils import num_tokens_from_string, encoder, rmSpace @manager.route('/set', methods=['POST']) @login_required -@validate_request("dialog_id") def set_conversation(): req = request.json conv_id = req.get("conversation_id") @@ -145,7 +144,7 @@ def message_fit_in(msg, max_length=4000): @manager.route('/completion', methods=['POST']) @login_required -@validate_request("dialog_id", "messages") +@validate_request("conversation_id", "messages") def completion(): req = request.json msg = [] @@ -154,12 +153,20 @@ def completion(): if m["role"] == "assistant" and not msg: continue msg.append({"role": m["role"], "content": m["content"]}) try: - e, dia = DialogService.get_by_id(req["dialog_id"]) + e, conv = ConversationService.get_by_id(req["conversation_id"]) + if not e: + return get_data_error_result(retmsg="Conversation not found!") + conv.message.append(msg[-1]) + e, dia = DialogService.get_by_id(conv.dialog_id) if not e: return get_data_error_result(retmsg="Dialog not found!") - del req["dialog_id"] + del req["conversation_id"] del req["messages"] - return get_json_result(data=chat(dia, msg, **req)) + ans = chat(dia, msg, **req) + conv.reference.append(ans["reference"]) + conv.message.append({"role": "assistant", "content": ans["answer"]}) + ConversationService.update_by_id(conv.id, conv.to_dict()) + return get_json_result(data=ans) except Exception as e: return server_error_response(e) @@ -194,8 +201,8 @@ def chat(dialog, messages, **kwargs): dialog.vector_similarity_weight, top=1024, aggs=False) knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] - if not knowledges and prompt_config["empty_response"]: - return {"answer": prompt_config["empty_response"], "retrieval": kbinfos} + if not knowledges and prompt_config.get("empty_response"): + return {"answer": prompt_config["empty_response"], "reference": kbinfos} kwargs["knowledge"] = "\n".join(knowledges) gen_conf = dialog.llm_setting @@ -205,7 +212,8 @@ def chat(dialog, messages, **kwargs): gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count) answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf) - answer = retrievaler.insert_citations(answer, + if knowledges: + answer = retrievaler.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, @@ -213,7 +221,7 @@ def chat(dialog, messages, **kwargs): vtweight=dialog.vector_similarity_weight) for c in kbinfos["chunks"]: if c.get("vector"): del c["vector"] - return {"answer": answer, "retrieval": kbinfos} + return {"answer": answer, "reference": kbinfos} def use_sql(question, field_map, tenant_id, chat_mdl): diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index c70f7ea..0a1167b 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -94,11 +94,11 @@ def list(): model_type = request.args.get("model_type") try: objs = TenantLLMService.query(tenant_id=current_user.id) - mdlnms = set([o.to_dict()["llm_name"] for o in objs if o.api_key]) + facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key]) llms = LLMService.get_all() llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value] for m in llms: - m["available"] = m["llm_name"] in mdlnms + m["available"] = m["fid"] in facts res = {} for m in llms: diff --git a/api/db/db_models.py b/api/db/db_models.py index 0e032fc..49aa169 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -500,7 +500,7 @@ class Document(DataBaseModel): token_num = IntegerField(default=0) chunk_num = IntegerField(default=0) progress = FloatField(default=0) - progress_msg = CharField(max_length=512, null=True, help_text="process message", default="") + progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="") process_begin_at = DateTimeField(null=True) process_duation = FloatField(default=0) run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0") @@ -518,7 +518,7 @@ class Task(DataBaseModel): begin_at = DateTimeField(null=True) process_duation = FloatField(default=0) progress = FloatField(default=0) - progress_msg = CharField(max_length=255, null=True, help_text="process message", default="") + progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="") class Dialog(DataBaseModel): @@ -561,6 +561,7 @@ class Conversation(DataBaseModel): dialog_id = CharField(max_length=32, null=False, index=True) name = CharField(max_length=255, null=True, help_text="converastion name") message = JSONField(null=True) + reference = JSONField(null=True, default=[]) class Meta: db_table = "conversation" diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 0fb10b0..6bc1150 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -75,7 +75,7 @@ class TenantLLMService(CommonService): model_config = cls.get_api_key(tenant_id, mdlnm) if not model_config: - raise LookupError("Model({}) not found".format(mdlnm)) + raise LookupError("Model({}) not authorized".format(mdlnm)) model_config = model_config.to_dict() if llm_type == LLMType.EMBEDDING.value: if model_config["llm_factory"] not in EmbeddingModel: diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index 576687b..c5d2e6a 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -1,9 +1,7 @@ # -*- coding: utf-8 -*- -import os import random import fitz -import requests import xgboost as xgb from io import BytesIO import torch @@ -14,9 +12,8 @@ from PIL import Image import numpy as np from api.db import ParserType -from deepdoc.visual import OCR, Recognizer +from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer from rag.nlp import huqie -from collections import Counter from copy import deepcopy from huggingface_hub import hf_hub_download @@ -29,29 +26,8 @@ class HuParser: self.ocr = OCR() if not hasattr(self, "model_speciess"): self.model_speciess = ParserType.GENERAL.value - self.layout_labels = [ - "_background_", - "Text", - "Title", - "Figure", - "Figure caption", - "Table", - "Table caption", - "Header", - "Footer", - "Reference", - "Equation", - ] - self.tsr_labels = [ - "table", - "table column", - "table row", - "table column header", - "table projected row header", - "table spanning cell", - ] - self.layouter = Recognizer(self.layout_labels, "layout", "/data/newpeak/medical-gpt/res/ppdet/") - self.tbl_det = Recognizer(self.tsr_labels, "tsr", "/data/newpeak/medical-gpt/res/ppdet.tbl/") + self.layouter = LayoutRecognizer("layout."+self.model_speciess) + self.tbl_det = TableStructureRecognizer() self.updown_cnt_mdl = xgb.Booster() if torch.cuda.is_available(): @@ -70,39 +46,6 @@ class HuParser: """ - def __remote_call(self, species, images, thr=0.7): - url = os.environ.get("INFINIFLOW_SERVER") - token = os.environ.get("INFINIFLOW_TOKEN") - if not url or not token: - logging.warning("INFINIFLOW_SERVER is not specified. To maximize the effectiveness, please visit https://github.com/infiniflow/ragflow, and sign in the our demo web site to get token. It's FREE! Using 'export' to set both environment variables: INFINIFLOW_SERVER and INFINIFLOW_TOKEN.") - return [[] for _ in range(len(images))] - - def convert_image_to_bytes(PILimage): - image = BytesIO() - PILimage.save(image, format='png') - image.seek(0) - return image.getvalue() - - images = [convert_image_to_bytes(img) for img in images] - - def remote_call(): - nonlocal images, thr - res = requests.post(url+"/v1/layout/detect/"+species, files=[("image", img) for img in images], data={"threashold": thr}, - headers={"Authorization": token}, timeout=len(images) * 10) - res = res.json() - if res["retcode"] != 0: raise RuntimeError(res["retmsg"]) - return res["data"] - - for _ in range(3): - try: - return remote_call() - except RuntimeError as e: - raise e - except Exception as e: - logging.error("layout_predict:"+str(e)) - return remote_call() - - def __char_width(self, c): return (c["x1"] - c["x0"]) // len(c["text"]) @@ -188,20 +131,6 @@ class HuParser: ] return fea - @staticmethod - def sort_Y_firstly(arr, threashold): - # sort using y1 first and then x1 - arr = sorted(arr, key=lambda r: (r["top"], r["x0"])) - for i in range(len(arr) - 1): - for j in range(i, -1, -1): - # restore the order using th - if abs(arr[j + 1]["top"] - arr[j]["top"]) < threashold \ - and arr[j + 1]["x0"] < arr[j]["x0"]: - tmp = deepcopy(arr[j]) - arr[j] = deepcopy(arr[j + 1]) - arr[j + 1] = deepcopy(tmp) - return arr - @staticmethod def sort_X_by_page(arr, threashold): # sort using y1 first and then x1 @@ -217,61 +146,6 @@ class HuParser: arr[j + 1] = tmp return arr - @staticmethod - def sort_R_firstly(arr, thr=0): - # sort using y1 first and then x1 - # sorted(arr, key=lambda r: (r["top"], r["x0"])) - arr = HuParser.sort_Y_firstly(arr, thr) - for i in range(len(arr) - 1): - for j in range(i, -1, -1): - if "R" not in arr[j] or "R" not in arr[j + 1]: - continue - if arr[j + 1]["R"] < arr[j]["R"] \ - or ( - arr[j + 1]["R"] == arr[j]["R"] - and arr[j + 1]["x0"] < arr[j]["x0"] - ): - tmp = arr[j] - arr[j] = arr[j + 1] - arr[j + 1] = tmp - return arr - - @staticmethod - def sort_X_firstly(arr, threashold, copy=True): - # sort using y1 first and then x1 - arr = sorted(arr, key=lambda r: (r["x0"], r["top"])) - for i in range(len(arr) - 1): - for j in range(i, -1, -1): - # restore the order using th - if abs(arr[j + 1]["x0"] - arr[j]["x0"]) < threashold \ - and arr[j + 1]["top"] < arr[j]["top"]: - tmp = deepcopy(arr[j]) if copy else arr[j] - arr[j] = deepcopy(arr[j + 1]) if copy else arr[j + 1] - arr[j + 1] = deepcopy(tmp) if copy else tmp - return arr - - @staticmethod - def sort_C_firstly(arr, thr=0): - # sort using y1 first and then x1 - # sorted(arr, key=lambda r: (r["x0"], r["top"])) - arr = HuParser.sort_X_firstly(arr, thr) - for i in range(len(arr) - 1): - for j in range(i, -1, -1): - # restore the order using th - if "C" not in arr[j] or "C" not in arr[j + 1]: - continue - if arr[j + 1]["C"] < arr[j]["C"] \ - or ( - arr[j + 1]["C"] == arr[j]["C"] - and arr[j + 1]["top"] < arr[j]["top"] - ): - tmp = arr[j] - arr[j] = arr[j + 1] - arr[j + 1] = tmp - return arr - - return sorted(arr, key=lambda r: (r.get("C", r["x0"]), r["top"])) - def _has_color(self, o): if o.get("ncs", "") == "DeviceGray": if o["stroking_color"] and o["stroking_color"][0] == 1 and o["non_stroking_color"] and \ @@ -280,172 +154,6 @@ class HuParser: return False return True - def __overlapped_area(self, a, b, ratio=True): - tp, btm, x0, x1 = a["top"], a["bottom"], a["x0"], a["x1"] - if b["x0"] > x1 or b["x1"] < x0: - return 0 - if b["bottom"] < tp or b["top"] > btm: - return 0 - x0_ = max(b["x0"], x0) - x1_ = min(b["x1"], x1) - assert x0_ <= x1_, "Fuckedup! T:{},B:{},X0:{},X1:{} ==> {}".format( - tp, btm, x0, x1, b) - tp_ = max(b["top"], tp) - btm_ = min(b["bottom"], btm) - assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format( - tp, btm, x0, x1, b) - ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \ - x0 != 0 and btm - tp != 0 else 0 - if ov > 0 and ratio: - ov /= (x1 - x0) * (btm - tp) - return ov - - def __find_overlapped_with_threashold(self, box, boxes, thr=0.3): - if not boxes: - return - max_overlaped_i, max_overlaped, _max_overlaped = None, thr, 0 - s, e = 0, len(boxes) - for i in range(s, e): - ov = self.__overlapped_area(box, boxes[i]) - _ov = self.__overlapped_area(boxes[i], box) - if (ov, _ov) < (max_overlaped, _max_overlaped): - continue - max_overlaped_i = i - max_overlaped = ov - _max_overlaped = _ov - - return max_overlaped_i - - def __find_overlapped(self, box, boxes_sorted_by_y, naive=False): - if not boxes_sorted_by_y: - return - bxs = boxes_sorted_by_y - s, e, ii = 0, len(bxs), 0 - while s < e and not naive: - ii = (e + s) // 2 - pv = bxs[ii] - if box["bottom"] < pv["top"]: - e = ii - continue - if box["top"] > pv["bottom"]: - s = ii + 1 - continue - break - while s < ii: - if box["top"] > bxs[s]["bottom"]: - s += 1 - break - while e - 1 > ii: - if box["bottom"] < bxs[e - 1]["top"]: - e -= 1 - break - - max_overlaped_i, max_overlaped = None, 0 - for i in range(s, e): - ov = self.__overlapped_area(bxs[i], box) - if ov <= max_overlaped: - continue - max_overlaped_i = i - max_overlaped = ov - - return max_overlaped_i - - def _is_garbage(self, b): - patt = [r"^•+$", r"(版æƒå½’©|å…è´£æ¡æ¬¾|地å€[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$", - r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", - "(资料|æ•°æ®)æ¥æº[::]", "[0-9a-z._-]+@[a-z0-9-]+\\.[a-z]{2,3}", - "\\(cid *: *[0-9]+ *\\)" - ] - return any([re.search(p, b["text"]) for p in patt]) - - def __layouts_cleanup(self, boxes, layouts, far=2, thr=0.7): - def notOverlapped(a, b): - return any([a["x1"] < b["x0"], - a["x0"] > b["x1"], - a["bottom"] < b["top"], - a["top"] > b["bottom"]]) - - i = 0 - while i + 1 < len(layouts): - j = i + 1 - while j < min(i + far, len(layouts)) \ - and (layouts[i].get("type", "") != layouts[j].get("type", "") - or notOverlapped(layouts[i], layouts[j])): - j += 1 - if j >= min(i + far, len(layouts)): - i += 1 - continue - if self.__overlapped_area(layouts[i], layouts[j]) < thr \ - and self.__overlapped_area(layouts[j], layouts[i]) < thr: - i += 1 - continue - - if layouts[i].get("score") and layouts[j].get("score"): - if layouts[i]["score"] > layouts[j]["score"]: - layouts.pop(j) - else: - layouts.pop(i) - continue - - area_i, area_i_1 = 0, 0 - for b in boxes: - if not notOverlapped(b, layouts[i]): - area_i += self.__overlapped_area(b, layouts[i], False) - if not notOverlapped(b, layouts[j]): - area_i_1 += self.__overlapped_area(b, layouts[j], False) - - if area_i > area_i_1: - layouts.pop(j) - else: - layouts.pop(i) - - return layouts - - def __table_tsr(self, images): - tbls = self.tbl_det(images, thr=0.5) - res = [] - # align left&right for rows, align top&bottom for columns - for tbl in tbls: - lts = [{"label": b["type"], - "score": b["score"], - "x0": b["bbox"][0], "x1": b["bbox"][2], - "top": b["bbox"][1], "bottom": b["bbox"][-1] - } for b in tbl] - if not lts: - continue - - left = [b["x0"] for b in lts if b["label"].find( - "row") > 0 or b["label"].find("header") > 0] - right = [b["x1"] for b in lts if b["label"].find( - "row") > 0 or b["label"].find("header") > 0] - if not left: - continue - left = np.median(left) if len(left) > 4 else np.min(left) - right = np.median(right) if len(right) > 4 else np.max(right) - for b in lts: - if b["label"].find("row") > 0 or b["label"].find("header") > 0: - if b["x0"] > left: - b["x0"] = left - if b["x1"] < right: - b["x1"] = right - - top = [b["top"] for b in lts if b["label"] == "table column"] - bottom = [b["bottom"] for b in lts if b["label"] == "table column"] - if not top: - res.append(lts) - continue - top = np.median(top) if len(top) > 4 else np.min(top) - bottom = np.median(bottom) if len(bottom) > 4 else np.max(bottom) - for b in lts: - if b["label"] == "table column": - if b["top"] > top: - b["top"] = top - if b["bottom"] < bottom: - b["bottom"] = bottom - - res.append(lts) - return res - def _table_transformer_job(self, ZM): logging.info("Table processing...") imgs, pos = [], [] @@ -471,7 +179,7 @@ class HuParser: assert len(self.page_images) == len(tbcnt) - 1 if not imgs: return - recos = self.__table_tsr(imgs) + recos = self.tbl_det(imgs) tbcnt = np.cumsum(tbcnt) for i in range(len(tbcnt) - 1): # for page pg = [] @@ -493,10 +201,10 @@ class HuParser: self.tb_cpns.extend(pg) def gather(kwd, fzy=10, ption=0.6): - eles = self.sort_Y_firstly( + eles = Recognizer.sort_Y_firstly( [r for r in self.tb_cpns if re.match(kwd, r["label"])], fzy) - eles = self.__layouts_cleanup(self.boxes, eles, 5, ption) - return self.sort_Y_firstly(eles, 0) + eles = Recognizer.layouts_cleanup(self.boxes, eles, 5, ption) + return Recognizer.sort_Y_firstly(eles, 0) # add R,H,C,SP tag to boxes within table layout headers = gather(r".*header$") @@ -504,17 +212,17 @@ class HuParser: spans = gather(r".*spanning") clmns = sorted([r for r in self.tb_cpns if re.match( r"table column$", r["label"])], key=lambda x: (x["pn"], x["layoutno"], x["x0"])) - clmns = self.__layouts_cleanup(self.boxes, clmns, 5, 0.5) + clmns = Recognizer.layouts_cleanup(self.boxes, clmns, 5, 0.5) for b in self.boxes: if b.get("layout_type", "") != "table": continue - ii = self.__find_overlapped_with_threashold(b, rows, thr=0.3) + ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3) if ii is not None: b["R"] = ii b["R_top"] = rows[ii]["top"] b["R_bott"] = rows[ii]["bottom"] - ii = self.__find_overlapped_with_threashold(b, headers, thr=0.3) + ii = Recognizer.find_overlapped_with_threashold(b, headers, thr=0.3) if ii is not None: b["H_top"] = headers[ii]["top"] b["H_bott"] = headers[ii]["bottom"] @@ -522,13 +230,13 @@ class HuParser: b["H_right"] = headers[ii]["x1"] b["H"] = ii - ii = self.__find_overlapped_with_threashold(b, clmns, thr=0.3) + ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3) if ii is not None: b["C"] = ii b["C_left"] = clmns[ii]["x0"] b["C_right"] = clmns[ii]["x1"] - ii = self.__find_overlapped_with_threashold(b, spans, thr=0.3) + ii = Recognizer.find_overlapped_with_threashold(b, spans, thr=0.3) if ii is not None: b["H_top"] = spans[ii]["top"] b["H_bott"] = spans[ii]["bottom"] @@ -542,7 +250,7 @@ class HuParser: self.boxes.append([]) return bxs = [(line[0], line[1][0]) for line in bxs] - bxs = self.sort_Y_firstly( + bxs = Recognizer.sort_Y_firstly( [{"x0": b[0][0] / ZM, "x1": b[1][0] / ZM, "top": b[0][1] / ZM, "text": "", "txt": t, "bottom": b[-1][1] / ZM, @@ -551,8 +259,8 @@ class HuParser: ) # merge chars in the same rect - for c in self.sort_X_firstly(chars, self.mean_width[pagenum - 1] // 4): - ii = self.__find_overlapped(c, bxs) + for c in Recognizer.sort_X_firstly(chars, self.mean_width[pagenum - 1] // 4): + ii = Recognizer.find_overlapped(c, bxs) if ii is None: self.lefted_chars.append(c) continue @@ -573,91 +281,11 @@ class HuParser: if self.mean_height[-1] == 0: self.mean_height[-1] = np.median([b["bottom"] - b["top"] for b in bxs]) - self.boxes.append(bxs) def _layouts_rec(self, ZM): assert len(self.page_images) == len(self.boxes) - # Tag layout type - boxes = [] - layouts = self.layouter(self.page_images) - #save_results(self.page_images, layouts, self.layout_labels, output_dir='output/', threshold=0.7) - assert len(self.page_images) == len(layouts) - for pn, lts in enumerate(layouts): - bxs = self.boxes[pn] - lts = [{"type": b["type"], - "score": float(b["score"]), - "x0": b["bbox"][0] / ZM, "x1": b["bbox"][2] / ZM, - "top": b["bbox"][1] / ZM, "bottom": b["bbox"][-1] / ZM, - "page_number": pn, - } for b in lts] - lts = self.sort_Y_firstly(lts, self.mean_height[pn] / 2) - lts = self.__layouts_cleanup(bxs, lts) - self.page_layout.append(lts) - - # Tag layout type, layouts are ready - def findLayout(ty): - nonlocal bxs, lts - lts_ = [lt for lt in lts if lt["type"] == ty] - i = 0 - while i < len(bxs): - if bxs[i].get("layout_type"): - i += 1 - continue - if self._is_garbage(bxs[i]): - logging.debug("GARBAGE: " + bxs[i]["text"]) - bxs.pop(i) - continue - - ii = self.__find_overlapped_with_threashold(bxs[i], lts_, - thr=0.4) - if ii is None: # belong to nothing - bxs[i]["layout_type"] = "" - i += 1 - continue - lts_[ii]["visited"] = True - if lts_[ii]["type"] in ["footer", "header", "reference"]: - if lts_[ii]["type"] not in self.garbages: - self.garbages[lts_[ii]["type"]] = [] - self.garbages[lts_[ii]["type"]].append(bxs[i]["text"]) - logging.debug("GARBAGE: " + bxs[i]["text"]) - bxs.pop(i) - continue - - bxs[i]["layoutno"] = f"{ty}-{ii}" - bxs[i]["layout_type"] = lts_[ii]["type"] - i += 1 - - for lt in ["footer", "header", "reference", "figure caption", - "table caption", "title", "text", "table", "figure"]: - findLayout(lt) - - # add box to figure layouts which has not text box - for i, lt in enumerate( - [lt for lt in lts if lt["type"] == "figure"]): - if lt.get("visited"): - continue - lt = deepcopy(lt) - del lt["type"] - lt["text"] = "" - lt["layout_type"] = "figure" - lt["layoutno"] = f"figure-{i}" - bxs.append(lt) - - boxes.extend(bxs) - - self.boxes = boxes - - garbage = set() - for k in self.garbages.keys(): - self.garbages[k] = Counter(self.garbages[k]) - for g, c in self.garbages[k].items(): - if c > 1: - garbage.add(g) - - logging.debug("GARBAGE:" + ",".join(garbage)) - self.boxes = [b for b in self.boxes if b["text"].strip() not in garbage] - + self.boxes, self.page_layout = self.layouter(self.page_images, self.boxes, ZM) # cumlative Y for i in range(len(self.boxes)): self.boxes[i]["top"] += \ @@ -710,7 +338,7 @@ class HuParser: self.boxes = bxs def _naive_vertical_merge(self): - bxs = self.sort_Y_firstly(self.boxes, np.median(self.mean_height) / 3) + bxs = Recognizer.sort_Y_firstly(self.boxes, np.median(self.mean_height) / 3) i = 0 while i + 1 < len(bxs): b = bxs[i] @@ -850,7 +478,7 @@ class HuParser: t["layout_type"] = c["layout_type"] boxes.append(t) - self.boxes = self.sort_Y_firstly(boxes, 0) + self.boxes = Recognizer.sort_Y_firstly(boxes, 0) def _filter_forpages(self): if not self.boxes: @@ -916,492 +544,6 @@ class HuParser: b_["top"] = b["top"] self.boxes.pop(i) - def _blockType(self, b): - patt = [ - ("^(20|19)[0-9]{2}[å¹´/-][0-9]{1,2}[月/-][0-9]{1,2}æ—¥*$", "Dt"), - (r"^(20|19)[0-9]{2}å¹´$", "Dt"), - (r"^(20|19)[0-9]{2}[å¹´-][0-9]{1,2}月*$", "Dt"), - ("^[0-9]{1,2}[月-][0-9]{1,2}æ—¥*$", "Dt"), - (r"^第*[一二三四1-4]å£åº¦$", "Dt"), - (r"^(20|19)[0-9]{2}å¹´*[一二三四1-4]å£åº¦$", "Dt"), - (r"^(20|19)[0-9]{2}[ABCDE]$", "Dt"), - ("^[0-9.,+%/ -]+$", "Nu"), - (r"^[0-9A-Z/\._~-]+$", "Ca"), - (r"^[A-Z]*[a-z' -]+$", "En"), - (r"^[0-9.,+-]+[0-9A-Za-z/$ï¿¥%<>()()' -]+$", "NE"), - (r"^.{1}$", "Sg") - ] - for p, n in patt: - if re.search(p, b["text"].strip()): - return n - tks = [t for t in huqie.qie(b["text"]).split(" ") if len(t) > 1] - if len(tks) > 3: - if len(tks) < 12: - return "Tx" - else: - return "Lx" - - if len(tks) == 1 and huqie.tag(tks[0]) == "nr": - return "Nr" - - return "Ot" - - def __cal_spans(self, boxes, rows, cols, tbl, html=True): - # caculate span - clft = [np.mean([c.get("C_left", c["x0"]) for c in cln]) - for cln in cols] - crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln]) - for cln in cols] - rtop = [np.mean([c.get("R_top", c["top"]) for c in row]) - for row in rows] - rbtm = [np.mean([c.get("R_btm", c["bottom"]) - for c in row]) for row in rows] - for b in boxes: - if "SP" not in b: - continue - b["colspan"] = [b["cn"]] - b["rowspan"] = [b["rn"]] - # col span - for j in range(0, len(clft)): - if j == b["cn"]: - continue - if clft[j] + (crgt[j] - clft[j]) / 2 < b["H_left"]: - continue - if crgt[j] - (crgt[j] - clft[j]) / 2 > b["H_right"]: - continue - b["colspan"].append(j) - # row span - for j in range(0, len(rtop)): - if j == b["rn"]: - continue - if rtop[j] + (rbtm[j] - rtop[j]) / 2 < b["H_top"]: - continue - if rbtm[j] - (rbtm[j] - rtop[j]) / 2 > b["H_bott"]: - continue - b["rowspan"].append(j) - - def join(arr): - if not arr: - return "" - return "".join([t["text"] for t in arr]) - - # rm the spaning cells - for i in range(len(tbl)): - for j, arr in enumerate(tbl[i]): - if not arr: - continue - if all(["rowspan" not in a and "colspan" not in a for a in arr]): - continue - rowspan, colspan = [], [] - for a in arr: - if isinstance(a.get("rowspan", 0), list): - rowspan.extend(a["rowspan"]) - if isinstance(a.get("colspan", 0), list): - colspan.extend(a["colspan"]) - rowspan, colspan = set(rowspan), set(colspan) - if len(rowspan) < 2 and len(colspan) < 2: - for a in arr: - if "rowspan" in a: - del a["rowspan"] - if "colspan" in a: - del a["colspan"] - continue - rowspan, colspan = sorted(rowspan), sorted(colspan) - rowspan = list(range(rowspan[0], rowspan[-1] + 1)) - colspan = list(range(colspan[0], colspan[-1] + 1)) - assert i in rowspan, rowspan - assert j in colspan, colspan - arr = [] - for r in rowspan: - for c in colspan: - arr_txt = join(arr) - if tbl[r][c] and join(tbl[r][c]) != arr_txt: - arr.extend(tbl[r][c]) - tbl[r][c] = None if html else arr - for a in arr: - if len(rowspan) > 1: - a["rowspan"] = len(rowspan) - elif "rowspan" in a: - del a["rowspan"] - if len(colspan) > 1: - a["colspan"] = len(colspan) - elif "colspan" in a: - del a["colspan"] - tbl[rowspan[0]][colspan[0]] = arr - - return tbl - - def __construct_table(self, boxes, html=False): - cap = "" - i = 0 - while i < len(boxes): - if self.is_caption(boxes[i]): - cap += boxes[i]["text"] - boxes.pop(i) - i -= 1 - i += 1 - - if not boxes: - return [] - for b in boxes: - b["btype"] = self._blockType(b) - max_type = Counter([b["btype"] for b in boxes]).items() - max_type = max(max_type, key=lambda x: x[1])[0] if max_type else "" - logging.debug("MAXTYPE: " + max_type) - - rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b] - rowh = np.min(rowh) if rowh else 0 - # boxes = self.sort_Y_firstly(boxes, rowh/5) - boxes = self.sort_R_firstly(boxes, rowh / 2) - boxes[0]["rn"] = 0 - rows = [[boxes[0]]] - btm = boxes[0]["bottom"] - for b in boxes[1:]: - b["rn"] = len(rows) - 1 - lst_r = rows[-1] - if lst_r[-1].get("R", "") != b.get("R", "") \ - or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2") - ): # new row - btm = b["bottom"] - b["rn"] += 1 - rows.append([b]) - continue - btm = (btm + b["bottom"]) / 2. - rows[-1].append(b) - - colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b] - colwm = np.min(colwm) if colwm else 0 - crosspage = len(set([b["page_number"] for b in boxes])) > 1 - if crosspage: - boxes = self.sort_X_firstly(boxes, colwm / 2, False) - else: - boxes = self.sort_C_firstly(boxes, colwm / 2) - boxes[0]["cn"] = 0 - cols = [[boxes[0]]] - right = boxes[0]["x1"] - for b in boxes[1:]: - b["cn"] = len(cols) - 1 - lst_c = cols[-1] - if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1][ - "page_number"]) \ - or (b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")): # new col - right = b["x1"] - b["cn"] += 1 - cols.append([b]) - continue - right = (right + b["x1"]) / 2. - cols[-1].append(b) - - tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))] - for b in boxes: - tbl[b["rn"]][b["cn"]].append(b) - - if len(rows) >= 4: - # remove single in column - j = 0 - while j < len(tbl[0]): - e, ii = 0, 0 - for i in range(len(tbl)): - if tbl[i][j]: - e += 1 - ii = i - if e > 1: - break - if e > 1: - j += 1 - continue - f = (j > 0 and tbl[ii][j - 1] and tbl[ii] - [j - 1][0].get("text")) or j == 0 - ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii] - [j + 1][0].get("text")) or j + 1 >= len(tbl[ii]) - if f and ff: - j += 1 - continue - bx = tbl[ii][j][0] - logging.debug("Relocate column single: " + bx["text"]) - # j column only has one value - left, right = 100000, 100000 - if j > 0 and not f: - for i in range(len(tbl)): - if tbl[i][j - 1]: - left = min(left, np.min( - [bx["x0"] - a["x1"] for a in tbl[i][j - 1]])) - if j + 1 < len(tbl[0]) and not ff: - for i in range(len(tbl)): - if tbl[i][j + 1]: - right = min(right, np.min( - [a["x0"] - bx["x1"] for a in tbl[i][j + 1]])) - assert left < 100000 or right < 100000 - if left < right: - for jj in range(j, len(tbl[0])): - for i in range(len(tbl)): - for a in tbl[i][jj]: - a["cn"] -= 1 - if tbl[ii][j - 1]: - tbl[ii][j - 1].extend(tbl[ii][j]) - else: - tbl[ii][j - 1] = tbl[ii][j] - for i in range(len(tbl)): - tbl[i].pop(j) - - else: - for jj in range(j + 1, len(tbl[0])): - for i in range(len(tbl)): - for a in tbl[i][jj]: - a["cn"] -= 1 - if tbl[ii][j + 1]: - tbl[ii][j + 1].extend(tbl[ii][j]) - else: - tbl[ii][j + 1] = tbl[ii][j] - for i in range(len(tbl)): - tbl[i].pop(j) - cols.pop(j) - assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % ( - len(cols), len(tbl[0])) - - if len(cols) >= 4: - # remove single in row - i = 0 - while i < len(tbl): - e, jj = 0, 0 - for j in range(len(tbl[i])): - if tbl[i][j]: - e += 1 - jj = j - if e > 1: - break - if e > 1: - i += 1 - continue - f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1] - [jj][0].get("text")) or i == 0 - ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1] - [jj][0].get("text")) or i + 1 >= len(tbl) - if f and ff: - i += 1 - continue - - bx = tbl[i][jj][0] - logging.debug("Relocate row single: " + bx["text"]) - # i row only has one value - up, down = 100000, 100000 - if i > 0 and not f: - for j in range(len(tbl[i - 1])): - if tbl[i - 1][j]: - up = min(up, np.min( - [bx["top"] - a["bottom"] for a in tbl[i - 1][j]])) - if i + 1 < len(tbl) and not ff: - for j in range(len(tbl[i + 1])): - if tbl[i + 1][j]: - down = min(down, np.min( - [a["top"] - bx["bottom"] for a in tbl[i + 1][j]])) - assert up < 100000 or down < 100000 - if up < down: - for ii in range(i, len(tbl)): - for j in range(len(tbl[ii])): - for a in tbl[ii][j]: - a["rn"] -= 1 - if tbl[i - 1][jj]: - tbl[i - 1][jj].extend(tbl[i][jj]) - else: - tbl[i - 1][jj] = tbl[i][jj] - tbl.pop(i) - - else: - for ii in range(i + 1, len(tbl)): - for j in range(len(tbl[ii])): - for a in tbl[ii][j]: - a["rn"] -= 1 - if tbl[i + 1][jj]: - tbl[i + 1][jj].extend(tbl[i][jj]) - else: - tbl[i + 1][jj] = tbl[i][jj] - tbl.pop(i) - rows.pop(i) - - # which rows are headers - hdset = set([]) - for i in range(len(tbl)): - cnt, h = 0, 0 - for j, arr in enumerate(tbl[i]): - if not arr: - continue - cnt += 1 - if max_type == "Nu" and arr[0]["btype"] == "Nu": - continue - if any([a.get("H") for a in arr]) \ - or (max_type == "Nu" and arr[0]["btype"] != "Nu"): - h += 1 - if h / cnt > 0.5: - hdset.add(i) - - if html: - return [self.__html_table(cap, hdset, - self.__cal_spans(boxes, rows, - cols, tbl, True) - )] - - return self.__desc_table(cap, hdset, - self.__cal_spans(boxes, rows, cols, tbl, False)) - - def __html_table(self, cap, hdset, tbl): - # constrcut HTML - html = "<table>" - if cap: - html += f"<caption>{cap}</caption>" - for i in range(len(tbl)): - row = "<tr>" - txts = [] - for j, arr in enumerate(tbl[i]): - if arr is None: - continue - if not arr: - row += "<td></td>" if i not in hdset else "<th></th>" - continue - txt = "" - if arr: - h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, - self.mean_height[arr[0]["page_number"] - 1] / 2) - txt = "".join([c["text"] - for c in self.sort_Y_firstly(arr, h)]) - txts.append(txt) - sp = "" - if arr[0].get("colspan"): - sp = "colspan={}".format(arr[0]["colspan"]) - if arr[0].get("rowspan"): - sp += " rowspan={}".format(arr[0]["rowspan"]) - if i in hdset: - row += f"<th {sp} >" + txt + "</th>" - else: - row += f"<td {sp} >" + txt + "</td>" - - if i in hdset: - if all([t in hdset for t in txts]): - continue - for t in txts: - hdset.add(t) - - if row != "<tr>": - row += "</tr>" - else: - row = "" - html += "\n" + row - html += "\n</table>" - return html - - def __desc_table(self, cap, hdr_rowno, tbl): - # get text of every colomn in header row to become header text - clmno = len(tbl[0]) - rowno = len(tbl) - headers = {} - hdrset = set() - lst_hdr = [] - de = "çš„" if not self.is_english else " for " - for r in sorted(list(hdr_rowno)): - headers[r] = ["" for _ in range(clmno)] - for i in range(clmno): - if not tbl[r][i]: - continue - txt = "".join([a["text"].strip() for a in tbl[r][i]]) - headers[r][i] = txt - hdrset.add(txt) - if all([not t for t in headers[r]]): - del headers[r] - hdr_rowno.remove(r) - continue - for j in range(clmno): - if headers[r][j]: - continue - if j >= len(lst_hdr): - break - headers[r][j] = lst_hdr[j] - lst_hdr = headers[r] - for i in range(rowno): - if i not in hdr_rowno: - continue - for j in range(i + 1, rowno): - if j not in hdr_rowno: - break - for k in range(clmno): - if not headers[j - 1][k]: - continue - if headers[j][k].find(headers[j - 1][k]) >= 0: - continue - if len(headers[j][k]) > len(headers[j - 1][k]): - headers[j][k] += (de if headers[j][k] - else "") + headers[j - 1][k] - else: - headers[j][k] = headers[j - 1][k] \ - + (de if headers[j - 1][k] else "") \ - + headers[j][k] - - logging.debug( - f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}") - row_txt = [] - for i in range(rowno): - if i in hdr_rowno: - continue - rtxt = [] - - def append(delimer): - nonlocal rtxt, row_txt - rtxt = delimer.join(rtxt) - if row_txt and len(row_txt[-1]) + len(rtxt) < 64: - row_txt[-1] += "\n" + rtxt - else: - row_txt.append(rtxt) - - r = 0 - if len(headers.items()): - _arr = [(i - r, r) for r, _ in headers.items() if r < i] - if _arr: - _, r = min(_arr, key=lambda x: x[0]) - - if r not in headers and clmno <= 2: - for j in range(clmno): - if not tbl[i][j]: - continue - txt = "".join([a["text"].strip() for a in tbl[i][j]]) - if txt: - rtxt.append(txt) - if rtxt: - append(":") - continue - - for j in range(clmno): - if not tbl[i][j]: - continue - txt = "".join([a["text"].strip() for a in tbl[i][j]]) - if not txt: - continue - ctt = headers[r][j] if r in headers else "" - if ctt: - ctt += ":" - ctt += txt - if ctt: - rtxt.append(ctt) - - if rtxt: - row_txt.append("; ".join(rtxt)) - - if cap: - if self.is_english: - from_ = " in " - else: - from_ = "æ¥è‡ª" - row_txt = [t + f"\t——{from_}“{cap}â€" for t in row_txt] - return row_txt - - @staticmethod - def is_caption(bx): - patt = [ - r"[图表]+[ 0-9::]{2,}" - ] - if any([re.match(p, bx["text"].strip()) for p in patt]) \ - or bx["layout_type"].find("caption") >= 0: - return True - return False - def _extract_table_figure(self, need_image, ZM, return_html): tables = {} figures = {} @@ -1415,7 +557,7 @@ class HuParser: continue lout_no = str(self.boxes[i]["page_number"]) + \ "-" + str(self.boxes[i]["layoutno"]) - if self.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", "title", + if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", "title", "figure caption", "reference"]: nomerge_lout_no.append(lst_lout_no) if self.boxes[i]["layout_type"] == "table": @@ -1470,7 +612,7 @@ class HuParser: while i < len(self.boxes): c = self.boxes[i] # mh = self.mean_height[c["page_number"]-1] - if not self.is_caption(c): + if not TableStructureRecognizer.is_caption(c): i += 1 continue @@ -1529,7 +671,7 @@ class HuParser: "bottom": np.max([b["bottom"] for b in bxs]) - ht } louts = [l for l in self.page_layout[pn] if l["type"] == ltype] - ii = self.__find_overlapped(b, louts, naive=True) + ii = Recognizer.find_overlapped(b, louts, naive=True) if ii is not None: b = louts[ii] else: @@ -1581,7 +723,7 @@ class HuParser: if not bxs: continue res.append((cropout(bxs, "table"), - self.__construct_table(bxs, html=return_html))) + self.tbl_det.construct_table(bxs, html=return_html, is_english=self.is_english))) return res diff --git a/deepdoc/vision/__init__.py b/deepdoc/vision/__init__.py new file mode 100644 index 0000000..93eea13 --- /dev/null +++ b/deepdoc/vision/__init__.py @@ -0,0 +1,4 @@ +from .ocr import OCR +from .recognizer import Recognizer +from .layout_recognizer import LayoutRecognizer +from .table_structure_recognizer import TableStructureRecognizer diff --git a/deepdoc/vision/layout_recognizer.py b/deepdoc/vision/layout_recognizer.py new file mode 100644 index 0000000..1a5795a --- /dev/null +++ b/deepdoc/vision/layout_recognizer.py @@ -0,0 +1,119 @@ +import os +import re +from collections import Counter +from copy import deepcopy + +import numpy as np + +from api.utils.file_utils import get_project_base_directory +from .recognizer import Recognizer + + +class LayoutRecognizer(Recognizer): + def __init__(self, domain): + self.layout_labels = [ + "_background_", + "Text", + "Title", + "Figure", + "Figure caption", + "Table", + "Table caption", + "Header", + "Footer", + "Reference", + "Equation", + ] + super().__init__(self.layout_labels, domain, + os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) + + def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16): + def __is_garbage(b): + patt = [r"^•+$", r"(版æƒå½’©|å…è´£æ¡æ¬¾|地å€[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$", + r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", + "(资料|æ•°æ®)æ¥æº[::]", "[0-9a-z._-]+@[a-z0-9-]+\\.[a-z]{2,3}", + "\\(cid *: *[0-9]+ *\\)" + ] + return any([re.search(p, b["text"]) for p in patt]) + + layouts = super().__call__(image_list, thr, batch_size) + # save_results(image_list, layouts, self.layout_labels, output_dir='output/', threshold=0.7) + assert len(image_list) == len(ocr_res) + # Tag layout type + boxes = [] + assert len(image_list) == len(layouts) + garbages = {} + page_layout = [] + for pn, lts in enumerate(layouts): + bxs = ocr_res[pn] + lts = [{"type": b["type"], + "score": float(b["score"]), + "x0": b["bbox"][0] / scale_factor, "x1": b["bbox"][2] / scale_factor, + "top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor, + "page_number": pn, + } for b in lts] + lts = self.sort_Y_firstly(lts, np.mean([l["bottom"]-l["top"] for l in lts]) / 2) + lts = self.layouts_cleanup(bxs, lts) + page_layout.append(lts) + + # Tag layout type, layouts are ready + def findLayout(ty): + nonlocal bxs, lts, self + lts_ = [lt for lt in lts if lt["type"] == ty] + i = 0 + while i < len(bxs): + if bxs[i].get("layout_type"): + i += 1 + continue + if __is_garbage(bxs[i]): + bxs.pop(i) + continue + + ii = self.find_overlapped_with_threashold(bxs[i], lts_, + thr=0.4) + if ii is None: # belong to nothing + bxs[i]["layout_type"] = "" + i += 1 + continue + lts_[ii]["visited"] = True + if lts_[ii]["type"] in ["footer", "header", "reference"]: + if lts_[ii]["type"] not in garbages: + garbages[lts_[ii]["type"]] = [] + garbages[lts_[ii]["type"]].append(bxs[i]["text"]) + bxs.pop(i) + continue + + bxs[i]["layoutno"] = f"{ty}-{ii}" + bxs[i]["layout_type"] = lts_[ii]["type"] + i += 1 + + for lt in ["footer", "header", "reference", "figure caption", + "table caption", "title", "text", "table", "figure", "equation"]: + findLayout(lt) + + # add box to figure layouts which has not text box + for i, lt in enumerate( + [lt for lt in lts if lt["type"] == "figure"]): + if lt.get("visited"): + continue + lt = deepcopy(lt) + del lt["type"] + lt["text"] = "" + lt["layout_type"] = "figure" + lt["layoutno"] = f"figure-{i}" + bxs.append(lt) + + boxes.extend(bxs) + + ocr_res = boxes + + garbag_set = set() + for k in garbages.keys(): + garbages[k] = Counter(garbages[k]) + for g, c in garbages[k].items(): + if c > 1: + garbag_set.add(g) + + ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set] + return ocr_res, page_layout + diff --git a/deepdoc/visual/ocr.py b/deepdoc/vision/ocr.py similarity index 99% rename from deepdoc/visual/ocr.py rename to deepdoc/vision/ocr.py index 65b2c2d..09c8a7a 100644 --- a/deepdoc/visual/ocr.py +++ b/deepdoc/vision/ocr.py @@ -74,7 +74,7 @@ class TextRecognizer(object): self.rec_batch_num = 16 postprocess_params = { 'name': 'CTCLabelDecode', - "character_dict_path": os.path.join(get_project_base_directory(), "rag/res", "ocr.res"), + "character_dict_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "ocr.res"), "use_space_char": True } self.postprocess_op = build_post_process(postprocess_params) @@ -450,7 +450,7 @@ class OCR(object): """ if not model_dir: - model_dir = snapshot_download(repo_id="InfiniFlow/ocr") + model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc") self.text_detector = TextDetector(model_dir) self.text_recognizer = TextRecognizer(model_dir) diff --git a/deepdoc/visual/ocr.res b/deepdoc/vision/ocr.res similarity index 100% rename from deepdoc/visual/ocr.res rename to deepdoc/vision/ocr.res diff --git a/deepdoc/visual/operators.py b/deepdoc/vision/operators.py similarity index 100% rename from deepdoc/visual/operators.py rename to deepdoc/vision/operators.py diff --git a/deepdoc/visual/postprocess.py b/deepdoc/vision/postprocess.py similarity index 100% rename from deepdoc/visual/postprocess.py rename to deepdoc/vision/postprocess.py diff --git a/deepdoc/vision/recognizer.py b/deepdoc/vision/recognizer.py new file mode 100644 index 0000000..9329234 --- /dev/null +++ b/deepdoc/vision/recognizer.py @@ -0,0 +1,327 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from copy import deepcopy + +import onnxruntime as ort +from huggingface_hub import snapshot_download + +from . import seeit +from .operators import * +from rag.settings import cron_logger + + +class Recognizer(object): + def __init__(self, label_list, task_name, model_dir=None): + """ + If you have trouble downloading HuggingFace models, -_^ this might help!! + + For Linux: + export HF_ENDPOINT=https://hf-mirror.com + + For Windows: + Good luck + ^_- + + """ + if not model_dir: + model_dir = snapshot_download(repo_id="InfiniFlow/ocr") + + model_file_path = os.path.join(model_dir, task_name + ".onnx") + if not os.path.exists(model_file_path): + raise ValueError("not find model file path {}".format( + model_file_path)) + if ort.get_device() == "GPU": + self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider']) + else: + self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider']) + self.label_list = label_list + + @staticmethod + def sort_Y_firstly(arr, threashold): + # sort using y1 first and then x1 + arr = sorted(arr, key=lambda r: (r["top"], r["x0"])) + for i in range(len(arr) - 1): + for j in range(i, -1, -1): + # restore the order using th + if abs(arr[j + 1]["top"] - arr[j]["top"]) < threashold \ + and arr[j + 1]["x0"] < arr[j]["x0"]: + tmp = deepcopy(arr[j]) + arr[j] = deepcopy(arr[j + 1]) + arr[j + 1] = deepcopy(tmp) + return arr + + @staticmethod + def sort_X_firstly(arr, threashold, copy=True): + # sort using y1 first and then x1 + arr = sorted(arr, key=lambda r: (r["x0"], r["top"])) + for i in range(len(arr) - 1): + for j in range(i, -1, -1): + # restore the order using th + if abs(arr[j + 1]["x0"] - arr[j]["x0"]) < threashold \ + and arr[j + 1]["top"] < arr[j]["top"]: + tmp = deepcopy(arr[j]) if copy else arr[j] + arr[j] = deepcopy(arr[j + 1]) if copy else arr[j + 1] + arr[j + 1] = deepcopy(tmp) if copy else tmp + return arr + + @staticmethod + def sort_C_firstly(arr, thr=0): + # sort using y1 first and then x1 + # sorted(arr, key=lambda r: (r["x0"], r["top"])) + arr = Recognizer.sort_X_firstly(arr, thr) + for i in range(len(arr) - 1): + for j in range(i, -1, -1): + # restore the order using th + if "C" not in arr[j] or "C" not in arr[j + 1]: + continue + if arr[j + 1]["C"] < arr[j]["C"] \ + or ( + arr[j + 1]["C"] == arr[j]["C"] + and arr[j + 1]["top"] < arr[j]["top"] + ): + tmp = arr[j] + arr[j] = arr[j + 1] + arr[j + 1] = tmp + return arr + + return sorted(arr, key=lambda r: (r.get("C", r["x0"]), r["top"])) + + @staticmethod + def sort_R_firstly(arr, thr=0): + # sort using y1 first and then x1 + # sorted(arr, key=lambda r: (r["top"], r["x0"])) + arr = Recognizer.sort_Y_firstly(arr, thr) + for i in range(len(arr) - 1): + for j in range(i, -1, -1): + if "R" not in arr[j] or "R" not in arr[j + 1]: + continue + if arr[j + 1]["R"] < arr[j]["R"] \ + or ( + arr[j + 1]["R"] == arr[j]["R"] + and arr[j + 1]["x0"] < arr[j]["x0"] + ): + tmp = arr[j] + arr[j] = arr[j + 1] + arr[j + 1] = tmp + return arr + + @staticmethod + def overlapped_area(a, b, ratio=True): + tp, btm, x0, x1 = a["top"], a["bottom"], a["x0"], a["x1"] + if b["x0"] > x1 or b["x1"] < x0: + return 0 + if b["bottom"] < tp or b["top"] > btm: + return 0 + x0_ = max(b["x0"], x0) + x1_ = min(b["x1"], x1) + assert x0_ <= x1_, "Fuckedup! T:{},B:{},X0:{},X1:{} ==> {}".format( + tp, btm, x0, x1, b) + tp_ = max(b["top"], tp) + btm_ = min(b["bottom"], btm) + assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format( + tp, btm, x0, x1, b) + ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \ + x0 != 0 and btm - tp != 0 else 0 + if ov > 0 and ratio: + ov /= (x1 - x0) * (btm - tp) + return ov + + @staticmethod + def layouts_cleanup(boxes, layouts, far=2, thr=0.7): + def notOverlapped(a, b): + return any([a["x1"] < b["x0"], + a["x0"] > b["x1"], + a["bottom"] < b["top"], + a["top"] > b["bottom"]]) + + i = 0 + while i + 1 < len(layouts): + j = i + 1 + while j < min(i + far, len(layouts)) \ + and (layouts[i].get("type", "") != layouts[j].get("type", "") + or notOverlapped(layouts[i], layouts[j])): + j += 1 + if j >= min(i + far, len(layouts)): + i += 1 + continue + if Recognizer.overlapped_area(layouts[i], layouts[j]) < thr \ + and Recognizer.overlapped_area(layouts[j], layouts[i]) < thr: + i += 1 + continue + + if layouts[i].get("score") and layouts[j].get("score"): + if layouts[i]["score"] > layouts[j]["score"]: + layouts.pop(j) + else: + layouts.pop(i) + continue + + area_i, area_i_1 = 0, 0 + for b in boxes: + if not notOverlapped(b, layouts[i]): + area_i += Recognizer.overlapped_area(b, layouts[i], False) + if not notOverlapped(b, layouts[j]): + area_i_1 += Recognizer.overlapped_area(b, layouts[j], False) + + if area_i > area_i_1: + layouts.pop(j) + else: + layouts.pop(i) + + return layouts + + def create_inputs(self, imgs, im_info): + """generate input for different model type + Args: + imgs (list(numpy)): list of images (np.ndarray) + im_info (list(dict)): list of image info + Returns: + inputs (dict): input of model + """ + inputs = {} + + im_shape = [] + scale_factor = [] + if len(imgs) == 1: + inputs['image'] = np.array((imgs[0],)).astype('float32') + inputs['im_shape'] = np.array( + (im_info[0]['im_shape'],)).astype('float32') + inputs['scale_factor'] = np.array( + (im_info[0]['scale_factor'],)).astype('float32') + return inputs + + for e in im_info: + im_shape.append(np.array((e['im_shape'],)).astype('float32')) + scale_factor.append(np.array((e['scale_factor'],)).astype('float32')) + + inputs['im_shape'] = np.concatenate(im_shape, axis=0) + inputs['scale_factor'] = np.concatenate(scale_factor, axis=0) + + imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs] + max_shape_h = max([e[0] for e in imgs_shape]) + max_shape_w = max([e[1] for e in imgs_shape]) + padding_imgs = [] + for img in imgs: + im_c, im_h, im_w = img.shape[:] + padding_im = np.zeros( + (im_c, max_shape_h, max_shape_w), dtype=np.float32) + padding_im[:, :im_h, :im_w] = img + padding_imgs.append(padding_im) + inputs['image'] = np.stack(padding_imgs, axis=0) + return inputs + + @staticmethod + def find_overlapped(box, boxes_sorted_by_y, naive=False): + if not boxes_sorted_by_y: + return + bxs = boxes_sorted_by_y + s, e, ii = 0, len(bxs), 0 + while s < e and not naive: + ii = (e + s) // 2 + pv = bxs[ii] + if box["bottom"] < pv["top"]: + e = ii + continue + if box["top"] > pv["bottom"]: + s = ii + 1 + continue + break + while s < ii: + if box["top"] > bxs[s]["bottom"]: + s += 1 + break + while e - 1 > ii: + if box["bottom"] < bxs[e - 1]["top"]: + e -= 1 + break + + max_overlaped_i, max_overlaped = None, 0 + for i in range(s, e): + ov = Recognizer.overlapped_area(bxs[i], box) + if ov <= max_overlaped: + continue + max_overlaped_i = i + max_overlaped = ov + + return max_overlaped_i + + @staticmethod + def find_overlapped_with_threashold(box, boxes, thr=0.3): + if not boxes: + return + max_overlaped_i, max_overlaped, _max_overlaped = None, thr, 0 + s, e = 0, len(boxes) + for i in range(s, e): + ov = Recognizer.overlapped_area(box, boxes[i]) + _ov = Recognizer.overlapped_area(boxes[i], box) + if (ov, _ov) < (max_overlaped, _max_overlaped): + continue + max_overlaped_i = i + max_overlaped = ov + _max_overlaped = _ov + + return max_overlaped_i + + def preprocess(self, image_list): + preprocess_ops = [] + for op_info in [ + {'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'}, + {'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'}, + {'type': 'Permute'}, + {'stride': 32, 'type': 'PadStride'} + ]: + new_op_info = op_info.copy() + op_type = new_op_info.pop('type') + preprocess_ops.append(eval(op_type)(**new_op_info)) + + inputs = [] + for im_path in image_list: + im, im_info = preprocess(im_path, preprocess_ops) + inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')}) + return inputs + + def __call__(self, image_list, thr=0.7, batch_size=16): + res = [] + imgs = [] + for i in range(len(image_list)): + if not isinstance(image_list[i], np.ndarray): + imgs.append(np.array(image_list[i])) + else: imgs.append(image_list[i]) + + batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size) + for i in range(batch_loop_cnt): + start_index = i * batch_size + end_index = min((i + 1) * batch_size, len(imgs)) + batch_image_list = imgs[start_index:end_index] + inputs = self.preprocess(batch_image_list) + for ins in inputs: + bb = [] + for b in self.ort_sess.run(None, ins)[0]: + clsid, bbox, score = int(b[0]), b[2:], b[1] + if score < thr: + continue + if clsid >= len(self.label_list): + cron_logger.warning(f"bad category id") + continue + bb.append({ + "type": self.label_list[clsid].lower(), + "bbox": [float(t) for t in bbox.tolist()], + "score": float(score) + }) + res.append(bb) + + #seeit.save_results(image_list, res, self.label_list, threshold=thr) + + return res diff --git a/deepdoc/visual/seeit.py b/deepdoc/vision/seeit.py similarity index 100% rename from deepdoc/visual/seeit.py rename to deepdoc/vision/seeit.py diff --git a/deepdoc/vision/table_structure_recognizer.py b/deepdoc/vision/table_structure_recognizer.py new file mode 100644 index 0000000..40366b1 --- /dev/null +++ b/deepdoc/vision/table_structure_recognizer.py @@ -0,0 +1,556 @@ +import logging +import os +import re +from collections import Counter +from copy import deepcopy + +import numpy as np + +from api.utils.file_utils import get_project_base_directory +from rag.nlp import huqie +from .recognizer import Recognizer + + +class TableStructureRecognizer(Recognizer): + def __init__(self): + self.labels = [ + "table", + "table column", + "table row", + "table column header", + "table projected row header", + "table spanning cell", + ] + super().__init__(self.labels, "tsr", + os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) + + def __call__(self, images, thr=0.5): + tbls = super().__call__(images, thr) + res = [] + # align left&right for rows, align top&bottom for columns + for tbl in tbls: + lts = [{"label": b["type"], + "score": b["score"], + "x0": b["bbox"][0], "x1": b["bbox"][2], + "top": b["bbox"][1], "bottom": b["bbox"][-1] + } for b in tbl] + if not lts: + continue + + left = [b["x0"] for b in lts if b["label"].find( + "row") > 0 or b["label"].find("header") > 0] + right = [b["x1"] for b in lts if b["label"].find( + "row") > 0 or b["label"].find("header") > 0] + if not left: + continue + left = np.median(left) if len(left) > 4 else np.min(left) + right = np.median(right) if len(right) > 4 else np.max(right) + for b in lts: + if b["label"].find("row") > 0 or b["label"].find("header") > 0: + if b["x0"] > left: + b["x0"] = left + if b["x1"] < right: + b["x1"] = right + + top = [b["top"] for b in lts if b["label"] == "table column"] + bottom = [b["bottom"] for b in lts if b["label"] == "table column"] + if not top: + res.append(lts) + continue + top = np.median(top) if len(top) > 4 else np.min(top) + bottom = np.median(bottom) if len(bottom) > 4 else np.max(bottom) + for b in lts: + if b["label"] == "table column": + if b["top"] > top: + b["top"] = top + if b["bottom"] < bottom: + b["bottom"] = bottom + + res.append(lts) + return res + + @staticmethod + def is_caption(bx): + patt = [ + r"[图表]+[ 0-9::]{2,}" + ] + if any([re.match(p, bx["text"].strip()) for p in patt]) \ + or bx["layout_type"].find("caption") >= 0: + return True + return False + + def __blockType(self, b): + patt = [ + ("^(20|19)[0-9]{2}[å¹´/-][0-9]{1,2}[月/-][0-9]{1,2}æ—¥*$", "Dt"), + (r"^(20|19)[0-9]{2}å¹´$", "Dt"), + (r"^(20|19)[0-9]{2}[å¹´-][0-9]{1,2}月*$", "Dt"), + ("^[0-9]{1,2}[月-][0-9]{1,2}æ—¥*$", "Dt"), + (r"^第*[一二三四1-4]å£åº¦$", "Dt"), + (r"^(20|19)[0-9]{2}å¹´*[一二三四1-4]å£åº¦$", "Dt"), + (r"^(20|19)[0-9]{2}[ABCDE]$", "Dt"), + ("^[0-9.,+%/ -]+$", "Nu"), + (r"^[0-9A-Z/\._~-]+$", "Ca"), + (r"^[A-Z]*[a-z' -]+$", "En"), + (r"^[0-9.,+-]+[0-9A-Za-z/$ï¿¥%<>()()' -]+$", "NE"), + (r"^.{1}$", "Sg") + ] + for p, n in patt: + if re.search(p, b["text"].strip()): + return n + tks = [t for t in huqie.qie(b["text"]).split(" ") if len(t) > 1] + if len(tks) > 3: + if len(tks) < 12: + return "Tx" + else: + return "Lx" + + if len(tks) == 1 and huqie.tag(tks[0]) == "nr": + return "Nr" + + return "Ot" + + def construct_table(self, boxes, is_english=False, html=False): + cap = "" + i = 0 + while i < len(boxes): + if self.is_caption(boxes[i]): + cap += boxes[i]["text"] + boxes.pop(i) + i -= 1 + i += 1 + + if not boxes: + return [] + for b in boxes: + b["btype"] = self.__blockType(b) + max_type = Counter([b["btype"] for b in boxes]).items() + max_type = max(max_type, key=lambda x: x[1])[0] if max_type else "" + logging.debug("MAXTYPE: " + max_type) + + rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b] + rowh = np.min(rowh) if rowh else 0 + boxes = self.sort_R_firstly(boxes, rowh / 2) + boxes[0]["rn"] = 0 + rows = [[boxes[0]]] + btm = boxes[0]["bottom"] + for b in boxes[1:]: + b["rn"] = len(rows) - 1 + lst_r = rows[-1] + if lst_r[-1].get("R", "") != b.get("R", "") \ + or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2") + ): # new row + btm = b["bottom"] + b["rn"] += 1 + rows.append([b]) + continue + btm = (btm + b["bottom"]) / 2. + rows[-1].append(b) + + colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b] + colwm = np.min(colwm) if colwm else 0 + crosspage = len(set([b["page_number"] for b in boxes])) > 1 + if crosspage: + boxes = self.sort_X_firstly(boxes, colwm / 2, False) + else: + boxes = self.sort_C_firstly(boxes, colwm / 2) + boxes[0]["cn"] = 0 + cols = [[boxes[0]]] + right = boxes[0]["x1"] + for b in boxes[1:]: + b["cn"] = len(cols) - 1 + lst_c = cols[-1] + if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1][ + "page_number"]) \ + or (b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")): # new col + right = b["x1"] + b["cn"] += 1 + cols.append([b]) + continue + right = (right + b["x1"]) / 2. + cols[-1].append(b) + + tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))] + for b in boxes: + tbl[b["rn"]][b["cn"]].append(b) + + if len(rows) >= 4: + # remove single in column + j = 0 + while j < len(tbl[0]): + e, ii = 0, 0 + for i in range(len(tbl)): + if tbl[i][j]: + e += 1 + ii = i + if e > 1: + break + if e > 1: + j += 1 + continue + f = (j > 0 and tbl[ii][j - 1] and tbl[ii] + [j - 1][0].get("text")) or j == 0 + ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii] + [j + 1][0].get("text")) or j + 1 >= len(tbl[ii]) + if f and ff: + j += 1 + continue + bx = tbl[ii][j][0] + logging.debug("Relocate column single: " + bx["text"]) + # j column only has one value + left, right = 100000, 100000 + if j > 0 and not f: + for i in range(len(tbl)): + if tbl[i][j - 1]: + left = min(left, np.min( + [bx["x0"] - a["x1"] for a in tbl[i][j - 1]])) + if j + 1 < len(tbl[0]) and not ff: + for i in range(len(tbl)): + if tbl[i][j + 1]: + right = min(right, np.min( + [a["x0"] - bx["x1"] for a in tbl[i][j + 1]])) + assert left < 100000 or right < 100000 + if left < right: + for jj in range(j, len(tbl[0])): + for i in range(len(tbl)): + for a in tbl[i][jj]: + a["cn"] -= 1 + if tbl[ii][j - 1]: + tbl[ii][j - 1].extend(tbl[ii][j]) + else: + tbl[ii][j - 1] = tbl[ii][j] + for i in range(len(tbl)): + tbl[i].pop(j) + + else: + for jj in range(j + 1, len(tbl[0])): + for i in range(len(tbl)): + for a in tbl[i][jj]: + a["cn"] -= 1 + if tbl[ii][j + 1]: + tbl[ii][j + 1].extend(tbl[ii][j]) + else: + tbl[ii][j + 1] = tbl[ii][j] + for i in range(len(tbl)): + tbl[i].pop(j) + cols.pop(j) + assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % ( + len(cols), len(tbl[0])) + + if len(cols) >= 4: + # remove single in row + i = 0 + while i < len(tbl): + e, jj = 0, 0 + for j in range(len(tbl[i])): + if tbl[i][j]: + e += 1 + jj = j + if e > 1: + break + if e > 1: + i += 1 + continue + f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1] + [jj][0].get("text")) or i == 0 + ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1] + [jj][0].get("text")) or i + 1 >= len(tbl) + if f and ff: + i += 1 + continue + + bx = tbl[i][jj][0] + logging.debug("Relocate row single: " + bx["text"]) + # i row only has one value + up, down = 100000, 100000 + if i > 0 and not f: + for j in range(len(tbl[i - 1])): + if tbl[i - 1][j]: + up = min(up, np.min( + [bx["top"] - a["bottom"] for a in tbl[i - 1][j]])) + if i + 1 < len(tbl) and not ff: + for j in range(len(tbl[i + 1])): + if tbl[i + 1][j]: + down = min(down, np.min( + [a["top"] - bx["bottom"] for a in tbl[i + 1][j]])) + assert up < 100000 or down < 100000 + if up < down: + for ii in range(i, len(tbl)): + for j in range(len(tbl[ii])): + for a in tbl[ii][j]: + a["rn"] -= 1 + if tbl[i - 1][jj]: + tbl[i - 1][jj].extend(tbl[i][jj]) + else: + tbl[i - 1][jj] = tbl[i][jj] + tbl.pop(i) + + else: + for ii in range(i + 1, len(tbl)): + for j in range(len(tbl[ii])): + for a in tbl[ii][j]: + a["rn"] -= 1 + if tbl[i + 1][jj]: + tbl[i + 1][jj].extend(tbl[i][jj]) + else: + tbl[i + 1][jj] = tbl[i][jj] + tbl.pop(i) + rows.pop(i) + + # which rows are headers + hdset = set([]) + for i in range(len(tbl)): + cnt, h = 0, 0 + for j, arr in enumerate(tbl[i]): + if not arr: + continue + cnt += 1 + if max_type == "Nu" and arr[0]["btype"] == "Nu": + continue + if any([a.get("H") for a in arr]) \ + or (max_type == "Nu" and arr[0]["btype"] != "Nu"): + h += 1 + if h / cnt > 0.5: + hdset.add(i) + + if html: + return [self.__html_table(cap, hdset, + self.__cal_spans(boxes, rows, + cols, tbl, True) + )] + + return self.__desc_table(cap, hdset, + self.__cal_spans(boxes, rows, cols, tbl, False), + is_english) + + def __html_table(self, cap, hdset, tbl): + # constrcut HTML + html = "<table>" + if cap: + html += f"<caption>{cap}</caption>" + for i in range(len(tbl)): + row = "<tr>" + txts = [] + for j, arr in enumerate(tbl[i]): + if arr is None: + continue + if not arr: + row += "<td></td>" if i not in hdset else "<th></th>" + continue + txt = "" + if arr: + h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10) + txt = "".join([c["text"] + for c in self.sort_Y_firstly(arr, h)]) + txts.append(txt) + sp = "" + if arr[0].get("colspan"): + sp = "colspan={}".format(arr[0]["colspan"]) + if arr[0].get("rowspan"): + sp += " rowspan={}".format(arr[0]["rowspan"]) + if i in hdset: + row += f"<th {sp} >" + txt + "</th>" + else: + row += f"<td {sp} >" + txt + "</td>" + + if i in hdset: + if all([t in hdset for t in txts]): + continue + for t in txts: + hdset.add(t) + + if row != "<tr>": + row += "</tr>" + else: + row = "" + html += "\n" + row + html += "\n</table>" + return html + + def __desc_table(self, cap, hdr_rowno, tbl, is_english): + # get text of every colomn in header row to become header text + clmno = len(tbl[0]) + rowno = len(tbl) + headers = {} + hdrset = set() + lst_hdr = [] + de = "çš„" if not is_english else " for " + for r in sorted(list(hdr_rowno)): + headers[r] = ["" for _ in range(clmno)] + for i in range(clmno): + if not tbl[r][i]: + continue + txt = "".join([a["text"].strip() for a in tbl[r][i]]) + headers[r][i] = txt + hdrset.add(txt) + if all([not t for t in headers[r]]): + del headers[r] + hdr_rowno.remove(r) + continue + for j in range(clmno): + if headers[r][j]: + continue + if j >= len(lst_hdr): + break + headers[r][j] = lst_hdr[j] + lst_hdr = headers[r] + for i in range(rowno): + if i not in hdr_rowno: + continue + for j in range(i + 1, rowno): + if j not in hdr_rowno: + break + for k in range(clmno): + if not headers[j - 1][k]: + continue + if headers[j][k].find(headers[j - 1][k]) >= 0: + continue + if len(headers[j][k]) > len(headers[j - 1][k]): + headers[j][k] += (de if headers[j][k] + else "") + headers[j - 1][k] + else: + headers[j][k] = headers[j - 1][k] \ + + (de if headers[j - 1][k] else "") \ + + headers[j][k] + + logging.debug( + f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}") + row_txt = [] + for i in range(rowno): + if i in hdr_rowno: + continue + rtxt = [] + + def append(delimer): + nonlocal rtxt, row_txt + rtxt = delimer.join(rtxt) + if row_txt and len(row_txt[-1]) + len(rtxt) < 64: + row_txt[-1] += "\n" + rtxt + else: + row_txt.append(rtxt) + + r = 0 + if len(headers.items()): + _arr = [(i - r, r) for r, _ in headers.items() if r < i] + if _arr: + _, r = min(_arr, key=lambda x: x[0]) + + if r not in headers and clmno <= 2: + for j in range(clmno): + if not tbl[i][j]: + continue + txt = "".join([a["text"].strip() for a in tbl[i][j]]) + if txt: + rtxt.append(txt) + if rtxt: + append(":") + continue + + for j in range(clmno): + if not tbl[i][j]: + continue + txt = "".join([a["text"].strip() for a in tbl[i][j]]) + if not txt: + continue + ctt = headers[r][j] if r in headers else "" + if ctt: + ctt += ":" + ctt += txt + if ctt: + rtxt.append(ctt) + + if rtxt: + row_txt.append("; ".join(rtxt)) + + if cap: + if is_english: + from_ = " in " + else: + from_ = "æ¥è‡ª" + row_txt = [t + f"\t——{from_}“{cap}â€" for t in row_txt] + return row_txt + + def __cal_spans(self, boxes, rows, cols, tbl, html=True): + # caculate span + clft = [np.mean([c.get("C_left", c["x0"]) for c in cln]) + for cln in cols] + crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln]) + for cln in cols] + rtop = [np.mean([c.get("R_top", c["top"]) for c in row]) + for row in rows] + rbtm = [np.mean([c.get("R_btm", c["bottom"]) + for c in row]) for row in rows] + for b in boxes: + if "SP" not in b: + continue + b["colspan"] = [b["cn"]] + b["rowspan"] = [b["rn"]] + # col span + for j in range(0, len(clft)): + if j == b["cn"]: + continue + if clft[j] + (crgt[j] - clft[j]) / 2 < b["H_left"]: + continue + if crgt[j] - (crgt[j] - clft[j]) / 2 > b["H_right"]: + continue + b["colspan"].append(j) + # row span + for j in range(0, len(rtop)): + if j == b["rn"]: + continue + if rtop[j] + (rbtm[j] - rtop[j]) / 2 < b["H_top"]: + continue + if rbtm[j] - (rbtm[j] - rtop[j]) / 2 > b["H_bott"]: + continue + b["rowspan"].append(j) + + def join(arr): + if not arr: + return "" + return "".join([t["text"] for t in arr]) + + # rm the spaning cells + for i in range(len(tbl)): + for j, arr in enumerate(tbl[i]): + if not arr: + continue + if all(["rowspan" not in a and "colspan" not in a for a in arr]): + continue + rowspan, colspan = [], [] + for a in arr: + if isinstance(a.get("rowspan", 0), list): + rowspan.extend(a["rowspan"]) + if isinstance(a.get("colspan", 0), list): + colspan.extend(a["colspan"]) + rowspan, colspan = set(rowspan), set(colspan) + if len(rowspan) < 2 and len(colspan) < 2: + for a in arr: + if "rowspan" in a: + del a["rowspan"] + if "colspan" in a: + del a["colspan"] + continue + rowspan, colspan = sorted(rowspan), sorted(colspan) + rowspan = list(range(rowspan[0], rowspan[-1] + 1)) + colspan = list(range(colspan[0], colspan[-1] + 1)) + assert i in rowspan, rowspan + assert j in colspan, colspan + arr = [] + for r in rowspan: + for c in colspan: + arr_txt = join(arr) + if tbl[r][c] and join(tbl[r][c]) != arr_txt: + arr.extend(tbl[r][c]) + tbl[r][c] = None if html else arr + for a in arr: + if len(rowspan) > 1: + a["rowspan"] = len(rowspan) + elif "rowspan" in a: + del a["rowspan"] + if len(colspan) > 1: + a["colspan"] = len(colspan) + elif "colspan" in a: + del a["colspan"] + tbl[rowspan[0]][colspan[0]] = arr + + return tbl + diff --git a/deepdoc/visual/__init__.py b/deepdoc/visual/__init__.py deleted file mode 100644 index e53762a..0000000 --- a/deepdoc/visual/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .ocr import OCR -from .recognizer import Recognizer \ No newline at end of file diff --git a/deepdoc/visual/recognizer.py b/deepdoc/visual/recognizer.py deleted file mode 100644 index 09ccbb3..0000000 --- a/deepdoc/visual/recognizer.py +++ /dev/null @@ -1,139 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import os -import onnxruntime as ort -from huggingface_hub import snapshot_download - -from .operators import * -from rag.settings import cron_logger - - -class Recognizer(object): - def __init__(self, label_list, task_name, model_dir=None): - """ - If you have trouble downloading HuggingFace models, -_^ this might help!! - - For Linux: - export HF_ENDPOINT=https://hf-mirror.com - - For Windows: - Good luck - ^_- - - """ - if not model_dir: - model_dir = snapshot_download(repo_id="InfiniFlow/ocr") - - model_file_path = os.path.join(model_dir, task_name + ".onnx") - if not os.path.exists(model_file_path): - raise ValueError("not find model file path {}".format( - model_file_path)) - if ort.get_device() == "GPU": - self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider']) - else: - self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider']) - self.label_list = label_list - - def create_inputs(self, imgs, im_info): - """generate input for different model type - Args: - imgs (list(numpy)): list of images (np.ndarray) - im_info (list(dict)): list of image info - Returns: - inputs (dict): input of model - """ - inputs = {} - - im_shape = [] - scale_factor = [] - if len(imgs) == 1: - inputs['image'] = np.array((imgs[0],)).astype('float32') - inputs['im_shape'] = np.array( - (im_info[0]['im_shape'],)).astype('float32') - inputs['scale_factor'] = np.array( - (im_info[0]['scale_factor'],)).astype('float32') - return inputs - - for e in im_info: - im_shape.append(np.array((e['im_shape'],)).astype('float32')) - scale_factor.append(np.array((e['scale_factor'],)).astype('float32')) - - inputs['im_shape'] = np.concatenate(im_shape, axis=0) - inputs['scale_factor'] = np.concatenate(scale_factor, axis=0) - - imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs] - max_shape_h = max([e[0] for e in imgs_shape]) - max_shape_w = max([e[1] for e in imgs_shape]) - padding_imgs = [] - for img in imgs: - im_c, im_h, im_w = img.shape[:] - padding_im = np.zeros( - (im_c, max_shape_h, max_shape_w), dtype=np.float32) - padding_im[:, :im_h, :im_w] = img - padding_imgs.append(padding_im) - inputs['image'] = np.stack(padding_imgs, axis=0) - return inputs - - def preprocess(self, image_list): - preprocess_ops = [] - for op_info in [ - {'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'}, - {'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'}, - {'type': 'Permute'}, - {'stride': 32, 'type': 'PadStride'} - ]: - new_op_info = op_info.copy() - op_type = new_op_info.pop('type') - preprocess_ops.append(eval(op_type)(**new_op_info)) - - inputs = [] - for im_path in image_list: - im, im_info = preprocess(im_path, preprocess_ops) - inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')}) - return inputs - - - def __call__(self, image_list, thr=0.7, batch_size=16): - res = [] - imgs = [] - for i in range(len(image_list)): - if not isinstance(image_list[i], np.ndarray): - imgs.append(np.array(image_list[i])) - else: imgs.append(image_list[i]) - - batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size) - for i in range(batch_loop_cnt): - start_index = i * batch_size - end_index = min((i + 1) * batch_size, len(imgs)) - batch_image_list = imgs[start_index:end_index] - inputs = self.preprocess(batch_image_list) - for ins in inputs: - bb = [] - for b in self.ort_sess.run(None, ins)[0]: - clsid, bbox, score = int(b[0]), b[2:], b[1] - if score < thr: - continue - if clsid >= len(self.label_list): - cron_logger.warning(f"bad category id") - continue - bb.append({ - "type": self.label_list[clsid].lower(), - "bbox": [float(t) for t in bbox.tolist()], - "score": float(score) - }) - res.append(bb) - - #seeit.save_results(image_list, res, self.label_list, threshold=thr) - - return res diff --git a/rag/svr/task_broker.py b/rag/svr/task_broker.py index 1204713..cd08b9f 100644 --- a/rag/svr/task_broker.py +++ b/rag/svr/task_broker.py @@ -21,7 +21,7 @@ from datetime import datetime from api.db.db_models import Task from api.db.db_utils import bulk_insert_into_db from api.db.services.task_service import TaskService -from deepdoc.parser import HuParser +from deepdoc.parser import PdfParser from rag.settings import cron_logger from rag.utils import MINIO from rag.utils import findMaxTm @@ -80,7 +80,7 @@ def dispatch(): tsks = [] if r["type"] == FileType.PDF.value: - pages = HuParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"])) + pages = PdfParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"])) for s,e in r["parser_config"].get("pages", [(0,100000)]): e = min(e, pages) for p in range(s, e, 10): -- GitLab