From 0429107e80de7ffb81cf50d7728126a8daaa3b58 Mon Sep 17 00:00:00 2001
From: KevinHuSh <kevinhu.sh@gmail.com>
Date: Thu, 29 Feb 2024 14:03:07 +0800
Subject: [PATCH] fix user login issue (#85)

---
 api/apps/user_app.py                         | 107 ++++++++-----------
 api/db/__init__.py                           |   1 -
 api/db/db_models.py                          |   2 +-
 api/db/init_data.py                          |   4 +-
 api/db/services/user_service.py              |  13 ++-
 api/settings.py                              |   2 +-
 deepdoc/parser/pdf_parser.py                 |   2 +-
 deepdoc/vision/layout_recognizer.py          |   3 +-
 deepdoc/vision/table_structure_recognizer.py |   3 +-
 rag/app/manual.py                            |   6 ++
 rag/app/naive.py                             |  35 ++++--
 rag/svr/task_executor.py                     |   4 +-
 12 files changed, 101 insertions(+), 81 deletions(-)

diff --git a/api/apps/user_app.py b/api/apps/user_app.py
index 8b5ba4a..da352fa 100644
--- a/api/apps/user_app.py
+++ b/api/apps/user_app.py
@@ -33,49 +33,14 @@ from api.utils.api_utils import get_json_result, cors_reponse
 
 @manager.route('/login', methods=['POST', 'GET'])
 def login():
-    userinfo = None
     login_channel = "password"
-    if session.get("access_token"):
-        login_channel = session["access_token_from"]
-        if session["access_token_from"] == "github":
-            userinfo = user_info_from_github(session["access_token"])
-    elif not request.json:
+    if not request.json:
         return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
                                retmsg='Unautherized!')
 
-    email = request.json.get('email') if not userinfo else userinfo["email"]
+    email = request.json.get('email', "")
     users = UserService.query(email=email)
-    if not users:
-        if request.json is not None:
-            return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
-        avatar = ""
-        try:
-            avatar = download_img(userinfo["avatar_url"])
-        except Exception as e:
-            stat_logger.exception(e)
-        user_id = get_uuid()
-        try:
-            users = user_register(user_id, {
-                "access_token": session["access_token"],
-                "email": userinfo["email"],
-                "avatar": avatar,
-                "nickname": userinfo["login"],
-                "login_channel": login_channel,
-                "last_login_time": get_format_time(),
-                "is_superuser": False,
-            })
-            if not users: raise Exception('Register user failure.')
-            if len(users) > 1: raise Exception('Same E-mail exist!')
-            user = users[0]
-            login_user(user)
-            return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
-        except Exception as e:
-            rollback_user_registration(user_id)
-            stat_logger.exception(e)
-            return server_error_response(e)
-    elif not request.json:
-        login_user(users[0])
-        return cors_reponse(data=users[0].to_json(), auth=users[0].get_id(), retmsg="Welcome back!")
+    if not users: return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
 
     password = request.json.get('password')
     try:
@@ -97,28 +62,50 @@ def login():
 
 @manager.route('/github_callback', methods=['GET'])
 def github_callback():
-    try:
-        import requests
-        res = requests.post(GITHUB_OAUTH.get("url"), data={
-            "client_id": GITHUB_OAUTH.get("client_id"),
-            "client_secret":  GITHUB_OAUTH.get("secret_key"),
-            "code": request.args.get('code')
-        },headers={"Accept": "application/json"})
-        res = res.json()
-        if "error" in res:
-            return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
-                                   retmsg=res["error_description"])
-
-        if "user:email" not in res["scope"].split(","):
-            return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')
-
-        session["access_token"] = res["access_token"]
-        session["access_token_from"] = "github"
-        return redirect(url_for("user.login"),  code=307)
+    import requests
+    res = requests.post(GITHUB_OAUTH.get("url"), data={
+        "client_id": GITHUB_OAUTH.get("client_id"),
+        "client_secret":  GITHUB_OAUTH.get("secret_key"),
+        "code": request.args.get('code')
+    }, headers={"Accept": "application/json"})
+    res = res.json()
+    if "error" in res:
+        return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
+                               retmsg=res["error_description"])
 
