Forráskód Böngészése

formatting: run precommit on all files

lanvent 3 éve
szülő
commit
618c94edb8
40 módosított fájl, 228 hozzáadás és 646 törlés
  1. 1 0
      app.py
  2. 2 10
      bot/baidu/baidu_unit_bot.py
  3. 7 22
      bot/chatgpt/chat_gpt_bot.py
  4. 5 17
      bot/chatgpt/chat_gpt_session.py
  5. 7 21
      bot/openai/open_ai_bot.py
  6. 2 8
      bot/openai/open_ai_image.py
  7. 3 13
      bot/openai/open_ai_session.py
  8. 4 16
      bot/session_manager.py
  9. 1 3
      bridge/context.py
  10. 28 103
      channel/chat_channel.py
  11. 1 3
      channel/terminal/terminal_channel.py
  12. 8 31
      channel/wechat/wechat_channel.py
  13. 6 20
      channel/wechat/wechat_message.py
  14. 5 15
      channel/wechat/wechaty_channel.py
  15. 3 9
      channel/wechat/wechaty_message.py
  16. 8 17
      channel/wechatmp/active_reply.py
  17. 5 3
      channel/wechatmp/common.py
  18. 18 36
      channel/wechatmp/passive_reply.py
  19. 20 27
      channel/wechatmp/wechatmp_channel.py
  20. 10 11
      channel/wechatmp/wechatmp_client.py
  21. 5 14
      channel/wechatmp/wechatmp_message.py
  22. 3 11
      common/time_check.py
  23. 1 3
      config.py
  24. 3 9
      plugins/banwords/banwords.py
  25. 7 32
      plugins/bdunit/bdunit.py
  26. 2 8
      plugins/dungeon/dungeon.py
  27. 7 23
      plugins/godcmd/godcmd.py
  28. 2 6
      plugins/hello/hello.py
  29. 2 2
      plugins/keyword/config.json.template
  30. 1 3
      plugins/keyword/keyword.py
  31. 14 46
      plugins/plugin_manager.py
  32. 7 24
      plugins/role/role.py
  33. 1 1
      plugins/tool/README.md
  34. 4 10
      plugins/tool/tool.py
  35. 5 15
      voice/audio_convert.py
  36. 9 33
      voice/azure/azure_voice.py
  37. 1 3
      voice/baidu/baidu_voice.py
  38. 2 8
      voice/google/google_voice.py
  39. 1 5
      voice/openai/openai_voice.py
  40. 7 5
      voice/pytts/pytts_voice.py

+ 1 - 0
app.py

@@ -19,6 +19,7 @@ def sigterm_handler_wrap(_signo):
         if callable(old_handler):  #  check old_handler
             return old_handler(_signo, _stack_frame)
         sys.exit(0)
+
     signal.signal(_signo, func)
 
 

+ 2 - 10
bot/baidu/baidu_unit_bot.py

@@ -10,10 +10,7 @@ from bridge.reply import Reply, ReplyType
 class BaiduUnitBot(Bot):
     def reply(self, query, context=None):
         token = self.get_token()
-        url = (
-            "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token="
-            + token
-        )
+        url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + token
         post_data = (
             '{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"'
             + query
@@ -32,12 +29,7 @@ class BaiduUnitBot(Bot):
     def get_token(self):
         access_key = "YOUR_ACCESS_KEY"
         secret_key = "YOUR_SECRET_KEY"
-        host = (
-            "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id="
-            + access_key
-            + "&client_secret="
-            + secret_key
-        )
+        host = "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + access_key + "&client_secret=" + secret_key
         response = requests.get(host)
         if response:
             print(response.json())

+ 7 - 22
bot/chatgpt/chat_gpt_bot.py

@@ -30,23 +30,15 @@ class ChatGPTBot(Bot, OpenAIImage):
         if conf().get("rate_limit_chatgpt"):
             self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))
 
-        self.sessions = SessionManager(
-            ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo"
-        )
+        self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
         self.args = {
             "model": conf().get("model") or "gpt-3.5-turbo",  # 对话模型的名称
             "temperature": conf().get("temperature", 0.9),  # 值在[0,1]之间,越大表示回复越具有不确定性
             # "max_tokens":4096,  # 回复最大的字符数
             "top_p": 1,
-            "frequency_penalty": conf().get(
-                "frequency_penalty", 0.0
-            ),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
-            "presence_penalty": conf().get(
-                "presence_penalty", 0.0
-            ),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
-            "request_timeout": conf().get(
-                "request_timeout", None
-            ),  # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
+            "frequency_penalty": conf().get("frequency_penalty", 0.0),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
+            "presence_penalty": conf().get("presence_penalty", 0.0),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
+            "request_timeout": conf().get("request_timeout", None),  # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
             "timeout": conf().get("request_timeout", None),  # 重试超时时间,在这个时间内,将会自动重试
         }
 
@@ -87,15 +79,10 @@ class ChatGPTBot(Bot, OpenAIImage):
                     reply_content["completion_tokens"],
                 )
             )
-            if (
-                reply_content["completion_tokens"] == 0
-                and len(reply_content["content"]) > 0
-            ):
+            if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
                 reply = Reply(ReplyType.ERROR, reply_content["content"])
             elif reply_content["completion_tokens"] > 0:
-                self.sessions.session_reply(
-                    reply_content["content"], session_id, reply_content["total_tokens"]
-                )
+                self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
                 reply = Reply(ReplyType.TEXT, reply_content["content"])
             else:
                 reply = Reply(ReplyType.ERROR, reply_content["content"])
@@ -126,9 +113,7 @@ class ChatGPTBot(Bot, OpenAIImage):
             if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
                 raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
             # if api_key == None, the default openai.api_key will be used
-            response = openai.ChatCompletion.create(
-                api_key=api_key, messages=session.messages, **self.args
-            )
+            response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **self.args)
             # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
             return {
                 "total_tokens": response["usage"]["total_tokens"],

+ 5 - 17
bot/chatgpt/chat_gpt_session.py

@@ -25,9 +25,7 @@ class ChatGPTSession(Session):
             precise = False
             if cur_tokens is None:
                 raise e
-            logger.debug(
-                "Exception when counting tokens precisely for query: {}".format(e)
-            )
+            logger.debug("Exception when counting tokens precisely for query: {}".format(e))
         while cur_tokens > max_tokens:
             if len(self.messages) > 2:
                 self.messages.pop(1)
@@ -39,16 +37,10 @@ class ChatGPTSession(Session):
                     cur_tokens = cur_tokens - max_tokens
                 break
             elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
-                logger.warn(
-                    "user message exceed max_tokens. total_tokens={}".format(cur_tokens)
-                )
+                logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
                 break
             else:
-                logger.debug(
-                    "max_tokens={}, total_tokens={}, len(messages)={}".format(
-                        max_tokens, cur_tokens, len(self.messages)
-                    )
-                )
+                logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
                 break
             if precise:
                 cur_tokens = self.calc_tokens()
@@ -75,17 +67,13 @@ def num_tokens_from_messages(messages, model):
     elif model == "gpt-4":
         return num_tokens_from_messages(messages, model="gpt-4-0314")
     elif model == "gpt-3.5-turbo-0301":
-        tokens_per_message = (
-            4  # every message follows <|start|>{role/name}\n{content}<|end|>\n
-        )
+        tokens_per_message = 4  # every message follows <|start|>{role/name}\n{content}<|end|>\n
         tokens_per_name = -1  # if there's a name, the role is omitted
     elif model == "gpt-4-0314":
         tokens_per_message = 3
         tokens_per_name = 1
     else:
-        logger.warn(
-            f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301."
-        )
+        logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.")
         return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
     num_tokens = 0
     for message in messages:

+ 7 - 21
bot/openai/open_ai_bot.py

@@ -28,23 +28,15 @@ class OpenAIBot(Bot, OpenAIImage):
         if proxy:
             openai.proxy = proxy
 
-        self.sessions = SessionManager(
-            OpenAISession, model=conf().get("model") or "text-davinci-003"
-        )
+        self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003")
         self.args = {
             "model": conf().get("model") or "text-davinci-003",  # 对话模型的名称
             "temperature": conf().get("temperature", 0.9),  # 值在[0,1]之间,越大表示回复越具有不确定性
             "max_tokens": 1200,  # 回复最大的字符数
             "top_p": 1,
-            "frequency_penalty": conf().get(
-                "frequency_penalty", 0.0
-            ),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
-            "presence_penalty": conf().get(
-                "presence_penalty", 0.0
-            ),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
-            "request_timeout": conf().get(
-                "request_timeout", None
-            ),  # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
+            "frequency_penalty": conf().get("frequency_penalty", 0.0),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
+            "presence_penalty": conf().get("presence_penalty", 0.0),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
+            "request_timeout": conf().get("request_timeout", None),  # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
             "timeout": conf().get("request_timeout", None),  # 重试超时时间,在这个时间内,将会自动重试
             "stop": ["\n\n\n"],
         }
@@ -71,17 +63,13 @@ class OpenAIBot(Bot, OpenAIImage):
                         result["content"],
                     )
                     logger.debug(
-                        "[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
-                            str(session), session_id, reply_content, completion_tokens
-                        )
+                        "[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)
                     )
 
                     if total_tokens == 0:
                         reply = Reply(ReplyType.ERROR, reply_content)
                     else:
-                        self.sessions.session_reply(
-                            reply_content, session_id, total_tokens
-                        )
+                        self.sessions.session_reply(reply_content, session_id, total_tokens)
                         reply = Reply(ReplyType.TEXT, reply_content)
                 return reply
             elif context.type == ContextType.IMAGE_CREATE:
