Explorar el Código

fix: midjourney check task thread

zhayujie hace 2 años
padre
commit
e027286b6d
Se han modificado 2 ficheros con 54 adiciones y 15 borrados
  1. 5 1
      plugins/linkai/linkai.py
  2. 49 14
      plugins/linkai/midjourney.py

+ 5 - 1
plugins/linkai/linkai.py

@@ -27,7 +27,8 @@ class LinkAI(Plugin):
         super().__init__()
         self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
         self.config = super().load_config()
-        self.mj_bot = MJBot(self.config.get("midjourney"))
+        if self.config:
+            self.mj_bot = MJBot(self.config.get("midjourney"))
         logger.info("[LinkAI] inited")
 
     def on_handle_context(self, e_context: EventContext):
@@ -35,6 +36,9 @@ class LinkAI(Plugin):
         消息处理逻辑
         :param e_context: 消息上下文
         """
+        if not self.config:
+            return
+
         context = e_context['context']
         if context.type not in [ContextType.TEXT, ContextType.IMAGE, ContextType.IMAGE_CREATE]:
             # filter content no need solve

+ 49 - 14
plugins/linkai/midjourney.py

@@ -58,23 +58,24 @@ class MJBot:
         self.temp_dict = {}
         self.tasks_lock = threading.Lock()
         self.event_loop = asyncio.new_event_loop()
-        threading.Thread(name="mj-check-thread", target=self._run_loop, args=(self.event_loop,)).start()
+        # threading.Thread(name="mj-check-thread", target=self._run_loop, args=(self.event_loop,)).start()
 
-    def judge_mj_task_type(self, e_context: EventContext) -> TaskType:
+    def judge_mj_task_type(self, e_context: EventContext):
         """
         判断MJ任务的类型
         :param e_context: 上下文
         :return: 任务类型枚举
         """
+        if not self.config or not self.config.get("enabled"):
+            return None
         trigger_prefix = conf().get("plugin_trigger_prefix", "$")
         context = e_context['context']
         if context.type == ContextType.TEXT:
-            if self.config and self.config.get("enabled"):
-                cmd_list = context.content.split(maxsplit=1)
-                if cmd_list[0].lower() == f"{trigger_prefix}mj":
-                    return TaskType.GENERATE
-                elif cmd_list[0].lower() == f"{trigger_prefix}mju":
-                    return TaskType.UPSCALE
+            cmd_list = context.content.split(maxsplit=1)
+            if cmd_list[0].lower() == f"{trigger_prefix}mj":
+                return TaskType.GENERATE
+            elif cmd_list[0].lower() == f"{trigger_prefix}mju":
+                return TaskType.UPSCALE
         elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"):
             return TaskType.GENERATE
 
@@ -142,7 +143,7 @@ class MJBot:
         logger.info(f"[MJ] image generate, prompt={prompt}")
         mode = self._fetch_mode(prompt)
         body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")}
-        res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers)
+        res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5,15))
         if res.status_code == 200:
             res = res.json()
             logger.debug(f"[MJ] image generate, res={res}")
@@ -152,7 +153,7 @@ class MJBot:
                 if mode == TaskMode.RELAX.value:
                     time_str = "1~10分钟"
                 else:
-                    time_str = "1~2分钟"
+                    time_str = "1分钟"
                 content = f"🚀您的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n"
                 if real_prompt:
                     content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}"
@@ -162,7 +163,8 @@ class MJBot:
                 task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id, task_type=TaskType.GENERATE)
                 # put to memory dict
                 self.tasks[task.id] = task
-                asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
+                # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
+                self._do_check_task(task, e_context)
                 return reply
         else:
             res_json = res.json()
@@ -173,7 +175,7 @@ class MJBot:
     def upscale(self, user_id: str, img_id: str, index: int, e_context: EventContext) -> Reply:
         logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}")
         body = {"type": TaskType.UPSCALE.name, "imgId": img_id, "index": index}
-        res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers)
+        res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5,15))
         logger.debug(res)
         if res.status_code == 200:
             res = res.json()
@@ -187,7 +189,8 @@ class MJBot:
                 self.tasks[task.id] = task
                 key = f"{TaskType.UPSCALE.name}_{img_id}_{index}"
                 self.temp_dict[key] = True
-                asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
+                # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
+                self._do_check_task(task, e_context)
                 return reply
         else:
             error_msg = ""
@@ -198,7 +201,36 @@ class MJBot:
             reply = Reply(ReplyType.ERROR, error_msg or "图片生成失败,请稍后再试")
             return reply
 
-    async def check_task(self, task: MJTask, e_context: EventContext):
+    def check_task_sync(self, task: MJTask, e_context: EventContext):
+        logger.debug(f"[MJ] start check task status, {task}")
+        max_retry_times = 90
+        while max_retry_times > 0:
+            time.sleep(10)
+            url = f"{self.base_url}/tasks/{task.id}"
+            try:
+                res = requests.get(url, headers=self.headers, timeout=5)
+                if res.status_code == 200:
+                    res_json = res.json()
+                    logger.debug(f"[MJ] task check res sync, task_id={task.id}, status={res.status_code}, "
+                                 f"data={res_json.get('data')}, thread={threading.current_thread().name}")
+                    if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name:
+                        # process success res
+                        if self.tasks.get(task.id):
+                            self.tasks[task.id].status = Status.FINISHED
+                        self._process_success_task(task, res_json.get("data"), e_context)
+                        return
+                else:
+                    res_json = res.json()
+                    logger.warn(f"[MJ] image check error, status_code={res.status_code}, res={res_json}")
+                    max_retry_times -= 20
+            except Exception as e:
+                max_retry_times -= 20
+                logger.warn(e)
+        logger.warn("[MJ] end from poll")
+        if self.tasks.get(task.id):
+            self.tasks[task.id].status = Status.EXPIRED
+
+    async def check_task_async(self, task: MJTask, e_context: EventContext):
         try:
             logger.debug(f"[MJ] start check task status, {task}")
             max_retry_times = 90
@@ -232,6 +264,9 @@ class MJBot:
         except Exception as e:
             logger.error(e)
 
+    def _do_check_task(self, task: MJTask, e_context: EventContext):
+        threading.Thread(target=self.check_task_sync, args=(task, e_context)).start()
+
     def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext):
         """
         处理任务成功的结果