Procházet zdrojové kódy

feat: support plugins

lanvent před 3 roky
rodič
revize
0fcf0824dc

+ 1 - 0
.gitignore

@@ -7,3 +7,4 @@ config.json
 QR.png
 nohup.out
 tmp
+plugins.json

+ 5 - 2
app.py

@@ -4,14 +4,17 @@ import config
 from channel import channel_factory
 from common.log import logger
 
-
+from plugins import *
 if __name__ == '__main__':
     try:
         # load config
         config.load_config()
 
         # create channel
-        channel = channel_factory.create_channel("wx")
+        channel_name='wx'
+        channel = channel_factory.create_channel(channel_name)
+        if channel_name=='wx':
+            PluginManager().load_plugins()
 
         # startup channel
         channel.startup()

+ 7 - 2
bot/chatgpt/chat_gpt_bot.py

@@ -60,12 +60,13 @@ class ChatGPTBot(Bot):
             ok, retstring = self.create_img(query, 0)
             reply = None
             if ok:
-                reply = {'type': 'IMAGE', 'content': retstring}
+                reply = {'type': 'IMAGE_URL', 'content': retstring}
             else:
                 reply = {'type': 'ERROR', 'content': retstring}
             return reply
         else:
             reply= {'type':'ERROR', 'content':'Bot不支持处理{}类型的消息'.format(context['type'])}
+            return reply
 
     def reply_text(self, session, session_id, retry_count=0) -> dict:
         '''
@@ -139,7 +140,11 @@ class ChatGPTBot(Bot):
 
 class SessionManager(object):
     def __init__(self):
-        self.sessions = {}
+        if conf().get('expires_in_seconds'):
+            sessions = ExpiredDict(conf().get('expires_in_seconds'))
+        else:
+            sessions = dict()
+        self.sessions = sessions
 
     def build_session_query(self, query, session_id):
         '''

+ 57 - 46
channel/wechat/wechat_channel.py

@@ -12,9 +12,12 @@ from concurrent.futures import ThreadPoolExecutor
 from common.log import logger
 from common.tmp_dir import TmpDir
 from config import conf
+from plugins import *
+
 import requests
 import io
 
+
 thread_pool = ThreadPoolExecutor(max_workers=8)
 
 
@@ -49,8 +52,8 @@ class WechatChannel(Channel):
 
     # handle_* 系列函数处理收到的消息后构造context,然后调用handle函数处理context
     # context是一个字典,包含了消息的所有信息,包括以下key
-    #   type: 消息类型,包括TEXT、VOICE、CMD_IMAGE_CREATE
-    #   content: 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是CMD_IMAGE_CREATE类型,content就是图片生成命令
+    #   type: 消息类型,包括TEXT、VOICE、IMAGE_CREATE
+    #   content: 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令
     #   session_id: 会话id
     #   isgroup: 是否是群聊
     #   msg: 原始消息对象
@@ -88,7 +91,7 @@ class WechatChannel(Channel):
         img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
         if img_match_prefix:
             content = content.replace(img_match_prefix, '', 1).strip()
-            context['type'] = 'CMD_IMAGE_CREATE'
+            context['type'] = 'IMAGE_CREATE'
         else:
             context['type'] = 'TEXT'
 
@@ -121,7 +124,7 @@ class WechatChannel(Channel):
             img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
             if img_match_prefix:
                 content = content.replace(img_match_prefix, '', 1).strip()
-                context['type'] = 'CMD_IMAGE_CREATE'
+                context['type'] = 'IMAGE_CREATE'
             else:
                 context['type'] = 'TEXT'
             context['content'] = content
@@ -136,8 +139,7 @@ class WechatChannel(Channel):
 
             thread_pool.submit(self.handle, context)
 
-    # 统一的发送函数,根据reply的type字段发送不同类型的消息
-
+    # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
     def send(self, reply, receiver):
         if reply['type'] == 'TEXT':
             itchat.send(reply['content'], toUserName=receiver)
@@ -163,54 +165,63 @@ class WechatChannel(Channel):
             itchat.send_image(image_storage, toUserName=receiver)
             logger.info('[WX] sendImage, receiver={}'.format(receiver))
 
-    # 处理消息
+    # 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类
     def handle(self, context):
-        content = context['content']
-        reply = None
+        reply = {}
 
         logger.debug('[WX] ready to handle context: {}'.format(context))
+        
         # reply的构建步骤
