Explorar o código

feat: avoid disorder by producer-consumer model

lanvent %!s(int64=3) %!d(string=hai) anos
pai
achega
5a221848e9
Modificáronse 3 ficheiros con 81 adicións e 27 borrados
  1. 65 2
      channel/chat_channel.py
  2. 8 13
      channel/wechat/wechat_channel.py
  3. 8 12
      channel/wechat/wechaty_channel.py

+ 65 - 2
channel/chat_channel.py

@@ -1,9 +1,13 @@
 
 
-
+from asyncio import CancelledError
+import queue
+from concurrent.futures import Future, ThreadPoolExecutor
 import os
 import re
+import threading
 import time
+from channel.chat_message import ChatMessage
 from common.expired_dict import ExpiredDict
 from channel.channel import Channel
 from bridge.reply import *
@@ -20,8 +24,16 @@ except Exception as e:
 class ChatChannel(Channel):
     name = None # 登录的用户名
     user_id = None # 登录的用户id
+    futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
+    sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
+    lock = threading.Lock() # 用于控制对sessions的访问
+    handler_pool = ThreadPoolExecutor(max_workers=8)  # 处理消息的线程池
+
     def __init__(self):
-        pass
+        _thread = threading.Thread(target=self.consume)
+        _thread.setDaemon(True)
+        _thread.start()
+        
 
     # 根据消息构造context,消息内容相关的触发项写在这里
     def _compose_context(self, ctype: ContextType, content, **kwargs):
@@ -215,6 +227,57 @@ class ChatChannel(Channel):
                 time.sleep(3+3*retry_cnt)
                 self._send(reply, context, retry_cnt+1)
 
+    def thread_pool_callback(self, session_id):
+        def func(worker:Future):
+            try:
+                worker_exception = worker.exception()
+                if worker_exception:
+                    logger.exception("Worker return exception: {}".format(worker_exception))
+            except CancelledError as e:
+                logger.info("Worker cancelled, session_id = {}".format(session_id))
+            except Exception as e:
+                logger.exception("Worker raise exception: {}".format(e))
+            with self.lock:
+                self.sessions[session_id][1].release()
+        return func
+
+    def produce(self, context: Context):
+        session_id = context['session_id']
+        with self.lock:
+            if session_id not in self.sessions:
+                self.sessions[session_id] = (queue.Queue(), threading.BoundedSemaphore(1))
+            self.sessions[session_id][0].put(context)
+
+    # 消费者函数,单独线程,用于从消息队列中取出消息并处理
+    def consume(self):
+        while True:
+            with self.lock:
+                session_ids = list(self.sessions.keys())
+                for session_id in session_ids:
+                    context_queue, semaphore = self.sessions[session_id]
+                    if semaphore.acquire(blocking = False): # 等线程处理完毕才能删除
+                        if not context_queue.empty():
+                            context = context_queue.get()
+                            logger.debug("[WX] consume context: {}".format(context))
+                            future:Future = self.handler_pool.submit(self._handle, context)
+                            future.add_done_callback(self.thread_pool_callback(session_id))
+                            if session_id not in self.futures:
+                                self.futures[session_id] = []
+                            self.futures[session_id].append(future)
+                        elif semaphore._initial_value == semaphore._value+1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
+                            self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
+                            assert len(self.futures[session_id]) == 0, "thread pool error"
+                            del self.sessions[session_id]
+                        else:
+                            semaphore.release()
+            time.sleep(0.1)
+
+    def cancel(self, session_id):
+        with self.lock:
+            if session_id in self.sessions:
+                for future in self.futures[session_id]:
+                    future.cancel()
+                self.sessions[session_id][0]=queue.Queue()
     
 
 def check_prefix(content, prefix_list):

+ 8 - 13
channel/wechat/wechat_channel.py

