Selaa lähdekoodia

Merge pull request #442 from lanvent/dev

简易支持插件,添加sdwebui(novelai画图), godcmd(管理员指令增强)插件,Banwords(敏感词过滤)插件
zhayujie 3 vuotta sitten
vanhempi
säilyke
2cb30b5f59

+ 1 - 0
.gitignore

@@ -7,3 +7,4 @@ config.json
 QR.png
 nohup.out
 tmp
+plugins.json

+ 5 - 2
app.py

@@ -4,14 +4,17 @@ import config
 from channel import channel_factory
 from common.log import logger
 
-
+from plugins import *
 if __name__ == '__main__':
     try:
         # load config
         config.load_config()
 
         # create channel
-        channel = channel_factory.create_channel("wx")
+        channel_name='wx'
+        channel = channel_factory.create_channel(channel_name)
+        if channel_name=='wx':
+            PluginManager().load_plugins()
 
         # startup channel
         channel.startup()

+ 3 - 1
bot/baidu/baidu_unit_bot.py

@@ -2,6 +2,7 @@
 
 import requests
 from bot.bot import Bot
+from bridge.reply import Reply, ReplyType
 
 
 # Baidu Unit对话接口 (可用, 但能力较弱)
@@ -14,7 +15,8 @@ class BaiduUnitBot(Bot):
         headers = {'content-type': 'application/x-www-form-urlencoded'}
         response = requests.post(url, data=post_data.encode(), headers=headers)
         if response:
-            return response.json()['result']['context']['SYS_PRESUMED_HIST'][1]
+            reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1])
+            return reply
 
     def get_token(self):
         access_key = 'YOUR_ACCESS_KEY'

+ 5 - 1
bot/bot.py

@@ -3,8 +3,12 @@ Auto-replay chat robot abstract class
 """
 
 
+from bridge.context import Context
+from bridge.reply import Reply
+
+
 class Bot(object):
-    def reply(self, query, context=None):
+    def reply(self, query, context : Context =None) -> Reply:
         """
         bot auto-reply content
         :param req: received message

+ 66 - 48
bot/chatgpt/chat_gpt_bot.py

@@ -1,41 +1,42 @@
 # encoding:utf-8
 
 from bot.bot import Bot
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
 from config import conf, load_config
 from common.log import logger
 from common.expired_dict import ExpiredDict
 import openai
 import time
 
-if conf().get('expires_in_seconds'):
-    all_sessions = ExpiredDict(conf().get('expires_in_seconds'))
-else:
-    all_sessions = dict()
 
 # OpenAI对话模型API (可用)
 class ChatGPTBot(Bot):
     def __init__(self):
         openai.api_key = conf().get('open_ai_api_key')
         proxy = conf().get('proxy')
+        self.sessions = SessionManager()
         if proxy:
             openai.proxy = proxy
 
     def reply(self, query, context=None):
         # acquire reply content
-        if not context or not context.get('type') or context.get('type') == 'TEXT':
+        if context.type == ContextType.TEXT:
             logger.info("[OPEN_AI] query={}".format(query))
-            session_id = context.get('session_id') or context.get('from_user_id')
+            session_id = context['session_id']
+            reply = None
             if query == '#清除记忆':
-                Session.clear_session(session_id)
-                return '记忆已清除'
+                self.sessions.clear_session(session_id)
+                reply = Reply(ReplyType.INFO, '记忆已清除')
             elif query == '#清除所有':
-                Session.clear_all_session()
-                return '所有人记忆已清除'
+                self.sessions.clear_all_session()
+                reply = Reply(ReplyType.INFO, '所有人记忆已清除')
             elif query == '#更新配置':
                 load_config()
-                return '配置已更新'
-
-            session = Session.build_session_query(query, session_id)
+                reply = Reply(ReplyType.INFO, '配置已更新')
+            if reply:
+                return reply
+            session = self.sessions.build_session_query(query, session_id)
             logger.debug("[OPEN_AI] session query={}".format(session))
 
             # if context.get('stream'):
@@ -44,14 +45,29 @@ class ChatGPTBot(Bot):
 
             reply_content = self.reply_text(session, session_id, 0)
             logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"]))
-            if reply_content["completion_tokens"] > 0:
-                Session.save_session(reply_content["content"], session_id, reply_content["total_tokens"])
-            return reply_content["content"]
-
-        elif context.get('type', None) == 'IMAGE_CREATE':
-            return self.create_img(query, 0)
+            if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
+                reply = Reply(ReplyType.ERROR, reply_content['content'])
+            elif reply_content["completion_tokens"] > 0:
+                self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"])
+                reply = Reply(ReplyType.TEXT, reply_content["content"])
+            else:
+                reply = Reply(ReplyType.ERROR, reply_content['content'])
+                logger.debug("[OPEN_AI] reply {} used 0 tokens.".format(reply_content))
+            return reply
+
+        elif context.type == ContextType.IMAGE_CREATE:
+            ok, retstring = self.create_img(query, 0)
+            reply = None
+            if ok:
+                reply = Reply(ReplyType.IMAGE_URL, retstring)
+            else:
+                reply = Reply(ReplyType.ERROR, retstring)
+            return reply
+        else:
+            reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type))
+            return reply
 