-        if context['type'] == 'TEXT' or context['type'] == 'CMD_IMAGE_CREATE':
-            reply = super().build_reply_content(content, context)
-        elif context['type'] == 'VOICE':
-            msg = context['msg']
-            file_name = TmpDir().path() + msg['FileName']
-            msg.download(file_name)
-            reply = super().build_voice_to_text(file_name)
-            if reply['type'] != 'ERROR' and reply['type'] != 'INFO':
-                reply = super().build_reply_content(reply['content'], context)
-                if reply['type'] == 'TEXT':
-                    if conf().get('voice_reply_voice'):
-                        reply = super().build_text_to_voice(reply['content'])
-        else:
-            logger.error('[WX] unknown context type: {}'.format(context['type']))
-            return
+        e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply}))
+        reply=e_context['reply']
+        if not e_context.is_pass():
+            logger.debug('[WX] ready to handle context: type={}, content={}'.format(context['type'], context['content']))
+            if context['type'] == 'TEXT' or context['type'] == 'IMAGE_CREATE':
+                reply = super().build_reply_content(context['content'], context)
+            elif context['type'] == 'VOICE':
+                msg = context['msg']
+                file_name = TmpDir().path() + msg['FileName']
+                msg.download(file_name)
+                reply = super().build_voice_to_text(file_name)
+                if reply['type'] != 'ERROR' and reply['type'] != 'INFO':
+                    reply = super().build_reply_content(reply['content'], context)
+                    if reply['type'] == 'TEXT':
+                        if conf().get('voice_reply_voice'):
+                            reply = super().build_text_to_voice(reply['content'])
+            else:
+                logger.error('[WX] unknown context type: {}'.format(context['type']))
+                return
 
         logger.debug('[WX] ready to decorate reply: {}'.format(reply))
+        
         # reply的包装步骤
-        if reply:
-            if reply['type'] == 'TEXT':
-                reply_text = reply['content']
-                if context['isgroup']:
-                    reply_text = '@' + \
-                        context['msg']['ActualNickName'] + \
-                        ' ' + reply_text.strip()
-                    reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
+        if reply and reply['type']:
+            e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
+            reply=e_context['reply']
+            if not e_context.is_pass() and reply and reply['type']:
+                if reply['type'] == 'TEXT':
+                    reply_text = reply['content']
+                    if context['isgroup']:
+                        reply_text = '@' +  context['msg']['ActualNickName'] + ' ' + reply_text.strip()
+                        reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
+                    else:
+                        reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
+                    reply['content'] = reply_text
+                elif reply['type'] == 'ERROR' or reply['type'] == 'INFO':
+                    reply['content'] = reply['type']+": " + reply['content']
+                elif reply['type'] == 'IMAGE_URL' or reply['type'] == 'VOICE' or reply['type'] == 'IMAGE':
+                    pass
                 else:
-                    reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
-                reply['content'] = reply_text
-            elif reply['type'] == 'ERROR' or reply['type'] == 'INFO':
-                reply['content'] = reply['type']+": " + reply['content']
-            elif reply['type'] == 'IMAGE_URL' or reply['type'] == 'VOICE':
-                pass
-            else:
-                logger.error(
-                    '[WX] unknown reply type: {}'.format(reply['type']))
-                return
-        if reply:
-            logger.debug('[WX] ready to send reply: {} to {}'.format(
-                reply, context['receiver']))
-            self.send(reply, context['receiver'])
+                    logger.error('[WX] unknown reply type: {}'.format(reply['type']))
+                    return
+
+        # reply的发送步骤   
+        if reply and reply['type']:
+            e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
+            reply=e_context['reply']
+            if not e_context.is_pass() and reply and reply['type']:
+                logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver']))
+                self.send(reply, context['receiver'])
 
 
 def check_prefix(content, prefix_list):

+ 9 - 0
plugins/__init__.py

@@ -0,0 +1,9 @@
+from .plugin_manager import PluginManager
+from .event import *
+from .plugin import *
+
+instance = PluginManager()
+
+register                    = instance.register
+# load_plugins                = instance.load_plugins
+# emit_event                  = instance.emit_event

+ 49 - 0
plugins/event.py

