diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 1d9894383b07c1edca90ff5c8f71c4713952448d..3ee943944e9288cc3d0c0dfa610418b5d95b3183 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -193,14 +193,14 @@ def chat(dialog, messages, **kwargs): embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING) chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) + prompt_config = dialog.prompt_config field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) # try to use sql if field mapping is good to go if field_map: chat_logger.info("Use SQL to retrieval:{}".format(questions[-1])) - ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl) + ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True)) if ans: return ans - prompt_config = dialog.prompt_config for p in prompt_config["parameters"]: if p["key"] == "knowledge": continue @@ -255,6 +255,7 @@ def chat(dialog, messages, **kwargs): d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] if not recall_docs: recall_docs = kbinfos["doc_aggs"] kbinfos["doc_aggs"] = recall_docs + for c in kbinfos["chunks"]: if c.get("vector"): del c["vector"] @@ -263,7 +264,7 @@ def chat(dialog, messages, **kwargs): return {"answer": answer, "reference": kbinfos} -def use_sql(question, field_map, tenant_id, chat_mdl): +def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): sys_prompt = "ä˝ ćŻä¸€ä¸ŞDBAă€‚ä˝ éś€č¦čż™ĺŻąä»Ąä¸‹čˇ¨çš„ĺ—ć®µç»“ćž„ďĽŚć ąćŤ®ç”¨ć·çš„é—®é˘ĺ—表,写出最ĺŽä¸€ä¸Şé—®é˘ĺŻąĺş”çš„SQL。" user_promt = """ 表ĺŤďĽš{}; @@ -353,12 +354,16 @@ def use_sql(question, field_map, tenant_id, chat_mdl): # compose markdown table clmns = "|" + "|".join([re.sub(r"(/.*|ďĽ[^ďĽďĽ‰]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|") + line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \ ("|------|" if docid_idx and docid_idx else "") + rows = ["|" + "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]] - rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) + if quota: + rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) + else: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows) if not docid_idx or not docnm_idx: diff --git a/api/db/init_data.py b/api/db/init_data.py index 2e5026af4d9446cb916ba6017d69c6e50c0f5b57..b813b293634dc16e44256140e1329977ca0a28c8 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -159,6 +159,12 @@ def init_llm_factory(): "max_tokens": 8191, "model_type": LLMType.CHAT.value }, { + "fid": factory_infos[0]["name"], + "llm_name": "gpt-4-turbo", + "tags": "LLM,CHAT,8K", + "max_tokens": 8191, + "model_type": LLMType.CHAT.value + },{ "fid": factory_infos[0]["name"], "llm_name": "gpt-4-32k", "tags": "LLM,CHAT,32K",