瀏覽代碼

feat: 增加moonshot api集成

moonshot本来可直接使用openai sdk,
但是要求openai sdk必须在1.0以上,与本项目冲突,
故现使用http接口对接的方式集成
weishao zeng 2 年之前
父節點
當前提交
38e1db7a37
共有 5 個文件被更改,包括 203 次插入0 次删除
  1. 4 0
      bot/bot_factory.py
  2. 143 0
      bot/moonshot/moonshot_bot.py
  3. 51 0
      bot/moonshot/moonshot_session.py
  4. 3 0
      bridge/bridge.py
  5. 2 0
      config.py

+ 4 - 0
bot/bot_factory.py

@@ -59,5 +59,9 @@ def create_bot(bot_type):
         from bot.zhipuai.zhipuai_bot import ZHIPUAIBot
         return ZHIPUAIBot()
 
+    elif bot_type == const.MOONSHOT:
+        from bot.moonshot.moonshot_bot import MoonshotBot
+        return MoonshotBot()
+
 
     raise RuntimeError

+ 143 - 0
bot/moonshot/moonshot_bot.py

@@ -0,0 +1,143 @@
+# encoding:utf-8
+
+import time
+
+import openai
+import openai.error
+from bot.bot import Bot
+from bot.session_manager import SessionManager
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from config import conf, load_config
+from .moonshot_session import MoonshotSession
+import requests
+
+
+# ZhipuAI对话模型API
+class MoonshotBot(Bot):
+    def __init__(self):
+        super().__init__()
+        self.sessions = SessionManager(MoonshotSession, model=conf().get("model") or "moonshot-v1-128k")
+        self.args = {
+            "model": conf().get("model") or "moonshot-v1-128k",  # 对话模型的名称
+            "temperature": conf().get("temperature", 0.3),  # 如果设置,值域须为 [0, 1] 我们推荐 0.3,以达到较合适的效果。
+            "top_p": conf().get("top_p", 1.0),  # 使用默认值
+        }
+        self.api_key = conf().get("moonshot_api_key")
+        self.base_url = conf().get("moonshot_base_url", "https://api.moonshot.cn/v1/chat/completions")
+
+    def reply(self, query, context=None):
+        # acquire reply content
+        if context.type == ContextType.TEXT:
+            logger.info("[MOONSHOT_AI] query={}".format(query))
+
+            session_id = context["session_id"]
+            reply = None
+            clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
+            if query in clear_memory_commands:
+                self.sessions.clear_session(session_id)
+                reply = Reply(ReplyType.INFO, "记忆已清除")
+            elif query == "#清除所有":
+                self.sessions.clear_all_session()
+                reply = Reply(ReplyType.INFO, "所有人记忆已清除")
+            elif query == "#更新配置":
+                load_config()
+                reply = Reply(ReplyType.INFO, "配置已更新")
+            if reply:
+                return reply
+            session = self.sessions.session_query(query, session_id)
+            logger.debug("[MOONSHOT_AI] session query={}".format(session.messages))
+
+            model = context.get("moonshot_model")
+            new_args = self.args.copy()
+            if model:
+                new_args["model"] = model
+            # if context.get('stream'):
+            #     # reply in stream
+            #     return self.reply_text_stream(query, new_query, session_id)
+
+            reply_content = self.reply_text(session, args=new_args)
+            logger.debug(
+                "[MOONSHOT_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
+                    session.messages,
+                    session_id,
+                    reply_content["content"],
+                    reply_content["completion_tokens"],
+                )
+            )
+            if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
+                reply = Reply(ReplyType.ERROR, reply_content["content"])
+            elif reply_content["completion_tokens"] > 0:
+                self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
+                reply = Reply(ReplyType.TEXT, reply_content["content"])
+            else:
+                reply = Reply(ReplyType.ERROR, reply_content["content"])
+                logger.debug("[MOONSHOT_AI] reply {} used 0 tokens.".format(reply_content))
+            return reply
+        else:
+            reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
+            return reply
+
+    def reply_text(self, session: MoonshotSession, args=None, retry_count=0) -> dict:
+        """
+        call openai's ChatCompletion to get the answer
+        :param session: a conversation session
+        :param session_id: session id
+        :param retry_count: retry count
+        :return: {}
+        """
+        try:
+            headers = {
+                "Content-Type": "application/json",
+                "Authorization": "Bearer " + self.api_key
+            }
+            body = args
+            body["messages"] = session.messages
+            # logger.debug("[MOONSHOT_AI] response={}".format(response))
+            # logger.info("[MOONSHOT_AI] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
+            res = requests.post(
+                self.base_url,
+                headers=headers,
+                json=body
+            )
+            if res.status_code == 200:
+                response = res.json()
+                return {
+                    "total_tokens": response["usage"]["total_tokens"],
+                    "completion_tokens": response["usage"]["completion_tokens"],
+                    "content": response["choices"][0]["message"]["content"]
+                }
+            else:
+                response = res.json()
+                error = response.get("error")
+                logger.error(f"[MOONSHOT_AI] chat failed, status_code={res.status_code}, "
+                             f"msg={error.get('message')}, type={error.get('type')}")
+
+                result = {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
+                need_retry = False
+                if res.status_code >= 500:
+                    # server error, need retry
+                    logger.warn(f"[MOONSHOT_AI] do retry, times={retry_count}")
+                    need_retry = retry_count < 2
+                elif res.status_code == 401:
+                    result["content"] = "授权失败,请检查API Key是否正确"
+                elif res.status_code == 429:
+                    result["content"] = "请求过于频繁,请稍后再试"
+                    need_retry = retry_count < 2
+                else:
+                    need_retry = False
+
+                if need_retry:
+                    time.sleep(3)
+                    return self.reply_text(session, args, retry_count + 1)
+                else:
+                    return result
+        except Exception as e:
+            logger.exception(e)
+            need_retry = retry_count < 2
+            result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
+            if need_retry:
+                return self.reply_text(session, args, retry_count + 1)
+            else:
+                return result

+ 51 - 0
bot/moonshot/moonshot_session.py

@@ -0,0 +1,51 @@
+from bot.session_manager import Session
+from common.log import logger
+
+
+class MoonshotSession(Session):
+    def __init__(self, session_id, system_prompt=None, model="moonshot-v1-128k"):
+        super().__init__(session_id, system_prompt)
+        self.model = model
+        self.reset()
+
+    def discard_exceeding(self, max_tokens, cur_tokens=None):
+        precise = True
+        try:
+            cur_tokens = self.calc_tokens()
+        except Exception as e:
+            precise = False
+            if cur_tokens is None:
+                raise e
+            logger.debug("Exception when counting tokens precisely for query: {}".format(e))
+        while cur_tokens > max_tokens:
+            if len(self.messages) > 2:
+                self.messages.pop(1)
+            elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
+                self.messages.pop(1)
+                if precise:
+                    cur_tokens = self.calc_tokens()
+                else:
+                    cur_tokens = cur_tokens - max_tokens
+                break
+            elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
+                logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
+                break
+            else:
+                logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens,
+                                                                                       len(self.messages)))
+                break
+            if precise:
+                cur_tokens = self.calc_tokens()
+            else:
+                cur_tokens = cur_tokens - max_tokens
+        return cur_tokens
+
+    def calc_tokens(self):
+        return num_tokens_from_messages(self.messages, self.model)
+
+
+def num_tokens_from_messages(messages, model):
+    tokens = 0
+    for msg in messages:
+        tokens += len(msg["content"])
+    return tokens

+ 3 - 0
bridge/bridge.py

@@ -46,6 +46,9 @@ class Bridge(object):
         if model_type in ["claude"]:
             self.btype["chat"] = const.CLAUDEAI
 
+        if model_type in ["moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]:
+            self.btype["chat"] = const.MOONSHOT
+
         self.bots = {}
         self.chat_bots = {}
     # 模型对应的接口

+ 2 - 0
config.py

@@ -155,6 +155,8 @@ available_setting = {
     # 智谱AI 平台配置
     "zhipu_ai_api_key": "",
     "zhipu_ai_api_base": "https://open.bigmodel.cn/api/paas/v4",
+    "moonshot_api_key": "",
+    "moonshot_base_url":"https://api.moonshot.cn/v1/chat/completions",
     # LinkAI平台配置
     "use_linkai": False,
     "linkai_api_key": "",