diff --git a/Cargo.toml b/Cargo.toml index d4cd561b816759ec5329e7927385bcd27c9cd805..aed22e1b16e2c23523b144253d6fcd9f494fadc1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ listenfd = "1.0.1" chrono = "0.4.31" migration = { path = "./migration" } futures-util = "0.3.29" +actix-multipart-extract = "0.1.5" [[bin]] name = "doc_gpt" diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 6bdc0f472bfb30544d64e2d0bbe8330c09987b4f..249e5ab5547292e153f67fe03e45d33351e6790e 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -1,7 +1,7 @@ version: '2.2' services: es01: - container_name: docass-es-01 + container_name: docgpt-es-01 image: docker.elastic.co/elasticsearch/elasticsearch:${STACK_VERSION} volumes: - esdata01:/usr/share/elasticsearch/data @@ -20,14 +20,14 @@ services: soft: -1 hard: -1 networks: - - docass + - docgpt restart: always kibana: depends_on: - es01 image: docker.elastic.co/kibana/kibana:${STACK_VERSION} - container_name: docass-kibana + container_name: docgpt-kibana volumes: - kibanadata:/usr/share/kibana/data ports: @@ -37,21 +37,21 @@ services: - ELASTICSEARCH_HOSTS=http://es01:9200 mem_limit: ${MEM_LIMIT} networks: - - docass + - docgpt postgres: image: postgres - container_name: docass-postgres + container_name: docgpt-postgres environment: - POSTGRES_USER=${POSTGRES_USER} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD} - POSTGRES_DB=${POSTGRES_DB} ports: - - 5455:5455 + - 5455:5432 volumes: - - pg_data:/usr/share/elasticsearch/data + - pg_data:/var/lib/postgresql/data networks: - - docass + - docgpt restart: always @@ -64,5 +64,5 @@ volumes: driver: local networks: - docass: + docgpt: driver: bridge diff --git a/migration/src/m20220101_000001_create_table.rs b/migration/src/m20220101_000001_create_table.rs index 3c0859dbbcaa8d34f0c6d7311c90d7fafe1981c7..f310a29081b9123b4a0fd07d99cd488ee2430dfd 100644 --- a/migration/src/m20220101_000001_create_table.rs +++ b/migration/src/m20220101_000001_create_table.rs @@ -47,8 +47,8 @@ impl MigrationTrait for Migration { .col(ColumnDef::new(TagInfo::Uid).big_integer().not_null()) .col(ColumnDef::new(TagInfo::TagName).string().not_null()) .col(ColumnDef::new(TagInfo::Regx).string()) - .col(ColumnDef::new(TagInfo::Color).big_integer().default(1)) - .col(ColumnDef::new(TagInfo::Icon).big_integer().default(1)) + .col(ColumnDef::new(TagInfo::Color).tiny_unsigned().default(1)) + .col(ColumnDef::new(TagInfo::Icon).tiny_unsigned().default(1)) .col(ColumnDef::new(TagInfo::Dir).string()) .col(ColumnDef::new(TagInfo::CreatedAt).date().not_null()) .col(ColumnDef::new(TagInfo::UpdatedAt).date().not_null()) @@ -62,6 +62,13 @@ impl MigrationTrait for Migration { Table::create() .table(Tag2Doc::Table) .if_not_exists() + .col( + ColumnDef::new(Tag2Doc::Id) + .big_integer() + .not_null() + .auto_increment() + .primary_key(), + ) .col(ColumnDef::new(Tag2Doc::TagId).big_integer()) .col(ColumnDef::new(Tag2Doc::Did).big_integer()) .to_owned(), @@ -73,6 +80,13 @@ impl MigrationTrait for Migration { Table::create() .table(Kb2Doc::Table) .if_not_exists() + .col( + ColumnDef::new(Kb2Doc::Id) + .big_integer() + .not_null() + .auto_increment() + .primary_key(), + ) .col(ColumnDef::new(Kb2Doc::KbId).big_integer()) .col(ColumnDef::new(Kb2Doc::Did).big_integer()) .to_owned(), @@ -84,6 +98,13 @@ impl MigrationTrait for Migration { Table::create() .table(Dialog2Kb::Table) .if_not_exists() + .col( + ColumnDef::new(Dialog2Kb::Id) + .big_integer() + .not_null() + .auto_increment() + .primary_key(), + ) .col(ColumnDef::new(Dialog2Kb::DialogId).big_integer()) .col(ColumnDef::new(Dialog2Kb::KbId).big_integer()) .to_owned(), @@ -95,6 +116,13 @@ impl MigrationTrait for Migration { Table::create() .table(Doc2Doc::Table) .if_not_exists() + .col( + ColumnDef::new(Doc2Doc::Id) + .big_integer() + .not_null() + .auto_increment() + .primary_key(), + ) .col(ColumnDef::new(Doc2Doc::ParentId).big_integer()) .col(ColumnDef::new(Doc2Doc::Did).big_integer()) .to_owned(), @@ -112,7 +140,7 @@ impl MigrationTrait for Migration { .primary_key()) .col(ColumnDef::new(KbInfo::Uid).big_integer().not_null()) .col(ColumnDef::new(KbInfo::KbName).string().not_null()) - .col(ColumnDef::new(KbInfo::Icon).big_integer().default(1)) + .col(ColumnDef::new(KbInfo::Icon).tiny_unsigned().default(1)) .col(ColumnDef::new(KbInfo::CreatedAt).date().not_null()) .col(ColumnDef::new(KbInfo::UpdatedAt).date().not_null()) .col(ColumnDef::new(KbInfo::IsDeleted).boolean().default(false)) @@ -135,6 +163,7 @@ impl MigrationTrait for Migration { .col(ColumnDef::new(DocInfo::Size).big_integer().not_null()) .col(ColumnDef::new(DocInfo::Type).string().not_null()).comment("doc|folder") .col(ColumnDef::new(DocInfo::KbProgress).float().default(0)) + .col(ColumnDef::new(DocInfo::KbProgressMsg).string().default("")) .col(ColumnDef::new(DocInfo::CreatedAt).date().not_null()) .col(ColumnDef::new(DocInfo::UpdatedAt).date().not_null()) .col(ColumnDef::new(DocInfo::IsDeleted).boolean().default(false)) @@ -148,7 +177,7 @@ impl MigrationTrait for Migration { .table(DialogInfo::Table) .if_not_exists() .col(ColumnDef::new(DialogInfo::DialogId) - .big_integer() + .big_integer() .not_null() .auto_increment() .primary_key()) @@ -240,6 +269,7 @@ enum TagInfo { #[derive(DeriveIden)] enum Tag2Doc { Table, + Id, TagId, Did, } @@ -247,6 +277,7 @@ enum Tag2Doc { #[derive(DeriveIden)] enum Kb2Doc { Table, + Id, KbId, Did, } @@ -254,6 +285,7 @@ enum Kb2Doc { #[derive(DeriveIden)] enum Dialog2Kb { Table, + Id, DialogId, KbId, } @@ -261,6 +293,7 @@ enum Dialog2Kb { #[derive(DeriveIden)] enum Doc2Doc { Table, + Id, ParentId, Did, } @@ -287,6 +320,7 @@ enum DocInfo { Size, Type, KbProgress, + KbProgressMsg, CreatedAt, UpdatedAt, IsDeleted, @@ -302,4 +336,4 @@ enum DialogInfo { CreatedAt, UpdatedAt, IsDeleted, -} \ No newline at end of file +} diff --git a/python/conf/sys.cnf b/python/conf/sys.cnf index fc0d64c414082eb6610298c5a8f709e15332e37f..b8d3268dd05cfb17126cbe18e11007f56d8182b6 100755 --- a/python/conf/sys.cnf +++ b/python/conf/sys.cnf @@ -1,8 +1,7 @@ -[online] +[infiniflow] es=127.0.0.1:9200 -idx_nm=toxic pgdb_usr=root -pgdb_pwd=infiniflow_docass +pgdb_pwd=infiniflow_docgpt pgdb_host=127.0.0.1 -pgdb_port=5432 +pgdb_port=5455 diff --git a/python/nlp/huchunk.py b/python/nlp/huchunk.py index 619640227800c61fced22812d713a303b3581efe..6164375d9dcbb34fe56e10d916d7fb0feffbaaf6 100644 --- a/python/nlp/huchunk.py +++ b/python/nlp/huchunk.py @@ -359,6 +359,47 @@ class ExcelChunker(HuChunker): return flds +class PptChunker(HuChunker): + + @dataclass + class Fields: + text_chunks: List = None + table_chunks: List = None + + def __init__(self): + super().__init__() + + def __call__(self, fnm): + from pptx import Presentation + ppt = Presentation(fnm) + flds = self.Fields() + for slide in ppt.slides: + for shape in slide.shapes: + if hasattr(shape, "text"): + flds.text_chunks.append((shape.text, None)) + flds.table_chunks = [] + return flds + + +class TextChunker(HuChunker): + + @dataclass + class Fields: + text_chunks: List = None + table_chunks: List = None + + def __init__(self): + super().__init__() + + def __call__(self, fnm): + flds = self.Fields() + with open(fnm, "r") as f: + txt = f.read() + flds.text_chunks = self.naive_text_chunk(txt) + flds.table_chunks = [] + return flds + + if __name__ == "__main__": import sys sys.path.append(os.path.dirname(__file__) + "/../") diff --git a/python/svr/parse_user_docs.py b/python/svr/parse_user_docs.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff14c40ed508f1b228d89085df78dd49bc9f943 --- /dev/null +++ b/python/svr/parse_user_docs.py @@ -0,0 +1,171 @@ +import json, re, sys, os, hashlib, copy, glob, util, time, random +from util.es_conn import HuEs, Postgres +from util import rmSpace, findMaxDt +from FlagEmbedding import FlagModel +from nlp import huchunk, huqie +import base64, hashlib +from io import BytesIO +from elasticsearch_dsl import Q +from parser import ( + PdfParser, + DocxParser, + ExcelParser +) +from nlp.huchunk import ( + PdfChunker, + DocxChunker, + ExcelChunker, + PptChunker, + TextChunker +) + +ES = HuEs("infiniflow") +BATCH_SIZE = 64 +PG = Postgres("infiniflow", "docgpt") + +PDF = PdfChunker(PdfParser()) +DOC = DocxChunker(DocxParser()) +EXC = ExcelChunker(ExcelParser()) +PPT = PptChunker() + + +def chuck_doc(name): + name = os.path.split(name)[-1].lower().split(".")[-1] + if name.find("pdf") >= 0: return PDF(name) + if name.find("doc") >= 0: return DOC(name) + if name.find("xlsx") >= 0: return EXC(name) + if name.find("ppt") >= 0: return PDF(name) + if name.find("pdf") >= 0: return PPT(name) + + if re.match(r"(txt|csv)", name): return TextChunker(name) + + +def collect(comm, mod, tm): + sql = f""" + select + did, + uid, + doc_name, + location, + updated_at + from docinfo + where + updated_at >= '{tm}' + and kb_progress = 0 + and type = 'doc' + and MOD(uid, {comm}) = {mod} + order by updated_at asc + limit 1000 + """ + df = PG.select(sql) + df = df.fillna("") + mtm = str(df["updated_at"].max())[:19] + print("TOTAL:", len(df), "To: ", mtm) + return df, mtm + + +def set_progress(did, prog, msg): + sql = f""" + update docinfo set kb_progress={prog}, kb_progress_msg='{msg}' where did={did} + """ + PG.update(sql) + + +def build(row): + if row["size"] > 256000000: + set_progress(row["did"], -1, "File size exceeds( <= 256Mb )") + return [] + doc = { + "doc_id": row["did"], + "title_tks": huqie.qie(os.path.split(row["location"])[-1]), + "updated_at": row["updated_at"] + } + random.seed(time.time()) + set_progress(row["did"], random.randint(0, 20)/100., "Finished preparing! Start to slice file!") + obj = chuck_doc(row["location"]) + if not obj: + set_progress(row["did"], -1, "Unsuported file type.") + return [] + + set_progress(row["did"], random.randint(20, 60)/100.) + + output_buffer = BytesIO() + docs = [] + md5 = hashlib.md5() + for txt, img in obj.text_chunks: + d = copy.deepcopy(doc) + md5.update((txt + str(d["doc_id"])).encode("utf-8")) + d["_id"] = md5.hexdigest() + d["content_ltks"] = huqie.qie(txt) + d["docnm_kwd"] = rmSpace(d["docnm_tks"]) + if not img: + docs.append(d) + continue + img.save(output_buffer, format='JPEG') + d["img_bin"] = base64.b64encode(output_buffer.getvalue()) + docs.append(d) + + for arr, img in obj.table_chunks: + for i, txt in enumerate(arr): + d = copy.deepcopy(doc) + d["content_ltks"] = huqie.qie(txt) + md5.update((txt + str(d["doc_id"])).encode("utf-8")) + d["_id"] = md5.hexdigest() + if not img: + docs.append(d) + continue + img.save(output_buffer, format='JPEG') + d["img_bin"] = base64.b64encode(output_buffer.getvalue()) + docs.append(d) + set_progress(row["did"], random.randint(60, 70)/100., "Finished slicing. Start to embedding the content.") + + return docs + + +def index_name(uid):return f"docgpt_{uid}" + +def init_kb(row): + idxnm = index_name(row["uid"]) + if ES.indexExist(idxnm): return + return ES.createIdx(idxnm, json.load(open("res/mapping.json", "r"))) + + +model = None +def embedding(docs): + global model + tts = model.encode([rmSpace(d["title_tks"]) for d in docs]) + cnts = model.encode([rmSpace(d["content_ltks"]) for d in docs]) + vects = 0.1 * tts + 0.9 * cnts + assert len(vects) == len(docs) + for i,d in enumerate(docs):d["q_vec"] = vects[i].tolist() + for d in docs: + set_progress(d["doc_id"], random.randint(70, 95)/100., + "Finished embedding! Start to build index!") + + +def main(comm, mod): + tm_fnm = f"res/{comm}-{mod}.tm" + tmf = open(tm_fnm, "a+") + tm = findMaxDt(tm_fnm) + rows, tm = collect(comm, mod, tm) + for r in rows: + if r["is_deleted"]: + ES.deleteByQuery(Q("term", dock_id=r["did"]), index_name(r["uid"])) + continue + + cks = build(r) + ## TODO: exception handler + ## set_progress(r["did"], -1, "ERROR: ") + embedding(cks) + if cks: init_kb(r) + ES.bulk(cks, index_name(r["uid"])) + tmf.write(str(r["updated_at"]) + "\n") + tmf.close() + + +if __name__ == "__main__": + from mpi4py import MPI + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + main(comm, rank) + diff --git a/python/util/__init__.py b/python/util/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..c7080b766dc8d313c457a1354c7d1770db0fc760 100644 --- a/python/util/__init__.py +++ b/python/util/__init__.py @@ -0,0 +1,19 @@ +import re + +def rmSpace(txt): + txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt) + return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt) + +def findMaxDt(fnm): + m = "1970-01-01 00:00:00" + try: + with open(fnm, "r") as f: + while True: + l = f.readline() + if not l:break + l = l.strip("\n") + if l == 'nan':continue + if l > m:m = l + except Exception as e: + print("WARNING: can't find "+ fnm) + return m diff --git a/python/util/config.py b/python/util/config.py index 868855d120537a8f912e8f5480a7e7a78213ace5..78429e570ab76dfd94434e44333a33e5c84ce763 100755 --- a/python/util/config.py +++ b/python/util/config.py @@ -9,7 +9,6 @@ if not os.path.exists(__fnm): __fnm = "./sys.cnf" CF.read(__fnm) - class Config: def __init__(self, env): self.env = env diff --git a/python/util/db_conn.py b/python/util/db_conn.py index b67e13e926c390fdf11d5d0f9134f2cddf34ecad..ca9e4baedf54f8e10bebf4eb02da5ce94171994e 100644 --- a/python/util/db_conn.py +++ b/python/util/db_conn.py @@ -3,7 +3,7 @@ import time from util import config import pandas as pd -class Postgre(object): +class Postgres(object): def __init__(self, env, dbnm): self.config = config.init(env) self.conn = None @@ -36,9 +36,28 @@ class Postgre(object): try: return pd.read_sql(sql, self.conn) except Exception as e: - logging.error(f"Fail to exec {sql}l "+str(e)) + logging.error(f"Fail to exec {sql} "+str(e)) self.__open__() time.sleep(1) return pd.DataFrame() + + def update(self, sql): + for _ in range(10): + try: + cur = self.conn.cursor() + cur.execute(sql) + updated_rows = cur.rowcount + conn.commit() + cur.close() + return updated_rows + except Exception as e: + logging.error(f"Fail to exec {sql} "+str(e)) + self.__open__() + time.sleep(1) + return 0 + +if __name__ == "__main__": + Postgres("infiniflow", "docgpt") + diff --git a/python/util/es_conn.py b/python/util/es_conn.py index 3e41ab5022950e3cfe01940e335c5ed5df458ffc..ea917a7238b9e3af3fb00f46f2d30e34a08e7dd5 100755 --- a/python/util/es_conn.py +++ b/python/util/es_conn.py @@ -31,7 +31,7 @@ class HuEs: self.info = {} self.config = config.init(env) self.conn() - self.idxnm = self.config.get("idx_nm") + self.idxnm = self.config.get("idx_nm","") if not self.es.ping(): raise Exception("Can't connect to ES cluster") diff --git a/src/api/doc_info.rs b/src/api/doc_info.rs index d9589c02e0f106b085294d6f77c025e2f7e4507b..df7a2d923d15c49fad051b55856b2587994569eb 100644 --- a/src/api/doc_info.rs +++ b/src/api/doc_info.rs @@ -1,15 +1,20 @@ use std::collections::HashMap; -use actix_multipart::Multipart; +use std::io::Write; +use std::slice::Chunks; +//use actix_multipart::{Multipart, MultipartError, Field}; +use actix_multipart_extract::{File, Multipart, MultipartForm}; use actix_web::{get, HttpResponse, post, web}; +use actix_web::web::Bytes; use chrono::Local; use futures_util::StreamExt; -use serde::Deserialize; -use std::io::Write; +use sea_orm::DbConn; use crate::api::JsonResponse; use crate::AppState; use crate::entity::doc_info::Model; use crate::errors::AppError; use crate::service::doc_info::{Mutation, Query}; +use serde::Deserialize; + #[derive(Debug, Deserialize)] pub struct Params { @@ -53,41 +58,54 @@ async fn list(params: web::Json<Params>, data: web::Data<AppState>) -> Result<Ht .body(serde_json::to_string(&json_response)?)) } -#[post("/v1.0/upload")] -async fn upload(mut payload: Multipart, filename: web::Data<String>, did: web::Data<i64>, uid: web::Data<i64>, data: web::Data<AppState>) -> Result<HttpResponse, AppError> { - let mut size = 0; - - while let Some(item) = payload.next().await { - let mut field = item.unwrap(); - - let filepath = format!("./uploads/{}", filename.as_str()); - - let mut file = web::block(|| std::fs::File::create(filepath)) - .await - .unwrap()?; +#[derive(Deserialize, MultipartForm, Debug)] +pub struct UploadForm { + #[multipart(max_size = 512MB)] + file_field: File, + uid: i64, + did: i64 +} - while let Some(chunk) = field.next().await { - let data = chunk.unwrap(); - size += data.len() as u64; - file = web::block(move || file.write_all(&data).map(|_| file)) - .await - .unwrap()?; +#[post("/v1.0/upload")] +async fn upload(payload: Multipart<UploadForm>, data: web::Data<AppState>) -> Result<HttpResponse, AppError> { + let uid = payload.uid; + async fn add_number_to_filename(file_name: String, conn:&DbConn, uid:i64) -> String { + let mut i = 0; + let mut new_file_name = file_name.to_string(); + let arr: Vec<&str> = file_name.split(".").collect(); + let suffix = String::from(arr[arr.len()-1]); + let preffix = arr[..arr.len()-1].join("."); + let mut docs = Query::find_doc_infos_by_name(conn, uid, new_file_name.clone()).await.unwrap(); + while docs.len()>0 { + i += 1; + new_file_name = format!("{}_{}.{}", preffix, i, suffix); + docs = Query::find_doc_infos_by_name(conn, uid, new_file_name.clone()).await.unwrap(); } + new_file_name } - - let _ = Mutation::create_doc_info(&data.conn, Model { - did: *did.into_inner(), - uid: *uid.into_inner(), - doc_name: filename.to_string(), - size, + let fnm = add_number_to_filename(payload.file_field.name.clone(), &data.conn, uid).await; + + std::fs::create_dir_all(format!("./upload/{}/", uid)); + let filepath = format!("./upload/{}/{}-{}", payload.uid, payload.did, fnm.clone()); + let mut f =std::fs::File::create(&filepath)?; + f.write(&payload.file_field.bytes)?; + + let doc = Mutation::create_doc_info(&data.conn, Model { + did:Default::default(), + uid: uid, + doc_name: fnm, + size: payload.file_field.bytes.len() as i64, kb_infos: Vec::new(), kb_progress: 0.0, - location: "".to_string(), - r#type: "".to_string(), + kb_progress_msg: "".to_string(), + location: filepath, + r#type: "doc".to_string(), created_at: Local::now().date_naive(), updated_at: Local::now().date_naive(), }).await?; + let _ = Mutation::place_doc(&data.conn, payload.did, doc.did.unwrap()).await?; + Ok(HttpResponse::Ok().body("File uploaded successfully")) } @@ -121,4 +139,4 @@ async fn mv(params: web::Json<MvParams>, data: web::Data<AppState>) -> Result<Ht Ok(HttpResponse::Ok() .content_type("application/json") .body(serde_json::to_string(&json_response)?)) -} \ No newline at end of file +} diff --git a/src/api/kb_info.rs b/src/api/kb_info.rs index d0545e3cbf28eb95710a852a89c19e91e8ee457e..80f75e01f55150bc1b6103691891b6ba306d3c1d 100644 --- a/src/api/kb_info.rs +++ b/src/api/kb_info.rs @@ -1,23 +1,58 @@ use std::collections::HashMap; use actix_web::{get, HttpResponse, post, web}; +use serde::Serialize; use crate::api::JsonResponse; use crate::AppState; use crate::entity::kb_info; use crate::errors::AppError; use crate::service::kb_info::Mutation; use crate::service::kb_info::Query; +use serde::Deserialize; +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AddDocs2KbParams { + pub uid: i64, + pub dids: Vec<i64>, + pub kb_id: i64, +} #[post("/v1.0/create_kb")] async fn create(model: web::Json<kb_info::Model>, data: web::Data<AppState>) -> Result<HttpResponse, AppError> { - let model = Mutation::create_kb_info(&data.conn, model.into_inner()).await?; + let mut docs = Query::find_kb_infos_by_name(&data.conn, model.kb_name.to_owned()).await.unwrap(); + if docs.len() >0 { + let json_response = JsonResponse { + code: 201, + err: "Duplicated name.".to_owned(), + data: () + }; + Ok(HttpResponse::Ok() + .content_type("application/json") + .body(serde_json::to_string(&json_response)?)) + }else{ + let model = Mutation::create_kb_info(&data.conn, model.into_inner()).await?; - let mut result = HashMap::new(); - result.insert("kb_id", model.kb_id.unwrap()); + let mut result = HashMap::new(); + result.insert("kb_id", model.kb_id.unwrap()); + + let json_response = JsonResponse { + code: 200, + err: "".to_owned(), + data: result, + }; + + Ok(HttpResponse::Ok() + .content_type("application/json") + .body(serde_json::to_string(&json_response)?)) + } +} + +#[post("/v1.0/add_docs_to_kb")] +async fn add_docs_to_kb(param: web::Json<AddDocs2KbParams>, data: web::Data<AppState>) -> Result<HttpResponse, AppError> { + let _ = Mutation::add_docs(&data.conn, param.kb_id, param.dids.to_owned()).await?; let json_response = JsonResponse { code: 200, err: "".to_owned(), - data: result, + data: (), }; Ok(HttpResponse::Ok() diff --git a/src/api/tag.rs b/src/api/tag.rs new file mode 100644 index 0000000000000000000000000000000000000000..b902f3d3a3c19c820202a7bdec2797abcb0a750b --- /dev/null +++ b/src/api/tag.rs @@ -0,0 +1,58 @@ +use std::collections::HashMap; +use actix_web::{get, HttpResponse, post, web}; +use actix_web::http::Error; +use crate::api::JsonResponse; +use crate::AppState; +use crate::entity::tag_info; +use crate::service::tag_info::{Mutation, Query}; + +#[post("/v1.0/create_tag")] +async fn create(model: web::Json<tag_info::Model>, data: web::Data<AppState>) -> Result<HttpResponse, Error> { + let model = Mutation::create_tag(&data.conn, model.into_inner()).await.unwrap(); + + let mut result = HashMap::new(); + result.insert("tid", model.tid.unwrap()); + + let json_response = JsonResponse { + code: 200, + err: "".to_owned(), + data: result, + }; + + Ok(HttpResponse::Ok() + .content_type("application/json") + .body(serde_json::to_string(&json_response).unwrap())) +} + +#[post("/v1.0/delete_tag")] +async fn delete(model: web::Json<tag_info::Model>, data: web::Data<AppState>) -> Result<HttpResponse, Error> { + let _ = Mutation::delete_tag(&data.conn, model.tid).await.unwrap(); + + let json_response = JsonResponse { + code: 200, + err: "".to_owned(), + data: (), + }; + + Ok(HttpResponse::Ok() + .content_type("application/json") + .body(serde_json::to_string(&json_response).unwrap())) +} + +#[get("/v1.0/tags")] +async fn list(data: web::Data<AppState>) -> Result<HttpResponse, Error> { + let tags = Query::find_tag_infos(&data.conn).await.unwrap(); + + let mut result = HashMap::new(); + result.insert("tags", tags); + + let json_response = JsonResponse { + code: 200, + err: "".to_owned(), + data: result, + }; + + Ok(HttpResponse::Ok() + .content_type("application/json") + .body(serde_json::to_string(&json_response).unwrap())) +} \ No newline at end of file diff --git a/src/entity/dialog2_kb.rs b/src/entity/dialog2_kb.rs index 18420f9eb07a8b84d677edf2208a897c4b5ef2a3..5a5f0f2cd4d8d3779cb32e344009e1c65b629de1 100644 --- a/src/entity/dialog2_kb.rs +++ b/src/entity/dialog2_kb.rs @@ -4,10 +4,11 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)] #[sea_orm(table_name = "dialog2_kb")] pub struct Model { - #[sea_orm(primary_key, auto_increment = false)] + #[sea_orm(primary_key, auto_increment = true)] + pub id: i64, #[sea_orm(index)] pub dialog_id: i64, - #[sea_orm(primary_key, auto_increment = false)] + #[sea_orm(index)] pub kb_id: i64, } diff --git a/src/entity/doc2_doc.rs b/src/entity/doc2_doc.rs index f987a2d557082eafd8b983a87d38057c4fe6a972..dff8daf95aa515d2a089ff5163d98bd19e9471be 100644 --- a/src/entity/doc2_doc.rs +++ b/src/entity/doc2_doc.rs @@ -4,10 +4,11 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)] #[sea_orm(table_name = "doc2_doc")] pub struct Model { - #[sea_orm(primary_key, auto_increment = false)] + #[sea_orm(primary_key, auto_increment = true)] + pub id: i64, #[sea_orm(index)] pub parent_id: i64, - #[sea_orm(primary_key, auto_increment = false)] + #[sea_orm(index)] pub did: i64, } diff --git a/src/entity/doc_info.rs b/src/entity/doc_info.rs index 61e0a371cfcff33e85f0cee77640206e3a484ecd..d2ccd411041b48c8c663f873b48b44e87b469d68 100644 --- a/src/entity/doc_info.rs +++ b/src/entity/doc_info.rs @@ -10,10 +10,11 @@ pub struct Model { #[sea_orm(index)] pub uid: i64, pub doc_name: String, - pub size: u64, + pub size: i64, #[sea_orm(column_name = "type")] pub r#type: String, - pub kb_progress: f64, + pub kb_progress: f32, + pub kb_progress_msg: String, pub location: String, #[sea_orm(ignore)] pub kb_infos: Vec<kb_info::Model>, @@ -57,4 +58,4 @@ impl Related<Entity> for Entity { } } -impl ActiveModelBehavior for ActiveModel {} \ No newline at end of file +impl ActiveModelBehavior for ActiveModel {} diff --git a/src/entity/kb2_doc.rs b/src/entity/kb2_doc.rs index f9d923430430f3f5ea35589f151c5ba2a4e0cec4..1d82756f0270a3812949b12a7414bd07d6040b19 100644 --- a/src/entity/kb2_doc.rs +++ b/src/entity/kb2_doc.rs @@ -4,11 +4,12 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)] #[sea_orm(table_name = "kb2_doc")] pub struct Model { - #[sea_orm(primary_key, auto_increment = false)] + #[sea_orm(primary_key, auto_increment = true)] + pub id: i64, #[sea_orm(index)] pub kb_id: i64, - #[sea_orm(primary_key, auto_increment = false)] - pub uid: i64, + #[sea_orm(index)] + pub did: i64, } #[derive(Debug, Clone, Copy, EnumIter)] @@ -21,8 +22,8 @@ impl RelationTrait for Relation { fn def(&self) -> RelationDef { match self { Self::DocInfo => Entity::belongs_to(super::doc_info::Entity) - .from(Column::Uid) - .to(super::doc_info::Column::Uid) + .from(Column::Did) + .to(super::doc_info::Column::Did) .into(), Self::KbInfo => Entity::belongs_to(super::kb_info::Entity) .from(Column::KbId) diff --git a/src/entity/kb_info.rs b/src/entity/kb_info.rs index 46ce903b19eed601789217d89232c9f3f77eb41d..97ea2bb9d097218edded7c0e152b2aa6f25c757a 100644 --- a/src/entity/kb_info.rs +++ b/src/entity/kb_info.rs @@ -8,8 +8,8 @@ pub struct Model { pub kb_id: i64, #[sea_orm(index)] pub uid: i64, - pub kn_name: String, - pub icon: i64, + pub kb_name: String, + pub icon: i16, #[serde(skip_deserializing)] pub created_at: Date, diff --git a/src/entity/mod.rs b/src/entity/mod.rs index 6fbdee3c1354f77873cf3d7ff4c1bbfe5bf71099..d0abbd31d091cf502d48bf7ab0fa77b4ff91f28c 100644 --- a/src/entity/mod.rs +++ b/src/entity/mod.rs @@ -1,8 +1,8 @@ pub(crate) mod user_info; pub(crate) mod tag_info; -mod tag2_doc; -mod kb2_doc; -mod dialog2_kb; +pub(crate) mod tag2_doc; +pub(crate) mod kb2_doc; +pub(crate) mod dialog2_kb; pub(crate) mod doc2_doc; pub(crate) mod kb_info; pub(crate) mod doc_info; diff --git a/src/entity/tag2_doc.rs b/src/entity/tag2_doc.rs index 61453c4f3c409eee021ad86b9dca9da16ca8dfc0..3825fe5f33d4e0442c686a46e4ce3e78105c15fa 100644 --- a/src/entity/tag2_doc.rs +++ b/src/entity/tag2_doc.rs @@ -4,10 +4,11 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)] #[sea_orm(table_name = "tag2_doc")] pub struct Model { - #[sea_orm(primary_key, auto_increment = false)] + #[sea_orm(primary_key, auto_increment = true)] + pub id: i64, #[sea_orm(index)] pub tag_id: i64, - #[sea_orm(primary_key, auto_increment = false)] + #[sea_orm(index)] pub uid: i64, } diff --git a/src/entity/tag_info.rs b/src/entity/tag_info.rs index 34d7e5975091cad1d3c9acbde50eaf050e38180f..b6c1a4ad5c9834aa56ccc9184d4ef4a0caffcb12 100644 --- a/src/entity/tag_info.rs +++ b/src/entity/tag_info.rs @@ -10,8 +10,8 @@ pub struct Model { pub uid: i64, pub tag_name: String, pub regx: Option<String>, - pub color: i64, - pub icon: i64, + pub color: u16, + pub icon: u16, pub dir: Option<String>, #[serde(skip_deserializing)] diff --git a/src/main.rs b/src/main.rs index 9a990735897975b345b6f4c4571f8358541d0c91..e1149163e76f4546dc85e608459c9ca8acec6bca 100644 --- a/src/main.rs +++ b/src/main.rs @@ -97,6 +97,7 @@ fn init(cfg: &mut web::ServiceConfig) { cfg.service(api::kb_info::create); cfg.service(api::kb_info::delete); cfg.service(api::kb_info::list); + cfg.service(api::kb_info::add_docs_to_kb); cfg.service(api::doc_info::list); cfg.service(api::doc_info::delete); diff --git a/src/service/doc_info.rs b/src/service/doc_info.rs index ed5e99caa2a631acb220436fb15ebe7f15d5e9b7..a97d79581e305f3fd28c0322333db23dce3750db 100644 --- a/src/service/doc_info.rs +++ b/src/service/doc_info.rs @@ -1,5 +1,5 @@ use chrono::Local; -use sea_orm::{ActiveModelTrait, ColumnTrait, DbConn, DbErr, DeleteResult, EntityTrait, PaginatorTrait, QueryOrder}; +use sea_orm::{ActiveModelTrait, ColumnTrait, DbConn, DbErr, DeleteResult, EntityTrait, PaginatorTrait, QueryOrder, Unset, Unchanged, ConditionalStatement}; use sea_orm::ActiveValue::Set; use sea_orm::QueryFilter; use crate::api::doc_info::Params; @@ -24,6 +24,14 @@ impl Query { .await } + pub async fn find_doc_infos_by_name(db: &DbConn, uid: i64, name: String) -> Result<Vec<doc_info::Model>, DbErr> { + Entity::find() + .filter(doc_info::Column::DocName.eq(name)) + .filter(doc_info::Column::Uid.eq(uid)) + .all(db) + .await + } + pub async fn find_doc_infos_by_params(db: &DbConn, params: Params) -> Result<Vec<doc_info::Model>, DbErr> { // Setup paginator let paginator = Entity::find(); @@ -80,18 +88,34 @@ impl Mutation { dids: &[i64] ) -> Result<(), DbErr> { for did in dids { + let d = doc2_doc::Entity::find().filter(doc2_doc::Column::Did.eq(did.to_owned())).all(db).await?; + let _ = doc2_doc::ActiveModel { - parent_id: Set(dest_did), - did: Set(*did), + id: Set(d[0].id), + did: Set(did.to_owned()), + parent_id: Set(dest_did) } - .save(db) - .await - .unwrap(); + .update(db) + .await?; } Ok(()) } + pub async fn place_doc( + db: &DbConn, + dest_did: i64, + did: i64 + ) -> Result<doc2_doc::ActiveModel, DbErr> { + doc2_doc::ActiveModel { + id: Default::default(), + parent_id: Set(dest_did), + did: Set(did), + } + .save(db) + .await + } + pub async fn create_doc_info( db: &DbConn, form_data: doc_info::Model, @@ -103,6 +127,7 @@ impl Mutation { size: Set(form_data.size.to_owned()), r#type: Set(form_data.r#type.to_owned()), kb_progress: Set(form_data.kb_progress.to_owned()), + kb_progress_msg: Set(form_data.kb_progress_msg.to_owned()), location: Set(form_data.location.to_owned()), created_at: Set(Local::now().date_naive()), updated_at: Set(Local::now().date_naive()), @@ -129,6 +154,7 @@ impl Mutation { size: Set(form_data.size.to_owned()), r#type: Set(form_data.r#type.to_owned()), kb_progress: Set(form_data.kb_progress.to_owned()), + kb_progress_msg: Set(form_data.kb_progress_msg.to_owned()), location: Set(form_data.location.to_owned()), created_at: Default::default(), updated_at: Set(Local::now().date_naive()), @@ -150,4 +176,4 @@ impl Mutation { pub async fn delete_all_doc_infos(db: &DbConn) -> Result<DeleteResult, DbErr> { Entity::delete_many().exec(db).await } -} \ No newline at end of file +} diff --git a/src/service/kb_info.rs b/src/service/kb_info.rs index 7580596e282132190ddb05d159247bab8d30ad6f..c20edce41e3753b64fa5a966fee7da1e0a50dd06 100644 --- a/src/service/kb_info.rs +++ b/src/service/kb_info.rs @@ -2,6 +2,7 @@ use chrono::Local; use sea_orm::{ActiveModelTrait, ColumnTrait, DbConn, DbErr, DeleteResult, EntityTrait, PaginatorTrait, QueryFilter, QueryOrder}; use sea_orm::ActiveValue::Set; use crate::entity::kb_info; +use crate::entity::kb2_doc; use crate::entity::kb_info::Entity; pub struct Query; @@ -21,6 +22,13 @@ impl Query { .all(db) .await } + + pub async fn find_kb_infos_by_name(db: &DbConn, name: String) -> Result<Vec<kb_info::Model>, DbErr> { + Entity::find() + .filter(kb_info::Column::KbName.eq(name)) + .all(db) + .await + } pub async fn find_kb_infos_in_page( db: &DbConn, @@ -48,7 +56,7 @@ impl Mutation { kb_info::ActiveModel { kb_id: Default::default(), uid: Set(form_data.uid.to_owned()), - kn_name: Set(form_data.kn_name.to_owned()), + kb_name: Set(form_data.kb_name.to_owned()), icon: Set(form_data.icon.to_owned()), created_at: Set(Local::now().date_naive()), updated_at: Set(Local::now().date_naive()), @@ -57,6 +65,24 @@ impl Mutation { .await } + pub async fn add_docs( + db: &DbConn, + kb_id: i64, + doc_ids: Vec<i64> + )-> Result<(), DbErr> { + for did in doc_ids{ + let _ = kb2_doc::ActiveModel { + id: Default::default(), + kb_id: Set(kb_id), + did: Set(did), + } + .save(db) + .await?; + } + + Ok(()) + } + pub async fn update_kb_info_by_id( db: &DbConn, id: i64, @@ -71,7 +97,7 @@ impl Mutation { kb_info::ActiveModel { kb_id: kb_info.kb_id, uid: kb_info.uid, - kn_name: Set(form_data.kn_name.to_owned()), + kb_name: Set(form_data.kb_name.to_owned()), icon: Set(form_data.icon.to_owned()), created_at: Default::default(), updated_at: Set(Local::now().date_naive()),