Эх сурвалжийг харах

feat: optimize consumer thread pool

zhayujie 2 жил өмнө
parent
commit
af5bc73dc0

+ 19 - 14
app.py

@@ -3,6 +3,7 @@
 import os
 import os
 import signal
 import signal
 import sys
 import sys
+import time
 
 
 from channel import channel_factory
 from channel import channel_factory
 from common import const
 from common import const
@@ -24,6 +25,21 @@ def sigterm_handler_wrap(_signo):
     signal.signal(_signo, func)
     signal.signal(_signo, func)
 
 
 
 
+def start_channel(channel_name: str):
+    channel = channel_factory.create_channel(channel_name)
+    if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework",
+                        const.FEISHU, const.DINGTALK]:
+        PluginManager().load_plugins()
+
+    if conf().get("use_linkai"):
+        try:
+            from common import linkai_client
+            threading.Thread(target=linkai_client.start, args=(channel,)).start()
+        except Exception as e:
+            pass
+    channel.startup()
+
+
 def run():
 def run():
     try:
     try:
         # load config
         # load config
@@ -41,22 +57,11 @@ def run():
 
 
         if channel_name == "wxy":
         if channel_name == "wxy":
             os.environ["WECHATY_LOG"] = "warn"
             os.environ["WECHATY_LOG"] = "warn"
-            # os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
-
-        channel = channel_factory.create_channel(channel_name)
-        if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework", const.FEISHU,const.DINGTALK]:
-            PluginManager().load_plugins()
-
-        if conf().get("use_linkai"):
-            try:
-                from common import linkai_client
-                threading.Thread(target=linkai_client.start, args=(channel, )).start()
-            except Exception as e:
-                pass
 
 
-        # startup channel
-        channel.startup()
+        start_channel(channel_name)
 
 
+        while True:
+            time.sleep(1)
     except Exception as e:
     except Exception as e:
         logger.error("App startup failed!")
         logger.error("App startup failed!")
         logger.exception(e)
         logger.exception(e)

+ 1 - 1
bot/linkai/link_ai_bot.py

@@ -400,7 +400,7 @@ class LinkAIBot(Bot):
                 i += 1
                 i += 1
                 if url.endswith(".mp4"):
                 if url.endswith(".mp4"):
                     reply_type = ReplyType.VIDEO_URL
                     reply_type = ReplyType.VIDEO_URL
-                elif url.endswith(".pdf") or url.endswith(".doc") or url.endswith(".docx"):
+                elif url.endswith(".pdf") or url.endswith(".doc") or url.endswith(".docx") or url.endswith(".csv"):
                     reply_type = ReplyType.FILE
                     reply_type = ReplyType.FILE
                     url = _download_file(url)
                     url = _download_file(url)
                     if not url:
                     if not url:

+ 4 - 2
channel/chat_channel.py

@@ -4,6 +4,7 @@ import threading
 import time
 import time
 from asyncio import CancelledError
 from asyncio import CancelledError
 from concurrent.futures import Future, ThreadPoolExecutor
 from concurrent.futures import Future, ThreadPoolExecutor
+from concurrent import futures
 
 
 from bridge.context import *
 from bridge.context import *
 from bridge.reply import *
 from bridge.reply import *
@@ -17,6 +18,8 @@ try:
 except Exception as e:
 except Exception as e:
     pass
     pass
 
 
+handler_pool = ThreadPoolExecutor(max_workers=8)  # 处理消息的线程池
+
 
 
 # 抽象类, 它包含了与消息通道无关的通用处理逻辑
 # 抽象类, 它包含了与消息通道无关的通用处理逻辑
 class ChatChannel(Channel):
 class ChatChannel(Channel):