@@ -0,0 +1,49 @@
+# encoding:utf-8
+
+from enum import Enum
+
+
+class Event(Enum):
+    # ON_RECEIVE_MESSAGE = 1  # 收到消息
+
+    ON_HANDLE_CONTEXT = 2   # 处理消息前
+    """
+    e_context = {  "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空  }
+    """
+
+    ON_DECORATE_REPLY = 3   # 得到回复后准备装饰
+    """
+    e_context = {  "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
+    """
+
+    ON_SEND_REPLY = 4       # 发送回复前
+    """
+    e_context = {  "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
+    """
+
+    # AFTER_SEND_REPLY = 5    # 发送回复后
+
+
+class EventAction(Enum):
+    CONTINUE = 1            # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑
+    BREAK = 2               # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑
+    BREAK_PASS = 3          # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑
+
+
+class EventContext:
+    def __init__(self, event, econtext=dict()):
+        self.event = event
+        self.econtext = econtext
+        self.action = EventAction.CONTINUE
+
+    def __getitem__(self, key):
+        return self.econtext[key]
+
+    def __setitem__(self, key, value):
+        self.econtext[key] = value
+
+    def __delitem__(self, key):
+        del self.econtext[key]
+
+    def is_pass(self):
+        return self.action == EventAction.BREAK_PASS

+ 3 - 0
plugins/plugin.py

@@ -0,0 +1,3 @@
+class Plugin:
+    def __init__(self):
+        self.handlers = {}

+ 89 - 0
plugins/plugin_manager.py

@@ -0,0 +1,89 @@
+# encoding:utf-8
+
+import importlib
+import json
+import os
+from common.singleton import singleton
+from .event import *
+from .plugin import *
+from common.log import logger
+
+
+@singleton
+class PluginManager:
+    def __init__(self):
+        self.plugins = {}
+        self.listening_plugins = {}
+        self.instances = {}
+
+    def register(self, name: str, desc: str, version: str, author: str):
+        def wrapper(plugincls):
+            self.plugins[name] = plugincls
+            plugincls.name = name
+            plugincls.desc = desc
+            plugincls.version = version
+            plugincls.author = author
+            plugincls.enabled = True
+            logger.info("Plugin %s registered" % name)
+            return plugincls
+        return wrapper
+
+    def save_config(self, pconf):
+        with open("plugins/plugins.json", "w", encoding="utf-8") as f:
+            json.dump(pconf, f, indent=4, ensure_ascii=False)
+
+    def load_config(self):
+        logger.info("Loading plugins config...")
+        plugins_dir = "plugins"
+        for plugin_name in os.listdir(plugins_dir):
+            plugin_path = os.path.join(plugins_dir, plugin_name)
+            if os.path.isdir(plugin_path):
+                # 判断插件是否包含main.py文件
+                main_module_path = os.path.join(plugin_path, "main.py")
+                if os.path.isfile(main_module_path):
+                    # 导入插件的main
+                    import_path = "{}.{}.main".format(plugins_dir, plugin_name)
+                    main_module = importlib.import_module(import_path)
+
+        modified = False
+        if os.path.exists("plugins/plugins.json"):
+            with open("plugins/plugins.json", "r", encoding="utf-8") as f:
+                pconf = json.load(f)
+        else:
+            modified = True
+            pconf = {"plugins": []}
+        for name, plugincls in self.plugins.items():
+            if name not in [plugin["name"] for plugin in pconf["plugins"]]:
+                modified = True
+                logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name)
+                pconf["plugins"].append({"name": name, "enabled": True})
+        if modified:
+            self.save_config(pconf)
+        return pconf
+
+    def load_plugins(self):
+        pconf = self.load_config()
+
+        for plugin in pconf["plugins"]:
+            name = plugin["name"]
+            enabled = plugin["enabled"]
+            self.plugins[name].enabled = enabled
+
+        for name, plugincls in self.plugins.items():
+            if plugincls.enabled:
+                if name not in self.instances:
+                    instance = plugincls()
+                    self.instances[name] = instance
+                    for event in instance.handlers:
+                        if event not in self.listening_plugins:
+                            self.listening_plugins[event] = []
+                        self.listening_plugins[event].append(name)
+
+    def emit_event(self, e_context: EventContext, *args, **kwargs):
+        if e_context.event in self.listening_plugins:
+            for name in self.listening_plugins[e_context.event]:
+                if e_context.action == EventAction.CONTINUE:
+                    logger.debug("Plugin %s triggered by event %s" % (name,e_context.event))
+                    instance = self.instances[name]
+                    instance.handlers[e_context.event](e_context, *args, **kwargs)
+        return e_context