Bladeren bron

feat: 通义千问使用新版的sdk实现

现在项目使用的通义千问是旧版本的百炼sdk,
这里增加一个新版本sdk(dashscope)的实现
weishao zeng 2 jaren geleden
bovenliggende
commit
5e399c46b1
7 gewijzigde bestanden met toevoegingen van 186 en 2 verwijderingen
  1. 3 1
      bot/bot_factory.py
  2. 117 0
      bot/dashscope/dashscope_bot.py
  3. 51 0
      bot/dashscope/dashscope_session.py
  4. 2 0
      bridge/bridge.py
  5. 8 1
      common/const.py
  6. 2 0
      config.py
  7. 3 0
      requirements-optional.txt

+ 3 - 1
bot/bot_factory.py

@@ -50,7 +50,9 @@ def create_bot(bot_type):
     elif bot_type == const.QWEN:
         from bot.ali.ali_qwen_bot import AliQwenBot
         return AliQwenBot()
-
+    elif bot_type == const.QWEN_DASHSCOPE:
+        from bot.dashscope.dashscope_bot import DashscopeBot
+        return DashscopeBot()
     elif bot_type == const.GEMINI:
         from bot.gemini.google_gemini_bot import GoogleGeminiBot
         return GoogleGeminiBot()

+ 117 - 0
bot/dashscope/dashscope_bot.py

@@ -0,0 +1,117 @@
+# encoding:utf-8
+
+from bot.bot import Bot
+from bot.session_manager import SessionManager
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from config import conf, load_config
+from .dashscope_session import DashscopeSession
+import os
+import dashscope
+from http import HTTPStatus
+
+
+
+dashscope_models = {
+    "qwen-turbo": dashscope.Generation.Models.qwen_turbo,
+    "qwen-plus": dashscope.Generation.Models.qwen_plus,
+    "qwen-max": dashscope.Generation.Models.qwen_max,
+    "qwen-bailian-v1": dashscope.Generation.Models.bailian_v1
+}
+# ZhipuAI对话模型API
+class DashscopeBot(Bot):
+    def __init__(self):
+        super().__init__()
+        self.sessions = SessionManager(DashscopeSession, model=conf().get("model") or "qwen-plus")
+        self.model_name = conf().get("model") or "qwen-plus"
+        self.api_key = conf().get("dashscope_api_key")
+        os.environ["DASHSCOPE_API_KEY"] = self.api_key
+        self.client = dashscope.Generation
+
+    def reply(self, query, context=None):
+        # acquire reply content
+        if context.type == ContextType.TEXT:
+            logger.info("[DASHSCOPE] query={}".format(query))
+
+            session_id = context["session_id"]
+            reply = None
+            clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
+            if query in clear_memory_commands:
+                self.sessions.clear_session(session_id)
+                reply = Reply(ReplyType.INFO, "记忆已清除")
+            elif query == "#清除所有":
+                self.sessions.clear_all_session()
+                reply = Reply(ReplyType.INFO, "所有人记忆已清除")
+            elif query == "#更新配置":
+                load_config()
+                reply = Reply(ReplyType.INFO, "配置已更新")
+            if reply:
+                return reply
+            session = self.sessions.session_query(query, session_id)
+            logger.debug("[DASHSCOPE] session query={}".format(session.messages))
+
+            reply_content = self.reply_text(session)
+            logger.debug(
+                "[DASHSCOPE] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
+                    session.messages,
+                    session_id,
+                    reply_content["content"],
+                    reply_content["completion_tokens"],
+                )
+            )
+            if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
+                reply = Reply(ReplyType.ERROR, reply_content["content"])
+            elif reply_content["completion_tokens"] > 0:
+                self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
+                reply = Reply(ReplyType.TEXT, reply_content["content"])
+            else:
+                reply = Reply(ReplyType.ERROR, reply_content["content"])
+                logger.debug("[DASHSCOPE] reply {} used 0 tokens.".format(reply_content))
+            return reply
+        else:
+            reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
+            return reply
+
+    def reply_text(self, session: DashscopeSession, retry_count=0) -> dict:
+        """
+        call openai's ChatCompletion to get the answer
+        :param session: a conversation session
+        :param session_id: session id
+        :param retry_count: retry count
+        :return: {}
+        """
+        try:
+            dashscope.api_key = self.api_key
+            response = self.client.call(
+                dashscope_models[self.model_name],
+                messages=session.messages,
+                result_format="message"
+            )
+            if response.status_code == HTTPStatus.OK:
+                content = response.output.choices[0]["message"]["content"]
+                return {
+                    "total_tokens": response.usage["total_tokens"],
+                    "completion_tokens": response.usage["output_tokens"],
+                    "content": content,
+                }
+            else:
+                logger.error('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
+                    response.request_id, response.status_code,
+                    response.code, response.message
+                ))
+                result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
+                need_retry = retry_count < 2
+                result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
+                if need_retry:
+                    return self.reply_text(session, retry_count + 1)
+                else:
+                    return result
+        except Exception as e:
+            logger.exception(e)
+            need_retry = retry_count < 2
+            result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
+            if need_retry:
+                return self.reply_text(session, retry_count + 1)
+            else:
+                return result