@@ -25,7 +28,6 @@ class ChatChannel(Channel):
     futures = {}  # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
     futures = {}  # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
     sessions = {}  # 用于控制并发,每个session_id同时只能有一个context在处理
     sessions = {}  # 用于控制并发,每个session_id同时只能有一个context在处理
     lock = threading.Lock()  # 用于控制对sessions的访问
     lock = threading.Lock()  # 用于控制对sessions的访问
-    handler_pool = ThreadPoolExecutor(max_workers=8)  # 处理消息的线程池
 
 
     def __init__(self):
     def __init__(self):
         _thread = threading.Thread(target=self.consume)
         _thread = threading.Thread(target=self.consume)
@@ -339,7 +341,7 @@ class ChatChannel(Channel):
                         if not context_queue.empty():
                         if not context_queue.empty():
                             context = context_queue.get()
                             context = context_queue.get()
                             logger.debug("[WX] consume context: {}".format(context))
                             logger.debug("[WX] consume context: {}".format(context))
-                            future: Future = self.handler_pool.submit(self._handle, context)
+                            future: Future = handler_pool.submit(self._handle, context)
                             future.add_done_callback(self._thread_pool_callback(session_id, context=context))
                             future.add_done_callback(self._thread_pool_callback(session_id, context=context))
                             if session_id not in self.futures:
                             if session_id not in self.futures:
                                 self.futures[session_id] = []
                                 self.futures[session_id] = []

+ 32 - 24
channel/wechat/wechat_channel.py

@@ -15,6 +15,7 @@ import requests
 from bridge.context import *
 from bridge.context import *
 from bridge.reply import *
 from bridge.reply import *
 from channel.chat_channel import ChatChannel
 from channel.chat_channel import ChatChannel
+from channel import chat_channel
 from channel.wechat.wechat_message import *
 from channel.wechat.wechat_message import *
 from common.expired_dict import ExpiredDict
 from common.expired_dict import ExpiredDict
 from common.log import logger
 from common.log import logger
@@ -112,30 +113,39 @@ class WechatChannel(ChatChannel):
         self.auto_login_times = 0
         self.auto_login_times = 0
 
 
     def startup(self):
     def startup(self):
-        itchat.instance.receivingRetryCount = 600  # 修改断线超时时间
-        # login by scan QRCode
-        hotReload = conf().get("hot_reload", False)
-        status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
-        itchat.auto_login(
-            enableCmdQR=2,
-            hotReload=hotReload,
-            statusStorageDir=status_path,
-            qrCallback=qrCallback,
-            exitCallback=self.exitCallback,
-            loginCallback=self.loginCallback
-        )
-        self.user_id = itchat.instance.storageClass.userName
-        self.name = itchat.instance.storageClass.nickName
-        logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
-        # start message listener
-        itchat.run()
+        try:
+            itchat.instance.receivingRetryCount = 600  # 修改断线超时时间
+            # login by scan QRCode
+            hotReload = conf().get("hot_reload", False)
+            status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
+            itchat.auto_login(
+                enableCmdQR=2,
+                hotReload=hotReload,
+                statusStorageDir=status_path,
+                qrCallback=qrCallback,
+                exitCallback=self.exitCallback,
+                loginCallback=self.loginCallback
+            )
+            self.user_id = itchat.instance.storageClass.userName
+            self.name = itchat.instance.storageClass.nickName
+            logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
+            # start message listener
+            itchat.run()
+        except Exception as e:
+            logger.error(e)
 
 
     def exitCallback(self):
     def exitCallback(self):
-        _send_logout()
-        time.sleep(3)
-        self.auto_login_times += 1
-        if self.auto_login_times < 100:
-            self.startup()
+        try:
+            from common.linkai_client import chat_client
+            if chat_client.client_id and conf().get("use_linkai"):
+                _send_logout()
+                time.sleep(2)
+                self.auto_login_times += 1
+                if self.auto_login_times < 100:
+                    chat_channel.handler_pool._shutdown = False
+                    self.startup()
+        except Exception as e:
+            pass
 
 
     def loginCallback(self):
     def loginCallback(self):
         logger.debug("Login success")
         logger.debug("Login success")
