Browse Source

plugins: add sdwebui(stable diffusion) plugin

lanvent 3 years ago
parent
commit
e6d148e729

+ 0 - 0
plugins/sdwebui/__init__.py


+ 67 - 0
plugins/sdwebui/config.json.template

@@ -0,0 +1,67 @@
+{
+  "start":{
+    "host" : "127.0.0.1",
+    "port" : 7860
+  },
+  "defaults": {
+    "params": {
+      "sampler_name": "DPM++ 2M Karras",
+      "steps": 20,
+      "width": 512,
+      "height": 512,
+      "cfg_scale": 7,
+      "prompt":"masterpiece, best quality",
+      "negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
+      "enable_hr": false,
+      "hr_scale": 2,
+      "hr_upscaler": "Latent",
+      "hr_second_pass_steps": 15,
+      "denoising_strength": 0.7
+    },
+    "options": {
+      "sd_model_checkpoint": "perfectWorld_v2Baked"
+    }
+  },
+  "rules": [
+    {
+      "keywords": [
+        "横版",
+        "壁纸"
+      ],
+      "params": {
+        "width": 640,
+        "height": 384
+      }
+    },
+    {
+      "keywords": [
+        "竖版"
+      ],
+      "params": {
+        "width": 384,
+        "height": 640
+      }
+    },
+    {
+      "keywords": [
+        "高清"
+      ],
+      "params": {
+        "enable_hr": true,
+        "hr_scale": 1.6
+      }
+    },
+    {
+      "keywords": [
+        "二次元"
+      ],
+      "params": {
+        "negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
+        "prompt": "masterpiece, best quality"
+      },
+      "options": {
+        "sd_model_checkpoint": "meinamix_meinaV8"
+      }
+    }
+  ]
+}

+ 63 - 0
plugins/sdwebui/readme.md

@@ -0,0 +1,63 @@
+本插件用于将画图请求转发给stable diffusion webui
+使用前先安装stable diffusion webui,并在它的启动参数中添加 "--api"
+具体参考(https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)
+
+请**安装**本插件的依赖包```webuiapi```
+```
+    ```pip install webuiapi```
+```
+请将```config.json.template```复制为```config.json```,并修改其中的参数和规则
+
+用户的画图请求格式为:
+```
+    <画图触发词><关键词1> <关键词2> ... <关键词n>:<prompt> 
+```
+本插件会对画图触发词后的关键词进行逐个匹配,如果触发了规则中的关键词,则会在画图请求中重载对应的参数。
+规则会按顺序匹配,每个关键词最多匹配到1次,如果有重复的参数,则以最后一个为准:
+第一个"**:**"号之后的内容会作为附加的**prompt**,接在最终的prompt后
+
+例如: 画横版 高清 二次元:cat
+会触发三个关键词 "横版", "高清", "二次元",prompt为"cat"
+若默认参数是:
+```
+    "width": 512,
+    "height": 512,
+    "enable_hr": false,
+    "prompt": "8k"
+    "negative_prompt": "nsfw",
+    "sd_model_checkpoint": "perfectWorld_v2Baked"
+```
+
+"横版"触发的规则参数为:
+```
+    "width": 640,
+    "height": 384,
+```
+"高清"触发的规则参数为:
+```
+    "enable_hr": true,
+    "hr_scale": 1.6,
+```
+"二次元"触发的规则参数为:
+```
+    "negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
+    "steps": 20,
+    "prompt": "masterpiece, best quality",
+
+    "sd_model_checkpoint": "meinamix_meinaV8"
+```
+最后将第一个":"后的内容cat连接在prompt后,得到最终参数为:
+```
+    "width": 640,
+    "height": 384,
+    "enable_hr": true,
+    "hr_scale": 1.6,
+    "negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
+    "steps": 20,
+    "prompt": "masterpiece, best quality, cat",
+    
+    "sd_model_checkpoint": "meinamix_meinaV8"
+```
+PS: 参数分为两部分,
+一部分是params,为画画的参数;参数名**必须**与webuiapi包中[txt2img api](https://github.com/mix1009/sdwebuiapi/blob/fb2054e149c0a4e25125c0cd7e7dca06bda839d4/webuiapi/webuiapi.py#L163)的参数名一致
+另一部分是options,指sdwebui的设置,使用的模型和vae需要写在里面。它和http://127.0.0.1:7860/sdapi/v1/options所返回的键一致。

+ 88 - 0
plugins/sdwebui/sdwebui.py

@@ -0,0 +1,88 @@
+# encoding:utf-8
+
+import json
+import os
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+import plugins
+from plugins import *
+from common.log import logger
+import webuiapi
+import io
+
+
+@plugins.register(name="sdwebui", desc="利用stable-diffusion webui来画图", version="2.0", author="lanvent")
+class SDWebUI(Plugin):
+    def __init__(self):
+        super().__init__()
+        curdir = os.path.dirname(__file__)
+        config_path = os.path.join(curdir, "config.json")
+        try:
+            with open(config_path, "r", encoding="utf-8") as f:
+                config = json.load(f)
+                self.rules = config["rules"]
+                defaults = config["defaults"]
+                self.default_params = defaults["params"]
+                self.default_options = defaults["options"]
+                self.start_args = config["start"]
+                self.api = webuiapi.WebUIApi(**self.start_args)
+            self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
+            logger.info("[SD] inited")
+        except FileNotFoundError:
+            logger.error(f"[SD] init failed, {config_path} not found")
+        except Exception as e:
+            logger.error("[SD] init failed, exception: %s" % e)
+    
+    def on_handle_context(self, e_context: EventContext):
+
+        if e_context['context'].type != ContextType.IMAGE_CREATE:
+            return
+
+        logger.debug("[SD] on_handle_context. content: %s" %e_context['context'].content)
+
+        logger.info("[SD] image_query={}".format(e_context['context'].content))
+        reply = Reply()
+        try:
+            content = e_context['context'].content[:]
+            # 解析用户输入 如"横版 高清 二次元:cat"
+            keywords, prompt = content.split(":", 1)
+            keywords = keywords.split()
+
+            rule_params = {}
+            rule_options = {}
+            for keyword in keywords:
+                matched = False
+                for rule in self.rules:
+                    if keyword in rule["keywords"]:
+                        for key in rule["params"]:
+                            rule_params[key] = rule["params"][key]
+                        if "options" in rule:
+                            for key in rule["options"]:
+                                rule_options[key] = rule["options"][key]
+                        matched = True
+                        break  # 一个关键词只匹配一个规则
+                if not matched:
+                    logger.warning("[SD] keyword not matched: %s" % keyword)
+            
+            params = {**self.default_params, **rule_params}
+            options = {**self.default_options, **rule_options}
+            params["prompt"] = params.get("prompt", "")+f", {prompt}"
+            if len(options) > 0:
+                logger.info("[SD] cover rule_options={}".format(rule_options))
+                self.api.set_options(options)
+            logger.info("[SD] params={}".format(params))
+            result = self.api.txt2img(
+                **params
+            )
+            reply.type = ReplyType.IMAGE
+            b_img = io.BytesIO()
+            result.image.save(b_img, format="PNG")
+            reply.content = b_img
+            e_context.action = EventAction.BREAK_PASS  # 事件结束后,不跳过处理context的默认逻辑
+        except Exception as e:
+            reply.type = ReplyType.ERROR
+            reply.content = "[SD] "+str(e)
+            logger.error("[SD] exception: %s" % e)
+            e_context.action = EventAction.CONTINUE  # 事件继续,交付给下个插件或默认逻辑
+        finally:
+            e_context['reply'] = reply