Przeglądaj źródła

decouple message processing process

lanvent 3 lat temu
rodzic
commit
d6037422ac
3 zmienionych plików z 167 dodań i 136 usunięć
  1. 32 13
      bot/chatgpt/chat_gpt_bot.py
  2. 5 0
      bridge/bridge.py
  3. 130 123
      channel/wechat/wechat_channel.py

+ 32 - 13
bot/chatgpt/chat_gpt_bot.py

@@ -19,19 +19,24 @@ class ChatGPTBot(Bot):
 
     def reply(self, query, context=None):
         # acquire reply content
-        if not context or not context.get('type') or context.get('type') == 'TEXT':
+        if context['type']=='TEXT':
             logger.info("[OPEN_AI] query={}".format(query))
-            session_id = context.get('session_id') or context.get('from_user_id')
+            session_id = context['session_id']
+            reply=None
             if query == '#清除记忆':
                 self.sessions.clear_session(session_id)
-                return '记忆已清除'
+                reply={'type':'INFO', 'content':'记忆已清除'}
             elif query == '#清除所有':
                 self.sessions.clear_all_session()
-                return '所有人记忆已清除'
+                reply={'type':'INFO', 'content':'所有人记忆已清除'}
             elif query == '#更新配置':
                 load_config()
-                return '配置已更新'
-
+                reply={'type':'INFO', 'content':'配置已更新'}
+            elif query == '#DEBUG':
+                logger.setLevel('DEBUG')
+                reply={'type':'INFO', 'content':'DEBUG模式已开启'}
+            if reply:
+                return reply
             session = self.sessions.build_session_query(query, session_id)
             logger.debug("[OPEN_AI] session query={}".format(session))
 
@@ -41,12 +46,26 @@ class ChatGPTBot(Bot):
 
             reply_content = self.reply_text(session, session_id, 0)
             logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"]))
-            if reply_content["completion_tokens"] > 0:
+            if reply_content['completion_tokens']==0 and len(reply_content['content'])>0:
+                reply={'type':'ERROR', 'content':reply_content['content']}
+            elif reply_content["completion_tokens"] > 0:
                 self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"])
-            return reply_content["content"]
+                reply={'type':'TEXT', 'content':reply_content["content"]}
+            else:
+                reply={'type':'ERROR', 'content':reply_content['content']}
+                logger.debug("[OPEN_AI] reply {} used 0 tokens.".format(reply_content))
+            return reply
 
