소스 검색

Merge branch 'master' into master

zhayujie 2 년 전
부모
커밋
995894d3aa

+ 2 - 1
.gitignore

@@ -28,4 +28,5 @@ plugins/banwords/__pycache__
 plugins/banwords/lib/__pycache__
 !plugins/hello
 !plugins/role
-!plugins/keyword
+!plugins/keyword
+!plugins/linkai

+ 2 - 1
README.md

@@ -111,7 +111,7 @@ pip3 install azure-cognitiveservices-speech
 {
   "open_ai_api_key": "YOUR API KEY",                          # 填入上面创建的 OpenAI API KEY
   "model": "gpt-3.5-turbo",                                   # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
-  "proxy": "127.0.0.1:7890",                                  # 代理客户端的ip和端口
+  "proxy": "",                                                # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890"
   "single_chat_prefix": ["bot", "@bot"],                      # 私聊时文本需要包含该前缀才能触发机器人回复
   "single_chat_reply_prefix": "[bot] ",                       # 私聊时自动回复的前缀,用于区分真人
   "group_chat_prefix": ["@bot"],                              # 群聊时包含该前缀则会触发机器人回复
@@ -123,6 +123,7 @@ pip3 install azure-cognitiveservices-speech
   "group_speech_recognition": false,                          # 是否开启群组语音识别
   "use_azure_chatgpt": false,                                 # 是否使用Azure ChatGPT service代替openai ChatGPT service. 当设置为true时需要设置 open_ai_api_base,如 https://xxx.openai.azure.com/
   "azure_deployment_id": "",                                  # 采用Azure ChatGPT时,模型部署名称
+  "azure_api_version": "",                                    # 采用Azure ChatGPT时,API版本
   "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",  # 人格描述
   # 订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复,可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
   "subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。"

+ 1 - 1
bot/chatgpt/chat_gpt_bot.py

@@ -166,7 +166,7 @@ class AzureChatGPTBot(ChatGPTBot):
     def __init__(self):
         super().__init__()
         openai.api_type = "azure"
-        openai.api_version = "2023-03-15-preview"
+        openai.api_version = conf().get("azure_api_version", "2023-06-01-preview")
         self.args["deployment_id"] = conf().get("azure_deployment_id")
 
     def create_img(self, query, retry_count=0, api_key=None):

+ 32 - 25
bot/linkai/link_ai_bot.py

@@ -29,18 +29,24 @@ class LinkAIBot(Bot, OpenAIImage):
         if context.type == ContextType.TEXT:
             return self._chat(query, context)
         elif context.type == ContextType.IMAGE_CREATE:
-            ok, retstring = self.create_img(query, 0)
-            reply = None
+            ok, res = self.create_img(query, 0)
             if ok:
-                reply = Reply(ReplyType.IMAGE_URL, retstring)
+                reply = Reply(ReplyType.IMAGE_URL, res)
             else:
-                reply = Reply(ReplyType.ERROR, retstring)
+                reply = Reply(ReplyType.ERROR, res)
             return reply
         else:
             reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
             return reply
 
-    def _chat(self, query, context, retry_count=0):
+    def _chat(self, query, context, retry_count=0) -> Reply:
+        """
+        发起对话请求
+        :param query: 请求提示词
+        :param context: 对话上下文
+        :param retry_count: 当前递归重试次数
+        :return: 回复
+        """
         if retry_count >= 2:
             # exit from retry 2 times
             logger.warn("[LINKAI] failed after maximum number of retry times")
@@ -52,7 +58,7 @@ class LinkAIBot(Bot, OpenAIImage):
                 logger.info(f"[LINKAI] won't set appcode because a plugin ({context['generate_breaked_by']}) affected the context")
                 app_code = None
             else:
-                app_code = conf().get("linkai_app_code")
+                app_code = context.kwargs.get("app_code") or conf().get("linkai_app_code")
             linkai_api_key = conf().get("linkai_api_key")
 
             session_id = context["session_id"]
@@ -63,10 +69,8 @@ class LinkAIBot(Bot, OpenAIImage):
             if app_code and session.messages[0].get("role") == "system":
                 session.messages.pop(0)
 
-            logger.info(f"[LINKAI] query={query}, app_code={app_code}")
-
             body = {
-                "appCode": app_code,
+                "app_code": app_code,
                 "messages": session.messages,
                 "model": conf().get("model") or "gpt-3.5-turbo",  # 对话模型的名称
                 "temperature": conf().get("temperature"),
@@ -74,31 +78,34 @@ class LinkAIBot(Bot, OpenAIImage):
                 "frequency_penalty": conf().get("frequency_penalty", 0.0),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
                 "presence_penalty": conf().get("presence_penalty", 0.0),  # [-2,2]之间,该值越大则更倾向于产生不同的内容
             }
+            logger.info(f"[LINKAI] query={query}, app_code={app_code}, mode={body.get('model')}")
             headers = {"Authorization": "Bearer " + linkai_api_key}
 
             # do http request
-            res = requests.post(url=self.base_url + "/chat/completion", json=body, headers=headers).json()
-
-            if not res or not res["success"]:
-                if res.get("code") == self.AUTH_FAILED_CODE:
-                    logger.exception(f"[LINKAI] please check your linkai_api_key, res={res}")
-                    return Reply(ReplyType.ERROR, "请再问我一次吧")
+            res = requests.post(url=self.base_url + "/chat/completions", json=body, headers=headers,
+                                timeout=conf().get("request_timeout", 180))
+            if res.status_code == 200:
+                # execute success
+                response = res.json()
+                reply_content = response["choices"][0]["message"]["content"]
+                total_tokens = response["usage"]["total_tokens"]
+                logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
+                self.sessions.session_reply(reply_content, session_id, total_tokens)
+                return Reply(ReplyType.TEXT, reply_content)
 
-                elif res.get("code") == self.NO_QUOTA_CODE:
-                    logger.exception(f"[LINKAI] please check your account quota, https://chat.link-ai.tech/console/account")
-                    return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
+            else:
+                response = res.json()
+                error = response.get("error")
+                logger.error(f"[LINKAI] chat failed, status_code={res.status_code}, "
+                             f"msg={error.get('message')}, type={error.get('type')}")
 
-                else:
-                    # retry
+                if res.status_code >= 500:
+                    # server error, need retry
                     time.sleep(2)
                     logger.warn(f"[LINKAI] do retry, times={retry_count}")
                     return self._chat(query, context, retry_count + 1)
 
-            # execute success
-            reply_content = res["data"]["content"]
-            logger.info(f"[LINKAI] reply={reply_content}")
-            self.sessions.session_reply(reply_content, session_id)
-            return Reply(ReplyType.TEXT, reply_content)
+                return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
 
         except Exception as e:
             logger.exception(e)

+ 6 - 0
bridge/bridge.py

@@ -56,3 +56,9 @@ class Bridge(object):
 
     def fetch_translate(self, text, from_lang="", to_lang="en") -> Reply:
         return self.get_bot("translate").translate(text, from_lang, to_lang)
+
+    def reset_bot(self):
+        """
+        重置bot路由
+        """
+        self.__init__()

+ 6 - 2
channel/chat_channel.py

@@ -108,8 +108,12 @@ class ChatChannel(Channel):
                     if not conf().get("group_at_off", False):
                         flag = True
                     pattern = f"@{re.escape(self.name)}(\u2005|\u0020)"
-                    content = re.sub(pattern, r"", content)
-
+                    subtract_res = re.sub(pattern, r"", content)
+                    if subtract_res == content and context["msg"].self_display_name:
+                        # 前缀移除后没有变化,使用群昵称再次移除
+                        pattern = f"@{re.escape(context['msg'].self_display_name)}(\u2005|\u0020)"
+                        subtract_res = re.sub(pattern, r"", content)
+                    content = subtract_res
                 if not flag:
                     if context["origin_ctype"] == ContextType.VOICE:
                         logger.info("[WX]receive group voice, but checkprefix didn't match")

+ 3 - 3
channel/chat_message.py

@@ -24,9 +24,7 @@ is_at: 是否被at
 - (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在)
 actual_user_id: 实际发送者id (群聊必填)
 actual_user_nickname:实际发送者昵称
-
-
-
+self_display_name: 自身的展示名,设置群昵称时,该字段表示群昵称
 
 _prepare_fn: 准备函数,用于准备消息的内容,比如下载图片等,
 _prepared: 是否已经调用过准备函数
@@ -48,6 +46,8 @@ class ChatMessage(object):
     to_user_nickname = None
     other_user_id = None
     other_user_nickname = None
+    my_msg = False
+    self_display_name = None
 
     is_group = False
     is_at = False

+ 3 - 0
channel/wechat/wechat_channel.py

@@ -58,6 +58,9 @@ def _check(func):
         if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60:  # 跳过1分钟前的历史消息
             logger.debug("[WX]history message {} skipped".format(msgId))
             return
+        if cmsg.my_msg and not cmsg.is_group:
+            logger.debug("[WX]my message {} skipped".format(msgId))
+            return
         return func(self, cmsg)
 
     return wrapper

+ 7 - 1
channel/wechat/wechat_message.py

@@ -57,13 +57,19 @@ class WechatMessage(ChatMessage):
             self.from_user_nickname = nickname
         if self.to_user_id == user_id:
             self.to_user_nickname = nickname
-        try:  # 陌生人时候, 'User'字段可能不存在
+        try:  # 陌生人时候, User字段可能不存在
+            # my_msg 为True是表示是自己发送的消息
+            self.my_msg = itchat_msg["ToUserName"] == itchat_msg["User"]["UserName"] and \
+                          itchat_msg["ToUserName"] != itchat_msg["FromUserName"]
             self.other_user_id = itchat_msg["User"]["UserName"]
             self.other_user_nickname = itchat_msg["User"]["NickName"]
             if self.other_user_id == self.from_user_id:
                 self.from_user_nickname = self.other_user_nickname
             if self.other_user_id == self.to_user_id:
                 self.to_user_nickname = self.other_user_nickname
+            if itchat_msg["User"].get("Self"):
+                # 自身的展示名,当设置了群昵称时,该字段表示群昵称
+                self.self_display_name = itchat_msg["User"].get("Self").get("DisplayName")
         except KeyError as e:  # 处理偶尔没有对方信息的情况
             logger.warn("[WX]get other_user_id failed: " + str(e))
             if self.from_user_id == user_id:

+ 9 - 0
config.py

@@ -20,6 +20,7 @@ available_setting = {
     "use_azure_chatgpt": False,  # 是否使用azure的chatgpt
     "azure_deployment_id": "",  # azure 模型部署名称
     "use_baidu_wenxin": False,  # 是否使用baidu文心一言,优先级次于azure
+    "azure_api_version": "",  # azure api版本
     # Bot触发配置
     "single_chat_prefix": ["bot", "@bot"],  # 私聊时文本需要包含该前缀才能触发机器人回复
     "single_chat_reply_prefix": "[bot] ",  # 私聊时自动回复的前缀,用于区分真人
@@ -107,6 +108,8 @@ available_setting = {
     "appdata_dir": "",  # 数据目录
     # 插件配置
     "plugin_trigger_prefix": "$",  # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突
+    # 是否使用全局插件配置
+    "use_global_plugin_config": False,
     # 知识库平台配置
     "use_linkai": False,
     "linkai_api_key": "",
@@ -257,3 +260,9 @@ def pconf(plugin_name: str) -> dict:
     :return: 该插件的配置项
     """
     return plugin_config.get(plugin_name.lower())
+
+
+# 全局配置,用于存放全局生效的状态
+global_config = {
+    "admin_users": []
+}

+ 1 - 0
docker/docker-compose.yml

@@ -18,6 +18,7 @@ services:
       SPEECH_RECOGNITION: 'False'
       CHARACTER_DESC: '你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。'
       EXPIRES_IN_SECONDS: 3600
+      USE_GLOBAL_PLUGIN_CONFIG: 'True'
       USE_LINKAI: 'False'
       LINKAI_API_KEY: ''
       LINKAI_APP_CODE: ''

+ 14 - 0
plugins/config.json.template

@@ -20,5 +20,19 @@
             "no_default": false,
             "model_name": "gpt-3.5-turbo"
         }
+    },
+    "linkai": {
+        "group_app_map": {
+            "测试群1": "default",
+            "测试群2": "Kv2fXJcH"
+        },
+        "midjourney": {
+            "enabled": true,
+            "auto_translate": true,
+            "img_proxy": true,
+            "max_tasks": 3,
+            "max_tasks_per_user": 1,
+            "use_image_create_prefix": true
+        }
     }
 }

+ 3 - 1
plugins/godcmd/godcmd.py

@@ -13,7 +13,7 @@ from bridge.context import ContextType
 from bridge.reply import Reply, ReplyType
 from common import const
 from common.log import logger
-from config import conf, load_config
+from config import conf, load_config, global_config
 from plugins import *
 
 # 定义指令集
@@ -426,9 +426,11 @@ class Godcmd(Plugin):
         password = args[0]
         if password == self.password:
             self.admin_users.append(userid)
+            global_config["admin_users"].append(userid)
             return True, "认证成功"
         elif password == self.temp_password:
             self.admin_users.append(userid)
+            global_config["admin_users"].append(userid)
             return True, "认证成功,请尽快设置口令"
         else:
             return False, "认证失败"

+ 12 - 3
plugins/keyword/keyword.py

@@ -54,9 +54,18 @@ class Keyword(Plugin):
             logger.debug(f"[keyword] 匹配到关键字【{content}】")
             reply_text = self.keyword[content]
 
-            reply = Reply()
-            reply.type = ReplyType.TEXT
-            reply.content = reply_text
+            # 判断匹配内容的类型
+            if (reply_text.startswith("http://") or reply_text.startswith("https://")) and any(reply_text.endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".gif", ".webp"]):
+                # 如果是以 http:// 或 https:// 开头,且.jpg/.jpeg/.png/.gif结尾,则认为是图片 URL
+                reply = Reply()
+                reply.type = ReplyType.IMAGE_URL
+                reply.content = reply_text
+            else:
+            # 否则认为是普通文本
+                reply = Reply()
+                reply.type = ReplyType.TEXT
+                reply.content = reply_text
+            
             e_context["reply"] = reply
             e_context.action = EventAction.BREAK_PASS  # 事件结束,并跳过处理context的默认逻辑
 

+ 66 - 0
plugins/linkai/README.md

@@ -0,0 +1,66 @@
+## 插件说明
+
+基于 LinkAI 提供的知识库、Midjourney绘画等能力对机器人的功能进行增强。平台地址: https://chat.link-ai.tech/console
+
+## 插件配置
+
+将 `plugins/linkai` 目录下的 `config.json.template` 配置模板复制为最终生效的 `config.json`:
+
+以下是配置项说明:
+
+```bash
+{
+    "group_app_map": {            # 群聊 和 应用编码 的映射关系
+        "测试群1": "default",      # 表示在名称为 "测试群1" 的群聊中将使用app_code 为 default 的应用
+        "测试群2": "Kv2fXJcH"
+    },
+    "midjourney": {
+        "enabled": true,          # midjourney 绘画开关
+        "auto_translate": true,   # 是否自动将提示词翻译为英文
+        "img_proxy": true,        # 是否对生成的图片使用代理,如果你是国外服务器,将这一项设置为false会获得更快的生成速度
+        "max_tasks": 3,           # 支持同时提交的总任务个数
+        "max_tasks_per_user": 1,  # 支持单个用户同时提交的任务个数
+        "use_image_create_prefix": true   # 是否使用全局的绘画触发词,如果开启将同时支持由`config.json`中的 image_create_prefix 配置触发
+    }
+}
+
+```
+注意:
+
+ - 配置项中 `group_app_map` 部分是用于映射群聊与LinkAI平台上的应用, `midjourney` 部分是 mj 画图的配置,可根据需要进行填写,未填写配置时默认不开启相应功能
+ - 实际 `config.json` 配置中应保证json格式,不应携带 '#' 及后面的注释
+ - 如果是`docker`部署,可通过映射 `plugins/config.json` 到容器中来完成插件配置,参考[文档](https://github.com/zhayujie/chatgpt-on-wechat#3-%E6%8F%92%E4%BB%B6%E4%BD%BF%E7%94%A8)
+
+## 插件使用
+
+> 使用插件中的知识库管理功能需要首先开启`linkai`对话,依赖全局 `config.json` 中的 `use_linkai` 和 `linkai_api_key` 配置;而midjourney绘画功能则只需填写 `linkai_api_key` 配置,`use_linkai` 无论是否关闭均可使用。具体可参考 [详细文档](https://link-ai.tech/platform/link-app/wechat)。
+
+完成配置后运行项目,会自动运行插件,输入 `#help linkai` 可查看插件功能。
+
+### 1.知识库管理功能
+
+提供在不同群聊使用不同应用的功能。可以在上述 `group_app_map` 配置中固定映射关系,也可以通过指令在群中快速完成切换。
+
+应用切换指令需要首先完成管理员 (`godcmd`) 插件的认证,然后按以下格式输入:
+
+`$linkai app {app_code}`
+
+例如输入 `$linkai app Kv2fXJcH`,即将当前群聊与 app_code为 Kv2fXJcH 的应用绑定。
+
+### 2.Midjourney绘画功能
+
+指令格式:
+
+```
+ - 图片生成: $mj 描述词1, 描述词2..
+ - 图片放大: $mju 图片ID 图片序号
+```
+
+例如:
+
+```
+"$mj a little cat, white --ar 9:16"
+"$mju 1105592717188272288 2"
+```
+
+注:开启 `use_image_create_prefix` 配置后可直接复用全局画图触发词,以"画"开头便可以生成图片。

+ 1 - 0
plugins/linkai/__init__.py

@@ -0,0 +1 @@
+from .linkai import *

+ 14 - 0
plugins/linkai/config.json.template

@@ -0,0 +1,14 @@
+{
+    "group_app_map": {
+        "测试群1": "default",
+        "测试群2": "Kv2fXJcH"
+    },
+    "midjourney": {
+        "enabled": true,
+        "auto_translate": true,
+        "img_proxy": true,
+        "max_tasks": 3,
+        "max_tasks_per_user": 1,
+        "use_image_create_prefix": true
+    }
+}

+ 161 - 0
plugins/linkai/linkai.py

@@ -0,0 +1,161 @@
+import plugins
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from config import global_config
+from plugins import *
+from .midjourney import MJBot
+from bridge import bridge
+
+
+@plugins.register(
+    name="linkai",
+    desc="A plugin that supports knowledge base and midjourney drawing.",
+    version="0.1.0",
+    author="https://link-ai.tech",
+)
+class LinkAI(Plugin):
+    def __init__(self):
+        super().__init__()
+        self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
+        self.config = super().load_config()
+        if self.config:
+            self.mj_bot = MJBot(self.config.get("midjourney"))
+        logger.info("[LinkAI] inited")
+
+    def on_handle_context(self, e_context: EventContext):
+        """
+        消息处理逻辑
+        :param e_context: 消息上下文
+        """
+        if not self.config:
+            return
+
+        context = e_context['context']
+        if context.type not in [ContextType.TEXT, ContextType.IMAGE, ContextType.IMAGE_CREATE]:
+            # filter content no need solve
+            return
+
+        mj_type = self.mj_bot.judge_mj_task_type(e_context)
+        if mj_type:
+            # MJ作图任务处理
+            self.mj_bot.process_mj_task(mj_type, e_context)
+            return
+
+        if context.content.startswith(f"{_get_trigger_prefix()}linkai"):
+            # 应用管理功能
+            self._process_admin_cmd(e_context)
+            return
+
+        if self._is_chat_task(e_context):
+            # 文本对话任务处理
+            self._process_chat_task(e_context)
+
+    # 插件管理功能
+    def _process_admin_cmd(self, e_context: EventContext):
+        context = e_context['context']
+        cmd = context.content.split()
+        if len(cmd) == 1 or (len(cmd) == 2 and cmd[1] == "help"):
+            _set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO)
+            return
+
+        if len(cmd) == 2 and (cmd[1] == "open" or cmd[1] == "close"):
+            # 知识库开关指令
+            if not _is_admin(e_context):
+                _set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
+                return
+            is_open = True
+            tips_text = "开启"
+            if cmd[1] == "close":
+                tips_text = "关闭"
+                is_open = False
+            conf()["use_linkai"] = is_open
+            bridge.Bridge().reset_bot()
+            _set_reply_text(f"知识库功能已{tips_text}", e_context, level=ReplyType.INFO)
+            return
+
+        if len(cmd) == 3 and cmd[1] == "app":
+            # 知识库应用切换指令
+            if not context.kwargs.get("isgroup"):
+                _set_reply_text("该指令需在群聊中使用", e_context, level=ReplyType.ERROR)
+                return
+            if not _is_admin(e_context):
+                _set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
+                return
+            app_code = cmd[2]
+            group_name = context.kwargs.get("msg").from_user_nickname
+            group_mapping = self.config.get("group_app_map")
+            if group_mapping:
+                group_mapping[group_name] = app_code
+            else:
+                self.config["group_app_map"] = {group_name: app_code}
+            # 保存插件配置
+            super().save_config(self.config)
+            _set_reply_text(f"应用设置成功: {app_code}", e_context, level=ReplyType.INFO)
+        else:
+            _set_reply_text(f"指令错误,请输入{_get_trigger_prefix()}linkai help 获取帮助", e_context,
+                            level=ReplyType.INFO)
+            return
+
+    # LinkAI 对话任务处理
+    def _is_chat_task(self, e_context: EventContext):
+        context = e_context['context']
+        # 群聊应用管理
+        return self.config.get("group_app_map") and context.kwargs.get("isgroup")
+
+    def _process_chat_task(self, e_context: EventContext):
+        """
+        处理LinkAI对话任务
+        :param e_context: 对话上下文
+        """
+        context = e_context['context']
+        # 群聊应用管理
+        group_name = context.kwargs.get("msg").from_user_nickname
+        app_code = self._fetch_group_app_code(group_name)
+        if app_code:
+            context.kwargs['app_code'] = app_code
+
+    def _fetch_group_app_code(self, group_name: str) -> str:
+        """
+        根据群聊名称获取对应的应用code
+        :param group_name: 群聊名称
+        :return: 应用code
+        """
+        group_mapping = self.config.get("group_app_map")
+        if group_mapping:
+            app_code = group_mapping.get(group_name) or group_mapping.get("ALL_GROUP")
+            return app_code
+
+    def get_help_text(self, verbose=False, **kwargs):
+        trigger_prefix = _get_trigger_prefix()
+        help_text = "用于集成 LinkAI 提供的知识库、Midjourney绘画等能力。\n\n"
+        if not verbose:
+            return help_text
+        help_text += f'📖 知识库\n - 群聊中指定应用: {trigger_prefix}linkai app 应用编码\n\n例如: \n"$linkai app Kv2fXJcH"\n\n'
+        help_text += f"🎨 绘画\n - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n - 变换: mjv 图片ID 图片序号\n - 重置: mjr 图片ID"
+        help_text += f"\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 11055927171882 2\""
+        help_text += f"\n\"{trigger_prefix}mjv 11055927171882 2\"\n\"{trigger_prefix}mjr 11055927171882\""
+        return help_text
+
+
+# 静态方法
+def _is_admin(e_context: EventContext) -> bool:
+    """
+    判断消息是否由管理员用户发送
+    :param e_context: 消息上下文
+    :return: True: 是, False: 否
+    """
+    context = e_context["context"]
+    if context["isgroup"]:
+        return context.kwargs.get("msg").actual_user_id in global_config["admin_users"]
+    else:
+        return context["receiver"] in global_config["admin_users"]
+
+
+def _set_reply_text(content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
+    reply = Reply(level, content)
+    e_context["reply"] = reply
+    e_context.action = EventAction.BREAK_PASS
+
+
+def _get_trigger_prefix():
+    return conf().get("plugin_trigger_prefix", "$")

+ 415 - 0
plugins/linkai/midjourney.py

@@ -0,0 +1,415 @@
+from enum import Enum
+from config import conf
+from common.log import logger
+import requests
+import threading
+import time
+from bridge.reply import Reply, ReplyType
+import aiohttp
+import asyncio
+from bridge.context import ContextType
+from plugins import EventContext, EventAction
+
+INVALID_REQUEST = 410
+NOT_FOUND_ORIGIN_IMAGE = 461
+NOT_FOUND_TASK = 462
+
+
+class TaskType(Enum):
+    GENERATE = "generate"
+    UPSCALE = "upscale"
+    VARIATION = "variation"
+    RESET = "reset"
+
+    def __str__(self):
+        return self.name
+
+
+class Status(Enum):
+    PENDING = "pending"
+    FINISHED = "finished"
+    EXPIRED = "expired"
+    ABORTED = "aborted"
+
+    def __str__(self):
+        return self.name
+
+
+class TaskMode(Enum):
+    FAST = "fast"
+    RELAX = "relax"
+
+
+task_name_mapping = {
+    TaskType.GENERATE.name: "生成",
+    TaskType.UPSCALE.name: "放大",
+    TaskType.VARIATION.name: "变换",
+    TaskType.RESET.name: "重新生成",
+}
+
+
+class MJTask:
+    def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int = 60 * 30,
+                 status=Status.PENDING):
+        self.id = id
+        self.user_id = user_id
+        self.task_type = task_type
+        self.raw_prompt = raw_prompt
+        self.send_func = None  # send_func(img_url)
+        self.expiry_time = time.time() + expires
+        self.status = status
+        self.img_url = None  # url
+        self.img_id = None
+
+    def __str__(self):
+        return f"id={self.id}, user_id={self.user_id}, task_type={self.task_type}, status={self.status}, img_id={self.img_id}"
+
+
+# midjourney bot
+class MJBot:
+    def __init__(self, config):
+        self.base_url = "https://api.link-ai.chat/v1/img/midjourney"
+
+        self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
+        self.config = config
+        self.tasks = {}
+        self.temp_dict = {}
+        self.tasks_lock = threading.Lock()
+        self.event_loop = asyncio.new_event_loop()
+
+    def judge_mj_task_type(self, e_context: EventContext):
+        """
+        判断MJ任务的类型
+        :param e_context: 上下文
+        :return: 任务类型枚举
+        """
+        if not self.config:
+            return None
+        trigger_prefix = conf().get("plugin_trigger_prefix", "$")
+        context = e_context['context']
+        if context.type == ContextType.TEXT:
+            cmd_list = context.content.split(maxsplit=1)
+            if cmd_list[0].lower() == f"{trigger_prefix}mj":
+                return TaskType.GENERATE
+            elif cmd_list[0].lower() == f"{trigger_prefix}mju":
+                return TaskType.UPSCALE
+            elif cmd_list[0].lower() == f"{trigger_prefix}mjv":
+                return TaskType.VARIATION
+            elif cmd_list[0].lower() == f"{trigger_prefix}mjr":
+                return TaskType.RESET
+        elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"):
+            return TaskType.GENERATE
+
+    def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
+        """
+        处理mj任务
+        :param mj_type: mj任务类型
+        :param e_context: 对话上下文
+        """
+        context = e_context['context']
+        session_id = context["session_id"]
+        cmd = context.content.split(maxsplit=1)
+        if len(cmd) == 1 and context.type == ContextType.TEXT:
+            # midjourney 帮助指令
+            self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO)
+            return
+
+        if len(cmd) == 2 and (cmd[1] == "open" or cmd[1] == "close"):
+            # midjourney 开关指令
+            is_open = True
+            tips_text = "开启"
+            if cmd[1] == "close":
+                tips_text = "关闭"
+                is_open = False
+            self.config["enabled"] = is_open
+            self._set_reply_text(f"Midjourney绘画已{tips_text}", e_context, level=ReplyType.INFO)
+            return
+
+        if not self.config.get("enabled"):
+            logger.warn("Midjourney绘画未开启,请查看 plugins/linkai/config.json 中的配置")
+            self._set_reply_text(f"Midjourney绘画未开启", e_context, level=ReplyType.INFO)
+            return
+
+        if not self._check_rate_limit(session_id, e_context):
+            logger.warn("[MJ] midjourney task exceed rate limit")
+            return
+
+        if mj_type == TaskType.GENERATE:
+            if context.type == ContextType.IMAGE_CREATE:
+                raw_prompt = context.content
+            else:
+                # 图片生成
+                raw_prompt = cmd[1]
+            reply = self.generate(raw_prompt, session_id, e_context)
+            e_context['reply'] = reply
+            e_context.action = EventAction.BREAK_PASS
+            return
+
+        elif mj_type == TaskType.UPSCALE or mj_type == TaskType.VARIATION:
+            # 图片放大/变换
+            clist = cmd[1].split()
+            if len(clist) < 2:
+                self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context)
+                return
+            img_id = clist[0]
+            index = int(clist[1])
+            if index < 1 or index > 4:
+                self._set_reply_text(f"图片序号 {index} 错误,应在 1 至 4 之间", e_context)
+                return
+            key = f"{str(mj_type)}_{img_id}_{index}"
+            if self.temp_dict.get(key):
+                self._set_reply_text(f"第 {index} 张图片已经{task_name_mapping.get(str(mj_type))}过了", e_context)
+                return
+            # 执行图片放大/变换操作
+            reply = self.do_operate(mj_type, session_id, img_id, e_context, index)
+            e_context['reply'] = reply
+            e_context.action = EventAction.BREAK_PASS
+            return
+
+        elif mj_type == TaskType.RESET:
+            # 图片重新生成
+            clist = cmd[1].split()
+            if len(clist) < 1:
+                self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context)
+                return
+            img_id = clist[0]
+            # 图片重新生成
+            reply = self.do_operate(mj_type, session_id, img_id, e_context)
+            e_context['reply'] = reply
+            e_context.action = EventAction.BREAK_PASS
+        else:
+            self._set_reply_text(f"暂不支持该命令", e_context)
+
+    def generate(self, prompt: str, user_id: str, e_context: EventContext) -> Reply:
+        """
+        图片生成
+        :param prompt: 提示词
+        :param user_id: 用户id
+        :param e_context: 对话上下文
+        :return: 任务ID
+        """
+        logger.info(f"[MJ] image generate, prompt={prompt}")
+        mode = self._fetch_mode(prompt)
+        body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")}
+        if not self.config.get("img_proxy"):
+            body["img_proxy"] = False
+        res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5, 40))
+        if res.status_code == 200:
+            res = res.json()
+            logger.debug(f"[MJ] image generate, res={res}")
+            if res.get("code") == 200:
+                task_id = res.get("data").get("task_id")
+                real_prompt = res.get("data").get("real_prompt")
+                if mode == TaskMode.RELAX.value:
+                    time_str = "1~10分钟"
+                else:
+                    time_str = "1分钟"
+                content = f"🚀您的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n"
+                if real_prompt:
+                    content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}"
+                else:
+                    content += f"prompt: {prompt}"
+                reply = Reply(ReplyType.INFO, content)
+                task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id,
+                              task_type=TaskType.GENERATE)
+                # put to memory dict
+                self.tasks[task.id] = task
+                # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
+                self._do_check_task(task, e_context)
+                return reply
+        else:
+            res_json = res.json()
+            logger.error(f"[MJ] generate error, msg={res_json.get('message')}, status_code={res.status_code}")
+            if res.status_code == INVALID_REQUEST:
+                reply = Reply(ReplyType.ERROR, "图片生成失败,请检查提示词参数或内容")
+            else:
+                reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试")
+            return reply
+
+    def do_operate(self, task_type: TaskType, user_id: str, img_id: str, e_context: EventContext,
+                   index: int = None) -> Reply:
+        logger.info(f"[MJ] image operate, task_type={task_type}, img_id={img_id}, index={index}")
+        body = {"type": task_type.name, "img_id": img_id}
+        if index:
+            body["index"] = index
+        if not self.config.get("img_proxy"):
+            body["img_proxy"] = False
+        res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5, 40))
+        logger.debug(res)
+        if res.status_code == 200:
+            res = res.json()
+            if res.get("code") == 200:
+                task_id = res.get("data").get("task_id")
+                logger.info(f"[MJ] image operate processing, task_id={task_id}")
+                icon_map = {TaskType.UPSCALE: "🔎", TaskType.VARIATION: "🪄", TaskType.RESET: "🔄"}
+                content = f"{icon_map.get(task_type)}图片正在{task_name_mapping.get(task_type.name)}中,请耐心等待"
+                reply = Reply(ReplyType.INFO, content)
+                task = MJTask(id=task_id, status=Status.PENDING, user_id=user_id, task_type=task_type)
+                # put to memory dict
+                self.tasks[task.id] = task
+                key = f"{task_type.name}_{img_id}_{index}"
+                self.temp_dict[key] = True
+                # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
+                self._do_check_task(task, e_context)
+                return reply
+        else:
+            error_msg = ""
+            if res.status_code == NOT_FOUND_ORIGIN_IMAGE:
+                error_msg = "请输入正确的图片ID"
+            res_json = res.json()
+            logger.error(f"[MJ] operate error, msg={res_json.get('message')}, status_code={res.status_code}")
+            reply = Reply(ReplyType.ERROR, error_msg or "图片生成失败,请稍后再试")
+            return reply
+
+    def check_task_sync(self, task: MJTask, e_context: EventContext):
+        logger.debug(f"[MJ] start check task status, {task}")
+        max_retry_times = 90
+        while max_retry_times > 0:
+            time.sleep(10)
+            url = f"{self.base_url}/tasks/{task.id}"
+            try:
+                res = requests.get(url, headers=self.headers, timeout=8)
+                if res.status_code == 200:
+                    res_json = res.json()
+                    logger.debug(f"[MJ] task check res sync, task_id={task.id}, status={res.status_code}, "
+                                 f"data={res_json.get('data')}, thread={threading.current_thread().name}")
+                    if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name:
+                        # process success res
+                        if self.tasks.get(task.id):
+                            self.tasks[task.id].status = Status.FINISHED
+                        self._process_success_task(task, res_json.get("data"), e_context)
+                        return
+                    max_retry_times -= 1
+                else:
+                    res_json = res.json()
+                    logger.warn(f"[MJ] image check error, status_code={res.status_code}, res={res_json}")
+                    max_retry_times -= 20
+            except Exception as e:
+                max_retry_times -= 20
+                logger.warn(e)
+        logger.warn("[MJ] end from poll")
+        if self.tasks.get(task.id):
+            self.tasks[task.id].status = Status.EXPIRED
+
+    def _do_check_task(self, task: MJTask, e_context: EventContext):
+        threading.Thread(target=self.check_task_sync, args=(task, e_context)).start()
+
+    def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext):
+        """
+        处理任务成功的结果
+        :param task: MJ任务
+        :param res: 请求结果
+        :param e_context: 对话上下文
+        """
+        # channel send img
+        task.status = Status.FINISHED
+        task.img_id = res.get("img_id")
+        task.img_url = res.get("img_url")
+        logger.info(f"[MJ] task success, task_id={task.id}, img_id={task.img_id}, img_url={task.img_url}")
+
+        # send img
+        reply = Reply(ReplyType.IMAGE_URL, task.img_url)
+        channel = e_context["channel"]
+        channel._send(reply, e_context["context"])
+
+        # send info
+        trigger_prefix = conf().get("plugin_trigger_prefix", "$")
+        text = ""
+        if task.task_type == TaskType.GENERATE or task.task_type == TaskType.VARIATION or task.task_type == TaskType.RESET:
+            text = f"🎨绘画完成!\n"
+            if task.raw_prompt:
+                text += f"prompt: {task.raw_prompt}\n"
+            text += f"- - - - - - - - -\n图片ID: {task.img_id}"
+            text += f"\n\n🔎使用 {trigger_prefix}mju 命令放大图片\n"
+            text += f"例如:\n{trigger_prefix}mju {task.img_id} 1"
+            text += f"\n\n🪄使用 {trigger_prefix}mjv 命令变换图片\n"
+            text += f"例如:\n{trigger_prefix}mjv {task.img_id} 1"
+            text += f"\n\n🔄使用 {trigger_prefix}mjr 命令重新生成图片\n"
+            text += f"例如:\n{trigger_prefix}mjr {task.img_id}"
+            reply = Reply(ReplyType.INFO, text)
+            channel._send(reply, e_context["context"])
+
+        self._print_tasks()
+        return
+
+    def _check_rate_limit(self, user_id: str, e_context: EventContext) -> bool:
+        """
+        midjourney任务限流控制
+        :param user_id: 用户id
+        :param e_context: 对话上下文
+        :return: 任务是否能够生成, True:可以生成, False: 被限流
+        """
+        tasks = self.find_tasks_by_user_id(user_id)
+        task_count = len([t for t in tasks if t.status == Status.PENDING])
+        if task_count >= self.config.get("max_tasks_per_user"):
+            reply = Reply(ReplyType.INFO, "您的Midjourney作图任务数已达上限,请稍后再试")
+            e_context["reply"] = reply
+            e_context.action = EventAction.BREAK_PASS
+            return False
+        task_count = len([t for t in self.tasks.values() if t.status == Status.PENDING])
+        if task_count >= self.config.get("max_tasks"):
+            reply = Reply(ReplyType.INFO, "Midjourney作图任务数已达上限,请稍后再试")
+            e_context["reply"] = reply
+            e_context.action = EventAction.BREAK_PASS
+            return False
+        return True
+
+    def _fetch_mode(self, prompt) -> str:
+        mode = self.config.get("mode")
+        if "--relax" in prompt or mode == TaskMode.RELAX.value:
+            return TaskMode.RELAX.value
+        return mode or TaskMode.FAST.value
+
+    def _run_loop(self, loop: asyncio.BaseEventLoop):
+        """
+        运行事件循环,用于轮询任务的线程
+        :param loop: 事件循环
+        """
+        loop.run_forever()
+        loop.stop()
+
+    def _print_tasks(self):
+        for id in self.tasks:
+            logger.debug(f"[MJ] current task: {self.tasks[id]}")
+
+    def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
+        """
+        设置回复文本
+        :param content: 回复内容
+        :param e_context: 对话上下文
+        :param level: 回复等级
+        """
+        reply = Reply(level, content)
+        e_context["reply"] = reply
+        e_context.action = EventAction.BREAK_PASS
+
+    def get_help_text(self, verbose=False, **kwargs):
+        trigger_prefix = conf().get("plugin_trigger_prefix", "$")
+        help_text = "🎨利用Midjourney进行画图\n\n"
+        if not verbose:
+            return help_text
+        help_text += f" - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n - 变换: mjv 图片ID 图片序号\n - 重置: mjr 图片ID"
+        help_text += f"\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 11055927171882 2\""
+        help_text += f"\n\"{trigger_prefix}mjv 11055927171882 2\"\n\"{trigger_prefix}mjr 11055927171882\""
+        return help_text
+
+    def find_tasks_by_user_id(self, user_id) -> list:
+        result = []
+        with self.tasks_lock:
+            now = time.time()
+            for task in self.tasks.values():
+                if task.status == Status.PENDING and now > task.expiry_time:
+                    task.status = Status.EXPIRED
+                    logger.info(f"[MJ] {task} expired")
+                if task.user_id == user_id:
+                    result.append(task)
+        return result
+
+
+def check_prefix(content, prefix_list):
+    if not prefix_list:
+        return None
+    for prefix in prefix_list:
+        if content.startswith(prefix):
+            return prefix
+    return None

+ 21 - 4
plugins/plugin.py

@@ -1,6 +1,6 @@
 import os
 import json
-from config import pconf
+from config import pconf, plugin_config, conf
 from common.log import logger
 
 
@@ -15,14 +15,31 @@ class Plugin:
         """
         # 优先获取 plugins/config.json 中的全局配置
         plugin_conf = pconf(self.name)
-        if not plugin_conf:
-            # 全局配置不存在,则获取插件目录下的配置
+        if not plugin_conf or not conf().get("use_global_plugin_config"):
+            # 全局配置不存在 或者 未开启全局配置开关,则获取插件目录下的配置
             plugin_config_path = os.path.join(self.path, "config.json")
             if os.path.exists(plugin_config_path):
-                with open(plugin_config_path, "r") as f:
+                with open(plugin_config_path, "r", encoding="utf-8") as f:
                     plugin_conf = json.load(f)
         logger.debug(f"loading plugin config, plugin_name={self.name}, conf={plugin_conf}")
         return plugin_conf
 
+    def save_config(self, config: dict):
+        try:
+            plugin_config[self.name] = config
+            # 写入全局配置
+            global_config_path = "./plugins/config.json"
+            if os.path.exists(global_config_path):
+                with open(global_config_path, "w", encoding='utf-8') as f:
+                    json.dump(plugin_config, f, indent=4, ensure_ascii=False)
+            # 写入插件配置
+            plugin_config_path = os.path.join(self.path, "config.json")
+            if os.path.exists(plugin_config_path):
+                with open(plugin_config_path, "w", encoding='utf-8') as f:
+                    json.dump(config, f, indent=4, ensure_ascii=False)
+
+        except Exception as e:
+            logger.warn("save plugin config failed: {}".format(e))
+
     def get_help_text(self, **kwargs):
         return "暂无帮助信息"