From 3fc700a1d4fa1cd232c55cb1555dfc6fc11995b2 Mon Sep 17 00:00:00 2001
From: KevinHuSh <kevinhu.sh@gmail.com>
Date: Tue, 26 Dec 2023 19:32:06 +0800
Subject: [PATCH] build dialog server; add thumbnail to docinfo; (#17)

---
 Cargo.toml                                    |  1 +
 .../src/m20220101_000001_create_table.rs      | 15 +++---
 python/conf/sys.cnf                           |  8 +--
 python/llm/chat_model.py                      |  1 +
 python/nlp/search.py                          |  2 +
 python/svr/dialog_svr.py                      | 24 ++++-----
 python/svr/parse_user_docs.py                 | 12 ++---
 src/api/doc_info.rs                           | 50 +++++++++++++++++--
 src/api/user_info.rs                          |  3 +-
 src/entity/doc_info.rs                        |  2 +
 src/entity/tag2_doc.rs                        | 14 +++---
 src/service/doc_info.rs                       |  4 +-
 12 files changed, 94 insertions(+), 42 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index 883d2fa..b324cf3 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -26,6 +26,7 @@ migration = { path = "./migration" }
 minio = "0.1.0"
 futures-util = "0.3.29"
 actix-multipart-extract = "0.1.5"
+regex = "1.10.2"
 
 [[bin]]
 name = "doc_gpt"
diff --git a/migration/src/m20220101_000001_create_table.rs b/migration/src/m20220101_000001_create_table.rs
index cbd8173..41287e8 100644
--- a/migration/src/m20220101_000001_create_table.rs
+++ b/migration/src/m20220101_000001_create_table.rs
@@ -201,7 +201,8 @@ impl MigrationTrait for Migration {
                 .col(ColumnDef::new(DocInfo::Location).string().not_null())
                 .col(ColumnDef::new(DocInfo::Size).big_integer().not_null())
                 .col(ColumnDef::new(DocInfo::Type).string().not_null())
-                .comment("doc|folder")
+                .col(ColumnDef::new(DocInfo::ThumbnailBase64).string().not_null())
+                .comment("doc type|folder")
                 .col(
                     ColumnDef::new(DocInfo::CreatedAt)
                         .timestamp_with_time_zone()
@@ -249,7 +250,6 @@ impl MigrationTrait for Migration {
                 .to_owned()
         ).await?;
 
-        let tm = now();
         let root_insert = Query::insert()
             .into_table(UserInfo::Table)
             .columns([UserInfo::Email, UserInfo::Nickname, UserInfo::Password])
@@ -273,28 +273,28 @@ impl MigrationTrait for Migration {
             .columns([TagInfo::Uid, TagInfo::TagName, TagInfo::Regx, TagInfo::Color, TagInfo::Icon])
             .values_panic([
                 (1).into(),
-                "视频".into(),
+                "Video".into(),
                 ".*\\.(mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa)".into(),
                 (1).into(),
                 (1).into(),
             ])
             .values_panic([
                 (1).into(),
-                "图片".into(),
-                ".*\\.(png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng)".into(),
+                "Picture".into(),
+                ".*\\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng)".into(),
                 (2).into(),
                 (2).into(),
             ])
             .values_panic([
                 (1).into(),
-                "音乐".into(),
+                "Music".into(),
                 ".*\\.(WAV|FLAC|APE|ALAC|WavPack|WV|MP3|AAC|Ogg|Vorbis|Opus)".into(),
                 (3).into(),
                 (3).into(),
             ])
             .values_panic([
                 (1).into(),
-                "文档".into(),
+                "Document".into(),
                 ".*\\.(pdf|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp)".into(),
                 (3).into(),
                 (3).into(),
@@ -419,6 +419,7 @@ enum DocInfo {
     Location,
     Size,
     Type,
+    ThumbnailBase64,
     CreatedAt,
     UpdatedAt,
     IsDeleted,
diff --git a/python/conf/sys.cnf b/python/conf/sys.cnf
index e217ad5..adff7cc 100755
--- a/python/conf/sys.cnf
+++ b/python/conf/sys.cnf
@@ -1,10 +1,10 @@
 [infiniflow]
-es=http://127.0.0.1:9200
+es=http://es01:9200
 pgdb_usr=root
 pgdb_pwd=infiniflow_docgpt
-pgdb_host=127.0.0.1
-pgdb_port=5455
-minio_host=127.0.0.1:9000
+pgdb_host=postgres
+pgdb_port=5432
+minio_host=minio:9000
 minio_usr=infiniflow
 minio_pwd=infiniflow_docgpt
 
diff --git a/python/llm/chat_model.py b/python/llm/chat_model.py
index 49c2518..34fb7d1 100644
--- a/python/llm/chat_model.py
+++ b/python/llm/chat_model.py
@@ -24,6 +24,7 @@ class QWen(Base):
         from http import HTTPStatus
         from dashscope import Generation
         from dashscope.api_entities.dashscope_response import Role
+        # export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
         response = Generation.call(
                     Generation.Models.qwen_turbo,
                     messages=messages,
diff --git a/python/nlp/search.py b/python/nlp/search.py
index e751b66..1fbd798 100644
--- a/python/nlp/search.py
+++ b/python/nlp/search.py
@@ -9,6 +9,8 @@ from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
 import numpy as np
 from copy import deepcopy
 
+def index_name(uid):return f"docgpt_{uid}"
+
 class Dealer:
     def __init__(self, es, emb_mdl):
         self.qryr = query.EsQueryer(es)
diff --git a/python/svr/dialog_svr.py b/python/svr/dialog_svr.py
index 5d683d6..80f9f5e 100755
--- a/python/svr/dialog_svr.py
+++ b/python/svr/dialog_svr.py
@@ -6,11 +6,10 @@ from tornado.ioloop import IOLoop
 from tornado.httpserver import HTTPServer
 from tornado.options import define,options
 from util import es_conn, setup_logging
-from svr import sec_search as search
-from svr.rpc_proxy import RPCProxy
 from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
 from nlp import huqie
 from nlp import query as Query
+from nlp import search
 from llm import HuEmbedding, GptTurbo
 import numpy as np
 from io import BytesIO
@@ -38,7 +37,7 @@ def get_QA_pairs(hists):
 
 
 
-def get_instruction(sres, top_i, max_len=8096 fld="content_ltks"):
+def get_instruction(sres, top_i, max_len=8096, fld="content_ltks"):
     max_len //= len(top_i)
     # add instruction to prompt
     instructions = [re.sub(r"[\r\n]+", " ", sres.field[sres.ids[i]][fld]) for i in top_i]
@@ -96,10 +95,11 @@ class Handler(RequestHandler):
         try:
             question = param.get("history",[{"user": "Hi!"}])[-1]["user"]
             res = SE.search({
-                "question": question,
-                "kb_ids": param.get("kb_ids", []),
-                "size": param.get("topn", 15)
-            })
+                    "question": question,
+                    "kb_ids": param.get("kb_ids", []),
+                    "size": param.get("topn", 15)},
+               search.index_name(param["uid"]) 
+            )
 
             sim = SE.rerank(res, question)  
             rk_idx = np.argsort(sim*-1)
@@ -112,12 +112,12 @@ class Handler(RequestHandler):
             refer = OrderedDict()
             docnms = {}
             for i in rk_idx:
-                 did = res.field[res.ids[i]]["doc_id"])
-                 if did not in docnms: docnms[did] = res.field[res.ids[i]]["docnm_kwd"])
+                 did = res.field[res.ids[i]]["doc_id"]
+                 if did not in docnms: docnms[did] = res.field[res.ids[i]]["docnm_kwd"]
                  if did not in refer: refer[did] = []
                  refer[did].append({
                      "chunk_id": res.ids[i],
-                     "content": res.field[res.ids[i]]["content_ltks"]),
+                     "content": res.field[res.ids[i]]["content_ltks"],
                      "image": ""
                  })
 
@@ -128,7 +128,7 @@ class Handler(RequestHandler):
                 "data":{
                     "uid": param["uid"],
                     "dialog_id": param["dialog_id"],
-                    "assistant": ans
+                    "assistant": ans,
                     "refer": [{
                         "did": did,
                         "doc_name": docnms[did],
@@ -153,7 +153,7 @@ if __name__ == '__main__':
     parser.add_argument("--port", default=4455, type=int, help="Port used for service")
     ARGS = parser.parse_args()
     
-    SE = search.ResearchReportSearch(es_conn.HuEs("infiniflow"), EMBEDDING)
+    SE = search.Dealer(es_conn.HuEs("infiniflow"), EMBEDDING)
 
     app = Application([(r'/v1/chat/completions', Handler)],debug=False)
     http_server = HTTPServer(app)
diff --git a/python/svr/parse_user_docs.py b/python/svr/parse_user_docs.py
index 573cd1a..cad4305 100644
--- a/python/svr/parse_user_docs.py
+++ b/python/svr/parse_user_docs.py
@@ -6,7 +6,7 @@ from util.db_conn import Postgres
 from util.minio_conn import HuMinio
 from util import rmSpace, findMaxDt
 from FlagEmbedding import FlagModel
-from nlp import huchunk, huqie
+from nlp import huchunk, huqie, search
 import base64, hashlib
 from io import BytesIO
 import pandas as pd
@@ -103,7 +103,7 @@ def build(row):
                                if(!ctx._source.kb_id.contains('%s'))
                                  ctx._source.kb_id.add('%s');
                                """%(str(row["kb_id"]), str(row["kb_id"])),
-                               idxnm = index_name(row["uid"])
+                               idxnm = search.index_name(row["uid"])
                               )
         set_progress(row["kb2doc_id"], 1, "Done")
         return []
@@ -171,10 +171,8 @@ def build(row):
     return docs
 
 
-def index_name(uid):return f"docgpt_{uid}"
-
 def init_kb(row):
-    idxnm = index_name(row["uid"])
+    idxnm = search.index_name(row["uid"])
     if ES.indexExist(idxnm): return
     return ES.createIdx(idxnm, json.load(open("conf/mapping.json", "r")))
 
@@ -199,7 +197,7 @@ def rm_doc_from_kb(df):
                                      ctx._source.kb_id.indexOf('%s')
                                );
                                 """%(str(r["kb_id"]),str(r["kb_id"])),
-                               idxnm = index_name(r["uid"])
+                               idxnm = search.index_name(r["uid"])
                               )
     if len(df) == 0:return
     sql = """
@@ -233,7 +231,7 @@ def main(comm, mod):
         set_progress(r["kb2doc_id"], random.randint(70, 95)/100., 
                      "Finished embedding! Start to build index!")
         init_kb(r)
-        es_r = ES.bulk(cks, index_name(r["uid"]))
+        es_r = ES.bulk(cks, search.index_name(r["uid"]))
         if es_r:
             set_progress(r["kb2doc_id"], -1, "Index failure!")
             print(es_r)
diff --git a/src/api/doc_info.rs b/src/api/doc_info.rs
index ba8006b..9169bf0 100644
--- a/src/api/doc_info.rs
+++ b/src/api/doc_info.rs
@@ -11,6 +11,7 @@ use crate::entity::doc_info::Model;
 use crate::errors::AppError;
 use crate::service::doc_info::{ Mutation, Query };
 use serde::Deserialize;
+use regex::Regex;
 
 fn now() -> chrono::DateTime<FixedOffset> {
     Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
@@ -64,6 +65,41 @@ pub struct UploadForm {
     did: i64,
 }
 
+fn file_type(filename: &String) -> String {
+    let fnm = filename.to_lowercase();
+    if
+        let Some(_) = Regex::new(r"\.(mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa)$")
+            .unwrap()
+            .captures(&fnm)
+    {
+        return "Video".to_owned();
+    }
+    if
+        let Some(_) = Regex::new(
+            r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng)$"
+        )
+            .unwrap()
+            .captures(&fnm)
+    {
+        return "Picture".to_owned();
+    }
+    if
+        let Some(_) = Regex::new(r"\.(WAV|FLAC|APE|ALAC|WavPack|WV|MP3|AAC|Ogg|Vorbis|Opus)$")
+            .unwrap()
+            .captures(&fnm)
+    {
+        return "Music".to_owned();
+    }
+    if
+        let Some(_) = Regex::new(r"\.(pdf|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp)$")
+            .unwrap()
+            .captures(&fnm)
+    {
+        return "Document".to_owned();
+    }
+    "Other".to_owned()
+}
+
 #[post("/v1.0/upload")]
 async fn upload(
     payload: Multipart<UploadForm>,
@@ -114,7 +150,13 @@ async fn upload(
         print!("Existing bucket: {}", bucket_name.clone());
     }
 
-    let location = format!("/{}/{}", payload.did, fnm);
+    let location = format!("/{}/{}", payload.did, fnm)
+        .as_bytes()
+        .to_vec()
+        .iter()
+        .map(|b| format!("{:02x}", b).to_string())
+        .collect::<Vec<String>>()
+        .join("");
     print!("===>{}", location.clone());
     s3_client.put_object(
         &mut PutObjectArgs::new(
@@ -129,10 +171,11 @@ async fn upload(
     let doc = Mutation::create_doc_info(&data.conn, Model {
         did: Default::default(),
         uid: uid,
-        doc_name: fnm,
+        doc_name: fnm.clone(),
         size: payload.file_field.bytes.len() as i64,
         location,
-        r#type: "doc".to_string(),
+        r#type: file_type(&fnm),
+        thumbnail_base64: Default::default(),
         created_at: now(),
         updated_at: now(),
         is_deleted: Default::default(),
@@ -214,6 +257,7 @@ async fn new_folder(
         size: 0,
         r#type: "folder".to_string(),
         location: "".to_owned(),
+        thumbnail_base64: Default::default(),
         created_at: now(),
         updated_at: now(),
         is_deleted: Default::default(),
diff --git a/src/api/user_info.rs b/src/api/user_info.rs
index 09eeb5d..3d9f89e 100644
--- a/src/api/user_info.rs
+++ b/src/api/user_info.rs
@@ -90,12 +90,13 @@ async fn register(
         doc_name: "/".into(),
         size: 0,
         location: "".into(),
+        thumbnail_base64: "".into(),
         r#type: "folder".to_string(),
         created_at: now(),
         updated_at: now(),
         is_deleted: Default::default(),
     }).await?;
-    let tnm = vec!["视频", "图片", "音乐", "文档"];
+    let tnm = vec!["Video", "Picture", "Music", "Document"];
     let tregx = vec![
         ".*\\.(mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa)",
         ".*\\.(png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng)",
diff --git a/src/entity/doc_info.rs b/src/entity/doc_info.rs
index b11b67f..26c46d5 100644
--- a/src/entity/doc_info.rs
+++ b/src/entity/doc_info.rs
@@ -17,6 +17,8 @@ pub struct Model {
     #[serde(skip_deserializing)]
     pub location: String,
     #[serde(skip_deserializing)]
+    pub thumbnail_base64: String,
+    #[serde(skip_deserializing)]
     pub created_at: DateTime<FixedOffset>,
     #[serde(skip_deserializing)]
     pub updated_at: DateTime<FixedOffset>,
diff --git a/src/entity/tag2_doc.rs b/src/entity/tag2_doc.rs
index 2a37f1c..468c5fd 100644
--- a/src/entity/tag2_doc.rs
+++ b/src/entity/tag2_doc.rs
@@ -9,28 +9,28 @@ pub struct Model {
     #[sea_orm(index)]
     pub tag_id: i64,
     #[sea_orm(index)]
-    pub uid: i64,
+    pub did: i64,
 }
 
 #[derive(Debug, Clone, Copy, EnumIter)]
 pub enum Relation {
-    DocInfo,
     Tag,
+    DocInfo,
 }
 
 impl RelationTrait for Relation {
     fn def(&self) -> sea_orm::RelationDef {
         match self {
-            Self::DocInfo =>
-                Entity::belongs_to(super::doc_info::Entity)
-                    .from(Column::Uid)
-                    .to(super::doc_info::Column::Uid)
-                    .into(),
             Self::Tag =>
                 Entity::belongs_to(super::tag_info::Entity)
                     .from(Column::TagId)
                     .to(super::tag_info::Column::Tid)
                     .into(),
+            Self::DocInfo =>
+                Entity::belongs_to(super::doc_info::Entity)
+                    .from(Column::Did)
+                    .to(super::doc_info::Column::Did)
+                    .into(),
         }
     }
 }
diff --git a/src/service/doc_info.rs b/src/service/doc_info.rs
index 480669f..18fc87f 100644
--- a/src/service/doc_info.rs
+++ b/src/service/doc_info.rs
@@ -163,7 +163,7 @@ impl Query {
                 );
             }
             if tag.regx.len() > 0 {
-                cond.push_str(&format!(" and doc_name ~ '{}'", tag.regx));
+                cond.push_str(&format!(" and (type='{}' or doc_name ~ '{}') ", tag.tag_name, tag.regx));
             }
         }
 
@@ -254,6 +254,7 @@ impl Mutation {
             size: Set(form_data.size.to_owned()),
             r#type: Set(form_data.r#type.to_owned()),
             location: Set(form_data.location.to_owned()),
+            thumbnail_base64: Default::default(),
             created_at: Set(form_data.created_at.to_owned()),
             updated_at: Set(form_data.updated_at.to_owned()),
             is_deleted: Default::default(),
@@ -277,6 +278,7 @@ impl Mutation {
             size: Set(form_data.size.to_owned()),
             r#type: Set(form_data.r#type.to_owned()),
             location: Set(form_data.location.to_owned()),
+            thumbnail_base64: doc_info.thumbnail_base64,
             created_at: doc_info.created_at,
             updated_at: Set(now()),
             is_deleted: Default::default(),
-- 
GitLab