diff --git a/Dockerfile b/Dockerfile index c174ccb800294ca4c19f61a41af52457f70c6152..49d3045cc1543199469e06b4283dd1d2b90c1148 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,6 +14,7 @@ ADD ./rag ./rag ENV PYTHONPATH=/ragflow/ ENV HF_ENDPOINT=https://hf-mirror.com +/root/miniconda3/envs/py11/bin/pip install peewee==3.17.1 ADD docker/entrypoint.sh ./entrypoint.sh RUN chmod +x ./entrypoint.sh diff --git a/Dockerfile.cuda b/Dockerfile.cuda new file mode 100644 index 0000000000000000000000000000000000000000..a5db6177dc3a0be4036e7ad421d1389dfff2e9ea --- /dev/null +++ b/Dockerfile.cuda @@ -0,0 +1,26 @@ +FROM swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow-base:v1.0 +USER root + +WORKDIR /ragflow + +## for cuda > 12.0 +RUN /root/miniconda3/envs/py11/bin/pip uninstall -y onnxruntime-gpu +RUN /root/miniconda3/envs/py11/bin/pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ + + +ADD ./web ./web +RUN cd ./web && npm i && npm run build + +ADD ./api ./api +ADD ./conf ./conf +ADD ./deepdoc ./deepdoc +ADD ./rag ./rag + +ENV PYTHONPATH=/ragflow/ +ENV HF_ENDPOINT=https://hf-mirror.com + +/root/miniconda3/envs/py11/bin/pip install peewee==3.17.1 +ADD docker/entrypoint.sh ./entrypoint.sh +RUN chmod +x ./entrypoint.sh + +ENTRYPOINT ["./entrypoint.sh"] diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index a0c2369979d9c1a3e9f0246b7cbd6baa5056cb1c..ef99a162c683a485d04f3bb83f7913eb7b376a84 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -21,7 +21,7 @@ from api.db.services.dialog_service import DialogService, ConversationService from api.db import LLMType from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMService, LLMBundle -from api.settings import access_logger, stat_logger, retrievaler +from api.settings import access_logger, stat_logger, retrievaler, chat_logger from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils import get_uuid from api.utils.api_utils import get_json_result @@ -183,10 +183,10 @@ def chat(dialog, messages, **kwargs): field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) ## try to use sql if field mapping is good to go if field_map: - stat_logger.info("Use SQL to retrieval.") - markdown_tbl, chunks = use_sql("\n".join(questions), field_map, dialog.tenant_id, chat_mdl) + chat_logger.info("Use SQL to retrieval:{}".format(questions[-1])) + markdown_tbl, chunks = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl) if markdown_tbl: - return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}} + return {"answer": markdown_tbl, "reference": {"chunks": chunks, "doc_aggs": []}} prompt_config = dialog.prompt_config for p in prompt_config["parameters"]: @@ -201,6 +201,7 @@ def chat(dialog, messages, **kwargs): dialog.similarity_threshold, dialog.vector_similarity_weight, top=1024, aggs=False) knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] + chat_logger.info("{}->{}".format(" ".join(questions), "\n->".join(knowledges))) if not knowledges and prompt_config.get("empty_response"): return {"answer": prompt_config["empty_response"], "reference": kbinfos} @@ -212,7 +213,7 @@ def chat(dialog, messages, **kwargs): if "max_tokens" in gen_conf: gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count) answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf) - stat_logger.info("User: {}|Assistant: {}".format(msg[-1]["content"], answer)) + chat_logger.info("User: {}|Assistant: {}".format(msg[-1]["content"], answer)) if knowledges: answer, idx = retrievaler.insert_citations(answer, @@ -237,47 +238,83 @@ def use_sql(question, field_map, tenant_id, chat_mdl): é—®é˘ĺ¦‚下: {} -请写出SQL,且只č¦SQL,不č¦ćś‰ĺ…¶ä»–说ćŽĺŹŠć–‡ĺ—。 +请写出SQL, 且只č¦SQL,不č¦ćś‰ĺ…¶ä»–说ćŽĺŹŠć–‡ĺ—。 """.format( index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question ) - sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.06}) - stat_logger.info(f"“{question}” get SQL: {sql}") - sql = re.sub(r"[\r\n]+", " ", sql.lower()) - sql = re.sub(r".*?select ", "select ", sql.lower()) - sql = re.sub(r" +", " ", sql) - sql = re.sub(r"([;;]|```).*", "", sql) - if sql[:len("select ")] != "select ": - return None, None - if sql[:len("select *")] != "select *": - sql = "select doc_id,docnm_kwd," + sql[6:] - else: - flds = [] - for k in field_map.keys(): - if k in forbidden_select_fields4resume:continue - if len(flds) > 11:break - flds.append(k) - sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:] - - stat_logger.info(f"“{question}” get SQL(refined): {sql}") - tbl = retrievaler.sql_retrieval(sql, format="json") - if not tbl or len(tbl["rows"]) == 0: return None, None + tried_times = 0 + def get_table(): + nonlocal sys_prompt, user_promt, question, tried_times + sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.06}) + print(user_promt, sql) + chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}") + sql = re.sub(r"[\r\n]+", " ", sql.lower()) + sql = re.sub(r".*select ", "select ", sql.lower()) + sql = re.sub(r" +", " ", sql) + sql = re.sub(r"([;;]|```).*", "", sql) + if sql[:len("select ")] != "select ": + return None, None + if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()): + if sql[:len("select *")] != "select *": + sql = "select doc_id,docnm_kwd," + sql[6:] + else: + flds = [] + for k in field_map.keys(): + if k in forbidden_select_fields4resume:continue + if len(flds) > 11:break + flds.append(k) + sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:] + + print(f"“{question}” get SQL(refined): {sql}") + + chat_logger.info(f"“{question}” get SQL(refined): {sql}") + tried_times += 1 + return retrievaler.sql_retrieval(sql, format="json"), sql + + tbl, sql = get_table() + if tbl.get("error") and tried_times <= 2: + user_promt = """ + 表ĺŤďĽš{}; + 数据库表ĺ—段说ćŽĺ¦‚下: + {} + + é—®é˘ĺ¦‚下: + {} + + ä˝ ä¸Šä¸€ć¬ˇç»™ĺ‡şçš„é”™čŻŻSQL如下: + {} + + ĺŽĺŹ°ćŠĄé”™ĺ¦‚下: + {} + + čŻ·çş ćŁSQLä¸çš„错误再写一éŤďĽŚä¸”只č¦SQL,不č¦ćś‰ĺ…¶ä»–说ćŽĺŹŠć–‡ĺ—。 + """.format( + index_name(tenant_id), + "\n".join([f"{k}: {v}" for k, v in field_map.items()]), + question, sql, tbl["error"] + ) + tbl, sql = get_table() + chat_logger.info("TRY it again: {}".format(sql)) + + chat_logger.info("GET table: {}".format(tbl)) + print(tbl) + if tbl.get("error") or len(tbl["rows"]) == 0: return None, None docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"]) docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"]) clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)] # compose markdown table - clmns = "|".join([re.sub(r"(/.*|ďĽ[^ďĽďĽ‰]+))", "", field_map.get(tbl["columns"][i]["name"], f"C{i}")) for i in clmn_idx]) + "|原文" - line = "|".join(["------" for _ in range(len(clmn_idx))]) + "|------" - rows = ["|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]] + clmns = "|"+"|".join([re.sub(r"(/.*|ďĽ[^ďĽďĽ‰]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|原文|" if docid_idx and docid_idx else "|") + line = "|"+"|".join(["------" for _ in range(len(clmn_idx))]) + ("|------|" if docid_idx and docid_idx else "") + rows = ["|"+"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]] if not docid_idx or not docnm_idx: - access_logger.error("SQL missing field: " + sql) + chat_logger.warning("SQL missing field: " + sql) return "\n".join([clmns, line, "\n".join(rows)]), [] - rows = "\n".join([r + f"##{ii}$$" for ii, r in enumerate(rows)]) + rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) docid_idx = list(docid_idx)[0] docnm_idx = list(docnm_idx)[0] return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]] diff --git a/api/db/db_models.py b/api/db/db_models.py index b02ea2c860cebb5b19502f8aeb0c708229598e8a..020899198683efc8408a97b7a68c69dbe87d7298 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -502,7 +502,7 @@ class Document(DataBaseModel): token_num = IntegerField(default=0) chunk_num = IntegerField(default=0) progress = FloatField(default=0) - progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="") + progress_msg = TextField(null=True, help_text="process message", default="") process_begin_at = DateTimeField(null=True) process_duation = FloatField(default=0) run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0") @@ -520,7 +520,7 @@ class Task(DataBaseModel): begin_at = DateTimeField(null=True) process_duation = FloatField(default=0) progress = FloatField(default=0) - progress_msg = TextField(max_length=4096, null=True, help_text="process message", default="") + progress_msg = TextField(null=True, help_text="process message", default="") class Dialog(DataBaseModel): diff --git a/api/db/init_data.py b/api/db/init_data.py index b3ef43cdac1cb7a0e39c98e20348fc144ca2244b..de201d3d0f6e46ddecc321278270c4cc86802fb4 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -90,6 +90,17 @@ def init_llm_factory(): "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", "status": "1", }, + { + "name": "Local", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", + "status": "0", + },{ + "name": "Moonshot", + "logo": "", + "tags": "LLM,TEXT EMBEDDING", + "status": "1", + } # { # "name": "ć–‡ĺżä¸€č¨€", # "logo": "", @@ -155,6 +166,12 @@ def init_llm_factory(): "tags": "LLM,CHAT,32K", "max_tokens": 32768, "model_type": LLMType.CHAT.value + },{ + "fid": factory_infos[1]["name"], + "llm_name": "qwen-max-1201", + "tags": "LLM,CHAT,6K", + "max_tokens": 5899, + "model_type": LLMType.CHAT.value },{ "fid": factory_infos[1]["name"], "llm_name": "text-embedding-v2", @@ -201,6 +218,46 @@ def init_llm_factory(): "max_tokens": 512, "model_type": LLMType.EMBEDDING.value }, + # ---------------------- 本地 ---------------------- + { + "fid": factory_infos[3]["name"], + "llm_name": "qwen-14B-chat", + "tags": "LLM,CHAT,", + "max_tokens": 8191, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[3]["name"], + "llm_name": "flag-enbedding", + "tags": "TEXT EMBEDDING,", + "max_tokens": 128 * 1000, + "model_type": LLMType.EMBEDDING.value + }, + # ------------------------ Moonshot ----------------------- + { + "fid": factory_infos[4]["name"], + "llm_name": "moonshot-v1-8k", + "tags": "LLM,CHAT,", + "max_tokens": 7900, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[4]["name"], + "llm_name": "flag-enbedding", + "tags": "TEXT EMBEDDING,", + "max_tokens": 128 * 1000, + "model_type": LLMType.EMBEDDING.value + },{ + "fid": factory_infos[4]["name"], + "llm_name": "moonshot-v1-32k", + "tags": "LLM,CHAT,", + "max_tokens": 32768, + "model_type": LLMType.CHAT.value + },{ + "fid": factory_infos[4]["name"], + "llm_name": "moonshot-v1-128k", + "tags": "LLM,CHAT", + "max_tokens": 128 * 1000, + "model_type": LLMType.CHAT.value + }, ] for info in factory_infos: LLMFactoriesService.save(**info) diff --git a/api/settings.py b/api/settings.py index 956e2f363cadff757508fd601c7ecf3f4d6eb7e4..ee3445cbbca2994abcb7339939b92c0d8aa01be0 100644 --- a/api/settings.py +++ b/api/settings.py @@ -29,6 +29,7 @@ LoggerFactory.LEVEL = 10 stat_logger = getLogger("stat") access_logger = getLogger("access") database_logger = getLogger("database") +chat_logger = getLogger("chat") API_VERSION = "v1" RAG_FLOW_SERVICE_NAME = "ragflow" @@ -69,9 +70,15 @@ default_llm = { "image2text_model": "glm-4v", "asr_model": "", }, - "local": { - "chat_model": "", - "embedding_model": "", + "Local": { + "chat_model": "qwen-14B-chat", + "embedding_model": "flag-enbedding", + "image2text_model": "", + "asr_model": "", + }, + "Moonshot": { + "chat_model": "moonshot-v1-8k", + "embedding_model": "flag-enbedding", "image2text_model": "", "asr_model": "", } @@ -86,7 +93,7 @@ EMBEDDING_MDL = default_llm[LLM_FACTORY]["embedding_model"] ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] -API_KEY = LLM.get("api_key", "infiniflow API Key") +API_KEY = LLM.get("api_key", "") PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture") # distribution diff --git a/deepdoc/parser/excel_parser.py b/deepdoc/parser/excel_parser.py index d2054f1f4d6f38cf5de30a0bc69525dc93f213b1..79c45e805d6d350fdf3d68aecc2874c75cf7a3f5 100644 --- a/deepdoc/parser/excel_parser.py +++ b/deepdoc/parser/excel_parser.py @@ -34,7 +34,7 @@ class HuExcelParser: total = 0 for sheetname in wb.sheetnames: ws = wb[sheetname] - total += len(ws.rows) + total += len(list(ws.rows)) return total if fnm.split(".")[-1].lower() in ["csv", "txt"]: diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index 9ac5ecc0c8c824426cb695bb5bb10d8f8a7f1645..fdfe6a28ebef21fc7feb50f43d68d8124a7f215a 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -655,14 +655,14 @@ class HuParser: #if min(tv, fv) > 2000: # i += 1 # continue - if tv < fv: + if tv < fv and tk: tables[tk].insert(0, c) logging.debug( "TABLE:" + self.boxes[i]["text"] + "; Cap: " + tk) - else: + elif fk: figures[fk].insert(0, c) logging.debug( "FIGURE:" + diff --git a/deepdoc/parser/ppt_parser.py b/deepdoc/parser/ppt_parser.py index 222899de175c684c36fc77d6486869e1a39e999c..899103a0d98e030638bde240a939e7cc4a29ac6b 100644 --- a/deepdoc/parser/ppt_parser.py +++ b/deepdoc/parser/ppt_parser.py @@ -31,7 +31,7 @@ class HuPptParser(object): if shape.shape_type == 6: texts = [] - for p in shape.shapes: + for p in sorted(shape.shapes, key=lambda x: (x.top//10, x.left)): t = self.__extract(p) if t: texts.append(t) return "\n".join(texts) @@ -46,7 +46,7 @@ class HuPptParser(object): if i < from_page: continue if i >= to_page:break texts = [] - for shape in slide.shapes: + for shape in sorted(slide.shapes, key=lambda x: (x.top//10, x.left)): txt = self.__extract(shape) if txt: texts.append(txt) txts.append("\n".join(texts)) diff --git a/deepdoc/vision/ocr.py b/deepdoc/vision/ocr.py index 13043e4591941a8835fba92516fb11719129ab88..5025219de037d52dc9a7f36da1e187de85573f0e 100644 --- a/deepdoc/vision/ocr.py +++ b/deepdoc/vision/ocr.py @@ -64,10 +64,15 @@ def load_model(model_dir, nm): raise ValueError("not find model file path {}".format( model_file_path)) + options = ort.SessionOptions() + options.enable_cpu_mem_arena = False + options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL + options.intra_op_num_threads = 2 + options.inter_op_num_threads = 2 if ort.get_device() == "GPU": - sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider']) + sess = ort.InferenceSession(model_file_path, options=options, providers=['CUDAExecutionProvider']) else: - sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider']) + sess = ort.InferenceSession(model_file_path, options=options, providers=['CPUExecutionProvider']) return sess, sess.get_inputs()[0] @@ -325,7 +330,13 @@ class TextRecognizer(object): input_dict = {} input_dict[self.input_tensor.name] = norm_img_batch - outputs = self.predictor.run(None, input_dict) + for i in range(100000): + try: + outputs = self.predictor.run(None, input_dict) + break + except Exception as e: + if i >= 3: raise e + time.sleep(5) preds = outputs[0] rec_result = self.postprocess_op(preds) for rno in range(len(rec_result)): @@ -430,7 +441,13 @@ class TextDetector(object): img = img.copy() input_dict = {} input_dict[self.input_tensor.name] = img - outputs = self.predictor.run(None, input_dict) + for i in range(100000): + try: + outputs = self.predictor.run(None, input_dict) + break + except Exception as e: + if i >= 3: raise e + time.sleep(5) post_result = self.postprocess_op({"maps": outputs[0]}, shape_list) dt_boxes = post_result[0]['points'] diff --git a/deepdoc/vision/recognizer.py b/deepdoc/vision/recognizer.py index 2f9123b99214a2ae79b5e9c1f3d19e54dd3a6bd8..33adff27adf9eaf2907b1af4dd3e66c284cea842 100644 --- a/deepdoc/vision/recognizer.py +++ b/deepdoc/vision/recognizer.py @@ -42,7 +42,9 @@ class Recognizer(object): raise ValueError("not find model file path {}".format( model_file_path)) if ort.get_device() == "GPU": - self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider']) + options = ort.SessionOptions() + options.enable_cpu_mem_arena = False + self.ort_sess = ort.InferenceSession(model_file_path, options=options, providers=[('CUDAExecutionProvider')]) else: self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider']) self.input_names = [node.name for node in self.ort_sess.get_inputs()] diff --git a/rag/app/table.py b/rag/app/table.py index 3a69fe66533245690c6d9f38846dfd4a7e69e0ff..3b6cff38ea57a924b40c7e2a7cea3c8d69b145d5 100644 --- a/rag/app/table.py +++ b/rag/app/table.py @@ -67,7 +67,7 @@ class Excel(ExcelParser): def trans_datatime(s): try: - return datetime_parse(s.strip()).strftime("%Y-%m-%dT%H:%M:%S") + return datetime_parse(s.strip()).strftime("%Y-%m-%d %H:%M:%S") except Exception as e: pass @@ -80,6 +80,7 @@ def trans_bool(s): def column_data_type(arr): + arr = list(arr) uni = len(set([a for a in arr if a is not None])) counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0} trans = {t: f for f, t in @@ -130,7 +131,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese if re.search(r"\.xlsx?$", filename, re.IGNORECASE): callback(0.1, "Start to parse.") excel_parser = Excel() - dfs = excel_parser(filename, binary, callback) + dfs = excel_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback) elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE): callback(0.1, "Start to parse.") txt = "" @@ -188,7 +189,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese df[clmns[j]] = cln if ty == "text": txts.extend([str(c) for c in cln if c]) - clmns_map = [(py_clmns[i] + fieds_map[clmn_tys[i]], clmns[i]) + clmns_map = [(py_clmns[i] + fieds_map[clmn_tys[i]], clmns[i].replace("_", " ")) for i in range(len(clmns))] eng = lang.lower() == "english"#is_english(txts) @@ -201,6 +202,8 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese for j in range(len(clmns)): if row[clmns[j]] is None: continue + if not str(row[clmns[j]]): + continue fld = clmns_map[j][0] d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie( row[clmns[j]]) diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index c2c27265f93d2bd1104ee4562cb482ce531e4e78..cc4e46269a8ec6018f9f7b452fe4d1819bb148d6 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -19,18 +19,20 @@ from .cv_model import * EmbeddingModel = { - "local": HuEmbedding, + "Local": HuEmbedding, "OpenAI": OpenAIEmbed, "通义ĺŤé—®": HuEmbedding, #QWenEmbed, - "智谱AI": ZhipuEmbed + "智谱AI": ZhipuEmbed, + "Moonshot": HuEmbedding } CvModel = { "OpenAI": GptV4, - "local": LocalCV, + "Local": LocalCV, "通义ĺŤé—®": QWenCV, - "智谱AI": Zhipu4V + "智谱AI": Zhipu4V, + "Moonshot": LocalCV } @@ -38,6 +40,7 @@ ChatModel = { "OpenAI": GptTurbo, "智谱AI": ZhipuChat, "通义ĺŤé—®": QWenChat, - "local": LocalLLM + "Local": LocalLLM, + "Moonshot": MoonshotChat } diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 88a682fb9fb9c06070a7ae612aba24987c21687a..57d6480c51cd2bb4ed157a8ca1358609be92acb8 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -14,11 +14,8 @@ # limitations under the License. # from abc import ABC -from copy import deepcopy - from openai import OpenAI import openai - from rag.nlp import is_english from rag.utils import num_tokens_from_string @@ -52,6 +49,12 @@ class GptTurbo(Base): return "**ERROR**: "+str(e), 0 +class MoonshotChat(GptTurbo): + def __init__(self, key, model_name="moonshot-v1-8k"): + self.client = OpenAI(api_key=key, base_url="https://api.moonshot.cn/v1",) + self.model_name = model_name + + from dashscope import Generation class QWenChat(Base): def __init__(self, key, model_name=Generation.Models.qwen_turbo): diff --git a/rag/llm/rpc_server.py b/rag/llm/rpc_server.py index 7cfd0e846ad074699ca2486d1c1d976aeed700e0..e1e8b829291815dbeaf1cf47298ed893208052e6 100644 --- a/rag/llm/rpc_server.py +++ b/rag/llm/rpc_server.py @@ -4,7 +4,7 @@ import random import time from multiprocessing.connection import Listener from threading import Thread -import torch +from transformers import AutoModelForCausalLM, AutoTokenizer class RPCHandler: @@ -47,14 +47,27 @@ tokenizer = None def chat(messages, gen_conf): global tokenizer model = Model() - roles = {"system":"System", "user": "User", "assistant": "Assistant"} - line = ["{}: {}".format(roles[m["role"].lower()], m["content"]) for m in messages] - line = "\n".join(line) + "\nAssistant: " - tokens = tokenizer([line], return_tensors='pt') - tokens = {k: tokens[k].to(model.device) if isinstance(tokens[k], torch.Tensor) else tokens[k] for k in - tokens.keys()} - res = [tokenizer.decode(t) for t in model.generate(**tokens, **gen_conf)][0] - return res.split("Assistant: ")[-1] + try: + conf = {"max_new_tokens": int(gen_conf.get("max_tokens", 256)), "temperature": float(gen_conf.get("temperature", 0.1))} + print(messages, conf) + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + model_inputs = tokenizer([text], return_tensors="pt").to(model.device) + + generated_ids = model.generate( + model_inputs.input_ids, + **conf + ) + generated_ids = [ + output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) + ] + + return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + except Exception as e: + return str(e) def Model(): @@ -71,20 +84,13 @@ if __name__ == "__main__": handler = RPCHandler() handler.register_function(chat) - from transformers import AutoModelForCausalLM, AutoTokenizer - from transformers.generation.utils import GenerationConfig - models = [] - for _ in range(2): + for _ in range(1): m = AutoModelForCausalLM.from_pretrained(args.model_name, device_map="auto", - torch_dtype='auto', - trust_remote_code=True) - m.generation_config = GenerationConfig.from_pretrained(args.model_name) - m.generation_config.pad_token_id = m.generation_config.eos_token_id + torch_dtype='auto') models.append(m) - tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False, - trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(args.model_name) # Run the server rpc_server(handler, ('0.0.0.0', args.port), authkey=b'infiniflow-token4kevinhu') diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 3603970125a6a318a6d56fd91101c33bc9690541..422d23f110ac638ed7952fe18fc5e9815260165a 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -7,6 +7,7 @@ from elasticsearch_dsl import Q, Search from typing import List, Optional, Dict, Union from dataclasses import dataclass +from api.settings import chat_logger from rag.settings import es_logger from rag.utils import rmSpace from rag.nlp import huqie, query @@ -333,15 +334,16 @@ class Dealer: replaces = [] for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql): fld, v = r.group(1), r.group(3) - match = " MATCH({}, '{}', 'operator=OR;fuzziness=AUTO:1,3;minimum_should_match=30%') ".format(fld, huqie.qieqie(huqie.qie(v))) + match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(fld, huqie.qieqie(huqie.qie(v))) replaces.append(("{}{}'{}'".format(r.group(1), r.group(2), r.group(3)), match)) for p, r in replaces: sql = sql.replace(p, r, 1) - es_logger.info(f"To es: {sql}") + chat_logger.info(f"To es: {sql}") try: tbl = self.es.sql(sql, fetch_size, format) return tbl except Exception as e: - es_logger.error(f"SQL failure: {sql} =>" + str(e)) + chat_logger.error(f"SQL failure: {sql} =>" + str(e)) + return {"error": str(e)} diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 4f0c28a9483c232f8d4e1f417912774d26a9c028..f8438e18ed3a99a388c30ff48f37cda8aaa29768 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -169,16 +169,25 @@ def init_kb(row): def embedding(docs, mdl, parser_config={}, callback=None): + batch_size = 32 tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [ d["content_with_weight"] for d in docs] tk_count = 0 if len(tts) == len(cnts): - tts, c = mdl.encode(tts) - tk_count += c + tts_ = np.array([]) + for i in range(0, len(tts), batch_size): + vts, c = mdl.encode(tts[i: i + batch_size]) + if len(tts_) == 0: + tts_ = vts + else: + tts_ = np.concatenate((tts_, vts), axis=0) + tk_count += c + callback(prog=0.6 + 0.1 * (i + 1) / len(tts), msg="") + tts = tts_ cnts_ = np.array([]) - for i in range(0, len(cnts), 8): - vts, c = mdl.encode(cnts[i: i+8]) + for i in range(0, len(cnts), batch_size): + vts, c = mdl.encode(cnts[i: i+batch_size]) if len(cnts_) == 0: cnts_ = vts else: cnts_ = np.concatenate((cnts_, vts), axis=0) tk_count += c diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index 515574c835e6240e094a3ae2c48da97025e86a15..2dd02bd68d0282e9adf9a1cd67603307a6afb40a 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -249,6 +249,8 @@ class HuEs: except ConnectionTimeout as e: es_logger.error("Timeoută€Q】:" + sql) continue + except Exception as e: + raise e es_logger.error("ES search timeout for 3 times!") raise ConnectionTimeout()