From 5e0a689c4340b8c7a060aaaf0a913925cd4221cf Mon Sep 17 00:00:00 2001
From: KevinHuSh <kevinhu.sh@gmail.com>
Date: Thu, 8 Feb 2024 17:01:01 +0800
Subject: [PATCH] refactor retieval_test, add SQl retrieval methods (#61)

---
 api/apps/chunk_app.py        |  5 +++-
 api/apps/conversation_app.py |  2 ++
 api/apps/document_app.py     | 17 ++++++++---
 api/db/__init__.py           |  2 ++
 api/db/db_models.py          |  2 +-
 api/db/init_data.py          | 39 ++++++++++++++++++++++--
 api/settings.py              |  2 +-
 rag/app/naive.py             |  3 +-
 rag/app/qa.py                | 48 +++++++++++++++++++----------
 rag/app/resume.py            | 55 +++++++++++++++++++++++-----------
 rag/app/table.py             | 58 +++++++++++++++++++++++++-----------
 rag/llm/chat_model.py        | 18 +++++++++++
 rag/llm/cv_model.py          | 21 ++++++++++++-
 rag/llm/embedding_model.py   | 20 +++++++++++--
 rag/nlp/search.py            |  8 +++--
 rag/svr/task_executor.py     | 12 ++++----
 16 files changed, 238 insertions(+), 74 deletions(-)

diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py
index 11d60d5..11d9c7b 100644
--- a/api/apps/chunk_app.py
+++ b/api/apps/chunk_app.py
@@ -227,7 +227,7 @@ def retrieval_test():
     doc_ids = req.get("doc_ids", [])
     similarity_threshold = float(req.get("similarity_threshold", 0.2))
     vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
-    top = int(req.get("top", 1024))
+    top = int(req.get("top_k", 1024))
     try:
         e, kb = KnowledgebaseService.get_by_id(kb_id)
         if not e:
@@ -237,6 +237,9 @@ def retrieval_test():
             kb.tenant_id, LLMType.EMBEDDING.value)
         ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold,
                                       vector_similarity_weight, top, doc_ids)
+        for c in ranks["chunks"]:
+            if "vector" in c:
+                del c["vector"]
 
         return get_json_result(data=ranks)
     except Exception as e:
diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py
index c48b586..30c3ee9 100644
--- a/api/apps/conversation_app.py
+++ b/api/apps/conversation_app.py
@@ -229,6 +229,7 @@ def use_sql(question,field_map, tenant_id, chat_mdl):
     sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.1})
     sql = re.sub(r".*?select ", "select ", sql, flags=re.IGNORECASE)
     sql = re.sub(r" +", " ", sql)
+    sql = re.sub(r"[;;].*", "", sql)
     if sql[:len("select ")].lower() != "select ":
         return None, None
     if sql[:len("select *")].lower() != "select *":
@@ -241,6 +242,7 @@ def use_sql(question,field_map, tenant_id, chat_mdl):
     docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
     clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx|docnm_idx)]
 
+    # compose markdown table
     clmns = "|".join([re.sub(r"/.*", "", field_map.get(tbl["columns"][i]["name"], f"C{i}")) for i in clmn_idx]) + "|原文"
     line = "|".join(["------" for _ in range(len(clmn_idx))]) + "|------"
     rows = ["|".join([str(r[i]) for i in clmn_idx])+"|" for r in tbl["rows"]]
diff --git a/api/apps/document_app.py b/api/apps/document_app.py
index 65b480d..0b949cc 100644
--- a/api/apps/document_app.py
+++ b/api/apps/document_app.py
@@ -13,9 +13,10 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License
 #
-#
+
 import base64
 import pathlib
+import re
 
 import flask
 from elasticsearch_dsl import Q
@@ -27,7 +28,7 @@ from api.db.services import duplicate_name
 from api.db.services.knowledgebase_service import KnowledgebaseService
 from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
 from api.utils import get_uuid
-from api.db import FileType, TaskStatus
+from api.db import FileType, TaskStatus, ParserType
 from api.db.services.document_service import DocumentService
 from api.settings import RetCode
 from api.utils.api_utils import get_json_result
@@ -66,7 +67,7 @@ def upload():
             location += "_"
         blob = request.files['file'].read()
         MINIO.put(kb_id, location, blob)
-        doc = DocumentService.insert({
+        doc = {
             "id": get_uuid(),
             "kb_id": kb.id,
             "parser_id": kb.parser_id,
@@ -77,7 +78,12 @@ def upload():
             "location": location,
             "size": len(blob),
             "thumbnail": thumbnail(filename, blob)
-        })
+        }
+        if doc["type"] == FileType.VISUAL:
+            doc["parser_id"] = ParserType.PICTURE.value
+        if re.search(r"\.(ppt|pptx|pages)$", filename):
+            doc["parser_id"] = ParserType.PRESENTATION.value
+        doc = DocumentService.insert(doc)
         return get_json_result(data=doc.to_json())
     except Exception as e:
         return server_error_response(e)
@@ -283,6 +289,9 @@ def change_parser():
         if doc.parser_id.lower() == req["parser_id"].lower():
             return get_json_result(data=True)
 
+        if doc.type == FileType.VISUAL or re.search(r"\.(ppt|pptx|pages)$", doc.name):
+            return get_data_error_result(retmsg="Not supported yet!")
+
         e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""})
         if not e:
             return get_data_error_result(retmsg="Document not found!")
diff --git a/api/db/__init__.py b/api/db/__init__.py
index c657dee..9c8a9b6 100644
--- a/api/db/__init__.py
+++ b/api/db/__init__.py
@@ -78,3 +78,5 @@ class ParserType(StrEnum):
     BOOK = "book"
     QA = "qa"
     TABLE = "table"
+    NAIVE = "naive"
+    PICTURE = "picture"
diff --git a/api/db/db_models.py b/api/db/db_models.py
index 09fe499..210b83b 100644
--- a/api/db/db_models.py
+++ b/api/db/db_models.py
@@ -381,7 +381,7 @@ class Tenant(DataBaseModel):
     embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID")
     asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID")
     img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID")
-    parser_ids = CharField(max_length=128, null=False, help_text="document processors")
+    parser_ids = CharField(max_length=256, null=False, help_text="document processors")
     credit = IntegerField(default=512)
     status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
 
diff --git a/api/db/init_data.py b/api/db/init_data.py
index 593468f..319524d 100644
--- a/api/db/init_data.py
+++ b/api/db/init_data.py
@@ -63,7 +63,9 @@ def init_llm_factory():
             "status": "1",
         },
     ]