@@ -96,9 +84,7 @@ class OpenAIBot(Bot, OpenAIImage):
     def reply_text(self, session: OpenAISession, retry_count=0):
         try:
             response = openai.Completion.create(prompt=str(session), **self.args)
-            res_content = (
-                response.choices[0]["text"].strip().replace("<|endoftext|>", "")
-            )
+            res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "")
             total_tokens = response["usage"]["total_tokens"]
             completion_tokens = response["usage"]["completion_tokens"]
             logger.info("[OPEN_AI] reply={}".format(res_content))

+ 2 - 8
bot/openai/open_ai_image.py

@@ -23,9 +23,7 @@ class OpenAIImage(object):
             response = openai.Image.create(
                 prompt=query,  # 图片描述
                 n=1,  # 每次生成图片的数量
-                size=conf().get(
-                    "image_create_size", "256x256"
-                ),  # 图片大小,可选有 256x256, 512x512, 1024x1024
+                size=conf().get("image_create_size", "256x256"),  # 图片大小,可选有 256x256, 512x512, 1024x1024
             )
             image_url = response["data"][0]["url"]
             logger.info("[OPEN_AI] image_url={}".format(image_url))
@@ -34,11 +32,7 @@ class OpenAIImage(object):
             logger.warn(e)
             if retry_count < 1:
                 time.sleep(5)
-                logger.warn(
-                    "[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(
-                        retry_count + 1
-                    )
-                )
+                logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1))
                 return self.create_img(query, retry_count + 1)
             else:
                 return False, "提问太快啦,请休息一下再问我吧"

+ 3 - 13
bot/openai/open_ai_session.py

@@ -36,9 +36,7 @@ class OpenAISession(Session):
             precise = False
             if cur_tokens is None:
                 raise e
-            logger.debug(
-                "Exception when counting tokens precisely for query: {}".format(e)
-            )
+            logger.debug("Exception when counting tokens precisely for query: {}".format(e))
         while cur_tokens > max_tokens:
             if len(self.messages) > 1:
                 self.messages.pop(0)
@@ -50,18 +48,10 @@ class OpenAISession(Session):
                     cur_tokens = len(str(self))
                 break
             elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
-                logger.warn(
-                    "user question exceed max_tokens. total_tokens={}".format(
-                        cur_tokens
-                    )
-                )
+                logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
                 break
             else:
-                logger.debug(
-                    "max_tokens={}, total_tokens={}, len(conversation)={}".format(
-                        max_tokens, cur_tokens, len(self.messages)
-                    )
-                )
+                logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
                 break
             if precise:
                 cur_tokens = self.calc_tokens()

+ 4 - 16
bot/session_manager.py

@@ -55,9 +55,7 @@ class SessionManager(object):
             return self.sessioncls(session_id, system_prompt, **self.session_args)
 
         if session_id not in self.sessions:
-            self.sessions[session_id] = self.sessioncls(
-                session_id, system_prompt, **self.session_args
-            )
+            self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args)
         elif system_prompt is not None:  # 如果有新的system_prompt,更新并重置session
             self.sessions[session_id].set_system_prompt(system_prompt)
         session = self.sessions[session_id]
@@ -71,9 +69,7 @@ class SessionManager(object):
             total_tokens = session.discard_exceeding(max_tokens, None)
             logger.debug("prompt tokens used={}".format(total_tokens))
         except Exception as e:
-            logger.debug(
-                "Exception when counting tokens precisely for prompt: {}".format(str(e))
-            )
+            logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
         return session
 
     def session_reply(self, reply, session_id, total_tokens=None):
@@ -82,17 +78,9 @@ class SessionManager(object):
         try:
             max_tokens = conf().get("conversation_max_tokens", 1000)
             tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
-            logger.debug(
-                "raw total_tokens={}, savesession tokens={}".format(
-                    total_tokens, tokens_cnt
-                )
-            )
+            logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
         except Exception as e:
-            logger.debug(
-                "Exception when counting tokens precisely for session: {}".format(
-                    str(e)
-                )
-            )
+            logger.debug("Exception when counting tokens precisely for session: {}".format(str(e)))
         return session
 
     def clear_session(self, session_id):

+ 1 - 3
bridge/context.py

@@ -60,6 +60,4 @@ class Context:
             del self.kwargs[key]
 
     def __str__(self):
-        return "Context(type={}, content={}, kwargs={})".format(
-            self.type, self.content, self.kwargs
-        )
+        return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)

+ 28 - 103
channel/chat_channel.py

@@ -53,9 +53,7 @@ class ChatChannel(Channel):
                 group_id = cmsg.other_user_id
 
                 group_name_white_list = config.get("group_name_white_list", [])
-                group_name_keyword_white_list = config.get(
-                    "group_name_keyword_white_list", []
-                )
+                group_name_keyword_white_list = config.get("group_name_keyword_white_list", [])
                 if any(
                     [
                         group_name in group_name_white_list,
@@ -63,9 +61,7 @@ class ChatChannel(Channel):
                         check_contain(group_name, group_name_keyword_white_list),
                     ]
                 ):
-                    group_chat_in_one_session = conf().get(
-                        "group_chat_in_one_session", []
-                    )
+                    group_chat_in_one_session = conf().get("group_chat_in_one_session", [])
                     session_id = cmsg.actual_user_id
                     if any(
                         [
@@ -81,17 +77,11 @@ class ChatChannel(Channel):
             else:
                 context["session_id"] = cmsg.other_user_id
                 context["receiver"] = cmsg.other_user_id
-            e_context = PluginManager().emit_event(
-                EventContext(
-                    Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}
-                )
-            )
+            e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}))
             context = e_context["context"]
             if e_context.is_pass() or context is None:
                 return context
-            if cmsg.from_user_id == self.user_id and not config.get(
-                "trigger_by_self", True
-            ):
+            if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True):
                 logger.debug("[WX]self message skipped")
                 return None
 
@@ -119,19 +109,13 @@ class ChatChannel(Channel):
 
                 if not flag:
                     if context["origin_ctype"] == ContextType.VOICE:
-                        logger.info(
-                            "[WX]receive group voice, but checkprefix didn't match"
-                        )
+                        logger.info("[WX]receive group voice, but checkprefix didn't match")
                     return None
             else:  # 单聊
-                match_prefix = check_prefix(
-                    content, conf().get("single_chat_prefix", [""])
-                )
+                match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""]))
                 if match_prefix is not None:  # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
                     content = content.replace(match_prefix, "", 1).strip()
-                elif (
-                    context["origin_ctype"] == ContextType.VOICE
-                ):  # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
+                elif context["origin_ctype"] == ContextType.VOICE:  # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
                     pass
                 else:
                     return None
@@ -143,18 +127,10 @@ class ChatChannel(Channel):
             else:
                 context.type = ContextType.TEXT
             context.content = content.strip()
-            if (
-                "desire_rtype" not in context
-                and conf().get("always_reply_voice")
-                and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
-            ):
+            if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
                 context["desire_rtype"] = ReplyType.VOICE
         elif context.type == ContextType.VOICE:
-            if (
-                "desire_rtype" not in context
-                and conf().get("voice_reply_voice")
-                and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
-            ):
+            if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
                 context["desire_rtype"] = ReplyType.VOICE
 
         return context
@@ -182,15 +158,8 @@ class ChatChannel(Channel):
         )
         reply = e_context["reply"]
         if not e_context.is_pass():
-            logger.debug(
-                "[WX] ready to handle context: type={}, content={}".format(
-                    context.type, context.content
-                )
-            )
-            if (
-                context.type == ContextType.TEXT
-                or context.type == ContextType.IMAGE_CREATE
-            ):  # 文字和图片消息
+            logger.debug("[WX] ready to handle context: type={}, content={}".format(context.type, context.content))
+            if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE:  # 文字和图片消息
                 reply = super().build_reply_content(context.content, context)
             elif context.type == ContextType.VOICE:  # 语音消息
                 cmsg = context["msg"]
@@ -214,9 +183,7 @@ class ChatChannel(Channel):
                     # logger.warning("[WX]delete temp file error: " + str(e))
 
                 if reply.type == ReplyType.TEXT:
-                    new_context = self._compose_context(
-                        ContextType.TEXT, reply.content, **context.kwargs
-                    )
+                    new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs)
                     if new_context:
                         reply = self._generate_reply(new_context)
                     else:
@@ -246,48 +213,24 @@ class ChatChannel(Channel):
 
                 if reply.type == ReplyType.TEXT:
                     reply_text = reply.content
-                    if (
-                        desire_rtype == ReplyType.VOICE
-                        and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
-                    ):
+                    if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
                         reply = super().build_text_to_voice(reply.content)
                         return self._decorate_reply(context, reply)
                     if context.get("isgroup", False):
-                        reply_text = (
-                            "@"
-                            + context["msg"].actual_user_nickname
-                            + " "
-                            + reply_text.strip()
-                        )
-                        reply_text = (
-                            conf().get("group_chat_reply_prefix", "") + reply_text
-                        )
+                        reply_text = "@" + context["msg"].actual_user_nickname + " " + reply_text.strip()
+                        reply_text = conf().get("group_chat_reply_prefix", "") + reply_text
                     else:
-                        reply_text = (
-                            conf().get("single_chat_reply_prefix", "") + reply_text
-                        )
+                        reply_text = conf().get("single_chat_reply_prefix", "") + reply_text
                     reply.content = reply_text
                 elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
                     reply.content = "[" + str(reply.type) + "]\n" + reply.content
