|
|
@@ -26,6 +26,9 @@ class ChatGPTBot(Bot):
|
|
|
if query == '#清除记忆':
|
|
|
Session.clear_session(from_user_id)
|
|
|
return '记忆已清除'
|
|
|
+ elif query == '#清除所有':
|
|
|
+ Session.clear_all_session()
|
|
|
+ return '所有人记忆已清除'
|
|
|
|
|
|
new_query = Session.build_session_query(query, from_user_id)
|
|
|
logger.debug("[OPEN_AI] session query={}".format(new_query))
|
|
|
@@ -132,13 +135,40 @@ class Session(object):
|
|
|
|
|
|
@staticmethod
|
|
|
def save_session(query, answer, user_id):
|
|
|
+ max_tokens = conf().get("conversation_max_tokens")
|
|
|
+ if not max_tokens:
|
|
|
+ # default 3000
|
|
|
+ max_tokens = 1000
|
|
|
+
|
|
|
session = user_session.get(user_id)
|
|
|
if session:
|
|
|
# append conversation
|
|
|
gpt_item = {'role': 'assistant', 'content': answer}
|
|
|
session.append(gpt_item)
|
|
|
|
|
|
+ # discard exceed limit conversation
|
|
|
+ Session.discard_exceed_conversation(user_session[user_id], max_tokens)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def discard_exceed_conversation(session, max_tokens):
|
|
|
+ count = 0
|
|
|
+ count_list = list()
|
|
|
+ for i in range(len(session)-1, -1, -1):
|
|
|
+ # count tokens of conversation list
|
|
|
+ history_conv = session[i]
|
|
|
+ tokens=history_conv.split()
|
|
|
+ count += len(tokens)
|
|
|
+ count_list.append(count)
|
|
|
+
|
|
|
+ for c in count_list:
|
|
|
+ if c > max_tokens:
|
|
|
+ # pop first conversation
|
|
|
+ session.pop(0)
|
|
|
+
|
|
|
@staticmethod
|
|
|
def clear_session(user_id):
|
|
|
user_session[user_id] = []
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def clear_all_session():
|
|
|
+ user_session.clear()
|