zhipuai_bot.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # encoding:utf-8
  2. import time
  3. import openai
  4. from bot.bot import Bot
  5. from bot.zhipuai.zhipu_ai_session import ZhipuAISession
  6. from bot.zhipuai.zhipu_ai_image import ZhipuAIImage
  7. from bot.session_manager import SessionManager
  8. from bridge.context import ContextType
  9. from bridge.reply import Reply, ReplyType
  10. from common.log import logger
  11. from config import conf, load_config
  12. from zhipuai import ZhipuAI
  13. # ZhipuAI对话模型API
  14. class ZHIPUAIBot(Bot, ZhipuAIImage):
  15. def __init__(self):
  16. super().__init__()
  17. self.sessions = SessionManager(ZhipuAISession, model=conf().get("model") or "ZHIPU_AI")
  18. self.args = {
  19. "model": conf().get("model") or "glm-4", # 对话模型的名称
  20. "temperature": conf().get("temperature", 0.9), # 值在(0,1)之间(智谱AI 的温度不能取 0 或者 1)
  21. "top_p": conf().get("top_p", 0.7), # 值在(0,1)之间(智谱AI 的 top_p 不能取 0 或者 1)
  22. }
  23. self.client = ZhipuAI(api_key=conf().get("zhipu_ai_api_key"))
  24. def reply(self, query, context=None):
  25. # acquire reply content
  26. if context.type == ContextType.TEXT:
  27. logger.info("[ZHIPU_AI] query={}".format(query))
  28. session_id = context["session_id"]
  29. reply = None
  30. clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
  31. if query in clear_memory_commands:
  32. self.sessions.clear_session(session_id)
  33. reply = Reply(ReplyType.INFO, "记忆已清除")
  34. elif query == "#清除所有":
  35. self.sessions.clear_all_session()
  36. reply = Reply(ReplyType.INFO, "所有人记忆已清除")
  37. elif query == "#更新配置":
  38. load_config()
  39. reply = Reply(ReplyType.INFO, "配置已更新")
  40. if reply:
  41. return reply
  42. session = self.sessions.session_query(query, session_id)
  43. logger.debug("[ZHIPU_AI] session query={}".format(session.messages))
  44. api_key = context.get("openai_api_key") or openai.api_key
  45. model = context.get("gpt_model")
  46. new_args = None
  47. if model:
  48. new_args = self.args.copy()
  49. new_args["model"] = model
  50. # if context.get('stream'):
  51. # # reply in stream
  52. # return self.reply_text_stream(query, new_query, session_id)
  53. reply_content = self.reply_text(session, api_key, args=new_args)
  54. logger.debug(
  55. "[ZHIPU_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
  56. session.messages,
  57. session_id,
  58. reply_content["content"],
  59. reply_content["completion_tokens"],
  60. )
  61. )
  62. if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
  63. reply = Reply(ReplyType.ERROR, reply_content["content"])
  64. elif reply_content["completion_tokens"] > 0:
  65. self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
  66. reply = Reply(ReplyType.TEXT, reply_content["content"])
  67. else:
  68. reply = Reply(ReplyType.ERROR, reply_content["content"])
  69. logger.debug("[ZHIPU_AI] reply {} used 0 tokens.".format(reply_content))
  70. return reply
  71. elif context.type == ContextType.IMAGE_CREATE:
  72. ok, retstring = self.create_img(query, 0)
  73. reply = None
  74. if ok:
  75. reply = Reply(ReplyType.IMAGE_URL, retstring)
  76. else:
  77. reply = Reply(ReplyType.ERROR, retstring)
  78. return reply
  79. else:
  80. reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
  81. return reply
  82. def reply_text(self, session: ZhipuAISession, api_key=None, args=None, retry_count=0) -> dict:
  83. """
  84. call openai's ChatCompletion to get the answer
  85. :param session: a conversation session
  86. :param session_id: session id
  87. :param retry_count: retry count
  88. :return: {}
  89. """
  90. try:
  91. # if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
  92. # raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
  93. # if api_key == None, the default openai.api_key will be used
  94. if args is None:
  95. args = self.args
  96. # response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
  97. response = self.client.chat.completions.create(messages=session.messages, **args)
  98. # logger.debug("[ZHIPU_AI] response={}".format(response))
  99. # logger.info("[ZHIPU_AI] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
  100. return {
  101. "total_tokens": response.usage.total_tokens,
  102. "completion_tokens": response.usage.completion_tokens,
  103. "content": response.choices[0].message.content,
  104. }
  105. except Exception as e:
  106. need_retry = retry_count < 2
  107. result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
  108. if isinstance(e, openai.RateLimitError):
  109. logger.warn("[ZHIPU_AI] RateLimitError: {}".format(e))
  110. result["content"] = "提问太快啦,请休息一下再问我吧"
  111. if need_retry:
  112. time.sleep(20)
  113. elif isinstance(e, openai.Timeout):
  114. logger.warn("[ZHIPU_AI] Timeout: {}".format(e))
  115. result["content"] = "我没有收到你的消息"
  116. if need_retry:
  117. time.sleep(5)
  118. elif isinstance(e, openai.APIError):
  119. logger.warn("[ZHIPU_AI] Bad Gateway: {}".format(e))
  120. result["content"] = "请再问我一次"
  121. if need_retry:
  122. time.sleep(10)
  123. elif isinstance(e, openai.APIConnectionError):
  124. logger.warn("[ZHIPU_AI] APIConnectionError: {}".format(e))
  125. result["content"] = "我连接不到你的网络"
  126. if need_retry:
  127. time.sleep(5)
  128. else:
  129. logger.exception("[ZHIPU_AI] Exception: {}".format(e), e)
  130. need_retry = False
  131. self.sessions.clear_session(session.session_id)
  132. if need_retry:
  133. logger.warn("[ZHIPU_AI] 第{}次重试".format(retry_count + 1))
  134. return self.reply_text(session, api_key, args, retry_count + 1)
  135. else:
  136. return result