chat_gpt_bot.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. # encoding:utf-8
  2. from bot.bot import Bot
  3. from bot.openai.open_ai_image import OpenAIImage
  4. from bridge.context import ContextType
  5. from bridge.reply import Reply, ReplyType
  6. from config import conf, load_config
  7. from common.log import logger
  8. from common.token_bucket import TokenBucket
  9. from common.expired_dict import ExpiredDict
  10. import openai
  11. import time
  12. # OpenAI对话模型API (可用)
  13. class ChatGPTBot(Bot,OpenAIImage):
  14. def __init__(self):
  15. openai.api_key = conf().get('open_ai_api_key')
  16. if conf().get('open_ai_api_base'):
  17. openai.api_base = conf().get('open_ai_api_base')
  18. proxy = conf().get('proxy')
  19. self.sessions = SessionManager(model= conf().get("model") or "gpt-3.5-turbo")
  20. if proxy:
  21. openai.proxy = proxy
  22. if conf().get('rate_limit_chatgpt'):
  23. self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20))
  24. def reply(self, query, context=None):
  25. # acquire reply content
  26. if context.type == ContextType.TEXT:
  27. logger.info("[OPEN_AI] query={}".format(query))
  28. session_id = context['session_id']
  29. reply = None
  30. clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆'])
  31. if query in clear_memory_commands:
  32. self.sessions.clear_session(session_id)
  33. reply = Reply(ReplyType.INFO, '记忆已清除')
  34. elif query == '#清除所有':
  35. self.sessions.clear_all_session()
  36. reply = Reply(ReplyType.INFO, '所有人记忆已清除')
  37. elif query == '#更新配置':
  38. load_config()
  39. reply = Reply(ReplyType.INFO, '配置已更新')
  40. if reply:
  41. return reply
  42. session = self.sessions.build_session_query(query, session_id)
  43. logger.debug("[OPEN_AI] session query={}".format(session))
  44. # if context.get('stream'):
  45. # # reply in stream
  46. # return self.reply_text_stream(query, new_query, session_id)
  47. reply_content = self.reply_text(session, session_id, 0)
  48. logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session, session_id, reply_content["content"], reply_content["completion_tokens"]))
  49. if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
  50. reply = Reply(ReplyType.ERROR, reply_content['content'])
  51. elif reply_content["completion_tokens"] > 0:
  52. self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"])
  53. reply = Reply(ReplyType.TEXT, reply_content["content"])
  54. else:
  55. reply = Reply(ReplyType.ERROR, reply_content['content'])
  56. logger.debug("[OPEN_AI] reply {} used 0 tokens.".format(reply_content))
  57. return reply
  58. elif context.type == ContextType.IMAGE_CREATE:
  59. ok, retstring = self.create_img(query, 0)
  60. reply = None
  61. if ok:
  62. reply = Reply(ReplyType.IMAGE_URL, retstring)
  63. else:
  64. reply = Reply(ReplyType.ERROR, retstring)
  65. return reply
  66. else:
  67. reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type))
  68. return reply
  69. def compose_args(self):
  70. return {
  71. "model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
  72. "temperature":conf().get('temperature', 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
  73. # "max_tokens":4096, # 回复最大的字符数
  74. "top_p":1,
  75. "frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
  76. "presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
  77. }
  78. def reply_text(self, session, session_id, retry_count=0) -> dict:
  79. '''
  80. call openai's ChatCompletion to get the answer
  81. :param session: a conversation session
  82. :param session_id: session id
  83. :param retry_count: retry count
  84. :return: {}
  85. '''
  86. try:
  87. if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token():
  88. return {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
  89. response = openai.ChatCompletion.create(
  90. messages=session, **self.compose_args()
  91. )
  92. # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
  93. return {"total_tokens": response["usage"]["total_tokens"],
  94. "completion_tokens": response["usage"]["completion_tokens"],
  95. "content": response.choices[0]['message']['content']}
  96. except openai.error.RateLimitError as e:
  97. # rate limit exception
  98. logger.warn(e)
  99. if retry_count < 1:
  100. time.sleep(5)
  101. logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
  102. return self.reply_text(session, session_id, retry_count+1)
  103. else:
  104. return {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
  105. except openai.error.APIConnectionError as e:
  106. # api connection exception
  107. logger.warn(e)
  108. logger.warn("[OPEN_AI] APIConnection failed")
  109. return {"completion_tokens": 0, "content": "我连接不到你的网络"}
  110. except openai.error.Timeout as e:
  111. logger.warn(e)
  112. logger.warn("[OPEN_AI] Timeout")
  113. return {"completion_tokens": 0, "content": "我没有收到你的消息"}
  114. except Exception as e:
  115. # unknown exception
  116. logger.exception(e)
  117. self.sessions.clear_session(session_id)
  118. return {"completion_tokens": 0, "content": "请再问我一次吧"}
  119. class AzureChatGPTBot(ChatGPTBot):
  120. def __init__(self):
  121. super().__init__()
  122. openai.api_type = "azure"
  123. openai.api_version = "2023-03-15-preview"
  124. def compose_args(self):
  125. args = super().compose_args()
  126. args["engine"] = args["model"]
  127. del(args["model"])
  128. return args
  129. class SessionManager(object):
  130. def __init__(self, model = "gpt-3.5-turbo-0301"):
  131. if conf().get('expires_in_seconds'):
  132. sessions = ExpiredDict(conf().get('expires_in_seconds'))
  133. else:
  134. sessions = dict()
  135. self.sessions = sessions
  136. self.model = model
  137. def build_session(self, session_id, system_prompt=None):
  138. session = self.sessions.get(session_id, [])
  139. if len(session) == 0:
  140. if system_prompt is None:
  141. system_prompt = conf().get("character_desc", "")
  142. system_item = {'role': 'system', 'content': system_prompt}
  143. session.append(system_item)
  144. self.sessions[session_id] = session
  145. return session
  146. def build_session_query(self, query, session_id):
  147. '''
  148. build query with conversation history
  149. e.g. [
  150. {"role": "system", "content": "You are a helpful assistant."},
  151. {"role": "user", "content": "Who won the world series in 2020?"},
  152. {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
  153. {"role": "user", "content": "Where was it played?"}
  154. ]
  155. :param query: query content
  156. :param session_id: session id
  157. :return: query content with conversaction
  158. '''
  159. session = self.build_session(session_id)
  160. user_item = {'role': 'user', 'content': query}
  161. session.append(user_item)
  162. try:
  163. total_tokens = num_tokens_from_messages(session, self.model)
  164. max_tokens = conf().get("conversation_max_tokens", 1000)
  165. total_tokens = self.discard_exceed_conversation(session, max_tokens, total_tokens)
  166. logger.debug("prompt tokens used={}".format(total_tokens))
  167. except Exception as e:
  168. logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
  169. return session
  170. def save_session(self, answer, session_id, total_tokens):
  171. max_tokens = conf().get("conversation_max_tokens", 1000)
  172. session = self.sessions.get(session_id)
  173. if session:
  174. # append conversation
  175. gpt_item = {'role': 'assistant', 'content': answer}
  176. session.append(gpt_item)
  177. # discard exceed limit conversation
  178. tokens_cnt = self.discard_exceed_conversation(session, max_tokens, total_tokens)
  179. logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
  180. def discard_exceed_conversation(self, session, max_tokens, total_tokens):
  181. dec_tokens = int(total_tokens)
  182. # logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
  183. while dec_tokens > max_tokens:
  184. # pop first conversation
  185. if len(session) > 2:
  186. session.pop(1)
  187. elif len(session) == 2 and session[1]["role"] == "assistant":
  188. session.pop(1)
  189. break
  190. elif len(session) == 2 and session[1]["role"] == "user":
  191. logger.warn("user message exceed max_tokens. total_tokens={}".format(dec_tokens))
  192. break
  193. else:
  194. logger.debug("max_tokens={}, total_tokens={}, len(sessions)={}".format(max_tokens, dec_tokens, len(session)))
  195. break
  196. try:
  197. cur_tokens = num_tokens_from_messages(session, self.model)
  198. dec_tokens = cur_tokens
  199. except Exception as e:
  200. logger.debug("Exception when counting tokens precisely for query: {}".format(e))
  201. dec_tokens = dec_tokens - max_tokens
  202. return dec_tokens
  203. def clear_session(self, session_id):
  204. self.sessions[session_id] = []
  205. def clear_all_session(self):
  206. self.sessions.clear()
  207. # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
  208. def num_tokens_from_messages(messages, model):
  209. """Returns the number of tokens used by a list of messages."""
  210. import tiktoken
  211. try:
  212. encoding = tiktoken.encoding_for_model(model)
  213. except KeyError:
  214. logger.debug("Warning: model not found. Using cl100k_base encoding.")
  215. encoding = tiktoken.get_encoding("cl100k_base")
  216. if model == "gpt-3.5-turbo":
  217. return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
  218. elif model == "gpt-4":
  219. return num_tokens_from_messages(messages, model="gpt-4-0314")
  220. elif model == "gpt-3.5-turbo-0301":
  221. tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
  222. tokens_per_name = -1 # if there's a name, the role is omitted
  223. elif model == "gpt-4-0314":
  224. tokens_per_message = 3
  225. tokens_per_name = 1
  226. else:
  227. logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.")
  228. return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
  229. num_tokens = 0
  230. for message in messages:
  231. num_tokens += tokens_per_message
  232. for key, value in message.items():
  233. num_tokens += len(encoding.encode(value))
  234. if key == "name":
  235. num_tokens += tokens_per_name
  236. num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
  237. return num_tokens