|
|
@@ -1,9 +1,13 @@
|
|
|
|
|
|
|
|
|
-
|
|
|
+from asyncio import CancelledError
|
|
|
+import queue
|
|
|
+from concurrent.futures import Future, ThreadPoolExecutor
|
|
|
import os
|
|
|
import re
|
|
|
+import threading
|
|
|
import time
|
|
|
+from channel.chat_message import ChatMessage
|
|
|
from common.expired_dict import ExpiredDict
|
|
|
from channel.channel import Channel
|
|
|
from bridge.reply import *
|
|
|
@@ -20,8 +24,16 @@ except Exception as e:
|
|
|
class ChatChannel(Channel):
|
|
|
name = None # 登录的用户名
|
|
|
user_id = None # 登录的用户id
|
|
|
+ futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
|
|
|
+ sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
|
|
|
+ lock = threading.Lock() # 用于控制对sessions的访问
|
|
|
+ handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
|
|
|
+
|
|
|
def __init__(self):
|
|
|
- pass
|
|
|
+ _thread = threading.Thread(target=self.consume)
|
|
|
+ _thread.setDaemon(True)
|
|
|
+ _thread.start()
|
|
|
+
|
|
|
|
|
|
# 根据消息构造context,消息内容相关的触发项写在这里
|
|
|
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
|
|
@@ -215,6 +227,57 @@ class ChatChannel(Channel):
|
|
|
time.sleep(3+3*retry_cnt)
|
|
|
self._send(reply, context, retry_cnt+1)
|
|
|
|
|
|
+ def thread_pool_callback(self, session_id):
|
|
|
+ def func(worker:Future):
|
|
|
+ try:
|
|
|
+ worker_exception = worker.exception()
|
|
|
+ if worker_exception:
|
|
|
+ logger.exception("Worker return exception: {}".format(worker_exception))
|
|
|
+ except CancelledError as e:
|
|
|
+ logger.info("Worker cancelled, session_id = {}".format(session_id))
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception("Worker raise exception: {}".format(e))
|
|
|
+ with self.lock:
|
|
|
+ self.sessions[session_id][1].release()
|
|
|
+ return func
|
|
|
+
|
|
|
+ def produce(self, context: Context):
|
|
|
+ session_id = context['session_id']
|
|
|
+ with self.lock:
|
|
|
+ if session_id not in self.sessions:
|
|
|
+ self.sessions[session_id] = (queue.Queue(), threading.BoundedSemaphore(1))
|
|
|
+ self.sessions[session_id][0].put(context)
|
|
|
+
|
|
|
+ # 消费者函数,单独线程,用于从消息队列中取出消息并处理
|
|
|
+ def consume(self):
|
|
|
+ while True:
|
|
|
+ with self.lock:
|
|
|
+ session_ids = list(self.sessions.keys())
|
|
|
+ for session_id in session_ids:
|
|
|
+ context_queue, semaphore = self.sessions[session_id]
|
|
|
+ if semaphore.acquire(blocking = False): # 等线程处理完毕才能删除
|
|
|
+ if not context_queue.empty():
|
|
|
+ context = context_queue.get()
|
|
|
+ logger.debug("[WX] consume context: {}".format(context))
|
|
|
+ future:Future = self.handler_pool.submit(self._handle, context)
|
|
|
+ future.add_done_callback(self.thread_pool_callback(session_id))
|
|
|
+ if session_id not in self.futures:
|
|
|
+ self.futures[session_id] = []
|
|
|
+ self.futures[session_id].append(future)
|
|
|
+ elif semaphore._initial_value == semaphore._value+1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
|
|
|
+ self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
|
|
|
+ assert len(self.futures[session_id]) == 0, "thread pool error"
|
|
|
+ del self.sessions[session_id]
|
|
|
+ else:
|
|
|
+ semaphore.release()
|
|
|
+ time.sleep(0.1)
|
|
|
+
|
|
|
+ def cancel(self, session_id):
|
|
|
+ with self.lock:
|
|
|
+ if session_id in self.sessions:
|
|
|
+ for future in self.futures[session_id]:
|
|
|
+ future.cancel()
|
|
|
+ self.sessions[session_id][0]=queue.Queue()
|
|
|
|
|
|
|
|
|
def check_prefix(content, prefix_list):
|