Ver Fonte

Merge pull request #1810 from FB208/master

增加了claude api的调用方法
zhayujie há 2 anos atrás
pai
commit
674fbc3f69

+ 2 - 1
README.md

@@ -118,7 +118,8 @@ pip3 install -r requirements-optional.txt
 # config.json文件内容示例
 {
   "open_ai_api_key": "YOUR API KEY",                          # 填入上面创建的 OpenAI API KEY
-  "model": "gpt-3.5-turbo",                                   # 模型名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
+  "model": "gpt-3.5-turbo",                                   # 模型名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei, claude-3-opus-20240229
+  "claude_api_key":"YOUR API KEY"                             # 如果选用claude3模型的话,配置这个key,同时如想使用生图,语音等功能,仍需配置open_ai_api_key
   "proxy": "",                                                # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890"
   "single_chat_prefix": ["bot", "@bot"],                      # 私聊时文本需要包含该前缀才能触发机器人回复
   "single_chat_reply_prefix": "[bot] ",                       # 私聊时自动回复的前缀,用于区分真人

+ 4 - 1
bot/bot_factory.py

@@ -2,6 +2,7 @@
 channel factory
 """
 from common import const
+from common.log import logger
 
 
 def create_bot(bot_type):
@@ -43,7 +44,9 @@ def create_bot(bot_type):
     elif bot_type == const.CLAUDEAI:
         from bot.claude.claude_ai_bot import ClaudeAIBot
         return ClaudeAIBot()
-
+    elif bot_type == const.CLAUDEAPI:
+        from bot.claudeapi.claude_api_bot import ClaudeAPIBot
+        return ClaudeAPIBot()
     elif bot_type == const.QWEN:
         from bot.ali.ali_qwen_bot import AliQwenBot
         return AliQwenBot()

+ 125 - 0
bot/claudeapi/claude_api_bot.py

@@ -0,0 +1,125 @@
+# encoding:utf-8
+
+import time
+
+import openai
+import openai.error
+import anthropic
+
+from bot.bot import Bot
+from bot.openai.open_ai_image import OpenAIImage
+from bot.claudeapi.claude_api_session import ClaudeAPISession
+from bot.session_manager import SessionManager
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from config import conf
+
+user_session = dict()
+
+
+# OpenAI对话模型API (可用)
+class ClaudeAPIBot(Bot, OpenAIImage):
+    def __init__(self):
+        super().__init__()
+        self.claudeClient = anthropic.Anthropic(
+            api_key=conf().get("claude_api_key")
+        )
+        openai.api_key = conf().get("open_ai_api_key")
+        if conf().get("open_ai_api_base"):
+            openai.api_base = conf().get("open_ai_api_base")
+        proxy = conf().get("proxy")
+        if proxy:
+            openai.proxy = proxy
+
+        self.sessions = SessionManager(ClaudeAPISession, model=conf().get("model") or "text-davinci-003")
+
+    def reply(self, query, context=None):
+        # acquire reply content
+        if context and context.type:
+            if context.type == ContextType.TEXT:
+                logger.info("[CLAUDE_API] query={}".format(query))
+                session_id = context["session_id"]
+                reply = None
+                if query == "#清除记忆":
+                    self.sessions.clear_session(session_id)
+                    reply = Reply(ReplyType.INFO, "记忆已清除")
+                elif query == "#清除所有":
+                    self.sessions.clear_all_session()
+                    reply = Reply(ReplyType.INFO, "所有人记忆已清除")
+                else:
+                    session = self.sessions.session_query(query, session_id)
+                    result = self.reply_text(session)
+                    logger.info(result)
+                    total_tokens, completion_tokens, reply_content = (
+                        result["total_tokens"],
+                        result["completion_tokens"],
+                        result["content"],
+                    )
+                    logger.debug(
+                        "[CLAUDE_API] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)
+                    )
+
+                    if total_tokens == 0:
+                        reply = Reply(ReplyType.ERROR, reply_content)
+                    else:
+                        self.sessions.session_reply(reply_content, session_id, total_tokens)
+                        reply = Reply(ReplyType.TEXT, reply_content)
+                return reply
+            elif context.type == ContextType.IMAGE_CREATE:
+                ok, retstring = self.create_img(query, 0)
+                reply = None
+                if ok:
+                    reply = Reply(ReplyType.IMAGE_URL, retstring)
+                else:
+                    reply = Reply(ReplyType.ERROR, retstring)
+                return reply
+
+    def reply_text(self, session: ClaudeAPISession, retry_count=0):
+        try:
+            logger.info("[CLAUDE_API] sendMessage={}".format(str(session)))
+            response = self.claudeClient.messages.create(
+                model=conf().get("model"),
+                max_tokens=1024,
+                # system=conf().get("system"),
+                messages=[
+                    {"role": "user", "content": "{}".format(str(session))}
+                ]
+            )
+            # response = openai.Completion.create(prompt=str(session), **self.args)
+            res_content = response.content[0].text.strip().replace("<|endoftext|>", "")
+            total_tokens = response.usage.input_tokens+response.usage.output_tokens
+            completion_tokens = response.usage.output_tokens
+            logger.info("[CLAUDE_API] reply={}".format(res_content))
+            return {
+                "total_tokens": total_tokens,
+                "completion_tokens": completion_tokens,
+                "content": res_content,
+            }
+        except Exception as e:
+            need_retry = retry_count < 2
+            result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
+            if isinstance(e, openai.error.RateLimitError):
+                logger.warn("[CLAUDE_API] RateLimitError: {}".format(e))
+                result["content"] = "提问太快啦,请休息一下再问我吧"
+                if need_retry:
+                    time.sleep(20)
+            elif isinstance(e, openai.error.Timeout):
+                logger.warn("[CLAUDE_API] Timeout: {}".format(e))
+                result["content"] = "我没有收到你的消息"
+                if need_retry:
+                    time.sleep(5)
+            elif isinstance(e, openai.error.APIConnectionError):
+                logger.warn("[CLAUDE_API] APIConnectionError: {}".format(e))
+                need_retry = False
+                result["content"] = "我连接不到你的网络"
+            else:
+                logger.warn("[CLAUDE_API] Exception: {}".format(e))
+                need_retry = False
+                self.sessions.clear_session(session.session_id)
+
+            if need_retry:
+                logger.warn("[CLAUDE_API] 第{}次重试".format(retry_count + 1))
+                return self.reply_text(session, retry_count + 1)
+            else:
+                return result

+ 74 - 0
bot/claudeapi/claude_api_session.py

@@ -0,0 +1,74 @@
+from bot.session_manager import Session
+from common.log import logger
+
+
+class ClaudeAPISession(Session):
+    def __init__(self, session_id, system_prompt=None, model="text-davinci-003"):
+        super().__init__(session_id, system_prompt)
+        self.model = model
+        self.reset()
+
+    def __str__(self):
+        # 构造对话模型的输入
+        """
+        e.g.  Q: xxx
+              A: xxx
+              Q: xxx
+        """
+        prompt = ""
+        for item in self.messages:
+            if item["role"] == "system":
+                prompt += item["content"] + "<|endoftext|>\n\n\n"
+            elif item["role"] == "user":
+                prompt += "Q: " + item["content"] + "\n"
+            elif item["role"] == "assistant":
+                prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n"
+
+        if len(self.messages) > 0 and self.messages[-1]["role"] == "user":
+            prompt += "A: "
+        return prompt
+
+    def discard_exceeding(self, max_tokens, cur_tokens=None):
+        precise = True
+
+        try:
+            cur_tokens = self.calc_tokens()
+        except Exception as e:
+            precise = False
+            if cur_tokens is None:
+                raise e
+            logger.debug("Exception when counting tokens precisely for query: {}".format(e))
+        while cur_tokens > max_tokens:
+            if len(self.messages) > 1:
+                self.messages.pop(0)
+            elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant":
+                self.messages.pop(0)
+                if precise:
+                    cur_tokens = self.calc_tokens()
+                else:
+                    cur_tokens = len(str(self))
+                break
+            elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
+                logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
+                break
+            else:
+                logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
+                break
+            if precise:
+                cur_tokens = self.calc_tokens()
+            else:
+                cur_tokens = len(str(self))
+        return cur_tokens
+    def calc_tokens(self):
+        return num_tokens_from_string(str(self), self.model)
+
+
+# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
+def num_tokens_from_string(string: str, model: str) -> int:
+    """Returns the number of tokens in a text string."""
+    num_tokens = len(string)
+    return num_tokens
+
+
+
+

+ 5 - 2
bridge/bridge.py

@@ -18,6 +18,7 @@ class Bridge(object):
             "text_to_voice": conf().get("text_to_voice", "google"),
             "translate": conf().get("translate", "baidu"),
         }
+        # 这边取配置的模型
         model_type = conf().get("model") or const.GPT35
         if model_type in ["text-davinci-003"]:
             self.btype["chat"] = const.OPEN_AI
@@ -33,6 +34,8 @@ class Bridge(object):
             self.btype["chat"] = const.GEMINI
         if model_type in [const.ZHIPU_AI]:
             self.btype["chat"] = const.ZHIPU_AI
+        if model_type in [const.CLAUDE3]:
+            self.btype["chat"] = const.CLAUDEAPI
 
         if conf().get("use_linkai") and conf().get("linkai_api_key"):
             self.btype["chat"] = const.LINKAI
@@ -40,12 +43,12 @@ class Bridge(object):
                 self.btype["voice_to_text"] = const.LINKAI
             if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
                 self.btype["text_to_voice"] = const.LINKAI
-
         if model_type in ["claude"]:
             self.btype["chat"] = const.CLAUDEAI
+
         self.bots = {}
         self.chat_bots = {}
-
+    # 模型对应的接口
     def get_bot(self, typename):
         if self.bots.get(typename) is None:
             logger.info("create bot {} for {}".format(self.btype[typename], typename))

+ 3 - 1
common/const.py

@@ -6,12 +6,14 @@ XUNFEI = "xunfei"
 CHATGPTONAZURE = "chatGPTOnAzure"
 LINKAI = "linkai"
 CLAUDEAI = "claude"
+CLAUDEAPI= "claudeAPI"
 QWEN = "qwen"
 GEMINI = "gemini"
 ZHIPU_AI = "glm-4"
 
 
 # model
+CLAUDE3="claude-3-opus-20240229"
 GPT35 = "gpt-3.5-turbo"
 GPT4 = "gpt-4"
 GPT4_TURBO_PREVIEW = "gpt-4-0125-preview"
@@ -20,7 +22,7 @@ WHISPER_1 = "whisper-1"
 TTS_1 = "tts-1"
 TTS_1_HD = "tts-1-hd"
 
-MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo",
+MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude","claude-3-opus-20240229", "gpt-4-turbo",
               "gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI]
 
 # channel

+ 1 - 0
config-template.json

@@ -2,6 +2,7 @@
   "channel_type": "wx",
   "model": "",
   "open_ai_api_key": "YOUR API KEY",
+  "claude_api_key": "YOUR API KEY",
   "text_to_image": "dall-e-2",
   "voice_to_text": "openai",
   "text_to_voice": "openai",

+ 2 - 0
config.py

@@ -67,6 +67,8 @@ available_setting = {
     # claude 配置
     "claude_api_cookie": "",
     "claude_uuid": "",
+    # claude api key
+    "claude_api_key":"",
     # 通义千问API, 获取方式查看文档 https://help.aliyun.com/document_detail/2587494.html
     "qwen_access_key_id": "",
     "qwen_access_key_secret": "",

+ 1 - 0
requirements.txt

@@ -7,3 +7,4 @@ chardet>=5.1.0
 Pillow
 pre-commit
 web.py
+anthropic