-                elif (
-                    reply.type == ReplyType.IMAGE_URL
-                    or reply.type == ReplyType.VOICE
-                    or reply.type == ReplyType.IMAGE
-                ):
+                elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
                     pass
                 else:
                     logger.error("[WX] unknown reply type: {}".format(reply.type))
                     return
-            if (
-                desire_rtype
-                and desire_rtype != reply.type
-                and reply.type not in [ReplyType.ERROR, ReplyType.INFO]
-            ):
-                logger.warning(
-                    "[WX] desire_rtype: {}, but reply type: {}".format(
-                        context.get("desire_rtype"), reply.type
-                    )
-                )
+            if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
+                logger.warning("[WX] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
             return reply
 
     def _send_reply(self, context: Context, reply: Reply):
@@ -300,9 +243,7 @@ class ChatChannel(Channel):
             )
             reply = e_context["reply"]
             if not e_context.is_pass() and reply and reply.type:
-                logger.debug(
-                    "[WX] ready to send reply: {}, context: {}".format(reply, context)
-                )
+                logger.debug("[WX] ready to send reply: {}, context: {}".format(reply, context))
                 self._send(reply, context)
 
     def _send(self, reply: Reply, context: Context, retry_cnt=0):
@@ -328,9 +269,7 @@ class ChatChannel(Channel):
             try:
                 worker_exception = worker.exception()
                 if worker_exception:
-                    self._fail_callback(
-                        session_id, exception=worker_exception, **kwargs
-                    )
+                    self._fail_callback(session_id, exception=worker_exception, **kwargs)
                 else:
                     self._success_callback(session_id, **kwargs)
             except CancelledError as e:
@@ -366,24 +305,14 @@ class ChatChannel(Channel):
                         if not context_queue.empty():
                             context = context_queue.get()
                             logger.debug("[WX] consume context: {}".format(context))
-                            future: Future = self.handler_pool.submit(
-                                self._handle, context
-                            )
-                            future.add_done_callback(
-                                self._thread_pool_callback(session_id, context=context)
-                            )
+                            future: Future = self.handler_pool.submit(self._handle, context)
+                            future.add_done_callback(self._thread_pool_callback(session_id, context=context))
                             if session_id not in self.futures:
                                 self.futures[session_id] = []
                             self.futures[session_id].append(future)
-                        elif (
-                            semaphore._initial_value == semaphore._value + 1
-                        ):  # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
-                            self.futures[session_id] = [
-                                t for t in self.futures[session_id] if not t.done()
-                            ]
-                            assert (
-                                len(self.futures[session_id]) == 0
-                            ), "thread pool error"
+                        elif semaphore._initial_value == semaphore._value + 1:  # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
+                            self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
+                            assert len(self.futures[session_id]) == 0, "thread pool error"
                             del self.sessions[session_id]
                         else:
                             semaphore.release()
@@ -397,9 +326,7 @@ class ChatChannel(Channel):
                     future.cancel()
                 cnt = self.sessions[session_id][0].qsize()
                 if cnt > 0:
-                    logger.info(
-                        "Cancel {} messages in session {}".format(cnt, session_id)
-                    )
+                    logger.info("Cancel {} messages in session {}".format(cnt, session_id))
                 self.sessions[session_id][0] = Dequeue()
 
     def cancel_all_session(self):
@@ -409,9 +336,7 @@ class ChatChannel(Channel):
                     future.cancel()
                 cnt = self.sessions[session_id][0].qsize()
                 if cnt > 0:
-                    logger.info(
-                        "Cancel {} messages in session {}".format(cnt, session_id)
-                    )
+                    logger.info("Cancel {} messages in session {}".format(cnt, session_id))
                 self.sessions[session_id][0] = Dequeue()
 
 

+ 1 - 3
channel/terminal/terminal_channel.py

@@ -77,9 +77,7 @@ class TerminalChannel(ChatChannel):
             if check_prefix(prompt, trigger_prefixs) is None:
                 prompt = trigger_prefixs[0] + prompt  # 给没触发的消息加上触发前缀
 
-            context = self._compose_context(
-                ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt)
-            )
+            context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt))
             if context:
                 self.produce(context)
             else:

+ 8 - 31
channel/wechat/wechat_channel.py

@@ -56,10 +56,7 @@ def _check(func):
             return
         self.receivedMsgs[msgId] = cmsg
         create_time = cmsg.create_time  # 消息时间戳
-        if (
-            conf().get("hot_reload") == True
-            and int(create_time) < int(time.time()) - 60
-        ):  # 跳过1分钟前的历史消息
+        if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60:  # 跳过1分钟前的历史消息
             logger.debug("[WX]history message {} skipped".format(msgId))
             return
         return func(self, cmsg)
@@ -88,15 +85,9 @@ def qrCallback(uuid, status, qrcode):
         url = f"https://login.weixin.qq.com/l/{uuid}"
 
         qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
-        qr_api2 = (
-            "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(
-                url
-            )
-        )
+        qr_api2 = "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url)
         qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url)
-        qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(
-            url
-        )
+        qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url)
         print("You can also scan QRCode in any website below:")
         print(qr_api3)
         print(qr_api4)
@@ -134,18 +125,12 @@ class WechatChannel(ChatChannel):
                 logger.error("Hot reload failed, try to login without hot reload")
                 itchat.logout()
                 os.remove(status_path)
-                itchat.auto_login(
-                    enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback
-                )
+                itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback)
             else:
                 raise e
         self.user_id = itchat.instance.storageClass.userName
         self.name = itchat.instance.storageClass.nickName
-        logger.info(
-            "Wechat login success, user_id: {}, nickname: {}".format(
-                self.user_id, self.name
-            )
-        )
+        logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
         # start message listener
         itchat.run()
 
@@ -173,16 +158,10 @@ class WechatChannel(ChatChannel):
         elif cmsg.ctype == ContextType.PATPAT:
             logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
         elif cmsg.ctype == ContextType.TEXT:
-            logger.debug(
-                "[WX]receive text msg: {}, cmsg={}".format(
-                    json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg
-                )
-            )
+            logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
         else:
             logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
-        context = self._compose_context(
-            cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg
-        )
+        context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
         if context:
             self.produce(context)
 
@@ -202,9 +181,7 @@ class WechatChannel(ChatChannel):
             pass
         else:
             logger.debug("[WX]receive group msg: {}".format(cmsg.content))
-        context = self._compose_context(
-            cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg
-        )
+        context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
         if context:
             self.produce(context)
 

+ 6 - 20
channel/wechat/wechat_message.py

@@ -27,37 +27,23 @@ class WeChatMessage(ChatMessage):
             self.content = TmpDir().path() + itchat_msg["FileName"]  # content直接存临时目录路径
             self._prepare_fn = lambda: itchat_msg.download(self.content)
         elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000:
-            if is_group and (
-                "加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]
-            ):
+            if is_group and ("加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]):
                 self.ctype = ContextType.JOIN_GROUP
                 self.content = itchat_msg["Content"]
                 # 这里只能得到nickname, actual_user_id还是机器人的id
                 if "加入了群聊" in itchat_msg["Content"]:
-                    self.actual_user_nickname = re.findall(
-                        r"\"(.*?)\"", itchat_msg["Content"]
-                    )[-1]
+                    self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1]
                 elif "加入群聊" in itchat_msg["Content"]:
-                    self.actual_user_nickname = re.findall(
-                        r"\"(.*?)\"", itchat_msg["Content"]
-                    )[0]
+                    self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
             elif "拍了拍我" in itchat_msg["Content"]:
                 self.ctype = ContextType.PATPAT
                 self.content = itchat_msg["Content"]
                 if is_group:
-                    self.actual_user_nickname = re.findall(
-                        r"\"(.*?)\"", itchat_msg["Content"]
-                    )[0]
+                    self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
             else:
-                raise NotImplementedError(
-                    "Unsupported note message: " + itchat_msg["Content"]
-                )
+                raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
         else:
-            raise NotImplementedError(
-                "Unsupported message type: Type:{} MsgType:{}".format(
-                    itchat_msg["Type"], itchat_msg["MsgType"]
-                )
-            )
+            raise NotImplementedError("Unsupported message type: Type:{} MsgType:{}".format(itchat_msg["Type"], itchat_msg["MsgType"]))
 
         self.from_user_id = itchat_msg["FromUserName"]
         self.to_user_id = itchat_msg["ToUserName"]

+ 5 - 15
channel/wechat/wechaty_channel.py

@@ -60,13 +60,9 @@ class WechatyChannel(ChatChannel):
         receiver_id = context["receiver"]
         loop = asyncio.get_event_loop()
         if context["isgroup"]:
-            receiver = asyncio.run_coroutine_threadsafe(
-                self.bot.Room.find(receiver_id), loop
-            ).result()
+            receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id), loop).result()
         else:
-            receiver = asyncio.run_coroutine_threadsafe(
-                self.bot.Contact.find(receiver_id), loop
-            ).result()
+            receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id), loop).result()
         msg = None
         if reply.type == ReplyType.TEXT:
             msg = reply.content
@@ -83,9 +79,7 @@ class WechatyChannel(ChatChannel):
             voiceLength = int(any_to_sil(file_path, sil_file))
             if voiceLength >= 60000:
                 voiceLength = 60000
-                logger.info(
-                    "[WX] voice too long, length={}, set to 60s".format(voiceLength)
-                )
+                logger.info("[WX] voice too long, length={}, set to 60s".format(voiceLength))
             # 发送语音
             t = int(time.time())
             msg = FileBox.from_file(sil_file, name=str(t) + ".sil")
