Browse Source

formatting code

lanvent 3 years ago
parent
commit
8f72e8c3e6
92 changed files with 1843 additions and 1181 deletions
  1. 1 1
      .github/ISSUE_TEMPLATE.md
  2. 2 2
      .github/workflows/deploy-image.yml
  3. 5 5
      README.md
  4. 17 11
      app.py
  5. 24 8
      bot/baidu/baidu_unit_bot.py
  6. 1 1
      bot/bot.py
  7. 4 0
      bot/bot_factory.py
  8. 72 48
      bot/chatgpt/chat_gpt_bot.py
  9. 28 12
      bot/chatgpt/chat_gpt_session.py
  10. 62 39
      bot/openai/open_ai_bot.py
  11. 20 13
      bot/openai/open_ai_image.py
  12. 31 17
      bot/openai/open_ai_session.py
  13. 31 18
      bot/session_manager.py
  14. 13 16
      bridge/bridge.py
  15. 24 19
      bridge/context.py
  16. 11 8
      bridge/reply.py
  17. 5 3
      channel/channel.py
  18. 13 7
      channel/channel_factory.py
  19. 213 112
      channel/chat_channel.py
  20. 10 10
      channel/chat_message.py
  21. 26 10
      channel/terminal/terminal_channel.py
  22. 80 47
      channel/wechat/wechat_channel.py
  23. 28 28
      channel/wechat/wechat_message.py
  24. 54 40
      channel/wechat/wechaty_channel.py
  25. 28 18
      channel/wechat/wechaty_message.py
  26. 2 2
      channel/wechatmp/README.md
  27. 35 16
      channel/wechatmp/ServiceAccount.py
  28. 98 38
      channel/wechatmp/SubscribeAccount.py
  29. 15 10
      channel/wechatmp/common.py
  30. 20 18
      channel/wechatmp/receive.py
  31. 12 9
      channel/wechatmp/reply.py
  32. 47 38
      channel/wechatmp/wechatmp_channel.py
  33. 1 1
      common/const.py
  34. 2 2
      common/dequeue.py
  35. 1 1
      common/expired_dict.py
  36. 16 7
      common/log.py
  37. 11 5
      common/package_manager.py
  38. 1 1
      common/sorted_dict.py
  39. 22 10
      common/time_check.py
  40. 5 7
      common/tmp_dir.py
  41. 20 6
      config-template.json
  42. 21 32
      config.py
  43. 1 1
      docker/Dockerfile.debian
  44. 1 1
      docker/Dockerfile.debian.latest
  45. 1 2
      docker/build.alpine.sh
  46. 1 1
      docker/build.debian.sh
  47. 1 1
      docker/chatgpt-on-wechat-voice-reply/Dockerfile.alpine
  48. 1 1
      docker/chatgpt-on-wechat-voice-reply/Dockerfile.debian
  49. 2 2
      docker/sample-chatgpt-on-wechat/Makefile
  50. 14 14
      plugins/README.md
  51. 2 2
      plugins/__init__.py
  52. 1 1
      plugins/banwords/__init__.py
  53. 49 35
      plugins/banwords/banwords.py
  54. 4 4
      plugins/banwords/config.json.template
  55. 1 1
      plugins/bdunit/README.md
  56. 1 1
      plugins/bdunit/__init__.py
  57. 30 42
      plugins/bdunit/bdunit.py
  58. 4 4
      plugins/bdunit/config.json.template
  59. 1 1
      plugins/dungeon/__init__.py
  60. 47 28
      plugins/dungeon/dungeon.py
  61. 6 6
      plugins/event.py
  62. 1 1
      plugins/finish/__init__.py
  63. 15 9
      plugins/finish/finish.py
  64. 1 1
      plugins/godcmd/__init__.py
  65. 3 3
      plugins/godcmd/config.json.template
  66. 96 70
      plugins/godcmd/godcmd.py
  67. 1 1
      plugins/hello/__init__.py
  68. 22 14
      plugins/hello/hello.py
  69. 1 1
      plugins/plugin.py
  70. 104 52
      plugins/plugin_manager.py
  71. 1 1
      plugins/role/__init__.py
  72. 66 35
      plugins/role/role.py
  73. 1 1
      plugins/role/roles.json
  74. 14 14
      plugins/source.json
  75. 16 16
      plugins/tool/README.md
  76. 1 1
      plugins/tool/__init__.py
  77. 10 5
      plugins/tool/config.json.template
  78. 40 21
      plugins/tool/tool.py
  79. 1 0
      requirements.txt
  80. 1 1
      scripts/start.sh
  81. 1 1
      scripts/tout.sh
  82. 25 8
      voice/audio_convert.py
  83. 37 17
      voice/azure/azure_voice.py
  84. 3 3
      voice/azure/config.json.template
  85. 4 4
      voice/baidu/README.md
  86. 26 23
      voice/baidu/baidu_voice.py
  87. 8 8
      voice/baidu/config.json.template
  88. 13 7
      voice/google/google_voice.py
  89. 9 6
      voice/openai/openai_voice.py
  90. 9 7
      voice/pytts/pytts_voice.py
  91. 2 1
      voice/voice.py
  92. 11 5
      voice/voice_factory.py

+ 1 - 1
.github/ISSUE_TEMPLATE.md

@@ -27,5 +27,5 @@
 ### 环境
 
  - 操作系统类型  (Mac/Windows/Linux):
- - Python版本  ( 执行 `python3 -V` ):                      
+ - Python版本  ( 执行 `python3 -V` ):  
  - pip版本  ( 依赖问题此项必填,执行 `pip3 -V`):

+ 2 - 2
.github/workflows/deploy-image.yml

@@ -49,9 +49,9 @@ jobs:
           file: ./docker/Dockerfile.latest
           tags: ${{ steps.meta.outputs.tags }}
           labels: ${{ steps.meta.outputs.labels }}
-      
+
       - uses: actions/delete-package-versions@v4
-        with: 
+        with:
           package-name: 'chatgpt-on-wechat'
           package-type: 'container'
           min-versions-to-keep: 10

+ 5 - 5
README.md

@@ -120,7 +120,7 @@ pip3 install azure-cognitiveservices-speech
 
 ```bash
 # config.json文件内容示例
-{ 
+{
   "open_ai_api_key": "YOUR API KEY",                          # 填入上面创建的 OpenAI API KEY
   "model": "gpt-3.5-turbo",                                   # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
   "proxy": "127.0.0.1:7890",                                  # 代理客户端的ip和端口
@@ -128,7 +128,7 @@ pip3 install azure-cognitiveservices-speech
   "single_chat_reply_prefix": "[bot] ",                       # 私聊时自动回复的前缀,用于区分真人
   "group_chat_prefix": ["@bot"],                              # 群聊时包含该前缀则会触发机器人回复
   "group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
-  "group_chat_in_one_session": ["ChatGPT测试群"],              # 支持会话上下文共享的群名称       
+  "group_chat_in_one_session": ["ChatGPT测试群"],              # 支持会话上下文共享的群名称  
   "image_create_prefix": ["画", "看", "找"],                   # 开启图片回复的前缀
   "conversation_max_tokens": 1000,                            # 支持上下文记忆的最多字符数
   "speech_recognition": false,                                # 是否开启语音识别
@@ -160,7 +160,7 @@ pip3 install azure-cognitiveservices-speech
 **4.其他配置**
 
 + `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k`  (其中gpt-4 api暂未开放)
-+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat) 
++ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
 + `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考  [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
 + 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
 + 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions)  文档直接在 [代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/bot/openai/open_ai_bot.py) `bot/openai/open_ai_bot.py` 中进行调整。
@@ -181,7 +181,7 @@ pip3 install azure-cognitiveservices-speech
 ```bash
 python3 app.py
 ```
-终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。 
+终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。
 
 
 ### 2.服务器部署
@@ -189,7 +189,7 @@ python3 app.py
 使用nohup命令在后台运行程序:
 
 ```bash
-touch nohup.out                                   # 首次运行需要新建日志文件                     
+touch nohup.out                                   # 首次运行需要新建日志文件  
 nohup python3 app.py & tail -f nohup.out          # 在后台运行程序并通过日志输出二维码
 ```
 扫码登录后程序即可运行于服务器后台,此时可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。此外,`scripts` 目录下有一键运行、关闭程序的脚本供使用。

+ 17 - 11
app.py

@@ -1,23 +1,28 @@
 # encoding:utf-8
 
 import os
-from config import conf, load_config
+import signal
+import sys
+
 from channel import channel_factory
 from common.log import logger
+from config import conf, load_config
 from plugins import *
-import signal
-import sys
+
 
 def sigterm_handler_wrap(_signo):
     old_handler = signal.getsignal(_signo)
+
     def func(_signo, _stack_frame):
         logger.info("signal {} received, exiting...".format(_signo))
         conf().save_user_datas()
-        if callable(old_handler): #  check old_handler
+        if callable(old_handler):  #  check old_handler
             return old_handler(_signo, _stack_frame)
         sys.exit(0)
+
     signal.signal(_signo, func)
 
+
 def run():
     try:
         # load config
@@ -28,17 +33,17 @@ def run():
         sigterm_handler_wrap(signal.SIGTERM)
 
         # create channel
-        channel_name=conf().get('channel_type', 'wx')
+        channel_name = conf().get("channel_type", "wx")
 
         if "--cmd" in sys.argv:
-            channel_name = 'terminal'
+            channel_name = "terminal"
 
-        if channel_name == 'wxy':
-            os.environ['WECHATY_LOG']="warn"
+        if channel_name == "wxy":
+            os.environ["WECHATY_LOG"] = "warn"
             # os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
 
         channel = channel_factory.create_channel(channel_name)
-        if channel_name in ['wx','wxy','terminal','wechatmp','wechatmp_service']:
+        if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service"]:
             PluginManager().load_plugins()
 
         # startup channel
@@ -47,5 +52,6 @@ def run():
         logger.error("App startup failed!")
         logger.exception(e)
 
-if __name__ == '__main__':
-    run()
+
+if __name__ == "__main__":
+    run()

+ 24 - 8
bot/baidu/baidu_unit_bot.py

@@ -1,6 +1,7 @@
 # encoding:utf-8
 
 import requests
+
 from bot.bot import Bot
 from bridge.reply import Reply, ReplyType
 
@@ -9,20 +10,35 @@ from bridge.reply import Reply, ReplyType
 class BaiduUnitBot(Bot):
     def reply(self, query, context=None):
         token = self.get_token()
-        url = 'https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=' + token
-        post_data = "{\"version\":\"3.0\",\"service_id\":\"S73177\",\"session_id\":\"\",\"log_id\":\"7758521\",\"skill_ids\":[\"1221886\"],\"request\":{\"terminal_id\":\"88888\",\"query\":\"" + query + "\", \"hyper_params\": {\"chat_custom_bot_profile\": 1}}}"
+        url = (
+            "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token="
+            + token
+        )
+        post_data = (
+            '{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"'
+            + query
+            + '", "hyper_params": {"chat_custom_bot_profile": 1}}}'
+        )
         print(post_data)
-        headers = {'content-type': 'application/x-www-form-urlencoded'}
+        headers = {"content-type": "application/x-www-form-urlencoded"}
         response = requests.post(url, data=post_data.encode(), headers=headers)
         if response:
-            reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1])
+            reply = Reply(
+                ReplyType.TEXT,
+                response.json()["result"]["context"]["SYS_PRESUMED_HIST"][1],
+            )
             return reply
 
     def get_token(self):
-        access_key = 'YOUR_ACCESS_KEY'
-        secret_key = 'YOUR_SECRET_KEY'
-        host = 'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=' + access_key + '&client_secret=' + secret_key
+        access_key = "YOUR_ACCESS_KEY"
+        secret_key = "YOUR_SECRET_KEY"
+        host = (
+            "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id="
+            + access_key
+            + "&client_secret="
+            + secret_key
+        )
         response = requests.get(host)
         if response:
             print(response.json())
-            return response.json()['access_token']
+            return response.json()["access_token"]

+ 1 - 1
bot/bot.py

@@ -8,7 +8,7 @@ from bridge.reply import Reply
 
 
 class Bot(object):
-    def reply(self, query, context : Context =None) -> Reply:
+    def reply(self, query, context: Context = None) -> Reply:
         """
         bot auto-reply content
         :param req: received message

+ 4 - 0
bot/bot_factory.py

@@ -13,20 +13,24 @@ def create_bot(bot_type):
     if bot_type == const.BAIDU:
         # Baidu Unit对话接口
         from bot.baidu.baidu_unit_bot import BaiduUnitBot
+
         return BaiduUnitBot()
 
     elif bot_type == const.CHATGPT:
         # ChatGPT 网页端web接口
         from bot.chatgpt.chat_gpt_bot import ChatGPTBot
+
         return ChatGPTBot()
 
     elif bot_type == const.OPEN_AI:
         # OpenAI 官方对话模型API
         from bot.openai.open_ai_bot import OpenAIBot
+
         return OpenAIBot()
 
     elif bot_type == const.CHATGPTONAZURE:
         # Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
         from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot
+
         return AzureChatGPTBot()
     raise RuntimeError

+ 72 - 48
bot/chatgpt/chat_gpt_bot.py

@@ -1,42 +1,53 @@
 # encoding:utf-8
 
+import time
+
+import openai
+import openai.error
+
 from bot.bot import Bot
 from bot.chatgpt.chat_gpt_session import ChatGPTSession
 from bot.openai.open_ai_image import OpenAIImage
 from bot.session_manager import SessionManager
 from bridge.context import ContextType
 from bridge.reply import Reply, ReplyType
-from config import conf, load_config
 from common.log import logger
 from common.token_bucket import TokenBucket
-import openai
-import openai.error
-import time
+from config import conf, load_config
+
 
 # OpenAI对话模型API (可用)
-class ChatGPTBot(Bot,OpenAIImage):
+class ChatGPTBot(Bot, OpenAIImage):
     def __init__(self):
         super().__init__()
         # set the default api_key
-        openai.api_key = conf().get('open_ai_api_key')
-        if conf().get('open_ai_api_base'):
-            openai.api_base = conf().get('open_ai_api_base')
-        proxy = conf().get('proxy')
+        openai.api_key = conf().get("open_ai_api_key")
+        if conf().get("open_ai_api_base"):
+            openai.api_base = conf().get("open_ai_api_base")
+        proxy = conf().get("proxy")
         if proxy:
             openai.proxy = proxy
-        if conf().get('rate_limit_chatgpt'):
-            self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20))
-        
-        self.sessions = SessionManager(ChatGPTSession, model= conf().get("model") or "gpt-3.5-turbo")
-        self.args ={
+        if conf().get("rate_limit_chatgpt"):
+            self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))
+
+        self.sessions = SessionManager(
+            ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo"
+        )
+        self.args = {
             "model": conf().get("model") or "gpt-3.5-turbo",  # 对话模型的名称
-            "temperature":conf().get('temperature', 0.9),  # 值在[0,1]之间,越大表示回复越具有不确定性
+            "temperature": conf().get("temperature", 0.9),  # 值在[0,1]之间,越大表示回复越具有不确定性
             # "max_tokens":4096,  # 回复最大的字符数
-            "top_p":1,
-            "frequency_penalty":conf().get('frequency_penalty', 0.0),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
-            "presence_penalty":conf().get('presence_penalty', 0.0),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
-            "request_timeout": conf().get('request_timeout', None),  # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
-            "timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试
+            "top_p": 1,
+            "frequency_penalty": conf().get(
+                "frequency_penalty", 0.0
+            ),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
+            "presence_penalty": conf().get(
+                "presence_penalty", 0.0
+            ),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
+            "request_timeout": conf().get(
+                "request_timeout", None
+            ),  # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
+            "timeout": conf().get("request_timeout", None),  # 重试超时时间,在这个时间内,将会自动重试
         }
 
     def reply(self, query, context=None):
@@ -44,39 +55,50 @@ class ChatGPTBot(Bot,OpenAIImage):
         if context.type == ContextType.TEXT:
             logger.info("[CHATGPT] query={}".format(query))
 
-
-            session_id = context['session_id']
+            session_id = context["session_id"]
             reply = None
-            clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆'])
+            clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
             if query in clear_memory_commands:
                 self.sessions.clear_session(session_id)
-                reply = Reply(ReplyType.INFO, '记忆已清除')
-            elif query == '#清除所有':
+                reply = Reply(ReplyType.INFO, "记忆已清除")
+            elif query == "#清除所有":
                 self.sessions.clear_all_session()
-                reply = Reply(ReplyType.INFO, '所有人记忆已清除')
-            elif query == '#更新配置':
+                reply = Reply(ReplyType.INFO, "所有人记忆已清除")
+            elif query == "#更新配置":
                 load_config()
-                reply = Reply(ReplyType.INFO, '配置已更新')
+                reply = Reply(ReplyType.INFO, "配置已更新")
             if reply:
                 return reply
             session = self.sessions.session_query(query, session_id)
             logger.debug("[CHATGPT] session query={}".format(session.messages))
 
-            api_key = context.get('openai_api_key')
+            api_key = context.get("openai_api_key")
 
             # if context.get('stream'):
             #     # reply in stream
             #     return self.reply_text_stream(query, new_query, session_id)
 
             reply_content = self.reply_text(session, api_key)
-            logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"]))
-            if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
-                reply = Reply(ReplyType.ERROR, reply_content['content'])
+            logger.debug(
+                "[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
+                    session.messages,
+                    session_id,
+                    reply_content["content"],
+                    reply_content["completion_tokens"],
+                )
+            )
+            if (
+                reply_content["completion_tokens"] == 0
+                and len(reply_content["content"]) > 0
+            ):
+                reply = Reply(ReplyType.ERROR, reply_content["content"])
             elif reply_content["completion_tokens"] > 0:
-                self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
+                self.sessions.session_reply(
+                    reply_content["content"], session_id, reply_content["total_tokens"]
+                )
                 reply = Reply(ReplyType.TEXT, reply_content["content"])
             else:
-                reply = Reply(ReplyType.ERROR, reply_content['content'])
+                reply = Reply(ReplyType.ERROR, reply_content["content"])
                 logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content))
             return reply
 
@@ -89,53 +111,55 @@ class ChatGPTBot(Bot,OpenAIImage):
                 reply = Reply(ReplyType.ERROR, retstring)
             return reply
         else:
-            reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type))
+            reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
             return reply
 
-    def reply_text(self, session:ChatGPTSession, api_key=None, retry_count=0) -> dict:
-        '''
+    def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> dict:
+        """
         call openai's ChatCompletion to get the answer
         :param session: a conversation session
         :param session_id: session id
         :param retry_count: retry count
         :return: {}
-        '''
+        """
         try:
-            if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token():
+            if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
                 raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
             # if api_key == None, the default openai.api_key will be used
             response = openai.ChatCompletion.create(
                 api_key=api_key, messages=session.messages, **self.args
             )
             # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
-            return {"total_tokens": response["usage"]["total_tokens"],
-                    "completion_tokens": response["usage"]["completion_tokens"],
-                    "content": response.choices[0]['message']['content']}
+            return {
+                "total_tokens": response["usage"]["total_tokens"],
+                "completion_tokens": response["usage"]["completion_tokens"],
+                "content": response.choices[0]["message"]["content"],
+            }
         except Exception as e:
             need_retry = retry_count < 2
             result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
             if isinstance(e, openai.error.RateLimitError):
                 logger.warn("[CHATGPT] RateLimitError: {}".format(e))
-                result['content'] = "提问太快啦,请休息一下再问我吧"
+                result["content"] = "提问太快啦,请休息一下再问我吧"
                 if need_retry:
                     time.sleep(5)
             elif isinstance(e, openai.error.Timeout):
                 logger.warn("[CHATGPT] Timeout: {}".format(e))
-                result['content'] = "我没有收到你的消息"
+                result["content"] = "我没有收到你的消息"
                 if need_retry:
                     time.sleep(5)
             elif isinstance(e, openai.error.APIConnectionError):
                 logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
                 need_retry = False
-                result['content'] = "我连接不到你的网络"
+                result["content"] = "我连接不到你的网络"
             else:
                 logger.warn("[CHATGPT] Exception: {}".format(e))
                 need_retry = False
                 self.sessions.clear_session(session.session_id)
 
             if need_retry:
-                logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1))
-                return self.reply_text(session, api_key, retry_count+1)
+                logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1))
+                return self.reply_text(session, api_key, retry_count + 1)
             else:
                 return result
 
@@ -145,4 +169,4 @@ class AzureChatGPTBot(ChatGPTBot):
         super().__init__()
         openai.api_type = "azure"
         openai.api_version = "2023-03-15-preview"
-        self.args["deployment_id"] = conf().get("azure_deployment_id")
+        self.args["deployment_id"] = conf().get("azure_deployment_id")

+ 28 - 12
bot/chatgpt/chat_gpt_session.py

@@ -1,20 +1,23 @@
 from bot.session_manager import Session
 from common.log import logger
-'''
+
+"""
     e.g.  [
         {"role": "system", "content": "You are a helpful assistant."},
         {"role": "user", "content": "Who won the world series in 2020?"},
         {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
         {"role": "user", "content": "Where was it played?"}
     ]
-'''
+"""
+
+
 class ChatGPTSession(Session):
-    def __init__(self, session_id, system_prompt=None, model= "gpt-3.5-turbo"):
+    def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"):
         super().__init__(session_id, system_prompt)
         self.model = model
         self.reset()
-    
-    def discard_exceeding(self, max_tokens, cur_tokens= None):
+
+    def discard_exceeding(self, max_tokens, cur_tokens=None):
         precise = True
         try:
             cur_tokens = self.calc_tokens()
@@ -22,7 +25,9 @@ class ChatGPTSession(Session):
             precise = False
             if cur_tokens is None:
                 raise e
-            logger.debug("Exception when counting tokens precisely for query: {}".format(e))
+            logger.debug(
+                "Exception when counting tokens precisely for query: {}".format(e)
+            )
         while cur_tokens > max_tokens:
             if len(self.messages) > 2:
                 self.messages.pop(1)
@@ -34,25 +39,32 @@ class ChatGPTSession(Session):
                     cur_tokens = cur_tokens - max_tokens
                 break
             elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
-                logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
+                logger.warn(
+                    "user message exceed max_tokens. total_tokens={}".format(cur_tokens)
+                )
                 break
             else:
-                logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
+                logger.debug(
+                    "max_tokens={}, total_tokens={}, len(messages)={}".format(
+                        max_tokens, cur_tokens, len(self.messages)
+                    )
+                )
                 break
             if precise:
                 cur_tokens = self.calc_tokens()
             else:
                 cur_tokens = cur_tokens - max_tokens
         return cur_tokens
-    
+
     def calc_tokens(self):
         return num_tokens_from_messages(self.messages, self.model)
-    
+
 
 # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
 def num_tokens_from_messages(messages, model):
     """Returns the number of tokens used by a list of messages."""
     import tiktoken
+
     try:
         encoding = tiktoken.encoding_for_model(model)
     except KeyError:
@@ -63,13 +75,17 @@ def num_tokens_from_messages(messages, model):
     elif model == "gpt-4":
         return num_tokens_from_messages(messages, model="gpt-4-0314")
     elif model == "gpt-3.5-turbo-0301":
-        tokens_per_message = 4  # every message follows <|start|>{role/name}\n{content}<|end|>\n
+        tokens_per_message = (
+            4  # every message follows <|start|>{role/name}\n{content}<|end|>\n
+        )
         tokens_per_name = -1  # if there's a name, the role is omitted
     elif model == "gpt-4-0314":
         tokens_per_message = 3
         tokens_per_name = 1
     else:
-        logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.")
+        logger.warn(
+            f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301."
+        )
         return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
     num_tokens = 0
     for message in messages:

+ 62 - 39
bot/openai/open_ai_bot.py

@@ -1,41 +1,52 @@
 # encoding:utf-8
 
+import time
+
+import openai
+import openai.error
+
 from bot.bot import Bot
 from bot.openai.open_ai_image import OpenAIImage
 from bot.openai.open_ai_session import OpenAISession
 from bot.session_manager import SessionManager
 from bridge.context import ContextType
 from bridge.reply import Reply, ReplyType
-from config import conf
 from common.log import logger
-import openai
-import openai.error
-import time
+from config import conf
 
 user_session = dict()
 
+
 # OpenAI对话模型API (可用)
 class OpenAIBot(Bot, OpenAIImage):
     def __init__(self):
         super().__init__()
-        openai.api_key = conf().get('open_ai_api_key')
-        if conf().get('open_ai_api_base'):
-            openai.api_base = conf().get('open_ai_api_base')
-        proxy = conf().get('proxy')
+        openai.api_key = conf().get("open_ai_api_key")
+        if conf().get("open_ai_api_base"):
+            openai.api_base = conf().get("open_ai_api_base")
+        proxy = conf().get("proxy")
         if proxy:
             openai.proxy = proxy
 
-        self.sessions = SessionManager(OpenAISession, model= conf().get("model") or "text-davinci-003")
+        self.sessions = SessionManager(
+            OpenAISession, model=conf().get("model") or "text-davinci-003"
+        )
         self.args = {
             "model": conf().get("model") or "text-davinci-003",  # 对话模型的名称
-            "temperature":conf().get('temperature', 0.9),  # 值在[0,1]之间,越大表示回复越具有不确定性
-            "max_tokens":1200,  # 回复最大的字符数
-            "top_p":1,
-            "frequency_penalty":conf().get('frequency_penalty', 0.0),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
-            "presence_penalty":conf().get('presence_penalty', 0.0),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
-            "request_timeout": conf().get('request_timeout', None),  # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
-            "timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试
-            "stop":["\n\n\n"]
+            "temperature": conf().get("temperature", 0.9),  # 值在[0,1]之间,越大表示回复越具有不确定性
+            "max_tokens": 1200,  # 回复最大的字符数
+            "top_p": 1,
+            "frequency_penalty": conf().get(
+                "frequency_penalty", 0.0
+            ),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
+            "presence_penalty": conf().get(
+                "presence_penalty", 0.0
+            ),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
+            "request_timeout": conf().get(
+                "request_timeout", None
+            ),  # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
+            "timeout": conf().get("request_timeout", None),  # 重试超时时间,在这个时间内,将会自动重试
+            "stop": ["\n\n\n"],
         }
 
     def reply(self, query, context=None):
@@ -43,24 +54,34 @@ class OpenAIBot(Bot, OpenAIImage):
         if context and context.type:
             if context.type == ContextType.TEXT:
                 logger.info("[OPEN_AI] query={}".format(query))
-                session_id = context['session_id']
+                session_id = context["session_id"]
                 reply = None
-                if query == '#清除记忆':
+                if query == "#清除记忆":
                     self.sessions.clear_session(session_id)
-                    reply = Reply(ReplyType.INFO, '记忆已清除')
-                elif query == '#清除所有':
+                    reply = Reply(ReplyType.INFO, "记忆已清除")
+                elif query == "#清除所有":
                     self.sessions.clear_all_session()
-                    reply = Reply(ReplyType.INFO, '所有人记忆已清除')
+                    reply = Reply(ReplyType.INFO, "所有人记忆已清除")
                 else:
                     session = self.sessions.session_query(query, session_id)
                     result = self.reply_text(session)
-                    total_tokens, completion_tokens, reply_content = result['total_tokens'], result['completion_tokens'], result['content']
-                    logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens))
+                    total_tokens, completion_tokens, reply_content = (
+                        result["total_tokens"],
+                        result["completion_tokens"],
+                        result["content"],
+                    )
+                    logger.debug(
+                        "[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
+                            str(session), session_id, reply_content, completion_tokens
+                        )
+                    )
 
-                    if total_tokens == 0 :
+                    if total_tokens == 0:
                         reply = Reply(ReplyType.ERROR, reply_content)
                     else:
-                        self.sessions.session_reply(reply_content, session_id, total_tokens)
+                        self.sessions.session_reply(
+                            reply_content, session_id, total_tokens
+                        )
                         reply = Reply(ReplyType.TEXT, reply_content)
                 return reply
             elif context.type == ContextType.IMAGE_CREATE:
@@ -72,42 +93,44 @@ class OpenAIBot(Bot, OpenAIImage):
                     reply = Reply(ReplyType.ERROR, retstring)
                 return reply
 
