|
@@ -10,6 +10,7 @@ import asyncio
|
|
|
from bridge.context import ContextType
|
|
from bridge.context import ContextType
|
|
|
from plugins import EventContext, EventAction
|
|
from plugins import EventContext, EventAction
|
|
|
|
|
|
|
|
|
|
+INVALID_REQUEST = 410
|
|
|
|
|
|
|
|
class TaskType(Enum):
|
|
class TaskType(Enum):
|
|
|
GENERATE = "generate"
|
|
GENERATE = "generate"
|
|
@@ -34,7 +35,8 @@ class TaskMode(Enum):
|
|
|
|
|
|
|
|
|
|
|
|
|
class MJTask:
|
|
class MJTask:
|
|
|
- def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int=60*30, status=Status.PENDING):
|
|
|
|
|
|
|
+ def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int = 60 * 30,
|
|
|
|
|
+ status=Status.PENDING):
|
|
|
self.id = id
|
|
self.id = id
|
|
|
self.user_id = user_id
|
|
self.user_id = user_id
|
|
|
self.task_type = task_type
|
|
self.task_type = task_type
|
|
@@ -48,17 +50,18 @@ class MJTask:
|
|
|
def __str__(self):
|
|
def __str__(self):
|
|
|
return f"id={self.id}, user_id={self.user_id}, task_type={self.task_type}, status={self.status}, img_id={self.img_id}"
|
|
return f"id={self.id}, user_id={self.user_id}, task_type={self.task_type}, status={self.status}, img_id={self.img_id}"
|
|
|
|
|
|
|
|
|
|
+
|
|
|
# midjourney bot
|
|
# midjourney bot
|
|
|
class MJBot:
|
|
class MJBot:
|
|
|
def __init__(self, config):
|
|
def __init__(self, config):
|
|
|
self.base_url = "https://api.link-ai.chat/v1/img/midjourney"
|
|
self.base_url = "https://api.link-ai.chat/v1/img/midjourney"
|
|
|
|
|
+
|
|
|
self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
|
self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
|
|
self.config = config
|
|
self.config = config
|
|
|
self.tasks = {}
|
|
self.tasks = {}
|
|
|
self.temp_dict = {}
|
|
self.temp_dict = {}
|
|
|
self.tasks_lock = threading.Lock()
|
|
self.tasks_lock = threading.Lock()
|
|
|
self.event_loop = asyncio.new_event_loop()
|
|
self.event_loop = asyncio.new_event_loop()
|
|
|
- # threading.Thread(name="mj-check-thread", target=self._run_loop, args=(self.event_loop,)).start()
|
|
|
|
|
|
|
|
|
|
def judge_mj_task_type(self, e_context: EventContext):
|
|
def judge_mj_task_type(self, e_context: EventContext):
|
|
|
"""
|
|
"""
|
|
@@ -79,7 +82,6 @@ class MJBot:
|
|
|
elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"):
|
|
elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"):
|
|
|
return TaskType.GENERATE
|
|
return TaskType.GENERATE
|
|
|
|
|
|
|
|
-
|
|
|
|
|
def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
|
|
def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
|
|
|
"""
|
|
"""
|
|
|
处理mj任务
|
|
处理mj任务
|
|
@@ -143,13 +145,15 @@ class MJBot:
|
|
|
logger.info(f"[MJ] image generate, prompt={prompt}")
|
|
logger.info(f"[MJ] image generate, prompt={prompt}")
|
|
|
mode = self._fetch_mode(prompt)
|
|
mode = self._fetch_mode(prompt)
|
|
|
body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")}
|
|
body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")}
|
|
|
- res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5,15))
|
|
|
|
|
|
|
+ if not self.config.get("img_proxy"):
|
|
|
|
|
+ body["img_proxy"] = False
|
|
|
|
|
+ res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5, 40))
|
|
|
if res.status_code == 200:
|
|
if res.status_code == 200:
|
|
|
res = res.json()
|
|
res = res.json()
|
|
|
logger.debug(f"[MJ] image generate, res={res}")
|
|
logger.debug(f"[MJ] image generate, res={res}")
|
|
|
if res.get("code") == 200:
|
|
if res.get("code") == 200:
|
|
|
- task_id = res.get("data").get("taskId")
|
|
|
|
|
- real_prompt = res.get("data").get("realPrompt")
|
|
|
|
|
|
|
+ task_id = res.get("data").get("task_id")
|
|
|
|
|
+ real_prompt = res.get("data").get("real_prompt")
|
|
|
if mode == TaskMode.RELAX.value:
|
|
if mode == TaskMode.RELAX.value:
|
|
|
time_str = "1~10分钟"
|
|
time_str = "1~10分钟"
|
|
|
else:
|
|
else:
|
|
@@ -160,7 +164,8 @@ class MJBot:
|
|
|
else:
|
|
else:
|
|
|
content += f"prompt: {prompt}"
|
|
content += f"prompt: {prompt}"
|
|
|
reply = Reply(ReplyType.INFO, content)
|
|
reply = Reply(ReplyType.INFO, content)
|
|
|
- task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id, task_type=TaskType.GENERATE)
|
|
|
|
|
|
|
+ task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id,
|
|
|
|
|
+ task_type=TaskType.GENERATE)
|
|
|
# put to memory dict
|
|
# put to memory dict
|
|
|
self.tasks[task.id] = task
|
|
self.tasks[task.id] = task
|
|
|
# asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
|
|
# asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
|
|
@@ -169,18 +174,23 @@ class MJBot:
|
|
|
else:
|
|
else:
|
|
|
res_json = res.json()
|
|
res_json = res.json()
|
|
|
logger.error(f"[MJ] generate error, msg={res_json.get('message')}, status_code={res.status_code}")
|
|
logger.error(f"[MJ] generate error, msg={res_json.get('message')}, status_code={res.status_code}")
|
|
|
- reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试")
|
|
|
|
|
|
|
+ if res.status_code == INVALID_REQUEST:
|
|
|
|
|
+ reply = Reply(ReplyType.ERROR, "图片生成失败,请检查提示词参数或内容")
|
|
|
|
|
+ else:
|
|
|
|
|
+ reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试")
|
|
|
return reply
|
|
return reply
|
|
|
|
|
|
|
|
def upscale(self, user_id: str, img_id: str, index: int, e_context: EventContext) -> Reply:
|
|
def upscale(self, user_id: str, img_id: str, index: int, e_context: EventContext) -> Reply:
|
|
|
logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}")
|
|
logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}")
|
|
|
- body = {"type": TaskType.UPSCALE.name, "imgId": img_id, "index": index}
|
|
|
|
|
- res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5,15))
|
|
|
|
|
|
|
+ body = {"type": TaskType.UPSCALE.name, "img_id": img_id, "index": index}
|
|
|
|
|
+ if not self.config.get("img_proxy"):
|
|
|
|
|
+ body["img_proxy"] = False
|
|
|
|
|
+ res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5, 40))
|
|
|
logger.debug(res)
|
|
logger.debug(res)
|
|
|
if res.status_code == 200:
|
|
if res.status_code == 200:
|
|
|
res = res.json()
|
|
res = res.json()
|
|
|
if res.get("code") == 200:
|
|
if res.get("code") == 200:
|
|
|
- task_id = res.get("data").get("taskId")
|
|
|
|
|
|
|
+ task_id = res.get("data").get("task_id")
|
|
|
logger.info(f"[MJ] image upscale processing, task_id={task_id}")
|
|
logger.info(f"[MJ] image upscale processing, task_id={task_id}")
|
|
|
content = f"🔎图片正在放大中,请耐心等待"
|
|
content = f"🔎图片正在放大中,请耐心等待"
|
|
|
reply = Reply(ReplyType.INFO, content)
|
|
reply = Reply(ReplyType.INFO, content)
|
|
@@ -208,7 +218,7 @@ class MJBot:
|
|
|
time.sleep(10)
|
|
time.sleep(10)
|
|
|
url = f"{self.base_url}/tasks/{task.id}"
|
|
url = f"{self.base_url}/tasks/{task.id}"
|
|
|
try:
|
|
try:
|
|
|
- res = requests.get(url, headers=self.headers, timeout=5)
|
|
|
|
|
|
|
+ res = requests.get(url, headers=self.headers, timeout=8)
|
|
|
if res.status_code == 200:
|
|
if res.status_code == 200:
|
|
|
res_json = res.json()
|
|
res_json = res.json()
|
|
|
logger.debug(f"[MJ] task check res sync, task_id={task.id}, status={res.status_code}, "
|
|
logger.debug(f"[MJ] task check res sync, task_id={task.id}, status={res.status_code}, "
|
|
@@ -219,6 +229,7 @@ class MJBot:
|
|
|
self.tasks[task.id].status = Status.FINISHED
|
|
self.tasks[task.id].status = Status.FINISHED
|
|
|
self._process_success_task(task, res_json.get("data"), e_context)
|
|
self._process_success_task(task, res_json.get("data"), e_context)
|
|
|
return
|
|
return
|
|
|
|
|
+ max_retry_times -= 1
|
|
|
else:
|
|
else:
|
|
|
res_json = res.json()
|
|
res_json = res.json()
|
|
|
logger.warn(f"[MJ] image check error, status_code={res.status_code}, res={res_json}")
|
|
logger.warn(f"[MJ] image check error, status_code={res.status_code}, res={res_json}")
|
|
@@ -276,8 +287,8 @@ class MJBot:
|
|
|
"""
|
|
"""
|
|
|
# channel send img
|
|
# channel send img
|
|
|
task.status = Status.FINISHED
|
|
task.status = Status.FINISHED
|
|
|
- task.img_id = res.get("imgId")
|
|
|
|
|
- task.img_url = res.get("imgUrl")
|
|
|
|
|
|
|
+ task.img_id = res.get("img_id")
|
|
|
|
|
+ task.img_url = res.get("img_url")
|
|
|
logger.info(f"[MJ] task success, task_id={task.id}, img_id={task.img_id}, img_url={task.img_url}")
|
|
logger.info(f"[MJ] task success, task_id={task.id}, img_id={task.img_id}, img_url={task.img_url}")
|
|
|
|
|
|
|
|
# send img
|
|
# send img
|
|
@@ -338,7 +349,7 @@ class MJBot:
|
|
|
for id in self.tasks:
|
|
for id in self.tasks:
|
|
|
logger.debug(f"[MJ] current task: {self.tasks[id]}")
|
|
logger.debug(f"[MJ] current task: {self.tasks[id]}")
|
|
|
|
|
|
|
|
- def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR):
|
|
|
|
|
|
|
+ def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
|
|
|
"""
|
|
"""
|
|
|
设置回复文本
|
|
设置回复文本
|
|
|
:param content: 回复内容
|
|
:param content: 回复内容
|
|
@@ -358,7 +369,7 @@ class MJBot:
|
|
|
|
|
|
|
|
return help_text
|
|
return help_text
|
|
|
|
|
|
|
|
- def find_tasks_by_user_id(self, user_id) -> list[MJTask]:
|
|
|
|
|
|
|
+ def find_tasks_by_user_id(self, user_id) -> list:
|
|
|
result = []
|
|
result = []
|
|
|
with self.tasks_lock:
|
|
with self.tasks_lock:
|
|
|
now = time.time()
|
|
now = time.time()
|
|
@@ -377,4 +388,4 @@ def check_prefix(content, prefix_list):
|
|
|
for prefix in prefix_list:
|
|
for prefix in prefix_list:
|
|
|
if content.startswith(prefix):
|
|
if content.startswith(prefix):
|
|
|
return prefix
|
|
return prefix
|
|
|
- return None
|
|
|
|
|
|
|
+ return None
|