@@ -98,9 +92,7 @@ class WechatyChannel(ChatChannel):
                     os.remove(sil_file)
             except Exception as e:
                 pass
-            logger.info(
-                "[WX] sendVoice={}, receiver={}".format(reply.content, receiver)
-            )
+            logger.info("[WX] sendVoice={}, receiver={}".format(reply.content, receiver))
         elif reply.type == ReplyType.IMAGE_URL:  # 从网络下载图片
             img_url = reply.content
             t = int(time.time())
@@ -111,9 +103,7 @@ class WechatyChannel(ChatChannel):
             image_storage = reply.content
             image_storage.seek(0)
             t = int(time.time())
-            msg = FileBox.from_base64(
-                base64.b64encode(image_storage.read()), str(t) + ".png"
-            )
+            msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + ".png")
             asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
             logger.info("[WX] sendImage, receiver={}".format(receiver))
 

+ 3 - 9
channel/wechat/wechaty_message.py

@@ -45,16 +45,12 @@ class WechatyMessage(ChatMessage, aobject):
 
             def func():
                 loop = asyncio.get_event_loop()
-                asyncio.run_coroutine_threadsafe(
-                    voice_file.to_file(self.content), loop
-                ).result()
+                asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content), loop).result()
 
             self._prepare_fn = func
 
         else:
-            raise NotImplementedError(
-                "Unsupported message type: {}".format(wechaty_msg.type())
-            )
+            raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type()))
 
         from_contact = wechaty_msg.talker()  # 获取消息的发送者
         self.from_user_id = from_contact.contact_id
@@ -73,9 +69,7 @@ class WechatyMessage(ChatMessage, aobject):
             self.to_user_id = to_contact.contact_id
             self.to_user_nickname = to_contact.name
 
-        if (
-            self.is_group or wechaty_msg.is_self()
-        ):  # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
+        if self.is_group or wechaty_msg.is_self():  # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
             self.other_user_id = self.to_user_id
             self.other_user_nickname = self.to_user_nickname
         else:

+ 8 - 17
channel/wechatmp/active_reply.py

@@ -1,16 +1,17 @@
 import time
 
 import web
+from wechatpy import parse_message
+from wechatpy.replies import create_reply
 
-from channel.wechatmp.wechatmp_message import WeChatMPMessage
 from bridge.context import *
 from bridge.reply import *
 from channel.wechatmp.common import *
 from channel.wechatmp.wechatmp_channel import WechatMPChannel
-from wechatpy import parse_message
+from channel.wechatmp.wechatmp_message import WeChatMPMessage
 from common.log import logger
 from config import conf
-from wechatpy.replies import create_reply
+
 
 # This class is instantiated once per query
 class Query:
@@ -50,29 +51,19 @@ class Query:
                     )
                 )
                 if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
-                    context = channel._compose_context(
-                        wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg
-                    )
+                    context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
                 else:
-                    context = channel._compose_context(
-                        wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg
-                    )
+                    context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
                 if context:
                     # set private openai_api_key
                     # if from_user is not changed in itchat, this can be placed at chat_channel
                     user_data = conf().get_user_data(from_user)
-                    context["openai_api_key"] = user_data.get(
-                        "openai_api_key"
-                    )  # None or user openai_api_key
+                    context["openai_api_key"] = user_data.get("openai_api_key")  # None or user openai_api_key
                     channel.produce(context)
                 # The reply will be sent by channel.send() in another thread
                 return "success"
             elif msg.type == "event":
-                logger.info(
-                    "[wechatmp] Event {} from {}".format(
-                        msg.event, msg.source
-                    )
-                )
+                logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
                 if msg.event in ["subscribe", "subscribe_scan"]:
                     reply_text = subscribe_msg()
                     replyPost = create_reply(reply_text, msg)

+ 5 - 3
channel/wechatmp/common.py

@@ -1,10 +1,12 @@
 import textwrap
-import web
 
-from config import conf
-from wechatpy.utils import check_signature
+import web
 from wechatpy.crypto import WeChatCrypto
 from wechatpy.exceptions import InvalidSignatureException
+from wechatpy.utils import check_signature
+
+from config import conf
+
 MAX_UTF8_LEN = 2048
 
 

+ 18 - 36
channel/wechatmp/passive_reply.py

@@ -1,17 +1,18 @@
-import time
 import asyncio
+import time
 
 import web
+from wechatpy import parse_message
+from wechatpy.replies import ImageReply, VoiceReply, create_reply
 
-from channel.wechatmp.wechatmp_message import WeChatMPMessage
 from bridge.context import *
 from bridge.reply import *
 from channel.wechatmp.common import *
 from channel.wechatmp.wechatmp_channel import WechatMPChannel
+from channel.wechatmp.wechatmp_message import WeChatMPMessage
 from common.log import logger
 from config import conf
-from wechatpy import parse_message
-from wechatpy.replies import create_reply, ImageReply, VoiceReply
+
 
 # This class is instantiated once per query
 class Query:
@@ -49,21 +50,15 @@ class Query:
                 if (
                     from_user not in channel.cache_dict
                     and from_user not in channel.running
-                    or content.startswith("#") 
-                    and message_id not in channel.request_cnt # insert the godcmd
+                    or content.startswith("#")
+                    and message_id not in channel.request_cnt  # insert the godcmd
                 ):
                     # The first query begin
                     if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
-                        context = channel._compose_context(
-                            wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg
-                        )
+                        context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
                     else:
-                        context = channel._compose_context(
-                            wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg
-                        )
-                    logger.debug(
-                        "[wechatmp] context: {} {} {}".format(context, wechatmp_msg, supported)
-                    )
+                        context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
+                    logger.debug("[wechatmp] context: {} {} {}".format(context, wechatmp_msg, supported))
 
                     if supported and context:
                         # set private openai_api_key
@@ -94,23 +89,17 @@ class Query:
                                 """\
                                 未知错误,请稍后再试"""
                             )
-                        
+
                         replyPost = create_reply(reply_text, msg)
                         return encrypt_func(replyPost.render())
 
-
                 # Wechat official server will request 3 times (5 seconds each), with the same message_id.
                 # Because the interval is 5 seconds, here assumed that do not have multithreading problems.
                 request_cnt = channel.request_cnt.get(message_id, 0) + 1
                 channel.request_cnt[message_id] = request_cnt
                 logger.info(
                     "[wechatmp] Request {} from {} {} {}:{}\n{}".format(
-                        request_cnt,
-                        from_user,
-                        message_id,
-                        web.ctx.env.get("REMOTE_ADDR"),
-                        web.ctx.env.get("REMOTE_PORT"),
-                        content
+                        request_cnt, from_user, message_id, web.ctx.env.get("REMOTE_ADDR"), web.ctx.env.get("REMOTE_PORT"), content
                     )
                 )
 
@@ -130,7 +119,7 @@ class Query:
                         time.sleep(2)
                         # and do nothing, waiting for the next request
                         return "success"
-                    else: # request_cnt == 3:
+                    else:  # request_cnt == 3:
                         # return timeout message
                         reply_text = "【正在思考中,回复任意文字尝试获取回复】"
                         replyPost = create_reply(reply_text, msg)
@@ -140,10 +129,7 @@ class Query:
                 channel.request_cnt.pop(message_id)
 
                 # no return because of bandwords or other reasons
-                if (
-                    from_user not in channel.cache_dict
-                    and from_user not in channel.running
-                ):
+                if from_user not in channel.cache_dict and from_user not in channel.running:
                     return "success"
 
                 # Only one request can access to the cached data
@@ -152,7 +138,7 @@ class Query:
                 except KeyError:
                     return "success"
 
-                if (reply_type == "text"):
+                if reply_type == "text":
                     if len(reply_content.encode("utf8")) <= MAX_UTF8_LEN:
                         reply_text = reply_content
                     else:
@@ -177,7 +163,7 @@ class Query:
                     replyPost = create_reply(reply_text, msg)
                     return encrypt_func(replyPost.render())
 
-                elif (reply_type == "voice"):
+                elif reply_type == "voice":
                     media_id = reply_content
                     asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
                     logger.info(
@@ -193,7 +179,7 @@ class Query:
                     replyPost.media_id = media_id
                     return encrypt_func(replyPost.render())
 
-                elif (reply_type == "image"):
+                elif reply_type == "image":
                     media_id = reply_content
                     asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
                     logger.info(
@@ -210,11 +196,7 @@ class Query:
                     return encrypt_func(replyPost.render())
 
             elif msg.type == "event":
-                logger.info(
-                    "[wechatmp] Event {} from {}".format(
-                        msg.event, msg.source
-                    )
-                )
+                logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
                 if msg.event in ["subscribe", "subscribe_scan"]:
                     reply_text = subscribe_msg()
                     replyPost = create_reply(reply_text, msg)

+ 20 - 27
channel/wechatmp/wechatmp_channel.py

@@ -1,24 +1,26 @@
 # -*- coding: utf-8 -*-
+import asyncio
+import imghdr
 import io
 import os
+import threading
 import time
-import imghdr
+
 import requests
-import asyncio
-import threading
-from config import conf
+import web
+from wechatpy.crypto import WeChatCrypto
+from wechatpy.exceptions import WeChatClientException
+
 from bridge.context import *
 from bridge.reply import *
-from common.log import logger
-from common.singleton import singleton
-from voice.audio_convert import any_to_mp3
 from channel.chat_channel import ChatChannel
 from channel.wechatmp.common import *
 from channel.wechatmp.wechatmp_client import WechatMPClient
-from wechatpy.exceptions import WeChatClientException
-from wechatpy.crypto import WeChatCrypto
+from common.log import logger
+from common.singleton import singleton
+from config import conf
+from voice.audio_convert import any_to_mp3
 
-import web
 # If using SSL, uncomment the following lines, and modify the certificate path.
 # from cheroot.server import HTTPServer
 # from cheroot.ssl.builtin import BuiltinSSLAdapter
@@ -54,7 +56,6 @@ class WechatMPChannel(ChatChannel):
             t.setDaemon(True)
             t.start()
 
-
     def startup(self):
         if self.passive_reply:
             urls = ("/wx", "channel.wechatmp.passive_reply.Query")
@@ -84,7 +85,7 @@ class WechatMPChannel(ChatChannel):
             elif reply.type == ReplyType.VOICE:
                 try:
                     voice_file_path = reply.content
-                    with open(voice_file_path, 'rb') as f:
+                    with open(voice_file_path, "rb") as f:
                         # support: <2M, <60s, mp3/wma/wav/amr
                         response = self.client.material.add("voice", f)
                         logger.debug("[wechatmp] upload voice response: {}".format(response))
@@ -107,7 +108,7 @@ class WechatMPChannel(ChatChannel):
                     image_storage.write(block)
                 image_storage.seek(0)
                 image_type = imghdr.what(image_storage)
-                filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type
+                filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
                 content_type = "image/" + image_type
                 try:
                     response = self.client.material.add("image", (filename, image_storage, content_type))
@@ -122,7 +123,7 @@ class WechatMPChannel(ChatChannel):
                 image_storage = reply.content
                 image_storage.seek(0)
                 image_type = imghdr.what(image_storage)
-                filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type
+                filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
                 content_type = "image/" + image_type
                 try:
                     response = self.client.material.add("image", (filename, image_storage, content_type))
@@ -137,7 +138,7 @@ class WechatMPChannel(ChatChannel):
             if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
                 reply_text = reply.content
                 texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN)
-                if len(texts)>1:
+                if len(texts) > 1:
                     logger.info("[wechatmp] text too long, split into {} parts".format(len(texts)))
                 for text in texts:
                     self.client.message.send_text(receiver, text)
@@ -174,7 +175,7 @@ class WechatMPChannel(ChatChannel):
                     image_storage.write(block)
                 image_storage.seek(0)
                 image_type = imghdr.what(image_storage)
-                filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type
+                filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
                 content_type = "image/" + image_type
                 try:
                     response = self.client.media.upload("image", (filename, image_storage, content_type))
@@ -188,7 +189,7 @@ class WechatMPChannel(ChatChannel):
                 image_storage = reply.content
                 image_storage.seek(0)
                 image_type = imghdr.what(image_storage)
-                filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type
+                filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
                 content_type = "image/" + image_type
                 try:
                     response = self.client.media.upload("image", (filename, image_storage, content_type))
@@ -201,20 +202,12 @@ class WechatMPChannel(ChatChannel):
         return
 
     def _success_callback(self, session_id, context, **kwargs):  # 线程异常结束时的回调函数
-        logger.debug(
-            "[wechatmp] Success to generate reply, msgId={}".format(
-                context["msg"].msg_id
-            )
-        )
+        logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context["msg"].msg_id))
         if self.passive_reply:
             self.running.remove(session_id)
 
     def _fail_callback(self, session_id, exception, context, **kwargs):  # 线程异常结束时的回调函数
-        logger.exception(
-            "[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(
-                context["msg"].msg_id, exception
-            )
-        )
+        logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context["msg"].msg_id, exception))
         if self.passive_reply:
             assert session_id not in self.cache_dict
             self.running.remove(session_id)

+ 10 - 11
channel/wechatmp/wechatmp_client.py

@@ -1,17 +1,16 @@
-import time
 import threading
-from channel.wechatmp.common import *
+import time
+
 from wechatpy.client import WeChatClient
-from common.log import logger
 from wechatpy.exceptions import APILimitedException
 
+from channel.wechatmp.common import *
+from common.log import logger
+
 
 class WechatMPClient(WeChatClient):
-    def __init__(self, appid, secret, access_token=None,
-                 session=None, timeout=None, auto_retry=True):
-        super(WechatMPClient, self).__init__(
-            appid, secret, access_token, session, timeout, auto_retry
-        )
+    def __init__(self, appid, secret, access_token=None, session=None, timeout=None, auto_retry=True):
+        super(WechatMPClient, self).__init__(appid, secret, access_token, session, timeout, auto_retry)
         self.fetch_access_token_lock = threading.Lock()
 
     def clear_quota(self):
@@ -20,7 +19,7 @@ class WechatMPClient(WeChatClient):
     def clear_quota_v2(self):
         return self.post("clear_quota/v2", params={"appid": self.appid, "appsecret": self.secret})
 
-    def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token
+    def fetch_access_token(self):  # 重载父类方法,加锁避免多线程重复获取access_token
         with self.fetch_access_token_lock:
             access_token = self.session.get(self.access_token_key)
             if access_token:
@@ -31,11 +30,11 @@ class WechatMPClient(WeChatClient):
                     return access_token
             return super().fetch_access_token()
 
-    def _request(self, method, url_or_endpoint, **kwargs): # 重载父类方法,遇到API限流时,清除quota后重试
+    def _request(self, method, url_or_endpoint, **kwargs):  # 重载父类方法,遇到API限流时,清除quota后重试
         try:
             return super()._request(method, url_or_endpoint, **kwargs)
         except APILimitedException as e:
             logger.error("[wechatmp] API quata has been used up. {}".format(e))
             response = self.clear_quota_v2()
             logger.debug("[wechatmp] API quata has been cleard, {}".format(response))
-            return super()._request(method, url_or_endpoint, **kwargs)
+            return super()._request(method, url_or_endpoint, **kwargs)

+ 5 - 14
channel/wechatmp/wechatmp_message.py

@@ -6,7 +6,6 @@ from common.log import logger
 from common.tmp_dir import TmpDir
 
 
-
 class WeChatMPMessage(ChatMessage):
     def __init__(self, msg, client=None):
         super().__init__(msg)
@@ -18,12 +17,9 @@ class WeChatMPMessage(ChatMessage):
             self.ctype = ContextType.TEXT
             self.content = msg.content
         elif msg.type == "voice":
-            
             if msg.recognition == None:
                 self.ctype = ContextType.VOICE
-                self.content = (
-                    TmpDir().path() + msg.media_id + "." + msg.format
-                )  # content直接存临时目录路径
+                self.content = TmpDir().path() + msg.media_id + "." + msg.format  # content直接存临时目录路径
 
                 def download_voice():
                     # 如果响应状态码是200,则将响应内容写入本地文件
@@ -32,9 +28,7 @@ class WeChatMPMessage(ChatMessage):
                         with open(self.content, "wb") as f:
                             f.write(response.content)
                     else:
-                        logger.info(
-                            f"[wechatmp] Failed to download voice file, {response.content}"
-                        )
+                        logger.info(f"[wechatmp] Failed to download voice file, {response.content}")
 
                 self._prepare_fn = download_voice
             else:
@@ -43,6 +37,7 @@ class WeChatMPMessage(ChatMessage):
         elif msg.type == "image":
             self.ctype = ContextType.IMAGE
             self.content = TmpDir().path() + msg.media_id + ".png"  # content直接存临时目录路径
+
             def download_image():
                 # 如果响应状态码是200,则将响应内容写入本地文件
                 response = client.media.download(msg.media_id)
@@ -50,15 +45,11 @@ class WeChatMPMessage(ChatMessage):
                     with open(self.content, "wb") as f:
                         f.write(response.content)
                 else:
-                    logger.info(
-                        f"[wechatmp] Failed to download image file, {response.content}"
-                    )
+                    logger.info(f"[wechatmp] Failed to download image file, {response.content}")
 
             self._prepare_fn = download_image
         else:
-            raise NotImplementedError(
-                "Unsupported message type: Type:{} ".format(msg.type)
-            )
+            raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type))
 
         self.from_user_id = msg.source
         self.to_user_id = msg.target

+ 3 - 11
common/time_check.py

@@ -13,23 +13,15 @@ def time_checker(f):
         if chat_time_module:
             chat_start_time = _config.get("chat_start_time", "00:00")
             chat_stopt_time = _config.get("chat_stop_time", "24:00")
-            time_regex = re.compile(
-                r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$"
-            )  # 时间匹配,包含24:00
+            time_regex = re.compile(r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$")  # 时间匹配,包含24:00
 
             starttime_format_check = time_regex.match(chat_start_time)  # 检查停止时间格式
             stoptime_format_check = time_regex.match(chat_stopt_time)  # 检查停止时间格式
             chat_time_check = chat_start_time < chat_stopt_time  # 确定启动时间<停止时间
 
             # 时间格式检查
-            if not (
-                starttime_format_check and stoptime_format_check and chat_time_check
-            ):
-                logger.warn(
-                    "时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(
-                        starttime_format_check, stoptime_format_check
-                    )
-                )
+            if not (starttime_format_check and stoptime_format_check and chat_time_check):
+                logger.warn("时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(starttime_format_check, stoptime_format_check))
             if chat_start_time > "23:59":
                 logger.error("启动时间可能存在问题,请修改!")
 

+ 1 - 3
config.py

@@ -158,9 +158,7 @@ def load_config():
     for name, value in os.environ.items():
         name = name.lower()
         if name in available_setting:
-            logger.info(
-                "[INIT] override config by environ args: {}={}".format(name, value)
-            )
+            logger.info("[INIT] override config by environ args: {}={}".format(name, value))
             try:
                 config[name] = eval(value)
             except:

+ 3 - 9
plugins/banwords/banwords.py

@@ -50,9 +50,7 @@ class Banwords(Plugin):
                 self.reply_action = conf.get("reply_action", "ignore")
             logger.info("[Banwords] inited")
         except Exception as e:
-            logger.warn(
-                "[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords ."
-            )
+            logger.warn("[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords .")
             raise e
 
     def on_handle_context(self, e_context: EventContext):
@@ -72,9 +70,7 @@ class Banwords(Plugin):
                 return
         elif self.action == "replace":
             if self.searchr.ContainsAny(content):
-                reply = Reply(
-                    ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content)
-                )
+                reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content))
                 e_context["reply"] = reply
                 e_context.action = EventAction.BREAK_PASS
                 return