-    except Exception as e:
-        stat_logger.exception(e)
-        return server_error_response(e)
+    if "user:email" not in res["scope"].split(","):
+        return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')
+
+    session["access_token"] = res["access_token"]
+    session["access_token_from"] = "github"
+    userinfo = user_info_from_github(session["access_token"])
+    users = UserService.query(email=userinfo["email"])
+    user_id = get_uuid()
+    if not users:
+        try:
+            try:
+                avatar = download_img(userinfo["avatar_url"])
+            except Exception as e:
+                stat_logger.exception(e)
+                avatar = ""
+            users = user_register(user_id, {
+                "access_token": session["access_token"],
+                "email": userinfo["email"],
+                "avatar": avatar,
+                "nickname": userinfo["login"],
+                "login_channel": "github",
+                "last_login_time": get_format_time(),
+                "is_superuser": False,
+            })
+            if not users: raise Exception('Register user failure.')
+            if len(users) > 1: raise Exception('Same E-mail exist!')
+            user = users[0]
+            login_user(user)
+        except Exception as e:
+            rollback_user_registration(user_id)
+            stat_logger.exception(e)
+
+    return redirect("/knowledge")
 
 
 def user_info_from_github(access_token):
@@ -208,7 +195,7 @@ def user_register(user_id, user):
     for llm in LLMService.query(fid=LLM_FACTORY):
         tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY})
 
-    if not UserService.insert(**user):return
+    if not UserService.save(**user):return
     TenantService.insert(**tenant)
     UserTenantService.insert(**usr_tenant)
     TenantLLMService.insert_many(tenant_llm)
diff --git a/api/db/__init__.py b/api/db/__init__.py
index 9c8a9b6..c1f5d80 100644
--- a/api/db/__init__.py
+++ b/api/db/__init__.py
@@ -69,7 +69,6 @@ class TaskStatus(StrEnum):
 
 
 class ParserType(StrEnum):
-    GENERAL = "general"
     PRESENTATION = "presentation"
     LAWS = "laws"
     MANUAL = "manual"
diff --git a/api/db/db_models.py b/api/db/db_models.py
index 9b2dc7f..4fb66d5 100644
--- a/api/db/db_models.py
+++ b/api/db/db_models.py
@@ -475,7 +475,7 @@ class Knowledgebase(DataBaseModel):
     similarity_threshold = FloatField(default=0.2)
     vector_similarity_weight = FloatField(default=0.3)
 
-    parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.GENERAL.value)
+    parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value)
     parser_config = JSONField(null=False, default={"pages":[[0,1000000]]})
     status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
 
diff --git a/api/db/init_data.py b/api/db/init_data.py
index b0f0b2f..92dce90 100644
--- a/api/db/init_data.py
+++ b/api/db/init_data.py
@@ -30,7 +30,7 @@ def init_superuser():
         "password": "admin",
         "nickname": "admin",
         "is_superuser": True,
-        "email": "kai.hu@infiniflow.org",
+        "email": "admin@ragflow.io",
         "creator": "system",
         "status": "1",
     }
@@ -61,7 +61,7 @@ def init_superuser():
     TenantService.insert(**tenant)
     UserTenantService.insert(**usr_tenant)
     TenantLLMService.insert_many(tenant_llm)
-    print("【INFO】Super user initialized. \033[93muser name: admin, password: admin\033[0m. Changing the password after logining is strongly recomanded.")
+    print("【INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.")
 
     chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
     msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
