From 0429107e80de7ffb81cf50d7728126a8daaa3b58 Mon Sep 17 00:00:00 2001 From: KevinHuSh <kevinhu.sh@gmail.com> Date: Thu, 29 Feb 2024 14:03:07 +0800 Subject: [PATCH] fix user login issue (#85) --- api/apps/user_app.py | 107 ++++++++----------- api/db/__init__.py | 1 - api/db/db_models.py | 2 +- api/db/init_data.py | 4 +- api/db/services/user_service.py | 13 ++- api/settings.py | 2 +- deepdoc/parser/pdf_parser.py | 2 +- deepdoc/vision/layout_recognizer.py | 3 +- deepdoc/vision/table_structure_recognizer.py | 3 +- rag/app/manual.py | 6 ++ rag/app/naive.py | 35 ++++-- rag/svr/task_executor.py | 4 +- 12 files changed, 101 insertions(+), 81 deletions(-) diff --git a/api/apps/user_app.py b/api/apps/user_app.py index 8b5ba4a..da352fa 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -33,49 +33,14 @@ from api.utils.api_utils import get_json_result, cors_reponse @manager.route('/login', methods=['POST', 'GET']) def login(): - userinfo = None login_channel = "password" - if session.get("access_token"): - login_channel = session["access_token_from"] - if session["access_token_from"] == "github": - userinfo = user_info_from_github(session["access_token"]) - elif not request.json: + if not request.json: return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Unautherized!') - email = request.json.get('email') if not userinfo else userinfo["email"] + email = request.json.get('email', "") users = UserService.query(email=email) - if not users: - if request.json is not None: - return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!') - avatar = "" - try: - avatar = download_img(userinfo["avatar_url"]) - except Exception as e: - stat_logger.exception(e) - user_id = get_uuid() - try: - users = user_register(user_id, { - "access_token": session["access_token"], - "email": userinfo["email"], - "avatar": avatar, - "nickname": userinfo["login"], - "login_channel": login_channel, - "last_login_time": get_format_time(), - "is_superuser": False, - }) - if not users: raise Exception('Register user failure.') - if len(users) > 1: raise Exception('Same E-mail exist!') - user = users[0] - login_user(user) - return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!") - except Exception as e: - rollback_user_registration(user_id) - stat_logger.exception(e) - return server_error_response(e) - elif not request.json: - login_user(users[0]) - return cors_reponse(data=users[0].to_json(), auth=users[0].get_id(), retmsg="Welcome back!") + if not users: return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!') password = request.json.get('password') try: @@ -97,28 +62,50 @@ def login(): @manager.route('/github_callback', methods=['GET']) def github_callback(): - try: - import requests - res = requests.post(GITHUB_OAUTH.get("url"), data={ - "client_id": GITHUB_OAUTH.get("client_id"), - "client_secret": GITHUB_OAUTH.get("secret_key"), - "code": request.args.get('code') - },headers={"Accept": "application/json"}) - res = res.json() - if "error" in res: - return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, - retmsg=res["error_description"]) - - if "user:email" not in res["scope"].split(","): - return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope') - - session["access_token"] = res["access_token"] - session["access_token_from"] = "github" - return redirect(url_for("user.login"), code=307) + import requests + res = requests.post(GITHUB_OAUTH.get("url"), data={ + "client_id": GITHUB_OAUTH.get("client_id"), + "client_secret": GITHUB_OAUTH.get("secret_key"), + "code": request.args.get('code') + }, headers={"Accept": "application/json"}) + res = res.json() + if "error" in res: + return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, + retmsg=res["error_description"]) - except Exception as e: - stat_logger.exception(e) - return server_error_response(e) + if "user:email" not in res["scope"].split(","): + return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope') + + session["access_token"] = res["access_token"] + session["access_token_from"] = "github" + userinfo = user_info_from_github(session["access_token"]) + users = UserService.query(email=userinfo["email"]) + user_id = get_uuid() + if not users: + try: + try: + avatar = download_img(userinfo["avatar_url"]) + except Exception as e: + stat_logger.exception(e) + avatar = "" + users = user_register(user_id, { + "access_token": session["access_token"], + "email": userinfo["email"], + "avatar": avatar, + "nickname": userinfo["login"], + "login_channel": "github", + "last_login_time": get_format_time(), + "is_superuser": False, + }) + if not users: raise Exception('Register user failure.') + if len(users) > 1: raise Exception('Same E-mail exist!') + user = users[0] + login_user(user) + except Exception as e: + rollback_user_registration(user_id) + stat_logger.exception(e) + + return redirect("/knowledge") def user_info_from_github(access_token): @@ -208,7 +195,7 @@ def user_register(user_id, user): for llm in LLMService.query(fid=LLM_FACTORY): tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY}) - if not UserService.insert(**user):return + if not UserService.save(**user):return TenantService.insert(**tenant) UserTenantService.insert(**usr_tenant) TenantLLMService.insert_many(tenant_llm) diff --git a/api/db/__init__.py b/api/db/__init__.py index 9c8a9b6..c1f5d80 100644 --- a/api/db/__init__.py +++ b/api/db/__init__.py @@ -69,7 +69,6 @@ class TaskStatus(StrEnum): class ParserType(StrEnum): - GENERAL = "general" PRESENTATION = "presentation" LAWS = "laws" MANUAL = "manual" diff --git a/api/db/db_models.py b/api/db/db_models.py index 9b2dc7f..4fb66d5 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -475,7 +475,7 @@ class Knowledgebase(DataBaseModel): similarity_threshold = FloatField(default=0.2) vector_similarity_weight = FloatField(default=0.3) - parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.GENERAL.value) + parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value) parser_config = JSONField(null=False, default={"pages":[[0,1000000]]}) status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") diff --git a/api/db/init_data.py b/api/db/init_data.py index b0f0b2f..92dce90 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -30,7 +30,7 @@ def init_superuser(): "password": "admin", "nickname": "admin", "is_superuser": True, - "email": "kai.hu@infiniflow.org", + "email": "admin@ragflow.io", "creator": "system", "status": "1", } @@ -61,7 +61,7 @@ def init_superuser(): TenantService.insert(**tenant) UserTenantService.insert(**usr_tenant) TenantLLMService.insert_many(tenant_llm) - print("ă€INFO】Super user initialized. \033[93muser name: admin, password: admin\033[0m. Changing the password after logining is strongly recomanded.") + print("ă€INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.") chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}) diff --git a/api/db/services/user_service.py b/api/db/services/user_service.py index 1ddfa01..fe68783 100644 --- a/api/db/services/user_service.py +++ b/api/db/services/user_service.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from datetime import datetime + import peewee from werkzeug.security import generate_password_hash, check_password_hash @@ -20,7 +22,7 @@ from api.db import UserTenantRole from api.db.db_models import DB, UserTenant from api.db.db_models import User, Tenant from api.db.services.common_service import CommonService -from api.utils import get_uuid, get_format_time +from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format from api.db import StatusEnum @@ -53,6 +55,11 @@ class UserService(CommonService): kwargs["id"] = get_uuid() if "password" in kwargs: kwargs["password"] = generate_password_hash(str(kwargs["password"])) + + kwargs["create_time"] = current_timestamp() + kwargs["create_date"] = datetime_format(datetime.now()) + kwargs["update_time"] = current_timestamp() + kwargs["update_date"] = datetime_format(datetime.now()) obj = cls.model(**kwargs).save(force_insert=True) return obj @@ -66,10 +73,10 @@ class UserService(CommonService): @classmethod @DB.connection_context() def update_user(cls, user_id, user_dict): - date_time = get_format_time() with DB.atomic(): if user_dict: - user_dict["update_time"] = date_time + user_dict["update_time"] = current_timestamp() + user_dict["update_date"] = datetime_format(datetime.now()) cls.model.update(user_dict).where(cls.model.id == user_id).execute() diff --git a/api/settings.py b/api/settings.py index 331a086..e0076ed 100644 --- a/api/settings.py +++ b/api/settings.py @@ -76,7 +76,7 @@ ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] API_KEY = LLM.get("api_key", "infiniflow API Key") -PARSERS = LLM.get("parsers", "general:General,qa:Q&A,resume:Resume,naive:Naive,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture") +PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture") # distribution DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False) diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index f333b1d..f99aa91 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -25,7 +25,7 @@ class HuParser: def __init__(self): self.ocr = OCR() if not hasattr(self, "model_speciess"): - self.model_speciess = ParserType.GENERAL.value + self.model_speciess = ParserType.NAIVE.value self.layouter = LayoutRecognizer("layout."+self.model_speciess) self.tbl_det = TableStructureRecognizer() diff --git a/deepdoc/vision/layout_recognizer.py b/deepdoc/vision/layout_recognizer.py index 7a7791a..52feaba 100644 --- a/deepdoc/vision/layout_recognizer.py +++ b/deepdoc/vision/layout_recognizer.py @@ -34,8 +34,7 @@ class LayoutRecognizer(Recognizer): "Equation", ] def __init__(self, domain): - super().__init__(self.labels, domain, - os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) + super().__init__(self.labels, domain) #, os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16): def __is_garbage(b): diff --git a/deepdoc/vision/table_structure_recognizer.py b/deepdoc/vision/table_structure_recognizer.py index e396cbf..8e149a5 100644 --- a/deepdoc/vision/table_structure_recognizer.py +++ b/deepdoc/vision/table_structure_recognizer.py @@ -33,8 +33,7 @@ class TableStructureRecognizer(Recognizer): ] def __init__(self): - super().__init__(self.labels, "tsr", - os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) + super().__init__(self.labels, "tsr")#,os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) def __call__(self, images, thr=0.2): tbls = super().__call__(images, thr) diff --git a/rag/app/manual.py b/rag/app/manual.py index 14fbf0b..0ef3195 100644 --- a/rag/app/manual.py +++ b/rag/app/manual.py @@ -1,11 +1,17 @@ import copy import re + +from api.db import ParserType from rag.nlp import huqie, tokenize from deepdoc.parser import PdfParser from rag.utils import num_tokens_from_string class Pdf(PdfParser): + def __init__(self): + self.model_speciess = ParserType.MANUAL.value + super().__init__() + def __call__(self, filename, binary=None, from_page=0, to_page=100000, zoomin=3, callback=None): self.__images__( diff --git a/rag/app/naive.py b/rag/app/naive.py index 0446217..b14e7bf 100644 --- a/rag/app/naive.py +++ b/rag/app/naive.py @@ -30,11 +30,21 @@ class Pdf(PdfParser): from timeit import default_timer as timer start = timer() + start = timer() self._layouts_rec(zoomin) - callback(0.77, "Layout analysis finished") + callback(0.5, "Layout analysis finished.") + print("paddle layouts:", timer() - start) + self._table_transformer_job(zoomin) + callback(0.7, "Table analysis finished.") + self._text_merge() + self._concat_downward(concat_between_pages=False) + self._filter_forpages() + callback(0.77, "Text merging finished") + tbls = self._extract_table_figure(True, zoomin, False) + cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1))) - self._naive_vertical_merge() - return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes] + #self._naive_vertical_merge() + return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): @@ -44,11 +54,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca Successive text will be sliced into pieces using 'delimiter'. Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'. """ + + eng = lang.lower() == "english"#is_english(cks) doc = { "docnm_kwd": filename, "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename)) } doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) + res = [] pdf_parser = None sections = [] if re.search(r"\.docx?$", filename, re.IGNORECASE): @@ -58,8 +71,19 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca callback(0.8, "Finish parsing.") elif re.search(r"\.pdf$", filename, re.IGNORECASE): pdf_parser = Pdf() - sections = pdf_parser(filename if not binary else binary, + sections, tbls = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback) + # add tables + for img, rows in tbls: + bs = 10 + de = ";" if eng else ";" + for i in range(0, len(rows), bs): + d = copy.deepcopy(doc) + r = de.join(rows[i:i + bs]) + r = re.sub(r"\t——(来自| in ).*”%s" % de, "", r) + tokenize(d, r, eng) + d["image"] = img + res.append(d) elif re.search(r"\.txt$", filename, re.IGNORECASE): callback(0.1, "Start to parse.") txt = "" @@ -79,8 +103,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;ďĽďĽź"}) cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimiter"]) - eng = lang.lower() == "english"#is_english(cks) - res = [] + # wrap up to es documents for ck in cks: print("--", ck) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 0d0b4e4..285ce96 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -37,7 +37,7 @@ from rag.nlp import search from io import BytesIO import pandas as pd -from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture +from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive from api.db import LLMType, ParserType from api.db.services.document_service import DocumentService @@ -48,7 +48,7 @@ from api.utils.file_utils import get_project_base_directory BATCH_SIZE = 64 FACTORY = { - ParserType.GENERAL.value: laws, + ParserType.NAIVE.value: naive, ParserType.PAPER.value: paper, ParserType.BOOK.value: book, ParserType.PRESENTATION.value: presentation, -- GitLab