From 4608cccd0552c06b54cdd000dab854f59ac86175 Mon Sep 17 00:00:00 2001
From: KevinHuSh <kevinhu.sh@gmail.com>
Date: Mon, 15 Apr 2024 08:58:42 +0800
Subject: [PATCH] add new model gpt-3-turbo (#352)

### What problem does this PR solve?


Issue link:#351

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
---
 api/apps/conversation_app.py | 13 +++++++++----
 api/db/init_data.py          |  6 ++++++
 2 files changed, 15 insertions(+), 4 deletions(-)

diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py
index 1d98943..3ee9439 100644
--- a/api/apps/conversation_app.py
+++ b/api/apps/conversation_app.py
@@ -193,14 +193,14 @@ def chat(dialog, messages, **kwargs):
     embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
     chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
 
+    prompt_config = dialog.prompt_config
     field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
     # try to use sql if field mapping is good to go
     if field_map:
         chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
-        ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl)
+        ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
         if ans: return ans
 
-    prompt_config = dialog.prompt_config
     for p in prompt_config["parameters"]:
         if p["key"] == "knowledge":
             continue
@@ -255,6 +255,7 @@ def chat(dialog, messages, **kwargs):
             d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
         if not recall_docs: recall_docs = kbinfos["doc_aggs"]
         kbinfos["doc_aggs"] = recall_docs
+
     for c in kbinfos["chunks"]:
         if c.get("vector"):
             del c["vector"]
@@ -263,7 +264,7 @@ def chat(dialog, messages, **kwargs):
     return {"answer": answer, "reference": kbinfos}
 
 
-def use_sql(question, field_map, tenant_id, chat_mdl):
+def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
     sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。"
     user_promt = """
 表名:{};
@@ -353,12 +354,16 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
     # compose markdown table
     clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
                            tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
+
     line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
         ("|------|" if docid_idx and docid_idx else "")
+
     rows = ["|" +
             "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
             "|" for r in tbl["rows"]]
-    rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
+    if quota:
+        rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
+    else: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
     rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
 
     if not docid_idx or not docnm_idx:
diff --git a/api/db/init_data.py b/api/db/init_data.py
index 2e5026a..b813b29 100644
--- a/api/db/init_data.py
+++ b/api/db/init_data.py
@@ -159,6 +159,12 @@ def init_llm_factory():
             "max_tokens": 8191,
             "model_type": LLMType.CHAT.value
         }, {
+            "fid": factory_infos[0]["name"],
+            "llm_name": "gpt-4-turbo",
+            "tags": "LLM,CHAT,8K",
+            "max_tokens": 8191,
+            "model_type": LLMType.CHAT.value
+        },{
             "fid": factory_infos[0]["name"],
             "llm_name": "gpt-4-32k",
             "tags": "LLM,CHAT,32K",
-- 
GitLab