From 5e0a689c4340b8c7a060aaaf0a913925cd4221cf Mon Sep 17 00:00:00 2001 From: KevinHuSh <kevinhu.sh@gmail.com> Date: Thu, 8 Feb 2024 17:01:01 +0800 Subject: [PATCH] refactor retieval_test, add SQl retrieval methods (#61) --- api/apps/chunk_app.py | 5 +++- api/apps/conversation_app.py | 2 ++ api/apps/document_app.py | 17 ++++++++--- api/db/__init__.py | 2 ++ api/db/db_models.py | 2 +- api/db/init_data.py | 39 ++++++++++++++++++++++-- api/settings.py | 2 +- rag/app/naive.py | 3 +- rag/app/qa.py | 48 +++++++++++++++++++---------- rag/app/resume.py | 55 +++++++++++++++++++++++----------- rag/app/table.py | 58 +++++++++++++++++++++++++----------- rag/llm/chat_model.py | 18 +++++++++++ rag/llm/cv_model.py | 21 ++++++++++++- rag/llm/embedding_model.py | 20 +++++++++++-- rag/nlp/search.py | 8 +++-- rag/svr/task_executor.py | 12 ++++---- 16 files changed, 238 insertions(+), 74 deletions(-) diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 11d60d5..11d9c7b 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -227,7 +227,7 @@ def retrieval_test(): doc_ids = req.get("doc_ids", []) similarity_threshold = float(req.get("similarity_threshold", 0.2)) vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) - top = int(req.get("top", 1024)) + top = int(req.get("top_k", 1024)) try: e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: @@ -237,6 +237,9 @@ def retrieval_test(): kb.tenant_id, LLMType.EMBEDDING.value) ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold, vector_similarity_weight, top, doc_ids) + for c in ranks["chunks"]: + if "vector" in c: + del c["vector"] return get_json_result(data=ranks) except Exception as e: diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index c48b586..30c3ee9 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -229,6 +229,7 @@ def use_sql(question,field_map, tenant_id, chat_mdl): sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.1}) sql = re.sub(r".*?select ", "select ", sql, flags=re.IGNORECASE) sql = re.sub(r" +", " ", sql) + sql = re.sub(r"[;;].*", "", sql) if sql[:len("select ")].lower() != "select ": return None, None if sql[:len("select *")].lower() != "select *": @@ -241,6 +242,7 @@ def use_sql(question,field_map, tenant_id, chat_mdl): docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"]) clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx|docnm_idx)] + # compose markdown table clmns = "|".join([re.sub(r"/.*", "", field_map.get(tbl["columns"][i]["name"], f"C{i}")) for i in clmn_idx]) + "|原文" line = "|".join(["------" for _ in range(len(clmn_idx))]) + "|------" rows = ["|".join([str(r[i]) for i in clmn_idx])+"|" for r in tbl["rows"]] diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 65b480d..0b949cc 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License # -# + import base64 import pathlib +import re import flask from elasticsearch_dsl import Q @@ -27,7 +28,7 @@ from api.db.services import duplicate_name from api.db.services.knowledgebase_service import KnowledgebaseService from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils import get_uuid -from api.db import FileType, TaskStatus +from api.db import FileType, TaskStatus, ParserType from api.db.services.document_service import DocumentService from api.settings import RetCode from api.utils.api_utils import get_json_result @@ -66,7 +67,7 @@ def upload(): location += "_" blob = request.files['file'].read() MINIO.put(kb_id, location, blob) - doc = DocumentService.insert({ + doc = { "id": get_uuid(), "kb_id": kb.id, "parser_id": kb.parser_id, @@ -77,7 +78,12 @@ def upload(): "location": location, "size": len(blob), "thumbnail": thumbnail(filename, blob) - }) + } + if doc["type"] == FileType.VISUAL: + doc["parser_id"] = ParserType.PICTURE.value + if re.search(r"\.(ppt|pptx|pages)$", filename): + doc["parser_id"] = ParserType.PRESENTATION.value + doc = DocumentService.insert(doc) return get_json_result(data=doc.to_json()) except Exception as e: return server_error_response(e) @@ -283,6 +289,9 @@ def change_parser(): if doc.parser_id.lower() == req["parser_id"].lower(): return get_json_result(data=True) + if doc.type == FileType.VISUAL or re.search(r"\.(ppt|pptx|pages)$", doc.name): + return get_data_error_result(retmsg="Not supported yet!") + e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""}) if not e: return get_data_error_result(retmsg="Document not found!") diff --git a/api/db/__init__.py b/api/db/__init__.py index c657dee..9c8a9b6 100644 --- a/api/db/__init__.py +++ b/api/db/__init__.py @@ -78,3 +78,5 @@ class ParserType(StrEnum): BOOK = "book" QA = "qa" TABLE = "table" + NAIVE = "naive" + PICTURE = "picture" diff --git a/api/db/db_models.py b/api/db/db_models.py index 09fe499..210b83b 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -381,7 +381,7 @@ class Tenant(DataBaseModel): embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID") asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID") img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID") - parser_ids = CharField(max_length=128, null=False, help_text="document processors") + parser_ids = CharField(max_length=256, null=False, help_text="document processors") credit = IntegerField(default=512) status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") diff --git a/api/db/init_data.py b/api/db/init_data.py index 593468f..319524d 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -63,7 +63,9 @@ def init_llm_factory(): "status": "1", }, ] - llm_infos = [{ + llm_infos = [ + # ---------------------- OpenAI ------------------------ + { "fid": factory_infos[0]["name"], "llm_name": "gpt-3.5-turbo", "tags": "LLM,CHAT,4K", @@ -105,7 +107,9 @@ def init_llm_factory(): "tags": "LLM,CHAT,IMAGE2TEXT", "max_tokens": 765, "model_type": LLMType.IMAGE2TEXT.value - },{ + }, + # ----------------------- Qwen ----------------------- + { "fid": factory_infos[1]["name"], "llm_name": "qwen-turbo", "tags": "LLM,CHAT,8K", @@ -135,7 +139,9 @@ def init_llm_factory(): "tags": "LLM,CHAT,IMAGE2TEXT", "max_tokens": 765, "model_type": LLMType.IMAGE2TEXT.value - },{ + }, + # ----------------------- Infiniflow ----------------------- + { "fid": factory_infos[2]["name"], "llm_name": "gpt-3.5-turbo", "tags": "LLM,CHAT,4K", @@ -160,6 +166,33 @@ def init_llm_factory(): "max_tokens": 765, "model_type": LLMType.IMAGE2TEXT.value }, + # ---------------------- ZhipuAI ---------------------- + { + "fid": factory_infos[3]["name"], + "llm_name": "glm-3-turbo", + "tags": "LLM,CHAT,", + "max_tokens": 128 * 1000, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[3]["name"], + "llm_name": "glm-4", + "tags": "LLM,CHAT,", + "max_tokens": 128 * 1000, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[3]["name"], + "llm_name": "glm-4v", + "tags": "LLM,CHAT,IMAGE2TEXT", + "max_tokens": 2000, + "model_type": LLMType.IMAGE2TEXT.value + }, + { + "fid": factory_infos[3]["name"], + "llm_name": "embedding-2", + "tags": "TEXT EMBEDDING", + "max_tokens": 512, + "model_type": LLMType.SPEECH2TEXT.value + }, ] for info in factory_infos: LLMFactoriesService.save(**info) diff --git a/api/settings.py b/api/settings.py index 23c7592..4493bdb 100644 --- a/api/settings.py +++ b/api/settings.py @@ -47,7 +47,7 @@ LLM = get_base_config("llm", {}) CHAT_MDL = LLM.get("chat_model", "gpt-3.5-turbo") EMBEDDING_MDL = LLM.get("embedding_model", "text-embedding-ada-002") ASR_MDL = LLM.get("asr_model", "whisper-1") -PARSERS = LLM.get("parsers", "general:General,resume:esume,laws:Laws,manual:Manual,book:Book,paper:Paper,qa:Q&A,presentation:Presentation") +PARSERS = LLM.get("parsers", "general:General,qa:Q&A,resume:Resume,naive:Naive,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture") IMAGE2TEXT_MDL = LLM.get("image2text_model", "gpt-4-vision-preview") # distribution diff --git a/rag/app/naive.py b/rag/app/naive.py index 178e016..b6a26f9 100644 --- a/rag/app/naive.py +++ b/rag/app/naive.py @@ -57,7 +57,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k callback(0.8, "Finish parsing.") else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") - cks = naive_merge(sections, kwargs.get("chunk_token_num", 128), kwargs.get("delimer", "\n。;ďĽďĽź")) + parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimer": "\n。;ďĽďĽź"}) + cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimer"]) eng = is_english(cks) res = [] # wrap up to es documents diff --git a/rag/app/qa.py b/rag/app/qa.py index fd4f568..75ebd94 100644 --- a/rag/app/qa.py +++ b/rag/app/qa.py @@ -24,31 +24,45 @@ class Excel(object): for i, r in enumerate(rows): q, a = "", "" for cell in r: - if not cell.value: continue - if not q: q = str(cell.value) - elif not a: a = str(cell.value) - else: break - if q and a: res.append((q, a)) - else: fails.append(str(i+1)) + if not cell.value: + continue + if not q: + q = str(cell.value) + elif not a: + a = str(cell.value) + else: + break + if q and a: + res.append((q, a)) + else: + fails.append(str(i + 1)) if len(res) % 999 == 0: - callback(len(res)*0.6/total, ("Extract Q&A: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..."%(",".join(fails[:3])) if fails else ""))) + callback(len(res) * + 0.6 / + total, ("Extract Q&A: {}".format(len(res)) + + (f"{len(fails)} failure, line: %s..." % + (",".join(fails[:3])) if fails else ""))) callback(0.6, ("Extract Q&A: {}. ".format(len(res)) + ( f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) - self.is_english = is_english([rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q)>1]) + self.is_english = is_english( + [rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1]) return res def rmPrefix(txt): - return re.sub(r"^(é—®é˘|ç”ćˇ|回ç”|user|assistant|Q|A|Question|Answer|é—®|ç”)[\t:: ]+", "", txt.strip(), flags=re.IGNORECASE) + return re.sub( + r"^(é—®é˘|ç”ćˇ|回ç”|user|assistant|Q|A|Question|Answer|é—®|ç”)[\t:: ]+", "", txt.strip(), flags=re.IGNORECASE) def beAdoc(d, q, a, eng): qprefix = "Question: " if eng else "é—®é˘ďĽš" aprefix = "Answer: " if eng else "回ç”:" - d["content_with_weight"] = "\t".join([qprefix+rmPrefix(q), aprefix+rmPrefix(a)]) + d["content_with_weight"] = "\t".join( + [qprefix + rmPrefix(q), aprefix + rmPrefix(a)]) if eng: - d["content_ltks"] = " ".join([stemmer.stem(w) for w in word_tokenize(q)]) + d["content_ltks"] = " ".join([stemmer.stem(w) + for w in word_tokenize(q)]) else: d["content_ltks"] = huqie.qie(q) d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) @@ -61,7 +75,7 @@ def chunk(filename, binary=None, callback=None, **kwargs): if re.search(r"\.xlsx?$", filename, re.IGNORECASE): callback(0.1, "Start to parse.") excel_parser = Excel() - for q,a in excel_parser(filename, binary, callback): + for q, a in excel_parser(filename, binary, callback): res.append(beAdoc({}, q, a, excel_parser.is_english)) return res elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE): @@ -73,7 +87,8 @@ def chunk(filename, binary=None, callback=None, **kwargs): with open(filename, "r") as f: while True: l = f.readline() - if not l: break + if not l: + break txt += l lines = txt.split("\n") eng = is_english([rmPrefix(l) for l in lines[:100]]) @@ -93,12 +108,13 @@ def chunk(filename, binary=None, callback=None, **kwargs): return res - raise NotImplementedError("file type not supported yet(pptx, pdf supported)") + raise NotImplementedError( + "file type not supported yet(pptx, pdf supported)") -if __name__== "__main__": +if __name__ == "__main__": import sys + def dummy(a, b): pass chunk(sys.argv[1], callback=dummy) - diff --git a/rag/app/resume.py b/rag/app/resume.py index f62d2c5..14649bb 100644 --- a/rag/app/resume.py +++ b/rag/app/resume.py @@ -11,15 +11,22 @@ from rag.utils import rmSpace def chunk(filename, binary=None, callback=None, **kwargs): - if not re.search(r"\.(pdf|doc|docx|txt)$", filename, flags=re.IGNORECASE): raise NotImplementedError("file type not supported yet(pdf supported)") + if not re.search(r"\.(pdf|doc|docx|txt)$", filename, flags=re.IGNORECASE): + raise NotImplementedError("file type not supported yet(pdf supported)") url = os.environ.get("INFINIFLOW_SERVER") - if not url:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_SERVER'") + if not url: + raise EnvironmentError( + "Please set environment variable: 'INFINIFLOW_SERVER'") token = os.environ.get("INFINIFLOW_TOKEN") - if not token:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_TOKEN'") + if not token: + raise EnvironmentError( + "Please set environment variable: 'INFINIFLOW_TOKEN'") if not binary: - with open(filename, "rb") as f: binary = f.read() + with open(filename, "rb") as f: + binary = f.read() + def remote_call(): nonlocal filename, binary for _ in range(3): @@ -27,14 +34,17 @@ def chunk(filename, binary=None, callback=None, **kwargs): res = requests.post(url + "/v1/layout/resume/", files=[(filename, binary)], headers={"Authorization": token}, timeout=180) res = res.json() - if res["retcode"] != 0: raise RuntimeError(res["retmsg"]) + if res["retcode"] != 0: + raise RuntimeError(res["retmsg"]) return res["data"] except RuntimeError as e: raise e except Exception as e: cron_logger.error("resume parsing:" + str(e)) + callback(0.2, "Resume parsing is going on...") resume = remote_call() + callback(0.6, "Done parsing. Chunking...") print(json.dumps(resume, ensure_ascii=False, indent=2)) field_map = { @@ -45,19 +55,19 @@ def chunk(filename, binary=None, callback=None, **kwargs): "email_tks": "email/e-mail/邮箱", "position_name_tks": "čŚä˝Ť/čŚč˝/岗位/čŚč´Ł", "expect_position_name_tks": "ćśźćś›čŚä˝Ť/ćśźćś›čŚč˝/期望岗位", - + "hightest_degree_kwd": "最é«ĺ¦ĺŽ†ďĽé«ä¸ďĽŚčŚé«ďĽŚçˇ•ĺŁ«ďĽŚćś¬ç§‘,博士,ĺťä¸ďĽŚä¸ćŠ€ďĽŚä¸ä¸“,专科,专升本,MPA,MBA,EMBA)", "first_degree_kwd": "第一ĺ¦ĺŽ†ďĽé«ä¸ďĽŚčŚé«ďĽŚçˇ•ĺŁ«ďĽŚćś¬ç§‘,博士,ĺťä¸ďĽŚä¸ćŠ€ďĽŚä¸ä¸“,专科,专升本,MPA,MBA,EMBA)", "first_major_tks": "第一ĺ¦ĺŽ†ä¸“业", "first_school_name_tks": "第一ĺ¦ĺŽ†ćŻ•ä¸šĺ¦ć ˇ", "edu_first_fea_kwd": "第一ĺ¦ĺŽ†ć ‡çľďĽ211,留ĺ¦ďĽŚĺŹŚä¸€ćµďĽŚ985,海外知ĺŤďĽŚé‡Ťç‚ąĺ¤§ĺ¦ďĽŚä¸ä¸“,专升本,专科,本科,大专)", - + "degree_kwd": "过往ĺ¦ĺŽ†ďĽé«ä¸ďĽŚčŚé«ďĽŚçˇ•ĺŁ«ďĽŚćś¬ç§‘,博士,ĺťä¸ďĽŚä¸ćŠ€ďĽŚä¸ä¸“,专科,专升本,MPA,MBA,EMBA)", "major_tks": "ĺ¦čż‡çš„专业/过往专业", "school_name_tks": "ĺ¦ć ˇ/ćŻ•ä¸šé™˘ć ˇ", "sch_rank_kwd": "ĺ¦ć ˇć ‡çľďĽéˇ¶ĺ°–ĺ¦ć ˇďĽŚç˛ľč‹±ĺ¦ć ˇďĽŚäĽč´¨ĺ¦ć ˇďĽŚä¸€č¬ĺ¦ć ˇďĽ‰", "edu_fea_kwd": "ć•™č‚˛ć ‡çľďĽ211,留ĺ¦ďĽŚĺŹŚä¸€ćµďĽŚ985,海外知ĺŤďĽŚé‡Ťç‚ąĺ¤§ĺ¦ďĽŚä¸ä¸“,专升本,专科,本科,大专)", - + "work_exp_flt": "工作年é™/工作年份/N年经验/毕业了多少年", "birth_dt": "生日/出生年份", "corp_nm_tks": "ĺ°±čŚčż‡çš„公司/之前的公司/上过çŹçš„公司", @@ -69,34 +79,43 @@ def chunk(filename, binary=None, callback=None, **kwargs): titles = [] for n in ["name_kwd", "gender_kwd", "position_name_tks", "age_int"]: v = resume.get(n, "") - if isinstance(v, list):v = v[0] - if n.find("tks") > 0: v = rmSpace(v) + if isinstance(v, list): + v = v[0] + if n.find("tks") > 0: + v = rmSpace(v) titles.append(str(v)) doc = { "docnm_kwd": filename, - "title_tks": huqie.qie("-".join(titles)+"-简历") + "title_tks": huqie.qie("-".join(titles) + "-简历") } doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) pairs = [] - for n,m in field_map.items(): - if not resume.get(n):continue + for n, m in field_map.items(): + if not resume.get(n): + continue v = resume[n] - if isinstance(v, list):v = " ".join(v) - if n.find("tks") > 0: v = rmSpace(v) + if isinstance(v, list): + v = " ".join(v) + if n.find("tks") > 0: + v = rmSpace(v) pairs.append((m, str(v))) - doc["content_with_weight"] = "\n".join(["{}: {}".format(re.sub(r"ďĽ[^ďĽďĽ‰]+)", "", k), v) for k,v in pairs]) + doc["content_with_weight"] = "\n".join( + ["{}: {}".format(re.sub(r"ďĽ[^ďĽďĽ‰]+)", "", k), v) for k, v in pairs]) doc["content_ltks"] = huqie.qie(doc["content_with_weight"]) doc["content_sm_ltks"] = huqie.qieqie(doc["content_ltks"]) - for n, _ in field_map.items(): doc[n] = resume[n] + for n, _ in field_map.items(): + doc[n] = resume[n] print(doc) - KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": field_map}) + KnowledgebaseService.update_parser_config( + kwargs["kb_id"], {"field_map": field_map}) return [doc] if __name__ == "__main__": import sys + def dummy(a, b): pass chunk(sys.argv[1], callback=dummy) diff --git a/rag/app/table.py b/rag/app/table.py index ee66356..a078a1a 100644 --- a/rag/app/table.py +++ b/rag/app/table.py @@ -28,10 +28,15 @@ class Excel(object): rows = list(ws.rows) headers = [cell.value for cell in rows[0]] missed = set([i for i, h in enumerate(headers) if h is None]) - headers = [cell.value for i, cell in enumerate(rows[0]) if i not in missed] + headers = [ + cell.value for i, + cell in enumerate( + rows[0]) if i not in missed] data = [] for i, r in enumerate(rows[1:]): - row = [cell.value for ii, cell in enumerate(r) if ii not in missed] + row = [ + cell.value for ii, + cell in enumerate(r) if ii not in missed] if len(row) != len(headers): fails.append(str(i)) continue @@ -55,8 +60,10 @@ def trans_datatime(s): def trans_bool(s): - if re.match(r"(true|yes|ćŻ)$", str(s).strip(), flags=re.IGNORECASE): return ["yes", "ćŻ"] - if re.match(r"(false|no|ĺ¦)$", str(s).strip(), flags=re.IGNORECASE): return ["no", "ĺ¦"] + if re.match(r"(true|yes|ćŻ)$", str(s).strip(), flags=re.IGNORECASE): + return ["yes", "ćŻ"] + if re.match(r"(false|no|ĺ¦)$", str(s).strip(), flags=re.IGNORECASE): + return ["no", "ĺ¦"] def column_data_type(arr): @@ -65,7 +72,8 @@ def column_data_type(arr): trans = {t: f for f, t in [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]} for a in arr: - if a is None: continue + if a is None: + continue if re.match(r"[+-]?[0-9]+(\.0+)?$", str(a).replace("%%", "")): counts["int"] += 1 elif re.match(r"[+-]?[0-9.]+$", str(a).replace("%%", "")): @@ -79,7 +87,8 @@ def column_data_type(arr): counts = sorted(counts.items(), key=lambda x: x[1] * -1) ty = counts[0][0] for i in range(len(arr)): - if arr[i] is None: continue + if arr[i] is None: + continue try: arr[i] = trans[ty](str(arr[i])) except Exception as e: @@ -105,7 +114,8 @@ def chunk(filename, binary=None, callback=None, **kwargs): with open(filename, "r") as f: while True: l = f.readline() - if not l: break + if not l: + break txt += l lines = txt.split("\n") fails = [] @@ -127,14 +137,22 @@ def chunk(filename, binary=None, callback=None, **kwargs): dfs = [pd.DataFrame(np.array(rows), columns=headers)] else: - raise NotImplementedError("file type not supported yet(excel, text, csv supported)") + raise NotImplementedError( + "file type not supported yet(excel, text, csv supported)") res = [] PY = Pinyin() - fieds_map = {"text": "_tks", "int": "_int", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"} + fieds_map = { + "text": "_tks", + "int": "_int", + "keyword": "_kwd", + "float": "_flt", + "datetime": "_dt", + "bool": "_kwd"} for df in dfs: for n in ["id", "_id", "index", "idx"]: - if n in df.columns: del df[n] + if n in df.columns: + del df[n] clmns = df.columns.values txts = list(copy.deepcopy(clmns)) py_clmns = [PY.get_pinyins(n)[0].replace("-", "_") for n in clmns] @@ -143,23 +161,29 @@ def chunk(filename, binary=None, callback=None, **kwargs): cln, ty = column_data_type(df[clmns[j]]) clmn_tys.append(ty) df[clmns[j]] = cln - if ty == "text": txts.extend([str(c) for c in cln if c]) - clmns_map = [(py_clmns[j] + fieds_map[clmn_tys[j]], clmns[j]) for i in range(len(clmns))] + if ty == "text": + txts.extend([str(c) for c in cln if c]) + clmns_map = [(py_clmns[j] + fieds_map[clmn_tys[j]], clmns[j]) + for i in range(len(clmns))] eng = is_english(txts) for ii, row in df.iterrows(): d = {} row_txt = [] for j in range(len(clmns)): - if row[clmns[j]] is None: continue + if row[clmns[j]] is None: + continue fld = clmns_map[j][0] - d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie(row[clmns[j]]) + d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie( + row[clmns[j]]) row_txt.append("{}:{}".format(clmns[j], row[clmns[j]])) - if not row_txt: continue + if not row_txt: + continue tokenize(d, "; ".join(row_txt), eng) res.append(d) - KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}}) + KnowledgebaseService.update_parser_config( + kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}}) callback(0.6, "") return res @@ -168,9 +192,7 @@ def chunk(filename, binary=None, callback=None, **kwargs): if __name__ == "__main__": import sys - def dummy(a, b): pass - chunk(sys.argv[1], callback=dummy) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 3162636..7868eb2 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -58,3 +58,21 @@ class QWenChat(Base): if response.status_code == HTTPStatus.OK: return response.output.choices[0]['message']['content'], response.usage.output_tokens return response.message, 0 + + +from zhipuai import ZhipuAI +class ZhipuChat(Base): + def __init__(self, key, model_name="glm-3-turbo"): + self.client = ZhipuAI(api_key=key) + self.model_name = model_name + + def chat(self, system, history, gen_conf): + from http import HTTPStatus + history.insert(0, {"role": "system", "content": system}) + response = self.client.chat.completions.create( + self.model_name, + messages=history + ) + if response.status_code == HTTPStatus.OK: + return response.output.choices[0]['message']['content'], response.usage.completion_tokens + return response.message, 0 \ No newline at end of file diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 67816a1..d663f90 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -61,7 +61,7 @@ class Base(ABC): class GptV4(Base): def __init__(self, key, model_name="gpt-4-vision-preview"): - self.client = OpenAI(api_key = key) + self.client = OpenAI(api_key=key) self.model_name = model_name def describe(self, image, max_tokens=300): @@ -89,3 +89,22 @@ class QWenCV(Base): if response.status_code == HTTPStatus.OK: return response.output.choices[0]['message']['content'], response.usage.output_tokens return response.message, 0 + + +from zhipuai import ZhipuAI + + +class Zhipu4V(Base): + def __init__(self, key, model_name="glm-4v"): + self.client = ZhipuAI(api_key=key) + self.model_name = model_name + + def describe(self, image, max_tokens=1024): + b64 = self.image2base64(image) + + res = self.client.chat.completions.create( + model=self.model_name, + messages=self.prompt(b64), + max_tokens=max_tokens, + ) + return res.choices[0].message.content.strip(), res.usage.total_tokens diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index be914e6..94277d9 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -19,7 +19,6 @@ import dashscope from openai import OpenAI from FlagEmbedding import FlagModel import torch -import os import numpy as np from rag.utils import num_tokens_from_string @@ -114,4 +113,21 @@ class QWenEmbed(Base): input=text[:2048], text_type="query" ) - return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["input_tokens"] \ No newline at end of file + return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["input_tokens"] + + +from zhipuai import ZhipuAI +class ZhipuEmbed(Base): + def __init__(self, key, model_name="embedding-2"): + self.client = ZhipuAI(api_key=key) + self.model_name = model_name + + def encode(self, texts: list, batch_size=32): + res = self.client.embeddings.create(input=texts, + model=self.model_name) + return np.array([d.embedding for d in res.data]), res.usage.total_tokens + + def encode_queries(self, text): + res = self.client.embeddings.create(input=text, + model=self.model_name) + return np.array(res["data"][0]["embedding"]), res.usage.total_tokens \ No newline at end of file diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 9a23781..85f2449 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -268,9 +268,9 @@ class Dealer: dim = len(sres.query_vector) start_idx = (page - 1) * page_size for i in idx: - ranks["total"] += 1 if sim[i] < similarity_threshold: break + ranks["total"] += 1 start_idx -= 1 if start_idx >= 0: continue @@ -280,6 +280,7 @@ class Dealer: break id = sres.ids[i] dnm = sres.field[id]["docnm_kwd"] + did = sres.field[id]["doc_id"] d = { "chunk_id": id, "content_ltks": sres.field[id]["content_ltks"], @@ -296,8 +297,9 @@ class Dealer: } ranks["chunks"].append(d) if dnm not in ranks["doc_aggs"]: - ranks["doc_aggs"][dnm] = 0 - ranks["doc_aggs"][dnm] += 1 + ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0} + ranks["doc_aggs"][dnm]["count"] += 1 + ranks["doc_aggs"] = [{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k,v in sorted(ranks["doc_aggs"].items(), key=lambda x:x[1]["count"]*-1)] return ranks diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index def9889..ed86987 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -36,7 +36,7 @@ from rag.nlp import search from io import BytesIO import pandas as pd -from rag.app import laws, paper, presentation, manual, qa, table,book +from rag.app import laws, paper, presentation, manual, qa, table, book, resume from api.db import LLMType, ParserType from api.db.services.document_service import DocumentService @@ -55,6 +55,7 @@ FACTORY = { ParserType.LAWS.value: laws, ParserType.QA.value: qa, ParserType.TABLE.value: table, + ParserType.RESUME.value: resume, } @@ -119,7 +120,7 @@ def build(row, cvmdl): try: cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"])) cks = chunker.chunk(row["name"], MINIO.get(row["kb_id"], row["location"]), row["from_page"], row["to_page"], - callback, kb_id=row["kb_id"]) + callback, kb_id=row["kb_id"], parser_config=row["parser_config"]) except Exception as e: if re.search("(No such file|not found)", str(e)): callback(-1, "Can not find file <%s>" % row["doc_name"]) @@ -171,7 +172,7 @@ def init_kb(row): open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r"))) -def embedding(docs, mdl): +def embedding(docs, mdl, parser_config={}): tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [d["content_with_weight"] for d in docs] tk_count = 0 if len(tts) == len(cnts): @@ -180,7 +181,8 @@ def embedding(docs, mdl): cnts, c = mdl.encode(cnts) tk_count += c - vects = (0.1 * tts + 0.9 * cnts) if len(tts) == len(cnts) else cnts + title_w = float(parser_config.get("filename_embd_weight", 0.1)) + vects = (title_w * tts + (1-title_w) * cnts) if len(tts) == len(cnts) else cnts assert len(vects) == len(docs) for i, d in enumerate(docs): @@ -216,7 +218,7 @@ def main(comm, mod): # TODO: exception handler ## set_progress(r["did"], -1, "ERROR: ") try: - tk_count = embedding(cks, embd_mdl) + tk_count = embedding(cks, embd_mdl, r["parser_config"]) except Exception as e: callback(-1, "Embedding error:{}".format(str(e))) cron_logger.error(str(e)) -- GitLab