diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 551f9c60a903610166f3993f8cf5d54906bc44df..e148560c2f5c8b0182a33418e86e0b5ae8f4adf1 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -35,7 +35,7 @@ class Base(ABC): class HuEmbedding(Base): - def __init__(self): + def __init__(self, key="", model_name=""): """ If you have trouble downloading HuggingFace models, -_^ this might help!! diff --git a/rag/nlp/huchunk.py b/rag/nlp/huchunk.py index fb28a915cf7766cfb9c293e4362a9df473c557bd..cc93f5faf2cf429a6f96eb1acaa2d43b41cd3228 100644 --- a/rag/nlp/huchunk.py +++ b/rag/nlp/huchunk.py @@ -411,9 +411,12 @@ class TextChunker(HuChunker): flds = self.Fields() if self.is_binary_file(fnm): return flds - with open(fnm, "r") as f: - txt = f.read() - flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)] + txt = "" + if isinstance(fnm, str): + with open(fnm, "r") as f: + txt = f.read() + else: txt = fnm.decode("utf-8") + flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)] flds.table_chunks = [] return flds diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 05ce6276fbd9b7877c70e50627c6d35946efd320..d79640b4c1e60ea94fc63a6bfca8ae471b98edf0 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -8,7 +8,7 @@ from rag.nlp import huqie, query import numpy as np -def index_name(uid): return f"docgpt_{uid}" +def index_name(uid): return f"ragflow_{uid}" class Dealer: diff --git a/rag/svr/parse_user_docs.py b/rag/svr/parse_user_docs.py index 29d2c28764e93d654330478dd2cd4d7f52ce1808..188662e6b2a31c2f4e045ba6e2da9242f088c495 100644 --- a/rag/svr/parse_user_docs.py +++ b/rag/svr/parse_user_docs.py @@ -14,6 +14,7 @@ # limitations under the License. # import json +import logging import os import hashlib import copy @@ -24,9 +25,10 @@ from timeit import default_timer as timer from rag.llm import EmbeddingModel, CvModel from rag.settings import cron_logger, DOC_MAXIMUM_SIZE -from rag.utils import ELASTICSEARCH, num_tokens_from_string +from rag.utils import ELASTICSEARCH from rag.utils import MINIO -from rag.utils import rmSpace, findMaxDt +from rag.utils import rmSpace, findMaxTm + from rag.nlp import huchunk, huqie, search from io import BytesIO import pandas as pd @@ -47,6 +49,7 @@ from rag.nlp.huchunk import ( from web_server.db import LLMType from web_server.db.services.document_service import DocumentService from web_server.db.services.llm_service import TenantLLMService +from web_server.settings import database_logger from web_server.utils import get_format_time from web_server.utils.file_utils import get_project_base_directory @@ -83,7 +86,7 @@ def collect(comm, mod, tm): if len(docs) == 0: return pd.DataFrame() docs = pd.DataFrame(docs) - mtm = str(docs["update_time"].max())[:19] + mtm = docs["update_time"].max() cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm)) return docs @@ -99,11 +102,12 @@ def set_progress(docid, prog, msg="Processing...", begin=False): cron_logger.error("set_progress:({}), {}".format(docid, str(e))) -def build(row): +def build(row, cvmdl): if row["size"] > DOC_MAXIMUM_SIZE: set_progress(row["id"], -1, "File size exceeds( <= %dMb )" % (int(DOC_MAXIMUM_SIZE / 1024 / 1024))) return [] + res = ELASTICSEARCH.search(Q("term", doc_id=row["id"])) if ELASTICSEARCH.getTotal(res) > 0: ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]), @@ -120,7 +124,8 @@ def build(row): set_progress(row["id"], random.randint(0, 20) / 100., "Finished preparing! Start to slice file!", True) try: - obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"])) + cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"])) + obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl) except Exception as e: if re.search("(No such file|not found)", str(e)): set_progress( @@ -131,6 +136,9 @@ def build(row): row["id"], -1, f"Internal server error: %s" % str(e).replace( "'", "")) + + cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e))) + return [] if not obj.text_chunks and not obj.table_chunks: @@ -144,7 +152,7 @@ def build(row): "Finished slicing files. Start to embedding the content.") doc = { - "doc_id": row["did"], + "doc_id": row["id"], "kb_id": [str(row["kb_id"])], "docnm_kwd": os.path.split(row["location"])[-1], "title_tks": huqie.qie(row["name"]), @@ -164,10 +172,10 @@ def build(row): docs.append(d) continue - if isinstance(img, Image): - img.save(output_buffer, format='JPEG') - else: + if isinstance(img, bytes): output_buffer = BytesIO(img) + else: + img.save(output_buffer, format='JPEG') MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue()) d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"]) @@ -215,15 +223,16 @@ def embedding(docs, mdl): def model_instance(tenant_id, llm_type): - model_config = TenantLLMService.query(tenant_id=tenant_id, model_type=LLMType.EMBEDDING) - if not model_config:return - model_config = model_config[0] + model_config = TenantLLMService.get_api_key(tenant_id, model_type=LLMType.EMBEDDING) + if not model_config: + model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""} + else: model_config = model_config[0].to_dict() if llm_type == LLMType.EMBEDDING: - if model_config.llm_factory not in EmbeddingModel: return - return EmbeddingModel[model_config.llm_factory](model_config.api_key, model_config.llm_name) + if model_config["llm_factory"] not in EmbeddingModel: return + return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"]) if llm_type == LLMType.IMAGE2TEXT: - if model_config.llm_factory not in CvModel: return - return CvModel[model_config.llm_factory](model_config.api_key, model_config.llm_name) + if model_config["llm_factory"] not in CvModel: return + return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"]) def main(comm, mod): @@ -231,7 +240,7 @@ def main(comm, mod): from rag.llm import HuEmbedding model = HuEmbedding() tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm") - tm = findMaxDt(tm_fnm) + tm = findMaxTm(tm_fnm) rows = collect(comm, mod, tm) if len(rows) == 0: return @@ -247,7 +256,7 @@ def main(comm, mod): st_tm = timer() cks = build(r, cv_mdl) if not cks: - tmf.write(str(r["updated_at"]) + "\n") + tmf.write(str(r["update_time"]) + "\n") continue # TODO: exception handler ## set_progress(r["did"], -1, "ERROR: ") @@ -268,12 +277,19 @@ def main(comm, mod): cron_logger.error(str(es_r)) else: set_progress(r["id"], 1., "Done!") - DocumentService.update_by_id(r["id"], {"token_num": tk_count, "chunk_num": len(cks), "process_duation": timer()-st_tm}) + DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, len(cks), timer()-st_tm) + cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks))) + tmf.write(str(r["update_time"]) + "\n") tmf.close() if __name__ == "__main__": + peewee_logger = logging.getLogger('peewee') + peewee_logger.propagate = False + peewee_logger.addHandler(database_logger.handlers[0]) + peewee_logger.setLevel(database_logger.level) + from mpi4py import MPI comm = MPI.COMM_WORLD main(comm.Get_size(), comm.Get_rank()) diff --git a/rag/utils/__init__.py b/rag/utils/__init__.py index d3f1632334f390724cead41352cc14491d24f845..9898d19d5040fe2857bd0c3645488c051ef2267d 100644 --- a/rag/utils/__init__.py +++ b/rag/utils/__init__.py @@ -40,6 +40,25 @@ def findMaxDt(fnm): print("WARNING: can't find " + fnm) return m + +def findMaxTm(fnm): + m = 0 + try: + with open(fnm, "r") as f: + while True: + l = f.readline() + if not l: + break + l = l.strip("\n") + if l == 'nan': + continue + if int(l) > m: + m = int(l) + except Exception as e: + print("WARNING: can't find " + fnm) + return m + + def num_tokens_from_string(string: str) -> int: """Returns the number of tokens in a text string.""" encoding = tiktoken.get_encoding('cl100k_base') diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index f8337c01db0599f65c61556299a20985efcf2f2c..632b01d6eab1c00dbf9c57a708ae678da9de0b4c 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -294,6 +294,7 @@ class HuEs: except Exception as e: es_logger.error("ES updateByQuery deleteByQuery: " + str(e) + "ă€Q】:" + str(query.to_dict())) + if str(e).find("NotFoundError") > 0: return True if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: continue diff --git a/web_server/apps/document_app.py b/web_server/apps/document_app.py index d14d69ab1d67c820c841dd8388b30950bb89622f..9be9cfde9cb73fbc6a9894b2600837fd4a336e4c 100644 --- a/web_server/apps/document_app.py +++ b/web_server/apps/document_app.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import base64 import pathlib from elasticsearch_dsl import Q @@ -195,11 +196,15 @@ def rm(): e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: return get_data_error_result(retmsg="Document not found!") + if not ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)): + return get_json_result(data=False, retmsg='Remove from ES failure"', retcode=RetCode.SERVER_ERROR) + + DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0) if not DocumentService.delete_by_id(req["doc_id"]): return get_data_error_result( retmsg="Database error (Document removal)!") - e, kb = KnowledgebaseService.get_by_id(doc.kb_id) - MINIO.rm(kb.id, doc.location) + + MINIO.rm(doc.kb_id, doc.location) return get_json_result(data=True) except Exception as e: return server_error_response(e) @@ -233,3 +238,43 @@ def rename(): return get_json_result(data=True) except Exception as e: return server_error_response(e) + + +@manager.route('/get', methods=['GET']) +@login_required +def get(): + doc_id = request.args["doc_id"] + try: + e, doc = DocumentService.get_by_id(doc_id) + if not e: + return get_data_error_result(retmsg="Document not found!") + + blob = MINIO.get(doc.kb_id, doc.location) + return get_json_result(data={"base64": base64.b64decode(blob)}) + except Exception as e: + return server_error_response(e) + + +@manager.route('/change_parser', methods=['POST']) +@login_required +@validate_request("doc_id", "parser_id") +def change_parser(): + req = request.json + try: + e, doc = DocumentService.get_by_id(req["doc_id"]) + if not e: + return get_data_error_result(retmsg="Document not found!") + if doc.parser_id.lower() == req["parser_id"].lower(): + return get_json_result(data=True) + + e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""}) + if not e: + return get_data_error_result(retmsg="Document not found!") + e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1) + if not e: + return get_data_error_result(retmsg="Document not found!") + + return get_json_result(data=True) + except Exception as e: + return server_error_response(e) + diff --git a/web_server/apps/kb_app.py b/web_server/apps/kb_app.py index c035cb6375e2e0a03d949f33743e768ca1011523..054f97e000911f6e22c5ff453ff113e27af8d0a9 100644 --- a/web_server/apps/kb_app.py +++ b/web_server/apps/kb_app.py @@ -29,7 +29,7 @@ from web_server.utils.api_utils import get_json_result @manager.route('/create', methods=['post']) @login_required -@validate_request("name", "description", "permission", "embd_id", "parser_id") +@validate_request("name", "description", "permission", "parser_id") def create(): req = request.json req["name"] = req["name"].strip() @@ -46,7 +46,7 @@ def create(): @manager.route('/update', methods=['post']) @login_required -@validate_request("kb_id", "name", "description", "permission", "embd_id", "parser_id") +@validate_request("kb_id", "name", "description", "permission", "parser_id") def update(): req = request.json req["name"] = req["name"].strip() @@ -72,6 +72,18 @@ def update(): return server_error_response(e) +@manager.route('/detail', methods=['GET']) +@login_required +def detail(): + kb_id = request.args["kb_id"] + try: + kb = KnowledgebaseService.get_detail(kb_id) + if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!") + return get_json_result(data=kb) + except Exception as e: + return server_error_response(e) + + @manager.route('/list', methods=['GET']) @login_required def list(): diff --git a/web_server/apps/llm_app.py b/web_server/apps/llm_app.py new file mode 100644 index 0000000000000000000000000000000000000000..0877a1977fe7d42efa95c63cc28374b1e4ca1e49 --- /dev/null +++ b/web_server/apps/llm_app.py @@ -0,0 +1,95 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from flask import request +from flask_login import login_required, current_user + +from web_server.db.services import duplicate_name +from web_server.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService +from web_server.db.services.user_service import TenantService, UserTenantService +from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request +from web_server.utils import get_uuid, get_format_time +from web_server.db import StatusEnum, UserTenantRole +from web_server.db.services.kb_service import KnowledgebaseService +from web_server.db.db_models import Knowledgebase, TenantLLM +from web_server.settings import stat_logger, RetCode +from web_server.utils.api_utils import get_json_result + + +@manager.route('/factories', methods=['GET']) +@login_required +def factories(): + try: + fac = LLMFactoriesService.get_all() + return get_json_result(data=fac.to_json()) + except Exception as e: + return server_error_response(e) + + +@manager.route('/set_api_key', methods=['POST']) +@login_required +@validate_request("llm_factory", "api_key") +def set_api_key(): + req = request.json + llm = { + "tenant_id": current_user.id, + "llm_factory": req["llm_factory"], + "api_key": req["api_key"] + } + # TODO: Test api_key + for n in ["model_type", "llm_name"]: + if n in req: llm[n] = req[n] + + TenantLLM.insert(**llm).on_conflict("replace").execute() + return get_json_result(data=True) + + +@manager.route('/my_llms', methods=['GET']) +@login_required +def my_llms(): + try: + objs = TenantLLMService.query(tenant_id=current_user.id) + objs = [o.to_dict() for o in objs] + for o in objs: del o["api_key"] + return get_json_result(data=objs) + except Exception as e: + return server_error_response(e) + + +@manager.route('/list', methods=['GET']) +@login_required +def list(): + try: + objs = TenantLLMService.query(tenant_id=current_user.id) + objs = [o.to_dict() for o in objs if o.api_key] + fct = {} + for o in objs: + if o["llm_factory"] not in fct: fct[o["llm_factory"]] = [] + if o["llm_name"]: fct[o["llm_factory"]].append(o["llm_name"]) + + llms = LLMService.get_all() + llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value] + for m in llms: + m["available"] = False + if m["fid"] in fct and (not fct[m["fid"]] or m["llm_name"] in fct[m["fid"]]): + m["available"] = True + res = {} + for m in llms: + if m["fid"] not in res: res[m["fid"]] = [] + res[m["fid"]].append(m) + + return get_json_result(data=res) + except Exception as e: + return server_error_response(e) \ No newline at end of file diff --git a/web_server/apps/user_app.py b/web_server/apps/user_app.py index aa2ba43cec4058d2ff257421947a652777be067f..81946074e6c3b78c55a3fad6724b08ab9c0d3587 100644 --- a/web_server/apps/user_app.py +++ b/web_server/apps/user_app.py @@ -16,9 +16,12 @@ from flask import request, session, redirect, url_for from werkzeug.security import generate_password_hash, check_password_hash from flask_login import login_required, current_user, login_user, logout_user + +from web_server.db.db_models import TenantLLM +from web_server.db.services.llm_service import TenantLLMService from web_server.utils.api_utils import server_error_response, validate_request from web_server.utils import get_uuid, get_format_time, decrypt, download_img -from web_server.db import UserTenantRole +from web_server.db import UserTenantRole, LLMType from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS from web_server.db.services.user_service import UserService, TenantService, UserTenantService from web_server.settings import stat_logger @@ -47,8 +50,9 @@ def login(): avatar = download_img(userinfo["avatar_url"]) except Exception as e: stat_logger.exception(e) + user_id = get_uuid() try: - users = user_register({ + users = user_register(user_id, { "access_token": session["access_token"], "email": userinfo["email"], "avatar": avatar, @@ -63,6 +67,7 @@ def login(): login_user(user) return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!") except Exception as e: + rollback_user_registration(user_id) stat_logger.exception(e) return server_error_response(e) elif not request.json: @@ -162,7 +167,25 @@ def user_info(): return get_json_result(data=current_user.to_dict()) -def user_register(user): +def rollback_user_registration(user_id): + try: + TenantService.delete_by_id(user_id) + except Exception as e: + pass + try: + u = UserTenantService.query(tenant_id=user_id) + if u: + UserTenantService.delete_by_id(u[0].id) + except Exception as e: + pass + try: + TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute() + except Exception as e: + pass + + +def user_register(user_id, user): + user_id = get_uuid() user["id"] = user_id tenant = { @@ -180,10 +203,12 @@ def user_register(user): "invited_by": user_id, "role": UserTenantRole.OWNER } + tenant_llm = {"tenant_id": user_id, "llm_factory": "OpenAI", "api_key": "infiniflow API Key"} if not UserService.save(**user):return TenantService.save(**tenant) UserTenantService.save(**usr_tenant) + TenantLLMService.save(**tenant_llm) return UserService.query(email=user["email"]) @@ -203,14 +228,17 @@ def user_add(): "last_login_time": get_format_time(), "is_superuser": False, } + + user_id = get_uuid() try: - users = user_register(user_dict) + users = user_register(user_id, user_dict) if not users: raise Exception('Register user failure.') if len(users) > 1: raise Exception('Same E-mail exist!') user = users[0] login_user(user) return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!") except Exception as e: + rollback_user_registration(user_id) stat_logger.exception(e) return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR) @@ -220,7 +248,7 @@ def user_add(): @login_required def tenant_info(): try: - tenants = TenantService.get_by_user_id(current_user.id) + tenants = TenantService.get_by_user_id(current_user.id)[0] return get_json_result(data=tenants) except Exception as e: return server_error_response(e) diff --git a/web_server/db/db_models.py b/web_server/db/db_models.py index b6761680369220103a3a4c1ad33bfaa2a057c8d2..62d92b4753208908969da6bd6499d96e11d0e737 100644 --- a/web_server/db/db_models.py +++ b/web_server/db/db_models.py @@ -428,6 +428,7 @@ class LLMFactories(DataBaseModel): class LLM(DataBaseModel): # defautlt LLMs for every users llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True) + model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") fid = CharField(max_length=128, null=False, help_text="LLM factory id") tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...") status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") @@ -442,8 +443,8 @@ class LLM(DataBaseModel): class TenantLLM(DataBaseModel): tenant_id = CharField(max_length=32, null=False) llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name") - model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") - llm_name = CharField(max_length=128, null=False, help_text="LLM name") + model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR") + llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="") api_key = CharField(max_length=255, null=True, help_text="API KEY") api_base = CharField(max_length=255, null=True, help_text="API Base") @@ -452,7 +453,7 @@ class TenantLLM(DataBaseModel): class Meta: db_table = "tenant_llm" - primary_key = CompositeKey('tenant_id', 'llm_factory') + primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name') class Knowledgebase(DataBaseModel): @@ -464,7 +465,9 @@ class Knowledgebase(DataBaseModel): permission = CharField(max_length=16, null=False, help_text="me|team") created_by = CharField(max_length=32, null=False) doc_num = IntegerField(default=0) - embd_id = CharField(max_length=32, null=False, help_text="default embedding model ID") + token_num = IntegerField(default=0) + chunk_num = IntegerField(default=0) + parser_id = CharField(max_length=32, null=False, help_text="default parser ID") status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") diff --git a/web_server/db/services/document_service.py b/web_server/db/services/document_service.py index e8746a4ac19f8f054dde515f90c894592251dac4..38b1cd559aef42ba2ffa3821d81026d52858df14 100644 --- a/web_server/db/services/document_service.py +++ b/web_server/db/services/document_service.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from peewee import Expression + from web_server.db import TenantPermission, FileType -from web_server.db.db_models import DB, Knowledgebase +from web_server.db.db_models import DB, Knowledgebase, Tenant from web_server.db.db_models import Document from web_server.db.services.common_service import CommonService from web_server.db.services.kb_service import KnowledgebaseService -from web_server.utils import get_uuid, get_format_time from web_server.db.db_utils import StatusEnum @@ -61,15 +62,28 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64): - fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, Knowledgebase.tenant_id] - docs = cls.model.select(fields).join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)).where( - cls.model.status == StatusEnum.VALID.value, - cls.model.type != FileType.VIRTUAL, - cls.model.progress == 0, - cls.model.update_time >= tm, - cls.model.create_time % - comm == mod).order_by( - cls.model.update_time.asc()).paginate( - 1, - items_per_page) + fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, cls.model.update_time] + docs = cls.model.select(*fields) \ + .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \ + .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\ + .where( + cls.model.status == StatusEnum.VALID.value, + ~(cls.model.type == FileType.VIRTUAL.value), + cls.model.progress == 0, + cls.model.update_time >= tm, + (Expression(cls.model.create_time, "%%", comm) == mod))\ + .order_by(cls.model.update_time.asc())\ + .paginate(1, items_per_page) return list(docs.dicts()) + + @classmethod + @DB.connection_context() + def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation): + num = cls.model.update(token_num=cls.model.token_num + token_num, + chunk_num=cls.model.chunk_num + chunk_num, + process_duation=cls.model.process_duation+duation).where( + cls.model.id == doc_id).execute() + if num == 0:raise LookupError("Document not found which is supposed to be there") + num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute() + return num + diff --git a/web_server/db/services/kb_service.py b/web_server/db/services/kb_service.py index 84b2e4f93d9cb09bb7d19a97779bd6db089a5625..a8ca96a2ae944dcd83ca34afd39436796a745b15 100644 --- a/web_server/db/services/kb_service.py +++ b/web_server/db/services/kb_service.py @@ -17,7 +17,7 @@ import peewee from werkzeug.security import generate_password_hash, check_password_hash from web_server.db import TenantPermission -from web_server.db.db_models import DB, UserTenant +from web_server.db.db_models import DB, UserTenant, Tenant from web_server.db.db_models import Knowledgebase from web_server.db.services.common_service import CommonService from web_server.utils import get_uuid, get_format_time @@ -29,15 +29,42 @@ class KnowledgebaseService(CommonService): @classmethod @DB.connection_context() - def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc): + def get_by_tenant_ids(cls, joined_tenant_ids, user_id, + page_number, items_per_page, orderby, desc): kbs = cls.model.select().where( - ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id)) - & (cls.model.status==StatusEnum.VALID.value) + ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == + TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id)) + & (cls.model.status == StatusEnum.VALID.value) ) - if desc: kbs = kbs.order_by(cls.model.getter_by(orderby).desc()) - else: kbs = kbs.order_by(cls.model.getter_by(orderby).asc()) + if desc: + kbs = kbs.order_by(cls.model.getter_by(orderby).desc()) + else: + kbs = kbs.order_by(cls.model.getter_by(orderby).asc()) kbs = kbs.paginate(page_number, items_per_page) return list(kbs.dicts()) + @classmethod + @DB.connection_context() + def get_detail(cls, kb_id): + fields = [ + cls.model.id, + Tenant.embd_id, + cls.model.avatar, + cls.model.name, + cls.model.description, + cls.model.permission, + cls.model.doc_num, + cls.model.token_num, + cls.model.chunk_num, + cls.model.parser_id] + kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where( + (cls.model.id == kb_id), + (cls.model.status == StatusEnum.VALID.value) + ) + if not kbs: + return + d = kbs[0].to_dict() + d["embd_id"] = kbs[0].tenant.embd_id + return d diff --git a/web_server/db/services/llm_service.py b/web_server/db/services/llm_service.py index 7d6b575fea608b8c21f749d22535fe3ca1e57272..350106e36e32a06ec372e0d459546a775cbc49fa 100644 --- a/web_server/db/services/llm_service.py +++ b/web_server/db/services/llm_service.py @@ -33,3 +33,21 @@ class LLMService(CommonService): class TenantLLMService(CommonService): model = TenantLLM + + @classmethod + @DB.connection_context() + def get_api_key(cls, tenant_id, model_type): + objs = cls.query(tenant_id=tenant_id, model_type=model_type) + if objs and len(objs)>0 and objs[0].llm_name: + return objs[0] + + fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key] + objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where( + (cls.model.tenant_id == tenant_id), + (cls.model.model_type == model_type), + (LLM.status == StatusEnum.VALID) + ) + + if not objs:return + return objs[0] + diff --git a/web_server/db/services/user_service.py b/web_server/db/services/user_service.py index 42e0b5c11ad60bb1486fa41cb2d72b81937c57e0..f4ed4b58c27745c46e97f7da7b44b6eb4d9f119e 100644 --- a/web_server/db/services/user_service.py +++ b/web_server/db/services/user_service.py @@ -79,7 +79,7 @@ class TenantService(CommonService): @classmethod @DB.connection_context() def get_by_user_id(cls, user_id): - fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role] + fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role] return list(cls.model.select(*fields)\ .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\ .where(cls.model.status == StatusEnum.VALID.value).dicts()) diff --git a/web_server/utils/file_utils.py b/web_server/utils/file_utils.py index 54b1514ecf85b70b5ebdf0155e9c68ffb92f5c6f..442ab19bf494c23e7f93ad4b217c0f3806cd9436 100644 --- a/web_server/utils/file_utils.py +++ b/web_server/utils/file_utils.py @@ -143,7 +143,7 @@ def filename_type(filename): if re.match(r".*\.pdf$", filename): return FileType.PDF.value - if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)$", filename): + if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename): return FileType.DOC.value if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):