+ 51 - 0
bot/dashscope/dashscope_session.py

@@ -0,0 +1,51 @@
+from bot.session_manager import Session
+from common.log import logger
+
+
+class DashscopeSession(Session):
+    def __init__(self, session_id, system_prompt=None, model="qwen-turbo"):
+        super().__init__(session_id)
+        self.reset()
+
+    def discard_exceeding(self, max_tokens, cur_tokens=None):
+        precise = True
+        try:
+            cur_tokens = self.calc_tokens()
+        except Exception as e:
+            precise = False
+            if cur_tokens is None:
+                raise e
+            logger.debug("Exception when counting tokens precisely for query: {}".format(e))
+        while cur_tokens > max_tokens:
+            if len(self.messages) > 2:
+                self.messages.pop(1)
+            elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
+                self.messages.pop(1)
+                if precise:
+                    cur_tokens = self.calc_tokens()
+                else:
+                    cur_tokens = cur_tokens - max_tokens
+                break
+            elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
+                logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
+                break
+            else:
+                logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens,
+                                                                                       len(self.messages)))
+                break
+            if precise:
+                cur_tokens = self.calc_tokens()
+            else:
+                cur_tokens = cur_tokens - max_tokens
+        return cur_tokens
+
+    def calc_tokens(self):
+        return num_tokens_from_messages(self.messages)
+
+
+def num_tokens_from_messages(messages):
+    # 只是大概,具体计算规则:https://help.aliyun.com/zh/dashscope/developer-reference/token-api?spm=a2c4g.11186623.0.0.4d8b12b0BkP3K9
+    tokens = 0
+    for msg in messages:
+        tokens += len(msg["content"])
+    return tokens

+ 2 - 0
bridge/bridge.py

@@ -30,6 +30,8 @@ class Bridge(object):
             self.btype["chat"] = const.XUNFEI
         if model_type in [const.QWEN]:
             self.btype["chat"] = const.QWEN
+        if model_type in [const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
+            self.btype["chat"] = const.QWEN_DASHSCOPE
         if model_type in [const.GEMINI]:
             self.btype["chat"] = const.GEMINI
         if model_type in [const.ZHIPU_AI]:

+ 8 - 1
common/const.py

@@ -8,6 +8,12 @@ LINKAI = "linkai"
 CLAUDEAI = "claude"
 CLAUDEAPI= "claudeAPI"
 QWEN = "qwen"
+
+QWEN_DASHSCOPE = "dashscope"
+QWEN_TURBO = "qwen-turbo"
+QWEN_PLUS = "qwen-plus"
+QWEN_MAX = "qwen-max"
+
 GEMINI = "gemini"
 ZHIPU_AI = "glm-4"
 MOONSHOT = "moonshot"
@@ -24,7 +30,8 @@ TTS_1 = "tts-1"
 TTS_1_HD = "tts-1-hd"
 
 MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude","claude-3-opus-20240229", "gpt-4-turbo",
-              "gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI, MOONSHOT]
+              "gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI, MOONSHOT,
+              QWEN_TURBO, QWEN_PLUS, QWEN_MAX]
 
 # channel
 FEISHU = "feishu"

+ 2 - 0
config.py

@@ -75,6 +75,8 @@ available_setting = {
     "qwen_agent_key": "",
     "qwen_app_id": "",
     "qwen_node_id": "",  # 流程编排模型用到的id,如果没有用到qwen_node_id,请务必保持为空字符串
+    # 阿里灵积模型api key
+    "dashscope_api_key": "",
     # Google Gemini Api Key
     "gemini_api_key": "",
     # wework的通用配置

+ 3 - 0
requirements-optional.txt

@@ -43,3 +43,6 @@ dingtalk_stream
 
 # zhipuai
 zhipuai>=2.0.1
+
+# tongyi qwen new sdk
+dashscope