Skip to content
Snippets Groups Projects
Unverified Commit f1f09df9 authored by KevinHuSh's avatar KevinHuSh Committed by GitHub
Browse files

add local llm implementation (#119)

parent 0452a6db
No related branches found
No related tags found
No related merge requests found
FROM infiniflow/ragflow-base:v1.0
FROM swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow-base:v1.0
USER root
WORKDIR /ragflow
......
......@@ -21,7 +21,7 @@
</a>
</p>
[RAGFLOW](http://ragflow.io) is a knowledge management platform built on custom-build document understanding engine and LLM,
[RagFlow](http://ragflow.io) is a knowledge management platform built on custom-build document understanding engine and LLM,
with reasoned and well-founded answers to your question. Clone this repository, you can deploy your own knowledge management
platform to empower your business with AI.
......@@ -29,12 +29,12 @@ platform to empower your business with AI.
<img src="https://github.com/infiniflow/ragflow/assets/12318111/b24a7a5f-4d1d-4a30-90b1-7b0ec558b79d" width="1000"/>
</div>
# Features
# Key Features
- **Custom-build document understanding engine.** Our deep learning engine is made according to the needs of analyzing and searching various type of documents in different domain.
- For documents from different domain for different purpose, the engine applys different analyzing and search strategy.
- Easily intervene and manipulate the data proccessing procedure when things goes beyond expectation.
- Multi-media document understanding is supported using OCR and multi-modal LLM.
- **State-of-the-art table structure and layout recognition.** Precisely extract and understand the document including table content. [README](./deepdoc/README.md)
- **State-of-the-art table structure and layout recognition.** Precisely extract and understand the document including table content. See [README.](./deepdoc/README.md)
- For PDF files, layout and table structures including row, column and span of them are recognized.
- Put the table accrossing the pages together.
- Reconstruct the table structure components into html table.
......
......@@ -52,7 +52,7 @@ app.errorhandler(Exception)(server_error_response)
#app.config["LOGIN_DISABLED"] = True
app.config["SESSION_PERMANENT"] = False
app.config["SESSION_TYPE"] = "filesystem"
app.config['MAX_CONTENT_LENGTH'] = 64 * 1024 * 1024
app.config['MAX_CONTENT_LENGTH'] = 128 * 1024 * 1024
Session(app)
login_manager = LoginManager()
......
......@@ -85,7 +85,7 @@ def my_llms():
}
res[o["llm_factory"]]["llm"].append({
"type": o["model_type"],
"name": o["model_name"],
"name": o["llm_name"],
"used_token": o["used_tokens"]
})
return get_json_result(data=res)
......
......@@ -520,7 +520,7 @@ class Task(DataBaseModel):
begin_at = DateTimeField(null=True)
process_duation = FloatField(default=0)
progress = FloatField(default=0)
progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="")
progress_msg = TextField(max_length=4096, null=True, help_text="process message", default="")
class Dialog(DataBaseModel):
......
......@@ -47,6 +47,7 @@ class KnowledgebaseService(CommonService):
Tenant.embd_id,
cls.model.avatar,
cls.model.name,
cls.model.language,
cls.model.description,
cls.model.permission,
cls.model.doc_num,
......
......@@ -42,7 +42,7 @@ ERROR_REPORT = True
ERROR_REPORT_WITH_PATH = False
MAX_TIMESTAMP_INTERVAL = 60
SESSION_VALID_PERIOD = 7 * 24 * 60 * 60 * 1000
SESSION_VALID_PERIOD = 7 * 24 * 60 * 60
REQUEST_TRY_TIMES = 3
REQUEST_WAIT_SEC = 2
......@@ -69,6 +69,12 @@ default_llm = {
"image2text_model": "glm-4v",
"asr_model": "",
},
"local": {
"chat_model": "",
"embedding_model": "",
"image2text_model": "",
"asr_model": "",
}
}
LLM = get_base_config("user_default_llm", {})
LLM_FACTORY = LLM.get("factory", "通义千问")
......@@ -134,7 +140,7 @@ USE_AUTHENTICATION = False
USE_DATA_AUTHENTICATION = False
AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True
USE_DEFAULT_TIMEOUT = False
AUTHENTICATION_DEFAULT_TIMEOUT = 30 * 24 * 60 * 60 # s
AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s
PRIVILEGE_COMMAND_WHITELIST = []
CHECK_NODES_IDENTITY = False
......
......@@ -20,13 +20,27 @@ class HuExcelParser:
for i,c in enumerate(r):
if not c.value:continue
t = str(ti[i].value) if i < len(ti) else ""
t += ("" if t else "") + str(c.value)
t += ("" if t else "") + str(c.value)
l.append(t)
l = "; ".join(l)
if sheetname.lower().find("sheet") <0: l += " ——"+sheetname
res.append(l)
return res
@staticmethod
def row_number(fnm, binary):
if fnm.split(".")[-1].lower().find("xls") >= 0:
wb = load_workbook(BytesIO(binary))
total = 0
for sheetname in wb.sheetnames:
ws = wb[sheetname]
total += len(ws.rows)
return total
if fnm.split(".")[-1].lower() in ["csv", "txt"]:
txt = binary.decode("utf-8")
return len(txt.split("\n"))
if __name__ == "__main__":
psr = HuExcelParser()
......
......@@ -26,7 +26,7 @@ http {
keepalive_timeout 65;
#gzip on;
client_max_body_size 82M;
client_max_body_size 128M;
include /etc/nginx/conf.d/ragflow.conf;
}
......
......@@ -25,7 +25,7 @@ from deepdoc.parser import ExcelParser
class Excel(ExcelParser):
def __call__(self, fnm, binary=None, callback=None):
def __call__(self, fnm, binary=None, from_page=0, to_page=10000000000, callback=None):
if not binary:
wb = load_workbook(fnm)
else:
......@@ -35,6 +35,7 @@ class Excel(ExcelParser):
total += len(list(wb[sheetname].rows))
res, fails, done = [], [], 0
rn = 0
for sheetname in wb.sheetnames:
ws = wb[sheetname]
rows = list(ws.rows)
......@@ -46,6 +47,9 @@ class Excel(ExcelParser):
rows[0]) if i not in missed]
data = []
for i, r in enumerate(rows[1:]):
rn += 1
if rn-1 < from_page:continue
if rn -1>=to_page: break
row = [
cell.value for ii,
cell in enumerate(r) if ii not in missed]
......@@ -111,7 +115,7 @@ def column_data_type(arr):
return arr, ty
def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese", callback=None, **kwargs):
"""
Excel and csv(txt) format files are supported.
For csv or txt file, the delimiter between columns is TAB.
......@@ -147,16 +151,15 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
headers = lines[0].split(kwargs.get("delimiter", "\t"))
rows = []
for i, line in enumerate(lines[1:]):
if from_page < from_page:continue
if i >= to_page: break
row = [l for l in line.split(kwargs.get("delimiter", "\t"))]
if len(row) != len(headers):
fails.append(str(i))
continue
rows.append(row)
if len(rows) % 999 == 0:
callback(len(rows) * 0.6 / len(lines), ("Extract records: {}".format(len(rows)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
callback(0.6, ("Extract records: {}".format(len(rows)) + (
callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
dfs = [pd.DataFrame(np.array(rows), columns=headers)]
......@@ -209,7 +212,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
KnowledgebaseService.update_parser_config(
kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}})
callback(0.6, "")
callback(0.35, "")
return res
......
......@@ -19,22 +19,25 @@ from .cv_model import *
EmbeddingModel = {
"Infiniflow": HuEmbedding,
"local": HuEmbedding,
"OpenAI": OpenAIEmbed,
"通义千问": HuEmbedding, #QWenEmbed,
"智谱AI": ZhipuEmbed
}
CvModel = {
"OpenAI": GptV4,
"Infiniflow": GptV4,
"local": LocalCV,
"通义千问": QWenCV,
"智谱AI": Zhipu4V
}
ChatModel = {
"OpenAI": GptTurbo,
"Infiniflow": GptTurbo,
"智谱AI": ZhipuChat,
"通义千问": QWenChat,
"local": LocalLLM
}
......@@ -20,6 +20,7 @@ from openai import OpenAI
import openai
from rag.nlp import is_english
from rag.utils import num_tokens_from_string
class Base(ABC):
......@@ -86,7 +87,6 @@ class ZhipuChat(Base):
self.model_name = model_name
def chat(self, system, history, gen_conf):
from http import HTTPStatus
if system: history.insert(0, {"role": "system", "content": system})
try:
response = self.client.chat.completions.create(
......@@ -100,4 +100,42 @@ class ZhipuChat(Base):
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return ans, response.usage.completion_tokens
except Exception as e:
return "**ERROR**: " + str(e), 0
\ No newline at end of file
return "**ERROR**: " + str(e), 0
class LocalLLM(Base):
class RPCProxy:
def __init__(self, host, port):
self.host = host
self.port = int(port)
self.__conn()
def __conn(self):
from multiprocessing.connection import Client
self._connection = Client((self.host, self.port), authkey=b'infiniflow-token4kevinhu')
def __getattr__(self, name):
import pickle
def do_rpc(*args, **kwargs):
for _ in range(3):
try:
self._connection.send(pickle.dumps((name, args, kwargs)))
return pickle.loads(self._connection.recv())
except Exception as e:
self.__conn()
raise Exception("RPC connection lost!")
return do_rpc
def __init__(self, key, model_name="glm-3-turbo"):
self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
def chat(self, system, history, gen_conf):
if system: history.insert(0, {"role": "system", "content": system})
try:
ans = self.client.chat(
history,
gen_conf
)
return ans, num_tokens_from_string(ans)
except Exception as e:
return "**ERROR**: " + str(e), 0
......@@ -138,3 +138,11 @@ class Zhipu4V(Base):
max_tokens=max_tokens,
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
class LocalCV(Base):
def __init__(self, key, model_name="glm-4v", lang="Chinese"):
pass
def describe(self, image, max_tokens=1024):
return "", 0
import argparse
import pickle
import random
import time
from multiprocessing.connection import Listener
from threading import Thread
import torch
class RPCHandler:
def __init__(self):
self._functions = { }
def register_function(self, func):
self._functions[func.__name__] = func
def handle_connection(self, connection):
try:
while True:
# Receive a message
func_name, args, kwargs = pickle.loads(connection.recv())
# Run the RPC and send a response
try:
r = self._functions[func_name](*args,**kwargs)
connection.send(pickle.dumps(r))
except Exception as e:
connection.send(pickle.dumps(e))
except EOFError:
pass
def rpc_server(hdlr, address, authkey):
sock = Listener(address, authkey=authkey)
while True:
try:
client = sock.accept()
t = Thread(target=hdlr.handle_connection, args=(client,))
t.daemon = True
t.start()
except Exception as e:
print("【EXCEPTION】:", str(e))
models = []
tokenizer = None
def chat(messages, gen_conf):
global tokenizer
model = Model()
roles = {"system":"System", "user": "User", "assistant": "Assistant"}
line = ["{}: {}".format(roles[m["role"].lower()], m["content"]) for m in messages]
line = "\n".join(line) + "\nAssistant: "
tokens = tokenizer([line], return_tensors='pt')
tokens = {k: tokens[k].to(model.device) if isinstance(tokens[k], torch.Tensor) else tokens[k] for k in
tokens.keys()}
res = [tokenizer.decode(t) for t in model.generate(**tokens, **gen_conf)][0]
return res.split("Assistant: ")[-1]
def Model():
global models
random.seed(time.time())
return random.choice(models)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, help="Model name")
parser.add_argument("--port", default=7860, type=int, help="RPC serving port")
args = parser.parse_args()
handler = RPCHandler()
handler.register_function(chat)
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
models = []
for _ in range(2):
m = AutoModelForCausalLM.from_pretrained(args.model_name,
device_map="auto",
torch_dtype='auto',
trust_remote_code=True)
m.generation_config = GenerationConfig.from_pretrained(args.model_name)
m.generation_config.pad_token_id = m.generation_config.eos_token_id
models.append(m)
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False,
trust_remote_code=True)
# Run the server
rpc_server(handler, ('0.0.0.0', args.port), authkey=b'infiniflow-token4kevinhu')
......@@ -25,7 +25,7 @@ SUBPROCESS_STD_LOG_NAME = "std.log"
ES = get_base_config("es", {})
MINIO = decrypt_database_config(name="minio")
DOC_MAXIMUM_SIZE = 64 * 1024 * 1024
DOC_MAXIMUM_SIZE = 128 * 1024 * 1024
# Logger
LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "rag"))
......
......@@ -22,6 +22,7 @@ from api.db.db_models import Task
from api.db.db_utils import bulk_insert_into_db
from api.db.services.task_service import TaskService
from deepdoc.parser import PdfParser
from deepdoc.parser.excel_parser import HuExcelParser
from rag.settings import cron_logger
from rag.utils import MINIO
from rag.utils import findMaxTm
......@@ -88,6 +89,13 @@ def dispatch():
task["from_page"] = p
task["to_page"] = min(p + 5, e)
tsks.append(task)
elif r["parser_id"] == "table":
rn = HuExcelParser.row_number(r["name"], MINIO.get(r["kb_id"], r["location"]))
for i in range(0, rn, 1000):
task = new_task()
task["from_page"] = i
task["to_page"] = min(i + 1000, rn)
tsks.append(task)
else:
tsks.append(new_task())
......
......@@ -184,7 +184,7 @@ def embedding(docs, mdl, parser_config={}, callback=None):
if len(cnts_) == 0: cnts_ = vts
else: cnts_ = np.concatenate((cnts_, vts), axis=0)
tk_count += c
callback(msg="")
callback(prog=0.7+0.2*(i+1)/len(cnts), msg="")
cnts = cnts_
title_w = float(parser_config.get("filename_embd_weight", 0.1))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment