|
@@ -66,12 +66,16 @@ class ChatGPTBot(Bot, OpenAIImage):
|
|
|
logger.debug("[CHATGPT] session query={}".format(session.messages))
|
|
logger.debug("[CHATGPT] session query={}".format(session.messages))
|
|
|
|
|
|
|
|
api_key = context.get("openai_api_key")
|
|
api_key = context.get("openai_api_key")
|
|
|
- self.args['model'] = context.get('gpt_model') or "gpt-3.5-turbo"
|
|
|
|
|
|
|
+ model = context.get("gpt_model")
|
|
|
|
|
+ new_args = None
|
|
|
|
|
+ if model:
|
|
|
|
|
+ new_args = self.args.copy()
|
|
|
|
|
+ new_args["model"] = model
|
|
|
# if context.get('stream'):
|
|
# if context.get('stream'):
|
|
|
# # reply in stream
|
|
# # reply in stream
|
|
|
# return self.reply_text_stream(query, new_query, session_id)
|
|
# return self.reply_text_stream(query, new_query, session_id)
|
|
|
|
|
|
|
|
- reply_content = self.reply_text(session, api_key)
|
|
|
|
|
|
|
+ reply_content = self.reply_text(session, api_key, args=new_args)
|
|
|
logger.debug(
|
|
logger.debug(
|
|
|
"[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
|
"[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
|
|
session.messages,
|
|
session.messages,
|
|
@@ -102,7 +106,7 @@ class ChatGPTBot(Bot, OpenAIImage):
|
|
|
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
|
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
|
|
return reply
|
|
return reply
|
|
|
|
|
|
|
|
- def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> dict:
|
|
|
|
|
|
|
+ def reply_text(self, session: ChatGPTSession, api_key=None, args=None, retry_count=0) -> dict:
|
|
|
"""
|
|
"""
|
|
|
call openai's ChatCompletion to get the answer
|
|
call openai's ChatCompletion to get the answer
|
|
|
:param session: a conversation session
|
|
:param session: a conversation session
|
|
@@ -114,7 +118,9 @@ class ChatGPTBot(Bot, OpenAIImage):
|
|
|
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
|
|
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
|
|
|
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
|
|
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
|
|
|
# if api_key == None, the default openai.api_key will be used
|
|
# if api_key == None, the default openai.api_key will be used
|
|
|
- response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **self.args)
|
|
|
|
|
|
|
+ if args is None:
|
|
|
|
|
+ args = self.args
|
|
|
|
|
+ response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
|
|
|
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
|
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
|
|
return {
|
|
return {
|
|
|
"total_tokens": response["usage"]["total_tokens"],
|
|
"total_tokens": response["usage"]["total_tokens"],
|
|
@@ -150,7 +156,7 @@ class ChatGPTBot(Bot, OpenAIImage):
|
|
|
|
|
|
|
|
if need_retry:
|
|
if need_retry:
|
|
|
logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1))
|
|
logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1))
|
|
|
- return self.reply_text(session, api_key, retry_count + 1)
|
|
|
|
|
|
|
+ return self.reply_text(session, api_key, args, retry_count + 1)
|
|
|
else:
|
|
else:
|
|
|
return result
|
|
return result
|
|
|
|
|
|