@@ -94,9 +90,7 @@ class Banwords(Plugin):
                 return
         elif self.reply_action == "replace":
             if self.searchr.ContainsAny(content):
-                reply = Reply(
-                    ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content)
-                )
+                reply = Reply(ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content))
                 e_context["reply"] = reply
                 e_context.action = EventAction.CONTINUE
                 return

+ 7 - 32
plugins/bdunit/bdunit.py

@@ -76,9 +76,7 @@ class BDunit(Plugin):
         Returns:
             string: access_token
         """
-        url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(
-            self.api_key, self.secret_key
-        )
+        url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(self.api_key, self.secret_key)
         payload = ""
         headers = {"Content-Type": "application/json", "Accept": "application/json"}
 
@@ -94,10 +92,7 @@ class BDunit(Plugin):
         :returns: UNIT 解析结果。如果解析失败,返回 None
         """
 
-        url = (
-            "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token="
-            + self.access_token
-        )
+        url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + self.access_token
         request = {
             "query": query,
             "user_id": str(get_mac())[:32],
@@ -124,10 +119,7 @@ class BDunit(Plugin):
         :param query: 用户的指令字符串
         :returns: UNIT 解析结果。如果解析失败,返回 None
         """
-        url = (
-            "https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token="
-            + self.access_token
-        )
+        url = "https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token=" + self.access_token
         request = {"query": query, "user_id": str(get_mac())[:32]}
         body = {
             "log_id": str(uuid.uuid1()),
@@ -170,11 +162,7 @@ class BDunit(Plugin):
         if parsed and "result" in parsed and "response_list" in parsed["result"]:
             response_list = parsed["result"]["response_list"]
             for response in response_list:
-                if (
-                    "schema" in response
-                    and "intent" in response["schema"]
-                    and response["schema"]["intent"] == intent
-                ):
+                if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
                     return True
             return False
         else:
@@ -198,12 +186,7 @@ class BDunit(Plugin):
                     logger.warning(e)
                     return []
             for response in response_list:
-                if (
-                    "schema" in response
-                    and "intent" in response["schema"]
-                    and "slots" in response["schema"]
-                    and response["schema"]["intent"] == intent
-                ):
+                if "schema" in response and "intent" in response["schema"] and "slots" in response["schema"] and response["schema"]["intent"] == intent:
                     return response["schema"]["slots"]
             return []
         else:
@@ -239,11 +222,7 @@ class BDunit(Plugin):
                 if (
                     "schema" in response
                     and "intent_confidence" in response["schema"]
-                    and (
-                        not answer
-                        or response["schema"]["intent_confidence"]
-                        > answer["schema"]["intent_confidence"]
-                    )
+                    and (not answer or response["schema"]["intent_confidence"] > answer["schema"]["intent_confidence"])
                 ):
                     answer = response
             return answer["action_list"][0]["say"]
@@ -267,11 +246,7 @@ class BDunit(Plugin):
                     logger.warning(e)
                     return ""
             for response in response_list:
-                if (
-                    "schema" in response
-                    and "intent" in response["schema"]
-                    and response["schema"]["intent"] == intent
-                ):
+                if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
                     try:
                         return response["action_list"][0]["say"]
                     except Exception as e:

+ 2 - 8
plugins/dungeon/dungeon.py

@@ -84,9 +84,7 @@ class Dungeon(Plugin):
                 if len(clist) > 1:
                     story = clist[1]
                 else:
-                    story = (
-                        "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。"
-                    )
+                    story = "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。"
                 self.games[sessionid] = StoryTeller(bot, sessionid, story)
                 reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story)
                 e_context["reply"] = reply
@@ -102,11 +100,7 @@ class Dungeon(Plugin):
         if kwargs.get("verbose") != True:
             return help_text
         trigger_prefix = conf().get("plugin_trigger_prefix", "$")
-        help_text = (
-            f"{trigger_prefix}开始冒险 "
-            + "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n"
-            + f"{trigger_prefix}停止冒险: 结束游戏。\n"
-        )
+        help_text = f"{trigger_prefix}开始冒险 " + "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n" + f"{trigger_prefix}停止冒险: 结束游戏。\n"
         if kwargs.get("verbose") == True:
             help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'"
         return help_text

+ 7 - 23
plugins/godcmd/godcmd.py

@@ -140,9 +140,7 @@ def get_help_text(isadmin, isgroup):
         if plugins[plugin].enabled and not plugins[plugin].hidden:
             namecn = plugins[plugin].namecn
             help_text += "\n%s:" % namecn
-            help_text += (
-                PluginManager().instances[plugin].get_help_text(verbose=False).strip()
-            )
+            help_text += PluginManager().instances[plugin].get_help_text(verbose=False).strip()
 
     if ADMIN_COMMANDS and isadmin:
         help_text += "\n\n管理员指令:\n"
@@ -191,9 +189,7 @@ class Godcmd(Plugin):
                     COMMANDS["reset"]["alias"].append(custom_command)
 
         self.password = gconf["password"]
-        self.admin_users = gconf[
-            "admin_users"
-        ]  # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用
+        self.admin_users = gconf["admin_users"]  # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用
         self.isrunning = True  # 机器人是否运行中
 
         self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
@@ -215,7 +211,7 @@ class Godcmd(Plugin):
                 reply.content = f"空指令,输入#help查看指令列表\n"
                 e_context["reply"] = reply
                 e_context.action = EventAction.BREAK_PASS
-                return 
+                return
             # msg = e_context['context']['msg']
             channel = e_context["channel"]
             user = e_context["context"]["receiver"]
@@ -248,11 +244,7 @@ class Godcmd(Plugin):
                             if not plugincls.enabled:
                                 continue
                             if query_name == name or query_name == plugincls.namecn:
-                                ok, result = True, PluginManager().instances[
-                                    name
-                                ].get_help_text(
-                                    isgroup=isgroup, isadmin=isadmin, verbose=True
-                                )
+                                ok, result = True, PluginManager().instances[name].get_help_text(isgroup=isgroup, isadmin=isadmin, verbose=True)
                                 break
                         if not ok:
                             result = "插件不存在或未启用"
@@ -285,11 +277,7 @@ class Godcmd(Plugin):
                     if isgroup:
                         ok, result = False, "群聊不可执行管理员指令"
                     else:
-                        cmd = next(
-                            c
-                            for c, info in ADMIN_COMMANDS.items()
-                            if cmd in info["alias"]
-                        )
+                        cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info["alias"])
                         if cmd == "stop":
                             self.isrunning = False
                             ok, result = True, "服务已暂停"
@@ -325,18 +313,14 @@ class Godcmd(Plugin):
                             PluginManager().activate_plugins()
                             if len(new_plugins) > 0:
                                 result += "\n发现新插件:\n"
-                                result += "\n".join(
-                                    [f"{p.name}_v{p.version}" for p in new_plugins]
-                                )
+                                result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins])
                             else:
                                 result += ", 未发现新插件"
                         elif cmd == "setpri":
                             if len(args) != 2:
                                 ok, result = False, "请提供插件名和优先级"
                             else:
-                                ok = PluginManager().set_plugin_priority(
-                                    args[0], int(args[1])
-                                )
+                                ok = PluginManager().set_plugin_priority(args[0], int(args[1]))
                                 if ok:
                                     result = "插件" + args[0] + "优先级已设置为" + args[1]
                                 else:

+ 2 - 6
plugins/hello/hello.py

@@ -33,9 +33,7 @@ class Hello(Plugin):
         if e_context["context"].type == ContextType.JOIN_GROUP:
             e_context["context"].type = ContextType.TEXT
             msg: ChatMessage = e_context["context"]["msg"]
-            e_context[
-                "context"
-            ].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。'
+            e_context["context"].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。'
             e_context.action = EventAction.CONTINUE  # 事件继续,交付给下个插件或默认逻辑
             return
 
@@ -53,9 +51,7 @@ class Hello(Plugin):
             reply.type = ReplyType.TEXT
             msg: ChatMessage = e_context["context"]["msg"]
             if e_context["context"]["isgroup"]:
-                reply.content = (
-                    f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
-                )
+                reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
             else:
                 reply.content = f"Hello, {msg.from_user_nickname}"
             e_context["reply"] = reply

+ 2 - 2
plugins/keyword/config.json.template

@@ -1,5 +1,5 @@
 {
   "keyword": {
-      "关键字匹配": "测试成功"
+    "关键字匹配": "测试成功"
   }
-}
+}

+ 1 - 3
plugins/keyword/keyword.py

@@ -41,9 +41,7 @@ class Keyword(Plugin):
             self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
             logger.info("[keyword] inited.")
         except Exception as e:
-            logger.warn(
-                "[keyword] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/keyword ."
-            )
+            logger.warn("[keyword] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/keyword .")
             raise e
 
     def on_handle_context(self, e_context: EventContext):

+ 14 - 46
plugins/plugin_manager.py

@@ -31,23 +31,14 @@ class PluginManager:
             plugincls.desc = kwargs.get("desc")
             plugincls.author = kwargs.get("author")
             plugincls.path = self.current_plugin_path
-            plugincls.version = (
-                kwargs.get("version") if kwargs.get("version") != None else "1.0"
-            )
-            plugincls.namecn = (
-                kwargs.get("namecn") if kwargs.get("namecn") != None else name
-            )
-            plugincls.hidden = (
-                kwargs.get("hidden") if kwargs.get("hidden") != None else False
-            )
+            plugincls.version = kwargs.get("version") if kwargs.get("version") != None else "1.0"
+            plugincls.namecn = kwargs.get("namecn") if kwargs.get("namecn") != None else name
+            plugincls.hidden = kwargs.get("hidden") if kwargs.get("hidden") != None else False
             plugincls.enabled = True
             if self.current_plugin_path == None:
                 raise Exception("Plugin path not set")
             self.plugins[name.upper()] = plugincls
-            logger.info(
-                "Plugin %s_v%s registered, path=%s"
-                % (name, plugincls.version, plugincls.path)
-            )
+            logger.info("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path))
 
         return wrapper
 
@@ -62,9 +53,7 @@ class PluginManager:
         if os.path.exists("./plugins/plugins.json"):
             with open("./plugins/plugins.json", "r", encoding="utf-8") as f:
                 pconf = json.load(f)
-                pconf["plugins"] = SortedDict(
-                    lambda k, v: v["priority"], pconf["plugins"], reverse=True
-                )
+                pconf["plugins"] = SortedDict(lambda k, v: v["priority"], pconf["plugins"], reverse=True)
         else:
             modified = True
             pconf = {"plugins": SortedDict(lambda k, v: v["priority"], reverse=True)}
@@ -90,26 +79,16 @@ class PluginManager:
                         if plugin_path in self.loaded:
                             if self.loaded[plugin_path] == None:
                                 logger.info("reload module %s" % plugin_name)
-                                self.loaded[plugin_path] = importlib.reload(
-                                    sys.modules[import_path]
-                                )
-                                dependent_module_names = [
-                                    name
-                                    for name in sys.modules.keys()
-                                    if name.startswith(import_path + ".")
-                                ]
+                                self.loaded[plugin_path] = importlib.reload(sys.modules[import_path])
+                                dependent_module_names = [name for name in sys.modules.keys() if name.startswith(import_path + ".")]
                                 for name in dependent_module_names:
                                     logger.info("reload module %s" % name)
                                     importlib.reload(sys.modules[name])
                         else:
-                            self.loaded[plugin_path] = importlib.import_module(
-                                import_path
-                            )
+                            self.loaded[plugin_path] = importlib.import_module(import_path)
                         self.current_plugin_path = None
                     except Exception as e:
-                        logger.exception(
-                            "Failed to import plugin %s: %s" % (plugin_name, e)
-                        )
+                        logger.exception("Failed to import plugin %s: %s" % (plugin_name, e))
                         continue
         pconf = self.pconf
         news = [self.plugins[name] for name in self.plugins]
@@ -119,9 +98,7 @@ class PluginManager:
             rawname = plugincls.name
             if rawname not in pconf["plugins"]:
                 modified = True
-                logger.info(
-                    "Plugin %s not found in pconfig, adding to pconfig..." % name
-                )
+                logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name)
                 pconf["plugins"][rawname] = {
                     "enabled": plugincls.enabled,
                     "priority": plugincls.priority,
@@ -136,9 +113,7 @@ class PluginManager:
 
     def refresh_order(self):
         for event in self.listening_plugins.keys():
-            self.listening_plugins[event].sort(
-                key=lambda name: self.plugins[name].priority, reverse=True
-            )
+            self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True)
 
     def activate_plugins(self):  # 生成新开启的插件实例
         failed_plugins = []
@@ -184,13 +159,8 @@ class PluginManager:
     def emit_event(self, e_context: EventContext, *args, **kwargs):
         if e_context.event in self.listening_plugins:
             for name in self.listening_plugins[e_context.event]:
-                if (
-                    self.plugins[name].enabled
-                    and e_context.action == EventAction.CONTINUE
-                ):
-                    logger.debug(
-                        "Plugin %s triggered by event %s" % (name, e_context.event)
-                    )
+                if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE:
+                    logger.debug("Plugin %s triggered by event %s" % (name, e_context.event))
                     instance = self.instances[name]
                     instance.handlers[e_context.event](e_context, *args, **kwargs)
         return e_context
@@ -262,9 +232,7 @@ class PluginManager:
                     source = json.load(f)
                 if repo in source["repo"]:
                     repo = source["repo"][repo]["url"]
-                    match = re.match(
-                        r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo
-                    )
+                    match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo)
                     if not match:
                         return False, "安装插件失败,source中的仓库地址不合法"
                 else:

+ 7 - 24
plugins/role/role.py

@@ -69,13 +69,9 @@ class Role(Plugin):
             logger.info("[Role] inited")
         except Exception as e:
             if isinstance(e, FileNotFoundError):
-                logger.warn(
-                    f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ."
-                )
+                logger.warn(f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
             else:
-                logger.warn(
-                    "[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ."
-                )
+                logger.warn("[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
             raise e
 
     def get_role(self, name, find_closest=True, min_sim=0.35):
@@ -143,9 +139,7 @@ class Role(Plugin):
                 else:
                     help_text = f"未知角色类型。\n"
                     help_text += "目前的角色类型有: \n"
-                    help_text += (
-                        ",".join([self.tags[tag][0] for tag in self.tags]) + "\n"
-                    )
+                    help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "\n"
             else:
                 help_text = f"请输入角色类型。\n"
                 help_text += "目前的角色类型有: \n"
@@ -158,9 +152,7 @@ class Role(Plugin):
             return
         logger.debug("[Role] on_handle_context. content: %s" % content)
         if desckey is not None:
-            if len(clist) == 1 or (
-                len(clist) > 1 and clist[1].lower() in ["help", "帮助"]
-            ):
+            if len(clist) == 1 or (len(clist) > 1 and clist[1].lower() in ["help", "帮助"]):
                 reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True))
                 e_context["reply"] = reply
                 e_context.action = EventAction.BREAK_PASS
@@ -178,9 +170,7 @@ class Role(Plugin):
                     self.roles[role][desckey],
                     self.roles[role].get("wrapper", "%s"),
                 )
-                reply = Reply(
-                    ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey]
-                )
+                reply = Reply(ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey])
                 e_context["reply"] = reply
                 e_context.action = EventAction.BREAK_PASS
         elif customize == True:
@@ -199,17 +189,10 @@ class Role(Plugin):
         if not verbose:
             return help_text
         trigger_prefix = conf().get("plugin_trigger_prefix", "$")
-        help_text = (
-            f"使用方法:\n{trigger_prefix}角色"
-            + " 预设角色名: 设定角色为{预设角色名}。\n"
-            + f"{trigger_prefix}role"
-            + " 预设角色名: 同上,但使用英文设定。\n"
-        )
+        help_text = f"使用方法:\n{trigger_prefix}角色" + " 预设角色名: 设定角色为{预设角色名}。\n" + f"{trigger_prefix}role" + " 预设角色名: 同上,但使用英文设定。\n"
         help_text += f"{trigger_prefix}设定扮演" + " 角色设定: 设定自定义角色人设为{角色设定}。\n"
         help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n"
-        help_text += (
-            f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n"
-        )
+        help_text += f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n"
         help_text += "\n目前的角色类型有: \n"
         help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "。\n"
         help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n"

+ 1 - 1
plugins/tool/README.md

@@ -60,7 +60,7 @@
 
 > 该tool每天返回内容相同
 
-#### 6.3. finance-news 
+#### 6.3. finance-news
 ###### 获取实时的金融财政新闻
 
 > 该工具需要解决browser tool 的google-chrome依赖安装

+ 4 - 10
plugins/tool/tool.py

@@ -82,9 +82,7 @@ class Tool(Plugin):
                     return
                 elif content_list[1].startswith("reset"):
                     logger.debug("[tool]: remind")
-                    e_context[
-                        "context"
-                    ].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符"
+                    e_context["context"].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符"
 
                     e_context.action = EventAction.BREAK
                     return
@@ -93,18 +91,14 @@ class Tool(Plugin):
 
                 # Don't modify bot name
                 all_sessions = Bridge().get_bot("chat").sessions
-                user_session = all_sessions.session_query(
-                    query, e_context["context"]["session_id"]
-                ).messages
+                user_session = all_sessions.session_query(query, e_context["context"]["session_id"]).messages
 
                 # chatgpt-tool-hub will reply you with many tools
                 logger.debug("[tool]: just-go")
                 try:
                     _reply = self.app.ask(query, user_session)
                     e_context.action = EventAction.BREAK_PASS
-                    all_sessions.session_reply(
-                        _reply, e_context["context"]["session_id"]
-                    )
+                    all_sessions.session_reply(_reply, e_context["context"]["session_id"])
                 except Exception as e:
                     logger.exception(e)
                     logger.error(str(e))
@@ -178,4 +172,4 @@ class Tool(Plugin):
         # filter not support tool
         tool_list = self._filter_tool_list(tool_config.get("tools", []))
 
-        return app.create_app(tools_list=tool_list, **app_kwargs)
+        return app.create_app(tools_list=tool_list, **app_kwargs)

+ 5 - 15
voice/audio_convert.py

@@ -33,6 +33,7 @@ def get_pcm_from_wav(wav_path):
     wav = wave.open(wav_path, "rb")
     return wav.readframes(wav.getnframes())
 
+
 def any_to_mp3(any_path, mp3_path):
     """
     把任意格式转成mp3文件
@@ -40,16 +41,13 @@ def any_to_mp3(any_path, mp3_path):
     if any_path.endswith(".mp3"):
         shutil.copy2(any_path, mp3_path)
         return
-    if (
-        any_path.endswith(".sil")
-        or any_path.endswith(".silk")
-        or any_path.endswith(".slk")
-    ):
+    if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
         sil_to_wav(any_path, any_path)
         any_path = mp3_path
     audio = AudioSegment.from_file(any_path)
     audio.export(mp3_path, format="mp3")
 
+
 def any_to_wav(any_path, wav_path):
     """
     把任意格式转成wav文件
@@ -57,11 +55,7 @@ def any_to_wav(any_path, wav_path):
     if any_path.endswith(".wav"):
         shutil.copy2(any_path, wav_path)
         return
-    if (
-        any_path.endswith(".sil")
-        or any_path.endswith(".silk")
-        or any_path.endswith(".slk")
-    ):
+    if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
         return sil_to_wav(any_path, wav_path)
     audio = AudioSegment.from_file(any_path)
     audio.export(wav_path, format="wav")
@@ -71,11 +65,7 @@ def any_to_sil(any_path, sil_path):
     """
     把任意格式转成sil文件
     """
-    if (
-        any_path.endswith(".sil")
-        or any_path.endswith(".silk")
-        or any_path.endswith(".slk")
-    ):
+    if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
         shutil.copy2(any_path, sil_path)
         return 10000
     audio = AudioSegment.from_file(any_path)

+ 9 - 33
voice/azure/azure_voice.py

@@ -40,57 +40,33 @@ class AzureVoice(Voice):
                     config = json.load(fr)
             self.api_key = conf().get("azure_voice_api_key")
             self.api_region = conf().get("azure_voice_region")
-            self.speech_config = speechsdk.SpeechConfig(
-                subscription=self.api_key, region=self.api_region
-            )
-            self.speech_config.speech_synthesis_voice_name = config[
-                "speech_synthesis_voice_name"
-            ]
-            self.speech_config.speech_recognition_language = config[
-                "speech_recognition_language"
-            ]
+            self.speech_config = speechsdk.SpeechConfig(subscription=self.api_key, region=self.api_region)
+            self.speech_config.speech_synthesis_voice_name = config["speech_synthesis_voice_name"]
+            self.speech_config.speech_recognition_language = config["speech_recognition_language"]
         except Exception as e:
             logger.warn("AzureVoice init failed: %s, ignore " % e)
 
     def voiceToText(self, voice_file):
         audio_config = speechsdk.AudioConfig(filename=voice_file)
-        speech_recognizer = speechsdk.SpeechRecognizer(
-            speech_config=self.speech_config, audio_config=audio_config
-        )
+        speech_recognizer = speechsdk.SpeechRecognizer(speech_config=self.speech_config, audio_config=audio_config)
         result = speech_recognizer.recognize_once()
         if result.reason == speechsdk.ResultReason.RecognizedSpeech:
-            logger.info(
-                "[Azure] voiceToText voice file name={} text={}".format(
-                    voice_file, result.text
-                )
-            )
+            logger.info("[Azure] voiceToText voice file name={} text={}".format(voice_file, result.text))
             reply = Reply(ReplyType.TEXT, result.text)
         else:
-            logger.error(
-                "[Azure] voiceToText error, result={}, canceldetails={}".format(
-                    result, result.cancellation_details
-                )
-            )
+            logger.error("[Azure] voiceToText error, result={}, canceldetails={}".format(result, result.cancellation_details))
             reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败")
         return reply
 
     def textToVoice(self, text):
         fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav"
         audio_config = speechsdk.AudioConfig(filename=fileName)
-        speech_synthesizer = speechsdk.SpeechSynthesizer(
-            speech_config=self.speech_config, audio_config=audio_config
-        )
+        speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=audio_config)
         result = speech_synthesizer.speak_text(text)
         if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
-            logger.info(
-                "[Azure] textToVoice text={} voice file name={}".format(text, fileName)
-            )
+            logger.info("[Azure] textToVoice text={} voice file name={}".format(text, fileName))
             reply = Reply(ReplyType.VOICE, fileName)
         else:
-            logger.error(
-                "[Azure] textToVoice error, result={}, canceldetails={}".format(
-                    result, result.cancellation_details
-                )
-            )
+            logger.error("[Azure] textToVoice error, result={}, canceldetails={}".format(result, result.cancellation_details))
             reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
         return reply

+ 1 - 3
voice/baidu/baidu_voice.py

@@ -85,9 +85,7 @@ class BaiduVoice(Voice):
             fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3"
             with open(fileName, "wb") as f:
                 f.write(result)
-            logger.info(
-                "[Baidu] textToVoice text={} voice file name={}".format(text, fileName)
-            )
+            logger.info("[Baidu] textToVoice text={} voice file name={}".format(text, fileName))
             reply = Reply(ReplyType.VOICE, fileName)
         else:
             logger.error("[Baidu] textToVoice error={}".format(result))

+ 2 - 8
voice/google/google_voice.py

@@ -24,11 +24,7 @@ class GoogleVoice(Voice):
             audio = self.recognizer.record(source)
         try:
             text = self.recognizer.recognize_google(audio, language="zh-CN")
-            logger.info(
-                "[Google] voiceToText text={} voice file name={}".format(
-                    text, voice_file
-                )
-            )
+            logger.info("[Google] voiceToText text={} voice file name={}".format(text, voice_file))
             reply = Reply(ReplyType.TEXT, text)
         except speech_recognition.UnknownValueError:
             reply = Reply(ReplyType.ERROR, "抱歉,我听不懂")
@@ -42,9 +38,7 @@ class GoogleVoice(Voice):
             mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3"
             tts = gTTS(text=text, lang="zh")
             tts.save(mp3File)
-            logger.info(
-                "[Google] textToVoice text={} voice file name={}".format(text, mp3File)
-            )
+            logger.info("[Google] textToVoice text={} voice file name={}".format(text, mp3File))
             reply = Reply(ReplyType.VOICE, mp3File)
         except Exception as e:
             reply = Reply(ReplyType.ERROR, str(e))

+ 1 - 5
voice/openai/openai_voice.py

@@ -22,11 +22,7 @@ class OpenaiVoice(Voice):
             result = openai.Audio.transcribe("whisper-1", file)
             text = result["text"]
             reply = Reply(ReplyType.TEXT, text)
-            logger.info(
-                "[Openai] voiceToText text={} voice file name={}".format(
-                    text, voice_file
-                )
-            )
+            logger.info("[Openai] voiceToText text={} voice file name={}".format(text, voice_file))
         except Exception as e:
             reply = Reply(ReplyType.ERROR, str(e))
         finally:

+ 7 - 5
voice/pytts/pytts_voice.py

@@ -5,6 +5,7 @@ pytts voice service (offline)
 import os
 import sys
 import time
+
 import pyttsx3
 
 from bridge.reply import Reply, ReplyType
@@ -12,6 +13,7 @@ from common.log import logger
 from common.tmp_dir import TmpDir
 from voice.voice import Voice
 
+
 class PyttsVoice(Voice):
     engine = pyttsx3.init()
 
@@ -20,7 +22,7 @@ class PyttsVoice(Voice):
         self.engine.setProperty("rate", 125)
         # 音量
         self.engine.setProperty("volume", 1.0)
-        if sys.platform == 'win32':
+        if sys.platform == "win32":
             for voice in self.engine.getProperty("voices"):
                 if "Chinese" in voice.name:
                     self.engine.setProperty("voice", voice.id)
@@ -33,23 +35,23 @@ class PyttsVoice(Voice):
     def textToVoice(self, text):
         try:
             # avoid the same filename
-            wavFileName = "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7fffffff) + ".wav"
+            wavFileName = "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".wav"
             wavFile = TmpDir().path() + wavFileName
             logger.info("[Pytts] textToVoice text={} voice file name={}".format(text, wavFile))
 
             self.engine.save_to_file(text, wavFile)
 
-            if sys.platform == 'win32':
+            if sys.platform == "win32":
                 self.engine.runAndWait()
             else:
-                # In ubuntu, runAndWait do not really wait until the file created. 
+                # In ubuntu, runAndWait do not really wait until the file created.
                 # It will return once the task queue is empty, but the task is still running in coroutine.
                 # And if you call runAndWait() and time.sleep() twice, it will stuck, so do not use this.
                 # If you want to fix this, add self._proxy.setBusy(True) in line 127 in espeak.py, at the beginning of the function save_to_file.
                 # self.engine.runAndWait()
 
                 # Before espeak fix this problem, we iterate the generator and control the waiting by ourself.
-                # But this is not the canonical way to use it, for example if the file already exists it also cannot wait. 
+                # But this is not the canonical way to use it, for example if the file already exists it also cannot wait.
                 self.engine.iterate()
                 while self.engine.isBusy() or wavFileName not in os.listdir(TmpDir().path()):
                     time.sleep(0.1)