-    def reply_text(self, session:OpenAISession, retry_count=0):
+    def reply_text(self, session: OpenAISession, retry_count=0):
         try:
-            response = openai.Completion.create(
-                prompt=str(session), **self.args
+            response = openai.Completion.create(prompt=str(session), **self.args)
+            res_content = (
+                response.choices[0]["text"].strip().replace("<|endoftext|>", "")
             )
-            res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
             total_tokens = response["usage"]["total_tokens"]
             completion_tokens = response["usage"]["completion_tokens"]
             logger.info("[OPEN_AI] reply={}".format(res_content))
-            return {"total_tokens": total_tokens,
-                    "completion_tokens": completion_tokens,
-                    "content": res_content}
+            return {
+                "total_tokens": total_tokens,
+                "completion_tokens": completion_tokens,
+                "content": res_content,
+            }
         except Exception as e:
             need_retry = retry_count < 2
             result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
             if isinstance(e, openai.error.RateLimitError):
                 logger.warn("[OPEN_AI] RateLimitError: {}".format(e))
-                result['content'] = "提问太快啦,请休息一下再问我吧"
+                result["content"] = "提问太快啦,请休息一下再问我吧"
                 if need_retry:
                     time.sleep(5)
             elif isinstance(e, openai.error.Timeout):
                 logger.warn("[OPEN_AI] Timeout: {}".format(e))
-                result['content'] = "我没有收到你的消息"
+                result["content"] = "我没有收到你的消息"
                 if need_retry:
                     time.sleep(5)
             elif isinstance(e, openai.error.APIConnectionError):
                 logger.warn("[OPEN_AI] APIConnectionError: {}".format(e))
                 need_retry = False
-                result['content'] = "我连接不到你的网络"
+                result["content"] = "我连接不到你的网络"
             else:
                 logger.warn("[OPEN_AI] Exception: {}".format(e))
                 need_retry = False
                 self.sessions.clear_session(session.session_id)
 
             if need_retry:
-                logger.warn("[OPEN_AI] 第{}次重试".format(retry_count+1))
-                return self.reply_text(session, retry_count+1)
+                logger.warn("[OPEN_AI] 第{}次重试".format(retry_count + 1))
+                return self.reply_text(session, retry_count + 1)
             else:
-                return result
+                return result

+ 20 - 13
bot/openai/open_ai_image.py

@@ -1,38 +1,45 @@
 import time
+
 import openai
 import openai.error
-from common.token_bucket import TokenBucket
+
 from common.log import logger
+from common.token_bucket import TokenBucket
 from config import conf
 
+
 # OPENAI提供的画图接口
 class OpenAIImage(object):
     def __init__(self):
-        openai.api_key = conf().get('open_ai_api_key')
-        if conf().get('rate_limit_dalle'):
-            self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50))
-            
+        openai.api_key = conf().get("open_ai_api_key")
+        if conf().get("rate_limit_dalle"):
+            self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50))
+
     def create_img(self, query, retry_count=0):
         try:
-            if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token():
+            if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token():
                 return False, "请求太快了,请休息一下再问我吧"
             logger.info("[OPEN_AI] image_query={}".format(query))
             response = openai.Image.create(
-                prompt=query,    #图片描述
-                n=1,             #每次生成图片的数量
-                size="256x256"   #图片大小,可选有 256x256, 512x512, 1024x1024
+                prompt=query,  # 图片描述
+                n=1,  # 每次生成图片的数量
+                size="256x256",  # 图片大小,可选有 256x256, 512x512, 1024x1024
             )
-            image_url = response['data'][0]['url']
+            image_url = response["data"][0]["url"]
             logger.info("[OPEN_AI] image_url={}".format(image_url))
             return True, image_url
         except openai.error.RateLimitError as e:
             logger.warn(e)
             if retry_count < 1:
                 time.sleep(5)
-                logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
-                return self.create_img(query, retry_count+1)
+                logger.warn(
+                    "[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(
+                        retry_count + 1
+                    )
+                )
+                return self.create_img(query, retry_count + 1)
             else:
                 return False, "提问太快啦,请休息一下再问我吧"
         except Exception as e:
             logger.exception(e)
-            return False, str(e)
+            return False, str(e)

+ 31 - 17
bot/openai/open_ai_session.py

@@ -1,32 +1,34 @@
 from bot.session_manager import Session
 from common.log import logger
+
+
 class OpenAISession(Session):
-    def __init__(self, session_id, system_prompt=None, model= "text-davinci-003"):
+    def __init__(self, session_id, system_prompt=None, model="text-davinci-003"):
         super().__init__(session_id, system_prompt)
         self.model = model
         self.reset()
 
     def __str__(self):
         # 构造对话模型的输入
-        '''
+        """
         e.g.  Q: xxx
               A: xxx
               Q: xxx
-        '''
+        """
         prompt = ""
         for item in self.messages:
-            if item['role'] == 'system':
-                prompt += item['content'] + "<|endoftext|>\n\n\n"
-            elif item['role'] == 'user':
-                prompt += "Q: " + item['content'] + "\n"
-            elif item['role'] == 'assistant':
-                prompt += "\n\nA: " + item['content'] + "<|endoftext|>\n"
+            if item["role"] == "system":
+                prompt += item["content"] + "<|endoftext|>\n\n\n"
+            elif item["role"] == "user":
+                prompt += "Q: " + item["content"] + "\n"
+            elif item["role"] == "assistant":
+                prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n"
 
-        if len(self.messages) > 0 and self.messages[-1]['role'] == 'user':
+        if len(self.messages) > 0 and self.messages[-1]["role"] == "user":
             prompt += "A: "
         return prompt
 
-    def discard_exceeding(self, max_tokens, cur_tokens= None):
+    def discard_exceeding(self, max_tokens, cur_tokens=None):
         precise = True
         try:
             cur_tokens = self.calc_tokens()
@@ -34,7 +36,9 @@ class OpenAISession(Session):
             precise = False
             if cur_tokens is None:
                 raise e
-            logger.debug("Exception when counting tokens precisely for query: {}".format(e))
+            logger.debug(
+                "Exception when counting tokens precisely for query: {}".format(e)
+            )
         while cur_tokens > max_tokens:
             if len(self.messages) > 1:
                 self.messages.pop(0)
@@ -46,24 +50,34 @@ class OpenAISession(Session):
                     cur_tokens = len(str(self))
                 break
             elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
-                logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
+                logger.warn(
+                    "user question exceed max_tokens. total_tokens={}".format(
+                        cur_tokens
+                    )
+                )
                 break
             else:
-                logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
+                logger.debug(
+                    "max_tokens={}, total_tokens={}, len(conversation)={}".format(
+                        max_tokens, cur_tokens, len(self.messages)
+                    )
+                )
                 break
             if precise:
                 cur_tokens = self.calc_tokens()
             else:
                 cur_tokens = len(str(self))
         return cur_tokens
-    
+
     def calc_tokens(self):
         return num_tokens_from_string(str(self), self.model)
 
+
 # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
 def num_tokens_from_string(string: str, model: str) -> int:
     """Returns the number of tokens in a text string."""
     import tiktoken
+
     encoding = tiktoken.encoding_for_model(model)
-    num_tokens = len(encoding.encode(string,disallowed_special=()))
-    return num_tokens
+    num_tokens = len(encoding.encode(string, disallowed_special=()))
+    return num_tokens

+ 31 - 18
bot/session_manager.py

@@ -2,6 +2,7 @@ from common.expired_dict import ExpiredDict
 from common.log import logger
 from config import conf
 
+
 class Session(object):
     def __init__(self, session_id, system_prompt=None):
         self.session_id = session_id
@@ -13,7 +14,7 @@ class Session(object):
 
     # 重置会话
     def reset(self):
-        system_item = {'role': 'system', 'content': self.system_prompt}
+        system_item = {"role": "system", "content": self.system_prompt}
         self.messages = [system_item]
 
     def set_system_prompt(self, system_prompt):
@@ -21,13 +22,13 @@ class Session(object):
         self.reset()
 
     def add_query(self, query):
-        user_item = {'role': 'user', 'content': query}
+        user_item = {"role": "user", "content": query}
         self.messages.append(user_item)
 
     def add_reply(self, reply):
-        assistant_item = {'role': 'assistant', 'content': reply}
+        assistant_item = {"role": "assistant", "content": reply}
         self.messages.append(assistant_item)
-    
+
     def discard_exceeding(self, max_tokens=None, cur_tokens=None):
         raise NotImplementedError
 
@@ -37,8 +38,8 @@ class Session(object):
 
 class SessionManager(object):
     def __init__(self, sessioncls, **session_args):
-        if conf().get('expires_in_seconds'):
-            sessions = ExpiredDict(conf().get('expires_in_seconds'))
+        if conf().get("expires_in_seconds"):
+            sessions = ExpiredDict(conf().get("expires_in_seconds"))
         else:
             sessions = dict()
         self.sessions = sessions
@@ -46,20 +47,22 @@ class SessionManager(object):
         self.session_args = session_args
 
     def build_session(self, session_id, system_prompt=None):
-        '''
-            如果session_id不在sessions中,创建一个新的session并添加到sessions中
-            如果system_prompt不会空,会更新session的system_prompt并重置session
-        '''
+        """
+        如果session_id不在sessions中,创建一个新的session并添加到sessions中
+        如果system_prompt不会空,会更新session的system_prompt并重置session
+        """
         if session_id is None:
             return self.sessioncls(session_id, system_prompt, **self.session_args)
-        
+
         if session_id not in self.sessions:
-            self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args)
+            self.sessions[session_id] = self.sessioncls(
+                session_id, system_prompt, **self.session_args
+            )
         elif system_prompt is not None:  # 如果有新的system_prompt,更新并重置session
             self.sessions[session_id].set_system_prompt(system_prompt)
         session = self.sessions[session_id]
         return session
-    
+
     def session_query(self, query, session_id):
         session = self.build_session(session_id)
         session.add_query(query)
@@ -68,23 +71,33 @@ class SessionManager(object):
             total_tokens = session.discard_exceeding(max_tokens, None)
             logger.debug("prompt tokens used={}".format(total_tokens))
         except Exception as e:
-            logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
+            logger.debug(
+                "Exception when counting tokens precisely for prompt: {}".format(str(e))
+            )
         return session
 
-    def session_reply(self, reply, session_id, total_tokens = None):
+    def session_reply(self, reply, session_id, total_tokens=None):
         session = self.build_session(session_id)
         session.add_reply(reply)
         try:
             max_tokens = conf().get("conversation_max_tokens", 1000)
             tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
-            logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
+            logger.debug(
+                "raw total_tokens={}, savesession tokens={}".format(
+                    total_tokens, tokens_cnt
+                )
+            )
         except Exception as e:
-            logger.debug("Exception when counting tokens precisely for session: {}".format(str(e)))
+            logger.debug(
+                "Exception when counting tokens precisely for session: {}".format(
+                    str(e)
+                )
+            )
         return session
 
     def clear_session(self, session_id):
         if session_id in self.sessions:
-            del(self.sessions[session_id])
+            del self.sessions[session_id]
 
     def clear_all_session(self):
         self.sessions.clear()

+ 13 - 16
bridge/bridge.py

@@ -1,31 +1,31 @@
+from bot import bot_factory
 from bridge.context import Context
 from bridge.reply import Reply
+from common import const
 from common.log import logger
-from bot import bot_factory
 from common.singleton import singleton
-from voice import voice_factory
 from config import conf
-from common import const
+from voice import voice_factory
 
 
 @singleton
 class Bridge(object):
     def __init__(self):
-        self.btype={
+        self.btype = {
             "chat": const.CHATGPT,
             "voice_to_text": conf().get("voice_to_text", "openai"),
-            "text_to_voice": conf().get("text_to_voice", "google")
+            "text_to_voice": conf().get("text_to_voice", "google"),
         }
         model_type = conf().get("model")
         if model_type in ["text-davinci-003"]:
-            self.btype['chat'] = const.OPEN_AI
+            self.btype["chat"] = const.OPEN_AI
         if conf().get("use_azure_chatgpt", False):
-            self.btype['chat'] = const.CHATGPTONAZURE
-        self.bots={}
+            self.btype["chat"] = const.CHATGPTONAZURE
+        self.bots = {}
 
-    def get_bot(self,typename):
+    def get_bot(self, typename):
         if self.bots.get(typename) is None:
-            logger.info("create bot {} for {}".format(self.btype[typename],typename))
+            logger.info("create bot {} for {}".format(self.btype[typename], typename))
             if typename == "text_to_voice":
                 self.bots[typename] = voice_factory.create_voice(self.btype[typename])
             elif typename == "voice_to_text":
@@ -33,18 +33,15 @@ class Bridge(object):
             elif typename == "chat":
                 self.bots[typename] = bot_factory.create_bot(self.btype[typename])
         return self.bots[typename]
-    
-    def get_bot_type(self,typename):
-        return self.btype[typename]
 
+    def get_bot_type(self, typename):
+        return self.btype[typename]
 
-    def fetch_reply_content(self, query, context : Context) -> Reply:
+    def fetch_reply_content(self, query, context: Context) -> Reply:
         return self.get_bot("chat").reply(query, context)
 
-
     def fetch_voice_to_text(self, voiceFile) -> Reply:
         return self.get_bot("voice_to_text").voiceToText(voiceFile)
 
     def fetch_text_to_voice(self, text) -> Reply:
         return self.get_bot("text_to_voice").textToVoice(text)
-

+ 24 - 19
bridge/context.py

@@ -2,36 +2,39 @@
 
 from enum import Enum
 
-class ContextType (Enum):
-    TEXT = 1         # 文本消息
-    VOICE = 2        # 音频消息
-    IMAGE = 3        # 图片消息
-    IMAGE_CREATE = 10 # 创建图片命令
-    
+
+class ContextType(Enum):
+    TEXT = 1  # 文本消息
+    VOICE = 2  # 音频消息
+    IMAGE = 3  # 图片消息
+    IMAGE_CREATE = 10  # 创建图片命令
+
     def __str__(self):
         return self.name
+
+
 class Context:
-    def __init__(self, type : ContextType = None , content = None,  kwargs = dict()):
+    def __init__(self, type: ContextType = None, content=None, kwargs=dict()):
         self.type = type
         self.content = content
         self.kwargs = kwargs
 
     def __contains__(self, key):
-        if key == 'type':
+        if key == "type":
             return self.type is not None
-        elif key == 'content':
+        elif key == "content":
             return self.content is not None
         else:
             return key in self.kwargs
-        
+
     def __getitem__(self, key):
-        if key == 'type':
+        if key == "type":
             return self.type
-        elif key == 'content':
+        elif key == "content":
             return self.content
         else:
             return self.kwargs[key]
-    
+
     def get(self, key, default=None):
         try:
             return self[key]
@@ -39,20 +42,22 @@ class Context:
             return default
 
     def __setitem__(self, key, value):
-        if key == 'type':
+        if key == "type":
             self.type = value
-        elif key == 'content':
+        elif key == "content":
             self.content = value
         else:
             self.kwargs[key] = value
 
     def __delitem__(self, key):
-        if key == 'type':
+        if key == "type":
             self.type = None
-        elif key == 'content':
+        elif key == "content":
             self.content = None
         else:
             del self.kwargs[key]
-    
+
     def __str__(self):
-        return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)
+        return "Context(type={}, content={}, kwargs={})".format(
+            self.type, self.content, self.kwargs
+        )

+ 11 - 8
bridge/reply.py

@@ -1,22 +1,25 @@
-
 # encoding:utf-8
 
 from enum import Enum
 
+
 class ReplyType(Enum):
-    TEXT = 1        # 文本
-    VOICE = 2       # 音频文件
-    IMAGE = 3       # 图片文件
-    IMAGE_URL = 4   # 图片URL
-    
+    TEXT = 1  # 文本
+    VOICE = 2  # 音频文件
+    IMAGE = 3  # 图片文件
+    IMAGE_URL = 4  # 图片URL
+
     INFO = 9
     ERROR = 10
+
     def __str__(self):
         return self.name
 
+
 class Reply:
-    def __init__(self, type : ReplyType = None , content = None):
+    def __init__(self, type: ReplyType = None, content=None):
         self.type = type
         self.content = content
+
     def __str__(self):
-        return "Reply(type={}, content={})".format(self.type, self.content)
+        return "Reply(type={}, content={})".format(self.type, self.content)

+ 5 - 3
channel/channel.py

@@ -6,8 +6,10 @@ from bridge.bridge import Bridge
 from bridge.context import Context
 from bridge.reply import *
 
+
 class Channel(object):
     NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE]
+
     def startup(self):
         """
         init channel
@@ -27,15 +29,15 @@ class Channel(object):
         send message to user
         :param msg: message content
         :param receiver: receiver channel account
-        :return: 
+        :return:
         """
         raise NotImplementedError
 
-    def build_reply_content(self, query, context : Context=None) -> Reply:
+    def build_reply_content(self, query, context: Context = None) -> Reply:
         return Bridge().fetch_reply_content(query, context)
 
     def build_voice_to_text(self, voice_file) -> Reply:
         return Bridge().fetch_voice_to_text(voice_file)
-    
+
     def build_text_to_voice(self, text) -> Reply:
         return Bridge().fetch_text_to_voice(text)

+ 13 - 7
channel/channel_factory.py

@@ -2,25 +2,31 @@
 channel factory
 """
 
+
 def create_channel(channel_type):
     """
     create a channel instance
     :param channel_type: channel type code
     :return: channel instance
     """
-    if channel_type == 'wx':
+    if channel_type == "wx":
         from channel.wechat.wechat_channel import WechatChannel
+
         return WechatChannel()
-    elif channel_type == 'wxy':
+    elif channel_type == "wxy":
         from channel.wechat.wechaty_channel import WechatyChannel
+
         return WechatyChannel()
-    elif channel_type == 'terminal':
+    elif channel_type == "terminal":
         from channel.terminal.terminal_channel import TerminalChannel
+
         return TerminalChannel()
-    elif channel_type == 'wechatmp':
+    elif channel_type == "wechatmp":
         from channel.wechatmp.wechatmp_channel import WechatMPChannel
-        return WechatMPChannel(passive_reply = True)
-    elif channel_type == 'wechatmp_service':
+
+        return WechatMPChannel(passive_reply=True)
+    elif channel_type == "wechatmp_service":
         from channel.wechatmp.wechatmp_channel import WechatMPChannel
-        return WechatMPChannel(passive_reply = False)
+
+        return WechatMPChannel(passive_reply=False)
     raise RuntimeError

+ 213 - 112
channel/chat_channel.py

@@ -1,137 +1,172 @@
-
-
-from asyncio import CancelledError
-from concurrent.futures import Future, ThreadPoolExecutor
 import os
 import re
 import threading
 import time
-from common.dequeue import Dequeue
-from channel.channel import Channel
-from bridge.reply import *
+from asyncio import CancelledError
+from concurrent.futures import Future, ThreadPoolExecutor
+
 from bridge.context import *
-from config import conf
+from bridge.reply import *
+from channel.channel import Channel
+from common.dequeue import Dequeue
 from common.log import logger
+from config import conf
 from plugins import *
+
 try:
     from voice.audio_convert import any_to_wav
 except Exception as e:
     pass
 
+
 # 抽象类, 它包含了与消息通道无关的通用处理逻辑
 class ChatChannel(Channel):
-    name = None # 登录的用户名
-    user_id = None # 登录的用户id
-    futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
-    sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
-    lock = threading.Lock() # 用于控制对sessions的访问
+    name = None  # 登录的用户名
+    user_id = None  # 登录的用户id
+    futures = {}  # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
+    sessions = {}  # 用于控制并发,每个session_id同时只能有一个context在处理
+    lock = threading.Lock()  # 用于控制对sessions的访问
     handler_pool = ThreadPoolExecutor(max_workers=8)  # 处理消息的线程池
 
     def __init__(self):
         _thread = threading.Thread(target=self.consume)
         _thread.setDaemon(True)
         _thread.start()
-        
 
     # 根据消息构造context,消息内容相关的触发项写在这里
     def _compose_context(self, ctype: ContextType, content, **kwargs):
         context = Context(ctype, content)
         context.kwargs = kwargs
-        # context首次传入时,origin_ctype是None, 
+        # context首次传入时,origin_ctype是None,
         # 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。
         # origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀
-        if 'origin_ctype' not in context:  
-            context['origin_ctype'] = ctype
+        if "origin_ctype" not in context:
+            context["origin_ctype"] = ctype
         # context首次传入时,receiver是None,根据类型设置receiver
-        first_in = 'receiver' not in context
+        first_in = "receiver" not in context
         # 群名匹配过程,设置session_id和receiver
-        if first_in: # context首次传入时,receiver是None,根据类型设置receiver
+        if first_in:  # context首次传入时,receiver是None,根据类型设置receiver
             config = conf()
-            cmsg = context['msg']
+            cmsg = context["msg"]
             if context.get("isgroup", False):
                 group_name = cmsg.other_user_nickname
                 group_id = cmsg.other_user_id
 
-                group_name_white_list = config.get('group_name_white_list', [])
-                group_name_keyword_white_list = config.get('group_name_keyword_white_list', [])
-                if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]):
-                    group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
+                group_name_white_list = config.get("group_name_white_list", [])
+                group_name_keyword_white_list = config.get(
+                    "group_name_keyword_white_list", []
+                )
+                if any(
+                    [
+                        group_name in group_name_white_list,
+                        "ALL_GROUP" in group_name_white_list,
+                        check_contain(group_name, group_name_keyword_white_list),
+                    ]
+                ):
+                    group_chat_in_one_session = conf().get(
+                        "group_chat_in_one_session", []
+                    )
                     session_id = cmsg.actual_user_id
-                    if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]):
+                    if any(
+                        [
+                            group_name in group_chat_in_one_session,
+                            "ALL_GROUP" in group_chat_in_one_session,
+                        ]
+                    ):
                         session_id = group_id
                 else:
                     return None
-                context['session_id'] = session_id
-                context['receiver'] = group_id
+                context["session_id"] = session_id
+                context["receiver"] = group_id
             else:
-                context['session_id'] = cmsg.other_user_id
-                context['receiver'] = cmsg.other_user_id
-            e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {'channel': self, 'context': context}))
-            context = e_context['context']
+                context["session_id"] = cmsg.other_user_id
+                context["receiver"] = cmsg.other_user_id
+            e_context = PluginManager().emit_event(
+                EventContext(
+                    Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}
+                )
+            )
+            context = e_context["context"]
             if e_context.is_pass() or context is None:
                 return context
-            if cmsg.from_user_id == self.user_id and not config.get('trigger_by_self', True):
+            if cmsg.from_user_id == self.user_id and not config.get(
+                "trigger_by_self", True
+            ):
                 logger.debug("[WX]self message skipped")
                 return None
 
         # 消息内容匹配过程,并处理content
         if ctype == ContextType.TEXT:
-            if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
+            if first_in and "」\n- - - - - - -" in content:  # 初次匹配 过滤引用消息
                 logger.debug("[WX]reference query skipped")
                 return None
-            
-            if context.get("isgroup", False): # 群聊
+
+            if context.get("isgroup", False):  # 群聊
                 # 校验关键字
-                match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
-                match_contain = check_contain(content, conf().get('group_chat_keyword'))
+                match_prefix = check_prefix(content, conf().get("group_chat_prefix"))
+                match_contain = check_contain(content, conf().get("group_chat_keyword"))
                 flag = False
                 if match_prefix is not None or match_contain is not None:
                     flag = True
                     if match_prefix:
-                        content = content.replace(match_prefix, '', 1).strip()
-                if context['msg'].is_at:
+                        content = content.replace(match_prefix, "", 1).strip()
+                if context["msg"].is_at:
                     logger.info("[WX]receive group at")
                     if not conf().get("group_at_off", False):
                         flag = True
-                    pattern = f'@{self.name}(\u2005|\u0020)'
-                    content = re.sub(pattern, r'', content)
-                
+                    pattern = f"@{self.name}(\u2005|\u0020)"
+                    content = re.sub(pattern, r"", content)
+
                 if not flag:
                     if context["origin_ctype"] == ContextType.VOICE:
-                        logger.info("[WX]receive group voice, but checkprefix didn't match")
+                        logger.info(
+                            "[WX]receive group voice, but checkprefix didn't match"
+                        )
                     return None
-            else: # 单聊
-                match_prefix = check_prefix(content, conf().get('single_chat_prefix',['']))
-                if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
-                    content = content.replace(match_prefix, '', 1).strip()
-                elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
+            else:  # 单聊
+                match_prefix = check_prefix(
+                    content, conf().get("single_chat_prefix", [""])
+                )
+                if match_prefix is not None:  # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
+                    content = content.replace(match_prefix, "", 1).strip()
+                elif (
+                    context["origin_ctype"] == ContextType.VOICE
+                ):  # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
                     pass
                 else:
-                    return None     
-                                                  
-            img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
+                    return None
+
+            img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
             if img_match_prefix:
-                content = content.replace(img_match_prefix, '', 1)
+                content = content.replace(img_match_prefix, "", 1)
                 context.type = ContextType.IMAGE_CREATE
             else:
                 context.type = ContextType.TEXT
             context.content = content.strip()
-            if 'desire_rtype' not in context and conf().get('always_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
-                context['desire_rtype'] = ReplyType.VOICE
-        elif context.type == ContextType.VOICE: 
-            if 'desire_rtype' not in context and conf().get('voice_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
-                context['desire_rtype'] = ReplyType.VOICE
+            if (
+                "desire_rtype" not in context
+                and conf().get("always_reply_voice")
+                and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
+            ):
+                context["desire_rtype"] = ReplyType.VOICE
+        elif context.type == ContextType.VOICE:
+            if (
+                "desire_rtype" not in context
+                and conf().get("voice_reply_voice")
+                and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
+            ):
+                context["desire_rtype"] = ReplyType.VOICE
 
         return context
 
     def _handle(self, context: Context):
         if context is None or not context.content:
             return
-        logger.debug('[WX] ready to handle context: {}'.format(context))
+        logger.debug("[WX] ready to handle context: {}".format(context))
         # reply的构建步骤
         reply = self._generate_reply(context)
 
-        logger.debug('[WX] ready to decorate reply: {}'.format(reply))
+        logger.debug("[WX] ready to decorate reply: {}".format(reply))
         # reply的包装步骤
         reply = self._decorate_reply(context, reply)
 
@@ -139,20 +174,31 @@ class ChatChannel(Channel):
         self._send_reply(context, reply)
 
     def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
-        e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
-            'channel': self, 'context': context, 'reply': reply}))
-        reply = e_context['reply']
+        e_context = PluginManager().emit_event(
+            EventContext(
+                Event.ON_HANDLE_CONTEXT,
+                {"channel": self, "context": context, "reply": reply},
+            )
+        )
+        reply = e_context["reply"]
         if not e_context.is_pass():
-            logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
-            if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE:  # 文字和图片消息
+            logger.debug(
+                "[WX] ready to handle context: type={}, content={}".format(
+                    context.type, context.content
+                )
+            )
+            if (
+                context.type == ContextType.TEXT
+                or context.type == ContextType.IMAGE_CREATE
+            ):  # 文字和图片消息
                 reply = super().build_reply_content(context.content, context)
             elif context.type == ContextType.VOICE:  # 语音消息
-                cmsg = context['msg']
+                cmsg = context["msg"]
                 cmsg.prepare()
                 file_path = context.content
-                wav_path = os.path.splitext(file_path)[0] + '.wav'
+                wav_path = os.path.splitext(file_path)[0] + ".wav"
                 try:
-                    any_to_wav(file_path, wav_path) 
+                    any_to_wav(file_path, wav_path)
                 except Exception as e:  # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
                     logger.warning("[WX]any to wav error, use raw path. " + str(e))
                     wav_path = file_path
@@ -169,7 +215,8 @@ class ChatChannel(Channel):
 
                 if reply.type == ReplyType.TEXT:
                     new_context = self._compose_context(
-                        ContextType.TEXT, reply.content, **context.kwargs)
+                        ContextType.TEXT, reply.content, **context.kwargs
+                    )
                     if new_context:
                         reply = self._generate_reply(new_context)
                     else:
@@ -177,18 +224,21 @@ class ChatChannel(Channel):
             elif context.type == ContextType.IMAGE:  # 图片消息,当前无默认逻辑
                 pass
             else:
-                logger.error('[WX] unknown context type: {}'.format(context.type))
+                logger.error("[WX] unknown context type: {}".format(context.type))
                 return
         return reply
 
     def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
         if reply and reply.type:
-            e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {
-                'channel': self, 'context': context, 'reply': reply}))
-            reply = e_context['reply']
-            desire_rtype = context.get('desire_rtype')
+            e_context = PluginManager().emit_event(
+                EventContext(
+                    Event.ON_DECORATE_REPLY,
+                    {"channel": self, "context": context, "reply": reply},
+                )
+            )
+            reply = e_context["reply"]
+            desire_rtype = context.get("desire_rtype")
             if not e_context.is_pass() and reply and reply.type:
-                
                 if reply.type in self.NOT_SUPPORT_REPLYTYPE:
                     logger.error("[WX]reply type not support: " + str(reply.type))
                     reply.type = ReplyType.ERROR
@@ -196,59 +246,91 @@ class ChatChannel(Channel):
 
                 if reply.type == ReplyType.TEXT:
                     reply_text = reply.content
-                    if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
+                    if (
+                        desire_rtype == ReplyType.VOICE
+                        and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
+                    ):
                         reply = super().build_text_to_voice(reply.content)
                         return self._decorate_reply(context, reply)
                     if context.get("isgroup", False):
-                        reply_text = '@' +  context['msg'].actual_user_nickname + ' ' + reply_text.strip()
-                        reply_text = conf().get("group_chat_reply_prefix", "") + reply_text
+                        reply_text = (
+                            "@"
+                            + context["msg"].actual_user_nickname
+                            + " "
+                            + reply_text.strip()
+                        )
+                        reply_text = (
+                            conf().get("group_chat_reply_prefix", "") + reply_text
+                        )
                     else:
-                        reply_text = conf().get("single_chat_reply_prefix", "") + reply_text
+                        reply_text = (
+                            conf().get("single_chat_reply_prefix", "") + reply_text
+                        )
                     reply.content = reply_text
                 elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
-                    reply.content = "["+str(reply.type)+"]\n" + reply.content
-                elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
+                    reply.content = "[" + str(reply.type) + "]\n" + reply.content
+                elif (
+                    reply.type == ReplyType.IMAGE_URL
+                    or reply.type == ReplyType.VOICE
+                    or reply.type == ReplyType.IMAGE
+                ):
                     pass
                 else:
-                    logger.error('[WX] unknown reply type: {}'.format(reply.type))
+                    logger.error("[WX] unknown reply type: {}".format(reply.type))
                     return
-            if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
-                logger.warning('[WX] desire_rtype: {}, but reply type: {}'.format(context.get('desire_rtype'), reply.type))
+            if (
+                desire_rtype
+                and desire_rtype != reply.type
+                and reply.type not in [ReplyType.ERROR, ReplyType.INFO]
+            ):
+                logger.warning(
+                    "[WX] desire_rtype: {}, but reply type: {}".format(
+                        context.get("desire_rtype"), reply.type
+                    )
+                )
             return reply
 
     def _send_reply(self, context: Context, reply: Reply):
         if reply and reply.type:
-            e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {
-                'channel': self, 'context': context, 'reply': reply}))
-            reply = e_context['reply']
+            e_context = PluginManager().emit_event(
+                EventContext(
+                    Event.ON_SEND_REPLY,
+                    {"channel": self, "context": context, "reply": reply},
+                )
+            )
+            reply = e_context["reply"]
             if not e_context.is_pass() and reply and reply.type:
-                logger.debug('[WX] ready to send reply: {}, context: {}'.format(reply, context))
+                logger.debug(
+                    "[WX] ready to send reply: {}, context: {}".format(reply, context)
+                )
                 self._send(reply, context)
 
-    def _send(self, reply: Reply, context: Context, retry_cnt = 0):
+    def _send(self, reply: Reply, context: Context, retry_cnt=0):
         try:
             self.send(reply, context)
         except Exception as e:
-            logger.error('[WX] sendMsg error: {}'.format(str(e)))
+            logger.error("[WX] sendMsg error: {}".format(str(e)))
             if isinstance(e, NotImplementedError):
                 return
             logger.exception(e)
             if retry_cnt < 2:
-                time.sleep(3+3*retry_cnt)
-                self._send(reply, context, retry_cnt+1)
+                time.sleep(3 + 3 * retry_cnt)
+                self._send(reply, context, retry_cnt + 1)
 
-    def _success_callback(self, session_id, **kwargs):# 线程正常结束时的回调函数
+    def _success_callback(self, session_id, **kwargs):  # 线程正常结束时的回调函数
         logger.debug("Worker return success, session_id = {}".format(session_id))
 
-    def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数
+    def _fail_callback(self, session_id, exception, **kwargs):  # 线程异常结束时的回调函数
         logger.exception("Worker return exception: {}".format(exception))
 
     def _thread_pool_callback(self, session_id, **kwargs):
-        def func(worker:Future):
+        def func(worker: Future):
             try:
                 worker_exception = worker.exception()
                 if worker_exception:
-                    self._fail_callback(session_id, exception = worker_exception, **kwargs)
+                    self._fail_callback(
+                        session_id, exception=worker_exception, **kwargs
+                    )
                 else:
                     self._success_callback(session_id, **kwargs)
             except CancelledError as e:
@@ -257,15 +339,19 @@ class ChatChannel(Channel):
                 logger.exception("Worker raise exception: {}".format(e))
             with self.lock:
                 self.sessions[session_id][1].release()
+
         return func
 
     def produce(self, context: Context):
-        session_id = context['session_id']
+        session_id = context["session_id"]
         with self.lock:
             if session_id not in self.sessions:
-                self.sessions[session_id] = [Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 4))]
-            if context.type == ContextType.TEXT and context.content.startswith("#"): 
-                self.sessions[session_id][0].putleft(context) # 优先处理管理命令
+                self.sessions[session_id] = [
+                    Dequeue(),
+                    threading.BoundedSemaphore(conf().get("concurrency_in_session", 4)),
+                ]
+            if context.type == ContextType.TEXT and context.content.startswith("#"):
+                self.sessions[session_id][0].putleft(context)  # 优先处理管理命令
             else:
                 self.sessions[session_id][0].put(context)
 
@@ -276,44 +362,58 @@ class ChatChannel(Channel):
                 session_ids = list(self.sessions.keys())
                 for session_id in session_ids:
                     context_queue, semaphore = self.sessions[session_id]
-                    if semaphore.acquire(blocking = False): # 等线程处理完毕才能删除
+                    if semaphore.acquire(blocking=False):  # 等线程处理完毕才能删除
                         if not context_queue.empty():
                             context = context_queue.get()
                             logger.debug("[WX] consume context: {}".format(context))
-                            future:Future = self.handler_pool.submit(self._handle, context)
-                            future.add_done_callback(self._thread_pool_callback(session_id, context = context))
+                            future: Future = self.handler_pool.submit(
+                                self._handle, context
+                            )
+                            future.add_done_callback(
+                                self._thread_pool_callback(session_id, context=context)
+                            )
                             if session_id not in self.futures:
                                 self.futures[session_id] = []
                             self.futures[session_id].append(future)
-                        elif semaphore._initial_value == semaphore._value+1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
-                            self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
-                            assert len(self.futures[session_id]) == 0, "thread pool error"
+                        elif (
+                            semaphore._initial_value == semaphore._value + 1
+                        ):  # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
+                            self.futures[session_id] = [
+                                t for t in self.futures[session_id] if not t.done()
+                            ]
+                            assert (
+                                len(self.futures[session_id]) == 0
+                            ), "thread pool error"
                             del self.sessions[session_id]
                         else:
                             semaphore.release()
             time.sleep(0.1)
 
     # 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
-    def cancel_session(self, session_id): 
+    def cancel_session(self, session_id):
         with self.lock:
             if session_id in self.sessions:
                 for future in self.futures[session_id]:
                     future.cancel()
                 cnt = self.sessions[session_id][0].qsize()
-                if cnt>0:
-                    logger.info("Cancel {} messages in session {}".format(cnt, session_id))
+                if cnt > 0:
+                    logger.info(
+                        "Cancel {} messages in session {}".format(cnt, session_id)
+                    )
                 self.sessions[session_id][0] = Dequeue()
-    
+
     def cancel_all_session(self):
         with self.lock:
             for session_id in self.sessions:
                 for future in self.futures[session_id]:
                     future.cancel()
                 cnt = self.sessions[session_id][0].qsize()
-                if cnt>0:
-                    logger.info("Cancel {} messages in session {}".format(cnt, session_id))
+                if cnt > 0:
+                    logger.info(
+                        "Cancel {} messages in session {}".format(cnt, session_id)
+                    )
                 self.sessions[session_id][0] = Dequeue()
-    
+
 
 def check_prefix(content, prefix_list):
     if not prefix_list:
@@ -323,6 +423,7 @@ def check_prefix(content, prefix_list):
             return prefix
     return None
 
+
 def check_contain(content, keyword_list):
     if not keyword_list:
         return None

+ 10 - 10
channel/chat_message.py

@@ -1,5 +1,4 @@
-
-""" 
+"""
 本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装。
 
 填好必填项(群聊6个,非群聊8个),即可接入ChatChannel,并支持插件,参考TerminalChannel
@@ -20,7 +19,7 @@ other_user_id: 对方的id,如果你是发送者,那这个就是接收者id
 other_user_nickname: 同上
 
 is_group: 是否是群消息 (群聊必填)
-is_at: 是否被at 
+is_at: 是否被at
 
 - (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在)
 actual_user_id: 实际发送者id (群聊必填)
@@ -34,20 +33,22 @@ _prepared: 是否已经调用过准备函数
 _rawmsg: 原始消息对象
 
 """
+
+
 class ChatMessage(object):
     msg_id = None
     create_time = None
-    
+
     ctype = None
     content = None
-    
+
     from_user_id = None
     from_user_nickname = None
     to_user_id = None
     to_user_nickname = None
     other_user_id = None
     other_user_nickname = None
-    
+
     is_group = False
     is_at = False
     actual_user_id = None
@@ -57,8 +58,7 @@ class ChatMessage(object):
     _prepared = False
     _rawmsg = None
 
-
-    def __init__(self,_rawmsg):
+    def __init__(self, _rawmsg):
         self._rawmsg = _rawmsg
 
     def prepare(self):
@@ -67,7 +67,7 @@ class ChatMessage(object):
             self._prepare_fn()
 
     def __str__(self):
-        return 'ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}'.format(
+        return "ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}".format(
             self.msg_id,
             self.create_time,
             self.ctype,
@@ -82,4 +82,4 @@ class ChatMessage(object):
             self.is_at,
             self.actual_user_id,
             self.actual_user_nickname,
-        )
+        )

+ 26 - 10
channel/terminal/terminal_channel.py

@@ -1,14 +1,23 @@
+import sys
+
 from bridge.context import *
 from bridge.reply import Reply, ReplyType
 from channel.chat_channel import ChatChannel, check_prefix
 from channel.chat_message import ChatMessage
-import sys
-
-from config import conf
 from common.log import logger
+from config import conf
+
 
 class TerminalMessage(ChatMessage):
-    def __init__(self, msg_id, content, ctype = ContextType.TEXT,  from_user_id = "User", to_user_id = "Chatgpt", other_user_id = "Chatgpt"):
+    def __init__(
+        self,
+        msg_id,
+        content,
+        ctype=ContextType.TEXT,
+        from_user_id="User",
+        to_user_id="Chatgpt",
+        other_user_id="Chatgpt",
+    ):
         self.msg_id = msg_id
         self.ctype = ctype
         self.content = content
@@ -16,6 +25,7 @@ class TerminalMessage(ChatMessage):
         self.to_user_id = to_user_id
         self.other_user_id = other_user_id
 
+
 class TerminalChannel(ChatChannel):
     NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE]
 
@@ -23,14 +33,18 @@ class TerminalChannel(ChatChannel):
         print("\nBot:")
         if reply.type == ReplyType.IMAGE:
             from PIL import Image
+
             image_storage = reply.content
             image_storage.seek(0)
             img = Image.open(image_storage)
             print("<IMAGE>")
             img.show()
-        elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
+        elif reply.type == ReplyType.IMAGE_URL:  # 从网络下载图片
+            import io
+
+            import requests
             from PIL import Image
-            import requests,io
+
             img_url = reply.content
             pic_res = requests.get(img_url, stream=True)
             image_storage = io.BytesIO()
@@ -59,11 +73,13 @@ class TerminalChannel(ChatChannel):
                 print("\nExiting...")
                 sys.exit()
             msg_id += 1
-            trigger_prefixs = conf().get("single_chat_prefix",[""])
+            trigger_prefixs = conf().get("single_chat_prefix", [""])
             if check_prefix(prompt, trigger_prefixs) is None:
-                prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
-                
-            context = self._compose_context(ContextType.TEXT, prompt, msg = TerminalMessage(msg_id, prompt))
+                prompt = trigger_prefixs[0] + prompt  # 给没触发的消息加上触发前缀
+
+            context = self._compose_context(
+                ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt)
+            )
             if context:
                 self.produce(context)
             else:

+ 80 - 47
channel/wechat/wechat_channel.py

@@ -4,40 +4,45 @@
 wechat channel
 """
 
+import io
+import json
 import os
 import threading
-import requests
-import io
 import time
-import json
+
+import requests
+
+from bridge.context import *
+from bridge.reply import *
 from channel.chat_channel import ChatChannel
 from channel.wechat.wechat_message import *
-from common.singleton import singleton
+from common.expired_dict import ExpiredDict
 from common.log import logger
+from common.singleton import singleton
+from common.time_check import time_checker
+from config import conf
 from lib import itchat
 from lib.itchat.content import *
-from bridge.reply import *
-from bridge.context import *
-from config import conf
-from common.time_check import time_checker
-from common.expired_dict import ExpiredDict
 from plugins import *
 
-@itchat.msg_register([TEXT,VOICE,PICTURE])
+
+@itchat.msg_register([TEXT, VOICE, PICTURE])
 def handler_single_msg(msg):
     # logger.debug("handler_single_msg: {}".format(msg))
-    if msg['Type'] == PICTURE and msg['MsgType'] == 47:
+    if msg["Type"] == PICTURE and msg["MsgType"] == 47:
         return None
     WechatChannel().handle_single(WeChatMessage(msg))
     return None
 
-@itchat.msg_register([TEXT,VOICE,PICTURE], isGroupChat=True)
+
+@itchat.msg_register([TEXT, VOICE, PICTURE], isGroupChat=True)
 def handler_group_msg(msg):
-    if msg['Type'] == PICTURE and msg['MsgType'] == 47:
+    if msg["Type"] == PICTURE and msg["MsgType"] == 47:
         return None
-    WechatChannel().handle_group(WeChatMessage(msg,True))
+    WechatChannel().handle_group(WeChatMessage(msg, True))
     return None
 
+
 def _check(func):
     def wrapper(self, cmsg: ChatMessage):
         msgId = cmsg.msg_id
@@ -45,21 +50,27 @@ def _check(func):
             logger.info("Wechat message {} already received, ignore".format(msgId))
             return
         self.receivedMsgs[msgId] = cmsg
-        create_time = cmsg.create_time            # 消息时间戳
-        if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60:  # 跳过1分钟前的历史消息
+        create_time = cmsg.create_time  # 消息时间戳
+        if (
+            conf().get("hot_reload") == True
+            and int(create_time) < int(time.time()) - 60
+        ):  # 跳过1分钟前的历史消息
             logger.debug("[WX]history message {} skipped".format(msgId))
             return
         return func(self, cmsg)
+
     return wrapper
 
-#可用的二维码生成接口
-#https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com
-#https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com
-def qrCallback(uuid,status,qrcode):
+
+# 可用的二维码生成接口
+# https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com
+# https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com
+def qrCallback(uuid, status, qrcode):
     # logger.debug("qrCallback: {} {}".format(uuid,status))
-    if status == '0':
+    if status == "0":
         try:
             from PIL import Image
+
             img = Image.open(io.BytesIO(qrcode))
             _thread = threading.Thread(target=img.show, args=("QRCode",))
             _thread.setDaemon(True)
@@ -68,35 +79,43 @@ def qrCallback(uuid,status,qrcode):
             pass
 
         import qrcode
+
         url = f"https://login.weixin.qq.com/l/{uuid}"
-        
-        qr_api1="https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
-        qr_api2="https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url)
-        qr_api3="https://api.pwmqr.com/qrcode/create/?url={}".format(url)
-        qr_api4="https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url)
+
+        qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
+        qr_api2 = (
+            "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(
+                url
+            )
+        )
+        qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url)
+        qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(
+            url
+        )
         print("You can also scan QRCode in any website below:")
         print(qr_api3)
         print(qr_api4)
         print(qr_api2)
         print(qr_api1)
-        
+
         qr = qrcode.QRCode(border=1)
         qr.add_data(url)
         qr.make(fit=True)
         qr.print_ascii(invert=True)
 
+
 @singleton
 class WechatChannel(ChatChannel):
     NOT_SUPPORT_REPLYTYPE = []
+
     def __init__(self):
         super().__init__()
-        self.receivedMsgs = ExpiredDict(60*60*24) 
+        self.receivedMsgs = ExpiredDict(60 * 60 * 24)
 
     def startup(self):
-
         itchat.instance.receivingRetryCount = 600  # 修改断线超时时间
         # login by scan QRCode
-        hotReload = conf().get('hot_reload', False)
+        hotReload = conf().get("hot_reload", False)
         try:
             itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback)
         except Exception as e:
@@ -104,12 +123,18 @@ class WechatChannel(ChatChannel):
                 logger.error("Hot reload failed, try to login without hot reload")
                 itchat.logout()
                 os.remove("itchat.pkl")
-                itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback)
+                itchat.auto_login(
+                    enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback
+                )
             else:
                 raise e
         self.user_id = itchat.instance.storageClass.userName
         self.name = itchat.instance.storageClass.nickName
-        logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
+        logger.info(
+            "Wechat login success, user_id: {}, nickname: {}".format(
+                self.user_id, self.name
+            )
+        )
         # start message listener
         itchat.run()
 
@@ -127,24 +152,30 @@ class WechatChannel(ChatChannel):
 
     @time_checker
     @_check
-    def handle_single(self, cmsg : ChatMessage):
+    def handle_single(self, cmsg: ChatMessage):
         if cmsg.ctype == ContextType.VOICE:
-            if conf().get('speech_recognition') != True:
+            if conf().get("speech_recognition") != True:
                 return
             logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
         elif cmsg.ctype == ContextType.IMAGE:
             logger.debug("[WX]receive image msg: {}".format(cmsg.content))
         else:
-            logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
-        context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
+            logger.debug(
+                "[WX]receive text msg: {}, cmsg={}".format(
+                    json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg
+                )
+            )
+        context = self._compose_context(
+            cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg
+        )
         if context:
             self.produce(context)
 
     @time_checker
     @_check
-    def handle_group(self, cmsg : ChatMessage):
+    def handle_group(self, cmsg: ChatMessage):
         if cmsg.ctype == ContextType.VOICE:
-            if conf().get('speech_recognition') != True:
+            if conf().get("speech_recognition") != True:
                 return
             logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
         elif cmsg.ctype == ContextType.IMAGE:
@@ -152,23 +183,25 @@ class WechatChannel(ChatChannel):
         else:
             # logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
             pass
-        context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
+        context = self._compose_context(
+            cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg
+        )
         if context:
             self.produce(context)
-    
+
     # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
     def send(self, reply: Reply, context: Context):
         receiver = context["receiver"]
         if reply.type == ReplyType.TEXT:
             itchat.send(reply.content, toUserName=receiver)
-            logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
+            logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
         elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
             itchat.send(reply.content, toUserName=receiver)
-            logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
+            logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
         elif reply.type == ReplyType.VOICE:
             itchat.send_file(reply.content, toUserName=receiver)
-            logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver))
-        elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
+            logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))
+        elif reply.type == ReplyType.IMAGE_URL:  # 从网络下载图片
             img_url = reply.content
             pic_res = requests.get(img_url, stream=True)
             image_storage = io.BytesIO()
@@ -176,9 +209,9 @@ class WechatChannel(ChatChannel):
                 image_storage.write(block)
             image_storage.seek(0)
             itchat.send_image(image_storage, toUserName=receiver)
-            logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver))
-        elif reply.type == ReplyType.IMAGE: # 从文件读取图片
+            logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
+        elif reply.type == ReplyType.IMAGE:  # 从文件读取图片
             image_storage = reply.content
             image_storage.seek(0)
             itchat.send_image(image_storage, toUserName=receiver)
-            logger.info('[WX] sendImage, receiver={}'.format(receiver))
+            logger.info("[WX] sendImage, receiver={}".format(receiver))

+ 28 - 28
channel/wechat/wechat_message.py

@@ -1,54 +1,54 @@
-
-
 from bridge.context import ContextType
 from channel.chat_message import ChatMessage
-from common.tmp_dir import TmpDir
 from common.log import logger
-from lib.itchat.content import *
+from common.tmp_dir import TmpDir
 from lib import itchat
+from lib.itchat.content import *
 
-class WeChatMessage(ChatMessage):
 
+class WeChatMessage(ChatMessage):
     def __init__(self, itchat_msg, is_group=False):
-        super().__init__( itchat_msg)
-        self.msg_id = itchat_msg['MsgId']
-        self.create_time = itchat_msg['CreateTime']
+        super().__init__(itchat_msg)
+        self.msg_id = itchat_msg["MsgId"]
+        self.create_time = itchat_msg["CreateTime"]
         self.is_group = is_group
-        
-        if itchat_msg['Type'] == TEXT:
+
+        if itchat_msg["Type"] == TEXT:
             self.ctype = ContextType.TEXT
-            self.content = itchat_msg['Text']
-        elif itchat_msg['Type'] == VOICE:
+            self.content = itchat_msg["Text"]
+        elif itchat_msg["Type"] == VOICE:
             self.ctype = ContextType.VOICE
-            self.content = TmpDir().path() + itchat_msg['FileName']  # content直接存临时目录路径
+            self.content = TmpDir().path() + itchat_msg["FileName"]  # content直接存临时目录路径
             self._prepare_fn = lambda: itchat_msg.download(self.content)
-        elif itchat_msg['Type'] == PICTURE and itchat_msg['MsgType'] == 3:
+        elif itchat_msg["Type"] == PICTURE and itchat_msg["MsgType"] == 3:
             self.ctype = ContextType.IMAGE
-            self.content = TmpDir().path() + itchat_msg['FileName']  # content直接存临时目录路径
+            self.content = TmpDir().path() + itchat_msg["FileName"]  # content直接存临时目录路径
             self._prepare_fn = lambda: itchat_msg.download(self.content)
         else:
-            raise NotImplementedError("Unsupported message type: {}".format(itchat_msg['Type']))
-        
-        self.from_user_id = itchat_msg['FromUserName']
-        self.to_user_id = itchat_msg['ToUserName']
-        
+            raise NotImplementedError(
+                "Unsupported message type: {}".format(itchat_msg["Type"])
+            )
+
+        self.from_user_id = itchat_msg["FromUserName"]
+        self.to_user_id = itchat_msg["ToUserName"]
+
         user_id = itchat.instance.storageClass.userName
         nickname = itchat.instance.storageClass.nickName
-        
+
         # 虽然from_user_id和to_user_id用的少,但是为了保持一致性,还是要填充一下
         # 以下很繁琐,一句话总结:能填的都填了。
         if self.from_user_id == user_id:
             self.from_user_nickname = nickname
         if self.to_user_id == user_id:
             self.to_user_nickname = nickname
-        try: # 陌生人时候, 'User'字段可能不存在
-            self.other_user_id = itchat_msg['User']['UserName']
-            self.other_user_nickname = itchat_msg['User']['NickName']
+        try:  # 陌生人时候, 'User'字段可能不存在
+            self.other_user_id = itchat_msg["User"]["UserName"]
+            self.other_user_nickname = itchat_msg["User"]["NickName"]
             if self.other_user_id == self.from_user_id:
                 self.from_user_nickname = self.other_user_nickname
             if self.other_user_id == self.to_user_id:
                 self.to_user_nickname = self.other_user_nickname
-        except KeyError as e: # 处理偶尔没有对方信息的情况
+        except KeyError as e:  # 处理偶尔没有对方信息的情况
             logger.warn("[WX]get other_user_id failed: " + str(e))
             if self.from_user_id == user_id:
                 self.other_user_id = self.to_user_id
@@ -56,6 +56,6 @@ class WeChatMessage(ChatMessage):
                 self.other_user_id = self.from_user_id
 
         if self.is_group:
-            self.is_at = itchat_msg['IsAt']
-            self.actual_user_id = itchat_msg['ActualUserName']
-            self.actual_user_nickname = itchat_msg['ActualNickName']
+            self.is_at = itchat_msg["IsAt"]
+            self.actual_user_id = itchat_msg["ActualUserName"]
+            self.actual_user_nickname = itchat_msg["ActualNickName"]

+ 54 - 40
channel/wechat/wechaty_channel.py

@@ -4,104 +4,118 @@
 wechaty channel
 Python Wechaty - https://github.com/wechaty/python-wechaty
 """
+import asyncio
 import base64
 import os
 import time
-import asyncio
-from bridge.context import Context
-from wechaty_puppet import FileBox
-from wechaty import Wechaty, Contact
+
+from wechaty import Contact, Wechaty
 from wechaty.user import Message
-from bridge.reply import *
+from wechaty_puppet import FileBox
+
 from bridge.context import *
+from bridge.context import Context
+from bridge.reply import *
 from channel.chat_channel import ChatChannel
 from channel.wechat.wechaty_message import WechatyMessage
 from common.log import logger
 from common.singleton import singleton
 from config import conf
+
 try:
     from voice.audio_convert import any_to_sil
 except Exception as e:
     pass
 
+
 @singleton
 class WechatyChannel(ChatChannel):
     NOT_SUPPORT_REPLYTYPE = []
+
     def __init__(self):
         super().__init__()
 
     def startup(self):
         config = conf()
-        token = config.get('wechaty_puppet_service_token')
-        os.environ['WECHATY_PUPPET_SERVICE_TOKEN'] = token
+        token = config.get("wechaty_puppet_service_token")
+        os.environ["WECHATY_PUPPET_SERVICE_TOKEN"] = token
         asyncio.run(self.main())
 
     async def main(self):
-        
         loop = asyncio.get_event_loop()
-        #将asyncio的loop传入处理线程
-        self.handler_pool._initializer= lambda: asyncio.set_event_loop(loop)
+        # 将asyncio的loop传入处理线程
+        self.handler_pool._initializer = lambda: asyncio.set_event_loop(loop)
         self.bot = Wechaty()
-        self.bot.on('login', self.on_login)
-        self.bot.on('message', self.on_message)
+        self.bot.on("login", self.on_login)
+        self.bot.on("message", self.on_message)
         await self.bot.start()
 
     async def on_login(self, contact: Contact):
         self.user_id = contact.contact_id
         self.name = contact.name
-        logger.info('[WX] login user={}'.format(contact))
+        logger.info("[WX] login user={}".format(contact))
 
     # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
     def send(self, reply: Reply, context: Context):
-        receiver_id = context['receiver']
+        receiver_id = context["receiver"]
         loop = asyncio.get_event_loop()
-        if context['isgroup']:
-            receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id),loop).result()
+        if context["isgroup"]:
+            receiver = asyncio.run_coroutine_threadsafe(
+                self.bot.Room.find(receiver_id), loop
+            ).result()
         else:
-            receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id),loop).result()
+            receiver = asyncio.run_coroutine_threadsafe(
+                self.bot.Contact.find(receiver_id), loop
+            ).result()
         msg = None
         if reply.type == ReplyType.TEXT:
             msg = reply.content
-            asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
-            logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
+            asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
+            logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
         elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
             msg = reply.content
-            asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
-            logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
+            asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
+            logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
         elif reply.type == ReplyType.VOICE:
             voiceLength = None
             file_path = reply.content
-            sil_file = os.path.splitext(file_path)[0] + '.sil'
+            sil_file = os.path.splitext(file_path)[0] + ".sil"
             voiceLength = int(any_to_sil(file_path, sil_file))
             if voiceLength >= 60000:
                 voiceLength = 60000
-                logger.info('[WX] voice too long, length={}, set to 60s'.format(voiceLength))
+                logger.info(
+                    "[WX] voice too long, length={}, set to 60s".format(voiceLength)
+                )
             # 发送语音
             t = int(time.time())
-            msg = FileBox.from_file(sil_file, name=str(t) + '.sil')
+            msg = FileBox.from_file(sil_file, name=str(t) + ".sil")
             if voiceLength is not None:
-                msg.metadata['voiceLength'] = voiceLength
-            asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
+                msg.metadata["voiceLength"] = voiceLength
+            asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
             try:
                 os.remove(file_path)
                 if sil_file != file_path:
                     os.remove(sil_file)
             except Exception as e:
                 pass
-            logger.info('[WX] sendVoice={}, receiver={}'.format(reply.content, receiver))
-        elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
+            logger.info(
+                "[WX] sendVoice={}, receiver={}".format(reply.content, receiver)
+            )
+        elif reply.type == ReplyType.IMAGE_URL:  # 从网络下载图片
             img_url = reply.content
             t = int(time.time())
-            msg = FileBox.from_url(url=img_url, name=str(t) + '.png')
-            asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
-            logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver))
-        elif reply.type == ReplyType.IMAGE: # 从文件读取图片
+            msg = FileBox.from_url(url=img_url, name=str(t) + ".png")
+            asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
+            logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
+        elif reply.type == ReplyType.IMAGE:  # 从文件读取图片
             image_storage = reply.content
             image_storage.seek(0)
             t = int(time.time())
-            msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + '.png')
-            asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
-            logger.info('[WX] sendImage, receiver={}'.format(receiver))
+            msg = FileBox.from_base64(
+                base64.b64encode(image_storage.read()), str(t) + ".png"
+            )
+            asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
+            logger.info("[WX] sendImage, receiver={}".format(receiver))
 
     async def on_message(self, msg: Message):
         """
@@ -110,16 +124,16 @@ class WechatyChannel(ChatChannel):
         try:
             cmsg = await WechatyMessage(msg)
         except NotImplementedError as e:
-            logger.debug('[WX] {}'.format(e))
+            logger.debug("[WX] {}".format(e))
             return
         except Exception as e:
-            logger.exception('[WX] {}'.format(e))
+            logger.exception("[WX] {}".format(e))
             return
-        logger.debug('[WX] message:{}'.format(cmsg))
+        logger.debug("[WX] message:{}".format(cmsg))
         room = msg.room()  # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None
         isgroup = room is not None
         ctype = cmsg.ctype
         context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg)
         if context:
-            logger.info('[WX] receiveMsg={}, context={}'.format(cmsg, context))
-            self.produce(context)
+            logger.info("[WX] receiveMsg={}, context={}".format(cmsg, context))
+            self.produce(context)

+ 28 - 18
channel/wechat/wechaty_message.py

@@ -1,17 +1,21 @@
 import asyncio
 import re
+
 from wechaty import MessageType
+from wechaty.user import Message
+
 from bridge.context import ContextType
 from channel.chat_message import ChatMessage
-from common.tmp_dir import TmpDir
 from common.log import logger
-from wechaty.user import Message
+from common.tmp_dir import TmpDir
+
 
 class aobject(object):
     """Inheriting this class allows you to define an async __init__.
 
     So you can create objects by doing something like `await MyClass(params)`
     """
+
     async def __new__(cls, *a, **kw):
         instance = super().__new__(cls)
         await instance.__init__(*a, **kw)
@@ -19,17 +23,18 @@ class aobject(object):
 
     async def __init__(self):
         pass
-class WechatyMessage(ChatMessage, aobject):
 
+
+class WechatyMessage(ChatMessage, aobject):
     async def __init__(self, wechaty_msg: Message):
         super().__init__(wechaty_msg)
-        
+
         room = wechaty_msg.room()
 
         self.msg_id = wechaty_msg.message_id
         self.create_time = wechaty_msg.payload.timestamp
         self.is_group = room is not None
-        
+
         if wechaty_msg.type() == MessageType.MESSAGE_TYPE_TEXT:
             self.ctype = ContextType.TEXT
             self.content = wechaty_msg.text()
@@ -40,12 +45,17 @@ class WechatyMessage(ChatMessage, aobject):
 
             def func():
                 loop = asyncio.get_event_loop()
-                asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content),loop).result()
+                asyncio.run_coroutine_threadsafe(
+                    voice_file.to_file(self.content), loop
+                ).result()
+
             self._prepare_fn = func
-            
+
         else:
-            raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type()))
-        
+            raise NotImplementedError(
+                "Unsupported message type: {}".format(wechaty_msg.type())
+            )
+
         from_contact = wechaty_msg.talker()  # 获取消息的发送者
         self.from_user_id = from_contact.contact_id
         self.from_user_nickname = from_contact.name
@@ -54,7 +64,7 @@ class WechatyMessage(ChatMessage, aobject):
         # wecahty: from是消息实际发送者, to:所在群
         # itchat: 如果是你发送群消息,from和to是你自己和所在群,如果是别人发群消息,from和to是所在群和你自己
         # 但这个差别不影响逻辑,group中只使用到:1.用from来判断是否是自己发的,2.actual_user_id来判断实际发送用户
-        
+
         if self.is_group:
             self.to_user_id = room.room_id
             self.to_user_nickname = await room.topic()
@@ -63,22 +73,22 @@ class WechatyMessage(ChatMessage, aobject):
             self.to_user_id = to_contact.contact_id
             self.to_user_nickname = to_contact.name
 
-        if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
+        if (
+            self.is_group or wechaty_msg.is_self()
+        ):  # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
             self.other_user_id = self.to_user_id
             self.other_user_nickname = self.to_user_nickname
         else:
             self.other_user_id = self.from_user_id
             self.other_user_nickname = self.from_user_nickname
 
-        
-
-        if self.is_group: # wechaty群聊中,实际发送用户就是from_user
+        if self.is_group:  # wechaty群聊中,实际发送用户就是from_user
             self.is_at = await wechaty_msg.mention_self()
-            if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容
+            if not self.is_at:  # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容
                 name = wechaty_msg.wechaty.user_self().name
-                pattern = f'@{name}(\u2005|\u0020)'
-                if re.search(pattern,self.content):
-                    logger.debug(f'wechaty message {self.msg_id} include at')
+                pattern = f"@{name}(\u2005|\u0020)"
+                if re.search(pattern, self.content):
+                    logger.debug(f"wechaty message {self.msg_id} include at")
                     self.is_at = True
 
             self.actual_user_id = self.from_user_id

+ 2 - 2
channel/wechatmp/README.md

@@ -21,12 +21,12 @@ pip3 install web.py
 
 相关的服务器验证代码已经写好,你不需要再添加任何代码。你只需要在本项目根目录的`config.json`中添加
 ```
-"channel_type": "wechatmp", 
+"channel_type": "wechatmp",
 "wechatmp_token": "Token",  # 微信公众平台的Token
 "wechatmp_port": 8080,      # 微信公众平台的端口,需要端口转发到80或443
 "wechatmp_app_id": "",      # 微信公众平台的appID,仅服务号需要
 "wechatmp_app_secret": "",  # 微信公众平台的appsecret,仅服务号需要
-``` 
+```
 然后运行`python3 app.py`启动web服务器。这里会默认监听8080端口,但是微信公众号的服务器配置只支持80/443端口,有两种方法来解决这个问题。第一个是推荐的方法,使用端口转发命令将80端口转发到8080端口(443同理,注意需要支持SSL,也就是https的访问,在`wechatmp_channel.py`需要修改相应的证书路径):
 ```
 sudo iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to-port 8080

+ 35 - 16
channel/wechatmp/ServiceAccount.py

@@ -1,46 +1,66 @@
-import web
 import time
-import channel.wechatmp.reply as reply
+
+import web
+
 import channel.wechatmp.receive as receive
-from config import conf
-from common.log import logger
+import channel.wechatmp.reply as reply
 from bridge.context import *
-from channel.wechatmp.common import * 
+from channel.wechatmp.common import *
 from channel.wechatmp.wechatmp_channel import WechatMPChannel
+from common.log import logger
+from config import conf
 
-# This class is instantiated once per query
-class Query():
 
+# This class is instantiated once per query
+class Query:
     def GET(self):
         return verify_server(web.input())
 
     def POST(self):
-        # Make sure to return the instance that first created, @singleton will do that. 
+        # Make sure to return the instance that first created, @singleton will do that.
         channel = WechatMPChannel()
         try:
             webData = web.data()
             # logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
             wechatmp_msg = receive.parse_xml(webData)
-            if wechatmp_msg.msg_type == 'text' or wechatmp_msg.msg_type == 'voice':
+            if wechatmp_msg.msg_type == "text" or wechatmp_msg.msg_type == "voice":
                 from_user = wechatmp_msg.from_user_id
                 message = wechatmp_msg.content.decode("utf-8")
                 message_id = wechatmp_msg.msg_id
 
-                logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message))
-                context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg)
+                logger.info(
+                    "[wechatmp] {}:{} Receive post query {} {}: {}".format(
+                        web.ctx.env.get("REMOTE_ADDR"),
+                        web.ctx.env.get("REMOTE_PORT"),
+                        from_user,
+                        message_id,
+                        message,
+                    )
+                )
+                context = channel._compose_context(
+                    ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg
+                )
                 if context:
                     # set private openai_api_key
                     # if from_user is not changed in itchat, this can be placed at chat_channel
                     user_data = conf().get_user_data(from_user)
-                    context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key
+                    context["openai_api_key"] = user_data.get(
+                        "openai_api_key"
+                    )  # None or user openai_api_key
                     channel.produce(context)
                 # The reply will be sent by channel.send() in another thread
                 return "success"
 
-            elif wechatmp_msg.msg_type == 'event':
-                logger.info("[wechatmp] Event {} from {}".format(wechatmp_msg.Event, wechatmp_msg.from_user_id))
+            elif wechatmp_msg.msg_type == "event":
+                logger.info(
+                    "[wechatmp] Event {} from {}".format(
+                        wechatmp_msg.Event, wechatmp_msg.from_user_id
+                    )
+                )
                 content = subscribe_msg()
-                replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content)
+                replyMsg = reply.TextMsg(
+                    wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content
+                )
                 return replyMsg.send()
             else:
                 logger.info("暂且不处理")
@@ -48,4 +68,3 @@ class Query():
         except Exception as exc:
             logger.exception(exc)
             return exc
-

+ 98 - 38
channel/wechatmp/SubscribeAccount.py

@@ -1,81 +1,117 @@
-import web
 import time
-import channel.wechatmp.reply as reply
+
+import web
+
 import channel.wechatmp.receive as receive
-from config import conf
-from common.log import logger
+import channel.wechatmp.reply as reply
 from bridge.context import *
-from channel.wechatmp.common import * 
+from channel.wechatmp.common import *
 from channel.wechatmp.wechatmp_channel import WechatMPChannel
+from common.log import logger
+from config import conf
 
-# This class is instantiated once per query
-class Query():
 
+# This class is instantiated once per query
+class Query:
     def GET(self):
         return verify_server(web.input())
 
     def POST(self):
-        # Make sure to return the instance that first created, @singleton will do that. 
+        # Make sure to return the instance that first created, @singleton will do that.
         channel = WechatMPChannel()
         try:
             query_time = time.time()
             webData = web.data()
             logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
             wechatmp_msg = receive.parse_xml(webData)
-            if wechatmp_msg.msg_type == 'text' or wechatmp_msg.msg_type == 'voice':
+            if wechatmp_msg.msg_type == "text" or wechatmp_msg.msg_type == "voice":
                 from_user = wechatmp_msg.from_user_id
                 to_user = wechatmp_msg.to_user_id
                 message = wechatmp_msg.content.decode("utf-8")
                 message_id = wechatmp_msg.msg_id
 
-                logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message))
+                logger.info(
+                    "[wechatmp] {}:{} Receive post query {} {}: {}".format(
+                        web.ctx.env.get("REMOTE_ADDR"),
+                        web.ctx.env.get("REMOTE_PORT"),
+                        from_user,
+                        message_id,
+                        message,
+                    )
+                )
                 supported = True
                 if "【收到不支持的消息类型,暂无法显示】" in message:
-                    supported = False # not supported, used to refresh
+                    supported = False  # not supported, used to refresh
                 cache_key = from_user
 
                 reply_text = ""
                 # New request
-                if cache_key not in channel.cache_dict and cache_key not in channel.running:
+                if (
+                    cache_key not in channel.cache_dict
+                    and cache_key not in channel.running
+                ):
                     # The first query begin, reset the cache
-                    context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg)
-                    logger.debug("[wechatmp] context: {} {}".format(context, wechatmp_msg))
-                    if message_id in channel.received_msgs: # received and finished
+                    context = channel._compose_context(
+                        ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg
+                    )
+                    logger.debug(
+                        "[wechatmp] context: {} {}".format(context, wechatmp_msg)
+                    )
+                    if message_id in channel.received_msgs:  # received and finished
                         # no return because of bandwords or other reasons
                         return "success"
                     if supported and context:
                         # set private openai_api_key
                         # if from_user is not changed in itchat, this can be placed at chat_channel
                         user_data = conf().get_user_data(from_user)
-                        context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key
+                        context["openai_api_key"] = user_data.get(
+                            "openai_api_key"
+                        )  # None or user openai_api_key
                         channel.received_msgs[message_id] = wechatmp_msg
                         channel.running.add(cache_key)
                         channel.produce(context)
                     else:
-                        trigger_prefix = conf().get('single_chat_prefix',[''])[0]
+                        trigger_prefix = conf().get("single_chat_prefix", [""])[0]
                         if trigger_prefix or not supported:
                             if trigger_prefix:
-                                content = textwrap.dedent(f"""\
+                                content = textwrap.dedent(
+                                    f"""\
                                     请输入'{trigger_prefix}'接你想说的话跟我说话。
                                     例如:
-                                    {trigger_prefix}你好,很高兴见到你。""")
+                                    {trigger_prefix}你好,很高兴见到你。"""
+                                )
                             else:
-                                content = textwrap.dedent("""\
+                                content = textwrap.dedent(
+                                    """\
                                     你好,很高兴见到你。
-                                    请跟我说话吧。""")
+                                    请跟我说话吧。"""
+                                )
                         else:
                             logger.error(f"[wechatmp] unknown error")
-                            content = textwrap.dedent("""\
-                                未知错误,请稍后再试""")
-                        replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content)
+                            content = textwrap.dedent(
+                                """\
+                                未知错误,请稍后再试"""
+                            )
+                        replyMsg = reply.TextMsg(
+                            wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content
+                        )
                         return replyMsg.send()
                     channel.query1[cache_key] = False
                     channel.query2[cache_key] = False
                     channel.query3[cache_key] = False
                 # User request again, and the answer is not ready
-                elif cache_key in channel.running and channel.query1.get(cache_key) == True and channel.query2.get(cache_key) == True and channel.query3.get(cache_key) == True:
-                    channel.query1[cache_key] = False  #To improve waiting experience, this can be set to True.
-                    channel.query2[cache_key] = False  #To improve waiting experience, this can be set to True.
+                elif (
+                    cache_key in channel.running
+                    and channel.query1.get(cache_key) == True
+                    and channel.query2.get(cache_key) == True
+                    and channel.query3.get(cache_key) == True
+                ):
+                    channel.query1[
+                        cache_key
+                    ] = False  # To improve waiting experience, this can be set to True.
+                    channel.query2[
+                        cache_key
+                    ] = False  # To improve waiting experience, this can be set to True.
                     channel.query3[cache_key] = False
                 # User request again, and the answer is ready
                 elif cache_key in channel.cache_dict:
@@ -84,7 +120,9 @@ class Query():
                     channel.query2[cache_key] = True
                     channel.query3[cache_key] = True
 
-                assert not (cache_key in channel.cache_dict and cache_key in channel.running)
+                assert not (
+                    cache_key in channel.cache_dict and cache_key in channel.running
+                )
 
                 if channel.query1.get(cache_key) == False:
                     # The first query from wechat official server
@@ -128,14 +166,20 @@ class Query():
                         # Have waiting for 3x5 seconds
                         # return timeout message
                         reply_text = "【正在思考中,回复任意文字尝试获取回复】"
-                        logger.info("[wechatmp] Three queries has finished For {}: {}".format(from_user, message_id))
+                        logger.info(
+                            "[wechatmp] Three queries has finished For {}: {}".format(
+                                from_user, message_id
+                            )
+                        )
                         replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
                         return replyPost
                     else:
                         pass
 
-
-                if cache_key not in channel.cache_dict and cache_key not in channel.running:
+                if (
+                    cache_key not in channel.cache_dict
+                    and cache_key not in channel.running
+                ):
                     # no return because of bandwords or other reasons
                     return "success"
 
@@ -147,26 +191,42 @@ class Query():
 
                 if cache_key in channel.cache_dict:
                     content = channel.cache_dict[cache_key]
-                    if len(content.encode('utf8'))<=MAX_UTF8_LEN:
+                    if len(content.encode("utf8")) <= MAX_UTF8_LEN:
                         reply_text = channel.cache_dict[cache_key]
                         channel.cache_dict.pop(cache_key)
                     else:
                         continue_text = "\n【未完待续,回复任意文字以继续】"
-                        splits = split_string_by_utf8_length(content, MAX_UTF8_LEN - len(continue_text.encode('utf-8')), max_split= 1)
+                        splits = split_string_by_utf8_length(
+                            content,
+                            MAX_UTF8_LEN - len(continue_text.encode("utf-8")),
+                            max_split=1,
+                        )
                         reply_text = splits[0] + continue_text
                         channel.cache_dict[cache_key] = splits[1]
-                logger.info("[wechatmp] {}:{} Do send {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), reply_text))
+                logger.info(
+                    "[wechatmp] {}:{} Do send {}".format(
+                        web.ctx.env.get("REMOTE_ADDR"),
+                        web.ctx.env.get("REMOTE_PORT"),
+                        reply_text,
+                    )
+                )
                 replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
                 return replyPost
 
-            elif wechatmp_msg.msg_type == 'event':
-                logger.info("[wechatmp] Event {} from {}".format(wechatmp_msg.content, wechatmp_msg.from_user_id))
+            elif wechatmp_msg.msg_type == "event":
+                logger.info(
+                    "[wechatmp] Event {} from {}".format(
+                        wechatmp_msg.content, wechatmp_msg.from_user_id
+                    )
+                )
                 content = subscribe_msg()
-                replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content)
+                replyMsg = reply.TextMsg(
+                    wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content
+                )
                 return replyMsg.send()
             else:
                 logger.info("暂且不处理")
                 return "success"
         except Exception as exc:
             logger.exception(exc)
-            return exc
+            return exc

+ 15 - 10
channel/wechatmp/common.py

@@ -1,9 +1,11 @@
-from config import conf
 import hashlib
 import textwrap
 
+from config import conf
+
 MAX_UTF8_LEN = 2048
 
+
 class WeChatAPIException(Exception):
     pass
 
@@ -16,13 +18,13 @@ def verify_server(data):
         timestamp = data.timestamp
         nonce = data.nonce
         echostr = data.echostr
-        token = conf().get('wechatmp_token') #请按照公众平台官网\基本配置中信息填写
+        token = conf().get("wechatmp_token")  # 请按照公众平台官网\基本配置中信息填写
 
         data_list = [token, timestamp, nonce]
         data_list.sort()
         sha1 = hashlib.sha1()
         # map(sha1.update, data_list) #python2
-        sha1.update("".join(data_list).encode('utf-8'))
+        sha1.update("".join(data_list).encode("utf-8"))
         hashcode = sha1.hexdigest()
         print("handle/GET func: hashcode, signature: ", hashcode, signature)
         if hashcode == signature:
@@ -32,9 +34,11 @@ def verify_server(data):
     except Exception as Argument:
         return Argument
 
+
 def subscribe_msg():
-    trigger_prefix = conf().get('single_chat_prefix',[''])[0]
-    msg = textwrap.dedent(f"""\
+    trigger_prefix = conf().get("single_chat_prefix", [""])[0]
+    msg = textwrap.dedent(
+        f"""\
                     感谢您的关注!
                     这里是ChatGPT,可以自由对话。
                     资源有限,回复较慢,请勿着急。
@@ -42,22 +46,23 @@ def subscribe_msg():
                     暂时不支持图片输入。
                     支持图片输出,画字开头的问题将回复图片链接。
                     支持角色扮演和文字冒险两种定制模式对话。
-                    输入'{trigger_prefix}#帮助' 查看详细指令。""")
+                    输入'{trigger_prefix}#帮助' 查看详细指令。"""
+    )
     return msg
 
 
 def split_string_by_utf8_length(string, max_length, max_split=0):
-    encoded = string.encode('utf-8')
+    encoded = string.encode("utf-8")
     start, end = 0, 0
     result = []
     while end < len(encoded):
         if max_split > 0 and len(result) >= max_split:
-            result.append(encoded[start:].decode('utf-8'))
+            result.append(encoded[start:].decode("utf-8"))
             break
         end = start + max_length
         # 如果当前字节不是 UTF-8 编码的开始字节,则向前查找直到找到开始字节为止
         while end < len(encoded) and (encoded[end] & 0b11000000) == 0b10000000:
             end -= 1
-        result.append(encoded[start:end].decode('utf-8'))
+        result.append(encoded[start:end].decode("utf-8"))
         start = end
-    return result
+    return result

+ 20 - 18
channel/wechatmp/receive.py

@@ -1,6 +1,7 @@
 # -*- coding: utf-8 -*-#
 # filename: receive.py
 import xml.etree.ElementTree as ET
+
 from bridge.context import ContextType
 from channel.chat_message import ChatMessage
 from common.log import logger
@@ -12,34 +13,35 @@ def parse_xml(web_data):
     xmlData = ET.fromstring(web_data)
     return WeChatMPMessage(xmlData)
 
+
 class WeChatMPMessage(ChatMessage):
     def __init__(self, xmlData):
         super().__init__(xmlData)
-        self.to_user_id = xmlData.find('ToUserName').text
-        self.from_user_id = xmlData.find('FromUserName').text
-        self.create_time = xmlData.find('CreateTime').text
-        self.msg_type = xmlData.find('MsgType').text
+        self.to_user_id = xmlData.find("ToUserName").text
+        self.from_user_id = xmlData.find("FromUserName").text
+        self.create_time = xmlData.find("CreateTime").text
+        self.msg_type = xmlData.find("MsgType").text
         try:
-            self.msg_id = xmlData.find('MsgId').text
+            self.msg_id = xmlData.find("MsgId").text
         except:
-            self.msg_id = self.from_user_id+self.create_time
+            self.msg_id = self.from_user_id + self.create_time
         self.is_group = False
-        
+
         # reply to other_user_id
         self.other_user_id = self.from_user_id
 
-        if self.msg_type == 'text':
+        if self.msg_type == "text":
             self.ctype = ContextType.TEXT
-            self.content = xmlData.find('Content').text.encode("utf-8")
-        elif self.msg_type == 'voice':
+            self.content = xmlData.find("Content").text.encode("utf-8")
+        elif self.msg_type == "voice":
             self.ctype = ContextType.TEXT
-            self.content = xmlData.find('Recognition').text.encode("utf-8")  # 接收语音识别结果
-        elif self.msg_type == 'image':
+            self.content = xmlData.find("Recognition").text.encode("utf-8")  # 接收语音识别结果
+        elif self.msg_type == "image":
             # not implemented
-            self.pic_url = xmlData.find('PicUrl').text
-            self.media_id = xmlData.find('MediaId').text
-        elif self.msg_type == 'event':
-            self.content = xmlData.find('Event').text
-        else: # video, shortvideo, location, link
+            self.pic_url = xmlData.find("PicUrl").text
+            self.media_id = xmlData.find("MediaId").text
+        elif self.msg_type == "event":
+            self.content = xmlData.find("Event").text
+        else:  # video, shortvideo, location, link
             # not implemented
-            pass
+            pass

+ 12 - 9
channel/wechatmp/reply.py

@@ -2,6 +2,7 @@
 # filename: reply.py
 import time
 
+
 class Msg(object):
     def __init__(self):
         pass
@@ -9,13 +10,14 @@ class Msg(object):
     def send(self):
         return "success"
 
+
 class TextMsg(Msg):
     def __init__(self, toUserName, fromUserName, content):
         self.__dict = dict()
-        self.__dict['ToUserName'] = toUserName
-        self.__dict['FromUserName'] = fromUserName
-        self.__dict['CreateTime'] = int(time.time())
-        self.__dict['Content'] = content
+        self.__dict["ToUserName"] = toUserName
+        self.__dict["FromUserName"] = fromUserName
+        self.__dict["CreateTime"] = int(time.time())
+        self.__dict["Content"] = content
 
     def send(self):
         XmlForm = """
@@ -29,13 +31,14 @@ class TextMsg(Msg):
             """
         return XmlForm.format(**self.__dict)
 
+
 class ImageMsg(Msg):
     def __init__(self, toUserName, fromUserName, mediaId):
         self.__dict = dict()
-        self.__dict['ToUserName'] = toUserName
-        self.__dict['FromUserName'] = fromUserName
-        self.__dict['CreateTime'] = int(time.time())
-        self.__dict['MediaId'] = mediaId
+        self.__dict["ToUserName"] = toUserName
+        self.__dict["FromUserName"] = fromUserName
+        self.__dict["CreateTime"] = int(time.time())
+        self.__dict["MediaId"] = mediaId
 
     def send(self):
         XmlForm = """
@@ -49,4 +52,4 @@ class ImageMsg(Msg):
                 </Image>
             </xml>
             """
-        return XmlForm.format(**self.__dict)
+        return XmlForm.format(**self.__dict)

+ 47 - 38
channel/wechatmp/wechatmp_channel.py

@@ -1,17 +1,19 @@
 # -*- coding: utf-8 -*-
-import web
-import time
 import json
-import requests
 import threading
-from common.singleton import singleton
-from common.log import logger
-from common.expired_dict import ExpiredDict
-from config import conf
-from bridge.reply import *
+import time
+
+import requests
+import web
+
 from bridge.context import *
+from bridge.reply import *
 from channel.chat_channel import ChatChannel
-from channel.wechatmp.common import * 
+from channel.wechatmp.common import *
+from common.expired_dict import ExpiredDict
+from common.log import logger
+from common.singleton import singleton
+from config import conf
 
 # If using SSL, uncomment the following lines, and modify the certificate path.
 # from cheroot.server import HTTPServer
@@ -20,13 +22,14 @@ from channel.wechatmp.common import *
 #         certificate='/ssl/cert.pem',
 #         private_key='/ssl/cert.key')
 
+
 @singleton
 class WechatMPChannel(ChatChannel):
-    def __init__(self, passive_reply = True):
+    def __init__(self, passive_reply=True):
         super().__init__()
         self.passive_reply = passive_reply
         self.running = set()
-        self.received_msgs = ExpiredDict(60*60*24)
+        self.received_msgs = ExpiredDict(60 * 60 * 24)
         if self.passive_reply:
             self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE]
             self.cache_dict = dict()
@@ -36,8 +39,8 @@ class WechatMPChannel(ChatChannel):
         else:
             # TODO support image
             self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE]
-            self.app_id = conf().get('wechatmp_app_id')
-            self.app_secret = conf().get('wechatmp_app_secret')
+            self.app_id = conf().get("wechatmp_app_id")
+            self.app_secret = conf().get("wechatmp_app_secret")
             self.access_token = None
             self.access_token_expires_time = 0
             self.access_token_lock = threading.Lock()
@@ -45,13 +48,12 @@ class WechatMPChannel(ChatChannel):
 
     def startup(self):
         if self.passive_reply:
-            urls = ('/wx', 'channel.wechatmp.SubscribeAccount.Query')
+            urls = ("/wx", "channel.wechatmp.SubscribeAccount.Query")
         else:
-            urls = ('/wx', 'channel.wechatmp.ServiceAccount.Query')
+            urls = ("/wx", "channel.wechatmp.ServiceAccount.Query")
         app = web.application(urls, globals(), autoreload=False)
-        port = conf().get('wechatmp_port', 8080)
-        web.httpserver.runsimple(app.wsgifunc(), ('0.0.0.0', port))
-
+        port = conf().get("wechatmp_port", 8080)
+        web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
 
     def wechatmp_request(self, method, url, **kwargs):
         r = requests.request(method=method, url=url, **kwargs)
@@ -63,7 +65,6 @@ class WechatMPChannel(ChatChannel):
         return ret
 
     def get_access_token(self):
-
         # return the access_token
         if self.access_token:
             if self.access_token_expires_time - time.time() > 60:
@@ -76,15 +77,15 @@ class WechatMPChannel(ChatChannel):
             # This happens every 2 hours, so it doesn't affect the experience very much
             time.sleep(1)
             self.access_token = None
-            url="https://api.weixin.qq.com/cgi-bin/token"
-            params={
+            url = "https://api.weixin.qq.com/cgi-bin/token"
+            params = {
                 "grant_type": "client_credential",
                 "appid": self.app_id,
-                "secret": self.app_secret
+                "secret": self.app_secret,
             }
-            data = self.wechatmp_request(method='get', url=url, params=params)
-            self.access_token = data['access_token']
-            self.access_token_expires_time = int(time.time()) + data['expires_in']
+            data = self.wechatmp_request(method="get", url=url, params=params)
+            self.access_token = data["access_token"]
+            self.access_token_expires_time = int(time.time()) + data["expires_in"]
             logger.info("[wechatmp] access_token: {}".format(self.access_token))
             self.access_token_lock.release()
         else:
@@ -101,29 +102,37 @@ class WechatMPChannel(ChatChannel):
         else:
             receiver = context["receiver"]
             reply_text = reply.content
-            url="https://api.weixin.qq.com/cgi-bin/message/custom/send"
-            params = {
-                "access_token": self.get_access_token()
-            }
+            url = "https://api.weixin.qq.com/cgi-bin/message/custom/send"
+            params = {"access_token": self.get_access_token()}
             json_data = {
                 "touser": receiver,
                 "msgtype": "text",
-                "text": {"content": reply_text}
+                "text": {"content": reply_text},
             }
-            self.wechatmp_request(method='post', url=url, params=params, data=json.dumps(json_data, ensure_ascii=False).encode('utf8'))
+            self.wechatmp_request(
+                method="post",
+                url=url,
+                params=params,
+                data=json.dumps(json_data, ensure_ascii=False).encode("utf8"),
+            )
             logger.info("[send] Do send to {}: {}".format(receiver, reply_text))
         return
 
-
-    def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数
-        logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context['msg'].msg_id))
+    def _success_callback(self, session_id, context, **kwargs):  # 线程异常结束时的回调函数
+        logger.debug(
+            "[wechatmp] Success to generate reply, msgId={}".format(
+                context["msg"].msg_id
+            )
+        )
         if self.passive_reply:
             self.running.remove(session_id)
 
-
-    def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数
-        logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context['msg'].msg_id, exception))
+    def _fail_callback(self, session_id, exception, context, **kwargs):  # 线程异常结束时的回调函数
+        logger.exception(
+            "[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(
+                context["msg"].msg_id, exception
+            )
+        )
         if self.passive_reply:
             assert session_id not in self.cache_dict
             self.running.remove(session_id)
-

+ 1 - 1
common/const.py

@@ -2,4 +2,4 @@
 OPEN_AI = "openAI"
 CHATGPT = "chatGPT"
 BAIDU = "baidu"
-CHATGPTONAZURE = "chatGPTOnAzure"
+CHATGPTONAZURE = "chatGPTOnAzure"

+ 2 - 2
common/dequeue.py

@@ -1,7 +1,7 @@
-
 from queue import Full, Queue
 from time import monotonic as time
 
+
 # add implementation of putleft to Queue
 class Dequeue(Queue):
     def putleft(self, item, block=True, timeout=None):
@@ -30,4 +30,4 @@ class Dequeue(Queue):
         return self.putleft(item, block=False)
 
     def _putleft(self, item):
-        self.queue.appendleft(item)
+        self.queue.appendleft(item)

+ 1 - 1
common/expired_dict.py

@@ -39,4 +39,4 @@ class ExpiredDict(dict):
         return [(key, self[key]) for key in self.keys()]
 
     def __iter__(self):
-        return self.keys().__iter__()
+        return self.keys().__iter__()

+ 16 - 7
common/log.py

@@ -10,20 +10,29 @@ def _reset_logger(log):
     log.handlers.clear()
     log.propagate = False
     console_handle = logging.StreamHandler(sys.stdout)
-    console_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s',
-                                                  datefmt='%Y-%m-%d %H:%M:%S'))
-    file_handle = logging.FileHandler('run.log', encoding='utf-8')
-    file_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s',
-                                                  datefmt='%Y-%m-%d %H:%M:%S'))
+    console_handle.setFormatter(
+        logging.Formatter(
+            "[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s",
+            datefmt="%Y-%m-%d %H:%M:%S",
+        )
+    )
+    file_handle = logging.FileHandler("run.log", encoding="utf-8")
+    file_handle.setFormatter(
+        logging.Formatter(
+            "[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s",
+            datefmt="%Y-%m-%d %H:%M:%S",
+        )
+    )
     log.addHandler(file_handle)
     log.addHandler(console_handle)
 
+
 def _get_logger():
-    log = logging.getLogger('log')
+    log = logging.getLogger("log")
     _reset_logger(log)
     log.setLevel(logging.INFO)
     return log
 
 
 # 日志句柄
-logger = _get_logger()
+logger = _get_logger()

+ 11 - 5
common/package_manager.py

@@ -1,15 +1,20 @@
 import time
+
 import pip
 from pip._internal import main as pipmain
-from common.log import logger,_reset_logger
+
+from common.log import _reset_logger, logger
+
 
 def install(package):
-    pipmain(['install', package])
+    pipmain(["install", package])
+
 
 def install_requirements(file):
-    pipmain(['install', '-r', file, "--upgrade"])
+    pipmain(["install", "-r", file, "--upgrade"])
     _reset_logger(logger)
 
+
 def check_dulwich():
     needwait = False
     for i in range(2):
@@ -18,13 +23,14 @@ def check_dulwich():
             needwait = False
         try:
             import dulwich
+
             return
         except ImportError:
             try:
-                install('dulwich')
+                install("dulwich")
             except:
                 needwait = True
     try:
         import dulwich
     except ImportError:
-        raise ImportError("Unable to import dulwich")
+        raise ImportError("Unable to import dulwich")

+ 1 - 1
common/sorted_dict.py

@@ -62,4 +62,4 @@ class SortedDict(dict):
         return iter(self.keys())
 
     def __repr__(self):
-        return f'{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})'
+        return f"{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})"

+ 22 - 10
common/time_check.py

@@ -1,7 +1,11 @@
-import time,re,hashlib
+import hashlib
+import re
+import time
+
 import config
 from common.log import logger
 
+
 def time_checker(f):
     def _time_checker(self, *args, **kwargs):
         _config = config.conf()
@@ -9,17 +13,25 @@ def time_checker(f):
         if chat_time_module:
             chat_start_time = _config.get("chat_start_time", "00:00")
             chat_stopt_time = _config.get("chat_stop_time", "24:00")
-            time_regex = re.compile(r'^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$')  #时间匹配,包含24:00
+            time_regex = re.compile(
+                r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$"
+            )  # 时间匹配,包含24:00
 
             starttime_format_check = time_regex.match(chat_start_time)  # 检查停止时间格式
             stoptime_format_check = time_regex.match(chat_stopt_time)  # 检查停止时间格式
-            chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间
+            chat_time_check = chat_start_time < chat_stopt_time  # 确定启动时间<停止时间
 
             # 时间格式检查
-            if not (starttime_format_check and stoptime_format_check and chat_time_check):
-                logger.warn('时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})'.format(starttime_format_check,stoptime_format_check))
-            if chat_start_time>"23:59":
-                logger.error('启动时间可能存在问题,请修改!')
+            if not (
+                starttime_format_check and stoptime_format_check and chat_time_check
+            ):
+                logger.warn(
+                    "时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(
+                        starttime_format_check, stoptime_format_check
+                    )
+                )
+            if chat_start_time > "23:59":
+                logger.error("启动时间可能存在问题,请修改!")
 
             # 服务时间检查
             now_time = time.strftime("%H:%M", time.localtime())
@@ -27,12 +39,12 @@ def time_checker(f):
                 f(self, *args, **kwargs)
                 return None
             else:
-                if args[0]['Content'] == "#更新配置":  # 不在服务时间内也可以更新配置
+                if args[0]["Content"] == "#更新配置":  # 不在服务时间内也可以更新配置
                     f(self, *args, **kwargs)
                 else:
-                    logger.info('非服务时间内,不接受访问')
+                    logger.info("非服务时间内,不接受访问")
                     return None
         else:
             f(self, *args, **kwargs)  # 未开启时间模块则直接回答
-    return _time_checker
 
+    return _time_checker

+ 5 - 7
common/tmp_dir.py

@@ -1,20 +1,18 @@
-
 import os
 import pathlib
+
 from config import conf
 
 
 class TmpDir(object):
-    """A temporary directory that is deleted when the object is destroyed.
-    """
+    """A temporary directory that is deleted when the object is destroyed."""
+
+    tmpFilePath = pathlib.Path("./tmp/")
 
-    tmpFilePath = pathlib.Path('./tmp/')
-    
     def __init__(self):
         pathExists = os.path.exists(self.tmpFilePath)
         if not pathExists:
             os.makedirs(self.tmpFilePath)
 
     def path(self):
-        return str(self.tmpFilePath) + '/'
-    
+        return str(self.tmpFilePath) + "/"

+ 20 - 6
config-template.json

@@ -2,16 +2,30 @@
   "open_ai_api_key": "YOUR API KEY",
   "model": "gpt-3.5-turbo",
   "proxy": "",
-  "single_chat_prefix": ["bot", "@bot"],
+  "single_chat_prefix": [
+    "bot",
+    "@bot"
+  ],
   "single_chat_reply_prefix": "[bot] ",
-  "group_chat_prefix": ["@bot"],
-  "group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"],
-  "group_chat_in_one_session": ["ChatGPT测试群"],
-  "image_create_prefix": ["画", "看", "找"],
+  "group_chat_prefix": [
+    "@bot"
+  ],
+  "group_name_white_list": [
+    "ChatGPT测试群",
+    "ChatGPT测试群2"
+  ],
+  "group_chat_in_one_session": [
+    "ChatGPT测试群"
+  ],
+  "image_create_prefix": [
+    "画",
+    "看",
+    "找"
+  ],
   "speech_recognition": false,
   "group_speech_recognition": false,
   "voice_reply_voice": false,
   "conversation_max_tokens": 1000,
   "expires_in_seconds": 3600,
   "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。"
-}
+}

+ 21 - 32
config.py

@@ -3,9 +3,10 @@
 import json
 import logging
 import os
-from common.log import logger
 import pickle
 
+from common.log import logger
+
 # 将所有可用的配置项写在字典里, 请使用小写字母
 available_setting = {
     # openai api配置
@@ -16,8 +17,7 @@ available_setting = {
     # chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
     "model": "gpt-3.5-turbo",
     "use_azure_chatgpt": False,  # 是否使用azure的chatgpt
-    "azure_deployment_id": "", #azure 模型部署名称
-
+    "azure_deployment_id": "",  # azure 模型部署名称
     # Bot触发配置
     "single_chat_prefix": ["bot", "@bot"],  # 私聊时文本需要包含该前缀才能触发机器人回复
     "single_chat_reply_prefix": "[bot] ",  # 私聊时自动回复的前缀,用于区分真人
@@ -30,25 +30,21 @@ available_setting = {
     "group_chat_in_one_session": ["ChatGPT测试群"],  # 支持会话上下文共享的群名称
     "trigger_by_self": False,  # 是否允许机器人触发
     "image_create_prefix": ["画", "看", "找"],  # 开启图片回复的前缀
-    "concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序
-
+    "concurrency_in_session": 1,  # 同一会话最多有多少条消息在处理中,大于1可能乱序
     # chatgpt会话参数
     "expires_in_seconds": 3600,  # 无操作会话的过期时间
     "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",  # 人格描述
     "conversation_max_tokens": 1000,  # 支持上下文记忆的最多字符数
-
     # chatgpt限流配置
     "rate_limit_chatgpt": 20,  # chatgpt的调用频率限制
     "rate_limit_dalle": 50,  # openai dalle的调用频率限制
-
     # chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create
     "temperature": 0.9,
     "top_p": 1,
     "frequency_penalty": 0,
     "presence_penalty": 0,
-    "request_timeout": 60, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
-    "timeout": 120,         # chatgpt重试超时时间,在这个时间内,将会自动重试
-
+    "request_timeout": 60,  # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
+    "timeout": 120,  # chatgpt重试超时时间,在这个时间内,将会自动重试
     # 语音设置
     "speech_recognition": False,  # 是否开启语音识别
     "group_speech_recognition": False,  # 是否开启群组语音识别
@@ -56,50 +52,40 @@ available_setting = {
     "always_reply_voice": False,  # 是否一直使用语音回复
     "voice_to_text": "openai",  # 语音识别引擎,支持openai,baidu,google,azure
     "text_to_voice": "baidu",  # 语音合成引擎,支持baidu,google,pytts(offline),azure
-
     # baidu 语音api配置, 使用百度语音识别和语音合成时需要
     "baidu_app_id": "",
     "baidu_api_key": "",
     "baidu_secret_key": "",
     # 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场
     "baidu_dev_pid": "1536",
-
     # azure 语音api配置, 使用azure语音识别和语音合成时需要
     "azure_voice_api_key": "",
     "azure_voice_region": "japaneast",
-
     # 服务时间限制,目前支持itchat
     "chat_time_module": False,  # 是否开启服务时间限制
     "chat_start_time": "00:00",  # 服务开始时间
     "chat_stop_time": "24:00",  # 服务结束时间
-
     # itchat的配置
     "hot_reload": False,  # 是否开启热重载
-
     # wechaty的配置
     "wechaty_puppet_service_token": "",  # wechaty的token
-
     # wechatmp的配置
-    "wechatmp_token": "",       # 微信公众平台的Token
-    "wechatmp_port": 8080,      # 微信公众平台的端口,需要端口转发到80或443
-    "wechatmp_app_id": "",      # 微信公众平台的appID,仅服务号需要
+    "wechatmp_token": "",  # 微信公众平台的Token
+    "wechatmp_port": 8080,  # 微信公众平台的端口,需要端口转发到80或443
+    "wechatmp_app_id": "",  # 微信公众平台的appID,仅服务号需要
     "wechatmp_app_secret": "",  # 微信公众平台的appsecret,仅服务号需要
-
     # chatgpt指令自定义触发词
-    "clear_memory_commands": ['#清除记忆'],  # 重置会话指令,必须以#开头
-
+    "clear_memory_commands": ["#清除记忆"],  # 重置会话指令,必须以#开头
     # channel配置
-    "channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service}
-
+    "channel_type": "wx",  # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service}
     "debug": False,  # 是否开启debug模式,开启后会打印更多日志
-
     # 插件配置
     "plugin_trigger_prefix": "$",  # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突
 }
 
 
 class Config(dict):
-    def __init__(self, d:dict={}):
+    def __init__(self, d: dict = {}):
         super().__init__(d)
         # user_datas: 用户数据,key为用户名,value为用户数据,也是dict
         self.user_datas = {}
@@ -130,7 +116,7 @@ class Config(dict):
 
     def load_user_datas(self):
         try:
-            with open('user_datas.pkl', 'rb') as f:
+            with open("user_datas.pkl", "rb") as f:
                 self.user_datas = pickle.load(f)
                 logger.info("[Config] User datas loaded.")
         except FileNotFoundError as e:
@@ -141,12 +127,13 @@ class Config(dict):
 
     def save_user_datas(self):
         try:
-            with open('user_datas.pkl', 'wb') as f:
+            with open("user_datas.pkl", "wb") as f:
                 pickle.dump(self.user_datas, f)
                 logger.info("[Config] User datas saved.")
         except Exception as e:
             logger.info("[Config] User datas error: {}".format(e))
 
+
 config = Config()
 
 
@@ -154,7 +141,7 @@ def load_config():
     global config
     config_path = "./config.json"
     if not os.path.exists(config_path):
-        logger.info('配置文件不存在,将使用config-template.json模板')
+        logger.info("配置文件不存在,将使用config-template.json模板")
         config_path = "./config-template.json"
 
     config_str = read_file(config_path)
@@ -169,7 +156,8 @@ def load_config():
         name = name.lower()
         if name in available_setting:
             logger.info(
-                "[INIT] override config by environ args: {}={}".format(name, value))
+                "[INIT] override config by environ args: {}={}".format(name, value)
+            )
             try:
                 config[name] = eval(value)
             except:
@@ -182,18 +170,19 @@ def load_config():
 
     if config.get("debug", False):
         logger.setLevel(logging.DEBUG)
-        logger.debug("[INIT] set log level to DEBUG")        
+        logger.debug("[INIT] set log level to DEBUG")
 
     logger.info("[INIT] load config: {}".format(config))
 
     config.load_user_datas()
 
+
 def get_root():
     return os.path.dirname(os.path.abspath(__file__))
 
 
 def read_file(path):
-    with open(path, mode='r', encoding='utf-8') as f:
+    with open(path, mode="r", encoding="utf-8") as f:
         return f.read()
 
 

+ 1 - 1
docker/Dockerfile.debian

@@ -33,7 +33,7 @@ ADD ./entrypoint.sh /entrypoint.sh
 RUN chmod +x /entrypoint.sh \
     && groupadd -r noroot \
     && useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \
-    && chown -R noroot:noroot ${BUILD_PREFIX} 
+    && chown -R noroot:noroot ${BUILD_PREFIX}
 
 USER noroot
 

+ 1 - 1
docker/Dockerfile.debian.latest

@@ -18,7 +18,7 @@ RUN apt-get update \
     && pip install --no-cache -r requirements.txt \
     && pip install --no-cache -r requirements-optional.txt \
     && pip install azure-cognitiveservices-speech
-    
+
 WORKDIR ${BUILD_PREFIX}
 
 ADD docker/entrypoint.sh /entrypoint.sh

+ 1 - 2
docker/build.alpine.sh

@@ -11,6 +11,5 @@ docker build -f Dockerfile.alpine \
              -t zhayujie/chatgpt-on-wechat .
 
 # tag image
-docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:alpine 
+docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:alpine
 docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-alpine
-           

+ 1 - 1
docker/build.debian.sh

@@ -11,5 +11,5 @@ docker build -f Dockerfile.debian \
              -t zhayujie/chatgpt-on-wechat .
 
 # tag image
-docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:debian 
+docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:debian
 docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-debian

+ 1 - 1
docker/chatgpt-on-wechat-voice-reply/Dockerfile.alpine

@@ -9,7 +9,7 @@ RUN apk add --no-cache \
         ffmpeg  \
         espeak \
     && pip install --no-cache \
-        baidu-aip \ 
+        baidu-aip \
         chardet \
         SpeechRecognition
 

+ 1 - 1
docker/chatgpt-on-wechat-voice-reply/Dockerfile.debian

@@ -10,7 +10,7 @@ RUN apt-get update  \
         ffmpeg \
         espeak  \
     && pip install --no-cache \
-        baidu-aip \ 
+        baidu-aip \
         chardet \
         SpeechRecognition
 

+ 2 - 2
docker/sample-chatgpt-on-wechat/Makefile

@@ -11,13 +11,13 @@ run_d:
 	docker rm $(CONTAINER_NAME) || echo
 	docker run -dt  --name $(CONTAINER_NAME) $(PORT_MAP) \
 			--env-file=$(DOTENV) \
-			$(MOUNT) $(IMG) 
+			$(MOUNT) $(IMG)
 
 run_i:
 	docker rm $(CONTAINER_NAME) || echo
 	docker run -it  --name $(CONTAINER_NAME) $(PORT_MAP) \
 			--env-file=$(DOTENV) \
-			$(MOUNT) $(IMG) 
+			$(MOUNT) $(IMG)
 
 stop:
 	docker stop $(CONTAINER_NAME)

+ 14 - 14
plugins/README.md

@@ -24,17 +24,17 @@
 在本仓库中预置了一些插件,如果要安装其他仓库的插件,有两种方法。
 
 - 第一种方法是在将下载的插件文件都解压到"plugins"文件夹的一个单独的文件夹,最终插件的代码都位于"plugins/PLUGIN_NAME/*"中。启动程序后,如果插件的目录结构正确,插件会自动被扫描加载。除此以外,注意你还需要安装文件夹中`requirements.txt`中的依赖。
-    
+
 - 第二种方法是`Godcmd`插件,它是预置的管理员插件,能够让程序在运行时就能安装插件,它能够自动安装依赖。
-    
+
     安装插件的命令是"#installp [仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件名/仓库地址"。这是管理员命令,认证方法在[这里](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/godcmd)。
-    
+
     - 安装[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件:#installp sdwebui
 
     - 安装指定仓库的插件:#installp https://github.com/lanvent/plugin_sdwebui.git
-    
+
     在安装之后,需要执行"#scanp"命令来扫描加载新安装的插件(或者重新启动程序)。
-    
+
 安装插件后需要注意有些插件有自己的配置模板,一般要去掉".template"新建一个配置文件。
 
 ## 插件化实现
@@ -107,14 +107,14 @@
 ```
 
 回复`Reply`的定义如下所示,它允许Bot可以回复多类不同的消息。同时也加入了`INFO`和`ERROR`消息类型区分系统提示和系统错误。
-    
+
 ```python
     class ReplyType(Enum):
         TEXT = 1        # 文本
         VOICE = 2       # 音频文件
         IMAGE = 3       # 图片文件
         IMAGE_URL = 4   # 图片URL
-        
+
         INFO = 9
         ERROR = 10
     class Reply:
@@ -159,12 +159,12 @@
 
 目前支持三类触发事件:
 ```
-1.收到消息 
----> `ON_HANDLE_CONTEXT` 
-2.产生回复 
----> `ON_DECORATE_REPLY` 
-3.装饰回复 
----> `ON_SEND_REPLY` 
+1.收到消息
+---> `ON_HANDLE_CONTEXT`
+2.产生回复
+---> `ON_DECORATE_REPLY`
+3.装饰回复
+---> `ON_SEND_REPLY`
 4.发送回复
 ```
 
@@ -268,6 +268,6 @@ class Hello(Plugin):
 - 一个插件目录建议只注册一个插件类。建议使用单独的仓库维护插件,便于更新。
 
   在测试调试好后提交`PR`,把自己的仓库加入到[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)中。
-  
+
 - 插件的config文件、使用说明`README.md`、`requirement.txt`等放置在插件目录中。
 - 默认优先级不要超过管理员插件`Godcmd`的优先级(999),`Godcmd`插件提供了配置管理、插件管理等功能。

+ 2 - 2
plugins/__init__.py

@@ -1,9 +1,9 @@
-from .plugin_manager import PluginManager
 from .event import *
 from .plugin import *
+from .plugin_manager import PluginManager
 
 instance = PluginManager()
 
-register                    = instance.register
+register = instance.register
 # load_plugins                = instance.load_plugins
 # emit_event                  = instance.emit_event

+ 1 - 1
plugins/banwords/__init__.py

@@ -1 +1 @@
-from .banwords import *
+from .banwords import *

+ 49 - 35
plugins/banwords/banwords.py

@@ -2,56 +2,67 @@
 
 import json
 import os
+
+import plugins
 from bridge.context import ContextType
 from bridge.reply import Reply, ReplyType
-import plugins
-from plugins import *
 from common.log import logger
+from plugins import *
+
 from .lib.WordsSearch import WordsSearch
 
 
-@plugins.register(name="Banwords", desire_priority=100, hidden=True, desc="判断消息中是否有敏感词、决定是否回复。", version="1.0", author="lanvent")
+@plugins.register(
+    name="Banwords",
+    desire_priority=100,
+    hidden=True,
+    desc="判断消息中是否有敏感词、决定是否回复。",
+    version="1.0",
+    author="lanvent",
+)
 class Banwords(Plugin):
     def __init__(self):
         super().__init__()
         try:
-            curdir=os.path.dirname(__file__)
-            config_path=os.path.join(curdir,"config.json")
-            conf=None
+            curdir = os.path.dirname(__file__)
+            config_path = os.path.join(curdir, "config.json")
+            conf = None
             if not os.path.exists(config_path):
-                conf={"action":"ignore"}
-                with open(config_path,"w") as f:
-                    json.dump(conf,f,indent=4)
+                conf = {"action": "ignore"}
+                with open(config_path, "w") as f:
+                    json.dump(conf, f, indent=4)
             else:
-                with open(config_path,"r") as f:
-                    conf=json.load(f)
+                with open(config_path, "r") as f:
+                    conf = json.load(f)
             self.searchr = WordsSearch()
             self.action = conf["action"]
-            banwords_path = os.path.join(curdir,"banwords.txt")
-            with open(banwords_path, 'r', encoding='utf-8') as f:
-                words=[]
+            banwords_path = os.path.join(curdir, "banwords.txt")
+            with open(banwords_path, "r", encoding="utf-8") as f:
+                words = []
                 for line in f:
                     word = line.strip()
                     if word:
                         words.append(word)
             self.searchr.SetKeywords(words)
             self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
-            if conf.get("reply_filter",True):
+            if conf.get("reply_filter", True):
                 self.handlers[Event.ON_DECORATE_REPLY] = self.on_decorate_reply
-                self.reply_action = conf.get("reply_action","ignore")
+                self.reply_action = conf.get("reply_action", "ignore")
             logger.info("[Banwords] inited")
         except Exception as e:
-            logger.warn("[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords .")
+            logger.warn(
+                "[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords ."
+            )
             raise e
-        
-
 
     def on_handle_context(self, e_context: EventContext):
-
-        if e_context['context'].type not in [ContextType.TEXT,ContextType.IMAGE_CREATE]:
+        if e_context["context"].type not in [
+            ContextType.TEXT,
+            ContextType.IMAGE_CREATE,
+        ]:
             return
-        
-        content = e_context['context'].content
+
+        content = e_context["context"].content
         logger.debug("[Banwords] on_handle_context. content: %s" % content)
         if self.action == "ignore":
             f = self.searchr.FindFirst(content)
@@ -61,31 +72,34 @@ class Banwords(Plugin):
                 return
         elif self.action == "replace":
             if self.searchr.ContainsAny(content):
-                reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n"+self.searchr.Replace(content))
-                e_context['reply'] = reply
+                reply = Reply(
+                    ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content)
+                )
+                e_context["reply"] = reply
                 e_context.action = EventAction.BREAK_PASS
                 return
-            
-    def on_decorate_reply(self, e_context: EventContext):
 
-        if e_context['reply'].type not in [ReplyType.TEXT]:
+    def on_decorate_reply(self, e_context: EventContext):
+        if e_context["reply"].type not in [ReplyType.TEXT]:
             return
-        
-        reply = e_context['reply']
+
+        reply = e_context["reply"]
         content = reply.content
         if self.reply_action == "ignore":
             f = self.searchr.FindFirst(content)
             if f:
                 logger.info("[Banwords] %s in reply" % f["Keyword"])
-                e_context['reply'] = None
+                e_context["reply"] = None
                 e_context.action = EventAction.BREAK_PASS
                 return
         elif self.reply_action == "replace":
             if self.searchr.ContainsAny(content):
-                reply = Reply(ReplyType.INFO, "已替换回复中的敏感词: \n"+self.searchr.Replace(content))
-                e_context['reply'] = reply
+                reply = Reply(
+                    ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content)
+                )
+                e_context["reply"] = reply
                 e_context.action = EventAction.CONTINUE
                 return
-    
+
     def get_help_text(self, **kwargs):
-        return Banwords.desc
+        return Banwords.desc

+ 4 - 4
plugins/banwords/config.json.template

@@ -1,5 +1,5 @@
 {
-    "action": "replace",
-    "reply_filter": true,
-    "reply_action": "ignore"
-}
+  "action": "replace",
+  "reply_filter": true,
+  "reply_action": "ignore"
+}

