From 944776f207fb9c5711f3911f31805130c0d70971 Mon Sep 17 00:00:00 2001
From: KevinHuSh <kevinhu.sh@gmail.com>
Date: Sun, 28 Apr 2024 09:57:40 +0800
Subject: [PATCH] fix bug about fetching file from minio (#574)

### What problem does this PR solve?


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
---
 api/apps/file_app.py                     |  6 +++---
 api/db/services/file2document_service.py | 17 +++++++++++++++++
 api/db/services/file_service.py          |  2 +-
 api/db/services/task_service.py          | 10 ++++++----
 rag/svr/task_broker.py                   |  7 +++++--
 rag/svr/task_executor.py                 |  5 ++++-
 6 files changed, 36 insertions(+), 11 deletions(-)

diff --git a/api/apps/file_app.py b/api/apps/file_app.py
index 6cd9742..17944a9 100644
--- a/api/apps/file_app.py
+++ b/api/apps/file_app.py
@@ -328,12 +328,12 @@ def rename():
 # @login_required
 def get(file_id):
     try:
-        e, doc = FileService.get_by_id(file_id)
+        e, file = FileService.get_by_id(file_id)
         if not e:
             return get_data_error_result(retmsg="Document not found!")
 
-        response = flask.make_response(MINIO.get(doc.parent_id, doc.location))
-        ext = re.search(r"\.([^.]+)$", doc.name)
+        response = flask.make_response(MINIO.get(file.parent_id, file.location))
+        ext = re.search(r"\.([^.]+)$", file.name)
         if ext:
             if doc.type == FileType.VISUAL.value:
                 response.headers.set('Content-Type', 'image/%s' % ext.group(1))
diff --git a/api/db/services/file2document_service.py b/api/db/services/file2document_service.py
index b53e0ad..18ec03d 100644
--- a/api/db/services/file2document_service.py
+++ b/api/db/services/file2document_service.py
@@ -18,6 +18,8 @@ from datetime import datetime
 from api.db.db_models import DB
 from api.db.db_models import File, Document, File2Document
 from api.db.services.common_service import CommonService
+from api.db.services.document_service import DocumentService
+from api.db.services.file_service import FileService
 from api.utils import current_timestamp, datetime_format
 
 
@@ -64,3 +66,18 @@ class File2DocumentService(CommonService):
         num = cls.model.update(obj).where(cls.model.id == file_id).execute()
         e, obj = cls.get_by_id(cls.model.id)
         return obj
+
+    @classmethod
+    @DB.connection_context()
+    def get_minio_address(cls, doc_id=None, file_id=None):
+        if doc_id:
+            ids = File2DocumentService.get_by_document_id(doc_id)
+        else:
+            ids = File2DocumentService.get_by_file_id(file_id)
+        if ids:
+            e, file = FileService.get_by_id(ids[0].file_id)
+            return file.parent_id, file.location
+        else:
+            assert doc_id, "please specify doc_id"
+            e, doc = DocumentService.get_by_id(doc_id)
+            return doc.kb_id, doc.location
diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py
index abb6e56..57948d4 100644
--- a/api/db/services/file_service.py
+++ b/api/db/services/file_service.py
@@ -21,7 +21,6 @@ from api.db.db_models import DB, File2Document, Knowledgebase
 from api.db.db_models import File, Document
 from api.db.services.common_service import CommonService
 from api.utils import get_uuid
-from rag.utils import MINIO
 
 
 class FileService(CommonService):
@@ -241,3 +240,4 @@ class FileService(CommonService):
 
         dfs(folder_id)
         return size
+
diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py
index 8c6bc6e..ccc837a 100644
--- a/api/db/services/task_service.py
+++ b/api/db/services/task_service.py
@@ -15,8 +15,8 @@
 #
 import random
 
-from peewee import Expression
-from api.db.db_models import DB
+from peewee import Expression, JOIN
+from api.db.db_models import DB, File2Document, File
 from api.db import StatusEnum, FileType, TaskStatus
 from api.db.db_models import Task, Document, Knowledgebase, Tenant
 from api.db.services.common_service import CommonService
@@ -75,8 +75,10 @@ class TaskService(CommonService):
     @DB.connection_context()
     def get_ongoing_doc_name(cls):
         with DB.lock("get_task", -1):
-            docs = cls.model.select(*[Document.kb_id, Document.location]) \
+            docs = cls.model.select(*[Document.id, Document.kb_id, Document.location, File.parent_id]) \
                 .join(Document, on=(cls.model.doc_id == Document.id)) \
+                .join(File2Document, on=(File2Document.document_id == Document.id), join_type=JOIN.LEFT_OUTER) \
+                .join(File, on=(File2Document.file_id == File.id)) \
                 .where(
                     Document.status == StatusEnum.VALID.value,
                     Document.run == TaskStatus.RUNNING.value,
@@ -88,7 +90,7 @@ class TaskService(CommonService):
             docs = list(docs.dicts())
             if not docs: return []
 
-            return list(set([(d["kb_id"], d["location"]) for d in docs]))
+            return list(set([(d["parent_id"] if d["parent_id"] else d["kb_id"], d["location"]) for d in docs]))
 
     @classmethod
     @DB.connection_context()
diff --git a/rag/svr/task_broker.py b/rag/svr/task_broker.py
index 3e43fbf..d7b57d5 100644
--- a/rag/svr/task_broker.py
+++ b/rag/svr/task_broker.py
@@ -20,6 +20,8 @@ import random
 from datetime import datetime
 from api.db.db_models import Task
 from api.db.db_utils import bulk_insert_into_db
+from api.db.services.file2document_service import File2DocumentService
+from api.db.services.file_service import FileService
 from api.db.services.task_service import TaskService
 from deepdoc.parser import PdfParser
 from deepdoc.parser.excel_parser import HuExcelParser
@@ -87,10 +89,11 @@ def dispatch():
 
         tsks = []
         try:
-            file_bin = MINIO.get(r["kb_id"], r["location"])
+            bucket, name = File2DocumentService.get_minio_address(doc_id=r["id"])
+            file_bin = MINIO.get(bucket, name)
             if REDIS_CONN.is_alive():
                 try:
-                    REDIS_CONN.set("{}/{}".format(r["kb_id"], r["location"]), file_bin, 12*60)
+                    REDIS_CONN.set("{}/{}".format(bucket, name), file_bin, 12*60)
                 except Exception as e:
                     cron_logger.warning("Put into redis[EXCEPTION]:" + str(e))
 
diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py
index b72b1c5..032d9ea 100644
--- a/rag/svr/task_executor.py
+++ b/rag/svr/task_executor.py
@@ -24,6 +24,8 @@ import sys
 import time
 import traceback
 from functools import partial
+
+from api.db.services.file2document_service import File2DocumentService
 from rag.utils import MINIO
 from api.db.db_models import close_connection
 from rag.settings import database_logger
@@ -135,7 +137,8 @@ def build(row):
     pool = Pool(processes=1)
     try:
         st = timer()
-        thr = pool.apply_async(get_minio_binary, args=(row["kb_id"], row["location"]))
+        bucket, name = File2DocumentService.get_minio_address(doc_id=row["doc_id"])
+        thr = pool.apply_async(get_minio_binary, args=(bucket, name))
         binary = thr.get(timeout=90)
         pool.terminate()
         cron_logger.info(
-- 
GitLab