diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 648ecd3d91d3ee5849eb564093b11b97572e6baf..9ee201f9043b47fdfe6f3203132cceacffe3e01b 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -309,13 +309,13 @@ def use_sql(question, field_map, tenant_id, chat_mdl): # compose markdown table clmns = "|"+"|".join([re.sub(r"(/.*|ďĽ[^ďĽďĽ‰]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|原文|" if docid_idx and docid_idx else "|") line = "|"+"|".join(["------" for _ in range(len(clmn_idx))]) + ("|------|" if docid_idx and docid_idx else "") - line = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}\|", "|", line) rows = ["|"+"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]] if not docid_idx or not docnm_idx: chat_logger.warning("SQL missing field: " + sql) return "\n".join([clmns, line, "\n".join(rows)]), [] rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) + rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows) docid_idx = list(docid_idx)[0] docnm_idx = list(docnm_idx)[0] return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]] diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index e89c85592104812d08c273d4d99d279b32c04930..d6a7b8176127e7002517de1ac1b8a1802877b417 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -39,36 +39,40 @@ def factories(): def set_api_key(): req = request.json # test if api key works + chat_passed = False + factory = req["llm_factory"] msg = "" - for llm in LLMService.query(fid=req["llm_factory"]): + for llm in LLMService.query(fid=factory): if llm.model_type == LLMType.EMBEDDING.value: - mdl = EmbeddingModel[req["llm_factory"]]( + mdl = EmbeddingModel[factory]( req["api_key"], llm.llm_name) try: arr, tc = mdl.encode(["Test if the api key is available"]) if len(arr[0]) == 0 or tc ==0: raise Exception("Fail") except Exception as e: msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." - elif llm.model_type == LLMType.CHAT.value: - mdl = ChatModel[req["llm_factory"]]( + elif not chat_passed and llm.model_type == LLMType.CHAT.value: + mdl = ChatModel[factory]( req["api_key"], llm.llm_name) try: m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9}) if not tc: raise Exception(m) + chat_passed = True except Exception as e: msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(e) if msg: return get_data_error_result(retmsg=msg) llm = { - "tenant_id": current_user.id, - "llm_factory": req["llm_factory"], "api_key": req["api_key"] } for n in ["model_type", "llm_name"]: if n in req: llm[n] = req[n] - TenantLLMService.filter_update([TenantLLM.tenant_id==llm["tenant_id"], TenantLLM.llm_factory==llm["llm_factory"]], llm) + if not TenantLLMService.filter_update([TenantLLM.tenant_id==current_user.id, TenantLLM.llm_factory==factory], llm): + for llm in LLMService.query(fid=factory): + TenantLLMService.save(tenant_id=current_user.id, llm_factory=factory, llm_name=llm.llm_name, model_type=llm.model_type, api_key=req["api_key"]) + return get_json_result(data=True) diff --git a/api/db/db_models.py b/api/db/db_models.py index 020899198683efc8408a97b7a68c69dbe87d7298..c32b4d56b5df9a372edb12e73d064e6e218ef2a0 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -429,7 +429,7 @@ class LLMFactories(DataBaseModel): class LLM(DataBaseModel): # LLMs dictionary - llm_name = CharField(max_length=128, null=False, help_text="LLM name", index=True) + llm_name = CharField(max_length=128, null=False, help_text="LLM name", index=True, primary_key=True) model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") fid = CharField(max_length=128, null=False, help_text="LLM factory id") max_tokens = IntegerField(default=0) diff --git a/api/db/init_data.py b/api/db/init_data.py index 189e5437be463660e103e627cf80f145e0757aee..5696cd8d094f55615a2a3f3be16e1c4d08caedbd 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -73,41 +73,41 @@ def init_superuser(): print("\33[91mă€ERROR】\33[0m:", " '{}' dosen't work!".format(tenant["embd_id"])) +factory_infos = [{ + "name": "OpenAI", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", + "status": "1", + },{ + "name": "通义ĺŤé—®", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", + "status": "1", + },{ + "name": "智谱AI", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", + "status": "1", + }, + { + "name": "Local", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", + "status": "1", + },{ + "name": "Moonshot", + "logo": "", + "tags": "LLM,TEXT EMBEDDING", + "status": "1", + } + # { + # "name": "ć–‡ĺżä¸€č¨€", + # "logo": "", + # "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", + # "status": "1", + # }, +] def init_llm_factory(): - factory_infos = [{ - "name": "OpenAI", - "logo": "", - "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", - "status": "1", - },{ - "name": "通义ĺŤé—®", - "logo": "", - "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", - "status": "1", - },{ - "name": "智谱AI", - "logo": "", - "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", - "status": "1", - }, - { - "name": "Local", - "logo": "", - "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", - "status": "1", - },{ - "name": "Moonshot", - "logo": "", - "tags": "LLM,TEXT EMBEDDING", - "status": "1", - } - # { - # "name": "ć–‡ĺżä¸€č¨€", - # "logo": "", - # "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", - # "status": "1", - # }, - ] llm_infos = [ # ---------------------- OpenAI ------------------------ { @@ -260,21 +260,30 @@ def init_llm_factory(): }, ] for info in factory_infos: - LLMFactoriesService.save(**info) + try: + LLMFactoriesService.save(**info) + except Exception as e: + pass for info in llm_infos: - LLMService.save(**info) + try: + LLMService.save(**info) + except Exception as e: + pass def init_web_data(): start_time = time.time() - if not LLMService.get_all().count():init_llm_factory() + if LLMFactoriesService.get_all().count() != len(factory_infos): + init_llm_factory() if not UserService.get_all().count(): init_superuser() print("init web data success:{}".format(time.time() - start_time)) + if __name__ == '__main__': init_web_db() - init_web_data() \ No newline at end of file + init_web_data() + add_tenant_llm() \ No newline at end of file diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 67dd70cae938c37fb68caf26e6aec94fb73aca44..5bb54b138e52aba0341ad758049be3fbc8671869 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -53,7 +53,7 @@ class TenantLLMService(CommonService): cls.model.used_tokens ] objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where( - cls.model.tenant_id == tenant_id).dicts() + cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts() return list(objs) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 57d6480c51cd2bb4ed157a8ca1358609be92acb8..2d9d9f5df4667fc92cb989b2d3fbc8da815920f0 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -54,6 +54,21 @@ class MoonshotChat(GptTurbo): self.client = OpenAI(api_key=key, base_url="https://api.moonshot.cn/v1",) self.model_name = model_name + def chat(self, system, history, gen_conf): + if system: history.insert(0, {"role": "system", "content": system}) + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=history, + **gen_conf) + ans = response.choices[0].message.content.strip() + if response.choices[0].finish_reason == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\nç”±äşŽé•żĺş¦çš„ĺŽźĺ› ďĽŚĺ›žç”被ćŞć–了,č¦ç»§ç»ĺ—?" + return ans, response.usage.completion_tokens + except openai.APIError as e: + return "**ERROR**: "+str(e), 0 + from dashscope import Generation class QWenChat(Base):