|
@@ -28,6 +28,11 @@ class Status(Enum):
|
|
|
return self.name
|
|
return self.name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+class TaskMode(Enum):
|
|
|
|
|
+ FAST = "fast"
|
|
|
|
|
+ RELAX = "relax"
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
class MJTask:
|
|
class MJTask:
|
|
|
def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int=60*30, status=Status.PENDING):
|
|
def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int=60*30, status=Status.PENDING):
|
|
|
self.id = id
|
|
self.id = id
|
|
@@ -47,7 +52,6 @@ class MJTask:
|
|
|
class MJBot:
|
|
class MJBot:
|
|
|
def __init__(self, config):
|
|
def __init__(self, config):
|
|
|
self.base_url = "https://api.link-ai.chat/v1/img/midjourney"
|
|
self.base_url = "https://api.link-ai.chat/v1/img/midjourney"
|
|
|
- # self.base_url = "http://127.0.0.1:8911/v1/img/midjourney"
|
|
|
|
|
self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
|
self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
|
|
self.config = config
|
|
self.config = config
|
|
|
self.tasks = {}
|
|
self.tasks = {}
|
|
@@ -71,10 +75,10 @@ class MJBot:
|
|
|
return TaskType.GENERATE
|
|
return TaskType.GENERATE
|
|
|
elif cmd_list[0].lower() == f"{trigger_prefix}mju":
|
|
elif cmd_list[0].lower() == f"{trigger_prefix}mju":
|
|
|
return TaskType.UPSCALE
|
|
return TaskType.UPSCALE
|
|
|
- # elif cmd_list[0].lower() == f"{trigger_prefix}mjv":
|
|
|
|
|
- # return TaskType.VARIATION
|
|
|
|
|
- # elif cmd_list[0].lower() == f"{trigger_prefix}mjr":
|
|
|
|
|
- # return TaskType.RESET
|
|
|
|
|
|
|
+ elif self.config.get("use_image_create_prefix") and \
|
|
|
|
|
+ check_prefix(context.content, conf().get("image_create_prefix")):
|
|
|
|
|
+ return TaskType.GENERATE
|
|
|
|
|
+
|
|
|
|
|
|
|
|
def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
|
|
def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
|
|
|
"""
|
|
"""
|
|
@@ -86,12 +90,20 @@ class MJBot:
|
|
|
session_id = context["session_id"]
|
|
session_id = context["session_id"]
|
|
|
cmd = context.content.split(maxsplit=1)
|
|
cmd = context.content.split(maxsplit=1)
|
|
|
if len(cmd) == 1:
|
|
if len(cmd) == 1:
|
|
|
- self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.ERROR)
|
|
|
|
|
|
|
+ self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO)
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ if not self._check_rate_limit(session_id, e_context):
|
|
|
|
|
+ logger.warn("[MJ] midjourney task exceed rate limit")
|
|
|
return
|
|
return
|
|
|
|
|
|
|
|
if mj_type == TaskType.GENERATE:
|
|
if mj_type == TaskType.GENERATE:
|
|
|
- # 图片生成
|
|
|
|
|
- raw_prompt = cmd[1]
|
|
|
|
|
|
|
+ image_prefix = check_prefix(context.content, conf().get("image_create_prefix"))
|
|
|
|
|
+ if image_prefix:
|
|
|
|
|
+ raw_prompt = context.content.replace(image_prefix, "", 1)
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 图片生成
|
|
|
|
|
+ raw_prompt = cmd[1]
|
|
|
reply = self.generate(raw_prompt, session_id, e_context)
|
|
reply = self.generate(raw_prompt, session_id, e_context)
|
|
|
e_context['reply'] = reply
|
|
e_context['reply'] = reply
|
|
|
e_context.action = EventAction.BREAK_PASS
|
|
e_context.action = EventAction.BREAK_PASS
|
|
@@ -126,10 +138,12 @@ class MJBot:
|
|
|
图片生成
|
|
图片生成
|
|
|
:param prompt: 提示词
|
|
:param prompt: 提示词
|
|
|
:param user_id: 用户id
|
|
:param user_id: 用户id
|
|
|
|
|
+ :param e_context: 对话上下文
|
|
|
:return: 任务ID
|
|
:return: 任务ID
|
|
|
"""
|
|
"""
|
|
|
logger.info(f"[MJ] image generate, prompt={prompt}")
|
|
logger.info(f"[MJ] image generate, prompt={prompt}")
|
|
|
- body = {"prompt": prompt}
|
|
|
|
|
|
|
+ mode = self._fetch_mode(prompt)
|
|
|
|
|
+ body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")}
|
|
|
res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers)
|
|
res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers)
|
|
|
if res.status_code == 200:
|
|
if res.status_code == 200:
|
|
|
res = res.json()
|
|
res = res.json()
|
|
@@ -137,7 +151,11 @@ class MJBot:
|
|
|
if res.get("code") == 200:
|
|
if res.get("code") == 200:
|
|
|
task_id = res.get("data").get("taskId")
|
|
task_id = res.get("data").get("taskId")
|
|
|
real_prompt = res.get("data").get("realPrompt")
|
|
real_prompt = res.get("data").get("realPrompt")
|
|
|
- content = f"🚀你的作品将在1~2分钟左右完成,请耐心等待\n- - - - - - - - -\n"
|
|
|
|
|
|
|
+ if mode == TaskMode.RELAX.name:
|
|
|
|
|
+ time_str = "1~10分钟"
|
|
|
|
|
+ else:
|
|
|
|
|
+ time_str = "1~2分钟"
|
|
|
|
|
+ content = f"🚀你的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n"
|
|
|
if real_prompt:
|
|
if real_prompt:
|
|
|
content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}"
|
|
content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}"
|
|
|
else:
|
|
else:
|
|
@@ -182,8 +200,9 @@ class MJBot:
|
|
|
return reply
|
|
return reply
|
|
|
|
|
|
|
|
async def check_task(self, task: MJTask, e_context: EventContext):
|
|
async def check_task(self, task: MJTask, e_context: EventContext):
|
|
|
- max_retry_time = 80
|
|
|
|
|
- while max_retry_time > 0:
|
|
|
|
|
|
|
+ max_retry_times = 90
|
|
|
|
|
+ while max_retry_times > 0:
|
|
|
|
|
+ await asyncio.sleep(10)
|
|
|
async with aiohttp.ClientSession() as session:
|
|
async with aiohttp.ClientSession() as session:
|
|
|
url = f"{self.base_url}/tasks/{task.id}"
|
|
url = f"{self.base_url}/tasks/{task.id}"
|
|
|
async with session.get(url, headers=self.headers) as res:
|
|
async with session.get(url, headers=self.headers) as res:
|
|
@@ -193,14 +212,17 @@ class MJBot:
|
|
|
f"data={res_json.get('data')}, thread={threading.current_thread().name}")
|
|
f"data={res_json.get('data')}, thread={threading.current_thread().name}")
|
|
|
if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name:
|
|
if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name:
|
|
|
# process success res
|
|
# process success res
|
|
|
|
|
+ if self.tasks.get(task.id):
|
|
|
|
|
+ self.tasks[task.id].status = Status.FINISHED
|
|
|
self._process_success_task(task, res_json.get("data"), e_context)
|
|
self._process_success_task(task, res_json.get("data"), e_context)
|
|
|
return
|
|
return
|
|
|
else:
|
|
else:
|
|
|
logger.warn(f"[MJ] image check error, status_code={res.status}")
|
|
logger.warn(f"[MJ] image check error, status_code={res.status}")
|
|
|
- max_retry_time -= 20
|
|
|
|
|
- await asyncio.sleep(10)
|
|
|
|
|
- max_retry_time -= 1
|
|
|
|
|
|
|
+ max_retry_times -= 20
|
|
|
|
|
+ max_retry_times -= 1
|
|
|
logger.warn("[MJ] end from poll")
|
|
logger.warn("[MJ] end from poll")
|
|
|
|
|
+ if self.tasks.get(task.id):
|
|
|
|
|
+ self.tasks[task.id].status = Status.EXPIRED
|
|
|
|
|
|
|
|
def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext):
|
|
def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext):
|
|
|
"""
|
|
"""
|
|
@@ -233,7 +255,39 @@ class MJBot:
|
|
|
self._print_tasks()
|
|
self._print_tasks()
|
|
|
return
|
|
return
|
|
|
|
|
|
|
|
|
|
+ def _check_rate_limit(self, user_id: str, e_context: EventContext) -> bool:
|
|
|
|
|
+ """
|
|
|
|
|
+ midjourney任务限流控制
|
|
|
|
|
+ :param user_id: 用户id
|
|
|
|
|
+ :param e_context: 对话上下文
|
|
|
|
|
+ :return: 任务是否能够生成, True:可以生成, False: 被限流
|
|
|
|
|
+ """
|
|
|
|
|
+ tasks = self.find_tasks_by_user_id(user_id)
|
|
|
|
|
+ task_count = len([t for t in tasks if t.status == Status.PENDING])
|
|
|
|
|
+ if task_count >= self.config.get("max_tasks_per_user"):
|
|
|
|
|
+ reply = Reply(ReplyType.INFO, "您的Midjourney作图任务数已达上限,请稍后再试")
|
|
|
|
|
+ e_context["reply"] = reply
|
|
|
|
|
+ e_context.action = EventAction.BREAK_PASS
|
|
|
|
|
+ return False
|
|
|
|
|
+ task_count = len([t for t in self.tasks.values() if t.status == Status.PENDING])
|
|
|
|
|
+ if task_count >= self.config.get("max_tasks"):
|
|
|
|
|
+ reply = Reply(ReplyType.INFO, "Midjourney服务的总任务数已达上限,请稍后再试")
|
|
|
|
|
+ e_context["reply"] = reply
|
|
|
|
|
+ e_context.action = EventAction.BREAK_PASS
|
|
|
|
|
+ return False
|
|
|
|
|
+ return True
|
|
|
|
|
+
|
|
|
|
|
+ def _fetch_mode(self, prompt) -> str:
|
|
|
|
|
+ mode = self.config.get("mode")
|
|
|
|
|
+ if "--relax" in prompt or mode == TaskMode.RELAX.name:
|
|
|
|
|
+ return TaskMode.RELAX.name
|
|
|
|
|
+ return TaskMode.FAST.name
|
|
|
|
|
+
|
|
|
def _run_loop(self, loop: asyncio.BaseEventLoop):
|
|
def _run_loop(self, loop: asyncio.BaseEventLoop):
|
|
|
|
|
+ """
|
|
|
|
|
+ 运行事件循环,用于轮询任务的线程
|
|
|
|
|
+ :param loop: 事件循环
|
|
|
|
|
+ """
|
|
|
loop.run_forever()
|
|
loop.run_forever()
|
|
|
loop.stop()
|
|
loop.stop()
|
|
|
|
|
|
|
@@ -241,6 +295,16 @@ class MJBot:
|
|
|
for id in self.tasks:
|
|
for id in self.tasks:
|
|
|
logger.debug(f"[MJ] current task: {self.tasks[id]}")
|
|
logger.debug(f"[MJ] current task: {self.tasks[id]}")
|
|
|
|
|
|
|
|
|
|
+ def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR):
|
|
|
|
|
+ """
|
|
|
|
|
+ 设置回复文本
|
|
|
|
|
+ :param content: 回复内容
|
|
|
|
|
+ :param e_context: 对话上下文
|
|
|
|
|
+ :param level: 回复等级
|
|
|
|
|
+ """
|
|
|
|
|
+ reply = Reply(level, content)
|
|
|
|
|
+ e_context["reply"] = reply
|
|
|
|
|
+ e_context.action = EventAction.BREAK_PASS
|
|
|
|
|
|
|
|
def get_help_text(self, verbose=False, **kwargs):
|
|
def get_help_text(self, verbose=False, **kwargs):
|
|
|
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
|
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
|
@@ -250,7 +314,23 @@ class MJBot:
|
|
|
help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n{trigger_prefix}mjimage 描述词1,描述词2 ... : 利用描述词进行图生图,参数请放在提示词之后。\n{trigger_prefix}mjr ID: 对指定ID消息重新生成图片。\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n{trigger_prefix}mjv ID 图片序号: 对指定ID消息中的第x张图片进行变换。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\""
|
|
help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n{trigger_prefix}mjimage 描述词1,描述词2 ... : 利用描述词进行图生图,参数请放在提示词之后。\n{trigger_prefix}mjr ID: 对指定ID消息重新生成图片。\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n{trigger_prefix}mjv ID 图片序号: 对指定ID消息中的第x张图片进行变换。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\""
|
|
|
return help_text
|
|
return help_text
|
|
|
|
|
|
|
|
- def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR):
|
|
|
|
|
- reply = Reply(level, content)
|
|
|
|
|
- e_context["reply"] = reply
|
|
|
|
|
- e_context.action = EventAction.BREAK_PASS
|
|
|
|
|
|
|
+ def find_tasks_by_user_id(self, user_id) -> list[MJTask]:
|
|
|
|
|
+ result = []
|
|
|
|
|
+ with self.tasks_lock:
|
|
|
|
|
+ now = time.time()
|
|
|
|
|
+ for task in self.tasks.values():
|
|
|
|
|
+ if task.status == Status.PENDING and now > task.expiry_time:
|
|
|
|
|
+ task.status = Status.EXPIRED
|
|
|
|
|
+ logger.info(f"[MJ] {task} expired")
|
|
|
|
|
+ if task.user_id == user_id:
|
|
|
|
|
+ result.append(task)
|
|
|
|
|
+ return result
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def check_prefix(content, prefix_list):
|
|
|
|
|
+ if not prefix_list:
|
|
|
|
|
+ return None
|
|
|
|
|
+ for prefix in prefix_list:
|
|
|
|
|
+ if content.startswith(prefix):
|
|
|
|
|
+ return prefix
|
|
|
|
|
+ return None
|