Procházet zdrojové kódy

Merge pull request #354 from zwssunny/master

增加处理会话超长问题
zhayujie před 3 roky
rodič
revize
6a98bc2d5a
1 změnil soubory, kde provedl 31 přidání a 0 odebrání
  1. 31 0
      bot/chatgpt/chat_gpt_bot.py

+ 31 - 0
bot/chatgpt/chat_gpt_bot.py

@@ -6,6 +6,7 @@ from common.log import logger
 from common.expired_dict import ExpiredDict
 from common.expired_dict import ExpiredDict
 import openai
 import openai
 import time
 import time
+import json
 
 
 if conf().get('expires_in_seconds'):
 if conf().get('expires_in_seconds'):
     user_session = ExpiredDict(conf().get('expires_in_seconds'))
     user_session = ExpiredDict(conf().get('expires_in_seconds'))
@@ -28,6 +29,9 @@ class ChatGPTBot(Bot):
             if query == '#清除记忆':
             if query == '#清除记忆':
                 Session.clear_session(from_user_id)
                 Session.clear_session(from_user_id)
                 return '记忆已清除'
                 return '记忆已清除'
+            elif query == '#清除所有':
+                Session.clear_all_session()
+                return '所有人记忆已清除'            
 
 
             new_query = Session.build_session_query(query, from_user_id)
             new_query = Session.build_session_query(query, from_user_id)
             logger.debug("[OPEN_AI] session query={}".format(new_query))
             logger.debug("[OPEN_AI] session query={}".format(new_query))
@@ -134,13 +138,40 @@ class Session(object):
 
 
     @staticmethod
     @staticmethod
     def save_session(query, answer, user_id):
     def save_session(query, answer, user_id):
+        max_tokens = conf().get("conversation_max_tokens")
+        if not max_tokens:
+            # default 3000
+            max_tokens = 1000
+
         session = user_session.get(user_id)
         session = user_session.get(user_id)
         if session:
         if session:
             # append conversation
             # append conversation
             gpt_item = {'role': 'assistant', 'content': answer}
             gpt_item = {'role': 'assistant', 'content': answer}
             session.append(gpt_item)
             session.append(gpt_item)
 
 
+        # discard exceed limit conversation
+        Session.discard_exceed_conversation(user_session[user_id], max_tokens) 
+
+    @staticmethod
+    def discard_exceed_conversation(session, max_tokens):
+        count = 0
+        count_list = list()
+        for i in range(len(session)-1, -1, -1):
+        # count tokens of conversation list
+            history_conv = session[i]
+            tokens=json.dumps(history_conv).split()
+            count += len(tokens)
+            count_list.append(count)
+
+        for c in count_list:
+            if c > max_tokens:
+                # pop first conversation
+                session.pop(0)
+
     @staticmethod
     @staticmethod
     def clear_session(user_id):
     def clear_session(user_id):
         user_session[user_id] = []
         user_session[user_id] = []
 
 
+    @staticmethod
+    def clear_all_session():
+        user_session.clear()