Ver Fonte

fix: gemini no content bug

zhayujie há 2 anos atrás
pai
commit
eca1892e2a
1 ficheiros alterados com 32 adições e 15 exclusões
  1. 32 15
      bot/gemini/google_gemini_bot.py

+ 32 - 15
bot/gemini/google_gemini_bot.py

@@ -26,21 +26,24 @@ class GoogleGeminiBot(Bot):
         self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo")
 
     def reply(self, query, context: Context = None) -> Reply:
-        if context.type != ContextType.TEXT:
-            logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
-            return Reply(ReplyType.TEXT, None)
-        logger.info(f"[Gemini] query={query}")
-        session_id = context["session_id"]
-        session = self.sessions.session_query(query, session_id)
-        gemini_messages = self._convert_to_gemini_messages(session.messages)
-        genai.configure(api_key=self.api_key)
-        model = genai.GenerativeModel('gemini-pro')
-        response = model.generate_content(gemini_messages)
-        reply_text = response.text
-        self.sessions.session_reply(reply_text, session_id)
-        logger.info(f"[Gemini] reply={reply_text}")
-        return Reply(ReplyType.TEXT, reply_text)
-
+        try:
+            if context.type != ContextType.TEXT:
+                logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
+                return Reply(ReplyType.TEXT, None)
+            logger.info(f"[Gemini] query={query}")
+            session_id = context["session_id"]
+            session = self.sessions.session_query(query, session_id)
+            gemini_messages = self._convert_to_gemini_messages(self._filter_messages(session.messages))
+            genai.configure(api_key=self.api_key)
+            model = genai.GenerativeModel('gemini-pro')
+            response = model.generate_content(gemini_messages)
+            reply_text = response.text
+            self.sessions.session_reply(reply_text, session_id)
+            logger.info(f"[Gemini] reply={reply_text}")
+            return Reply(ReplyType.TEXT, reply_text)
+        except Exception as e:
+            logger.error("[Gemini] fetch reply error, may contain unsafe content")
+            logger.error(e)
 
     def _convert_to_gemini_messages(self, messages: list):
         res = []
@@ -56,3 +59,17 @@ class GoogleGeminiBot(Bot):
                 "parts": [{"text": msg.get("content")}]
             })
         return res
+
+    def _filter_messages(self, messages: list):
+        res = []
+        turn = "user"
+        for i in range(len(messages) - 1, -1, -1):
+            message = messages[i]
+            if message.get("role") != turn:
+                continue
+            res.insert(0, message)
+            if turn == "user":
+                turn = "assistant"
+            elif turn == "assistant":
+                turn = "user"
+        return res