-    llm_infos = [{
+    llm_infos = [
+        # ---------------------- OpenAI ------------------------
+        {
             "fid": factory_infos[0]["name"],
             "llm_name": "gpt-3.5-turbo",
             "tags": "LLM,CHAT,4K",
@@ -105,7 +107,9 @@ def init_llm_factory():
             "tags": "LLM,CHAT,IMAGE2TEXT",
             "max_tokens": 765,
             "model_type": LLMType.IMAGE2TEXT.value
-        },{
+        },
+        # ----------------------- Qwen -----------------------
+        {
             "fid": factory_infos[1]["name"],
             "llm_name": "qwen-turbo",
             "tags": "LLM,CHAT,8K",
@@ -135,7 +139,9 @@ def init_llm_factory():
             "tags": "LLM,CHAT,IMAGE2TEXT",
             "max_tokens": 765,
             "model_type": LLMType.IMAGE2TEXT.value
-        },{
+        },
+        # ----------------------- Infiniflow -----------------------
+        {
             "fid": factory_infos[2]["name"],
             "llm_name": "gpt-3.5-turbo",
             "tags": "LLM,CHAT,4K",
@@ -160,6 +166,33 @@ def init_llm_factory():
             "max_tokens": 765,
             "model_type": LLMType.IMAGE2TEXT.value
         },
+        # ---------------------- ZhipuAI ----------------------
+        {
+            "fid": factory_infos[3]["name"],
+            "llm_name": "glm-3-turbo",
+            "tags": "LLM,CHAT,",
+            "max_tokens": 128 * 1000,
+            "model_type": LLMType.CHAT.value
+        }, {
+            "fid": factory_infos[3]["name"],
+            "llm_name": "glm-4",
+            "tags": "LLM,CHAT,",
+            "max_tokens": 128 * 1000,
+            "model_type": LLMType.CHAT.value
+        }, {
+            "fid": factory_infos[3]["name"],
+            "llm_name": "glm-4v",
+            "tags": "LLM,CHAT,IMAGE2TEXT",
+            "max_tokens": 2000,
+            "model_type": LLMType.IMAGE2TEXT.value
+        },
+        {
+            "fid": factory_infos[3]["name"],
+            "llm_name": "embedding-2",
+            "tags": "TEXT EMBEDDING",
+            "max_tokens": 512,
+            "model_type": LLMType.SPEECH2TEXT.value
+        },
     ]
     for info in factory_infos:
         LLMFactoriesService.save(**info)
diff --git a/api/settings.py b/api/settings.py
index 23c7592..4493bdb 100644
--- a/api/settings.py
+++ b/api/settings.py
@@ -47,7 +47,7 @@ LLM = get_base_config("llm", {})
 CHAT_MDL = LLM.get("chat_model", "gpt-3.5-turbo")
 EMBEDDING_MDL = LLM.get("embedding_model", "text-embedding-ada-002")
 ASR_MDL = LLM.get("asr_model", "whisper-1")
-PARSERS = LLM.get("parsers", "general:General,resume:esume,laws:Laws,manual:Manual,book:Book,paper:Paper,qa:Q&A,presentation:Presentation")
+PARSERS = LLM.get("parsers", "general:General,qa:Q&A,resume:Resume,naive:Naive,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
 IMAGE2TEXT_MDL = LLM.get("image2text_model", "gpt-4-vision-preview")
 
 # distribution
diff --git a/rag/app/naive.py b/rag/app/naive.py
index 178e016..b6a26f9 100644
--- a/rag/app/naive.py
+++ b/rag/app/naive.py
@@ -57,7 +57,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
         callback(0.8, "Finish parsing.")
     else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")
 
-    cks = naive_merge(sections, kwargs.get("chunk_token_num", 128), kwargs.get("delimer", "\n。;!?"))
+    parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimer": "\n。;!?"})
+    cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimer"])
     eng = is_english(cks)
     res = []
     # wrap up to es documents
diff --git a/rag/app/qa.py b/rag/app/qa.py
index fd4f568..75ebd94 100644
--- a/rag/app/qa.py
+++ b/rag/app/qa.py
@@ -24,31 +24,45 @@ class Excel(object):
             for i, r in enumerate(rows):
                 q, a = "", ""
                 for cell in r:
-                    if not cell.value: continue
-                    if not q: q = str(cell.value)
-                    elif not a: a = str(cell.value)
-                    else: break
-                if q and a: res.append((q, a))
-                else: fails.append(str(i+1))
+                    if not cell.value:
+                        continue
+                    if not q:
+                        q = str(cell.value)
+                    elif not a:
+                        a = str(cell.value)
+                    else:
+                        break
+                if q and a:
+                    res.append((q, a))
+                else:
+                    fails.append(str(i + 1))
                 if len(res) % 999 == 0:
-                    callback(len(res)*0.6/total, ("Extract Q&A: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..."%(",".join(fails[:3])) if fails else "")))
+                    callback(len(res) *
+                             0.6 /
+                             total, ("Extract Q&A: {}".format(len(res)) +
+                                     (f"{len(fails)} failure, line: %s..." %
+                                      (",".join(fails[:3])) if fails else "")))
 
         callback(0.6, ("Extract Q&A: {}. ".format(len(res)) + (
             f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
-        self.is_english = is_english([rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q)>1])
+        self.is_english = is_english(
+            [rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1])
         return res
 
 
 def rmPrefix(txt):
-    return re.sub(r"^(问题|答案|回答|user|assistant|Q|A|Question|Answer|问|答)[\t:: ]+", "", txt.strip(), flags=re.IGNORECASE)
+    return re.sub(
+        r"^(问题|答案|回答|user|assistant|Q|A|Question|Answer|问|答)[\t:: ]+", "", txt.strip(), flags=re.IGNORECASE)
 
 
 def beAdoc(d, q, a, eng):
     qprefix = "Question: " if eng else "问题:"
     aprefix = "Answer: " if eng else "回答:"
-    d["content_with_weight"] = "\t".join([qprefix+rmPrefix(q), aprefix+rmPrefix(a)])
+    d["content_with_weight"] = "\t".join(
+        [qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
     if eng:
-        d["content_ltks"] = " ".join([stemmer.stem(w) for w in word_tokenize(q)])
+        d["content_ltks"] = " ".join([stemmer.stem(w)
+                                     for w in word_tokenize(q)])
     else:
         d["content_ltks"] = huqie.qie(q)
         d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
@@ -61,7 +75,7 @@ def chunk(filename, binary=None, callback=None, **kwargs):
     if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
         callback(0.1, "Start to parse.")
         excel_parser = Excel()
-        for q,a in excel_parser(filename, binary, callback):
+        for q, a in excel_parser(filename, binary, callback):
             res.append(beAdoc({}, q, a, excel_parser.is_english))
         return res
     elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE):
@@ -73,7 +87,8 @@ def chunk(filename, binary=None, callback=None, **kwargs):
             with open(filename, "r") as f:
                 while True:
                     l = f.readline()
-                    if not l: break
+                    if not l:
+                        break
                     txt += l
         lines = txt.split("\n")
         eng = is_english([rmPrefix(l) for l in lines[:100]])
@@ -93,12 +108,13 @@ def chunk(filename, binary=None, callback=None, **kwargs):
 
         return res
 
-    raise NotImplementedError("file type not supported yet(pptx, pdf supported)")
+    raise NotImplementedError(
+        "file type not supported yet(pptx, pdf supported)")
 
 
-if __name__== "__main__":
+if __name__ == "__main__":
     import sys
+
     def dummy(a, b):
         pass
     chunk(sys.argv[1], callback=dummy)
-
diff --git a/rag/app/resume.py b/rag/app/resume.py
index f62d2c5..14649bb 100644
--- a/rag/app/resume.py
+++ b/rag/app/resume.py
@@ -11,15 +11,22 @@ from rag.utils import rmSpace
 
 
 def chunk(filename, binary=None, callback=None, **kwargs):
-    if not re.search(r"\.(pdf|doc|docx|txt)$", filename, flags=re.IGNORECASE): raise NotImplementedError("file type not supported yet(pdf supported)")
+    if not re.search(r"\.(pdf|doc|docx|txt)$", filename, flags=re.IGNORECASE):
+        raise NotImplementedError("file type not supported yet(pdf supported)")
 
     url = os.environ.get("INFINIFLOW_SERVER")
-    if not url:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_SERVER'")
+    if not url:
+        raise EnvironmentError(
+            "Please set environment variable: 'INFINIFLOW_SERVER'")
     token = os.environ.get("INFINIFLOW_TOKEN")
-    if not token:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_TOKEN'")
+    if not token:
+        raise EnvironmentError(
+            "Please set environment variable: 'INFINIFLOW_TOKEN'")
 
     if not binary:
-        with open(filename, "rb") as f: binary = f.read()
+        with open(filename, "rb") as f:
+            binary = f.read()
+
     def remote_call():
         nonlocal filename, binary
         for _ in range(3):
@@ -27,14 +34,17 @@ def chunk(filename, binary=None, callback=None, **kwargs):
                 res = requests.post(url + "/v1/layout/resume/", files=[(filename, binary)],
                                     headers={"Authorization": token}, timeout=180)
                 res = res.json()
-                if res["retcode"] != 0: raise RuntimeError(res["retmsg"])
+                if res["retcode"] != 0:
+                    raise RuntimeError(res["retmsg"])
                 return res["data"]
             except RuntimeError as e:
                 raise e
             except Exception as e:
                 cron_logger.error("resume parsing:" + str(e))
 
+    callback(0.2, "Resume parsing is going on...")
     resume = remote_call()
+    callback(0.6, "Done parsing. Chunking...")
     print(json.dumps(resume, ensure_ascii=False, indent=2))
 
     field_map = {
@@ -45,19 +55,19 @@ def chunk(filename, binary=None, callback=None, **kwargs):
         "email_tks": "email/e-mail/邮箱",
         "position_name_tks": "职位/职能/岗位/职责",
         "expect_position_name_tks": "期望职位/期望职能/期望岗位",
-    
+
         "hightest_degree_kwd": "最高学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)",
         "first_degree_kwd": "第一学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)",
         "first_major_tks": "第一学历专业",
         "first_school_name_tks": "第一学历毕业学校",
         "edu_first_fea_kwd": "第一学历标签(211,留学,双一流,985,海外知名,重点大学,中专,专升本,专科,本科,大专)",
-    
+
         "degree_kwd": "过往学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)",
         "major_tks": "学过的专业/过往专业",
         "school_name_tks": "学校/毕业院校",
         "sch_rank_kwd": "学校标签(顶尖学校,精英学校,优质学校,一般学校)",
         "edu_fea_kwd": "教育标签(211,留学,双一流,985,海外知名,重点大学,中专,专升本,专科,本科,大专)",
-    
+
         "work_exp_flt": "工作年限/工作年份/N年经验/毕业了多少年",
         "birth_dt": "生日/出生年份",
         "corp_nm_tks": "就职过的公司/之前的公司/上过班的公司",
@@ -69,34 +79,43 @@ def chunk(filename, binary=None, callback=None, **kwargs):
     titles = []
     for n in ["name_kwd", "gender_kwd", "position_name_tks", "age_int"]:
         v = resume.get(n, "")
-        if isinstance(v, list):v = v[0]
-        if n.find("tks") > 0: v = rmSpace(v)
+        if isinstance(v, list):
+            v = v[0]
+        if n.find("tks") > 0:
+            v = rmSpace(v)
         titles.append(str(v))
     doc = {
         "docnm_kwd": filename,
-        "title_tks": huqie.qie("-".join(titles)+"-简历")
+        "title_tks": huqie.qie("-".join(titles) + "-简历")
     }
     doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
     pairs = []
-    for n,m in field_map.items():
-        if not resume.get(n):continue
+    for n, m in field_map.items():
+        if not resume.get(n):
+            continue
         v = resume[n]
-        if isinstance(v, list):v = " ".join(v)
-        if n.find("tks") > 0: v = rmSpace(v)
+        if isinstance(v, list):
+            v = " ".join(v)
+        if n.find("tks") > 0:
+            v = rmSpace(v)
         pairs.append((m, str(v)))
 
-    doc["content_with_weight"] = "\n".join(["{}: {}".format(re.sub(r"([^()]+)", "", k), v) for k,v in pairs])
+    doc["content_with_weight"] = "\n".join(
+        ["{}: {}".format(re.sub(r"([^()]+)", "", k), v) for k, v in pairs])
     doc["content_ltks"] = huqie.qie(doc["content_with_weight"])
     doc["content_sm_ltks"] = huqie.qieqie(doc["content_ltks"])
-    for n, _ in field_map.items(): doc[n] = resume[n]
+    for n, _ in field_map.items():
+        doc[n] = resume[n]
 
     print(doc)
-    KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": field_map})
+    KnowledgebaseService.update_parser_config(
+        kwargs["kb_id"], {"field_map": field_map})
     return [doc]
 
 
 if __name__ == "__main__":
     import sys
+
     def dummy(a, b):
         pass
     chunk(sys.argv[1], callback=dummy)
diff --git a/rag/app/table.py b/rag/app/table.py
index ee66356..a078a1a 100644
--- a/rag/app/table.py
+++ b/rag/app/table.py
@@ -28,10 +28,15 @@ class Excel(object):
             rows = list(ws.rows)
             headers = [cell.value for cell in rows[0]]
             missed = set([i for i, h in enumerate(headers) if h is None])
-            headers = [cell.value for i, cell in enumerate(rows[0]) if i not in missed]
+            headers = [
+                cell.value for i,
+                cell in enumerate(
+                    rows[0]) if i not in missed]
             data = []
             for i, r in enumerate(rows[1:]):
-                row = [cell.value for ii, cell in enumerate(r) if ii not in missed]
+                row = [
+                    cell.value for ii,
+                    cell in enumerate(r) if ii not in missed]
                 if len(row) != len(headers):
                     fails.append(str(i))
                     continue
@@ -55,8 +60,10 @@ def trans_datatime(s):
 
 
 def trans_bool(s):
-    if re.match(r"(true|yes|是)$", str(s).strip(), flags=re.IGNORECASE): return ["yes", "是"]
-    if re.match(r"(false|no|否)$", str(s).strip(), flags=re.IGNORECASE): return ["no", "否"]
+    if re.match(r"(true|yes|是)$", str(s).strip(), flags=re.IGNORECASE):
+        return ["yes", "是"]
+    if re.match(r"(false|no|否)$", str(s).strip(), flags=re.IGNORECASE):
+        return ["no", "否"]
 
 
 def column_data_type(arr):
@@ -65,7 +72,8 @@ def column_data_type(arr):
     trans = {t: f for f, t in
              [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
     for a in arr:
-        if a is None: continue
+        if a is None:
+            continue
         if re.match(r"[+-]?[0-9]+(\.0+)?$", str(a).replace("%%", "")):
             counts["int"] += 1
         elif re.match(r"[+-]?[0-9.]+$", str(a).replace("%%", "")):
@@ -79,7 +87,8 @@ def column_data_type(arr):
     counts = sorted(counts.items(), key=lambda x: x[1] * -1)
     ty = counts[0][0]
     for i in range(len(arr)):
-        if arr[i] is None: continue
+        if arr[i] is None:
+            continue
         try:
             arr[i] = trans[ty](str(arr[i]))
         except Exception as e:
@@ -105,7 +114,8 @@ def chunk(filename, binary=None, callback=None, **kwargs):
             with open(filename, "r") as f:
                 while True:
                     l = f.readline()
-                    if not l: break
+                    if not l:
+                        break
                     txt += l
         lines = txt.split("\n")
         fails = []
@@ -127,14 +137,22 @@ def chunk(filename, binary=None, callback=None, **kwargs):
         dfs = [pd.DataFrame(np.array(rows), columns=headers)]
 
     else:
-        raise NotImplementedError("file type not supported yet(excel, text, csv supported)")
+        raise NotImplementedError(
+            "file type not supported yet(excel, text, csv supported)")
 
     res = []
     PY = Pinyin()
-    fieds_map = {"text": "_tks", "int": "_int", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"}
+    fieds_map = {
+        "text": "_tks",
+        "int": "_int",
+        "keyword": "_kwd",
+        "float": "_flt",
+        "datetime": "_dt",
+        "bool": "_kwd"}
     for df in dfs:
         for n in ["id", "_id", "index", "idx"]:
-            if n in df.columns: del df[n]
+            if n in df.columns:
+                del df[n]
         clmns = df.columns.values
         txts = list(copy.deepcopy(clmns))
         py_clmns = [PY.get_pinyins(n)[0].replace("-", "_") for n in clmns]
@@ -143,23 +161,29 @@ def chunk(filename, binary=None, callback=None, **kwargs):
             cln, ty = column_data_type(df[clmns[j]])
             clmn_tys.append(ty)
             df[clmns[j]] = cln
-            if ty == "text": txts.extend([str(c) for c in cln if c])
-        clmns_map = [(py_clmns[j] + fieds_map[clmn_tys[j]], clmns[j]) for i in range(len(clmns))]
+            if ty == "text":
+                txts.extend([str(c) for c in cln if c])
+        clmns_map = [(py_clmns[j] + fieds_map[clmn_tys[j]], clmns[j])
+                     for i in range(len(clmns))]
 
         eng = is_english(txts)
         for ii, row in df.iterrows():
             d = {}
             row_txt = []
             for j in range(len(clmns)):
-                if row[clmns[j]] is None: continue
+                if row[clmns[j]] is None:
+                    continue
                 fld = clmns_map[j][0]
-                d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie(row[clmns[j]])
+                d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie(
+                    row[clmns[j]])
                 row_txt.append("{}:{}".format(clmns[j], row[clmns[j]]))
-            if not row_txt: continue
+            if not row_txt:
+                continue
             tokenize(d, "; ".join(row_txt), eng)
             res.append(d)
 
-        KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}})
+        KnowledgebaseService.update_parser_config(
+            kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}})
     callback(0.6, "")
 
     return res
@@ -168,9 +192,7 @@ def chunk(filename, binary=None, callback=None, **kwargs):
 if __name__ == "__main__":
     import sys
 
-
     def dummy(a, b):
         pass
 
-
     chunk(sys.argv[1], callback=dummy)
diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py
index 3162636..7868eb2 100644
--- a/rag/llm/chat_model.py
+++ b/rag/llm/chat_model.py
@@ -58,3 +58,21 @@ class QWenChat(Base):
         if response.status_code == HTTPStatus.OK:
             return response.output.choices[0]['message']['content'], response.usage.output_tokens
         return response.message, 0
+
+
+from zhipuai import ZhipuAI
+class ZhipuChat(Base):
+    def __init__(self, key, model_name="glm-3-turbo"):
+        self.client = ZhipuAI(api_key=key)
+        self.model_name = model_name
+
+    def chat(self, system, history, gen_conf):
+        from http import HTTPStatus
+        history.insert(0, {"role": "system", "content": system})
+        response = self.client.chat.completions.create(
+            self.model_name,
+            messages=history
+        )
+        if response.status_code == HTTPStatus.OK:
+            return response.output.choices[0]['message']['content'], response.usage.completion_tokens
+        return response.message, 0
\ No newline at end of file
diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py
index 67816a1..d663f90 100644
--- a/rag/llm/cv_model.py
+++ b/rag/llm/cv_model.py
@@ -61,7 +61,7 @@ class Base(ABC):
 
 class GptV4(Base):
     def __init__(self, key, model_name="gpt-4-vision-preview"):
-        self.client = OpenAI(api_key = key)
+        self.client = OpenAI(api_key=key)
         self.model_name = model_name
 
     def describe(self, image, max_tokens=300):
@@ -89,3 +89,22 @@ class QWenCV(Base):
         if response.status_code == HTTPStatus.OK:
             return response.output.choices[0]['message']['content'], response.usage.output_tokens
         return response.message, 0
+
+
+from zhipuai import ZhipuAI
+
+
+class Zhipu4V(Base):
+    def __init__(self, key, model_name="glm-4v"):
+        self.client = ZhipuAI(api_key=key)
+        self.model_name = model_name
+
+    def describe(self, image, max_tokens=1024):
+        b64 = self.image2base64(image)
+
+        res = self.client.chat.completions.create(
+            model=self.model_name,
+            messages=self.prompt(b64),
+            max_tokens=max_tokens,
+        )
+        return res.choices[0].message.content.strip(), res.usage.total_tokens
diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py
index be914e6..94277d9 100644
--- a/rag/llm/embedding_model.py
+++ b/rag/llm/embedding_model.py
@@ -19,7 +19,6 @@ import dashscope
 from openai import OpenAI
 from FlagEmbedding import FlagModel
 import torch
-import os
 import numpy as np
 
 from rag.utils import num_tokens_from_string
@@ -114,4 +113,21 @@ class QWenEmbed(Base):
                 input=text[:2048],
                 text_type="query"
             )
-        return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["input_tokens"]
\ No newline at end of file
+        return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["input_tokens"]
+
+
+from zhipuai import ZhipuAI
+class ZhipuEmbed(Base):
+    def __init__(self, key, model_name="embedding-2"):
+        self.client = ZhipuAI(api_key=key)
+        self.model_name = model_name
+
+    def encode(self, texts: list, batch_size=32):
+        res = self.client.embeddings.create(input=texts,
+                                            model=self.model_name)
+        return np.array([d.embedding for d in res.data]), res.usage.total_tokens
+
+    def encode_queries(self, text):
+        res = self.client.embeddings.create(input=text,
+                                            model=self.model_name)
+        return np.array(res["data"][0]["embedding"]), res.usage.total_tokens
\ No newline at end of file
diff --git a/rag/nlp/search.py b/rag/nlp/search.py
index 9a23781..85f2449 100644
--- a/rag/nlp/search.py
+++ b/rag/nlp/search.py
@@ -268,9 +268,9 @@ class Dealer:
         dim = len(sres.query_vector)
         start_idx = (page - 1) * page_size
         for i in idx:
-            ranks["total"] += 1
             if sim[i] < similarity_threshold:
                 break
+            ranks["total"] += 1
             start_idx -= 1
             if start_idx >= 0:
                 continue
@@ -280,6 +280,7 @@ class Dealer:
                 break
             id = sres.ids[i]
             dnm = sres.field[id]["docnm_kwd"]
+            did = sres.field[id]["doc_id"]
             d = {
                 "chunk_id": id,
                 "content_ltks": sres.field[id]["content_ltks"],
@@ -296,8 +297,9 @@ class Dealer:
             }
             ranks["chunks"].append(d)
             if dnm not in ranks["doc_aggs"]:
-                ranks["doc_aggs"][dnm] = 0
-            ranks["doc_aggs"][dnm] += 1
+                ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
+            ranks["doc_aggs"][dnm]["count"] += 1
+        ranks["doc_aggs"] = [{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k,v in sorted(ranks["doc_aggs"].items(), key=lambda x:x[1]["count"]*-1)]
 
         return ranks
 
diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py
index def9889..ed86987 100644
--- a/rag/svr/task_executor.py
+++ b/rag/svr/task_executor.py
@@ -36,7 +36,7 @@ from rag.nlp import search
 from io import BytesIO
 import pandas as pd
 
-from rag.app import laws, paper, presentation, manual, qa, table,book
+from rag.app import laws, paper, presentation, manual, qa, table, book, resume
 
 from api.db import LLMType, ParserType
 from api.db.services.document_service import DocumentService
@@ -55,6 +55,7 @@ FACTORY = {
     ParserType.LAWS.value: laws,
     ParserType.QA.value: qa,
     ParserType.TABLE.value: table,
+    ParserType.RESUME.value: resume,
 }
 
 
@@ -119,7 +120,7 @@ def build(row, cvmdl):
     try:
         cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
         cks = chunker.chunk(row["name"], MINIO.get(row["kb_id"], row["location"]), row["from_page"], row["to_page"],
-                            callback, kb_id=row["kb_id"])
+                            callback, kb_id=row["kb_id"], parser_config=row["parser_config"])
     except Exception as e:
         if re.search("(No such file|not found)", str(e)):
             callback(-1, "Can not find file <%s>" % row["doc_name"])
@@ -171,7 +172,7 @@ def init_kb(row):
         open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
 
 
-def embedding(docs, mdl):
+def embedding(docs, mdl, parser_config={}):
     tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [d["content_with_weight"] for d in docs]
     tk_count = 0
     if len(tts) == len(cnts):
@@ -180,7 +181,8 @@ def embedding(docs, mdl):
 
     cnts, c = mdl.encode(cnts)
     tk_count += c
-    vects = (0.1 * tts + 0.9 * cnts) if len(tts) == len(cnts) else cnts
+    title_w = float(parser_config.get("filename_embd_weight", 0.1))
+    vects = (title_w * tts + (1-title_w) * cnts) if len(tts) == len(cnts) else cnts
 
     assert len(vects) == len(docs)
     for i, d in enumerate(docs):
@@ -216,7 +218,7 @@ def main(comm, mod):
         # TODO: exception handler
         ## set_progress(r["did"], -1, "ERROR: ")
         try:
-            tk_count = embedding(cks, embd_mdl)
+            tk_count = embedding(cks, embd_mdl, r["parser_config"])
         except Exception as e:
             callback(-1, "Embedding error:{}".format(str(e)))
             cron_logger.error(str(e))
-- 
GitLab