@@ -259,7 +269,6 @@ def _send_login_success():
 def _send_logout():
 def _send_logout():
     try:
     try:
         from common.linkai_client import chat_client
         from common.linkai_client import chat_client
-        time.sleep(2)
         if chat_client.client_id:
         if chat_client.client_id:
             chat_client.send_logout()
             chat_client.send_logout()
     except Exception as e:
     except Exception as e:
@@ -268,7 +277,6 @@ def _send_logout():
 def _send_qr_code(qrcode_list: list):
 def _send_qr_code(qrcode_list: list):
     try:
     try:
         from common.linkai_client import chat_client
         from common.linkai_client import chat_client
-        time.sleep(2)
         if chat_client.client_id:
         if chat_client.client_id:
             chat_client.send_qrcode(qrcode_list)
             chat_client.send_qrcode(qrcode_list)
     except Exception as e:
     except Exception as e:

+ 26 - 1
common/linkai_client.py

@@ -2,7 +2,9 @@ from bridge.context import Context, ContextType
 from bridge.reply import Reply, ReplyType
 from bridge.reply import Reply, ReplyType
 from common.log import logger
 from common.log import logger
 from linkai import LinkAIClient, PushMsg
 from linkai import LinkAIClient, PushMsg
-from config import conf
+from config import conf, pconf, plugin_config
+from plugins import PluginManager
+
 
 
 chat_client: LinkAIClient
 chat_client: LinkAIClient
 
 
@@ -22,6 +24,29 @@ class ChatClient(LinkAIClient):
         context["isgroup"] = push_msg.is_group
         context["isgroup"] = push_msg.is_group
         self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context)
         self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context)
 
 
+    def on_config(self, config: dict):
+        if not self.client_id:
+            return
+        logger.info(f"从控制台加载配置: {config}")
+        local_config = conf()
+        for key in local_config.keys():
+            if config.get(key) is not None:
+                local_config[key] = config.get(key)
+        if config.get("reply_voice_mode"):
+            if config.get("reply_voice_mode") == "voice_reply_voice":
+                local_config["voice_reply_voice"] = True
+            elif config.get("reply_voice_mode") == "always_reply_voice":
+                local_config["always_reply_voice"] = True
+        # if config.get("admin_password") and plugin_config["Godcmd"]:
+        #     plugin_config["Godcmd"]["password"] = config.get("admin_password")
+        #     PluginManager().instances["Godcmd"].reload()
+        # if config.get("group_app_map") and pconf("linkai"):
+        #     local_group_map = {}
+        #     for mapping in config.get("group_app_map"):
+        #         local_group_map[mapping.get("group_name")] = mapping.get("app_code")
+        #     pconf("linkai")["group_app_map"] = local_group_map
+        #     PluginManager().instances["linkai"].reload()
+
 
 
 def start(channel):
 def start(channel):
     global chat_client
     global chat_client

+ 8 - 0
plugins/godcmd/godcmd.py

@@ -475,3 +475,11 @@ class Godcmd(Plugin):
         if model == "gpt-4-turbo":
         if model == "gpt-4-turbo":
             return const.GPT4_TURBO_PREVIEW
             return const.GPT4_TURBO_PREVIEW
         return model
         return model
+
+    def reload(self):
+        gconf = plugin_config[self.name]
+        if gconf:
+            if gconf.get("password"):
+                self.password = gconf["password"]
+            if gconf.get("admin_users"):
+                self.admin_users = gconf["admin_users"]

+ 3 - 0
plugins/plugin.py

@@ -46,3 +46,6 @@ class Plugin:
 
 
     def get_help_text(self, **kwargs):
     def get_help_text(self, **kwargs):
         return "暂无帮助信息"
         return "暂无帮助信息"
+
+    def reload(self):
+        pass