From bf2e3d7fc1e84425c7693a9cfa06cf46902366d4 Mon Sep 17 00:00:00 2001
From: KevinHuSh <kevinhu.sh@gmail.com>
Date: Wed, 27 Mar 2024 17:55:45 +0800
Subject: [PATCH] refine OpenAi Api (#159)

---
 README.md                  |  2 +-
 api/apps/llm_app.py        |  2 +-
 rag/app/picture.py         |  2 +-
 rag/llm/chat_model.py      | 10 +++++-----
 rag/llm/embedding_model.py | 12 ++++++++----
 5 files changed, 16 insertions(+), 12 deletions(-)

diff --git a/README.md b/README.md
index c3123cc..04f512c 100644
--- a/README.md
+++ b/README.md
@@ -127,7 +127,7 @@ Open your browser, enter the IP address of your server, _**Hallelujah**_ again!
 # System Architecture Diagram
 
 <div align="center" style="margin-top:20px;margin-bottom:20px;">
-<img src="https://github.com/infiniflow/ragflow/assets/12318111/39c8e546-51ca-4b50-a1da-83731b540cd0" width="1000"/>
+<img src="https://github.com/infiniflow/ragflow/assets/12318111/d6ac5664-c237-4200-a7c2-a4a00691b485" width="1000"/>
 </div>
 
 # Configuration
diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py
index 0a98ab0..e8b3dcd 100644
--- a/api/apps/llm_app.py
+++ b/api/apps/llm_app.py
@@ -51,7 +51,7 @@ def set_api_key():
                 if len(arr[0]) == 0 or tc == 0:
                     raise Exception("Fail")
             except Exception as e:
-                msg += f"\nFail to access embedding model({llm.llm_name}) using this api key."
+                msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e)
         elif not chat_passed and llm.model_type == LLMType.CHAT.value:
             mdl = ChatModel[factory](
                 req["api_key"], llm.llm_name)
diff --git a/rag/app/picture.py b/rag/app/picture.py
index fdaccc2..fbbc9f3 100644
--- a/rag/app/picture.py
+++ b/rag/app/picture.py
@@ -29,7 +29,7 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
     except Exception as e:
         callback(prog=-1, msg=str(e))
         return []
-    img = Image.open(io.BytesIO(binary))
+    img = Image.open(io.BytesIO(binary)).convert('RGB')
     doc = {
         "docnm_kwd": filename,
         "image": img
diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py
index d5ddbc0..e44af53 100644
--- a/rag/llm/chat_model.py
+++ b/rag/llm/chat_model.py
@@ -43,8 +43,8 @@ class GptTurbo(Base):
                 model=self.model_name,
                 messages=history,
                 **gen_conf)
-            ans = response.output.choices[0]['message']['content'].strip()
-            if response.output.choices[0].get("finish_reason", "") == "length":
+            ans = response.choices[0].message.content.strip()
+            if response.choices[0].finish_reason == "length":
                 ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
                     [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
             return ans, response.usage.completion_tokens
@@ -114,12 +114,12 @@ class ZhipuChat(Base):
             history.insert(0, {"role": "system", "content": system})
         try:
             response = self.client.chat.completions.create(
-                self.model_name,
+                model=self.model_name,
                 messages=history,
                 **gen_conf
             )
-            ans = response.output.choices[0]['message']['content'].strip()
-            if response.output.choices[0].get("finish_reason", "") == "length":
+            ans = response.choices[0].message.content.strip()
+            if response.choices[0].finish_reason == "length":
                 ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
                     [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
             return ans, response.usage.completion_tokens
diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py
index c2fe24b..68a6e0a 100644
--- a/rag/llm/embedding_model.py
+++ b/rag/llm/embedding_model.py
@@ -139,12 +139,16 @@ class ZhipuEmbed(Base):
         self.model_name = model_name
 
     def encode(self, texts: list, batch_size=32):
-        res = self.client.embeddings.create(input=texts,
+        arr = []
+        tks_num = 0
+        for txt in texts:
+            res = self.client.embeddings.create(input=txt,
                                             model=self.model_name)
-        return np.array([d.embedding for d in res.data]
-                        ), res.usage.total_tokens
+            arr.append(res.data[0].embedding)
+            tks_num += res.usage.total_tokens
+        return np.array(arr), tks_num
 
     def encode_queries(self, text):
         res = self.client.embeddings.create(input=text,
                                             model=self.model_name)
-        return np.array(res["data"][0]["embedding"]), res.usage.total_tokens
+        return np.array(res.data[0].embedding), res.usage.total_tokens
-- 
GitLab