moonshot_bot.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # encoding:utf-8
  2. import time
  3. import requests
  4. from bot.bot import Bot
  5. from bot.session_manager import SessionManager
  6. from bridge.context import ContextType
  7. from bridge.reply import Reply, ReplyType
  8. from common.log import logger
  9. from config import conf, load_config
  10. from .moonshot_session import MoonshotSession
  11. # ZhipuAI对话模型API
  12. class MoonshotBot(Bot):
  13. def __init__(self):
  14. super().__init__()
  15. self.sessions = SessionManager(MoonshotSession, model=conf().get("model") or "moonshot-v1-128k")
  16. self.args = {
  17. "model": conf().get("model") or "moonshot-v1-128k", # 对话模型的名称
  18. "temperature": conf().get("temperature", 0.3), # 如果设置,值域须为 [0, 1] 我们推荐 0.3,以达到较合适的效果。
  19. "top_p": conf().get("top_p", 1.0), # 使用默认值
  20. }
  21. self.api_key = conf().get("moonshot_api_key")
  22. self.base_url = conf().get("moonshot_base_url", "https://api.moonshot.cn/v1/chat/completions")
  23. def reply(self, query, context=None):
  24. # acquire reply content
  25. if context.type == ContextType.TEXT:
  26. logger.info("[MOONSHOT_AI] query={}".format(query))
  27. session_id = context["session_id"]
  28. reply = None
  29. clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
  30. if query in clear_memory_commands:
  31. self.sessions.clear_session(session_id)
  32. reply = Reply(ReplyType.INFO, "记忆已清除")
  33. elif query == "#清除所有":
  34. self.sessions.clear_all_session()
  35. reply = Reply(ReplyType.INFO, "所有人记忆已清除")
  36. elif query == "#更新配置":
  37. load_config()
  38. reply = Reply(ReplyType.INFO, "配置已更新")
  39. if reply:
  40. return reply
  41. session = self.sessions.session_query(query, session_id)
  42. logger.debug("[MOONSHOT_AI] session query={}".format(session.messages))
  43. model = context.get("moonshot_model")
  44. new_args = self.args.copy()
  45. if model:
  46. new_args["model"] = model
  47. # if context.get('stream'):
  48. # # reply in stream
  49. # return self.reply_text_stream(query, new_query, session_id)
  50. reply_content = self.reply_text(session, args=new_args)
  51. logger.debug(
  52. "[MOONSHOT_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
  53. session.messages,
  54. session_id,
  55. reply_content["content"],
  56. reply_content["completion_tokens"],
  57. )
  58. )
  59. if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
  60. reply = Reply(ReplyType.ERROR, reply_content["content"])
  61. elif reply_content["completion_tokens"] > 0:
  62. self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
  63. reply = Reply(ReplyType.TEXT, reply_content["content"])
  64. else:
  65. reply = Reply(ReplyType.ERROR, reply_content["content"])
  66. logger.debug("[MOONSHOT_AI] reply {} used 0 tokens.".format(reply_content))
  67. return reply
  68. else:
  69. reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
  70. return reply
  71. def reply_text(self, session: MoonshotSession, args=None, retry_count=0) -> dict:
  72. """
  73. call openai's ChatCompletion to get the answer
  74. :param session: a conversation session
  75. :param session_id: session id
  76. :param retry_count: retry count
  77. :return: {}
  78. """
  79. try:
  80. headers = {
  81. "Content-Type": "application/json",
  82. "Authorization": "Bearer " + self.api_key
  83. }
  84. body = args
  85. body["messages"] = session.messages
  86. # logger.debug("[MOONSHOT_AI] response={}".format(response))
  87. # logger.info("[MOONSHOT_AI] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
  88. res = requests.post(
  89. self.base_url,
  90. headers=headers,
  91. json=body
  92. )
  93. if res.status_code == 200:
  94. response = res.json()
  95. return {
  96. "total_tokens": response["usage"]["total_tokens"],
  97. "completion_tokens": response["usage"]["completion_tokens"],
  98. "content": response["choices"][0]["message"]["content"]
  99. }
  100. else:
  101. response = res.json()
  102. error = response.get("error")
  103. logger.error(f"[MOONSHOT_AI] chat failed, status_code={res.status_code}, "
  104. f"msg={error.get('message')}, type={error.get('type')}")
  105. result = {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
  106. need_retry = False
  107. if res.status_code >= 500:
  108. # server error, need retry
  109. logger.warn(f"[MOONSHOT_AI] do retry, times={retry_count}")
  110. need_retry = retry_count < 2
  111. elif res.status_code == 401:
  112. result["content"] = "授权失败,请检查API Key是否正确"
  113. elif res.status_code == 429:
  114. result["content"] = "请求过于频繁,请稍后再试"
  115. need_retry = retry_count < 2
  116. else:
  117. need_retry = False
  118. if need_retry:
  119. time.sleep(3)
  120. return self.reply_text(session, args, retry_count + 1)
  121. else:
  122. return result
  123. except Exception as e:
  124. logger.exception(e)
  125. need_retry = retry_count < 2
  126. result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
  127. if need_retry:
  128. return self.reply_text(session, args, retry_count + 1)
  129. else:
  130. return result