From 1567e881ded135d3be6e1271907128ad35df6f4a Mon Sep 17 00:00:00 2001 From: KevinHuSh <kevinhu.sh@gmail.com> Date: Tue, 27 Feb 2024 17:51:54 +0800 Subject: [PATCH] fix bug of inserting cites (#76) --- api/apps/user_app.py | 6 +++--- api/db/init_data.py | 8 ++++---- api/db/services/common_service.py | 25 ++++++++++++++++++++----- conf/service_conf.yaml | 6 +++--- rag/nlp/search.py | 2 +- 5 files changed, 31 insertions(+), 16 deletions(-) diff --git a/api/apps/user_app.py b/api/apps/user_app.py index 0f08e0e..8b5ba4a 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -208,9 +208,9 @@ def user_register(user_id, user): for llm in LLMService.query(fid=LLM_FACTORY): tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY}) - if not UserService.save(**user):return - TenantService.save(**tenant) - UserTenantService.save(**usr_tenant) + if not UserService.insert(**user):return + TenantService.insert(**tenant) + UserTenantService.insert(**usr_tenant) TenantLLMService.insert_many(tenant_llm) return UserService.query(email=user["email"]) diff --git a/api/db/init_data.py b/api/db/init_data.py index ee91fd8..531abf4 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -58,16 +58,16 @@ def init_superuser(): if not UserService.save(**user_info): print("ă€ERROR】can't init admin.") return - TenantService.save(**tenant) - UserTenantService.save(**usr_tenant) + TenantService.insert(**tenant) + UserTenantService.insert(**usr_tenant) TenantLLMService.insert_many(tenant_llm) - UserService.save(**user_info) + print("ă€INFO】Super user initialized. user name: admin, password: admin. Changing the password after logining is strongly recomanded.") chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}) if msg.find("ERROR: ") == 0: print("ă€ERROR】: '{}' dosen't work. {}".format(tenant["llm_id"]), msg) - embd_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["embd_id"]) + embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"]) v,c = embd_mdl.encode(["Hello!"]) if c == 0: print("ă€ERROR】: '{}' dosen't work...".format(tenant["embd_id"])) diff --git a/api/db/services/common_service.py b/api/db/services/common_service.py index 6ae1c35..fbbb645 100644 --- a/api/db/services/common_service.py +++ b/api/db/services/common_service.py @@ -18,7 +18,7 @@ from datetime import datetime import peewee from api.db.db_models import DB -from api.utils import datetime_format +from api.utils import datetime_format, current_timestamp, get_uuid class CommonService: @@ -66,27 +66,42 @@ class CommonService: sample_obj = cls.model(**kwargs).save(force_insert=True) return sample_obj + @classmethod + @DB.connection_context() + def insert(cls, **kwargs): + if "id" not in kwargs: + kwargs["id"] = get_uuid() + kwargs["create_time"] = current_timestamp() + kwargs["create_date"] = datetime_format(datetime.now()) + kwargs["update_time"] = current_timestamp() + kwargs["update_date"] = datetime_format(datetime.now()) + sample_obj = cls.model(**kwargs).save(force_insert=True) + return sample_obj + @classmethod @DB.connection_context() def insert_many(cls, data_list, batch_size=100): with DB.atomic(): - for d in data_list: d["create_time"] = datetime_format(datetime.now()) + for d in data_list: + d["create_time"] = current_timestamp() + d["create_date"] = datetime_format(datetime.now()) for i in range(0, len(data_list), batch_size): cls.model.insert_many(data_list[i:i + batch_size]).execute() @classmethod @DB.connection_context() def update_many_by_id(cls, data_list): - cur = datetime_format(datetime.now()) with DB.atomic(): for data in data_list: - data["update_time"] = cur + data["update_time"] = current_timestamp() + data["update_date"] = datetime_format(datetime.now()) cls.model.update(data).where(cls.model.id == data["id"]).execute() @classmethod @DB.connection_context() def update_by_id(cls, pid, data): - data["update_time"] = datetime_format(datetime.now()) + data["update_time"] = current_timestamp() + data["update_date"] = datetime_format(datetime.now()) num = cls.model.update(data).where(cls.model.id == pid).execute() return num diff --git a/conf/service_conf.yaml b/conf/service_conf.yaml index 1aa8cb8..34b357c 100644 --- a/conf/service_conf.yaml +++ b/conf/service_conf.yaml @@ -17,16 +17,16 @@ database: name: 'rag_flow' user: 'root' passwd: 'infini_rag_flow' - host: '123.60.95.134' + host: '127.0.0.1' port: 5455 max_connections: 100 stale_timeout: 30 minio: user: 'rag_flow' passwd: 'infini_rag_flow' - host: '123.60.95.134:9000' + host: '127.0.0.1:9000' es: - hosts: 'http://123.60.95.134:9200' + hosts: 'http://127.0.0.1:9200' user_default_llm: factory: '通义ĺŤé—®' chat_model: 'qwen-plus' diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 5f9fb70..e031888 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -226,7 +226,7 @@ class Dealer: continue if i not in cites: continue - assert int(cites[i]) < len(chunk_v) + for c in cites[i]: assert int(c) < len(chunk_v) res += "##%s$$" % "$".join(cites[i]) return res -- GitLab