+ 1 - 1
plugins/bdunit/README.md

@@ -24,7 +24,7 @@ see https://ai.baidu.com/unit/home#/home?track=61fe1b0d3407ce3face1d92cb5c291087
 ``` json
     {
     "service_id": "s...", #"机器人ID"
-    "api_key": "", 
+    "api_key": "",
     "secret_key": ""
     }
 ```

+ 1 - 1
plugins/bdunit/__init__.py

@@ -1 +1 @@
-from .bdunit import *
+from .bdunit import *

+ 30 - 42
plugins/bdunit/bdunit.py

@@ -2,21 +2,29 @@
 import json
 import os
 import uuid
+from uuid import getnode as get_mac
+
 import requests
+
+import plugins
 from bridge.context import ContextType
 from bridge.reply import Reply, ReplyType
 from common.log import logger
-import plugins
 from plugins import *
-from uuid import getnode as get_mac
-
 
 """利用百度UNIT实现智能对话
     如果命中意图,返回意图对应的回复,否则返回继续交付给下个插件处理
 """
 
 
-@plugins.register(name="BDunit", desire_priority=0, hidden=True, desc="Baidu unit bot system", version="0.1", author="jackson")
+@plugins.register(
+    name="BDunit",
+    desire_priority=0,
+    hidden=True,
+    desc="Baidu unit bot system",
+    version="0.1",
+    author="jackson",
+)
 class BDunit(Plugin):
     def __init__(self):
         super().__init__()
@@ -40,11 +48,10 @@ class BDunit(Plugin):
             raise e
 
     def on_handle_context(self, e_context: EventContext):
-
-        if e_context['context'].type != ContextType.TEXT:
+        if e_context["context"].type != ContextType.TEXT:
             return
 
-        content = e_context['context'].content
+        content = e_context["context"].content
         logger.debug("[BDunit] on_handle_context. content: %s" % content)
         parsed = self.getUnit2(content)
         intent = self.getIntent(parsed)
@@ -53,7 +60,7 @@ class BDunit(Plugin):
             reply = Reply()
             reply.type = ReplyType.TEXT
             reply.content = self.getSay(parsed)
-            e_context['reply'] = reply
+            e_context["reply"] = reply
             e_context.action = EventAction.BREAK_PASS  # 事件结束,并跳过处理context的默认逻辑
         else:
             e_context.action = EventAction.CONTINUE  # 事件继续,交付给下个插件或默认逻辑
@@ -70,17 +77,15 @@ class BDunit(Plugin):
             string: access_token
         """
         url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(
-            self.api_key, self.secret_key)
+            self.api_key, self.secret_key
+        )
         payload = ""
-        headers = {
-            'Content-Type': 'application/json',
-            'Accept': 'application/json'
-        }
+        headers = {"Content-Type": "application/json", "Accept": "application/json"}
 
         response = requests.request("POST", url, headers=headers, data=payload)
 
         # print(response.text)
-        return response.json()['access_token']
+        return response.json()["access_token"]
 
     def getUnit(self, query):
         """
@@ -90,11 +95,14 @@ class BDunit(Plugin):
         """
 
         url = (
-            'https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token='
+            "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token="
             + self.access_token
         )
-        request = {"query": query, "user_id": str(
-            get_mac())[:32], "terminal_id": "88888"}
+        request = {
+            "query": query,
+            "user_id": str(get_mac())[:32],
+            "terminal_id": "88888",
+        }
         body = {
             "log_id": str(uuid.uuid1()),
             "version": "3.0",
@@ -142,11 +150,7 @@ class BDunit(Plugin):
         :param parsed: UNIT 解析结果
         :returns: 意图数组
         """
-        if (
-            parsed
-            and "result" in parsed
-            and "response_list" in parsed["result"]
-        ):
+        if parsed and "result" in parsed and "response_list" in parsed["result"]:
             try:
                 return parsed["result"]["response_list"][0]["schema"]["intent"]
             except Exception as e:
@@ -163,11 +167,7 @@ class BDunit(Plugin):
         :param intent: 意图的名称
         :returns: True: 包含; False: 不包含
         """
-        if (
-            parsed
-            and "result" in parsed
-            and "response_list" in parsed["result"]
-        ):
+        if parsed and "result" in parsed and "response_list" in parsed["result"]:
             response_list = parsed["result"]["response_list"]
             for response in response_list:
                 if (
@@ -189,11 +189,7 @@ class BDunit(Plugin):
             :returns: 词槽列表。你可以通过 name 属性筛选词槽,
         再通过 normalized_word 属性取出相应的值
         """
-        if (
-            parsed
-            and "result" in parsed
-            and "response_list" in parsed["result"]
-        ):
+        if parsed and "result" in parsed and "response_list" in parsed["result"]:
             response_list = parsed["result"]["response_list"]
             if intent == "":
                 try:
@@ -236,11 +232,7 @@ class BDunit(Plugin):
         :param parsed: UNIT 解析结果
         :returns: UNIT 的回复文本
         """
-        if (
-            parsed
-            and "result" in parsed
-            and "response_list" in parsed["result"]
-        ):
+        if parsed and "result" in parsed and "response_list" in parsed["result"]:
             response_list = parsed["result"]["response_list"]
             answer = {}
             for response in response_list:
@@ -266,11 +258,7 @@ class BDunit(Plugin):
         :param intent: 意图的名称
         :returns: UNIT 的回复文本
         """
-        if (
-            parsed
-            and "result" in parsed
-            and "response_list" in parsed["result"]
-        ):
+        if parsed and "result" in parsed and "response_list" in parsed["result"]:
             response_list = parsed["result"]["response_list"]
             if intent == "":
                 try:

+ 4 - 4
plugins/bdunit/config.json.template

@@ -1,5 +1,5 @@
 {
-    "service_id": "s...",
-    "api_key": "",
-    "secret_key": ""
-}
+  "service_id": "s...",
+  "api_key": "",
+  "secret_key": ""
+}

+ 1 - 1
plugins/dungeon/__init__.py

@@ -1 +1 @@
-from .dungeon import *
+from .dungeon import *

+ 47 - 28
plugins/dungeon/dungeon.py

@@ -1,17 +1,18 @@
 # encoding:utf-8
 
+import plugins
 from bridge.bridge import Bridge
 from bridge.context import ContextType
 from bridge.reply import Reply, ReplyType
+from common import const
 from common.expired_dict import ExpiredDict
+from common.log import logger
 from config import conf
-import plugins
 from plugins import *
-from common.log import logger
-from common import const
+
 
 # https://github.com/bupticybee/ChineseAiDungeonChatGPT
-class StoryTeller():
+class StoryTeller:
     def __init__(self, bot, sessionid, story):
         self.bot = bot
         self.sessionid = sessionid
@@ -27,67 +28,85 @@ class StoryTeller():
         if user_action[-1] != "。":
             user_action = user_action + "。"
         if self.first_interact:
-            prompt = """现在来充当一个文字冒险游戏,描述时候注意节奏,不要太快,仔细描述各个人物的心情和周边环境。一次只需写四到六句话。
-            开头是,""" + self.story + " " + user_action
+            prompt = (
+                """现在来充当一个文字冒险游戏,描述时候注意节奏,不要太快,仔细描述各个人物的心情和周边环境。一次只需写四到六句话。
+            开头是,"""
+                + self.story
+                + " "
+                + user_action
+            )
             self.first_interact = False
         else:
             prompt = """继续,一次只需要续写四到六句话,总共就只讲5分钟内发生的事情。""" + user_action
         return prompt
 
 
-@plugins.register(name="Dungeon", desire_priority=0, namecn="文字冒险", desc="A plugin to play dungeon game", version="1.0", author="lanvent")
+@plugins.register(
+    name="Dungeon",
+    desire_priority=0,
+    namecn="文字冒险",
+    desc="A plugin to play dungeon game",
+    version="1.0",
+    author="lanvent",
+)
 class Dungeon(Plugin):
     def __init__(self):
         super().__init__()
         self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
         logger.info("[Dungeon] inited")
         # 目前没有设计session过期事件,这里先暂时使用过期字典
-        if conf().get('expires_in_seconds'):
-            self.games = ExpiredDict(conf().get('expires_in_seconds'))
+        if conf().get("expires_in_seconds"):
+            self.games = ExpiredDict(conf().get("expires_in_seconds"))
         else:
             self.games = dict()
 
     def on_handle_context(self, e_context: EventContext):
-
-        if e_context['context'].type != ContextType.TEXT:
+        if e_context["context"].type != ContextType.TEXT:
             return
         bottype = Bridge().get_bot_type("chat")
         if bottype not in (const.CHATGPT, const.OPEN_AI):
             return
         bot = Bridge().get_bot("chat")
-        content = e_context['context'].content[:]
-        clist = e_context['context'].content.split(maxsplit=1)
-        sessionid = e_context['context']['session_id']
+        content = e_context["context"].content[:]
+        clist = e_context["context"].content.split(maxsplit=1)
+        sessionid = e_context["context"]["session_id"]
         logger.debug("[Dungeon] on_handle_context. content: %s" % clist)
-        trigger_prefix = conf().get('plugin_trigger_prefix', "$")
+        trigger_prefix = conf().get("plugin_trigger_prefix", "$")
         if clist[0] == f"{trigger_prefix}停止冒险":
             if sessionid in self.games:
                 self.games[sessionid].reset()
                 del self.games[sessionid]
                 reply = Reply(ReplyType.INFO, "冒险结束!")
-                e_context['reply'] = reply
+                e_context["reply"] = reply
                 e_context.action = EventAction.BREAK_PASS
         elif clist[0] == f"{trigger_prefix}开始冒险" or sessionid in self.games:
             if sessionid not in self.games or clist[0] == f"{trigger_prefix}开始冒险":
-                if len(clist)>1 :
+                if len(clist) > 1:
                     story = clist[1]
                 else:
-                    story = "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。"
+                    story = (
+                        "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。"
+                    )
                 self.games[sessionid] = StoryTeller(bot, sessionid, story)
                 reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story)
-                e_context['reply'] = reply
-                e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
+                e_context["reply"] = reply
+                e_context.action = EventAction.BREAK_PASS  # 事件结束,并跳过处理context的默认逻辑
             else:
                 prompt = self.games[sessionid].action(content)
-                e_context['context'].type = ContextType.TEXT
-                e_context['context'].content = prompt
-                e_context.action = EventAction.BREAK # 事件结束,不跳过处理context的默认逻辑
+                e_context["context"].type = ContextType.TEXT
+                e_context["context"].content = prompt
+                e_context.action = EventAction.BREAK  # 事件结束,不跳过处理context的默认逻辑
+
     def get_help_text(self, **kwargs):
         help_text = "可以和机器人一起玩文字冒险游戏。\n"
-        if kwargs.get('verbose') != True:
+        if kwargs.get("verbose") != True:
             return help_text
-        trigger_prefix = conf().get('plugin_trigger_prefix', "$")
-        help_text = f"{trigger_prefix}开始冒险 "+"背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n"+f"{trigger_prefix}停止冒险: 结束游戏。\n"
-        if kwargs.get('verbose') == True:
+        trigger_prefix = conf().get("plugin_trigger_prefix", "$")
+        help_text = (
+            f"{trigger_prefix}开始冒险 "
+            + "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n"
+            + f"{trigger_prefix}停止冒险: 结束游戏。\n"
+        )
+        if kwargs.get("verbose") == True:
             help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'"
-        return help_text
+        return help_text

+ 6 - 6
plugins/event.py

