From fd7fcb5baff9b9bbf0324056715f3a2591ae9967 Mon Sep 17 00:00:00 2001 From: KevinHuSh <kevinhu.sh@gmail.com> Date: Wed, 27 Mar 2024 11:33:46 +0800 Subject: [PATCH] apply pep8 formalize (#155) --- api/apps/chunk_app.py | 19 +- api/apps/conversation_app.py | 124 +++++--- api/apps/dialog_app.py | 43 ++- api/apps/document_app.py | 29 +- api/apps/kb_app.py | 56 +++- api/apps/llm_app.py | 38 ++- api/apps/user_app.py | 86 ++++-- api/db/db_models.py | 307 +++++++++++++++---- api/db/db_utils.py | 15 +- api/db/init_data.py | 100 +++--- api/db/operatioins.py | 2 +- api/db/reload_config_base.py | 5 +- api/db/runtime_config.py | 2 +- api/db/services/common_service.py | 38 ++- api/db/services/dialog_service.py | 1 - api/db/services/document_service.py | 61 +++- api/db/services/knowledgebase_service.py | 15 +- api/db/services/llm_service.py | 33 +- api/db/services/user_service.py | 45 ++- api/settings.py | 54 +++- api/utils/__init__.py | 53 +++- api/utils/api_utils.py | 60 +++- api/utils/file_utils.py | 23 +- api/utils/log_utils.py | 60 ++-- api/utils/t_crypt.py | 15 +- deepdoc/parser/__init__.py | 2 - deepdoc/parser/docx_parser.py | 9 +- deepdoc/parser/excel_parser.py | 19 +- deepdoc/parser/pdf_parser.py | 158 ++++++---- deepdoc/parser/ppt_parser.py | 20 +- deepdoc/vision/layout_recognizer.py | 53 ++-- deepdoc/vision/operators.py | 3 +- deepdoc/vision/t_ocr.py | 34 +- deepdoc/vision/t_recognizer.py | 51 +-- deepdoc/vision/table_structure_recognizer.py | 23 +- rag/app/book.py | 54 ++-- rag/app/laws.py | 44 ++- rag/app/manual.py | 49 +-- rag/app/naive.py | 33 +- rag/app/one.py | 26 +- rag/app/paper.py | 40 ++- rag/app/presentation.py | 64 ++-- rag/app/resume.py | 25 +- rag/app/table.py | 42 ++- rag/llm/chat_model.py | 37 ++- rag/llm/cv_model.py | 17 +- rag/llm/embedding_model.py | 29 +- rag/llm/rpc_server.py | 27 +- rag/nlp/huchunk.py | 19 +- rag/nlp/query.py | 11 +- rag/nlp/search.py | 61 ++-- rag/nlp/term_weight.py | 6 +- rag/settings.py | 9 +- rag/svr/task_broker.py | 64 ++-- rag/svr/task_executor.py | 18 +- 55 files changed, 1573 insertions(+), 758 deletions(-) diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 886a02f..247627a 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -121,7 +121,9 @@ def get(): "important_kwd") def set(): req = request.json - d = {"id": req["chunk_id"], "content_with_weight": req["content_with_weight"]} + d = { + "id": req["chunk_id"], + "content_with_weight": req["content_with_weight"]} d["content_ltks"] = huqie.qie(req["content_with_weight"]) d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) d["important_kwd"] = req["important_kwd"] @@ -140,10 +142,16 @@ def set(): return get_data_error_result(retmsg="Document not found!") if doc.parser_id == ParserType.QA: - arr = [t for t in re.split(r"[\n\t]", req["content_with_weight"]) if len(t) > 1] - if len(arr) != 2: return get_data_error_result(retmsg="Q&A must be separated by TAB/ENTER key.") + arr = [ + t for t in re.split( + r"[\n\t]", + req["content_with_weight"]) if len(t) > 1] + if len(arr) != 2: + return get_data_error_result( + retmsg="Q&A must be separated by TAB/ENTER key.") q, a = rmPrefix(arr[0]), rmPrefix[arr[1]] - d = beAdoc(d, arr[0], arr[1], not any([huqie.is_chinese(t) for t in q + a])) + d = beAdoc(d, arr[0], arr[1], not any( + [huqie.is_chinese(t) for t in q + a])) v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] @@ -177,7 +185,8 @@ def switch(): def rm(): req = request.json try: - if not ELASTICSEARCH.deleteByQuery(Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)): + if not ELASTICSEARCH.deleteByQuery( + Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)): return get_data_error_result(retmsg="Index updating failure") return get_json_result(data=True) except Exception as e: diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 5a23efb..5c55d5d 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -100,7 +100,10 @@ def rm(): def list_convsersation(): dialog_id = request.args["dialog_id"] try: - convs = ConversationService.query(dialog_id=dialog_id, order_by=ConversationService.model.create_time, reverse=True) + convs = ConversationService.query( + dialog_id=dialog_id, + order_by=ConversationService.model.create_time, + reverse=True) convs = [d.to_dict() for d in convs] return get_json_result(data=convs) except Exception as e: @@ -111,19 +114,24 @@ def message_fit_in(msg, max_length=4000): def count(): nonlocal msg tks_cnts = [] - for m in msg: tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])}) + for m in msg: + tks_cnts.append( + {"role": m["role"], "count": num_tokens_from_string(m["content"])}) total = 0 - for m in tks_cnts: total += m["count"] + for m in tks_cnts: + total += m["count"] return total c = count() - if c < max_length: return c, msg + if c < max_length: + return c, msg msg_ = [m for m in msg[:-1] if m.role == "system"] msg_.append(msg[-1]) msg = msg_ c = count() - if c < max_length: return c, msg + if c < max_length: + return c, msg ll = num_tokens_from_string(msg_[0].content) l = num_tokens_from_string(msg_[-1].content) @@ -146,8 +154,10 @@ def completion(): req = request.json msg = [] for m in req["messages"]: - if m["role"] == "system": continue - if m["role"] == "assistant" and not msg: continue + if m["role"] == "system": + continue + if m["role"] == "assistant" and not msg: + continue msg.append({"role": m["role"], "content": m["content"]}) try: e, conv = ConversationService.get_by_id(req["conversation_id"]) @@ -160,7 +170,8 @@ def completion(): del req["conversation_id"] del req["messages"] ans = chat(dia, msg, **req) - if not conv.reference: conv.reference = [] + if not conv.reference: + conv.reference = [] conv.reference.append(ans["reference"]) conv.message.append({"role": "assistant", "content": ans["answer"]}) ConversationService.update_by_id(conv.id, conv.to_dict()) @@ -180,52 +191,67 @@ def chat(dialog, messages, **kwargs): chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) - ## try to use sql if field mapping is good to go + # try to use sql if field mapping is good to go if field_map: chat_logger.info("Use SQL to retrieval:{}".format(questions[-1])) return use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl) prompt_config = dialog.prompt_config for p in prompt_config["parameters"]: - if p["key"] == "knowledge": continue - if p["key"] not in kwargs and not p["optional"]: raise KeyError("Miss parameter: " + p["key"]) + if p["key"] == "knowledge": + continue + if p["key"] not in kwargs and not p["optional"]: + raise KeyError("Miss parameter: " + p["key"]) if p["key"] not in kwargs: - prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ") + prompt_config["system"] = prompt_config["system"].replace( + "{%s}" % p["key"], " ") - for _ in range(len(questions)//2): + for _ in range(len(questions) // 2): questions.append(questions[-1]) if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]: - kbinfos = {"total":0, "chunks":[],"doc_aggs":[]} + kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} else: kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, - dialog.similarity_threshold, - dialog.vector_similarity_weight, top=1024, aggs=False) + dialog.similarity_threshold, + dialog.vector_similarity_weight, top=1024, aggs=False) knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] - chat_logger.info("{}->{}".format(" ".join(questions), "\n->".join(knowledges))) + chat_logger.info( + "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) if not knowledges and prompt_config.get("empty_response"): - return {"answer": prompt_config["empty_response"], "reference": kbinfos} + return { + "answer": prompt_config["empty_response"], "reference": kbinfos} kwargs["knowledge"] = "\n".join(knowledges) gen_conf = dialog.llm_setting - msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"] + msg = [{"role": m["role"], "content": m["content"]} + for m in messages if m["role"] != "system"] used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97)) if "max_tokens" in gen_conf: - gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count) - answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf) - chat_logger.info("User: {}|Assistant: {}".format(msg[-1]["content"], answer)) + gen_conf["max_tokens"] = min( + gen_conf["max_tokens"], + llm.max_tokens - used_token_count) + answer = chat_mdl.chat( + prompt_config["system"].format( + **kwargs), msg, gen_conf) + chat_logger.info("User: {}|Assistant: {}".format( + msg[-1]["content"], answer)) if knowledges: answer, idx = retrievaler.insert_citations(answer, - [ck["content_ltks"] for ck in kbinfos["chunks"]], - [ck["vector"] for ck in kbinfos["chunks"]], - embd_mdl, - tkweight=1 - dialog.vector_similarity_weight, - vtweight=dialog.vector_similarity_weight) + [ck["content_ltks"] + for ck in kbinfos["chunks"]], + [ck["vector"] + for ck in kbinfos["chunks"]], + embd_mdl, + tkweight=1 - dialog.vector_similarity_weight, + vtweight=dialog.vector_similarity_weight) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) - kbinfos["doc_aggs"] = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] + kbinfos["doc_aggs"] = [ + d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] for c in kbinfos["chunks"]: - if c.get("vector"): del c["vector"] + if c.get("vector"): + del c["vector"] return {"answer": answer, "reference": kbinfos} @@ -245,9 +271,11 @@ def use_sql(question, field_map, tenant_id, chat_mdl): question ) tried_times = 0 + def get_table(): nonlocal sys_prompt, user_promt, question, tried_times - sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.06}) + sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], { + "temperature": 0.06}) print(user_promt, sql) chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}") sql = re.sub(r"[\r\n]+", " ", sql.lower()) @@ -262,8 +290,10 @@ def use_sql(question, field_map, tenant_id, chat_mdl): else: flds = [] for k in field_map.keys(): - if k in forbidden_select_fields4resume:continue - if len(flds) > 11:break + if k in forbidden_select_fields4resume: + continue + if len(flds) > 11: + break flds.append(k) sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:] @@ -284,13 +314,13 @@ def use_sql(question, field_map, tenant_id, chat_mdl): é—®é˘ĺ¦‚下: {} - + ä˝ ä¸Šä¸€ć¬ˇç»™ĺ‡şçš„é”™čŻŻSQL如下: {} - + ĺŽĺŹ°ćŠĄé”™ĺ¦‚下: {} - + čŻ·çş ćŁSQLä¸çš„错误再写一éŤďĽŚä¸”只č¦SQL,不č¦ćś‰ĺ…¶ä»–说ćŽĺŹŠć–‡ĺ—。 """.format( index_name(tenant_id), @@ -302,16 +332,24 @@ def use_sql(question, field_map, tenant_id, chat_mdl): chat_logger.info("GET table: {}".format(tbl)) print(tbl) - if tbl.get("error") or len(tbl["rows"]) == 0: return None, None + if tbl.get("error") or len(tbl["rows"]) == 0: + return None, None - docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"]) - docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"]) - clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)] + docid_idx = set([ii for ii, c in enumerate( + tbl["columns"]) if c["name"] == "doc_id"]) + docnm_idx = set([ii for ii, c in enumerate( + tbl["columns"]) if c["name"] == "docnm_kwd"]) + clmn_idx = [ii for ii in range( + len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)] # compose markdown table - clmns = "|"+"|".join([re.sub(r"(/.*|ďĽ[^ďĽďĽ‰]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|") - line = "|"+"|".join(["------" for _ in range(len(clmn_idx))]) + ("|------|" if docid_idx and docid_idx else "") - rows = ["|"+"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]] + clmns = "|" + "|".join([re.sub(r"(/.*|ďĽ[^ďĽďĽ‰]+))", "", field_map.get(tbl["columns"][i]["name"], + tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|") + line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \ + ("|------|" if docid_idx and docid_idx else "") + rows = ["|" + + "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + + "|" for r in tbl["rows"]] if not docid_idx or not docnm_idx: chat_logger.warning("SQL missing field: " + sql) return "\n".join([clmns, line, "\n".join(rows)]), [] @@ -328,5 +366,5 @@ def use_sql(question, field_map, tenant_id, chat_mdl): return { "answer": "\n".join([clmns, line, rows]), "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]], - "doc_aggs":[{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]} + "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]} } diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py index fccc0ec..50db453 100644 --- a/api/apps/dialog_app.py +++ b/api/apps/dialog_app.py @@ -55,7 +55,8 @@ def set_dialog(): } prompt_config = req.get("prompt_config", default_prompt) - if not prompt_config["system"]: prompt_config["system"] = default_prompt["system"] + if not prompt_config["system"]: + prompt_config["system"] = default_prompt["system"] # if len(prompt_config["parameters"]) < 1: # prompt_config["parameters"] = default_prompt["parameters"] # for p in prompt_config["parameters"]: @@ -63,16 +64,21 @@ def set_dialog(): # else: prompt_config["parameters"].append(default_prompt["parameters"][0]) for p in prompt_config["parameters"]: - if p["optional"]: continue + if p["optional"]: + continue if prompt_config["system"].find("{%s}" % p["key"]) < 0: - return get_data_error_result(retmsg="Parameter '{}' is not used".format(p["key"])) + return get_data_error_result( + retmsg="Parameter '{}' is not used".format(p["key"])) try: e, tenant = TenantService.get_by_id(current_user.id) - if not e: return get_data_error_result(retmsg="Tenant not found!") + if not e: + return get_data_error_result(retmsg="Tenant not found!") llm_id = req.get("llm_id", tenant.llm_id) if not dialog_id: - if not req.get("kb_ids"):return get_data_error_result(retmsg="Fail! Please select knowledgebase!") + if not req.get("kb_ids"): + return get_data_error_result( + retmsg="Fail! Please select knowledgebase!") dia = { "id": get_uuid(), "tenant_id": current_user.id, @@ -86,17 +92,21 @@ def set_dialog(): "similarity_threshold": similarity_threshold, "vector_similarity_weight": vector_similarity_weight } - if not DialogService.save(**dia): return get_data_error_result(retmsg="Fail to new a dialog!") + if not DialogService.save(**dia): + return get_data_error_result(retmsg="Fail to new a dialog!") e, dia = DialogService.get_by_id(dia["id"]) - if not e: return get_data_error_result(retmsg="Fail to new a dialog!") + if not e: + return get_data_error_result(retmsg="Fail to new a dialog!") return get_json_result(data=dia.to_json()) else: del req["dialog_id"] - if "kb_names" in req: del req["kb_names"] + if "kb_names" in req: + del req["kb_names"] if not DialogService.update_by_id(dialog_id, req): return get_data_error_result(retmsg="Dialog not found!") e, dia = DialogService.get_by_id(dialog_id) - if not e: return get_data_error_result(retmsg="Fail to update a dialog!") + if not e: + return get_data_error_result(retmsg="Fail to update a dialog!") dia = dia.to_dict() dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"]) return get_json_result(data=dia) @@ -110,7 +120,8 @@ def get(): dialog_id = request.args["dialog_id"] try: e, dia = DialogService.get_by_id(dialog_id) - if not e: return get_data_error_result(retmsg="Dialog not found!") + if not e: + return get_data_error_result(retmsg="Dialog not found!") dia = dia.to_dict() dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"]) return get_json_result(data=dia) @@ -122,7 +133,8 @@ def get_kb_names(kb_ids): ids, nms = [], [] for kid in kb_ids: e, kb = KnowledgebaseService.get_by_id(kid) - if not e or kb.status != StatusEnum.VALID.value: continue + if not e or kb.status != StatusEnum.VALID.value: + continue ids.append(kid) nms.append(kb.name) return ids, nms @@ -132,7 +144,11 @@ def get_kb_names(kb_ids): @login_required def list(): try: - diags = DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value, reverse=True, order_by=DialogService.model.create_time) + diags = DialogService.query( + tenant_id=current_user.id, + status=StatusEnum.VALID.value, + reverse=True, + order_by=DialogService.model.create_time) diags = [d.to_dict() for d in diags] for d in diags: d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"]) @@ -147,7 +163,8 @@ def list(): def rm(): req = request.json try: - DialogService.update_many_by_id([{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]]) + DialogService.update_many_by_id( + [{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]]) return get_json_result(data=True) except Exception as e: return server_error_response(e) diff --git a/api/apps/document_app.py b/api/apps/document_app.py index b944076..ea06a3c 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -57,6 +57,9 @@ def upload(): if not e: return get_data_error_result( retmsg="Can't find this knowledgebase!") + if DocumentService.get_doc_count(kb.tenant_id) >= 128: + return get_data_error_result( + retmsg="Exceed the maximum file number of a free user!") filename = duplicate_name( DocumentService.query, @@ -215,9 +218,11 @@ def rm(): tenant_id = DocumentService.get_tenant_id(req["doc_id"]) if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") - ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)) + ELASTICSEARCH.deleteByQuery( + Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)) - DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, 0) + DocumentService.increment_chunk_num( + doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, 0) if not DocumentService.delete(doc): return get_data_error_result( retmsg="Database error (Document removal)!") @@ -245,7 +250,8 @@ def run(): tenant_id = DocumentService.get_tenant_id(id) if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") - ELASTICSEARCH.deleteByQuery(Q("match", doc_id=id), idxnm=search.index_name(tenant_id)) + ELASTICSEARCH.deleteByQuery( + Q("match", doc_id=id), idxnm=search.index_name(tenant_id)) return get_json_result(data=True) except Exception as e: @@ -261,7 +267,8 @@ def rename(): e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: return get_data_error_result(retmsg="Document not found!") - if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix: + if pathlib.Path(req["name"].lower()).suffix != pathlib.Path( + doc.name.lower()).suffix: return get_json_result( data=False, retmsg="The extension of file can't be changed", @@ -294,7 +301,10 @@ def get(doc_id): if doc.type == FileType.VISUAL.value: response.headers.set('Content-Type', 'image/%s' % ext.group(1)) else: - response.headers.set('Content-Type', 'application/%s' % ext.group(1)) + response.headers.set( + 'Content-Type', + 'application/%s' % + ext.group(1)) return response except Exception as e: return server_error_response(e) @@ -313,9 +323,11 @@ def change_parser(): if "parser_config" in req: if req["parser_config"] == doc.parser_config: return get_json_result(data=True) - else: return get_json_result(data=True) + else: + return get_json_result(data=True) - if doc.type == FileType.VISUAL or re.search(r"\.(ppt|pptx|pages)$", doc.name): + if doc.type == FileType.VISUAL or re.search( + r"\.(ppt|pptx|pages)$", doc.name): return get_data_error_result(retmsg="Not supported yet!") e = DocumentService.update_by_id(doc.id, @@ -332,7 +344,8 @@ def change_parser(): tenant_id = DocumentService.get_tenant_id(req["doc_id"]) if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") - ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)) + ELASTICSEARCH.deleteByQuery( + Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)) return get_json_result(data=True) except Exception as e: diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index b0ba165..bcffbc8 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -33,15 +33,21 @@ from api.utils.api_utils import get_json_result def create(): req = request.json req["name"] = req["name"].strip() - req["name"] = duplicate_name(KnowledgebaseService.query, name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value) + req["name"] = duplicate_name( + KnowledgebaseService.query, + name=req["name"], + tenant_id=current_user.id, + status=StatusEnum.VALID.value) try: req["id"] = get_uuid() req["tenant_id"] = current_user.id req["created_by"] = current_user.id e, t = TenantService.get_by_id(current_user.id) - if not e: return get_data_error_result(retmsg="Tenant not found.") + if not e: + return get_data_error_result(retmsg="Tenant not found.") req["embd_id"] = t.embd_id - if not KnowledgebaseService.save(**req): return get_data_error_result() + if not KnowledgebaseService.save(**req): + return get_data_error_result() return get_json_result(data={"kb_id": req["id"]}) except Exception as e: return server_error_response(e) @@ -54,21 +60,29 @@ def update(): req = request.json req["name"] = req["name"].strip() try: - if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]): - return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) + if not KnowledgebaseService.query( + created_by=current_user.id, id=req["kb_id"]): + return get_json_result( + data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) e, kb = KnowledgebaseService.get_by_id(req["kb_id"]) - if not e: return get_data_error_result(retmsg="Can't find this knowledgebase!") + if not e: + return get_data_error_result( + retmsg="Can't find this knowledgebase!") if req["name"].lower() != kb.name.lower() \ - and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value))>1: - return get_data_error_result(retmsg="Duplicated knowledgebase name.") + and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) > 1: + return get_data_error_result( + retmsg="Duplicated knowledgebase name.") del req["kb_id"] - if not KnowledgebaseService.update_by_id(kb.id, req): return get_data_error_result() + if not KnowledgebaseService.update_by_id(kb.id, req): + return get_data_error_result() e, kb = KnowledgebaseService.get_by_id(kb.id) - if not e: return get_data_error_result(retmsg="Database error (Knowledgebase rename)!") + if not e: + return get_data_error_result( + retmsg="Database error (Knowledgebase rename)!") return get_json_result(data=kb.to_json()) except Exception as e: @@ -81,7 +95,9 @@ def detail(): kb_id = request.args["kb_id"] try: kb = KnowledgebaseService.get_detail(kb_id) - if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!") + if not kb: + return get_data_error_result( + retmsg="Can't find this knowledgebase!") return get_json_result(data=kb) except Exception as e: return server_error_response(e) @@ -96,7 +112,8 @@ def list(): desc = request.args.get("desc", True) try: tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) - kbs = KnowledgebaseService.get_by_tenant_ids([m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc) + kbs = KnowledgebaseService.get_by_tenant_ids( + [m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc) return get_json_result(data=kbs) except Exception as e: return server_error_response(e) @@ -108,10 +125,15 @@ def list(): def rm(): req = request.json try: - if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]): - return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) - - if not KnowledgebaseService.update_by_id(req["kb_id"], {"status": StatusEnum.INVALID.value}): return get_data_error_result(retmsg="Database error (Knowledgebase removal)!") + if not KnowledgebaseService.query( + created_by=current_user.id, id=req["kb_id"]): + return get_json_result( + data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) + + if not KnowledgebaseService.update_by_id( + req["kb_id"], {"status": StatusEnum.INVALID.value}): + return get_data_error_result( + retmsg="Database error (Knowledgebase removal)!") return get_json_result(data=True) except Exception as e: - return server_error_response(e) \ No newline at end of file + return server_error_response(e) diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 6ed3bc3..0a98ab0 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -48,30 +48,42 @@ def set_api_key(): req["api_key"], llm.llm_name) try: arr, tc = mdl.encode(["Test if the api key is available"]) - if len(arr[0]) == 0 or tc ==0: raise Exception("Fail") + if len(arr[0]) == 0 or tc == 0: + raise Exception("Fail") except Exception as e: msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." elif not chat_passed and llm.model_type == LLMType.CHAT.value: mdl = ChatModel[factory]( req["api_key"], llm.llm_name) try: - m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9}) - if not tc: raise Exception(m) + m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], { + "temperature": 0.9}) + if not tc: + raise Exception(m) chat_passed = True except Exception as e: - msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(e) + msg += f"\nFail to access model({llm.llm_name}) using this api key." + str( + e) - if msg: return get_data_error_result(retmsg=msg) + if msg: + return get_data_error_result(retmsg=msg) llm = { "api_key": req["api_key"] } for n in ["model_type", "llm_name"]: - if n in req: llm[n] = req[n] + if n in req: + llm[n] = req[n] - if not TenantLLMService.filter_update([TenantLLM.tenant_id==current_user.id, TenantLLM.llm_factory==factory], llm): + if not TenantLLMService.filter_update( + [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory], llm): for llm in LLMService.query(fid=factory): - TenantLLMService.save(tenant_id=current_user.id, llm_factory=factory, llm_name=llm.llm_name, model_type=llm.model_type, api_key=req["api_key"]) + TenantLLMService.save( + tenant_id=current_user.id, + llm_factory=factory, + llm_name=llm.llm_name, + model_type=llm.model_type, + api_key=req["api_key"]) return get_json_result(data=True) @@ -105,17 +117,19 @@ def list(): objs = TenantLLMService.query(tenant_id=current_user.id) facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key]) llms = LLMService.get_all() - llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value] + llms = [m.to_dict() + for m in llms if m.status == StatusEnum.VALID.value] for m in llms: m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" res = {} for m in llms: - if model_type and m["model_type"] != model_type: continue - if m["fid"] not in res: res[m["fid"]] = [] + if model_type and m["model_type"] != model_type: + continue + if m["fid"] not in res: + res[m["fid"]] = [] res[m["fid"]].append(m) return get_json_result(data=res) except Exception as e: return server_error_response(e) - diff --git a/api/apps/user_app.py b/api/apps/user_app.py index c3dd68e..857c506 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -40,13 +40,16 @@ def login(): email = request.json.get('email', "") users = UserService.query(email=email) - if not users: return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!') + if not users: + return get_json_result( + data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!') password = request.json.get('password') try: password = decrypt(password) - except: - return get_json_result(data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password') + except BaseException: + return get_json_result( + data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password') user = UserService.query_user(email, password) if user: @@ -57,7 +60,8 @@ def login(): msg = "Welcome back!" return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg) else: - return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Email and Password do not match!') + return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, + retmsg='Email and Password do not match!') @manager.route('/github_callback', methods=['GET']) @@ -65,7 +69,7 @@ def github_callback(): import requests res = requests.post(GITHUB_OAUTH.get("url"), data={ "client_id": GITHUB_OAUTH.get("client_id"), - "client_secret": GITHUB_OAUTH.get("secret_key"), + "client_secret": GITHUB_OAUTH.get("secret_key"), "code": request.args.get('code') }, headers={"Accept": "application/json"}) res = res.json() @@ -96,15 +100,17 @@ def github_callback(): "last_login_time": get_format_time(), "is_superuser": False, }) - if not users: raise Exception('Register user failure.') - if len(users) > 1: raise Exception('Same E-mail exist!') + if not users: + raise Exception('Register user failure.') + if len(users) > 1: + raise Exception('Same E-mail exist!') user = users[0] login_user(user) - return redirect("/?auth=%s"%user.get_id()) + return redirect("/?auth=%s" % user.get_id()) except Exception as e: rollback_user_registration(user_id) stat_logger.exception(e) - return redirect("/?error=%s"%str(e)) + return redirect("/?error=%s" % str(e)) user = users[0] user.access_token = get_uuid() login_user(user) @@ -114,11 +120,18 @@ def github_callback(): def user_info_from_github(access_token): import requests - headers = {"Accept": "application/json", 'Authorization': f"token {access_token}"} - res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers) + headers = {"Accept": "application/json", + 'Authorization': f"token {access_token}"} + res = requests.get( + f"https://api.github.com/user?access_token={access_token}", + headers=headers) user_info = res.json() - email_info = requests.get(f"https://api.github.com/user/emails?access_token={access_token}", headers=headers).json() - user_info["email"] = next((email for email in email_info if email['primary'] == True), None)["email"] + email_info = requests.get( + f"https://api.github.com/user/emails?access_token={access_token}", + headers=headers).json() + user_info["email"] = next( + (email for email in email_info if email['primary'] == True), + None)["email"] return user_info @@ -138,13 +151,18 @@ def setting_user(): request_data = request.json if request_data.get("password"): new_password = request_data.get("new_password") - if not check_password_hash(current_user.password, decrypt(request_data["password"])): - return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!') + if not check_password_hash( + current_user.password, decrypt(request_data["password"])): + return get_json_result( + data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!') - if new_password: update_dict["password"] = generate_password_hash(decrypt(new_password)) + if new_password: + update_dict["password"] = generate_password_hash( + decrypt(new_password)) for k in request_data.keys(): - if k in ["password", "new_password"]:continue + if k in ["password", "new_password"]: + continue update_dict[k] = request_data[k] try: @@ -152,7 +170,8 @@ def setting_user(): return get_json_result(data=True) except Exception as e: stat_logger.exception(e) - return get_json_result(data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR) + return get_json_result( + data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR) @manager.route("/info", methods=["GET"]) @@ -173,11 +192,11 @@ def rollback_user_registration(user_id): except Exception as e: pass try: - TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute() + TenantLLM.delete().where(TenantLLM.tenant_id == user_id).excute() except Exception as e: pass - + def user_register(user_id, user): user["id"] = user_id tenant = { @@ -197,9 +216,14 @@ def user_register(user_id, user): } tenant_llm = [] for llm in LLMService.query(fid=LLM_FACTORY): - tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY}) - - if not UserService.save(**user):return + tenant_llm.append({"tenant_id": user_id, + "llm_factory": LLM_FACTORY, + "llm_name": llm.llm_name, + "model_type": llm.model_type, + "api_key": API_KEY}) + + if not UserService.save(**user): + return TenantService.insert(**tenant) UserTenantService.insert(**usr_tenant) TenantLLMService.insert_many(tenant_llm) @@ -211,7 +235,8 @@ def user_register(user_id, user): def user_add(): req = request.json if UserService.query(email=req["email"]): - return get_json_result(data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR) + return get_json_result( + data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR) if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,4}$", req["email"]): return get_json_result(data=False, retmsg=f'Invaliad e-mail: {req["email"]}!', retcode=RetCode.OPERATING_ERROR) @@ -229,16 +254,19 @@ def user_add(): user_id = get_uuid() try: users = user_register(user_id, user_dict) - if not users: raise Exception('Register user failure.') - if len(users) > 1: raise Exception('Same E-mail exist!') + if not users: + raise Exception('Register user failure.') + if len(users) > 1: + raise Exception('Same E-mail exist!') user = users[0] login_user(user) - return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!") + return cors_reponse(data=user.to_json(), + auth=user.get_id(), retmsg="Welcome aboard!") except Exception as e: rollback_user_registration(user_id) stat_logger.exception(e) - return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR) - + return get_json_result( + data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR) @manager.route("/tenant_info", methods=["GET"]) diff --git a/api/db/db_models.py b/api/db/db_models.py index caa756c..96620db 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -50,7 +50,13 @@ def singleton(cls, *args, **kw): CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField} -AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {"create", "start", "end", "update", "read_access", "write_access"} +AUTO_DATE_TIMESTAMP_FIELD_PREFIX = { + "create", + "start", + "end", + "update", + "read_access", + "write_access"} class LongTextField(TextField): @@ -73,7 +79,8 @@ class JSONField(LongTextField): def python_value(self, value): if not value: return self.default_value - return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) + return utils.json_loads( + value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) class ListField(JSONField): @@ -81,7 +88,8 @@ class ListField(JSONField): class SerializedField(LongTextField): - def __init__(self, serialized_type=SerializedType.PICKLE, object_hook=None, object_pairs_hook=None, **kwargs): + def __init__(self, serialized_type=SerializedType.PICKLE, + object_hook=None, object_pairs_hook=None, **kwargs): self._serialized_type = serialized_type self._object_hook = object_hook self._object_pairs_hook = object_pairs_hook @@ -95,7 +103,8 @@ class SerializedField(LongTextField): return None return utils.json_dumps(value, with_type=True) else: - raise ValueError(f"the serialized type {self._serialized_type} is not supported") + raise ValueError( + f"the serialized type {self._serialized_type} is not supported") def python_value(self, value): if self._serialized_type == SerializedType.PICKLE: @@ -103,9 +112,11 @@ class SerializedField(LongTextField): elif self._serialized_type == SerializedType.JSON: if value is None: return {} - return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) + return utils.json_loads( + value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) else: - raise ValueError(f"the serialized type {self._serialized_type} is not supported") + raise ValueError( + f"the serialized type {self._serialized_type} is not supported") def is_continuous_field(cls: typing.Type) -> bool: @@ -150,7 +161,8 @@ class BaseModel(Model): model_dict = self.__dict__['__data__'] if not only_primary_with: - return {remove_field_name_prefix(k): v for k, v in model_dict.items()} + return {remove_field_name_prefix( + k): v for k, v in model_dict.items()} human_model_dict = {} for k in self._meta.primary_key.field_names: @@ -184,17 +196,22 @@ class BaseModel(Model): if is_continuous_field(type(getattr(cls, attr_name))): if len(f_v) == 2: for i, v in enumerate(f_v): - if isinstance(v, str) and f_n in auto_date_timestamp_field(): + if isinstance( + v, str) and f_n in auto_date_timestamp_field(): # time type: %Y-%m-%d %H:%M:%S f_v[i] = utils.date_string_to_timestamp(v) lt_value = f_v[0] gt_value = f_v[1] if lt_value is not None and gt_value is not None: - filters.append(cls.getter_by(attr_name).between(lt_value, gt_value)) + filters.append( + cls.getter_by(attr_name).between( + lt_value, gt_value)) elif lt_value is not None: - filters.append(operator.attrgetter(attr_name)(cls) >= lt_value) + filters.append( + operator.attrgetter(attr_name)(cls) >= lt_value) elif gt_value is not None: - filters.append(operator.attrgetter(attr_name)(cls) <= gt_value) + filters.append( + operator.attrgetter(attr_name)(cls) <= gt_value) else: filters.append(operator.attrgetter(attr_name)(cls) << f_v) else: @@ -205,9 +222,11 @@ class BaseModel(Model): if not order_by or not hasattr(cls, f"{order_by}"): order_by = "create_time" if reverse is True: - query_records = query_records.order_by(cls.getter_by(f"{order_by}").desc()) + query_records = query_records.order_by( + cls.getter_by(f"{order_by}").desc()) elif reverse is False: - query_records = query_records.order_by(cls.getter_by(f"{order_by}").asc()) + query_records = query_records.order_by( + cls.getter_by(f"{order_by}").asc()) return [query_record for query_record in query_records] else: return [] @@ -215,7 +234,8 @@ class BaseModel(Model): @classmethod def insert(cls, __data=None, **insert): if isinstance(__data, dict) and __data: - __data[cls._meta.combined["create_time"]] = utils.current_timestamp() + __data[cls._meta.combined["create_time"] + ] = utils.current_timestamp() if insert: insert["create_time"] = utils.current_timestamp() @@ -228,7 +248,8 @@ class BaseModel(Model): if not normalized: return {} - normalized[cls._meta.combined["update_time"]] = utils.current_timestamp() + normalized[cls._meta.combined["update_time"] + ] = utils.current_timestamp() for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX: if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \ @@ -241,7 +262,8 @@ class BaseModel(Model): class JsonSerializedField(SerializedField): - def __init__(self, object_hook=utils.from_dict_hook, object_pairs_hook=None, **kwargs): + def __init__(self, object_hook=utils.from_dict_hook, + object_pairs_hook=None, **kwargs): super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook, object_pairs_hook=object_pairs_hook, **kwargs) @@ -251,7 +273,8 @@ class BaseDataBase: def __init__(self): database_config = DATABASE.copy() db_name = database_config.pop("name") - self.database_connection = PooledMySQLDatabase(db_name, **database_config) + self.database_connection = PooledMySQLDatabase( + db_name, **database_config) stat_logger.info('init mysql database on cluster mode successfully') @@ -263,7 +286,8 @@ class DatabaseLock: def lock(self): # SQL parameters only support %s format placeholders - cursor = self.db.execute_sql("SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout)) + cursor = self.db.execute_sql( + "SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout)) ret = cursor.fetchone() if ret[0] == 0: raise Exception(f'acquire mysql lock {self.lock_name} timeout') @@ -273,10 +297,12 @@ class DatabaseLock: raise Exception(f'failed to acquire lock {self.lock_name}') def unlock(self): - cursor = self.db.execute_sql("SELECT RELEASE_LOCK(%s)", (self.lock_name,)) + cursor = self.db.execute_sql( + "SELECT RELEASE_LOCK(%s)", (self.lock_name,)) ret = cursor.fetchone() if ret[0] == 0: - raise Exception(f'mysql lock {self.lock_name} was not established by this thread') + raise Exception( + f'mysql lock {self.lock_name} was not established by this thread') elif ret[0] == 1: return True else: @@ -350,17 +376,37 @@ class User(DataBaseModel, UserMixin): access_token = CharField(max_length=255, null=True) nickname = CharField(max_length=100, null=False, help_text="nicky name") password = CharField(max_length=255, null=True, help_text="password") - email = CharField(max_length=255, null=False, help_text="email", index=True) + email = CharField( + max_length=255, + null=False, + help_text="email", + index=True) avatar = TextField(null=True, help_text="avatar base64 string") - language = CharField(max_length=32, null=True, help_text="English|Chinese", default="Chinese") - color_schema = CharField(max_length=32, null=True, help_text="Bright|Dark", default="Bright") - timezone = CharField(max_length=64, null=True, help_text="Timezone", default="UTC+8\tAsia/Shanghai") + language = CharField( + max_length=32, + null=True, + help_text="English|Chinese", + default="Chinese") + color_schema = CharField( + max_length=32, + null=True, + help_text="Bright|Dark", + default="Bright") + timezone = CharField( + max_length=64, + null=True, + help_text="Timezone", + default="UTC+8\tAsia/Shanghai") last_login_time = DateTimeField(null=True) is_authenticated = CharField(max_length=1, null=False, default="1") is_active = CharField(max_length=1, null=False, default="1") is_anonymous = CharField(max_length=1, null=False, default="0") login_channel = CharField(null=True, help_text="from which user login") - status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") + status = CharField( + max_length=1, + null=True, + help_text="is it validate(0: wasted,1: validate)", + default="1") is_superuser = BooleanField(null=True, help_text="is root", default=False) def __str__(self): @@ -379,12 +425,28 @@ class Tenant(DataBaseModel): name = CharField(max_length=100, null=True, help_text="Tenant name") public_key = CharField(max_length=255, null=True) llm_id = CharField(max_length=128, null=False, help_text="default llm ID") - embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID") - asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID") - img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID") - parser_ids = CharField(max_length=256, null=False, help_text="document processors") + embd_id = CharField( + max_length=128, + null=False, + help_text="default embedding model ID") + asr_id = CharField( + max_length=128, + null=False, + help_text="default ASR model ID") + img2txt_id = CharField( + max_length=128, + null=False, + help_text="default image to text model ID") + parser_ids = CharField( + max_length=256, + null=False, + help_text="document processors") credit = IntegerField(default=512) - status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") + status = CharField( + max_length=1, + null=True, + help_text="is it validate(0: wasted,1: validate)", + default="1") class Meta: db_table = "tenant" @@ -396,7 +458,11 @@ class UserTenant(DataBaseModel): tenant_id = CharField(max_length=32, null=False) role = CharField(max_length=32, null=False, help_text="UserTenantRole") invited_by = CharField(max_length=32, null=False) - status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") + status = CharField( + max_length=1, + null=True, + help_text="is it validate(0: wasted,1: validate)", + default="1") class Meta: db_table = "user_tenant" @@ -408,17 +474,32 @@ class InvitationCode(DataBaseModel): visit_time = DateTimeField(null=True) user_id = CharField(max_length=32, null=True) tenant_id = CharField(max_length=32, null=True) - status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") + status = CharField( + max_length=1, + null=True, + help_text="is it validate(0: wasted,1: validate)", + default="1") class Meta: db_table = "invitation_code" class LLMFactories(DataBaseModel): - name = CharField(max_length=128, null=False, help_text="LLM factory name", primary_key=True) + name = CharField( + max_length=128, + null=False, + help_text="LLM factory name", + primary_key=True) logo = TextField(null=True, help_text="llm logo base64") - tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") - status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") + tags = CharField( + max_length=255, + null=False, + help_text="LLM, Text Embedding, Image2Text, ASR") + status = CharField( + max_length=1, + null=True, + help_text="is it validate(0: wasted,1: validate)", + default="1") def __str__(self): return self.name @@ -429,12 +510,27 @@ class LLMFactories(DataBaseModel): class LLM(DataBaseModel): # LLMs dictionary - llm_name = CharField(max_length=128, null=False, help_text="LLM name", index=True, primary_key=True) - model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") + llm_name = CharField( + max_length=128, + null=False, + help_text="LLM name", + index=True, + primary_key=True) + model_type = CharField( + max_length=128, + null=False, + help_text="LLM, Text Embedding, Image2Text, ASR") fid = CharField(max_length=128, null=False, help_text="LLM factory id") max_tokens = IntegerField(default=0) - tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...") - status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") + tags = CharField( + max_length=255, + null=False, + help_text="LLM, Text Embedding, Image2Text, Chat, 32k...") + status = CharField( + max_length=1, + null=True, + help_text="is it validate(0: wasted,1: validate)", + default="1") def __str__(self): return self.llm_name @@ -445,9 +541,19 @@ class LLM(DataBaseModel): class TenantLLM(DataBaseModel): tenant_id = CharField(max_length=32, null=False) - llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name") - model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR") - llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="") + llm_factory = CharField( + max_length=128, + null=False, + help_text="LLM factory name") + model_type = CharField( + max_length=128, + null=True, + help_text="LLM, Text Embedding, Image2Text, ASR") + llm_name = CharField( + max_length=128, + null=True, + help_text="LLM name", + default="") api_key = CharField(max_length=255, null=True, help_text="API KEY") api_base = CharField(max_length=255, null=True, help_text="API Base") used_tokens = IntegerField(default=0) @@ -464,11 +570,26 @@ class Knowledgebase(DataBaseModel): id = CharField(max_length=32, primary_key=True) avatar = TextField(null=True, help_text="avatar base64 string") tenant_id = CharField(max_length=32, null=False) - name = CharField(max_length=128, null=False, help_text="KB name", index=True) - language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese") + name = CharField( + max_length=128, + null=False, + help_text="KB name", + index=True) + language = CharField( + max_length=32, + null=True, + default="Chinese", + help_text="English|Chinese") description = TextField(null=True, help_text="KB description") - embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID") - permission = CharField(max_length=16, null=False, help_text="me|team", default="me") + embd_id = CharField( + max_length=128, + null=False, + help_text="default embedding model ID") + permission = CharField( + max_length=16, + null=False, + help_text="me|team", + default="me") created_by = CharField(max_length=32, null=False) doc_num = IntegerField(default=0) token_num = IntegerField(default=0) @@ -476,9 +597,17 @@ class Knowledgebase(DataBaseModel): similarity_threshold = FloatField(default=0.2) vector_similarity_weight = FloatField(default=0.3) - parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value) - parser_config = JSONField(null=False, default={"pages":[[1,1000000]]}) - status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") + parser_id = CharField( + max_length=32, + null=False, + help_text="default parser ID", + default=ParserType.NAIVE.value) + parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]}) + status = CharField( + max_length=1, + null=True, + help_text="is it validate(0: wasted,1: validate)", + default="1") def __str__(self): return self.name @@ -491,22 +620,50 @@ class Document(DataBaseModel): id = CharField(max_length=32, primary_key=True) thumbnail = TextField(null=True, help_text="thumbnail base64 string") kb_id = CharField(max_length=256, null=False, index=True) - parser_id = CharField(max_length=32, null=False, help_text="default parser ID") - parser_config = JSONField(null=False, default={"pages":[[1,1000000]]}) - source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from") + parser_id = CharField( + max_length=32, + null=False, + help_text="default parser ID") + parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]}) + source_type = CharField( + max_length=128, + null=False, + default="local", + help_text="where dose this document from") type = CharField(max_length=32, null=False, help_text="file extension") - created_by = CharField(max_length=32, null=False, help_text="who created it") - name = CharField(max_length=255, null=True, help_text="file name", index=True) - location = CharField(max_length=255, null=True, help_text="where dose it store") + created_by = CharField( + max_length=32, + null=False, + help_text="who created it") + name = CharField( + max_length=255, + null=True, + help_text="file name", + index=True) + location = CharField( + max_length=255, + null=True, + help_text="where dose it store") size = IntegerField(default=0) token_num = IntegerField(default=0) chunk_num = IntegerField(default=0) progress = FloatField(default=0) - progress_msg = TextField(null=True, help_text="process message", default="") + progress_msg = TextField( + null=True, + help_text="process message", + default="") process_begin_at = DateTimeField(null=True) process_duation = FloatField(default=0) - run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0") - status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") + run = CharField( + max_length=1, + null=True, + help_text="start to run processing or cancel.(1: run it; 2: cancel)", + default="0") + status = CharField( + max_length=1, + null=True, + help_text="is it validate(0: wasted,1: validate)", + default="1") class Meta: db_table = "document" @@ -520,30 +677,52 @@ class Task(DataBaseModel): begin_at = DateTimeField(null=True) process_duation = FloatField(default=0) progress = FloatField(default=0) - progress_msg = TextField(null=True, help_text="process message", default="") + progress_msg = TextField( + null=True, + help_text="process message", + default="") class Dialog(DataBaseModel): id = CharField(max_length=32, primary_key=True) tenant_id = CharField(max_length=32, null=False) - name = CharField(max_length=255, null=True, help_text="dialog application name") + name = CharField( + max_length=255, + null=True, + help_text="dialog application name") description = TextField(null=True, help_text="Dialog description") icon = TextField(null=True, help_text="icon base64 string") - language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese") + language = CharField( + max_length=32, + null=True, + default="Chinese", + help_text="English|Chinese") llm_id = CharField(max_length=32, null=False, help_text="default llm ID") llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7, "presence_penalty": 0.4, "max_tokens": 215}) - prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced") + prompt_type = CharField( + max_length=16, + null=False, + default="simple", + help_text="simple|advanced") prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,ć‘ćŻć‚¨çš„助手小樱,长得可ç±ĺŹĺ–„良,can I help you?", "parameters": [], "empty_response": "Sorry! 知识库ä¸ćśŞć‰ľĺ°ç›¸ĺ…łĺ†…容ďĽ"}) similarity_threshold = FloatField(default=0.2) vector_similarity_weight = FloatField(default=0.3) top_n = IntegerField(default=6) - do_refer = CharField(max_length=1, null=False, help_text="it needs to insert reference index into answer or not", default="1") + do_refer = CharField( + max_length=1, + null=False, + help_text="it needs to insert reference index into answer or not", + default="1") kb_ids = JSONField(null=False, default=[]) - status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") + status = CharField( + max_length=1, + null=True, + help_text="is it validate(0: wasted,1: validate)", + default="1") class Meta: db_table = "dialog" diff --git a/api/db/db_utils.py b/api/db/db_utils.py index 1e5a384..144cc1f 100644 --- a/api/db/db_utils.py +++ b/api/db/db_utils.py @@ -32,8 +32,7 @@ LOGGER = getLogger() def bulk_insert_into_db(model, data_source, replace_on_conflict=False): DB.create_tables([model]) - - for i,data in enumerate(data_source): + for i, data in enumerate(data_source): current_time = current_timestamp() + i current_date = timestamp_to_date(current_time) if 'create_time' not in data: @@ -55,7 +54,8 @@ def bulk_insert_into_db(model, data_source, replace_on_conflict=False): def get_dynamic_db_model(base, job_id): - return type(base.model(table_index=get_dynamic_tracking_table_index(job_id=job_id))) + return type(base.model( + table_index=get_dynamic_tracking_table_index(job_id=job_id))) def get_dynamic_tracking_table_index(job_id): @@ -86,7 +86,9 @@ supported_operators = { '~': operator.inv, } -def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]): + +def query_dict2expression( + model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]): expression = [] for field, value in query.items(): @@ -95,7 +97,10 @@ def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[boo op, *val = value field = getattr(model, f'f_{field}') - value = supported_operators[op](field, val[0]) if op in supported_operators else getattr(field, op)(*val) + value = supported_operators[op]( + field, val[0]) if op in supported_operators else getattr( + field, op)( + *val) expression.append(value) return reduce(operator.iand, expression) diff --git a/api/db/init_data.py b/api/db/init_data.py index 89a94e6..9c735de 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -61,45 +61,54 @@ def init_superuser(): TenantService.insert(**tenant) UserTenantService.insert(**usr_tenant) TenantLLMService.insert_many(tenant_llm) - print("ă€INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.") + print( + "ă€INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.") chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) - msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}) + msg = chat_mdl.chat(system="", history=[ + {"role": "user", "content": "Hello!"}], gen_conf={}) if msg.find("ERROR: ") == 0: - print("\33[91mă€ERROR】\33[0m: ", "'{}' dosen't work. {}".format(tenant["llm_id"], msg)) + print( + "\33[91mă€ERROR】\33[0m: ", + "'{}' dosen't work. {}".format( + tenant["llm_id"], + msg)) embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"]) v, c = embd_mdl.encode(["Hello!"]) if c == 0: - print("\33[91mă€ERROR】\33[0m:", " '{}' dosen't work!".format(tenant["embd_id"])) + print( + "\33[91mă€ERROR】\33[0m:", + " '{}' dosen't work!".format( + tenant["embd_id"])) factory_infos = [{ - "name": "OpenAI", - "logo": "", - "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", + "name": "OpenAI", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", + "status": "1", +}, { + "name": "Tongyi-Qianwen", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", + "status": "1", +}, { + "name": "ZHIPU-AI", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", + "status": "1", +}, + { + "name": "Local", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", "status": "1", - },{ - "name": "Tongyi-Qianwen", - "logo": "", - "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", - "status": "1", - },{ - "name": "ZHIPU-AI", - "logo": "", - "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", - "status": "1", - }, - { - "name": "Local", - "logo": "", - "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", - "status": "1", - },{ +}, { "name": "Moonshot", - "logo": "", - "tags": "LLM,TEXT EMBEDDING", - "status": "1", - } + "logo": "", + "tags": "LLM,TEXT EMBEDDING", + "status": "1", +} # { # "name": "ć–‡ĺżä¸€č¨€", # "logo": "", @@ -107,6 +116,8 @@ factory_infos = [{ # "status": "1", # }, ] + + def init_llm_factory(): llm_infos = [ # ---------------------- OpenAI ------------------------ @@ -116,37 +127,37 @@ def init_llm_factory(): "tags": "LLM,CHAT,4K", "max_tokens": 4096, "model_type": LLMType.CHAT.value - },{ + }, { "fid": factory_infos[0]["name"], "llm_name": "gpt-3.5-turbo-16k-0613", "tags": "LLM,CHAT,16k", "max_tokens": 16385, "model_type": LLMType.CHAT.value - },{ + }, { "fid": factory_infos[0]["name"], "llm_name": "text-embedding-ada-002", "tags": "TEXT EMBEDDING,8K", "max_tokens": 8191, "model_type": LLMType.EMBEDDING.value - },{ + }, { "fid": factory_infos[0]["name"], "llm_name": "whisper-1", "tags": "SPEECH2TEXT", - "max_tokens": 25*1024*1024, + "max_tokens": 25 * 1024 * 1024, "model_type": LLMType.SPEECH2TEXT.value - },{ + }, { "fid": factory_infos[0]["name"], "llm_name": "gpt-4", "tags": "LLM,CHAT,8K", "max_tokens": 8191, "model_type": LLMType.CHAT.value - },{ + }, { "fid": factory_infos[0]["name"], "llm_name": "gpt-4-32k", "tags": "LLM,CHAT,32K", "max_tokens": 32768, "model_type": LLMType.CHAT.value - },{ + }, { "fid": factory_infos[0]["name"], "llm_name": "gpt-4-vision-preview", "tags": "LLM,CHAT,IMAGE2TEXT", @@ -160,31 +171,31 @@ def init_llm_factory(): "tags": "LLM,CHAT,8K", "max_tokens": 8191, "model_type": LLMType.CHAT.value - },{ + }, { "fid": factory_infos[1]["name"], "llm_name": "qwen-plus", "tags": "LLM,CHAT,32K", "max_tokens": 32768, "model_type": LLMType.CHAT.value - },{ + }, { "fid": factory_infos[1]["name"], "llm_name": "qwen-max-1201", "tags": "LLM,CHAT,6K", "max_tokens": 5899, "model_type": LLMType.CHAT.value - },{ + }, { "fid": factory_infos[1]["name"], "llm_name": "text-embedding-v2", "tags": "TEXT EMBEDDING,2K", "max_tokens": 2048, "model_type": LLMType.EMBEDDING.value - },{ + }, { "fid": factory_infos[1]["name"], "llm_name": "paraformer-realtime-8k-v1", "tags": "SPEECH2TEXT", - "max_tokens": 25*1024*1024, + "max_tokens": 25 * 1024 * 1024, "model_type": LLMType.SPEECH2TEXT.value - },{ + }, { "fid": factory_infos[1]["name"], "llm_name": "qwen-vl-max", "tags": "LLM,CHAT,IMAGE2TEXT", @@ -245,13 +256,13 @@ def init_llm_factory(): "tags": "TEXT EMBEDDING,", "max_tokens": 128 * 1000, "model_type": LLMType.EMBEDDING.value - },{ + }, { "fid": factory_infos[4]["name"], "llm_name": "moonshot-v1-32k", "tags": "LLM,CHAT,", "max_tokens": 32768, "model_type": LLMType.CHAT.value - },{ + }, { "fid": factory_infos[4]["name"], "llm_name": "moonshot-v1-128k", "tags": "LLM,CHAT", @@ -294,7 +305,6 @@ def init_web_data(): print("init web data success:{}".format(time.time() - start_time)) - if __name__ == '__main__': init_web_db() - init_web_data() \ No newline at end of file + init_web_data() diff --git a/api/db/operatioins.py b/api/db/operatioins.py index 10d58df..cc13a42 100644 --- a/api/db/operatioins.py +++ b/api/db/operatioins.py @@ -18,4 +18,4 @@ import operator import time import typing from api.utils.log_utils import sql_logger -import peewee \ No newline at end of file +import peewee diff --git a/api/db/reload_config_base.py b/api/db/reload_config_base.py index b6df15c..fff9b59 100644 --- a/api/db/reload_config_base.py +++ b/api/db/reload_config_base.py @@ -18,10 +18,11 @@ class ReloadConfigBase: def get_all(cls): configs = {} for k, v in cls.__dict__.items(): - if not callable(getattr(cls, k)) and not k.startswith("__") and not k.startswith("_"): + if not callable(getattr(cls, k)) and not k.startswith( + "__") and not k.startswith("_"): configs[k] = v return configs @classmethod def get(cls, config_name): - return getattr(cls, config_name) if hasattr(cls, config_name) else None \ No newline at end of file + return getattr(cls, config_name) if hasattr(cls, config_name) else None diff --git a/api/db/runtime_config.py b/api/db/runtime_config.py index de7ab3a..ad488dc 100644 --- a/api/db/runtime_config.py +++ b/api/db/runtime_config.py @@ -51,4 +51,4 @@ class RuntimeConfig(ReloadConfigBase): @classmethod def set_service_db(cls, service_db): - cls.SERVICE_DB = service_db \ No newline at end of file + cls.SERVICE_DB = service_db diff --git a/api/db/services/common_service.py b/api/db/services/common_service.py index fbbb645..ad87d65 100644 --- a/api/db/services/common_service.py +++ b/api/db/services/common_service.py @@ -27,7 +27,8 @@ class CommonService: @classmethod @DB.connection_context() def query(cls, cols=None, reverse=None, order_by=None, **kwargs): - return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs) + return cls.model.query(cols=cols, reverse=reverse, + order_by=order_by, **kwargs) @classmethod @DB.connection_context() @@ -40,9 +41,11 @@ class CommonService: if not order_by or not hasattr(cls, order_by): order_by = "create_time" if reverse is True: - query_records = query_records.order_by(cls.model.getter_by(order_by).desc()) + query_records = query_records.order_by( + cls.model.getter_by(order_by).desc()) elif reverse is False: - query_records = query_records.order_by(cls.model.getter_by(order_by).asc()) + query_records = query_records.order_by( + cls.model.getter_by(order_by).asc()) return query_records @classmethod @@ -61,7 +64,7 @@ class CommonService: @classmethod @DB.connection_context() def save(cls, **kwargs): - #if "id" not in kwargs: + # if "id" not in kwargs: # kwargs["id"] = get_uuid() sample_obj = cls.model(**kwargs).save(force_insert=True) return sample_obj @@ -95,7 +98,8 @@ class CommonService: for data in data_list: data["update_time"] = current_timestamp() data["update_date"] = datetime_format(datetime.now()) - cls.model.update(data).where(cls.model.id == data["id"]).execute() + cls.model.update(data).where( + cls.model.id == data["id"]).execute() @classmethod @DB.connection_context() @@ -128,7 +132,6 @@ class CommonService: def delete_by_id(cls, pid): return cls.model.delete().where(cls.model.id == pid).execute() - @classmethod @DB.connection_context() def filter_delete(cls, filters): @@ -151,19 +154,30 @@ class CommonService: @classmethod @DB.connection_context() - def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None): + def filter_scope_list(cls, in_key, in_filters_list, + filters=None, cols=None): in_filters_tuple_list = cls.cut_list(in_filters_list, 20) if not filters: filters = [] res_list = [] if cols: for i in in_filters_tuple_list: - query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters) + query_records = cls.model.select( + * + cols).where( + getattr( + cls.model, + in_key).in_(i), + * + filters) if query_records: - res_list.extend([query_record for query_record in query_records]) + res_list.extend( + [query_record for query_record in query_records]) else: for i in in_filters_tuple_list: - query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters) + query_records = cls.model.select().where( + getattr(cls.model, in_key).in_(i), *filters) if query_records: - res_list.extend([query_record for query_record in query_records]) - return res_list \ No newline at end of file + res_list.extend( + [query_record for query_record in query_records]) + return res_list diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 2864e4f..bb770eb 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -21,6 +21,5 @@ class DialogService(CommonService): model = Dialog - class ConversationService(CommonService): model = Conversation diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index a01798f..bd85a94 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -72,7 +72,20 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def get_newly_uploaded(cls, tm, mod=0, comm=1, items_per_page=64): - fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.parser_config, cls.model.name, cls.model.type, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] + fields = [ + cls.model.id, + cls.model.kb_id, + cls.model.parser_id, + cls.model.parser_config, + cls.model.name, + cls.model.type, + cls.model.location, + cls.model.size, + Knowledgebase.tenant_id, + Tenant.embd_id, + Tenant.img2txt_id, + Tenant.asr_id, + cls.model.update_time] docs = cls.model.select(*fields) \ .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \ .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\ @@ -103,40 +116,64 @@ class DocumentService(CommonService): @DB.connection_context() def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation): num = cls.model.update(token_num=cls.model.token_num + token_num, - chunk_num=cls.model.chunk_num + chunk_num, - process_duation=cls.model.process_duation+duation).where( + chunk_num=cls.model.chunk_num + chunk_num, + process_duation=cls.model.process_duation + duation).where( cls.model.id == doc_id).execute() - if num == 0:raise LookupError("Document not found which is supposed to be there") - num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute() + if num == 0: + raise LookupError( + "Document not found which is supposed to be there") + num = Knowledgebase.update( + token_num=Knowledgebase.token_num + + token_num, + chunk_num=Knowledgebase.chunk_num + + chunk_num).where( + Knowledgebase.id == kb_id).execute() return num @classmethod @DB.connection_context() def get_tenant_id(cls, doc_id): - docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status==StatusEnum.VALID.value) + docs = cls.model.select( + Knowledgebase.tenant_id).join( + Knowledgebase, on=( + Knowledgebase.id == cls.model.kb_id)).where( + cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) docs = docs.dicts() - if not docs:return + if not docs: + return return docs[0]["tenant_id"] @classmethod @DB.connection_context() def get_thumbnails(cls, docids): fields = [cls.model.id, cls.model.thumbnail] - return list(cls.model.select(*fields).where(cls.model.id.in_(docids)).dicts()) + return list(cls.model.select( + *fields).where(cls.model.id.in_(docids)).dicts()) @classmethod @DB.connection_context() def update_parser_config(cls, id, config): e, d = cls.get_by_id(id) - if not e:raise LookupError(f"Document({id}) not found.") + if not e: + raise LookupError(f"Document({id}) not found.") + def dfs_update(old, new): - for k,v in new.items(): + for k, v in new.items(): if k not in old: old[k] = v continue if isinstance(v, dict): assert isinstance(old[k], dict) dfs_update(old[k], v) - else: old[k] = v + else: + old[k] = v dfs_update(d.parser_config, config) - cls.update_by_id(id, {"parser_config": d.parser_config}) \ No newline at end of file + cls.update_by_id(id, {"parser_config": d.parser_config}) + + @classmethod + @DB.connection_context() + def get_doc_count(cls, tenant_id): + docs = cls.model.select(cls.model.id).join(Knowledgebase, + on=(Knowledgebase.id == cls.model.kb_id)).where( + Knowledgebase.tenant_id == tenant_id) + return len(docs) diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index be2c964..365f8ed 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -55,7 +55,7 @@ class KnowledgebaseService(CommonService): cls.model.chunk_num, cls.model.parser_id, cls.model.parser_config] - kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where( + kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where( (cls.model.id == kb_id), (cls.model.status == StatusEnum.VALID.value) ) @@ -69,9 +69,11 @@ class KnowledgebaseService(CommonService): @DB.connection_context() def update_parser_config(cls, id, config): e, m = cls.get_by_id(id) - if not e:raise LookupError(f"knowledgebase({id}) not found.") + if not e: + raise LookupError(f"knowledgebase({id}) not found.") + def dfs_update(old, new): - for k,v in new.items(): + for k, v in new.items(): if k not in old: old[k] = v continue @@ -80,12 +82,12 @@ class KnowledgebaseService(CommonService): dfs_update(old[k], v) elif isinstance(v, list): assert isinstance(old[k], list) - old[k] = list(set(old[k]+v)) - else: old[k] = v + old[k] = list(set(old[k] + v)) + else: + old[k] = v dfs_update(m.parser_config, config) cls.update_by_id(id, {"parser_config": m.parser_config}) - @classmethod @DB.connection_context() def get_field_map(cls, ids): @@ -94,4 +96,3 @@ class KnowledgebaseService(CommonService): if k.parser_config and "field_map" in k.parser_config: conf.update(k.parser_config["field_map"]) return conf - diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 5bb54b1..f4e2a41 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -59,7 +59,8 @@ class TenantLLMService(CommonService): @classmethod @DB.connection_context() - def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese"): + def model_instance(cls, tenant_id, llm_type, + llm_name=None, lang="Chinese"): e, tenant = TenantService.get_by_id(tenant_id) if not e: raise LookupError("Tenant not found") @@ -126,29 +127,39 @@ class LLMBundle(object): self.tenant_id = tenant_id self.llm_type = llm_type self.llm_name = llm_name - self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang) - assert self.mdl, "Can't find mole for {}/{}/{}".format(tenant_id, llm_type, llm_name) + self.mdl = TenantLLMService.model_instance( + tenant_id, llm_type, llm_name, lang=lang) + assert self.mdl, "Can't find mole for {}/{}/{}".format( + tenant_id, llm_type, llm_name) def encode(self, texts: list, batch_size=32): emd, used_tokens = self.mdl.encode(texts, batch_size) - if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): - database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) + if TenantLLMService.increase_usage( + self.tenant_id, self.llm_type, used_tokens): + database_logger.error( + "Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) return emd, used_tokens def encode_queries(self, query: str): emd, used_tokens = self.mdl.encode_queries(query) - if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): - database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) + if TenantLLMService.increase_usage( + self.tenant_id, self.llm_type, used_tokens): + database_logger.error( + "Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) return emd, used_tokens def describe(self, image, max_tokens=300): txt, used_tokens = self.mdl.describe(image, max_tokens) - if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): - database_logger.error("Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id)) + if not TenantLLMService.increase_usage( + self.tenant_id, self.llm_type, used_tokens): + database_logger.error( + "Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id)) return txt def chat(self, system, history, gen_conf): txt, used_tokens = self.mdl.chat(system, history, gen_conf) - if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): - database_logger.error("Can't update token usage for {}/CHAT".format(self.tenant_id)) + if TenantLLMService.increase_usage( + self.tenant_id, self.llm_type, used_tokens, self.llm_name): + database_logger.error( + "Can't update token usage for {}/CHAT".format(self.tenant_id)) return txt diff --git a/api/db/services/user_service.py b/api/db/services/user_service.py index fe68783..4194ff5 100644 --- a/api/db/services/user_service.py +++ b/api/db/services/user_service.py @@ -54,7 +54,8 @@ class UserService(CommonService): if "id" not in kwargs: kwargs["id"] = get_uuid() if "password" in kwargs: - kwargs["password"] = generate_password_hash(str(kwargs["password"])) + kwargs["password"] = generate_password_hash( + str(kwargs["password"])) kwargs["create_time"] = current_timestamp() kwargs["create_date"] = datetime_format(datetime.now()) @@ -63,12 +64,12 @@ class UserService(CommonService): obj = cls.model(**kwargs).save(force_insert=True) return obj - @classmethod @DB.connection_context() def delete_user(cls, user_ids, update_user_dict): with DB.atomic(): - cls.model.update({"status": 0}).where(cls.model.id.in_(user_ids)).execute() + cls.model.update({"status": 0}).where( + cls.model.id.in_(user_ids)).execute() @classmethod @DB.connection_context() @@ -77,7 +78,8 @@ class UserService(CommonService): if user_dict: user_dict["update_time"] = current_timestamp() user_dict["update_date"] = datetime_format(datetime.now()) - cls.model.update(user_dict).where(cls.model.id == user_id).execute() + cls.model.update(user_dict).where( + cls.model.id == user_id).execute() class TenantService(CommonService): @@ -86,25 +88,42 @@ class TenantService(CommonService): @classmethod @DB.connection_context() def get_by_user_id(cls, user_id): - fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role] - return list(cls.model.select(*fields)\ - .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\ - .where(cls.model.status == StatusEnum.VALID.value).dicts()) + fields = [ + cls.model.id.alias("tenant_id"), + cls.model.name, + cls.model.llm_id, + cls.model.embd_id, + cls.model.asr_id, + cls.model.img2txt_id, + cls.model.parser_ids, + UserTenant.role] + return list(cls.model.select(*fields) + .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value))) + .where(cls.model.status == StatusEnum.VALID.value).dicts()) @classmethod @DB.connection_context() def get_joined_tenants_by_user_id(cls, user_id): - fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role] - return list(cls.model.select(*fields)\ - .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role==UserTenantRole.NORMAL.value)))\ - .where(cls.model.status == StatusEnum.VALID.value).dicts()) + fields = [ + cls.model.id.alias("tenant_id"), + cls.model.name, + cls.model.llm_id, + cls.model.embd_id, + cls.model.asr_id, + cls.model.img2txt_id, + UserTenant.role] + return list(cls.model.select(*fields) + .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role == UserTenantRole.NORMAL.value))) + .where(cls.model.status == StatusEnum.VALID.value).dicts()) @classmethod @DB.connection_context() def decrease(cls, user_id, num): num = cls.model.update(credit=cls.model.credit - num).where( cls.model.id == user_id).execute() - if num == 0: raise LookupError("Tenant not found which is supposed to be there") + if num == 0: + raise LookupError("Tenant not found which is supposed to be there") + class UserTenantService(CommonService): model = UserTenant diff --git a/api/settings.py b/api/settings.py index 93c5906..0142fd0 100644 --- a/api/settings.py +++ b/api/settings.py @@ -13,16 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from rag.utils import ELASTICSEARCH +from rag.nlp import search import os from enum import IntEnum, Enum -from api.utils import get_base_config,decrypt_database_config +from api.utils import get_base_config, decrypt_database_config from api.utils.file_utils import get_project_base_directory from api.utils.log_utils import LoggerFactory, getLogger # Logger -LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "api")) +LoggerFactory.set_directory( + os.path.join( + get_project_base_directory(), + "logs", + "api")) # {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0} LoggerFactory.LEVEL = 10 @@ -86,7 +92,9 @@ default_llm = { LLM = get_base_config("user_default_llm", {}) LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen") if LLM_FACTORY not in default_llm: - print("\33[91mă€ERROR】\33[0m:", f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.") + print( + "\33[91mă€ERROR】\33[0m:", + f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.") LLM_FACTORY = "Tongyi-Qianwen" CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] EMBEDDING_MDL = default_llm[LLM_FACTORY]["embedding_model"] @@ -94,7 +102,9 @@ ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] API_KEY = LLM.get("api_key", "") -PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One") +PARSERS = LLM.get( + "parsers", + "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One") # distribution DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False) @@ -103,13 +113,25 @@ RAG_FLOW_UPDATE_CHECK = False HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") -SECRET_KEY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key", "infiniflow") -TOKEN_EXPIRE_IN = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("token_expires_in", 3600) - -NGINX_HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("host") or HOST -NGINX_HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("http_port") or HTTP_PORT - -RANDOM_INSTANCE_ID = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("random_instance_id", False) +SECRET_KEY = get_base_config( + RAG_FLOW_SERVICE_NAME, + {}).get( + "secret_key", + "infiniflow") +TOKEN_EXPIRE_IN = get_base_config( + RAG_FLOW_SERVICE_NAME, {}).get( + "token_expires_in", 3600) + +NGINX_HOST = get_base_config( + RAG_FLOW_SERVICE_NAME, {}).get( + "nginx", {}).get("host") or HOST +NGINX_HTTP_PORT = get_base_config( + RAG_FLOW_SERVICE_NAME, {}).get( + "nginx", {}).get("http_port") or HTTP_PORT + +RANDOM_INSTANCE_ID = get_base_config( + RAG_FLOW_SERVICE_NAME, {}).get( + "random_instance_id", False) PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy") PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol") @@ -124,7 +146,9 @@ UPLOAD_DATA_FROM_CLIENT = True AUTHENTICATION_CONF = get_base_config("authentication", {}) # client -CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get("client", {}).get("switch", False) +CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get( + "client", {}).get( + "switch", False) HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key") GITHUB_OAUTH = get_base_config("oauth", {}).get("github") WECHAT_OAUTH = get_base_config("oauth", {}).get("wechat") @@ -147,12 +171,10 @@ USE_AUTHENTICATION = False USE_DATA_AUTHENTICATION = False AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True USE_DEFAULT_TIMEOUT = False -AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s +AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s PRIVILEGE_COMMAND_WHITELIST = [] CHECK_NODES_IDENTITY = False -from rag.nlp import search -from rag.utils import ELASTICSEARCH retrievaler = search.Dealer(ELASTICSEARCH) @@ -162,7 +184,7 @@ class CustomEnum(Enum): try: cls(value) return True - except: + except BaseException: return False @classmethod diff --git a/api/utils/__init__.py b/api/utils/__init__.py index 9ae6e0c..65c6b31 100644 --- a/api/utils/__init__.py +++ b/api/utils/__init__.py @@ -34,10 +34,12 @@ from . import file_utils SERVICE_CONF = "service_conf.yaml" + def conf_realpath(conf_name): conf_path = f"conf/{conf_name}" return os.path.join(file_utils.get_project_base_directory(), conf_path) + def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict: local_config = {} local_path = conf_realpath(f'local.{conf_name}') @@ -62,7 +64,8 @@ def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict: return config.get(key, default) if key is not None else config -use_deserialize_safe_module = get_base_config('use_deserialize_safe_module', False) +use_deserialize_safe_module = get_base_config( + 'use_deserialize_safe_module', False) class CoordinationCommunicationProtocol(object): @@ -93,7 +96,8 @@ class BaseType: data[_k] = _dict(vv) else: data = obj - return {"type": obj.__class__.__name__, "data": data, "module": module} + return {"type": obj.__class__.__name__, + "data": data, "module": module} return _dict(self) @@ -129,7 +133,8 @@ def rag_uuid(): def string_to_bytes(string): - return string if isinstance(string, bytes) else string.encode(encoding="utf-8") + return string if isinstance( + string, bytes) else string.encode(encoding="utf-8") def bytes_to_string(byte): @@ -137,7 +142,11 @@ def bytes_to_string(byte): def json_dumps(src, byte=False, indent=None, with_type=False): - dest = json.dumps(src, indent=indent, cls=CustomJSONEncoder, with_type=with_type) + dest = json.dumps( + src, + indent=indent, + cls=CustomJSONEncoder, + with_type=with_type) if byte: dest = string_to_bytes(dest) return dest @@ -146,7 +155,8 @@ def json_dumps(src, byte=False, indent=None, with_type=False): def json_loads(src, object_hook=None, object_pairs_hook=None): if isinstance(src, bytes): src = bytes_to_string(src) - return json.loads(src, object_hook=object_hook, object_pairs_hook=object_pairs_hook) + return json.loads(src, object_hook=object_hook, + object_pairs_hook=object_pairs_hook) def current_timestamp(): @@ -177,7 +187,9 @@ def serialize_b64(src, to_str=False): def deserialize_b64(src): - src = base64.b64decode(string_to_bytes(src) if isinstance(src, str) else src) + src = base64.b64decode( + string_to_bytes(src) if isinstance( + src, str) else src) if use_deserialize_safe_module: return restricted_loads(src) return pickle.loads(src) @@ -237,12 +249,14 @@ def get_lan_ip(): pass return ip or '' + def from_dict_hook(in_dict: dict): if "type" in in_dict and "data" in in_dict: if in_dict["module"] is None: return in_dict["data"] else: - return getattr(importlib.import_module(in_dict["module"]), in_dict["type"])(**in_dict["data"]) + return getattr(importlib.import_module( + in_dict["module"]), in_dict["type"])(**in_dict["data"]) else: return in_dict @@ -259,12 +273,16 @@ def decrypt_database_password(password): raise ValueError("No private key") module_fun = encrypt_module.split("#") - pwdecrypt_fun = getattr(importlib.import_module(module_fun[0]), module_fun[1]) + pwdecrypt_fun = getattr( + importlib.import_module( + module_fun[0]), + module_fun[1]) return pwdecrypt_fun(private_key, password) -def decrypt_database_config(database=None, passwd_key="password", name="database"): +def decrypt_database_config( + database=None, passwd_key="password", name="database"): if not database: database = get_base_config(name, {}) @@ -275,7 +293,8 @@ def decrypt_database_config(database=None, passwd_key="password", name="database def update_config(key, value, conf_name=SERVICE_CONF): conf_path = conf_realpath(conf_name=conf_name) if not os.path.isabs(conf_path): - conf_path = os.path.join(file_utils.get_project_base_directory(), conf_path) + conf_path = os.path.join( + file_utils.get_project_base_directory(), conf_path) with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")): config = file_utils.load_yaml_conf(conf_path=conf_path) or {} @@ -288,7 +307,8 @@ def get_uuid(): def datetime_format(date_time: datetime.datetime) -> datetime.datetime: - return datetime.datetime(date_time.year, date_time.month, date_time.day, date_time.hour, date_time.minute, date_time.second) + return datetime.datetime(date_time.year, date_time.month, date_time.day, + date_time.hour, date_time.minute, date_time.second) def get_format_time() -> datetime.datetime: @@ -307,14 +327,19 @@ def elapsed2time(elapsed): def decrypt(line): - file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem") + file_path = os.path.join( + file_utils.get_project_base_directory(), + "conf", + "private.pem") rsa_key = RSA.importKey(open(file_path).read(), "Welcome") cipher = Cipher_pkcs1_v1_5.new(rsa_key) - return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8') + return cipher.decrypt(base64.b64decode( + line), "Fail to decrypt password!").decode('utf-8') def download_img(url): - if not url: return "" + if not url: + return "" response = requests.get(url) return "data:" + \ response.headers.get('Content-Type', 'image/jpg') + ";" + \ diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 08d386d..bf6ddb3 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -19,7 +19,7 @@ import time from functools import wraps from io import BytesIO from flask import ( - Response, jsonify, send_file,make_response, + Response, jsonify, send_file, make_response, request as flask_request, ) from werkzeug.http import HTTP_STATUS_CODES @@ -29,7 +29,7 @@ from api.versions import get_rag_version from api.settings import RetCode from api.settings import ( REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC, - stat_logger,CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY + stat_logger, CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY ) import requests import functools @@ -40,14 +40,21 @@ from hmac import HMAC from urllib.parse import quote, urlencode -requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) +requests.models.complexjson.dumps = functools.partial( + json.dumps, cls=CustomJSONEncoder) def request(**kwargs): sess = requests.Session() stream = kwargs.pop('stream', sess.stream) timeout = kwargs.pop('timeout', None) - kwargs['headers'] = {k.replace('_', '-').upper(): v for k, v in kwargs.get('headers', {}).items()} + kwargs['headers'] = { + k.replace( + '_', + '-').upper(): v for k, + v in kwargs.get( + 'headers', + {}).items()} prepped = requests.Request(**kwargs).prepare() if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY: @@ -59,7 +66,11 @@ def request(**kwargs): HTTP_APP_KEY.encode('ascii'), prepped.path_url.encode('ascii'), prepped.body if kwargs.get('json') else b'', - urlencode(sorted(kwargs['data'].items()), quote_via=quote, safe='-._~').encode('ascii') + urlencode( + sorted( + kwargs['data'].items()), + quote_via=quote, + safe='-._~').encode('ascii') if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'', ]), 'sha1').digest()).decode('ascii') @@ -88,11 +99,12 @@ def get_exponential_backoff_interval(retries, full_jitter=False): return max(0, countdown) -def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id=None, meta=None): +def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', + data=None, job_id=None, meta=None): import re result_dict = { "retcode": retcode, - "retmsg":retmsg, + "retmsg": retmsg, # "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE), "data": data, "jobId": job_id, @@ -107,9 +119,17 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id response[key] = value return jsonify(response) -def get_data_error_result(retcode=RetCode.DATA_ERROR, retmsg='Sorry! Data missing!'): + +def get_data_error_result(retcode=RetCode.DATA_ERROR, + retmsg='Sorry! Data missing!'): import re - result_dict = {"retcode": retcode, "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE)} + result_dict = { + "retcode": retcode, + "retmsg": re.sub( + r"rag", + "seceum", + retmsg, + flags=re.IGNORECASE)} response = {} for key, value in result_dict.items(): if value is None and key != "retcode": @@ -118,15 +138,17 @@ def get_data_error_result(retcode=RetCode.DATA_ERROR, retmsg='Sorry! Data missin response[key] = value return jsonify(response) + def server_error_response(e): stat_logger.exception(e) try: - if e.code==401: + if e.code == 401: return get_json_result(retcode=401, retmsg=repr(e)) - except: + except BaseException: pass if len(e.args) > 1: - return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1]) + return get_json_result( + retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1]) return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e)) @@ -162,10 +184,13 @@ def validate_request(*args, **kwargs): if no_arguments or error_arguments: error_string = "" if no_arguments: - error_string += "required argument are missing: {}; ".format(",".join(no_arguments)) + error_string += "required argument are missing: {}; ".format( + ",".join(no_arguments)) if error_arguments: - error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) - return get_json_result(retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string) + error_string += "required argument values: {}".format( + ",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) + return get_json_result( + retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string) return func(*_args, **_kwargs) return decorated_function return wrapper @@ -193,7 +218,8 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None): return jsonify(response) -def cors_reponse(retcode=RetCode.SUCCESS, retmsg='success', data=None, auth=None): +def cors_reponse(retcode=RetCode.SUCCESS, + retmsg='success', data=None, auth=None): result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data} response_dict = {} for key, value in result_dict.items(): @@ -209,4 +235,4 @@ def cors_reponse(retcode=RetCode.SUCCESS, retmsg='success', data=None, auth=None response.headers["Access-Control-Allow-Headers"] = "*" response.headers["Access-Control-Allow-Headers"] = "*" response.headers["Access-Control-Expose-Headers"] = "Authorization" - return response \ No newline at end of file + return response diff --git a/api/utils/file_utils.py b/api/utils/file_utils.py index 7f81459..159090f 100644 --- a/api/utils/file_utils.py +++ b/api/utils/file_utils.py @@ -29,6 +29,7 @@ from api.db import FileType PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") RAG_BASE = os.getenv("RAG_BASE") + def get_project_base_directory(*args): global PROJECT_BASE if PROJECT_BASE is None: @@ -65,7 +66,6 @@ def get_rag_python_directory(*args): return get_rag_directory("python", *args) - @cached(cache=LRUCache(maxsize=10)) def load_json_conf(conf_path): if os.path.isabs(conf_path): @@ -146,10 +146,12 @@ def filename_type(filename): if re.match(r".*\.pdf$", filename): return FileType.PDF.value - if re.match(r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md)$", filename): + if re.match( + r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md)$", filename): return FileType.DOC.value - if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): + if re.match( + r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): return FileType.AURAL.value if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename): @@ -164,14 +166,16 @@ def thumbnail(filename, blob): buffered = BytesIO() Image.frombytes("RGB", [pix.width, pix.height], pix.samples).save(buffered, format="png") - return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8") + return "data:image/png;base64," + \ + base64.b64encode(buffered.getvalue()).decode("utf-8") if re.match(r".*\.(jpg|jpeg|png|tif|gif|icon|ico|webp)$", filename): image = Image.open(BytesIO(blob)) image.thumbnail((30, 30)) buffered = BytesIO() image.save(buffered, format="png") - return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8") + return "data:image/png;base64," + \ + base64.b64encode(buffered.getvalue()).decode("utf-8") if re.match(r".*\.(ppt|pptx)$", filename): import aspose.slides as slides @@ -179,8 +183,10 @@ def thumbnail(filename, blob): try: with slides.Presentation(BytesIO(blob)) as presentation: buffered = BytesIO() - presentation.slides[0].get_thumbnail(0.03, 0.03).save(buffered, drawing.imaging.ImageFormat.png) - return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8") + presentation.slides[0].get_thumbnail(0.03, 0.03).save( + buffered, drawing.imaging.ImageFormat.png) + return "data:image/png;base64," + \ + base64.b64encode(buffered.getvalue()).decode("utf-8") except Exception as e: pass @@ -190,6 +196,3 @@ def traversal_files(base): for f in fs: fullname = os.path.join(root, f) yield fullname - - - diff --git a/api/utils/log_utils.py b/api/utils/log_utils.py index ee59e47..528bd99 100644 --- a/api/utils/log_utils.py +++ b/api/utils/log_utils.py @@ -23,6 +23,7 @@ from threading import RLock from api.utils import file_utils + class LoggerFactory(object): TYPE = "FILE" LOG_FORMAT = "[%(levelname)s] [%(asctime)s] [jobId] [%(process)s:%(thread)s] - [%(module)s.%(funcName)s] [line:%(lineno)d]: %(message)s" @@ -49,7 +50,8 @@ class LoggerFactory(object): schedule_logger_dict = {} @staticmethod - def set_directory(directory=None, parent_log_dir=None, append_to_parent_log=None, force=False): + def set_directory(directory=None, parent_log_dir=None, + append_to_parent_log=None, force=False): if parent_log_dir: LoggerFactory.PARENT_LOG_DIR = parent_log_dir if append_to_parent_log: @@ -66,11 +68,13 @@ class LoggerFactory(object): else: os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True) for loggerName, ghandler in LoggerFactory.global_handler_dict.items(): - for className, (logger, handler) in LoggerFactory.logger_dict.items(): + for className, (logger, + handler) in LoggerFactory.logger_dict.items(): logger.removeHandler(ghandler) ghandler.close() LoggerFactory.global_handler_dict = {} - for className, (logger, handler) in LoggerFactory.logger_dict.items(): + for className, (logger, + handler) in LoggerFactory.logger_dict.items(): logger.removeHandler(handler) _handler = None if handler: @@ -111,19 +115,23 @@ class LoggerFactory(object): if logger_name_key not in LoggerFactory.global_handler_dict: with LoggerFactory.lock: if logger_name_key not in LoggerFactory.global_handler_dict: - handler = LoggerFactory.get_handler(logger_name, level, log_dir) + handler = LoggerFactory.get_handler( + logger_name, level, log_dir) LoggerFactory.global_handler_dict[logger_name_key] = handler return LoggerFactory.global_handler_dict[logger_name_key] @staticmethod - def get_handler(class_name, level=None, log_dir=None, log_type=None, job_id=None): + def get_handler(class_name, level=None, log_dir=None, + log_type=None, job_id=None): if not log_type: if not LoggerFactory.LOG_DIR or not class_name: return logging.StreamHandler() # return Diy_StreamHandler() if not log_dir: - log_file = os.path.join(LoggerFactory.LOG_DIR, "{}.log".format(class_name)) + log_file = os.path.join( + LoggerFactory.LOG_DIR, + "{}.log".format(class_name)) else: log_file = os.path.join(log_dir, "{}.log".format(class_name)) else: @@ -133,16 +141,16 @@ class LoggerFactory(object): os.makedirs(os.path.dirname(log_file), exist_ok=True) if LoggerFactory.log_share: handler = ROpenHandler(log_file, - when='D', - interval=1, - backupCount=14, - delay=True) + when='D', + interval=1, + backupCount=14, + delay=True) else: handler = TimedRotatingFileHandler(log_file, - when='D', - interval=1, - backupCount=14, - delay=True) + when='D', + interval=1, + backupCount=14, + delay=True) if level: handler.level = level @@ -170,7 +178,9 @@ class LoggerFactory(object): for level in LoggerFactory.levels: if level >= LoggerFactory.LEVEL: level_logger_name = logging._levelToName[level] - logger.addHandler(LoggerFactory.get_global_handler(level_logger_name, level)) + logger.addHandler( + LoggerFactory.get_global_handler( + level_logger_name, level)) if LoggerFactory.append_to_parent_log and LoggerFactory.PARENT_LOG_DIR: for level in LoggerFactory.levels: if level >= LoggerFactory.LEVEL: @@ -224,22 +234,26 @@ def start_log(msg, job=None, task=None, role=None, party_id=None, detail=None): return f"{prefix}start to {msg}{suffix}" -def successful_log(msg, job=None, task=None, role=None, party_id=None, detail=None): +def successful_log(msg, job=None, task=None, role=None, + party_id=None, detail=None): prefix, suffix = base_msg(job, task, role, party_id, detail) return f"{prefix}{msg} successfully{suffix}" -def warning_log(msg, job=None, task=None, role=None, party_id=None, detail=None): +def warning_log(msg, job=None, task=None, role=None, + party_id=None, detail=None): prefix, suffix = base_msg(job, task, role, party_id, detail) return f"{prefix}{msg} is not effective{suffix}" -def failed_log(msg, job=None, task=None, role=None, party_id=None, detail=None): +def failed_log(msg, job=None, task=None, role=None, + party_id=None, detail=None): prefix, suffix = base_msg(job, task, role, party_id, detail) return f"{prefix}failed to {msg}{suffix}" -def base_msg(job=None, task=None, role: str = None, party_id: typing.Union[str, int] = None, detail=None): +def base_msg(job=None, task=None, role: str = None, + party_id: typing.Union[str, int] = None, detail=None): if detail: detail_msg = f" detail: \n{detail}" else: @@ -285,10 +299,14 @@ def get_job_logger(job_id, log_type): for job_log_dir in log_dirs: handler = LoggerFactory.get_handler(class_name=None, level=LoggerFactory.LEVEL, log_dir=job_log_dir, log_type=log_type, job_id=job_id) - error_handler = LoggerFactory.get_handler(class_name=None, level=logging.ERROR, log_dir=job_log_dir, log_type=log_type, job_id=job_id) + error_handler = LoggerFactory.get_handler( + class_name=None, + level=logging.ERROR, + log_dir=job_log_dir, + log_type=log_type, + job_id=job_id) logger.addHandler(handler) logger.addHandler(error_handler) with LoggerFactory.lock: LoggerFactory.schedule_logger_dict[job_id + log_type] = logger return logger - diff --git a/api/utils/t_crypt.py b/api/utils/t_crypt.py index 224bf22..6defa22 100644 --- a/api/utils/t_crypt.py +++ b/api/utils/t_crypt.py @@ -1,18 +1,23 @@ -import base64, os, sys +import base64 +import os +import sys from Cryptodome.PublicKey import RSA from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5 from api.utils import decrypt, file_utils + def crypt(line): - file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem") + file_path = os.path.join( + file_utils.get_project_base_directory(), + "conf", + "public.pem") rsa_key = RSA.importKey(open(file_path).read()) cipher = Cipher_pkcs1_v1_5.new(rsa_key) - return base64.b64encode(cipher.encrypt(line.encode('utf-8'))).decode("utf-8") - + return base64.b64encode(cipher.encrypt( + line.encode('utf-8'))).decode("utf-8") if __name__ == "__main__": pswd = crypt(sys.argv[1]) print(pswd) print(decrypt(pswd)) - diff --git a/deepdoc/parser/__init__.py b/deepdoc/parser/__init__.py index 8204287..30353b3 100644 --- a/deepdoc/parser/__init__.py +++ b/deepdoc/parser/__init__.py @@ -4,5 +4,3 @@ from .pdf_parser import HuParser as PdfParser, PlainParser from .docx_parser import HuDocxParser as DocxParser from .excel_parser import HuExcelParser as ExcelParser from .ppt_parser import HuPptParser as PptParser - - diff --git a/deepdoc/parser/docx_parser.py b/deepdoc/parser/docx_parser.py index 2ee0edb..10a84d5 100644 --- a/deepdoc/parser/docx_parser.py +++ b/deepdoc/parser/docx_parser.py @@ -99,12 +99,15 @@ class HuDocxParser: return ["\n".join(lines)] def __call__(self, fnm, from_page=0, to_page=100000): - self.doc = Document(fnm) if isinstance(fnm, str) else Document(BytesIO(fnm)) + self.doc = Document(fnm) if isinstance( + fnm, str) else Document(BytesIO(fnm)) pn = 0 secs = [] for p in self.doc.paragraphs: - if pn > to_page: break - if from_page <= pn < to_page and p.text.strip(): secs.append((p.text, p.style.name)) + if pn > to_page: + break + if from_page <= pn < to_page and p.text.strip(): + secs.append((p.text, p.style.name)) for run in p.runs: if 'lastRenderedPageBreak' in run._element.xml: pn += 1 diff --git a/deepdoc/parser/excel_parser.py b/deepdoc/parser/excel_parser.py index 4b436ec..7d470f3 100644 --- a/deepdoc/parser/excel_parser.py +++ b/deepdoc/parser/excel_parser.py @@ -15,13 +15,16 @@ class HuExcelParser: ws = wb[sheetname] rows = list(ws.rows) tb += f"<table><caption>{sheetname}</caption><tr>" - for t in list(rows[0]): tb += f"<th>{t.value}</th>" + for t in list(rows[0]): + tb += f"<th>{t.value}</th>" tb += "</tr>" for r in list(rows[1:]): tb += "<tr>" - for i,c in enumerate(r): - if c.value is None: tb += "<td></td>" - else: tb += f"<td>{c.value}</td>" + for i, c in enumerate(r): + if c.value is None: + tb += "<td></td>" + else: + tb += f"<td>{c.value}</td>" tb += "</tr>" tb += "</table>\n" return tb @@ -38,13 +41,15 @@ class HuExcelParser: ti = list(rows[0]) for r in list(rows[1:]): l = [] - for i,c in enumerate(r): - if not c.value:continue + for i, c in enumerate(r): + if not c.value: + continue t = str(ti[i].value) if i < len(ti) else "" t += (":" if t else "") + str(c.value) l.append(t) l = "; ".join(l) - if sheetname.lower().find("sheet") <0: l += " ——"+sheetname + if sheetname.lower().find("sheet") < 0: + l += " ——" + sheetname res.append(l) return res diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index ed6edea..10257ec 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -43,9 +43,11 @@ class HuParser: "rag/res/deepdoc"), local_files_only=True) except Exception as e: - model_dir = snapshot_download(repo_id="InfiniFlow/text_concat_xgb_v1.0") + model_dir = snapshot_download( + repo_id="InfiniFlow/text_concat_xgb_v1.0") - self.updown_cnt_mdl.load_model(os.path.join(model_dir, "updown_concat_xgb.model")) + self.updown_cnt_mdl.load_model(os.path.join( + model_dir, "updown_concat_xgb.model")) self.page_from = 0 """ If you have trouble downloading HuggingFace models, -_^ this might help!! @@ -72,7 +74,7 @@ class HuParser: def _y_dis( self, a, b): return ( - b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2 + b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2 def _match_proj(self, b): proj_patt = [ @@ -95,9 +97,9 @@ class HuParser: tks_down = huqie.qie(down["text"][:LEN]).split(" ") tks_up = huqie.qie(up["text"][-LEN:]).split(" ") tks_all = up["text"][-LEN:].strip() \ - + (" " if re.match(r"[a-zA-Z0-9]+", - up["text"][-1] + down["text"][0]) else "") \ - + down["text"][:LEN].strip() + + (" " if re.match(r"[a-zA-Z0-9]+", + up["text"][-1] + down["text"][0]) else "") \ + + down["text"][:LEN].strip() tks_all = huqie.qie(tks_all).split(" ") fea = [ up.get("R", -1) == down.get("R", -1), @@ -119,7 +121,7 @@ class HuParser: True if re.search(r"[,,][^。.]+$", up["text"]) else False, True if re.search(r"[,,][^。.]+$", up["text"]) else False, True if re.search(r"[\(ďĽ][^\))]+$", up["text"]) - and re.search(r"[\))]", down["text"]) else False, + and re.search(r"[\))]", down["text"]) else False, self._match_proj(down), True if re.match(r"[A-Z]", down["text"]) else False, True if re.match(r"[A-Z]", up["text"][-1]) else False, @@ -181,7 +183,7 @@ class HuParser: continue for tb in tbls: # for table left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \ - tb["x1"] + MARGIN, tb["bottom"] + MARGIN + tb["x1"] + MARGIN, tb["bottom"] + MARGIN left *= ZM top *= ZM right *= ZM @@ -235,7 +237,8 @@ class HuParser: b["R_top"] = rows[ii]["top"] b["R_bott"] = rows[ii]["bottom"] - ii = Recognizer.find_overlapped_with_threashold(b, headers, thr=0.3) + ii = Recognizer.find_overlapped_with_threashold( + b, headers, thr=0.3) if ii is not None: b["H_top"] = headers[ii]["top"] b["H_bott"] = headers[ii]["bottom"] @@ -272,7 +275,8 @@ class HuParser: ) # merge chars in the same rect - for c in Recognizer.sort_X_firstly(chars, self.mean_width[pagenum - 1] // 4): + for c in Recognizer.sort_X_firstly( + chars, self.mean_width[pagenum - 1] // 4): ii = Recognizer.find_overlapped(c, bxs) if ii is None: self.lefted_chars.append(c) @@ -283,13 +287,15 @@ class HuParser: self.lefted_chars.append(c) continue if c["text"] == " " and bxs[ii]["text"]: - if re.match(r"[0-9a-zA-Z,.?;:!%%]", bxs[ii]["text"][-1]): bxs[ii]["text"] += " " + if re.match(r"[0-9a-zA-Z,.?;:!%%]", bxs[ii]["text"][-1]): + bxs[ii]["text"] += " " else: bxs[ii]["text"] += c["text"] for b in bxs: if not b["text"]: - left, right, top, bott = b["x0"] * ZM, b["x1"] * ZM, b["top"] * ZM, b["bottom"] * ZM + left, right, top, bott = b["x0"] * ZM, b["x1"] * \ + ZM, b["top"] * ZM, b["bottom"] * ZM b["text"] = self.ocr.recognize(np.array(img), np.array([[left, top], [right, top], [right, bott], [left, bott]], dtype=np.float32)) @@ -302,7 +308,8 @@ class HuParser: def _layouts_rec(self, ZM, drop=True): assert len(self.page_images) == len(self.boxes) - self.boxes, self.page_layout = self.layouter(self.page_images, self.boxes, ZM, drop=drop) + self.boxes, self.page_layout = self.layouter( + self.page_images, self.boxes, ZM, drop=drop) # cumlative Y for i in range(len(self.boxes)): self.boxes[i]["top"] += \ @@ -332,7 +339,8 @@ class HuParser: "equation"]: i += 1 continue - if abs(self._y_dis(b, b_)) < self.mean_height[bxs[i]["page_number"] - 1] / 3: + if abs(self._y_dis(b, b_) + ) < self.mean_height[bxs[i]["page_number"] - 1] / 3: # merge bxs[i]["x1"] = b_["x1"] bxs[i]["top"] = (b["top"] + b_["top"]) / 2 @@ -366,12 +374,15 @@ class HuParser: self.boxes = bxs def _naive_vertical_merge(self): - bxs = Recognizer.sort_Y_firstly(self.boxes, np.median(self.mean_height) / 3) + bxs = Recognizer.sort_Y_firstly( + self.boxes, np.median( + self.mean_height) / 3) i = 0 while i + 1 < len(bxs): b = bxs[i] b_ = bxs[i + 1] - if b["page_number"] < b_["page_number"] and re.match(r"[0-9 •一—-]+$", b["text"]): + if b["page_number"] < b_["page_number"] and re.match( + r"[0-9 •一—-]+$", b["text"]): bxs.pop(i) continue if not b["text"].strip(): @@ -379,7 +390,8 @@ class HuParser: continue concatting_feats = [ b["text"].strip()[-1] in ",;:'\",ă€â€â€śďĽ›ďĽš-", - len(b["text"].strip()) > 1 and b["text"].strip()[-2] in ",;:'\",â€â€śă€ďĽ›ďĽš", + len(b["text"].strip()) > 1 and b["text"].strip( + )[-2] in ",;:'\",â€â€śă€ďĽ›ďĽš", b["text"].strip()[0] in "。;?ďĽ?”)),,ă€ďĽš", ] # features for not concating @@ -387,7 +399,7 @@ class HuParser: b.get("layoutno", 0) != b.get("layoutno", 0), b["text"].strip()[-1] in "。?ďĽ?", self.is_english and b["text"].strip()[-1] in ".!?", - b["page_number"] == b_["page_number"] and b_["top"] - \ + b["page_number"] == b_["page_number"] and b_["top"] - b["bottom"] > self.mean_height[b["page_number"] - 1] * 1.5, b["page_number"] < b_["page_number"] and abs( b["x0"] - b_["x0"]) > self.mean_width[b["page_number"] - 1] * 4, @@ -396,7 +408,12 @@ class HuParser: detach_feats = [b["x1"] < b_["x0"], b["x0"] > b_["x1"]] if (any(feats) and not any(concatting_feats)) or any(detach_feats): - print(b["text"], b_["text"], any(feats), any(concatting_feats), any(detach_feats)) + print( + b["text"], + b_["text"], + any(feats), + any(concatting_feats), + any(detach_feats)) i += 1 continue # merge up and down @@ -526,31 +543,39 @@ class HuParser: i += 1 continue findit = True - eng = re.match(r"[0-9a-zA-Z :'.-]{5,}", self.boxes[i]["text"].strip()) + eng = re.match( + r"[0-9a-zA-Z :'.-]{5,}", + self.boxes[i]["text"].strip()) self.boxes.pop(i) - if i >= len(self.boxes): break + if i >= len(self.boxes): + break prefix = self.boxes[i]["text"].strip()[:3] if not eng else " ".join( self.boxes[i]["text"].strip().split(" ")[:2]) while not prefix: self.boxes.pop(i) - if i >= len(self.boxes): break + if i >= len(self.boxes): + break prefix = self.boxes[i]["text"].strip()[:3] if not eng else " ".join( self.boxes[i]["text"].strip().split(" ")[:2]) self.boxes.pop(i) - if i >= len(self.boxes) or not prefix: break + if i >= len(self.boxes) or not prefix: + break for j in range(i, min(i + 128, len(self.boxes))): if not re.match(prefix, self.boxes[j]["text"]): continue - for k in range(i, j): self.boxes.pop(i) + for k in range(i, j): + self.boxes.pop(i) break - if findit: return + if findit: + return page_dirty = [0] * len(self.page_images) for b in self.boxes: if re.search(r"(··|··|··)", b["text"]): page_dirty[b["page_number"] - 1] += 1 page_dirty = set([i + 1 for i, t in enumerate(page_dirty) if t > 3]) - if not page_dirty: return + if not page_dirty: + return i = 0 while i < len(self.boxes): if self.boxes[i]["page_number"] in page_dirty: @@ -582,7 +607,8 @@ class HuParser: b_["top"] = b["top"] self.boxes.pop(i) - def _extract_table_figure(self, need_image, ZM, return_html, need_position): + def _extract_table_figure(self, need_image, ZM, + return_html, need_position): tables = {} figures = {} # extract figure and table boxes @@ -594,7 +620,7 @@ class HuParser: i += 1 continue lout_no = str(self.boxes[i]["page_number"]) + \ - "-" + str(self.boxes[i]["layoutno"]) + "-" + str(self.boxes[i]["layoutno"]) if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", "title", "figure caption", @@ -761,7 +787,8 @@ class HuParser: for k, bxs in tables.items(): if not bxs: continue - bxs = Recognizer.sort_Y_firstly(bxs, np.mean([(b["bottom"] - b["top"]) / 2 for b in bxs])) + bxs = Recognizer.sort_Y_firstly(bxs, np.mean( + [(b["bottom"] - b["top"]) / 2 for b in bxs])) poss = [] res.append((cropout(bxs, "table", poss), self.tbl_det.construct_table(bxs, html=return_html, is_english=self.is_english))) @@ -769,7 +796,8 @@ class HuParser: assert len(positions) == len(res) - if need_position: return list(zip(res, positions)) + if need_position: + return list(zip(res, positions)) return res def proj_match(self, line): @@ -873,7 +901,8 @@ class HuParser: boxes.pop(0) mw = np.mean(widths) if mj or mw / pw >= 0.35 or mw > 200: - res.append("\n".join([c["text"] + self._line_tag(c, ZM) for c in lines])) + res.append( + "\n".join([c["text"] + self._line_tag(c, ZM) for c in lines])) else: logging.debug("REMOVED: " + "<<".join([c["text"] for c in lines])) @@ -883,13 +912,16 @@ class HuParser: @staticmethod def total_page_number(fnm, binary=None): try: - pdf = pdfplumber.open(fnm) if not binary else pdfplumber.open(BytesIO(binary)) + pdf = pdfplumber.open( + fnm) if not binary else pdfplumber.open(BytesIO(binary)) return len(pdf.pages) except Exception as e: - pdf = fitz.open(fnm) if not binary else fitz.open(stream=fnm, filetype="pdf") + pdf = fitz.open(fnm) if not binary else fitz.open( + stream=fnm, filetype="pdf") return len(pdf) - def __images__(self, fnm, zoomin=3, page_from=0, page_to=299, callback=None): + def __images__(self, fnm, zoomin=3, page_from=0, + page_to=299, callback=None): self.lefted_chars = [] self.mean_height = [] self.mean_width = [] @@ -899,21 +931,26 @@ class HuParser: self.page_layout = [] self.page_from = page_from try: - self.pdf = pdfplumber.open(fnm) if isinstance(fnm, str) else pdfplumber.open(BytesIO(fnm)) + self.pdf = pdfplumber.open(fnm) if isinstance( + fnm, str) else pdfplumber.open(BytesIO(fnm)) self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in enumerate(self.pdf.pages[page_from:page_to])] self.page_chars = [[c for c in page.chars if self._has_color(c)] for page in self.pdf.pages[page_from:page_to]] self.total_page = len(self.pdf.pages) except Exception as e: - self.pdf = fitz.open(fnm) if isinstance(fnm, str) else fitz.open(stream=fnm, filetype="pdf") + self.pdf = fitz.open(fnm) if isinstance( + fnm, str) else fitz.open( + stream=fnm, filetype="pdf") self.page_images = [] self.page_chars = [] mat = fitz.Matrix(zoomin, zoomin) self.total_page = len(self.pdf) for i, page in enumerate(self.pdf): - if i < page_from: continue - if i >= page_to: break + if i < page_from: + continue + if i >= page_to: + break pix = page.get_pixmap(matrix=mat) img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) @@ -930,7 +967,7 @@ class HuParser: if isinstance(a, dict): self.outlines.append((a["/Title"], depth)) continue - dfs(a, depth+1) + dfs(a, depth + 1) dfs(outlines, 0) except Exception as e: logging.warning(f"Outlines exception: {e}") @@ -940,8 +977,9 @@ class HuParser: logging.info("Images converted.") self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join( random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in - range(len(self.page_chars))] - if sum([1 if e else 0 for e in self.is_english]) > len(self.page_images) / 2: + range(len(self.page_chars))] + if sum([1 if e else 0 for e in self.is_english]) > len( + self.page_images) / 2: self.is_english = True else: self.is_english = False @@ -970,9 +1008,11 @@ class HuParser: # self.page_cum_height.append( # np.max([c["bottom"] for c in chars])) self.__ocr(i + 1, img, chars, zoomin) - if callback: callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="") + if callback: + callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="") - if not self.is_english and not any([c for c in self.page_chars]) and self.boxes: + if not self.is_english and not any( + [c for c in self.page_chars]) and self.boxes: bxes = [b for bxs in self.boxes for b in bxs] self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join([b["text"] for b in random.choices(bxes, k=min(30, len(bxes)))])) @@ -989,7 +1029,8 @@ class HuParser: self._text_merge() self._concat_downward() self._filter_forpages() - tbls = self._extract_table_figure(need_image, zoomin, return_html, False) + tbls = self._extract_table_figure( + need_image, zoomin, return_html, False) return self.__filterout_scraps(deepcopy(self.boxes), zoomin), tbls def remove_tag(self, txt): @@ -1003,15 +1044,19 @@ class HuParser: "#").strip("@").split("\t") left, right, top, bottom = float(left), float( right), float(top), float(bottom) - poss.append(([int(p) - 1 for p in pn.split("-")], left, right, top, bottom)) + poss.append(([int(p) - 1 for p in pn.split("-")], + left, right, top, bottom)) if not poss: - if need_position: return None, None + if need_position: + return None, None return - max_width = max(np.max([right - left for (_, left, right, _, _) in poss]), 6) + max_width = max( + np.max([right - left for (_, left, right, _, _) in poss]), 6) GAP = 6 pos = poss[0] - poss.insert(0, ([pos[0][0]], pos[1], pos[2], max(0, pos[3] - 120), max(pos[3] - GAP, 0))) + poss.insert(0, ([pos[0][0]], pos[1], pos[2], max( + 0, pos[3] - 120), max(pos[3] - GAP, 0))) pos = poss[-1] poss.append(([pos[0][-1]], pos[1], pos[2], min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + GAP), min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + 120))) @@ -1026,7 +1071,7 @@ class HuParser: self.page_images[pns[0]].crop((left * ZM, top * ZM, right * ZM, min( - bottom, self.page_images[pns[0]].size[1]) + bottom, self.page_images[pns[0]].size[1]) )) ) if 0 < ii < len(poss) - 1: @@ -1047,7 +1092,8 @@ class HuParser: bottom -= self.page_images[pn].size[1] if not imgs: - if need_position: return None, None + if need_position: + return None, None return height = 0 for img in imgs: @@ -1076,12 +1122,14 @@ class HuParser: pn = bx["page_number"] top = bx["top"] - self.page_cum_height[pn - 1] bott = bx["bottom"] - self.page_cum_height[pn - 1] - poss.append((pn, bx["x0"], bx["x1"], top, min(bott, self.page_images[pn - 1].size[1] / ZM))) + poss.append((pn, bx["x0"], bx["x1"], top, min( + bott, self.page_images[pn - 1].size[1] / ZM))) while bott * ZM > self.page_images[pn - 1].size[1]: bott -= self.page_images[pn - 1].size[1] / ZM top = 0 pn += 1 - poss.append((pn, bx["x0"], bx["x1"], top, min(bott, self.page_images[pn - 1].size[1] / ZM))) + poss.append((pn, bx["x0"], bx["x1"], top, min( + bott, self.page_images[pn - 1].size[1] / ZM))) return poss @@ -1090,11 +1138,14 @@ class PlainParser(object): self.outlines = [] lines = [] try: - self.pdf = pdf2_read(filename if isinstance(filename, str) else BytesIO(filename)) + self.pdf = pdf2_read( + filename if isinstance( + filename, str) else BytesIO(filename)) for page in self.pdf.pages[from_page:to_page]: lines.extend([t for t in page.extract_text().split("\n")]) outlines = self.pdf.outline + def dfs(arr, depth): for a in arr: if isinstance(a, dict): @@ -1117,5 +1168,6 @@ class PlainParser(object): def remove_tag(txt): raise NotImplementedError + if __name__ == "__main__": pass diff --git a/deepdoc/parser/ppt_parser.py b/deepdoc/parser/ppt_parser.py index 899103a..7266112 100644 --- a/deepdoc/parser/ppt_parser.py +++ b/deepdoc/parser/ppt_parser.py @@ -23,7 +23,8 @@ class HuPptParser(object): tb = shape.table rows = [] for i in range(1, len(tb.rows)): - rows.append("; ".join([tb.cell(0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)])) + rows.append("; ".join([tb.cell( + 0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)])) return "\n".join(rows) if shape.has_text_frame: @@ -31,9 +32,10 @@ class HuPptParser(object): if shape.shape_type == 6: texts = [] - for p in sorted(shape.shapes, key=lambda x: (x.top//10, x.left)): + for p in sorted(shape.shapes, key=lambda x: (x.top // 10, x.left)): t = self.__extract(p) - if t: texts.append(t) + if t: + texts.append(t) return "\n".join(texts) def __call__(self, fnm, from_page, to_page, callback=None): @@ -43,12 +45,16 @@ class HuPptParser(object): txts = [] self.total_page = len(ppt.slides) for i, slide in enumerate(ppt.slides): - if i < from_page: continue - if i >= to_page:break + if i < from_page: + continue + if i >= to_page: + break texts = [] - for shape in sorted(slide.shapes, key=lambda x: (x.top//10, x.left)): + for shape in sorted( + slide.shapes, key=lambda x: (x.top // 10, x.left)): txt = self.__extract(shape) - if txt: texts.append(txt) + if txt: + texts.append(txt) txts.append("\n".join(texts)) return txts diff --git a/deepdoc/vision/layout_recognizer.py b/deepdoc/vision/layout_recognizer.py index e107e07..7b87622 100644 --- a/deepdoc/vision/layout_recognizer.py +++ b/deepdoc/vision/layout_recognizer.py @@ -24,18 +24,19 @@ from deepdoc.vision import Recognizer class LayoutRecognizer(Recognizer): labels = [ - "_background_", - "Text", - "Title", - "Figure", - "Figure caption", - "Table", - "Table caption", - "Header", - "Footer", - "Reference", - "Equation", - ] + "_background_", + "Text", + "Title", + "Figure", + "Figure caption", + "Table", + "Table caption", + "Header", + "Footer", + "Reference", + "Equation", + ] + def __init__(self, domain): try: model_dir = snapshot_download( @@ -47,10 +48,12 @@ class LayoutRecognizer(Recognizer): except Exception as e: model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc") - super().__init__(self.labels, domain, model_dir)#os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) + # os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) + super().__init__(self.labels, domain, model_dir) self.garbage_layouts = ["footer", "header", "reference"] - def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True): + def __call__(self, image_list, ocr_res, scale_factor=3, + thr=0.2, batch_size=16, drop=True): def __is_garbage(b): patt = [r"^•+$", r"(ç‰ćťĺ˝’©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$", r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", @@ -75,7 +78,8 @@ class LayoutRecognizer(Recognizer): "top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor, "page_number": pn, } for b in lts] - lts = self.sort_Y_firstly(lts, np.mean([l["bottom"]-l["top"] for l in lts]) / 2) + lts = self.sort_Y_firstly(lts, np.mean( + [l["bottom"] - l["top"] for l in lts]) / 2) lts = self.layouts_cleanup(bxs, lts) page_layout.append(lts) @@ -93,17 +97,20 @@ class LayoutRecognizer(Recognizer): continue ii = self.find_overlapped_with_threashold(bxs[i], lts_, - thr=0.4) + thr=0.4) if ii is None: # belong to nothing bxs[i]["layout_type"] = "" i += 1 continue lts_[ii]["visited"] = True keep_feats = [ - lts_[ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1]*0.9/scale_factor, - lts_[ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1]*0.1/scale_factor, + lts_[ + ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1] * 0.9 / scale_factor, + lts_[ + ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1] * 0.1 / scale_factor, ] - if drop and lts_[ii]["type"] in self.garbage_layouts and not any(keep_feats): + if drop and lts_[ + ii]["type"] in self.garbage_layouts and not any(keep_feats): if lts_[ii]["type"] not in garbages: garbages[lts_[ii]["type"]] = [] garbages[lts_[ii]["type"]].append(bxs[i]["text"]) @@ -111,7 +118,8 @@ class LayoutRecognizer(Recognizer): continue bxs[i]["layoutno"] = f"{ty}-{ii}" - bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[ii]["type"]!="equation" else "figure" + bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[ + ii]["type"] != "equation" else "figure" i += 1 for lt in ["footer", "header", "reference", "figure caption", @@ -120,7 +128,7 @@ class LayoutRecognizer(Recognizer): # add box to figure layouts which has not text box for i, lt in enumerate( - [lt for lt in lts if lt["type"] in ["figure","equation"]]): + [lt for lt in lts if lt["type"] in ["figure", "equation"]]): if lt.get("visited"): continue lt = deepcopy(lt) @@ -143,6 +151,3 @@ class LayoutRecognizer(Recognizer): ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set] return ocr_res, page_layout - - - diff --git a/deepdoc/vision/operators.py b/deepdoc/vision/operators.py index a4ef57c..382fe36 100644 --- a/deepdoc/vision/operators.py +++ b/deepdoc/vision/operators.py @@ -63,6 +63,7 @@ class DecodeImage(object): data['image'] = img return data + class StandardizeImage(object): """normalize image Args: @@ -707,4 +708,4 @@ def preprocess(im, preprocess_ops): im, im_info = decode_image(im, im_info) for operator in preprocess_ops: im, im_info = operator(im, im_info) - return im, im_info \ No newline at end of file + return im, im_info diff --git a/deepdoc/vision/t_ocr.py b/deepdoc/vision/t_ocr.py index 79b3cdb..d30f3c2 100644 --- a/deepdoc/vision/t_ocr.py +++ b/deepdoc/vision/t_ocr.py @@ -11,12 +11,20 @@ # limitations under the License. # -import os, sys -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../'))) -import numpy as np -import argparse -from deepdoc.vision import OCR, init_in_out from deepdoc.vision.seeit import draw_box +from deepdoc.vision import OCR, init_in_out +import argparse +import numpy as np +import os +import sys +sys.path.insert( + 0, + os.path.abspath( + os.path.join( + os.path.dirname( + os.path.abspath(__file__)), + '../../'))) + def main(args): ocr = OCR() @@ -26,14 +34,14 @@ def main(args): bxs = ocr(np.array(img)) bxs = [(line[0], line[1][0]) for line in bxs] bxs = [{ - "text": t, - "bbox": [b[0][0], b[0][1], b[1][0], b[-1][1]], - "type": "ocr", - "score": 1} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]] + "text": t, + "bbox": [b[0][0], b[0][1], b[1][0], b[-1][1]], + "type": "ocr", + "score": 1} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]] img = draw_box(images[i], bxs, ["ocr"], 1.) img.save(outputs[i], quality=95) - with open(outputs[i] + ".txt", "w+") as f: f.write("\n".join([o["text"] for o in bxs])) - + with open(outputs[i] + ".txt", "w+") as f: + f.write("\n".join([o["text"] for o in bxs])) if __name__ == "__main__": @@ -42,6 +50,6 @@ if __name__ == "__main__": help="Directory where to store images or PDFs, or a file path to a single image or PDF", required=True) parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './ocr_outputs'", - default="./ocr_outputs") + default="./ocr_outputs") args = parser.parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/deepdoc/vision/t_recognizer.py b/deepdoc/vision/t_recognizer.py index 23033e2..a04afa4 100644 --- a/deepdoc/vision/t_recognizer.py +++ b/deepdoc/vision/t_recognizer.py @@ -11,24 +11,35 @@ # limitations under the License. # -import os, sys +from deepdoc.vision.seeit import draw_box +from deepdoc.vision import Recognizer, LayoutRecognizer, TableStructureRecognizer, OCR, init_in_out +from api.utils.file_utils import get_project_base_directory +import argparse +import os +import sys import re import numpy as np -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../'))) - -import argparse -from api.utils.file_utils import get_project_base_directory -from deepdoc.vision import Recognizer, LayoutRecognizer, TableStructureRecognizer, OCR, init_in_out -from deepdoc.vision.seeit import draw_box +sys.path.insert( + 0, + os.path.abspath( + os.path.join( + os.path.dirname( + os.path.abspath(__file__)), + '../../'))) def main(args): images, outputs = init_in_out(args) if args.mode.lower() == "layout": labels = LayoutRecognizer.labels - detr = Recognizer(labels, "layout", os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) + detr = Recognizer( + labels, + "layout", + os.path.join( + get_project_base_directory(), + "rag/res/deepdoc/")) if args.mode.lower() == "tsr": labels = TableStructureRecognizer.labels detr = TableStructureRecognizer() @@ -39,7 +50,8 @@ def main(args): if args.mode.lower() == "tsr": #lyt = [t for t in lyt if t["type"] == "table column"] html = get_table_html(images[i], lyt, ocr) - with open(outputs[i]+".html", "w+") as f: f.write(html) + with open(outputs[i] + ".html", "w+") as f: + f.write(html) lyt = [{ "type": t["label"], "bbox": [t["x0"], t["top"], t["x1"], t["bottom"]], @@ -58,7 +70,7 @@ def get_table_html(img, tb_cpns, ocr): "bottom": b[-1][1], "layout_type": "table", "page_number": 0} for b, t in boxes if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]], - np.mean([b[-1][1]-b[0][1] for b,_ in boxes]) / 3 + np.mean([b[-1][1] - b[0][1] for b, _ in boxes]) / 3 ) def gather(kwd, fzy=10, ption=0.6): @@ -117,7 +129,7 @@ def get_table_html(img, tb_cpns, ocr): margin-bottom: 50px; border: 1px solid #e1e1e1; } - + caption { color: #6ac1ca; font-size: 20px; @@ -126,25 +138,25 @@ def get_table_html(img, tb_cpns, ocr): font-weight: 600; margin-bottom: 10px; } - + ._table_1nkzy_11 table { width: 100%%; border-collapse: collapse; } - + th { color: #fff; background-color: #6ac1ca; } - + td:hover { background: #c1e8e8; } - + tr:nth-child(even) { background-color: #f2f2f2; } - + ._table_1nkzy_11 th, ._table_1nkzy_11 td { text-align: center; @@ -157,7 +169,7 @@ def get_table_html(img, tb_cpns, ocr): %s </body> </html> -"""% TableStructureRecognizer.construct_table(boxes, html=True) +""" % TableStructureRecognizer.construct_table(boxes, html=True) return html @@ -168,7 +180,10 @@ if __name__ == "__main__": required=True) parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './layouts_outputs'", default="./layouts_outputs") - parser.add_argument('--threshold', help="A threshold to filter out detections. Default: 0.5", default=0.5) + parser.add_argument( + '--threshold', + help="A threshold to filter out detections. Default: 0.5", + default=0.5) parser.add_argument('--mode', help="Task mode: layout recognition or table structure recognition", choices=["layout", "tsr"], default="layout") args = parser.parse_args() diff --git a/deepdoc/vision/table_structure_recognizer.py b/deepdoc/vision/table_structure_recognizer.py index 022d558..ebd57a6 100644 --- a/deepdoc/vision/table_structure_recognizer.py +++ b/deepdoc/vision/table_structure_recognizer.py @@ -44,7 +44,8 @@ class TableStructureRecognizer(Recognizer): except Exception as e: model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc") - super().__init__(self.labels, "tsr", model_dir)#os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) + # os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) + super().__init__(self.labels, "tsr", model_dir) def __call__(self, images, thr=0.2): tbls = super().__call__(images, thr) @@ -138,7 +139,8 @@ class TableStructureRecognizer(Recognizer): i = 0 while i < len(boxes): if TableStructureRecognizer.is_caption(boxes[i]): - if is_english: cap + " " + if is_english: + cap + " " cap += boxes[i]["text"] boxes.pop(i) i -= 1 @@ -164,7 +166,7 @@ class TableStructureRecognizer(Recognizer): lst_r = rows[-1] if lst_r[-1].get("R", "") != b.get("R", "") \ or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2") - ): # new row + ): # new row btm = b["bottom"] b["rn"] += 1 rows.append([b]) @@ -214,9 +216,9 @@ class TableStructureRecognizer(Recognizer): j += 1 continue f = (j > 0 and tbl[ii][j - 1] and tbl[ii] - [j - 1][0].get("text")) or j == 0 + [j - 1][0].get("text")) or j == 0 ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii] - [j + 1][0].get("text")) or j + 1 >= len(tbl[ii]) + [j + 1][0].get("text")) or j + 1 >= len(tbl[ii]) if f and ff: j += 1 continue @@ -277,9 +279,9 @@ class TableStructureRecognizer(Recognizer): i += 1 continue f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1] - [jj][0].get("text")) or i == 0 + [jj][0].get("text")) or i == 0 ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1] - [jj][0].get("text")) or i + 1 >= len(tbl) + [jj][0].get("text")) or i + 1 >= len(tbl) if f and ff: i += 1 continue @@ -366,7 +368,8 @@ class TableStructureRecognizer(Recognizer): continue txt = "" if arr: - h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10) + h = min(np.min([c["bottom"] - c["top"] + for c in arr]) / 2, 10) txt = " ".join([c["text"] for c in Recognizer.sort_Y_firstly(arr, h)]) txts.append(txt) @@ -438,8 +441,8 @@ class TableStructureRecognizer(Recognizer): else "") + headers[j - 1][k] else: headers[j][k] = headers[j - 1][k] \ - + (de if headers[j - 1][k] else "") \ - + headers[j][k] + + (de if headers[j - 1][k] else "") \ + + headers[j][k] logging.debug( f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}") diff --git a/rag/app/book.py b/rag/app/book.py index 6f51a95..2ea850a 100644 --- a/rag/app/book.py +++ b/rag/app/book.py @@ -48,10 +48,12 @@ class Pdf(PdfParser): callback(0.8, "Text extraction finished") - return [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno","")) for b in self.boxes], tbls + return [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) + for b in self.boxes], tbls -def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): +def chunk(filename, binary=None, from_page=0, to_page=100000, + lang="Chinese", callback=None, **kwargs): """ Supported file formats are docx, pdf, txt. Since a book is long and not all the parts are useful, if it's a PDF, @@ -63,48 +65,63 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca } doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) pdf_parser = None - sections,tbls = [], [] + sections, tbls = [], [] if re.search(r"\.docx?$", filename, re.IGNORECASE): callback(0.1, "Start to parse.") doc_parser = DocxParser() # TODO: table of contents need to be removed - sections, tbls = doc_parser(binary if binary else filename, from_page=from_page, to_page=to_page) - remove_contents_table(sections, eng=is_english(random_choices([t for t,_ in sections], k=200))) + sections, tbls = doc_parser( + binary if binary else filename, from_page=from_page, to_page=to_page) + remove_contents_table(sections, eng=is_english( + random_choices([t for t, _ in sections], k=200))) callback(0.8, "Finish parsing.") elif re.search(r"\.pdf$", filename, re.IGNORECASE): - pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser() + pdf_parser = Pdf() if kwargs.get( + "parser_config", {}).get( + "layout_recognize", True) else PlainParser() sections, tbls = pdf_parser(filename if not binary else binary, - from_page=from_page, to_page=to_page, callback=callback) + from_page=from_page, to_page=to_page, callback=callback) elif re.search(r"\.txt$", filename, re.IGNORECASE): callback(0.1, "Start to parse.") txt = "" - if binary:txt = binary.decode("utf-8") + if binary: + txt = binary.decode("utf-8") else: with open(filename, "r") as f: while True: l = f.readline() - if not l:break + if not l: + break txt += l sections = txt.split("\n") - sections = [(l,"") for l in sections if l] - remove_contents_table(sections, eng = is_english(random_choices([t for t,_ in sections], k=200))) + sections = [(l, "") for l in sections if l] + remove_contents_table(sections, eng=is_english( + random_choices([t for t, _ in sections], k=200))) callback(0.8, "Finish parsing.") - else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") + else: + raise NotImplementedError( + "file type not supported yet(docx, pdf, txt supported)") make_colon_as_title(sections) - bull = bullets_category([t for t in random_choices([t for t,_ in sections], k=100)]) + bull = bullets_category( + [t for t in random_choices([t for t, _ in sections], k=100)]) if bull >= 0: - chunks = ["\n".join(ck) for ck in hierarchical_merge(bull, sections, 3)] + chunks = ["\n".join(ck) + for ck in hierarchical_merge(bull, sections, 3)] else: - sections = [s.split("@") for s,_ in sections] - sections = [(pr[0], "@"+pr[1]) for pr in sections if len(pr)==2] - chunks = naive_merge(sections, kwargs.get("chunk_token_num", 256), kwargs.get("delimer", "\n。;ďĽďĽź")) + sections = [s.split("@") for s, _ in sections] + sections = [(pr[0], "@" + pr[1]) for pr in sections if len(pr) == 2] + chunks = naive_merge( + sections, kwargs.get( + "chunk_token_num", 256), kwargs.get( + "delimer", "\n。;ďĽďĽź")) # is it English - eng = lang.lower() == "english"#is_english(random_choices([t for t, _ in sections], k=218)) + # is_english(random_choices([t for t, _ in sections], k=218)) + eng = lang.lower() == "english" res = tokenize_table(tbls, doc, eng) res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser)) @@ -114,6 +131,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca if __name__ == "__main__": import sys + def dummy(prog=None, msg=""): pass chunk(sys.argv[1], from_page=1, to_page=10, callback=dummy) diff --git a/rag/app/laws.py b/rag/app/laws.py index 94a1e7a..1c99479 100644 --- a/rag/app/laws.py +++ b/rag/app/laws.py @@ -35,8 +35,10 @@ class Docx(DocxParser): pn = 0 lines = [] for p in self.doc.paragraphs: - if pn > to_page:break - if from_page <= pn < to_page and p.text.strip(): lines.append(self.__clean(p.text)) + if pn > to_page: + break + if from_page <= pn < to_page and p.text.strip(): + lines.append(self.__clean(p.text)) for run in p.runs: if 'lastRenderedPageBreak' in run._element.xml: pn += 1 @@ -63,15 +65,18 @@ class Pdf(PdfParser): start = timer() self._layouts_rec(zoomin) callback(0.67, "Layout analysis finished") - cron_logger.info("paddle layouts:".format((timer()-start)/(self.total_page+0.1))) + cron_logger.info("paddle layouts:".format( + (timer() - start) / (self.total_page + 0.1))) self._naive_vertical_merge() callback(0.8, "Text extraction finished") - return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], None + return [(b["text"], self._line_tag(b, zoomin)) + for b in self.boxes], None -def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): +def chunk(filename, binary=None, from_page=0, to_page=100000, + lang="Chinese", callback=None, **kwargs): """ Supported file formats are docx, pdf, txt. """ @@ -89,41 +94,50 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca callback(0.8, "Finish parsing.") elif re.search(r"\.pdf$", filename, re.IGNORECASE): - pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser() - for txt, poss in pdf_parser(filename if not binary else binary, - from_page=from_page, to_page=to_page, callback=callback)[0]: - sections.append(txt + poss) + pdf_parser = Pdf() if kwargs.get( + "parser_config", {}).get( + "layout_recognize", True) else PlainParser() + for txt, poss in pdf_parser(filename if not binary else binary, + from_page=from_page, to_page=to_page, callback=callback)[0]: + sections.append(txt + poss) elif re.search(r"\.txt$", filename, re.IGNORECASE): callback(0.1, "Start to parse.") txt = "" - if binary:txt = binary.decode("utf-8") + if binary: + txt = binary.decode("utf-8") else: with open(filename, "r") as f: while True: l = f.readline() - if not l:break + if not l: + break txt += l sections = txt.split("\n") sections = [l for l in sections if l] callback(0.8, "Finish parsing.") - else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") + else: + raise NotImplementedError( + "file type not supported yet(docx, pdf, txt supported)") # is it English - eng = lang.lower() == "english"#is_english(sections) + eng = lang.lower() == "english" # is_english(sections) # Remove 'Contents' part remove_contents_table(sections, eng) make_colon_as_title(sections) bull = bullets_category(sections) chunks = hierarchical_merge(bull, sections, 3) - if not chunks: callback(0.99, "No chunk parsed out.") + if not chunks: + callback(0.99, "No chunk parsed out.") - return tokenize_chunks(["\n".join(ck) for ck in chunks], doc, eng, pdf_parser) + return tokenize_chunks(["\n".join(ck) + for ck in chunks], doc, eng, pdf_parser) if __name__ == "__main__": import sys + def dummy(prog=None, msg=""): pass chunk(sys.argv[1], callback=dummy) diff --git a/rag/app/manual.py b/rag/app/manual.py index d829e3a..234ff9c 100644 --- a/rag/app/manual.py +++ b/rag/app/manual.py @@ -25,10 +25,10 @@ class Pdf(PdfParser): callback ) callback(msg="OCR finished.") - #for bb in self.boxes: + # for bb in self.boxes: # for b in bb: # print(b) - print("OCR:", timer()-start) + print("OCR:", timer() - start) self._layouts_rec(zoomin) callback(0.65, "Layout analysis finished.") @@ -45,30 +45,35 @@ class Pdf(PdfParser): for b in self.boxes: b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip()) - return [(b["text"], b.get("layout_no", ""), self.get_position(b, zoomin)) for i, b in enumerate(self.boxes)], tbls + return [(b["text"], b.get("layout_no", ""), self.get_position(b, zoomin)) + for i, b in enumerate(self.boxes)], tbls - -def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): +def chunk(filename, binary=None, from_page=0, to_page=100000, + lang="Chinese", callback=None, **kwargs): """ Only pdf is supported. """ pdf_parser = None if re.search(r"\.pdf$", filename, re.IGNORECASE): - pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser() + pdf_parser = Pdf() if kwargs.get( + "parser_config", {}).get( + "layout_recognize", True) else PlainParser() sections, tbls = pdf_parser(filename if not binary else binary, - from_page=from_page, to_page=to_page, callback=callback) - if sections and len(sections[0])<3: sections = [(t, l, [[0]*5]) for t, l in sections] + from_page=from_page, to_page=to_page, callback=callback) + if sections and len(sections[0]) < 3: + sections = [(t, l, [[0] * 5]) for t, l in sections] - else: raise NotImplementedError("file type not supported yet(pdf supported)") + else: + raise NotImplementedError("file type not supported yet(pdf supported)") doc = { "docnm_kwd": filename } doc["title_tks"] = huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", doc["docnm_kwd"])) doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) # is it English - eng = lang.lower() == "english"#pdf_parser.is_english + eng = lang.lower() == "english" # pdf_parser.is_english # set pivot using the most frequent type of title, # then merge between 2 pivot @@ -79,7 +84,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca for txt, _, _ in sections: for t, lvl in pdf_parser.outlines: tks = set([t[i] + t[i + 1] for i in range(len(t) - 1)]) - tks_ = set([txt[i] + txt[i + 1] for i in range(min(len(t), len(txt) - 1))]) + tks_ = set([txt[i] + txt[i + 1] + for i in range(min(len(t), len(txt) - 1))]) if len(set(tks & tks_)) / max([len(tks), len(tks_), 1]) > 0.8: levels.append(lvl) break @@ -87,24 +93,27 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca levels.append(max_lvl + 1) else: - bull = bullets_category([txt for txt,_,_ in sections]) - most_level, levels = title_frequency(bull, [(txt, l) for txt, l, poss in sections]) + bull = bullets_category([txt for txt, _, _ in sections]) + most_level, levels = title_frequency( + bull, [(txt, l) for txt, l, poss in sections]) assert len(sections) == len(levels) sec_ids = [] sid = 0 for i, lvl in enumerate(levels): - if lvl <= most_level and i > 0 and lvl != levels[i - 1]: sid += 1 + if lvl <= most_level and i > 0 and lvl != levels[i - 1]: + sid += 1 sec_ids.append(sid) # print(lvl, self.boxes[i]["text"], most_level, sid) - sections = [(txt, sec_ids[i], poss) for i, (txt, _, poss) in enumerate(sections)] + sections = [(txt, sec_ids[i], poss) + for i, (txt, _, poss) in enumerate(sections)] for (img, rows), poss in tbls: sections.append((rows if isinstance(rows, str) else rows[0], -1, [(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss])) def tag(pn, left, right, top, bottom): - if pn+left+right+top+bottom == 0: + if pn + left + right + top + bottom == 0: return "" return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \ .format(pn, left, right, top, bottom) @@ -112,7 +121,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca chunks = [] last_sid = -2 tk_cnt = 0 - for txt, sec_id, poss in sorted(sections, key=lambda x: (x[-1][0][0], x[-1][0][3], x[-1][0][1])): + for txt, sec_id, poss in sorted(sections, key=lambda x: ( + x[-1][0][0], x[-1][0][3], x[-1][0][1])): poss = "\t".join([tag(*pos) for pos in poss]) if tk_cnt < 2048 and (sec_id == last_sid or sec_id == -1): if chunks: @@ -121,16 +131,17 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca continue chunks.append(txt + poss) tk_cnt = num_tokens_from_string(txt) - if sec_id > -1: last_sid = sec_id + if sec_id > -1: + last_sid = sec_id res = tokenize_table(tbls, doc, eng) res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser)) return res - if __name__ == "__main__": import sys + def dummy(prog=None, msg=""): pass chunk(sys.argv[1], callback=dummy) diff --git a/rag/app/naive.py b/rag/app/naive.py index a92f2e3..6dad8a2 100644 --- a/rag/app/naive.py +++ b/rag/app/naive.py @@ -44,11 +44,14 @@ class Pdf(PdfParser): tbls = self._extract_table_figure(True, zoomin, True, True) self._naive_vertical_merge() - cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1))) - return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls + cron_logger.info("paddle layouts:".format( + (timer() - start) / (self.total_page + 0.1))) + return [(b["text"], self._line_tag(b, zoomin)) + for b in self.boxes], tbls -def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): +def chunk(filename, binary=None, from_page=0, to_page=100000, + lang="Chinese", callback=None, **kwargs): """ Supported file formats are docx, pdf, excel, txt. This method apply the naive ways to chunk files. @@ -56,8 +59,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'. """ - eng = lang.lower() == "english"#is_english(cks) - parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;ďĽďĽź", "layout_recognize": True}) + eng = lang.lower() == "english" # is_english(cks) + parser_config = kwargs.get( + "parser_config", { + "chunk_token_num": 128, "delimiter": "\n!?。;ďĽďĽź", "layout_recognize": True}) doc = { "docnm_kwd": filename, "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename)) @@ -73,9 +78,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca callback(0.8, "Finish parsing.") elif re.search(r"\.pdf$", filename, re.IGNORECASE): - pdf_parser = Pdf() if parser_config["layout_recognize"] else PlainParser() + pdf_parser = Pdf( + ) if parser_config["layout_recognize"] else PlainParser() sections, tbls = pdf_parser(filename if not binary else binary, - from_page=from_page, to_page=to_page, callback=callback) + from_page=from_page, to_page=to_page, callback=callback) res = tokenize_table(tbls, doc, eng) elif re.search(r"\.xlsx?$", filename, re.IGNORECASE): @@ -92,16 +98,21 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca with open(filename, "r") as f: while True: l = f.readline() - if not l: break + if not l: + break txt += l sections = txt.split("\n") sections = [(l, "") for l in sections if l] callback(0.8, "Finish parsing.") else: - raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") + raise NotImplementedError( + "file type not supported yet(docx, pdf, txt supported)") - chunks = naive_merge(sections, parser_config.get("chunk_token_num", 128), parser_config.get("delimiter", "\n!?。;ďĽďĽź")) + chunks = naive_merge( + sections, parser_config.get( + "chunk_token_num", 128), parser_config.get( + "delimiter", "\n!?。;ďĽďĽź")) res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser)) return res @@ -110,9 +121,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca if __name__ == "__main__": import sys - def dummy(prog=None, msg=""): pass - chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) diff --git a/rag/app/one.py b/rag/app/one.py index 998cc56..bd90461 100644 --- a/rag/app/one.py +++ b/rag/app/one.py @@ -41,20 +41,23 @@ class Pdf(PdfParser): tbls = self._extract_table_figure(True, zoomin, True, True) self._concat_downward() - sections = [(b["text"], self.get_position(b, zoomin)) for i, b in enumerate(self.boxes)] + sections = [(b["text"], self.get_position(b, zoomin)) + for i, b in enumerate(self.boxes)] for (img, rows), poss in tbls: sections.append((rows if isinstance(rows, str) else rows[0], [(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss])) - return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (x[-1][0][0], x[-1][0][3], x[-1][0][1]))], None + return [(txt, "") for txt, _ in sorted(sections, key=lambda x: ( + x[-1][0][0], x[-1][0][3], x[-1][0][1]))], None -def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): +def chunk(filename, binary=None, from_page=0, to_page=100000, + lang="Chinese", callback=None, **kwargs): """ Supported file formats are docx, pdf, excel, txt. One file forms a chunk which maintains original text order. """ - eng = lang.lower() == "english"#is_english(cks) + eng = lang.lower() == "english" # is_english(cks) if re.search(r"\.docx?$", filename, re.IGNORECASE): callback(0.1, "Start to parse.") @@ -62,8 +65,11 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca callback(0.8, "Finish parsing.") elif re.search(r"\.pdf$", filename, re.IGNORECASE): - pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser() - sections, _ = pdf_parser(filename if not binary else binary, to_page=to_page, callback=callback) + pdf_parser = Pdf() if kwargs.get( + "parser_config", {}).get( + "layout_recognize", True) else PlainParser() + sections, _ = pdf_parser( + filename if not binary else binary, to_page=to_page, callback=callback) sections = [s for s, _ in sections if s] elif re.search(r"\.xlsx?$", filename, re.IGNORECASE): @@ -80,14 +86,16 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca with open(filename, "r") as f: while True: l = f.readline() - if not l: break + if not l: + break txt += l sections = txt.split("\n") sections = [s for s in sections if s] callback(0.8, "Finish parsing.") else: - raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") + raise NotImplementedError( + "file type not supported yet(docx, pdf, txt supported)") doc = { "docnm_kwd": filename, @@ -101,9 +109,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca if __name__ == "__main__": import sys - def dummy(prog=None, msg=""): pass - chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) diff --git a/rag/app/paper.py b/rag/app/paper.py index c3cb298..8725054 100644 --- a/rag/app/paper.py +++ b/rag/app/paper.py @@ -67,11 +67,11 @@ class Pdf(PdfParser): if from_page > 0: return { - "title":"", + "title": "", "authors": "", "abstract": "", "sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes if - re.match(r"(text|title)", b.get("layoutno", "text"))], + re.match(r"(text|title)", b.get("layoutno", "text"))], "tables": tbls } # get title and authors @@ -87,7 +87,8 @@ class Pdf(PdfParser): title = "" break for j in range(3): - if _begin(self.boxes[i + j]["text"]): break + if _begin(self.boxes[i + j]["text"]): + break authors.append(self.boxes[i + j]["text"]) break break @@ -107,10 +108,15 @@ class Pdf(PdfParser): abstr = txt + self._line_tag(self.boxes[i], zoomin) i += 1 break - if not abstr: i = 0 + if not abstr: + i = 0 - callback(0.8, "Page {}~{}: Text merging finished".format(from_page, min(to_page, self.total_page))) - for b in self.boxes: print(b["text"], b.get("layoutno")) + callback( + 0.8, "Page {}~{}: Text merging finished".format( + from_page, min( + to_page, self.total_page))) + for b in self.boxes: + print(b["text"], b.get("layoutno")) print(tbls) return { @@ -118,19 +124,20 @@ class Pdf(PdfParser): "authors": " ".join(authors), "abstract": abstr, "sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes[i:] if - re.match(r"(text|title)", b.get("layoutno", "text"))], + re.match(r"(text|title)", b.get("layoutno", "text"))], "tables": tbls } -def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): +def chunk(filename, binary=None, from_page=0, to_page=100000, + lang="Chinese", callback=None, **kwargs): """ Only pdf is supported. The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly. """ pdf_parser = None if re.search(r"\.pdf$", filename, re.IGNORECASE): - if not kwargs.get("parser_config",{}).get("layout_recognize", True): + if not kwargs.get("parser_config", {}).get("layout_recognize", True): pdf_parser = PlainParser() paper = { "title": filename, @@ -143,14 +150,15 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca pdf_parser = Pdf() paper = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback) - else: raise NotImplementedError("file type not supported yet(pdf supported)") + else: + raise NotImplementedError("file type not supported yet(pdf supported)") doc = {"docnm_kwd": filename, "authors_tks": huqie.qie(paper["authors"]), "title_tks": huqie.qie(paper["title"] if paper["title"] else filename)} doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) doc["authors_sm_tks"] = huqie.qieqie(doc["authors_tks"]) # is it English - eng = lang.lower() == "english"#pdf_parser.is_english + eng = lang.lower() == "english" # pdf_parser.is_english print("It's English.....", eng) res = tokenize_table(paper["tables"], doc, eng) @@ -160,7 +168,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca txt = pdf_parser.remove_tag(paper["abstract"]) d["important_kwd"] = ["abstract", "总结", "概括", "summary", "summarize"] d["important_tks"] = " ".join(d["important_kwd"]) - d["image"], poss = pdf_parser.crop(paper["abstract"], need_position=True) + d["image"], poss = pdf_parser.crop( + paper["abstract"], need_position=True) add_positions(d, poss) tokenize(d, txt, eng) res.append(d) @@ -174,7 +183,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca sec_ids = [] sid = 0 for i, lvl in enumerate(levels): - if lvl <= most_level and i > 0 and lvl != levels[i-1]: sid += 1 + if lvl <= most_level and i > 0 and lvl != levels[i - 1]: + sid += 1 sec_ids.append(sid) print(lvl, sorted_sections[i][0], most_level, sid) @@ -190,6 +200,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser)) return res + """ readed = [0] * len(paper["lines"]) # find colon firstly @@ -212,7 +223,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca for k in range(j, i): readed[k] = True txt = txt[::-1] if eng: - r = re.search(r"(.*?) ([\.;?!]|$)", txt) + r = re.search(r"(.*?) ([\\.;?!]|$)", txt) txt = r.group(1)[::-1] if r else txt[::-1] else: r = re.search(r"(.*?) ([。?;ďĽ]|$)", txt) @@ -270,6 +281,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca if __name__ == "__main__": import sys + def dummy(prog=None, msg=""): pass chunk(sys.argv[1], callback=dummy) diff --git a/rag/app/presentation.py b/rag/app/presentation.py index 1d1c38e..be4525b 100644 --- a/rag/app/presentation.py +++ b/rag/app/presentation.py @@ -33,9 +33,12 @@ class Ppt(PptParser): with slides.Presentation(BytesIO(fnm)) as presentation: for i, slide in enumerate(presentation.slides[from_page: to_page]): buffered = BytesIO() - slide.get_thumbnail(0.5, 0.5).save(buffered, drawing.imaging.ImageFormat.jpeg) + slide.get_thumbnail( + 0.5, 0.5).save( + buffered, drawing.imaging.ImageFormat.jpeg) imgs.append(Image.open(buffered)) - assert len(imgs) == len(txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts)) + assert len(imgs) == len( + txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts)) callback(0.9, "Image extraction finished") self.is_english = is_english(txts) return [(txts[i], imgs[i]) for i in range(len(txts))] @@ -47,25 +50,34 @@ class Pdf(PdfParser): def __garbage(self, txt): txt = txt.lower().strip() - if re.match(r"[0-9\.,%/-]+$", txt): return True - if len(txt) < 3:return True + if re.match(r"[0-9\.,%/-]+$", txt): + return True + if len(txt) < 3: + return True return False - def __call__(self, filename, binary=None, from_page=0, to_page=100000, zoomin=3, callback=None): + def __call__(self, filename, binary=None, from_page=0, + to_page=100000, zoomin=3, callback=None): callback(msg="OCR is running...") - self.__images__(filename if not binary else binary, zoomin, from_page, to_page, callback) - callback(0.8, "Page {}~{}: OCR finished".format(from_page, min(to_page, self.total_page))) - assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(len(self.boxes), len(self.page_images)) + self.__images__(filename if not binary else binary, + zoomin, from_page, to_page, callback) + callback(0.8, "Page {}~{}: OCR finished".format( + from_page, min(to_page, self.total_page))) + assert len(self.boxes) == len(self.page_images), "{} vs. {}".format( + len(self.boxes), len(self.page_images)) res = [] for i in range(len(self.boxes)): - lines = "\n".join([b["text"] for b in self.boxes[i] if not self.__garbage(b["text"])]) + lines = "\n".join([b["text"] for b in self.boxes[i] + if not self.__garbage(b["text"])]) res.append((lines, self.page_images[i])) - callback(0.9, "Page {}~{}: Parsing finished".format(from_page, min(to_page, self.total_page))) + callback(0.9, "Page {}~{}: Parsing finished".format( + from_page, min(to_page, self.total_page))) return res class PlainPdf(PlainParser): - def __call__(self, filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs): + def __call__(self, filename, binary=None, from_page=0, + to_page=100000, callback=None, **kwargs): self.pdf = pdf2_read(filename if not binary else BytesIO(binary)) page_txt = [] for page in self.pdf.pages[from_page: to_page]: @@ -74,7 +86,8 @@ class PlainPdf(PlainParser): return [(txt, None) for txt in page_txt] -def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): +def chunk(filename, binary=None, from_page=0, to_page=100000, + lang="Chinese", callback=None, **kwargs): """ The supported file formats are pdf, pptx. Every page will be treated as a chunk. And the thumbnail of every page will be stored. @@ -89,35 +102,42 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca res = [] if re.search(r"\.pptx?$", filename, re.IGNORECASE): ppt_parser = Ppt() - for pn, (txt,img) in enumerate(ppt_parser(filename if not binary else binary, from_page, 1000000, callback)): + for pn, (txt, img) in enumerate(ppt_parser( + filename if not binary else binary, from_page, 1000000, callback)): d = copy.deepcopy(doc) pn += from_page d["image"] = img - d["page_num_int"] = [pn+1] + d["page_num_int"] = [pn + 1] d["top_int"] = [0] d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])] tokenize(d, txt, eng) res.append(d) return res elif re.search(r"\.pdf$", filename, re.IGNORECASE): - pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainPdf() - for pn, (txt,img) in enumerate(pdf_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback)): + pdf_parser = Pdf() if kwargs.get( + "parser_config", {}).get( + "layout_recognize", True) else PlainPdf() + for pn, (txt, img) in enumerate(pdf_parser(filename, binary, + from_page=from_page, to_page=to_page, callback=callback)): d = copy.deepcopy(doc) pn += from_page - if img: d["image"] = img - d["page_num_int"] = [pn+1] + if img: + d["image"] = img + d["page_num_int"] = [pn + 1] d["top_int"] = [0] - d["position_int"] = [(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)] + d["position_int"] = [ + (pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)] tokenize(d, txt, eng) res.append(d) return res - raise NotImplementedError("file type not supported yet(pptx, pdf supported)") + raise NotImplementedError( + "file type not supported yet(pptx, pdf supported)") -if __name__== "__main__": +if __name__ == "__main__": import sys + def dummy(a, b): pass chunk(sys.argv[1], callback=dummy) - diff --git a/rag/app/resume.py b/rag/app/resume.py index cced856..d341bcb 100644 --- a/rag/app/resume.py +++ b/rag/app/resume.py @@ -27,6 +27,8 @@ from rag.utils import rmSpace forbidden_select_fields4resume = [ "name_pinyin_kwd", "edu_first_fea_kwd", "degree_kwd", "sch_rank_kwd", "edu_fea_kwd" ] + + def remote_call(filename, binary): q = { "header": { @@ -48,18 +50,22 @@ def remote_call(filename, binary): } for _ in range(3): try: - resume = requests.post("http://127.0.0.1:61670/tog", data=json.dumps(q)) + resume = requests.post( + "http://127.0.0.1:61670/tog", + data=json.dumps(q)) resume = resume.json()["response"]["results"] resume = refactor(resume) - for k in ["education", "work", "project", "training", "skill", "certificate", "language"]: - if not resume.get(k) and k in resume: del resume[k] + for k in ["education", "work", "project", + "training", "skill", "certificate", "language"]: + if not resume.get(k) and k in resume: + del resume[k] resume = step_one.refactor(pd.DataFrame([{"resume_content": json.dumps(resume), "tob_resume_id": "x", - "updated_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}])) + "updated_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}])) resume = step_two.parse(resume) return resume except Exception as e: - cron_logger.error("Resume parser error: "+str(e)) + cron_logger.error("Resume parser error: " + str(e)) return {} @@ -144,10 +150,13 @@ def chunk(filename, binary=None, callback=None, **kwargs): doc["content_ltks"] = huqie.qie(doc["content_with_weight"]) doc["content_sm_ltks"] = huqie.qieqie(doc["content_ltks"]) for n, _ in field_map.items(): - if n not in resume:continue - if isinstance(resume[n], list) and (len(resume[n]) == 1 or n not in forbidden_select_fields4resume): + if n not in resume: + continue + if isinstance(resume[n], list) and ( + len(resume[n]) == 1 or n not in forbidden_select_fields4resume): resume[n] = resume[n][0] - if n.find("_tks")>0: resume[n] = huqie.qieqie(resume[n]) + if n.find("_tks") > 0: + resume[n] = huqie.qieqie(resume[n]) doc[n] = resume[n] print(doc) diff --git a/rag/app/table.py b/rag/app/table.py index 9512e9f..3d10527 100644 --- a/rag/app/table.py +++ b/rag/app/table.py @@ -25,7 +25,8 @@ from deepdoc.parser import ExcelParser class Excel(ExcelParser): - def __call__(self, fnm, binary=None, from_page=0, to_page=10000000000, callback=None): + def __call__(self, fnm, binary=None, from_page=0, + to_page=10000000000, callback=None): if not binary: wb = load_workbook(fnm) else: @@ -48,8 +49,10 @@ class Excel(ExcelParser): data = [] for i, r in enumerate(rows[1:]): rn += 1 - if rn-1 < from_page:continue - if rn -1>=to_page: break + if rn - 1 < from_page: + continue + if rn - 1 >= to_page: + break row = [ cell.value for ii, cell in enumerate(r) if ii not in missed] @@ -60,7 +63,7 @@ class Excel(ExcelParser): done += 1 res.append(pd.DataFrame(np.array(data), columns=headers)) - callback(0.3, ("Extract records: {}~{}".format(from_page+1, min(to_page, from_page+rn)) + ( + callback(0.3, ("Extract records: {}~{}".format(from_page + 1, min(to_page, from_page + rn)) + ( f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) return res @@ -73,7 +76,8 @@ def trans_datatime(s): def trans_bool(s): - if re.match(r"(true|yes|ćŻ|\*|âś“|âś”|â‘|âś…|âš)$", str(s).strip(), flags=re.IGNORECASE): + if re.match(r"(true|yes|ćŻ|\*|âś“|âś”|â‘|âś…|âš)$", + str(s).strip(), flags=re.IGNORECASE): return "yes" if re.match(r"(false|no|ĺ¦|⍻|Ă—)$", str(s).strip(), flags=re.IGNORECASE): return "no" @@ -107,13 +111,14 @@ def column_data_type(arr): arr[i] = trans[ty](str(arr[i])) except Exception as e: arr[i] = None - #if ty == "text": + # if ty == "text": # if len(arr) > 128 and uni / len(arr) < 0.1: # ty = "keyword" return arr, ty -def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese", callback=None, **kwargs): +def chunk(filename, binary=None, from_page=0, to_page=10000000000, + lang="Chinese", callback=None, **kwargs): """ Excel and csv(txt) format files are supported. For csv or txt file, the delimiter between columns is TAB. @@ -131,7 +136,12 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese if re.search(r"\.xlsx?$", filename, re.IGNORECASE): callback(0.1, "Start to parse.") excel_parser = Excel() - dfs = excel_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback) + dfs = excel_parser( + filename, + binary, + from_page=from_page, + to_page=to_page, + callback=callback) elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE): callback(0.1, "Start to parse.") txt = "" @@ -149,8 +159,10 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese headers = lines[0].split(kwargs.get("delimiter", "\t")) rows = [] for i, line in enumerate(lines[1:]): - if i < from_page:continue - if i >= to_page: break + if i < from_page: + continue + if i >= to_page: + break row = [l for l in line.split(kwargs.get("delimiter", "\t"))] if len(row) != len(headers): fails.append(str(i)) @@ -181,7 +193,13 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese del df[n] clmns = df.columns.values txts = list(copy.deepcopy(clmns)) - py_clmns = [PY.get_pinyins(re.sub(r"(/.*|ďĽ[^ďĽďĽ‰]+?)|\([^()]+?\))", "", n), '_')[0] for n in clmns] + py_clmns = [ + PY.get_pinyins( + re.sub( + r"(/.*|ďĽ[^ďĽďĽ‰]+?)|\([^()]+?\))", + "", + n), + '_')[0] for n in clmns] clmn_tys = [] for j in range(len(clmns)): cln, ty = column_data_type(df[clmns[j]]) @@ -192,7 +210,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], clmns[i].replace("_", " ")) for i in range(len(clmns))] - eng = lang.lower() == "english"#is_english(txts) + eng = lang.lower() == "english" # is_english(txts) for ii, row in df.iterrows(): d = { "docnm_kwd": filename, diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 2d9d9f5..d5ddbc0 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from zhipuai import ZhipuAI +from dashscope import Generation from abc import ABC from openai import OpenAI import openai @@ -34,7 +36,8 @@ class GptTurbo(Base): self.model_name = model_name def chat(self, system, history, gen_conf): - if system: history.insert(0, {"role": "system", "content": system}) + if system: + history.insert(0, {"role": "system", "content": system}) try: response = self.client.chat.completions.create( model=self.model_name, @@ -46,16 +49,18 @@ class GptTurbo(Base): [ans]) else "······\nç”±äşŽé•żĺş¦çš„ĺŽźĺ› ďĽŚĺ›žç”被ćŞć–了,č¦ç»§ç»ĺ—?" return ans, response.usage.completion_tokens except openai.APIError as e: - return "**ERROR**: "+str(e), 0 + return "**ERROR**: " + str(e), 0 class MoonshotChat(GptTurbo): def __init__(self, key, model_name="moonshot-v1-8k"): - self.client = OpenAI(api_key=key, base_url="https://api.moonshot.cn/v1",) + self.client = OpenAI( + api_key=key, base_url="https://api.moonshot.cn/v1",) self.model_name = model_name def chat(self, system, history, gen_conf): - if system: history.insert(0, {"role": "system", "content": system}) + if system: + history.insert(0, {"role": "system", "content": system}) try: response = self.client.chat.completions.create( model=self.model_name, @@ -67,10 +72,9 @@ class MoonshotChat(GptTurbo): [ans]) else "······\nç”±äşŽé•żĺş¦çš„ĺŽźĺ› ďĽŚĺ›žç”被ćŞć–了,č¦ç»§ç»ĺ—?" return ans, response.usage.completion_tokens except openai.APIError as e: - return "**ERROR**: "+str(e), 0 + return "**ERROR**: " + str(e), 0 -from dashscope import Generation class QWenChat(Base): def __init__(self, key, model_name=Generation.Models.qwen_turbo): import dashscope @@ -79,7 +83,8 @@ class QWenChat(Base): def chat(self, system, history, gen_conf): from http import HTTPStatus - if system: history.insert(0, {"role": "system", "content": system}) + if system: + history.insert(0, {"role": "system", "content": system}) response = Generation.call( self.model_name, messages=history, @@ -92,20 +97,21 @@ class QWenChat(Base): ans += response.output.choices[0]['message']['content'] tk_count += response.usage.output_tokens if response.output.choices[0].get("finish_reason", "") == "length": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\nç”±äşŽé•żĺş¦çš„ĺŽźĺ› ďĽŚĺ›žç”被ćŞć–了,č¦ç»§ç»ĺ—?" + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\nç”±äşŽé•żĺş¦çš„ĺŽźĺ› ďĽŚĺ›žç”被ćŞć–了,č¦ç»§ç»ĺ—?" return ans, tk_count return "**ERROR**: " + response.message, tk_count -from zhipuai import ZhipuAI class ZhipuChat(Base): def __init__(self, key, model_name="glm-3-turbo"): self.client = ZhipuAI(api_key=key) self.model_name = model_name def chat(self, system, history, gen_conf): - if system: history.insert(0, {"role": "system", "content": system}) + if system: + history.insert(0, {"role": "system", "content": system}) try: response = self.client.chat.completions.create( self.model_name, @@ -120,6 +126,7 @@ class ZhipuChat(Base): except Exception as e: return "**ERROR**: " + str(e), 0 + class LocalLLM(Base): class RPCProxy: def __init__(self, host, port): @@ -129,14 +136,17 @@ class LocalLLM(Base): def __conn(self): from multiprocessing.connection import Client - self._connection = Client((self.host, self.port), authkey=b'infiniflow-token4kevinhu') + self._connection = Client( + (self.host, self.port), authkey=b'infiniflow-token4kevinhu') def __getattr__(self, name): import pickle + def do_rpc(*args, **kwargs): for _ in range(3): try: - self._connection.send(pickle.dumps((name, args, kwargs))) + self._connection.send( + pickle.dumps((name, args, kwargs))) return pickle.loads(self._connection.recv()) except Exception as e: self.__conn() @@ -148,7 +158,8 @@ class LocalLLM(Base): self.client = LocalLLM.RPCProxy("127.0.0.1", 7860) def chat(self, system, history, gen_conf): - if system: history.insert(0, {"role": "system", "content": system}) + if system: + history.insert(0, {"role": "system", "content": system}) try: ans = self.client.chat( history, diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 6e139c7..cb89509 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from zhipuai import ZhipuAI import io from abc import ABC @@ -57,8 +58,8 @@ class Base(ABC): }, }, { - "text": "请用ä¸ć–‡čŻ¦ç»†ćŹŹčż°ä¸€ä¸‹ĺ›ľä¸çš„内容,比如时间,地点,人物,事ć…,人物ĺżć…ç‰ďĽŚĺ¦‚果有数据请ćŹĺŹ–出数据。" if self.lang.lower() == "chinese" else \ - "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", + "text": "请用ä¸ć–‡čŻ¦ç»†ćŹŹčż°ä¸€ä¸‹ĺ›ľä¸çš„内容,比如时间,地点,人物,事ć…,人物ĺżć…ç‰ďĽŚĺ¦‚果有数据请ćŹĺŹ–出数据。" if self.lang.lower() == "chinese" else + "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", }, ], } @@ -92,8 +93,9 @@ class QWenCV(Base): def prompt(self, binary): # stupid as hell tmp_dir = get_project_base_directory("tmp") - if not os.path.exists(tmp_dir): os.mkdir(tmp_dir) - path = os.path.join(tmp_dir, "%s.jpg"%get_uuid()) + if not os.path.exists(tmp_dir): + os.mkdir(tmp_dir) + path = os.path.join(tmp_dir, "%s.jpg" % get_uuid()) Image.open(io.BytesIO(binary)).save(path) return [ { @@ -103,8 +105,8 @@ class QWenCV(Base): "image": f"file://{path}" }, { - "text": "请用ä¸ć–‡čŻ¦ç»†ćŹŹčż°ä¸€ä¸‹ĺ›ľä¸çš„内容,比如时间,地点,人物,事ć…,人物ĺżć…ç‰ďĽŚĺ¦‚果有数据请ćŹĺŹ–出数据。" if self.lang.lower() == "chinese" else \ - "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", + "text": "请用ä¸ć–‡čŻ¦ç»†ćŹŹčż°ä¸€ä¸‹ĺ›ľä¸çš„内容,比如时间,地点,人物,事ć…,人物ĺżć…ç‰ďĽŚĺ¦‚果有数据请ćŹĺŹ–出数据。" if self.lang.lower() == "chinese" else + "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", }, ], } @@ -120,9 +122,6 @@ class QWenCV(Base): return response.message, 0 -from zhipuai import ZhipuAI - - class Zhipu4V(Base): def __init__(self, key, model_name="glm-4v", lang="Chinese"): self.client = ZhipuAI(api_key=key) diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index d5dc2dd..c2fe24b 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from zhipuai import ZhipuAI import os from abc import ABC @@ -40,11 +41,11 @@ flag_model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句ĺ生ćčˇ¨ç¤şä»Ąç”¨äşŽćŁ€ç´˘ç›¸ĺ…łć–‡ç« ďĽš", use_fp16=torch.cuda.is_available()) + class Base(ABC): def __init__(self, key, model_name): pass - def encode(self, texts: list, batch_size=32): raise NotImplementedError("Please implement encode method!") @@ -67,11 +68,11 @@ class HuEmbedding(Base): """ self.model = flag_model - def encode(self, texts: list, batch_size=32): texts = [t[:2000] for t in texts] token_count = 0 - for t in texts: token_count += num_tokens_from_string(t) + for t in texts: + token_count += num_tokens_from_string(t) res = [] for i in range(0, len(texts), batch_size): res.extend(self.model.encode(texts[i:i + batch_size]).tolist()) @@ -90,7 +91,8 @@ class OpenAIEmbed(Base): def encode(self, texts: list, batch_size=32): res = self.client.embeddings.create(input=texts, model=self.model_name) - return np.array([d.embedding for d in res.data]), res.usage.total_tokens + return np.array([d.embedding for d in res.data] + ), res.usage.total_tokens def encode_queries(self, text): res = self.client.embeddings.create(input=[text], @@ -111,7 +113,7 @@ class QWenEmbed(Base): for i in range(0, len(texts), batch_size): resp = dashscope.TextEmbedding.call( model=self.model_name, - input=texts[i:i+batch_size], + input=texts[i:i + batch_size], text_type="document" ) embds = [[] for _ in range(len(resp["output"]["embeddings"]))] @@ -123,14 +125,14 @@ class QWenEmbed(Base): def encode_queries(self, text): resp = dashscope.TextEmbedding.call( - model=self.model_name, - input=text[:2048], - text_type="query" - ) - return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["total_tokens"] + model=self.model_name, + input=text[:2048], + text_type="query" + ) + return np.array(resp["output"]["embeddings"][0] + ["embedding"]), resp["usage"]["total_tokens"] -from zhipuai import ZhipuAI class ZhipuEmbed(Base): def __init__(self, key, model_name="embedding-2"): self.client = ZhipuAI(api_key=key) @@ -139,9 +141,10 @@ class ZhipuEmbed(Base): def encode(self, texts: list, batch_size=32): res = self.client.embeddings.create(input=texts, model=self.model_name) - return np.array([d.embedding for d in res.data]), res.usage.total_tokens + return np.array([d.embedding for d in res.data] + ), res.usage.total_tokens def encode_queries(self, text): res = self.client.embeddings.create(input=text, model=self.model_name) - return np.array(res["data"][0]["embedding"]), res.usage.total_tokens \ No newline at end of file + return np.array(res["data"][0]["embedding"]), res.usage.total_tokens diff --git a/rag/llm/rpc_server.py b/rag/llm/rpc_server.py index e1e8b82..ce15d74 100644 --- a/rag/llm/rpc_server.py +++ b/rag/llm/rpc_server.py @@ -9,7 +9,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer class RPCHandler: def __init__(self): - self._functions = { } + self._functions = {} def register_function(self, func): self._functions[func.__name__] = func @@ -21,12 +21,12 @@ class RPCHandler: func_name, args, kwargs = pickle.loads(connection.recv()) # Run the RPC and send a response try: - r = self._functions[func_name](*args,**kwargs) + r = self._functions[func_name](*args, **kwargs) connection.send(pickle.dumps(r)) except Exception as e: connection.send(pickle.dumps(e)) except EOFError: - pass + pass def rpc_server(hdlr, address, authkey): @@ -44,11 +44,17 @@ def rpc_server(hdlr, address, authkey): models = [] tokenizer = None + def chat(messages, gen_conf): global tokenizer model = Model() try: - conf = {"max_new_tokens": int(gen_conf.get("max_tokens", 256)), "temperature": float(gen_conf.get("temperature", 0.1))} + conf = { + "max_new_tokens": int( + gen_conf.get( + "max_tokens", 256)), "temperature": float( + gen_conf.get( + "temperature", 0.1))} print(messages, conf) text = tokenizer.apply_chat_template( messages, @@ -65,7 +71,8 @@ def chat(messages, gen_conf): output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] - return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + return tokenizer.batch_decode( + generated_ids, skip_special_tokens=True)[0] except Exception as e: return str(e) @@ -75,10 +82,15 @@ def Model(): random.seed(time.time()) return random.choice(models) + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model_name", type=str, help="Model name") - parser.add_argument("--port", default=7860, type=int, help="RPC serving port") + parser.add_argument( + "--port", + default=7860, + type=int, + help="RPC serving port") args = parser.parse_args() handler = RPCHandler() @@ -93,4 +105,5 @@ if __name__ == "__main__": tokenizer = AutoTokenizer.from_pretrained(args.model_name) # Run the server - rpc_server(handler, ('0.0.0.0', args.port), authkey=b'infiniflow-token4kevinhu') + rpc_server(handler, ('0.0.0.0', args.port), + authkey=b'infiniflow-token4kevinhu') diff --git a/rag/nlp/huchunk.py b/rag/nlp/huchunk.py index bb2d46f..8c4c6fc 100644 --- a/rag/nlp/huchunk.py +++ b/rag/nlp/huchunk.py @@ -372,7 +372,8 @@ class PptChunker(HuChunker): tb = shape.table rows = [] for i in range(1, len(tb.rows)): - rows.append("; ".join([tb.cell(0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)])) + rows.append("; ".join([tb.cell( + 0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)])) return "\n".join(rows) if shape.has_text_frame: @@ -382,7 +383,8 @@ class PptChunker(HuChunker): texts = [] for p in shape.shapes: t = self.__extract(p) - if t: texts.append(t) + if t: + texts.append(t) return "\n".join(texts) def __call__(self, fnm): @@ -395,7 +397,8 @@ class PptChunker(HuChunker): texts = [] for shape in slide.shapes: txt = self.__extract(shape) - if txt: texts.append(txt) + if txt: + texts.append(txt) txts.append("\n".join(texts)) import aspose.slides as slides @@ -404,9 +407,12 @@ class PptChunker(HuChunker): with slides.Presentation(BytesIO(fnm)) as presentation: for slide in presentation.slides: buffered = BytesIO() - slide.get_thumbnail(0.5, 0.5).save(buffered, drawing.imaging.ImageFormat.jpeg) + slide.get_thumbnail( + 0.5, 0.5).save( + buffered, drawing.imaging.ImageFormat.jpeg) imgs.append(buffered.getvalue()) - assert len(imgs) == len(txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts)) + assert len(imgs) == len( + txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts)) flds = self.Fields() flds.text_chunks = [(txts[i], imgs[i]) for i in range(len(txts))] @@ -445,7 +451,8 @@ class TextChunker(HuChunker): if isinstance(fnm, str): with open(fnm, "r") as f: txt = f.read() - else: txt = fnm.decode("utf-8") + else: + txt = fnm.decode("utf-8") flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)] flds.table_chunks = [] return flds diff --git a/rag/nlp/query.py b/rag/nlp/query.py index aac5d2a..ea5b5cb 100644 --- a/rag/nlp/query.py +++ b/rag/nlp/query.py @@ -149,7 +149,8 @@ class EsQueryer: atks = toDict(atks) btkss = [toDict(tks) for tks in btkss] tksim = [self.similarity(atks, btks) for btks in btkss] - return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0] + return np.array(sims[0]) * vtweight + \ + np.array(tksim) * tkweight, tksim, sims[0] def similarity(self, qtwt, dtwt): if isinstance(dtwt, type("")): @@ -159,11 +160,11 @@ class EsQueryer: s = 1e-9 for k, v in qtwt.items(): if k in dtwt: - s += v# * dtwt[k] + s += v # * dtwt[k] q = 1e-9 for k, v in qtwt.items(): - q += v #* v + q += v # * v #d = 1e-9 - #for k, v in dtwt.items(): + # for k, v in dtwt.items(): # d += v * v - return s / q #math.sqrt(q) / math.sqrt(d) + return s / q # math.sqrt(q) / math.sqrt(d) diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 94fbe8e..a3a644f 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -80,14 +80,18 @@ class Dealer: if not req.get("sort"): s = s.sort( {"create_time": {"order": "desc", "unmapped_type": "date"}}, - {"create_timestamp_flt": {"order": "desc", "unmapped_type": "float"}} + {"create_timestamp_flt": { + "order": "desc", "unmapped_type": "float"}} ) else: s = s.sort( - {"page_num_int": {"order": "asc", "unmapped_type": "float", "mode": "avg", "numeric_type": "double"}}, - {"top_int": {"order": "asc", "unmapped_type": "float", "mode": "avg", "numeric_type": "double"}}, + {"page_num_int": {"order": "asc", "unmapped_type": "float", + "mode": "avg", "numeric_type": "double"}}, + {"top_int": {"order": "asc", "unmapped_type": "float", + "mode": "avg", "numeric_type": "double"}}, {"create_time": {"order": "desc", "unmapped_type": "date"}}, - {"create_timestamp_flt": {"order": "desc", "unmapped_type": "float"}} + {"create_timestamp_flt": { + "order": "desc", "unmapped_type": "float"}} ) if qst: @@ -180,11 +184,13 @@ class Dealer: m = {n: d.get(n) for n in flds if d.get(n) is not None} for n, v in m.items(): if isinstance(v, type([])): - m[n] = "\t".join([str(vv) if not isinstance(vv, list) else "\t".join([str(vvv) for vvv in vv]) for vv in v]) + m[n] = "\t".join([str(vv) if not isinstance( + vv, list) else "\t".join([str(vvv) for vvv in vv]) for vv in v]) continue if not isinstance(v, type("")): m[n] = str(m[n]) - if n.find("tks")>0: m[n] = rmSpace(m[n]) + if n.find("tks") > 0: + m[n] = rmSpace(m[n]) if m: res[d["id"]] = m @@ -205,12 +211,16 @@ class Dealer: if pieces[i] == "```": st = i i += 1 - while i<len(pieces) and pieces[i] != "```": + while i < len(pieces) and pieces[i] != "```": i += 1 - if i < len(pieces): i += 1 - pieces_.append("".join(pieces[st: i])+"\n") + if i < len(pieces): + i += 1 + pieces_.append("".join(pieces[st: i]) + "\n") else: - pieces_.extend(re.split(r"([^\|][;。?!ďĽ\n]|[a-z][.?;!][ \n])", pieces[i])) + pieces_.extend( + re.split( + r"([^\|][;。?!ďĽ\n]|[a-z][.?;!][ \n])", + pieces[i])) i += 1 pieces = pieces_ else: @@ -234,7 +244,8 @@ class Dealer: assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format( len(ans_v[0]), len(chunk_v[0])) - chunks_tks = [huqie.qie(self.qryr.rmWWW(ck)).split(" ") for ck in chunks] + chunks_tks = [huqie.qie(self.qryr.rmWWW(ck)).split(" ") + for ck in chunks] cites = {} for i, a in enumerate(pieces_): sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i], @@ -258,9 +269,11 @@ class Dealer: continue if i not in cites: continue - for c in cites[i]: assert int(c) < len(chunk_v) for c in cites[i]: - if c in seted:continue + assert int(c) < len(chunk_v) + for c in cites[i]: + if c in seted: + continue res += f" ##{c}$$" seted.add(c) @@ -343,7 +356,11 @@ class Dealer: if dnm not in ranks["doc_aggs"]: ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0} ranks["doc_aggs"][dnm]["count"] += 1 - ranks["doc_aggs"] = [{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k,v in sorted(ranks["doc_aggs"].items(), key=lambda x:x[1]["count"]*-1)] + ranks["doc_aggs"] = [{"doc_name": k, + "doc_id": v["doc_id"], + "count": v["count"]} for k, + v in sorted(ranks["doc_aggs"].items(), + key=lambda x:x[1]["count"] * -1)] return ranks @@ -354,10 +371,17 @@ class Dealer: replaces = [] for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql): fld, v = r.group(1), r.group(3) - match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(fld, huqie.qieqie(huqie.qie(v))) - replaces.append(("{}{}'{}'".format(r.group(1), r.group(2), r.group(3)), match)) - - for p, r in replaces: sql = sql.replace(p, r, 1) + match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format( + fld, huqie.qieqie(huqie.qie(v))) + replaces.append( + ("{}{}'{}'".format( + r.group(1), + r.group(2), + r.group(3)), + match)) + + for p, r in replaces: + sql = sql.replace(p, r, 1) chat_logger.info(f"To es: {sql}") try: @@ -366,4 +390,3 @@ class Dealer: except Exception as e: chat_logger.error(f"SQL failure: {sql} =>" + str(e)) return {"error": str(e)} - diff --git a/rag/nlp/term_weight.py b/rag/nlp/term_weight.py index 7be9d55..f1b3604 100644 --- a/rag/nlp/term_weight.py +++ b/rag/nlp/term_weight.py @@ -150,8 +150,10 @@ class Dealer: return 6 def ner(t): - if re.match(r"[0-9,.]{2,}$", t): return 2 - if re.match(r"[a-z]{1,2}$", t): return 0.01 + if re.match(r"[0-9,.]{2,}$", t): + return 2 + if re.match(r"[a-z]{1,2}$", t): + return 0.01 if not self.ne or t not in self.ne: return 1 m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3, diff --git a/rag/settings.py b/rag/settings.py index 7c2257a..f84831d 100644 --- a/rag/settings.py +++ b/rag/settings.py @@ -14,7 +14,7 @@ # limitations under the License. # import os -from api.utils import get_base_config,decrypt_database_config +from api.utils import get_base_config, decrypt_database_config from api.utils.file_utils import get_project_base_directory from api.utils.log_utils import LoggerFactory, getLogger @@ -28,7 +28,11 @@ MINIO = decrypt_database_config(name="minio") DOC_MAXIMUM_SIZE = 128 * 1024 * 1024 # Logger -LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "rag")) +LoggerFactory.set_directory( + os.path.join( + get_project_base_directory(), + "logs", + "rag")) # {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0} LoggerFactory.LEVEL = 10 @@ -37,4 +41,3 @@ minio_logger = getLogger("minio") cron_logger = getLogger("cron_logger") chunk_logger = getLogger("chunk_logger") database_logger = getLogger("database") - diff --git a/rag/svr/task_broker.py b/rag/svr/task_broker.py index 87a2965..0892a6c 100644 --- a/rag/svr/task_broker.py +++ b/rag/svr/task_broker.py @@ -47,7 +47,7 @@ def collect(tm): def set_dispatching(docid): try: DocumentService.update_by_id( - docid, {"progress": random.random()*1 / 100., + docid, {"progress": random.random() * 1 / 100., "progress_msg": "Task dispatched...", "process_begin_at": get_format_time() }) @@ -56,7 +56,10 @@ def set_dispatching(docid): def dispatch(): - tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"broker.tm") + tm_fnm = os.path.join( + get_project_base_directory(), + "rag/res", + f"broker.tm") tm = findMaxTm(tm_fnm) rows = collect(tm) if len(rows) == 0: @@ -82,17 +85,22 @@ def dispatch(): tsks = [] if r["type"] == FileType.PDF.value: do_layout = r["parser_config"].get("layout_recognize", True) - pages = PdfParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"])) + pages = PdfParser.total_page_number( + r["name"], MINIO.get(r["kb_id"], r["location"])) page_size = r["parser_config"].get("task_page_size", 12) - if r["parser_id"] == "paper": page_size = r["parser_config"].get("task_page_size", 22) - if r["parser_id"] == "one": page_size = 1000000000 - if not do_layout: page_size = 1000000000 + if r["parser_id"] == "paper": + page_size = r["parser_config"].get("task_page_size", 22) + if r["parser_id"] == "one": + page_size = 1000000000 + if not do_layout: + page_size = 1000000000 page_ranges = r["parser_config"].get("pages") - if not page_ranges: page_ranges = [(1, 100000)] - for s,e in page_ranges: + if not page_ranges: + page_ranges = [(1, 100000)] + for s, e in page_ranges: s -= 1 s = max(0, s) - e = min(e-1, pages) + e = min(e - 1, pages) for p in range(s, e, page_size): task = new_task() task["from_page"] = p @@ -100,12 +108,14 @@ def dispatch(): tsks.append(task) elif r["parser_id"] == "table": - rn = HuExcelParser.row_number(r["name"], MINIO.get(r["kb_id"], r["location"])) - for i in range(0, rn, 3000): - task = new_task() - task["from_page"] = i - task["to_page"] = min(i + 3000, rn) - tsks.append(task) + rn = HuExcelParser.row_number( + r["name"], MINIO.get( + r["kb_id"], r["location"])) + for i in range(0, rn, 3000): + task = new_task() + task["from_page"] = i + task["to_page"] = min(i + 3000, rn) + tsks.append(task) else: tsks.append(new_task()) @@ -120,27 +130,37 @@ def update_progress(): for d in docs: try: tsks = TaskService.query(doc_id=d["id"], order_by=Task.create_time) - if not tsks:continue + if not tsks: + continue msg = [] prg = 0 finished = True bad = 0 status = TaskStatus.RUNNING.value for t in tsks: - if 0 <= t.progress < 1: finished = False + if 0 <= t.progress < 1: + finished = False prg += t.progress if t.progress >= 0 else 0 msg.append(t.progress_msg) - if t.progress == -1: bad += 1 + if t.progress == -1: + bad += 1 prg /= len(tsks) if finished and bad: prg = -1 status = TaskStatus.FAIL.value - elif finished: status = TaskStatus.DONE.value + elif finished: + status = TaskStatus.DONE.value msg = "\n".join(msg) - info = {"process_duation": datetime.timestamp(datetime.now())-d["process_begin_at"].timestamp(), "run": status} - if prg !=0 : info["progress"] = prg - if msg: info["progress_msg"] = msg + info = { + "process_duation": datetime.timestamp( + datetime.now()) - + d["process_begin_at"].timestamp(), + "run": status} + if prg != 0: + info["progress"] = prg + if msg: + info["progress_msg"] = msg DocumentService.update_by_id(d["id"], info) except Exception as e: cron_logger.error("fetch task exception:" + str(e)) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 9036fce..517d8a2 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -67,7 +67,7 @@ FACTORY = { def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."): if prog is not None and prog < 0: - msg = "[ERROR]"+msg + msg = "[ERROR]" + msg cancel = TaskService.do_cancel(task_id) if cancel: msg += " [Canceled]" @@ -188,11 +188,13 @@ def embedding(docs, mdl, parser_config={}, callback=None): cnts_ = np.array([]) for i in range(0, len(cnts), batch_size): - vts, c = mdl.encode(cnts[i: i+batch_size]) - if len(cnts_) == 0: cnts_ = vts - else: cnts_ = np.concatenate((cnts_, vts), axis=0) + vts, c = mdl.encode(cnts[i: i + batch_size]) + if len(cnts_) == 0: + cnts_ = vts + else: + cnts_ = np.concatenate((cnts_, vts), axis=0) tk_count += c - callback(prog=0.7+0.2*(i+1)/len(cnts), msg="") + callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="") cnts = cnts_ title_w = float(parser_config.get("filename_embd_weight", 0.1)) @@ -234,7 +236,9 @@ def main(comm, mod): continue # TODO: exception handler ## set_progress(r["did"], -1, "ERROR: ") - callback(msg="Finished slicing files(%d). Start to embedding the content."%len(cks)) + callback( + msg="Finished slicing files(%d). Start to embedding the content." % + len(cks)) try: tk_count = embedding(cks, embd_mdl, r["parser_config"], callback) except Exception as e: @@ -249,7 +253,7 @@ def main(comm, mod): if es_r: callback(-1, "Index failure!") ELASTICSEARCH.deleteByQuery( - Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"])) + Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"])) cron_logger.error(str(es_r)) else: if TaskService.do_cancel(r["id"]): -- GitLab