소스 검색

feat: check app_code dynamically

zhayujie 2 년 전
부모
커밋
2f9e5b1219
7개의 변경된 파일181개의 추가작업 그리고 35개의 파일을 삭제
  1. 6 0
      config.py
  2. 3 1
      plugins/godcmd/godcmd.py
  3. 0 0
      plugins/linkai/README.md
  4. 2 1
      plugins/linkai/config.json.template
  5. 53 13
      plugins/linkai/linkai.py
  6. 99 19
      plugins/linkai/midjourney.py
  7. 18 1
      plugins/plugin.py

+ 6 - 0
config.py

@@ -252,3 +252,9 @@ def pconf(plugin_name: str) -> dict:
     :return: 该插件的配置项
     """
     return plugin_config.get(plugin_name.lower())
+
+
+# 全局配置,用于存放全局生效的状态
+global_config = {
+    "admin_users": []
+}

+ 3 - 1
plugins/godcmd/godcmd.py

@@ -13,7 +13,7 @@ from bridge.context import ContextType
 from bridge.reply import Reply, ReplyType
 from common import const
 from common.log import logger
-from config import conf, load_config
+from config import conf, load_config, global_config
 from plugins import *
 
 # 定义指令集
@@ -426,9 +426,11 @@ class Godcmd(Plugin):
         password = args[0]
         if password == self.password:
             self.admin_users.append(userid)
+            global_config["admin_users"].append(userid)
             return True, "认证成功"
         elif password == self.temp_password:
             self.admin_users.append(userid)
+            global_config["admin_users"].append(userid)
             return True, "认证成功,请尽快设置口令"
         else:
             return False, "认证失败"

+ 0 - 0
plugins/linkai/README.md


+ 2 - 1
plugins/linkai/config.json.template

@@ -8,6 +8,7 @@
         "mode": "relax",
         "auto_translate": true,
         "max_tasks": 3,
-        "max_tasks_per_user": 1
+        "max_tasks_per_user": 1,
+        "use_image_create_prefix": true
     }
 }

+ 53 - 13
plugins/linkai/linkai.py

@@ -8,7 +8,7 @@ from bridge.context import ContextType
 from bridge.reply import Reply, ReplyType
 from channel.chat_message import ChatMessage
 from common.log import logger
-from config import conf
+from config import conf, global_config
 from plugins import *
 from .midjourney import MJBot, TaskType
 
@@ -46,14 +46,48 @@ class LinkAI(Plugin):
             self.mj_bot.process_mj_task(mj_type, e_context)
             return
 
+        if context.content.startswith(f"{_get_trigger_prefix()}linkai"):
+            # 应用管理功能
+            self._process_admin_cmd(e_context)
+            return
+
         if self._is_chat_task(e_context):
+            # 文本对话任务处理
             self._process_chat_task(e_context)
 
+    # 插件管理功能
+    def _process_admin_cmd(self, e_context: EventContext):
+        context = e_context['context']
+        cmd = context.content.split()
+        if len(cmd) == 1 or (len(cmd) == 2 and cmd[1] == "help"):
+            _set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO)
+            return
+        if len(cmd) == 3 and cmd[1] == "app":
+            if not context.kwargs.get("isgroup"):
+                _set_reply_text("该指令需在群聊中使用", e_context, level=ReplyType.ERROR)
+                return
+            if e_context["context"]["session_id"] not in global_config["admin_users"]:
+                _set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
+                return
+            app_code = cmd[2]
+            group_name = context.kwargs.get("msg").from_user_nickname
+            group_mapping = self.config.get("group_app_map")
+            if group_mapping:
+                group_mapping[group_name] = app_code
+            else:
+                self.config["group_app_map"] = {group_name: app_code}
+            # 保存插件配置
+            super().save_config(self.config)
+            _set_reply_text(f"应用设置成功: {app_code}", e_context, level=ReplyType.INFO)
+        else:
+            _set_reply_text(f"指令错误,请输入{_get_trigger_prefix()}linkai help 获取帮助", e_context, level=ReplyType.INFO)
+            return
+
     # LinkAI 对话任务处理
     def _is_chat_task(self, e_context: EventContext):
         context = e_context['context']
         # 群聊应用管理
-        return self.config.get("knowledge_base") and context.kwargs.get("isgroup")
+        return self.config.get("group_app_map") and context.kwargs.get("isgroup")
 
     def _process_chat_task(self, e_context: EventContext):
         """