@@ -9,17 +9,17 @@ class Event(Enum):
     e_context = {  "channel": 消息channel, "context" : 本次消息的context}
     """
 
-    ON_HANDLE_CONTEXT = 2   # 处理消息前
+    ON_HANDLE_CONTEXT = 2  # 处理消息前
     """
     e_context = {  "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空  }
     """
 
-    ON_DECORATE_REPLY = 3   # 得到回复后准备装饰
+    ON_DECORATE_REPLY = 3  # 得到回复后准备装饰
     """
     e_context = {  "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
     """
 
-    ON_SEND_REPLY = 4       # 发送回复前
+    ON_SEND_REPLY = 4  # 发送回复前
     """
     e_context = {  "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
     """
@@ -28,9 +28,9 @@ class Event(Enum):
 
 
 class EventAction(Enum):
-    CONTINUE = 1            # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑
-    BREAK = 2               # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑
-    BREAK_PASS = 3          # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑
+    CONTINUE = 1  # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑
+    BREAK = 2  # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑
+    BREAK_PASS = 3  # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑
 
 
 class EventContext:

+ 1 - 1
plugins/finish/__init__.py

@@ -1 +1 @@
-from .finish import *
+from .finish import *

+ 15 - 9
plugins/finish/finish.py

@@ -1,14 +1,21 @@
 # encoding:utf-8
 
+import plugins
 from bridge.context import ContextType
 from bridge.reply import Reply, ReplyType
+from common.log import logger
 from config import conf
-import plugins
 from plugins import *
-from common.log import logger
 
 
-@plugins.register(name="Finish", desire_priority=-999, hidden=True, desc="A plugin that check unknown command", version="1.0", author="js00000")
+@plugins.register(
+    name="Finish",
+    desire_priority=-999,
+    hidden=True,
+    desc="A plugin that check unknown command",
+    version="1.0",
+    author="js00000",
+)
 class Finish(Plugin):
     def __init__(self):
         super().__init__()
@@ -16,19 +23,18 @@ class Finish(Plugin):
         logger.info("[Finish] inited")
 
     def on_handle_context(self, e_context: EventContext):
-
-        if e_context['context'].type != ContextType.TEXT:
+        if e_context["context"].type != ContextType.TEXT:
             return
 
-        content = e_context['context'].content
+        content = e_context["context"].content
         logger.debug("[Finish] on_handle_context. content: %s" % content)
-        trigger_prefix = conf().get('plugin_trigger_prefix',"$")
+        trigger_prefix = conf().get("plugin_trigger_prefix", "$")
         if content.startswith(trigger_prefix):
             reply = Reply()
             reply.type = ReplyType.ERROR
             reply.content = "未知插件命令\n查看插件命令列表请输入#help 插件名\n"
-            e_context['reply'] = reply
-            e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
+            e_context["reply"] = reply
+            e_context.action = EventAction.BREAK_PASS  # 事件结束,并跳过处理context的默认逻辑
 
     def get_help_text(self, **kwargs):
         return ""

+ 1 - 1
plugins/godcmd/__init__.py

@@ -1 +1 @@
-from .godcmd import *
+from .godcmd import *

+ 3 - 3
plugins/godcmd/config.json.template

@@ -1,4 +1,4 @@
 {
-    "password": "",
-    "admin_users": []
-}
+  "password": "",
+  "admin_users": []
+}

+ 96 - 70
plugins/godcmd/godcmd.py

@@ -6,14 +6,16 @@ import random
 import string
 import traceback
 from typing import Tuple
+
+import plugins
 from bridge.bridge import Bridge
 from bridge.context import ContextType
 from bridge.reply import Reply, ReplyType
-from config import conf, load_config
-import plugins
-from plugins import *
 from common import const
 from common.log import logger
+from config import conf, load_config
+from plugins import *
+
 # 定义指令集
 COMMANDS = {
     "help": {
@@ -41,7 +43,7 @@ COMMANDS = {
     },
     "id": {
         "alias": ["id", "用户"],
-        "desc": "获取用户id", # wechaty和wechatmp的用户id不会变化,可用于绑定管理员
+        "desc": "获取用户id",  # wechaty和wechatmp的用户id不会变化,可用于绑定管理员
     },
     "reset": {
         "alias": ["reset", "重置会话"],
@@ -114,18 +116,20 @@ ADMIN_COMMANDS = {
         "desc": "开启机器调试日志",
     },
 }
+
+
 # 定义帮助函数
 def get_help_text(isadmin, isgroup):
     help_text = "通用指令:\n"
     for cmd, info in COMMANDS.items():
-        if cmd=="auth": #不提示认证指令
+        if cmd == "auth":  # 不提示认证指令
             continue
-        if cmd=="id" and conf().get("channel_type","wx") not in ["wxy","wechatmp"]:
+        if cmd == "id" and conf().get("channel_type", "wx") not in ["wxy", "wechatmp"]:
             continue
-        alias=["#"+a for a in info['alias'][:1]]
+        alias = ["#" + a for a in info["alias"][:1]]
         help_text += f"{','.join(alias)} "
-        if 'args' in info:
-            args=[a for a in info['args']]
+        if "args" in info:
+            args = [a for a in info["args"]]
             help_text += f"{' '.join(args)}"
         help_text += f": {info['desc']}\n"
 
@@ -135,39 +139,48 @@ def get_help_text(isadmin, isgroup):
     for plugin in plugins:
         if plugins[plugin].enabled and not plugins[plugin].hidden:
             namecn = plugins[plugin].namecn
-            help_text += "\n%s:"%namecn
-            help_text += PluginManager().instances[plugin].get_help_text(verbose=False).strip()
+            help_text += "\n%s:" % namecn
+            help_text += (
+                PluginManager().instances[plugin].get_help_text(verbose=False).strip()
+            )
 
     if ADMIN_COMMANDS and isadmin:
         help_text += "\n\n管理员指令:\n"
         for cmd, info in ADMIN_COMMANDS.items():
-            alias=["#"+a for a in info['alias'][:1]]
+            alias = ["#" + a for a in info["alias"][:1]]
             help_text += f"{','.join(alias)} "
-            if 'args' in info:
-                args=[a for a in info['args']]
+            if "args" in info:
+                args = [a for a in info["args"]]
                 help_text += f"{' '.join(args)}"
             help_text += f": {info['desc']}\n"
     return help_text
 
-@plugins.register(name="Godcmd", desire_priority=999, hidden=True, desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", version="1.0", author="lanvent")
-class Godcmd(Plugin):
 
+@plugins.register(
+    name="Godcmd",
+    desire_priority=999,
+    hidden=True,
+    desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证",
+    version="1.0",
+    author="lanvent",
+)
+class Godcmd(Plugin):
     def __init__(self):
         super().__init__()
 
-        curdir=os.path.dirname(__file__)
-        config_path=os.path.join(curdir,"config.json")
-        gconf=None
+        curdir = os.path.dirname(__file__)
+        config_path = os.path.join(curdir, "config.json")
+        gconf = None
         if not os.path.exists(config_path):
-            gconf={"password":"","admin_users":[]}
-            with open(config_path,"w") as f:
-                json.dump(gconf,f,indent=4)
+            gconf = {"password": "", "admin_users": []}
+            with open(config_path, "w") as f:
+                json.dump(gconf, f, indent=4)
         else:
-            with open(config_path,"r") as f:
-                gconf=json.load(f)
+            with open(config_path, "r") as f:
+                gconf = json.load(f)
         if gconf["password"] == "":
             self.temp_password = "".join(random.sample(string.digits, 4))
-            logger.info("[Godcmd] 因未设置口令,本次的临时口令为%s。"%self.temp_password)
+            logger.info("[Godcmd] 因未设置口令,本次的临时口令为%s。" % self.temp_password)
         else:
             self.temp_password = None
         custom_commands = conf().get("clear_memory_commands", [])
@@ -178,41 +191,42 @@ class Godcmd(Plugin):
                     COMMANDS["reset"]["alias"].append(custom_command)
 
         self.password = gconf["password"]
-        self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用
-        self.isrunning = True # 机器人是否运行中
+        self.admin_users = gconf[
+            "admin_users"
+        ]  # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用
+        self.isrunning = True  # 机器人是否运行中
 
         self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
         logger.info("[Godcmd] inited")
 
-
     def on_handle_context(self, e_context: EventContext):
-        context_type = e_context['context'].type
+        context_type = e_context["context"].type
         if context_type != ContextType.TEXT:
             if not self.isrunning:
                 e_context.action = EventAction.BREAK_PASS
             return
 
-        content = e_context['context'].content
+        content = e_context["context"].content
         logger.debug("[Godcmd] on_handle_context. content: %s" % content)
         if content.startswith("#"):
             # msg = e_context['context']['msg']
-            channel = e_context['channel']
-            user = e_context['context']['receiver']
-            session_id = e_context['context']['session_id']
-            isgroup = e_context['context'].get("isgroup", False)
+            channel = e_context["channel"]
+            user = e_context["context"]["receiver"]
+            session_id = e_context["context"]["session_id"]
+            isgroup = e_context["context"].get("isgroup", False)
             bottype = Bridge().get_bot_type("chat")
             bot = Bridge().get_bot("chat")
             # 将命令和参数分割
             command_parts = content[1:].strip().split()
             cmd = command_parts[0]
             args = command_parts[1:]
-            isadmin=False
+            isadmin = False
             if user in self.admin_users:
-                isadmin=True
-            ok=False
-            result="string"
-            if any(cmd in info['alias'] for info in COMMANDS.values()):
-                cmd = next(c for c, info in COMMANDS.items() if cmd in info['alias'])
+                isadmin = True
+            ok = False
+            result = "string"
+            if any(cmd in info["alias"] for info in COMMANDS.values()):
+                cmd = next(c for c, info in COMMANDS.items() if cmd in info["alias"])
                 if cmd == "auth":
                     ok, result = self.authenticate(user, args, isadmin, isgroup)
                 elif cmd == "help" or cmd == "helpp":
@@ -224,10 +238,14 @@ class Godcmd(Plugin):
                         query_name = args[0].upper()
                         # search name and namecn
                         for name, plugincls in plugins.items():
-                            if not plugincls.enabled :
+                            if not plugincls.enabled:
                                 continue
                             if query_name == name or query_name == plugincls.namecn:
-                                ok, result = True, PluginManager().instances[name].get_help_text(isgroup=isgroup, isadmin=isadmin, verbose=True)
+                                ok, result = True, PluginManager().instances[
+                                    name
+                                ].get_help_text(
+                                    isgroup=isgroup, isadmin=isadmin, verbose=True
+                                )
                                 break
                         if not ok:
                             result = "插件不存在或未启用"
@@ -236,14 +254,14 @@ class Godcmd(Plugin):
                 elif cmd == "set_openai_api_key":
                     if len(args) == 1:
                         user_data = conf().get_user_data(user)
-                        user_data['openai_api_key'] = args[0]
+                        user_data["openai_api_key"] = args[0]
                         ok, result = True, "你的OpenAI私有api_key已设置为" + args[0]
                     else:
                         ok, result = False, "请提供一个api_key"
                 elif cmd == "reset_openai_api_key":
                     try:
                         user_data = conf().get_user_data(user)
-                        user_data.pop('openai_api_key')
+                        user_data.pop("openai_api_key")
                         ok, result = True, "你的OpenAI私有api_key已清除"
                     except Exception as e:
                         ok, result = False, "你没有设置私有api_key"
@@ -255,12 +273,16 @@ class Godcmd(Plugin):
                     else:
                         ok, result = False, "当前对话机器人不支持重置会话"
                 logger.debug("[Godcmd] command: %s by %s" % (cmd, user))
-            elif any(cmd in info['alias'] for info in ADMIN_COMMANDS.values()):
+            elif any(cmd in info["alias"] for info in ADMIN_COMMANDS.values()):
                 if isadmin:
                     if isgroup:
                         ok, result = False, "群聊不可执行管理员指令"
                     else:
-                        cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info['alias'])
+                        cmd = next(
+                            c
+                            for c, info in ADMIN_COMMANDS.items()
+                            if cmd in info["alias"]
+                        )
                         if cmd == "stop":
                             self.isrunning = False
                             ok, result = True, "服务已暂停"
@@ -278,13 +300,13 @@ class Godcmd(Plugin):
                             else:
                                 ok, result = False, "当前对话机器人不支持重置会话"
                         elif cmd == "debug":
-                            logger.setLevel('DEBUG')
+                            logger.setLevel("DEBUG")
                             ok, result = True, "DEBUG模式已开启"
                         elif cmd == "plist":
                             plugins = PluginManager().list_plugins()
                             ok = True
                             result = "插件列表:\n"
-                            for name,plugincls in plugins.items():
+                            for name, plugincls in plugins.items():
                                 result += f"{plugincls.name}_v{plugincls.version} {plugincls.priority} - "
                                 if plugincls.enabled:
                                     result += "已启用\n"
@@ -294,16 +316,20 @@ class Godcmd(Plugin):
                             new_plugins = PluginManager().scan_plugins()
                             ok, result = True, "插件扫描完成"
                             PluginManager().activate_plugins()
-                            if len(new_plugins) >0 :
+                            if len(new_plugins) > 0:
                                 result += "\n发现新插件:\n"
-                                result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins])
-                            else :
-                                result +=", 未发现新插件"
+                                result += "\n".join(
+                                    [f"{p.name}_v{p.version}" for p in new_plugins]
+                                )
+                            else:
+                                result += ", 未发现新插件"
                         elif cmd == "setpri":
                             if len(args) != 2:
                                 ok, result = False, "请提供插件名和优先级"
                             else:
-                                ok = PluginManager().set_plugin_priority(args[0], int(args[1]))
+                                ok = PluginManager().set_plugin_priority(
+                                    args[0], int(args[1])
+                                )
                                 if ok:
                                     result = "插件" + args[0] + "优先级已设置为" + args[1]
                                 else:
@@ -350,42 +376,42 @@ class Godcmd(Plugin):
                 else:
                     ok, result = False, "需要管理员权限才能执行该指令"
             else:
-                trigger_prefix = conf().get('plugin_trigger_prefix',"$")
-                if trigger_prefix == "#": # 跟插件聊天指令前缀相同,继续递交
+                trigger_prefix = conf().get("plugin_trigger_prefix", "$")
+                if trigger_prefix == "#":  # 跟插件聊天指令前缀相同,继续递交
                     return
                 ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n"
-            
+
             reply = Reply()
             if ok:
                 reply.type = ReplyType.INFO
             else:
                 reply.type = ReplyType.ERROR
             reply.content = result
-            e_context['reply'] = reply
+            e_context["reply"] = reply
 
-            e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
+            e_context.action = EventAction.BREAK_PASS  # 事件结束,并跳过处理context的默认逻辑
         elif not self.isrunning:
             e_context.action = EventAction.BREAK_PASS
 
-    def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool,str] : 
+    def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool, str]:
         if isgroup:
-            return False,"请勿在群聊中认证"
-        
+            return False, "请勿在群聊中认证"
+
         if isadmin:
-            return False,"管理员账号无需认证"
-        
+            return False, "管理员账号无需认证"
+
         if len(args) != 1:
-            return False,"请提供口令"
-        
+            return False, "请提供口令"
+
         password = args[0]
         if password == self.password:
             self.admin_users.append(userid)
-            return True,"认证成功"
+            return True, "认证成功"
         elif password == self.temp_password:
             self.admin_users.append(userid)
-            return True,"认证成功,请尽快设置口令"
+            return True, "认证成功,请尽快设置口令"
         else:
-            return False,"认证失败"
+            return False, "认证失败"
 
-    def get_help_text(self, isadmin = False, isgroup = False, **kwargs):
-        return get_help_text(isadmin, isgroup)
+    def get_help_text(self, isadmin=False, isgroup=False, **kwargs):
+        return get_help_text(isadmin, isgroup)

+ 1 - 1
plugins/hello/__init__.py

@@ -1 +1 @@
-from .hello import *
+from .hello import *

+ 22 - 14
plugins/hello/hello.py

@@ -1,14 +1,21 @@
 # encoding:utf-8
 
+import plugins
 from bridge.context import ContextType
 from bridge.reply import Reply, ReplyType
 from channel.chat_message import ChatMessage
-import plugins
-from plugins import *
 from common.log import logger
+from plugins import *
 
 
-@plugins.register(name="Hello", desire_priority=-1, hidden=True, desc="A simple plugin that says hello", version="0.1", author="lanvent")
+@plugins.register(
+    name="Hello",
+    desire_priority=-1,
+    hidden=True,
+    desc="A simple plugin that says hello",
+    version="0.1",
+    author="lanvent",
+)
 class Hello(Plugin):
     def __init__(self):
         super().__init__()
@@ -16,33 +23,34 @@ class Hello(Plugin):
         logger.info("[Hello] inited")
 
     def on_handle_context(self, e_context: EventContext):
-
-        if e_context['context'].type != ContextType.TEXT:
+        if e_context["context"].type != ContextType.TEXT:
             return
-        
-        content = e_context['context'].content
+
+        content = e_context["context"].content
         logger.debug("[Hello] on_handle_context. content: %s" % content)
         if content == "Hello":
             reply = Reply()
             reply.type = ReplyType.TEXT
-            msg:ChatMessage = e_context['context']['msg']
-            if e_context['context']['isgroup']:
-                reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
+            msg: ChatMessage = e_context["context"]["msg"]
+            if e_context["context"]["isgroup"]:
+                reply.content = (
+                    f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
+                )
             else:
                 reply.content = f"Hello, {msg.from_user_nickname}"
-            e_context['reply'] = reply
-            e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
+            e_context["reply"] = reply
+            e_context.action = EventAction.BREAK_PASS  # 事件结束,并跳过处理context的默认逻辑
 
         if content == "Hi":
             reply = Reply()
             reply.type = ReplyType.TEXT
             reply.content = "Hi"
-            e_context['reply'] = reply
+            e_context["reply"] = reply
             e_context.action = EventAction.BREAK  # 事件结束,进入默认处理逻辑,一般会覆写reply
 
         if content == "End":
             # 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World"
-            e_context['context'].type = ContextType.IMAGE_CREATE
+            e_context["context"].type = ContextType.IMAGE_CREATE
             content = "The World"
             e_context.action = EventAction.CONTINUE  # 事件继续,交付给下个插件或默认逻辑
 

+ 1 - 1
plugins/plugin.py

@@ -3,4 +3,4 @@ class Plugin:
         self.handlers = {}
 
     def get_help_text(self, **kwargs):
-        return "暂无帮助信息"
+        return "暂无帮助信息"

+ 104 - 52
plugins/plugin_manager.py

@@ -5,17 +5,19 @@ import importlib.util
 import json
 import os
 import sys
+
+from common.log import logger
 from common.singleton import singleton
 from common.sorted_dict import SortedDict
-from .event import *
-from common.log import logger
 from config import conf
 
+from .event import *
+
 
 @singleton
 class PluginManager:
     def __init__(self):
-        self.plugins = SortedDict(lambda k,v: v.priority,reverse=True)
+        self.plugins = SortedDict(lambda k, v: v.priority, reverse=True)
         self.listening_plugins = {}
         self.instances = {}
         self.pconf = {}
@@ -26,17 +28,27 @@ class PluginManager:
         def wrapper(plugincls):
             plugincls.name = name
             plugincls.priority = desire_priority
-            plugincls.desc = kwargs.get('desc')
-            plugincls.author = kwargs.get('author')
+            plugincls.desc = kwargs.get("desc")
+            plugincls.author = kwargs.get("author")
             plugincls.path = self.current_plugin_path
-            plugincls.version = kwargs.get('version') if kwargs.get('version') != None else "1.0"
-            plugincls.namecn = kwargs.get('namecn') if kwargs.get('namecn') != None else name
-            plugincls.hidden = kwargs.get('hidden') if kwargs.get('hidden') != None else False
+            plugincls.version = (
+                kwargs.get("version") if kwargs.get("version") != None else "1.0"
+            )
+            plugincls.namecn = (
+                kwargs.get("namecn") if kwargs.get("namecn") != None else name
+            )
+            plugincls.hidden = (
+                kwargs.get("hidden") if kwargs.get("hidden") != None else False
+            )
             plugincls.enabled = True
             if self.current_plugin_path == None:
                 raise Exception("Plugin path not set")
             self.plugins[name.upper()] = plugincls
-            logger.info("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path))
+            logger.info(
+                "Plugin %s_v%s registered, path=%s"
+                % (name, plugincls.version, plugincls.path)
+            )
+
         return wrapper
 
     def save_config(self):
@@ -50,10 +62,12 @@ class PluginManager:
         if os.path.exists("./plugins/plugins.json"):
             with open("./plugins/plugins.json", "r", encoding="utf-8") as f:
                 pconf = json.load(f)
-                pconf['plugins'] = SortedDict(lambda k,v: v["priority"],pconf['plugins'],reverse=True)
+                pconf["plugins"] = SortedDict(
+                    lambda k, v: v["priority"], pconf["plugins"], reverse=True
+                )
         else:
             modified = True
-            pconf = {"plugins": SortedDict(lambda k,v: v["priority"],reverse=True)}
+            pconf = {"plugins": SortedDict(lambda k, v: v["priority"], reverse=True)}
         self.pconf = pconf
         if modified:
             self.save_config()
@@ -67,7 +81,7 @@ class PluginManager:
             plugin_path = os.path.join(plugins_dir, plugin_name)
             if os.path.isdir(plugin_path):
                 # 判断插件是否包含同名__init__.py文件
-                main_module_path = os.path.join(plugin_path,"__init__.py")
+                main_module_path = os.path.join(plugin_path, "__init__.py")
                 if os.path.isfile(main_module_path):
                     # 导入插件
                     import_path = "plugins.{}".format(plugin_name)
@@ -76,16 +90,26 @@ class PluginManager:
                         if plugin_path in self.loaded:
                             if self.loaded[plugin_path] == None:
                                 logger.info("reload module %s" % plugin_name)
-                                self.loaded[plugin_path] = importlib.reload(sys.modules[import_path])
-                                dependent_module_names = [name for name in sys.modules.keys() if name.startswith( import_path+ '.')]
+                                self.loaded[plugin_path] = importlib.reload(
+                                    sys.modules[import_path]
+                                )
+                                dependent_module_names = [
+                                    name
+                                    for name in sys.modules.keys()
+                                    if name.startswith(import_path + ".")
+                                ]
                                 for name in dependent_module_names:
                                     logger.info("reload module %s" % name)
                                     importlib.reload(sys.modules[name])
                         else:
-                            self.loaded[plugin_path] = importlib.import_module(import_path)
+                            self.loaded[plugin_path] = importlib.import_module(
+                                import_path
+                            )
                         self.current_plugin_path = None
                     except Exception as e:
-                        logger.exception("Failed to import plugin %s: %s" % (plugin_name, e))
+                        logger.exception(
+                            "Failed to import plugin %s: %s" % (plugin_name, e)
+                        )
                         continue
         pconf = self.pconf
         news = [self.plugins[name] for name in self.plugins]
@@ -95,21 +119,28 @@ class PluginManager:
             rawname = plugincls.name
             if rawname not in pconf["plugins"]:
                 modified = True
-                logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name)
-                pconf["plugins"][rawname] = {"enabled": plugincls.enabled, "priority": plugincls.priority}
+                logger.info(
+                    "Plugin %s not found in pconfig, adding to pconfig..." % name
+                )
+                pconf["plugins"][rawname] = {
+                    "enabled": plugincls.enabled,
+                    "priority": plugincls.priority,
+                }
             else:
                 self.plugins[name].enabled = pconf["plugins"][rawname]["enabled"]
                 self.plugins[name].priority = pconf["plugins"][rawname]["priority"]
-                self.plugins._update_heap(name) # 更新下plugins中的顺序
+                self.plugins._update_heap(name)  # 更新下plugins中的顺序
         if modified:
             self.save_config()
         return new_plugins
 
     def refresh_order(self):
         for event in self.listening_plugins.keys():
-            self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True)
+            self.listening_plugins[event].sort(
+                key=lambda name: self.plugins[name].priority, reverse=True
+            )
 
-    def activate_plugins(self): # 生成新开启的插件实例
+    def activate_plugins(self):  # 生成新开启的插件实例
         failed_plugins = []
         for name, plugincls in self.plugins.items():
             if plugincls.enabled:
@@ -129,7 +160,7 @@ class PluginManager:
         self.refresh_order()
         return failed_plugins
 
-    def reload_plugin(self, name:str):
+    def reload_plugin(self, name: str):
         name = name.upper()
         if name in self.instances:
             for event in self.listening_plugins:
@@ -139,13 +170,13 @@ class PluginManager:
             self.activate_plugins()
             return True
         return False
-    
+
     def load_plugins(self):
         self.load_config()
         self.scan_plugins()
         pconf = self.pconf
         logger.debug("plugins.json config={}".format(pconf))
-        for name,plugin in pconf["plugins"].items():
+        for name, plugin in pconf["plugins"].items():
             if name.upper() not in self.plugins:
                 logger.error("Plugin %s not found, but found in plugins.json" % name)
         self.activate_plugins()
@@ -153,13 +184,18 @@ class PluginManager:
     def emit_event(self, e_context: EventContext, *args, **kwargs):
         if e_context.event in self.listening_plugins:
             for name in self.listening_plugins[e_context.event]:
-                if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE:
-                    logger.debug("Plugin %s triggered by event %s" % (name,e_context.event))
+                if (
+                    self.plugins[name].enabled
+                    and e_context.action == EventAction.CONTINUE
+                ):
+                    logger.debug(
+                        "Plugin %s triggered by event %s" % (name, e_context.event)
+                    )
                     instance = self.instances[name]
                     instance.handlers[e_context.event](e_context, *args, **kwargs)
         return e_context
 
-    def set_plugin_priority(self, name:str, priority:int):
+    def set_plugin_priority(self, name: str, priority: int):
         name = name.upper()
         if name not in self.plugins:
             return False
@@ -174,11 +210,11 @@ class PluginManager:
         self.refresh_order()
         return True
 
-    def enable_plugin(self, name:str):
+    def enable_plugin(self, name: str):
         name = name.upper()
         if name not in self.plugins:
             return False, "插件不存在"
-        if not self.plugins[name].enabled :
+        if not self.plugins[name].enabled:
             self.plugins[name].enabled = True
             rawname = self.plugins[name].name
             self.pconf["plugins"][rawname]["enabled"] = True
@@ -188,43 +224,47 @@ class PluginManager:
                 return False, "插件开启失败"
             return True, "插件已开启"
         return True, "插件已开启"
-    
-    def disable_plugin(self, name:str):
+
+    def disable_plugin(self, name: str):
         name = name.upper()
         if name not in self.plugins:
             return False
-        if self.plugins[name].enabled :
+        if self.plugins[name].enabled:
             self.plugins[name].enabled = False
             rawname = self.plugins[name].name
             self.pconf["plugins"][rawname]["enabled"] = False
             self.save_config()
             return True
         return True
-    
+
     def list_plugins(self):
         return self.plugins
-    
-    def install_plugin(self, repo:str):
+
+    def install_plugin(self, repo: str):
         try:
             import common.package_manager as pkgmgr
+
             pkgmgr.check_dulwich()
         except Exception as e:
             logger.error("Failed to install plugin, {}".format(e))
             return False, "无法导入dulwich,安装插件失败"
         import re
+
         from dulwich import porcelain
 
         logger.info("clone git repo: {}".format(repo))
-        
+
         match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo)
-        
+
         if not match:
             try:
-                with open("./plugins/source.json","r", encoding="utf-8") as f:
+                with open("./plugins/source.json", "r", encoding="utf-8") as f:
                     source = json.load(f)
                 if repo in source["repo"]:
                     repo = source["repo"][repo]["url"]
-                    match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo)
+                    match = re.match(
+                        r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo
+                    )
                     if not match:
                         return False, "安装插件失败,source中的仓库地址不合法"
                 else:
@@ -232,42 +272,53 @@ class PluginManager:
             except Exception as e:
                 logger.error("Failed to install plugin, {}".format(e))
                 return False, "安装插件失败,请检查仓库地址是否正确"
-        dirname = os.path.join("./plugins",match.group(4))
+        dirname = os.path.join("./plugins", match.group(4))
         try:
             repo = porcelain.clone(repo, dirname, checkout=True)
-            if os.path.exists(os.path.join(dirname,"requirements.txt")):
+            if os.path.exists(os.path.join(dirname, "requirements.txt")):
                 logger.info("detect requirements.txt,installing...")
-            pkgmgr.install_requirements(os.path.join(dirname,"requirements.txt"))
+            pkgmgr.install_requirements(os.path.join(dirname, "requirements.txt"))
             return True, "安装插件成功,请使用 #scanp 命令扫描插件或重启程序,开启前请检查插件是否需要配置"
         except Exception as e:
             logger.error("Failed to install plugin, {}".format(e))
-            return False, "安装插件失败,"+str(e)
-        
-    def update_plugin(self, name:str):
+            return False, "安装插件失败," + str(e)
+
+    def update_plugin(self, name: str):
         try:
             import common.package_manager as pkgmgr
+
             pkgmgr.check_dulwich()
         except Exception as e:
             logger.error("Failed to install plugin, {}".format(e))
             return False, "无法导入dulwich,更新插件失败"
         from dulwich import porcelain
+
         name = name.upper()
         if name not in self.plugins:
             return False, "插件不存在"
-        if name in ["HELLO","GODCMD","ROLE","TOOL","BDUNIT","BANWORDS","FINISH","DUNGEON"]:
+        if name in [
+            "HELLO",
+            "GODCMD",
+            "ROLE",
+            "TOOL",
+            "BDUNIT",
+            "BANWORDS",
+            "FINISH",
+            "DUNGEON",
+        ]:
             return False, "预置插件无法更新,请更新主程序仓库"
         dirname = self.plugins[name].path
         try:
             porcelain.pull(dirname, "origin")
-            if os.path.exists(os.path.join(dirname,"requirements.txt")):
+            if os.path.exists(os.path.join(dirname, "requirements.txt")):
                 logger.info("detect requirements.txt,installing...")
-            pkgmgr.install_requirements(os.path.join(dirname,"requirements.txt"))
+            pkgmgr.install_requirements(os.path.join(dirname, "requirements.txt"))
             return True, "更新插件成功,请重新运行程序"
         except Exception as e:
             logger.error("Failed to update plugin, {}".format(e))
-            return False, "更新插件失败,"+str(e)
-        
-    def uninstall_plugin(self, name:str):
+            return False, "更新插件失败," + str(e)
+
+    def uninstall_plugin(self, name: str):
         name = name.upper()
         if name not in self.plugins:
             return False, "插件不存在"
@@ -276,6 +327,7 @@ class PluginManager:
         dirname = self.plugins[name].path
         try:
             import shutil
+
             shutil.rmtree(dirname)
             rawname = self.plugins[name].name
             for event in self.listening_plugins:
@@ -288,4 +340,4 @@ class PluginManager:
             return True, "卸载插件成功"
         except Exception as e:
             logger.error("Failed to uninstall plugin, {}".format(e))
-            return False, "卸载插件失败,请手动删除文件夹完成卸载,"+str(e)
+            return False, "卸载插件失败,请手动删除文件夹完成卸载," + str(e)

+ 1 - 1
plugins/role/__init__.py

@@ -1 +1 @@
-from .role import *
+from .role import *

+ 66 - 35
plugins/role/role.py

@@ -2,17 +2,18 @@
 
 import json
 import os
+
+import plugins
 from bridge.bridge import Bridge
 from bridge.context import ContextType
 from bridge.reply import Reply, ReplyType
 from common import const
+from common.log import logger
 from config import conf
-import plugins
 from plugins import *
-from common.log import logger
 
 
-class RolePlay():
+class RolePlay:
     def __init__(self, bot, sessionid, desc, wrapper=None):
         self.bot = bot
         self.sessionid = sessionid
@@ -25,12 +26,20 @@ class RolePlay():
 
     def action(self, user_action):
         session = self.bot.sessions.build_session(self.sessionid)
-        if session.system_prompt != self.desc: # 目前没有触发session过期事件,这里先简单判断,然后重置
+        if session.system_prompt != self.desc:  # 目前没有触发session过期事件,这里先简单判断,然后重置
             session.set_system_prompt(self.desc)
         prompt = self.wrapper % user_action
         return prompt
 
-@plugins.register(name="Role", desire_priority=0, namecn="角色扮演", desc="为你的Bot设置预设角色", version="1.0", author="lanvent")
+
+@plugins.register(
+    name="Role",
+    desire_priority=0,
+    namecn="角色扮演",
+    desc="为你的Bot设置预设角色",
+    version="1.0",
+    author="lanvent",
+)
 class Role(Plugin):
     def __init__(self):
         super().__init__()
@@ -39,7 +48,7 @@ class Role(Plugin):
         try:
             with open(config_path, "r", encoding="utf-8") as f:
                 config = json.load(f)
-                self.tags = { tag:(desc,[]) for tag,desc in config["tags"].items()}
+                self.tags = {tag: (desc, []) for tag, desc in config["tags"].items()}
                 self.roles = {}
                 for role in config["roles"]:
                     self.roles[role["title"].lower()] = role
@@ -60,12 +69,16 @@ class Role(Plugin):
             logger.info("[Role] inited")
         except Exception as e:
             if isinstance(e, FileNotFoundError):
-                logger.warn(f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
+                logger.warn(
+                    f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ."
+                )
             else:
-                logger.warn("[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
+                logger.warn(
+                    "[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ."
+                )
             raise e
 
-    def get_role(self, name, find_closest=True, min_sim = 0.35):
+    def get_role(self, name, find_closest=True, min_sim=0.35):
         name = name.lower()
         found_role = None
         if name in self.roles:
@@ -75,6 +88,7 @@ class Role(Plugin):
 
             def str_simularity(a, b):
                 return difflib.SequenceMatcher(None, a, b).ratio()
+
             max_sim = min_sim
             max_role = None
             for role in self.roles:
@@ -86,25 +100,24 @@ class Role(Plugin):
         return found_role
 
     def on_handle_context(self, e_context: EventContext):
-
-        if e_context['context'].type != ContextType.TEXT:
+        if e_context["context"].type != ContextType.TEXT:
             return
         bottype = Bridge().get_bot_type("chat")
         if bottype not in (const.CHATGPT, const.OPEN_AI):
             return
         bot = Bridge().get_bot("chat")
-        content = e_context['context'].content[:]
-        clist = e_context['context'].content.split(maxsplit=1)
+        content = e_context["context"].content[:]
+        clist = e_context["context"].content.split(maxsplit=1)
         desckey = None
         customize = False
-        sessionid = e_context['context']['session_id']
-        trigger_prefix = conf().get('plugin_trigger_prefix', "$")
+        sessionid = e_context["context"]["session_id"]
+        trigger_prefix = conf().get("plugin_trigger_prefix", "$")
         if clist[0] == f"{trigger_prefix}停止扮演":
             if sessionid in self.roleplays:
                 self.roleplays[sessionid].reset()
                 del self.roleplays[sessionid]
             reply = Reply(ReplyType.INFO, "角色扮演结束!")
-            e_context['reply'] = reply
+            e_context["reply"] = reply
             e_context.action = EventAction.BREAK_PASS
             return
         elif clist[0] == f"{trigger_prefix}角色":
@@ -114,10 +127,10 @@ class Role(Plugin):
         elif clist[0] == f"{trigger_prefix}设定扮演":
             customize = True
         elif clist[0] == f"{trigger_prefix}角色类型":
-            if len(clist) >1:
+            if len(clist) > 1:
                 tag = clist[1].strip()
                 help_text = "角色列表:\n"
-                for key,value in self.tags.items():
+                for key, value in self.tags.items():
                     if value[0] == tag:
                         tag = key
                         break
@@ -130,57 +143,75 @@ class Role(Plugin):
                 else:
                     help_text = f"未知角色类型。\n"
                     help_text += "目前的角色类型有: \n"
-                    help_text += ",".join([self.tags[tag][0] for tag in self.tags])+"\n"
+                    help_text += (
+                        ",".join([self.tags[tag][0] for tag in self.tags]) + "\n"
+                    )
             else:
                 help_text = f"请输入角色类型。\n"
                 help_text += "目前的角色类型有: \n"
-                help_text += ",".join([self.tags[tag][0] for tag in self.tags])+"\n"
+                help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "\n"
             reply = Reply(ReplyType.INFO, help_text)
-            e_context['reply'] = reply
+            e_context["reply"] = reply
             e_context.action = EventAction.BREAK_PASS
             return
         elif sessionid not in self.roleplays:
             return
         logger.debug("[Role] on_handle_context. content: %s" % content)
         if desckey is not None:
-            if len(clist) == 1 or (len(clist) > 1 and clist[1].lower() in ["help", "帮助"]):
+            if len(clist) == 1 or (
+                len(clist) > 1 and clist[1].lower() in ["help", "帮助"]
+            ):
                 reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True))
-                e_context['reply'] = reply
+                e_context["reply"] = reply
                 e_context.action = EventAction.BREAK_PASS
                 return
             role = self.get_role(clist[1])
             if role is None:
                 reply = Reply(ReplyType.ERROR, "角色不存在")
-                e_context['reply'] = reply
+                e_context["reply"] = reply
                 e_context.action = EventAction.BREAK_PASS
                 return
             else:
-                self.roleplays[sessionid] = RolePlay(bot, sessionid, self.roles[role][desckey], self.roles[role].get("wrapper","%s"))
-                reply = Reply(ReplyType.INFO, f"预设角色为 {role}:\n"+self.roles[role][desckey])
-                e_context['reply'] = reply
+                self.roleplays[sessionid] = RolePlay(
+                    bot,
+                    sessionid,
+                    self.roles[role][desckey],
+                    self.roles[role].get("wrapper", "%s"),
+                )
+                reply = Reply(
+                    ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey]
+                )
+                e_context["reply"] = reply
                 e_context.action = EventAction.BREAK_PASS
         elif customize == True:
             self.roleplays[sessionid] = RolePlay(bot, sessionid, clist[1], "%s")
             reply = Reply(ReplyType.INFO, f"角色设定为:\n{clist[1]}")
-            e_context['reply'] = reply
+            e_context["reply"] = reply
             e_context.action = EventAction.BREAK_PASS
         else:
             prompt = self.roleplays[sessionid].action(content)
-            e_context['context'].type = ContextType.TEXT
-            e_context['context'].content = prompt
+            e_context["context"].type = ContextType.TEXT
+            e_context["context"].content = prompt
             e_context.action = EventAction.BREAK
 
     def get_help_text(self, verbose=False, **kwargs):
         help_text = "让机器人扮演不同的角色。\n"
         if not verbose:
             return help_text
-        trigger_prefix = conf().get('plugin_trigger_prefix', "$")
-        help_text = f"使用方法:\n{trigger_prefix}角色"+" 预设角色名: 设定角色为{预设角色名}。\n"+f"{trigger_prefix}role"+" 预设角色名: 同上,但使用英文设定。\n"
-        help_text += f"{trigger_prefix}设定扮演"+" 角色设定: 设定自定义角色人设为{角色设定}。\n"
+        trigger_prefix = conf().get("plugin_trigger_prefix", "$")
+        help_text = (
+            f"使用方法:\n{trigger_prefix}角色"
+            + " 预设角色名: 设定角色为{预设角色名}。\n"
+            + f"{trigger_prefix}role"
+            + " 预设角色名: 同上,但使用英文设定。\n"
+        )
+        help_text += f"{trigger_prefix}设定扮演" + " 角色设定: 设定自定义角色人设为{角色设定}。\n"
         help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n"
-        help_text += f"{trigger_prefix}角色类型"+" 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n"
+        help_text += (
+            f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n"
+        )
         help_text += "\n目前的角色类型有: \n"
-        help_text += ",".join([self.tags[tag][0] for tag in self.tags])+"。\n"
+        help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "。\n"
         help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n"
         help_text += f"{trigger_prefix}角色类型 所有\n"
         help_text += f"{trigger_prefix}停止扮演\n"

+ 1 - 1
plugins/role/roles.json

@@ -428,4 +428,4 @@
       ]
     }
   ]
-}
+}

+ 14 - 14
plugins/source.json

@@ -1,16 +1,16 @@
 {
-    "repo": {
-        "sdwebui": {
-            "url": "https://github.com/lanvent/plugin_sdwebui.git",
-            "desc": "利用stable-diffusion画图的插件"
-        },
-        "replicate": {
-            "url": "https://github.com/lanvent/plugin_replicate.git",
-            "desc": "利用replicate api画图的插件"
-        },
-        "summary": {
-            "url": "https://github.com/lanvent/plugin_summary.git",
-            "desc": "总结聊天记录的插件"
-        }
+  "repo": {
+    "sdwebui": {
+      "url": "https://github.com/lanvent/plugin_sdwebui.git",
+      "desc": "利用stable-diffusion画图的插件"
+    },
+    "replicate": {
+      "url": "https://github.com/lanvent/plugin_replicate.git",
+      "desc": "利用replicate api画图的插件"
+    },
+    "summary": {
+      "url": "https://github.com/lanvent/plugin_summary.git",
+      "desc": "总结聊天记录的插件"
     }
-}
+  }
+}

+ 16 - 16
plugins/tool/README.md

@@ -1,14 +1,14 @@
 ## 插件描述
-一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力   
+一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力  
 使用该插件需在机器人回复你的前提下,在对话内容前加$tool;仅输入$tool将返回tool插件帮助信息,用于测试插件是否加载成功  
 ### 本插件所有工具同步存放至专用仓库:[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)
-  
-  
+
+
 ## 使用说明
-使用该插件后将默认使用4个工具, 无需额外配置长期生效: 
-### 1. python 
+使用该插件后将默认使用4个工具, 无需额外配置长期生效:
+### 1. python
 ###### python解释器,使用它来解释执行python指令,可以配合你想要chatgpt生成的代码输出结果或执行事务
-  
+
 ### 2. url-get
 ###### 往往用来获取某个网站具体内容,结果可能会被反爬策略影响
 
@@ -23,16 +23,16 @@
 
 > meteo调优记录:https://github.com/zhayujie/chatgpt-on-wechat/issues/776#issuecomment-1500771334
 
-## 使用本插件对话(prompt)技巧 
-### 1. 有指引的询问 
+## 使用本插件对话(prompt)技巧
+### 1. 有指引的询问
 #### 例如:
-- 总结这个链接的内容 https://github.com/goldfishh/chatgpt-tool-hub 
+- 总结这个链接的内容 https://github.com/goldfishh/chatgpt-tool-hub
 - 使用Terminal执行curl cip.cc
 - 使用python查询今天日期
-  
+
 ### 2. 使用搜索引擎工具
 - 如果有搜索工具就能让chatgpt获取到你的未传达清楚的上下文信息,比如chatgpt不知道你的地理位置,现在时间等,所以无法查询到天气
-  
+
 ## 其他工具
 
 ### 5. wikipedia
@@ -55,9 +55,9 @@
 ### 10. google-search *
 ###### google搜索引擎,申请流程较bing-search繁琐
 
-###### 注1:带*工具需要获取api-key才能使用,部分工具需要外网支持   
+###### 注1:带*工具需要获取api-key才能使用,部分工具需要外网支持  
 #### [申请方法](https://github.com/goldfishh/chatgpt-tool-hub/blob/master/docs/apply_optional_tool.md)
-  
+
 ## config.json 配置说明
 ###### 默认工具无需配置,其它工具需手动配置,一个例子:
 ```json
@@ -71,15 +71,15 @@
 }
 
 ```
-注:config.json文件非必须,未创建仍可使用本tool;带*工具需在kwargs填入对应api-key键值对    
+注:config.json文件非必须,未创建仍可使用本tool;带*工具需在kwargs填入对应api-key键值对  
 - `tools`:本插件初始化时加载的工具, 目前可选集:["wikipedia", "wolfram-alpha", "bing-search", "google-search", "news", "morning-news"] & 默认工具,除wikipedia工具之外均需要申请api-key
 - `kwargs`:工具执行时的配置,一般在这里存放**api-key**,或环境配置
   - `request_timeout`: 访问openai接口的超时时间,默认与wechat-on-chatgpt配置一致,可单独配置
   - `no_default`: 用于配置默认加载4个工具的行为,如果为true则仅使用tools列表工具,不加载默认工具
   - `top_k_results`: 控制所有有关搜索的工具返回条目数,数字越高则参考信息越多,但无用信息可能干扰判断,该值一般为2
   - `model_name`: 用于控制tool插件底层使用的llm模型,目前暂未测试3.5以外的模型,一般保持默认
-  
-  
+
+
 ## 备注
 - 强烈建议申请搜索工具搭配使用,推荐bing-search
 - 虽然我会有意加入一些限制,但请不要使用本插件做危害他人的事情,请提前了解清楚某些内容是否会违反相关规定,建议提前做好过滤

+ 1 - 1
plugins/tool/__init__.py

@@ -1 +1 @@
-from .tool import *
+from .tool import *

+ 10 - 5
plugins/tool/config.json.template

@@ -1,8 +1,13 @@
 {
-  "tools": ["python", "url-get", "terminal", "meteo-weather"],
+  "tools": [
+    "python",
+    "url-get",
+    "terminal",
+    "meteo-weather"
+  ],
   "kwargs": {
-      "top_k_results": 2,
-      "no_default": false,
-      "model_name": "gpt-3.5-turbo"
+    "top_k_results": 2,
+    "no_default": false,
+    "model_name": "gpt-3.5-turbo"
   }
-}
+}

+ 40 - 21
plugins/tool/tool.py

@@ -4,6 +4,7 @@ import os
 from chatgpt_tool_hub.apps import load_app
 from chatgpt_tool_hub.apps.app import App
 from chatgpt_tool_hub.tools.all_tool_list import get_all_tool_names
+
 import plugins
 from bridge.bridge import Bridge
 from bridge.context import ContextType
@@ -14,7 +15,13 @@ from config import conf
 from plugins import *
 
 
-@plugins.register(name="tool", desc="Arming your ChatGPT bot with various tools", version="0.3", author="goldfishh", desire_priority=0)
+@plugins.register(
+    name="tool",
+    desc="Arming your ChatGPT bot with various tools",
+    version="0.3",
+    author="goldfishh",
+    desire_priority=0,
+)
 class Tool(Plugin):
     def __init__(self):
         super().__init__()
@@ -28,22 +35,26 @@ class Tool(Plugin):
         help_text = "这是一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力。"
         if not verbose:
             return help_text
-        trigger_prefix = conf().get('plugin_trigger_prefix', "$")
+        trigger_prefix = conf().get("plugin_trigger_prefix", "$")
         help_text += "使用说明:\n"
-        help_text += f"{trigger_prefix}tool "+"命令: 根据给出的{命令}使用一些可用工具尽力为你得到结果。\n"
+        help_text += f"{trigger_prefix}tool " + "命令: 根据给出的{命令}使用一些可用工具尽力为你得到结果。\n"
         help_text += f"{trigger_prefix}tool reset: 重置工具。\n"
         return help_text
 
     def on_handle_context(self, e_context: EventContext):
-        if e_context['context'].type != ContextType.TEXT:
+        if e_context["context"].type != ContextType.TEXT:
             return
 
         # 暂时不支持未来扩展的bot
-        if Bridge().get_bot_type("chat") not in (const.CHATGPT, const.OPEN_AI, const.CHATGPTONAZURE):
+        if Bridge().get_bot_type("chat") not in (
+            const.CHATGPT,
+            const.OPEN_AI,
+            const.CHATGPTONAZURE,
+        ):
             return
 
-        content = e_context['context'].content
-        content_list = e_context['context'].content.split(maxsplit=1)
+        content = e_context["context"].content
+        content_list = e_context["context"].content.split(maxsplit=1)
 
         if not content or len(content_list) < 1:
             e_context.action = EventAction.CONTINUE
@@ -52,13 +63,13 @@ class Tool(Plugin):
         logger.debug("[tool] on_handle_context. content: %s" % content)
         reply = Reply()
         reply.type = ReplyType.TEXT
-        trigger_prefix = conf().get('plugin_trigger_prefix', "$")
+        trigger_prefix = conf().get("plugin_trigger_prefix", "$")
         # todo: 有些工具必须要api-key,需要修改config文件,所以这里没有实现query增删tool的功能
         if content.startswith(f"{trigger_prefix}tool"):
             if len(content_list) == 1:
                 logger.debug("[tool]: get help")
                 reply.content = self.get_help_text()
-                e_context['reply'] = reply
+                e_context["reply"] = reply
                 e_context.action = EventAction.BREAK_PASS
                 return
             elif len(content_list) > 1:
@@ -66,12 +77,14 @@ class Tool(Plugin):
                     logger.debug("[tool]: reset config")
                     self.app = self._reset_app()
                     reply.content = "重置工具成功"
-                    e_context['reply'] = reply
+                    e_context["reply"] = reply
                     e_context.action = EventAction.BREAK_PASS
                     return
                 elif content_list[1].startswith("reset"):
                     logger.debug("[tool]: remind")
-                    e_context['context'].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符"
+                    e_context[
+                        "context"
+                    ].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符"
 
                     e_context.action = EventAction.BREAK
                     return
@@ -80,34 +93,35 @@ class Tool(Plugin):
 
                 # Don't modify bot name
                 all_sessions = Bridge().get_bot("chat").sessions
-                user_session = all_sessions.session_query(query, e_context['context']['session_id']).messages
+                user_session = all_sessions.session_query(
+                    query, e_context["context"]["session_id"]
+                ).messages
 
                 # chatgpt-tool-hub will reply you with many tools
                 logger.debug("[tool]: just-go")
                 try:
                     _reply = self.app.ask(query, user_session)
                     e_context.action = EventAction.BREAK_PASS
-                    all_sessions.session_reply(_reply, e_context['context']['session_id'])
+                    all_sessions.session_reply(
+                        _reply, e_context["context"]["session_id"]
+                    )
                 except Exception as e:
                     logger.exception(e)
                     logger.error(str(e))
 
-                    e_context['context'].content = "请你随机用一种聊天风格,提醒用户:这个问题tool插件暂时无法处理"
+                    e_context["context"].content = "请你随机用一种聊天风格,提醒用户:这个问题tool插件暂时无法处理"
                     reply.type = ReplyType.ERROR
                     e_context.action = EventAction.BREAK
                     return
 
                 reply.content = _reply
-                e_context['reply'] = reply
+                e_context["reply"] = reply
         return
 
     def _read_json(self) -> dict:
         curdir = os.path.dirname(__file__)
         config_path = os.path.join(curdir, "config.json")
-        tool_config = {
-            "tools": [],
-            "kwargs": {}
-        }
+        tool_config = {"tools": [], "kwargs": {}}
         if not os.path.exists(config_path):
             return tool_config
         else:
@@ -123,7 +137,9 @@ class Tool(Plugin):
             "proxy": conf().get("proxy", ""),
             "request_timeout": conf().get("request_timeout", 60),
             # note: 目前tool暂未对其他模型测试,但这里仍对配置来源做了优先级区分,一般插件配置可覆盖全局配置
-            "model_name": tool_model_name if tool_model_name else conf().get("model", "gpt-3.5-turbo"),
+            "model_name": tool_model_name
+            if tool_model_name
+            else conf().get("model", "gpt-3.5-turbo"),
             "no_default": kwargs.get("no_default", False),
             "top_k_results": kwargs.get("top_k_results", 2),
             # for news tool
@@ -160,4 +176,7 @@ class Tool(Plugin):
         # filter not support tool
         tool_list = self._filter_tool_list(tool_config.get("tools", []))
 
-        return load_app(tools_list=tool_list, **self._build_tool_kwargs(tool_config.get("kwargs", {})))
+        return load_app(
+            tools_list=tool_list,
+            **self._build_tool_kwargs(tool_config.get("kwargs", {})),
+        )

+ 1 - 0
requirements.txt

@@ -4,3 +4,4 @@ PyQRCode>=1.2.1
 qrcode>=7.4.2
 requests>=2.28.2
 chardet>=5.1.0
+pre-commit

+ 1 - 1
scripts/start.sh

@@ -8,7 +8,7 @@ echo $BASE_DIR
 # check the nohup.out log output file
 if [ ! -f "${BASE_DIR}/nohup.out" ]; then
   touch "${BASE_DIR}/nohup.out"
-echo "create file  ${BASE_DIR}/nohup.out"  
+echo "create file  ${BASE_DIR}/nohup.out"
 fi
 
 nohup python3 "${BASE_DIR}/app.py" & tail -f "${BASE_DIR}/nohup.out"

+ 1 - 1
scripts/tout.sh

@@ -7,7 +7,7 @@ echo $BASE_DIR
 
 # check the nohup.out log output file
 if [ ! -f "${BASE_DIR}/nohup.out" ]; then
-   echo "No file  ${BASE_DIR}/nohup.out"  
+   echo "No file  ${BASE_DIR}/nohup.out"
    exit -1;
 fi
 

+ 25 - 8
voice/audio_convert.py

@@ -1,9 +1,12 @@
 import shutil
 import wave
+
 import pysilk
 from pydub import AudioSegment
 
-sil_supports=[8000, 12000, 16000, 24000, 32000, 44100, 48000] # slk转wav时,支持的采样率
+sil_supports = [8000, 12000, 16000, 24000, 32000, 44100, 48000]  # slk转wav时,支持的采样率
+
+
 def find_closest_sil_supports(sample_rate):
     """
     找到最接近的支持的采样率
@@ -19,6 +22,7 @@ def find_closest_sil_supports(sample_rate):
             mindiff = diff
     return closest
 
+
 def get_pcm_from_wav(wav_path):
     """
     从 wav 文件中读取 pcm
@@ -29,31 +33,42 @@ def get_pcm_from_wav(wav_path):
     wav = wave.open(wav_path, "rb")
     return wav.readframes(wav.getnframes())
 
+
 def any_to_wav(any_path, wav_path):
     """
     把任意格式转成wav文件
     """
-    if any_path.endswith('.wav'):
+    if any_path.endswith(".wav"):
         shutil.copy2(any_path, wav_path)
         return
-    if any_path.endswith('.sil') or any_path.endswith('.silk') or any_path.endswith('.slk'):
+    if (
+        any_path.endswith(".sil")
+        or any_path.endswith(".silk")
+        or any_path.endswith(".slk")
+    ):
         return sil_to_wav(any_path, wav_path)
     audio = AudioSegment.from_file(any_path)
     audio.export(wav_path, format="wav")
 
+
 def any_to_sil(any_path, sil_path):
     """
     把任意格式转成sil文件
     """
-    if any_path.endswith('.sil') or any_path.endswith('.silk') or any_path.endswith('.slk'):
+    if (
+        any_path.endswith(".sil")
+        or any_path.endswith(".silk")
+        or any_path.endswith(".slk")
+    ):
         shutil.copy2(any_path, sil_path)
         return 10000
-    if any_path.endswith('.wav'):
+    if any_path.endswith(".wav"):
         return pcm_to_sil(any_path, sil_path)
-    if any_path.endswith('.mp3'):
+    if any_path.endswith(".mp3"):
         return mp3_to_sil(any_path, sil_path)
     raise NotImplementedError("Not support file type: {}".format(any_path))
 
+
 def mp3_to_wav(mp3_path, wav_path):
     """
     把mp3格式转成pcm文件
@@ -61,6 +76,7 @@ def mp3_to_wav(mp3_path, wav_path):
     audio = AudioSegment.from_mp3(mp3_path)
     audio.export(wav_path, format="wav")
 
+
 def pcm_to_sil(pcm_path, silk_path):
     """
     wav 文件转成 silk
@@ -72,12 +88,12 @@ def pcm_to_sil(pcm_path, silk_path):
     pcm_s16 = audio.set_sample_width(2)
     pcm_s16 = pcm_s16.set_frame_rate(rate)
     wav_data = pcm_s16.raw_data
-    silk_data = pysilk.encode(
-        wav_data, data_rate=rate, sample_rate=rate)
+    silk_data = pysilk.encode(wav_data, data_rate=rate, sample_rate=rate)
     with open(silk_path, "wb") as f:
         f.write(silk_data)
     return audio.duration_seconds * 1000
 
+
 def mp3_to_sil(mp3_path, silk_path):
     """
     mp3 文件转成 silk
@@ -95,6 +111,7 @@ def mp3_to_sil(mp3_path, silk_path):
         f.write(silk_data)
     return audio.duration_seconds * 1000
 
+
 def sil_to_wav(silk_path, wav_path, rate: int = 24000):
     """
     silk 文件转 wav

+ 37 - 17
voice/azure/azure_voice.py

@@ -1,16 +1,18 @@
-
 """
 azure voice service
 """
 import json
 import os
 import time
+
 import azure.cognitiveservices.speech as speechsdk
+
 from bridge.reply import Reply, ReplyType
 from common.log import logger
 from common.tmp_dir import TmpDir
-from voice.voice import Voice
 from config import conf
+from voice.voice import Voice
+
 """
 Azure voice
 主目录设置文件中需填写azure_voice_api_key和azure_voice_region
@@ -19,50 +21,68 @@ Azure voice
 
 """
 
-class AzureVoice(Voice):
 
+class AzureVoice(Voice):
     def __init__(self):
         try:
             curdir = os.path.dirname(__file__)
             config_path = os.path.join(curdir, "config.json")
             config = None
-            if not os.path.exists(config_path): #如果没有配置文件,创建本地配置文件
-                config = { "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", "speech_recognition_language": "zh-CN"}
+            if not os.path.exists(config_path):  # 如果没有配置文件,创建本地配置文件
+                config = {
+                    "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural",
+                    "speech_recognition_language": "zh-CN",
+                }
                 with open(config_path, "w") as fw:
                     json.dump(config, fw, indent=4)
             else:
                 with open(config_path, "r") as fr:
                     config = json.load(fr)
-            self.api_key = conf().get('azure_voice_api_key')
-            self.api_region = conf().get('azure_voice_region')
-            self.speech_config = speechsdk.SpeechConfig(subscription=self.api_key, region=self.api_region)
-            self.speech_config.speech_synthesis_voice_name = config["speech_synthesis_voice_name"]
-            self.speech_config.speech_recognition_language = config["speech_recognition_language"]
+            self.api_key = conf().get("azure_voice_api_key")
+            self.api_region = conf().get("azure_voice_region")
+            self.speech_config = speechsdk.SpeechConfig(
+                subscription=self.api_key, region=self.api_region
+            )
+            self.speech_config.speech_synthesis_voice_name = config[
+                "speech_synthesis_voice_name"
+            ]
+            self.speech_config.speech_recognition_language = config[
+                "speech_recognition_language"
+            ]
         except Exception as e:
             logger.warn("AzureVoice init failed: %s, ignore " % e)
 
     def voiceToText(self, voice_file):
         audio_config = speechsdk.AudioConfig(filename=voice_file)
-        speech_recognizer = speechsdk.SpeechRecognizer(speech_config=self.speech_config, audio_config=audio_config)
+        speech_recognizer = speechsdk.SpeechRecognizer(
+            speech_config=self.speech_config, audio_config=audio_config
+        )
         result = speech_recognizer.recognize_once()
         if result.reason == speechsdk.ResultReason.RecognizedSpeech:
-            logger.info('[Azure] voiceToText voice file name={} text={}'.format(voice_file, result.text))
+            logger.info(
+                "[Azure] voiceToText voice file name={} text={}".format(
+                    voice_file, result.text
+                )
+            )
             reply = Reply(ReplyType.TEXT, result.text)
         else:
-            logger.error('[Azure] voiceToText error, result={}'.format(result))
+            logger.error("[Azure] voiceToText error, result={}".format(result))
             reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败")
         return reply
 
     def textToVoice(self, text):
-        fileName = TmpDir().path() + 'reply-' + str(int(time.time())) + '.wav'
+        fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav"
         audio_config = speechsdk.AudioConfig(filename=fileName)
-        speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=audio_config)
+        speech_synthesizer = speechsdk.SpeechSynthesizer(
+            speech_config=self.speech_config, audio_config=audio_config
+        )
         result = speech_synthesizer.speak_text(text)
         if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
             logger.info(
-                '[Azure] textToVoice text={} voice file name={}'.format(text, fileName))
+                "[Azure] textToVoice text={} voice file name={}".format(text, fileName)
+            )
             reply = Reply(ReplyType.VOICE, fileName)
         else:
-            logger.error('[Azure] textToVoice error, result={}'.format(result))
+            logger.error("[Azure] textToVoice error, result={}".format(result))
             reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
         return reply

+ 3 - 3
voice/azure/config.json.template

@@ -1,4 +1,4 @@
 {
-    "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural",
-    "speech_recognition_language": "zh-CN"
-}
+  "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural",
+  "speech_recognition_language": "zh-CN"
+}

+ 4 - 4
voice/baidu/README.md

@@ -29,7 +29,7 @@ dev_pid	    必填	语言选择,填写语言对应的dev_pid值
 
 2、对于def textToVoice(self, text)函数中调用的百度语音合成API,中接口调用synthesis(参数)在本目录下的`config.json`文件中进行配置。
 参数	    可需	描述
-tex	        必填	合成的文本,使用UTF-8编码,请注意文本长度必须小于1024字节   
+tex	        必填	合成的文本,使用UTF-8编码,请注意文本长度必须小于1024字节  
 lan	        必填	固定值zh。语言选择,目前只有中英文混合模式,填写固定值zh
 spd	        选填	语速,取值0-15,默认为5中语速
 pit	        选填	音调,取值0-15,默认为5中语调
@@ -40,14 +40,14 @@ aue	        选填	3为mp3格式(默认); 4为pcm-16k;5为pcm-8k;6为wav
 
 关于per参数的说明,注意您购买的哪个音库,就填写哪个音库的参数,否则会报错。如果您购买的是基础音库,那么per参数只能填写0到4,如果您购买的是精品音库,那么per参数只能填写5003,5118,106,110,111,103,5其他的都会报错。
 ### 配置文件
-  
+
 将文件夹中`config.json.template`复制为`config.json`。
 
 ``` json
     {
-    "lang": "zh", 
+    "lang": "zh",
     "ctp": 1,
-    "spd": 5, 
+    "spd": 5,
     "pit": 5,
     "vol": 5,
     "per": 0

+ 26 - 23
voice/baidu/baidu_voice.py

@@ -1,17 +1,19 @@
-
 """
 baidu voice service
 """
 import json
 import os
 import time
+
 from aip import AipSpeech
+
 from bridge.reply import Reply, ReplyType
 from common.log import logger
 from common.tmp_dir import TmpDir
-from voice.voice import Voice
-from voice.audio_convert import get_pcm_from_wav
 from config import conf
+from voice.audio_convert import get_pcm_from_wav
+from voice.voice import Voice
+
 """
     百度的语音识别API.
     dev_pid:
@@ -28,40 +30,37 @@ from config import conf
 
 
 class BaiduVoice(Voice):
-
     def __init__(self):
         try:
             curdir = os.path.dirname(__file__)
             config_path = os.path.join(curdir, "config.json")
             bconf = None
-            if not os.path.exists(config_path): #如果没有配置文件,创建本地配置文件
-                bconf = { "lang": "zh", "ctp": 1, "spd": 5,
-                         "pit": 5, "vol": 5, "per": 0}
+            if not os.path.exists(config_path):  # 如果没有配置文件,创建本地配置文件
+                bconf = {"lang": "zh", "ctp": 1, "spd": 5, "pit": 5, "vol": 5, "per": 0}
                 with open(config_path, "w") as fw:
                     json.dump(bconf, fw, indent=4)
             else:
                 with open(config_path, "r") as fr:
                     bconf = json.load(fr)
-                    
-            self.app_id = conf().get('baidu_app_id')
-            self.api_key = conf().get('baidu_api_key')
-            self.secret_key = conf().get('baidu_secret_key')
-            self.dev_id = conf().get('baidu_dev_pid')
+
+            self.app_id = conf().get("baidu_app_id")
+            self.api_key = conf().get("baidu_api_key")
+            self.secret_key = conf().get("baidu_secret_key")
+            self.dev_id = conf().get("baidu_dev_pid")
             self.lang = bconf["lang"]
             self.ctp = bconf["ctp"]
             self.spd = bconf["spd"]
             self.pit = bconf["pit"]
             self.vol = bconf["vol"]
             self.per = bconf["per"]
-            
+
             self.client = AipSpeech(self.app_id, self.api_key, self.secret_key)
         except Exception as e:
             logger.warn("BaiduVoice init failed: %s, ignore " % e)
 
-        
     def voiceToText(self, voice_file):
         # 识别本地文件
-        logger.debug('[Baidu] voice file name={}'.format(voice_file))
+        logger.debug("[Baidu] voice file name={}".format(voice_file))
         pcm = get_pcm_from_wav(voice_file)
         res = self.client.asr(pcm, "pcm", 16000, {"dev_pid": self.dev_id})
         if res["err_no"] == 0:
@@ -72,21 +71,25 @@ class BaiduVoice(Voice):
             logger.info("百度语音识别出错了: {}".format(res["err_msg"]))
             if res["err_msg"] == "request pv too much":
                 logger.info("  出现这个原因很可能是你的百度语音服务调用量超出限制,或未开通付费")
-            reply = Reply(ReplyType.ERROR,
-                          "百度语音识别出错了;{0}".format(res["err_msg"]))
+            reply = Reply(ReplyType.ERROR, "百度语音识别出错了;{0}".format(res["err_msg"]))
         return reply
 
     def textToVoice(self, text):
-        result = self.client.synthesis(text, self.lang, self.ctp, {
-            'spd': self.spd, 'pit': self.pit, 'vol': self.vol, 'per': self.per})
+        result = self.client.synthesis(
+            text,
+            self.lang,
+            self.ctp,
+            {"spd": self.spd, "pit": self.pit, "vol": self.vol, "per": self.per},
+        )
         if not isinstance(result, dict):
-            fileName = TmpDir().path() + 'reply-' + str(int(time.time())) + '.mp3'
-            with open(fileName, 'wb') as f:
+            fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3"
+            with open(fileName, "wb") as f:
                 f.write(result)
             logger.info(
-                '[Baidu] textToVoice text={} voice file name={}'.format(text, fileName))
+                "[Baidu] textToVoice text={} voice file name={}".format(text, fileName)
+            )
             reply = Reply(ReplyType.VOICE, fileName)
         else:
-            logger.error('[Baidu] textToVoice error={}'.format(result))
+            logger.error("[Baidu] textToVoice error={}".format(result))
             reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
         return reply

+ 8 - 8
voice/baidu/config.json.template

@@ -1,8 +1,8 @@
-    {
-    "lang": "zh", 
-    "ctp": 1,
-    "spd": 5, 
-    "pit": 5,
-    "vol": 5,
-    "per": 0
-    }
+{
+  "lang": "zh",
+  "ctp": 1,
+  "spd": 5,
+  "pit": 5,
+  "vol": 5,
+  "per": 0
+}

+ 13 - 7
voice/google/google_voice.py

@@ -1,11 +1,12 @@
-
 """
 google voice service
 """
 
 import time
+
 import speech_recognition
 from gtts import gTTS
+
 from bridge.reply import Reply, ReplyType
 from common.log import logger
 from common.tmp_dir import TmpDir
@@ -22,9 +23,12 @@ class GoogleVoice(Voice):
         with speech_recognition.AudioFile(voice_file) as source:
             audio = self.recognizer.record(source)
         try:
-            text = self.recognizer.recognize_google(audio, language='zh-CN')
+            text = self.recognizer.recognize_google(audio, language="zh-CN")
             logger.info(
-                '[Google] voiceToText text={} voice file name={}'.format(text, voice_file))
+                "[Google] voiceToText text={} voice file name={}".format(
+                    text, voice_file
+                )
+            )
             reply = Reply(ReplyType.TEXT, text)
         except speech_recognition.UnknownValueError:
             reply = Reply(ReplyType.ERROR, "抱歉,我听不懂")
@@ -32,13 +36,15 @@ class GoogleVoice(Voice):
             reply = Reply(ReplyType.ERROR, "抱歉,无法连接到 Google 语音识别服务;{0}".format(e))
         finally:
             return reply
+
     def textToVoice(self, text):
         try:
-            mp3File = TmpDir().path() + 'reply-' + str(int(time.time())) + '.mp3'
-            tts = gTTS(text=text, lang='zh')
-            tts.save(mp3File)            
+            mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3"
+            tts = gTTS(text=text, lang="zh")
+            tts.save(mp3File)
             logger.info(
-                '[Google] textToVoice text={} voice file name={}'.format(text, mp3File))
+                "[Google] textToVoice text={} voice file name={}".format(text, mp3File)
+            )
             reply = Reply(ReplyType.VOICE, mp3File)
         except Exception as e:
             reply = Reply(ReplyType.ERROR, str(e))

+ 9 - 6
voice/openai/openai_voice.py

@@ -1,29 +1,32 @@
-
 """
 google voice service
 """
 import json
+
 import openai
+
 from bridge.reply import Reply, ReplyType
-from config import conf
 from common.log import logger
+from config import conf
 from voice.voice import Voice
 
 
 class OpenaiVoice(Voice):
     def __init__(self):
-        openai.api_key = conf().get('open_ai_api_key')
+        openai.api_key = conf().get("open_ai_api_key")
 
     def voiceToText(self, voice_file):
-        logger.debug(
-            '[Openai] voice file name={}'.format(voice_file))
+        logger.debug("[Openai] voice file name={}".format(voice_file))
         try:
             file = open(voice_file, "rb")
             result = openai.Audio.transcribe("whisper-1", file)
             text = result["text"]
             reply = Reply(ReplyType.TEXT, text)
             logger.info(
-                '[Openai] voiceToText text={} voice file name={}'.format(text, voice_file))
+                "[Openai] voiceToText text={} voice file name={}".format(
+                    text, voice_file
+                )
+            )
         except Exception as e:
             reply = Reply(ReplyType.ERROR, str(e))
         finally:

+ 9 - 7
voice/pytts/pytts_voice.py

@@ -1,10 +1,11 @@
-
 """
 pytts voice service (offline)
 """
 
 import time
+
 import pyttsx3
+
 from bridge.reply import Reply, ReplyType
 from common.log import logger
 from common.tmp_dir import TmpDir
@@ -16,20 +17,21 @@ class PyttsVoice(Voice):
 
     def __init__(self):
         # 语速
-        self.engine.setProperty('rate', 125)
+        self.engine.setProperty("rate", 125)
         # 音量
-        self.engine.setProperty('volume', 1.0)
-        for voice in self.engine.getProperty('voices'):
+        self.engine.setProperty("volume", 1.0)
+        for voice in self.engine.getProperty("voices"):
             if "Chinese" in voice.name:
-                self.engine.setProperty('voice', voice.id)
+                self.engine.setProperty("voice", voice.id)
 
     def textToVoice(self, text):
         try:
-            wavFile = TmpDir().path() + 'reply-' + str(int(time.time())) + '.wav'
+            wavFile = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav"
             self.engine.save_to_file(text, wavFile)
             self.engine.runAndWait()
             logger.info(
-                '[Pytts] textToVoice text={} voice file name={}'.format(text, wavFile))
+                "[Pytts] textToVoice text={} voice file name={}".format(text, wavFile)
+            )
             reply = Reply(ReplyType.VOICE, wavFile)
         except Exception as e:
             reply = Reply(ReplyType.ERROR, str(e))

+ 2 - 1
voice/voice.py

@@ -2,6 +2,7 @@
 Voice service abstract class
 """
 
+
 class Voice(object):
     def voiceToText(self, voice_file):
         """
@@ -13,4 +14,4 @@ class Voice(object):
         """
         Send text to voice service and get voice
         """
-        raise NotImplementedError
+        raise NotImplementedError

+ 11 - 5
voice/voice_factory.py

@@ -2,25 +2,31 @@
 voice factory
 """
 
+
 def create_voice(voice_type):
     """
     create a voice instance
     :param voice_type: voice type code
     :return: voice instance
     """
-    if voice_type == 'baidu':
+    if voice_type == "baidu":
         from voice.baidu.baidu_voice import BaiduVoice
+
         return BaiduVoice()
-    elif voice_type == 'google':
+    elif voice_type == "google":
         from voice.google.google_voice import GoogleVoice
+
         return GoogleVoice()
-    elif voice_type == 'openai':
+    elif voice_type == "openai":
         from voice.openai.openai_voice import OpenaiVoice
+
         return OpenaiVoice()
-    elif voice_type == 'pytts':
+    elif voice_type == "pytts":
         from voice.pytts.pytts_voice import PyttsVoice
+
         return PyttsVoice()
-    elif voice_type == 'azure':
+    elif voice_type == "azure":
         from voice.azure.azure_voice import AzureVoice
+
         return AzureVoice()
     raise RuntimeError