@@ -5,6 +5,7 @@ wechat channel
 """
 
 import os
+import threading
 import requests
 import io
 import time
@@ -17,18 +18,10 @@ from lib import itchat
 from lib.itchat.content import *
 from bridge.reply import *
 from bridge.context import *
-from concurrent.futures import ThreadPoolExecutor
 from config import conf
 from common.time_check import time_checker
 from common.expired_dict import ExpiredDict
 from plugins import *
-thread_pool = ThreadPoolExecutor(max_workers=8)
-
-def thread_pool_callback(worker):
-    worker_exception = worker.exception()
-    if worker_exception:
-        logger.exception("Worker return exception: {}".format(worker_exception))
-
 
 @itchat.msg_register(TEXT)
 def handler_single_msg(msg):
@@ -73,7 +66,9 @@ def qrCallback(uuid,status,qrcode):
         try:
             from PIL import Image
             img = Image.open(io.BytesIO(qrcode))
-            thread_pool.submit(img.show,"QRCode")
+            _thread = threading.Thread(target=img.show, args=("QRCode",))
+            _thread.setDaemon(True)
+            _thread.start()
         except Exception as e:
             pass
 
@@ -142,7 +137,7 @@ class WechatChannel(ChatChannel):
         logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
         context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=False, msg=cmsg)
         if context:
-            thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)
+            self.produce(context)
 
     @time_checker
     @_check
@@ -150,7 +145,7 @@ class WechatChannel(ChatChannel):
         logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
         context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=False, msg=cmsg)
         if context:
-            thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)
+            self.produce(context)
 
     @time_checker
     @_check
@@ -158,7 +153,7 @@ class WechatChannel(ChatChannel):
         logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
         context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=True, msg=cmsg)
         if context:
-            thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)
+            self.produce(context)
     
     @time_checker
     @_check
@@ -168,7 +163,7 @@ class WechatChannel(ChatChannel):
         logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
         context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=True, msg=cmsg)
         if context:
-            thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)
+            self.produce(context)
     
     # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
     def send(self, reply: Reply, context: Context):

+ 8 - 12
channel/wechat/wechaty_channel.py

@@ -5,7 +5,6 @@ wechaty channel
 Python Wechaty - https://github.com/wechaty/python-wechaty
 """
 import base64
-from concurrent.futures import ThreadPoolExecutor
 import os
 import time
 import asyncio
@@ -18,21 +17,18 @@ from bridge.context import *
 from channel.chat_channel import ChatChannel
 from channel.wechat.wechaty_message import WechatyMessage
 from common.log import logger
+from common.singleton import singleton
 from config import conf
 try:
     from voice.audio_convert import any_to_sil
 except Exception as e:
     pass
 
-thread_pool = ThreadPoolExecutor(max_workers=8)
-def thread_pool_callback(worker):
-    worker_exception = worker.exception()
-    if worker_exception:
-        logger.exception("Worker return exception: {}".format(worker_exception))
+@singleton
 class WechatyChannel(ChatChannel):
 
     def __init__(self):
-        pass
+        super().__init__()
 
     def startup(self):
         config = conf()
@@ -41,6 +37,10 @@ class WechatyChannel(ChatChannel):
         asyncio.run(self.main())
 
     async def main(self):
+        
+        loop = asyncio.get_event_loop()
+        #将asyncio的loop传入处理线程
+        self.handler_pool._initializer= lambda: asyncio.set_event_loop(loop)
         self.bot = Wechaty()
         self.bot.on('login', self.on_login)
         self.bot.on('message', self.on_message)
@@ -122,8 +122,4 @@ class WechatyChannel(ChatChannel):
         context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg)
         if context:
             logger.info('[WX] receiveMsg={}, context={}'.format(cmsg, context))
-            thread_pool.submit(self._handle_loop, context, asyncio.get_event_loop()).add_done_callback(thread_pool_callback)
-
-    def _handle_loop(self,context,loop):
-        asyncio.set_event_loop(loop)
-        self._handle(context)
+            self.produce(context)