@@ -73,21 +107,27 @@ class LinkAI(Plugin):
         :param group_name: 群聊名称
         :return: 应用code
         """
-        knowledge_base_config = self.config.get("knowledge_base")
-        if knowledge_base_config and knowledge_base_config.get("group_mapping"):
-            app_code = knowledge_base_config.get("group_mapping").get(group_name) \
-                       or knowledge_base_config.get("group_mapping").get("ALL_GROUP")
+        group_mapping = self.config.get("group_app_map")
+        if group_mapping:
+            app_code = group_mapping.get(group_name) or group_mapping.get("ALL_GROUP")
             return app_code
 
     def get_help_text(self, verbose=False, **kwargs):
-        trigger_prefix = conf().get("plugin_trigger_prefix", "$")
-        help_text = "利用midjourney来画图。\n"
+        trigger_prefix = _get_trigger_prefix()
+        help_text = "用于集成 LinkAI 提供的文本对话、知识库、绘画等能力。\n"
         if not verbose:
             return help_text
-        help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n{trigger_prefix}mjimage 描述词1,描述词2 ... : 利用描述词进行图生图,参数请放在提示词之后。\n{trigger_prefix}mjr ID: 对指定ID消息重新生成图片。\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n{trigger_prefix}mjv ID 图片序号: 对指定ID消息中的第x张图片进行变换。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\""
+        help_text += ""
+        help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\""
         return help_text
 
-    def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR):
-        reply = Reply(level, content)
-        e_context["reply"] = reply
-        e_context.action = EventAction.BREAK_PASS
+
+# 静态方法
+def _set_reply_text(content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
+    reply = Reply(level, content)
+    e_context["reply"] = reply
+    e_context.action = EventAction.BREAK_PASS
+
+
+def _get_trigger_prefix():
+    return conf().get("plugin_trigger_prefix", "$")

+ 99 - 19
plugins/linkai/midjourney.py

@@ -28,6 +28,11 @@ class Status(Enum):
         return self.name
 
 
+class TaskMode(Enum):
+    FAST = "fast"
+    RELAX = "relax"
+
+
 class MJTask:
     def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int=60*30, status=Status.PENDING):
         self.id = id
@@ -47,7 +52,6 @@ class MJTask:
 class MJBot:
     def __init__(self, config):
         self.base_url = "https://api.link-ai.chat/v1/img/midjourney"
-        # self.base_url = "http://127.0.0.1:8911/v1/img/midjourney"
         self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
         self.config = config
         self.tasks = {}
@@ -71,10 +75,10 @@ class MJBot:
                     return TaskType.GENERATE
                 elif cmd_list[0].lower() == f"{trigger_prefix}mju":
                     return TaskType.UPSCALE
-                # elif cmd_list[0].lower() == f"{trigger_prefix}mjv":
-                #     return TaskType.VARIATION
-                # elif cmd_list[0].lower() == f"{trigger_prefix}mjr":
-                #     return TaskType.RESET
+                elif self.config.get("use_image_create_prefix") and \
+                        check_prefix(context.content, conf().get("image_create_prefix")):
+                    return TaskType.GENERATE
+
 
     def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
         """
@@ -86,12 +90,20 @@ class MJBot:
         session_id = context["session_id"]
         cmd = context.content.split(maxsplit=1)
         if len(cmd) == 1:
-            self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.ERROR)
+            self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO)
+            return
+
+        if not self._check_rate_limit(session_id, e_context):
+            logger.warn("[MJ] midjourney task exceed rate limit")
             return
 
         if mj_type == TaskType.GENERATE:
-            # 图片生成
-            raw_prompt = cmd[1]
+            image_prefix = check_prefix(context.content, conf().get("image_create_prefix"))
+            if image_prefix:
+                raw_prompt = context.content.replace(image_prefix, "", 1)
+            else:
+                # 图片生成
+                raw_prompt = cmd[1]
             reply = self.generate(raw_prompt, session_id, e_context)
             e_context['reply'] = reply
             e_context.action = EventAction.BREAK_PASS
@@ -126,10 +138,12 @@ class MJBot:
         图片生成
         :param prompt: 提示词
         :param user_id: 用户id
+        :param e_context: 对话上下文
         :return: 任务ID
         """
         logger.info(f"[MJ] image generate, prompt={prompt}")
-        body = {"prompt": prompt}
+        mode = self._fetch_mode(prompt)
+        body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")}
         res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers)
         if res.status_code == 200:
             res = res.json()
@@ -137,7 +151,11 @@ class MJBot:
             if res.get("code") == 200:
                 task_id = res.get("data").get("taskId")
                 real_prompt = res.get("data").get("realPrompt")
-                content = f"🚀你的作品将在1~2分钟左右完成,请耐心等待\n- - - - - - - - -\n"
+                if mode == TaskMode.RELAX.name:
+                    time_str = "1~10分钟"
+                else:
+                    time_str = "1~2分钟"
+                content = f"🚀你的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n"
                 if real_prompt:
                     content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}"
                 else:
@@ -182,8 +200,9 @@ class MJBot:
             return reply
 
     async def check_task(self, task: MJTask, e_context: EventContext):
-        max_retry_time = 80
-        while max_retry_time > 0:
+        max_retry_times = 90
+        while max_retry_times > 0:
+            await asyncio.sleep(10)
             async with aiohttp.ClientSession() as session:
                 url = f"{self.base_url}/tasks/{task.id}"
                 async with session.get(url, headers=self.headers) as res:
@@ -193,14 +212,17 @@ class MJBot:
                                      f"data={res_json.get('data')}, thread={threading.current_thread().name}")
                         if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name:
                             # process success res
+                            if self.tasks.get(task.id):
+                                self.tasks[task.id].status = Status.FINISHED
                             self._process_success_task(task, res_json.get("data"), e_context)
                             return
                     else:
                         logger.warn(f"[MJ] image check error, status_code={res.status}")
-                        max_retry_time -= 20
-            await asyncio.sleep(10)
-            max_retry_time -= 1
+                        max_retry_times -= 20
+            max_retry_times -= 1
         logger.warn("[MJ] end from poll")
+        if self.tasks.get(task.id):
+            self.tasks[task.id].status = Status.EXPIRED
 
     def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext):
         """
