Skip to content
Snippets Groups Projects
Unverified Commit b89ac3c4 authored by KevinHuSh's avatar KevinHuSh Committed by GitHub
Browse files

chage tas execution logic (#103)

parent 16eade4c
No related branches found
No related tags found
No related merge requests found
...@@ -15,6 +15,8 @@ import re ...@@ -15,6 +15,8 @@ import re
from collections import Counter from collections import Counter
from copy import deepcopy from copy import deepcopy
import numpy as np import numpy as np
from api.db import ParserType
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from deepdoc.vision import Recognizer from deepdoc.vision import Recognizer
...@@ -35,6 +37,7 @@ class LayoutRecognizer(Recognizer): ...@@ -35,6 +37,7 @@ class LayoutRecognizer(Recognizer):
] ]
def __init__(self, domain): def __init__(self, domain):
super().__init__(self.labels, domain, os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) super().__init__(self.labels, domain, os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
self.garbage_layouts = ["footer", "header", "reference"]
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16): def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
def __is_garbage(b): def __is_garbage(b):
...@@ -85,7 +88,7 @@ class LayoutRecognizer(Recognizer): ...@@ -85,7 +88,7 @@ class LayoutRecognizer(Recognizer):
i += 1 i += 1
continue continue
lts_[ii]["visited"] = True lts_[ii]["visited"] = True
if lts_[ii]["type"] in ["footer", "header", "reference"]: if lts_[ii]["type"] in self.garbage_layouts:
if lts_[ii]["type"] not in garbages: if lts_[ii]["type"] not in garbages:
garbages[lts_[ii]["type"]] = [] garbages[lts_[ii]["type"]] = []
garbages[lts_[ii]["type"]].append(bxs[i]["text"]) garbages[lts_[ii]["type"]].append(bxs[i]["text"])
......
...@@ -6,11 +6,10 @@ export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/ ...@@ -6,11 +6,10 @@ export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/
PY=/root/miniconda3/envs/py11/bin/python PY=/root/miniconda3/envs/py11/bin/python
function task_exe(){ function task_exe(){
sleep 60; while [ 1 -eq 1 ];do
while [ 1 -eq 1 ];do mpirun -n 4 --allow-run-as-root $PY rag/svr/task_executor.py ; done $PY rag/svr/task_executor.py $1 $2;
done
} }
function watch_broker(){ function watch_broker(){
...@@ -29,7 +28,12 @@ function task_bro(){ ...@@ -29,7 +28,12 @@ function task_bro(){
} }
task_bro & task_bro &
task_exe &
WS=8
for ((i=0;i<WS;i++))
do
task_exe $i $WS &
done
$PY api/ragflow_server.py $PY api/ragflow_server.py
......
...@@ -119,7 +119,6 @@ def add_positions(d, poss): ...@@ -119,7 +119,6 @@ def add_positions(d, poss):
d["page_num_int"].append(pn + 1) d["page_num_int"].append(pn + 1)
d["top_int"].append(top) d["top_int"].append(top)
d["position_int"].append((pn + 1, left, right, top, bottom)) d["position_int"].append((pn + 1, left, right, top, bottom))
d["top_int"] = d["top_int"][:1]
def remove_contents_table(sections, eng=False): def remove_contents_table(sections, eng=False):
......
...@@ -157,11 +157,11 @@ class EsQueryer: ...@@ -157,11 +157,11 @@ class EsQueryer:
s = 1e-9 s = 1e-9
for k, v in qtwt.items(): for k, v in qtwt.items():
if k in dtwt: if k in dtwt:
s += v * dtwt[k] s += v# * dtwt[k]
q = 1e-9 q = 1e-9
for k, v in qtwt.items(): for k, v in qtwt.items():
q += v * v q += v * v
d = 1e-9 d = 1e-9
for k, v in dtwt.items(): for k, v in dtwt.items():
d += v * v d += v * v
return s / math.sqrt(q) / math.sqrt(d) return s / q#math.sqrt(q) / math.sqrt(d)
...@@ -192,7 +192,7 @@ class Dealer: ...@@ -192,7 +192,7 @@ class Dealer:
return [float(t) for t in txt.split("\t")] return [float(t) for t in txt.split("\t")]
def insert_citations(self, answer, chunks, chunk_v, def insert_citations(self, answer, chunks, chunk_v,
embd_mdl, tkweight=0.3, vtweight=0.7): embd_mdl, tkweight=0.7, vtweight=0.3):
assert len(chunks) == len(chunk_v) assert len(chunks) == len(chunk_v)
pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer) pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer)
for i in range(1, len(pieces)): for i in range(1, len(pieces)):
...@@ -224,7 +224,7 @@ class Dealer: ...@@ -224,7 +224,7 @@ class Dealer:
chunks_tks, chunks_tks,
tkweight, vtweight) tkweight, vtweight)
mx = np.max(sim) * 0.99 mx = np.max(sim) * 0.99
if mx < 0.55: if mx < 0.35:
continue continue
cites[idx[i]] = list( cites[idx[i]] = list(
set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4] set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
...@@ -237,7 +237,7 @@ class Dealer: ...@@ -237,7 +237,7 @@ class Dealer:
if i not in cites: if i not in cites:
continue continue
for c in cites[i]: assert int(c) < len(chunk_v) for c in cites[i]: assert int(c) < len(chunk_v)
res += "##%s$$" % "$".join(cites[i]) for c in cites[i]: res += f" ##{c}$$"
return res return res
......
...@@ -152,6 +152,7 @@ class Dealer: ...@@ -152,6 +152,7 @@ class Dealer:
def ner(t): def ner(t):
if not self.ne or t not in self.ne: if not self.ne or t not in self.ne:
return 1 return 1
if re.match(r"[0-9,.]+$", t): return 2
m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3, m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3,
"firstnm": 1} "firstnm": 1}
return m[self.ne[t]] return m[self.ne[t]]
......
...@@ -36,3 +36,5 @@ es_logger = getLogger("es") ...@@ -36,3 +36,5 @@ es_logger = getLogger("es")
minio_logger = getLogger("minio") minio_logger = getLogger("minio")
cron_logger = getLogger("cron_logger") cron_logger = getLogger("cron_logger")
chunk_logger = getLogger("chunk_logger") chunk_logger = getLogger("chunk_logger")
database_logger = getLogger("database")
...@@ -23,13 +23,14 @@ import re ...@@ -23,13 +23,14 @@ import re
import sys import sys
import traceback import traceback
from functools import partial from functools import partial
from timeit import default_timer as timer
from rag.settings import database_logger
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
import numpy as np import numpy as np
from elasticsearch_dsl import Q from elasticsearch_dsl import Q
from api.db.services.task_service import TaskService from api.db.services.task_service import TaskService
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
from rag.utils import ELASTICSEARCH from rag.utils import ELASTICSEARCH
from rag.utils import MINIO from rag.utils import MINIO
from rag.utils import rmSpace, findMaxTm from rag.utils import rmSpace, findMaxTm
...@@ -43,7 +44,6 @@ from rag.app import laws, paper, presentation, manual, qa, table, book, resume, ...@@ -43,7 +44,6 @@ from rag.app import laws, paper, presentation, manual, qa, table, book, resume,
from api.db import LLMType, ParserType from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.settings import database_logger
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
BATCH_SIZE = 64 BATCH_SIZE = 64
...@@ -267,4 +267,4 @@ if __name__ == "__main__": ...@@ -267,4 +267,4 @@ if __name__ == "__main__":
from mpi4py import MPI from mpi4py import MPI
comm = MPI.COMM_WORLD comm = MPI.COMM_WORLD
main(comm.Get_size(), comm.Get_rank()) main(int(sys.argv[2]), int(sys.argv[1]))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment