Kaynağa Gözat

fix: custom GPT model bug

lanvent 2 yıl önce
ebeveyn
işleme
b476085110
1 değiştirilmiş dosya ile 11 ekleme ve 5 silme
  1. 11 5
      bot/chatgpt/chat_gpt_bot.py

+ 11 - 5
bot/chatgpt/chat_gpt_bot.py

@@ -66,12 +66,16 @@ class ChatGPTBot(Bot, OpenAIImage):
             logger.debug("[CHATGPT] session query={}".format(session.messages))
 
             api_key = context.get("openai_api_key")
-            self.args['model'] = context.get('gpt_model') or "gpt-3.5-turbo"
+            model = context.get("gpt_model")
+            new_args = None
+            if model:
+                new_args = self.args.copy()
+                new_args["model"] = model
             # if context.get('stream'):
             #     # reply in stream
             #     return self.reply_text_stream(query, new_query, session_id)
 
-            reply_content = self.reply_text(session, api_key)
+            reply_content = self.reply_text(session, api_key, args=new_args)
             logger.debug(
                 "[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
                     session.messages,
@@ -102,7 +106,7 @@ class ChatGPTBot(Bot, OpenAIImage):
             reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
             return reply
 
-    def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> dict:
+    def reply_text(self, session: ChatGPTSession, api_key=None, args=None, retry_count=0) -> dict:
         """
         call openai's ChatCompletion to get the answer
         :param session: a conversation session
@@ -114,7 +118,9 @@ class ChatGPTBot(Bot, OpenAIImage):
             if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
                 raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
             # if api_key == None, the default openai.api_key will be used
-            response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **self.args)
+            if args is None:
+                args = self.args
+            response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
             # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
             return {
                 "total_tokens": response["usage"]["total_tokens"],
@@ -150,7 +156,7 @@ class ChatGPTBot(Bot, OpenAIImage):
 
             if need_retry:
                 logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1))
-                return self.reply_text(session, api_key, retry_count + 1)
+                return self.reply_text(session, api_key, args, retry_count + 1)
             else:
                 return result