From f1f09df901c2729419489b68ca93b493426080e7 Mon Sep 17 00:00:00 2001 From: KevinHuSh <kevinhu.sh@gmail.com> Date: Tue, 12 Mar 2024 11:57:08 +0800 Subject: [PATCH] add local llm implementation (#119) --- Dockerfile | 2 +- README.md | 6 +- api/apps/__init__.py | 2 +- api/apps/llm_app.py | 2 +- api/db/db_models.py | 2 +- api/db/services/knowledgebase_service.py | 1 + api/settings.py | 10 ++- deepdoc/parser/excel_parser.py | 16 ++++- docker/nginx/nginx.conf | 2 +- rag/app/table.py | 17 +++-- rag/llm/__init__.py | 9 ++- rag/llm/chat_model.py | 42 ++++++++++- rag/llm/cv_model.py | 8 +++ rag/llm/rpc_server.py | 90 ++++++++++++++++++++++++ rag/settings.py | 2 +- rag/svr/task_broker.py | 8 +++ rag/svr/task_executor.py | 2 +- 17 files changed, 196 insertions(+), 25 deletions(-) create mode 100644 rag/llm/rpc_server.py diff --git a/Dockerfile b/Dockerfile index b49434f..c174ccb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM infiniflow/ragflow-base:v1.0 +FROM swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow-base:v1.0 USER root WORKDIR /ragflow diff --git a/README.md b/README.md index 6793e65..b59a4ad 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ </a> </p> -[RAGFLOW](http://ragflow.io) is a knowledge management platform built on custom-build document understanding engine and LLM, +[RagFlow](http://ragflow.io) is a knowledge management platform built on custom-build document understanding engine and LLM, with reasoned and well-founded answers to your question. Clone this repository, you can deploy your own knowledge management platform to empower your business with AI. @@ -29,12 +29,12 @@ platform to empower your business with AI. <img src="https://github.com/infiniflow/ragflow/assets/12318111/b24a7a5f-4d1d-4a30-90b1-7b0ec558b79d" width="1000"/> </div> -# Features +# Key Features - **Custom-build document understanding engine.** Our deep learning engine is made according to the needs of analyzing and searching various type of documents in different domain. - For documents from different domain for different purpose, the engine applys different analyzing and search strategy. - Easily intervene and manipulate the data proccessing procedure when things goes beyond expectation. - Multi-media document understanding is supported using OCR and multi-modal LLM. -- **State-of-the-art table structure and layout recognition.** Precisely extract and understand the document including table content. [README](./deepdoc/README.md) +- **State-of-the-art table structure and layout recognition.** Precisely extract and understand the document including table content. See [README.](./deepdoc/README.md) - For PDF files, layout and table structures including row, column and span of them are recognized. - Put the table accrossing the pages together. - Reconstruct the table structure components into html table. diff --git a/api/apps/__init__.py b/api/apps/__init__.py index a53663b..02f629a 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -52,7 +52,7 @@ app.errorhandler(Exception)(server_error_response) #app.config["LOGIN_DISABLED"] = True app.config["SESSION_PERMANENT"] = False app.config["SESSION_TYPE"] = "filesystem" -app.config['MAX_CONTENT_LENGTH'] = 64 * 1024 * 1024 +app.config['MAX_CONTENT_LENGTH'] = 128 * 1024 * 1024 Session(app) login_manager = LoginManager() diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index e0e213f..e89c855 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -85,7 +85,7 @@ def my_llms(): } res[o["llm_factory"]]["llm"].append({ "type": o["model_type"], - "name": o["model_name"], + "name": o["llm_name"], "used_token": o["used_tokens"] }) return get_json_result(data=res) diff --git a/api/db/db_models.py b/api/db/db_models.py index f28d37b..b02ea2c 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -520,7 +520,7 @@ class Task(DataBaseModel): begin_at = DateTimeField(null=True) process_duation = FloatField(default=0) progress = FloatField(default=0) - progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="") + progress_msg = TextField(max_length=4096, null=True, help_text="process message", default="") class Dialog(DataBaseModel): diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index 236b8d0..46fe4bc 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -47,6 +47,7 @@ class KnowledgebaseService(CommonService): Tenant.embd_id, cls.model.avatar, cls.model.name, + cls.model.language, cls.model.description, cls.model.permission, cls.model.doc_num, diff --git a/api/settings.py b/api/settings.py index 3200763..956e2f3 100644 --- a/api/settings.py +++ b/api/settings.py @@ -42,7 +42,7 @@ ERROR_REPORT = True ERROR_REPORT_WITH_PATH = False MAX_TIMESTAMP_INTERVAL = 60 -SESSION_VALID_PERIOD = 7 * 24 * 60 * 60 * 1000 +SESSION_VALID_PERIOD = 7 * 24 * 60 * 60 REQUEST_TRY_TIMES = 3 REQUEST_WAIT_SEC = 2 @@ -69,6 +69,12 @@ default_llm = { "image2text_model": "glm-4v", "asr_model": "", }, + "local": { + "chat_model": "", + "embedding_model": "", + "image2text_model": "", + "asr_model": "", + } } LLM = get_base_config("user_default_llm", {}) LLM_FACTORY = LLM.get("factory", "通义ĺŤé—®") @@ -134,7 +140,7 @@ USE_AUTHENTICATION = False USE_DATA_AUTHENTICATION = False AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True USE_DEFAULT_TIMEOUT = False -AUTHENTICATION_DEFAULT_TIMEOUT = 30 * 24 * 60 * 60 # s +AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s PRIVILEGE_COMMAND_WHITELIST = [] CHECK_NODES_IDENTITY = False diff --git a/deepdoc/parser/excel_parser.py b/deepdoc/parser/excel_parser.py index 10f3b28..d2054f1 100644 --- a/deepdoc/parser/excel_parser.py +++ b/deepdoc/parser/excel_parser.py @@ -20,13 +20,27 @@ class HuExcelParser: for i,c in enumerate(r): if not c.value:continue t = str(ti[i].value) if i < len(ti) else "" - t += (":" if t else "") + str(c.value) + t += (":" if t else "") + str(c.value) l.append(t) l = "; ".join(l) if sheetname.lower().find("sheet") <0: l += " ——"+sheetname res.append(l) return res + @staticmethod + def row_number(fnm, binary): + if fnm.split(".")[-1].lower().find("xls") >= 0: + wb = load_workbook(BytesIO(binary)) + total = 0 + for sheetname in wb.sheetnames: + ws = wb[sheetname] + total += len(ws.rows) + return total + + if fnm.split(".")[-1].lower() in ["csv", "txt"]: + txt = binary.decode("utf-8") + return len(txt.split("\n")) + if __name__ == "__main__": psr = HuExcelParser() diff --git a/docker/nginx/nginx.conf b/docker/nginx/nginx.conf index 05c5029..8933a8d 100644 --- a/docker/nginx/nginx.conf +++ b/docker/nginx/nginx.conf @@ -26,7 +26,7 @@ http { keepalive_timeout 65; #gzip on; - client_max_body_size 82M; + client_max_body_size 128M; include /etc/nginx/conf.d/ragflow.conf; } diff --git a/rag/app/table.py b/rag/app/table.py index 68b2d33..4cf1c1c 100644 --- a/rag/app/table.py +++ b/rag/app/table.py @@ -25,7 +25,7 @@ from deepdoc.parser import ExcelParser class Excel(ExcelParser): - def __call__(self, fnm, binary=None, callback=None): + def __call__(self, fnm, binary=None, from_page=0, to_page=10000000000, callback=None): if not binary: wb = load_workbook(fnm) else: @@ -35,6 +35,7 @@ class Excel(ExcelParser): total += len(list(wb[sheetname].rows)) res, fails, done = [], [], 0 + rn = 0 for sheetname in wb.sheetnames: ws = wb[sheetname] rows = list(ws.rows) @@ -46,6 +47,9 @@ class Excel(ExcelParser): rows[0]) if i not in missed] data = [] for i, r in enumerate(rows[1:]): + rn += 1 + if rn-1 < from_page:continue + if rn -1>=to_page: break row = [ cell.value for ii, cell in enumerate(r) if ii not in missed] @@ -111,7 +115,7 @@ def column_data_type(arr): return arr, ty -def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): +def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese", callback=None, **kwargs): """ Excel and csv(txt) format files are supported. For csv or txt file, the delimiter between columns is TAB. @@ -147,16 +151,15 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): headers = lines[0].split(kwargs.get("delimiter", "\t")) rows = [] for i, line in enumerate(lines[1:]): + if from_page < from_page:continue + if i >= to_page: break row = [l for l in line.split(kwargs.get("delimiter", "\t"))] if len(row) != len(headers): fails.append(str(i)) continue rows.append(row) - if len(rows) % 999 == 0: - callback(len(rows) * 0.6 / len(lines), ("Extract records: {}".format(len(rows)) + ( - f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) - callback(0.6, ("Extract records: {}".format(len(rows)) + ( + callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + ( f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) dfs = [pd.DataFrame(np.array(rows), columns=headers)] @@ -209,7 +212,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): KnowledgebaseService.update_parser_config( kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}}) - callback(0.6, "") + callback(0.35, "") return res diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 462313d..c2c2726 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -19,22 +19,25 @@ from .cv_model import * EmbeddingModel = { - "Infiniflow": HuEmbedding, + "local": HuEmbedding, "OpenAI": OpenAIEmbed, "通义ĺŤé—®": HuEmbedding, #QWenEmbed, + "智谱AI": ZhipuEmbed } CvModel = { "OpenAI": GptV4, - "Infiniflow": GptV4, + "local": LocalCV, "通义ĺŤé—®": QWenCV, + "智谱AI": Zhipu4V } ChatModel = { "OpenAI": GptTurbo, - "Infiniflow": GptTurbo, + "智谱AI": ZhipuChat, "通义ĺŤé—®": QWenChat, + "local": LocalLLM } diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 2389561..88a682f 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -20,6 +20,7 @@ from openai import OpenAI import openai from rag.nlp import is_english +from rag.utils import num_tokens_from_string class Base(ABC): @@ -86,7 +87,6 @@ class ZhipuChat(Base): self.model_name = model_name def chat(self, system, history, gen_conf): - from http import HTTPStatus if system: history.insert(0, {"role": "system", "content": system}) try: response = self.client.chat.completions.create( @@ -100,4 +100,42 @@ class ZhipuChat(Base): [ans]) else "······\nç”±äşŽé•żĺş¦çš„ĺŽźĺ› ďĽŚĺ›žç”被ćŞć–了,č¦ç»§ç»ĺ—?" return ans, response.usage.completion_tokens except Exception as e: - return "**ERROR**: " + str(e), 0 \ No newline at end of file + return "**ERROR**: " + str(e), 0 + +class LocalLLM(Base): + class RPCProxy: + def __init__(self, host, port): + self.host = host + self.port = int(port) + self.__conn() + + def __conn(self): + from multiprocessing.connection import Client + self._connection = Client((self.host, self.port), authkey=b'infiniflow-token4kevinhu') + + def __getattr__(self, name): + import pickle + def do_rpc(*args, **kwargs): + for _ in range(3): + try: + self._connection.send(pickle.dumps((name, args, kwargs))) + return pickle.loads(self._connection.recv()) + except Exception as e: + self.__conn() + raise Exception("RPC connection lost!") + + return do_rpc + + def __init__(self, key, model_name="glm-3-turbo"): + self.client = LocalLLM.RPCProxy("127.0.0.1", 7860) + + def chat(self, system, history, gen_conf): + if system: history.insert(0, {"role": "system", "content": system}) + try: + ans = self.client.chat( + history, + gen_conf + ) + return ans, num_tokens_from_string(ans) + except Exception as e: + return "**ERROR**: " + str(e), 0 diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 02298c6..6e139c7 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -138,3 +138,11 @@ class Zhipu4V(Base): max_tokens=max_tokens, ) return res.choices[0].message.content.strip(), res.usage.total_tokens + + +class LocalCV(Base): + def __init__(self, key, model_name="glm-4v", lang="Chinese"): + pass + + def describe(self, image, max_tokens=1024): + return "", 0 diff --git a/rag/llm/rpc_server.py b/rag/llm/rpc_server.py new file mode 100644 index 0000000..7cfd0e8 --- /dev/null +++ b/rag/llm/rpc_server.py @@ -0,0 +1,90 @@ +import argparse +import pickle +import random +import time +from multiprocessing.connection import Listener +from threading import Thread +import torch + + +class RPCHandler: + def __init__(self): + self._functions = { } + + def register_function(self, func): + self._functions[func.__name__] = func + + def handle_connection(self, connection): + try: + while True: + # Receive a message + func_name, args, kwargs = pickle.loads(connection.recv()) + # Run the RPC and send a response + try: + r = self._functions[func_name](*args,**kwargs) + connection.send(pickle.dumps(r)) + except Exception as e: + connection.send(pickle.dumps(e)) + except EOFError: + pass + + +def rpc_server(hdlr, address, authkey): + sock = Listener(address, authkey=authkey) + while True: + try: + client = sock.accept() + t = Thread(target=hdlr.handle_connection, args=(client,)) + t.daemon = True + t.start() + except Exception as e: + print("ă€EXCEPTION】:", str(e)) + + +models = [] +tokenizer = None + +def chat(messages, gen_conf): + global tokenizer + model = Model() + roles = {"system":"System", "user": "User", "assistant": "Assistant"} + line = ["{}: {}".format(roles[m["role"].lower()], m["content"]) for m in messages] + line = "\n".join(line) + "\nAssistant: " + tokens = tokenizer([line], return_tensors='pt') + tokens = {k: tokens[k].to(model.device) if isinstance(tokens[k], torch.Tensor) else tokens[k] for k in + tokens.keys()} + res = [tokenizer.decode(t) for t in model.generate(**tokens, **gen_conf)][0] + return res.split("Assistant: ")[-1] + + +def Model(): + global models + random.seed(time.time()) + return random.choice(models) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_name", type=str, help="Model name") + parser.add_argument("--port", default=7860, type=int, help="RPC serving port") + args = parser.parse_args() + + handler = RPCHandler() + handler.register_function(chat) + + from transformers import AutoModelForCausalLM, AutoTokenizer + from transformers.generation.utils import GenerationConfig + + models = [] + for _ in range(2): + m = AutoModelForCausalLM.from_pretrained(args.model_name, + device_map="auto", + torch_dtype='auto', + trust_remote_code=True) + m.generation_config = GenerationConfig.from_pretrained(args.model_name) + m.generation_config.pad_token_id = m.generation_config.eos_token_id + models.append(m) + tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False, + trust_remote_code=True) + + # Run the server + rpc_server(handler, ('0.0.0.0', args.port), authkey=b'infiniflow-token4kevinhu') diff --git a/rag/settings.py b/rag/settings.py index 2906ea3..7c2257a 100644 --- a/rag/settings.py +++ b/rag/settings.py @@ -25,7 +25,7 @@ SUBPROCESS_STD_LOG_NAME = "std.log" ES = get_base_config("es", {}) MINIO = decrypt_database_config(name="minio") -DOC_MAXIMUM_SIZE = 64 * 1024 * 1024 +DOC_MAXIMUM_SIZE = 128 * 1024 * 1024 # Logger LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "rag")) diff --git a/rag/svr/task_broker.py b/rag/svr/task_broker.py index 38a2b44..e79ac38 100644 --- a/rag/svr/task_broker.py +++ b/rag/svr/task_broker.py @@ -22,6 +22,7 @@ from api.db.db_models import Task from api.db.db_utils import bulk_insert_into_db from api.db.services.task_service import TaskService from deepdoc.parser import PdfParser +from deepdoc.parser.excel_parser import HuExcelParser from rag.settings import cron_logger from rag.utils import MINIO from rag.utils import findMaxTm @@ -88,6 +89,13 @@ def dispatch(): task["from_page"] = p task["to_page"] = min(p + 5, e) tsks.append(task) + elif r["parser_id"] == "table": + rn = HuExcelParser.row_number(r["name"], MINIO.get(r["kb_id"], r["location"])) + for i in range(0, rn, 1000): + task = new_task() + task["from_page"] = i + task["to_page"] = min(i + 1000, rn) + tsks.append(task) else: tsks.append(new_task()) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 9ba17ef..9765957 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -184,7 +184,7 @@ def embedding(docs, mdl, parser_config={}, callback=None): if len(cnts_) == 0: cnts_ = vts else: cnts_ = np.concatenate((cnts_, vts), axis=0) tk_count += c - callback(msg="") + callback(prog=0.7+0.2*(i+1)/len(cnts), msg="") cnts = cnts_ title_w = float(parser_config.get("filename_embd_weight", 0.1)) -- GitLab