Răsfoiți Sursa

fix: gemini session bug

zhayujie 2 ani în urmă
părinte
comite
1fafd39298

+ 1 - 1
bot/chatgpt/chat_gpt_session.py

@@ -62,7 +62,7 @@ def num_tokens_from_messages(messages, model):
 
     import tiktoken
 
-    if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106"]:
+    if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106", "moonshot"]:
         return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
     elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613",
                    "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k", "gpt-4-turbo-preview",

+ 2 - 5
bot/claudeapi/claude_api_bot.py

@@ -8,8 +8,8 @@ import anthropic
 
 from bot.bot import Bot
 from bot.openai.open_ai_image import OpenAIImage
-from bot.claudeapi.claude_api_session import ClaudeAPISession
 from bot.chatgpt.chat_gpt_session import ChatGPTSession
+from bot.gemini.google_gemini_bot import GoogleGeminiBot
 from bot.session_manager import SessionManager
 from bridge.context import ContextType
 from bridge.reply import Reply, ReplyType
@@ -78,15 +78,12 @@ class ClaudeAPIBot(Bot, OpenAIImage):
 
     def reply_text(self, session: ChatGPTSession, retry_count=0):
         try:
-            if session.messages[0].get("role") == "system":
-                system = session.messages[0].get("content")
-                session.messages.pop(0)
             actual_model = self._model_mapping(conf().get("model"))
             response = self.claudeClient.messages.create(
                 model=actual_model,
                 max_tokens=1024,
                 # system=conf().get("system"),
-                messages=session.messages
+                messages=GoogleGeminiBot.filter_messages(session.messages)
             )
             # response = openai.Completion.create(prompt=str(session), **self.args)
             res_content = response.content[0].text.strip().replace("<|endoftext|>", "")

+ 0 - 74
bot/claudeapi/claude_api_session.py

@@ -1,74 +0,0 @@
-from bot.session_manager import Session
-from common.log import logger
-
-
-class ClaudeAPISession(Session):
-    def __init__(self, session_id, system_prompt=None, model="text-davinci-003"):
-        super().__init__(session_id, system_prompt)
-        self.model = model
-        self.reset()
-
-    def __str__(self):
-        # 构造对话模型的输入
-        """
-        e.g.  Q: xxx
-              A: xxx
-              Q: xxx
-        """
-        prompt = ""
-        for item in self.messages:
-            if item["role"] == "system":
-                prompt += item["content"] + "<|endoftext|>\n\n\n"
-            elif item["role"] == "user":
-                prompt += "Q: " + item["content"] + "\n"
-            elif item["role"] == "assistant":
-                prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n"
-
-        if len(self.messages) > 0 and self.messages[-1]["role"] == "user":
-            prompt += "A: "
-        return prompt
-
-    def discard_exceeding(self, max_tokens, cur_tokens=None):
-        precise = True
-
-        try:
-            cur_tokens = self.calc_tokens()
-        except Exception as e:
-            precise = False
-            if cur_tokens is None:
-                raise e
-            logger.debug("Exception when counting tokens precisely for query: {}".format(e))
-        while cur_tokens > max_tokens:
-            if len(self.messages) > 1:
-                self.messages.pop(0)
-            elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant":
-                self.messages.pop(0)
-                if precise:
-                    cur_tokens = self.calc_tokens()
-                else:
-                    cur_tokens = len(str(self))
-                break
-            elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
-                logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
-                break
-            else:
-                logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
-                break
-            if precise:
-                cur_tokens = self.calc_tokens()
-            else:
-                cur_tokens = len(str(self))
-        return cur_tokens
-    def calc_tokens(self):
-        return num_tokens_from_string(str(self), self.model)
-
-
-# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
-def num_tokens_from_string(string: str, model: str) -> int:
-    """Returns the number of tokens in a text string."""
-    num_tokens = len(string)
-    return num_tokens
-
-
-
-

+ 3 - 2
bot/gemini/google_gemini_bot.py

@@ -33,7 +33,7 @@ class GoogleGeminiBot(Bot):
             logger.info(f"[Gemini] query={query}")
             session_id = context["session_id"]
             session = self.sessions.session_query(query, session_id)
-            gemini_messages = self._convert_to_gemini_messages(self._filter_messages(session.messages))
+            gemini_messages = self._convert_to_gemini_messages(self.filter_messages(session.messages))
             genai.configure(api_key=self.api_key)
             model = genai.GenerativeModel('gemini-pro')
             response = model.generate_content(gemini_messages)
@@ -61,7 +61,8 @@ class GoogleGeminiBot(Bot):
             })
         return res
 
-    def _filter_messages(self, messages: list):
+    @staticmethod
+    def filter_messages(messages: list):
         res = []
         turn = "user"
         if not messages:

+ 1 - 0
bot/linkai/link_ai_bot.py

@@ -7,6 +7,7 @@ import requests
 import config
 from bot.bot import Bot
 from bot.chatgpt.chat_gpt_session import ChatGPTSession
+from bot.gemini.google_gemini_bot import GoogleGeminiBot
 from bot.session_manager import SessionManager
 from bridge.context import Context, ContextType
 from bridge.reply import Reply, ReplyType

+ 3 - 2
common/const.py

@@ -10,10 +10,11 @@ CLAUDEAPI= "claudeAPI"
 QWEN = "qwen"
 GEMINI = "gemini"
 ZHIPU_AI = "glm-4"
+MOONSHOT = "moonshot"
 
 
 # model
-CLAUDE3="claude-3-opus-20240229"
+CLAUDE3 = "claude-3-opus-20240229"
 GPT35 = "gpt-3.5-turbo"
 GPT4 = "gpt-4"
 GPT4_TURBO_PREVIEW = "gpt-4-0125-preview"
@@ -23,7 +24,7 @@ TTS_1 = "tts-1"
 TTS_1_HD = "tts-1-hd"
 
 MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude","claude-3-opus-20240229", "gpt-4-turbo",
-              "gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI]
+              "gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI, MOONSHOT]
 
 # channel
 FEISHU = "feishu"

+ 1 - 1
plugins/godcmd/godcmd.py

@@ -339,7 +339,7 @@ class Godcmd(Plugin):
                             ok, result = True, "配置已重载"
                         elif cmd == "resetall":
                             if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI,
-                                           const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.ZHIPU_AI]:
+                                           const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.ZHIPU_AI, const.MOONSHOT]:
                                 channel.cancel_all_session()
                                 bot.sessions.clear_all_session()
                                 ok, result = True, "重置所有会话成功"