Преглед на файлове

feat: terminal support plugins

lanvent преди 3 години
родител
ревизия
484de6237b
променени са 5 файла, в които са добавени 73 реда и са изтрити 21 реда
  1. 5 1
      app.py
  2. 3 3
      channel/chat_channel.py
  3. 62 15
      channel/terminal/terminal_channel.py
  4. 1 1
      plugins/godcmd/godcmd.py
  5. 2 1
      plugins/source.json

+ 5 - 1
app.py

@@ -27,12 +27,16 @@ def run():
 
         # create channel
         channel_name=conf().get('channel_type', 'wx')
+
+        if "--cmd" in sys.argv:
+            channel_name = 'terminal'
+
         if channel_name == 'wxy':
             os.environ['WECHATY_LOG']="warn"
             # os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
 
         channel = channel_factory.create_channel(channel_name)
-        if channel_name in ['wx','wxy','wechatmp']:
+        if channel_name in ['wx','wxy','wechatmp','terminal']:
             PluginManager().load_plugins()
 
         # startup channel

+ 3 - 3
channel/chat_channel.py

@@ -51,7 +51,7 @@ class ChatChannel(Channel):
             if cmsg.from_user_id == self.user_id and not config.get('trigger_by_self', True):
                 logger.debug("[WX]self message skipped")
                 return None
-            if context["isgroup"]:
+            if context.get("isgroup", False):
                 group_name = cmsg.other_user_nickname
                 group_id = cmsg.other_user_id
 
@@ -76,7 +76,7 @@ class ChatChannel(Channel):
                 logger.debug("[WX]reference query skipped")
                 return None
             
-            if context["isgroup"]: # 群聊
+            if context.get("isgroup", False): # 群聊
                 # 校验关键字
                 match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
                 match_contain = check_contain(content, conf().get('group_chat_keyword'))
@@ -193,7 +193,7 @@ class ChatChannel(Channel):
                     if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
                         reply = super().build_text_to_voice(reply.content)
                         return self._decorate_reply(context, reply)
-                    if context['isgroup']:
+                    if context.get("isgroup", False):
                         reply_text = '@' +  context['msg'].actual_user_nickname + ' ' + reply_text.strip()
                         reply_text = conf().get("group_chat_reply_prefix", "") + reply_text
                     else:

+ 62 - 15
channel/terminal/terminal_channel.py

@@ -1,31 +1,78 @@
 from bridge.context import *
-from channel.channel import Channel
+from bridge.reply import Reply, ReplyType
+from channel.chat_channel import ChatChannel, check_prefix
+from channel.chat_message import ChatMessage
 import sys
 
-class TerminalChannel(Channel):
+from config import conf
+from common.log import logger
+
+class TerminalMessage(ChatMessage):
+    def __init__(self, msg_id, content, ctype = ContextType.TEXT,  from_user_id = "User", to_user_id = "Chatgpt", other_user_id = "Chatgpt"):
+        self.msg_id = msg_id
+        self.ctype = ctype
+        self.content = content
+        self.from_user_id = from_user_id
+        self.to_user_id = to_user_id
+        self.other_user_id = other_user_id
+
+class TerminalChannel(ChatChannel):
+    NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE]
+
+    def send(self, reply: Reply, context: Context):
+        print("\nBot:")
+        if reply.type == ReplyType.IMAGE:
+            from PIL import Image
+            image_storage = reply.content
+            image_storage.seek(0)
+            img = Image.open(image_storage)
+            print("<IMAGE>")
+            img.show()
+        elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
+            from PIL import Image
+            import requests,io
+            img_url = reply.content
+            pic_res = requests.get(img_url, stream=True)
+            image_storage = io.BytesIO()
+            for block in pic_res.iter_content(1024):
+                image_storage.write(block)
+            image_storage.seek(0)
+            img = Image.open(image_storage)
+            print(img_url)
+            img.show()
+        else:
+            print(reply.content)
+        print("\nUser:", end="")
+        sys.stdout.flush()
+        return
+
     def startup(self):
         context = Context()
-        print("\nPlease input your question")
+        logger.setLevel("WARN")
+        print("\nPlease input your question:\nUser:", end="")
+        sys.stdout.flush()
+        msg_id = 0
         while True:
             try:
-                prompt = self.get_input("User:\n")
+                prompt = self.get_input()
             except KeyboardInterrupt:
                 print("\nExiting...")
                 sys.exit()
+            msg_id += 1
+            trigger_prefixs = conf().get("single_chat_prefix",[""])
+            if check_prefix(prompt, trigger_prefixs) is None:
+                prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
+                
+            context = self._compose_context(ContextType.TEXT, prompt, msg = TerminalMessage(msg_id, prompt))
+            if context:
+                self.produce(context)
+            else:
+                raise Exception("context is None")
 
-            context.type = ContextType.TEXT
-            context['session_id'] = "User"
-            context.content = prompt
-            print("Bot:")
-            sys.stdout.flush()
-            res = super().build_reply_content(prompt, context).content
-            print(res)
-
-
-    def get_input(self, prompt):
+    def get_input(self):
         """
         Multi-line input function
         """
-        print(prompt, end="")
+        sys.stdout.flush()
         line = input()
         return line

+ 1 - 1
plugins/godcmd/godcmd.py

@@ -194,7 +194,7 @@ class Godcmd(Plugin):
             channel = e_context['channel']
             user = e_context['context']['receiver']
             session_id = e_context['context']['session_id']
-            isgroup = e_context['context']['isgroup']
+            isgroup = e_context['context'].get("isgroup", False)
             bottype = Bridge().get_bot_type("chat")
             bot = Bridge().get_bot("chat")
             # 将命令和参数分割

+ 2 - 1
plugins/source.json

@@ -1,7 +1,8 @@
 {
     "repo": {
         "sdwebui": {
-            "url": "https://github.com/lanvent/plugin_sdwebui.git"
+            "url": "https://github.com/lanvent/plugin_sdwebui.git",
+            "desc": "利用stable-diffusion画图的插件"
         }
     }
 }