diff --git a/api/db/services/user_service.py b/api/db/services/user_service.py
index 1ddfa01..fe68783 100644
--- a/api/db/services/user_service.py
+++ b/api/db/services/user_service.py
@@ -13,6 +13,8 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 #
+from datetime import datetime
+
 import peewee
 from werkzeug.security import generate_password_hash, check_password_hash
 
@@ -20,7 +22,7 @@ from api.db import UserTenantRole
 from api.db.db_models import DB, UserTenant
 from api.db.db_models import User, Tenant
 from api.db.services.common_service import CommonService
-from api.utils import get_uuid, get_format_time
+from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format
 from api.db import StatusEnum
 
 
@@ -53,6 +55,11 @@ class UserService(CommonService):
             kwargs["id"] = get_uuid()
         if "password" in kwargs:
             kwargs["password"] = generate_password_hash(str(kwargs["password"]))
+
+        kwargs["create_time"] = current_timestamp()
+        kwargs["create_date"] = datetime_format(datetime.now())
+        kwargs["update_time"] = current_timestamp()
+        kwargs["update_date"] = datetime_format(datetime.now())
         obj = cls.model(**kwargs).save(force_insert=True)
         return obj
 
@@ -66,10 +73,10 @@ class UserService(CommonService):
     @classmethod
     @DB.connection_context()
     def update_user(cls, user_id, user_dict):
-        date_time = get_format_time()
         with DB.atomic():
             if user_dict:
-                user_dict["update_time"] = date_time
+                user_dict["update_time"] = current_timestamp()
+                user_dict["update_date"] = datetime_format(datetime.now())
                 cls.model.update(user_dict).where(cls.model.id == user_id).execute()
 
 
diff --git a/api/settings.py b/api/settings.py
index 331a086..e0076ed 100644
--- a/api/settings.py
+++ b/api/settings.py
@@ -76,7 +76,7 @@ ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
 IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
 
 API_KEY = LLM.get("api_key", "infiniflow API Key")
-PARSERS = LLM.get("parsers", "general:General,qa:Q&A,resume:Resume,naive:Naive,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
+PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
 
 # distribution
 DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py
index f333b1d..f99aa91 100644
--- a/deepdoc/parser/pdf_parser.py
+++ b/deepdoc/parser/pdf_parser.py
@@ -25,7 +25,7 @@ class HuParser:
     def __init__(self):
         self.ocr = OCR()
         if not hasattr(self, "model_speciess"):
-            self.model_speciess = ParserType.GENERAL.value
+            self.model_speciess = ParserType.NAIVE.value
         self.layouter = LayoutRecognizer("layout."+self.model_speciess)
         self.tbl_det = TableStructureRecognizer()
 
diff --git a/deepdoc/vision/layout_recognizer.py b/deepdoc/vision/layout_recognizer.py
index 7a7791a..52feaba 100644
--- a/deepdoc/vision/layout_recognizer.py
+++ b/deepdoc/vision/layout_recognizer.py
@@ -34,8 +34,7 @@ class LayoutRecognizer(Recognizer):
              "Equation",
         ]
     def __init__(self, domain):
-        super().__init__(self.labels, domain,
-                         os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
+        super().__init__(self.labels, domain) #, os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
 
     def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
         def __is_garbage(b):
diff --git a/deepdoc/vision/table_structure_recognizer.py b/deepdoc/vision/table_structure_recognizer.py
index e396cbf..8e149a5 100644
--- a/deepdoc/vision/table_structure_recognizer.py
+++ b/deepdoc/vision/table_structure_recognizer.py
@@ -33,8 +33,7 @@ class TableStructureRecognizer(Recognizer):
     ]
 
     def __init__(self):
