Browse Source

feat: refactor handle function

lanvent 3 năm trước cách đây
mục cha
commit
83136e3142
2 tập tin đã thay đổi với 115 bổ sung98 xóa
  1. 15 0
      bridge/context.py
  2. 100 98
      channel/wechat/wechat_channel.py

+ 15 - 0
bridge/context.py

@@ -14,6 +14,15 @@ class Context:
         self.type = type
         self.content = content
         self.kwargs = kwargs
+
+    def __contains__(self, key):
+        if key == 'type':
+            return self.type is not None
+        elif key == 'content':
+            return self.content is not None
+        else:
+            return key in self.kwargs
+        
     def __getitem__(self, key):
         if key == 'type':
             return self.type
@@ -21,6 +30,12 @@ class Context:
             return self.content
         else:
             return self.kwargs[key]
+    
+    def get(self, key, default=None):
+        try:
+            return self[key]
+        except KeyError:
+            return default
 
     def __setitem__(self, key, value):
         if key == 'type':

+ 100 - 98
channel/wechat/wechat_channel.py

@@ -90,6 +90,8 @@ class WechatChannel(Channel):
     #        isgroup: 是否是群聊
     #        receiver: 需要回复的对象
     #        msg: itchat的原始消息对象
+    #        origin_ctype: 原始消息类型,用于私聊语音消息时,避免匹配前缀
+    #        desire_rtype: 希望回复类型,TEXT类型是文本回复,VOICE类型是语音回复
 
     def handle_voice(self, msg):
         if conf().get('speech_recognition') != True:
@@ -106,9 +108,9 @@ class WechatChannel(Channel):
             else:
                 other_user_id = from_user_id
         if from_user_id == other_user_id:
-            context = Context(ContextType.VOICE,msg['FileName'])
-            context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id}
-            thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
+            context = self._compose_context(ContextType.VOICE, msg['FileName'], isgroup=False, msg=msg, receiver=other_user_id, session_id=other_user_id)
+            if context:
+                thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
 
     @time_checker
     def handle_text(self, msg):
@@ -125,30 +127,16 @@ class WechatChannel(Channel):
             else:
                 other_user_id = from_user_id
         create_time = msg['CreateTime']             # 消息时间
-        match_prefix = check_prefix(content, conf().get('single_chat_prefix'))
         if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60:  # 跳过1分钟前的历史消息
             logger.debug("[WX]history message skipped")
             return
         if "」\n- - - - - - - - - - - - - - -" in content:
             logger.debug("[WX]reference query skipped")
             return
-        if match_prefix:
-            content = content.replace(match_prefix, '', 1).strip()
-        elif match_prefix is None:
-            return
-        context = Context()
-        context.kwargs = {'isgroup': False, 'msg': msg,
-                          'receiver': other_user_id, 'session_id': other_user_id}
-
-        img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
-        if img_match_prefix:
-            content = content.replace(img_match_prefix, '', 1).strip()
-            context.type = ContextType.IMAGE_CREATE
-        else:
-            context.type = ContextType.TEXT
-
-        context.content = content
-        thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
+        
+        context = self._compose_context(ContextType.TEXT, content, isgroup=False, msg=msg, receiver=other_user_id, session_id=other_user_id)
+        if context:
+            thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
 
     @time_checker
     def handle_group(self, msg):
@@ -172,30 +160,19 @@ class WechatChannel(Channel):
         if "」\n- - - - - - - - - - - - - - -" in content:
             logger.debug("[WX]reference query skipped")
             return ""
-        config = conf()
-        match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \
-            or check_contain(origin_content, config.get('group_chat_keyword'))
-        if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix:
-            context = Context()
-            context.kwargs = { 'isgroup': True, 'msg': msg, 'receiver': group_id}
 
-            img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
-            if img_match_prefix:
-                content = content.replace(img_match_prefix, '', 1).strip()
-                context.type = ContextType.IMAGE_CREATE
-            else:
-                context.type = ContextType.TEXT
-            context.content = content
+        config = conf()
+        group_name_white_list = config.get('group_name_white_list', [])
+        group_name_keyword_white_list = config.get('group_name_keyword_white_list', [])
 
+        if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list), msg['IsAt'] and not config.get("group_at_off", False)]):
             group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
