Explorar o código

feat: support gemini model

zhayujie %!s(int64=2) %!d(string=hai) anos
pai
achega
23a237074e

+ 5 - 0
bot/bot_factory.py

@@ -47,4 +47,9 @@ def create_bot(bot_type):
     elif bot_type == const.QWEN:
         from bot.tongyi.tongyi_qwen_bot import TongyiQwenBot
         return TongyiQwenBot()
+
+    elif bot_type == const.GEMINI:
+        from bot.gemini.google_gemini_bot import GoogleGeminiBot
+        return GoogleGeminiBot()
+
     raise RuntimeError

+ 1 - 1
bot/chatgpt/chat_gpt_session.py

@@ -57,7 +57,7 @@ class ChatGPTSession(Session):
 def num_tokens_from_messages(messages, model):
     """Returns the number of tokens used by a list of messages."""
 
-    if model in ["wenxin", "xunfei"]:
+    if model in ["wenxin", "xunfei", const.GEMINI]:
         return num_tokens_by_character(messages)
 
     import tiktoken

+ 58 - 0
bot/gemini/google_gemini_bot.py

@@ -0,0 +1,58 @@
+"""
+Google gemini bot
+
+@author zhayujie
+@Date 2023/12/15
+"""
+# encoding:utf-8
+
+from bot.bot import Bot
+import google.generativeai as genai
+from bot.session_manager import SessionManager
+from bridge.context import ContextType, Context
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from config import conf
+from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
+
+
+# OpenAI对话模型API (可用)
+class GoogleGeminiBot(Bot):
+
+    def __init__(self):
+        super().__init__()
+        self.api_key = conf().get("gemini_api_key")
+        # 复用文心的token计算方式
+        self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo")
+
+    def reply(self, query, context: Context = None) -> Reply:
+        if context.type != ContextType.TEXT:
+            logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
+            return Reply(ReplyType.TEXT, None)
+        logger.info(f"[Gemini] query={query}")
+        session_id = context["session_id"]
+        session = self.sessions.session_query(query, session_id)
+        gemini_messages = self._convert_to_gemini_messages(session.messages)
+        genai.configure(api_key=self.api_key)
+        model = genai.GenerativeModel('gemini-pro')
+        response = model.generate_content(gemini_messages)
+        reply_text = response.text
+        self.sessions.session_reply(reply_text, session_id)
+        logger.info(f"[Gemini] reply={reply_text}")
+        return Reply(ReplyType.TEXT, reply_text)
+
+
+    def _convert_to_gemini_messages(self, messages: list):
+        res = []
+        for msg in messages:
+            if msg.get("role") == "user":
+                role = "user"
+            elif msg.get("role") == "assistant":
+                role = "model"
+            else:
+                continue
+            res.append({
+                "role": role,
+                "parts": [{"text": msg.get("content")}]
+            })
+        return res

+ 4 - 0
bridge/bridge.py

@@ -29,12 +29,16 @@ class Bridge(object):
             self.btype["chat"] = const.XUNFEI
         if model_type in [const.QWEN]:
             self.btype["chat"] = const.QWEN
+        if model_type in [const.GEMINI]:
+            self.btype["chat"] = const.GEMINI
+
         if conf().get("use_linkai") and conf().get("linkai_api_key"):
             self.btype["chat"] = const.LINKAI
             if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]:
                 self.btype["voice_to_text"] = const.LINKAI
             if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
                 self.btype["text_to_voice"] = const.LINKAI
+
         if model_type in ["claude"]:
             self.btype["chat"] = const.CLAUDEAI
         self.bots = {}

+ 2 - 1
common/const.py

@@ -7,6 +7,7 @@ CHATGPTONAZURE = "chatGPTOnAzure"
 LINKAI = "linkai"
 CLAUDEAI = "claude"
 QWEN = "qwen"
+GEMINI = "gemini"
 
 # model
 GPT35 = "gpt-3.5-turbo"
@@ -17,7 +18,7 @@ WHISPER_1 = "whisper-1"
 TTS_1 = "tts-1"
 TTS_1_HD = "tts-1-hd"
 
-MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo", GPT4_TURBO_PREVIEW, QWEN]
+MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo", GPT4_TURBO_PREVIEW, QWEN, GEMINI]
 
 # channel
 FEISHU = "feishu"

+ 2 - 0
config.py

@@ -73,6 +73,8 @@ available_setting = {
     "qwen_agent_key": "",
     "qwen_app_id": "",
     "qwen_node_id": "",  # 流程编排模型用到的id,如果没有用到qwen_node_id,请务必保持为空字符串
+    # Google Gemini Api Key
+    "gemini_api_key": "",
     # wework的通用配置
     "wework_smart": True,  # 配置wework是否使用已登录的企业微信,False为多开
     # 语音设置

+ 1 - 1
plugins/godcmd/godcmd.py

@@ -313,7 +313,7 @@ class Godcmd(Plugin):
                     except Exception as e:
                         ok, result = False, "你没有设置私有GPT模型"
                 elif cmd == "reset":
-                    if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI]:
+                    if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI, const.GEMINI]:
                         bot.sessions.clear_session(session_id)
                         if Bridge().chat_bots.get(bottype):
                             Bridge().chat_bots.get(bottype).sessions.clear_session(session_id)