-        super().__init__(self.labels, "tsr",
-                         os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
+        super().__init__(self.labels, "tsr")#,os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
 
     def __call__(self, images, thr=0.2):
         tbls = super().__call__(images, thr)
diff --git a/rag/app/manual.py b/rag/app/manual.py
index 14fbf0b..0ef3195 100644
--- a/rag/app/manual.py
+++ b/rag/app/manual.py
@@ -1,11 +1,17 @@
 import copy
 import re
+
+from api.db import ParserType
 from rag.nlp import huqie, tokenize
 from deepdoc.parser import PdfParser
 from rag.utils import num_tokens_from_string
 
 
 class Pdf(PdfParser):
+    def __init__(self):
+        self.model_speciess = ParserType.MANUAL.value
+        super().__init__()
+
     def __call__(self, filename, binary=None, from_page=0,
                  to_page=100000, zoomin=3, callback=None):
         self.__images__(
diff --git a/rag/app/naive.py b/rag/app/naive.py
index 0446217..b14e7bf 100644
--- a/rag/app/naive.py
+++ b/rag/app/naive.py
@@ -30,11 +30,21 @@ class Pdf(PdfParser):
 
         from timeit import default_timer as timer
         start = timer()
+        start = timer()
         self._layouts_rec(zoomin)
-        callback(0.77, "Layout analysis finished")
+        callback(0.5, "Layout analysis finished.")
+        print("paddle layouts:", timer() - start)
+        self._table_transformer_job(zoomin)
+        callback(0.7, "Table analysis finished.")
+        self._text_merge()
+        self._concat_downward(concat_between_pages=False)
+        self._filter_forpages()
+        callback(0.77, "Text merging finished")
+        tbls = self._extract_table_figure(True, zoomin, False)
+
         cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1)))
-        self._naive_vertical_merge()
-        return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes]
+        #self._naive_vertical_merge()
+        return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls
 
 
 def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
@@ -44,11 +54,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
         Successive text will be sliced into pieces using 'delimiter'.
         Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
     """
+
+    eng = lang.lower() == "english"#is_english(cks)
     doc = {
         "docnm_kwd": filename,
         "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
     }
     doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
+    res = []
     pdf_parser = None
     sections = []
     if re.search(r"\.docx?$", filename, re.IGNORECASE):
@@ -58,8 +71,19 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
         callback(0.8, "Finish parsing.")
     elif re.search(r"\.pdf$", filename, re.IGNORECASE):
         pdf_parser = Pdf()
-        sections = pdf_parser(filename if not binary else binary,
+        sections, tbls = pdf_parser(filename if not binary else binary,
                               from_page=from_page, to_page=to_page, callback=callback)
+        # add tables
+        for img, rows in tbls:
+            bs = 10
+            de = ";" if eng else ";"
+            for i in range(0, len(rows), bs):
+                d = copy.deepcopy(doc)
+                r = de.join(rows[i:i + bs])
+                r = re.sub(r"\t——(来自| in ).*”%s" % de, "", r)
+                tokenize(d, r, eng)
+                d["image"] = img
+                res.append(d)
     elif re.search(r"\.txt$", filename, re.IGNORECASE):
         callback(0.1, "Start to parse.")
         txt = ""
@@ -79,8 +103,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
 
     parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?"})
     cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimiter"])
-    eng = lang.lower() == "english"#is_english(cks)
-    res = []
+
     # wrap up to es documents
     for ck in cks:
         print("--", ck)
diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py
index 0d0b4e4..285ce96 100644
--- a/rag/svr/task_executor.py
+++ b/rag/svr/task_executor.py
@@ -37,7 +37,7 @@ from rag.nlp import search
 from io import BytesIO
 import pandas as pd
 
-from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture
+from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive
 
 from api.db import LLMType, ParserType
 from api.db.services.document_service import DocumentService
@@ -48,7 +48,7 @@ from api.utils.file_utils import get_project_base_directory
 BATCH_SIZE = 64
 
 FACTORY = {
-    ParserType.GENERAL.value: laws,
+    ParserType.NAIVE.value: naive,
     ParserType.PAPER.value: paper,
     ParserType.BOOK.value: book,
     ParserType.PRESENTATION.value: presentation,
-- 
GitLab