From 4608cccd0552c06b54cdd000dab854f59ac86175 Mon Sep 17 00:00:00 2001 From: KevinHuSh <kevinhu.sh@gmail.com> Date: Mon, 15 Apr 2024 08:58:42 +0800 Subject: [PATCH] add new model gpt-3-turbo (#352) ### What problem does this PR solve? Issue link:#351 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/conversation_app.py | 13 +++++++++---- api/db/init_data.py | 6 ++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 1d98943..3ee9439 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -193,14 +193,14 @@ def chat(dialog, messages, **kwargs): embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING) chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) + prompt_config = dialog.prompt_config field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) # try to use sql if field mapping is good to go if field_map: chat_logger.info("Use SQL to retrieval:{}".format(questions[-1])) - ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl) + ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True)) if ans: return ans - prompt_config = dialog.prompt_config for p in prompt_config["parameters"]: if p["key"] == "knowledge": continue @@ -255,6 +255,7 @@ def chat(dialog, messages, **kwargs): d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] if not recall_docs: recall_docs = kbinfos["doc_aggs"] kbinfos["doc_aggs"] = recall_docs + for c in kbinfos["chunks"]: if c.get("vector"): del c["vector"] @@ -263,7 +264,7 @@ def chat(dialog, messages, **kwargs): return {"answer": answer, "reference": kbinfos} -def use_sql(question, field_map, tenant_id, chat_mdl): +def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): sys_prompt = "ä˝ ćŻä¸€ä¸ŞDBAă€‚ä˝ éś€č¦čż™ĺŻąä»Ąä¸‹čˇ¨çš„ĺ—ć®µç»“ćž„ďĽŚć ąćŤ®ç”¨ć·çš„é—®é˘ĺ—表,写出最ĺŽä¸€ä¸Şé—®é˘ĺŻąĺş”çš„SQL。" user_promt = """ 表ĺŤďĽš{}; @@ -353,12 +354,16 @@ def use_sql(question, field_map, tenant_id, chat_mdl): # compose markdown table clmns = "|" + "|".join([re.sub(r"(/.*|ďĽ[^ďĽďĽ‰]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|") + line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \ ("|------|" if docid_idx and docid_idx else "") + rows = ["|" + "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]] - rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) + if quota: + rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) + else: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows) if not docid_idx or not docnm_idx: diff --git a/api/db/init_data.py b/api/db/init_data.py index 2e5026a..b813b29 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -159,6 +159,12 @@ def init_llm_factory(): "max_tokens": 8191, "model_type": LLMType.CHAT.value }, { + "fid": factory_infos[0]["name"], + "llm_name": "gpt-4-turbo", + "tags": "LLM,CHAT,8K", + "max_tokens": 8191, + "model_type": LLMType.CHAT.value + },{ "fid": factory_infos[0]["name"], "llm_name": "gpt-4-32k", "tags": "LLM,CHAT,32K", -- GitLab