|
|
@@ -26,21 +26,24 @@ class GoogleGeminiBot(Bot):
|
|
|
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo")
|
|
|
|
|
|
def reply(self, query, context: Context = None) -> Reply:
|
|
|
- if context.type != ContextType.TEXT:
|
|
|
- logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
|
|
|
- return Reply(ReplyType.TEXT, None)
|
|
|
- logger.info(f"[Gemini] query={query}")
|
|
|
- session_id = context["session_id"]
|
|
|
- session = self.sessions.session_query(query, session_id)
|
|
|
- gemini_messages = self._convert_to_gemini_messages(session.messages)
|
|
|
- genai.configure(api_key=self.api_key)
|
|
|
- model = genai.GenerativeModel('gemini-pro')
|
|
|
- response = model.generate_content(gemini_messages)
|
|
|
- reply_text = response.text
|
|
|
- self.sessions.session_reply(reply_text, session_id)
|
|
|
- logger.info(f"[Gemini] reply={reply_text}")
|
|
|
- return Reply(ReplyType.TEXT, reply_text)
|
|
|
-
|
|
|
+ try:
|
|
|
+ if context.type != ContextType.TEXT:
|
|
|
+ logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
|
|
|
+ return Reply(ReplyType.TEXT, None)
|
|
|
+ logger.info(f"[Gemini] query={query}")
|
|
|
+ session_id = context["session_id"]
|
|
|
+ session = self.sessions.session_query(query, session_id)
|
|
|
+ gemini_messages = self._convert_to_gemini_messages(self._filter_messages(session.messages))
|
|
|
+ genai.configure(api_key=self.api_key)
|
|
|
+ model = genai.GenerativeModel('gemini-pro')
|
|
|
+ response = model.generate_content(gemini_messages)
|
|
|
+ reply_text = response.text
|
|
|
+ self.sessions.session_reply(reply_text, session_id)
|
|
|
+ logger.info(f"[Gemini] reply={reply_text}")
|
|
|
+ return Reply(ReplyType.TEXT, reply_text)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error("[Gemini] fetch reply error, may contain unsafe content")
|
|
|
+ logger.error(e)
|
|
|
|
|
|
def _convert_to_gemini_messages(self, messages: list):
|
|
|
res = []
|
|
|
@@ -56,3 +59,17 @@ class GoogleGeminiBot(Bot):
|
|
|
"parts": [{"text": msg.get("content")}]
|
|
|
})
|
|
|
return res
|
|
|
+
|
|
|
+ def _filter_messages(self, messages: list):
|
|
|
+ res = []
|
|
|
+ turn = "user"
|
|
|
+ for i in range(len(messages) - 1, -1, -1):
|
|
|
+ message = messages[i]
|
|
|
+ if message.get("role") != turn:
|
|
|
+ continue
|
|
|
+ res.insert(0, message)
|
|
|
+ if turn == "user":
|
|
|
+ turn = "assistant"
|
|
|
+ elif turn == "assistant":
|
|
|
+ turn = "user"
|
|
|
+ return res
|