-            if ('ALL_GROUP' in group_chat_in_one_session or
-                    group_name in group_chat_in_one_session or
-                    check_contain(group_name, group_chat_in_one_session)):
-                context['session_id'] = group_id
-            else:
-                context['session_id'] = msg['ActualUserName']
-
-            thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
+            session_id = msg['ActualUserName']
+            if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]):
+                session_id = group_id
+            context = self._compose_context(ContextType.TEXT, content, isgroup=True, msg=msg, receiver=group_id, session_id=session_id)
+            if context:
+                thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
 
     def handle_group_voice(self, msg):
         if conf().get('group_speech_recognition', False) != True:
@@ -210,20 +187,57 @@ class WechatChannel(Channel):
         # 验证群名
         if not group_name:
             return ""
-        if ('ALL_GROUP' in conf().get('group_name_white_list') or group_name in conf().get('group_name_white_list') or check_contain(group_name, conf().get('group_name_keyword_white_list'))):
-            context = Context(ContextType.VOICE,msg['FileName'])
-            context.kwargs = {'isgroup': True, 'msg': msg, 'receiver': group_id}
-
+        
+        config = conf()
+        group_name_white_list = config.get('group_name_white_list', [])
+        group_name_keyword_white_list = config.get('group_name_keyword_white_list', [])
+        if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]):
             group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
-            if ('ALL_GROUP' in group_chat_in_one_session or
-                    group_name in group_chat_in_one_session or
-                    check_contain(group_name, group_chat_in_one_session)):
-                context['session_id'] = group_id
+            session_id =msg['ActualUserName']
+            if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]):
+                session_id = group_id
+            context = self._compose_context(ContextType.VOICE, msg['FileName'], isgroup=True, msg=msg, receiver=group_id, session_id=session_id)
+            if context:
+                thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
+
+    def _compose_context(self, ctype: ContextType, content, **kwargs):
+        context = Context(ctype, content)
+        context.kwargs = kwargs
+        if 'origin_ctype' not in context:
+            context['origin_ctype'] = ctype
+
+        if ctype == ContextType.TEXT:
+            if context["isgroup"]: # 群聊
+                # 校验关键字
+                match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
+                match_contain = check_contain(content, conf().get('group_chat_keyword'))
+                if match_prefix is not None or match_contain is not None:
+                    # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能
+                    if match_prefix:
+                        content = content.replace(match_prefix, '', 1).strip()
+                elif context["origin_ctype"] == ContextType.VOICE:
+                    logger.info("[WX]receive group voice, checkprefix didn't match")
+                    return None
+            else: # 单聊
+                match_prefix = check_prefix(content, conf().get('single_chat_prefix'))  
+                if match_prefix: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
+                    content = content.replace(match_prefix, '', 1).strip()
+                elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,不匹配前缀,直接返回
+                    pass
+                else:
+                    return None                                       
+            img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
+            if img_match_prefix:
+                content = content.replace(img_match_prefix, '', 1).strip()
+                context.type = ContextType.IMAGE_CREATE
             else:
-                context['session_id'] = msg['ActualUserName']
-
-            thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
-
+                context.type = ContextType.TEXT
+            context.content = content
+        elif context.type == ContextType.VOICE:
+            if 'desire_rtype' not in context and conf().get('voice_reply_voice'):
+                context['desire_rtype'] = ReplyType.VOICE
+        return context
+    
     # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
     def send(self, reply: Reply, receiver, retry_cnt = 0):
         try:
@@ -257,23 +271,29 @@ class WechatChannel(Channel):
                 self.send(reply, receiver, retry_cnt + 1)
 
     # 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类
-    def handle(self, context):
-        if not context.content:
-            return 
-        
-        reply = Reply()
-
+    def handle(self, context: Context):
+        if context is None or not context.content:
+            return
         logger.debug('[WX] ready to handle context: {}'.format(context))
-
         # reply的构建步骤
+        reply = self._generate_reply(context)
+
+        logger.debug('[WX] ready to decorate reply: {}'.format(reply))
+        # reply的包装步骤
+        reply = self._decorate_reply(context, reply)
+
+        # reply的发送步骤
+        self._send_reply(context, reply)
+
+    def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
         e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
             'channel': self, 'context': context, 'reply': reply}))
         reply = e_context['reply']
         if not e_context.is_pass():
             logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
