Procházet zdrojové kódy

sdwebui : add help reply

lanvent před 3 roky
rodič
revize
e6b65437e4
2 změnil soubory, kde provedl 64 přidání a 35 odebrání
  1. 6 3
      plugins/sdwebui/config.json.template
  2. 58 32
      plugins/sdwebui/sdwebui.py

+ 6 - 3
plugins/sdwebui/config.json.template

@@ -31,7 +31,8 @@
       "params": {
         "width": 640,
         "height": 384
-      }
+      },
+      "desc": "分辨率会变成640x384"
     },
     {
       "keywords": [
@@ -49,7 +50,8 @@
       "params": {
         "enable_hr": true,
         "hr_scale": 1.6
-      }
+      },
+      "desc": "出图分辨率长宽都会提高1.6倍"
     },
     {
       "keywords": [
@@ -61,7 +63,8 @@
       },
       "options": {
         "sd_model_checkpoint": "meinamix_meinaV8"
-      }
+      },
+      "desc": "使用二次元风格模型出图"
     }
   ]
 }

+ 58 - 32
plugins/sdwebui/sdwebui.py

@@ -4,6 +4,7 @@ import json
 import os
 from bridge.context import ContextType
 from bridge.reply import Reply, ReplyType
+from config import conf
 import plugins
 from plugins import *
 from common.log import logger
@@ -45,40 +46,49 @@ class SDWebUI(Plugin):
         try:
             content = e_context['context'].content[:]
             # 解析用户输入 如"横版 高清 二次元:cat"
-            keywords, prompt = content.split(":", 1)
+            if ":" in content:
+                keywords, prompt = content.split(":", 1)
+            else:
+                keywords = content
+                prompt = ""
+
             keywords = keywords.split()
 
-            rule_params = {}
-            rule_options = {}
-            for keyword in keywords:
-                matched = False
-                for rule in self.rules:
-                    if keyword in rule["keywords"]:
-                        for key in rule["params"]:
-                            rule_params[key] = rule["params"][key]
-                        if "options" in rule:
-                            for key in rule["options"]:
-                                rule_options[key] = rule["options"][key]
-                        matched = True
-                        break  # 一个关键词只匹配一个规则
-                if not matched:
-                    logger.warning("[SD] keyword not matched: %s" % keyword)
-            
-            params = {**self.default_params, **rule_params}
-            options = {**self.default_options, **rule_options}
-            params["prompt"] = params.get("prompt", "")+f", {prompt}"
-            if len(options) > 0:
-                logger.info("[SD] cover rule_options={}".format(rule_options))
-                self.api.set_options(options)
-            logger.info("[SD] params={}".format(params))
-            result = self.api.txt2img(
-                **params
-            )
-            reply.type = ReplyType.IMAGE
-            b_img = io.BytesIO()
-            result.image.save(b_img, format="PNG")
-            reply.content = b_img
-            e_context.action = EventAction.BREAK_PASS  # 事件结束后,不跳过处理context的默认逻辑
+            if "help" in keywords or "帮助" in keywords:
+                reply.type = ReplyType.INFO
+                reply.content = self.get_help_text()
+            else:
+                rule_params = {}
+                rule_options = {}
+                for keyword in keywords:
+                    matched = False
+                    for rule in self.rules:
+                        if keyword in rule["keywords"]:
+                            for key in rule["params"]:
+                                rule_params[key] = rule["params"][key]
+                            if "options" in rule:
+                                for key in rule["options"]:
+                                    rule_options[key] = rule["options"][key]
+                            matched = True
+                            break  # 一个关键词只匹配一个规则
+                    if not matched:
+                        logger.warning("[SD] keyword not matched: %s" % keyword)
+                
+                params = {**self.default_params, **rule_params}
+                options = {**self.default_options, **rule_options}
+                params["prompt"] = params.get("prompt", "")+f", {prompt}"
+                if len(options) > 0:
+                    logger.info("[SD] cover rule_options={}".format(rule_options))
+                    self.api.set_options(options)
+                logger.info("[SD] params={}".format(params))
+                result = self.api.txt2img(
+                    **params
+                )
+                reply.type = ReplyType.IMAGE
+                b_img = io.BytesIO()
+                result.image.save(b_img, format="PNG")
+                reply.content = b_img
+            e_context.action = EventAction.BREAK_PASS  # 事件结束后,跳过处理context的默认逻辑
         except Exception as e:
             reply.type = ReplyType.ERROR
             reply.content = "[SD] "+str(e)
@@ -86,3 +96,19 @@ class SDWebUI(Plugin):
             e_context.action = EventAction.CONTINUE  # 事件继续,交付给下个插件或默认逻辑
         finally:
             e_context['reply'] = reply
+
+    def get_help_text(self):
+        if not conf().get('image_create_prefix'):
+            return "画图功能未启用"
+        else:
+            trigger = conf()['image_create_prefix'][0]
+        help_text = f"请使用<{trigger}[关键词1] [关键词2]...:提示语>的格式作画,如\"{trigger}横版 高清:cat\"\n"
+        help_text += "目前可用关键词:\n"
+        for rule in self.rules:
+            keywords = [f"[{keyword}]" for keyword in rule['keywords']]
+            help_text += f"{','.join(keywords)}"
+            if "desc" in rule:
+                help_text += f"-{rule['desc']}\n"
+            else:
+                help_text += "\n"
+        return help_text