From bf2e3d7fc1e84425c7693a9cfa06cf46902366d4 Mon Sep 17 00:00:00 2001 From: KevinHuSh <kevinhu.sh@gmail.com> Date: Wed, 27 Mar 2024 17:55:45 +0800 Subject: [PATCH] refine OpenAi Api (#159) --- README.md | 2 +- api/apps/llm_app.py | 2 +- rag/app/picture.py | 2 +- rag/llm/chat_model.py | 10 +++++----- rag/llm/embedding_model.py | 12 ++++++++---- 5 files changed, 16 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index c3123cc..04f512c 100644 --- a/README.md +++ b/README.md @@ -127,7 +127,7 @@ Open your browser, enter the IP address of your server, _**Hallelujah**_ again! # System Architecture Diagram <div align="center" style="margin-top:20px;margin-bottom:20px;"> -<img src="https://github.com/infiniflow/ragflow/assets/12318111/39c8e546-51ca-4b50-a1da-83731b540cd0" width="1000"/> +<img src="https://github.com/infiniflow/ragflow/assets/12318111/d6ac5664-c237-4200-a7c2-a4a00691b485" width="1000"/> </div> # Configuration diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 0a98ab0..e8b3dcd 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -51,7 +51,7 @@ def set_api_key(): if len(arr[0]) == 0 or tc == 0: raise Exception("Fail") except Exception as e: - msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e) elif not chat_passed and llm.model_type == LLMType.CHAT.value: mdl = ChatModel[factory]( req["api_key"], llm.llm_name) diff --git a/rag/app/picture.py b/rag/app/picture.py index fdaccc2..fbbc9f3 100644 --- a/rag/app/picture.py +++ b/rag/app/picture.py @@ -29,7 +29,7 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs): except Exception as e: callback(prog=-1, msg=str(e)) return [] - img = Image.open(io.BytesIO(binary)) + img = Image.open(io.BytesIO(binary)).convert('RGB') doc = { "docnm_kwd": filename, "image": img diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index d5ddbc0..e44af53 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -43,8 +43,8 @@ class GptTurbo(Base): model=self.model_name, messages=history, **gen_conf) - ans = response.output.choices[0]['message']['content'].strip() - if response.output.choices[0].get("finish_reason", "") == "length": + ans = response.choices[0].message.content.strip() + if response.choices[0].finish_reason == "length": ans += "...\nFor the content length reason, it stopped, continue?" if is_english( [ans]) else "······\nç”±äşŽé•żĺş¦çš„ĺŽźĺ› ďĽŚĺ›žç”被ćŞć–了,č¦ç»§ç»ĺ—?" return ans, response.usage.completion_tokens @@ -114,12 +114,12 @@ class ZhipuChat(Base): history.insert(0, {"role": "system", "content": system}) try: response = self.client.chat.completions.create( - self.model_name, + model=self.model_name, messages=history, **gen_conf ) - ans = response.output.choices[0]['message']['content'].strip() - if response.output.choices[0].get("finish_reason", "") == "length": + ans = response.choices[0].message.content.strip() + if response.choices[0].finish_reason == "length": ans += "...\nFor the content length reason, it stopped, continue?" if is_english( [ans]) else "······\nç”±äşŽé•żĺş¦çš„ĺŽźĺ› ďĽŚĺ›žç”被ćŞć–了,č¦ç»§ç»ĺ—?" return ans, response.usage.completion_tokens diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index c2fe24b..68a6e0a 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -139,12 +139,16 @@ class ZhipuEmbed(Base): self.model_name = model_name def encode(self, texts: list, batch_size=32): - res = self.client.embeddings.create(input=texts, + arr = [] + tks_num = 0 + for txt in texts: + res = self.client.embeddings.create(input=txt, model=self.model_name) - return np.array([d.embedding for d in res.data] - ), res.usage.total_tokens + arr.append(res.data[0].embedding) + tks_num += res.usage.total_tokens + return np.array(arr), tks_num def encode_queries(self, text): res = self.client.embeddings.create(input=text, model=self.model_name) - return np.array(res["data"][0]["embedding"]), res.usage.total_tokens + return np.array(res.data[0].embedding), res.usage.total_tokens -- GitLab