-        elif context.get('type', None) == 'IMAGE_CREATE':
-            return self.create_img(query, 0)
+        elif context['type'] == 'IMAGE_CREATE':
+            ok, retstring=self.create_img(query, 0)
+            reply=None
+            if ok:
+                reply = {'type':'IMAGE', 'content':retstring}
+            else:
+                reply = {'type':'ERROR', 'content':retstring}
+            return reply
+        else:
+            reply= {'type':'ERROR', 'content':'Bot不支持处理{}类型的消息'.format(context['type'])}
 
     def reply_text(self, session, session_id, retry_count=0) ->dict:
         '''
@@ -104,7 +123,7 @@ class ChatGPTBot(Bot):
             )
             image_url = response['data'][0]['url']
             logger.info("[OPEN_AI] image_url={}".format(image_url))
-            return image_url
+            return True,image_url
         except openai.error.RateLimitError as e:
             logger.warn(e)
             if retry_count < 1:
@@ -112,10 +131,10 @@ class ChatGPTBot(Bot):
                 logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
                 return self.create_img(query, retry_count+1)
             else:
-                return "提问太快啦,请休息一下再问我吧"
+                return False,"提问太快啦,请休息一下再问我吧"
         except Exception as e:
             logger.exception(e)
-            return None
+            return False,str(e)
         
 class SessionManager(object):
     def __init__(self):

+ 5 - 0
bridge/bridge.py

@@ -15,6 +15,11 @@ class Bridge(object):
         except ModuleNotFoundError as e:
             print(e)
 
+
+    # 以下所有函数需要得到一个reply字典,格式如下:
+    # reply["type"] = "ERROR" / "TEXT" / "VOICE" / ...
+    # reply["content"] = reply的内容
+
     def fetch_reply_content(self, query, context):
         return self.bots["chat"].reply(query, context)
 

+ 130 - 123
channel/wechat/wechat_channel.py

@@ -46,62 +46,55 @@ class WechatChannel(Channel):
 
         # start message listener
         itchat.run()
+    
+    # handle_* 系列函数处理收到的消息后构造context,然后调用handle函数处理context
+    # context是一个字典,包含了消息的所有信息,包括以下key
+    #   type: 消息类型,包括TEXT、VOICE、CMD_IMAGE_CREATE
+    #   content: 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是CMD_IMAGE_CREATE类型,content就是图片生成命令
+    #   session_id: 会话id
+    #   isgroup: 是否是群聊
+    #   msg: 原始消息对象
+    #   receiver: 需要回复的对象
 
     def handle_voice(self, msg):
         if conf().get('speech_recognition') != True :
             return
         logger.debug("[WX]receive voice msg: " + msg['FileName'])
-        thread_pool.submit(self._do_handle_voice, msg)
-
-    def _do_handle_voice(self, msg):
         from_user_id = msg['FromUserName']
         other_user_id = msg['User']['UserName']
         if from_user_id == other_user_id:
-            file_name = TmpDir().path() + msg['FileName']
-            msg.download(file_name)
-            query = super().build_voice_to_text(file_name)
-            if conf().get('voice_reply_voice'):
-                self._do_send_voice(query, from_user_id)
-            else:
-                self._do_send_text(query, from_user_id)
+            context = { 'isgroup': False, 'msg': msg, 'receiver': other_user_id}
+            context['type']='VOICE'
+            context['session_id']=other_user_id
+            thread_pool.submit(self.handle, context)
+
 
     def handle_text(self, msg):
         logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False))
         content = msg['Text']
-        self._handle_single_msg(msg, content)
-
-    def _handle_single_msg(self, msg, content):
         from_user_id = msg['FromUserName']
         to_user_id = msg['ToUserName']              # 接收人id
         other_user_id = msg['User']['UserName']     # 对手方id
-        match_prefix = self.check_prefix(content, conf().get('single_chat_prefix'))
+        match_prefix = check_prefix(content, conf().get('single_chat_prefix'))
         if "」\n- - - - - - - - - - - - - - -" in content:
             logger.debug("[WX]reference query skipped")
             return
-        if from_user_id == other_user_id and match_prefix is not None:
-            # 好友向自己发送消息
-            if match_prefix != '':
-                str_list = content.split(match_prefix, 1)
-                if len(str_list) == 2:
-                    content = str_list[1].strip()
-
-            img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
-            if img_match_prefix:
-                content = content.split(img_match_prefix, 1)[1].strip()
-                thread_pool.submit(self._do_send_img, content, from_user_id)
-            else :
-                thread_pool.submit(self._do_send_text, content, from_user_id)
-        elif to_user_id == other_user_id and match_prefix:
-            # 自己给好友发送消息
-            str_list = content.split(match_prefix, 1)
-            if len(str_list) == 2:
-                content = str_list[1].strip()
-            img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
-            if img_match_prefix:
-                content = content.split(img_match_prefix, 1)[1].strip()
-                thread_pool.submit(self._do_send_img, content, to_user_id)
-            else:
-                thread_pool.submit(self._do_send_text, content, to_user_id)
+        if match_prefix:
+            content=content.replace(match_prefix,'',1).strip()
+        else:
+            return
+        context = { 'isgroup': False, 'msg': msg, 'receiver': other_user_id}
+        context['session_id']=other_user_id
+        
+        img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
+        if img_match_prefix:
+            content=content.replace(img_match_prefix,'',1).strip()
+            context['type']='CMD_IMAGE_CREATE'
+        else:
+            context['type']='TEXT'
+        
+        context['content']=content
+        thread_pool.submit(self.handle, context)
 
 
     def handle_group(self, msg):
@@ -122,100 +115,114 @@ class WechatChannel(Channel):
             logger.debug("[WX]reference query skipped")
             return ""
         config = conf()
-        match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or self.check_prefix(origin_content, config.get('group_chat_prefix')) \
-                       or self.check_contain(origin_content, config.get('group_chat_keyword'))
-        if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or self.check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix:
-            img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
+        match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \
+                       or check_contain(origin_content, config.get('group_chat_keyword'))
+        if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix:
+            context = { 'isgroup': True, 'msg': msg, 'receiver': group_id}
+            
+            img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
             if img_match_prefix:
-                content = content.split(img_match_prefix, 1)[1].strip()
-                thread_pool.submit(self._do_send_img, content, group_id)
+                content=content.replace(img_match_prefix,'',1).strip()
+                context['type']='CMD_IMAGE_CREATE'
             else:
-                thread_pool.submit(self._do_send_group, content, msg)
-
-    def send(self, msg, receiver):
-        itchat.send(msg, toUserName=receiver)
-        logger.info('[WX] sendMsg={}, receiver={}'.format(msg, receiver))
-
-    def _do_send_voice(self, query, reply_user_id):
-        try:
-            if not query:
-                return
-            context = dict()
-            context['from_user_id'] = reply_user_id
-            reply_text = super().build_reply_content(query, context)
-            if reply_text:
-                replyFile = super().build_text_to_voice(reply_text)
-                itchat.send_file(replyFile, toUserName=reply_user_id)
-                logger.info('[WX] sendFile={}, receiver={}'.format(replyFile, reply_user_id))
-        except Exception as e:
-            logger.exception(e)
-
-    def _do_send_text(self, query, reply_user_id):
-        try:
-            if not query:
-                return
-            context = dict()
-            context['session_id'] = reply_user_id
-            reply_text = super().build_reply_content(query, context)
-            if reply_text:
-                self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
-        except Exception as e:
-            logger.exception(e)
-
-    def _do_send_img(self, query, reply_user_id):
-        try:
-            if not query:
-                return
-            context = dict()
-            context['type'] = 'IMAGE_CREATE'
-            img_url = super().build_reply_content(query, context)
-            if not img_url:
-                return
-
-            # 图片下载
+                context['type']='TEXT'
+            context['content']=content
+
+            group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
+            if ('ALL_GROUP' in group_chat_in_one_session or \
+                    group_name in group_chat_in_one_session or \
+                    check_contain(group_name, group_chat_in_one_session)):
+                context['session_id'] = group_id
+            else:
+                context['session_id'] = msg['ActualUserName']
+            
+            thread_pool.submit(self.handle, context)
+    
+    # 统一的发送函数,根据reply的type字段发送不同类型的消息
+
+    def send(self, reply, receiver):
+        if reply['type']=='TEXT':
+            itchat.send(reply['content'], toUserName=receiver)
+            logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
+        elif reply['type']=='ERROR' or reply['type']=='INFO':
+            itchat.send(reply['content'], toUserName=receiver)
+            logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
+        elif reply['type']=='VOICE':
+            itchat.send_file(reply['content'], toUserName=receiver)
+            logger.info('[WX] sendFile={}, receiver={}'.format(reply['content'], receiver))
+        elif reply['type']=='IMAGE_URL': # 从网络下载图片
+            img_url = reply['content']
             pic_res = requests.get(img_url, stream=True)
             image_storage = io.BytesIO()
             for block in pic_res.iter_content(1024):
                 image_storage.write(block)
             image_storage.seek(0)
-
-            # 图片发送
-            itchat.send_image(image_storage, reply_user_id)
-            logger.info('[WX] sendImage, receiver={}'.format(reply_user_id))
-        except Exception as e:
-            logger.exception(e)
-
-    def _do_send_group(self, query, msg):
-        if not query:
-            return
-        context = dict()
-        group_name = msg['User']['NickName']
-        group_id = msg['User']['UserName']
-        group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
-        if ('ALL_GROUP' in group_chat_in_one_session or \
-                group_name in group_chat_in_one_session or \
-                self.check_contain(group_name, group_chat_in_one_session)):
-            context['session_id'] = group_id
+            itchat.send_image(image_storage, toUserName=receiver)
+            logger.info('[WX] sendImage url=, receiver={}'.format(img_url,receiver))
+        elif reply['type']=='IMAGE': # 从文件读取图片
+            image_storage = reply['content']
+            image_storage.seek(0)
+            itchat.send_image(image_storage, toUserName=receiver)
+            logger.info('[WX] sendImage, receiver={}'.format(receiver))
+        
+    # 处理消息
+    def handle(self, context):
+        content=context['content']
+        reply=None
+
+        logger.debug('[WX] ready to handle context: {}'.format(context))
+        # reply的构建步骤
+        if context['type']=='TEXT' or context['type']=='CMD_IMAGE_CREATE':
+            reply = super().build_reply_content(content,context)
+        elif context['type']=='VOICE':
+            msg=context['msg']
+            file_name = TmpDir().path() + msg['FileName']
+            msg.download(file_name)
+            reply = super().build_voice_to_text(file_name)
+            if reply['type'] != 'ERROR' and reply['type'] != 'INFO':
+                reply = super().build_reply_content(reply['content'],context)
+                if reply['type']=='TEXT': 
+                    if conf().get('voice_reply_voice'):
+                        reply = super().build_text_to_voice(reply['content'])
         else:
-            context['session_id'] = msg['ActualUserName']
-        reply_text = super().build_reply_content(query, context)
-        if reply_text:
-            reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip()
-            self.send(conf().get("group_chat_reply_prefix", "") + reply_text, group_id)
-
-
-    def check_prefix(self, content, prefix_list):
-        for prefix in prefix_list:
-            if content.startswith(prefix):
-                return prefix
-        return None
+            logger.error('[WX] unknown context type: {}'.format(context['type']))
+            return
+        
+        logger.debug('[WX] ready to decorate reply: {}'.format(reply))
+        # reply的包装步骤
+        if reply:
+            if reply['type']=='TEXT':
+                reply_text=reply['content']
+                if context['isgroup']:
+                    reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip()
+                    reply_text=conf().get("group_chat_reply_prefix","")+reply_text
+                else:
+                    reply_text=conf().get("single_chat_reply_prefix","")+reply_text
+                reply['content']=reply_text
+            elif reply['type']=='ERROR' or reply['type']=='INFO':
+                reply['content']=reply['type']+": "+ reply['content']
+            elif reply['type']=='IMAGE_URL' or reply['type']=='VOICE':
+                pass
+            else:
+                logger.error('[WX] unknown reply type: {}'.format(reply['type']))
+                return
+        if reply:
+            logger.debug('[WX] ready to send reply: {} to {}'.format(reply,context['receiver']))
+            self.send(reply, context['receiver'])
+        
+
+def check_prefix(content, prefix_list):
+    for prefix in prefix_list:
+        if content.startswith(prefix):
+            return prefix
+    return None
 
 
-    def check_contain(self, content, keyword_list):
-        if not keyword_list:
-            return None
-        for ky in keyword_list:
-            if content.find(ky) != -1:
-                return True
+def check_contain(content, keyword_list):
+    if not keyword_list:
         return None
+    for ky in keyword_list:
+        if content.find(ky) != -1:
+            return True
+    return None