diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 13e02aab10eefc7e0e05723b8950d0ff99ea519c..1f7e6cef6e0cc7fad0fa834e264b64c2ef17c14d 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -176,7 +176,7 @@ def chat(dialog, messages, **kwargs): if not llm: raise LookupError("LLM(%s) not found" % dialog.llm_id) llm = llm[0] - question = messages[-1]["content"] + questions = [m["content"] for m in messages if m["role"] == "user"] embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING) chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) @@ -184,7 +184,7 @@ def chat(dialog, messages, **kwargs): ## try to use sql if field mapping is good to go if field_map: stat_logger.info("Use SQL to retrieval.") - markdown_tbl, chunks = use_sql(question, field_map, dialog.tenant_id, chat_mdl) + markdown_tbl, chunks = use_sql("\n".join(questions), field_map, dialog.tenant_id, chat_mdl) if markdown_tbl: return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}} @@ -195,7 +195,9 @@ def chat(dialog, messages, **kwargs): if p["key"] not in kwargs: prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ") - kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, + for _ in range(len(questions)//2): + questions.append(questions[-1]) + kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold, dialog.vector_similarity_weight, top=1024, aggs=False) knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] @@ -224,13 +226,14 @@ def chat(dialog, messages, **kwargs): def use_sql(question, field_map, tenant_id, chat_mdl): - sys_prompt = "ä˝ ćŻä¸€ä¸ŞDBAă€‚ä˝ éś€č¦čż™ĺŻąä»Ąä¸‹čˇ¨çš„ĺ—ć®µç»“ćž„ďĽŚć ąćŤ®ć‘çš„é—®é˘ĺ†™ĺ‡şsql。" + sys_prompt = "ä˝ ćŻä¸€ä¸ŞDBAă€‚ä˝ éś€č¦čż™ĺŻąä»Ąä¸‹čˇ¨çš„ĺ—ć®µç»“ćž„ďĽŚć ąćŤ®ç”¨ć·çš„é—®é˘ĺ—表,写出最ĺŽä¸€ä¸Şé—®é˘ĺŻąĺş”çš„SQL。" user_promt = """ 表ĺŤďĽš{}; 数据库表ĺ—段说ćŽĺ¦‚下: {} -é—®é˘ďĽš{} +é—®é˘ĺ¦‚下: +{} 请写出SQL,且只č¦SQL,不č¦ćś‰ĺ…¶ä»–说ćŽĺŹŠć–‡ĺ—。 """.format( index_name(tenant_id), diff --git a/api/apps/user_app.py b/api/apps/user_app.py index 2c1c814e5b9942f66b9e301d1cf529de9436c344..7ff2364f62df315fc41cd504e2fe1f0901deed5c 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -100,12 +100,14 @@ def github_callback(): if len(users) > 1: raise Exception('Same E-mail exist!') user = users[0] login_user(user) + return redirect("/?auth=%s"%user.get_id()) except Exception as e: rollback_user_registration(user_id) stat_logger.exception(e) return redirect("/?error=%s"%str(e)) - - return redirect("/?auth=%s"%user_id) + user = users[0] + login_user(user) + return redirect("/?auth=%s" % user.get_id()) def user_info_from_github(access_token): diff --git a/deepdoc/vision/t_recognizer.py b/deepdoc/vision/t_recognizer.py index 7358c4e0aa47a3e9b2e9bfe86361956667c8a41f..23033e23b0c1fc399a9df47f92dfb94fbbb52cfc 100644 --- a/deepdoc/vision/t_recognizer.py +++ b/deepdoc/vision/t_recognizer.py @@ -28,7 +28,7 @@ def main(args): images, outputs = init_in_out(args) if args.mode.lower() == "layout": labels = LayoutRecognizer.labels - detr = Recognizer(labels, "layout.paper", os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) + detr = Recognizer(labels, "layout", os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) if args.mode.lower() == "tsr": labels = TableStructureRecognizer.labels detr = TableStructureRecognizer() diff --git a/rag/app/presentation.py b/rag/app/presentation.py index 002dc252f800cf1d643adf5fb610be26606586af..16c11bd76efe36a0a4c6d749270644e26f4ae79a 100644 --- a/rag/app/presentation.py +++ b/rag/app/presentation.py @@ -73,12 +73,13 @@ class Pdf(PdfParser): return res -def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs): +def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): """ The supported file formats are pdf, pptx. Every page will be treated as a chunk. And the thumbnail of every page will be stored. PPT file will be parsed by using this method automatically, setting-up for every PPT file is not necessary. """ + eng = lang.lower() == "english" doc = { "docnm_kwd": filename, "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename)) @@ -98,8 +99,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k for pn, (txt,img) in enumerate(pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback)): d = copy.deepcopy(doc) d["image"] = img - d["page_num_obj"] = [pn+1] - tokenize(d, txt, pdf_parser.is_english) + d["page_num_int"] = [pn+1] + d["top_int"] = [0] + d["position_int"].append((pn + 1, 0, img.size[0], 0, img.size[1])) + tokenize(d, txt, eng) res.append(d) return res diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 7b95cd68c0e0461f35c8089fbc1398ae3bd6c309..2389561966e1071f0d50482e7cefc6dcf577b5fd 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -14,9 +14,13 @@ # limitations under the License. # from abc import ABC +from copy import deepcopy + from openai import OpenAI import openai +from rag.nlp import is_english + class Base(ABC): def __init__(self, key, model_name): @@ -34,13 +38,17 @@ class GptTurbo(Base): def chat(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) try: - res = self.client.chat.completions.create( + response = self.client.chat.completions.create( model=self.model_name, messages=history, **gen_conf) - return res.choices[0].message.content.strip(), res.usage.completion_tokens + ans = response.output.choices[0]['message']['content'].strip() + if response.output.choices[0].get("finish_reason", "") == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\nç”±äşŽé•żĺş¦çš„ĺŽźĺ› ďĽŚĺ›žç”被ćŞć–了,č¦ç»§ç»ĺ—?" + return ans, response.usage.completion_tokens except openai.APIError as e: - return "ERROR: "+str(e), 0 + return "**ERROR**: "+str(e), 0 from dashscope import Generation @@ -59,9 +67,16 @@ class QWenChat(Base): result_format='message', **gen_conf ) + ans = "" + tk_count = 0 if response.status_code == HTTPStatus.OK: - return response.output.choices[0]['message']['content'], response.usage.output_tokens - return "ERROR: " + response.message, 0 + ans += response.output.choices[0]['message']['content'] + tk_count += response.usage.output_tokens + if response.output.choices[0].get("finish_reason", "") == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\nç”±äşŽé•żĺş¦çš„ĺŽźĺ› ďĽŚĺ›žç”被ćŞć–了,č¦ç»§ç»ĺ—?" + return ans, tk_count + + return "**ERROR**: " + response.message, tk_count from zhipuai import ZhipuAI @@ -73,11 +88,16 @@ class ZhipuChat(Base): def chat(self, system, history, gen_conf): from http import HTTPStatus if system: history.insert(0, {"role": "system", "content": system}) - response = self.client.chat.completions.create( - self.model_name, - messages=history, - **gen_conf - ) - if response.status_code == HTTPStatus.OK: - return response.output.choices[0]['message']['content'], response.usage.completion_tokens - return "ERROR: " + response.message, 0 \ No newline at end of file + try: + response = self.client.chat.completions.create( + self.model_name, + messages=history, + **gen_conf + ) + ans = response.output.choices[0]['message']['content'].strip() + if response.output.choices[0].get("finish_reason", "") == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\nç”±äşŽé•żĺş¦çš„ĺŽźĺ› ďĽŚĺ›žç”被ćŞć–了,č¦ç»§ç»ĺ—?" + return ans, response.usage.completion_tokens + except Exception as e: + return "**ERROR**: " + str(e), 0 \ No newline at end of file diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 361387daa5a282db330962636785e3dfee1c2266..bce7db41692541ae94da0a1fa274eedc473e1ef5 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -224,12 +224,13 @@ class Dealer: chunks_tks, tkweight, vtweight) mx = np.max(sim) * 0.99 - if mx < 0.35: + if mx < 0.66: continue cites[idx[i]] = list( set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4] res = "" + seted = set([]) for i, p in enumerate(pieces): res += p if i not in idx: @@ -237,7 +238,10 @@ class Dealer: if i not in cites: continue for c in cites[i]: assert int(c) < len(chunk_v) - for c in cites[i]: res += f" ##{c}$$" + for c in cites[i]: + if c in seted:continue + res += f" ##{c}$$" + seted.add(c) return res @@ -318,7 +322,7 @@ class Dealer: if dnm not in ranks["doc_aggs"]: ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0} ranks["doc_aggs"][dnm]["count"] += 1 - ranks["doc_aggs"] = [{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k,v in sorted(ranks["doc_aggs"].items(), key=lambda x:x[1]["count"]*-1)] + ranks["doc_aggs"] = []#[{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k,v in sorted(ranks["doc_aggs"].items(), key=lambda x:x[1]["count"]*-1)] return ranks