From f3477202fefe0438538b32cd442800485b6627fe Mon Sep 17 00:00:00 2001
From: KevinHuSh <kevinhu.sh@gmail.com>
Date: Thu, 28 Mar 2024 11:45:50 +0800
Subject: [PATCH] refine citation (#161)

---
 api/apps/conversation_app.py |  7 ++++---
 rag/app/paper.py             |  2 +-
 rag/nlp/search.py            | 29 ++++++++++++++++-------------
 3 files changed, 21 insertions(+), 17 deletions(-)

diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py
index 5c55d5d..5521cca 100644
--- a/api/apps/conversation_app.py
+++ b/api/apps/conversation_app.py
@@ -194,7 +194,8 @@ def chat(dialog, messages, **kwargs):
     # try to use sql if field mapping is good to go
     if field_map:
         chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
-        return use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl)
+        ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl)
+        if ans: return ans
 
     prompt_config = dialog.prompt_config
     for p in prompt_config["parameters"]:
@@ -305,7 +306,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
 
     tbl, sql = get_table()
     if tbl is None:
-        return None, None
+        return None
     if tbl.get("error") and tried_times <= 2:
         user_promt = """
         表名:{};
@@ -333,7 +334,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
     chat_logger.info("GET table: {}".format(tbl))
     print(tbl)
     if tbl.get("error") or len(tbl["rows"]) == 0:
-        return None, None
+        return None
 
     docid_idx = set([ii for ii, c in enumerate(
         tbl["columns"]) if c["name"] == "doc_id"])
diff --git a/rag/app/paper.py b/rag/app/paper.py
index 8725054..9a75bec 100644
--- a/rag/app/paper.py
+++ b/rag/app/paper.py
@@ -120,7 +120,7 @@ class Pdf(PdfParser):
         print(tbls)
 
         return {
-            "title": title if title else filename,
+            "title": title,
             "authors": " ".join(authors),
             "abstract": abstr,
             "sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes[i:] if
diff --git a/rag/nlp/search.py b/rag/nlp/search.py
index cc9f533..ac92853 100644
--- a/rag/nlp/search.py
+++ b/rag/nlp/search.py
@@ -246,19 +246,22 @@ class Dealer:
         chunks_tks = [huqie.qie(self.qryr.rmWWW(ck)).split(" ")
                       for ck in chunks]
         cites = {}
-        for i, a in enumerate(pieces_):
-            sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
-                                                            chunk_v,
-                                                            huqie.qie(
-                                                                self.qryr.rmWWW(pieces_[i])).split(" "),
-                                                            chunks_tks,
-                                                            tkweight, vtweight)
-            mx = np.max(sim) * 0.99
-            es_logger.info("{} SIM: {}".format(pieces_[i], mx))
-            if mx < 0.63:
-                continue
-            cites[idx[i]] = list(
-                set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
+        thr = 0.63
+        while len(cites.keys()) == 0 and pieces_ and chunks_tks:
+            for i, a in enumerate(pieces_):
+                sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
+                                                                chunk_v,
+                                                                huqie.qie(
+                                                                    self.qryr.rmWWW(pieces_[i])).split(" "),
+                                                                chunks_tks,
+                                                                tkweight, vtweight)
+                mx = np.max(sim) * 0.99
+                es_logger.info("{} SIM: {}".format(pieces_[i], mx))
+                if mx < thr:
+                    continue
+                cites[idx[i]] = list(
+                    set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
+            thr *= 0.8
 
         res = ""
         seted = set([])
-- 
GitLab