-            if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
+            if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE:  # 文字和图片消息
                 reply = super().build_reply_content(context.content, context)
-            elif context.type == ContextType.VOICE: # 语音消息
+            elif context.type == ContextType.VOICE:  # 语音消息
                 msg = context['msg']
                 mp3_path = TmpDir().path() + context.content
                 msg.download(mp3_path)
@@ -281,7 +301,7 @@ class WechatChannel(Channel):
                 wav_path = os.path.splitext(mp3_path)[0] + '.wav'
                 try:
                     mp3_to_wav(mp3_path=mp3_path, wav_path=wav_path)
-                except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
+                except Exception as e:  # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
                     logger.warning("[WX]mp3 to wav error, use mp3 path. " + str(e))
                     wav_path = mp3_path
                 # 语音识别
@@ -293,50 +313,28 @@ class WechatChannel(Channel):
                 except Exception as e:
                     logger.warning("[WX]delete temp file error: " + str(e))
 
-                if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO:
-                    content = reply.content  # 语音转文字后,将文字内容作为新的context
-                    context.type = ContextType.TEXT
-                    if context["isgroup"]: # 群聊
-                        # 校验关键字
-                        match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
-                        match_contain = check_contain(content, conf().get('group_chat_keyword'))
-                        if match_prefix is not None or match_contain is not None:
-                            # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能
-                            if match_prefix:
-                                content = content.replace(match_prefix, '', 1).strip()
-                        else:
-                            logger.info("[WX]receive voice, checkprefix didn't match")
-                            return
-                    else: # 单聊
-                        match_prefix = check_prefix(content, conf().get('single_chat_prefix'))  
-                        if match_prefix: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
-                            content = content.replace(match_prefix, '', 1).strip()
-                                               
-                    img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
-                    if img_match_prefix:
-                        content = content.replace(img_match_prefix, '', 1).strip()
-                        context.type = ContextType.IMAGE_CREATE
-                    else:
-                        context.type = ContextType.TEXT
-                    context.content = content
-                    reply = super().build_reply_content(context.content, context)
-                    if reply.type == ReplyType.TEXT:
-                        if conf().get('voice_reply_voice'):
-                            reply = super().build_text_to_voice(reply.content)
+                if reply.type == ReplyType.TEXT:
+                    new_context = self._compose_context(
+                        ContextType.TEXT, reply.content, **context.kwargs)
+                    if new_context:
+                        reply = self._generate_reply(new_context)
             else:
                 logger.error('[WX] unknown context type: {}'.format(context.type))
                 return
+        return reply
 
-        logger.debug('[WX] ready to decorate reply: {}'.format(reply))
-
-        # reply的包装步骤
+    def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
         if reply and reply.type:
             e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {
                 'channel': self, 'context': context, 'reply': reply}))
             reply = e_context['reply']
+            desire_rtype = context.get('desire_rtype')
             if not e_context.is_pass() and reply and reply.type:
                 if reply.type == ReplyType.TEXT:
                     reply_text = reply.content
+                    if desire_rtype == ReplyType.VOICE:
+                        reply = super().build_text_to_voice(reply.content)
+                        return self._decorate_reply(context, reply)
                     if context['isgroup']:
                         reply_text = '@' +  context['msg']['ActualNickName'] + ' ' + reply_text.strip()
                         reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
@@ -350,8 +348,11 @@ class WechatChannel(Channel):
                 else:
                     logger.error('[WX] unknown reply type: {}'.format(reply.type))
                     return
+            if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
+                logger.warning('[WX] desire_rtype: {}, but reply type: {}'.format(context.get('desire_rtype'), reply.type))
+            return reply
 
-        # reply的发送步骤
+    def _send_reply(self, context: Context, reply: Reply):
         if reply and reply.type:
             e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {
                 'channel': self, 'context': context, 'reply': reply}))
@@ -360,6 +361,7 @@ class WechatChannel(Channel):
                 logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver']))
                 self.send(reply, context['receiver'])
 
+
 def check_prefix(content, prefix_list):
     for prefix in prefix_list:
         if content.startswith(prefix):