-    def reply_text(self, session, session_id, retry_count=0) ->dict:
+    def reply_text(self, session, session_id, retry_count=0) -> dict:
         '''
         call openai's ChatCompletion to get the answer
         :param session: a conversation session
@@ -70,8 +86,8 @@ class ChatGPTBot(Bot):
                 presence_penalty=0.0,  # [-2,2]之间,该值越大则更倾向于产生不同的内容
             )
             # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
-            return {"total_tokens": response["usage"]["total_tokens"], 
-                    "completion_tokens": response["usage"]["completion_tokens"], 
+            return {"total_tokens": response["usage"]["total_tokens"],
+                    "completion_tokens": response["usage"]["completion_tokens"],
                     "content": response.choices[0]['message']['content']}
         except openai.error.RateLimitError as e:
             # rate limit exception
@@ -86,15 +102,15 @@ class ChatGPTBot(Bot):
             # api connection exception
             logger.warn(e)
             logger.warn("[OPEN_AI] APIConnection failed")
-            return {"completion_tokens": 0, "content":"我连接不到你的网络"}
+            return {"completion_tokens": 0, "content": "我连接不到你的网络"}
         except openai.error.Timeout as e:
             logger.warn(e)
             logger.warn("[OPEN_AI] Timeout")
-            return {"completion_tokens": 0, "content":"我没有收到你的消息"}
+            return {"completion_tokens": 0, "content": "我没有收到你的消息"}
         except Exception as e:
             # unknown exception
             logger.exception(e)
-            Session.clear_session(session_id)
+            self.sessions.clear_session(session_id)
             return {"completion_tokens": 0, "content": "请再问我一次吧"}
 
     def create_img(self, query, retry_count=0):
@@ -107,7 +123,7 @@ class ChatGPTBot(Bot):
             )
             image_url = response['data'][0]['url']
             logger.info("[OPEN_AI] image_url={}".format(image_url))
-            return image_url
+            return True, image_url
         except openai.error.RateLimitError as e:
             logger.warn(e)
             if retry_count < 1:
@@ -115,14 +131,21 @@ class ChatGPTBot(Bot):
                 logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
                 return self.create_img(query, retry_count+1)
             else:
-                return "提问太快啦,请休息一下再问我吧"
+                return False, "提问太快啦,请休息一下再问我吧"
         except Exception as e:
             logger.exception(e)
-            return None
+            return False, str(e)
+
+
+class SessionManager(object):
+    def __init__(self):
+        if conf().get('expires_in_seconds'):
+            sessions = ExpiredDict(conf().get('expires_in_seconds'))
+        else:
+            sessions = dict()
+        self.sessions = sessions
 
-class Session(object):
-    @staticmethod
-    def build_session_query(query, session_id):
+    def build_session_query(self, query, session_id):
         '''
         build query with conversation history
         e.g.  [
@@ -135,36 +158,33 @@ class Session(object):
         :param session_id: session id
         :return: query content with conversaction
         '''
-        session = all_sessions.get(session_id, [])
+        session = self.sessions.get(session_id, [])
         if len(session) == 0:
             system_prompt = conf().get("character_desc", "")
             system_item = {'role': 'system', 'content': system_prompt}
             session.append(system_item)
-            all_sessions[session_id] = session
+            self.sessions[session_id] = session
         user_item = {'role': 'user', 'content': query}
         session.append(user_item)
         return session
 
-    @staticmethod
-    def save_session(answer, session_id, total_tokens):
+    def save_session(self, answer, session_id, total_tokens):
         max_tokens = conf().get("conversation_max_tokens")
         if not max_tokens:
             # default 3000
             max_tokens = 1000
-        max_tokens=int(max_tokens)
+        max_tokens = int(max_tokens)
 
-        session = all_sessions.get(session_id)
+        session = self.sessions.get(session_id)
         if session:
             # append conversation
             gpt_item = {'role': 'assistant', 'content': answer}
             session.append(gpt_item)
 
         # discard exceed limit conversation
-        Session.discard_exceed_conversation(session, max_tokens, total_tokens)
-    
+        self.discard_exceed_conversation(session, max_tokens, total_tokens)
 
-    @staticmethod
-    def discard_exceed_conversation(session, max_tokens, total_tokens):
+    def discard_exceed_conversation(self, session, max_tokens, total_tokens):
         dec_tokens = int(total_tokens)
         # logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
         while dec_tokens > max_tokens:
@@ -173,13 +193,11 @@ class Session(object):
                 session.pop(1)
                 session.pop(1)
             else:
-                break    
+                break
             dec_tokens = dec_tokens - max_tokens
 
-    @staticmethod
-    def clear_session(session_id):
-        all_sessions[session_id] = []
+    def clear_session(self, session_id):
+        self.sessions[session_id] = []
 
-    @staticmethod
-    def clear_all_session():
-        all_sessions.clear()
+    def clear_all_session(self):
+        self.sessions.clear()

+ 25 - 22
bot/openai/open_ai_bot.py

@@ -1,6 +1,8 @@
 # encoding:utf-8
 
 from bot.bot import Bot
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
 from config import conf
 from common.log import logger
 import openai
@@ -13,30 +15,31 @@ class OpenAIBot(Bot):
     def __init__(self):
         openai.api_key = conf().get('open_ai_api_key')
 
-
     def reply(self, query, context=None):
         # acquire reply content
-        if not context or not context.get('type') or context.get('type') == 'TEXT':
-            logger.info("[OPEN_AI] query={}".format(query))
-            from_user_id = context.get('from_user_id') or context.get('session_id')
-            if query == '#清除记忆':
-                Session.clear_session(from_user_id)
-                return '记忆已清除'
-            elif query == '#清除所有':
-                Session.clear_all_session()
-                return '所有人记忆已清除'
-
-            new_query = Session.build_session_query(query, from_user_id)
-            logger.debug("[OPEN_AI] session query={}".format(new_query))
-
-            reply_content = self.reply_text(new_query, from_user_id, 0)
-            logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
-            if reply_content and query:
-                Session.save_session(query, reply_content, from_user_id)
-            return reply_content
-
-        elif context.get('type', None) == 'IMAGE_CREATE':
-            return self.create_img(query, 0)
+        if context and context.type:
+            if context.type == ContextType.TEXT:
+                logger.info("[OPEN_AI] query={}".format(query))
+                from_user_id = context['session_id']
+                reply = None
+                if query == '#清除记忆':
+                    Session.clear_session(from_user_id)
+                    reply = Reply(ReplyType.INFO, '记忆已清除')
+                elif query == '#清除所有':
+                    Session.clear_all_session()
+                    reply = Reply(ReplyType.INFO, '所有人记忆已清除')
+                else:
+                    new_query = Session.build_session_query(query, from_user_id)
+                    logger.debug("[OPEN_AI] session query={}".format(new_query))
+
+                    reply_content = self.reply_text(new_query, from_user_id, 0)
+                    logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
+                    if reply_content and query:
+                        Session.save_session(query, reply_content, from_user_id)
+                    reply = Reply(ReplyType.TEXT, reply_content)
+                return reply
+            elif context.type == ContextType.IMAGE_CREATE:
+                return self.create_img(query, 0)
 
     def reply_text(self, query, user_id, retry_count=0):
         try:

+ 33 - 7
bridge/bridge.py

@@ -1,16 +1,42 @@
+from bridge.context import Context
+from bridge.reply import Reply
+from common.log import logger
 from bot import bot_factory
+from common.singleton import singleton
 from voice import voice_factory
 
 
+@singleton
 class Bridge(object):
     def __init__(self):
-        pass
+        self.btype={
+            "chat": "chatGPT",
+            "voice_to_text": "openai",
+            "text_to_voice": "baidu"
+        }
+        self.bots={}
 
-    def fetch_reply_content(self, query, context):
-        return bot_factory.create_bot("chatGPT").reply(query, context)
+    def get_bot(self,typename):
+        if self.bots.get(typename) is None:
+            logger.info("create bot {} for {}".format(self.btype[typename],typename))
+            if typename == "text_to_voice":
+                self.bots[typename] = voice_factory.create_voice(self.btype[typename])
+            elif typename == "voice_to_text":
+                self.bots[typename] = voice_factory.create_voice(self.btype[typename])
+            elif typename == "chat":
+                self.bots[typename] = bot_factory.create_bot(self.btype[typename])
+        return self.bots[typename]
+    
+    def get_bot_type(self,typename):
+        return self.btype[typename]
 
-    def fetch_voice_to_text(self, voiceFile):
-        return voice_factory.create_voice("openai").voiceToText(voiceFile)
 
-    def fetch_text_to_voice(self, text):
-        return voice_factory.create_voice("baidu").textToVoice(text)
+    def fetch_reply_content(self, query, context : Context) -> Reply:
+        return self.get_bot("chat").reply(query, context)
+
+    def fetch_voice_to_text(self, voiceFile) -> Reply:
+        return self.get_bot("voice_to_text").voiceToText(voiceFile)
+
+    def fetch_text_to_voice(self, text) -> Reply:
+        return self.get_bot("text_to_voice").textToVoice(text)
+

+ 42 - 0
bridge/context.py

@@ -0,0 +1,42 @@
+# encoding:utf-8
+
+from enum import Enum
+
+class ContextType (Enum):
+    TEXT = 1         # 文本消息
+    VOICE = 2        # 音频消息
+    IMAGE_CREATE = 3 # 创建图片命令
+    
+    def __str__(self):
+        return self.name
+class Context:
+    def __init__(self, type : ContextType = None , content = None,  kwargs = dict()):
+        self.type = type
+        self.content = content
+        self.kwargs = kwargs
+    def __getitem__(self, key):
+        if key == 'type':
+            return self.type
+        elif key == 'content':
+            return self.content
+        else:
+            return self.kwargs[key]
+
+    def __setitem__(self, key, value):
+        if key == 'type':
+            self.type = value
+        elif key == 'content':
+            self.content = value
+        else:
+            self.kwargs[key] = value
+
+    def __delitem__(self, key):
+        if key == 'type':
+            self.type = None
+        elif key == 'content':
+            self.content = None
+        else:
+            del self.kwargs[key]
+    
+    def __str__(self):
+        return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)

+ 22 - 0
bridge/reply.py

@@ -0,0 +1,22 @@
+
+# encoding:utf-8
+
+from enum import Enum
+
+class ReplyType(Enum):
+    TEXT = 1        # 文本
+    VOICE = 2       # 音频文件
+    IMAGE = 3       # 图片文件
+    IMAGE_URL = 4   # 图片URL
+    
+    INFO = 9
+    ERROR = 10
+    def __str__(self):
+        return self.name
+
+class Reply:
+    def __init__(self, type : ReplyType = None , content = None):
+        self.type = type
+        self.content = content
+    def __str__(self):
+        return "Reply(type={}, content={})".format(self.type, self.content)

+ 5 - 3
channel/channel.py

@@ -3,6 +3,8 @@ Message sending channel abstract class
 """
 
 from bridge.bridge import Bridge
+from bridge.context import Context
+from bridge.reply import Reply
 
 class Channel(object):
     def startup(self):
@@ -27,11 +29,11 @@ class Channel(object):
         """
         raise NotImplementedError
 
-    def build_reply_content(self, query, context=None):
+    def build_reply_content(self, query, context : Context=None) -> Reply:
         return Bridge().fetch_reply_content(query, context)
 
-    def build_voice_to_text(self, voice_file):
+    def build_voice_to_text(self, voice_file) -> Reply:
         return Bridge().fetch_voice_to_text(voice_file)
     
-    def build_text_to_voice(self, text):
+    def build_text_to_voice(self, text) -> Reply:
         return Bridge().fetch_text_to_voice(text)

+ 152 - 126
channel/wechat/wechat_channel.py

@@ -7,16 +7,24 @@ wechat channel
 import itchat
 import json
 from itchat.content import *
+from bridge.reply import *
+from bridge.context import *
 from channel.channel import Channel
 from concurrent.futures import ThreadPoolExecutor
 from common.log import logger
 from common.tmp_dir import TmpDir
 from config import conf
+from plugins import *
+
 import requests
 import io
 
-thread_pool = ThreadPoolExecutor(max_workers=8)
 
+thread_pool = ThreadPoolExecutor(max_workers=8)
+def thread_pool_callback(worker):
+    worker_exception = worker.exception()
+    if worker_exception:
+        logger.exception("Worker return exception: {}".format(worker_exception))
 
 @itchat.msg_register(TEXT)
 def handler_single_msg(msg):
@@ -47,62 +55,52 @@ class WechatChannel(Channel):
         # start message listener
         itchat.run()
 
+    # handle_* 系列函数处理收到的消息后构造context,然后调用handle函数处理context
+    # context是一个字典,包含了消息的所有信息,包括以下key
+    #   type: 消息类型,包括TEXT、VOICE、IMAGE_CREATE
+    #   content: 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令
+    #   session_id: 会话id
+    #   isgroup: 是否是群聊
+    #   msg: 原始消息对象
+    #   receiver: 需要回复的对象
+
     def handle_voice(self, msg):
-        if conf().get('speech_recognition') != True :
+        if conf().get('speech_recognition') != True:
             return
         logger.debug("[WX]receive voice msg: " + msg['FileName'])
-        thread_pool.submit(self._do_handle_voice, msg)
-
-    def _do_handle_voice(self, msg):
         from_user_id = msg['FromUserName']
         other_user_id = msg['User']['UserName']
         if from_user_id == other_user_id:
-            file_name = TmpDir().path() + msg['FileName']
-            msg.download(file_name)
-            query = super().build_voice_to_text(file_name)
-            if conf().get('voice_reply_voice'):
-                self._do_send_voice(query, from_user_id)
-            else:
-                self._do_send_text(query, from_user_id)
+            context = Context(ContextType.VOICE,msg['FileName'])
+            context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id}
+            thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
 
     def handle_text(self, msg):
         logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False))
         content = msg['Text']
-        self._handle_single_msg(msg, content)
-
-    def _handle_single_msg(self, msg, content):
         from_user_id = msg['FromUserName']
         to_user_id = msg['ToUserName']              # 接收人id
         other_user_id = msg['User']['UserName']     # 对手方id
-        match_prefix = self.check_prefix(content, conf().get('single_chat_prefix'))
+        match_prefix = check_prefix(content, conf().get('single_chat_prefix'))
         if "」\n- - - - - - - - - - - - - - -" in content:
             logger.debug("[WX]reference query skipped")
             return
-        if from_user_id == other_user_id and match_prefix is not None:
-            # 好友向自己发送消息
-            if match_prefix != '':
-                str_list = content.split(match_prefix, 1)
-                if len(str_list) == 2:
-                    content = str_list[1].strip()
-
-            img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
-            if img_match_prefix:
-                content = content.split(img_match_prefix, 1)[1].strip()
-                thread_pool.submit(self._do_send_img, content, from_user_id)
-            else :
-                thread_pool.submit(self._do_send_text, content, from_user_id)
-        elif to_user_id == other_user_id and match_prefix:
-            # 自己给好友发送消息
-            str_list = content.split(match_prefix, 1)
-            if len(str_list) == 2:
-                content = str_list[1].strip()
-            img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
-            if img_match_prefix:
-                content = content.split(img_match_prefix, 1)[1].strip()
-                thread_pool.submit(self._do_send_img, content, to_user_id)
-            else:
-                thread_pool.submit(self._do_send_text, content, to_user_id)
+        if match_prefix:
+            content = content.replace(match_prefix, '', 1).strip()
+        else:
+            return
+        context = Context()
+        context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id}
+
+        img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
+        if img_match_prefix:
+            content = content.replace(img_match_prefix, '', 1).strip()
+            context.type = ContextType.IMAGE_CREATE
+        else:
+            context.type = ContextType.TEXT
 
+        context.content = content
+        thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
 
     def handle_group(self, msg):
         logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False))
@@ -122,100 +120,128 @@ class WechatChannel(Channel):
             logger.debug("[WX]reference query skipped")
             return ""
         config = conf()
-        match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or self.check_prefix(origin_content, config.get('group_chat_prefix')) \
-                       or self.check_contain(origin_content, config.get('group_chat_keyword'))
-        if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or self.check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix:
-            img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
+        match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \
+                       or check_contain(origin_content, config.get('group_chat_keyword'))
+        if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix:
+            context = Context()
+            context.kwargs = { 'isgroup': True, 'msg': msg, 'receiver': group_id}
+            
+            img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
             if img_match_prefix:
-                content = content.split(img_match_prefix, 1)[1].strip()
-                thread_pool.submit(self._do_send_img, content, group_id)
+                content = content.replace(img_match_prefix, '', 1).strip()
+                context.type = ContextType.IMAGE_CREATE
             else:
-                thread_pool.submit(self._do_send_group, content, msg)
-
-    def send(self, msg, receiver):
-        itchat.send(msg, toUserName=receiver)
-        logger.info('[WX] sendMsg={}, receiver={}'.format(msg, receiver))
-
-    def _do_send_voice(self, query, reply_user_id):
-        try:
-            if not query:
-                return
-            context = dict()
-            context['from_user_id'] = reply_user_id
-            reply_text = super().build_reply_content(query, context)
-            if reply_text:
-                replyFile = super().build_text_to_voice(reply_text)
-                itchat.send_file(replyFile, toUserName=reply_user_id)
-                logger.info('[WX] sendFile={}, receiver={}'.format(replyFile, reply_user_id))
-        except Exception as e:
-            logger.exception(e)
-
-    def _do_send_text(self, query, reply_user_id):
-        try:
-            if not query:
-                return
-            context = dict()
-            context['session_id'] = reply_user_id
-            reply_text = super().build_reply_content(query, context)
-            if reply_text:
-                self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
-        except Exception as e:
-            logger.exception(e)
-
-    def _do_send_img(self, query, reply_user_id):
-        try:
-            if not query:
-                return
-            context = dict()
-            context['type'] = 'IMAGE_CREATE'
-            img_url = super().build_reply_content(query, context)
-            if not img_url:
-                return
-
-            # 图片下载
+                context.type = ContextType.TEXT
+            context.content = content
+
+            group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
+            if ('ALL_GROUP' in group_chat_in_one_session or
+                    group_name in group_chat_in_one_session or
+                    check_contain(group_name, group_chat_in_one_session)):
+                context['session_id'] = group_id
+            else:
+                context['session_id'] = msg['ActualUserName']
+
+            thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
+
+    # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
+    def send(self, reply : Reply, receiver):
+        if reply.type == ReplyType.TEXT:
+            itchat.send(reply.content, toUserName=receiver)
+            logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
+        elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
+            itchat.send(reply.content, toUserName=receiver)
+            logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
+        elif reply.type == ReplyType.VOICE:
+            itchat.send_file(reply.content, toUserName=receiver)
+            logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver))
+        elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
+            img_url = reply.content
             pic_res = requests.get(img_url, stream=True)
             image_storage = io.BytesIO()
             for block in pic_res.iter_content(1024):
                 image_storage.write(block)
             image_storage.seek(0)
+            itchat.send_image(image_storage, toUserName=receiver)
+            logger.info('[WX] sendImage url=, receiver={}'.format(img_url,receiver))
+        elif reply.type == ReplyType.IMAGE: # 从文件读取图片
+            image_storage = reply.content
+            image_storage.seek(0)
+            itchat.send_image(image_storage, toUserName=receiver)
+            logger.info('[WX] sendImage, receiver={}'.format(receiver))
+
+    # 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类
+    def handle(self, context):
+        reply = Reply()
+
+        logger.debug('[WX] ready to handle context: {}'.format(context))
+        
+        # reply的构建步骤
+        e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply}))
+        reply = e_context['reply']
+        if not e_context.is_pass():
+            logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
+            if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE:
+                reply = super().build_reply_content(context.content, context)
+            elif context.type == ContextType.VOICE:
+                msg = context['msg']
+                file_name = TmpDir().path() + context.content
+                msg.download(file_name)
+                reply = super().build_voice_to_text(file_name)
+                if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO:
+                    context.content = reply.content # 语音转文字后,将文字内容作为新的context
+                    context.type = ContextType.TEXT
+                    reply = super().build_reply_content(context.content, context)
+                    if reply.type == ReplyType.TEXT:
+                        if conf().get('voice_reply_voice'):
+                            reply = super().build_text_to_voice(reply.content)
+            else:
+                logger.error('[WX] unknown context type: {}'.format(context.type))
+                return
 
-            # 图片发送
-            itchat.send_image(image_storage, reply_user_id)
-            logger.info('[WX] sendImage, receiver={}'.format(reply_user_id))
-        except Exception as e:
-            logger.exception(e)
-
-    def _do_send_group(self, query, msg):
-        if not query:
-            return
-        context = dict()
-        group_name = msg['User']['NickName']
-        group_id = msg['User']['UserName']
-        group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
-        if ('ALL_GROUP' in group_chat_in_one_session or \
-                group_name in group_chat_in_one_session or \
-                self.check_contain(group_name, group_chat_in_one_session)):
-            context['session_id'] = group_id
-        else:
-            context['session_id'] = msg['ActualUserName']
-        reply_text = super().build_reply_content(query, context)
-        if reply_text:
-            reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip()
-            self.send(conf().get("group_chat_reply_prefix", "") + reply_text, group_id)
-
-
-    def check_prefix(self, content, prefix_list):
-        for prefix in prefix_list:
-            if content.startswith(prefix):
-                return prefix
-        return None
+        logger.debug('[WX] ready to decorate reply: {}'.format(reply))
+        
+        # reply的包装步骤
+        if reply and reply.type:
+            e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
+            reply=e_context['reply']
+            if not e_context.is_pass() and reply and reply.type:
+                if reply.type == ReplyType.TEXT:
+                    reply_text = reply.content
+                    if context['isgroup']:
+                        reply_text = '@' +  context['msg']['ActualNickName'] + ' ' + reply_text.strip()
+                        reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
+                    else:
+                        reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
+                    reply.content = reply_text
+                elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
+                    reply.content = str(reply.type)+":\n" + reply.content
+                elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
+                    pass
+                else:
+                    logger.error('[WX] unknown reply type: {}'.format(reply.type))
+                    return
+
+        # reply的发送步骤   
+        if reply and reply.type:
+            e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
+            reply=e_context['reply']
+            if not e_context.is_pass() and reply and reply.type:
+                logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver']))
+                self.send(reply, context['receiver'])
+
+
+def check_prefix(content, prefix_list):
+    for prefix in prefix_list:
+        if content.startswith(prefix):
+            return prefix
+    return None
 
 
-    def check_contain(self, content, keyword_list):
-        if not keyword_list:
-            return None
-        for ky in keyword_list:
-            if content.find(ky) != -1:
-                return True
+def check_contain(content, keyword_list):
+    if not keyword_list:
         return None
-
+    for ky in keyword_list:
+        if content.find(ky) != -1:
+            return True
+    return None

+ 9 - 10
channel/wechat/wechaty_channel.py

@@ -11,6 +11,7 @@ import time
 import asyncio
 import requests
 from typing import Optional, Union
+from bridge.context import Context, ContextType
 from wechaty_puppet import MessageType, FileBox, ScanStatus  # type: ignore
 from wechaty import Wechaty, Contact
 from wechaty.user import Message, Room, MiniProgram, UrlLink
@@ -127,9 +128,9 @@ class WechatyChannel(Channel):
         try:
             if not query:
                 return
-            context = dict()
+            context = Context(ContextType.TEXT, query)
             context['session_id'] = reply_user_id
-            reply_text = super().build_reply_content(query, context)
+            reply_text = super().build_reply_content(query, context).content
             if reply_text:
                 await self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
         except Exception as e:
@@ -139,9 +140,8 @@ class WechatyChannel(Channel):
         try:
             if not query:
                 return
-            context = dict()
-            context['type'] = 'IMAGE_CREATE'
-            img_url = super().build_reply_content(query, context)
+            context = Context(ContextType.IMAGE_CREATE, query)
+            img_url = super().build_reply_content(query, context).content
             if not img_url:
                 return
             # 图片下载
@@ -162,7 +162,7 @@ class WechatyChannel(Channel):
     async def _do_send_group(self, query, group_id, group_name, group_user_id, group_user_name):
         if not query:
             return
-        context = dict()
+        context = Context(ContextType.TEXT, query)
         group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
         if ('ALL_GROUP' in group_chat_in_one_session or \
                 group_name in group_chat_in_one_session or \
@@ -170,7 +170,7 @@ class WechatyChannel(Channel):
             context['session_id'] = str(group_id)
         else:
             context['session_id'] = str(group_id) + '-' + str(group_user_id)
-        reply_text = super().build_reply_content(query, context)
+        reply_text = super().build_reply_content(query, context).content
         if reply_text:
             reply_text = '@' + group_user_name + ' ' + reply_text.strip()
             await self.send_group(conf().get("group_chat_reply_prefix", "") + reply_text, group_id)
@@ -179,9 +179,8 @@ class WechatyChannel(Channel):
         try:
             if not query:
                 return
-            context = dict()
-            context['type'] = 'IMAGE_CREATE'
-            img_url = super().build_reply_content(query, context)
+            context = Context(ContextType.IMAGE_CREATE, query)
+            img_url = super().build_reply_content(query, context).content
             if not img_url:
                 return
             # 图片发送

+ 9 - 0
common/singleton.py

@@ -0,0 +1,9 @@
+def singleton(cls):
+    instances = {}
+
+    def get_instance(*args, **kwargs):
+        if cls not in instances:
+            instances[cls] = cls(*args, **kwargs)
+        return instances[cls]
+
+    return get_instance

+ 65 - 0
common/sorted_dict.py

@@ -0,0 +1,65 @@
+import heapq
+
+
+class SortedDict(dict):
+    def __init__(self, sort_func=lambda k, v: k, init_dict=None, reverse=False):
+        if init_dict is None:
+            init_dict = []
+        if isinstance(init_dict, dict):
+            init_dict = init_dict.items()
+        self.sort_func = sort_func
+        self.sorted_keys = None
+        self.reverse = reverse
+        self.heap = []
+        for k, v in init_dict:
+            self[k] = v
+
+    def __setitem__(self, key, value):
+        if key in self:
+            super().__setitem__(key, value)
+            for i, (priority, k) in enumerate(self.heap):
+                if k == key:
+                    self.heap[i] = (self.sort_func(key, value), key)
+                    heapq.heapify(self.heap)
+                    break
+            self.sorted_keys = None
+        else:
+            super().__setitem__(key, value)
+            heapq.heappush(self.heap, (self.sort_func(key, value), key))
+            self.sorted_keys = None
+
+    def __delitem__(self, key):
+        super().__delitem__(key)
+        for i, (priority, k) in enumerate(self.heap):
+            if k == key:
+                del self.heap[i]
+                heapq.heapify(self.heap)
+                break
+        self.sorted_keys = None
+
+    def keys(self):
+        if self.sorted_keys is None:
+            self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)]
+        return self.sorted_keys
+
+    def items(self):
+        if self.sorted_keys is None:
+            self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)]
+        sorted_items = [(k, self[k]) for k in self.sorted_keys]
+        return sorted_items
+
+    def _update_heap(self, key):
+        for i, (priority, k) in enumerate(self.heap):
+            if k == key:
+                new_priority = self.sort_func(key, self[key])
+                if new_priority != priority:
+                    self.heap[i] = (new_priority, key)
+                    heapq.heapify(self.heap)
+                    self.sorted_keys = None
+                break
+
+    def __iter__(self):
+        return iter(self.keys())
+
+    def __repr__(self):
+        return f'{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})'

+ 9 - 0
plugins/__init__.py

@@ -0,0 +1,9 @@
+from .plugin_manager import PluginManager
+from .event import *
+from .plugin import *
+
+instance = PluginManager()
+
+register                    = instance.register
+# load_plugins                = instance.load_plugins
+# emit_event                  = instance.emit_event

+ 1 - 0
plugins/banwords/.gitignore

@@ -0,0 +1 @@
+banwords.txt

+ 9 - 0
plugins/banwords/README.md

@@ -0,0 +1,9 @@
+### 说明
+简易的敏感词插件,暂不支持分词,请自行导入词库到插件文件夹中的`banwords.txt`,每行一个词,一个参考词库是[1](https://github.com/cjh0613/tencent-sensitive-words/blob/main/sensitive_words_lines.txt)。
+
+`config.json`中能够填写默认的处理行为,目前行为有:
+- `ignore` : 无视这条消息。
+- `replace` : 将消息中的敏感词替换成"*",并回复违规。
+
+### 致谢
+搜索功能实现来自https://github.com/toolgood/ToolGood.Words

+ 250 - 0
plugins/banwords/WordsSearch.py

@@ -0,0 +1,250 @@
+#!/usr/bin/env python
+# -*- coding:utf-8 -*-
+# ToolGood.Words.WordsSearch.py
+# 2020, Lin Zhijun, https://github.com/toolgood/ToolGood.Words
+# Licensed under the Apache License 2.0
+# 更新日志
+# 2020.04.06 第一次提交
+# 2020.05.16 修改,支持大于0xffff的字符
+
+__all__ = ['WordsSearch']
+__author__ = 'Lin Zhijun'
+__date__ = '2020.05.16'
+
+class TrieNode():
+    def __init__(self):
+        self.Index = 0
+        self.Index = 0
+        self.Layer = 0
+        self.End = False
+        self.Char = ''
+        self.Results = []
+        self.m_values = {}
+        self.Failure = None
+        self.Parent = None
+
+    def Add(self,c):
+        if c in self.m_values :
+            return self.m_values[c]
+        node = TrieNode()
+        node.Parent = self
+        node.Char = c
+        self.m_values[c] = node
+        return node
+
+    def SetResults(self,index):
+        if (self.End == False):
+            self.End = True
+        self.Results.append(index)
+
+class TrieNode2():
+    def __init__(self):
+        self.End = False
+        self.Results = []
+        self.m_values = {}
+        self.minflag = 0xffff
+        self.maxflag = 0
+
+    def Add(self,c,node3):
+        if (self.minflag > c):
+            self.minflag = c
+        if (self.maxflag < c):
+             self.maxflag = c
+        self.m_values[c] = node3
+
+    def SetResults(self,index):
+        if (self.End == False) :
+            self.End = True
+        if (index in self.Results )==False : 
+            self.Results.append(index)
+
+    def HasKey(self,c):
+        return c in self.m_values
+        
+ 
+    def TryGetValue(self,c):
+        if (self.minflag <= c and self.maxflag >= c):
+            if c in self.m_values:
+                return self.m_values[c]
+        return None
+
+
+class WordsSearch():
+    def __init__(self):
+        self._first = {}
+        self._keywords = []
+        self._indexs=[]
+    
+    def SetKeywords(self,keywords):
+        self._keywords = keywords
+        self._indexs=[]
+        for i in range(len(keywords)):
+            self._indexs.append(i)
+
+        root = TrieNode()
+        allNodeLayer={}
+
+        for i in range(len(self._keywords)): # for (i = 0; i < _keywords.length; i++) 
+            p = self._keywords[i]
+            nd = root
+            for j in range(len(p)): # for (j = 0; j < p.length; j++) 
+                nd = nd.Add(ord(p[j]))
+                if (nd.Layer == 0):
+                    nd.Layer = j + 1
+                    if nd.Layer in allNodeLayer:
+                        allNodeLayer[nd.Layer].append(nd)
+                    else:
+                        allNodeLayer[nd.Layer]=[]
+                        allNodeLayer[nd.Layer].append(nd)
+            nd.SetResults(i)
+
+
+        allNode = []
+        allNode.append(root)
+        for key in allNodeLayer.keys():
+            for nd in allNodeLayer[key]:
+                allNode.append(nd)
+        allNodeLayer=None
+
+        for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++) 
+            if i==0 :
+                continue
+            nd=allNode[i]
+            nd.Index = i
+            r = nd.Parent.Failure
+            c = nd.Char
+            while (r != None and (c in r.m_values)==False):
+                r = r.Failure
+            if (r == None):
+                nd.Failure = root
+            else:
+                nd.Failure = r.m_values[c]
+                for key2 in nd.Failure.Results :
+                    nd.SetResults(key2)
+        root.Failure = root
+
+        allNode2 = []
+        for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++) 
+            allNode2.append( TrieNode2())
+        
+        for i in range(len(allNode2)): # for (i = 0; i < allNode2.length; i++) 
+            oldNode = allNode[i]
+            newNode = allNode2[i]
+
+            for key in oldNode.m_values :
+                index = oldNode.m_values[key].Index
+                newNode.Add(key, allNode2[index])
+            
+            for index in range(len(oldNode.Results)): # for (index = 0; index < oldNode.Results.length; index++) 
+                item = oldNode.Results[index]
+                newNode.SetResults(item)
+            
+            oldNode=oldNode.Failure
+            while oldNode != root:
+                for key in oldNode.m_values :
+                    if (newNode.HasKey(key) == False):
+                        index = oldNode.m_values[key].Index
+                        newNode.Add(key, allNode2[index])
+                for index in range(len(oldNode.Results)): 
+                    item = oldNode.Results[index]
+                    newNode.SetResults(item)
+                oldNode=oldNode.Failure
+        allNode = None
+        root = None
+
+        # first = []
+        # for index in range(65535):# for (index = 0; index < 0xffff; index++) 
+        #     first.append(None)
+        
+        # for key in allNode2[0].m_values :
+        #     first[key] = allNode2[0].m_values[key]
+        
+        self._first = allNode2[0]
+    
+
+    def FindFirst(self,text):
+        ptr = None
+        for index in range(len(text)): # for (index = 0; index < text.length; index++) 
+            t =ord(text[index]) # text.charCodeAt(index)
+            tn = None
+            if (ptr == None):
+                tn = self._first.TryGetValue(t)
+            else:
+                tn = ptr.TryGetValue(t)
+                if (tn==None):
+                    tn = self._first.TryGetValue(t)
+                
+            
+            if (tn != None):
+                if (tn.End):
+                    item = tn.Results[0]
+                    keyword = self._keywords[item]
+                    return { "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] }
+            ptr = tn
+        return None
+
+    def FindAll(self,text):
+        ptr = None
+        list = []
+
+        for index in range(len(text)): # for (index = 0; index < text.length; index++) 
+            t =ord(text[index]) # text.charCodeAt(index)
+            tn = None
+            if (ptr == None):
+                tn = self._first.TryGetValue(t)
+            else:
+                tn = ptr.TryGetValue(t)
+                if (tn==None):
+                    tn = self._first.TryGetValue(t)
+                
+            
+            if (tn != None):
+                if (tn.End):
+                    for j in range(len(tn.Results)): # for (j = 0; j < tn.Results.length; j++) 
+                        item = tn.Results[j]
+                        keyword = self._keywords[item]
+                        list.append({ "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] })
+            ptr = tn
+        return list
+
+
+    def ContainsAny(self,text):
+        ptr = None
+        for index in range(len(text)): # for (index = 0; index < text.length; index++) 
+            t =ord(text[index]) # text.charCodeAt(index)
+            tn = None
+            if (ptr == None):
+                tn = self._first.TryGetValue(t)
+            else:
+                tn = ptr.TryGetValue(t)
+                if (tn==None):
+                    tn = self._first.TryGetValue(t)
+            
+            if (tn != None):
+                if (tn.End):
+                    return True
+            ptr = tn
+        return False
+    
+    def Replace(self,text, replaceChar = '*'):
+        result = list(text) 
+
+        ptr = None
+        for i in range(len(text)): # for (i = 0; i < text.length; i++) 
+            t =ord(text[i]) # text.charCodeAt(index)
+            tn = None
+            if (ptr == None):
+                tn = self._first.TryGetValue(t)
+            else:
+                tn = ptr.TryGetValue(t)
+                if (tn==None):
+                    tn = self._first.TryGetValue(t)
+            
+            if (tn != None):
+                if (tn.End):
+                    maxLength = len( self._keywords[tn.Results[0]])
+                    start = i + 1 - maxLength
+                    for j in range(start,i+1): # for (j = start; j <= i; j++) 
+                        result[j] = replaceChar
+            ptr = tn
+        return ''.join(result) 

+ 0 - 0
plugins/banwords/__init__.py


+ 63 - 0
plugins/banwords/banwords.py

@@ -0,0 +1,63 @@
+# encoding:utf-8
+
+import json
+import os
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+import plugins
+from plugins import *
+from common.log import logger
+from .WordsSearch import WordsSearch
+
+
+@plugins.register(name="Banwords", desc="判断消息中是否有敏感词、决定是否回复。", version="1.0", author="lanvent", desire_priority= 100)
+class Banwords(Plugin):
+    def __init__(self):
+        super().__init__()
+        try:
+            curdir=os.path.dirname(__file__)
+            config_path=os.path.join(curdir,"config.json")
+            conf=None
+            if not os.path.exists(config_path):
+                conf={"action":"ignore"}
+                with open(config_path,"w") as f:
+                    json.dump(conf,f,indent=4)
+            else:
+                with open(config_path,"r") as f:
+                    conf=json.load(f)
+            self.searchr = WordsSearch()
+            self.action = conf["action"]
+            banwords_path = os.path.join(curdir,"banwords.txt")
+            with open(banwords_path, 'r', encoding='utf-8') as f:
+                words=[]
+                for line in f:
+                    word = line.strip()
+                    if word:
+                        words.append(word)
+            self.searchr.SetKeywords(words)
+            self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
+            logger.info("[Banwords] inited")
+        except Exception as e:
+            logger.error("Banwords init failed: %s" % e)
+        
+
+
+    def on_handle_context(self, e_context: EventContext):
+
+        if e_context['context'].type not in [ContextType.TEXT,ContextType.IMAGE_CREATE]:
+            return
+        
+        content = e_context['context'].content
+        logger.debug("[Banwords] on_handle_context. content: %s" % content)
+        if self.action == "ignore":
+            f = self.searchr.FindFirst(content)
+            if f:
+                logger.info("Banwords: %s" % f["Keyword"])
+                e_context.action = EventAction.BREAK_PASS
+                return
+        elif self.action == "replace":
+            if self.searchr.ContainsAny(content):
+                reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n"+self.searchr.Replace(content))
+                e_context['reply'] = reply
+                e_context.action = EventAction.BREAK_PASS
+                return

+ 3 - 0
plugins/banwords/banwords.txt.template

@@ -0,0 +1,3 @@
+nipples
+pennis
+法轮功

+ 3 - 0
plugins/banwords/config.json.template

@@ -0,0 +1,3 @@
+{
+    "action": "ignore"
+}

+ 49 - 0
plugins/event.py

@@ -0,0 +1,49 @@
+# encoding:utf-8
+
+from enum import Enum
+
+
+class Event(Enum):
+    # ON_RECEIVE_MESSAGE = 1  # 收到消息
+
+    ON_HANDLE_CONTEXT = 2   # 处理消息前
+    """
+    e_context = {  "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空  }
+    """
+
+    ON_DECORATE_REPLY = 3   # 得到回复后准备装饰
+    """
+    e_context = {  "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
+    """
+
+    ON_SEND_REPLY = 4       # 发送回复前
+    """
+    e_context = {  "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
+    """
+
+    # AFTER_SEND_REPLY = 5    # 发送回复后
+
+
+class EventAction(Enum):
+    CONTINUE = 1            # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑
+    BREAK = 2               # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑
+    BREAK_PASS = 3          # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑
+
+
+class EventContext:
+    def __init__(self, event, econtext=dict()):
+        self.event = event
+        self.econtext = econtext
+        self.action = EventAction.CONTINUE
+
+    def __getitem__(self, key):
+        return self.econtext[key]
+
+    def __setitem__(self, key, value):
+        self.econtext[key] = value
+
+    def __delitem__(self, key):
+        del self.econtext[key]
+
+    def is_pass(self):
+        return self.action == EventAction.BREAK_PASS

+ 0 - 0
plugins/godcmd/__init__.py


+ 4 - 0
plugins/godcmd/config.json.template

@@ -0,0 +1,4 @@
+{
+    "password": "",
+    "admin_users": []
+}

+ 289 - 0
plugins/godcmd/godcmd.py

@@ -0,0 +1,289 @@
+# encoding:utf-8
+
+import json
+import os
+import traceback
+from typing import Tuple
+from bridge.bridge import Bridge
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from config import load_config
+import plugins
+from plugins import *
+from common.log import logger
+
+# 定义指令集
+COMMANDS = {
+    "help": {
+        "alias": ["help", "帮助"],
+        "desc": "打印指令集合",
+    },
+    "auth": {
+        "alias": ["auth", "认证"],
+        "args": ["口令"],
+        "desc": "管理员认证",
+    },
+    # "id": {
+    #     "alias": ["id", "用户"],
+    #     "desc": "获取用户id", #目前无实际意义
+    # },
+    "reset": {
+        "alias": ["reset", "重置会话"],
+        "desc": "重置会话",
+    },
+}
+
+ADMIN_COMMANDS = {
+    "resume": {
+        "alias": ["resume", "恢复服务"],
+        "desc": "恢复服务",
+    },
+    "stop": {
+        "alias": ["stop", "暂停服务"],
+        "desc": "暂停服务",
+    },
+    "reconf": {
+        "alias": ["reconf", "重载配置"],
+        "desc": "重载配置(不包含插件配置)",
+    },
+    "resetall": {
+        "alias": ["resetall", "重置所有会话"],
+        "desc": "重置所有会话",
+    },
+    "scanp": {
+        "alias": ["scanp", "扫描插件"],
+        "desc": "扫描插件目录是否有新插件",
+    },
+    "plist": {
+        "alias": ["plist", "插件"],
+        "desc": "打印当前插件列表",
+    },
+    "setpri": {
+        "alias": ["setpri", "设置插件优先级"],
+        "args": ["插件名", "优先级"],
+        "desc": "设置指定插件的优先级,越大越优先",
+    },
+    "reloadp": {
+        "alias": ["reloadp", "重载插件"],
+        "args": ["插件名"],
+        "desc": "重载指定插件配置",
+    },
+    "enablep": {
+        "alias": ["enablep", "启用插件"],
+        "args": ["插件名"],
+        "desc": "启用指定插件",
+    },
+    "disablep": {
+        "alias": ["disablep", "禁用插件"],
+        "args": ["插件名"],
+        "desc": "禁用指定插件",
+    },
+    "debug": {
+        "alias": ["debug", "调试模式", "DEBUG"],
+        "desc": "开启机器调试日志",
+    },
+}
+# 定义帮助函数
+def get_help_text(isadmin, isgroup):
+    help_text = "可用指令:\n"
+    for cmd, info in COMMANDS.items():
+        if cmd=="auth" and (isadmin or isgroup): # 群聊不可认证
+            continue
+
+        alias=["#"+a for a in info['alias']]
+        help_text += f"{','.join(alias)} "
+        if 'args' in info:
+            args=["{"+a+"}" for a in info['args']]
+            help_text += f"{' '.join(args)} "
+        help_text += f": {info['desc']}\n"
+    if ADMIN_COMMANDS and isadmin:
+        help_text += "\n管理员指令:\n"
+        for cmd, info in ADMIN_COMMANDS.items():
+            alias=["#"+a for a in info['alias']]
+            help_text += f"{','.join(alias)} "
+            help_text += f": {info['desc']}\n"
+    return help_text
+
+@plugins.register(name="Godcmd", desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", version="1.0", author="lanvent", desire_priority= 999)
+class Godcmd(Plugin):
+
+    def __init__(self):
+        super().__init__()
+
+        curdir=os.path.dirname(__file__)
+        config_path=os.path.join(curdir,"config.json")
+        gconf=None
+        if not os.path.exists(config_path):
+            gconf={"password":"","admin_users":[]}
+            with open(config_path,"w") as f:
+                json.dump(gconf,f,indent=4)
+        else:
+            with open(config_path,"r") as f:
+                gconf=json.load(f)
+                
+        self.password = gconf["password"]
+        self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证 TODO: 用户名每次都会变,目前不可用
+        self.isrunning = True # 机器人是否运行中
+
+        self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
+        logger.info("[Godcmd] inited")
+
+    
+    def on_handle_context(self, e_context: EventContext):
+        context_type = e_context['context'].type
+        if context_type != ContextType.TEXT:
+            if not self.isrunning:
+                e_context.action = EventAction.BREAK_PASS
+            return
+        
+        content = e_context['context'].content
+        logger.debug("[Godcmd] on_handle_context. content: %s" % content)
+        if content.startswith("#"):
+            # msg = e_context['context']['msg']
+            user = e_context['context']['receiver']
+            session_id = e_context['context']['session_id']
+            isgroup = e_context['context']['isgroup']
+            bottype = Bridge().get_bot_type("chat")
+            bot = Bridge().get_bot("chat")
+            # 将命令和参数分割
+            command_parts = content[1:].split(" ")
+            cmd = command_parts[0]
+            args = command_parts[1:]
+            isadmin=False
+            if user in self.admin_users:
+                isadmin=True
+            ok=False
+            result="string"
+            if any(cmd in info['alias'] for info in COMMANDS.values()):
+                cmd = next(c for c, info in COMMANDS.items() if cmd in info['alias'])
+                if cmd == "auth":
+                    ok, result = self.authenticate(user, args, isadmin, isgroup)
+                elif cmd == "help":
+                    ok, result = True, get_help_text(isadmin, isgroup)
+                elif cmd == "id":
+                    ok, result = True, f"用户id=\n{user}"
+                elif cmd == "reset":
+                    if bottype == "chatGPT":
+                        bot.sessions.clear_session(session_id)
+                        ok, result = True, "会话已重置"
+                    else:
+                        ok, result = False, "当前对话机器人不支持重置会话"
+                logger.debug("[Godcmd] command: %s by %s" % (cmd, user))
+            elif any(cmd in info['alias'] for info in ADMIN_COMMANDS.values()):
+                if isadmin:
+                    if isgroup:
+                        ok, result = False, "群聊不可执行管理员指令"
+                    else:
+                        cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info['alias'])
+                        if cmd == "stop":
+                            self.isrunning = False
+                            ok, result = True, "服务已暂停"
+                        elif cmd == "resume":
+                            self.isrunning = True
+                            ok, result = True, "服务已恢复"
+                        elif cmd == "reconf":
+                            load_config()
+                            ok, result = True, "配置已重载"
+                        elif cmd == "resetall":
+                            if bottype == "chatGPT":
+                                bot.sessions.clear_all_session()
+                                ok, result = True, "重置所有会话成功"
+                            else:
+                                ok, result = False, "当前对话机器人不支持重置会话"
+                        elif cmd == "debug":
+                            logger.setLevel('DEBUG')
+                            ok, result = True, "DEBUG模式已开启"
+                        elif cmd == "plist":
+                            plugins = PluginManager().list_plugins()
+                            ok = True
+                            result = "插件列表:\n"
+                            for name,plugincls in plugins.items():
+                                result += f"{plugincls.name}_v{plugincls.version} {plugincls.priority} - "
+                                if plugincls.enabled:
+                                    result += "已启用\n"
+                                else:
+                                    result += "未启用\n"
+                        elif cmd == "scanp":
+                            new_plugins = PluginManager().scan_plugins()
+                            ok, result = True, "插件扫描完成"
+                            PluginManager().activate_plugins()
+                            if len(new_plugins) >0 :
+                                result += "\n发现新插件:\n"
+                                result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins])
+                            else :
+                                result +=", 未发现新插件"
+                        elif cmd == "setpri":
+                            if len(args) != 2:
+                                ok, result = False, "请提供插件名和优先级"
+                            else:
+                                ok = PluginManager().set_plugin_priority(args[0], int(args[1]))
+                                if ok:
+                                    result = "插件" + args[0] + "优先级已设置为" + args[1]
+                                else:
+                                    result = "插件不存在"
+                        elif cmd == "reloadp":
+                            if len(args) != 1:
+                                ok, result = False, "请提供插件名"
+                            else:
+                                ok = PluginManager().reload_plugin(args[0])
+                                if ok:
+                                    result = "插件配置已重载"
+                                else:
+                                    result = "插件不存在"
+                        elif cmd == "enablep":
+                            if len(args) != 1:
+                                ok, result = False, "请提供插件名"
+                            else:
+                                ok = PluginManager().enable_plugin(args[0])
+                                if ok:
+                                    result = "插件已启用"
+                                else:
+                                    result = "插件不存在"
+                        elif cmd == "disablep":
+                            if len(args) != 1:
+                                ok, result = False, "请提供插件名"
+                            else:
+                                ok = PluginManager().disable_plugin(args[0])
+                                if ok:
+                                    result = "插件已禁用"
+                                else:
+                                    result = "插件不存在"
+
+                        logger.debug("[Godcmd] admin command: %s by %s" % (cmd, user))
+                else:
+                    ok, result = False, "需要管理员权限才能执行该指令"
+            else:
+                ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n"
+            
+            reply = Reply()
+            if ok:
+                reply.type = ReplyType.INFO
+            else:
+                reply.type = ReplyType.ERROR
+            reply.content = result
+            e_context['reply'] = reply
+
+            e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
+        elif not self.isrunning:
+            e_context.action = EventAction.BREAK_PASS
+    
+    def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool,str] : 
+        if isgroup:
+            return False,"请勿在群聊中认证"
+        
+        if isadmin:
+            return False,"管理员账号无需认证"
+        
+        if len(self.password) == 0:
+            return False,"未设置口令,无法认证"
+        
+        if len(args) != 1:
+            return False,"请提供口令"
+        
+        password = args[0]
+        if password == self.password:
+            self.admin_users.append(userid)
+            return True,"认证成功"
+        else:
+            return False,"认证失败"
+

+ 0 - 0
plugins/hello/__init__.py


+ 46 - 0
plugins/hello/hello.py

@@ -0,0 +1,46 @@
+# encoding:utf-8
+
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+import plugins
+from plugins import *
+from common.log import logger
+
+
+@plugins.register(name="Hello", desc="A simple plugin that says hello", version="0.1", author="lanvent", desire_priority= -1)
+class Hello(Plugin):
+    def __init__(self):
+        super().__init__()
+        self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
+        logger.info("[Hello] inited")
+
+    def on_handle_context(self, e_context: EventContext):
+
+        if e_context['context'].type != ContextType.TEXT:
+            return
+        
+        content = e_context['context'].content
+        logger.debug("[Hello] on_handle_context. content: %s" % content)
+        if content == "Hello":
+            reply = Reply()
+            reply.type = ReplyType.TEXT
+            msg = e_context['context']['msg']
+            if e_context['context']['isgroup']:
+                reply.content = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group")
+            else:
+                reply.content = "Hello, " + msg['User'].get('NickName', "My friend")
+            e_context['reply'] = reply
+            e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
+
+        if content == "Hi":
+            reply = Reply()
+            reply.type = ReplyType.TEXT
+            reply.content = "Hi"
+            e_context['reply'] = reply
+            e_context.action = EventAction.BREAK  # 事件结束,进入默认处理逻辑,一般会覆写reply
+
+        if content == "End":
+            # 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World"
+            e_context['context'].type = "IMAGE_CREATE"
+            content = "The World"
+            e_context.action = EventAction.CONTINUE  # 事件继续,交付给下个插件或默认逻辑

+ 3 - 0
plugins/plugin.py

@@ -0,0 +1,3 @@
+class Plugin:
+    def __init__(self):
+        self.handlers = {}

+ 171 - 0
plugins/plugin_manager.py

@@ -0,0 +1,171 @@
+# encoding:utf-8
+
+import importlib
+import json
+import os
+from common.singleton import singleton
+from common.sorted_dict import SortedDict
+from .event import *
+from .plugin import *
+from common.log import logger
+
+
+@singleton
+class PluginManager:
+    def __init__(self):
+        self.plugins = SortedDict(lambda k,v: v.priority,reverse=True)
+        self.listening_plugins = {}
+        self.instances = {}
+        self.pconf = {}
+
+    def register(self, name: str, desc: str, version: str, author: str, desire_priority: int = 0):
+        def wrapper(plugincls):
+            plugincls.name = name
+            plugincls.desc = desc
+            plugincls.version = version
+            plugincls.author = author
+            plugincls.priority = desire_priority
+            plugincls.enabled = True
+            self.plugins[name.upper()] = plugincls
+            logger.info("Plugin %s_v%s registered" % (name, version))
+            return plugincls
+        return wrapper
+
+    def save_config(self):
+        with open("plugins/plugins.json", "w", encoding="utf-8") as f:
+            json.dump(self.pconf, f, indent=4, ensure_ascii=False)
+
+    def load_config(self):
+        logger.info("Loading plugins config...")
+
+        modified = False
+        if os.path.exists("plugins/plugins.json"):
+            with open("plugins/plugins.json", "r", encoding="utf-8") as f:
+                pconf = json.load(f)
+                pconf['plugins'] = SortedDict(lambda k,v: v["priority"],pconf['plugins'],reverse=True)
+        else:
+            modified = True
+            pconf = {"plugins": SortedDict(lambda k,v: v["priority"],reverse=True)}
+        self.pconf = pconf
+        if modified:
+            self.save_config()
+        return pconf
+
+    def scan_plugins(self):
+        logger.info("Scaning plugins ...")
+        plugins_dir = "plugins"
+        for plugin_name in os.listdir(plugins_dir):
+            plugin_path = os.path.join(plugins_dir, plugin_name)
+            if os.path.isdir(plugin_path):
+                # 判断插件是否包含同名.py文件
+                main_module_path = os.path.join(plugin_path, plugin_name+".py")
+                if os.path.isfile(main_module_path):
+                    # 导入插件
+                    import_path = "{}.{}.{}".format(plugins_dir, plugin_name, plugin_name)
+                    main_module = importlib.import_module(import_path)
+        pconf = self.pconf
+        new_plugins = []
+        modified = False
+        for name, plugincls in self.plugins.items():
+            rawname = plugincls.name
+            if rawname not in pconf["plugins"]:
+                new_plugins.append(plugincls)
+                modified = True
+                logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name)
+                pconf["plugins"][rawname] = {"enabled": plugincls.enabled, "priority": plugincls.priority}
+            else:
+                self.plugins[name].enabled = pconf["plugins"][rawname]["enabled"]
+                self.plugins[name].priority = pconf["plugins"][rawname]["priority"]
+                self.plugins._update_heap(name) # 更新下plugins中的顺序
+        if modified:
+            self.save_config()
+        return new_plugins
+
+    def refresh_order(self):
+        for event in self.listening_plugins.keys():
+            self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True)
+
+    def activate_plugins(self): # 生成新开启的插件实例
+        for name, plugincls in self.plugins.items():
+            if plugincls.enabled:
+                if name not in self.instances:
+                    instance = plugincls()
+                    self.instances[name] = instance
+                    for event in instance.handlers:
+                        if event not in self.listening_plugins:
+                            self.listening_plugins[event] = []
+                        self.listening_plugins[event].append(name)
+        self.refresh_order()
+
+    def reload_plugin(self, name:str):
+        name = name.upper()
+        if name in self.instances:
+            for event in self.listening_plugins:
+                if name in self.listening_plugins[event]:
+                    self.listening_plugins[event].remove(name)
+            del self.instances[name]
+            self.activate_plugins()
+            return True
+        return False
+    
+    def load_plugins(self):
+        self.load_config()
+        self.scan_plugins()
+        pconf = self.pconf
+        logger.debug("plugins.json config={}".format(pconf))
+        for name,plugin in pconf["plugins"].items():
+            if name.upper() not in self.plugins:
+                logger.error("Plugin %s not found, but found in plugins.json" % name)
+        self.activate_plugins()
+
+    def emit_event(self, e_context: EventContext, *args, **kwargs):
+        if e_context.event in self.listening_plugins:
+            for name in self.listening_plugins[e_context.event]:
+                if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE:
+                    logger.debug("Plugin %s triggered by event %s" % (name,e_context.event))
+                    instance = self.instances[name]
+                    instance.handlers[e_context.event](e_context, *args, **kwargs)
+        return e_context
+
+    def set_plugin_priority(self, name:str, priority:int):
+        name = name.upper()
+        if name not in self.plugins:
+            return False
+        if self.plugins[name].priority == priority:
+            return True
+        self.plugins[name].priority = priority
+        self.plugins._update_heap(name)
+        rawname = self.plugins[name].name
+        self.pconf["plugins"][rawname]["priority"] = priority
+        self.pconf["plugins"]._update_heap(rawname)
+        self.save_config()
+        self.refresh_order()
+        return True
+
+    def enable_plugin(self, name:str):
+        name = name.upper()
+        if name not in self.plugins:
+            return False
+        if not self.plugins[name].enabled :
+            self.plugins[name].enabled = True
+            rawname = self.plugins[name].name
+            self.pconf["plugins"][rawname]["enabled"] = True
+            self.save_config()
+            self.activate_plugins()
+            return True
+        return True
+    
+    def disable_plugin(self, name:str):
+        name = name.upper()
+        if name not in self.plugins:
+            return False
+        if self.plugins[name].enabled :
+            self.plugins[name].enabled = False
+            rawname = self.plugins[name].name
+            self.pconf["plugins"][rawname]["enabled"] = False
+            self.save_config()
+            return True
+        return True
+    
+    def list_plugins(self):
+        return self.plugins

+ 0 - 0
plugins/sdwebui/__init__.py


+ 70 - 0
plugins/sdwebui/config.json.template

@@ -0,0 +1,70 @@
+{
+  "start":{
+    "host" : "127.0.0.1",
+    "port" : 7860
+  },
+  "defaults": {
+    "params": {
+      "sampler_name": "DPM++ 2M Karras",
+      "steps": 20,
+      "width": 512,
+      "height": 512,
+      "cfg_scale": 7,
+      "prompt":"masterpiece, best quality",
+      "negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
+      "enable_hr": false,
+      "hr_scale": 2,
+      "hr_upscaler": "Latent",
+      "hr_second_pass_steps": 15,
+      "denoising_strength": 0.7
+    },
+    "options": {
+      "sd_model_checkpoint": "perfectWorld_v2Baked"
+    }
+  },
+  "rules": [
+    {
+      "keywords": [
+        "横版",
+        "壁纸"
+      ],
+      "params": {
+        "width": 640,
+        "height": 384
+      },
+      "desc": "分辨率会变成640x384"
+    },
+    {
+      "keywords": [
+        "竖版"
+      ],
+      "params": {
+        "width": 384,
+        "height": 640
+      }
+    },
+    {
+      "keywords": [
+        "高清"
+      ],
+      "params": {
+        "enable_hr": true,
+        "hr_scale": 1.6
+      },
+      "desc": "出图分辨率长宽都会提高1.6倍"
+    },
+    {
+      "keywords": [
+        "二次元"
+      ],
+      "params": {
+        "negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
+        "prompt": "masterpiece, best quality"
+      },
+      "options": {
+        "sd_model_checkpoint": "meinamix_meinaV8"
+      },
+      "desc": "使用二次元风格模型出图"
+    }
+  ]
+}

+ 69 - 0
plugins/sdwebui/readme.md

@@ -0,0 +1,69 @@
+### 插件描述
+本插件用于将画图请求转发给stable diffusion webui。
+
+### 环境要求
+使用前先安装stable diffusion webui,并在它的启动参数中添加 "--api"。
+具体信息,请参考[文章](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)。
+
+请**安装**本插件的依赖包```webuiapi```
+```
+    ```pip install webuiapi```
+```
+### 使用说明
+请将`config.json.template`复制为`config.json`,并修改其中的参数和规则。
+
+#### 画图请求格式
+用户的画图请求格式为:
+```
+    <画图触发词><关键词1> <关键词2> ... <关键词n>:<prompt> 
+```
+- 本插件会对画图触发词后的关键词进行逐个匹配,如果触发了规则中的关键词,则会在画图请求中重载对应的参数。
+- 规则的匹配顺序参考`config.json`中的顺序,每个关键词最多被匹配到1次,如果多个关键词触发了重复的参数,重复参数以最后一个关键词为准:
+- 关键词中包含`help`或`帮助`,会打印出帮助文档。
+第一个"**:**"号之后的内容会作为附加的**prompt**,接在最终的prompt后
+
+例如: 画横版 高清 二次元:cat
+会触发三个关键词 "横版", "高清", "二次元",prompt为"cat"
+若默认参数是:
+```
+    "width": 512,
+    "height": 512,
+    "enable_hr": false,
+    "prompt": "8k"
+    "negative_prompt": "nsfw",
+    "sd_model_checkpoint": "perfectWorld_v2Baked"
+```
+
+"横版"触发的规则参数为:
+```
+    "width": 640,
+    "height": 384,
+```
+"高清"触发的规则参数为:
+```
+    "enable_hr": true,
+    "hr_scale": 1.6,
+```
+"二次元"触发的规则参数为:
+```
+    "negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
+    "steps": 20,
+    "prompt": "masterpiece, best quality",
+
+    "sd_model_checkpoint": "meinamix_meinaV8"
+```
+最后将第一个":"后的内容cat连接在prompt后,得到最终参数为:
+```
+    "width": 640,
+    "height": 384,
+    "enable_hr": true,
+    "hr_scale": 1.6,
+    "negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
+    "steps": 20,
+    "prompt": "masterpiece, best quality, cat",
+    
+    "sd_model_checkpoint": "meinamix_meinaV8"
+```
+PS: 参数分为两部分:
+- 一部分是params,为画画的参数;参数名**必须**与webuiapi包中[txt2img api](https://github.com/mix1009/sdwebuiapi/blob/fb2054e149c0a4e25125c0cd7e7dca06bda839d4/webuiapi/webuiapi.py#L163)的参数名一致
+- 另一部分是options,指sdwebui的设置,使用的模型和vae需要写在里面。它和http://127.0.0.1:7860/sdapi/v1/options所返回的键一致。

+ 114 - 0
plugins/sdwebui/sdwebui.py

@@ -0,0 +1,114 @@
+# encoding:utf-8
+
+import json
+import os
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from config import conf
+import plugins
+from plugins import *
+from common.log import logger
+import webuiapi
+import io
+
+
+@plugins.register(name="sdwebui", desc="利用stable-diffusion webui来画图", version="2.0", author="lanvent")
+class SDWebUI(Plugin):
+    def __init__(self):
+        super().__init__()
+        curdir = os.path.dirname(__file__)
+        config_path = os.path.join(curdir, "config.json")
+        try:
+            with open(config_path, "r", encoding="utf-8") as f:
+                config = json.load(f)
+                self.rules = config["rules"]
+                defaults = config["defaults"]
+                self.default_params = defaults["params"]
+                self.default_options = defaults["options"]
+                self.start_args = config["start"]
+                self.api = webuiapi.WebUIApi(**self.start_args)
+            self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
+            logger.info("[SD] inited")
+        except FileNotFoundError:
+            logger.error(f"[SD] init failed, {config_path} not found")
+        except Exception as e:
+            logger.error("[SD] init failed, exception: %s" % e)
+    
+    def on_handle_context(self, e_context: EventContext):
+
+        if e_context['context'].type != ContextType.IMAGE_CREATE:
+            return
+
+        logger.debug("[SD] on_handle_context. content: %s" %e_context['context'].content)
+
+        logger.info("[SD] image_query={}".format(e_context['context'].content))
+        reply = Reply()
+        try:
+            content = e_context['context'].content[:]
+            # 解析用户输入 如"横版 高清 二次元:cat"
+            if ":" in content:
+                keywords, prompt = content.split(":", 1)
+            else:
+                keywords = content
+                prompt = ""
+
+            keywords = keywords.split()
+
+            if "help" in keywords or "帮助" in keywords:
+                reply.type = ReplyType.INFO
+                reply.content = self.get_help_text()
+            else:
+                rule_params = {}
+                rule_options = {}
+                for keyword in keywords:
+                    matched = False
+                    for rule in self.rules:
+                        if keyword in rule["keywords"]:
+                            for key in rule["params"]:
+                                rule_params[key] = rule["params"][key]
+                            if "options" in rule:
+                                for key in rule["options"]:
+                                    rule_options[key] = rule["options"][key]
+                            matched = True
+                            break  # 一个关键词只匹配一个规则
+                    if not matched:
+                        logger.warning("[SD] keyword not matched: %s" % keyword)
+                
+                params = {**self.default_params, **rule_params}
+                options = {**self.default_options, **rule_options}
+                params["prompt"] = params.get("prompt", "")+f", {prompt}"
+                if len(options) > 0:
+                    logger.info("[SD] cover options={}".format(options))
+                    self.api.set_options(options)
+                logger.info("[SD] params={}".format(params))
+                result = self.api.txt2img(
+                    **params
+                )
+                reply.type = ReplyType.IMAGE
+                b_img = io.BytesIO()
+                result.image.save(b_img, format="PNG")
+                reply.content = b_img
+            e_context.action = EventAction.BREAK_PASS  # 事件结束后,跳过处理context的默认逻辑
+        except Exception as e:
+            reply.type = ReplyType.ERROR
+            reply.content = "[SD] "+str(e)
+            logger.error("[SD] exception: %s" % e)
+            e_context.action = EventAction.CONTINUE  # 事件继续,交付给下个插件或默认逻辑
+        finally:
+            e_context['reply'] = reply
+
+    def get_help_text(self):
+        if not conf().get('image_create_prefix'):
+            return "画图功能未启用"
+        else:
+            trigger = conf()['image_create_prefix'][0]
+        help_text = f"请使用<{trigger}[关键词1] [关键词2]...:提示语>的格式作画,如\"{trigger}横版 高清:cat\"\n"
+        help_text += "目前可用关键词:\n"
+        for rule in self.rules:
+            keywords = [f"[{keyword}]" for keyword in rule['keywords']]
+            help_text += f"{','.join(keywords)}"
+            if "desc" in rule:
+                help_text += f"-{rule['desc']}\n"
+            else:
+                help_text += "\n"
+        return help_text

+ 4 - 2
voice/baidu/baidu_voice.py

@@ -4,6 +4,7 @@ baidu voice service
 """
 import time
 from aip import AipSpeech
+from bridge.reply import Reply, ReplyType
 from common.log import logger
 from common.tmp_dir import TmpDir
 from voice.voice import Voice
@@ -30,7 +31,8 @@ class BaiduVoice(Voice):
             with open(fileName, 'wb') as f:
                 f.write(result)
             logger.info('[Baidu] textToVoice text={} voice file name={}'.format(text, fileName))
-            return fileName
+            reply = Reply(ReplyType.VOICE, fileName)
         else:
             logger.error('[Baidu] textToVoice error={}'.format(result))
-            return None
+            reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
+        return reply

+ 17 - 10
voice/google/google_voice.py

@@ -6,6 +6,7 @@ google voice service
 import pathlib
 import subprocess
 import time
+from bridge.reply import Reply, ReplyType
 import speech_recognition
 import pyttsx3
 from common.log import logger
@@ -36,16 +37,22 @@ class GoogleVoice(Voice):
             text = self.recognizer.recognize_google(audio, language='zh-CN')
             logger.info(
                 '[Google] voiceToText text={} voice file name={}'.format(text, voice_file))
-            return text
+            reply = Reply(ReplyType.TEXT, text)
         except speech_recognition.UnknownValueError:
-            return "抱歉,我听不懂。"
+            reply = Reply(ReplyType.ERROR, "抱歉,我听不懂")
         except speech_recognition.RequestError as e:
-            return "抱歉,无法连接到 Google 语音识别服务;{0}".format(e)
-
+            reply = Reply(ReplyType.ERROR, "抱歉,无法连接到 Google 语音识别服务;{0}".format(e))
+        finally:
+            return reply
     def textToVoice(self, text):
-        textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
-        self.engine.save_to_file(text, textFile)
-        self.engine.runAndWait()
-        logger.info(
-            '[Google] textToVoice text={} voice file name={}'.format(text, textFile))
-        return textFile
+        try:
+            textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
+            self.engine.save_to_file(text, textFile)
+            self.engine.runAndWait()
+            logger.info(
+                '[Google] textToVoice text={} voice file name={}'.format(text, textFile))
+            reply = Reply(ReplyType.VOICE, textFile)
+        except Exception as e:
+            reply = Reply(ReplyType.ERROR, str(e))
+        finally:
+            return reply

+ 12 - 6
voice/openai/openai_voice.py

@@ -4,6 +4,7 @@ google voice service
 """
 import json
 import openai
+from bridge.reply import Reply, ReplyType
 from config import conf
 from common.log import logger
 from voice.voice import Voice
@@ -16,12 +17,17 @@ class OpenaiVoice(Voice):
     def voiceToText(self, voice_file):
         logger.debug(
             '[Openai] voice file name={}'.format(voice_file))
-        file = open(voice_file, "rb")
-        reply = openai.Audio.transcribe("whisper-1", file)
-        text = reply["text"]
-        logger.info(
-            '[Openai] voiceToText text={} voice file name={}'.format(text, voice_file))
-        return text
+        try:
+            file = open(voice_file, "rb")
+            result = openai.Audio.transcribe("whisper-1", file)
+            text = result["text"]
+            reply = Reply(ReplyType.TEXT, text)
+            logger.info(
+                '[Openai] voiceToText text={} voice file name={}'.format(text, voice_file))
+        except Exception as e:
+            reply = Reply(ReplyType.ERROR, str(e))
+        finally:
+            return reply
 
     def textToVoice(self, text):
         pass