zwssunny пре 3 година
родитељ
комит
5de600c689
1 измењених фајлова са 22 додато и 20 уклоњено
  1. 22 20
      bot/chatgpt/chat_gpt_bot.py

+ 22 - 20
bot/chatgpt/chat_gpt_bot.py

@@ -40,21 +40,21 @@ class ChatGPTBot(Bot):
             #     return self.reply_text_stream(query, new_query, from_user_id)
 
             reply_content = self.reply_text(new_query, from_user_id, 0)
-            logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
-            if reply_content:
-                Session.save_session(query, reply_content, from_user_id)
-            return reply_content[1]
+            logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content["content"]))
+            if reply_content["completion_tokens"] > 0:
+                Session.save_session(reply_content["content"], from_user_id, reply_content["total_tokens"])
+            return reply_content["content"]
 
         elif context.get('type', None) == 'IMAGE_CREATE':
             return self.create_img(query, 0)
 
-    def reply_text(self, query, user_id, retry_count=0):
+    def reply_text(self, query, user_id, retry_count=0) ->dict:
         '''
         call openai's ChatCompletion to get the answer
         :param query: query content
         :param user_id: from user id
         :param retry_count: retry count
-        :return: [0]-tokens used and [1]-answer
+        :return: {}
         '''
         try:
             response = openai.ChatCompletion.create(
@@ -68,8 +68,9 @@ class ChatGPTBot(Bot):
             )
             # res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
             logger.info(response.choices[0]['message']['content'])
-
-            return response["usage"]["prompt_tokens"],response.choices[0]['message']['content']
+            return {"total_tokens": response["usage"]["total_tokens"], 
+                    "completion_tokens": response["usage"]["completion_tokens"], 
+                    "content": response.choices[0]['message']['content']}
         except openai.error.RateLimitError as e:
             # rate limit exception
             logger.warn(e)
@@ -78,21 +79,21 @@ class ChatGPTBot(Bot):
                 logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
                 return self.reply_text(query, user_id, retry_count+1)
             else:
-                return 0,"提问太快啦,请休息一下再问我吧"
+                return {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
         except openai.error.APIConnectionError as e:
             # api connection exception
             logger.warn(e)
             logger.warn("[OPEN_AI] APIConnection failed")
-            return 0,"我连接不到你的网络"
+            return {"completion_tokens": 0, "content":"我连接不到你的网络"}
         except openai.error.Timeout as e:
             logger.warn(e)
             logger.warn("[OPEN_AI] Timeout")
-            return 0,"我没有收到你的消息"
+            return {"completion_tokens": 0, "content":"我没有收到你的消息"}
         except Exception as e:
             # unknown exception
             logger.exception(e)
             Session.clear_session(user_id)
-            return 0,"请再问我一次吧"
+            return {"completion_tokens": 0, "content": "请再问我一次吧"}
 
     def create_img(self, query, retry_count=0):
         try:
@@ -143,7 +144,7 @@ class Session(object):
         return session
 
     @staticmethod
-    def save_session(query, answer, user_id):
+    def save_session(answer, user_id, total_tokens):
         max_tokens = conf().get("conversation_max_tokens")
         if not max_tokens:
             # default 3000
@@ -153,22 +154,23 @@ class Session(object):
         session = user_session.get(user_id)
         if session:
             # append conversation
-            gpt_item = {'role': 'assistant', 'content': answer[1]}
+            gpt_item = {'role': 'assistant', 'content': answer}
             session.append(gpt_item)
 
         # discard exceed limit conversation
-        used_tokens=int(answer[0])
-        # logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
+        Session.discard_exceed_conversation(session, max_tokens, total_tokens)
 
-        while used_tokens > max_tokens:
+    @staticmethod
+    def discard_exceed_conversation(session, max_tokens, total_tokens):
+        dec_tokens=int(total_tokens)
+        # logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
+        while dec_tokens > max_tokens:
             # pop first conversation
             if len(session) > 0:
                 session.pop(0) 
             else:
                 break    
-
-            used_tokens=used_tokens-max_tokens
-
+            dec_tokens=dec_tokens-max_tokens
 
     @staticmethod
     def clear_session(user_id):