Răsfoiți Sursa

fix: some image bug

zhayujie 2 ani în urmă
părinte
comite
b22994c2d2
2 a modificat fișierele cu 29 adăugiri și 17 ștergeri
  1. 1 0
      plugins/linkai/config.json.template
  2. 28 17
      plugins/linkai/midjourney.py

+ 1 - 0
plugins/linkai/config.json.template

@@ -7,6 +7,7 @@
         "enabled": true,
         "mode": "relax",
         "auto_translate": true,
+        "img_proxy": true,
         "max_tasks": 3,
         "max_tasks_per_user": 1,
         "use_image_create_prefix": true

+ 28 - 17
plugins/linkai/midjourney.py

@@ -10,6 +10,7 @@ import asyncio
 from bridge.context import ContextType
 from plugins import EventContext, EventAction
 
+INVALID_REQUEST = 410
 
 class TaskType(Enum):
     GENERATE = "generate"
@@ -34,7 +35,8 @@ class TaskMode(Enum):
 
 
 class MJTask:
-    def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int=60*30, status=Status.PENDING):
+    def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int = 60 * 30,
+                 status=Status.PENDING):
         self.id = id
         self.user_id = user_id
         self.task_type = task_type
@@ -48,17 +50,18 @@ class MJTask:
     def __str__(self):
         return f"id={self.id}, user_id={self.user_id}, task_type={self.task_type}, status={self.status}, img_id={self.img_id}"
 
+
 # midjourney bot
 class MJBot:
     def __init__(self, config):
         self.base_url = "https://api.link-ai.chat/v1/img/midjourney"
+
         self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
         self.config = config
         self.tasks = {}
         self.temp_dict = {}
         self.tasks_lock = threading.Lock()
         self.event_loop = asyncio.new_event_loop()
-        # threading.Thread(name="mj-check-thread", target=self._run_loop, args=(self.event_loop,)).start()
 
     def judge_mj_task_type(self, e_context: EventContext):
         """
@@ -79,7 +82,6 @@ class MJBot:
         elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"):
             return TaskType.GENERATE
 
-
     def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
         """
         处理mj任务
@@ -143,13 +145,15 @@ class MJBot:
         logger.info(f"[MJ] image generate, prompt={prompt}")
         mode = self._fetch_mode(prompt)
         body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")}
-        res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5,15))
+        if not self.config.get("img_proxy"):
+            body["img_proxy"] = False
+        res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5, 40))
         if res.status_code == 200:
             res = res.json()
             logger.debug(f"[MJ] image generate, res={res}")
             if res.get("code") == 200:
-                task_id = res.get("data").get("taskId")
-                real_prompt = res.get("data").get("realPrompt")
+                task_id = res.get("data").get("task_id")
+                real_prompt = res.get("data").get("real_prompt")
                 if mode == TaskMode.RELAX.value:
                     time_str = "1~10分钟"
                 else:
@@ -160,7 +164,8 @@ class MJBot:
                 else:
                     content += f"prompt: {prompt}"
                 reply = Reply(ReplyType.INFO, content)
-                task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id, task_type=TaskType.GENERATE)
+                task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id,
+                              task_type=TaskType.GENERATE)
                 # put to memory dict
                 self.tasks[task.id] = task
                 # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
@@ -169,18 +174,23 @@ class MJBot:
         else:
             res_json = res.json()
             logger.error(f"[MJ] generate error, msg={res_json.get('message')}, status_code={res.status_code}")
-            reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试")
+            if res.status_code == INVALID_REQUEST:
+                reply = Reply(ReplyType.ERROR, "图片生成失败,请检查提示词参数或内容")
+            else:
+                reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试")
             return reply
 
     def upscale(self, user_id: str, img_id: str, index: int, e_context: EventContext) -> Reply:
         logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}")
-        body = {"type": TaskType.UPSCALE.name, "imgId": img_id, "index": index}
-        res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5,15))
+        body = {"type": TaskType.UPSCALE.name, "img_id": img_id, "index": index}
+        if not self.config.get("img_proxy"):
+            body["img_proxy"] = False
+        res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5, 40))
         logger.debug(res)
         if res.status_code == 200:
             res = res.json()
             if res.get("code") == 200:
-                task_id = res.get("data").get("taskId")
+                task_id = res.get("data").get("task_id")
                 logger.info(f"[MJ] image upscale processing, task_id={task_id}")
                 content = f"🔎图片正在放大中,请耐心等待"
                 reply = Reply(ReplyType.INFO, content)
@@ -208,7 +218,7 @@ class MJBot:
             time.sleep(10)
             url = f"{self.base_url}/tasks/{task.id}"
             try:
-                res = requests.get(url, headers=self.headers, timeout=5)
+                res = requests.get(url, headers=self.headers, timeout=8)
                 if res.status_code == 200:
                     res_json = res.json()
                     logger.debug(f"[MJ] task check res sync, task_id={task.id}, status={res.status_code}, "
@@ -219,6 +229,7 @@ class MJBot:
                             self.tasks[task.id].status = Status.FINISHED
                         self._process_success_task(task, res_json.get("data"), e_context)
                         return
+                    max_retry_times -= 1
                 else:
                     res_json = res.json()
                     logger.warn(f"[MJ] image check error, status_code={res.status_code}, res={res_json}")
@@ -276,8 +287,8 @@ class MJBot:
         """
         # channel send img
         task.status = Status.FINISHED
-        task.img_id = res.get("imgId")
-        task.img_url = res.get("imgUrl")
+        task.img_id = res.get("img_id")
+        task.img_url = res.get("img_url")
         logger.info(f"[MJ] task success, task_id={task.id}, img_id={task.img_id}, img_url={task.img_url}")
 
         # send img
@@ -338,7 +349,7 @@ class MJBot:
         for id in self.tasks:
             logger.debug(f"[MJ] current task: {self.tasks[id]}")
 
-    def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR):
+    def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
         """
         设置回复文本
         :param content: 回复内容
@@ -358,7 +369,7 @@ class MJBot:
 
         return help_text
 
-    def find_tasks_by_user_id(self, user_id) -> list[MJTask]:
+    def find_tasks_by_user_id(self, user_id) -> list:
         result = []
         with self.tasks_lock:
             now = time.time()
@@ -377,4 +388,4 @@ def check_prefix(content, prefix_list):
     for prefix in prefix_list:
         if content.startswith(prefix):
             return prefix
-    return None
+    return None