فهرست منبع

feat: reset will cancel unprocessed messages

lanvent 3 سال پیش
والد
کامیت
28eb67bc24
2فایلهای تغییر یافته به همراه21 افزوده شده و 4 حذف شده
  1. 18 4
      channel/chat_channel.py
  2. 3 0
      plugins/godcmd/godcmd.py

+ 18 - 4
channel/chat_channel.py

@@ -243,9 +243,9 @@ class ChatChannel(Channel):
         session_id = context['session_id']
         with self.lock:
             if session_id not in self.sessions:
-                self.sessions[session_id] = (Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 1)))
+                self.sessions[session_id] = [Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 1))]
             if context.type == ContextType.TEXT and context.content.startswith("#"): 
-                self.sessions[session_id][0].putleft(context) # 优先处理命令
+                self.sessions[session_id][0].putleft(context) # 优先处理管理命令
             else:
                 self.sessions[session_id][0].put(context)
 
@@ -273,12 +273,26 @@ class ChatChannel(Channel):
                             semaphore.release()
             time.sleep(0.1)
 
-    def cancel(self, session_id):
+    # 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
+    def cancel_session(self, session_id): 
         with self.lock:
             if session_id in self.sessions:
                 for future in self.futures[session_id]:
                     future.cancel()
-                self.sessions[session_id][0]=Dequeue()
+                cnt = self.sessions[session_id][0].qsize()
+                if cnt>0:
+                    logger.info("Cancel {} messages in session {}".format(cnt, session_id))
+                self.sessions[session_id][0] = Dequeue()
+    
+    def cancel_all_session(self):
+        with self.lock:
+            for session_id in self.sessions:
+                for future in self.futures[session_id]:
+                    future.cancel()
+                cnt = self.sessions[session_id][0].qsize()
+                if cnt>0:
+                    logger.info("Cancel {} messages in session {}".format(cnt, session_id))
+                self.sessions[session_id][0] = Dequeue()
     
 
 def check_prefix(content, prefix_list):

+ 3 - 0
plugins/godcmd/godcmd.py

@@ -146,6 +146,7 @@ class Godcmd(Plugin):
         logger.debug("[Godcmd] on_handle_context. content: %s" % content)
         if content.startswith("#"):
             # msg = e_context['context']['msg']
+            channel = e_context['channel']
             user = e_context['context']['receiver']
             session_id = e_context['context']['session_id']
             isgroup = e_context['context']['isgroup']
@@ -181,6 +182,7 @@ class Godcmd(Plugin):
                 elif cmd == "reset":
                     if bottype in (const.CHATGPT, const.OPEN_AI):
                         bot.sessions.clear_session(session_id)
+                        channel.cancel_session(session_id)
                         ok, result = True, "会话已重置"
                     else:
                         ok, result = False, "当前对话机器人不支持重置会话"
@@ -202,6 +204,7 @@ class Godcmd(Plugin):
                             ok, result = True, "配置已重载"
                         elif cmd == "resetall":
                             if bottype in (const.CHATGPT, const.OPEN_AI):
+                                channel.cancel_all_session()
                                 bot.sessions.clear_all_session()
                                 ok, result = True, "重置所有会话成功"
                             else: