zhayujie 2 лет назад
Родитель
Сommit
cac7a6228a

+ 2 - 1
bot/chatgpt/chat_gpt_session.py

@@ -68,7 +68,8 @@ def num_tokens_from_messages(messages, model):
                    "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k", "gpt-4-turbo-preview",
                    "gpt-4-1106-preview", const.GPT4_TURBO_PREVIEW, const.GPT4_VISION_PREVIEW]:
         return num_tokens_from_messages(messages, model="gpt-4")
-
+    elif model.startswith("claude-3"):
+        return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
     try:
         encoding = tiktoken.encoding_for_model(model)
     except KeyError:

+ 18 - 7
bot/claudeapi/claude_api_bot.py

@@ -9,6 +9,7 @@ import anthropic
 from bot.bot import Bot
 from bot.openai.open_ai_image import OpenAIImage
 from bot.claudeapi.claude_api_session import ClaudeAPISession
+from bot.chatgpt.chat_gpt_session import ChatGPTSession
 from bot.session_manager import SessionManager
 from bridge.context import ContextType
 from bridge.reply import Reply, ReplyType
@@ -32,7 +33,7 @@ class ClaudeAPIBot(Bot, OpenAIImage):
         if proxy:
             openai.proxy = proxy
 
-        self.sessions = SessionManager(ClaudeAPISession, model=conf().get("model") or "text-davinci-003")
+        self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "text-davinci-003")
 
     def reply(self, query, context=None):
         # acquire reply content
@@ -75,16 +76,17 @@ class ClaudeAPIBot(Bot, OpenAIImage):
                     reply = Reply(ReplyType.ERROR, retstring)
                 return reply
 
-    def reply_text(self, session: ClaudeAPISession, retry_count=0):
+    def reply_text(self, session: ChatGPTSession, retry_count=0):
         try:
-            logger.info("[CLAUDE_API] sendMessage={}".format(str(session)))
+            if session.messages[0].get("role") == "system":
+                system = session.messages[0].get("content")
+                session.messages.pop(0)
+            actual_model = self._model_mapping(conf().get("model"))
             response = self.claudeClient.messages.create(
-                model=conf().get("model"),
+                model=actual_model,
                 max_tokens=1024,
                 # system=conf().get("system"),
-                messages=[
-                    {"role": "user", "content": "{}".format(str(session))}
-                ]
+                messages=session.messages
             )
             # response = openai.Completion.create(prompt=str(session), **self.args)
             res_content = response.content[0].text.strip().replace("<|endoftext|>", "")
@@ -123,3 +125,12 @@ class ClaudeAPIBot(Bot, OpenAIImage):
                 return self.reply_text(session, retry_count + 1)
             else:
                 return result
+
+    def _model_mapping(self, model) -> str:
+        if model == "claude-3-opus":
+            return "claude-3-opus-20240229"
+        elif model == "claude-3-sonnet":
+            return "claude-3-sonnet-20240229"
+        elif model == "claude-3-haiku":
+            return "claude-3-haiku-20240307"
+        return model

+ 10 - 4
bot/linkai/link_ai_bot.py

@@ -130,9 +130,12 @@ class LinkAIBot(Bot):
                 response = res.json()
                 reply_content = response["choices"][0]["message"]["content"]
                 total_tokens = response["usage"]["total_tokens"]
-                logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
-                self.sessions.session_reply(reply_content, session_id, total_tokens, query=query)
-    
+                res_code = response.get('code')
+                logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}, res_code={res_code}")
+                if res_code == 429:
+                    logger.warn(f"[LINKAI] 用户访问超出限流配置,sender_id={body.get('sender_id')}")
+                else:
+                    self.sessions.session_reply(reply_content, session_id, total_tokens, query=query)
                 agent_suffix = self._fetch_agent_suffix(response)
                 if agent_suffix:
                     reply_content += agent_suffix
@@ -161,7 +164,10 @@ class LinkAIBot(Bot):
                     logger.warn(f"[LINKAI] do retry, times={retry_count}")
                     return self._chat(query, context, retry_count + 1)
 
-                return Reply(ReplyType.TEXT, "提问太快啦,请休息一下再问我吧")
+                error_reply = "提问太快啦,请休息一下再问我吧"
+                if res.status_code == 409:
+                    error_reply = "这个问题我还没有学会,请问我其它问题吧"
+                return Reply(ReplyType.TEXT, error_reply)
 
         except Exception as e:
             logger.exception(e)

+ 1 - 1
bridge/bridge.py

@@ -34,7 +34,7 @@ class Bridge(object):
             self.btype["chat"] = const.GEMINI
         if model_type in [const.ZHIPU_AI]:
             self.btype["chat"] = const.ZHIPU_AI
-        if model_type in [const.CLAUDE3]:
+        if model_type and model_type.startswith("claude-3"):
             self.btype["chat"] = const.CLAUDEAPI
 
         if conf().get("use_linkai") and conf().get("linkai_api_key"):

+ 2 - 0
requirements-optional.txt

@@ -26,6 +26,8 @@ websocket-client==1.2.0
 
 # claude bot
 curl_cffi
+# claude API
+anthropic
 
 # tongyi qwen
 broadscope_bailian

+ 1 - 2
requirements.txt

@@ -6,5 +6,4 @@ requests>=2.28.2
 chardet>=5.1.0
 Pillow
 pre-commit
-web.py
-anthropic
+web.py