瀏覽代碼

Merge pull request #1358 from zhayujie/feat-1.3.5

feat: add midjourney variation and reset
zhayujie 2 年之前
父節點
當前提交
4da8714124
共有 2 個文件被更改,包括 64 次插入55 次删除
  1. 3 1
      plugins/linkai/linkai.py
  2. 61 54
      plugins/linkai/midjourney.py

+ 3 - 1
plugins/linkai/linkai.py

@@ -131,7 +131,9 @@ class LinkAI(Plugin):
         if not verbose:
             return help_text
         help_text += f'📖 知识库\n - 群聊中指定应用: {trigger_prefix}linkai app 应用编码\n\n例如: \n"$linkai app Kv2fXJcH"\n\n'
-        help_text += f"🎨 绘画\n - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\""
+        help_text += f"🎨 绘画\n - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n - 变换: mjv 图片ID 图片序号\n - 重置: mjr 图片ID"
+        help_text += f"\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 11055927171882 2\""
+        help_text += f"\n\"{trigger_prefix}mjv 11055927171882 2\"\n\"{trigger_prefix}mjr 11055927171882\""
         return help_text
 
 

+ 61 - 54
plugins/linkai/midjourney.py

@@ -11,6 +11,9 @@ from bridge.context import ContextType
 from plugins import EventContext, EventAction
 
 INVALID_REQUEST = 410
+NOT_FOUND_ORIGIN_IMAGE = 461
+NOT_FOUND_TASK = 462
+
 
 class TaskType(Enum):
     GENERATE = "generate"
@@ -18,6 +21,9 @@ class TaskType(Enum):
     VARIATION = "variation"
     RESET = "reset"
 
+    def __str__(self):
+        return self.name
+
 
 class Status(Enum):
     PENDING = "pending"
@@ -34,6 +40,14 @@ class TaskMode(Enum):
     RELAX = "relax"
 
 
+task_name_mapping = {
+    TaskType.GENERATE.name: "生成",
+    TaskType.UPSCALE.name: "放大",
+    TaskType.VARIATION.name: "变换",
+    TaskType.RESET.name: "重新生成",
+}
+
+
 class MJTask:
     def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int = 60 * 30,
                  status=Status.PENDING):
@@ -79,6 +93,10 @@ class MJBot:
                 return TaskType.GENERATE
             elif cmd_list[0].lower() == f"{trigger_prefix}mju":
                 return TaskType.UPSCALE
+            elif cmd_list[0].lower() == f"{trigger_prefix}mjv":
+                return TaskType.VARIATION
+            elif cmd_list[0].lower() == f"{trigger_prefix}mjr":
+                return TaskType.RESET
         elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"):
             return TaskType.GENERATE
 
@@ -127,8 +145,8 @@ class MJBot:
             e_context.action = EventAction.BREAK_PASS
             return
 
-        elif mj_type == TaskType.UPSCALE:
-            # 图片放大
+        elif mj_type == TaskType.UPSCALE or mj_type == TaskType.VARIATION:
+            # 图片放大/变换
             clist = cmd[1].split()
             if len(clist) < 2:
                 self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context)
@@ -138,16 +156,27 @@ class MJBot:
             if index < 1 or index > 4:
                 self._set_reply_text(f"图片序号 {index} 错误,应在 1 至 4 之间", e_context)
                 return
-            key = f"{TaskType.UPSCALE.name}_{img_id}_{index}"
+            key = f"{str(mj_type)}_{img_id}_{index}"
             if self.temp_dict.get(key):
-                self._set_reply_text(f"第 {index} 张图片已经放大过了", e_context)
+                self._set_reply_text(f"第 {index} 张图片已经{task_name_mapping.get(str(mj_type))}过了", e_context)
                 return
-            # 图片放大操作
-            reply = self.upscale(session_id, img_id, index, e_context)
+            # 执行图片放大/变换操作
+            reply = self.do_operate(mj_type, session_id, img_id, e_context, index)
             e_context['reply'] = reply
             e_context.action = EventAction.BREAK_PASS
             return
 
+        elif mj_type == TaskType.RESET:
+            # 图片重新生成
+            clist = cmd[1].split()
+            if len(clist) < 1:
+                self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context)
+                return
+            img_id = clist[0]
+            # 图片重新生成
+            reply = self.do_operate(mj_type, session_id, img_id, e_context)
+            e_context['reply'] = reply
+            e_context.action = EventAction.BREAK_PASS
         else:
             self._set_reply_text(f"暂不支持该命令", e_context)
 
@@ -197,9 +226,12 @@ class MJBot:
                 reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试")
             return reply
 
-    def upscale(self, user_id: str, img_id: str, index: int, e_context: EventContext) -> Reply:
-        logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}")
-        body = {"type": TaskType.UPSCALE.name, "img_id": img_id, "index": index}
+    def do_operate(self, task_type: TaskType, user_id: str, img_id: str, e_context: EventContext,
+                   index: int = None) -> Reply:
+        logger.info(f"[MJ] image operate, task_type={task_type}, img_id={img_id}, index={index}")
+        body = {"type": task_type.name, "img_id": img_id}
+        if index:
+            body["index"] = index
         if not self.config.get("img_proxy"):
             body["img_proxy"] = False
         res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5, 40))
@@ -208,23 +240,24 @@ class MJBot:
             res = res.json()
             if res.get("code") == 200:
                 task_id = res.get("data").get("task_id")
-                logger.info(f"[MJ] image upscale processing, task_id={task_id}")
-                content = f"🔎图片正在放大中,请耐心等待"
+                logger.info(f"[MJ] image operate processing, task_id={task_id}")
+                icon_map = {TaskType.UPSCALE: "🔎", TaskType.VARIATION: "🪄", TaskType.RESET: "🔄"}
+                content = f"{icon_map.get(task_type)}图片正在{task_name_mapping.get(task_type.name)}中,请耐心等待"
                 reply = Reply(ReplyType.INFO, content)
-                task = MJTask(id=task_id, status=Status.PENDING, user_id=user_id, task_type=TaskType.UPSCALE)
+                task = MJTask(id=task_id, status=Status.PENDING, user_id=user_id, task_type=task_type)
                 # put to memory dict
                 self.tasks[task.id] = task
-                key = f"{TaskType.UPSCALE.name}_{img_id}_{index}"
+                key = f"{task_type.name}_{img_id}_{index}"
                 self.temp_dict[key] = True
                 # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
                 self._do_check_task(task, e_context)
                 return reply
         else:
             error_msg = ""
-            if res.status_code == 461:
+            if res.status_code == NOT_FOUND_ORIGIN_IMAGE:
                 error_msg = "请输入正确的图片ID"
             res_json = res.json()
-            logger.error(f"[MJ] upscale error, msg={res_json.get('message')}, status_code={res.status_code}")
+            logger.error(f"[MJ] operate error, msg={res_json.get('message')}, status_code={res.status_code}")
             reply = Reply(ReplyType.ERROR, error_msg or "图片生成失败,请稍后再试")
             return reply
 
@@ -258,40 +291,6 @@ class MJBot:
         if self.tasks.get(task.id):
             self.tasks[task.id].status = Status.EXPIRED
 
-    async def check_task_async(self, task: MJTask, e_context: EventContext):
-        try:
-            logger.debug(f"[MJ] start check task status, {task}")
-            max_retry_times = 90
-            while max_retry_times > 0:
-                await asyncio.sleep(10)
-                async with aiohttp.ClientSession() as session:
-                    url = f"{self.base_url}/tasks/{task.id}"
-                    try:
-                        async with session.get(url, headers=self.headers) as res:
-                            if res.status == 200:
-                                res_json = await res.json()
-                                logger.debug(f"[MJ] task check res, task_id={task.id}, status={res.status}, "
-                                             f"data={res_json.get('data')}, thread={threading.current_thread().name}")
-                                if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name:
-                                    # process success res
-                                    if self.tasks.get(task.id):
-                                        self.tasks[task.id].status = Status.FINISHED
-                                    self._process_success_task(task, res_json.get("data"), e_context)
-                                    return
-                            else:
-                                res_json = await res.json()
-                                logger.warn(f"[MJ] image check error, status_code={res.status}, res={res_json}")
-                                max_retry_times -= 20
-                    except Exception as e:
-                        max_retry_times -= 20
-                        logger.warn(e)
-                max_retry_times -= 1
-            logger.warn("[MJ] end from poll")
-            if self.tasks.get(task.id):
-                self.tasks[task.id].status = Status.EXPIRED
-        except Exception as e:
-            logger.error(e)
-
     def _do_check_task(self, task: MJTask, e_context: EventContext):
         threading.Thread(target=self.check_task_sync, args=(task, e_context)).start()
 
@@ -316,10 +315,17 @@ class MJBot:
         # send info
         trigger_prefix = conf().get("plugin_trigger_prefix", "$")
         text = ""
-        if task.task_type == TaskType.GENERATE:
-            text = f"🎨绘画完成!\nprompt: {task.raw_prompt}\n- - - - - - - - -\n图片ID: {task.img_id}"
-            text += f"\n\n🔎可使用 {trigger_prefix}mju 命令放大指定图片\n"
+        if task.task_type == TaskType.GENERATE or task.task_type == TaskType.VARIATION or task.task_type == TaskType.RESET:
+            text = f"🎨绘画完成!\n"
+            if task.raw_prompt:
+                text += f"prompt: {task.raw_prompt}\n"
+            text += f"- - - - - - - - -\n图片ID: {task.img_id}"
+            text += f"\n\n🔎使用 {trigger_prefix}mju 命令放大图片\n"
             text += f"例如:\n{trigger_prefix}mju {task.img_id} 1"
+            text += f"\n\n🪄使用 {trigger_prefix}mjv 命令变换图片\n"
+            text += f"例如:\n{trigger_prefix}mjv {task.img_id} 1"
+            text += f"\n\n🔄使用 {trigger_prefix}mjr 命令重新生成图片\n"
+            text += f"例如:\n{trigger_prefix}mjr {task.img_id}"
             reply = Reply(ReplyType.INFO, text)
             channel._send(reply, e_context["context"])
 
@@ -382,8 +388,9 @@ class MJBot:
         help_text = "🎨利用Midjourney进行画图\n\n"
         if not verbose:
             return help_text
-        help_text += f" - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\""
-
+        help_text += f" - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n - 变换: mjv 图片ID 图片序号\n - 重置: mjr 图片ID"
+        help_text += f"\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 11055927171882 2\""
+        help_text += f"\n\"{trigger_prefix}mjv 11055927171882 2\"\n\"{trigger_prefix}mjr 11055927171882\""
         return help_text
 
     def find_tasks_by_user_id(self, user_id) -> list: