|
@@ -58,23 +58,24 @@ class MJBot:
|
|
|
self.temp_dict = {}
|
|
self.temp_dict = {}
|
|
|
self.tasks_lock = threading.Lock()
|
|
self.tasks_lock = threading.Lock()
|
|
|
self.event_loop = asyncio.new_event_loop()
|
|
self.event_loop = asyncio.new_event_loop()
|
|
|
- threading.Thread(name="mj-check-thread", target=self._run_loop, args=(self.event_loop,)).start()
|
|
|
|
|
|
|
+ # threading.Thread(name="mj-check-thread", target=self._run_loop, args=(self.event_loop,)).start()
|
|
|
|
|
|
|
|
- def judge_mj_task_type(self, e_context: EventContext) -> TaskType:
|
|
|
|
|
|
|
+ def judge_mj_task_type(self, e_context: EventContext):
|
|
|
"""
|
|
"""
|
|
|
判断MJ任务的类型
|
|
判断MJ任务的类型
|
|
|
:param e_context: 上下文
|
|
:param e_context: 上下文
|
|
|
:return: 任务类型枚举
|
|
:return: 任务类型枚举
|
|
|
"""
|
|
"""
|
|
|
|
|
+ if not self.config or not self.config.get("enabled"):
|
|
|
|
|
+ return None
|
|
|
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
|
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
|
|
context = e_context['context']
|
|
context = e_context['context']
|
|
|
if context.type == ContextType.TEXT:
|
|
if context.type == ContextType.TEXT:
|
|
|
- if self.config and self.config.get("enabled"):
|
|
|
|
|
- cmd_list = context.content.split(maxsplit=1)
|
|
|
|
|
- if cmd_list[0].lower() == f"{trigger_prefix}mj":
|
|
|
|
|
- return TaskType.GENERATE
|
|
|
|
|
- elif cmd_list[0].lower() == f"{trigger_prefix}mju":
|
|
|
|
|
- return TaskType.UPSCALE
|
|
|
|
|
|
|
+ cmd_list = context.content.split(maxsplit=1)
|
|
|
|
|
+ if cmd_list[0].lower() == f"{trigger_prefix}mj":
|
|
|
|
|
+ return TaskType.GENERATE
|
|
|
|
|
+ elif cmd_list[0].lower() == f"{trigger_prefix}mju":
|
|
|
|
|
+ return TaskType.UPSCALE
|
|
|
elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"):
|
|
elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"):
|
|
|
return TaskType.GENERATE
|
|
return TaskType.GENERATE
|
|
|
|
|
|
|
@@ -142,7 +143,7 @@ class MJBot:
|
|
|
logger.info(f"[MJ] image generate, prompt={prompt}")
|
|
logger.info(f"[MJ] image generate, prompt={prompt}")
|
|
|
mode = self._fetch_mode(prompt)
|
|
mode = self._fetch_mode(prompt)
|
|
|
body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")}
|
|
body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")}
|
|
|
- res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers)
|
|
|
|
|
|
|
+ res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5,15))
|
|
|
if res.status_code == 200:
|
|
if res.status_code == 200:
|
|
|
res = res.json()
|
|
res = res.json()
|
|
|
logger.debug(f"[MJ] image generate, res={res}")
|
|
logger.debug(f"[MJ] image generate, res={res}")
|
|
@@ -152,7 +153,7 @@ class MJBot:
|
|
|
if mode == TaskMode.RELAX.value:
|
|
if mode == TaskMode.RELAX.value:
|
|
|
time_str = "1~10分钟"
|
|
time_str = "1~10分钟"
|
|
|
else:
|
|
else:
|
|
|
- time_str = "1~2分钟"
|
|
|
|
|
|
|
+ time_str = "1分钟"
|
|
|
content = f"🚀您的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n"
|
|
content = f"🚀您的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n"
|
|
|
if real_prompt:
|
|
if real_prompt:
|
|
|
content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}"
|
|
content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}"
|
|
@@ -162,7 +163,8 @@ class MJBot:
|
|
|
task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id, task_type=TaskType.GENERATE)
|
|
task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id, task_type=TaskType.GENERATE)
|
|
|
# put to memory dict
|
|
# put to memory dict
|
|
|
self.tasks[task.id] = task
|
|
self.tasks[task.id] = task
|
|
|
- asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
|
|
|
|
|
|
|
+ # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
|
|
|
|
|
+ self._do_check_task(task, e_context)
|
|
|
return reply
|
|
return reply
|
|
|
else:
|
|
else:
|
|
|
res_json = res.json()
|
|
res_json = res.json()
|
|
@@ -173,7 +175,7 @@ class MJBot:
|
|
|
def upscale(self, user_id: str, img_id: str, index: int, e_context: EventContext) -> Reply:
|
|
def upscale(self, user_id: str, img_id: str, index: int, e_context: EventContext) -> Reply:
|
|
|
logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}")
|
|
logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}")
|
|
|
body = {"type": TaskType.UPSCALE.name, "imgId": img_id, "index": index}
|
|
body = {"type": TaskType.UPSCALE.name, "imgId": img_id, "index": index}
|
|
|
- res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers)
|
|
|
|
|
|
|
+ res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5,15))
|
|
|
logger.debug(res)
|
|
logger.debug(res)
|
|
|
if res.status_code == 200:
|
|
if res.status_code == 200:
|
|
|
res = res.json()
|
|
res = res.json()
|
|
@@ -187,7 +189,8 @@ class MJBot:
|
|
|
self.tasks[task.id] = task
|
|
self.tasks[task.id] = task
|
|
|
key = f"{TaskType.UPSCALE.name}_{img_id}_{index}"
|
|
key = f"{TaskType.UPSCALE.name}_{img_id}_{index}"
|
|
|
self.temp_dict[key] = True
|
|
self.temp_dict[key] = True
|
|
|
- asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
|
|
|
|
|
|
|
+ # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
|
|
|
|
|
+ self._do_check_task(task, e_context)
|
|
|
return reply
|
|
return reply
|
|
|
else:
|
|
else:
|
|
|
error_msg = ""
|
|
error_msg = ""
|
|
@@ -198,7 +201,36 @@ class MJBot:
|
|
|
reply = Reply(ReplyType.ERROR, error_msg or "图片生成失败,请稍后再试")
|
|
reply = Reply(ReplyType.ERROR, error_msg or "图片生成失败,请稍后再试")
|
|
|
return reply
|
|
return reply
|
|
|
|
|
|
|
|
- async def check_task(self, task: MJTask, e_context: EventContext):
|
|
|
|
|
|
|
+ def check_task_sync(self, task: MJTask, e_context: EventContext):
|
|
|
|
|
+ logger.debug(f"[MJ] start check task status, {task}")
|
|
|
|
|
+ max_retry_times = 90
|
|
|
|
|
+ while max_retry_times > 0:
|
|
|
|
|
+ time.sleep(10)
|
|
|
|
|
+ url = f"{self.base_url}/tasks/{task.id}"
|
|
|
|
|
+ try:
|
|
|
|
|
+ res = requests.get(url, headers=self.headers, timeout=5)
|
|
|
|
|
+ if res.status_code == 200:
|
|
|
|
|
+ res_json = res.json()
|
|
|
|
|
+ logger.debug(f"[MJ] task check res sync, task_id={task.id}, status={res.status_code}, "
|
|
|
|
|
+ f"data={res_json.get('data')}, thread={threading.current_thread().name}")
|
|
|
|
|
+ if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name:
|
|
|
|
|
+ # process success res
|
|
|
|
|
+ if self.tasks.get(task.id):
|
|
|
|
|
+ self.tasks[task.id].status = Status.FINISHED
|
|
|
|
|
+ self._process_success_task(task, res_json.get("data"), e_context)
|
|
|
|
|
+ return
|
|
|
|
|
+ else:
|
|
|
|
|
+ res_json = res.json()
|
|
|
|
|
+ logger.warn(f"[MJ] image check error, status_code={res.status_code}, res={res_json}")
|
|
|
|
|
+ max_retry_times -= 20
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ max_retry_times -= 20
|
|
|
|
|
+ logger.warn(e)
|
|
|
|
|
+ logger.warn("[MJ] end from poll")
|
|
|
|
|
+ if self.tasks.get(task.id):
|
|
|
|
|
+ self.tasks[task.id].status = Status.EXPIRED
|
|
|
|
|
+
|
|
|
|
|
+ async def check_task_async(self, task: MJTask, e_context: EventContext):
|
|
|
try:
|
|
try:
|
|
|
logger.debug(f"[MJ] start check task status, {task}")
|
|
logger.debug(f"[MJ] start check task status, {task}")
|
|
|
max_retry_times = 90
|
|
max_retry_times = 90
|
|
@@ -232,6 +264,9 @@ class MJBot:
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
logger.error(e)
|
|
logger.error(e)
|
|
|
|
|
|
|
|
|
|
+ def _do_check_task(self, task: MJTask, e_context: EventContext):
|
|
|
|
|
+ threading.Thread(target=self.check_task_sync, args=(task, e_context)).start()
|
|
|
|
|
+
|
|
|
def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext):
|
|
def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext):
|
|
|
"""
|
|
"""
|
|
|
处理任务成功的结果
|
|
处理任务成功的结果
|