diff --git a/README.md b/README.md index 686f6bb427f99a8d8e9b68709c8cccd04e40982a..28b9692b9353ce832ff7c7e47c1230b784a49460 100644 --- a/README.md +++ b/README.md @@ -172,6 +172,7 @@ $ docker compose up -d ## 🆕 Latest Features +- 2024-04-11 Support [Xinference](./docs/xinference.md) for local LLM deployment. - 2024-04-10 Add a new layout recognize model for method 'Laws'. - 2024-04-08 Support [Ollama](./docs/ollama.md) for local LLM deployment. - 2024-04-07 Support Chinese UI. diff --git a/README_ja.md b/README_ja.md index d3074bd14df699b35e32bc286c312ecdc2670a97..3197279e712041ebee8d5914211b38f9b1045637 100644 --- a/README_ja.md +++ b/README_ja.md @@ -171,6 +171,8 @@ $ docker compose up -d ``` ## 🆕 最新ă®ć–°ć©źč˝ + +- 2024-04-11 ăăĽă‚«ă« LLM ă‡ă—ăイăˇăłă用㫠[Xinference](./docs/xinference.md) をサăťăĽăă—ăľă™ă€‚ - 2024-04-10 ăˇă‚˝ăă‰ă€ŚLaws」ă«ć–°ă—ă„ă¬ă‚¤ă‚˘ă‚¦ă認čă˘ă‡ă«ă‚’čż˝ĺŠ ă—ăľă™ă€‚ - 2024-04-08 [Ollama](./docs/ollama.md) を使用ă—ăźĺ¤§č¦Źć¨ˇă˘ă‡ă«ă®ăăĽă‚«ă©ă‚¤ă‚şă•ă‚Śăźă‡ă—ăイăˇăłăをサăťăĽăă—ăľă™ă€‚ - 2024-04-07 ä¸ĺ›˝čŞžă‚¤ăłă‚żăĽă•ă‚§ăĽă‚ąă‚’サăťăĽăă—ăľă™ă€‚ diff --git a/README_zh.md b/README_zh.md index 7c7d571b4833ab9d162bfdcc845d136d43242176..143fcc769244aa7073d3c4561b03849cea14cedd 100644 --- a/README_zh.md +++ b/README_zh.md @@ -172,6 +172,7 @@ $ docker compose up -d ## 🆕 最近新特性 +- 2024-04-11 支ćŚç”¨ [Xinference](./docs/xinference.md) for local LLM deployment. - 2024-04-10 为â€Laws’ç‰éť˘ĺ†ćžĺ˘žĺŠ 了模型。 - 2024-04-08 支ćŚç”¨ [Ollama](./docs/ollama.md) 对大模型进行本地化é¨ç˝˛ă€‚ - 2024-04-07 支ćŚä¸ć–‡ç•Śéť˘ă€‚ diff --git a/api/apps/__init__.py b/api/apps/__init__.py index 02f629ac54365da82236bd621b344271ed283f43..5ee940f46282f07297a68c93d9125ac15e1de1e5 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -22,6 +22,7 @@ from werkzeug.wrappers.request import Request from flask_cors import CORS from api.db import StatusEnum +from api.db.db_models import close_connection from api.db.services import UserService from api.utils import CustomJSONEncoder @@ -42,7 +43,7 @@ for h in access_logger.handlers: Request.json = property(lambda self: self.get_json(force=True, silent=True)) app = Flask(__name__) -CORS(app, supports_credentials=True,max_age = 2592000) +CORS(app, supports_credentials=True,max_age=2592000) app.url_map.strict_slashes = False app.json_encoder = CustomJSONEncoder app.errorhandler(Exception)(server_error_response) @@ -94,8 +95,6 @@ client_urls_prefix = [ ] - - @login_manager.request_loader def load_user(web_request): jwt = Serializer(secret_key=SECRET_KEY) @@ -112,4 +111,9 @@ def load_user(web_request): stat_logger.exception(e) return None else: - return None \ No newline at end of file + return None + + +@app.teardown_request +def _db_close(exc): + close_connection() \ No newline at end of file diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index fd9c266ea4edec7e5cc4d506739cb5819dcf5697..8c42c804bffb857f34c880518807019f152499b9 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -360,6 +360,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl): "|" for r in tbl["rows"]] rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows) + if not docid_idx or not docnm_idx: chat_logger.warning("SQL missing field: " + sql) return { diff --git a/api/db/init_data.py b/api/db/init_data.py index 4cc72a2d5ed596fa39b6667d7b7b3f30dc79904c..2e5026af4d9446cb916ba6017d69c6e50c0f5b57 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -109,6 +109,12 @@ factory_infos = [{ "logo": "", "tags": "LLM,TEXT EMBEDDING", "status": "1", +}, + { + "name": "Xinference", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", + "status": "1", }, # { # "name": "ć–‡ĺżä¸€č¨€", diff --git a/docker/docker-compose-CN.yml b/docker/docker-compose-CN.yml index 2621634208b1b46db57f88da74d2403b72a9609c..a4f3f77c3aff6c07f1858d85ef834df6d253e6a9 100644 --- a/docker/docker-compose-CN.yml +++ b/docker/docker-compose-CN.yml @@ -20,7 +20,6 @@ services: - 443:443 volumes: - ./service_conf.yaml:/ragflow/conf/service_conf.yaml - - ./entrypoint.sh:/ragflow/entrypoint.sh - ./ragflow-logs:/ragflow/logs - ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf - ./nginx/proxy.conf:/etc/nginx/proxy.conf diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index f5ad8f8b27f5b74af8a69367ccdab79885308abc..312b5329e52f679c20a5700288de1db4ad36a061 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -19,7 +19,6 @@ services: - 443:443 volumes: - ./service_conf.yaml:/ragflow/conf/service_conf.yaml - - ./entrypoint.sh:/ragflow/entrypoint.sh - ./ragflow-logs:/ragflow/logs - ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf - ./nginx/proxy.conf:/etc/nginx/proxy.conf diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index c3fc7db816a441f22a32ca0813435b9414c7c98f..c088a7f94443deea005c81952ec1be88e0a9f785 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -21,6 +21,7 @@ from .cv_model import * EmbeddingModel = { "Ollama": OllamaEmbed, "OpenAI": OpenAIEmbed, + "Xinference": XinferenceEmbed, "Tongyi-Qianwen": HuEmbedding, #QWenEmbed, "ZHIPU-AI": ZhipuEmbed, "Moonshot": HuEmbedding @@ -30,6 +31,7 @@ EmbeddingModel = { CvModel = { "OpenAI": GptV4, "Ollama": OllamaCV, + "Xinference": XinferenceCV, "Tongyi-Qianwen": QWenCV, "ZHIPU-AI": Zhipu4V, "Moonshot": LocalCV @@ -41,6 +43,7 @@ ChatModel = { "ZHIPU-AI": ZhipuChat, "Tongyi-Qianwen": QWenChat, "Ollama": OllamaChat, + "Xinference": XinferenceChat, "Moonshot": MoonshotChat } diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index d4f0e7b64b6074cd6da488e0b180a2ab62561d07..b9bb36d736d1ee4a395d80cd375f38070acab242 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -158,6 +158,28 @@ class OllamaChat(Base): return "**ERROR**: " + str(e), 0 +class XinferenceChat(Base): + def __init__(self, key=None, model_name="", base_url=""): + self.client = OpenAI(api_key="xxx", base_url=base_url) + self.model_name = model_name + + def chat(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=history, + **gen_conf) + ans = response.choices[0].message.content.strip() + if response.choices[0].finish_reason == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\nç”±äşŽé•żĺş¦çš„ĺŽźĺ› ďĽŚĺ›žç”被ćŞć–了,č¦ç»§ç»ĺ—?" + return ans, response.usage.completion_tokens + except openai.APIError as e: + return "**ERROR**: " + str(e), 0 + + class LocalLLM(Base): class RPCProxy: def __init__(self, host, port): diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index d764bc873009fd1cf3910cd2acbfd2bfb309b92a..4b966991bf8329f1fcf82594351e70df530fd9fc 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -161,6 +161,22 @@ class OllamaCV(Base): except Exception as e: return "**ERROR**: " + str(e), 0 +class XinferenceCV(Base): + def __init__(self, key, model_name="", lang="Chinese", base_url=""): + self.client = OpenAI(api_key=key, base_url=base_url) + self.model_name = model_name + self.lang = lang + + def describe(self, image, max_tokens=300): + b64 = self.image2base64(image) + + res = self.client.chat.completions.create( + model=self.model_name, + messages=self.prompt(b64), + max_tokens=max_tokens, + ) + return res.choices[0].message.content.strip(), res.usage.total_tokens + class LocalCV(Base): def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs): diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index d5b763d1855f53039340cdd1e87e7665e44005d6..aa6b565b87491734e8e1e225e012ca883a6e8297 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -170,3 +170,20 @@ class OllamaEmbed(Base): res = self.client.embeddings(prompt=text, model=self.model_name) return np.array(res["embedding"]), 128 + + +class XinferenceEmbed(Base): + def __init__(self, key, model_name="", base_url=""): + self.client = OpenAI(api_key="xxx", base_url=base_url) + self.model_name = model_name + + def encode(self, texts: list, batch_size=32): + res = self.client.embeddings.create(input=texts, + model=self.model_name) + return np.array([d.embedding for d in res.data] + ), res.usage.total_tokens + + def encode_queries(self, text): + res = self.client.embeddings.create(input=[text], + model=self.model_name) + return np.array(res.data[0].embedding), res.usage.total_tokens diff --git a/rag/settings.py b/rag/settings.py index f84831df8095adc593ce3711282861adf071a1b6..da022628f49c0a589ee0004cda689ea13a5e4d66 100644 --- a/rag/settings.py +++ b/rag/settings.py @@ -34,7 +34,7 @@ LoggerFactory.set_directory( "logs", "rag")) # {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0} -LoggerFactory.LEVEL = 10 +LoggerFactory.LEVEL = 30 es_logger = getLogger("es") minio_logger = getLogger("minio") diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 1f5e37af26704ffd12eeefb47cc6062748dc5c21..6ea80d9c95bc96050d3238bb9a1509ca2d60e61d 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -24,6 +24,8 @@ import sys import time import traceback from functools import partial + +from api.db.db_models import close_connection from rag.settings import database_logger from rag.settings import cron_logger, DOC_MAXIMUM_SIZE from multiprocessing import Pool @@ -302,3 +304,4 @@ if __name__ == "__main__": comm = MPI.COMM_WORLD while True: main(int(sys.argv[2]), int(sys.argv[1])) + close_connection()