From 9bf75d4511a9e58af0f9f106ec88e6f06cc2189a Mon Sep 17 00:00:00 2001 From: KevinHuSh <kevinhu.sh@gmail.com> Date: Wed, 17 Jan 2024 20:20:42 +0800 Subject: [PATCH] add dialog api (#33) --- api/apps/__init__.py | 22 +-- api/apps/chunk_app.py | 139 +++++++++++++++---- api/apps/dialog_app.py | 163 +++++++++++++++++++++++ api/apps/document_app.py | 48 ++++--- api/apps/kb_app.py | 20 +-- api/apps/llm_app.py | 22 +-- api/apps/user_app.py | 20 +-- api/db/__init__.py | 11 +- api/db/db_models.py | 27 ++-- api/db/db_services.py | 8 +- api/db/db_utils.py | 16 +-- api/db/init_data.py | 10 +- api/db/operatioins.py | 4 +- api/db/reload_config_base.py | 2 +- api/db/runtime_config.py | 4 +- api/db/services/__init__.py | 2 +- api/db/services/common_service.py | 6 +- api/db/services/dialog_service.py | 16 +-- api/db/services/document_service.py | 14 +- api/db/services/kb_service.py | 15 +-- api/db/services/knowledgebase_service.py | 12 +- api/db/services/llm_service.py | 15 +-- api/db/services/user_service.py | 14 +- api/errors/error_services.py | 2 +- api/errors/general_error.py | 2 +- api/hook/__init__.py | 4 +- api/hook/api/client_authentication.py | 10 +- api/hook/api/permission.py | 10 +- api/hook/api/site_authentication.py | 10 +- api/hook/common/parameters.py | 2 +- api/ragflow_server.py | 18 +-- api/settings.py | 10 +- api/utils/__init__.py | 2 +- api/utils/api_utils.py | 12 +- api/utils/file_utils.py | 4 +- api/utils/log_utils.py | 4 +- api/utils/t_crypt.py | 2 +- api/versions.py | 4 +- rag/llm/__init__.py | 2 +- rag/llm/chat_model.py | 2 +- rag/llm/cv_model.py | 2 +- rag/llm/embedding_model.py | 6 +- rag/nlp/huqie.py | 2 +- rag/nlp/query.py | 2 +- rag/nlp/search.py | 36 ++--- rag/nlp/synonym.py | 2 +- rag/nlp/term_weight.py | 2 +- rag/settings.py | 8 +- rag/svr/parse_user_docs.py | 15 +-- rag/utils/es_conn.py | 1 - 50 files changed, 512 insertions(+), 274 deletions(-) create mode 100644 api/apps/dialog_app.py diff --git a/api/apps/__init__.py b/api/apps/__init__.py index 71ab9f1..fd49e13 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,17 +21,17 @@ from flask import Blueprint, Flask, request from werkzeug.wrappers.request import Request from flask_cors import CORS -from web_server.db import StatusEnum -from web_server.db.services import UserService -from web_server.utils import CustomJSONEncoder +from api.db import StatusEnum +from api.db.services import UserService +from api.utils import CustomJSONEncoder from flask_session import Session from flask_login import LoginManager -from web_server.settings import RetCode, SECRET_KEY, stat_logger -from web_server.hook import HookManager -from web_server.hook.common.parameters import AuthenticationParameters, ClientAuthenticationParameters -from web_server.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger -from web_server.utils.api_utils import get_json_result, server_error_response +from api.settings import RetCode, SECRET_KEY, stat_logger +from api.hook import HookManager +from api.hook.common.parameters import AuthenticationParameters, ClientAuthenticationParameters +from api.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger +from api.utils.api_utils import get_json_result, server_error_response from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer __all__ = ['app'] @@ -68,7 +68,7 @@ def search_pages_path(pages_dir): def register_page(page_path): page_name = page_path.stem.rstrip('_app') - module_name = '.'.join(page_path.parts[page_path.parts.index('web_server'):-1] + (page_name, )) + module_name = '.'.join(page_path.parts[page_path.parts.index('api'):-1] + (page_name, )) spec = spec_from_file_location(module_name, page_path) page = module_from_spec(spec) @@ -86,7 +86,7 @@ def register_page(page_path): pages_dir = [ Path(__file__).parent, - Path(__file__).parent.parent / 'web_server' / 'apps', + Path(__file__).parent.parent / 'api' / 'apps', ] client_urls_prefix = [ diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index f48d6db..adcee9e 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,31 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import base64 import hashlib -import pathlib import re -from elasticsearch_dsl import Q +import numpy as np from flask import request from flask_login import login_required, current_user from rag.nlp import search, huqie from rag.utils import ELASTICSEARCH, rmSpace -from web_server.db import LLMType -from web_server.db.services import duplicate_name -from web_server.db.services.kb_service import KnowledgebaseService -from web_server.db.services.llm_service import TenantLLMService -from web_server.db.services.user_service import UserTenantService -from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request -from web_server.utils import get_uuid -from web_server.db.services.document_service import DocumentService -from web_server.settings import RetCode -from web_server.utils.api_utils import get_json_result -from rag.utils.minio_conn import MINIO -from web_server.utils.file_utils import filename_type - -retrival = search.Dealer(ELASTICSEARCH, None) +from api.db import LLMType +from api.db.services import duplicate_name +from api.db.services.kb_service import KnowledgebaseService +from api.db.services.llm_service import TenantLLMService +from api.db.services.user_service import UserTenantService +from api.utils.api_utils import server_error_response, get_data_error_result, validate_request +from api.db.services.document_service import DocumentService +from api.settings import RetCode +from api.utils.api_utils import get_json_result + +retrival = search.Dealer(ELASTICSEARCH) @manager.route('/list', methods=['POST']) @login_required @@ -45,16 +40,29 @@ retrival = search.Dealer(ELASTICSEARCH, None) def list(): req = request.json doc_id = req["doc_id"] - page = req.get("page", 1) - size = req.get("size", 30) + page = int(req.get("page", 1)) + size = int(req.get("size", 30)) question = req.get("keywords", "") try: - tenants = UserTenantService.query(user_id=current_user.id) - if not tenants: - return get_data_error_result(retmsg="Tenant not found!") - res = retrival.search({ + tenant_id = DocumentService.get_tenant_id(req["doc_id"]) + if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") + query = { "doc_ids": [doc_id], "page": page, "size": size, "question": question - }, search.index_name(tenants[0].tenant_id)) + } + if "available_int" in req: query["available_int"] = int(req["available_int"]) + sres = retrival.search(query, search.index_name(tenant_id)) + res = {"total": sres.total, "chunks": []} + for id in sres.ids: + d = { + "chunk_id": id, + "content_ltks": rmSpace(sres.highlight[id]) if question else sres.field[id]["content_ltks"], + "doc_id": sres.field[id]["doc_id"], + "docnm_kwd": sres.field[id]["docnm_kwd"], + "important_kwd": sres.field[id].get("important_kwd", []), + "img_id": sres.field[id].get("img_id", ""), + "available_int": sres.field[id].get("available_int", 1), + } + res["chunks"].append(d) return get_json_result(data=res) except Exception as e: if str(e).find("not_found") > 0: @@ -102,6 +110,7 @@ def set(): d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) d["important_kwd"] = req["important_kwd"] d["important_tks"] = huqie.qie(" ".join(req["important_kwd"])) + if "available_int" in req: d["available_int"] = req["available_int"] try: tenant_id = DocumentService.get_tenant_id(req["doc_id"]) @@ -116,10 +125,27 @@ def set(): return server_error_response(e) +@manager.route('/switch', methods=['POST']) +@login_required +@validate_request("chunk_ids", "available_int", "doc_id") +def switch(): + req = request.json + try: + tenant_id = DocumentService.get_tenant_id(req["doc_id"]) + if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") + if not ELASTICSEARCH.upsert([{"id": i, "available_int": int(req["available_int"])} for i in req["chunk_ids"]], + search.index_name(tenant_id)): + return get_data_error_result(retmsg="Index updating failure") + return get_json_result(data=True) + except Exception as e: + return server_error_response(e) + + + @manager.route('/create', methods=['POST']) @login_required @validate_request("doc_id", "content_ltks", "important_kwd") -def set(): +def create(): req = request.json md5 = hashlib.md5() md5.update((req["content_ltks"] + req["doc_id"]).encode("utf-8")) @@ -148,3 +174,64 @@ def set(): return get_json_result(data={"chunk_id": chunck_id}) except Exception as e: return server_error_response(e) + + +@manager.route('/retrieval_test', methods=['POST']) +@login_required +@validate_request("kb_id", "question") +def retrieval_test(): + req = request.json + page = int(req.get("page", 1)) + size = int(req.get("size", 30)) + question = req["question"] + kb_id = req["kb_id"] + doc_ids = req.get("doc_ids", []) + similarity_threshold = float(req.get("similarity_threshold", 0.4)) + vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) + top = int(req.get("top", 1024)) + try: + e, kb = KnowledgebaseService.get_by_id(kb_id) + if not e: + return get_data_error_result(retmsg="Knowledgebase not found!") + + embd_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.EMBEDDING.value) + sres = retrival.search({"kb_ids": [kb_id], "doc_ids": doc_ids, "size": top, + "question": question, "vector": True, + "similarity": similarity_threshold}, + search.index_name(kb.tenant_id), + embd_mdl) + + sim, tsim, vsim = retrival.rerank(sres, question, 1-vector_similarity_weight, vector_similarity_weight) + idx = np.argsort(sim*-1) + ranks = {"total": 0, "chunks": [], "doc_aggs": {}} + start_idx = (page-1)*size + for i in idx: + ranks["total"] += 1 + if sim[i] < similarity_threshold: break + start_idx -= 1 + if start_idx >= 0:continue + if len(ranks["chunks"]) == size:continue + id = sres.ids[i] + dnm = sres.field[id]["docnm_kwd"] + d = { + "chunk_id": id, + "content_ltks": sres.field[id]["content_ltks"], + "doc_id": sres.field[id]["doc_id"], + "docnm_kwd": dnm, + "kb_id": sres.field[id]["kb_id"], + "important_kwd": sres.field[id].get("important_kwd", []), + "img_id": sres.field[id].get("img_id", ""), + "similarity": sim[i], + "vector_similarity": vsim[i], + "term_similarity": tsim[i] + } + ranks["chunks"].append(d) + if dnm not in ranks["doc_aggs"]:ranks["doc_aggs"][dnm] = 0 + ranks["doc_aggs"][dnm] += 1 + + return get_json_result(data=ranks) + except Exception as e: + if str(e).find("not_found") > 0: + return get_json_result(data=False, retmsg=f'Index not found!', + retcode=RetCode.DATA_ERROR) + return server_error_response(e) \ No newline at end of file diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py new file mode 100644 index 0000000..1593f8c --- /dev/null +++ b/api/apps/dialog_app.py @@ -0,0 +1,163 @@ +# +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import hashlib +import re + +import numpy as np +from flask import request +from flask_login import login_required, current_user + +from api.db.services.dialog_service import DialogService +from rag.nlp import search, huqie +from rag.utils import ELASTICSEARCH, rmSpace +from api.db import LLMType, StatusEnum +from api.db.services import duplicate_name +from api.db.services.kb_service import KnowledgebaseService +from api.db.services.llm_service import TenantLLMService +from api.db.services.user_service import UserTenantService, TenantService +from api.utils.api_utils import server_error_response, get_data_error_result, validate_request +from api.utils import get_uuid +from api.db.services.document_service import DocumentService +from api.settings import RetCode, stat_logger +from api.utils.api_utils import get_json_result +from rag.utils.minio_conn import MINIO +from api.utils.file_utils import filename_type + + +@manager.route('/set', methods=['POST']) +@login_required +def set(): + req = request.json + dialog_id = req.get("dialog_id") + name = req.get("name", "New Dialog") + description = req.get("description", "A helpful Dialog") + language = req.get("language", "Chinese") + llm_setting_type = req.get("llm_setting_type", "Precise") + llm_setting = req.get("llm_setting", { + "Creative": { + "temperature": 0.9, + "top_p": 0.9, + "frequency_penalty": 0.2, + "presence_penalty": 0.4, + "max_tokens": 512 + }, + "Precise": { + "temperature": 0.1, + "top_p": 0.3, + "frequency_penalty": 0.7, + "presence_penalty": 0.4, + "max_tokens": 215 + }, + "Evenly": { + "temperature": 0.5, + "top_p": 0.5, + "frequency_penalty": 0.7, + "presence_penalty": 0.4, + "max_tokens": 215 + }, + "Custom": { + "temperature": 0.2, + "top_p": 0.3, + "frequency_penalty": 0.6, + "presence_penalty": 0.3, + "max_tokens": 215 + }, + }) + prompt_config = req.get("prompt_config", { + "system": """ä˝ ćŻä¸€ä¸Şć™şč˝ĺŠ©ć‰‹ďĽŚčŻ·ć€»ç»“知识库的内容来回ç”é—®é˘ďĽŚčŻ·ĺ—举知识库ä¸çš„数据详细回ç”。当所有知识库内容é˝ä¸Žé—®é˘ć— ĺ…łć—¶ďĽŚä˝ çš„ĺ›žç”必须包括“知识库ä¸ćśŞć‰ľĺ°ć‚¨č¦çš„ç”ćˇďĽâ€ťčż™ĺŹĄčŻťă€‚回ç”需č¦č€č™‘čŠĺ¤©ĺŽ†ĺŹ˛ă€‚ +以下ćŻçźĄčŻ†ĺş“: +{knowledge} +以上ćŻçźĄčŻ†ĺş“。""", + "prologue": "您好,ć‘ćŻć‚¨çš„助手小樱,长得可ç±ĺŹĺ–„良,can I help you?", + "parameters": [ + {"key": "knowledge", "optional": False} + ], + "empty_response": "Sorry! 知识库ä¸ćśŞć‰ľĺ°ç›¸ĺ…łĺ†…容ďĽ" + }) + + if len(prompt_config["parameters"]) < 1: + return get_data_error_result(retmsg="'knowledge' should be in parameters") + + for p in prompt_config["parameters"]: + if prompt_config["system"].find("{%s}"%p["key"]) < 0: + return get_data_error_result(retmsg="Parameter '{}' is not used".format(p["key"])) + + try: + e, tenant = TenantService.get_by_id(current_user.id) + if not e:return get_data_error_result(retmsg="Tenant not found!") + llm_id = req.get("llm_id", tenant.llm_id) + if not dialog_id: + dia = { + "id": get_uuid(), + "tenant_id": current_user.id, + "name": name, + "description": description, + "language": language, + "llm_id": llm_id, + "llm_setting_type": llm_setting_type, + "llm_setting": llm_setting, + "prompt_config": prompt_config + } + if not DialogService.save(**dia): return get_data_error_result(retmsg="Fail to new a dialog!") + e, dia = DialogService.get_by_id(dia["id"]) + if not e: return get_data_error_result(retmsg="Fail to new a dialog!") + return get_json_result(data=dia.to_json()) + else: + del req["dialog_id"] + if "kb_names" in req: del req["kb_names"] + if not DialogService.update_by_id(dialog_id, req): + return get_data_error_result(retmsg="Dialog not found!") + e, dia = DialogService.get_by_id(dialog_id) + if not e: return get_data_error_result(retmsg="Fail to update a dialog!") + dia = dia.to_dict() + dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"]) + return get_json_result(data=dia) + except Exception as e: + return server_error_response(e) + +@manager.route('/get', methods=['GET']) +@login_required +def get(): + dialog_id = request.args["dialog_id"] + try: + e,dia = DialogService.get_by_id(dialog_id) + if not e: return get_data_error_result(retmsg="Dialog not found!") + dia = dia.to_dict() + dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"]) + return get_json_result(data=dia) + except Exception as e: + return server_error_response(e) + +def get_kb_names(kb_ids): + ids, nms = [], [] + for kid in kb_ids: + e, kb = KnowledgebaseService.get_by_id(kid) + if not e or kb.status != StatusEnum.VALID.value:continue + ids.append(kid) + nms.append(kb.name) + return ids, nms + +@manager.route('/list', methods=['GET']) +@login_required +def list(): + try: + diags = DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value) + diags = [d.to_dict() for d in diags] + for d in diags: + d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"]) + return get_json_result(data=diags) + except Exception as e: + return server_error_response(e) \ No newline at end of file diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 58094e1..48c9240 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,22 +16,23 @@ import base64 import pathlib +import flask from elasticsearch_dsl import Q from flask import request from flask_login import login_required, current_user from rag.nlp import search from rag.utils import ELASTICSEARCH -from web_server.db.services import duplicate_name -from web_server.db.services.kb_service import KnowledgebaseService -from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request -from web_server.utils import get_uuid -from web_server.db import FileType -from web_server.db.services.document_service import DocumentService -from web_server.settings import RetCode -from web_server.utils.api_utils import get_json_result +from api.db.services import duplicate_name +from api.db.services.kb_service import KnowledgebaseService +from api.utils.api_utils import server_error_response, get_data_error_result, validate_request +from api.utils import get_uuid +from api.db import FileType +from api.db.services.document_service import DocumentService +from api.settings import RetCode +from api.utils.api_utils import get_json_result from rag.utils.minio_conn import MINIO -from web_server.utils.file_utils import filename_type +from api.utils.file_utils import filename_type @manager.route('/upload', methods=['POST']) @@ -163,21 +164,13 @@ def change_status(): if str(req["status"]) == "0": ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]), - scripts=""" - if(ctx._source.kb_id.contains('%s')) - ctx._source.kb_id.remove( - ctx._source.kb_id.indexOf('%s') - ); - """ % (doc.kb_id, doc.kb_id), + scripts="ctx._source.available_int=0;", idxnm=search.index_name( kb.tenant_id) ) else: ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]), - scripts=""" - if(!ctx._source.kb_id.contains('%s')) - ctx._source.kb_id.add('%s'); - """ % (doc.kb_id, doc.kb_id), + scripts="ctx._source.available_int=1;", idxnm=search.index_name( kb.tenant_id) ) @@ -195,8 +188,7 @@ def rm(): e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: return get_data_error_result(retmsg="Document not found!") - if not ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)): - return get_json_result(data=False, retmsg='Remove from ES failure"', retcode=RetCode.SERVER_ERROR) + ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)) DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0) if not DocumentService.delete_by_id(req["doc_id"]): @@ -277,3 +269,15 @@ def change_parser(): except Exception as e: return server_error_response(e) + +@manager.route('/image/<image_id>', methods=['GET']) +@login_required +def get_image(image_id): + try: + bkt, nm = image_id.split("-") + response = flask.make_response(MINIO.get(bkt, nm)) + response.headers.set('Content-Type', 'image/JPEG') + return response + except Exception as e: + return server_error_response(e) + diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 0919896..3aae410 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,15 +16,15 @@ from flask import request from flask_login import login_required, current_user -from web_server.db.services import duplicate_name -from web_server.db.services.user_service import TenantService, UserTenantService -from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request -from web_server.utils import get_uuid, get_format_time -from web_server.db import StatusEnum, UserTenantRole -from web_server.db.services.kb_service import KnowledgebaseService -from web_server.db.db_models import Knowledgebase -from web_server.settings import stat_logger, RetCode -from web_server.utils.api_utils import get_json_result +from api.db.services import duplicate_name +from api.db.services.user_service import TenantService, UserTenantService +from api.utils.api_utils import server_error_response, get_data_error_result, validate_request +from api.utils import get_uuid, get_format_time +from api.db import StatusEnum, UserTenantRole +from api.db.services.kb_service import KnowledgebaseService +from api.db.db_models import Knowledgebase +from api.settings import stat_logger, RetCode +from api.utils.api_utils import get_json_result @manager.route('/create', methods=['post']) diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 12b8a63..c7dffaa 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,16 +16,16 @@ from flask import request from flask_login import login_required, current_user -from web_server.db.services import duplicate_name -from web_server.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService -from web_server.db.services.user_service import TenantService, UserTenantService -from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request -from web_server.utils import get_uuid, get_format_time -from web_server.db import StatusEnum, UserTenantRole -from web_server.db.services.kb_service import KnowledgebaseService -from web_server.db.db_models import Knowledgebase, TenantLLM -from web_server.settings import stat_logger, RetCode -from web_server.utils.api_utils import get_json_result +from api.db.services import duplicate_name +from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService +from api.db.services.user_service import TenantService, UserTenantService +from api.utils.api_utils import server_error_response, get_data_error_result, validate_request +from api.utils import get_uuid, get_format_time +from api.db import StatusEnum, UserTenantRole +from api.db.services.kb_service import KnowledgebaseService +from api.db.db_models import Knowledgebase, TenantLLM +from api.settings import stat_logger, RetCode +from api.utils.api_utils import get_json_result @manager.route('/factories', methods=['GET']) diff --git a/api/apps/user_app.py b/api/apps/user_app.py index e7e8a5e..61d67a6 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,15 +17,15 @@ from flask import request, session, redirect, url_for from werkzeug.security import generate_password_hash, check_password_hash from flask_login import login_required, current_user, login_user, logout_user -from web_server.db.db_models import TenantLLM -from web_server.db.services.llm_service import TenantLLMService -from web_server.utils.api_utils import server_error_response, validate_request -from web_server.utils import get_uuid, get_format_time, decrypt, download_img -from web_server.db import UserTenantRole, LLMType -from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS -from web_server.db.services.user_service import UserService, TenantService, UserTenantService -from web_server.settings import stat_logger -from web_server.utils.api_utils import get_json_result, cors_reponse +from api.db.db_models import TenantLLM +from api.db.services.llm_service import TenantLLMService +from api.utils.api_utils import server_error_response, validate_request +from api.utils import get_uuid, get_format_time, decrypt, download_img +from api.db import UserTenantRole, LLMType +from api.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS +from api.db.services.user_service import UserService, TenantService, UserTenantService +from api.settings import stat_logger +from api.utils.api_utils import get_json_result, cors_reponse @manager.route('/login', methods=['POST', 'GET']) diff --git a/api/db/__init__.py b/api/db/__init__.py index 600c144..8097d55 100644 --- a/api/db/__init__.py +++ b/api/db/__init__.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -51,4 +51,11 @@ class LLMType(StrEnum): CHAT = 'chat' EMBEDDING = 'embedding' SPEECH2TEXT = 'speech2text' - IMAGE2TEXT = 'image2text' \ No newline at end of file + IMAGE2TEXT = 'image2text' + + +class ChatStyle(StrEnum): + CREATIVE = 'Creative' + PRECISE = 'Precise' + EVENLY = 'Evenly' + CUSTOM = 'Custom' diff --git a/api/db/db_models.py b/api/db/db_models.py index 9b5886b..36a1331 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,10 +29,10 @@ from peewee import ( ) from playhouse.pool import PooledMySQLDatabase -from web_server.db import SerializedType -from web_server.settings import DATABASE, stat_logger, SECRET_KEY -from web_server.utils.log_utils import getLogger -from web_server import utils +from api.db import SerializedType +from api.settings import DATABASE, stat_logger, SECRET_KEY +from api.utils.log_utils import getLogger +from api import utils LOGGER = getLogger() @@ -467,6 +467,8 @@ class Knowledgebase(DataBaseModel): doc_num = IntegerField(default=0) token_num = IntegerField(default=0) chunk_num = IntegerField(default=0) + similarity_threshold = FloatField(default=0.4) + vector_similarity_weight = FloatField(default=0.3) parser_id = CharField(max_length=32, null=False, help_text="default parser ID") status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") @@ -516,19 +518,20 @@ class Dialog(DataBaseModel): prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced") prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,ć‘ćŻć‚¨çš„助手小樱,长得可ç±ĺŹĺ–„良,can I help you?", "parameters": [], "empty_response": "Sorry! 知识库ä¸ćśŞć‰ľĺ°ç›¸ĺ…łĺ†…容ďĽ"}) + kb_ids = JSONField(null=False, default=[]) status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") class Meta: db_table = "dialog" -class DialogKb(DataBaseModel): - dialog_id = CharField(max_length=32, null=False, index=True) - kb_id = CharField(max_length=32, null=False) - - class Meta: - db_table = "dialog_kb" - primary_key = CompositeKey('dialog_id', 'kb_id') +# class DialogKb(DataBaseModel): +# dialog_id = CharField(max_length=32, null=False, index=True) +# kb_id = CharField(max_length=32, null=False) +# +# class Meta: +# db_table = "dialog_kb" +# primary_key = CompositeKey('dialog_id', 'kb_id') class Conversation(DataBaseModel): diff --git a/api/db/db_services.py b/api/db/db_services.py index 9dc945a..9e734f5 100644 --- a/api/db/db_services.py +++ b/api/db/db_services.py @@ -1,5 +1,5 @@ # -# Copyright 2021 The RAG Flow Authors. All Rights Reserved. +# Copyright 2021 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,10 +19,10 @@ import time from functools import wraps from shortuuid import ShortUUID -from web_server.versions import get_rag_version +from api.versions import get_rag_version -from web_server.errors.error_services import * -from web_server.settings import ( +from api.errors.error_services import * +from api.settings import ( GRPC_PORT, HOST, HTTP_PORT, RANDOM_INSTANCE_ID, stat_logger, ) diff --git a/api/db/db_utils.py b/api/db/db_utils.py index cdae322..47f8788 100644 --- a/api/db/db_utils.py +++ b/api/db/db_utils.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,11 +17,11 @@ import operator from functools import reduce from typing import Dict, Type, Union -from web_server.utils import current_timestamp, timestamp_to_date +from api.utils import current_timestamp, timestamp_to_date -from web_server.db.db_models import DB, DataBaseModel -from web_server.db.runtime_config import RuntimeConfig -from web_server.utils.log_utils import getLogger +from api.db.db_models import DB, DataBaseModel +from api.db.runtime_config import RuntimeConfig +from api.utils.log_utils import getLogger from enum import Enum @@ -123,9 +123,3 @@ def query_db(model: Type[DataBaseModel], limit: int = 0, offset: int = 0, data = data.offset(offset) return list(data), count - - -class StatusEnum(Enum): - # ć ·ćś¬ĺŹŻç”¨çŠ¶ć€ - VALID = "1" - IN_VALID = "0" \ No newline at end of file diff --git a/api/db/init_data.py b/api/db/init_data.py index f9a49ff..96d3f99 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,10 +16,10 @@ import time import uuid -from web_server.db import LLMType -from web_server.db.db_models import init_database_tables as init_web_db -from web_server.db.services import UserService -from web_server.db.services.llm_service import LLMFactoriesService, LLMService +from api.db import LLMType +from api.db.db_models import init_database_tables as init_web_db +from api.db.services import UserService +from api.db.services.llm_service import LLMFactoriesService, LLMService def init_superuser(): diff --git a/api/db/operatioins.py b/api/db/operatioins.py index 3c04fec..79f2e3f 100644 --- a/api/db/operatioins.py +++ b/api/db/operatioins.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,5 +17,5 @@ import operator import time import typing -from web_server.utils.log_utils import sql_logger +from api.utils.log_utils import sql_logger import peewee \ No newline at end of file diff --git a/api/db/reload_config_base.py b/api/db/reload_config_base.py index 746407d..810a949 100644 --- a/api/db/reload_config_base.py +++ b/api/db/reload_config_base.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/api/db/runtime_config.py b/api/db/runtime_config.py index 1b71fe6..b977cb1 100644 --- a/api/db/runtime_config.py +++ b/api/db/runtime_config.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from web_server.versions import get_versions +from api.versions import get_versions from .reload_config_base import ReloadConfigBase diff --git a/api/db/services/__init__.py b/api/db/services/__init__.py index 9861347..dbcfe12 100644 --- a/api/db/services/__init__.py +++ b/api/db/services/__init__.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/api/db/services/common_service.py b/api/db/services/common_service.py index 61b90ee..bef0789 100644 --- a/api/db/services/common_service.py +++ b/api/db/services/common_service.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,8 +17,8 @@ from datetime import datetime import peewee -from web_server.db.db_models import DB -from web_server.utils import datetime_format +from api.db.db_models import DB +from api.utils import datetime_format class CommonService: diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index a4f9bcb..1885ed1 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,14 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import peewee -from werkzeug.security import generate_password_hash, check_password_hash - -from web_server.db.db_models import DB, UserTenant -from web_server.db.db_models import Dialog, Conversation, DialogKb -from web_server.db.services.common_service import CommonService -from web_server.utils import get_uuid, get_format_time -from web_server.db.db_utils import StatusEnum +from api.db.db_models import Dialog, Conversation +from api.db.services.common_service import CommonService class DialogService(CommonService): @@ -29,7 +23,3 @@ class DialogService(CommonService): class ConversationService(CommonService): model = Conversation - - -class DialogKbService(CommonService): - model = DialogKb diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 519c7cb..6b66c14 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,12 +15,12 @@ # from peewee import Expression -from web_server.db import TenantPermission, FileType -from web_server.db.db_models import DB, Knowledgebase, Tenant -from web_server.db.db_models import Document -from web_server.db.services.common_service import CommonService -from web_server.db.services.kb_service import KnowledgebaseService -from web_server.db.db_utils import StatusEnum +from api.db import TenantPermission, FileType +from api.db.db_models import DB, Knowledgebase, Tenant +from api.db.db_models import Document +from api.db.services.common_service import CommonService +from api.db.services.kb_service import KnowledgebaseService +from api.db import StatusEnum class DocumentService(CommonService): diff --git a/api/db/services/kb_service.py b/api/db/services/kb_service.py index 1f25d6b..d0c127a 100644 --- a/api/db/services/kb_service.py +++ b/api/db/services/kb_service.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,15 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import peewee -from werkzeug.security import generate_password_hash, check_password_hash -from web_server.db import TenantPermission -from web_server.db.db_models import DB, UserTenant, Tenant -from web_server.db.db_models import Knowledgebase -from web_server.db.services.common_service import CommonService -from web_server.utils import get_uuid, get_format_time -from web_server.db.db_utils import StatusEnum +from api.db import TenantPermission +from api.db.db_models import DB, Tenant +from api.db.db_models import Knowledgebase +from api.db.services.common_service import CommonService +from api.db import StatusEnum class KnowledgebaseService(CommonService): diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index bd7734d..4d98dcf 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,14 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import peewee -from werkzeug.security import generate_password_hash, check_password_hash - -from web_server.db.db_models import DB, UserTenant -from web_server.db.db_models import Knowledgebase, Document -from web_server.db.services.common_service import CommonService -from web_server.utils import get_uuid, get_format_time -from web_server.db.db_utils import StatusEnum +from api.db.db_models import Knowledgebase, Document +from api.db.services.common_service import CommonService class KnowledgebaseService(CommonService): diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 674f7d0..e406bef 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,15 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import peewee -from werkzeug.security import generate_password_hash, check_password_hash - from rag.llm import EmbeddingModel, CvModel -from web_server.db import LLMType -from web_server.db.db_models import DB, UserTenant -from web_server.db.db_models import LLMFactories, LLM, TenantLLM -from web_server.db.services.common_service import CommonService -from web_server.db.db_utils import StatusEnum +from api.db import LLMType +from api.db.db_models import DB, UserTenant +from api.db.db_models import LLMFactories, LLM, TenantLLM +from api.db.services.common_service import CommonService +from api.db import StatusEnum class LLMFactoriesService(CommonService): diff --git a/api/db/services/user_service.py b/api/db/services/user_service.py index 6641335..87b510a 100644 --- a/api/db/services/user_service.py +++ b/api/db/services/user_service.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,12 +16,12 @@ import peewee from werkzeug.security import generate_password_hash, check_password_hash -from web_server.db import UserTenantRole -from web_server.db.db_models import DB, UserTenant -from web_server.db.db_models import User, Tenant -from web_server.db.services.common_service import CommonService -from web_server.utils import get_uuid, get_format_time -from web_server.db.db_utils import StatusEnum +from api.db import UserTenantRole +from api.db.db_models import DB, UserTenant +from api.db.db_models import User, Tenant +from api.db.services.common_service import CommonService +from api.utils import get_uuid, get_format_time +from api.db import StatusEnum class UserService(CommonService): diff --git a/api/errors/error_services.py b/api/errors/error_services.py index 98350ce..73b383b 100644 --- a/api/errors/error_services.py +++ b/api/errors/error_services.py @@ -1,4 +1,4 @@ -from web_server.errors import RagFlowError +from api.errors import RagFlowError __all__ = ['ServicesError', 'ServiceNotSupported', 'ZooKeeperNotConfigured', 'MissingZooKeeperUsernameOrPassword', 'ZooKeeperBackendError'] diff --git a/api/errors/general_error.py b/api/errors/general_error.py index 48fddd7..dba1915 100644 --- a/api/errors/general_error.py +++ b/api/errors/general_error.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/api/hook/__init__.py b/api/hook/__init__.py index 3c21c07..9cbec58 100644 --- a/api/hook/__init__.py +++ b/api/hook/__init__.py @@ -1,8 +1,8 @@ import importlib -from web_server.hook.common.parameters import SignatureParameters, AuthenticationParameters, \ +from api.hook.common.parameters import SignatureParameters, AuthenticationParameters, \ SignatureReturn, AuthenticationReturn, PermissionReturn, ClientAuthenticationReturn, ClientAuthenticationParameters -from web_server.settings import HOOK_MODULE, stat_logger,RetCode +from api.settings import HOOK_MODULE, stat_logger,RetCode class HookManager: diff --git a/api/hook/api/client_authentication.py b/api/hook/api/client_authentication.py index 99e9389..da126af 100644 --- a/api/hook/api/client_authentication.py +++ b/api/hook/api/client_authentication.py @@ -1,10 +1,10 @@ import requests -from web_server.db.service_registry import ServiceRegistry -from web_server.settings import RegistryServiceName -from web_server.hook import HookManager -from web_server.hook.common.parameters import ClientAuthenticationParameters, ClientAuthenticationReturn -from web_server.settings import HOOK_SERVER_NAME +from api.db.service_registry import ServiceRegistry +from api.settings import RegistryServiceName +from api.hook import HookManager +from api.hook.common.parameters import ClientAuthenticationParameters, ClientAuthenticationReturn +from api.settings import HOOK_SERVER_NAME @HookManager.register_client_authentication_hook diff --git a/api/hook/api/permission.py b/api/hook/api/permission.py index 318173d..76bfc5f 100644 --- a/api/hook/api/permission.py +++ b/api/hook/api/permission.py @@ -1,10 +1,10 @@ import requests -from web_server.db.service_registry import ServiceRegistry -from web_server.settings import RegistryServiceName -from web_server.hook import HookManager -from web_server.hook.common.parameters import PermissionCheckParameters, PermissionReturn -from web_server.settings import HOOK_SERVER_NAME +from api.db.service_registry import ServiceRegistry +from api.settings import RegistryServiceName +from api.hook import HookManager +from api.hook.common.parameters import PermissionCheckParameters, PermissionReturn +from api.settings import HOOK_SERVER_NAME @HookManager.register_permission_check_hook diff --git a/api/hook/api/site_authentication.py b/api/hook/api/site_authentication.py index bea3b77..7c751d9 100644 --- a/api/hook/api/site_authentication.py +++ b/api/hook/api/site_authentication.py @@ -1,11 +1,11 @@ import requests -from web_server.db.service_registry import ServiceRegistry -from web_server.settings import RegistryServiceName -from web_server.hook import HookManager -from web_server.hook.common.parameters import SignatureParameters, AuthenticationParameters, AuthenticationReturn,\ +from api.db.service_registry import ServiceRegistry +from api.settings import RegistryServiceName +from api.hook import HookManager +from api.hook.common.parameters import SignatureParameters, AuthenticationParameters, AuthenticationReturn,\ SignatureReturn -from web_server.settings import HOOK_SERVER_NAME, PARTY_ID +from api.settings import HOOK_SERVER_NAME, PARTY_ID @HookManager.register_site_signature_hook diff --git a/api/hook/common/parameters.py b/api/hook/common/parameters.py index 40ce4ef..22893e8 100644 --- a/api/hook/common/parameters.py +++ b/api/hook/common/parameters.py @@ -1,4 +1,4 @@ -from web_server.settings import RetCode +from api.settings import RetCode class ParametersBase: diff --git a/api/ragflow_server.py b/api/ragflow_server.py index 5f1a374..838f5a1 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,17 +23,17 @@ import traceback from werkzeug.serving import run_simple -from web_server.apps import app -from web_server.db.runtime_config import RuntimeConfig -from web_server.hook import HookManager -from web_server.settings import ( +from api.apps import app +from api.db.runtime_config import RuntimeConfig +from api.hook import HookManager +from api.settings import ( HOST, HTTP_PORT, access_logger, database_logger, stat_logger, ) -from web_server import utils +from api import utils -from web_server.db.db_models import init_database_tables as init_web_db -from web_server.db.init_data import init_web_data -from web_server.versions import get_versions +from api.db.db_models import init_database_tables as init_web_db +from api.db.init_data import init_web_data +from api.versions import get_versions if __name__ == '__main__': stat_logger.info( diff --git a/api/settings.py b/api/settings.py index 22f289c..bad8650 100644 --- a/api/settings.py +++ b/api/settings.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,9 +17,9 @@ import os from enum import IntEnum, Enum -from web_server.utils import get_base_config,decrypt_database_config -from web_server.utils.file_utils import get_project_base_directory -from web_server.utils.log_utils import LoggerFactory, getLogger +from api.utils import get_base_config,decrypt_database_config +from api.utils.file_utils import get_project_base_directory +from api.utils.log_utils import LoggerFactory, getLogger # Server @@ -71,7 +71,7 @@ PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol") DATABASE = decrypt_database_config() # Logger -LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "web_server")) +LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "api")) # {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0} LoggerFactory.LEVEL = 10 diff --git a/api/utils/__init__.py b/api/utils/__init__.py index 5347269..4bd9c35 100644 --- a/api/utils/__init__.py +++ b/api/utils/__init__.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index e3031fb..fd63f7a 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,16 +24,16 @@ from flask import ( ) from werkzeug.http import HTTP_STATUS_CODES -from web_server.utils import json_dumps -from web_server.versions import get_rag_version -from web_server.settings import RetCode -from web_server.settings import ( +from api.utils import json_dumps +from api.versions import get_rag_version +from api.settings import RetCode +from api.settings import ( REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC, stat_logger,CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY ) import requests import functools -from web_server.utils import CustomJSONEncoder +from api.utils import CustomJSONEncoder from uuid import uuid1 from base64 import b64encode from hmac import HMAC diff --git a/api/utils/file_utils.py b/api/utils/file_utils.py index 60c5b1f..92f0a9c 100644 --- a/api/utils/file_utils.py +++ b/api/utils/file_utils.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ import re from cachetools import LRUCache, cached from ruamel.yaml import YAML -from web_server.db import FileType +from api.db import FileType PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") RAG_BASE = os.getenv("RAG_BASE") diff --git a/api/utils/log_utils.py b/api/utils/log_utils.py index 45d25b2..2ebf2e5 100644 --- a/api/utils/log_utils.py +++ b/api/utils/log_utils.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ import inspect from logging.handlers import TimedRotatingFileHandler from threading import RLock -from web_server.utils import file_utils +from api.utils import file_utils class LoggerFactory(object): TYPE = "FILE" diff --git a/api/utils/t_crypt.py b/api/utils/t_crypt.py index 1d007f4..224bf22 100644 --- a/api/utils/t_crypt.py +++ b/api/utils/t_crypt.py @@ -1,7 +1,7 @@ import base64, os, sys from Cryptodome.PublicKey import RSA from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5 -from web_server.utils import decrypt, file_utils +from api.utils import decrypt, file_utils def crypt(line): file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem") diff --git a/api/versions.py b/api/versions.py index cdf1041..5d92988 100644 --- a/api/versions.py +++ b/api/versions.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ import os import dotenv import typing -from web_server.utils.file_utils import get_project_base_directory +from api.utils.file_utils import get_project_base_directory def get_versions() -> typing.Mapping[str, typing.Any]: diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index b14a9db..2671e1d 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 5564b9c..06ac625 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index b3ec202..bc923a5 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 2060872..e70dc6c 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -60,6 +60,10 @@ class HuEmbedding(Base): res.extend(self.model.encode(texts[i:i + batch_size]).tolist()) return np.array(res), token_count + def encode_queries(self, text: str): + token_count = num_tokens_from_string(text) + return self.model.encode_queries([text]).tolist()[0], token_count + class OpenAIEmbed(Base): def __init__(self, key, model_name="text-embedding-ada-002"): diff --git a/rag/nlp/huqie.py b/rag/nlp/huqie.py index f2df67b..2bdcaf9 100644 --- a/rag/nlp/huqie.py +++ b/rag/nlp/huqie.py @@ -9,7 +9,7 @@ import string import sys from hanziconv import HanziConv -from web_server.utils.file_utils import get_project_base_directory +from api.utils.file_utils import get_project_base_directory class Huqie: diff --git a/rag/nlp/query.py b/rag/nlp/query.py index de1edd7..17364b3 100644 --- a/rag/nlp/query.py +++ b/rag/nlp/query.py @@ -147,7 +147,7 @@ class EsQueryer: atks = toDict(atks) btkss = [toDict(tks) for tks in btkss] tksim = [self.similarity(atks, btks) for btks in btkss] - return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight + return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, sims[0], tksim def similarity(self, qtwt, dtwt): if isinstance(dtwt, type("")): diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 442812c..66e3fe5 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -15,7 +15,7 @@ def index_name(uid): return f"ragflow_{uid}" class Dealer: - def __init__(self, es, emb_mdl): + def __init__(self, es): self.qryr = query.EsQueryer(es) self.qryr.flds = [ "title_tks^10", @@ -23,7 +23,6 @@ class Dealer: "content_ltks^2", "content_sm_ltks"] self.es = es - self.emb_mdl = emb_mdl @dataclass class SearchResult: @@ -36,23 +35,26 @@ class Dealer: keywords: Optional[List[str]] = None group_docs: List[List] = None - def _vector(self, txt, sim=0.8, topk=10): - qv = self.emb_mdl.encode_queries(txt) + def _vector(self, txt, emb_mdl, sim=0.8, topk=10): + qv, c = emb_mdl.encode_queries(txt) return { "field": "q_%d_vec"%len(qv), "k": topk, "similarity": sim, - "num_candidates": 1000, + "num_candidates": topk*2, "query_vector": qv } - def search(self, req, idxnm, tks_num=3): + def search(self, req, idxnm, emb_mdl=None): qst = req.get("question", "") bqry, keywords = self.qryr.question(qst) if req.get("kb_ids"): bqry.filter.append(Q("terms", kb_id=req["kb_ids"])) if req.get("doc_ids"): bqry.filter.append(Q("terms", doc_id=req["doc_ids"])) + if "available_int" in req: + if req["available_int"] == 0: bqry.filter.append(Q("range", available_int={"lt": 1})) + else: bqry.filter.append(Q("bool", must_not=Q("range", available_int={"lt": 1}))) bqry.boost = 0.05 s = Search() @@ -60,7 +62,7 @@ class Dealer: ps = int(req.get("size", 1000)) src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id","img_id", "image_id", "doc_id", "q_512_vec", "q_768_vec", - "q_1024_vec", "q_1536_vec"]) + "q_1024_vec", "q_1536_vec", "available_int"]) s = s.query(bqry)[pg * ps:(pg + 1) * ps] s = s.highlight("content_ltks") @@ -80,7 +82,8 @@ class Dealer: s = s.to_dict() q_vec = [] if req.get("vector"): - s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps) + assert emb_mdl, "No embedding model selected" + s["knn"] = self._vector(qst, emb_mdl, req.get("similarity", 0.4), ps) s["knn"]["filter"] = bqry.to_dict() if "highlight" in s: del s["highlight"] q_vec = s["knn"]["query_vector"] @@ -168,7 +171,7 @@ class Dealer: def trans2floats(txt): return [float(t) for t in txt.split("\t")] - def insert_citations(self, ans, top_idx, sres, + def insert_citations(self, ans, top_idx, sres, emb_mdl, vfield="q_vec", cfield="content_ltks"): ins_embd = [Dealer.trans2floats( @@ -179,15 +182,14 @@ class Dealer: res = "" def citeit(): - nonlocal s, e, ans, res + nonlocal s, e, ans, res, emb_mdl if not ins_embd: return - embd = self.emb_mdl.encode(ans[s: e]) + embd = emb_mdl.encode(ans[s: e]) sim = self.qryr.hybrid_similarity(embd, ins_embd, huqie.qie(ans[s:e]).split(" "), ins_tw) - print(ans[s: e], sim) mx = np.max(sim) * 0.99 if mx < 0.55: return @@ -225,20 +227,18 @@ class Dealer: return res - def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, - vfield="q_vec", cfield="content_ltks"): + def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks"): ins_embd = [ Dealer.trans2floats( - sres.field[i]["q_vec"]) for i in sres.ids] + sres.field[i]["q_%d_vec"%len(sres.query_vector)]) for i in sres.ids] if not ins_embd: return [] ins_tw = [sres.field[i][cfield].split(" ") for i in sres.ids] - # return CosineSimilarity([sres.query_vector], ins_embd)[0] - sim = self.qryr.hybrid_similarity(sres.query_vector, + sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, ins_embd, huqie.qie(query).split(" "), ins_tw, tkweight, vtweight) - return sim + return sim, tksim, vtsim diff --git a/rag/nlp/synonym.py b/rag/nlp/synonym.py index 895fab3..e358c91 100644 --- a/rag/nlp/synonym.py +++ b/rag/nlp/synonym.py @@ -4,7 +4,7 @@ import time import logging import re -from web_server.utils.file_utils import get_project_base_directory +from api.utils.file_utils import get_project_base_directory class Dealer: diff --git a/rag/nlp/term_weight.py b/rag/nlp/term_weight.py index 752bbf3..3c60808 100644 --- a/rag/nlp/term_weight.py +++ b/rag/nlp/term_weight.py @@ -5,7 +5,7 @@ import re import os import numpy as np from rag.nlp import huqie -from web_server.utils.file_utils import get_project_base_directory +from api.utils.file_utils import get_project_base_directory class Dealer: diff --git a/rag/settings.py b/rag/settings.py index 43deedf..9180dd3 100644 --- a/rag/settings.py +++ b/rag/settings.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,9 +14,9 @@ # limitations under the License. # import os -from web_server.utils import get_base_config,decrypt_database_config -from web_server.utils.file_utils import get_project_base_directory -from web_server.utils.log_utils import LoggerFactory, getLogger +from api.utils import get_base_config,decrypt_database_config +from api.utils.file_utils import get_project_base_directory +from api.utils.log_utils import LoggerFactory, getLogger # Server diff --git a/rag/svr/parse_user_docs.py b/rag/svr/parse_user_docs.py index 0de5d03..1712106 100644 --- a/rag/svr/parse_user_docs.py +++ b/rag/svr/parse_user_docs.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The RAG Flow Authors. All Rights Reserved. +# Copyright 2019 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -47,12 +47,12 @@ from rag.nlp.huchunk import ( PptChunker, TextChunker ) -from web_server.db import LLMType -from web_server.db.services.document_service import DocumentService -from web_server.db.services.llm_service import TenantLLMService -from web_server.settings import database_logger -from web_server.utils import get_format_time -from web_server.utils.file_utils import get_project_base_directory +from api.db import LLMType +from api.db.services.document_service import DocumentService +from api.db.services.llm_service import TenantLLMService +from api.settings import database_logger +from api.utils import get_format_time +from api.utils.file_utils import get_project_base_directory BATCH_SIZE = 64 @@ -257,7 +257,6 @@ def main(comm, mod): cron_logger.error(str(e)) continue - set_progress(r["id"], random.randint(70, 95) / 100., "Finished embedding! Start to build index!") init_kb(r) diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index e036ca9..0c7f2da 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -66,7 +66,6 @@ class HuEs: body=d, id=id, refresh=False, - doc_type="_doc", retry_on_conflict=100) es_logger.info("Successfully upsert: %s" % id) T = True -- GitLab