@@ -233,7 +255,39 @@ class MJBot:
         self._print_tasks()
         return
 
+    def _check_rate_limit(self, user_id: str, e_context: EventContext) -> bool:
+        """
+        midjourney任务限流控制
+        :param user_id: 用户id
+        :param e_context: 对话上下文
+        :return: 任务是否能够生成, True:可以生成, False: 被限流
+        """
+        tasks = self.find_tasks_by_user_id(user_id)
+        task_count = len([t for t in tasks if t.status == Status.PENDING])
+        if task_count >= self.config.get("max_tasks_per_user"):
+            reply = Reply(ReplyType.INFO, "您的Midjourney作图任务数已达上限,请稍后再试")
+            e_context["reply"] = reply
+            e_context.action = EventAction.BREAK_PASS
+            return False
+        task_count = len([t for t in self.tasks.values() if t.status == Status.PENDING])
+        if task_count >= self.config.get("max_tasks"):
+            reply = Reply(ReplyType.INFO, "Midjourney服务的总任务数已达上限,请稍后再试")
+            e_context["reply"] = reply
+            e_context.action = EventAction.BREAK_PASS
+            return False
+        return True
+
+    def _fetch_mode(self, prompt) -> str:
+        mode = self.config.get("mode")
+        if "--relax" in prompt or mode == TaskMode.RELAX.name:
+            return TaskMode.RELAX.name
+        return TaskMode.FAST.name
+
     def _run_loop(self, loop: asyncio.BaseEventLoop):
+        """
+        运行事件循环,用于轮询任务的线程
+        :param loop: 事件循环
+        """
         loop.run_forever()
         loop.stop()
 
@@ -241,6 +295,16 @@ class MJBot:
         for id in self.tasks:
             logger.debug(f"[MJ] current task: {self.tasks[id]}")
 
+    def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR):
+        """
+        设置回复文本
+        :param content: 回复内容
+        :param e_context: 对话上下文
+        :param level: 回复等级
+        """
+        reply = Reply(level, content)
+        e_context["reply"] = reply
+        e_context.action = EventAction.BREAK_PASS
 
     def get_help_text(self, verbose=False, **kwargs):
         trigger_prefix = conf().get("plugin_trigger_prefix", "$")
@@ -250,7 +314,23 @@ class MJBot:
         help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n{trigger_prefix}mjimage 描述词1,描述词2 ... : 利用描述词进行图生图,参数请放在提示词之后。\n{trigger_prefix}mjr ID: 对指定ID消息重新生成图片。\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n{trigger_prefix}mjv ID 图片序号: 对指定ID消息中的第x张图片进行变换。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\""
         return help_text
 
-    def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR):
-        reply = Reply(level, content)
-        e_context["reply"] = reply
-        e_context.action = EventAction.BREAK_PASS
+    def find_tasks_by_user_id(self, user_id) -> list[MJTask]:
+        result = []
+        with self.tasks_lock:
+            now = time.time()
+            for task in self.tasks.values():
+                if task.status == Status.PENDING and now > task.expiry_time:
+                    task.status = Status.EXPIRED
+                    logger.info(f"[MJ] {task} expired")
+                if task.user_id == user_id:
+                    result.append(task)
+        return result
+
+
+def check_prefix(content, prefix_list):
+    if not prefix_list:
+        return None
+    for prefix in prefix_list:
+        if content.startswith(prefix):
+            return prefix
+    return None

+ 18 - 1
plugins/plugin.py

@@ -1,6 +1,6 @@
 import os
 import json
-from config import pconf
+from config import pconf, plugin_config
 from common.log import logger
 
 
@@ -24,5 +24,22 @@ class Plugin:
         logger.debug(f"loading plugin config, plugin_name={self.name}, conf={plugin_conf}")
         return plugin_conf
 
+    def save_config(self, config: dict):
+        try:
+            plugin_config[self.name] = config
+            # 写入全局配置
+            global_config_path = "./plugins/config.json"
+            if os.path.exists(global_config_path):
+                with open(global_config_path, "w", encoding='utf-8') as f:
+                    json.dump(plugin_config, f, indent=4, ensure_ascii=False)
+            # 写入插件配置
+            plugin_config_path = os.path.join(self.path, "config.json")
+            if os.path.exists(plugin_config_path):
+                with open(plugin_config_path, "w", encoding='utf-8') as f:
+                    json.dump(config, f, indent=4, ensure_ascii=False)
+
+        except Exception as e:
+            logger.warn("save plugin config failed: {}".format(e))
+
     def get_help_text(self, **kwargs):
         return "暂无帮助信息"